├── .gitignore ├── README.md ├── Whisper ├── .gitignore ├── Whisper.xcodeproj │ ├── project.pbxproj │ ├── project.xcworkspace │ │ ├── contents.xcworkspacedata │ │ └── xcshareddata │ │ │ ├── IDEWorkspaceChecks.plist │ │ │ └── swiftpm │ │ │ └── Package.resolved │ └── xcshareddata │ │ └── xcschemes │ │ └── Whisper.xcscheme └── Whisper │ ├── Assets.xcassets │ ├── AccentColor.colorset │ │ └── Contents.json │ ├── AppIcon.appiconset │ │ └── Contents.json │ └── Contents.json │ ├── AudioLoader.swift │ ├── AudioRecorder.swift │ ├── ContentView.swift │ ├── Info.plist │ ├── Preview Content │ └── Preview Assets.xcassets │ │ └── Contents.json │ ├── Whisper │ ├── FFT.swift │ ├── MLMultiArray+Utils.swift │ ├── MatftShapedArrayExtensions.swift │ ├── Math.swift │ ├── MelSpectrogram.swift │ ├── STFT.swift │ ├── Whisper.swift │ └── WhisperTokenizer.swift │ ├── WhisperApp.swift │ ├── gpt2-merges.txt │ ├── gpt2-vocab.json │ ├── mel_filters.data │ ├── multilingual-merges.txt │ ├── multilingual-vocab.json │ └── python_log_mel.raw ├── export_m80.py └── whisper_to_cml.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Whisper CoreML 3 | 4 | A port of OpenAI's Whisper Speech Transcription model to CoreML 5 | 6 | The goal of this project is to natively port, and optimize Whisper for use on Apple Silicon including optimization for the Apple Neural Engine, and match the incredible WhisperCPP project on features. 7 | 8 | --- 9 | 10 | ***Please note this repo is currently under development, so there will be bumps in the road.*** 11 | 12 | ***Community input is welcome!*** 13 | 14 | --- 15 | 16 | You can: 17 | 18 | Create a Whipser instance `whisper = try Whisper()` 19 | 20 | And run transcription on a Quicktime compatible asset via: `await whisper.transcribe(assetURL:URL, options:WhisperOptions)` 21 | 22 | You can choose options via the `WhisperOptions` struct. 23 | 24 | Whipser CoreML will load an asset using AVFoundation and convert the audio to the appropriate format for transcription. 25 | 26 | Alternatively, for realtime usage, you can call start a whisper session via `startWhisperSession(options:WhisperOptions)`, and then send sample buffers to `accrueSamplesFromSampleBuffer(sampleBuffer:CMSampleBuffer)` from say an AVCaptureSession or AVAudioSession, or any other source. 27 | 28 | Note, we accrue a 30 second sample for now as that is the expected number of samples required. 29 | 30 | ## Status 31 | * [X] Working Multi Lingual Transcription 32 | * [ ] [Optimize the CoreML models for ANE](https://machinelearning.apple.com/research/neural-engine-transformers) using [Apples ANE Transformers sample code found at this repository](https://github.com/apple/ml-ane-transformers) 33 | * [ ] Port Log Mel Spectrogram to native vDSP and ditch RosaKit package dependency. 34 | * [ ] Decode Special Tokens for time stamps. 35 | * [ ] Decide on API design 36 | 37 | ## Performance 38 | 39 | * Base model gets roughly 4x realtime using a single core on an M1 Mac Book Pro. 40 | 41 | 42 | ## Getting Models: 43 | 44 | [For ease of use, you can use this Google Colab to convert models](https://colab.research.google.com/drive/1IiBx6-hipt3ER3VjkjuUEAObwipHy1mL 45 | ). Note that if you convert Medium or larger models you may run into memory issues on Google Colab. 46 | 47 | This repository assumes youre converting multilingual models. If you need 'en' models you'll need to adjust the special token values by negative 1. 48 | -------------------------------------------------------------------------------- /Whisper/.gitignore: -------------------------------------------------------------------------------- 1 | xcuserdata/ 2 | *.mlpackage 3 | *.mlmodel 4 | *.a 5 | -------------------------------------------------------------------------------- /Whisper/Whisper.xcodeproj/project.pbxproj: -------------------------------------------------------------------------------- 1 | // !$*UTF8*$! 2 | { 3 | archiveVersion = 1; 4 | classes = { 5 | }; 6 | objectVersion = 56; 7 | objects = { 8 | 9 | /* Begin PBXBuildFile section */ 10 | 01AF98F928E28BF2002DAC53 /* WhisperApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 01AF98F828E28BF2002DAC53 /* WhisperApp.swift */; }; 11 | 01AF98FB28E28BF2002DAC53 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 01AF98FA28E28BF2002DAC53 /* ContentView.swift */; }; 12 | 01AF98FD28E28BF3002DAC53 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 01AF98FC28E28BF3002DAC53 /* Assets.xcassets */; }; 13 | 01AF990028E28BF3002DAC53 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 01AF98FF28E28BF3002DAC53 /* Preview Assets.xcassets */; }; 14 | 01AF991F28E2A97E002DAC53 /* Whisper.swift in Sources */ = {isa = PBXBuildFile; fileRef = 01AF991E28E2A97E002DAC53 /* Whisper.swift */; }; 15 | 01AF992528E2D90F002DAC53 /* AudioRecorder.swift in Sources */ = {isa = PBXBuildFile; fileRef = 01AF992428E2D90F002DAC53 /* AudioRecorder.swift */; }; 16 | 1B03910729760795007E945E /* encoder_base.mlpackage in Sources */ = {isa = PBXBuildFile; fileRef = 1B03910629760795007E945E /* encoder_base.mlpackage */; }; 17 | 1B58A60E296A3B40006E0969 /* Math.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1B58A60D296A3B40006E0969 /* Math.swift */; }; 18 | 1B58A610296A3BAB006E0969 /* MLMultiArray+Utils.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1B58A60F296A3BAB006E0969 /* MLMultiArray+Utils.swift */; }; 19 | 1B58A613296BD704006E0969 /* multilingual-vocab.json in Resources */ = {isa = PBXBuildFile; fileRef = 1B58A611296BD704006E0969 /* multilingual-vocab.json */; }; 20 | 1B58A614296BD704006E0969 /* multilingual-merges.txt in Resources */ = {isa = PBXBuildFile; fileRef = 1B58A612296BD704006E0969 /* multilingual-merges.txt */; }; 21 | 1B58A618296E03FF006E0969 /* python_log_mel.raw in Resources */ = {isa = PBXBuildFile; fileRef = 1B58A617296E03FE006E0969 /* python_log_mel.raw */; }; 22 | 1B58A61A2970C127006E0969 /* FFT.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1B58A6192970C127006E0969 /* FFT.swift */; }; 23 | 1B58A61C29733174006E0969 /* STFT.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1B58A61B29733173006E0969 /* STFT.swift */; }; 24 | 1B58A6382975C3B1006E0969 /* RosaKit in Frameworks */ = {isa = PBXBuildFile; productRef = 1B58A6372975C3B1006E0969 /* RosaKit */; }; 25 | 1B58A63B2975F07A006E0969 /* decoder_base.mlpackage in Sources */ = {isa = PBXBuildFile; fileRef = 1B58A6392975F07A006E0969 /* decoder_base.mlpackage */; }; 26 | 1B6A59D929634F2A004C510D /* MelSpectrogram.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1B6A59D829634F2A004C510D /* MelSpectrogram.swift */; }; 27 | 1B6A59DF2963A1A6004C510D /* mel_filters.data in Resources */ = {isa = PBXBuildFile; fileRef = 1B6A59DE2963A1A6004C510D /* mel_filters.data */; }; 28 | 1B6A59E12963A670004C510D /* AudioLoader.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1B6A59E02963A670004C510D /* AudioLoader.swift */; }; 29 | 1BF960EA298B063200AA2990 /* Matft in Frameworks */ = {isa = PBXBuildFile; productRef = 1BF960E9298B063200AA2990 /* Matft */; }; 30 | 1BF960EC298B071300AA2990 /* MatftShapedArrayExtensions.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1BF960EB298B071300AA2990 /* MatftShapedArrayExtensions.swift */; }; 31 | 1BFAC9DD2968F0A900A5A099 /* WhisperTokenizer.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1BFAC9DC2968F0A900A5A099 /* WhisperTokenizer.swift */; }; 32 | 1BFAC9E02968F11100A5A099 /* gpt2-merges.txt in Resources */ = {isa = PBXBuildFile; fileRef = 1BFAC9DE2968F11100A5A099 /* gpt2-merges.txt */; }; 33 | 1BFAC9E12968F11100A5A099 /* gpt2-vocab.json in Resources */ = {isa = PBXBuildFile; fileRef = 1BFAC9DF2968F11100A5A099 /* gpt2-vocab.json */; }; 34 | /* End PBXBuildFile section */ 35 | 36 | /* Begin PBXFileReference section */ 37 | 01AF98F528E28BF2002DAC53 /* Whisper.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = Whisper.app; sourceTree = BUILT_PRODUCTS_DIR; }; 38 | 01AF98F828E28BF2002DAC53 /* WhisperApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperApp.swift; sourceTree = ""; }; 39 | 01AF98FA28E28BF2002DAC53 /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = ""; }; 40 | 01AF98FC28E28BF3002DAC53 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; 41 | 01AF98FF28E28BF3002DAC53 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = ""; }; 42 | 01AF990928E28C2E002DAC53 /* libresolv.tbd */ = {isa = PBXFileReference; lastKnownFileType = "sourcecode.text-based-dylib-definition"; name = libresolv.tbd; path = usr/lib/libresolv.tbd; sourceTree = SDKROOT; }; 43 | 01AF991E28E2A97E002DAC53 /* Whisper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Whisper.swift; sourceTree = ""; }; 44 | 01AF992428E2D90F002DAC53 /* AudioRecorder.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AudioRecorder.swift; sourceTree = ""; }; 45 | 01AF992628E2D9B9002DAC53 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist; path = Info.plist; sourceTree = ""; }; 46 | 1B03910629760795007E945E /* encoder_base.mlpackage */ = {isa = PBXFileReference; lastKnownFileType = folder.mlpackage; path = encoder_base.mlpackage; sourceTree = ""; }; 47 | 1B58A60D296A3B40006E0969 /* Math.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Math.swift; sourceTree = ""; }; 48 | 1B58A60F296A3BAB006E0969 /* MLMultiArray+Utils.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "MLMultiArray+Utils.swift"; sourceTree = ""; }; 49 | 1B58A611296BD704006E0969 /* multilingual-vocab.json */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.json; path = "multilingual-vocab.json"; sourceTree = ""; }; 50 | 1B58A612296BD704006E0969 /* multilingual-merges.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = "multilingual-merges.txt"; sourceTree = ""; }; 51 | 1B58A617296E03FE006E0969 /* python_log_mel.raw */ = {isa = PBXFileReference; lastKnownFileType = file; path = python_log_mel.raw; sourceTree = ""; }; 52 | 1B58A6192970C127006E0969 /* FFT.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = FFT.swift; sourceTree = ""; }; 53 | 1B58A61B29733173006E0969 /* STFT.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = STFT.swift; sourceTree = ""; }; 54 | 1B58A6392975F07A006E0969 /* decoder_base.mlpackage */ = {isa = PBXFileReference; lastKnownFileType = folder.mlpackage; path = decoder_base.mlpackage; sourceTree = ""; }; 55 | 1B6A59D829634F2A004C510D /* MelSpectrogram.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MelSpectrogram.swift; sourceTree = ""; }; 56 | 1B6A59DA29635145004C510D /* decoder.mlpackage */ = {isa = PBXFileReference; lastKnownFileType = folder.mlpackage; path = decoder.mlpackage; sourceTree = ""; }; 57 | 1B6A59DB29635145004C510D /* encoder.mlpackage */ = {isa = PBXFileReference; lastKnownFileType = folder.mlpackage; path = encoder.mlpackage; sourceTree = ""; }; 58 | 1B6A59DE2963A1A6004C510D /* mel_filters.data */ = {isa = PBXFileReference; lastKnownFileType = file; path = mel_filters.data; sourceTree = ""; }; 59 | 1B6A59E02963A670004C510D /* AudioLoader.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AudioLoader.swift; sourceTree = ""; }; 60 | 1BF960EB298B071300AA2990 /* MatftShapedArrayExtensions.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MatftShapedArrayExtensions.swift; sourceTree = ""; }; 61 | 1BFAC9DC2968F0A900A5A099 /* WhisperTokenizer.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperTokenizer.swift; sourceTree = ""; }; 62 | 1BFAC9DE2968F11100A5A099 /* gpt2-merges.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = "gpt2-merges.txt"; sourceTree = ""; }; 63 | 1BFAC9DF2968F11100A5A099 /* gpt2-vocab.json */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.json; path = "gpt2-vocab.json"; sourceTree = ""; }; 64 | /* End PBXFileReference section */ 65 | 66 | /* Begin PBXFrameworksBuildPhase section */ 67 | 01AF98F228E28BF2002DAC53 /* Frameworks */ = { 68 | isa = PBXFrameworksBuildPhase; 69 | buildActionMask = 2147483647; 70 | files = ( 71 | 1B58A6382975C3B1006E0969 /* RosaKit in Frameworks */, 72 | 1BF960EA298B063200AA2990 /* Matft in Frameworks */, 73 | ); 74 | runOnlyForDeploymentPostprocessing = 0; 75 | }; 76 | /* End PBXFrameworksBuildPhase section */ 77 | 78 | /* Begin PBXGroup section */ 79 | 01AF98EC28E28BF2002DAC53 = { 80 | isa = PBXGroup; 81 | children = ( 82 | 01AF98F728E28BF2002DAC53 /* Whisper */, 83 | 01AF98F628E28BF2002DAC53 /* Products */, 84 | 01AF990828E28C2E002DAC53 /* Frameworks */, 85 | ); 86 | sourceTree = ""; 87 | }; 88 | 01AF98F628E28BF2002DAC53 /* Products */ = { 89 | isa = PBXGroup; 90 | children = ( 91 | 01AF98F528E28BF2002DAC53 /* Whisper.app */, 92 | ); 93 | name = Products; 94 | sourceTree = ""; 95 | }; 96 | 01AF98F728E28BF2002DAC53 /* Whisper */ = { 97 | isa = PBXGroup; 98 | children = ( 99 | 01AF992628E2D9B9002DAC53 /* Info.plist */, 100 | 01AF98F828E28BF2002DAC53 /* WhisperApp.swift */, 101 | 01AF992428E2D90F002DAC53 /* AudioRecorder.swift */, 102 | 1B6A59E02963A670004C510D /* AudioLoader.swift */, 103 | 01AF98FA28E28BF2002DAC53 /* ContentView.swift */, 104 | 1B03910829788B35007E945E /* Whisper */, 105 | 1B58A6392975F07A006E0969 /* decoder_base.mlpackage */, 106 | 1B03910629760795007E945E /* encoder_base.mlpackage */, 107 | 1B6A59DA29635145004C510D /* decoder.mlpackage */, 108 | 1B6A59DB29635145004C510D /* encoder.mlpackage */, 109 | 1B6A59DE2963A1A6004C510D /* mel_filters.data */, 110 | 1BFAC9DE2968F11100A5A099 /* gpt2-merges.txt */, 111 | 1BFAC9DF2968F11100A5A099 /* gpt2-vocab.json */, 112 | 1B58A612296BD704006E0969 /* multilingual-merges.txt */, 113 | 1B58A611296BD704006E0969 /* multilingual-vocab.json */, 114 | 1B58A617296E03FE006E0969 /* python_log_mel.raw */, 115 | 01AF98FC28E28BF3002DAC53 /* Assets.xcassets */, 116 | 01AF98FE28E28BF3002DAC53 /* Preview Content */, 117 | ); 118 | path = Whisper; 119 | sourceTree = ""; 120 | }; 121 | 01AF98FE28E28BF3002DAC53 /* Preview Content */ = { 122 | isa = PBXGroup; 123 | children = ( 124 | 01AF98FF28E28BF3002DAC53 /* Preview Assets.xcassets */, 125 | ); 126 | path = "Preview Content"; 127 | sourceTree = ""; 128 | }; 129 | 01AF990828E28C2E002DAC53 /* Frameworks */ = { 130 | isa = PBXGroup; 131 | children = ( 132 | 01AF990928E28C2E002DAC53 /* libresolv.tbd */, 133 | ); 134 | name = Frameworks; 135 | sourceTree = ""; 136 | }; 137 | 1B03910829788B35007E945E /* Whisper */ = { 138 | isa = PBXGroup; 139 | children = ( 140 | 01AF991E28E2A97E002DAC53 /* Whisper.swift */, 141 | 1BFAC9DC2968F0A900A5A099 /* WhisperTokenizer.swift */, 142 | 1B58A6192970C127006E0969 /* FFT.swift */, 143 | 1B58A61B29733173006E0969 /* STFT.swift */, 144 | 1B6A59D829634F2A004C510D /* MelSpectrogram.swift */, 145 | 1B58A60F296A3BAB006E0969 /* MLMultiArray+Utils.swift */, 146 | 1B58A60D296A3B40006E0969 /* Math.swift */, 147 | 1BF960EB298B071300AA2990 /* MatftShapedArrayExtensions.swift */, 148 | ); 149 | path = Whisper; 150 | sourceTree = ""; 151 | }; 152 | /* End PBXGroup section */ 153 | 154 | /* Begin PBXNativeTarget section */ 155 | 01AF98F428E28BF2002DAC53 /* Whisper */ = { 156 | isa = PBXNativeTarget; 157 | buildConfigurationList = 01AF990328E28BF3002DAC53 /* Build configuration list for PBXNativeTarget "Whisper" */; 158 | buildPhases = ( 159 | 01AF98F128E28BF2002DAC53 /* Sources */, 160 | 01AF98F228E28BF2002DAC53 /* Frameworks */, 161 | 01AF98F328E28BF2002DAC53 /* Resources */, 162 | ); 163 | buildRules = ( 164 | ); 165 | dependencies = ( 166 | ); 167 | name = Whisper; 168 | packageProductDependencies = ( 169 | 1B58A6372975C3B1006E0969 /* RosaKit */, 170 | 1BF960E9298B063200AA2990 /* Matft */, 171 | ); 172 | productName = Whisper; 173 | productReference = 01AF98F528E28BF2002DAC53 /* Whisper.app */; 174 | productType = "com.apple.product-type.application"; 175 | }; 176 | /* End PBXNativeTarget section */ 177 | 178 | /* Begin PBXProject section */ 179 | 01AF98ED28E28BF2002DAC53 /* Project object */ = { 180 | isa = PBXProject; 181 | attributes = { 182 | BuildIndependentTargetsInParallel = 1; 183 | LastSwiftUpdateCheck = 1400; 184 | LastUpgradeCheck = 1400; 185 | TargetAttributes = { 186 | 01AF98F428E28BF2002DAC53 = { 187 | CreatedOnToolsVersion = 14.0; 188 | }; 189 | }; 190 | }; 191 | buildConfigurationList = 01AF98F028E28BF2002DAC53 /* Build configuration list for PBXProject "Whisper" */; 192 | compatibilityVersion = "Xcode 14.0"; 193 | developmentRegion = en; 194 | hasScannedForEncodings = 0; 195 | knownRegions = ( 196 | en, 197 | Base, 198 | ); 199 | mainGroup = 01AF98EC28E28BF2002DAC53; 200 | packageReferences = ( 201 | 1B58A6362975C3B1006E0969 /* XCRemoteSwiftPackageReference "RosaKit" */, 202 | 1BF960E8298B063200AA2990 /* XCRemoteSwiftPackageReference "Matft" */, 203 | ); 204 | productRefGroup = 01AF98F628E28BF2002DAC53 /* Products */; 205 | projectDirPath = ""; 206 | projectRoot = ""; 207 | targets = ( 208 | 01AF98F428E28BF2002DAC53 /* Whisper */, 209 | ); 210 | }; 211 | /* End PBXProject section */ 212 | 213 | /* Begin PBXResourcesBuildPhase section */ 214 | 01AF98F328E28BF2002DAC53 /* Resources */ = { 215 | isa = PBXResourcesBuildPhase; 216 | buildActionMask = 2147483647; 217 | files = ( 218 | 1B6A59DF2963A1A6004C510D /* mel_filters.data in Resources */, 219 | 01AF990028E28BF3002DAC53 /* Preview Assets.xcassets in Resources */, 220 | 1B58A618296E03FF006E0969 /* python_log_mel.raw in Resources */, 221 | 1B58A614296BD704006E0969 /* multilingual-merges.txt in Resources */, 222 | 1B58A613296BD704006E0969 /* multilingual-vocab.json in Resources */, 223 | 1BFAC9E02968F11100A5A099 /* gpt2-merges.txt in Resources */, 224 | 1BFAC9E12968F11100A5A099 /* gpt2-vocab.json in Resources */, 225 | 01AF98FD28E28BF3002DAC53 /* Assets.xcassets in Resources */, 226 | ); 227 | runOnlyForDeploymentPostprocessing = 0; 228 | }; 229 | /* End PBXResourcesBuildPhase section */ 230 | 231 | /* Begin PBXSourcesBuildPhase section */ 232 | 01AF98F128E28BF2002DAC53 /* Sources */ = { 233 | isa = PBXSourcesBuildPhase; 234 | buildActionMask = 2147483647; 235 | files = ( 236 | 01AF992528E2D90F002DAC53 /* AudioRecorder.swift in Sources */, 237 | 1BFAC9DD2968F0A900A5A099 /* WhisperTokenizer.swift in Sources */, 238 | 01AF98FB28E28BF2002DAC53 /* ContentView.swift in Sources */, 239 | 1B6A59D929634F2A004C510D /* MelSpectrogram.swift in Sources */, 240 | 1B58A610296A3BAB006E0969 /* MLMultiArray+Utils.swift in Sources */, 241 | 01AF98F928E28BF2002DAC53 /* WhisperApp.swift in Sources */, 242 | 1B6A59E12963A670004C510D /* AudioLoader.swift in Sources */, 243 | 1B58A61C29733174006E0969 /* STFT.swift in Sources */, 244 | 1B03910729760795007E945E /* encoder_base.mlpackage in Sources */, 245 | 1B58A61A2970C127006E0969 /* FFT.swift in Sources */, 246 | 1BF960EC298B071300AA2990 /* MatftShapedArrayExtensions.swift in Sources */, 247 | 1B58A63B2975F07A006E0969 /* decoder_base.mlpackage in Sources */, 248 | 01AF991F28E2A97E002DAC53 /* Whisper.swift in Sources */, 249 | 1B58A60E296A3B40006E0969 /* Math.swift in Sources */, 250 | ); 251 | runOnlyForDeploymentPostprocessing = 0; 252 | }; 253 | /* End PBXSourcesBuildPhase section */ 254 | 255 | /* Begin XCBuildConfiguration section */ 256 | 01AF990128E28BF3002DAC53 /* Debug */ = { 257 | isa = XCBuildConfiguration; 258 | buildSettings = { 259 | ALWAYS_SEARCH_USER_PATHS = NO; 260 | CLANG_ANALYZER_NONNULL = YES; 261 | CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; 262 | CLANG_CXX_LANGUAGE_STANDARD = "gnu++17"; 263 | CLANG_ENABLE_MODULES = YES; 264 | CLANG_ENABLE_OBJC_ARC = YES; 265 | CLANG_ENABLE_OBJC_WEAK = YES; 266 | CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; 267 | CLANG_WARN_BOOL_CONVERSION = YES; 268 | CLANG_WARN_COMMA = YES; 269 | CLANG_WARN_CONSTANT_CONVERSION = YES; 270 | CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; 271 | CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; 272 | CLANG_WARN_DOCUMENTATION_COMMENTS = YES; 273 | CLANG_WARN_EMPTY_BODY = YES; 274 | CLANG_WARN_ENUM_CONVERSION = YES; 275 | CLANG_WARN_INFINITE_RECURSION = YES; 276 | CLANG_WARN_INT_CONVERSION = YES; 277 | CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; 278 | CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; 279 | CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; 280 | CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; 281 | CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; 282 | CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; 283 | CLANG_WARN_STRICT_PROTOTYPES = YES; 284 | CLANG_WARN_SUSPICIOUS_MOVE = YES; 285 | CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; 286 | CLANG_WARN_UNREACHABLE_CODE = YES; 287 | CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; 288 | COPY_PHASE_STRIP = NO; 289 | DEBUG_INFORMATION_FORMAT = dwarf; 290 | ENABLE_STRICT_OBJC_MSGSEND = YES; 291 | ENABLE_TESTABILITY = YES; 292 | GCC_C_LANGUAGE_STANDARD = gnu11; 293 | GCC_DYNAMIC_NO_PIC = NO; 294 | GCC_NO_COMMON_BLOCKS = YES; 295 | GCC_OPTIMIZATION_LEVEL = 0; 296 | GCC_PREPROCESSOR_DEFINITIONS = ( 297 | "DEBUG=1", 298 | "$(inherited)", 299 | ); 300 | GCC_WARN_64_TO_32_BIT_CONVERSION = YES; 301 | GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; 302 | GCC_WARN_UNDECLARED_SELECTOR = YES; 303 | GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; 304 | GCC_WARN_UNUSED_FUNCTION = YES; 305 | GCC_WARN_UNUSED_VARIABLE = YES; 306 | IPHONEOS_DEPLOYMENT_TARGET = 16.0; 307 | MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; 308 | MTL_FAST_MATH = YES; 309 | ONLY_ACTIVE_ARCH = YES; 310 | SDKROOT = iphoneos; 311 | SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG; 312 | SWIFT_OPTIMIZATION_LEVEL = "-Onone"; 313 | }; 314 | name = Debug; 315 | }; 316 | 01AF990228E28BF3002DAC53 /* Release */ = { 317 | isa = XCBuildConfiguration; 318 | buildSettings = { 319 | ALWAYS_SEARCH_USER_PATHS = NO; 320 | CLANG_ANALYZER_NONNULL = YES; 321 | CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; 322 | CLANG_CXX_LANGUAGE_STANDARD = "gnu++17"; 323 | CLANG_ENABLE_MODULES = YES; 324 | CLANG_ENABLE_OBJC_ARC = YES; 325 | CLANG_ENABLE_OBJC_WEAK = YES; 326 | CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; 327 | CLANG_WARN_BOOL_CONVERSION = YES; 328 | CLANG_WARN_COMMA = YES; 329 | CLANG_WARN_CONSTANT_CONVERSION = YES; 330 | CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; 331 | CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; 332 | CLANG_WARN_DOCUMENTATION_COMMENTS = YES; 333 | CLANG_WARN_EMPTY_BODY = YES; 334 | CLANG_WARN_ENUM_CONVERSION = YES; 335 | CLANG_WARN_INFINITE_RECURSION = YES; 336 | CLANG_WARN_INT_CONVERSION = YES; 337 | CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; 338 | CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; 339 | CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; 340 | CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; 341 | CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; 342 | CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; 343 | CLANG_WARN_STRICT_PROTOTYPES = YES; 344 | CLANG_WARN_SUSPICIOUS_MOVE = YES; 345 | CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; 346 | CLANG_WARN_UNREACHABLE_CODE = YES; 347 | CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; 348 | COPY_PHASE_STRIP = NO; 349 | DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; 350 | ENABLE_NS_ASSERTIONS = NO; 351 | ENABLE_STRICT_OBJC_MSGSEND = YES; 352 | GCC_C_LANGUAGE_STANDARD = gnu11; 353 | GCC_NO_COMMON_BLOCKS = YES; 354 | GCC_WARN_64_TO_32_BIT_CONVERSION = YES; 355 | GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; 356 | GCC_WARN_UNDECLARED_SELECTOR = YES; 357 | GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; 358 | GCC_WARN_UNUSED_FUNCTION = YES; 359 | GCC_WARN_UNUSED_VARIABLE = YES; 360 | IPHONEOS_DEPLOYMENT_TARGET = 16.0; 361 | MTL_ENABLE_DEBUG_INFO = NO; 362 | MTL_FAST_MATH = YES; 363 | SDKROOT = iphoneos; 364 | SWIFT_COMPILATION_MODE = wholemodule; 365 | SWIFT_OPTIMIZATION_LEVEL = "-O"; 366 | VALIDATE_PRODUCT = YES; 367 | }; 368 | name = Release; 369 | }; 370 | 01AF990428E28BF3002DAC53 /* Debug */ = { 371 | isa = XCBuildConfiguration; 372 | buildSettings = { 373 | ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; 374 | ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; 375 | CODE_SIGN_STYLE = Automatic; 376 | CURRENT_PROJECT_VERSION = 1; 377 | DEVELOPMENT_ASSET_PATHS = "\"Whisper/Preview Content\""; 378 | DEVELOPMENT_TEAM = SHG3AW6YV7; 379 | ENABLE_PREVIEWS = YES; 380 | GENERATE_INFOPLIST_FILE = YES; 381 | HEADER_SEARCH_PATHS = "$(PROJECT_DIR)/Whisper"; 382 | INFOPLIST_FILE = Whisper/Info.plist; 383 | INFOPLIST_KEY_NSMicrophoneUsageDescription = "Language detection"; 384 | INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES; 385 | INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES; 386 | INFOPLIST_KEY_UILaunchScreen_Generation = YES; 387 | INFOPLIST_KEY_UISupportedInterfaceOrientations = UIInterfaceOrientationPortrait; 388 | INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight UIInterfaceOrientationPortraitUpsideDown"; 389 | IPHONEOS_DEPLOYMENT_TARGET = 15.5; 390 | LD_RUNPATH_SEARCH_PATHS = ( 391 | "$(inherited)", 392 | "@executable_path/Frameworks", 393 | ); 394 | LIBRARY_SEARCH_PATHS = ( 395 | "$(inherited)", 396 | "$(PROJECT_DIR)/Whisper", 397 | ); 398 | MACOSX_DEPLOYMENT_TARGET = 12.4; 399 | MARKETING_VERSION = 1.0; 400 | PRODUCT_BUNDLE_IDENTIFIER = tbss.Whisper; 401 | PRODUCT_NAME = "$(TARGET_NAME)"; 402 | SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx"; 403 | SUPPORTS_MACCATALYST = NO; 404 | SUPPORTS_MAC_DESIGNED_FOR_IPHONE_IPAD = NO; 405 | SWIFT_EMIT_LOC_STRINGS = YES; 406 | SWIFT_OPTIMIZATION_LEVEL = "-Onone"; 407 | SWIFT_VERSION = 5.0; 408 | TARGETED_DEVICE_FAMILY = "1,2"; 409 | }; 410 | name = Debug; 411 | }; 412 | 01AF990528E28BF3002DAC53 /* Release */ = { 413 | isa = XCBuildConfiguration; 414 | buildSettings = { 415 | ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; 416 | ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; 417 | CODE_SIGN_STYLE = Automatic; 418 | CURRENT_PROJECT_VERSION = 1; 419 | DEVELOPMENT_ASSET_PATHS = "\"Whisper/Preview Content\""; 420 | DEVELOPMENT_TEAM = SHG3AW6YV7; 421 | ENABLE_PREVIEWS = YES; 422 | GCC_OPTIMIZATION_LEVEL = fast; 423 | GENERATE_INFOPLIST_FILE = YES; 424 | HEADER_SEARCH_PATHS = "$(PROJECT_DIR)/Whisper"; 425 | INFOPLIST_FILE = Whisper/Info.plist; 426 | INFOPLIST_KEY_NSMicrophoneUsageDescription = "Language detection"; 427 | INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES; 428 | INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES; 429 | INFOPLIST_KEY_UILaunchScreen_Generation = YES; 430 | INFOPLIST_KEY_UISupportedInterfaceOrientations = UIInterfaceOrientationPortrait; 431 | INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight UIInterfaceOrientationPortraitUpsideDown"; 432 | IPHONEOS_DEPLOYMENT_TARGET = 15.5; 433 | LD_RUNPATH_SEARCH_PATHS = ( 434 | "$(inherited)", 435 | "@executable_path/Frameworks", 436 | ); 437 | LIBRARY_SEARCH_PATHS = ( 438 | "$(inherited)", 439 | "$(PROJECT_DIR)/Whisper", 440 | ); 441 | MACOSX_DEPLOYMENT_TARGET = 12.4; 442 | MARKETING_VERSION = 1.0; 443 | PRODUCT_BUNDLE_IDENTIFIER = tbss.Whisper; 444 | PRODUCT_NAME = "$(TARGET_NAME)"; 445 | SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx"; 446 | SUPPORTS_MACCATALYST = NO; 447 | SUPPORTS_MAC_DESIGNED_FOR_IPHONE_IPAD = NO; 448 | SWIFT_EMIT_LOC_STRINGS = YES; 449 | SWIFT_OPTIMIZATION_LEVEL = "-O"; 450 | SWIFT_VERSION = 5.0; 451 | TARGETED_DEVICE_FAMILY = "1,2"; 452 | }; 453 | name = Release; 454 | }; 455 | /* End XCBuildConfiguration section */ 456 | 457 | /* Begin XCConfigurationList section */ 458 | 01AF98F028E28BF2002DAC53 /* Build configuration list for PBXProject "Whisper" */ = { 459 | isa = XCConfigurationList; 460 | buildConfigurations = ( 461 | 01AF990128E28BF3002DAC53 /* Debug */, 462 | 01AF990228E28BF3002DAC53 /* Release */, 463 | ); 464 | defaultConfigurationIsVisible = 0; 465 | defaultConfigurationName = Release; 466 | }; 467 | 01AF990328E28BF3002DAC53 /* Build configuration list for PBXNativeTarget "Whisper" */ = { 468 | isa = XCConfigurationList; 469 | buildConfigurations = ( 470 | 01AF990428E28BF3002DAC53 /* Debug */, 471 | 01AF990528E28BF3002DAC53 /* Release */, 472 | ); 473 | defaultConfigurationIsVisible = 0; 474 | defaultConfigurationName = Release; 475 | }; 476 | /* End XCConfigurationList section */ 477 | 478 | /* Begin XCRemoteSwiftPackageReference section */ 479 | 1B58A6362975C3B1006E0969 /* XCRemoteSwiftPackageReference "RosaKit" */ = { 480 | isa = XCRemoteSwiftPackageReference; 481 | repositoryURL = "https://github.com/dhrebeniuk/RosaKit"; 482 | requirement = { 483 | branch = main; 484 | kind = branch; 485 | }; 486 | }; 487 | 1BF960E8298B063200AA2990 /* XCRemoteSwiftPackageReference "Matft" */ = { 488 | isa = XCRemoteSwiftPackageReference; 489 | repositoryURL = "https://github.com/jjjkkkjjj/Matft"; 490 | requirement = { 491 | kind = upToNextMajorVersion; 492 | minimumVersion = 0.3.2; 493 | }; 494 | }; 495 | /* End XCRemoteSwiftPackageReference section */ 496 | 497 | /* Begin XCSwiftPackageProductDependency section */ 498 | 1B58A6372975C3B1006E0969 /* RosaKit */ = { 499 | isa = XCSwiftPackageProductDependency; 500 | package = 1B58A6362975C3B1006E0969 /* XCRemoteSwiftPackageReference "RosaKit" */; 501 | productName = RosaKit; 502 | }; 503 | 1BF960E9298B063200AA2990 /* Matft */ = { 504 | isa = XCSwiftPackageProductDependency; 505 | package = 1BF960E8298B063200AA2990 /* XCRemoteSwiftPackageReference "Matft" */; 506 | productName = Matft; 507 | }; 508 | /* End XCSwiftPackageProductDependency section */ 509 | }; 510 | rootObject = 01AF98ED28E28BF2002DAC53 /* Project object */; 511 | } 512 | -------------------------------------------------------------------------------- /Whisper/Whisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /Whisper/Whisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | IDEDidComputeMac32BitWarning 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /Whisper/Whisper.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved: -------------------------------------------------------------------------------- 1 | { 2 | "pins" : [ 3 | { 4 | "identity" : "matft", 5 | "kind" : "remoteSourceControl", 6 | "location" : "https://github.com/jjjkkkjjj/Matft", 7 | "state" : { 8 | "revision" : "6372049074c17dfd8e6cccd83af6228c601c7c63", 9 | "version" : "0.3.2" 10 | } 11 | }, 12 | { 13 | "identity" : "plain-pocketfft", 14 | "kind" : "remoteSourceControl", 15 | "location" : "https://github.com/dhrebeniuk/plain-pocketfft.git", 16 | "state" : { 17 | "revision" : "b9633be41ad61e40a93f58abafed9fb5e281ed18", 18 | "version" : "0.0.9" 19 | } 20 | }, 21 | { 22 | "identity" : "pocketfft", 23 | "kind" : "remoteSourceControl", 24 | "location" : "https://github.com/dhrebeniuk/pocketfft.git", 25 | "state" : { 26 | "revision" : "f7f2ac6085d9123ab130bb8daebcc11acd72d52c", 27 | "version" : "0.0.1" 28 | } 29 | }, 30 | { 31 | "identity" : "rosakit", 32 | "kind" : "remoteSourceControl", 33 | "location" : "https://github.com/dhrebeniuk/RosaKit", 34 | "state" : { 35 | "branch" : "main", 36 | "revision" : "a00984a4c320cef88effa68574318ec34d9f165c" 37 | } 38 | }, 39 | { 40 | "identity" : "swift-collections", 41 | "kind" : "remoteSourceControl", 42 | "location" : "https://github.com/apple/swift-collections", 43 | "state" : { 44 | "revision" : "937e904258d22af6e447a0b72c0bc67583ef64a2", 45 | "version" : "1.0.4" 46 | } 47 | } 48 | ], 49 | "version" : 2 50 | } 51 | -------------------------------------------------------------------------------- /Whisper/Whisper.xcodeproj/xcshareddata/xcschemes/Whisper.xcscheme: -------------------------------------------------------------------------------- 1 | 2 | 5 | 8 | 9 | 15 | 21 | 22 | 23 | 24 | 25 | 30 | 31 | 32 | 33 | 45 | 47 | 53 | 54 | 55 | 56 | 62 | 64 | 70 | 71 | 72 | 73 | 75 | 76 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /Whisper/Whisper/Assets.xcassets/AccentColor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "idiom" : "universal" 5 | } 6 | ], 7 | "info" : { 8 | "author" : "xcode", 9 | "version" : 1 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /Whisper/Whisper/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "idiom" : "universal", 5 | "platform" : "ios", 6 | "size" : "1024x1024" 7 | } 8 | ], 9 | "info" : { 10 | "author" : "xcode", 11 | "version" : 1 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /Whisper/Whisper/Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Whisper/Whisper/AudioLoader.swift: -------------------------------------------------------------------------------- 1 | // 2 | // AudioLoader.swift 3 | // Whisper 4 | // 5 | // Created by vade on 1/2/23. 6 | // 7 | 8 | import Foundation 9 | import Cocoa 10 | 11 | class AudioLoader: NSObject, ObservableObject 12 | { 13 | func selectFileURL() -> URL { 14 | 15 | let openpanel = NSOpenPanel() 16 | 17 | openpanel.runModal() 18 | 19 | return openpanel.url! 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /Whisper/Whisper/AudioRecorder.swift: -------------------------------------------------------------------------------- 1 | // 2 | // AudioRecorder.swift 3 | // Whisper 4 | // 5 | // Created by Tanmay Bakshi on 2022-09-27. 6 | // 7 | 8 | import AVFoundation 9 | import SwiftUI 10 | 11 | #if os(iOS) 12 | 13 | func getDocumentsDirectory() -> URL { 14 | let paths = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask) 15 | return paths[0] 16 | } 17 | 18 | class AudioRecorder: NSObject, ObservableObject, AVAudioRecorderDelegate { 19 | static let audioURL = getDocumentsDirectory().appendingPathComponent("query.wav") 20 | 21 | private var recordingSession: AVAudioSession! 22 | private var audioRecorder: AVAudioRecorder! 23 | 24 | @Published private(set) var canRecord = false 25 | @Published private(set) var recording = false 26 | 27 | enum RecordingError: Error { 28 | case invalidFormat 29 | case noBuffer 30 | } 31 | 32 | func setup() { 33 | recordingSession = AVAudioSession.sharedInstance() 34 | do { 35 | try recordingSession.setCategory(.record) 36 | try recordingSession.setMode(.measurement) 37 | } catch let error { 38 | print(error) 39 | } 40 | do { 41 | try recordingSession.setCategory(.playAndRecord, mode: .default) 42 | try recordingSession.setActive(true) 43 | recordingSession.requestRecordPermission { allowed in 44 | DispatchQueue.main.async { 45 | if allowed { 46 | self.canRecord = true 47 | } else { 48 | fatalError("User did not allow access to microphone") 49 | } 50 | } 51 | } 52 | } catch let error { 53 | fatalError("\(error)") 54 | } 55 | } 56 | 57 | func startRecording() throws { 58 | let settings = [ 59 | AVFormatIDKey: Int(kAudioFormatLinearPCM), 60 | AVSampleRateKey: 16000, 61 | AVNumberOfChannelsKey: 1, 62 | AVEncoderAudioQualityKey: AVAudioQuality.high.rawValue 63 | ] 64 | 65 | audioRecorder = try AVAudioRecorder(url: Self.audioURL, settings: settings) 66 | audioRecorder.delegate = self 67 | audioRecorder.record() 68 | recording = true 69 | } 70 | 71 | func finishRecording() throws -> URL { 72 | audioRecorder.stop() 73 | audioRecorder = nil 74 | recording = false 75 | 76 | return AudioRecorder.audioURL 77 | } 78 | 79 | func audioRecorderDidFinishRecording(_ recorder: AVAudioRecorder, successfully flag: Bool) { 80 | if !flag { 81 | fatalError("Error in recording") 82 | } 83 | } 84 | } 85 | 86 | #endif 87 | -------------------------------------------------------------------------------- /Whisper/Whisper/ContentView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // ContentView.swift 3 | // Whisper 4 | // 5 | // Created by Tanmay Bakshi on 2022-09-26. 6 | // 7 | 8 | import SwiftUI 9 | import AVFoundation 10 | 11 | struct ContentView: View { 12 | let whisper: Whisper 13 | 14 | #if os(iOS) 15 | 16 | @ObservedObject var recorder = AudioRecorder() 17 | 18 | #elseif os(macOS) 19 | 20 | @ObservedObject var loader = AudioLoader() 21 | 22 | #endif 23 | 24 | init() throws { 25 | whisper = try Whisper() 26 | } 27 | 28 | var body: some View { 29 | #if os(iOS) 30 | 31 | ZStack { 32 | RoundedRectangle(cornerRadius: 20) 33 | .frame(height: 60) 34 | .padding() 35 | .foregroundColor(!recorder.canRecord || recorder.recording ? .gray : .blue) 36 | 37 | Text(!recorder.canRecord ? "Waiting for permissions..." : (recorder.recording ? "Recording..." : "Record")) 38 | .font(.title2) 39 | .bold() 40 | .foregroundColor(.white) 41 | } 42 | .onTapGesture { 43 | if recorder.canRecord && !recorder.recording { 44 | 45 | do { 46 | try recorder.startRecording() 47 | } catch let error { 48 | fatalError("Couldn't record. Error: \(error)") 49 | } 50 | 51 | let audioAssetURL: URL 52 | do { 53 | audioAssetURL = try recorder.finishRecording() 54 | } catch let error { 55 | fatalError("Couldn't finish recording. Error: \(error)") 56 | } 57 | 58 | getAudioPredict(url: audioAssetURL) 59 | } 60 | } 61 | .onAppear { 62 | recorder.setup() 63 | } 64 | 65 | #elseif os(macOS) 66 | ZStack { 67 | RoundedRectangle(cornerRadius: 20) 68 | .frame(height: 60) 69 | .padding() 70 | .foregroundColor(.blue) 71 | 72 | Text("Load Audio File") 73 | .font(.title2) 74 | .bold() 75 | .foregroundColor(.white) 76 | } 77 | .onTapGesture { 78 | 79 | getAudioPredict(url: loader.selectFileURL()) 80 | 81 | } 82 | #endif 83 | 84 | } 85 | 86 | func getAudioPredict(url:URL) { 87 | 88 | Task { 89 | let start = Date().timeIntervalSince1970 90 | 91 | let options:Whisper.WhisperOptions = Whisper.WhisperOptions(task: .Transcribe, 92 | format: .Text, 93 | verbose: true) 94 | 95 | let transcription = await whisper.transcribe(assetURL: url, options: options) 96 | 97 | print(transcription) 98 | } 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /Whisper/Whisper/Info.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /Whisper/Whisper/Preview Content/Preview Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Whisper/Whisper/Whisper/FFT.swift: -------------------------------------------------------------------------------- 1 | // 2 | // FFTImplementations.swift 3 | // Whisper 4 | // 5 | // Created by Anton Marini on 1/12/23. 6 | // 7 | 8 | import Foundation 9 | import Accelerate 10 | 11 | //figure out if i need to remove the dc offset and nyquist from the returned arrays because maybe im loosing a bin? 12 | 13 | 14 | public class WhisperFFT 15 | { 16 | 17 | var window:[Double] 18 | var numFFT:Int = 400 19 | 20 | init(numFFT:Int) 21 | { 22 | self.numFFT = numFFT 23 | 24 | self.window = vDSP.window(ofType: Double.self, 25 | usingSequence: .hanningDenormalized, 26 | count: self.numFFT, 27 | isHalfWindow: false) 28 | } 29 | } 30 | 31 | public class NumpyRFFT : WhisperFFT 32 | { 33 | // var fft : vDSP.FFT 34 | var nOver2:Int! 35 | var logSize:vDSP_Length! 36 | 37 | override init(numFFT: Int) 38 | { 39 | super.init(numFFT:numFFT) 40 | 41 | self.logSize = vDSP_Length(floor(log2(Float(self.numFFT)))) 42 | self.nOver2 = (numFFT / 2) 43 | 44 | // self.fft = vDSP.FFT(log2n: UInt(self.nOver2), 45 | // radix: .radix2, 46 | // ofType: DSPDoubleSplitComplex.self)! 47 | 48 | } 49 | 50 | public func forward(_ audioFrame:[Double]) -> ([Double], [Double]) 51 | { 52 | precondition(audioFrame.count == self.numFFT, "FFT Size and Audio frame size doesnt match") 53 | 54 | var windowedAudioFrame = [Double](repeating: 0, count: self.numFFT) 55 | 56 | vDSP.multiply(audioFrame, 57 | self.window, 58 | result: &windowedAudioFrame) 59 | 60 | var sampleReal:[Double] = [Double](repeating: 0, count: self.nOver2) 61 | var sampleImaginary:[Double] = [Double](repeating: 0, count: self.nOver2) 62 | 63 | var resultReal:[Double] = [Double](repeating: 0, count: self.numFFT) 64 | var resultImaginary:[Double] = [Double](repeating: 0, count:self.numFFT) 65 | 66 | let fftSetup:FFTSetupD = vDSP_create_fftsetupD(self.logSize, FFTRadix(kFFTRadix2))!; 67 | 68 | sampleReal.withUnsafeMutableBytes { unsafeReal in 69 | sampleImaginary.withUnsafeMutableBytes { unsafeImaginary in 70 | 71 | resultReal.withUnsafeMutableBytes { unsafeResultReal in 72 | resultImaginary.withUnsafeMutableBytes { unsafeResultImaginary in 73 | 74 | var complexSignal = DSPDoubleSplitComplex(realp: unsafeReal.bindMemory(to: Double.self).baseAddress!, 75 | imagp: unsafeImaginary.bindMemory(to: Double.self).baseAddress!) 76 | 77 | let complexResult = DSPDoubleSplitComplex(realp: unsafeResultReal.bindMemory(to: Double.self).baseAddress!, 78 | imagp: unsafeResultImaginary.bindMemory(to: Double.self).baseAddress!) 79 | 80 | // Treat our windowed audio as a Interleaved Complex 81 | // And convert it into a split complex Signal 82 | windowedAudioFrame.withUnsafeBytes { unsafeAudioBytes in 83 | let letInterleavedComplexAudio = [DSPDoubleComplex](unsafeAudioBytes.bindMemory(to: DSPDoubleComplex.self)) 84 | 85 | vDSP_ctozD(letInterleavedComplexAudio, 2, &complexSignal, 1, vDSP_Length(self.nOver2)) ; 86 | 87 | vDSP_fft_zripD (fftSetup, &complexSignal, 1, self.logSize, FFTDirection(kFFTDirection_Forward)); 88 | } 89 | 90 | // Scale by 1/2 : https://stackoverflow.com/questions/51804365/why-is-fft-different-in-swift-than-in-python 91 | var scaleFactor = Double( 1.0/2.0 ) // * 1.165 ?? 92 | vDSP_vsmulD(complexSignal.realp, 1, &scaleFactor, complexSignal.realp, 1, vDSP_Length(self.nOver2)) 93 | vDSP_vsmulD(complexSignal.imagp, 1, &scaleFactor, complexSignal.imagp, 1, vDSP_Length(self.nOver2)) 94 | 95 | // Borrowed from https://github.com/jseales/numpy-style-fft-in-obj-c 96 | complexResult.realp[0] = complexSignal.realp[0]; 97 | complexResult.imagp[0] = 0; 98 | complexResult.realp[self.nOver2] = complexSignal.imagp[0]; 99 | complexResult.imagp[self.nOver2] = 0; 100 | 101 | for (i) in 1 ..< self.nOver2 102 | { 103 | complexResult.realp[i] = complexSignal.realp[i]; 104 | complexResult.imagp[i] = complexSignal.imagp[i]; 105 | 106 | // Complex conjugate is mirrored (?) 107 | complexResult.realp[self.numFFT - i] = complexSignal.realp[i]; 108 | complexResult.imagp[self.numFFT - i] = complexSignal.imagp[i]; 109 | } 110 | } 111 | } 112 | } 113 | } 114 | 115 | return (resultReal, resultImaginary) 116 | } 117 | 118 | } 119 | 120 | public class NumpyFFT : WhisperFFT 121 | { 122 | // var fft : vDSP.FFT 123 | var n:Int! 124 | var nOver2:Int! 125 | var logSize:vDSP_Length! 126 | 127 | override init(numFFT: Int) 128 | { 129 | super.init(numFFT:numFFT) 130 | 131 | self.logSize = vDSP_Length(floor(log2(Float(self.numFFT)))) 132 | self.n = numFFT 133 | self.nOver2 = (numFFT / 2) 134 | 135 | // self.fft = vDSP.FFT(log2n: UInt(self.nOver2), 136 | // radix: .radix2, 137 | // ofType: DSPDoubleSplitComplex.self)! 138 | 139 | } 140 | 141 | public func forward(_ audioFrame:[Double]) -> ([Double], [Double]) 142 | { 143 | var sampleReal = [Double](repeating: 0, count: self.numFFT) 144 | 145 | vDSP.multiply(audioFrame, 146 | self.window, 147 | result: &sampleReal) 148 | 149 | var sampleImaginary:[Double] = [Double](repeating: 0, count: self.numFFT) 150 | 151 | let fftSetup:FFTSetupD = vDSP_create_fftsetupD(self.logSize, FFTRadix(kFFTRadix2))!; 152 | 153 | sampleReal.withUnsafeMutableBytes { unsafeReal in 154 | sampleImaginary.withUnsafeMutableBytes { unsafeImaginary in 155 | 156 | var complexSignal = DSPDoubleSplitComplex(realp: unsafeReal.bindMemory(to: Double.self).baseAddress!, 157 | imagp: unsafeImaginary.bindMemory(to: Double.self).baseAddress!) 158 | 159 | vDSP_fft_zipD(fftSetup, &complexSignal, 1, self.logSize, FFTDirection(kFFTDirection_Forward)); 160 | } 161 | } 162 | 163 | return (sampleReal, sampleImaginary) 164 | } 165 | 166 | } 167 | -------------------------------------------------------------------------------- /Whisper/Whisper/Whisper/MLMultiArray+Utils.swift: -------------------------------------------------------------------------------- 1 | // 2 | // MLMultiArray+Utils.swift 3 | // CoreMLBert 4 | // 5 | // Created by Julien Chaumond on 27/06/2019. 6 | // Copyright © 2019 Hugging Face. All rights reserved. 7 | // 8 | 9 | import Foundation 10 | import CoreML 11 | 12 | extension MLMultiArray { 13 | /// All values will be stored in the last dimension of the MLMultiArray (default is dims=1) 14 | static func from(_ arr: [Int], dims: Int = 1) -> MLMultiArray { 15 | var shape = Array(repeating: 1, count: dims) 16 | shape[shape.count - 1] = arr.count 17 | /// Examples: 18 | /// dims=1 : [arr.count] 19 | /// dims=2 : [1, arr.count] 20 | /// 21 | let o = try! MLMultiArray(shape: shape as [NSNumber], dataType: .int32) 22 | let ptr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) 23 | for (i, item) in arr.enumerated() { 24 | ptr[i] = Int32(item) 25 | } 26 | return o 27 | } 28 | 29 | /// This will concatenate all dimensions into one one-dim array. 30 | static func toIntArray(_ o: MLMultiArray) -> [Int] { 31 | var arr = Array(repeating: 0, count: o.count) 32 | let ptr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) 33 | for i in 0.. [Double] { 41 | var arr: [Double] = Array(repeating: 0, count: o.count) 42 | let ptr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) 43 | for i in 0.. MLMultiArray { 61 | let arr = try! MLMultiArray(shape: shape as [NSNumber], dataType: .double) 62 | let ptr = UnsafeMutablePointer(OpaquePointer(arr.dataPointer)) 63 | for i in 0.. MLMultiArray { 83 | assert( 84 | indexing.count == o.shape.count 85 | ) 86 | assert( 87 | indexing.filter { $0 == Indexing.slice }.count == 1 88 | ) 89 | var selectDims: [Int: Int] = [:] 90 | for (i, idx) in indexing.enumerated() { 91 | if case .select(let select) = idx { 92 | selectDims[i] = select 93 | } 94 | } 95 | return slice( 96 | o, 97 | sliceDim: indexing.firstIndex { $0 == Indexing.slice }!, 98 | selectDims: selectDims 99 | ) 100 | } 101 | 102 | /// Slice an array according to a list, according to `sliceDim` (which dimension to slice on) 103 | /// and a dictionary of `dim` to `index`. 104 | /// 105 | /// You must select all other dimensions than the slice dimension (cf. the assert). 106 | static func slice(_ o: MLMultiArray, sliceDim: Int, selectDims: [Int: Int]) -> MLMultiArray { 107 | assert( 108 | selectDims.count + 1 == o.shape.count 109 | ) 110 | var shape: [NSNumber] = Array(repeating: 1, count: o.shape.count) 111 | shape[sliceDim] = o.shape[sliceDim] 112 | /// print("About to slice ndarray of shape \(o.shape) into ndarray of shape \(shape)") 113 | let arr = try! MLMultiArray(shape: shape, dataType: .double) 114 | 115 | /// let srcPtr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) 116 | /// TODO: use srcPtr instead of array subscripting. 117 | let dstPtr = UnsafeMutablePointer(OpaquePointer(arr.dataPointer)) 118 | for i in 0.. String { 145 | func indent(_ x: Int) -> String { 146 | return String(repeating: " ", count: x) 147 | } 148 | 149 | // This function is called recursively for every dimension. 150 | // Add an entry for this dimension to the end of the array. 151 | var indices = indices + [0] 152 | 153 | let d = indices.count - 1 // the current dimension 154 | let N = shape[d].intValue // how many elements in this dimension 155 | var s = "[" 156 | if indices.count < shape.count { // not last dimension yet? 157 | for i in 0.. MfArray 15 | //// { 16 | //// let shape = self.shape 17 | //// let type = self. 18 | //// 19 | //// 20 | //// self.withUnsafeShapedBufferPointer { ptr, shape, strides in 21 | //// 22 | //// let mat = Matft< 23 | //// 24 | //// } 25 | //// } 26 | //} 27 | // 28 | //extension MfArray 29 | //{ 30 | // public func toShapedArray() -> any MLShapedArrayProtocol 31 | // { 32 | // let flatArray = self.data 33 | //// let shapedArray = MLShapedArray(scalars: <#T##Sequence#>, shape:self.shape) 34 | // } 35 | //} 36 | -------------------------------------------------------------------------------- /Whisper/Whisper/Whisper/Math.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Math.swift 3 | // CoreMLBert 4 | // 5 | // Created by Julien Chaumond on 27/06/2019. 6 | // Copyright © 2019 Hugging Face. All rights reserved. 7 | // 8 | 9 | import Foundation 10 | import Accelerate 11 | import CoreML 12 | 13 | /// 14 | /// From M.I. Hollemans 15 | /// 16 | /// https://github.com/hollance/CoreMLHelpers 17 | /// 18 | struct Math { 19 | 20 | /** 21 | Returns the index and value of the largest element in the array. 22 | 23 | - Parameters: 24 | - ptr: Pointer to the first element in memory. 25 | - count: How many elements to look at. 26 | - stride: The distance between two elements in memory. 27 | */ 28 | static func argmax(_ ptr: UnsafePointer, count: Int, stride: Int = 1) -> (Int, Float) { 29 | var maxValue: Float = 0 30 | var maxIndex: vDSP_Length = 0 31 | vDSP_maxvi(ptr, vDSP_Stride(stride), &maxValue, &maxIndex, vDSP_Length(count)) 32 | return (Int(maxIndex), maxValue) 33 | } 34 | 35 | /** 36 | Returns the index and value of the largest element in the array. 37 | - Parameters: 38 | - ptr: Pointer to the first element in memory. 39 | - count: How many elements to look at. 40 | - stride: The distance between two elements in memory. 41 | */ 42 | static func argmax(_ ptr: UnsafePointer, count: Int, stride: Int = 1) -> (Int, Double) { 43 | var maxValue: Double = 0 44 | var maxIndex: vDSP_Length = 0 45 | vDSP_maxviD(ptr, vDSP_Stride(stride), &maxValue, &maxIndex, vDSP_Length(count)) 46 | return (Int(maxIndex), maxValue) 47 | } 48 | 49 | 50 | /// MLMultiArray helper. 51 | /// Works in our specific use case. 52 | static func argmax(_ multiArray: MLMultiArray) -> (Int, Double) { 53 | assert(multiArray.dataType == .double) 54 | let ptr = UnsafeMutablePointer(OpaquePointer(multiArray.dataPointer)) 55 | return Math.argmax(ptr, count: multiArray.count) 56 | } 57 | 58 | /// MLMultiArray helper. 59 | /// Works in our specific use case. 60 | static func argmax32(_ multiArray: MLMultiArray) -> (Int, Float) { 61 | assert(multiArray.dataType == .float32) 62 | let ptr = UnsafeMutablePointer(OpaquePointer(multiArray.dataPointer)) 63 | let count = multiArray.count 64 | var maxValue: Float = 0 65 | var maxIndex: vDSP_Length = 0 66 | vDSP_maxvi(ptr, vDSP_Stride(1), &maxValue, &maxIndex, vDSP_Length(count)) 67 | return (Int(maxIndex), maxValue) 68 | } 69 | 70 | /// Top-K. 71 | /// Select the k most-probable elements indices from `arr` 72 | /// and return both the indices (from the original array) 73 | /// and their softmaxed probabilities. 74 | /// 75 | static func topK(arr: [Double], k: Int) -> (indexes: [Int], probs: [Float]) { 76 | let x = Array(arr.enumerated().map { ($0, $1) } 77 | .sorted(by: { a, b -> Bool in a.1 > b.1 }) 78 | .prefix(through: min(k, arr.count) - 1)) 79 | let indexes = x.map { $0.0 } 80 | let logits = x.map { Float($0.1) } 81 | let probs = softmax(logits) 82 | return (indexes: indexes, probs: probs) 83 | } 84 | 85 | /// Multinomial sampling from an array of probs. Works well with topK 86 | static func sample(indexes: [Int], probs: [Float]) -> Int { 87 | let i = randomNumber(probabilities: probs) 88 | return indexes[i] 89 | } 90 | 91 | /** 92 | Computes the "softmax" function over an array. 93 | Based on code from https://github.com/nikolaypavlov/MLPNeuralNet/ 94 | This is what softmax looks like in "pseudocode" (actually using Python 95 | and numpy): 96 | x -= np.max(x) 97 | exp_scores = np.exp(x) 98 | softmax = exp_scores / np.sum(exp_scores) 99 | First we shift the values of x so that the highest value in the array is 0. 100 | This ensures numerical stability with the exponents, so they don't blow up. 101 | */ 102 | static func softmax(_ x: [Float]) -> [Float] { 103 | var x = x 104 | let len = vDSP_Length(x.count) 105 | 106 | // Find the maximum value in the input array. 107 | var max: Float = 0 108 | vDSP_maxv(x, 1, &max, len) 109 | 110 | // Subtract the maximum from all the elements in the array. 111 | // Now the highest value in the array is 0. 112 | max = -max 113 | vDSP_vsadd(x, 1, &max, &x, 1, len) 114 | 115 | // Exponentiate all the elements in the array. 116 | var count = Int32(x.count) 117 | vvexpf(&x, x, &count) 118 | 119 | // Compute the sum of all exponentiated values. 120 | var sum: Float = 0 121 | vDSP_sve(x, 1, &sum, len) 122 | 123 | // Divide each element by the sum. This normalizes the array contents 124 | // so that they all add up to 1. 125 | vDSP_vsdiv(x, 1, &sum, &x, 1, len) 126 | 127 | return x 128 | } 129 | 130 | /// Multinomial sampling 131 | /// 132 | /// From https://stackoverflow.com/questions/30309556/generate-random-numbers-with-a-given-distribution 133 | /// 134 | static func randomNumber(probabilities: [Float]) -> Int { 135 | // Sum of all probabilities (so that we don't have to require that the sum is 1.0): 136 | let sum = probabilities.reduce(0, +) 137 | // Random number in the range 0.0 <= rnd < sum : 138 | let rnd = sum * Float(arc4random_uniform(UInt32.max)) / Float(UInt32.max) 139 | // Find the first interval of accumulated probabilities into which `rnd` falls: 140 | var accum: Float = 0.0 141 | for (i, p) in probabilities.enumerated() { 142 | accum += p 143 | if rnd < accum { 144 | return i 145 | } 146 | } 147 | // This point might be reached due to floating point inaccuracies: 148 | return (probabilities.count - 1) 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /Whisper/Whisper/Whisper/MelSpectrogram.swift: -------------------------------------------------------------------------------- 1 | // 2 | // stft.swift 3 | // Whisper 4 | // 5 | // Created by Tanmay Bakshi on 2022-09-26. 6 | // 7 | import Accelerate 8 | 9 | // Reference implementation we are attempting to match 10 | // https://github.com/openai/whisper/blob/main/whisper/audio.py#L92 11 | 12 | // See 13 | // https://colab.research.google.com/drive/1r9ghakH8__jGqGiYHC2DXtKaW_ozdSrV?usp=sharing 14 | // For simple isolated code to test this implementation 15 | 16 | /* 17 | window = torch.hann_window(N_FFT).to(audio.device) 18 | 1, 2, 3 - stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) 19 | 20 | 4 - magnitudes = stft[:, :-1].abs() ** 2 21 | 22 | 5 - filters = mel_filters(audio.device, n_mels) 23 | 6 - mel_spec = filters @ magnitudes 24 | 25 | 7 - log_spec = torch.clamp(mel_spec, min=1e-10).log10() 26 | 8 - log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) 27 | 9 - log_spec = (log_spec + 4.0) / 4.0 28 | 29 | // Reference shapes - note - we dont match perfectly 30 | stft torch.Size([201, 3001]) 31 | magnitudes torch.Size([201, 3000]) 32 | mel filters torch.Size([80, 201]) 33 | mel spec torch.Size([80, 3000]) 34 | log spec torch.Size([80, 3000]) 35 | 36 | */ 37 | 38 | // https://pytorch.org/docs/stable/generated/torch.stft.html 39 | // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/SpectralOps.cpp#L820 40 | // SEE https://dsp.stackexchange.com/questions/49184/stft-amplitude-normalization-librosa-library 41 | // See https://github.com/Jounce/Surge/issues/94 42 | // see https://github.com/abokhalel2/istft/blob/main/swift/istft/ViewController.swift 43 | 44 | // Some notes 45 | // We do not calculate a 3001 mel, we skip the last since it wont be used anyway and is dropped later, saving 1/3000th of work. 46 | // 47 | 48 | 49 | // alternatively 50 | // http://www.ml-illustrated.com/2020/06/01/deploy-pytorch-model-with-coreml-convert-issues.html 51 | // https://github.com/ml-illustrated/Pytorch-CoreML-Spectrogram/blob/d0dd6c55eaf5fdcfaf00b1f036b258bd144b1ac4/python/model.py#L142 52 | 53 | public class MelSpectrogram 54 | { 55 | // MARK: Properties 56 | 57 | /// windows for each mel frequency 58 | /// Our 80 x 201 sized matrix of 16080 float values of precomputed filters. 59 | var melFilterMatrix:[[Double]] 60 | 61 | /// Tthe width of the spectrogram. 62 | var melSampleCount:Int = 3000 63 | 64 | /// The height of the spectrogram. 65 | var melFilterBankCount:Int = 80 66 | 67 | /// The number of audio samples per chunk. 68 | var sampleCount:Int = 480000 69 | 70 | /// Determines the overlap between samples for an FFT. 71 | var hopCount:Int = 160 72 | 73 | /// The forward fast Fourier transform object. 74 | var stft:STFT 75 | 76 | 77 | init(sampleCount:Int, hopCount:Int, melCount:Int, numFFT:Int) 78 | { 79 | self.sampleCount = sampleCount 80 | self.hopCount = hopCount 81 | self.melFilterBankCount = melCount 82 | 83 | self.melSampleCount = self.sampleCount / self.hopCount 84 | 85 | self.melFilterMatrix = MelSpectrogram.makeFilterBankWithNumpyData() 86 | 87 | self.stft = STFT(fftLength: numFFT, windowType: .hanningDenormalized, windowLength: numFFT, sampleCount: sampleCount, hopCount: hopCount, center: true, padding: .Reflect) 88 | } 89 | 90 | // This method DOES NOT WORK 91 | // I'm keeping it here because portions do work correcty, but our FFT / STFT implementation appears to be incorrect? 92 | // See https://www.reddit.com/r/DSP/comments/10dpyzx/help_with_stft_log_mel_spectrogram/ 93 | 94 | func processData(audio: [Int16]) -> [Float] 95 | { 96 | // Calculate STFT 97 | var (allSampleReal, allSampleImaginary) = self.stft.calculateSTFT(audio: audio) 98 | 99 | // drop the 3001'st column as per Python 100 | allSampleReal = Array(allSampleReal.prefix(upTo: Whisper.kWhisperNumSamplesInMel)) // 3000 101 | allSampleImaginary = Array(allSampleImaginary.prefix(upTo: Whisper.kWhisperNumSamplesInMel)) // 3000 102 | 103 | // Unroll matrices into flat arrays for vDSP 104 | var flattnedReal:[Double] = allSampleReal.flatMap { $0 } 105 | var flattnedImaginary:[Double] = allSampleImaginary.flatMap { $0 } 106 | 107 | let flattenedMelMatrix = self.melFilterMatrix.flatMap{ $0 } 108 | // let flattenedMelMatrix = [Double].createMelFilter(sampleRate: Whisper.kWhisperSampleRate, FTTCount:Whisper.kWhisperNumFFTs, melsCount:Whisper.kWhisperNumMels).flatMap { $0 } 109 | 110 | // print("Swift 0 - complex real min", vDSP.minimum(flattnedReal), "max", vDSP.maximum(flattnedReal)) 111 | // print("Swift 0 - complex imag min", vDSP.minimum(flattnedImaginary), "max", vDSP.maximum(flattnedImaginary)) 112 | 113 | // Take the magnitude squared of the matrix, which results in a Result flat array of 3000 x 200 of real floats 114 | // Then multiply it with our mel filter bank 115 | var magnitudes = [Double](repeating: 0, count: flattnedReal.count) 116 | var melSpectroGram = [Double](repeating: 0, count: Whisper.kWhisperNumMels * Whisper.kWhisperNumSamplesInMel) // 80 x 3000 117 | 118 | flattnedReal.withUnsafeMutableBytes { unsafeFlatReal in 119 | flattnedImaginary.withUnsafeMutableBytes { unsafeFlatImaginary in 120 | 121 | // We create a Split Complex representation of our flattened real and imaginary component 122 | let complexMatrix = DSPDoubleSplitComplex(realp: unsafeFlatReal.bindMemory(to: Double.self).baseAddress!, 123 | imagp: unsafeFlatImaginary.bindMemory(to: Double.self).baseAddress!) 124 | 125 | vDSP.squareMagnitudes(complexMatrix, result: &magnitudes) 126 | 127 | 128 | // print("Swift 1 - magnitudes min ", vDSP.minimum(magnitudes), "max", vDSP.maximum(magnitudes)) 129 | // print("Swift 1 - magnitudes min ", vDSP.minimum(flattenedMagnitudes), "max", vDSP.maximum(flattenedMagnitudes)) 130 | 131 | // transpose magnitudes from 3000 X 201, to 201 x 3000 132 | vDSP_mtransD(magnitudes, 1, &magnitudes, 1, 201, vDSP_Length(Whisper.kWhisperNumSamplesInMel)) // verified correct 133 | 134 | 135 | // Step 5 & 6 (filters loaded earlier) 136 | 137 | // MATRIX A, a MxK sized matrix 138 | // MATRIX B, a KxN sized matrix 139 | // MATRIX C, a MxN sized matrix 140 | 141 | // MATRIX A mel filters is 80 rows x 200 columns 142 | // MATRIX B magnitudes is 3000 x 200 143 | // MATRIX B is TRANSPOSED to be 200 rows x 3000 columns 144 | // MATRIX C melSpectroGram is 80 rows x 3000 columns 145 | 146 | let M: Int32 = Int32(Whisper.kWhisperNumMels) // number of rows in matrix A 147 | let N: Int32 = Int32(Whisper.kWhisperNumSamplesInMel) // number of columns in matrix B 148 | let K: Int32 = 201 // number of columns in matrix A and number of rows in 149 | 150 | // matrix multiply magitude squared matrix with our filter bank 151 | // see https://www.advancedswift.com/matrix-math/ 152 | cblas_dgemm(CblasRowMajor, 153 | CblasNoTrans, // Transpose A 154 | CblasNoTrans, // 155 | M, // M Number of rows in matrices A and C. 156 | N, // N Number of columns in matrices B and C. 157 | K, // K Number of columns in matrix A; number of rows in matrix B. 158 | 1.0, // Alpha Scaling factor for the product of matrices A and B. 159 | flattenedMelMatrix, // Matrix A 160 | K, // LDA The size of the first dimension of matrix A; if you are passing a matrix A[m][n], the value should be m. 161 | magnitudes, // Matrix B 162 | N, // LDB The size of the first dimension of matrix B; if you are passing a matrix B[m][n], the value should be m. 163 | 0, // Beta Scaling factor for matrix C. 164 | &melSpectroGram, // Matrix C 165 | N) // LDC The size of the first dimension of matrix C; if you are passing a matrix C[m][n], the value should be m. 166 | 167 | 168 | 169 | // print("Swift 2 - mel min ", vDSP.minimum(melSpectroGram), "max", vDSP.maximum(melSpectroGram)) 170 | 171 | // Step 7 - clamp / clip the min to 1e-10 172 | vDSP.threshold(melSpectroGram, to: 1e-10, with: .clampToThreshold, result: &melSpectroGram) 173 | 174 | // print("Swift 3 - mel clip min ", vDSP.minimum(melSpectroGram), "max", vDSP.maximum(melSpectroGram)) 175 | 176 | // Step 7 - Take the log base 10 - vDSP_vdbcon and power:toDecibels seems to fuck things up here and isnt right, even though its what everyone else uses? 177 | vForce.log10(melSpectroGram, result: &melSpectroGram) 178 | // print("Swift 4 - mel log min ", vDSP.minimum(melSpectroGram), "max", vDSP.maximum(melSpectroGram)) 179 | 180 | 181 | // Step 8 - 182 | // Clip to new max and updated min 183 | let newMin = vDSP.maximum(melSpectroGram) - 8.0 184 | vDSP.maximum(melSpectroGram, [Double](repeating: newMin, count: melSpectroGram.count), result: &melSpectroGram) 185 | 186 | // print("Swift 5 - mel log min ", vDSP.minimum(melSpectroGram), "max", vDSP.maximum(melSpectroGram)) 187 | 188 | // Step 9 - Add 4 and Divide by 4 189 | vDSP.add(4.0, melSpectroGram, result: &melSpectroGram) 190 | vDSP.divide(melSpectroGram, 4.0, result: &melSpectroGram) 191 | 192 | 193 | 194 | // print("Swift 6 - mel log norm min ", vDSP.minimum(melSpectroGram), "max", vDSP.maximum(melSpectroGram)) 195 | 196 | 197 | // print("--------------") 198 | // 199 | // print("Torch 0 - complex real min -11.8792142868 max 12.0689258575") 200 | // print("Torch 0 - complex imag min -10.5751876831 max 11.5213479996") 201 | // print("Torch 1 - magnitudes min 0.0000000000 max 165.6671142578") 202 | // print("Torch 2 - mel min 0.0000000036 max 4.2800636292") 203 | // print("Torch 3 - mel clip min 0.0000000036 max 4.2800636292") 204 | // print("Torch 4 - mel log min -8.4495277405 max 0.6314502358") 205 | // print("Torch 5 - mel log min -7.3685498238 max 0.6314502358") 206 | // print("Torch 6 - mel log norm min -0.8421374559 max 1.1578625441") 207 | } 208 | } 209 | 210 | return vDSP.doubleToFloat(melSpectroGram) 211 | } 212 | 213 | // This method works, but doesnt use the vDSP as fully as possible. 214 | func processDataRosa(audio: [Int16]) -> [Float] 215 | { 216 | // COnvert to a normalized Float representation 217 | var audioFloat:[Double] = [Double](repeating: 0, count: audio.count) 218 | vDSP.convertElements(of: audio, to: &audioFloat) 219 | vDSP.divide(audioFloat, 32768.0, result: &audioFloat) 220 | 221 | // Modify the default spectrogram to produce magnitudes squared 222 | // Note, accelerated true produces incorrect results! 223 | var spectrogram = audioFloat.stft(nFFT: Whisper.kWhisperNumFFTs, hopLength:Whisper.kWhisperHopLength, isAccelerated: false).map { $0.map { pow(sqrt(pow($0.real, 2.0) + pow($0.imagine, 2.0)), 2.0) } } 224 | 225 | // Remove the 3001st row. 226 | spectrogram = spectrogram.transposed 227 | spectrogram.removeLast() 228 | spectrogram = spectrogram.transposed 229 | 230 | // Calculate the spectrogram 231 | let melBasis = [Double].createMelFilter(sampleRate: Whisper.kWhisperSampleRate, FTTCount:Whisper.kWhisperNumFFTs, melsCount:Whisper.kWhisperNumMels) 232 | 233 | let melDouble = melBasis.dot(matrix: spectrogram) 234 | 235 | var melSpectroGram:[Double] = melDouble.flatMap( { $0 } ) 236 | 237 | // Normalize the Mel Spectrogram into a Log Mel in the format Whisper expects: 238 | // print("Swift 2 - mel min ", vDSP.minimum(melSpectroGram), "max", vDSP.maximum(melSpectroGram)) 239 | 240 | // Step 7 - clamp / clip the min to 1e-10 241 | vDSP.threshold(melSpectroGram, to: 1e-10, with: .clampToThreshold, result: &melSpectroGram) 242 | 243 | // print("Swift 3 - mel clip min ", vDSP.minimum(melSpectroGram), "max", vDSP.maximum(melSpectroGram)) 244 | 245 | // Step 7 - Take the log base 10 - vDSP_vdbcon and power:toDecibels seems to fuck things up here and isnt right, even though its what everyone else uses? 246 | vForce.log10(melSpectroGram, result: &melSpectroGram) 247 | // print("Swift 4 - mel log min ", vDSP.minimum(melSpectroGram), "max", vDSP.maximum(melSpectroGram)) 248 | 249 | // Step 8 - 250 | // Clip to new max and updated min 251 | let newMin = vDSP.maximum(melSpectroGram) - 8.0 252 | vDSP.maximum(melSpectroGram, [Double](repeating: newMin, count: melSpectroGram.count), result: &melSpectroGram) 253 | 254 | // print("Swift 5 - mel log min ", vDSP.minimum(melSpectroGram), "max", vDSP.maximum(melSpectroGram)) 255 | 256 | // Step 9 - Add 4 and Divide by 4 257 | vDSP.add(4.0, melSpectroGram, result: &melSpectroGram) 258 | vDSP.divide(melSpectroGram, 4.0, result: &melSpectroGram) 259 | 260 | // print("Swift 6 - mel log norm min ", vDSP.minimum(melSpectroGram), "max", vDSP.maximum(melSpectroGram)) 261 | 262 | // print("--------------") 263 | 264 | // print("Torch 0 - complex real min -11.8792142868 max 12.0689258575") 265 | // print("Torch 0 - complex imag min -10.5751876831 max 11.5213479996") 266 | // print("Torch 1 - magnitudes min 0.0000000000 max 165.6671142578") 267 | // print("Torch 2 - mel min 0.0000000036 max 4.2800636292") 268 | // print("Torch 3 - mel clip min 0.0000000036 max 4.2800636292") 269 | // print("Torch 4 - mel log min -8.4495277405 max 0.6314502358") 270 | // print("Torch 5 - mel log min -7.3685498238 max 0.6314502358") 271 | // print("Torch 6 - mel log norm min -0.8421374559 max 1.1578625441") 272 | 273 | return vDSP.doubleToFloat(melSpectroGram) 274 | } 275 | 276 | static func makeFilterBankWithNumpyData() -> [[Double]] { 277 | // let numpyFloatArrayLength = 16080 278 | let fileURL = Bundle.main.url(forResource: "mel_filters", withExtension:"data") 279 | let fileHandle = try! FileHandle(forReadingFrom: fileURL!) 280 | 281 | let floatData = fileHandle.readDataToEndOfFile() 282 | let floatArray = floatData.withUnsafeBytes { unsafeFloatArray in 283 | return Array(UnsafeBufferPointer(start: unsafeFloatArray.bindMemory(to: Float.self).baseAddress!, count: floatData.count / MemoryLayout.stride) ) 284 | } 285 | 286 | let doubleArray = vDSP.floatToDouble(floatArray); 287 | 288 | return doubleArray.chunked(into: 201) 289 | } 290 | 291 | static func loadReferencePythonRawMelToDebugShit() -> [Float] { 292 | // let numpyFloatArrayLength = 16080 293 | let fileURL = Bundle.main.url(forResource: "python_log_mel", withExtension:"raw") 294 | let fileHandle = try! FileHandle(forReadingFrom: fileURL!) 295 | 296 | let floatData = fileHandle.readDataToEndOfFile() 297 | let floatArray = floatData.withUnsafeBytes { unsafeFloatArray in 298 | return Array(UnsafeBufferPointer(start: unsafeFloatArray.bindMemory(to: Float.self).baseAddress!, count: floatData.count / MemoryLayout.stride /*(80 * 3000)*/) ) 299 | } 300 | 301 | return floatArray; 302 | } 303 | 304 | // func power_to_db( 305 | // _ S: Matrix, 306 | // ref: Float = 1.0, 307 | // amin: Float = 1e-10, 308 | // top_db: Float? = 80.0 309 | // ) -> Matrix { 310 | // precondition(amin > 0) 311 | // let magnitude = S 312 | // let ref_value = abs(ref) 313 | // var log_spec = 10.0 * magnitude.vect { np.log10(np.maximum($0, amin)) } 314 | // log_spec = log_spec.vect { $0 - 10.0 * log10(max(amin, ref_value)) } 315 | // if let _top_db = top_db { 316 | // precondition(_top_db >= 0) 317 | // log_spec = log_spec.vect { np.maximum($0, max($0) - _top_db) } 318 | // } 319 | // return log_spec 320 | // } 321 | } 322 | 323 | 324 | 325 | 326 | extension Array { 327 | func chunked(into size: Int) -> [[Element]] { 328 | return stride(from: 0, to: count, by: size).map { 329 | Array(self[$0 ..< Swift.min($0 + size, count)]) 330 | } 331 | } 332 | } 333 | -------------------------------------------------------------------------------- /Whisper/Whisper/Whisper/STFT.swift: -------------------------------------------------------------------------------- 1 | // 2 | // STFT.swift 3 | // Whisper 4 | // 5 | // Created by Anton Marini on 1/14/23. 6 | // 7 | 8 | import Foundation 9 | import Accelerate 10 | import RosaKit 11 | 12 | class STFT 13 | { 14 | /// An attempt at mimiking Torch STFT 15 | /// Consume 'sampleCount' SInt16 single channel audio buffers 16 | /// Produce a complex STFT output 17 | /// 18 | /// Note, audioFrames we produce are not padded, or reflected or centered 19 | /// 20 | 21 | enum Padding { 22 | case Reflect 23 | case Zero 24 | // case 25 | } 26 | 27 | 28 | // MARK: FFT 29 | 30 | /// length of the FFT 31 | var fftLength:Int! 32 | 33 | /// FFT Window - should match one Frame of audio we process 34 | var fftWindowLength:Int! 35 | var fftWindowType:vDSP.WindowSequence! 36 | 37 | private var fft:NumpyRFFT! 38 | 39 | // MARK: STFT 40 | 41 | /// Total number of expected audio samples we process 42 | /// Our sample count should be divisible by our fftLength / 2 43 | var sampleCount:Int! 44 | 45 | /// Number of samples we shift forward when constructing a new audio frame out of our input audio 46 | var hopCount:Int! 47 | 48 | var padding:Padding! 49 | var center:Bool! 50 | 51 | // Calculate the number of iteractions we need to do 52 | // typically sampleCount / hopCount 53 | private var stftIterationCount:Int! 54 | 55 | 56 | init(fftLength:Int, windowType:vDSP.WindowSequence, windowLength:Int, sampleCount:Int, hopCount:Int, center:Bool = true, padding:Padding = .Reflect ) 57 | { 58 | self.fft = NumpyRFFT(numFFT: fftLength) 59 | 60 | self.fftLength = fftLength 61 | self.fftWindowType = windowType 62 | self.fftWindowLength = windowLength 63 | 64 | 65 | self.sampleCount = sampleCount 66 | self.hopCount = hopCount 67 | self.stftIterationCount = self.sampleCount / self.hopCount 68 | 69 | self.padding = padding 70 | self.center = center 71 | } 72 | 73 | /// Calculate STFT and return matrix of real and imaginary components calculated 74 | public func calculateSTFT(audio:[Int16]) -> ([[Double]], [[Double]]) 75 | { 76 | // Step 1 77 | assert(self.sampleCount == audio.count) 78 | 79 | var audioFloat:[Double] = [Double](repeating: 0, count: audio.count) 80 | 81 | vDSP.convertElements(of: audio, to: &audioFloat) 82 | // Audio now in Float, at Signed Int ranges - matches Pytorch Exactly 83 | 84 | vDSP.divide(audioFloat, 32768.0, result: &audioFloat) 85 | // Audio now in -1.0 to 1.0 Float ranges - matches Pytorch exactly 86 | 87 | // Center pad, reflect mode 88 | 89 | if (self.center) 90 | { 91 | switch ( self.padding ) 92 | { 93 | case .Reflect, .none: 94 | let reflectStart = audioFloat[0 ..< self.fftLength/2] 95 | let reflectEnd = audioFloat[audioFloat.count - 1 - self.fftLength/2 ..< audioFloat.count] 96 | 97 | audioFloat.insert(contentsOf:reflectStart.reversed(), at: 0) 98 | audioFloat.append(contentsOf:reflectEnd.reversed()) 99 | case .Zero: 100 | let zero:[Double] = [Double](repeating: 0, count: self.fftLength/2 ) 101 | 102 | audioFloat.insert(contentsOf:zero, at: 0) 103 | audioFloat.append(contentsOf:zero) 104 | } 105 | } 106 | else 107 | { 108 | // Alternatively all at the end? 109 | audioFloat.append(contentsOf: [Double](repeating: 0, count: self.fftLength)) 110 | } 111 | 112 | // Split Complex arrays holding the FFT results 113 | var allSampleReal:[[Double]] = [] 114 | var allSampleImaginary:[[Double]] = [] 115 | // var allSampleMagnitudes:[[Double]] = [] 116 | 117 | // Step 2 - we need to create 3001 x 201 matrix of windowed FFTs 118 | // Pytorch outputs complex numbers 119 | for (m) in 0 ... self.stftIterationCount 120 | { 121 | // Slice numFFTs every hop count (barf) and make a mel spectrum out of it 122 | // audioFrame ends up holding split complex numbers 123 | 124 | // TODO: Handle Pytorch STFT Defaults: 125 | // TODO: Handle Centering = True 126 | // TODO: Handle Padding = Reflect 127 | let audioFrame = Array( audioFloat[ (m * self.hopCount) ..< ( (m * self.hopCount) + self.fftLength ) ] ) 128 | 129 | assert(audioFrame.count == self.fftLength) 130 | 131 | var (real, imaginary) = self.fft.forward(audioFrame) 132 | 133 | // We divide our half our FFT output, 134 | // because the Pytorch `onesized` is true by default for real valued signals 135 | // See https://pytorch.org/docs/stable/generated/torch.stft.html 136 | 137 | // if (real.count == self.fftLength ) 138 | // { 139 | // real = Array(real.prefix(upTo:1 + self.fftLength / 2)) 140 | // imaginary = Array(imaginary.prefix(upTo:1 + self.fftLength / 2)) 141 | //// magnitudes = Array(magnitudes.prefix(upTo:1 + self.fftLength / 2)) 142 | // 143 | // } 144 | // 145 | // assert(real.count == 1 + self.fft.numFFT / 2) 146 | // assert(imaginary.count == 1 + self.fft.numFFT / 2) 147 | 148 | real = Array(real[0 ..< 201]) 149 | imaginary = Array(imaginary[0 ..< 201]) 150 | 151 | allSampleReal.append(real) 152 | allSampleImaginary.append(imaginary) 153 | // allSampleMagnitudes.append(magnitudes) 154 | } 155 | 156 | return (allSampleReal, allSampleImaginary) 157 | } 158 | 159 | // func calculateSTFTRosa(audio:[Int16], nFFT: Int = 256, hopLength: Int = 1024, isAccelerated: Bool = false) -> [[(real: Double, imagine: Double)]] { 160 | // 161 | // var audioFloat:[Double] = [Double](repeating: 0, count: audio.count) 162 | // 163 | // vDSP.convertElements(of: audio, to: &audioFloat) 164 | // 165 | // if (self.center) 166 | // { 167 | // switch ( self.padding ) 168 | // { 169 | // case .Reflect, .none: 170 | // let reflectStart = audioFloat[0 ..< self.fftLength/2] 171 | // let reflectEnd = audioFloat[audioFloat.count - self.fftLength/2 ..< audioFloat.count] 172 | // 173 | // audioFloat.insert(contentsOf:reflectStart.reversed(), at: 0) 174 | // audioFloat.append(contentsOf:reflectEnd.reversed()) 175 | // case .Zero: 176 | // let zero:[Double] = [Double](repeating: 0, count: self.fftLength/2 ) 177 | // 178 | // audioFloat.insert(contentsOf:zero, at: 0) 179 | // audioFloat.append(contentsOf:zero) 180 | // } 181 | // } 182 | // else 183 | // { 184 | // // Alternatively all at the end? 185 | // audioFloat.append(contentsOf: [Double](repeating: 0, count: self.fftLength)) 186 | // } 187 | // 188 | // 189 | //// let FFTWindow = [Double].getHannWindow(frameLength: (nFFT)).map { [$0] } 190 | // 191 | //// let FFTWWindow 192 | // 193 | //// let centered = audioDouble.reflectPad(fftSize: nFFT) 194 | // 195 | // let yFrames = audioFloat.frame(frameLength: nFFT, hopLength: hopLength) 196 | // 197 | // let matrix = FFTWindow.multiplyVector(matrix: yFrames) 198 | // 199 | // let rfftMatrix = isAccelerated ? matrix.acceleratedRFFT : matrix.rfft 200 | // 201 | // let result = rfftMatrix 202 | // 203 | // return result 204 | // } 205 | 206 | } 207 | -------------------------------------------------------------------------------- /Whisper/Whisper/Whisper/Whisper.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Whisper.swift 3 | // Whisper 4 | // 5 | // Created by Tanmay Bakshi on 2022-09-26. 6 | // 7 | 8 | import Foundation 9 | import CoreML 10 | import AVFoundation 11 | import Accelerate 12 | import RosaKit 13 | 14 | protocol WhisperLogitFilter 15 | { 16 | // Parameters 17 | // ---------- 18 | // logits : Tensor, shape = (n_batch, vocab_size) 19 | // per-token logits of the probability distribution at the current step 20 | // tokens : Tensor, shape = (n_batch, current_sequence_length) 21 | // all tokens in the context so far, including the prefix and sot_sequence tokens 22 | func apply(logits: inout MLShapedArray, tokens: inout MLShapedArray) 23 | } 24 | 25 | public class Whisper { 26 | 27 | // MARK: Public Constants Enums and Structs 28 | // hard-coded audio hyperparameters 29 | static let kWhisperSampleRate:Int = 16000; 30 | static let kWhisperNumFFTs:Int = 400; 31 | static let kWhisperNumMels:Int = 80; 32 | static let kWhisperHopLength:Int = 160; 33 | static let kWhisperChunkTimeSeconds:Int = 30; 34 | // kWhisperChunkTimeSeconds * kWhisperSampleRate # 480000: number of samples in a chunk 35 | static let kWhisperNumSamplesInChunk:Int = 480000; // Raw audio chunks we convert to MEL 36 | // exact_div(kWhisperNumSamplesInChunk, kWhisperHopLength) # 3000: number of frames in a mel spectrogram input 37 | static let kWhisperNumSamplesInMel:Int = 3000; // frames of Mel spectrograms 38 | 39 | enum WhisperError:Error 40 | { 41 | case notImplementedYet // Just havent gotten there hang tight. 42 | case unrecoverableError 43 | } 44 | 45 | /// Basic tasks types 46 | enum WhisperTask 47 | { 48 | case Transcribe 49 | case Translate 50 | } 51 | 52 | /// Transcript format - this is the string format of the returned transcript or translation task. 53 | enum WhisperTranscriptFormat 54 | { 55 | /// Output text only - Transcription or Translation 56 | case Text 57 | /// Output text with timestamps - suitable for Transcription only 58 | case TextAndTimestamps 59 | /// Soon - Transcript as VTT formatted text 60 | case VTT 61 | /// Soon - Transcript as SRT formatted text 62 | case SRT 63 | } 64 | 65 | /// Options to initialize a session with a task, language, 66 | /// See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L19 67 | struct WhisperOptions 68 | { 69 | var task:WhisperTask! 70 | var format:WhisperTranscriptFormat! 71 | 72 | var verbose = false 73 | 74 | // Below are WIP 75 | 76 | /// Temperature for sampling. It can be a tuple of temperatures, which will be successfully used 77 | /// upon failures according to either `compression_ratio_threshold` or `logprob_threshold`. 78 | var temperatureSchedule:[Float] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] 79 | 80 | /// If the gzip compression ratio is above this value, treat as failed 81 | var compressionRatioTresh:Float = 2.4 82 | /// If the average log probability over sampled tokens is below this value, treat as failed 83 | var logProbThresh:Float = -1.0 84 | /// If the no_speech probability is higher than this value AND the average log probability 85 | /// over sampled tokens is below `logprob_threshold`, consider the segment as silent 86 | var noSpeechThresh:Float = 0.6 87 | 88 | /// if True, the previous output of the model is provided as a prompt for the next window; 89 | /// disabling may make the text inconsistent across windows, but the model becomes less prone to 90 | /// getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. 91 | var conditionOnPrevText = true 92 | } 93 | 94 | // MARK: Private Constants Enums and Structs 95 | 96 | // All of these are major WIP 97 | 98 | // See https://github.com/openai/whisper/blob/12e1089462a2ea92e9ade6145e7be5b6883676ff/whisper/decoding.py#L383 99 | private struct SupressBlank: WhisperLogitFilter 100 | { 101 | let tokenizer:WhisperTokenizer! 102 | let encodedBlank:[Int]! 103 | let sampleBegin:Int! 104 | 105 | init(tokenizer: WhisperTokenizer!, sampleBegin: Int) { 106 | self.tokenizer = tokenizer 107 | self.encodedBlank = tokenizer.encode(text: " ") 108 | self.sampleBegin = sampleBegin 109 | } 110 | 111 | func apply(logits: inout MLShapedArray, tokens: inout MLShapedArray) 112 | { 113 | print("Not Yet Implemented") 114 | 115 | // https://www.geeksforgeeks.org/how-to-slice-a-3d-tensor-in-pytorch/ 116 | // tensor[tensor_position_start:tensor_position_end, tensor_dimension_start:tensor_dimension_end , tensor_value_start:tensor_value_end] 117 | // logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf 118 | 119 | // Need to - for all batch dimensions, access tokens at the encoded " " int value, and the to value and set to -Inf 120 | // Float.infinity.negate() 121 | } 122 | } 123 | 124 | private struct SupressTokens: WhisperLogitFilter 125 | { 126 | let suppress:[Int] 127 | 128 | init(suppress: [Int]) { 129 | self.suppress = suppress 130 | } 131 | func apply(logits: inout MLShapedArray, tokens: inout MLShapedArray) 132 | { 133 | print("Not Yet Implemented") 134 | } 135 | } 136 | 137 | private struct ApplyTimestampRules: WhisperLogitFilter 138 | { 139 | let tokenizer:WhisperTokenizer! 140 | let sampleBegin:Int! 141 | let maxInitialTimestampIdx:Int? 142 | 143 | init(tokenizer: WhisperTokenizer!, sampleBegin: Int!, maxInitialTimestampIdx: Int?) { 144 | self.tokenizer = tokenizer 145 | self.sampleBegin = sampleBegin 146 | self.maxInitialTimestampIdx = maxInitialTimestampIdx 147 | } 148 | 149 | func apply(logits: inout MLShapedArray, tokens: inout MLShapedArray) 150 | { 151 | print("Not Yet Implemented") 152 | } 153 | 154 | } 155 | 156 | 157 | // WhisperSegment internal state tracking for our Whisper session 158 | // See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L153 159 | private struct WhisperSegment 160 | { 161 | var id:Int! 162 | var seek:Int! 163 | 164 | // Segment times in rational time base units 165 | // Time base is in standard 600 units 166 | var startTime:CMTime! 167 | var endTime:CMTime! 168 | 169 | // Tokens predicted for this segment 170 | var textTokens:[Int]! 171 | // Text resulting from decoded tokens 172 | var decodedText:String! 173 | 174 | // Todo: 175 | var temperature:Float! 176 | var avgLogProb:Float! 177 | var compressionRatio:Float! 178 | var noSpeechProb:Float! 179 | } 180 | 181 | private enum WhisperDecodingStrategy 182 | { 183 | case Greedy 184 | case BeamSearch // Not implemented yet 185 | } 186 | 187 | // Vended by the decode method and used internally 188 | // See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py 189 | private struct WhisperDecodingOptions 190 | { 191 | var task:WhisperTask 192 | var languageCode:String? 193 | 194 | var decodingStetegy:WhisperDecodingStrategy = .Greedy 195 | 196 | // FYI Semantics from Python are these values can be 197 | // None 198 | // Zero 199 | // Some value 200 | // each has specific meaning, specifically None. 201 | // We treat optional / nil as none here. 202 | 203 | // Sampling Related Options 204 | 205 | var temperature:Float = 0.0 206 | // Maximum number of tokens to sample 207 | var maxSampleLen:Int? 208 | // Number of independent samples to collect, when t > 0 209 | var bestOf:Int? 210 | // number of beams in beam search, when t == 0 211 | var beamSize:Int? 212 | // patience in beam search (https://arxiv.org/abs/2204.05424) 213 | var patience:Float? 214 | 215 | // Options for ranking generations (either beams or best-of-N samples) 216 | 217 | // "alpha" in Google NMT, None defaults to length norm 218 | var lengthPenalty:Float? 219 | 220 | // Prompt, prefix, and token suppression 221 | 222 | // Text or tokens for the previous context 223 | var prompt:String? 224 | var promptTokens:[Int]? 225 | // text or tokens to prefix the current context 226 | var prefix:String? 227 | var prefixTokens:[Int]? 228 | // this will suppress blank outputs 229 | var suppressBlank:Bool = true 230 | // list of tokens ids (or comma-separated token ids) to suppress 231 | // nil will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()` 232 | // and empty array will do no suppression 233 | var suppresTokens:[Int]? 234 | 235 | // timestamp sampling options 236 | var withoutTimestamps:Bool = false 237 | var maxInitialTimestampL:Float = 1.0 238 | } 239 | 240 | // https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L104 241 | private struct WhisperDecodingResult 242 | { 243 | var tokens:[Int] 244 | var text:String = "" 245 | 246 | var languageCode:String? 247 | var langProbs:[String:Float]? 248 | 249 | var avgLogProbs:Float = Float.nan 250 | var noSpeechProbs:Float = Float.nan 251 | var temperature:Float = Float.nan 252 | var compressionRatio:Float = Float.nan 253 | } 254 | 255 | // MARK: Whisper Properties 256 | 257 | let decoderModel: decoder_base 258 | let encoderModel: encoder_base 259 | let tokenizer = WhisperTokenizer() 260 | 261 | let melGenerator:MelSpectrogram = MelSpectrogram(sampleCount: kWhisperNumSamplesInChunk, hopCount: kWhisperHopLength, melCount: kWhisperNumMels, numFFT: kWhisperNumFFTs) 262 | 263 | // These are variables which cache our current session, tasks and option 264 | var sessionOptions:WhisperOptions! 265 | private var sessionAccruedAudioSamples:[Int16] = [] 266 | private var sessionNumAccruedAudioSamples:Int = 0 267 | private var sessionTranscription:[String] = [] 268 | 269 | private var sessionSegments:[WhisperSegment] = [] 270 | 271 | init() throws { 272 | let config = MLModelConfiguration() 273 | config.computeUnits = .all 274 | 275 | self.decoderModel = try decoder_base(configuration: config) 276 | self.encoderModel = try encoder_base(configuration: config) 277 | 278 | self.sessionAccruedAudioSamples.reserveCapacity( Whisper.kWhisperNumSamplesInChunk ) 279 | 280 | self.resetState() 281 | } 282 | 283 | 284 | // MARK: Public Methods 285 | 286 | /// Call this method whenever you have a new asset, or wish to start a new realtime transcription session 287 | /// This method resets internal counters, tokens, accrued transcriptions, time stamps, etc 288 | func startWhisperSession(options:WhisperOptions) 289 | { 290 | self.sessionOptions = options 291 | self.resetState() 292 | } 293 | 294 | // this function accrues 295 | func accrueSamplesFromSampleBuffer(sampleBuffer:CMSampleBuffer) 296 | { 297 | var audioBufferListSize:Int = 0 298 | 299 | CMSampleBufferGetAudioBufferListWithRetainedBlockBuffer(sampleBuffer, bufferListSizeNeededOut: &audioBufferListSize, bufferListOut: nil, bufferListSize:0, blockBufferAllocator: kCFAllocatorNull, blockBufferMemoryAllocator: kCFAllocatorNull, flags: kCMSampleBufferFlag_AudioBufferList_Assure16ByteAlignment, blockBufferOut: nil) 300 | 301 | var audioBufferList = AudioBufferList(mNumberBuffers: 1, mBuffers: AudioBuffer(mNumberChannels: 1, mDataByteSize: UInt32(audioBufferListSize), mData: nil)) 302 | 303 | var blockBuffer:CMBlockBuffer? 304 | 305 | CMSampleBufferGetAudioBufferListWithRetainedBlockBuffer(sampleBuffer, bufferListSizeNeededOut: nil, bufferListOut: &audioBufferList, bufferListSize: audioBufferListSize, blockBufferAllocator: kCFAllocatorNull, blockBufferMemoryAllocator: kCFAllocatorNull, flags: kCMSampleBufferFlag_AudioBufferList_Assure16ByteAlignment, blockBufferOut: &blockBuffer) 306 | 307 | // Determine the number of samples we need from our audio 308 | 309 | let numAvailableSamples = Int( CMSampleBufferGetNumSamples(sampleBuffer) ) 310 | 311 | // Calculate the number of samples we have to acrrue to get a full chunk 312 | let remainingSampleCount = Whisper.kWhisperNumSamplesInChunk - self.sessionAccruedAudioSamples.count; 313 | 314 | let samplesToAccrue = min(numAvailableSamples, remainingSampleCount); 315 | 316 | let remainingCurrentSamplesInBuffer = numAvailableSamples - samplesToAccrue; 317 | 318 | let unsafeAudioBufferList = UnsafeMutableAudioBufferListPointer(&audioBufferList) 319 | 320 | for (buffer) in unsafeAudioBufferList 321 | { 322 | let audioSampleArray:[Int16] = buffer.convertInt16() 323 | 324 | let samplesWeNeedToAccrueForAProperChunk = audioSampleArray[0 ... samplesToAccrue - 1] 325 | 326 | self.sessionAccruedAudioSamples.insert(contentsOf: samplesWeNeedToAccrueForAProperChunk, at: self.sessionNumAccruedAudioSamples) 327 | 328 | self.sessionNumAccruedAudioSamples = self.sessionNumAccruedAudioSamples + samplesWeNeedToAccrueForAProperChunk.count 329 | 330 | if (self.sessionAccruedAudioSamples.count == Whisper.kWhisperNumSamplesInChunk) 331 | { 332 | self.mainDeccodeLogicFromTranscribe(audio: self.sessionAccruedAudioSamples) 333 | 334 | self.sessionAccruedAudioSamples = [] 335 | self.sessionNumAccruedAudioSamples = 0 336 | } 337 | 338 | // Accrue whatever remaining Samples in our current samples buffer we have.. 339 | if (remainingCurrentSamplesInBuffer > 0) 340 | { 341 | let numSamplesWeHaveAccruedFromThisSampleBuffer = samplesWeNeedToAccrueForAProperChunk.count - 1 342 | 343 | let remainingSampleCount = Whisper.kWhisperNumSamplesInChunk - self.sessionNumAccruedAudioSamples 344 | 345 | let samplesToAccrue = min(remainingCurrentSamplesInBuffer, remainingSampleCount); 346 | 347 | let remainingSamplesWeNeedToAccrueForAProperChunk = audioSampleArray[numSamplesWeHaveAccruedFromThisSampleBuffer ... (numSamplesWeHaveAccruedFromThisSampleBuffer + samplesToAccrue - 1)] 348 | 349 | self.sessionAccruedAudioSamples.insert(contentsOf: remainingSamplesWeNeedToAccrueForAProperChunk, at: self.sessionNumAccruedAudioSamples) 350 | self.sessionNumAccruedAudioSamples = self.sessionNumAccruedAudioSamples + remainingSamplesWeNeedToAccrueForAProperChunk.count 351 | } 352 | } 353 | 354 | // TODO: 355 | // We might have residual audio samples that dont quite fill a full audio chunk (ie num frames is not equal to Whisper.kWhisperNumSamplesInChunk 356 | // Handle that here. 357 | 358 | // .... 359 | } 360 | 361 | func transcribe(assetURL:URL, options:WhisperOptions) async -> String 362 | { 363 | self.startWhisperSession(options: options) 364 | 365 | let asset = AVURLAsset(url:assetURL) 366 | 367 | do { 368 | let assetReader = try AVAssetReader(asset: asset) 369 | 370 | let audioTracks = try await asset.loadTracks(withMediaType: .audio) 371 | 372 | // Output SInt 16 373 | let audioOutputSettings = [ AVFormatIDKey : kAudioFormatLinearPCM, 374 | AVSampleRateKey : Whisper.kWhisperSampleRate, 375 | AVLinearPCMBitDepthKey: 16, 376 | AVNumberOfChannelsKey: 1, 377 | AVLinearPCMIsFloatKey : false, 378 | AVLinearPCMIsNonInterleaved: false, 379 | AVLinearPCMIsBigEndianKey: false 380 | 381 | ] as [String : Any] 382 | 383 | let audioOutput = AVAssetReaderAudioMixOutput(audioTracks: audioTracks, audioSettings: audioOutputSettings) 384 | audioOutput.alwaysCopiesSampleData = false 385 | 386 | if ( assetReader.canAdd(audioOutput) ) 387 | { 388 | assetReader.add(audioOutput) 389 | } 390 | 391 | assetReader.startReading() 392 | 393 | let startTime = NSDate.timeIntervalSinceReferenceDate 394 | 395 | while ( assetReader.status == .reading ) 396 | { 397 | guard let audioSampleBuffer = audioOutput.copyNextSampleBuffer(), CMSampleBufferIsValid(audioSampleBuffer) else { 398 | 399 | // Some media formats can have weird decode issues. 400 | // Unless our asset reader EXPLICITELT tells us its done, keep trying to decode. 401 | // We just skip bad samples 402 | if ( assetReader.status == .reading) 403 | { 404 | continue 405 | } 406 | 407 | else if (assetReader.status == .completed) 408 | { 409 | break; 410 | } 411 | 412 | else 413 | { 414 | // something went wrong 415 | print(assetReader.error as Any) 416 | return "" 417 | } 418 | } 419 | 420 | self.accrueSamplesFromSampleBuffer(sampleBuffer: audioSampleBuffer) 421 | 422 | } 423 | 424 | let processingTime = NSDate.timeIntervalSinceReferenceDate - startTime 425 | 426 | print("Decode and Predict took", processingTime, "seconds") 427 | 428 | let assetDuration = try await asset.load(.duration).seconds 429 | 430 | print("Movie is", assetDuration) 431 | print("Realtime Factor is", assetDuration / processingTime) 432 | 433 | return self.sessionTranscription.joined(separator: " ") 434 | 435 | } 436 | catch let error 437 | { 438 | print("Unable to process asset:") 439 | print(error) 440 | exit(0) 441 | } 442 | } 443 | 444 | // MARK: Private Methods 445 | 446 | private func encode(audio: [Int16]) throws -> MLShapedArray { 447 | // TODO: Fix our vDSP based mel processor 448 | 449 | let mel:[Float] = melGenerator.processData(audio: audio) 450 | let melRosa:[Float] = melGenerator.processDataRosa(audio: audio) 451 | let melPreProcessed = MelSpectrogram.loadReferencePythonRawMelToDebugShit() 452 | 453 | self.saveNormalizedMelToDisk(mel: mel, url: URL(fileURLWithPath: "/Users/vade/Downloads/rawMel-normalized.raw")) 454 | self.saveNormalizedMelToDisk(mel: melRosa, url: URL(fileURLWithPath: "/Users/vade/Downloads/rawMel-rosa-normalized.raw")) 455 | self.saveNormalizedMelToDisk(mel: melPreProcessed, url: URL(fileURLWithPath: "/Users/vade/Downloads/rawMel-python-normalized.raw")) 456 | 457 | let array = MLShapedArray(scalars: mel, shape: [1, 80, 3000]) 458 | 459 | let encoded = try encoderModel.prediction(logmel_data:array).var_719ShapedArray 460 | return encoded 461 | // return array 462 | } 463 | 464 | 465 | // https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L102 466 | private func decodeWithFallback(audio:[Int16]) -> WhisperDecodingResult? 467 | { 468 | do { 469 | let audioFeatures = try self.encode(audio: audio) 470 | 471 | var decodingOptions = WhisperDecodingOptions(task: self.sessionOptions.task) 472 | 473 | var decodeResult:WhisperDecodingResult? = nil 474 | 475 | for (t) in self.sessionOptions.temperatureSchedule 476 | { 477 | // Current pass decoding options 478 | if ( t > 0.0) 479 | { 480 | // disable beam_size and patience when t > 0 481 | decodingOptions.beamSize = nil 482 | decodingOptions.patience = nil 483 | } 484 | else 485 | { 486 | decodingOptions.bestOf = nil 487 | } 488 | 489 | // Set the current temperature from our temperature schedule 490 | decodingOptions.temperature = t 491 | 492 | decodeResult = try self.decode(audioFeatures: audioFeatures, 493 | decodingOptions: decodingOptions) 494 | 495 | var needsFallback = false 496 | 497 | if let decodeResult = decodeResult 498 | { 499 | if decodeResult.compressionRatio > self.sessionOptions.compressionRatioTresh 500 | { 501 | needsFallback = true 502 | } 503 | 504 | if (decodeResult.avgLogProbs < self.sessionOptions.logProbThresh) 505 | { 506 | needsFallback = true 507 | } 508 | } 509 | 510 | if ( needsFallback == false) 511 | { 512 | return decodeResult 513 | } 514 | } 515 | 516 | return decodeResult 517 | } 518 | catch let error 519 | { 520 | print("Unable to process audio frames", error) 521 | return nil 522 | } 523 | } 524 | 525 | // See https://github.com/openai/whisper/blob/12e1089462a2ea92e9ade6145e7be5b6883676ff/whisper/decoding.py#L616 526 | private func decode(audioFeatures: MLShapedArray, decodingOptions:WhisperDecodingOptions) throws -> Whisper.WhisperDecodingResult { 527 | 528 | // SOT Initialize sequence 529 | var tokens:[Int] = [] 530 | var timestampTokens:[Int] = [] 531 | 532 | // create sot sequence - multilingual model always needs a task and 533 | // https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L325 534 | // https://github.com/huggingface/transformers/blob/main/tests/models/whisper/test_tokenization_whisper.py 535 | tokens.append(WhisperTokenizer.sotToken) 536 | tokens.append(WhisperTokenizer.langToken) 537 | tokens.append(WhisperTokenizer.transcribeToken) 538 | 539 | // No Time Stamps 540 | if ( self.sessionOptions.format == WhisperTranscriptFormat.Text ) 541 | { 542 | tokens.append(WhisperTokenizer.notToken) 543 | } 544 | 545 | var nextToken = 0 546 | var nextTSToken = WhisperTokenizer.begToken 547 | 548 | // More or less main loop https://github.com/openai/whisper/blob/12e1089462a2ea92e9ade6145e7be5b6883676ff/whisper/decoding.py#L584 549 | while ( nextToken != WhisperTokenizer.eotToken ) 550 | { 551 | autoreleasepool { 552 | 553 | let tokensArray = self.tokenizer.tokensToMultiArray(tokens) 554 | 555 | let logits = try! decoderModel.prediction(token_data: tokensArray, audio_data: audioFeatures).var_1131 556 | 557 | // Get logit filters 558 | // apply them 559 | 560 | 561 | let (textToken, tsToken) = self.tokenizer.nextTokenGreedy(decoded: logits) 562 | 563 | nextToken = textToken 564 | nextTSToken = tsToken 565 | 566 | timestampTokens.append(nextTSToken) 567 | tokens.append(nextToken) 568 | 569 | 570 | // Verbose debugging as we iterate 571 | // let transcription = self.tokenizer.decode(tokens: tokens)// 572 | // print(transcription) 573 | } 574 | } 575 | 576 | 577 | 578 | // TODO: Implement calculation of other decodingResult requirements 579 | var decodingResult = WhisperDecodingResult(tokens: tokens, text: self.tokenizer.decode(tokens: tokens)) 580 | 581 | 582 | return decodingResult 583 | } 584 | 585 | // See https://github.com/openai/whisper/blob/12e1089462a2ea92e9ade6145e7be5b6883676ff/whisper/decoding.py#L199 586 | // Beam or Greedy sampling logic goes here 587 | private func decodeTokenUpdate(decodeOptions:WhisperDecodingOptions, tokens:[Int], logits:MLShapedArray, sumLogProbs:[Int]) throws -> (tokens:[Int], completed:Bool) 588 | { 589 | switch (decodeOptions.decodingStetegy) 590 | { 591 | case .Greedy: 592 | throw WhisperError.notImplementedYet 593 | // return self.decodeGreedyStrategy(decodeOptions:decodeOptions, tokens: tokens, logits: logits, sumLogProbs: sumLogProbs) 594 | 595 | case .BeamSearch: 596 | throw WhisperError.notImplementedYet 597 | } 598 | } 599 | 600 | // See "Greedy Decoder" 601 | // https://github.com/openai/whisper/blob/12e1089462a2ea92e9ade6145e7be5b6883676ff/whisper/decoding.py#L249 602 | // private func decodeGreedyStrategy(decodeOptions:WhisperDecodingOptions, tokens:[Int], logits:MLMultiArray, sumLogProbs:[Int]) -> (tokens:[Int], completed:Bool) 603 | // { 604 | // let temp = decodeOptions.temperature 605 | // 606 | // if (temp == 0) 607 | // { 608 | // let next_tokens = 609 | // } 610 | // } 611 | 612 | // See BeamSearch 613 | // https://github.com/openai/whisper/blob/12e1089462a2ea92e9ade6145e7be5b6883676ff/whisper/decoding.py#L277 614 | // private func decodeBeamSearchStrategy(tdecodeOptions:WhisperDecodingOptions, okens:[Int], logits:MLMultiArray, sumLogProbs:[Int]) -> (tokens:[Int], completed:Bool) 615 | // { 616 | // 617 | // } 618 | 619 | // https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L175 620 | private func mainDeccodeLogicFromTranscribe(audio:[Int16]) 621 | { 622 | 623 | // Timestamp shit 624 | 625 | // Today, we dont support audio frames other than the full 3000 Mel frame count 626 | // So seek is count of a mel chunk (3000) * hop length / sample rate 627 | // See : for reference https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L176 628 | let idForSegment = self.sessionSegments.count 629 | let segmentFrameCount = Whisper.kWhisperNumSamplesInMel 630 | let seek = self.sessionSegments.count * segmentFrameCount 631 | 632 | let timestampOffset = Float64(seek * Whisper.kWhisperHopLength / Whisper.kWhisperSampleRate) 633 | let segmentDuration = Float64(segmentFrameCount * Whisper.kWhisperHopLength / Whisper.kWhisperSampleRate) 634 | 635 | print("segment start, segment duration", timestampOffset, segmentDuration) 636 | 637 | if let result:WhisperDecodingResult = self.decodeWithFallback(audio: audio) 638 | { 639 | if ( self.sessionOptions.verbose ) 640 | { 641 | print (result.text) 642 | } 643 | 644 | let currentSegment = Whisper.WhisperSegment(id: idForSegment, 645 | seek: seek, 646 | startTime: CMTimeMakeWithSeconds(timestampOffset, preferredTimescale: 600), 647 | endTime: CMTimeMakeWithSeconds(segmentDuration, preferredTimescale: 600), 648 | textTokens: result.tokens, 649 | decodedText: result.text) 650 | 651 | self.sessionSegments.append(currentSegment) 652 | 653 | } 654 | 655 | } 656 | 657 | private func resetState() 658 | { 659 | // Reset our state 660 | self.sessionSegments = [] 661 | self.sessionAccruedAudioSamples = [] 662 | self.sessionNumAccruedAudioSamples = 0 663 | } 664 | 665 | // MARK: Debug Methods 666 | 667 | // Internal Helper to just test and visualize the output of our Log Mel processing 668 | private func normalize(array: [Float]) -> [Float] { 669 | var normalizedArray = array 670 | var min = Float.greatestFiniteMagnitude 671 | var max = -Float.greatestFiniteMagnitude 672 | var shift: Float = 0.0 673 | var scale: Float = 0.0 674 | 675 | vDSP_minv(array, 1, &min, vDSP_Length(array.count)) 676 | vDSP_maxv(array, 1, &max, vDSP_Length(array.count)) 677 | shift = abs(min) 678 | vDSP_vsadd(array, 1, &shift, &normalizedArray, 1, vDSP_Length(array.count)) 679 | scale = 1 / (max + shift) 680 | vDSP_vsmul(normalizedArray, 1, &scale, &normalizedArray, 1, vDSP_Length(array.count)) 681 | return normalizedArray 682 | } 683 | 684 | private func saveNormalizedMelToDisk(mel:[Float], url:URL) 685 | { 686 | let normalizedFloatMel = self.normalize(array: mel ) 687 | 688 | normalizedFloatMel.withUnsafeBufferPointer { unsafeMel in 689 | 690 | let data = Data(buffer: unsafeMel) 691 | do { 692 | try data.write(to: url) 693 | } 694 | catch { 695 | } 696 | } 697 | } 698 | } 699 | 700 | 701 | // Taken from : https://gist.github.com/tion-low/47e9fc4082717078dff4d6259b6ffbc9 702 | 703 | //extension AudioBufferList { 704 | // public mutating func convert() -> [AudioBuffer] { 705 | // 706 | // self.mBuffers 707 | // 708 | // let buf = UnsafeMutableAudioBufferListPointer(UnsafeMutablePointer(start: &(self.mBuffers), count: Int(self.mNumberBuffers)) ) 709 | // 710 | // return Array(buf) 711 | // 712 | // 713 | //// let buf: UnsafeBufferPointer = UnsafeBufferPointer(start: &(self.mBuffers), count: Int(self.mNumberBuffers)) 714 | //// return 715 | // } 716 | //} 717 | 718 | extension AudioBuffer { 719 | public func convertFloat() -> [Float] { 720 | if let mdata = self.mData { 721 | let ump = mdata.bindMemory(to: Float.self, capacity: Int(mDataByteSize)) 722 | let usp = UnsafeBufferPointer(start: ump, count: Int(mDataByteSize) / MemoryLayout.size) 723 | return [Float](usp) 724 | } else { 725 | return [] 726 | } 727 | } 728 | 729 | public func convertInt16() -> [Int16] { 730 | if let mdata = self.mData { 731 | let ump = mdata.bindMemory(to: Int16.self, capacity: Int(mDataByteSize)) 732 | let usp = UnsafeBufferPointer(start: ump, count: Int(mDataByteSize) / MemoryLayout.size) 733 | return [Int16](usp) 734 | } else { 735 | return [] 736 | } 737 | } 738 | 739 | } 740 | -------------------------------------------------------------------------------- /Whisper/Whisper/Whisper/WhisperTokenizer.swift: -------------------------------------------------------------------------------- 1 | // 2 | // WhisperTokenizer.swift 3 | // Whisper 4 | // 5 | // Created by Anton Marini on 1/6/23. 6 | // 7 | 8 | // Based heavily on https://github.com/huggingface/swift-coreml-transformers/blob/master/Sources/GPT2Tokenizer.swift 9 | // GPT2Tokenizer.swift by Created by Julien Chaumond on 18/07/2019. 10 | 11 | import Foundation 12 | import CoreML 13 | import Accelerate 14 | 15 | struct Utils { 16 | /// Invert a (k, v) dictionary 17 | static func invert(_ dict: Dictionary) -> Dictionary { 18 | var inverted: [V: K] = [:] 19 | for (k, v) in dict { 20 | inverted[v] = k 21 | } 22 | return inverted 23 | } 24 | 25 | } 26 | 27 | struct BytePair: Hashable { 28 | let a: String 29 | let b: String 30 | init(_ a: String, _ b: String) { 31 | self.a = a 32 | self.b = b 33 | } 34 | init(tuple: [String]) { 35 | self.a = tuple[0] 36 | self.b = tuple[1] 37 | } 38 | 39 | static func == (lhs: BytePair, rhs: BytePair) -> Bool { 40 | return lhs.a == rhs.a && lhs.b == rhs.b 41 | } 42 | func hash(into hasher: inout Hasher) { 43 | hasher.combine(a) 44 | hasher.combine(b) 45 | } 46 | } 47 | 48 | fileprivate extension String { 49 | func ranges(of string: String, options: CompareOptions = .regularExpression) -> [Range] { 50 | var result: [Range] = [] 51 | var start = startIndex 52 | while let range = range(of: string, options: options, range: start.. = [ 62 | 33: "!", 63 | 34: "\"", 64 | 35: "#", 65 | 36: "$", 66 | 37: "%", 67 | 38: "&", 68 | 39: "'", 69 | 40: "(", 70 | 41: ")", 71 | 42: "*", 72 | 43: "+", 73 | 44: ",", 74 | 45: "-", 75 | 46: ".", 76 | 47: "/", 77 | 48: "0", 78 | 49: "1", 79 | 50: "2", 80 | 51: "3", 81 | 52: "4", 82 | 53: "5", 83 | 54: "6", 84 | 55: "7", 85 | 56: "8", 86 | 57: "9", 87 | 58: ":", 88 | 59: ";", 89 | 60: "<", 90 | 61: "=", 91 | 62: ">", 92 | 63: "?", 93 | 64: "@", 94 | 65: "A", 95 | 66: "B", 96 | 67: "C", 97 | 68: "D", 98 | 69: "E", 99 | 70: "F", 100 | 71: "G", 101 | 72: "H", 102 | 73: "I", 103 | 74: "J", 104 | 75: "K", 105 | 76: "L", 106 | 77: "M", 107 | 78: "N", 108 | 79: "O", 109 | 80: "P", 110 | 81: "Q", 111 | 82: "R", 112 | 83: "S", 113 | 84: "T", 114 | 85: "U", 115 | 86: "V", 116 | 87: "W", 117 | 88: "X", 118 | 89: "Y", 119 | 90: "Z", 120 | 91: "[", 121 | 92: "\\", 122 | 93: "]", 123 | 94: "^", 124 | 95: "_", 125 | 96: "`", 126 | 97: "a", 127 | 98: "b", 128 | 99: "c", 129 | 100: "d", 130 | 101: "e", 131 | 102: "f", 132 | 103: "g", 133 | 104: "h", 134 | 105: "i", 135 | 106: "j", 136 | 107: "k", 137 | 108: "l", 138 | 109: "m", 139 | 110: "n", 140 | 111: "o", 141 | 112: "p", 142 | 113: "q", 143 | 114: "r", 144 | 115: "s", 145 | 116: "t", 146 | 117: "u", 147 | 118: "v", 148 | 119: "w", 149 | 120: "x", 150 | 121: "y", 151 | 122: "z", 152 | 123: "{", 153 | 124: "|", 154 | 125: "}", 155 | 126: "~", 156 | 161: "\u{00a1}", 157 | 162: "\u{00a2}", 158 | 163: "\u{00a3}", 159 | 164: "\u{00a4}", 160 | 165: "\u{00a5}", 161 | 166: "\u{00a6}", 162 | 167: "\u{00a7}", 163 | 168: "\u{00a8}", 164 | 169: "\u{00a9}", 165 | 170: "\u{00aa}", 166 | 171: "\u{00ab}", 167 | 172: "\u{00ac}", 168 | 174: "\u{00ae}", 169 | 175: "\u{00af}", 170 | 176: "\u{00b0}", 171 | 177: "\u{00b1}", 172 | 178: "\u{00b2}", 173 | 179: "\u{00b3}", 174 | 180: "\u{00b4}", 175 | 181: "\u{00b5}", 176 | 182: "\u{00b6}", 177 | 183: "\u{00b7}", 178 | 184: "\u{00b8}", 179 | 185: "\u{00b9}", 180 | 186: "\u{00ba}", 181 | 187: "\u{00bb}", 182 | 188: "\u{00bc}", 183 | 189: "\u{00bd}", 184 | 190: "\u{00be}", 185 | 191: "\u{00bf}", 186 | 192: "\u{00c0}", 187 | 193: "\u{00c1}", 188 | 194: "\u{00c2}", 189 | 195: "\u{00c3}", 190 | 196: "\u{00c4}", 191 | 197: "\u{00c5}", 192 | 198: "\u{00c6}", 193 | 199: "\u{00c7}", 194 | 200: "\u{00c8}", 195 | 201: "\u{00c9}", 196 | 202: "\u{00ca}", 197 | 203: "\u{00cb}", 198 | 204: "\u{00cc}", 199 | 205: "\u{00cd}", 200 | 206: "\u{00ce}", 201 | 207: "\u{00cf}", 202 | 208: "\u{00d0}", 203 | 209: "\u{00d1}", 204 | 210: "\u{00d2}", 205 | 211: "\u{00d3}", 206 | 212: "\u{00d4}", 207 | 213: "\u{00d5}", 208 | 214: "\u{00d6}", 209 | 215: "\u{00d7}", 210 | 216: "\u{00d8}", 211 | 217: "\u{00d9}", 212 | 218: "\u{00da}", 213 | 219: "\u{00db}", 214 | 220: "\u{00dc}", 215 | 221: "\u{00dd}", 216 | 222: "\u{00de}", 217 | 223: "\u{00df}", 218 | 224: "\u{00e0}", 219 | 225: "\u{00e1}", 220 | 226: "\u{00e2}", 221 | 227: "\u{00e3}", 222 | 228: "\u{00e4}", 223 | 229: "\u{00e5}", 224 | 230: "\u{00e6}", 225 | 231: "\u{00e7}", 226 | 232: "\u{00e8}", 227 | 233: "\u{00e9}", 228 | 234: "\u{00ea}", 229 | 235: "\u{00eb}", 230 | 236: "\u{00ec}", 231 | 237: "\u{00ed}", 232 | 238: "\u{00ee}", 233 | 239: "\u{00ef}", 234 | 240: "\u{00f0}", 235 | 241: "\u{00f1}", 236 | 242: "\u{00f2}", 237 | 243: "\u{00f3}", 238 | 244: "\u{00f4}", 239 | 245: "\u{00f5}", 240 | 246: "\u{00f6}", 241 | 247: "\u{00f7}", 242 | 248: "\u{00f8}", 243 | 249: "\u{00f9}", 244 | 250: "\u{00fa}", 245 | 251: "\u{00fb}", 246 | 252: "\u{00fc}", 247 | 253: "\u{00fd}", 248 | 254: "\u{00fe}", 249 | 255: "\u{00ff}", 250 | 0: "\u{0100}", 251 | 1: "\u{0101}", 252 | 2: "\u{0102}", 253 | 3: "\u{0103}", 254 | 4: "\u{0104}", 255 | 5: "\u{0105}", 256 | 6: "\u{0106}", 257 | 7: "\u{0107}", 258 | 8: "\u{0108}", 259 | 9: "\u{0109}", 260 | 10: "\u{010a}", 261 | 11: "\u{010b}", 262 | 12: "\u{010c}", 263 | 13: "\u{010d}", 264 | 14: "\u{010e}", 265 | 15: "\u{010f}", 266 | 16: "\u{0110}", 267 | 17: "\u{0111}", 268 | 18: "\u{0112}", 269 | 19: "\u{0113}", 270 | 20: "\u{0114}", 271 | 21: "\u{0115}", 272 | 22: "\u{0116}", 273 | 23: "\u{0117}", 274 | 24: "\u{0118}", 275 | 25: "\u{0119}", 276 | 26: "\u{011a}", 277 | 27: "\u{011b}", 278 | 28: "\u{011c}", 279 | 29: "\u{011d}", 280 | 30: "\u{011e}", 281 | 31: "\u{011f}", 282 | 32: "\u{0120}", 283 | 127: "\u{0121}", 284 | 128: "\u{0122}", 285 | 129: "\u{0123}", 286 | 130: "\u{0124}", 287 | 131: "\u{0125}", 288 | 132: "\u{0126}", 289 | 133: "\u{0127}", 290 | 134: "\u{0128}", 291 | 135: "\u{0129}", 292 | 136: "\u{012a}", 293 | 137: "\u{012b}", 294 | 138: "\u{012c}", 295 | 139: "\u{012d}", 296 | 140: "\u{012e}", 297 | 141: "\u{012f}", 298 | 142: "\u{0130}", 299 | 143: "\u{0131}", 300 | 144: "\u{0132}", 301 | 145: "\u{0133}", 302 | 146: "\u{0134}", 303 | 147: "\u{0135}", 304 | 148: "\u{0136}", 305 | 149: "\u{0137}", 306 | 150: "\u{0138}", 307 | 151: "\u{0139}", 308 | 152: "\u{013a}", 309 | 153: "\u{013b}", 310 | 154: "\u{013c}", 311 | 155: "\u{013d}", 312 | 156: "\u{013e}", 313 | 157: "\u{013f}", 314 | 158: "\u{0140}", 315 | 159: "\u{0141}", 316 | 160: "\u{0142}", 317 | 173: "\u{0143}", 318 | ] 319 | 320 | let byteDecoder = Utils.invert(byteEncoder) 321 | 322 | class GPT2Tokenizer { 323 | let bpeRanks: Dictionary 324 | internal let encoder: [String: Int] 325 | internal let decoder: [Int: String] 326 | 327 | init() { 328 | let url = Bundle.main.url(forResource: "multilingual-merges", withExtension: "txt")! 329 | let bpeMergesTxt = try! String(contentsOf: url) 330 | let arr = bpeMergesTxt.split(separator: "\n").map { String($0) } 331 | var bpeRanks: Dictionary = [:] 332 | for i in 1.. [String] { 350 | let RE = #"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"# 351 | let tokens = text.ranges(of: RE).map { String(text[$0]) } 352 | return tokens.map { (token) -> String in 353 | return Array(token.utf8).map { byteEncoder[$0]! }.joined() 354 | } 355 | } 356 | 357 | private func getPairs(word: [String]) -> Set { 358 | var s = Set() 359 | for i in 0.. String { 370 | if token.count <= 1 { 371 | return token 372 | } 373 | 374 | var word = Array(token).map { String($0) } 375 | var pairs = Array(getPairs(word: word)) 376 | 377 | while true { 378 | let bigrams = pairs.filter { (bp) -> Bool in bpeRanks[bp] != nil } 379 | if bigrams.count == 0 { 380 | break 381 | } 382 | let bigram = bigrams.min { (bp1, bp2) -> Bool in 383 | return bpeRanks[bp1]! < bpeRanks[bp2]! 384 | }! 385 | let first = bigram.a 386 | let second = bigram.b 387 | var newWord: [String] = [] 388 | var i = 0 389 | while i < word.count { 390 | if let j = word[i.. [String] { 417 | var tokens: [String] = [] 418 | for token in self.byteEncode(text: text) { 419 | let xx = self.bpe(token: token).split(separator: " ").map { String($0) } 420 | tokens.append(contentsOf: xx) 421 | } 422 | return tokens 423 | } 424 | 425 | /// Main entry point 426 | func encode(text: String) -> [Int] { 427 | return tokenize(text: text).map { encoder[$0]! } 428 | } 429 | 430 | /// Decode 431 | func decode(tokens: [Int]) -> String { 432 | let text = tokens.map { decoder[$0]! }.joined(separator: "") 433 | let utfCodepoints = text.map { byteDecoder[String($0)]! } 434 | return String(decoding: utfCodepoints, as: UTF8.self) 435 | } 436 | } 437 | 438 | 439 | class WhisperTokenizer:GPT2Tokenizer 440 | { 441 | // https://github.com/huggingface/transformers/pull/19921 442 | // Tokens is Vocab + 443 | static let eotToken = 50257 444 | static let sotToken = 50258 445 | static let langToken = 50259 446 | // sotToken + 1 + langIdx for a specific language, ie en is (langToken since it is index 0) 447 | // .. language tokens length of lang array (99) 448 | static let translateToken = 50358 449 | static let transcribeToken = 50359 450 | static let prevToken = 50361 451 | static let solmToken = 50362 452 | static let notToken = 50363 453 | static let begToken = 50364 454 | 455 | // https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/tokenizer.py#L279 456 | static let eotString = "<|endoftranscript|>" 457 | static let sotString = "<|startoftranscript|>" 458 | static let languageStrings:[String] = WhisperTokenizer.LANGUAGES.map{ "|<" + $0 + ">|" } 459 | static let translateString = "<|translate|>" 460 | static let transcribeString = "<|transcribe|>" 461 | static let solmString = "<|startoflm|>" // start of language model (??) 462 | static let prevString = "<|startofprev|>" 463 | static let noSpeechString = "<|nospeech|>" 464 | static let noTimestampsString = "<|notimestamps|>" 465 | 466 | static let LANGUAGES = ["en", "zh", "de", "es", "ru", "ko", "fr", "ja", "pt", "tr", "pl", "ca", "nl", "ar", "sv", "it", "id", "hi", "fi", "vi", "iw", "uk", "el", "ms", "cs", "ro", "da", "hu", "ta", "no", "th", "ur", "hr", "bg", "lt", "la", "mi", "ml", "cy", "sk", "te", "fa", "lv", "bn", "sr", "az", "sl", "kn", "et", "mk", "br", "eu", "is", "hy", "ne", "mn", "bs", "kk", "sq", "sw", "gl", "mr", "pa", "si", "km", "sn", "yo", "so", "af", "oc", "ka", "be", "tg", "sd", "gu", "am", "yi", "lo", "uz", "fo", "ht", "ps", "tk", "nn", "mt", "sa", "lb", "my", "bo", "tl", "mg", "as", "tt", "haw", "ln", "ha", "ba", "jw", "su"] 467 | 468 | var specialTokens:[Int:String]! 469 | 470 | override init() 471 | { 472 | super.init() 473 | 474 | self.specialTokens = self.generateSpecialTokenDict() 475 | } 476 | 477 | // MARK: - Token Accessors 478 | 479 | func timestampBeginToken() -> Int 480 | { 481 | return Self.begToken 482 | } 483 | 484 | // MARK: - Helper Methods 485 | 486 | func generateSpecialTokenDict() -> [Int : String] 487 | { 488 | let lastVocabIndex = self.decoder.count; 489 | 490 | // Start from our decoder vocabularies last entry 491 | // let lastVocab:[Int:String] = self.decoder.endIndex 492 | // 493 | var specialTokenArray:[String] = [Self.eotString, Self.sotString] 494 | specialTokenArray.append(contentsOf:Self.languageStrings) 495 | specialTokenArray.append(contentsOf:[Self.translateString, Self.transcribeString, Self.solmString, Self.prevString]) 496 | 497 | let specialTokenDict = specialTokenArray.reduce(into: [Int : String]() ) { specialTokenDict, tokenString in 498 | 499 | let index = specialTokenArray.firstIndex(of: tokenString)! 500 | 501 | specialTokenDict[index + lastVocabIndex] = tokenString 502 | } 503 | 504 | return specialTokenDict 505 | } 506 | 507 | func tokenToMultiArray(token:Int) -> MLShapedArray 508 | { 509 | let array = MLShapedArray(scalar: Int32(token)) 510 | 511 | // let array = try! MLMultiArray(shape: [1, 1], dataType: .int32) 512 | // let ptr = UnsafeMutablePointer(OpaquePointer(array.dataPointer)) 513 | // ptr[0] = Int32(token) 514 | 515 | return array 516 | } 517 | 518 | func tokensToMultiArray(_ tokens: [Int] ) -> MLShapedArray 519 | { 520 | let array = MLShapedArray(scalars: tokens.map{ Int32( $0) }, shape: [1, tokens.count]) 521 | 522 | return array 523 | // var shape = Array(repeating: 1, count: dims) 524 | // shape[shape.count - 1] = tokens.count 525 | // /// Examples: 526 | // /// dims=1 : [arr.count] 527 | // /// dims=2 : [1, arr.count] 528 | // /// 529 | // let o = try! MLMultiArray(shape: shape as [NSNumber], dataType: .int32) 530 | // let ptr = UnsafeMutablePointer(OpaquePointer(o.dataPointer)) 531 | // for (i, item) in tokens.enumerated() { 532 | // ptr[i] = Int32(item) 533 | // } 534 | // return o 535 | } 536 | 537 | func simdMaxIndexForRange(startToken:Int, endToken:Int, decoded:MLMultiArray) -> (Int, Float) 538 | { 539 | // we need to look at the shape, and extract the latest 1 x 51865 logits. 540 | // for example, if i have 23 tokens, i'll have a 1 x 23 x 51865 dim array 541 | // we need the LATEST (ie, the 22nd) 51865 logits 542 | 543 | let numPredictedTokens = decoded.shape[1].intValue - 1 544 | let numTokenIDs = decoded.shape[2].intValue 545 | 546 | var maxValue: Float = 0.0 547 | var maxIndex: vDSP_Length = 0 548 | 549 | // This is the offset into the entire 550 | let offsetIntoLogits = (numTokenIDs * numPredictedTokens) 551 | 552 | // This is slow, and should be optimized via raw pointer access 553 | var confidence:[Float] = (startToken..(OpaquePointer(decoded.dataPointer)) 562 | // vDSP_maxvi(ptr + (numTokenIDs * numPredictedTokens), 1, &maxValue, &maxIndex, vDSP_Length( numTokenIDs ) ) 563 | 564 | return (Int(maxIndex) + startToken, maxValue) 565 | } 566 | 567 | func predictLangToken(decoded:MLMultiArray) -> Int 568 | { 569 | let (token, _) = self.simdMaxIndexForRange(startToken: Self.langToken, 570 | endToken: WhisperTokenizer.sotToken, 571 | decoded: decoded) 572 | return token 573 | } 574 | 575 | func langFromToken(token:Int) -> String 576 | { 577 | return Self.LANGUAGES[token] 578 | } 579 | 580 | func nextTokenGreedy(decoded:MLMultiArray) -> (Int, Int) 581 | { 582 | 583 | let (token, _) = self.simdMaxIndexForRange(startToken: 0, endToken: WhisperTokenizer.sotToken, decoded: decoded) 584 | 585 | let (timestamp_token, _) = self.simdMaxIndexForRange(startToken: WhisperTokenizer.begToken, endToken:Int(truncating: decoded.shape[2]), decoded: decoded) 586 | 587 | // print(timestamp_token) 588 | 589 | return (token, timestamp_token) 590 | } 591 | 592 | override func decode(tokens:[Int]) -> String 593 | { 594 | // We need a method to not let our custom tokens hit the vocab 595 | // We also likely need to match some special processing done 596 | let pruned_tokens = tokens.filter{ $0 < WhisperTokenizer.eotToken} 597 | 598 | return super.decode(tokens: pruned_tokens) 599 | } 600 | 601 | // Our token can be decoded by our standard GPT2 vocab 602 | func tokenIsVocab(token:Int) -> Bool 603 | { 604 | return token < self.decoder.count 605 | } 606 | 607 | func tokenIsTimestamp(token:Int) -> Bool 608 | { 609 | return token > (self.decoder.count + self.specialTokens.count - 1) 610 | } 611 | 612 | // https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/tokenizer.py#L143 613 | // Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. 614 | // This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". 615 | func decodeWithTimestamps(tokens:[Int], timestampTokens:[Int]) -> String 616 | { 617 | // let timeStampOutput:[String] = [String]() 618 | // var tokensForDecode:[Int] = [Int]() 619 | 620 | for (token) in timestampTokens 621 | { 622 | let timestampValue:Float = 0.02 * Float(token - self.timestampBeginToken()) 623 | let timestamp = String(format:"<| %.2f |>", timestampValue) 624 | 625 | print(timestamp) 626 | } 627 | 628 | return self.decode(tokens: tokens) 629 | 630 | 631 | } 632 | } 633 | -------------------------------------------------------------------------------- /Whisper/Whisper/WhisperApp.swift: -------------------------------------------------------------------------------- 1 | // 2 | // WhisperApp.swift 3 | // Whisper 4 | // 5 | // Created by Tanmay Bakshi on 2022-09-26. 6 | // 7 | 8 | import SwiftUI 9 | 10 | @main 11 | struct WhisperApp: App { 12 | var body: some Scene { 13 | WindowGroup { 14 | try! ContentView() 15 | } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /Whisper/Whisper/mel_filters.data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vade/OpenAI-Whisper-CoreML/2c60f5cb67dee05adbfaf1f31a12595dfbe75789/Whisper/Whisper/mel_filters.data -------------------------------------------------------------------------------- /Whisper/Whisper/python_log_mel.raw: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vade/OpenAI-Whisper-CoreML/2c60f5cb67dee05adbfaf1f31a12595dfbe75789/Whisper/Whisper/python_log_mel.raw -------------------------------------------------------------------------------- /export_m80.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from whisper import audio 3 | 4 | m80 = np.load(os.path.join(os.path.dirname(inspect.getfile(audio)), "assets", "mel_filters.npz"))["mel_80"].flatten() 5 | np.save("m80.npy", m80) 6 | -------------------------------------------------------------------------------- /whisper_to_cml.py: -------------------------------------------------------------------------------- 1 | import whisper 2 | import numpy as np 3 | import torch 4 | import coremltools as ct 5 | 6 | def load_models(): 7 | model = whisper.load_model("small").cpu() 8 | return model.encoder, model.decoder 9 | 10 | def convert_encoder_to_tvm(model): 11 | model.eval() 12 | 13 | input_shape = (1, 80, 3000) 14 | input_data = torch.randn(input_shape) 15 | traced_model = torch.jit.trace(model, input_data) 16 | 17 | model = ct.convert( 18 | traced_model, 19 | convert_to="mlprogram", 20 | inputs=[ct.TensorType(name="logmel_data", shape=input_shape)] 21 | ) 22 | 23 | return model 24 | 25 | def convert_decoder_to_tvm(model): 26 | model.eval() 27 | 28 | tokens_shape = (1, 1) 29 | audio_shape = (1, 1500, 768) 30 | token_data = torch.randn(tokens_shape).long() 31 | audio_data = torch.randn(audio_shape) 32 | traced_model = torch.jit.trace(model, (token_data, audio_data)) 33 | 34 | token_flexible_shape = ct.Shape(shape=(1, 35 | ct.RangeDim(lower_bound=1, upper_bound=-1, default=1))) 36 | 37 | 38 | model = ct.convert( 39 | traced_model, 40 | convert_to="mlprogram", 41 | inputs=[ 42 | ct.TensorType(name="token_data", shape=token_flexible_shape, dtype=int), 43 | ct.TensorType(name="audio_data", shape=audio_shape) 44 | ] 45 | ) 46 | 47 | return model 48 | 49 | def main(): 50 | encoder, decoder = load_models() 51 | 52 | decoder = convert_decoder_to_tvm(decoder) 53 | decoder.save("decoder.mlpackage") 54 | 55 | encoder = convert_encoder_to_tvm(encoder) 56 | encoder.save("encoder.mlpackage") 57 | 58 | if __name__ == "__main__": 59 | main() 60 | --------------------------------------------------------------------------------