├── .github └── workflows │ └── cmake-multi-platform.yml ├── .gitignore ├── CMakeLists.txt ├── LICENCE ├── README.md ├── build-standalone.sh ├── demo ├── README.txt ├── demo_mp.m ├── demo_mp.py ├── input.bin ├── input.cfg ├── input.xml ├── input_generator.py └── input_smp.db ├── include ├── Array.h ├── Atom.h ├── Block.h ├── BlockAtom.h ├── BlockAtomBase.h ├── BlockAtomCache.h ├── BlockAtomObjective.h ├── BlockAtomProductCalculator.h ├── BlockDictionary.h ├── BlockDictionaryStructure.h ├── BlockHelper.h ├── BlockInterface.h ├── BlockStructure.h ├── BookWriter.h ├── BufferedWriter.h ├── CUDA.h ├── Configuration.h ├── Corrector.h ├── DeltaAtom.h ├── DeltaDictionary.h ├── Dictionary.h ├── EpochIndex.h ├── ExportedAtom.h ├── ExtenderLoop.h ├── ExtraData.h ├── Extractor.h ├── Family.h ├── File.h ├── GaussianFamily.h ├── IndexRange.h ├── Logger.h ├── OptimizationMode.h ├── PinnedArray.h ├── Progress.h ├── ProtoRequest.h ├── Semaphore.h ├── SignalReader.h ├── SpectrogramCalculator.h ├── SpectrogramCalculatorCUDA.h ├── SpectrogramCalculatorDummy.h ├── SpectrogramCalculatorFFTW.h ├── SpectrogramLoop.h ├── SpectrogramRequest.h ├── SpectrumCalculator.h ├── TaskQueue.h ├── Testing.h ├── Thread.h ├── Timer.h ├── TriangularFamily.h ├── Types.h ├── Worker.h └── WorkerLoop.h ├── run-benchmark.sh ├── src ├── Block.cpp ├── BlockAtom.cpp ├── BlockAtomBase.cpp ├── BlockAtomCache.cpp ├── BlockAtomObjective.cpp ├── BlockAtomProductCalculator.cpp ├── BlockDictionary.cpp ├── BlockDictionaryStructure.cpp ├── BlockHelper.cpp ├── BookWriter.cpp ├── BufferedWriter.cpp ├── Configuration.cpp ├── Corrector.cpp ├── DeltaAtom.cpp ├── DeltaDictionary.cpp ├── ExportedAtom.cpp ├── Extractor.cpp ├── Family.cpp ├── File.cpp ├── GaussianFamily.cpp ├── IndexRange.cpp ├── Logger.cpp ├── Progress.cpp ├── SignalReader.cpp ├── SpectrogramCalculatorCUDA.cu ├── SpectrogramCalculatorCUDACallback.cu ├── SpectrogramCalculatorDummy.cpp ├── SpectrogramCalculatorFFTW.cpp ├── Thread.cpp ├── Timer.cpp ├── TriangularFamily.cpp ├── Types.cpp ├── Worker.cpp └── special │ ├── alloc.cpp │ ├── alloc.cu │ └── empi.cpp ├── tests ├── special │ ├── benchmark.cpp │ └── test-cuda.cpp ├── test-allocation.cpp ├── test-best-match.cpp ├── test-block-atom.cpp ├── test-block.cpp ├── test-book-writer.cpp ├── test-corrector.cpp ├── test-dictionary.cpp ├── test-envelope.cpp ├── test-fftw.cpp ├── test-index-range.cpp ├── test-measure.cpp ├── test-move-semantics.cpp ├── test-product-calculator.cpp ├── test-rounding.cpp ├── test-signal-reader.cpp ├── test-subsample.cpp ├── test-task-queue.cpp └── test-worker.cpp └── vendor ├── include ├── CLI11.hpp ├── nelder_mead.h └── sqlite3.h └── sqlite3.c /.github/workflows/cmake-multi-platform.yml: -------------------------------------------------------------------------------- 1 | # This starter workflow is for a CMake project running on multiple platforms. There is a different starter workflow if you just want a single platform. 2 | # See: https://github.com/actions/starter-workflows/blob/main/ci/cmake-single-platform.yml 3 | name: CMake on multiple platforms 4 | 5 | on: 6 | push: 7 | branches: [ "master" ] 8 | pull_request: 9 | branches: [ "master" ] 10 | 11 | jobs: 12 | build: 13 | runs-on: ${{ matrix.os }} 14 | 15 | strategy: 16 | # Set fail-fast to false to ensure that feedback is delivered for all matrix combinations. Consider changing this to true when your workflow is stable. 17 | fail-fast: false 18 | 19 | # Set up a matrix to run the following 2 configurations: 20 | # 1. 21 | # 2. 22 | # 23 | # To add more build types (Release, Debug, RelWithDebInfo, etc.) customize the build_type list. 24 | matrix: 25 | os: [ubuntu-latest, macos-12, macos-14, windows-latest] 26 | build_type: [Release] 27 | 28 | steps: 29 | - uses: actions/checkout@v4 30 | with: 31 | fetch-depth: 0 32 | 33 | - name: Set reusable strings 34 | # Turn repeated input strings (such as the build output directory) into step outputs. These step outputs can be used throughout the workflow file. 35 | id: strings 36 | shell: bash 37 | run: | 38 | echo "build-output-dir=${{ github.workspace }}/build" >> "$GITHUB_OUTPUT" 39 | echo "empi=empi-`git describe --tags`-`echo ${{ runner.os }}-${{ runner.arch }} | tr [:upper:] [:lower:]`" >> "$GITHUB_OUTPUT" 40 | 41 | - name: Install dependencies (Linux) 42 | if: runner.os == 'Linux' 43 | run: sudo apt-get update && sudo apt-get install -y libfftw3-dev 44 | 45 | - name: Install dependencies (macOS) 46 | if: runner.os == 'macOS' 47 | run: brew install gcc fftw && rm -vf `brew --prefix fftw`/lib/*.dylib 48 | 49 | - name: Install dependencies (Windows) 50 | if: runner.os == 'Windows' 51 | uses: johnwason/vcpkg-action@v5 52 | id: vcpkg 53 | with: 54 | pkgs: fftw3 pthread 55 | triplet: x64-windows-release 56 | cache-key: ${{ matrix.os }} 57 | revision: master 58 | token: ${{ github.token }} 59 | 60 | - name: Configure CMake 61 | # Configure CMake in a 'build' subdirectory. `CMAKE_BUILD_TYPE` is only required if you are using a single-configuration generator such as make. 62 | # See https://cmake.org/cmake/help/latest/variable/CMAKE_BUILD_TYPE.html?highlight=cmake_build_type 63 | run: > 64 | cmake ${{ steps.vcpkg.outputs.vcpkg-cmake-config }} 65 | -B ${{ steps.strings.outputs.build-output-dir }} 66 | -DCMAKE_CXX_COMPILER=${{ runner.os == 'Windows' && 'cl' || 'g++' }} 67 | -DCMAKE_C_COMPILER=${{ runner.os == 'Windows' && 'cl' || 'gcc' }} 68 | -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} 69 | ${{ runner.os == 'Windows' && '-DCMAKE_MSVC_RUNTIME_LIBRARY=MultiThreaded' || '' }} 70 | ${{ runner.os == 'Linux' && '-DSTANDALONE=TRUE' || '' }} 71 | -S ${{ github.workspace }} 72 | 73 | - name: Build 74 | # Build your program with the given configuration. Note that --config is needed because the default Windows generator is a multi-config generator (Visual Studio generator). 75 | run: cmake --build ${{ steps.strings.outputs.build-output-dir }} --config ${{ matrix.build_type }} --target empi 76 | 77 | - name: Prepare production artifacts (Non-Windows) 78 | if: runner.os != 'Windows' 79 | run: | 80 | mkdir ${{ steps.strings.outputs.empi }} 81 | cp ${{ steps.strings.outputs.build-output-dir }}/empi ${{ steps.strings.outputs.empi }} 82 | cp LICENCE ${{ steps.strings.outputs.empi }} 83 | cp README.md ${{ steps.strings.outputs.empi }} 84 | chmod +x ${{ steps.strings.outputs.empi }}/empi 85 | 86 | - name: Prepare production artifacts (Windows) 87 | if: runner.os == 'Windows' 88 | run: | 89 | mkdir ${{ steps.strings.outputs.empi }} 90 | copy ${{ steps.strings.outputs.build-output-dir }}\${{ matrix.build_type }}\empi.exe ${{ steps.strings.outputs.empi }} 91 | copy ${{ github.workspace }}\vcpkg\packages\fftw3_x64-windows-release\bin\fftw3.dll ${{ steps.strings.outputs.empi }} 92 | copy LICENCE ${{ steps.strings.outputs.empi }} 93 | copy README.md ${{ steps.strings.outputs.empi }} 94 | 95 | - name: Archive production artifacts 96 | uses: actions/upload-artifact@v4 97 | with: 98 | name: ${{ steps.strings.outputs.empi }} 99 | path: ${{ steps.strings.outputs.empi }}*/ 100 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea/ 2 | /benchmark 3 | /test-* 4 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.15) 2 | project(empi CXX C) 3 | set(CMAKE_CXX_STANDARD 17) 4 | 5 | enable_testing() 6 | 7 | if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) 8 | set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Choose the type of build, defaults to Release." FORCE) 9 | set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release") 10 | endif() 11 | 12 | if(APPLE) 13 | execute_process( 14 | COMMAND brew --prefix fftw 15 | OUTPUT_VARIABLE FFTW_PREFIX 16 | OUTPUT_STRIP_TRAILING_WHITESPACE 17 | ) 18 | include_directories("${FFTW_PREFIX}/include") 19 | link_directories("${FFTW_PREFIX}/lib") 20 | endif() 21 | 22 | if(NOT UNIX) 23 | find_package(FFTW3 CONFIG REQUIRED) 24 | find_path(FFTW_INCLUDE_DIR NAMES fftw3.h) 25 | add_library(fftw3 ALIAS FFTW3::fftw3) 26 | include_directories(${FFTW_INCLUDE_DIR}) 27 | endif() 28 | 29 | find_package(Git) 30 | set(VERSION_FILE "${CMAKE_SOURCE_DIR}/version") 31 | 32 | if(EXISTS ${VERSION_FILE}) 33 | file(READ ${VERSION_FILE} FILE_VERSION) 34 | string(STRIP ${FILE_VERSION} APP_VERSION) 35 | endif() 36 | if(NOT DEFINED APP_VERSION AND GIT_FOUND AND EXISTS "${PROJECT_SOURCE_DIR}/.git") 37 | execute_process( 38 | COMMAND ${GIT_EXECUTABLE} describe --tags 39 | WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} 40 | OUTPUT_VARIABLE GIT_TAG 41 | OUTPUT_STRIP_TRAILING_WHITESPACE 42 | RESULT_VARIABLE GIT_TAG_RESULT 43 | ) 44 | if(GIT_TAG_RESULT EQUAL "0") 45 | set(APP_VERSION ${GIT_TAG}) 46 | endif() 47 | endif() 48 | if(NOT DEFINED APP_VERSION) 49 | set(APP_VERSION "1.?.?") 50 | endif() 51 | add_definitions("-DAPP_VERSION=\"${APP_VERSION}\"") 52 | add_definitions("-D_USE_MATH_DEFINES") 53 | 54 | include(CheckLanguage) 55 | check_language(CUDA) 56 | if(CMAKE_CUDA_COMPILER) 57 | set(WITH_CUDA TRUE CACHE BOOL "Compile with CUDA support?") 58 | elseif(WITH_CUDA) 59 | message(SEND_ERROR "CUDA is not available on this system") 60 | endif() 61 | if(WITH_CUDA) 62 | enable_language(CUDA) 63 | set(CMAKE_CUDA_STANDARD 11) 64 | add_compile_definitions(HAVE_CUDA) 65 | # link all executables with nvcc 66 | string(REPLACE nvcc CMAKE_CXX_LINK_EXECUTABLE ${CMAKE_CXX_LINK_EXECUTABLE}) 67 | endif() 68 | 69 | file(GLOB EMPI_CPU_SOURCES src/*.cpp) 70 | file(GLOB EMPI_GPU_SOURCES src/*.cu) 71 | file(GLOB EMPI_TEST_SOURCES tests/*.cpp) 72 | 73 | add_library(empi-cpu STATIC ${EMPI_CPU_SOURCES} vendor/sqlite3.c) 74 | target_include_directories(empi-cpu PUBLIC include vendor/include) 75 | set_property(SOURCE vendor/sqlite3.c PROPERTY COMPILE_FLAGS -DSQLITE_OMIT_LOAD_EXTENSION) 76 | 77 | add_library(empi-fake-gpu STATIC src/special/alloc.cpp) 78 | target_include_directories(empi-fake-gpu PUBLIC include) 79 | 80 | if(UNIX) 81 | # on UNIX, you can compile statically with cmake -DSTANDALONE=1 82 | set(STANDALONE FALSE CACHE BOOL "Build a standalone binary?") 83 | elseif(STANDALONE) 84 | message(SEND_ERROR "Standalone compilation is only available on UNIX") 85 | endif() 86 | 87 | if(STANDALONE AND WITH_CUDA) 88 | message(SEND_ERROR "Standalone compilation cannot use CUDA") 89 | endif() 90 | 91 | if(WITH_CUDA) 92 | add_library(empi-gpu STATIC ${EMPI_GPU_SOURCES} src/special/alloc.cu) 93 | target_include_directories(empi-gpu PUBLIC include) 94 | set_property(SOURCE src/SpectrogramCalculatorCUDACallback.cu PROPERTY COMPILE_FLAGS --relocatable-device-code=true) 95 | elseif(STANDALONE) 96 | set(CMAKE_FIND_LIBRARY_SUFFIXES ".a") 97 | set(CMAKE_EXE_LINKER_FLAGS "-static -Wl,--whole-archive -lrt -lpthread -Wl,--no-whole-archive") 98 | else() 99 | set(CMAKE_EXE_LINKER_FLAGS "-pthread") 100 | endif() 101 | 102 | add_executable(empi src/special/empi.cpp) 103 | install(TARGETS empi) 104 | target_link_libraries(empi empi-cpu fftw3) 105 | if(WITH_CUDA) 106 | target_link_libraries(empi empi-gpu cufft_static culibos) 107 | else() 108 | target_link_libraries(empi empi-fake-gpu) 109 | endif() 110 | 111 | foreach(TEST_SOURCE ${EMPI_TEST_SOURCES}) 112 | get_filename_component(TEST ${TEST_SOURCE} NAME_WE) 113 | add_executable(${TEST} tests/${TEST}.cpp) 114 | target_link_libraries(${TEST} empi-cpu) 115 | add_test(${TEST} ${TEST}) 116 | endforeach() 117 | 118 | target_link_libraries(test-best-match empi-fake-gpu fftw3) 119 | 120 | target_link_libraries(test-block empi-fake-gpu) 121 | 122 | target_link_libraries(test-block-atom empi-fake-gpu) 123 | 124 | if(WITH_CUDA) 125 | add_executable(test-cuda tests/special/test-cuda.cpp) 126 | target_link_libraries(test-cuda empi-cpu empi-gpu cufft_static culibos) 127 | add_test(test-cuda test-cuda) 128 | endif() 129 | 130 | target_link_libraries(test-dictionary empi-fake-gpu) 131 | 132 | target_link_libraries(test-fftw fftw3) 133 | 134 | target_link_libraries(test-measure empi-fake-gpu fftw3) 135 | 136 | target_link_libraries(test-subsample empi-fake-gpu fftw3) 137 | -------------------------------------------------------------------------------- /build-standalone.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -o errexit 3 | DIR=`mktemp -d` 4 | trap "rm -rf $DIR" 0 2 5 | 6 | OS=`uname -s | tr '[A-Z]' '[a-z]'` 7 | if [ "$OS" = darwin ] ; then 8 | OS=macos 9 | STANDALONE=FALSE 10 | else 11 | STANDALONE=TRUE 12 | fi 13 | 14 | VER=empi-`git describe --tags --exact-match`-$OS 15 | mkdir -p $DIR/$VER 16 | cp README.md LICENCE $DIR/$VER 17 | cmake -DSTANDALONE=$STANDALONE -DWITH_CUDA=FALSE -S. -B$DIR 18 | ( cd $DIR \ 19 | && make -j8 empi \ 20 | && mv empi $VER/empi-$OS \ 21 | && zip -9r $VER.zip $VER/ 22 | ) 23 | mv $DIR/$VER.zip . 24 | -------------------------------------------------------------------------------- /demo/README.txt: -------------------------------------------------------------------------------- 1 | ################################ 2 | # MATCHING PURSUIT DEMO TOOLKIT 3 | ################################ 4 | 5 | This toolkit demonstrates the usage of JSON decomposition files generated 6 | with matching pursuit implementation «empi» in Python and Matlab. 7 | 8 | The main part consists of two scripts: «demo_mp.py» and «demo_mp.m». 9 | 10 | Both scripts read the file consisting of matching pursuit decomposition 11 | «input_smp.db», plotting the original signal and its reconstruction, 12 | as well as printing the atoms' parameters to the standard output. 13 | Atom filtering may be enabled by un-commenting parts of the source code. 14 | 15 | Additional files allow to understand the procedure in more detail: 16 | 17 | 1. Example input signal file «input.bin» consist of a synthetic 18 | multi-channel signal with one Gabor atom per channel, and can be 19 | re-generated with a Python script «input_generator.py». The script accepts 20 | an optional "--plot" command line parameter to plot the generated signal. 21 | 22 | 2. Decomposition results «input_smp.db» can be computed by running empi as 23 | 24 | path/to/empi -c5 -f512 -i10 -o local --gabor input.bin input_smp.db 25 | 26 | Thanks to the attached «input.xml» specification, the input signal itself 27 | may be displayed and manipulated in Svarog as well. 28 | -------------------------------------------------------------------------------- /demo/demo_mp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import numpy 3 | import pylab 4 | import sqlite3 5 | import sys 6 | 7 | ################################################## 8 | # PLEASE REFER TO README.txt FOR MORE INFORMATION 9 | ################################################## 10 | 11 | 12 | def demo_mp(): 13 | cursor = sqlite3.connect('input_smp.db').cursor() 14 | cursor.row_factory = sqlite3.Row 15 | 16 | # number of channels for each segment 17 | channel_count = int(fetch_metadata(cursor, 'channel_count')) 18 | 19 | # sampling frequency in hertz 20 | sampling_frequency = float(fetch_metadata(cursor, 'sampling_frequency_Hz')) 21 | 22 | # segment_id (starting from 0) can be given as command line parameter; 23 | # it does not matter for the example file as there is only one segment 24 | segment_id = int(sys.argv[1]) if len(sys.argv) > 1 else 0 25 | 26 | # number of samples in segment 27 | sample_count = fetch_sample_count(cursor, segment_id) 28 | 29 | for channel_id in range(channel_count): 30 | # each channel forms a separate sub-plot 31 | pylab.subplot(channel_count, 1, 1+channel_id) 32 | original_signal = fetch_original_signal(cursor, segment_id, channel_id) 33 | reconstruction = numpy.zeros(sample_count) 34 | 35 | t = numpy.arange(sample_count) / sampling_frequency 36 | query_for_atoms(cursor, segment_id, channel_id) # results will be iterated through cursor 37 | for atom in cursor: 38 | if atom['envelope'] != 'gauss': 39 | raise Exception('only Gabor atoms are supported') 40 | 41 | amplitude = atom['amplitude'] 42 | energy = atom['energy'] 43 | f = atom['f_Hz'] # frequency in hertz 44 | phase = atom['phase'] # phase in radians 45 | s = atom['scale_s'] # scale in seconds 46 | t0 = atom['t0_s'] # position (centre) in seconds 47 | t0_abs = atom['t0_abs_s'] # absolute position 48 | 49 | # lines below may be un-commented and edited to exclude 50 | # selected atoms from reconstruction 51 | ''' 52 | if f < 2.5: 53 | continue 54 | ''' 55 | g = amplitude * gabor(t, s, t0, f, phase) 56 | energy = numpy.sum(g**2) / sampling_frequency 57 | reconstruction += g 58 | print('\n-- ATOM IN CHANNEL {} --'.format(channel_id)) 59 | print('amplitude = %.3f' % amplitude) 60 | print('scale = %.3f s' % s) 61 | print('position in segment = %.3f s' % t0) 62 | print('position in signal = %.3f s' % t0_abs) 63 | print('frequency = %.3f Hz' % f) 64 | print('energy = %.6f' % energy) 65 | 66 | pylab.plot(t, original_signal) 67 | pylab.plot(t, reconstruction, 'red') 68 | 69 | pylab.show() 70 | 71 | 72 | def gabor(t, s, t0, f0, phase=0.0): 73 | """ 74 | Generates values for Gabor atom with unit amplitude. 75 | """ 76 | return numpy.exp(-numpy.pi*((t-t0)/s)**2) \ 77 | * numpy.cos(2*numpy.pi*f0*(t-t0) + phase) 78 | 79 | 80 | def fetch_metadata(cursor, name): 81 | return cursor.execute( 82 | 'SELECT value FROM metadata WHERE param=?', 83 | [name] 84 | ).fetchone()[0] 85 | 86 | 87 | def fetch_sample_count(cursor, segment_id): 88 | return cursor.execute( 89 | 'SELECT sample_count FROM segments WHERE segment_id=?', 90 | [segment_id] 91 | ).fetchone()[0] 92 | 93 | 94 | def fetch_original_signal(cursor, segment_id, channel_id): 95 | blob = cursor.execute( 96 | 'SELECT samples_float32 FROM samples WHERE segment_id=? AND channel_id=?', 97 | [segment_id, channel_id] 98 | ).fetchone()[0] 99 | return numpy.frombuffer(blob, numpy.dtype('float32').newbyteorder('>')) 100 | 101 | 102 | def query_for_atoms(cursor, segment_id, channel_id): 103 | cursor.execute( 104 | 'SELECT * FROM atoms WHERE segment_id=? AND channel_id=? ORDER BY iteration', 105 | [segment_id, channel_id] 106 | ) 107 | 108 | 109 | if __name__ == '__main__': 110 | demo_mp() 111 | -------------------------------------------------------------------------------- /demo/input.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/develancer/empi/b406ca501957d5979a7ea7e31a50d461076e33cc/demo/input.bin -------------------------------------------------------------------------------- /demo/input.cfg: -------------------------------------------------------------------------------- 1 | energyError 0.01 2 | maximalNumberOfIterations 10 3 | energyPercent 99.0 4 | 5 | MP SMP 6 | 7 | nameOfDataFile input.bin 8 | nameOfOutputDirectory . 9 | samplingFrequency 512.0 10 | 11 | numberOfChannels 5 12 | selectedChannels 1-5 13 | 14 | numberOfSamplesInEpoch 5120 15 | selectedEpochs 1-1 16 | -------------------------------------------------------------------------------- /demo/input.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | input.bin 4 | input.bin 5 | 6 | 7 | 8 | 512.0 9 | 5 10 | 5120 11 | FLOAT 12 | LITTLE_ENDIAN 13 | 10.0 14 | 1 15 | 16 | G1 17 | G2 18 | G3 19 | G4 20 | G5 21 | 22 | 23 | 1.0 24 | 1.0 25 | 1.0 26 | 1.0 27 | 1.0 28 | 29 | 30 | 0.0 31 | 0.0 32 | 0.0 33 | 0.0 34 | 0.0 35 | 36 | 37 | -------------------------------------------------------------------------------- /demo/input_generator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import numpy 3 | import sys 4 | 5 | sampling_frequency = 512. # Hz 6 | 7 | ################################################## 8 | # PLEASE REFER TO README.txt FOR MORE INFORMATION 9 | ################################################## 10 | 11 | 12 | def gabor_norm(s, t0, f0, phase): 13 | """ 14 | Calculates a correct normalization factor for Gabor atom. 15 | """ 16 | nyquist_frequency = 0.5 * sampling_frequency 17 | if f0 > 0.5 * nyquist_frequency: 18 | f0 = nyquist_frequency - f0 19 | phase = 2 * numpy.pi * nyquist_frequency * t0 - phase 20 | return numpy.sqrt(2**1.5 / ( 21 | s * (1 + numpy.cos(2*phase) * numpy.exp(-2*numpy.pi*s*s*f0*f0)) 22 | )) 23 | 24 | 25 | def gabor(t, s, t0, f0, phase=0.0): 26 | """ 27 | Generates values for Gabor atom normalized to unit energy. 28 | """ 29 | return gabor_norm(s, t0, f0, phase) \ 30 | * numpy.exp(-numpy.pi*((t-t0)/s)**2) \ 31 | * numpy.cos(2*numpy.pi*f0*(t-t0) + phase) 32 | 33 | 34 | if __name__ == '__main__': 35 | numpy.random.seed(42) # for deterministic results 36 | 37 | # shape of the resulting signal 38 | channel_count = 5 39 | sample_count = 5120 40 | 41 | t = numpy.arange(sample_count) / sampling_frequency 42 | data = numpy.zeros((channel_count, sample_count)) 43 | for i in range(channel_count): 44 | data[i] = 10 * gabor(t, s=3.0, t0=5.0, f0=i+1) + 0.1 * numpy.random.randn(sample_count) 45 | 46 | # save to a single binary file with interleaved channels 47 | data.T.reshape((-1)).astype('float32').tofile('input.bin') 48 | 49 | if len(sys.argv) > 1 and sys.argv[1] == '--plot': 50 | # optionally, plot all channels 51 | import pylab 52 | for i in range(channel_count): 53 | pylab.subplot(channel_count, 1, 1+i) 54 | pylab.plot(t, data[i]) 55 | pylab.show() 56 | -------------------------------------------------------------------------------- /demo/input_smp.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/develancer/empi/b406ca501957d5979a7ea7e31a50d461076e33cc/demo/input_smp.db -------------------------------------------------------------------------------- /include/Atom.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_ATOM_H 7 | #define EMPI_ATOM_H 8 | 9 | #include 10 | #include 11 | #include "Array.h" 12 | #include "ExportedAtom.h" 13 | #include "IndexRange.h" 14 | 15 | /** 16 | * Base class for all atom representations. 17 | * Each atom stores a set of parameters in a dictionary parameter space, 18 | * has a well-defined energy and can be used to generate a waveform. 19 | */ 20 | class Atom { 21 | protected: 22 | explicit Atom(Array2D data, double energy) : data(std::move(data)), energy(energy) {} 23 | 24 | public: 25 | /** 26 | * reference to multi-channel data of the analysed signal 27 | */ 28 | const Array2D data; 29 | 30 | /** 31 | * energy of the atom, computed as sum of all samples squared 32 | */ 33 | const double energy; 34 | 35 | /** 36 | * Compares atoms according to their energies. 37 | * @return true if atom has smaller energy than the other atom, false otherwise 38 | */ 39 | bool operator<(const Atom &other) const { 40 | return energy < other.energy; 41 | } 42 | 43 | virtual ~Atom() = default; 44 | }; 45 | 46 | /** 47 | * Base class for all atom representations for which all parameters are already known. 48 | * Instances of this class can be obtained from BasicAtoms' method extend(). 49 | */ 50 | class ExtendedAtom : public Atom { 51 | protected: 52 | explicit ExtendedAtom(Array2D data, double energy) : Atom(std::move(data), energy) {} 53 | 54 | public: 55 | /** 56 | * Export all parameters of this atom to a uniform representation. 57 | * 58 | * @param atoms array of lists of atoms (one list per channel) to which atoms' parameters should be appended 59 | */ 60 | virtual void export_atom(std::list atoms[]) = 0; 61 | 62 | /** 63 | * Subtract this atom from the multi-channel signal being analyzed. 64 | * 65 | * @return range of the signal samples that are being changed 66 | * (can be later passed to fetch_requests in all dictionaries) 67 | */ 68 | virtual IndexRange subtract_from_signal() const = 0; 69 | }; 70 | 71 | using ExtendedAtomPointer = std::shared_ptr; 72 | 73 | /** 74 | * Base class for all atom representations for which some of the parameters may not be known yet. 75 | * Instance of ExtendedAtom can be obtained by calling the method extend(). 76 | */ 77 | class BasicAtom : public Atom { 78 | public: 79 | explicit BasicAtom(Array2D data, double energy) : Atom(std::move(data), energy) {} 80 | 81 | /** 82 | * Create a new ExtendedAtom instance as a fully-defined representation of this atom. 83 | * 84 | * @return smart pointer to a newly created ExtendedAtom 85 | */ 86 | [[nodiscard]] virtual ExtendedAtomPointer extend(bool allow_optimization) = 0; 87 | 88 | /** 89 | * @return estimate for maximum possible energy of the atom that can be obtained by locally optimizing the coefficients 90 | */ 91 | [[nodiscard]] virtual double get_energy_upper_bound() const = 0; 92 | }; 93 | 94 | using BasicAtomPointer = std::shared_ptr; 95 | 96 | #endif //EMPI_ATOM_H 97 | -------------------------------------------------------------------------------- /include/Block.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_BLOCK_H 7 | #define EMPI_BLOCK_H 8 | 9 | #include 10 | #include 11 | #include "BlockAtom.h" 12 | #include "BlockAtomBase.h" 13 | #include "BlockAtomCache.h" 14 | #include "BlockInterface.h" 15 | #include "Extractor.h" 16 | #include "PinnedArray.h" 17 | #include "SpectrogramRequest.h" 18 | #include "Types.h" 19 | 20 | /** 21 | * Basic building block for all dictionaries. 22 | * Each block includes a well-defined envelope waveform of a given width, 23 | * and is responsible for computing scalar products between signal and all atoms 24 | * that can be obtained by time-shifting and modulating that waveform. 25 | */ 26 | class Block : public BlockInterface { 27 | 28 | PinnedArray2D data; 29 | std::shared_ptr family; 30 | double scale; 31 | double envelope_center_offset; 32 | PinnedArray1D envelope; 33 | PinnedArray1D correctors; 34 | std::shared_ptr converter; 35 | 36 | PinnedArray1D maxima; 37 | Array1D booster; 38 | SpectrogramRequest total_request; 39 | 40 | int best_index; 41 | 42 | std::shared_ptr extended_atom_cache; 43 | 44 | public: 45 | /** 46 | * Create a new block for the dictionary. 47 | * 48 | * @param data reference to multi-channel data of the analysed signal 49 | * @param family Family object describing properties of the envelope function 50 | * @param envelope L²-normalized samples of the particular realization of the envelope function 51 | * @param correctors array of corrections to be applied on the computed spectrum to account for normalization issues 52 | * @param steps TODO 53 | * @param window_length number of samples of the FFT to which data will be zero-padded, should be a power of 2 54 | * @param input_shift shift (in samples) between consecutive time-shifted envelope realizations 55 | * @param extractor function that will be used to extract information from multi-channel results, 56 | * according to a particular mode of multi-channel operation (constant vs variable phase) 57 | * @param allow_overstep whether we can assume that samples before and after the actual signal range are equal to zero 58 | */ 59 | Block(PinnedArray2D data, std::shared_ptr family, double scale, PinnedArray1D envelope, 60 | PinnedArray1D correctors, std::shared_ptr converter, double booster, 61 | int window_length, int input_shift, double envelope_center_offset, Extractor extractor, bool allow_overstep = true); 62 | 63 | /** 64 | * Create a SpectrogramRequest that can be used to recompute the entire block. 65 | */ 66 | SpectrogramRequest buildRequest(); 67 | 68 | /** 69 | * Create a SpectrogramRequest that can be used to recompute part of this block. 70 | * 71 | * @param first_sample_index zero-based index of the first sample in the interval that needs to be updated 72 | * @param end_sample_index zero-based index of the first sample _following_ the interval that needs to be updated (last sample index + 1) 73 | * @return 74 | */ 75 | SpectrogramRequest buildRequest(index_t first_sample_index, index_t end_sample_index); 76 | 77 | /** 78 | * This method is called internally by Computer instances 79 | * to notify that the request obtained by the last call to buildRequest is completed. 80 | * It is then the block's responsibility to update its internal cache. 81 | */ 82 | void notify() final; 83 | 84 | /** 85 | * Calculate how many atoms are represented by this block. 86 | * 87 | * @return number of atoms in this block 88 | */ 89 | [[nodiscard]] size_t get_atom_count() const; 90 | 91 | /** 92 | * Compute and return the atom that is currently a best match for the analyzed signal. 93 | * 94 | * @return best matching atom 95 | */ 96 | [[nodiscard]] BlockAtom get_best_match() const; 97 | 98 | [[nodiscard]] std::list get_candidate_matches(double energy_to_exceed) const; 99 | 100 | private: 101 | [[nodiscard]] BlockAtom get_atom_from_index(int index) const; 102 | }; 103 | 104 | #endif //EMPI_BLOCK_H 105 | -------------------------------------------------------------------------------- /include/BlockAtom.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_BLOCK_ATOM_H 7 | #define EMPI_BLOCK_ATOM_H 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include "Array.h" 16 | #include "Atom.h" 17 | #include "BlockAtomBase.h" 18 | #include "BlockAtomCache.h" 19 | #include "Extractor.h" 20 | #include "Family.h" 21 | 22 | /** 23 | * Extended atom implementation for block dictionaries. 24 | * It defines all parameters, including amplitudes, phases and energies for all channels. 25 | */ 26 | class BlockExtendedAtom : public ExtendedAtom { 27 | std::shared_ptr family; 28 | 29 | public: 30 | const BlockAtomParams params; 31 | 32 | /** 33 | * Channel-specific data (amplitudes, phases and energies) for all channels of the analysed signal. 34 | */ 35 | Array1D extra; 36 | 37 | /** 38 | * Create a new extended atom for block dictionary. 39 | * 40 | * @param data reference to multi-channel data of the analysed signal 41 | * @param energy energy of the atom, computed as sum of all samples squared 42 | * @param frequency frequency in range between 0 and 0.5 (1 would be sampling frequency) 43 | * @param position center position in samples 44 | * @param scale atom's scale (related to its duration) in samples 45 | * @param extra channel-specific data (amplitudes, phases and energies) for all channels of the analysed signal 46 | */ 47 | BlockExtendedAtom(Array2D data, double energy, std::shared_ptr family, double frequency, double position, double scale, 48 | Array1D extra); 49 | 50 | /** 51 | * Export all parameters of this atom to a uniform representation. 52 | * 53 | * @param atoms array of lists of atoms (one list per channel) to which atoms' parameters should be appended 54 | */ 55 | void export_atom(std::list atoms[]) final; 56 | 57 | /** 58 | * Subtract this atom from the multi-channel signal being analyzed. 59 | * 60 | * @return range of the signal samples that are being changed 61 | * (can be later passed to fetch_requests in all dictionaries) 62 | */ 63 | IndexRange subtract_from_signal() const final; 64 | }; 65 | 66 | /** 67 | * Basic atom implementation for block dictionaries. 68 | * It defines position, frequency and scale, but neither amplitude nor phase. 69 | */ 70 | class BlockAtom : public BasicAtom { 71 | double energy_upper_bound; 72 | Extractor extractor; 73 | std::shared_ptr family; 74 | std::shared_ptr converter; 75 | 76 | std::optional cache_slot; 77 | 78 | static std::atomic failed_optimization_count; 79 | static std::atomic total_optimization_count; 80 | 81 | public: 82 | const BlockAtomParams params; 83 | 84 | static double get_failed_optimization_percent() 85 | { 86 | return 100.0 * failed_optimization_count / total_optimization_count; 87 | } 88 | 89 | /** 90 | * Create a new basic atom for block dictionary. 91 | * 92 | * @param data reference to multi-channel data of the analysed signal 93 | * @param energy energy of the atom, computed as sum of all samples squared 94 | * @param family Family object describing properties of the envelope function 95 | * @param frequency frequency in range between 0 and 0.5 (1 would be sampling frequency) 96 | * @param position center position in samples 97 | * @param scale in samples 98 | * @param extractor function that will be used to extract information from multi-channel results, 99 | * according to a particular mode of multi-channel operation (constant vs variable phase) 100 | */ 101 | BlockAtom(Array2D data, double energy, double energy_upper_bound, std::shared_ptr family, 102 | double frequency, double position, double scale, 103 | Extractor extractor, std::shared_ptr converter); 104 | 105 | void connect_cache(std::shared_ptr cache, size_t key); 106 | 107 | /** 108 | * Create a new BlockExtendedAtom instance as a fully-defined representation of this atom. 109 | * 110 | * @return smart pointer to a newly created BlockExtendedAtom 111 | */ 112 | [[nodiscard]] ExtendedAtomPointer extend(bool allow_optimization) final; 113 | 114 | /** 115 | * @return estimate for maximum possible energy of the atom that can be obtained by locally optimizing the coefficients 116 | */ 117 | [[nodiscard]] double get_energy_upper_bound() const final; 118 | }; 119 | 120 | #endif //EMPI_BLOCK_ATOM_H 121 | -------------------------------------------------------------------------------- /include/BlockAtomBase.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_BLOCK_ATOM_BASE_H 7 | #define EMPI_BLOCK_ATOM_BASE_H 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | /** 14 | * Simple structure with three fields, used in both BlockAtom and ExtendedBlockAtom. 15 | */ 16 | struct BlockAtomParams { 17 | double frequency; // between 0 and 0.5 (1 would be sampling frequency) 18 | double position; // in samples 19 | double scale; // in samples 20 | 21 | BlockAtomParams(double frequency, double position, double scale) 22 | : frequency(frequency), position(position), scale(scale) {} 23 | }; 24 | 25 | /** 26 | * Auxiliary class converting atom parameters between actual values and scaled values used internally by the minimizer. 27 | */ 28 | class BlockAtomParamsConverter { 29 | protected: 30 | double frequency_step; 31 | double position_step; 32 | double log_scale_step; 33 | 34 | double max_scaled_frequency; 35 | 36 | public: 37 | BlockAtomParamsConverter() : frequency_step(1.0), position_step(1.0), log_scale_step(1.0), max_scaled_frequency(0.5) { } 38 | 39 | BlockAtomParamsConverter(double frequency_step, double position_step, double log_scale_step, double frequency_max); 40 | 41 | /** 42 | * Convert from actual values to scaled values used internally by the minimizer. 43 | * @param params in actual units (@see BlockAtomParams) 44 | * @return array of 3 parameters in unit related to dictionary steps 45 | */ 46 | [[nodiscard]] std::array arrayFromParams(const BlockAtomParams ¶ms) const; 47 | 48 | /** 49 | * Convert from scaled values used internally by the minimizer to actual values. 50 | * @param array of 3 parameters in unit related to dictionary steps 51 | * @return params in actual units (@see BlockAtomParams) 52 | */ 53 | [[nodiscard]] std::pair paramsFromArray(const std::array &array) const; 54 | 55 | private: 56 | double fix_frequency(double &scaled_frequency) const; 57 | 58 | protected: 59 | virtual double fix_log_scale(double &scaled_log_scale) const; 60 | 61 | /** 62 | * Fix the value given as a reference so it will be between min and max, inclusive. 63 | * 64 | * @return if given value has already been between min and max, 0; 65 | * otherwise, its distance from this interval 66 | */ 67 | static double fix_scaled_argument(double &value, double min, double max); 68 | }; 69 | 70 | /** 71 | * Auxiliary class converting atom parameters between actual values and scaled values used internally by the minimizer. 72 | */ 73 | class BlockAtomParamsConverterBounded : public BlockAtomParamsConverter { 74 | double min_scaled_log_scale; 75 | double max_scaled_log_scale; 76 | 77 | public: 78 | BlockAtomParamsConverterBounded(double frequency_step, double position_step, double log_scale_step, double frequency_max, double min_scale, double max_scale); 79 | 80 | double fix_log_scale(double &scaled_log_scale) const final; 81 | }; 82 | 83 | #endif //EMPI_BLOCK_ATOM_BASE_H 84 | -------------------------------------------------------------------------------- /include/BlockAtomCache.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_BLOCK_ATOM_CACHE_H 7 | #define EMPI_BLOCK_ATOM_CACHE_H 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include "IndexRange.h" 14 | 15 | // forward declaration to avoid cyclic dependency between BlockAtomCache.h and BlockAtom.h 16 | class BlockExtendedAtom; 17 | 18 | /** 19 | * Internal data object used in BlockAtomCache. 20 | */ 21 | struct BlockAtomCacheItem { 22 | IndexRange index_range; 23 | std::shared_ptr atom; 24 | }; 25 | 26 | /** 27 | * Cache for BlockExtendedAtom instances used internally in Block class. 28 | */ 29 | //using BlockAtomCache = std::map; 30 | class BlockAtomCache 31 | { 32 | std::map items; 33 | std::shared_mutex items_mutex; 34 | 35 | public: 36 | [[nodiscard]] std::shared_ptr get(size_t key); 37 | 38 | void set(size_t key, IndexRange index_range, std::shared_ptr atom); 39 | 40 | void remove_overlapping(IndexRange range_to_overlap); 41 | }; 42 | 43 | /** 44 | * Reference to one of the items in the cache, identified by a unique key. 45 | */ 46 | class BlockAtomCacheSlot { 47 | std::weak_ptr cache; 48 | size_t key; 49 | 50 | public: 51 | BlockAtomCacheSlot(const std::shared_ptr& cache, size_t key); 52 | 53 | [[nodiscard]] std::shared_ptr get() const; 54 | 55 | void set(IndexRange index_range, std::shared_ptr atom); 56 | }; 57 | 58 | #endif //EMPI_BLOCK_ATOM_CACHE_H 59 | -------------------------------------------------------------------------------- /include/BlockAtomObjective.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_BLOCK_ATOM_OBJECTIVE_H 7 | #define EMPI_BLOCK_ATOM_OBJECTIVE_H 8 | 9 | #include 10 | #include 11 | #include 12 | #include "Array.h" 13 | #include "BlockAtomBase.h" 14 | #include "Extractor.h" 15 | #include "Family.h" 16 | 17 | /** 18 | * Objective function used by multidimensional optimization to find optimal atom parameters. 19 | */ 20 | class BlockAtomObjective { 21 | 22 | std::shared_ptr family; 23 | const int channel_count; 24 | Array2D data; 25 | Array2D products; 26 | Extractor extractor; 27 | std::vector envelope; 28 | std::vector oscillating; 29 | std::shared_ptr converter; 30 | 31 | public: 32 | BlockAtomObjective(std::shared_ptr family, Array2D data, Extractor extractor, std::shared_ptr converter); 33 | 34 | double calculate_energy(const std::array &array, double *out_norm, ExtraData *out_extra_data); 35 | 36 | double calculate_energy(const BlockAtomParams ¶ms, double *out_norm, ExtraData *out_extra_data); 37 | 38 | double operator()(const std::array &array) { 39 | return -calculate_energy(array, nullptr, nullptr); 40 | } 41 | }; 42 | 43 | #endif //EMPI_BLOCK_ATOM_OBJECTIVE_H 44 | -------------------------------------------------------------------------------- /include/BlockAtomProductCalculator.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_BLOCK_ATOM_PRODUCT_CALCULATOR_H 7 | #define EMPI_BLOCK_ATOM_PRODUCT_CALCULATOR_H 8 | 9 | #include 10 | #include "BlockAtomBase.h" 11 | #include "Family.h" 12 | 13 | class BlockAtomProductCalculator { 14 | std::shared_ptr family; 15 | 16 | public: 17 | explicit BlockAtomProductCalculator(std::shared_ptr family); 18 | 19 | [[nodiscard]] double calculate_squared_product(const BlockAtomParams &p, const BlockAtomParams &q) const; 20 | 21 | [[nodiscard]] double calculate_squared_product(const BlockAtomParams &p, const BlockAtomParams &q, double q_phase) const; 22 | }; 23 | 24 | #endif //EMPI_BLOCK_ATOM_PRODUCT_CALCULATOR_H 25 | -------------------------------------------------------------------------------- /include/BlockDictionary.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_BLOCK_DICTIONARY_H 7 | #define EMPI_BLOCK_DICTIONARY_H 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include "Atom.h" 14 | #include "Block.h" 15 | #include "BlockDictionaryStructure.h" 16 | #include "Family.h" 17 | #include "Dictionary.h" 18 | #include "IndexRange.h" 19 | #include "PinnedArray.h" 20 | #include "SpectrogramRequest.h" 21 | #include "SpectrumCalculator.h" 22 | 23 | /** 24 | * Dictionary for envelope-based oscillating atoms, consisting of one or more Block instances. 25 | */ 26 | class BlockDictionary : public Dictionary { 27 | 28 | std::list blocks; 29 | 30 | public: 31 | /** 32 | * Create a new dictionary for a given envelope function. 33 | * 34 | * @param data reference to multi-channel data of the analysed signal 35 | * @param family Family object describing properties of the envelope function 36 | * @param scale atom's scale (for scaling an envelope function's argument) in samples 37 | * @param extractor function that will be used to extract information from multi-channel results, 38 | * according to a particular mode of multi-channel operation (constant vs variable phase) 39 | * @param calculator needed for calculating corrections to account for normalization issues 40 | */ 41 | BlockDictionary(const BlockDictionaryStructure& structure, const PinnedArray2D& data, 42 | Extractor extractor, SpectrumCalculator &calculator, bool allow_overstep); 43 | 44 | explicit BlockDictionary(Block&& block); 45 | 46 | /** 47 | * Calculate how many atoms are represented by this dictionary. 48 | * 49 | * @return number of atoms in this dictionary 50 | */ 51 | size_t get_atom_count() final; 52 | 53 | /** 54 | * Compute and return the atom that is currently a best match for the analyzed signal. 55 | * 56 | * @return best matching atom 57 | */ 58 | BasicAtomPointer get_best_match() final; 59 | 60 | std::list get_candidate_matches(double energy_to_exceed) final; 61 | 62 | /** 63 | * Create all request templates that can be requested by this dictionary. 64 | * The generated proto-requests will be added to the given list. 65 | * 66 | * @param requests list for the requests to be appended to 67 | */ 68 | void fetch_proto_requests(std::list &requests) final; 69 | 70 | /** 71 | * Create all actual recalculation requests for given updated signal range. 72 | * The generated requests will be added to the given list. 73 | * 74 | * @param signal_range range of the signal samples that have been changed 75 | * (e.g. as returned from the last call to subtract_from_signal) 76 | * @param requests list for the requests to be appended to 77 | */ 78 | void fetch_requests(IndexRange signal_range, std::list &requests) final; 79 | }; 80 | 81 | #endif //EMPI_BLOCK_DICTIONARY_H 82 | -------------------------------------------------------------------------------- /include/BlockDictionaryStructure.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_BLOCK_DICTIONARY_STRUCTURE_H 7 | #define EMPI_BLOCK_DICTIONARY_STRUCTURE_H 8 | 9 | #include 10 | #include 11 | #include 12 | #include "BlockStructure.h" 13 | #include "Family.h" 14 | 15 | struct BlockDictionaryStructure { 16 | 17 | const std::shared_ptr family; 18 | const double energy_error, scale_min, scale_max, frequency_max, log_scale_step, dt_scale, df_scale; 19 | const std::list block_structures; 20 | 21 | /** 22 | * @param family Family object describing properties of the envelope function 23 | */ 24 | BlockDictionaryStructure(std::shared_ptr family, double energy_error, double scale_min, double scale_max, double frequency_max); 25 | 26 | [[nodiscard]] std::set get_transform_sizes() const; 27 | }; 28 | 29 | #endif //EMPI_BLOCK_DICTIONARY_STRUCTURE_H 30 | -------------------------------------------------------------------------------- /include/BlockHelper.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_BLOCK_HELPER_H 7 | #define EMPI_BLOCK_HELPER_H 8 | 9 | #include 10 | #include 11 | #include "Array.h" 12 | #include "Block.h" 13 | #include "BlockAtomBase.h" 14 | #include "BlockStructure.h" 15 | #include "Corrector.h" 16 | #include "Family.h" 17 | #include "PinnedArray.h" 18 | #include "SpectrumCalculator.h" 19 | 20 | class BlockHelper { 21 | static void add_block_structures_to_list(std::list& result, const Family* family, double scale, double df_scale, double dt_scale); 22 | 23 | public: 24 | static Block create_block(PinnedArray2D data, std::shared_ptr family, double scale, 25 | std::shared_ptr converter, double booster, 26 | int window_length, int output_bins, int input_shift, double subsample_offset, Extractor extractor, SpectrumCalculator& calculator, bool allow_overstep = true); 27 | 28 | static std::pair, double> generate_envelope(const Family* family, double scale, double subsample_offset); 29 | 30 | static PinnedArray1D generate_correctors(const Array1D& envelope, int window_length, int output_bins, SpectrumCalculator& calculator); 31 | 32 | static std::list compute_block_structures(const Family* family, double scale_min, double scale_max, double log_scale_step, double df_scale, double dt_scale); 33 | 34 | static int round_transform_size(double min_transform_size_as_float); 35 | }; 36 | 37 | #endif //EMPI_BLOCK_HELPER_H 38 | -------------------------------------------------------------------------------- /include/BlockInterface.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_BLOCK_INTERFACE_H 7 | #define EMPI_BLOCK_INTERFACE_H 8 | 9 | /** 10 | * Simple interface so that Computer instances can notify Block instances about completed computation requests. 11 | */ 12 | class BlockInterface { 13 | public: 14 | /** 15 | * This method is called internally by Computer instances 16 | * to notify that the last computation request has been completed. 17 | * It is then the block's responsibility to update its internal cache. 18 | */ 19 | virtual void notify() = 0; 20 | }; 21 | 22 | #endif //EMPI_BLOCK_INTERFACE_H 23 | -------------------------------------------------------------------------------- /include/BlockStructure.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_BLOCK_STRUCTURE_H 7 | #define EMPI_BLOCK_STRUCTURE_H 8 | 9 | struct BlockStructure { 10 | double scale; 11 | int envelope_length; 12 | int transform_size; 13 | double input_shift; 14 | }; 15 | 16 | #endif //EMPI_BLOCK_STRUCTURE_H 17 | -------------------------------------------------------------------------------- /include/BufferedWriter.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_BUFFERED_WRITER_H 7 | #define EMPI_BUFFERED_WRITER_H 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include "Array.h" 15 | #include "BookWriter.h" 16 | #include "Semaphore.h" 17 | 18 | struct BufferedResult { 19 | std::mutex mutex; 20 | EpochIndex index; 21 | Array2D data; 22 | std::vector> atoms; 23 | }; 24 | 25 | /** 26 | * Wrapper for book writer implementations to allow a thread-safe access to it. 27 | * Instances of this class can be used by multiple threads and epochs may be passed to write in any sequence. 28 | * Moreover, in case of the single-channel decomposition, individual channels can be passed as well. 29 | */ 30 | class BufferedWriter : public BookWriter { 31 | const int total_channel_count; 32 | const int real_epoch_count; 33 | std::unique_ptr actual_writer; 34 | 35 | std::vector semaphores; 36 | std::vector results; 37 | 38 | BufferedResult& get_result_ref(int epoch_counter, index_t sample_count); 39 | 40 | public: 41 | BufferedWriter(int total_channel_count, int real_epoch_count, std::unique_ptr actual_writer); 42 | 43 | void finalize() final; 44 | 45 | void write(Array2D data, EpochIndex index, const std::vector> &atoms) final; 46 | }; 47 | 48 | #endif //EMPI_BUFFERED_WRITER_H 49 | -------------------------------------------------------------------------------- /include/CUDA.h: -------------------------------------------------------------------------------- 1 | #ifndef EMPI_CUDA_H 2 | #define EMPI_CUDA_H 3 | 4 | #ifndef __CUDACC__ 5 | #error this file can be included only when compiling with nvcc 6 | #endif 7 | 8 | #include 9 | #include 10 | #include 11 | #include "Types.h" 12 | 13 | using cucomplex = cuDoubleComplex; // TODO #ifdef SINGLE_PRECISION 14 | 15 | struct CudaCallbackInfo { 16 | real *envelope; 17 | void *correctors; 18 | uint32_t window_length; 19 | uint32_t spectrum_length; 20 | uint32_t envelope_length; 21 | uint32_t input_shift; 22 | uint32_t output_bins; 23 | unsigned window_length_bits; 24 | size_t window_length_mask; 25 | }; 26 | 27 | class CudaCallback { 28 | void *hostCopyOfInputCallback = nullptr; 29 | void *hostCopyOfOutputCallback = nullptr; 30 | 31 | public: 32 | void initialize(); 33 | 34 | void associate(cufftHandle plan, CudaCallbackInfo *dev_info); 35 | }; 36 | 37 | class CudaException : public std::runtime_error { 38 | static char buffer[256]; 39 | 40 | template 41 | const char *prepare(const char *format, Args... args) { 42 | snprintf(buffer, sizeof buffer, format, args...); 43 | return buffer; 44 | } 45 | 46 | public: 47 | explicit CudaException(cudaError_t error) 48 | : std::runtime_error(prepare("GPU error (CUDA) %s (%s)", cudaGetErrorName(error), cudaGetErrorString(error))) {} 49 | 50 | explicit CudaException(cufftResult_t result) 51 | : std::runtime_error(prepare("GPU error (cuFFT) #%d", static_cast(result))) {} 52 | }; 53 | 54 | class CudaMemoryException : public std::exception { 55 | const char *what() const noexcept override { 56 | return "GPU ran out of memory"; 57 | } 58 | }; 59 | 60 | void cuda_check(cudaError_t error); 61 | 62 | void cufft_check(cufftResult_t result); 63 | 64 | #endif //EMPI_CUDA_H 65 | -------------------------------------------------------------------------------- /include/Corrector.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_CORRECTOR_H 7 | #define EMPI_CORRECTOR_H 8 | 9 | #include "ExtraData.h" 10 | #include "Types.h" 11 | 12 | #ifdef _WIN32 13 | __declspec(align(16)) 14 | #endif 15 | class Corrector { 16 | complex ft; 17 | complex re_factor; 18 | complex im_factor; 19 | 20 | public: 21 | explicit Corrector(); 22 | 23 | explicit Corrector(complex envelope_ft); 24 | 25 | [[nodiscard]] inline double estimate_energy(complex value) const { 26 | return std::norm(value.real() * re_factor + value.imag() * im_factor); 27 | } 28 | 29 | double compute(complex value, ExtraData* extra = nullptr) const; 30 | 31 | } 32 | #ifndef _WIN32 33 | __attribute__ ((aligned (16))) 34 | #endif 35 | ; 36 | 37 | #endif //EMPI_CORRECTOR_H 38 | -------------------------------------------------------------------------------- /include/DeltaAtom.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_DELTA_ATOM_H 7 | #define EMPI_DELTA_ATOM_H 8 | 9 | #include 10 | #include "Atom.h" 11 | 12 | /** 13 | * Basic atom implementation for delta (single sample) dictionaries. 14 | */ 15 | class DeltaAtom : public BasicAtom { 16 | index_t position; 17 | 18 | public: 19 | /** 20 | * Create a new basic atom for delta (single sample) dictionary. 21 | * 22 | * @param position position of the sample 23 | * @param amplitudes amplitudes for the sample for all channels 24 | */ 25 | DeltaAtom(Array2D data, double energy, index_t position); 26 | 27 | /** 28 | * Create a new BlockExtendedAtom instance as a fully-defined representation of this atom. 29 | * 30 | * @return smart pointer to a newly created BlockExtendedAtom 31 | */ 32 | [[nodiscard]] ExtendedAtomPointer extend(bool allow_optimization) final; 33 | 34 | /** 35 | * @return estimate for maximum possible energy of the atom that can be obtained by locally optimizing the coefficients 36 | */ 37 | [[nodiscard]] double get_energy_upper_bound() const final; 38 | }; 39 | 40 | /** 41 | * Extended atom implementation for delta (single sample) dictionaries. 42 | * It defines position and amplitudes for all channels. 43 | * Some of the amplitudes may be negative. 44 | */ 45 | class DeltaExtendedAtom : public ExtendedAtom { 46 | index_t position; 47 | std::vector amplitudes; 48 | 49 | public: 50 | /** 51 | * Create a new basic atom for delta (single sample) dictionary. 52 | * 53 | * @param position position of the sample 54 | * @param amplitudes amplitudes for the sample for all channels 55 | */ 56 | DeltaExtendedAtom(Array2D data, double energy, index_t position, std::vector&& amplitudes); 57 | 58 | /** 59 | * Export all parameters of this atom to a uniform representation. 60 | * 61 | * @param atoms array of lists of atoms (one list per channel) to which atoms' parameters should be appended 62 | */ 63 | void export_atom(std::list atoms[]) final; 64 | 65 | /** 66 | * Subtract this atom from the multi-channel signal being analyzed. 67 | * 68 | * @return range of the signal samples that are being changed 69 | * (can be later passed to fetch_requests in all dictionaries) 70 | */ 71 | IndexRange subtract_from_signal() const final; 72 | }; 73 | 74 | #endif //EMPI_DELTA_ATOM_H 75 | -------------------------------------------------------------------------------- /include/DeltaDictionary.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_DELTA_DICTIONARY_H 7 | #define EMPI_DELTA_DICTIONARY_H 8 | 9 | #include 10 | #include "Dictionary.h" 11 | #include "PinnedArray.h" 12 | 13 | /** 14 | * Base class for dictionaries of all types. 15 | * One or more instances of any class derived from Dictionary 16 | * can be associated with a single Computer instance. 17 | */ 18 | class DeltaDictionary : public Dictionary { 19 | PinnedArray2D data; 20 | 21 | public: 22 | explicit DeltaDictionary(PinnedArray2D data); 23 | 24 | /** 25 | * Calculate how many atoms are represented by this dictionary. 26 | * 27 | * @return number of atoms in this dictionary 28 | */ 29 | size_t get_atom_count() final; 30 | 31 | /** 32 | * Compute and return the atom that is currently a best match for the analyzed signal. 33 | * 34 | * @return best matching atom 35 | */ 36 | BasicAtomPointer get_best_match() final; 37 | 38 | std::list get_candidate_matches(double energy_to_exceed) final { 39 | return {}; 40 | } 41 | 42 | /** 43 | * Create all request templates that can be requested by this dictionary. 44 | * The generated proto-requests will be added to the given list. 45 | * 46 | * @param requests list for the requests to be appended to 47 | */ 48 | void fetch_proto_requests(std::list &requests) final {}; 49 | 50 | /** 51 | * Create all actual recalculation requests for given updated signal range. 52 | * The generated requests will be added to the given list. 53 | * 54 | * @param signal_range range of the signal samples that have been changed 55 | * (e.g. as returned from the last call to subtract_from_signal) 56 | * @param requests list for the requests to be appended to 57 | */ 58 | void fetch_requests(IndexRange signal_range, std::list &requests) final; 59 | }; 60 | 61 | #endif //EMPI_DELTA_DICTIONARY_H 62 | -------------------------------------------------------------------------------- /include/Dictionary.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_DICTIONARY_H 7 | #define EMPI_DICTIONARY_H 8 | 9 | #include 10 | #include "Atom.h" 11 | #include "IndexRange.h" 12 | #include "SpectrogramRequest.h" 13 | 14 | /** 15 | * Base class for dictionaries of all types. 16 | * One or more instances of any class derived from Dictionary 17 | * can be associated with a single Computer instance. 18 | */ 19 | class Dictionary { 20 | public: 21 | /** 22 | * Calculate how many atoms are represented by this dictionary. 23 | * 24 | * @return number of atoms in this dictionary 25 | */ 26 | virtual size_t get_atom_count() = 0; 27 | 28 | /** 29 | * Compute and return the atom that is currently a best match for the analyzed signal. 30 | * 31 | * @return best matching atom 32 | */ 33 | virtual BasicAtomPointer get_best_match() = 0; 34 | 35 | virtual std::list get_candidate_matches(double energy_to_exceed) = 0; 36 | 37 | /** 38 | * Create all request templates that can be requested by this dictionary. 39 | * The generated proto-requests will be added to the given list. 40 | * 41 | * @param requests list for the requests to be appended to 42 | */ 43 | virtual void fetch_proto_requests(std::list &requests) = 0; 44 | 45 | /** 46 | * Create all actual recalculation requests for given updated signal range. 47 | * The generated requests will be added to the given list. 48 | * 49 | * @param signal_range range of the signal samples that have been changed 50 | * (e.g. as returned from the last call to subtract_from_signal) 51 | * @param requests list for the requests to be appended to 52 | */ 53 | virtual void fetch_requests(IndexRange signal_range, std::list &requests) = 0; 54 | 55 | virtual ~Dictionary() = default; 56 | }; 57 | 58 | #endif //EMPI_DICTIONARY_H 59 | -------------------------------------------------------------------------------- /include/EpochIndex.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015–2018 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_EPOCH_INDEX_H 7 | #define EMPI_EPOCH_INDEX_H 8 | 9 | struct EpochIndex { 10 | int epoch_counter; // continuous, counting from 0 11 | int epoch_offset; // epoch offset (not necessarily continuous), the first epoch being 0 12 | int channel_offset; 13 | 14 | EpochIndex() 15 | : epoch_counter(0), epoch_offset(0), channel_offset(0) {} 16 | 17 | EpochIndex(int epoch_counter, int epoch_offset, int channel_offset = 0) 18 | : epoch_counter(epoch_counter), epoch_offset(epoch_offset), channel_offset(channel_offset) {} 19 | 20 | bool operator<(const EpochIndex& other) const { 21 | if (epoch_counter < other.epoch_counter) { 22 | return true; 23 | } 24 | if (epoch_counter > other.epoch_counter) { 25 | return false; 26 | } 27 | return (channel_offset < other.channel_offset); 28 | } 29 | }; 30 | 31 | //------------------------------------------------------------------------------ 32 | 33 | #endif //EMPI_EPOCH_INDEX_H 34 | -------------------------------------------------------------------------------- /include/ExportedAtom.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_EXPORTED_ATOM_H 7 | #define EMPI_EXPORTED_ATOM_H 8 | 9 | #include 10 | 11 | /** 12 | * Uniform plain data representation for all different types of dictionary atoms. 13 | */ 14 | class ExportedAtom { 15 | public: 16 | /** atom's envelope type (e.g. "gauss") */ 17 | std::string envelope; 18 | 19 | /** atom's amplitude */ 20 | double amplitude; 21 | 22 | /** atom's energy, computed as sum of all samples squared */ 23 | double energy; 24 | 25 | /** atom's frequency between 0 and 0.5 (1 would be sampling frequency) */ 26 | double frequency; 27 | 28 | /** atom's phase between ±π */ 29 | double phase; 30 | 31 | /** atom's scale in samples */ 32 | double scale; 33 | 34 | /** atom's central position in samples */ 35 | double position; 36 | 37 | /** 38 | * Create a new ExportedAtom with a given energy. 39 | * All other fields will be initialized to NAN. 40 | * 41 | * @param energy atom's energy, computed as sum of all samples squared 42 | */ 43 | explicit ExportedAtom(double energy); 44 | }; 45 | 46 | #endif //EMPI_EXPORTED_ATOM_H 47 | -------------------------------------------------------------------------------- /include/ExtenderLoop.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_EXTENDER_LOOP_H 7 | #define EMPI_EXTENDER_LOOP_H 8 | 9 | #include 10 | #include "Atom.h" 11 | #include "TaskQueue.h" 12 | 13 | /** 14 | * Runnable object used internally in Computer to form separate consumer threads for BlockAtom::extend() calculations. 15 | */ 16 | class ExtenderLoop { 17 | std::shared_ptr> task_queue; 18 | 19 | public: 20 | explicit ExtenderLoop(std::shared_ptr> task_queue) 21 | : task_queue(std::move(task_queue)) {} 22 | 23 | void operator()(bool wait = true) { 24 | BasicAtomPointer atom; 25 | while (task_queue->get(atom, wait)) { 26 | (void) atom->extend(true); // ExtendedAtomPointer will be put into cache 27 | task_queue->notify(); 28 | } 29 | } 30 | }; 31 | 32 | #endif //EMPI_EXTENDER_LOOP_H 33 | -------------------------------------------------------------------------------- /include/ExtraData.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_EXTRA_DATA_H 7 | #define EMPI_EXTRA_DATA_H 8 | 9 | /** 10 | * Plain data structure for channel-related additional atom data. 11 | */ 12 | struct ExtraData { 13 | /** atom's amplitude, must be multiplied by family->value(0.0) / std::sqrt(scale) */ 14 | double amplitude; 15 | 16 | /** atom's energy, computed as sum of all samples squared */ 17 | double energy; 18 | 19 | /** atom's phase between ±π */ 20 | double phase; 21 | }; 22 | 23 | #endif //EMPI_EXTRA_DATA_H 24 | -------------------------------------------------------------------------------- /include/Extractor.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_EXTRACTOR_H 7 | #define EMPI_EXTRACTOR_H 8 | 9 | #include "ExtraData.h" 10 | #include "Corrector.h" 11 | #include "Types.h" 12 | 13 | /** 14 | * Plain data structure storing a frequency index (bin) of the best-match atom as well as its energy. 15 | */ 16 | struct ExtractedMaximum { 17 | /** atom's energy, computed as sum of all samples squared */ 18 | double energy; 19 | 20 | /** index of the bin in discrete Fourier transform where 0 corresponds to zero frequency and so forth */ 21 | int bin_index; 22 | 23 | bool operator<(const ExtractedMaximum &other) const { 24 | return energy < other.energy; 25 | } 26 | }; 27 | 28 | using Extractor = ExtractedMaximum (*)(int channel_count, int output_bins, complex *const *channels, const Corrector *correctors, 29 | double *bins_buffer, ExtraData *atom_data); 30 | 31 | /** 32 | * Extractor implementation for single-channel decomposition. 33 | */ 34 | ExtractedMaximum extractorSingleChannel(int channel_count, int output_bins, complex *const *channels, const Corrector *correctors, 35 | double *bins_buffer, ExtraData *atom_data); 36 | 37 | /** 38 | * Extractor implementation for multi-channel decomposition where atoms in all signals share position, frequency and scale, 39 | * but amplitudes as well as phases may differ across channels. 40 | */ 41 | ExtractedMaximum extractorVariablePhase(int channel_count, int output_bins, complex *const *channels, const Corrector *correctors, 42 | double *bins_buffer, ExtraData *atom_data); 43 | 44 | /** 45 | * Extractor implementation for multi-channel decomposition where atoms in all signals share position, frequency, scale as well as phase, 46 | * but amplitudes may differ across channels. 47 | */ 48 | ExtractedMaximum extractorConstantPhase(int channel_count, int output_bins, complex *const *channels, const Corrector *correctors, 49 | double *bins_buffer, ExtraData *atom_data); 50 | 51 | #endif //EMPI_EXTRACTOR_H 52 | -------------------------------------------------------------------------------- /include/File.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_FILE_H 7 | #define EMPI_FILE_H 8 | 9 | #include 10 | #include 11 | 12 | /** 13 | * Simple smart pointer to FILE based on std::shared_ptr. 14 | */ 15 | class File : public std::shared_ptr { 16 | public: 17 | /** 18 | * Create a null file handle. 19 | */ 20 | File() = default; 21 | 22 | /** 23 | * Create a new file handle. 24 | * If fopen() call fails, throw an exception. 25 | * 26 | * @param path path to the file 27 | * @param mode as in fopen() 28 | */ 29 | File(const char *path, const char *mode); 30 | }; 31 | 32 | /** 33 | * Smart pointer to binary file in read mode. 34 | * Includes a method to get size of the file. 35 | */ 36 | class FileToRead : public File { 37 | const size_t file_size; 38 | 39 | static size_t read_file_size(FILE* file); 40 | 41 | public: 42 | /** 43 | * Create a new file handle opened in read-only binary mode. 44 | * If fopen() call fails, throw an exception. 45 | * 46 | * @param path path to the file 47 | */ 48 | explicit FileToRead(const char* path); 49 | 50 | /** 51 | * @return file size in bytes or 0 if file size could not be queried 52 | */ 53 | [[nodiscard]] size_t get_file_size() const; 54 | }; 55 | 56 | #endif //EMPI_FILE_H 57 | -------------------------------------------------------------------------------- /include/GaussianFamily.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_GAUSSIAN_FAMILY_H 7 | #define EMPI_GAUSSIAN_FAMILY_H 8 | 9 | #include "Family.h" 10 | 11 | /** 12 | * Family implementation for Gaussian envelope (Gabor atoms). 13 | */ 14 | class GaussianFamily : public FamilyTemplate { 15 | const double min_max; 16 | 17 | public: 18 | static inline const double DEFAULT_MIN_MAX = 1.5; 19 | 20 | /** 21 | * Create a new instance representing a Gaussian envelope. 22 | * 23 | * @param min_max half-width of envelope function i.e. for |t| larger than this value, f(t) will be assumed as zero 24 | */ 25 | explicit GaussianFamily(double min_max = DEFAULT_MIN_MAX); 26 | 27 | double max_arg() const final; 28 | 29 | double min_arg() const final; 30 | 31 | const char *name() const final; 32 | 33 | double value(double t) const final; 34 | 35 | double scale_integral(double log_scale) const final; 36 | 37 | double inv_scale_integral(double value) const final; 38 | 39 | double freq_integral(double x) const final; 40 | 41 | double inv_freq_integral(double value) const final; 42 | 43 | double skew_integral(double x) const final; 44 | 45 | double time_integral(double x) const final; 46 | 47 | double inv_time_integral(double value) const final; 48 | 49 | double optimality_factor_e2(double epsilon2) const final; 50 | 51 | double optimality_factor_sf(double scale_frequency) const final; 52 | }; 53 | 54 | #endif //EMPI_GAUSSIAN_FAMILY_H 55 | -------------------------------------------------------------------------------- /include/IndexRange.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_INDEX_RANGE_H 7 | #define EMPI_INDEX_RANGE_H 8 | 9 | #include "Types.h" 10 | 11 | /** 12 | * Simple data object representing a range of samples in a signal. 13 | * Represented range is from first_index (inclusive) to end_index (exclusive), 14 | * which means the last sample in the range has an index of end_index-1. 15 | */ 16 | struct IndexRange { 17 | /** 18 | * Index of the first sample in the range. 19 | */ 20 | index_t first_index; 21 | 22 | /** 23 | * Index of the sample following the last sample in the range. 24 | */ 25 | index_t end_index; 26 | 27 | /** 28 | * Create a new range starting at the beginning of the signal, with length (in samples) equal to end_index. 29 | * 30 | * @param end_index length of the range = index of the sample following the last sample in the range 31 | */ 32 | IndexRange(index_t end_index = 0) : first_index(0), end_index(end_index) {} // NOLINT 33 | 34 | /** 35 | * Create a new range. 36 | * 37 | * @param first_index index of the first sample in the range 38 | * @param end_index index of the sample following the last sample in the range 39 | */ 40 | IndexRange(index_t first_index, index_t end_index) : first_index(first_index), end_index(end_index) {} 41 | 42 | /** 43 | * Calculate the intersection between two ranges. 44 | * 45 | * @param other the other range 46 | * @return intersection of the two ranges, or IndexRange(0) if the ranges don't overlap 47 | */ 48 | [[nodiscard]] IndexRange overlap(const IndexRange &other) const; 49 | 50 | [[nodiscard]] bool includes(index_t index) const; 51 | 52 | /** 53 | * @return true if range is empty, false otherwise 54 | */ 55 | bool operator!() const; 56 | }; 57 | 58 | #endif //EMPI_INDEX_RANGE_H 59 | -------------------------------------------------------------------------------- /include/Logger.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_LOGGER_H 7 | #define EMPI_LOGGER_H 8 | 9 | #include 10 | 11 | class Logger 12 | { 13 | static void header(const char* type); 14 | 15 | static void message(const char* type, const char* string); 16 | 17 | template 18 | static void message(const char* type, const char* format, VALUES... parameters) { 19 | header(type); 20 | fprintf(stderr, format, parameters...); 21 | fputc('\n', stderr); 22 | } 23 | 24 | public: 25 | Logger() = delete; 26 | 27 | template 28 | static void error(const char* format, VALUES... parameters) { 29 | message("ERROR", format, parameters...); 30 | } 31 | 32 | template 33 | static void info(const char* format, VALUES... parameters) { 34 | message("INFO", format, parameters...); 35 | } 36 | 37 | template 38 | static void internal_error(const char* format, VALUES... parameters) { 39 | message("INTERNAL ERROR", format, parameters...); 40 | } 41 | }; 42 | 43 | #endif //EMPI_LOGGER_H 44 | -------------------------------------------------------------------------------- /include/OptimizationMode.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_OPTIMIZATION_MODE_H 7 | #define EMPI_OPTIMIZATION_MODE_H 8 | 9 | enum OptimizationMode { 10 | OPTIMIZATION_DISABLED = 0, 11 | OPTIMIZATION_LOCAL = 1, 12 | OPTIMIZATION_GLOBAL = 2 13 | }; 14 | 15 | #endif //EMPI_OPTIMIZATION_MODE_H 16 | -------------------------------------------------------------------------------- /include/PinnedArray.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_PINNED_ARRAY_H 7 | #define EMPI_PINNED_ARRAY_H 8 | 9 | #include "Array.h" 10 | 11 | /** 12 | * Disable CUDA-specific routines for memory allocation. 13 | * Call to this method should be performed before the first call to cuda_host_malloc. 14 | */ 15 | void cuda_host_disable(); 16 | 17 | /** 18 | * Allocate pinned memory area of a given size. 19 | * This defaults to ordinary malloc in case of CUDA-less compilation. 20 | * 21 | * @param length requested size of the memory area 22 | * @return pointer to memory area 23 | * @throw std::bad_alloc if allocation fails 24 | */ 25 | void *cuda_host_malloc(size_t length); 26 | 27 | /** 28 | * Free given area of pinned memory, previously allocated with cuda_host_malloc. 29 | * This defaults to ordinary free in case of CUDA-less compilation. 30 | * 31 | * @param pointer pointer to memory area 32 | */ 33 | void cuda_host_free(void *pointer); 34 | 35 | template 36 | T *cuda_host_alloc(size_t length) { 37 | return reinterpret_cast(cuda_host_malloc(sizeof(T) * length)); 38 | } 39 | 40 | /** 41 | * Subclass of Array1D template for allocating one-dimensional arrays 42 | * using page-locked (pinned) memory which can be safely used with a wider range of CUDA calls. 43 | * 44 | * @tparam T type of elements to be stored in the array 45 | */ 46 | template 47 | class PinnedArray1D : public Array1D { 48 | public: 49 | /** 50 | * Create an empty 1-D array. No memory will be allocated. 51 | */ 52 | PinnedArray1D() : Array1D() {} 53 | 54 | /** 55 | * Create a 1-D array of the requested size, using pinned memory allocators. 56 | * Array's values won't be initialized. 57 | * 58 | * @param length number of elements to be stored in array 59 | */ 60 | explicit PinnedArray1D(index_t length) 61 | : Array1D(length, cuda_host_alloc, cuda_host_free) {} 62 | }; 63 | 64 | /** 65 | * Subclass of Array1D template for allocating two-dimensional arrays 66 | * using page-locked (pinned) memory which can be safely used with a wider range of CUDA calls. 67 | * 68 | * @tparam T type of elements to be stored in the array 69 | */ 70 | template 71 | class PinnedArray2D : public Array2D { 72 | public: 73 | /** 74 | * Create an empty 2-D array. No memory will be allocated. 75 | */ 76 | PinnedArray2D() : Array2D() {} 77 | 78 | /** 79 | * Create a 2-D array of the requested dimensions, using pinned memory allocators. 80 | * Array's values won't be initialized. 81 | * 82 | * @param height first dimension of the array (number of sub-arrays) 83 | * @param length number of elements to be stored in each sub-array 84 | */ 85 | explicit PinnedArray2D(int height, index_t length) 86 | : Array2D(height, length, cuda_host_alloc, cuda_host_free) {} 87 | }; 88 | 89 | #endif //EMPI_PINNED_ARRAY_H 90 | -------------------------------------------------------------------------------- /include/Progress.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_PROGRESS_H 7 | #define EMPI_PROGRESS_H 8 | 9 | #include 10 | #include 11 | #include "EpochIndex.h" 12 | 13 | class Progress { 14 | const int total_epoch_count; 15 | int epochs_completed; 16 | int last_progress; 17 | 18 | std::map progress_map; 19 | std::mutex mutex; 20 | 21 | void print_progress(); 22 | 23 | public: 24 | explicit Progress(int total_epoch_count); 25 | 26 | void epoch_started(EpochIndex index); 27 | 28 | void epoch_progress(EpochIndex index, double progress); 29 | 30 | void epoch_finished(EpochIndex index); 31 | }; 32 | 33 | #endif //EMPI_PROGRESS_H 34 | -------------------------------------------------------------------------------- /include/ProtoRequest.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_PROTO_REQUEST_H 7 | #define EMPI_PROTO_REQUEST_H 8 | 9 | /** 10 | * Plain data object consisting of the subset of fields from SpectrogramRequest 11 | * which can be known and obtained in advance from Dictionary instances. 12 | * This way, SpectrogramCalculator instances may pre-allocate all needed internal buffers. 13 | */ 14 | struct ProtoRequest { 15 | 16 | /** number of channels in input data */ 17 | int channel_count; 18 | 19 | /** shift (in samples) between inputs for consecutive transforms */ 20 | int input_shift; 21 | 22 | /** number of transforms to perform */ 23 | int how_many; 24 | 25 | /** length (in samples) of the envelope */ 26 | int envelope_length; 27 | 28 | /** number of samples for calculating each transform, can be larger than envelope_length (zero-padding) */ 29 | int window_length; 30 | 31 | /** number of meaningful output bins from each transform (must not exceed window_length/2+1) */ 32 | int output_bins; 33 | 34 | }; 35 | 36 | #endif //EMPI_PROTO_REQUEST_H 37 | -------------------------------------------------------------------------------- /include/Semaphore.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by piotr on 09.02.2022. 3 | // 4 | 5 | #ifndef EMPI_SEMAPHORE_H 6 | #define EMPI_SEMAPHORE_H 7 | 8 | #include 9 | #include 10 | 11 | class Semaphore { 12 | std::mutex mutex; 13 | std::condition_variable condition; 14 | unsigned long state = 0; 15 | 16 | public: 17 | explicit Semaphore(int state = 0) : state(state) {} 18 | 19 | Semaphore(const Semaphore &) = delete; 20 | 21 | Semaphore(Semaphore &&) = delete; 22 | 23 | void operator=(const Semaphore &) = delete; 24 | 25 | void release(int increment = 1) { 26 | std::lock_guard lock(mutex); 27 | state += increment; 28 | if (increment == 1) { 29 | condition.notify_one(); 30 | } else if (increment > 1) { 31 | condition.notify_all(); 32 | } 33 | } 34 | 35 | void acquire(int decrement = 1) { 36 | std::unique_lock lock(mutex); 37 | while (state < decrement) { 38 | condition.wait(lock); 39 | } 40 | state -= decrement; 41 | } 42 | }; 43 | 44 | #endif //EMPI_SEMAPHORE_H 45 | -------------------------------------------------------------------------------- /include/SpectrogramCalculator.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_SPECTROGRAM_CALCULATOR_H 7 | #define EMPI_SPECTROGRAM_CALCULATOR_H 8 | 9 | #include "SpectrogramRequest.h" 10 | 11 | /** 12 | * Interface for objects calculating spectrograms (short-time Fourier transforms) of real signals. 13 | */ 14 | class SpectrogramCalculator { 15 | public: 16 | /** 17 | * Compute a spectrogram of a given signal, according to the given specification. 18 | * 19 | * @param request specification for the spectrogram to be computed 20 | */ 21 | virtual void compute(const SpectrogramRequest &request) = 0; 22 | 23 | virtual ~SpectrogramCalculator() = default; 24 | }; 25 | 26 | #endif //EMPI_SPECTROGRAM_CALCULATOR_H 27 | -------------------------------------------------------------------------------- /include/SpectrogramCalculatorCUDA.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_SPECTROGRAM_CALCULATOR_CUDA_H 7 | #define EMPI_SPECTROGRAM_CALCULATOR_CUDA_H 8 | 9 | #include 10 | #include 11 | #include "SpectrogramCalculator.h" 12 | 13 | class CudaTask; 14 | 15 | /** 16 | * Short-time Fourier transform calculator based on CUDA and CuFFT. 17 | */ 18 | class SpectrogramCalculatorCUDA : public SpectrogramCalculator { 19 | std::shared_ptr task; 20 | 21 | public: 22 | /** 23 | * Create a new CUDA calculator for multi-channel signals with given number of channels, 24 | * able to serve pre-defined set of request templates. 25 | * 26 | * @param channel_count number of channels in the signal 27 | * @param proto_requests list of request templates that will be passed to compute() later on 28 | * @param device ID of the device on which all calculations should be performed 29 | */ 30 | SpectrogramCalculatorCUDA(int channel_count, const std::list& proto_requests, int device); 31 | 32 | /** 33 | * Compute a spectrogram of a given signal, according to the given specification. 34 | * 35 | * @param request specification for the spectrogram to be computed 36 | */ 37 | void compute(const SpectrogramRequest &request) final; 38 | }; 39 | 40 | #endif //EMPI_SPECTROGRAM_CALCULATOR_CUDA_H 41 | -------------------------------------------------------------------------------- /include/SpectrogramCalculatorDummy.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_SPECTROGRAM_CALCULATOR_DUMMY_H 7 | #define EMPI_SPECTROGRAM_CALCULATOR_DUMMY_H 8 | 9 | #include "SpectrogramCalculator.h" 10 | 11 | /** 12 | * Dummy short-time Fourier transform calculator, used only for testing purposes. 13 | */ 14 | class SpectrogramCalculatorDummy : public SpectrogramCalculator { 15 | public: 16 | /** 17 | * Compute a spectrogram of a given signal, according to the given specification. 18 | * 19 | * @param request specification for the spectrogram to be computed 20 | */ 21 | void compute(const SpectrogramRequest &request) final; 22 | }; 23 | 24 | #endif //EMPI_SPECTROGRAM_CALCULATOR_DUMMY_H 25 | -------------------------------------------------------------------------------- /include/SpectrogramCalculatorFFTW.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_SPECTROGRAM_CALCULATOR_FFTW_H 7 | #define EMPI_SPECTROGRAM_CALCULATOR_FFTW_H 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include "Array.h" 14 | #include "SpectrogramCalculator.h" 15 | #include "SpectrumCalculator.h" 16 | 17 | /** 18 | * Short-time Fourier transform calculator based on FFTW. 19 | */ 20 | class SpectrogramCalculatorFFTW : public SpectrogramCalculator, public SpectrumCalculator { 21 | const int max_window_length; 22 | Array2D input_buffers; 23 | Array2D output_buffers; 24 | std::map> plans; 25 | 26 | public: 27 | /** 28 | * Create a new FFTW calculator for multi-channel signals with given number of channels, 29 | * able to serve pre-defined set of possible window lengths. 30 | * 31 | * @param channel_count number of channels in the signal 32 | * @param window_lengths list of window lengths that will be used with compute() calls later on 33 | */ 34 | SpectrogramCalculatorFFTW(int channel_count, const std::set &window_lengths); 35 | 36 | /** 37 | * Create a new FFTW calculator based on plan wisdom from an existing calculator. 38 | * Both calculator will be independent, won't share any internal buffers and will be able to run in parallel. 39 | * 40 | * @param source existing worker to take plan wisdom from 41 | */ 42 | SpectrogramCalculatorFFTW(const SpectrogramCalculatorFFTW &source); 43 | 44 | /** 45 | * Compute a spectrogram of a given signal, according to the given specification. 46 | * 47 | * @param request specification for the spectrogram to be computed 48 | */ 49 | void compute(const SpectrogramRequest &request) final; 50 | 51 | /** 52 | * Compute spectrum of a given signal with optional zero-padding. 53 | * Return the pointer to the internal buffer which may be overwritten on next calls to computeSpectrum. 54 | * 55 | * @param input input signal 56 | * @param window_length length of the window for FFT, should be at least the size of input 57 | * @return pointer to the internal buffer with size of at last window_length/2+1 58 | */ 59 | const complex *computeSpectrum(Array1D input, int window_length) final; 60 | 61 | void operator=(const SpectrogramCalculatorFFTW &) = delete; 62 | 63 | private: 64 | [[nodiscard]] fftw_plan getPlan(int window_length) const; 65 | }; 66 | 67 | #endif //EMPI_SPECTROGRAM_CALCULATOR_FFTW_H 68 | -------------------------------------------------------------------------------- /include/SpectrogramLoop.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_SPECTROGRAM_LOOP_H 7 | #define EMPI_SPECTROGRAM_LOOP_H 8 | 9 | #include 10 | #include "SpectrogramCalculator.h" 11 | #include "SpectrogramRequest.h" 12 | #include "TaskQueue.h" 13 | 14 | /** 15 | * Runnable object used with SpectrogramCalculator instances to be run in separate consumer threads. 16 | */ 17 | class SpectrogramLoop { 18 | std::shared_ptr> task_queue; 19 | std::unique_ptr calculator; 20 | bool reverse_order; 21 | 22 | public: 23 | SpectrogramLoop(std::shared_ptr> task_queue, std::unique_ptr calculator, bool reverse_order = false) 24 | : task_queue(std::move(task_queue)), calculator(std::move(calculator)), reverse_order(reverse_order) {} 25 | 26 | void operator()(bool wait = true) { 27 | SpectrogramRequest task; 28 | while (task_queue->get(task, wait, reverse_order)) { 29 | calculator->compute(task); 30 | task.interface->notify(); 31 | task_queue->notify(); 32 | } 33 | } 34 | }; 35 | 36 | #endif //EMPI_SPECTROGRAM_LOOP_H 37 | -------------------------------------------------------------------------------- /include/SpectrogramRequest.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_SPECTROGRAM_REQUEST_H 7 | #define EMPI_SPECTROGRAM_REQUEST_H 8 | 9 | #include 10 | #include 11 | #include "BlockInterface.h" 12 | #include "Corrector.h" 13 | #include "Extractor.h" 14 | #include "ProtoRequest.h" 15 | #include "Types.h" 16 | 17 | /** 18 | * Plain data object for a single request for spectrogram calculation. 19 | * Such request can be served by SpectrogramCalculator implementations by 20 | * 1) computing a windowed spectrogram of each channel in given multi-channel input data, and 21 | * 2) using a specified Extractor implementation to find the maximum value of each spectrum. 22 | */ 23 | struct SpectrogramRequest : public ProtoRequest { 24 | 25 | /** 26 | * pointer to the multi-channel data, of at least channel_count×channel_length samples 27 | * (in case of CUDA calculator, each channel MUST point to page-locked host memory) 28 | */ 29 | const real *const *data; 30 | 31 | /** 32 | * length (in samples) of each channel (at least input_offset+envelope_length+(how_many-1)*input_shift) 33 | */ 34 | index_t channel_length; 35 | 36 | /** 37 | * offset of the first sample of the input for the first transform 38 | */ 39 | index_t input_offset; 40 | 41 | /** 42 | * envelope (window function) for calculating transforms (at least envelope_length items) 43 | * (in case of CUDA calculator, each channel MUST point to page-locked host memory) 44 | */ 45 | const real *envelope; 46 | 47 | /** 48 | * objects capable of transforming FFT output into normalized energy values (at least output_bins items) 49 | * (in case of CUDA calculator, each channel MUST point to page-locked host memory) 50 | */ 51 | const Corrector *correctors; 52 | 53 | /** 54 | * extractor implementation to reduce each spectrum to a single value 55 | */ 56 | Extractor extractor; 57 | 58 | /** 59 | * buffer for extracted maximum values, one for each transform (space for at least how_many items) 60 | * (in case of CUDA calculator, this MUST point to page-locked host memory) 61 | */ 62 | ExtractedMaximum *maxima; 63 | 64 | /** 65 | * listener to be notified after the request is complete 66 | */ 67 | BlockInterface *interface; 68 | 69 | void assertCorrectness() const { 70 | assert(data); 71 | assert(channel_length > 0); 72 | assert(channel_count > 0); 73 | assert(input_shift > 0); 74 | assert(how_many > 0); 75 | 76 | assert(envelope); 77 | assert(envelope_length > 0); 78 | assert(window_length > 0); 79 | assert(envelope_length <= window_length); 80 | 81 | assert(output_bins > 0); 82 | assert(output_bins <= window_length / 2 + 1); 83 | assert(correctors); 84 | assert(extractor); 85 | assert(maxima); 86 | } 87 | }; 88 | 89 | #endif //EMPI_SPECTROGRAM_REQUEST_H 90 | -------------------------------------------------------------------------------- /include/SpectrumCalculator.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_SPECTRUM_CALCULATOR_H 7 | #define EMPI_SPECTRUM_CALCULATOR_H 8 | 9 | #include "Array.h" 10 | #include "Types.h" 11 | 12 | /** 13 | * Interface for objects calculating spectra (discrete Fourier transforms) of real signals. 14 | */ 15 | class SpectrumCalculator { 16 | public: 17 | /** 18 | * Compute spectrum of a given signal with optional zero-padding. 19 | * Return the pointer to the internal buffer which may be overwritten on next calls to computeSpectrum. 20 | * 21 | * @param input input signal 22 | * @param window_length length of the window for FFT, should be at least the size of input 23 | * @return pointer to the internal buffer with size of at last window_length/2+1 24 | */ 25 | virtual const complex *computeSpectrum(Array1D input, int window_length) = 0; 26 | 27 | virtual ~SpectrumCalculator() = default; 28 | }; 29 | 30 | #endif //EMPI_SPECTRUM_CALCULATOR_H 31 | -------------------------------------------------------------------------------- /include/TaskQueue.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_TASK_QUEUE_H 7 | #define EMPI_TASK_QUEUE_H 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | /** 16 | * Synchronized (thread-safe) task queue for producer-consumer pattern with multiple consumers. 17 | * All public methods are thread-safe. 18 | * 19 | * @tparam T type of elements (tasks) that will be stored in queue. 20 | */ 21 | template 22 | class TaskQueue { 23 | std::mutex mutex; 24 | std::list tasks; 25 | bool terminated; 26 | size_t unfinished_count; 27 | 28 | std::condition_variable task_count_up_or_terminated, unfinished_count_down_or_terminated; 29 | 30 | public: 31 | /** 32 | * Create a new queue with no tasks. 33 | */ 34 | TaskQueue() : terminated(false), unfinished_count(0) {} 35 | 36 | /** 37 | * Add a task to the queue and wait until all tasks are finished. 38 | * This method should be called by the producer thread. 39 | */ 40 | void put(const T &task) { 41 | std::unique_lock lock(mutex); 42 | unfinished_count++; 43 | tasks.push_back(task); 44 | lock.unlock(); 45 | 46 | task_count_up_or_terminated.notify_one(); 47 | } 48 | 49 | /** 50 | * Add all tasks from the given list to the queue. 51 | * This method should be called by the producer thread. 52 | */ 53 | void put(std::list &&list_of_tasks) { 54 | std::unique_lock lock(mutex); 55 | unfinished_count += list_of_tasks.size(); 56 | tasks.splice(tasks.end(), list_of_tasks); 57 | lock.unlock(); 58 | 59 | task_count_up_or_terminated.notify_all(); 60 | } 61 | 62 | /** 63 | * Get a next task from the queue. If no tasks are available, wait until one became available and then return. 64 | * This method should be called by consumer threads. 65 | * After processing the task, each consumer should call notify() prior to get-ting the next one. 66 | * 67 | * @param task reference to store the task taken from the queue 68 | * @param wait if true, and the queue is empty, wait until the task is available or the queue terminates; 69 | * if false, return false immediately 70 | * @return true if task was taken from the queue, false if the queue was terminated instead 71 | * (or not available and wait=false) 72 | */ 73 | bool get(T &task, bool wait = true, bool from_back = false) { 74 | return from_back 75 | ? custom_get(task, wait, &decltype(tasks)::back, &decltype(tasks)::pop_back) 76 | : custom_get(task, wait, &decltype(tasks)::front, &decltype(tasks)::pop_front); 77 | } 78 | 79 | /** 80 | * Notify the producers that the task previously obtained with get() is now completed. 81 | */ 82 | void notify() { 83 | std::unique_lock lock(mutex); 84 | unfinished_count--; 85 | lock.unlock(); 86 | unfinished_count_down_or_terminated.notify_all(); 87 | } 88 | 89 | /** 90 | * Terminate the queue. All consumers and producers will exit. 91 | */ 92 | void terminate() { 93 | std::unique_lock lock(mutex); 94 | terminated = true; 95 | lock.unlock(); 96 | task_count_up_or_terminated.notify_all(); 97 | unfinished_count_down_or_terminated.notify_all(); 98 | } 99 | 100 | /** 101 | * Wait until all tasks are finished. 102 | * This method should be called by the producer thread. 103 | */ 104 | void wait_for_tasks() { 105 | std::unique_lock lock(mutex); 106 | unfinished_count_down_or_terminated.wait(lock, [&] { return terminated || !unfinished_count; }); 107 | if (unfinished_count) { 108 | throw std::logic_error("TaskQueue was terminated while waiting for tasks"); 109 | } 110 | lock.unlock(); 111 | } 112 | 113 | private: 114 | bool custom_get(T &task, bool wait, T &(std::list::*ref)(), void (std::list::*pop)()) { 115 | std::unique_lock lock(mutex); 116 | if (wait) { 117 | task_count_up_or_terminated.wait(lock, [&] { return terminated || !tasks.empty(); }); 118 | } else if (tasks.empty()) { 119 | return false; 120 | } 121 | bool result = false; 122 | if (!terminated) { 123 | task = (tasks.*ref)(); 124 | (tasks.*pop)(); 125 | result = true; 126 | } 127 | 128 | return result; 129 | } 130 | }; 131 | 132 | #endif //EMPI_TASK_QUEUE_H 133 | -------------------------------------------------------------------------------- /include/Testing.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_TESTING_H 7 | #define EMPI_TESTING_H 8 | 9 | #include 10 | #include 11 | 12 | #define ASSERT(ACTUAL) if (!(ACTUAL)) errx(EXIT_FAILURE, #ACTUAL " evaluates to false") 13 | #define ASSERT_EQUALS(EXPECTED, ACTUAL) if ((EXPECTED) != (ACTUAL)) errx(EXIT_FAILURE, #ACTUAL " does not equal expected " #EXPECTED) 14 | #define ASSERT_NEAR_ZERO(ACTUAL) if (std::abs(ACTUAL) > 1.0e-6) errx(EXIT_FAILURE, #ACTUAL " (%le) is significantly nonzero", (ACTUAL)) 15 | #define ASSERT_APPROX(EXPECTED, ACTUAL, EPSILON) if (std::abs((EXPECTED)-(ACTUAL)) > (EPSILON)) \ 16 | errx(EXIT_FAILURE, #ACTUAL " (%lf) is significantly (" #EPSILON ") different than " #EXPECTED " (%lf)", (ACTUAL), (EXPECTED)) 17 | #define ASSERT_SAME_PHASE(EXPECTED, ACTUAL, EPSILON) if (std::min(std::abs((EXPECTED)-(ACTUAL)), 2*M_PI-std::abs((EXPECTED)-(ACTUAL))) > (EPSILON)) \ 18 | errx(EXIT_FAILURE, #ACTUAL " (%lf) and " #EXPECTED " (%lf) differ significantly (" #EPSILON ") in phase", (ACTUAL), (EXPECTED)) 19 | 20 | #endif //EMPI_TESTING_H 21 | -------------------------------------------------------------------------------- /include/Thread.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_THREAD_H 7 | #define EMPI_THREAD_H 8 | 9 | #include 10 | #include 11 | 12 | class Thread : public std::thread 13 | { 14 | public: 15 | static bool affinity_enabled; 16 | 17 | template 18 | explicit Thread(Function function) : std::thread(std::move(function)) 19 | { 20 | affix_to_cpu(); 21 | } 22 | 23 | private: 24 | void affix_to_cpu(); 25 | }; 26 | 27 | #endif //EMPI_THREAD_H 28 | -------------------------------------------------------------------------------- /include/Timer.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_TIMER_H 7 | #define EMPI_TIMER_H 8 | 9 | #include 10 | #include 11 | 12 | class ElapsingTimer { 13 | const bool hasMessage_; 14 | 15 | std::chrono::time_point start_; 16 | 17 | ElapsingTimer(const ElapsingTimer &); 18 | 19 | ElapsingTimer &operator=(const ElapsingTimer &); 20 | 21 | public: 22 | ElapsingTimer(); 23 | 24 | ElapsingTimer(const char *msg, va_list ap); 25 | 26 | ~ElapsingTimer(); 27 | 28 | [[nodiscard]] float time() const; 29 | }; 30 | 31 | class Timer { 32 | std::unique_ptr timer_; 33 | 34 | public: 35 | Timer() =default; 36 | 37 | void start(); 38 | 39 | void start(const char *msg, ...) 40 | #ifdef __GNUC__ 41 | __attribute__ (( format (printf, 2, 3))) 42 | #endif 43 | ; 44 | 45 | [[nodiscard]] float time() const; 46 | 47 | void stop(); 48 | }; 49 | 50 | #endif //EMPI_TIMER_H 51 | -------------------------------------------------------------------------------- /include/TriangularFamily.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_TRIANGULAR_FAMILY_H 7 | #define EMPI_TRIANGULAR_FAMILY_H 8 | 9 | #include "Family.h" 10 | 11 | /** 12 | * Family implementation for triangular envelope. 13 | */ 14 | class TriangularFamily : public FamilyTemplate { 15 | public: 16 | /** 17 | * Create a new instance representing a triangular envelope. 18 | */ 19 | explicit TriangularFamily() = default; 20 | 21 | double max_arg() const final; 22 | 23 | double min_arg() const final; 24 | 25 | const char *name() const final; 26 | 27 | double value(double t) const final; 28 | 29 | double scale_integral(double log_scale) const final; 30 | 31 | double inv_scale_integral(double value) const final; 32 | 33 | double freq_integral(double x) const final; 34 | 35 | double inv_freq_integral(double value) const final; 36 | 37 | double skew_integral(double x) const final; 38 | 39 | double time_integral(double x) const final; 40 | 41 | double inv_time_integral(double value) const final; 42 | 43 | double optimality_factor_e2(double epsilon2) const final; 44 | 45 | double optimality_factor_sf(double scale_frequency) const final; 46 | }; 47 | 48 | #endif //EMPI_TRIANGULAR_FAMILY_H 49 | -------------------------------------------------------------------------------- /include/Types.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_TYPES_H 7 | #define EMPI_TYPES_H 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | using index_t = std::ptrdiff_t; 16 | using real = double; // TODO #ifdef SINGLE_PRECISION 17 | using complex = std::complex; 18 | 19 | template 20 | inline size_t mulx(X x, Y y) { 21 | // TODO range checking 22 | return static_cast(x) * static_cast(y); 23 | } 24 | 25 | template 26 | int as_positive_int(T x) { 27 | assert(x > 0); 28 | assert(static_cast(x) <= static_cast(std::numeric_limits::max())); 29 | return static_cast(x); 30 | } 31 | 32 | class Types { 33 | template 34 | static T rint(double x); 35 | 36 | public: 37 | template 38 | static T round(double x) { 39 | if (!(x > static_cast(std::numeric_limits::lowest()) && x < static_cast(std::numeric_limits::max()))) { 40 | throw std::runtime_error("rounding overflow detected"); 41 | } 42 | return rint(x); 43 | } 44 | 45 | private: 46 | template 47 | static T round_with_mode(double x, int rounding_mode) { 48 | const int default_mode = fegetround(); 49 | fesetround(rounding_mode); 50 | T result = round(x); 51 | fesetround(default_mode); 52 | return result; 53 | } 54 | 55 | public: 56 | template 57 | static T ceil(double x) { 58 | return round_with_mode(x, FE_UPWARD); 59 | } 60 | 61 | template 62 | static T floor(double x) { 63 | return round_with_mode(x, FE_DOWNWARD); 64 | } 65 | }; 66 | 67 | template<> 68 | int Types::rint(double x); 69 | 70 | template<> 71 | long Types::rint(double x); 72 | 73 | template<> 74 | long long Types::rint(double x); 75 | 76 | #endif //EMPI_TYPES_H 77 | -------------------------------------------------------------------------------- /include/Worker.h: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #ifndef EMPI_WORKER_H 7 | #define EMPI_WORKER_H 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include "Array.h" 16 | #include "BlockInterface.h" 17 | #include "Dictionary.h" 18 | #include "IndexRange.h" 19 | #include "OptimizationMode.h" 20 | #include "SpectrogramCalculator.h" 21 | #include "SpectrogramLoop.h" 22 | #include "SpectrogramRequest.h" 23 | #include "TaskQueue.h" 24 | #include "Thread.h" 25 | #include "Types.h" 26 | 27 | /** 28 | * The central point of the computation process. 29 | * Worker instances facilitate communication between Dictionary objects 30 | * (which have an internal cache pointing to the current best atom) 31 | * and SpectrogramCalculator objects performing the main part of calculations and allowing dictionaries to update. 32 | * Resulting decomposition can be accessed by repeated calls to get_next_atom(). 33 | */ 34 | class Worker { 35 | Array2D data; 36 | const OptimizationMode mode; 37 | IndexRange updated_index_range; 38 | 39 | std::shared_ptr> atom_queue; 40 | std::shared_ptr> spectrogram_queue; 41 | std::unique_ptr primary_calculator_loop; 42 | std::vector threads; 43 | 44 | std::list> dictionaries; 45 | 46 | public: 47 | /** 48 | * Create a new Worker to analyze given multi-channel signal. 49 | * Dictionaries and workers must be added manually after constructing the Worker instance. 50 | * 51 | * @param data reference to multi-channel data of the analysed signal 52 | * @param cpu_threads number of CPU threads to use for global optimization 53 | */ 54 | explicit Worker(Array2D data, unsigned cpu_threads, OptimizationMode mode = OPTIMIZATION_DISABLED); 55 | 56 | Worker(Worker&& source) = default; 57 | Worker(const Worker&) = delete; 58 | void operator=(const Worker&) = delete; 59 | 60 | /** 61 | * Destroy the Worker instance by terminating the task queue and closing all worker threads. 62 | */ 63 | ~Worker(); 64 | 65 | /** 66 | * Associate a SpectrogramCalculator object with this Worker instance. 67 | * It will be used during following calls to get_next_atom(). 68 | * Each calculator should only be added once. 69 | * 70 | * @param calculator smart pointer to SpectrogramCalculator instance 71 | * @param prefer_long_fft should be set to true in case of GPU calculators, false otherwise 72 | */ 73 | void add_calculator(std::unique_ptr calculator, bool prefer_long_fft = false); 74 | 75 | /** 76 | * Associate a Dictionary object with this Worker instance. 77 | * It will be used during following calls to get_next_atom(). 78 | * Each dictionary should only be added once. 79 | * 80 | * @param dictionary smart pointer to Dictionary instance 81 | */ 82 | void add_dictionary(std::unique_ptr dictionary); 83 | 84 | /** 85 | * Compute and return the atom that was a best match for the analyzed signal. 86 | * This atom will be subtracted from signal before returning, 87 | * so consecutive calls to get_next_atom() will provide consecutive atoms for the decomposition. 88 | * 89 | * @return best matching atom 90 | */ 91 | ExtendedAtomPointer get_next_atom(); 92 | 93 | /** 94 | * Create a list of all request templates that can be requested by this worker's dictionaries. 95 | */ 96 | std::list get_proto_requests(); 97 | 98 | /** 99 | * Reset the internal state of all dictionaries. 100 | * This method should be called whenever the input signal has been replaced with a new signal segment 101 | * and computation should be started from scratch. 102 | */ 103 | void reset(); 104 | }; 105 | 106 | #endif //EMPI_WORKER_H 107 | -------------------------------------------------------------------------------- /run-benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | S=2000000 3 | N=256 4 | LOG="`git rev-parse --short HEAD`.log" 5 | rm -f $LOG 6 | while [ $N -lt $S ] ; do 7 | printf "%7d " $N | tee -a $LOG 8 | ./benchmark $S $N | tee -a $LOG 9 | N=`expr 2 \* $N` 10 | done 11 | -------------------------------------------------------------------------------- /src/BlockAtomBase.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include "BlockAtomBase.h" 7 | 8 | ////////////////////////////////////////////////////////////////////////////// 9 | 10 | BlockAtomParamsConverter::BlockAtomParamsConverter(double frequency_step, double position_step, double log_scale_step, double frequency_max) 11 | : frequency_step(frequency_step), position_step(position_step), log_scale_step(log_scale_step), 12 | max_scaled_frequency(frequency_max / frequency_step) {} 13 | 14 | std::array BlockAtomParamsConverter::arrayFromParams(const BlockAtomParams ¶ms) const { 15 | return { 16 | params.frequency / frequency_step, 17 | params.position / position_step, 18 | std::log(params.scale) / log_scale_step 19 | }; 20 | } 21 | 22 | std::pair BlockAtomParamsConverter::paramsFromArray(const std::array &array) const { 23 | double scaled_frequency = array[0]; 24 | double scaled_position = array[1]; 25 | double scaled_log_scale = array[2]; 26 | 27 | double penalty = fix_frequency(scaled_frequency) + fix_log_scale(scaled_log_scale); 28 | 29 | return std::make_pair( 30 | BlockAtomParams( 31 | scaled_frequency * frequency_step, 32 | scaled_position * position_step, 33 | std::exp(scaled_log_scale * log_scale_step) 34 | ), 35 | penalty 36 | ); 37 | } 38 | 39 | double BlockAtomParamsConverter::fix_frequency(double &scaled_frequency) const { 40 | return fix_scaled_argument(scaled_frequency, 0.0, max_scaled_frequency); 41 | } 42 | 43 | double BlockAtomParamsConverter::fix_log_scale(double &scaled_log_scale) const { 44 | return 0.0; 45 | } 46 | 47 | double BlockAtomParamsConverter::fix_scaled_argument(double &value, double min, double max) { 48 | double penalty = 0.0; 49 | if (value < min) { 50 | penalty += min - value; 51 | value = min; 52 | } else if (value > max) { 53 | penalty += value - max; 54 | value = max; 55 | } 56 | return penalty; 57 | } 58 | 59 | ////////////////////////////////////////////////////////////////////////////// 60 | 61 | BlockAtomParamsConverterBounded::BlockAtomParamsConverterBounded(double frequency_step, double position_step, double log_scale_step, 62 | double frequency_max, double min_scale, double max_scale) 63 | : BlockAtomParamsConverter(frequency_step, position_step, log_scale_step, frequency_max), 64 | min_scaled_log_scale(std::log(min_scale) / log_scale_step), 65 | max_scaled_log_scale(std::log(max_scale) / log_scale_step) {} 66 | 67 | double BlockAtomParamsConverterBounded::fix_log_scale(double &scaled_log_scale) const { 68 | return fix_scaled_argument(scaled_log_scale, min_scaled_log_scale, max_scaled_log_scale); 69 | } 70 | -------------------------------------------------------------------------------- /src/BlockAtomCache.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "BlockAtomCache.h" 8 | 9 | ////////////////////////////////////////////////////////////////////////////// 10 | 11 | std::shared_ptr BlockAtomCache::get(size_t key) { 12 | std::shared_lock items_lock_guard(items_mutex); 13 | auto iterator = items.find(key); 14 | if (iterator == items.end()) { 15 | return nullptr; 16 | } 17 | return iterator->second.atom; 18 | } 19 | 20 | void BlockAtomCache::set(size_t key, IndexRange index_range, std::shared_ptr atom) { 21 | std::unique_lock items_lock_guard(items_mutex); 22 | items.insert_or_assign(key, BlockAtomCacheItem{index_range, std::move(atom)}); 23 | } 24 | 25 | void BlockAtomCache::remove_overlapping(IndexRange range_to_overlap) { 26 | std::unique_lock items_lock_guard(items_mutex); 27 | for (auto iterator = items.begin(); iterator != items.end();) { 28 | if (!iterator->second.index_range.overlap(range_to_overlap)) { 29 | ++iterator; 30 | } else { 31 | iterator = items.erase(iterator); 32 | } 33 | } 34 | } 35 | 36 | ////////////////////////////////////////////////////////////////////////////// 37 | 38 | BlockAtomCacheSlot::BlockAtomCacheSlot(const std::shared_ptr& cache, size_t key) : cache(cache), key(key) {} 39 | 40 | std::shared_ptr BlockAtomCacheSlot::get() const { 41 | return cache.lock()->get(key); 42 | } 43 | 44 | void BlockAtomCacheSlot::set(IndexRange index_range, std::shared_ptr atom) { 45 | cache.lock()->set(key, index_range, std::move(atom)); 46 | } 47 | -------------------------------------------------------------------------------- /src/BlockAtomObjective.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include "BlockAtomObjective.h" 7 | 8 | ////////////////////////////////////////////////////////////////////////////// 9 | 10 | BlockAtomObjective::BlockAtomObjective(std::shared_ptr family, Array2D data_, Extractor extractor, 11 | std::shared_ptr converter_) : 12 | family(std::move(family)), 13 | channel_count(data_.height()), 14 | data(std::move(data_)), 15 | products(channel_count, 1), 16 | extractor(extractor), 17 | converter(std::move(converter_)) {} 18 | 19 | double BlockAtomObjective::calculate_energy(const std::array &array, double *out_norm, ExtraData *out_extra_data) { 20 | auto [params, penalty] = converter->paramsFromArray(array); 21 | return std::exp(-penalty) * calculate_energy(params, out_norm, out_extra_data); 22 | } 23 | 24 | double BlockAtomObjective::calculate_energy(const BlockAtomParams& params, double *out_norm, ExtraData *out_extra_data) { 25 | index_t envelope_length = family->size_for_values(params.position, params.scale, nullptr); 26 | index_t envelope_offset; 27 | envelope.resize(envelope_length); 28 | oscillating.resize(envelope_length); 29 | double norm = family->generate_values(params.position, params.scale, &envelope_offset, envelope.data(), true); 30 | if (out_norm) { 31 | *out_norm = norm; 32 | } 33 | 34 | const double omega = 2 * M_PI * params.frequency; 35 | for (index_t i=0; i < envelope_length; ++i) { 36 | const double t = static_cast(envelope_offset + i) - params.position; 37 | oscillating[i] = std::polar(envelope[i], -omega * t); 38 | } 39 | 40 | complex FT = 0.0; 41 | for (index_t i = 0; i < envelope_length; ++i) { 42 | FT += oscillating[i] * oscillating[i]; 43 | } 44 | index_t first_sample_offset = std::max(0, envelope_offset); 45 | index_t last_sample_offset = std::max(0, envelope_offset + envelope_length - 1); 46 | 47 | first_sample_offset = std::min(data.length() - 1, first_sample_offset); 48 | last_sample_offset = std::min(data.length() - 1, last_sample_offset); 49 | 50 | products.fill(0.0); 51 | for (index_t i = first_sample_offset; i <= last_sample_offset; ++i) { 52 | complex z = oscillating[i - envelope_offset]; 53 | for (int c = 0; c < channel_count; ++c) { 54 | products[c][0] += data[c][i] * z; 55 | } 56 | } 57 | 58 | Corrector corrector(FT); 59 | double tmp_for_extractor; 60 | double result = extractor(channel_count, 1, products.get(), &corrector, &tmp_for_extractor, out_extra_data).energy; 61 | return result; 62 | } 63 | -------------------------------------------------------------------------------- /src/BlockAtomProductCalculator.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include "BlockAtomProductCalculator.h" 7 | #include "Extractor.h" 8 | #include "nelder_mead.h" 9 | 10 | ////////////////////////////////////////////////////////////////////////////// 11 | 12 | BlockAtomProductCalculator::BlockAtomProductCalculator(std::shared_ptr family) 13 | : family(std::move(family)) {} 14 | 15 | double BlockAtomProductCalculator::calculate_squared_product(const BlockAtomParams &p, const BlockAtomParams &q, double q_phase) const { 16 | IndexRange p_range = family->compute_range(p.position, p.scale); 17 | IndexRange q_range = family->compute_range(q.position, q.scale); 18 | 19 | IndexRange pq_range = p_range.overlap(q_range); 20 | 21 | double product = 0.0; 22 | if (pq_range.first_index < pq_range.end_index) { 23 | IndexRange wide_range( 24 | std::min(p_range.first_index, q_range.first_index), 25 | std::max(p_range.end_index, q_range.end_index) 26 | ); 27 | complex sum_p2 = 0.0; 28 | double sum_ep2 = 0.0; 29 | double sum_q2 = 0.0; 30 | complex sum_pq = 0.0; 31 | for (index_t i = wide_range.first_index; i < wide_range.end_index; ++i) { 32 | complex p_value = 0.0; 33 | double q_value = 0.0; 34 | if (p_range.includes(i)) { 35 | double ep_value = family->value((i - p.position) / p.scale); 36 | p_value = std::polar(ep_value, 2 * M_PI * p.frequency * (i - p.position)); 37 | sum_ep2 += ep_value * ep_value; 38 | sum_p2 += p_value * p_value; 39 | } 40 | if (q_range.includes(i)) { 41 | double eq_value = family->value((i - q.position) / q.scale); 42 | q_value = eq_value * cos(2 * M_PI * q.frequency * (i - q.position) + q_phase); 43 | sum_q2 += q_value * q_value; 44 | } 45 | sum_pq += p_value * q_value; 46 | } 47 | 48 | Corrector corrector(sum_p2 / sum_ep2); 49 | product = corrector.compute(sum_pq / std::sqrt(sum_q2 * sum_ep2)); 50 | } 51 | return product; 52 | } 53 | 54 | double BlockAtomProductCalculator::calculate_squared_product(const BlockAtomParams &p, const BlockAtomParams &q) const { 55 | const int GRID = 10; 56 | double best_phase = 0.0; 57 | double min_product2 = INFINITY; 58 | for (int i=0; i( 67 | [&](const std::array& x) { 68 | return this->calculate_squared_product(p, q, x[0]); 69 | }, 70 | std::array{best_phase}, 71 | 0.01, 72 | std::array{0.1} 73 | ); 74 | if (result.ifault) { 75 | throw std::runtime_error("cannot find optimal phase"); 76 | } 77 | return result.ynewlo; 78 | } 79 | -------------------------------------------------------------------------------- /src/BlockDictionary.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include "BlockDictionary.h" 7 | #include "BlockHelper.h" 8 | 9 | static const int MAX_BLOCKS_PER_SCALE = 1000; // sanity check, quite arbitrary 10 | 11 | ////////////////////////////////////////////////////////////////////////////// 12 | 13 | BlockDictionary::BlockDictionary(const BlockDictionaryStructure& structure, const PinnedArray2D& data, 14 | Extractor extractor, SpectrumCalculator &calculator, bool allow_overstep) { 15 | const double booster = 1.0 / structure.family->optimality_factor_e2(structure.energy_error); 16 | 17 | for (const auto& bs : structure.block_structures) { 18 | int output_bins = Types::floor(bs.transform_size * structure.frequency_max) + 1; 19 | auto converter = std::make_shared( 20 | 1.0 / bs.transform_size, 21 | static_cast(bs.input_shift), 22 | structure.log_scale_step, 23 | structure.frequency_max, 24 | structure.scale_min, 25 | structure.scale_max 26 | ); 27 | 28 | if (bs.input_shift < 1) { 29 | double number_of_blocks_as_float = 1.0 / bs.input_shift; 30 | if (number_of_blocks_as_float > MAX_BLOCKS_PER_SCALE) { 31 | throw std::runtime_error("requested minimum atom scale is too small"); 32 | } 33 | int number_of_blocks = Types::round(number_of_blocks_as_float); 34 | for (int i=0; i(i) / static_cast(number_of_blocks); 36 | blocks.push_back( 37 | BlockHelper::create_block(data, structure.family, bs.scale, converter, booster, bs.transform_size, 38 | output_bins, 1, subsample_offset, extractor, calculator, allow_overstep) 39 | ); 40 | } 41 | } else { 42 | int input_shift = Types::round(bs.input_shift); 43 | blocks.push_back( 44 | BlockHelper::create_block(data, structure.family, bs.scale, converter, booster, bs.transform_size, 45 | output_bins, input_shift, 0.0, extractor, calculator, allow_overstep) 46 | ); 47 | } 48 | } 49 | } 50 | 51 | BlockDictionary::BlockDictionary(Block&& block) { 52 | blocks.push_back(block); 53 | } 54 | 55 | size_t BlockDictionary::get_atom_count() { 56 | size_t result = 0; 57 | for (auto &block : blocks) { 58 | result += block.get_atom_count(); 59 | } 60 | return result; 61 | } 62 | 63 | BasicAtomPointer BlockDictionary::get_best_match() { 64 | std::shared_ptr result; 65 | for (auto &block : blocks) { 66 | BlockAtom atom = block.get_best_match(); 67 | if (!result || *result < atom) { 68 | result = std::make_shared(std::move(atom)); 69 | } 70 | } 71 | if (!result) { 72 | return nullptr; 73 | } 74 | return result; 75 | } 76 | 77 | std::list BlockDictionary::get_candidate_matches(double energy_to_exceed) { 78 | // TODO 79 | std::list atoms; 80 | for (auto &block : blocks) { 81 | atoms.splice(atoms.end(), block.get_candidate_matches(energy_to_exceed)); 82 | } 83 | std::list result; 84 | for (auto &atom : atoms) { 85 | result.push_back(std::make_shared(std::move(atom))); 86 | } 87 | return result; 88 | } 89 | 90 | void BlockDictionary::fetch_proto_requests(std::list &requests) { 91 | for (auto &block: blocks) { 92 | requests.push_back(block.buildRequest()); 93 | } 94 | } 95 | 96 | void BlockDictionary::fetch_requests(IndexRange signal_range, std::list &requests) { 97 | for (auto &block: blocks) { 98 | SpectrogramRequest request = block.buildRequest(signal_range.first_index, signal_range.end_index); 99 | if (request.how_many > 0) { 100 | requests.push_back(request); 101 | } 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /src/BlockDictionaryStructure.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include "BlockDictionaryStructure.h" 7 | #include "BlockHelper.h" 8 | 9 | ////////////////////////////////////////////////////////////////////////////// 10 | 11 | BlockDictionaryStructure::BlockDictionaryStructure(std::shared_ptr family_, double energy_error, double scale_min, double scale_max, double frequency_max) 12 | : family(std::move(family_)), energy_error(energy_error), scale_min(scale_min), scale_max(scale_max), frequency_max(frequency_max), 13 | log_scale_step(family->inv_scale_integral(1 - energy_error)), 14 | dt_scale(family->inv_time_integral(1 - energy_error)), 15 | df_scale(family->inv_freq_integral(1 - energy_error)), 16 | block_structures(BlockHelper::compute_block_structures(family.get(), scale_min, scale_max, log_scale_step, df_scale, dt_scale)) 17 | { } 18 | 19 | [[nodiscard]] std::set BlockDictionaryStructure::get_transform_sizes() const 20 | { 21 | std::set transform_sizes; 22 | for (const auto& bs : block_structures) { 23 | transform_sizes.insert(transform_sizes.end(), bs.transform_size); 24 | } 25 | return transform_sizes; 26 | } 27 | -------------------------------------------------------------------------------- /src/BufferedWriter.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include "BufferedWriter.h" 7 | 8 | ////////////////////////////////////////////////////////////////////////////// 9 | 10 | BufferedWriter::BufferedWriter(int total_channel_count, int real_epoch_count, std::unique_ptr actual_writer) 11 | : BookWriter(*actual_writer), total_channel_count(total_channel_count), real_epoch_count(real_epoch_count), 12 | actual_writer(std::move(actual_writer)), 13 | semaphores(real_epoch_count), results(real_epoch_count) {} 14 | 15 | BufferedResult& BufferedWriter::get_result_ref(int epoch_counter, index_t sample_count) { 16 | BufferedResult& result = results[epoch_counter]; 17 | std::lock_guard lock(result.mutex); 18 | if (!result.data) { 19 | result.data = Array2D(total_channel_count, sample_count); 20 | } 21 | if (result.atoms.empty()) { 22 | result.atoms.resize(total_channel_count); 23 | } 24 | return result; 25 | } 26 | 27 | void BufferedWriter::finalize() { 28 | for (int epoch=0; epochwrite(result.data, result.index, result.atoms); 32 | result.data.reset(); 33 | result.atoms.clear(); 34 | } 35 | actual_writer->finalize(); 36 | } 37 | 38 | void BufferedWriter::write(Array2D data, EpochIndex index, const std::vector> &atoms) { 39 | const index_t sample_count = data.length(); 40 | BufferedResult& result = get_result_ref(index.epoch_counter, sample_count); 41 | result.index = EpochIndex{index.epoch_counter, index.epoch_offset, 0}; 42 | for (int c=0; c= 1) { 15 | re_factor = 2.0 / (1 + std::sqrt(norm_ft)); 16 | im_factor = re_factor * complex(0.0, 1.0); 17 | } else { 18 | const double common_factor = 2 / (1 - norm_ft); 19 | re_factor = (1.0 - ft) * common_factor; 20 | im_factor = (1.0 + ft) * common_factor * complex(0.0, 1.0); 21 | } 22 | } 23 | 24 | double Corrector::compute(complex value, ExtraData* extra) const { 25 | const complex corrected = value.real() * re_factor + value.imag() * im_factor; 26 | const complex corrected2 = corrected * corrected; 27 | const double norm_corrected = std::norm(corrected); 28 | const double energy = 0.5 * (norm_corrected + corrected2.real() * ft.real() + corrected2.imag() * ft.imag()); 29 | if (extra) { 30 | extra->amplitude = std::sqrt(norm_corrected); 31 | extra->energy = energy; 32 | extra->phase = std::arg(corrected); 33 | } 34 | return energy; 35 | } 36 | -------------------------------------------------------------------------------- /src/DeltaAtom.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include "DeltaAtom.h" 7 | 8 | ////////////////////////////////////////////////////////////////////////////// 9 | 10 | DeltaAtom::DeltaAtom(Array2D data, double energy, index_t position) 11 | : BasicAtom(std::move(data), energy), position(position) {} 12 | 13 | ExtendedAtomPointer DeltaAtom::extend(bool allow_optimization) { 14 | const int channel_count = data.height(); 15 | std::vector amplitudes(channel_count); 16 | for (int c = 0; c < channel_count; ++c) { 17 | amplitudes[c] = data[c][position]; 18 | } 19 | return std::make_shared(data, energy, position, std::move(amplitudes)); 20 | } 21 | 22 | [[nodiscard]] double DeltaAtom::get_energy_upper_bound() const { 23 | return energy; 24 | } 25 | 26 | ////////////////////////////////////////////////////////////////////////////// 27 | 28 | DeltaExtendedAtom::DeltaExtendedAtom(Array2D data, double energy, index_t position, std::vector &&litudes) 29 | : ExtendedAtom(std::move(data), energy), position(position), amplitudes(amplitudes) {} 30 | 31 | void DeltaExtendedAtom::export_atom(std::list *atoms) { 32 | const int channel_count = data.height(); 33 | for (int c = 0; c < channel_count; ++c) { 34 | ExportedAtom atom(amplitudes[c] * amplitudes[c]); 35 | atom.amplitude = amplitudes[c]; 36 | atom.envelope = "delta"; 37 | atom.position = static_cast(position); 38 | atoms[c].push_back(std::move(atom)); 39 | } 40 | } 41 | 42 | IndexRange DeltaExtendedAtom::subtract_from_signal() const { 43 | const int channel_count = data.height(); 44 | for (int c = 0; c < channel_count; ++c) { 45 | data[c][position] -= amplitudes[c]; 46 | } 47 | return {position, position + 1}; 48 | } 49 | -------------------------------------------------------------------------------- /src/DeltaDictionary.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include "DeltaAtom.h" 7 | #include "DeltaDictionary.h" 8 | 9 | ////////////////////////////////////////////////////////////////////////////// 10 | 11 | DeltaDictionary::DeltaDictionary(PinnedArray2D data) : data(std::move(data)) 12 | { } 13 | 14 | size_t DeltaDictionary::get_atom_count() { 15 | return data.length(); 16 | } 17 | 18 | BasicAtomPointer DeltaDictionary::get_best_match() { 19 | const int channel_count = data.height(); 20 | const index_t sample_count = data.length(); 21 | 22 | std::pair best_match; 23 | for (index_t i=0; i best_match.second) { 29 | best_match = {i, energy}; 30 | } 31 | } 32 | return std::make_shared(data, best_match.second, best_match.first); 33 | } 34 | 35 | void DeltaDictionary::fetch_requests(IndexRange signal_range, std::list &requests) { 36 | // TODO recalculate only part 37 | } 38 | -------------------------------------------------------------------------------- /src/ExportedAtom.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "ExportedAtom.h" 8 | 9 | ////////////////////////////////////////////////////////////////////////////// 10 | 11 | ExportedAtom::ExportedAtom(double energy) 12 | : amplitude(NAN), energy(energy), frequency(NAN), phase(NAN), scale(NAN), position(NAN) {} 13 | -------------------------------------------------------------------------------- /src/Family.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include "Family.h" 7 | #include "GaussianFamily.h" 8 | #include "TriangularFamily.h" 9 | 10 | const std::map> Family::ALL = { 11 | { "gauss", std::make_shared() }, 12 | { "triangular", std::make_shared() }, 13 | }; 14 | -------------------------------------------------------------------------------- /src/File.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "File.h" 8 | 9 | #ifdef _WIN32 10 | int fseeko(FILE *stream, off_t offset, int whence) { 11 | return _fseeki64(stream, static_cast<__int64>(offset), whence); 12 | } 13 | off_t ftello(FILE *stream) { 14 | return static_cast(_ftelli64(stream)); 15 | } 16 | #endif 17 | 18 | ////////////////////////////////////////////////////////////////////////////// 19 | 20 | File::File(const char *path, const char *mode) 21 | : std::shared_ptr(fopen(path, mode), [](FILE *f) { f && fclose(f); }) { 22 | if (!get()) { 23 | throw std::runtime_error("failed to open file"); 24 | } 25 | } 26 | 27 | ////////////////////////////////////////////////////////////////////////////// 28 | 29 | FileToRead::FileToRead(const char *path) 30 | : File(path, "rb"), file_size(read_file_size(get())) {} 31 | 32 | size_t FileToRead::read_file_size(FILE *file) { 33 | off_t end_position; 34 | if (fseeko(file, 0, SEEK_END) 35 | || (end_position = ftello(file)) < 0) { 36 | end_position = 0; // default value when file size could not be read 37 | } 38 | rewind(file); 39 | return end_position; 40 | } 41 | 42 | size_t FileToRead::get_file_size() const { 43 | return file_size; 44 | } 45 | -------------------------------------------------------------------------------- /src/GaussianFamily.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "GaussianFamily.h" 8 | 9 | static const double NORM = std::sqrt(M_SQRT2); 10 | 11 | ////////////////////////////////////////////////////////////////////////////// 12 | 13 | GaussianFamily::GaussianFamily(double min_max) : min_max(min_max) { 14 | assert(min_max > 0); 15 | } 16 | 17 | double GaussianFamily::max_arg() const { 18 | return min_max; 19 | } 20 | 21 | double GaussianFamily::min_arg() const { 22 | return -min_max; 23 | } 24 | 25 | const char *GaussianFamily::name() const { 26 | return "gauss"; 27 | } 28 | 29 | double GaussianFamily::value(double t) const { 30 | return NORM * std::exp(-M_PI * t * t); 31 | } 32 | 33 | double GaussianFamily::scale_integral(double log_scale) const { 34 | return 1.0 / std::sqrt(std::cosh(log_scale)); 35 | } 36 | 37 | double GaussianFamily::inv_scale_integral(double value) const { 38 | return std::acosh(1 / (value * value)); 39 | } 40 | 41 | double GaussianFamily::freq_integral(double x) const { 42 | return std::exp(-M_PI_2 * x * x); 43 | } 44 | 45 | double GaussianFamily::inv_freq_integral(double value) const { 46 | return std::sqrt(-M_2_PI * std::log(value)); 47 | } 48 | 49 | double GaussianFamily::skew_integral(double) const { 50 | return 0.0; 51 | } 52 | 53 | double GaussianFamily::time_integral(double x) const { 54 | return std::exp(-M_PI_2 * x * x); 55 | } 56 | 57 | double GaussianFamily::inv_time_integral(double value) const { 58 | return std::sqrt(-M_2_PI * std::log(value)); 59 | } 60 | 61 | double GaussianFamily::optimality_factor_e2(double epsilon2) const { 62 | return 1 - 1.5 * epsilon2; 63 | } 64 | 65 | double GaussianFamily::optimality_factor_sf(double sf) const { 66 | return 1 - exp(-1.59 * sf - 2.11); 67 | } 68 | -------------------------------------------------------------------------------- /src/IndexRange.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "IndexRange.h" 8 | 9 | ////////////////////////////////////////////////////////////////////////////// 10 | 11 | IndexRange IndexRange::overlap(const IndexRange &other) const { 12 | index_t new_first_index = std::max(first_index, other.first_index); 13 | index_t new_end_index = std::min(end_index, other.end_index); 14 | if (new_first_index < new_end_index) { 15 | return {new_first_index, new_end_index}; 16 | } else { 17 | return {}; 18 | } 19 | } 20 | 21 | bool IndexRange::includes(index_t index) const { 22 | return first_index <= index && index < end_index; 23 | } 24 | 25 | bool IndexRange::operator!() const { 26 | return first_index >= end_index; 27 | } 28 | -------------------------------------------------------------------------------- /src/Logger.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "Logger.h" 8 | 9 | #ifdef _WIN32 10 | void ctime_r(const time_t* timep, char* buf) { 11 | ctime_s(buf, 26, timep); 12 | } 13 | #endif 14 | 15 | ////////////////////////////////////////////////////////////////////////////// 16 | 17 | void Logger::header(const char* type) { 18 | time_t now = time(nullptr); 19 | char ctime_buffer[26]; 20 | ctime_r(&now, ctime_buffer); 21 | ctime_buffer[24] = 0; 22 | 23 | fprintf(stderr, "[%s] %s: ", ctime_buffer, type); 24 | } 25 | 26 | void Logger::message(const char* type, const char* string) { 27 | header(type); 28 | fputs(string, stderr); 29 | fputc('\n', stderr); 30 | } 31 | -------------------------------------------------------------------------------- /src/Progress.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "Progress.h" 8 | #include "Types.h" 9 | 10 | Progress::Progress(int total_epoch_count) 11 | : total_epoch_count(total_epoch_count), epochs_completed(0), last_progress(0) {} 12 | 13 | void Progress::print_progress() { 14 | double completed = epochs_completed; 15 | for (const auto &pair: progress_map) { 16 | completed += pair.second; 17 | } 18 | int progress = Types::floor(100.0 * completed / total_epoch_count); 19 | if (progress > last_progress) { 20 | printf("%d%% completed (finished %d out of %d segments)\n", progress, epochs_completed, total_epoch_count); 21 | fflush(stdout); 22 | last_progress = progress; 23 | } 24 | } 25 | 26 | void Progress::epoch_started(EpochIndex index) { 27 | std::lock_guard lock(mutex); 28 | progress_map[index] = 0; 29 | } 30 | 31 | void Progress::epoch_progress(EpochIndex index, double progress) { 32 | std::lock_guard lock(mutex); 33 | progress_map[index] = progress; 34 | print_progress(); 35 | } 36 | 37 | void Progress::epoch_finished(EpochIndex index) { 38 | std::lock_guard lock(mutex); 39 | progress_map.erase(index); 40 | ++epochs_completed; 41 | print_progress(); 42 | } 43 | -------------------------------------------------------------------------------- /src/SignalReader.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015–2018 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include "SignalReader.h" 7 | 8 | //------------------------------------------------------------------------------ 9 | 10 | SignalReaderSingleChannel::SignalReaderSingleChannel(std::shared_ptr source_) 11 | : source(std::move(source_)), epoch(source->get_epoch_channel_count(), source->get_epoch_sample_count()) 12 | { } 13 | 14 | size_t SignalReaderSingleChannel::get_epoch_count() const { 15 | return source->get_epoch_count() * source->get_epoch_channel_count(); 16 | } 17 | 18 | int SignalReaderSingleChannel::get_epoch_channel_count() const { 19 | return 1; 20 | } 21 | 22 | index_t SignalReaderSingleChannel::get_epoch_sample_count() const { 23 | return source->get_epoch_sample_count(); 24 | } 25 | 26 | std::optional SignalReaderSingleChannel::read(Array2D buffer) { 27 | std::lock_guard lock(this->mutex); 28 | if (!last_epoch || ++last_epoch->channel_offset >= epoch.height()) { 29 | last_epoch = source->read(epoch); 30 | if (!last_epoch) { 31 | return std::nullopt; 32 | } 33 | } 34 | std::copy(epoch[last_epoch->channel_offset], epoch[last_epoch->channel_offset] + epoch.length(), buffer[0]); 35 | 36 | return last_epoch; 37 | } 38 | 39 | //------------------------------------------------------------------------------ 40 | -------------------------------------------------------------------------------- /src/SpectrogramCalculatorCUDACallback.cu: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include "CUDA.h" 7 | 8 | static __device__ real myInputCallback(void *dataIn, size_t offset, void *callerInfo, void *) { 9 | auto *info = reinterpret_cast(callerInfo); 10 | size_t index_in_batch = offset & info->window_length_mask; 11 | real ret = (index_in_batch < info->envelope_length) 12 | ? ((real *) dataIn)[(offset >> info->window_length_bits) * info->input_shift + index_in_batch] * info->envelope[index_in_batch] 13 | : 0; 14 | return ret; 15 | } 16 | 17 | static __device__ void 18 | myOutputCallback(void *dataOut, size_t offset, cucomplex element, void *callerInfo, void *) { 19 | auto *info = reinterpret_cast(callerInfo); 20 | size_t index_in_batch = offset % info->spectrum_length; 21 | if (index_in_batch < info->output_bins) { 22 | ((cucomplex *) dataOut)[offset / info->spectrum_length * info->output_bins + index_in_batch] = element; 23 | } 24 | } 25 | 26 | static __device__ cufftCallbackLoadD myInputPtr = myInputCallback; 27 | 28 | static __device__ cufftCallbackStoreZ myOutputPtr = myOutputCallback; 29 | 30 | void CudaCallback::initialize() { 31 | cuda_check(cudaMemcpyFromSymbol(&hostCopyOfInputCallback, myInputPtr, sizeof(void*))); 32 | cuda_check(cudaMemcpyFromSymbol(&hostCopyOfOutputCallback, myOutputPtr, sizeof(void*))); 33 | } 34 | 35 | void CudaCallback::associate(cufftHandle plan, CudaCallbackInfo *dev_info) { 36 | cufft_check(cufftXtSetCallback(plan, (void **) &hostCopyOfInputCallback, CUFFT_CB_LD_REAL_DOUBLE, (void **) &dev_info)); 37 | cufft_check(cufftXtSetCallback(plan, (void **) &hostCopyOfOutputCallback, CUFFT_CB_ST_COMPLEX_DOUBLE, (void **) &dev_info)); 38 | } 39 | -------------------------------------------------------------------------------- /src/SpectrogramCalculatorDummy.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "Array.h" 8 | #include "SpectrogramCalculatorDummy.h" 9 | 10 | ////////////////////////////////////////////////////////////////////////////// 11 | 12 | void SpectrogramCalculatorDummy::compute(const SpectrogramRequest &request) { 13 | std::vector tmp_for_extractor(request.output_bins); 14 | request.assertCorrectness(); 15 | 16 | Array2D tmp_spectra(request.channel_count, request.output_bins); 17 | index_t input_offset = request.input_offset; 18 | for (int m = 0; m < request.how_many; ++m) { 19 | for (int c = 0; c < request.channel_count; ++c) { 20 | for (int k = 0; k < request.output_bins; ++k) { 21 | complex sum = 0.0; 22 | for (int n = 0; n < request.envelope_length; ++n) { 23 | index_t index = input_offset + n; 24 | if (index >= 0 && index < request.channel_length) { 25 | sum += request.data[c][index] * request.envelope[n] * 26 | std::polar(1.0, -2 * M_PI * k * n / request.window_length); 27 | } 28 | } 29 | tmp_spectra[c][k] = sum; 30 | } 31 | } 32 | input_offset += request.input_shift; 33 | request.maxima[m] = request.extractor(request.channel_count, request.output_bins, tmp_spectra.get(), 34 | request.correctors, tmp_for_extractor.data(), nullptr); 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/SpectrogramCalculatorFFTW.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include 8 | #include "IndexRange.h" 9 | #include "SpectrogramCalculatorFFTW.h" 10 | 11 | template 12 | static void mul(size_t count, TX *__restrict x, const TY *__restrict y) noexcept { 13 | while (count--) { 14 | (*x++) *= (*y++); 15 | } 16 | } 17 | 18 | ////////////////////////////////////////////////////////////////////////////// 19 | 20 | SpectrogramCalculatorFFTW::SpectrogramCalculatorFFTW(int channel_count, const std::set &window_lengths) 21 | : max_window_length(*window_lengths.rbegin()), 22 | input_buffers(channel_count, max_window_length, fftw_alloc_real, fftw_free), 23 | output_buffers(channel_count, max_window_length / 2 + 1, fftw_alloc_complex, fftw_free) { 24 | for (int window_length : window_lengths) { 25 | fftw_plan plan = fftw_plan_dft_r2c_1d( 26 | window_length, 27 | input_buffers[0], 28 | output_buffers[0], 29 | FFTW_ESTIMATE | FFTW_DESTROY_INPUT 30 | ); 31 | plans[window_length].reset(plan, fftw_destroy_plan); 32 | } 33 | } 34 | 35 | SpectrogramCalculatorFFTW::SpectrogramCalculatorFFTW(const SpectrogramCalculatorFFTW &source) 36 | : max_window_length(source.max_window_length), 37 | input_buffers(source.input_buffers.height(), source.input_buffers.length(), fftw_alloc_real, fftw_free), 38 | output_buffers(source.output_buffers.height(), source.output_buffers.length(), fftw_alloc_complex, fftw_free), 39 | plans(source.plans) {} 40 | 41 | void SpectrogramCalculatorFFTW::compute(const SpectrogramRequest &request) { 42 | request.assertCorrectness(); 43 | fftw_plan plan = getPlan(request.window_length); 44 | 45 | for (int c = 0; c < request.channel_count; ++c) { 46 | real *const input_buffer = input_buffers[c]; 47 | std::fill(input_buffer + request.envelope_length, input_buffer + request.window_length, 0); 48 | } 49 | 50 | for (int h = 0; h < request.how_many; ++h) { 51 | index_t offset = request.input_offset + mulx(h, request.input_shift); 52 | IndexRange overlap = IndexRange(-offset, request.channel_length - offset).overlap(request.envelope_length); 53 | if (!overlap) { 54 | // all samples consist of zero-padding 55 | request.maxima[h] = ExtractedMaximum{0, 0}; 56 | continue; 57 | } 58 | 59 | for (int c = 0; c < request.channel_count; ++c) { 60 | real *const input_buffer = input_buffers[c]; 61 | const real *input = request.data[c] + offset; 62 | 63 | if (overlap.first_index > 0) { 64 | std::fill(input_buffer, input_buffer + overlap.first_index, 0); 65 | } 66 | std::copy(input + overlap.first_index, input + overlap.end_index, input_buffer + overlap.first_index); 67 | if (overlap.end_index < static_cast(request.window_length)) { 68 | std::fill(input_buffer + overlap.end_index, input_buffer + request.window_length, 0); 69 | } 70 | 71 | mul(request.envelope_length, input_buffer, request.envelope); 72 | fftw_execute_dft_r2c(plan, input_buffer, output_buffers[c]); 73 | } 74 | request.maxima[h] = request.extractor(request.channel_count, request.output_bins, 75 | reinterpret_cast(output_buffers.get()), request.correctors, input_buffers[0], 76 | nullptr); 77 | } 78 | } 79 | 80 | const complex *SpectrogramCalculatorFFTW::computeSpectrum(Array1D input, int window_length) { 81 | fftw_plan plan = getPlan(window_length); 82 | 83 | double *input_buffer = input_buffers[0]; 84 | fftw_complex *output_buffer = output_buffers[0]; 85 | 86 | std::copy(input.get(), input.get() + input.length(), input_buffer); 87 | if (input.length() < window_length) { 88 | std::fill(input_buffer + input.length(), input_buffer + window_length, 0); 89 | } 90 | 91 | fftw_execute_dft_r2c(plan, input_buffer, output_buffer); 92 | return reinterpret_cast(output_buffer); 93 | } 94 | 95 | fftw_plan SpectrogramCalculatorFFTW::getPlan(int window_length) const { 96 | auto it = plans.find(window_length); 97 | if (it == plans.end()) { 98 | throw std::logic_error("invalid window_length passed to FFTW calculator"); 99 | } 100 | return it->second.get(); 101 | } 102 | -------------------------------------------------------------------------------- /src/Thread.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "Thread.h" 8 | 9 | bool Thread::affinity_enabled = false; 10 | 11 | void Thread::affix_to_cpu() { 12 | #ifdef _GNU_SOURCE 13 | if (!affinity_enabled) { 14 | return; 15 | } 16 | static unsigned cpu_index = 0; 17 | cpu_set_t cpuset; 18 | CPU_ZERO(&cpuset); 19 | CPU_SET(cpu_index, &cpuset); 20 | pthread_setaffinity_np(native_handle(), sizeof(cpu_set_t), &cpuset); 21 | cpu_index = (cpu_index + 1) % std::thread::hardware_concurrency(); 22 | #endif 23 | } 24 | -------------------------------------------------------------------------------- /src/Timer.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include 8 | #include "Timer.h" 9 | 10 | ////////////////////////////////////////////////////////////////////////////// 11 | 12 | ElapsingTimer::ElapsingTimer() 13 | : hasMessage_(false) { 14 | start_ = std::chrono::steady_clock::now(); 15 | } 16 | 17 | ElapsingTimer::ElapsingTimer(const char *msg, va_list ap) 18 | : hasMessage_(msg) { 19 | if (hasMessage_) { 20 | vfprintf(stderr, msg, ap); 21 | fputs("... ", stderr); 22 | fflush(stderr); 23 | } 24 | 25 | start_ = std::chrono::steady_clock::now(); 26 | } 27 | 28 | ElapsingTimer::~ElapsingTimer() { 29 | if (hasMessage_) { 30 | fprintf(stderr, "(%.3f s)\n", time()); 31 | fflush(stderr); 32 | } 33 | } 34 | 35 | float ElapsingTimer::time() const { 36 | auto now = std::chrono::steady_clock::now(); 37 | return static_cast(std::chrono::duration_cast(now - start_).count()) * 1.0e-6f; 38 | } 39 | 40 | ////////////////////////////////////////////////////////////////////////////// 41 | 42 | void Timer::start() { 43 | timer_ = std::make_unique(); 44 | } 45 | 46 | void Timer::start(const char *msg, ...) { 47 | va_list ap; 48 | va_start(ap, msg); 49 | timer_ = std::make_unique(msg, ap); 50 | va_end(ap); 51 | } 52 | 53 | float Timer::time() const { 54 | return timer_ ? timer_->time() : 0.0f; 55 | } 56 | 57 | void Timer::stop() { 58 | timer_.reset(); 59 | } 60 | -------------------------------------------------------------------------------- /src/TriangularFamily.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "TriangularFamily.h" 8 | 9 | static const double MIN_MAX = std::sqrt(2.5*M_1_PI); 10 | static const double NORM = std::pow(0.9 * M_PI, 0.25); 11 | 12 | ////////////////////////////////////////////////////////////////////////////// 13 | 14 | double TriangularFamily::max_arg() const { 15 | return MIN_MAX; 16 | } 17 | 18 | double TriangularFamily::min_arg() const { 19 | return -MIN_MAX; 20 | } 21 | 22 | const char *TriangularFamily::name() const { 23 | return "triangular"; 24 | } 25 | 26 | double TriangularFamily::value(double t) const { 27 | const double abs_t = std::abs(t); 28 | return (abs_t < MIN_MAX) ? NORM * (1.0 - abs_t / MIN_MAX) : 0.0; 29 | } 30 | 31 | double TriangularFamily::scale_integral(double log_scale) const { 32 | const double exp_half_log_scale = std::exp(-0.5 * log_scale); 33 | return (3.0 - exp_half_log_scale * exp_half_log_scale) * exp_half_log_scale / 2; 34 | } 35 | 36 | double TriangularFamily::inv_scale_integral(double value) const { 37 | return solve_integral(&TriangularFamily::scale_integral, value); 38 | } 39 | 40 | double TriangularFamily::freq_integral(double x) const { 41 | if (std::abs(x) < 0.001) { 42 | // approximation to avoid numerical errors 43 | return 1.0 - M_PI_2 * x * x; 44 | } 45 | const double x_rel = std::sqrt(10 * M_PI) * x; 46 | return 6 / (x_rel * x_rel) * (1 - sin(x_rel) / x_rel); 47 | } 48 | 49 | double TriangularFamily::inv_freq_integral(double value) const { 50 | return solve_integral(&TriangularFamily::freq_integral, value); 51 | } 52 | 53 | double TriangularFamily::skew_integral(double) const { 54 | return 0.0; 55 | } 56 | 57 | double TriangularFamily::time_integral(double x) const { 58 | const double x_rel = std::abs(x) / MIN_MAX; 59 | if (x_rel <= 1.0) { 60 | return 1 + 0.75 * x_rel * x_rel * (x_rel - 2.0); 61 | } 62 | if (x_rel <= 2.0) { 63 | const double t = 1.0 - 0.5 * x_rel; 64 | return 2 * t * t * t; 65 | } 66 | return 0.0; 67 | } 68 | 69 | double TriangularFamily::inv_time_integral(double value) const { 70 | return solve_integral(&TriangularFamily::time_integral, value); 71 | } 72 | 73 | double TriangularFamily::optimality_factor_e2(double epsilon2) const { 74 | return 1 - 1.52 * epsilon2; 75 | } 76 | 77 | double TriangularFamily::optimality_factor_sf(double sf) const { 78 | return 1 - exp(-0.9 * sf - 1.97); 79 | } 80 | -------------------------------------------------------------------------------- /src/Types.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "Types.h" 8 | 9 | template<> 10 | int Types::rint(double x) { 11 | return static_cast(std::lrint(x)); 12 | } 13 | 14 | template<> 15 | long Types::rint(double x) { 16 | return std::lrint(x); 17 | } 18 | 19 | template<> 20 | long long Types::rint(double x) { 21 | return std::llrint(x); 22 | } 23 | -------------------------------------------------------------------------------- /src/Worker.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "ExtenderLoop.h" 8 | #include "Worker.h" 9 | 10 | ////////////////////////////////////////////////////////////////////////////// 11 | 12 | Worker::Worker(Array2D data, unsigned cpu_threads, OptimizationMode mode) : data(std::move(data)), mode(mode) { 13 | atom_queue = std::make_shared>(); 14 | spectrogram_queue = std::make_shared>(); 15 | reset(); 16 | if (mode == OPTIMIZATION_GLOBAL) { 17 | for (unsigned i = 1; i < cpu_threads; ++i) { 18 | threads.emplace_back(ExtenderLoop{atom_queue}); 19 | } 20 | } 21 | } 22 | 23 | Worker::~Worker() { 24 | if (atom_queue) { 25 | atom_queue->terminate(); 26 | } 27 | if (spectrogram_queue) { 28 | spectrogram_queue->terminate(); 29 | } 30 | for (auto &thread : threads) { 31 | thread.join(); 32 | } 33 | } 34 | 35 | void Worker::add_calculator(std::unique_ptr calculator, bool prefer_long_fft) { 36 | if (!primary_calculator_loop) { 37 | primary_calculator_loop = std::make_unique(spectrogram_queue, std::move(calculator), prefer_long_fft); 38 | } else { 39 | threads.emplace_back(SpectrogramLoop(spectrogram_queue, std::move(calculator), prefer_long_fft)); 40 | } 41 | } 42 | 43 | void Worker::add_dictionary(std::unique_ptr dictionary) { 44 | dictionaries.push_back(std::move(dictionary)); 45 | } 46 | 47 | ExtendedAtomPointer Worker::get_next_atom() { 48 | std::list requests; 49 | for (auto &dictionary : dictionaries) { 50 | dictionary->fetch_requests(updated_index_range, requests); 51 | } 52 | if (!requests.empty()) { 53 | if (!primary_calculator_loop) { 54 | throw std::logic_error("need at least one SpectrogramCalculator instance"); 55 | } 56 | requests.sort([](const SpectrogramRequest &a, const SpectrogramRequest &b) { 57 | return a.window_length < b.window_length; 58 | }); 59 | 60 | spectrogram_queue->put(std::move(requests)); 61 | (*primary_calculator_loop)(false); 62 | spectrogram_queue->wait_for_tasks(); 63 | } 64 | 65 | BasicAtomPointer best_match; 66 | for (auto &dictionary : dictionaries) { 67 | BasicAtomPointer match = dictionary->get_best_match(); 68 | if (match && (!best_match || *best_match < *match)) { 69 | best_match = match; 70 | } 71 | } 72 | 73 | if (!best_match) { 74 | // strangely, no more atoms 75 | return nullptr; 76 | } 77 | 78 | ExtendedAtomPointer best_atom = best_match->extend(mode >= OPTIMIZATION_LOCAL); 79 | assert(best_atom); 80 | if (mode >= OPTIMIZATION_GLOBAL) { 81 | // we also have to check other candidates 82 | std::list candidates; 83 | for (auto &dictionary : dictionaries) { 84 | candidates.splice(candidates.end(), dictionary->get_candidate_matches(best_atom->energy)); 85 | } 86 | atom_queue->put(std::list(candidates)); 87 | ExtenderLoop{atom_queue}(false); 88 | atom_queue->wait_for_tasks(); 89 | for (const auto &candidate : candidates) { 90 | ExtendedAtomPointer another_atom = candidate->extend(true); 91 | if (another_atom->energy > best_atom->energy) { 92 | best_atom = another_atom; 93 | } 94 | } 95 | } 96 | 97 | updated_index_range = best_atom->subtract_from_signal(); 98 | return best_atom; 99 | } 100 | 101 | std::list Worker::get_proto_requests() { 102 | std::list proto_requests; 103 | for (auto &dictionary : dictionaries) { 104 | dictionary->fetch_proto_requests(proto_requests); 105 | } 106 | return proto_requests; 107 | } 108 | 109 | void Worker::reset() { 110 | updated_index_range = data.length(); 111 | } 112 | -------------------------------------------------------------------------------- /src/special/alloc.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "PinnedArray.h" 8 | 9 | // This file is meant as a replacement for CUDA-based pinned memory allocation 10 | // in case it isn't needed (for some unit tests) or CUDA is not available. 11 | // Each executable should link with either alloc.cpp or alloc.cu, but not both at once. 12 | 13 | void cuda_host_disable() { 14 | // nothing here 15 | } 16 | 17 | void *cuda_host_malloc(size_t length) { 18 | void *result = malloc(length); 19 | if (!result) { 20 | throw std::bad_alloc(); 21 | } 22 | return result; 23 | } 24 | 25 | void cuda_host_free(void *pointer) { 26 | free(pointer); 27 | } 28 | -------------------------------------------------------------------------------- /src/special/alloc.cu: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include "PinnedArray.h" 7 | 8 | // This file provides allocation of pinned (page-locked) host memory, 9 | // so it can be safely used in some CUDA routines. 10 | // Each executable should link with either alloc.cpp or alloc.cu, but not both at once. 11 | 12 | static bool CUDA_ALLOC_DISABLED = false; 13 | 14 | void cuda_host_disable() { 15 | CUDA_ALLOC_DISABLED = true; 16 | } 17 | 18 | void *cuda_host_malloc(size_t length) { 19 | void *result; 20 | if (CUDA_ALLOC_DISABLED) { 21 | result = malloc(length); 22 | if (!result) { 23 | throw std::bad_alloc(); 24 | } 25 | } else { 26 | if (cudaHostAlloc(&result, length, cudaHostAllocPortable)) { 27 | throw std::bad_alloc(); 28 | } 29 | } 30 | return result; 31 | } 32 | 33 | void cuda_host_free(void *pointer) { 34 | if (CUDA_ALLOC_DISABLED) { 35 | free(pointer); 36 | } else { 37 | cudaFreeHost(pointer); 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /tests/special/benchmark.cpp: -------------------------------------------------------------------------------- 1 | #include "Array.h" 2 | #include "PinnedArray.h" 3 | #include "SpectrogramCalculatorCUDA.h" 4 | #include "SpectrogramCalculatorFFTW.h" 5 | #include "Timer.h" 6 | 7 | const int REPEATS = 10; 8 | 9 | double rand_float() { 10 | return (real) rand() / (real) RAND_MAX - 0.5; 11 | } 12 | 13 | int main(int argc, char** argv) 14 | { 15 | if (argc < 3) { 16 | fprintf(stderr, "USAGE: %s total_input_length window_length\n", argv[0]); 17 | return 1; 18 | } 19 | srand(time(nullptr)); 20 | const long total_input_length = atol(argv[1]); 21 | const int window_length = atoi(argv[2]); 22 | const int channel_count = 10; 23 | 24 | const int input_shift = window_length / 8; 25 | const int how_many = (total_input_length - window_length) / input_shift + 1; 26 | const int spectrum_length = window_length/2 + 1; 27 | 28 | PinnedArray1D envelope(window_length); 29 | PinnedArray1D correctors(spectrum_length); 30 | PinnedArray1D maximaCUDA(how_many); 31 | Array1D maximaFFTW(how_many); 32 | 33 | SpectrogramRequest request; 34 | request.channel_length = total_input_length; 35 | request.channel_count = channel_count; 36 | request.input_offset = 0; 37 | request.input_shift = input_shift; 38 | request.how_many = how_many; 39 | request.envelope = envelope.get(); 40 | request.envelope_length = window_length; 41 | request.window_length = window_length; 42 | request.output_bins = spectrum_length; 43 | request.correctors = correctors.get(); 44 | request.extractor = extractorVariablePhase; 45 | 46 | PinnedArray2D inputs(channel_count, total_input_length); 47 | request.data = inputs.get(); 48 | 49 | WorkerFFTW fftw(channel_count, {window_length }); 50 | WorkerCUDA cuda(channel_count, {request }); 51 | 52 | double fftwTime = 0.0, cudaTime = 0.0; 53 | for (int repeat=0; repeat 1.0e-10) { 90 | printf("ERROR: max diff = %le\n", max_diff); 91 | return 1; 92 | } 93 | } 94 | 95 | printf("%.6lf %.6lf\n", fftwTime/REPEATS, cudaTime/REPEATS); 96 | 97 | return 0; 98 | } 99 | -------------------------------------------------------------------------------- /tests/special/test-cuda.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "Array.h" 8 | #include "Corrector.h" 9 | #include "PinnedArray.h" 10 | #include "Testing.h" 11 | #include "SpectrogramCalculatorDummy.h" 12 | #include "SpectrogramCalculatorCUDA.h" 13 | 14 | double rand_float() { 15 | return (double) rand() / (double) RAND_MAX - 0.5; 16 | } 17 | 18 | void benchmark(SpectrogramCalculator &calculator, const SpectrogramRequest &request, const ExtractedMaximum *reference_maxima) { 19 | calculator.compute(request); 20 | 21 | if (reference_maxima) { 22 | double max_diff = 0; 23 | for (int i = 0; i < request.how_many; ++i) { 24 | ASSERT_EQUALS(reference_maxima[i].bin_index, request.maxima[i].bin_index); 25 | max_diff = std::max(max_diff, std::abs(request.maxima[i].energy - reference_maxima[i].energy)); 26 | } 27 | ASSERT(max_diff > 0); 28 | ASSERT_NEAR_ZERO(max_diff); 29 | } 30 | } 31 | 32 | void run_test(Extractor extractor) { 33 | const int envelope_length = 800; 34 | Array1D envelope(envelope_length); 35 | for (int i = 0; i < envelope_length; ++i) { 36 | envelope[i] = rand_float(); 37 | } 38 | 39 | const int output_bins = 300; 40 | Array1D correctors(output_bins); 41 | for (int i = 0; i < output_bins; ++i) { 42 | correctors[i] = Corrector(0.5 * complex(rand_float(), rand_float())); 43 | } 44 | 45 | const index_t channel_length = 12000; 46 | const int channel_count = 10; 47 | PinnedArray2D data(channel_count, channel_length); 48 | for (int c = 0; c < channel_count; ++c) { 49 | real *channel = data[c]; 50 | for (index_t i = 0; i < channel_length; ++i) { 51 | channel[i] = rand_float(); 52 | } 53 | } 54 | 55 | SpectrogramRequest request; 56 | request.data = data.get(); 57 | request.channel_length = 12000; 58 | request.channel_count = 10; 59 | request.input_offset = -1000; 60 | request.input_shift = 200; 61 | request.how_many = 60; 62 | request.envelope = envelope.get(); 63 | request.envelope_length = envelope_length; 64 | request.window_length = 1024; 65 | request.output_bins = output_bins; 66 | request.correctors = correctors.get(); 67 | request.extractor = extractor; 68 | 69 | Array1D maxima_dummy(request.how_many); 70 | Array1D maxima_cuda(request.how_many); 71 | 72 | SpectrogramCalculatorCUDA cuda(channel_count, {request}, 0); 73 | SpectrogramCalculatorDummy dummy; 74 | 75 | request.maxima = maxima_dummy.get(); 76 | benchmark(dummy, request, nullptr); 77 | 78 | request.maxima = maxima_cuda.get(); 79 | benchmark(cuda, request, maxima_dummy.get()); 80 | } 81 | 82 | int main() { 83 | run_test(extractorVariablePhase); 84 | run_test(extractorConstantPhase); 85 | puts("OK"); 86 | } 87 | -------------------------------------------------------------------------------- /tests/test-allocation.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "Array.h" 8 | #include "Testing.h" 9 | 10 | struct Item { 11 | unsigned id; 12 | }; 13 | 14 | unsigned N = 5; 15 | unsigned M = 10; 16 | unsigned allocator_called = 0; 17 | unsigned deallocator_called = 0; 18 | 19 | Item *myAllocator(size_t n) { 20 | ++allocator_called; 21 | static unsigned next_id = 0; 22 | Item *result = new Item[n]; 23 | for (unsigned i = 0; i < n; ++i) { 24 | result[i].id = next_id++; 25 | } 26 | return result; 27 | } 28 | 29 | void myDeallocator(Item *array) { 30 | ++deallocator_called; 31 | delete[] array; 32 | } 33 | 34 | void test_alloc() { 35 | { 36 | auto data = Array2D(N, M, myAllocator, myDeallocator); 37 | for (unsigned n = 0; n < N; ++n) { 38 | for (unsigned m = 0; m < M; ++m) { 39 | ASSERT_EQUALS(n * M + m, data[n][m].id); 40 | } 41 | } 42 | } 43 | ASSERT_EQUALS(N, allocator_called); 44 | ASSERT_EQUALS(N, deallocator_called); 45 | } 46 | 47 | int main(void) { 48 | test_alloc(); 49 | puts("OK"); 50 | } 51 | -------------------------------------------------------------------------------- /tests/test-best-match.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "Array.h" 8 | #include "BlockAtom.h" 9 | #include "BlockDictionary.h" 10 | #include "BlockHelper.h" 11 | #include "Extractor.h" 12 | #include "GaussianFamily.h" 13 | #include "SpectrogramCalculatorFFTW.h" 14 | #include "SpectrogramRequest.h" 15 | #include "Testing.h" 16 | 17 | const int N = 300; 18 | 19 | void test(SpectrogramCalculatorFFTW &calculator, double frequency, int position, double scale, double phase0, double phase1, double amplitude0, double amplitude1) { 20 | std::shared_ptr family = std::make_shared(); 21 | index_t envelope_offset; 22 | index_t envelope_length = family->size_for_values(position, scale, nullptr); 23 | PinnedArray1D envelope(envelope_length); 24 | family->generate_values(position, scale, &envelope_offset, envelope.get(), true); 25 | 26 | PinnedArray2D data(2, N); 27 | const double value_max = envelope[envelope.length() / 2]; 28 | 29 | double energies[2]; 30 | energies[0] = energies[1] = 0.0; 31 | for (int i = 0; i < N; ++i) { 32 | // a single dual-channel Gabor atom with specified phases and amplitudes 33 | const int io = i - envelope_offset; 34 | const double v = (io >= 0 && io < envelope_length) ? envelope[io] / value_max : 0.0; 35 | const double phi = 2 * M_PI * frequency * (i - position); 36 | data[0][i] = amplitude0 * v * std::cos(phi + phase0); 37 | data[1][i] = amplitude1 * v * std::cos(phi + phase1); 38 | energies[0] += data[0][i] * data[0][i]; 39 | energies[1] += data[1][i] * data[1][i]; 40 | } 41 | auto converter = std::make_shared(); 42 | auto correctors = BlockHelper::generate_correctors(envelope, 256, 129, calculator); 43 | BlockDictionary dictionary(Block(data, family, scale, envelope, correctors, converter, NAN, 256, 1, envelope_length/2, extractorVariablePhase)); 44 | 45 | std::list requests; 46 | dictionary.fetch_requests({0, N}, requests); 47 | for (const auto &request : requests) { 48 | calculator.compute(request); 49 | request.interface->notify(); 50 | } 51 | 52 | BasicAtomPointer atom = dictionary.get_best_match(); 53 | ASSERT(atom); 54 | 55 | const BlockAtom &block_atom = static_cast(*atom); 56 | 57 | ASSERT_NEAR_ZERO(block_atom.params.scale - scale); 58 | ASSERT_NEAR_ZERO(block_atom.params.frequency - frequency); 59 | ASSERT_NEAR_ZERO(block_atom.params.position - position); 60 | ASSERT_NEAR_ZERO(block_atom.energy - (energies[0] + energies[1])); 61 | } 62 | 63 | int main() { 64 | SpectrogramCalculatorFFTW fftw(2, {256}); 65 | 66 | test(fftw, 67 | 9.0 / 256, 144, 10, 68 | 0.71529, 0.92517, 69 | 1.5, 2.5); 70 | test(fftw, 71 | 119.0 / 256, 148, 10, 72 | 0.71529, 0.92517, 73 | 1.5, 2.5); 74 | test(fftw, 75 | 0.25, 100, 1.0, 76 | 0.0, 0.0, 77 | 1.5, 2.5); 78 | 79 | puts("OK"); 80 | } 81 | -------------------------------------------------------------------------------- /tests/test-block.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "Block.h" 8 | #include "Corrector.h" 9 | #include "Extractor.h" 10 | #include "GaussianFamily.h" 11 | #include "SpectrogramRequest.h" 12 | #include "Testing.h" 13 | #include "Types.h" 14 | 15 | ExtractedMaximum extractorForTest(int, int, complex *const *, const Corrector *, double *, ExtraData *) { 16 | return ExtractedMaximum(); 17 | } 18 | 19 | int main() { 20 | PinnedArray2D data(10, 300); 21 | PinnedArray1D envelope(129); 22 | PinnedArray1D correctors(100); 23 | 24 | auto converter = std::make_shared(); 25 | auto family = std::make_shared(); 26 | Block block(data, family, NAN, envelope, correctors, converter, NAN, 256, 30, 64.0, extractorForTest); 27 | 28 | SpectrogramRequest request = block.buildRequest(0, 300); 29 | ASSERT_EQUALS(data.get(), request.data); 30 | ASSERT_EQUALS(300, request.channel_length); 31 | ASSERT_EQUALS(10, request.channel_count); 32 | ASSERT_EQUALS(-64, request.input_offset); // half of envelope 33 | ASSERT_EQUALS(30, request.input_shift); 34 | ASSERT_EQUALS(10, request.how_many); // centered at: 0, 30, 60, 90, 120, 150, 180, 210, 240 and 270 35 | ASSERT_EQUALS(envelope.get(), request.envelope); 36 | ASSERT_EQUALS(129, request.envelope_length); 37 | ASSERT_EQUALS(256, request.window_length); 38 | ASSERT_EQUALS(100, request.output_bins); 39 | ASSERT_EQUALS(correctors.get(), request.correctors); 40 | ASSERT_EQUALS(extractorForTest, request.extractor); 41 | 42 | request = block.buildRequest(64, 300); 43 | ASSERT_EQUALS(-64, request.input_offset); // half of envelope 44 | ASSERT_EQUALS(10, request.how_many); // still the same, the last envelope is [270..334] 45 | 46 | request = block.buildRequest(65, 300); 47 | ASSERT_EQUALS(-34, request.input_offset); // just got outside the first envelope [-64..64] 48 | ASSERT_EQUALS(9, request.how_many); // one less 49 | 50 | request = block.buildRequest(65, 207); 51 | ASSERT_EQUALS(-34, request.input_offset); 52 | ASSERT_EQUALS(9, request.how_many); // still the same 53 | 54 | request = block.buildRequest(65, 206); 55 | ASSERT_EQUALS(-34, request.input_offset); 56 | ASSERT_EQUALS(8, request.how_many); // got outside the last envelope [206..334] 57 | 58 | puts("OK"); 59 | } 60 | -------------------------------------------------------------------------------- /tests/test-book-writer.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include 8 | #include "BookWriter.h" 9 | #include "Testing.h" 10 | 11 | const char *const tmp_name = ".test-book-writer.tmp"; 12 | 13 | const char *const expected = "{\n\ 14 | \"version\": \"" APP_VERSION "\",\n\ 15 | \"channel_count\": 2,\n\ 16 | \"sampling_frequency_Hz\": 16,\n\ 17 | \"segments\": [{\n\ 18 | \"sample_count\": 4,\n\ 19 | \"segment_length_s\": 0.25,\n\ 20 | \"segment_offset_s\": 2,\n\ 21 | \"channels\": [{\n\ 22 | \"atoms\": [{\n\ 23 | \"amplitude\": 20,\n\ 24 | \"energy\": 10,\n\ 25 | \"envelope\": \"gauss\",\n\ 26 | \"f_Hz\": 8,\n\ 27 | \"phase\": 3,\n\ 28 | \"scale_s\": 3,\n\ 29 | \"t0_s\": 0.25,\n\ 30 | \"t0_abs_s\": 2.25\n\ 31 | }],\n\ 32 | \"samples\": [\n\ 33 | 0,\n\ 34 | 0,\n\ 35 | 0,\n\ 36 | 0\n\ 37 | ]\n\ 38 | },{\n\ 39 | \"atoms\": [{\n\ 40 | \"amplitude\": 40,\n\ 41 | \"energy\": 20,\n\ 42 | \"envelope\": \"gauss\",\n\ 43 | \"f_Hz\": 8,\n\ 44 | \"phase\": 3,\n\ 45 | \"scale_s\": 3,\n\ 46 | \"t0_s\": 0.25,\n\ 47 | \"t0_abs_s\": 2.25\n\ 48 | }],\n\ 49 | \"samples\": [\n\ 50 | 0,\n\ 51 | 0,\n\ 52 | 0,\n\ 53 | 0\n\ 54 | ]\n\ 55 | }]\n\ 56 | }],\n\ 57 | \"segment_count\": 1\n\ 58 | }"; 59 | 60 | ExportedAtom prepare_gabor(double energy) { 61 | ExportedAtom result(energy); 62 | result.envelope = "gauss"; 63 | result.amplitude = energy / 8; 64 | result.frequency = 0.5; 65 | result.phase = 3.0; 66 | result.position = 4.0; 67 | result.scale = 48.0; 68 | return result; 69 | } 70 | 71 | std::vector> prepare_atoms() { 72 | std::vector> atoms(2); 73 | atoms[0].push_back(prepare_gabor(160.0)); 74 | atoms[1].push_back(prepare_gabor(320.0)); 75 | return atoms; 76 | } 77 | 78 | void test_json_writer() { 79 | Array2D data(2, 4); 80 | data.fill(0); 81 | auto atoms = prepare_atoms(); 82 | JsonBookWriter writer(16, 4, tmp_name); 83 | writer.write(data, EpochIndex{0, 8}, atoms); 84 | writer.finalize(); 85 | 86 | size_t length = strlen(expected); 87 | char buffer[length]; 88 | FILE *f = fopen(tmp_name, "rb"); 89 | ASSERT_EQUALS(length, fread(buffer, 1, length, f)); 90 | ASSERT(!memcmp(buffer, expected, length)); 91 | fclose(f); 92 | } 93 | 94 | int main() { 95 | test_json_writer(); 96 | 97 | remove(tmp_name); 98 | puts("OK"); 99 | } 100 | -------------------------------------------------------------------------------- /tests/test-dictionary.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include 8 | #include "Array.h" 9 | #include "BlockAtom.h" 10 | #include "BlockDictionary.h" 11 | #include "BlockHelper.h" 12 | #include "GaussianFamily.h" 13 | #include "SpectrumCalculator.h" 14 | #include "SpectrogramRequest.h" 15 | #include "Testing.h" 16 | 17 | class SpectrumCalculatorForTest : public SpectrumCalculator { 18 | Array1D result; 19 | 20 | public: 21 | SpectrumCalculatorForTest() : result(9) {} 22 | 23 | const complex *computeSpectrum(Array1D input, int window_length) final { 24 | ASSERT_EQUALS(7, input.length()); 25 | ASSERT_EQUALS(16, window_length); 26 | return result.get(); 27 | } 28 | }; 29 | 30 | int main() { 31 | SpectrumCalculatorForTest calculator; 32 | PinnedArray2D data(1, 11); 33 | auto family = std::make_shared(); 34 | 35 | auto converter = std::make_shared(); 36 | BlockDictionary dictionary(BlockHelper::create_block(data, family, 2.0, converter, NAN, 16, 4, 1, 0.0, extractorVariablePhase, calculator)); 37 | 38 | std::list requests; 39 | dictionary.fetch_requests({0, 11}, requests); 40 | 41 | ASSERT_EQUALS(1, requests.size()); 42 | 43 | const SpectrogramRequest &request = requests.front(); 44 | ASSERT_EQUALS(11, request.channel_length); 45 | ASSERT_EQUALS(1, request.channel_count); 46 | ASSERT_EQUALS(-3, request.input_offset); 47 | ASSERT_EQUALS(1, request.input_shift); 48 | ASSERT_EQUALS(11, request.how_many); 49 | ASSERT_EQUALS(7, request.envelope_length); 50 | ASSERT_EQUALS(16, request.window_length); 51 | ASSERT_EQUALS(4, request.output_bins); 52 | 53 | for (int i = 0; i < 11; ++i) { 54 | request.maxima[i].energy = std::min(i + 1, 11 - i); 55 | request.maxima[i].bin_index = i % 4; 56 | } 57 | request.interface->notify(); 58 | 59 | BasicAtomPointer atom_pointer = dictionary.get_best_match(); 60 | BlockAtom atom = *dynamic_cast(atom_pointer.get()); 61 | ASSERT_EQUALS(5, atom.params.position); 62 | 63 | puts("OK"); 64 | } 65 | -------------------------------------------------------------------------------- /tests/test-envelope.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "Array.h" 8 | #include "GaussianFamily.h" 9 | #include "TriangularFamily.h" 10 | #include "Testing.h" 11 | 12 | #define N 1000 13 | 14 | void test_integrals(Family& family, const double delta_t = 0.3) { 15 | const double min_arg = family.min_arg(); 16 | const double max_arg = family.max_arg(); 17 | const double step = (max_arg - min_arg) / N; 18 | 19 | // any values would do fine 20 | const double delta_log_scale = 0.1; 21 | const double delta_f = 0.2; 22 | 23 | double sum2 = 0.0, sum2t = 0.0, sum2t2 = 0.0, sumA = 0.0, sumB = 0.0, sumC = 0.0; 24 | for (int i = 0; i < N; ++i) { 25 | double t = min_arg + (i + 0.5) * step; 26 | double value = family.value(t); 27 | sum2 += value * value; 28 | sum2t += value * value * t; 29 | sum2t2 += value * value * t * t; 30 | sumA += family.value(t * exp(delta_log_scale / 2)) * family.value(t * exp(-delta_log_scale / 2)); 31 | sumB += value * value * std::cos(2 * M_PI * delta_f * t); 32 | sumC += family.value(t + delta_t / 2) * family.value(t - delta_t / 2); 33 | } 34 | ASSERT_APPROX(1.0, sum2 * step, 1.0e-5); 35 | ASSERT_NEAR_ZERO(sum2t * step); 36 | ASSERT_NEAR_ZERO(sum2t2 * step - 0.25 / M_PI); 37 | ASSERT_APPROX(sumA * step, family.scale_integral(delta_log_scale), 1.0e-5); 38 | ASSERT_APPROX(sumB * step, family.freq_integral(delta_f), 1.0e-5); 39 | ASSERT_APPROX(sumC * step, family.time_integral(delta_t), 1.0e-5); 40 | 41 | ASSERT_NEAR_ZERO(family.inv_scale_integral(family.scale_integral(delta_log_scale)) - delta_log_scale); 42 | ASSERT_NEAR_ZERO(family.inv_freq_integral(family.freq_integral(delta_f)) - delta_f); 43 | ASSERT_NEAR_ZERO(family.inv_time_integral(family.time_integral(delta_t)) - delta_t); 44 | } 45 | 46 | void test_generate(Family& family, int expected_sample_count, int expected_first_sample_offset) { 47 | index_t first_sample_offset; 48 | index_t sample_count = family.size_for_values(400.0, 100.0, &first_sample_offset); 49 | ASSERT_EQUALS(expected_sample_count, sample_count); 50 | ASSERT_EQUALS(expected_first_sample_offset, first_sample_offset); 51 | 52 | Array1D values(sample_count); 53 | family.generate_values(400.0, 100.0, &first_sample_offset, values.get(), true); 54 | ASSERT_EQUALS(expected_first_sample_offset, first_sample_offset); 55 | 56 | double sum2 = 0.0, sum2t = 0.0, sum2t2 = 0.0; 57 | for (int i = 0; i < sample_count; ++i) { 58 | double t = static_cast(first_sample_offset + i) - 400.0; 59 | double value = values[i]; 60 | sum2 += value * value; 61 | sum2t += value * value * t; 62 | sum2t2 += value * value * t * t; 63 | } 64 | ASSERT_NEAR_ZERO(sum2 - 1.0); 65 | ASSERT_NEAR_ZERO(sum2t); 66 | ASSERT_APPROX(2500 / M_PI, sum2t2, 0.1); 67 | } 68 | 69 | int main() { 70 | GaussianFamily gauss; 71 | TriangularFamily triangle; 72 | 73 | test_integrals(gauss); 74 | test_integrals(triangle); 75 | test_integrals(triangle, 1.0); 76 | test_generate(gauss, 301, 250); 77 | test_generate(triangle, 179, 311); 78 | 79 | puts("OK"); 80 | } 81 | -------------------------------------------------------------------------------- /tests/test-fftw.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "Array.h" 8 | #include "Testing.h" 9 | #include "SpectrogramCalculatorDummy.h" 10 | #include "SpectrogramCalculatorFFTW.h" 11 | 12 | double rand_float() { 13 | return (double) rand() / (double) RAND_MAX - 0.5; 14 | } 15 | 16 | void benchmark(SpectrogramCalculator &calculator, const SpectrogramRequest &request, const ExtractedMaximum *reference_maxima) { 17 | calculator.compute(request); 18 | 19 | if (reference_maxima) { 20 | double max_diff = 0; 21 | for (int i = 0; i < request.how_many; ++i) { 22 | ASSERT_EQUALS(reference_maxima[i].bin_index, request.maxima[i].bin_index); 23 | max_diff = std::max(max_diff, std::abs(request.maxima[i].energy - reference_maxima[i].energy)); 24 | } 25 | ASSERT(max_diff > 0); 26 | ASSERT_NEAR_ZERO(max_diff); 27 | } 28 | } 29 | 30 | void run_test(Extractor extractor) { 31 | const int envelope_length = 800; 32 | Array1D envelope(envelope_length); 33 | for (int i = 0; i < envelope_length; ++i) { 34 | envelope[i] = rand_float(); 35 | } 36 | 37 | const int output_bins = 300; 38 | Array1D correctors(output_bins); 39 | for (int i = 0; i < output_bins; ++i) { 40 | correctors[i] = Corrector(0.5 * complex(rand_float(), rand_float())); 41 | } 42 | 43 | const index_t channel_length = 12000; 44 | const int channel_count = 10; 45 | Array2D data(channel_count, channel_length); 46 | for (int c = 0; c < channel_count; ++c) { 47 | real *channel = data[c]; 48 | for (index_t i = 0; i < channel_length; ++i) { 49 | channel[i] = rand_float(); 50 | } 51 | } 52 | 53 | SpectrogramRequest request; 54 | request.data = data.get(); 55 | request.channel_length = 12000; 56 | request.channel_count = 10; 57 | request.input_offset = -1000; 58 | request.input_shift = 200; 59 | request.how_many = 64; 60 | request.envelope = envelope.get(); 61 | request.envelope_length = envelope_length; 62 | request.window_length = 1024; 63 | request.output_bins = output_bins; 64 | request.correctors = correctors.get(); 65 | request.extractor = extractor; 66 | 67 | Array1D maxima_dummy(request.how_many); 68 | Array1D maxima_fftw(request.how_many); 69 | 70 | SpectrogramCalculatorFFTW fftw(channel_count, {1024}); 71 | SpectrogramCalculatorDummy dummy; 72 | 73 | request.maxima = maxima_dummy.get(); 74 | benchmark(dummy, request, nullptr); 75 | 76 | request.maxima = maxima_fftw.get(); 77 | benchmark(fftw, request, maxima_dummy.get()); 78 | } 79 | 80 | int main() { 81 | run_test(extractorVariablePhase); 82 | run_test(extractorConstantPhase); 83 | puts("OK"); 84 | } 85 | -------------------------------------------------------------------------------- /tests/test-index-range.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "IndexRange.h" 8 | #include "Testing.h" 9 | 10 | void assertNoOverlap(index_t start, index_t end, index_t other_length) { 11 | auto overlap = IndexRange(start, end).overlap(other_length); 12 | ASSERT(!overlap); 13 | } 14 | 15 | void assertOverlap(index_t start, index_t end, index_t other_length, index_t first_index, index_t end_index) { 16 | auto overlap = IndexRange(start, end).overlap(other_length); 17 | ASSERT_EQUALS(first_index, overlap.first_index); 18 | ASSERT_EQUALS(end_index, overlap.end_index); 19 | } 20 | 21 | int main() { 22 | assertOverlap(10, 30, 50, 10, 30); 23 | assertOverlap(10, 50, 50, 10, 50); 24 | assertOverlap(10, 60, 50, 10, 50); 25 | assertOverlap(-10, 10, 30, 0, 10); 26 | assertOverlap(-10, 40, 30, 0, 30); 27 | assertNoOverlap(-10, 0, 30); 28 | assertNoOverlap(30, 40, 30); 29 | puts("OK"); 30 | } 31 | -------------------------------------------------------------------------------- /tests/test-measure.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include 8 | #include 9 | 10 | #include "BlockAtomProductCalculator.h" 11 | #include "GaussianFamily.h" 12 | 13 | class OptimalityCalculator 14 | { 15 | BlockAtomProductCalculator calculator; 16 | const double frequency_step; 17 | const double position_step; 18 | const double log_scale_step; 19 | 20 | public: 21 | OptimalityCalculator(const std::shared_ptr& family, double energy_error) 22 | : calculator(family), 23 | frequency_step(0.5 * family->inv_freq_integral(1 - energy_error)), 24 | position_step(0.5 * family->inv_time_integral(1 - energy_error)), 25 | log_scale_step(0.5 * family->inv_scale_integral(1 - energy_error)) 26 | { } 27 | 28 | double compute_optimality_factor(double scale, double frequency) { 29 | BlockAtomParams params{frequency, 0.0, scale}; 30 | const double a = std::exp(log_scale_step); 31 | double result = 1.0; 32 | 33 | for (int i0 = -1; i0 <= 1; i0 += 2) { 34 | for (int i1 = -1; i1 <= 1; i1 += 2) { 35 | BlockAtomParams other = params; 36 | other.scale *= a; 37 | other.frequency += i0 * frequency_step / other.scale; 38 | if (other.frequency >= 0 && other.frequency <= 0.5) { 39 | other.position += i1 * position_step * other.scale; 40 | double y = calculator.calculate_squared_product(other, params); 41 | result = std::min(result, y); 42 | } 43 | } 44 | } 45 | for (int i0 = -1; i0 <= 1; i0 += 2) { 46 | for (int i1 = -1; i1 <= 1; i1 += 2) { 47 | BlockAtomParams other = params; 48 | other.scale /= a; 49 | other.frequency += i0 * frequency_step / other.scale; 50 | if (other.frequency >= 0 && other.frequency <= 0.5) { 51 | other.position += i1 * position_step * other.scale; 52 | double y = calculator.calculate_squared_product(other, params); 53 | result = std::min(result, y); 54 | } 55 | } 56 | } 57 | return result; 58 | } 59 | }; 60 | 61 | void run_test() { 62 | // TODO 63 | } 64 | 65 | int main(int argc, char **argv) { 66 | if (argc == 1) { 67 | // actual test is run only if no parameters are passed 68 | run_test(); 69 | puts("OK"); 70 | return 0; 71 | } 72 | 73 | // otherwise, perform a numerical study for the optimality factor for the given envelope 74 | if (argc != 5) { 75 | errx(EXIT_FAILURE, "USAGE: %s family energy_error scale frequency", argv[0]); 76 | } 77 | srand(time(nullptr)); 78 | auto it = Family::ALL.find(argv[1]); 79 | if (it == Family::ALL.end()) { 80 | errx(EXIT_FAILURE, "ERROR: invalid family %s", argv[1]); 81 | } 82 | 83 | const std::shared_ptr family = it->second; 84 | const double energy_error = atof(argv[2]); 85 | const double scale = atof(argv[3]); 86 | const double frequency = atof(argv[4]); 87 | 88 | if (frequency < 0 || frequency > 0.5) { 89 | errx(EXIT_FAILURE, "frequency %lf is invalid", frequency); 90 | } 91 | 92 | OptimalityCalculator calculator(family, energy_error); 93 | const double result = calculator.compute_optimality_factor(scale, frequency); 94 | printf("%lf %lf %lf %lf\n", energy_error, scale, frequency, result); 95 | } 96 | -------------------------------------------------------------------------------- /tests/test-move-semantics.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include 8 | #include "Testing.h" 9 | 10 | class Item { 11 | public: 12 | static int copy_called; 13 | static int copy_ctor_called; 14 | static int ctor_called; 15 | static int move_ctor_called; 16 | 17 | static void reset() { 18 | copy_called = copy_ctor_called = ctor_called = move_ctor_called = 0; 19 | } 20 | 21 | Item() { 22 | ++ctor_called; 23 | } 24 | 25 | Item(const Item &) { 26 | ++copy_ctor_called; 27 | } 28 | 29 | Item(Item &&) { 30 | ++move_ctor_called; 31 | } 32 | 33 | void operator=(const Item &) { 34 | ++copy_called; 35 | } 36 | }; 37 | 38 | int Item::copy_called = 0; 39 | int Item::copy_ctor_called = 0; 40 | int Item::ctor_called = 0; 41 | int Item::move_ctor_called = 0; 42 | 43 | class Internal { 44 | Item item; 45 | 46 | public: 47 | Internal(Item item) : item(std::move(item)) {} 48 | }; 49 | 50 | class External { 51 | Internal internal; 52 | 53 | public: 54 | External(Item item) : internal(std::move(item)) {} 55 | }; 56 | 57 | int main(void) { 58 | Item item; 59 | Item::reset(); 60 | 61 | External external(std::move(item)); 62 | ASSERT_EQUALS(0, Item::copy_called); 63 | ASSERT_EQUALS(0, Item::copy_ctor_called); 64 | ASSERT_EQUALS(0, Item::ctor_called); 65 | ASSERT_EQUALS(3, Item::move_ctor_called); 66 | 67 | puts("OK"); 68 | } 69 | -------------------------------------------------------------------------------- /tests/test-product-calculator.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "BlockAtomProductCalculator.h" 8 | #include "GaussianFamily.h" 9 | #include "Testing.h" 10 | 11 | void test_product() { 12 | auto family = std::make_shared(); 13 | BlockAtomProductCalculator calculator(family); 14 | 15 | BlockAtomParams x0{0.2, 2.0, 20.0}; 16 | ASSERT_APPROX(1.0, calculator.calculate_squared_product(x0, x0), 1.0e-6); 17 | 18 | const double energy_error = 0.01; 19 | const double dl = family->inv_scale_integral(1 - energy_error); 20 | const double dt = family->inv_time_integral(1 - energy_error) * x0.scale; 21 | const double df = family->inv_freq_integral(1 - energy_error) / x0.scale; 22 | 23 | BlockAtomParams xf(x0), xt(x0), xs(x0); 24 | xf.frequency += df; 25 | xt.position += dt; 26 | xs.scale *= std::exp(dl); 27 | 28 | const double expected = std::pow(1 - energy_error, 2); 29 | ASSERT_APPROX(expected, calculator.calculate_squared_product(x0, xf), 1.0e-6); 30 | ASSERT_APPROX(expected, calculator.calculate_squared_product(x0, xt), 1.0e-6); 31 | ASSERT_APPROX(expected, calculator.calculate_squared_product(x0, xs), 2.0e-6); 32 | } 33 | 34 | int main() { 35 | test_product(); 36 | 37 | puts("OK"); 38 | } 39 | -------------------------------------------------------------------------------- /tests/test-rounding.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "Testing.h" 8 | #include "Types.h" 9 | 10 | template 11 | void assert_equals(double input, T expected, T (*fun)(double)) { 12 | T actual = fun(input); 13 | ASSERT_EQUALS(expected, actual); 14 | } 15 | 16 | template 17 | void test_all_cases() { 18 | assert_equals(-4.0, -4, Types::round); 19 | assert_equals(-3.3, -3, Types::round); 20 | assert_equals(-2.6, -3, Types::round); 21 | assert_equals(-1.9, -2, Types::round); 22 | assert_equals(-1.2, -1, Types::round); 23 | assert_equals(-0.5, 0, Types::round); 24 | assert_equals(0.2, 0, Types::round); 25 | assert_equals(0.9, 1, Types::round); 26 | assert_equals(1.6, 2, Types::round); 27 | assert_equals(2.3, 2, Types::round); 28 | assert_equals(3.0, 3, Types::round); 29 | 30 | assert_equals(-4.0, -4, Types::ceil); 31 | assert_equals(-3.3, -3, Types::ceil); 32 | assert_equals(-2.6, -2, Types::ceil); 33 | assert_equals(-1.9, -1, Types::ceil); 34 | assert_equals(-1.2, -1, Types::ceil); 35 | assert_equals(-0.5, 0, Types::ceil); 36 | assert_equals(0.2, 1, Types::ceil); 37 | assert_equals(0.9, 1, Types::ceil); 38 | assert_equals(1.6, 2, Types::ceil); 39 | assert_equals(2.3, 3, Types::ceil); 40 | assert_equals(3.0, 3, Types::ceil); 41 | 42 | assert_equals(-4.0, -4, Types::floor); 43 | assert_equals(-3.3, -4, Types::floor); 44 | assert_equals(-2.6, -3, Types::floor); 45 | assert_equals(-1.9, -2, Types::floor); 46 | assert_equals(-1.2, -2, Types::floor); 47 | assert_equals(-0.5, -1, Types::floor); 48 | assert_equals(0.2, 0, Types::floor); 49 | assert_equals(0.9, 0, Types::floor); 50 | assert_equals(1.6, 1, Types::floor); 51 | assert_equals(2.3, 2, Types::floor); 52 | assert_equals(3.0, 3, Types::floor); 53 | } 54 | 55 | int main() { 56 | test_all_cases(); 57 | test_all_cases(); 58 | test_all_cases(); 59 | 60 | puts("OK"); 61 | } 62 | -------------------------------------------------------------------------------- /tests/test-signal-reader.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "SignalReader.h" 8 | #include "Testing.h" 9 | 10 | const char *const tmp_name = ".test-signal-reader.tmp"; 11 | 12 | void prepare_signal(int total_sample_count) { 13 | FILE *file = fopen(tmp_name, "w+b"); 14 | float data[total_sample_count]; 15 | for (int i = 0; i < total_sample_count; ++i) { 16 | data[i] = static_cast(i); 17 | } 18 | fwrite(data, sizeof(float), total_sample_count, file); 19 | fclose(file); 20 | } 21 | 22 | void test_empty_file_all_epochs() { 23 | prepare_signal(0); 24 | 25 | std::vector selected_channels = {1, 2}; 26 | SignalReaderForAllEpochs reader(tmp_name, 3, std::move(selected_channels), 8); 27 | 28 | ASSERT_EQUALS(0, reader.get_epoch_count()); 29 | ASSERT_EQUALS(2, reader.get_epoch_channel_count()); 30 | ASSERT_EQUALS(8, reader.get_epoch_sample_count()); 31 | 32 | Array2D buffer(2, 8); 33 | auto result = reader.read(buffer); 34 | ASSERT(!result); 35 | } 36 | 37 | void test_empty_file_whole_signal() { 38 | prepare_signal(0); 39 | 40 | std::vector selected_channels = {1, 2}; 41 | SignalReaderForWholeSignal reader(tmp_name, 3, std::move(selected_channels)); 42 | 43 | ASSERT_EQUALS(1, reader.get_epoch_count()); 44 | ASSERT_EQUALS(2, reader.get_epoch_channel_count()); 45 | ASSERT_EQUALS(0, reader.get_epoch_sample_count()); 46 | 47 | Array2D buffer(2, 0); 48 | auto result = reader.read(buffer); 49 | ASSERT(!result); 50 | } 51 | 52 | void test_all_epochs() { 53 | prepare_signal(36); 54 | 55 | std::vector selected_channels = {1, 2}; 56 | SignalReaderForAllEpochs reader(tmp_name, 3, std::move(selected_channels), 8); 57 | 58 | ASSERT_EQUALS(2, reader.get_epoch_count()); 59 | ASSERT_EQUALS(2, reader.get_epoch_channel_count()); 60 | ASSERT_EQUALS(8, reader.get_epoch_sample_count()); 61 | 62 | Array2D buffer(2, 8); 63 | auto result = reader.read(buffer); 64 | ASSERT(result); 65 | ASSERT_EQUALS(0, result->channel_offset); 66 | ASSERT_EQUALS(0, result->epoch_offset); 67 | ASSERT_EQUALS(0, result->epoch_counter); 68 | for (int i = 0; i < 8; ++i) { 69 | ASSERT_EQUALS(3 * i, buffer[0][i]); 70 | ASSERT_EQUALS(3 * i + 1, buffer[1][i]); 71 | } 72 | result = reader.read(buffer); 73 | ASSERT(result); 74 | ASSERT_EQUALS(0, result->channel_offset); 75 | ASSERT_EQUALS(1, result->epoch_offset); 76 | ASSERT_EQUALS(1, result->epoch_counter); 77 | for (int i = 0; i < 4; ++i) { 78 | ASSERT_EQUALS(3 * (i + 8), buffer[0][i]); 79 | ASSERT_EQUALS(3 * (i + 8) + 1, buffer[1][i]); 80 | } 81 | for (int i = 4; i < 8; ++i) { 82 | ASSERT_EQUALS(0, buffer[0][i]); 83 | ASSERT_EQUALS(0, buffer[1][i]); 84 | } 85 | result = reader.read(buffer); 86 | ASSERT(!result); 87 | } 88 | 89 | void test_whole_signal() { 90 | prepare_signal(36); 91 | 92 | std::vector selected_channels = {1, 2}; 93 | SignalReaderForWholeSignal reader(tmp_name, 3, std::move(selected_channels)); 94 | 95 | ASSERT_EQUALS(1, reader.get_epoch_count()); 96 | ASSERT_EQUALS(2, reader.get_epoch_channel_count()); 97 | ASSERT_EQUALS(12, reader.get_epoch_sample_count()); 98 | 99 | Array2D buffer(2, 12); 100 | auto result = reader.read(buffer); 101 | ASSERT(result); 102 | ASSERT_EQUALS(0, result->channel_offset); 103 | ASSERT_EQUALS(0, result->epoch_offset); 104 | ASSERT_EQUALS(0, result->epoch_counter); 105 | for (int i = 0; i < 12; ++i) { 106 | ASSERT_EQUALS(3 * i, buffer[0][i]); 107 | ASSERT_EQUALS(3 * i + 1, buffer[1][i]); 108 | } 109 | result = reader.read(buffer); 110 | ASSERT(!result); 111 | } 112 | 113 | int main() { 114 | test_empty_file_all_epochs(); 115 | test_empty_file_whole_signal(); 116 | test_all_epochs(); 117 | test_whole_signal(); 118 | 119 | remove(tmp_name); 120 | puts("OK"); 121 | } 122 | -------------------------------------------------------------------------------- /tests/test-subsample.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include "BlockDictionary.h" 8 | #include "BlockDictionaryStructure.h" 9 | #include "GaussianFamily.h" 10 | #include "SpectrogramCalculatorFFTW.h" 11 | #include "OptimizationMode.h" 12 | #include "Testing.h" 13 | #include "Worker.h" 14 | 15 | static const double ENERGY_ERROR = 0.001; 16 | static const double SCALE = 10.0; 17 | 18 | void test_subsample(double position, double frequency, double phase = 0.0) { 19 | auto family = std::make_shared(); 20 | PinnedArray2D data(1, 1024); 21 | data.fill(0.0); 22 | 23 | index_t offset, sample_count = family->size_for_values(position, SCALE, &offset); 24 | family->generate_values(position, SCALE, NULL, data[0]+offset, true); 25 | 26 | const double omega = 2 * M_PI * frequency; 27 | for (int i=0; i(1, std::set{ 512 }); 33 | auto dictionary = std::make_unique(structure, data, extractorSingleChannel, *calculator, true); 34 | 35 | Worker worker(data, 1, OPTIMIZATION_DISABLED); 36 | worker.add_calculator(std::move(calculator)); 37 | worker.add_dictionary(std::move(dictionary)); 38 | auto result = worker.get_next_atom(); 39 | 40 | std::list atoms; 41 | result->export_atom(&atoms); 42 | ASSERT_EQUALS(1, atoms.size()); 43 | 44 | const ExportedAtom& atom = atoms.front(); 45 | ASSERT_EQUALS(position, atom.position); 46 | ASSERT_EQUALS(frequency, atom.frequency); 47 | ASSERT_SAME_PHASE(phase, atom.phase, 1.0e-12); 48 | } 49 | 50 | int main() { 51 | test_subsample(511.25, 0/512.0, 0.0); 52 | test_subsample(512.00, 50/512.0, 1.5); 53 | test_subsample(512.75, 100/512.0, 1.0); 54 | test_subsample(513.50, 150/512.0, 0.5); 55 | 56 | puts("OK"); 57 | } 58 | -------------------------------------------------------------------------------- /tests/test-task-queue.cpp: -------------------------------------------------------------------------------- 1 | /********************************************************** 2 | * Piotr T. Różański (c) 2015-2023 * 3 | * Enhanced Matching Pursuit Implementation (empi) * 4 | * See README.md and LICENCE for details. * 5 | **********************************************************/ 6 | #include 7 | #include 8 | #include 9 | #include "TaskQueue.h" 10 | #include "Testing.h" 11 | 12 | const int THREAD_COUNT = 4; 13 | const int ITEM_COUNT = 8; 14 | 15 | int processed[ITEM_COUNT]; 16 | 17 | void consumer_main(TaskQueue* task_queue) { 18 | int number; 19 | while (task_queue->get(number)) { 20 | processed[number] = true; 21 | task_queue->notify(); 22 | } 23 | } 24 | 25 | void test_put_twice() { 26 | TaskQueue task_queue; 27 | memset(processed, 0, sizeof processed); 28 | 29 | std::vector threads; 30 | for (int i=0; i task_queue; 54 | memset(processed, 0, sizeof processed); 55 | 56 | std::vector threads; 57 | for (int i=0; i numbers; 62 | for (int i=0; i 7 | #include 8 | #include "SpectrogramCalculator.h" 9 | #include "Testing.h" 10 | #include "Worker.h" 11 | 12 | class SpectrogramCalculatorForTest : public SpectrogramCalculator { 13 | void compute(const SpectrogramRequest &request) final { 14 | request.maxima[0].energy = 10.0 * (request.how_many * 13 % 100); 15 | } 16 | }; 17 | 18 | class BlockImplForTest : public BlockInterface { 19 | ExtractedMaximum *maximum; 20 | 21 | public: 22 | explicit BlockImplForTest(ExtractedMaximum *maximum) : maximum(maximum) {} 23 | 24 | void notify() final { 25 | maximum->bin_index = 1; 26 | } 27 | }; 28 | 29 | class ExtendedAtomForTest : public ExtendedAtom { 30 | public: 31 | explicit ExtendedAtomForTest(Array2D data, double energy) : ExtendedAtom(std::move(data), energy) {} 32 | 33 | void export_atom(std::list *) final { 34 | // nothing here 35 | } 36 | 37 | IndexRange subtract_from_signal() const final { 38 | return IndexRange(); 39 | } 40 | }; 41 | 42 | class BasicAtomForTest : public BasicAtom { 43 | public: 44 | explicit BasicAtomForTest(Array2D data, double energy) : BasicAtom(std::move(data), energy) {} 45 | 46 | [[nodiscard]] ExtendedAtomPointer extend(bool allow_optimization) final { 47 | return std::make_shared(data, energy); 48 | } 49 | 50 | [[nodiscard]] double get_energy_upper_bound() const final { 51 | return energy; 52 | } 53 | }; 54 | 55 | const int R = 100; // number of requests 56 | 57 | class DictionaryImplForTest : public Dictionary, public BlockInterface { 58 | Array2D data_; 59 | ExtractedMaximum maxima[R]; 60 | 61 | public: 62 | static std::atomic notify_called_count; 63 | 64 | DictionaryImplForTest(Array2D data) : data_(std::move(data)) {} 65 | 66 | size_t get_atom_count() final { 67 | return 1; // does not matter, really 68 | } 69 | 70 | BasicAtomPointer get_best_match() final { 71 | double max_energy = 0; 72 | for (const auto &maximum : maxima) { 73 | max_energy = std::max(max_energy, maximum.energy); 74 | } 75 | return std::make_shared(data_, max_energy); 76 | } 77 | 78 | std::list get_candidate_matches(double energy_to_exceed) final { 79 | return std::list(); 80 | } 81 | 82 | void fetch_requests(IndexRange signal_range, std::list &requests) final { 83 | ASSERT_EQUALS(0, signal_range.first_index); 84 | ASSERT_EQUALS(55, signal_range.end_index); 85 | for (int r = 0; r < R; ++r) { 86 | SpectrogramRequest request; 87 | request.how_many = r + 1; 88 | request.maxima = &maxima[r]; 89 | request.interface = this; 90 | requests.push_back(request); 91 | 92 | maxima[r].energy = 0.0; 93 | maxima[r].bin_index = 0; 94 | } 95 | } 96 | 97 | void fetch_proto_requests(std::list &) final { 98 | // nothing here 99 | } 100 | 101 | void notify() final { 102 | notify_called_count++; 103 | } 104 | }; 105 | 106 | std::atomic DictionaryImplForTest::notify_called_count; 107 | 108 | int main() { 109 | const int W = 7; // number of workers 110 | Array2D data(2, 55); 111 | 112 | Worker worker(data, 1); 113 | worker.add_dictionary(std::make_unique(data)); 114 | for (int w = 0; w < W; ++w) { 115 | worker.add_calculator(std::make_unique()); 116 | } 117 | ExtendedAtomPointer atom = worker.get_next_atom(); 118 | ASSERT_EQUALS(10.0 * (R - 1), atom->energy); 119 | ASSERT_EQUALS(R, DictionaryImplForTest::notify_called_count); 120 | puts("OK"); 121 | } 122 | --------------------------------------------------------------------------------