├── .gitignore ├── ComLightLib ├── ComLightLib.vcxproj ├── ComLightLib.vcxproj.filters ├── Exception.hpp ├── Readme.txt ├── client │ └── CComPtr.hpp ├── comLightClient.h ├── comLightCommon.h ├── comLightServer.h ├── hresult.h ├── pal │ ├── guiddef.h │ └── hresult.h ├── server │ ├── Object.hpp │ ├── ObjectRoot.hpp │ ├── RefCounter.hpp │ ├── freeThreadedMarshaller.cpp │ ├── freeThreadedMarshaller.h │ └── interfaceMap.h ├── streams.h ├── unknwn.h └── utils │ ├── guid_parse.hpp │ └── typeTraits.hpp ├── ComputeShaders ├── ComputeShaders.cpp ├── ComputeShaders.vcxproj ├── ComputeShaders.vcxproj.filters ├── Readme.txt ├── add.hlsl ├── addInPlace.hlsl ├── addRepeat.hlsl ├── addRepeat64.hlsl ├── addRepeatEx.hlsl ├── addRepeatGelu.hlsl ├── addRepeatGelu64.hlsl ├── addRepeatScale.hlsl ├── addRows.hlsl ├── componentwiseBinaryOp.hlsli ├── convolutionMain.hlsl ├── convolutionMain2.hlsl ├── convolutionMain2Fixed.hlsl ├── convolutionPrep1.hlsl ├── convolutionPrep2.hlsl ├── copyConvert.hlsl ├── copyTranspose.hlsl ├── dbgFindNaN.hlsl ├── diagMaskInf.hlsl ├── flashAttention.hlsl ├── flashAttentionCommon.hlsli ├── flashAttentionCompat1.hlsl ├── flashAttentionCompat2.hlsl ├── flashAttentionCompat3.hlsl ├── fmaRepeat1.hlsl ├── fmaRepeat164.hlsl ├── fmaRepeat2.hlsl ├── fp64Utils.hlsli ├── groupReduce.hlsli ├── groupReduce64.hlsli ├── matReshapePanels.hlsl ├── miscUtils.hlsli ├── mulMatByRow.hlsl ├── mulMatByRow64.hlsl ├── mulMatByRowTiled.hlsl ├── mulMatByRowTiledEx.hlsl ├── mulMatByScalar.hlsl ├── mulMatDotMain.hlsl ├── mulMatDotReshape.hlsl ├── mulMatMadMain.hlsl ├── mulMatTiled.hlsl ├── mulMatTiledEx.hlsl ├── norm.hlsl ├── normCompat.hlsl ├── normFixed.hlsl ├── normFixed64.hlsl ├── repeatUtils.hlsli ├── scaleInPlace.hlsl ├── softMax.hlsl ├── softMax64.hlsl ├── softMaxCompat.hlsl ├── softMaxFixed.hlsl ├── softMaxLong.hlsl └── zeroMemory.hlsl ├── Examples ├── MicrophoneCS │ ├── CaptureThread.cs │ ├── CommandLineArgs.cs │ ├── MicrophoneCS.cs │ ├── MicrophoneCS.csproj │ ├── Readme.txt │ └── TranscribeCallbacks.cs ├── OldMain │ ├── OldMain.vcxproj │ ├── OldMain.vcxproj.filters │ ├── Readme.txt │ ├── Utils │ │ ├── Logger.cpp │ │ └── Logger.h │ ├── dr_wav.h │ └── main.cpp ├── TranscribeCS │ ├── AnsiCodes.cs │ ├── CommandLineArgs.cs │ ├── Readme.txt │ ├── Transcribe.cs │ ├── TranscribeCS.cs │ └── TranscribeCS.csproj ├── WhisperDesktop │ ├── AppState.cpp │ ├── AppState.h │ ├── CaptureDlg.cpp │ ├── CaptureDlg.h │ ├── CircleIndicator.cpp │ ├── CircleIndicator.h │ ├── LoadModelDlg.cpp │ ├── LoadModelDlg.h │ ├── ModelAdvancedDlg.cpp │ ├── ModelAdvancedDlg.h │ ├── Readme.txt │ ├── Resource.h │ ├── TranscribeDlg.cpp │ ├── TranscribeDlg.h │ ├── Utils │ │ ├── DebugConsole.cpp │ │ ├── DebugConsole.h │ │ ├── LanguageDropdown.cpp │ │ ├── LanguageDropdown.h │ │ ├── PendingState.cpp │ │ ├── PendingState.h │ │ ├── TranslateCheckbox.cpp │ │ ├── TranslateCheckbox.h │ │ ├── WTL │ │ │ ├── MS-PL.txt │ │ │ ├── ReadMe.html │ │ │ ├── atlapp.h │ │ │ ├── atlcrack.h │ │ │ ├── atlctrls.h │ │ │ ├── atlddx.h │ │ │ ├── atlgdi.h │ │ │ ├── atlres.h │ │ │ ├── atluser.h │ │ │ └── atlwinx.h │ │ ├── logger.cpp │ │ ├── logger.h │ │ ├── miscUtils.cpp │ │ └── miscUtils.h │ ├── WhisperDesktop.cpp │ ├── WhisperDesktop.manifest │ ├── WhisperDesktop.rc │ ├── WhisperDesktop.vcxproj │ ├── WhisperDesktop.vcxproj.filters │ ├── framework.h │ ├── stdafx.cpp │ ├── stdafx.h │ ├── sunflower.ico │ ├── targetver.h │ └── useDiscreteGpu.c └── main │ ├── Readme.txt │ ├── main.cpp │ ├── main.vcxproj │ ├── main.vcxproj.filters │ ├── miscUtils.cpp │ ├── miscUtils.h │ ├── params.cpp │ ├── params.h │ ├── textWriter.cpp │ └── textWriter.h ├── LICENSE ├── Readme.md ├── SampleClips ├── Readme.txt ├── columbia-large-1080ti.txt ├── columbia-large-1650.txt ├── columbia-large-vega7.txt ├── columbia-large-vega8.txt ├── columbia-medium-1080ti.txt ├── columbia-medium-1650.txt ├── columbia-medium-vega7.txt ├── columbia-medium-vega8.txt ├── columbia.wma ├── jfk-large-1080ti.txt ├── jfk-large-1650.txt ├── jfk-large-vega7.txt ├── jfk-large-vega8.txt ├── jfk-medium-1080ti.txt ├── jfk-medium-1650.txt ├── jfk-medium-vega7.txt ├── jfk-medium-vega8.txt ├── jfk.wav └── summary.tsv ├── Tools ├── CompressShaders │ ├── Cabinet.cs │ ├── CompressShaders.cs │ ├── CompressShaders.csproj │ ├── DetectFp64.cs │ ├── LZ4.cs │ ├── LanguageCodes.cs │ ├── Readme.txt │ └── ShaderNames.cs ├── CompressTables │ ├── CompressTables.cs │ └── CompressTables.csproj ├── PerfSummary │ ├── LogParser.cs │ ├── PerfSummary.cs │ ├── PerfSummary.csproj │ └── Summary.cs ├── compareTraces │ ├── CommandLineArgs.cpp │ ├── CommandLineArgs.h │ ├── Readme.txt │ ├── TraceReader.cpp │ ├── TraceReader.h │ ├── compare.cpp │ ├── compare.h │ ├── compareTraces.cpp │ ├── compareTraces.vcxproj │ ├── compareTraces.vcxproj.filters │ ├── stdafx.cpp │ ├── stdafx.h │ └── testUtils.cpp └── copy-binaries.cmd ├── Whisper ├── API │ ├── MfStructs.h │ ├── Readme.txt │ ├── SpecialTokens.h │ ├── TranscribeStructs.h │ ├── iContext.cl.h │ ├── iContext.h │ ├── iMediaFoundation.cl.h │ ├── iMediaFoundation.h │ ├── iTranscribeResult.cl.h │ ├── iTranscribeResult.h │ ├── loggerApi.h │ ├── sFullParams.h │ ├── sLanguageList.h │ ├── sLoadModelCallbacks.h │ ├── sModelSetup.h │ ├── whisperComLight.h │ └── whisperWindows.h ├── CPU │ ├── BufferAllocator.cpp │ ├── BufferAllocator.h │ ├── DecoderTensors.cpp │ ├── DecoderTensors.h │ ├── HybridLoader.cpp │ ├── HybridLoader.h │ ├── KvTensors.h │ ├── KvTensorsCpu.cpp │ ├── LargeBuffer.cpp │ ├── LargeBuffer.h │ ├── MlContext.h │ ├── MlContextCpu.cpp │ ├── ParallelForRunner.cpp │ ├── ParallelForRunner.h │ ├── Readme.txt │ ├── Tensor.h │ ├── TensorCpu.cpp │ ├── mulMat.cpp │ ├── mulMat.h │ ├── mulMat.kernel.hpp │ ├── mulMatImpl.avx2.cpp │ ├── mulMatImpl.cpp │ ├── mulMatImpl.h │ ├── mulMatImpl.panel.cpp │ ├── mulMatUtils.hpp │ ├── simdUtils.cpp │ └── simdUtils.h ├── D3D │ ├── Binder.cpp │ ├── Binder.h │ ├── MappedResource.cpp │ ├── MappedResource.h │ ├── RenderDoc │ │ ├── renderDoc.cpp │ │ ├── renderDoc.h │ │ └── renderdoc_app.h │ ├── createBuffer.cpp │ ├── createBuffer.h │ ├── createDevice.cpp │ ├── createDevice.h │ ├── device.h │ ├── downloadBuffer.cpp │ ├── downloadBuffer.h │ ├── enums.cpp │ ├── enums.h │ ├── listGPUs.cpp │ ├── listGPUs.h │ ├── sGpuInfo.h │ ├── shaderNames.cpp │ ├── shaderNames.h │ ├── shaders.cpp │ └── shaders.h ├── DllMain.cpp ├── Hybrid │ ├── HybridContext.cpp │ ├── HybridContext.h │ ├── KeyValueDownloader.cpp │ ├── KeyValueDownloader.h │ └── Readme.txt ├── MF │ ├── AudioBuffer.cpp │ ├── AudioBuffer.h │ ├── AudioCapture.cpp │ ├── AudioCapture.h │ ├── MediaFoundation.cpp │ ├── PcmReader.cpp │ ├── PcmReader.h │ ├── loadAudioFile.cpp │ ├── loadAudioFile.h │ ├── mfStartup.cpp │ ├── mfStartup.h │ ├── mfUtils.cpp │ └── mfUtils.h ├── ML │ ├── ConstantBuffer.cpp │ ├── ConstantBuffer.h │ ├── Context.ops.cpp │ ├── DbgNanTest.cpp │ ├── DbgNanTest.h │ ├── Device.cpp │ ├── Device.h │ ├── LookupTables.cpp │ ├── LookupTables.h │ ├── LookupTablesData.cpp │ ├── LookupTablesData.h │ ├── LookupTablesData.inl │ ├── MlContext.cpp │ ├── MlContext.dbg.cpp │ ├── MlContext.h │ ├── Reshaper.cpp │ ├── Reshaper.h │ ├── TempBuffers.cpp │ ├── TempBuffers.h │ ├── Tensor.cpp │ ├── Tensor.h │ ├── TensorEx.cpp │ ├── TensorEx.h │ ├── TensorGpuViews.cpp │ ├── TensorGpuViews.h │ ├── TensorShape.cpp │ ├── TensorShape.h │ ├── TensorsArena.cpp │ ├── TensorsArena.h │ ├── mlUtils.cpp │ ├── mlUtils.h │ ├── reshapedMultiply.h │ ├── tensorOpsTests.cpp │ ├── tensorOpsTests.h │ ├── testUtils.cpp │ ├── testUtils.h │ └── testUtilsC.h ├── Readme.txt ├── Resource.rc ├── Utils │ ├── CpuProfiler.cpp │ ├── CpuProfiler.h │ ├── DelayExecution.cpp │ ├── DelayExecution.h │ ├── GpuProfiler.cpp │ ├── GpuProfiler.h │ ├── GpuProfilerSimple.h │ ├── LZ4 │ │ ├── LICENSE │ │ ├── lz4.c │ │ └── lz4.h │ ├── Logger.cpp │ ├── Logger.h │ ├── MurmurHash3.cpp │ ├── MurmurHash3.h │ ├── ProfileCollection.cpp │ ├── ProfileCollection.h │ ├── ReadStream.h │ ├── Trace │ │ ├── TraceStructures.cpp │ │ ├── TraceStructures.h │ │ ├── TraceWriter.cpp │ │ ├── TraceWriter.h │ │ ├── tracing.cpp │ │ └── tracing.h │ ├── miscUtils.cpp │ ├── miscUtils.h │ ├── parallelFor.cpp │ └── parallelFor.h ├── Whisper.vcxproj ├── Whisper.vcxproj.filters ├── Whisper │ ├── ContextImpl.capture.cpp │ ├── ContextImpl.cpp │ ├── ContextImpl.diarize.cpp │ ├── ContextImpl.h │ ├── ContextImpl.misc.cpp │ ├── DecoderInputBuffers.cpp │ ├── DecoderInputBuffers.h │ ├── DecoderResultBuffer.cpp │ ├── DecoderResultBuffer.h │ ├── KeyValueBuffers.cpp │ ├── KeyValueBuffers.h │ ├── Languages.cpp │ ├── Languages.h │ ├── MelInputTensor.cpp │ ├── MelInputTensor.h │ ├── MelStreamer.cpp │ ├── MelStreamer.h │ ├── ModelBuffers.clone.cpp │ ├── ModelBuffers.cpp │ ├── ModelBuffers.h │ ├── ModelImpl.cpp │ ├── ModelImpl.h │ ├── ModelLoader.h │ ├── Spectrogram.cpp │ ├── Spectrogram.h │ ├── TranscribeResult.h │ ├── Vocabulary.cpp │ ├── Vocabulary.h │ ├── WhisperContext.cpp │ ├── WhisperContext.h │ ├── WhisperModel.cpp │ ├── WhisperModel.h │ ├── audioConstants.h │ ├── iSpectrogram.h │ ├── languageCodez.inl │ ├── languageCodez.tsv │ ├── loaderUtils.h │ ├── melSpectrogram.cpp │ ├── melSpectrogram.h │ ├── sEncodeParams.h │ ├── sModelParams.h │ ├── sTokenData.h │ ├── voiceActivityDetection.cpp │ └── voiceActivityDetection.h ├── misc.natvis ├── modelFactory.cpp ├── modelFactory.h ├── resource.h ├── source.compat │ ├── Readme.txt │ ├── convertThings.cpp │ ├── convertThings.h │ └── ggmlMsvc.c ├── source │ ├── LICENSE │ ├── Readme.txt │ ├── ggml.c │ ├── ggml.h │ ├── whisper.cpp │ └── whisper.h ├── stdafx.cpp ├── stdafx.h ├── whisper.def └── whisperCom.cpp ├── WhisperCpp.sln ├── WhisperNet ├── API │ ├── CaptureDeviceId.cs │ ├── Parameters.cs │ ├── SpecialTokens.cs │ ├── eCaptureStatus.cs │ ├── eGpuModelFlags.cs │ ├── eLanguage.cs │ ├── eLogLevel.cs │ ├── eModelImplementation.cs │ ├── eResultFlags.cs │ ├── eSpeakerChannel.cs │ ├── iAudioBuffer.cs │ ├── iAudioReader.cs │ ├── iMediaFoundation.cs │ ├── iModel.cs │ └── sCaptureParams.cs ├── AssemblyInfo.cs ├── AssemblyTitle.cs ├── Callbacks.cs ├── CaptureCallbacks.cs ├── Context.cs ├── ExtensionMethods.cs ├── Internal │ ├── NativeLogger.cs │ ├── iContext.cs │ ├── iTranscribeResult.cs │ ├── sCaptureCallbacks.cs │ ├── sCaptureDevice.cs │ ├── sFullParams.cs │ ├── sLoadModelCallbacks.cs │ ├── sLoggerSetup.cs │ ├── sModelSetup.cs │ └── sProgressSink.cs ├── Library.cs ├── Readme.md ├── WhisperNet.csproj └── WhisperNet.nuspec ├── WhisperPS ├── Commands │ ├── ExportBase.cs │ ├── ExportSubrip.cs │ ├── ExportText.cs │ ├── ExportWebVtt.cs │ ├── FormatSegments.cs │ ├── ListAdapters.cs │ ├── LoadModel.cs │ ├── TranscribeBase.cs │ └── TranscribeFile.cs ├── Internal │ ├── MarshalEx.cs │ ├── NativeLogger.cs │ ├── iTranscribeResult.cs │ ├── sCaptureDevice.cs │ ├── sFullParams.cs │ ├── sLoadModelCallbacks.cs │ ├── sModelSetup.cs │ └── sProgressSink.cs ├── Library.cs ├── Properties │ └── AssemblyTitle.cs ├── Readme.md ├── Types │ ├── Model.cs │ ├── Segment.cs │ └── Transcription.cs ├── Utils │ ├── CommandLogger.cs │ └── MiscUtils.cs ├── WhisperPS.csproj ├── WhisperPS.psd1 ├── app.config └── packages.config ├── gui-capture.png ├── gui-load-model.png └── gui-transcribe.png /.gitignore: -------------------------------------------------------------------------------- 1 | .vs/ 2 | ComLightLib/x64/ 3 | Whisper/x64/ 4 | x64/ 5 | Tools/CompressShaders/bin/ 6 | Tools/CompressShaders/obj/ 7 | Whisper/D3D/shaderData-Debug.inl 8 | Whisper/D3D/shaderData-Release.inl 9 | WhisperNet/bin/ 10 | WhisperNet/obj/ 11 | Examples/TranscribeCS/bin/ 12 | Examples/TranscribeCS/obj/ 13 | *.aps 14 | *.json 15 | *.user 16 | Examples/MicrophoneCS/obj/ 17 | Examples/MicrophoneCS/bin/ 18 | Tools/PerfSummary/bin/ 19 | Tools/PerfSummary/obj/ 20 | packages/ 21 | WhisperPS/obj/ 22 | WhisperPS/bin/ 23 | Tools/CompressTables/bin/ 24 | Tools/CompressTables/obj/ -------------------------------------------------------------------------------- /ComLightLib/ComLightLib.vcxproj.filters: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /ComLightLib/Exception.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace ComLight 4 | { 5 | class Exception : public std::runtime_error 6 | { 7 | // I don't like C++ exceptions too much, but for some cases they are useful. 8 | // You can throw ComLight::Exception from constructor, or from FinalConstruct() method, the library will catch & return the code from the class factory function. 9 | // Unfortunately, for interface methods this doesn't work, the C++ parts of the library can't catch them without very complex trickery like code generation. 10 | // You can still use this class in methods, but you'll need to catch them manually near the API boundary or the app will crash. 11 | // C++ doesn't have an ABI, the framework can't catch C++ exception across the modules. 12 | const HRESULT m_code; 13 | 14 | public: 15 | 16 | Exception( HRESULT hr ) : runtime_error( "ComLight HRESULT exception" ), m_code( hr ) { } 17 | 18 | HRESULT code() const { return m_code; } 19 | }; 20 | } -------------------------------------------------------------------------------- /ComLightLib/Readme.txt: -------------------------------------------------------------------------------- 1 | Copy-pasted from there: 2 | https://github.com/Const-me/ComLightInterop/tree/master/ComLightLib 3 | With only a few minor changes. -------------------------------------------------------------------------------- /ComLightLib/comLightClient.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "comLightCommon.h" 3 | #include "client/CComPtr.hpp" 4 | #include "utils/typeTraits.hpp" 5 | 6 | namespace ComLight 7 | { 8 | namespace details 9 | { 10 | template 11 | inline constexpr void** castDoublePointerToVoid( T** pp ) 12 | { 13 | static_assert( pointersAssignable(), "IID_PPV_ARGS macro should be used with IUnknown interfaces" ); 14 | return reinterpret_cast( pp ); 15 | } 16 | } 17 | } 18 | 19 | #ifdef IID_PPV_ARGS 20 | #undef IID_PPV_ARGS 21 | #endif 22 | 23 | #define IID_PPV_ARGS( pp ) decltype( **pp )::iid, ::ComLight::details::castDoublePointerToVoid( pp ) -------------------------------------------------------------------------------- /ComLightLib/comLightCommon.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "hresult.h" 3 | 4 | #ifdef _MSC_VER 5 | #include 6 | #else 7 | #include "pal/guiddef.h" 8 | using LPCTSTR = const char*; 9 | #endif 10 | 11 | #include "unknwn.h" -------------------------------------------------------------------------------- /ComLightLib/comLightServer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "comLightCommon.h" 3 | #include "client/CComPtr.hpp" 4 | 5 | #include "server/ObjectRoot.hpp" 6 | #include "server/interfaceMap.h" 7 | #include "server/Object.hpp" 8 | #include "server/freeThreadedMarshaller.h" 9 | 10 | #ifdef _MSC_VER 11 | // On Windows, it's controlled by library.def module definition file. There's __declspec(dllexport), but it adds underscore, I don't like that. 12 | #define DLLEXPORT extern "C" 13 | #else 14 | #define DLLEXPORT extern "C" __attribute__((visibility("default"))) 15 | #endif -------------------------------------------------------------------------------- /ComLightLib/hresult.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #ifdef _MSC_VER 4 | #include 5 | #include 6 | #else 7 | #include "pal/hresult.h" 8 | #endif 9 | 10 | #define CHECK( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) return __hr; } 11 | 12 | #ifndef _MSC_VER 13 | inline constexpr HRESULT HRESULT_FROM_WIN32( int c ) 14 | { 15 | return c < 0 ? c : ( ( 0xFFFF & c ) | 0x80070000 ); 16 | } 17 | 18 | constexpr HRESULT OLE_E_BLANK = _HRESULT_TYPEDEF_( 0x80040007 ); 19 | constexpr HRESULT E_BOUNDS = _HRESULT_TYPEDEF_( 0x8000000BL ); 20 | 21 | constexpr int ERROR_HANDLE_EOF = 38; 22 | constexpr int ERROR_ALREADY_INITIALIZED = 1247; 23 | #endif 24 | 25 | constexpr HRESULT E_EOF = HRESULT_FROM_WIN32( ERROR_HANDLE_EOF ); 26 | constexpr HRESULT E_ALREADY_INITIALIZED = HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED ); -------------------------------------------------------------------------------- /ComLightLib/pal/guiddef.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #ifndef GUID_DEFINED 5 | #define GUID_DEFINED 6 | #endif 7 | 8 | struct GUID 9 | { 10 | uint32_t Data1; 11 | uint16_t Data2; 12 | uint16_t Data3; 13 | std::array Data4; 14 | 15 | constexpr inline bool operator==( const GUID& that ) const 16 | { 17 | return Data1 == that.Data1 && Data2 == that.Data2 && Data3 == that.Data3 && Data4 == that.Data4; 18 | } 19 | }; 20 | 21 | using REFIID = const GUID&; -------------------------------------------------------------------------------- /ComLightLib/server/ObjectRoot.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "RefCounter.hpp" 3 | #include "../comLightCommon.h" 4 | #include "../utils/typeTraits.hpp" 5 | 6 | namespace ComLight 7 | { 8 | // Base class of objects, implements reference counting, also a few lifetime methods. 9 | // The template argument is the interface you want clients to get when they ask for IID_IUnknown. By convention, that pointer defines object's identity. 10 | template 11 | class ObjectRoot : public RefCounter, public I 12 | { 13 | protected: 14 | 15 | inline HRESULT internalFinalConstruct() 16 | { 17 | return S_FALSE; 18 | } 19 | 20 | inline HRESULT FinalConstruct() 21 | { 22 | return S_FALSE; 23 | } 24 | 25 | inline void FinalRelease() { } 26 | 27 | IUnknown* getUnknown() 28 | { 29 | static_assert( details::pointersAssignable(), "The interface doesn't derive from IUnknown" ); 30 | return static_cast( this ); 31 | } 32 | 33 | bool queryExtraInterfaces( REFIID riid, void **ppvObject ) const 34 | { 35 | return false; 36 | } 37 | 38 | // Implement query interface with 2 entries, IUnknown and I. 39 | bool implQueryInterface( REFIID riid, void** ppvObject ) 40 | { 41 | if( riid == I::iid() || riid == IUnknown::iid() ) 42 | { 43 | I* const result = this; 44 | result->AddRef(); 45 | *ppvObject = result; 46 | return true; 47 | } 48 | return false; 49 | } 50 | }; 51 | } -------------------------------------------------------------------------------- /ComLightLib/server/RefCounter.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | 6 | namespace ComLight 7 | { 8 | // Very base class of objects, implements reference counting. 9 | class RefCounter 10 | { 11 | std::atomic_uint referenceCounter; 12 | 13 | public: 14 | 15 | RefCounter() : referenceCounter( 0 ) { } 16 | 17 | inline virtual ~RefCounter() { } 18 | 19 | RefCounter( const RefCounter &that ) = delete; 20 | RefCounter( RefCounter &&that ) = delete; 21 | 22 | protected: 23 | 24 | uint32_t implAddRef() 25 | { 26 | return ++referenceCounter; 27 | } 28 | 29 | uint32_t implRelease() 30 | { 31 | // Might be a good idea to use locks, at least in debug builds. They're much slower than atomics, but with locks it's possible to detect when 2 threads call release at the same time, for object with counter = 1. 32 | // It's a memory management bug, but it would be nice if debug builds would handle that case gracefully. 33 | const uint32_t rc = --referenceCounter; 34 | assert( rc != UINT_MAX ); 35 | return rc; 36 | } 37 | }; 38 | } -------------------------------------------------------------------------------- /ComLightLib/server/freeThreadedMarshaller.cpp: -------------------------------------------------------------------------------- 1 | #include "freeThreadedMarshaller.h" 2 | #ifdef _MSC_VER 3 | #include 4 | 5 | HRESULT ComLight::details::createFreeThreadedMarshaller( IUnknown* pUnkOuter, IUnknown** ppUnkMarshal ) 6 | { 7 | return ::CoCreateFreeThreadedMarshaler( (LPUNKNOWN)pUnkOuter, (LPUNKNOWN *)ppUnkMarshal ); 8 | } 9 | 10 | bool ComLight::details::queryMarshallerInterface( REFIID riid, void **ppvObject, IUnknown* marshaller ) 11 | { 12 | if( riid != IID_IMarshal || nullptr == marshaller ) 13 | return false; 14 | const HRESULT hr = marshaller->QueryInterface( IID_IMarshal, ppvObject ); 15 | return SUCCEEDED( hr ) ? true : false; 16 | } 17 | #endif -------------------------------------------------------------------------------- /ComLightLib/server/freeThreadedMarshaller.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifdef _MSC_VER 3 | #include "../comLightCommon.h" 4 | 5 | namespace ComLight 6 | { 7 | namespace details 8 | { 9 | HRESULT createFreeThreadedMarshaller( IUnknown* pUnkOuter, IUnknown** ppUnkMarshal ); 10 | bool queryMarshallerInterface( REFIID riid, void **ppvObject, IUnknown* marshaller ); 11 | } 12 | } 13 | 14 | #define DECLARE_FREE_THREADED_MARSHALLER() \ 15 | private: \ 16 | ComLight::CComPtr m_freeThreadedMarshaller; \ 17 | protected: \ 18 | HRESULT internalFinalConstruct() \ 19 | { \ 20 | return ComLight::details::createFreeThreadedMarshaller( getUnknown(), &m_freeThreadedMarshaller ); \ 21 | } \ 22 | bool queryExtraInterfaces( REFIID riid, void **ppvObject ) const \ 23 | { \ 24 | return ComLight::details::queryMarshallerInterface( riid, ppvObject, m_freeThreadedMarshaller ); \ 25 | } 26 | 27 | #else 28 | #define DECLARE_FREE_THREADED_MARSHALLER() 29 | #endif -------------------------------------------------------------------------------- /ComLightLib/server/interfaceMap.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../utils/typeTraits.hpp" 3 | 4 | // Unlike ATL, the interface map is optional for ComLight. 5 | // If you won't declare a map, the object will support 2 interfaces: IUnknown, and whatever template argument was passed to ObjectRoot class. 6 | #define BEGIN_COM_MAP() \ 7 | protected: \ 8 | bool implQueryInterface( REFIID iid, void** ppvObject ) { 9 | 10 | #define END_COM_MAP() return false; } 11 | 12 | namespace ComLight 13 | { 14 | namespace details 15 | { 16 | template 17 | inline bool tryReturnInterface( REFIID iid, C* pThis, void** ppvResult ) 18 | { 19 | static_assert( pointersAssignable(), "Trying to implement an interface that doesn't derive from IUnknown" ); 20 | static_assert( pointersAssignable(), "Declared support for an interface, but the class doesn't implement it" ); 21 | if( I::iid() != iid ) 22 | return false; 23 | I* const result = pThis; 24 | result->AddRef(); 25 | *ppvResult = result; 26 | return true; 27 | } 28 | } 29 | } 30 | 31 | #define COM_INTERFACE_ENTRY( I ) if( ComLight::details::tryReturnInterface( iid, this, ppvObject ) ) return true; -------------------------------------------------------------------------------- /ComLightLib/unknwn.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | // Calling conventions 5 | #ifdef _MSC_VER 6 | #define COMLIGHTCALL __stdcall 7 | #define DECLSPEC_NOVTABLE __declspec(novtable) 8 | #elif defined(__GNUC__) || defined(__clang__) 9 | #if defined(__i386__) 10 | #define COMLIGHTCALL __attribute__((stdcall)) 11 | #else 12 | #define COMLIGHTCALL 13 | #endif 14 | #define DECLSPEC_NOVTABLE 15 | #else 16 | #error Unsupported C++ compiler 17 | #endif 18 | 19 | #include "utils/guid_parse.hpp" 20 | 21 | #define DEFINE_INTERFACE_ID( guidString ) static constexpr GUID iid() { return ::ComLight::make_guid( guidString ); } 22 | 23 | namespace ComLight 24 | { 25 | // This thing is binary compatible with IUnknown from Windows SDK. See DesktopClient demo project, it uses normal COM interop in .NET framework 4.7 to call my implementation. 26 | struct DECLSPEC_NOVTABLE IUnknown 27 | { 28 | DEFINE_INTERFACE_ID( "00000000-0000-0000-c000-000000000046" ); 29 | 30 | virtual HRESULT COMLIGHTCALL QueryInterface( REFIID riid, void **ppvObject ) = 0; 31 | 32 | virtual uint32_t COMLIGHTCALL AddRef() = 0; 33 | 34 | virtual uint32_t COMLIGHTCALL Release() = 0; 35 | }; 36 | } -------------------------------------------------------------------------------- /ComputeShaders/ComputeShaders.cpp: -------------------------------------------------------------------------------- 1 | void fnComputeShaders() 2 | { 3 | } -------------------------------------------------------------------------------- /ComputeShaders/Readme.txt: -------------------------------------------------------------------------------- 1 | This project compiles all the compute shaders which implement the model. 2 | 3 | Many shaders come in 2 versions, something.hlsl and something64.hlsl 4 | 5 | The version with the `64` suffix is used on AMD GPUs, the version without suffix is used on nVidia and Intel GPUs. 6 | 7 | Not all of these shaders are actually used for anything. 8 | Some of them are implementing binary compatibility for the reference CPU version, and not used unless messing with the `constexpr` flags in MlContext C++ class. 9 | Such shaders often require FP64 support, which is an optional feature in D3D11. 10 | CompressShaders tool detects such shaders by looking at the SFI0 chunk in the binary, and outputs a bitmap of the FP64 shaders. 11 | This way, missing FP64 hardware support shouldn’t break the library. -------------------------------------------------------------------------------- /ComputeShaders/add.hlsl: -------------------------------------------------------------------------------- 1 | inline float compute( float a, float b ) 2 | { 3 | return a + b; 4 | } 5 | 6 | #include "componentwiseBinaryOp.hlsli" -------------------------------------------------------------------------------- /ComputeShaders/addInPlace.hlsl: -------------------------------------------------------------------------------- 1 | #ifndef THREADS 2 | #define THREADS 512 3 | #endif 4 | 5 | Buffer arg0: register( t0 ); 6 | RWBuffer result: register( u0 ); 7 | 8 | cbuffer Constants: register( b0 ) 9 | { 10 | uint4 size: packoffset( c0 ); 11 | uint4 strides: packoffset( c1 ); 12 | uint4 argStrides: packoffset( c3 ); 13 | } 14 | 15 | inline uint rowOffset( uint3 idx, uint4 strides ) 16 | { 17 | return idx[ 0 ] * strides[ 1 ] + idx[ 1 ] * strides[ 2 ] + idx[ 2 ] * strides[ 3 ]; 18 | } 19 | 20 | [ numthreads( THREADS, 1, 1 ) ] 21 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) 22 | { 23 | uint rdi = rowOffset( group, strides ); 24 | uint rsi = rowOffset( group, argStrides ); 25 | 26 | const uint rdiEnd = rdi + size[ 0 ] * strides[ 0 ]; 27 | rdi += thread * strides[ 0 ]; 28 | rsi += thread * argStrides[ 0 ]; 29 | 30 | const uint rdiInc = THREADS * strides[ 0 ]; 31 | const uint rsiInc = THREADS * argStrides[ 0 ]; 32 | 33 | for( ; rdi < rdiEnd; rdi += rdiInc, rsi += rsiInc ) 34 | { 35 | float f = result[ rdi ]; 36 | f += arg0[ rsi ]; 37 | result[ rdi ] = f; 38 | } 39 | } -------------------------------------------------------------------------------- /ComputeShaders/addRepeat64.hlsl: -------------------------------------------------------------------------------- 1 | #define THREADS 64 2 | #include "addRepeat.hlsl" -------------------------------------------------------------------------------- /ComputeShaders/addRepeatGelu64.hlsl: -------------------------------------------------------------------------------- 1 | #define THREADS 64 2 | #include "addRepeatGelu.hlsl" -------------------------------------------------------------------------------- /ComputeShaders/addRows.hlsl: -------------------------------------------------------------------------------- 1 | #ifndef THREADS 2 | #define THREADS 256 3 | #endif 4 | 5 | // dec.tokenEmbedding tensor 6 | Buffer tokenEmbedding: register( t0 ); 7 | // dec.positionalEmbedding tensor 8 | Buffer positionalEmbedding: register( t1 ); 9 | // R32_UINT buffer with the input tokens 10 | Buffer embd: register( t2 ); 11 | // Output tensor 12 | RWBuffer result: register( u0 ); 13 | 14 | cbuffer Constants: register( b0 ) 15 | { 16 | uint rowLength: packoffset( c0.x ); 17 | uint pastTokensCount: packoffset( c0.y ); 18 | uint outputRowStride: packoffset( c0.z ); 19 | uint2 embStrides: packoffset( c1.x ); 20 | uint2 posStrides: packoffset( c1.z ); 21 | } 22 | 23 | [ numthreads( THREADS, 1, 1 ) ] 24 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) 25 | { 26 | const uint row = group.x; 27 | const uint rowTok = embd[ row ]; 28 | const uint rowPos = row + pastTokensCount; 29 | 30 | uint rdi = row * outputRowStride; 31 | const uint rdiEnd = rdi + rowLength; 32 | rdi += thread; 33 | 34 | uint rsiTok = rowTok * embStrides.y; 35 | rsiTok += thread * embStrides.x; 36 | 37 | uint rsiPos = rowPos * posStrides.y; 38 | rsiPos += thread * posStrides.x; 39 | 40 | for( ; rdi < rdiEnd; rdi += THREADS, rsiTok += THREADS * embStrides.x, rsiPos += THREADS * posStrides.x ) 41 | { 42 | float a = tokenEmbedding[ rsiTok ]; 43 | float b = positionalEmbedding[ rsiPos ]; 44 | result[ rdi ] = a + b; 45 | } 46 | } -------------------------------------------------------------------------------- /ComputeShaders/componentwiseBinaryOp.hlsli: -------------------------------------------------------------------------------- 1 | Buffer arg0: register( t0 ); 2 | Buffer arg1: register( t1 ); 3 | RWBuffer result: register( u0 ); 4 | 5 | cbuffer Constants: register( b0 ) 6 | { 7 | uint4 src0_elements: packoffset( c0 ); 8 | uint4 src0_strides: packoffset( c1 ); 9 | uint4 src1_elements: packoffset( c2 ); 10 | uint4 src1_strides: packoffset( c3 ); 11 | uint4 result_elements: packoffset( c4 ); 12 | uint4 result_strides: packoffset( c5 ); 13 | } 14 | 15 | [ numthreads( 32, 1, 1 ) ] 16 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) 17 | { 18 | const uint j = group.x; 19 | const uint nb1 = result_strides[ 1 ]; 20 | const uint nb01 = src0_strides[ 1 ]; 21 | 22 | const uint nb10 = src1_strides[ 0 ]; 23 | const uint nb11 = src1_strides[ 1 ]; 24 | const uint nc = src0_elements[ 0 ]; 25 | 26 | uint rsi0 = j * nb01; 27 | uint rsi1 = j * nb11; 28 | uint rdi = j * nb1; 29 | const uint rsi0End = rsi0 + nc; 30 | 31 | rsi0 += thread; 32 | rsi1 += thread * nb10; 33 | rdi += thread; 34 | 35 | const uint rsi1Inc = 32 * nb10; 36 | for( ; rsi0 < rsi0End; rsi0 += 32, rsi1 += rsi1Inc, rdi += 32 ) 37 | { 38 | const float a = arg0[ rsi0 ]; 39 | const float b = arg1[ rsi1 ]; 40 | const float res = compute( a, b ); 41 | result[ rdi ] = res; 42 | } 43 | } -------------------------------------------------------------------------------- /ComputeShaders/convolutionPrep1.hlsl: -------------------------------------------------------------------------------- 1 | // ggml_compute_forward_conv_1d_1s_f16_f32, prepare kernel data (src0) 2 | // Dispatch [ ne01, ne02, 1 ] thread groups 3 | Buffer arg0: register( t0 ); 4 | RWBuffer result: register( u0 ); 5 | 6 | cbuffer Constants: register( b0 ) 7 | { 8 | uint4 src0_elements: packoffset( c0 ); 9 | uint4 src0_strides: packoffset( c1 ); 10 | } 11 | 12 | inline uint roundUp32( uint x ) 13 | { 14 | return ( x + 31 ) & ( ~31u ); 15 | } 16 | 17 | [ numthreads( 32, 1, 1 ) ] 18 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) 19 | { 20 | const uint nb01 = src0_strides[ 1 ]; 21 | const uint nb02 = src0_strides[ 2 ]; 22 | 23 | const uint ne00 = src0_elements[ 0 ]; 24 | const uint ne01 = src0_elements[ 1 ]; 25 | const uint ew0 = roundUp32( ne01 ); 26 | 27 | const uint i02 = group.y; 28 | const uint i01 = group.x; 29 | 30 | uint rsi = i02 * nb02 + i01 * nb01; 31 | const uint rsiEnd = rsi + ne00; 32 | uint rdi = i02 * ew0 * ne00 + i01; 33 | rsi += thread; 34 | rdi += thread * ew0; 35 | const uint rdiInc = 32 * ew0; 36 | 37 | for( ; rsi < rsiEnd; rsi += 32, rdi += rdiInc ) 38 | result[ rdi ] = arg0[ rsi ]; 39 | } -------------------------------------------------------------------------------- /ComputeShaders/convolutionPrep2.hlsl: -------------------------------------------------------------------------------- 1 | // ggml_compute_forward_conv_1d_1s_f16_f32, prepare source data (src1) 2 | // Dispatch [ ne11, 1, 1 ] thread groups 3 | Buffer arg1: register( t0 ); 4 | RWBuffer result: register( u0 ); 5 | 6 | cbuffer Constants: register( b0 ) 7 | { 8 | uint4 src0_elements: packoffset( c0 ); 9 | uint4 src1_elements: packoffset( c2 ); 10 | uint4 src1_strides: packoffset( c3 ); 11 | } 12 | 13 | #include "miscUtils.hlsli" 14 | 15 | [ numthreads( 32, 1, 1 ) ] 16 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) 17 | { 18 | const uint i11 = group.x; 19 | 20 | const uint ne00 = src0_elements[ 0 ]; 21 | const uint ne01 = src0_elements[ 1 ]; 22 | const uint ne10 = src1_elements[ 0 ]; 23 | const uint nb11 = src1_strides[ 1 ]; 24 | 25 | const uint nk = ne00; 26 | const uint nh = nk / 2; 27 | const int ew0 = roundUp32( ne01 ); 28 | 29 | uint rsi = i11 * nb11; 30 | uint rdi = nh * ew0 + i11; 31 | const uint rdiInc = ew0 * 32; 32 | const uint rsiEnd = rsi + ne10; 33 | 34 | rsi += thread; 35 | rdi += thread * ew0; 36 | 37 | for( ; rsi < rsiEnd; rsi += 32, rdi += rdiInc ) 38 | { 39 | float f = arg1[ rsi ]; 40 | f = adjustFp16( f ); 41 | result[ rdi ] = f; 42 | } 43 | } -------------------------------------------------------------------------------- /ComputeShaders/copyConvert.hlsl: -------------------------------------------------------------------------------- 1 | // ggml_compute_forward_dup_f32 when we only need to convert types, but not reshape the tensor 2 | // Dispatch [ ne01, ne02, ne03 ] thread groups of this shader 3 | Buffer arg0: register( t0 ); 4 | RWBuffer result: register( u0 ); 5 | 6 | cbuffer Constants: register( b0 ) 7 | { 8 | uint4 src0_elements: packoffset( c0 ); 9 | uint4 src0_strides: packoffset( c1 ); 10 | bool downcastFp32 : packoffset( c2.x ); 11 | } 12 | 13 | #include "miscUtils.hlsli" 14 | 15 | [ numthreads( 32, 1, 1 ) ] 16 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) 17 | { 18 | const uint nb00 = src0_strides[ 0 ]; 19 | const uint nb01 = src0_strides[ 1 ]; 20 | const uint nb02 = src0_strides[ 2 ]; 21 | const uint nb03 = src0_strides[ 3 ]; 22 | 23 | const uint ne00 = src0_elements[ 0 ]; 24 | const uint ne01 = src0_elements[ 1 ]; 25 | const uint ne02 = src0_elements[ 2 ]; 26 | const uint ne03 = src0_elements[ 3 ]; 27 | 28 | const uint i01 = group.x; 29 | const uint i02 = group.y; 30 | const uint i03 = group.z; 31 | 32 | const uint rs = ne00 * nb00; 33 | //const uint id = i01 + i02 * ne02 + i03 * ne01 * ne02; 34 | const uint id = ( i03 * ne01 + i02 ) * ne02 + i01; 35 | 36 | uint rsi = i01 * nb01 + i02 * nb02 + i03 * nb03; 37 | uint rdi = id * rs; 38 | 39 | const uint rsiEnd = rsi + rs; 40 | rsi += thread; 41 | rdi += thread; 42 | for( ; rsi < rsiEnd; rsi += 32, rdi += 32 ) 43 | { 44 | float f = arg0[ rsi ]; 45 | [branch] 46 | if( downcastFp32 ) 47 | f = adjustFp16( f ); 48 | result[ rdi ] = f; 49 | } 50 | } -------------------------------------------------------------------------------- /ComputeShaders/diagMaskInf.hlsl: -------------------------------------------------------------------------------- 1 | // ggml_compute_forward_diag_mask_inf_f32 2 | RWBuffer result: register( u0 ); 3 | 4 | cbuffer Constants: register( b0 ) 5 | { 6 | uint4 elements: packoffset( c0 ); 7 | uint4 strides: packoffset( c1 ); 8 | uint n_past : packoffset( c2.x ); 9 | } 10 | 11 | static const float negativeInfinity = asfloat( 0xff800000 ); 12 | 13 | [numthreads( 32, 1, 1 )] 14 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) 15 | { 16 | const uint k = group.y; 17 | const uint j = group.x; 18 | 19 | // Start of the row 20 | uint rdi = k * strides[ 2 ] + j * strides[ 1 ]; 21 | // End of the row 22 | const uint rdiEnd = rdi + elements[ 0 ] * strides[ 0 ]; 23 | // First index to write in this thread 24 | rdi += ( n_past + j + thread + 1 ) * strides[ 0 ]; 25 | // Index increment 26 | const uint rdiInc = 32 * strides[ 0 ]; 27 | 28 | for( ; rdi < rdiEnd; rdi += rdiInc ) 29 | result[ rdi ] = negativeInfinity; 30 | } -------------------------------------------------------------------------------- /ComputeShaders/fmaRepeat164.hlsl: -------------------------------------------------------------------------------- 1 | #define THREADS 64 2 | #include "fmaRepeat1.hlsl" -------------------------------------------------------------------------------- /ComputeShaders/fmaRepeat2.hlsl: -------------------------------------------------------------------------------- 1 | // Implementation of fmaRepeat() when source arguments have different shape or VRAM layout 2 | // Dispatch [ nb[ 1 ], nb[ 2 ], nb[ 3 ] ] thread groups of this shader, where nb is size of the destination tensor 3 | RWBuffer tensor: register( u0 ); 4 | Buffer patternMul: register( t0 ); 5 | Buffer patternAdd: register( t1 ); 6 | 7 | cbuffer Constants: register( b0 ) 8 | { 9 | uint4 tensorSize: packoffset( c0 ); 10 | uint4 tensorStrides: packoffset( c1 ); 11 | uint4 patternSizeMul: packoffset( c2 ); 12 | uint4 patternStridesMul: packoffset( c3 ); 13 | uint4 patternSizeAdd: packoffset( c4 ); 14 | uint4 patternStridesAdd: packoffset( c5 ); 15 | } 16 | 17 | #ifndef THREADS 18 | #define THREADS 32 19 | #endif 20 | 21 | #include "repeatUtils.hlsli" 22 | 23 | inline float loadPattern( Buffer buffer, uint rowStart, uint i, uint4 size, uint4 stride ) 24 | { 25 | i %= size.x; 26 | return buffer[ i * stride.x + rowStart ]; 27 | } 28 | 29 | [ numthreads( THREADS, 1, 1 ) ] 30 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) 31 | { 32 | uint3 it = tensorIteratorState( group, thread, tensorSize, tensorStrides ); 33 | const uint rsiMul = rowOffset( group % patternSizeMul.yzw, patternStridesMul ); 34 | const uint rsiAdd = rowOffset( group % patternSizeAdd.yzw, patternStridesAdd ); 35 | 36 | for( uint i = thread; it.x < it.z; it.x += it.y, i++ ) 37 | { 38 | precise float f = tensor[ it.x ]; 39 | float mul = loadPattern( patternMul, rsiMul, i, patternSizeMul, patternStridesMul ); 40 | float add = loadPattern( patternAdd, rsiAdd, i, patternSizeAdd, patternStridesAdd ); 41 | f *= mul; 42 | f += add; 43 | tensor[ it.x ] = f; 44 | } 45 | } -------------------------------------------------------------------------------- /ComputeShaders/fp64Utils.hlsli: -------------------------------------------------------------------------------- 1 | // TODO: compile another version of these shader, and use it on GPUs with ExtendedDoublesShaderInstructions flag, will become slightly faster 2 | // https://learn.microsoft.com/en-us/windows/win32/api/d3d11/ns-d3d11-d3d11_feature_data_d3d11_options 3 | #ifndef ExtendedDoublesShaderInstructions 4 | #define ExtendedDoublesShaderInstructions 0 5 | #endif 6 | 7 | // Compute num/den in FP64 precision 8 | inline double div64( double num, double den ) 9 | { 10 | #if ExtendedDoublesShaderInstructions 11 | return num / den; 12 | #else 13 | // https://en.wikipedia.org/wiki/Division_algorithm#Newton%E2%80%93Raphson_division 14 | double x = 1.0f / (float)den; 15 | x += x * ( 1.0 - den * x ); 16 | x += x * ( 1.0 - den * x ); 17 | return num * x; 18 | #endif 19 | } 20 | 21 | // Compute sqrt(x) in FP64 precision 22 | inline double sqrt64( double x ) 23 | { 24 | double root = sqrt( (float)x ); 25 | root = 0.5 * ( root + div64( x, root ) ); 26 | root = 0.5 * ( root + div64( x, root ) ); 27 | return root; 28 | } -------------------------------------------------------------------------------- /ComputeShaders/groupReduce64.hlsli: -------------------------------------------------------------------------------- 1 | groupshared float sharedAccumulators[ 64 ]; 2 | 3 | // Compute horisontal sum of the numbers. The result is only correct on the thread #0 of the group. 4 | void horizontalSum( const uint thread, inout float sum ) 5 | { 6 | sharedAccumulators[ thread ] = sum; 7 | for( uint i = 32; i > 1; i /= 2 ) 8 | { 9 | GroupMemoryBarrierWithGroupSync(); 10 | if( thread < i ) 11 | { 12 | sum += sharedAccumulators[ thread + i ]; 13 | sharedAccumulators[ thread ] = sum; 14 | } 15 | } 16 | GroupMemoryBarrierWithGroupSync(); 17 | if( 0 == thread ) 18 | sum += sharedAccumulators[ 1 ]; 19 | } 20 | 21 | // Compute horisontal sum of the numbers, and broadcast to all threads of the group. 22 | void horizontalSumBroadcast( const uint thread, inout float sum ) 23 | { 24 | horizontalSum( thread, sum ); 25 | if( 0 == thread ) 26 | sharedAccumulators[ 0 ] = sum; 27 | GroupMemoryBarrierWithGroupSync(); 28 | sum = sharedAccumulators[ 0 ]; 29 | } 30 | 31 | // Compute horizontal maximum of the numbers, and broadcast to all threads of the group. 32 | void horizontalMaxBroadcast( const uint thread, inout float ax ) 33 | { 34 | sharedAccumulators[ thread ] = ax; 35 | for( uint i = 32; i > 0; i /= 2 ) 36 | { 37 | GroupMemoryBarrierWithGroupSync(); 38 | if( thread < i ) 39 | { 40 | ax = max( ax, sharedAccumulators[ thread + i ] ); 41 | sharedAccumulators[ thread ] = ax; 42 | } 43 | } 44 | GroupMemoryBarrierWithGroupSync(); 45 | ax = sharedAccumulators[ 0 ]; 46 | } -------------------------------------------------------------------------------- /ComputeShaders/mulMatByRow.hlsl: -------------------------------------------------------------------------------- 1 | // Matrix * row product, like [ E0, E1, E2, E3 ] * [ E0, 1, E2, E3 ] = [ E1, 1, E2, E3 ] 2 | // Dispatch [ E1, E2, E3 ] groups of this shader 3 | Buffer arg0: register( t0 ); 4 | Buffer arg1: register( t1 ); 5 | RWBuffer result: register( u0 ); 6 | 7 | cbuffer Constants: register( b0 ) 8 | { 9 | uint4 arg0Size: packoffset( c0 ); 10 | uint4 arg0Strides: packoffset( c1 ); 11 | uint4 arg1Size: packoffset( c2 ); 12 | uint4 arg1Strides: packoffset( c3 ); 13 | uint4 resultSize: packoffset( c4 ); 14 | uint4 resultStrides: packoffset( c5 ); 15 | } 16 | 17 | #include "groupReduce.hlsli" 18 | 19 | inline uint hadd( uint3 vec ) 20 | { 21 | return vec.x + vec.y + vec.z; 22 | } 23 | inline uint hadd( uint2 vec ) 24 | { 25 | return vec.x + vec.y; 26 | } 27 | 28 | [ numthreads( 32, 1, 1 ) ] 29 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) 30 | { 31 | uint s0 = hadd( group * arg0Strides.yzw ); 32 | uint s1 = hadd( group.yz * arg1Strides.zw ); 33 | const uint s0End = s0 + arg0Size.x * arg0Strides.x; 34 | const uint s0Inc = 32 * arg0Strides.x; 35 | const uint s1Inc = 32 * arg1Strides.x; 36 | 37 | s0 += thread * arg0Strides.x; 38 | s1 += thread * arg1Strides.x; 39 | float dp = 0; 40 | for( ; s0 < s0End; s0 += s0Inc, s1 += s1Inc ) 41 | dp = mad( arg0[ s0 ], arg1[ s1 ], dp ); 42 | 43 | horizontalSum( thread, dp ); 44 | if( 0 != thread ) 45 | return; 46 | 47 | const uint rdi = group.x + hadd( group.yz * resultStrides.zw ); 48 | result[ rdi ] = dp; 49 | } -------------------------------------------------------------------------------- /ComputeShaders/mulMatByScalar.hlsl: -------------------------------------------------------------------------------- 1 | // Matrix * scalar product, like [ 1, E1, E2, E3 ] * [ 1, 1, E2, E3 ] = [ E1, 1, E2, E3 ] 2 | // Dispatch [ E2, E3, 1 ] thread groups of this shader 3 | Buffer arg0: register( t0 ); 4 | Buffer arg1: register( t1 ); 5 | RWBuffer result: register( u0 ); 6 | 7 | cbuffer Constants: register( b0 ) 8 | { 9 | uint4 arg0Size: packoffset( c0 ); 10 | uint4 arg0Strides: packoffset( c1 ); 11 | uint4 arg1Size: packoffset( c2 ); 12 | uint4 arg1Strides: packoffset( c3 ); 13 | uint4 resultSize: packoffset( c4 ); 14 | uint4 resultStrides: packoffset( c5 ); 15 | } 16 | 17 | inline uint hadd( uint2 vec ) 18 | { 19 | return vec.x + vec.y; 20 | } 21 | 22 | [ numthreads( 32, 1, 1 ) ] 23 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) 24 | { 25 | const float scalarValue = arg1[ hadd( group.xy * arg1Strides.zw ) ]; 26 | 27 | uint s0 = hadd( group.xy * arg0Strides.zw ); 28 | const uint s0Inc = 32 * arg0Strides.y; 29 | s0 += thread * arg0Strides.y; 30 | 31 | uint rdi = hadd( group.xy * resultStrides.zw ); 32 | const uint rdiEnd = rdi + arg0Size.y; 33 | rdi += thread; 34 | 35 | for( ; rdi < rdiEnd; rdi += 32, s0 += s0Inc ) 36 | { 37 | float f = arg0[ s0 ]; 38 | f *= scalarValue; 39 | result[ rdi ] = f; 40 | } 41 | } -------------------------------------------------------------------------------- /ComputeShaders/mulMatDotReshape.hlsl: -------------------------------------------------------------------------------- 1 | // GGML_TASK_INIT step for matrix*matrix product, where nb01 >= nb00; 2 | // Dispatch with [ ne11, ne12 ] groups 3 | Buffer arg0: register( t0 ); 4 | RWBuffer result: register( u0 ); 5 | 6 | cbuffer Constants: register( b0 ) 7 | { 8 | uint4 src0_elements: packoffset( c0 ); 9 | uint4 src0_strides: packoffset( c1 ); 10 | } 11 | 12 | #include "miscUtils.hlsli" 13 | 14 | // Each thread group of this shader copies a single rows of the matrix 15 | [ numthreads( 32, 1, 1 ) ] 16 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) 17 | { 18 | const uint i12 = group.y; 19 | const uint i11 = group.x; 20 | const uint ne10 = src0_elements.x; 21 | const uint ne11 = src0_elements.y; 22 | const uint nb12 = src0_strides.z; 23 | const uint nb11 = src0_strides.y; 24 | 25 | uint rdi = i11 * ne10 + i12 * ne10 * ne11; 26 | const uint rdiEnd = rdi + ne10; 27 | uint rsi = i12 * nb12 + i11 * nb11; 28 | rdi += thread; 29 | rsi += thread; 30 | 31 | for( ; rdi < rdiEnd; rdi += 32, rsi += 32 ) 32 | result[ rdi ] = adjustFp16( arg0[ rsi ] ); 33 | } -------------------------------------------------------------------------------- /ComputeShaders/normFixed64.hlsl: -------------------------------------------------------------------------------- 1 | #define THREADS 64 2 | #include "normFixed.hlsl" -------------------------------------------------------------------------------- /ComputeShaders/repeatUtils.hlsli: -------------------------------------------------------------------------------- 1 | inline uint rowOffset( uint3 idx, uint4 strides ) 2 | { 3 | return idx[ 0 ] * strides[ 1 ] + idx[ 1 ] * strides[ 2 ] + idx[ 2 ] * strides[ 3 ]; 4 | } 5 | 6 | // Initial iterator state for a row of the output tensor 7 | // x = current index, y = index increment, z = end of the index 8 | inline uint3 tensorIteratorState( uint3 group, uint thread, uint4 size, uint4 stride ) 9 | { 10 | uint3 res; 11 | res.x = rowOffset( group, stride ); 12 | res.y = THREADS * stride[ 0 ]; 13 | res.z = res.x + size[ 0 ] * stride[ 0 ]; 14 | res.x += thread * stride[ 0 ]; 15 | return res; 16 | } 17 | 18 | // Handle a complete row of output tensor, using the iterator made by tensorIteratorState() function 19 | #define ROW_LOOP( ts ) for( ; ts.x < ts.z; ts.x += ts.y ) 20 | // Same as above, using different row length 21 | #define ROW_LOOP_EX( ts, len, stride ) for( ; ts.x < ts.z; ts.x += len * stride[ 0 ] ) -------------------------------------------------------------------------------- /ComputeShaders/scaleInPlace.hlsl: -------------------------------------------------------------------------------- 1 | RWBuffer buffer: register( u0 ); 2 | 3 | cbuffer Constants: register( b0 ) 4 | { 5 | uint4 src0_elements: packoffset( c0 ); 6 | uint4 src0_strides: packoffset( c1 ); 7 | float multiplier: packoffset( c2.x ); 8 | } 9 | 10 | [ numthreads( 32, 1, 1 ) ] 11 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) 12 | { 13 | const uint nc0 = src0_elements[ 0 ]; 14 | uint i = group.x * src0_strides[ 1 ]; 15 | const uint iEnd = i + nc0; 16 | const float mul = multiplier; 17 | for( i += thread; i < iEnd; i += 32 ) 18 | { 19 | float f = buffer[ i ]; 20 | f *= mul; 21 | buffer[ i ] = f; 22 | } 23 | } -------------------------------------------------------------------------------- /ComputeShaders/softMax64.hlsl: -------------------------------------------------------------------------------- 1 | #define THREADS 64 2 | #include "softMax.hlsl" -------------------------------------------------------------------------------- /ComputeShaders/softMaxCompat.hlsl: -------------------------------------------------------------------------------- 1 | // ggml_compute_forward_soft_max_f32 2 | // Dispatch [ ( nr + 31 ) / 32, 1, 1 ] thread groups of this shader 3 | RWBuffer result: register( u0 ); 4 | 5 | // table_exp_f16 6 | Buffer lookupTable: register( t0 ); 7 | 8 | cbuffer Constants: register( b0 ) 9 | { 10 | uint4 elements: packoffset( c0 ); 11 | uint4 strides: packoffset( c1 ); 12 | uint nr: packoffset( c2.x ); 13 | } 14 | 15 | #include "miscUtils.hlsli" 16 | #include "fp64Utils.hlsli" 17 | 18 | static const float negativeInfinity = asfloat( 0xff800000 ); 19 | 20 | [ numthreads( 32, 1, 1 ) ] 21 | void main( uint3 dtid: SV_DispatchThreadID ) 22 | { 23 | if( dtid.x >= nr ) 24 | return; 25 | 26 | const uint p = dtid.x * strides[ 1 ]; 27 | const uint nc = elements[ 0 ]; 28 | const uint pEnd = p + nc; 29 | uint i; 30 | 31 | float m = negativeInfinity; 32 | for( i = p; i < pEnd; i++ ) 33 | m = max( m, result[ i ] ); 34 | 35 | double sum = 0; 36 | for( i = p; i < pEnd; i++ ) 37 | { 38 | float f = result[ i ]; 39 | 40 | [branch] 41 | if( f != negativeInfinity ) 42 | { 43 | uint s = fp16Rounded( f - m ); 44 | s = lookupTable[ s ]; 45 | f = f16tof32( s ); 46 | sum += f; 47 | } 48 | else 49 | f = 0; 50 | 51 | result[ i ] = f; 52 | } 53 | 54 | const float scale = (float)div64( 1.0, sum ); 55 | // ggml_vec_scale_f32 56 | for( i = p; i < pEnd; i++ ) 57 | { 58 | float f = result[ i ]; 59 | f *= scale; 60 | result[ i ] = f; 61 | } 62 | } -------------------------------------------------------------------------------- /ComputeShaders/softMaxLong.hlsl: -------------------------------------------------------------------------------- 1 | // This version is for the "dec.probs" shader tag 2 | // The input tensor has a size [ 51865, 3 ], a very long tensor with just 3 rows. 3 | // Despite the shader only runs on 3 GPU cores, large count of threads helps substantially, this shader is about 50% faster. 4 | #define THREADS 1024 5 | 6 | #include "softMax.hlsl" -------------------------------------------------------------------------------- /ComputeShaders/zeroMemory.hlsl: -------------------------------------------------------------------------------- 1 | RWBuffer result: register( u0 ); 2 | 3 | cbuffer Constants: register( b0 ) 4 | { 5 | uint elements: packoffset( c0.x ); 6 | bool writeNan: packoffset( c0.y ); 7 | } 8 | 9 | // Thread group index is 16 bits per coordinate: 10 | // https://learn.microsoft.com/en-us/windows/win32/api/d3d11/nf-d3d11-id3d11devicecontext-dispatch 11 | // We want this shader to support buffers up to 2 GB. 12 | #ifndef THREADS 13 | static const uint THREADS = 512; 14 | #endif 15 | #ifndef ITERATIONS 16 | static const uint ITERATIONS = 128; 17 | #endif 18 | 19 | static const uint itemsPerGroup = THREADS * ITERATIONS; 20 | 21 | [numthreads( THREADS, 1, 1 )] 22 | void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) 23 | { 24 | uint rdi = group.x * itemsPerGroup; 25 | const uint rdiEnd = min( rdi + itemsPerGroup, elements ); 26 | // https://www.h-schmidt.net/FloatConverter/IEEE754.html 27 | const float pattern = writeNan ? asfloat( 0x7FFFFFFFu ) : 0.0; 28 | for( rdi += thread; rdi < rdiEnd; rdi += THREADS ) 29 | result[ rdi ] = pattern; 30 | } -------------------------------------------------------------------------------- /Examples/MicrophoneCS/CaptureThread.cs: -------------------------------------------------------------------------------- 1 | using System.Runtime.ExceptionServices; 2 | using Whisper; 3 | 4 | namespace MicrophoneCS 5 | { 6 | sealed class CaptureThread: CaptureCallbacks 7 | { 8 | public CaptureThread( CommandLineArgs args, Context context, iAudioCapture source ) 9 | { 10 | callbacks = new TranscribeCallbacks( args ); 11 | this.context = context; 12 | this.source = source; 13 | 14 | thread = new Thread( threadMain ) { Name = "Capture Thread" }; 15 | Console.WriteLine( "Press any key to quit" ); 16 | thread.Start(); 17 | } 18 | 19 | static void readKeyCallback( object? state ) 20 | { 21 | CaptureThread ct = ( state as CaptureThread ) ?? throw new ApplicationException(); 22 | Console.ReadKey(); 23 | ct.shouldQuit = true; 24 | } 25 | 26 | public void join() 27 | { 28 | ThreadPool.QueueUserWorkItem( readKeyCallback, this ); 29 | thread.Join(); 30 | edi?.Throw(); 31 | } 32 | 33 | volatile bool shouldQuit = false; 34 | 35 | protected override bool shouldCancel( Context sender ) => 36 | shouldQuit; 37 | 38 | protected override void captureStatusChanged( Context sender, eCaptureStatus status ) 39 | { 40 | Console.WriteLine( $"CaptureStatusChanged: {status}" ); 41 | } 42 | 43 | readonly TranscribeCallbacks callbacks; 44 | readonly Thread thread; 45 | readonly Context context; 46 | readonly iAudioCapture source; 47 | ExceptionDispatchInfo? edi = null; 48 | 49 | void threadMain() 50 | { 51 | try 52 | { 53 | context.runCapture( source, callbacks, this ); 54 | } 55 | catch( Exception ex ) 56 | { 57 | edi = ExceptionDispatchInfo.Capture( ex ); 58 | } 59 | } 60 | } 61 | } -------------------------------------------------------------------------------- /Examples/MicrophoneCS/MicrophoneCS.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | Exe 5 | net6.0-windows 6 | enable 7 | enable 8 | true 9 | false 10 | x64 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | PreserveNewest 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /Examples/MicrophoneCS/Readme.txt: -------------------------------------------------------------------------------- 1 | This example builds .NET 6 console application which shows how to use audio capture API of the .NET wrapper. -------------------------------------------------------------------------------- /Examples/OldMain/OldMain.vcxproj.filters: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /Examples/OldMain/Readme.txt: -------------------------------------------------------------------------------- 1 | This project builds the original whisper.cpp command-line sample -------------------------------------------------------------------------------- /Examples/OldMain/Utils/Logger.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "Logger.h" 5 | 6 | namespace 7 | { 8 | void logMessage( const char* lvl, const char8_t* pszFormat, std::va_list va ) 9 | { 10 | fprintf( stderr, "%s: ", lvl ); 11 | vfprintf( stderr, (const char*)pszFormat, va ); 12 | fprintf( stderr, "\n" ); 13 | } 14 | } 15 | 16 | #define LOG_MESSAGE_IMPL( lvl ) \ 17 | std::va_list args; \ 18 | va_start( args, pszFormat ); \ 19 | logMessage( lvl, pszFormat, args ); \ 20 | va_end( args ); 21 | 22 | void logError( const char8_t* pszFormat, ... ) 23 | { 24 | LOG_MESSAGE_IMPL( "Error" ); 25 | } 26 | 27 | void logWarning( const char8_t* pszFormat, ... ) 28 | { 29 | LOG_MESSAGE_IMPL( "Warning" ); 30 | } 31 | 32 | void logInfo( const char8_t* pszFormat, ... ) 33 | { 34 | LOG_MESSAGE_IMPL( "Info" ); 35 | } 36 | 37 | void logDebug( const char8_t* pszFormat, ... ) 38 | { 39 | LOG_MESSAGE_IMPL( "Debug" ); 40 | } -------------------------------------------------------------------------------- /Examples/OldMain/Utils/Logger.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef __cplusplus 4 | extern "C" { 5 | #endif 6 | 7 | struct ggml_tensor; 8 | 9 | void logError( const char8_t* pszFormat, ... ); 10 | void logWarning( const char8_t* pszFormat, ... ); 11 | void logInfo( const char8_t* pszFormat, ... ); 12 | void logDebug( const char8_t* pszFormat, ... ); 13 | 14 | #ifdef __cplusplus 15 | } 16 | 17 | namespace Tracing 18 | { 19 | struct ItemName 20 | { 21 | ItemName( const char* str ) { } 22 | ItemName( const char* str, uint32_t a0 ) { } 23 | ItemName( const char* str, int a0 ) { } 24 | }; 25 | 26 | inline void tensor( const ItemName& name, const ggml_tensor* tensor ) { } 27 | inline void delayTensor( const ItemName& name, const ggml_tensor* tensor ) { } 28 | inline void vector( const ItemName& name, const std::vector& vec ) { } 29 | inline void writeDelayedTensors() { } 30 | } 31 | #endif -------------------------------------------------------------------------------- /Examples/TranscribeCS/Readme.txt: -------------------------------------------------------------------------------- 1 | This example builds .NET 6 console application which shows how to transcribe or translate audio files with the .NET wrapper. -------------------------------------------------------------------------------- /Examples/TranscribeCS/TranscribeCS.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | Exe 4 | net6.0-windows 5 | enable 6 | enable 7 | true 8 | false 9 | x64 10 | 11 | 12 | 13 | PreserveNewest 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /Examples/WhisperDesktop/AppState.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "Utils/DebugConsole.h" 3 | 4 | class AppState 5 | { 6 | bool coInit = false; 7 | CRegKey registryKey; 8 | CIcon appIcon; 9 | public: 10 | 11 | struct ModelSource 12 | { 13 | CString path; 14 | bool found = false; 15 | Whisper::eModelImplementation impl = (Whisper::eModelImplementation)0; 16 | uint64_t sizeInBytes = 0; 17 | }; 18 | ModelSource source; 19 | 20 | DebugConsole console; 21 | CComPtr mediaFoundation; 22 | CComPtr model; 23 | 24 | ~AppState(); 25 | 26 | // Setup the initial things 27 | HRESULT startup(); 28 | 29 | HRESULT findModelSource(); 30 | 31 | HRESULT saveModelSource(); 32 | 33 | uint32_t languageRead(); 34 | void languageWrite( uint32_t key ); 35 | 36 | CString stringLoad( LPCTSTR name ); 37 | void stringStore( LPCTSTR name, LPCTSTR value ); 38 | uint32_t dwordLoad( LPCTSTR name, uint32_t fallback ); 39 | void dwordStore( LPCTSTR name, uint32_t value ); 40 | bool boolLoad( LPCTSTR name ); 41 | void boolStore( LPCTSTR name, bool val ); 42 | 43 | bool automaticallyLoadModel = true; 44 | 45 | void lastScreenSave( HRESULT code ); 46 | HRESULT lastScreenLoad(); 47 | 48 | void setupIcon( CWindow* wnd ); 49 | 50 | uint32_t gpuFlagsLoad(); 51 | void gpuFlagsStore( uint32_t flags ); 52 | }; 53 | 54 | constexpr HRESULT SCREEN_MODEL = 1; 55 | constexpr HRESULT SCREEN_TRANSCRIBE = 2; 56 | constexpr HRESULT SCREEN_CAPTURE = 3; -------------------------------------------------------------------------------- /Examples/WhisperDesktop/CaptureDlg.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Const-me/Whisper/306aadd1fce4b168cd38659236f4ba7c1603cebd/Examples/WhisperDesktop/CaptureDlg.cpp -------------------------------------------------------------------------------- /Examples/WhisperDesktop/CircleIndicator.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "Utils/miscUtils.h" 3 | #include "Utils/WTL/atlcrack.h" 4 | 5 | // This control renders a black circle, and in the active state, the circle is filled with a bright color. 6 | class CircleIndicator: public CWindowImpl 7 | { 8 | public: 9 | static ATL::CWndClassInfo& GetWndClassInfo(); 10 | 11 | BEGIN_MSG_MAP( CircleIndicator ) 12 | MSG_WM_PAINT( onPaint ) 13 | MSG_WM_DESTROY( onDestroy ) 14 | END_MSG_MAP() 15 | 16 | // Class registration 17 | static HRESULT registerClass(); 18 | 19 | void setActive( bool nowActive ); 20 | 21 | void setActiveColor( uint32_t col ) 22 | { 23 | activeColor = col; 24 | } 25 | CircleIndicator(); 26 | 27 | private: 28 | bool isActive = false; 29 | uint32_t activeColor; 30 | int fontHeight = 0; 31 | CFont font; 32 | HRESULT createFont( int height ); 33 | 34 | void onDestroy(); 35 | void onPaint( CDCHandle dc ); 36 | }; -------------------------------------------------------------------------------- /Examples/WhisperDesktop/ModelAdvancedDlg.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "AppState.h" 3 | #include "Utils/WTL/atlddx.h" 4 | #include "Utils/miscUtils.h" 5 | 6 | class ModelAdvancedDlg : 7 | public CDialogImpl 8 | { 9 | CComboBox cbWave, cbReshapedMatMul, cbAdapter; 10 | AppState& appState; 11 | 12 | public: 13 | static constexpr UINT IDD = IDD_MODEL_ADV; 14 | 15 | ModelAdvancedDlg( AppState& app ) : appState( app ) { } 16 | 17 | BEGIN_MSG_MAP( ModelAdvancedDlg ) 18 | MESSAGE_HANDLER( WM_INITDIALOG, onInitDialog ) 19 | ON_BUTTON_CLICK( IDOK, onOk ) 20 | ON_BUTTON_CLICK( IDCANCEL, onCancel ) 21 | END_MSG_MAP() 22 | 23 | bool show( HWND owner ); 24 | 25 | private: 26 | 27 | LRESULT onInitDialog( UINT nMessage, WPARAM wParam, LPARAM lParam, BOOL& bHandled ); 28 | 29 | void onOk(); 30 | 31 | void onCancel() 32 | { 33 | EndDialog( IDCANCEL ); 34 | } 35 | }; -------------------------------------------------------------------------------- /Examples/WhisperDesktop/Readme.txt: -------------------------------------------------------------------------------- 1 | This example shows how to consume the DLL from a C++ GUI application. 2 | 3 | The GUI is implemented with ATL and WTL libraries. -------------------------------------------------------------------------------- /Examples/WhisperDesktop/Utils/DebugConsole.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | 6 | class AppState; 7 | class DebugConsole 8 | { 9 | using eLogLevel = Whisper::eLogLevel; 10 | 11 | struct Entry 12 | { 13 | eLogLevel level; 14 | CStringA message; 15 | HRESULT print( HANDLE hConsole, CString& tempString ) const; 16 | }; 17 | 18 | CComAutoCriticalSection critSec; 19 | std::deque buffer; 20 | CString tempString; 21 | CHandle output; 22 | 23 | inline void logSink( eLogLevel lvl, const char* message ); 24 | static void __stdcall logSinkStatic( void* context, eLogLevel lvl, const char* message ); 25 | 26 | static BOOL __stdcall consoleHandlerRoutine( DWORD dwCtrlType ); 27 | 28 | static DebugConsole* pGlobalInstance; 29 | void windowClosed(); 30 | 31 | std::unordered_set checkboxes; 32 | 33 | CStringA tempStringA; 34 | void log( eLogLevel lvl, const char* pszFormat, va_list args ); 35 | 36 | public: 37 | HRESULT initialize( eLogLevel level = eLogLevel::Debug ); 38 | ~DebugConsole(); 39 | 40 | HRESULT show(); 41 | HRESULT hide(); 42 | bool isVisible() const { return output; } 43 | 44 | void addCheckbox( CButton& cb ); 45 | void removeCheckbox( CButton& cb ); 46 | 47 | static void logMessage( eLogLevel lvl, const char* pszFormat, va_list args ); 48 | }; 49 | 50 | class ConsoleCheckbox 51 | { 52 | CButton control; 53 | DebugConsole* console = nullptr; 54 | 55 | public: 56 | HRESULT initialize( HWND dialog, int idc, AppState& state ); 57 | void click(); 58 | ~ConsoleCheckbox() 59 | { 60 | if( nullptr != console ) 61 | console->removeCheckbox( control ); 62 | } 63 | void ensureChecked(); 64 | }; -------------------------------------------------------------------------------- /Examples/WhisperDesktop/Utils/LanguageDropdown.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../AppState.h" 3 | 4 | // Dropdown list which implements language selector control 5 | class LanguageDropdown 6 | { 7 | HWND m_hWnd = nullptr; 8 | std::vector keys; 9 | int getInitialSelection( AppState& state ) const; 10 | 11 | public: 12 | operator HWND() const 13 | { 14 | return m_hWnd; 15 | } 16 | 17 | // Query language list form the native library, populate the combo box 18 | // Then load the last saved language selection from registry, and preselect an item. 19 | void initialize( HWND owner, int idc, AppState& state ); 20 | 21 | // Get the ID of the currently selected language, or UINT_MAX if nothing's selected 22 | uint32_t selectedLanguage(); 23 | 24 | // Get the ID of the currently selected language, and store in registry 25 | void saveSelection( AppState& state ); 26 | }; -------------------------------------------------------------------------------- /Examples/WhisperDesktop/Utils/PendingState.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include "PendingState.h" 3 | 4 | void PendingState::initialize( std::initializer_list editors, std::initializer_list pending ) 5 | { 6 | editorsWindows = editors; 7 | wasEnabled.resize( editorsWindows.size() ); 8 | pendingWindows = pending; 9 | } 10 | 11 | void PendingState::setPending( bool nowPending ) 12 | { 13 | if( nowPending ) 14 | { 15 | for( size_t i = 0; i < editorsWindows.size(); i++ ) 16 | { 17 | BOOL e = IsWindowEnabled( editorsWindows[ i ] ); 18 | if( e ) 19 | { 20 | wasEnabled[ i ] = true; 21 | EnableWindow( editorsWindows[ i ], FALSE ); 22 | } 23 | else 24 | wasEnabled[ i ] = false; 25 | } 26 | } 27 | else 28 | { 29 | for( size_t i = 0; i < editorsWindows.size(); i++ ) 30 | { 31 | if( !wasEnabled[ i ] ) 32 | continue; 33 | EnableWindow( editorsWindows[ i ], TRUE ); 34 | } 35 | } 36 | 37 | const int show = nowPending ? SW_NORMAL : SW_HIDE; 38 | for( HWND w : pendingWindows ) 39 | ::ShowWindow( w, show ); 40 | } -------------------------------------------------------------------------------- /Examples/WhisperDesktop/Utils/PendingState.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // Utility class to switch visual state of dialog controls between idle and pending 4 | class PendingState 5 | { 6 | std::vector editorsWindows; 7 | std::vector wasEnabled; 8 | std::vector pendingWindows; 9 | public: 10 | void initialize( std::initializer_list editors, std::initializer_list pending ); 11 | void setPending( bool nowPending ); 12 | }; -------------------------------------------------------------------------------- /Examples/WhisperDesktop/Utils/TranslateCheckbox.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include "TranslateCheckbox.h" 3 | 4 | static const LPCTSTR regValTranslate = L"translate"; 5 | 6 | void TranslateCheckbox::initialize( HWND owner, int idc, AppState& state ) 7 | { 8 | m_hWnd = GetDlgItem( owner, idc ); 9 | assert( nullptr != m_hWnd ); 10 | 11 | if( state.boolLoad( regValTranslate ) ) 12 | ::SendMessage( m_hWnd, BM_SETCHECK, BST_CHECKED, 0L ); 13 | } 14 | 15 | bool TranslateCheckbox::checked() 16 | { 17 | assert( nullptr != m_hWnd ); 18 | const int state = ( int )::SendMessage( m_hWnd, BM_GETCHECK, 0, 0 ); 19 | return state == BST_CHECKED; 20 | } 21 | 22 | void TranslateCheckbox::saveSelection( AppState& state ) 23 | { 24 | state.boolStore( regValTranslate, checked() ); 25 | } -------------------------------------------------------------------------------- /Examples/WhisperDesktop/Utils/TranslateCheckbox.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../AppState.h" 3 | 4 | class TranslateCheckbox 5 | { 6 | HWND m_hWnd = nullptr; 7 | public: 8 | operator HWND() const 9 | { 10 | return m_hWnd; 11 | } 12 | 13 | void initialize( HWND owner, int idc, AppState& state ); 14 | 15 | bool checked(); 16 | 17 | void saveSelection( AppState& state ); 18 | }; -------------------------------------------------------------------------------- /Examples/WhisperDesktop/Utils/logger.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | void logMessage( Whisper::eLogLevel lvl, const char8_t* pszFormat, va_list args ); 6 | 7 | #define LOG_MESSAGE_IMPL( lvl ) \ 8 | std::va_list args; \ 9 | va_start( args, pszFormat ); \ 10 | logMessage( lvl, pszFormat, args ); \ 11 | va_end( args ) 12 | 13 | inline void logError( const char8_t* pszFormat, ... ) 14 | { 15 | LOG_MESSAGE_IMPL( Whisper::eLogLevel::Error ); 16 | } 17 | inline void logWarning( const char8_t* pszFormat, ... ) 18 | { 19 | LOG_MESSAGE_IMPL( Whisper::eLogLevel::Warning ); 20 | } 21 | inline void logInfo( const char8_t* pszFormat, ... ) 22 | { 23 | LOG_MESSAGE_IMPL( Whisper::eLogLevel::Info ); 24 | } 25 | inline void logDebug( const char8_t* pszFormat, ... ) 26 | { 27 | LOG_MESSAGE_IMPL( Whisper::eLogLevel::Debug ); 28 | } 29 | 30 | #undef LOG_MESSAGE_IMPL 31 | 32 | HRESULT logNewSegments( const Whisper::iTranscribeResult* results, size_t newSegments, bool printSpecial = false ); 33 | 34 | void clearLastError(); 35 | bool getLastError( CString& rdi ); 36 | 37 | void printTime( CStringA& rdi, Whisper::sTimeSpan time, bool comma = false ); -------------------------------------------------------------------------------- /Examples/WhisperDesktop/WhisperDesktop.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include "AppState.h" 3 | #include "Utils/miscUtils.h" 4 | #include "LoadModelDlg.h" 5 | #include "TranscribeDlg.h" 6 | #include "CaptureDlg.h" 7 | 8 | static HRESULT dialogLoadModel( AppState& appState ) 9 | { 10 | LoadModelDlg loadDialog{ appState }; 11 | HRESULT hr = loadDialog.show(); 12 | if( FAILED( hr ) ) 13 | { 14 | reportFatalError( "Error loading the model", hr ); 15 | return hr; 16 | } 17 | appState.automaticallyLoadModel = false; 18 | return hr; 19 | } 20 | 21 | static HRESULT dialogTranscribe( AppState& appState ) 22 | { 23 | TranscribeDlg dialog{ appState }; 24 | return dialog.show(); 25 | } 26 | 27 | static HRESULT dialogCapture( AppState& appState ) 28 | { 29 | CaptureDlg dialog{ appState }; 30 | return dialog.show(); 31 | } 32 | 33 | using pfnDialog = HRESULT( * )( AppState& appState ); 34 | static const std::array s_dialogs = 35 | { 36 | nullptr, // S_OK 37 | &dialogLoadModel, // SCREEN_MODEL 38 | &dialogTranscribe, // SCREEN_TRANSCRIBE 39 | &dialogCapture, // SCREEN_CAPTURE 40 | }; 41 | 42 | int __stdcall wWinMain( HINSTANCE hInstance, HINSTANCE hPrevInstance, LPWSTR lpCmdLine, int nCmdShow ) 43 | { 44 | AppState appState; 45 | HRESULT hr = appState.startup(); 46 | if( FAILED( hr ) ) 47 | return hr; 48 | 49 | appState.findModelSource(); 50 | 51 | hr = SCREEN_MODEL; 52 | while( true ) 53 | { 54 | pfnDialog pfn = s_dialogs[ hr ]; 55 | if( nullptr == pfn ) 56 | return S_OK; 57 | hr = pfn( appState ); 58 | if( FAILED( hr ) ) 59 | return hr; 60 | if( hr == SCREEN_MODEL ) 61 | appState.model = nullptr; 62 | } 63 | } -------------------------------------------------------------------------------- /Examples/WhisperDesktop/WhisperDesktop.manifest: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | Your application description here. 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | true 13 | PerMonitorV2 14 | 15 | 16 | -------------------------------------------------------------------------------- /Examples/WhisperDesktop/WhisperDesktop.rc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Const-me/Whisper/306aadd1fce4b168cd38659236f4ba7c1603cebd/Examples/WhisperDesktop/WhisperDesktop.rc -------------------------------------------------------------------------------- /Examples/WhisperDesktop/framework.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #define WIN32_LEAN_AND_MEAN // Exclude rarely-used stuff from Windows headers 3 | #define NOMINMAX 4 | // Windows Header Files 5 | #include "targetver.h" 6 | #include 7 | // ATL header files 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | // C RunTime Header Files 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | // C++ headers 20 | #include 21 | #include 22 | #include -------------------------------------------------------------------------------- /Examples/WhisperDesktop/stdafx.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" -------------------------------------------------------------------------------- /Examples/WhisperDesktop/stdafx.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "framework.h" 3 | 4 | #include 5 | 6 | #include "resource.h" 7 | #include "Utils/WTL/atlapp.h" 8 | #include "Utils/WTL/atlctrls.h" -------------------------------------------------------------------------------- /Examples/WhisperDesktop/sunflower.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Const-me/Whisper/306aadd1fce4b168cd38659236f4ba7c1603cebd/Examples/WhisperDesktop/sunflower.ico -------------------------------------------------------------------------------- /Examples/WhisperDesktop/targetver.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | // Setup Windows SDK to only enable features available since Windows 8.0 3 | #include 4 | #define _WIN32_WINNT _WIN32_WINNT_WIN8 5 | #define NTDDI_VERSION NTDDI_WIN8 6 | #include 7 | -------------------------------------------------------------------------------- /Examples/WhisperDesktop/useDiscreteGpu.c: -------------------------------------------------------------------------------- 1 | __declspec( dllexport ) int NvOptimusEnablement = 1; 2 | __declspec( dllexport ) int AmdPowerXpressRequestHighPerformance = 1; -------------------------------------------------------------------------------- /Examples/main/Readme.txt: -------------------------------------------------------------------------------- 1 | This example shows how to consume the DLL from a C++ console application. 2 | 3 | The command-line interface matches the corresponding example from whisper.cpp project. -------------------------------------------------------------------------------- /Examples/main/main.vcxproj.filters: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /Examples/main/miscUtils.cpp: -------------------------------------------------------------------------------- 1 | #include "miscUtils.h" 2 | #define WIN32_LEAN_AND_MEAN 3 | #include 4 | 5 | std::string utf8( const std::wstring& utf16 ) 6 | { 7 | int count = WideCharToMultiByte( CP_UTF8, 0, utf16.c_str(), (int)utf16.length(), nullptr, 0, nullptr, nullptr ); 8 | std::string str( count, 0 ); 9 | WideCharToMultiByte( CP_UTF8, 0, utf16.c_str(), -1, &str[ 0 ], count, nullptr, nullptr ); 10 | return str; 11 | } 12 | 13 | std::wstring utf16( const std::string& u8 ) 14 | { 15 | int count = MultiByteToWideChar( CP_UTF8, 0, u8.c_str(), (int)u8.length(), nullptr, 0 ); 16 | std::wstring str( count, 0 ); 17 | MultiByteToWideChar( CP_UTF8, 0, u8.c_str(), (int)u8.length(), &str[ 0 ], count ); 18 | return str; 19 | } 20 | 21 | namespace 22 | { 23 | wchar_t* formatMessage( HRESULT hr ) 24 | { 25 | wchar_t* err; 26 | if( FormatMessage( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM, 27 | NULL, 28 | hr, 29 | MAKELANGID( LANG_NEUTRAL, SUBLANG_DEFAULT ), 30 | (LPTSTR)&err, 31 | 0, 32 | NULL ) ) 33 | return err; 34 | return nullptr; 35 | } 36 | } 37 | 38 | void printError( const char* what, HRESULT hr ) 39 | { 40 | const wchar_t* err = formatMessage( hr ); 41 | if( nullptr != err ) 42 | { 43 | fwprintf( stderr, L"%S: %s\n", what, err ); 44 | LocalFree( (HLOCAL)err ); 45 | } 46 | else 47 | fprintf( stderr, "%s: error code %i (0x%08X)\n", what, hr, hr ); 48 | } -------------------------------------------------------------------------------- /Examples/main/miscUtils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | std::string utf8( const std::wstring& utf16 ); 5 | 6 | std::wstring utf16( const std::string& u8 ); 7 | 8 | using HRESULT = long; 9 | void printError( const char* what, HRESULT hr ); -------------------------------------------------------------------------------- /Examples/main/params.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | // command-line parameters 6 | struct whisper_params 7 | { 8 | uint32_t n_threads; 9 | uint32_t n_processors = 1; 10 | uint32_t offset_t_ms = 0; 11 | uint32_t offset_n = 0; 12 | uint32_t duration_ms = 0; 13 | uint32_t max_context = UINT_MAX; 14 | uint32_t max_len = 0; 15 | 16 | float word_thold = 0.01f; 17 | 18 | bool speed_up = false; 19 | bool translate = false; 20 | bool diarize = false; 21 | bool output_txt = false; 22 | bool output_vtt = false; 23 | bool output_srt = false; 24 | bool output_wts = false; 25 | bool print_special = false; 26 | bool print_colors = true; 27 | bool no_timestamps = false; 28 | 29 | std::string language = "en"; 30 | std::wstring model = L"models/ggml-base.en.bin"; 31 | std::wstring gpu; 32 | std::string prompt; 33 | std::vector fname_inp; 34 | 35 | whisper_params(); 36 | 37 | bool parse( int argc, wchar_t* argv[] ); 38 | }; 39 | 40 | void whisper_print_usage( int argc, wchar_t** argv, const whisper_params& params ); -------------------------------------------------------------------------------- /Examples/main/textWriter.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../../Whisper/API/iContext.cl.h" 3 | 4 | // These functions print output segments into text files of various formats 5 | HRESULT writeText( Whisper::iContext* context, LPCTSTR audioPath, bool timestamps ); 6 | HRESULT writeSubRip( Whisper::iContext* context, LPCTSTR audioPath ); 7 | HRESULT writeWebVTT( Whisper::iContext* context, LPCTSTR audioPath ); -------------------------------------------------------------------------------- /SampleClips/columbia.wma: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Const-me/Whisper/306aadd1fce4b168cd38659236f4ba7c1603cebd/SampleClips/columbia.wma -------------------------------------------------------------------------------- /SampleClips/jfk.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Const-me/Whisper/306aadd1fce4b168cd38659236f4ba7c1603cebd/SampleClips/jfk.wav -------------------------------------------------------------------------------- /SampleClips/summary.tsv: -------------------------------------------------------------------------------- 1 | Audio Clip Model GPU Total, sec Relative speed Encode, sec Decode, sec RAM, MB VRAM, MB 2 | jfk.wav medium GeForce 1080Ti 1.13909 9.6568 0.599964 0.451832 2.84049 2185.90208 3 | jfk.wav medium GeForce 1650 3.16174 3.4791 1.95373 0.987441 2.84049 2185.90208 4 | jfk.wav medium Ryzen 5 5600U 8.79853 1.2502 7.34021 1.30925 2.84049 2233.35424 5 | jfk.wav medium Ryzen 7 5700G 4.95485 2.2200 3.82894 1.03424 2.84049 2233.35424 6 | jfk.wav large GeForce 1080Ti 2.19628 5.0085 1.29615 0.7739170000000001 2.85543 4052.89984 7 | jfk.wav large GeForce 1650 8.33686 1.3194 3.98133 4.07331 2.85543 4052.89984 8 | jfk.wav large Ryzen 5 5600U 14.2729 0.7707 11.9404 2.12941 2.85543 4112.35328 9 | jfk.wav large Ryzen 7 5700G 9.46787 1.1618 7.35144 2.01739 2.85543 4112.35328 10 | columbia.wma medium GeForce 1080Ti 14.9475 13.2973 6.01034 8.78676 91.929 2247.34208 11 | columbia.wma medium GeForce 1650 48.7479 4.0773 19.2258 29.233 91.929 2247.34208 12 | columbia.wma medium Ryzen 5 5600U 81.256 2.4461 51.1656 29.6502 91.929 2295.52128 13 | columbia.wma medium Ryzen 7 5700G 62.1145 3.1999 37.8702 23.9262 91.929 2295.52128 14 | columbia.wma large GeForce 1080Ti 27.5329 7.2191 11.3967 15.9412 93.1329 4118.28224 15 | columbia.wma large GeForce 1650 109.423 1.8165 36.3141 72.8459 93.1329 4118.28224 16 | columbia.wma large Ryzen 5 5600U 140.747 1.4122 88.7441 51.4306 93.1329 4178.78016 17 | columbia.wma large Ryzen 7 5700G 110.474 1.7992 65.7998 44.3232 93.1329 4178.78016 18 | -------------------------------------------------------------------------------- /Tools/CompressShaders/CompressShaders.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | Exe 4 | net6.0 5 | enable 6 | enable 7 | true 8 | false 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /Tools/CompressShaders/DetectFp64.cs: -------------------------------------------------------------------------------- 1 | #pragma warning disable CS0649 2 | using System.Runtime.InteropServices; 3 | 4 | namespace CompressShaders 5 | { 6 | static class DetectFp64 7 | { 8 | struct DXBCHeader 9 | { 10 | public uint FourCC; // Four character code "DXBC" 11 | public uint Hash0; // 32-bit hash of the DXBC file 12 | public uint Hash1; // 32-bit hash of the DXBC file 13 | public uint Hash2; // 32-bit hash of the DXBC file 14 | public uint Hash3; // 32-bit hash of the DXBC file 15 | public uint unknownOne; 16 | public uint TotalSize; // Total size of the DXBC file in bytes 17 | public int NumChunks; // Number of chunks in the DXBC file 18 | }; 19 | 20 | public static bool usesFp64( ReadOnlySpan dxbc ) 21 | { 22 | ReadOnlySpan dxbcHeaderSpan = MemoryMarshal.Cast( dxbc ); 23 | DXBCHeader dxbcHeader = dxbcHeaderSpan[ 0 ]; 24 | 25 | int cbHeader = Marshal.SizeOf(); 26 | int nChunks = dxbcHeader.NumChunks; 27 | ReadOnlySpan chunkOffsets = MemoryMarshal.Cast( dxbc.Slice( cbHeader, nChunks * 4 ) ); 28 | foreach( int off in chunkOffsets ) 29 | { 30 | uint id = MemoryMarshal.Cast( dxbc.Slice( off, 4 ) )[ 0 ]; 31 | const uint SFI0 = 0x30494653; 32 | if( id != SFI0 ) 33 | continue; 34 | int size = MemoryMarshal.Cast( dxbc.Slice( off + 4, 4 ) )[ 0 ]; 35 | if( size < 4 ) 36 | throw new ApplicationException(); 37 | uint data = MemoryMarshal.Cast( dxbc.Slice( off + 8, 4 ) )[ 0 ]; 38 | return 0 != ( data & 1u ); 39 | } 40 | return false; 41 | } 42 | } 43 | } -------------------------------------------------------------------------------- /Tools/CompressShaders/LZ4.cs: -------------------------------------------------------------------------------- 1 | namespace CompressShaders; 2 | using K4os.Compression.LZ4; 3 | 4 | /// Lossless data compressor which uses LZ4-HC compressor 5 | /// 6 | /// 7 | static class LZ4 8 | { 9 | // compression speed drops rapidly when not using FAST mode, while decompression speed stays the same 10 | // Actually, it is usually faster for high compression levels as there is less data to process 11 | // https://github.com/MiloszKrajewski/K4os.Compression.LZ4#compression-levels 12 | const LZ4Level compressionLevel = LZ4Level.L12_MAX; 13 | 14 | public static byte[] compressBuffer( byte[] src ) 15 | { 16 | int maxLength = LZ4Codec.MaximumOutputSize( src.Length ); 17 | byte[] output = new byte[ maxLength ]; 18 | int cb = LZ4Codec.Encode( src, output, compressionLevel ); 19 | if( cb > 0 ) 20 | { 21 | Array.Resize( ref output, cb ); 22 | return output; 23 | } 24 | throw new ApplicationException( $"LZ4Codec.Encode failed with status {cb}" ); 25 | } 26 | } -------------------------------------------------------------------------------- /Tools/CompressShaders/Readme.txt: -------------------------------------------------------------------------------- 1 | This project builds a C# console app which serves as a code generator for a few pieces of Whisper.dll and WhisperNet.dll. 2 | 3 | Specifically, it generates two things. 4 | 5 | 1. It compresses the compiled DXBC shaders into a blob of bytes, and prints std::array with these bytes into shaderData-Release.inl and shaderData-Debug.inl C++ files. 6 | 7 | 2. It parses the `languageCodez.tsv`, and generates both C++ and C# code with the data from that table. 8 | 9 | The tool uses relative paths across source files. 10 | These paths will break if you move the source of the tool, or the source data of the tool. -------------------------------------------------------------------------------- /Tools/CompressShaders/ShaderNames.cs: -------------------------------------------------------------------------------- 1 | static class ShaderNames 2 | { 3 | public static void write( string path, IEnumerable names ) 4 | { 5 | string[] arr = names.ToArray(); 6 | using var stream = File.CreateText( path ); 7 | stream.WriteLine( @"// This source file is generated by a tool 8 | #include ""stdafx.h"" 9 | #include ""shaderNames.h"" 10 | " ); 11 | 12 | stream.WriteLine( "static const std::array s_shaderNames = ", arr.Length ); 13 | stream.WriteLine( "{" ); 14 | foreach( string name in arr ) 15 | stream.WriteLine( @" ""{0}"",", name ); 16 | 17 | stream.Write( @"}; 18 | 19 | const char* DirectCompute::computeShaderName( eComputeShader cs ) 20 | { 21 | const uint16_t i = (uint16_t)cs; 22 | if( i < s_shaderNames.size() ) 23 | return s_shaderNames[ i ]; 24 | return nullptr; 25 | }" ); 26 | } 27 | } -------------------------------------------------------------------------------- /Tools/CompressTables/CompressTables.cs: -------------------------------------------------------------------------------- 1 | namespace CompressTables; 2 | using CompressShaders; 3 | using System.IO; 4 | using System.Runtime.CompilerServices; 5 | 6 | /// Utility app to compress lookup tables data with LZ4-HC, and generate C++ source with the data 7 | internal class Program 8 | { 9 | static string getSolutionRoot( [CallerFilePath] string? path = null ) 10 | { 11 | string? dir = Path.GetDirectoryName( path ); 12 | dir = Path.GetDirectoryName( dir ); 13 | dir = Path.GetDirectoryName( dir ); 14 | return dir ?? throw new ApplicationException(); 15 | } 16 | 17 | static void writeArray( byte[] compressed, string path ) 18 | { 19 | using var stream = File.CreateText( path ); 20 | stream.WriteLine( "// This source file is generated by a tool" ); 21 | stream.Write( @"static const std::array s_tableData = {{", compressed.Length ); 22 | 23 | for( int i = 0; i < compressed.Length; i++ ) 24 | { 25 | if( 0 == i % 32 ) 26 | stream.Write( "\r\n\t" ); 27 | else 28 | stream.Write( ' ' ); 29 | stream.Write( "0x{0:X02},", compressed[ i ] ); 30 | } 31 | stream.Write( @" 32 | };" ); 33 | } 34 | 35 | static void Main( string[] args ) 36 | { 37 | byte[] source = File.ReadAllBytes( @"C:\Temp\2remove\Whisper\tables.bin" ); 38 | byte[] result = LZ4.compressBuffer( source ); 39 | 40 | string root = getSolutionRoot(); 41 | string path = Path.Combine( root, "Whisper", "ML", "LookupTablesData.inl" ); 42 | writeArray( result, path ); 43 | } 44 | } -------------------------------------------------------------------------------- /Tools/CompressTables/CompressTables.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | Exe 4 | net7.0 5 | enable 6 | enable 7 | true 8 | false 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /Tools/PerfSummary/PerfSummary.cs: -------------------------------------------------------------------------------- 1 | using System.Runtime.CompilerServices; 2 | 3 | namespace PerfSummary 4 | { 5 | internal class Program 6 | { 7 | static string getSolutionRoot( [CallerFilePath] string? path = null ) 8 | { 9 | string? dir = Path.GetDirectoryName( path ); 10 | dir = Path.GetDirectoryName( dir ); 11 | dir = Path.GetDirectoryName( dir ); 12 | return dir ?? throw new ApplicationException(); 13 | } 14 | 15 | static void Main( string[] args ) 16 | { 17 | string root = getSolutionRoot(); 18 | root = Path.Combine( root, "SampleClips" ); 19 | 20 | LogData[] logs = LogParser.parse( root ) 21 | .OrderBy( x => x.name.clip ) 22 | .ThenBy( x => x.name.model ) 23 | .ThenBy( x => x.name.gpu ) 24 | .ToArray(); 25 | 26 | Summary.print( logs, root ); 27 | } 28 | } 29 | } -------------------------------------------------------------------------------- /Tools/PerfSummary/PerfSummary.csproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | Exe 4 | net6.0 5 | enable 6 | enable 7 | true 8 | false 9 | 10 | -------------------------------------------------------------------------------- /Tools/compareTraces/CommandLineArgs.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include "CommandLineArgs.h" 3 | #include 4 | 5 | static bool printUsage() 6 | { 7 | fprintf( stderr, "Usage: compareTraces.exe trace1.bin trace2.bin [-diff N]\n" ); 8 | return false; 9 | } 10 | 11 | bool CommandLineArgs::parse( int argc, wchar_t* argv[] ) 12 | { 13 | size_t idx = 0; 14 | CString sw; 15 | CStringA tmp; 16 | for( int i = 1; i < argc; i++ ) 17 | { 18 | if( argv[ i ][ 0 ] != L'-' ) 19 | { 20 | if( idx >= 2 ) 21 | return printUsage(); 22 | inputs[ idx ] = argv[ i ]; 23 | idx++; 24 | continue; 25 | } 26 | sw = argv[ i ]; 27 | if( 0 == sw.CompareNoCase( L"-diff" ) ) 28 | { 29 | i++; 30 | if( i >= argc ) 31 | return printUsage(); 32 | tmp.Format( "%S", argv[ i ] ); 33 | tmp.Trim(); 34 | uint64_t v; 35 | auto res = std::from_chars( tmp, cstr( tmp ) + tmp.GetLength(), v ); 36 | if( res.ec != (std::errc)0 ) 37 | { 38 | fprintf( stderr, "Unable to parse string into number\n" ); 39 | return false; 40 | } 41 | printDiff = v; 42 | continue; 43 | } 44 | return printUsage(); 45 | } 46 | 47 | if( idx != 2 ) 48 | return printUsage(); 49 | 50 | return true; 51 | } -------------------------------------------------------------------------------- /Tools/compareTraces/CommandLineArgs.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | struct CommandLineArgs 4 | { 5 | int64_t printDiff = -1; 6 | std::array inputs; 7 | 8 | bool parse( int argc, wchar_t* argv[] ); 9 | }; -------------------------------------------------------------------------------- /Tools/compareTraces/Readme.txt: -------------------------------------------------------------------------------- 1 | This project builds a C++ console tool which compares debug traces of the model. 2 | 3 | Tracing files easily exceed 1GB, and by default they’re disabled with a preprocessor macro in stdafx.h of the Whisper project. 4 | 5 | When enabled, the main GPU implementation saves a trace into C:\Temp\2remove\Whisper\gpu.bin 6 | 7 | The reference CPU implementation saves a trace into C:\Temp\2remove\Whisper\ref.bin 8 | 9 | This code in this project is optimized for development speed. For this reason it requires AVX2 CPU, uses memory-mapped IO instead of proper parsing, and checks little to no errors. -------------------------------------------------------------------------------- /Tools/compareTraces/TraceReader.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include "TraceReader.h" 3 | using namespace Tracing; 4 | 5 | const sTraceItem& TraceReader::operator[]( size_t idx ) const 6 | { 7 | if( idx >= countItems ) 8 | throw E_BOUNDS; 9 | return items[ idx ]; 10 | } 11 | 12 | CStringA TraceReader::getName( const sTraceItem& item ) const 13 | { 14 | const size_t idx = item.stringIndex; 15 | if( idx >= countStrings ) 16 | throw E_BOUNDS; 17 | const char* const source = stringData + stringIndex[ idx ]; 18 | CStringA res; 19 | res.Format( source, item.formatArgs[ 0 ], item.formatArgs[ 1 ], item.formatArgs[ 2 ], item.formatArgs[ 3 ] ); 20 | return res; 21 | } 22 | 23 | HRESULT TraceReader::open( LPCTSTR path ) 24 | { 25 | CHECK( file.Create( path, GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING ) ); 26 | CHECK( mapping.MapFile( file ) ); 27 | 28 | const uint8_t* rsi = mapping; 29 | const sFileHeader& header = *(const sFileHeader*)rsi; 30 | if( header.magic != header.correctMagic ) 31 | return E_INVALIDARG; 32 | countItems = header.countItems; 33 | countStrings = header.countStrings; 34 | 35 | rsi += sizeof( sFileHeader ); 36 | payloadPointer = rsi; 37 | 38 | rsi += header.bytesPayload; 39 | stringIndex = (const uint32_t*)( rsi ); 40 | stringData = (const char*)( rsi + countStrings * 4 ); 41 | 42 | rsi += header.bytesStrings; 43 | items = (const sTraceItem*)rsi; 44 | 45 | return S_OK; 46 | } -------------------------------------------------------------------------------- /Tools/compareTraces/TraceReader.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../../Whisper/Utils/Trace/TraceStructures.h" 3 | #include 4 | #include 5 | 6 | namespace Tracing 7 | { 8 | class TraceReader 9 | { 10 | const uint8_t* payloadPointer = nullptr; 11 | const sTraceItem* items = nullptr; 12 | size_t countItems = 0; 13 | size_t countStrings = 0; 14 | const uint32_t* stringIndex = nullptr; 15 | const char* stringData = nullptr; 16 | 17 | CAtlFile file; 18 | CAtlFileMapping mapping; 19 | 20 | public: 21 | 22 | TraceReader() = default; 23 | ~TraceReader() = default; 24 | 25 | HRESULT open( LPCTSTR path ); 26 | size_t size() const { return countItems; } 27 | const sTraceItem& operator[]( size_t idx ) const; 28 | CStringA getName( const sTraceItem& item ) const; 29 | 30 | const void* payload( const sTraceItem& item ) const 31 | { 32 | return payloadPointer + item.payloadOffset; 33 | } 34 | }; 35 | } -------------------------------------------------------------------------------- /Tools/compareTraces/compare.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "CommandLineArgs.h" 3 | 4 | HRESULT compareTraces( const CommandLineArgs& arguments ); -------------------------------------------------------------------------------- /Tools/compareTraces/compareTraces.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include 3 | #include "compare.h" 4 | #include "CommandLineArgs.h" 5 | 6 | int wmain( int argc, wchar_t* argv[] ) 7 | { 8 | CommandLineArgs cla; 9 | if( !cla.parse( argc, argv ) ) 10 | return 1; 11 | 12 | HRESULT hr = compareTraces( cla ); 13 | if( SUCCEEDED( hr ) ) 14 | return 0; 15 | return hr; 16 | } -------------------------------------------------------------------------------- /Tools/compareTraces/compareTraces.vcxproj.filters: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /Tools/compareTraces/stdafx.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | 3 | namespace 4 | { 5 | wchar_t* formatMessage( HRESULT hr ) 6 | { 7 | wchar_t* err; 8 | if( FormatMessage( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM, 9 | NULL, 10 | hr, 11 | MAKELANGID( LANG_NEUTRAL, SUBLANG_DEFAULT ), 12 | (LPTSTR)&err, 13 | 0, 14 | NULL ) ) 15 | return err; 16 | return nullptr; 17 | } 18 | } 19 | 20 | void printError( HRESULT hr ) 21 | { 22 | const wchar_t* err = formatMessage( hr ); 23 | if( nullptr != err ) 24 | { 25 | fwprintf( stderr, L"%s\n", err ); 26 | LocalFree( (HLOCAL)err ); 27 | } 28 | else 29 | fprintf( stderr, "Error code %i (0x%08X)\n", hr, hr ); 30 | } -------------------------------------------------------------------------------- /Tools/compareTraces/stdafx.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #define WIN32_LEAN_AND_MEAN 6 | #define NOMINMAX 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #define CHECK( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) return __hr; } 17 | 18 | inline __m128i load16( const int* rsi ) 19 | { 20 | return _mm_loadu_si128( ( const __m128i* )rsi ); 21 | } 22 | inline __m128i load16( const uint32_t* rsi ) 23 | { 24 | return _mm_loadu_si128( ( const __m128i* )rsi ); 25 | } 26 | inline __m128i load( const std::array& arr ) 27 | { 28 | return load16( arr.data() ); 29 | } 30 | 31 | inline bool vectorEqual( __m128i a, __m128i b ) 32 | { 33 | __m128i xx = _mm_xor_si128( a, b ); 34 | return (bool)_mm_testz_si128( xx, xx ); 35 | } 36 | 37 | void printError( HRESULT hr ); 38 | 39 | inline const char* cstr( const CStringA& s ) { return s; } 40 | inline const wchar_t* cstr( const CString& s ) { return s; } -------------------------------------------------------------------------------- /Whisper/API/MfStructs.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace Whisper 4 | { 5 | struct sCaptureDevice 6 | { 7 | // The display name is suitable for showing to the user, but might not be unique. 8 | const wchar_t* displayName; 9 | 10 | // Endpoint ID for an audio capture device 11 | // It uniquely identifies the device on the system, but is not a readable string. 12 | const wchar_t* endpoint; 13 | }; 14 | 15 | using pfnFoundCaptureDevices = HRESULT( __stdcall* )( int len, const sCaptureDevice* buffer, void* pv ); 16 | 17 | // Flags for the audio capture 18 | enum struct eCaptureFlags : uint32_t 19 | { 20 | // When the capture device supports stereo, keep stereo PCM samples in addition to mono 21 | Stereo = 1, 22 | }; 23 | 24 | // Parameters for audio capture 25 | struct sCaptureParams 26 | { 27 | float minDuration = 2.0f; 28 | float maxDuration = 3.0f; 29 | float dropStartSilence = 0.25f; 30 | float pauseDuration = 0.333f; 31 | // Flags for the audio capture 32 | uint32_t flags = 0; 33 | }; 34 | 35 | enum struct eCaptureStatus : uint8_t 36 | { 37 | Listening = 1, 38 | Voice = 2, 39 | Transcribing = 4, 40 | Stalled = 0x80, 41 | }; 42 | 43 | // Return S_OK to continue, or S_FALSE to stop the capture session 44 | using pfnShouldCancel = HRESULT( __stdcall* )( void* pv ) noexcept; 45 | 46 | using pfnCaptureStatus = HRESULT( __stdcall* )( void* pv, eCaptureStatus status ) noexcept; 47 | 48 | struct sCaptureCallbacks 49 | { 50 | pfnShouldCancel shouldCancel; 51 | pfnCaptureStatus captureStatus; 52 | void* pv; 53 | }; 54 | } -------------------------------------------------------------------------------- /Whisper/API/Readme.txt: -------------------------------------------------------------------------------- 1 | The headers in this folder define the complete public API of Whisper.dll. 2 | 3 | To consume the library in your C++ software, include exactly one of the following headers. 4 | 5 | 1. If you’re building a windows app, include whisperWindows.h header, and you'll get traditional Win32 COM projection of the API. 6 | 7 | 2. If you’re porting to other OS, or porting to different C++ compiler, or already using ComLight support library, include whisperComLight.h header. 8 | If you do that, in addition to this "Whisper/API" folder you also gonna need the "ComLightLib" dependency. 9 | This will get you the ComLight flavor of these COM interfaces. 10 | 11 | Internally, the actual implementation uses the ComLight flavour of the interfaces, but that’s fine because they are binary compatible. 12 | 13 | The reason for the difference between these flavors — Visual Studio’s CComPtr and other related utilities expect interface IDs specified with __declspec(uuid) directive. 14 | 15 | That language extension is specific to Visual C++, not supported in GCC nor Clang compilers. -------------------------------------------------------------------------------- /Whisper/API/SpecialTokens.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace Whisper 4 | { 5 | struct SpecialTokens 6 | { 7 | // The end of a transcription, token_eot 8 | int TranscriptionEnd; 9 | // Start of a transcription, token_sot 10 | int TranscriptionStart; 11 | // Represents the previous word in the transcription. It is used to help the model predict the current word based on the context of the words that came before it. 12 | int PreviousWord; // token_prev 13 | // Start of a sentence 14 | int SentenceStart; // token_solm 15 | //Represents the word "not" in the transcription 16 | int Not; // token_not 17 | //New transcription 18 | int TranscriptionBegin; // token_beg 19 | 20 | // token_translate 21 | int TaskTranslate; 22 | // token_transcribe 23 | int TaskTranscribe; 24 | }; 25 | } -------------------------------------------------------------------------------- /Whisper/API/iTranscribeResult.cl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "TranscribeStructs.h" 3 | #include "../../ComLightLib/comLightCommon.h" 4 | 5 | namespace Whisper 6 | { 7 | struct iTranscribeResult : public ComLight::IUnknown 8 | { 9 | DEFINE_INTERFACE_ID( "{2871a73f-5ce3-48f8-8779-6582ee11935e}" ); 10 | 11 | virtual HRESULT COMLIGHTCALL getSize( sTranscribeLength& rdi ) const = 0; 12 | virtual const sSegment* COMLIGHTCALL getSegments() const = 0; 13 | virtual const sToken* COMLIGHTCALL getTokens() const = 0; 14 | }; 15 | } -------------------------------------------------------------------------------- /Whisper/API/iTranscribeResult.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "TranscribeStructs.h" 3 | 4 | namespace Whisper 5 | { 6 | __interface __declspec( novtable, uuid( "2871a73f-5ce3-48f8-8779-6582ee11935e" ) ) iTranscribeResult : public IUnknown 7 | { 8 | HRESULT __stdcall getSize( sTranscribeLength& rdi ) const; 9 | const sSegment* __stdcall getSegments() const; 10 | const sToken* __stdcall getTokens() const; 11 | }; 12 | } -------------------------------------------------------------------------------- /Whisper/API/loggerApi.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace Whisper 5 | { 6 | // Log level for messages 7 | enum struct eLogLevel : uint8_t 8 | { 9 | Error = 0, 10 | Warning = 1, 11 | Info = 2, 12 | Debug = 3 13 | }; 14 | enum struct eLoggerFlags : uint8_t 15 | { 16 | UseStandardError = 1, 17 | SkipFormatMessage = 2, 18 | }; 19 | 20 | // C function pointer to receive log messages from the library. The messages are encoded in UTF-8. 21 | using pfnLoggerSink = void( __stdcall* )( void* context, eLogLevel lvl, const char* message ); 22 | 23 | // A sink to receive log messages produced by MeshRepair.dll 24 | struct sLoggerSetup 25 | { 26 | // C function pointer to receive log messages from the library 27 | pfnLoggerSink sink = nullptr; 28 | // Optional context parameter for the sink function; when consuming from C# you don't need that, pass IntPtr.Zero, delegates can capture things. 29 | void* context = nullptr; 30 | // Maximum log level to produce 31 | eLogLevel level; 32 | // Flags about the logger 33 | eLoggerFlags flags = (eLoggerFlags)0; 34 | }; 35 | } -------------------------------------------------------------------------------- /Whisper/API/sLanguageList.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace Whisper 5 | { 6 | struct sLanguageEntry 7 | { 8 | uint32_t key; 9 | int id; 10 | const char* name; 11 | }; 12 | 13 | struct sLanguageList 14 | { 15 | uint32_t length; 16 | const sLanguageEntry* pointer; 17 | }; 18 | } -------------------------------------------------------------------------------- /Whisper/API/sLoadModelCallbacks.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace Whisper 4 | { 5 | using pfnLoadProgress = HRESULT( __stdcall* )( double val, void* pv ) noexcept; 6 | // Return S_OK to continue, or S_FALSE to fail with "The operation was canceled by the user" status code 7 | using pfnCancel = HRESULT( __stdcall* )( void* pv ) noexcept; 8 | 9 | struct sLoadModelCallbacks 10 | { 11 | pfnLoadProgress progress; 12 | pfnCancel cancel; 13 | void* pv; 14 | }; 15 | } -------------------------------------------------------------------------------- /Whisper/API/sModelSetup.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace Whisper 5 | { 6 | enum struct eModelImplementation : uint32_t 7 | { 8 | // GPGPU implementation based on Direct3D 11.0 compute shaders 9 | GPU = 1, 10 | 11 | // A hybrid implementation which uses DirectCompute for encode, and decodes on CPU 12 | // Not implemented in the published builds of the DLL. To enable, change BUILD_HYBRID_VERSION macro to 1 13 | Hybrid = 2, 14 | 15 | // A reference implementation which uses the original GGML CPU-running code 16 | // Not implemented in the published builds of the DLL. To enable, change BUILD_BOTH_VERSIONS macro to 1 17 | Reference = 3, 18 | }; 19 | 20 | enum struct eGpuModelFlags : uint32_t 21 | { 22 | Wave32 = 1, 23 | Wave64 = 2, 24 | NoReshapedMatMul = 4, 25 | UseReshapedMatMul = 8, 26 | Cloneable = 0x10, 27 | }; 28 | 29 | struct sModelSetup 30 | { 31 | eModelImplementation impl = eModelImplementation::GPU; 32 | uint32_t flags = 0; 33 | const wchar_t* adapter = nullptr; 34 | }; 35 | 36 | // Function pointer to enumerate GPUs 37 | using pfnListAdapters = void( __stdcall* )( const wchar_t* name, void* pv ); 38 | 39 | // Function pointer to receive array of tokens from iModel.tokenize() API method 40 | using pfnDecodedTokens = void( __stdcall* )( const int* tokens, int tokensLength, void* pv ); 41 | } -------------------------------------------------------------------------------- /Whisper/API/whisperComLight.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "iMediaFoundation.cl.h" 3 | #include "iContext.cl.h" 4 | #include "iTranscribeResult.cl.h" -------------------------------------------------------------------------------- /Whisper/API/whisperWindows.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "iMediaFoundation.h" 3 | #include "iContext.h" 4 | #include "iTranscribeResult.h" -------------------------------------------------------------------------------- /Whisper/CPU/HybridLoader.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "DecoderTensors.h" 3 | #include 4 | #include 5 | #include "../../ComLightLib/streams.h" 6 | 7 | namespace CpuCompute 8 | { 9 | __interface iLoaderProgressSink 10 | { 11 | HRESULT gotBytes( int64_t cb ); 12 | }; 13 | 14 | class HybridLoader 15 | { 16 | DecoderTensors& destination; 17 | CAtlMap map; 18 | size_t bufferBytes = 0; 19 | 20 | struct alignas( 32 ) PendingTensor 21 | { 22 | Tensor* destPointer = nullptr; 23 | int64_t streamOffset = 0; 24 | size_t bufferOffset = 0; 25 | size_t payloadBytes = 0; 26 | }; 27 | std::vector pending; 28 | 29 | public: 30 | 31 | HybridLoader( DecoderTensors& m, int countLayers ); 32 | 33 | HRESULT setupTensor( const CStringA& name, int n_dims, int ftype, const std::array& ne, ComLight::iReadStream* stream, int64_t& postponedBytes ); 34 | 35 | HRESULT completeLoad( ComLight::iReadStream* stream, iLoaderProgressSink& progressSink ); 36 | }; 37 | } -------------------------------------------------------------------------------- /Whisper/CPU/KvTensors.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "Tensor.h" 3 | #include "LargeBuffer.h" 4 | #include "../Whisper/sModelParams.h" 5 | 6 | namespace CpuCompute 7 | { 8 | class KvTensors 9 | { 10 | uint16_t* keys = nullptr; 11 | uint16_t* values = nullptr; 12 | uint32_t size = 0; 13 | 14 | CpuCompute::LargeBuffer memory; 15 | 16 | public: 17 | // Create these two large tensors, FP16 precision 18 | HRESULT create( const Whisper::sModelParams& mp ); 19 | 20 | // A slice of model.memory_cross_k tensor 21 | Tensor keysView( uint32_t len, uint32_t off ) const 22 | { 23 | if( len + off <= size ) 24 | return Tensor::fromData( keys + off, eDataType::FP16, len ); 25 | throw E_BOUNDS; 26 | } 27 | 28 | // A slice of model.memory_cross_v tensor 29 | Tensor valuesView( uint32_t len, uint32_t off ) const 30 | { 31 | if( len + off <= size ) 32 | return Tensor::fromData( values + off, eDataType::FP16, len ); 33 | throw E_BOUNDS; 34 | } 35 | }; 36 | } -------------------------------------------------------------------------------- /Whisper/CPU/KvTensorsCpu.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include "KvTensors.h" 3 | using namespace CpuCompute; 4 | 5 | // Create these two large tensors, FP16 precision 6 | HRESULT KvTensors::create( const Whisper::sModelParams& mp ) 7 | { 8 | const uint32_t n_mem = mp.n_text_layer * mp.n_text_ctx; 9 | const uint32_t n_elements = mp.n_text_state * n_mem; 10 | 11 | const size_t cb = sizeof( uint16_t ) * (size_t)n_elements * 2; 12 | CHECK( memory.allocate( cb ) ); 13 | 14 | uint16_t* pointer = (uint16_t*)memory.pointer(); 15 | keys = pointer; 16 | values = pointer + n_elements; 17 | size = n_elements; 18 | return S_OK; 19 | } -------------------------------------------------------------------------------- /Whisper/CPU/LargeBuffer.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include "LargeBuffer.h" 3 | using namespace CpuCompute; 4 | 5 | void LargeBuffer::deallocate() 6 | { 7 | if( nullptr == pv ) 8 | return; 9 | VirtualFree( pv, 0, MEM_RELEASE ); 10 | pv = nullptr; 11 | } 12 | 13 | HRESULT LargeBuffer::allocate( size_t cb ) 14 | { 15 | deallocate(); 16 | 17 | pv = VirtualAlloc( nullptr, cb, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE ); 18 | if( nullptr != pv ) 19 | return S_OK; 20 | return HRESULT_FROM_WIN32( GetLastError() ); 21 | } 22 | 23 | HRESULT LargeBuffer::setReadOnly( size_t cb ) 24 | { 25 | if( nullptr != pv ) 26 | { 27 | DWORD op = 0; 28 | if( VirtualProtect( pv, cb, PAGE_READONLY, &op ) ) 29 | return S_OK; 30 | return HRESULT_FROM_WIN32( GetLastError() ); 31 | } 32 | else 33 | return OLE_E_BLANK; 34 | } -------------------------------------------------------------------------------- /Whisper/CPU/LargeBuffer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace CpuCompute 4 | { 5 | // A large memory buffer allocated with VirtualAlloc kernel API, bypassing the heap. 6 | class LargeBuffer 7 | { 8 | void* pv = nullptr; 9 | public: 10 | LargeBuffer() = default; 11 | LargeBuffer( const LargeBuffer& ) = delete; 12 | LargeBuffer( LargeBuffer&& that ) noexcept 13 | { 14 | pv = that.pv; 15 | that.pv = nullptr; 16 | } 17 | ~LargeBuffer() 18 | { 19 | deallocate(); 20 | } 21 | void operator=( LargeBuffer&& that ) noexcept 22 | { 23 | std::swap( pv, that.pv ); 24 | } 25 | void operator=( const LargeBuffer& that ) = delete; 26 | 27 | // Allocate buffer with specified count of bytes, and read+write memory protection 28 | // The OS kernel guarantees zero-initialization of that memory. 29 | HRESULT allocate( size_t cb ); 30 | 31 | // Change memory protection of the buffer to read only 32 | HRESULT setReadOnly( size_t cb ); 33 | 34 | // Unless the pointer is nullptr, deallocate the buffer 35 | void deallocate(); 36 | 37 | // Pointer to the start of the buffer, aligned by memory page = 4 kilobytes 38 | uint8_t* pointer() const 39 | { 40 | assert( nullptr != pv ); 41 | return (uint8_t*)pv; 42 | } 43 | }; 44 | } -------------------------------------------------------------------------------- /Whisper/CPU/ParallelForRunner.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "LargeBuffer.h" 3 | 4 | namespace CpuCompute 5 | { 6 | // Callback interface for the parallel `for` 7 | __interface iComputeRange 8 | { 9 | // The implementation calls this method on multiple thread pool threads in parallel, and aggregates status codes. 10 | HRESULT __stdcall compute( size_t begin, size_t end ) const; 11 | }; 12 | 13 | // Similar to ThreadPoolWork in parallelFor.h, optimized to be used as a direct replacement of OpenMP pool. 14 | class alignas( 64 ) ParallelForRunner 15 | { 16 | public: 17 | ParallelForRunner( int threads ); 18 | ~ParallelForRunner(); 19 | 20 | HRESULT setThreadsCount( int threads ); 21 | 22 | HRESULT parallelFor( iComputeRange& compute, size_t length, size_t minBatch = 1 ); 23 | 24 | // Allocate a temporary buffer for the calling thread. 25 | // The pointer is guaranteed to be aligned by page size = 4kb 26 | void* threadLocalBuffer( size_t cb ); 27 | 28 | private: 29 | 30 | int maxThreads; 31 | PTP_WORK work = nullptr; 32 | iComputeRange* computeRange = nullptr; 33 | size_t countItems = 0; 34 | size_t countThreads = 0; 35 | 36 | // Aligning by cache lines. 37 | // Avoiding cache line sharing between CPU cores improves performance, despite wasting a few bytes of memory. 38 | struct alignas( 64 ) ThreadBuffer 39 | { 40 | LargeBuffer memory; 41 | size_t cb = 0; 42 | }; 43 | std::vector threadBuffers; 44 | 45 | alignas( 64 ) volatile long threadIndex = 0; 46 | volatile HRESULT status = S_OK; 47 | 48 | void runBatch( size_t ith ) noexcept; 49 | 50 | static void __stdcall workCallbackStatic( PTP_CALLBACK_INSTANCE Instance, void* pv, PTP_WORK Work ) noexcept; 51 | }; 52 | } -------------------------------------------------------------------------------- /Whisper/CPU/Readme.txt: -------------------------------------------------------------------------------- 1 | The code in this folder is dropped by the linker’s dead code elimination optimization pass, unless you change BUILD_HYBRID_VERSION macro in stdafx.h -------------------------------------------------------------------------------- /Whisper/CPU/mulMat.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include "mulMat.h" 3 | #include "mulMatImpl.h" 4 | using namespace CpuCompute; 5 | 6 | namespace 7 | { 8 | template 9 | static HRESULT mulMatImpl( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor ) 10 | { 11 | MulMatImpl impl{ result, a, b, pfor }; 12 | return impl.run( pfor ); 13 | } 14 | } 15 | 16 | HRESULT CpuCompute::mulMat( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor ) 17 | { 18 | if( a.type() != eDataType::FP16 ) 19 | return E_NOTIMPL; 20 | if( b.type() != eDataType::FP32 ) 21 | return E_NOTIMPL; 22 | 23 | // return mulMatImpl<1, 1>( result, a, b, pfor ); 24 | 25 | if( b.ne[ 1 ] == 1 ) 26 | { 27 | // Multiplying by a single row 28 | if( a.ne[ 1 ] >= 32 ) 29 | return mulMatImpl<4, 1>( result, a, b, pfor ); 30 | else 31 | return mulMatImpl<1, 1>( result, a, b, pfor ); 32 | } 33 | else if( b.ne[ 1 ] == 2 ) 34 | { 35 | if( a.ne[ 1 ] >= 32 ) 36 | return mulMatImpl<4, 2>( result, a, b, pfor ); 37 | else 38 | return mulMatImpl<1, 2>( result, a, b, pfor ); 39 | } 40 | else if( b.ne[ 1 ] == 3 ) 41 | { 42 | if( a.ne[ 1 ] >= 16 ) 43 | return mulMatImpl<2, 3>( result, a, b, pfor ); 44 | else 45 | return mulMatImpl<1, 3>( result, a, b, pfor ); 46 | } 47 | else 48 | { 49 | if( a.ne[ 1 ] >= 16 ) 50 | return mulMatImpl<2, 4>( result, a, b, pfor ); 51 | else 52 | return mulMatImpl<1, 4>( result, a, b, pfor ); 53 | } 54 | } -------------------------------------------------------------------------------- /Whisper/CPU/mulMat.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "ParallelForRunner.h" 3 | #include "Tensor.h" 4 | 5 | namespace CpuCompute 6 | { 7 | HRESULT mulMat( Tensor& result, const Tensor& a, const Tensor& b, ParallelForRunner& pfor ); 8 | } 9 | 10 | #if TENSOR_GGML_COMPAT 11 | #include "../source/ggml.h" 12 | inline HRESULT mulMat( ggml_tensor* result, const ggml_tensor* a, const ggml_tensor* b, CpuCompute::ParallelForRunner& pfor ) 13 | { 14 | CpuCompute::Tensor r{ result }, lhs{ a }, rhs{ b }; 15 | return CpuCompute::mulMat( r, lhs, rhs, pfor ); 16 | } 17 | #endif -------------------------------------------------------------------------------- /Whisper/CPU/mulMatImpl.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Const-me/Whisper/306aadd1fce4b168cd38659236f4ba7c1603cebd/Whisper/CPU/mulMatImpl.cpp -------------------------------------------------------------------------------- /Whisper/D3D/Binder.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "device.h" 3 | 4 | namespace DirectCompute 5 | { 6 | class Binder 7 | { 8 | uint8_t maxSrv = 0; 9 | uint8_t maxUav = 0; 10 | 11 | public: 12 | Binder() = default; 13 | Binder( const Binder& ) = delete; 14 | 15 | void bind( ID3D11ShaderResourceView* srv0, ID3D11UnorderedAccessView* uav0 ); 16 | void bind( ID3D11ShaderResourceView* srv0, ID3D11ShaderResourceView* srv1, ID3D11UnorderedAccessView* uav0 ); 17 | void bind( std::initializer_list srvs, std::initializer_list uavs ); 18 | void bind( ID3D11UnorderedAccessView* uav0 ); 19 | ~Binder(); 20 | }; 21 | } -------------------------------------------------------------------------------- /Whisper/D3D/MappedResource.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include "MappedResource.h" 3 | using namespace DirectCompute; 4 | #define CHECK( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) return __hr; } 5 | 6 | MappedResource::MappedResource() 7 | { 8 | mapped.pData = nullptr; 9 | mapped.RowPitch = mapped.DepthPitch = 0; 10 | resource = nullptr; 11 | } 12 | 13 | HRESULT MappedResource::map( ID3D11Resource* res, bool reading ) 14 | { 15 | if( nullptr == resource ) 16 | { 17 | D3D11_MAP mt = reading ? D3D11_MAP_READ : D3D11_MAP_WRITE_DISCARD; 18 | CHECK( context()->Map( res, 0, mt, 0, &mapped ) ); 19 | resource = res; 20 | return S_OK; 21 | } 22 | return HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED ); 23 | } 24 | 25 | MappedResource::~MappedResource() 26 | { 27 | if( nullptr != resource ) 28 | { 29 | context()->Unmap( resource, 0 ); 30 | resource = nullptr; 31 | mapped.pData = nullptr; 32 | } 33 | } -------------------------------------------------------------------------------- /Whisper/D3D/MappedResource.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "device.h" 3 | #include 4 | 5 | namespace DirectCompute 6 | { 7 | class MappedResource 8 | { 9 | D3D11_MAPPED_SUBRESOURCE mapped; 10 | ID3D11Resource* resource; 11 | public: 12 | MappedResource(); 13 | HRESULT map( ID3D11Resource* res, bool reading ); 14 | ~MappedResource(); 15 | 16 | void* data() const 17 | { 18 | assert( nullptr != mapped.pData ); 19 | return mapped.pData; 20 | } 21 | }; 22 | } -------------------------------------------------------------------------------- /Whisper/D3D/RenderDoc/renderDoc.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include "renderDoc.h" 3 | #include "renderdoc_app.h" 4 | #include "../device.h" 5 | 6 | #define ENABLE_RENDERDOC_DEBUGGER 1 7 | 8 | #if ENABLE_RENDERDOC_DEBUGGER 9 | namespace 10 | { 11 | static HMODULE hmRenderDoc = nullptr; 12 | static RENDERDOC_API_1_6_0* api = nullptr; 13 | } 14 | 15 | bool DirectCompute::initializeRenderDoc() 16 | { 17 | hmRenderDoc = GetModuleHandleW( L"renderdoc.dll" ); 18 | if( nullptr == hmRenderDoc ) 19 | return false; 20 | 21 | pRENDERDOC_GetAPI getApi = (pRENDERDOC_GetAPI)GetProcAddress( hmRenderDoc, "RENDERDOC_GetAPI" ); 22 | if( nullptr == getApi ) 23 | return false; 24 | if( 1 != getApi( eRENDERDOC_API_Version_1_6_0, (void**)&api ) ) 25 | return false; 26 | if( nullptr == api ) 27 | return false; 28 | 29 | return true; 30 | } 31 | 32 | namespace 33 | { 34 | using namespace DirectCompute; 35 | inline bool isKeyPressed( int vKey ) 36 | { 37 | return 0 != ( GetAsyncKeyState( vKey ) & 0x8000 ); 38 | } 39 | } 40 | 41 | CaptureRaii::CaptureRaii() : capturing( false ) 42 | { 43 | if( nullptr == api ) 44 | return; 45 | if( !isKeyPressed( VK_F12 ) ) 46 | return; 47 | ID3D11Device* const dev = device(); 48 | if( nullptr == dev ) 49 | return; 50 | 51 | api->StartFrameCapture( dev, nullptr ); 52 | capturing = true; 53 | } 54 | 55 | CaptureRaii::~CaptureRaii() 56 | { 57 | if( !capturing ) 58 | return; 59 | api->EndFrameCapture( device(), nullptr ); 60 | } 61 | #else // !ENABLE_RENDERDOC_DEBUGGER 62 | bool DirectCompute::initializeRenderDoc() 63 | { 64 | return false; 65 | } 66 | DirectCompute::CaptureRaii::CaptureRaii() : capturing( false ) 67 | { 68 | } 69 | DirectCompute::CaptureRaii::~CaptureRaii() 70 | { 71 | } 72 | #endif -------------------------------------------------------------------------------- /Whisper/D3D/RenderDoc/renderDoc.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace DirectCompute 4 | { 5 | bool initializeRenderDoc(); 6 | 7 | class CaptureRaii 8 | { 9 | bool capturing; 10 | public: 11 | CaptureRaii(); 12 | CaptureRaii( const CaptureRaii& ) = delete; 13 | ~CaptureRaii(); 14 | }; 15 | } -------------------------------------------------------------------------------- /Whisper/D3D/createBuffer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "enums.h" 3 | #include "device.h" 4 | 5 | namespace DirectCompute 6 | { 7 | HRESULT createBuffer( eBufferUse use, size_t totalBytes, ID3D11Buffer** ppGpuBuffer, const void* rsi, ID3D11Buffer** ppStagingBuffer, bool shared = false ); 8 | } -------------------------------------------------------------------------------- /Whisper/D3D/createDevice.h: -------------------------------------------------------------------------------- 1 | // Low-level functions to create and initialize D3D11 device 2 | #pragma once 3 | #include 4 | #include "sGpuInfo.h" 5 | 6 | namespace DirectCompute 7 | { 8 | HRESULT createDevice( const std::wstring& adapter, ID3D11Device** dev, ID3D11DeviceContext** context ); 9 | 10 | HRESULT validateFlags( uint32_t flags ); 11 | 12 | HRESULT queryDeviceInfo( sGpuInfo& rdi, ID3D11Device* dev, uint32_t flags ); 13 | 14 | // Create another device and context, on the same hardware GPU 15 | HRESULT cloneDevice( ID3D11Device* source, ID3D11Device** dev, ID3D11DeviceContext** context ); 16 | } -------------------------------------------------------------------------------- /Whisper/D3D/device.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include "sGpuInfo.h" 5 | 6 | namespace DirectCompute 7 | { 8 | ID3D11Device* device(); 9 | ID3D11DeviceContext* context(); 10 | const sGpuInfo& gpuInfo(); 11 | 12 | inline void csSetCB( ID3D11Buffer* cb ) 13 | { 14 | context()->CSSetConstantBuffers( 0, 1, &cb ); 15 | } 16 | 17 | __m128i bufferMemoryUsage( ID3D11Buffer* buffer ); 18 | 19 | __m128i resourceMemoryUsage( ID3D11ShaderResourceView* srv ); 20 | } -------------------------------------------------------------------------------- /Whisper/D3D/downloadBuffer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace DirectCompute 4 | { 5 | // Download a buffer from VRAM into std::vector 6 | // The function is relatively expensive, creates a temporary staging buffer on each call, and only used to test things. 7 | template 8 | HRESULT downloadBuffer( ID3D11ShaderResourceView* srv, std::vector& vec ); 9 | } -------------------------------------------------------------------------------- /Whisper/D3D/enums.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include "enums.h" 3 | 4 | static const alignas( 16 ) std::array s_tensorViewFormats = { DXGI_FORMAT_R16_FLOAT, DXGI_FORMAT_R32_FLOAT, DXGI_FORMAT_R32_UINT }; 5 | 6 | DXGI_FORMAT DirectCompute::viewFormat( eDataType dt ) 7 | { 8 | return s_tensorViewFormats[ (uint8_t)dt ]; 9 | } -------------------------------------------------------------------------------- /Whisper/D3D/enums.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace DirectCompute 6 | { 7 | enum struct eDataType : uint8_t 8 | { 9 | FP16, 10 | FP32, 11 | U32, 12 | }; 13 | 14 | inline size_t elementSize( eDataType dt ) 15 | { 16 | assert( dt == eDataType::FP16 || dt == eDataType::FP32 || dt == eDataType::U32 ); 17 | 18 | return ( dt == eDataType::FP16 ) ? 2 : 4; 19 | } 20 | 21 | DXGI_FORMAT viewFormat( eDataType dt ); 22 | 23 | enum struct eBufferUse : uint8_t 24 | { 25 | // Immutable tensor, readable from GPU 26 | Immutable, 27 | // Read+write tensor, readable and writable on GPU 28 | ReadWrite, 29 | // Read+write tensor, readable and writable on GPU, which supports downloads from GPU 30 | ReadWriteDownload, 31 | // The tensor is accessible by both GPU (read only) and CPU (write only). Optimized for resources frequently updated from CPU. 32 | Dynamic, 33 | }; 34 | } -------------------------------------------------------------------------------- /Whisper/D3D/listGPUs.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace DirectCompute 5 | { 6 | CComPtr selectAdapter( const std::wstring& requestedName ); 7 | } -------------------------------------------------------------------------------- /Whisper/D3D/sGpuInfo.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace DirectCompute 5 | { 6 | // DXGI_ADAPTER_DESC.VendorId magic numbers; they come from that database: https://pcisig.com/membership/member-companies 7 | enum struct eGpuVendor : uint16_t 8 | { 9 | AMD = 0x1002, 10 | NVidia = 0x10de, 11 | Intel = 0x8086, 12 | VMWare = 0x15ad, 13 | }; 14 | 15 | enum struct eGpuEffectiveFlags : uint8_t 16 | { 17 | Wave64 = 1, 18 | ReshapedMatMul = 2, 19 | Cloneable = 4, 20 | }; 21 | 22 | struct sGpuInfo 23 | { 24 | eGpuEffectiveFlags flags; 25 | eGpuVendor vendor; 26 | uint16_t device, revision; 27 | uint32_t subsystem; 28 | size_t vramDedicated, ramDedicated, ramShared; 29 | std::wstring description; 30 | 31 | inline bool wave64() const 32 | { 33 | return 0 != ( (uint8_t)flags & (uint8_t)eGpuEffectiveFlags::Wave64 ); 34 | } 35 | 36 | // On nVidia 1080Ti that approach is much slower, by a factor of 2.4 37 | // On AMD Cezanne that approach is faster by a factor of 0.69, i.e. 30% faster. 38 | // Dunno why that is, maybe 'coz on that AMD complete panels fit in L3 cache. 39 | // Anyway, we do want extra 30% perf on AMD Cezanne, so only using that code on AMD GPUs. 40 | // Dunno how it gonna behave on other GPUs, need to test. 41 | inline bool useReshapedMatMul() const 42 | { 43 | return 0 != ( (uint8_t)flags & (uint8_t)eGpuEffectiveFlags::ReshapedMatMul ); 44 | } 45 | 46 | inline bool cloneableModel() const 47 | { 48 | return 0 != ( (uint8_t)flags & (uint8_t)eGpuEffectiveFlags::Cloneable ); 49 | } 50 | }; 51 | } -------------------------------------------------------------------------------- /Whisper/D3D/shaderNames.cpp: -------------------------------------------------------------------------------- 1 | // This source file is generated by a tool 2 | #include "stdafx.h" 3 | #include "shaderNames.h" 4 | 5 | static const std::array s_shaderNames = 6 | { 7 | "add", 8 | "addInPlace", 9 | "addRepeat", 10 | "addRepeatEx", 11 | "addRepeatGelu", 12 | "addRepeatScale", 13 | "addRows", 14 | "convolutionMain", 15 | "convolutionMain2", 16 | "convolutionMain2Fixed", 17 | "convolutionPrep1", 18 | "convolutionPrep2", 19 | "copyConvert", 20 | "copyTranspose", 21 | "dbgFindNaN", 22 | "diagMaskInf", 23 | "flashAttention", 24 | "flashAttentionCompat1", 25 | "flashAttentionCompat2", 26 | "flashAttentionCompat3", 27 | "fmaRepeat1", 28 | "fmaRepeat2", 29 | "matReshapePanels", 30 | "mulMatByRow", 31 | "mulMatByRowTiled", 32 | "mulMatByRowTiledEx", 33 | "mulMatByScalar", 34 | "mulMatDotMain", 35 | "mulMatDotReshape", 36 | "mulMatMadMain", 37 | "mulMatTiled", 38 | "mulMatTiledEx", 39 | "norm", 40 | "normCompat", 41 | "normFixed", 42 | "scaleInPlace", 43 | "softMax", 44 | "softMaxCompat", 45 | "softMaxFixed", 46 | "softMaxLong", 47 | "zeroMemory", 48 | }; 49 | 50 | const char* DirectCompute::computeShaderName( eComputeShader cs ) 51 | { 52 | const uint16_t i = (uint16_t)cs; 53 | if( i < s_shaderNames.size() ) 54 | return s_shaderNames[ i ]; 55 | return nullptr; 56 | } -------------------------------------------------------------------------------- /Whisper/D3D/shaderNames.h: -------------------------------------------------------------------------------- 1 | // This header is generated by a tool 2 | #pragma once 3 | #include 4 | 5 | namespace DirectCompute 6 | { 7 | enum struct eComputeShader: uint16_t 8 | { 9 | add = 0, 10 | addInPlace = 1, 11 | addRepeat = 2, 12 | addRepeatEx = 3, 13 | addRepeatGelu = 4, 14 | addRepeatScale = 5, 15 | addRows = 6, 16 | convolutionMain = 7, 17 | convolutionMain2 = 8, 18 | convolutionMain2Fixed = 9, 19 | convolutionPrep1 = 10, 20 | convolutionPrep2 = 11, 21 | copyConvert = 12, 22 | copyTranspose = 13, 23 | dbgFindNaN = 14, 24 | diagMaskInf = 15, 25 | flashAttention = 16, 26 | flashAttentionCompat1 = 17, 27 | flashAttentionCompat2 = 18, 28 | flashAttentionCompat3 = 19, 29 | fmaRepeat1 = 20, 30 | fmaRepeat2 = 21, 31 | matReshapePanels = 22, 32 | mulMatByRow = 23, 33 | mulMatByRowTiled = 24, 34 | mulMatByRowTiledEx = 25, 35 | mulMatByScalar = 26, 36 | mulMatDotMain = 27, 37 | mulMatDotReshape = 28, 38 | mulMatMadMain = 29, 39 | mulMatTiled = 30, 40 | mulMatTiledEx = 31, 41 | norm = 32, 42 | normCompat = 33, 43 | normFixed = 34, 44 | scaleInPlace = 35, 45 | softMax = 36, 46 | softMaxCompat = 37, 47 | softMaxFixed = 38, 48 | softMaxLong = 39, 49 | zeroMemory = 40, 50 | }; 51 | 52 | const char* computeShaderName( eComputeShader cs ); 53 | } -------------------------------------------------------------------------------- /Whisper/D3D/shaders.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "shaderNames.h" 3 | 4 | namespace DirectCompute 5 | { 6 | HRESULT createComputeShaders( std::vector>& shaders ); 7 | 8 | void bindShader( eComputeShader shader ); 9 | } -------------------------------------------------------------------------------- /Whisper/DllMain.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | 3 | BOOL __stdcall DllMain( HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpvReserved ) 4 | { 5 | // Perform actions based on the reason for calling. 6 | switch( fdwReason ) 7 | { 8 | case DLL_PROCESS_ATTACH: 9 | // Initialize once for each new process. Return FALSE to fail DLL load. 10 | DisableThreadLibraryCalls( (HMODULE)hinstDLL ); 11 | break; 12 | case DLL_THREAD_ATTACH: 13 | // Do thread-specific initialization. 14 | break; 15 | case DLL_THREAD_DETACH: 16 | // Do thread-specific cleanup. 17 | break; 18 | case DLL_PROCESS_DETACH: 19 | if( lpvReserved != nullptr ) 20 | { 21 | break; // do not do cleanup if process termination scenario 22 | } 23 | // Perform any necessary cleanup 24 | break; 25 | } 26 | return TRUE; // Successful DLL_PROCESS_ATTACH. 27 | } -------------------------------------------------------------------------------- /Whisper/Hybrid/HybridContext.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../Whisper/WhisperModel.h" 3 | #include "../CPU/MlContext.h" 4 | #include "../CPU/BufferAllocator.h" 5 | #include "KeyValueDownloader.h" 6 | #include "../CPU/KvTensors.h" 7 | 8 | // This version of the hybrid context uses the new, custom-built kernels 9 | class HybridContext 10 | { 11 | CpuCompute::MlContext ml; 12 | CpuCompute::VirtualAllocator allocCompute, allocComputeLayer; 13 | 14 | class AllocSingle : public CpuCompute::iArenaAllocator 15 | { 16 | CpuCompute::LargeBuffer buffer; 17 | size_t capacity = 0; 18 | bool allocated = false; 19 | // Inherited via iArenaAllocator 20 | virtual void* allocate( size_t cb, size_t align ) override final; 21 | 22 | public: 23 | virtual void resetArena() override final; 24 | }; 25 | AllocSingle allocLayerOutput; 26 | 27 | const CpuCompute::DecoderTensors& model; 28 | const Whisper::WhisperModel& whisperModel; 29 | KeyValueDownloader kvCross; 30 | CpuCompute::KvTensors kv; 31 | 32 | class SetAllocatorRaii; 33 | 34 | public: 35 | 36 | HybridContext( const Whisper::WhisperModel& wm ); 37 | 38 | HRESULT create(); 39 | 40 | HRESULT downloadKeyValues( const DirectCompute::KeyValueBuffers& source ) 41 | { 42 | return kvCross.download( source ); 43 | } 44 | 45 | struct sDecParams 46 | { 47 | int n_threads; 48 | int M; 49 | }; 50 | 51 | HRESULT decode( const int* tokens, const int n_tokens, const int n_past, const sDecParams& dp, std::vector& probs_out ); 52 | }; -------------------------------------------------------------------------------- /Whisper/Hybrid/KeyValueDownloader.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include "KeyValueDownloader.h" 3 | 4 | HRESULT KeyValueDownloader::create( const Whisper::sModelParams& mp ) 5 | { 6 | const uint32_t n_audio_ctx = mp.n_audio_ctx; 7 | const uint32_t n_mem = mp.n_text_layer * mp.n_audio_ctx; 8 | const uint32_t n_elements = mp.n_text_state * n_mem; 9 | 10 | CD3D11_BUFFER_DESC desc{ n_elements * 2, 0, D3D11_USAGE_STAGING, D3D11_CPU_ACCESS_READ }; 11 | ID3D11Device* dev = DirectCompute::device(); 12 | CHECK( dev->CreateBuffer( &desc, nullptr, &keys ) ); 13 | CHECK( dev->CreateBuffer( &desc, nullptr, &values ) ); 14 | 15 | length = n_elements; 16 | return S_OK; 17 | } 18 | 19 | HRESULT KeyValueDownloader::download( const DirectCompute::KeyValueBuffers& source ) 20 | { 21 | ID3D11DeviceContext* ctx = DirectCompute::context(); 22 | ctx->CopyResource( keys, source.keys.getBuffer() ); 23 | ctx->CopyResource( values, source.values.getBuffer() ); 24 | return S_OK; 25 | } 26 | 27 | KeyValueDownloader::ReadMap::ReadMap( KeyValueDownloader& owner ) : 28 | length( owner.length ) 29 | { 30 | check( mappedKeys.map( owner.keys, true ) ); 31 | check( mappedValues.map( owner.values, true ) ); 32 | } -------------------------------------------------------------------------------- /Whisper/Hybrid/Readme.txt: -------------------------------------------------------------------------------- 1 | The code in this folder is dropped by the linker’s dead code elimination optimization pass, unless you change BUILD_HYBRID_VERSION macro in stdafx.h -------------------------------------------------------------------------------- /Whisper/MF/AudioBuffer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace Whisper 5 | { 6 | struct AudioBuffer 7 | { 8 | std::vector mono; 9 | std::vector stereo; 10 | 11 | void appendMono( const float* rsi, size_t countFloats ); 12 | void appendDownmixedStereo( const float* rsi, size_t countFloats ); 13 | void appendStereo( const float* rsi, size_t countFloats ); 14 | 15 | using pfnAppendSamples = void( AudioBuffer::* )( const float* rsi, size_t countFloats ); 16 | 17 | inline static pfnAppendSamples appendSamplesFunc( bool sourceMono, bool wantStereo ) 18 | { 19 | if( sourceMono ) 20 | return &AudioBuffer::appendMono; 21 | else if( !wantStereo ) 22 | return &AudioBuffer::appendDownmixedStereo; 23 | else 24 | return &AudioBuffer::appendStereo; 25 | } 26 | 27 | void clear() 28 | { 29 | mono.clear(); 30 | stereo.clear(); 31 | } 32 | 33 | void swap( AudioBuffer& that ) 34 | { 35 | mono.swap( that.mono ); 36 | stereo.swap( that.stereo ); 37 | } 38 | 39 | void resize( size_t len ) 40 | { 41 | assert( len <= mono.size() ); 42 | mono.resize( len ); 43 | if( !stereo.empty() ) 44 | stereo.resize( len * 2 ); 45 | } 46 | }; 47 | } -------------------------------------------------------------------------------- /Whisper/MF/AudioCapture.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../API/MfStructs.h" 3 | 4 | namespace Whisper 5 | { 6 | struct iAudioCapture; 7 | struct iMediaFoundation; 8 | 9 | HRESULT __stdcall captureDeviceList( pfnFoundCaptureDevices pfn, void* pv ); 10 | 11 | HRESULT __stdcall captureOpen( iMediaFoundation* owner, const wchar_t* endpoint, const sCaptureParams& captureParams, iAudioCapture** pp ) noexcept; 12 | } -------------------------------------------------------------------------------- /Whisper/MF/loadAudioFile.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../API/iMediaFoundation.cl.h" 3 | 4 | namespace Whisper 5 | { 6 | HRESULT COMLIGHTCALL loadAudioFile( LPCTSTR path, bool stereo, iAudioBuffer** pp ); 7 | } -------------------------------------------------------------------------------- /Whisper/MF/mfStartup.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace Whisper 4 | { 5 | class MfStartupRaii 6 | { 7 | uint8_t successFlags = 0; 8 | public: 9 | MfStartupRaii() = default; 10 | ~MfStartupRaii(); 11 | MfStartupRaii( const MfStartupRaii& ) = delete; 12 | 13 | HRESULT startup(); 14 | }; 15 | } -------------------------------------------------------------------------------- /Whisper/MF/mfUtils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "../Whisper/audioConstants.h" 7 | 8 | namespace Whisper 9 | { 10 | HRESULT createMediaType( bool stereo, IMFMediaType** pp ); 11 | 12 | HRESULT getStreamDuration( IMFSourceReader* reader, int64_t& duration ); 13 | 14 | HRESULT validateCurrentMediaType( IMFSourceReader* reader, uint32_t expectedChannels ); 15 | 16 | struct iAudioReader; 17 | void setPreciseSamplesCount( const iAudioReader* ar, int64_t count ); 18 | } -------------------------------------------------------------------------------- /Whisper/ML/ConstantBuffer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../D3D/device.h" 3 | #include "TensorShape.h" 4 | 5 | namespace DirectCompute 6 | { 7 | // 96 bytes dynamic constant buffers, with dimensions and VRAM layout of 2-3 tensors 8 | class ConstantBuffer 9 | { 10 | CComPtr buffer; 11 | 12 | public: 13 | HRESULT create(); 14 | HRESULT update( const TensorShape& t0 ); 15 | HRESULT update( const TensorShape& t0, const TensorShape& t1 ); 16 | HRESULT update( const TensorShape& t0, const TensorShape& t1, const TensorShape& t2 ); 17 | 18 | void bind() const; 19 | 20 | __m128i getMemoryUse() const 21 | { 22 | return bufferMemoryUsage( buffer ); 23 | } 24 | }; 25 | } -------------------------------------------------------------------------------- /Whisper/ML/DbgNanTest.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include "DbgNanTest.h" 3 | #include "../D3D/MappedResource.h" 4 | using namespace DirectCompute; 5 | 6 | HRESULT DbgNanTest::create() 7 | { 8 | ID3D11Device* const dev = DirectCompute::device(); 9 | 10 | CD3D11_BUFFER_DESC desc{ 4, D3D11_BIND_UNORDERED_ACCESS }; 11 | CHECK( dev->CreateBuffer( &desc, nullptr, &bufferDefault ) ); 12 | 13 | desc.Usage = D3D11_USAGE_STAGING; 14 | desc.BindFlags = 0; 15 | desc.CPUAccessFlags = D3D11_CPU_ACCESS_READ; 16 | CHECK( dev->CreateBuffer( &desc, nullptr, &bufferStaging ) ); 17 | 18 | CD3D11_UNORDERED_ACCESS_VIEW_DESC viewDesc{ D3D11_UAV_DIMENSION_BUFFER, DXGI_FORMAT_R32_UINT, 0, 1 }; 19 | CHECK( dev->CreateUnorderedAccessView( bufferDefault, &viewDesc, &uav ) ); 20 | 21 | return S_OK; 22 | } 23 | 24 | void DbgNanTest::destroy() 25 | { 26 | uav = nullptr; 27 | bufferStaging = nullptr; 28 | bufferDefault = nullptr; 29 | } 30 | 31 | bool DbgNanTest::test() const 32 | { 33 | context()->CopyResource( bufferStaging, bufferDefault ); 34 | MappedResource mapped; 35 | check( mapped.map( bufferStaging, true ) ); 36 | const BOOL val = *(const BOOL*)mapped.data(); 37 | return val != 0; 38 | } -------------------------------------------------------------------------------- /Whisper/ML/DbgNanTest.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace DirectCompute 4 | { 5 | class DbgNanTest 6 | { 7 | CComPtr bufferDefault, bufferStaging; 8 | CComPtr uav; 9 | public: 10 | HRESULT create(); 11 | void destroy(); 12 | operator ID3D11UnorderedAccessView* ( ) const 13 | { 14 | return uav; 15 | } 16 | bool test() const; 17 | }; 18 | 19 | #if DBG_TEST_NAN 20 | const DbgNanTest& getNanTestBuffers(); 21 | #endif 22 | } -------------------------------------------------------------------------------- /Whisper/ML/Device.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include "../D3D/sGpuInfo.h" 4 | #include "LookupTables.h" 5 | #include "DbgNanTest.h" 6 | 7 | namespace DirectCompute 8 | { 9 | struct Device 10 | { 11 | CComPtr device; 12 | CComPtr context; 13 | 14 | std::vector> shaders; 15 | CComPtr smallCb; 16 | sGpuInfo gpuInfo; 17 | LookupTables lookupTables; 18 | #if DBG_TEST_NAN 19 | DbgNanTest nanTestBuffers; 20 | #endif 21 | 22 | HRESULT create( uint32_t flags, const std::wstring& adapter ); 23 | HRESULT createClone( const Device& source ); 24 | void destroy(); 25 | 26 | class ThreadSetupRaii 27 | { 28 | bool setup; 29 | public: 30 | ThreadSetupRaii( const Device* dev ); 31 | ~ThreadSetupRaii(); 32 | ThreadSetupRaii( ThreadSetupRaii&& that ) noexcept 33 | { 34 | setup = that.setup; 35 | that.setup = false; 36 | } 37 | ThreadSetupRaii( const ThreadSetupRaii& ) = delete; 38 | void operator=( const ThreadSetupRaii& ) = delete; 39 | }; 40 | 41 | ThreadSetupRaii setForCurrentThread() const 42 | { 43 | return ThreadSetupRaii{ this }; 44 | } 45 | }; 46 | } -------------------------------------------------------------------------------- /Whisper/ML/LookupTables.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../D3D/device.h" 3 | 4 | namespace DirectCompute 5 | { 6 | class LookupTables 7 | { 8 | CComPtr m_gelu, m_exponent; 9 | 10 | public: 11 | 12 | HRESULT create(); 13 | HRESULT createClone( const LookupTables& source ); 14 | void clear(); 15 | ID3D11ShaderResourceView* gelu() const { return m_gelu; } 16 | ID3D11ShaderResourceView* exponent() const { return m_exponent; } 17 | 18 | __m128i getMemoryUsage() const; 19 | }; 20 | 21 | const LookupTables& lookupTables(); 22 | } -------------------------------------------------------------------------------- /Whisper/ML/LookupTablesData.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace DirectCompute 6 | { 7 | struct LookupTablesData 8 | { 9 | std::array gelu; 10 | std::array exponent; 11 | 12 | LookupTablesData(); 13 | }; 14 | } -------------------------------------------------------------------------------- /Whisper/ML/Reshaper.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "Tensor.h" 3 | 4 | namespace DirectCompute 5 | { 6 | // This class reshapes some of the model’s tensor, immediately after they’re loaded. 7 | // That feature is used on all AMD GPUs. 8 | class Reshaper 9 | { 10 | CComPtr constantBuffer; 11 | HRESULT createConstants(); 12 | 13 | public: 14 | ~Reshaper(); 15 | HRESULT makePanels( Tensor& tensor, eDataType dataType ); 16 | }; 17 | } -------------------------------------------------------------------------------- /Whisper/ML/TempBuffers.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "TensorGpuViews.h" 3 | 4 | namespace DirectCompute 5 | { 6 | class TempBuffers 7 | { 8 | class Buffer : public TensorGpuViews 9 | { 10 | size_t capacity = 0; 11 | 12 | public: 13 | 14 | void clear() 15 | { 16 | TensorGpuViews::clear(); 17 | capacity = 0; 18 | } 19 | 20 | HRESULT resize( DXGI_FORMAT format, size_t elements, size_t cbElement, bool zeroMemory ); 21 | 22 | size_t getCapacity() const { return capacity; } 23 | }; 24 | 25 | Buffer m_fp16; 26 | Buffer m_fp16_2; 27 | Buffer m_fp32; 28 | 29 | public: 30 | 31 | const TensorGpuViews& fp16( size_t countElements, bool zeroMemory = false ); 32 | const TensorGpuViews& fp16_2( size_t countElements, bool zeroMemory = false ); 33 | const TensorGpuViews& fp32( size_t countElements, bool zeroMemory = false ); 34 | 35 | void clear() 36 | { 37 | m_fp16.clear(); 38 | m_fp16_2.clear(); 39 | m_fp32.clear(); 40 | } 41 | 42 | __m128i getMemoryUse() const; 43 | }; 44 | } -------------------------------------------------------------------------------- /Whisper/ML/TensorEx.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "Tensor.h" 3 | 4 | namespace DirectCompute 5 | { 6 | // A tensor which supports dynamic updates from CPU, or downloads from VRAM to system RAM 7 | class TensorEx : public Tensor 8 | { 9 | protected: 10 | CComPtr buffer; 11 | CComPtr stagingBuffer; 12 | 13 | HRESULT getViewSize( uint32_t& cbElement, uint32_t& countElements ) const; 14 | 15 | public: 16 | 17 | HRESULT create( const ggml_tensor& ggml, eBufferUse usage, bool uploadData ); 18 | HRESULT create( eDataType type, eBufferUse usage, const std::array& sizeElements ); 19 | 20 | HRESULT download( void* rdi, size_t cb ) const; 21 | 22 | HRESULT download( void* rdi ) const; 23 | 24 | template 25 | HRESULT download( std::vector& vec ) const 26 | { 27 | uint32_t cbElement, numElements; 28 | CHECK( getViewSize( cbElement, numElements ) ); 29 | 30 | try 31 | { 32 | vec.resize( numElements ); 33 | } 34 | catch( const std::bad_alloc& ) 35 | { 36 | return E_OUTOFMEMORY; 37 | } 38 | 39 | return download( vec.data(), (size_t)cbElement * numElements ); 40 | } 41 | }; 42 | } -------------------------------------------------------------------------------- /Whisper/ML/TensorGpuViews.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include "TensorGpuViews.h" 3 | using namespace DirectCompute; 4 | 5 | HRESULT TensorGpuViews::create( ID3D11Buffer* gpuBuffer, DXGI_FORMAT format, size_t countElements, bool makeUav ) 6 | { 7 | srv = nullptr; 8 | uav = nullptr; 9 | 10 | if( countElements > UINT_MAX ) 11 | return DISP_E_OVERFLOW; 12 | 13 | CD3D11_SHADER_RESOURCE_VIEW_DESC viewDesc{ D3D11_SRV_DIMENSION_BUFFER, format, 0, (UINT)countElements }; 14 | CHECK( device()->CreateShaderResourceView( gpuBuffer, &viewDesc, &srv ) ); 15 | 16 | if( makeUav ) 17 | { 18 | CD3D11_UNORDERED_ACCESS_VIEW_DESC uavDesc{ D3D11_UAV_DIMENSION_BUFFER, format , 0, (UINT)countElements }; 19 | CHECK( device()->CreateUnorderedAccessView( gpuBuffer, &uavDesc, &uav ) ); 20 | } 21 | 22 | return S_OK; 23 | } -------------------------------------------------------------------------------- /Whisper/ML/TensorGpuViews.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include "../D3D/device.h" 4 | 5 | namespace DirectCompute 6 | { 7 | class TensorGpuViews 8 | { 9 | protected: 10 | CComPtr srv; 11 | CComPtr uav; 12 | 13 | public: 14 | 15 | operator ID3D11ShaderResourceView* ( ) const { return srv; } 16 | operator ID3D11UnorderedAccessView* ( ) const { return uav; } 17 | 18 | HRESULT create( ID3D11Buffer* buffer, DXGI_FORMAT format, size_t countElements, bool makeUav ); 19 | 20 | void clear() 21 | { 22 | srv = nullptr; 23 | uav = nullptr; 24 | } 25 | 26 | void setGpuViews( ID3D11ShaderResourceView* read, ID3D11UnorderedAccessView* write = nullptr ) 27 | { 28 | srv = read; 29 | uav = write; 30 | } 31 | }; 32 | } -------------------------------------------------------------------------------- /Whisper/ML/mlUtils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace DirectCompute 4 | { 5 | // Update the small dynamic constant buffer 6 | ID3D11Buffer* __vectorcall updateSmallCb( __m128i cbData ); 7 | 8 | // Fill the tensor with either 0.0 or NaN values 9 | void zeroMemory( ID3D11UnorderedAccessView* uav, uint32_t length, bool fillWithNaN = false ); 10 | 11 | // Fill the complete UAV with NaN values 12 | void fillTensorWithNaN( ID3D11UnorderedAccessView* uav ); 13 | 14 | // true when the tensor contains at least 1 NaN value 15 | bool scanTensorForNaN( ID3D11ShaderResourceView* tensor, uint32_t length ); 16 | 17 | // Create SRV on another device, reusing the resource 18 | HRESULT cloneResourceView( ID3D11ShaderResourceView* rsi, ID3D11ShaderResourceView** rdi ); 19 | } -------------------------------------------------------------------------------- /Whisper/ML/reshapedMultiply.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace DirectCompute 5 | { 6 | namespace ReshapedMultiply 7 | { 8 | constexpr uint32_t TILE_SIZE = 32; 9 | } 10 | } -------------------------------------------------------------------------------- /Whisper/ML/tensorOpsTests.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../source/ggml.h" 3 | 4 | namespace DirectCompute 5 | { 6 | // void testMulMatReshape( const ggml_tensor* src1, const void* tempBuffer ); 7 | void testMulMat( const ggml_tensor* src0, const ggml_tensor* src1, const ggml_tensor* dst, const void* tempBuffer ); 8 | void computeMulMat( const ggml_tensor* src0, const ggml_tensor* src1, ggml_tensor* dst ); 9 | 10 | void testFlashAttention( const ggml_tensor* q, const ggml_tensor* k, const ggml_tensor* v, bool masked, const ggml_tensor* dst ); 11 | void computeFlashAttention( const ggml_tensor* q, const ggml_tensor* k, const ggml_tensor* v, bool masked, ggml_tensor* dst ); 12 | 13 | void testConvolution( const ggml_tensor* src0, const ggml_tensor* src1, const ggml_tensor* dst ); 14 | void computeConvolution( const ggml_tensor* src0, const ggml_tensor* src1, ggml_tensor* dst ); 15 | } -------------------------------------------------------------------------------- /Whisper/ML/testUtils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../D3D/downloadBuffer.h" 3 | #include "../D3D/RenderDoc/renderDoc.h" 4 | #include 5 | #include 6 | 7 | // Funfact: this code written by ChatGPT 8 | namespace std 9 | { 10 | template<> 11 | struct hash> 12 | { 13 | size_t operator()( const array& arr ) const 14 | { 15 | size_t result = 0; 16 | for( uint32_t element : arr ) 17 | result = ( result * 31 ) ^ element; 18 | return result; 19 | } 20 | }; 21 | } 22 | 23 | namespace DirectCompute 24 | { 25 | struct sTensorDiff 26 | { 27 | // maximum( absolute( a - b ) ) 28 | float maxAbsDiff; 29 | // average( ( a - b )^2 ) 30 | float avgDiffSquared; 31 | size_t length; 32 | 33 | void print() const; 34 | void print( const char* what ) const; 35 | }; 36 | 37 | // Compute difference between 2 FP32 vectors 38 | sTensorDiff computeDiff( const float* a, const float* b, size_t length ); 39 | 40 | // Compute difference between 2 FP16 vectors 41 | sTensorDiff computeDiff( const uint16_t* a, const uint16_t* b, size_t length ); 42 | 43 | class Tensor; 44 | sTensorDiff computeDiff( const Tensor& a, const Tensor& b ); 45 | 46 | HRESULT dbgWriteBinaryFile( LPCTSTR fileName, const void* rsi, size_t cb ); 47 | 48 | // Print unique sizes of the two tensors 49 | class PrintUniqueTensorSizes 50 | { 51 | std::unordered_set> set; 52 | const char* const what; 53 | void printImpl( const std::array& a ); 54 | 55 | public: 56 | PrintUniqueTensorSizes( const char* w ) : what( w ) { } 57 | 58 | void print( const Tensor& lhs, const Tensor& rhs ); 59 | void print( const Tensor& lhs ); 60 | void print( const int* lhs, const int* rhs ); 61 | }; 62 | } -------------------------------------------------------------------------------- /Whisper/ML/testUtilsC.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef __cplusplus 4 | extern "C" 5 | { 6 | #endif 7 | void printUniqueTensorSize( const char* name, const int* lhs, const int* rhs ); 8 | #ifdef __cplusplus 9 | } 10 | #endif -------------------------------------------------------------------------------- /Whisper/Readme.txt: -------------------------------------------------------------------------------- 1 | This C++ project builds a DLL which actually does the heavy lifting of this project. 2 | 3 | It implements the ML model, handles multimedia files with Media Foundation, captures audio (also with MF), does voice activity detection (custom code running on CPU), and a few smaller things. 4 | 5 | The code requires C++/20, and only tested with Visual Studio 2022. 6 | 7 | When running pure GPGPU model, the DLL requires SSE 4.1 instruction set. 8 | 9 | When running a hybrid model, the DLL requires AVX1, FMA3, F16C, and BMI1 instruction set extensions. -------------------------------------------------------------------------------- /Whisper/Resource.rc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Const-me/Whisper/306aadd1fce4b168cd38659236f4ba7c1603cebd/Whisper/Resource.rc -------------------------------------------------------------------------------- /Whisper/Utils/CpuProfiler.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include "CpuProfiler.h" 3 | 4 | namespace 5 | { 6 | using namespace Whisper; 7 | 8 | inline int64_t qpcNow() 9 | { 10 | int64_t res; 11 | QueryPerformanceCounter( (LARGE_INTEGER*)&res ); 12 | return res; 13 | } 14 | 15 | class CpuTimescale 16 | { 17 | uint64_t frequency = 0; 18 | const int64_t tscStart; 19 | const int64_t qpcStart; 20 | 21 | uint64_t computeTscFrequency(); 22 | 23 | public: 24 | 25 | CpuTimescale() : 26 | tscStart( tscNow() ), 27 | qpcStart( qpcNow() ) 28 | { } 29 | 30 | inline uint64_t computeTicks( uint64_t tsc ) 31 | { 32 | uint64_t freq = frequency; 33 | if( freq == 0 ) 34 | freq = computeTscFrequency(); 35 | 36 | return makeTime( tsc, freq ); 37 | } 38 | }; 39 | 40 | uint64_t __declspec( noinline ) CpuTimescale::computeTscFrequency() 41 | { 42 | int64_t tsc = tscNow(); 43 | int64_t qpc = qpcNow(); 44 | tsc -= tscStart; 45 | qpc -= qpcStart; 46 | 47 | uint64_t qpcFreq; 48 | QueryPerformanceFrequency( (LARGE_INTEGER*)&qpcFreq ); 49 | 50 | // Seconds = qpc / qpcFreq 51 | // ticks per second = tsc / seconds = tsc * qpcFreq / qpc 52 | uint64_t res = ( (uint64_t)tsc * qpcFreq + ( (uint64_t)qpc / 2 ) - 1 ) / (uint64_t)qpc; 53 | frequency = res; 54 | const double GHz = (double)(int64_t)res * 1.0E-9; 55 | logDebug( u8"Computed CPU base frequency: %g GHz", GHz ); 56 | return res; 57 | } 58 | 59 | static CpuTimescale timescale; 60 | } 61 | 62 | uint64_t Whisper::ticksFromTsc( uint64_t tscDiff ) 63 | { 64 | return timescale.computeTicks( tscDiff ); 65 | } -------------------------------------------------------------------------------- /Whisper/Utils/CpuProfiler.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace Whisper 4 | { 5 | // Get current time in CPU clock 6 | // More specifically, each CPU core has a timestamp counter which runs at CPU's base frequency, regardless on the frequency scaling of that core. 7 | inline int64_t tscNow() 8 | { 9 | return __rdtsc(); 10 | } 11 | 12 | // Scale the time interval from CPU time stamp counter clock into 100-nanosecond ticks, rounding to nearest 13 | uint64_t ticksFromTsc( uint64_t tscDiff ); 14 | 15 | class CpuProfiler 16 | { 17 | const int64_t started = tscNow(); 18 | 19 | public: 20 | 21 | uint64_t elapsed() const 22 | { 23 | return ticksFromTsc( (uint64_t)( tscNow() - started ) ); 24 | } 25 | }; 26 | } -------------------------------------------------------------------------------- /Whisper/Utils/DelayExecution.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include "DelayExecution.h" 3 | 4 | namespace 5 | { 6 | constexpr bool useHighRezTimer = false; 7 | 8 | constexpr int64_t sleepMicroseconds = 200; 9 | 10 | inline HRESULT sleepImpl( HANDLE timer ) 11 | { 12 | constexpr int64_t sleepTicks = sleepMicroseconds * 10; 13 | 14 | LARGE_INTEGER li; 15 | // Negative values indicate relative time 16 | li.QuadPart = -sleepTicks; 17 | if( !SetWaitableTimerEx( timer, &li, 0, nullptr, nullptr, nullptr, 0 ) ) 18 | return getLastHr(); 19 | const DWORD res = WaitForSingleObject( timer, 50 ); 20 | if( res == WAIT_OBJECT_0 ) 21 | return S_OK; 22 | if( res == WAIT_FAILED ) 23 | return getLastHr(); 24 | return E_FAIL; 25 | } 26 | } 27 | 28 | void DelayExecution::sleepOnTheTimer( const DelayExecution& delay ) 29 | { 30 | HRESULT hr = sleepImpl( delay.timer ); 31 | if( SUCCEEDED( hr ) ) 32 | return; 33 | logWarningHr( hr, u8"DelayExecution.sleepOnTheTimer" ); 34 | } 35 | 36 | void DelayExecution::spinWait( const DelayExecution& ) 37 | { 38 | for( size_t i = 0; i < 1024; i++ ) 39 | _mm_pause(); 40 | } 41 | 42 | void DelayExecution::sleep( const DelayExecution& ) 43 | { 44 | Sleep( 0 ); 45 | } 46 | 47 | DelayExecution::DelayExecution() 48 | { 49 | if constexpr( useHighRezTimer ) 50 | { 51 | constexpr DWORD flags = CREATE_WAITABLE_TIMER_HIGH_RESOLUTION; 52 | HANDLE h = CreateWaitableTimerEx( nullptr, nullptr, flags, TIMER_ALL_ACCESS ); 53 | if( nullptr != h ) 54 | { 55 | timer.Attach( h ); 56 | pfn = &sleepOnTheTimer; 57 | return; 58 | } 59 | 60 | const HRESULT hr = getLastHr(); 61 | logWarningHr( hr, u8"CreateWaitableTimerEx" ); 62 | } 63 | 64 | pfn = &spinWait; 65 | // pfn = &sleep; 66 | } -------------------------------------------------------------------------------- /Whisper/Utils/DelayExecution.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | // Utility class implementing a high-resolution Sleep() function 5 | class DelayExecution 6 | { 7 | using pfnDelay = void( * )( const DelayExecution& de ); 8 | pfnDelay pfn = nullptr; 9 | CHandle timer; 10 | 11 | static void sleepOnTheTimer( const DelayExecution& delay ); 12 | static void spinWait( const DelayExecution& ); 13 | static void sleep( const DelayExecution& ); 14 | 15 | public: 16 | DelayExecution(); 17 | DelayExecution( const DelayExecution& ) = delete; 18 | ~DelayExecution() = default; 19 | 20 | void delay() const 21 | { 22 | pfn( *this ); 23 | } 24 | }; -------------------------------------------------------------------------------- /Whisper/Utils/GpuProfiler.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Const-me/Whisper/306aadd1fce4b168cd38659236f4ba7c1603cebd/Whisper/Utils/GpuProfiler.h -------------------------------------------------------------------------------- /Whisper/Utils/GpuProfilerSimple.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../D3D/device.h" 3 | #include "DelayExecution.h" 4 | 5 | namespace DirectCompute 6 | { 7 | // A simple profiler which doesn't collect anything, used to measure time it took to load the model 8 | class GpuProfilerSimple 9 | { 10 | DelayExecution delay; 11 | CComPtr disjoint, begin, end; 12 | public: 13 | HRESULT create(); 14 | HRESULT time( uint64_t& rdi ) const; 15 | }; 16 | } -------------------------------------------------------------------------------- /Whisper/Utils/LZ4/LICENSE: -------------------------------------------------------------------------------- 1 | LZ4 Library 2 | Copyright (c) 2011-2020, Yann Collet 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, 6 | are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above copyright notice, this 12 | list of conditions and the following disclaimer in the documentation and/or 13 | other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 19 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 22 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /Whisper/Utils/Logger.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../API/loggerApi.h" 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | void logError( const char8_t* pszFormat, ... ); 9 | void logError16( const wchar_t* pszFormat, ... ); 10 | void logErrorHr( long hr, const char8_t* pszFormat, ... ); 11 | void logWarning( const char8_t* pszFormat, ... ); 12 | void logWarning16( const wchar_t* pszFormat, ... ); 13 | void logWarningHr( long hr, const char8_t* pszFormat, ... ); 14 | void logInfo( const char8_t* pszFormat, ... ); 15 | void logInfo16( const wchar_t* pszFormat, ... ); 16 | void logDebug( const char8_t* pszFormat, ... ); 17 | void logDebug16( const wchar_t* pszFormat, ... ); 18 | 19 | bool willLogMessage( Whisper::eLogLevel lvl ); 20 | 21 | #ifdef __cplusplus 22 | } 23 | #endif 24 | -------------------------------------------------------------------------------- /Whisper/Utils/MurmurHash3.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | void MurmurHash3_x86_32( const void* key, int len, uint32_t seed, void* out ); 5 | void MurmurHash3_x86_128( const void* key, int len, uint32_t seed, void* out ); 6 | void MurmurHash3_x64_128( const void* key, int len, uint32_t seed, void* out ); 7 | 8 | #include 9 | 10 | // Traits class for `CAtlMap` which does not copy nor owns these strings 11 | struct StringPtrTraits : public ATL::CDefaultElementTraits 12 | { 13 | using INARGTYPE = const char*; 14 | 15 | static inline bool CompareElements( const char* a, const char* b ) 16 | { 17 | return 0 == strcmp( a, b ); 18 | } 19 | 20 | static inline int CompareElementsOrdered( const char* a, const char* b ) 21 | { 22 | return strcmp( a, b ); 23 | } 24 | 25 | static inline ULONG Hash( const char* ptr ) 26 | { 27 | uint32_t hash = UINT_MAX; 28 | if( nullptr != ptr ) 29 | { 30 | const int len = (int)strlen( ptr ); 31 | constexpr uint32_t seed = 0; 32 | MurmurHash3_x86_32( ptr, len, seed, &hash ); 33 | } 34 | return hash; 35 | } 36 | }; -------------------------------------------------------------------------------- /Whisper/Utils/ReadStream.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../ComLightLib/streams.h" 3 | #include "../ComLightLib/comLightServer.h" 4 | #define WIN32_LEAN_AND_MEAN 5 | #include 6 | 7 | class ReadStream : public ComLight::ObjectRoot 8 | { 9 | CAtlFile file; 10 | // TODO: implement a buffer in this class, at least 256kb 11 | 12 | HRESULT COMLIGHTCALL read( void* lpBuffer, int nNumberOfBytesToRead, int& lpNumberOfBytesRead ) override final 13 | { 14 | return file.Read( lpBuffer, (DWORD)nNumberOfBytesToRead, *(DWORD*)&lpNumberOfBytesRead ); 15 | } 16 | HRESULT COMLIGHTCALL seek( int64_t offset, ComLight::eSeekOrigin origin ) override final 17 | { 18 | return file.Seek( offset, (uint8_t)origin ); 19 | } 20 | HRESULT COMLIGHTCALL getPosition( int64_t& position ) override final 21 | { 22 | return file.GetPosition( *(ULONGLONG*)&position ); 23 | } 24 | HRESULT COMLIGHTCALL getLength( int64_t& length ) override final 25 | { 26 | return file.GetSize( *(ULONGLONG*)&length ); 27 | } 28 | 29 | public: 30 | 31 | HRESULT open( const wchar_t* path ) 32 | { 33 | if( file ) 34 | return HRESULT_CODE( ERROR_ALREADY_INITIALIZED ); 35 | return file.Create( path, GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL | FILE_FLAG_SEQUENTIAL_SCAN ); 36 | } 37 | }; -------------------------------------------------------------------------------- /Whisper/Utils/Trace/TraceStructures.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include "TraceStructures.h" 3 | using namespace Tracing; 4 | 5 | uint64_t sTraceItem::buffer( uint64_t off, size_t length, eDataType type ) 6 | { 7 | payloadOffset = off; 8 | payloadSize = length * DirectCompute::elementSize( type ); 9 | *(uint64_t*)( &size[ 0 ] ) = length; 10 | *(uint64_t*)( &size[ 2 ] ) = 0; 11 | _mm_storeu_si128( ( __m128i* )stride.data(), _mm_setzero_si128() ); 12 | itemType = eItemType::Buffer; 13 | dataType = type; 14 | return payloadSize; 15 | } 16 | 17 | uint64_t sTraceItem::tensor( uint64_t off, __m128i ne, __m128i nb, eDataType type ) 18 | { 19 | payloadOffset = off; 20 | _mm_storeu_si128( ( __m128i* )size.data(), ne ); 21 | _mm_storeu_si128( ( __m128i* )stride.data(), nb ); 22 | uint64_t count = 1; 23 | for( uint32_t i : size ) 24 | if( i != 0 ) 25 | count *= i; 26 | 27 | payloadSize = count * DirectCompute::elementSize( type ); 28 | itemType = eItemType::Tensor; 29 | dataType = type; 30 | return payloadSize; 31 | } -------------------------------------------------------------------------------- /Whisper/Utils/parallelFor.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace Whisper 4 | { 5 | // A callback to offload to the thread pool 6 | using pfnParallelForCallback = HRESULT( * )( int ith, void* ctx ) noexcept; 7 | 8 | // A simple parallel for implementation; Windows includes a decent thread pool since Vista (2006) 9 | HRESULT parallelFor( pfnParallelForCallback pfn, int threadsCount, void* ctx ); 10 | 11 | // Use this version when you wanna use the thread pool repeatedly, for the same work. 12 | // This class caches native work handle, saving a couple of WinAPI calls. 13 | class alignas( 64 ) ThreadPoolWork 14 | { 15 | PTP_WORK work = nullptr; 16 | 17 | // We want these volatile fields in another cache line from the rest of the data of this class. 18 | // threadIndex field is concurrently modified by different CPU cores, and these cache coherency protocols are slow. 19 | // OTOH, work and callback fields of this class only change when created / destroyed, that cache line is shared by CPU cores without any performance penalty. 20 | alignas( 64 ) volatile long threadIndex = 0; 21 | volatile HRESULT status = E_UNEXPECTED; 22 | 23 | static void __stdcall callbackStatic( PTP_CALLBACK_INSTANCE Instance, PVOID pv, PTP_WORK Work ); 24 | 25 | protected: 26 | virtual HRESULT threadPoolCallback( int ith ) noexcept = 0; 27 | 28 | public: 29 | ThreadPoolWork() = default; 30 | ThreadPoolWork( const ThreadPoolWork& ) = delete; 31 | 32 | ~ThreadPoolWork(); 33 | 34 | HRESULT create(); 35 | 36 | HRESULT parallelFor( int threadsCount ) noexcept; 37 | }; 38 | } -------------------------------------------------------------------------------- /Whisper/Whisper/DecoderInputBuffers.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../ML/Tensor.h" 3 | 4 | namespace DirectCompute 5 | { 6 | // A dynamic buffer 7 | class DecoderInputBuffers 8 | { 9 | CComPtr embd; 10 | uint32_t m_size = 0; 11 | uint32_t m_capacity = 0; 12 | 13 | public: 14 | 15 | void resize( uint32_t size ); 16 | 17 | // Create 1D tensor with R32_UINT elements, upload the source data 18 | Tensor embedding( const int* rsi ) const; 19 | 20 | void clear(); 21 | 22 | __m128i getMemoryUse() const 23 | { 24 | size_t i = m_capacity; 25 | i *= sizeof( uint32_t ); 26 | return _mm_set_epi64x( (int64_t)i, 0 ); 27 | } 28 | 29 | HRESULT zeroMemory() const; 30 | }; 31 | } -------------------------------------------------------------------------------- /Whisper/Whisper/DecoderResultBuffer.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include "DecoderResultBuffer.h" 3 | #include "../D3D/MappedResource.h" 4 | using namespace DirectCompute; 5 | 6 | void DecoderResultBuffer::copyFromVram( const Tensor& rsi ) 7 | { 8 | ID3D11ShaderResourceView* srv = rsi; 9 | if( nullptr == srv ) 10 | throw OLE_E_BLANK; 11 | if( !rsi.isContinuous() ) 12 | throw E_INVALIDARG; 13 | 14 | const uint32_t len = rsi.countElements(); 15 | if( len > m_capacity ) 16 | { 17 | buffer = nullptr; 18 | CD3D11_BUFFER_DESC desc{ len * 4, 0, D3D11_USAGE_STAGING, D3D11_CPU_ACCESS_READ }; 19 | check( device()->CreateBuffer( &desc, nullptr, &buffer ) ); 20 | m_capacity = len; 21 | } 22 | 23 | CComPtr source; 24 | srv->GetResource( &source ); 25 | // Coordinates of a box are in bytes for buffers 26 | D3D11_BOX box; 27 | store16( &box, _mm_setr_epi32( 0, 0, 0, (int)( len * 4 ) ) ); 28 | *(uint64_t*)&box.bottom = 0x100000001ull; 29 | context()->CopySubresourceRegion( buffer, 0, 0, 0, 0, source, 0, &box ); 30 | m_size = len; 31 | } 32 | 33 | void DecoderResultBuffer::copyToVector( std::vector& vec ) const 34 | { 35 | vec.resize( m_size ); 36 | if( vec.empty() ) 37 | throw OLE_E_BLANK; 38 | 39 | MappedResource mapped; 40 | check( mapped.map( buffer, true ) ); 41 | memcpy( vec.data(), mapped.data(), (size_t)4 * m_size ); 42 | } 43 | 44 | void DecoderResultBuffer::clear() 45 | { 46 | buffer = nullptr; 47 | m_size = m_capacity = 0; 48 | } -------------------------------------------------------------------------------- /Whisper/Whisper/DecoderResultBuffer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../ML/Tensor.h" 3 | 4 | namespace DirectCompute 5 | { 6 | class DecoderResultBuffer 7 | { 8 | CComPtr buffer; 9 | uint32_t m_size = 0; 10 | uint32_t m_capacity = 0; 11 | 12 | public: 13 | 14 | void copyFromVram( const Tensor& rsi ); 15 | 16 | void copyToVector( std::vector& vec ) const; 17 | 18 | uint32_t size() const 19 | { 20 | return m_size; 21 | } 22 | void clear(); 23 | 24 | __m128i getMemoryUse() const 25 | { 26 | return bufferMemoryUsage( buffer ); 27 | } 28 | }; 29 | } -------------------------------------------------------------------------------- /Whisper/Whisper/KeyValueBuffers.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../ML/Tensor.h" 3 | 4 | namespace DirectCompute 5 | { 6 | // FP16 buffer for self-attention and cross-attention layers 7 | class AttentionBuffer 8 | { 9 | CComPtr buffer; 10 | uint32_t m_size = 0; 11 | 12 | public: 13 | // Create buffer for the specified count of elements 14 | void resize( uint32_t size ); 15 | 16 | // Create an 1D tensor which references a slice of that buffer 17 | Tensor view( uint32_t length, uint32_t offset ) const; 18 | 19 | void clear() 20 | { 21 | buffer = nullptr; 22 | m_size = 0; 23 | } 24 | 25 | ID3D11Buffer* getBuffer() const { return buffer; } 26 | 27 | uint32_t getSize() const { return m_size; } 28 | 29 | HRESULT zeroMemory() const; 30 | }; 31 | 32 | struct KeyValueBuffers 33 | { 34 | AttentionBuffer keys, values; 35 | 36 | void resize( uint32_t size ); 37 | 38 | void clear() 39 | { 40 | keys.clear(); 41 | values.clear(); 42 | } 43 | 44 | __m128i getMemoryUse() const 45 | { 46 | size_t i = keys.getSize(); 47 | i += values.getSize(); 48 | i *= sizeof( uint16_t ); 49 | return setHigh_size( (int64_t)i ); // They both are in VRAM 50 | } 51 | 52 | HRESULT zeroMemory() const; 53 | }; 54 | } -------------------------------------------------------------------------------- /Whisper/Whisper/Languages.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../../ComLightLib/comLightCommon.h" 3 | 4 | namespace Whisper 5 | { 6 | int lookupLanguageId( const char* code ); 7 | int lookupLanguageId( uint32_t key ); 8 | 9 | const char* lookupLanguageName( const char* code ); 10 | 11 | int COMLIGHTCALL getLanguageId( const char* lang ); 12 | } -------------------------------------------------------------------------------- /Whisper/Whisper/MelInputTensor.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../ML/TensorEx.h" 3 | #include "sEncodeParams.h" 4 | #include "iSpectrogram.h" 5 | 6 | namespace DirectCompute 7 | { 8 | // Input tensor in VRAM, in a dynamic FP32 buffer 9 | class MelInputTensor : public TensorEx 10 | { 11 | uint32_t capacity; 12 | 13 | public: 14 | 15 | HRESULT create( Whisper::iSpectrogram& spectrogram, const sEncodeParams& encParams ); 16 | 17 | __m128i getMemoryUse() const 18 | { 19 | return setHigh_size( (size_t)capacity * 4 ); 20 | } 21 | }; 22 | } -------------------------------------------------------------------------------- /Whisper/Whisper/ModelLoader.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "ModelBuffers.h" 3 | #include 4 | 5 | namespace DirectCompute 6 | { 7 | struct ModelLoader 8 | { 9 | ModelLoader( int encoderLayers, int decoderLayers ); 10 | 11 | void add( const ggml_tensor* ggml, Tensor& gpu ); 12 | 13 | void add( const ggml_tensor* w, const ggml_tensor* b, TensorPair& gpu ) 14 | { 15 | add( w, gpu.w ); 16 | add( b, gpu.b ); 17 | } 18 | 19 | bool tryLoad( const ggml_tensor* ggml ); 20 | 21 | ModelBuffers& model; 22 | 23 | private: 24 | 25 | Tensor* lookup( const ggml_tensor* ggml ) const; 26 | 27 | std::map map; 28 | }; 29 | } -------------------------------------------------------------------------------- /Whisper/Whisper/Spectrogram.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "WhisperModel.h" 3 | #include "iSpectrogram.h" 4 | #include "audioConstants.h" 5 | 6 | namespace Whisper 7 | { 8 | struct iAudioBuffer; 9 | 10 | // This implementation of iSpectrogram interface converts complete audio into MEL spectrogram 11 | // Used for unbuffered audio, and capture: iContext.runFull and runCapture methods. 12 | class Spectrogram: public iSpectrogram 13 | { 14 | uint32_t length = 0; 15 | static constexpr uint32_t mel = N_MEL; 16 | std::vector data; 17 | std::vector stereo; 18 | 19 | HRESULT makeBuffer( size_t off, size_t len, const float** buffer, size_t& stride ) noexcept override final 20 | { 21 | if( off + len > length ) 22 | return E_BOUNDS; 23 | *buffer = &data[ off ]; 24 | stride = length; 25 | return S_OK; 26 | } 27 | 28 | class MelContext; 29 | 30 | HRESULT copyStereoPcm( size_t offset, size_t length, std::vector& buffer ) const override final; 31 | 32 | public: 33 | size_t getLength() const noexcept override final 34 | { 35 | return length; 36 | } 37 | HRESULT pcmToMel( const iAudioBuffer* buffer, const Filters& filters, int threads = 1 ); 38 | 39 | size_t memoryUsage() const 40 | { 41 | return data.size() * 4; 42 | } 43 | }; 44 | 45 | // average the fabs of the signal 46 | void computeSignalEnergy( std::vector& result, const iAudioBuffer* buffer, int n_samples_per_half_window ); 47 | } -------------------------------------------------------------------------------- /Whisper/Whisper/TranscribeResult.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../API/iTranscribeResult.cl.h" 3 | #include "../ComLightLib/comLightServer.h" 4 | 5 | namespace Whisper 6 | { 7 | class TranscribeResult : public ComLight::ObjectRoot 8 | { 9 | HRESULT COMLIGHTCALL getSize( sTranscribeLength& rdi ) const noexcept override final 10 | { 11 | rdi.countSegments = (uint32_t)segments.size(); 12 | rdi.countTokens = (uint32_t)tokens.size(); 13 | return S_OK; 14 | } 15 | const sSegment* COMLIGHTCALL getSegments() const noexcept override final 16 | { 17 | if( !segments.empty() ) 18 | return segments.data(); 19 | return nullptr; 20 | } 21 | const sToken* COMLIGHTCALL getTokens() const noexcept override final 22 | { 23 | if( !tokens.empty() ) 24 | return tokens.data(); 25 | return nullptr; 26 | } 27 | 28 | public: 29 | std::vector segments; 30 | std::vector tokens; 31 | std::vector segmentsText; 32 | }; 33 | 34 | class TranscribeResultStatic : public ComLight::Object 35 | { 36 | uint32_t COMLIGHTCALL Release() override final 37 | { 38 | // When the ref.counter reaches zero, Object.Release() method calls `delete this`. 39 | // We don't want that for the aggregated object. 40 | // Instead we only decrement the ref.counter, but do not delete the object even when the counter reaches zero. 41 | return RefCounter::implRelease(); 42 | } 43 | }; 44 | } -------------------------------------------------------------------------------- /Whisper/Whisper/Vocabulary.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../../ComLightLib/streams.h" 3 | #include "../API/SpecialTokens.h" 4 | #include "../Utils/MurmurHash3.h" 5 | 6 | namespace Whisper 7 | { 8 | class Vocabulary 9 | { 10 | std::vector tokens; 11 | std::vector stringData; 12 | using THashMap = CAtlMap; 13 | THashMap idFromToken; 14 | 15 | void addExtra( int index, const char* format, int i ); 16 | 17 | void completeBuild(); 18 | public: 19 | Vocabulary(); 20 | 21 | int n_vocab = 51864; 22 | 23 | HRESULT load( ComLight::iReadStream* stm, int lengthInHeader ); 24 | 25 | using id = int; 26 | 27 | id token_eot = 50256; 28 | id token_sot = 50257; 29 | id token_prev = 50360; 30 | id token_solm = 50361; // ?? 31 | id token_not = 50362; // no timestamps 32 | id token_beg = 50363; 33 | 34 | // available tasks 35 | static const id token_translate = 50358; 36 | static const id token_transcribe = 50359; 37 | 38 | bool is_multilingual() const 39 | { 40 | return n_vocab == 51865; 41 | } 42 | 43 | const char* string( int id ) const 44 | { 45 | if( id >= 0 && id < (int)tokens.size() ) 46 | return tokens[ id ]; 47 | return nullptr; 48 | } 49 | 50 | int findId( const char* token ) const; 51 | int findId( const std::string& token ) const 52 | { 53 | return findId( token.c_str() ); 54 | } 55 | 56 | size_t size() const 57 | { 58 | return tokens.size(); 59 | } 60 | 61 | void getSpecialTokens( SpecialTokens& rdi ) const; 62 | 63 | size_t getMemoryUse() const 64 | { 65 | return vectorMemoryUse( tokens ) + vectorMemoryUse( stringData ); 66 | } 67 | 68 | HRESULT tokenize( const std::string& text, std::vector& tokens ) const; 69 | }; 70 | } -------------------------------------------------------------------------------- /Whisper/Whisper/audioConstants.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace Whisper 5 | { 6 | // WHISPER_SAMPLE_RATE, 16 kHz 7 | constexpr uint32_t SAMPLE_RATE = 16000; 8 | // WHISPER_N_FFT, 25 milliseconds 9 | constexpr uint32_t FFT_SIZE = 400; 10 | // WHISPER_HOP_LENGTH, 10 milliseconds 11 | constexpr uint32_t FFT_STEP = 160; 12 | // WHISPER_N_MEL 13 | constexpr uint32_t N_MEL = 80; 14 | } -------------------------------------------------------------------------------- /Whisper/Whisper/iSpectrogram.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "audioConstants.h" 3 | 4 | namespace Whisper 5 | { 6 | struct alignas( 8 ) StereoSample 7 | { 8 | float left, right; 9 | }; 10 | 11 | __interface iSpectrogram 12 | { 13 | // Make a buffer with length * N_MEL floats, starting at the specified offset 14 | // An implementation of this interface may visualize the spectrogram, making pieces on demand 15 | HRESULT makeBuffer( size_t offset, size_t length, const float** buffer, size_t& stride ); 16 | 17 | // Apparently, the length unit is 160 input samples = 10 milliseconds of audio 18 | size_t getLength() const; 19 | 20 | // If the source data is stereo, copy the specified slice of the data into the provided vector 21 | HRESULT copyStereoPcm( size_t offset, size_t length, std::vector& buffer ) const; 22 | }; 23 | 24 | // RAII class to deal with iSpectrogram's makeBuffer method. 25 | // Throws exceptions when things fail. 26 | class MelBufferRaii 27 | { 28 | const float* pointer; 29 | size_t stride; 30 | public: 31 | 32 | HRESULT make( iSpectrogram& mel, size_t off, size_t len ) 33 | { 34 | return mel.makeBuffer( off, len, &pointer, stride ); 35 | } 36 | 37 | const float* operator[]( size_t idx ) const 38 | { 39 | assert( idx < N_MEL ); 40 | return pointer + idx * stride; 41 | } 42 | 43 | const BYTE* bytePtr() const { return (const BYTE*)pointer; } 44 | LONG strideBytes() const { return (LONG)stride * 4; } 45 | }; 46 | } -------------------------------------------------------------------------------- /Whisper/Whisper/loaderUtils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../../ComLightLib/streams.h" 3 | 4 | namespace Whisper 5 | { 6 | inline HRESULT readBytes( ComLight::iReadStream* stm, void* rdi, size_t cb ) 7 | { 8 | if( cb > INT_MAX ) 9 | return DISP_E_OVERFLOW; 10 | if( cb == 0 ) 11 | return S_FALSE; 12 | int n; 13 | CHECK( stm->read( rdi, (int)cb, n ) ); 14 | if( n != (int)cb ) 15 | return E_EOF; 16 | return S_OK; 17 | } 18 | 19 | template 20 | inline HRESULT readStruct( ComLight::iReadStream* stm, T& dest ) 21 | { 22 | return readBytes( stm, &dest, sizeof( T ) ); 23 | } 24 | } -------------------------------------------------------------------------------- /Whisper/Whisper/melSpectrogram.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "audioConstants.h" 3 | #include "WhisperModel.h" 4 | #include 5 | 6 | namespace Whisper 7 | { 8 | class HanningWindow 9 | { 10 | std::array hann; 11 | public: 12 | HanningWindow(); 13 | 14 | float operator[]( size_t i ) const 15 | { 16 | return hann[ i ]; 17 | } 18 | }; 19 | 20 | extern const HanningWindow s_hanning; 21 | 22 | class SpectrogramContext 23 | { 24 | const Filters& filters; 25 | static float* fftRecursion( float* temp, const float* const rsi, const size_t len ); 26 | std::unique_ptr tempBuffer; 27 | 28 | public: 29 | SpectrogramContext( const Filters& flt ); 30 | 31 | // First step of the MEL algorithm, and recursively compute the FFT 32 | void fft( std::array& rdi, const float* pcm, size_t length ); 33 | }; 34 | } -------------------------------------------------------------------------------- /Whisper/Whisper/sEncodeParams.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace DirectCompute 5 | { 6 | struct sEncodeParams 7 | { 8 | uint32_t n_ctx, n_mels, mel_offset; 9 | uint32_t layersCount, n_state, n_head; 10 | uint32_t n_audio_ctx, n_text_state, n_text_layer, n_text_ctx; 11 | }; 12 | 13 | struct sDecodeParams 14 | { 15 | uint32_t n_state, n_head; 16 | uint32_t n_ctx, n_past, M; 17 | uint32_t n_text_layer; 18 | uint32_t n_vocab; 19 | }; 20 | } -------------------------------------------------------------------------------- /Whisper/Whisper/sModelParams.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | namespace Whisper 3 | { 4 | // default hparams (Whisper tiny) 5 | struct sModelParams 6 | { 7 | int n_vocab = 51864; 8 | int n_audio_ctx = 1500; 9 | int n_audio_state = 384; 10 | int n_audio_head = 6; 11 | int n_audio_layer = 4; 12 | int n_text_ctx = 448; 13 | int n_text_state = 384; 14 | int n_text_head = 6; 15 | int n_text_layer = 4; 16 | int n_mels = 80; 17 | int f16 = 1; 18 | }; 19 | } -------------------------------------------------------------------------------- /Whisper/Whisper/sTokenData.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | namespace Whisper 5 | { 6 | using whisper_token = int; 7 | 8 | struct sTokenData 9 | { 10 | whisper_token id; // token id 11 | whisper_token tid; // forced timestamp token id 12 | float p; // probability of the token 13 | float pt; // probability of the timestamp token 14 | float ptsum; // sum of probabilities of all timestamp tokens 15 | float vlen; // voice length of the token 16 | // token-level timestamp data 17 | // do not use if you haven't computed token-level timestamps 18 | int64_t t0; // start time of the token 19 | int64_t t1; // end time of the token 20 | }; 21 | 22 | 23 | } -------------------------------------------------------------------------------- /Whisper/Whisper/voiceActivityDetection.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include "audioConstants.h" 5 | 6 | namespace Whisper 7 | { 8 | class VAD 9 | { 10 | using cplx = std::complex; 11 | std::unique_ptr fft_signal; 12 | 13 | struct Feature 14 | { 15 | float energy; 16 | float F; 17 | float SFM; 18 | }; 19 | const Feature primThresh; 20 | static Feature defaultPrimaryThresholds(); 21 | 22 | struct State 23 | { 24 | Feature currThresh; 25 | Feature minFeature; 26 | Feature curr; 27 | 28 | uint32_t lastSpeech; 29 | float silenceRun; 30 | uint32_t i; 31 | }; 32 | State state; 33 | 34 | static inline void fft( cplx* buf, cplx* out, size_t n, size_t step ); 35 | void fft() const; 36 | 37 | static float computeEnergy( const float* rsi ); 38 | static float computeDominant( const cplx* spectrum ); 39 | static float computreSpectralFlatnessMeasure( const cplx* spectrum ); 40 | 41 | public: 42 | 43 | VAD(); 44 | 45 | // When no speech is detected, returns 0 46 | // When speech is detected, returns sample position for the end of the speech 47 | size_t detect( const float* rsi, size_t length ); 48 | 49 | void clear(); 50 | 51 | static constexpr uint32_t FFT_POINTS = 256; 52 | static constexpr float FFT_STEP = (float)SAMPLE_RATE / (float)FFT_POINTS; 53 | }; 54 | } -------------------------------------------------------------------------------- /Whisper/modelFactory.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" 2 | #include "modelFactory.h" 3 | #include "API/iContext.cl.h" 4 | 5 | HRESULT COMLIGHTCALL Whisper::loadModel( const wchar_t* path, const sModelSetup& setup, const sLoadModelCallbacks* callbacks, iModel** pp ) 6 | { 7 | switch( setup.impl ) 8 | { 9 | case eModelImplementation::GPU: 10 | case eModelImplementation::Hybrid: 11 | return loadGpuModel( path, setup, callbacks, pp ); 12 | case eModelImplementation::Reference: 13 | if( 0 != setup.flags ) 14 | logWarning( u8"The reference model doesn’t currently use any flags, argument ignored" ); 15 | return loadReferenceCpuModel( path, pp ); 16 | } 17 | 18 | logError( u8"Unknown model implementation 0x%X", (int)setup.impl ); 19 | return E_INVALIDARG; 20 | } -------------------------------------------------------------------------------- /Whisper/modelFactory.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "API/sLoadModelCallbacks.h" 3 | #include "API/sModelSetup.h" 4 | 5 | namespace Whisper 6 | { 7 | struct iModel; 8 | 9 | HRESULT __stdcall loadGpuModel( const wchar_t* path, const sModelSetup& setup, const sLoadModelCallbacks* callbacks, iModel** pp ); 10 | 11 | HRESULT __stdcall loadReferenceCpuModel( const wchar_t* path, iModel** pp ); 12 | } -------------------------------------------------------------------------------- /Whisper/resource.h: -------------------------------------------------------------------------------- 1 | //{{NO_DEPENDENCIES}} 2 | // Microsoft Visual C++ generated include file. 3 | // Used by Resource.rc 4 | 5 | // Next default values for new objects 6 | // 7 | #ifdef APSTUDIO_INVOKED 8 | #ifndef APSTUDIO_READONLY_SYMBOLS 9 | #define _APS_NEXT_RESOURCE_VALUE 101 10 | #define _APS_NEXT_COMMAND_VALUE 40001 11 | #define _APS_NEXT_CONTROL_VALUE 1001 12 | #define _APS_NEXT_SYMED_VALUE 101 13 | #endif 14 | #endif 15 | -------------------------------------------------------------------------------- /Whisper/source.compat/Readme.txt: -------------------------------------------------------------------------------- 1 | The code in this folder is dropped by the linker’s dead code elimination optimization pass, unless you change BUILD_BOTH_VERSIONS macro in stdafx.h -------------------------------------------------------------------------------- /Whisper/source.compat/convertThings.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "../source/whisper.h" 3 | #include "../API/sFullParams.h" 4 | #include "../API/iTranscribeResult.cl.h" 5 | 6 | Whisper::sFullParams makeNewParams( const whisper_full_params& rsi ); 7 | 8 | whisper_full_params makeOldParams( const Whisper::sFullParams& rsi, Whisper::iContext* context ); 9 | 10 | HRESULT makeNewResults( whisper_context* ctx, Whisper::eResultFlags flags, Whisper::iTranscribeResult** pp ); -------------------------------------------------------------------------------- /Whisper/source.compat/ggmlMsvc.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "../source/ggml.h" 6 | 7 | __forceinline float _cvtsh_ss( uint16_t f16 ) 8 | { 9 | __m128i i = _mm_cvtsi32_si128( f16 ); 10 | __m128 f = _mm_cvtph_ps( i ); 11 | return _mm_cvtss_f32( f ); 12 | } 13 | 14 | __forceinline uint16_t _cvtss_sh( float f, int rounding ) 15 | { 16 | assert( 0 == rounding ); 17 | __m128 v = _mm_set_ss( f ); 18 | __m128i i = _mm_cvtps_ph( v, 0 ); 19 | return (uint16_t)(uint32_t)_mm_cvtsi128_si32( i ); 20 | } 21 | 22 | FILE* fopen_msvc( const char* filename, const char* mode ) 23 | { 24 | FILE* stream; 25 | errno_t err = fopen_s( &stream, filename, mode ); 26 | if( err == 0 ) 27 | return stream; 28 | return NULL; 29 | } 30 | 31 | #define fopen fopen_msvc 32 | 33 | #include "../ML/testUtilsC.h" 34 | 35 | #define __F16C__ 36 | #define __FMA__ 37 | #include "../source/ggml.c" -------------------------------------------------------------------------------- /Whisper/source/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Georgi Gerganov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Whisper/source/Readme.txt: -------------------------------------------------------------------------------- 1 | The code in this folder is dropped by the linker’s dead code elimination optimization pass, unless you change BUILD_BOTH_VERSIONS macro in stdafx.h -------------------------------------------------------------------------------- /Whisper/stdafx.cpp: -------------------------------------------------------------------------------- 1 | #include "stdafx.h" -------------------------------------------------------------------------------- /Whisper/whisper.def: -------------------------------------------------------------------------------- 1 | LIBRARY 2 | EXPORTS setupLogger 3 | EXPORTS loadModel 4 | EXPORTS initMediaFoundation 5 | EXPORTS findLanguageKeyW 6 | EXPORTS findLanguageKeyA 7 | EXPORTS getSupportedLanguages 8 | EXPORTS listGPUs -------------------------------------------------------------------------------- /WhisperNet/API/CaptureDeviceId.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using Whisper.Internal; 3 | 4 | namespace Whisper 5 | { 6 | /// Identifiers for an audio capture device 7 | public struct CaptureDeviceId 8 | { 9 | /// The display name is suitable for showing to the user, but might not be unique. 10 | public string displayName; 11 | 12 | /// Endpoint ID for an audio capture device.
13 | /// It uniquely identifies the device on the system, but is not a readable string.
14 | public string endpoint; 15 | 16 | internal CaptureDeviceId( in sCaptureDevice rsi ) 17 | { 18 | displayName = rsi.displayName ?? ""; 19 | endpoint = rsi.endpoint ?? throw new ApplicationException( "The device has no endpoint ID" ); 20 | } 21 | 22 | /// Returns a String which represents the object instance 23 | public override string ToString() => $"Capture device: \"{displayName}\""; 24 | } 25 | } -------------------------------------------------------------------------------- /WhisperNet/API/SpecialTokens.cs: -------------------------------------------------------------------------------- 1 | namespace Whisper 2 | { 3 | /// Special tokens defined in the model 4 | public readonly struct SpecialTokens 5 | { 6 | /// The end of a transcription 7 | public readonly int TranscriptionEnd; // token_eot 8 | /// Start of a transcription 9 | public readonly int TranscriptionStart; // token_sot 10 | /// Represents the previous word in the transcription. It is used to help the model predict the current word based on the context of the words that came before it. 11 | public readonly int PreviousWord; // token_prev 12 | /// Start of a sentence 13 | public readonly int SentenceStart; // token_solm 14 | /// Represents the word "not" in the transcription 15 | public readonly int Not; // token_not 16 | /// New transcription 17 | public readonly int TranscriptionBegin; // token_beg 18 | /// token_translate 19 | public readonly int TaskTranslate; 20 | /// token_transcribe 21 | public readonly int TaskTranscribe; 22 | } 23 | } -------------------------------------------------------------------------------- /WhisperNet/API/eCaptureStatus.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | 3 | namespace Whisper 4 | { 5 | /// Status of the voice capture 6 | [Flags] 7 | public enum eCaptureStatus: byte 8 | { 9 | /// Doing nothing 10 | None = 0, 11 | /// Capturing the audio 12 | Listening = 1, 13 | /// A voice is detected in the captured audio, recording 14 | Voice = 2, 15 | /// Transcribing a recorded piece of the audio 16 | Transcribing = 4, 17 | /// The computer is unable to transcribe the audio quickly enough,
18 | /// and the capture is dropping the incoming audio samples.
19 | Stalled = 0x80, 20 | } 21 | } -------------------------------------------------------------------------------- /WhisperNet/API/eGpuModelFlags.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | 3 | namespace Whisper 4 | { 5 | /// 6 | /// These flags affect compute shaders performance (which ones are faster depends on GPU model),
7 | /// and VRAM memory usage (UseReshapedMatMul needs slightly more VRAM).
8 | ///
9 | [Flags] 10 | public enum eGpuModelFlags: uint 11 | { 12 | /// Equivalent to Wave32 | NoReshapedMatMul on Intel and nVidia GPUs,
13 | /// and Wave64 | UseReshapedMatMul on AMD GPUs
14 | None = 0, 15 | 16 | /// Use Wave32 version of compute shaders even on AMD GPUs 17 | /// Incompatible with 18 | Wave32 = 1, 19 | 20 | /// Use Wave64 version of compute shaders even on nVidia and Intel GPUs 21 | /// Incompatible with 22 | Wave64 = 2, 23 | 24 | /// Do not use reshaped matrix multiplication shaders on AMD GPUs 25 | /// Incompatible with 26 | NoReshapedMatMul = 4, 27 | 28 | /// Use reshaped matrix multiplication shaders even on nVidia and Intel GPUs 29 | /// Incompatible with 30 | UseReshapedMatMul = 8, 31 | 32 | /// Create GPU tensors in a way which allows sharing across D3D devices 33 | Cloneable = 0x10, 34 | } 35 | } -------------------------------------------------------------------------------- /WhisperNet/API/eLogLevel.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | 3 | namespace Whisper 4 | { 5 | /// Message log level 6 | public enum eLogLevel: byte 7 | { 8 | /// Error message 9 | Error = 0, 10 | /// Warning message 11 | Warning = 1, 12 | /// Informational message 13 | Info = 2, 14 | /// Debug message 15 | Debug = 3 16 | } 17 | 18 | /// A delegate to receive log messages from the library 19 | public delegate void pfnLogMessage( eLogLevel level, string message ); 20 | 21 | /// Log destination flags 22 | [Flags] 23 | public enum eLoggerFlags: byte 24 | { 25 | /// No special flags 26 | None = 0, 27 | 28 | /// In addition to calling the delegate, print messaged to standard error 29 | UseStandardError = 1, 30 | 31 | /// Don’t format error codes into messages 32 | /// It’s recommended to use this flag in .NET.
33 | /// The standard library already formats these messages automatically, as needed.
34 | SkipFormatMessage = 2, 35 | } 36 | } -------------------------------------------------------------------------------- /WhisperNet/API/eModelImplementation.cs: -------------------------------------------------------------------------------- 1 | namespace Whisper 2 | { 3 | /// Implementation value for the factory function 4 | public enum eModelImplementation: uint 5 | { 6 | /// GPGPU implementation based on Direct3D 11.0 compute shaders 7 | GPU = 1, 8 | 9 | /// A hybrid implementation which uses DirectCompute for encode, and decodes on CPU 10 | /// 11 | /// The build of the native DLL included into this nuget package doesn’t implement this version.
12 | /// To enable, edit stdafx.h in Whisper project, change the value of BUILD_HYBRID_VERSION macro from zero to one, and build.
13 | /// This implementation requires a CPU with AVX1, FMA3, F16C and BMI1 instruction set extensions. 14 | ///
15 | Hybrid = 2, 16 | 17 | /// A reference implementation which uses the original GGML CPU-running code. 18 | /// 19 | /// The build of the native DLL included into this nuget package doesn’t implement this version either.
20 | /// To enable, edit stdafx.h in Whisper project, change the value of BUILD_BOTH_VERSIONS macro from zero to one, and build the project.
21 | /// This implementation requires a CPU with AVX1, FMA3, and F16C instruction set extensions. 22 | ///
23 | Reference = 3, 24 | } 25 | } -------------------------------------------------------------------------------- /WhisperNet/API/eResultFlags.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | 3 | namespace Whisper 4 | { 5 | /// Flags for method 6 | [Flags] 7 | public enum eResultFlags: uint 8 | { 9 | /// No flags 10 | None = 0, 11 | 12 | /// Return individual tokens in addition to the segments 13 | Tokens = 1, 14 | 15 | /// Return timestamps 16 | Timestamps = 2, 17 | 18 | /// Create a new COM object for the results. 19 | /// Without this flag, the context returns a pointer to the COM object stored in the context.
20 | /// The content of that object is replaced every time you call method.
21 | NewObject = 0x100, 22 | } 23 | } -------------------------------------------------------------------------------- /WhisperNet/API/eSpeakerChannel.cs: -------------------------------------------------------------------------------- 1 | namespace Whisper 2 | { 3 | /// Output value for iContext.detectSpeaker method 4 | public enum eSpeakerChannel: byte 5 | { 6 | /// Unable to detect 7 | Unsure = 0, 8 | /// The speech was mostly in the left channel 9 | Left = 1, 10 | /// The speech was mostly in the right channel 11 | Right = 2, 12 | /// The audio only has 1 channel 13 | NoStereoData = 0xFF, 14 | } 15 | } -------------------------------------------------------------------------------- /WhisperNet/API/iAudioReader.cs: -------------------------------------------------------------------------------- 1 | using ComLight; 2 | using System; 3 | 4 | namespace Whisper 5 | { 6 | /// Audio stream reader object 7 | /// The implementation is forward-only, and these objects aren’t reusable.
8 | /// To read an audio file multiple time, dispose this object, and create a new one from the same source file.
9 | [ComInterface( "35b988da-04a6-476a-a193-d8891d5dc390", eMarshalDirection.ToManaged )] 10 | public interface iAudioReader: IDisposable 11 | { 12 | /// Get duration of the media file 13 | [RetValIndex] 14 | TimeSpan getDuration(); 15 | } 16 | 17 | /// Audio capture reader object 18 | /// This interface has no public methods callable from C#.
19 | /// It’s only here to pass data between different functions implemented in C++.
20 | [ComInterface( "747752c2-d9fd-40df-8847-583c781bf013", eMarshalDirection.ToManaged )] 21 | public interface iAudioCapture: IDisposable 22 | { 23 | } 24 | } -------------------------------------------------------------------------------- /WhisperNet/API/sCaptureParams.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | 3 | namespace Whisper 4 | { 5 | /// Flags for the audio capture 6 | [Flags] 7 | public enum eCaptureFlags: uint 8 | { 9 | /// No special flags 10 | None = 0, 11 | /// When the capture device supports stereo, keep stereo PCM samples in addition to mono 12 | Stereo = 1, 13 | } 14 | 15 | /// Parameters for audio capture 16 | public struct sCaptureParams 17 | { 18 | /// Minimum transcribe duration in seconds 19 | public float minDuration; 20 | /// Maximum transcribe duration in seconds 21 | public float maxDuration; 22 | /// 23 | public float dropStartSilence; 24 | /// 25 | public float pauseDuration; 26 | /// Flags for the audio capture 27 | public eCaptureFlags flags; 28 | 29 | /// Initialize the structure with some reasonable default values 30 | public sCaptureParams( bool unused ) 31 | { 32 | minDuration = 7.0f; // 7 seconds 33 | maxDuration = 11.0f; // 11 seconds 34 | dropStartSilence = 0.25f; // 250 ms 35 | pauseDuration = 0.333f; // 333 ms 36 | flags = eCaptureFlags.None; 37 | } 38 | } 39 | } -------------------------------------------------------------------------------- /WhisperNet/AssemblyInfo.cs: -------------------------------------------------------------------------------- 1 | using System.Reflection; 2 | using System.Runtime.InteropServices; 3 | 4 | [assembly: AssemblyCopyright( "Copyright © const.me, 2022-2023" )] 5 | [assembly: ComVisible( false )] 6 | [assembly: AssemblyVersion( "1.12.0.0" )] -------------------------------------------------------------------------------- /WhisperNet/AssemblyTitle.cs: -------------------------------------------------------------------------------- 1 | using System.Reflection; 2 | using System.Runtime.InteropServices; 3 | 4 | [assembly: AssemblyTitle( "WhisperNet" )] 5 | [assembly: AssemblyDescription( "DirectCompute port of whisper.cpp library, C# bindings" )] 6 | [assembly: Guid( "ced6cdb7-e040-4398-bae8-3417e5fa35f1" )] -------------------------------------------------------------------------------- /WhisperNet/Callbacks.cs: -------------------------------------------------------------------------------- 1 | using Whisper.Internal; 2 | 3 | namespace Whisper 4 | { 5 | /// Implement this abstract class to receive callbacks from the native code 6 | public abstract class Callbacks 7 | { 8 | /// The callback is called before every encoder run. 9 | /// If it returns false, the processing is aborted. 10 | protected virtual bool onEncoderBegin( Context sender ) { return true; } 11 | 12 | /// This callback is called on each new segment 13 | protected virtual void onNewSegment( Context sender, int countNew ) { } 14 | 15 | const int S_OK = 0; 16 | const int S_FALSE = 1; 17 | internal int encoderBegin( Context sender ) 18 | { 19 | try 20 | { 21 | return onEncoderBegin( sender ) ? S_OK : S_FALSE; 22 | } 23 | catch( Exception ex ) 24 | { 25 | NativeLogger.captureException( ex ); 26 | return ex.HResult; 27 | } 28 | } 29 | 30 | internal int newSegment( Context sender, int countNew ) 31 | { 32 | try 33 | { 34 | onNewSegment( sender, countNew ); 35 | return S_OK; 36 | } 37 | catch( Exception ex ) 38 | { 39 | NativeLogger.captureException( ex ); 40 | return ex.HResult; 41 | } 42 | } 43 | } 44 | } -------------------------------------------------------------------------------- /WhisperNet/CaptureCallbacks.cs: -------------------------------------------------------------------------------- 1 | using Whisper.Internal; 2 | 3 | namespace Whisper 4 | { 5 | /// Implement this abstract class to provide callbacks for audio capture method 6 | public abstract class CaptureCallbacks 7 | { 8 | /// Override this method to support cancellation 9 | protected virtual bool shouldCancel( Context sender ) { return false; } 10 | 11 | /// Override this method to get notified about status changes 12 | protected virtual void captureStatusChanged( Context sender, eCaptureStatus status ) { } 13 | 14 | internal pfnShouldCancel cancel( Context sender ) 15 | { 16 | const int S_OK = 0; 17 | const int S_FALSE = 1; 18 | return delegate ( IntPtr pv ) 19 | { 20 | try 21 | { 22 | return shouldCancel( sender ) ? S_FALSE : S_OK; 23 | } 24 | catch( Exception ex ) 25 | { 26 | NativeLogger.captureException( ex ); 27 | return ex.HResult; 28 | } 29 | }; 30 | } 31 | 32 | internal pfnCaptureStatus status( Context sender ) 33 | { 34 | return delegate ( IntPtr pv, eCaptureStatus status ) 35 | { 36 | try 37 | { 38 | captureStatusChanged( sender, status ); 39 | return 0; 40 | } 41 | catch( Exception ex ) 42 | { 43 | NativeLogger.captureException( ex ); 44 | return ex.HResult; 45 | } 46 | }; 47 | } 48 | } 49 | } -------------------------------------------------------------------------------- /WhisperNet/Internal/sCaptureCallbacks.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Runtime.InteropServices; 3 | 4 | namespace Whisper.Internal 5 | { 6 | /// Unmanaged code calls this to check for cancellation 7 | /// Return 0 to proceed, or 1 to stop the process and return from Context.runCapture method 8 | [UnmanagedFunctionPointer( CallingConvention.StdCall )] 9 | public delegate int pfnShouldCancel( IntPtr pv ); 10 | 11 | /// Unmanaged code calls this to notify about the status 12 | [UnmanagedFunctionPointer( CallingConvention.StdCall )] 13 | public delegate int pfnCaptureStatus( IntPtr pv, eCaptureStatus status ); 14 | 15 | /// Capture callbacks for unmanaged code 16 | public struct sCaptureCallbacks 17 | { 18 | /// Cancellation function pointer 19 | public pfnShouldCancel shouldCancel; 20 | /// Capture status function pointer 21 | public pfnCaptureStatus captureStatus; 22 | /// Context pointer, only needed for C++ compatibility 23 | public IntPtr pv; 24 | } 25 | } -------------------------------------------------------------------------------- /WhisperNet/Internal/sCaptureDevice.cs: -------------------------------------------------------------------------------- 1 | #pragma warning disable CS0649 // Field is never assigned to 2 | using System.Runtime.InteropServices; 3 | 4 | namespace Whisper.Internal 5 | { 6 | /// Identifiers for an audio capture device 7 | public struct sCaptureDevice 8 | { 9 | readonly IntPtr m_displayName; 10 | /// The display name is suitable for showing to the user, but might not be unique. 11 | public string? displayName => Marshal.PtrToStringUni( m_displayName ); 12 | 13 | readonly IntPtr m_endpoint; 14 | /// Endpoint ID for an audio capture device.
15 | /// It uniquely identifies the device on the system, but is not a readable string.
16 | public string? endpoint => Marshal.PtrToStringUni( m_endpoint ); 17 | } 18 | 19 | /// Function pointer to consume a list of audio capture device IDs 20 | [UnmanagedFunctionPointer( CallingConvention.StdCall )] 21 | public delegate int pfnFoundCaptureDevices( int len, [In, MarshalAs( UnmanagedType.LPArray, SizeParamIndex = 0 )] sCaptureDevice[]? arr, IntPtr pv ); 22 | } -------------------------------------------------------------------------------- /WhisperNet/Internal/sFullParams.cs: -------------------------------------------------------------------------------- 1 | #pragma warning disable CS0649 // Field is never assigned to 2 | 3 | // Missing XML comment for publicly visible type or member 4 | // TODO: remove this line and document them. 5 | #pragma warning disable CS1591 6 | 7 | using System.Runtime.InteropServices; 8 | 9 | namespace Whisper.Internals 10 | { 11 | /// This callback is called on each new segment 12 | [UnmanagedFunctionPointer( CallingConvention.Cdecl )] 13 | delegate int pfnNewSegment( IntPtr ctx, int countNew, IntPtr userData ); 14 | 15 | /// The callback is called before every encoder run. If it returns S_FALSE, the processing is aborted. 16 | [UnmanagedFunctionPointer( CallingConvention.Cdecl )] 17 | delegate int pfnEncoderBegin( IntPtr ctx, IntPtr userData ); 18 | 19 | /// Transcribe parameters 20 | public struct sFullParams 21 | { 22 | internal Parameters publicParams; 23 | // The rest of these parameters are not exposed to the user-friendly public API of this DLL 24 | 25 | internal IntPtr prompt_tokens; 26 | internal int prompt_n_tokens; 27 | 28 | /// This callback is called on each new segment 29 | [MarshalAs( UnmanagedType.FunctionPtr )] 30 | internal pfnNewSegment? newSegmentCallback; 31 | /// Parameter for the above, not needed in C# 32 | internal IntPtr newSegmentCallbackData; 33 | 34 | /// The callback is called before every encoder run. If it returns false, the processing is aborted 35 | [MarshalAs( UnmanagedType.FunctionPtr )] 36 | internal pfnEncoderBegin? encoderBeginCallback; 37 | /// Parameter for the above, not needed in C# 38 | internal IntPtr encoderBeginCallbackData; 39 | } 40 | } -------------------------------------------------------------------------------- /WhisperNet/Internal/sLoggerSetup.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Runtime.InteropServices; 3 | 4 | namespace Whisper.Internal 5 | { 6 | [UnmanagedFunctionPointer( CallingConvention.StdCall )] 7 | delegate void pfnLoggerSink( IntPtr context, eLogLevel lvl, [MarshalAs( UnmanagedType.LPUTF8Str )] string message ); 8 | 9 | struct sLoggerSetup 10 | { 11 | [MarshalAs( UnmanagedType.FunctionPtr )] 12 | public pfnLoggerSink sink; 13 | IntPtr context; 14 | public eLogLevel level; 15 | public eLoggerFlags flags; 16 | } 17 | } -------------------------------------------------------------------------------- /WhisperNet/Internal/sModelSetup.cs: -------------------------------------------------------------------------------- 1 | using System.Runtime.InteropServices; 2 | 3 | namespace Whisper.Internal 4 | { 5 | struct sModelSetup 6 | { 7 | eModelImplementation impl; 8 | eGpuModelFlags flags; 9 | [MarshalAs( UnmanagedType.LPWStr )] 10 | string? adapter; 11 | 12 | public sModelSetup( eGpuModelFlags flags, eModelImplementation impl, string? adapter ) 13 | { 14 | this.impl = impl; 15 | this.flags = flags; 16 | this.adapter = adapter; 17 | } 18 | } 19 | 20 | [UnmanagedFunctionPointer( CallingConvention.StdCall )] 21 | delegate void pfnListAdapters( [In, MarshalAs( UnmanagedType.LPWStr )] string name, IntPtr pv ); 22 | } -------------------------------------------------------------------------------- /WhisperNet/Internal/sProgressSink.cs: -------------------------------------------------------------------------------- 1 | #pragma warning disable CS0649 // Field is never assigned to 2 | using System.Runtime.InteropServices; 3 | 4 | namespace Whisper.Internal 5 | { 6 | /// A callback to get notified about the progress 7 | [UnmanagedFunctionPointer( CallingConvention.StdCall )] 8 | delegate int pfnReportProgress( double value, IntPtr context, IntPtr pv ); 9 | 10 | /// C structure with a progress reporting function pointer 11 | public struct sProgressSink 12 | { 13 | /// A callback to get notified about the progress 14 | [MarshalAs( UnmanagedType.FunctionPtr )] 15 | internal pfnReportProgress? pfn; 16 | 17 | /// Last parameter to the callback 18 | internal IntPtr pv; 19 | } 20 | } -------------------------------------------------------------------------------- /WhisperNet/Readme.md: -------------------------------------------------------------------------------- 1 | This library implements high-performance GPGPU inference of OpenAI's Whisper automatic speech recognition (ASR) model. 2 | 3 | The library requires a hardware GPU which supports Direct3D 11.0, a 64-bit Windows OS, only works within 64-bit processes, and requires a 64 bit CPU which supports [AVX1](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) and [F16C](https://en.wikipedia.org/wiki/F16C) extensions. 4 | 5 | The main entry point of the llibrary is `Whisper.Library` static class. 6 | Call `loadModel` function from that class to load an ML model from a binary file. 7 | 8 | These binary files are available for free download on [Hugging Face](https://huggingface.co/ggerganov/whisper.cpp/tree/main). 9 | I recommend `ggml-medium.bin` (1.42GB in size, but that web page says 1.53 GB), because I’ve mostly tested the software with that model. 10 | Compressed models in ZIP format with `mlmodelc` in the file name are not supported. 11 | 12 | Once the model is loaded, create a context by calling `createContext` extension method, 13 | then use that object to transcribe or translate multimedia files or realtime audio captured from microphones. -------------------------------------------------------------------------------- /WhisperNet/WhisperNet.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | net6.0-windows 4 | enable 5 | enable 6 | true 7 | false 8 | True 9 | True 10 | Whisper 11 | false 12 | x64 13 | 14 | 15 | True 16 | WhisperNet.nuspec 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /WhisperNet/WhisperNet.nuspec: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | WhisperNet 5 | 1.12.0 6 | Konstantin, const.me 7 | MPL-2.0 8 | https://github.com/Const-me/Whisper 9 | High-performance GPGPU inference of OpenAI's Whisper automatic speech recognition (ASR) model 10 | 11 | Updated models source URL in the documentation. 12 | Reliability enhancement, microphone capture less likely to transition to “Stalled” state and discard the audio. 13 | 14 | Copyright © const.me, 2022-2023 15 | whisper, gpgpu, speech recognition 16 | 17 | 18 | 19 | 20 | 21 | 22 | docs/Readme.md 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /WhisperPS/Commands/ExportBase.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.IO; 3 | using System.Management.Automation; 4 | 5 | namespace Whisper 6 | { 7 | /// Base class for commands which export results into some text-based format 8 | public abstract class ExportBase: PSCmdlet 9 | { 10 | /// 11 | /// Transcribe result produced by 12 | /// It requires the value of the correct type 13 | /// 14 | [Parameter( Mandatory = true, ValueFromPipeline = true )] 15 | public Transcription source { get; set; } 16 | 17 | /// 18 | /// Output file to write 19 | /// 20 | [Parameter( Mandatory = true, Position = 0 ), ValidateNotNullOrEmpty] 21 | public string path { get; set; } 22 | 23 | /// Performs execution of the command 24 | protected override void ProcessRecord() 25 | { 26 | string path = this.absolutePath( this.path ); 27 | string dir = Path.GetDirectoryName( path ); 28 | Directory.CreateDirectory( dir ); 29 | if( File.Exists( path ) ) 30 | if( !ShouldContinue( $"Overwrite \"{path}\" ?", "The output file already exists" ) ) 31 | return; 32 | 33 | var results = source.getResult(); 34 | using( var stream = File.CreateText( path ) ) 35 | write( stream, results ); 36 | } 37 | 38 | /// Actual implementation 39 | protected abstract void write( StreamWriter stream, TranscribeResult transcribeResult ); 40 | } 41 | } -------------------------------------------------------------------------------- /WhisperPS/Commands/ExportSubrip.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Globalization; 3 | using System.IO; 4 | using System.Management.Automation; 5 | 6 | namespace Whisper 7 | { 8 | /// 9 | /// Write transcribe results into SubRip format. 10 | /// The format is documented there: https://en.wikipedia.org/wiki/SubRip#SubRip_file_format 11 | /// 12 | /// Export-SubRip $transcribeResults -path transcript.srt 13 | [Cmdlet( VerbsData.Export, "SubRip" )] 14 | public sealed class ExportSubrip: ExportBase 15 | { 16 | /// 17 | /// Optional integer offset to the indices 18 | /// 19 | [Parameter] 20 | public int offset { get; set; } = 0; 21 | 22 | static string printTimeWithComma( TimeSpan ts ) => 23 | ts.ToString( "hh':'mm':'ss','fff", CultureInfo.InvariantCulture ); 24 | 25 | /// Write that text 26 | protected override void write( StreamWriter stream, TranscribeResult transcribeResult ) 27 | { 28 | var segments = transcribeResult.segments; 29 | 30 | for( int i = 0; i < segments.Length; i++ ) 31 | { 32 | stream.WriteLine( i + 1 + offset ); 33 | sSegment seg = segments[ i ]; 34 | string begin = printTimeWithComma( seg.time.begin ); 35 | string end = printTimeWithComma( seg.time.end ); 36 | stream.WriteLine( "{0} --> {1}", begin, end ); 37 | stream.WriteLine( seg.text.Trim() ); 38 | stream.WriteLine(); 39 | } 40 | } 41 | } 42 | } -------------------------------------------------------------------------------- /WhisperPS/Commands/ExportText.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Globalization; 3 | using System.IO; 4 | using System.Management.Automation; 5 | 6 | namespace Whisper 7 | { 8 | /// 9 | /// Write transcribe results into SubRip format. 10 | /// The format is documented there: https://en.wikipedia.org/wiki/SubRip#SubRip_file_format 11 | /// 12 | /// Export-Text $transcribeResults -path transcript.txt -timestamps 13 | [Cmdlet( VerbsData.Export, "Text" )] 14 | public sealed class ExportText: ExportBase 15 | { 16 | /// 17 | /// Specify this switch to include timestamps 18 | /// 19 | [Parameter] 20 | public SwitchParameter timestamps { get; set; } 21 | 22 | static string printTime( TimeSpan ts ) => 23 | ts.ToString( "hh':'mm':'ss'.'fff", CultureInfo.InvariantCulture ); 24 | 25 | /// Write that text 26 | protected override void write( StreamWriter stream, TranscribeResult transcribeResult ) 27 | { 28 | foreach( sSegment seg in transcribeResult.segments ) 29 | { 30 | if( timestamps ) 31 | { 32 | string begin = printTime( seg.time.begin ); 33 | string end = printTime( seg.time.end ); 34 | stream.Write( "[{0} --> {1}] ", begin, end ); 35 | } 36 | stream.WriteLine( seg.text.Trim() ); 37 | } 38 | } 39 | } 40 | } -------------------------------------------------------------------------------- /WhisperPS/Commands/ExportWebVtt.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Globalization; 3 | using System.IO; 4 | using System.Management.Automation; 5 | 6 | namespace Whisper 7 | { 8 | /// 9 | /// Write transcribe results into WebVTT format. 10 | /// The format is documented there: https://en.wikipedia.org/wiki/WebVTT 11 | /// 12 | /// Export-WebVTT $transcribeResults -path transcript.vtt 13 | [Cmdlet( VerbsData.Export, "WebVTT" )] 14 | public sealed class ExportWebVTT: ExportBase 15 | { 16 | static string printTime( TimeSpan ts ) => 17 | ts.ToString( "hh':'mm':'ss'.'fff", CultureInfo.InvariantCulture ); 18 | 19 | /// Write that text 20 | protected override void write( StreamWriter stream, TranscribeResult transcribeResult ) 21 | { 22 | var segments = transcribeResult.segments; 23 | 24 | stream.WriteLine( "WEBVTT" ); 25 | stream.WriteLine(); 26 | 27 | foreach( sSegment seg in segments ) 28 | { 29 | string begin = printTime( seg.time.begin ); 30 | string end = printTime( seg.time.end ); 31 | stream.WriteLine( "{0} --> {1}", begin, end ); 32 | stream.WriteLine( seg.text ); 33 | stream.WriteLine(); 34 | } 35 | } 36 | } 37 | } -------------------------------------------------------------------------------- /WhisperPS/Commands/FormatSegments.cs: -------------------------------------------------------------------------------- 1 | using System.Management.Automation; 2 | 3 | namespace Whisper 4 | { 5 | /// 6 | /// Format transcribe results as a sequence of segments. 7 | /// Each segment has a pair of timestamps, and the text 8 | /// 9 | /// Format-Segments $transcribeResults 10 | [Cmdlet( VerbsCommon.Format, "Segments" )] 11 | public sealed class FormatSegments: Cmdlet 12 | { 13 | /// 14 | /// Transcribe result produced by 15 | /// It requires the value of the correct type 16 | /// 17 | [Parameter( Mandatory = true, ValueFromPipeline = true )] 18 | public Transcription source { get; set; } 19 | 20 | /// Performs execution of the command 21 | protected override void ProcessRecord() 22 | { 23 | var res = source.getResult(); 24 | foreach( var seg in res.segments ) 25 | { 26 | Segment obj = new Segment( seg ); 27 | WriteObject( obj ); 28 | } 29 | } 30 | } 31 | } -------------------------------------------------------------------------------- /WhisperPS/Commands/ListAdapters.cs: -------------------------------------------------------------------------------- 1 | using System.Management.Automation; 2 | 3 | namespace Whisper 4 | { 5 | /// 6 | /// Produces list of the names of the GPUs available to Direct3D 11 7 | /// You can pass any of these strings to the adapter argument of the Import-WhisperModel cmdlet. 8 | /// 9 | /// Get-Adapters 10 | 11 | [Cmdlet( VerbsCommon.Get, "Adapters" )] 12 | public sealed class ListAdapters: Cmdlet 13 | { 14 | /// Performs execution of the command 15 | protected override void ProcessRecord() 16 | { 17 | string[] arr = Library.listGraphicAdapters(); 18 | foreach( string item in arr ) 19 | WriteObject( item ); 20 | } 21 | } 22 | } -------------------------------------------------------------------------------- /WhisperPS/Internal/MarshalEx.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Text; 3 | 4 | namespace Whisper.Internal 5 | { 6 | static class MarshalEx 7 | { 8 | /// Workaround for the missing Marshal.PtrToStringUTF8 method 9 | public static string PtrToStringUTF8( IntPtr ptr ) 10 | { 11 | if( ptr != IntPtr.Zero ) 12 | { 13 | unsafe 14 | { 15 | byte* stringStart = (byte*)ptr; 16 | byte* stringEnd = stringStart; 17 | while( true ) 18 | { 19 | if( 0 == *stringEnd ) 20 | break; 21 | stringEnd++; 22 | } 23 | 24 | int len = (int)( stringEnd - stringStart ); 25 | return Encoding.UTF8.GetString( stringStart, len ); 26 | } 27 | } 28 | return null; 29 | } 30 | } 31 | } -------------------------------------------------------------------------------- /WhisperPS/Internal/sCaptureDevice.cs: -------------------------------------------------------------------------------- 1 | #pragma warning disable CS0649 // Field is never assigned to 2 | using System; 3 | using System.Runtime.InteropServices; 4 | 5 | namespace Whisper.Internal 6 | { 7 | /// Identifiers for an audio capture device 8 | public struct sCaptureDevice 9 | { 10 | readonly IntPtr m_displayName; 11 | /// The display name is suitable for showing to the user, but might not be unique. 12 | public string displayName => Marshal.PtrToStringUni( m_displayName ); 13 | 14 | readonly IntPtr m_endpoint; 15 | /// Endpoint ID for an audio capture device.
16 | /// It uniquely identifies the device on the system, but is not a readable string.
17 | public string endpoint => Marshal.PtrToStringUni( m_endpoint ); 18 | } 19 | 20 | /// Function pointer to consume a list of audio capture device IDs 21 | [UnmanagedFunctionPointer( CallingConvention.StdCall )] 22 | public delegate int pfnFoundCaptureDevices( int len, [In, MarshalAs( UnmanagedType.LPArray, SizeParamIndex = 0 )] sCaptureDevice[] arr, IntPtr pv ); 23 | } -------------------------------------------------------------------------------- /WhisperPS/Internal/sModelSetup.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Runtime.InteropServices; 3 | 4 | namespace Whisper.Internal 5 | { 6 | struct sModelSetup 7 | { 8 | eModelImplementation impl; 9 | eGpuModelFlags flags; 10 | [MarshalAs( UnmanagedType.LPWStr )] 11 | string adapter; 12 | 13 | public sModelSetup( eGpuModelFlags flags, eModelImplementation impl, string adapter ) 14 | { 15 | this.impl = impl; 16 | this.flags = flags; 17 | this.adapter = adapter; 18 | } 19 | } 20 | 21 | [UnmanagedFunctionPointer( CallingConvention.StdCall )] 22 | delegate void pfnListAdapters( [In, MarshalAs( UnmanagedType.LPWStr )] string name, IntPtr pv ); 23 | } -------------------------------------------------------------------------------- /WhisperPS/Internal/sProgressSink.cs: -------------------------------------------------------------------------------- 1 | #pragma warning disable CS0649 // Field is never assigned to 2 | using System; 3 | using System.Runtime.InteropServices; 4 | 5 | namespace Whisper.Internal 6 | { 7 | /// A callback to get notified about the progress 8 | [UnmanagedFunctionPointer( CallingConvention.StdCall )] 9 | delegate int pfnReportProgress( double value, IntPtr context, IntPtr pv ); 10 | 11 | /// C structure with a progress reporting function pointer 12 | public struct sProgressSink 13 | { 14 | /// A callback to get notified about the progress 15 | [MarshalAs( UnmanagedType.FunctionPtr )] 16 | internal pfnReportProgress pfn; 17 | 18 | /// Last parameter to the callback 19 | internal IntPtr pv; 20 | } 21 | } -------------------------------------------------------------------------------- /WhisperPS/Properties/AssemblyTitle.cs: -------------------------------------------------------------------------------- 1 | using System.Reflection; 2 | using System.Runtime.InteropServices; 3 | 4 | [assembly: AssemblyTitle( "WhisperPS" )] 5 | [assembly: AssemblyDescription( "DirectCompute port of whisper.cpp library, PowerShell 5.1 bindings" )] 6 | [assembly: Guid( "6909b760-ff72-48f5-8493-2e956cbb7cec" )] -------------------------------------------------------------------------------- /WhisperPS/Types/Model.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | 3 | namespace Whisper 4 | { 5 | /// 6 | /// This object holds a Whisper model, loaded from disk to VRAM on the GPU. 7 | /// For large models, the data size may exceed 4GB of video memory 8 | /// 9 | public sealed class Model: IDisposable 10 | { 11 | internal iMediaFoundation mf { get; private set; } 12 | internal iModel model { get; private set; } 13 | 14 | internal Model( iMediaFoundation mf, iModel model ) 15 | { 16 | this.mf = mf; 17 | this.model = model; 18 | } 19 | 20 | public void Dispose() 21 | { 22 | mf?.Dispose(); 23 | mf = null; 24 | model?.Dispose(); 25 | model = null; 26 | GC.SuppressFinalize( this ); 27 | } 28 | 29 | ~Model() 30 | { 31 | mf?.Dispose(); 32 | mf = null; 33 | model?.Dispose(); 34 | model = null; 35 | } 36 | } 37 | } -------------------------------------------------------------------------------- /WhisperPS/Types/Segment.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | 3 | namespace Whisper 4 | { 5 | /// One text segment of a transcription 6 | public sealed class Segment 7 | { 8 | internal Segment( in sSegment seg ) 9 | { 10 | Begin = seg.time.begin; 11 | End = seg.time.end; 12 | Text = seg.text.Trim(); 13 | } 14 | 15 | /// Timestamp of the start of the segment, since the start of the media 16 | public TimeSpan Begin { get; } 17 | /// Timestamp of the end of the segment, since the start of the media 18 | public TimeSpan End { get; } 19 | 20 | /// Text of the segment 21 | public string Text { get; } 22 | 23 | /// A string representation of this object 24 | public override string ToString() => Text; 25 | } 26 | } -------------------------------------------------------------------------------- /WhisperPS/Utils/CommandLogger.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Management.Automation; 3 | 4 | namespace Whisper 5 | { 6 | static class CommandLogger 7 | { 8 | static CommandLogger() 9 | { 10 | Library.setLogSink( eLogLevel.Debug, eLoggerFlags.SkipFormatMessage, sink ); 11 | } 12 | 13 | static void sink( eLogLevel level, string message ) 14 | { 15 | switch( level ) 16 | { 17 | case eLogLevel.Warning: 18 | cmdlet?.WriteWarning( message ); 19 | break; 20 | case eLogLevel.Info: 21 | cmdlet?.WriteInformation( message, null ); 22 | break; 23 | case eLogLevel.Debug: 24 | cmdlet?.WriteDebug( message ); 25 | break; 26 | // Errors usually become C# exceptions 27 | } 28 | } 29 | 30 | [ThreadStatic] 31 | static Cmdlet cmdlet; 32 | 33 | sealed class Impl: IDisposable 34 | { 35 | public Impl( Cmdlet c ) 36 | { 37 | cmdlet = c; 38 | } 39 | 40 | void IDisposable.Dispose() 41 | { 42 | cmdlet = null; 43 | } 44 | } 45 | 46 | public static IDisposable setupLog( this Cmdlet cmd ) => 47 | new Impl( cmd ); 48 | } 49 | } -------------------------------------------------------------------------------- /WhisperPS/app.config: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /WhisperPS/packages.config: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /gui-capture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Const-me/Whisper/306aadd1fce4b168cd38659236f4ba7c1603cebd/gui-capture.png -------------------------------------------------------------------------------- /gui-load-model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Const-me/Whisper/306aadd1fce4b168cd38659236f4ba7c1603cebd/gui-load-model.png -------------------------------------------------------------------------------- /gui-transcribe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Const-me/Whisper/306aadd1fce4b168cd38659236f4ba7c1603cebd/gui-transcribe.png --------------------------------------------------------------------------------