├── .editorconfig ├── .gitattributes ├── .gitignore ├── Makefile ├── audio.py ├── cpp ├── CMakeLists.txt ├── libs │ ├── cpm.cmake │ ├── deepfilternet.cmake │ ├── npy.cmake │ ├── onnxruntime.cmake │ ├── pocketfft.cmake │ ├── stftpitchshift.cmake │ ├── xtensor.cmake │ └── xtl.cmake └── src │ ├── DeepAssert.h │ ├── DeepFilter.cpp │ ├── DeepFilter.h │ ├── DeepFilterInference.cpp │ ├── DeepFilterInference.h │ ├── FFT.h │ ├── STFT.h │ ├── main.cpp │ └── xt.h ├── cpx.py ├── erb.py ├── requirements.txt ├── spectrum.py ├── stft.py ├── test.py ├── test_ort.py ├── x.wav └── y.wav /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | charset = utf-8 5 | insert_final_newline = true 6 | trim_trailing_whitespace = true 7 | 8 | [*.py] 9 | indent_style = space 10 | indent_size = 4 11 | 12 | [*.{h,c,cpp}] 13 | indent_style = space 14 | indent_size = 2 15 | 16 | [{*.cmake,CMakeLists.txt}] 17 | indent_style = space 18 | indent_size = 2 19 | 20 | [Makefile] 21 | indent_style = tab 22 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto 2 | *.bat eol=crlf 3 | *.sh eol=lf 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .*/ 2 | 3 | build/ 4 | models/ 5 | 6 | __pycache__/ 7 | *.py[cod] 8 | 9 | .DS_Store 10 | Thumbs.db 11 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: help build clean run 2 | 3 | CONFIG = Release 4 | INPUT = cpp 5 | OUTPUT = build 6 | 7 | help: 8 | @echo build 9 | @echo clean 10 | @echo run 11 | 12 | build: 13 | @cmake -DCMAKE_BUILD_TYPE=$(CONFIG) -S $(INPUT) -B $(OUTPUT) 14 | @cmake --build $(OUTPUT) 15 | 16 | clean: 17 | @rm -rf $(OUTPUT) 18 | 19 | run: 20 | build/debug_cpp 21 | -------------------------------------------------------------------------------- /audio.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import resampy 3 | import soundfile 4 | 5 | from numpy.typing import NDArray 6 | 7 | 8 | def read(path: str, sr: int) -> NDArray: 9 | 10 | samples, samplerate = soundfile.read(path, always_2d=True) 11 | 12 | if samplerate != sr: 13 | samples = resampy.resample(samples, samplerate, sr) 14 | samplerate = sr 15 | 16 | samples = samples.T 17 | 18 | return samples, samplerate 19 | 20 | 21 | def write(path: str, sr: int, samples: NDArray): 22 | 23 | samples = np.squeeze(samples.T) 24 | soundfile.write(path, samples, sr) 25 | -------------------------------------------------------------------------------- /cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.21) 2 | 3 | project(debug_cpp) 4 | 5 | include("${CMAKE_CURRENT_LIST_DIR}/libs/cpm.cmake") 6 | include("${CMAKE_CURRENT_LIST_DIR}/libs/deepfilternet.cmake") 7 | include("${CMAKE_CURRENT_LIST_DIR}/libs/npy.cmake") 8 | include("${CMAKE_CURRENT_LIST_DIR}/libs/onnxruntime.cmake") 9 | include("${CMAKE_CURRENT_LIST_DIR}/libs/pocketfft.cmake") 10 | include("${CMAKE_CURRENT_LIST_DIR}/libs/stftpitchshift.cmake") 11 | include("${CMAKE_CURRENT_LIST_DIR}/libs/xtensor.cmake") 12 | include("${CMAKE_CURRENT_LIST_DIR}/libs/xtl.cmake") 13 | 14 | file(GLOB_RECURSE HDR "${CMAKE_CURRENT_LIST_DIR}/src/*.h") 15 | file(GLOB_RECURSE SRC "${CMAKE_CURRENT_LIST_DIR}/src/*.cpp") 16 | 17 | add_executable(${PROJECT_NAME}) 18 | target_sources(${PROJECT_NAME} PRIVATE "${HDR}" "${SRC}") 19 | target_include_directories(${PROJECT_NAME} PRIVATE "${CMAKE_CURRENT_LIST_DIR}/src") 20 | target_compile_features(${PROJECT_NAME} PRIVATE cxx_std_20) 21 | target_link_libraries(${PROJECT_NAME} PRIVATE deepfilternet npy onnxruntime pocketfft stftpitchshift xtensor) 22 | 23 | if(MSVC) 24 | target_compile_options(${CMAKE_PROJECT_NAME} PRIVATE /fp:fast) 25 | target_compile_options(${CMAKE_PROJECT_NAME} PRIVATE /W3 /WX) 26 | else() 27 | target_compile_options(${CMAKE_PROJECT_NAME} PRIVATE -ffast-math) 28 | target_compile_options(${CMAKE_PROJECT_NAME} PRIVATE -Wall -Werror) 29 | endif() 30 | 31 | add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD 32 | COMMAND ${CMAKE_COMMAND} -E copy -t $ $ 33 | COMMAND_EXPAND_LISTS) 34 | -------------------------------------------------------------------------------- /cpp/libs/cpm.cmake: -------------------------------------------------------------------------------- 1 | # https://github.com/cpm-cmake/CPM.cmake 2 | 3 | set(CPMSRC "https://github.com/cpm-cmake/CPM.cmake/releases/download/v0.39.0/CPM.cmake") 4 | set(CPMDST "${CMAKE_BINARY_DIR}/CPM.cmake") 5 | 6 | if(NOT EXISTS "${CPMDST}") 7 | file(DOWNLOAD "${CPMSRC}" "${CPMDST}") 8 | endif() 9 | 10 | include("${CPMDST}") 11 | -------------------------------------------------------------------------------- /cpp/libs/deepfilternet.cmake: -------------------------------------------------------------------------------- 1 | # https://github.com/Rikorose/DeepFilterNet 2 | 3 | CPMAddPackage( 4 | NAME deepfilternet 5 | VERSION 0.5.6 6 | GIT_REPOSITORY https://github.com/Rikorose/DeepFilterNet 7 | DOWNLOAD_ONLY YES) 8 | 9 | if(deepfilternet_ADDED) 10 | 11 | add_library(deepfilternet INTERFACE) 12 | 13 | file(ARCHIVE_EXTRACT 14 | INPUT "${deepfilternet_SOURCE_DIR}/models/DeepFilterNet3_onnx.tar.gz" 15 | DESTINATION "${CMAKE_BINARY_DIR}/DeepFilterNet3") 16 | 17 | target_compile_definitions(deepfilternet 18 | INTERFACE -DDeepFilterNetOnnx="${CMAKE_BINARY_DIR}/DeepFilterNet3/tmp/export") 19 | 20 | endif() 21 | -------------------------------------------------------------------------------- /cpp/libs/npy.cmake: -------------------------------------------------------------------------------- 1 | # https://github.com/llohse/libnpy 2 | 3 | CPMAddPackage( 4 | NAME npy 5 | VERSION 1.0.1 6 | GIT_REPOSITORY https://github.com/llohse/libnpy 7 | DOWNLOAD_ONLY YES) 8 | 9 | if(npy_ADDED) 10 | 11 | add_library(npy INTERFACE) 12 | 13 | target_include_directories(npy 14 | INTERFACE "${npy_SOURCE_DIR}/include") 15 | 16 | endif() 17 | -------------------------------------------------------------------------------- /cpp/libs/onnxruntime.cmake: -------------------------------------------------------------------------------- 1 | # https://github.com/microsoft/onnxruntime 2 | 3 | set(VERSION 1.18.0) 4 | set(GITHUB https://github.com/microsoft/onnxruntime/releases/download/v${VERSION}) 5 | set(MAVEN https://repo.maven.apache.org/maven2/com/microsoft/onnxruntime/onnxruntime-mobile/${VERSION}) 6 | set(URL "") 7 | 8 | if(CMAKE_SYSTEM_NAME STREQUAL "Android") 9 | message(STATUS "ONNX Runtime ${CMAKE_SYSTEM_NAME} ${ANDROID_ABI}") 10 | set(URL ${MAVEN}/onnxruntime-mobile-${VERSION}.aar) 11 | endif() 12 | 13 | if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") 14 | message(STATUS "ONNX Runtime ${CMAKE_SYSTEM_NAME} ${CMAKE_SYSTEM_PROCESSOR}") 15 | set(URL ${GITHUB}/onnxruntime-osx-${CMAKE_SYSTEM_PROCESSOR}-${VERSION}.tgz) 16 | endif() 17 | 18 | if(CMAKE_SYSTEM_NAME STREQUAL "Linux") 19 | message(STATUS "ONNX Runtime ${CMAKE_SYSTEM_NAME} ${CMAKE_SYSTEM_PROCESSOR}") 20 | set(URL ${GITHUB}/onnxruntime-linux-x64-${VERSION}.tgz) 21 | endif() 22 | 23 | if(CMAKE_SYSTEM_NAME STREQUAL "Windows") 24 | message(STATUS "ONNX Runtime ${CMAKE_SYSTEM_NAME} ${CMAKE_SYSTEM_PROCESSOR}") 25 | set(URL ${GITHUB}/onnxruntime-win-x64-${VERSION}.zip) 26 | endif() 27 | 28 | string(COMPARE EQUAL "${URL}" "" NOK) 29 | if(NOK) 30 | message(FATAL_ERROR "Unable to determine the ONNX Runtime prebuilt binary url!") 31 | endif() 32 | 33 | CPMAddPackage( 34 | NAME onnxruntime 35 | VERSION ${VERSION} 36 | URL ${URL}) 37 | 38 | if(onnxruntime_ADDED) 39 | 40 | add_library(onnxruntime SHARED IMPORTED) 41 | 42 | if(CMAKE_SYSTEM_NAME STREQUAL "Android") 43 | set_target_properties(onnxruntime PROPERTIES 44 | IMPORTED_LOCATION "${onnxruntime_SOURCE_DIR}/jni/${ANDROID_ABI}/libonnxruntime.so" 45 | INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/headers") 46 | endif() 47 | 48 | if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") 49 | set_target_properties(onnxruntime PROPERTIES 50 | IMPORTED_LOCATION "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime.${VERSION}.dylib" 51 | IMPORTED_IMPLIB "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime.dylib" 52 | INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include") 53 | endif() 54 | 55 | if(CMAKE_SYSTEM_NAME STREQUAL "Linux") 56 | set_target_properties(onnxruntime PROPERTIES 57 | IMPORTED_LOCATION "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime.so.${VERSION}" 58 | IMPORTED_IMPLIB "${onnxruntime_SOURCE_DIR}/lib/libonnxruntime.so" 59 | INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include") 60 | endif() 61 | 62 | if(CMAKE_SYSTEM_NAME STREQUAL "Windows") 63 | set_target_properties(onnxruntime PROPERTIES 64 | IMPORTED_LOCATION "${onnxruntime_SOURCE_DIR}/lib/onnxruntime.dll" 65 | IMPORTED_IMPLIB "${onnxruntime_SOURCE_DIR}/lib/onnxruntime.lib" 66 | INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_SOURCE_DIR}/include") 67 | endif() 68 | 69 | endif() 70 | -------------------------------------------------------------------------------- /cpp/libs/pocketfft.cmake: -------------------------------------------------------------------------------- 1 | # https://gitlab.mpcdf.mpg.de/mtr/pocketfft 2 | 3 | CPMAddPackage( 4 | NAME pocketfft 5 | VERSION 2024.05.05 6 | GIT_TAG b557a3519ccc1e36b74dc0901a073dd7872c0af2 7 | GIT_REPOSITORY https://gitlab.mpcdf.mpg.de/mtr/pocketfft 8 | DOWNLOAD_ONLY YES) 9 | 10 | if(pocketfft_ADDED) 11 | 12 | add_library(pocketfft INTERFACE) 13 | 14 | target_include_directories(pocketfft 15 | INTERFACE "${pocketfft_SOURCE_DIR}") 16 | 17 | target_compile_definitions(pocketfft 18 | INTERFACE -DPOCKETFFT_NO_MULTITHREADING) 19 | 20 | target_compile_definitions(pocketfft 21 | INTERFACE -DPOCKETFFT_CACHE_SIZE=10) 22 | 23 | if(UNIX) 24 | target_link_libraries(pocketfft 25 | INTERFACE pthread) 26 | endif() 27 | 28 | endif() 29 | -------------------------------------------------------------------------------- /cpp/libs/stftpitchshift.cmake: -------------------------------------------------------------------------------- 1 | # https://github.com/jurihock/stftPitchShift 2 | 3 | CPMAddPackage( 4 | NAME stftpitchshift 5 | VERSION 2.0 6 | GIT_TAG 8e55dc639811a8ac5b3f5bc3d1f90a11c59ca582 7 | GITHUB_REPOSITORY jurihock/stftPitchShift 8 | DOWNLOAD_ONLY YES) 9 | 10 | if(stftpitchshift_ADDED) 11 | 12 | add_library(stftpitchshift INTERFACE) 13 | 14 | target_include_directories(stftpitchshift 15 | INTERFACE "${stftpitchshift_SOURCE_DIR}/cpp") 16 | 17 | target_compile_definitions(stftpitchshift 18 | INTERFACE -DENABLE_ARCTANGENT_APPROXIMATION) 19 | 20 | endif() 21 | -------------------------------------------------------------------------------- /cpp/libs/xtensor.cmake: -------------------------------------------------------------------------------- 1 | # https://github.com/xtensor-stack/xtensor 2 | 3 | CPMAddPackage( 4 | NAME xtensor 5 | VERSION 0.25.0 6 | GIT_TAG 0.25.0 7 | GITHUB_REPOSITORY xtensor-stack/xtensor 8 | DOWNLOAD_ONLY YES) 9 | 10 | if(xtensor_ADDED) 11 | 12 | add_library(xtensor INTERFACE) 13 | 14 | target_include_directories(xtensor 15 | INTERFACE "${xtensor_SOURCE_DIR}/include") 16 | 17 | target_link_libraries(xtensor 18 | INTERFACE xtl) 19 | 20 | endif() 21 | -------------------------------------------------------------------------------- /cpp/libs/xtl.cmake: -------------------------------------------------------------------------------- 1 | # https://github.com/xtensor-stack/xtl 2 | 3 | CPMAddPackage( 4 | NAME xtl 5 | VERSION 0.7.7 6 | GIT_TAG 0.7.7 7 | GITHUB_REPOSITORY xtensor-stack/xtl 8 | DOWNLOAD_ONLY YES) 9 | 10 | if(xtl_ADDED) 11 | 12 | add_library(xtl INTERFACE) 13 | 14 | target_include_directories(xtl 15 | INTERFACE "${xtl_SOURCE_DIR}/include") 16 | 17 | endif() 18 | -------------------------------------------------------------------------------- /cpp/src/DeepAssert.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #define deep_assert(condition, ...) \ 9 | do \ 10 | { \ 11 | deep_detailed_throw_if( \ 12 | !(condition), \ 13 | (#condition), \ 14 | __FILE__, \ 15 | __LINE__, \ 16 | __VA_ARGS__); \ 17 | } \ 18 | while(0) // 19 | 20 | #define deep_throw_if(condition, ...) \ 21 | do \ 22 | { \ 23 | deep_detailed_throw_if( \ 24 | (condition), \ 25 | (#condition), \ 26 | __FILE__, \ 27 | __LINE__, \ 28 | __VA_ARGS__); \ 29 | } \ 30 | while(0) // 31 | 32 | template 33 | static void deep_detailed_throw_if( 34 | const bool condition, 35 | const char* expression, 36 | const char* file, 37 | const int line, 38 | const char* message, 39 | Args&&... args) 40 | { 41 | if (!condition) 42 | { 43 | return; 44 | } 45 | 46 | auto format = [](const char* format, ...) 47 | { 48 | va_list args; 49 | va_list temp; 50 | 51 | va_start(args, format); 52 | 53 | va_copy(temp, args); 54 | auto size = std::vsnprintf(nullptr, 0, format, temp); 55 | va_end(temp); 56 | 57 | std::string buffer(size + 1, '\0'); 58 | std::vsnprintf(buffer.data(), buffer.size(), format, args); 59 | if (!buffer.empty()) { buffer.pop_back(); } 60 | 61 | va_end(args); 62 | 63 | return buffer; 64 | }; 65 | 66 | auto filename = std::filesystem::path(file).filename(); 67 | 68 | throw std::runtime_error( 69 | format(message, std::forward(args)...) + 70 | format(" (%s in %s at %d)", expression, filename.c_str(), line)); 71 | } 72 | -------------------------------------------------------------------------------- /cpp/src/DeepFilter.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | DeepFilter::DeepFilter() 4 | { 5 | } 6 | 7 | void DeepFilter::operator()( 8 | const std::span> input, 9 | const std::span> output) 10 | { 11 | inference(); 12 | } 13 | -------------------------------------------------------------------------------- /cpp/src/DeepFilter.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | class DeepFilter final : public DeepFilterInference 9 | { 10 | 11 | public: 12 | 13 | DeepFilter(); 14 | 15 | void operator()( 16 | const std::span> input, 17 | const std::span> output); 18 | 19 | private: 20 | 21 | }; 22 | -------------------------------------------------------------------------------- /cpp/src/DeepFilterInference.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | DeepFilterInference::DeepFilterInference() : 6 | tensors(get_tensors()), 7 | sessions(get_sessions()) 8 | { 9 | } 10 | 11 | void DeepFilterInference::inference() const 12 | { 13 | Ort::MemoryInfo cpu(Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU)); 14 | Ort::RunOptions opt; 15 | 16 | auto tensor = [&](const std::string& name) 17 | { 18 | auto tensor = tensors.at(name); 19 | 20 | return Ort::Value::CreateTensor( 21 | cpu, 22 | tensor->value.data(), 23 | tensor->value.size(), 24 | tensor->shape.data(), 25 | tensor->shape.size()); 26 | }; 27 | 28 | auto session = [&](const std::string& name, 29 | const std::initializer_list& inputs, 30 | const std::initializer_list& outputs) 31 | { 32 | std::vector input_names; 33 | std::vector output_names; 34 | std::vector input_values; 35 | std::vector output_values; 36 | 37 | input_names.reserve(inputs.size()); 38 | output_names.reserve(outputs.size()); 39 | input_values.reserve(inputs.size()); 40 | output_values.reserve(outputs.size()); 41 | 42 | for (const std::string& input : inputs) 43 | { 44 | input_names.emplace_back(input.c_str()); 45 | } 46 | 47 | for (const std::string& output : outputs) 48 | { 49 | output_names.emplace_back(output.c_str()); 50 | } 51 | 52 | for (const std::string& input : inputs) 53 | { 54 | input_values.emplace_back(tensor(input)); 55 | } 56 | 57 | for (const std::string& output : outputs) 58 | { 59 | output_values.emplace_back(tensor(output)); 60 | } 61 | 62 | sessions.at(name)->Run( 63 | opt, 64 | input_names.data(), 65 | input_values.data(), 66 | inputs.size(), 67 | output_names.data(), 68 | output_values.data(), 69 | outputs.size()); 70 | }; 71 | 72 | session("enc", { "feat_erb", "feat_spec" }, { "e0", "e1", "e2", "e3", "c0", "emb" }); 73 | session("erb_dec", { "e0", "e1", "e2", "e3", "emb" }, { "m" }); 74 | session("df_dec", { "c0", "emb" }, { "coefs" }); 75 | } 76 | 77 | std::string DeepFilterInference::probe() const 78 | { 79 | auto session2str = [](const std::map>& sessions, const std::string name) 80 | { 81 | auto type2str = [](ONNXTensorElementDataType type) 82 | { 83 | switch (type) 84 | { 85 | case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: 86 | return std::string("float"); 87 | default: 88 | return std::string("not float"); 89 | } 90 | }; 91 | 92 | auto shape2str = [](std::vector shape) 93 | { 94 | auto value2str = [](int64_t value) 95 | { 96 | switch (value) 97 | { 98 | case -1: 99 | return std::string("*"); 100 | default: 101 | return std::to_string(value); 102 | } 103 | }; 104 | 105 | std::ostringstream result; 106 | 107 | result << "("; 108 | if (!shape.empty()) 109 | result << value2str(shape.front()); 110 | for (size_t i = 1; i < shape.size(); ++i) 111 | result << "," << value2str(shape.at(i)); 112 | result << ")"; 113 | 114 | return result.str(); 115 | }; 116 | 117 | auto shape2infer = [](std::vector shape, int64_t infervalue) 118 | { 119 | std::vector values; 120 | 121 | for (int64_t value : shape) 122 | { 123 | values.push_back((value < 0) ? infervalue : value); 124 | } 125 | 126 | return values; 127 | }; 128 | 129 | std::ostringstream result; 130 | 131 | Ort::AllocatorWithDefaultOptions allocator; 132 | Ort::RunOptions opt; 133 | 134 | std::shared_ptr session(sessions.at(name)); 135 | std::vector inputs(session->GetInputCount()); 136 | std::vector outputs(session->GetOutputCount()); 137 | std::vector input_names(session->GetInputCount()); 138 | std::vector output_names(session->GetOutputCount()); 139 | std::vector input_values, output_values; 140 | 141 | for (size_t i = 0; i < inputs.size(); ++i) 142 | { 143 | inputs.at(i) = session->GetInputNameAllocated(i, allocator).get(); 144 | input_names.at(i) = inputs.at(i).c_str(); 145 | } 146 | 147 | for (size_t i = 0; i < outputs.size(); ++i) 148 | { 149 | outputs.at(i) = session->GetOutputNameAllocated(i, allocator).get(); 150 | output_names.at(i) = outputs.at(i).c_str(); 151 | } 152 | 153 | result << name << " inputs" << std::endl; 154 | 155 | for (size_t i = 0; i < inputs.size(); ++i) 156 | { 157 | auto info = session->GetInputTypeInfo(i); 158 | auto name = std::string(session->GetInputNameAllocated(i, allocator).get()); 159 | auto type = info.GetTensorTypeAndShapeInfo().GetElementType(); 160 | auto shape = info.GetTensorTypeAndShapeInfo().GetShape(); 161 | auto infer = shape2infer(shape, 1); 162 | 163 | result 164 | << (i+1) << ") " << name << " " 165 | << type2str(type) << " " 166 | << shape2str(shape) << " -> " 167 | << shape2str(infer) 168 | << std::endl; 169 | 170 | input_values.emplace_back( 171 | Ort::Value::CreateTensor( 172 | allocator, 173 | infer.data(), 174 | infer.size(), 175 | type)); 176 | } 177 | 178 | output_values = sessions.at(name)->Run( 179 | opt, 180 | input_names.data(), 181 | input_values.data(), 182 | inputs.size(), 183 | output_names.data(), 184 | outputs.size()); 185 | 186 | result << name << " outputs" << std::endl; 187 | 188 | for (size_t i = 0; i < outputs.size(); ++i) 189 | { 190 | auto info = session->GetOutputTypeInfo(i); 191 | auto name = std::string(session->GetOutputNameAllocated(i, allocator).get()); 192 | auto type = info.GetTensorTypeAndShapeInfo().GetElementType(); 193 | auto shape = info.GetTensorTypeAndShapeInfo().GetShape(); 194 | auto infer = output_values.at(i).GetTensorTypeAndShapeInfo().GetShape(); 195 | 196 | result 197 | << (i+1) << ") " 198 | << name << " " 199 | << type2str(type) << " " 200 | << shape2str(shape) << " -> " 201 | << shape2str(infer) 202 | << std::endl; 203 | } 204 | 205 | return result.str(); 206 | }; 207 | 208 | return 209 | session2str(sessions, "enc") + 210 | session2str(sessions, "erb_dec") + 211 | session2str(sessions, "df_dec"); 212 | } 213 | 214 | std::map> DeepFilterInference::get_tensors() 215 | { 216 | const std::map> shapes 217 | { 218 | { "feat_erb", {1,1,1,32} }, 219 | { "feat_spec", {1,2,1,96} }, 220 | { "e0", {1,64,1,32} }, 221 | { "e1", {1,64,1,16} }, 222 | { "e2", {1,64,1,8} }, 223 | { "e3", {1,64,1,8} }, 224 | { "c0", {1,64,1,96} }, 225 | { "emb", {1,1,512} }, 226 | { "m", {1,1,1,32} }, 227 | { "coefs", {1,1,96,10} } 228 | }; 229 | 230 | auto size = [](const std::vector& shape) 231 | { 232 | size_t product = 1; 233 | 234 | for (auto value : shape) 235 | { 236 | deep_assert(value >= 0, 237 | "Negative tensor shape element!"); 238 | 239 | product *= static_cast(std::abs(value)); 240 | } 241 | 242 | return product; 243 | }; 244 | 245 | std::map> tensors; 246 | 247 | for (const auto& [name, shape] : shapes) 248 | { 249 | auto tensor = new DeepFilterInference::Tensor(); 250 | 251 | tensor->shape = shape; 252 | tensor->value.resize(size(shape)); 253 | 254 | tensors[name] = std::shared_ptr(tensor); 255 | } 256 | 257 | return tensors; 258 | } 259 | 260 | std::map> DeepFilterInference::get_sessions() 261 | { 262 | const std::filesystem::path onnxpath = DeepFilterNetOnnx; 263 | 264 | const std::map onnxpaths 265 | { 266 | { "enc", onnxpath / "enc.onnx" }, 267 | { "erb_dec", onnxpath / "erb_dec.onnx" }, 268 | { "df_dec", onnxpath / "df_dec.onnx" } 269 | }; 270 | 271 | Ort::Env env; 272 | Ort::SessionOptions opt; 273 | 274 | auto session = [&](const std::string& name) 275 | { 276 | return std::make_shared( 277 | env, onnxpaths.at(name).c_str(), opt); 278 | }; 279 | 280 | return 281 | { 282 | { "enc", session("enc") }, 283 | { "erb_dec", session("erb_dec") }, 284 | { "df_dec", session("df_dec") } 285 | }; 286 | } 287 | -------------------------------------------------------------------------------- /cpp/src/DeepFilterInference.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | 11 | class DeepFilterInference 12 | { 13 | 14 | public: 15 | 16 | DeepFilterInference(); 17 | virtual ~DeepFilterInference() = default; 18 | 19 | float samplerate() const { return 48000; } 20 | size_t framesize() const { return 960; } 21 | size_t hopsize() const { return 480; } 22 | size_t erbsize() const { return 32; } 23 | size_t cpxsize() const { return 96; } 24 | 25 | std::string probe() const; 26 | 27 | protected: 28 | 29 | struct Tensor 30 | { 31 | std::vector shape; 32 | std::vector value; 33 | }; 34 | 35 | const std::map> tensors; 36 | const std::map> sessions; 37 | 38 | void inference() const; 39 | 40 | private: 41 | 42 | static std::map> get_tensors(); 43 | static std::map> get_sessions(); 44 | 45 | }; 46 | -------------------------------------------------------------------------------- /cpp/src/FFT.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | 7 | class FFT final : public stftpitchshift::FFT 8 | { 9 | 10 | public: 11 | 12 | void fft(const std::span frame, 13 | const std::span> dft) override 14 | { 15 | pocketfft::r2c( 16 | {frame.size()}, 17 | {sizeof(float)}, 18 | {sizeof(std::complex)}, 19 | 0, 20 | true, 21 | frame.data(), 22 | dft.data(), 23 | float(1) / frame.size()); 24 | } 25 | 26 | void fft(const std::span frame, 27 | const std::span> dft) override 28 | { 29 | pocketfft::r2c( 30 | {frame.size()}, 31 | {sizeof(double)}, 32 | {sizeof(std::complex)}, 33 | 0, 34 | true, 35 | frame.data(), 36 | dft.data(), 37 | double(1) / frame.size()); 38 | } 39 | 40 | void ifft(const std::span> dft, 41 | const std::span frame) override 42 | { 43 | pocketfft::c2r( 44 | {frame.size()}, 45 | {sizeof(std::complex)}, 46 | {sizeof(float)}, 47 | 0, 48 | false, 49 | dft.data(), 50 | frame.data(), 51 | float(1)); 52 | } 53 | 54 | void ifft(const std::span> dft, 55 | const std::span frame) override 56 | { 57 | pocketfft::c2r( 58 | {frame.size()}, 59 | {sizeof(std::complex)}, 60 | {sizeof(double)}, 61 | 0, 62 | false, 63 | dft.data(), 64 | frame.data(), 65 | double(1)); 66 | } 67 | 68 | }; 69 | -------------------------------------------------------------------------------- /cpp/src/STFT.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | using STFT = stftpitchshift::STFT; 6 | -------------------------------------------------------------------------------- /cpp/src/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | int main() 8 | { 9 | DeepFilter filter; 10 | 11 | // std::cout << filter.probe() << std::endl; 12 | 13 | auto samplerate = filter.samplerate(); 14 | auto framesize = filter.framesize(); 15 | auto hopsize = filter.hopsize(); 16 | auto chronometry = true; 17 | auto samples = static_cast(10 * samplerate); 18 | 19 | auto fft = std::make_shared(); 20 | auto stft = std::make_shared(fft, framesize, hopsize, chronometry); 21 | 22 | std::vector x(samples); 23 | std::vector y(samples); 24 | 25 | (*stft)(x, y, [&](std::span> dft) 26 | { 27 | filter(dft, dft); 28 | }); 29 | } 30 | -------------------------------------------------------------------------------- /cpp/src/xt.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | // #include 14 | #include 15 | #include 16 | 17 | namespace xt 18 | { 19 | template 20 | inline xt::xarray adapt_vector(const std::span span) 21 | { 22 | return xt::adapt( 23 | span.data(), 24 | span.size(), 25 | xt::no_ownership(), 26 | std::vector{span.size()}); 27 | } 28 | 29 | template 30 | inline xt::xarray adapt_matrix(const std::span span, const std::initializer_list& shape) 31 | { 32 | const auto size = std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies{}); 33 | 34 | if (size != span.size()) 35 | { 36 | throw std::runtime_error("Invalid matrix shape!"); 37 | } 38 | 39 | return xt::adapt( 40 | span.data(), 41 | span.size(), 42 | xt::no_ownership(), 43 | std::vector{shape}); 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /cpx.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class CPX: 5 | 6 | def __init__(self, cpxsize: int, alpha: float): 7 | 8 | self.cpxsize = cpxsize 9 | self.alpha = alpha 10 | 11 | def __call__(self, dfts): 12 | 13 | y = np.copy(dfts[..., :self.cpxsize]) 14 | 15 | # TODO ISSUE #100 16 | mean = np.full(y.shape[-1], y[..., 0, :]) 17 | alpha = self.alpha 18 | for i in range(y.shape[-2]): 19 | mean = np.absolute(y[..., i, :]) * (1 - alpha) + mean * alpha # orig: norm 20 | y[..., i, :] /= np.sqrt(mean) 21 | 22 | return y 23 | -------------------------------------------------------------------------------- /erb.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def hz2erb(hz): 5 | """ 6 | Converts frequency value in Hz to human-defined ERB band index, 7 | using the formula of Glasberg and Moore. 8 | """ 9 | return 9.265 * np.log(1 + hz / (24.7 * 9.265)) 10 | 11 | def erb2hz(erb): 12 | """ 13 | Converts human-defined ERB band index to frequency value in Hz, 14 | using the formula of Glasberg and Moore. 15 | """ 16 | return 24.7 * 9.265 * (np.exp(erb / 9.265) - 1) 17 | 18 | 19 | class ERB: 20 | 21 | def __init__(self, samplerate: int, fftsize: int, erbsize: int, minwidth: int, alpha: float): 22 | 23 | self.samplerate = samplerate 24 | self.fftsize = fftsize 25 | self.erbsize = erbsize 26 | self.minwidth = minwidth 27 | self.alpha = alpha 28 | 29 | self.widths = ERB.get_band_widths(samplerate, fftsize, erbsize, minwidth) 30 | self.weights = ERB.get_band_weights(samplerate, self.widths) 31 | 32 | def __call__(self, dfts): 33 | 34 | x = np.abs(dfts) # TODO try np.absolute with 10*log10 instead 35 | y = np.matmul(x, self.weights) 36 | y = 20 * np.log10(y + np.finfo(dfts.dtype).eps) 37 | 38 | # TODO ISSUE #100 39 | mean = np.full(y.shape[-1], y[..., 0, :]) 40 | alpha = self.alpha 41 | for i in range(y.shape[-2]): 42 | mean = y[..., i, :] * (1 - alpha) + mean * alpha 43 | y[..., i, :] -= mean 44 | y /= 40 45 | 46 | return y 47 | 48 | @staticmethod 49 | def get_band_widths(samplerate: int, fftsize: int, erbsize: int, minwidth: int): 50 | 51 | dftsize = fftsize / 2 + 1 52 | nyquist = samplerate / 2 53 | bandwidth = samplerate / fftsize 54 | 55 | erbmin = hz2erb(0) 56 | erbmax = hz2erb(nyquist) 57 | erbinc = (erbmax - erbmin) / erbsize 58 | 59 | bands = np.arange(1, erbsize + 1) 60 | freqs = erb2hz(erbmin + erbinc * bands) 61 | widths = np.round(freqs / bandwidth).astype(int) 62 | 63 | prev = 0 64 | over = 0 65 | 66 | for i in range(erbsize): 67 | 68 | next = widths[i] 69 | width = next - prev - over 70 | prev = next 71 | 72 | over = max(minwidth - width, 0) 73 | width = max(minwidth, width) 74 | 75 | widths[i] = width 76 | 77 | widths[erbsize - 1] += 1 78 | assert np.sum(widths) == dftsize 79 | 80 | return widths 81 | 82 | @staticmethod 83 | def get_band_weights(samplerate: int, widths: np.ndarray, normalized: bool = True, inverse: bool = False): 84 | 85 | n_freqs = int(np.sum(widths)) 86 | all_freqs = np.linspace(0, samplerate // 2, n_freqs + 1)[:-1] 87 | 88 | b_pts = np.cumsum([0] + widths.tolist()).astype(int)[:-1] 89 | 90 | fb = np.zeros((all_freqs.shape[0], b_pts.shape[0])) 91 | 92 | for i, (b, w) in enumerate(zip(b_pts.tolist(), widths.tolist())): 93 | fb[b : b + w, i] = 1 94 | 95 | if inverse: 96 | fb = fb.t() 97 | if not normalized: 98 | fb /= np.sum(fb, axis=1, keepdim=True) 99 | else: 100 | if normalized: 101 | fb /= np.sum(fb, axis=0) 102 | 103 | return fb 104 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | deepfilternet 2 | matplotlib 3 | numpy 4 | onnx 5 | onnxruntime 6 | resampy 7 | sdft 8 | soundfile 9 | torch 10 | torchaudio 11 | -------------------------------------------------------------------------------- /spectrum.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from numpy.typing import ArrayLike 3 | 4 | import matplotlib.pyplot as plot 5 | import numpy as np 6 | 7 | 8 | def spectrogram(x: ArrayLike, *, 9 | name: str = 'Spectrogram', 10 | xlim: Tuple[float, float] = (None, None), 11 | ylim: Tuple[float, float] = (None, None), 12 | clim: Tuple[float, float] = (-120, 0)): 13 | 14 | if not np.any(np.iscomplex(x)): 15 | 16 | x = np.atleast_1d(x) 17 | assert x.ndim == 1 18 | 19 | # X = STFT(framesize=..., hopsize=..., window=...).stft(x) 20 | # assert X.ndim == 2 21 | raise NotImplementedError('TODO STFT') 22 | 23 | else: 24 | 25 | X = np.atleast_2d(x) 26 | assert X.ndim == 2 27 | 28 | epsilon = np.finfo(X.dtype).eps 29 | 30 | spectrum = np.abs(X) 31 | spectrum = 20 * np.log10(spectrum + epsilon) 32 | 33 | timestamps = np.arange(spectrum.shape[0]) # TODO np.arange(len(X)) * hopsize / samplerate 34 | frequencies = np.arange(spectrum.shape[1]) # TODO np.fft.rfftfreq(framesize, 1 / samplerate) 35 | 36 | extent = (timestamps[0], timestamps[-1], frequencies[0], frequencies[-1]) 37 | args = dict(aspect='auto', cmap='inferno', extent=extent, interpolation='nearest', origin='lower') 38 | 39 | plot.figure(name) 40 | plot.imshow(spectrum.T, **args) 41 | colorbar = plot.colorbar() 42 | 43 | plot.xlabel('time [s]') 44 | plot.ylabel('frequency [Hz]') 45 | colorbar.set_label('magnitude [dB]') 46 | 47 | plot.xlim(*xlim) 48 | plot.ylim(*ylim) 49 | plot.clim(*clim) 50 | 51 | return plot 52 | 53 | 54 | def erbgram(x: ArrayLike, *, 55 | name: str = 'ERB', 56 | xlim: Tuple[float, float] = (None, None), 57 | ylim: Tuple[float, float] = (None, None), 58 | clim: Tuple[float, float] = (None, None)): 59 | 60 | X = np.atleast_2d(x) 61 | assert X.ndim == 2 62 | 63 | spectrum = np.abs(X) 64 | 65 | timestamps = np.arange(spectrum.shape[0]) # TODO np.arange(len(X)) * hopsize / samplerate 66 | frequencies = np.arange(spectrum.shape[1]) # TODO np.fft.rfftfreq(framesize, 1 / samplerate) 67 | 68 | extent = (timestamps[0], timestamps[-1], frequencies[0], frequencies[-1]) 69 | args = dict(aspect='auto', cmap='inferno', extent=extent, interpolation='nearest', origin='lower') 70 | 71 | plot.figure(name) 72 | plot.imshow(spectrum.T, **args) 73 | colorbar = plot.colorbar() 74 | 75 | plot.xlabel('time [s]') 76 | plot.ylabel('frequency [Hz]') 77 | colorbar.set_label('magnitude [dB]') 78 | 79 | plot.xlim(*xlim) 80 | plot.ylim(*ylim) 81 | plot.clim(*clim) 82 | 83 | return plot 84 | -------------------------------------------------------------------------------- /stft.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from numpy.typing import ArrayLike, NDArray 3 | 4 | import numpy as np 5 | import warnings 6 | 7 | from numpy.lib.stride_tricks import sliding_window_view 8 | 9 | 10 | class STFT: 11 | """ 12 | Short-Time Fourier Transform (STFT). 13 | """ 14 | 15 | def __init__(self, framesize: int, *, hopsize: Union[int, None] = None, padsize: int = 0, center: bool = False, window: Union[bool, str, None] = True): 16 | """ 17 | Create a new STFT plan. 18 | 19 | Parameters 20 | ---------- 21 | framesize: int 22 | Time domain segment length in samples. 23 | hopsize: int, optional 24 | Distance between consecutive segments in samples. 25 | Defaults to `framesize // 4`. 26 | padsize: int, optional 27 | Number of zeros to pad the segments with. 28 | center: bool, optional 29 | Shift the zero-frequency to the center of the segment. 30 | window: bool, str, none, optional 31 | Window function name or a boolean, to enable the default hann window. 32 | Currently, only hann and rect window functions are supported. 33 | """ 34 | 35 | assert framesize > 0 36 | assert hopsize > 0 37 | assert padsize >= 0 38 | 39 | if False: # TODO power of two warnings 40 | 41 | def is_power_of_two(n: int) -> bool: 42 | return (n != 0) and (n & (n-1) == 0) 43 | 44 | if not is_power_of_two(framesize): 45 | warnings.warn('The frame size should be a power of two for optimal performance!', UserWarning) 46 | 47 | if not is_power_of_two(framesize + padsize): 48 | warnings.warn('The sum of frame and pad sizes should be a power of two for optimal performance!', UserWarning) 49 | 50 | windows = { 51 | 'rect': lambda n: np.ones(n), 52 | 'none': lambda n: np.ones(n), 53 | 'false': lambda n: np.ones(n), 54 | 'true': lambda n: np.hanning(n+1)[:-1], 55 | 'hann': lambda n: np.hanning(n+1)[:-1], 56 | } 57 | 58 | self.framesize = framesize 59 | self.hopsize = hopsize or (self.framesize // 4) 60 | self.padsize = padsize 61 | self.center = center 62 | self.window = windows[str(window).lower()](self.framesize) 63 | 64 | def freqs(self, samplerate: Union[int, None] = None) -> NDArray: 65 | """ 66 | Returns an array of DFT bin center frequency values in hertz. 67 | If no sample rate is specified, then the frequency unit is cycles/second. 68 | """ 69 | 70 | return np.fft.rfftfreq(self.framesize + self.padsize, 1 / (samplerate or 1)) 71 | 72 | def stft(self, samples: ArrayLike) -> NDArray: 73 | """ 74 | Estimates the DFT matrix for the given sample array. 75 | 76 | Parameters 77 | ---------- 78 | samples: ndarray 79 | Array of time domain signal values. 80 | 81 | Returns 82 | ------- 83 | dfts: ndarray 84 | Estimated DFT matrix of shape (samples, frequencies). 85 | """ 86 | 87 | samples = np.atleast_1d(samples) 88 | 89 | assert samples.ndim == 1, f'Expected 1D array (samples,), got {samples.shape}!' 90 | 91 | frames = sliding_window_view(samples, self.framesize, writeable=False)[::self.hopsize] 92 | dfts = self.fft(frames) 93 | 94 | return dfts 95 | 96 | def istft(self, dfts: ArrayLike) -> NDArray: 97 | """ 98 | Synthesizes the sample array from the given DFT matrix. 99 | 100 | Parameters 101 | ---------- 102 | dfts: ndarray 103 | DFT matrix of shape (samples, frequencies). 104 | 105 | Returns 106 | ------- 107 | samples: ndarray 108 | Synthesized array of time domain signal values. 109 | """ 110 | 111 | dfts = np.atleast_2d(dfts) 112 | 113 | assert dfts.ndim == 2, f'Expected 2D array (samples, frequencies), got {dfts.shape}!' 114 | 115 | gain = self.hopsize / np.sum(np.square(self.window)) 116 | size = dfts.shape[0] * self.hopsize + self.framesize 117 | 118 | samples = np.zeros(size, float) 119 | 120 | frames0 = sliding_window_view(samples, self.framesize, writeable=True)[::self.hopsize] 121 | frames1 = self.ifft(dfts) * gain 122 | 123 | for i in range(min(len(frames0), len(frames1))): 124 | 125 | frames0[i] += frames1[i] 126 | 127 | return samples 128 | 129 | def fft(self, data: ArrayLike) -> NDArray: 130 | """ 131 | Performs the forward FFT. 132 | """ 133 | 134 | assert len(np.shape(data)) == 2 135 | 136 | data = np.atleast_2d(data) * self.window 137 | 138 | if self.padsize: 139 | 140 | data = np.pad(data, ((0, 0), (0, self.padsize))) 141 | 142 | if self.center: 143 | 144 | data = np.roll(data, self.framesize // -2, axis=-1) 145 | 146 | return np.fft.rfft(data, axis=-1, norm='forward') 147 | 148 | def ifft(self, data: ArrayLike) -> NDArray: 149 | """ 150 | Performs the backward FFT. 151 | """ 152 | 153 | assert len(np.shape(data)) == 2 154 | 155 | data = np.fft.irfft(data, axis=-1, norm='forward') 156 | 157 | if self.center: 158 | 159 | data = np.roll(data, self.framesize // +2, axis=-1) 160 | 161 | if self.padsize: 162 | 163 | data = data[..., :self.framesize] 164 | 165 | data *= self.window 166 | 167 | return data 168 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plot 2 | import numpy as np 3 | import onnxruntime as ort 4 | import torch 5 | 6 | from audio import read, write 7 | from cpx import CPX 8 | from erb import ERB 9 | from spectrum import spectrogram, erbgram 10 | from stft import STFT 11 | 12 | from df.enhance import init_df, df_features, enhance 13 | from df.model import ModelParams 14 | from df.utils import get_norm_alpha 15 | 16 | 17 | def filter0(model, state, x): 18 | 19 | x = torch.from_numpy(x.astype(np.float32)) 20 | y = enhance(model, state, x) 21 | y = y.detach().cpu().numpy() 22 | 23 | return y 24 | 25 | 26 | def filter1(model, state, x): 27 | 28 | cpu = torch.device('cpu') 29 | 30 | if hasattr(model, 'reset_h0'): 31 | bs = x.shape[0] 32 | print(f'reset_h0 bs={bs}') 33 | model.reset_h0(batch_size=bs, device=cpu) 34 | assert False, 'TODO reset_h0' 35 | 36 | params = ModelParams() 37 | 38 | param_sr = params.sr # 48000 39 | param_fft_size = params.fft_size # 960 40 | param_hop_size = params.hop_size # 480 41 | param_fft_bins = params.fft_size // 2 + 1 # 481 42 | param_erb_bins = params.nb_erb # 32 43 | param_erb_min_width = params.min_nb_freqs # 2 44 | param_deep_filter_bins = params.nb_df # 96 45 | param_norm_alpha = get_norm_alpha(False) # 0.99 46 | 47 | assert getattr(model, 'freq_bins', param_fft_bins) == param_fft_bins 48 | assert getattr(model, 'erb_bins', param_erb_bins) == param_erb_bins 49 | assert getattr(model, 'nb_df', getattr(model, 'df_bins', param_deep_filter_bins)) == param_deep_filter_bins 50 | assert state.sr() == param_sr 51 | assert len(state.erb_widths()) == param_erb_bins 52 | 53 | print(dict( 54 | sr=param_sr, 55 | fft_size=param_fft_size, 56 | hop_size=param_hop_size, 57 | fft_bins=param_fft_bins, 58 | erb_bins=param_erb_bins, 59 | erb_min_width=param_erb_min_width, 60 | deep_filter_bins=param_deep_filter_bins, 61 | norm_alpha=param_norm_alpha)) 62 | print() 63 | 64 | stft = STFT( 65 | framesize=param_fft_size, 66 | hopsize=param_hop_size, 67 | window='hann') 68 | 69 | erb = ERB( 70 | samplerate=param_sr, 71 | fftsize=param_fft_size, 72 | erbsize=param_erb_bins, 73 | minwidth=param_erb_min_width, 74 | alpha=param_norm_alpha) 75 | 76 | cpx = CPX( 77 | cpxsize=param_deep_filter_bins, 78 | alpha=param_norm_alpha) 79 | 80 | spec, erb_feat, spec_feat = df_features(torch.from_numpy(x.astype(np.float32)), state, param_deep_filter_bins, device=cpu) 81 | print('spec', spec.shape, spec.dtype) 82 | print('erb_feat', erb_feat.shape, erb_feat.dtype) 83 | print('spec_feat', spec_feat.shape, spec_feat.dtype) 84 | print() 85 | 86 | if False: 87 | 88 | dfts0 = np.squeeze(torch.view_as_complex(spec).numpy()) 89 | print(dfts0.shape, dfts0.dtype) 90 | spectrogram(dfts0, name='dfts0') 91 | 92 | dfts1 = stft.stft(x[0]) 93 | print(dfts1.shape, dfts1.dtype) 94 | spectrogram(dfts1, name='dfts1') 95 | 96 | plot.show() 97 | exit() 98 | 99 | if False: 100 | 101 | x = state.erb_widths() 102 | y = erb.widths 103 | print(x) 104 | print(y) 105 | assert np.allclose(x, y) 106 | 107 | weights = erb.weights 108 | print(weights.shape) 109 | plot.figure() 110 | for i in range(weights.shape[-1]): 111 | plot.plot(weights[..., i]) 112 | 113 | plot.show() 114 | exit() 115 | 116 | if False: 117 | 118 | x = torch.view_as_complex(spec).numpy() 119 | y = erb(x) 120 | 121 | foo = np.squeeze(erb_feat.numpy()) 122 | bar = np.squeeze(y) 123 | 124 | erbgram(foo, name='foo', clim=(0,1)) 125 | erbgram(bar, name='bar', clim=(0,1)) 126 | 127 | plot.show() 128 | exit() 129 | 130 | if False: 131 | 132 | x = torch.view_as_complex(spec_feat).numpy() 133 | x = np.squeeze(np.abs(x)) 134 | 135 | y = torch.view_as_complex(spec).numpy() 136 | y = cpx(y) 137 | y = np.squeeze(np.abs(y)) 138 | 139 | erbgram(x, name='x', clim=(0,2)) 140 | erbgram(y, name='y', clim=(0,2)) 141 | 142 | plot.show() 143 | exit() 144 | 145 | if True: 146 | 147 | x = torch.view_as_complex(spec).numpy() 148 | y = erb(x) 149 | erb_feat = torch.from_numpy(y.astype(np.float32)) 150 | 151 | if True: 152 | 153 | x = torch.view_as_complex(spec).numpy() 154 | y = cpx(x) 155 | y = np.stack((y.real, y.imag), axis=-1) 156 | spec_feat = torch.from_numpy(y.astype(np.float32)) 157 | 158 | output = model(spec, erb_feat, spec_feat) # orig: spec.clone() 159 | enhanced = output[0].cpu() 160 | print('enhanced', enhanced.shape, enhanced.dtype) 161 | enhanced = enhanced.squeeze(1) 162 | print('enhanced squeeze', enhanced.shape, enhanced.dtype) 163 | enhanced = torch.view_as_complex(enhanced) # orig: as_complex 164 | print('enhanced complex', enhanced.shape, enhanced.dtype) 165 | print() 166 | 167 | y = state.synthesis(enhanced.detach().numpy()) 168 | 169 | return y 170 | 171 | 172 | if __name__ == '__main__': 173 | 174 | model, state, _ = init_df() 175 | 176 | model.eval() 177 | 178 | x, sr = read('x.wav', state.sr()) 179 | 180 | y = filter1(model, state, x) \ 181 | if True else \ 182 | filter0(model, state, x) 183 | 184 | write('y.wav', sr, y) 185 | -------------------------------------------------------------------------------- /test_ort.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plot 2 | import numpy as np 3 | import onnxruntime as ort 4 | import os 5 | 6 | models = os.path.join('build', 'DeepFilterNet3', 'tmp', 'export') 7 | 8 | enc = ort.InferenceSession(os.path.join(models, 'enc.onnx')) 9 | erb_dec = ort.InferenceSession(os.path.join(models, 'erb_dec.onnx')) 10 | df_dec = ort.InferenceSession(os.path.join(models, 'df_dec.onnx')) 11 | 12 | for name, model in dict(enc=enc, erb_dec=erb_dec, df_dec=df_dec).items(): 13 | 14 | print(name) 15 | 16 | print('INPUTS') 17 | for x in model.get_inputs(): 18 | print(x.name, x.shape, x.type) 19 | 20 | print('OUTPUTS') 21 | for y in model.get_outputs(): 22 | print(y.name, y.shape, y.type) 23 | 24 | print() 25 | 26 | # feat_erb = np.zeros((1, 1, 1, 32), np.float32) 27 | # feat_spec = np.zeros((1, 2, 1, 96), np.float32) 28 | 29 | # enc_output = enc.run( 30 | # ['e0', 'e1', 'e2', 'e3', 'c0', 'emb'], 31 | # {'feat_erb': feat_erb, 'feat_spec': feat_spec}) 32 | 33 | # print(type(enc_output)) 34 | -------------------------------------------------------------------------------- /x.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jurihock/RealTimeDeepFilterNet/775f582e5a0dc09fcfaf895915a5d36e5a236e5c/x.wav -------------------------------------------------------------------------------- /y.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jurihock/RealTimeDeepFilterNet/775f582e5a0dc09fcfaf895915a5d36e5a236e5c/y.wav --------------------------------------------------------------------------------