├── .dockerignore ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── app.py ├── bert_punctuator ├── bert.py ├── modules.py └── tokenizer.py ├── config.ini ├── data_loader.py ├── decoder.py ├── docker-compose.yml ├── docs ├── Руководство оператора SOVA Speech (RU).pdf └── Руководство програмиста SOVA Speech (RU).pdf ├── file_handler.py ├── number_utils ├── russian_numbers.py └── text2numbers.py ├── punctuator.py ├── requirements.txt ├── speech_recognizer.py ├── static └── speech │ ├── css │ └── styles.css │ └── js │ ├── speech_recorder.js │ └── web-audio-recording-tests-simpler-master │ └── js │ ├── RecorderService.js │ ├── WebAudioPeakMeter.js │ ├── app.js │ └── encoder-wav-worker.js ├── templates ├── _formhelpers.html ├── base.html └── speech_recognition.html ├── train.py ├── trie_decoder ├── __init__.py ├── _common.cpython-36m-x86_64-linux-gnu.so ├── _criterion.cpython-36m-x86_64-linux-gnu.so ├── _decoder.cpython-36m-x86_64-linux-gnu.so ├── _feature.cpython-36m-x86_64-linux-gnu.so ├── common.py ├── criterion.py ├── decoder.py └── feature.py └── utils.py /.dockerignore: -------------------------------------------------------------------------------- 1 | data/ 2 | records/ 3 | checkpoints/ 4 | __pycache__/ 5 | .idea/ 6 | .DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | records/ 3 | checkpoints/ 4 | __pycache__/ 5 | .idea/ 6 | .DS_Store -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # If using CPU only replace the following line with: 2 | # FROM ubuntu:18.04 3 | FROM nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 4 | 5 | ARG DEBIAN_FRONTEND=noninteractive 6 | 7 | RUN apt-get update && apt-get upgrade -y && apt-get autoremove && apt-get autoclean 8 | RUN apt-get install -y python3-dev python3-pip ffmpeg 9 | 10 | ARG PROJECT=sova-asr 11 | ARG PROJECT_DIR=/$PROJECT 12 | RUN mkdir -p $PROJECT_DIR 13 | WORKDIR $PROJECT_DIR 14 | 15 | COPY requirements.txt . 16 | RUN pip3 install --upgrade pip 17 | RUN pip3 install -r requirements.txt 18 | 19 | # If using CPU only replace the following two lines with: 20 | # RUN pip3 install PuzzleLib 21 | RUN ln -s /usr/local/cuda/targets/x86_64-linux/lib/ /usr/local/cuda/lib64/ 22 | RUN pip3 install PuzzleLib==1.0.3a0 --install-option="--backend=cuda" 23 | 24 | RUN rm -rf $PROJECT_DIR/* 25 | 26 | RUN apt-get install -y locales && locale-gen en_US.UTF-8 27 | ENV LANG en_US.UTF-8 28 | ENV LANGUAGE en_US:en 29 | ENV LC_ALL en_US.UTF-8 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright (c) 2021, Virtual Assistants, LLC 190 | All rights reserved. 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SOVA ASR 2 | 3 | SOVA ASR is a fast speech recognition solution based on [Wav2Letter](https://arxiv.org/abs/1609.03193) architecture. It is designed as a REST API service and it can be customized (both code and models) for your needs. 4 | 5 | ## Installation 6 | 7 | The easiest way to deploy the service is via docker-compose, so you have to install Docker and docker-compose first. Here's a brief instruction for Ubuntu: 8 | 9 | #### Docker installation 10 | 11 | * Install Docker: 12 | ```bash 13 | $ sudo apt-get update 14 | $ sudo apt-get install \ 15 | apt-transport-https \ 16 | ca-certificates \ 17 | curl \ 18 | gnupg-agent \ 19 | software-properties-common 20 | $ curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add - 21 | $ sudo apt-key fingerprint 0EBFCD88 22 | $ sudo add-apt-repository \ 23 | "deb [arch=amd64] https://download.docker.com/linux/ubuntu \ 24 | $(lsb_release -cs) \ 25 | stable" 26 | $ sudo apt-get update 27 | $ sudo apt-get install docker-ce docker-ce-cli containerd.io 28 | $ sudo usermod -aG docker $(whoami) 29 | ``` 30 | In order to run docker commands without sudo you might need to relogin. 31 | * Install docker-compose: 32 | ``` 33 | $ sudo curl -L "https://github.com/docker/compose/releases/download/1.25.5/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose 34 | $ sudo chmod +x /usr/local/bin/docker-compose 35 | ``` 36 | 37 | * (Optional) If you're planning on using CUDA run these commands: 38 | ``` 39 | $ curl -s -L https://nvidia.github.io/nvidia-container-runtime/gpgkey | \ 40 | sudo apt-key add - 41 | $ distribution=$(. /etc/os-release;echo $ID$VERSION_ID) 42 | $ curl -s -L https://nvidia.github.io/nvidia-container-runtime/$distribution/nvidia-container-runtime.list | \ 43 | sudo tee /etc/apt/sources.list.d/nvidia-container-runtime.list 44 | $ sudo apt-get update 45 | $ sudo apt-get install nvidia-container-runtime 46 | ``` 47 | Add the following content to the file **/etc/docker/daemon.json**: 48 | ```json 49 | { 50 | "runtimes": { 51 | "nvidia": { 52 | "path": "nvidia-container-runtime", 53 | "runtimeArgs": [] 54 | } 55 | }, 56 | "default-runtime": "nvidia" 57 | } 58 | ``` 59 | Restart the service: 60 | ```bash 61 | $ sudo systemctl restart docker.service 62 | ``` 63 | 64 | #### Build and deploy 65 | 66 | **In order to run service with pretrained models you will have to download http://dataset.sova.ai/SOVA-ASR/data.tar.gz.** 67 | 68 | * Clone the repository, download the pretrained models archive and extract the contents into the project folder: 69 | ```bash 70 | $ git clone --recursive https://github.com/sovaai/sova-asr.git 71 | $ cd sova-asr/ 72 | $ wget http://dataset.sova.ai/SOVA-ASR/data.tar.gz 73 | $ tar -xvf data.tar.gz && rm data.tar.gz 74 | ``` 75 | 76 | * Build docker image 77 | * If you're planning on using GPU (it is required for training and can be used for inference): build *sova-asr* image using the following command: 78 | ```bash 79 | $ sudo docker-compose build 80 | ``` 81 | * If you're planning on using CPU only: modify `Dockerfile`, `docker-compose.yml` (remove the runtime and environment sections) and `config.ini` (*cpu* should be set to 0) and build *sova-asr* image: 82 | ```bash 83 | $ sudo docker-compose build 84 | ``` 85 | 86 | * Run web service in a docker container 87 | ```bash 88 | $ sudo docker-compose up -d sova-asr 89 | ``` 90 | 91 | ## Testing 92 | 93 | To test the service you can send a POST request: 94 | ```bash 95 | $ curl --request POST 'http://localhost:8888/asr' --form 'audio_blob=@"data/test.wav"' 96 | ``` 97 | 98 | ## Finetuning acoustic model 99 | 100 | If you want to finetune the acoustic model you can set hyperparameters and paths to your own train and validation manifest files and run the training service. 101 | 102 | * Set training options in *Train* section of **config.ini**. Train and validation csv manifest files should contain comma-separated audio file paths and reference texts in each line. For instance: 103 | ```bash 104 | data/audio/000000.wav,добрый день 105 | data/audio/000001.wav,как ваши дела 106 | ... 107 | ``` 108 | * Run training in docker container: 109 | ```bash 110 | $ sudo docker-compose up -d sova-asr-train 111 | ``` 112 | 113 | ## Customizations 114 | 115 | If you want to train your own acoustic model refer to [PuzzleLib tutorials](https://puzzlelib.org/tutorials/Wav2Letter/). Check [KenLM documentation](https://kheafield.com/code/kenlm/) for building your own language model. This repository was tested on Ubuntu 18.04 and has pre-built .so Trie decoder files for Python 3.6 running inside the Docker container, for modifications you can get your own .so files using [Wav2Letter++](https://github.com/facebookresearch/wav2letter) code for building Python bindings. Otherwise you can use a standard Greedy decoder (set in config.ini). 116 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, render_template, request, send_from_directory, url_for 2 | from file_handler import FileHandler 3 | import json 4 | 5 | 6 | app = Flask(__name__) 7 | 8 | 9 | @app.route('/', methods=['GET']) 10 | def index(): 11 | return render_template('speech_recognition.html') 12 | 13 | 14 | @app.route('/asr', methods=['POST']) 15 | def asr(): 16 | res = [] 17 | for f in request.files: 18 | if f.startswith('audio_blob') and FileHandler.check_format(request.files[f]): 19 | 20 | response_code, filename, response = FileHandler.get_recognized_text(request.files[f]) 21 | 22 | if response_code == 0: 23 | response_audio_url = url_for('media_file', filename=filename) 24 | else: 25 | response_audio_url = None 26 | 27 | res.append({ 28 | 'response_audio_url': response_audio_url, 29 | 'response_code': response_code, 30 | 'response': response, 31 | }) 32 | return json.dumps({'r': res}, ensure_ascii=False) 33 | 34 | 35 | @app.route('/media/', methods=['GET']) 36 | def media_file(filename): 37 | return send_from_directory('./records', filename, as_attachment=False) 38 | -------------------------------------------------------------------------------- /bert_punctuator/bert.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import math 4 | import numpy as np 5 | from bert_punctuator.modules import Embedder, Linear 6 | from PuzzleLib.Backend.Blas import mulTensorBatch 7 | from PuzzleLib.Backend import gpuarray 8 | from PuzzleLib.Modules import Module, Activation, SwapAxes, Mul, SoftMax, ModuleError, Tile, Gelu, BatchNorm, InstanceNorm2D 9 | from PuzzleLib.Backend.Kernels import MatVec 10 | from PuzzleLib.Variable import Variable 11 | from PuzzleLib.Containers import Container, Sequential 12 | 13 | 14 | class BertConfig(object): 15 | def __init__(self, 16 | vocab_size_or_json, 17 | hidden_size=768, 18 | num_hidden_layers=12, 19 | num_attention_heads=12, 20 | intermediate_size=3072, 21 | max_position_embeddings=512, 22 | type_vocab_size=2, 23 | segment_size=32): 24 | 25 | if isinstance(vocab_size_or_json, str): 26 | with open(vocab_size_or_json, "r") as reader: 27 | json_config = json.loads(reader.read()) 28 | for key, value in json_config.items(): 29 | self.__dict__[key] = value 30 | elif isinstance(vocab_size_or_json, int): 31 | self.vocab_size = vocab_size_or_json 32 | self.hidden_size = hidden_size 33 | self.num_hidden_layers = num_hidden_layers 34 | self.num_attention_heads = num_attention_heads 35 | self.intermediate_size = intermediate_size 36 | self.max_position_embeddings = max_position_embeddings 37 | self.type_vocab_size = type_vocab_size 38 | self.segment_size = segment_size 39 | self.output_size = output_size 40 | else: 41 | raise ValueError("First argument must be either a vocabulary size (int)" 42 | "or the path to a pretrained model config file (str)") 43 | 44 | def __repr__(self): 45 | return str(self.to_json()) 46 | 47 | def to_dict(self): 48 | """Serializes this instance to a Python dictionary.""" 49 | output = copy.deepcopy(self.__dict__) 50 | return output 51 | 52 | def to_json(self): 53 | """Serializes this instance to a JSON string.""" 54 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 55 | 56 | 57 | class BertLayerNorm(Module): 58 | def __init__(self, config, epsilon=1e-12, name=None): 59 | super().__init__(name) 60 | self.registerBlueprint(locals()) 61 | 62 | self.epsilon = epsilon 63 | self.scale = None 64 | self.setVar("scale", Variable(gpuarray.to_gpu(np.ones(config.hidden_size, dtype=np.float32)))) 65 | self.bias = None 66 | self.setVar("bias", Variable(gpuarray.to_gpu(np.zeros(config.hidden_size, dtype=np.float32)))) 67 | self.mul = Mul() 68 | 69 | def updateData(self, data): 70 | batchsize, maps, h = data.shape 71 | 72 | data = data.reshape((batchsize, maps, h, 1)) 73 | norm = InstanceNorm2D(maps, epsilon=self.epsilon) 74 | norm.calcMode(self.calctype) 75 | data = norm(data) 76 | data = data.reshape((batchsize * maps, h)) 77 | 78 | tile = Tile(axis=0, times=data.shape[0]) 79 | tile.calcMode(self.calctype) 80 | scale = tile(self.scale.reshape(tuple([1]) + self.scale.shape)) 81 | 82 | data = self.mul([data, scale]) 83 | data = MatVec.addVecToMat(self.bias, data, axis=1, out=data) 84 | self.data = data.reshape(batchsize, maps, h) 85 | 86 | def checkDataType(self, dtype): 87 | if dtype != self.calctype: 88 | raise ModuleError("Expected {} (got dtype {})".format(self.calctype, dtype)) 89 | 90 | def calcMode(self, T): 91 | if self.calctype == T: 92 | return 93 | 94 | self.mul.calcMode(T) 95 | 96 | variables = self.vars 97 | self.vars = {} 98 | 99 | for varName, var in variables.items(): 100 | self.setVar(varName, Variable(var.data.astype(T), name=var.name, grad=var.grad.astype(T))) 101 | 102 | attrs = self.attrs 103 | self.attrs = {} 104 | 105 | for attrName, attr in attrs.items(): 106 | self.setAttr(attrName, attr.astype(T)) 107 | 108 | self.mul.calctype = T 109 | self.calctype = T 110 | 111 | 112 | class BertEmbeddings(Container): 113 | def __init__(self, config, name=None): 114 | super().__init__(name) 115 | self.registerBlueprint(locals()) 116 | 117 | self.append(Embedder(config.vocab_size, config.hidden_size, name='wordEmbedder')) 118 | self.append(Embedder(config.max_position_embeddings, config.hidden_size, name='positionEmbedder')) 119 | self.append(Embedder(config.type_vocab_size, config.hidden_size, name='tokenTypeEmbedder')) 120 | 121 | self.append(BertLayerNorm(config, name='LayerNorm')) 122 | 123 | def updateData(self, data): 124 | if self.acquireDtypesFrom(data) == np.int32: 125 | inputIds = data 126 | tokenTypeIds = gpuarray.zeros(inputIds.shape, dtype=np.int32) 127 | else: 128 | inputIds, tokenTypeIds = data 129 | 130 | seqlength = inputIds.shape[1] 131 | positionIds = gpuarray.to_gpu(np.array([range(seqlength)]*inputIds.shape[0]).astype(np.int32)) 132 | 133 | wordsEmbeddings = self.modules['wordEmbedder'](inputIds) 134 | positionEmbeddings = self.modules['positionEmbedder'](positionIds) 135 | tokenTypeEmbeddings = self.modules['tokenTypeEmbedder'](tokenTypeIds) 136 | 137 | embeddings = wordsEmbeddings + positionEmbeddings + tokenTypeEmbeddings 138 | embeddings = self.modules['LayerNorm'](embeddings) 139 | self.data = embeddings 140 | 141 | def checkDataType(self, dtype): 142 | if dtype != np.int32 and dtype != [np.int32, np.int32]: 143 | raise ModuleError("Expected int32-tensor or [int32-tensor, int32-tensor] (got dtype %s)" % dtype) 144 | 145 | 146 | class BertSelfAttention(Container): 147 | def __init__(self, config, name=None): 148 | super().__init__(name) 149 | self.registerBlueprint(locals()) 150 | 151 | if config.hidden_size % config.num_attention_heads != 0: 152 | raise ValueError( 153 | "The hidden size (%d) is not a multiple of the number of attention " 154 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 155 | 156 | self.num_attention_heads = config.num_attention_heads 157 | self.attentionHeadSize = int(config.hidden_size / config.num_attention_heads) 158 | self.allHeadSize = self.num_attention_heads * self.attentionHeadSize 159 | 160 | self.append(Linear(config.hidden_size, self.allHeadSize, name='query')) 161 | self.append(Linear(config.hidden_size, self.allHeadSize, name='key')) 162 | self.append(Linear(config.hidden_size, self.allHeadSize, name='value')) 163 | self.append(Mul(name='mul')) 164 | 165 | def transpose(self, x): 166 | x = x.reshape(x.shape[:-1] + (self.num_attention_heads, self.attentionHeadSize)) 167 | swap = SwapAxes(axis1=1, axis2=2) 168 | swap.calcMode(self.calctype) 169 | x = swap(x) 170 | return x 171 | 172 | def updateData(self, data): 173 | hiddenStates, attentionMask = data 174 | 175 | mixedQueryLayer = self.modules['query'](hiddenStates) 176 | mixedKeyLayer = self.modules['key'](hiddenStates) 177 | mixedValueLayer = self.modules['value'](hiddenStates) 178 | 179 | queryLayer = self.transpose(mixedQueryLayer) 180 | keyLayer = self.transpose(mixedKeyLayer) 181 | valueLayer = self.transpose(mixedValueLayer) 182 | 183 | batchsize, maps, h, w = queryLayer.shape 184 | 185 | swap = SwapAxes(axis1=2, axis2=1) 186 | swap.calcMode(self.calctype) 187 | 188 | A = queryLayer.reshape((batchsize * maps, h, w)) 189 | B = swap(keyLayer.reshape((batchsize * maps, h, w))) 190 | attentionScores = mulTensorBatch(A, B, formatA="gbp", formatB="gbp", formatOut="gbp") 191 | attentionScores = attentionScores.reshape((batchsize, maps, h, h)) 192 | 193 | a = gpuarray.empty(attentionScores.shape, self.calctype).fill(1/math.sqrt(self.attentionHeadSize)) 194 | attentionScores = self.modules['mul']([attentionScores, a]) 195 | attentionScores = attentionScores + attentionMask 196 | 197 | softmax = SoftMax() 198 | softmax.calcMode(self.calctype) 199 | swap2 = SwapAxes(axis1=1, axis2=3) 200 | swap2.calcMode(self.calctype) 201 | attentionProbs = swap2(softmax(swap2(attentionScores))) 202 | 203 | contextLayer = mulTensorBatch(attentionProbs.reshape((batchsize * maps, h, h)), \ 204 | valueLayer.reshape((batchsize * maps, h, w)), \ 205 | formatA="gbp", formatB="gbp", formatOut="gbp") 206 | 207 | contextLayer = swap(contextLayer.reshape((batchsize, maps, h, w))).reshape((batchsize, h, self.allHeadSize)) 208 | self.data = contextLayer 209 | 210 | def calcMode(self, T): 211 | for mod in self.modules.values(): 212 | try: 213 | mod.calcMode(T) 214 | 215 | except Exception as e: 216 | self.handleError(mod, e) 217 | self.calctype = T 218 | 219 | 220 | class BertSelfOutput(Container): 221 | def __init__(self, config, name=None): 222 | super().__init__(name) 223 | self.registerBlueprint(locals()) 224 | 225 | self.append(Linear(config.hidden_size, config.hidden_size, name='dense')) 226 | self.append(BertLayerNorm(config, name='LayerNorm')) 227 | 228 | def updateData(self, data): 229 | hiddenStates, inputTensor = data 230 | hiddenStates = self.modules['dense'](hiddenStates) 231 | hiddenStates = self.modules['LayerNorm'](hiddenStates + inputTensor) 232 | self.data = hiddenStates 233 | 234 | 235 | class BertAttention(Container): 236 | def __init__(self, config, name=None): 237 | super().__init__(name) 238 | self.registerBlueprint(locals()) 239 | 240 | self.append(BertSelfAttention(config, name='self')) 241 | self.append(BertSelfOutput(config, name='output')) 242 | 243 | def updateData(self, data): 244 | inputTensor, attentionMask = data 245 | selfOutput = self.modules['self']((inputTensor, attentionMask)) 246 | attentionOutput = self.modules['output']((selfOutput, inputTensor)) 247 | self.data = attentionOutput 248 | 249 | 250 | class BertIntermediate(Container): 251 | def __init__(self, config, name=None): 252 | super().__init__(name) 253 | self.registerBlueprint(locals()) 254 | 255 | self.append(Linear(config.hidden_size, config.intermediate_size, name='dense')) 256 | self.append(Gelu(name='gelu')) 257 | 258 | def updateData(self, hiddenStates): 259 | hiddenStates = self.modules['dense'](hiddenStates) 260 | hiddenStates = self.modules['gelu'](hiddenStates) 261 | self.data = hiddenStates 262 | 263 | 264 | class BertOutput(Container): 265 | def __init__(self, config, name=None): 266 | super().__init__(name) 267 | self.registerBlueprint(locals()) 268 | 269 | self.append(Linear(config.intermediate_size, config.hidden_size, name='dense')) 270 | self.append(BertLayerNorm(config, name='LayerNorm')) 271 | 272 | def updateData(self, data): 273 | hiddenStates, inputTensor = data 274 | hiddenStates = self.modules['dense'](hiddenStates) 275 | hiddenStates = self.modules['LayerNorm'](hiddenStates + inputTensor) 276 | self.data = hiddenStates 277 | 278 | 279 | class BertLayer(Container): 280 | def __init__(self, config, name=None): 281 | super().__init__(name) 282 | self.registerBlueprint(locals()) 283 | 284 | self.append(BertAttention(config, name='attention')) 285 | self.append(BertIntermediate(config, name='intermediate')) 286 | self.append(BertOutput(config, name='output')) 287 | 288 | def updateData(self, data): 289 | hiddenStates, attentionMask = data 290 | attentionOutput = self.modules['attention']((hiddenStates, attentionMask)) 291 | intermediateOutput = self.modules['intermediate'](attentionOutput) 292 | layerOutput = self.modules['output']((intermediateOutput, attentionOutput)) 293 | self.data = layerOutput 294 | 295 | 296 | class BertEncoder(Container): 297 | def __init__(self, config, name=None): 298 | super().__init__(name) 299 | self.registerBlueprint(locals()) 300 | for i in range(config.num_hidden_layers): 301 | self.append(BertLayer(config, name=i)) 302 | 303 | def updateData(self, data): 304 | hiddenStates, attentionMask = data 305 | for i in self.modules: 306 | hiddenStates = self.modules[i]((hiddenStates, attentionMask)) 307 | 308 | self.data = hiddenStates 309 | 310 | 311 | class BertPooler(Container): 312 | def __init__(self, config, name=None): 313 | super().__init__(name) 314 | self.registerBlueprint(locals()) 315 | self.append(Linear(config.hidden_size, config.hidden_size, name='dense')) 316 | self.activation = Activation('tanh') 317 | 318 | def updateData(self, data): 319 | firstTokenTensor = data[:, 0].copy() 320 | pooledOutput = self.modules['dense'](firstTokenTensor) 321 | pooledOutput = self.activation(pooledOutput) 322 | self.data = pooledOutput 323 | 324 | 325 | class BertModel(Container): 326 | def __init__(self, config, name=None): 327 | super().__init__(name) 328 | self.registerBlueprint(locals()) 329 | self.append(BertEmbeddings(config, name='embeddings')) 330 | self.append(BertEncoder(config, name='encoder')) 331 | self.append(BertPooler(config, name='pooler')) 332 | self.num_attention_heads = config.num_attention_heads 333 | 334 | def updateData(self, data): 335 | inputIds = data 336 | 337 | attentionMask = np.ones(inputIds.shape) 338 | 339 | attentionMask = (1.0 - attentionMask) * -10000.0 340 | extendedAttentionMask = np.repeat(np.expand_dims(attentionMask, axis=1), inputIds.shape[1], axis=1) 341 | extendedAttentionMask = np.repeat(np.expand_dims(extendedAttentionMask, axis=1), self.num_attention_heads, axis=1) 342 | calctype = self.modules['embeddings'].modules['LayerNorm'].calctype 343 | extendedAttentionMask = gpuarray.to_gpu(extendedAttentionMask.astype(calctype)) 344 | 345 | embeddingOutput = self.modules['embeddings'](inputIds) 346 | sequenceOutput = self.modules['encoder']((embeddingOutput, extendedAttentionMask)) 347 | 348 | self.data = sequenceOutput 349 | 350 | def checkDataType(self, dtype): 351 | if dtype != np.int32: 352 | raise ModuleError("Expected int32-tensor (got dtype %s)" % dtype) 353 | 354 | 355 | class BertLMPredictionHead(Container): 356 | def __init__(self, config, bertModelEmbeddingWeights, name=None): 357 | super().__init__(name) 358 | self.registerBlueprint(locals()) 359 | 360 | self.append(Linear(config.hidden_size, config.hidden_size, name='dense')) 361 | self.append(Gelu(name='gelu')) 362 | self.append(BertLayerNorm(config, name='LayerNorm')) 363 | self.append(Linear(bertModelEmbeddingWeights.shape[1], bertModelEmbeddingWeights.shape[0], name='decoder')) 364 | self.modules['decoder'].setVar('W', Variable(bertModelEmbeddingWeights)) 365 | self.modules['decoder'].setVar('b', Variable(gpuarray.zeros((bertModelEmbeddingWeights.shape[1],), dtype = np.float32))) 366 | 367 | def updateData(self, hiddenStates): 368 | hiddenStates = self.modules['dense'](hiddenStates) 369 | hiddenStates = self.modules['gelu'](hiddenStates) 370 | hiddenStates = self.modules['LayerNorm'](hiddenStates) 371 | hiddenStates = self.modules['decoder'](hiddenStates) 372 | self.data = hiddenStates 373 | 374 | 375 | class BertForMaskedLM(Container): 376 | def __init__(self, config, name=None): 377 | super().__init__(name) 378 | self.registerBlueprint(locals()) 379 | self.append(BertModel(config, name='bert')) 380 | swap = SwapAxes(axis1=0, axis2=1) 381 | bertModelEmbeddingWeights = swap(self.modules['bert'].modules['embeddings'].modules['wordEmbedder'].W) 382 | self.append(BertLMPredictionHead(config, bertModelEmbeddingWeights, name='cls')) 383 | 384 | def updateData(self, inputIds): 385 | sequenceOutput = self.modules['bert'](inputIds) 386 | predictionScores = self.modules['cls'](sequenceOutput) 387 | self.data = predictionScores 388 | 389 | def checkDataType(self, dtype): 390 | if dtype != np.int32: 391 | raise ModuleError("Expected int32-tensor (got dtype %s)" % dtype) 392 | 393 | 394 | class BertPunc(Container): 395 | def __init__(self, config, name=None): 396 | super().__init__(name) 397 | self.registerBlueprint(locals()) 398 | self.append(BertForMaskedLM(config, name='lm')) 399 | self.bert_vocab_size = config.vocab_size 400 | self.segment_size = config.segment_size 401 | self.output_size = config.output_size 402 | self.append(BatchNorm(self.segment_size*self.bert_vocab_size, affine=False, name='bn')) 403 | self.append(Linear(self.segment_size*self.bert_vocab_size, self.output_size, name='dense')) 404 | self.modules['bn'].evalMode() 405 | 406 | def updateData(self, data): 407 | data = self.modules['lm'](data) 408 | data = data.reshape((data.shape[0], int(np.prod(data.shape[1:])))) 409 | data = self.modules['bn'](data).astype(np.float16) 410 | self.data = self.modules['dense'](data) 411 | 412 | def checkDataType(self, dtype): 413 | if dtype != np.int32: 414 | raise ModuleError("Expected int32-tensor (got dtype %s)" % dtype) 415 | 416 | def calcMode(self, T): 417 | for mod in self.modules.values(): 418 | try: 419 | mod.calcMode(T) 420 | 421 | except Exception as e: 422 | self.handleError(mod, e) 423 | self.calctype = T 424 | 425 | -------------------------------------------------------------------------------- /bert_punctuator/modules.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | from PuzzleLib import Config 4 | from PuzzleLib.Backend import gpuarray, Blas 5 | from PuzzleLib.Backend.Kernels.Embedder import embed, embedBackwardParams 6 | from PuzzleLib.Backend.Kernels import MatVec 7 | from PuzzleLib.Variable import Variable 8 | from PuzzleLib.Modules.Module import ModuleError, Module 9 | from PuzzleLib.Modules import Reshape 10 | 11 | 12 | class Embedder(Module): 13 | def __init__(self, vocabulary, embsize, onVocabulary=None, initscheme="uniform", wscale=1.0, 14 | learnable=True, name=None): 15 | super().__init__(name) 16 | args = dict(locals()) 17 | 18 | self.embsize = embsize 19 | 20 | self.wgrad = None 21 | self.learnable = learnable 22 | self.outgrad = None 23 | 24 | dt = h5py.special_dtype(vlen=str) 25 | 26 | if isinstance(vocabulary, dict): 27 | vocabsize = len(vocabulary) 28 | vocab = np.empty(shape=(vocabsize, ), dtype=dt) 29 | 30 | for word, idx in vocabulary.items(): 31 | vocab[int(idx)] = word 32 | 33 | elif isinstance(vocabulary, int): 34 | vocabsize = vocabulary 35 | vocab = np.empty(shape=(0, ), dtype=dt) 36 | 37 | else: 38 | raise ModuleError("Unrecognized vocabulary parameter type") 39 | 40 | self.vocab = None 41 | self.setAttr("vocab", vocab) 42 | 43 | args["vocabulary"] = vocabsize 44 | self.registerBlueprint(args, exclude=["onVocabulary"]) 45 | 46 | Wshape = (vocabsize, embsize) 47 | W = self.createTensorWithScheme(initscheme, Wshape, wscale, (embsize, vocabsize)) 48 | if W is None: 49 | W = np.empty(Wshape, dtype=np.float32) 50 | 51 | if onVocabulary is not None: 52 | onVocabulary(W) 53 | 54 | self.W = None 55 | self.setVar("W", Variable(gpuarray.to_gpu(W))) 56 | 57 | self.loadVarHook = self.checkVarOnLoad 58 | self.loadAttrHook = self.checkAttrOnLoad 59 | 60 | def checkVarOnLoad(self, paramName, dataset): 61 | if paramName == "W": 62 | if dataset.shape[1] != self.embsize: 63 | raise ModuleError("Expected embedding size %s, was given %s" % (self.embsize, dataset.shape[1])) 64 | 65 | self.setVar("W", Variable(gpuarray.to_gpu(dataset))) 66 | 67 | else: 68 | raise ModuleError("Unknown parameter name '%s' for embedder" % paramName) 69 | 70 | def checkAttrOnLoad(self, attrName, dataset): 71 | if attrName == "vocab": 72 | self.setAttr("vocab", dataset) 73 | 74 | else: 75 | raise ModuleError("Unknown attribute name '%s' for embedder" % attrName) 76 | 77 | def getVocabulary(self): 78 | voc = {} 79 | 80 | if self.hasAttr("vocab"): 81 | for i in range(self.vocab.shape[0]): 82 | voc[self.vocab[i]] = i 83 | 84 | return voc 85 | 86 | def verifyData(self, data): 87 | mn, mx = gpuarray.minimum(data).get(), gpuarray.maximum(data).get() 88 | if mn < -1: 89 | raise ModuleError("Embedder data verification failed, found index %s (< -1)" % mn) 90 | 91 | if mx >= self.W.shape[0]: 92 | raise ModuleError("Embedder data verification failed, found index %s (vocabulary size is %s)" % 93 | (mx, self.W.shape[0])) 94 | 95 | def updateData(self, data): 96 | if Config.verifyData: 97 | self.verifyData(data) 98 | self.data = embed(data, self.W) 99 | 100 | def updateGrad(self, grad): 101 | self.grad = None 102 | 103 | def accGradParams(self, grad, scale=1.0, momentum=0.0): 104 | self.outgrad = grad 105 | self.vars["W"].grad.fill(0.0) 106 | 107 | if self.learnable: 108 | embedBackwardParams(self.inData, grad, self.vars["W"].grad, scale) 109 | 110 | def updateParams(self, learnRate): 111 | if self.learnable: 112 | embedBackwardParams(self.inData, self.outgrad, self.W, learnRate) 113 | 114 | def dataShapeFrom(self, shape): 115 | batchsize, sentlen = shape 116 | return batchsize, sentlen, self.embsize 117 | 118 | def gradShapeFrom(self, shape): 119 | raise ModuleError("Gradient propagation is undefined") 120 | 121 | def checkDataShape(self, shape): 122 | if len(shape) != 2: 123 | raise ModuleError("Data must be 2d matrix") 124 | 125 | def checkGradShape(self, shape): 126 | if len(shape) != 3: 127 | raise ModuleError("Grad must be 3d tensor") 128 | 129 | batchsize, sentlen, embsize = shape 130 | 131 | if embsize != self.embsize: 132 | raise ModuleError("Expected %d grad embedding size, %d was given" % (self.embsize, embsize)) 133 | 134 | if batchsize != self.inData.shape[0]: 135 | raise ModuleError("Expected %d grad batch size, %d was given" % (self.inData.shape[0], batchsize)) 136 | 137 | def checkDataType(self, dtype): 138 | if dtype != np.int32: 139 | raise ModuleError("Expected int32-tensor (got dtype %s)" % dtype) 140 | 141 | def reset(self): 142 | super().reset() 143 | self.outgrad = None 144 | 145 | def calcMode(self, T): 146 | if self.calctype == T: 147 | return 148 | 149 | variables = self.vars 150 | self.vars = {} 151 | 152 | for varName, var in variables.items(): 153 | self.setVar(varName, Variable(var.data.astype(T), name=var.name, grad=var.grad.astype(T))) 154 | 155 | self.calctype = T 156 | 157 | 158 | class Linear(Module): 159 | def __init__(self, insize, outsize, wscale=1.0, useBias=True, initscheme=None, name=None, 160 | empty=False, transpose=False): 161 | super().__init__(name) 162 | self.registerBlueprint(locals()) 163 | 164 | self.transpose = transpose 165 | self.useBias = useBias 166 | 167 | self.W = None 168 | self.b = None 169 | 170 | if empty: 171 | return 172 | 173 | Wshape, bshape = ((outsize, insize), (insize, )) if transpose else ((insize, outsize), (outsize, )) 174 | W = self.createTensorWithScheme(initscheme, Wshape, wscale, factorShape=Wshape) 175 | 176 | self.setVar("W", Variable(gpuarray.empty(Wshape, dtype=self.calctype) if W is None else gpuarray.to_gpu(W))) 177 | 178 | if useBias: 179 | self.setVar("b", Variable(gpuarray.zeros(bshape, dtype=self.calctype))) 180 | 181 | def updateData(self, data): 182 | reshape = len(data.shape)>2 183 | if reshape: 184 | reshape2d = Reshape((int(np.prod(data.shape[:-1])), data.shape[-1])) 185 | reshape2d.calcMode(self.calctype) 186 | reshapeNd = Reshape(data.shape[:-1] + tuple([self.W.shape[1]])) 187 | reshapeNd.calcMode(self.calctype) 188 | data = reshape2d(data) 189 | self.data = Blas.mulMatrixOnMatrix(data, self.W, transpB=self.transpose) 190 | if self.useBias: 191 | MatVec.addVecToMat(self.b, self.data, axis=1, out=self.data) 192 | if reshape: 193 | self.data = reshapeNd(self.data) 194 | 195 | def updateGrad(self, grad): 196 | self.grad = Blas.mulMatrixOnMatrix(grad, self.W, transpB=not self.transpose) 197 | 198 | def accGradParams(self, grad, scale=1.0, momentum=0.0): 199 | if not self.transpose: 200 | Blas.mulMatrixOnMatrix(self.inData, grad, out=self.vars["W"].grad, transpA=True, alpha=scale, beta=momentum) 201 | else: 202 | Blas.mulMatrixOnMatrix(grad, self.inData, out=self.vars["W"].grad, transpA=True, alpha=scale, beta=momentum) 203 | 204 | if self.useBias: 205 | Blas.sumOnMatrix(grad, out=self.vars["b"].grad, alpha=scale, beta=momentum) 206 | 207 | def dataShapeFrom(self, shape): 208 | return (shape[0], self.W.shape[1]) if not self.transpose else (shape[0], self.W.shape[0]) 209 | 210 | def checkDataShape(self, shape): 211 | if not self.transpose: 212 | if shape[-1] != self.W.shape[0]: 213 | raise ModuleError("Expected %d data dimensions, %d were given" % (self.W.shape[0], shape[1])) 214 | else: 215 | if shape[-1]!= self.W.shape[1]: 216 | raise ModuleError("Expected %d data dimensions, %d were given" % (self.W.shape[1], shape[1])) 217 | 218 | def gradShapeFrom(self, shape): 219 | return (shape[0], self.W.shape[0]) if not self.transpose else (shape[0], self.W.shape[1]) 220 | 221 | def checkGradShape(self, shape): 222 | if len(shape) != 2: 223 | raise ModuleError("Grad must be 2d matrix") 224 | 225 | if not self.transpose: 226 | if shape[1] != self.W.shape[1]: 227 | raise ModuleError("Expected %d grad dimensions, %d were given" % (self.W.shape[1], shape[1])) 228 | else: 229 | if shape[1] != self.W.shape[0]: 230 | raise ModuleError("Expected %d grad dimensions, %d were given" % (self.W.shape[0], shape[1])) 231 | 232 | def calcMode(self, T): 233 | if self.calctype == T: 234 | return 235 | 236 | variables = self.vars 237 | self.vars = {} 238 | 239 | for varName, var in variables.items(): 240 | self.setVar(varName, Variable(var.data.astype(T), name=var.name, grad=var.grad.astype(T))) 241 | 242 | self.calctype = T 243 | 244 | -------------------------------------------------------------------------------- /bert_punctuator/tokenizer.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import unicodedata 3 | import os 4 | 5 | 6 | def load_vocab(vocab_file): 7 | vocab = collections.OrderedDict() 8 | index = 0 9 | with open(vocab_file, "r", encoding="utf-8") as reader: 10 | while True: 11 | token = reader.readline() 12 | if not token: 13 | break 14 | token = token.strip() 15 | vocab[token] = index 16 | index += 1 17 | return vocab 18 | 19 | 20 | def whitespace_tokenize(text): 21 | text = text.strip() 22 | if not text: 23 | return [] 24 | tokens = text.split() 25 | return tokens 26 | 27 | 28 | class BertTokenizer(object): 29 | def __init__(self, vocab_file, lower_case=True): 30 | if not os.path.isfile(vocab_file): 31 | raise ValueError( 32 | "Can't find a vocabulary file at path {}".format(vocab_file)) 33 | self.vocab = load_vocab(vocab_file) 34 | self.ids_to_tokens = collections.OrderedDict( 35 | [(ids, tok) for tok, ids in self.vocab.items()]) 36 | self.basic_tokenizer = BasicTokenizer(lower_case=lower_case) 37 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 38 | 39 | def tokenize(self, text): 40 | split_tokens = [] 41 | for token in self.basic_tokenizer.tokenize(text): 42 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 43 | split_tokens.append(sub_token) 44 | return split_tokens 45 | 46 | def convert_tokens_to_ids(self, tokens): 47 | ids = [] 48 | for token in tokens: 49 | ids.append(self.vocab[token]) 50 | return ids 51 | 52 | def convert_ids_to_tokens(self, ids): 53 | tokens = [] 54 | for i in ids: 55 | tokens.append(self.ids_to_tokens[i]) 56 | return tokens 57 | 58 | 59 | class BasicTokenizer(object): 60 | def __init__(self, lower_case=True): 61 | self.lower_case = lower_case 62 | 63 | def tokenize(self, text): 64 | text = self.clean_text(text) 65 | 66 | orig_tokens = whitespace_tokenize(text) 67 | split_tokens = [] 68 | 69 | for token in orig_tokens: 70 | if self.lower_case: 71 | token = token.lower() 72 | token = self.run_strip_accents(token) 73 | split_tokens.extend(self.run_split_on_punc(token)) 74 | 75 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 76 | return output_tokens 77 | 78 | def run_strip_accents(self, text): 79 | text = unicodedata.normalize("NFD", text) 80 | output = [] 81 | for char in text: 82 | cat = unicodedata.category(char) 83 | if cat == "Mn": 84 | continue 85 | output.append(char) 86 | return "".join(output) 87 | 88 | def run_split_on_punc(self, text): 89 | chars = list(text) 90 | i = 0 91 | start_new_word = True 92 | output = [] 93 | while i < len(chars): 94 | char = chars[i] 95 | if is_punctuation(char): 96 | output.append([char]) 97 | start_new_word = True 98 | else: 99 | if start_new_word: 100 | output.append([]) 101 | start_new_word = False 102 | output[-1].append(char) 103 | i += 1 104 | 105 | return ["".join(x) for x in output] 106 | 107 | def clean_text(self, text): 108 | output = [] 109 | for char in text: 110 | cp = ord(char) 111 | if cp == 0 or cp == 0xfffd or is_control(char): 112 | continue 113 | if is_whitespace(char): 114 | output.append(" ") 115 | else: 116 | output.append(char) 117 | return "".join(output) 118 | 119 | 120 | class WordpieceTokenizer(object): 121 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 122 | self.vocab = vocab 123 | self.unk_token = unk_token 124 | self.max_input_chars_per_word = max_input_chars_per_word 125 | 126 | def tokenize(self, text): 127 | output_tokens = [] 128 | for token in whitespace_tokenize(text): 129 | chars = list(token) 130 | if len(chars) > self.max_input_chars_per_word: 131 | output_tokens.append(self.unk_token) 132 | continue 133 | 134 | is_bad = False 135 | start = 0 136 | sub_tokens = [] 137 | while start < len(chars): 138 | end = len(chars) 139 | cur_substr = None 140 | while start < end: 141 | substr = "".join(chars[start:end]) 142 | if start > 0: 143 | substr = "##" + substr 144 | if substr in self.vocab: 145 | cur_substr = substr 146 | break 147 | end -= 1 148 | if cur_substr is None: 149 | is_bad = True 150 | break 151 | sub_tokens.append(cur_substr) 152 | start = end 153 | 154 | if is_bad: 155 | output_tokens.append(self.unk_token) 156 | else: 157 | output_tokens.extend(sub_tokens) 158 | return output_tokens 159 | 160 | def is_whitespace(char): 161 | if char == " " or char == "\t" or char == "\n" or char == "\r": 162 | return True 163 | cat = unicodedata.category(char) 164 | if cat == "Zs": 165 | return True 166 | return False 167 | 168 | def is_control(char): 169 | if char == "\t" or char == "\n" or char == "\r": 170 | return False 171 | cat = unicodedata.category(char) 172 | if cat.startswith("C"): 173 | return True 174 | return False 175 | 176 | def is_punctuation(char): 177 | cp = ord(char) 178 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 179 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 180 | return True 181 | cat = unicodedata.category(char) 182 | if cat.startswith("P"): 183 | return True 184 | return False 185 | 186 | -------------------------------------------------------------------------------- /config.ini: -------------------------------------------------------------------------------- 1 | [Wav2Letter] 2 | # Prediction letters 3 | labels = [_-абвгдеёжзийклмнопрстуфхцчшщъыьэюя ] 4 | 5 | # Path to acoustic model 6 | model_path = data/w2l-16khz.hdf 7 | 8 | # Path to language model 9 | lm_path = data/vosk/lm.klm 10 | 11 | # Path to the lexicon file 12 | lexicon = data/vosk/lexicon.txt 13 | 14 | # Path to prediction tokens file 15 | tokens = data/tokens.txt 16 | 17 | # Use CPU for acoustic model inference 18 | cpu = 0 19 | 20 | # Use greedy decoder instead of CTC decoder 21 | greedy = 0 22 | 23 | # Beam threshold for CTC decoder 24 | beam_threshold = 10 25 | 26 | # Audio sample rate, should match the acoustic model sample rate 27 | sample_rate = 16000 28 | 29 | # Window size in seconds for acoustic model samples 30 | window_size = 0.02 31 | 32 | # Window stride in seconds for acoustic model samples 33 | window_stride = 0.01 34 | 35 | 36 | [Train] 37 | # Path to train manifest csv 38 | train_manifest = data/train.csv 39 | 40 | # Path to validation manifest csv 41 | val_manifest = data/val.csv 42 | 43 | # Number of training epochs 44 | epochs = 100 45 | 46 | # Batch size for training 47 | batch_size = 8 48 | 49 | # Initial learning rate 50 | learning_rate = 1e-5 51 | 52 | # Run training with half-precision for memory use optimization 53 | fp16 = 1 54 | 55 | # Name prefix for checkpoints 56 | checkpoint_name = w2l 57 | 58 | # Save checkpoints per batch, 0 means never save 59 | checkpoint_per_batch = 1000 60 | 61 | # Location to save checkpoints 62 | save_folder = Checkpoints/ 63 | 64 | # Continue from checkpoint model 65 | continue_from = data/w2l-16khz.hdf -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | from utils import stft, magphase 4 | from pydub import AudioSegment 5 | 6 | 7 | def pcen2(e, sr=16000, hop_length=512, t=0.395, eps=0.000001, alpha=0.98, delta=2.0, r=0.5): 8 | s = 1 - np.exp(-float(hop_length) / (t * sr)) 9 | m = scipy.signal.lfilter([s], [1, s - 1], e) 10 | smooth = (eps + m) ** (-alpha) 11 | 12 | return (e * smooth + delta) ** r - delta ** r 13 | 14 | 15 | def load_audio(path, sample_rate): 16 | sound = AudioSegment.from_wav(path) 17 | sound = sound.set_frame_rate(sample_rate) 18 | sound = sound.set_channels(1) 19 | sound = sound.set_sample_width(2) 20 | 21 | return np.array(sound.get_array_of_samples()).astype(float) 22 | 23 | 24 | def preprocess(audio_path, sample_rate=16000, window_size=0.02, window_stride=0.01, window='hamming'): 25 | audio = load_audio(audio_path, sample_rate) 26 | nfft = int(sample_rate * window_size) 27 | win_length = nfft 28 | hop_length = int(sample_rate * window_stride) 29 | 30 | d = stft(audio, n_fft=nfft, hop_length=hop_length, 31 | win_length=win_length, window=window) 32 | 33 | spect, phase = magphase(d) 34 | pcen_result = pcen2(e=spect, sr=sample_rate, hop_length=hop_length) 35 | mean_pcen = pcen_result.mean() 36 | std_pcen = pcen_result.std() 37 | 38 | pcen_result = np.add(pcen_result, -mean_pcen) 39 | pcen_result = pcen_result / std_pcen 40 | 41 | return pcen_result 42 | 43 | 44 | def get_batch(batch): 45 | longest_sample = max(batch, key=lambda p: p[0].shape[1])[0] 46 | freq_size = longest_sample.shape[0] 47 | mini_batch_size = len(batch) 48 | max_seq_length = longest_sample.shape[1] 49 | inputs = np.zeros((mini_batch_size, freq_size, max_seq_length)) 50 | target_sizes = np.zeros(shape=(mini_batch_size,), dtype=int) 51 | input_percentages = np.zeros(shape=(mini_batch_size,), dtype=float) 52 | targets = [] 53 | input_file_path_and_transcription = [] 54 | 55 | for x in range(mini_batch_size): 56 | sample = batch[x] 57 | tensor = sample[0] 58 | target = sample[1] 59 | tensor_path = sample[2] 60 | original_transcription = sample[3] 61 | seq_length = tensor.shape[1] 62 | tensor_new = np.pad(tensor, ((0, 0), (0, abs(seq_length - max_seq_length))), 'wrap') 63 | inputs[x] = tensor_new 64 | input_percentages[x] = seq_length / float(max_seq_length) 65 | target_sizes[x] = len(target) 66 | targets.extend(target) 67 | input_file_path_and_transcription.append([tensor_path, original_transcription]) 68 | 69 | targets = np.array(targets) 70 | 71 | return inputs, input_percentages, targets, target_sizes, input_file_path_and_transcription 72 | 73 | 74 | class DataLoader(object): 75 | def __init__(self, dataset, batch_sampler): 76 | self.dataset = dataset 77 | self.batch_sampler = batch_sampler 78 | self.sample_iter = iter(self.batch_sampler) 79 | 80 | def __next__(self): 81 | try: 82 | indices = next(self.sample_iter) 83 | indices = [i for i in indices][0] 84 | batch = get_batch([self.dataset[i] for i in indices]) 85 | return batch 86 | except Exception as e: 87 | print("Encountered exception {}".format(e)) 88 | raise StopIteration() 89 | 90 | def __iter__(self): 91 | return self 92 | 93 | def __len__(self): 94 | return len(self.batch_sampler) 95 | 96 | def reset(self): 97 | self.batch_sampler.reset() 98 | 99 | 100 | class SpectrogramDataset(object): 101 | def __init__(self, labels, sample_rate, window_size, window_stride, manifest_file_path): 102 | self.manifest_file_path = manifest_file_path 103 | with open(self.manifest_file_path) as f: 104 | lines = f.readlines() 105 | self.ids = [x.strip().split(',') for x in lines] 106 | self.size = len(lines) 107 | self.labels_map = dict([(labels[i], i) for i in range(len(labels))]) 108 | self.sample_rate = sample_rate 109 | self.window_size = window_size 110 | self.window_stride = window_stride 111 | 112 | def __getitem__(self, index): 113 | sample = self.ids[index] 114 | audio_path, transcript_loaded = sample[0], sample[1] 115 | spectrogram = preprocess(audio_path, self.sample_rate, self.window_size, self.window_stride) 116 | transcript = list(filter(None, [self.labels_map.get(x) for x in list(transcript_loaded)])) 117 | return spectrogram, transcript, audio_path, transcript_loaded 118 | 119 | def __len__(self): 120 | return self.size 121 | 122 | 123 | class BucketingSampler(object): 124 | def __init__(self, data_source, batch_size=1, shuffle=False): 125 | self.data_source = data_source 126 | self.batch_size = batch_size 127 | self.ids = list(range(0, len(data_source))) 128 | self.batch_id = 0 129 | self.bins = [] 130 | self.shuffle = shuffle 131 | self.reset() 132 | 133 | def __iter__(self): 134 | return self 135 | 136 | def __next__(self): 137 | if self.batch_id < len(self): 138 | ids = self.bins[self.batch_id] 139 | self.batch_id += 1 140 | yield ids 141 | else: 142 | raise StopIteration() 143 | 144 | def __len__(self): 145 | return len(self.bins) 146 | 147 | def get_bins(self): 148 | if self.shuffle: 149 | np.random.shuffle(self.ids) 150 | self.bins = [self.ids[i:i + self.batch_size] for i in range(0, len(self.ids), self.batch_size)] 151 | 152 | def reset(self): 153 | self.get_bins() 154 | self.batch_id = 0 155 | -------------------------------------------------------------------------------- /decoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import itertools 4 | from scipy.special import softmax 5 | np.seterr(divide='ignore') 6 | 7 | 8 | class DecodeResult: 9 | def __init__(self, score, words): 10 | self.score, self.words = score, words 11 | self.text = " ".join(word["word"] for word in words) 12 | 13 | 14 | class GreedyDecoder: 15 | def __init__(self, labels, blank_idx=0): 16 | self.labels, self.blank_idx = labels, blank_idx 17 | self.delim_idx = self.labels.index("|") 18 | 19 | 20 | def decode(self, output, start_timestamp=0, frame_time=0.02): 21 | best_path = np.argmax(output.astype(np.float32, copy=False), axis=1) 22 | score = None 23 | 24 | words, new_word, i = [], True, 0 25 | current_word, current_timestamp, end_idx = None, start_timestamp, 0 26 | words_len = 0 27 | 28 | for k, g in itertools.groupby(best_path): 29 | if k != self.blank_idx: 30 | if new_word and k != self.delim_idx: 31 | new_word, start_idx = False, i 32 | current_word, current_timestamp = self.labels[k], frame_time * i + start_timestamp 33 | 34 | elif k == self.delim_idx: 35 | end_timestamp = frame_time * i + start_timestamp 36 | new_word, end_idx = True, i 37 | word_score = output[range(start_idx, end_idx), best_path[range(start_idx, end_idx)]] - np.max(output) 38 | if score is not None: 39 | score = np.hstack([score, word_score]) 40 | else: 41 | score = word_score 42 | word_confidence = np.round(np.exp(word_score.mean() / max(1, end_idx - start_idx)) * 100.0, 2) 43 | words_len += end_idx - start_idx 44 | words.append({ 45 | "word": current_word, 46 | "start": np.round(current_timestamp, 2), 47 | "end": np.round(end_timestamp, 2), 48 | "confidence": word_confidence 49 | }) 50 | 51 | else: 52 | current_word += self.labels[k] 53 | 54 | i += sum(1 for _ in g) 55 | 56 | score = np.round(np.exp(score.mean() / max(1, words_len)) * 100.0, 2) 57 | 58 | return DecodeResult(score, words) 59 | 60 | 61 | class TrieDecoder: 62 | def __init__(self, lexicon, tokens, lm_path, beam_threshold=30): 63 | from trie_decoder.common import Dictionary, create_word_dict, load_words 64 | from trie_decoder.decoder import CriterionType, DecoderOptions, KenLM, LexiconDecoder 65 | lexicon = load_words(lexicon) 66 | self.wordDict = create_word_dict(lexicon) 67 | self.tokenDict = Dictionary(tokens) 68 | self.lm = KenLM(lm_path, self.wordDict) 69 | 70 | trie, self.sil_idx, self.blank_idx, self.unk_idx = self.get_trie(lexicon) 71 | transitions = np.zeros((self.tokenDict.index_size(), self.tokenDict.index_size())).flatten() 72 | 73 | opts = DecoderOptions( 74 | 2000, 100, beam_threshold, 1.4, 1.0, -math.inf, -1, 0, False, CriterionType.CTC 75 | ) 76 | 77 | self.trieDecoder = LexiconDecoder( 78 | opts, trie, self.lm, self.sil_idx, self.blank_idx, self.unk_idx, transitions, False 79 | ) 80 | self.delim_idx = self.tokenDict.get_index("|") 81 | 82 | def get_trie(self, lexicon): 83 | from trie_decoder.common import tkn_to_idx 84 | from trie_decoder.decoder import SmearingMode, Trie 85 | unk_idx = self.wordDict.get_index("") 86 | sil_idx = blank_idx = self.tokenDict.get_index("#") 87 | 88 | trie = Trie(self.tokenDict.index_size(), sil_idx) 89 | start_state = self.lm.start(False) 90 | 91 | for word, spellings in lexicon.items(): 92 | usr_idx = self.wordDict.get_index(word) 93 | _, score = self.lm.score(start_state, usr_idx) 94 | score = np.round(score, 2) 95 | 96 | for spelling in spellings: 97 | spelling_indices = tkn_to_idx(spelling, self.tokenDict, 0) 98 | trie.insert(spelling_indices, usr_idx, score) 99 | 100 | trie.smear(SmearingMode.MAX) 101 | 102 | return trie, sil_idx, blank_idx, unk_idx 103 | 104 | def decode(self, output, start_timestamp=0, frame_time=0.02): 105 | output = np.log(softmax(output[:, :].astype(np.float32, copy=False), axis=-1)) 106 | 107 | t, n = output.shape 108 | result = self.trieDecoder.decode(output.ctypes.data, t, n)[0] 109 | tokens = result.tokens 110 | 111 | words, new_word = [], True 112 | current_word, current_timestamp, start_idx, end_idx = None, start_timestamp, 0, 0 113 | lm_state = self.lm.start(False) 114 | words_len = 0 115 | 116 | for i, k in enumerate(tokens): 117 | if k != self.blank_idx: 118 | if i > 0 and k == tokens[i - 1]: 119 | pass 120 | 121 | elif k == self.sil_idx: 122 | new_word = True 123 | 124 | else: 125 | if new_word and k != self.delim_idx: 126 | new_word = False 127 | current_word, current_timestamp = self.tokenDict.get_entry(k), frame_time * i + start_timestamp 128 | start_idx = i 129 | 130 | elif k == self.delim_idx: 131 | new_word, end_idx = True, i 132 | lm_state, word_lm_score = self.lm.score(lm_state, self.wordDict.get_index(current_word)) 133 | end_timestamp = frame_time * i + start_timestamp 134 | words_len += end_idx - start_idx 135 | words.append({ 136 | "word": current_word, 137 | "start": np.round(current_timestamp, 2), 138 | "end": np.round(end_timestamp, 2), 139 | "confidence": np.round(np.exp(word_lm_score / max(1, end_idx - start_idx)) * 100, 2) 140 | }) 141 | 142 | else: 143 | current_word += self.tokenDict.get_entry(k) 144 | 145 | score = np.round(np.exp(result.score / max(1, words_len)), 2) 146 | 147 | return DecodeResult(score, words) 148 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '2.3' 2 | 3 | services: 4 | sova-asr: 5 | restart: always 6 | container_name: sova-asr 7 | build: 8 | context: . 9 | dockerfile: Dockerfile 10 | runtime: nvidia 11 | environment: 12 | - NVIDIA_VISIBLE_DEVICES=all 13 | 14 | image: sova-asr:master 15 | volumes: 16 | - .:/sova-asr 17 | command: bash -c "gunicorn --access-logfile - -w 1 --bind 0.0.0.0:8888 app:app --timeout 15000" 18 | ports: 19 | - 8888:8888 20 | 21 | sova-asr-train: 22 | restart: "no" 23 | container_name: sova-asr-train 24 | build: 25 | context: . 26 | dockerfile: Dockerfile 27 | runtime: nvidia 28 | environment: 29 | - NVIDIA_VISIBLE_DEVICES=all 30 | image: sova-asr:master 31 | volumes: 32 | - .:/sova-asr 33 | command: bash -c "python3 train.py" -------------------------------------------------------------------------------- /docs/Руководство оператора SOVA Speech (RU).pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sovaai/sova-asr/d6c257555a225c9c1e1bb3e3cebf9b7ce8d302d7/docs/Руководство оператора SOVA Speech (RU).pdf -------------------------------------------------------------------------------- /docs/Руководство програмиста SOVA Speech (RU).pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sovaai/sova-asr/d6c257555a225c9c1e1bb3e3cebf9b7ce8d302d7/docs/Руководство програмиста SOVA Speech (RU).pdf -------------------------------------------------------------------------------- /file_handler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import time 4 | import logging 5 | import uuid 6 | from speech_recognizer import SpeechRecognizer 7 | from punctuator import Punctuator 8 | from number_utils.text2numbers import TextToNumbers 9 | 10 | 11 | speech_recognizer = SpeechRecognizer() 12 | punctuator = Punctuator(model_path="data/punctuator") 13 | text2numbers = TextToNumbers() 14 | 15 | 16 | class FileHandler: 17 | @staticmethod 18 | def get_recognized_text(blob): 19 | try: 20 | filename = str(uuid.uuid4()) 21 | os.makedirs('./records', exist_ok=True) 22 | new_record_path = os.path.join('./records', filename + '.webm') 23 | blob.save(new_record_path) 24 | new_filename = filename + '.wav' 25 | converted_record_path = FileHandler.convert_to_wav(new_record_path, new_filename) 26 | response_models_result = FileHandler.get_models_result(converted_record_path) 27 | return 0, new_filename, response_models_result 28 | except Exception as e: 29 | logging.exception(e) 30 | return 1, None, str(e) 31 | 32 | @staticmethod 33 | def convert_to_wav(webm_full_filepath, new_filename): 34 | converted_record_path = os.path.join('./records', new_filename) 35 | subprocess.call('ffmpeg -i {0} -ar 16000 -b:a 256k -ac 1 -sample_fmt s16 {1}'.format( 36 | webm_full_filepath, converted_record_path 37 | ), 38 | shell=True, 39 | stdout=subprocess.DEVNULL, 40 | stderr=subprocess.DEVNULL 41 | ) 42 | os.remove(webm_full_filepath) 43 | return converted_record_path 44 | 45 | @staticmethod 46 | def check_format(files): 47 | return (files.mimetype.startswith('audio/') or [ 48 | files.filename.endswith(audio_format) for audio_format in [ 49 | 'mp3', 'ogg', 'acc', 'flac', 'au', 'm4a', 'mp4', 'mov', 'avi', 'wmv', '3gp', 'flv', 'mkv' 50 | ] 51 | ]) 52 | return True 53 | 54 | @staticmethod 55 | def get_models_result(converted_record_path, delimiter='
'): 56 | results = [] 57 | start = time.time() 58 | decoder_result = speech_recognizer.recognize(converted_record_path) 59 | text = punctuator.predict(decoder_result.text) 60 | text = text2numbers.convert(text) 61 | end = time.time() 62 | results.append( 63 | { 64 | 'text': text, 65 | 'time': round(end - start, 3), 66 | 'confidence': decoder_result.score, 67 | 'words': decoder_result.words 68 | } 69 | ) 70 | return results 71 | -------------------------------------------------------------------------------- /number_utils/russian_numbers.py: -------------------------------------------------------------------------------- 1 | import string, re 2 | import numpy as np 3 | 4 | 5 | class Numeral: 6 | def __init__(self, value, level, is_multiplier, is_eleven_to_nineteen=False): 7 | self.value = value 8 | self.level = level 9 | self.is_multiplier = is_multiplier 10 | self.is_eleven_to_nineteen = is_eleven_to_nineteen 11 | 12 | 13 | class NumericToken: 14 | def __init__(self, numeral, error=0): 15 | self.numeral = numeral 16 | self.error = error 17 | self.is_significant = False 18 | 19 | 20 | class ParserResult: 21 | def __init__(self, value, error=0): 22 | self.value = value 23 | self.error = error 24 | 25 | 26 | class RussianNumbers: 27 | def __init__(self): 28 | self.tokens_fractions = { 29 | "целых": Numeral(1, 0, True), 30 | "целым": Numeral(1, 0, True), 31 | "целой": Numeral(1, 0, True), 32 | "целая": Numeral(1, 0, True), 33 | 34 | "точка": Numeral(0.1, 0, True), 35 | "запятая": Numeral(1, 0, True), 36 | 37 | "десятых": Numeral(0.1, -1, True), 38 | "десятым": Numeral(0.1, -1, True), 39 | "десятая": Numeral(0.1, -1, True), 40 | "сотых": Numeral(0.01, -3, True), 41 | "сотым": Numeral(0.01, -3, True), 42 | "сотая": Numeral(0.01, -3, True), 43 | "тысячных": Numeral(0.001, -4, True), 44 | "тысячным": Numeral(0.001, -4, True), 45 | "тысячная": Numeral(0.001, -4, True), 46 | "десятитысячных": Numeral(0.0001, -5, True), 47 | "десятитысячная": Numeral(0.0001, -5, True), 48 | 49 | # "половиной": Numeral(0.5, -1, False), 50 | } 51 | 52 | self.tokens = { 53 | "ноль": Numeral(0, 1, False), 54 | "нулю": Numeral(0, 1, False), 55 | "нолю": Numeral(0, 1, False), 56 | "ноля": Numeral(0, 1, False), 57 | 58 | "полтора": Numeral(1.5, 1, False), 59 | "полторы": Numeral(1.5, 1, False), 60 | 61 | "один": Numeral(1, 1, False), 62 | "одна": Numeral(1, 1, False), 63 | "одной": Numeral(1, 1, False), 64 | "первое": Numeral(1, 1, False), 65 | "первый": Numeral(1, 1, False), 66 | "первая": Numeral(1, 1, False), 67 | "первого": Numeral(1, 1, False), 68 | "первой": Numeral(1, 1, False), 69 | "первом": Numeral(1, 1, False), 70 | 71 | "два": Numeral(2, 1, False), 72 | "две": Numeral(2, 1, False), 73 | "двум": Numeral(2, 1, False), 74 | "двух": Numeral(2, 1, False), 75 | "второе": Numeral(2, 1, False), 76 | "второй": Numeral(2, 1, False), 77 | "вторая": Numeral(2, 1, False), 78 | "второго": Numeral(2, 1, False), 79 | "втором": Numeral(2, 1, False), 80 | 81 | "три": Numeral(3, 1, False), 82 | "трем": Numeral(3, 1, False), 83 | "трех": Numeral(3, 1, False), 84 | "трёх": Numeral(3, 1, False), 85 | "трём": Numeral(3, 1, False), 86 | "третье": Numeral(3, 1, False), 87 | "третий": Numeral(3, 1, False), 88 | "третья": Numeral(3, 1, False), 89 | "третьего": Numeral(3, 1, False), 90 | "третьей": Numeral(3, 1, False), 91 | "третьем": Numeral(3, 1, False), 92 | 93 | "четыре": Numeral(4, 1, False), 94 | "четырем": Numeral(4, 1, False), 95 | "четырех": Numeral(4, 1, False), 96 | "четырём": Numeral(4, 1, False), 97 | "четырёх": Numeral(4, 1, False), 98 | "четвертое": Numeral(4, 1, False), 99 | "четвертый": Numeral(4, 1, False), 100 | "четвертая": Numeral(4, 1, False), 101 | "четвертого": Numeral(4, 1, False), 102 | "четвертой": Numeral(4, 1, False), 103 | "четвертом": Numeral(4, 1, False), 104 | 105 | "пять": Numeral(5, 1, False), 106 | "пяти": Numeral(5, 1, False), 107 | "пятое": Numeral(5, 1, False), 108 | "пятый": Numeral(5, 1, False), 109 | "пятая": Numeral(5, 1, False), 110 | "пятого": Numeral(5, 1, False), 111 | "пятой": Numeral(5, 1, False), 112 | "пятом": Numeral(5, 1, False), 113 | 114 | "шесть": Numeral(6, 1, False), 115 | "шести": Numeral(6, 1, False), 116 | "шестое": Numeral(6, 1, False), 117 | "шестой": Numeral(6, 1, False), 118 | "шестая": Numeral(6, 1, False), 119 | "шестого": Numeral(6, 1, False), 120 | "шестом": Numeral(6, 1, False), 121 | 122 | "семь": Numeral(7, 1, False), 123 | "семи": Numeral(7, 1, False), 124 | "седьмое": Numeral(7, 1, False), 125 | "седьмой": Numeral(7, 1, False), 126 | "седьмая": Numeral(7, 1, False), 127 | "седьмого": Numeral(7, 1, False), 128 | "седьмом": Numeral(7, 1, False), 129 | 130 | "восемь": Numeral(8, 1, False), 131 | "восьми": Numeral(8, 1, False), 132 | "восьмое": Numeral(8, 1, False), 133 | "восьмой": Numeral(8, 1, False), 134 | "восьмая": Numeral(8, 1, False), 135 | "восьмого": Numeral(8, 1, False), 136 | "восьмом": Numeral(8, 1, False), 137 | 138 | "девять": Numeral(9, 1, False), 139 | "девяти": Numeral(9, 1, False), 140 | "девятое": Numeral(9, 1, False), 141 | "девятый": Numeral(9, 1, False), 142 | "девятая": Numeral(9, 1, False), 143 | "девятого": Numeral(9, 1, False), 144 | "девятой": Numeral(9, 1, False), 145 | "девятом": Numeral(9, 1, False), 146 | 147 | "десять": Numeral(10, 1, False), 148 | "десятью": Numeral(10, 1, False), 149 | "десяти": Numeral(10, 1, False), 150 | "десятое": Numeral(10, 1, False), 151 | "десятый": Numeral(10, 1, False), 152 | "десятая": Numeral(10, 1, False), 153 | "десятой": Numeral(10, 1, False), 154 | "десятого": Numeral(10, 1, False), 155 | "десятом": Numeral(10, 1, False), 156 | 157 | "одиннадцать": Numeral(11, 1, False, True), 158 | "одиннадцатью": Numeral(11, 1, False, True), 159 | "одиннадцати": Numeral(11, 1, False, True), 160 | "одиннадцатое": Numeral(11, 1, False, True), 161 | "одиннадцатый": Numeral(11, 1, False, True), 162 | "одиннадцатая": Numeral(11, 1, False, True), 163 | "одиннадцатого": Numeral(11, 1, False, True), 164 | "одиннадцатой": Numeral(11, 1, False, True), 165 | "одиннадцатом": Numeral(11, 1, False, True), 166 | 167 | "двенадцать": Numeral(12, 1, False, True), 168 | "двенадцатью": Numeral(12, 1, False, True), 169 | "двенадцати": Numeral(12, 1, False, True), 170 | "двенадцатое": Numeral(12, 1, False, True), 171 | "двенадцатый": Numeral(12, 1, False, True), 172 | "двенадцатая": Numeral(12, 1, False, True), 173 | "двенадцатого": Numeral(12, 1, False, True), 174 | "двенадцатой": Numeral(12, 1, False, True), 175 | "двенадцатом": Numeral(12, 1, False, True), 176 | 177 | "тринадцать": Numeral(13, 1, False, True), 178 | "тринадцатью": Numeral(13, 1, False, True), 179 | "тринадцати": Numeral(13, 1, False, True), 180 | "тринадцатое": Numeral(13, 1, False, True), 181 | "тринадцатый": Numeral(13, 1, False, True), 182 | "тринадцатая": Numeral(13, 1, False, True), 183 | "тринадцатого": Numeral(13, 1, False, True), 184 | "тринадцатой": Numeral(13, 1, False, True), 185 | "тринадцатом": Numeral(13, 1, False, True), 186 | 187 | "четырнадцать": Numeral(14, 1, False, True), 188 | "четырнадцатью": Numeral(14, 1, False, True), 189 | "четырнадцати": Numeral(14, 1, False, True), 190 | "четырнадцатое": Numeral(14, 1, False, True), 191 | "четырнадцатый": Numeral(14, 1, False, True), 192 | "четырнадцатая": Numeral(14, 1, False, True), 193 | "четырнадцатого": Numeral(14, 1, False, True), 194 | "четырнадцатой": Numeral(14, 1, False, True), 195 | "четырнадцатом": Numeral(14, 1, False, True), 196 | 197 | "пятнадцать": Numeral(15, 1, False, True), 198 | "пятнадцатью": Numeral(15, 1, False, True), 199 | "пятнадцати": Numeral(15, 1, False, True), 200 | "пятнадцатое": Numeral(15, 1, False, True), 201 | "пятнадцатый": Numeral(15, 1, False, True), 202 | "пятнадцатая": Numeral(15, 1, False, True), 203 | "пятнадцатого": Numeral(15, 1, False, True), 204 | "пятнадцатой": Numeral(15, 1, False, True), 205 | "пятнадцатом": Numeral(15, 1, False, True), 206 | 207 | "шестнадцать": Numeral(16, 1, False, True), 208 | "шестнадцатью": Numeral(16, 1, False, True), 209 | "шестнадцати": Numeral(16, 1, False, True), 210 | "шестнадцатое": Numeral(16, 1, False, True), 211 | "шестнадцатый": Numeral(16, 1, False, True), 212 | "шестнадцатая": Numeral(16, 1, False, True), 213 | "шестнадцатого": Numeral(16, 1, False, True), 214 | "шестнадцатой": Numeral(16, 1, False, True), 215 | "шестнадцатом": Numeral(16, 1, False, True), 216 | 217 | "семнадцать": Numeral(17, 1, False, True), 218 | "семнадцатью": Numeral(17, 1, False, True), 219 | "семнадцати": Numeral(17, 1, False, True), 220 | "семнадцатое": Numeral(17, 1, False, True), 221 | "семнадцатый": Numeral(17, 1, False, True), 222 | "семнадцатая": Numeral(17, 1, False, True), 223 | "семнадцатого": Numeral(17, 1, False, True), 224 | "семнадцатой": Numeral(17, 1, False, True), 225 | "семнадцатом": Numeral(17, 1, False, True), 226 | 227 | "восемнадцать": Numeral(18, 1, False, True), 228 | "восемнадцатью": Numeral(18, 1, False, True), 229 | "восемнадцати": Numeral(18, 1, False, True), 230 | "восемнадцатое": Numeral(18, 1, False, True), 231 | "восемнадцатый": Numeral(18, 1, False, True), 232 | "восемнадцатая": Numeral(18, 1, False, True), 233 | "восемнадцатого": Numeral(18, 1, False, True), 234 | "восемнадцатой": Numeral(18, 1, False, True), 235 | "восемнадцатом": Numeral(18, 1, False, True), 236 | 237 | "девятнадцать": Numeral(19, 1, False, True), 238 | "девятнадцатью": Numeral(19, 1, False, True), 239 | "девятнадцати": Numeral(19, 1, False, True), 240 | "девятнадцатое": Numeral(19, 1, False, True), 241 | "девятнадцатый": Numeral(19, 1, False, True), 242 | "девятнадцатая": Numeral(19, 1, False, True), 243 | "девятнадцатого": Numeral(19, 1, False, True), 244 | "девятнадцатой": Numeral(19, 1, False, True), 245 | "девятнадцатом": Numeral(19, 1, False, True), 246 | 247 | "двадцать": Numeral(20, 2, False), 248 | "двадцатью": Numeral(20, 2, False), 249 | "двадцати": Numeral(20, 2, False), 250 | "двадцатое": Numeral(20, 2, False), 251 | "двадцатый": Numeral(20, 2, False), 252 | "двадцатая": Numeral(20, 2, False), 253 | "двадцатого": Numeral(20, 2, False), 254 | "двадцатой": Numeral(20, 2, False), 255 | "двадцатом": Numeral(20, 2, False), 256 | 257 | "тридцать": Numeral(30, 2, False), 258 | "тридцатью": Numeral(30, 2, False), 259 | "тридцати": Numeral(30, 2, False), 260 | "тридцатое": Numeral(30, 2, False), 261 | "тридцатый": Numeral(30, 2, False), 262 | "тридцатая": Numeral(30, 2, False), 263 | "тридцатого": Numeral(30, 2, False), 264 | "тридцатой": Numeral(30, 2, False), 265 | "тридцатом": Numeral(30, 2, False), 266 | 267 | "сорок": Numeral(40, 2, False), 268 | "сорока": Numeral(40, 2, False), 269 | "сороковое": Numeral(40, 2, False), 270 | "сороковой": Numeral(40, 2, False), 271 | "сороковая": Numeral(40, 2, False), 272 | "сорокового": Numeral(40, 2, False), 273 | "сороковом": Numeral(40, 2, False), 274 | 275 | "пятьдесят": Numeral(50, 2, False), 276 | "пятидесяти": Numeral(50, 2, False), 277 | "пятьюдесятью": Numeral(50, 2, False), 278 | "пятидесятое": Numeral(50, 2, False), 279 | "пятидесятый": Numeral(50, 2, False), 280 | "пятидесятая": Numeral(50, 2, False), 281 | "пятидесятого": Numeral(50, 2, False), 282 | "пятидесятой": Numeral(50, 2, False), 283 | "пятидесятом": Numeral(50, 2, False), 284 | 285 | "шестьдесят": Numeral(60, 2, False), 286 | "шестидесяти": Numeral(60, 2, False), 287 | "шестьюдесятью": Numeral(60, 2, False), 288 | "шестидесятое": Numeral(60, 2, False), 289 | "шестидесятый": Numeral(60, 2, False), 290 | "шестидесятая": Numeral(60, 2, False), 291 | "шестидесятого": Numeral(60, 2, False), 292 | "шестидесятой": Numeral(60, 2, False), 293 | "шестидесятом": Numeral(60, 2, False), 294 | 295 | "семьдесят": Numeral(70, 2, False), 296 | "семидесяти": Numeral(70, 2, False), 297 | "семьюдесятью": Numeral(70, 2, False), 298 | "семидесятое": Numeral(70, 2, False), 299 | "семидесятый": Numeral(70, 2, False), 300 | "семидесятая": Numeral(70, 2, False), 301 | "семидесятого": Numeral(70, 2, False), 302 | "семидесятой": Numeral(70, 2, False), 303 | "семидесятом": Numeral(70, 2, False), 304 | 305 | "восемьдесят": Numeral(80, 2, False), 306 | "восьмидесяти": Numeral(80, 2, False), 307 | "восемьюдесятью": Numeral(80, 2, False), 308 | "восьмидесятое": Numeral(80, 2, False), 309 | "восьмидесятый": Numeral(80, 2, False), 310 | "восьмидесятая": Numeral(80, 2, False), 311 | "восьмидесятого": Numeral(80, 2, False), 312 | "восьмидесятой": Numeral(80, 2, False), 313 | "восьмидесятом": Numeral(80, 2, False), 314 | 315 | "девяносто": Numeral(90, 2, False), 316 | "девяноста": Numeral(90, 2, False), 317 | "девяностое": Numeral(90, 2, False), 318 | "девяностый": Numeral(90, 2, False), 319 | "девяностая": Numeral(90, 2, False), 320 | "девяностого": Numeral(90, 2, False), 321 | "девяностой": Numeral(90, 2, False), 322 | "девяностом": Numeral(90, 2, False), 323 | 324 | "сто": Numeral(100, 3, False), 325 | "ста": Numeral(100, 3, False), 326 | "сотое": Numeral(100, 3, False), 327 | "сотый": Numeral(100, 3, False), 328 | # "сотая": Numeral(100, 3, False), 329 | "сотого": Numeral(100, 3, False), 330 | "сотой": Numeral(100, 3, False), 331 | "сотом": Numeral(100, 3, False), 332 | 333 | "двести": Numeral(200, 3, False), 334 | "двухсот": Numeral(200, 3, False), 335 | "двумстам": Numeral(200, 3, False), 336 | "двухстах": Numeral(200, 3, False), 337 | "двустам": Numeral(200, 3, False), 338 | "двухсотый": Numeral(200, 3, False), 339 | "двухсотая": Numeral(200, 3, False), 340 | "двухсотого": Numeral(200, 3, False), 341 | "двухсотой": Numeral(200, 3, False), 342 | "двухсотом": Numeral(200, 3, False), 343 | 344 | "триста": Numeral(300, 3, False), 345 | "трехсот": Numeral(300, 3, False), 346 | "тремстам": Numeral(300, 3, False), 347 | "трехстах": Numeral(300, 3, False), 348 | "трёхстах": Numeral(300, 3, False), 349 | "трехсотое": Numeral(300, 3, False), 350 | "трехсотый": Numeral(300, 3, False), 351 | "трехсотая": Numeral(300, 3, False), 352 | "трехсотого": Numeral(300, 3, False), 353 | "трехсотой": Numeral(300, 3, False), 354 | "трехсотом": Numeral(300, 3, False), 355 | "трёхсот": Numeral(300, 3, False), 356 | "трёмстам": Numeral(300, 3, False), 357 | "трёхсотое": Numeral(300, 3, False), 358 | "трёхсотый": Numeral(300, 3, False), 359 | "трёхсотая": Numeral(300, 3, False), 360 | "трёхсотого": Numeral(300, 3, False), 361 | "трёхсотой": Numeral(300, 3, False), 362 | "трёхсотом": Numeral(300, 3, False), 363 | 364 | "четыреста": Numeral(400, 3, False), 365 | "четырехсот": Numeral(400, 3, False), 366 | "четыремстам": Numeral(400, 3, False), 367 | "четырехстах": Numeral(400, 3, False), 368 | "четырёхстах": Numeral(400, 3, False), 369 | "четырехсотое": Numeral(400, 3, False), 370 | "четырехсотый": Numeral(400, 3, False), 371 | "четырехсотая": Numeral(400, 3, False), 372 | "четырехсотого": Numeral(400, 3, False), 373 | "четырехсотой": Numeral(400, 3, False), 374 | "четырехсотом": Numeral(400, 3, False), 375 | "четырёхсот": Numeral(400, 3, False), 376 | "четырёмстам": Numeral(400, 3, False), 377 | "четырёхсотое": Numeral(400, 3, False), 378 | "четырёхсотый": Numeral(400, 3, False), 379 | "четырёхсотая": Numeral(400, 3, False), 380 | "четырёхсотого": Numeral(400, 3, False), 381 | "четырёхсотой": Numeral(400, 3, False), 382 | "четырёхсотом": Numeral(400, 3, False), 383 | 384 | "пятьсот": Numeral(500, 3, False), 385 | "пятистах": Numeral(500, 3, False), 386 | "пятисот": Numeral(500, 3, False), 387 | "пятистам": Numeral(500, 3, False), 388 | "пятьсотое": Numeral(500, 3, False), 389 | "пятьсотый": Numeral(500, 3, False), 390 | "пятьсотая": Numeral(500, 3, False), 391 | "пятьсотого": Numeral(500, 3, False), 392 | "пятьсотой": Numeral(500, 3, False), 393 | "пятьсотом": Numeral(500, 3, False), 394 | 395 | "шестьсот": Numeral(600, 3, False), 396 | "шестистах": Numeral(600, 3, False), 397 | "шестисот": Numeral(600, 3, False), 398 | "шестистам": Numeral(600, 3, False), 399 | "шестисотое": Numeral(600, 3, False), 400 | "шестисотый": Numeral(600, 3, False), 401 | "шестисотая": Numeral(600, 3, False), 402 | "шестисотого": Numeral(600, 3, False), 403 | "шестисотой": Numeral(600, 3, False), 404 | "шестисотом": Numeral(600, 3, False), 405 | 406 | "семьсот": Numeral(700, 3, False), 407 | "семистах": Numeral(700, 3, False), 408 | "семисот": Numeral(700, 3, False), 409 | "семистам": Numeral(700, 3, False), 410 | "семисотое": Numeral(700, 3, False), 411 | "семисотый": Numeral(700, 3, False), 412 | "семисотая": Numeral(700, 3, False), 413 | "семисотого": Numeral(700, 3, False), 414 | "семисотой": Numeral(700, 3, False), 415 | "семисотом": Numeral(700, 3, False), 416 | 417 | "восемьсот": Numeral(800, 3, False), 418 | "восьмистах": Numeral(800, 3, False), 419 | "восьмисот": Numeral(800, 3, False), 420 | "восьмистам": Numeral(800, 3, False), 421 | "восьмисотое": Numeral(800, 3, False), 422 | "восьмисотый": Numeral(800, 3, False), 423 | "восьмисотая": Numeral(800, 3, False), 424 | "восьмисотого": Numeral(800, 3, False), 425 | "восьмисотой": Numeral(800, 3, False), 426 | "восьмисотом": Numeral(800, 3, False), 427 | 428 | "девятьсот": Numeral(900, 3, False), 429 | "девятистах": Numeral(900, 3, False), 430 | "девятисот": Numeral(900, 3, False), 431 | "девятистам": Numeral(900, 3, False), 432 | "девятьсотое": Numeral(900, 3, False), 433 | "девятьсотый": Numeral(900, 3, False), 434 | "девятьсотая": Numeral(900, 3, False), 435 | "девятьсотого": Numeral(900, 3, False), 436 | "девятьсотой": Numeral(900, 3, False), 437 | "девятьсотом": Numeral(900, 3, False), 438 | 439 | "тысяч": Numeral(1000, 4, True), 440 | "тысяча": Numeral(1000, 4, True), 441 | "тысячи": Numeral(1000, 4, True), 442 | "тысячном": Numeral(1000, 4, True), 443 | 444 | "миллион": Numeral(1000000, 5, True), 445 | "миллиона": Numeral(1000000, 5, True), 446 | "миллионов": Numeral(1000000, 5, True), 447 | "миллионом": Numeral(1000000, 5, True), 448 | 449 | "миллиард": Numeral(1000000000, 6, True), 450 | "миллиарда": Numeral(1000000000, 6, True), 451 | "миллиардов": Numeral(1000000000, 6, True), 452 | "миллиардном": Numeral(1000000000, 6, True), 453 | 454 | "триллион": Numeral(1000000000000, 7, True), 455 | "триллиона": Numeral(1000000000000, 7, True), 456 | "триллионов": Numeral(1000000000000, 7, True), 457 | "триллионом": Numeral(1000000000000, 7, True) 458 | } 459 | 460 | self.max_token_error = 0.3 461 | 462 | def get_token_sum_error_from_lists(self, token): 463 | token_sum_error = 0 464 | 465 | if isinstance(token, NumericToken): 466 | token_sum_error += token.error 467 | 468 | else: 469 | for sub_token in token: 470 | token_sum_error += self.get_token_sum_error_from_lists(sub_token) 471 | 472 | return token_sum_error 473 | 474 | def parse_tokens(self, text_line, matrix_d, level, fraction=False): 475 | if text_line in self.tokens.keys(): 476 | return [NumericToken(self.tokens[text_line])] 477 | elif fraction: 478 | return [NumericToken(self.tokens_fractions[text_line])] 479 | else: 480 | return [None] 481 | 482 | def parse(self, text): 483 | text = text.strip().lower() 484 | 485 | if len(text) == 0: 486 | return ParserResult(value=0, error=1) 487 | 488 | # Массив для рассчета расстояния Левенштейна 489 | max_token_length = 13 490 | matrix_d = np.zeros((2, max_token_length), dtype=np.float32) 491 | 492 | # Разбиваем текст на токены 493 | raw_token_list = re.split(r"\s+", text) 494 | all_token_list = [] 495 | token_list = [] 496 | result_text_list = [] 497 | left_space_for_number = False 498 | current_level = 0 499 | 500 | # Обрабатываем токены 501 | for token_idx, raw_token in enumerate(raw_token_list): 502 | clean_token = raw_token.strip(string.punctuation) 503 | current_token_list = self.parse_tokens(clean_token, matrix_d, 0) 504 | 505 | # Определение дробного числа: 506 | if clean_token in ["целых", "целой", "целым", "целая"] and token_idx != 0: 507 | 508 | try: 509 | if raw_token_list[token_idx - 1] in self.tokens \ 510 | and raw_token_list[token_idx + 1] in self.tokens \ 511 | or raw_token_list[token_idx + 1] == "и": 512 | if raw_token_list[token_idx + 2] in self.tokens_fractions \ 513 | or raw_token_list[token_idx + 3] in self.tokens_fractions \ 514 | or raw_token_list[token_idx + 4] in self.tokens_fractions: 515 | current_token_list = self.parse_tokens(clean_token, matrix_d, 0, fraction=True) 516 | 517 | except IndexError: 518 | pass 519 | 520 | # Обработка первого порядка: 521 | if clean_token in ["десятых", "десятой", "десятым", "десятая"] and token_idx != 0: 522 | 523 | if raw_token_list[token_idx - 1] in self.tokens: 524 | current_token_list = self.parse_tokens(clean_token, matrix_d, 0, fraction=True) 525 | 526 | # Обработка второго порядка: 527 | if clean_token in ["сотых", "сотой", "сотым", "сотая"] and token_idx != 0: 528 | 529 | if raw_token_list[token_idx - 1] in self.tokens: 530 | current_token_list = self.parse_tokens(clean_token, matrix_d, 0, fraction=True) 531 | 532 | # Обработка третьего порядка: 533 | if clean_token in ["тысячных", "тысячной", "тысячным", "тысячная"] and token_idx != 0: 534 | 535 | if raw_token_list[token_idx - 1] in self.tokens: 536 | current_token_list = self.parse_tokens(clean_token, matrix_d, 0, fraction=True) 537 | 538 | # Обработка четвёртого порядка: 539 | if clean_token in ["десятитысячных", "десятитысячной", "десятитысячным", "десятитысячная"] and token_idx != 0: 540 | 541 | if raw_token_list[token_idx - 1] in self.tokens: 542 | current_token_list = self.parse_tokens(clean_token, matrix_d, 0, fraction=True) 543 | 544 | if raw_token == "тысяча": 545 | if token_idx == 0 or raw_token_list[token_idx - 1] != "одна": 546 | current_token_list = [NumericToken(numeral=Numeral(1, 1, False)), current_token_list[0]] 547 | current_level = 1 548 | 549 | if len(token_list) > 0: 550 | all_token_list.append(token_list) 551 | token_list = [] 552 | 553 | left_space_for_number = False 554 | 555 | if raw_token == "тысячная": 556 | if token_idx == 0 or raw_token_list[token_idx - 1] != "одна": 557 | current_token_list = [NumericToken(numeral=Numeral(1000, 1, False))] 558 | current_level = 0 559 | 560 | if len(token_list) > 0: 561 | all_token_list.append(token_list) 562 | token_list = [] 563 | 564 | left_space_for_number = False 565 | 566 | if raw_token == "десятая": 567 | if token_idx == 0 or raw_token_list[token_idx - 1] != "одна": 568 | current_token_list = [NumericToken(numeral=Numeral(10, 1, False))] 569 | current_level = 0 570 | 571 | if len(token_list) > 0: 572 | all_token_list.append(token_list) 573 | token_list = [] 574 | 575 | left_space_for_number = False 576 | 577 | if raw_token == "сотая": 578 | if token_idx == 0 or raw_token_list[token_idx - 1] != "одна": 579 | current_token_list = [NumericToken(numeral=Numeral(100, 1, False))] 580 | current_level = 0 581 | 582 | if len(token_list) > 0: 583 | all_token_list.append(token_list) 584 | token_list = [] 585 | 586 | left_space_for_number = False 587 | 588 | if raw_token == "ноль": 589 | if token_idx != 0 or raw_token_list[token_idx - 1] == "ноль": 590 | current_token_list = [NumericToken(numeral=Numeral(0, 1, False)), current_token_list[0]] 591 | current_level = 0 592 | 593 | if len(token_list) > 0: 594 | all_token_list.append(token_list) 595 | token_list = [] 596 | 597 | left_space_for_number = False 598 | 599 | if current_token_list[0] is not None: 600 | previous_level = current_level 601 | current_level = current_token_list[len(current_token_list) - 1].numeral.level 602 | is_eleven_to_nineteen = current_token_list[0].numeral.is_eleven_to_nineteen 603 | 604 | if current_level != 0 and previous_level != 0 and ( 605 | (previous_level < current_level <= 3) or current_level == previous_level or (current_level < previous_level <= 2 and is_eleven_to_nineteen)): 606 | all_token_list.append(token_list) 607 | token_list = [] 608 | left_space_for_number = False 609 | 610 | bad_tokens = True 611 | 612 | for current_token in current_token_list: 613 | if current_token is None: 614 | break 615 | 616 | if current_token.error <= self.max_token_error: 617 | bad_tokens = False 618 | break 619 | 620 | if bad_tokens is True: 621 | del current_token_list 622 | current_token_list = [None] 623 | 624 | if current_token_list[0] is not None: 625 | if left_space_for_number is False: 626 | result_text_list.append("") 627 | left_space_for_number = True 628 | 629 | for current_token in current_token_list: 630 | token_list.append(current_token) 631 | 632 | else: 633 | result_text_list.append(raw_token) 634 | left_space_for_number = False 635 | current_level = 0 636 | 637 | if len(token_list) > 0: 638 | all_token_list.append(token_list) 639 | token_list = [] 640 | 641 | if len(token_list) > 0: 642 | all_token_list.append(token_list) 643 | token_list = [] 644 | 645 | parser_result_list = [] 646 | 647 | for token_list in all_token_list: 648 | global_level = None 649 | local_level = None 650 | global_value = None 651 | local_value = None 652 | critical_error = False 653 | 654 | token_count = len(token_list) 655 | 656 | for current_token in token_list: 657 | current_error = self.get_token_sum_error_from_lists(current_token) 658 | 659 | if current_error > self.max_token_error: 660 | continue 661 | 662 | value = current_token.numeral.value 663 | level = current_token.numeral.level 664 | multiplier = current_token.numeral.is_multiplier 665 | 666 | if multiplier: 667 | if global_level is None: 668 | if local_level is None: 669 | global_value = value 670 | else: 671 | global_value = np.round(local_value * value, 5) 672 | 673 | global_level = level 674 | local_value = None 675 | local_level = None 676 | 677 | current_token.is_significant = True 678 | 679 | elif global_level > level: 680 | if local_level is None: 681 | global_value = global_value + value 682 | else: 683 | if value == 0.1: 684 | global_value = np.round((global_value + local_value * value), 1) 685 | elif value == 0.01: 686 | global_value = np.round((global_value + local_value * value), 2) 687 | elif value == 0.001: 688 | global_value = np.round((global_value + local_value * value), 3) 689 | elif value == 0.0001: 690 | global_value = np.round((global_value + local_value * value), 4) 691 | 692 | global_level = level 693 | local_value = None 694 | local_level = None 695 | 696 | current_token.is_significant = True 697 | 698 | else: 699 | # Ошибка несоответствия уровней 700 | current_token.error = 1 701 | current_token.is_significant = True 702 | critical_error = True 703 | 704 | else: 705 | # Простое числительное 706 | if local_level is None: 707 | local_value = value 708 | local_level = level 709 | 710 | current_token.is_significant = True 711 | 712 | elif local_level > level: 713 | local_value = local_value + value 714 | local_level = level 715 | 716 | current_token.is_significant = True 717 | 718 | else: 719 | # Ошибка несоответствия уровней 720 | current_token.error = 1 721 | current_token.is_significant = True 722 | critical_error = True 723 | 724 | # Считаем общий уровень ошибки 725 | if token_count == 0: 726 | total_error = 1 727 | 728 | else: 729 | total_error = 0 730 | significant_token_count = 0 731 | 732 | for current_token in token_list: 733 | if current_token.is_significant: 734 | total_error += current_token.error 735 | significant_token_count += 1 736 | 737 | total_error /= significant_token_count 738 | 739 | if critical_error: 740 | # Имела место критическая ошибка 741 | if total_error >= 0.5: 742 | total_error = 1 743 | 744 | else: 745 | total_error *= 2 746 | 747 | result_value = 0 748 | 749 | if global_value is not None: 750 | result_value += global_value 751 | 752 | if local_value is not None: 753 | result_value += local_value 754 | parser_result_list.append(ParserResult(result_value, total_error)) 755 | 756 | return parser_result_list, result_text_list 757 | -------------------------------------------------------------------------------- /number_utils/text2numbers.py: -------------------------------------------------------------------------------- 1 | from number_utils.russian_numbers import RussianNumbers 2 | import numpy as np 3 | 4 | 5 | class TextToNumbers: 6 | def __init__(self): 7 | self.russian_numbers = RussianNumbers() 8 | 9 | def convert(self, text_line): 10 | if not text_line: 11 | return text_line 12 | 13 | parsed_list, result_text_list = self.russian_numbers.parse(text=text_line) 14 | converted_text = "" 15 | parsed_idx = 0 16 | result_text_list_len = len(result_text_list) 17 | 18 | for i, element in enumerate(result_text_list): 19 | if element == "": 20 | converted_text += str(parsed_list[parsed_idx].value) 21 | parsed_idx += 1 22 | 23 | else: 24 | converted_text += element 25 | 26 | if i < result_text_list_len - 1: 27 | converted_text += " " 28 | 29 | converted_text = self.float_postprocessing(converted_text) 30 | 31 | return converted_text 32 | 33 | @staticmethod 34 | def float_postprocessing(converted_text): 35 | if "минус" in converted_text and " и " not in converted_text: 36 | for x in ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]: 37 | if x in converted_text: 38 | converted_text = converted_text.replace("минус ", "-") 39 | 40 | if "точка" in converted_text: 41 | try: 42 | for idx, val in enumerate(converted_text): 43 | if val == "т": 44 | if converted_text[idx + 4] == "а" \ 45 | and converted_text[idx - 2] in ["0", "1", "2", "3", "4", 46 | "5", "6", "7", "8", "9"] \ 47 | and converted_text[idx + 6] in ["0", "1", "2", "3", "4", 48 | "5", "6", "7", "8", "9"] \ 49 | and "0." in converted_text[idx + 6:]: 50 | converted_text = converted_text.replace(" точка ", ".").replace("0.", "") 51 | if converted_text[idx + 4] == "а" \ 52 | and converted_text[idx - 2] in ["0", "1", "2", "3", "4", 53 | "5", "6", "7", "8", "9"] \ 54 | and converted_text[idx + 6] in ["0", "1", "2", "3", "4", 55 | "5", "6", "7", "8", "9"] \ 56 | and "0." not in converted_text[idx + 6:]: 57 | converted_text = converted_text.replace(" точка ", ".") 58 | 59 | except IndexError: 60 | pass 61 | 62 | if "запятая" in converted_text: 63 | try: 64 | for idx, val in enumerate(converted_text): 65 | if val == "з": 66 | if converted_text[idx + 6] == "я" \ 67 | and converted_text[idx - 2] in ["0", "1", "2", "3", "4", 68 | "5", "6", "7", "8", "9"] \ 69 | and converted_text[idx + 8] in ["0", "1", "2", "3", "4", 70 | "5", "6", "7", "8", "9"] \ 71 | and "0." in converted_text[idx + 8:]: 72 | converted_text = converted_text.replace(" запятая ", ".").replace("0.", "") 73 | if converted_text[idx + 6] == "я" \ 74 | and converted_text[idx - 2] in ["0", "1", "2", "3", "4", 75 | "5", "6", "7", "8", "9"] \ 76 | and converted_text[idx + 8] in ["0", "1", "2", "3", "4", 77 | "5", "6", "7", "8", "9"] \ 78 | and "0." not in converted_text[idx + 8:]: 79 | converted_text = converted_text.replace(" запятая ", ".") 80 | 81 | except IndexError: 82 | pass 83 | 84 | if " и " in converted_text: 85 | try: 86 | for idx, val in enumerate(converted_text): 87 | if val == "и" \ 88 | and converted_text[idx + 2] in ["0", "1", "2", "3", "4", 89 | "5", "6", "7", "8", "9"] \ 90 | and converted_text[idx - 1] == " " \ 91 | and converted_text[idx - 2] in ["0", "1", "2", "3", "4", 92 | "5", "6", "7", "8", "9"]: 93 | 94 | if " 0." in converted_text: 95 | converted_text = converted_text.replace("0.", ".").replace(" и ", "") 96 | 97 | if "после запятой" in converted_text: 98 | converted_text = converted_text.replace(" после запятой", "") 99 | 100 | if "после запятой" in converted_text: 101 | converted_text = converted_text.replace(" и ", ".").replace(" после запятой", "") 102 | 103 | else: 104 | converted_text = converted_text.replace(" и ", ".") 105 | except IndexError: 106 | return converted_text 107 | 108 | try: 109 | converted_text = str(np.round(np.float32(converted_text), 5)) 110 | except (ValueError, TypeError): 111 | pass 112 | 113 | return converted_text 114 | 115 | 116 | if __name__ == "__main__": 117 | text2numbers = TextToNumbers() 118 | 119 | while True: 120 | text_line = input("Введите ваш текст:\n") 121 | converted_line = text2numbers.convert(text_line) 122 | print(f"\nРаспознанное: {converted_line}\n\n") 123 | -------------------------------------------------------------------------------- /punctuator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | from bert_punctuator.bert import BertPunc, BertConfig 5 | from bert_punctuator.tokenizer import BertTokenizer 6 | from PuzzleLib.Backend import gpuarray 7 | from PuzzleLib.Config import getLogger 8 | import logging 9 | 10 | 11 | logger = getLogger() 12 | logger.setLevel(logging.ERROR) 13 | 14 | 15 | class Punctuator(object): 16 | def __init__(self, model_path="data/punctuator"): 17 | self.tokenizer = BertTokenizer(os.path.join(model_path, "vocab.txt"), lower_case=True) 18 | 19 | conf = BertConfig(os.path.join(model_path, "config.json")) 20 | self.segment_size = conf.segment_size 21 | 22 | self.punctuation_enc = { 23 | " ": 0, 24 | ", ": 1, 25 | ". ": 2, 26 | "? ": 3 27 | } 28 | self.punctuation_dec = {i:key for key, i in self.punctuation_enc.items()} 29 | 30 | self.bert_punctuator = BertPunc(conf) 31 | self.bert_punctuator.evalMode() 32 | self.bert_punctuator.calcMode(np.float16) 33 | self.bert_punctuator.load(os.path.join(model_path, "bert16.hdf")) 34 | 35 | 36 | def segment(self, ids): 37 | x = [] 38 | x_pad = ids[-((self.segment_size - 1) // 2 - 1):] + ids + ids[:self.segment_size // 2] 39 | 40 | for i in range(len(x_pad) - self.segment_size + 2): 41 | segment = x_pad[i:i + self.segment_size - 1] 42 | segment.insert((self.segment_size - 1) // 2, 0) 43 | x.append(segment) 44 | 45 | return np.array(x) 46 | 47 | 48 | def preprocess_data(self, txt): 49 | data = txt.split() 50 | token_count = [] 51 | x = [] 52 | for word in data: 53 | tokens = self.tokenizer.tokenize(word) 54 | ids = self.tokenizer.convert_tokens_to_ids(tokens) 55 | if len(ids) > 0: 56 | x += ids 57 | token_count.append([word, len(ids)]) 58 | x = self.segment(x) 59 | return x, token_count 60 | 61 | 62 | def get_predictions(self, batches): 63 | y_pred = [] 64 | for batch in batches: 65 | inputs = gpuarray.to_gpu(batch.astype(np.int32)) 66 | output = self.bert_punctuator(inputs).get() 67 | y_pred += list(output.argmax(axis=1).flatten()) 68 | return y_pred 69 | 70 | 71 | def convert_predictions(self, token_count, y): 72 | i = 0 73 | s = "" 74 | for word, k in token_count: 75 | i += k 76 | punc = self.punctuation_dec[y[i - 1]] 77 | if i == len(y) and punc not in [". ", "? "]: 78 | punc = ". " 79 | s = s + word + punc 80 | 81 | pred = s[0].upper() + s[1] 82 | for i in range(2, len(s)): 83 | if s[i - 2] in [".", "?"]: 84 | pred += s[i].upper() 85 | else: 86 | pred += s[i] 87 | 88 | return pred[:-1] 89 | 90 | 91 | def predict(self, txt): 92 | txt = "берт расставляет знаки препинания в строке предсказывая токены знаков препинания. " + txt 93 | x, token_count = self.preprocess_data(txt) 94 | 95 | batches = np.array_split(x, math.ceil(x.shape[0] / 8), axis=0) 96 | 97 | y_pred = self.get_predictions(batches) 98 | 99 | pred = self.convert_predictions(token_count, y_pred)[83:] 100 | 101 | return pred 102 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Flask==1.1.2 2 | gunicorn==20.0.4 3 | numpy==1.19.5 4 | scipy==1.5.4 5 | pydub==0.25.1 6 | h5py==3.1.0 7 | colorama==0.4.4 8 | Levenshtein==0.12.0 -------------------------------------------------------------------------------- /speech_recognizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import configparser 4 | from data_loader import preprocess 5 | from decoder import GreedyDecoder 6 | 7 | 8 | class SpeechRecognizer(object): 9 | def __init__(self, config_path='config.ini'): 10 | if config_path is None: 11 | raise Exception('Path to config file is None') 12 | self.config = configparser.ConfigParser() 13 | self.config.read(config_path, encoding='UTF-8') 14 | self.labels = self.config['Wav2Letter']['labels'][1:-1] 15 | self.sample_rate = int(self.config['Wav2Letter']['sample_rate']) 16 | self.window_size = float(self.config['Wav2Letter']['window_size']) 17 | self.window_stride = float(self.config['Wav2Letter']['window_stride']) 18 | self.greedy = int(self.config['Wav2Letter']['greedy']) 19 | self.cpu = int(self.config['Wav2Letter']['cpu']) 20 | 21 | if self.cpu: 22 | from PuzzleLib import Config 23 | Config.backend = Config.Backend.cpu 24 | 25 | from PuzzleLib.Models.Nets.WaveToLetter import loadW2L 26 | from PuzzleLib.Modules import MoveAxis 27 | 28 | nfft = int(self.sample_rate * self.window_size) 29 | self.w2l = loadW2L(modelpath=self.config['Wav2Letter']['model_path'], inmaps=(1 + nfft // 2), 30 | nlabels=len(self.labels)) 31 | self.w2l.append(MoveAxis(src=2, dst=0)) 32 | 33 | if not self.cpu: 34 | self.w2l.calcMode(np.float16) 35 | 36 | self.w2l.evalMode() 37 | 38 | if not self.greedy: 39 | from decoder import TrieDecoder 40 | lexicon = self.config['Wav2Letter']['lexicon'] 41 | tokens = self.config['Wav2Letter']['tokens'] 42 | lm_path = self.config['Wav2Letter']['lm_path'] 43 | beam_threshold = float(self.config['Wav2Letter']['beam_threshold']) 44 | self.decoder = TrieDecoder(lexicon, tokens, lm_path, beam_threshold) 45 | else: 46 | self.decoder = GreedyDecoder(self.labels) 47 | 48 | def recognize(self, audio_path): 49 | preprocessed_audio = preprocess(audio_path, self.sample_rate, self.window_size, self.window_stride) 50 | if self.cpu: 51 | from PuzzleLib.CPU.CPUArray import CPUArray 52 | inputs = CPUArray.toDevice(np.array([preprocessed_audio]).astype(np.float32)) 53 | else: 54 | from PuzzleLib.Backend import gpuarray 55 | inputs = gpuarray.to_gpu(np.array([preprocessed_audio]).astype(np.float16)) 56 | 57 | output = self.w2l(inputs).get() 58 | output = np.vstack(output).astype(np.float32) 59 | result = self.decoder.decode(output) 60 | 61 | if not self.cpu: 62 | from PuzzleLib.Backend.gpuarray import memoryPool 63 | memoryPool.freeHeld() 64 | 65 | del inputs, output 66 | 67 | return result 68 | 69 | 70 | def test(): 71 | parser = argparse.ArgumentParser(description='Pipeline') 72 | parser.add_argument('--audio', default='data/test.wav', metavar='DIR', help='Path to wav file') 73 | parser.add_argument('--config', default='config.ini', help='Path to config') 74 | args = parser.parse_args() 75 | 76 | recognizer = SpeechRecognizer(args.config) 77 | 78 | print(recognizer.recognize(args.audio)) 79 | 80 | 81 | if __name__ == "__main__": 82 | test() 83 | -------------------------------------------------------------------------------- /static/speech/css/styles.css: -------------------------------------------------------------------------------- 1 | /************* GENERAL ************/ 2 | a { 3 | text-decoration: none; 4 | } 5 | 6 | * { 7 | font-family: 'Vollkorn', Arial, Geneva, Helvetica, serif; 8 | } 9 | 10 | html { 11 | width: 100%; 12 | } 13 | 14 | body { 15 | width: 100%; 16 | margin: 0; 17 | height: 100%; 18 | display: flex; 19 | flex-direction: column; 20 | } 21 | 22 | 23 | /************** HEADER ****************/ 24 | header { 25 | text-align: justify; 26 | letter-spacing: 0.6px; 27 | padding: 2em 10em; 28 | color: #000; 29 | line-height: 0.6em; 30 | height: 3em; 31 | } 32 | 33 | header #logo, #login-button, 34 | header nav { 35 | display: inline-block; 36 | } 37 | 38 | header nav a { 39 | color: #4d4a4a; 40 | } 41 | 42 | header nav a:hover { 43 | color: #858080; 44 | } 45 | 46 | header #login-button { 47 | background-color: #4fb2ff; 48 | padding: 1em 3em 1em 3em; 49 | color: #fff; 50 | border-radius: 5px; 51 | font-weight: bold; 52 | } 53 | 54 | header #login-button:hover { 55 | background-color: #4891df; 56 | } 57 | 58 | header::after { 59 | content: ''; 60 | display: inline-block; 61 | width: 100%; 62 | } 63 | 64 | header #logo { 65 | height: 100%; 66 | } 67 | 68 | header #logo::before { 69 | content: ''; 70 | display: inline-block; 71 | vertical-align: middle; 72 | height: 100%; 73 | } 74 | 75 | @media screen and (max-width: 820px) { 76 | 77 | header { 78 | height: auto; 79 | } 80 | 81 | header > div, 82 | header > div #logo, 83 | header nav { 84 | height: auto; 85 | width: auto; 86 | display: block; 87 | text-align: center; 88 | } 89 | 90 | } 91 | 92 | /********* CONTENT ********/ 93 | #content { 94 | width: 100%; 95 | min-height: calc(100vh - 7em - 4.6em); 96 | box-sizing: border-box; 97 | } 98 | 99 | 100 | /* FORM SECTION */ 101 | #id-form-section { 102 | width: 100%; 103 | text-align: center; 104 | margin-top: 0%; 105 | } 106 | 107 | #id-form-section #info { 108 | color: #969696; 109 | width: 36%; 110 | margin-left: 32%; 111 | } 112 | 113 | #id-form-section h1 { 114 | font-size: 1.9em; 115 | letter-spacing: 1px; 116 | } 117 | 118 | 119 | #id-form-section form { 120 | margin-top: 2%; 121 | display: inline-block; 122 | width: 50%; 123 | } 124 | 125 | #id-form-section form * { 126 | box-sizing: border-box; 127 | } 128 | 129 | 130 | /*** CHECK_PAGE ***/ 131 | /* Style the search field */ 132 | #id-form-section #check-page-form input[type=text], 133 | #id-form-section #load-new-form input[type=text] { 134 | border: none; 135 | padding: 1.3em; 136 | font-size: 1.1em; 137 | float: left; 138 | width: 70%; 139 | display: block; 140 | } 141 | 142 | /* Style the search submit button */ 143 | #id-form-section form button { 144 | float: left; 145 | display: block; 146 | width: 13%; 147 | color: #969696; 148 | font-size: 1.6em; 149 | border: none; 150 | padding: 0.7em 0; 151 | background-color: #fff; 152 | cursor: pointer; 153 | } 154 | 155 | #id-form-section form::after { 156 | content: ""; 157 | } 158 | 159 | #id-form-section .element { 160 | margin-right: 4%; 161 | float: left; 162 | display: block; 163 | width: 13%; 164 | } 165 | 166 | #id-form-section #file-input { 167 | display: none; 168 | } 169 | 170 | /************ FOOTER ************/ 171 | 172 | footer { 173 | flex-shrink: 0; 174 | background-image: linear-gradient(-90deg, #4891df, #4fb2ff); 175 | width: 80%; 176 | text-align: justify; 177 | letter-spacing: 0.6px; 178 | height: 1em; 179 | padding: 1.8em 10%; 180 | color: #000; 181 | } 182 | 183 | footer #copyright, 184 | footer nav { 185 | display: inline-block; 186 | } 187 | 188 | footer nav a { 189 | color: #fff; 190 | } 191 | 192 | footer nav a:hover { 193 | color: #dbdbdb; 194 | } 195 | 196 | footer::after { 197 | content: ''; 198 | display: inline-block; 199 | width: 100%; 200 | } 201 | 202 | footer #copyright { 203 | height: 100%; 204 | color: #fff; 205 | } 206 | 207 | footer #copyright::before { 208 | content: ''; 209 | display: inline-block; 210 | vertical-align: middle; 211 | height: 100%; 212 | } 213 | 214 | 215 | /* LOADING SCREEN */ 216 | #uploading_message_info { 217 | display: none; 218 | text-align: center; 219 | position: fixed; 220 | width: 40%; 221 | height: 10%; 222 | left: 30%; 223 | top: 40%; 224 | vertical-align: center; 225 | z-index: 999; 226 | } 227 | 228 | #uploading_screen { 229 | position: fixed; 230 | cursor: progress; 231 | background-color: #7b7c92; 232 | opacity: 0.4; 233 | width: 100%; 234 | height: 100%; 235 | top: 0; 236 | left: 0; 237 | z-index: 999; 238 | } 239 | 240 | .lds-spinner { 241 | color: official; 242 | display: inline-block; 243 | position: relative; 244 | width: 64px; 245 | height: 64px; 246 | } 247 | 248 | .lds-spinner div { 249 | transform-origin: 32px 32px; 250 | animation: lds-spinner 1.2s linear infinite; 251 | } 252 | 253 | .lds-spinner div:after { 254 | content: " "; 255 | display: block; 256 | position: absolute; 257 | top: 3px; 258 | left: 29px; 259 | width: 5px; 260 | height: 14px; 261 | border-radius: 20%; 262 | background: #000000; 263 | } 264 | 265 | .lds-spinner div:nth-child(1) { 266 | transform: rotate(0deg); 267 | animation-delay: -1.1s; 268 | } 269 | 270 | .lds-spinner div:nth-child(2) { 271 | transform: rotate(30deg); 272 | animation-delay: -1s; 273 | } 274 | 275 | .lds-spinner div:nth-child(3) { 276 | transform: rotate(60deg); 277 | animation-delay: -0.9s; 278 | } 279 | 280 | .lds-spinner div:nth-child(4) { 281 | transform: rotate(90deg); 282 | animation-delay: -0.8s; 283 | } 284 | 285 | .lds-spinner div:nth-child(5) { 286 | transform: rotate(120deg); 287 | animation-delay: -0.7s; 288 | } 289 | 290 | .lds-spinner div:nth-child(6) { 291 | transform: rotate(150deg); 292 | animation-delay: -0.6s; 293 | } 294 | 295 | .lds-spinner div:nth-child(7) { 296 | transform: rotate(180deg); 297 | animation-delay: -0.5s; 298 | } 299 | 300 | .lds-spinner div:nth-child(8) { 301 | transform: rotate(210deg); 302 | animation-delay: -0.4s; 303 | } 304 | 305 | .lds-spinner div:nth-child(9) { 306 | transform: rotate(240deg); 307 | animation-delay: -0.3s; 308 | } 309 | 310 | .lds-spinner div:nth-child(10) { 311 | transform: rotate(270deg); 312 | animation-delay: -0.2s; 313 | } 314 | 315 | .lds-spinner div:nth-child(11) { 316 | transform: rotate(300deg); 317 | animation-delay: -0.1s; 318 | } 319 | 320 | .lds-spinner div:nth-child(12) { 321 | transform: rotate(330deg); 322 | animation-delay: 0s; 323 | } 324 | 325 | @keyframes lds-spinner { 326 | 0% { 327 | opacity: 1; 328 | } 329 | 100% { 330 | opacity: 0; 331 | } 332 | } 333 | 334 | #form-error-block { 335 | width: 100%; 336 | text-align: center; 337 | color: darkred; 338 | float: left; 339 | } 340 | 341 | 342 | /* BTN */ 343 | 344 | #rec_btn { 345 | background-color: #f8f6f1; 346 | padding: 1.5% 2%; 347 | font-size: 260%; 348 | border-radius: 50%; 349 | margin-top: 2%; 350 | border: #acacac solid 2px; 351 | } 352 | 353 | #rec_btn:hover { 354 | opacity: 0.7; 355 | } 356 | 357 | .wait_btn { 358 | cursor: wait; 359 | opacity: 0.5; 360 | pointer-events: none; 361 | } 362 | 363 | .exe_btn { 364 | background-color: #b92a2a !important; 365 | } 366 | 367 | .exe_btn:hover { 368 | background: #e20000 !important; 369 | cursor: pointer; 370 | } 371 | 372 | .act_btn:hover { 373 | background-color: #dedede !important; 374 | cursor: pointer; 375 | } 376 | 377 | #error_field { 378 | color: #B22222; 379 | margin-top: 1%; 380 | } 381 | 382 | 383 | #model-select-id { 384 | text-align: center; 385 | background: transparent; 386 | font-size: 20px; 387 | font-weight: bold; 388 | padding: 2px 10px; 389 | width: 20%; 390 | margin: 1% 40% 1% 40%; 391 | text-align-last: center; 392 | } 393 | 394 | #recognition_result { 395 | text-align: center; 396 | width: 90%; 397 | margin-left: 5%; 398 | margin-top: 4%; 399 | font-size: 1.1em; 400 | } 401 | 402 | #recognition_result_title { 403 | float: left; 404 | width: 18%; 405 | text-align: right; 406 | padding: 1%; 407 | font-weight: bold; 408 | } 409 | 410 | #recognition_result_table { 411 | float: left; 412 | width: 90%; 413 | margin-left: 5%; 414 | text-align: left; 415 | padding: 1%; 416 | } 417 | 418 | #recognition_result_audio_title { 419 | float: left; 420 | width: 18%; 421 | text-align: right; 422 | padding: 1%; 423 | font-weight: bold; 424 | } 425 | 426 | #recognition_result_audio { 427 | float: left; 428 | width: 78%; 429 | text-align: left; 430 | padding: 1%; 431 | } 432 | 433 | #table_result { 434 | width: 100%; 435 | } 436 | 437 | #table_result th, td { 438 | padding: 15px; 439 | text-align: left; 440 | border-bottom: 1px solid #ddd; 441 | } 442 | 443 | #table_result .model_name { 444 | width: 15%; 445 | } 446 | 447 | #table_result .model_audio { 448 | width: 15%; 449 | } 450 | 451 | #table_result .model_text { 452 | width: 55%; 453 | } 454 | 455 | #table_result .model_time { 456 | width: 15%; 457 | } 458 | 459 | #recording-meter { 460 | height: 10rem; 461 | margin: 1rem 0; 462 | width: 4rem; 463 | position: absolute; 464 | display: none; 465 | } 466 | -------------------------------------------------------------------------------- /static/speech/js/speech_recorder.js: -------------------------------------------------------------------------------- 1 | String.prototype.format = function () { 2 | a = this; 3 | for (k in arguments) { 4 | a = a.replace("{" + k + "}", arguments[k]); 5 | } 6 | return a; 7 | }; 8 | 9 | $(document).ready(function () { 10 | $("#multiaudio").change(function () { 11 | MULTIFILE = true; 12 | RECORDING = 2; 13 | send_records(); 14 | }); 15 | }); 16 | 17 | const send_records = (blob) => { 18 | ERROR_FIELD = document.getElementById("error_field"); 19 | RECOGNITION_RESULT = document.getElementById("table_body_result"); 20 | 21 | var data = new FormData(); 22 | 23 | if (MULTIFILE) { 24 | var blobs = $("#multiaudio")[0].files; 25 | for (var i = 0; i < blobs.length; i++) { 26 | data.append("audio_blob_" + i, blobs[i]); 27 | } 28 | MULTIFILE = false; 29 | } else { 30 | data.append("audio_blob", blob ? blob : BLOB); 31 | } 32 | 33 | fetch("/asr", { method: "post", body: data }) 34 | .then((response) => { 35 | if (!response.ok) throw response; 36 | return response.json(); 37 | }) 38 | .then((response) => { 39 | RECOGNITION_RESULT.innerHTML = ""; 40 | for (i of response["r"]) { 41 | response_code = i["response_code"]; 42 | response_audio_url = i["response_audio_url"]; 43 | response = i["response"]; 44 | 45 | if (response_code === 0) { 46 | response.forEach(function (model_ans) { 47 | RECOGNITION_RESULT.insertAdjacentHTML( 48 | "beforeend", 49 | TR_PATTERN.format( 50 | response_audio_url, 51 | model_ans["time"], 52 | model_ans["text"] 53 | ) 54 | ); 55 | }); 56 | } else { 57 | ERROR_FIELD.innerHTML = "Error: " + response; 58 | ERROR_FIELD.style.display = "block"; 59 | } 60 | } 61 | }) 62 | .catch((err) => { 63 | console.log(err); 64 | err.text().then((errorMessage) => { 65 | ERROR_FIELD.innerHTML = errorMessage; 66 | ERROR_FIELD.style.display = "block"; 67 | }); 68 | }); 69 | }; 70 | 71 | const recordAudio = (stream) => 72 | new Promise(async (resolve) => { 73 | const mediaRecorder = new MediaRecorder(stream); 74 | var audioChunks = []; 75 | 76 | mediaRecorder.addEventListener("dataavailable", (event) => { 77 | audioChunks.push(event.data); 78 | }); 79 | 80 | const start = () => { 81 | STOP = true; 82 | audioChunks = []; 83 | mediaRecorder.start(); 84 | }; 85 | 86 | const stop = () => 87 | new Promise((resolve) => { 88 | mediaRecorder.addEventListener("stop", () => { 89 | STOP = false; 90 | BLOB = new Blob(audioChunks); 91 | send_records(); 92 | }); 93 | 94 | try { 95 | mediaRecorder.stop(); 96 | } finally { 97 | } 98 | }); 99 | resolve({ start, stop }); 100 | }); 101 | 102 | const startRecord = async () => { 103 | REC_BTN = document.getElementById("rec_btn"); 104 | 105 | if (STOP) { 106 | REC_BTN.className = "fa fa-microphone act_btn"; 107 | await recorder.stop(); 108 | return; 109 | } 110 | recorder = await recordAudio(stream); 111 | recorder.start(); 112 | REC_BTN.className = "fa fa-pause exe_btn"; 113 | }; 114 | 115 | document.addEventListener("DOMContentLoaded", () => { 116 | navigator.getUserMedia = 117 | navigator.mediaDevices.getUserMedia || 118 | navigator.getUserMedia || 119 | navigator.webkitGetUserMedia || 120 | navigator.mozGetUserMedia || 121 | navigator.msGetUserMedia; 122 | navigator.mediaDevices 123 | .getUserMedia({ 124 | audio: { sampleSize: 24, channelCount: 1, sampleRate: 44100 }, 125 | }) 126 | .then((r) => (stream = r)); 127 | }); 128 | 129 | var recorder, 130 | stream, 131 | BLOB = {}; 132 | var MULTIFILE = false, 133 | STOP = false; 134 | var TR_PATTERN = ` 135 | 136 | 137 | 141 | 142 | {1} 143 | {2} 144 | 145 | `; 146 | -------------------------------------------------------------------------------- /static/speech/js/web-audio-recording-tests-simpler-master/js/RecorderService.js: -------------------------------------------------------------------------------- 1 | class RecorderService { 2 | constructor () { 3 | window.AudioContext = window.AudioContext || window.webkitAudioContext 4 | 5 | this.em = document.createDocumentFragment() 6 | 7 | this.state = 'inactive' 8 | 9 | this.chunks = [] 10 | this.chunkType = '' 11 | 12 | this.encoderMimeType = 'audio/wav' 13 | 14 | this.config = { 15 | manualEncoderId: 'wav', 16 | micGain: 1.0, 17 | processorBufferSize: 2048, 18 | stopTracksAndCloseCtxWhenFinished: true, 19 | usingMediaRecorder: typeof window.MediaRecorder !== 'undefined', 20 | userMediaConstraints: { audio: { echoCancellation: false } } 21 | } 22 | } 23 | 24 | /* Returns promise */ 25 | startRecording () { 26 | if (this.state !== 'inactive') { 27 | return 28 | } 29 | 30 | // This is the case on ios/chrome, when clicking links from within ios/slack (sometimes), etc. 31 | if (!navigator || !navigator.mediaDevices || !navigator.mediaDevices.getUserMedia) { 32 | alert('Missing support for navigator.mediaDevices.getUserMedia') // temp: helps when testing for strange issues on ios/safari 33 | return 34 | } 35 | 36 | this.audioCtx = new AudioContext() 37 | this.micGainNode = this.audioCtx.createGain() 38 | this.outputGainNode = this.audioCtx.createGain() 39 | 40 | // If not using MediaRecorder(i.e. safari and edge), then a script processor is required. It's optional 41 | // on browsers using MediaRecorder and is only useful if you want to do custom analysis or manipulation of 42 | // recorded audio data. 43 | if (!this.config.usingMediaRecorder) { 44 | this.processorNode = this.audioCtx.createScriptProcessor(this.config.processorBufferSize, 1, 1) // TODO: Get the number of channels from mic 45 | } 46 | 47 | // Create stream destination on chrome/firefox because, AFAICT, we have no other way of feeding audio graph output 48 | // in to MediaRecorder. Safari/Edge don't have this method as of 2018-04. 49 | if (this.audioCtx.createMediaStreamDestination) { 50 | this.destinationNode = this.audioCtx.createMediaStreamDestination() 51 | } 52 | else { 53 | this.destinationNode = this.audioCtx.destination 54 | } 55 | 56 | // Create web worker for doing the encoding 57 | if (!this.config.usingMediaRecorder) { 58 | this.encoderWorker = new Worker('static/speech/js/web-audio-recording-tests-simpler-master/js/encoder-wav-worker.js') 59 | this.encoderMimeType = 'audio/wav' 60 | 61 | this.encoderWorker.addEventListener('message', (e) => { 62 | let event = new Event('dataavailable') 63 | event.data = new Blob(e.data, { type: this.encoderMimeType }) 64 | this._onDataAvailable(event) 65 | }) 66 | } 67 | 68 | // This will prompt user for permission if needed 69 | return navigator.mediaDevices.getUserMedia(this.config.userMediaConstraints) 70 | .then((stream) => { 71 | this._startRecordingWithStream(stream) 72 | }) 73 | .catch((error) => { 74 | alert('Error with getUserMedia: ' + error.message) // temp: helps when testing for strange issues on ios/safari 75 | console.log(error) 76 | }) 77 | } 78 | 79 | _startRecordingWithStream (stream) { 80 | this.micAudioStream = stream 81 | 82 | this.inputStreamNode = this.audioCtx.createMediaStreamSource(this.micAudioStream) 83 | this.audioCtx = this.inputStreamNode.context 84 | 85 | // Allow optionally hooking in to audioGraph inputStreamNode, useful for meters 86 | if (this.onGraphSetupWithInputStream) { 87 | this.onGraphSetupWithInputStream(this.inputStreamNode) 88 | } 89 | 90 | this.inputStreamNode.connect(this.micGainNode) 91 | this.micGainNode.gain.setValueAtTime(this.config.micGain, this.audioCtx.currentTime) 92 | 93 | let nextNode = this.micGainNode 94 | 95 | this.state = 'recording' 96 | 97 | if (this.processorNode) { 98 | nextNode.connect(this.processorNode) 99 | this.processorNode.connect(this.outputGainNode) 100 | this.processorNode.onaudioprocess = (e) => this._onAudioProcess(e) 101 | } 102 | else { 103 | nextNode.connect(this.outputGainNode) 104 | } 105 | 106 | this.outputGainNode.connect(this.destinationNode) 107 | 108 | if (this.config.usingMediaRecorder) { 109 | this.mediaRecorder = new MediaRecorder(this.destinationNode.stream) 110 | this.mediaRecorder.addEventListener('dataavailable', (evt) => this._onDataAvailable(evt)) 111 | this.mediaRecorder.addEventListener('error', (evt) => this._onError(evt)) 112 | 113 | this.mediaRecorder.start() 114 | } 115 | else { 116 | // Output gain to zero to prevent feedback. Seems to matter only on Edge, though seems like should matter 117 | // on iOS too. Matters on chrome when connecting graph to directly to audioCtx.destination, but we are 118 | // not able to do that when using MediaRecorder. 119 | this.outputGainNode.gain.setValueAtTime(0, this.audioCtx.currentTime) 120 | } 121 | } 122 | 123 | _onAudioProcess (e) { 124 | if (this.config.broadcastAudioProcessEvents) { 125 | this.em.dispatchEvent(new CustomEvent('onaudioprocess', { 126 | detail: { 127 | inputBuffer: e.inputBuffer, 128 | outputBuffer: e.outputBuffer 129 | } 130 | })) 131 | } 132 | 133 | // Safari and Edge require manual encoding via web worker. Single channel only for now. 134 | // Example stereo encoderWav: https://github.com/MicrosoftEdge/Demos/blob/master/microphone/scripts/recorderworker.js 135 | if (!this.config.usingMediaRecorder) { 136 | if (this.state === 'recording') { 137 | if (this.config.broadcastAudioProcessEvents) { 138 | this.encoderWorker.postMessage(['encode', e.outputBuffer.getChannelData(0)]) 139 | } 140 | else { 141 | this.encoderWorker.postMessage(['encode', e.inputBuffer.getChannelData(0)]) 142 | } 143 | } 144 | } 145 | } 146 | 147 | stopRecording () { 148 | if (this.state === 'inactive') { 149 | return 150 | } 151 | 152 | if (this.config.usingMediaRecorder) { 153 | this.state = 'inactive' 154 | this.mediaRecorder.stop() 155 | } 156 | else { 157 | this.state = 'inactive' 158 | this.encoderWorker.postMessage(['dump', this.audioCtx.sampleRate]) 159 | } 160 | } 161 | 162 | _onDataAvailable (evt) { 163 | this.chunks.push(evt.data) 164 | this.chunkType = evt.data.type 165 | 166 | if (this.state !== 'inactive') { 167 | return 168 | } 169 | 170 | let blob = new Blob(this.chunks, { 'type': this.chunkType }) 171 | let blobUrl = URL.createObjectURL(blob) 172 | const recording = { 173 | ts: new Date().getTime(), 174 | blobUrl: blobUrl, 175 | mimeType: blob.type, 176 | size: blob.size, 177 | blob: blob 178 | } 179 | 180 | this.chunks = [] 181 | this.chunkType = null 182 | 183 | if (this.destinationNode) { 184 | this.destinationNode.disconnect() 185 | this.destinationNode = null 186 | } 187 | if (this.outputGainNode) { 188 | this.outputGainNode.disconnect() 189 | this.outputGainNode = null 190 | } 191 | 192 | if (this.processorNode) { 193 | this.processorNode.disconnect() 194 | this.processorNode = null 195 | } 196 | 197 | if (this.encoderWorker) { 198 | this.encoderWorker.postMessage(['close']) 199 | this.encoderWorker = null 200 | } 201 | 202 | if (this.micGainNode) { 203 | this.micGainNode.disconnect() 204 | this.micGainNode = null 205 | } 206 | if (this.inputStreamNode) { 207 | this.inputStreamNode.disconnect() 208 | this.inputStreamNode = null 209 | } 210 | 211 | if (this.config.stopTracksAndCloseCtxWhenFinished) { 212 | // This removes the red bar in iOS/Safari 213 | this.micAudioStream.getTracks().forEach((track) => track.stop()) 214 | this.micAudioStream = null 215 | 216 | this.audioCtx.close() 217 | this.audioCtx = null 218 | } 219 | 220 | this.em.dispatchEvent(new CustomEvent('recording', { detail: { recording: recording } })) 221 | } 222 | 223 | _onError (evt) { 224 | console.log('error', evt) 225 | this.em.dispatchEvent(new Event('error')) 226 | alert('error:' + evt) // for debugging purposes 227 | } 228 | } 229 | -------------------------------------------------------------------------------- /static/speech/js/web-audio-recording-tests-simpler-master/js/WebAudioPeakMeter.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copied from https://github.com/esonderegger/web-audio-peak-meter 3 | * Modified to class form to allow multiple instances on a page. 4 | */ 5 | class WebAudioPeakMeter { 6 | constructor () { 7 | this.options = { 8 | borderSize: 2, 9 | fontSize: 9, 10 | backgroundColor: 'white', 11 | tickColor: '#000', 12 | gradient: ['red 1%', '#ff0 16%', 'lime 45%', '#080 100%'], 13 | dbRange: 48, 14 | dbTickSize: 6, 15 | maskTransition: 'height 0.1s' 16 | } 17 | 18 | this.vertical = true 19 | this.channelCount = 1 20 | this.channelMasks = [] 21 | this.channelPeaks = [] 22 | this.channelPeakLabels = [] 23 | } 24 | 25 | getBaseLog (x, y) { 26 | return Math.log(y) / Math.log(x) 27 | } 28 | 29 | dbFromFloat (floatVal) { 30 | return this.getBaseLog(10, floatVal) * 20 31 | } 32 | 33 | setOptions (userOptions) { 34 | for (var k in userOptions) { 35 | this.options[k] = userOptions[k] 36 | } 37 | this.tickWidth = this.options.fontSize * 2.0 38 | this.meterTop = this.options.fontSize * 1.5 + this.options.borderSize 39 | } 40 | 41 | createMeterNode (sourceNode, audioCtx) { 42 | var c = sourceNode.channelCount 43 | var meterNode = audioCtx.createScriptProcessor(2048, c, c) 44 | sourceNode.connect(meterNode) 45 | meterNode.connect(audioCtx.destination) 46 | return meterNode 47 | } 48 | 49 | createContainerDiv (parent) { 50 | var meterElement = document.createElement('div') 51 | meterElement.style.position = 'relative' 52 | meterElement.style.width = this.elementWidth + 'px' 53 | meterElement.style.height = this.elementHeight + 'px' 54 | meterElement.style.backgroundColor = this.options.backgroundColor 55 | parent.appendChild(meterElement) 56 | return meterElement 57 | } 58 | 59 | createMeter (domElement, meterNode, optionsOverrides) { 60 | this.setOptions(optionsOverrides) 61 | this.elementWidth = domElement.clientWidth 62 | this.elementHeight = domElement.clientHeight 63 | var meterElement = this.createContainerDiv(domElement) 64 | if (this.elementWidth > this.elementHeight) { 65 | this.vertical = false 66 | } 67 | this.meterHeight = this.elementHeight - this.meterTop - this.options.borderSize 68 | this.meterWidth = this.elementWidth - this.tickWidth - this.options.borderSize 69 | this.createTicks(meterElement) 70 | this.createRainbow(meterElement, this.meterWidth, this.meterHeight, 71 | this.meterTop, this.tickWidth) 72 | this.channelCount = meterNode.channelCount 73 | var channelWidth = this.meterWidth / this.channelCount 74 | var channelLeft = this.tickWidth 75 | for (var i = 0; i < this.channelCount; i++) { 76 | this.createChannelMask(meterElement, this.options.borderSize, 77 | this.meterTop, channelLeft, false) 78 | this.channelMasks[i] = this.createChannelMask(meterElement, channelWidth, 79 | this.meterTop, channelLeft, 80 | this.options.maskTransition) 81 | this.channelPeaks[i] = 0.0 82 | this.channelPeakLabels[i] = this.createPeakLabel(meterElement, channelWidth, 83 | channelLeft) 84 | channelLeft += channelWidth 85 | } 86 | meterNode.onaudioprocess = (e) => this.updateMeter(e) 87 | meterElement.addEventListener('click', function () { 88 | for (var i = 0; i < this.channelCount; i++) { 89 | this.channelPeaks[i] = 0.0 90 | this.channelPeakLabels[i].textContent = '-∞' 91 | } 92 | }, false) 93 | } 94 | 95 | createTicks (parent) { 96 | var numTicks = Math.floor(this.options.dbRange / this.options.dbTickSize) 97 | var dbTickLabel = 0 98 | var dbTickTop = this.options.fontSize + this.options.borderSize 99 | for (var i = 0; i < numTicks; i++) { 100 | var dbTick = document.createElement('div') 101 | parent.appendChild(dbTick) 102 | dbTick.style.width = this.tickWidth + 'px' 103 | dbTick.style.textAlign = 'right' 104 | dbTick.style.color = this.options.tickColor 105 | dbTick.style.fontSize = this.options.fontSize + 'px' 106 | dbTick.style.position = 'absolute' 107 | dbTick.style.top = dbTickTop + 'px' 108 | dbTick.textContent = dbTickLabel + '' 109 | dbTickLabel -= this.options.dbTickSize 110 | dbTickTop += this.meterHeight / numTicks 111 | } 112 | } 113 | 114 | createRainbow (parent, width, height, top, left) { 115 | var rainbow = document.createElement('div') 116 | parent.appendChild(rainbow) 117 | rainbow.style.width = width + 'px' 118 | rainbow.style.height = height + 'px' 119 | rainbow.style.position = 'absolute' 120 | rainbow.style.top = top + 'px' 121 | rainbow.style.left = left + 'px' 122 | var gradientStyle = 'linear-gradient(' + this.options.gradient.join(', ') + ')' 123 | rainbow.style.backgroundImage = gradientStyle 124 | return rainbow 125 | } 126 | 127 | createPeakLabel (parent, width, left) { 128 | var label = document.createElement('div') 129 | parent.appendChild(label) 130 | label.style.width = width + 'px' 131 | label.style.textAlign = 'center' 132 | label.style.color = this.options.tickColor 133 | label.style.fontSize = this.options.fontSize + 'px' 134 | label.style.position = 'absolute' 135 | label.style.top = this.options.borderSize + 'px' 136 | label.style.left = left + 'px' 137 | label.textContent = '-∞' 138 | return label 139 | } 140 | 141 | createChannelMask (parent, width, top, left, transition) { 142 | var channelMask = document.createElement('div') 143 | parent.appendChild(channelMask) 144 | channelMask.style.width = width + 'px' 145 | channelMask.style.height = this.meterHeight + 'px' 146 | channelMask.style.position = 'absolute' 147 | channelMask.style.top = top + 'px' 148 | channelMask.style.left = left + 'px' 149 | channelMask.style.backgroundColor = this.options.backgroundColor 150 | if (transition) { 151 | channelMask.style.transition = this.options.maskTransition 152 | } 153 | return channelMask 154 | } 155 | 156 | maskSize (floatVal) { 157 | if (floatVal === 0.0) { 158 | return this.meterHeight 159 | } 160 | else { 161 | var d = this.options.dbRange * -1 162 | var returnVal = Math.floor(this.dbFromFloat(floatVal) * this.meterHeight / d) 163 | if (returnVal > this.meterHeight) { 164 | return this.meterHeight 165 | } 166 | else { 167 | return returnVal 168 | } 169 | } 170 | } 171 | 172 | updateMeter (audioProcessingEvent) { 173 | var inputBuffer = audioProcessingEvent.inputBuffer 174 | var i 175 | var channelData = [] 176 | var channelMaxes = [] 177 | for (i = 0; i < this.channelCount; i++) { 178 | channelData[i] = inputBuffer.getChannelData(i) 179 | channelMaxes[i] = 0.0 180 | } 181 | for (var sample = 0; sample < inputBuffer.length; sample++) { 182 | for (i = 0; i < this.channelCount; i++) { 183 | if (Math.abs(channelData[i][sample]) > channelMaxes[i]) { 184 | channelMaxes[i] = Math.abs(channelData[i][sample]) 185 | } 186 | } 187 | } 188 | for (i = 0; i < this.channelCount; i++) { 189 | var thisMaskSize = this.maskSize(channelMaxes[i], this.meterHeight) 190 | this.channelMasks[i].style.height = thisMaskSize + 'px' 191 | if (channelMaxes[i] > this.channelPeaks[i]) { 192 | this.channelPeaks[i] = channelMaxes[i] 193 | var labelText = this.dbFromFloat(this.channelPeaks[i]).toFixed(1) 194 | this.channelPeakLabels[i].textContent = labelText 195 | } 196 | } 197 | } 198 | } 199 | -------------------------------------------------------------------------------- /static/speech/js/web-audio-recording-tests-simpler-master/js/app.js: -------------------------------------------------------------------------------- 1 | "use strict"; 2 | 3 | class App { 4 | constructor () { 5 | this.recBtn = document.getElementById('rec_btn'); 6 | 7 | this.isRecording = false 8 | this.saveNextRecording = false 9 | 10 | this.stop = false 11 | } 12 | 13 | init () { 14 | this._initEventListeners() 15 | } 16 | 17 | _initEventListeners () { 18 | 19 | this.recBtn.addEventListener('click', evt => { 20 | if (this.stop) { 21 | 22 | this.stop = false; 23 | 24 | this._stopAllRecording(); 25 | this.recBtn.className = 'fa fa-microphone act_btn'; 26 | 27 | } else { 28 | 29 | this.stop = true 30 | 31 | this._stopAllRecording() 32 | this.saveNextRecording = true 33 | this._startRecording() 34 | 35 | this.recBtn.className = "fa fa-pause exe_btn"; 36 | } 37 | }) 38 | } 39 | 40 | _startRecording () { 41 | if (!this.recorderSrvc) { 42 | this.recorderSrvc = new RecorderService() 43 | this.recorderSrvc.em.addEventListener('recording', (evt) => this._onNewRecording(evt)) 44 | } 45 | 46 | if (!this.webAudioPeakMeter) { 47 | this.webAudioPeakMeter = new WebAudioPeakMeter() 48 | this.meterEl = document.getElementById('recording-meter') 49 | } 50 | 51 | this.recorderSrvc.onGraphSetupWithInputStream = (inputStreamNode) => { 52 | this.meterNodeRaw = this.webAudioPeakMeter.createMeterNode(inputStreamNode, this.recorderSrvc.audioCtx) 53 | this.webAudioPeakMeter.createMeter(this.meterEl, this.meterNodeRaw, {}) 54 | } 55 | 56 | this.recorderSrvc.startRecording() 57 | this.isRecording = true 58 | } 59 | 60 | _stopAllRecording () { 61 | if (this.recorderSrvc && this.isRecording) { 62 | 63 | this.recorderSrvc.stopRecording() 64 | this.isRecording = false 65 | 66 | if (this.meterNodeRaw) { 67 | this.meterNodeRaw.disconnect() 68 | this.meterNodeRaw = null 69 | this.meterEl.innerHTML = '' 70 | } 71 | } 72 | } 73 | 74 | _onNewRecording (evt) { 75 | if (!this.saveNextRecording) { 76 | return 77 | } 78 | send_records(evt.detail.recording.blob) 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /static/speech/js/web-audio-recording-tests-simpler-master/js/encoder-wav-worker.js: -------------------------------------------------------------------------------- 1 | // Parts copied from https://github.com/chris-rudmin/Recorderjs 2 | let BYTES_PER_SAMPLE = 2 3 | let recorded = [] 4 | 5 | function encode (buffer) { 6 | let length = buffer.length 7 | let data = new Uint8Array(length * BYTES_PER_SAMPLE) 8 | for (let i = 0; i < length; i++) { 9 | let index = i * BYTES_PER_SAMPLE 10 | let sample = buffer[i] 11 | if (sample > 1) { 12 | sample = 1 13 | } 14 | else if (sample < -1) { 15 | sample = -1 16 | } 17 | sample = sample * 32768 18 | data[index] = sample 19 | data[index + 1] = sample >> 8 20 | } 21 | recorded.push(data) 22 | } 23 | 24 | function dump (sampleRate) { 25 | let bufferLength = recorded.length ? recorded[0].length : 0 26 | let length = recorded.length * bufferLength 27 | let wav = new Uint8Array(44 + length) 28 | 29 | let view = new DataView(wav.buffer) 30 | 31 | // RIFF identifier 'RIFF' 32 | view.setUint32(0, 1380533830, false) 33 | // file length minus RIFF identifier length and file description length 34 | view.setUint32(4, 36 + length, true) 35 | // RIFF type 'WAVE' 36 | view.setUint32(8, 1463899717, false) 37 | // format chunk identifier 'fmt ' 38 | view.setUint32(12, 1718449184, false) 39 | // format chunk length 40 | view.setUint32(16, 16, true) 41 | // sample format (raw) 42 | view.setUint16(20, 1, true) 43 | // channel count 44 | view.setUint16(22, 1, true) 45 | // sample rate 46 | view.setUint32(24, sampleRate, true) 47 | // byte rate (sample rate * block align) 48 | view.setUint32(28, sampleRate * BYTES_PER_SAMPLE, true) 49 | // block align (channel count * bytes per sample) 50 | view.setUint16(32, BYTES_PER_SAMPLE, true) 51 | // bits per sample 52 | view.setUint16(34, 8 * BYTES_PER_SAMPLE, true) 53 | // data chunk identifier 'data' 54 | view.setUint32(36, 1684108385, false) 55 | // data chunk length 56 | view.setUint32(40, length, true) 57 | 58 | for (var i = 0; i < recorded.length; i++) { 59 | wav.set(recorded[i], i * bufferLength + 44) 60 | } 61 | 62 | recorded = [] 63 | let msg = [wav.buffer] 64 | postMessage(msg, [msg[0]]) 65 | } 66 | 67 | onmessage = function (e) { 68 | if (e.data[0] === 'encode') { 69 | encode(e.data[1]) 70 | } 71 | else if (e.data[0] === 'dump') { 72 | dump(e.data[1]) 73 | } 74 | else if (e.data[0] === 'close') { 75 | self.close() 76 | } 77 | } 78 | 79 | -------------------------------------------------------------------------------- /templates/_formhelpers.html: -------------------------------------------------------------------------------- 1 | {% macro render_field(field) %} 2 |
{{ field.label }} 3 |
{{ field(**kwargs)|safe }} 4 | {% if field.errors %} 5 |
    6 | {% for error in field.errors %} 7 |
  • {{ error }}
  • 8 | {% endfor %} 9 |
10 | {% endif %} 11 |
12 | {% endmacro %} 13 | -------------------------------------------------------------------------------- /templates/base.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | {% block title %}{% endblock %} 5 | 6 | 7 | 10 | 11 | {% block head %}{% endblock %} 12 | 13 | 14 | 15 |
16 |
17 | 18 |
19 | {% block sections %} 20 | {% endblock %} 21 |
22 | 23 |
24 | 29 | 32 |
33 | 34 | -------------------------------------------------------------------------------- /templates/speech_recognition.html: -------------------------------------------------------------------------------- 1 | {% extends 'base.html' %} 2 | {% from "_formhelpers.html" import render_field %} 3 | 4 | {% block title %}SOVA ASR{% endblock %} 5 | 6 | {% block head %} 7 | 8 | {% endblock %} 9 | 10 | {% block sections %} 11 | 12 | 13 | 17 | 18 |
19 | 20 |

21 | Speech Recognition 22 |

23 |
24 | Click on the microphone icon, record audio and wait for the ASR results. You can start recording by pressing "space" or "enter". You can also upload multiple audio files via "Choose Files" button. 25 |
26 | 27 |
28 | 29 | 30 | 31 |
32 |
33 |
34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 |
AudioRecognition timeRecognized text
45 |
46 |
47 | 48 | 49 | 50 | 51 | 59 | {% endblock %} 60 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | from shutil import copyfile 5 | import configparser 6 | import Levenshtein 7 | from data_loader import DataLoader, SpectrogramDataset, BucketingSampler 8 | from decoder import GreedyDecoder 9 | from PuzzleLib.Models.Nets.WaveToLetter import loadW2L 10 | from PuzzleLib.Backend import gpuarray 11 | from PuzzleLib.Cost.CTC import CTC 12 | from PuzzleLib.Optimizers.Adam import Adam 13 | from PuzzleLib.Modules import MoveAxis 14 | from PuzzleLib.Modules.Cast import Cast 15 | 16 | 17 | def get_data_loader(manifest_file_path, labels, sample_rate, window_size, window_stride, batch_size): 18 | dataset = SpectrogramDataset(labels, sample_rate, window_size, window_stride, manifest_file_path) 19 | sampler = BucketingSampler(dataset, batch_size=batch_size) 20 | return DataLoader(dataset, batch_sampler=sampler) 21 | 22 | 23 | def calculate_wer(s1, s2): 24 | b = set(s1.split() + s2.split()) 25 | word2char = dict(zip(b, range(len(b)))) 26 | w1 = [chr(word2char[w]) for w in s1.split()] 27 | w2 = [chr(word2char[w]) for w in s2.split()] 28 | return Levenshtein.distance(''.join(w1), ''.join(w2)) / len(''.join(w2)) 29 | 30 | 31 | def calculate_cer(s1, s2): 32 | s1, s2, = s1.replace(' ', ''), s2.replace(' ', '') 33 | return Levenshtein.distance(s1, s2) / len(s2) 34 | 35 | 36 | def train(model, ctc, optimizer, loader, checkpoint_per_batch, save_folder, save_name, fp16): 37 | model.reset() 38 | model.trainMode() 39 | loader.reset() 40 | 41 | if not os.path.exists(save_folder): 42 | os.mkdir(save_folder) 43 | 44 | for i, (data) in enumerate(loader): 45 | inputs, input_percentages, targets, target_sizes, _ = data 46 | if fp16: 47 | inputs = gpuarray.to_gpu(inputs.astype(np.float16)) 48 | else: 49 | inputs = gpuarray.to_gpu(inputs.astype(np.float32)) 50 | 51 | out = model(inputs) 52 | out_len = gpuarray.to_gpu((out.shape[0] * input_percentages).astype(np.int32)) 53 | target_sizes = target_sizes.astype(np.int32) 54 | targets = gpuarray.to_gpu(targets.astype(np.int32)) 55 | 56 | error, grad = ctc([out, out_len], [targets, target_sizes]) 57 | 58 | print('Training iter {} of {}, CTC: {}'.format(i + 1, len(loader), error)) 59 | 60 | optimizer.zeroGradParams() 61 | model.backward(grad.astype(np.float32), updGrad=False) 62 | optimizer.update() 63 | 64 | if checkpoint_per_batch and i % checkpoint_per_batch == 0 and i > 0: 65 | save_path = os.path.join(save_folder, '{}_iter_{}.hdf'.format(save_name, i)) 66 | model.save(hdf=save_path) 67 | copyfile(save_path, os.path.join(save_folder, 'last.hdf')) 68 | 69 | save_path = os.path.join(save_folder, '{}.hdf'.format(save_name)) 70 | model.save(hdf=save_path) 71 | copyfile(save_path, os.path.join(save_folder, 'last.hdf')) 72 | 73 | return model 74 | 75 | 76 | def validate(model, loader, decoder, fp16): 77 | loader.reset() 78 | model.evalMode() 79 | total_cer, total_wer = 0, 0 80 | 81 | for i, (data) in enumerate(loader): 82 | inputs, input_percentages, targets, target_sizes, input_file = data 83 | if fp16: 84 | inputs = gpuarray.to_gpu(inputs.astype(np.float16)) 85 | else: 86 | inputs = gpuarray.to_gpu(inputs.astype(np.float32)) 87 | 88 | out = model(inputs) 89 | out_len = (out.shape[0] * input_percentages).astype(np.int32) 90 | 91 | decoded_output = [ 92 | decoder.decode(output, max_len=out_len[j]) for j, output in enumerate(np.moveaxis(out.get(), 0, 1)) 93 | ] 94 | 95 | print('\nValidation iter {} of {}'.format(i + 1, len(loader))) 96 | 97 | wer, cer = 0, 0 98 | for x in range(len(decoded_output)): 99 | transcript, reference = decoded_output[x], input_file[x][1] 100 | print('Transcript: {}\nReference: {}\nFilepath: {}'.format(transcript, reference, input_file[x][0])) 101 | try: 102 | wer += calculate_wer(transcript, reference) 103 | cer += calculate_cer(transcript, reference) 104 | except Exception as e: 105 | print('Encountered exception {}'.format(e)) 106 | total_cer += cer 107 | total_wer += wer 108 | 109 | wer = total_wer / len(loader.dataset) * 100 110 | cer = total_cer / len(loader.dataset) * 100 111 | print('WER: {}'.format(wer)) 112 | print('CER: {}'.format(cer)) 113 | 114 | 115 | def main(): 116 | parser = argparse.ArgumentParser(description='Training') 117 | parser.add_argument('--config', metavar='DIR', help='Path to train config', default='config.ini') 118 | args = parser.parse_args() 119 | 120 | config_path = args.config 121 | if config_path is None: 122 | raise Exception('Path to config file is None') 123 | config = configparser.ConfigParser() 124 | config.read(config_path, encoding='UTF-8') 125 | 126 | sample_rate = int(config['Wav2Letter'].get('sample_rate')) 127 | window_size = float(config['Wav2Letter'].get('window_size')) 128 | window_stride = float(config['Wav2Letter'].get('window_stride')) 129 | labels = config['Wav2Letter'].get('labels')[1:-1] 130 | 131 | train_manifest = config['Train'].get('train_manifest', None) 132 | val_manifest = config['Train'].get('val_manifest', None) 133 | epochs = int(config['Train'].get('epochs')) 134 | batch_size = int(config['Train'].get('batch_size')) 135 | learning_rate = float(config['Train'].get('learning_rate')) 136 | fp16 = bool(config['Train'].get('fp16')) 137 | checkpoint_name = config['Train'].get('checkpoint_name') 138 | checkpoint_per_batch = int(config['Train'].get('checkpoint_per_batch')) 139 | save_folder = config['Train'].get('save_folder') 140 | continue_from = config['Train'].get('continue_from') 141 | 142 | train_loader, val_loader = None, None 143 | 144 | if train_manifest is not None: 145 | train_loader = get_data_loader(train_manifest, labels, sample_rate, window_size, window_stride, batch_size) 146 | 147 | if val_manifest is not None: 148 | val_loader = get_data_loader(val_manifest, labels, sample_rate, window_size, window_stride, batch_size) 149 | 150 | nfft = int(sample_rate * window_size) 151 | w2l = loadW2L(modelpath=continue_from, inmaps=(1 + nfft // 2), nlabels=len(labels)) 152 | 153 | if fp16: 154 | w2l.calcMode(np.float16) 155 | w2l.append(Cast(np.float16, np.float32)) 156 | 157 | w2l.append(MoveAxis(src=2, dst=0)) 158 | 159 | blank_index = [i for i in range(len(labels)) if labels[i] == '_'][0] 160 | ctc = CTC(blank_index) 161 | 162 | adam = Adam(alpha=learning_rate) 163 | adam.setupOn(w2l, useGlobalState=True) 164 | 165 | decoder = GreedyDecoder(labels, blank_index) 166 | 167 | for epoch in range(epochs): 168 | if train_manifest is not None: 169 | print('Epoch {} of {}'.format(epoch + 1, epochs)) 170 | w2l = train(w2l, ctc, adam, train_loader, checkpoint_per_batch, save_folder, 171 | '{}_{}'.format(checkpoint_name, epoch), fp16) 172 | 173 | if val_manifest is not None: 174 | validate(w2l, val_loader, decoder, fp16) 175 | 176 | 177 | if __name__ == '__main__': 178 | main() 179 | -------------------------------------------------------------------------------- /trie_decoder/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | __all__ = ["Common", "Criterion", "Decoder", "Feature"] 4 | -------------------------------------------------------------------------------- /trie_decoder/_common.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sovaai/sova-asr/d6c257555a225c9c1e1bb3e3cebf9b7ce8d302d7/trie_decoder/_common.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /trie_decoder/_criterion.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sovaai/sova-asr/d6c257555a225c9c1e1bb3e3cebf9b7ce8d302d7/trie_decoder/_criterion.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /trie_decoder/_decoder.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sovaai/sova-asr/d6c257555a225c9c1e1bb3e3cebf9b7ce8d302d7/trie_decoder/_decoder.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /trie_decoder/_feature.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sovaai/sova-asr/d6c257555a225c9c1e1bb3e3cebf9b7ce8d302d7/trie_decoder/_feature.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /trie_decoder/common.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from trie_decoder._common import * 4 | -------------------------------------------------------------------------------- /trie_decoder/criterion.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from trie_decoder._criterion import * 4 | -------------------------------------------------------------------------------- /trie_decoder/decoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from trie_decoder._decoder import * 4 | -------------------------------------------------------------------------------- /trie_decoder/feature.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from trie_decoder._feature import * 4 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from scipy import signal 2 | import numpy as np 3 | import scipy 4 | from numpy import fft 5 | from numpy.lib.stride_tricks import as_strided 6 | 7 | 8 | MAX_MEM_BLOCK = 2 ** 8 * 2 ** 10 9 | 10 | 11 | def get_window(window, nx, fftbins=True): 12 | if callable(window): 13 | return window(nx) 14 | 15 | elif isinstance(window, (str, tuple)) or np.isscalar(window): 16 | return scipy.signal.get_window(window, nx, fftbins=fftbins) 17 | 18 | elif isinstance(window, (np.ndarray, list)): 19 | if len(window) == nx: 20 | return np.asarray(window) 21 | 22 | raise Exception('Window size mismatch: ' 23 | '{:d} != {:d}'.format(len(window), nx)) 24 | else: 25 | raise Exception('Invalid window specification: {}'.format(window)) 26 | 27 | 28 | def pad_center(data, size, axis=-1, **kwargs): 29 | kwargs.setdefault('mode', 'constant') 30 | 31 | n = data.shape[axis] 32 | 33 | lpad = int((size - n) // 2) 34 | 35 | lengths = [(0, 0)] * data.ndim 36 | lengths[axis] = (lpad, int(size - n - lpad)) 37 | 38 | if lpad < 0: 39 | raise Exception(('Target size ({:d}) must be ' 40 | 'at least input size ({:d})').format(size, n)) 41 | 42 | return np.pad(data, lengths, **kwargs) 43 | 44 | 45 | def frame(x, frame_length, hop_length, axis=-1): 46 | if not isinstance(x, np.ndarray): 47 | raise Exception('Input must be of type numpy.ndarray, ' 48 | 'given type(x)={}'.format(type(x))) 49 | 50 | if x.shape[axis] < frame_length: 51 | raise Exception('Input is too short (n={:d})' 52 | ' for frame_length={:d}'.format(x.shape[axis], frame_length)) 53 | 54 | if hop_length < 1: 55 | raise Exception('Invalid hop_length: {:d}'.format(hop_length)) 56 | 57 | if axis == -1 and not x.flags['F_CONTIGUOUS']: 58 | x = np.asfortranarray(x) 59 | elif axis == 0 and not x.flags['C_CONTIGUOUS']: 60 | x = np.ascontiguousarray(x) 61 | 62 | n_frames = 1 + (x.shape[axis] - frame_length) // hop_length 63 | strides = np.asarray(x.strides) 64 | 65 | new_stride = np.prod(strides[strides > 0] // x.itemsize) * x.itemsize 66 | 67 | if axis == -1: 68 | shape = list(x.shape)[:-1] + [frame_length, n_frames] 69 | strides = list(strides) + [hop_length * new_stride] 70 | 71 | elif axis == 0: 72 | shape = [n_frames, frame_length] + list(x.shape)[1:] 73 | strides = [hop_length * new_stride] + list(strides) 74 | 75 | else: 76 | raise Exception('Frame axis={} must be either 0 or -1'.format(axis)) 77 | 78 | return as_strided(x, shape=shape, strides=strides) 79 | 80 | 81 | def stft(y, n_fft=2048, hop_length=None, win_length=None, window='hann', 82 | center=True, dtype=np.complex64, pad_mode='reflect'): 83 | # By default, use the entire frame 84 | if win_length is None: 85 | win_length = n_fft 86 | 87 | # Set the default hop, if it's not already specified 88 | if hop_length is None: 89 | hop_length = int(win_length // 4) 90 | 91 | fft_window = get_window(window, win_length, fftbins=True) 92 | 93 | # Pad the window out to n_fft size 94 | fft_window = pad_center(fft_window, n_fft) 95 | 96 | # Reshape so that the window can be broadcast 97 | fft_window = fft_window.reshape((-1, 1)) 98 | 99 | # Pad the time series so that frames are centered 100 | if center: 101 | y = np.pad(y, int(n_fft // 2), mode=pad_mode) 102 | 103 | # Window the time series. 104 | y_frames = frame(y, frame_length=n_fft, hop_length=hop_length) 105 | 106 | # Pre-allocate the STFT matrix 107 | stft_matrix = np.empty((int(1 + n_fft // 2), y_frames.shape[1]), 108 | dtype=dtype, 109 | order='F') 110 | 111 | # how many columns can we fit within MAX_MEM_BLOCK? 112 | n_columns = int(MAX_MEM_BLOCK / (stft_matrix.shape[0] * 113 | stft_matrix.itemsize)) 114 | 115 | for bl_s in range(0, stft_matrix.shape[1], n_columns): 116 | bl_t = min(bl_s + n_columns, stft_matrix.shape[1]) 117 | 118 | # RFFT and Conjugate here to match phase from DPWE code 119 | stft_matrix[:, bl_s:bl_t] = fft.fft(fft_window * 120 | y_frames[:, bl_s:bl_t], 121 | axis=0)[:stft_matrix.shape[0]].conj() 122 | 123 | return stft_matrix 124 | 125 | 126 | def magphase(d, power=1): 127 | mag = np.abs(d) 128 | mag **= power 129 | phase = np.exp(1.j * np.angle(d)) 130 | 131 | return mag, phase 132 | --------------------------------------------------------------------------------