├── .gitlab-ci.yml ├── AUTHORS ├── COPYING ├── Makefile.am ├── README ├── autogen.sh ├── configure.ac ├── datasets.txt ├── doc ├── Doxyfile.in └── Makefile ├── download_model.sh ├── examples └── rnnoise_demo.c ├── include └── rnnoise.h ├── m4 └── attributes.m4 ├── model_version ├── rnnoise-uninstalled.pc.in ├── rnnoise.pc.in ├── scripts ├── dump_features_parallel.sh ├── rir_deconv.py ├── shrink_model.sh └── sweep.py ├── src ├── _kiss_fft_guts.h ├── arch.h ├── celt_lpc.c ├── celt_lpc.h ├── common.h ├── compile.sh ├── cpu_support.h ├── denoise.c ├── denoise.h ├── dump_features.c ├── dump_rnnoise_tables.c ├── kiss_fft.c ├── kiss_fft.h ├── nnet.c ├── nnet.h ├── nnet_arch.h ├── nnet_default.c ├── opus_types.h ├── parse_lpcnet_weights.c ├── pitch.c ├── pitch.h ├── rnn.c ├── rnn.h ├── rnn_train.py ├── rnnoise_tables.c ├── vec.h ├── vec_avx.h ├── vec_neon.h ├── write_weights.c └── x86 │ ├── dnn_x86.h │ ├── nnet_avx2.c │ ├── nnet_sse4_1.c │ ├── x86_arch_macros.h │ ├── x86_dnn_map.c │ ├── x86cpu.c │ └── x86cpu.h ├── torch ├── rnnoise │ ├── dump_rnnoise_weights.py │ ├── rnnoise.py │ └── train_rnnoise.py ├── sparsification │ ├── __init__.py │ ├── common.py │ └── gru_sparsifier.py └── weight-exchange │ ├── README.md │ ├── requirements.txt │ ├── setup.py │ └── wexchange │ ├── __init__.py │ ├── c_export │ ├── __init__.py │ ├── c_writer.py │ └── common.py │ ├── tf │ ├── __init__.py │ └── tf.py │ └── torch │ ├── __init__.py │ └── torch.py ├── training ├── bin2hdf5.py ├── dump_rnn.py └── rnn_train.py └── update_version /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | default: 2 | tags: 3 | - docker 4 | # Image from https://hub.docker.com/_/gcc/ based on Debian 5 | image: gcc:9 6 | 7 | .autoconf: 8 | stage: build 9 | before_script: 10 | - apt-get update && 11 | apt-get install -y git ${INSTALL_COMPILER} zip ${INSTALL_EXTRA} 12 | script: 13 | - ./autogen.sh 14 | - ./configure --enable-x86-rtcd ${CONFIG_FLAGS} || cat config.log 15 | - make 16 | - make ${CHECKTARGET} 17 | - nm $(find . -name librnnoise.a) | awk '/ T / {print $3}' | sort 18 | variables: 19 | INSTALL_COMPILER: gcc g++ 20 | CHECKTARGET: check 21 | 22 | autoconf-gcc: 23 | extends: .autoconf 24 | variables: 25 | CHECKTARGET: distcheck 26 | 27 | autoconf-clang: 28 | extends: .autoconf 29 | variables: 30 | INSTALL_COMPILER: clang 31 | CC: clang 32 | 33 | enable-assertions: 34 | extends: .autoconf 35 | variables: 36 | CONFIG_FLAGS: --enable-assertions 37 | 38 | enable-dnn-debug-float: 39 | extends: .autoconf 40 | variables: 41 | CONFIG_FLAGS: --enable-dnn-debug-float 42 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | Jean-Marc Valin 2 | -------------------------------------------------------------------------------- /COPYING: -------------------------------------------------------------------------------- 1 | Copyright (c) 2007-2017, 2024 Jean-Marc Valin 2 | Copyright (c) 2023 Amazon 3 | Copyright (c) 2017, Mozilla 4 | Copyright (c) 2005-2017, Xiph.Org Foundation 5 | Copyright (c) 2003-2004, Mark Borgerding 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions 9 | are met: 10 | 11 | - Redistributions of source code must retain the above copyright 12 | notice, this list of conditions and the following disclaimer. 13 | 14 | - Redistributions in binary form must reproduce the above copyright 15 | notice, this list of conditions and the following disclaimer in the 16 | documentation and/or other materials provided with the distribution. 17 | 18 | - Neither the name of the Xiph.Org Foundation nor the names of its 19 | contributors may be used to endorse or promote products derived from 20 | this software without specific prior written permission. 21 | 22 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 23 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 24 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 25 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION 26 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 27 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 28 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 30 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 31 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 32 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 | -------------------------------------------------------------------------------- /Makefile.am: -------------------------------------------------------------------------------- 1 | ACLOCAL_AMFLAGS = -I m4 2 | 3 | AM_CFLAGS = -I$(top_srcdir)/include -I$(top_srcdir)/src $(DEPS_CFLAGS) 4 | 5 | dist_doc_DATA = COPYING AUTHORS README 6 | 7 | include_HEADERS = include/rnnoise.h 8 | 9 | lib_LTLIBRARIES = librnnoise.la 10 | noinst_HEADERS = src/arch.h \ 11 | src/celt_lpc.h \ 12 | src/cpu_support.h \ 13 | src/common.h \ 14 | src/denoise.h \ 15 | src/_kiss_fft_guts.h \ 16 | src/kiss_fft.h \ 17 | src/nnet.h \ 18 | src/nnet_arch.h \ 19 | src/opus_types.h \ 20 | src/pitch.h \ 21 | src/rnn.h \ 22 | src/rnnoise_data.h \ 23 | src/vec_neon.h \ 24 | src/vec.h \ 25 | src/vec_avx.h \ 26 | src/x86/x86_arch_macros.h \ 27 | src/x86/x86cpu.h \ 28 | src/x86/dnn_x86.h 29 | 30 | RNNOISE_SOURCES = \ 31 | src/denoise.c \ 32 | src/rnn.c \ 33 | src/pitch.c \ 34 | src/kiss_fft.c \ 35 | src/celt_lpc.c \ 36 | src/nnet.c \ 37 | src/nnet_default.c \ 38 | src/parse_lpcnet_weights.c \ 39 | src/rnnoise_data.c \ 40 | src/rnnoise_tables.c 41 | 42 | RNNOISE_SOURCES_SSE4_1 = src/x86/nnet_sse4_1.c 43 | RNNOISE_SOURCES_AVX2 = src/x86/nnet_avx2.c 44 | 45 | X86_RTCD = src/x86/x86_dnn_map.c \ 46 | src/x86/x86cpu.c 47 | 48 | if RNN_ENABLE_X86_RTCD 49 | RNNOISE_SOURCES += $(X86_RTCD) $(RNNOISE_SOURCES_SSE4_1) $(RNNOISE_SOURCES_AVX2) 50 | endif 51 | 52 | librnnoise_la_SOURCES = $(RNNOISE_SOURCES) 53 | librnnoise_la_LIBADD = $(DEPS_LIBS) $(lrintf_lib) $(LIBM) 54 | librnnoise_la_LDFLAGS = -no-undefined \ 55 | -version-info @OP_LT_CURRENT@:@OP_LT_REVISION@:@OP_LT_AGE@ 56 | 57 | noinst_PROGRAMS = dump_features dump_weights_blob 58 | if OP_ENABLE_EXAMPLES 59 | noinst_PROGRAMS += examples/rnnoise_demo 60 | endif 61 | 62 | examples_rnnoise_demo_SOURCES = examples/rnnoise_demo.c 63 | examples_rnnoise_demo_LDADD = librnnoise.la 64 | 65 | dump_features_SOURCES = src/dump_features.c src/denoise.c src/pitch.c src/celt_lpc.c src/kiss_fft.c src/parse_lpcnet_weights.c src/rnnoise_tables.c 66 | dump_features_LDADD = $(LIBM) 67 | dump_features_CFLAGS = $(AM_CFLAGS) -DTRAINING 68 | 69 | dump_weights_blob_SOURCES = src/write_weights.c 70 | dump_weights_blob_LDADD = $(LIBM) 71 | dump_weights_blob_CFLAGS = $(AM_CFLAGS) -DDUMP_BINARY_WEIGHTS 72 | 73 | pkgconfigdir = $(libdir)/pkgconfig 74 | pkgconfig_DATA = rnnoise.pc 75 | 76 | debug: 77 | $(MAKE) CFLAGS="${CFLAGS} -O0 -ggdb -DOP_ENABLE_ASSERTIONS" all 78 | 79 | EXTRA_DIST = \ 80 | rnnoise.pc.in \ 81 | rnnoise-uninstalled.pc.in \ 82 | doc/Doxyfile.in \ 83 | doc/Makefile 84 | 85 | # Targets to build and install just the library without the docs 86 | librnnoise install-librnnoise: NO_DOXYGEN = 1 87 | 88 | rnnoise: all 89 | install-rnnoise: install 90 | 91 | # Or just the docs 92 | docs: doc/doxygen-build.stamp 93 | 94 | install-docs: 95 | @if [ -z "$(NO_DOXYGEN)" ]; then \ 96 | ( cd doc && \ 97 | echo "Installing documentation in $(DESTDIR)$(docdir)"; \ 98 | $(INSTALL) -d $(DESTDIR)$(docdir)/html/search; \ 99 | for f in `find html -type f \! -name "installdox"` ; do \ 100 | $(INSTALL_DATA) $$f $(DESTDIR)$(docdir)/$$f; \ 101 | done ) \ 102 | fi 103 | 104 | doc/doxygen-build.stamp: doc/Doxyfile \ 105 | $(top_srcdir)/include/*.h 106 | @[ -n "$(NO_DOXYGEN)" ] || ( cd doc && doxygen && touch $(@F) ) 107 | 108 | 109 | if HAVE_DOXYGEN 110 | 111 | # Or everything (by default) 112 | all-local: docs 113 | 114 | install-data-local: install-docs 115 | 116 | clean-local: 117 | $(RM) -r doc/html 118 | $(RM) -r doc/latex 119 | $(RM) doc/doxygen-build.stamp 120 | 121 | uninstall-local: 122 | $(RM) -r $(DESTDIR)$(docdir)/html 123 | 124 | endif 125 | 126 | # We check this every time make is run, with configure.ac being touched to 127 | # trigger an update of the build system files if update_version changes the 128 | # current PACKAGE_VERSION (or if package_version was modified manually by a 129 | # user with either AUTO_UPDATE=no or no update_version script present - the 130 | # latter being the normal case for tarball releases). 131 | # 132 | # We can't just add the package_version file to CONFIGURE_DEPENDENCIES since 133 | # simply running autoconf will not actually regenerate configure for us when 134 | # the content of that file changes (due to autoconf dependency checking not 135 | # knowing about that without us creating yet another file for it to include). 136 | # 137 | # The MAKECMDGOALS check is a gnu-make'ism, but will degrade 'gracefully' for 138 | # makes that don't support it. The only loss of functionality is not forcing 139 | # an update of package_version for `make dist` if AUTO_UPDATE=no, but that is 140 | # unlikely to be a real problem for any real user. 141 | $(top_srcdir)/configure.ac: force 142 | @case "$(MAKECMDGOALS)" in \ 143 | dist-hook) exit 0 ;; \ 144 | dist-* | dist | distcheck | distclean) _arg=release ;; \ 145 | esac; \ 146 | if ! $(top_srcdir)/update_version $$_arg 2> /dev/null; then \ 147 | if [ ! -e $(top_srcdir)/package_version ]; then \ 148 | echo 'PACKAGE_VERSION="unknown"' > $(top_srcdir)/package_version; \ 149 | fi; \ 150 | . $(top_srcdir)/package_version || exit 1; \ 151 | [ "$(PACKAGE_VERSION)" != "$$PACKAGE_VERSION" ] || exit 0; \ 152 | fi; \ 153 | touch $@ 154 | 155 | force: 156 | 157 | # Create a minimal package_version file when make dist is run. 158 | dist-hook: 159 | echo 'PACKAGE_VERSION="$(PACKAGE_VERSION)"' > $(top_distdir)/package_version 160 | 161 | 162 | .PHONY: rnnoise install-rnnoise docs install-docs 163 | 164 | if RNN_ENABLE_X86_RTCD 165 | SSE4_1_OBJ = $(RNNOISE_SOURCES_SSE4_1:.c=.lo) 166 | $(SSE4_1_OBJ): CFLAGS += $(OPUS_X86_SSE4_1_CFLAGS) 167 | 168 | AVX2_OBJ = $(RNNOISE_SOURCES_AVX2:.c=.lo) 169 | $(AVX2_OBJ): CFLAGS += $(OPUS_X86_AVX2_CFLAGS) 170 | endif 171 | -------------------------------------------------------------------------------- /README: -------------------------------------------------------------------------------- 1 | RNNoise is a noise suppression library based on a recurrent neural network. 2 | A description of the algorithm is provided in the following paper: 3 | 4 | J.-M. Valin, A Hybrid DSP/Deep Learning Approach to Real-Time Full-Band Speech 5 | Enhancement, Proceedings of IEEE Multimedia Signal Processing (MMSP) Workshop, 6 | arXiv:1709.08243, 2018. 7 | https://arxiv.org/pdf/1709.08243.pdf 8 | 9 | An interactive demo of version 0.1 is available at: https://jmvalin.ca/demo/rnnoise/ 10 | 11 | To compile, just type: 12 | % ./autogen.sh 13 | % ./configure 14 | % make 15 | 16 | Optionally: 17 | % make install 18 | 19 | It is recommended to either set -march= in the CFLAGS to an architecture 20 | with AVX2 support or to add --enable-x86-rtcd to the configure script 21 | so that AVX2 (or SSE4.1) can at least be used as an option. 22 | Note that the autogen.sh script will automatically download the model files 23 | from the Xiph.Org servers, since those are too large to put in Git. 24 | 25 | While it is meant to be used as a library, a simple command-line tool is 26 | provided as an example. It operates on RAW 16-bit (machine endian) mono 27 | PCM files sampled at 48 kHz. It can be used as: 28 | 29 | % ./examples/rnnoise_demo 30 | 31 | The output is also a 16-bit raw PCM file. 32 | NOTE AGAIN, THE INPUT and OUTPUT ARE IN RAW FORMAT, NOT WAV. 33 | 34 | The latest version of the source is available from 35 | https://gitlab.xiph.org/xiph/rnnoise . The GitHub repository 36 | is a convenience copy. 37 | 38 | == Training == 39 | 40 | The models distributed with RNNoise are now trained using only the publicly 41 | available datasets listed below and using the training precedure described 42 | here. Exact results will still depend on the the exact mix of data used, 43 | on how long the training is performed and on the various random seeds involved. 44 | 45 | To train an RNNoise model, you need both clean speech data, and noise data. 46 | Both need to be sampled at 48 kHz, in 16-bit PCM format (machine endian). 47 | Clean speech data can be obtained from the datasets listed in the datasets.txt 48 | file, or by downloaded the already-concatenation of those files in 49 | https://media.xiph.org/rnnoise/data/tts_speech_48k.sw 50 | For noise data, we suggest the background_noise.sw and foreground_noise.sw 51 | (or later versions) noise files from https://media.xiph.org/rnnoise/data/ 52 | The foreground_noise.sw file contains noise signals that are meant to be added 53 | to the background noise (e.g. keyboard sounds). Optionally, the foreground noise 54 | file can even be denoised with a traditional denoiser (e.g. libspeexdsp) to 55 | keep only the transient components. For background noise, the data from the 56 | original RNNoise noise collection have now been sufficiently filtered to 57 | provide good results -- either alone or in combination with the 58 | background_noise.sw file. The dataset can be downloaded (updated Jan 30th 2025) 59 | from: https://media.xiph.org/rnnoise/rnnoise_contributions.tar.gz 60 | 61 | The first step is to take the speech and noise, and mix them in a variety of 62 | ways to simulate real life conditions (including pauses, filtering and more). 63 | Assuming the files are called speech.pcm and noise.pcm, start by generating 64 | the training feature data with: 65 | 66 | % ./dump_features speech.pcm background_noise.pcm foreground_noise.pcm features.f32 67 | where is the number of sequences to process. The number of sequences 68 | should be at least 10000, but the more the better (200000 or more is 69 | recommended). 70 | 71 | Optionally, training can also simulate reverberation, in which case room impulse 72 | responses (RIR) are also needed. Limited RIR data is available at: 73 | https://media.xiph.org/rnnoise/data/measured_rirs-v2.tar.gz 74 | The format for those is raw 32-bit floating-point (files are little endian). 75 | Assuming a list of all the RIR files is contained in a rir_list.txt file, 76 | the training feature data can be generated with: 77 | 78 | % ./dump_features -rir_list rir_list.txt speech.pcm background_noise.pcm foreground_noise.pcm features.f32 79 | 80 | To make the feature generation faster, you can use the script provided in 81 | script/dump_features_parallel.sh (you will need to modify the script if you 82 | want to add RIR augmentation). 83 | 84 | To use it: 85 | % script/dump_features_parallel.sh ./dump_features speech.pcm background_noise.pcm foreground_noise.pcm features.f32 rir_list.txt 86 | which will run nb_processes processes, each for count sequences, and 87 | concatenate the output to a single file. 88 | 89 | Once the feature file is computed, you can start the training with: 90 | % python3 train_rnnoise.py features.f32 output_directory 91 | 92 | Choose a number of epochs (using --epochs) that leads to about 75000 weight 93 | updates. The training will produce .pth files, e.g. rnnoise_50.pth . 94 | The next step is to convert the model to C files using: 95 | 96 | % python3 dump_rnnoise_weights.py --quantize rnnoise_50.pth rnnoise_c 97 | 98 | which will produce the rnnoise_data.c and rnnoise_data.h files in the 99 | rnnoise_c directory. 100 | 101 | Copy these files to src/ and then build RNNoise using the instructions above. 102 | 103 | For slightly better results, a trained model can be used to remove any noise 104 | from the "clean" training speech, before restaring the denoising process 105 | again (no need to do that more than once). 106 | 107 | == Loadable Models == 108 | 109 | The model format has changed since v0.1.1. Models now use a binary 110 | "machine endian" format. To output a model in that format, build RNNoise 111 | with that model and use the dump_weights_blob executable to output a 112 | weights_blob.bin binary file. That file can then be used with the 113 | rnnoise_model_from_file() API call. Note that the model object MUST NOT 114 | be deleted while the RNNoise state is active and the file MUST NOT 115 | be closed. 116 | 117 | To avoid including the default model in the build (e.g. to reduce download 118 | size) and rely only on model loading, add -DUSE_WEIGHTS_FILE to the CFLAGS. 119 | To be able to load different models, the model size (and header file) needs 120 | to patch the size use during build. Otherwise the model will not load 121 | We provide a "little" model with half as an alternative. To use the smaller 122 | model, rename rnnoise_data_little.c to rnnoise_data.c. It is possible 123 | to build both the regular and little binary weights and load any of them 124 | at run time since the little model has the same size as the regular one 125 | (except for the increased sparsity). 126 | -------------------------------------------------------------------------------- /autogen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # Run this to set up the build system: configure, makefiles, etc. 3 | set -e 4 | 5 | srcdir=`dirname $0` 6 | test -n "$srcdir" && cd "$srcdir" 7 | 8 | ./download_model.sh 9 | 10 | echo "Updating build configuration files for rnnoise, please wait...." 11 | 12 | autoreconf -isf 13 | -------------------------------------------------------------------------------- /configure.ac: -------------------------------------------------------------------------------- 1 | # autoconf source script for generating configure 2 | 3 | dnl The package_version file will be automatically synced to the git revision 4 | dnl by the update_version script when configured in the repository, but will 5 | dnl remain constant in tarball releases unless it is manually edited. 6 | m4_define([CURRENT_VERSION], 7 | m4_esyscmd([ ./update_version 2>/dev/null || true 8 | if test -e package_version; then 9 | . ./package_version 10 | printf "$PACKAGE_VERSION" 11 | else 12 | printf "unknown" 13 | fi ])) 14 | 15 | AC_INIT([rnnoise],[CURRENT_VERSION],[jmvalin@jmvalin.ca]) 16 | AC_CONFIG_SRCDIR([src/denoise.c]) 17 | AC_CONFIG_MACRO_DIR([m4]) 18 | 19 | AC_USE_SYSTEM_EXTENSIONS 20 | AC_SYS_LARGEFILE 21 | 22 | AM_INIT_AUTOMAKE([1.11 foreign no-define dist-zip subdir-objects]) 23 | AM_MAINTAINER_MODE([enable]) 24 | 25 | AC_C_INLINE 26 | 27 | LT_INIT 28 | 29 | m4_ifdef([AM_SILENT_RULES], [AM_SILENT_RULES([yes])]) 30 | 31 | AC_DEFINE([RNNOISE_BUILD], [], [This is a build of the library]) 32 | 33 | dnl Library versioning for libtool. 34 | dnl Please update these for releases. 35 | dnl CURRENT, REVISION, AGE 36 | dnl - library source changed -> increment REVISION 37 | dnl - interfaces added/removed/changed -> increment CURRENT, REVISION = 0 38 | dnl - interfaces added -> increment AGE 39 | dnl - interfaces removed -> AGE = 0 40 | 41 | OP_LT_CURRENT=4 42 | OP_LT_REVISION=1 43 | OP_LT_AGE=4 44 | 45 | AC_SUBST(OP_LT_CURRENT) 46 | AC_SUBST(OP_LT_REVISION) 47 | AC_SUBST(OP_LT_AGE) 48 | 49 | CC_CHECK_CFLAGS_APPEND( 50 | [-pedantic -Wall -Wextra -Wno-sign-compare -Wno-parentheses -Wno-long-long]) 51 | 52 | # Platform-specific tweaks 53 | case $host in 54 | *-mingw*) 55 | # -std=c89 causes some warnings under mingw. 56 | CC_CHECK_CFLAGS_APPEND([-U__STRICT_ANSI__]) 57 | # We need WINNT>=0x501 (WindowsXP) for getaddrinfo/freeaddrinfo. 58 | # It's okay to define this even when HTTP support is disabled, as it only 59 | # affects header declarations, not linking (unless we actually use some 60 | # XP-only functions). 61 | AC_DEFINE_UNQUOTED(_WIN32_WINNT,0x501, 62 | [We need at least WindowsXP for getaddrinfo/freeaddrinfo]) 63 | host_mingw=true 64 | ;; 65 | esac 66 | AM_CONDITIONAL(OP_WIN32, test "$host_mingw" = "true") 67 | 68 | AC_ARG_ENABLE([assertions], 69 | AS_HELP_STRING([--enable-assertions], [Enable assertions in code]),, 70 | enable_assertions=no) 71 | 72 | AS_IF([test "$enable_assertions" = "yes"], [ 73 | AC_DEFINE([OP_ENABLE_ASSERTIONS], [1], [Enable assertions in code]) 74 | ]) 75 | 76 | AC_ARG_ENABLE([examples], 77 | AS_HELP_STRING([--disable-examples], [Do not build example applications]),, 78 | enable_examples=yes) 79 | AM_CONDITIONAL([OP_ENABLE_EXAMPLES], [test "$enable_examples" = "yes"]) 80 | 81 | AC_ARG_ENABLE([dnn-debug-float], 82 | AS_HELP_STRING([--enable-dnn-debug-float], [Use floating-point DNN computation everywhere]),, 83 | enable_dnn_debug_float=no) 84 | 85 | AS_IF([test "$enable_dnn_debug_float" = "no"], [ 86 | AC_DEFINE([DISABLE_DEBUG_FLOAT], [1], [Disable DNN debug float]) 87 | ]) 88 | 89 | OPUS_X86_SSE4_1_CFLAGS='-msse4.1' 90 | OPUS_X86_AVX2_CFLAGS='-mavx -mfma -mavx2' 91 | AC_SUBST([OPUS_X86_SSE4_1_CFLAGS]) 92 | AC_SUBST([OPUS_X86_AVX2_CFLAGS]) 93 | AC_ARG_ENABLE([x86-rtcd], 94 | AS_HELP_STRING([--enable-x86-rtcd], [x86 rtcd]),, 95 | enable_x86_rtcd=no) 96 | AM_CONDITIONAL([RNN_ENABLE_X86_RTCD], [test "$enable_x86_rtcd" = "yes"]) 97 | 98 | AS_IF([test "$enable_x86_rtcd" = "yes"], [ 99 | AC_DEFINE([RNN_ENABLE_X86_RTCD], [1], [Enable x86 rtcd]) 100 | AC_DEFINE([CPU_INFO_BY_ASM], [1], [RTCD from ASM only for now]) 101 | ]) 102 | 103 | AS_CASE(["$ac_cv_search_lrintf"], 104 | ["no"],[], 105 | ["none required"],[], 106 | [lrintf_lib="$ac_cv_search_lrintf"]) 107 | 108 | LT_LIB_M 109 | 110 | AC_SUBST([lrintf_lib]) 111 | 112 | CC_ATTRIBUTE_VISIBILITY([default], [ 113 | CC_FLAG_VISIBILITY([CFLAGS="${CFLAGS} -fvisibility=hidden"]) 114 | ]) 115 | 116 | dnl Check for doxygen 117 | AC_ARG_ENABLE([doc], 118 | AS_HELP_STRING([--disable-doc], [Do not build API documentation]),, 119 | [enable_doc=yes] 120 | ) 121 | 122 | AS_IF([test "$enable_doc" = "yes"], [ 123 | AC_CHECK_PROG([HAVE_DOXYGEN], [doxygen], [yes], [no]) 124 | AC_CHECK_PROG([HAVE_DOT], [dot], [yes], [no]) 125 | ],[ 126 | HAVE_DOXYGEN=no 127 | ]) 128 | 129 | AM_CONDITIONAL([HAVE_DOXYGEN], [test "$HAVE_DOXYGEN" = "yes"]) 130 | 131 | AC_CONFIG_FILES([ 132 | Makefile 133 | rnnoise.pc 134 | rnnoise-uninstalled.pc 135 | doc/Doxyfile 136 | ]) 137 | AC_CONFIG_HEADERS([config.h]) 138 | AC_OUTPUT 139 | 140 | AC_MSG_NOTICE([ 141 | ------------------------------------------------------------------------ 142 | $PACKAGE_NAME $PACKAGE_VERSION: Automatic configuration OK. 143 | 144 | Assertions ................... ${enable_assertions} 145 | 146 | Hidden visibility ............ ${cc_cv_flag_visibility} 147 | 148 | API code examples ............ ${enable_examples} 149 | API documentation ............ ${enable_doc} 150 | ------------------------------------------------------------------------ 151 | ]) 152 | -------------------------------------------------------------------------------- /datasets.txt: -------------------------------------------------------------------------------- 1 | The following clean speech datasets can be used to train RNNoise. 2 | A good choice is to include all the data from these datasets, except for 3 | hi_fi_tts for which only a small subset is recommended (since it's very large 4 | but has few speakers). Note that this data typically needs to be resampled 5 | before it can be used. 6 | 7 | https://www.openslr.org/resources/30/si_lk.tar.gz 8 | https://www.openslr.org/resources/32/af_za.tar.gz 9 | https://www.openslr.org/resources/32/st_za.tar.gz 10 | https://www.openslr.org/resources/32/tn_za.tar.gz 11 | https://www.openslr.org/resources/32/xh_za.tar.gz 12 | https://www.openslr.org/resources/37/bn_bd.zip 13 | https://www.openslr.org/resources/37/bn_in.zip 14 | https://www.openslr.org/resources/41/jv_id_female.zip 15 | https://www.openslr.org/resources/41/jv_id_male.zip 16 | https://www.openslr.org/resources/42/km_kh_male.zip 17 | https://www.openslr.org/resources/43/ne_np_female.zip 18 | https://www.openslr.org/resources/44/su_id_female.zip 19 | https://www.openslr.org/resources/44/su_id_male.zip 20 | https://www.openslr.org/resources/61/es_ar_female.zip 21 | https://www.openslr.org/resources/61/es_ar_male.zip 22 | https://www.openslr.org/resources/63/ml_in_female.zip 23 | https://www.openslr.org/resources/63/ml_in_male.zip 24 | https://www.openslr.org/resources/64/mr_in_female.zip 25 | https://www.openslr.org/resources/65/ta_in_female.zip 26 | https://www.openslr.org/resources/65/ta_in_male.zip 27 | https://www.openslr.org/resources/66/te_in_female.zip 28 | https://www.openslr.org/resources/66/te_in_male.zip 29 | https://www.openslr.org/resources/69/ca_es_female.zip 30 | https://www.openslr.org/resources/69/ca_es_male.zip 31 | https://www.openslr.org/resources/70/en_ng_female.zip 32 | https://www.openslr.org/resources/70/en_ng_male.zip 33 | https://www.openslr.org/resources/71/es_cl_female.zip 34 | https://www.openslr.org/resources/71/es_cl_male.zip 35 | https://www.openslr.org/resources/72/es_co_female.zip 36 | https://www.openslr.org/resources/72/es_co_male.zip 37 | https://www.openslr.org/resources/73/es_pe_female.zip 38 | https://www.openslr.org/resources/73/es_pe_male.zip 39 | https://www.openslr.org/resources/74/es_pr_female.zip 40 | https://www.openslr.org/resources/75/es_ve_female.zip 41 | https://www.openslr.org/resources/75/es_ve_male.zip 42 | https://www.openslr.org/resources/76/eu_es_female.zip 43 | https://www.openslr.org/resources/76/eu_es_male.zip 44 | https://www.openslr.org/resources/77/gl_es_female.zip 45 | https://www.openslr.org/resources/77/gl_es_male.zip 46 | https://www.openslr.org/resources/78/gu_in_female.zip 47 | https://www.openslr.org/resources/78/gu_in_male.zip 48 | https://www.openslr.org/resources/79/kn_in_female.zip 49 | https://www.openslr.org/resources/79/kn_in_male.zip 50 | https://www.openslr.org/resources/80/my_mm_female.zip 51 | https://www.openslr.org/resources/83/irish_english_male.zip 52 | https://www.openslr.org/resources/83/midlands_english_female.zip 53 | https://www.openslr.org/resources/83/midlands_english_male.zip 54 | https://www.openslr.org/resources/83/northern_english_female.zip 55 | https://www.openslr.org/resources/83/northern_english_male.zip 56 | https://www.openslr.org/resources/83/scottish_english_female.zip 57 | https://www.openslr.org/resources/83/scottish_english_male.zip 58 | https://www.openslr.org/resources/83/southern_english_female.zip 59 | https://www.openslr.org/resources/83/southern_english_male.zip 60 | https://www.openslr.org/resources/83/welsh_english_female.zip 61 | https://www.openslr.org/resources/83/welsh_english_male.zip 62 | https://www.openslr.org/resources/86/yo_ng_female.zip 63 | https://www.openslr.org/resources/86/yo_ng_male.zip 64 | https://www.openslr.org/resources/109/hi_fi_tts_v0.tar.gz 65 | 66 | The corresponding citations for all these datasets are: 67 | 68 | @inproceedings{demirsahin-etal-2020-open, 69 | title = {{Open-source Multi-speaker Corpora of the English Accents in the British Isles}}, 70 | author = {Demirsahin, Isin and Kjartansson, Oddur and Gutkin, Alexander and Rivera, Clara}, 71 | booktitle = {Proceedings of The 12th Language Resources and Evaluation Conference (LREC)}, 72 | month = may, 73 | year = {2020}, 74 | pages = {6532--6541}, 75 | address = {Marseille, France}, 76 | publisher = {European Language Resources Association (ELRA)}, 77 | url = {https://www.aclweb.org/anthology/2020.lrec-1.804}, 78 | ISBN = {979-10-95546-34-4}, 79 | } 80 | @inproceedings{kjartansson-etal-2020-open, 81 | title = {{Open-Source High Quality Speech Datasets for Basque, Catalan and Galician}}, 82 | author = {Kjartansson, Oddur and Gutkin, Alexander and Butryna, Alena and Demirsahin, Isin and Rivera, Clara}, 83 | booktitle = {Proceedings of the 1st Joint Workshop on Spoken Language Technologies for Under-resourced languages (SLTU) and Collaboration and Computing for Under-Resourced Languages (CCURL)}, 84 | year = {2020}, 85 | pages = {21--27}, 86 | month = may, 87 | address = {Marseille, France}, 88 | publisher = {European Language Resources association (ELRA)}, 89 | url = {https://www.aclweb.org/anthology/2020.sltu-1.3}, 90 | ISBN = {979-10-95546-35-1}, 91 | } 92 | 93 | 94 | @inproceedings{guevara-rukoz-etal-2020-crowdsourcing, 95 | title = {{Crowdsourcing Latin American Spanish for Low-Resource Text-to-Speech}}, 96 | author = {Guevara-Rukoz, Adriana and Demirsahin, Isin and He, Fei and Chu, Shan-Hui Cathy and Sarin, Supheakmungkol and Pipatsrisawat, Knot and Gutkin, Alexander and Butryna, Alena and Kjartansson, Oddur}, 97 | booktitle = {Proceedings of The 12th Language Resources and Evaluation Conference (LREC)}, 98 | year = {2020}, 99 | month = may, 100 | address = {Marseille, France}, 101 | publisher = {European Language Resources Association (ELRA)}, 102 | url = {https://www.aclweb.org/anthology/2020.lrec-1.801}, 103 | pages = {6504--6513}, 104 | ISBN = {979-10-95546-34-4}, 105 | } 106 | @inproceedings{he-etal-2020-open, 107 | title = {{Open-source Multi-speaker Speech Corpora for Building Gujarati, Kannada, Malayalam, Marathi, Tamil and Telugu Speech Synthesis Systems}}, 108 | author = {He, Fei and Chu, Shan-Hui Cathy and Kjartansson, Oddur and Rivera, Clara and Katanova, Anna and Gutkin, Alexander and Demirsahin, Isin and Johny, Cibu and Jansche, Martin and Sarin, Supheakmungkol and Pipatsrisawat, Knot}, 109 | booktitle = {Proceedings of The 12th Language Resources and Evaluation Conference (LREC)}, 110 | month = may, 111 | year = {2020}, 112 | address = {Marseille, France}, 113 | publisher = {European Language Resources Association (ELRA)}, 114 | pages = {6494--6503}, 115 | url = {https://www.aclweb.org/anthology/2020.lrec-1.800}, 116 | ISBN = "{979-10-95546-34-4}", 117 | } 118 | 119 | 120 | @inproceedings{kjartansson-etal-tts-sltu2018, 121 | title = {{A Step-by-Step Process for Building TTS Voices Using Open Source Data and Framework for Bangla, Javanese, Khmer, Nepali, Sinhala, and Sundanese}}, 122 | author = {Keshan Sodimana and Knot Pipatsrisawat and Linne Ha and Martin Jansche and Oddur Kjartansson and Pasindu De Silva and Supheakmungkol Sarin}, 123 | booktitle = {Proc. The 6th Intl. Workshop on Spoken Language Technologies for Under-Resourced Languages (SLTU)}, 124 | year = {2018}, 125 | address = {Gurugram, India}, 126 | month = aug, 127 | pages = {66--70}, 128 | URL = {http://dx.doi.org/10.21437/SLTU.2018-14} 129 | } 130 | 131 | 132 | @inproceedings{oo-etal-2020-burmese, 133 | title = {{Burmese Speech Corpus, Finite-State Text Normalization and Pronunciation Grammars with an Application to Text-to-Speech}}, 134 | author = {Oo, Yin May and Wattanavekin, Theeraphol and Li, Chenfang and De Silva, Pasindu and Sarin, Supheakmungkol and Pipatsrisawat, Knot and Jansche, Martin and Kjartansson, Oddur and Gutkin, Alexander}, 135 | booktitle = {Proceedings of The 12th Language Resources and Evaluation Conference (LREC)}, 136 | month = may, 137 | year = {2020}, 138 | pages = "6328--6339", 139 | address = {Marseille, France}, 140 | publisher = {European Language Resources Association (ELRA)}, 141 | url = {https://www.aclweb.org/anthology/2020.lrec-1.777}, 142 | ISBN = {979-10-95546-34-4}, 143 | } 144 | @inproceedings{van-niekerk-etal-2017, 145 | title = {{Rapid development of TTS corpora for four South African languages}}, 146 | author = {Daniel van Niekerk and Charl van Heerden and Marelie Davel and Neil Kleynhans and Oddur Kjartansson and Martin Jansche and Linne Ha}, 147 | booktitle = {Proc. Interspeech 2017}, 148 | pages = {2178--2182}, 149 | address = {Stockholm, Sweden}, 150 | month = aug, 151 | year = {2017}, 152 | URL = {http://dx.doi.org/10.21437/Interspeech.2017-1139} 153 | } 154 | 155 | @inproceedings{gutkin-et-al-yoruba2020, 156 | title = {{Developing an Open-Source Corpus of Yoruba Speech}}, 157 | author = {Alexander Gutkin and I{\c{s}}{\i}n Demir{\c{s}}ahin and Oddur Kjartansson and Clara Rivera and K\d{\'o}lá Túb\d{\`o}sún}, 158 | booktitle = {Proceedings of Interspeech 2020}, 159 | pages = {404--408}, 160 | month = {October}, 161 | year = {2020}, 162 | address = {Shanghai, China}, 163 | publisher = {International Speech and Communication Association (ISCA)}, 164 | doi = {10.21437/Interspeech.2020-1096}, 165 | url = {http://dx.doi.org/10.21437/Interspeech.2020-1096}, 166 | } 167 | 168 | @article{bakhturina2021hi, 169 | title={{Hi-Fi Multi-Speaker English TTS Dataset}}, 170 | author={Bakhturina, Evelina and Lavrukhin, Vitaly and Ginsburg, Boris and Zhang, Yang}, 171 | journal={arXiv preprint arXiv:2104.01497}, 172 | year={2021} 173 | } 174 | -------------------------------------------------------------------------------- /doc/Doxyfile.in: -------------------------------------------------------------------------------- 1 | # Process with doxygen to generate API documentation 2 | 3 | PROJECT_NAME = @PACKAGE_NAME@ 4 | PROJECT_NUMBER = @PACKAGE_VERSION@ 5 | PROJECT_BRIEF = "RNN-based noise suppressor." 6 | INPUT = @top_srcdir@/include/rnnoise.h 7 | OPTIMIZE_OUTPUT_FOR_C = YES 8 | 9 | QUIET = YES 10 | WARNINGS = YES 11 | WARN_IF_UNDOCUMENTED = YES 12 | WARN_IF_DOC_ERROR = YES 13 | WARN_NO_PARAMDOC = YES 14 | 15 | JAVADOC_AUTOBRIEF = YES 16 | SORT_MEMBER_DOCS = NO 17 | 18 | HAVE_DOT = @HAVE_DOT@ 19 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | ## GNU makefile for rnnoise documentation. 2 | 3 | -include ../package_version 4 | 5 | all: doxygen 6 | 7 | doxygen: Doxyfile ../include/rnnoise.h 8 | doxygen 9 | 10 | pdf: doxygen 11 | make -C latex 12 | 13 | clean: 14 | $(RM) -r html 15 | $(RM) -r latex 16 | 17 | distclean: clean 18 | $(RM) Doxyfile 19 | 20 | .PHONY: all clean distclean doxygen pdf 21 | 22 | ../package_version: 23 | @if [ -x ../update_version ]; then \ 24 | ../update_version || true; \ 25 | elif [ ! -e $@ ]; then \ 26 | echo 'PACKAGE_VERSION="unknown"' > $@; \ 27 | fi 28 | 29 | # run autoconf-like replacements to finalize our config 30 | Doxyfile: Doxyfile.in Makefile ../package_version 31 | sed -e 's/@PACKAGE_NAME@/rnnoise/' \ 32 | -e 's/@PACKAGE_VERSION@/$(PACKAGE_VERSION)/' \ 33 | -e 's/@top_srcdir@/../' \ 34 | < $< > $@ 35 | -------------------------------------------------------------------------------- /download_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -e 3 | 4 | hash=`cat model_version` 5 | model=rnnoise_data-$hash.tar.gz 6 | 7 | if [ ! -f $model ]; then 8 | echo "Downloading latest model" 9 | wget https://media.xiph.org/rnnoise/models/$model 10 | fi 11 | 12 | if command -v sha256sum 13 | then 14 | echo "Validating checksum" 15 | checksum="$hash" 16 | checksum2=$(sha256sum $model | awk '{print $1}') 17 | if [ "$checksum" != "$checksum2" ] 18 | then 19 | echo "Aborting due to mismatching checksums. This could be caused by a corrupted download of $model." 20 | echo "Consider deleting local copy of $model and running this script again." 21 | exit 1 22 | else 23 | echo "checksums match" 24 | fi 25 | else 26 | echo "Could not find sha256 sum; skipping verification. Please verify manually that sha256 hash of ${model} matches ${1}." 27 | fi 28 | 29 | 30 | tar xvomf $model 31 | 32 | -------------------------------------------------------------------------------- /examples/rnnoise_demo.c: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2018 Gregor Richards 2 | * Copyright (c) 2017 Mozilla */ 3 | /* 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | - Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | - Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 19 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | 28 | #include 29 | #include "rnnoise.h" 30 | 31 | #define FRAME_SIZE 480 32 | 33 | int main(int argc, char **argv) { 34 | int i; 35 | int first = 1; 36 | float x[FRAME_SIZE]; 37 | FILE *f1, *fout; 38 | DenoiseState *st; 39 | #ifdef USE_WEIGHTS_FILE 40 | RNNModel *model = rnnoise_model_from_filename("weights_blob.bin"); 41 | st = rnnoise_create(model); 42 | #else 43 | st = rnnoise_create(NULL); 44 | #endif 45 | 46 | if (argc!=3) { 47 | fprintf(stderr, "usage: %s \n", argv[0]); 48 | return 1; 49 | } 50 | f1 = fopen(argv[1], "rb"); 51 | fout = fopen(argv[2], "wb"); 52 | while (1) { 53 | short tmp[FRAME_SIZE]; 54 | fread(tmp, sizeof(short), FRAME_SIZE, f1); 55 | if (feof(f1)) break; 56 | for (i=0;i 32 | 33 | #ifdef __cplusplus 34 | extern "C" { 35 | #endif 36 | 37 | #ifndef RNNOISE_EXPORT 38 | # if defined(WIN32) 39 | # if defined(RNNOISE_BUILD) && defined(DLL_EXPORT) 40 | # define RNNOISE_EXPORT __declspec(dllexport) 41 | # else 42 | # define RNNOISE_EXPORT 43 | # endif 44 | # elif defined(__GNUC__) && defined(RNNOISE_BUILD) 45 | # define RNNOISE_EXPORT __attribute__ ((visibility ("default"))) 46 | # else 47 | # define RNNOISE_EXPORT 48 | # endif 49 | #endif 50 | 51 | typedef struct DenoiseState DenoiseState; 52 | typedef struct RNNModel RNNModel; 53 | 54 | /** 55 | * Return the size of DenoiseState 56 | */ 57 | RNNOISE_EXPORT int rnnoise_get_size(void); 58 | 59 | /** 60 | * Return the number of samples processed by rnnoise_process_frame at a time 61 | */ 62 | RNNOISE_EXPORT int rnnoise_get_frame_size(void); 63 | 64 | /** 65 | * Initializes a pre-allocated DenoiseState 66 | * 67 | * If model is NULL the default model is used. 68 | * 69 | * See: rnnoise_create() and rnnoise_model_from_file() 70 | */ 71 | RNNOISE_EXPORT int rnnoise_init(DenoiseState *st, RNNModel *model); 72 | 73 | /** 74 | * Allocate and initialize a DenoiseState 75 | * 76 | * If model is NULL the default model is used. 77 | * 78 | * The returned pointer MUST be freed with rnnoise_destroy(). 79 | */ 80 | RNNOISE_EXPORT DenoiseState *rnnoise_create(RNNModel *model); 81 | 82 | /** 83 | * Free a DenoiseState produced by rnnoise_create. 84 | * 85 | * The optional custom model must be freed by rnnoise_model_free() after. 86 | */ 87 | RNNOISE_EXPORT void rnnoise_destroy(DenoiseState *st); 88 | 89 | /** 90 | * Denoise a frame of samples 91 | * 92 | * in and out must be at least rnnoise_get_frame_size() large. 93 | */ 94 | RNNOISE_EXPORT float rnnoise_process_frame(DenoiseState *st, float *out, const float *in); 95 | 96 | /** 97 | * Load a model from a memory buffer 98 | * 99 | * It must be deallocated with rnnoise_model_free() and the buffer must remain 100 | * valid until after the returned object is destroyed. 101 | */ 102 | RNNOISE_EXPORT RNNModel *rnnoise_model_from_buffer(const void *ptr, int len); 103 | 104 | 105 | /** 106 | * Load a model from a file 107 | * 108 | * It must be deallocated with rnnoise_model_free() and the file must not be 109 | * closed until the returned object is destroyed. 110 | */ 111 | RNNOISE_EXPORT RNNModel *rnnoise_model_from_file(FILE *f); 112 | 113 | /** 114 | * Load a model from a file name 115 | * 116 | * It must be deallocated with rnnoise_model_free() 117 | */ 118 | RNNOISE_EXPORT RNNModel *rnnoise_model_from_filename(const char *filename); 119 | 120 | /** 121 | * Free a custom model 122 | * 123 | * It must be called after all the DenoiseStates referring to it are freed. 124 | */ 125 | RNNOISE_EXPORT void rnnoise_model_free(RNNModel *model); 126 | 127 | #ifdef __cplusplus 128 | } 129 | #endif 130 | 131 | #endif 132 | -------------------------------------------------------------------------------- /model_version: -------------------------------------------------------------------------------- 1 | 0a8755f8e2d834eff6a54714ecc7d75f9932e845df35f8b59bc52a7cfe6e8b37 2 | -------------------------------------------------------------------------------- /rnnoise-uninstalled.pc.in: -------------------------------------------------------------------------------- 1 | # rnnoise uninstalled pkg-config file 2 | 3 | prefix= 4 | exec_prefix= 5 | libdir=${pcfiledir}/.libs 6 | includedir=${pcfiledir}/@top_srcdir@/include 7 | 8 | Name: rnnoise uninstalled 9 | Description: RNN-based noise suppression (not installed) 10 | Version: @PACKAGE_VERSION@ 11 | Conflicts: 12 | Libs: ${libdir}/librnnoise.la @lrintf_lib@ 13 | Cflags: -I${includedir} 14 | -------------------------------------------------------------------------------- /rnnoise.pc.in: -------------------------------------------------------------------------------- 1 | # rnnoise installed pkg-config file 2 | 3 | prefix=@prefix@ 4 | exec_prefix=@exec_prefix@ 5 | libdir=@libdir@ 6 | includedir=@includedir@ 7 | 8 | Name: rnnoise 9 | Description: RNN-based noise suppression 10 | Version: @PACKAGE_VERSION@ 11 | Conflicts: 12 | Libs: -L${libdir} -lrnnoise 13 | Libs.private: @lrintf_lib@ 14 | Cflags: -I${includedir}/ 15 | -------------------------------------------------------------------------------- /scripts/dump_features_parallel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | cmd=$1 4 | speech=$2 5 | noise=$3 6 | fgnoise=$4 7 | output=$5 8 | count=$6 9 | rir=$7 10 | split=400 11 | seq $split | parallel -j +2 "$cmd -rir_list $rir $speech $noise $fgnoise $output.{} $count" 12 | mv $output.1 $output 13 | for i in $output.* 14 | do 15 | cat $i >> $output 16 | rm $i 17 | done 18 | -------------------------------------------------------------------------------- /scripts/rir_deconv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import sweep 4 | import numpy as np 5 | from numpy import fft 6 | from scipy import signal 7 | from scipy.io import wavfile 8 | import sys 9 | 10 | def extract_sweep(pilot, y, pilot_len, sweep_len, silence_len): 11 | pilot = np.concatenate([pilot, np.zeros(len(y)-len(pilot))]) 12 | N = fft.rfft(pilot) 13 | Y = fft.rfft(y) 14 | xcorr = fft.irfft(Y * np.conj(N)) 15 | pos = np.argmax(np.abs(xcorr[:sweep_len])) 16 | pilot_offset = sweep_len+pilot_len+2*silence_len 17 | pilot1 = y[pos:pos+pilot_len] 18 | pilot2 = y[pilot_offset+pos:pilot_offset+pos+pilot_len] 19 | drift_xcorr = fft.irfft(fft.rfft(pilot1) * np.conj(fft.rfft(pilot2))) 20 | drift = np.argmax(np.abs(drift_xcorr)) 21 | if drift > pilot_len//2: 22 | drift = drift - pilot_len 23 | print(f"measured drift is {drift} samples ({100*drift/(pilot_len + sweep_len + 2*silence_len)})%"); 24 | return y[pos+pilot_len+silence_len//2 : pos+pilot_len+silence_len+sweep_len-drift+silence_len//2] 25 | 26 | def deconv_rir(pilot, x, y, Fs=48000, duration=60): 27 | pilot_len = Fs 28 | sweep_len=Fs*duration 29 | silence_len = Fs 30 | # Properly synchronize the signal and extract just the sweep 31 | y = extract_sweep(pilot, y, pilot_len, sweep_len, silence_len) 32 | x = np.concatenate([x, np.zeros(sweep_len)]) 33 | y = np.concatenate([y, np.zeros(sweep_len-silence_len)]) 34 | X = fft.rfft(x) 35 | Y = fft.rfft(y) 36 | # Truncate or pad depending on the drift 37 | if len(Y) >= len(X): 38 | Y = Y[:len(X)] 39 | else: 40 | Y = np.concatenate([Y, np.zeros(len(X)-len(Y))]) 41 | # Do the actual deconvolution 42 | rir = fft.irfft(Y*np.conj(X)/(1.+X*np.conj(X))) 43 | # Chopping the non-causal part (before the direct path) 44 | direct = np.max(np.abs(rir)) 45 | direct_pos = np.argmax(np.abs(rir)) 46 | crop_pos = np.argwhere(np.abs(rir[:direct_pos+1]) > .02*direct)[0][0] 47 | rir = rir[crop_pos:] 48 | # Chopping the everything that's buried in the noise 49 | noise_floor = np.mean(rir[Fs*10:Fs*20]**2) 50 | smoothed = signal.lfilter(np.array([.002]), np.array([1, -.998]), rir[:Fs*10]**2) 51 | rir_length = np.argwhere(smoothed > 15*noise_floor)[-1][0] 52 | rir = rir[:rir_length] 53 | # Normalize 54 | rir = rir/np.sqrt(np.sum(rir**2)) 55 | return rir 56 | 57 | 58 | if __name__ == '__main__': 59 | duration=60 60 | #Re-compute the sweep that was played 61 | sine = sweep.compute_sweep(duration) 62 | #Load recorded signal 63 | _,mic=wavfile.read(sys.argv[1]) 64 | #Re-compute pilot sequence 65 | pilot=sweep.compute_sweep(1.) 66 | 67 | rir = deconv_rir(pilot, sine, mic, duration=duration) 68 | rir.astype('float32').tofile(sys.argv[2]) 69 | -------------------------------------------------------------------------------- /scripts/shrink_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | for i in rnnoise_data.c 4 | do 5 | cat src/$i | perl -ne 'if (/DEBUG/ || /#else/) {$skip=1} if (!$skip && !/ifdef DOT_PROD/) {s/^ *//; s/, /,/g; print $_} elsif (/endif/) {$skip=0}' > tmp_data.c 6 | mv tmp_data.c src/$i 7 | done 8 | -------------------------------------------------------------------------------- /scripts/sweep.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import scipy.io.wavfile as wav 4 | import numpy as np 5 | import sys 6 | 7 | def compute_sweep(T, Fs=48000, F0=100): 8 | F1=Fs//2 9 | b=np.log((F1+F0)/F0)/T 10 | a=F0/b 11 | n=np.arange(T*Fs) 12 | t = n/Fs 13 | y=0.9*np.sin(2*np.pi*a*(np.exp(b*t)-b*t-1)) 14 | return y 15 | 16 | def compute_sequence(T, Fs=48000, F0=100): 17 | noise = compute_sweep(1, Fs, F0) 18 | zeros = np.zeros(Fs) 19 | sine = compute_sweep(T, Fs, F0) 20 | sequence = np.concatenate([zeros, noise, zeros, sine, zeros, noise, zeros]) 21 | return np.round(32768*sequence).astype('int16') 22 | 23 | if __name__ == '__main__': 24 | 25 | filename = sys.argv[1] 26 | Fs = 48000 27 | seq = compute_sequence(60, Fs=Fs) 28 | wav.write(filename, Fs, seq) 29 | -------------------------------------------------------------------------------- /src/_kiss_fft_guts.h: -------------------------------------------------------------------------------- 1 | /*Copyright (c) 2003-2004, Mark Borgerding 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, 9 | this list of conditions and the following disclaimer. 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 17 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 18 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 19 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 20 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 21 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 22 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 23 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 24 | POSSIBILITY OF SUCH DAMAGE.*/ 25 | 26 | #ifndef KISS_FFT_GUTS_H 27 | #define KISS_FFT_GUTS_H 28 | 29 | #define MIN(a,b) ((a)<(b) ? (a):(b)) 30 | #define MAX(a,b) ((a)>(b) ? (a):(b)) 31 | 32 | /* kiss_fft.h 33 | defines kiss_fft_scalar as either short or a float type 34 | and defines 35 | typedef struct { kiss_fft_scalar r; kiss_fft_scalar i; }kiss_fft_cpx; */ 36 | #include "kiss_fft.h" 37 | 38 | /* 39 | Explanation of macros dealing with complex math: 40 | 41 | C_MUL(m,a,b) : m = a*b 42 | C_FIXDIV( c , div ) : if a fixed point impl., c /= div. noop otherwise 43 | C_SUB( res, a,b) : res = a - b 44 | C_SUBFROM( res , a) : res -= a 45 | C_ADDTO( res , a) : res += a 46 | * */ 47 | #ifdef FIXED_POINT 48 | #include "arch.h" 49 | 50 | 51 | #define SAMP_MAX 2147483647 52 | #define TWID_MAX 32767 53 | #define TRIG_UPSCALE 1 54 | 55 | #define SAMP_MIN -SAMP_MAX 56 | 57 | 58 | # define S_MUL(a,b) MULT16_32_Q15(b, a) 59 | 60 | # define C_MUL(m,a,b) \ 61 | do{ (m).r = SUB32_ovflw(S_MUL((a).r,(b).r) , S_MUL((a).i,(b).i)); \ 62 | (m).i = ADD32_ovflw(S_MUL((a).r,(b).i) , S_MUL((a).i,(b).r)); }while(0) 63 | 64 | # define C_MULC(m,a,b) \ 65 | do{ (m).r = ADD32_ovflw(S_MUL((a).r,(b).r) , S_MUL((a).i,(b).i)); \ 66 | (m).i = SUB32_ovflw(S_MUL((a).i,(b).r) , S_MUL((a).r,(b).i)); }while(0) 67 | 68 | # define C_MULBYSCALAR( c, s ) \ 69 | do{ (c).r = S_MUL( (c).r , s ) ;\ 70 | (c).i = S_MUL( (c).i , s ) ; }while(0) 71 | 72 | # define DIVSCALAR(x,k) \ 73 | (x) = S_MUL( x, (TWID_MAX-((k)>>1))/(k)+1 ) 74 | 75 | # define C_FIXDIV(c,div) \ 76 | do { DIVSCALAR( (c).r , div); \ 77 | DIVSCALAR( (c).i , div); }while (0) 78 | 79 | #define C_ADD( res, a,b)\ 80 | do {(res).r=ADD32_ovflw((a).r,(b).r); (res).i=ADD32_ovflw((a).i,(b).i); \ 81 | }while(0) 82 | #define C_SUB( res, a,b)\ 83 | do {(res).r=SUB32_ovflw((a).r,(b).r); (res).i=SUB32_ovflw((a).i,(b).i); \ 84 | }while(0) 85 | #define C_ADDTO( res , a)\ 86 | do {(res).r = ADD32_ovflw((res).r, (a).r); (res).i = ADD32_ovflw((res).i,(a).i);\ 87 | }while(0) 88 | 89 | #define C_SUBFROM( res , a)\ 90 | do {(res).r = ADD32_ovflw((res).r,(a).r); (res).i = SUB32_ovflw((res).i,(a).i); \ 91 | }while(0) 92 | 93 | #if defined(OPUS_ARM_INLINE_ASM) 94 | #include "arm/kiss_fft_armv4.h" 95 | #endif 96 | 97 | #if defined(OPUS_ARM_INLINE_EDSP) 98 | #include "arm/kiss_fft_armv5e.h" 99 | #endif 100 | #if defined(MIPSr1_ASM) 101 | #include "mips/kiss_fft_mipsr1.h" 102 | #endif 103 | 104 | #else /* not FIXED_POINT*/ 105 | 106 | # define S_MUL(a,b) ( (a)*(b) ) 107 | #define C_MUL(m,a,b) \ 108 | do{ (m).r = (a).r*(b).r - (a).i*(b).i;\ 109 | (m).i = (a).r*(b).i + (a).i*(b).r; }while(0) 110 | #define C_MULC(m,a,b) \ 111 | do{ (m).r = (a).r*(b).r + (a).i*(b).i;\ 112 | (m).i = (a).i*(b).r - (a).r*(b).i; }while(0) 113 | 114 | #define C_MUL4(m,a,b) C_MUL(m,a,b) 115 | 116 | # define C_FIXDIV(c,div) /* NOOP */ 117 | # define C_MULBYSCALAR( c, s ) \ 118 | do{ (c).r *= (s);\ 119 | (c).i *= (s); }while(0) 120 | #endif 121 | 122 | #ifndef CHECK_OVERFLOW_OP 123 | # define CHECK_OVERFLOW_OP(a,op,b) /* noop */ 124 | #endif 125 | 126 | #ifndef C_ADD 127 | #define C_ADD( res, a,b)\ 128 | do { \ 129 | CHECK_OVERFLOW_OP((a).r,+,(b).r)\ 130 | CHECK_OVERFLOW_OP((a).i,+,(b).i)\ 131 | (res).r=(a).r+(b).r; (res).i=(a).i+(b).i; \ 132 | }while(0) 133 | #define C_SUB( res, a,b)\ 134 | do { \ 135 | CHECK_OVERFLOW_OP((a).r,-,(b).r)\ 136 | CHECK_OVERFLOW_OP((a).i,-,(b).i)\ 137 | (res).r=(a).r-(b).r; (res).i=(a).i-(b).i; \ 138 | }while(0) 139 | #define C_ADDTO( res , a)\ 140 | do { \ 141 | CHECK_OVERFLOW_OP((res).r,+,(a).r)\ 142 | CHECK_OVERFLOW_OP((res).i,+,(a).i)\ 143 | (res).r += (a).r; (res).i += (a).i;\ 144 | }while(0) 145 | 146 | #define C_SUBFROM( res , a)\ 147 | do {\ 148 | CHECK_OVERFLOW_OP((res).r,-,(a).r)\ 149 | CHECK_OVERFLOW_OP((res).i,-,(a).i)\ 150 | (res).r -= (a).r; (res).i -= (a).i; \ 151 | }while(0) 152 | #endif /* C_ADD defined */ 153 | 154 | #ifdef FIXED_POINT 155 | /*# define KISS_FFT_COS(phase) TRIG_UPSCALE*floor(MIN(32767,MAX(-32767,.5+32768 * cos (phase)))) 156 | # define KISS_FFT_SIN(phase) TRIG_UPSCALE*floor(MIN(32767,MAX(-32767,.5+32768 * sin (phase))))*/ 157 | # define KISS_FFT_COS(phase) floor(.5+TWID_MAX*cos (phase)) 158 | # define KISS_FFT_SIN(phase) floor(.5+TWID_MAX*sin (phase)) 159 | # define HALF_OF(x) ((x)>>1) 160 | #elif defined(USE_SIMD) 161 | # define KISS_FFT_COS(phase) _mm_set1_ps( cos(phase) ) 162 | # define KISS_FFT_SIN(phase) _mm_set1_ps( sin(phase) ) 163 | # define HALF_OF(x) ((x)*_mm_set1_ps(.5f)) 164 | #else 165 | # define KISS_FFT_COS(phase) (kiss_fft_scalar) cos(phase) 166 | # define KISS_FFT_SIN(phase) (kiss_fft_scalar) sin(phase) 167 | # define HALF_OF(x) ((x)*.5f) 168 | #endif 169 | 170 | #define kf_cexp(x,phase) \ 171 | do{ \ 172 | (x)->r = KISS_FFT_COS(phase);\ 173 | (x)->i = KISS_FFT_SIN(phase);\ 174 | }while(0) 175 | 176 | #define kf_cexp2(x,phase) \ 177 | do{ \ 178 | (x)->r = TRIG_UPSCALE*celt_cos_norm((phase));\ 179 | (x)->i = TRIG_UPSCALE*celt_cos_norm((phase)-32768);\ 180 | }while(0) 181 | 182 | #endif /* KISS_FFT_GUTS_H */ 183 | -------------------------------------------------------------------------------- /src/arch.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2003-2008 Jean-Marc Valin 2 | Copyright (c) 2007-2008 CSIRO 3 | Copyright (c) 2007-2009 Xiph.Org Foundation 4 | Written by Jean-Marc Valin */ 5 | /** 6 | @file arch.h 7 | @brief Various architecture definitions for CELT 8 | */ 9 | /* 10 | Redistribution and use in source and binary forms, with or without 11 | modification, are permitted provided that the following conditions 12 | are met: 13 | 14 | - Redistributions of source code must retain the above copyright 15 | notice, this list of conditions and the following disclaimer. 16 | 17 | - Redistributions in binary form must reproduce the above copyright 18 | notice, this list of conditions and the following disclaimer in the 19 | documentation and/or other materials provided with the distribution. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 25 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 26 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 27 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 28 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 29 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 30 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 31 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | */ 33 | 34 | #ifndef ARCH_H 35 | #define ARCH_H 36 | 37 | #include "opus_types.h" 38 | #include "common.h" 39 | 40 | # if !defined(__GNUC_PREREQ) 41 | # if defined(__GNUC__)&&defined(__GNUC_MINOR__) 42 | # define __GNUC_PREREQ(_maj,_min) \ 43 | ((__GNUC__<<16)+__GNUC_MINOR__>=((_maj)<<16)+(_min)) 44 | # else 45 | # define __GNUC_PREREQ(_maj,_min) 0 46 | # endif 47 | # endif 48 | 49 | #define CELT_SIG_SCALE 32768.f 50 | 51 | #define celt_fatal(str) _celt_fatal(str, __FILE__, __LINE__); 52 | #ifdef ENABLE_ASSERTIONS 53 | #include 54 | #include 55 | #ifdef __GNUC__ 56 | __attribute__((noreturn)) 57 | #endif 58 | static OPUS_INLINE void _celt_fatal(const char *str, const char *file, int line) 59 | { 60 | fprintf (stderr, "Fatal (internal) error in %s, line %d: %s\n", file, line, str); 61 | abort(); 62 | } 63 | #define celt_assert(cond) {if (!(cond)) {celt_fatal("assertion failed: " #cond);}} 64 | #define celt_assert2(cond, message) {if (!(cond)) {celt_fatal("assertion failed: " #cond "\n" message);}} 65 | #else 66 | #define celt_assert(cond) 67 | #define celt_assert2(cond, message) 68 | #endif 69 | 70 | #define IMUL32(a,b) ((a)*(b)) 71 | 72 | #define MIN16(a,b) ((a) < (b) ? (a) : (b)) /**< Minimum 16-bit value. */ 73 | #define MAX16(a,b) ((a) > (b) ? (a) : (b)) /**< Maximum 16-bit value. */ 74 | #define MIN32(a,b) ((a) < (b) ? (a) : (b)) /**< Minimum 32-bit value. */ 75 | #define MAX32(a,b) ((a) > (b) ? (a) : (b)) /**< Maximum 32-bit value. */ 76 | #define IMIN(a,b) ((a) < (b) ? (a) : (b)) /**< Minimum int value. */ 77 | #define IMAX(a,b) ((a) > (b) ? (a) : (b)) /**< Maximum int value. */ 78 | #define UADD32(a,b) ((a)+(b)) 79 | #define USUB32(a,b) ((a)-(b)) 80 | 81 | /* Set this if opus_int64 is a native type of the CPU. */ 82 | /* Assume that all LP64 architectures have fast 64-bit types; also x86_64 83 | (which can be ILP32 for x32) and Win64 (which is LLP64). */ 84 | #if defined(__x86_64__) || defined(__LP64__) || defined(_WIN64) 85 | #define OPUS_FAST_INT64 1 86 | #else 87 | #define OPUS_FAST_INT64 0 88 | #endif 89 | 90 | #define PRINT_MIPS(file) 91 | 92 | #ifdef FIXED_POINT 93 | 94 | typedef opus_int16 opus_val16; 95 | typedef opus_int32 opus_val32; 96 | typedef opus_int64 opus_val64; 97 | 98 | typedef opus_val32 celt_sig; 99 | typedef opus_val16 celt_norm; 100 | typedef opus_val32 celt_ener; 101 | 102 | #define Q15ONE 32767 103 | 104 | #define SIG_SHIFT 12 105 | /* Safe saturation value for 32-bit signals. Should be less than 106 | 2^31*(1-0.85) to avoid blowing up on DC at deemphasis.*/ 107 | #define SIG_SAT (300000000) 108 | 109 | #define NORM_SCALING 16384 110 | 111 | #define DB_SHIFT 10 112 | 113 | #define EPSILON 1 114 | #define VERY_SMALL 0 115 | #define VERY_LARGE16 ((opus_val16)32767) 116 | #define Q15_ONE ((opus_val16)32767) 117 | 118 | #define SCALEIN(a) (a) 119 | #define SCALEOUT(a) (a) 120 | 121 | #define ABS16(x) ((x) < 0 ? (-(x)) : (x)) 122 | #define ABS32(x) ((x) < 0 ? (-(x)) : (x)) 123 | 124 | static OPUS_INLINE opus_int16 SAT16(opus_int32 x) { 125 | return x > 32767 ? 32767 : x < -32768 ? -32768 : (opus_int16)x; 126 | } 127 | 128 | #ifdef FIXED_DEBUG 129 | #include "fixed_debug.h" 130 | #else 131 | 132 | #include "fixed_generic.h" 133 | 134 | #ifdef OPUS_ARM_PRESUME_AARCH64_NEON_INTR 135 | #include "arm/fixed_arm64.h" 136 | #elif OPUS_ARM_INLINE_EDSP 137 | #include "arm/fixed_armv5e.h" 138 | #elif defined (OPUS_ARM_INLINE_ASM) 139 | #include "arm/fixed_armv4.h" 140 | #elif defined (BFIN_ASM) 141 | #include "fixed_bfin.h" 142 | #elif defined (TI_C5X_ASM) 143 | #include "fixed_c5x.h" 144 | #elif defined (TI_C6X_ASM) 145 | #include "fixed_c6x.h" 146 | #endif 147 | 148 | #endif 149 | 150 | #else /* FIXED_POINT */ 151 | 152 | typedef float opus_val16; 153 | typedef float opus_val32; 154 | typedef float opus_val64; 155 | 156 | typedef float celt_sig; 157 | typedef float celt_norm; 158 | typedef float celt_ener; 159 | 160 | #ifdef FLOAT_APPROX 161 | /* This code should reliably detect NaN/inf even when -ffast-math is used. 162 | Assumes IEEE 754 format. */ 163 | static OPUS_INLINE int celt_isnan(float x) 164 | { 165 | union {float f; opus_uint32 i;} in; 166 | in.f = x; 167 | return ((in.i>>23)&0xFF)==0xFF && (in.i&0x007FFFFF)!=0; 168 | } 169 | #else 170 | #ifdef __FAST_MATH__ 171 | #error Cannot build libopus with -ffast-math unless FLOAT_APPROX is defined. This could result in crashes on extreme (e.g. NaN) input 172 | #endif 173 | #define celt_isnan(x) ((x)!=(x)) 174 | #endif 175 | 176 | #define Q15ONE 1.0f 177 | 178 | #define NORM_SCALING 1.f 179 | 180 | #define EPSILON 1e-15f 181 | #define VERY_SMALL 1e-30f 182 | #define VERY_LARGE16 1e15f 183 | #define Q15_ONE ((opus_val16)1.f) 184 | 185 | /* This appears to be the same speed as C99's fabsf() but it's more portable. */ 186 | #define ABS16(x) ((float)fabs(x)) 187 | #define ABS32(x) ((float)fabs(x)) 188 | 189 | #define QCONST16(x,bits) (x) 190 | #define QCONST32(x,bits) (x) 191 | 192 | #define NEG16(x) (-(x)) 193 | #define NEG32(x) (-(x)) 194 | #define NEG32_ovflw(x) (-(x)) 195 | #define EXTRACT16(x) (x) 196 | #define EXTEND32(x) (x) 197 | #define SHR16(a,shift) (a) 198 | #define SHL16(a,shift) (a) 199 | #define SHR32(a,shift) (a) 200 | #define SHL32(a,shift) (a) 201 | #define PSHR32(a,shift) (a) 202 | #define VSHR32(a,shift) (a) 203 | 204 | #define PSHR(a,shift) (a) 205 | #define SHR(a,shift) (a) 206 | #define SHL(a,shift) (a) 207 | #define SATURATE(x,a) (x) 208 | #define SATURATE16(x) (x) 209 | 210 | #define ROUND16(a,shift) (a) 211 | #define SROUND16(a,shift) (a) 212 | #define HALF16(x) (.5f*(x)) 213 | #define HALF32(x) (.5f*(x)) 214 | 215 | #define ADD16(a,b) ((a)+(b)) 216 | #define SUB16(a,b) ((a)-(b)) 217 | #define ADD32(a,b) ((a)+(b)) 218 | #define SUB32(a,b) ((a)-(b)) 219 | #define ADD32_ovflw(a,b) ((a)+(b)) 220 | #define SUB32_ovflw(a,b) ((a)-(b)) 221 | #define MULT16_16_16(a,b) ((a)*(b)) 222 | #define MULT16_16(a,b) ((opus_val32)(a)*(opus_val32)(b)) 223 | #define MAC16_16(c,a,b) ((c)+(opus_val32)(a)*(opus_val32)(b)) 224 | 225 | #define MULT16_32_Q15(a,b) ((a)*(b)) 226 | #define MULT16_32_Q16(a,b) ((a)*(b)) 227 | 228 | #define MULT32_32_Q31(a,b) ((a)*(b)) 229 | 230 | #define MAC16_32_Q15(c,a,b) ((c)+(a)*(b)) 231 | #define MAC16_32_Q16(c,a,b) ((c)+(a)*(b)) 232 | 233 | #define MULT16_16_Q11_32(a,b) ((a)*(b)) 234 | #define MULT16_16_Q11(a,b) ((a)*(b)) 235 | #define MULT16_16_Q13(a,b) ((a)*(b)) 236 | #define MULT16_16_Q14(a,b) ((a)*(b)) 237 | #define MULT16_16_Q15(a,b) ((a)*(b)) 238 | #define MULT16_16_P15(a,b) ((a)*(b)) 239 | #define MULT16_16_P13(a,b) ((a)*(b)) 240 | #define MULT16_16_P14(a,b) ((a)*(b)) 241 | #define MULT16_32_P16(a,b) ((a)*(b)) 242 | 243 | #define DIV32_16(a,b) (((opus_val32)(a))/(opus_val16)(b)) 244 | #define DIV32(a,b) (((opus_val32)(a))/(opus_val32)(b)) 245 | 246 | #define SCALEIN(a) ((a)*CELT_SIG_SCALE) 247 | #define SCALEOUT(a) ((a)*(1/CELT_SIG_SCALE)) 248 | 249 | #define SIG2WORD16(x) (x) 250 | 251 | #endif /* !FIXED_POINT */ 252 | 253 | #ifndef GLOBAL_STACK_SIZE 254 | #ifdef FIXED_POINT 255 | #define GLOBAL_STACK_SIZE 120000 256 | #else 257 | #define GLOBAL_STACK_SIZE 120000 258 | #endif 259 | #endif 260 | 261 | #endif /* ARCH_H */ 262 | -------------------------------------------------------------------------------- /src/celt_lpc.c: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2009-2010 Xiph.Org Foundation 2 | Written by Jean-Marc Valin */ 3 | /* 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | - Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | - Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 19 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | 28 | #ifdef HAVE_CONFIG_H 29 | #include "config.h" 30 | #endif 31 | 32 | #include "celt_lpc.h" 33 | #include "arch.h" 34 | #include "common.h" 35 | #include "pitch.h" 36 | #include "denoise.h" 37 | 38 | void rnn_lpc( 39 | opus_val16 *_lpc, /* out: [0...p-1] LPC coefficients */ 40 | const opus_val32 *ac, /* in: [0...p] autocorrelation values */ 41 | int p 42 | ) 43 | { 44 | int i, j; 45 | opus_val32 r; 46 | opus_val32 error = ac[0]; 47 | #ifdef FIXED_POINT 48 | opus_val32 lpc[LPC_ORDER]; 49 | #else 50 | float *lpc = _lpc; 51 | #endif 52 | 53 | RNN_CLEAR(lpc, p); 54 | if (ac[0] != 0) 55 | { 56 | for (i = 0; i < p; i++) { 57 | /* Sum up this iteration's reflection coefficient */ 58 | opus_val32 rr = 0; 59 | for (j = 0; j < i; j++) 60 | rr += MULT32_32_Q31(lpc[j],ac[i - j]); 61 | rr += SHR32(ac[i + 1],3); 62 | r = -SHL32(rr,3)/error; 63 | /* Update LPC coefficients and total error */ 64 | lpc[i] = SHR32(r,3); 65 | for (j = 0; j < (i+1)>>1; j++) 66 | { 67 | opus_val32 tmp1, tmp2; 68 | tmp1 = lpc[j]; 69 | tmp2 = lpc[i-1-j]; 70 | lpc[j] = tmp1 + MULT32_32_Q31(r,tmp2); 71 | lpc[i-1-j] = tmp2 + MULT32_32_Q31(r,tmp1); 72 | } 73 | 74 | error = error - MULT32_32_Q31(MULT32_32_Q31(r,r),error); 75 | /* Bail out once we get 30 dB gain */ 76 | #ifdef FIXED_POINT 77 | if (error0); 107 | celt_assert(n<=PITCH_BUF_SIZE/2) 108 | celt_assert(overlap>=0); 109 | if (overlap == 0) 110 | { 111 | xptr = x; 112 | } else { 113 | for (i=0;i0) 137 | { 138 | for(i=0;i= 536870912) 163 | { 164 | int shift2=1; 165 | if (ac[0] >= 1073741824) 166 | shift2++; 167 | for (i=0;i<=lag;i++) 168 | ac[i] = SHR32(ac[i], shift2); 169 | shift += shift2; 170 | } 171 | #endif 172 | 173 | return shift; 174 | } 175 | -------------------------------------------------------------------------------- /src/celt_lpc.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2009-2010 Xiph.Org Foundation 2 | Written by Jean-Marc Valin */ 3 | /* 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | - Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | - Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 19 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | 28 | #ifndef PLC_H 29 | #define PLC_H 30 | 31 | #include "arch.h" 32 | #include "common.h" 33 | 34 | #if defined(OPUS_X86_MAY_HAVE_SSE4_1) 35 | #include "x86/celt_lpc_sse.h" 36 | #endif 37 | 38 | #define LPC_ORDER 24 39 | 40 | void rnn_lpc(opus_val16 *_lpc, const opus_val32 *ac, int p); 41 | 42 | int rnn_autocorr(const opus_val16 *x, opus_val32 *ac, 43 | const opus_val16 *window, int overlap, int lag, int n); 44 | 45 | #endif /* PLC_H */ 46 | -------------------------------------------------------------------------------- /src/common.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | #ifndef COMMON_H 4 | #define COMMON_H 5 | 6 | #include "stdlib.h" 7 | #include "string.h" 8 | 9 | #define RNN_INLINE inline 10 | #define OPUS_INLINE inline 11 | 12 | 13 | /** RNNoise wrapper for malloc(). To do your own dynamic allocation, all you need t 14 | o do is replace this function and rnnoise_free */ 15 | #ifndef OVERRIDE_RNNOISE_ALLOC 16 | static RNN_INLINE void *rnnoise_alloc (size_t size) 17 | { 18 | return malloc(size); 19 | } 20 | #endif 21 | 22 | /** RNNoise wrapper for free(). To do your own dynamic allocation, all you need to do is replace this function and rnnoise_alloc */ 23 | #ifndef OVERRIDE_RNNOISE_FREE 24 | static RNN_INLINE void rnnoise_free (void *ptr) 25 | { 26 | free(ptr); 27 | } 28 | #endif 29 | 30 | /** Copy n elements from src to dst. The 0* term provides compile-time type checking */ 31 | #ifndef OVERRIDE_RNN_COPY 32 | #define RNN_COPY(dst, src, n) (memcpy((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) )) 33 | #endif 34 | 35 | /** Copy n elements from src to dst, allowing overlapping regions. The 0* term 36 | provides compile-time type checking */ 37 | #ifndef OVERRIDE_RNN_MOVE 38 | #define RNN_MOVE(dst, src, n) (memmove((dst), (src), (n)*sizeof(*(dst)) + 0*((dst)-(src)) )) 39 | #endif 40 | 41 | /** Set n elements of dst to zero */ 42 | #ifndef OVERRIDE_RNN_CLEAR 43 | #define RNN_CLEAR(dst, n) (memset((dst), 0, (n)*sizeof(*(dst)))) 44 | #endif 45 | 46 | # if !defined(OPUS_GNUC_PREREQ) 47 | # if defined(__GNUC__)&&defined(__GNUC_MINOR__) 48 | # define OPUS_GNUC_PREREQ(_maj,_min) \ 49 | ((__GNUC__<<16)+__GNUC_MINOR__>=((_maj)<<16)+(_min)) 50 | # else 51 | # define OPUS_GNUC_PREREQ(_maj,_min) 0 52 | # endif 53 | # endif 54 | 55 | 56 | #endif 57 | -------------------------------------------------------------------------------- /src/compile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | gcc -DTRAINING=1 -Wall -W -O3 -g -I../include denoise.c kiss_fft.c pitch.c celt_lpc.c rnn.c rnn_data.c -o denoise_training -lm 4 | -------------------------------------------------------------------------------- /src/cpu_support.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2010 Xiph.Org Foundation 2 | * Copyright (c) 2013 Parrot */ 3 | /* 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | - Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | - Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 19 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | 28 | #ifndef CPU_SUPPORT_H 29 | #define CPU_SUPPORT_H 30 | 31 | #include "opus_types.h" 32 | #include "common.h" 33 | 34 | #ifdef RNN_ENABLE_X86_RTCD 35 | 36 | #include "x86/x86cpu.h" 37 | /* We currently support 5 x86 variants: 38 | * arch[0] -> sse2 39 | * arch[1] -> sse4.1 40 | * arch[2] -> avx2 41 | */ 42 | #define OPUS_ARCHMASK 3 43 | int rnn_select_arch(void); 44 | 45 | #else 46 | #define OPUS_ARCHMASK 0 47 | 48 | static OPUS_INLINE int rnn_select_arch(void) 49 | { 50 | return 0; 51 | } 52 | #endif 53 | #endif 54 | -------------------------------------------------------------------------------- /src/denoise.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2017 Mozilla */ 2 | /* 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions 5 | are met: 6 | 7 | - Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | 10 | - Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 15 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 16 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 17 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 18 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 19 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 20 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 21 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 22 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 23 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | */ 26 | 27 | #include "rnnoise.h" 28 | #include "kiss_fft.h" 29 | #include "nnet.h" 30 | 31 | #define FRAME_SIZE 480 32 | #define WINDOW_SIZE (2*FRAME_SIZE) 33 | #define FREQ_SIZE (FRAME_SIZE + 1) 34 | #define NB_BANDS 32 35 | #define NB_FEATURES (2*NB_BANDS+1) 36 | 37 | 38 | #define PITCH_MIN_PERIOD 60 39 | #define PITCH_MAX_PERIOD 768 40 | #define PITCH_FRAME_SIZE 960 41 | #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE) 42 | 43 | extern const WeightArray rnnoise_arrays[]; 44 | 45 | extern const int eband20ms[]; 46 | 47 | 48 | void rnn_biquad(float *y, float mem[2], const float *x, const float *b, const float *a, int N); 49 | 50 | void rnn_pitch_filter(kiss_fft_cpx *X, const kiss_fft_cpx *P, const float *Ex, const float *Ep, 51 | const float *Exp, const float *g); 52 | 53 | void rnn_frame_analysis(DenoiseState *st, kiss_fft_cpx *X, float *Ex, const float *in); 54 | 55 | int rnn_compute_frame_features(DenoiseState *st, kiss_fft_cpx *X, kiss_fft_cpx *P, 56 | float *Ex, float *Ep, float *Exp, float *features, const float *in); 57 | -------------------------------------------------------------------------------- /src/dump_rnnoise_tables.c: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2017-2018 Mozilla 2 | Copyright (c) 2023 Amazon */ 3 | /* 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | - Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | - Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 19 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | 28 | #ifdef HAVE_CONFIG_H 29 | #include "config.h" 30 | #endif 31 | 32 | #include 33 | #include 34 | #include "denoise.h" 35 | #include "kiss_fft.h" 36 | 37 | #define OVERLAP_SIZE FRAME_SIZE 38 | 39 | int main(void) { 40 | int i; 41 | FILE *file; 42 | kiss_fft_state *kfft; 43 | float half_window[OVERLAP_SIZE]; 44 | float dct_table[NB_BANDS*NB_BANDS]; 45 | 46 | file=fopen("rnnoise_tables.c", "wb"); 47 | fprintf(file, "/* The contents of this file was automatically generated by dump_rnnoise_tables.c*/\n\n"); 48 | fprintf(file, "#ifdef HAVE_CONFIG_H\n"); 49 | fprintf(file, "#include \"config.h\"\n"); 50 | fprintf(file, "#endif\n"); 51 | 52 | fprintf(file, "#include \"kiss_fft.h\"\n\n"); 53 | 54 | kfft = rnn_fft_alloc_twiddles(WINDOW_SIZE, NULL, NULL, NULL, 0); 55 | 56 | fprintf(file, "static const arch_fft_state arch_fft = {0, NULL};\n\n"); 57 | 58 | fprintf (file, "static const opus_int32 fft_bitrev[%d] = {\n", kfft->nfft); 59 | for (i=0;infft;i++) 60 | fprintf (file, "%d,%c", kfft->bitrev[i],(i+16)%15==0?'\n':' '); 61 | fprintf (file, "};\n\n"); 62 | 63 | fprintf (file, "static const kiss_twiddle_cpx fft_twiddles[%d] = {\n", kfft->nfft); 64 | for (i=0;infft;i++) 65 | fprintf (file, "{%#0.9gf, %#0.9gf},%c", kfft->twiddles[i].r, kfft->twiddles[i].i,(i+3)%2==0?'\n':' '); 66 | fprintf (file, "};\n\n"); 67 | 68 | 69 | fprintf(file, "const kiss_fft_state rnn_kfft = {\n"); 70 | fprintf(file, "%d, /* nfft */\n", kfft->nfft); 71 | fprintf(file, "%#0.8gf, /* scale */\n", kfft->scale); 72 | fprintf(file, "%d, /* shift */\n", kfft->shift); 73 | fprintf(file, "{"); 74 | for (i=0;i<2*MAXFACTORS;i++) { 75 | fprintf(file, "%d, ", kfft->factors[i]); 76 | } 77 | fprintf(file, "}, /* factors */\n"); 78 | fprintf(file, "fft_bitrev, /* bitrev*/\n"); 79 | fprintf(file, "fft_twiddles, /* twiddles*/\n"); 80 | fprintf(file, "(arch_fft_state *)&arch_fft, /* arch_fft*/\n"); 81 | 82 | fprintf(file, "};\n\n"); 83 | 84 | for (i=0;i 33 | #include 34 | #include "arch.h" 35 | 36 | #include 37 | #define opus_alloc(x) malloc(x) 38 | #define opus_free(x) free(x) 39 | 40 | #ifdef __cplusplus 41 | extern "C" { 42 | #endif 43 | 44 | #ifdef USE_SIMD 45 | # include 46 | # define kiss_fft_scalar __m128 47 | #define KISS_FFT_MALLOC(nbytes) memalign(16,nbytes) 48 | #else 49 | #define KISS_FFT_MALLOC opus_alloc 50 | #endif 51 | 52 | #ifdef FIXED_POINT 53 | #include "arch.h" 54 | 55 | # define kiss_fft_scalar opus_int32 56 | # define kiss_twiddle_scalar opus_int16 57 | 58 | 59 | #else 60 | # ifndef kiss_fft_scalar 61 | /* default is float */ 62 | # define kiss_fft_scalar float 63 | # define kiss_twiddle_scalar float 64 | # define KF_SUFFIX _celt_single 65 | # endif 66 | #endif 67 | 68 | typedef struct { 69 | kiss_fft_scalar r; 70 | kiss_fft_scalar i; 71 | }kiss_fft_cpx; 72 | 73 | typedef struct { 74 | kiss_twiddle_scalar r; 75 | kiss_twiddle_scalar i; 76 | }kiss_twiddle_cpx; 77 | 78 | #define MAXFACTORS 8 79 | /* e.g. an fft of length 128 has 4 factors 80 | as far as kissfft is concerned 81 | 4*4*4*2 82 | */ 83 | 84 | typedef struct arch_fft_state{ 85 | int is_supported; 86 | void *priv; 87 | } arch_fft_state; 88 | 89 | typedef struct kiss_fft_state{ 90 | int nfft; 91 | opus_val16 scale; 92 | #ifdef FIXED_POINT 93 | int scale_shift; 94 | #endif 95 | int shift; 96 | opus_int16 factors[2*MAXFACTORS]; 97 | const opus_int32 *bitrev; 98 | const kiss_twiddle_cpx *twiddles; 99 | arch_fft_state *arch_fft; 100 | } kiss_fft_state; 101 | 102 | #if defined(HAVE_ARM_NE10) 103 | #include "arm/fft_arm.h" 104 | #endif 105 | 106 | /*typedef struct kiss_fft_state* kiss_fft_cfg;*/ 107 | 108 | /** 109 | * opus_fft_alloc 110 | * 111 | * Initialize a FFT (or IFFT) algorithm's cfg/state buffer. 112 | * 113 | * typical usage: kiss_fft_cfg mycfg=opus_fft_alloc(1024,0,NULL,NULL); 114 | * 115 | * The return value from fft_alloc is a cfg buffer used internally 116 | * by the fft routine or NULL. 117 | * 118 | * If lenmem is NULL, then opus_fft_alloc will allocate a cfg buffer using malloc. 119 | * The returned value should be free()d when done to avoid memory leaks. 120 | * 121 | * The state can be placed in a user supplied buffer 'mem': 122 | * If lenmem is not NULL and mem is not NULL and *lenmem is large enough, 123 | * then the function places the cfg in mem and the size used in *lenmem 124 | * and returns mem. 125 | * 126 | * If lenmem is not NULL and ( mem is NULL or *lenmem is not large enough), 127 | * then the function returns NULL and places the minimum cfg 128 | * buffer size in *lenmem. 129 | * */ 130 | 131 | kiss_fft_state *rnn_fft_alloc_twiddles(int nfft,void * mem,size_t * lenmem, const kiss_fft_state *base, int arch); 132 | 133 | kiss_fft_state *rnn_fft_alloc(int nfft,void * mem,size_t * lenmem, int arch); 134 | 135 | /** 136 | * opus_fft(cfg,in_out_buf) 137 | * 138 | * Perform an FFT on a complex input buffer. 139 | * for a forward FFT, 140 | * fin should be f[0] , f[1] , ... ,f[nfft-1] 141 | * fout will be F[0] , F[1] , ... ,F[nfft-1] 142 | * Note that each element is complex and can be accessed like 143 | f[k].r and f[k].i 144 | * */ 145 | void rnn_fft_c(const kiss_fft_state *cfg,const kiss_fft_cpx *fin,kiss_fft_cpx *fout); 146 | void rnn_ifft_c(const kiss_fft_state *cfg,const kiss_fft_cpx *fin,kiss_fft_cpx *fout); 147 | 148 | void rnn_fft_impl(const kiss_fft_state *st,kiss_fft_cpx *fout); 149 | void rnn_ifft_impl(const kiss_fft_state *st,kiss_fft_cpx *fout); 150 | 151 | void rnn_fft_free(const kiss_fft_state *cfg, int arch); 152 | 153 | 154 | void rnn_fft_free_arch_c(kiss_fft_state *st); 155 | int rnn_fft_alloc_arch_c(kiss_fft_state *st); 156 | 157 | #if !defined(OVERRIDE_OPUS_FFT) 158 | /* Is run-time CPU detection enabled on this platform? */ 159 | #if defined(OPUS_HAVE_RTCD) && (defined(HAVE_ARM_NE10)) 160 | 161 | extern int (*const OPUS_FFT_ALLOC_ARCH_IMPL[OPUS_ARCHMASK+1])( 162 | kiss_fft_state *st); 163 | 164 | #define opus_fft_alloc_arch(_st, arch) \ 165 | ((*OPUS_FFT_ALLOC_ARCH_IMPL[(arch)&OPUS_ARCHMASK])(_st)) 166 | 167 | extern void (*const OPUS_FFT_FREE_ARCH_IMPL[OPUS_ARCHMASK+1])( 168 | kiss_fft_state *st); 169 | #define opus_fft_free_arch(_st, arch) \ 170 | ((*OPUS_FFT_FREE_ARCH_IMPL[(arch)&OPUS_ARCHMASK])(_st)) 171 | 172 | extern void (*const OPUS_FFT[OPUS_ARCHMASK+1])(const kiss_fft_state *cfg, 173 | const kiss_fft_cpx *fin, kiss_fft_cpx *fout); 174 | #define opus_fft(_cfg, _fin, _fout, arch) \ 175 | ((*OPUS_FFT[(arch)&OPUS_ARCHMASK])(_cfg, _fin, _fout)) 176 | 177 | extern void (*const OPUS_IFFT[OPUS_ARCHMASK+1])(const kiss_fft_state *cfg, 178 | const kiss_fft_cpx *fin, kiss_fft_cpx *fout); 179 | #define opus_ifft(_cfg, _fin, _fout, arch) \ 180 | ((*OPUS_IFFT[(arch)&OPUS_ARCHMASK])(_cfg, _fin, _fout)) 181 | 182 | #else /* else for if defined(OPUS_HAVE_RTCD) && (defined(HAVE_ARM_NE10)) */ 183 | 184 | #define rnn_fft_alloc_arch(_st, arch) \ 185 | ((void)(arch), rnn_fft_alloc_arch_c(_st)) 186 | 187 | #define rnn_fft_free_arch(_st, arch) \ 188 | ((void)(arch), rnn_fft_free_arch_c(_st)) 189 | 190 | #define rnn_fft(_cfg, _fin, _fout, arch) \ 191 | ((void)(arch), rnn_fft_c(_cfg, _fin, _fout)) 192 | 193 | #define rnn_ifft(_cfg, _fin, _fout, arch) \ 194 | ((void)(arch), rnn_ifft_c(_cfg, _fin, _fout)) 195 | 196 | #endif /* end if defined(OPUS_HAVE_RTCD) && (defined(HAVE_ARM_NE10)) */ 197 | #endif /* end if !defined(OVERRIDE_OPUS_FFT) */ 198 | 199 | #ifdef __cplusplus 200 | } 201 | #endif 202 | 203 | #endif 204 | -------------------------------------------------------------------------------- /src/nnet.c: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2018 Mozilla 2 | 2008-2011 Octasic Inc. 3 | 2012-2017 Jean-Marc Valin */ 4 | /* 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions 7 | are met: 8 | 9 | - Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | - Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in the 14 | documentation and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 20 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | */ 28 | 29 | #ifdef HAVE_CONFIG_H 30 | #include "config.h" 31 | #endif 32 | 33 | #include 34 | #include 35 | #include "opus_types.h" 36 | #include "arch.h" 37 | #include "nnet.h" 38 | #include "common.h" 39 | #include "vec.h" 40 | 41 | #ifdef ENABLE_OSCE 42 | #include "osce.h" 43 | #endif 44 | 45 | #ifdef NO_OPTIMIZATIONS 46 | #if defined(_MSC_VER) 47 | #pragma message ("Compiling without any vectorization. This code will be very slow") 48 | #else 49 | #warning Compiling without any vectorization. This code will be very slow 50 | #endif 51 | #endif 52 | 53 | 54 | #define SOFTMAX_HACK 55 | 56 | 57 | void compute_generic_dense(const LinearLayer *layer, float *output, const float *input, int activation, int arch) 58 | { 59 | compute_linear(layer, output, input, arch); 60 | compute_activation(output, output, layer->nb_outputs, activation, arch); 61 | } 62 | 63 | #define MAX_RNN_NEURONS_ALL 1024 64 | 65 | void compute_generic_gru(const LinearLayer *input_weights, const LinearLayer *recurrent_weights, float *state, const float *in, int arch) 66 | { 67 | int i; 68 | int N; 69 | float zrh[3*MAX_RNN_NEURONS_ALL]; 70 | float recur[3*MAX_RNN_NEURONS_ALL]; 71 | float *z; 72 | float *r; 73 | float *h; 74 | celt_assert(3*recurrent_weights->nb_inputs == recurrent_weights->nb_outputs); 75 | celt_assert(input_weights->nb_outputs == recurrent_weights->nb_outputs); 76 | N = recurrent_weights->nb_inputs; 77 | z = zrh; 78 | r = &zrh[N]; 79 | h = &zrh[2*N]; 80 | celt_assert(recurrent_weights->nb_outputs <= 3*MAX_RNN_NEURONS_ALL); 81 | celt_assert(in != state); 82 | compute_linear(input_weights, zrh, in, arch); 83 | compute_linear(recurrent_weights, recur, state, arch); 84 | for (i=0;i<2*N;i++) 85 | zrh[i] += recur[i]; 86 | compute_activation(zrh, zrh, 2*N, ACTIVATION_SIGMOID, arch); 87 | for (i=0;inb_inputs == layer->nb_outputs); 101 | compute_linear(layer, act2, input, arch); 102 | compute_activation(act2, act2, layer->nb_outputs, ACTIVATION_SIGMOID, arch); 103 | if (input == output) { 104 | /* Give a vectorization hint to the compiler for the in-place case. */ 105 | for (i=0;inb_outputs;i++) output[i] = output[i]*act2[i]; 106 | } else { 107 | for (i=0;inb_outputs;i++) output[i] = input[i]*act2[i]; 108 | } 109 | } 110 | 111 | #define MAX_CONV_INPUTS_ALL 1024 112 | 113 | void compute_generic_conv1d(const LinearLayer *layer, float *output, float *mem, const float *input, int input_size, int activation, int arch) 114 | { 115 | float tmp[MAX_CONV_INPUTS_ALL]; 116 | celt_assert(input != output); 117 | celt_assert(layer->nb_inputs <= MAX_CONV_INPUTS_ALL); 118 | if (layer->nb_inputs!=input_size) RNN_COPY(tmp, mem, layer->nb_inputs-input_size); 119 | RNN_COPY(&tmp[layer->nb_inputs-input_size], input, input_size); 120 | compute_linear(layer, output, tmp, arch); 121 | compute_activation(output, output, layer->nb_outputs, activation, arch); 122 | if (layer->nb_inputs!=input_size) RNN_COPY(mem, &tmp[input_size], layer->nb_inputs-input_size); 123 | } 124 | -------------------------------------------------------------------------------- /src/nnet.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2018 Mozilla 2 | Copyright (c) 2017 Jean-Marc Valin */ 3 | /* 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | - Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | - Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 19 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | 28 | #ifndef NNET_H_ 29 | #define NNET_H_ 30 | 31 | #include 32 | #include "opus_types.h" 33 | 34 | #define ACTIVATION_LINEAR 0 35 | #define ACTIVATION_SIGMOID 1 36 | #define ACTIVATION_TANH 2 37 | #define ACTIVATION_RELU 3 38 | #define ACTIVATION_SOFTMAX 4 39 | #define ACTIVATION_SWISH 5 40 | 41 | #define WEIGHT_BLOB_VERSION 0 42 | #define WEIGHT_BLOCK_SIZE 64 43 | typedef struct { 44 | const char *name; 45 | int type; 46 | int size; 47 | const void *data; 48 | } WeightArray; 49 | 50 | #define WEIGHT_TYPE_float 0 51 | #define WEIGHT_TYPE_int 1 52 | #define WEIGHT_TYPE_qweight 2 53 | #define WEIGHT_TYPE_int8 3 54 | 55 | typedef struct { 56 | char head[4]; 57 | int version; 58 | int type; 59 | int size; 60 | int block_size; 61 | char name[44]; 62 | } WeightHead; 63 | 64 | /* Generic sparse affine transformation. */ 65 | typedef struct { 66 | const float *bias; 67 | const float *subias; 68 | const opus_int8 *weights; 69 | const float *float_weights; 70 | const int *weights_idx; 71 | const float *diag; 72 | const float *scale; 73 | int nb_inputs; 74 | int nb_outputs; 75 | } LinearLayer; 76 | 77 | /* Generic sparse affine transformation. */ 78 | typedef struct { 79 | const float *bias; 80 | const float *float_weights; 81 | int in_channels; 82 | int out_channels; 83 | int ktime; 84 | int kheight; 85 | } Conv2dLayer; 86 | 87 | 88 | /* Changes some symbol names to add the rnn_ prefix so we don't get conflicts with Opus. */ 89 | #define linear_init rnn_linear_init 90 | #define conv2d_init rnn_conv2d_init 91 | #define compute_generic_dense rnn_compute_generic_dense 92 | #define compute_generic_gru rnn_compute_generic_gru 93 | #define compute_generic_conv1d rnn_compute_generic_conv1d 94 | #define compute_glu rnn_compute_glu 95 | 96 | #define parse_weights rnn_parse_weights 97 | 98 | #define compute_linear_c rnn_compute_linear_c 99 | #define compute_activation_c rnn_compute_activation_c 100 | #define compute_conv2d_c rnn_compute_conv2d_c 101 | #define compute_linear_sse4_1 rnn_compute_linear_sse4_1 102 | #define compute_activation_sse4_1 rnn_compute_activation_sse4_1 103 | #define compute_conv2d_sse4_1 rnn_compute_conv2d_sse4_1 104 | #define compute_linear_avx2 rnn_compute_linear_avx2 105 | #define compute_activation_avx2 rnn_compute_activation_avx2 106 | #define compute_conv2d_avx2 rnn_compute_conv2d_avx2 107 | 108 | 109 | void compute_generic_dense(const LinearLayer *layer, float *output, const float *input, int activation, int arch); 110 | void compute_generic_gru(const LinearLayer *input_weights, const LinearLayer *recurrent_weights, float *state, const float *in, int arch); 111 | void compute_generic_conv1d(const LinearLayer *layer, float *output, float *mem, const float *input, int input_size, int activation, int arch); 112 | void compute_glu(const LinearLayer *layer, float *output, const float *input, int arch); 113 | 114 | 115 | int parse_weights(WeightArray **list, const void *data, int len); 116 | 117 | 118 | 119 | int linear_init(LinearLayer *layer, const WeightArray *arrays, 120 | const char *bias, 121 | const char *subias, 122 | const char *weights, 123 | const char *float_weights, 124 | const char *weights_idx, 125 | const char *diag, 126 | const char *scale, 127 | int nb_inputs, 128 | int nb_outputs); 129 | 130 | int conv2d_init(Conv2dLayer *layer, const WeightArray *arrays, 131 | const char *bias, 132 | const char *float_weights, 133 | int in_channels, 134 | int out_channels, 135 | int ktime, 136 | int kheight); 137 | 138 | 139 | void compute_linear_c(const LinearLayer *linear, float *out, const float *in); 140 | void compute_activation_c(float *output, const float *input, int N, int activation); 141 | void compute_conv2d_c(const Conv2dLayer *conv, float *out, float *mem, const float *in, int height, int hstride, int activation); 142 | 143 | #ifdef RNN_ENABLE_X86_RTCD 144 | #include "x86/dnn_x86.h" 145 | #endif 146 | 147 | #ifndef OVERRIDE_COMPUTE_LINEAR 148 | #define compute_linear(linear, out, in, arch) ((void)(arch),compute_linear_c(linear, out, in)) 149 | #endif 150 | 151 | #ifndef OVERRIDE_COMPUTE_ACTIVATION 152 | #define compute_activation(output, input, N, activation, arch) ((void)(arch),compute_activation_c(output, input, N, activation)) 153 | #endif 154 | 155 | #ifndef OVERRIDE_COMPUTE_CONV2D 156 | #define compute_conv2d(conv, out, mem, in, height, hstride, activation, arch) ((void)(arch),compute_conv2d_c(conv, out, mem, in, height, hstride, activation)) 157 | #endif 158 | 159 | #if defined(__x86_64__) && !defined(RNN_ENABLE_X86_RTCD) && !defined(__AVX2__) 160 | #if defined(_MSC_VER) 161 | #pragma message ("Only SSE and SSE2 are available. On newer machines, enable SSSE3/AVX/AVX2 to get better performance") 162 | #else 163 | #warning "Only SSE and SSE2 are available. On newer machines, enable SSSE3/AVX/AVX2 using -march= to get better performance" 164 | #endif 165 | #endif 166 | 167 | 168 | 169 | #endif /* NNET_H_ */ 170 | -------------------------------------------------------------------------------- /src/nnet_arch.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2018-2019 Mozilla 2 | 2023 Amazon */ 3 | /* 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | - Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | - Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 19 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | 28 | #ifndef NNET_ARCH_H 29 | #define NNET_ARCH_H 30 | 31 | #include "nnet.h" 32 | #include "arch.h" 33 | #include "common.h" 34 | #include "vec.h" 35 | 36 | #define CAT_SUFFIX2(a,b) a ## b 37 | #define CAT_SUFFIX(a,b) CAT_SUFFIX2(a, b) 38 | 39 | #define RTCD_SUF(name) CAT_SUFFIX(name, RTCD_ARCH) 40 | 41 | # if !defined(OPUS_GNUC_PREREQ) 42 | # if defined(__GNUC__)&&defined(__GNUC_MINOR__) 43 | # define OPUS_GNUC_PREREQ(_maj,_min) \ 44 | ((__GNUC__<<16)+__GNUC_MINOR__>=((_maj)<<16)+(_min)) 45 | # else 46 | # define OPUS_GNUC_PREREQ(_maj,_min) 0 47 | # endif 48 | # endif 49 | 50 | 51 | /* Force vectorization on for DNN code because some of the loops rely on 52 | compiler vectorization rather than explicitly using intrinsics. */ 53 | #if OPUS_GNUC_PREREQ(5,1) 54 | #define GCC_POP_OPTIONS 55 | #pragma GCC push_options 56 | #pragma GCC optimize("tree-vectorize") 57 | #endif 58 | 59 | 60 | #define MAX_ACTIVATIONS (4096) 61 | 62 | static OPUS_INLINE void vec_swish(float *y, const float *x, int N) 63 | { 64 | int i; 65 | float tmp[MAX_ACTIVATIONS]; 66 | celt_assert(N <= MAX_ACTIVATIONS); 67 | vec_sigmoid(tmp, x, N); 68 | for (i=0;ibias; 136 | M = linear->nb_inputs; 137 | N = linear->nb_outputs; 138 | if (linear->float_weights != NULL) { 139 | if (linear->weights_idx != NULL) sparse_sgemv8x4(out, linear->float_weights, linear->weights_idx, N, in); 140 | else sgemv(out, linear->float_weights, N, M, N, in); 141 | } else if (linear->weights != NULL) { 142 | if (linear->weights_idx != NULL) sparse_cgemv8x4(out, linear->weights, linear->weights_idx, linear->scale, N, M, in); 143 | else cgemv8x4(out, linear->weights, linear->scale, N, M, in); 144 | /* Only use SU biases on for integer matrices on SU archs. */ 145 | #ifdef USE_SU_BIAS 146 | bias = linear->subias; 147 | #endif 148 | } 149 | else RNN_CLEAR(out, N); 150 | if (bias != NULL) { 151 | for (i=0;idiag) { 154 | /* Diag is only used for GRU recurrent weights. */ 155 | celt_assert(3*M == N); 156 | for (i=0;idiag[i]*in[i]; 158 | out[i+M] += linear->diag[i+M]*in[i]; 159 | out[i+2*M] += linear->diag[i+2*M]*in[i]; 160 | } 161 | } 162 | } 163 | 164 | /* Computes non-padded convolution for input [ ksize1 x in_channels x (len2+ksize2) ], 165 | kernel [ out_channels x in_channels x ksize1 x ksize2 ], 166 | storing the output as [ out_channels x len2 ]. 167 | We assume that the output dimension along the ksize1 axis is 1, 168 | i.e. processing one frame at a time. */ 169 | static void conv2d_float(float *out, const float *weights, int in_channels, int out_channels, int ktime, int kheight, const float *in, int height, int hstride) 170 | { 171 | int i; 172 | int in_stride; 173 | in_stride = height+kheight-1; 174 | for (i=0;iin_channels*(height+conv->kheight-1); 233 | celt_assert(conv->ktime*time_stride <= MAX_CONV2D_INPUTS); 234 | RNN_COPY(in_buf, mem, (conv->ktime-1)*time_stride); 235 | RNN_COPY(&in_buf[(conv->ktime-1)*time_stride], in, time_stride); 236 | RNN_COPY(mem, &in_buf[time_stride], (conv->ktime-1)*time_stride); 237 | bias = conv->bias; 238 | if (conv->kheight == 3 && conv->ktime == 3) 239 | conv2d_3x3_float(out, conv->float_weights, conv->in_channels, conv->out_channels, in_buf, height, hstride); 240 | else 241 | conv2d_float(out, conv->float_weights, conv->in_channels, conv->out_channels, conv->ktime, conv->kheight, in_buf, height, hstride); 242 | if (bias != NULL) { 243 | for (i=0;iout_channels;i++) { 244 | int j; 245 | for (j=0;jout_channels;i++) { 249 | RTCD_SUF(compute_activation_)(&out[i*hstride], &out[i*hstride], height, activation); 250 | } 251 | } 252 | 253 | #ifdef GCC_POP_OPTIONS 254 | #pragma GCC pop_options 255 | #endif 256 | 257 | #endif 258 | -------------------------------------------------------------------------------- /src/nnet_default.c: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2018-2019 Mozilla 2 | 2023 Amazon */ 3 | /* 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | - Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | - Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 19 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | 28 | #ifdef HAVE_CONFIG_H 29 | #include "config.h" 30 | #endif 31 | 32 | 33 | #define RTCD_ARCH c 34 | 35 | #include "nnet_arch.h" 36 | -------------------------------------------------------------------------------- /src/opus_types.h: -------------------------------------------------------------------------------- 1 | /* (C) COPYRIGHT 1994-2002 Xiph.Org Foundation */ 2 | /* Modified by Jean-Marc Valin */ 3 | /* 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | - Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | - Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 19 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | /* opus_types.h based on ogg_types.h from libogg */ 28 | 29 | /** 30 | @file opus_types.h 31 | @brief Opus reference implementation types 32 | */ 33 | #ifndef OPUS_TYPES_H 34 | #define OPUS_TYPES_H 35 | 36 | /* Use the real stdint.h if it's there (taken from Paul Hsieh's pstdint.h) */ 37 | #if (defined(__STDC__) && __STDC__ && defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L) || (defined(__GNUC__) && (defined(_STDINT_H) || defined(_STDINT_H_)) || defined (HAVE_STDINT_H)) 38 | #include 39 | 40 | typedef int16_t opus_int16; 41 | typedef uint16_t opus_uint16; 42 | typedef int32_t opus_int32; 43 | typedef uint32_t opus_uint32; 44 | #elif defined(_WIN32) 45 | 46 | # if defined(__CYGWIN__) 47 | # include <_G_config.h> 48 | typedef _G_int32_t opus_int32; 49 | typedef _G_uint32_t opus_uint32; 50 | typedef _G_int16 opus_int16; 51 | typedef _G_uint16 opus_uint16; 52 | # elif defined(__MINGW32__) 53 | typedef short opus_int16; 54 | typedef unsigned short opus_uint16; 55 | typedef int opus_int32; 56 | typedef unsigned int opus_uint32; 57 | # elif defined(__MWERKS__) 58 | typedef int opus_int32; 59 | typedef unsigned int opus_uint32; 60 | typedef short opus_int16; 61 | typedef unsigned short opus_uint16; 62 | # else 63 | /* MSVC/Borland */ 64 | typedef __int32 opus_int32; 65 | typedef unsigned __int32 opus_uint32; 66 | typedef __int16 opus_int16; 67 | typedef unsigned __int16 opus_uint16; 68 | # endif 69 | 70 | #elif defined(__MACOS__) 71 | 72 | # include 73 | typedef SInt16 opus_int16; 74 | typedef UInt16 opus_uint16; 75 | typedef SInt32 opus_int32; 76 | typedef UInt32 opus_uint32; 77 | 78 | #elif (defined(__APPLE__) && defined(__MACH__)) /* MacOS X Framework build */ 79 | 80 | # include 81 | typedef int16_t opus_int16; 82 | typedef u_int16_t opus_uint16; 83 | typedef int32_t opus_int32; 84 | typedef u_int32_t opus_uint32; 85 | 86 | #elif defined(__BEOS__) 87 | 88 | /* Be */ 89 | # include 90 | typedef int16 opus_int16; 91 | typedef u_int16 opus_uint16; 92 | typedef int32_t opus_int32; 93 | typedef u_int32_t opus_uint32; 94 | 95 | #elif defined (__EMX__) 96 | 97 | /* OS/2 GCC */ 98 | typedef short opus_int16; 99 | typedef unsigned short opus_uint16; 100 | typedef int opus_int32; 101 | typedef unsigned int opus_uint32; 102 | 103 | #elif defined (DJGPP) 104 | 105 | /* DJGPP */ 106 | typedef short opus_int16; 107 | typedef unsigned short opus_uint16; 108 | typedef int opus_int32; 109 | typedef unsigned int opus_uint32; 110 | 111 | #elif defined(R5900) 112 | 113 | /* PS2 EE */ 114 | typedef int opus_int32; 115 | typedef unsigned opus_uint32; 116 | typedef short opus_int16; 117 | typedef unsigned short opus_uint16; 118 | 119 | #elif defined(__SYMBIAN32__) 120 | 121 | /* Symbian GCC */ 122 | typedef signed short opus_int16; 123 | typedef unsigned short opus_uint16; 124 | typedef signed int opus_int32; 125 | typedef unsigned int opus_uint32; 126 | 127 | #elif defined(CONFIG_TI_C54X) || defined (CONFIG_TI_C55X) 128 | 129 | typedef short opus_int16; 130 | typedef unsigned short opus_uint16; 131 | typedef long opus_int32; 132 | typedef unsigned long opus_uint32; 133 | 134 | #elif defined(CONFIG_TI_C6X) 135 | 136 | typedef short opus_int16; 137 | typedef unsigned short opus_uint16; 138 | typedef int opus_int32; 139 | typedef unsigned int opus_uint32; 140 | 141 | #else 142 | 143 | /* Give up, take a reasonable guess */ 144 | typedef short opus_int16; 145 | typedef unsigned short opus_uint16; 146 | typedef int opus_int32; 147 | typedef unsigned int opus_uint32; 148 | 149 | #endif 150 | 151 | #define opus_int int /* used for counters etc; at least 16 bits */ 152 | #define opus_int64 long long 153 | #define opus_int8 signed char 154 | 155 | #define opus_uint unsigned int /* used for counters etc; at least 16 bits */ 156 | #define opus_uint64 unsigned long long 157 | #define opus_uint8 unsigned char 158 | 159 | #endif /* OPUS_TYPES_H */ 160 | -------------------------------------------------------------------------------- /src/parse_lpcnet_weights.c: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Amazon */ 2 | /* 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions 5 | are met: 6 | 7 | - Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | 10 | - Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 15 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 16 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 17 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 18 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 19 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 20 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 21 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 22 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 23 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | */ 26 | 27 | #ifdef HAVE_CONFIG_H 28 | #include "config.h" 29 | #endif 30 | 31 | #include 32 | #include 33 | #include "nnet.h" 34 | 35 | #define SPARSE_BLOCK_SIZE 32 36 | 37 | static int parse_record(const void **data, int *len, WeightArray *array) { 38 | WeightHead *h = (WeightHead *)*data; 39 | if (*len < WEIGHT_BLOCK_SIZE) return -1; 40 | if (h->block_size < h->size) return -1; 41 | if (h->block_size > *len-WEIGHT_BLOCK_SIZE) return -1; 42 | if (h->name[sizeof(h->name)-1] != 0) return -1; 43 | if (h->size < 0) return -1; 44 | array->name = h->name; 45 | array->type = h->type; 46 | array->size = h->size; 47 | array->data = (void*)((unsigned char*)(*data)+WEIGHT_BLOCK_SIZE); 48 | 49 | *data = (void*)((unsigned char*)*data + h->block_size+WEIGHT_BLOCK_SIZE); 50 | *len -= h->block_size+WEIGHT_BLOCK_SIZE; 51 | return array->size; 52 | } 53 | 54 | int parse_weights(WeightArray **list, const void *data, int len) 55 | { 56 | int nb_arrays=0; 57 | int capacity=20; 58 | *list = calloc(capacity*sizeof(WeightArray), 1); 59 | while (len > 0) { 60 | int ret; 61 | WeightArray array = {NULL, 0, 0, 0}; 62 | ret = parse_record(&data, &len, &array); 63 | if (ret > 0) { 64 | if (nb_arrays+1 >= capacity) { 65 | /* Make sure there's room for the ending NULL element too. */ 66 | capacity = capacity*3/2; 67 | *list = realloc(*list, capacity*sizeof(WeightArray)); 68 | } 69 | (*list)[nb_arrays++] = array; 70 | } else { 71 | free(*list); 72 | *list = NULL; 73 | return -1; 74 | } 75 | } 76 | (*list)[nb_arrays].name=NULL; 77 | return nb_arrays; 78 | } 79 | 80 | static const void *find_array_entry(const WeightArray *arrays, const char *name) { 81 | while (arrays->name && strcmp(arrays->name, name) != 0) arrays++; 82 | return arrays; 83 | } 84 | 85 | static const void *find_array_check(const WeightArray *arrays, const char *name, int size) { 86 | const WeightArray *a = find_array_entry(arrays, name); 87 | if (a->name && a->size == size) return a->data; 88 | else return NULL; 89 | } 90 | 91 | static const void *opt_array_check(const WeightArray *arrays, const char *name, int size, int *error) { 92 | const WeightArray *a = find_array_entry(arrays, name); 93 | *error = (a->name != NULL && a->size != size); 94 | if (a->name && a->size == size) return a->data; 95 | else return NULL; 96 | } 97 | 98 | static const void *find_idx_check(const WeightArray *arrays, const char *name, int nb_in, int nb_out, int *total_blocks) { 99 | int remain; 100 | const int *idx; 101 | const WeightArray *a = find_array_entry(arrays, name); 102 | *total_blocks = 0; 103 | if (a == NULL) return NULL; 104 | idx = a->data; 105 | remain = a->size/sizeof(int); 106 | while (remain > 0) { 107 | int nb_blocks; 108 | int i; 109 | nb_blocks = *idx++; 110 | if (remain < nb_blocks+1) return NULL; 111 | for (i=0;i= nb_in || (pos&0x3)) return NULL; 114 | } 115 | nb_out -= 8; 116 | remain -= nb_blocks+1; 117 | *total_blocks += nb_blocks; 118 | } 119 | if (nb_out != 0) return NULL; 120 | return a->data; 121 | } 122 | 123 | int linear_init(LinearLayer *layer, const WeightArray *arrays, 124 | const char *bias, 125 | const char *subias, 126 | const char *weights, 127 | const char *float_weights, 128 | const char *weights_idx, 129 | const char *diag, 130 | const char *scale, 131 | int nb_inputs, 132 | int nb_outputs) 133 | { 134 | int err; 135 | layer->bias = NULL; 136 | layer->subias = NULL; 137 | layer->weights = NULL; 138 | layer->float_weights = NULL; 139 | layer->weights_idx = NULL; 140 | layer->diag = NULL; 141 | layer->scale = NULL; 142 | if (bias != NULL) { 143 | if ((layer->bias = find_array_check(arrays, bias, nb_outputs*sizeof(layer->bias[0]))) == NULL) return 1; 144 | } 145 | if (subias != NULL) { 146 | if ((layer->subias = find_array_check(arrays, subias, nb_outputs*sizeof(layer->subias[0]))) == NULL) return 1; 147 | } 148 | if (weights_idx != NULL) { 149 | int total_blocks; 150 | if ((layer->weights_idx = find_idx_check(arrays, weights_idx, nb_inputs, nb_outputs, &total_blocks)) == NULL) return 1; 151 | if (weights != NULL) { 152 | if ((layer->weights = find_array_check(arrays, weights, SPARSE_BLOCK_SIZE*total_blocks*sizeof(layer->weights[0]))) == NULL) return 1; 153 | } 154 | if (float_weights != NULL) { 155 | layer->float_weights = opt_array_check(arrays, float_weights, SPARSE_BLOCK_SIZE*total_blocks*sizeof(layer->float_weights[0]), &err); 156 | if (err) return 1; 157 | } 158 | } else { 159 | if (weights != NULL) { 160 | if ((layer->weights = find_array_check(arrays, weights, nb_inputs*nb_outputs*sizeof(layer->weights[0]))) == NULL) return 1; 161 | } 162 | if (float_weights != NULL) { 163 | layer->float_weights = opt_array_check(arrays, float_weights, nb_inputs*nb_outputs*sizeof(layer->float_weights[0]), &err); 164 | if (err) return 1; 165 | } 166 | } 167 | if (diag != NULL) { 168 | if ((layer->diag = find_array_check(arrays, diag, nb_outputs*sizeof(layer->diag[0]))) == NULL) return 1; 169 | } 170 | if (weights != NULL) { 171 | if ((layer->scale = find_array_check(arrays, scale, nb_outputs*sizeof(layer->scale[0]))) == NULL) return 1; 172 | } 173 | layer->nb_inputs = nb_inputs; 174 | layer->nb_outputs = nb_outputs; 175 | return 0; 176 | } 177 | 178 | int conv2d_init(Conv2dLayer *layer, const WeightArray *arrays, 179 | const char *bias, 180 | const char *float_weights, 181 | int in_channels, 182 | int out_channels, 183 | int ktime, 184 | int kheight) 185 | { 186 | int err; 187 | layer->bias = NULL; 188 | layer->float_weights = NULL; 189 | if (bias != NULL) { 190 | if ((layer->bias = find_array_check(arrays, bias, out_channels*sizeof(layer->bias[0]))) == NULL) return 1; 191 | } 192 | if (float_weights != NULL) { 193 | layer->float_weights = opt_array_check(arrays, float_weights, in_channels*out_channels*ktime*kheight*sizeof(layer->float_weights[0]), &err); 194 | if (err) return 1; 195 | } 196 | layer->in_channels = in_channels; 197 | layer->out_channels = out_channels; 198 | layer->ktime = ktime; 199 | layer->kheight = kheight; 200 | return 0; 201 | } 202 | 203 | 204 | 205 | #if 0 206 | #include 207 | #include 208 | #include 209 | #include 210 | #include 211 | 212 | int main() 213 | { 214 | int fd; 215 | void *data; 216 | int len; 217 | int nb_arrays; 218 | int i; 219 | WeightArray *list; 220 | struct stat st; 221 | const char *filename = "weights_blob.bin"; 222 | stat(filename, &st); 223 | len = st.st_size; 224 | fd = open(filename, O_RDONLY); 225 | data = mmap(NULL, len, PROT_READ, MAP_SHARED, fd, 0); 226 | printf("size is %d\n", len); 227 | nb_arrays = parse_weights(&list, data, len); 228 | for (i=0;i=3); 56 | y_3=0; /* gcc doesn't realize that y_3 can't be used uninitialized */ 57 | y_0=*y++; 58 | y_1=*y++; 59 | y_2=*y++; 60 | for (j=0;j 33 | #include "opus_types.h" 34 | #include "common.h" 35 | #include "arch.h" 36 | #include "rnn.h" 37 | #include "rnnoise_data.h" 38 | #include 39 | 40 | 41 | #define INPUT_SIZE 42 42 | 43 | 44 | void compute_rnn(const RNNoise *model, RNNState *rnn, float *gains, float *vad, const float *input, int arch) { 45 | float tmp[MAX_NEURONS]; 46 | float cat[CONV2_OUT_SIZE + GRU1_OUT_SIZE + GRU2_OUT_SIZE + GRU3_OUT_SIZE]; 47 | /*for (int i=0;iconv1, tmp, rnn->conv1_state, input, CONV1_IN_SIZE, ACTIVATION_TANH, arch); 49 | compute_generic_conv1d(&model->conv2, cat, rnn->conv2_state, tmp, CONV2_IN_SIZE, ACTIVATION_TANH, arch); 50 | compute_generic_gru(&model->gru1_input, &model->gru1_recurrent, rnn->gru1_state, cat, arch); 51 | compute_generic_gru(&model->gru2_input, &model->gru2_recurrent, rnn->gru2_state, rnn->gru1_state, arch); 52 | compute_generic_gru(&model->gru3_input, &model->gru3_recurrent, rnn->gru3_state, rnn->gru2_state, arch); 53 | RNN_COPY(&cat[CONV2_OUT_SIZE], rnn->gru1_state, GRU1_OUT_SIZE); 54 | RNN_COPY(&cat[CONV2_OUT_SIZE+GRU1_OUT_SIZE], rnn->gru2_state, GRU2_OUT_SIZE); 55 | RNN_COPY(&cat[CONV2_OUT_SIZE+GRU1_OUT_SIZE+GRU2_OUT_SIZE], rnn->gru3_state, GRU3_OUT_SIZE); 56 | compute_generic_dense(&model->dense_out, gains, cat, ACTIVATION_SIGMOID, arch); 57 | compute_generic_dense(&model->vad_dense, vad, cat, ACTIVATION_SIGMOID, arch); 58 | /*for (int i=0;i<22;i++) printf("%f ", gains[i]);printf("\n");*/ 59 | /*printf("%f\n", *vad);*/ 60 | } 61 | -------------------------------------------------------------------------------- /src/rnn.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2017 Jean-Marc Valin */ 2 | /* 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions 5 | are met: 6 | 7 | - Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | 10 | - Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 15 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 16 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 17 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 18 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 19 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 20 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 21 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 22 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 23 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | */ 26 | 27 | #ifndef RNN_H_ 28 | #define RNN_H_ 29 | 30 | #include "rnnoise.h" 31 | #include "rnnoise_data.h" 32 | 33 | #include "opus_types.h" 34 | 35 | #define WEIGHTS_SCALE (1.f/256) 36 | 37 | #define MAX_NEURONS 1024 38 | 39 | 40 | typedef struct { 41 | float conv1_state[CONV1_STATE_SIZE]; 42 | float conv2_state[CONV2_STATE_SIZE]; 43 | float gru1_state[GRU1_STATE_SIZE]; 44 | float gru2_state[GRU2_STATE_SIZE]; 45 | float gru3_state[GRU3_STATE_SIZE]; 46 | } RNNState; 47 | void compute_rnn(const RNNoise *model, RNNState *rnn, float *gains, float *vad, const float *input, int arch); 48 | 49 | #endif /* RNN_H_ */ 50 | -------------------------------------------------------------------------------- /src/rnn_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | from __future__ import print_function 4 | 5 | from keras.models import Sequential 6 | from keras.models import Model 7 | from keras.layers import Input 8 | from keras.layers import Dense 9 | from keras.layers import LSTM 10 | from keras.layers import GRU 11 | from keras.layers import SimpleRNN 12 | from keras.layers import Dropout 13 | from keras import losses 14 | import h5py 15 | 16 | from keras import backend as K 17 | import numpy as np 18 | 19 | print('Build model...') 20 | main_input = Input(shape=(None, 22), name='main_input') 21 | #x = Dense(44, activation='relu')(main_input) 22 | #x = GRU(44, dropout=0.0, recurrent_dropout=0.0, activation='tanh', recurrent_activation='sigmoid', return_sequences=True)(x) 23 | x=main_input 24 | x = GRU(128, activation='tanh', recurrent_activation='sigmoid', return_sequences=True)(x) 25 | #x = GRU(128, return_sequences=True)(x) 26 | #x = GRU(22, activation='relu', return_sequences=True)(x) 27 | x = Dense(22, activation='sigmoid')(x) 28 | #x = Dense(22, activation='softplus')(x) 29 | model = Model(inputs=main_input, outputs=x) 30 | 31 | batch_size = 32 32 | 33 | print('Loading data...') 34 | with h5py.File('denoise_data.h5', 'r') as hf: 35 | all_data = hf['denoise_data'][:] 36 | print('done.') 37 | 38 | window_size = 500 39 | 40 | nb_sequences = len(all_data)//window_size 41 | print(nb_sequences, ' sequences') 42 | x_train = all_data[:nb_sequences*window_size, :-22] 43 | x_train = np.reshape(x_train, (nb_sequences, window_size, 22)) 44 | 45 | y_train = np.copy(all_data[:nb_sequences*window_size, -22:]) 46 | y_train = np.reshape(y_train, (nb_sequences, window_size, 22)) 47 | 48 | #y_train = -20*np.log10(np.add(y_train, .03)); 49 | 50 | all_data = 0; 51 | x_train = x_train.astype('float32') 52 | y_train = y_train.astype('float32') 53 | 54 | print(len(x_train), 'train sequences. x shape =', x_train.shape, 'y shape = ', y_train.shape) 55 | 56 | # try using different optimizers and different optimizer configs 57 | model.compile(loss='mean_squared_error', 58 | optimizer='adam', 59 | metrics=['binary_accuracy']) 60 | 61 | print('Train...') 62 | model.fit(x_train, y_train, 63 | batch_size=batch_size, 64 | epochs=200, 65 | validation_data=(x_train, y_train)) 66 | model.save("newweights.hdf5") 67 | -------------------------------------------------------------------------------- /src/write_weights.c: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Amazon */ 2 | /* 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions 5 | are met: 6 | 7 | - Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | 10 | - Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 15 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 16 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 17 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 18 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 19 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 20 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 21 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 22 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 23 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | */ 26 | 27 | #ifdef HAVE_CONFIG_H 28 | #include "config.h" 29 | #endif 30 | 31 | #include 32 | #include 33 | #include 34 | #include "nnet.h" 35 | #include "arch.h" 36 | #include "nnet.h" 37 | 38 | /* This is a bit of a hack because we need to build nnet_data.c and plc_data.c without USE_WEIGHTS_FILE, 39 | but USE_WEIGHTS_FILE is defined in config.h. */ 40 | #undef HAVE_CONFIG_H 41 | #ifdef USE_WEIGHTS_FILE 42 | #undef USE_WEIGHTS_FILE 43 | #endif 44 | #include "rnnoise_data.c" 45 | 46 | void write_weights(const WeightArray *list, FILE *fout) 47 | { 48 | int i=0; 49 | unsigned char zeros[WEIGHT_BLOCK_SIZE] = {0}; 50 | while (list[i].name != NULL) { 51 | WeightHead h; 52 | if (strlen(list[i].name) >= sizeof(h.name) - 1) { 53 | printf("[write_weights] warning: name %s too long\n", list[i].name); 54 | } 55 | memcpy(h.head, "DNNw", 4); 56 | h.version = WEIGHT_BLOB_VERSION; 57 | h.type = list[i].type; 58 | h.size = list[i].size; 59 | h.block_size = (h.size+WEIGHT_BLOCK_SIZE-1)/WEIGHT_BLOCK_SIZE*WEIGHT_BLOCK_SIZE; 60 | RNN_CLEAR(h.name, sizeof(h.name)); 61 | strncpy(h.name, list[i].name, sizeof(h.name)); 62 | h.name[sizeof(h.name)-1] = 0; 63 | celt_assert(sizeof(h) == WEIGHT_BLOCK_SIZE); 64 | fwrite(&h, 1, WEIGHT_BLOCK_SIZE, fout); 65 | fwrite(list[i].data, 1, h.size, fout); 66 | fwrite(zeros, 1, h.block_size-h.size, fout); 67 | i++; 68 | } 69 | } 70 | 71 | int main(void) 72 | { 73 | FILE *fout = fopen("weights_blob.bin", "w"); 74 | write_weights(rnnoise_arrays, fout); 75 | fclose(fout); 76 | return 0; 77 | } 78 | -------------------------------------------------------------------------------- /src/x86/dnn_x86.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2011-2019 Mozilla 2 | 2023 Amazon */ 3 | /* 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | - Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | - Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 19 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | 28 | #ifndef DNN_X86_H 29 | #define DNN_X86_H 30 | 31 | #include "cpu_support.h" 32 | #include "opus_types.h" 33 | 34 | void compute_linear_sse4_1(const LinearLayer *linear, float *out, const float *in); 35 | void compute_activation_sse4_1(float *output, const float *input, int N, int activation); 36 | void compute_conv2d_sse4_1(const Conv2dLayer *conv, float *out, float *mem, const float *in, int height, int hstride, int activation); 37 | 38 | void compute_linear_avx2(const LinearLayer *linear, float *out, const float *in); 39 | void compute_activation_avx2(float *output, const float *input, int N, int activation); 40 | void compute_conv2d_avx2(const Conv2dLayer *conv, float *out, float *mem, const float *in, int height, int hstride, int activation); 41 | 42 | 43 | 44 | #ifdef RNN_ENABLE_X86_RTCD 45 | 46 | extern void (*const RNN_COMPUTE_LINEAR_IMPL[OPUS_ARCHMASK + 1])( 47 | const LinearLayer *linear, 48 | float *out, 49 | const float *in 50 | ); 51 | #define OVERRIDE_COMPUTE_LINEAR 52 | #define compute_linear(linear, out, in, arch) \ 53 | ((*RNN_COMPUTE_LINEAR_IMPL[(arch) & OPUS_ARCHMASK])(linear, out, in)) 54 | 55 | 56 | extern void (*const RNN_COMPUTE_ACTIVATION_IMPL[OPUS_ARCHMASK + 1])( 57 | float *output, 58 | const float *input, 59 | int N, 60 | int activation 61 | ); 62 | #define OVERRIDE_COMPUTE_ACTIVATION 63 | #define compute_activation(output, input, N, activation, arch) \ 64 | ((*RNN_COMPUTE_ACTIVATION_IMPL[(arch) & OPUS_ARCHMASK])(output, input, N, activation)) 65 | 66 | 67 | extern void (*const RNN_COMPUTE_CONV2D_IMPL[OPUS_ARCHMASK + 1])( 68 | const Conv2dLayer *conv, 69 | float *out, 70 | float *mem, 71 | const float *in, 72 | int height, 73 | int hstride, 74 | int activation 75 | ); 76 | #define OVERRIDE_COMPUTE_CONV2D 77 | #define compute_conv2d(conv, out, mem, in, height, hstride, activation, arch) \ 78 | ((*RNN_COMPUTE_CONV2D_IMPL[(arch) & OPUS_ARCHMASK])(conv, out, mem, in, height, hstride, activation)) 79 | 80 | 81 | #endif 82 | 83 | 84 | 85 | #endif /* DNN_X86_H */ 86 | -------------------------------------------------------------------------------- /src/x86/nnet_avx2.c: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2018-2019 Mozilla 2 | 2023 Amazon */ 3 | /* 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | - Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | - Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 19 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | 28 | #ifdef HAVE_CONFIG_H 29 | #include "config.h" 30 | #endif 31 | 32 | #include "x86/x86_arch_macros.h" 33 | 34 | #ifndef __AVX2__ 35 | #error nnet_avx2.c is being compiled without AVX2 enabled 36 | #endif 37 | 38 | #define RTCD_ARCH avx2 39 | 40 | #include "nnet_arch.h" 41 | -------------------------------------------------------------------------------- /src/x86/nnet_sse4_1.c: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2018-2019 Mozilla 2 | 2023 Amazon */ 3 | /* 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | - Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | - Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 19 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | 28 | #ifdef HAVE_CONFIG_H 29 | #include "config.h" 30 | #endif 31 | 32 | #include "x86/x86_arch_macros.h" 33 | 34 | #ifndef __SSE4_1__ 35 | #error nnet_sse4_1.c is being compiled without SSE4.1 enabled 36 | #endif 37 | 38 | #define RTCD_ARCH sse4_1 39 | 40 | #include "nnet_arch.h" 41 | -------------------------------------------------------------------------------- /src/x86/x86_arch_macros.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2023 Amazon */ 2 | /* 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions 5 | are met: 6 | 7 | - Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | 10 | - Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 15 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 16 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 17 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 18 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 19 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 20 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 21 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 22 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 23 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | */ 26 | 27 | #ifdef _MSC_VER 28 | 29 | # ifdef OPUS_X86_MAY_HAVE_SSE 30 | # ifndef __SSE__ 31 | # define __SSE__ 32 | # endif 33 | # endif 34 | 35 | # ifdef OPUS_X86_MAY_HAVE_SSE2 36 | # ifndef __SSE2__ 37 | # define __SSE2__ 38 | # endif 39 | # endif 40 | 41 | # ifdef OPUS_X86_MAY_HAVE_SSE4_1 42 | # ifndef __SSE4_1__ 43 | # define __SSE4_1__ 44 | # endif 45 | # endif 46 | 47 | #endif 48 | -------------------------------------------------------------------------------- /src/x86/x86_dnn_map.c: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2018-2019 Mozilla 2 | 2023 Amazon */ 3 | /* 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | - Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | - Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 19 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | 28 | #ifdef HAVE_CONFIG_H 29 | #include "config.h" 30 | #endif 31 | 32 | #include "x86/x86cpu.h" 33 | #include "nnet.h" 34 | 35 | #ifdef RNN_ENABLE_X86_RTCD 36 | 37 | 38 | void (*const RNN_COMPUTE_LINEAR_IMPL[OPUS_ARCHMASK + 1])( 39 | const LinearLayer *linear, 40 | float *out, 41 | const float *in 42 | ) = { 43 | compute_linear_c, /* non-sse */ 44 | MAY_HAVE_SSE4_1(compute_linear), /* sse4.1 */ 45 | MAY_HAVE_AVX2(compute_linear) /* avx */ 46 | }; 47 | 48 | void (*const RNN_COMPUTE_ACTIVATION_IMPL[OPUS_ARCHMASK + 1])( 49 | float *output, 50 | const float *input, 51 | int N, 52 | int activation 53 | ) = { 54 | compute_activation_c, /* non-sse */ 55 | MAY_HAVE_SSE4_1(compute_activation), /* sse4.1 */ 56 | MAY_HAVE_AVX2(compute_activation) /* avx */ 57 | }; 58 | 59 | void (*const RNN_COMPUTE_CONV2D_IMPL[OPUS_ARCHMASK + 1])( 60 | const Conv2dLayer *conv, 61 | float *out, 62 | float *mem, 63 | const float *in, 64 | int height, 65 | int hstride, 66 | int activation 67 | ) = { 68 | compute_conv2d_c, /* non-sse */ 69 | MAY_HAVE_SSE4_1(compute_conv2d), /* sse4.1 */ 70 | MAY_HAVE_AVX2(compute_conv2d) /* avx */ 71 | }; 72 | 73 | 74 | #endif 75 | -------------------------------------------------------------------------------- /src/x86/x86cpu.c: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2014, Cisco Systems, INC 2 | Written by XiangMingZhu WeiZhou MinPeng YanWang 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | - Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | - Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 19 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | 28 | #ifdef HAVE_CONFIG_H 29 | #include "config.h" 30 | #endif 31 | 32 | #include "cpu_support.h" 33 | #include "pitch.h" 34 | #include "x86cpu.h" 35 | 36 | #ifdef RNN_ENABLE_X86_RTCD 37 | 38 | #if defined(_MSC_VER) 39 | 40 | #include 41 | static _inline void cpuid(unsigned int CPUInfo[4], unsigned int InfoType) 42 | { 43 | __cpuid((int*)CPUInfo, InfoType); 44 | } 45 | 46 | #else 47 | 48 | #if defined(CPU_INFO_BY_C) 49 | #include 50 | #endif 51 | 52 | static void cpuid(unsigned int CPUInfo[4], unsigned int InfoType) 53 | { 54 | #if defined(CPU_INFO_BY_ASM) 55 | #if defined(__i386__) && defined(__PIC__) 56 | /* %ebx is PIC register in 32-bit, so mustn't clobber it. */ 57 | __asm__ __volatile__ ( 58 | "xchg %%ebx, %1\n" 59 | "cpuid\n" 60 | "xchg %%ebx, %1\n": 61 | "=a" (CPUInfo[0]), 62 | "=r" (CPUInfo[1]), 63 | "=c" (CPUInfo[2]), 64 | "=d" (CPUInfo[3]) : 65 | /* We clear ECX to avoid a valgrind false-positive prior to v3.17.0. */ 66 | "0" (InfoType), "2" (0) 67 | ); 68 | #else 69 | __asm__ __volatile__ ( 70 | "cpuid": 71 | "=a" (CPUInfo[0]), 72 | "=b" (CPUInfo[1]), 73 | "=c" (CPUInfo[2]), 74 | "=d" (CPUInfo[3]) : 75 | /* We clear ECX to avoid a valgrind false-positive prior to v3.17.0. */ 76 | "0" (InfoType), "2" (0) 77 | ); 78 | #endif 79 | #elif defined(CPU_INFO_BY_C) 80 | /* We use __get_cpuid_count to clear ECX to avoid a valgrind false-positive 81 | prior to v3.17.0.*/ 82 | if (!__get_cpuid_count(InfoType, 0, &(CPUInfo[0]), &(CPUInfo[1]), &(CPUInfo[2]), &(CPUInfo[3]))) { 83 | /* Our function cannot fail, but __get_cpuid{_count} can. 84 | Returning all zeroes will effectively disable all SIMD, which is 85 | what we want on CPUs that don't support CPUID. */ 86 | CPUInfo[3] = CPUInfo[2] = CPUInfo[1] = CPUInfo[0] = 0; 87 | } 88 | #else 89 | # error "Configured to use x86 RTCD, but no CPU detection method available. " \ 90 | "Reconfigure with --disable-rtcd (or send patches)." 91 | #endif 92 | } 93 | 94 | #endif 95 | 96 | typedef struct CPU_Feature{ 97 | /* SIMD: 128-bit */ 98 | int HW_SSE; 99 | int HW_SSE2; 100 | int HW_SSE41; 101 | /* SIMD: 256-bit */ 102 | int HW_AVX2; 103 | } CPU_Feature; 104 | 105 | static void rnn_cpu_feature_check(CPU_Feature *cpu_feature) 106 | { 107 | unsigned int info[4]; 108 | unsigned int nIds = 0; 109 | 110 | cpuid(info, 0); 111 | nIds = info[0]; 112 | 113 | if (nIds >= 1){ 114 | cpuid(info, 1); 115 | cpu_feature->HW_SSE = (info[3] & (1 << 25)) != 0; 116 | cpu_feature->HW_SSE2 = (info[3] & (1 << 26)) != 0; 117 | cpu_feature->HW_SSE41 = (info[2] & (1 << 19)) != 0; 118 | cpu_feature->HW_AVX2 = (info[2] & (1 << 28)) != 0 && (info[2] & (1 << 12)) != 0; 119 | if (cpu_feature->HW_AVX2 && nIds >= 7) { 120 | cpuid(info, 7); 121 | cpu_feature->HW_AVX2 = cpu_feature->HW_AVX2 && (info[1] & (1 << 5)) != 0; 122 | } else { 123 | cpu_feature->HW_AVX2 = 0; 124 | } 125 | } 126 | else { 127 | cpu_feature->HW_SSE = 0; 128 | cpu_feature->HW_SSE2 = 0; 129 | cpu_feature->HW_SSE41 = 0; 130 | cpu_feature->HW_AVX2 = 0; 131 | } 132 | } 133 | 134 | static int rnn_select_arch_impl(void) 135 | { 136 | CPU_Feature cpu_feature; 137 | int arch; 138 | 139 | rnn_cpu_feature_check(&cpu_feature); 140 | 141 | arch = 0; 142 | if (!cpu_feature.HW_SSE41) 143 | { 144 | return arch; 145 | } 146 | arch++; 147 | 148 | if (!cpu_feature.HW_AVX2) 149 | { 150 | return arch; 151 | } 152 | arch++; 153 | 154 | return arch; 155 | } 156 | 157 | int rnn_select_arch(void) { 158 | int arch = rnn_select_arch_impl(); 159 | #ifdef FUZZING 160 | /* Randomly downgrade the architecture. */ 161 | arch = rand()%(arch+1); 162 | #endif 163 | return arch; 164 | } 165 | 166 | #endif 167 | -------------------------------------------------------------------------------- /src/x86/x86cpu.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2014, Cisco Systems, INC 2 | Written by XiangMingZhu WeiZhou MinPeng YanWang 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | - Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | - Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 19 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | 28 | #if !defined(X86CPU_H) 29 | # define X86CPU_H 30 | 31 | # define MAY_HAVE_SSE4_1(name) name ## _sse4_1 32 | 33 | # define MAY_HAVE_AVX2(name) name ## _avx2 34 | 35 | # ifdef RNN_ENABLE_X86_RTCD 36 | int opus_select_arch(void); 37 | # endif 38 | 39 | # if defined(__SSE2__) 40 | # include "common.h" 41 | 42 | /*MOVD should not impose any alignment restrictions, but the C standard does, 43 | and UBSan will report errors if we actually make unaligned accesses. 44 | Use this to work around those restrictions (which should hopefully all get 45 | optimized to a single MOVD instruction). 46 | GCC implemented _mm_loadu_si32() since GCC 11; HOWEVER, there is a bug! 47 | https://gcc.gnu.org/bugzilla/show_bug.cgi?id=99754 48 | LLVM implemented _mm_loadu_si32() since Clang 8.0, however the 49 | __clang_major__ version number macro is unreliable, as vendors 50 | (specifically, Apple) will use different numbering schemes than upstream. 51 | Clang's advice is "use feature detection", but they do not provide feature 52 | detection support for specific SIMD functions. 53 | We follow the approach from the SIMDe project and instead detect unrelated 54 | features that should be available in the version we want (see 55 | ).*/ 56 | # if defined(__clang__) 57 | # if __has_warning("-Wextra-semi-stmt") || \ 58 | __has_builtin(__builtin_rotateleft32) 59 | # define OPUS_CLANG_8 (1) 60 | # endif 61 | # endif 62 | # if !defined(_MSC_VER) && !OPUS_GNUC_PREREQ(11,3) && !defined(OPUS_CLANG_8) 63 | # include 64 | # include 65 | 66 | # ifdef _mm_loadu_si32 67 | # undef _mm_loadu_si32 68 | # endif 69 | # define _mm_loadu_si32 WORKAROUND_mm_loadu_si32 70 | static inline __m128i WORKAROUND_mm_loadu_si32(void const* mem_addr) { 71 | int val; 72 | memcpy(&val, mem_addr, sizeof(val)); 73 | return _mm_cvtsi32_si128(val); 74 | } 75 | # elif defined(_MSC_VER) 76 | /* MSVC needs this for _mm_loadu_si32 */ 77 | # include 78 | # endif 79 | 80 | # define OP_CVTEPI8_EPI32_M32(x) \ 81 | (_mm_cvtepi8_epi32(_mm_loadu_si32(x))) 82 | 83 | # define OP_CVTEPI16_EPI32_M64(x) \ 84 | (_mm_cvtepi16_epi32(_mm_loadl_epi64((__m128i *)(void*)(x)))) 85 | 86 | # endif 87 | 88 | #endif 89 | -------------------------------------------------------------------------------- /torch/rnnoise/dump_rnnoise_weights.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | 5 | import torch 6 | from torch import nn 7 | 8 | 9 | sys.path.append(os.path.join(os.path.split(__file__)[0], '../weight-exchange')) 10 | import wexchange.torch 11 | 12 | import rnnoise 13 | #from models import model_dict 14 | 15 | unquantized = [ 'conv1', 'dense_out', 'vad_dense' ] 16 | 17 | description=f""" 18 | This is an unsafe dumping script for RNNoise models. It assumes that all weights are included in Linear, Conv1d or GRU layer 19 | and will fail to export any other weights. 20 | 21 | Furthermore, the quanitze option relies on the following explicit list of layers to be excluded: 22 | {unquantized}. 23 | 24 | Modify this script manually if adjustments are needed. 25 | """ 26 | 27 | parser = argparse.ArgumentParser(description=description) 28 | parser.add_argument('weightfile', type=str, help='weight file path') 29 | parser.add_argument('export_folder', type=str) 30 | parser.add_argument('--export-filename', type=str, default='rnnoise_data', help='filename for source and header file (.c and .h will be added), defaults to rnnoise_data') 31 | parser.add_argument('--struct-name', type=str, default='RNNoise', help='name for C struct, defaults to RNNoise') 32 | parser.add_argument('--quantize', action='store_true', help='apply quantization') 33 | 34 | if __name__ == "__main__": 35 | args = parser.parse_args() 36 | 37 | print(f"loading weights from {args.weightfile}...") 38 | saved_gen= torch.load(args.weightfile, map_location='cpu') 39 | saved_gen['model_args'] = () 40 | #saved_gen['model_kwargs'] = {'cond_size': 256, 'gamma': 0.9} 41 | 42 | model = rnnoise.RNNoise(*saved_gen['model_args'], **saved_gen['model_kwargs']) 43 | model.load_state_dict(saved_gen['state_dict'], strict=False) 44 | def _remove_weight_norm(m): 45 | try: 46 | torch.nn.utils.remove_weight_norm(m) 47 | except ValueError: # this module didn't have weight norm 48 | return 49 | model.apply(_remove_weight_norm) 50 | 51 | 52 | print("dumping model...") 53 | quantize_model=args.quantize 54 | 55 | output_folder = args.export_folder 56 | os.makedirs(output_folder, exist_ok=True) 57 | 58 | writer = wexchange.c_export.c_writer.CWriter(os.path.join(output_folder, args.export_filename), model_struct_name=args.struct_name, add_typedef=True) 59 | 60 | for name, module in model.named_modules(): 61 | 62 | if quantize_model: 63 | quantize=name not in unquantized 64 | scale = None if quantize else 1/128 65 | else: 66 | quantize=False 67 | scale=1/128 68 | 69 | if isinstance(module, nn.Linear): 70 | print(f"dumping linear layer {name}...") 71 | wexchange.torch.dump_torch_dense_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale) 72 | 73 | elif isinstance(module, nn.Conv1d): 74 | print(f"dumping conv1d layer {name}...") 75 | wexchange.torch.dump_torch_conv1d_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale) 76 | 77 | elif isinstance(module, nn.GRU): 78 | print(f"dumping GRU layer {name}...") 79 | wexchange.torch.dump_torch_gru_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale, input_sparse=True, recurrent_sparse=True) 80 | 81 | elif isinstance(module, nn.GRUCell): 82 | print(f"dumping GRUCell layer {name}...") 83 | wexchange.torch.dump_torch_grucell_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale, recurrent_scale=scale) 84 | 85 | elif isinstance(module, nn.Embedding): 86 | print(f"dumping Embedding layer {name}...") 87 | wexchange.torch.dump_torch_embedding_weights(writer, module, name.replace('.', '_'), quantize=quantize, scale=scale) 88 | #wexchange.torch.dump_torch_embedding_weights(writer, module) 89 | 90 | else: 91 | print(f"Ignoring layer {name}...") 92 | 93 | writer.close() 94 | -------------------------------------------------------------------------------- /torch/rnnoise/rnnoise.py: -------------------------------------------------------------------------------- 1 | """ 2 | /* Copyright (c) 2024 Jean-Marc Valin */ 3 | /* 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | - Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | - Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 19 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | """ 28 | 29 | import torch 30 | from torch import nn 31 | import torch.nn.functional as F 32 | import sys 33 | import os 34 | 35 | sys.path.append(os.path.join(os.path.split(__file__)[0], '..')) 36 | from sparsification import GRUSparsifier 37 | 38 | sparsify_start = 6000 39 | sparsify_stop = 20000 40 | sparsify_interval = 100 41 | sparsify_exponent = 3 42 | 43 | sparse_params1 = { 44 | 'W_hr' : (0.3, [8, 4], True), 45 | 'W_hz' : (0.2, [8, 4], True), 46 | 'W_hn' : (0.5, [8, 4], True), 47 | 'W_ir' : (0.3, [8, 4], False), 48 | 'W_iz' : (0.2, [8, 4], False), 49 | 'W_in' : (0.5, [8, 4], False) 50 | } 51 | 52 | def init_weights(module): 53 | if isinstance(module, nn.GRU): 54 | for p in module.named_parameters(): 55 | if p[0].startswith('weight_hh_'): 56 | nn.init.orthogonal_(p[1]) 57 | 58 | class RNNoise(nn.Module): 59 | def __init__(self, input_dim=65, output_dim=32, cond_size=128, gru_size=256): 60 | super(RNNoise, self).__init__() 61 | 62 | self.input_dim = input_dim 63 | self.output_dim = output_dim 64 | self.cond_size = cond_size 65 | self.gru_size = gru_size 66 | self.conv1 = nn.Conv1d(input_dim, cond_size, kernel_size=3, padding='valid') 67 | self.conv2 = nn.Conv1d(cond_size, gru_size, kernel_size=3, padding='valid') 68 | self.gru1 = nn.GRU(self.gru_size, self.gru_size, batch_first=True) 69 | self.gru2 = nn.GRU(self.gru_size, self.gru_size, batch_first=True) 70 | self.gru3 = nn.GRU(self.gru_size, self.gru_size, batch_first=True) 71 | self.dense_out = nn.Linear(4*self.gru_size, self.output_dim) 72 | self.vad_dense = nn.Linear(4*self.gru_size, 1) 73 | nb_params = sum(p.numel() for p in self.parameters()) 74 | print(f"model: {nb_params} weights") 75 | self.apply(init_weights) 76 | self.sparsifier = [] 77 | self.sparsifier.append(GRUSparsifier([(self.gru1, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent)) 78 | self.sparsifier.append(GRUSparsifier([(self.gru2, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent)) 79 | self.sparsifier.append(GRUSparsifier([(self.gru3, sparse_params1)], sparsify_start, sparsify_stop, sparsify_interval, sparsify_exponent)) 80 | 81 | 82 | def sparsify(self): 83 | for sparsifier in self.sparsifier: 84 | sparsifier.step() 85 | 86 | def forward(self, features, states=None): 87 | #print(states) 88 | device = features.device 89 | batch_size = features.size(0) 90 | if states is None: 91 | gru1_state = torch.zeros((1, batch_size, self.gru_size), device=device) 92 | gru2_state = torch.zeros((1, batch_size, self.gru_size), device=device) 93 | gru3_state = torch.zeros((1, batch_size, self.gru_size), device=device) 94 | else: 95 | gru1_state = states[0] 96 | gru2_state = states[1] 97 | gru3_state = states[2] 98 | tmp = features.permute(0, 2, 1) 99 | tmp = torch.tanh(self.conv1(tmp)) 100 | tmp = torch.tanh(self.conv2(tmp)) 101 | tmp = tmp.permute(0, 2, 1) 102 | 103 | gru1_out, gru1_state = self.gru1(tmp, gru1_state) 104 | gru2_out, gru2_state = self.gru2(gru1_out, gru2_state) 105 | gru3_out, gru3_state = self.gru3(gru2_out, gru3_state) 106 | out_cat = torch.cat([tmp, gru1_out, gru2_out, gru3_out], dim=-1) 107 | gain = torch.sigmoid(self.dense_out(out_cat)) 108 | vad = torch.sigmoid(self.vad_dense(out_cat)) 109 | return gain, vad, [gru1_state, gru2_state, gru3_state] 110 | -------------------------------------------------------------------------------- /torch/rnnoise/train_rnnoise.py: -------------------------------------------------------------------------------- 1 | """ 2 | /* Copyright (c) 2024 Jean-Marc Valin */ 3 | /* 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | - Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | - Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 19 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | */ 27 | """ 28 | 29 | import numpy as np 30 | import torch 31 | from torch import nn 32 | import torch.nn.functional as F 33 | import tqdm 34 | import os 35 | import rnnoise 36 | import argparse 37 | 38 | parser = argparse.ArgumentParser() 39 | 40 | parser.add_argument('features', type=str, help='path to feature file in .f32 format') 41 | parser.add_argument('output', type=str, help='path to output folder') 42 | 43 | parser.add_argument('--suffix', type=str, help="model name suffix", default="") 44 | parser.add_argument('--cuda-visible-devices', type=str, help="comma separates list of cuda visible device indices, default: CUDA_VISIBLE_DEVICES", default=None) 45 | 46 | 47 | model_group = parser.add_argument_group(title="model parameters") 48 | model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 128", default=128) 49 | model_group.add_argument('--gru-size', type=int, help="first conditioning size, default: 384", default=384) 50 | 51 | training_group = parser.add_argument_group(title="training parameters") 52 | training_group.add_argument('--batch-size', type=int, help="batch size, default: 128", default=128) 53 | training_group.add_argument('--lr', type=float, help='learning rate, default: 1e-3', default=1e-3) 54 | training_group.add_argument('--epochs', type=int, help='number of training epochs, default: 200', default=200) 55 | training_group.add_argument('--sequence-length', type=int, help='sequence length, default: 2000', default=2000) 56 | training_group.add_argument('--lr-decay', type=float, help='learning rate decay factor, default: 5e-5', default=5e-5) 57 | training_group.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None) 58 | training_group.add_argument('--gamma', type=float, help='perceptual exponent (default 0.25)', default=0.25) 59 | training_group.add_argument('--sparse', action='store_true') 60 | 61 | args = parser.parse_args() 62 | 63 | 64 | 65 | class RNNoiseDataset(torch.utils.data.Dataset): 66 | def __init__(self, 67 | features_file, 68 | sequence_length=2000): 69 | 70 | self.sequence_length = sequence_length 71 | 72 | self.data = np.memmap(features_file, dtype='float32', mode='r') 73 | dim = 98 74 | 75 | self.nb_sequences = self.data.shape[0]//self.sequence_length//dim 76 | self.data = self.data[:self.nb_sequences*self.sequence_length*dim] 77 | 78 | self.data = np.reshape(self.data, (self.nb_sequences, self.sequence_length, dim)) 79 | 80 | def __len__(self): 81 | return self.nb_sequences 82 | 83 | def __getitem__(self, index): 84 | return self.data[index, :, :65].copy(), self.data[index, :, 65:-1].copy(), self.data[index, :, -1:].copy() 85 | 86 | def mask(g): 87 | return torch.clamp(g+1, max=1) 88 | 89 | adam_betas = [0.8, 0.98] 90 | adam_eps = 1e-8 91 | batch_size = args.batch_size 92 | lr = args.lr 93 | epochs = args.epochs 94 | sequence_length = args.sequence_length 95 | lr_decay = args.lr_decay 96 | 97 | cond_size = args.cond_size 98 | gru_size = args.gru_size 99 | 100 | checkpoint_dir = os.path.join(args.output, 'checkpoints') 101 | os.makedirs(checkpoint_dir, exist_ok=True) 102 | checkpoint = dict() 103 | 104 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 105 | 106 | checkpoint['model_args'] = () 107 | checkpoint['model_kwargs'] = {'cond_size': cond_size, 'gru_size': gru_size} 108 | model = rnnoise.RNNoise(*checkpoint['model_args'], **checkpoint['model_kwargs']) 109 | 110 | if type(args.initial_checkpoint) != type(None): 111 | checkpoint = torch.load(args.initial_checkpoint, map_location='cpu') 112 | model.load_state_dict(checkpoint['state_dict'], strict=False) 113 | 114 | checkpoint['state_dict'] = model.state_dict() 115 | 116 | dataset = RNNoiseDataset(args.features) 117 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4) 118 | 119 | 120 | optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=adam_betas, eps=adam_eps) 121 | 122 | 123 | # learning rate scheduler 124 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda x : 1 / (1 + lr_decay * x)) 125 | 126 | gamma = args.gamma 127 | 128 | if __name__ == '__main__': 129 | model.to(device) 130 | states = None 131 | for epoch in range(1, epochs + 1): 132 | 133 | running_gain_loss = 0 134 | running_vad_loss = 0 135 | running_loss = 0 136 | 137 | print(f"training epoch {epoch}...") 138 | with tqdm.tqdm(dataloader, unit='batch') as tepoch: 139 | for i, (features, gain, vad) in enumerate(tepoch): 140 | optimizer.zero_grad() 141 | features = features.to(device) 142 | gain = gain.to(device) 143 | vad = vad.to(device) 144 | 145 | pred_gain, pred_vad, states = model(features, states=states) 146 | states = [state.detach() for state in states] 147 | gain = gain[:,3:-1,:] 148 | vad = vad[:,3:-1,:] 149 | target_gain = torch.clamp(gain, min=0) 150 | target_gain = target_gain*(torch.tanh(8*target_gain)**2) 151 | 152 | e = pred_gain**gamma - target_gain**gamma 153 | gain_loss = torch.mean((1+5.*vad)*mask(gain)*(e**2)) 154 | #vad_loss = torch.mean(torch.abs(2*vad-1)*(vad-pred_vad)**2) 155 | vad_loss = torch.mean(torch.abs(2*vad-1)*(-vad*torch.log(.01+pred_vad) - (1-vad)*torch.log(1.01-pred_vad))) 156 | loss = gain_loss + .001*vad_loss 157 | 158 | loss.backward() 159 | optimizer.step() 160 | if args.sparse: 161 | model.sparsify() 162 | 163 | scheduler.step() 164 | 165 | running_gain_loss += gain_loss.detach().cpu().item() 166 | running_vad_loss += vad_loss.detach().cpu().item() 167 | running_loss += loss.detach().cpu().item() 168 | tepoch.set_postfix(loss=f"{running_loss/(i+1):8.5f}", 169 | gain_loss=f"{running_gain_loss/(i+1):8.5f}", 170 | vad_loss=f"{running_vad_loss/(i+1):8.5f}", 171 | ) 172 | 173 | # save checkpoint 174 | checkpoint_path = os.path.join(checkpoint_dir, f'rnnoise{args.suffix}_{epoch}.pth') 175 | checkpoint['state_dict'] = model.state_dict() 176 | checkpoint['loss'] = running_loss / len(dataloader) 177 | checkpoint['epoch'] = epoch 178 | torch.save(checkpoint, checkpoint_path) 179 | -------------------------------------------------------------------------------- /torch/sparsification/__init__.py: -------------------------------------------------------------------------------- 1 | from .gru_sparsifier import GRUSparsifier 2 | from .common import sparsify_matrix, calculate_gru_flops_per_step -------------------------------------------------------------------------------- /torch/sparsification/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | /* Copyright (c) 2023 Amazon 3 | Written by Jan Buethe */ 4 | /* 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions 7 | are met: 8 | 9 | - Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | - Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in the 14 | documentation and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | */ 28 | """ 29 | 30 | import torch 31 | 32 | def sparsify_matrix(matrix : torch.tensor, density : float, block_size, keep_diagonal : bool=False, return_mask : bool=False): 33 | """ sparsifies matrix with specified block size 34 | 35 | Parameters: 36 | ----------- 37 | matrix : torch.tensor 38 | matrix to sparsify 39 | density : int 40 | target density 41 | block_size : [int, int] 42 | block size dimensions 43 | keep_diagonal : bool 44 | If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False 45 | """ 46 | 47 | m, n = matrix.shape 48 | m1, n1 = block_size 49 | 50 | if m % m1 or n % n1: 51 | raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}") 52 | 53 | # extract diagonal if keep_diagonal = True 54 | if keep_diagonal: 55 | if m != n: 56 | raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True") 57 | 58 | to_spare = torch.diag(torch.diag(matrix)) 59 | matrix = matrix - to_spare 60 | else: 61 | to_spare = torch.zeros_like(matrix) 62 | 63 | # calculate energy in sub-blocks 64 | x = torch.reshape(matrix, (m // m1, m1, n // n1, n1)) 65 | x = x ** 2 66 | block_energies = torch.sum(torch.sum(x, dim=3), dim=1) 67 | 68 | number_of_blocks = (m * n) // (m1 * n1) 69 | number_of_survivors = round(number_of_blocks * density) 70 | 71 | # masking threshold 72 | if number_of_survivors == 0: 73 | threshold = 0 74 | else: 75 | threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors] 76 | 77 | # create mask 78 | mask = torch.ones_like(block_energies) 79 | mask[block_energies < threshold] = 0 80 | mask = torch.repeat_interleave(mask, m1, dim=0) 81 | mask = torch.repeat_interleave(mask, n1, dim=1) 82 | 83 | # perform masking 84 | masked_matrix = mask * matrix + to_spare 85 | 86 | if return_mask: 87 | return masked_matrix, mask 88 | else: 89 | return masked_matrix 90 | 91 | def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False): 92 | input_size = gru.input_size 93 | hidden_size = gru.hidden_size 94 | flops = 0 95 | 96 | input_density = ( 97 | sparsification_dict.get('W_ir', [1])[0] 98 | + sparsification_dict.get('W_in', [1])[0] 99 | + sparsification_dict.get('W_iz', [1])[0] 100 | ) / 3 101 | 102 | recurrent_density = ( 103 | sparsification_dict.get('W_hr', [1])[0] 104 | + sparsification_dict.get('W_hn', [1])[0] 105 | + sparsification_dict.get('W_hz', [1])[0] 106 | ) / 3 107 | 108 | # input matrix vector multiplications 109 | if not drop_input: 110 | flops += 2 * 3 * input_size * hidden_size * input_density 111 | 112 | # recurrent matrix vector multiplications 113 | flops += 2 * 3 * hidden_size * hidden_size * recurrent_density 114 | 115 | # biases 116 | flops += 6 * hidden_size 117 | 118 | # activations estimated by 10 flops per activation 119 | flops += 30 * hidden_size 120 | 121 | return flops 122 | -------------------------------------------------------------------------------- /torch/sparsification/gru_sparsifier.py: -------------------------------------------------------------------------------- 1 | """ 2 | /* Copyright (c) 2023 Amazon 3 | Written by Jan Buethe */ 4 | /* 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions 7 | are met: 8 | 9 | - Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | - Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in the 14 | documentation and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | */ 28 | """ 29 | 30 | import torch 31 | 32 | from .common import sparsify_matrix 33 | 34 | 35 | class GRUSparsifier: 36 | def __init__(self, task_list, start, stop, interval, exponent=3): 37 | """ Sparsifier for torch.nn.GRUs 38 | 39 | Parameters: 40 | ----------- 41 | task_list : list 42 | task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance 43 | of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in', 44 | 'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset, 45 | update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal), 46 | where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which 47 | sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal 48 | should be kept. 49 | 50 | start : int 51 | training step after which sparsification will be started. 52 | 53 | stop : int 54 | training step after which sparsification will be completed. 55 | 56 | interval : int 57 | sparsification interval for steps between start and stop. After stop sparsification will be 58 | carried out after every call to GRUSparsifier.step() 59 | 60 | exponent : float 61 | Interpolation exponent for sparsification interval. In step i sparsification will be carried out 62 | with density (alpha + target_density * (1 * alpha)), where 63 | alpha = ((stop - i) / (start - stop)) ** exponent 64 | 65 | Example: 66 | -------- 67 | >>> import torch 68 | >>> gru = torch.nn.GRU(10, 20) 69 | >>> sparsify_dict = { 70 | ... 'W_ir' : (0.5, [2, 2], False), 71 | ... 'W_iz' : (0.6, [2, 2], False), 72 | ... 'W_in' : (0.7, [2, 2], False), 73 | ... 'W_hr' : (0.1, [4, 4], True), 74 | ... 'W_hz' : (0.2, [4, 4], True), 75 | ... 'W_hn' : (0.3, [4, 4], True), 76 | ... } 77 | >>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50) 78 | >>> for i in range(100): 79 | ... sparsifier.step() 80 | """ 81 | # just copying parameters... 82 | self.start = start 83 | self.stop = stop 84 | self.interval = interval 85 | self.exponent = exponent 86 | self.task_list = task_list 87 | 88 | # ... and setting counter to 0 89 | self.step_counter = 0 90 | 91 | self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']} 92 | 93 | def step(self, verbose=False): 94 | """ carries out sparsification step 95 | 96 | Call this function after optimizer.step in your 97 | training loop. 98 | 99 | Parameters: 100 | ---------- 101 | verbose : bool 102 | if true, densities are printed out 103 | 104 | Returns: 105 | -------- 106 | None 107 | 108 | """ 109 | # compute current interpolation factor 110 | self.step_counter += 1 111 | 112 | if self.step_counter < self.start: 113 | return 114 | elif self.step_counter < self.stop: 115 | # update only every self.interval-th interval 116 | if self.step_counter % self.interval: 117 | return 118 | 119 | alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent 120 | else: 121 | alpha = 0 122 | 123 | 124 | with torch.no_grad(): 125 | for gru, params in self.task_list: 126 | hidden_size = gru.hidden_size 127 | 128 | # input weights 129 | for i, key in enumerate(['W_ir', 'W_iz', 'W_in']): 130 | if key in params: 131 | density = alpha + (1 - alpha) * params[key][0] 132 | if verbose: 133 | print(f"[{self.step_counter}]: {key} density: {density}") 134 | 135 | gru.weight_ih_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix( 136 | gru.weight_ih_l0[i * hidden_size : (i + 1) * hidden_size, : ], 137 | density, # density 138 | params[key][1], # block_size 139 | params[key][2], # keep_diagonal (might want to set this to False) 140 | return_mask=True 141 | ) 142 | 143 | if type(self.last_masks[key]) != type(None): 144 | if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop: 145 | print(f"sparsification mask {key} changed for gru {gru}") 146 | 147 | self.last_masks[key] = new_mask 148 | 149 | # recurrent weights 150 | for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']): 151 | if key in params: 152 | density = alpha + (1 - alpha) * params[key][0] 153 | if verbose: 154 | print(f"[{self.step_counter}]: {key} density: {density}") 155 | gru.weight_hh_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix( 156 | gru.weight_hh_l0[i * hidden_size : (i + 1) * hidden_size, : ], 157 | density, 158 | params[key][1], # block_size 159 | params[key][2], # keep_diagonal (might want to set this to False) 160 | return_mask=True 161 | ) 162 | 163 | if type(self.last_masks[key]) != type(None): 164 | if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop: 165 | print(f"sparsification mask {key} changed for gru {gru}") 166 | 167 | self.last_masks[key] = new_mask 168 | 169 | 170 | 171 | if __name__ == "__main__": 172 | print("Testing sparsifier") 173 | 174 | gru = torch.nn.GRU(10, 20) 175 | sparsify_dict = { 176 | 'W_ir' : (0.5, [2, 2], False), 177 | 'W_iz' : (0.6, [2, 2], False), 178 | 'W_in' : (0.7, [2, 2], False), 179 | 'W_hr' : (0.1, [4, 4], True), 180 | 'W_hz' : (0.2, [4, 4], True), 181 | 'W_hn' : (0.3, [4, 4], True), 182 | } 183 | 184 | sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10) 185 | 186 | for i in range(100): 187 | sparsifier.step(verbose=True) 188 | -------------------------------------------------------------------------------- /torch/weight-exchange/README.md: -------------------------------------------------------------------------------- 1 | # weight-exchange 2 | 3 | 4 | 5 | ## Weight Exchange 6 | Repo wor exchanging weights betweeen torch an tensorflow.keras modules, using an intermediate numpy format. 7 | 8 | Routines for loading/dumping torch weights are located in exchange/torch and can be loaded with 9 | ``` 10 | import exchange.torch 11 | ``` 12 | and routines for loading/dumping tensorflow weights are located in exchange/tf and can be loaded with 13 | ``` 14 | import exchange.tf 15 | ``` 16 | 17 | Note that `exchange.torch` requires torch to be installed and `exchange.tf` requires tensorflow. To avoid the necessity of installing both torch and tensorflow in the working environment, none of these submodules is imported when calling `import exchange`. Similarly, the requirements listed in `requirements.txt` do include neither Tensorflow or Pytorch. 18 | 19 | 20 | ## C export 21 | The module `exchange.c_export` contains routines to export weights to C files. On the long run it will be possible to call all `dump_...` functions with either a path string or a `CWriter` instance based on which the export format is chosen. This is currently only implemented for `torch.nn.GRU`, `torch.nn.Linear` and `torch.nn.Conv1d`. -------------------------------------------------------------------------------- /torch/weight-exchange/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy -------------------------------------------------------------------------------- /torch/weight-exchange/setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | /* Copyright (c) 2023 Amazon 3 | Written by Jan Buethe */ 4 | /* 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions 7 | are met: 8 | 9 | - Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | - Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in the 14 | documentation and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | */ 28 | """ 29 | 30 | #!/usr/bin/env/python 31 | import os 32 | from setuptools import setup 33 | 34 | lib_folder = os.path.dirname(os.path.realpath(__file__)) 35 | 36 | with open(os.path.join(lib_folder, 'requirements.txt'), 'r') as f: 37 | install_requires = list(f.read().splitlines()) 38 | 39 | print(install_requires) 40 | 41 | setup(name='wexchange', 42 | version='1.6', 43 | author='Jan Buethe', 44 | author_email='jbuethe@amazon.de', 45 | description='Weight-exchange library between Pytorch and Tensorflow', 46 | packages=['wexchange', 'wexchange.tf', 'wexchange.torch', 'wexchange.c_export'], 47 | install_requires=install_requires 48 | ) 49 | -------------------------------------------------------------------------------- /torch/weight-exchange/wexchange/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | /* Copyright (c) 2023 Amazon 3 | Written by Jan Buethe */ 4 | /* 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions 7 | are met: 8 | 9 | - Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | - Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in the 14 | documentation and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | */ 28 | """ 29 | 30 | from . import c_export -------------------------------------------------------------------------------- /torch/weight-exchange/wexchange/c_export/__init__.py: -------------------------------------------------------------------------------- 1 | from .c_writer import CWriter 2 | """ 3 | /* Copyright (c) 2023 Amazon 4 | Written by Jan Buethe */ 5 | /* 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions 8 | are met: 9 | 10 | - Redistributions of source code must retain the above copyright 11 | notice, this list of conditions and the following disclaimer. 12 | 13 | - Redistributions in binary form must reproduce the above copyright 14 | notice, this list of conditions and the following disclaimer in the 15 | documentation and/or other materials provided with the distribution. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 21 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 22 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 23 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 24 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 25 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 26 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | */ 29 | """ 30 | 31 | from .common import print_gru_layer, print_dense_layer, print_conv1d_layer, print_tconv1d_layer, print_conv2d_layer, print_vector -------------------------------------------------------------------------------- /torch/weight-exchange/wexchange/c_export/c_writer.py: -------------------------------------------------------------------------------- 1 | """ 2 | /* Copyright (c) 2023 Amazon 3 | Written by Jan Buethe */ 4 | /* 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions 7 | are met: 8 | 9 | - Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | - Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in the 14 | documentation and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | */ 28 | """ 29 | 30 | import os 31 | from collections import OrderedDict 32 | 33 | class CWriter: 34 | def __init__(self, 35 | filename_without_extension, 36 | message=None, 37 | header_only=False, 38 | create_state_struct=False, 39 | enable_binary_blob=True, 40 | model_struct_name="Model", 41 | nnet_header="nnet.h", 42 | add_typedef=False): 43 | """ 44 | Writer class for creating souce and header files for weight exports to C 45 | 46 | Parameters: 47 | ----------- 48 | 49 | filename_without_extension: str 50 | filename from which .c and .h files are created 51 | 52 | message: str, optional 53 | if given and not None, this message will be printed as comment in the header file 54 | 55 | header_only: bool, optional 56 | if True, only a header file is created; defaults to False 57 | 58 | enable_binary_blob: bool, optional 59 | if True, export is done in binary blob format and a model type is created; defaults to False 60 | 61 | create_state_struct: bool, optional 62 | if True, a state struct type is created in the header file; if False, state sizes are defined as macros; defaults to False 63 | 64 | model_struct_name: str, optional 65 | name used for the model struct type; only relevant when enable_binary_blob is True; defaults to "Model" 66 | 67 | nnet_header: str, optional 68 | name of header nnet header file; defaults to nnet.h 69 | 70 | """ 71 | 72 | 73 | self.header_only = header_only 74 | self.enable_binary_blob = enable_binary_blob 75 | self.create_state_struct = create_state_struct 76 | self.model_struct_name = model_struct_name 77 | self.add_typedef = add_typedef 78 | 79 | # for binary blob format, format is key=, value=(, ) 80 | self.layer_dict = OrderedDict() 81 | 82 | # for binary blob format, format is key=, value= 83 | self.weight_arrays = [] 84 | 85 | # form model struct, format is key=, value= 86 | self.state_dict = OrderedDict() 87 | 88 | self.header = open(filename_without_extension + ".h", "w") 89 | header_name = os.path.basename(filename_without_extension) + '.h' 90 | 91 | if message is not None: 92 | self.header.write(f"/* {message} */\n\n") 93 | 94 | self.header_guard = os.path.basename(filename_without_extension).upper() + "_H" 95 | self.header.write( 96 | f''' 97 | #ifndef {self.header_guard} 98 | #define {self.header_guard} 99 | 100 | #include "{nnet_header}" 101 | 102 | ''' 103 | ) 104 | 105 | if not self.header_only: 106 | self.source = open(filename_without_extension + ".c", "w") 107 | if message is not None: 108 | self.source.write(f"/* {message} */\n\n") 109 | 110 | self.source.write( 111 | f""" 112 | #ifdef HAVE_CONFIG_H 113 | #include "config.h" 114 | #endif 115 | 116 | """) 117 | self.source.write(f'#include "{header_name}"\n\n') 118 | 119 | 120 | def _finalize_header(self): 121 | 122 | # create model type 123 | if self.enable_binary_blob: 124 | if self.add_typedef: 125 | self.header.write(f"\ntypedef struct {{") 126 | else: 127 | self.header.write(f"\nstruct {self.model_struct_name} {{") 128 | for name, data in self.layer_dict.items(): 129 | layer_type = data[0] 130 | self.header.write(f"\n {layer_type} {name};") 131 | if self.add_typedef: 132 | self.header.write(f"\n}} {self.model_struct_name};\n") 133 | else: 134 | self.header.write(f"\n}};\n") 135 | 136 | init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)" 137 | self.header.write(f"\n{init_prototype};\n") 138 | 139 | self.header.write(f"\n#endif /* {self.header_guard} */\n") 140 | 141 | def _finalize_source(self): 142 | 143 | if self.enable_binary_blob: 144 | # create weight array 145 | if len(set(self.weight_arrays)) != len(self.weight_arrays): 146 | raise ValueError("error: detected duplicates in weight arrays") 147 | self.source.write("\n#ifndef USE_WEIGHTS_FILE\n") 148 | self.source.write(f"const WeightArray {self.model_struct_name.lower()}_arrays[] = {{\n") 149 | for name in self.weight_arrays: 150 | self.source.write(f"#ifdef WEIGHTS_{name}_DEFINED\n") 151 | self.source.write(f' {{"{name}", WEIGHTS_{name}_TYPE, sizeof({name}), {name}}},\n') 152 | self.source.write(f"#endif\n") 153 | self.source.write(" {NULL, 0, 0, NULL}\n") 154 | self.source.write("};\n") 155 | 156 | self.source.write("#endif /* USE_WEIGHTS_FILE */\n") 157 | 158 | # create init function definition 159 | init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)" 160 | self.source.write("\n#ifndef DUMP_BINARY_WEIGHTS\n") 161 | self.source.write(f"{init_prototype} {{\n") 162 | for name, data in self.layer_dict.items(): 163 | self.source.write(f" if ({data[1]}) return 1;\n") 164 | self.source.write(" return 0;\n") 165 | self.source.write("}\n") 166 | self.source.write("#endif /* DUMP_BINARY_WEIGHTS */\n") 167 | 168 | 169 | def close(self): 170 | 171 | if not self.header_only: 172 | self._finalize_source() 173 | self.source.close() 174 | 175 | self._finalize_header() 176 | self.header.close() 177 | 178 | def __del__(self): 179 | try: 180 | self.close() 181 | except: 182 | pass -------------------------------------------------------------------------------- /torch/weight-exchange/wexchange/tf/__init__.py: -------------------------------------------------------------------------------- 1 | from .tf import dump_tf_conv1d_weights, load_tf_conv1d_weights 2 | from .tf import dump_tf_dense_weights, load_tf_dense_weights 3 | from .tf import dump_tf_embedding_weights, load_tf_embedding_weights 4 | from .tf import dump_tf_gru_weights, load_tf_gru_weights 5 | from .tf import dump_tf_weights, load_tf_weights -------------------------------------------------------------------------------- /torch/weight-exchange/wexchange/tf/tf.py: -------------------------------------------------------------------------------- 1 | """ 2 | /* Copyright (c) 2023 Amazon 3 | Written by Jan Buethe */ 4 | /* 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions 7 | are met: 8 | 9 | - Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | - Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in the 14 | documentation and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | */ 28 | """ 29 | 30 | import os 31 | 32 | import tensorflow as tf 33 | import numpy as np 34 | 35 | from wexchange.c_export import CWriter, print_gru_layer, print_dense_layer, print_conv1d_layer 36 | 37 | def dump_tf_gru_weights(where, gru, name='gru', input_sparse=False, recurrent_sparse=False, quantize=False, scale=1/128, recurrent_scale=1/128): 38 | 39 | 40 | assert gru.activation == tf.keras.activations.tanh 41 | assert gru.recurrent_activation == tf.keras.activations.sigmoid 42 | assert gru.reset_after == True 43 | 44 | w_ih = gru.weights[0].numpy().transpose().copy() 45 | w_hh = gru.weights[1].numpy().transpose().copy() 46 | b_ih = gru.weights[2].numpy()[0].copy() 47 | b_hh = gru.weights[2].numpy()[1].copy() 48 | 49 | if isinstance(where, CWriter): 50 | return print_gru_layer(where, name, w_ih, w_hh, b_ih, b_hh, format='tf', input_sparse=input_sparse, recurrent_sparse=recurrent_sparse, quantize=quantize, scale=scale, recurrent_scale=recurrent_scale) 51 | else: 52 | os.makedirs(where, exist_ok=True) 53 | 54 | # zrn => rzn 55 | N = w_ih.shape[0] // 3 56 | for x in [w_ih, w_hh, b_ih, b_hh]: 57 | tmp = x[0:N].copy() 58 | x[0:N] = x[N:2*N] 59 | x[N:2*N] = tmp 60 | 61 | np.save(os.path.join(where, 'weight_ih_rzn.npy'), w_ih) 62 | np.save(os.path.join(where, 'weight_hh_rzn.npy'), w_hh) 63 | np.save(os.path.join(where, 'bias_ih_rzn.npy'), b_ih) 64 | np.save(os.path.join(where, 'bias_hh_rzn.npy'), b_hh) 65 | 66 | 67 | def load_tf_gru_weights(path, gru): 68 | 69 | assert gru.activation == tf.keras.activations.tanh 70 | assert gru.recurrent_activation == tf.keras.activations.sigmoid 71 | assert gru.reset_after == True 72 | 73 | w_ih = np.load(os.path.join(path, 'weight_ih_rzn.npy')) 74 | w_hh = np.load(os.path.join(path, 'weight_hh_rzn.npy')) 75 | b_ih = np.load(os.path.join(path, 'bias_ih_rzn.npy')) 76 | b_hh = np.load(os.path.join(path, 'bias_hh_rzn.npy')) 77 | 78 | # rzn => zrn 79 | N = w_ih.shape[0] // 3 80 | for x in [w_ih, w_hh, b_ih, b_hh]: 81 | tmp = x[0:N].copy() 82 | x[0:N] = x[N:2*N] 83 | x[N:2*N] = tmp 84 | 85 | gru.weights[0].assign(tf.convert_to_tensor(w_ih.transpose())) 86 | gru.weights[1].assign(tf.convert_to_tensor(w_hh.transpose())) 87 | gru.weights[2].assign(tf.convert_to_tensor(np.vstack((b_ih, b_hh)))) 88 | 89 | 90 | def dump_tf_dense_weights(where, dense, name='dense', scale=1/128, sparse=False, diagonal=False, quantize=False): 91 | 92 | w = dense.weights[0].numpy() 93 | if dense.bias is None: 94 | b = np.zeros(dense.units, dtype=w.dtype) 95 | else: 96 | b = dense.bias.numpy() 97 | 98 | 99 | 100 | if isinstance(where, CWriter): 101 | return print_dense_layer(where, name, w, b, scale=scale, format='tf', sparse=sparse, diagonal=diagonal, quantize=quantize) 102 | 103 | else: 104 | os.makedirs(where, exist_ok=True) 105 | 106 | np.save(os.path.join(where, 'weight.npy'), w.transpose()) 107 | np.save(os.path.join(where, 'bias.npy'), b) 108 | 109 | 110 | def load_tf_dense_weights(path, dense): 111 | 112 | w = np.load(os.path.join(path, 'weight.npy')).transpose() 113 | b = np.load(os.path.join(path, 'bias.npy')) 114 | 115 | dense.weights[0].assign(tf.convert_to_tensor(w)) 116 | if dense.bias is not None: 117 | dense.weights[1].assign(tf.convert_to_tensor(b)) 118 | 119 | 120 | def dump_tf_conv1d_weights(where, conv, name='conv', scale=1/128, quantize=False): 121 | 122 | assert conv.data_format == 'channels_last' 123 | 124 | w = conv.weights[0].numpy().copy() 125 | if conv.bias is None: 126 | b = np.zeros(conv.filters, dtype=w.dtype) 127 | else: 128 | b = conv.bias.numpy() 129 | 130 | if isinstance(where, CWriter): 131 | return print_conv1d_layer(where, name, w, b, scale=scale, format='tf', quantize=quantize) 132 | else: 133 | os.makedirs(where, exist_ok=True) 134 | 135 | w = np.transpose(w, (2, 1, 0)) 136 | np.save(os.path.join(where, 'weight_oik.npy'), w) 137 | np.save(os.path.join(where, 'bias.npy'), b) 138 | 139 | 140 | def load_tf_conv1d_weights(path, conv): 141 | 142 | w = np.load(os.path.join(path, 'weight_oik.npy')) 143 | b = np.load(os.path.join(path, 'bias.npy')) 144 | 145 | w = np.transpose(w, (2, 1, 0)) 146 | 147 | conv.weights[0].assign(tf.convert_to_tensor(w)) 148 | if conv.bias is not None: 149 | conv.weights[1].assign(tf.convert_to_tensor(b)) 150 | 151 | 152 | def dump_tf_embedding_weights(path, emb): 153 | os.makedirs(path, exist_ok=True) 154 | 155 | w = emb.weights[0].numpy() 156 | np.save(os.path.join(path, 'weight.npy'), w) 157 | 158 | 159 | 160 | def load_tf_embedding_weights(path, emb): 161 | 162 | w = np.load(os.path.join(path, 'weight.npy')) 163 | emb.weights[0].assign(tf.convert_to_tensor(w)) 164 | 165 | 166 | def dump_tf_weights(path, module): 167 | if isinstance(module, tf.keras.layers.Dense): 168 | dump_tf_dense_weights(path, module) 169 | elif isinstance(module, tf.keras.layers.GRU): 170 | dump_tf_gru_weights(path, module) 171 | elif isinstance(module, tf.keras.layers.Conv1D): 172 | dump_tf_conv1d_weights(path, module) 173 | elif isinstance(module, tf.keras.layers.Embedding): 174 | dump_tf_embedding_weights(path, module) 175 | else: 176 | raise ValueError(f'dump_tf_weights: layer of type {type(module)} not supported') 177 | 178 | def load_tf_weights(path, module): 179 | if isinstance(module, tf.keras.layers.Dense): 180 | load_tf_dense_weights(path, module) 181 | elif isinstance(module, tf.keras.layers.GRU): 182 | load_tf_gru_weights(path, module) 183 | elif isinstance(module, tf.keras.layers.Conv1D): 184 | load_tf_conv1d_weights(path, module) 185 | elif isinstance(module, tf.keras.layers.Embedding): 186 | load_tf_embedding_weights(path, module) 187 | else: 188 | raise ValueError(f'dump_tf_weights: layer of type {type(module)} not supported') -------------------------------------------------------------------------------- /torch/weight-exchange/wexchange/torch/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | /* Copyright (c) 2023 Amazon 3 | Written by Jan Buethe */ 4 | /* 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions 7 | are met: 8 | 9 | - Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | - Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in the 14 | documentation and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 | ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 | PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | */ 28 | """ 29 | 30 | from .torch import dump_torch_conv1d_weights, load_torch_conv1d_weights 31 | from .torch import dump_torch_conv2d_weights, load_torch_conv2d_weights 32 | from .torch import dump_torch_dense_weights, load_torch_dense_weights 33 | from .torch import dump_torch_gru_weights, load_torch_gru_weights 34 | from .torch import dump_torch_grucell_weights 35 | from .torch import dump_torch_embedding_weights, load_torch_embedding_weights 36 | from .torch import dump_torch_weights, load_torch_weights 37 | from .torch import dump_torch_adaptive_conv1d_weights -------------------------------------------------------------------------------- /training/bin2hdf5.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import h5py 7 | import sys 8 | 9 | data = np.fromfile(sys.argv[1], dtype='float32'); 10 | data = np.reshape(data, (int(sys.argv[2]), int(sys.argv[3]))); 11 | h5f = h5py.File(sys.argv[4], 'w'); 12 | h5f.create_dataset('data', data=data) 13 | h5f.close() 14 | -------------------------------------------------------------------------------- /training/dump_rnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | from __future__ import print_function 4 | 5 | from keras.models import Sequential 6 | from keras.layers import Dense 7 | from keras.layers import LSTM 8 | from keras.layers import GRU 9 | from keras.models import load_model 10 | from keras import backend as K 11 | import sys 12 | import re 13 | import numpy as np 14 | 15 | def printVector(f, ft, vector, name): 16 | v = np.reshape(vector, (-1)); 17 | #print('static const float ', name, '[', len(v), '] = \n', file=f) 18 | f.write('static const rnn_weight {}[{}] = {{\n '.format(name, len(v))) 19 | for i in range(0, len(v)): 20 | f.write('{}'.format(min(127, int(round(256*v[i]))))) 21 | ft.write('{}'.format(min(127, int(round(256*v[i]))))) 22 | if (i!=len(v)-1): 23 | f.write(',') 24 | else: 25 | break; 26 | ft.write(" ") 27 | if (i%8==7): 28 | f.write("\n ") 29 | else: 30 | f.write(" ") 31 | #print(v, file=f) 32 | f.write('\n};\n\n') 33 | ft.write("\n") 34 | return; 35 | 36 | def printLayer(f, ft, layer): 37 | weights = layer.get_weights() 38 | activation = re.search('function (.*) at', str(layer.activation)).group(1).upper() 39 | if len(weights) > 2: 40 | ft.write('{} {} '.format(weights[0].shape[0], weights[0].shape[1]/3)) 41 | else: 42 | ft.write('{} {} '.format(weights[0].shape[0], weights[0].shape[1])) 43 | if activation == 'SIGMOID': 44 | ft.write('1\n') 45 | elif activation == 'RELU': 46 | ft.write('2\n') 47 | else: 48 | ft.write('0\n') 49 | printVector(f, ft, weights[0], layer.name + '_weights') 50 | if len(weights) > 2: 51 | printVector(f, ft, weights[1], layer.name + '_recurrent_weights') 52 | printVector(f, ft, weights[-1], layer.name + '_bias') 53 | name = layer.name 54 | if len(weights) > 2: 55 | f.write('static const GRULayer {} = {{\n {}_bias,\n {}_weights,\n {}_recurrent_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n' 56 | .format(name, name, name, name, weights[0].shape[0], weights[0].shape[1]/3, activation)) 57 | else: 58 | f.write('static const DenseLayer {} = {{\n {}_bias,\n {}_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n' 59 | .format(name, name, name, weights[0].shape[0], weights[0].shape[1], activation)) 60 | 61 | def structLayer(f, layer): 62 | weights = layer.get_weights() 63 | name = layer.name 64 | if len(weights) > 2: 65 | f.write(' {},\n'.format(weights[0].shape[1]/3)) 66 | else: 67 | f.write(' {},\n'.format(weights[0].shape[1])) 68 | f.write(' &{},\n'.format(name)) 69 | 70 | 71 | def foo(c, name): 72 | return None 73 | 74 | def mean_squared_sqrt_error(y_true, y_pred): 75 | return K.mean(K.square(K.sqrt(y_pred) - K.sqrt(y_true)), axis=-1) 76 | 77 | 78 | model = load_model(sys.argv[1], custom_objects={'msse': mean_squared_sqrt_error, 'mean_squared_sqrt_error': mean_squared_sqrt_error, 'my_crossentropy': mean_squared_sqrt_error, 'mycost': mean_squared_sqrt_error, 'WeightClip': foo}) 79 | 80 | weights = model.get_weights() 81 | 82 | f = open(sys.argv[2], 'w') 83 | ft = open(sys.argv[3], 'w') 84 | 85 | f.write('/*This file is automatically generated from a Keras model*/\n\n') 86 | f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "rnn.h"\n#include "rnn_data.h"\n\n') 87 | ft.write('rnnoise-nu model file version 1\n') 88 | 89 | layer_list = [] 90 | for i, layer in enumerate(model.layers): 91 | if len(layer.get_weights()) > 0: 92 | printLayer(f, ft, layer) 93 | if len(layer.get_weights()) > 2: 94 | layer_list.append(layer.name) 95 | 96 | f.write('const struct RNNModel rnnoise_model_{} = {{\n'.format(sys.argv[4])) 97 | for i, layer in enumerate(model.layers): 98 | if len(layer.get_weights()) > 0: 99 | structLayer(f, layer) 100 | f.write('};\n') 101 | 102 | #hf.write('struct RNNState {\n') 103 | #for i, name in enumerate(layer_list): 104 | # hf.write(' float {}_state[{}_SIZE];\n'.format(name, name.upper())) 105 | #hf.write('};\n') 106 | 107 | f.close() 108 | -------------------------------------------------------------------------------- /training/rnn_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | from __future__ import print_function 4 | 5 | import keras 6 | from keras.models import Sequential 7 | from keras.models import Model 8 | from keras.layers import Input 9 | from keras.layers import Dense 10 | from keras.layers import LSTM 11 | from keras.layers import GRU 12 | from keras.layers import SimpleRNN 13 | from keras.layers import Dropout 14 | from keras.layers import concatenate 15 | from keras import losses 16 | from keras import regularizers 17 | from keras.constraints import min_max_norm 18 | import h5py 19 | 20 | from keras.constraints import Constraint 21 | from keras import backend as K 22 | import numpy as np 23 | 24 | #import tensorflow as tf 25 | #from keras.backend.tensorflow_backend import set_session 26 | #config = tf.ConfigProto() 27 | #config.gpu_options.per_process_gpu_memory_fraction = 0.42 28 | #set_session(tf.Session(config=config)) 29 | 30 | 31 | def my_crossentropy(y_true, y_pred): 32 | return K.mean(2*K.abs(y_true-0.5) * K.binary_crossentropy(y_pred, y_true), axis=-1) 33 | 34 | def mymask(y_true): 35 | return K.minimum(y_true+1., 1.) 36 | 37 | def msse(y_true, y_pred): 38 | return K.mean(mymask(y_true) * K.square(K.sqrt(y_pred) - K.sqrt(y_true)), axis=-1) 39 | 40 | def mycost(y_true, y_pred): 41 | return K.mean(mymask(y_true) * (10*K.square(K.square(K.sqrt(y_pred) - K.sqrt(y_true))) + K.square(K.sqrt(y_pred) - K.sqrt(y_true)) + 0.01*K.binary_crossentropy(y_pred, y_true)), axis=-1) 42 | 43 | def my_accuracy(y_true, y_pred): 44 | return K.mean(2*K.abs(y_true-0.5) * K.equal(y_true, K.round(y_pred)), axis=-1) 45 | 46 | class WeightClip(Constraint): 47 | '''Clips the weights incident to each hidden unit to be inside a range 48 | ''' 49 | def __init__(self, c=2): 50 | self.c = c 51 | 52 | def __call__(self, p): 53 | return K.clip(p, -self.c, self.c) 54 | 55 | def get_config(self): 56 | return {'name': self.__class__.__name__, 57 | 'c': self.c} 58 | 59 | reg = 0.000001 60 | constraint = WeightClip(0.499) 61 | 62 | print('Build model...') 63 | main_input = Input(shape=(None, 42), name='main_input') 64 | tmp = Dense(24, activation='tanh', name='input_dense', kernel_constraint=constraint, bias_constraint=constraint)(main_input) 65 | vad_gru = GRU(24, activation='tanh', recurrent_activation='sigmoid', return_sequences=True, name='vad_gru', kernel_regularizer=regularizers.l2(reg), recurrent_regularizer=regularizers.l2(reg), kernel_constraint=constraint, recurrent_constraint=constraint, bias_constraint=constraint)(tmp) 66 | vad_output = Dense(1, activation='sigmoid', name='vad_output', kernel_constraint=constraint, bias_constraint=constraint)(vad_gru) 67 | noise_input = keras.layers.concatenate([tmp, vad_gru, main_input]) 68 | noise_gru = GRU(48, activation='relu', recurrent_activation='sigmoid', return_sequences=True, name='noise_gru', kernel_regularizer=regularizers.l2(reg), recurrent_regularizer=regularizers.l2(reg), kernel_constraint=constraint, recurrent_constraint=constraint, bias_constraint=constraint)(noise_input) 69 | denoise_input = keras.layers.concatenate([vad_gru, noise_gru, main_input]) 70 | 71 | denoise_gru = GRU(96, activation='tanh', recurrent_activation='sigmoid', return_sequences=True, name='denoise_gru', kernel_regularizer=regularizers.l2(reg), recurrent_regularizer=regularizers.l2(reg), kernel_constraint=constraint, recurrent_constraint=constraint, bias_constraint=constraint)(denoise_input) 72 | 73 | denoise_output = Dense(22, activation='sigmoid', name='denoise_output', kernel_constraint=constraint, bias_constraint=constraint)(denoise_gru) 74 | 75 | model = Model(inputs=main_input, outputs=[denoise_output, vad_output]) 76 | 77 | model.compile(loss=[mycost, my_crossentropy], 78 | metrics=[msse], 79 | optimizer='adam', loss_weights=[10, 0.5]) 80 | 81 | 82 | batch_size = 32 83 | 84 | print('Loading data...') 85 | with h5py.File('training.h5', 'r') as hf: 86 | all_data = hf['data'][:] 87 | print('done.') 88 | 89 | window_size = 2000 90 | 91 | nb_sequences = len(all_data)//window_size 92 | print(nb_sequences, ' sequences') 93 | x_train = all_data[:nb_sequences*window_size, :42] 94 | x_train = np.reshape(x_train, (nb_sequences, window_size, 42)) 95 | 96 | y_train = np.copy(all_data[:nb_sequences*window_size, 42:64]) 97 | y_train = np.reshape(y_train, (nb_sequences, window_size, 22)) 98 | 99 | noise_train = np.copy(all_data[:nb_sequences*window_size, 64:86]) 100 | noise_train = np.reshape(noise_train, (nb_sequences, window_size, 22)) 101 | 102 | vad_train = np.copy(all_data[:nb_sequences*window_size, 86:87]) 103 | vad_train = np.reshape(vad_train, (nb_sequences, window_size, 1)) 104 | 105 | all_data = 0; 106 | #x_train = x_train.astype('float32') 107 | #y_train = y_train.astype('float32') 108 | 109 | print(len(x_train), 'train sequences. x shape =', x_train.shape, 'y shape = ', y_train.shape) 110 | 111 | print('Train...') 112 | model.fit(x_train, [y_train, vad_train], 113 | batch_size=batch_size, 114 | epochs=120, 115 | validation_split=0.1) 116 | model.save("weights.hdf5") 117 | -------------------------------------------------------------------------------- /update_version: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Creates and updates the package_version information used by configure.ac 4 | # (or other makefiles). When run inside a git repository it will use the 5 | # version information that can be queried from it unless AUTO_UPDATE is set 6 | # to 'no'. If no version is currently known it will be set to 'unknown'. 7 | # 8 | # If called with the argument 'release', the PACKAGE_VERSION will be updated 9 | # even if AUTO_UPDATE=no, but the value of AUTO_UPDATE shall be preserved. 10 | # This is used to force a version update whenever `make dist` is run. 11 | # 12 | # The exit status is 1 if package_version is not modified, else 0 is returned. 13 | # 14 | # This script should NOT be included in distributed tarballs, because if a 15 | # parent directory contains a git repository we do not want to accidentally 16 | # retrieve the version information from it instead. Tarballs should ship 17 | # with only the package_version file. 18 | # 19 | # Ron , 2012. 20 | 21 | SRCDIR=$(dirname $0) 22 | 23 | if [ -e "$SRCDIR/package_version" ]; then 24 | . "$SRCDIR/package_version" 25 | fi 26 | 27 | if [ "$AUTO_UPDATE" = no ]; then 28 | [ "$1" = release ] || exit 1 29 | else 30 | AUTO_UPDATE=yes 31 | fi 32 | 33 | # We run `git status` before describe here to ensure that we don't get a false 34 | # -dirty from files that have been touched but are not actually altered in the 35 | # working dir. 36 | GIT_VERSION=$(cd "$SRCDIR" && git status > /dev/null 2>&1 \ 37 | && git describe --tags --match 'v*' --dirty 2> /dev/null) 38 | GIT_VERSION=${GIT_VERSION#v} 39 | 40 | if [ -n "$GIT_VERSION" ]; then 41 | 42 | [ "$GIT_VERSION" != "$PACKAGE_VERSION" ] || exit 1 43 | PACKAGE_VERSION="$GIT_VERSION" 44 | 45 | elif [ -z "$PACKAGE_VERSION" ]; then 46 | # No current package_version and no git ... 47 | # We really shouldn't ever get here, because this script should only be 48 | # included in the git repository, and should usually be export-ignored. 49 | PACKAGE_VERSION="unknown" 50 | else 51 | exit 1 52 | fi 53 | 54 | cat > "$SRCDIR/package_version" <<-EOF 55 | # Automatically generated by update_version. 56 | # This file may be sourced into a shell script or makefile. 57 | 58 | # Set this to 'no' if you do not wish the version information 59 | # to be checked and updated for every build. Most people will 60 | # never want to change this, it is an option for developers 61 | # making frequent changes that they know will not be released. 62 | AUTO_UPDATE=$AUTO_UPDATE 63 | 64 | PACKAGE_VERSION="$PACKAGE_VERSION" 65 | EOF 66 | --------------------------------------------------------------------------------