├── .github └── workflows │ ├── linux-arm64.yml │ ├── linux.yml │ └── windows.yml ├── CMakeLists.txt ├── LICENSE ├── README.md └── source ├── config.h.in └── wnnm.cpp /.github/workflows/linux-arm64.yml: -------------------------------------------------------------------------------- 1 | name: Build (Linux, ARM64) 2 | 3 | on: 4 | push: 5 | paths: 6 | - 'source/*' 7 | - 'CMakeLists.txt' 8 | - '.github/workflows/linux-arm64.yml' 9 | workflow_dispatch: 10 | 11 | jobs: 12 | build-linux: 13 | runs-on: ubuntu-24.04-arm 14 | steps: 15 | - name: Checkout repo 16 | uses: actions/checkout@v4 17 | with: 18 | fetch-depth: 0 19 | 20 | - name: Setup ArmPL 21 | run: | 22 | wget -q -O armpl.tar https://developer.arm.com/-/cdn-downloads/permalink/Arm-Performance-Libraries/Version_24.10/arm-performance-libraries_24.10_deb_gcc.tar 23 | tar -xf armpl.tar 24 | sudo arm-performance-libraries_24.10_deb/arm-performance-libraries_24.10_deb.sh -a 25 | /opt/arm/armpl_24.10_gcc/bin/armpl-info 26 | 27 | - name: Download VapourSynth headers 28 | run: | 29 | wget -q -O vs.zip https://github.com/vapoursynth/vapoursynth/archive/refs/tags/R57.zip 30 | unzip -q vs.zip 31 | mv vapoursynth*/ vapoursynth 32 | 33 | - name: Setup Ninja 34 | run: pip install ninja 35 | 36 | - name: Configure 37 | run: cmake -S . -B build -G Ninja -LA 38 | -D CMAKE_BUILD_TYPE=Release 39 | -D CMAKE_CXX_FLAGS="-Wall -ffast-math" 40 | -D ARMPL_HOME=/opt/arm/armpl_24.10_gcc 41 | -D VS_INCLUDE_DIR="`pwd`/vapoursynth/include" 42 | 43 | - name: Build 44 | run: cmake --build build --verbose 45 | 46 | - name: Install 47 | run: cmake --install build --prefix install 48 | 49 | - name: Upload 50 | uses: actions/upload-artifact@v4 51 | if: true 52 | with: 53 | name: wnnm-linux-armv8 54 | path: install/lib/*.so 55 | 56 | - name: Configure (SVE) 57 | run: cmake -S . -B build_sve -G Ninja -LA 58 | -D CMAKE_BUILD_TYPE=Release 59 | -D CMAKE_CXX_FLAGS="-Wall -ffast-math -march=armv8-a+sve" 60 | -D ARMPL_HOME=/opt/arm/armpl_24.10_gcc 61 | -D VS_INCLUDE_DIR="`pwd`/vapoursynth/include" 62 | 63 | - name: Build (SVE) 64 | run: cmake --build build_sve --verbose 65 | 66 | - name: Install (SVE) 67 | run: cmake --install build_sve --prefix install_sve 68 | 69 | - name: Upload (SVE) 70 | uses: actions/upload-artifact@v4 71 | if: true 72 | with: 73 | name: wnnm-linux-armv8+sve 74 | path: install_sve/lib/*.so 75 | 76 | -------------------------------------------------------------------------------- /.github/workflows/linux.yml: -------------------------------------------------------------------------------- 1 | name: Build (Linux) 2 | 3 | on: 4 | push: 5 | paths: 6 | - 'source/*' 7 | - 'CMakeLists.txt' 8 | - '.github/workflows/linux.yml' 9 | workflow_dispatch: 10 | 11 | jobs: 12 | build-linux: 13 | runs-on: ubuntu-24.04 14 | steps: 15 | - name: Checkout repo 16 | uses: actions/checkout@v4 17 | with: 18 | fetch-depth: 0 19 | 20 | - name: Setup oneMKL 21 | run: | 22 | wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | sudo tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null 23 | echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | sudo tee /etc/apt/sources.list.d/oneAPI.list 24 | sudo apt update 25 | sudo apt install intel-oneapi-mkl-devel 26 | 27 | - name: Download VapourSynth headers 28 | run: | 29 | wget -q -O vs.zip https://github.com/vapoursynth/vapoursynth/archive/refs/tags/R57.zip 30 | unzip -q vs.zip 31 | mv vapoursynth*/ vapoursynth 32 | 33 | - name: Download Vector Class Library 34 | run: | 35 | wget -q -O vcl.zip https://github.com/vectorclass/version2/archive/refs/tags/v2.01.04.zip 36 | unzip -q vcl.zip 37 | mv version2*/ vectorclass 38 | 39 | - name: Setup Ninja 40 | run: pip install ninja 41 | 42 | - name: Configure 43 | run: cmake -S . -B build -G Ninja -LA 44 | -D CMAKE_BUILD_TYPE=Release 45 | -D CMAKE_CXX_COMPILER=g++ 46 | -D CMAKE_CXX_FLAGS="-Wall -mavx2 -mfma -ffast-math" 47 | -D MKL_DIR=/opt/intel/oneapi/mkl/latest/lib/cmake/mkl 48 | -D MKL_LINK=static -D MKL_THREADING=sequential -D MKL_INTERFACE=lp64 49 | -D VCL_HOME="`pwd`/vectorclass" 50 | -D VS_INCLUDE_DIR="`pwd`/vapoursynth/include" 51 | env: 52 | CXX: g++-11 53 | 54 | - name: Build 55 | run: cmake --build build --verbose 56 | 57 | - name: Install 58 | run: cmake --install build --prefix install 59 | 60 | - name: Upload 61 | uses: actions/upload-artifact@v4 62 | if: true 63 | with: 64 | name: wnnm-linux 65 | path: install/lib/*.so 66 | 67 | -------------------------------------------------------------------------------- /.github/workflows/windows.yml: -------------------------------------------------------------------------------- 1 | name: Build (Windows) 2 | 3 | on: 4 | push: 5 | paths: 6 | - 'source/*' 7 | - 'CMakeLists.txt' 8 | - '.github/workflows/windows.yml' 9 | workflow_dispatch: 10 | inputs: 11 | tag: 12 | description: 'which tag to upload to' 13 | default: '' 14 | 15 | jobs: 16 | build-windows: 17 | runs-on: windows-2022 18 | defaults: 19 | run: 20 | shell: cmd 21 | 22 | steps: 23 | - name: Checkout repo 24 | uses: actions/checkout@v4 25 | with: 26 | fetch-depth: 0 27 | 28 | - name: Setup MSVC 29 | uses: ilammy/msvc-dev-cmd@v1 30 | 31 | - name: Cache oneMKL 32 | id: cache-onemkl 33 | uses: actions/cache@v4 34 | with: 35 | path: C:\Program Files (x86)\Intel\oneAPI\mkl 36 | key: ${{ runner.os }}-onemkl-2022.2.0 37 | 38 | - name: Setup oneMKL 39 | if: steps.cache-onemkl.outputs.cache-hit != 'true' 40 | run: | 41 | curl -s -o onemkl.exe -L https://registrationcenter-download.intel.com/akdlm/IRC_NAS/18899/w_onemkl_p_2022.2.0.9563_offline.exe 42 | onemkl.exe -s -a --silent --eula accept 43 | 44 | - name: Download VapourSynth headers 45 | run: | 46 | curl -s -o vs.zip -L https://github.com/vapoursynth/vapoursynth/archive/refs/tags/R57.zip 47 | unzip -q vs.zip 48 | mv vapoursynth-*/ vapoursynth/ 49 | 50 | - name: Download Vector Class Library 51 | run: | 52 | curl -s -o vcl.zip -L https://github.com/vectorclass/version2/archive/refs/tags/v2.01.04.zip 53 | unzip -q vcl.zip 54 | mv version2*/ vectorclass 55 | 56 | - name: Setup Ninja 57 | run: pip install ninja 58 | 59 | - name: Setup LLVM 60 | shell: bash 61 | run: | 62 | curl -s -o llvm-win64.exe -LJO https://github.com/llvm/llvm-project/releases/download/llvmorg-19.1.7/LLVM-19.1.7-win64.exe 63 | 7z x -ollvm llvm-win64.exe 64 | 65 | - name: Configure 66 | shell: bash 67 | run: cmake -S . -B build -G Ninja -LA 68 | -D CMAKE_BUILD_TYPE=Release 69 | -D CMAKE_CXX_COMPILER="$(pwd)/llvm/bin/clang++.exe" 70 | -D CMAKE_CXX_FLAGS="-Wall -mavx2 -mfma -ffast-math" 71 | -D CMAKE_MSVC_RUNTIME_LIBRARY=MultiThreaded 72 | -D MKL_DIR="/c/Program Files (x86)/Intel/oneAPI/mkl/latest/lib/cmake/mkl" 73 | -D MKL_LINK=static -D MKL_THREADING=sequential -D MKL_INTERFACE=lp64 74 | -D VCL_HOME="$(pwd)/vectorclass" 75 | -D VS_INCLUDE_DIR="$(pwd)/vapoursynth/include" 76 | 77 | - name: Build 78 | run: cmake --build build --verbose 79 | 80 | - name: Install 81 | run: | 82 | cmake --install build --prefix install 83 | mkdir artifact 84 | copy install\bin\wnnm.dll artifact\ 85 | 86 | - name: Upload 87 | uses: actions/upload-artifact@v3 88 | with: 89 | name: wnnm-windows-x64 90 | path: artifact 91 | 92 | - name: Compress artifact for release 93 | if: github.event_name == 'workflow_dispatch' && github.event.inputs.tag != '' 94 | run: | 95 | cd artifact 96 | 97 | mkdir VapourSynth-WNNM-${{ github.event.inputs.tag }} 98 | 7z a -t7z -mx=9 ../VapourSynth-WNNM-${{ github.event.inputs.tag }}.7z wnnm.dll 99 | 100 | - name: Release 101 | uses: softprops/action-gh-release@v1 102 | if: github.event_name == 'workflow_dispatch' && github.event.inputs.tag != '' 103 | with: 104 | tag_name: ${{ github.event.inputs.tag }} 105 | files: VapourSynth-WNNM-${{ github.event.inputs.tag }}.7z 106 | fail_on_unmatched_files: true 107 | generate_release_notes: false 108 | prerelease: true 109 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.22.0) 2 | 3 | project(vs-wnnm VERSION 0.1 LANGUAGES CXX) 4 | 5 | set(VCL_HOME "" CACHE PATH "Path to vector class v2 headers") 6 | set(ARMPL_HOME "" CACHE PATH "Path to Arm Performance Libraries") 7 | 8 | add_library(wnnm SHARED source/wnnm.cpp) 9 | 10 | if(CMAKE_SYSTEM_PROCESSOR MATCHES "amd64.*|x86_64.*|AMD64.*") 11 | find_package(MKL CONFIG REQUIRED) 12 | 13 | target_include_directories(wnnm PRIVATE ${VCL_HOME}) 14 | 15 | target_link_libraries(wnnm PRIVATE $) 16 | elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64.*|AARCH64.*|arm64.*|ARM64.*|armv8.*)") 17 | cmake_path(APPEND TEMP_PATH "${ARMPL_HOME}" "include") 18 | target_include_directories(wnnm PRIVATE "${TEMP_PATH}") 19 | 20 | cmake_path(APPEND TEMP_PATH "${ARMPL_HOME}" "lib") 21 | target_link_directories(wnnm PRIVATE "${TEMP_PATH}") 22 | 23 | target_link_libraries(wnnm PRIVATE armpl_lp64) 24 | else() 25 | message(SEND_ERROR "unknown target: ${CMAKE_SYSTEM_PROCESSOR}") 26 | endif() 27 | 28 | set_target_properties(wnnm PROPERTIES 29 | CXX_EXTENSIONS OFF 30 | CXX_STANDARD 17 31 | CXX_STANDARD_REQUIRED ON 32 | ) 33 | 34 | find_package(PkgConfig QUIET MODULE) 35 | 36 | if(PKG_CONFIG_FOUND) 37 | pkg_search_module(VS vapoursynth) 38 | 39 | if(VS_FOUND) 40 | message(STATUS "Found VapourSynth r${VS_VERSION}") 41 | 42 | cmake_path(APPEND install_dir ${VS_LIBDIR} vapoursynth) 43 | target_include_directories(wnnm PRIVATE ${VS_INCLUDE_DIRS}) 44 | 45 | install(TARGETS wnnm LIBRARY DESTINATION ${install_dir}) 46 | endif() 47 | endif() 48 | 49 | if(NOT VS_FOUND) 50 | set(VS_INCLUDE_DIR "" CACHE PATH "Path to VapourSynth headers") 51 | 52 | if(VS_INCLUDE_DIR STREQUAL "") 53 | message(WARNING "VapourSynth not found") 54 | endif() 55 | 56 | target_include_directories(wnnm PRIVATE ${VS_INCLUDE_DIR}) 57 | 58 | install(TARGETS wnnm LIBRARY RUNTIME) 59 | endif() 60 | 61 | find_package(Git QUIET) 62 | 63 | if(GIT_FOUND) 64 | execute_process( 65 | COMMAND ${GIT_EXECUTABLE} describe --tags --long --always 66 | WORKING_DIRECTORY "${CMAKE_SOURCE_DIR}" 67 | OUTPUT_VARIABLE VCS_TAG 68 | ) 69 | if(VCS_TAG) 70 | string(STRIP ${VCS_TAG} VCS_TAG) 71 | endif() 72 | endif() 73 | 74 | if(VCS_TAG) 75 | message(STATUS "vs-wnnm ${VCS_TAG}") 76 | else() 77 | message(WARNING "unknown plugin version") 78 | set(VCS_TAG "unknown") 79 | endif() 80 | 81 | configure_file(source/config.h.in config.h) 82 | 83 | target_include_directories(wnnm PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) 84 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 WolframRhodium 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VapourSynth-WNNM 2 | [Weighted Nuclear Norm Minimization](https://ieeexplore.ieee.org/document/6909762) Denoiser for VapourSynth. 3 | 4 | ## Description 5 | `WNNM` is a denoising algorithm based on block-matching and weighted nuclear norm minimization. 6 | 7 | Block matching, which is popularized by `BM3D`, finds similar blocks and then stacks together in a 3-D group. The similarity between these blocks allows details to be preserved during denoising. 8 | 9 | In contrast to `BM3D`, which denoises the 3-D group based on frequency domain filtering, `WNNM` utilizes weighted nuclear norm minimization, a kind of low rank matrix approximation. Because of this, `WNNM` exhibits less blocking and ringing artifact compared to `BM3D`, but the computational complexity is much higher. This stage is called collaborative filtering in `BM3D`. 10 | 11 | ## Usage 12 | Prototype: 13 | 14 | `core.wnnm.WNNM(clip clip[, float[] sigma = 3.0, int block_size = 8, int block_step = 8, int group_size = 8, int bm_range = 7, int radius = 0, int ps_num = 2, int ps_range = 4, bool residual = false, bool adaptive_aggregation = true, clip rclip = None])` 15 | 16 | - clip: 17 | 18 | The input clip. Must be of 32 bit float format. Each plane is denoised separately. 19 | 20 | - sigma: 21 | 22 | Denoising strength of each plane. 23 | 24 | - block_size, block_step, group_size, bm_range, radius, ps_num, ps_range: 25 | 26 | Same as those in [VapourSynth-BM3D](https://github.com/HomeOfVapourSynthEvolution/VapourSynth-BM3D). 27 | 28 | - residual: 29 | 30 | Whether to center blocks before collaborative filtering. Default: `False`. 31 | 32 | - adaptive_aggregation: 33 | 34 | Whether to aggregate blocks adaptively. Default: `True`. 35 | 36 | - rclip: 37 | 38 | Reference clip for block matching. Must be of the same dimensions and format as `clip`. 39 | 40 | ## Implementation 41 | Default values of `block_size`, `block_step`, `group_size` are modified for acceleration. 42 | 43 | For spatial denoising, the block-matching implemented is the same as the official implementation, which is similar to that of `BM3D` without setting a threshold on whether dissimilar blocks should be included in the 3-D group. This is the same strategy implemented in [VapourSynth-BM3DCUDA](https://github.com/WolframRhodium/VapourSynth-BM3DCUDA) but not in [VapourSynth-BM3D](https://github.com/HomeOfVapourSynthEvolution/VapourSynth-BM3D). 44 | 45 | For temporal denoising, this implementation utilizes the same predictive search proposed by `V-BM3D`, which is closer to [VapourSynth-BM3D](https://github.com/HomeOfVapourSynthEvolution/VapourSynth-BM3D) (without dissimilar block thresholding) than [VapourSynth-BM3DCUDA](https://github.com/WolframRhodium/VapourSynth-BM3DCUDA). The later one implemented a modified temporal predictive search that may finds multiple instances of the same similar block for acceleration. 46 | 47 | During collaborative filtering, the official WNNM implementation centers blocks in the 3-D group. This is controlled by the `residual` parameter and is off by default. The major singular value is untouched when `residual` is off. 48 | 49 | **Note**: Because of WNNM and the modification, the maximum denoising effect achieved is the best rank-one approximation of the 3-D group when `residual` is off, or the mean of the group when `residual` is on, which may not be enough for strong noises. The official implementation uses iterative regularization, which can be easily implemented as 50 | ```python 51 | for i in range(num_iterations): 52 | if i == 0: 53 | previous = source 54 | elif i == 1: 55 | previous = denoised 56 | else: 57 | previous = core.std.Expr([source, previous, denoised], "x y - {factor} * z +".format(factor=0.1)) 58 | denoised = WNNM(previous) 59 | # output: `denoised` 60 | ``` 61 | 62 | The similar blocks are weightedly aggregated by the inverse of the number of non-zero singular values after WNNM, inspired by `BM3D`. This is controlled by the `adaptive_aggregation` parameter and is on by default. 63 | 64 | The block-matching can be guided by an oracle reference clip `rclip` in the same manner as `ref` for `BM3D`. The collaborative filtering is not guided, unlike `BM3D`. 65 | 66 | ## Compilation 67 | - On x86_64, [oneMKL](https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl.html) is required. [Vector class library](https://github.com/vectorclass/version2) is also required when compiling with AVX2. 68 | 69 | ```bash 70 | cmake -S . -B build -D CMAKE_BUILD_TYPE=Release \ 71 | -D MKL_LINK=static -D MKL_THREADING=sequential -D MKL_INTERFACE=lp64 72 | 73 | cmake --build build 74 | 75 | cmake --install build 76 | ``` 77 | 78 | - On Aarch64, [ArmPL](https://developer.arm.com/Tools%20and%20Software/Arm%20Performance%20Libraries) is required. 79 | 80 | Example build process can be found in [workflows](https://github.com/WolframRhodium/VapourSynth-WNNM/tree/master/.github/workflows). 81 | 82 | ## Reference 83 | 1. S. Gu, L. Zhang, W. Zuo and X. Feng, "[Weighted Nuclear Norm Minimization with Application to Image Denoising](https://ieeexplore.ieee.org/document/6909762)," 2014 IEEE Conference on Computer Vision and Pattern Recognition, 2014, pp. 2862-2869. 84 | 85 | 2. K. Dabov, A. Foi, V. Katkovnik and K. Egiazarian, "[Image Denoising by Sparse 3-D Transform-Domain Collaborative Filtering](https://ieeexplore.ieee.org/document/4271520)," in IEEE Transactions on Image Processing, vol. 16, no. 8, pp. 2080-2095, Aug. 2007. 86 | 87 | 3. K. Dabov, A. Foi and K. Egiazarian, "[Video denoising by sparse 3D transform-domain collaborative filtering](https://ieeexplore.ieee.org/document/7098781)," 2007 15th European Signal Processing Conference, 2007, pp. 145-149. 88 | 89 | 4. [Official implementation](https://www4.comp.polyu.edu.hk/~cslzhang/code/WNNM_code.zip) 90 | 91 | 5. [VapourSynth-BM3D](https://github.com/HomeOfVapourSynthEvolution/VapourSynth-BM3D) 92 | 93 | 6. [VapourSynth-BM3DCUDA](https://github.com/WolframRhodium/VapourSynth-BM3DCUDA) 94 | -------------------------------------------------------------------------------- /source/config.h.in: -------------------------------------------------------------------------------- 1 | #define VERSION "@VCS_TAG@" -------------------------------------------------------------------------------- /source/wnnm.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #if defined(__x86_64__) || defined(_M_AMD64) 20 | // MKL 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | #ifdef __AVX2__ 27 | #include 28 | #include 29 | #endif // __AVX2__ 30 | 31 | #elif defined(__aarch64__) || defined(_M_ARM64) 32 | #include 33 | 34 | #ifdef __ARM_FEATURE_SVE 35 | #include 36 | #endif // __ARM_FEATURE_SVE 37 | 38 | #else 39 | #error "unknown target" 40 | #endif 41 | 42 | #include 43 | #include 44 | 45 | #include 46 | 47 | static VSPlugin * myself = nullptr; 48 | 49 | template 50 | static inline T square(T const & x) noexcept { 51 | return x * x; 52 | } 53 | 54 | static inline int m16(int x) noexcept { 55 | assert(x > 0); 56 | return ((x - 1) / 16 + 1) * 16; 57 | } 58 | 59 | namespace { 60 | struct Workspace { 61 | float * intermediate; // [radius == 0] shape: (2, height, width) 62 | float * denoising_patch; // shape: (group_size, svd_lda) + pad (simd_lanes - 1) 63 | float * mean_patch; // [residual] shape: (block_size, block_size) + pad (simd_lanes - 1) 64 | float * current_patch; // shape: (block_size, block_size) + pad (simd_lanes - 1) 65 | float * svd_s; // shape: (min(square(block_size), group_size),) 66 | float * svd_u; // shape: (min(square(block_size), group_size), svd_ldu) 67 | float * svd_vt; // shape: (min(square(block_size), group_size), svd_ldvt) 68 | float * svd_work; // shape: (svd_lwork,) 69 | int * svd_iwork; // shape: (8 * min(square(block_size), group_size),) 70 | std::vector> * errors; // shape: dynamic 71 | std::vector> * center_errors; // shape: dynamic 72 | std::vector> * search_locations; // shape: dynamic 73 | std::vector> * new_locations; // shape: dynamic 74 | std::vector> * locations_copy; // shape: dynamic 75 | std::vector> * temporal_errors; // shape: dynamic 76 | 77 | void init( 78 | int width, int height, 79 | int block_size, int group_size, int radius, 80 | bool residual, 81 | int svd_lda, int svd_ldu, int svd_ldvt, int svd_lwork 82 | ) noexcept { 83 | 84 | #ifdef __AVX2__ 85 | constexpr int pad = 7; 86 | #elif defined(__ARM_FEATURE_SVE) 87 | const int pad = static_cast(svlen(svfloat32_t{})) - 1; 88 | #else 89 | constexpr int pad = 0; 90 | #endif 91 | 92 | if (residual) { 93 | mean_patch = vs_aligned_malloc((square(block_size) + pad) * sizeof(float), 64); 94 | } else { 95 | mean_patch = nullptr; 96 | } 97 | 98 | current_patch = vs_aligned_malloc((square(block_size) + pad) * sizeof(float), 64); 99 | 100 | if (radius == 0) { 101 | intermediate = reinterpret_cast(std::malloc(2 * height * width * sizeof(float))); 102 | } else { 103 | intermediate = nullptr; 104 | } 105 | 106 | int m = square(block_size); 107 | int n = group_size; 108 | 109 | denoising_patch = vs_aligned_malloc((svd_lda * n + pad) * sizeof(float), 64); 110 | 111 | svd_s = vs_aligned_malloc(std::min(m, n) * sizeof(float), 64); 112 | 113 | svd_u = vs_aligned_malloc(svd_ldu * std::min(m, n) * sizeof(float), 64); 114 | 115 | svd_vt = vs_aligned_malloc(svd_ldvt * n * sizeof(float), 64); 116 | 117 | svd_work = vs_aligned_malloc(svd_lwork * sizeof(float), 64); 118 | 119 | svd_iwork = vs_aligned_malloc(8 * std::min(m, n) * sizeof(int), 64); 120 | 121 | errors = new std::remove_pointer_t; 122 | center_errors = new std::remove_pointer_t; 123 | search_locations = new std::remove_pointer_t; 124 | new_locations = new std::remove_pointer_t; 125 | locations_copy = new std::remove_pointer_t; 126 | temporal_errors = new std::remove_pointer_t; 127 | } 128 | 129 | void release() noexcept { 130 | vs_aligned_free(mean_patch); 131 | mean_patch = nullptr; 132 | 133 | vs_aligned_free(current_patch); 134 | current_patch = nullptr; 135 | 136 | std::free(intermediate); 137 | intermediate = nullptr; 138 | 139 | vs_aligned_free(denoising_patch); 140 | denoising_patch = nullptr; 141 | 142 | vs_aligned_free(svd_s); 143 | svd_s = nullptr; 144 | 145 | vs_aligned_free(svd_u); 146 | svd_u = nullptr; 147 | 148 | vs_aligned_free(svd_vt); 149 | svd_vt = nullptr; 150 | 151 | vs_aligned_free(svd_work); 152 | svd_work = nullptr; 153 | 154 | vs_aligned_free(svd_iwork); 155 | svd_iwork = nullptr; 156 | 157 | delete errors; 158 | errors = nullptr; 159 | 160 | delete center_errors; 161 | center_errors = nullptr; 162 | 163 | delete search_locations; 164 | search_locations = nullptr; 165 | 166 | delete new_locations; 167 | new_locations = nullptr; 168 | 169 | delete locations_copy; 170 | locations_copy = nullptr; 171 | 172 | delete temporal_errors; 173 | temporal_errors = nullptr; 174 | } 175 | }; 176 | 177 | struct WNNMData { 178 | VSNodeRef * node; 179 | float sigma[3]; 180 | int block_size, block_step, group_size, bm_range; 181 | int radius, ps_num, ps_range; 182 | bool process[3]; 183 | bool residual, adaptive_aggregation; 184 | VSNodeRef * ref_node; // rclip 185 | int svd_lwork, svd_lda, svd_ldu, svd_ldvt; 186 | 187 | std::unordered_map workspaces; 188 | std::shared_mutex workspaces_lock; 189 | }; 190 | 191 | enum class WnnmInfo { SUCCESS, FAILURE }; 192 | } // namespace 193 | 194 | #ifdef __AVX2__ 195 | static inline Vec8i make_mask(int block_size_m8) noexcept { 196 | static constexpr int temp[16] {-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0}; 197 | 198 | return Vec8i().load(temp + 8 - block_size_m8); 199 | } 200 | #endif 201 | 202 | #ifdef __AVX2__ 203 | namespace { 204 | enum class BlockSizeInfo { Is8, Mod8, General }; 205 | 206 | struct Empty {}; 207 | } 208 | 209 | template 210 | static inline void compute_block_distances_avx2( 211 | std::vector> & errors, 212 | const float * VS_RESTRICT current_patch, 213 | const float * VS_RESTRICT neighbour_patch, 214 | int top, int bottom, int left, int right, 215 | int stride, int block_size 216 | ) noexcept { 217 | 218 | if constexpr (dispatch == BlockSizeInfo::Is8) { 219 | block_size = 8; 220 | } 221 | 222 | [[maybe_unused]] std::conditional_t mask; 223 | if constexpr (dispatch == BlockSizeInfo::General) { 224 | mask = make_mask(block_size % 8); 225 | } 226 | 227 | for (int bm_y = top; bm_y <= bottom; ++bm_y) { 228 | for (int bm_x = left; bm_x <= right; ++bm_x) { 229 | Vec8f vec_error {0.f}; 230 | 231 | const float * VS_RESTRICT current_patchp = current_patch; 232 | const float * VS_RESTRICT neighbour_patchp = neighbour_patch; 233 | 234 | for (int patch_y = 0; patch_y < block_size; ++patch_y) { 235 | if constexpr (dispatch == BlockSizeInfo::Is8) { 236 | Vec8f vec_current = Vec8f().load_a(current_patchp); 237 | Vec8f vec_neighbour = Vec8f().load(neighbour_patchp); 238 | 239 | Vec8f diff = vec_current - vec_neighbour; 240 | vec_error = mul_add(diff, diff, vec_error); 241 | 242 | current_patchp += 8; 243 | neighbour_patchp += stride; 244 | } else if constexpr (dispatch == BlockSizeInfo::Mod8) { 245 | for (int patch_x = 0; patch_x < block_size; patch_x += 8) { 246 | Vec8f vec_current = Vec8f().load_a(current_patchp); 247 | Vec8f vec_neighbour = Vec8f().load(neighbour_patchp); 248 | 249 | Vec8f diff = vec_current - vec_neighbour; 250 | vec_error = mul_add(diff, diff, vec_error); 251 | 252 | current_patchp += 8; 253 | neighbour_patchp += 8; 254 | } 255 | 256 | neighbour_patchp += stride - block_size; 257 | } else if constexpr (dispatch == BlockSizeInfo::General) { 258 | for (int patch_x = 0; patch_x < (block_size & (-8)); patch_x += 8) { 259 | Vec8f vec_current = Vec8f().load(current_patchp); 260 | Vec8f vec_neighbour = Vec8f().load(neighbour_patchp); 261 | 262 | Vec8f diff = vec_current - vec_neighbour; 263 | vec_error = mul_add(diff, diff, vec_error); 264 | 265 | current_patchp += 8; 266 | neighbour_patchp += 8; 267 | } 268 | 269 | { 270 | Vec8f vec_current = _mm256_maskload_ps(current_patchp, mask); 271 | Vec8f vec_neighbour = _mm256_maskload_ps(neighbour_patchp, mask); 272 | 273 | Vec8f diff = vec_current - vec_neighbour; 274 | vec_error = mul_add(diff, diff, vec_error); 275 | 276 | current_patchp += block_size % 8; 277 | neighbour_patchp += stride - (block_size & (-8)); 278 | } 279 | } 280 | } 281 | 282 | float error { horizontal_add(vec_error) }; 283 | 284 | errors.emplace_back(error, bm_x, bm_y); 285 | 286 | neighbour_patch++; 287 | } 288 | 289 | neighbour_patch += stride - (right - left + 1); 290 | } 291 | } 292 | 293 | template 294 | static inline void compute_block_distances_avx2( 295 | std::vector> & errors, 296 | const float * VS_RESTRICT current_patch, 297 | const float * VS_RESTRICT refp, 298 | const std::vector> & search_positions, 299 | int stride, int block_size 300 | ) noexcept { 301 | 302 | if constexpr (dispatch == BlockSizeInfo::Is8) { 303 | block_size = 8; 304 | } 305 | 306 | [[maybe_unused]] std::conditional_t mask; 307 | if constexpr (dispatch == BlockSizeInfo::General) { 308 | mask = make_mask(block_size % 8); 309 | } 310 | 311 | for (const auto & [bm_x, bm_y]: search_positions) { 312 | Vec8f vec_error {0.f}; 313 | 314 | const float * VS_RESTRICT current_patchp = current_patch; 315 | const float * VS_RESTRICT neighbour_patchp = &refp[bm_y * stride + bm_x]; 316 | 317 | for (int patch_y = 0; patch_y < block_size; ++patch_y) { 318 | if constexpr (dispatch == BlockSizeInfo::Is8) { 319 | Vec8f vec_current = Vec8f().load_a(current_patchp); 320 | Vec8f vec_neighbour = Vec8f().load(neighbour_patchp); 321 | 322 | Vec8f diff = vec_current - vec_neighbour; 323 | vec_error = mul_add(diff, diff, vec_error); 324 | 325 | current_patchp += 8; 326 | neighbour_patchp += stride; 327 | } else if constexpr (dispatch == BlockSizeInfo::Mod8) { 328 | for (int patch_x = 0; patch_x < block_size; patch_x += 8) { 329 | Vec8f vec_current = Vec8f().load_a(current_patchp); 330 | Vec8f vec_neighbour = Vec8f().load(neighbour_patchp); 331 | 332 | Vec8f diff = vec_current - vec_neighbour; 333 | vec_error = mul_add(diff, diff, vec_error); 334 | 335 | current_patchp += 8; 336 | neighbour_patchp += 8; 337 | } 338 | 339 | neighbour_patchp += stride - block_size; 340 | } else if constexpr (dispatch == BlockSizeInfo::General) { 341 | for (int patch_x = 0; patch_x < (block_size & (-8)); patch_x += 8) { 342 | Vec8f vec_current = Vec8f().load(current_patchp); 343 | Vec8f vec_neighbour = Vec8f().load(neighbour_patchp); 344 | 345 | Vec8f diff = vec_current - vec_neighbour; 346 | vec_error = mul_add(diff, diff, vec_error); 347 | 348 | current_patchp += 8; 349 | neighbour_patchp += 8; 350 | } 351 | 352 | { 353 | Vec8f vec_current = _mm256_maskload_ps(current_patchp, mask); 354 | Vec8f vec_neighbour = _mm256_maskload_ps(neighbour_patchp, mask); 355 | 356 | Vec8f diff = vec_current - vec_neighbour; 357 | vec_error = mul_add(diff, diff, vec_error); 358 | 359 | current_patchp += block_size % 8; 360 | neighbour_patchp += stride - (block_size & (-8)); 361 | } 362 | } 363 | } 364 | 365 | float error { horizontal_add(vec_error) }; 366 | 367 | errors.emplace_back(error, bm_x, bm_y); 368 | } 369 | } 370 | #endif // __AVX2__ 371 | 372 | static inline void generate_search_locations( 373 | const std::tuple * center_positions, int num_center_positions, 374 | int block_size, int width, int height, int bm_range, 375 | std::vector> & search_locations, 376 | std::vector> & new_locations, 377 | std::vector> & locations_copy 378 | ) noexcept { 379 | 380 | search_locations.clear(); 381 | 382 | for (int i = 0; i < num_center_positions; i++) { 383 | const auto & [_, x, y] = center_positions[i]; 384 | int left = std::max(x - bm_range, 0); 385 | int right = std::min(x + bm_range, width - block_size); 386 | int top = std::max(y - bm_range, 0); 387 | int bottom = std::min(y + bm_range, height - block_size); 388 | 389 | new_locations.clear(); 390 | new_locations.reserve((bottom - top + 1) * (right - left + 1)); 391 | for (int j = top; j <= bottom; j++) { 392 | for (int k = left; k <= right; k++) { 393 | new_locations.emplace_back(k, j); 394 | } 395 | } 396 | 397 | locations_copy = search_locations; 398 | 399 | search_locations.reserve(std::size(search_locations) + std::size(new_locations)); 400 | 401 | search_locations.clear(); 402 | 403 | std::set_union( 404 | std::cbegin(locations_copy), std::cend(locations_copy), 405 | std::cbegin(new_locations), std::cend(new_locations), 406 | std::back_inserter(search_locations), 407 | [](const std::tuple & a, const std::tuple & b) -> bool { 408 | auto [ax, ay] = a; 409 | auto [bx, by] = b; 410 | return ay < by || (ay == by && ax < bx); 411 | } 412 | ); 413 | } 414 | } 415 | 416 | static inline void compute_block_distances( 417 | std::vector> & errors, 418 | const float * VS_RESTRICT current_patch, 419 | const float * VS_RESTRICT neighbour_patch, 420 | int top, int bottom, int left, int right, 421 | int stride, 422 | int block_size 423 | ) noexcept { 424 | 425 | #ifdef __AVX2__ 426 | if (block_size == 8) { 427 | return compute_block_distances_avx2(errors, current_patch, neighbour_patch, top, bottom, left, right, stride, block_size); 428 | } else if ((block_size % 8) == 0) { 429 | return compute_block_distances_avx2(errors, current_patch, neighbour_patch, top, bottom, left, right, stride, block_size); 430 | } else { 431 | return compute_block_distances_avx2(errors, current_patch, neighbour_patch, top, bottom, left, right, stride, block_size); 432 | } 433 | #elif defined(__ARM_FEATURE_SVE) 434 | const int step = static_cast(svlen(svfloat32_t{})); 435 | 436 | for (int bm_y = top; bm_y <= bottom; ++bm_y) { 437 | for (int bm_x = left; bm_x <= right; ++bm_x) { 438 | svfloat32_t error {}; 439 | 440 | const float * VS_RESTRICT current_patchp = current_patch; 441 | const float * VS_RESTRICT neighbour_patchp = neighbour_patch; 442 | 443 | for (int patch_y = 0; patch_y < block_size; ++patch_y) { 444 | for (int patch_x = 0; patch_x < block_size; patch_x += step) { 445 | auto predicate = svwhilelt_b32(patch_x, block_size); 446 | 447 | auto current = svld1(predicate, ¤t_patchp[patch_x]); 448 | auto neighbour = svld1(predicate, &neighbour_patchp[patch_x]); 449 | 450 | auto diff = svsub_z(predicate, current, neighbour); 451 | error = svmad_z(predicate, diff, diff, error); 452 | } 453 | 454 | current_patchp += block_size; 455 | neighbour_patchp += stride; 456 | } 457 | 458 | errors.emplace_back(svaddv(svptrue_b32(), error), bm_x, bm_y); 459 | 460 | neighbour_patch++; 461 | } 462 | 463 | neighbour_patch += stride - (right - left + 1); 464 | } 465 | #else // __AVX2__ 466 | for (int bm_y = top; bm_y <= bottom; ++bm_y) { 467 | for (int bm_x = left; bm_x <= right; ++bm_x) { 468 | float error = 0.f; 469 | 470 | const float * VS_RESTRICT current_patchp = current_patch; 471 | const float * VS_RESTRICT neighbour_patchp = neighbour_patch; 472 | 473 | for (int patch_y = 0; patch_y < block_size; ++patch_y) { 474 | for (int patch_x = 0; patch_x < block_size; ++patch_x) { 475 | error += square(current_patchp[patch_x] - neighbour_patchp[patch_x]); 476 | } 477 | 478 | current_patchp += block_size; 479 | neighbour_patchp += stride; 480 | } 481 | 482 | errors.emplace_back(error, bm_x, bm_y); 483 | 484 | neighbour_patch++; 485 | } 486 | 487 | neighbour_patch += stride - (right - left + 1); 488 | } 489 | #endif // __AVX2__ 490 | } 491 | 492 | static inline void compute_block_distances( 493 | std::vector> & errors, 494 | const float * VS_RESTRICT current_patch, 495 | const float * VS_RESTRICT refp, 496 | const std::vector> & search_positions, 497 | int stride, 498 | int block_size 499 | ) noexcept { 500 | 501 | #ifdef __AVX2__ 502 | if (block_size == 8) { 503 | return compute_block_distances_avx2( 504 | errors, 505 | current_patch, refp, search_positions, stride, block_size 506 | ); 507 | } else if ((block_size % 8) == 0) { 508 | return compute_block_distances_avx2( 509 | errors, 510 | current_patch, refp, search_positions, stride, block_size 511 | ); 512 | } else { 513 | return compute_block_distances_avx2( 514 | errors, 515 | current_patch, refp, search_positions, stride, block_size 516 | ); 517 | } 518 | #elif defined(__ARM_FEATURE_SVE) 519 | const int step = static_cast(svlen(svfloat32_t{})); 520 | 521 | for (const auto & [bm_x, bm_y]: search_positions) { 522 | svfloat32_t error {}; 523 | 524 | const float * VS_RESTRICT current_patchp = current_patch; 525 | const float * VS_RESTRICT neighbour_patchp = &refp[bm_y * stride + bm_x]; 526 | 527 | for (int patch_y = 0; patch_y < block_size; ++patch_y) { 528 | for (int patch_x = 0; patch_x < block_size; patch_x += step) { 529 | auto predicate = svwhilelt_b32(patch_x, block_size); 530 | 531 | auto current = svld1(predicate, ¤t_patchp[patch_x]); 532 | auto neighbour = svld1(predicate, &neighbour_patchp[patch_x]); 533 | 534 | auto diff = svsub_z(predicate, current, neighbour); 535 | error = svmad_z(predicate, diff, diff, error); 536 | } 537 | 538 | current_patchp += block_size; 539 | neighbour_patchp += stride; 540 | } 541 | 542 | errors.emplace_back(svaddv(svptrue_b32(), error), bm_x, bm_y); 543 | } 544 | #else 545 | for (const auto & [bm_x, bm_y]: search_positions) { 546 | float error = 0.f; 547 | 548 | const float * VS_RESTRICT current_patchp = current_patch; 549 | const float * VS_RESTRICT neighbour_patchp = &refp[bm_y * stride + bm_x]; 550 | 551 | for (int patch_y = 0; patch_y < block_size; ++patch_y) { 552 | for (int patch_x = 0; patch_x < block_size; ++patch_x) { 553 | error += square(current_patchp[patch_x] - neighbour_patchp[patch_x]); 554 | } 555 | 556 | current_patchp += block_size; 557 | neighbour_patchp += stride; 558 | } 559 | 560 | errors.emplace_back(error, bm_x, bm_y); 561 | } 562 | #endif // __AVX2__ 563 | } 564 | 565 | #ifdef __AVX2__ 566 | template 567 | static inline void load_patches_avx2( 568 | float * VS_RESTRICT denoising_patch, int svd_lda, 569 | std::conditional_t mean_patch, 570 | const std::vector & srcps, 571 | const std::vector> & errors, 572 | int stride, 573 | int active_group_size, 574 | int block_size 575 | ) noexcept { 576 | 577 | if constexpr (dispatch == BlockSizeInfo::Is8) { 578 | block_size = 8; 579 | } 580 | 581 | [[maybe_unused]] std::conditional_t mask; 582 | if constexpr (dispatch == BlockSizeInfo::General) { 583 | mask = make_mask(block_size % 8); 584 | } 585 | 586 | assert(stride % 8 == 0); 587 | 588 | for (int i = 0; i < active_group_size; ++i) { 589 | auto [error, bm_x, bm_y, bm_t] = errors[i]; 590 | 591 | const float * VS_RESTRICT src_patchp = &srcps[bm_t][bm_y * stride + bm_x]; 592 | 593 | [[maybe_unused]] std::conditional_t mean_patchp { mean_patch }; 594 | 595 | for (int patch_y = 0; patch_y < block_size; ++patch_y) { 596 | if constexpr (dispatch == BlockSizeInfo::Is8) { 597 | Vec8f vec_src = Vec8f().load(src_patchp); 598 | vec_src.store_a(denoising_patch); 599 | src_patchp += stride; 600 | denoising_patch += 8; 601 | 602 | if constexpr (residual) { 603 | Vec8f vec_mean = Vec8f().load_a(mean_patchp); 604 | vec_mean += vec_src; 605 | vec_mean.store_a(mean_patchp); 606 | mean_patchp += 8; 607 | } 608 | } else if constexpr (dispatch == BlockSizeInfo::Mod8) { 609 | for (int patch_x = 0; patch_x < block_size; patch_x += 8) { 610 | Vec8f vec_src = Vec8f().load(src_patchp); 611 | vec_src.store_a(denoising_patch); 612 | src_patchp += 8; 613 | denoising_patch += 8; 614 | 615 | if constexpr (residual) { 616 | Vec8f vec_mean = Vec8f().load_a(mean_patchp); 617 | vec_mean += vec_src; 618 | vec_mean.store_a(mean_patchp); 619 | mean_patchp += 8; 620 | } 621 | } 622 | 623 | src_patchp += stride - block_size; 624 | } if constexpr (dispatch == BlockSizeInfo::General) { 625 | for (int patch_x = 0; patch_x < (block_size & (-8)); patch_x += 8) { 626 | Vec8f vec_src = Vec8f().load(src_patchp); 627 | vec_src.store(denoising_patch); 628 | src_patchp += 8; 629 | denoising_patch += 8; 630 | 631 | if constexpr (residual) { 632 | Vec8f vec_mean = Vec8f().load(mean_patchp); 633 | vec_mean += vec_src; 634 | vec_mean.store(mean_patchp); 635 | mean_patchp += 8; 636 | } 637 | } 638 | 639 | { 640 | Vec8f vec_src = _mm256_maskload_ps(src_patchp, mask); 641 | vec_src.store(denoising_patch); // denoising_patch is padded 642 | src_patchp += stride - (block_size & (-8)); 643 | denoising_patch += block_size % 8; 644 | 645 | if constexpr (residual) { 646 | Vec8f vec_mean = Vec8f().load(mean_patchp); 647 | vec_mean += vec_src; 648 | vec_mean.store(mean_patchp); // mean_patch is padded 649 | mean_patchp += block_size % 8; 650 | } 651 | } 652 | } 653 | } 654 | 655 | if constexpr (dispatch == BlockSizeInfo::General) { 656 | denoising_patch += svd_lda - square(block_size); 657 | } else { 658 | assert(svd_lda - square(block_size) == 0); 659 | } 660 | } 661 | } 662 | #endif // __AVX2__ 663 | 664 | template 665 | static inline void load_patches( 666 | float * VS_RESTRICT denoising_patch, int svd_lda, 667 | std::conditional_t mean_patch, 668 | const std::vector & srcps, 669 | const std::vector> & errors, 670 | int stride, 671 | int active_group_size, 672 | int block_size 673 | ) noexcept { 674 | 675 | #ifdef __AVX2__ 676 | if (block_size == 8) { 677 | return load_patches_avx2( 678 | denoising_patch, svd_lda, mean_patch, 679 | srcps, errors, stride, 680 | active_group_size, block_size 681 | ); 682 | } else if ((block_size % 8) == 0) { // block_size % 8 == 0 683 | return load_patches_avx2( 684 | denoising_patch, svd_lda, mean_patch, 685 | srcps, errors, stride, 686 | active_group_size, block_size 687 | ); 688 | } else { // block_size % 8 != 0 689 | return load_patches_avx2( 690 | denoising_patch, svd_lda, mean_patch, 691 | srcps, errors, stride, 692 | active_group_size, block_size 693 | ); 694 | } 695 | #elif defined(__ARM_FEATURE_SVE) 696 | const int step = static_cast(svlen(svfloat32_t{})); 697 | 698 | for (int i = 0; i < active_group_size; ++i) { 699 | auto [error, bm_x, bm_y, bm_t] = errors[i]; 700 | 701 | const float * VS_RESTRICT src_patchp = &srcps[bm_t][bm_y * stride + bm_x]; 702 | 703 | float * VS_RESTRICT mean_patchp {nullptr}; 704 | if constexpr (residual) { 705 | mean_patchp = mean_patch; 706 | } 707 | 708 | for (int patch_y = 0; patch_y < block_size; ++patch_y) { 709 | for (int patch_x = 0; patch_x < block_size; patch_x += step) { 710 | auto predicate = svwhilelt_b32(patch_x, block_size); 711 | 712 | auto src_val = svld1(predicate, &src_patchp[patch_x]); 713 | 714 | svst1(predicate, &denoising_patch[patch_x], src_val); 715 | 716 | if constexpr (residual) { 717 | auto mean_val = svld1(predicate, &mean_patchp[patch_x]); 718 | svst1(predicate, &mean_patchp[patch_x], svadd_z(predicate, mean_val, src_val)); 719 | } 720 | } 721 | 722 | src_patchp += stride; 723 | denoising_patch += block_size; 724 | } 725 | 726 | denoising_patch += svd_lda - square(block_size); 727 | } 728 | #else 729 | for (int i = 0; i < active_group_size; ++i) { 730 | auto [error, bm_x, bm_y, bm_t] = errors[i]; 731 | 732 | const float * VS_RESTRICT src_patchp = &srcps[bm_t][bm_y * stride + bm_x]; 733 | 734 | float * VS_RESTRICT mean_patchp {nullptr}; 735 | if constexpr (residual) { 736 | mean_patchp = mean_patch; 737 | } 738 | 739 | for (int patch_y = 0; patch_y < block_size; ++patch_y) { 740 | for (int patch_x = 0; patch_x < block_size; ++patch_x) { 741 | float src_val = src_patchp[patch_x]; 742 | 743 | denoising_patch[patch_x] = src_val; 744 | 745 | if constexpr (residual) { 746 | mean_patchp[patch_x] += src_val; 747 | } 748 | } 749 | 750 | src_patchp += stride; 751 | denoising_patch += block_size; 752 | } 753 | 754 | denoising_patch += svd_lda - square(block_size); 755 | } 756 | #endif // __AVX2__ 757 | } 758 | 759 | static inline void bm_post(float * VS_RESTRICT mean_patch, float * VS_RESTRICT denoising_patch, 760 | int block_size, int active_group_size, int svd_lda) noexcept { 761 | 762 | // substract group mean 763 | 764 | for (int i = 0; i < square(block_size); ++i) { 765 | mean_patch[i] /= active_group_size; 766 | } 767 | 768 | for (int i = 0; i < active_group_size; ++i) { 769 | for (int j = 0; j < square(block_size); ++j) { 770 | denoising_patch[j] -= mean_patch[j]; 771 | } 772 | 773 | denoising_patch += svd_lda; 774 | } 775 | } 776 | 777 | static inline void extend_errors( 778 | std::vector> & errors, 779 | const std::vector> & spatial_errors, 780 | int temporal_index 781 | ) noexcept { 782 | 783 | errors.reserve(std::size(errors) + std::size(spatial_errors)); 784 | for (const auto & [error, x, y] : spatial_errors) { 785 | errors.emplace_back(error, x, y, temporal_index); 786 | } 787 | } 788 | 789 | template 790 | static inline int block_matching( 791 | float * VS_RESTRICT denoising_patch, int svd_lda, 792 | std::vector> & errors, 793 | float * VS_RESTRICT current_patch, 794 | std::conditional_t mean_patch, 795 | const std::vector & srcps, // length: 2 * radius + 1 796 | const std::vector & refps, // length: 2 * radius + 1 797 | int width, int height, int stride, 798 | int x, int y, 799 | int block_size, int group_size, int bm_range, 800 | int ps_num, int ps_range, 801 | std::vector> & center_errors, 802 | std::vector> & search_locations, 803 | std::vector> & new_locations, 804 | std::vector> & locations_copy, 805 | std::vector> & temporal_errors 806 | ) noexcept { 807 | 808 | errors.clear(); 809 | center_errors.clear(); 810 | 811 | auto radius = (static_cast(std::size(srcps)) - 1) / 2; 812 | 813 | vs_bitblt( 814 | current_patch, block_size * sizeof(float), 815 | &refps[radius][y * stride + x], stride * sizeof(float), 816 | block_size * sizeof(float), block_size 817 | ); 818 | 819 | int top = std::max(y - bm_range, 0); 820 | int bottom = std::min(y + bm_range, height - block_size); 821 | int left = std::max(x - bm_range, 0); 822 | int right = std::min(x + bm_range, width - block_size); 823 | 824 | compute_block_distances( 825 | center_errors, 826 | current_patch, 827 | &refps[radius][top * stride + left], 828 | top, bottom, left, right, 829 | stride, block_size 830 | ); 831 | 832 | if (radius == 0) { 833 | extend_errors(errors, center_errors, radius); 834 | } else { 835 | int active_ps_num = std::min( 836 | ps_num, 837 | static_cast(std::size(center_errors)) 838 | ); 839 | 840 | int active_num = std::min( 841 | std::max(group_size, ps_num), 842 | static_cast(std::size(center_errors)) 843 | ); 844 | 845 | std::partial_sort( 846 | center_errors.begin(), 847 | center_errors.begin() + active_num, 848 | center_errors.end(), 849 | [](auto a, auto b) { return std::get<0>(a) < std::get<0>(b); } 850 | ); 851 | center_errors.resize(active_num); 852 | extend_errors(errors, center_errors, radius); 853 | 854 | for (int direction = -1; direction <= 1; direction += 2) { 855 | temporal_errors = center_errors; // mutable 856 | 857 | for (int i = 1; i <= radius; i++) { 858 | auto temporal_index = radius + direction * i; 859 | 860 | generate_search_locations( 861 | std::data(temporal_errors), active_ps_num, 862 | block_size, width, height, ps_range, 863 | search_locations, new_locations, locations_copy 864 | ); 865 | 866 | temporal_errors.clear(); 867 | 868 | compute_block_distances( 869 | temporal_errors, 870 | current_patch, 871 | refps[temporal_index], 872 | search_locations, 873 | stride, block_size 874 | ); 875 | 876 | auto active_temporal_num = std::min( 877 | std::max(group_size, ps_num), 878 | static_cast(std::size(temporal_errors)) 879 | ); 880 | 881 | std::partial_sort( 882 | temporal_errors.begin(), 883 | temporal_errors.begin() + active_temporal_num, 884 | temporal_errors.end(), 885 | [](auto a, auto b) { return std::get<0>(a) < std::get<0>(b); } 886 | ); 887 | temporal_errors.resize(active_temporal_num); 888 | extend_errors(errors, temporal_errors, temporal_index); 889 | } 890 | } 891 | } 892 | 893 | int active_group_size = std::min(group_size, static_cast(std::size(errors))); 894 | std::partial_sort( 895 | errors.begin(), 896 | errors.begin() + active_group_size, 897 | errors.end(), 898 | [](auto a, auto b) { return std::get<0>(a) < std::get<0>(b); } 899 | ); 900 | errors.resize(active_group_size); 901 | bool center = false; 902 | for (int i = 0; i < active_group_size; i++) { 903 | const auto & [_, bm_x, bm_y, bm_t] = errors[i]; 904 | if (bm_x == x && bm_y == y && bm_t == radius) { 905 | center = true; 906 | } 907 | } 908 | if (!center) { 909 | errors[0] = std::make_tuple(0.0f, x, y, radius); 910 | } 911 | 912 | load_patches( 913 | denoising_patch, svd_lda, mean_patch, 914 | srcps, errors, stride, active_group_size, block_size); 915 | 916 | if constexpr (residual) { 917 | bm_post(mean_patch, denoising_patch, block_size, active_group_size, svd_lda); 918 | } 919 | 920 | return active_group_size; 921 | } 922 | 923 | template 924 | static inline WnnmInfo patch_estimation( 925 | float * VS_RESTRICT denoising_patch, int svd_lda, 926 | float & adaptive_weight, 927 | float sigma, 928 | int block_size, int active_group_size, 929 | const float * VS_RESTRICT mean_patch, 930 | bool adaptive_aggregation, 931 | float * VS_RESTRICT svd_s, 932 | float * VS_RESTRICT svd_u, int svd_ldu, 933 | float * VS_RESTRICT svd_vt, int svd_ldvt, 934 | float * VS_RESTRICT svd_work, int svd_lwork, int * VS_RESTRICT svd_iwork 935 | ) noexcept { 936 | 937 | int m = square(block_size); 938 | int n = active_group_size; 939 | 940 | int svd_info; 941 | #if defined(__INTEL_MKL__) 942 | sgesdd( 943 | "S", &m, &n, 944 | denoising_patch, &svd_lda, 945 | svd_s, 946 | svd_u, &svd_ldu, 947 | svd_vt, &svd_ldvt, 948 | svd_work, &svd_lwork, svd_iwork, &svd_info 949 | ); 950 | #else 951 | sgesdd_( 952 | "S", &m, &n, 953 | denoising_patch, &svd_lda, 954 | svd_s, 955 | svd_u, &svd_ldu, 956 | svd_vt, &svd_ldvt, 957 | svd_work, &svd_lwork, svd_iwork, &svd_info 958 | ); 959 | #endif // defined(__INTEL_MKL__) 960 | 961 | if (svd_info != 0) { 962 | return WnnmInfo::FAILURE; 963 | } 964 | 965 | // WNNP with parameter epsilon ignored 966 | const float constant = 8.f * sqrtf(2.0f * n) * square(sigma); 967 | 968 | int k = 1; 969 | if constexpr (residual) { 970 | k = 0; 971 | } 972 | 973 | for ( ; k < std::min(m, n); ++k) { 974 | float s = svd_s[k]; 975 | float tmp = square(s) - constant; 976 | if (tmp > 0.f) { 977 | svd_s[k] = (s + sqrtf(tmp)) * 0.5f; 978 | } else { 979 | break; 980 | } 981 | } 982 | 983 | if (adaptive_aggregation) { 984 | adaptive_weight = (k > 0) ? (1.f / k) : 1.0f; 985 | } 986 | 987 | // gemm 988 | if (m < n) { 989 | float * VS_RESTRICT svd_up {svd_u}; 990 | 991 | for (int i = 0; i < k; ++i) { 992 | for (int j = 0; j < m; ++j) { 993 | svd_up[j] *= svd_s[i]; 994 | } 995 | svd_up += svd_ldu; 996 | } 997 | } else { 998 | float * VS_RESTRICT svd_vtp {svd_vt}; 999 | 1000 | for (int i = 0; i < n; ++i) { 1001 | for (int j = 0; j < k; ++j) { 1002 | svd_vtp[j] *= svd_s[j]; 1003 | } 1004 | svd_vtp += svd_ldvt; 1005 | } 1006 | } 1007 | 1008 | constexpr float alpha = 1.0f; 1009 | constexpr float beta = 0.0f; 1010 | #if defined(__INTEL_MKL__) 1011 | sgemm("N", "N", &m, &n, &k, &alpha, svd_u, &svd_ldu, svd_vt, &svd_ldvt, &beta, denoising_patch, &svd_lda); 1012 | #else 1013 | sgemm_("N", "N", &m, &n, &k, &alpha, svd_u, &svd_ldu, svd_vt, &svd_ldvt, &beta, denoising_patch, &svd_lda); 1014 | #endif // defined(__INTEL_MKL__) 1015 | 1016 | if constexpr (residual) { 1017 | for (int i = 0; i < active_group_size; ++i) { 1018 | for (int patch_i = 0; patch_i < square(block_size); ++patch_i) { 1019 | denoising_patch[patch_i] += mean_patch[patch_i]; 1020 | } 1021 | 1022 | denoising_patch += svd_lda; 1023 | } 1024 | } 1025 | 1026 | return WnnmInfo::SUCCESS; 1027 | } 1028 | 1029 | static inline void col2im( 1030 | float * VS_RESTRICT intermediate, 1031 | const float * VS_RESTRICT denoising_patch, int svd_lda, 1032 | const std::vector> & errors, 1033 | int height, int intermediate_stride, 1034 | int block_size, int active_group_size, 1035 | float adaptive_weight 1036 | ) noexcept { 1037 | 1038 | for (int i = 0; i < active_group_size; ++i) { 1039 | auto [error, bm_x, bm_y, bm_t] = errors[i]; 1040 | 1041 | float * VS_RESTRICT wdstp = &intermediate[(bm_t * 2 * height + bm_y) * intermediate_stride + bm_x]; 1042 | float * VS_RESTRICT weightp = &intermediate[((bm_t * 2 + 1) * height + bm_y) * intermediate_stride + bm_x]; 1043 | 1044 | for (int patch_y = 0; patch_y < block_size; ++patch_y) { 1045 | for (int patch_x = 0; patch_x < block_size; ++patch_x) { 1046 | wdstp[patch_x] += denoising_patch[patch_x] * adaptive_weight; 1047 | weightp[patch_x] += adaptive_weight; 1048 | } 1049 | 1050 | wdstp += intermediate_stride; 1051 | weightp += intermediate_stride; 1052 | denoising_patch += block_size; 1053 | } 1054 | 1055 | denoising_patch += svd_lda - square(block_size); 1056 | } 1057 | } 1058 | 1059 | static void patch_estimation_skip( 1060 | float * VS_RESTRICT intermediate, 1061 | const std::vector & srcps, 1062 | const std::vector> & errors, 1063 | int height, int stride, int intermediate_stride, 1064 | int block_size, int active_group_size 1065 | ) noexcept { 1066 | 1067 | for (int i = 0; i < active_group_size; ++i) { 1068 | auto [error, bm_x, bm_y, bm_t] = errors[i]; 1069 | 1070 | const float * VS_RESTRICT srcp = &srcps[bm_t][bm_y * stride + bm_x]; 1071 | float * VS_RESTRICT wdstp = &intermediate[(bm_t * 2 * height + bm_y) * intermediate_stride + bm_x]; 1072 | float * VS_RESTRICT weightp = &intermediate[((bm_t * 2 + 1) * height + bm_y) * intermediate_stride + bm_x]; 1073 | 1074 | for (int patch_y = 0; patch_y < block_size; ++patch_y) { 1075 | for (int patch_x = 0; patch_x < block_size; ++patch_x) { 1076 | wdstp[patch_x] += srcp[patch_x]; 1077 | weightp[patch_x] += 1.f; 1078 | } 1079 | 1080 | srcp += stride; 1081 | wdstp += intermediate_stride; 1082 | weightp += intermediate_stride; 1083 | } 1084 | } 1085 | } 1086 | 1087 | static inline void aggregation( 1088 | float * VS_RESTRICT dstp, 1089 | const float * VS_RESTRICT intermediate, 1090 | int width, int height, int stride 1091 | ) noexcept { 1092 | 1093 | const float * VS_RESTRICT wdst = intermediate; 1094 | const float * VS_RESTRICT weight = &intermediate[height * width]; 1095 | 1096 | for (int y = 0; y < height; ++y) { 1097 | int x = 0; 1098 | 1099 | #ifdef __AVX2__ 1100 | const float * VS_RESTRICT vec_wdstp { wdst }; 1101 | const float * VS_RESTRICT vec_weightp { weight }; 1102 | float * VS_RESTRICT vec_dstp {dstp}; 1103 | 1104 | for ( ; x < (width & (-8)); x += 8) { 1105 | Vec8f vec_wdst = Vec8f().load(vec_wdstp); 1106 | Vec8f vec_weight = Vec8f().load(vec_weightp); 1107 | Vec8f vec_dst = vec_wdst * approx_recipr(vec_weight); 1108 | vec_dst.store_a(vec_dstp); 1109 | 1110 | vec_wdstp += 8; 1111 | vec_weightp += 8; 1112 | vec_dstp += 8; 1113 | } 1114 | #endif 1115 | 1116 | for ( ; x < width; ++x) { 1117 | dstp[x] = wdst[x] / weight[x]; 1118 | } 1119 | 1120 | dstp += stride; 1121 | wdst += width; 1122 | weight += width; 1123 | } 1124 | } 1125 | 1126 | template 1127 | static void process( 1128 | const std::vector & srcs, 1129 | const std::vector & refs, 1130 | VSFrameRef * dst, 1131 | WNNMData * d, 1132 | const VSAPI * vsapi 1133 | ) noexcept { 1134 | 1135 | const auto threadId = std::this_thread::get_id(); 1136 | 1137 | #ifdef __AVX2__ 1138 | auto control_word = get_control_word(); 1139 | no_subnormals(); 1140 | #endif 1141 | 1142 | Workspace workspace {}; 1143 | bool init = true; 1144 | 1145 | d->workspaces_lock.lock_shared(); 1146 | 1147 | try { 1148 | const auto & const_workspaces = d->workspaces; 1149 | workspace = const_workspaces.at(threadId); 1150 | } catch (const std::out_of_range &) { 1151 | init = false; 1152 | } 1153 | 1154 | d->workspaces_lock.unlock_shared(); 1155 | 1156 | auto vi = vsapi->getVideoInfo(d->node); 1157 | 1158 | if (!init) { 1159 | workspace.init( 1160 | vi->width, vi->height, 1161 | d->block_size, d->group_size, d->radius, 1162 | d->residual, 1163 | d->svd_lda, d->svd_ldu, d->svd_ldvt, d->svd_lwork 1164 | ); 1165 | 1166 | d->workspaces_lock.lock(); 1167 | d->workspaces.emplace(threadId, workspace); 1168 | d->workspaces_lock.unlock(); 1169 | } 1170 | 1171 | std::conditional_t mean_patch {}; 1172 | if constexpr (residual) { 1173 | mean_patch = workspace.mean_patch; 1174 | } 1175 | 1176 | std::vector> & errors = *workspace.errors; 1177 | 1178 | for (int plane = 0; plane < vi->format->numPlanes; plane++) { 1179 | if (!d->process[plane]) { 1180 | continue; 1181 | } 1182 | 1183 | const int width = vsapi->getFrameWidth(srcs[0], plane); 1184 | const int height = vsapi->getFrameHeight(srcs[0], plane); 1185 | const int stride = vsapi->getStride(srcs[0], plane) / static_cast(sizeof(float)); 1186 | std::vector srcps; 1187 | srcps.reserve(std::size(srcs)); 1188 | for (const auto & src : srcs) { 1189 | srcps.emplace_back(reinterpret_cast(vsapi->getReadPtr(src, plane))); 1190 | } 1191 | std::vector refps; 1192 | refps.reserve(std::size(refs)); 1193 | for (const auto & ref : refs) { 1194 | refps.emplace_back(reinterpret_cast(vsapi->getReadPtr(ref, plane))); 1195 | } 1196 | float * const VS_RESTRICT dstp = reinterpret_cast(vsapi->getWritePtr(dst, plane)); 1197 | 1198 | if (d->radius == 0) { 1199 | std::memset(workspace.intermediate, 0, 2 * height * width * sizeof(float)); 1200 | } else { 1201 | std::memset(dstp, 0, 2 * (2 * d->radius + 1) * height * stride * sizeof(float)); 1202 | } 1203 | 1204 | int temp_r = height - d->block_size; 1205 | int temp_c = width - d->block_size; 1206 | 1207 | for (int _y = 0; _y < temp_r + d->block_step; _y += d->block_step) { 1208 | int y = std::min(_y, temp_r); // clamp 1209 | 1210 | for (int _x = 0; _x < temp_c + d->block_step; _x += d->block_step) { 1211 | int x = std::min(_x, temp_c); // clamp 1212 | 1213 | if constexpr (residual) { 1214 | std::memset(mean_patch, 0, sizeof(float) * square(d->block_size)); 1215 | } 1216 | 1217 | int active_group_size = block_matching( 1218 | // outputs 1219 | workspace.denoising_patch, d->svd_lda, 1220 | errors, 1221 | workspace.current_patch, mean_patch, 1222 | // inputs 1223 | srcps, refps, width, height, stride, 1224 | x, y, 1225 | d->block_size, d->group_size, d->bm_range, 1226 | d->ps_num, d->ps_range, 1227 | *workspace.center_errors, 1228 | *workspace.search_locations, 1229 | *workspace.new_locations, 1230 | *workspace.locations_copy, 1231 | *workspace.temporal_errors 1232 | ); 1233 | 1234 | // patch_estimation with early skipping on SVD exception 1235 | float adaptive_weight = 1.f; 1236 | WnnmInfo info = patch_estimation( 1237 | // outputs 1238 | workspace.denoising_patch, d->svd_lda, 1239 | adaptive_weight, 1240 | // inputs 1241 | d->sigma[plane], 1242 | d->block_size, active_group_size, mean_patch, d->adaptive_aggregation, 1243 | // temporaries 1244 | workspace.svd_s, workspace.svd_u, d->svd_ldu, workspace.svd_vt, d->svd_ldvt, 1245 | workspace.svd_work, d->svd_lwork, workspace.svd_iwork 1246 | ); 1247 | 1248 | switch (info) { 1249 | case WnnmInfo::SUCCESS: { 1250 | if (d->radius == 0) { 1251 | col2im( 1252 | // output 1253 | workspace.intermediate, 1254 | // inputs 1255 | workspace.denoising_patch, d->svd_lda, 1256 | errors, height, width, 1257 | d->block_size, active_group_size, adaptive_weight 1258 | ); 1259 | } else { 1260 | col2im( 1261 | // output 1262 | dstp, 1263 | // inputs 1264 | workspace.denoising_patch, d->svd_lda, 1265 | errors, height, stride, 1266 | d->block_size, active_group_size, adaptive_weight 1267 | ); 1268 | } 1269 | break; 1270 | } 1271 | case WnnmInfo::FAILURE: { 1272 | if (d->radius == 0) { 1273 | patch_estimation_skip( 1274 | // output 1275 | workspace.intermediate, srcps, 1276 | errors, height, stride, width, 1277 | d->block_size, active_group_size 1278 | ); 1279 | } else { 1280 | patch_estimation_skip( 1281 | // output 1282 | dstp, 1283 | // inputs 1284 | srcps, 1285 | errors, height, stride, stride, 1286 | d->block_size, active_group_size 1287 | ); 1288 | } 1289 | break; 1290 | } 1291 | } 1292 | } 1293 | } 1294 | 1295 | if (d->radius == 0) { 1296 | aggregation(dstp, workspace.intermediate, width, height, stride); 1297 | } 1298 | } 1299 | 1300 | #ifdef __AVX2__ 1301 | set_control_word(control_word); 1302 | #endif 1303 | } 1304 | 1305 | static void VS_CC WNNMRawInit( 1306 | VSMap *in, VSMap *out, void **instanceData, VSNode *node, 1307 | VSCore *core, const VSAPI *vsapi 1308 | ) noexcept { 1309 | 1310 | WNNMData * d = static_cast(*instanceData); 1311 | 1312 | if (d->radius > 0) { 1313 | auto vi = *vsapi->getVideoInfo(d->node); 1314 | vi.height *= 2 * (2 * d->radius + 1); 1315 | vsapi->setVideoInfo(&vi, 1, node); 1316 | } else { 1317 | auto vi = vsapi->getVideoInfo(d->node); 1318 | vsapi->setVideoInfo(vi, 1, node); 1319 | } 1320 | } 1321 | 1322 | static const VSFrameRef *VS_CC WNNMRawGetFrame( 1323 | int n, int activationReason, void **instanceData, void **frameData, 1324 | VSFrameContext *frameCtx, VSCore *core, const VSAPI *vsapi 1325 | ) noexcept { 1326 | 1327 | auto * d = static_cast(*instanceData); 1328 | 1329 | if (activationReason == arInitial) { 1330 | auto vi = vsapi->getVideoInfo(d->node); 1331 | 1332 | int start_frame = std::max(n - d->radius, 0); 1333 | int end_frame = std::min(n + d->radius, vi->numFrames - 1); 1334 | 1335 | for (int i = start_frame; i <= end_frame; ++i) { 1336 | vsapi->requestFrameFilter(i, d->node, frameCtx); 1337 | } 1338 | if (d->ref_node) { 1339 | for (int i = start_frame; i <= end_frame; ++i) { 1340 | vsapi->requestFrameFilter(i, d->ref_node, frameCtx); 1341 | } 1342 | } 1343 | } else if (activationReason == arAllFramesReady) { 1344 | auto vi = vsapi->getVideoInfo(d->node); 1345 | 1346 | std::vector srcs; 1347 | srcs.reserve(2 * d->radius + 1); 1348 | for (int i = -d->radius; i <= d->radius; i++) { 1349 | auto frame_id = std::clamp(n + i, 0, vi->numFrames - 1); 1350 | srcs.emplace_back(vsapi->getFrameFilter(frame_id, d->node, frameCtx)); 1351 | } 1352 | 1353 | std::vector refs; 1354 | if (d->ref_node) { 1355 | refs.reserve(2 * d->radius + 1); 1356 | for (int i = -d->radius; i <= d->radius; i++) { 1357 | auto frame_id = std::clamp(n + i, 0, vi->numFrames - 1); 1358 | refs.emplace_back(vsapi->getFrameFilter(frame_id, d->ref_node, frameCtx)); 1359 | } 1360 | } else { 1361 | refs = srcs; 1362 | } 1363 | 1364 | const auto & center_src = srcs[d->radius]; 1365 | VSFrameRef * dst; 1366 | if (d->radius == 0) { 1367 | const VSFrameRef * fr[] { 1368 | d->process[0] ? nullptr : center_src, 1369 | d->process[1] ? nullptr : center_src, 1370 | d->process[2] ? nullptr : center_src 1371 | }; 1372 | const int pl[] { 0, 1, 2 }; 1373 | dst = vsapi->newVideoFrame2(vi->format, vi->width, vi->height, fr, pl, center_src, core); 1374 | } else { 1375 | dst = vsapi->newVideoFrame(vi->format, vi->width, 2 * (2 * d->radius + 1) * vi->height, center_src, core); 1376 | } 1377 | 1378 | if (d->residual) { 1379 | process(srcs, refs, dst, d, vsapi); 1380 | } else { 1381 | process(srcs, refs, dst, d, vsapi); 1382 | } 1383 | 1384 | for (const auto & src : srcs) { 1385 | vsapi->freeFrame(src); 1386 | } 1387 | 1388 | if (d->ref_node) { 1389 | for (const auto & ref : refs) { 1390 | vsapi->freeFrame(ref); 1391 | } 1392 | } 1393 | 1394 | return dst; 1395 | } 1396 | 1397 | return nullptr; 1398 | } 1399 | 1400 | static void VS_CC WNNMRawFree( 1401 | void *instanceData, VSCore *core, const VSAPI *vsapi 1402 | ) noexcept { 1403 | 1404 | auto d = static_cast(instanceData); 1405 | 1406 | vsapi->freeNode(d->node); 1407 | 1408 | if (d->ref_node) { 1409 | vsapi->freeNode(d->ref_node); 1410 | } 1411 | 1412 | for (auto & [_, workspace] : d->workspaces) { 1413 | workspace.release(); 1414 | } 1415 | 1416 | delete d; 1417 | } 1418 | 1419 | static void VS_CC WNNMRawCreate( 1420 | const VSMap *in, VSMap *out, void *userData, 1421 | VSCore *core, const VSAPI *vsapi 1422 | ) noexcept { 1423 | 1424 | auto d = std::make_unique(); 1425 | 1426 | d->node = vsapi->propGetNode(in, "clip", 0, nullptr); 1427 | 1428 | auto set_error = [&](const std::string & error) -> void { 1429 | vsapi->setError(out, ("WNNM: " + error).c_str()); 1430 | vsapi->freeNode(d->node); 1431 | return ; 1432 | }; 1433 | 1434 | auto vi = vsapi->getVideoInfo(d->node); 1435 | 1436 | if (!isConstantFormat(vi) || vi->format->sampleType == stInteger || 1437 | (vi->format->sampleType == stFloat && vi->format->bitsPerSample != 32) 1438 | ) { 1439 | return set_error("only constant format 32 bit float input supported"); 1440 | } 1441 | 1442 | int error; 1443 | 1444 | for (unsigned i = 0; i < std::size(d->sigma); i++) { 1445 | d->sigma[i] = static_cast(vsapi->propGetFloat(in, "sigma", i, &error)); 1446 | if (error) { 1447 | d->sigma[i] = (i == 0) ? 3.0f : d->sigma[i - 1]; 1448 | } 1449 | if (d->sigma[i] < 0.0f) { 1450 | return set_error("\"sigma\" must be positive"); 1451 | } 1452 | } 1453 | 1454 | for (unsigned i = 0; i < std::size(d->sigma); ++i) { 1455 | if (d->sigma[i] < std::numeric_limits::epsilon()) { 1456 | d->process[i] = false; 1457 | } else { 1458 | d->process[i] = true; 1459 | d->sigma[i] /= 255.f; 1460 | } 1461 | } 1462 | 1463 | d->block_size = int64ToIntS(vsapi->propGetInt(in, "block_size", 0, &error)); 1464 | if (error) { 1465 | // d->block_size = 6; 1466 | d->block_size = 8; // more optimized 1467 | } else if (d->block_size <= 0) { 1468 | return set_error("\"block_size\" must be positive"); 1469 | } 1470 | 1471 | d->block_step = int64ToIntS(vsapi->propGetInt(in, "block_step", 0, &error)); 1472 | if (error) { 1473 | // d->block_step = 6; 1474 | d->block_step = d->block_size; 1475 | } else if (d->block_step <= 0 || d->block_step > d->block_size) { 1476 | return set_error("\"block_step\" must be positive and no larger than \"block_size\""); 1477 | } 1478 | 1479 | d->group_size = int64ToIntS(vsapi->propGetInt(in, "group_size", 0, &error)); 1480 | if (error) { 1481 | d->group_size = 8; 1482 | } else if (d->group_size <= 0) { 1483 | return set_error("\"group_size\" must be positive"); 1484 | } 1485 | 1486 | d->bm_range = int64ToIntS(vsapi->propGetInt(in, "bm_range", 0, &error)); 1487 | if (error) { 1488 | d->bm_range = 7; 1489 | } else if (d->bm_range < 0) { 1490 | return set_error("\"bm_range\" must be non-negative"); 1491 | } 1492 | 1493 | d->radius = int64ToIntS(vsapi->propGetInt(in, "radius", 0, &error)); 1494 | if (error) { 1495 | d->radius = 0; 1496 | } else if (d->radius < 0) { 1497 | return set_error("\"radius\" must be non-negative"); 1498 | } 1499 | 1500 | d->ps_num = int64ToIntS(vsapi->propGetInt(in, "ps_num", 0, &error)); 1501 | if (error) { 1502 | d->ps_num = 2; 1503 | } else if (d->ps_num <= 0) { 1504 | return set_error("\"ps_num\" must be positive"); 1505 | } 1506 | 1507 | d->ps_range = int64ToIntS(vsapi->propGetInt(in, "ps_range", 0, &error)); 1508 | if (error) { 1509 | d->ps_range = 4; 1510 | } else if (d->ps_range < 0) { 1511 | return set_error("\"ps_range\" must be non-negative"); 1512 | } 1513 | 1514 | d->svd_lda = m16(square(d->block_size)); 1515 | d->svd_ldu = m16(square(d->block_size)); 1516 | d->svd_ldvt = m16(std::min(square(d->block_size), d->group_size)); 1517 | 1518 | d->residual = !!vsapi->propGetInt(in, "residual", 0, &error); 1519 | if (error) { 1520 | d->residual = false; 1521 | } 1522 | 1523 | d->adaptive_aggregation = !!vsapi->propGetInt(in, "adaptive_aggregation", 0, &error); 1524 | if (error) { 1525 | d->adaptive_aggregation = true; 1526 | } 1527 | 1528 | d->ref_node = vsapi->propGetNode(in, "rclip", 0, &error); 1529 | if (error) { 1530 | d->ref_node = nullptr; 1531 | } else { 1532 | auto ref_vi = vsapi->getVideoInfo(d->ref_node); 1533 | if (!isSameFormat(vi, ref_vi) || vi->numFrames != ref_vi->numFrames) { 1534 | return set_error("\"rclip\" must be of the same format and number of frames as \"clip\""); 1535 | } 1536 | } 1537 | 1538 | VSCoreInfo core_info; 1539 | vsapi->getCoreInfo2(core, &core_info); 1540 | auto numThreads = core_info.numThreads; 1541 | d->workspaces.reserve(numThreads); 1542 | 1543 | int svd_m = square(d->block_size); 1544 | int svd_n = d->group_size; 1545 | d->svd_lwork = std::min(svd_m, svd_n) * (6 + 4 * std::min(svd_m, svd_n)) + std::max(svd_m, svd_n); 1546 | 1547 | vsapi->createFilter(in, out, "WNNMRaw", WNNMRawInit, WNNMRawGetFrame, WNNMRawFree, fmParallel, 0, d.release(), core); 1548 | } 1549 | 1550 | struct VAggregateData { 1551 | VSNodeRef * node; 1552 | 1553 | VSNodeRef * src_node; 1554 | const VSVideoInfo * src_vi; 1555 | 1556 | std::array process; // sigma != 0 1557 | 1558 | int radius; 1559 | 1560 | std::unordered_map buffer; 1561 | std::shared_mutex buffer_lock; 1562 | }; 1563 | 1564 | static void VS_CC VAggregateInit( 1565 | VSMap *in, VSMap *out, void **instanceData, VSNode *node, 1566 | VSCore *core, const VSAPI *vsapi 1567 | ) noexcept { 1568 | 1569 | auto * d = static_cast(*instanceData); 1570 | 1571 | vsapi->setVideoInfo(d->src_vi, 1, node); 1572 | } 1573 | 1574 | static const VSFrameRef *VS_CC VAggregateGetFrame( 1575 | int n, int activationReason, void **instanceData, void **frameData, 1576 | VSFrameContext *frameCtx, VSCore *core, const VSAPI *vsapi 1577 | ) noexcept { 1578 | 1579 | auto * d = static_cast(*instanceData); 1580 | 1581 | if (activationReason == arInitial) { 1582 | int start_frame = std::max(n - d->radius, 0); 1583 | int end_frame = std::min(n + d->radius, d->src_vi->numFrames - 1); 1584 | 1585 | for (int i = start_frame; i <= end_frame; ++i) { 1586 | vsapi->requestFrameFilter(i, d->node, frameCtx); 1587 | } 1588 | vsapi->requestFrameFilter(n, d->src_node, frameCtx); 1589 | } else if (activationReason == arAllFramesReady) { 1590 | const VSFrameRef * src_frame = vsapi->getFrameFilter(n, d->src_node, frameCtx); 1591 | 1592 | std::vector frames; 1593 | frames.reserve(2 * d->radius + 1); 1594 | for (int i = n - d->radius; i <= n + d->radius; ++i) { 1595 | auto frame_id = std::clamp(i, 0, d->src_vi->numFrames - 1); 1596 | frames.emplace_back(vsapi->getFrameFilter(frame_id, d->node, frameCtx)); 1597 | } 1598 | 1599 | float * buffer {}; 1600 | { 1601 | const auto thread_id = std::this_thread::get_id(); 1602 | bool init = true; 1603 | 1604 | d->buffer_lock.lock_shared(); 1605 | 1606 | try { 1607 | const auto & const_buffer = d->buffer; 1608 | buffer = const_buffer.at(thread_id); 1609 | } catch (const std::out_of_range &) { 1610 | init = false; 1611 | } 1612 | 1613 | d->buffer_lock.unlock_shared(); 1614 | 1615 | if (!init) { 1616 | assert(d->process[0] || d->src_vi->format->numPlanes > 1); 1617 | 1618 | const int max_width { 1619 | d->process[0] ? 1620 | vsapi->getFrameWidth(src_frame, 0) : 1621 | vsapi->getFrameWidth(src_frame, 1) 1622 | }; 1623 | 1624 | buffer = reinterpret_cast(std::malloc(2 * max_width * sizeof(float))); 1625 | 1626 | std::lock_guard _ { d->buffer_lock }; 1627 | d->buffer.emplace(thread_id, buffer); 1628 | } 1629 | } 1630 | 1631 | const VSFrameRef * fr[] { 1632 | d->process[0] ? nullptr : src_frame, 1633 | d->process[1] ? nullptr : src_frame, 1634 | d->process[2] ? nullptr : src_frame 1635 | }; 1636 | constexpr int pl[] { 0, 1, 2 }; 1637 | auto dst_frame = vsapi->newVideoFrame2( 1638 | d->src_vi->format, 1639 | d->src_vi->width, d->src_vi->height, 1640 | fr, pl, src_frame, core); 1641 | 1642 | for (int plane = 0; plane < d->src_vi->format->numPlanes; ++plane) { 1643 | if (d->process[plane]) { 1644 | int plane_width = vsapi->getFrameWidth(src_frame, plane); 1645 | int plane_height = vsapi->getFrameHeight(src_frame, plane); 1646 | int plane_stride = vsapi->getStride(src_frame, plane) / sizeof(float); 1647 | 1648 | std::vector srcps; 1649 | srcps.reserve(2 * d->radius + 1); 1650 | for (int i = 0; i < 2 * d->radius + 1; ++i) { 1651 | srcps.emplace_back(reinterpret_cast(vsapi->getReadPtr(frames[i], plane))); 1652 | } 1653 | 1654 | auto dstp = reinterpret_cast(vsapi->getWritePtr(dst_frame, plane)); 1655 | 1656 | for (int y = 0; y < plane_height; ++y) { 1657 | std::memset(buffer, 0, 2 * plane_width * sizeof(float)); 1658 | for (int i = 0; i < 2 * d->radius + 1; ++i) { 1659 | auto agg_src = srcps[i]; 1660 | // bm3d.VAggregate implements zero padding in temporal dimension 1661 | // here we implements replication padding 1662 | agg_src += ( 1663 | std::clamp(2 * d->radius - i, n - d->src_vi->numFrames + 1 + d->radius, n + d->radius) 1664 | * 2 * plane_height + y) * plane_stride; 1665 | for (int x = 0; x < plane_width; ++x) { 1666 | buffer[x] += agg_src[x]; 1667 | } 1668 | agg_src += plane_height * plane_stride; 1669 | for (int x = 0; x < plane_width; ++x) { 1670 | buffer[plane_width + x] += agg_src[x]; 1671 | } 1672 | } 1673 | for (int x = 0; x < plane_width; ++x) { 1674 | dstp[x] = buffer[x] / buffer[plane_width + x]; 1675 | } 1676 | dstp += plane_stride; 1677 | } 1678 | } 1679 | } 1680 | 1681 | for (const auto & frame : frames) { 1682 | vsapi->freeFrame(frame); 1683 | } 1684 | vsapi->freeFrame(src_frame); 1685 | 1686 | return dst_frame; 1687 | } 1688 | 1689 | return nullptr; 1690 | } 1691 | 1692 | static void VS_CC VAggregateFree( 1693 | void *instanceData, VSCore *core, const VSAPI *vsapi 1694 | ) noexcept { 1695 | 1696 | auto * d = static_cast(instanceData); 1697 | 1698 | for (const auto & [_, ptr] : d->buffer) { 1699 | std::free(ptr); 1700 | } 1701 | 1702 | vsapi->freeNode(d->src_node); 1703 | vsapi->freeNode(d->node); 1704 | 1705 | delete d; 1706 | } 1707 | 1708 | static void VS_CC VAggregateCreate( 1709 | const VSMap *in, VSMap *out, void *userData, 1710 | VSCore *core, const VSAPI *vsapi 1711 | ) noexcept { 1712 | 1713 | { 1714 | int error; 1715 | bool internal = !!vsapi->propGetInt(in, "internal", 0, &error); 1716 | if (error) { 1717 | internal = false; 1718 | } 1719 | if (!internal) { 1720 | vsapi->setError( 1721 | out, 1722 | "this interface is for internal use only, please use \"wnnm.WNNM()\" directly" 1723 | ); 1724 | return ; 1725 | } 1726 | } 1727 | 1728 | auto d = std::make_unique(); 1729 | 1730 | d->node = vsapi->propGetNode(in, "clip", 0, nullptr); 1731 | auto vi = vsapi->getVideoInfo(d->node); 1732 | d->src_node = vsapi->propGetNode(in, "src", 0, nullptr); 1733 | d->src_vi = vsapi->getVideoInfo(d->src_node); 1734 | 1735 | d->radius = (vi->height / d->src_vi->height - 2) / 4; 1736 | 1737 | d->process.fill(false); 1738 | int num_planes_args = vsapi->propNumElements(in, "planes"); 1739 | for (int i = 0; i < num_planes_args; ++i) { 1740 | int plane = int64ToIntS(vsapi->propGetInt(in, "planes", i, nullptr)); 1741 | d->process[plane] = true; 1742 | } 1743 | 1744 | VSCoreInfo core_info; 1745 | vsapi->getCoreInfo2(core, &core_info); 1746 | d->buffer.reserve(core_info.numThreads); 1747 | 1748 | vsapi->createFilter( 1749 | in, out, "VAggregate", 1750 | VAggregateInit, VAggregateGetFrame, VAggregateFree, 1751 | fmParallel, 0, d.release(), core); 1752 | } 1753 | 1754 | static void VS_CC WNNMCreate( 1755 | const VSMap *in, VSMap *out, void *userData, 1756 | VSCore *core, const VSAPI *vsapi 1757 | ) noexcept { 1758 | 1759 | std::array process; 1760 | process.fill(true); 1761 | 1762 | int num_sigma_args = vsapi->propNumElements(in, "sigma"); 1763 | for (int i = 0; i < std::min(3, num_sigma_args); ++i) { 1764 | auto sigma = vsapi->propGetFloat(in, "sigma", i, nullptr); 1765 | if (sigma < std::numeric_limits::epsilon()) { 1766 | process[i] = false; 1767 | } 1768 | } 1769 | if (num_sigma_args > 0) { // num_sigma_args may be -1 1770 | for (int i = num_sigma_args; i < 3; ++i) { 1771 | process[i] = process[i - 1]; 1772 | } 1773 | } 1774 | 1775 | bool skip = true; 1776 | auto src = vsapi->propGetNode(in, "clip", 0, nullptr); 1777 | auto src_vi = vsapi->getVideoInfo(src); 1778 | for (int i = 0; i < src_vi->format->numPlanes; ++i) { 1779 | skip &= !process[i]; 1780 | } 1781 | if (skip) { 1782 | vsapi->propSetNode(out, "clip", src, paReplace); 1783 | vsapi->freeNode(src); 1784 | return ; 1785 | } 1786 | 1787 | auto map = vsapi->invoke(myself, "WNNMRaw", in); 1788 | if (auto error = vsapi->getError(map); error) { 1789 | vsapi->setError(out, error); 1790 | vsapi->freeMap(map); 1791 | vsapi->freeNode(src); 1792 | return ; 1793 | } 1794 | 1795 | int err; 1796 | int radius = int64ToIntS(vsapi->propGetInt(in, "radius", 0, &err)); 1797 | if (err) { 1798 | radius = 0; 1799 | } 1800 | if (radius == 0) { 1801 | // spatial WNNM should handle everything itself 1802 | auto node = vsapi->propGetNode(map, "clip", 0, nullptr); 1803 | vsapi->freeMap(map); 1804 | vsapi->propSetNode(out, "clip", node, paReplace); 1805 | vsapi->freeNode(node); 1806 | vsapi->freeNode(src); 1807 | return ; 1808 | } 1809 | 1810 | vsapi->propSetNode(map, "src", src, paReplace); 1811 | vsapi->freeNode(src); 1812 | 1813 | for (int i = 0; i < 3; ++i) { 1814 | if (process[i]) { 1815 | vsapi->propSetInt(map, "planes", i, paAppend); 1816 | } 1817 | } 1818 | 1819 | vsapi->propSetInt(map, "internal", 1, paReplace); 1820 | 1821 | auto map2 = vsapi->invoke(myself, "VAggregate", map); 1822 | vsapi->freeMap(map); 1823 | if (auto error = vsapi->getError(map2); error) { 1824 | vsapi->setError(out, error); 1825 | vsapi->freeMap(map2); 1826 | return ; 1827 | } 1828 | 1829 | auto node = vsapi->propGetNode(map2, "clip", 0, nullptr); 1830 | vsapi->freeMap(map2); 1831 | vsapi->propSetNode(out, "clip", node, paReplace); 1832 | vsapi->freeNode(node); 1833 | } 1834 | 1835 | VS_EXTERNAL_API(void) VapourSynthPluginInit( 1836 | VSConfigPlugin configFunc, VSRegisterFunction registerFunc, VSPlugin *plugin 1837 | ) noexcept { 1838 | 1839 | myself = plugin; 1840 | 1841 | configFunc( 1842 | "com.wolframrhodium.wnnm", 1843 | "wnnm", "Weighted Nuclear Norm Minimization Denoiser", 1844 | VAPOURSYNTH_API_VERSION, 1, plugin 1845 | ); 1846 | 1847 | constexpr auto wnnm_args { 1848 | "clip:clip;" 1849 | "sigma:float[]:opt;" 1850 | "block_size:int:opt;" 1851 | "block_step:int:opt;" 1852 | "group_size:int:opt;" 1853 | "bm_range:int:opt;" 1854 | "radius:int:opt;" 1855 | "ps_num:int:opt;" 1856 | "ps_range:int:opt;" 1857 | "residual:int:opt;" 1858 | "adaptive_aggregation:int:opt;" 1859 | "rclip:clip:opt;" 1860 | }; 1861 | 1862 | registerFunc("WNNMRaw", wnnm_args, WNNMRawCreate, nullptr, plugin); 1863 | 1864 | registerFunc( 1865 | "VAggregate", 1866 | "clip:clip;" 1867 | "src:clip;" 1868 | "planes:int[];" 1869 | "internal:int:opt;", 1870 | VAggregateCreate, nullptr, plugin); 1871 | 1872 | registerFunc("WNNM", wnnm_args, WNNMCreate, nullptr, plugin); 1873 | 1874 | auto getVersion = [](const VSMap *, VSMap * out, void *, VSCore *, const VSAPI *vsapi) { 1875 | vsapi->propSetData(out, "version", VERSION, -1, paReplace); 1876 | 1877 | #if defined(__INTEL_MKL__) 1878 | std::ostringstream mkl_version_build_str; 1879 | mkl_version_build_str << __INTEL_MKL__ << '.' << __INTEL_MKL_MINOR__ << '.' << __INTEL_MKL_UPDATE__; 1880 | 1881 | vsapi->propSetData(out, "mkl_version_build", mkl_version_build_str.str().c_str(), -1, paReplace); 1882 | 1883 | MKLVersion version; 1884 | mkl_get_version(&version); 1885 | 1886 | vsapi->propSetData(out, "mkl_processor", version.Processor, -1, paReplace); 1887 | 1888 | std::ostringstream mkl_version_str; 1889 | mkl_version_str << version.MajorVersion << '.' << version.MinorVersion << '.' << version.UpdateVersion; 1890 | 1891 | vsapi->propSetData(out, "mkl_version", mkl_version_str.str().c_str(), -1, paReplace); 1892 | #else 1893 | int major, minor, patch; 1894 | const char * tag; 1895 | armplversion(&major, &minor, &patch, &tag); 1896 | 1897 | std::ostringstream armpl_version; 1898 | armpl_version << major << '.' << minor << '.' << patch << '.' << tag; 1899 | vsapi->propSetData(out, "armpl_version", armpl_version.str().c_str(), -1, paReplace); 1900 | #endif 1901 | }; 1902 | registerFunc("Version", "", getVersion, nullptr, plugin); 1903 | } 1904 | 1905 | --------------------------------------------------------------------------------