├── .gitattributes ├── .gitignore ├── .vscode ├── c_cpp_properties.json ├── launch.json ├── settings.json └── tasks.json ├── CMakeLists.txt ├── LICENSE ├── README.md ├── bintoh5test.ipynb ├── dump_percepnet.py ├── requirements.txt ├── rnn_train.py ├── sampledata ├── noise │ └── noise.pcm └── speech │ └── speech.pcm ├── src ├── CMakeLists.txt ├── _kiss_fft_guts.h ├── arch ├── arch.h ├── celt_lpc.cpp ├── celt_lpc.h ├── common.h ├── denoise.cpp ├── erbband.h ├── kiss_fft.cpp ├── kiss_fft.h ├── main.cpp ├── nnet.cpp ├── nnet.h ├── nnet_data.h ├── opus_types.h ├── pitch.cpp ├── pitch.h ├── rnn.cpp ├── rnnoise.h ├── tansig_table.h └── vec.h ├── tests ├── CMakeLists.txt ├── Untitled.ipynb ├── main.cpp ├── moduletest.py ├── nnet_data_test.h └── testnnet.cpp └── utils ├── DNS_Challenge.yaml ├── __pycache__ └── filterbanks.cpython-35.pyc ├── bin2h5.py ├── filterbanks.py ├── parse_options.sh ├── path.sh ├── run.sh └── split_feature_dataset.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.wav filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # python cache 141 | __pycache__/ 142 | 143 | *.exe 144 | src/*.pcm 145 | *.wav 146 | 147 | #binary folder 148 | bin/ 149 | 150 | #dataset folder 151 | sampledata/ 152 | test_input.pcm 153 | test_output.pcm 154 | test.output 155 | training_set_sept12_500h 156 | sampledata_vctk_DEMAND 157 | training.h5 158 | 159 | DNS-Challenge 160 | -------------------------------------------------------------------------------- /.vscode/c_cpp_properties.json: -------------------------------------------------------------------------------- 1 | { 2 | "configurations": [ 3 | { 4 | "name": "Linux", 5 | "includePath": [ 6 | "${workspaceFolder}/**" 7 | ], 8 | "defines": [], 9 | "compilerPath": "/usr/bin/gcc", 10 | "cStandard": "gnu17", 11 | "cppStandard": "gnu++14", 12 | "intelliSenseMode": "linux-gcc-x64" 13 | } 14 | ], 15 | "version": 4 16 | } -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Current File", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal" 13 | }, 14 | { 15 | "name": "g++ - Build and debug active file", 16 | "type": "cppdbg", 17 | "request": "launch", 18 | "program": "${workspaceFolder}/src/main", 19 | "args": ["${fileDirname}/../sampledata/speech/speech.pcm", "${fileDirname}/../sampledata/noise/noise.pcm", "20000", "${fileDirname}/../test.output"], 20 | "stopAtEntry": false, 21 | "cwd": "${fileDirname}", 22 | "environment": [ { "name": "ENABLE_ASSERTIONS", "value": "1" }], 23 | "externalConsole": false, 24 | "MIMode": "gdb", 25 | "setupCommands": [ 26 | { 27 | "description": "Enable pretty-printing for gdb", 28 | "text": "-enable-pretty-printing", 29 | "ignoreFailures": true 30 | } 31 | ], 32 | "preLaunchTask": "C/C++: g++ build active file", 33 | "miDebuggerPath": "/usr/bin/gdb", 34 | 35 | }, 36 | { 37 | "name": "C/C++: g++ build active file(PercepNet_run)", 38 | "type": "cppdbg", 39 | "request": "launch", 40 | "program": "${workspaceFolder}/src/PercepNet_run", 41 | "args": ["${fileDirname}/../test_input.pcm", "${fileDirname}/../test_dnn_output.pcm"], 42 | "stopAtEntry": false, 43 | "cwd": "${fileDirname}", 44 | "environment": [ { "name": "ENABLE_ASSERTIONS", "value": "1" }], 45 | "externalConsole": false, 46 | "MIMode": "gdb", 47 | "setupCommands": [ 48 | { 49 | "description": "Enable pretty-printing for gdb", 50 | "text": "-enable-pretty-printing", 51 | "ignoreFailures": true 52 | } 53 | ], 54 | "preLaunchTask": "C/C++: g++ build active file(PercepNet_run)", 55 | "miDebuggerPath": "/usr/bin/gdb", 56 | 57 | } 58 | ] 59 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/usr/bin/python3", 3 | "files.associations": { 4 | "stdio.h": "c", 5 | "nnet_data.h": "c", 6 | "array": "cpp", 7 | "atomic": "cpp", 8 | "bit": "cpp", 9 | "*.tcc": "cpp", 10 | "cctype": "cpp", 11 | "clocale": "cpp", 12 | "cmath": "cpp", 13 | "cstdarg": "cpp", 14 | "cstddef": "cpp", 15 | "cstdint": "cpp", 16 | "cstdio": "cpp", 17 | "cstdlib": "cpp", 18 | "cstring": "cpp", 19 | "cwchar": "cpp", 20 | "cwctype": "cpp", 21 | "deque": "cpp", 22 | "unordered_map": "cpp", 23 | "vector": "cpp", 24 | "exception": "cpp", 25 | "algorithm": "cpp", 26 | "functional": "cpp", 27 | "iterator": "cpp", 28 | "memory": "cpp", 29 | "memory_resource": "cpp", 30 | "numeric": "cpp", 31 | "optional": "cpp", 32 | "random": "cpp", 33 | "string": "cpp", 34 | "string_view": "cpp", 35 | "system_error": "cpp", 36 | "tuple": "cpp", 37 | "type_traits": "cpp", 38 | "utility": "cpp", 39 | "fstream": "cpp", 40 | "initializer_list": "cpp", 41 | "iosfwd": "cpp", 42 | "iostream": "cpp", 43 | "istream": "cpp", 44 | "limits": "cpp", 45 | "new": "cpp", 46 | "ostream": "cpp", 47 | "sstream": "cpp", 48 | "stdexcept": "cpp", 49 | "streambuf": "cpp", 50 | "cinttypes": "cpp", 51 | "typeinfo": "cpp", 52 | "any": "cpp", 53 | "chrono": "cpp", 54 | "ctime": "cpp", 55 | "forward_list": "cpp", 56 | "list": "cpp", 57 | "map": "cpp", 58 | "set": "cpp", 59 | "unordered_set": "cpp", 60 | "ratio": "cpp", 61 | "iomanip": "cpp", 62 | "variant": "cpp" 63 | } 64 | } -------------------------------------------------------------------------------- /.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | "tasks": [ 3 | { 4 | "type": "cppbuild", 5 | "label": "C/C++: g++ build active file", 6 | "command": "/usr/bin/g++", 7 | "args": [ 8 | "-g", 9 | "${workspaceFolder}/src/**.cpp", 10 | "${workspaceFolder}/src/**.h", 11 | "-o", 12 | "${workspaceFolder}/src/main", 13 | ], 14 | "options": { 15 | "cwd": "${fileDirname}" 16 | }, 17 | "problemMatcher": [ 18 | "$gcc" 19 | ], 20 | "group": { 21 | "kind": "build", 22 | "isDefault": true 23 | }, 24 | "detail": "Task generated by Debugger." 25 | }, 26 | { 27 | "type": "cppbuild", 28 | "label": "C/C++: g++ build active file(PercepNet_run)", 29 | "command": "/usr/bin/g++", 30 | "args": [ 31 | "-g", 32 | "${workspaceFolder}/src/**.cpp", 33 | "${workspaceFolder}/src/**.h", 34 | "-o", 35 | "${workspaceFolder}/src/PercepNet_run", 36 | "-DTRAINING=0", 37 | ], 38 | "options": { 39 | "cwd": "${fileDirname}" 40 | }, 41 | "problemMatcher": [ 42 | "$gcc" 43 | ], 44 | "group": { 45 | "kind": "build", 46 | "isDefault": true 47 | }, 48 | "detail": "Task generated by Debugger." 49 | } 50 | ], 51 | "version": "2.0.0" 52 | } -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14) 2 | project(percepNet) 3 | 4 | # GoogleTest requires at least C++11 5 | set(CMAKE_CXX_STANDARD 11) 6 | set(CMAKE_CXX_FLAGS "-O3 -Wall -Wextra") 7 | 8 | include(FetchContent) 9 | FetchContent_Declare( 10 | googletest 11 | URL https://github.com/google/googletest/archive/609281088cfefc76f9d0ce82e1ff6c30cc3591e5.zip 12 | ) 13 | # For Windows: Prevent overriding the parent project's compiler/linker settings 14 | set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) 15 | FetchContent_MakeAvailable(googletest) 16 | 17 | enable_testing() 18 | 19 | include_directories(src) 20 | 21 | include(GoogleTest) 22 | add_subdirectory(src) 23 | add_subdirectory(tests) 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, Mozilla 2 | Copyright (c) 2007-2017, Jean-Marc Valin 3 | Copyright (c) 2005-2017, Xiph.Org Foundation 4 | Copyright (c) 2003-2004, Mark Borgerding 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions 8 | are met: 9 | 10 | - Redistributions of source code must retain the above copyright 11 | notice, this list of conditions and the following disclaimer. 12 | 13 | - Redistributions in binary form must reproduce the above copyright 14 | notice, this list of conditions and the following disclaimer in the 15 | documentation and/or other materials provided with the distribution. 16 | 17 | - Neither the name of the Xiph.Org Foundation nor the names of its 18 | contributors may be used to endorse or promote products derived from 19 | this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION 25 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PercepNet 2 | Unofficial implementation of PercepNet: A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech described in https://arxiv.org/abs/2008.04259 3 | 4 | https://www.researchgate.net/publication/343568932_A_Perceptually-Motivated_Approach_for_Low-Complexity_Real-Time_Enhancement_of_Fullband_Speech 5 | 6 | ## Todo 7 | 8 | - [X] pitch estimation 9 | - [X] Comb filter 10 | - [X] ERBBand c++ implementation 11 | - [X] Feature(r,g,pitch,corr) Generator(c++) for pytorch 12 | - [X] DNNModel pytorch 13 | - [X] DNNModel c++ implementation 14 | - [ ] Pretrained model 15 | - [X] Postfiltering (done by [@TeaPoly](https://github.com/TeaPoly ) ) 16 | 17 | 18 | ## Requirements 19 | - CMake 20 | - Sox 21 | - Python>=3.6 22 | - Pytorch 23 | 24 | ## Prepare sampledata 25 | 1. download and sythesize data DNS-Challenge 2020 Dataset before excute utils/run.sh for training. 26 | ```shell 27 | git clone -b interspeech2020/master https://github.com/microsoft/DNS-Challenge.git 28 | ``` 29 | 2. Follow the Usage instruction in DNS Challenge repo(https://github.com/microsoft/DNS-Challenge) at interspeech2020/master branch. please modify save directories at DNS-Challenge/noisyspeech_synthesizer.cfg sampledata/speech and sampledata/noise each. 30 | 31 | ## Build & Training 32 | This repository is tested on Ubuntu 20.04(WSL2) 33 | 34 | 1. setup CMake build environments 35 | ``` 36 | sudo apt-get install cmake 37 | ``` 38 | 2. make binary directory & build 39 | ``` 40 | mkdir bin && cd bin 41 | cmake .. 42 | make -j 43 | cd .. 44 | ``` 45 | 46 | 3. feature generation for training with sampleData 47 | ``` 48 | bin/src/percepNet sampledata/speech/speech.pcm sampledata/noise/noise.pcm 4000 test.output 49 | ``` 50 | 51 | 4. Convert output binary to h5 52 | ``` 53 | python3 utils/bin2h5.py test.output training.h5 54 | ``` 55 | 56 | 5. Training 57 | run utils/run.sh 58 | ```shell 59 | cd utils 60 | ./run.sh 61 | ``` 62 | 63 | 6. Dump weight from pytorch to c++ header 64 | ``` 65 | python3 dump_percepnet.py model.pt 66 | ``` 67 | 68 | 7. Inference 69 | ``` 70 | cd bin 71 | cmake .. 72 | make -j1 73 | cd .. 74 | bin/src/percepNet_run test_input.pcm percepnet_output.pcm 75 | ``` 76 | 77 | 78 | 79 | ## Acknowledgements 80 | [@jasdasdf]( https://github.com/jasdasdf ), [@sTarAnna]( https://github.com/sTarAnna ), [@cookcodes]( https://github.com/cookcodes ), [@xyx361100238]( https://github.com/xyx361100238 ), [@zhangyutf]( https://github.com/zhangyutf ), [@TeaPoly](https://github.com/TeaPoly ), [@rameshkunasi]( https://github.com/rameshkunasi ), [@OscarLiau]( https://github.com/OscarLiau ), [@YangangCao]( https://github.com/YangangCao ), [Jaeyoung Yang]( https://www.linkedin.com/in/jaeyoung-yang-354b21146 ) 81 | 82 | [IIP Lab. Sogang Univ]( http://iip.sogang.ac.kr/) 83 | 84 | 85 | 86 | ## Reference 87 | https://github.com/wil-j-wil/py_bank 88 | 89 | https://github.com/dgaspari/pyrapt 90 | 91 | https://github.com/xiph/rnnoise 92 | 93 | https://github.com/mozilla/LPCNet 94 | -------------------------------------------------------------------------------- /bintoh5test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "96bf793d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 86, 16 | "id": "b865244f", 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stdout", 21 | "output_type": "stream", 22 | "text": [ 23 | "torch.Size([20, 8, 128])\n" 24 | ] 25 | } 26 | ], 27 | "source": [ 28 | "fc = torch.nn.Linear(70, 128)\n", 29 | "m = torch.nn.Conv1d(128, 512, 5, stride=1)\n", 30 | "m2 = torch.nn.Conv1d(512, 512, 3, stride=1)\n", 31 | "rnn = torch.nn.GRU(512, 512, 3, batch_first=True)\n", 32 | "input = torch.randn(20, 8, 70) # B, T, D\n", 33 | "output = fc(input)\n", 34 | "print(output.shape)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 87, 40 | "id": "8697ece9", 41 | "metadata": {}, 42 | "outputs": [ 43 | { 44 | "name": "stdout", 45 | "output_type": "stream", 46 | "text": [ 47 | "torch.Size([20, 512, 2])\n", 48 | "torch.Size([20, 2, 512])\n" 49 | ] 50 | } 51 | ], 52 | "source": [ 53 | "output=output.permute([0,2,1]) # B, D, T \n", 54 | "output = m(output)\n", 55 | "output = m2(output)\n", 56 | "print(output.shape)\n", 57 | "output=output.permute([0,2,1]) # B, T, D\n", 58 | "output= rnn(output)\n", 59 | "\n", 60 | "\n", 61 | "print(output[0].shape)\n", 62 | "input = torch.randn(5, 3, 10)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 15, 68 | "id": "361571d9", 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "name": "stdout", 73 | "output_type": "stream", 74 | "text": [ 75 | "Loading data...\n", 76 | "done.\n", 77 | "2 sequences\n" 78 | ] 79 | } 80 | ], 81 | "source": [ 82 | "import numpy as np\n", 83 | "import h5py\n", 84 | "import sys\n", 85 | "\n", 86 | "\n", 87 | "input_and_output_dim=138\n", 88 | "bin_file_name=\"test.output\"\n", 89 | "data = np.fromfile(bin_file_name, dtype='float32')\n", 90 | "data = np.reshape(data, (len(data)//input_and_output_dim, input_and_output_dim))\n", 91 | "h5f = h5py.File(\"training.h5\", 'w');\n", 92 | "h5f.create_dataset('data', data=data)\n", 93 | "h5f.close()\n", 94 | "\n", 95 | "\n", 96 | "print('Loading data...')\n", 97 | "with h5py.File('training.h5', 'r') as hf:\n", 98 | " all_data = hf['data'][:]\n", 99 | "print('done.')\n", 100 | "\n", 101 | "window_size = 2000\n", 102 | "\n", 103 | "nb_sequences = len(all_data)//window_size\n", 104 | "print(nb_sequences, ' sequences')\n", 105 | "x_train = all_data[:nb_sequences*window_size, :70]\n", 106 | "x_train = np.reshape(x_train, (nb_sequences, window_size, 70))\n", 107 | "\n", 108 | "y_train = np.copy(all_data[:nb_sequences*window_size, 70:])\n", 109 | "y_train = np.reshape(y_train, (nb_sequences, window_size, 68))\n", 110 | "\n", 111 | "all_data = 0\n" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 34, 117 | "id": "19d76b37", 118 | "metadata": {}, 119 | "outputs": [ 120 | { 121 | "data": { 122 | "text/plain": [ 123 | "(2, 2000, 70)" 124 | ] 125 | }, 126 | "execution_count": 34, 127 | "metadata": {}, 128 | "output_type": "execute_result" 129 | } 130 | ], 131 | "source": [ 132 | "x_train.shape" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 41, 138 | "id": "c78d9775", 139 | "metadata": {}, 140 | "outputs": [ 141 | { 142 | "data": { 143 | "text/plain": [ 144 | "array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 145 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 146 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 147 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 148 | " 1., 0.], dtype=float32)" 149 | ] 150 | }, 151 | "execution_count": 41, 152 | "metadata": {}, 153 | "output_type": "execute_result" 154 | } 155 | ], 156 | "source": [ 157 | "x_train[0,1]" 158 | ] 159 | } 160 | ], 161 | "metadata": { 162 | "kernelspec": { 163 | "display_name": "Python 3", 164 | "language": "python", 165 | "name": "python3" 166 | }, 167 | "language_info": { 168 | "codemirror_mode": { 169 | "name": "ipython", 170 | "version": 3 171 | }, 172 | "file_extension": ".py", 173 | "mimetype": "text/x-python", 174 | "name": "python", 175 | "nbconvert_exporter": "python", 176 | "pygments_lexer": "ipython3", 177 | "version": "3.8.5" 178 | } 179 | }, 180 | "nbformat": 4, 181 | "nbformat_minor": 5 182 | } 183 | -------------------------------------------------------------------------------- /dump_percepnet.py: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/python3 3 | '''Copyright (c) 2017-2018 Mozilla 4 | 2020-2021 Seonghun Noh 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions 7 | are met: 8 | - Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | - Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 14 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 15 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 16 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 17 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 18 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 19 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 20 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 21 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 22 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 23 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | ''' 25 | 26 | import torch 27 | import sys 28 | import rnn_train 29 | from torch.nn import Sequential, GRU, Conv1d, Linear 30 | import numpy as np 31 | 32 | def printVector(f, vector, name, dtype='float'): 33 | #torch.transpose(vector, 0, 1) 34 | v = np.reshape(vector.detach().numpy(), (-1)) 35 | #print('static const float ', name, '[', len(v), '] = \n', file=f) 36 | f.write('static const {} {}[{}] = {{\n '.format(dtype, name, len(v))) 37 | for i in range(0, len(v)): 38 | f.write('{}'.format(v[i])) 39 | if (i!=len(v)-1): 40 | f.write(',') 41 | else: 42 | break 43 | if (i%8==7): 44 | f.write("\n ") 45 | else: 46 | f.write(" ") 47 | #print(v, file=f) 48 | f.write('\n};\n\n') 49 | return 50 | 51 | def dump_sequential_module(self, f, name): 52 | activation = self[1].__class__.__name__.upper() 53 | self[0].dump_data(f,name,activation) 54 | Sequential.dump_data = dump_sequential_module 55 | 56 | def dump_linear_module(self, f, name, activation): 57 | print("printing layer " + name) 58 | weight = self.weight 59 | bias = self.bias 60 | #print("weight:", weight) 61 | #activation = self[1].__class__.__name__.upper() 62 | printVector(f, torch.transpose(weight, 0, 1), name + '_weights') 63 | printVector(f, bias, name + '_bias') 64 | f.write('const DenseLayer {} = {{\n {}_bias,\n {}_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n' 65 | .format(name, name, name, weight.shape[1], weight.shape[0], activation)) 66 | Linear.dump_data = dump_linear_module 67 | 68 | def convert_gru_input_kernel(kernel): 69 | kernel_r, kernel_z, kernel_h = np.vsplit(kernel, 3) 70 | kernels = [kernel_z, kernel_r, kernel_h] 71 | return torch.tensor(np.hstack([k.T for k in kernels])) 72 | 73 | def convert_gru_recurrent_kernel(kernel): 74 | kernel_r, kernel_z, kernel_h = np.vsplit(kernel, 3) 75 | kernels = [kernel_z, kernel_r, kernel_h] 76 | return torch.tensor(np.hstack([k.T for k in kernels])) 77 | 78 | def convert_bias(bias): 79 | bias = bias.reshape(2, 3, -1) 80 | return torch.tensor(bias[:, [1, 0, 2], :].reshape(-1)) 81 | 82 | def dump_gru_module(self, f, name): 83 | print("printing layer " + name ) 84 | weights = convert_gru_input_kernel(self.weight_ih_l0.detach().numpy()) 85 | recurrent_weights = convert_gru_recurrent_kernel(self.weight_hh_l0.detach().numpy()) 86 | bias = torch.cat((self.bias_ih_l0, self.bias_hh_l0)) 87 | bias = convert_bias(bias.detach().numpy()) 88 | printVector(f, weights, name + '_weights') 89 | printVector(f, recurrent_weights, name + '_recurrent_weights') 90 | printVector(f, bias, name + '_bias') 91 | if hasattr(self, 'activation'): 92 | activation = self.activation.__name__.upper() 93 | else: 94 | activation = 'TANH' 95 | if hasattr(self, 'reset_after') and not self.reset_after: 96 | reset_after = 0 97 | else: 98 | reset_after = 1 99 | neurons = weights.shape[0]//3 100 | #max_rnn_neurons = max(max_rnn_neurons, neurons) 101 | print('const GRULayer {} = {{\n {}_bias,\n {}_weights,\n {}_recurrent_weights,\n {}, {}, ACTIVATION_{}, {}\n}};\n\n' 102 | .format(name, name, name, name, weights.shape[0], weights.shape[1]//3, activation, reset_after)) 103 | f.write('const GRULayer {} = {{\n {}_bias,\n {}_weights,\n {}_recurrent_weights,\n {}, {}, ACTIVATION_{}, {}\n}};\n\n' 104 | .format(name, name, name, name, weights.shape[0], weights.shape[1]//3, activation, reset_after)) 105 | #hf.write('#define {}_OUT_SIZE {}\n'.format(name.upper(), weights[0].shape[1]//3)) 106 | #hf.write('#define {}_STATE_SIZE {}\n'.format(name.upper(), weights[0].shape[1]//3)) 107 | #hf.write('extern const GRULayer {};\n\n'.format(name)) 108 | GRU.dump_data = dump_gru_module 109 | 110 | def dump_conv1d_module(self, f, name, activation): 111 | print("printing layer " + name ) 112 | weights = self.weight 113 | printVector(f, weights.permute(2,1,0), name + '_weights') 114 | printVector(f, self.bias, name + '_bias') 115 | #activation = self.activation.__name__.upper() 116 | #max_conv_inputs = max(max_conv_inputs, weights[0].shape[1]*weights[0].shape[0]) 117 | #warn! activation hard codedW 118 | print('const Conv1DLayer {} = {{\n {}_bias,\n {}_weights,\n {}, {}, {}, ACTIVATION_{}\n}};\n\n' 119 | .format(name, name, name, weights.shape[1], weights.shape[2], weights.shape[0], activation)) 120 | f.write('const Conv1DLayer {} = {{\n {}_bias,\n {}_weights,\n {}, {}, {}, ACTIVATION_{}\n}};\n\n' 121 | .format(name, name, name, weights.shape[1], weights.shape[2], weights.shape[0], activation)) 122 | #hf.write('#define {}_OUT_SIZE {}\n'.format(name.upper(), weights[0].shape[2])) 123 | #hf.write('#define {}_STATE_SIZE ({}*{})\n'.format(name.upper(), weights[0].shape[1], (weights[0].shape[0]-1))) 124 | #hf.write('#define {}_DELAY {}\n'.format(name.upper(), (weights[0].shape[0]-1)//2)) 125 | #hf.write('extern const Conv1DLayer {};\n\n'.format(name)); 126 | Conv1d.dump_data = dump_conv1d_module 127 | 128 | if __name__ == '__main__': 129 | model = rnn_train.PercepNet() 130 | #model = ( 131 | model.load_state_dict(torch.load(sys.argv[1], map_location="cpu")) 132 | 133 | if len(sys.argv) > 2: 134 | cfile = sys.argv[2] 135 | #hfile = sys.argv[3]; 136 | else: 137 | cfile = 'src/nnet_data.cpp' 138 | #hfile = 'nnet_data.h' 139 | 140 | f = open(cfile, 'w') 141 | #hf = open(hfile, 'w') 142 | 143 | f.write('/*This file is automatically generated from a Pytorch model*/\n\n') 144 | f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "nnet.h"\n#include "nnet_data.h"\n\n') 145 | 146 | for name, module in model.named_children(): 147 | module.dump_data(f, name) 148 | 149 | f.write('extern const RNNModel percepnet_model_orig = {\n') 150 | for name, module in model.named_children(): 151 | f.write(' &{},\n'.format(name)) 152 | f.write('};\n') 153 | 154 | f.close() 155 | print("done") -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | librosa 2 | numpy 3 | torch 4 | tensorboardX 5 | matplotlib 6 | h5py 7 | yaml 8 | tqdm 9 | glob -------------------------------------------------------------------------------- /rnn_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin python3 2 | 3 | # Copyright 2021 Seonghun Noh 4 | 5 | import argparse 6 | import logging 7 | import os 8 | import io 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.utils.data import DataLoader, Dataset 13 | import numpy as np 14 | import h5py 15 | import argparse 16 | from tensorboardX import SummaryWriter 17 | import matplotlib.pyplot as plt 18 | import torch.nn.utils.rnn as rnn_utils 19 | 20 | from collections import defaultdict 21 | import yaml 22 | import glob 23 | from tqdm import tqdm 24 | import PIL.Image 25 | from torchvision.transforms import ToTensor 26 | 27 | plt.switch_backend('agg') 28 | class CppRawListDataset(Dataset): 29 | def __init__(self, filelist_path, train_length_size=500): 30 | self.train_length_size = train_length_size 31 | self.filelist_path = filelist_path 32 | self.x_dim = 70 33 | self.y_dim = 68 34 | 35 | with open(filelist_path, "r") as f: 36 | self.filelist = [filepath.rstrip('\n') for filepath in f.readlines()] 37 | self.nb_sequences = len(self.filelist) 38 | 39 | print(self.nb_sequences, ' sequences') 40 | 41 | def __len__(self): 42 | return self.nb_sequences 43 | 44 | def __getitem__(self, index): 45 | with open(self.filelist[index], 'rb') as cpp_out: 46 | all_data = np.fromfile(cpp_out, np.float32) 47 | all_data = np.reshape(all_data, (self.train_length_size,138)) 48 | #make it band energy 30 times bigger for compansating low energy 49 | all_data[:,:68] = all_data[:,:68]*30 50 | 51 | x = all_data[:,:self.x_dim] 52 | y = all_data[:,self.x_dim:] 53 | return (x,y) 54 | 55 | class h5DirDataset(Dataset): 56 | def __init__(self, h5_dir_path, train_length_size=500): 57 | self.train_length_size = train_length_size 58 | self.h5_dir_path = h5_dir_path 59 | self.x_dim = 70 60 | self.y_dim = 68 61 | 62 | self.h5_filelist = glob.glob(os.path.join(h5_dir_path, "*.h5")) 63 | self.nb_sequences = len(self.h5_filelist) 64 | 65 | print(self.nb_sequences, ' sequences') 66 | 67 | def __len__(self): 68 | return self.nb_sequences 69 | 70 | def __getitem__(self, index): 71 | with h5py.File(self.h5_filelist[index], 'r') as hf: 72 | all_data = hf['data'][:] 73 | x = all_data[:,:self.x_dim] 74 | y = all_data[:,self.x_dim:] 75 | return (x,y) 76 | 77 | class h5Dataset(Dataset): 78 | 79 | def __init__(self, h5_filename="training.h5", window_size=500): 80 | self.window_size = window_size 81 | self.h5_filename = h5_filename 82 | self.x_dim = 70 83 | self.y_dim = 68 84 | 85 | #read h5file 86 | with h5py.File(self.h5_filename, 'r') as hf: 87 | all_data = hf['data'][:] 88 | 89 | self.nb_sequences = len(all_data)//window_size 90 | print(self.nb_sequences, ' sequences') 91 | x_train = all_data[:self.nb_sequences*self.window_size, :self.x_dim] 92 | self.x_train = np.reshape(x_train, (self.nb_sequences, self.window_size, self.x_dim)) 93 | #pad 3 for each batch .. not sure it's right 94 | #self.x_train = np.pad(self.x_train,[(0,0),(3,3),(0,0)],'constant') 95 | 96 | y_train = np.copy(all_data[:self.nb_sequences*self.window_size, self.x_dim:self.x_dim+self.y_dim]) 97 | self.y_train = np.reshape(y_train, (self.nb_sequences, self.window_size, self.y_dim)) 98 | 99 | def __len__(self): 100 | return self.nb_sequences 101 | 102 | def __getitem__(self, index): 103 | return (self.x_train[index], self.y_train[index]) 104 | 105 | class PercepNet(nn.Module): 106 | def __init__(self, input_dim=70): 107 | super(PercepNet, self).__init__() 108 | #self.hidden_dim = hidden_dim 109 | #self.n_layers = n_layers 110 | 111 | self.fc = nn.Sequential(nn.Linear(input_dim, 128), nn.ReLU()) 112 | self.conv1 = nn.Sequential(nn.Conv1d(128, 512, 5, stride=1, padding=4), nn.ReLU())#padding for align with c++ dnn 113 | self.conv2 = nn.Sequential(nn.Conv1d(512, 512, 3, stride=1, padding=2), nn.Tanh()) 114 | #self.gru = nn.GRU(512, 512, 3, batch_first=True) 115 | self.gru1 = nn.GRU(512, 512, 1, batch_first=True) 116 | self.gru2 = nn.GRU(512, 512, 1, batch_first=True) 117 | self.gru3 = nn.GRU(512, 512, 1, batch_first=True) 118 | self.gru_gb = nn.GRU(512, 512, 1, batch_first=True) 119 | self.gru_rb = nn.GRU(1024, 128, 1, batch_first=True) 120 | self.fc_gb = nn.Sequential(nn.Linear(512*5, 34), nn.Sigmoid()) 121 | self.fc_rb = nn.Sequential(nn.Linear(128, 34), nn.Sigmoid()) 122 | 123 | def forward(self, x): 124 | x = self.fc(x) 125 | x = x.permute([0,2,1]) # B, D, T 126 | x = self.conv1(x) 127 | x = x[:,:,:-4] 128 | convout = self.conv2(x) 129 | convout = convout[:,:,:-2]#align with c++ dnn 130 | convout = convout.permute([0,2,1]) # B, T, D 131 | 132 | gru1_out, gru1_state = self.gru1(convout) 133 | gru2_out, gru2_state = self.gru2(gru1_out) 134 | gru3_out, gru3_state = self.gru3(gru2_out) 135 | gru_gb_out, gru_gb_state = self.gru_gb(gru3_out) 136 | concat_gb_layer = torch.cat((convout,gru1_out,gru2_out,gru3_out,gru_gb_out),-1) 137 | gb = self.fc_gb(concat_gb_layer) 138 | 139 | #concat rb need fix 140 | concat_rb_layer = torch.cat((gru3_out,convout),-1) 141 | rnn_rb_out, gru_rb_state = self.gru_rb(concat_rb_layer) 142 | rb = self.fc_rb(rnn_rb_out) 143 | 144 | output = torch.cat((gb,rb),-1) 145 | return output 146 | 147 | def test(): 148 | model = PercepNet() 149 | x = torch.randn(20, 8, 70) 150 | out = model(x) 151 | print(out.shape) 152 | 153 | class CustomLoss(nn.Module): 154 | def __init__(self, weight=None, size_average=True): 155 | super(CustomLoss, self).__init__() 156 | 157 | def forward(self, outputs, targets): 158 | gamma = 0.5 159 | C4 = 10 160 | epsi = 1e-10 161 | gb_hat = outputs[:,:,:34] 162 | rb_hat = outputs[:,:,34:68] 163 | gb = targets[:,:,:34] 164 | rb = targets[:,:,34:68] 165 | 166 | ''' 167 | total_loss=0 168 | for i in range(500): 169 | total_loss += (torch.sum(torch.pow((torch.pow(gb[:,i,:],gamma) - torch.pow(gb_hat[:,i,:],gamma)),2))) \ 170 | + C4*torch.sum(torch.pow(torch.pow(gb[:,i,:],gamma) - torch.pow(gb_hat[:,i,:],gamma),4)) \ 171 | + torch.sum(torch.pow(torch.pow((1-rb[:,i,:]),gamma)-torch.pow((1-rb_hat[:,i,:]),gamma),2)) 172 | return total_loss 173 | ''' 174 | return (torch.mean(torch.pow((torch.pow(gb,gamma) - torch.pow(gb_hat,gamma)),2))) \ 175 | + C4*torch.mean(torch.pow(torch.pow(gb,gamma) - torch.pow(gb_hat,gamma),4)) \ 176 | + torch.mean(torch.pow(torch.pow((1-rb),gamma)-torch.pow((1-rb_hat),gamma),2)) 177 | 178 | 179 | 180 | def train(): 181 | parser = argparse.ArgumentParser() 182 | writer = SummaryWriter() 183 | 184 | UseCustomLoss = True 185 | dataset = h5Dataset("training.h5") 186 | trainset_ratio = 1 # 1 - validation set ration 187 | train_size = int(trainset_ratio * len(dataset)) 188 | test_size = len(dataset) - train_size 189 | batch_size=10 190 | train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size]) 191 | 192 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 193 | #validation_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True) 194 | 195 | model = PercepNet() 196 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) 197 | if UseCustomLoss: 198 | #CustomLoss cause Nan error need fix 199 | criterion = CustomLoss() 200 | else: 201 | criterion = nn.MSELoss() 202 | num_epochs = 10000 203 | for epoch in range(num_epochs): # loop over the dataset multiple times 204 | 205 | running_loss = 0.0 206 | for i, data in enumerate(train_loader, 0): 207 | # get the inputs; data is a list of [inputs, labels] 208 | inputs, targets = data 209 | 210 | # zero the parameter gradients 211 | optimizer.zero_grad() 212 | 213 | # forward + backward + optimize 214 | outputs = model(inputs) 215 | #outputs = torch.cat(outputs,-1) 216 | loss = criterion(outputs, targets) 217 | loss.backward() 218 | optimizer.step() 219 | 220 | # print statistics 221 | running_loss += loss.item() 222 | 223 | # for testing 224 | print('[%d, %5d] loss: %.3f' % 225 | (epoch + 1, i + 1, loss.item())) 226 | 227 | if i % 2000 == 1999: # print every 2000 mini-batches 228 | print('[%d, %5d] loss: %.3f' % 229 | (epoch + 1, i + 1, running_loss / 2000)) 230 | running_loss = 0.0 231 | 232 | model.eval() 233 | tmp_output = model(torch.tensor(dataset[0][0]).unsqueeze(0)) 234 | model.train() 235 | fig = plt.figure() 236 | plt.plot(tmp_output[0].squeeze(0).T.detach().numpy()) 237 | writer.add_figure('output gb', fig, global_step=epoch) 238 | fig = plt.figure() 239 | plt.plot(dataset[0][1][:,:].T) 240 | writer.add_figure('target gb', fig, global_step=epoch) 241 | writer.add_scalar('loss', loss.item(), global_step=epoch) 242 | print('Finished Training') 243 | print('save model') 244 | writer.close() 245 | torch.save(model.state_dict(), 'model.pt') 246 | 247 | def gen_plot(y, y_hat): 248 | # Create a figure to contain the plot. 249 | plt.figure(figsize=(10,5)) 250 | 251 | # Start next subplot. 252 | plt.subplot(1, 2, 1) 253 | plt.imshow(y_hat.T,interpolation='none',cmap=plt.cm.jet,origin='lower',aspect='auto') 254 | plt.subplot(1, 2, 2) 255 | plt.imshow(y.T,interpolation='none',cmap=plt.cm.jet,origin='lower',aspect='auto') 256 | buf = io.BytesIO() 257 | plt.savefig(buf, format='png') 258 | buf.seek(0) 259 | return buf 260 | 261 | class Trainer(object): 262 | """Customized trainer module for PercepNet training.""" 263 | 264 | def __init__( 265 | self, 266 | steps, 267 | epochs, 268 | data_loader, 269 | sampler, 270 | model, 271 | criterion, 272 | optimizer, 273 | args, 274 | config, 275 | device=torch.device("cpu"), 276 | ): 277 | """Initialize trainer. 278 | Args: 279 | steps (int): Initial global steps. 280 | epochs (int): Initial global epochs. 281 | data_loader (dict): Dict of data loaders. It must contrain "train" and "dev" loaders. 282 | model (nn.Module): Model. Instance of nn.Module 283 | criterion (nn.Module): criterions. 284 | optimizer (torch.optim): optimizers. 285 | args (parser.parse_args()): Instance of argparse parse_args() 286 | device (torch.deive): Pytorch device instance. 287 | """ 288 | self.steps = steps 289 | self.epochs = epochs 290 | self.data_loader = data_loader 291 | self.sampler = sampler 292 | self.model = model 293 | self.criterion = criterion 294 | self.args = args 295 | self.optimizer = optimizer 296 | self.device = device 297 | self.config = config 298 | 299 | self.writer = SummaryWriter(config["out_dir"]) 300 | self.finish_train = False 301 | self.total_train_loss = defaultdict(float) 302 | self.total_eval_loss = defaultdict(float) 303 | 304 | def run(self): 305 | """Run training.""" 306 | self.tqdm = tqdm( 307 | initial=self.steps, total=self.args.train_max_steps, desc="[train]" 308 | ) 309 | while True: 310 | # train one epoch 311 | self._train_epoch() 312 | 313 | # check whether training is finished 314 | if self.finish_train: 315 | break 316 | 317 | self.tqdm.close() 318 | logging.info("Finished training.") 319 | 320 | def save_checkpoint(self, checkpoint_path): 321 | """Save checkpoint.""" 322 | torch.save(self.model.state_dict(), checkpoint_path) 323 | 324 | def load_checkpoint(self, checkpoint_path): 325 | """Load checkpoint. 326 | Args: 327 | checkpoint_path (str): Checkpoint path to be loaded. 328 | """ 329 | state_dict = torch.load(checkpoint_path, map_location="cpu") 330 | if self.args.distributed: 331 | self.model.module.load_state_dict(state_dict) 332 | else: 333 | self.model.load_state_dict(state_dict) 334 | 335 | def _train_step(self, batch): 336 | """Train model one step.""" 337 | # get the inputs; data is a list of [inputs, labels] 338 | inputs, targets = batch 339 | inputs = inputs.to(self.device) 340 | targets = targets.to(self.device) 341 | # zero the parameter gradients 342 | self.optimizer.zero_grad() 343 | 344 | # forward + backward + optimize 345 | outputs = self.model(inputs) 346 | #outputs = torch.cat(outputs,-1) 347 | loss = self.criterion(outputs, targets) 348 | loss.backward() 349 | self.optimizer.step() 350 | 351 | self.total_train_loss["train/total_loss"] += loss.item() 352 | # update counts 353 | self.steps += 1 354 | self.tqdm.update(1) 355 | self._check_train_finish() 356 | 357 | def _train_epoch(self): 358 | """Train model one epoch.""" 359 | for train_steps_per_epoch, batch in enumerate(self.data_loader["train"], 1): 360 | # train one step 361 | self._train_step(batch) 362 | 363 | # check interval 364 | if self.args.rank == 0: 365 | self._check_log_interval() 366 | self._check_eval_interval() 367 | self._check_save_interval() 368 | 369 | # check whether training is finished 370 | if self.finish_train: 371 | return 372 | 373 | # update 374 | self.epochs += 1 375 | self.train_steps_per_epoch = train_steps_per_epoch 376 | logging.info( 377 | f"(Steps: {self.steps}) Finished {self.epochs} epoch training " 378 | f"({self.train_steps_per_epoch} steps per epoch)." 379 | ) 380 | 381 | @torch.no_grad() 382 | def _eval_step(self, batch): 383 | """Evaluate model one step.""" 384 | # parse batch 385 | inputs, targets = batch 386 | inputs = inputs.to(self.device) 387 | targets = targets.to(self.device) 388 | # forward + backward + optimize 389 | outputs = self.model(inputs) 390 | 391 | loss = self.criterion(outputs, targets) 392 | self.total_eval_loss["eval/total_loss"] += loss.item() 393 | 394 | def _eval_epoch(self): 395 | """Evaluate model one epoch.""" 396 | logging.info(f"(Steps: {self.steps}) Start evaluation.") 397 | # change mode 398 | self.model.eval() 399 | 400 | # calculate loss for each batch 401 | for eval_steps_per_epoch, batch in enumerate( 402 | tqdm(self.data_loader["dev"], desc="[eval]"), 1 403 | ): 404 | # eval one step 405 | self._eval_step(batch) 406 | 407 | # save intermediate result 408 | if eval_steps_per_epoch == 1: 409 | self._genearete_and_save_intermediate_result(batch) 410 | 411 | logging.info( 412 | f"(Steps: {self.steps}) Finished evaluation " 413 | f"({eval_steps_per_epoch} steps per epoch)." 414 | ) 415 | # average loss 416 | for key in self.total_eval_loss.keys(): 417 | self.total_eval_loss[key] /= eval_steps_per_epoch 418 | logging.info( 419 | f"(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}." 420 | ) 421 | 422 | # record 423 | self._write_to_tensorboard(self.total_eval_loss) 424 | 425 | # reset 426 | self.total_eval_loss = defaultdict(float) 427 | 428 | # restore mode 429 | self.model.train() 430 | 431 | @torch.no_grad() 432 | def _genearete_and_save_intermediate_result(self, batch): 433 | """Generate and save intermediate result.""" 434 | # delayed import to avoid error related backend error 435 | import matplotlib.pyplot as plt 436 | 437 | # generate 438 | x_batch, y_batch = batch 439 | x_batch = x_batch.to(self.device) 440 | y_batch = y_batch.to(self.device) 441 | y_batch_ = self.model(x_batch) 442 | 443 | 444 | # check directory 445 | #dirname = os.path.join(self.config["outdir"], f"predictions/{self.steps}steps") 446 | #if not os.path.exists(dirname): 447 | # os.makedirs(dirname) 448 | 449 | for idx, (y, y_) in enumerate(zip(y_batch, y_batch_), 1): 450 | if idx==1: 451 | # convert to ndarray 452 | y, y_ = y.cpu().numpy(), y_.cpu().numpy() 453 | plot_buf=gen_plot(y, y_) 454 | image = PIL.Image.open(plot_buf) 455 | image = ToTensor()(image) 456 | self.writer.add_image('rb,gb hat', image, self.steps) 457 | print("writeimage") 458 | 459 | def _write_to_tensorboard(self, loss): 460 | """Write to tensorboard.""" 461 | for key, value in loss.items(): 462 | self.writer.add_scalar(key, value, self.steps) 463 | 464 | def _check_save_interval(self): 465 | if self.steps % self.config["save_interval_steps"] == 0: 466 | self.save_checkpoint( 467 | os.path.join(self.config["out_dir"], f"checkpoint-{self.steps}steps.pkl") 468 | ) 469 | logging.info(f"Successfully saved checkpoint @ {self.steps} steps.") 470 | 471 | def _check_eval_interval(self): 472 | if self.steps % self.config["eval_interval_steps"] == 0: 473 | self._eval_epoch() 474 | 475 | def _check_log_interval(self): 476 | if self.steps % self.config["log_interval_steps"] == 0: 477 | for key in self.total_train_loss.keys(): 478 | self.total_train_loss[key] /= self.config["log_interval_steps"] 479 | logging.info( 480 | f"(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}." 481 | ) 482 | self._write_to_tensorboard(self.total_train_loss) 483 | 484 | # reset 485 | self.total_train_loss = defaultdict(float) 486 | 487 | def _check_train_finish(self): 488 | if self.steps >= self.config["train_max_steps"]: 489 | self.finish_train = True 490 | 491 | def main(): 492 | """Run training process.""" 493 | parser = argparse.ArgumentParser( 494 | description="Train PercepNet (See detail in rnn_train.py)." 495 | ) 496 | parser.add_argument( 497 | "--train_length_size", 498 | default=2000, 499 | type=int, 500 | help="RNN network train length size.", 501 | ) 502 | parser.add_argument( 503 | "--train_max_steps", 504 | default=100000, 505 | type=int, 506 | help="max train steps.", 507 | ) 508 | parser.add_argument( 509 | "--train_filelist_path", 510 | type=str, 511 | required=True, 512 | help="cpp generated feature train filelist path.", 513 | ) 514 | parser.add_argument( 515 | "--dev_filelist_path", 516 | type=str, 517 | required=True, 518 | help="cpp generated feature dev filelist path", 519 | ) 520 | parser.add_argument( 521 | "--pretrain", 522 | default="", 523 | type=str, 524 | nargs="?", 525 | help='checkpoint file path to load pretrained params. (default="")', 526 | ) 527 | parser.add_argument( 528 | "--rank", 529 | "--local_rank", 530 | default=0, 531 | type=int, 532 | help="rank for distributed training. no need to explictly specify.", 533 | ) 534 | parser.add_argument( 535 | "--out_dir", 536 | type=str, 537 | required=True, 538 | help="directory to save checkpoints.", 539 | ) 540 | parser.add_argument( 541 | "--config", 542 | type=str, 543 | required=True, 544 | help="yaml format configuration file.", 545 | ) 546 | 547 | args = parser.parse_args() 548 | 549 | args.distributed = False 550 | if not torch.cuda.is_available(): 551 | device = torch.device("cpu") 552 | else: 553 | device = torch.device("cuda") 554 | # effective when using fixed size inputs 555 | # see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936 556 | torch.backends.cudnn.benchmark = True 557 | torch.cuda.set_device(args.rank) 558 | # setup for distributed training 559 | # see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed 560 | if "WORLD_SIZE" in os.environ: 561 | args.world_size = int(os.environ["WORLD_SIZE"]) 562 | args.distributed = args.world_size > 1 563 | if args.distributed: 564 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 565 | 566 | # load and save config 567 | with open(args.config) as f: 568 | config = yaml.load(f, Loader=yaml.Loader) 569 | config.update(vars(args)) 570 | with open(os.path.join(args.out_dir, "config.yml"), "w") as f: 571 | yaml.dump(config, f, Dumper=yaml.Dumper) 572 | for key, value in config.items(): 573 | logging.info(f"{key} = {value}") 574 | 575 | model = PercepNet().to(device) 576 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) 577 | criterion = CustomLoss() 578 | 579 | train_dataset = CppRawListDataset( 580 | args.train_filelist_path, train_length_size=args.train_length_size) 581 | dev_dataset = CppRawListDataset( 582 | args.dev_filelist_path, train_length_size=args.train_length_size) 583 | 584 | logging.info(f"The number of training files = {len(train_dataset)}.") 585 | logging.info(f"The number of training files = {len(dev_dataset)}.") 586 | 587 | dataset = { 588 | "train": train_dataset, 589 | "dev": dev_dataset, 590 | } 591 | 592 | sampler = {"train": None, "dev": None} 593 | if args.distributed: 594 | # setup sampler for distributed training 595 | from torch.utils.data.distributed import DistributedSampler 596 | 597 | sampler["train"] = DistributedSampler( 598 | dataset=dataset["train"], 599 | num_replicas=args.world_size, 600 | rank=args.rank, 601 | shuffle=True, 602 | ) 603 | sampler["dev"] = DistributedSampler( 604 | dataset=dataset["dev"], 605 | num_replicas=args.world_size, 606 | rank=args.rank, 607 | shuffle=False, 608 | ) 609 | 610 | data_loader = { 611 | "train" : torch.utils.data.DataLoader( 612 | dataset["train"], 613 | batch_size=config["batch_size"], 614 | num_workers=config["num_workers"], 615 | shuffle=True 616 | ), 617 | "dev": torch.utils.data.DataLoader( 618 | dataset["dev"], 619 | batch_size=config["batch_size"], 620 | num_workers=config["num_workers"], 621 | shuffle=False 622 | ) 623 | } 624 | # define trainer 625 | trainer = Trainer( 626 | steps=0, 627 | epochs=0, 628 | model=model, 629 | data_loader=data_loader, 630 | criterion=criterion, 631 | optimizer=optimizer, 632 | config=config, 633 | args=args, 634 | sampler=sampler, 635 | device=device, 636 | ) 637 | 638 | # load pretrained parameters from checkpoint 639 | if len(args.pretrain) != 0: 640 | trainer.load_checkpoint(args.pretrain) 641 | logging.info(f"Successfully load parameters from {args.pretrain}.") 642 | # run training loop 643 | 644 | try: 645 | trainer.run() 646 | finally: 647 | trainer.save_checkpoint( 648 | os.path.join(config["out_dir"], f"checkpoint-{trainer.steps}steps.pt") 649 | ) 650 | logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.") 651 | 652 | if __name__ == '__main__': 653 | main() 654 | #train() 655 | -------------------------------------------------------------------------------- /sampledata/noise/noise.pcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jzi040941/PercepNet/8ffae4337d23f920176ac2a7426e84610fa338ab/sampledata/noise/noise.pcm -------------------------------------------------------------------------------- /sampledata/speech/speech.pcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jzi040941/PercepNet/8ffae4337d23f920176ac2a7426e84610fa338ab/sampledata/speech/speech.pcm -------------------------------------------------------------------------------- /src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(BINARY ${CMAKE_PROJECT_NAME}) 2 | 3 | file(GLOB_RECURSE SOURCES LIST_DIRECTORIES true *.h *.cpp) 4 | 5 | set(SOURCES ${SOURCES}) 6 | 7 | add_executable(${BINARY} ${SOURCES}) 8 | set_target_properties( 9 | ${BINARY} 10 | PROPERTIES 11 | COMPILE_DEFINITIONS TRAINING=1 12 | ) 13 | if(EXISTS "${PROJECT_SOURCE_DIR}/src/nnet_data.cpp") 14 | add_executable(${BINARY}_run ${SOURCES}) 15 | else() 16 | message([WARNING] "nnet_data.cpp is not exist. Do not generate inference executable" ...) 17 | endif() 18 | 19 | 20 | 21 | add_library(${BINARY}_lib STATIC ${SOURCES}) 22 | -------------------------------------------------------------------------------- /src/_kiss_fft_guts.h: -------------------------------------------------------------------------------- 1 | /*Copyright (c) 2003-2004, Mark Borgerding 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, 9 | this list of conditions and the following disclaimer. 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 17 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 18 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 19 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 20 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 21 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 22 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 23 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 24 | POSSIBILITY OF SUCH DAMAGE.*/ 25 | 26 | #ifndef KISS_FFT_GUTS_H 27 | #define KISS_FFT_GUTS_H 28 | 29 | #define MIN(a,b) ((a)<(b) ? (a):(b)) 30 | #define MAX(a,b) ((a)>(b) ? (a):(b)) 31 | 32 | /* kiss_fft.h 33 | defines kiss_fft_scalar as either short or a float type 34 | and defines 35 | typedef struct { kiss_fft_scalar r; kiss_fft_scalar i; }kiss_fft_cpx; */ 36 | #include "kiss_fft.h" 37 | 38 | /* 39 | Explanation of macros dealing with complex math: 40 | 41 | C_MUL(m,a,b) : m = a*b 42 | C_FIXDIV( c , div ) : if a fixed point impl., c /= div. noop otherwise 43 | C_SUB( res, a,b) : res = a - b 44 | C_SUBFROM( res , a) : res -= a 45 | C_ADDTO( res , a) : res += a 46 | * */ 47 | #ifdef FIXED_POINT 48 | #include "arch.h" 49 | 50 | 51 | #define SAMP_MAX 2147483647 52 | #define TWID_MAX 32767 53 | #define TRIG_UPSCALE 1 54 | 55 | #define SAMP_MIN -SAMP_MAX 56 | 57 | 58 | # define S_MUL(a,b) MULT16_32_Q15(b, a) 59 | 60 | # define C_MUL(m,a,b) \ 61 | do{ (m).r = SUB32_ovflw(S_MUL((a).r,(b).r) , S_MUL((a).i,(b).i)); \ 62 | (m).i = ADD32_ovflw(S_MUL((a).r,(b).i) , S_MUL((a).i,(b).r)); }while(0) 63 | 64 | # define C_MULC(m,a,b) \ 65 | do{ (m).r = ADD32_ovflw(S_MUL((a).r,(b).r) , S_MUL((a).i,(b).i)); \ 66 | (m).i = SUB32_ovflw(S_MUL((a).i,(b).r) , S_MUL((a).r,(b).i)); }while(0) 67 | 68 | # define C_MULBYSCALAR( c, s ) \ 69 | do{ (c).r = S_MUL( (c).r , s ) ;\ 70 | (c).i = S_MUL( (c).i , s ) ; }while(0) 71 | 72 | # define DIVSCALAR(x,k) \ 73 | (x) = S_MUL( x, (TWID_MAX-((k)>>1))/(k)+1 ) 74 | 75 | # define C_FIXDIV(c,div) \ 76 | do { DIVSCALAR( (c).r , div); \ 77 | DIVSCALAR( (c).i , div); }while (0) 78 | 79 | #define C_ADD( res, a,b)\ 80 | do {(res).r=ADD32_ovflw((a).r,(b).r); (res).i=ADD32_ovflw((a).i,(b).i); \ 81 | }while(0) 82 | #define C_SUB( res, a,b)\ 83 | do {(res).r=SUB32_ovflw((a).r,(b).r); (res).i=SUB32_ovflw((a).i,(b).i); \ 84 | }while(0) 85 | #define C_ADDTO( res , a)\ 86 | do {(res).r = ADD32_ovflw((res).r, (a).r); (res).i = ADD32_ovflw((res).i,(a).i);\ 87 | }while(0) 88 | 89 | #define C_SUBFROM( res , a)\ 90 | do {(res).r = ADD32_ovflw((res).r,(a).r); (res).i = SUB32_ovflw((res).i,(a).i); \ 91 | }while(0) 92 | 93 | #if defined(OPUS_ARM_INLINE_ASM) 94 | #include "arm/kiss_fft_armv4.h" 95 | #endif 96 | 97 | #if defined(OPUS_ARM_INLINE_EDSP) 98 | #include "arm/kiss_fft_armv5e.h" 99 | #endif 100 | #if defined(MIPSr1_ASM) 101 | #include "mips/kiss_fft_mipsr1.h" 102 | #endif 103 | 104 | #else /* not FIXED_POINT*/ 105 | 106 | # define S_MUL(a,b) ( (a)*(b) ) 107 | #define C_MUL(m,a,b) \ 108 | do{ (m).r = (a).r*(b).r - (a).i*(b).i;\ 109 | (m).i = (a).r*(b).i + (a).i*(b).r; }while(0) 110 | #define C_MULC(m,a,b) \ 111 | do{ (m).r = (a).r*(b).r + (a).i*(b).i;\ 112 | (m).i = (a).i*(b).r - (a).r*(b).i; }while(0) 113 | 114 | #define C_MUL4(m,a,b) C_MUL(m,a,b) 115 | 116 | # define C_FIXDIV(c,div) /* NOOP */ 117 | # define C_MULBYSCALAR( c, s ) \ 118 | do{ (c).r *= (s);\ 119 | (c).i *= (s); }while(0) 120 | #endif 121 | 122 | #ifndef CHECK_OVERFLOW_OP 123 | # define CHECK_OVERFLOW_OP(a,op,b) /* noop */ 124 | #endif 125 | 126 | #ifndef C_ADD 127 | #define C_ADD( res, a,b)\ 128 | do { \ 129 | CHECK_OVERFLOW_OP((a).r,+,(b).r)\ 130 | CHECK_OVERFLOW_OP((a).i,+,(b).i)\ 131 | (res).r=(a).r+(b).r; (res).i=(a).i+(b).i; \ 132 | }while(0) 133 | #define C_SUB( res, a,b)\ 134 | do { \ 135 | CHECK_OVERFLOW_OP((a).r,-,(b).r)\ 136 | CHECK_OVERFLOW_OP((a).i,-,(b).i)\ 137 | (res).r=(a).r-(b).r; (res).i=(a).i-(b).i; \ 138 | }while(0) 139 | #define C_ADDTO( res , a)\ 140 | do { \ 141 | CHECK_OVERFLOW_OP((res).r,+,(a).r)\ 142 | CHECK_OVERFLOW_OP((res).i,+,(a).i)\ 143 | (res).r += (a).r; (res).i += (a).i;\ 144 | }while(0) 145 | 146 | #define C_SUBFROM( res , a)\ 147 | do {\ 148 | CHECK_OVERFLOW_OP((res).r,-,(a).r)\ 149 | CHECK_OVERFLOW_OP((res).i,-,(a).i)\ 150 | (res).r -= (a).r; (res).i -= (a).i; \ 151 | }while(0) 152 | #endif /* C_ADD defined */ 153 | 154 | #ifdef FIXED_POINT 155 | /*# define KISS_FFT_COS(phase) TRIG_UPSCALE*floor(MIN(32767,MAX(-32767,.5+32768 * cos (phase)))) 156 | # define KISS_FFT_SIN(phase) TRIG_UPSCALE*floor(MIN(32767,MAX(-32767,.5+32768 * sin (phase))))*/ 157 | # define KISS_FFT_COS(phase) floor(.5+TWID_MAX*cos (phase)) 158 | # define KISS_FFT_SIN(phase) floor(.5+TWID_MAX*sin (phase)) 159 | # define HALF_OF(x) ((x)>>1) 160 | #elif defined(USE_SIMD) 161 | # define KISS_FFT_COS(phase) _mm_set1_ps( cos(phase) ) 162 | # define KISS_FFT_SIN(phase) _mm_set1_ps( sin(phase) ) 163 | # define HALF_OF(x) ((x)*_mm_set1_ps(.5f)) 164 | #else 165 | # define KISS_FFT_COS(phase) (kiss_fft_scalar) cos(phase) 166 | # define KISS_FFT_SIN(phase) (kiss_fft_scalar) sin(phase) 167 | # define HALF_OF(x) ((x)*.5f) 168 | #endif 169 | 170 | #define kf_cexp(x,phase) \ 171 | do{ \ 172 | (x)->r = KISS_FFT_COS(phase);\ 173 | (x)->i = KISS_FFT_SIN(phase);\ 174 | }while(0) 175 | 176 | #define kf_cexp2(x,phase) \ 177 | do{ \ 178 | (x)->r = TRIG_UPSCALE*celt_cos_norm((phase));\ 179 | (x)->i = TRIG_UPSCALE*celt_cos_norm((phase)-32768);\ 180 | }while(0) 181 | 182 | #endif /* KISS_FFT_GUTS_H */ 183 | -------------------------------------------------------------------------------- /src/arch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jzi040941/PercepNet/8ffae4337d23f920176ac2a7426e84610fa338ab/src/arch -------------------------------------------------------------------------------- /src/arch.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2003-2008 Jean-Marc Valin 2 | Copyright (c) 2007-2008 CSIRO 3 | Copyright (c) 2007-2009 Xiph.Org Foundation 4 | Written by Jean-Marc Valin */ 5 | /** 6 | @file arch.h 7 | @brief Various architecture definitions for CELT 8 | */ 9 | /* 10 | Redistribution and use in source and binary forms, with or without 11 | modification, are permitted provided that the following conditions 12 | are met: 13 | 14 | - Redistributions of source code must retain the above copyright 15 | notice, this list of conditions and the following disclaimer. 16 | 17 | - Redistributions in binary form must reproduce the above copyright 18 | notice, this list of conditions and the following disclaimer in the 19 | documentation and/or other materials provided with the distribution. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 25 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 26 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 27 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 28 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 29 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 30 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 31 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | 34 | #ifndef ARCH_H 35 | #define ARCH_H 36 | 37 | #include "opus_types.h" 38 | #include "common.h" 39 | 40 | # if !defined(__GNUC_PREREQ) 41 | # if defined(__GNUC__)&&defined(__GNUC_MINOR__) 42 | # define __GNUC_PREREQ(_maj,_min) \ 43 | ((__GNUC__<<16)+__GNUC_MINOR__>=((_maj)<<16)+(_min)) 44 | # else 45 | # define __GNUC_PREREQ(_maj,_min) 0 46 | # endif 47 | # endif 48 | 49 | #define CELT_SIG_SCALE 32768.f 50 | 51 | #define celt_fatal(str) _celt_fatal(str, __FILE__, __LINE__); 52 | #ifdef ENABLE_ASSERTIONS 53 | #include 54 | #include 55 | #ifdef __GNUC__ 56 | __attribute__((noreturn)) 57 | #endif 58 | static OPUS_INLINE void _celt_fatal(const char *str, const char *file, int line) 59 | { 60 | fprintf (stderr, "Fatal (internal) error in %s, line %d: %s\n", file, line, str); 61 | abort(); 62 | } 63 | #define celt_assert(cond) {if (!(cond)) {celt_fatal("assertion failed: " #cond);}} 64 | #define celt_assert2(cond, message) {if (!(cond)) {celt_fatal("assertion failed: " #cond "\n" message);}} 65 | #else 66 | #define celt_assert(cond) 67 | #define celt_assert2(cond, message) 68 | #endif 69 | 70 | #define IMUL32(a,b) ((a)*(b)) 71 | 72 | #define MIN16(a,b) ((a) < (b) ? (a) : (b)) /**< Minimum 16-bit value. */ 73 | #define MAX16(a,b) ((a) > (b) ? (a) : (b)) /**< Maximum 16-bit value. */ 74 | #define MIN32(a,b) ((a) < (b) ? (a) : (b)) /**< Minimum 32-bit value. */ 75 | #define MAX32(a,b) ((a) > (b) ? (a) : (b)) /**< Maximum 32-bit value. */ 76 | #define IMIN(a,b) ((a) < (b) ? (a) : (b)) /**< Minimum int value. */ 77 | #define IMAX(a,b) ((a) > (b) ? (a) : (b)) /**< Maximum int value. */ 78 | #define UADD32(a,b) ((a)+(b)) 79 | #define USUB32(a,b) ((a)-(b)) 80 | 81 | /* Set this if opus_int64 is a native type of the CPU. */ 82 | /* Assume that all LP64 architectures have fast 64-bit types; also x86_64 83 | (which can be ILP32 for x32) and Win64 (which is LLP64). */ 84 | #if defined(__x86_64__) || defined(__LP64__) || defined(_WIN64) 85 | #define OPUS_FAST_INT64 1 86 | #else 87 | #define OPUS_FAST_INT64 0 88 | #endif 89 | 90 | #define PRINT_MIPS(file) 91 | 92 | #ifdef FIXED_POINT 93 | 94 | typedef opus_int16 opus_val16; 95 | typedef opus_int32 opus_val32; 96 | typedef opus_int64 opus_val64; 97 | 98 | typedef opus_val32 celt_sig; 99 | typedef opus_val16 celt_norm; 100 | typedef opus_val32 celt_ener; 101 | 102 | #define Q15ONE 32767 103 | 104 | #define SIG_SHIFT 12 105 | /* Safe saturation value for 32-bit signals. Should be less than 106 | 2^31*(1-0.85) to avoid blowing up on DC at deemphasis.*/ 107 | #define SIG_SAT (300000000) 108 | 109 | #define NORM_SCALING 16384 110 | 111 | #define DB_SHIFT 10 112 | 113 | #define EPSILON 1 114 | #define VERY_SMALL 0 115 | #define VERY_LARGE16 ((opus_val16)32767) 116 | #define Q15_ONE ((opus_val16)32767) 117 | 118 | #define SCALEIN(a) (a) 119 | #define SCALEOUT(a) (a) 120 | 121 | #define ABS16(x) ((x) < 0 ? (-(x)) : (x)) 122 | #define ABS32(x) ((x) < 0 ? (-(x)) : (x)) 123 | 124 | static OPUS_INLINE opus_int16 SAT16(opus_int32 x) { 125 | return x > 32767 ? 32767 : x < -32768 ? -32768 : (opus_int16)x; 126 | } 127 | 128 | #ifdef FIXED_DEBUG 129 | #include "fixed_debug.h" 130 | #else 131 | 132 | #include "fixed_generic.h" 133 | 134 | #ifdef OPUS_ARM_PRESUME_AARCH64_NEON_INTR 135 | #include "arm/fixed_arm64.h" 136 | #elif OPUS_ARM_INLINE_EDSP 137 | #include "arm/fixed_armv5e.h" 138 | #elif defined (OPUS_ARM_INLINE_ASM) 139 | #include "arm/fixed_armv4.h" 140 | #elif defined (BFIN_ASM) 141 | #include "fixed_bfin.h" 142 | #elif defined (TI_C5X_ASM) 143 | #include "fixed_c5x.h" 144 | #elif defined (TI_C6X_ASM) 145 | #include "fixed_c6x.h" 146 | #endif 147 | 148 | #endif 149 | 150 | #else /* FIXED_POINT */ 151 | 152 | typedef float opus_val16; 153 | typedef float opus_val32; 154 | typedef float opus_val64; 155 | 156 | typedef float celt_sig; 157 | typedef float celt_norm; 158 | typedef float celt_ener; 159 | 160 | #ifdef FLOAT_APPROX 161 | /* This code should reliably detect NaN/inf even when -ffast-math is used. 162 | Assumes IEEE 754 format. */ 163 | static OPUS_INLINE int celt_isnan(float x) 164 | { 165 | union {float f; opus_uint32 i;} in; 166 | in.f = x; 167 | return ((in.i>>23)&0xFF)==0xFF && (in.i&0x007FFFFF)!=0; 168 | } 169 | #else 170 | #ifdef __FAST_MATH__ 171 | #error Cannot build libopus with -ffast-math unless FLOAT_APPROX is defined. This could result in crashes on extreme (e.g. NaN) input 172 | #endif 173 | #define celt_isnan(x) ((x)!=(x)) 174 | #endif 175 | 176 | #define Q15ONE 1.0f 177 | 178 | #define NORM_SCALING 1.f 179 | 180 | #define EPSILON 1e-15f 181 | #define VERY_SMALL 1e-30f 182 | #define VERY_LARGE16 1e15f 183 | #define Q15_ONE ((opus_val16)1.f) 184 | 185 | /* This appears to be the same speed as C99's fabsf() but it's more portable. */ 186 | #define ABS16(x) ((float)fabs(x)) 187 | #define ABS32(x) ((float)fabs(x)) 188 | 189 | #define QCONST16(x,bits) (x) 190 | #define QCONST32(x,bits) (x) 191 | 192 | #define NEG16(x) (-(x)) 193 | #define NEG32(x) (-(x)) 194 | #define NEG32_ovflw(x) (-(x)) 195 | #define EXTRACT16(x) (x) 196 | #define EXTEND32(x) (x) 197 | #define SHR16(a,shift) (a) 198 | #define SHL16(a,shift) (a) 199 | #define SHR32(a,shift) (a) 200 | #define SHL32(a,shift) (a) 201 | #define PSHR32(a,shift) (a) 202 | #define VSHR32(a,shift) (a) 203 | 204 | #define PSHR(a,shift) (a) 205 | #define SHR(a,shift) (a) 206 | #define SHL(a,shift) (a) 207 | #define SATURATE(x,a) (x) 208 | #define SATURATE16(x) (x) 209 | 210 | #define ROUND16(a,shift) (a) 211 | #define SROUND16(a,shift) (a) 212 | #define HALF16(x) (.5f*(x)) 213 | #define HALF32(x) (.5f*(x)) 214 | 215 | #define ADD16(a,b) ((a)+(b)) 216 | #define SUB16(a,b) ((a)-(b)) 217 | #define ADD32(a,b) ((a)+(b)) 218 | #define SUB32(a,b) ((a)-(b)) 219 | #define ADD32_ovflw(a,b) ((a)+(b)) 220 | #define SUB32_ovflw(a,b) ((a)-(b)) 221 | #define MULT16_16_16(a,b) ((a)*(b)) 222 | #define MULT16_16(a,b) ((opus_val32)(a)*(opus_val32)(b)) 223 | #define MAC16_16(c,a,b) ((c)+(opus_val32)(a)*(opus_val32)(b)) 224 | 225 | #define MULT16_32_Q15(a,b) ((a)*(b)) 226 | #define MULT16_32_Q16(a,b) ((a)*(b)) 227 | 228 | #define MULT32_32_Q31(a,b) ((a)*(b)) 229 | 230 | #define MAC16_32_Q15(c,a,b) ((c)+(a)*(b)) 231 | #define MAC16_32_Q16(c,a,b) ((c)+(a)*(b)) 232 | 233 | #define MULT16_16_Q11_32(a,b) ((a)*(b)) 234 | #define MULT16_16_Q11(a,b) ((a)*(b)) 235 | #define MULT16_16_Q13(a,b) ((a)*(b)) 236 | #define MULT16_16_Q14(a,b) ((a)*(b)) 237 | #define MULT16_16_Q15(a,b) ((a)*(b)) 238 | #define MULT16_16_P15(a,b) ((a)*(b)) 239 | #define MULT16_16_P13(a,b) ((a)*(b)) 240 | #define MULT16_16_P14(a,b) ((a)*(b)) 241 | #define MULT16_32_P16(a,b) ((a)*(b)) 242 | 243 | #define DIV32_16(a,b) (((opus_val32)(a))/(opus_val16)(b)) 244 | #define DIV32(a,b) (((opus_val32)(a))/(opus_val32)(b)) 245 | 246 | #define SCALEIN(a) ((a)*CELT_SIG_SCALE) 247 | #define SCALEOUT(a) ((a)*(1/CELT_SIG_SCALE)) 248 | 249 | #define SIG2WORD16(x) (x) 250 | 251 | #endif /* !FIXED_POINT */ 252 | 253 | #ifndef GLOBAL_STACK_SIZE 254 | #ifdef FIXED_POINT 255 | #define GLOBAL_STACK_SIZE 120000 256 | #else 257 | #define GLOBAL_STACK_SIZE 120000 258 | #endif 259 | #endif 260 | 261 | #endif /* ARCH_H */ 262 | -------------------------------------------------------------------------------- /src/celt_lpc.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2009-2010 Xiph.Org Foundation 2 | Written by Jean-Marc Valin */ 3 | /* 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | - Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | - Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 19 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | 28 | #ifdef HAVE_CONFIG_H 29 | #include "config.h" 30 | #endif 31 | 32 | #include "celt_lpc.h" 33 | #include "arch.h" 34 | #include "common.h" 35 | #include "pitch.h" 36 | 37 | void _celt_lpc( 38 | opus_val16 *_lpc, /* out: [0...p-1] LPC coefficients */ 39 | const opus_val32 *ac, /* in: [0...p] autocorrelation values */ 40 | int p 41 | ) 42 | { 43 | int i, j; 44 | opus_val32 r; 45 | opus_val32 error = ac[0]; 46 | #ifdef FIXED_POINT 47 | opus_val32 lpc[LPC_ORDER]; 48 | #else 49 | float *lpc = _lpc; 50 | #endif 51 | 52 | RNN_CLEAR(lpc, p); 53 | if (ac[0] != 0) 54 | { 55 | for (i = 0; i < p; i++) { 56 | /* Sum up this iteration's reflection coefficient */ 57 | opus_val32 rr = 0; 58 | for (j = 0; j < i; j++) 59 | rr += MULT32_32_Q31(lpc[j],ac[i - j]); 60 | rr += SHR32(ac[i + 1],3); 61 | r = -SHL32(rr,3)/(error+0.00001); 62 | /* Update LPC coefficients and total error */ 63 | lpc[i] = SHR32(r,3); 64 | for (j = 0; j < (i+1)>>1; j++) 65 | { 66 | opus_val32 tmp1, tmp2; 67 | tmp1 = lpc[j]; 68 | tmp2 = lpc[i-1-j]; 69 | lpc[j] = tmp1 + MULT32_32_Q31(r,tmp2); 70 | lpc[i-1-j] = tmp2 + MULT32_32_Q31(r,tmp1); 71 | } 72 | 73 | error = error - MULT32_32_Q31(MULT32_32_Q31(r,r),error); 74 | /* Bail out once we get 30 dB gain */ 75 | #ifdef FIXED_POINT 76 | if (error=1;j--) 141 | { 142 | mem[j]=mem[j-1]; 143 | } 144 | mem[0] = SROUND16(sum, SIG_SHIFT); 145 | _y[i] = sum; 146 | } 147 | #else 148 | int i,j; 149 | celt_assert((ord&3)==0); 150 | opus_val16 rden[ord]; 151 | opus_val16 y[N+ord]; 152 | for(i=0;i0); 213 | celt_assert(overlap>=0); 214 | if (overlap == 0) 215 | { 216 | xptr = x; 217 | } else { 218 | for (i=0;i0) 242 | { 243 | for(i=0;i= 536870912) 268 | { 269 | int shift2=1; 270 | if (ac[0] >= 1073741824) 271 | shift2++; 272 | for (i=0;i<=lag;i++) 273 | ac[i] = SHR32(ac[i], shift2); 274 | shift += shift2; 275 | } 276 | #endif 277 | 278 | return shift; 279 | } 280 | -------------------------------------------------------------------------------- /src/celt_lpc.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2009-2010 Xiph.Org Foundation 2 | Written by Jean-Marc Valin */ 3 | /* 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | - Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | - Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 19 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | 28 | #ifndef PLC_H 29 | #define PLC_H 30 | 31 | #include "arch.h" 32 | #include "common.h" 33 | 34 | #if defined(OPUS_X86_MAY_HAVE_SSE4_1) 35 | #include "x86/celt_lpc_sse.h" 36 | #endif 37 | 38 | #define LPC_ORDER 24 39 | 40 | void _celt_lpc(opus_val16 *_lpc, const opus_val32 *ac, int p); 41 | 42 | void celt_fir( 43 | const opus_val16 *x, 44 | const opus_val16 *num, 45 | opus_val16 *y, 46 | int N, 47 | int ord); 48 | 49 | void celt_iir(const opus_val32 *x, 50 | const opus_val16 *den, 51 | opus_val32 *y, 52 | int N, 53 | int ord, 54 | opus_val16 *mem); 55 | 56 | int _celt_autocorr(const opus_val16 *x, opus_val32 *ac, 57 | const opus_val16 *window, int overlap, int lag, int n); 58 | 59 | #endif /* PLC_H */ 60 | -------------------------------------------------------------------------------- /src/common.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | #ifndef COMMON_H 4 | #define COMMON_H 5 | 6 | #include "stdlib.h" 7 | #include "string.h" 8 | 9 | #define RNN_INLINE inline 10 | #define OPUS_INLINE inline 11 | 12 | 13 | /** RNNoise wrapper for malloc(). To do your own dynamic allocation, all you need t 14 | o do is replace this function and rnnoise_free */ 15 | #ifndef OVERRIDE_RNNOISE_ALLOC 16 | static RNN_INLINE void *rnnoise_alloc (size_t size) 17 | { 18 | return malloc(size); 19 | } 20 | #endif 21 | 22 | /** RNNoise wrapper for free(). To do your own dynamic allocation, all you need to do is replace this function and rnnoise_alloc */ 23 | #ifndef OVERRIDE_RNNOISE_FREE 24 | static RNN_INLINE void rnnoise_free (void *ptr) 25 | { 26 | free(ptr); 27 | } 28 | #endif 29 | 30 | /** Copy n elements from src to dst. The 0* term provides compile-time type checking */ 31 | #ifndef OVERRIDE_RNN_COPY 32 | #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) )) 33 | #endif 34 | 35 | /** Copy n elements from src to dst, allowing overlapping regions. The 0* term 36 | provides compile-time type checking */ 37 | #ifndef OVERRIDE_RNN_MOVE 38 | #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) )) 39 | #endif 40 | 41 | /** Set n elements of dst to zero */ 42 | #ifndef OVERRIDE_RNN_CLEAR 43 | #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst)))) 44 | #endif 45 | 46 | 47 | 48 | #endif 49 | -------------------------------------------------------------------------------- /src/erbband.h: -------------------------------------------------------------------------------- 1 | #include "stdio.h" 2 | #include 3 | #include 4 | #include 5 | 6 | template 7 | std::vector linspace(T start_in, T end_in, int num_in) 8 | { 9 | 10 | std::vector linspaced; 11 | 12 | float start = static_cast(start_in); 13 | float end = static_cast(end_in); 14 | float num = static_cast(num_in); 15 | 16 | if (num == 0) { return linspaced; } 17 | if (num == 1) 18 | { 19 | linspaced.push_back(start); 20 | return linspaced; 21 | } 22 | 23 | float delta = (end - start) / (num - 1); 24 | 25 | for(int i=0; i < num-1; ++i) 26 | { 27 | linspaced.push_back(start + delta * i); 28 | } 29 | linspaced.push_back(end); // I want to ensure that start and end 30 | // are exactly the same as the input 31 | return linspaced; 32 | } 33 | 34 | class ERBBand{ 35 | float erb_low,erb_high; 36 | int bandN; 37 | std::vector cutoffs,erb_lims; 38 | 39 | 40 | public: 41 | std::vector nfftborder,centerfreqs; 42 | std::vector,std::vector>> filters; 43 | ERBBand(int window_size, int N, float low_lim, float high_lim){ 44 | cutoffs.assign(N+2,0); 45 | int i; 46 | erb_low = freq2erb(low_lim); 47 | erb_high = freq2erb(high_lim); 48 | erb_lims = linspace(erb_low, erb_high, N+2); 49 | for(i=0; i,std::vector>> make_filters(int N){ 64 | std::vector,std::vector>> cos_filter; 65 | float freqRangePerBin = 50;//for 48000 smaplerate and 960 window_size fft 66 | float l_k, h_k, avg, rnge; 67 | int l_nfftind, h_nfftind; 68 | for(int k=0; k,std::vector> kthfilter = {{l_nfftind,h_nfftind+1},{}}; 88 | //cos_filter.push_back() 89 | for(int i=l_nfftind; itwiddles; 143 | /* m is guaranteed to be a multiple of 4. */ 144 | for (j=0;jtwiddles[fstride*m]; 195 | #endif 196 | for (i=0;itwiddles; 200 | /* For non-custom modes, m is guaranteed to be a multiple of 4. */ 201 | k=m; 202 | do { 203 | 204 | C_MUL(scratch[1],Fout[m] , *tw1); 205 | C_MUL(scratch[2],Fout[m2] , *tw2); 206 | 207 | C_ADD(scratch[3],scratch[1],scratch[2]); 208 | C_SUB(scratch[0],scratch[1],scratch[2]); 209 | tw1 += fstride; 210 | tw2 += fstride*2; 211 | 212 | Fout[m].r = SUB32_ovflw(Fout->r, HALF_OF(scratch[3].r)); 213 | Fout[m].i = SUB32_ovflw(Fout->i, HALF_OF(scratch[3].i)); 214 | 215 | C_MULBYSCALAR( scratch[0] , epi3.i ); 216 | 217 | C_ADDTO(*Fout,scratch[3]); 218 | 219 | Fout[m2].r = ADD32_ovflw(Fout[m].r, scratch[0].i); 220 | Fout[m2].i = SUB32_ovflw(Fout[m].i, scratch[0].r); 221 | 222 | Fout[m].r = SUB32_ovflw(Fout[m].r, scratch[0].i); 223 | Fout[m].i = ADD32_ovflw(Fout[m].i, scratch[0].r); 224 | 225 | ++Fout; 226 | } while(--k); 227 | } 228 | } 229 | 230 | 231 | #ifndef OVERRIDE_kf_bfly5 232 | static void kf_bfly5( 233 | kiss_fft_cpx * Fout, 234 | const size_t fstride, 235 | const kiss_fft_state *st, 236 | int m, 237 | int N, 238 | int mm 239 | ) 240 | { 241 | kiss_fft_cpx *Fout0,*Fout1,*Fout2,*Fout3,*Fout4; 242 | int i, u; 243 | kiss_fft_cpx scratch[13]; 244 | const kiss_twiddle_cpx *tw; 245 | kiss_twiddle_cpx ya,yb; 246 | kiss_fft_cpx * Fout_beg = Fout; 247 | 248 | #ifdef FIXED_POINT 249 | ya.r = 10126; 250 | ya.i = -31164; 251 | yb.r = -26510; 252 | yb.i = -19261; 253 | #else 254 | ya = st->twiddles[fstride*m]; 255 | yb = st->twiddles[fstride*2*m]; 256 | #endif 257 | tw=st->twiddles; 258 | 259 | for (i=0;ir = ADD32_ovflw(Fout0->r, ADD32_ovflw(scratch[7].r, scratch[8].r)); 283 | Fout0->i = ADD32_ovflw(Fout0->i, ADD32_ovflw(scratch[7].i, scratch[8].i)); 284 | 285 | scratch[5].r = ADD32_ovflw(scratch[0].r, ADD32_ovflw(S_MUL(scratch[7].r,ya.r), S_MUL(scratch[8].r,yb.r))); 286 | scratch[5].i = ADD32_ovflw(scratch[0].i, ADD32_ovflw(S_MUL(scratch[7].i,ya.r), S_MUL(scratch[8].i,yb.r))); 287 | 288 | scratch[6].r = ADD32_ovflw(S_MUL(scratch[10].i,ya.i), S_MUL(scratch[9].i,yb.i)); 289 | scratch[6].i = NEG32_ovflw(ADD32_ovflw(S_MUL(scratch[10].r,ya.i), S_MUL(scratch[9].r,yb.i))); 290 | 291 | C_SUB(*Fout1,scratch[5],scratch[6]); 292 | C_ADD(*Fout4,scratch[5],scratch[6]); 293 | 294 | scratch[11].r = ADD32_ovflw(scratch[0].r, ADD32_ovflw(S_MUL(scratch[7].r,yb.r), S_MUL(scratch[8].r,ya.r))); 295 | scratch[11].i = ADD32_ovflw(scratch[0].i, ADD32_ovflw(S_MUL(scratch[7].i,yb.r), S_MUL(scratch[8].i,ya.r))); 296 | scratch[12].r = SUB32_ovflw(S_MUL(scratch[9].i,ya.i), S_MUL(scratch[10].i,yb.i)); 297 | scratch[12].i = SUB32_ovflw(S_MUL(scratch[10].r,yb.i), S_MUL(scratch[9].r,ya.i)); 298 | 299 | C_ADD(*Fout2,scratch[11],scratch[12]); 300 | C_SUB(*Fout3,scratch[11],scratch[12]); 301 | 302 | ++Fout0;++Fout1;++Fout2;++Fout3;++Fout4; 303 | } 304 | } 305 | } 306 | #endif /* OVERRIDE_kf_bfly5 */ 307 | 308 | 309 | #endif 310 | 311 | 312 | #ifdef CUSTOM_MODES 313 | 314 | static 315 | void compute_bitrev_table( 316 | int Fout, 317 | opus_int16 *f, 318 | const size_t fstride, 319 | int in_stride, 320 | opus_int16 * factors, 321 | const kiss_fft_state *st 322 | ) 323 | { 324 | const int p=*factors++; /* the radix */ 325 | const int m=*factors++; /* stage's fft length/p */ 326 | 327 | /*printf ("fft %d %d %d %d %d %d\n", p*m, m, p, s2, fstride*in_stride, N);*/ 328 | if (m==1) 329 | { 330 | int j; 331 | for (j=0;j32000 || (opus_int32)p*(opus_int32)p > n) 368 | p = n; /* no more factors, skip to end */ 369 | } 370 | n /= p; 371 | #ifdef RADIX_TWO_ONLY 372 | if (p!=2 && p != 4) 373 | #else 374 | if (p>5) 375 | #endif 376 | { 377 | return 0; 378 | } 379 | facbuf[2*stages] = p; 380 | if (p==2 && stages > 1) 381 | { 382 | facbuf[2*stages] = 4; 383 | facbuf[2] = 2; 384 | } 385 | stages++; 386 | } while (n > 1); 387 | n = nbak; 388 | /* Reverse the order to get the radix 4 at the end, so we can use the 389 | fast degenerate case. It turns out that reversing the order also 390 | improves the noise behaviour. */ 391 | for (i=0;i= memneeded) 444 | st = (kiss_fft_state*)mem; 445 | *lenmem = memneeded; 446 | } 447 | if (st) { 448 | opus_int16 *bitrev; 449 | kiss_twiddle_cpx *twiddles; 450 | 451 | st->nfft=nfft; 452 | #ifdef FIXED_POINT 453 | st->scale_shift = celt_ilog2(st->nfft); 454 | if (st->nfft == 1<scale_shift) 455 | st->scale = Q15ONE; 456 | else 457 | st->scale = (1073741824+st->nfft/2)/st->nfft>>(15-st->scale_shift); 458 | #else 459 | st->scale = 1.f/nfft; 460 | #endif 461 | if (base != NULL) 462 | { 463 | st->twiddles = base->twiddles; 464 | st->shift = 0; 465 | while (st->shift < 32 && nfft<shift != base->nfft) 466 | st->shift++; 467 | if (st->shift>=32) 468 | goto fail; 469 | } else { 470 | st->twiddles = twiddles = (kiss_twiddle_cpx*)KISS_FFT_MALLOC(sizeof(kiss_twiddle_cpx)*nfft); 471 | compute_twiddles(twiddles, nfft); 472 | st->shift = -1; 473 | } 474 | if (!kf_factor(nfft,st->factors)) 475 | { 476 | goto fail; 477 | } 478 | 479 | /* bitrev */ 480 | st->bitrev = bitrev = (opus_int16*)KISS_FFT_MALLOC(sizeof(opus_int16)*nfft); 481 | if (st->bitrev==NULL) 482 | goto fail; 483 | compute_bitrev_table(0, bitrev, 1,1, st->factors,st); 484 | 485 | /* Initialize architecture specific fft parameters */ 486 | if (opus_fft_alloc_arch(st, arch)) 487 | goto fail; 488 | } 489 | return st; 490 | fail: 491 | opus_fft_free(st, arch); 492 | return NULL; 493 | } 494 | 495 | kiss_fft_state *opus_fft_alloc(int nfft,void * mem,size_t * lenmem, int arch) 496 | { 497 | return opus_fft_alloc_twiddles(nfft, mem, lenmem, NULL, arch); 498 | } 499 | 500 | void opus_fft_free_arch_c(kiss_fft_state *st) { 501 | (void)st; 502 | } 503 | 504 | void opus_fft_free(const kiss_fft_state *cfg, int arch) 505 | { 506 | if (cfg) 507 | { 508 | opus_fft_free_arch((kiss_fft_state *)cfg, arch); 509 | opus_free((opus_int16*)cfg->bitrev); 510 | if (cfg->shift < 0) 511 | opus_free((kiss_twiddle_cpx*)cfg->twiddles); 512 | opus_free((kiss_fft_state*)cfg); 513 | } 514 | } 515 | 516 | #endif /* CUSTOM_MODES */ 517 | 518 | void opus_fft_impl(const kiss_fft_state *st,kiss_fft_cpx *fout) 519 | { 520 | int m2, m; 521 | int p; 522 | int L; 523 | int fstride[MAXFACTORS]; 524 | int i; 525 | int shift; 526 | 527 | /* st->shift can be -1 */ 528 | shift = st->shift>0 ? st->shift : 0; 529 | 530 | fstride[0] = 1; 531 | L=0; 532 | do { 533 | p = st->factors[2*L]; 534 | m = st->factors[2*L+1]; 535 | fstride[L+1] = fstride[L]*p; 536 | L++; 537 | } while(m!=1); 538 | m = st->factors[2*L-1]; 539 | for (i=L-1;i>=0;i--) 540 | { 541 | if (i!=0) 542 | m2 = st->factors[2*i-1]; 543 | else 544 | m2 = 1; 545 | switch (st->factors[2*i]) 546 | { 547 | case 2: 548 | kf_bfly2(fout, m, fstride[i]); 549 | break; 550 | case 4: 551 | kf_bfly4(fout,fstride[i]<scale_shift-1; 574 | #endif 575 | scale = st->scale; 576 | 577 | celt_assert2 (fin != fout, "In-place FFT not supported"); 578 | /* Bit-reverse the input */ 579 | for (i=0;infft;i++) 580 | { 581 | kiss_fft_cpx x = fin[i]; 582 | fout[st->bitrev[i]].r = SHR32(MULT16_32_Q16(scale, x.r), scale_shift); 583 | fout[st->bitrev[i]].i = SHR32(MULT16_32_Q16(scale, x.i), scale_shift); 584 | } 585 | opus_fft_impl(st, fout); 586 | } 587 | 588 | 589 | void opus_ifft_c(const kiss_fft_state *st,const kiss_fft_cpx *fin,kiss_fft_cpx *fout) 590 | { 591 | int i; 592 | celt_assert2 (fin != fout, "In-place FFT not supported"); 593 | /* Bit-reverse the input */ 594 | for (i=0;infft;i++) 595 | fout[st->bitrev[i]] = fin[i]; 596 | for (i=0;infft;i++) 597 | fout[i].i = -fout[i].i; 598 | opus_fft_impl(st, fout); 599 | for (i=0;infft;i++) 600 | fout[i].i = -fout[i].i; 601 | } 602 | -------------------------------------------------------------------------------- /src/kiss_fft.h: -------------------------------------------------------------------------------- 1 | /*Copyright (c) 2003-2004, Mark Borgerding 2 | Lots of modifications by Jean-Marc Valin 3 | Copyright (c) 2005-2007, Xiph.Org Foundation 4 | Copyright (c) 2008, Xiph.Org Foundation, CSIRO 5 | 6 | All rights reserved. 7 | 8 | Redistribution and use in source and binary forms, with or without 9 | modification, are permitted provided that the following conditions are met: 10 | 11 | * Redistributions of source code must retain the above copyright notice, 12 | this list of conditions and the following disclaimer. 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the 15 | documentation and/or other materials provided with the distribution. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 21 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | POSSIBILITY OF SUCH DAMAGE.*/ 28 | 29 | #ifndef KISS_FFT_H 30 | #define KISS_FFT_H 31 | 32 | #include 33 | #include 34 | #include "arch.h" 35 | 36 | #include 37 | #define opus_alloc(x) malloc(x) 38 | #define opus_free(x) free(x) 39 | 40 | #ifdef __cplusplus 41 | extern "C" { 42 | #endif 43 | 44 | #ifdef USE_SIMD 45 | # include 46 | # define kiss_fft_scalar __m128 47 | #define KISS_FFT_MALLOC(nbytes) memalign(16,nbytes) 48 | #else 49 | #define KISS_FFT_MALLOC opus_alloc 50 | #endif 51 | 52 | #ifdef FIXED_POINT 53 | #include "arch.h" 54 | 55 | # define kiss_fft_scalar opus_int32 56 | # define kiss_twiddle_scalar opus_int16 57 | 58 | 59 | #else 60 | # ifndef kiss_fft_scalar 61 | /* default is float */ 62 | # define kiss_fft_scalar float 63 | # define kiss_twiddle_scalar float 64 | # define KF_SUFFIX _celt_single 65 | # endif 66 | #endif 67 | 68 | typedef struct { 69 | kiss_fft_scalar r; 70 | kiss_fft_scalar i; 71 | }kiss_fft_cpx; 72 | 73 | typedef struct { 74 | kiss_twiddle_scalar r; 75 | kiss_twiddle_scalar i; 76 | }kiss_twiddle_cpx; 77 | 78 | #define MAXFACTORS 8 79 | /* e.g. an fft of length 128 has 4 factors 80 | as far as kissfft is concerned 81 | 4*4*4*2 82 | */ 83 | 84 | typedef struct arch_fft_state{ 85 | int is_supported; 86 | void *priv; 87 | } arch_fft_state; 88 | 89 | typedef struct kiss_fft_state{ 90 | int nfft; 91 | opus_val16 scale; 92 | #ifdef FIXED_POINT 93 | int scale_shift; 94 | #endif 95 | int shift; 96 | opus_int16 factors[2*MAXFACTORS]; 97 | const opus_int16 *bitrev; 98 | const kiss_twiddle_cpx *twiddles; 99 | arch_fft_state *arch_fft; 100 | } kiss_fft_state; 101 | 102 | #if defined(HAVE_ARM_NE10) 103 | #include "arm/fft_arm.h" 104 | #endif 105 | 106 | /*typedef struct kiss_fft_state* kiss_fft_cfg;*/ 107 | 108 | /** 109 | * opus_fft_alloc 110 | * 111 | * Initialize a FFT (or IFFT) algorithm's cfg/state buffer. 112 | * 113 | * typical usage: kiss_fft_cfg mycfg=opus_fft_alloc(1024,0,NULL,NULL); 114 | * 115 | * The return value from fft_alloc is a cfg buffer used internally 116 | * by the fft routine or NULL. 117 | * 118 | * If lenmem is NULL, then opus_fft_alloc will allocate a cfg buffer using malloc. 119 | * The returned value should be free()d when done to avoid memory leaks. 120 | * 121 | * The state can be placed in a user supplied buffer 'mem': 122 | * If lenmem is not NULL and mem is not NULL and *lenmem is large enough, 123 | * then the function places the cfg in mem and the size used in *lenmem 124 | * and returns mem. 125 | * 126 | * If lenmem is not NULL and ( mem is NULL or *lenmem is not large enough), 127 | * then the function returns NULL and places the minimum cfg 128 | * buffer size in *lenmem. 129 | * */ 130 | 131 | kiss_fft_state *opus_fft_alloc_twiddles(int nfft,void * mem,size_t * lenmem, const kiss_fft_state *base, int arch); 132 | 133 | kiss_fft_state *opus_fft_alloc(int nfft,void * mem,size_t * lenmem, int arch); 134 | 135 | /** 136 | * opus_fft(cfg,in_out_buf) 137 | * 138 | * Perform an FFT on a complex input buffer. 139 | * for a forward FFT, 140 | * fin should be f[0] , f[1] , ... ,f[nfft-1] 141 | * fout will be F[0] , F[1] , ... ,F[nfft-1] 142 | * Note that each element is complex and can be accessed like 143 | f[k].r and f[k].i 144 | * */ 145 | void opus_fft_c(const kiss_fft_state *cfg,const kiss_fft_cpx *fin,kiss_fft_cpx *fout); 146 | void opus_ifft_c(const kiss_fft_state *cfg,const kiss_fft_cpx *fin,kiss_fft_cpx *fout); 147 | 148 | void opus_fft_impl(const kiss_fft_state *st,kiss_fft_cpx *fout); 149 | void opus_ifft_impl(const kiss_fft_state *st,kiss_fft_cpx *fout); 150 | 151 | void opus_fft_free(const kiss_fft_state *cfg, int arch); 152 | 153 | 154 | void opus_fft_free_arch_c(kiss_fft_state *st); 155 | int opus_fft_alloc_arch_c(kiss_fft_state *st); 156 | 157 | #if !defined(OVERRIDE_OPUS_FFT) 158 | /* Is run-time CPU detection enabled on this platform? */ 159 | #if defined(OPUS_HAVE_RTCD) && (defined(HAVE_ARM_NE10)) 160 | 161 | extern int (*const OPUS_FFT_ALLOC_ARCH_IMPL[OPUS_ARCHMASK+1])( 162 | kiss_fft_state *st); 163 | 164 | #define opus_fft_alloc_arch(_st, arch) \ 165 | ((*OPUS_FFT_ALLOC_ARCH_IMPL[(arch)&OPUS_ARCHMASK])(_st)) 166 | 167 | extern void (*const OPUS_FFT_FREE_ARCH_IMPL[OPUS_ARCHMASK+1])( 168 | kiss_fft_state *st); 169 | #define opus_fft_free_arch(_st, arch) \ 170 | ((*OPUS_FFT_FREE_ARCH_IMPL[(arch)&OPUS_ARCHMASK])(_st)) 171 | 172 | extern void (*const OPUS_FFT[OPUS_ARCHMASK+1])(const kiss_fft_state *cfg, 173 | const kiss_fft_cpx *fin, kiss_fft_cpx *fout); 174 | #define opus_fft(_cfg, _fin, _fout, arch) \ 175 | ((*OPUS_FFT[(arch)&OPUS_ARCHMASK])(_cfg, _fin, _fout)) 176 | 177 | extern void (*const OPUS_IFFT[OPUS_ARCHMASK+1])(const kiss_fft_state *cfg, 178 | const kiss_fft_cpx *fin, kiss_fft_cpx *fout); 179 | #define opus_ifft(_cfg, _fin, _fout, arch) \ 180 | ((*OPUS_IFFT[(arch)&OPUS_ARCHMASK])(_cfg, _fin, _fout)) 181 | 182 | #else /* else for if defined(OPUS_HAVE_RTCD) && (defined(HAVE_ARM_NE10)) */ 183 | 184 | #define opus_fft_alloc_arch(_st, arch) \ 185 | ((void)(arch), opus_fft_alloc_arch_c(_st)) 186 | 187 | #define opus_fft_free_arch(_st, arch) \ 188 | ((void)(arch), opus_fft_free_arch_c(_st)) 189 | 190 | #define opus_fft(_cfg, _fin, _fout, arch) \ 191 | ((void)(arch), opus_fft_c(_cfg, _fin, _fout)) 192 | 193 | #define opus_ifft(_cfg, _fin, _fout, arch) \ 194 | ((void)(arch), opus_ifft_c(_cfg, _fin, _fout)) 195 | 196 | #endif /* end if defined(OPUS_HAVE_RTCD) && (defined(HAVE_ARM_NE10)) */ 197 | #endif /* end if !defined(OVERRIDE_OPUS_FFT) */ 198 | 199 | #ifdef __cplusplus 200 | } 201 | #endif 202 | 203 | #endif 204 | -------------------------------------------------------------------------------- /src/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "rnnoise.h" 3 | #include 4 | #define FRAME_SIZE 480 5 | 6 | //using namespace std; 7 | #ifndef TRAINING 8 | #define TRAINING 0 9 | #endif 10 | 11 | int main(int argc, char **argv) 12 | { 13 | if(TRAINING){ 14 | train(argc, argv); 15 | return 0; 16 | } 17 | int i; 18 | int first = 1; 19 | float x[FRAME_SIZE]; 20 | FILE *f1, *fout, *f_feature; 21 | DenoiseState *st; 22 | st = rnnoise_create(NULL); 23 | if (argc!=3) { 24 | fprintf(stderr, "usage: %s \n", argv[0]); 25 | return 1; 26 | } 27 | f1 = fopen(argv[1], "rb"); 28 | fout = fopen(argv[2], "wb"); 29 | f_feature = fopen("feature_test.raw", "wb"); 30 | while (1) { 31 | short tmp[FRAME_SIZE]; 32 | fread(tmp, sizeof(short), FRAME_SIZE, f1); 33 | if (feof(f1)) break; 34 | for (i=0;i 34 | #include 35 | #include "opus_types.h" 36 | #include "arch.h" 37 | #include "common.h" 38 | #include "tansig_table.h" 39 | #include "nnet.h" 40 | #include "nnet_data.h" 41 | 42 | #define SOFTMAX_HACK 43 | 44 | #ifdef __AVX__ 45 | #include "vec_avx.h" 46 | #elif __ARM_NEON__ 47 | #include "vec_neon.h" 48 | #else 49 | #warning Compiling without any vectorization. This code will be very slow 50 | #include "vec.h" 51 | #endif 52 | 53 | static OPUS_INLINE float relu(float x) 54 | { 55 | return x < 0 ? 0 : x; 56 | } 57 | 58 | 59 | static void sgemv_accum(float *out, const float *weights, int rows, int cols, int col_stride, const float *x) 60 | { 61 | int i, j; 62 | if (rows % 16 == 0) 63 | { 64 | sgemv_accum16(out, weights, rows, cols, col_stride, x); 65 | } else { 66 | for (i=0;inb_inputs; 111 | N = layer->nb_neurons; 112 | stride = N; 113 | celt_assert(input != output); 114 | for (i=0;ibias[i]; 116 | sgemv_accum(output, layer->input_weights, N, M, stride, input); 117 | compute_activation(output, output, N, layer->activation); 118 | } 119 | 120 | void compute_gru(const GRULayer *gru, float *state, const float *input) 121 | { 122 | int i; 123 | int N, M; 124 | int stride; 125 | float tmp[MAX_NEURONS]; 126 | float z[MAX_NEURONS]; 127 | float r[MAX_NEURONS]; 128 | float h[MAX_NEURONS]; 129 | celt_assert(gru->nb_neurons <= MAX_NEURONS); 130 | celt_assert(input != state); 131 | M = gru->nb_inputs; 132 | N = gru->nb_neurons; 133 | stride = 3*N; 134 | /* Compute update gate. */ 135 | for (i=0;ibias[i]; 137 | if (gru->reset_after) 138 | { 139 | for (i=0;ibias[3*N + i]; 141 | } 142 | sgemv_accum(z, gru->input_weights, N, M, stride, input); 143 | sgemv_accum(z, gru->recurrent_weights, N, N, stride, state); 144 | compute_activation(z, z, N, ACTIVATION_SIGMOID); 145 | 146 | /* Compute reset gate. */ 147 | for (i=0;ibias[N + i]; 149 | if (gru->reset_after) 150 | { 151 | for (i=0;ibias[4*N + i]; 153 | } 154 | sgemv_accum(r, &gru->input_weights[N], N, M, stride, input); 155 | sgemv_accum(r, &gru->recurrent_weights[N], N, N, stride, state); 156 | compute_activation(r, r, N, ACTIVATION_SIGMOID); 157 | 158 | /* Compute output. */ 159 | for (i=0;ibias[2*N + i]; 161 | if (gru->reset_after) 162 | { 163 | for (i=0;ibias[5*N + i]; 165 | sgemv_accum(tmp, &gru->recurrent_weights[2*N], N, N, stride, state); 166 | for (i=0;iinput_weights[2*N], N, M, stride, input); 169 | } else { 170 | for (i=0;iinput_weights[2*N], N, M, stride, input); 173 | sgemv_accum(h, &gru->recurrent_weights[2*N], N, N, stride, tmp); 174 | } 175 | compute_activation(h, h, N, gru->activation); 176 | for (i=0;inb_inputs*layer->kernel_size <= MAX_CONV_INPUTS); 190 | RNN_COPY(tmp, mem, layer->nb_inputs*(layer->kernel_size-1)); 191 | RNN_COPY(&tmp[layer->nb_inputs*(layer->kernel_size-1)], input, layer->nb_inputs); 192 | M = layer->nb_inputs*layer->kernel_size; 193 | N = layer->nb_neurons; 194 | stride = N; 195 | for (i=0;ibias[i]; 197 | sgemv_accum(output, layer->input_weights, N, M, stride, tmp); 198 | compute_activation(output, output, N, layer->activation); 199 | RNN_COPY(mem, &tmp[layer->nb_inputs], layer->nb_inputs*(layer->kernel_size-1)); 200 | } 201 | 202 | void compute_embedding(const EmbeddingLayer *layer, float *output, int input) 203 | { 204 | int i; 205 | celt_assert(input >= 0); 206 | celt_assert(input < layer->nb_inputs); 207 | /*if (layer->dim == 64) printf("%d\n", input);*/ 208 | for (i=0;idim;i++) 209 | { 210 | output[i] = layer->embedding_weights[input*layer->dim + i]; 211 | } 212 | } 213 | 214 | void accum_embedding(const EmbeddingLayer *layer, float *output, int input) 215 | { 216 | int i; 217 | celt_assert(input >= 0); 218 | celt_assert(input < layer->nb_inputs); 219 | /*if (layer->dim == 64) printf("%d\n", input);*/ 220 | for (i=0;idim;i++) 221 | { 222 | output[i] += layer->embedding_weights[input*layer->dim + i]; 223 | } 224 | } 225 | -------------------------------------------------------------------------------- /src/nnet.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2018 Mozilla 2 | Copyright (c) 2017 Jean-Marc Valin */ 3 | /* 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | - Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | - Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 19 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | 28 | #ifndef _NNET_H_ 29 | #define _NNET_H_ 30 | 31 | #define WEIGHTS_SCALE (1.f/256) 32 | #define MAX_NEURONS 512 33 | #define MAX_CONV_INPUTS 1536 34 | 35 | #define ACTIVATION_LINEAR 0 36 | #define ACTIVATION_SIGMOID 1 37 | #define ACTIVATION_TANH 2 38 | #define ACTIVATION_RELU 3 39 | #define ACTIVATION_SOFTMAX 4 40 | 41 | #define INPUT_SIZE 70 42 | #define CONV_DIM 512 43 | #define CONVOUT_BUF_SIZE CONV_DIM*3 44 | typedef struct { 45 | const float *bias; 46 | const float *input_weights; 47 | int nb_inputs; 48 | int nb_neurons; 49 | int activation; 50 | } DenseLayer; 51 | 52 | typedef struct { 53 | const float *bias; 54 | const float *input_weights; 55 | const float *factor; 56 | int nb_inputs; 57 | int nb_neurons; 58 | int nb_channels; 59 | int activation; 60 | } MDenseLayer; 61 | 62 | typedef struct { 63 | const float *bias; 64 | const float *input_weights; 65 | const float *recurrent_weights; 66 | int nb_inputs; 67 | int nb_neurons; 68 | int activation; 69 | int reset_after; 70 | } GRULayer; 71 | 72 | typedef struct { 73 | const float *bias; 74 | const float *diag_weights; 75 | const float *recurrent_weights; 76 | const int *idx; 77 | int nb_neurons; 78 | int activation; 79 | int reset_after; 80 | } SparseGRULayer; 81 | 82 | typedef struct { 83 | const float *bias; 84 | const float *input_weights; 85 | int nb_inputs; 86 | int kernel_size; 87 | int nb_neurons; 88 | int activation; 89 | } Conv1DLayer; 90 | 91 | typedef struct { 92 | const float *embedding_weights; 93 | int nb_inputs; 94 | int dim; 95 | } EmbeddingLayer; 96 | 97 | void compute_activation(float *output, float *input, int N, int activation); 98 | 99 | void compute_dense(const DenseLayer *layer, float *output, const float *input); 100 | 101 | void compute_mdense(const MDenseLayer *layer, float *output, const float *input); 102 | 103 | void compute_gru(const GRULayer *gru, float *state, const float *input); 104 | 105 | void compute_gru2(const GRULayer *gru, float *state, const float *input); 106 | 107 | void compute_gru3(const GRULayer *gru, float *state, const float *input); 108 | 109 | void compute_sparse_gru(const SparseGRULayer *gru, float *state, const float *input); 110 | 111 | void compute_conv1d(const Conv1DLayer *layer, float *output, float *mem, const float *input); 112 | 113 | void compute_embedding(const EmbeddingLayer *layer, float *output, int input); 114 | 115 | void accum_embedding(const EmbeddingLayer *layer, float *output, int input); 116 | 117 | int sample_from_pdf(const float *pdf, int N, float exp_boost, float pdf_floor); 118 | 119 | #endif /* _MLP_H_ */ 120 | -------------------------------------------------------------------------------- /src/nnet_data.h: -------------------------------------------------------------------------------- 1 | #ifndef NNET_DATA_H 2 | #define NNET_DATA_H 3 | 4 | #include "nnet.h" 5 | 6 | typedef struct RNNModel { 7 | const DenseLayer *fc; 8 | 9 | const Conv1DLayer *conv1; 10 | 11 | const Conv1DLayer *conv2; 12 | 13 | const GRULayer *gru1; 14 | 15 | const GRULayer *gru2; 16 | 17 | const GRULayer *gru3; 18 | 19 | const GRULayer *gru_gb; 20 | 21 | const GRULayer *gru_rb; 22 | 23 | const DenseLayer *fc_gb; 24 | 25 | const DenseLayer *fc_rb; 26 | }RNNModel; 27 | 28 | typedef struct RNNState { 29 | const RNNModel *model; 30 | float *first_conv1d_state; 31 | float *second_conv1d_state; 32 | float *gru1_state; 33 | float *gru2_state; 34 | float *gru3_state; 35 | float *gb_gru_state; 36 | float *rb_gru_state; 37 | float convout_buf[CONV_DIM*3]; 38 | } RNNState; 39 | 40 | 41 | #endif -------------------------------------------------------------------------------- /src/opus_types.h: -------------------------------------------------------------------------------- 1 | /* (C) COPYRIGHT 1994-2002 Xiph.Org Foundation */ 2 | /* Modified by Jean-Marc Valin */ 3 | /* 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | - Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | - Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 19 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | /* opus_types.h based on ogg_types.h from libogg */ 28 | 29 | /** 30 | @file opus_types.h 31 | @brief Opus reference implementation types 32 | */ 33 | #ifndef OPUS_TYPES_H 34 | #define OPUS_TYPES_H 35 | 36 | /* Use the real stdint.h if it's there (taken from Paul Hsieh's pstdint.h) */ 37 | #if (defined(__STDC__) && __STDC__ && defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L) || (defined(__GNUC__) && (defined(_STDINT_H) || defined(_STDINT_H_)) || defined (HAVE_STDINT_H)) 38 | #include 39 | 40 | typedef int16_t opus_int16; 41 | typedef uint16_t opus_uint16; 42 | typedef int32_t opus_int32; 43 | typedef uint32_t opus_uint32; 44 | #elif defined(_WIN32) 45 | 46 | # if defined(__CYGWIN__) 47 | # include <_G_config.h> 48 | typedef _G_int32_t opus_int32; 49 | typedef _G_uint32_t opus_uint32; 50 | typedef _G_int16 opus_int16; 51 | typedef _G_uint16 opus_uint16; 52 | # elif defined(__MINGW32__) 53 | typedef short opus_int16; 54 | typedef unsigned short opus_uint16; 55 | typedef int opus_int32; 56 | typedef unsigned int opus_uint32; 57 | # elif defined(__MWERKS__) 58 | typedef int opus_int32; 59 | typedef unsigned int opus_uint32; 60 | typedef short opus_int16; 61 | typedef unsigned short opus_uint16; 62 | # else 63 | /* MSVC/Borland */ 64 | typedef __int32 opus_int32; 65 | typedef unsigned __int32 opus_uint32; 66 | typedef __int16 opus_int16; 67 | typedef unsigned __int16 opus_uint16; 68 | # endif 69 | 70 | #elif defined(__MACOS__) 71 | 72 | # include 73 | typedef SInt16 opus_int16; 74 | typedef UInt16 opus_uint16; 75 | typedef SInt32 opus_int32; 76 | typedef UInt32 opus_uint32; 77 | 78 | #elif (defined(__APPLE__) && defined(__MACH__)) /* MacOS X Framework build */ 79 | 80 | # include 81 | typedef int16_t opus_int16; 82 | typedef u_int16_t opus_uint16; 83 | typedef int32_t opus_int32; 84 | typedef u_int32_t opus_uint32; 85 | 86 | #elif defined(__BEOS__) 87 | 88 | /* Be */ 89 | # include 90 | typedef int16 opus_int16; 91 | typedef u_int16 opus_uint16; 92 | typedef int32_t opus_int32; 93 | typedef u_int32_t opus_uint32; 94 | 95 | #elif defined (__EMX__) 96 | 97 | /* OS/2 GCC */ 98 | typedef short opus_int16; 99 | typedef unsigned short opus_uint16; 100 | typedef int opus_int32; 101 | typedef unsigned int opus_uint32; 102 | 103 | #elif defined (DJGPP) 104 | 105 | /* DJGPP */ 106 | typedef short opus_int16; 107 | typedef unsigned short opus_uint16; 108 | typedef int opus_int32; 109 | typedef unsigned int opus_uint32; 110 | 111 | #elif defined(R5900) 112 | 113 | /* PS2 EE */ 114 | typedef int opus_int32; 115 | typedef unsigned opus_uint32; 116 | typedef short opus_int16; 117 | typedef unsigned short opus_uint16; 118 | 119 | #elif defined(__SYMBIAN32__) 120 | 121 | /* Symbian GCC */ 122 | typedef signed short opus_int16; 123 | typedef unsigned short opus_uint16; 124 | typedef signed int opus_int32; 125 | typedef unsigned int opus_uint32; 126 | 127 | #elif defined(CONFIG_TI_C54X) || defined (CONFIG_TI_C55X) 128 | 129 | typedef short opus_int16; 130 | typedef unsigned short opus_uint16; 131 | typedef long opus_int32; 132 | typedef unsigned long opus_uint32; 133 | 134 | #elif defined(CONFIG_TI_C6X) 135 | 136 | typedef short opus_int16; 137 | typedef unsigned short opus_uint16; 138 | typedef int opus_int32; 139 | typedef unsigned int opus_uint32; 140 | 141 | #else 142 | 143 | /* Give up, take a reasonable guess */ 144 | typedef short opus_int16; 145 | typedef unsigned short opus_uint16; 146 | typedef int opus_int32; 147 | typedef unsigned int opus_uint32; 148 | 149 | #endif 150 | 151 | #define opus_int int /* used for counters etc; at least 16 bits */ 152 | #define opus_int64 long long 153 | #define opus_int8 signed char 154 | 155 | #define opus_uint unsigned int /* used for counters etc; at least 16 bits */ 156 | #define opus_uint64 unsigned long long 157 | #define opus_uint8 unsigned char 158 | 159 | #endif /* OPUS_TYPES_H */ 160 | -------------------------------------------------------------------------------- /src/pitch.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2007-2008 CSIRO 2 | Copyright (c) 2007-2009 Xiph.Org Foundation 3 | Written by Jean-Marc Valin */ 4 | /** 5 | @file pitch.c 6 | @brief Pitch analysis 7 | */ 8 | 9 | /* 10 | Redistribution and use in source and binary forms, with or without 11 | modification, are permitted provided that the following conditions 12 | are met: 13 | 14 | - Redistributions of source code must retain the above copyright 15 | notice, this list of conditions and the following disclaimer. 16 | 17 | - Redistributions in binary form must reproduce the above copyright 18 | notice, this list of conditions and the following disclaimer in the 19 | documentation and/or other materials provided with the distribution. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 25 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 26 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 27 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 28 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 29 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 30 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 31 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | 34 | #ifdef HAVE_CONFIG_H 35 | #include "config.h" 36 | #endif 37 | 38 | #include "pitch.h" 39 | #include "common.h" 40 | //#include "modes.h" 41 | //#include "stack_alloc.h" 42 | //#include "mathops.h" 43 | #include "celt_lpc.h" 44 | #include "math.h" 45 | 46 | static void find_best_pitch(opus_val32 *xcorr, opus_val16 *y, int len, 47 | int max_pitch, int *best_pitch 48 | #ifdef FIXED_POINT 49 | , int yshift, opus_val32 maxcorr 50 | #endif 51 | ) 52 | { 53 | int i, j; 54 | opus_val32 Syy=1; 55 | opus_val16 best_num[2]; 56 | opus_val32 best_den[2]; 57 | #ifdef FIXED_POINT 58 | int xshift; 59 | 60 | xshift = celt_ilog2(maxcorr)-14; 61 | #endif 62 | 63 | best_num[0] = -1; 64 | best_num[1] = -1; 65 | best_den[0] = 0; 66 | best_den[1] = 0; 67 | best_pitch[0] = 0; 68 | best_pitch[1] = 1; 69 | for (j=0;j0) 74 | { 75 | opus_val16 num; 76 | opus_val32 xcorr16; 77 | xcorr16 = EXTRACT16(VSHR32(xcorr[i], xshift)); 78 | #ifndef FIXED_POINT 79 | /* Considering the range of xcorr16, this should avoid both underflows 80 | and overflows (inf) when squaring xcorr16 */ 81 | xcorr16 *= 1e-12f; 82 | #endif 83 | num = MULT16_16_Q15(xcorr16,xcorr16); 84 | if (MULT16_32_Q15(num,best_den[1]) > MULT16_32_Q15(best_num[1],Syy)) 85 | { 86 | if (MULT16_32_Q15(num,best_den[0]) > MULT16_32_Q15(best_num[0],Syy)) 87 | { 88 | best_num[1] = best_num[0]; 89 | best_den[1] = best_den[0]; 90 | best_pitch[1] = best_pitch[0]; 91 | best_num[0] = num; 92 | best_den[0] = Syy; 93 | best_pitch[0] = i; 94 | } else { 95 | best_num[1] = num; 96 | best_den[1] = Syy; 97 | best_pitch[1] = i; 98 | } 99 | } 100 | } 101 | Syy += SHR32(MULT16_16(y[i+len],y[i+len]),yshift) - SHR32(MULT16_16(y[i],y[i]),yshift); 102 | Syy = MAX32(1, Syy); 103 | } 104 | } 105 | 106 | static void celt_fir5(const opus_val16 *x, 107 | const opus_val16 *num, 108 | opus_val16 *y, 109 | int N, 110 | opus_val16 *mem) 111 | { 112 | int i; 113 | opus_val16 num0, num1, num2, num3, num4; 114 | opus_val32 mem0, mem1, mem2, mem3, mem4; 115 | num0=num[0]; 116 | num1=num[1]; 117 | num2=num[2]; 118 | num3=num[3]; 119 | num4=num[4]; 120 | mem0=mem[0]; 121 | mem1=mem[1]; 122 | mem2=mem[2]; 123 | mem3=mem[3]; 124 | mem4=mem[4]; 125 | for (i=0;i>1;i++) 174 | x_lp[i] = SHR32(HALF32(HALF32(x[0][(2*i-1)]+x[0][(2*i+1)])+x[0][2*i]), shift); 175 | x_lp[0] = SHR32(HALF32(HALF32(x[0][1])+x[0][0]), shift); 176 | if (C==2) 177 | { 178 | for (i=1;i>1;i++) 179 | x_lp[i] += SHR32(HALF32(HALF32(x[1][(2*i-1)]+x[1][(2*i+1)])+x[1][2*i]), shift); 180 | x_lp[0] += SHR32(HALF32(HALF32(x[1][1])+x[1][0]), shift); 181 | } 182 | 183 | _celt_autocorr(x_lp, ac, NULL, 0, 184 | 4, len>>1); 185 | 186 | /* Noise floor -40 dB */ 187 | #ifdef FIXED_POINT 188 | ac[0] += SHR32(ac[0],13); 189 | #else 190 | ac[0] *= 1.0001f; 191 | #endif 192 | /* Lag windowing */ 193 | for (i=1;i<=4;i++) 194 | { 195 | /*ac[i] *= exp(-.5*(2*M_PI*.002*i)*(2*M_PI*.002*i));*/ 196 | #ifdef FIXED_POINT 197 | ac[i] -= MULT16_32_Q15(2*i*i, ac[i]); 198 | #else 199 | ac[i] -= ac[i]*(.008f*i)*(.008f*i); 200 | #endif 201 | } 202 | 203 | _celt_lpc(lpc, ac, 4); 204 | for (i=0;i<4;i++) 205 | { 206 | tmp = MULT16_16_Q15(QCONST16(.9f,15), tmp); 207 | lpc[i] = MULT16_16_Q15(lpc[i], tmp); 208 | } 209 | /* Add a zero */ 210 | lpc2[0] = lpc[0] + QCONST16(.8f,SIG_SHIFT); 211 | lpc2[1] = lpc[1] + MULT16_16_Q15(c1,lpc[0]); 212 | lpc2[2] = lpc[2] + MULT16_16_Q15(c1,lpc[1]); 213 | lpc2[3] = lpc[3] + MULT16_16_Q15(c1,lpc[2]); 214 | lpc2[4] = MULT16_16_Q15(c1,lpc[3]); 215 | celt_fir5(x_lp, lpc2, x_lp, len>>1, mem); 216 | } 217 | 218 | void celt_pitch_xcorr(const opus_val16 *_x, const opus_val16 *_y, 219 | opus_val32 *xcorr, int len, int max_pitch) 220 | { 221 | 222 | #if 0 /* This is a simple version of the pitch correlation that should work 223 | well on DSPs like Blackfin and TI C5x/C6x */ 224 | int i, j; 225 | #ifdef FIXED_POINT 226 | opus_val32 maxcorr=1; 227 | #endif 228 | for (i=0;i0); 251 | celt_assert((((unsigned char *)_x-(unsigned char *)NULL)&3)==0); 252 | for (i=0;i0); 297 | celt_assert(max_pitch>0); 298 | lag = len+max_pitch; 299 | 300 | opus_val16 x_lp4[len>>2]; 301 | opus_val16 y_lp4[lag>>2]; 302 | opus_val32 xcorr[max_pitch>>1]; 303 | 304 | /* Downsample by 2 again */ 305 | for (j=0;j>2;j++) 306 | x_lp4[j] = x_lp[2*j]; 307 | for (j=0;j>2;j++) 308 | y_lp4[j] = y[2*j]; 309 | 310 | #ifdef FIXED_POINT 311 | xmax = celt_maxabs16(x_lp4, len>>2); 312 | ymax = celt_maxabs16(y_lp4, lag>>2); 313 | shift = celt_ilog2(MAX32(1, MAX32(xmax, ymax)))-11; 314 | if (shift>0) 315 | { 316 | for (j=0;j>2;j++) 317 | x_lp4[j] = SHR16(x_lp4[j], shift); 318 | for (j=0;j>2;j++) 319 | y_lp4[j] = SHR16(y_lp4[j], shift); 320 | /* Use double the shift for a MAC */ 321 | shift *= 2; 322 | } else { 323 | shift = 0; 324 | } 325 | #endif 326 | 327 | /* Coarse search with 4x decimation */ 328 | 329 | #ifdef FIXED_POINT 330 | maxcorr = 331 | #endif 332 | celt_pitch_xcorr(x_lp4, y_lp4, xcorr, len>>2, max_pitch>>2); 333 | 334 | find_best_pitch(xcorr, y_lp4, len>>2, max_pitch>>2, best_pitch 335 | #ifdef FIXED_POINT 336 | , 0, maxcorr 337 | #endif 338 | ); 339 | 340 | /* Finer search with 2x decimation */ 341 | #ifdef FIXED_POINT 342 | maxcorr=1; 343 | #endif 344 | for (i=0;i>1;i++) 345 | { 346 | opus_val32 sum; 347 | xcorr[i] = 0; 348 | if (abs(i-2*best_pitch[0])>2 && abs(i-2*best_pitch[1])>2) 349 | continue; 350 | #ifdef FIXED_POINT 351 | sum = 0; 352 | for (j=0;j>1;j++) 353 | sum += SHR32(MULT16_16(x_lp[j],y[i+j]), shift); 354 | #else 355 | sum = celt_inner_prod(x_lp, y+i, len>>1); 356 | #endif 357 | xcorr[i] = MAX32(-1, sum); 358 | #ifdef FIXED_POINT 359 | maxcorr = MAX32(maxcorr, sum); 360 | #endif 361 | } 362 | find_best_pitch(xcorr, y, len>>1, max_pitch>>1, best_pitch 363 | #ifdef FIXED_POINT 364 | , shift+1, maxcorr 365 | #endif 366 | ); 367 | 368 | /* Refine by pseudo-interpolation */ 369 | if (best_pitch[0]>0 && best_pitch[0]<(max_pitch>>1)-1) 370 | { 371 | opus_val32 a, b, c; 372 | a = xcorr[best_pitch[0]-1]; 373 | b = xcorr[best_pitch[0]]; 374 | c = xcorr[best_pitch[0]+1]; 375 | if ((c-a) > MULT16_32_Q15(QCONST16(.7f,15),b-a)) 376 | offset = 1; 377 | else if ((a-c) > MULT16_32_Q15(QCONST16(.7f,15),b-c)) 378 | offset = -1; 379 | else 380 | offset = 0; 381 | } else { 382 | offset = 0; 383 | } 384 | *pitch = 2*best_pitch[0]-offset; 385 | *pitch_corr = xcorr[best_pitch[0]]; 386 | } 387 | 388 | #ifdef FIXED_POINT 389 | static opus_val16 compute_pitch_gain(opus_val32 xy, opus_val32 xx, opus_val32 yy) 390 | { 391 | opus_val32 x2y2; 392 | int sx, sy, shift; 393 | opus_val32 g; 394 | opus_val16 den; 395 | if (xy == 0 || xx == 0 || yy == 0) 396 | return 0; 397 | sx = celt_ilog2(xx)-14; 398 | sy = celt_ilog2(yy)-14; 399 | shift = sx + sy; 400 | x2y2 = SHR32(MULT16_16(VSHR32(xx, sx), VSHR32(yy, sy)), 14); 401 | if (shift & 1) { 402 | if (x2y2 < 32768) 403 | { 404 | x2y2 <<= 1; 405 | shift--; 406 | } else { 407 | x2y2 >>= 1; 408 | shift++; 409 | } 410 | } 411 | den = celt_rsqrt_norm(x2y2); 412 | g = MULT16_32_Q15(den, xy); 413 | g = VSHR32(g, (shift>>1)-1); 414 | return EXTRACT16(MIN32(g, Q15ONE)); 415 | } 416 | #else 417 | static opus_val16 compute_pitch_gain(opus_val32 xy, opus_val32 xx, opus_val32 yy) 418 | { 419 | return xy/sqrt(1+xx*yy); 420 | } 421 | #endif 422 | 423 | static const int second_check[16] = {0, 0, 3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2}; 424 | opus_val16 remove_doubling(opus_val16 *x, int maxperiod, int minperiod, 425 | int N, int *T0_, int prev_period, opus_val16 prev_gain) 426 | { 427 | int k, i, T, T0; 428 | opus_val16 g, g0; 429 | opus_val16 pg; 430 | opus_val32 xy,xx,yy,xy2; 431 | opus_val32 xcorr[3]; 432 | opus_val32 best_xy, best_yy; 433 | int offset; 434 | int minperiod0; 435 | 436 | minperiod0 = minperiod; 437 | maxperiod /= 2; 438 | minperiod /= 2; 439 | *T0_ /= 2; 440 | prev_period /= 2; 441 | N /= 2; 442 | x += maxperiod; 443 | if (*T0_>=maxperiod) 444 | *T0_=maxperiod-1; 445 | 446 | T = T0 = *T0_; 447 | opus_val32 yy_lookup[maxperiod+1]; 448 | dual_inner_prod(x, x, x-T0, N, &xx, &xy); 449 | yy_lookup[0] = xx; 450 | yy=xx; 451 | for (i=1;i<=maxperiod;i++) 452 | { 453 | yy = yy+MULT16_16(x[-i],x[-i])-MULT16_16(x[N-i],x[N-i]); 454 | yy_lookup[i] = MAX32(0, yy); 455 | } 456 | yy = yy_lookup[T0]; 457 | best_xy = xy; 458 | best_yy = yy; 459 | g = g0 = compute_pitch_gain(xy, xx, yy); 460 | /* Look for any pitch at T/k */ 461 | for (k=2;k<=15;k++) 462 | { 463 | int T1, T1b; 464 | opus_val16 g1; 465 | opus_val16 cont=0; 466 | opus_val16 thresh; 467 | T1 = (2*T0+k)/(2*k); 468 | if (T1 < minperiod) 469 | break; 470 | /* Look for another strong correlation at T1b */ 471 | if (k==2) 472 | { 473 | if (T1+T0>maxperiod) 474 | T1b = T0; 475 | else 476 | T1b = T0+T1; 477 | } else 478 | { 479 | T1b = (2*second_check[k]*T0+k)/(2*k); 480 | } 481 | dual_inner_prod(x, &x[-T1], &x[-T1b], N, &xy, &xy2); 482 | xy = HALF32(xy + xy2); 483 | yy = HALF32(yy_lookup[T1] + yy_lookup[T1b]); 484 | g1 = compute_pitch_gain(xy, xx, yy); 485 | if (abs(T1-prev_period)<=1) 486 | cont = prev_gain; 487 | else if (abs(T1-prev_period)<=2 && 5*k*k < T0) 488 | cont = HALF16(prev_gain); 489 | else 490 | cont = 0; 491 | thresh = MAX16(QCONST16(.3f,15), MULT16_16_Q15(QCONST16(.7f,15),g0)-cont); 492 | /* Bias against very high pitch (very short period) to avoid false-positives 493 | due to short-term correlation */ 494 | if (T1<3*minperiod) 495 | thresh = MAX16(QCONST16(.4f,15), MULT16_16_Q15(QCONST16(.85f,15),g0)-cont); 496 | else if (T1<2*minperiod) 497 | thresh = MAX16(QCONST16(.5f,15), MULT16_16_Q15(QCONST16(.9f,15),g0)-cont); 498 | if (g1 > thresh) 499 | { 500 | best_xy = xy; 501 | best_yy = yy; 502 | T = T1; 503 | g = g1; 504 | } 505 | } 506 | best_xy = MAX32(0, best_xy); 507 | if (best_yy <= best_xy) 508 | pg = Q15ONE; 509 | else 510 | pg = best_xy/(best_yy+1); 511 | 512 | for (k=0;k<3;k++) 513 | xcorr[k] = celt_inner_prod(x, x-(T+k-1), N); 514 | if ((xcorr[2]-xcorr[0]) > MULT16_32_Q15(QCONST16(.7f,15),xcorr[1]-xcorr[0])) 515 | offset = 1; 516 | else if ((xcorr[0]-xcorr[2]) > MULT16_32_Q15(QCONST16(.7f,15),xcorr[1]-xcorr[2])) 517 | offset = -1; 518 | else 519 | offset = 0; 520 | if (pg > g) 521 | pg = g; 522 | *T0_ = 2*T+offset; 523 | 524 | if (*T0_=3); 58 | y_3=0; /* gcc doesn't realize that y_3 can't be used uninitialized */ 59 | y_0=*y++; 60 | y_1=*y++; 61 | y_2=*y++; 62 | for (j=0;j 33 | #include "opus_types.h" 34 | #include "common.h" 35 | #include "arch.h" 36 | #include "tansig_table.h" 37 | #include "nnet.h" 38 | #include "nnet_data.h" 39 | #include 40 | 41 | 42 | void compute_rnn(RNNState *rnn, float *gains, float *strengths, const float *input) { 43 | int i; 44 | float dense_out[MAX_NEURONS]; 45 | float first_conv1d_out[CONV_DIM]; 46 | float second_conv1d_out[CONV_DIM]; 47 | float gb_dense_input[CONV_DIM*5]; 48 | float rb_gru_input[CONV_DIM*2]; 49 | 50 | compute_dense(rnn->model->fc, dense_out, input); 51 | compute_conv1d(rnn->model->conv1, first_conv1d_out/*512*/, rnn->first_conv1d_state, dense_out); 52 | compute_conv1d(rnn->model->conv2, second_conv1d_out/*512*/, rnn->second_conv1d_state, first_conv1d_out); 53 | 54 | //align 3 conv data 55 | //RNN_MOVE(rnn->convout_buf, &rnn->convout_buf[CONV_DIM], CONVOUT_BUF_SIZE-CONV_DIM); 56 | //RNN_COPY(&rnn->convout_buf[CONVOUT_BUF_SIZE-CONV_DIM], rnn->second_conv1d_state, CONV_DIM); 57 | //T-3 convout input for gru1 58 | compute_gru(rnn->model->gru1, rnn->gru1_state, second_conv1d_out); 59 | compute_gru(rnn->model->gru2, rnn->gru2_state, rnn->gru1_state); 60 | compute_gru(rnn->model->gru3, rnn->gru3_state, rnn->gru2_state); 61 | 62 | //for temporary input for gb_gru and rb_gru is gru3_state 63 | //but it might be need concat convout_buf through gru1,2,3_state 64 | compute_gru(rnn->model->gru_gb, rnn->gb_gru_state, rnn->gru3_state); 65 | 66 | //concat for rb gru 67 | for (i=0;igru3_state[i]; 68 | for (i=0;imodel->gru_rb, rnn->rb_gru_state, rb_gru_input); 70 | 71 | //concat for gb denseW 72 | for (i=0;igru1_state[i]; 74 | for (i=0;igru2_state[i]; 75 | for (i=0;igru3_state[i]; 76 | for (i=0;igb_gru_state[i]; 77 | compute_dense(rnn->model->fc_gb, gains, gb_dense_input); 78 | 79 | compute_dense(rnn->model->fc_rb, strengths, rnn->rb_gru_state); 80 | 81 | } 82 | -------------------------------------------------------------------------------- /src/rnnoise.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2018 Gregor Richards 2 | * Copyright (c) 2017 Mozilla */ 3 | /* 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | - Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | - Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 19 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | 28 | #ifndef RNNOISE_H 29 | #define RNNOISE_H 1 30 | 31 | #include 32 | #include "nnet_data.h" 33 | 34 | 35 | #ifndef RNNOISE_EXPORT 36 | # if defined(WIN32) 37 | # if defined(RNNOISE_BUILD) && defined(DLL_EXPORT) 38 | # define RNNOISE_EXPORT __declspec(dllexport) 39 | # else 40 | # define RNNOISE_EXPORT 41 | # endif 42 | # elif defined(__GNUC__) && defined(RNNOISE_BUILD) 43 | # define RNNOISE_EXPORT __attribute__ ((visibility ("default"))) 44 | # else 45 | # define RNNOISE_EXPORT 46 | # endif 47 | #endif 48 | 49 | typedef struct DenoiseState DenoiseState; 50 | typedef struct RNNModel RNNModel; 51 | 52 | RNNOISE_EXPORT int rnnoise_get_size(); 53 | 54 | RNNOISE_EXPORT int rnnoise_init(DenoiseState *st, RNNModel *model); 55 | 56 | RNNOISE_EXPORT DenoiseState *rnnoise_create(RNNModel *model); 57 | 58 | RNNOISE_EXPORT void rnnoise_destroy(DenoiseState *st); 59 | 60 | RNNOISE_EXPORT float rnnoise_process_frame(DenoiseState *st, float *out, const float *in, FILE *f_feature); 61 | 62 | RNNOISE_EXPORT RNNModel *rnnoise_model_from_file(FILE *f); 63 | 64 | RNNOISE_EXPORT void rnnoise_model_free(RNNModel *model); 65 | 66 | int train(int argc, char **argv); 67 | 68 | void compute_rnn(RNNState *rnn, float *gains, float *strengths, const float *input); 69 | #endif 70 | -------------------------------------------------------------------------------- /src/tansig_table.h: -------------------------------------------------------------------------------- 1 | /* This file is auto-generated by gen_tables */ 2 | 3 | #ifndef TANSIGTABLE_H 4 | #define TANSIGTABLE_H 1 5 | static const float tansig_table[201] = { 6 | 0.000000f, 0.039979f, 0.079830f, 0.119427f, 0.158649f, 7 | 0.197375f, 0.235496f, 0.272905f, 0.309507f, 0.345214f, 8 | 0.379949f, 0.413644f, 0.446244f, 0.477700f, 0.507977f, 9 | 0.537050f, 0.564900f, 0.591519f, 0.616909f, 0.641077f, 10 | 0.664037f, 0.685809f, 0.706419f, 0.725897f, 0.744277f, 11 | 0.761594f, 0.777888f, 0.793199f, 0.807569f, 0.821040f, 12 | 0.833655f, 0.845456f, 0.856485f, 0.866784f, 0.876393f, 13 | 0.885352f, 0.893698f, 0.901468f, 0.908698f, 0.915420f, 14 | 0.921669f, 0.927473f, 0.932862f, 0.937863f, 0.942503f, 15 | 0.946806f, 0.950795f, 0.954492f, 0.957917f, 0.961090f, 16 | 0.964028f, 0.966747f, 0.969265f, 0.971594f, 0.973749f, 17 | 0.975743f, 0.977587f, 0.979293f, 0.980869f, 0.982327f, 18 | 0.983675f, 0.984921f, 0.986072f, 0.987136f, 0.988119f, 19 | 0.989027f, 0.989867f, 0.990642f, 0.991359f, 0.992020f, 20 | 0.992631f, 0.993196f, 0.993718f, 0.994199f, 0.994644f, 21 | 0.995055f, 0.995434f, 0.995784f, 0.996108f, 0.996407f, 22 | 0.996682f, 0.996937f, 0.997172f, 0.997389f, 0.997590f, 23 | 0.997775f, 0.997946f, 0.998104f, 0.998249f, 0.998384f, 24 | 0.998508f, 0.998623f, 0.998728f, 0.998826f, 0.998916f, 25 | 0.999000f, 0.999076f, 0.999147f, 0.999213f, 0.999273f, 26 | 0.999329f, 0.999381f, 0.999428f, 0.999472f, 0.999513f, 27 | 0.999550f, 0.999585f, 0.999617f, 0.999646f, 0.999673f, 28 | 0.999699f, 0.999722f, 0.999743f, 0.999763f, 0.999781f, 29 | 0.999798f, 0.999813f, 0.999828f, 0.999841f, 0.999853f, 30 | 0.999865f, 0.999875f, 0.999885f, 0.999893f, 0.999902f, 31 | 0.999909f, 0.999916f, 0.999923f, 0.999929f, 0.999934f, 32 | 0.999939f, 0.999944f, 0.999948f, 0.999952f, 0.999956f, 33 | 0.999959f, 0.999962f, 0.999965f, 0.999968f, 0.999970f, 34 | 0.999973f, 0.999975f, 0.999977f, 0.999978f, 0.999980f, 35 | 0.999982f, 0.999983f, 0.999984f, 0.999986f, 0.999987f, 36 | 0.999988f, 0.999989f, 0.999990f, 0.999990f, 0.999991f, 37 | 0.999992f, 0.999992f, 0.999993f, 0.999994f, 0.999994f, 38 | 0.999994f, 0.999995f, 0.999995f, 0.999996f, 0.999996f, 39 | 0.999996f, 0.999997f, 0.999997f, 0.999997f, 0.999997f, 40 | 0.999997f, 0.999998f, 0.999998f, 0.999998f, 0.999998f, 41 | 0.999998f, 0.999998f, 0.999999f, 0.999999f, 0.999999f, 42 | 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f, 43 | 0.999999f, 0.999999f, 0.999999f, 0.999999f, 0.999999f, 44 | 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 45 | 1.000000f, 1.000000f, 1.000000f, 1.000000f, 1.000000f, 46 | 1.000000f, 47 | }; 48 | #endif -------------------------------------------------------------------------------- /src/vec.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2018 Mozilla 2 | 2008-2011 Octasic Inc. 3 | 2012-2017 Jean-Marc Valin */ 4 | /* 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions 7 | are met: 8 | 9 | - Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | - Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in the 14 | documentation and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 20 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | */ 28 | /* No AVX2/FMA support */ 29 | //#include "opus_types.h" 30 | #include "_kiss_fft_guts.h" 31 | #include "tansig_table.h" 32 | #ifndef LPCNET_TEST 33 | static float celt_exp2(float x) 34 | { 35 | int integer; 36 | float frac; 37 | union { 38 | float f; 39 | opus_uint32 i; 40 | } res; 41 | integer = floor(x); 42 | if (integer < -50) 43 | return 0; 44 | frac = x-integer; 45 | /* K0 = 1, K1 = log(2), K2 = 3-4*log(2), K3 = 3*log(2) - 2 */ 46 | res.f = 0.99992522f + frac * (0.69583354f 47 | + frac * (0.22606716f + 0.078024523f*frac)); 48 | res.i = (res.i + (integer<<23)) & 0x7fffffff; 49 | return res.f; 50 | } 51 | #define celt_exp(x) celt_exp2((x)*1.44269504f) 52 | 53 | static float tansig_approx(float x) 54 | { 55 | int i; 56 | float y, dy; 57 | float sign=1; 58 | if (x<0) 59 | { 60 | x=-x; 61 | sign=-1; 62 | } 63 | i = (int)floor(.5f+25*x); 64 | i = IMAX(0, IMIN(200, i)); 65 | x -= .04f*i; 66 | y = tansig_table[i]; 67 | dy = 1-y*y; 68 | y = y + x*dy*(1 - y*x); 69 | return sign*y; 70 | } 71 | 72 | static OPUS_INLINE float sigmoid_approx(float x) 73 | { 74 | return .5f + .5f*tansig_approx(.5f*x); 75 | } 76 | 77 | static void softmax(float *y, const float *x, int N) 78 | { 79 | int i; 80 | for (i=0;i),\n", 39 | " tensor([[[-0.4103, -0.0571, -0.0257, 0.0596]]], grad_fn=))" 40 | ] 41 | }, 42 | "execution_count": 3, 43 | "metadata": {}, 44 | "output_type": "execute_result" 45 | } 46 | ], 47 | "source": [ 48 | "gru1_output" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 44, 54 | "id": "5be1c088", 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "data": { 59 | "text/plain": [ 60 | "tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])" 61 | ] 62 | }, 63 | "execution_count": 44, 64 | "metadata": {}, 65 | "output_type": "execute_result" 66 | } 67 | ], 68 | "source": [ 69 | "gru1.bias_ih_l0.data.fill_(0)\n", 70 | "gru1.bias_hh_l0.data.fill_(0)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 22, 76 | "id": "5156b364", 77 | "metadata": {}, 78 | "outputs": [ 79 | { 80 | "data": { 81 | "text/plain": [ 82 | "torch.Size([1, 3, 4])" 83 | ] 84 | }, 85 | "execution_count": 22, 86 | "metadata": {}, 87 | "output_type": "execute_result" 88 | } 89 | ], 90 | "source": [ 91 | "gru1_output = gru1(gru_test_input)\n", 92 | "gru1_output[0].shape" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 32, 98 | "id": "31dfa209", 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "name": "stdout", 103 | "output_type": "stream", 104 | "text": [ 105 | "y\n" 106 | ] 107 | } 108 | ], 109 | "source": [ 110 | "if isinstance(gru1,nn.GRU):\n", 111 | " print('y')" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 5, 117 | "id": "037748c9", 118 | "metadata": {}, 119 | "outputs": [ 120 | { 121 | "data": { 122 | "text/plain": [ 123 | "torch.Size([12, 2])" 124 | ] 125 | }, 126 | "execution_count": 5, 127 | "metadata": {}, 128 | "output_type": "execute_result" 129 | } 130 | ], 131 | "source": [ 132 | "gru1.weight_ih_l0.shape" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 14, 138 | "id": "c1652b6a", 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "def convert_gru_input_kernel(kernel):\n", 143 | " kernel_r, kernel_z, kernel_h = np.vsplit(kernel, 3)\n", 144 | " kernels = [kernel_z, kernel_r, kernel_h]\n", 145 | " return torch.tensor(np.hstack([k.reshape(k.T.shape) for k in kernels]))\n", 146 | "\n", 147 | "def convert_gru_recurrent_kernel(kernel):\n", 148 | " kernel_r, kernel_z, kernel_h = np.vsplit(kernel, 3)\n", 149 | " kernels = [kernel_z, kernel_r, kernel_h]\n", 150 | " return torch.tensor(np.hstack(kernels))" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 18, 156 | "id": "2c5bbe5f", 157 | "metadata": {}, 158 | "outputs": [ 159 | { 160 | "name": "stdout", 161 | "output_type": "stream", 162 | "text": [ 163 | "torch.Size([2, 12])\n" 164 | ] 165 | } 166 | ], 167 | "source": [ 168 | "new_weight_ih_l0 = convert_gru_input_kernel(gru1.weight_ih_l0.detach().numpy())\n", 169 | "print(new_weight_ih_l0.shape)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 6, 175 | "id": "6bb3c098", 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "import keras\n", 180 | "import tensorflow\n", 181 | "from keras import layers" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 13, 187 | "id": "a310a67f", 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "keras_conv1 = layers.Conv1D(3, 2, padding='same', activation='tanh', name='feature_conv1')" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 14, 197 | "id": "9afb23f3", 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "input_shape = (1,3,4)\n", 202 | "x = tensorflow.random.normal(input_shape)\n", 203 | "y = keras_conv1(x)" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 15, 209 | "id": "72c5296e", 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "weights = keras_conv1.get_weights()" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 16, 219 | "id": "18315a2e", 220 | "metadata": {}, 221 | "outputs": [ 222 | { 223 | "data": { 224 | "text/plain": [ 225 | "(2, 4, 3)" 226 | ] 227 | }, 228 | "execution_count": 16, 229 | "metadata": {}, 230 | "output_type": "execute_result" 231 | } 232 | ], 233 | "source": [ 234 | "weights[0].shape\n", 235 | "#input 1\n", 236 | "#kernal_size 0\n", 237 | "#output 2" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 17, 243 | "id": "757e6b12", 244 | "metadata": {}, 245 | "outputs": [ 246 | { 247 | "data": { 248 | "text/plain": [ 249 | "(3,)" 250 | ] 251 | }, 252 | "execution_count": 17, 253 | "metadata": {}, 254 | "output_type": "execute_result" 255 | } 256 | ], 257 | "source": [ 258 | "weights[1].shape" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 7, 264 | "id": "6cd8bedc", 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "keras_gru1 = layers.GRU(4,return_sequences=True, recurrent_activation=\"sigmoid\", reset_after='true', name='gru_a',)\n", 269 | "input_shape = (1,3,2)\n", 270 | "x = tensorflow.random.normal(input_shape)\n", 271 | "y = keras_gru1(x)" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": 13, 277 | "id": "c3b4b238", 278 | "metadata": {}, 279 | "outputs": [ 280 | { 281 | "data": { 282 | "text/plain": [ 283 | "(2, 12)" 284 | ] 285 | }, 286 | "execution_count": 13, 287 | "metadata": {}, 288 | "output_type": "execute_result" 289 | } 290 | ], 291 | "source": [ 292 | "weights = keras_gru1.get_weights()\n", 293 | "weights[0].shape" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 12, 299 | "id": "c5ece6c4", 300 | "metadata": {}, 301 | "outputs": [ 302 | { 303 | "name": "stdout", 304 | "output_type": "stream", 305 | "text": [ 306 | "[[ 0.50630426 -0.18125176 -0.22455446 -0.22690475 -0.21649115 -0.19073492\n", 307 | " -0.01370192 -0.47307855 0.05561786 0.21302928 0.0155199 0.50292975]\n", 308 | " [-0.05305615 0.01354556 0.02428324 -0.10544971 0.5824456 0.23123857\n", 309 | " 0.18460734 -0.3195278 -0.5989489 -0.13466732 -0.20463274 0.19403724]\n", 310 | " [-0.03756262 0.24331398 -0.7099939 -0.00343641 0.21649931 -0.02412327\n", 311 | " -0.04613936 -0.11119071 -0.07995344 0.18405285 0.4935046 -0.2992046 ]\n", 312 | " [ 0.30607358 0.029637 -0.01524288 -0.06086059 0.20136751 -0.4949669\n", 313 | " -0.19576669 -0.22203046 0.10759834 -0.52205724 -0.24321698 -0.4301926 ]]\n" 314 | ] 315 | } 316 | ], 317 | "source": [ 318 | "print(weights[1])" 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": 19, 324 | "id": "b44c780b", 325 | "metadata": {}, 326 | "outputs": [ 327 | { 328 | "data": { 329 | "text/plain": [ 330 | "TensorShape([1, 3, 4])" 331 | ] 332 | }, 333 | "execution_count": 19, 334 | "metadata": {}, 335 | "output_type": "execute_result" 336 | } 337 | ], 338 | "source": [ 339 | "y.shape" 340 | ] 341 | } 342 | ], 343 | "metadata": { 344 | "kernelspec": { 345 | "display_name": "Python 3", 346 | "language": "python", 347 | "name": "python3" 348 | }, 349 | "language_info": { 350 | "codemirror_mode": { 351 | "name": "ipython", 352 | "version": 3 353 | }, 354 | "file_extension": ".py", 355 | "mimetype": "text/x-python", 356 | "name": "python", 357 | "nbconvert_exporter": "python", 358 | "pygments_lexer": "ipython3", 359 | "version": "3.8.5" 360 | } 361 | }, 362 | "nbformat": 4, 363 | "nbformat_minor": 5 364 | } 365 | -------------------------------------------------------------------------------- /tests/main.cpp: -------------------------------------------------------------------------------- 1 | #include "gtest/gtest.h" 2 | 3 | int main(int argc, char **argv) 4 | { 5 | ::testing::InitGoogleTest(&argc, argv); 6 | int ret = RUN_ALL_TESTS(); 7 | return ret; 8 | } -------------------------------------------------------------------------------- /tests/moduletest.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | from torch import nn 4 | sys.path.append("../") 5 | import dump_percepnet 6 | from dump_percepnet import printVector 7 | class PercepNet(nn.Module): 8 | def __init__(self, input_dim=70): 9 | super(PercepNet, self).__init__() 10 | 11 | self.fc = nn.Sequential(nn.Linear(2, 3), nn.Sigmoid()) 12 | self.conv1 = nn.Sequential(nn.Conv1d(2, 3, 3, stride=1, padding=1), nn.Sigmoid()) 13 | self.gru1 = nn.GRU(2, 3, 1, batch_first=True) 14 | 15 | if __name__ == '__main__': 16 | model = PercepNet() 17 | 18 | cfile = 'nnet_data_test.h' 19 | 20 | f = open(cfile, 'w') 21 | 22 | f.write('/*This file is automatically generated from a Pytorch model*/\n\n') 23 | f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "nnet.h"\n//#include "nnet_data.h"\n\n') 24 | 25 | testdataset = [torch.Tensor([0.5,0.5]),torch.zeros([1,2,3])+0.5,torch.zeros([1,3,2])+0.5] 26 | for children, testdata in zip(model.named_children(),testdataset): 27 | name, module = children 28 | module.dump_data(f, name) 29 | output = module(testdata) 30 | if isinstance(output, tuple) : 31 | output = output[0] 32 | if len(output.size())>2 and not isinstance(module,nn.GRU): 33 | output = torch.transpose(output, 1, 2) 34 | printVector(f, output, name+"_output") 35 | 36 | 37 | #f.write('extern const struct RNNModel percepnet_model_orig = {\n') 38 | #for name, module in model.named_children(): 39 | # f.write(' &{},\n'.format(name)) 40 | #f.write('};\n') 41 | 42 | -------------------------------------------------------------------------------- /tests/nnet_data_test.h: -------------------------------------------------------------------------------- 1 | /*This file is automatically generated from a Pytorch model*/ 2 | 3 | #ifdef HAVE_CONFIG_H 4 | #include "config.h" 5 | #endif 6 | 7 | #include "nnet.h" 8 | //#include "nnet_data.h" 9 | 10 | static const float fc_weights[6] = { 11 | -0.3736399710178375, -0.5867775082588196, 0.35544341802597046, 0.48470452427864075, -0.28879988193511963, 0.029497651383280754 12 | }; 13 | 14 | static const float fc_bias[3] = { 15 | -0.33606868982315063, -0.6151031255722046, -0.11614929139614105 16 | }; 17 | 18 | const DenseLayer fc = { 19 | fc_bias, 20 | fc_weights, 21 | 2, 3, ACTIVATION_SIGMOID 22 | }; 23 | 24 | static const float fc_output[3] = { 25 | 0.4303222894668579, 0.2586701810359955, 0.519071102142334 26 | }; 27 | 28 | static const float conv1_weights[18] = { 29 | -0.20590366423130035, -0.04676629975438118, 0.06895061582326889, -0.1175355538725853, -0.0025675243232399225, 0.31224867701530457, 0.0243507232517004, -0.1226673424243927, 30 | -0.32207781076431274, -0.26902732253074646, 0.34255385398864746, -0.23909498751163483, 0.2627784311771393, -0.2598373591899872, 0.13830594718456268, -0.22439751029014587, 31 | 0.2666202485561371, 0.20408029854297638 32 | }; 33 | 34 | static const float conv1_bias[3] = { 35 | -0.1918720006942749, -0.22945673763751984, -0.0022818492725491524 36 | }; 37 | 38 | const Conv1DLayer conv1 = { 39 | conv1_bias, 40 | conv1_weights, 41 | 2, 3, 3, ACTIVATION_SIGMOID 42 | }; 43 | 44 | static const float conv1_output[9] = { 45 | 0.42677536606788635, 0.47100207209587097, 0.47211021184921265, 0.38775959610939026, 0.4648607671260834, 0.5197209119796753, 0.38321366906166077, 0.46401721239089966, 46 | 0.4769492447376251 47 | }; 48 | 49 | static const float gru1_weights[18] = { 50 | 0.05608205869793892, 0.4327813684940338, -0.14592474699020386, 0.3059828281402588, 0.44043174386024475, 0.4284402132034302, -0.5530121922492981, -0.19666491448879242, 51 | -0.47842809557914734, 0.4116198718547821, -0.48748236894607544, -0.433973491191864, -0.12620672583580017, 0.06710749119520187, 0.38061898946762085, 0.3667631149291992, 52 | -0.21762359142303467, 0.49776896834373474 53 | }; 54 | 55 | static const float gru1_recurrent_weights[27] = { 56 | -0.5226409435272217, 0.31022316217422485, -0.06384880840778351, 0.4340868592262268, -0.5519341230392456, -0.03121725656092167, -0.14920800924301147, -0.47921422123908997, 57 | -0.34647974371910095, -0.15125687420368195, 0.059751834720373154, -0.2908184826374054, 0.09929904341697693, 0.5218473076820374, -0.4606817960739136, 0.4674922823905945, 58 | 0.2380998730659485, -0.5730081796646118, 0.15210899710655212, 0.3517436683177948, 0.5022947788238525, -0.3818601965904236, 0.07916372269392014, -0.3516411781311035, 59 | 0.2250480055809021, 0.2854323387145996, -0.3876718282699585 60 | }; 61 | 62 | static const float gru1_bias[18] = { 63 | 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 64 | 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 65 | 0.0, 0.0 66 | }; 67 | 68 | const GRULayer gru1 = { 69 | gru1_bias, 70 | gru1_weights, 71 | gru1_recurrent_weights, 72 | 2, 3, ACTIVATION_TANH, 1 73 | }; 74 | 75 | static const float gru1_output[9] = { 76 | -0.041024111211299896, -0.10351210832595825, 0.005531159229576588, -0.07247771322727203, -0.1559143215417862, 0.03240172564983368, -0.09284657984972, -0.17878594994544983, 77 | 0.05450616776943207 78 | }; 79 | 80 | -------------------------------------------------------------------------------- /tests/testnnet.cpp: -------------------------------------------------------------------------------- 1 | #include "nnet.h" 2 | #include "nnet_data_test.h" 3 | #include 4 | #include 5 | #include 6 | 7 | // Demonstrate some basic assertions. 8 | /* 9 | TEST(TestNnet, BasicAssertions) { 10 | // Expect two strings not to be equal. 11 | EXPECT_STRNE("hello", "world"); 12 | // Expect equality. 13 | EXPECT_EQ(7 * 6, 42); 14 | } 15 | */ 16 | 17 | 18 | 19 | TEST(TestNnet, fcCheck) { 20 | float eps = 1e-5; 21 | std::vector fc_input(fc.nb_inputs, 0.5); 22 | std::vector fc_output_c(fc.nb_neurons, 0); 23 | compute_dense(&fc, &fc_output_c[0], &fc_input[0]); 24 | 25 | for(int i=0; i first_conv1d_state(conv1.kernel_size*conv1.nb_inputs,0); 33 | float eps = 1e-5; 34 | std::vector conv1_input(conv1.nb_inputs*3, 0.5); 35 | std::vector conv1_output_c(conv1.nb_neurons, 0); 36 | compute_conv1d(&conv1, &conv1_output_c[0], &first_conv1d_state[0], &conv1_input[0]); 37 | compute_conv1d(&conv1, &conv1_output_c[0], &first_conv1d_state[0], &conv1_input[0]); 38 | //EXPECT_EQ(conv1_output_c.size(), sizeof(conv1_output)/sizeof(float)); 39 | for(int i=0; i gru1_state(gru1.nb_neurons,0); 53 | std::vector gru1_input(gru1.nb_inputs, 0.5); 54 | //std::vector gru1_output_c(gr.nb_neurons, 0); 55 | compute_gru(&gru1,&gru1_state[0], &gru1_input[0]); 56 | 57 | for(int i=0; i fc_input(fc.nb_inputs, 0.5); 72 | std::vector fc_output_c(fc.nb_neurons, 0); 73 | compute_dense(&fc, &fc_output_c[0], &fc_input[0]); 74 | 75 | for(int i=0; imodel->gb_gru, rnn->gb_gru_state, rnn->gru3_state); 83 | 84 | 85 | return 0; 86 | } 87 | */ -------------------------------------------------------------------------------- /utils/DNS_Challenge.yaml: -------------------------------------------------------------------------------- 1 | ########################################################### 2 | # DATA LOADER SETTING # 3 | ########################################################### 4 | batch_size: 64 # Batch size. 5 | pin_memory: true # Whether to pin memory in Pytorch DataLoader. 6 | num_workers: 4 # Number of workers in Pytorch DataLoader. 7 | allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. 8 | 9 | ########################################################### 10 | # INTERVAL SETTING # 11 | ########################################################### 12 | train_max_steps: 100000 # Number of training steps. 13 | save_interval_steps: 1000 # Interval steps to save checkpoint. 14 | eval_interval_steps: 1000 # Interval steps to evaluate the network. 15 | log_interval_steps: 1000 # Interval steps to record the training log. 16 | -------------------------------------------------------------------------------- /utils/__pycache__/filterbanks.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jzi040941/PercepNet/8ffae4337d23f920176ac2a7426e84610fa338ab/utils/__pycache__/filterbanks.cpython-35.pyc -------------------------------------------------------------------------------- /utils/bin2h5.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import sys 4 | 5 | if(len(sys.argv)<2): 6 | print("wrong usage: bin2h5.py [binary_filename] [h5_filename] ") 7 | 8 | input_and_output_dim=138 9 | bin_file_name=sys.argv[1] 10 | data = np.fromfile(bin_file_name, dtype='float32') 11 | data = np.reshape(data, (len(data)//input_and_output_dim, input_and_output_dim)) 12 | h5f = h5py.File(sys.argv[2], 'w'); 13 | h5f.create_dataset('data', data=data) 14 | h5f.close() -------------------------------------------------------------------------------- /utils/filterbanks.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created 06/03/2018 3 | @author: Will Wilkinson 4 | ''' 5 | 6 | import numpy as np 7 | 8 | 9 | class FilterBank(object): 10 | """ 11 | Based on Josh McDermott's Matlab filterbank code: 12 | http://mcdermottlab.mit.edu/Sound_Texture_Synthesis_Toolbox_v1.7.zip 13 | 14 | leny = filter bank length in samples 15 | fs = sample rate 16 | N = number of frequency channels / subbands (excluding high-&low-pass which are added for perfect reconstruction) 17 | low_lim = centre frequency of first (lowest) channel 18 | high_lim = centre frequency of last (highest) channel 19 | """ 20 | def __init__(self, leny, fs, N, low_lim, high_lim): 21 | self.leny = leny 22 | self.fs = fs 23 | self.N = N 24 | self.low_lim = low_lim 25 | self.high_lim, self.freqs, self.nfreqs = self.check_limits(leny, fs, high_lim) 26 | 27 | def check_limits(self, leny, fs, high_lim): 28 | if np.remainder(leny, 2) == 0: 29 | nfreqs = leny / 2 30 | max_freq = fs / 2 31 | else: 32 | nfreqs = (leny - 1) / 2 33 | max_freq = fs * (leny - 1) / 2 / leny 34 | freqs = np.linspace(0, max_freq, nfreqs + 1) 35 | if high_lim > fs / 2: 36 | high_lim = max_freq 37 | return high_lim, freqs, int(nfreqs) 38 | 39 | def generate_subbands(self, signal): 40 | if signal.shape[0] == 1: # turn into column vector 41 | signal = np.transpose(signal) 42 | N = self.filters.shape[1] - 2 43 | signal_length = signal.shape[0] 44 | filt_length = self.filters.shape[0] 45 | # watch out: numpy fft acts on rows, whereas Matlab fft acts on columns 46 | fft_sample = np.transpose(np.asmatrix(np.fft.fft(signal))) 47 | # generate negative frequencies in right place; filters are column vectors 48 | if np.remainder(signal_length, 2) == 0: # even length 49 | fft_filts = np.concatenate([self.filters, np.flipud(self.filters[1:filt_length - 1, :])]) 50 | else: # odd length 51 | fft_filts = np.concatenate([self.filters, np.flipud(self.filters[1:filt_length, :])]) 52 | # multiply by array of column replicas of fft_sample 53 | tile = np.dot(fft_sample, np.ones([1, N + 2])) 54 | fft_subbands = np.multiply(fft_filts, tile) 55 | # ifft works on rows; imag part is small, probably discretization error? 56 | self.subbands = np.transpose(np.real(np.fft.ifft(np.transpose(fft_subbands)))) 57 | 58 | 59 | class EqualRectangularBandwidth(FilterBank): 60 | def __init__(self, leny, fs, N, low_lim, high_lim): 61 | super(EqualRectangularBandwidth, self).__init__(leny, fs, N, low_lim, high_lim) 62 | # make cutoffs evenly spaced on an erb scale 63 | erb_low = self.freq2erb(self.low_lim) 64 | erb_high = self.freq2erb(self.high_lim) 65 | erb_lims = np.linspace(erb_low, erb_high, self.N + 2) 66 | self.cutoffs = self.erb2freq(erb_lims) 67 | self.filters = self.make_filters(self.N, self.nfreqs, self.freqs, self.cutoffs) 68 | 69 | def freq2erb(self, freq_Hz): 70 | n_erb = 9.265 * np.log(1 + np.divide(freq_Hz, 24.7 * 9.265)) 71 | return n_erb 72 | 73 | def erb2freq(self, n_erb): 74 | freq_Hz = 24.7 * 9.265 * (np.exp(np.divide(n_erb, 9.265)) - 1) 75 | return freq_Hz 76 | 77 | def make_filters(self, N, nfreqs, freqs, cutoffs): 78 | cos_filts = np.zeros([nfreqs + 1, N]) 79 | for k in range(N): 80 | l_k = cutoffs[k] 81 | h_k = cutoffs[k + 2] # adjacent filters overlap by 50% 82 | l_ind = np.min(np.where(freqs > l_k)) 83 | h_ind = np.max(np.where(freqs < h_k)) 84 | avg = (self.freq2erb(l_k) + self.freq2erb(h_k)) / 2 85 | rnge = self.freq2erb(h_k) - self.freq2erb(l_k) 86 | # map cutoffs to -pi/2, pi/2 interval 87 | cos_filts[l_ind:h_ind + 1, k] = np.cos((self.freq2erb(freqs[l_ind:h_ind + 1]) - avg) / rnge * np.pi) 88 | # add lowpass and highpass to get perfect reconstruction 89 | filters = np.zeros([nfreqs + 1, N + 2]) 90 | filters[:, 1:N + 1] = cos_filts 91 | # lowpass filter goes up to peak of first cos filter 92 | h_ind = np.max(np.where(freqs < cutoffs[1])) 93 | filters[:h_ind + 1, 0] = np.sqrt(1 - np.power(filters[:h_ind + 1, 1], 2)) 94 | # highpass filter goes down to peak of last cos filter 95 | l_ind = np.min(np.where(freqs > cutoffs[N])) 96 | filters[l_ind:nfreqs + 1, N + 1] = np.sqrt(1 - np.power(filters[l_ind:nfreqs + 1, N], 2)) 97 | return filters 98 | 99 | 100 | class Linear(FilterBank): 101 | def __init__(self, leny, fs, N, low_lim, high_lim): 102 | super(Linear, self).__init__(leny, fs, N, low_lim, high_lim) 103 | self.cutoffs = np.linspace(self.low_lim, self.high_lim, self.N + 2) 104 | self.filters = self.make_filters(self.N, self.nfreqs, self.freqs, self.cutoffs) 105 | 106 | def make_filters(self, N, nfreqs, freqs, cutoffs): 107 | cos_filts = np.zeros([nfreqs + 1, N]) 108 | for k in range(N): 109 | l_k = cutoffs[k] 110 | h_k = cutoffs[k + 2] # adjacent filters overlap by 50% 111 | l_ind = np.min(np.where(freqs > l_k)) 112 | h_ind = np.max(np.where(freqs < h_k)) 113 | avg = (l_k + h_k) / 2 114 | rnge = h_k - l_k 115 | # map cutoffs to -pi/2, pi/2 interval 116 | cos_filts[l_ind:h_ind + 1, k] = np.cos((freqs[l_ind:h_ind + 1] - avg) / rnge * np.pi) 117 | # add lowpass and highpass to get perfect reconstruction 118 | filters = np.zeros([nfreqs + 1, N + 2]) 119 | filters[:, 1:N + 1] = cos_filts 120 | # lowpass filter goes up to peak of first cos filter 121 | h_ind = np.max(np.where(freqs < cutoffs[1])) 122 | filters[:h_ind + 1, 0] = np.sqrt(1 - np.power(filters[:h_ind + 1, 1], 2)) 123 | # highpass filter goes down to peak of last cos filter 124 | l_ind = np.min(np.where(freqs > cutoffs[N])) 125 | filters[l_ind:nfreqs + 1, N + 1] = np.sqrt(1 - np.power(filters[l_ind:nfreqs + 1, N], 2)) 126 | return filters 127 | -------------------------------------------------------------------------------- /utils/parse_options.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey); 4 | # Arnab Ghoshal, Karel Vesely 5 | 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 14 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 15 | # MERCHANTABLITY OR NON-INFRINGEMENT. 16 | # See the Apache 2 License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | # Parse command-line options. 21 | # To be sourced by another script (as in ". parse_options.sh"). 22 | # Option format is: --option-name arg 23 | # and shell variable "option_name" gets set to value "arg." 24 | # The exception is --help, which takes no arguments, but prints the 25 | # $help_message variable (if defined). 26 | 27 | 28 | ### 29 | ### The --config file options have lower priority to command line 30 | ### options, so we need to import them first... 31 | ### 32 | 33 | # Now import all the configs specified by command-line, in left-to-right order 34 | for ((argpos=1; argpos<$#; argpos++)); do 35 | if [ "${!argpos}" == "--config" ]; then 36 | argpos_plus1=$((argpos+1)) 37 | config=${!argpos_plus1} 38 | [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 39 | . $config # source the config file. 40 | fi 41 | done 42 | 43 | 44 | ### 45 | ### No we process the command line options 46 | ### 47 | while true; do 48 | [ -z "${1:-}" ] && break; # break if there are no arguments 49 | case "$1" in 50 | # If the enclosing script is called with --help option, print the help 51 | # message and exit. Scripts should put help messages in $help_message 52 | --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; 53 | else printf "$help_message\n" 1>&2 ; fi; 54 | exit 0 ;; 55 | --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" 56 | exit 1 ;; 57 | # If the first command-line argument begins with "--" (e.g. --foo-bar), 58 | # then work out the variable name as $name, which will equal "foo_bar". 59 | --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; 60 | # Next we test whether the variable in question is undefned-- if so it's 61 | # an invalid option and we die. Note: $0 evaluates to the name of the 62 | # enclosing script. 63 | # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar 64 | # is undefined. We then have to wrap this test inside "eval" because 65 | # foo_bar is itself inside a variable ($name). 66 | eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; 67 | 68 | oldval="`eval echo \\$$name`"; 69 | # Work out whether we seem to be expecting a Boolean argument. 70 | if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then 71 | was_bool=true; 72 | else 73 | was_bool=false; 74 | fi 75 | 76 | # Set the variable to the right value-- the escaped quotes make it work if 77 | # the option had spaces, like --cmd "queue.pl -sync y" 78 | eval $name=\"$2\"; 79 | 80 | # Check that Boolean-valued arguments are really Boolean. 81 | if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then 82 | echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 83 | exit 1; 84 | fi 85 | shift 2; 86 | ;; 87 | *) break; 88 | esac 89 | done 90 | 91 | 92 | # Check for an empty argument to the --cmd option, which can easily occur as a 93 | # result of scripting errors. 94 | [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; 95 | 96 | 97 | true; # so this script returns exit code 0. -------------------------------------------------------------------------------- /utils/path.sh: -------------------------------------------------------------------------------- 1 | # path related 2 | export PRJ_ROOT="${PWD}/.." -------------------------------------------------------------------------------- /utils/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) 2021 Seonghun Noh 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions 6 | # are met: 7 | 8 | # - Redistributions of source code must retain the above copyright 9 | # notice, this list of conditions and the following disclaimer. 10 | 11 | # - Redistributions in binary form must reproduce the above copyright 12 | # notice, this list of conditions and the following disclaimer in the 13 | # documentation and/or other materials provided with the distribution. 14 | 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | # ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | 28 | . ./path.sh || exit 1; 29 | . ./parse_options.sh || exit 1; 30 | 31 | dataset_dir="training_set_sept12_500h" 32 | noisy_wav_dir="noisy" 33 | clean_wav_dir="clean" 34 | noisy_pcm_dir="noisy_pcm" 35 | clean_pcm_dir="clean_pcm" 36 | feature_dir="features" 37 | h5_train_dir="h5_train" 38 | h5_dev_dir="h5_dev" 39 | out_dir="exp_erbfix_x30_snr45_rmax99" 40 | #out_dir="exp_test" 41 | model_filename="model.pt" 42 | train_size_per_batch=2000 43 | config="DNS_Challenge.yaml" 44 | #pretrain="/home/seonghun/develop/PercepNet/training_set_sept12_500h/exp_erbfix_band30times_nopretrain/checkpoint-10000steps.pkl" 45 | 46 | stage=4 #stage to start 47 | stop_stage=4 #stop stage 48 | 49 | NR_CPUS=8 #TODO: automatically detect how many cpu have 50 | 51 | 52 | ################################################### 53 | # mkdir related dir if not exist # 54 | ################################################### 55 | mkdir -p ${PRJ_ROOT}/${dataset_dir}/{${noisy_pcm_dir},${clean_pcm_dir},${h5_train_dir},${h5_dev_dir},${feature_dir},${out_dir}} 56 | 57 | ################################################### 58 | # stage1 :resample to 48khz and convert wav to pcm# 59 | ################################################### 60 | if [ "${stage}" -le 1 ] && [ "${stop_stage}" -ge 1 ]; then 61 | i=0 62 | mkdir -p ${PRJ_ROOT}/${dataset_dir}/${noisy_wav_dir}/fileid 63 | mkdir -p ${PRJ_ROOT}/${dataset_dir}/${clean_wav_dir}/fileid 64 | for wavfilepath in ${PRJ_ROOT}/${dataset_dir}/${noisy_wav_dir}/*.wav; do 65 | ((i=i%NR_CPUS)); ((i++==0)) && wait 66 | ( 67 | # pcmfilename="`basename "${wavfilepath##*fileid_}" .wav`.pcm" 68 | # pcmfilepath=${PRJ_ROOT}/${dataset_dir}/${noisy_pcm_dir}/${pcmfilename} 69 | # sox ${wavfilepath} -b 16 -e signed-integer -c 1 -r 48k -t raw ${pcmfilepath} 70 | 71 | # reduce disk usage 72 | newwavfilename="`basename "${wavfilepath##*fileid_}" .wav`.wav" 73 | mv ${wavfilepath} ${PRJ_ROOT}/${dataset_dir}/${noisy_pcm_dir}/${newwavfilename} 74 | ) & 75 | done 76 | wait 77 | 78 | i=0 79 | for wavfilepath in ${PRJ_ROOT}/${dataset_dir}/${clean_wav_dir}/*.wav; do 80 | ((i=i%NR_CPUS)); ((i++==0)) && wait 81 | ( 82 | # pcmfilename="`basename "${wavfilepath##*fileid_}" .wav`.pcm" 83 | # pcmfilepath=${PRJ_ROOT}/${dataset_dir}/${clean_pcm_dir}/${pcmfilename} 84 | # sox ${wavfilepath} -b 16 -e signed-integer -c 1 -r 48k -t raw ${pcmfilepath} 85 | newwavfilename="`basename "${wavfilepath##*fileid_}" .wav`.wav" 86 | mv ${wavfilepath} ${PRJ_ROOT}/${dataset_dir}/${clean_pcm_dir}/${newwavfilename} 87 | ) & 88 | done 89 | wait 90 | fi 91 | 92 | ################################################### 93 | #Generate c++ feature data for each noisy and clean data # 94 | ################################################### 95 | if [ "${stage}" -le 2 ] && [ "${stop_stage}" -ge 2 ]; then 96 | i=0 97 | 98 | for noisy_wav_filepath in ${PRJ_ROOT}/${dataset_dir}/${noisy_pcm_dir}/*.wav; do 99 | ((i=i%NR_CPUS)); ((i++==0)) && wait 100 | ( 101 | fileid=`basename "${noisy_wav_filepath%%.wav*}" .wav` 102 | noisy_pcm_filepath="${PRJ_ROOT}/${dataset_dir}/${noisy_pcm_dir}/${fileid}.pcm" 103 | sox ${noisy_wav_filepath} -b 16 -e signed-integer -c 1 -r 48k -t raw ${noisy_pcm_filepath} 104 | clean_wav_filepath="${PRJ_ROOT}/${dataset_dir}/${clean_pcm_dir}/${fileid}.wav" 105 | clean_pcm_filepath="${PRJ_ROOT}/${dataset_dir}/${clean_pcm_dir}/${fileid}.pcm" 106 | sox ${clean_wav_filepath} -b 16 -e signed-integer -c 1 -r 48k -t raw ${clean_pcm_filepath} 107 | ${PRJ_ROOT}/bin/src/percepNet ${clean_pcm_filepath} \ 108 | ${noisy_pcm_filepath} \ 109 | ${train_size_per_batch} \ 110 | ${PRJ_ROOT}/${dataset_dir}/${feature_dir}/${fileid}.out 111 | rm ${noisy_pcm_filepath} 112 | rm ${clean_pcm_filepath} 113 | echo "genereated ${fileid}.out" 114 | ) & 115 | done 116 | wait 117 | fi 118 | 119 | ################################################### 120 | #Convert features to h5 files & split dataset # 121 | ################################################### 122 | if [ "${stage}" -le 3 ] && [ "${stop_stage}" -ge 3 ]; then 123 | python3 split_feature_dataset.py ${PRJ_ROOT}/${dataset_dir}/${feature_dir} 124 | # for featurefile in `cat ${PRJ_ROOT}/${dataset_dir}/${feature_dir}/train.txt`; do 125 | # fileid=`basename ${featurefile} .out` 126 | # python3 bin2h5.py ${featurefile} ${PRJ_ROOT}/${dataset_dir}/${h5_train_dir}/${fileid}.h5 127 | # done 128 | # for featurefile in `cat ${PRJ_ROOT}/${dataset_dir}/${feature_dir}/dev.txt`; do 129 | # fileid=`basename ${featurefile} .out` 130 | # python3 bin2h5.py ${featurefile} ${PRJ_ROOT}/${dataset_dir}/${h5_dev_dir}/${fileid}.h5 131 | # done 132 | fi 133 | 134 | ################################################### 135 | #Train pytorch model # 136 | ################################################### 137 | 138 | if [ "${stage}" -le 4 ] && [ "${stop_stage}" -ge 4 ]; then 139 | echo "'--train_length_size', '${train_size_per_batch}', '--train_filelist_path', '${PRJ_ROOT}/${dataset_dir}/${feature_dir}/train.txt', \ 140 | '--dev_filelist_path', '${PRJ_ROOT}/${dataset_dir}/${feature_dir}/dev.txt', \ 141 | '--out_dir', '${PRJ_ROOT}/${dataset_dir}/${out_dir}', '--config', '${config}'" 142 | python3 ${PRJ_ROOT}/rnn_train.py --train_length_size ${train_size_per_batch} --train_filelist_path ${PRJ_ROOT}/${dataset_dir}/${feature_dir}/train.txt \ 143 | --dev_filelist_path ${PRJ_ROOT}/${dataset_dir}/${feature_dir}/dev.txt \ 144 | --out_dir ${PRJ_ROOT}/${dataset_dir}/${out_dir} --config ${config} 145 | fi 146 | 147 | ################################################### 148 | #Convert pytorch model to c++ header # 149 | ################################################### 150 | if [ "${stage}" -le 5 ] && [ "${stop_stage}" -ge 5 ]; then 151 | python3 dump_percepnet.py ${model_filename} 152 | fi -------------------------------------------------------------------------------- /utils/split_feature_dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | import sys 4 | import os 5 | 6 | def main(path): 7 | dev_ratio = 0.2 8 | filelist = glob.glob(os.path.join(path, "*.out")) 9 | random.shuffle(filelist) 10 | border_idx = int(len(filelist)*(1-dev_ratio)) 11 | train_set_list = filelist[:border_idx] 12 | dev_set_list = filelist[border_idx:] 13 | with open(os.path.join(path, "train.txt"), "w") as outfile: 14 | outfile.write("\n".join(train_set_list)) 15 | with open(os.path.join(path, "dev.txt"), "w") as outfile: 16 | outfile.write("\n".join(dev_set_list)) 17 | 18 | if __name__ == '__main__': 19 | main(sys.argv[1]) --------------------------------------------------------------------------------