├── .gitignore ├── go.mod ├── go.sum ├── LICENSE ├── coqui_wrap.h ├── README.md ├── cmd └── asticoqui │ └── main.go ├── coqui.cpp └── coqui.go /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | Thumbs.db 3 | .idea/ 4 | tmp/ -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/asticode/go-asticoqui 2 | 3 | go 1.13 4 | 5 | require ( 6 | github.com/cheekybits/is v0.0.0-20150225183255-68e9c0620927 // indirect 7 | github.com/cryptix/wav v0.0.0-20180415113528-8bdace674401 8 | ) 9 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/cheekybits/is v0.0.0-20150225183255-68e9c0620927 h1:SKI1/fuSdodxmNNyVBR8d7X/HuLnRpvvFO0AgyQk764= 2 | github.com/cheekybits/is v0.0.0-20150225183255-68e9c0620927/go.mod h1:h/aW8ynjgkuj+NQRlZcDbAbM1ORAbXjXX77sX7T289U= 3 | github.com/cryptix/wav v0.0.0-20180415113528-8bdace674401 h1:rZ+OHHkwlkYALTEd6AYXSL92K/SEc4fkz+TfweIwu6A= 4 | github.com/cryptix/wav v0.0.0-20180415113528-8bdace674401/go.mod h1:knK8fd+KPlGGqSUWogv1DQzGTwnfUvAi0cIoWyOG7+U= 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Quentin Renard 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /coqui_wrap.h: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | extern "C" { 3 | #endif 4 | typedef struct TokenMetadata { 5 | const char* text; 6 | const unsigned int timestep; 7 | const float start_time; 8 | } TokenMetadata; 9 | 10 | typedef struct CandidateTranscript { 11 | const TokenMetadata* const tokens; 12 | const unsigned int num_tokens; 13 | const double confidence; 14 | } CandidateTranscript; 15 | 16 | typedef struct Metadata { 17 | const CandidateTranscript* const transcripts; 18 | const unsigned int num_transcripts; 19 | } Metadata; 20 | 21 | typedef void* ModelWrapper; 22 | ModelWrapper* New(const char* aModelPath, int* errorOut); 23 | void Model_Close(ModelWrapper* w); 24 | unsigned int Model_BeamWidth(ModelWrapper* w); 25 | int Model_SetBeamWidth(ModelWrapper* w, unsigned int aBeamWidth); 26 | int Model_SampleRate(ModelWrapper* w); 27 | int Model_EnableExternalScorer(ModelWrapper* w, const char* aScorerPath); 28 | int Model_DisableExternalScorer(ModelWrapper* w); 29 | int Model_SetScorerAlphaBeta(ModelWrapper* w, float aAlpha, float aBeta); 30 | char* Model_STT(ModelWrapper* w, const short* aBuffer, unsigned int aBufferSize); 31 | Metadata* Model_STTWithMetadata(ModelWrapper* w, const short* aBuffer, unsigned int aBufferSize, unsigned int aNumResults); 32 | 33 | typedef void* StreamWrapper; 34 | StreamWrapper* Model_NewStream(ModelWrapper* w, int* errorOut); 35 | void Stream_Discard(StreamWrapper* sw); 36 | void Stream_FeedAudioContent(StreamWrapper* sw, const short* aBuffer, unsigned int aBufferSize); 37 | char* Stream_IntermediateDecode(StreamWrapper* sw); 38 | Metadata* Stream_IntermediateDecodeWithMetadata(StreamWrapper* sw, unsigned int aNumResults); 39 | char* Stream_Finish(StreamWrapper* sw); 40 | Metadata* Stream_FinishWithMetadata(StreamWrapper* sw, unsigned int aNumResults); 41 | 42 | const CandidateTranscript* Metadata_Transcripts(Metadata* m); 43 | unsigned int Metadata_NumTranscripts(Metadata* m); 44 | void Metadata_Close(Metadata* m); 45 | 46 | const TokenMetadata* CandidateTranscript_Tokens(CandidateTranscript* ct); 47 | unsigned int CandidateTranscript_NumTokens(CandidateTranscript* ct); 48 | double CandidateTranscript_Confidence(CandidateTranscript* ct); 49 | 50 | const char* TokenMetadata_Text(TokenMetadata* tm); 51 | unsigned int TokenMetadata_Timestep(TokenMetadata* tm); 52 | float TokenMetadata_StartTime(TokenMetadata* tm); 53 | 54 | void FreeString(char* s); 55 | char* Version(); 56 | char* ErrorCodeToErrorMessage(int aErrorCode); 57 | 58 | #ifdef __cplusplus 59 | } 60 | #endif 61 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![GoReportCard](http://goreportcard.com/badge/github.com/asticode/go-asticoqui)](http://goreportcard.com/report/github.com/asticode/go-asticoqui) 2 | [![GoDoc](https://godoc.org/github.com/asticode/go-asticoqui?status.svg)](https://godoc.org/github.com/asticode/go-asticoqui) 3 | 4 | Golang bindings for Coqui's [:frog:STT](https://github.com/coqui-ai/STT) speech-to-text library. 5 | 6 | `asticoqui` is compatible with version `v1.0.0`, `v1.1.0`, and `v1.2.0` of 🐸STT. 7 | 8 | # Installation 9 | 10 | ## Install tflite 11 | 12 | Run the following command: 13 | 14 | ```bash 15 | $ pip3 install --extra-index-url https://google-coral.github.io/py-repo/ tflite_runtime 16 | ``` 17 | 18 | If you're interested in running against your CUDA-enabled GPU (optional), then set the environment variable `STT_TFLITE_DELEGATE=gpu`. 19 | 20 | ## Install Coqui STT 21 | 22 | 1. fetch an up-to-date `native_client.*.tar.xz` matching your system from [:frog:STT releases](https://github.com/coqui-ai/STT/releases). For example, on macOS: 23 | 24 | ```bash 25 | $ wget https://github.com/coqui-ai/STT/releases/download/v1.2.0/native_client.tflite.macOS.tar.xz 26 | ``` 27 | 28 | 2. extract its content to `$HOME/.coqui/`. For example, on macOS: 29 | 30 | ```bash 31 | $ mkdir $HOME/.coqui/ 32 | $ tar -xvzf native_client.tflite.macOS.tar.xz -C $HOME/.coqui/ 33 | ``` 34 | 35 | 3. set environment variables to point to client 36 | 37 | ```bash 38 | $ export CGO_LDFLAGS="-L$HOME/.coqui/" 39 | $ export CGO_CXXFLAGS="-I$HOME/.coqui/" 40 | $ export LD_LIBRARY_PATH="$HOME/.coqui/:$LD_LIBRARY_PATH" 41 | ``` 42 | 43 | ## Install asticoqui 44 | ### Install dependencies 45 | 46 | Run the following command: 47 | 48 | ```bash 49 | $ go get -u github.com/asticode/go-asticoqui/... 50 | ``` 51 | 52 | ### Install executables 53 | 54 | Run the following command: 55 | 56 | ```bash 57 | $ go install github.com/asticode/go-asticoqui/cmd 58 | ``` 59 | 60 | # Example Usage 61 | 62 | ## Get the pre-trained model and scorer 63 | 64 | Go to [this page](https://coqui.ai/english/coqui/v1.0.0-huge-vocab) and click `Enter Email to Download` at the bottom of the page. Download `model.tflite` and `huge_vocabulary.scorer`. 65 | 66 | ## Get the audio files 67 | 68 | Run the following commands: 69 | 70 | ```bash 71 | $ cd $HOME/.coqui 72 | $ wget https://github.com/coqui-ai/STT/releases/download/v1.2.0/audio-1.2.0.tar.gz 73 | $ tar -xvfz audio-1.2.0.tar.gz 74 | ``` 75 | 76 | ## Use this client 77 | 78 | Run the following commands: 79 | 80 | ```bash 81 | $ go run coqui/main.go -model model.tflite -scorer huge_vocabulary.scorer -audio audio/2830-3980-0043.wav 82 | 83 | Text: experience proves this 84 | 85 | $ go run coqui/main.go -model model.tflite -scorer huge_vocabulary.scorer -audio audio/4507-16021-0012.wav 86 | 87 | Text: why should one hall on the way 88 | 89 | $ go run coqui/main.go -model model.tflite -scorer huge_vobaculary.scorer -audio audio/8455-210777-0068.wav 90 | 91 | Text: your power is sufficient i said 92 | ``` 93 | -------------------------------------------------------------------------------- /cmd/asticoqui/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "io" 7 | "log" 8 | "os" 9 | 10 | "github.com/asticode/go-asticoqui" 11 | "github.com/cryptix/wav" 12 | ) 13 | 14 | var model = flag.String("model", "", "Path to the model (protocol buffer binary file)") 15 | var audio = flag.String("audio", "", "Path to the audio file to run (WAV format)") 16 | var scorer = flag.String("scorer", "", "Path to the external scorer") 17 | var version = flag.Bool("version", false, "Print version and exit") 18 | var extended = flag.Bool("extended", false, "Use extended metadata") 19 | var maxResults = flag.Uint("max-results", 5, "Maximum number of results when -extended is true") 20 | var printSampleRate = flag.Bool("sample-rate", false, "Print model sample rate and exit") 21 | 22 | func metadataToStrings(m *asticoqui.Metadata) []string { 23 | results := make([]string, 0, m.NumTranscripts()) 24 | for _, tr := range m.Transcripts() { 25 | var res string 26 | for _, tok := range tr.Tokens() { 27 | res += tok.Text() 28 | } 29 | res += fmt.Sprintf(" [%0.3f]", tr.Confidence()) 30 | results = append(results, res) 31 | } 32 | return results 33 | } 34 | 35 | func main() { 36 | flag.Parse() 37 | log.SetFlags(0) 38 | 39 | if *version { 40 | fmt.Println(asticoqui.Version()) 41 | return 42 | } 43 | 44 | if *model == "" || *audio == "" { 45 | // In case of error print error and print usage 46 | // This can also be done by passing -h or --help flags 47 | fmt.Fprintf(flag.CommandLine.Output(), "Usage of %s:\n", os.Args[0]) 48 | flag.PrintDefaults() 49 | return 50 | } 51 | 52 | // Initialize Coqui 53 | m, err := asticoqui.New(*model) 54 | if err != nil { 55 | log.Fatal("Failed initializing model: ", err) 56 | } 57 | defer m.Close() 58 | 59 | if *printSampleRate { 60 | fmt.Println(m.SampleRate()) 61 | return 62 | } 63 | 64 | if *scorer != "" { 65 | if err := m.EnableExternalScorer(*scorer); err != nil { 66 | log.Fatal("Failed enabling external scorer: ", err) 67 | } 68 | } 69 | 70 | // Stat audio 71 | i, err := os.Stat(*audio) 72 | if err != nil { 73 | log.Fatal(fmt.Errorf("stating %s failed: %w", *audio, err)) 74 | } 75 | 76 | // Open audio 77 | f, err := os.Open(*audio) 78 | if err != nil { 79 | log.Fatal(fmt.Errorf("opening %s failed: %w", *audio, err)) 80 | } 81 | 82 | // Create reader 83 | r, err := wav.NewReader(f, i.Size()) 84 | if err != nil { 85 | log.Fatal(fmt.Errorf("creating new reader failed: %w", err)) 86 | } 87 | 88 | // Read 89 | var d []int16 90 | for { 91 | // Read sample 92 | s, err := r.ReadSample() 93 | if err == io.EOF { 94 | break 95 | } else if err != nil { 96 | log.Fatal(fmt.Errorf("reading sample failed: %w", err)) 97 | } 98 | 99 | // Append 100 | d = append(d, int16(s)) 101 | } 102 | 103 | // Speech to text 104 | var results []string 105 | if *extended { 106 | metadata, err := m.SpeechToTextWithMetadata(d, *maxResults) 107 | if err != nil { 108 | log.Fatal("Failed converting speech to text: ", err) 109 | } 110 | defer metadata.Close() 111 | results = metadataToStrings(metadata) 112 | } else { 113 | res, err := m.SpeechToText(d) 114 | if err != nil { 115 | log.Fatal("Failed converting speech to text: ", err) 116 | } 117 | results = []string{res} 118 | } 119 | for _, res := range results { 120 | fmt.Println("Text:", res) 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /coqui.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | extern "C" { 5 | class ModelWrapper { 6 | private: 7 | ModelState* model; 8 | 9 | public: 10 | ModelWrapper(const char* aModelPath, int *errorOut) 11 | { 12 | model = nullptr; 13 | *errorOut = STT_CreateModel(aModelPath, &model); 14 | } 15 | 16 | ~ModelWrapper() 17 | { 18 | if (model) { 19 | STT_FreeModel(model); 20 | model = nullptr; 21 | } 22 | } 23 | 24 | unsigned int beamWidth() 25 | { 26 | return STT_GetModelBeamWidth(model); 27 | } 28 | 29 | int setBeamWidth(unsigned int aBeamWidth) 30 | { 31 | return STT_SetModelBeamWidth(model, aBeamWidth); 32 | } 33 | 34 | int sampleRate() 35 | { 36 | return STT_GetModelSampleRate(model); 37 | } 38 | 39 | int enableExternalScorer(const char* aScorerPath) 40 | { 41 | return STT_EnableExternalScorer(model, aScorerPath); 42 | } 43 | 44 | int disableExternalScorer() 45 | { 46 | return STT_DisableExternalScorer(model); 47 | } 48 | 49 | int setScorerAlphaBeta(float aAlpha, float aBeta) 50 | { 51 | return STT_SetScorerAlphaBeta(model, aAlpha, aBeta); 52 | } 53 | 54 | char* stt(const short* aBuffer, unsigned int aBufferSize) 55 | { 56 | return STT_SpeechToText(model, aBuffer, aBufferSize); 57 | } 58 | 59 | Metadata* sttWithMetadata(const short* aBuffer, unsigned int aBufferSize, unsigned int aNumResults) 60 | { 61 | return STT_SpeechToTextWithMetadata(model, aBuffer, aBufferSize, aNumResults); 62 | } 63 | 64 | ModelState* getModel() 65 | { 66 | return model; 67 | } 68 | }; 69 | 70 | ModelWrapper* New(const char* aModelPath, int* errorOut) 71 | { 72 | auto mw = new ModelWrapper(aModelPath, errorOut); 73 | if (*errorOut != STT_ERR_OK) { 74 | delete mw; 75 | mw = nullptr; 76 | } 77 | return mw; 78 | } 79 | void Model_Close(ModelWrapper* w) 80 | { 81 | delete w; 82 | } 83 | 84 | unsigned int Model_BeamWidth(ModelWrapper* w) 85 | { 86 | return w->beamWidth(); 87 | } 88 | 89 | int Model_SetBeamWidth(ModelWrapper* w, unsigned int aBeamWidth) 90 | { 91 | return w->setBeamWidth(aBeamWidth); 92 | } 93 | 94 | int Model_SampleRate(ModelWrapper* w) 95 | { 96 | return w->sampleRate(); 97 | } 98 | 99 | int Model_EnableExternalScorer(ModelWrapper* w, const char* aScorerPath) 100 | { 101 | return w->enableExternalScorer(aScorerPath); 102 | } 103 | 104 | int Model_DisableExternalScorer(ModelWrapper* w) 105 | { 106 | return w->disableExternalScorer(); 107 | } 108 | 109 | int Model_SetScorerAlphaBeta(ModelWrapper* w, float aAlpha, float aBeta) 110 | { 111 | return w->setScorerAlphaBeta(aAlpha, aBeta); 112 | } 113 | 114 | char* Model_STT(ModelWrapper* w, const short* aBuffer, unsigned int aBufferSize) 115 | { 116 | return w->stt(aBuffer, aBufferSize); 117 | } 118 | 119 | Metadata* Model_STTWithMetadata(ModelWrapper* w, const short* aBuffer, unsigned int aBufferSize, unsigned int aNumResults) 120 | { 121 | return w->sttWithMetadata(aBuffer, aBufferSize, aNumResults); 122 | } 123 | 124 | const CandidateTranscript* Metadata_Transcripts(Metadata* m) 125 | { 126 | return m->transcripts; 127 | } 128 | 129 | unsigned int Metadata_NumTranscripts(Metadata* m) 130 | { 131 | return m->num_transcripts; 132 | } 133 | 134 | void Metadata_Close(Metadata* m) 135 | { 136 | STT_FreeMetadata(m); 137 | } 138 | 139 | const TokenMetadata* CandidateTranscript_Tokens(CandidateTranscript* ct) 140 | { 141 | return ct->tokens; 142 | } 143 | 144 | int CandidateTranscript_NumTokens(CandidateTranscript* ct) 145 | { 146 | return ct->num_tokens; 147 | } 148 | 149 | double CandidateTranscript_Confidence(CandidateTranscript* ct) 150 | { 151 | return ct->confidence; 152 | } 153 | 154 | const char* TokenMetadata_Text(TokenMetadata* tm) 155 | { 156 | return tm->text; 157 | } 158 | 159 | unsigned int TokenMetadata_Timestep(TokenMetadata* tm) 160 | { 161 | return tm->timestep; 162 | } 163 | 164 | float TokenMetadata_StartTime(TokenMetadata* tm) 165 | { 166 | return tm->start_time; 167 | } 168 | 169 | class StreamWrapper { 170 | private: 171 | StreamingState* s; 172 | 173 | public: 174 | StreamWrapper(ModelWrapper* w, int* errorOut) 175 | { 176 | s = nullptr; 177 | *errorOut = STT_CreateStream(w->getModel(), &s); 178 | } 179 | 180 | ~StreamWrapper() 181 | { 182 | if (s) { 183 | STT_FreeStream(s); 184 | s = nullptr; 185 | } 186 | } 187 | 188 | void feedAudioContent(const short* aBuffer, unsigned int aBufferSize) 189 | { 190 | STT_FeedAudioContent(s, aBuffer, aBufferSize); 191 | } 192 | 193 | char* intermediateDecode() 194 | { 195 | return STT_IntermediateDecode(s); 196 | } 197 | 198 | Metadata* intermediateDecodeWithMetadata(unsigned int aNumResults) 199 | { 200 | return STT_IntermediateDecodeWithMetadata(s, aNumResults); 201 | } 202 | 203 | char* finish() 204 | { 205 | // STT_FinishStream frees the supplied state pointer. 206 | char* res = STT_FinishStream(s); 207 | s = nullptr; 208 | return res; 209 | } 210 | 211 | Metadata* finishWithMetadata(unsigned int aNumResults) 212 | { 213 | // STT_FinishStreamWithMetadata frees the supplied state pointer. 214 | Metadata* m = STT_FinishStreamWithMetadata(s, aNumResults); 215 | s = nullptr; 216 | return m; 217 | } 218 | 219 | void discard() 220 | { 221 | STT_FreeStream(s); 222 | s = nullptr; 223 | } 224 | }; 225 | 226 | StreamWrapper* Model_NewStream(ModelWrapper* mw, int* errorOut) 227 | { 228 | auto sw = new StreamWrapper(mw, errorOut); 229 | if (*errorOut != STT_ERR_OK) { 230 | delete sw; 231 | sw = nullptr; 232 | } 233 | return sw; 234 | } 235 | void Stream_Discard(StreamWrapper* sw) 236 | { 237 | sw->discard(); 238 | delete sw; 239 | } 240 | 241 | void Stream_FeedAudioContent(StreamWrapper* sw, const short* aBuffer, unsigned int aBufferSize) 242 | { 243 | sw->feedAudioContent(aBuffer, aBufferSize); 244 | } 245 | 246 | char* Stream_IntermediateDecode(StreamWrapper* sw) 247 | { 248 | return sw->intermediateDecode(); 249 | } 250 | 251 | Metadata* Stream_IntermediateDecodeWithMetadata(StreamWrapper* sw, unsigned int aNumResults) 252 | { 253 | return sw->intermediateDecodeWithMetadata(aNumResults); 254 | } 255 | 256 | char* Stream_Finish(StreamWrapper* sw) 257 | { 258 | char* str = sw->finish(); 259 | delete sw; 260 | return str; 261 | } 262 | 263 | Metadata* Stream_FinishWithMetadata(StreamWrapper* sw, unsigned int aNumResults) 264 | { 265 | Metadata* m = sw->finishWithMetadata(aNumResults); 266 | delete sw; 267 | return m; 268 | } 269 | 270 | void FreeString(char* s) 271 | { 272 | STT_FreeString(s); 273 | } 274 | 275 | char* Version() 276 | { 277 | return STT_Version(); 278 | } 279 | 280 | char* ErrorCodeToErrorMessage(int aErrorCode) 281 | { 282 | return STT_ErrorCodeToErrorMessage(aErrorCode); 283 | } 284 | } 285 | -------------------------------------------------------------------------------- /coqui.go: -------------------------------------------------------------------------------- 1 | package asticoqui 2 | 3 | /* 4 | #cgo CXXFLAGS: -std=c++11 5 | #cgo LDFLAGS: -lstt 6 | #include "coqui_wrap.h" 7 | #include "stdlib.h" 8 | */ 9 | import "C" 10 | import ( 11 | "errors" 12 | "math" 13 | "unsafe" 14 | ) 15 | 16 | // Model provides an interface to a trained model. 17 | type Model struct { 18 | w *C.ModelWrapper 19 | } 20 | 21 | // New creates a new Model. 22 | // modelPath is the path to the frozen model graph. 23 | func New(modelPath string) (*Model, error) { 24 | cModelPath := C.CString(modelPath) 25 | defer C.free(unsafe.Pointer(cModelPath)) 26 | 27 | var ret C.int 28 | w := C.New(cModelPath, &ret) // returns nil on error 29 | if ret != 0 { 30 | return nil, errorFromCode(ret) 31 | } 32 | return &Model{w}, nil 33 | } 34 | 35 | // Close frees associated resources and destroys the model object. 36 | func (m *Model) Close() { 37 | C.Model_Close(m.w) // deletes m.w 38 | m.w = nil 39 | } 40 | 41 | // BeamWidth returns the beam width value used by the model. 42 | // If SetModelBeamWidth was not called before, it will return the default 43 | // value loaded from the model file. 44 | func (m *Model) BeamWidth() uint { 45 | return uint(C.Model_BeamWidth(m.w)) 46 | } 47 | 48 | // SetBeamWidth sets the beam width value used by the model. 49 | // A larger beam width value generates better results at the cost of decoding time. 50 | func (m *Model) SetBeamWidth(width uint) error { 51 | return errorFromCode(C.Model_SetBeamWidth(m.w, C.uint(width))) 52 | } 53 | 54 | // SampleRate returns the sample rate that was used to produce the model file. 55 | func (m *Model) SampleRate() int { 56 | return int(C.Model_SampleRate(m.w)) 57 | } 58 | 59 | // EnableExternalScorer enables decoding using an external scorer. 60 | // scorerPath is the path to the external scorer file. 61 | func (m *Model) EnableExternalScorer(scorerPath string) error { 62 | cScorerPath := C.CString(scorerPath) 63 | defer C.free(unsafe.Pointer(cScorerPath)) 64 | return errorFromCode(C.Model_EnableExternalScorer(m.w, cScorerPath)) 65 | } 66 | 67 | // DisableExternalScorer disables decoding using an external scorer. 68 | func (m *Model) DisableExternalScorer() error { 69 | return errorFromCode(C.Model_DisableExternalScorer(m.w)) 70 | } 71 | 72 | // SetScorerAlphaBeta sets hyperparameters alpha and beta of the external scorer. 73 | // alpha is the language model weight. beta is the word insertion weight. 74 | func (m *Model) SetScorerAlphaBeta(alpha, beta float32) error { 75 | return errorFromCode(C.Model_SetScorerAlphaBeta(m.w, C.float(alpha), C.float(beta))) 76 | } 77 | 78 | // sliceHeader represents a slice header 79 | type sliceHeader struct { 80 | Data uintptr 81 | Len int 82 | Cap int 83 | } 84 | 85 | // SpeechToText uses the model to convert speech to text. 86 | // buffer is 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on). 87 | func (m *Model) SpeechToText(buffer []int16) (string, error) { 88 | hdr := (*sliceHeader)(unsafe.Pointer(&buffer)) 89 | str := C.Model_STT(m.w, (*C.short)(unsafe.Pointer(hdr.Data)), C.uint(hdr.Len)) 90 | if str == nil { 91 | return "", errors.New("conversion failed") 92 | } 93 | defer C.FreeString(str) 94 | return C.GoString(str), nil 95 | } 96 | 97 | // TokenMetadata stores text of an individual token, along with its timing information. 98 | type TokenMetadata C.struct_TokenMetadata 99 | 100 | // Text returns the text corresponding to this token. 101 | func (tm *TokenMetadata) Text() string { 102 | return C.GoString(C.TokenMetadata_Text((*C.TokenMetadata)(unsafe.Pointer(tm)))) 103 | } 104 | 105 | // Timestep returns the position of the token in units of 20ms. 106 | func (tm *TokenMetadata) Timestep() uint { 107 | return uint(C.TokenMetadata_Timestep((*C.TokenMetadata)(unsafe.Pointer(tm)))) 108 | } 109 | 110 | // StartTime returns the position of the token in seconds. 111 | func (tm *TokenMetadata) StartTime() float32 { 112 | return float32(C.TokenMetadata_StartTime((*C.TokenMetadata)(unsafe.Pointer(tm)))) 113 | } 114 | 115 | // CandidateTranscript is a single transcript computed by the model, 116 | // including a confidence value and the metadata for its constituent tokens. 117 | type CandidateTranscript C.struct_CandidateTranscript 118 | 119 | func (ct *CandidateTranscript) NumTokens() uint { 120 | return uint(C.CandidateTranscript_NumTokens((*C.CandidateTranscript)(unsafe.Pointer(ct)))) 121 | } 122 | 123 | func (ct *CandidateTranscript) Tokens() []TokenMetadata { 124 | numTokens := uint(C.CandidateTranscript_NumTokens((*C.CandidateTranscript)(unsafe.Pointer(ct)))) 125 | allTokens := C.CandidateTranscript_Tokens((*C.CandidateTranscript)(unsafe.Pointer(ct))) 126 | return (*[math.MaxInt32 - 1]TokenMetadata)(unsafe.Pointer(allTokens))[:numTokens:numTokens] 127 | } 128 | 129 | // Confidence returns the approximated confidence value for this transcript. 130 | // This is roughly the sum of the acoustic model logit values for each timestep/character that 131 | // contributed to the creation of this transcript. 132 | func (ct *CandidateTranscript) Confidence() float64 { 133 | return float64(C.CandidateTranscript_Confidence((*C.CandidateTranscript)(unsafe.Pointer(ct)))) 134 | } 135 | 136 | // Metadata holds an array of CandidateTranscript objects computed by the model. 137 | type Metadata C.struct_Metadata 138 | 139 | func (m *Metadata) NumTranscripts() uint { 140 | return uint(C.Metadata_NumTranscripts((*C.Metadata)(unsafe.Pointer(m)))) 141 | } 142 | 143 | func (m *Metadata) Transcripts() []CandidateTranscript { 144 | numTranscripts := int32(C.Metadata_NumTranscripts((*C.Metadata)(unsafe.Pointer(m)))) 145 | allTranscripts := C.Metadata_Transcripts((*C.Metadata)(unsafe.Pointer(m))) 146 | return (*[math.MaxInt32 - 1]CandidateTranscript)(unsafe.Pointer(allTranscripts))[:numTranscripts:numTranscripts] 147 | } 148 | 149 | // Close frees the Metadata structure properly. 150 | func (m *Metadata) Close() { 151 | C.Metadata_Close((*C.Metadata)(unsafe.Pointer(m))) 152 | } 153 | 154 | // SpeechToTextWithMetadata uses the model to convert speech to text and 155 | // output results including metadata. 156 | // 157 | // buffer is a 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on). 158 | // numResults is the maximum number of CandidateTranscript structs to return. Returned value might be smaller than this. 159 | // If an error is not returned, the returned metadata's Close method must be called later to free resources. 160 | func (m *Model) SpeechToTextWithMetadata(buffer []int16, numResults uint) (*Metadata, error) { 161 | hdr := (*sliceHeader)(unsafe.Pointer(&buffer)) 162 | md := (*Metadata)(unsafe.Pointer(C.Model_STTWithMetadata( 163 | m.w, (*C.short)(unsafe.Pointer(hdr.Data)), C.uint(hdr.Len), C.uint(numResults)))) 164 | if md == nil { 165 | return nil, errors.New("conversion failed") 166 | } 167 | return md, nil 168 | } 169 | 170 | // Stream represents a streaming inference state. 171 | type Stream struct { 172 | sw *C.StreamWrapper 173 | } 174 | 175 | // NewStream creates a new streaming inference state. 176 | // If an error is not returned, exactly one of the returned stream's Finish, 177 | // FinishWithMetadata, or Discard methods must be called later to free resources. 178 | func (m *Model) NewStream() (*Stream, error) { 179 | var ret C.int 180 | sw := C.Model_NewStream(m.w, &ret) // returns nil on error 181 | if ret != 0 { 182 | return nil, errorFromCode(ret) 183 | } 184 | return &Stream{sw}, nil 185 | } 186 | 187 | // FeedAudioContent feeds audio samples to an ongoing streaming inference. 188 | // buffer is an array of 16-bit, mono raw audio samples at the appropriate sample rate 189 | // (matching what the model was trained on). 190 | func (s *Stream) FeedAudioContent(buffer []int16) { 191 | hdr := (*sliceHeader)(unsafe.Pointer(&buffer)) 192 | C.Stream_FeedAudioContent(s.sw, (*C.short)(unsafe.Pointer(hdr.Data)), C.uint(hdr.Len)) 193 | } 194 | 195 | // IntermediateDecode computes the intermediate decoding of an ongoing streaming inference. 196 | // This is an expensive process as the decoder implementation isn't 197 | // currently capable of streaming, so it always starts from the beginning 198 | // of the audio. 199 | func (s *Stream) IntermediateDecode() (string, error) { 200 | // STT_IntermediateDecode isn't documented as returning null, but future-proofing this seems safer. 201 | str := C.Stream_IntermediateDecode(s.sw) 202 | if str == nil { 203 | return "", errors.New("decoding failed") 204 | } 205 | defer C.FreeString(str) 206 | return C.GoString(str), nil 207 | } 208 | 209 | // IntermediateDecodeWithMetadata computes the intermediate decoding of an 210 | // ongoing streaming inference, returning results including metadata. 211 | // numResults is the number of candidate transcripts to return. 212 | // If an error is not returned, the metadata's Close method must be called. 213 | func (s *Stream) IntermediateDecodeWithMetadata(numResults uint) (*Metadata, error) { 214 | md := (*Metadata)(unsafe.Pointer(C.Stream_IntermediateDecodeWithMetadata(s.sw, C.uint(numResults)))) 215 | if md == nil { 216 | return nil, errors.New("decoding failed") 217 | } 218 | return md, nil 219 | } 220 | 221 | // Finish computes the final decoding of an ongoing streaming inference and returns the result. 222 | // This signals the end of an ongoing streaming inference. 223 | func (s *Stream) Finish() (string, error) { 224 | // STT_FinishStream isn't documented as returning null, but future-proofing this seems safer. 225 | str := C.Stream_Finish(s.sw) // deletes s.sw 226 | s.sw = nil 227 | 228 | if str == nil { 229 | return "", errors.New("decoding failed") 230 | } 231 | defer C.FreeString(str) 232 | return C.GoString(str), nil 233 | } 234 | 235 | // FinishWithMetadata computes the final decoding of an ongoing streaming inference and returns 236 | // results including metadata. This signals the end of an ongoing streaming inference. 237 | // If an error is not returned, the metadata's Close method must be called. 238 | func (s *Stream) FinishWithMetadata(numResults uint) (*Metadata, error) { 239 | md := (*Metadata)(unsafe.Pointer(C.Stream_FinishWithMetadata(s.sw, C.uint(numResults)))) // deletes s.sw 240 | s.sw = nil 241 | 242 | if md == nil { 243 | return nil, errors.New("decoding failed") 244 | } 245 | return md, nil 246 | } 247 | 248 | // Discard destroys a streaming state without decoding the computed logits. 249 | // This can be used if you no longer need the result of an ongoing streaming 250 | // inference and don't want to perform a costly decode operation. 251 | func (s *Stream) Discard() { 252 | C.Stream_Discard(s.sw) // deletes s.sw 253 | s.sw = nil 254 | } 255 | 256 | // Version returns the version of the C library. 257 | // The returned version is a semantic version (SemVer 2.0.0). 258 | func Version() string { 259 | str := C.Version() 260 | defer C.FreeString(str) 261 | return C.GoString(str) 262 | } 263 | 264 | // errorFromCode converts a C error code into a Go error. 265 | // Returns nil if code is equal to zero, indicating success. 266 | func errorFromCode(code C.int) error { 267 | if code == 0 { 268 | return nil 269 | } 270 | str := C.ErrorCodeToErrorMessage(code) 271 | defer C.FreeString(str) 272 | return errors.New(C.GoString(str)) 273 | } 274 | --------------------------------------------------------------------------------