├── .gitignore ├── Makefile ├── README.md ├── benchmark ├── article.py ├── benchmark.h ├── fftw.cpp ├── graphs.py ├── kissfft.cpp ├── main.cpp └── others │ └── contents.txt ├── common ├── console-colours.h ├── csv.h ├── simple-args.h ├── sqlite-cpp.h └── test │ ├── example-test.cpp │ ├── main.cpp │ └── tests.h ├── comparison.svg ├── signalsmith-fft.h └── tests ├── 00-fft.cpp ├── 01-real.cpp └── tests-common.h /.gitignore: -------------------------------------------------------------------------------- 1 | out 2 | benchmark/others 3 | fft-v4/ 4 | .DS_Store 5 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test 2 | ifndef VERBOSE 3 | .SILENT: 4 | endif 5 | 6 | SHARED_PATH := "common" 7 | 8 | clean: 9 | rm -rf out 10 | 11 | ############## Testing ############## 12 | 13 | TEST_CPP_FILES := $(shell find tests -iname "*.cpp" | sort) 14 | 15 | test: out/test 16 | ./out/test 17 | 18 | out/test: *.h $(shell find tests -iname "*.h") $(shell find tests -iname "*.cpp") 19 | echo "building tests: ${TEST_CPP_FILES}" 20 | mkdir -p out 21 | g++ -std=c++11 -Wall -Wextra -Wfatal-errors -O0 \ 22 | -Wpedantic -pedantic-errors \ 23 | "${SHARED_PATH}/test/main.cpp" -I "${SHARED_PATH}" \ 24 | -I tests/ ${TEST_CPP_FILES} \ 25 | -o out/test 26 | 27 | ############## Benchmarking ############## 28 | 29 | graphs: 30 | python benchmark/graphs.py 31 | 32 | benchmarks: test benchmark-main benchmark-kissfft benchmark-fftw graphs 33 | 34 | BENCHMARK_TEST_TIME := 0.05 35 | 36 | # Generic versions 37 | 38 | benchmark-%: out/benchmark-% 39 | mkdir -p out/results 40 | cd out && ./benchmark-$* --test-time=${BENCHMARK_TEST_TIME} 41 | 42 | out/benchmark-%: *.h $(shell find benchmark -iname "*.h") $(shell find benchmark -iname "*.cpp") 43 | mkdir -p out 44 | g++ -std=c++11 -msse2 -mavx -Wfatal-errors -O3 \ 45 | "${SHARED_PATH}/test/main.cpp" -I "${SHARED_PATH}" \ 46 | -I benchmark/ benchmark/$*.cpp \ 47 | -o out/benchmark-$* 48 | 49 | # Custom versions which need more config 50 | 51 | out/benchmark-fftw: $(shell find benchmark -iname "*.h") $(shell find benchmark -iname "*.cpp") 52 | mkdir -p out 53 | g++ -std=c++11 -msse2 -mavx -Wfatal-errors -g -O3 \ 54 | "${SHARED_PATH}/test/main.cpp" -I "${SHARED_PATH}" \ 55 | -I benchmark/ benchmark/fftw.cpp \ 56 | -lfftw3 \ 57 | -o out/benchmark-fftw 58 | 59 | ############## Development ############## 60 | 61 | # After an initial "make benchmark", you can continue 62 | dev: test benchmark-main graphs 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Signalsmith FFT 2 | 3 | A small and (reasonably) performant C++11 FFT implementation. 4 | 5 | Under development, so things might be a little untidy at the moment, but the API's stable and the tests should build/pass. 6 | 7 | ```cpp 8 | #include "signalsmith-fft.h" 9 | 10 | signalsmith::FFT fft(size); 11 | ``` 12 | 13 | ![speed comparison graph](comparison.svg) 14 | 15 | ## Setting the size 16 | 17 | ```cpp 18 | fft.setSize(1024); 19 | ``` 20 | 21 | It is faster for certain sizes (powers of 2, 3 and 5), and it's _strongly_ recommended that you use these. For convenience there are two methods for finding fast sizes above and below a limit: 22 | 23 | ```cpp 24 | actualSize = fft.setSizeMinimum(1025); // sets (and returns) a fast size >= 1025 25 | 26 | actualSize = fft.setSizeMaximum(1025); // sets (and returns) a fast size <= 1025 27 | ``` 28 | 29 | 30 | ## Forward/reverse FFT 31 | 32 | ```cpp 33 | fft.fft(complexTime, complexSpectrum); 34 | 35 | fft.ifft(complexSpectrum, complexTime); 36 | ``` 37 | 38 | These methods are templated, and accept any iterator or container holding a `std::complex`. This could be a pointer (e.g. `std::complex *`), or a `std::vector`, or whatever. 39 | 40 | ## Real FFT 41 | 42 | ```cpp 43 | 44 | signalsmith::RealFFT fft(size); 45 | 46 | fft.fft(realTime, complexSpectrum); 47 | fft.ifft(complexSpectrum, realTime); 48 | ``` 49 | 50 | The size _must_ be even. The complex spectrum is half the size - e.g. 256 real inputs produce 128 complex outputs. 51 | 52 | Since the 0 and Nyquist frequencies are both real, these are packed into the real/imaginary parts of index 0. 53 | -------------------------------------------------------------------------------- /benchmark/article.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib 4 | from matplotlib import pyplot 5 | 6 | def small(*args, **kwargs): 7 | figure, axes = pyplot.subplots(*args, **kwargs) 8 | figure.set_size_inches(4.5, 3) 9 | return figure, axes 10 | 11 | def medium(*args, **kwargs): 12 | figure, axes = pyplot.subplots(*args, **kwargs) 13 | figure.set_size_inches(6.5, 4) 14 | return figure, axes 15 | 16 | def tall(*args, **kwargs): 17 | figure, axes = pyplot.subplots(*args, **kwargs) 18 | figure.set_size_inches(4.5, 5.5) 19 | return figure, axes 20 | 21 | def short(*args, **kwargs): 22 | figure, axes = pyplot.subplots(*args, **kwargs) 23 | figure.set_size_inches(7, 3) 24 | return figure, axes 25 | 26 | def wide(*args, **kwargs): 27 | figure, axes = pyplot.subplots(*args, **kwargs) 28 | figure.set_size_inches(11, 4) 29 | return figure, axes 30 | 31 | def full(*args, **kwargs): 32 | figure, axes = pyplot.subplots(*args, **kwargs) 33 | figure.set_size_inches(16, 10) 34 | return figure, axes 35 | 36 | def save(prefix, figure, legend_loc=0): 37 | dirname = os.path.dirname(prefix) 38 | if len(dirname) and not os.path.exists(dirname): 39 | os.makedirs(dirname) 40 | 41 | if len(figure.get_axes()) > 1: 42 | figure.set_tight_layout(True) 43 | for axes in figure.get_axes(): 44 | if len(axes.get_lines()) > 0 and legend_loc != None: 45 | axes.legend(loc=legend_loc) 46 | 47 | print(prefix) 48 | figure.savefig(prefix + '.svg', bbox_inches='tight') 49 | -------------------------------------------------------------------------------- /benchmark/benchmark.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "../tests/tests-common.h" 7 | 8 | template 9 | class OutOfPlaceRunner { 10 | std::vector> inVector, outVector; 11 | public: 12 | size_t size; 13 | 14 | OutOfPlaceRunner(size_t size) : size(size) {} 15 | virtual ~OutOfPlaceRunner() {} 16 | 17 | virtual void getPointers(std::complex **inPointer, std::complex **outPointer) { 18 | inVector.resize(size); 19 | outVector.resize(size); 20 | *inPointer = inVector.data(); 21 | *outPointer = outVector.data(); 22 | } 23 | }; 24 | 25 | template 26 | struct Benchmark { 27 | static std::vector getSizes(int customMaxSize=16777216) { // 16777216 28 | std::vector possibleSizes(0); 29 | int size = 1; 30 | int maxSize = std::min(constMaxSize, customMaxSize); 31 | while (size <= maxSize) { 32 | possibleSizes.push_back(size); 33 | size *= 2; 34 | } 35 | 36 | std::vector result; 37 | for (int size : possibleSizes) { 38 | if (size <= maxSize) { 39 | result.push_back(size); 40 | } 41 | 42 | std::vector mults = {3, 5, 7, 9, 11, 15, 23}; 43 | mults = {3, 5, 9, 15, 25}; 44 | mults = {3, 9}; 45 | for (int mult : mults) { 46 | if (size*mult < maxSize) { 47 | result.push_back(size*mult); 48 | } 49 | } 50 | } 51 | return result; 52 | } 53 | 54 | template 55 | static void runBenchmark(Test &test, std::string name, std::string resultPrefix) { 56 | auto sizes = getSizes(); 57 | std::sort(sizes.begin(), sizes.end()); 58 | 59 | std::cout << name << ":\n"; 60 | std::vector rates = BenchmarkRate::map(sizes, [](int size, int repeats, Timer &timer) { 61 | auto runner = Runner(size); 62 | 63 | // The runner handles input/output allocation 64 | std::complex *input, *output; 65 | runner.getPointers(&input, &output); 66 | 67 | for (int i = 0; i < size; i++) { 68 | input[i] = randomComplex(); 69 | } 70 | 71 | timer.start(); 72 | for (int repeat = 0; repeat < repeats; ++repeat) { 73 | runner.forward(input, output); 74 | } 75 | timer.stop(); 76 | }, true); 77 | std::vector scaled = rates; 78 | for (size_t i = 0; i < sizes.size(); i++) { 79 | double size = sizes[i]; 80 | double scaling = std::max(1.0, size*log(size))*1e-6; 81 | scaled[i] = rates[i]*scaling; 82 | } 83 | 84 | std::ofstream outputJs, outputCsv; 85 | outputJs.open(resultPrefix + ".js"); 86 | outputCsv.open(resultPrefix + ".csv"); 87 | 88 | outputJs << "addResults(\"" << name << "\", ["; 89 | outputCsv << "size,ops/sec," << name << "\n"; 90 | outputCsv.precision(15); 91 | 92 | for (size_t i = 0; i < sizes.size(); i++) { 93 | if (i > 0) outputJs << ","; 94 | outputJs << "\n\t{size: " << sizes[i] << ", rate: " << scaled[i] << "}"; 95 | outputCsv << sizes[i] << "," << rates[i] << "," << scaled[i] << "\n"; 96 | } 97 | outputJs << "\n]);"; 98 | 99 | return test.pass(); 100 | } 101 | 102 | TEST_METHOD("double out-of-place", double_out) { 103 | std::string resultPrefix = std::string("results/") + Implementation::resultTag(); 104 | runBenchmark(test, Implementation::name(), resultPrefix); 105 | } 106 | }; 107 | -------------------------------------------------------------------------------- /benchmark/fftw.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "benchmark.h" 4 | 5 | #include 6 | 7 | template 8 | struct FFTW_Estimate { 9 | static std::string name() { 10 | if (searchLevel == FFTW_ESTIMATE) { 11 | return "FFTW (estimate)"; 12 | } else if (searchLevel == FFTW_MEASURE) { 13 | return "FFTW (measure)"; 14 | } 15 | return "FFTW"; 16 | } 17 | static std::string resultTag() { 18 | if (searchLevel == FFTW_ESTIMATE) { 19 | return "fftw-estimate"; 20 | } else if (searchLevel == FFTW_MEASURE) { 21 | return "fftw-measure"; 22 | } 23 | return "fftw"; 24 | } 25 | static std::string version() { 26 | return "3.3.8"; 27 | } 28 | 29 | template 30 | struct OutOfPlace : public OutOfPlaceRunner { 31 | fftw_plan fftForward; 32 | fftw_complex *input, *output; 33 | 34 | OutOfPlace(size_t size) : OutOfPlaceRunner(size) { 35 | input = (fftw_complex*)fftw_malloc(size*sizeof(fftw_complex)); 36 | output = (fftw_complex*)fftw_malloc(size*sizeof(fftw_complex)); 37 | fftForward = fftw_plan_dft_1d(size, input, output, FFTW_FORWARD, searchLevel); 38 | } 39 | ~OutOfPlace() { 40 | fftw_destroy_plan(fftForward); 41 | fftw_free(input); 42 | fftw_free(output); 43 | } 44 | 45 | virtual void getPointers(std::complex **inPointer, std::complex **outPointer) { 46 | *inPointer = (std::complex*)input; 47 | *outPointer = (std::complex*)output; 48 | } 49 | 50 | void forward(std::complex *input, std::complex *output) { 51 | fftw_execute(fftForward); 52 | } 53 | }; 54 | 55 | using DoubleOutOfPlace = OutOfPlace; 56 | }; 57 | static Benchmark> benchEstimate; 58 | static Benchmark, 65536> benchMeasure; 59 | -------------------------------------------------------------------------------- /benchmark/graphs.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import csv 3 | import numpy 4 | import matplotlib 5 | from matplotlib import pyplot 6 | 7 | import article 8 | 9 | def plot_if_exists(axis, csvFile): 10 | axis.set(xlabel="size", ylabel="normalised rate") 11 | axis.set_xscale('log', basex=4) 12 | axis.xaxis.set_major_formatter(matplotlib.ticker.FuncFormatter(lambda x, pos: ("%i"%x) if (x <= 65536) else ("2^%i"%int(numpy.log2(x))))) 13 | if not os.path.exists(csvFile): 14 | return False 15 | with open(csvFile) as inputCsv: 16 | reader = csv.reader(inputCsv) 17 | header = next(reader) 18 | (_, _, lineName) = header 19 | 20 | sizes = [] 21 | rates = [] 22 | for row in reader: 23 | (size, rate, scaledRate) = [float(x) for x in row] 24 | sizes.append(size) 25 | rates.append(scaledRate) 26 | axis.plot(sizes, rates, label=lineName) 27 | 28 | figure, axis = article.wide() 29 | plot_if_exists(axis, "out/results/signalsmith.csv") 30 | plot_if_exists(axis, "out/results/fftw-estimate.csv") 31 | plot_if_exists(axis, "out/results/fftw-measure.csv") 32 | plot_if_exists(axis, "out/results/kissfft.csv") 33 | article.save("out/comparison", figure) 34 | 35 | figure, axis = article.wide() 36 | plot_if_exists(axis, "out/results/signalsmith.csv") 37 | plot_if_exists(axis, "out/results/previous-v4.csv") 38 | plot_if_exists(axis, "out/results/previous-permute.csv") 39 | plot_if_exists(axis, "out/results/previous.csv") 40 | article.save("out/previous", figure) 41 | 42 | figure, axis = article.wide() 43 | plot_if_exists(axis, "out/results/signalsmith.csv") 44 | plot_if_exists(axis, "out/results/dev-history-radix234.csv") 45 | plot_if_exists(axis, "out/results/dev-history-radix23.csv") 46 | plot_if_exists(axis, "out/results/dev-history-factorise.csv") 47 | plot_if_exists(axis, "out/results/dev-history-direct.csv") 48 | article.save("out/history", figure) 49 | -------------------------------------------------------------------------------- /benchmark/kissfft.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "benchmark.h" 4 | 5 | #include "others/kissfft/kissfft.hh" 6 | 7 | struct Kiss { 8 | static std::string name() { 9 | return "KissFFT"; 10 | } 11 | static std::string resultTag() { 12 | return "kissfft"; 13 | } 14 | static std::string version() { 15 | return "(?)"; 16 | } 17 | 18 | template 19 | struct OutOfPlace : public OutOfPlaceRunner { 20 | kissfft fftForward; 21 | 22 | OutOfPlace(size_t size) : OutOfPlaceRunner(size), fftForward(size, false) {} 23 | 24 | void forward(std::complex *input, std::complex *output) { 25 | fftForward.transform(input, output); 26 | } 27 | }; 28 | 29 | using DoubleOutOfPlace = OutOfPlace; 30 | }; 31 | Benchmark benchKiss; 32 | -------------------------------------------------------------------------------- /benchmark/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "benchmark.h" 4 | 5 | struct Signalsmith { 6 | static std::string name() { 7 | return "Signalsmith"; 8 | } 9 | static std::string resultTag() { 10 | return "signalsmith"; 11 | } 12 | static std::string version() { 13 | return "(development)"; 14 | } 15 | 16 | template 17 | struct OutOfPlace : public OutOfPlaceRunner { 18 | signalsmith::FFT fft; 19 | OutOfPlace(size_t size) : OutOfPlaceRunner(size), fft(size) {} 20 | 21 | void forward(std::complex *input, std::complex *output) { 22 | fft.fft(input, output); 23 | } 24 | }; 25 | 26 | using DoubleOutOfPlace = OutOfPlace; 27 | }; 28 | Benchmark benchSignalsmith; 29 | -------------------------------------------------------------------------------- /benchmark/others/contents.txt: -------------------------------------------------------------------------------- 1 | kissfft/ 2 | cloned from https://github.com/mborgerding/kissfft.git 3 | 4 | fftw-3.3.8/ 5 | downloaded from http://www.fftw.org/fftw-3.3.8.tar.gz 6 | (via http://www.fftw.org/download.html) 7 | 8 | compiled as: 9 | ./configure --enable-sse2 --enable-avx --enable-avx2 10 | (I tried --enable-avx-128-fma --enable-generic-simd128 --enable-generic-simd256 and it came out slower?) 11 | (There are more flags on the FFTW site, but I couldn't get them to build) 12 | -------------------------------------------------------------------------------- /common/console-colours.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef _CONSOLE_COLOURS_H 3 | #define _CONSOLE_COLOURS_H 4 | 5 | #include 6 | 7 | namespace Console { 8 | std::string Reset = "\x1b[0m"; 9 | std::string Bright = "\x1b[1m"; 10 | std::string Dim = "\x1b[2m"; 11 | std::string Underscore = "\x1b[4m"; 12 | std::string Blink = "\x1b[5m"; 13 | std::string Reverse = "\x1b[7m"; 14 | std::string Hidden = "\x1b[8m"; 15 | 16 | namespace Foreground { 17 | std::string Black = "\x1b[30m"; 18 | std::string Red = "\x1b[31m"; 19 | std::string Green = "\x1b[32m"; 20 | std::string Yellow = "\x1b[33m"; 21 | std::string Blue = "\x1b[34m"; 22 | std::string Magenta = "\x1b[35m"; 23 | std::string Cyan = "\x1b[36m"; 24 | std::string White = "\x1b[37m"; 25 | } 26 | 27 | namespace Background { 28 | std::string Black = "\x1b[40m"; 29 | std::string Red = "\x1b[41m"; 30 | std::string Green = "\x1b[42m"; 31 | std::string Yellow = "\x1b[43m"; 32 | std::string Blue = "\x1b[44m"; 33 | std::string Magenta = "\x1b[45m"; 34 | std::string Cyan = "\x1b[46m"; 35 | std::string White = "\x1b[47m"; 36 | } 37 | 38 | using namespace Foreground; 39 | } 40 | 41 | #endif -------------------------------------------------------------------------------- /common/csv.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace csv { 9 | class RowReader { 10 | std::ifstream ifstream; 11 | char columnSeparator = ','; 12 | char lineSeparator = '\n'; 13 | char quote = '"'; 14 | 15 | public: 16 | bool open(std::string filename) { 17 | ifstream = std::ifstream(filename.data(), std::ifstream::in); 18 | if (ifstream.good()) return true; 19 | return false; 20 | } 21 | 22 | bool done() { 23 | return !ifstream.good() || ifstream.eof(); 24 | } 25 | 26 | std::vector next() { 27 | std::vector result; 28 | if (done()) return result; 29 | 30 | constexpr int N = 1024; 31 | char line[N]; 32 | ifstream.getline(line, N, lineSeparator); 33 | // Split line into sections 34 | unsigned int length = strlen(line); 35 | unsigned int startIndex = 0, index = 0; 36 | while (index <= length) { 37 | if (index == length || line[index] == columnSeparator) { 38 | unsigned int endIndex = index; 39 | // Strip quotes 40 | if (line[startIndex] == quote) startIndex++; 41 | if (endIndex > 0 && line[endIndex - 1] == quote) endIndex--; 42 | 43 | std::string data(line + startIndex, endIndex - startIndex); 44 | result.push_back(data); 45 | 46 | index++; 47 | startIndex = index; 48 | } else { 49 | if (line[index] == quote) { // Skip ahead to ending quote 50 | index++; 51 | while (index < length && line[index] != quote) { 52 | index++; 53 | } 54 | } 55 | index++; 56 | } 57 | } 58 | return result; 59 | } 60 | }; 61 | } 62 | -------------------------------------------------------------------------------- /common/simple-args.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include "console-colours.h" 8 | 9 | /** Expected use: 10 | 11 | SimpleArgs args(argc, argv); 12 | 13 | std::string foo = args.arg("foo"); 14 | std::string bar = args.arg("bar", "a string for Bar", "default"); 15 | 16 | // Exits if "foo" not supplied. "bar" has a default value, so it's fine to omit 17 | if (args.help(std::cerr)) return 1; 18 | 19 | **/ 20 | class SimpleArgs { 21 | int argc; 22 | char** argv; 23 | 24 | template 25 | T valueFromString(const char *arg); 26 | 27 | std::string parsedCommand; 28 | struct Keywords { 29 | std::string keyword; 30 | std::string description; 31 | bool isHelp; 32 | }; 33 | std::vector keywordOptions; 34 | std::vector argDetails; 35 | std::map flagOptions; 36 | void clearKeywords() { 37 | keywordOptions.resize(0); 38 | flagOptions.clear(); 39 | } 40 | 41 | bool helpMode = false; 42 | bool hasError = false; 43 | std::string errorMessage; 44 | void setError(std::string message) { 45 | if (!hasError) { 46 | hasError = true; 47 | errorMessage = message; 48 | } 49 | } 50 | 51 | std::map flagMap; 52 | void consumeFlags() { 53 | while (index < argc && std::strlen(argv[index]) > 0 && argv[index][0] == '-') { 54 | const char* arg = argv[index++]; 55 | size_t length = strlen(arg); 56 | 57 | size_t keyStart = 1, keyEnd = keyStart + 1; 58 | size_t valueStart = keyEnd; 59 | // If it's "--long-arg" format 60 | if (length > 1 && arg[1] == '-') { 61 | keyStart++; 62 | while (keyEnd < length && arg[keyEnd] != '=') { 63 | keyEnd++; 64 | } 65 | valueStart = keyEnd; 66 | if (keyEnd < length) valueStart++; 67 | } 68 | 69 | std::string key = std::string(arg + keyStart, keyEnd - keyStart); 70 | std::string value = std::string(arg + valueStart); 71 | 72 | flagMap[key] = value; 73 | } 74 | } 75 | 76 | int index = 1; 77 | public: 78 | SimpleArgs(int argc, char* argv[]) : argc(argc), argv(argv) { 79 | parsedCommand = argv[0]; 80 | } 81 | 82 | int help(std::ostream& out=std::cerr) { 83 | if (keywordOptions.size() > 0) { 84 | parsedCommand += std::string(" "); 85 | } 86 | out << "Usage:\n\t" << parsedCommand << "\n\n"; 87 | if (keywordOptions.size() > 0) { 88 | out << "Commands:\n"; 89 | for (unsigned int i = 0; i < keywordOptions.size(); i++) { 90 | out << "\t" << keywordOptions[i].keyword; 91 | if (keywordOptions[i].isHelp) out << " ..."; 92 | if (keywordOptions[i].description.size()) out << " - " << keywordOptions[i].description; 93 | out << "\n"; 94 | } 95 | out << "\n"; 96 | } 97 | if (argDetails.size() > 0) { 98 | out << "Arguments:\n"; 99 | for (auto iter = flagOptions.begin(); iter != flagOptions.end(); iter++) { 100 | Keywords &pair = iter->second; 101 | out << "\t" << (pair.keyword.length() > 1 ? "--" : "-") << pair.keyword; 102 | if (pair.description.size()) out << " - " << pair.description; 103 | out << "\n"; 104 | } 105 | out << "\n"; 106 | } 107 | if (flagOptions.size() > 0) { 108 | out << "Options:\n"; 109 | for (auto iter = flagOptions.begin(); iter != flagOptions.end(); iter++) { 110 | Keywords &pair = iter->second; 111 | out << "\t" << (pair.keyword.length() > 1 ? "--" : "-") << pair.keyword; 112 | if (pair.description.size()) out << " - " << pair.description; 113 | out << "\n"; 114 | } 115 | out << "\n"; 116 | } 117 | return hasError ? -1 : 0; 118 | } 119 | 120 | bool error(std::ostream& out=std::cerr) { 121 | if (!hasError && !helpMode) return false; 122 | help(out); 123 | if (!helpMode) { 124 | out << Console::Red << errorMessage << Console::Reset << "\n"; 125 | } 126 | return true; 127 | } 128 | bool error(std::string forcedError, std::ostream& out=std::cerr) { 129 | if (!hasError) { 130 | hasError = true; 131 | errorMessage = forcedError; 132 | } 133 | return error(out); 134 | } 135 | 136 | template 137 | T arg(std::string name, std::string longName, T defaultValue) { 138 | clearKeywords(); 139 | consumeFlags(); 140 | parsedCommand += std::string(" [?") + name + "]"; 141 | argDetails.push_back(Keywords{name, longName, false}); 142 | 143 | if (index >= argc) return defaultValue; 144 | return valueFromString(argv[index++]); 145 | } 146 | 147 | template 148 | T arg(std::string name, std::string longName="") { 149 | clearKeywords(); 150 | consumeFlags(); 151 | parsedCommand += std::string(" [") + name + "]"; 152 | argDetails.push_back(Keywords{name, longName, false}); 153 | 154 | if (index >= argc) { 155 | if (longName.length() > 0) { 156 | setError("Missing " + longName + " <" + name + ">"); 157 | } else { 158 | setError("Missing argument <" + name + ">"); 159 | } 160 | return T(); 161 | } 162 | 163 | return valueFromString(argv[index++]); 164 | } 165 | 166 | bool command(std::string keyword, std::string description="", bool isHelp=false) { 167 | consumeFlags(); 168 | if (index < argc && !keyword.compare(argv[index])) { 169 | clearKeywords(); 170 | index++; 171 | if (!isHelp) parsedCommand += std::string(" ") + keyword; 172 | return true; 173 | } 174 | keywordOptions.push_back(Keywords{keyword, description, isHelp}); 175 | return false; 176 | } 177 | bool helpCommand(std::string keyword) { 178 | helpMode = command(keyword, "", true); 179 | if (helpMode) { 180 | keywordOptions.insert(keywordOptions.begin(), Keywords{keyword, "", true}); 181 | } 182 | return helpMode; 183 | } 184 | 185 | template 186 | T flag(std::string key, std::string description, T defaultValue) { 187 | consumeFlags(); 188 | if (!hasFlag(key, description)) return defaultValue; 189 | 190 | auto iterator = flagMap.find(key); 191 | return valueFromString(iterator->second.c_str()); 192 | } 193 | template 194 | T flag(std::string key, T defaultValue) { 195 | consumeFlags(); 196 | if (!hasFlag(key, "")) return defaultValue; 197 | 198 | auto iterator = flagMap.find(key); 199 | return valueFromString(iterator->second.c_str()); 200 | } 201 | template 202 | T flag(std::string key) { 203 | return flag(key, T()); 204 | } 205 | bool hasFlag(std::string key, std::string description="") { 206 | consumeFlags(); 207 | auto iterator = flagMap.find(key); 208 | if (description.length() > 0 || iterator == flagMap.end()) { 209 | flagOptions[key] = Keywords{key, description, false}; 210 | } 211 | 212 | iterator = flagMap.find(key); 213 | return iterator != flagMap.end(); 214 | } 215 | bool helpFlag(std::string key, std::string description="") { 216 | consumeFlags(); 217 | flagOptions[key] = Keywords{key, description, true}; 218 | auto iterator = flagMap.find(key); 219 | helpMode = (iterator != flagMap.end()); 220 | return helpMode; 221 | } 222 | }; 223 | 224 | template<> 225 | std::string SimpleArgs::valueFromString(const char *arg) { 226 | return arg; 227 | } 228 | template<> 229 | const char * SimpleArgs::valueFromString(const char *arg) { 230 | return arg; 231 | } 232 | template<> 233 | int SimpleArgs::valueFromString(const char *arg) { 234 | return std::stoi(arg); 235 | } 236 | template<> 237 | long SimpleArgs::valueFromString(const char *arg) { 238 | return std::stol(arg); 239 | } 240 | template<> 241 | float SimpleArgs::valueFromString(const char *arg) { 242 | return std::stof(arg); 243 | } 244 | template<> 245 | double SimpleArgs::valueFromString(const char *arg) { 246 | return std::stod(arg); 247 | } -------------------------------------------------------------------------------- /common/sqlite-cpp.h: -------------------------------------------------------------------------------- 1 | //extern C { 2 | #include 3 | //} 4 | 5 | #include 6 | #include 7 | 8 | #include "console-colours.h" 9 | 10 | /* Expected use: 11 | 12 | SQLite3Versioned db("database.sqlite"); 13 | // Backwards-compatible structure - append to this as changes are made 14 | db.bumpVersion(1, "CREATE TABLE hello_world (foo INTEGER, bar TEXT)"); 15 | 16 | // Template-based query 17 | auto find = db.statement("SELECT bar FROM hello_world WHERE foo=@foo"); 18 | // Specify values by index, or by string 19 | find->bind(1, 123); 20 | find->bind("foo", 456); // can explicitly set type 21 | 22 | // Step through result 23 | if (find->step()) { 24 | std::string bar = find->get(0); 25 | ... 26 | } 27 | 28 | // Transactions are mostly used for efficiency. If you have multiple threads, re-consider using SQLite. 29 | { 30 | auto transaction = db.transaction(); 31 | 32 | auto query = db.statement(...); 33 | } 34 | */ 35 | 36 | class SQLite3 { 37 | std::string filename; 38 | sqlite3* m_db = nullptr; 39 | 40 | sqlite3* db() { // Lazy-loading 41 | if (!m_db) { 42 | #ifdef DEBUG_SQL 43 | std::cout << Console::Blue << "lazy-loading SQL file: " << filename << Console::Reset << "\n"; 44 | #endif 45 | int resultCode = sqlite3_open(filename.c_str(), &m_db); 46 | if (resultCode) { 47 | error = 0; 48 | errorMessage = std::string(sqlite3_errmsg(m_db)); 49 | sqlite3_close(m_db); 50 | m_db = nullptr; 51 | 52 | std::cerr << Console::Bright << Console::Red << "SQL error: " << error << " - " << errorMessage << Console::Reset << "\n"; 53 | } 54 | } 55 | return m_db; 56 | } 57 | public: 58 | int error = 0; 59 | std::string errorMessage; 60 | 61 | SQLite3(std::string filename) : filename(filename) {} 62 | ~SQLite3() { 63 | if (m_db) sqlite3_close(m_db); 64 | } 65 | 66 | class Statement { 67 | SQLite3* db; 68 | sqlite3_stmt* statement = 0; 69 | const char * sqlRemainder = 0; 70 | 71 | void getError(int code) { 72 | error = db->error = code; 73 | errorMessage = db->errorMessage = std::string(sqlite3_errmsg(db->db())); 74 | } 75 | void checkForError(int resultCode) { 76 | if (resultCode != SQLITE_OK) { 77 | getError(resultCode); 78 | //statement = nullptr; 79 | 80 | std::cerr << Console::Bright << Console::Red << "SQL error: " << error << " - " << errorMessage << Console::Reset << "\n"; 81 | } 82 | } 83 | void customError(std::string message) { 84 | error = -1; 85 | errorMessage = message; 86 | std::cerr << Console::Bright << Console::Red << "Error: " << message << Console::Reset << "\n"; 87 | } 88 | public: 89 | int error = 0; 90 | std::string errorMessage; 91 | 92 | Statement(SQLite3* db, std::string sqlString) : db(db) { 93 | #ifdef DEBUG_SQL 94 | std::cout << Console::Dim << sqlString << Console::Reset << "\n"; 95 | #endif 96 | int resultCode = sqlite3_prepare_v2(db->db(), sqlString.c_str(), sqlString.length(), &statement, &sqlRemainder); 97 | if (resultCode != SQLITE_OK) { 98 | getError(resultCode); 99 | statement = nullptr; 100 | 101 | std::cerr << Console::Bright << Console::Red << "SQL error: " << error << " - " << errorMessage << Console::Reset << "\n"; 102 | } 103 | } 104 | ~Statement() { 105 | if (statement) sqlite3_finalize(statement); 106 | } 107 | 108 | template 109 | void bind(const char* name, T value) { 110 | int index = sqlite3_bind_parameter_index(statement, name); 111 | this->bind(index, value); 112 | } 113 | 114 | void bind(int index, std::string value) { 115 | checkForError(sqlite3_bind_text(statement, index, value.data(), value.size(), SQLITE_TRANSIENT)); 116 | } 117 | void bind(int index, char const *value) { 118 | checkForError(sqlite3_bind_text(statement, index, value, -1, SQLITE_TRANSIENT)); 119 | } 120 | void bind(int index, double value) { 121 | checkForError(sqlite3_bind_double(statement, index, value)); 122 | } 123 | void bind(int index, int value) { 124 | checkForError(sqlite3_bind_int(statement, index, value)); 125 | } 126 | void bind(int index, long value) { 127 | checkForError(sqlite3_bind_int64(statement, index, value)); 128 | } 129 | void bind(int index) { 130 | checkForError(sqlite3_bind_null(statement, index)); 131 | } 132 | void reset() { 133 | checkForError(sqlite3_reset(statement)); 134 | } 135 | 136 | bool step() { 137 | if (!statement) return false; 138 | int queryStatus = sqlite3_step(statement); 139 | if (queryStatus != SQLITE_ROW && queryStatus != SQLITE_DONE) { 140 | getError(queryStatus); 141 | } 142 | return queryStatus == SQLITE_ROW; 143 | } 144 | 145 | int columns() { 146 | return statement ? sqlite3_column_count(statement) : 0; 147 | } 148 | 149 | int columnType(int i) { 150 | return statement ? sqlite3_column_type(statement, i) : 0; 151 | } 152 | 153 | std::string columnName(int i) { 154 | if (!statement) return ""; 155 | const char* cstr = sqlite3_column_name(statement, i); 156 | return std::string(reinterpret_cast(cstr)); 157 | } 158 | 159 | template 160 | T get(int i); 161 | 162 | template 163 | void bind(int i, T value); 164 | }; 165 | 166 | std::shared_ptr statement(std::string sqlString) { 167 | return std::make_shared(this, sqlString); 168 | } 169 | std::shared_ptr query(std::string sqlString) { 170 | return this->statement(sqlString); 171 | } 172 | 173 | sqlite3_int64 insertId() { 174 | return sqlite3_last_insert_rowid(db()); 175 | } 176 | 177 | class Transaction { 178 | SQLite3* db; 179 | Statement open; 180 | Statement close; 181 | public: 182 | Transaction(SQLite3* db) : db(db), open(db, "BEGIN TRANSACTION"), close(db, "COMMIT TRANSACTION") { 183 | open.step(); 184 | } 185 | ~Transaction() { 186 | close.step(); 187 | db->currentTransaction = nullptr; 188 | } 189 | }; 190 | 191 | std::shared_ptr transaction() { 192 | if (currentTransaction) return nullptr; 193 | return currentTransaction = std::make_shared(this); 194 | } 195 | private: 196 | std::shared_ptr currentTransaction; 197 | }; 198 | 199 | template<> 200 | int SQLite3::Statement::get(int i) { 201 | return statement ? sqlite3_column_int(statement, i) : 0; 202 | } 203 | template<> 204 | long SQLite3::Statement::get(int i) { 205 | return statement ? sqlite3_column_int64(statement, i) : 0; 206 | } 207 | template<> 208 | double SQLite3::Statement::get(int i) { 209 | return statement ? sqlite3_column_double(statement, i) : 0; 210 | } 211 | template<> 212 | std::string SQLite3::Statement::get(int i) { 213 | if (!statement) return ""; 214 | const unsigned char* cstr = sqlite3_column_text(statement, i); 215 | if (!cstr) return ""; 216 | return std::string(reinterpret_cast(cstr)); 217 | } 218 | 219 | template<> 220 | void SQLite3::Statement::bind(int i, int value) { 221 | if (statement) sqlite3_bind_int(statement, i, value); 222 | } 223 | 224 | /********/ 225 | 226 | class SQLite3Versioned : public SQLite3 { 227 | int userVersion = 0; 228 | bool versionSuccess = true; 229 | public: 230 | SQLite3Versioned(std::string filename) : SQLite3(filename) { 231 | auto query = statement("PRAGMA user_version"); 232 | query->step(); 233 | userVersion = query->get(0); 234 | } 235 | 236 | /* Use to create sequence of changes which upgrade old DBs, like: 237 | db.bumpVersion(1, "CREATE TABLE hello_world (foo INTEGER, bar TEXT)"); 238 | db.bumpVersion(2, "ALTER TABLE hello_world ADD beep FLOAT"); 239 | */ 240 | bool bumpVersion(int version, std::string sql) { 241 | if (!versionSuccess) return false; // Fail once, don't continue 242 | if (userVersion < version) { 243 | userVersion = version; 244 | auto query = statement(sql); 245 | if (query->error) return versionSuccess = false; 246 | query->step(); 247 | if (query->error) return versionSuccess = false; 248 | 249 | query = statement(std::string("PRAGMA user_version=") + std::to_string(version)); 250 | if (query->error) return versionSuccess = false; 251 | query->step(); 252 | return true; 253 | } 254 | return versionSuccess; 255 | } 256 | }; 257 | -------------------------------------------------------------------------------- /common/test/example-test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | // from the shared library 6 | #include 7 | 8 | TEST("Example test", example_test) { 9 | if (false) test.fail("it failed"); 10 | } 11 | 12 | struct TestObject { 13 | TEST_METHOD("Example method-test", example_test) { 14 | if (false) test.fail("it failed"); 15 | } 16 | }; 17 | TestObject myObject; // Can be customised/templated 18 | 19 | TEST("Example benchmark", example_benchmarks) { 20 | std::vector configs = {10, 100, 1000}; 21 | 22 | std::cout << "size:\t"; 23 | BenchmarkRate::print(configs); 24 | 25 | std::cout << "rate:\t"; 26 | std::vector rates = BenchmarkRate::map(configs, [](int config, int repeats, Timer &timer) { 27 | timer.scaleRate(config/1e6); // scale to mega-ops/second 28 | timer.start(); 29 | 30 | for (int repeat = 0; repeat < repeats; ++repeat) { 31 | // Actual test code 32 | int sum = 0; 33 | for (int i = 0; i < config; ++i) { 34 | sum++; 35 | } 36 | } 37 | 38 | timer.stop(); 39 | }, true); 40 | 41 | return test.pass(); 42 | } -------------------------------------------------------------------------------- /common/test/main.cpp: -------------------------------------------------------------------------------- 1 | #include "../simple-args.h" 2 | #include "../console-colours.h" 3 | 4 | #include "tests.h" 5 | #include // srand, rand 6 | #include // time 7 | 8 | TestList _globalTestList; 9 | 10 | void Test::run(int depth, bool silent) { 11 | if (running) return fail("Re-entered test function"); 12 | if (!silent) { 13 | std::cerr << Console::Dim; 14 | for (int i = 0; i < depth - 1; i++) { 15 | std::cerr << " > "; 16 | } 17 | std::cerr << Console::Cyan << "Test: " 18 | << Console::Reset << Console::Cyan << testName 19 | << Console::White << " (" << codeLocation << ")" << Console::Reset << std::endl; 20 | } 21 | running = true; 22 | 23 | runFn(*this); 24 | 25 | running = false; 26 | } 27 | 28 | void TestList::add(Test& test) { 29 | if (currentlyRunning.size() > 0) { 30 | Test *latest = currentlyRunning[0]; 31 | if (!latest->success) return; 32 | // This is a sub-test, run it immediately instead of adding it 33 | currentlyRunning.push_back(&test); 34 | test.run(currentlyRunning.size(), currentlySilent); 35 | currentlyRunning.pop_back(); 36 | return; 37 | } 38 | tests.push_back(test); 39 | } 40 | 41 | void TestList::fail(std::string reason) { 42 | for (auto testPtr : currentlyRunning) { 43 | testPtr->fail(reason); 44 | } 45 | } 46 | 47 | int TestList::run(int repeats) { 48 | currentlySilent = false; 49 | for (int repeat = 0; repeat < repeats; repeat++) { 50 | for (unsigned int i = 0; i < tests.size(); i++) { 51 | Test& test = tests[i]; 52 | currentlyRunning = {&test}; 53 | test.run(0, currentlySilent); 54 | if (!test.success) { 55 | std::cerr << Console::Red << Console::Bright << "\nFailed: " 56 | << Console::Reset << test.reason << "\n\n"; 57 | return 1; 58 | } 59 | } 60 | currentlySilent = true; 61 | } 62 | currentlyRunning.resize(0); 63 | return 0; 64 | } 65 | 66 | double defaultBenchmarkTime = 1; 67 | int defaultBenchmarkDivisions = 5; 68 | 69 | int main(int argc, char* argv[]) { 70 | SimpleArgs args(argc, argv); 71 | args.helpFlag("help"); 72 | 73 | int repeats = args.flag("repeats", "loop the tests a certain number of times", 1); 74 | defaultBenchmarkTime = args.flag("test-time", "target per-test duration for benchmarks (excluding setup)", 1); 75 | defaultBenchmarkDivisions = args.flag("test-divisions", "target number of sub-divisions for benchmarks", 5); 76 | if (args.error()) return 1; 77 | 78 | int randomSeed = args.flag("seed", "random seed", time(NULL)); 79 | srand(randomSeed); 80 | std::cout << Console::Dim << "random seed: " << randomSeed << Console::Reset << "\n"; 81 | return _globalTestList.run(repeats); 82 | } 83 | -------------------------------------------------------------------------------- /common/test/tests.h: -------------------------------------------------------------------------------- 1 | #ifndef _TEST_FRAMEWORK_TESTS_H 2 | #define _TEST_FRAMEWORK_TESTS_H 3 | 4 | #include // std::cout 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | class Test; 11 | 12 | // A test-runner, defined statically 13 | class TestList { 14 | std::vector tests; 15 | std::vector currentlyRunning; 16 | bool currentlySilent = false; 17 | public: 18 | void add(Test& test); 19 | int run(int repeats=1); 20 | void fail(std::string reason); 21 | }; 22 | 23 | // A test object with an associated function, which adds itself to the above list 24 | class Test { 25 | using TestFn = std::function; 26 | 27 | TestList& testList; 28 | std::string codeLocation; 29 | std::string testName; 30 | TestFn runFn; 31 | 32 | bool running = false; 33 | public: 34 | Test(TestList& testList, std::string codeLocation, std::string testName, TestFn fn) : testList(testList), codeLocation(codeLocation), testName(testName), runFn(fn) { 35 | testList.add(*this); 36 | } 37 | void run(int depth, bool silent=false); 38 | 39 | bool success = true; 40 | std::string reason; 41 | void fail(std::string r="") { 42 | if (!success) return; 43 | success = false; 44 | reason = r; 45 | testList.fail(reason); 46 | } 47 | void pass() {} 48 | }; 49 | 50 | #define TEST_VAR_NAME test 51 | // A macro to define a new test 52 | // Use like: TEST(some_unique_name) {...} 53 | #define TEST(description, uniqueName) \ 54 | static void test_##uniqueName (Test &test); \ 55 | static Test Test_##uniqueName {_globalTestList, std::string(__FILE__ ":") + std::to_string(__LINE__), description, test_##uniqueName}; \ 56 | static void test_##uniqueName (Test &TEST_VAR_NAME) 57 | // Use if defining test inside a struct (e.g. for templating) 58 | #define TEST_METHOD(description, uniqueName) \ 59 | static void test_##uniqueName (Test &test) { \ 60 | testbody_##uniqueName(test); \ 61 | } \ 62 | Test Test_##uniqueName {_globalTestList, std::string(__FILE__ ":") + std::to_string(__LINE__), description, test_##uniqueName}; \ 63 | static void testbody_##uniqueName (Test &TEST_VAR_NAME) 64 | 65 | #define FAIL(reason) TEST_VAR_NAME.fail(reason) 66 | 67 | extern TestList _globalTestList; 68 | 69 | /***** Benchmarking stuff *****/ 70 | 71 | class Timer { 72 | std::chrono::high_resolution_clock::time_point startTime; 73 | double totalTime = 0; 74 | int segmentCount = 0; 75 | double scaleFactor = 1; 76 | public: 77 | void start() { 78 | startTime = std::chrono::high_resolution_clock::now(); 79 | } 80 | double stop() { 81 | std::chrono::duration duration = std::chrono::high_resolution_clock::now() - startTime; 82 | segmentCount++; 83 | return (totalTime += duration.count()); 84 | } 85 | void clear() { 86 | totalTime = 0; 87 | segmentCount = 0; 88 | scaleFactor = 1; 89 | } 90 | void scale(double scale) { 91 | scaleFactor *= scale; 92 | } 93 | void scaleRate(double scale) { 94 | scaleFactor /= scale; 95 | } 96 | double time() const { 97 | return totalTime; 98 | } 99 | double scaledTime() const { 100 | return totalTime*scaleFactor; 101 | } 102 | double segments() const { 103 | return segmentCount; 104 | } 105 | }; 106 | 107 | /* 108 | Executes a test function with an increasing number of repeats 109 | until a certain amount of time is spent on the computation. 110 | 111 | Performs a few repeated measurements with shorter periods, 112 | and collects the fastest. 113 | 114 | Example use: 115 | 116 | BenchmarkRate trial([](int repeats, Timer &timer) { 117 | timer.start(); 118 | // test code 119 | timer.stop(); 120 | }); 121 | trial.run(1); // spend at least a second on it 122 | trial.fastest // also returned from .run() 123 | 124 | Use for a range of configurations 125 | 126 | std::vector configSize = {1, 2, 4, 8, 16}; 127 | std::vector rates = BenchmarkRate::map(configSize, [](int configSize, int repeats, Timer &timer) { 128 | 129 | }); 130 | 131 | */ 132 | 133 | extern double defaultBenchmarkTime; 134 | extern int defaultBenchmarkDivisions; 135 | 136 | struct BenchmarkRate { 137 | using TestFunction = std::function; 138 | 139 | TestFunction fn; 140 | std::vector rates; 141 | double fastest = 0; 142 | double optimistic = 0; 143 | 144 | BenchmarkRate(TestFunction fn) : fn(fn) {} 145 | 146 | void clear() { 147 | rates.resize(0); 148 | fastest = 0; 149 | } 150 | 151 | double run(double targetTotalTime=0, int divisions=0) { 152 | if (targetTotalTime == 0) targetTotalTime = defaultBenchmarkTime; 153 | if (divisions == 0) divisions = defaultBenchmarkDivisions; 154 | 155 | Timer timer; 156 | double totalTime = 0; 157 | 158 | int repeats = 1; 159 | double targetBlockTime = std::min(targetTotalTime/(divisions + 1), 0.05); // 50ms blocks or less 160 | while (repeats < 1e10) { 161 | timer.clear(); 162 | fn(repeats, timer); 163 | if (timer.segments() == 0) { 164 | std::cerr << "Benchmark function didn't call timer.start()/.stop()\n"; 165 | // The test isn't calling the timer 166 | return 0; 167 | } 168 | 169 | double time = timer.time(); 170 | totalTime += time; 171 | if (time >= targetBlockTime) { 172 | break; 173 | } else { 174 | int estimatedRepeats = repeats*targetBlockTime/(time + targetBlockTime*0.01); 175 | repeats = std::max(repeats*2, (int)estimatedRepeats); 176 | } 177 | } 178 | 179 | rates.push_back(repeats/timer.scaledTime()); 180 | 181 | while (totalTime < targetTotalTime) { 182 | timer.clear(); 183 | fn(repeats, timer); 184 | 185 | double time = timer.time(); 186 | totalTime += time; 187 | rates.push_back(repeats/timer.scaledTime()); 188 | } 189 | 190 | double sum = 0; 191 | for (double rate : rates) { 192 | fastest = std::max(fastest, rate); 193 | sum += rate; 194 | } 195 | double mean = sum/rates.size(); 196 | 197 | double optimisticSum = 0; 198 | int optimisticCount = 0; 199 | for (double rate : rates) { 200 | if (rate >= mean) { 201 | optimisticSum += rate; 202 | optimisticCount++; 203 | } 204 | } 205 | optimistic = optimisticSum/optimisticCount; 206 | return optimistic; 207 | } 208 | 209 | template 210 | static std::vector map(std::vector &args, std::function fn, bool print=false) { 211 | std::vector results; 212 | for (Arg& arg : args) { 213 | BenchmarkRate trial([&, arg, fn](int repeats, Timer &timer) { 214 | fn(arg, repeats, timer); 215 | }); 216 | results.push_back(trial.run()); 217 | } 218 | if (print) { 219 | BenchmarkRate::print(results, true); 220 | } 221 | return results; 222 | } 223 | 224 | template 225 | static void print(std::vector array, bool newline=true) { 226 | for (unsigned int i = 0; i < array.size(); ++i) { 227 | if (i > 0) std::cout << "\t"; 228 | std::cout << array[i]; 229 | } 230 | if (newline) std::cout << std::endl; 231 | } 232 | }; 233 | 234 | #endif -------------------------------------------------------------------------------- /comparison.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 10 | 11 | 12 | 13 | 20 | 21 | 22 | 23 | 30 | 31 | 32 | 102 | 103 | 104 | 150 | 151 | 152 | 222 | 223 | 224 | 294 | 295 | 296 | 297 | 298 | 301 | 302 | 303 | 304 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 448 | 449 | 450 | 451 | 452 | 453 | 454 | 455 | 456 | 457 | 482 | 506 | 507 | 508 | 509 | 510 | 511 | 512 | 513 | 514 | 515 | 516 | 519 | 520 | 521 | 522 | 523 | 524 | 525 | 526 | 527 | 528 | 551 | 552 | 553 | 554 | 555 | 556 | 557 | 558 | 559 | 560 | 561 | 562 | 565 | 566 | 567 | 568 | 569 | 570 | 571 | 572 | 573 | 574 | 607 | 608 | 609 | 610 | 611 | 612 | 613 | 614 | 615 | 616 | 617 | 618 | 621 | 622 | 623 | 624 | 625 | 626 | 627 | 628 | 629 | 630 | 667 | 699 | 700 | 701 | 702 | 703 | 704 | 705 | 706 | 707 | 708 | 709 | 710 | 711 | 714 | 715 | 716 | 717 | 718 | 719 | 720 | 721 | 722 | 723 | 724 | 725 | 726 | 727 | 728 | 729 | 730 | 731 | 732 | 733 | 736 | 737 | 738 | 739 | 740 | 741 | 742 | 743 | 744 | 745 | 755 | 756 | 757 | 758 | 759 | 760 | 761 | 762 | 763 | 764 | 765 | 766 | 769 | 770 | 771 | 772 | 773 | 774 | 775 | 776 | 777 | 778 | 779 | 780 | 781 | 782 | 783 | 784 | 785 | 786 | 787 | 790 | 791 | 792 | 793 | 794 | 795 | 796 | 797 | 798 | 799 | 800 | 801 | 802 | 803 | 804 | 805 | 806 | 807 | 808 | 811 | 812 | 813 | 814 | 815 | 816 | 817 | 818 | 819 | 820 | 821 | 822 | 823 | 824 | 825 | 826 | 827 | 828 | 829 | 830 | 833 | 834 | 835 | 836 | 837 | 838 | 839 | 840 | 841 | 842 | 843 | 844 | 845 | 846 | 847 | 848 | 849 | 850 | 851 | 852 | 853 | 854 | 855 | 856 | 857 | 858 | 859 | 860 | 861 | 862 | 863 | 864 | 865 | 866 | 867 | 868 | 869 | 870 | 871 | 872 | 873 | 874 | 875 | 876 | 877 | 878 | 879 | 880 | 881 | 882 | 883 | 884 | 885 | 886 | 887 | 888 | 889 | 890 | 891 | 892 | 893 | 894 | 895 | 896 | 897 | 898 | 899 | 900 | 901 | 902 | 903 | 904 | 905 | 906 | 907 | 908 | 909 | 910 | 911 | 912 | 913 | 914 | 915 | 916 | 917 | 918 | 919 | 920 | 921 | 922 | 923 | 924 | 925 | 926 | 959 | 972 | 1016 | 1032 | 1033 | 1034 | 1035 | 1036 | 1037 | 1038 | 1039 | 1040 | 1041 | 1042 | 1043 | 1044 | 1047 | 1048 | 1049 | 1050 | 1053 | 1054 | 1055 | 1056 | 1057 | 1058 | 1059 | 1060 | 1061 | 1062 | 1063 | 1064 | 1065 | 1066 | 1067 | 1070 | 1071 | 1072 | 1073 | 1074 | 1075 | 1076 | 1077 | 1078 | 1079 | 1080 | 1081 | 1082 | 1083 | 1084 | 1085 | 1086 | 1087 | 1090 | 1091 | 1092 | 1093 | 1094 | 1095 | 1096 | 1097 | 1098 | 1099 | 1100 | 1101 | 1102 | 1103 | 1104 | 1105 | 1106 | 1107 | 1108 | 1111 | 1112 | 1113 | 1114 | 1115 | 1116 | 1117 | 1118 | 1119 | 1120 | 1121 | 1122 | 1123 | 1124 | 1125 | 1126 | 1127 | 1128 | 1129 | 1132 | 1133 | 1134 | 1135 | 1136 | 1137 | 1138 | 1139 | 1140 | 1141 | 1142 | 1143 | 1144 | 1145 | 1146 | 1147 | 1148 | 1149 | 1150 | 1153 | 1154 | 1155 | 1156 | 1157 | 1158 | 1159 | 1160 | 1161 | 1162 | 1163 | 1164 | 1165 | 1166 | 1167 | 1168 | 1169 | 1170 | 1171 | 1174 | 1175 | 1176 | 1177 | 1178 | 1179 | 1180 | 1181 | 1182 | 1183 | 1184 | 1185 | 1186 | 1187 | 1188 | 1189 | 1190 | 1191 | 1192 | 1195 | 1196 | 1197 | 1198 | 1199 | 1200 | 1201 | 1202 | 1203 | 1204 | 1205 | 1206 | 1207 | 1208 | 1209 | 1210 | 1211 | 1212 | 1213 | 1214 | 1230 | 1254 | 1292 | 1338 | 1339 | 1374 | 1411 | 1434 | 1451 | 1452 | 1453 | 1454 | 1455 | 1456 | 1457 | 1458 | 1459 | 1460 | 1461 | 1462 | 1463 | 1464 | 1465 | 1466 | 1467 | 1468 | 1469 | 1470 | 1471 | 1472 | 1475 | 1476 | 1477 | 1480 | 1481 | 1482 | 1483 | 1491 | 1492 | 1493 | 1496 | 1497 | 1498 | 1499 | 1500 | 1501 | 1549 | 1598 | 1619 | 1620 | 1621 | 1622 | 1623 | 1624 | 1625 | 1626 | 1627 | 1628 | 1629 | 1630 | 1631 | 1632 | 1633 | 1634 | 1635 | 1638 | 1639 | 1640 | 1641 | 1642 | 1643 | 1656 | 1674 | 1690 | 1708 | 1719 | 1720 | 1721 | 1722 | 1723 | 1724 | 1725 | 1726 | 1727 | 1728 | 1729 | 1730 | 1731 | 1732 | 1733 | 1734 | 1735 | 1736 | 1737 | 1738 | 1739 | 1742 | 1743 | 1744 | 1745 | 1746 | 1747 | 1770 | 1771 | 1772 | 1773 | 1774 | 1775 | 1776 | 1777 | 1778 | 1779 | 1780 | 1781 | 1782 | 1783 | 1784 | 1785 | 1786 | 1787 | 1788 | 1789 | 1792 | 1793 | 1794 | 1795 | 1796 | 1797 | 1812 | 1813 | 1814 | 1815 | 1816 | 1817 | 1818 | 1819 | 1820 | 1821 | 1822 | 1823 | 1824 | 1825 | 1826 | 1827 | 1828 | 1829 | 1830 | 1831 | 1832 | -------------------------------------------------------------------------------- /signalsmith-fft.h: -------------------------------------------------------------------------------- 1 | #ifndef SIGNALSMITH_FFT_V5 2 | #define SIGNALSMITH_FFT_V5 3 | // So we can easily include multiple versions by redefining this 4 | #ifndef SIGNALSMITH_FFT_NAMESPACE 5 | #define SIGNALSMITH_FFT_NAMESPACE signalsmith 6 | #endif 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #ifndef SIGNALSMITH_INLINE 15 | #ifdef __GNUC__ 16 | #define SIGNALSMITH_INLINE __attribute__((always_inline)) inline 17 | #elif defined(__MSVC__) 18 | #define SIGNALSMITH_INLINE __forceinline inline 19 | #else 20 | #define SIGNALSMITH_INLINE inline 21 | #endif 22 | #endif 23 | 24 | #ifndef M_PI 25 | #define M_PI 3.14159265358979323846264338327950288 26 | #endif 27 | 28 | namespace SIGNALSMITH_FFT_NAMESPACE { 29 | 30 | namespace perf { 31 | // Complex multiplication has edge-cases around Inf/NaN - handling those properly makes std::complex non-inlineable, so we use our own 32 | template 33 | SIGNALSMITH_INLINE std::complex complexMul(const std::complex &a, const std::complex &b) { 34 | return conjugateSecond ? std::complex{ 35 | b.real()*a.real() + b.imag()*a.imag(), 36 | b.real()*a.imag() - b.imag()*a.real() 37 | } : std::complex{ 38 | a.real()*b.real() - a.imag()*b.imag(), 39 | a.real()*b.imag() + a.imag()*b.real() 40 | }; 41 | } 42 | 43 | template 44 | SIGNALSMITH_INLINE std::complex complexAddI(const std::complex &a, const std::complex &b) { 45 | return flipped ? std::complex{ 46 | a.real() + b.imag(), 47 | a.imag() - b.real() 48 | } : std::complex{ 49 | a.real() - b.imag(), 50 | a.imag() + b.real() 51 | }; 52 | } 53 | } 54 | 55 | // Use SFINAE to get an iterator from std::begin(), if supported - otherwise assume the value itself is an iterator 56 | template 57 | struct GetIterator { 58 | static T get(const T &t) { 59 | return t; 60 | } 61 | }; 62 | template 63 | struct GetIterator()))> { 64 | static auto get(const T &t) -> decltype(std::begin(t)) { 65 | return std::begin(t); 66 | } 67 | }; 68 | 69 | template 70 | class FFT { 71 | using complex = std::complex; 72 | size_t _size; 73 | std::vector workingVector; 74 | 75 | enum class StepType { 76 | generic, step2, step3, step4 77 | }; 78 | struct Step { 79 | StepType type; 80 | size_t factor; 81 | size_t startIndex; 82 | size_t innerRepeats; 83 | size_t outerRepeats; 84 | size_t twiddleIndex; 85 | }; 86 | std::vector factors; 87 | std::vector plan; 88 | std::vector twiddleVector; 89 | 90 | struct PermutationPair {size_t from, to;}; 91 | std::vector permutation; 92 | 93 | void addPlanSteps(size_t factorIndex, size_t start, size_t length, size_t repeats) { 94 | if (factorIndex >= factors.size()) return; 95 | 96 | size_t factor = factors[factorIndex]; 97 | if (factorIndex + 1 < factors.size()) { 98 | if (factors[factorIndex] == 2 && factors[factorIndex + 1] == 2) { 99 | ++factorIndex; 100 | factor = 4; 101 | } 102 | } 103 | 104 | size_t subLength = length/factor; 105 | Step mainStep{StepType::generic, factor, start, subLength, repeats, twiddleVector.size()}; 106 | 107 | if (factor == 2) mainStep.type = StepType::step2; 108 | if (factor == 3) mainStep.type = StepType::step3; 109 | if (factor == 4) mainStep.type = StepType::step4; 110 | 111 | // Twiddles 112 | bool foundStep = false; 113 | for (const Step &existingStep : plan) { 114 | if (existingStep.factor == mainStep.factor && existingStep.innerRepeats == mainStep.innerRepeats) { 115 | foundStep = true; 116 | mainStep.twiddleIndex = existingStep.twiddleIndex; 117 | break; 118 | } 119 | } 120 | if (!foundStep) { 121 | for (size_t i = 0; i < subLength; ++i) { 122 | for (size_t f = 0; f < factor; ++f) { 123 | V phase = 2*M_PI*i*f/length; 124 | complex twiddle = {cos(phase), -sin(phase)}; 125 | twiddleVector.push_back(twiddle); 126 | } 127 | } 128 | } 129 | 130 | if (repeats == 1 && sizeof(complex)*subLength > 65536) { 131 | for (size_t i = 0; i < factor; ++i) { 132 | addPlanSteps(factorIndex + 1, start + i*subLength, subLength, 1); 133 | } 134 | } else { 135 | addPlanSteps(factorIndex + 1, start, subLength, repeats*factor); 136 | } 137 | plan.push_back(mainStep); 138 | } 139 | void setPlan() { 140 | factors.resize(0); 141 | size_t size = _size, f = 2; 142 | while (size > 1) { 143 | if (size%f == 0) { 144 | factors.push_back(f); 145 | size /= f; 146 | } else if (f > sqrt(size)) { 147 | f = size; 148 | } else { 149 | ++f; 150 | } 151 | } 152 | 153 | plan.resize(0); 154 | twiddleVector.resize(0); 155 | addPlanSteps(0, 0, _size, 1); 156 | 157 | permutation.resize(0); 158 | permutation.push_back(PermutationPair{0, 0}); 159 | size_t indexLow = 0, indexHigh = factors.size(); 160 | size_t inputStepLow = _size, outputStepLow = 1; 161 | size_t inputStepHigh = 1, outputStepHigh = _size; 162 | while (outputStepLow*inputStepHigh < _size) { 163 | size_t f, inputStep, outputStep; 164 | if (outputStepLow <= inputStepHigh) { 165 | f = factors[indexLow++]; 166 | inputStep = (inputStepLow /= f); 167 | outputStep = outputStepLow; 168 | outputStepLow *= f; 169 | } else { 170 | f = factors[--indexHigh]; 171 | inputStep = inputStepHigh; 172 | inputStepHigh *= f; 173 | outputStep = (outputStepHigh /= f); 174 | } 175 | size_t oldSize = permutation.size(); 176 | for (size_t i = 1; i < f; ++i) { 177 | for (size_t j = 0; j < oldSize; ++j) { 178 | PermutationPair pair = permutation[j]; 179 | pair.from += i*inputStep; 180 | pair.to += i*outputStep; 181 | permutation.push_back(pair); 182 | } 183 | } 184 | } 185 | } 186 | 187 | template 188 | void fftStepGeneric(RandomAccessIterator &&origData, const Step &step) { 189 | complex *working = workingVector.data(); 190 | const size_t stride = step.innerRepeats; 191 | 192 | for (size_t outerRepeat = 0; outerRepeat < step.outerRepeats; ++outerRepeat) { 193 | RandomAccessIterator data = origData; 194 | 195 | const complex *twiddles = twiddleVector.data() + step.twiddleIndex; 196 | const size_t factor = step.factor; 197 | for (size_t repeat = 0; repeat < step.innerRepeats; ++repeat) { 198 | for (size_t i = 0; i < step.factor; ++i) { 199 | working[i] = perf::complexMul(data[i*stride], twiddles[i]); 200 | } 201 | for (size_t f = 0; f < factor; ++f) { 202 | complex sum = working[0]; 203 | for (size_t i = 1; i < factor; ++i) { 204 | V phase = 2*M_PI*f*i/factor; 205 | complex factor = {cos(phase), -sin(phase)}; 206 | sum += perf::complexMul(working[i], factor); 207 | } 208 | data[f*stride] = sum; 209 | } 210 | ++data; 211 | twiddles += factor; 212 | } 213 | origData += step.factor*step.innerRepeats; 214 | } 215 | } 216 | 217 | template 218 | void fftStep2(RandomAccessIterator &&origData, const Step &step) { 219 | const size_t stride = step.innerRepeats; 220 | const complex *origTwiddles = twiddleVector.data() + step.twiddleIndex; 221 | for (size_t outerRepeat = 0; outerRepeat < step.outerRepeats; ++outerRepeat) { 222 | const complex* twiddles = origTwiddles; 223 | for (RandomAccessIterator data = origData; data < origData + stride; ++data) { 224 | complex A = data[0]; 225 | complex B = perf::complexMul(data[stride], twiddles[1]); 226 | 227 | data[0] = A + B; 228 | data[stride] = A - B; 229 | twiddles += 2; 230 | } 231 | origData += 2*stride; 232 | } 233 | } 234 | 235 | template 236 | void fftStep3(RandomAccessIterator &&origData, const Step &step) { 237 | constexpr complex factor3 = {-0.5, inverse ? 0.8660254037844386 : -0.8660254037844386}; 238 | const size_t stride = step.innerRepeats; 239 | const complex *origTwiddles = twiddleVector.data() + step.twiddleIndex; 240 | 241 | for (size_t outerRepeat = 0; outerRepeat < step.outerRepeats; ++outerRepeat) { 242 | const complex* twiddles = origTwiddles; 243 | for (RandomAccessIterator data = origData; data < origData + stride; ++data) { 244 | complex A = data[0]; 245 | complex B = perf::complexMul(data[stride], twiddles[1]); 246 | complex C = perf::complexMul(data[stride*2], twiddles[2]); 247 | 248 | complex realSum = A + (B + C)*factor3.real(); 249 | complex imagSum = (B - C)*factor3.imag(); 250 | 251 | data[0] = A + B + C; 252 | data[stride] = perf::complexAddI(realSum, imagSum); 253 | data[stride*2] = perf::complexAddI(realSum, imagSum); 254 | 255 | twiddles += 3; 256 | } 257 | origData += 3*stride; 258 | } 259 | } 260 | 261 | template 262 | void fftStep4(RandomAccessIterator &&origData, const Step &step) { 263 | const size_t stride = step.innerRepeats; 264 | const complex *origTwiddles = twiddleVector.data() + step.twiddleIndex; 265 | 266 | for (size_t outerRepeat = 0; outerRepeat < step.outerRepeats; ++outerRepeat) { 267 | const complex* twiddles = origTwiddles; 268 | for (RandomAccessIterator data = origData; data < origData + stride; ++data) { 269 | complex A = data[0]; 270 | complex C = perf::complexMul(data[stride], twiddles[2]); 271 | complex B = perf::complexMul(data[stride*2], twiddles[1]); 272 | complex D = perf::complexMul(data[stride*3], twiddles[3]); 273 | 274 | complex sumAC = A + C, sumBD = B + D; 275 | complex diffAC = A - C, diffBD = B - D; 276 | 277 | data[0] = sumAC + sumBD; 278 | data[stride] = perf::complexAddI(diffAC, diffBD); 279 | data[stride*2] = sumAC - sumBD; 280 | data[stride*3] = perf::complexAddI(diffAC, diffBD); 281 | 282 | twiddles += 4; 283 | } 284 | origData += 4*stride; 285 | } 286 | } 287 | 288 | template 289 | void permute(InputIterator input, OutputIterator data) { 290 | for (auto pair : permutation) { 291 | data[pair.from] = input[pair.to]; 292 | } 293 | } 294 | 295 | template 296 | void run(InputIterator &&input, OutputIterator &&data) { 297 | permute(input, data); 298 | 299 | for (const Step &step : plan) { 300 | switch (step.type) { 301 | case StepType::generic: 302 | fftStepGeneric(data + step.startIndex, step); 303 | break; 304 | case StepType::step2: 305 | fftStep2(data + step.startIndex, step); 306 | break; 307 | case StepType::step3: 308 | fftStep3(data + step.startIndex, step); 309 | break; 310 | case StepType::step4: 311 | fftStep4(data + step.startIndex, step); 312 | break; 313 | } 314 | } 315 | } 316 | 317 | static bool validSize(size_t size) { 318 | constexpr static bool filter[32] = { 319 | 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, // 0-9 320 | 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, // 10-19 321 | 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, // 20-29 322 | 0, 0 323 | }; 324 | return filter[size]; 325 | } 326 | public: 327 | static size_t sizeMinimum(size_t size) { 328 | size_t power2 = 1; 329 | while (size >= 32) { 330 | size = (size - 1)/2 + 1; 331 | power2 *= 2; 332 | } 333 | while (size < 32 && !validSize(size)) { 334 | ++size; 335 | } 336 | return power2*size; 337 | } 338 | static size_t sizeMaximum(size_t size) { 339 | size_t power2 = 1; 340 | while (size >= 32) { 341 | size /= 2; 342 | power2 *= 2; 343 | } 344 | while (size > 1 && !validSize(size)) { 345 | --size; 346 | } 347 | return power2*size; 348 | } 349 | 350 | FFT(size_t size, int fastDirection=0) : _size(0) { 351 | if (fastDirection > 0) size = sizeMinimum(size); 352 | if (fastDirection < 0) size = sizeMaximum(size); 353 | this->setSize(size); 354 | } 355 | 356 | size_t setSize(size_t size) { 357 | if (size != _size) { 358 | _size = size; 359 | workingVector.resize(size); 360 | setPlan(); 361 | } 362 | return _size; 363 | } 364 | size_t setSizeMinimum(size_t size) { 365 | return setSize(sizeMinimum(size)); 366 | } 367 | size_t setSizeMaximum(size_t size) { 368 | return setSize(sizeMaximum(size)); 369 | } 370 | const size_t & size() const { 371 | return _size; 372 | } 373 | 374 | template 375 | void fft(InputIterator &&input, OutputIterator &&output) { 376 | auto inputIter = GetIterator::get(input); 377 | auto outputIter = GetIterator::get(output); 378 | return run(inputIter, outputIter); 379 | } 380 | 381 | template 382 | void ifft(InputIterator &&input, OutputIterator &&output) { 383 | auto inputIter = GetIterator::get(input); 384 | auto outputIter = GetIterator::get(output); 385 | return run(inputIter, outputIter); 386 | } 387 | }; 388 | 389 | struct FFTOptions { 390 | static constexpr int halfFreqShift = 1; 391 | }; 392 | 393 | template 394 | class RealFFT { 395 | static constexpr bool modified = (optionFlags&FFTOptions::halfFreqShift); 396 | 397 | using complex = std::complex; 398 | std::vector complexBuffer1, complexBuffer2; 399 | std::vector twiddlesMinusI; 400 | std::vector modifiedRotations; 401 | FFT complexFft; 402 | public: 403 | static size_t sizeMinimum(size_t size) { 404 | return (FFT::sizeMinimum((size - 1)/2) + 1)*2; 405 | } 406 | static size_t sizeMaximum(size_t size) { 407 | return FFT::sizeMinimum(size/2)*2; 408 | } 409 | 410 | RealFFT(size_t size, int fastDirection=0) : complexFft(0) { 411 | if (fastDirection > 0) size = sizeMinimum(size); 412 | if (fastDirection < 0) size = sizeMaximum(size); 413 | this->setSize(size); 414 | } 415 | 416 | size_t setSize(size_t size) { 417 | complexBuffer1.resize(size/2); 418 | complexBuffer2.resize(size/2); 419 | 420 | size_t hhSize = size/4 + 1; 421 | twiddlesMinusI.resize(hhSize); 422 | for (size_t i = 0; i < hhSize; ++i) { 423 | double rotPhase = -2*M_PI*(modified ? i + 0.5 : i)/size; 424 | twiddlesMinusI[i] = {sin(rotPhase), -cos(rotPhase)}; 425 | } 426 | if (modified) { 427 | modifiedRotations.resize(size/2); 428 | for (size_t i = 0; i < size/2; ++i) { 429 | double rotPhase = -2*M_PI*i/size; 430 | modifiedRotations[i] = {cos(rotPhase), sin(rotPhase)}; 431 | } 432 | } 433 | 434 | return complexFft.setSize(size/2); 435 | } 436 | size_t setSizeMinimum(size_t size) { 437 | return setSize(sizeMinimum(size)); 438 | } 439 | size_t setSizeMaximum(size_t size) { 440 | return setSize(sizeMaximum(size)); 441 | } 442 | size_t size() const { 443 | return complexFft.size()*2; 444 | } 445 | 446 | template 447 | void fft(InputIterator &&input, OutputIterator &&output) { 448 | size_t hSize = complexFft.size(); 449 | for (size_t i = 0; i < hSize; ++i) { 450 | if (modified) { 451 | complexBuffer1[i] = perf::complexMul({input[2*i], input[2*i + 1]}, modifiedRotations[i]); 452 | } else { 453 | complexBuffer1[i] = {input[2*i], input[2*i + 1]}; 454 | } 455 | } 456 | 457 | complexFft.fft(complexBuffer1.data(), complexBuffer2.data()); 458 | 459 | if (!modified) output[0] = { 460 | complexBuffer2[0].real() + complexBuffer2[0].imag(), 461 | complexBuffer2[0].real() - complexBuffer2[0].imag() 462 | }; 463 | for (size_t i = modified ? 0 : 1; i <= hSize/2; ++i) { 464 | size_t conjI = modified ? (hSize - 1 - i) : (hSize - i); 465 | 466 | complex odd = (complexBuffer2[i] + conj(complexBuffer2[conjI]))*(V)0.5; 467 | complex evenI = (complexBuffer2[i] - conj(complexBuffer2[conjI]))*(V)0.5; 468 | complex evenRotMinusI = perf::complexMul(evenI, twiddlesMinusI[i]); 469 | 470 | output[i] = odd + evenRotMinusI; 471 | output[conjI] = conj(odd - evenRotMinusI); 472 | } 473 | } 474 | 475 | template 476 | void ifft(InputIterator &&input, OutputIterator &&output) { 477 | size_t hSize = complexFft.size(); 478 | if (!modified) complexBuffer1[0] = { 479 | input[0].real() + input[0].imag(), 480 | input[0].real() - input[0].imag() 481 | }; 482 | for (size_t i = modified ? 0 : 1; i <= hSize/2; ++i) { 483 | size_t conjI = modified ? (hSize - 1 - i) : (hSize - i); 484 | complex v = input[i], v2 = input[conjI]; 485 | 486 | complex odd = v + conj(v2); 487 | complex evenRotMinusI = v - conj(v2); 488 | complex evenI = perf::complexMul(evenRotMinusI, twiddlesMinusI[i]); 489 | 490 | complexBuffer1[i] = odd + evenI; 491 | complexBuffer1[conjI] = conj(odd - evenI); 492 | } 493 | 494 | complexFft.ifft(complexBuffer1.data(), complexBuffer2.data()); 495 | 496 | for (size_t i = 0; i < hSize; ++i) { 497 | complex v = complexBuffer2[i]; 498 | if (modified) v = perf::complexMul(v, modifiedRotations[i]); 499 | output[2*i] = v.real(); 500 | output[2*i + 1] = v.imag(); 501 | } 502 | } 503 | }; 504 | 505 | template 506 | struct ModifiedRealFFT : public RealFFT { 507 | using RealFFT::RealFFT; 508 | }; 509 | } 510 | 511 | #undef SIGNALSMITH_FFT_NAMESPACE 512 | #endif // SIGNALSMITH_FFT_V5 513 | -------------------------------------------------------------------------------- /tests/00-fft.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "tests-common.h" 7 | 8 | std::vector testSizes() { 9 | return { 10 | 1, 2, 4, 8, 16, 32, 64, 128, 256, 11 | 3, 6, 9, 12, 18, 24, 12 | 5, 10, 15, 20, 25, 13 | 7, 14, 21, 28, 49, 14 | 11, 13, 17, 19, 22, 23 15 | }; 16 | } 17 | 18 | 19 | TEST("Individual bins", test_2N_bins) { 20 | using signalsmith::FFT; 21 | using std::vector; 22 | using std::complex; 23 | 24 | std::vector sizes = testSizes(); 25 | 26 | for (int size : sizes) { 27 | vector> input(size); 28 | vector> inputCopy(size); 29 | vector> output(size); 30 | vector> expected(size); 31 | 32 | FFT fft(size); 33 | 34 | // Test each bin 35 | for (int bin = 0; bin < size; bin++) { 36 | for (int i = 0; i < size; i++) { 37 | double phase = 2*M_PI*i*bin/size; 38 | input[i] = inputCopy[i] = complex{cos(phase), sin(phase)}; 39 | expected[i] = (i == bin) ? size : 0; 40 | } 41 | 42 | fft.fft(input, output); 43 | 44 | if (!closeEnough(input, inputCopy)) { 45 | return test.fail("input was changed"); 46 | } 47 | if (!closeEnough(output, expected)) { 48 | std::cout << "N = " << size << "\n"; 49 | std::cout << " input:\t"; 50 | printArray(input); 51 | std::cout << " output:\t"; 52 | printArray(output); 53 | std::cout << "expected:\t"; 54 | printArray(expected); 55 | 56 | return test.fail("output != expected"); 57 | } 58 | } 59 | } 60 | } 61 | 62 | TEST("Linearity", test_2N_linearity) { 63 | using signalsmith::FFT; 64 | using std::vector; 65 | using std::complex; 66 | 67 | std::vector sizes = testSizes(); 68 | 69 | for (int size : sizes) { 70 | vector> inputA(size); 71 | vector> inputB(size); 72 | vector> inputAB(size); 73 | vector> outputA(size); 74 | vector> outputB(size); 75 | vector> outputAB(size); 76 | vector> outputSummed(size); 77 | 78 | FFT fft(size); 79 | 80 | // Test linearity 81 | for (int i = 0; i < size; i++) { 82 | inputA[i] = randomComplex(); 83 | inputB[i] = randomComplex(); 84 | inputAB[i] = inputA[i] + inputB[i]; 85 | } 86 | 87 | fft.fft(inputA, outputA); 88 | fft.fft(inputB, outputB); 89 | fft.fft(inputAB, outputAB); 90 | 91 | for (int i = 0; i < size; i++) { 92 | outputSummed[i] = outputA[i] + outputB[i]; 93 | } 94 | 95 | if (!closeEnough(outputAB, outputSummed)) { 96 | return test.fail("result was not linear"); 97 | } 98 | } 99 | } 100 | 101 | template 102 | void inverseTest(Test test) { 103 | using signalsmith::FFT; 104 | 105 | std::vector sizes = testSizes(); 106 | 107 | for (int size : sizes) { 108 | std::vector> input(size); 109 | std::vector> spectrum(size); 110 | std::vector> output(size); 111 | std::vector> expected(size); 112 | 113 | for (int i = 0; i < size; i++) { 114 | if (fixedHarmonics >= 0) { 115 | double freq = fixedHarmonics; 116 | double phase = 2*M_PI*i*freq/size; 117 | input[i] = {cos(phase), sin(phase)}; 118 | } else { 119 | input[i] = randomComplex(); 120 | } 121 | expected[i] = input[i]*(double)size; 122 | } 123 | 124 | FFT fft(size); 125 | fft.fft(input, spectrum); 126 | fft.ifft(spectrum, output); 127 | 128 | if (!closeEnough(output, expected)) { 129 | printArray(input); 130 | printArray(spectrum); 131 | std::cout << size << "\n"; 132 | printArray(output); 133 | printArray(expected); 134 | return test.fail("inverse did not match"); 135 | } 136 | } 137 | } 138 | 139 | TEST("Inverse (first harmonic)", inverse_f1) { 140 | return inverseTest<1>(test); 141 | } 142 | 143 | TEST("Inverse (random)", inverse_random) { 144 | return inverseTest<-1>(test); 145 | } 146 | 147 | struct Powers { 148 | size_t two = 0, three = 0, five = 0, remainder = 1; 149 | }; 150 | Powers factorise(size_t size) { 151 | Powers powers; 152 | while (size%2 == 0) { 153 | size /= 2; 154 | powers.two++; 155 | } 156 | while (size%3 == 0) { 157 | size /= 3; 158 | powers.three++; 159 | } 160 | while (size%5 == 0) { 161 | size /= 5; 162 | powers.five++; 163 | } 164 | powers.remainder = size; 165 | return powers; 166 | } 167 | 168 | TEST("Sizes", sizes) { 169 | using signalsmith::FFT; 170 | 171 | for (size_t i = 1; i < 1000; ++i) { 172 | size_t above = FFT::sizeMinimum(i); 173 | size_t below = FFT::sizeMaximum(i); 174 | 175 | if (above < i) return test.fail("above < i"); 176 | if (below > i) return test.fail("below > i"); 177 | 178 | auto factorsAbove = factorise(above); 179 | auto factorsBelow = factorise(below); 180 | 181 | if (factorsAbove.remainder != 1) return test.fail("non-fast above remainder"); 182 | if (factorsBelow.remainder != 1) return test.fail("non-fast below remainder"); 183 | 184 | if (factorsAbove.three + factorsAbove.five > 2) return test.fail("above is too complex"); 185 | if (factorsBelow.three + factorsBelow.five > 2) return test.fail("below is too complex"); 186 | } 187 | } 188 | -------------------------------------------------------------------------------- /tests/01-real.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "tests-common.h" 7 | 8 | #define LOG_VALUE(expr) \ 9 | (std::cout << #expr << " = " << (expr) << "\n") 10 | 11 | #define FAIL_VALUE_PAIR(expr1, expr2) \ 12 | ( \ 13 | LOG_VALUE(expr1), LOG_VALUE(expr2), \ 14 | test.fail(#expr1 " and " #expr2) \ 15 | ) 16 | 17 | template 18 | void test_real(Test &test) { 19 | using signalsmith::FFT; 20 | using signalsmith::RealFFT; 21 | using signalsmith::ModifiedRealFFT; 22 | using std::vector; 23 | using std::complex; 24 | 25 | for (int size = 2; size < 100; size += 2) { 26 | vector> complexInput(size); 27 | vector> complexMid(size); 28 | vector> complexOutput(size); 29 | vector realInput(size); 30 | vector> realMid(size); // Only need half, but check it's undisturbed 31 | vector realOutput(size); 32 | 33 | FFT fft(size); 34 | typename std::conditional, RealFFT>::type realFft(size); 35 | 36 | // Random inputs 37 | for (int i = 0; i < size; ++i) { 38 | double v = rand()/(double)RAND_MAX - 0.5; 39 | complexInput[i] = v; 40 | realInput[i] = v; 41 | } 42 | if (modified) { 43 | for (int i = 0; i < size; ++i) { 44 | double rotPhase = -M_PI*i/size; 45 | complex rot = {cos(rotPhase), sin(rotPhase)}; 46 | complexInput[i] *= rot; 47 | } 48 | } 49 | for (int i = size/2; i < size; ++i) { 50 | // Should be undisturbed - fill with known value 51 | realMid[i] = complex{52, 21}; 52 | } 53 | 54 | fft.fft(complexInput, complexMid); 55 | realFft.fft(realInput, realMid); 56 | 57 | // Check complex spectrum matches 58 | if (!modified) { 59 | if (complexMid[0].imag() > 1e-6) return test.fail("complexMid[0].imag()"); 60 | if (abs(complexMid[0].real() - realMid[0].real()) > 1e-6) return FAIL_VALUE_PAIR(complexMid[0].real(), realMid[0].real()); 61 | if (abs(complexMid[size/2].real() - realMid[0].imag()) > 1e-6) return FAIL_VALUE_PAIR(complexMid[size/2].real(), realMid[0].imag()); 62 | } 63 | for (int i = modified ? 0 : 1; i < size/2; ++i) { 64 | complex diff = complexMid[i] - realMid[i]; 65 | if (abs(diff) > size*1e-6) { 66 | LOG_VALUE(i); 67 | return FAIL_VALUE_PAIR(complexMid[i], realMid[i]); 68 | } 69 | } 70 | for (int i = size/2; i < size; ++i) { 71 | // It should have left the second half of realMid completely alone 72 | if (realMid[i] != complex{52, 21}) return test.fail("realMid second half"); 73 | } 74 | 75 | fft.ifft(complexMid, complexOutput); 76 | realFft.ifft(realMid, realOutput); 77 | 78 | if (modified) { 79 | for (int i = 0; i < size; ++i) { 80 | double rotPhase = M_PI*i/size; 81 | complex rot = {cos(rotPhase), sin(rotPhase)}; 82 | complexOutput[i] *= rot; 83 | } 84 | } 85 | 86 | for (int i = 0; i < size; ++i) { 87 | if (complexOutput[i].imag() > size*1e-6) return test.fail("complexOutput[i].imag"); 88 | 89 | if (abs(complexOutput[i].real() - realOutput[i]) > size*1e-6) { 90 | LOG_VALUE(size); 91 | LOG_VALUE(i); 92 | return FAIL_VALUE_PAIR(complexOutput[i], realOutput[i]); 93 | } 94 | } 95 | } 96 | } 97 | 98 | TEST("Random real", random_real) { 99 | test_real(test); 100 | } 101 | 102 | TEST("Modified real", random_modified_real) { 103 | test_real(test); 104 | } 105 | -------------------------------------------------------------------------------- /tests/tests-common.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | // from the shared library 5 | #include 6 | 7 | #include "../signalsmith-fft.h" 8 | 9 | template 10 | std::complex randomComplex() { 11 | std::complex r; 12 | r.real(rand()/(double)RAND_MAX - 0.5); 13 | r.imag(rand()/(double)RAND_MAX - 0.5); 14 | return r; 15 | } 16 | 17 | template 18 | bool closeEnough(std::vector> vectorA, std::vector> vectorB) { 19 | double totalEnergy = 0; 20 | double totalError = 0; 21 | if (vectorA.size() != vectorB.size()) return false; 22 | 23 | for (unsigned int i = 0; i < vectorA.size(); ++i) { 24 | T error = norm(vectorA[i] - vectorB[i]); 25 | T energy = norm(vectorA[i]*vectorA[i]) + norm(vectorB[i]*vectorB[i]); 26 | totalEnergy += energy; 27 | totalError += error; 28 | } 29 | 30 | if (!totalEnergy) return true; 31 | T errorRatio = sqrt(totalError/totalEnergy); 32 | return errorRatio < 1e-6; 33 | } 34 | 35 | template 36 | static void printArray(std::vector array, bool newline=true) { 37 | for (unsigned int i = 0; i < array.size(); ++i) { 38 | if (i > 0) std::cout << "\t"; 39 | std::cout << array[i]; 40 | } 41 | if (newline) std::cout << std::endl; 42 | } 43 | --------------------------------------------------------------------------------