├── MANIFEST.in ├── cpp └── zimt │ ├── test.wav │ ├── visqol_model.h │ ├── test_file_paths.cc │ ├── visqol_model.cc │ ├── test_file_paths.h │ ├── visqol.h │ ├── mos_test.cc │ ├── mos.h │ ├── dtw_test.cc │ ├── audio_test.cc │ ├── distance_benchmark_test.cc │ ├── audio.h │ ├── pyohrli_test.py │ ├── resample.h │ ├── nsim_test.cc │ ├── audio.cc │ ├── pyohrli.py │ ├── goohrli.cc │ ├── visqol.cc │ ├── compare.cc │ └── pyohrli.cc ├── pyproject.toml ├── .clang-format ├── .gitignore ├── cmake ├── pffft.cmake ├── armadillo.cmake ├── libsvm.cmake ├── benchmark.cmake ├── deps.cmake ├── protobuf.cmake ├── tests.cmake ├── visqol.cmake └── visqol_manager.cc ├── go.mod ├── .github └── workflows │ ├── build_python_wheels.yml │ └── test.yml ├── configure.sh ├── go ├── README.md ├── data │ └── table.go ├── resource │ └── pool.go ├── bin │ ├── csv │ │ └── csv.go │ ├── coresvnet │ │ └── coresvnet.go │ ├── perceptual_audio │ │ └── perceptual_audio.go │ ├── compare │ │ └── compare.go │ ├── tcd_voip │ │ └── tcd_voip.go │ ├── sebass_db │ │ └── sebass_db.go │ ├── score │ │ └── score.go │ └── odaq │ │ └── odaq.go ├── goohrli │ ├── goohrli.h │ └── goohrli_test.go ├── worker │ └── pool.go ├── aio │ └── aio.go ├── progress │ └── bar.go ├── audio │ ├── audio_test.go │ └── audio.go └── pipe │ └── metric.go ├── CONTRIBUTING.md ├── setup.py ├── CMakeLists.txt ├── go.sum ├── README.md └── tools └── optimizer ├── simplex_fork.py └── random_fork.py /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include cpp/zimt/mos.h 2 | include cpp/zimt/zimtohrli.h 3 | -------------------------------------------------------------------------------- /cpp/zimt/test.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/zimtohrli/HEAD/cpp/zimt/test.wav -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: Google 2 | IndentWidth: 2 3 | ColumnLimit: 80 4 | PointerAlignment: Left 5 | DerivePointerAlignment: false 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | debug_build 3 | asan_build 4 | .vscode 5 | python/__pycache__/** 6 | **/METADATA 7 | # Some files produced by one of the external metrics. 8 | pesq_results.txt 9 | analized 10 | Testing/* 11 | go/goohrli/goohrli.a 12 | -------------------------------------------------------------------------------- /cmake/pffft.cmake: -------------------------------------------------------------------------------- 1 | FetchContent_Declare(pffft 2 | EXCLUDE_FROM_ALL 3 | GIT_REPOSITORY https://bitbucket.org/jpommier/pffft.git 4 | GIT_TAG 7c3b5a7dc510a0f513b9c5b6dc5b56f7aeeda422 5 | ) 6 | 7 | FetchContent_MakeAvailable(pffft) 8 | 9 | add_library(pffft STATIC 10 | ${pffft_SOURCE_DIR}/pffft.c 11 | ${pffft_SOURCE_DIR}/pffft.h 12 | ) 13 | target_include_directories(pffft PUBLIC ${pffft_SOURCE_DIR}) 14 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/google/zimtohrli 2 | 3 | go 1.23.0 4 | 5 | toolchain go1.23.7 6 | 7 | require ( 8 | github.com/PuerkitoBio/goquery v1.10.3 9 | github.com/dgryski/go-onlinestats v0.0.0-20170612111826-1c7d19468768 10 | github.com/mattn/go-sqlite3 v1.14.28 11 | ) 12 | 13 | require ( 14 | github.com/aclements/go-moremath v0.0.0-20241023150245-c8bbc672ef66 // indirect 15 | github.com/andybalholm/cascadia v1.3.3 // indirect 16 | golang.org/x/net v0.39.0 // indirect 17 | ) 18 | -------------------------------------------------------------------------------- /cmake/armadillo.cmake: -------------------------------------------------------------------------------- 1 | FetchContent_Declare(armadillo 2 | EXCLUDE_FROM_ALL 3 | GIT_REPOSITORY https://gitlab.com/conradsnicta/armadillo-code.git 4 | GIT_TAG 5e57e49667d8913b88855925fcb6ef3b1f6ebe98 5 | ) 6 | set(BUILD_SMOKE_TEST OFF CACHE INTERNAL "") 7 | FetchContent_MakeAvailable(armadillo) 8 | 9 | file(GLOB_RECURSE armadillo_files ${armadillo_SOURCE_DIR} *.cc *.c *.h) 10 | set_source_files_properties( 11 | ${armadillo_files} 12 | TARGET_DIRECTORY armadillo 13 | PROPERTIES SKIP_LINTING ON 14 | ) 15 | -------------------------------------------------------------------------------- /cmake/libsvm.cmake: -------------------------------------------------------------------------------- 1 | FetchContent_Declare(libsvm 2 | EXCLUDE_FROM_ALL 3 | GIT_REPOSITORY https://github.com/cjlin1/libsvm.git 4 | GIT_TAG v332 5 | ) 6 | FetchContent_MakeAvailable(libsvm) 7 | 8 | set(libsvm_files 9 | ${libsvm_SOURCE_DIR}/svm.cpp 10 | ${libsvm_SOURCE_DIR}/svm.h 11 | ) 12 | add_library(libsvm STATIC ${libsvm_files}) 13 | target_include_directories(libsvm PUBLIC ${libsvm_SOURCE_DIR}) 14 | 15 | set_source_files_properties( 16 | ${libsvm_files} 17 | TARGET_DIRECTORY libsvm 18 | PROPERTIES SKIP_LINTING ON 19 | ) 20 | -------------------------------------------------------------------------------- /.github/workflows/build_python_wheels.yml: -------------------------------------------------------------------------------- 1 | name: Build Python wheels 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build_python_wheels: 7 | name: Build Python wheels on ${{matrix.os}} 8 | runs-on: ${{matrix.os}} 9 | strategy: 10 | matrix: 11 | os: [ubuntu-latest, ubuntu-24.04-arm, windows-latest, windows-11-arm, macos-13, macos-14] 12 | 13 | steps: 14 | - uses: actions/checkout@v4 15 | 16 | - name: Build Python wheels 17 | uses: pypa/cibuildwheel@v3.1.2 18 | env: 19 | MACOSX_DEPLOYMENT_TARGET: 10.13 20 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test Zimtohrli 2 | 3 | on: 4 | pull_request: 5 | 6 | jobs: 7 | build: 8 | 9 | runs-on: ubuntu-latest 10 | 11 | if: '! github.event.pull_request.draft' 12 | 13 | steps: 14 | - name: Install dependencies 15 | run: sudo apt install -y libogg-dev libvorbis-dev libflac-dev cmake ninja-build libasound2-dev libopus-dev libsoxr-dev 16 | - name: Check out code 17 | uses: actions/checkout@v3 18 | - name: Configure 19 | run: ./configure.sh 20 | - name: Build 21 | run: (cd build && env ninja) 22 | - name: Test 23 | run: (cd build && env ctest --output-on-failure) 24 | -------------------------------------------------------------------------------- /cmake/benchmark.cmake: -------------------------------------------------------------------------------- 1 | FetchContent_Declare(benchmark 2 | EXCLUDE_FROM_ALL 3 | GIT_REPOSITORY https://github.com/google/benchmark.git 4 | GIT_TAG v1.8.3 5 | ) 6 | set(BENCHMARK_ENABLE_TESTING OFF CACHE INTERNAL "") 7 | set(BENCHMARK_ENABLE_GTEST_TESTS OFF CACHE INTERNAL "") 8 | set(BENCHMARK_ENABLE_ASSEMBLY_TESTS OFF CACHE INTERNAL "") 9 | set(BENCHMARK_ENABLE_WERROR OFF CACHE INTERNAL "") 10 | FetchContent_MakeAvailable(benchmark) 11 | 12 | file(GLOB_RECURSE benchmark_files ${benchmark_SOURCE_DIR} *.cc *.c *.h) 13 | set_source_files_properties( 14 | ${benchmark_files} 15 | TARGET_DIRECTORY benchmark 16 | PROPERTIES SKIP_LINTING ON 17 | ) 18 | -------------------------------------------------------------------------------- /cmake/deps.cmake: -------------------------------------------------------------------------------- 1 | find_package(PkgConfig REQUIRED) 2 | pkg_check_modules(flac REQUIRED flac) 3 | pkg_check_modules(ogg REQUIRED ogg) 4 | pkg_check_modules(vorbis REQUIRED vorbis) 5 | pkg_check_modules(vorbisenc REQUIRED vorbisenc) 6 | 7 | pkg_check_modules(soxr REQUIRED IMPORTED_TARGET soxr) 8 | 9 | include(FetchContent) 10 | 11 | include(cmake/protobuf.cmake) 12 | 13 | FetchContent_Declare(libsndfile 14 | EXCLUDE_FROM_ALL 15 | GIT_REPOSITORY https://github.com/libsndfile/libsndfile.git 16 | GIT_TAG 1.2.2 17 | ) 18 | FetchContent_MakeAvailable(libsndfile) 19 | 20 | FetchContent_Declare(googletest 21 | EXCLUDE_FROM_ALL 22 | GIT_REPOSITORY https://github.com/google/googletest.git 23 | GIT_TAG v1.14.0 24 | ) 25 | FetchContent_MakeAvailable(googletest) 26 | 27 | include(cmake/benchmark.cmake) 28 | include(cmake/pffft.cmake) 29 | include(cmake/libsvm.cmake) 30 | include(cmake/armadillo.cmake) 31 | include(cmake/visqol.cmake) 32 | -------------------------------------------------------------------------------- /cmake/protobuf.cmake: -------------------------------------------------------------------------------- 1 | FetchContent_Declare(protobuf 2 | EXCLUDE_FROM_ALL 3 | GIT_REPOSITORY https://github.com/protocolbuffers/protobuf.git 4 | GIT_TAG v26.1 5 | ) 6 | set(protobuf_BUILD_TESTS OFF) 7 | set(BUILD_TESTING OFF CACHE INTERNAL "") 8 | FetchContent_MakeAvailable(protobuf) 9 | 10 | target_compile_options(libprotobuf PRIVATE -Wno-attributes) 11 | 12 | file(GLOB_RECURSE protobuf_files ${protobuf_SOURCE_DIR} *.cc *.c *.h) 13 | set_source_files_properties( 14 | ${protobuf_files} 15 | TARGET_DIRECTORY 16 | libprotobuf 17 | absl::base 18 | absl::strings 19 | absl::debugging 20 | absl::hash 21 | absl::flags 22 | absl::log 23 | absl::sample_recorder 24 | absl::any 25 | absl::time 26 | absl::container_common 27 | absl::crc_internal 28 | absl::numeric 29 | absl::synchronization 30 | absl::status 31 | PROPERTIES SKIP_LINTING ON 32 | ) 33 | -------------------------------------------------------------------------------- /configure.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CLANG_OPTS="" 4 | if [ -n "${CLANG_PREFIX}" ]; then 5 | CLANG_OPTS="-DCMAKE_CXX_COMPILER=${CLANG_PREFIX}/clang++ -DCMAKE_C_COMPILER=${CLANG_PREFIX}/clang" 6 | fi 7 | 8 | if [ "${1}" == "debug" ]; then 9 | mkdir -p debug_build 10 | (cd debug_build && cmake -G Ninja ${CLANG_OPTS} -DCMAKE_POLICY_VERSION_MINIMUM=3.5 -DCMAKE_C_FLAGS='-fPIC' -DCMAKE_CXX_FLAGS='-fPIC' -DCMAKE_BUILD_TYPE=RelWithDebInfo ..) 11 | elif [ "${1}" == "asan" ]; then 12 | mkdir -p asan_build 13 | (cd asan_build && cmake -G Ninja ${CLANG_OPTS} -DCMAKE_POLICY_VERSION_MINIMUM=3.5 -DCMAKE_C_FLAGS='-fsanitize=address -fPIC' -DCMAKE_CXX_FLAGS='-fsanitize=address -fPIC' -DCMAKE_LINKER_FLAGS_DEBUG='-fsanitize=address' -DCMAKE_BUILD_TYPE=RelWithDebInfo ..) 14 | else 15 | mkdir -p build 16 | (cd build && cmake -G Ninja ${CLANG_OPTS} -DCMAKE_POLICY_VERSION_MINIMUM=3.5 -DCMAKE_C_FLAGS='-fPIC -march=native -O3' -DCMAKE_CXX_FLAGS='-fPIC -march=native -O3' -DCMAKE_BUILD_TYPE=Release ..) 17 | fi 18 | -------------------------------------------------------------------------------- /cpp/zimt/visqol_model.h: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef CPP_ZIMT_VISQOL_MODEL_H_ 16 | #define CPP_ZIMT_VISQOL_MODEL_H_ 17 | 18 | #include "zimt/zimtohrli.h" 19 | 20 | namespace zimtohrli { 21 | 22 | // Returns the bytes of the default ViSQOL model. 23 | Span ViSQOLModel(); 24 | 25 | } // namespace zimtohrli 26 | 27 | #endif // CPP_ZIMT_VISQOL_MODEL_H_ 28 | -------------------------------------------------------------------------------- /cpp/zimt/test_file_paths.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "zimt/test_file_paths.h" 16 | 17 | #include 18 | 19 | namespace zimtohrli { 20 | 21 | std::filesystem::path GetTestFilePath( 22 | const std::filesystem::path& relative_path) { 23 | return std::filesystem::path(_xstr(CMAKE_CURRENT_SOURCE_DIR)) / relative_path; 24 | } 25 | 26 | } // namespace zimtohrli -------------------------------------------------------------------------------- /cpp/zimt/visqol_model.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "visqol_model.h" 16 | 17 | #include "libsvm_nu_svr_model.h" 18 | #include "zimt/zimtohrli.h" 19 | 20 | namespace zimtohrli { 21 | 22 | Span ViSQOLModel() { 23 | return Span(reinterpret_cast(visqol_model_bytes), 24 | visqol_model_bytes_len); 25 | } 26 | 27 | } // namespace zimtohrli -------------------------------------------------------------------------------- /go/README.md: -------------------------------------------------------------------------------- 1 | # Zimtohrli Go tools 2 | 3 | Zimtohrli has a Go wrapper, goohrli, installable via `go install`. 4 | 5 | It works by using `cgo` to wrap the Zimtohrli C++ library. 6 | 7 | To install it, a few dependencies are needed. To do this in a Debian-like system: 8 | 9 | ``` 10 | sudo apt install -y libc++-dev libc++abi-dev libflac-dev libogg-dev libvorbis-dev libvorbis-dev libopus-dev 11 | ``` 12 | 13 | ## Goohrli wrapper 14 | 15 | To install the wrapper library when inside a go module (a directory with a `go.mod` file): 16 | 17 | ``` 18 | go get github.com/google/zimtohrli/go/goohrli 19 | ``` 20 | 21 | For documentation about the API, see [https://pkg.go.dev/github.com/google/zimtohrli/go/goohrli](https://pkg.go.dev/github.com/google/zimtohrli/go/goohrli) 22 | 23 | ## Compare command line tool 24 | 25 | A simple command line tool to compare WAV files is provided. 26 | 27 | To install it: 28 | 29 | ``` 30 | go install github.com/google/zimtohrli/go/bin/compare 31 | ``` 32 | 33 | To run it (run `go env` to see your $GOPATH): 34 | 35 | ``` 36 | $GOPATH/bin/compare -path_a reference.wav -path_b distortion.wav 37 | ``` -------------------------------------------------------------------------------- /cpp/zimt/test_file_paths.h: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef CPP_ZIMT_TEST_FILE_PATHS_H_ 16 | #define CPP_ZIMT_TEST_FILE_PATHS_H_ 17 | 18 | #include 19 | 20 | #ifndef CMAKE_CURRENT_SOURCE_DIR 21 | #error "CMAKE_CURRENT_SOURCE_DIR must be #defined in the analysis test!" 22 | #endif 23 | #define _xstr(a) _str(a) 24 | #define _str(a) #a 25 | 26 | namespace zimtohrli { 27 | 28 | std::filesystem::path GetTestFilePath( 29 | const std::filesystem::path& relative_path); 30 | 31 | } // namespace zimtohrli 32 | 33 | #endif // CPP_ZIMT_TEST_FILE_PATHS_H_ 34 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We currently do not accept patches and contributions to this project. 4 | 5 | ## Before you begin 6 | 7 | ### Sign our Contributor License Agreement 8 | 9 | Contributions to this project must be accompanied by a 10 | [Contributor License Agreement](https://cla.developers.google.com/about) (CLA). 11 | You (or your employer) retain the copyright to your contribution; this simply 12 | gives us permission to use and redistribute your contributions as part of the 13 | project. 14 | 15 | If you or your current employer have already signed the Google CLA (even if it 16 | was for a different project), you probably don't need to do it again. 17 | 18 | Visit to see your current agreements or to 19 | sign a new one. 20 | 21 | ### Review our Community Guidelines 22 | 23 | This project follows 24 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 25 | 26 | ## Contribution process 27 | 28 | ### Code Reviews 29 | 30 | All submissions, including submissions by project members, require review. We 31 | use GitHub pull requests for this purpose. Consult 32 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 33 | information on using pull requests. 34 | -------------------------------------------------------------------------------- /cpp/zimt/visqol.h: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef CPP_ZIMT_VISQOL_H_ 16 | #define CPP_ZIMT_VISQOL_H_ 17 | 18 | #include 19 | 20 | #include "absl/status/statusor.h" 21 | #include "zimt/zimtohrli.h" 22 | 23 | namespace zimtohrli { 24 | 25 | class ViSQOL { 26 | public: 27 | ViSQOL(); 28 | ~ViSQOL(); 29 | absl::StatusOr MOS(Span reference, 30 | Span degraded, 31 | float sample_rate) const; 32 | 33 | private: 34 | std::filesystem::path model_path_; 35 | }; 36 | 37 | } // namespace zimtohrli 38 | 39 | #endif // CPP_ZIMT_VISQOL_H_ 40 | -------------------------------------------------------------------------------- /cpp/zimt/mos_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "zimt/mos.h" 16 | 17 | #include 18 | #include 19 | 20 | #include "gtest/gtest.h" 21 | 22 | namespace zimtohrli { 23 | 24 | namespace { 25 | 26 | TEST(MOS, MOSFromZimtohrli) { 27 | const std::vector zimt_scores = {0, 0.001, 0.01, 0.02, 0.03, 0.04}; 28 | const std::vector mos = { 29 | 5.0, 30 | 4.7487573623657227, 31 | 3.0908994674682617, 32 | 2.0929651260375977, 33 | 1.5713200569152832, 34 | 1.2986432313919067, 35 | }; 36 | for (size_t index = 0; index < zimt_scores.size(); ++index) { 37 | ASSERT_NEAR(MOSFromZimtohrli(zimt_scores[index]), mos[index], 1e-2); 38 | } 39 | } 40 | 41 | } // namespace 42 | 43 | } // namespace zimtohrli -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages, Extension 2 | from setuptools.command.build_ext import build_ext 3 | 4 | class PyohrliBuildExt(build_ext): 5 | def build_extensions(self): 6 | if self.compiler.compiler_type == 'msvc': 7 | # MSVC is strict about designated initializers (and requires special 8 | # flag syntax anyway). 9 | cpp_standard_flag = '/std:c++20' 10 | else: 11 | cpp_standard_flag = '-std=c++17' 12 | 13 | for ext in self.extensions: 14 | ext.extra_compile_args.append(cpp_standard_flag) 15 | 16 | super().build_extensions() 17 | 18 | setup( 19 | name='pyohrli', 20 | version='0.2.1', 21 | author='Martin Bruse, Jyrki Alakuijala', 22 | author_email='zond@google.com, jyrki@google.com', 23 | description='Psychoacoustic perceptual metric that quantifies the human observable difference in two audio signals in the proximity of just-noticeable-differences', 24 | long_description=open('README.md').read(), 25 | long_description_content_type='text/markdown', 26 | install_requires=[ 27 | 'numpy>=1.20.0', 28 | ], 29 | package_dir={'': 'cpp/zimt'}, 30 | packages=find_packages(where='cpp/zimt'), 31 | py_modules=['pyohrli'], 32 | ext_modules=[ 33 | Extension( 34 | name='_pyohrli', 35 | sources=['cpp/zimt/pyohrli.cc'], 36 | include_dirs=['cpp'], 37 | ), 38 | ], 39 | cmdclass={'build_ext': PyohrliBuildExt}, 40 | zip_safe=False, 41 | python_requires='>=3.10', 42 | ) 43 | -------------------------------------------------------------------------------- /cmake/tests.cmake: -------------------------------------------------------------------------------- 1 | add_executable(zimtohrli_test 2 | cpp/zimt/audio.cc 3 | cpp/zimt/audio_test.cc 4 | cpp/zimt/dtw_test.cc 5 | cpp/zimt/mos_test.cc 6 | cpp/zimt/nsim_test.cc 7 | cpp/zimt/test_file_paths.cc 8 | ) 9 | target_include_directories(zimtohrli_test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cpp) 10 | target_link_libraries(zimtohrli_test gtest sndfile gmock_main benchmark absl::statusor absl::check PkgConfig::soxr) 11 | target_compile_definitions(zimtohrli_test PRIVATE CMAKE_CURRENT_SOURCE_DIR=${CMAKE_CURRENT_SOURCE_DIR}) 12 | gtest_discover_tests(zimtohrli_test) 13 | 14 | set(python3_VENV_DIR ${CMAKE_CURRENT_BINARY_DIR}/venv) 15 | set(python3_VENV ${python3_VENV_DIR}/bin/python3) 16 | configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cpp/zimt/pyohrli.py ${CMAKE_CURRENT_BINARY_DIR}/pyohrli.py COPYONLY) 17 | configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cpp/zimt/pyohrli_test.py ${CMAKE_CURRENT_BINARY_DIR}/pyohrli_test.py COPYONLY) 18 | add_test(NAME zimtohrli_pyohrli_test 19 | COMMAND sh -c "${Python3_EXECUTABLE} -m venv ${python3_VENV_DIR} && 20 | ${python3_VENV} -m pip install jax jaxlib numpy scipy && 21 | ${python3_VENV} pyohrli_test.py" 22 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} 23 | ) 24 | 25 | add_test(NAME zimtohrli_go_test 26 | COMMAND go test ./... 27 | WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} 28 | ) 29 | 30 | add_executable(zimtohrli_benchmark 31 | cpp/zimt/dtw_test.cc 32 | cpp/zimt/nsim_test.cc 33 | cpp/zimt/distance_benchmark_test.cc 34 | ) 35 | target_include_directories(zimtohrli_benchmark PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cpp) 36 | target_link_libraries(zimtohrli_benchmark gtest gmock benchmark_main PkgConfig::soxr) 37 | -------------------------------------------------------------------------------- /cpp/zimt/mos.h: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef CPP_ZIMT_MOS_H_ 16 | #define CPP_ZIMT_MOS_H_ 17 | 18 | #include 19 | #include 20 | 21 | namespace zimtohrli { 22 | 23 | namespace { 24 | 25 | const std::array mos_params = {1.000e+00, -6.799e-09, 6.487e+01}; 26 | 27 | float sigmoid(float x) { 28 | return mos_params[0] / (mos_params[1] + std::exp(mos_params[2] * x)); 29 | } 30 | 31 | const float zero_crossing_reciprocal = 1.0 / sigmoid(0); 32 | 33 | } // namespace 34 | 35 | // Returns a _very_approximate_ mean opinion score based on the 36 | // provided Zimtohrli distance. 37 | // This is calibrated using default settings of v0.1.5, with a 38 | // minimum channel bandwidth (zimtohrli::Cam.minimum_bandwidth_hz) 39 | // of 5Hz and perceptual sample rate 40 | // (zimtohrli::Distance(..., perceptual_sample_rate, ...) of 100Hz. 41 | float MOSFromZimtohrli(float zimtohrli_distance) { 42 | return 1.0 + 4.0 * sigmoid(zimtohrli_distance) * zero_crossing_reciprocal; 43 | } 44 | 45 | } // namespace zimtohrli 46 | 47 | #endif // CPP_ZIMT_MOS_H_ -------------------------------------------------------------------------------- /go/data/table.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package data 16 | 17 | import ( 18 | "bytes" 19 | "fmt" 20 | ) 21 | 22 | // Row is a row of table data. 23 | type Row []string 24 | 25 | // Table is table structured data that can render in straight columns in a terminal. 26 | type Table []Row 27 | 28 | // String returns a string representation of the table with colSpacing blanks between columns. 29 | func (t Table) String() string { 30 | maxCells := 0 31 | for _, row := range t { 32 | if len(row) > maxCells { 33 | maxCells = len(row) 34 | } 35 | } 36 | maxCellWidths := make([]int, maxCells) 37 | for _, row := range t { 38 | for cellIndex, cell := range row { 39 | if len(cell) > maxCellWidths[cellIndex] { 40 | maxCellWidths[cellIndex] = len(cell) 41 | } 42 | } 43 | } 44 | out := &bytes.Buffer{} 45 | for _, row := range t { 46 | fmt.Fprint(out, "|") 47 | for cellIndex, maxCellWidth := range maxCellWidths { 48 | if row == nil { 49 | for i := 0; i < maxCellWidth+1; i++ { 50 | fmt.Fprint(out, "-") 51 | } 52 | } else { 53 | fmt.Fprint(out, row[cellIndex]) 54 | for i := len(row[cellIndex]); i < maxCellWidth+1; i++ { 55 | fmt.Fprint(out, " ") 56 | } 57 | } 58 | fmt.Fprint(out, "|") 59 | } 60 | fmt.Fprint(out, "\n") 61 | } 62 | return out.String() 63 | } 64 | -------------------------------------------------------------------------------- /cpp/zimt/dtw_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | #include 18 | 19 | #include "benchmark/benchmark.h" 20 | #include "gtest/gtest.h" 21 | #include "zimt/zimtohrli.h" 22 | 23 | namespace zimtohrli { 24 | 25 | namespace { 26 | 27 | TEST(DTW, DTWTest) { 28 | Spectrogram spec_a(10, 1, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); 29 | Spectrogram spec_b(10, 1, {0, 1, 2, 3, 3, 4, 5, 6, 8, 9}); 30 | 31 | const std::vector> got_dtw = DTW(spec_a, spec_b); 32 | const std::vector> expected_dtw = { 33 | {0, 0}, {1, 1}, {2, 2}, {3, 3}, {3, 4}, {4, 5}, 34 | {5, 6}, {6, 7}, {7, 8}, {8, 8}, {9, 9}}; 35 | EXPECT_EQ(got_dtw, expected_dtw); 36 | } 37 | 38 | void BM_DTW(benchmark::State& state) { 39 | Spectrogram spec_a(state.range(0), 1024); 40 | Spectrogram spec_b(state.range(0), 1024); 41 | for (size_t step_index = 0; step_index < spec_a.num_steps; ++step_index) { 42 | for (size_t channel_index = 0; channel_index < spec_a.num_dims; 43 | ++channel_index) { 44 | spec_a[step_index][channel_index] = 1.0; 45 | } 46 | } 47 | 48 | for (auto s : state) { 49 | DTW(spec_a, spec_b); 50 | } 51 | state.SetItemsProcessed(state.range(0) * state.iterations()); 52 | } 53 | BENCHMARK_RANGE(BM_DTW, 100, 5000); 54 | 55 | } // namespace 56 | 57 | } // namespace zimtohrli 58 | -------------------------------------------------------------------------------- /go/resource/pool.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // Package resource contains tools to handle resource pools. 16 | package resource 17 | 18 | import ( 19 | "sync" 20 | 21 | "github.com/google/zimtohrli/go/worker" 22 | ) 23 | 24 | // Closer can close. 25 | type Closer interface { 26 | Close() error 27 | } 28 | 29 | // Pool produces, keeps tracks of, and releases resources. 30 | type Pool[T Closer] struct { 31 | Create func() (T, error) 32 | 33 | lock sync.Mutex 34 | sequence int 35 | resources map[int]T 36 | } 37 | 38 | // Get returns a resource, which might create a new one. 39 | func (p *Pool[T]) Get() (T, error) { 40 | p.lock.Lock() 41 | defer p.lock.Unlock() 42 | if p.resources == nil { 43 | p.resources = map[int]T{} 44 | } 45 | for id, res := range p.resources { 46 | delete(p.resources, id) 47 | return res, nil 48 | } 49 | return p.Create() 50 | } 51 | 52 | // Return returns a resource to the pool. 53 | func (p *Pool[T]) Return(t T) { 54 | p.lock.Lock() 55 | defer p.lock.Unlock() 56 | if p.resources == nil { 57 | p.resources = map[int]T{} 58 | } 59 | p.sequence++ 60 | p.resources[p.sequence] = t 61 | } 62 | 63 | // Close releases all the resources of the pool. 64 | func (p *Pool[T]) Close() error { 65 | p.lock.Lock() 66 | defer p.lock.Unlock() 67 | errs := worker.Errors{} 68 | for _, res := range p.resources { 69 | if err := res.Close(); err != nil { 70 | errs = append(errs, err) 71 | } 72 | } 73 | if len(errs) > 0 { 74 | return errs 75 | } 76 | return nil 77 | } 78 | -------------------------------------------------------------------------------- /cpp/zimt/audio_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "zimt/audio.h" 16 | 17 | #include 18 | #include 19 | 20 | #include "absl/log/check.h" 21 | #include "absl/status/statusor.h" 22 | #include "gtest/gtest.h" 23 | #include "zimt/test_file_paths.h" 24 | 25 | namespace zimtohrli { 26 | 27 | namespace { 28 | 29 | TEST(AudioFile, LoadAudioFileTest) { 30 | const std::filesystem::path test_wav_path = 31 | GetTestFilePath("cpp/zimt/test.wav"); 32 | absl::StatusOr audio_file = 33 | AudioFile::Load(test_wav_path.string()); 34 | CHECK_OK(audio_file.status()); 35 | EXPECT_EQ(audio_file->Info().channels, 2); 36 | EXPECT_EQ(audio_file->Info().frames, 10); 37 | for (size_t frame_index = 0; frame_index < audio_file->Info().frames; 38 | ++frame_index) { 39 | for (size_t channel_index = 0; channel_index < audio_file->Info().channels; 40 | ++channel_index) { 41 | switch (channel_index) { 42 | case 0: 43 | switch (frame_index % 2) { 44 | case 0: 45 | EXPECT_EQ((*audio_file)[channel_index][frame_index], 0.5); 46 | break; 47 | case 1: 48 | EXPECT_EQ((*audio_file)[channel_index][frame_index], -0.5); 49 | break; 50 | } 51 | break; 52 | case 1: 53 | switch (frame_index % 2) { 54 | case 0: 55 | EXPECT_EQ((*audio_file)[channel_index][frame_index], 0.25); 56 | break; 57 | case 1: 58 | EXPECT_EQ((*audio_file)[channel_index][frame_index], -0.25); 59 | break; 60 | } 61 | break; 62 | } 63 | } 64 | } 65 | } 66 | 67 | } // namespace 68 | 69 | } // namespace zimtohrli 70 | -------------------------------------------------------------------------------- /cpp/zimt/distance_benchmark_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | #include 18 | 19 | #include "benchmark/benchmark.h" 20 | #include "gtest/gtest.h" 21 | #include "zimt/zimtohrli.h" 22 | 23 | namespace zimtohrli { 24 | 25 | namespace { 26 | 27 | // Generate a simple sine wave at given frequency 28 | std::vector GenerateSineWave(float frequency, float duration_seconds, 29 | float sample_rate = 48000.0f) { 30 | const size_t num_samples = static_cast(duration_seconds * sample_rate); 31 | std::vector samples(num_samples); 32 | const float angular_freq = 2.0f * M_PI * frequency / sample_rate; 33 | 34 | for (size_t i = 0; i < num_samples; ++i) { 35 | samples[i] = 0.5f * std::sin(angular_freq * i); 36 | } 37 | 38 | return samples; 39 | } 40 | 41 | static void BM_ZimtohrliFullPipeline(benchmark::State& state) { 42 | // Generate two 2-second sine waves at slightly different frequencies 43 | const float duration = 2.0f; // 2 seconds 44 | const std::vector signal_a = GenerateSineWave(440.0f, duration); // A4 45 | const std::vector signal_b = GenerateSineWave(445.0f, duration); // Slightly sharp A4 46 | 47 | Zimtohrli zimtohrli; 48 | 49 | // Benchmark the full pipeline 50 | for (auto _ : state) { 51 | Spectrogram spec_a = zimtohrli.Analyze({signal_a.data(), signal_a.size()}); 52 | Spectrogram spec_b = zimtohrli.Analyze({signal_b.data(), signal_b.size()}); 53 | float distance = zimtohrli.Distance(spec_a, spec_b); 54 | benchmark::DoNotOptimize(distance); 55 | } 56 | 57 | state.SetItemsProcessed(state.iterations()); 58 | } 59 | BENCHMARK(BM_ZimtohrliFullPipeline); 60 | 61 | } // namespace 62 | 63 | } // namespace zimtohrli 64 | -------------------------------------------------------------------------------- /cpp/zimt/audio.h: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef CPP_ZIMT_AUDIO_FILE_H_ 16 | #define CPP_ZIMT_AUDIO_FILE_H_ 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | #include "absl/status/status.h" 24 | #include "absl/status/statusor.h" 25 | #include "sndfile.h" 26 | #include "zimt/resample.h" 27 | #include "zimt/zimtohrli.h" 28 | 29 | namespace zimtohrli { 30 | 31 | // Returns a string representation of the libsndfile format ID. 32 | std::string GetFormatName(size_t format_id); 33 | 34 | // An audio file. 35 | class AudioFile { 36 | public: 37 | // Reads from the path and returns an audio file. 38 | static absl::StatusOr Load(const std::string& path); 39 | 40 | // Returns the path this audio file was loaded from. 41 | const std::string& Path() const { return path_; } 42 | 43 | // Returns the metadata about this audio file. 44 | const SF_INFO& Info() const { return info_; } 45 | 46 | // Returns a channel of this audio file. 47 | Span operator[](size_t n) const { 48 | return Span(buffer_.data() + info_.frames * n, info_.frames); 49 | } 50 | 51 | // Returns a channel of this audio file. 52 | Span operator[](size_t n) { 53 | return Span(buffer_.data() + info_.frames * n, info_.frames); 54 | } 55 | 56 | std::vector AtRate(size_t channel_id, float want_rate) { 57 | return Resample(operator[](channel_id), 58 | static_cast(info_.samplerate), 59 | want_rate); 60 | } 61 | 62 | private: 63 | AudioFile(const std::string& path, const SF_INFO& info, 64 | std::vector buffer) 65 | : path_(path), info_(info), buffer_(buffer) {} 66 | std::string path_; 67 | SF_INFO info_; 68 | std::vector buffer_; 69 | }; 70 | 71 | } // namespace zimtohrli 72 | 73 | #endif // CPP_ZIMT_AUDIO_FILE_H_ 74 | -------------------------------------------------------------------------------- /cpp/zimt/pyohrli_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Zimtohrli Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for google3.third_party.zimtohrli.cpp.python.pyohrli.""" 15 | 16 | import numpy as np 17 | 18 | import unittest 19 | import pyohrli 20 | import functools 21 | 22 | 23 | def parameterize(*kwargs): 24 | def decorator(func): 25 | @functools.wraps(func) 26 | def call_with_parameters(self, **inner_kwargs): 27 | for kwarg in kwargs: 28 | func(self, **kwarg) 29 | 30 | return call_with_parameters 31 | 32 | return decorator 33 | 34 | 35 | class PyohrliTest(unittest.TestCase): 36 | 37 | def test_num_rotators(self): 38 | self.assertEqual(128, pyohrli.Pyohrli().num_rotators) 39 | 40 | def test_sample_rate(self): 41 | self.assertEqual(48000, pyohrli.Pyohrli().sample_rate) 42 | 43 | @parameterize( 44 | dict( 45 | a_hz=5000.0, 46 | b_hz=5000.0, 47 | distance=0, 48 | ), 49 | dict( 50 | a_hz=5000.0, 51 | b_hz=5010.0, 52 | distance=3.737211227416992e-05, 53 | ), 54 | dict( 55 | a_hz=5000.0, 56 | b_hz=10000.0, 57 | distance=0.3206554651260376, 58 | ), 59 | ) 60 | def test_distance(self, a_hz: float, b_hz: float, distance: float): 61 | sample_rate = 48000.0 62 | metric = pyohrli.Pyohrli() 63 | signal_a = np.sin(np.linspace(0.0, np.pi * 2 * a_hz, int(sample_rate))) 64 | signal_b = np.sin(np.linspace(0.0, np.pi * 2 * b_hz, int(sample_rate))) 65 | distance = metric.distance(signal_a, signal_b) 66 | self.assertLess(abs(distance - distance), 1e-3) 67 | 68 | @parameterize( 69 | dict(zimtohrli_distance=0.0, mos=5.0), 70 | dict(zimtohrli_distance=0.001, mos=4.748757362365723), 71 | dict(zimtohrli_distance=0.04, mos=1.2986432313919067), 72 | ) 73 | def test_mos_from_zimtohrli(self, zimtohrli_distance: float, mos: float): 74 | self.assertAlmostEqual( 75 | mos, pyohrli.mos_from_zimtohrli(zimtohrli_distance), places=3 76 | ) 77 | 78 | 79 | if __name__ == "__main__": 80 | unittest.main() 81 | -------------------------------------------------------------------------------- /cpp/zimt/resample.h: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #ifndef CPP_ZIMT_RESAMPLE_H_ 16 | #define CPP_ZIMT_RESAMPLE_H_ 17 | 18 | #include 19 | #include 20 | #include 21 | 22 | #include "absl/log/check.h" 23 | #include "soxr.h" 24 | #include "zimt/zimtohrli.h" 25 | 26 | namespace zimtohrli { 27 | 28 | template 29 | std::vector Convert(Span input) { 30 | if constexpr (std::is_same::value) { 31 | std::vector result(input.size); 32 | memcpy(result.data(), input.data, input.size * sizeof(I)); 33 | return result; 34 | } 35 | std::vector output(input.size); 36 | for (size_t sample_index = 0; sample_index < input.size; ++sample_index) { 37 | output[sample_index] = static_cast(input[sample_index]); 38 | } 39 | return output; 40 | } 41 | 42 | template 43 | inline constexpr soxr_datatype_t SoxrType() { 44 | if constexpr (std::is_same_v) { 45 | return SOXR_INT16_I; 46 | } else if constexpr (std::is_same_v) { 47 | return SOXR_INT32_I; 48 | } else if constexpr (std::is_same_v) { 49 | return SOXR_FLOAT32_I; 50 | } else if constexpr (std::is_same_v) { 51 | return SOXR_FLOAT64_I; 52 | } else { 53 | // This can't be `static_assert(false)`, as explained here: 54 | // https://devblogs.microsoft.com/oldnewthing/20200311-00/?p=103553 55 | static_assert(sizeof(T) < 0, "Unsupported type for resampling"); 56 | } 57 | } 58 | 59 | template 60 | std::vector Resample(Span samples, float in_sample_rate, 61 | float out_sample_rate) { 62 | if (in_sample_rate == out_sample_rate) { 63 | return Convert(samples); 64 | } 65 | 66 | std::vector result( 67 | static_cast(samples.size * out_sample_rate / in_sample_rate)); 68 | soxr_quality_spec_t quality = soxr_quality_spec(SOXR_VHQ, SOXR_LINEAR_PHASE); 69 | soxr_io_spec_t io_spec = soxr_io_spec(SoxrType(), SoxrType()); 70 | const soxr_error_t error = soxr_oneshot( 71 | in_sample_rate, out_sample_rate, 1, samples.data, samples.size, nullptr, 72 | result.data(), result.size(), nullptr, &io_spec, &quality, nullptr); 73 | assert(error == 0); 74 | return result; 75 | } 76 | 77 | } // namespace zimtohrli 78 | 79 | #endif // CPP_ZIMT_RESAMPLE_H_ 80 | -------------------------------------------------------------------------------- /cpp/zimt/nsim_test.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | #include 18 | 19 | #include "benchmark/benchmark.h" 20 | #include "gmock/gmock.h" 21 | #include "gtest/gtest.h" 22 | #include "zimt/zimtohrli.h" 23 | 24 | namespace zimtohrli { 25 | 26 | namespace { 27 | 28 | void CheckEqual(Span span, std::vector expected) { 29 | for (size_t i = 0; i < span.size; i++) { 30 | EXPECT_EQ(span[i], expected[i]); 31 | } 32 | } 33 | 34 | TEST(NSIM, WindowMeanTest) { 35 | Spectrogram spec(5, 5, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 36 | 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); 37 | Spectrogram mean_3x3 = WindowMean( 38 | 5, 5, 3, 3, [&](size_t step, size_t dim) { return spec[step][dim]; }); 39 | CheckEqual(mean_3x3[0], {0.0, 1.0 / 9.0, 3.0 / 9.0, 6.0 / 9.0, 1.0}); 40 | CheckEqual(mean_3x3[1], {5.0 / 9.0, 12.0 / 9.0, 21.0 / 9.0, 3.0, 33.0 / 9.0}); 41 | CheckEqual(mean_3x3[2], {15.0 / 9.0, 33.0 / 9.0, 6.0, 7.0, 8.0}); 42 | CheckEqual(mean_3x3[3], {30.0 / 9.0, 7.0, 11.0, 12.0, 13.0}); 43 | CheckEqual(mean_3x3[4], {5.0, 93.0 / 9.0, 16.0, 17.0, 18.0}); 44 | } 45 | 46 | TEST(NSIM, NSIMTest) { 47 | Spectrogram spec_a(5, 5, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 48 | 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); 49 | Spectrogram spec_b(5, 5, {5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 50 | 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29}); 51 | EXPECT_THAT( 52 | NSIM(spec_a, spec_b, {{0, 0}, {1, 1}, {2, 2}, {3, 3}, {4, 4}}, 3, 3), 53 | 0.97899121); 54 | 55 | Spectrogram spec_c(5, 5, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 56 | 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); 57 | EXPECT_THAT( 58 | NSIM(spec_a, spec_c, {{0, 0}, {1, 1}, {2, 2}, {3, 3}, {4, 4}}, 3, 3), 59 | 1.0); 60 | } 61 | 62 | void BM_NSIM(benchmark::State& state) { 63 | Spectrogram spec_a(state.range(0) * 100, 1000); 64 | std::vector> time_pairs(spec_a.num_steps); 65 | for (size_t i = 0; i < time_pairs.size(); i++) { 66 | time_pairs[i] = {i, i}; 67 | } 68 | for (auto s : state) { 69 | NSIM(spec_a, spec_a, time_pairs, 9, 9); 70 | } 71 | state.SetItemsProcessed(spec_a.size() * state.iterations()); 72 | } 73 | BENCHMARK_RANGE(BM_NSIM, 1, 60); 74 | 75 | } // namespace 76 | 77 | } // namespace zimtohrli 78 | -------------------------------------------------------------------------------- /cpp/zimt/audio.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "zimt/audio.h" 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | #include "absl/base/attributes.h" 24 | #include "absl/log/check.h" 25 | #include "absl/status/status.h" 26 | #include "sndfile.h" 27 | #include "zimt/zimtohrli.h" 28 | 29 | namespace zimtohrli { 30 | 31 | std::string GetFormatName(size_t format_id) { 32 | if (format_id & SF_FORMAT_WAV) { 33 | return "wav"; 34 | } else if (format_id & SF_FORMAT_AIFF) { 35 | return "aiff"; 36 | } else if (format_id & SF_FORMAT_AU) { 37 | return "au"; 38 | } else if (format_id & SF_FORMAT_RAW) { 39 | return "raw"; 40 | } else if (format_id & SF_FORMAT_PAF) { 41 | return "paf"; 42 | } else if (format_id & SF_FORMAT_SVX) { 43 | return "svx"; 44 | } else if (format_id & SF_FORMAT_NIST) { 45 | return "nist"; 46 | } else if (format_id & SF_FORMAT_VOC) { 47 | return "voc"; 48 | } else if (format_id & SF_FORMAT_IRCAM) { 49 | return "ircam"; 50 | } else if (format_id & SF_FORMAT_W64) { 51 | return "w64"; 52 | } else if (format_id & SF_FORMAT_MAT4) { 53 | return "mat4"; 54 | } else if (format_id & SF_FORMAT_MAT5) { 55 | return "mat5"; 56 | } else if (format_id & SF_FORMAT_PVF) { 57 | return "pvf"; 58 | } else if (format_id & SF_FORMAT_XI) { 59 | return "xi"; 60 | } else if (format_id & SF_FORMAT_HTK) { 61 | return "htk"; 62 | } else if (format_id & SF_FORMAT_SDS) { 63 | return "sds"; 64 | } else if (format_id & SF_FORMAT_AVR) { 65 | return "avr"; 66 | } else if (format_id & SF_FORMAT_WAVEX) { 67 | return "wavex"; 68 | } else if (format_id & SF_FORMAT_SD2) { 69 | return "sd2"; 70 | } else if (format_id & SF_FORMAT_FLAC) { 71 | return "flac"; 72 | } else if (format_id & SF_FORMAT_CAF) { 73 | return "caf"; 74 | } else if (format_id & SF_FORMAT_WVE) { 75 | return "wve"; 76 | } else if (format_id & SF_FORMAT_OGG) { 77 | return "ogg"; 78 | } else if (format_id & SF_FORMAT_MPC2K) { 79 | return "mpc2k"; 80 | } else if (format_id & SF_FORMAT_RF64) { 81 | return "rf64"; 82 | } else if (format_id & SF_FORMAT_MPEG) { 83 | return "mpeg"; 84 | } 85 | return "unknown"; 86 | } 87 | 88 | absl::StatusOr AudioFile::Load(const std::string& path) { 89 | SF_INFO info{}; 90 | SNDFILE* file = sf_open(path.c_str(), SFM_READ, &info); 91 | if (sf_error(file)) { 92 | return absl::InternalError(sf_strerror(file)); 93 | } 94 | std::vector samples(info.channels * info.frames); 95 | CHECK_EQ(sf_readf_float(file, samples.data(), info.frames), info.frames); 96 | std::vector buffer(info.frames * info.channels); 97 | for (size_t frame_index = 0; frame_index < info.frames; ++frame_index) { 98 | for (size_t channel_index = 0; channel_index < info.channels; 99 | ++channel_index) { 100 | buffer[channel_index * info.frames + frame_index] = 101 | samples[frame_index * info.channels + channel_index]; 102 | } 103 | } 104 | sf_close(file); 105 | return AudioFile(path, info, std::move(buffer)); 106 | } 107 | 108 | } // namespace zimtohrli 109 | -------------------------------------------------------------------------------- /go/bin/csv/csv.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // csv crates a Zimtohrli dataset based on a CSV file. 16 | package main 17 | 18 | import ( 19 | "encoding/csv" 20 | "flag" 21 | "log" 22 | "os" 23 | "strconv" 24 | 25 | "github.com/google/zimtohrli/go/aio" 26 | "github.com/google/zimtohrli/go/data" 27 | "github.com/google/zimtohrli/go/progress" 28 | "github.com/google/zimtohrli/go/worker" 29 | ) 30 | 31 | func populate(dst string, src string, refHeader string, distHeader string, mosHeader string, workers int) error { 32 | study, err := data.OpenStudy(dst) 33 | if err != nil { 34 | return err 35 | } 36 | defer study.Close() 37 | 38 | srcFile, err := os.Open(src) 39 | if err != nil { 40 | return err 41 | } 42 | defer srcFile.Close() 43 | 44 | csvReader := csv.NewReader(srcFile) 45 | headers, err := csvReader.Read() 46 | if err != nil { 47 | return err 48 | } 49 | indices := map[string]int{} 50 | for index, header := range headers { 51 | indices[header] = index 52 | } 53 | 54 | references := []*data.Reference{} 55 | err = nil 56 | bar := progress.New("Downloading") 57 | pool := worker.Pool[any]{ 58 | Workers: workers, 59 | OnChange: bar.Update, 60 | } 61 | defer bar.Finish() 62 | for records, err := csvReader.Read(); err == nil; records, err = csvReader.Read() { 63 | refURL := records[indices[refHeader]] 64 | distURL := records[indices[distHeader]] 65 | mos, err := strconv.ParseFloat(records[indices[mosHeader]], 64) 66 | if err != nil { 67 | return err 68 | } 69 | ref := &data.Reference{ 70 | Name: refURL, 71 | } 72 | pool.Submit(func(func(any)) error { 73 | var err error 74 | ref.Path, err = aio.Fetch(ref.Name, dst) 75 | return err 76 | }) 77 | pool.Submit(func(func(any)) error { 78 | dist := &data.Distortion{ 79 | Name: distURL, 80 | Scores: map[data.ScoreType]float64{}, 81 | } 82 | pool.Submit(func(func(any)) error { 83 | var err error 84 | dist.Path, err = aio.Fetch(dist.Name, dst) 85 | return err 86 | }) 87 | dist.Scores[data.MOS] = mos 88 | ref.Distortions = append(ref.Distortions, dist) 89 | return nil 90 | }) 91 | references = append(references, ref) 92 | } 93 | if err := pool.Error(); err != nil { 94 | return err 95 | } 96 | if err := study.Put(references); err != nil { 97 | return err 98 | } 99 | return nil 100 | } 101 | 102 | func main() { 103 | destination := flag.String("dst", "", "Destination directory.") 104 | workers := flag.Int("workers", 1, "Number of workers downloading sounds.") 105 | source := flag.String("src", "", "Source CSV.") 106 | refHeader := flag.String("ref_header", "", "Header in the CSV with the URL to the reference file.") 107 | distHeader := flag.String("dist_header", "", "Header in the CSV with the URL to the distortion file.") 108 | mosHeader := flag.String("mos_header", "", "Header in the CSV with the MOS score.") 109 | flag.Parse() 110 | 111 | if *destination == "" || *source == "" || *refHeader == "" || *distHeader == "" || *mosHeader == "" { 112 | flag.Usage() 113 | os.Exit(1) 114 | } 115 | 116 | if err := populate(*destination, *source, *refHeader, *distHeader, *mosHeader, *workers); err != nil { 117 | log.Fatal(err) 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /go/bin/coresvnet/coresvnet.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // coresvnet downloads the listening test at https://listening-test.coresv.net/results.htm. 16 | package main 17 | 18 | import ( 19 | "flag" 20 | "fmt" 21 | "log" 22 | "net/http" 23 | "net/url" 24 | "os" 25 | "strconv" 26 | 27 | "github.com/PuerkitoBio/goquery" 28 | "github.com/google/zimtohrli/go/aio" 29 | "github.com/google/zimtohrli/go/data" 30 | "github.com/google/zimtohrli/go/progress" 31 | "github.com/google/zimtohrli/go/worker" 32 | ) 33 | 34 | func populate(dest string, workers int) error { 35 | study, err := data.OpenStudy(dest) 36 | if err != nil { 37 | return err 38 | } 39 | defer study.Close() 40 | 41 | rootURL, err := url.Parse("https://listening-test.coresv.net/results.htm") 42 | if err != nil { 43 | return err 44 | } 45 | res, err := http.Get(rootURL.String()) 46 | if err != nil { 47 | return err 48 | } 49 | defer res.Body.Close() 50 | if res.StatusCode != 200 { 51 | return fmt.Errorf("status code error: %d %s", res.StatusCode, res.Status) 52 | } 53 | doc, err := goquery.NewDocumentFromReader(res.Body) 54 | if err != nil { 55 | return err 56 | } 57 | resultTable := doc.Find("h2#list3:contains(\"All sets of tracks (5 sets, each 8 tracks)\")").Next().Find("table.table") 58 | 59 | references := []*data.Reference{} 60 | err = nil 61 | bar := progress.New("Downloading") 62 | pool := worker.Pool[any]{ 63 | Workers: workers, 64 | OnChange: bar.Update, 65 | } 66 | resultTable.Find("tbody > tr").Each(func(index int, sel *goquery.Selection) { 67 | columns := sel.Find("td") 68 | if columns.Length() != 8 { 69 | return 70 | } 71 | u, parseErr := rootURL.Parse(columns.Eq(0).Find("a").AttrOr("href", "")) 72 | if parseErr != nil { 73 | err = parseErr 74 | return 75 | } 76 | ref := &data.Reference{ 77 | Name: u.String(), 78 | } 79 | pool.Submit(func(func(any)) error { 80 | var err error 81 | ref.Path, err = aio.Fetch(ref.Name, dest) 82 | return err 83 | }) 84 | for columnIndex := 2; columnIndex < columns.Length(); columnIndex++ { 85 | u, parseErr := rootURL.Parse(columns.Eq(columnIndex).Find("a").AttrOr("href", "")) 86 | if parseErr != nil { 87 | err = parseErr 88 | return 89 | } 90 | dist := &data.Distortion{ 91 | Name: u.String(), 92 | Scores: map[data.ScoreType]float64{}, 93 | } 94 | pool.Submit(func(func(any)) error { 95 | var err error 96 | dist.Path, err = aio.Fetch(dist.Name, dest) 97 | return err 98 | }) 99 | score, parseErr := strconv.ParseFloat(columns.Eq(columnIndex).Text(), 64) 100 | if parseErr != nil { 101 | err = parseErr 102 | return 103 | } 104 | dist.Scores[data.MOS] = score 105 | ref.Distortions = append(ref.Distortions, dist) 106 | } 107 | references = append(references, ref) 108 | }) 109 | if err := pool.Error(); err != nil { 110 | return err 111 | } 112 | if err := study.Put(references); err != nil { 113 | return err 114 | } 115 | bar.Finish() 116 | return nil 117 | } 118 | 119 | func main() { 120 | destination := flag.String("dest", "", "Destination directory.") 121 | workers := flag.Int("workers", 1, "Number of workers downloading sounds.") 122 | flag.Parse() 123 | if *destination == "" { 124 | flag.Usage() 125 | os.Exit(1) 126 | } 127 | 128 | if err := populate(*destination, *workers); err != nil { 129 | log.Fatal(err) 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /go/bin/perceptual_audio/perceptual_audio.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // perceptual_audio creates a study from https://github.com/pranaymanocha/PerceptualAudio/blob/master/dataset/README.md. 16 | // 17 | // Download and unpack the dataset ZIP, and provide the unpacked directory 18 | // as -source when running this binary. 19 | package main 20 | 21 | import ( 22 | "bufio" 23 | "flag" 24 | "fmt" 25 | "log" 26 | "net/http" 27 | "net/url" 28 | "os" 29 | "path/filepath" 30 | "runtime" 31 | "strconv" 32 | "strings" 33 | 34 | "github.com/google/zimtohrli/go/aio" 35 | "github.com/google/zimtohrli/go/data" 36 | "github.com/google/zimtohrli/go/progress" 37 | "github.com/google/zimtohrli/go/worker" 38 | ) 39 | 40 | func populate(source string, dest string, workers int) error { 41 | study, err := data.OpenStudy(dest) 42 | if err != nil { 43 | return err 44 | } 45 | defer study.Close() 46 | 47 | csvURL, err := url.Parse("https://raw.githubusercontent.com/pranaymanocha/PerceptualAudio/master/dataset/dataset_combined.txt") 48 | if err != nil { 49 | return err 50 | } 51 | res, err := http.Get(csvURL.String()) 52 | if err != nil { 53 | return err 54 | } 55 | defer res.Body.Close() 56 | if res.StatusCode != 200 { 57 | return fmt.Errorf("status code error: %d %s", res.StatusCode, res.Status) 58 | } 59 | 60 | lineReader := bufio.NewReader(res.Body) 61 | err = nil 62 | bar := progress.New("Transcoding") 63 | pool := worker.Pool[*data.Reference]{ 64 | Workers: workers, 65 | OnChange: bar.Update, 66 | } 67 | line := "" 68 | lineIndex := 0 69 | for line, err = lineReader.ReadString('\n'); err == nil; line, err = lineReader.ReadString('\n') { 70 | fields := strings.Split(strings.TrimSpace(line), "\t") 71 | jnd, err := strconv.ParseFloat(fields[2], 64) 72 | if err != nil { 73 | return err 74 | } 75 | refIndex := lineIndex 76 | pool.Submit(func(f func(*data.Reference)) error { 77 | ref := &data.Reference{ 78 | Name: fmt.Sprintf("ref-%v", refIndex), 79 | } 80 | var err error 81 | ref.Path, err = aio.Fetch(filepath.Join(source, fields[0]), dest) 82 | if err != nil { 83 | return fmt.Errorf("unable to fetch %q", fields[0]) 84 | } 85 | dist := &data.Distortion{ 86 | Name: fmt.Sprintf("dist-%v", refIndex), 87 | Scores: map[data.ScoreType]float64{ 88 | data.JND: jnd, 89 | }, 90 | } 91 | dist.Path, err = aio.Fetch(filepath.Join(source, fields[1]), dest) 92 | if err != nil { 93 | return fmt.Errorf("unable to fetch %q", fields[1]) 94 | } 95 | ref.Distortions = append(ref.Distortions, dist) 96 | f(ref) 97 | return nil 98 | }) 99 | lineIndex++ 100 | } 101 | if err := pool.Error(); err != nil { 102 | log.Println(err.Error()) 103 | } 104 | bar.Finish() 105 | refs := []*data.Reference{} 106 | for ref := range pool.Results() { 107 | refs = append(refs, ref) 108 | } 109 | if err := study.Put(refs); err != nil { 110 | return err 111 | } 112 | return nil 113 | } 114 | 115 | func main() { 116 | source := flag.String("source", "", "Directory containing the unpacked http://percepaudio.cs.princeton.edu/icassp2020_perceptual/audio_perception.zip.") 117 | destination := flag.String("dest", "", "Destination directory.") 118 | workers := flag.Int("workers", runtime.NumCPU(), "Number of workers transcoding sounds.") 119 | flag.Parse() 120 | if *source == "" || *destination == "" { 121 | flag.Usage() 122 | os.Exit(1) 123 | } 124 | 125 | if err := populate(*source, *destination, *workers); err != nil { 126 | log.Fatal(err) 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.27) 2 | 3 | project(Zimtohrli) 4 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) 5 | set(CMAKE_CXX_CLANG_TIDY clang-tidy) 6 | add_compile_options(-fPIC) 7 | 8 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=native -O3") 9 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -O3") 10 | 11 | include(cmake/deps.cmake) 12 | 13 | enable_testing() 14 | include(GoogleTest) 15 | 16 | add_library(zimtohrli_visqol_adapter STATIC 17 | cpp/zimt/visqol_model.h 18 | cpp/zimt/visqol_model.cc 19 | cpp/zimt/visqol.h 20 | cpp/zimt/visqol.cc 21 | cpp/zimt/resample.h 22 | cpp/zimt/zimtohrli.h 23 | ) 24 | target_include_directories(zimtohrli_visqol_adapter PUBLIC cpp) 25 | target_link_libraries(zimtohrli_visqol_adapter PRIVATE visqol PkgConfig::soxr) 26 | target_link_libraries(zimtohrli_visqol_adapter PUBLIC absl::span) 27 | # Use selective fast-math flags only for zimtohrli code to avoid protobuf warnings 28 | target_compile_options(zimtohrli_visqol_adapter PRIVATE 29 | -freciprocal-math -fno-signed-zeros -fno-math-errno) 30 | 31 | find_package(Python3 COMPONENTS Interpreter Development) 32 | add_library(zimtohrli_pyohrli SHARED 33 | cpp/zimt/zimtohrli.h 34 | cpp/zimt/pyohrli.cc 35 | ) 36 | target_include_directories(zimtohrli_pyohrli PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cpp) 37 | set_target_properties(zimtohrli_pyohrli PROPERTIES 38 | PREFIX "" 39 | OUTPUT_NAME _pyohrli.so 40 | SUFFIX "" 41 | ) 42 | target_link_libraries(zimtohrli_pyohrli Python3::Python) 43 | 44 | add_library(zimtohrli_goohrli_glue STATIC 45 | cpp/zimt/goohrli.cc 46 | ) 47 | target_include_directories(zimtohrli_goohrli_glue PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cpp ${CMAKE_CURRENT_SOURCE_DIR}/go/goohrli ${CMAKE_CURRENT_SOURCE_DIR}/go/gosqol) 48 | target_link_libraries(zimtohrli_goohrli_glue zimtohrli_visqol_adapter PkgConfig::soxr) 49 | # Use selective fast-math flags only for zimtohrli code 50 | target_compile_options(zimtohrli_goohrli_glue PRIVATE 51 | -fassociative-math -freciprocal-math -fno-signed-zeros -fno-math-errno) 52 | 53 | set(zimtohrli_goohrli_object ${CMAKE_CURRENT_BINARY_DIR}/goohrli.o) 54 | set(zimtohrli_goohrli_archive_build ${CMAKE_CURRENT_BINARY_DIR}/goohrli.a) 55 | set(zimtohrli_goohrli_archive ${CMAKE_CURRENT_SOURCE_DIR}/go/goohrli/goohrli.a) 56 | 57 | if (NOT WIN32) 58 | add_custom_command( 59 | OUTPUT ${zimtohrli_goohrli_archive_build} 60 | COMMAND ${CMAKE_LINKER} -r 61 | $$\(find ${CMAKE_CURRENT_BINARY_DIR}/CMakeFiles/zimtohrli_goohrli_glue.dir/ -name \"*.o\"\) 62 | $$\(find ${CMAKE_CURRENT_BINARY_DIR}/CMakeFiles/zimtohrli_visqol_adapter.dir/ -name \"*.o\"\) 63 | $$\(find ${CMAKE_CURRENT_BINARY_DIR}/CMakeFiles/visqol.dir/ -name \"*.o\"\) 64 | $$\(find ${CMAKE_CURRENT_BINARY_DIR}/CMakeFiles/visqol_proto.dir/ -name \"*.o\"\) 65 | $$\(find ${CMAKE_CURRENT_BINARY_DIR}/CMakeFiles/libsvm.dir/ -name \"*.o\"\) 66 | $$\(find ${CMAKE_CURRENT_BINARY_DIR}/CMakeFiles/pffft.dir/ -name \"*.o\"\) 67 | $$\(find ${libsndfile_BINARY_DIR}/CMakeFiles/sndfile.dir/ -name \"*.o\"\) 68 | $$\(find ${protobuf_BINARY_DIR} -name \"*.o\" ! -ipath \"*/google/protobuf/compiler/main.cc.o\"\) 69 | -o ${zimtohrli_goohrli_object} 70 | COMMAND ${CMAKE_AR} rcs ${zimtohrli_goohrli_archive_build} ${zimtohrli_goohrli_object} 71 | COMMAND ${CMAKE_COMMAND} -E copy ${zimtohrli_goohrli_archive_build} ${zimtohrli_goohrli_archive} 72 | DEPENDS zimtohrli_goohrli_glue zimtohrli_visqol_adapter 73 | ) 74 | add_custom_target(zimtohrli_goohrli ALL DEPENDS ${zimtohrli_goohrli_archive_build}) 75 | endif() 76 | 77 | add_executable(zimtohrli_compare 78 | cpp/zimt/audio.cc 79 | cpp/zimt/compare.cc 80 | ) 81 | target_include_directories(zimtohrli_compare PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/cpp) 82 | target_link_libraries(zimtohrli_compare sndfile absl::statusor absl::check absl::flags_parse PkgConfig::soxr) 83 | set_target_properties(zimtohrli_compare PROPERTIES 84 | PREFIX "" 85 | OUTPUT_NAME compare 86 | SUFFIX "" 87 | ) 88 | # Use selective fast-math flags for compare tool 89 | target_compile_options(zimtohrli_compare PRIVATE 90 | -fassociative-math -freciprocal-math -fno-signed-zeros -fno-math-errno) 91 | 92 | option(BUILD_ZIMTOHRLI_TESTS "Build Zimtohrli test binaries." ON) 93 | if (BUILD_ZIMTOHRLI_TESTS) 94 | include(cmake/tests.cmake) 95 | endif() 96 | -------------------------------------------------------------------------------- /go/goohrli/goohrli.h: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // 16 | // This file contains a C-compatible API layer simple to integrate in Go via 17 | // cgo. 18 | // 19 | // All C++ classes used are aliased as void*, and all plumbing is done in 20 | // C-style function calls. 21 | // 22 | 23 | #ifndef GO_LIB_GOOHRLI_H_ 24 | #define GO_LIB_GOOHRLI_H_ 25 | 26 | #ifdef __cplusplus 27 | extern "C" { 28 | #endif 29 | 30 | #define NUM_LOUDNESS_A_F_PARAMS 10 31 | #define NUM_LOUDNESS_L_U_PARAMS 16 32 | #define NUM_LOUDNESS_T_F_PARAMS 13 33 | 34 | // The supported sample rate for Zimtohrli. 35 | float SampleRate(); 36 | 37 | // The number of rotators used by Zimtohrli, i.e. the number 38 | // of dimensions in the spectrograms. 39 | int NumRotators(); 40 | 41 | // Contains the parameters controlling Zimtohrli behavior. 42 | typedef struct ZimtohrliParameters { 43 | float PerceptualSampleRate; 44 | float FullScaleSineDB; 45 | int NSIMStepWindow; 46 | int NSIMChannelWindow; 47 | } ZimtohrliParameters; 48 | 49 | // Returns the default parameters. 50 | ZimtohrliParameters DefaultZimtohrliParameters(); 51 | 52 | // void* representation of zimtohrli::Zimtohrli. 53 | typedef void* Zimtohrli; 54 | 55 | // Returns a zimtohrli::Zimtohrli for the given parameters. 56 | Zimtohrli CreateZimtohrli(ZimtohrliParameters params); 57 | 58 | // Deletes a zimtohrli::Zimtohrli. 59 | void FreeZimtohrli(Zimtohrli z); 60 | 61 | // Returns the number of steps a spectrogram of the given number 62 | // of samples requires. 63 | int SpectrogramSteps(Zimtohrli zimtohrli, int samples); 64 | 65 | // Represents a zimtohrli::Span. 66 | typedef struct { 67 | float* data; 68 | int size; 69 | } GoSpan; 70 | 71 | // Represents a zimtohrli::Spectrogram. 72 | typedef struct { 73 | float* values; 74 | int steps; 75 | int dims; 76 | } GoSpectrogram; 77 | 78 | // Returns a spectrogram by the provided zimtohrli::Zimtohrli using the provided 79 | // data. 80 | void Analyze(Zimtohrli zimtohrli, const GoSpan* signal, GoSpectrogram* spec); 81 | 82 | // Returns an approximate mean opinion score based on the 83 | // provided Zimtohrli distance. 84 | // This is calibrated using default settings of v0.1.5, with a 85 | // minimum channel bandwidth (zimtohrli::Cam.minimum_bandwidth_hz) 86 | // of 5Hz and perceptual sample rate 87 | // (zimtohrli::Distance(..., perceptual_sample_rate, ...) of 100Hz. 88 | float MOSFromZimtohrli(float zimtohrli_distance); 89 | 90 | // Returns the Zimtohrli distance between two analyses using the provided 91 | // zimtohrli::Zimtohrli. 92 | float Distance(Zimtohrli zimtohrli, const GoSpectrogram* a, GoSpectrogram* b); 93 | 94 | // Sets the parameters. 95 | // 96 | // Sample rate, frequency resolution, and filter parameters can only be set when 97 | // an instance is created and will be ignored in this function. 98 | void SetZimtohrliParameters(Zimtohrli zimtohrli, 99 | ZimtohrliParameters parameters); 100 | 101 | // Returns the parameters. 102 | ZimtohrliParameters GetZimtohrliParameters(Zimtohrli zimtohrli); 103 | 104 | // void* representation of zimtohrli::ViSQOL. 105 | typedef void* ViSQOL; 106 | 107 | // Returns a zimtohrli::ViSQOL. 108 | ViSQOL CreateViSQOL(); 109 | 110 | // Deletes a zimtohrli::ViSQOL. 111 | void FreeViSQOL(ViSQOL v); 112 | 113 | // MOSResult contains a MOS value and a status code. 114 | typedef struct { 115 | float MOS; 116 | int Status; 117 | } MOSResult; 118 | 119 | // MOS returns a ViSQOL MOS between reference and distorted. 120 | MOSResult ViSQOLMOS(ViSQOL v, float sample_rate, const float* reference, 121 | int reference_size, const float* distorted, 122 | int distorted_size); 123 | 124 | #ifdef __cplusplus 125 | } 126 | #endif 127 | 128 | #endif // GO_LIB_GOOHRLI_H_ 129 | -------------------------------------------------------------------------------- /cpp/zimt/pyohrli.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Zimtohrli Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Pyohrli is a Zimtohrli wrapper in Python.""" 15 | 16 | import numpy as np 17 | import numpy.typing as npt 18 | 19 | import _pyohrli 20 | 21 | def mos_from_signals(signal_a: npt.ArrayLike, signal_b: npt.ArrayLike) -> float: 22 | """Returns the Mean Opinion Score (MOS) between two audio signals using Zimtohrli""" 23 | distance = Pyohrli().distance(signal_a, signal_b) 24 | mos = mos_from_zimtohrli(distance) 25 | return mos 26 | 27 | def mos_from_zimtohrli(zimtohrli_distance: float) -> float: 28 | """Returns an approximate mean opinion score based on the provided Zimtohrli distance.""" 29 | return _pyohrli.MOSFromZimtohrli(zimtohrli_distance) 30 | 31 | 32 | class Spectrogram: 33 | """Wrapper around C++ zimtohrli::Spectrogram.""" 34 | 35 | _cc_analysis: _pyohrli.Spectrogram 36 | 37 | 38 | class Pyohrli: 39 | """Wrapper around C++ zimtohrli::Zimtohrli.""" 40 | 41 | _cc_pyohrli: _pyohrli.Pyohrli 42 | 43 | def __init__(self): 44 | """Initializes the instance.""" 45 | self._cc_pyohrli = _pyohrli.Pyohrli() 46 | 47 | def distance(self, signal_a: npt.ArrayLike, signal_b: npt.ArrayLike) -> float: 48 | """Computes the distance between two signals. 49 | 50 | The signals must be (num_samples,)-shaped 48kHz [-1, 1] float arrays. 51 | 52 | Args: 53 | signal_a: A signal to compare. 54 | signal_b: Another signal to compare with. 55 | 56 | Returns: 57 | The Zimtohrli distance between the signals. 58 | """ 59 | return self._cc_pyohrli.distance( 60 | np.asarray(signal_a).astype(np.float32).ravel().data, 61 | np.asarray(signal_b).astype(np.float32).ravel().data, 62 | ) 63 | 64 | def analyze(self, signal: npt.ArrayLike) -> npt.ArrayLike: 65 | """Computes the Zimtohrli spectrogram of a signal.compile. 66 | 67 | Args: 68 | signal: A (num_samples,)-shaped 48kHz [-1, 1] float array. 69 | 70 | Returns: 71 | A (num_steps, num_dims)-shaped float array with the Zimtohrli 72 | spectrogram of the signal. 73 | """ 74 | bts = self._cc_pyohrli.analyze( 75 | np.asarray(signal).astype(np.float32).ravel().data 76 | ) 77 | result = np.frombuffer(bts, dtype=np.float32) 78 | num_rotators = self._cc_pyohrli.num_rotators() 79 | return result.reshape((result.shape[0] // num_rotators, num_rotators)) 80 | 81 | @property 82 | def sample_rate(self) -> float: 83 | """Expected sample rate of analyzed audio.""" 84 | return self._cc_pyohrli.sample_rate() 85 | 86 | @property 87 | def num_rotators(self) -> int: 88 | """Number of rotators (spectrogram dimensions).""" 89 | return int(self._cc_pyohrli.num_rotators()) 90 | 91 | @property 92 | def full_scale_sine_db(self) -> float: 93 | """Reference intensity for an amplitude 1.0 sine wave at 1kHz. 94 | 95 | Defaults to 80dB SPL. 96 | """ 97 | return self._cc_pyohrli.get_full_scale_sine_db() 98 | 99 | @full_scale_sine_db.setter 100 | def full_scale_sine_db(self, value: float): 101 | self._cc_pyohrli.set_full_scale_sine_db(value) 102 | 103 | @property 104 | def nsim_step_window(self) -> float: 105 | """Order of the window in perceptual_sample_rate time steps when compting the NSIM.""" 106 | return self._cc_pyohrli.get_nsim_step_window() 107 | 108 | @nsim_step_window.setter 109 | def nsim_step_window(self, value: float): 110 | self._cc_pyohrli.set_nsim_step_window(value) 111 | 112 | @property 113 | def nsim_channel_window(self) -> float: 114 | """Order of the window in channels when computing the NSIM.""" 115 | return self._cc_pyohrli.get_nsim_channel_window() 116 | 117 | @nsim_channel_window.setter 118 | def nsim_channel_window(self, value: float): 119 | self._cc_pyohrli.set_nsim_channel_window(value) 120 | -------------------------------------------------------------------------------- /cpp/zimt/goohrli.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "goohrli.h" 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #include "absl/log/check.h" 23 | #include "absl/status/statusor.h" 24 | #include "absl/types/span.h" 25 | #include "zimt/mos.h" 26 | #include "zimt/visqol.h" 27 | #include "zimt/zimtohrli.h" 28 | 29 | float SampleRate() { return zimtohrli::kSampleRate; } 30 | 31 | int NumRotators() { return zimtohrli::kNumRotators; } 32 | 33 | float MOSFromZimtohrli(float zimtohrli_distance) { 34 | return zimtohrli::MOSFromZimtohrli(zimtohrli_distance); 35 | } 36 | 37 | Zimtohrli CreateZimtohrli(ZimtohrliParameters params) { 38 | zimtohrli::Zimtohrli* result = new zimtohrli::Zimtohrli{}; 39 | SetZimtohrliParameters(result, params); 40 | return result; 41 | } 42 | 43 | void FreeZimtohrli(Zimtohrli zimtohrli) { 44 | delete static_cast(zimtohrli); 45 | } 46 | 47 | int SpectrogramSteps(Zimtohrli zimtohrli, int samples) { 48 | zimtohrli::Zimtohrli* z = static_cast(zimtohrli); 49 | return z->SpectrogramSteps(static_cast(samples)); 50 | } 51 | 52 | void Analyze(Zimtohrli zimtohrli, const GoSpan* signal, GoSpectrogram* result) { 53 | zimtohrli::Zimtohrli* z = static_cast(zimtohrli); 54 | zimtohrli::Spectrogram spec = 55 | zimtohrli::Spectrogram(result->steps, result->dims, result->values); 56 | z->Analyze(zimtohrli::Span(signal->data, signal->size), spec); 57 | spec.values.release(); 58 | } 59 | 60 | float Distance(Zimtohrli zimtohrli, const GoSpectrogram* a, GoSpectrogram* b) { 61 | zimtohrli::Zimtohrli* z = static_cast(zimtohrli); 62 | zimtohrli::Spectrogram spec_a = 63 | zimtohrli::Spectrogram(a->steps, a->dims, a->values); 64 | zimtohrli::Spectrogram spec_b = 65 | zimtohrli::Spectrogram(b->steps, b->dims, b->values); 66 | const float result = z->Distance(spec_a, spec_b); 67 | spec_a.values.release(); 68 | spec_b.values.release(); 69 | return result; 70 | } 71 | 72 | ZimtohrliParameters GetZimtohrliParameters(const Zimtohrli zimtohrli) { 73 | zimtohrli::Zimtohrli* z = static_cast(zimtohrli); 74 | ZimtohrliParameters result; 75 | result.PerceptualSampleRate = z->perceptual_sample_rate; 76 | result.FullScaleSineDB = z->full_scale_sine_db; 77 | result.NSIMStepWindow = z->nsim_step_window; 78 | result.NSIMChannelWindow = z->nsim_channel_window; 79 | return result; 80 | } 81 | 82 | void SetZimtohrliParameters(Zimtohrli zimtohrli, 83 | const ZimtohrliParameters parameters) { 84 | zimtohrli::Zimtohrli* z = static_cast(zimtohrli); 85 | z->perceptual_sample_rate = parameters.PerceptualSampleRate; 86 | z->full_scale_sine_db = parameters.FullScaleSineDB; 87 | z->nsim_step_window = parameters.NSIMStepWindow; 88 | z->nsim_channel_window = parameters.NSIMChannelWindow; 89 | } 90 | 91 | ZimtohrliParameters DefaultZimtohrliParameters() { 92 | zimtohrli::Zimtohrli default_zimtohrli{}; 93 | return GetZimtohrliParameters(&default_zimtohrli); 94 | } 95 | 96 | ViSQOL CreateViSQOL() { return new zimtohrli::ViSQOL(); } 97 | 98 | void FreeViSQOL(ViSQOL v) { delete (zimtohrli::ViSQOL*)(v); } 99 | 100 | MOSResult ViSQOLMOS(const ViSQOL v, float sample_rate, const float* reference, 101 | int reference_size, const float* distorted, 102 | int distorted_size) { 103 | const zimtohrli::ViSQOL* visqol = static_cast(v); 104 | const absl::StatusOr result = visqol->MOS( 105 | zimtohrli::Span(reference, reference_size), 106 | zimtohrli::Span(distorted, distorted_size), sample_rate); 107 | if (result.ok()) { 108 | return MOSResult{.MOS = result.value(), .Status = 0}; 109 | } else { 110 | return MOSResult{.MOS = 0.0, 111 | .Status = static_cast(result.status().code())}; 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /go/worker/pool.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // Package worker contains functionality to parallelize tasks with a pool of workers. 16 | package worker 17 | 18 | import ( 19 | "bytes" 20 | "fmt" 21 | "log" 22 | "sync" 23 | "sync/atomic" 24 | ) 25 | 26 | // ChangeHandler is updated when the worker pool increases the number of submitted, completed, or error jobs. 27 | type ChangeHandler func(submitted, completed, errors int) 28 | 29 | // ErrorHandler is updated when the worker pool encounters an error. The encountered error will be replaced with the return value of the handler. 30 | type ErrorHandler func(error) error 31 | 32 | // Pool is a pool of workers. 33 | type Pool[T any] struct { 34 | Workers int 35 | OnChange ChangeHandler 36 | OnError ErrorHandler 37 | FailFast bool 38 | 39 | startOnce sync.Once 40 | 41 | jobs chan func(func(T)) error 42 | jobsWaitGroup sync.WaitGroup 43 | results chan T 44 | resultsWaitGroup sync.WaitGroup 45 | errors chan error 46 | errorsWaitGroup sync.WaitGroup 47 | 48 | submittedJobs uint32 49 | completedJobs uint32 50 | errorJobs uint32 51 | } 52 | 53 | func (p *Pool[T]) init() { 54 | p.startOnce.Do(func() { 55 | p.jobs = make(chan func(func(T)) error) 56 | p.results = make(chan T) 57 | p.errors = make(chan error) 58 | for i := 0; i < p.Workers; i++ { 59 | go func() { 60 | for job := range p.jobs { 61 | if err := job(func(t T) { 62 | p.resultsWaitGroup.Add(1) 63 | go func() { 64 | p.results <- t 65 | p.resultsWaitGroup.Done() 66 | }() 67 | }); err != nil { 68 | if err = p.err(err); err != nil { 69 | if p.FailFast { 70 | log.Fatal(err) 71 | } 72 | p.errorsWaitGroup.Add(1) 73 | go func() { 74 | p.errors <- err 75 | p.errorsWaitGroup.Done() 76 | }() 77 | atomic.AddUint32(&p.errorJobs, 1) 78 | p.change() 79 | } 80 | } 81 | p.jobsWaitGroup.Done() 82 | atomic.AddUint32(&p.completedJobs, 1) 83 | p.change() 84 | } 85 | }() 86 | } 87 | }) 88 | } 89 | 90 | func (p *Pool[T]) err(err error) error { 91 | if p.OnError != nil { 92 | return p.OnError(err) 93 | } 94 | return err 95 | } 96 | 97 | func (p *Pool[T]) change() { 98 | if p.OnChange != nil { 99 | p.OnChange(int(atomic.LoadUint32(&p.submittedJobs)), int(atomic.LoadUint32(&p.completedJobs)), int(atomic.LoadUint32(&p.errorJobs))) 100 | } 101 | } 102 | 103 | // Submit submits a job to the pool. 104 | func (p *Pool[T]) Submit(job func(func(T)) error) error { 105 | p.init() 106 | 107 | p.jobsWaitGroup.Add(1) 108 | atomic.AddUint32(&p.submittedJobs, 1) 109 | p.change() 110 | 111 | go func() { 112 | p.jobs <- job 113 | }() 114 | return nil 115 | } 116 | 117 | // Errors is a slice of errors. 118 | type Errors []error 119 | 120 | func (e Errors) Error() string { 121 | buf := &bytes.Buffer{} 122 | for _, err := range e { 123 | fmt.Fprintln(buf, err.Error()) 124 | } 125 | return buf.String() 126 | } 127 | 128 | // Error waits for all submitted jobs to finish, closes the submission channel, and returns whether 129 | // any of the jobs produced an error. 130 | // 131 | // Must be called after all jobs are added. 132 | func (p *Pool[T]) Error() error { 133 | p.init() 134 | 135 | p.jobsWaitGroup.Wait() 136 | close(p.jobs) 137 | go func() { 138 | p.errorsWaitGroup.Wait() 139 | close(p.errors) 140 | }() 141 | result := Errors{} 142 | for err := range p.errors { 143 | result = append(result, err) 144 | } 145 | if len(result) > 0 { 146 | return result 147 | } 148 | return nil 149 | } 150 | 151 | // Results returns all results produced. The result channel will close once all results are processed. 152 | // 153 | // Must be called if any jobs might have produced results. 154 | // 155 | // Error() must be called before Results(). 156 | func (p *Pool[T]) Results() <-chan T { 157 | go func() { 158 | p.resultsWaitGroup.Wait() 159 | close(p.results) 160 | }() 161 | return p.results 162 | } 163 | -------------------------------------------------------------------------------- /go/aio/aio.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // Package aio handles audio in/out. 16 | package aio 17 | 18 | import ( 19 | "bytes" 20 | "encoding/binary" 21 | "fmt" 22 | "os" 23 | "os/exec" 24 | "path/filepath" 25 | "strings" 26 | 27 | "github.com/google/zimtohrli/go/audio" 28 | ) 29 | 30 | // Fetch calls Recode if path ends with .wav, otherwise Copy. 31 | func Fetch(path string, dir string) (string, error) { 32 | if strings.ToLower(filepath.Ext(path)) == ".wav" { 33 | return Recode(path, dir) 34 | } 35 | return Copy(path, dir) 36 | } 37 | 38 | // Load loads audio from an ffmpeg-decodable file from a path (which may be a URL). 39 | func Load(path string) (*audio.Audio, error) { 40 | return LoadAtRate(path, 48000) 41 | } 42 | 43 | // LoadAtRate loads audio from an ffmpeg-decodable file from a path (which may be a URL) and returns it at the given sample rate. 44 | func LoadAtRate(path string, rate int) (*audio.Audio, error) { 45 | cmd := exec.Command("ffmpeg", "-i", path, "-vn", "-acodec", "pcm_s16le", "-f", "wav", "-ar", fmt.Sprint(rate), "-") 46 | stdout, stderr := &bytes.Buffer{}, &bytes.Buffer{} 47 | cmd.Stdout, cmd.Stderr = stdout, stderr 48 | if err := cmd.Run(); err != nil { 49 | return nil, fmt.Errorf("while executing %v: %v\n%v", cmd, err, stderr.String()) 50 | } 51 | w, err := audio.ReadWAV(stdout) 52 | if err != nil { 53 | return nil, err 54 | } 55 | return w.Audio() 56 | } 57 | 58 | // Copy copies any file from a path (which may be a URL) and returns a path inside dir containing the file. 59 | func Copy(path string, dir string) (string, error) { 60 | // This function uses ffmpeg since it both verifies that the file is a proper media file, and handles 61 | // URLs and paths exactly like the other functions in this package. 62 | outFile, err := os.CreateTemp(dir, fmt.Sprintf("zimtohrli.go.aio.Copy.*%s", filepath.Ext(path))) 63 | if err != nil { 64 | return "", err 65 | } 66 | outFile.Close() 67 | cmd := exec.Command("ffmpeg", "-y", "-i", path, "-vn", "-acodec", "copy", outFile.Name()) 68 | ffmpegResult, err := cmd.CombinedOutput() 69 | if err != nil { 70 | return "", fmt.Errorf("trying to execute %v: %v\n%s", cmd, err, ffmpegResult) 71 | } 72 | return filepath.Rel(dir, outFile.Name()) 73 | } 74 | 75 | // Recode copies an ffmpeg-decodable file from path (which may be a URL) and returns a path inside dir containing a FLAC encoded version of it. 76 | func Recode(path string, dir string) (string, error) { 77 | flacFile, err := os.CreateTemp(dir, "zimtohrli.go.aio.Recode.*.flac") 78 | if err != nil { 79 | return "", err 80 | } 81 | flacFile.Close() 82 | cmd := exec.Command("ffmpeg", "-y", "-i", path, "-vn", "-acodec", "flac", "-f", "flac", flacFile.Name()) 83 | ffmpegResult, err := cmd.CombinedOutput() 84 | if err != nil { 85 | return "", fmt.Errorf("trying to execute %v: %v\n%s", cmd, err, ffmpegResult) 86 | } 87 | return filepath.Rel(dir, flacFile.Name()) 88 | } 89 | 90 | // DumpWAV stores the audio as a WAV in a temporary directory and returns the path. 91 | func DumpWAV(audio *audio.Audio) (string, error) { 92 | wavFile, err := os.CreateTemp(os.TempDir(), "zimtohrli.go.aio.DumpWAV.*.wav") 93 | if err != nil { 94 | return "", err 95 | } 96 | wavFile.Close() 97 | return wavFile.Name(), Save(audio, wavFile.Name()) 98 | } 99 | 100 | // Save stores the audio in the ffmpeg-encodable path. 101 | func Save(audio *audio.Audio, path string) error { 102 | buf := &bytes.Buffer{} 103 | for sampleIndex := range audio.Samples[0] { 104 | for channelIndex := range audio.Samples { 105 | if err := binary.Write(buf, binary.LittleEndian, audio.Samples[channelIndex][sampleIndex]); err != nil { 106 | return err 107 | } 108 | } 109 | } 110 | cmd := exec.Command("ffmpeg", "-y", "-ac", fmt.Sprint(len(audio.Samples)), "-f", "f32le", "-ar", fmt.Sprint(int(audio.Rate)), "-i", "-", path) 111 | cmd.Stdin = buf 112 | ffmpegResult, err := cmd.CombinedOutput() 113 | if err != nil { 114 | return fmt.Errorf("trying to execute %v: %v\n%s", cmd, err, ffmpegResult) 115 | } 116 | return nil 117 | } 118 | -------------------------------------------------------------------------------- /cpp/zimt/visqol.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include "zimt/visqol.h" 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | #include "absl/log/check.h" 25 | #include "absl/status/statusor.h" 26 | #include "absl/types/span.h" 27 | #include "visqol_api.h" 28 | #include "zimt/resample.h" 29 | #include "zimt/visqol_model.h" 30 | #include "zimt/zimtohrli.h" 31 | 32 | #ifdef _WIN32 33 | #include 34 | #endif 35 | 36 | constexpr size_t SAMPLE_RATE = 48000; 37 | 38 | namespace zimtohrli { 39 | 40 | ViSQOL::ViSQOL() { 41 | std::string path_template = (std::filesystem::temp_directory_path() / 42 | "zimtohrli_cpp_zimt_visqol_model_XXXXXX") 43 | .string(); 44 | std::vector populated_path_template(path_template.begin(), 45 | path_template.end()); 46 | populated_path_template.push_back('\0'); 47 | const int model_path_file = mkstemp(populated_path_template.data()); 48 | CHECK_GT(model_path_file, 0) << strerror(errno); 49 | #ifdef _WIN32 50 | CHECK_EQ(_close(model_path_file), 0); 51 | #else 52 | CHECK_EQ(close(model_path_file), 0); 53 | #endif 54 | model_path_ = std::filesystem::path(std::string( 55 | populated_path_template.data(), populated_path_template.size())); 56 | std::ofstream output_stream(model_path_); 57 | CHECK(output_stream.good()); 58 | Span model = ViSQOLModel(); 59 | output_stream.write(model.data, model.size); 60 | CHECK(output_stream.good()); 61 | output_stream.close(); 62 | CHECK(output_stream.good()); 63 | } 64 | 65 | ViSQOL::~ViSQOL() { std::filesystem::remove(model_path_); } 66 | 67 | absl::StatusOr ViSQOL::MOS(Span reference, 68 | Span degraded, 69 | float sample_rate) const { 70 | std::vector resampled_reference = 71 | Resample(reference, sample_rate, SAMPLE_RATE); 72 | std::vector resampled_degraded = 73 | Resample(degraded, sample_rate, SAMPLE_RATE); 74 | 75 | Visqol::VisqolConfig config; 76 | config.mutable_options()->set_svr_model_path(model_path_.string()); 77 | config.mutable_audio()->set_sample_rate(SAMPLE_RATE); 78 | 79 | // When running in audio mode, sample rates of 48k is recommended for 80 | // the input signals. Using non-48k input will very likely negatively 81 | // affect the comparison result. If, however, API users wish to run with 82 | // non-48k input, set this to true. 83 | config.mutable_options()->set_allow_unsupported_sample_rates(false); 84 | 85 | // ViSQOL will run in audio mode comparison by default. 86 | // If speech mode comparison is desired, set to true. 87 | config.mutable_options()->set_use_speech_scoring(false); 88 | 89 | // Speech mode will scale the MOS mapping by default. This means that a 90 | // perfect NSIM score of 1.0 will be mapped to a perfect MOS-LQO of 5.0. 91 | // Set to true to use unscaled speech mode. This means that a perfect 92 | // NSIM score will instead be mapped to a MOS-LQO of ~4.x. 93 | config.mutable_options()->set_use_unscaled_speech_mos_mapping(false); 94 | 95 | Visqol::VisqolApi visqol; 96 | CHECK_OK(visqol.Create(config)); 97 | 98 | absl::StatusOr comparison_status_or = 99 | visqol.Measure(absl::Span(resampled_reference.data(), 100 | resampled_reference.size()), 101 | absl::Span(resampled_degraded.data(), 102 | resampled_degraded.size())); 103 | if (!comparison_status_or.ok()) { 104 | return absl::Status(comparison_status_or.status().code(), 105 | "when calling visqol.Measure"); 106 | } 107 | 108 | Visqol::SimilarityResultMsg similarity_result = comparison_status_or.value(); 109 | 110 | return similarity_result.moslqo(); 111 | } 112 | 113 | } // namespace zimtohrli -------------------------------------------------------------------------------- /go/bin/compare/compare.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // compare is a Go version of compare.cc. 16 | package main 17 | 18 | import ( 19 | "encoding/json" 20 | "flag" 21 | "fmt" 22 | "log" 23 | "os" 24 | "reflect" 25 | 26 | "github.com/google/zimtohrli/go/aio" 27 | "github.com/google/zimtohrli/go/goohrli" 28 | "github.com/google/zimtohrli/go/pipe" 29 | ) 30 | 31 | func main() { 32 | pathA := flag.String("path_a", "", "Path to ffmpeg-decodable file with signal A.") 33 | pathB := flag.String("path_b", "", "Path to ffmpeg-decodable file with signal B.") 34 | visqol := flag.Bool("visqol", false, "Whether to measure using ViSQOL.") 35 | pipeMetric := flag.String("pipe_metric", "", "Path to a binary that serves metrics via stdin/stdout pipe. Install some of them via 'install_python_metrics.py'.") 36 | zimtohrli := flag.Bool("zimtohrli", true, "Whether to measure using Zimtohrli.") 37 | outputZimtohrliDistance := flag.Bool("output_zimtohrli_distance", false, "Whether to output the raw Zimtohrli distance instead of a mapped mean opinion score.") 38 | zimtohrliParameters := goohrli.DefaultParameters() 39 | b, err := json.Marshal(zimtohrliParameters) 40 | if err != nil { 41 | log.Panic(err) 42 | } 43 | zimtohrliParametersJSON := flag.String("zimtohrli_parameters", string(b), "Zimtohrli model parameters.") 44 | perChannel := flag.Bool("per_channel", false, "Whether to output the produced metric per channel instead of a single value for all channels.") 45 | flag.Parse() 46 | 47 | if *pathA == "" || *pathB == "" { 48 | flag.Usage() 49 | os.Exit(1) 50 | } 51 | 52 | signalA, err := aio.LoadAtRate(*pathA, int(goohrli.SampleRate())) 53 | if err != nil { 54 | log.Panic(err) 55 | } 56 | signalB, err := aio.LoadAtRate(*pathB, int(goohrli.SampleRate())) 57 | if err != nil { 58 | log.Panic(err) 59 | } 60 | 61 | if signalA.Rate != signalB.Rate { 62 | log.Panic(fmt.Errorf("sample rate of %q is %v, and sample rate of %q is %v", *pathA, signalA.Rate, *pathB, signalB.Rate)) 63 | } 64 | 65 | if len(signalA.Samples) != len(signalB.Samples) { 66 | log.Panic(fmt.Errorf("%q has %v channels, and %q has %v channels", *pathA, len(signalA.Samples), *pathB, len(signalB.Samples))) 67 | } 68 | 69 | if *pipeMetric != "" { 70 | metric, err := pipe.StartMetric(*pipeMetric) 71 | if err != nil { 72 | log.Panic(err) 73 | } 74 | defer metric.Close() 75 | scoreType, err := metric.ScoreType() 76 | if err != nil { 77 | log.Panic(err) 78 | } 79 | score, err := metric.Measure(signalA, signalB) 80 | if err != nil { 81 | log.Panic(err) 82 | } 83 | fmt.Printf("%v=%v\n", scoreType, score) 84 | } 85 | 86 | if *visqol { 87 | v := goohrli.NewViSQOL() 88 | if *perChannel { 89 | for channelIndex := range signalA.Samples { 90 | mos, err := v.MOS(signalA.Rate, signalA.Samples[channelIndex], signalB.Samples[channelIndex]) 91 | if err != nil { 92 | log.Panic(err) 93 | } 94 | fmt.Printf("ViSQOL#%v=%v\n", channelIndex, mos) 95 | } 96 | } else { 97 | mos, err := v.AudioMOS(signalA, signalB) 98 | if err != nil { 99 | log.Panic(err) 100 | } 101 | fmt.Printf("ViSQOL=%v\n", mos) 102 | } 103 | } 104 | 105 | if *zimtohrli { 106 | getMetric := func(f float64) float64 { 107 | if *outputZimtohrliDistance { 108 | return f 109 | } 110 | return goohrli.MOSFromZimtohrli(f) 111 | } 112 | 113 | if err := zimtohrliParameters.Update([]byte(*zimtohrliParametersJSON)); err != nil { 114 | log.Panic(err) 115 | } 116 | if !reflect.DeepEqual(zimtohrliParameters, goohrli.DefaultParameters()) { 117 | log.Printf("Using %+v", zimtohrliParameters) 118 | } 119 | g := goohrli.New(zimtohrliParameters) 120 | if *perChannel { 121 | for channelIndex := range signalA.Samples { 122 | fmt.Printf("Zimtohrli#%v=%v\n", channelIndex, getMetric(g.Distance(signalA.Samples[channelIndex], signalB.Samples[channelIndex]))) 123 | } 124 | } else { 125 | dist, err := g.NormalizedAudioDistance(signalA, signalB) 126 | if err != nil { 127 | log.Panic(err) 128 | } 129 | fmt.Printf("Zimtohrli=%v\n", getMetric(dist)) 130 | } 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /go/progress/bar.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // Package progress paints a very simple progress bar on the screen 16 | package progress 17 | 18 | import ( 19 | "bytes" 20 | "fmt" 21 | "log" 22 | "math" 23 | "os" 24 | "sync" 25 | "syscall" 26 | "time" 27 | "unsafe" 28 | ) 29 | 30 | type winsize struct { 31 | Row uint16 32 | Col uint16 33 | Xpixel uint16 34 | Ypixel uint16 35 | } 36 | 37 | func getTerminalWidth() (int, error) { 38 | ws := &winsize{} 39 | retCode, _, errno := syscall.Syscall(syscall.SYS_IOCTL, 40 | uintptr(syscall.Stdin), 41 | uintptr(syscall.TIOCGWINSZ), 42 | uintptr(unsafe.Pointer(ws))) 43 | 44 | if int(retCode) == -1 { 45 | return 0, fmt.Errorf("Syscall returned %v", errno) 46 | } 47 | return int(ws.Col), nil 48 | } 49 | 50 | // New returns a new progress bar. 51 | func New(name string) *Bar { 52 | now := time.Now() 53 | return &Bar{ 54 | name: name, 55 | created: now, 56 | lastRender: now, 57 | } 58 | } 59 | 60 | // Bar contains state for a progress bar. 61 | type Bar struct { 62 | name string 63 | created time.Time 64 | completed int 65 | errors int 66 | total int 67 | emaCompletedSpeed float64 68 | emaFractionSpeed float64 69 | lastRender time.Time 70 | lock sync.Mutex 71 | } 72 | 73 | // AddCompleted adds completed tasks to the bar and renders it. 74 | func (b *Bar) AddCompleted(num int) { 75 | b.Update(b.total, b.completed+num, b.errors) 76 | } 77 | 78 | // Finish prints the final actual time of completion and a newline. 79 | func (b *Bar) Finish() { 80 | prefix := fmt.Sprintf("%s, %d/%d/%d ", b.name, b.completed, b.errors, b.total) 81 | atc := time.Since(b.created) 82 | speed := float64(b.completed) / float64(atc) 83 | round := time.Minute 84 | if atc < time.Minute { 85 | round = time.Second 86 | } 87 | suffix := fmt.Sprintf(" %.2f/s ATC: %s", speed*float64(time.Second), atc.Round(round)) 88 | 89 | fmt.Fprintf(os.Stderr, "\r%s%s%s\n", prefix, b.filler(prefix, suffix), suffix) 90 | } 91 | 92 | func (b *Bar) filler(prefix, suffix string) string { 93 | width, err := getTerminalWidth() 94 | if err != nil { 95 | log.Println(err) 96 | return "" 97 | } 98 | numFiller := width - len(prefix) - len(suffix) 99 | completedFiller := int(float64(numFiller) * float64(b.completed) / float64(b.total)) 100 | errorFiller := int(float64(numFiller) * float64(b.errors) / float64(b.total)) 101 | filler := &bytes.Buffer{} 102 | for i := 0; i < numFiller; i++ { 103 | if i < completedFiller { 104 | fmt.Fprintf(filler, "#") 105 | } else if i < completedFiller+errorFiller { 106 | fmt.Fprintf(filler, "☠") 107 | } else { 108 | fmt.Fprintf(filler, " ") 109 | } 110 | } 111 | return filler.String() 112 | } 113 | 114 | // Update update completed and total tasks to the bar and updates it. 115 | func (b *Bar) Update(total, completed, errors int) { 116 | b.lock.Lock() 117 | defer b.lock.Unlock() 118 | if completed < b.completed { 119 | return 120 | } 121 | 122 | prefix := fmt.Sprintf("%s, %d/%d/%d ", b.name, completed, errors, total) 123 | 124 | now := time.Now() 125 | fraction := float64(completed) / float64(total) 126 | if timeLived := now.Sub(b.created); timeLived < 2*time.Second { 127 | b.emaCompletedSpeed = float64(completed) / float64(timeLived) 128 | b.emaFractionSpeed = fraction / float64(timeLived) 129 | } else { 130 | timeUsed := now.Sub(b.lastRender) 131 | currentCompletedSpeed := float64(completed-b.completed) / float64(timeUsed) 132 | currentFractionSpeed := (fraction - (float64(b.completed) / float64(b.total))) / float64(timeUsed) 133 | secondsUsed := float64(timeUsed) / float64(time.Second) 134 | w := math.Exp(-0.1 * secondsUsed) 135 | b.emaCompletedSpeed = b.emaCompletedSpeed*w + currentCompletedSpeed*(1-w) 136 | b.emaFractionSpeed = b.emaFractionSpeed*w + currentFractionSpeed*(1-w) 137 | } 138 | eta := time.Duration((1 - fraction) / b.emaFractionSpeed) 139 | round := time.Minute 140 | if eta < 2*time.Minute { 141 | round = time.Second 142 | } 143 | suffix := fmt.Sprintf(" %.2f/s ETA: %s", b.emaCompletedSpeed*float64(time.Second), eta.Round(round)) 144 | 145 | b.completed = completed 146 | b.errors = errors 147 | b.total = total 148 | b.lastRender = now 149 | 150 | os.Stderr.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) 151 | fmt.Fprintf(os.Stderr, "\r%s%s%s", prefix, b.filler(prefix, suffix), suffix) 152 | } 153 | -------------------------------------------------------------------------------- /go/bin/tcd_voip/tcd_voip.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // tcd_voip creates a study from https://qxlab.ucd.ie/index.php/tcd-voip-dataset/. 16 | // 17 | // It requires the user to open "TCD VOIP - Test Set Conditions and MOS Results.xlsx" 18 | // in a compatible spreadsheet application and export the "Subjective Test Scores" 19 | // tab to a CSV file. 20 | package main 21 | 22 | import ( 23 | "encoding/csv" 24 | "flag" 25 | "fmt" 26 | "log" 27 | "os" 28 | "path/filepath" 29 | "regexp" 30 | "runtime" 31 | "strconv" 32 | "strings" 33 | 34 | "github.com/google/zimtohrli/go/aio" 35 | "github.com/google/zimtohrli/go/data" 36 | "github.com/google/zimtohrli/go/progress" 37 | "github.com/google/zimtohrli/go/worker" 38 | ) 39 | 40 | var fileReg = regexp.MustCompile("[^_]+(_[^_]+_([^_]+)_[^_]+.wav)") 41 | 42 | func populate(source string, dest string, workers int, failFast bool) error { 43 | study, err := data.OpenStudy(dest) 44 | if err != nil { 45 | return err 46 | } 47 | defer study.Close() 48 | 49 | csvFiles, err := filepath.Glob(filepath.Join(source, "*.csv")) 50 | if err != nil { 51 | return err 52 | } 53 | if len(csvFiles) != 1 { 54 | return fmt.Errorf("not exactly one .csv file in %q", source) 55 | } 56 | csvFile, err := os.Open(csvFiles[0]) 57 | if err != nil { 58 | return err 59 | } 60 | defer csvFile.Close() 61 | csvReader := csv.NewReader(csvFile) 62 | header, err := csvReader.Read() 63 | if err != nil { 64 | return err 65 | } 66 | if strings.Join(header, ",") != "Filename,ConditionID,sample MOS,listener # ->,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24" { 67 | return fmt.Errorf("header %+v doesn't match expected TCD VOIP header", header) 68 | } 69 | err = nil 70 | bar := progress.New("Transcoding") 71 | pool := worker.Pool[*data.Reference]{ 72 | Workers: workers, 73 | OnChange: bar.Update, 74 | FailFast: failFast, 75 | } 76 | var loopLine []string 77 | lineIndex := 0 78 | for loopLine, err = csvReader.Read(); err == nil; loopLine, err = csvReader.Read() { 79 | line := loopLine 80 | match := fileReg.FindStringSubmatch(line[0]) 81 | if match == nil { 82 | return fmt.Errorf("line %+v doesn't have a file matching %v", line, fileReg) 83 | } 84 | distPath := filepath.Join(source, "Test Set", strings.ToLower(match[2]), line[0]) 85 | if _, err := os.Stat(distPath); err != nil { 86 | return err 87 | } 88 | refPath := filepath.Join(source, "Test Set", strings.ToLower(match[2]), "ref", fmt.Sprintf("R%s", match[1])) 89 | if _, err := os.Stat(refPath); err != nil { 90 | return err 91 | } 92 | mos, err := strconv.ParseFloat(line[2], 64) 93 | if err != nil { 94 | return err 95 | } 96 | refIndex := lineIndex 97 | pool.Submit(func(f func(*data.Reference)) error { 98 | ref := &data.Reference{ 99 | Name: fmt.Sprintf("ref-%v", refIndex), 100 | } 101 | var err error 102 | ref.Path, err = aio.Recode(refPath, dest) 103 | if err != nil { 104 | return fmt.Errorf("unable to fetch %q", refPath) 105 | } 106 | dist := &data.Distortion{ 107 | Name: fmt.Sprintf("dist-%v", refIndex), 108 | Scores: map[data.ScoreType]float64{ 109 | data.MOS: mos, 110 | }, 111 | } 112 | dist.Path, err = aio.Recode(distPath, dest) 113 | if err != nil { 114 | return fmt.Errorf("unable to fetch %q", distPath) 115 | } 116 | ref.Distortions = append(ref.Distortions, dist) 117 | f(ref) 118 | return nil 119 | }) 120 | lineIndex++ 121 | } 122 | if err := pool.Error(); err != nil { 123 | log.Println(err.Error()) 124 | } 125 | bar.Finish() 126 | refs := []*data.Reference{} 127 | for ref := range pool.Results() { 128 | refs = append(refs, ref) 129 | } 130 | if err := study.Put(refs); err != nil { 131 | return err 132 | } 133 | return nil 134 | } 135 | 136 | func main() { 137 | source := flag.String("source", "", "Directory containing the unpacked Dataset zip from https://qxlab.ucd.ie/index.php/tcd-voip-dataset/ along with a CSV export of the 'Subjective Test Scores' tab of 'TCD VOIP - Test Set Conditions and MOS Results.xlsx'.") 138 | destination := flag.String("dest", "", "Destination directory.") 139 | workers := flag.Int("workers", runtime.NumCPU(), "Number of workers transcoding sounds.") 140 | failFast := flag.Bool("fail_fast", false, "Whether to exit immediately at the first error.") 141 | flag.Parse() 142 | if *source == "" || *destination == "" { 143 | flag.Usage() 144 | os.Exit(1) 145 | } 146 | 147 | if err := populate(*source, *destination, *workers, *failFast); err != nil { 148 | log.Fatal(err) 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /go/audio/audio_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package audio 16 | 17 | import ( 18 | "bytes" 19 | "math" 20 | "testing" 21 | ) 22 | 23 | type test1khz48khzWAV struct { 24 | name string 25 | data []byte 26 | channels int 27 | } 28 | 29 | var ( 30 | wavs = []test1khz48khzWAV{ 31 | { 32 | name: "1khz_sine_48khz_rate_1ch_s16le_wav", 33 | data: []byte{ 34 | 0x52, 0x49, 0x46, 0x46, 0xa6, 0x00, 0x00, 0x00, 0x57, 0x41, 0x56, 0x45, 35 | 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 36 | 0x80, 0xbb, 0x00, 0x00, 0x00, 0x77, 0x01, 0x00, 0x02, 0x00, 0x10, 0x00, 37 | 0x4c, 0x49, 0x53, 0x54, 0x1a, 0x00, 0x00, 0x00, 0x49, 0x4e, 0x46, 0x4f, 38 | 0x49, 0x53, 0x46, 0x54, 0x0e, 0x00, 0x00, 0x00, 0x4c, 0x61, 0x76, 0x66, 39 | 0x36, 0x30, 0x2e, 0x31, 0x36, 0x2e, 0x31, 0x30, 0x30, 0x00, 0x64, 0x61, 40 | 0x74, 0x61, 0x60, 0x00, 0x00, 0x00, 0x00, 0x00, 0x16, 0x02, 0x24, 0x04, 41 | 0x1e, 0x06, 0xff, 0x07, 0xbd, 0x09, 0x4f, 0x0b, 0xb1, 0x0c, 0xda, 0x0d, 42 | 0xc7, 0x0e, 0x73, 0x0f, 0xdc, 0x0f, 0xff, 0x0f, 0xdc, 0x0f, 0x74, 0x0f, 43 | 0xc8, 0x0e, 0xdb, 0x0d, 0xb1, 0x0c, 0x50, 0x0b, 0xbd, 0x09, 0x00, 0x08, 44 | 0x20, 0x06, 0x24, 0x04, 0x17, 0x02, 0x01, 0x00, 0xea, 0xfd, 0xdc, 0xfb, 45 | 0xe2, 0xf9, 0x01, 0xf8, 0x43, 0xf6, 0xb1, 0xf4, 0x4f, 0xf3, 0x26, 0xf2, 46 | 0x39, 0xf1, 0x8d, 0xf0, 0x24, 0xf0, 0x01, 0xf0, 0x24, 0xf0, 0x8c, 0xf0, 47 | 0x38, 0xf1, 0x25, 0xf2, 0x4f, 0xf3, 0xb0, 0xf4, 0x43, 0xf6, 0x00, 0xf8, 48 | 0xe0, 0xf9, 0xdc, 0xfb, 0xe9, 0xfd, 49 | }, 50 | channels: 1, 51 | }, 52 | { 53 | name: "1khz_sine_48khz_rate_2ch_s16le_wav", 54 | data: []byte{ 55 | 0x52, 0x49, 0x46, 0x46, 0x06, 0x01, 0x00, 0x00, 0x57, 0x41, 0x56, 0x45, 56 | 0x66, 0x6d, 0x74, 0x20, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x02, 0x00, 57 | 0x80, 0xbb, 0x00, 0x00, 0x00, 0xee, 0x02, 0x00, 0x04, 0x00, 0x10, 0x00, 58 | 0x4c, 0x49, 0x53, 0x54, 0x1a, 0x00, 0x00, 0x00, 0x49, 0x4e, 0x46, 0x4f, 59 | 0x49, 0x53, 0x46, 0x54, 0x0e, 0x00, 0x00, 0x00, 0x4c, 0x61, 0x76, 0x66, 60 | 0x36, 0x30, 0x2e, 0x31, 0x36, 0x2e, 0x31, 0x30, 0x30, 0x00, 0x64, 0x61, 61 | 0x74, 0x61, 0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x16, 0x02, 62 | 0x16, 0x02, 0x24, 0x04, 0x24, 0x04, 0x1e, 0x06, 0x1e, 0x06, 0xff, 0x07, 63 | 0xff, 0x07, 0xbd, 0x09, 0xbd, 0x09, 0x4f, 0x0b, 0x4f, 0x0b, 0xb1, 0x0c, 64 | 0xb1, 0x0c, 0xda, 0x0d, 0xda, 0x0d, 0xc7, 0x0e, 0xc7, 0x0e, 0x73, 0x0f, 65 | 0x73, 0x0f, 0xdc, 0x0f, 0xdc, 0x0f, 0xff, 0x0f, 0xff, 0x0f, 0xdc, 0x0f, 66 | 0xdc, 0x0f, 0x74, 0x0f, 0x74, 0x0f, 0xc8, 0x0e, 0xc8, 0x0e, 0xdb, 0x0d, 67 | 0xdb, 0x0d, 0xb1, 0x0c, 0xb1, 0x0c, 0x50, 0x0b, 0x50, 0x0b, 0xbd, 0x09, 68 | 0xbd, 0x09, 0x00, 0x08, 0x00, 0x08, 0x20, 0x06, 0x20, 0x06, 0x24, 0x04, 69 | 0x24, 0x04, 0x17, 0x02, 0x17, 0x02, 0x01, 0x00, 0x01, 0x00, 0xea, 0xfd, 70 | 0xea, 0xfd, 0xdc, 0xfb, 0xdc, 0xfb, 0xe2, 0xf9, 0xe2, 0xf9, 0x01, 0xf8, 71 | 0x01, 0xf8, 0x43, 0xf6, 0x43, 0xf6, 0xb1, 0xf4, 0xb1, 0xf4, 0x4f, 0xf3, 72 | 0x4f, 0xf3, 0x26, 0xf2, 0x26, 0xf2, 0x39, 0xf1, 0x39, 0xf1, 0x8d, 0xf0, 73 | 0x8d, 0xf0, 0x24, 0xf0, 0x24, 0xf0, 0x01, 0xf0, 0x01, 0xf0, 0x24, 0xf0, 74 | 0x24, 0xf0, 0x8c, 0xf0, 0x8c, 0xf0, 0x38, 0xf1, 0x38, 0xf1, 0x25, 0xf2, 75 | 0x25, 0xf2, 0x4f, 0xf3, 0x4f, 0xf3, 0xb0, 0xf4, 0xb0, 0xf4, 0x43, 0xf6, 76 | 0x43, 0xf6, 0x00, 0xf8, 0x00, 0xf8, 0xe0, 0xf9, 0xe0, 0xf9, 0xdc, 0xfb, 77 | 0xdc, 0xfb, 0xe9, 0xfd, 0xe9, 0xfd, 78 | }, 79 | channels: 2, 80 | }, 81 | } 82 | ) 83 | 84 | func readWAVTest(data []byte, wantChannels int) func(t *testing.T) { 85 | return func(t *testing.T) { 86 | t.Helper() 87 | w, err := ReadWAV(bytes.NewBuffer(data)) 88 | if err != nil { 89 | t.Fatal(err) 90 | } 91 | if w.FormatChunk.NumChannels != int16(wantChannels) { 92 | t.Errorf("got %v channels, want %v", w.FormatChunk.NumChannels, wantChannels) 93 | } 94 | if w.FormatChunk.SampleRate != 48000 { 95 | t.Errorf("got sample rate %v, want %v", w.FormatChunk.SampleRate, 48000) 96 | } 97 | sampleRateReciporcal := 1.0 / float64(w.FormatChunk.SampleRate) 98 | audio, err := w.Audio() 99 | if err != nil { 100 | t.Fatal(err) 101 | } 102 | audio.Amplify(1.0 / audio.MaxAbsAmplitude) 103 | for _, channel := range audio.Samples { 104 | for sampleIndex, sample := range channel { 105 | wantSample := math.Sin(2 * math.Pi * 1000 * sampleRateReciporcal * float64(sampleIndex)) 106 | if math.Abs(wantSample-float64(sample)) > 1e-3 { 107 | t.Errorf("got sample %v %v, want %v", sampleIndex, sample, wantSample) 108 | } 109 | } 110 | } 111 | 112 | } 113 | } 114 | 115 | func TestReadWAV(t *testing.T) { 116 | for _, w := range wavs { 117 | t.Run(w.name, readWAVTest(w.data, w.channels)) 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /go/bin/sebass_db/sebass_db.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // sebass_db creates a study from https://www.audiolabs-erlangen.de/resources/2019-WASPAA-SEBASS/. 16 | // 17 | // It currently supports SASSEC, SiSEC08, SAOC, and PEASS-DB. SiSEC18 at 18 | // that web site doesn't contain all audio, and is not currently supported. 19 | // 20 | // Download and unpack one of the supported ZIP files, and use the directory 21 | // it unpacked as -source when running this binary. 22 | package main 23 | 24 | import ( 25 | "encoding/csv" 26 | "flag" 27 | "fmt" 28 | "log" 29 | "math" 30 | "os" 31 | "path/filepath" 32 | "reflect" 33 | "runtime" 34 | "strconv" 35 | 36 | "github.com/google/zimtohrli/go/aio" 37 | "github.com/google/zimtohrli/go/data" 38 | "github.com/google/zimtohrli/go/progress" 39 | "github.com/google/zimtohrli/go/worker" 40 | ) 41 | 42 | func populate(source string, dest string, workers int, failFast bool) error { 43 | study, err := data.OpenStudy(dest) 44 | if err != nil { 45 | return err 46 | } 47 | defer study.Close() 48 | 49 | csvFiles, err := filepath.Glob(filepath.Join(source, "*.csv")) 50 | if err != nil { 51 | return err 52 | } 53 | for _, csvFile := range csvFiles { 54 | signals := "Signals" 55 | switch filepath.Base(csvFile) { 56 | case "SAOC_1_anonymized.csv": 57 | signals = "Signals_1" 58 | case "SAOC_2_anonymized.csv": 59 | signals = "Signals_2" 60 | case "SAOC_3_anonymized.csv": 61 | signals = "Signals_3" 62 | } 63 | fileReader, err := os.Open(csvFile) 64 | if err != nil { 65 | return err 66 | } 67 | defer fileReader.Close() 68 | csvReader := csv.NewReader(fileReader) 69 | header, err := csvReader.Read() 70 | if err != nil { 71 | return err 72 | } 73 | if !reflect.DeepEqual(header, []string{"Testname", "Listener", "Trial", "Condition", "Ratingscore"}) { 74 | return fmt.Errorf("header %+v doesn't match expected SEBASS-DB header", header) 75 | } 76 | err = nil 77 | bar := progress.New(fmt.Sprintf("Transcoding from %q", csvFile)) 78 | pool := worker.Pool[*data.Reference]{ 79 | Workers: workers, 80 | OnChange: bar.Update, 81 | FailFast: failFast, 82 | } 83 | var loopLine []string 84 | lineIndex := 0 85 | for loopLine, err = csvReader.Read(); err == nil; loopLine, err = csvReader.Read() { 86 | line := loopLine 87 | if len(line) == 0 { 88 | continue 89 | } 90 | if line[3] == "anchor" { 91 | line[3] = "anker_mix" 92 | } 93 | if line[3] == "hidden_ref" { 94 | line[3] = "orig" 95 | } 96 | if line[3] == "SAOC" { 97 | continue 98 | } 99 | mos, err := strconv.ParseFloat(line[4], 64) 100 | if err != nil { 101 | return err 102 | } 103 | if math.IsNaN(mos) { 104 | continue 105 | } 106 | refIndex := lineIndex 107 | pool.Submit(func(f func(*data.Reference)) error { 108 | ref := &data.Reference{ 109 | Name: fmt.Sprintf("ref-%v", refIndex), 110 | } 111 | var err error 112 | path := filepath.Join(source, signals, "orig", fmt.Sprintf("%s.wav", line[2])) 113 | ref.Path, err = aio.Recode(path, dest) 114 | if err != nil { 115 | return fmt.Errorf("unable to fetch %q", path) 116 | } 117 | dist := &data.Distortion{ 118 | Name: fmt.Sprintf("dist-%v", refIndex), 119 | Scores: map[data.ScoreType]float64{ 120 | data.MOS: mos, 121 | }, 122 | } 123 | path = filepath.Join(source, signals, line[3], fmt.Sprintf("%s.wav", line[2])) 124 | dist.Path, err = aio.Recode(path, dest) 125 | if err != nil { 126 | return fmt.Errorf("unable to fetch %q", path) 127 | } 128 | ref.Distortions = append(ref.Distortions, dist) 129 | f(ref) 130 | return nil 131 | }) 132 | lineIndex++ 133 | } 134 | if err := pool.Error(); err != nil { 135 | log.Println(err.Error()) 136 | } 137 | bar.Finish() 138 | refs := []*data.Reference{} 139 | for ref := range pool.Results() { 140 | refs = append(refs, ref) 141 | } 142 | if err := study.Put(refs); err != nil { 143 | return err 144 | } 145 | } 146 | return nil 147 | } 148 | 149 | func main() { 150 | source := flag.String("source", "", "Directory containing one of the unpacked datasets from https://www.audiolabs-erlangen.de/resources/2019-WASPAA-SEBASS/.") 151 | destination := flag.String("dest", "", "Destination directory.") 152 | workers := flag.Int("workers", runtime.NumCPU(), "Number of workers transcoding sounds.") 153 | failFast := flag.Bool("fail_fast", false, "Whether to exit immediately at the first error.") 154 | flag.Parse() 155 | if *source == "" || *destination == "" { 156 | flag.Usage() 157 | os.Exit(1) 158 | } 159 | 160 | if err := populate(*source, *destination, *workers, *failFast); err != nil { 161 | log.Fatal(err) 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /cmake/visqol.cmake: -------------------------------------------------------------------------------- 1 | FetchContent_Declare(visqol 2 | EXCLUDE_FROM_ALL 3 | GIT_REPOSITORY https://github.com/google/visqol.git 4 | GIT_TAG v3.3.3 5 | PATCH_COMMAND cp ${CMAKE_CURRENT_SOURCE_DIR}/cmake/visqol_manager.cc src/visqol_manager.cc 6 | ) 7 | FetchContent_MakeAvailable(visqol) 8 | 9 | add_library(visqol_proto STATIC ${visqol_SOURCE_DIR}/src/proto/similarity_result.proto ${visqol_SOURCE_DIR}/src/proto/visqol_config.proto) 10 | 11 | include(${protobuf_BINARY_DIR}/cmake/protobuf/protobuf-generate.cmake) 12 | 13 | set(visqol_PROTO_DIR ${visqol_BINARY_DIR}/generated) 14 | protobuf_generate( 15 | TARGET visqol_proto 16 | IMPORT_DIRS ${visqol_SOURCE_DIR} 17 | PROTOC_OUT_DIR ${visqol_PROTO_DIR} 18 | LANGUAGE cpp 19 | ) 20 | target_link_libraries(visqol_proto protobuf::libprotobuf) 21 | target_include_directories(visqol_proto PUBLIC ${visqol_PROTO_DIR}) 22 | 23 | set(visqol_MODEL_H ${visqol_SOURCE_DIR}/src/include/libsvm_nu_svr_model.h) 24 | add_custom_command( 25 | OUTPUT ${visqol_MODEL_H} 26 | COMMAND sh -c "xxd -i ${visqol_SOURCE_DIR}/model/libsvm_nu_svr_model.txt | sed 's/[^ ]*libsvm_nu_svr_model_txt/visqol_model_bytes/' > ${visqol_MODEL_H}" 27 | DEPENDS ${visqol_SOURCE_DIR}/model/libsvm_nu_svr_model.txt 28 | VERBATIM 29 | ) 30 | add_custom_target(visqol_model DEPENDS ${visqol_MODEL_H}) 31 | 32 | add_library(visqol STATIC 33 | ${visqol_SOURCE_DIR}/src/alignment.cc 34 | ${visqol_SOURCE_DIR}/src/amatrix.cc 35 | ${visqol_SOURCE_DIR}/src/analysis_window.cc 36 | ${visqol_SOURCE_DIR}/src/commandline_parser.cc 37 | ${visqol_SOURCE_DIR}/src/comparison_patches_selector.cc 38 | ${visqol_SOURCE_DIR}/src/complex_valarray.cc 39 | ${visqol_SOURCE_DIR}/src/convolution_2d.cc 40 | ${visqol_SOURCE_DIR}/src/envelope.cc 41 | ${visqol_SOURCE_DIR}/src/equivalent_rectangular_bandwidth.cc 42 | ${visqol_SOURCE_DIR}/src/fast_fourier_transform.cc 43 | ${visqol_SOURCE_DIR}/src/fft_manager.cc 44 | ${visqol_SOURCE_DIR}/src/gammatone_filterbank.cc 45 | ${visqol_SOURCE_DIR}/src/gammatone_spectrogram_builder.cc 46 | ${visqol_SOURCE_DIR}/src/image_patch_creator.cc 47 | ${visqol_SOURCE_DIR}/src/include/alignment.h 48 | ${visqol_SOURCE_DIR}/src/include/amatrix.h 49 | ${visqol_SOURCE_DIR}/src/include/analysis_window.h 50 | ${visqol_SOURCE_DIR}/src/include/audio_channel.h 51 | ${visqol_SOURCE_DIR}/src/include/audio_signal.h 52 | ${visqol_SOURCE_DIR}/src/include/commandline_parser.h 53 | ${visqol_SOURCE_DIR}/src/include/comparison_patches_selector.h 54 | ${visqol_SOURCE_DIR}/src/include/complex_valarray.h 55 | ${visqol_SOURCE_DIR}/src/include/conformance.h 56 | ${visqol_SOURCE_DIR}/src/include/convolution_2d.h 57 | ${visqol_SOURCE_DIR}/src/include/envelope.h 58 | ${visqol_SOURCE_DIR}/src/include/equivalent_rectangular_bandwidth.h 59 | ${visqol_SOURCE_DIR}/src/include/fast_fourier_transform.h 60 | ${visqol_SOURCE_DIR}/src/include/fft_manager.h 61 | ${visqol_SOURCE_DIR}/src/include/file_path.h 62 | ${visqol_SOURCE_DIR}/src/include/gammatone_filterbank.h 63 | ${visqol_SOURCE_DIR}/src/include/gammatone_spectrogram_builder.h 64 | ${visqol_SOURCE_DIR}/src/include/image_patch_creator.h 65 | ${visqol_SOURCE_DIR}/src/include/libsvm_target_observation_convertor.h 66 | ${visqol_SOURCE_DIR}/src/include/machine_learning.h 67 | ${visqol_SOURCE_DIR}/src/include/misc_audio.h 68 | ${visqol_SOURCE_DIR}/src/include/misc_math.h 69 | ${visqol_SOURCE_DIR}/src/include/misc_vector.h 70 | ${visqol_SOURCE_DIR}/src/include/neurogram_similiarity_index_measure.h 71 | ${visqol_SOURCE_DIR}/src/include/patch_similarity_comparator.h 72 | ${visqol_SOURCE_DIR}/src/include/rms_vad.h 73 | ${visqol_SOURCE_DIR}/src/include/signal_filter.h 74 | ${visqol_SOURCE_DIR}/src/include/similarity_result.h 75 | ${visqol_SOURCE_DIR}/src/include/similarity_to_quality_mapper.h 76 | ${visqol_SOURCE_DIR}/src/include/sim_results_writer.h 77 | ${visqol_SOURCE_DIR}/src/include/spectrogram_builder.h 78 | ${visqol_SOURCE_DIR}/src/include/spectrogram.h 79 | ${visqol_SOURCE_DIR}/src/include/speech_similarity_to_quality_mapper.h 80 | ${visqol_SOURCE_DIR}/src/include/status_macros.h 81 | ${visqol_SOURCE_DIR}/src/include/support_vector_regression_model.h 82 | ${visqol_SOURCE_DIR}/src/include/svr_similarity_to_quality_mapper.h 83 | ${visqol_SOURCE_DIR}/src/include/vad_patch_creator.h 84 | ${visqol_SOURCE_DIR}/src/include/visqol_api.h 85 | ${visqol_SOURCE_DIR}/src/include/visqol.h 86 | ${visqol_SOURCE_DIR}/src/include/wav_reader.h 87 | ${visqol_SOURCE_DIR}/src/include/xcorr.h 88 | ${visqol_SOURCE_DIR}/src/libsvm_target_observation_convertor.cc 89 | ${visqol_SOURCE_DIR}/src/misc_audio.cc 90 | ${visqol_SOURCE_DIR}/src/misc_math.cc 91 | ${visqol_SOURCE_DIR}/src/misc_vector.cc 92 | ${visqol_SOURCE_DIR}/src/neurogram_similiarity_index_measure.cc 93 | ${visqol_SOURCE_DIR}/src/rms_vad.cc 94 | ${visqol_SOURCE_DIR}/src/signal_filter.cc 95 | ${visqol_SOURCE_DIR}/src/spectrogram.cc 96 | ${visqol_SOURCE_DIR}/src/speech_similarity_to_quality_mapper.cc 97 | ${visqol_SOURCE_DIR}/src/support_vector_regression_model.cc 98 | ${visqol_SOURCE_DIR}/src/svr_similarity_to_quality_mapper.cc 99 | ${visqol_SOURCE_DIR}/src/svr_training/training_data_file_reader.cc 100 | ${visqol_SOURCE_DIR}/src/svr_training/training_data_file_reader.h 101 | ${visqol_SOURCE_DIR}/src/vad_patch_creator.cc 102 | ${visqol_SOURCE_DIR}/src/visqol_api.cc 103 | ${visqol_SOURCE_DIR}/src/visqol.cc 104 | ${visqol_SOURCE_DIR}/src/visqol_manager.cc 105 | ${visqol_SOURCE_DIR}/src/wav_reader.cc 106 | ${visqol_SOURCE_DIR}/src/xcorr.cc 107 | ) 108 | target_include_directories(visqol PUBLIC ${visqol_SOURCE_DIR} ${visqol_SOURCE_DIR}/src/include) 109 | target_link_libraries(visqol visqol_proto libsvm armadillo absl::flags_parse absl::span pffft) 110 | add_dependencies(visqol visqol_model) 111 | 112 | file(GLOB_RECURSE visqol_files ${visqol_SOURCE_DIR} *.cc *.c *.h) 113 | set_source_files_properties( 114 | ${visqol_files} 115 | TARGET_DIRECTORY visqol 116 | PROPERTIES SKIP_LINTING ON 117 | ) 118 | -------------------------------------------------------------------------------- /go/pipe/metric.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // Package pipe manages services communicating via pipes. 16 | package pipe 17 | 18 | import ( 19 | "bufio" 20 | "bytes" 21 | "fmt" 22 | "io" 23 | "os" 24 | "os/exec" 25 | "strconv" 26 | "strings" 27 | 28 | "github.com/google/zimtohrli/go/aio" 29 | "github.com/google/zimtohrli/go/audio" 30 | "github.com/google/zimtohrli/go/data" 31 | "github.com/google/zimtohrli/go/resource" 32 | ) 33 | 34 | // MeterPool contains a resource pool of pipe-communicating processes. 35 | type MeterPool struct { 36 | *resource.Pool[*Metric] 37 | 38 | ScoreType data.ScoreType 39 | } 40 | 41 | // NewMeterPool returns a new pool of pipe-communicating processes. 42 | func NewMeterPool(path string) (*MeterPool, error) { 43 | result := &MeterPool{ 44 | Pool: &resource.Pool[*Metric]{ 45 | Create: func() (*Metric, error) { return StartMetric(path) }, 46 | }, 47 | } 48 | metric, err := result.Get() 49 | if err != nil { 50 | return nil, err 51 | } 52 | defer result.Pool.Return(metric) 53 | result.ScoreType, err = metric.ScoreType() 54 | return result, err 55 | } 56 | 57 | // Close closes all the processes in the pool. 58 | func (m *MeterPool) Close() error { 59 | return m.Pool.Close() 60 | } 61 | 62 | // Measure returns the distance between ref and dist using a metric in the pool, and then returns it to the pool. 63 | func (m *MeterPool) Measure(ref, dist *audio.Audio) (float64, error) { 64 | metric, err := m.Pool.Get() 65 | if err != nil { 66 | return 0, err 67 | } 68 | result, err := metric.Measure(ref, dist) 69 | if err != nil { 70 | return 0, err 71 | } 72 | m.Pool.Return(metric) 73 | return result, nil 74 | } 75 | 76 | // Metric wraps a pipe-communicating process. 77 | type Metric struct { 78 | scoreType data.ScoreType 79 | stdin io.WriteCloser 80 | stdout *bufio.Reader 81 | stderr *bytes.Buffer 82 | nextLine string 83 | } 84 | 85 | // StartMetric starts a new pipe-communicating process. 86 | func StartMetric(path string) (*Metric, error) { 87 | cmd := exec.Command(path) 88 | stdin, err := cmd.StdinPipe() 89 | if err != nil { 90 | return nil, fmt.Errorf("creating stdin pipe for %v: %v", cmd, err) 91 | } 92 | stderr := &bytes.Buffer{} 93 | cmd.Stderr = stderr 94 | stdout, err := cmd.StdoutPipe() 95 | if err != nil { 96 | return nil, fmt.Errorf("creating stdout pipe for %v: %v", cmd, err) 97 | } 98 | m := &Metric{ 99 | stdin: stdin, 100 | stderr: stderr, 101 | stdout: bufio.NewReader(stdout), 102 | } 103 | if err := cmd.Start(); err != nil { 104 | return nil, fmt.Errorf("running %v: %v\n%s", cmd, err, stderr) 105 | } 106 | return m, nil 107 | } 108 | 109 | func (m *Metric) awaitReady() error { 110 | if m.scoreType != "" { 111 | return nil 112 | } 113 | var err error 114 | for ; err == nil && !strings.HasPrefix(m.nextLine, "READY:"); m.nextLine, err = m.stdout.ReadString('\n') { 115 | } 116 | if err != nil { 117 | return fmt.Errorf("waiting for READY: %v\n%s", err, m.stderr) 118 | } 119 | scoreType, found := strings.CutPrefix(strings.TrimSpace(m.nextLine), "READY:") 120 | if !found { 121 | return fmt.Errorf("%q doesn't have the prefix 'READY:'", m.nextLine) 122 | } 123 | m.scoreType = data.ScoreType(scoreType) 124 | return nil 125 | } 126 | 127 | // ScoreType waits for the process to emit it's score type and returns it. 128 | func (m *Metric) ScoreType() (data.ScoreType, error) { 129 | if err := m.awaitReady(); err != nil { 130 | return "", err 131 | } 132 | return m.scoreType, nil 133 | } 134 | 135 | func (m *Metric) await(msg string) error { 136 | var err error 137 | for ; err == nil && strings.TrimSpace(m.nextLine) != msg; m.nextLine, err = m.stdout.ReadString('\n') { 138 | } 139 | if err != nil { 140 | return fmt.Errorf("waiting for %q: %v\n%s", msg, err, m.stderr) 141 | } 142 | return nil 143 | } 144 | 145 | // Measure waits until the process has emitted it's score type (which signals that it's ready) and returns the score for the provided ref and dist. 146 | func (m *Metric) Measure(ref, dist *audio.Audio) (float64, error) { 147 | if err := m.awaitReady(); err != nil { 148 | return 0, err 149 | } 150 | refPath, err := aio.DumpWAV(ref) 151 | if err != nil { 152 | return 0, fmt.Errorf("dumping referenc audio: %v", err) 153 | } 154 | defer os.RemoveAll(refPath) 155 | distPath, err := aio.DumpWAV(dist) 156 | if err != nil { 157 | return 0, fmt.Errorf("dumping distortion audio: %v", err) 158 | } 159 | defer os.RemoveAll(distPath) 160 | if err := m.await("REF"); err != nil { 161 | return 0, err 162 | } 163 | if _, err := fmt.Fprintln(m.stdin, refPath); err != nil { 164 | return 0, fmt.Errorf("printing ref path: %v\n%s", err, m.stderr) 165 | } 166 | if err := m.await("DIST"); err != nil { 167 | return 0, err 168 | } 169 | if _, err := fmt.Fprintln(m.stdin, distPath); err != nil { 170 | return 0, fmt.Errorf("printing dist path: %v\n%s", err, m.stderr) 171 | } 172 | for ; err == nil && !strings.HasPrefix(m.nextLine, "SCORE="); m.nextLine, err = m.stdout.ReadString('\n') { 173 | } 174 | if err != nil { 175 | return 0, fmt.Errorf("waiting for SCORE=: %v\n%s", err, m.stderr) 176 | } 177 | scoreString, found := strings.CutPrefix(strings.TrimSpace(m.nextLine), "SCORE=") 178 | if !found { 179 | return 0, fmt.Errorf("%q doesn't have the prefix SCORE=", m.nextLine) 180 | } 181 | return strconv.ParseFloat(scoreString, 64) 182 | } 183 | 184 | // Close closes the process by closing it's stdin. 185 | func (m *Metric) Close() error { 186 | return m.stdin.Close() 187 | } 188 | -------------------------------------------------------------------------------- /go/audio/audio.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // Package audio handles WAV data. 16 | package audio 17 | 18 | import ( 19 | "encoding/binary" 20 | "fmt" 21 | "io" 22 | "math" 23 | "unsafe" 24 | ) 25 | 26 | // FixString contains 4 bytes that can render to a String. 27 | type FixString [4]byte 28 | 29 | func (f FixString) String() string { 30 | return string(f[:]) 31 | } 32 | 33 | // RIFFHeader contains a RIFF header. 34 | type RIFFHeader struct { 35 | ChunkID FixString 36 | ChunkSize int32 37 | Format FixString 38 | } 39 | 40 | // ReadRIFFHeader reads a RIFF header. 41 | func ReadRIFFHeader(r io.Reader) (*RIFFHeader, error) { 42 | result := &RIFFHeader{} 43 | if err := binary.Read(r, binary.LittleEndian, result); err != nil { 44 | return nil, err 45 | } 46 | return result, nil 47 | } 48 | 49 | // ChunkHeader contains a chunk header. 50 | type ChunkHeader struct { 51 | SubChunkID FixString 52 | SubChunkSize int32 53 | } 54 | 55 | // ReadChunkHeader reads a chunk header. 56 | func ReadChunkHeader(r io.Reader) (*ChunkHeader, error) { 57 | result := &ChunkHeader{} 58 | if err := binary.Read(r, binary.LittleEndian, result); err != nil { 59 | return nil, err 60 | } 61 | return result, nil 62 | } 63 | 64 | // FormatChunk contains a format chunk. 65 | type FormatChunk struct { 66 | AudioFormat uint16 67 | NumChannels int16 68 | SampleRate int32 69 | ByteRate int32 70 | BlockAlign int16 71 | BitsPerSample int16 72 | } 73 | 74 | // ReadFormatChunk reads a format chunk. 75 | func ReadFormatChunk(r io.Reader) (*FormatChunk, error) { 76 | result := &FormatChunk{} 77 | if err := binary.Read(r, binary.LittleEndian, result); err != nil { 78 | return nil, err 79 | } 80 | return result, nil 81 | } 82 | 83 | // WAV contains a WAV file. 84 | type WAV struct { 85 | RIFFHeader *RIFFHeader 86 | FormatChunk *FormatChunk 87 | Data []byte 88 | } 89 | 90 | // Audio contains audio data. 91 | type Audio struct { 92 | // Samples is a (num_channels, num_samples)-shaped array containing samples between -1 and 1. 93 | Samples [][]float32 94 | // Rate is the sample rate of the sound. 95 | Rate float64 96 | // MaxAbsAmplitude contains the max amplitude of the sound. 97 | MaxAbsAmplitude float32 98 | } 99 | 100 | // Amplify multiplies all samples in the audio with the amplification. 101 | func (a *Audio) Amplify(amplification float32) { 102 | for _, channel := range a.Samples { 103 | for sampleIndex := range channel { 104 | channel[sampleIndex] *= amplification 105 | } 106 | } 107 | a.MaxAbsAmplitude *= float32(math.Abs(float64(amplification))) 108 | } 109 | 110 | // Audio returns the audio in a WAV file. 111 | func (w *WAV) Audio() (*Audio, error) { 112 | result := &Audio{ 113 | Samples: make([][]float32, w.FormatChunk.NumChannels), 114 | Rate: float64(w.FormatChunk.SampleRate), 115 | } 116 | if w.FormatChunk.AudioFormat == 1 { 117 | pcmSamples := unsafe.Slice((*int16)(unsafe.Pointer(&w.Data[0])), len(w.Data)/2) 118 | numFrames := len(pcmSamples) / int(w.FormatChunk.NumChannels) 119 | for channelIndex := 0; channelIndex < int(w.FormatChunk.NumChannels); channelIndex++ { 120 | result.Samples[channelIndex] = make([]float32, numFrames) 121 | } 122 | scaleReciprocal := 1.0 / float32(int(1)<<(w.FormatChunk.BitsPerSample-1)) 123 | for sampleIndex, sample := range pcmSamples { 124 | scaledSample := float32(sample) * scaleReciprocal 125 | absSample := scaledSample 126 | if absSample < 0 { 127 | absSample = -absSample 128 | } 129 | if absSample > result.MaxAbsAmplitude { 130 | result.MaxAbsAmplitude = absSample 131 | } 132 | result.Samples[sampleIndex%int(w.FormatChunk.NumChannels)][sampleIndex/int(w.FormatChunk.NumChannels)] = scaledSample 133 | } 134 | } else if w.FormatChunk.AudioFormat == 3 { 135 | return nil, fmt.Errorf("blah") 136 | } else { 137 | return nil, fmt.Errorf("not audio format 1 (PCM) or 3 (IEEE float)") 138 | } 139 | return result, nil 140 | } 141 | 142 | // ReadWAV reads a WAV file from a reader. 143 | func ReadWAV(r io.Reader) (*WAV, error) { 144 | result := &WAV{} 145 | var err error 146 | if result.RIFFHeader, err = ReadRIFFHeader(r); err != nil { 147 | return nil, fmt.Errorf("while reading RIFF header: %v", err) 148 | } 149 | if result.RIFFHeader.ChunkID.String() != "RIFF" { 150 | return nil, fmt.Errorf("not RIFF") 151 | } 152 | if result.RIFFHeader.Format.String() != "WAVE" { 153 | return nil, fmt.Errorf("not WAVE") 154 | } 155 | for { 156 | chunkHeader, err := ReadChunkHeader(r) 157 | if err == io.EOF { 158 | break 159 | } 160 | if err != nil { 161 | return nil, fmt.Errorf("while reading chunk header: %v", err) 162 | } 163 | if chunkHeader.SubChunkID.String() == "fmt " { 164 | if result.FormatChunk, err = ReadFormatChunk(r); err != nil { 165 | return nil, fmt.Errorf("while reading format chunk: %v", err) 166 | } 167 | if result.FormatChunk.AudioFormat != 1 { 168 | return nil, fmt.Errorf("not audio format 1 (PCM): %x", result.FormatChunk.AudioFormat) 169 | } 170 | if result.FormatChunk.NumChannels != 1 && result.FormatChunk.NumChannels != 2 { 171 | return nil, fmt.Errorf("not 1 or 2 channels: %v", result.FormatChunk.NumChannels) 172 | } 173 | if result.FormatChunk.BitsPerSample != 16 { 174 | return nil, fmt.Errorf("not 16 bits: %v", result.FormatChunk.BitsPerSample) 175 | } 176 | } else if chunkHeader.SubChunkID.String() == "data" { 177 | if result.Data, err = io.ReadAll(r); err != nil { 178 | return nil, fmt.Errorf("while reading data chunk: %v", err) 179 | } 180 | break 181 | } else { 182 | buf := make([]byte, chunkHeader.SubChunkSize) 183 | if readBytes, err := r.Read(buf); readBytes != int(chunkHeader.SubChunkSize) || err != nil { 184 | return nil, fmt.Errorf("tried to read %v bytes of %q chunk: read %v bytes, got %v", chunkHeader.SubChunkSize, chunkHeader.SubChunkID.String(), readBytes, err) 185 | } 186 | } 187 | } 188 | if result.FormatChunk == nil { 189 | return nil, fmt.Errorf("no format chunk") 190 | } 191 | if result.Data == nil { 192 | return nil, fmt.Errorf("no data chunk") 193 | } 194 | return result, nil 195 | } 196 | -------------------------------------------------------------------------------- /cpp/zimt/compare.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #include 17 | #include 18 | 19 | #include "absl/flags/flag.h" 20 | #include "absl/flags/parse.h" 21 | #include "sndfile.h" 22 | #include "zimt/audio.h" 23 | #include "zimt/mos.h" 24 | #include "zimt/zimtohrli.h" 25 | 26 | ABSL_FLAG(std::string, path_a, "", "file A to compare"); 27 | ABSL_FLAG(std::vector, path_b, {}, "files B to compare to file A"); 28 | ABSL_FLAG(float, perceptual_sample_rate, 29 | zimtohrli::Zimtohrli{}.perceptual_sample_rate, 30 | "the frequency corresponding to the maximum time resolution, Hz"); 31 | ABSL_FLAG(float, full_scale_sine_db, 80, 32 | "reference dB SPL for a sine signal of amplitude 1"); 33 | ABSL_FLAG(bool, verbose, false, "verbose output"); 34 | ABSL_FLAG(bool, output_zimtohrli_distance, false, 35 | "Whether to output the raw Zimtohrli distance instead of a mapped " 36 | "mean opinion score."); 37 | ABSL_FLAG(bool, per_channel, false, 38 | "Whether to output the produced metric per channel instead of a " 39 | "single value for all channels."); 40 | 41 | namespace zimtohrli { 42 | 43 | namespace { 44 | 45 | void PrintLoadFileInfo(const std::string& path, const SF_INFO& file_info) { 46 | std::cout << "Loaded " << path << " (" << file_info.channels << "x" 47 | << file_info.frames << "@" << file_info.samplerate << "Hz " 48 | << GetFormatName(file_info.format) << ", " 49 | << (static_cast(file_info.frames) / 50 | static_cast(file_info.samplerate)) 51 | << "s)\n"; 52 | } 53 | 54 | float GetMetric(float zimtohrli_score) { 55 | if (absl::GetFlag(FLAGS_output_zimtohrli_distance)) { 56 | return zimtohrli_score; 57 | } 58 | return MOSFromZimtohrli(zimtohrli_score); 59 | } 60 | 61 | int Main(int argc, char* argv[]) { 62 | absl::ParseCommandLine(argc, argv); 63 | const std::string path_a = absl::GetFlag(FLAGS_path_a); 64 | const std::vector path_b = absl::GetFlag(FLAGS_path_b); 65 | if (path_a.empty() || path_b.empty()) { 66 | std::cerr << "Both path_a and path_b have to be specified." << std::endl; 67 | return 1; 68 | } 69 | const float full_scale_sine_db = absl::GetFlag(FLAGS_full_scale_sine_db); 70 | if (full_scale_sine_db < 1) { 71 | std::cerr << "Full scale sine dB must be >= 1." << std::endl; 72 | return 3; 73 | } 74 | 75 | absl::StatusOr file_a = AudioFile::Load(path_a); 76 | if (!file_a.ok()) { 77 | std::cerr << file_a.status().message(); 78 | return 4; 79 | } 80 | std::vector> channels_a; 81 | channels_a.reserve(file_a->Info().channels); 82 | for (size_t channel_idx = 0; channel_idx < file_a->Info().channels; 83 | ++channel_idx) { 84 | channels_a.push_back(file_a->AtRate(channel_idx, kSampleRate)); 85 | } 86 | const bool verbose = absl::GetFlag(FLAGS_verbose); 87 | if (verbose) { 88 | PrintLoadFileInfo(path_a, file_a->Info()); 89 | } 90 | 91 | std::vector file_b_vector; 92 | file_b_vector.reserve(path_b.size()); 93 | std::vector>> channels_b_vector; 94 | channels_b_vector.reserve(path_b.size()); 95 | for (const std::string& path : path_b) { 96 | absl::StatusOr file_b = AudioFile::Load(path); 97 | if (!file_b.ok()) { 98 | std::cerr << file_b.status().message(); 99 | return 4; 100 | } 101 | std::vector> channels_b; 102 | channels_b.reserve(file_b->Info().channels); 103 | for (size_t channel_idx = 0; channel_idx < file_b->Info().channels; 104 | ++channel_idx) { 105 | channels_b.push_back(file_b->AtRate(channel_idx, kSampleRate)); 106 | } 107 | if (verbose) { 108 | PrintLoadFileInfo(file_b->Path(), file_b->Info()); 109 | } 110 | CHECK_EQ(file_a->Info().channels, file_b->Info().channels); 111 | CHECK_EQ(file_a->Info().samplerate, file_b->Info().samplerate); 112 | file_b_vector.push_back(*std::move(file_b)); 113 | channels_b_vector.push_back(channels_b); 114 | } 115 | 116 | Zimtohrli z = { 117 | .perceptual_sample_rate = absl::GetFlag(FLAGS_perceptual_sample_rate), 118 | .full_scale_sine_db = absl::GetFlag(FLAGS_full_scale_sine_db), 119 | }; 120 | 121 | const bool per_channel = absl::GetFlag(FLAGS_per_channel); 122 | std::vector file_a_spectrograms; 123 | for (size_t channel_index = 0; channel_index < file_a->Info().channels; 124 | ++channel_index) { 125 | Spectrogram spectrogram = z.Analyze(channels_a[channel_index]); 126 | file_a_spectrograms.push_back(std::move(spectrogram)); 127 | } 128 | for (int file_b_index = 0; file_b_index < file_b_vector.size(); 129 | ++file_b_index) { 130 | const AudioFile& file_b = file_b_vector[file_b_index]; 131 | const std::vector>& channels_b = 132 | channels_b_vector[file_b_index]; 133 | std::optional spectrogram_b; 134 | float sum_of_squares = 0; 135 | for (size_t channel_index = 0; channel_index < file_a->Info().channels; 136 | ++channel_index) { 137 | if (spectrogram_b.has_value()) { 138 | z.Analyze(channels_b[channel_index], *spectrogram_b); 139 | } else { 140 | spectrogram_b = z.Analyze(channels_b[channel_index]); 141 | } 142 | const float distance = 143 | z.Distance(file_a_spectrograms[channel_index], *spectrogram_b); 144 | if (per_channel) { 145 | std::cout << GetMetric(distance) << std::endl; 146 | } else { 147 | sum_of_squares += distance * distance; 148 | } 149 | } 150 | if (!per_channel) { 151 | for (int file_b_index = 0; file_b_index < file_b_vector.size(); 152 | ++file_b_index) { 153 | std::cout << GetMetric(std::sqrt(sum_of_squares / 154 | float(file_a->Info().channels))) 155 | << std::endl; 156 | } 157 | } 158 | } 159 | return 0; 160 | } 161 | 162 | } // namespace 163 | 164 | } // namespace zimtohrli 165 | 166 | int main(int argc, char* argv[]) { return zimtohrli::Main(argc, argv); } 167 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/PuerkitoBio/goquery v1.10.3 h1:pFYcNSqHxBD06Fpj/KsbStFRsgRATgnf3LeXiUkhzPo= 2 | github.com/PuerkitoBio/goquery v1.10.3/go.mod h1:tMUX0zDMHXYlAQk6p35XxQMqMweEKB7iK7iLNd4RH4Y= 3 | github.com/aclements/go-moremath v0.0.0-20241023150245-c8bbc672ef66 h1:siNQlUMcFUDZWCOt0p+RHl7et5Nnwwyq/sFZmr4iG1I= 4 | github.com/aclements/go-moremath v0.0.0-20241023150245-c8bbc672ef66/go.mod h1:FDw7qicTbJ1y1SZcNnOvym2BogPdC3lY9Z1iUM4MVhw= 5 | github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kktS1LM= 6 | github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA= 7 | github.com/dgryski/go-onlinestats v0.0.0-20170612111826-1c7d19468768 h1:Xzl7CSuSnGsyU+9xmSU2h8w3d7Tnis66xeoNN207tLo= 8 | github.com/dgryski/go-onlinestats v0.0.0-20170612111826-1c7d19468768/go.mod h1:alfmlCqcg4uw9jaoIU1nOp9RFdJLMuu8P07BCEgpgoo= 9 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 10 | github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A= 11 | github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= 12 | github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= 13 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 14 | golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= 15 | golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= 16 | golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= 17 | golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= 18 | golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= 19 | golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= 20 | golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= 21 | golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= 22 | golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= 23 | golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= 24 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 25 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= 26 | golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= 27 | golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= 28 | golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= 29 | golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= 30 | golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= 31 | golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= 32 | golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= 33 | golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= 34 | golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= 35 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 36 | golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 37 | golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 38 | golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= 39 | golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 40 | golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 41 | golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 42 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 43 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 44 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 45 | golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 46 | golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 47 | golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 48 | golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 49 | golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 50 | golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 51 | golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 52 | golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 53 | golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= 54 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 55 | golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= 56 | golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= 57 | golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= 58 | golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= 59 | golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= 60 | golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= 61 | golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= 62 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 63 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 64 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= 65 | golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= 66 | golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= 67 | golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= 68 | golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= 69 | golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= 70 | golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= 71 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 72 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 73 | golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= 74 | golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= 75 | golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= 76 | golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= 77 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 78 | -------------------------------------------------------------------------------- /go/bin/score/score.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // score handles listening test datasets. 16 | package main 17 | 18 | import ( 19 | "encoding/json" 20 | "flag" 21 | "fmt" 22 | "log" 23 | "os" 24 | "reflect" 25 | "runtime" 26 | "sort" 27 | "strings" 28 | 29 | "github.com/google/zimtohrli/go/data" 30 | "github.com/google/zimtohrli/go/goohrli" 31 | "github.com/google/zimtohrli/go/pipe" 32 | "github.com/google/zimtohrli/go/progress" 33 | "github.com/google/zimtohrli/go/worker" 34 | ) 35 | 36 | func main() { 37 | details := flag.String("details", "", "Glob to directories with databases to show the details of.") 38 | calculate := flag.String("calculate", "", "Glob to directories with databases to calculate metrics for.") 39 | force := flag.Bool("force", false, "Whether to recalculate scores that already exist.") 40 | calculateZimtohrli := flag.Bool("calculate_zimtohrli", false, "Whether to calculate Zimtohrli scores.") 41 | zimtohrliScoreType := flag.String("zimtohrli_score_type", string(data.Zimtohrli), "Score type name to use when storing Zimtohrli scores in a dataset.") 42 | calculateViSQOL := flag.Bool("calculate_visqol", false, "Whether to calculate ViSQOL scores.") 43 | calculatePipeMetric := flag.String("calculate_pipe", "", "Path to a binary that serves metrics via stdin/stdout pipe. Install some of the via 'install_python_metrics.py'.") 44 | zimtohrliParameters := goohrli.DefaultParameters() 45 | b, err := json.Marshal(zimtohrliParameters) 46 | if err != nil { 47 | log.Panic(err) 48 | } 49 | zimtohrliParametersJSON := flag.String("zimtohrli_parameters", string(b), "Zimtohrli model parameters. Sample rate will be set to the sample rate of the measured audio files.") 50 | correlate := flag.String("correlate", "", "Glob to directories with databases to correlate scores for.") 51 | leaderboard := flag.String("leaderboard", "", "Glob to directories with databases to compute leaderboard for.") 52 | report := flag.String("report", "", "Glob to directories with databases to generate a report for.") 53 | skipMetrics := flag.String("skip_metrics", "", "Comma separated list of metrics to exclude from the output.") 54 | accuracy := flag.String("accuracy", "", "Glob to directories with databases to provide JND accuracy for.") 55 | workers := flag.Int("workers", runtime.NumCPU(), "Number of concurrent workers for tasks.") 56 | failFast := flag.Bool("fail_fast", false, "Whether to panic immediately on any error.") 57 | clear := flag.String("clear", "", "Glob to directories with databases to clear a particular score type from.") 58 | clearScore := flag.String("clear_score", "", "Name of score type to clear.") 59 | flag.Parse() 60 | 61 | if *details == "" && *calculate == "" && *correlate == "" && *accuracy == "" && *leaderboard == "" && *report == "" && *clear == "" { 62 | flag.Usage() 63 | os.Exit(1) 64 | } 65 | 66 | if err := zimtohrliParameters.Update([]byte(*zimtohrliParametersJSON)); err != nil { 67 | log.Panic(err) 68 | } 69 | 70 | skipMap := map[data.ScoreType]bool{} 71 | for _, skip := range strings.Split(*skipMetrics, ",") { 72 | skipMap[data.ScoreType(skip)] = true 73 | } 74 | 75 | if *clear != "" { 76 | studies, err := data.OpenStudies(*clear) 77 | if err != nil { 78 | log.Fatal(err) 79 | } 80 | defer studies.Close() 81 | for _, study := range studies { 82 | if err := study.ClearScore(data.ScoreType(*clearScore)); err != nil { 83 | log.Fatal(err) 84 | } 85 | } 86 | } 87 | 88 | if *calculate != "" { 89 | studies, err := data.OpenStudies(*calculate) 90 | if err != nil { 91 | log.Fatal(err) 92 | } 93 | defer studies.Close() 94 | for _, study := range studies { 95 | measurements := map[data.ScoreType]data.Measurement{} 96 | if *calculateZimtohrli { 97 | if !reflect.DeepEqual(zimtohrliParameters, goohrli.DefaultParameters()) { 98 | log.Printf("Using %+v", zimtohrliParameters) 99 | } 100 | z := goohrli.New(zimtohrliParameters) 101 | measurements[data.ScoreType(*zimtohrliScoreType)] = z.NormalizedAudioDistance 102 | } 103 | if *calculateViSQOL { 104 | v := goohrli.NewViSQOL() 105 | measurements[data.ViSQOL] = v.AudioMOS 106 | } 107 | if *calculatePipeMetric != "" { 108 | pool, err := pipe.NewMeterPool(*calculatePipeMetric) 109 | if err != nil { 110 | log.Fatal(err) 111 | } 112 | defer pool.Close() 113 | measurements[pool.ScoreType] = pool.Measure 114 | } 115 | if len(measurements) == 0 { 116 | log.Print("No metrics to calculate, provide one of the -calculate_XXX flags!") 117 | os.Exit(2) 118 | } 119 | sortedTypes := sort.StringSlice{} 120 | for scoreType := range measurements { 121 | sortedTypes = append(sortedTypes, string(scoreType)) 122 | } 123 | sort.Sort(sortedTypes) 124 | bundle, err := study.ToBundle() 125 | if err != nil { 126 | log.Fatal(err) 127 | } 128 | log.Printf("*** Calculating %+v (force=%v) for %v", sortedTypes, *force, bundle.Dir) 129 | bar := progress.New("Calculating") 130 | pool := &worker.Pool[any]{ 131 | Workers: *workers, 132 | FailFast: *failFast, 133 | OnChange: bar.Update, 134 | } 135 | if err := bundle.Calculate(measurements, pool, *force); err != nil { 136 | log.Printf("%#v", err) 137 | log.Fatal(err) 138 | } 139 | if err := study.Put(bundle.References); err != nil { 140 | log.Fatal(err) 141 | } 142 | bar.Finish() 143 | } 144 | } 145 | 146 | if *correlate != "" { 147 | bundles, err := data.OpenBundles(*correlate) 148 | if err != nil { 149 | log.Fatal(err) 150 | } 151 | for _, bundle := range bundles { 152 | if bundle.IsJND() { 153 | fmt.Printf("Not computing correlation for JND dataset %q\n\n", bundle.Dir) 154 | } else { 155 | corrTable, err := bundle.Correlate(skipMap) 156 | if err != nil { 157 | log.Fatal(err) 158 | } 159 | fmt.Printf("## %v\n\n", bundle.Dir) 160 | fmt.Println(corrTable) 161 | } 162 | } 163 | } 164 | 165 | if *accuracy != "" { 166 | bundles, err := data.OpenBundles(*accuracy) 167 | if err != nil { 168 | log.Fatal(err) 169 | } 170 | for _, bundle := range bundles { 171 | if bundle.IsJND() { 172 | accuracy, err := bundle.JNDAccuracy(skipMap) 173 | if err != nil { 174 | log.Fatal(err) 175 | } 176 | fmt.Printf("## %v\n", bundle.Dir) 177 | fmt.Println(accuracy) 178 | } else { 179 | fmt.Printf("Not computing accuracy for non-JND dataset %q\n\n", bundle.Dir) 180 | } 181 | } 182 | } 183 | 184 | if *report != "" { 185 | bundles, err := data.OpenBundles(*report) 186 | if err != nil { 187 | log.Fatal(err) 188 | } 189 | report, err := bundles.Report(skipMap) 190 | if err != nil { 191 | log.Fatal(err) 192 | } 193 | fmt.Println(report) 194 | } 195 | 196 | if *leaderboard != "" { 197 | bundles, err := data.OpenBundles(*leaderboard) 198 | if err != nil { 199 | log.Fatal(err) 200 | } 201 | board, err := bundles.Leaderboard(15, skipMap) 202 | if err != nil { 203 | log.Fatal(err) 204 | } 205 | fmt.Println(board) 206 | } 207 | 208 | if *details != "" { 209 | bundles, err := data.OpenBundles(*details) 210 | if err != nil { 211 | log.Fatal(err) 212 | } 213 | b, err := json.MarshalIndent(bundles, "", " ") 214 | if err != nil { 215 | log.Fatal(err) 216 | } 217 | fmt.Printf("%s\n", b) 218 | } 219 | } 220 | -------------------------------------------------------------------------------- /go/goohrli/goohrli_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | package goohrli 16 | 17 | import ( 18 | "encoding/json" 19 | "log" 20 | "math" 21 | "reflect" 22 | "testing" 23 | "time" 24 | ) 25 | 26 | func TestMOSFromZimtohrli(t *testing.T) { 27 | for _, tc := range []struct { 28 | zimtDistance float64 29 | wantMOS float64 30 | }{ 31 | { 32 | zimtDistance: 0, 33 | wantMOS: 5.0, 34 | }, 35 | { 36 | zimtDistance: 0.001, 37 | wantMOS: 4.748757362365723, 38 | }, 39 | { 40 | zimtDistance: 0.04, 41 | wantMOS: 1.2986432313919067, 42 | }, 43 | } { 44 | if mos := MOSFromZimtohrli(tc.zimtDistance); math.Abs(mos-tc.wantMOS) > 1e-2 { 45 | t.Errorf("MOSFromZimtohrli(%v) = %v, want %v", tc.zimtDistance, mos, tc.wantMOS) 46 | } 47 | } 48 | } 49 | 50 | func TestParams(t *testing.T) { 51 | g := New(DefaultParameters()) 52 | 53 | params := g.Parameters() 54 | params.FullScaleSineDB *= 0.5 55 | params.NSIMChannelWindow *= 2 56 | params.NSIMStepWindow *= 2 57 | params.PerceptualSampleRate *= 0.5 58 | 59 | g.Set(params) 60 | newParams := g.Parameters() 61 | if !reflect.DeepEqual(newParams, params) { 62 | t.Errorf("Expected updated parameters to be %+v, got %+v", params, newParams) 63 | } 64 | } 65 | 66 | func TestGoohrli(t *testing.T) { 67 | for _, tc := range []struct { 68 | freqA float64 69 | freqB float64 70 | distance float64 71 | }{ 72 | { 73 | freqA: 5000, 74 | freqB: 5000, 75 | distance: 0.0, 76 | }, 77 | { 78 | freqA: 5000, 79 | freqB: 5010, 80 | distance: 1.3709068298339844e-06, 81 | }, 82 | { 83 | freqA: 5000, 84 | freqB: 10000, 85 | distance: 0.01977372169494629, 86 | }, 87 | } { 88 | params := DefaultParameters() 89 | g := New(params) 90 | soundA := make([]float32, int(SampleRate())) 91 | for index := 0; index < len(soundA); index++ { 92 | soundA[index] = float32(math.Sin(2 * math.Pi * tc.freqA * float64(index) / SampleRate())) 93 | } 94 | analysisA := g.Analyze(soundA) 95 | soundB := make([]float32, int(SampleRate())) 96 | for index := 0; index < len(soundB); index++ { 97 | soundB[index] = float32(math.Sin(2 * math.Pi * tc.freqB * float64(index) / SampleRate())) 98 | } 99 | analysisB := g.Analyze(soundB) 100 | analysisDistance := float64(g.SpecDistance(analysisA, analysisB)) 101 | if d := rdiff(analysisDistance, tc.distance); d > 0.1 { 102 | t.Errorf("Distance = %v, want %v", analysisDistance, tc.distance) 103 | } 104 | distance := float64(g.Distance(soundA, soundB)) 105 | if d := rdiff(distance, tc.distance); d > 0.1 { 106 | t.Errorf("Distance = %v, want %v", distance, tc.distance) 107 | } 108 | } 109 | } 110 | 111 | func TestViSQOL(t *testing.T) { 112 | sampleRate := 48000.0 113 | g := NewViSQOL() 114 | for _, tc := range []struct { 115 | freqA float64 116 | freqB float64 117 | wantMOS float64 118 | }{ 119 | { 120 | freqA: 5000, 121 | freqB: 5000, 122 | wantMOS: 4.7321014404296875, 123 | }, 124 | { 125 | freqA: 5000, 126 | freqB: 10000, 127 | wantMOS: 1.5407887697219849, 128 | }, 129 | } { 130 | soundA := make([]float32, int(sampleRate)) 131 | for index := 0; index < len(soundA); index++ { 132 | soundA[index] = float32(math.Sin(2 * math.Pi * tc.freqA * float64(index) / sampleRate)) 133 | } 134 | soundB := make([]float32, int(sampleRate)) 135 | for index := 0; index < len(soundB); index++ { 136 | soundB[index] = float32(math.Sin(2 * math.Pi * tc.freqB * float64(index) / sampleRate)) 137 | } 138 | mos, err := g.MOS(sampleRate, soundA, soundB) 139 | if err != nil { 140 | t.Fatal(err) 141 | } 142 | if math.Abs(mos-tc.wantMOS) > 1e-3 { 143 | t.Errorf("got mos %v, wanted mos %v", mos, tc.wantMOS) 144 | } 145 | } 146 | } 147 | 148 | var goohrliDurationType = reflect.TypeOf(Duration{}) 149 | 150 | func populate(s any) { 151 | counter := 1 152 | val := reflect.ValueOf(s).Elem() 153 | typ := val.Type() 154 | for _, field := range reflect.VisibleFields(typ) { 155 | switch field.Type.Kind() { 156 | case reflect.Float64: 157 | val.FieldByIndex(field.Index).SetFloat(float64(counter)) 158 | case reflect.Int: 159 | val.FieldByIndex(field.Index).SetInt(int64(counter & 0xffff)) 160 | case reflect.Bool: 161 | val.FieldByIndex(field.Index).SetBool(true) 162 | case reflect.Array: 163 | if field.Type.Elem().Kind() == reflect.Float64 { 164 | for i := 0; i < field.Type.Len(); i++ { 165 | val.FieldByIndex(field.Index).Index(i).SetFloat(float64(counter)) 166 | counter++ 167 | } 168 | } else { 169 | log.Panicf("Unsupported array type %v", field.Type.Elem().Kind()) 170 | } 171 | default: 172 | if field.Type == goohrliDurationType { 173 | val.FieldByIndex(field.Index).Set(reflect.ValueOf(Duration{Duration: time.Duration(counter) * time.Minute})) 174 | } else { 175 | log.Panicf("Unsupported field %v", field) 176 | } 177 | } 178 | counter++ 179 | } 180 | } 181 | 182 | func rdiff(a, b float64) float64 { 183 | return math.Abs(float64(a-b) / (0.5 * (a + b))) 184 | } 185 | 186 | func checkNear(a, b any, rtol float64, t *testing.T) { 187 | t.Helper() 188 | aVal := reflect.ValueOf(a) 189 | bVal := reflect.ValueOf(b) 190 | if aVal.Type() != bVal.Type() { 191 | t.Fatalf("%v and %v not same type", a, b) 192 | } 193 | for _, field := range reflect.VisibleFields(aVal.Type()) { 194 | switch field.Type.Kind() { 195 | case reflect.Float64: 196 | aFloat := aVal.FieldByIndex(field.Index).Float() 197 | bFloat := bVal.FieldByIndex(field.Index).Float() 198 | if d := rdiff(aFloat, bFloat); d > rtol { 199 | t.Errorf("%v: %v is more than %v off from %v", field, aFloat, rtol, bFloat) 200 | } 201 | case reflect.Int: 202 | aInt := aVal.FieldByIndex(field.Index).Int() 203 | bInt := bVal.FieldByIndex(field.Index).Int() 204 | if aInt != bInt { 205 | t.Errorf("%v: %v != %v", field, aInt, bInt) 206 | } 207 | case reflect.Bool: 208 | aBool := aVal.FieldByIndex(field.Index).Bool() 209 | bBool := bVal.FieldByIndex(field.Index).Bool() 210 | if aBool != bBool { 211 | t.Errorf("%v: %v != %v", field, aBool, bBool) 212 | } 213 | case reflect.Array: 214 | if field.Type.Elem().Kind() == reflect.Float64 { 215 | for i := 0; i < aVal.FieldByIndex(field.Index).Len(); i++ { 216 | aFloat := aVal.FieldByIndex(field.Index).Index(i).Float() 217 | bFloat := bVal.FieldByIndex(field.Index).Index(i).Float() 218 | if d := rdiff(aFloat, bFloat); d > rtol { 219 | t.Errorf("%v[%v]: %v is more than %v off from %v", field, i, aFloat, rtol, bFloat) 220 | } 221 | } 222 | } else { 223 | log.Panicf("Unsupported array type %v", field.Type.Elem()) 224 | } 225 | default: 226 | if field.Type == goohrliDurationType { 227 | aDur := aVal.FieldByIndex(field.Index).Interface().(Duration).Duration 228 | bDur := bVal.FieldByIndex(field.Index).Interface().(Duration).Duration 229 | if d := rdiff(float64(aDur), float64(bDur)); d > rtol { 230 | t.Errorf("%v: %v is more than %v off from %v", field, aDur, rtol, bDur) 231 | } 232 | } else { 233 | log.Panicf("Unsupported field %v", field) 234 | } 235 | } 236 | } 237 | } 238 | 239 | func TestParamConversion(t *testing.T) { 240 | params := Parameters{} 241 | populate(¶ms) 242 | cParams := cFromGoParameters(params) 243 | reconvertedParams := goFromCParameters(cParams) 244 | checkNear(reconvertedParams, params, 1e-6, t) 245 | } 246 | 247 | func TestParamUpdate(t *testing.T) { 248 | params := DefaultParameters() 249 | js, err := json.Marshal(params) 250 | if err != nil { 251 | t.Fatal(err) 252 | } 253 | m := map[string]any{} 254 | if err := json.Unmarshal(js, &m); err != nil { 255 | t.Fatal(err) 256 | } 257 | for k, v := range m { 258 | js, err = json.Marshal(map[string]any{k: v}) 259 | if err != nil { 260 | t.Fatal(err) 261 | } 262 | if err := params.Update(js); err != nil { 263 | t.Error(err) 264 | } 265 | } 266 | } 267 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Tests](https://github.com/google/zimtohrli/workflows/Test%20Zimtohrli/badge.svg)](https://github.com/google/zimtohrli/actions) 2 | 3 | # Zimtohrli: A New Psychoacoustic Perceptual Metric for Audio Compression 4 | 5 | Zimtohrli is a psychoacoustic perceptual metric that quantifies the human 6 | observable difference in two audio signals in the proximity of 7 | just-noticeable-differences. 8 | 9 | In this project we study the psychological and physiological responses 10 | associated with sound, to create a new more accurate model for measuring 11 | human-subjective similarity between sounds. 12 | The main focus will be on just-noticeable-difference to get most significant 13 | benefits in high quality audio compression. 14 | The main goals of the project is to further both existing and new practical 15 | audio (and video containing audio) compression, and also be able to plug in the 16 | resulting psychoacoustic similarity measure into audio related machine learning 17 | models. 18 | 19 | ## Design 20 | 21 | Zimtohrli implements a perceptually-motivated audio similarity metric that 22 | models the human auditory system through a multi-stage signal processing 23 | pipeline. The metric operates on audio signals sampled at 48 kHz and produces a 24 | scalar distance value that correlates with human perception of audio quality 25 | differences. 26 | 27 | ### Signal Processing Pipeline 28 | 29 | The algorithm consists of four main stages: 30 | 31 | 1. **3rd Order Complex Gammatone Filterbank**: The input signal is processed 32 | through a bank of 128 bins of 3rd order filters with center frequencies 33 | spaced between 24.349 Hz and 19658.3 Hz. These filters are implemented using 34 | a computationally efficient rotating phasor algorithm that computes spectral 35 | energy at each frequency band. The filterbank incorporates 36 | bandwidth-dependent exponential windowing to model frequency selectivity of 37 | the basilar membrane. 38 | 39 | 2. **Physiological Modeling**: The filtered signals undergo several transformations 40 | inspired by auditory physiology: 41 | - A resonator model simulating the mechanical response of the ear drum and 42 | middle ear structures, implemented as a linear 32-point FIR filter with 43 | physiologically-motivated coefficients 44 | - A second 32-point FIR filter for modeling direct absorption of sound, 45 | e.g., through skull. 46 | - Loudness transformation using a logarithmic function with 47 | frequency-dependent gains inspired by equal-loudness contours 48 | - A frequency-dependant bias energy added to signals before taking the 49 | logarithm models the hearing threshold. 50 | 51 | 3. **Temporal Alignment**: To handle temporal misalignments between reference 52 | and test signals, the algorithm employs Dynamic Time Warping (DTW) with a 53 | perceptually-motivated cost function. The warping path minimizes a weighted 54 | combination of spectral distance (raised to power 0.323) and temporal 55 | distortion penalties. 56 | 57 | 4. **Perceptual Similarity Computation**: The aligned spectrograms are compared 58 | using a modified Neurogram Similarity Index Measure (NSIM). This metric 59 | computes windowed statistics (mean, variance, covariance) over 8 temporal 60 | frames and 5 frequency channels, combining intensity and structure components 61 | through empirically-optimized non-linear functions inspired by SSIM. 62 | 63 | ### Key Parameters 64 | 65 | - **Perceptual sampling rate**: 85 Hz (derived from [high gamma band](https://doi.org/10.1523/JNEUROSCI.5297-10.2011) frequency) 66 | - **NSIM temporal window**: 8 frames (~94 ms), a smaller window of 6 samples 67 | gives better results in the high MOS (high quality, MOS >= 2.9) subset of 68 | our test corpora, 8 is a compromise where a better performance across a 69 | broad range of MOS is achieved. 70 | - **NSIM frequency window**: 5 channels 71 | - **Reference level**: 78.3 dB SPL for unity amplitude sine wave 72 | 73 | The final distance metric is computed as 1 - NSIM, providing a value between 0 74 | (identical) and 1 (maximally different) that correlates with subjective quality 75 | assessments. 76 | 77 | ## Performance 78 | 79 | For correlation performance with a few datasets see [CORRELATION.md](CORRELATION.md). 80 | 81 | The datasets can be acquired using the tools [coresvnet](go/bin/coresvnet), 82 | [perceptual_audio](go/bin/perceptual_audio), [sebass_db](go/bin/sebass_db), 83 | [odaq](go/bin/odaq), and [tcd_voip](go/bin/tcd_voip). 84 | 85 | Zimtohrli can compare ~70 seconds of audio per second on a single 2.5GHz core. 86 | 87 | ## Correlation Testing 88 | 89 | Zimtohrli includes a comprehensive correlation testing framework to validate how 90 | well audio quality metrics correlate with human perception. The system evaluates 91 | metrics against multiple listening test datasets containing either Mean Opinion 92 | Scores (MOS) or Just Noticeable Difference (JND) ratings. 93 | 94 | ### How Correlation Scoring Works 95 | 96 | The system uses two different evaluation methods depending on the dataset type: 97 | 98 | - **For MOS datasets**: Calculates Spearman rank correlation coefficient between 99 | predicted scores and human ratings. Higher correlation values (closer to 1.0) 100 | indicate better alignment with human perception. 101 | - **For JND datasets**: Determines classification accuracy by finding an optimal 102 | threshold that maximizes correct predictions of whether differences are 103 | audible. The score represents the percentage of correct classifications. 104 | 105 | ### Running Correlation Tests 106 | 107 | 1. **Install external metrics** (optional): 108 | ```bash 109 | ./install_external_metrics.sh /path/to/destination 110 | ``` 111 | 112 | 2. **Acquire datasets** using the provided tools in `go/bin/` 113 | 114 | 3. **Calculate metrics**: 115 | ```bash 116 | go run go/bin/score/score.go -calculate "/path/to/datasets/*" -calculate_zimtohrli -calculate_visqol 117 | ``` 118 | 119 | 4. **Generate correlation report**: 120 | ```bash 121 | go run go/bin/score/score.go -report "/path/to/datasets/*" > correlation_report.md 122 | ``` 123 | 124 | The report includes correlation tables for each dataset and a global leaderboard 125 | showing mean squared error across all studies, where lower values indicate 126 | better overall performance. 127 | 128 | ## Compatibility 129 | 130 | Zimtohrli is a project under development, and is built and tested in a Debian-like 131 | environment. It's built to work with C++17. 132 | 133 | ## Minimal simple usage 134 | 135 | The very simplest way to use Zimtohrli is to just include the `zimtohrli.h` header. 136 | 137 | This allows you to 138 | 139 | ``` 140 | #include "zimtohrli.h" 141 | 142 | const Zimtohrli z(); 143 | const Spectrogram spec_a = z.Analyze(Span(samples_a, size_a)); 144 | Spectrogram spec_b = z.Analyze(Span(samples_b, size_b)); 145 | const float distance = z.Distance(spec_a, spec_b); 146 | ``` 147 | 148 | The samples have to be floats between -1 and 1 at 48kHz sample rate. 149 | 150 | ## Build 151 | 152 | Some dependencies for Zimtohrli are downloaded and managed by the build script, 153 | but others need to be installed before building. 154 | 155 | - cmake 156 | - ninja-build 157 | 158 | To build the compare tool, a few more dependencies are necessary: 159 | 160 | - libogg-dev 161 | - libvorbis-dev 162 | - libflac-dev 163 | - libopus-dev 164 | - libasound2-dev 165 | - libglfw3-dev 166 | - libsoxr-dev 167 | 168 | Finally, to build and test the Python and Go wrappers, the following dependencies 169 | are necessary: 170 | 171 | - golang-go 172 | - python3 173 | - xxd 174 | - zlib1g-dev 175 | - ffmpeg 176 | 177 | To install these in a Debian-like system: 178 | 179 | ``` 180 | sudo apt install -y cmake ninja-build clang clang-tidy libogg-dev libvorbis-dev libflac-dev libopus-dev libasound2-dev libglfw3-dev libsoxr-dev golang-go python3 xxd zlib1g-dev ffmpeg 181 | ``` 182 | 183 | Once they are installed, configure the project: 184 | 185 | ``` 186 | ./configure.sh 187 | ``` 188 | 189 | Build the project: 190 | ``` 191 | (cd build && ninja) 192 | ``` 193 | 194 | ### Address sanitizer build 195 | 196 | To build with address sanitizer, configure a new build directory with asan configured: 197 | 198 | 199 | ``` 200 | ./configure.sh asan 201 | ``` 202 | 203 | Build the project: 204 | ``` 205 | (cd asan_build && ninja) 206 | ``` 207 | 208 | ### Debug build 209 | 210 | To build with debug symbols, configure a new build directory with debugging configured: 211 | 212 | 213 | ``` 214 | ./configure.sh debug 215 | ``` 216 | 217 | Build the project: 218 | ``` 219 | (cd debug_build && ninja) 220 | ``` 221 | 222 | ### Testing 223 | 224 | ``` 225 | (cd build && ninja && ninja test) 226 | ``` 227 | -------------------------------------------------------------------------------- /cpp/zimt/pyohrli.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | #include 16 | #define PY_SSIZE_T_CLEAN 17 | #include 18 | 19 | #include 20 | 21 | #include "structmember.h" // NOLINT // For PyMemberDef 22 | #include "zimt/mos.h" 23 | #include "zimt/zimtohrli.h" 24 | 25 | namespace { 26 | 27 | struct SpectrogramObject { 28 | // clang-format off 29 | PyObject_HEAD 30 | void *spectrogram; 31 | // clang-format on 32 | }; 33 | 34 | void Spectrogram_dealloc(SpectrogramObject* self) { 35 | if (self) { 36 | if (self->spectrogram) { 37 | delete static_cast(self->spectrogram); 38 | self->spectrogram = nullptr; 39 | } 40 | Py_TYPE(self)->tp_free((PyObject*)self); 41 | } 42 | } 43 | 44 | PyTypeObject SpectrogramType = { 45 | // clang-format off 46 | .ob_base = PyVarObject_HEAD_INIT(nullptr, 0) 47 | .tp_name = "pyohrli.Spectrogram", 48 | // clang-format on 49 | .tp_basicsize = sizeof(SpectrogramObject), 50 | .tp_itemsize = 0, 51 | .tp_dealloc = (destructor)Spectrogram_dealloc, 52 | .tp_flags = Py_TPFLAGS_DEFAULT, 53 | .tp_doc = PyDoc_STR("Python wrapper around C++ zimtohrli::Spectrogram."), 54 | .tp_new = PyType_GenericNew, 55 | }; 56 | 57 | struct PyohrliObject { 58 | // clang-format off 59 | PyObject_HEAD 60 | void *zimtohrli; 61 | // clang-format on 62 | }; 63 | 64 | int Pyohrli_init(PyohrliObject* self, PyObject* args, PyObject* kwds) { 65 | try { 66 | self->zimtohrli = new zimtohrli::Zimtohrli{}; 67 | } catch (const std::bad_alloc&) { 68 | PyErr_SetNone(PyExc_MemoryError); 69 | return -1; 70 | } 71 | return 0; 72 | } 73 | 74 | void Pyohrli_dealloc(PyohrliObject* self) { 75 | if (self) { 76 | if (self->zimtohrli) { 77 | delete static_cast(self->zimtohrli); 78 | self->zimtohrli = nullptr; 79 | } 80 | Py_TYPE(self)->tp_free((PyObject*)self); 81 | } 82 | } 83 | 84 | struct BufferDeleter { 85 | void operator()(Py_buffer* buffer) const { PyBuffer_Release(buffer); } 86 | }; 87 | 88 | // Plain C++ function to analyze a Python buffer object using Zimtohrli. 89 | // 90 | // Calls to Analyze never need to be cleaned up (with e.g. delete or DECREF) 91 | // afterwards. 92 | // 93 | // If the return value is std::nullopt that means a Python error is set and the 94 | // current operation should be terminated ASAP. 95 | std::optional Analyze( 96 | const zimtohrli::Zimtohrli& zimtohrli, PyObject* buffer_object) { 97 | Py_buffer buffer_view; 98 | if (PyObject_GetBuffer(buffer_object, &buffer_view, PyBUF_C_CONTIGUOUS)) { 99 | PyErr_SetString(PyExc_TypeError, "object is not buffer"); 100 | return std::nullopt; 101 | } 102 | std::unique_ptr buffer_view_deleter(&buffer_view); 103 | if (buffer_view.itemsize != sizeof(float)) { 104 | PyErr_SetString(PyExc_TypeError, "buffer does not contain floats"); 105 | return std::nullopt; 106 | } 107 | if (buffer_view.ndim != 1) { 108 | PyErr_SetString(PyExc_TypeError, "buffer has more than 1 axis"); 109 | return std::nullopt; 110 | } 111 | return std::optional(zimtohrli.Analyze( 112 | zimtohrli::Span(static_cast(buffer_view.buf), 113 | buffer_view.len / sizeof(float)))); 114 | } 115 | 116 | PyObject* BadArgument(const std::string& message) { 117 | PyErr_SetString(PyExc_TypeError, message.c_str()); 118 | return nullptr; 119 | } 120 | 121 | PyObject* Pyohrli_distance(PyohrliObject* self, PyObject* const* args, 122 | Py_ssize_t nargs) { 123 | if (nargs != 2) { 124 | return BadArgument("not exactly 2 arguments provided"); 125 | } 126 | const zimtohrli::Zimtohrli zimtohrli = 127 | *static_cast(self->zimtohrli); 128 | std::optional spectrogram_a = 129 | Analyze(zimtohrli, args[0]); 130 | if (!spectrogram_a.has_value()) { 131 | return nullptr; 132 | } 133 | std::optional spectrogram_b = 134 | Analyze(zimtohrli, args[1]); 135 | if (!spectrogram_b.has_value()) { 136 | return nullptr; 137 | } 138 | return PyFloat_FromDouble( 139 | zimtohrli.Distance(spectrogram_a.value(), spectrogram_b.value())); 140 | } 141 | 142 | PyObject* Pyohrli_analyze(PyohrliObject* self, PyObject* const* args, 143 | Py_ssize_t nargs) { 144 | if (nargs != 1) { 145 | return BadArgument("not exactly 1 argument provided"); 146 | } 147 | const zimtohrli::Zimtohrli zimtohrli = 148 | *static_cast(self->zimtohrli); 149 | const std::optional spectrogram = 150 | Analyze(zimtohrli, args[0]); 151 | if (!spectrogram.has_value()) { 152 | return nullptr; 153 | } 154 | return PyBytes_FromStringAndSize( 155 | reinterpret_cast(spectrogram->values.get()), 156 | spectrogram->size() * sizeof(float)); 157 | } 158 | 159 | PyObject* Pyohrli_num_rotators(PyohrliObject* self, PyObject* const* args, 160 | Py_ssize_t nargs) { 161 | if (nargs != 0) { 162 | return BadArgument("not exactly 0 arguments provided"); 163 | } 164 | return PyLong_FromLong(zimtohrli::kNumRotators); 165 | } 166 | 167 | PyObject* Pyohrli_sample_rate(PyohrliObject* self, PyObject* const* args, 168 | Py_ssize_t nargs) { 169 | if (nargs != 0) { 170 | return BadArgument("not exactly 0 arguments provided"); 171 | } 172 | return PyLong_FromLong(zimtohrli::kSampleRate); 173 | } 174 | 175 | PyMethodDef Pyohrli_methods[] = { 176 | {"num_rotators", (PyCFunction)Pyohrli_num_rotators, METH_FASTCALL, 177 | "Returns the number of rotators, i.e. the number of dimensions in a " 178 | "spectrogram."}, 179 | {"analyze", (PyCFunction)Pyohrli_analyze, METH_FASTCALL, 180 | "Returns a spectrogram of the provided signal."}, 181 | {"distance", (PyCFunction)Pyohrli_distance, METH_FASTCALL, 182 | "Returns the distance between the two provided signals."}, 183 | {"sample_rate", (PyCFunction)Pyohrli_sample_rate, METH_FASTCALL, 184 | "Returns the expected sample rate for analyzed audio."}, 185 | {nullptr} /* Sentinel */ 186 | }; 187 | 188 | PyTypeObject PyohrliType = { 189 | // clang-format off 190 | .ob_base = PyVarObject_HEAD_INIT(nullptr, 0) 191 | .tp_name = "pyohrli.Pyohrli", 192 | // clang-format on 193 | .tp_basicsize = sizeof(PyohrliObject), 194 | .tp_itemsize = 0, 195 | .tp_dealloc = (destructor)Pyohrli_dealloc, 196 | .tp_flags = Py_TPFLAGS_DEFAULT, 197 | .tp_doc = 198 | PyDoc_STR("Python wrapper around the C++ zimtohrli::Zimtohrli type."), 199 | .tp_methods = Pyohrli_methods, 200 | .tp_init = (initproc)Pyohrli_init, 201 | .tp_new = PyType_GenericNew, 202 | }; 203 | 204 | PyObject* MOSFromZimtohrli(PyohrliObject* self, PyObject* const* args, 205 | Py_ssize_t nargs) { 206 | if (nargs != 1) { 207 | return BadArgument("not exactly 1 argument provided"); 208 | } 209 | return PyFloat_FromDouble( 210 | zimtohrli::MOSFromZimtohrli(PyFloat_AsDouble(args[0]))); 211 | } 212 | 213 | static PyMethodDef PyohrliModuleMethods[] = { 214 | {"MOSFromZimtohrli", (PyCFunction)MOSFromZimtohrli, METH_FASTCALL, 215 | "Returns an approximate mean opinion score based on the provided " 216 | "Zimtohrli distance."}, 217 | {NULL, NULL, 0, NULL}, 218 | }; 219 | 220 | PyModuleDef PyohrliModule = { 221 | .m_base = PyModuleDef_HEAD_INIT, 222 | .m_name = "pyohrli", 223 | .m_doc = "Python wrapper around the C++ zimtohrli library.", 224 | .m_size = -1, 225 | .m_methods = PyohrliModuleMethods, 226 | }; 227 | 228 | PyMODINIT_FUNC PyInit__pyohrli(void) { 229 | PyObject* m = PyModule_Create(&PyohrliModule); 230 | if (m == nullptr) return nullptr; 231 | 232 | if (PyType_Ready(&SpectrogramType) < 0) { 233 | Py_DECREF(m); 234 | return nullptr; 235 | } 236 | if (PyModule_AddObjectRef(m, "Spectrogram", (PyObject*)&SpectrogramType) < 237 | 0) { 238 | Py_DECREF(m); 239 | return nullptr; 240 | } 241 | 242 | if (PyType_Ready(&PyohrliType) < 0) { 243 | Py_DECREF(m); 244 | return nullptr; 245 | }; 246 | if (PyModule_AddObjectRef(m, "Pyohrli", (PyObject*)&PyohrliType) < 0) { 247 | Py_DECREF(m); 248 | return nullptr; 249 | } 250 | 251 | return m; 252 | } 253 | 254 | } // namespace 255 | -------------------------------------------------------------------------------- /go/bin/odaq/odaq.go: -------------------------------------------------------------------------------- 1 | // Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // odaq downloads the listening test at https://zenodo.org/records/13377284. 16 | package main 17 | 18 | import ( 19 | "archive/zip" 20 | "encoding/xml" 21 | "flag" 22 | "fmt" 23 | "io" 24 | "log" 25 | "net/http" 26 | "os" 27 | "path/filepath" 28 | "runtime" 29 | "sync" 30 | 31 | "github.com/google/zimtohrli/go/aio" 32 | "github.com/google/zimtohrli/go/data" 33 | "github.com/google/zimtohrli/go/progress" 34 | "github.com/google/zimtohrli/go/worker" 35 | ) 36 | 37 | func withZipFiles(u string, f func(count int, each func(yield func(root string, path string) error) error) error) error { 38 | res, err := http.Get(u) 39 | if err != nil { 40 | return fmt.Errorf("GETing %q: %v", u, err) 41 | } 42 | if res.StatusCode != 200 { 43 | return fmt.Errorf("GETing %q: %v", u, res.Status) 44 | } 45 | 46 | tmpFile, err := os.CreateTemp("", "odaq.*.zip") 47 | if err != nil { 48 | return fmt.Errorf("creating temp file: %v", err) 49 | } 50 | defer os.Remove(tmpFile.Name()) 51 | 52 | if err := func() error { 53 | defer tmpFile.Close() 54 | bar := progress.New(u) 55 | bar.Update(int(res.ContentLength), 0, 0) 56 | defer bar.Finish() 57 | buf := make([]byte, 1024*1024) 58 | read := 0 59 | sum := 0 60 | for read, err = res.Body.Read(buf); err == nil || err == io.EOF; read, err = res.Body.Read(buf) { 61 | sum += read 62 | bar.Update(int(res.ContentLength), sum, 0) 63 | if _, err := tmpFile.Write(buf[:read]); err != nil { 64 | return fmt.Errorf("writing to %q: %v", tmpFile.Name(), err) 65 | } 66 | if err == io.EOF { 67 | break 68 | } 69 | } 70 | if err != io.EOF { 71 | return fmt.Errorf("reading %q: %v", u, err) 72 | } 73 | return nil 74 | }(); err != nil { 75 | return err 76 | } 77 | zipReader, err := zip.OpenReader(tmpFile.Name()) 78 | if err != nil { 79 | return fmt.Errorf("reading %q: %v", tmpFile.Name(), err) 80 | } 81 | defer zipReader.Close() 82 | 83 | tmpDir, err := os.MkdirTemp("", "odaq.*") 84 | if err != nil { 85 | return fmt.Errorf("creating temp directory: %v", err) 86 | } 87 | defer os.RemoveAll(tmpDir) 88 | 89 | return f(len(zipReader.File), func(yield func(root string, path string) error) error { 90 | for _, file := range zipReader.File { 91 | if file.FileInfo().IsDir() { 92 | continue 93 | } 94 | destPath := filepath.Join(tmpDir, file.Name) 95 | if err := os.MkdirAll(filepath.Dir(destPath), 0700); err != nil { 96 | return fmt.Errorf("creating directory %q: %v", filepath.Dir(destPath), err) 97 | } 98 | dest, err := os.Create(destPath) 99 | if err != nil { 100 | return fmt.Errorf("creating %q: %v", destPath, err) 101 | } 102 | if err := func() error { 103 | defer dest.Close() 104 | reader, err := file.Open() 105 | if err != nil { 106 | return fmt.Errorf("opening zip reader for %q: %v", file.Name, err) 107 | } 108 | defer reader.Close() 109 | if _, err := io.Copy(dest, reader); err != nil { 110 | return fmt.Errorf("copying zip reader for %q to %q: %v", file.Name, destPath, err) 111 | } 112 | return nil 113 | }(); err != nil { 114 | return fmt.Errorf("copying zip %q to %q: %v", file.Name, destPath, err) 115 | } 116 | if err := func() error { 117 | return yield(tmpDir, destPath) 118 | }(); err != nil { 119 | return err 120 | } 121 | } 122 | return nil 123 | }) 124 | } 125 | 126 | type result struct { 127 | FileName string `xml:"fileName,attr"` 128 | Score float64 `xml:"score,attr"` 129 | } 130 | 131 | type trial struct { 132 | Name string `xml:"trialName,attr"` 133 | Results []result `xml:"testFile"` 134 | } 135 | 136 | type mushra struct { 137 | Trials []trial `xml:"trials>trial"` 138 | } 139 | 140 | func readMushraXML(path string) (*mushra, error) { 141 | f, err := os.Open(path) 142 | if err != nil { 143 | return nil, err 144 | } 145 | defer f.Close() 146 | result := &mushra{} 147 | if err := xml.NewDecoder(f).Decode(result); err != nil { 148 | return nil, err 149 | } 150 | return result, nil 151 | } 152 | 153 | func populate(dest string, workers int) error { 154 | study, err := data.OpenStudy(dest) 155 | if err != nil { 156 | return err 157 | } 158 | defer study.Close() 159 | 160 | refToDistToScores := map[string]map[string][]float64{} 161 | 162 | appendTrials := func(xmlPath string) error { 163 | m, err := readMushraXML(xmlPath) 164 | if err != nil { 165 | return fmt.Errorf("reading MUSHRA XML from %q: %v", xmlPath, err) 166 | } 167 | for _, trial := range m.Trials { 168 | refPath := filepath.Join("ODAQ", "ODAQ_listening_test", trial.Name, "reference.wav") 169 | distToScores, found := refToDistToScores[refPath] 170 | if !found { 171 | distToScores = map[string][]float64{} 172 | refToDistToScores[refPath] = distToScores 173 | } 174 | for _, result := range trial.Results { 175 | if result.FileName != "reference.wav" { 176 | distPath := filepath.Join("ODAQ", "ODAQ_listening_test", trial.Name, result.FileName) 177 | distToScores[distPath] = append(distToScores[distPath], result.Score) 178 | } 179 | } 180 | } 181 | return nil 182 | } 183 | 184 | if err := withZipFiles("https://zenodo.org/records/13377284/files/ODAQ_v1_BSU.zip?download=1", func(count int, each func(yield func(root string, path string) error) error) error { 185 | return each(func(root string, zipPath string) error { 186 | if filepath.Ext(zipPath) == ".xml" { 187 | if err := appendTrials(zipPath); err != nil { 188 | return fmt.Errorf("trying to append trials from %q: %v", zipPath, err) 189 | } 190 | } 191 | return nil 192 | }) 193 | }); err != nil { 194 | return err 195 | } 196 | 197 | recodedPaths := map[string]string{} 198 | recodedPathLock := sync.Mutex{} 199 | if err := withZipFiles("https://zenodo.org/records/10405774/files/ODAQ.zip?download=1", func(count int, each func(yield func(root string, path string) error) error) error { 200 | bar := progress.New("Converting") 201 | pool := worker.Pool[any]{ 202 | Workers: workers, 203 | OnChange: func(submitted, completed, errors int) { 204 | bar.Update(count, completed, errors) 205 | }, 206 | } 207 | bar.Update(count, 0, 0) 208 | if err := each(func(root string, zipPath string) error { 209 | if ext := filepath.Ext(zipPath); ext == ".xml" { 210 | if err := appendTrials(zipPath); err != nil { 211 | return fmt.Errorf("trying to append trials from %q: %v", zipPath, err) 212 | } 213 | } else if ext == ".wav" { 214 | pool.Submit(func(func(any)) error { 215 | recodedPath, err := aio.Recode(zipPath, dest) 216 | if err != nil { 217 | return err 218 | } 219 | rel, err := filepath.Rel(root, zipPath) 220 | if err != nil { 221 | return err 222 | } 223 | recodedPathLock.Lock() 224 | defer recodedPathLock.Unlock() 225 | recodedPaths[rel] = recodedPath 226 | return nil 227 | }) 228 | } 229 | return nil 230 | }); err != nil { 231 | return err 232 | } 233 | if err := pool.Error(); err != nil { 234 | return err 235 | } 236 | bar.Finish() 237 | return nil 238 | }); err != nil { 239 | return err 240 | } 241 | 242 | references := []*data.Reference{} 243 | for refPath, distToScores := range refToDistToScores { 244 | recodePath, found := recodedPaths[refPath] 245 | if !found { 246 | return fmt.Errorf("recoded path for %q not found", refPath) 247 | } 248 | ref := &data.Reference{ 249 | Name: refPath, 250 | Path: recodePath, 251 | } 252 | for distPath, scores := range distToScores { 253 | recodePath, found = recodedPaths[distPath] 254 | if !found { 255 | return fmt.Errorf("recoded path for %q not found", refPath) 256 | } 257 | sum := 0.0 258 | for _, score := range scores { 259 | sum += score 260 | } 261 | mean := sum / float64(len(scores)) 262 | ref.Distortions = append(ref.Distortions, &data.Distortion{ 263 | Name: distPath, 264 | Scores: map[data.ScoreType]float64{ 265 | data.MOS: mean, 266 | }, 267 | Path: recodePath, 268 | }) 269 | } 270 | references = append(references, ref) 271 | } 272 | 273 | study.Put(references) 274 | 275 | return nil 276 | } 277 | 278 | func main() { 279 | destination := flag.String("dest", "", "Destination directory.") 280 | workers := flag.Int("workers", runtime.NumCPU(), "Number of sounds converted in parallel.") 281 | flag.Parse() 282 | if *destination == "" { 283 | flag.Usage() 284 | os.Exit(1) 285 | } 286 | 287 | if err := populate(*destination, *workers); err != nil { 288 | log.Fatal(err) 289 | } 290 | } 291 | -------------------------------------------------------------------------------- /tools/optimizer/simplex_fork.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implementation of simplex search for an external process. 17 | 18 | The external process gets the input vector through environment variables. 19 | Input of vector as setenv("VAR%dimension", val) 20 | Getting the optimized function with regexp match from stdout 21 | of the forked process. 22 | 23 | https://en.wikipedia.org/wiki/Nelder%E2%80%93Mead_method 24 | 25 | start as ./simplex_fork.py binary dimensions amount 26 | """ 27 | 28 | from __future__ import absolute_import 29 | from __future__ import division 30 | from __future__ import print_function 31 | from six.moves import range 32 | import copy 33 | import os 34 | import random 35 | import re 36 | import subprocess 37 | import sys 38 | 39 | def Midpoint(simplex): 40 | """Nelder-Mead-like simplex midpoint calculation.""" 41 | simplex.sort() 42 | dim = len(simplex) - 1 43 | retval = [None] + [0.0] * dim 44 | for i in range(1, dim + 1): 45 | for k in range(dim): 46 | retval[i] += simplex[k][i] 47 | retval[i] /= dim 48 | return retval 49 | 50 | 51 | def Subtract(a, b): 52 | """Vector arithmetic, with [0] being ignored.""" 53 | return [None if k == 0 else a[k] - b[k] for k in range(len(a))] 54 | 55 | def Add(a, b): 56 | """Vector arithmetic, with [0] being ignored.""" 57 | return [None if k == 0 else a[k] + b[k] for k in range(len(a))] 58 | 59 | def Average(a, b): 60 | """Vector arithmetic, with [0] being ignored.""" 61 | return [None if k == 0 else 0.5 * (a[k] + b[k]) for k in range(len(a))] 62 | 63 | 64 | eval_hash = {} 65 | g_best_val = None 66 | g_sample = None 67 | g_perceptual_sample_rate = None 68 | 69 | def EvalCacheForget(): 70 | global eval_hash 71 | eval_hash = {} 72 | global g_sample 73 | g_sample = "zimtohrli_scores_sample" + str(random.randint(0, 10)) 74 | g_sample = "zimtohrli_scores2" 75 | global g_best_val 76 | g_best_val = None 77 | global g_perceptual_sample_rate 78 | g_perceptual_sample_rate = 97 + random.random() * 2.0 79 | 80 | 81 | def Eval(vec, binary_name, cached=True): 82 | """Evaluates the objective function by forking a process. 83 | 84 | Args: 85 | vec: [0] will be set to the objective function, [1:] will 86 | contain the vector position for the objective function. 87 | binary_name: the name of the binary that evaluates the value. 88 | """ 89 | global eval_hash 90 | global g_best_val 91 | global g_sample 92 | global g_perceptual_sample_rate 93 | key = "" 94 | # os.environ["BUTTERAUGLI_OPTIMIZE"] = "1" 95 | for i in range(300): 96 | os.environ["VAR%d" % i] = "0" 97 | for i in range(len(vec) - 1): 98 | os.environ["VAR%d" % i] = str(vec[i + 1]) 99 | key += str(vec[i + 1]) + ":" 100 | if cached and (key in eval_hash): 101 | vec[0] = eval_hash[key] 102 | return 103 | 104 | #corpus = 'coresvnet' 105 | corpus = '*' 106 | print("popen") 107 | process = subprocess.Popen( 108 | ('/usr/lib/google-golang/bin/go', 'run', '../go/bin/score/score.go', '--force', 109 | '--calculate_zimtohrli', 110 | '--calculate', '/usr/local/google/home/jyrki/' + g_sample + '/' + corpus, 111 | '--leaderboard', '/usr/local/google/home/jyrki/' + g_sample + '/' + corpus, 112 | # '--zimtohrli_parameters', '{"PerceptualSampleRate":%.15f}' % g_perceptual_sample_rate 113 | ), 114 | stdout=subprocess.PIPE, 115 | stderr=subprocess.PIPE, 116 | env=dict(os.environ)) 117 | 118 | #print("wait") 119 | #process.wait() 120 | #print("wait complete") 121 | found_score = False 122 | vec[0] = 1.0 123 | dct2 = 0.0 124 | dct4 = 0.0 125 | dct16 = 0.0 126 | dct32 = 0.0 127 | n = 0 128 | print("communicate") 129 | for line in process.communicate(input=None, timeout=600)[0].splitlines(): 130 | print("BE", line) 131 | sys.stdout.flush() 132 | linesplit = line.split(b'|') 133 | if len(linesplit) >= 3 and linesplit[1][:5] == b'Zimto': 134 | mse = float(linesplit[2]) 135 | mean = 1.0 - float(linesplit[5]) 136 | minval = 1.0 - float(linesplit[3]) 137 | maxval = 1.0 - float(linesplit[4]) 138 | vec[0] = mse # * (maxval ** 0.02)# * (mean ** 0.5) # * (minval ** 0.3) * (maxval ** 0.1) 139 | # vec[0] = float(linesplit[2]) 140 | 141 | found_score = True 142 | print("end communicate", found_score) 143 | print("eval: ", vec) 144 | if (vec[0] <= 0.0): 145 | vec[0] = 1e30 146 | eval_hash[key] = vec[0] 147 | if found_score: 148 | if not g_best_val or vec[0] < g_best_val: 149 | g_best_val = vec[0] 150 | print("\nSaving best simplex\n") 151 | with open("best_simplex.txt", "w") as f: 152 | print(vec, file=f) 153 | print("wait really") 154 | process.wait() 155 | print("wait really done") 156 | return 157 | vec[0] = 1e31 158 | print("wait really [not score]") 159 | process.wait() 160 | print("wait really done [no score]") 161 | return 162 | # sys.exit("awful things happened") 163 | 164 | def Reflect(simplex, binary): 165 | """Main iteration step of Nelder-Mead optimization. Modifies `simplex`.""" 166 | simplex.sort() 167 | last = simplex[-1] 168 | mid = Midpoint(simplex) 169 | diff = Subtract(mid, last) 170 | mirrored = Add(mid, diff) 171 | Eval(mirrored, binary) 172 | if mirrored[0] > simplex[-2][0]: 173 | print("\nStill worst\n\n") 174 | # Still the worst, shrink towards the best. 175 | shrinking = Average(simplex[-1], simplex[0]) 176 | Eval(shrinking, binary) 177 | print("\nshrinking...\n\n") 178 | simplex[-1] = shrinking 179 | return 180 | if mirrored[0] < simplex[0][0]: 181 | # new best 182 | print("\nNew Best\n\n") 183 | even_further = Add(mirrored, diff) 184 | Eval(even_further, binary) 185 | if even_further[0] < mirrored[0]: 186 | print("\nEven Further\n\n") 187 | mirrored = even_further 188 | simplex[-1] = mirrored 189 | # try to extend 190 | return 191 | else: 192 | # not a best, not a worst point 193 | simplex[-1] = mirrored 194 | 195 | 196 | def InitialSimplex(vec, dim, amount): 197 | """Initialize the simplex at origin.""" 198 | EvalCacheForget() 199 | best = vec[:] 200 | Eval(best, g_binary) 201 | retval = [best] 202 | comp_order = list(range(1, dim + 1)) 203 | random.shuffle(comp_order) 204 | 205 | for i in range(dim): 206 | index = comp_order[i] 207 | best = retval[0][:] 208 | best_vals = [None, best[0], None] 209 | best[index] += amount 210 | Eval(best, g_binary) 211 | retval.append(best) 212 | best_vals[2] = best[0] 213 | if (retval[0][0] < retval[-1][0]): 214 | print("not best, let's negate this axis") 215 | best = copy.copy(retval[0][:]) 216 | best[index] -= amount 217 | Eval(best, g_binary) 218 | best_vals[0] = best[0] 219 | if (best[0] < retval[-1][0]): 220 | print("found new displacement best by negating") 221 | retval[-1] = best 222 | # perhaps one more try with shrinking amount 223 | best = copy.copy(retval[0][:]) 224 | if (best_vals[1] < best_vals[0] and 225 | best_vals[1] < best_vals[2]): 226 | if (best_vals[0] < best_vals[2] and 227 | best_vals[2] - best_vals[0] > best_vals[0] - best_vals[1]): 228 | best[index] -= 0.1 * amount 229 | Eval(best, g_binary) 230 | if (best[0] < retval[-1][0]): 231 | print("found new best displacement by shrinking (neg)") 232 | retval[-1] = best 233 | if (best_vals[0] > best_vals[2] and 234 | best_vals[0] - best_vals[2] > best_vals[2] - best_vals[1]): 235 | best[index] += 0.1 * amount 236 | Eval(best, g_binary) 237 | if (best[0] < retval[-1][0]): 238 | print("found new best displacement by shrinking (neg)") 239 | retval[-1] = best 240 | 241 | 242 | retval.sort() 243 | return retval 244 | 245 | 246 | if len(sys.argv) != 4: 247 | print("usage: ", sys.argv[0], "binary-name number-of-dimensions simplex-size") 248 | exit(1) 249 | 250 | EvalCacheForget() 251 | 252 | g_dim = int(sys.argv[2]) 253 | g_amount = float(sys.argv[3]) 254 | g_binary = sys.argv[1] 255 | g_simplex = InitialSimplex([None] + [0.0] * g_dim, 256 | g_dim, 7.0 * g_amount) 257 | best = g_simplex[0][:] 258 | g_simplex = InitialSimplex(best, g_dim, g_amount * 2.47) 259 | best = g_simplex[0][:] 260 | g_simplex = InitialSimplex(best, g_dim, g_amount) 261 | best = g_simplex[0][:] 262 | g_simplex = InitialSimplex(best, g_dim, g_amount * 0.33) 263 | best = g_simplex[0][:] 264 | 265 | for restarts in range(99999): 266 | for ii in range(g_dim * 5): 267 | g_simplex.sort() 268 | print("reflect", ii, g_simplex[0]) 269 | Reflect(g_simplex, g_binary) 270 | 271 | mulli = 0.1 + 15 * random.random()**2.0 272 | print("\n\n\nRestart", restarts, "mulli", mulli) 273 | g_simplex.sort() 274 | best = g_simplex[0][:] 275 | g_simplex = InitialSimplex(best, g_dim, g_amount * mulli) 276 | -------------------------------------------------------------------------------- /tools/optimizer/random_fork.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2025 The Zimtohrli Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Implementation of simplex search for an external process. 17 | 18 | The external process gets the input vector through environment variables. 19 | Input of vector as setenv("VAR%dimension", val) 20 | Getting the optimized function with regexp match from stdout 21 | of the forked process. 22 | 23 | https://en.wikipedia.org/wiki/Nelder%E2%80%93Mead_method 24 | 25 | start as ./simplex_fork.py binary dimensions amount 26 | """ 27 | 28 | from __future__ import absolute_import 29 | from __future__ import division 30 | from __future__ import print_function 31 | from six.moves import range 32 | import copy 33 | import os 34 | import random 35 | import re 36 | import subprocess 37 | import sys 38 | 39 | global g_offset 40 | g_offset = int(sys.argv[4]) 41 | 42 | def Midpoint(simplex): 43 | """Nelder-Mead-like simplex midpoint calculation.""" 44 | simplex.sort() 45 | dim = len(simplex) - 1 46 | retval = [None] + [0.0] * dim 47 | for i in range(1, dim + 1): 48 | for k in range(dim): 49 | retval[i] += simplex[k][i] 50 | retval[i] /= dim 51 | return retval 52 | 53 | 54 | def Subtract(a, b): 55 | """Vector arithmetic, with [0] being ignored.""" 56 | return [None if k == 0 else a[k] - b[k] for k in range(len(a))] 57 | 58 | def Add(a, b): 59 | """Vector arithmetic, with [0] being ignored.""" 60 | return [None if k == 0 else a[k] + b[k] for k in range(len(a))] 61 | 62 | def Average(a, b): 63 | """Vector arithmetic, with [0] being ignored.""" 64 | return [None if k == 0 else 0.5 * (a[k] + b[k]) for k in range(len(a))] 65 | 66 | 67 | eval_hash = {} 68 | g_best_val = None 69 | g_sample = None 70 | g_perceptual_sample_rate = None 71 | 72 | def EvalCacheForget(): 73 | global eval_hash 74 | eval_hash = {} 75 | global g_sample 76 | # g_sample = "zimtohrli_scores_sample" + str(random.randint(0, 10)) 77 | # g_sample = "odaqi" 78 | g_sample = "zimtohrli_scores2" 79 | global g_best_val 80 | g_best_val = None 81 | global g_perceptual_sample_rate 82 | g_perceptual_sample_rate = 97 + random.random() * 2.0 83 | 84 | 85 | def Eval(vec, binary_name, cached=True): 86 | """Evaluates the objective function by forking a process. 87 | 88 | Args: 89 | vec: [0] will be set to the objective function, [1:] will 90 | contain the vector position for the objective function. 91 | binary_name: the name of the binary that evaluates the value. 92 | """ 93 | global eval_hash 94 | global g_best_val 95 | global g_sample 96 | global g_perceptual_sample_rate 97 | key = "" 98 | # os.environ["BUTTERAUGLI_OPTIMIZE"] = "1" 99 | for i in range(999): 100 | os.environ["VAR%d" % i] = "0" 101 | for i in range(len(vec) - 1): 102 | os.environ["VAR%d" % i] = str(vec[i + 1]) 103 | key += str(vec[i + 1]) + ":" 104 | if cached and (key in eval_hash): 105 | vec[0] = eval_hash[key] 106 | return 107 | 108 | corpus = '*' 109 | #corpus = 'coresvnet' 110 | #corpus = 'odaq' 111 | print("popen") 112 | process = subprocess.Popen( 113 | ('/usr/bin/go', 'run', '../go/bin/score/score.go', '--force', 114 | '--calculate_zimtohrli', 115 | '--calculate', '/usr/local/google/home/jyrki/' + g_sample + '/' + corpus, 116 | '--leaderboard', '/usr/local/google/home/jyrki/' + g_sample + '/' + corpus, 117 | # '--zimtohrli_parameters', '{"PerceptualSampleRate":%.15f}' % g_perceptual_sample_rate 118 | ), 119 | stdout=subprocess.PIPE, 120 | stderr=subprocess.PIPE, 121 | env=dict(os.environ)) 122 | 123 | #print("wait") 124 | #process.wait() 125 | #print("wait complete") 126 | found_score = False 127 | vec[0] = 1.0 128 | dct2 = 0.0 129 | dct4 = 0.0 130 | dct16 = 0.0 131 | dct32 = 0.0 132 | n = 0 133 | print("communicate") 134 | for line in process.communicate(input=None, timeout=1000)[0].splitlines(): 135 | print("BE", line) 136 | sys.stdout.flush() 137 | linesplit = line.split(b'|') 138 | if len(linesplit) >= 3 and linesplit[1][:5] == b'Zimto': 139 | mse = float(linesplit[2]) 140 | mean = 1.0 - float(linesplit[5]) 141 | minval = 1.0 - float(linesplit[3]) 142 | maxval = 1.0 - float(linesplit[4]) 143 | vec[0] = mse # * (maxval ** 0.02)# * (mean ** 0.5) # * (minval ** 0.3) * (maxval ** 0.1) 144 | # vec[0] = float(linesplit[2]) 145 | 146 | found_score = True 147 | print("end communicate", found_score) 148 | print("eval: ", vec) 149 | if (vec[0] <= 0.0): 150 | vec[0] = 1e30 151 | eval_hash[key] = vec[0] 152 | if found_score: 153 | if not g_best_val or vec[0] < g_best_val: 154 | g_best_val = vec[0] 155 | print("\nSaving best simplex\n") 156 | with open("best_simplex.txt", "w") as f: 157 | print(vec, file=f) 158 | print("wait really") 159 | process.wait() 160 | print("wait really done") 161 | return 162 | vec[0] = 1e31 163 | print("wait really [not score]") 164 | process.wait() 165 | print("wait really done [no score]") 166 | return 167 | # sys.exit("awful things happened") 168 | 169 | def Reflect(simplex, binary): 170 | """Main iteration step of Nelder-Mead optimization. Modifies `simplex`.""" 171 | simplex.sort() 172 | last = simplex[-1] 173 | mid = Midpoint(simplex) 174 | diff = Subtract(mid, last) 175 | mirrored = Add(mid, diff) 176 | Eval(mirrored, binary) 177 | if mirrored[0] > simplex[-2][0]: 178 | print("\nStill worst\n\n") 179 | # Still the worst, shrink towards the best. 180 | shrinking = Average(simplex[-1], simplex[0]) 181 | Eval(shrinking, binary) 182 | print("\nshrinking...\n\n") 183 | simplex[-1] = shrinking 184 | return 185 | if mirrored[0] < simplex[0][0]: 186 | # new best 187 | print("\nNew Best\n\n") 188 | even_further = Add(mirrored, diff) 189 | Eval(even_further, binary) 190 | if even_further[0] < mirrored[0]: 191 | print("\nEven Further\n\n") 192 | mirrored = even_further 193 | simplex[-1] = mirrored 194 | # try to extend 195 | return 196 | else: 197 | # not a best, not a worst point 198 | simplex[-1] = mirrored 199 | 200 | import copy 201 | 202 | def InitialSimplex(vec, dim, amount, biases): 203 | """Initialize the simplex at origin.""" 204 | EvalCacheForget() 205 | best = vec[:] 206 | Eval(best, g_binary) 207 | retval = [best] 208 | 209 | for i in range(dim): 210 | best = retval[0][:] 211 | bestcopy = copy.copy(best) 212 | rangelimit = random.random() < 0.4 213 | for k in range(dim): 214 | r = ((random.random() - 0.5) * 2) ** 5 215 | best[k + 1] += amount * r 216 | ave = 0 217 | if (k != 0 and k != dim - 1 and biases[k + 1] != 0): 218 | ave = (best[k] + best[k + 2]) * 0.5 219 | best[k + 1] += 0.001 * amount * (biases[k] + ave - best[k + 1]) * random.random() 220 | #if k < 90 or (k >= 130 and k < g_offset): 221 | if k < g_offset: 222 | best[k + 1] = 0.0 223 | if rangelimit: 224 | rangevals = max(0, dim - g_offset - 5) 225 | temp_offset = random.randint(g_offset, g_offset + rangevals) 226 | if k < temp_offset or k >= temp_offset + 5: 227 | best[k + 1] = bestcopy[k + 1] 228 | Eval(best, g_binary) 229 | retval.append(best) 230 | if (best[0] < retval[0][0]): 231 | amount *= 1.25 232 | print("scaling up amount", amount) 233 | else: 234 | amount *= 0.98 235 | print("scaling down amount", amount) 236 | retval.sort() 237 | return retval 238 | 239 | 240 | if len(sys.argv) != 5: 241 | print("usage: ", sys.argv[0], "binary-name number-of-dimensions simplex-size first-non-zero-dim") 242 | exit(1) 243 | 244 | EvalCacheForget() 245 | 246 | 247 | def FileLine(i): 248 | path = "/usr/local/google/home/jyrki/github/zimtohrli/cpp/zimt/fourier_bank.cc" 249 | 250 | linenoset = [] 251 | valset = [] 252 | print("trying", i) 253 | for lineno, line in enumerate(open(path).readlines()): 254 | if 'atof(getenv("VAR' in line: 255 | ixu = int(line.split("VAR")[1].split('"')[0]) 256 | if not (ixu == i - 1 or ixu == i or ixu == i + 1): 257 | continue 258 | linenoset.append(lineno) 259 | if not '+ atof(' in line: 260 | return 0.0 261 | numstr = line[:line.index('+ atof(')].strip().split(' ')[-1] 262 | numstr = numstr.split("(")[-1] 263 | if numstr[-1] == 'f': 264 | numstr = numstr[0:-1] 265 | num = float(numstr) 266 | valset.append(num) 267 | if len(linenoset) == 3 and abs(linenoset[2] - linenoset[0]) == 2 and linenoset[0] + linenoset[2] == 2 * linenoset[1]: 268 | # good 269 | bias = 0.5 * (valset[0] + valset[2]) - valset[1] 270 | print("found bias for", i, bias) 271 | return bias 272 | 273 | return 0.0 274 | 275 | g_dim = int(sys.argv[2]) 276 | g_biases = [FileLine(i) for i in range(g_dim)] 277 | 278 | 279 | g_amount = float(sys.argv[3]) 280 | g_binary = sys.argv[1] 281 | g_simplex = InitialSimplex([None] + [0.0] * g_dim, 282 | g_dim, 7.0 * g_amount, g_biases) 283 | best = g_simplex[0][:] 284 | g_simplex = InitialSimplex(best, g_dim, g_amount * 2.47, g_biases) 285 | best = g_simplex[0][:] 286 | g_simplex = InitialSimplex(best, g_dim, g_amount, g_biases) 287 | best = g_simplex[0][:] 288 | g_simplex = InitialSimplex(best, g_dim, g_amount * 0.33, g_biases) 289 | best = g_simplex[0][:] 290 | 291 | for restarts in range(99999): 292 | for ii in range(g_dim * 5): 293 | g_simplex.sort() 294 | print("reflect", ii, g_simplex[0]) 295 | Reflect(g_simplex, g_binary) 296 | 297 | mulli = 0.1 + 15 * random.random()**2.0 298 | print("\n\n\nRestart", restarts, "mulli", mulli) 299 | g_simplex.sort() 300 | best = g_simplex[0][:] 301 | g_simplex = InitialSimplex(best, g_dim, g_amount * mulli, g_biases) 302 | -------------------------------------------------------------------------------- /cmake/visqol_manager.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Google LLC, Andrew Hines 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // This file replaces visqol_manager.cc in the normal visqol distribution, 16 | // since it lets us build ViSQOL without redundant TensorFlow dependencies. 17 | 18 | #include "visqol_manager.h" 19 | 20 | #include 21 | #include 22 | #include 23 | 24 | #include "absl/base/internal/raw_logging.h" 25 | #include "absl/memory/memory.h" 26 | #include "absl/status/status.h" 27 | #include "absl/status/statusor.h" 28 | #include "absl/strings/str_cat.h" 29 | #include "alignment.h" 30 | #include "analysis_window.h" 31 | #include "audio_signal.h" 32 | #include "gammatone_filterbank.h" 33 | #include "misc_audio.h" 34 | #include "neurogram_similiarity_index_measure.h" 35 | #include "similarity_result.h" 36 | #include "speech_similarity_to_quality_mapper.h" 37 | #include "src/proto/similarity_result.pb.h" // Generated by cc_proto_library rule 38 | #include "status_macros.h" 39 | #include "vad_patch_creator.h" 40 | #include "visqol.h" 41 | 42 | namespace Visqol { 43 | 44 | const size_t k16kSampleRate = 16000; 45 | const size_t k48kSampleRate = 48000; 46 | const size_t VisqolManager::kPatchSize = 30; 47 | const size_t VisqolManager::kPatchSizeSpeech = 20; 48 | const size_t VisqolManager::kNumBandsAudio = 32; 49 | const size_t VisqolManager::kNumBandsSpeech = 21; 50 | const double VisqolManager::kMinimumFreq = 50; // wideband 51 | const double VisqolManager::kOverlap = 0.25; // 25% overlap 52 | const double VisqolManager::kDurationMismatchTolerance = 1.0; 53 | 54 | absl::Status VisqolManager::Init( 55 | const FilePath& similarity_to_quality_mapper_model, bool use_speech_mode, 56 | bool use_unscaled_speech, int search_window, bool use_lattice_model) { 57 | use_speech_mode_ = use_speech_mode; 58 | assert(use_speech_mode == false); 59 | use_unscaled_speech_mos_mapping_ = use_unscaled_speech; 60 | search_window_ = search_window; 61 | use_lattice_model_ = use_lattice_model; 62 | 63 | InitPatchCreator(); 64 | InitPatchSelector(); 65 | InitSpectrogramBuilder(); 66 | auto status = 67 | InitSimilarityToQualityMapper(similarity_to_quality_mapper_model); 68 | 69 | if (status.ok()) { 70 | is_initialized_ = true; 71 | } else { 72 | ABSL_RAW_LOG(ERROR, "%s", status.ToString().c_str()); 73 | } 74 | 75 | return status; 76 | } 77 | 78 | absl::Status VisqolManager::Init( 79 | absl::string_view similarity_to_quality_mapper_model_string, 80 | bool use_speech_mode, bool use_unscaled_speech, int search_window, 81 | bool use_lattice_model) { 82 | return Init(FilePath(similarity_to_quality_mapper_model_string), 83 | use_speech_mode, use_unscaled_speech, search_window, 84 | use_lattice_model); 85 | } 86 | 87 | void VisqolManager::InitPatchCreator() { 88 | if (use_speech_mode_) { 89 | patch_creator_ = absl::make_unique(kPatchSizeSpeech); 90 | } else { 91 | patch_creator_ = absl::make_unique(kPatchSize); 92 | } 93 | } 94 | 95 | void VisqolManager::InitPatchSelector() { 96 | // Setup the patch similarity comparator to use the Neurogram. 97 | patch_selector_ = absl::make_unique( 98 | absl::make_unique()); 99 | } 100 | 101 | void VisqolManager::InitSpectrogramBuilder() { 102 | if (use_speech_mode_) { 103 | spectrogram_builder_ = absl::make_unique( 104 | GammatoneFilterBank{kNumBandsSpeech, kMinimumFreq}, true); 105 | } else { 106 | spectrogram_builder_ = absl::make_unique( 107 | GammatoneFilterBank{kNumBandsAudio, kMinimumFreq}, false); 108 | } 109 | } 110 | 111 | absl::Status VisqolManager::InitSimilarityToQualityMapper( 112 | FilePath sim_to_quality_mapper_model) { 113 | if (use_lattice_model_) { 114 | ABSL_RAW_LOG( 115 | WARNING, 116 | "Lattice models are not yet supported for audio mode, falling back " 117 | "to SVR model."); 118 | } 119 | sim_to_qual_ = absl::make_unique( 120 | sim_to_quality_mapper_model); 121 | return sim_to_qual_->Init(); 122 | } 123 | 124 | absl::StatusOr VisqolManager::Run( 125 | const FilePath& ref_signal_path, const FilePath& deg_signal_path) { 126 | // Ensure the initialization succeeded. 127 | VISQOL_RETURN_IF_ERROR(ErrorIfNotInitialized()); 128 | 129 | // Load the wav audio files as mono. 130 | const AudioSignal ref_signal = MiscAudio::LoadAsMono(ref_signal_path); 131 | AudioSignal deg_signal = MiscAudio::LoadAsMono(deg_signal_path); 132 | 133 | // If the sim result was successfully calculated, set the signal file paths. 134 | // Else, return the StatusOr failure. 135 | SimilarityResultMsg sim_result_msg; 136 | VISQOL_ASSIGN_OR_RETURN(sim_result_msg, Run(ref_signal, deg_signal)); 137 | sim_result_msg.set_reference_filepath(ref_signal_path.Path()); 138 | sim_result_msg.set_degraded_filepath(deg_signal_path.Path()); 139 | return sim_result_msg; 140 | } 141 | 142 | absl::StatusOr VisqolManager::Run( 143 | const AudioSignal& ref_signal, AudioSignal& deg_signal) { 144 | // Ensure the initialization succeeded. 145 | VISQOL_RETURN_IF_ERROR(ErrorIfNotInitialized()); 146 | 147 | VISQOL_RETURN_IF_ERROR(ValidateInputAudio(ref_signal, deg_signal)); 148 | 149 | // Adjust for codec initial padding. 150 | auto alignment_result = Alignment::GloballyAlign(ref_signal, deg_signal); 151 | deg_signal = std::get<0>(alignment_result); 152 | 153 | const AnalysisWindow window{ref_signal.sample_rate, kOverlap}; 154 | 155 | // If the sim result is successfully calculated, populate the protobuf msg. 156 | // Else, return the StatusOr failure. 157 | const Visqol visqol; 158 | SimilarityResult sim_result; 159 | VISQOL_ASSIGN_OR_RETURN( 160 | sim_result, visqol.CalculateSimilarity( 161 | ref_signal, deg_signal, spectrogram_builder_.get(), 162 | window, patch_creator_.get(), patch_selector_.get(), 163 | sim_to_qual_.get(), search_window_)); 164 | SimilarityResultMsg sim_result_msg = PopulateSimResultMsg(sim_result); 165 | sim_result_msg.set_alignment_lag_s(std::get<1>(alignment_result)); 166 | return sim_result_msg; 167 | } 168 | 169 | SimilarityResultMsg VisqolManager::PopulateSimResultMsg( 170 | const SimilarityResult& sim_result) { 171 | SimilarityResultMsg sim_result_msg; 172 | sim_result_msg.set_moslqo(sim_result.moslqo); 173 | sim_result_msg.set_vnsim(sim_result.vnsim); 174 | 175 | for (double val : sim_result.fvnsim) { 176 | sim_result_msg.add_fvnsim(val); 177 | } 178 | 179 | for (double val : sim_result.fvnsim10) { 180 | sim_result_msg.add_fvnsim10(val); 181 | } 182 | 183 | for (double val : sim_result.fstdnsim) { 184 | sim_result_msg.add_fstdnsim(val); 185 | } 186 | 187 | for (double val : sim_result.center_freq_bands) { 188 | sim_result_msg.add_center_freq_bands(val); 189 | } 190 | 191 | for (double val : sim_result.fvdegenergy) { 192 | sim_result_msg.add_fvdegenergy(val); 193 | } 194 | 195 | for (const PatchSimilarityResult& patch : sim_result.debug_info.patch_sims) { 196 | SimilarityResultMsg_PatchSimilarityMsg* patch_msg = 197 | sim_result_msg.add_patch_sims(); 198 | patch_msg->set_similarity(patch.similarity); 199 | patch_msg->set_ref_patch_start_time(patch.ref_patch_start_time); 200 | patch_msg->set_ref_patch_end_time(patch.ref_patch_end_time); 201 | patch_msg->set_deg_patch_start_time(patch.deg_patch_start_time); 202 | patch_msg->set_deg_patch_end_time(patch.deg_patch_end_time); 203 | for (double each_fbm : patch.freq_band_means.ToVector()) { 204 | patch_msg->add_freq_band_means(each_fbm); 205 | } 206 | } 207 | 208 | return sim_result_msg; 209 | } 210 | 211 | absl::Status VisqolManager::ErrorIfNotInitialized() { 212 | if (is_initialized_ == false) { 213 | return absl::Status(absl::StatusCode::kAborted, 214 | "VisqolManager must be initialized before use."); 215 | } else { 216 | return absl::Status(); 217 | } 218 | } 219 | 220 | absl::Status VisqolManager::ValidateInputAudio(const AudioSignal& ref_signal, 221 | const AudioSignal& deg_signal) { 222 | // Warn if there is an excessive difference in durations. 223 | double ref_duration = ref_signal.GetDuration(); 224 | double deg_duration = deg_signal.GetDuration(); 225 | if (std::abs(ref_duration - deg_duration) > kDurationMismatchTolerance) { 226 | ABSL_RAW_LOG(WARNING, 227 | "Mismatch in duration between reference and " 228 | "degraded signal. Reference is %.2f seconds. Degraded is " 229 | "%.2f seconds.", 230 | ref_duration, deg_duration); 231 | } 232 | 233 | // Error if the signals have different sample rates. 234 | if (ref_signal.sample_rate != deg_signal.sample_rate) { 235 | return absl::InvalidArgumentError(absl::StrCat( 236 | "Input audio signals have different sample rates! Reference audio " 237 | "sample rate: ", 238 | ref_signal.sample_rate, 239 | ". Degraded audio sample rate: ", deg_signal.sample_rate)); 240 | } 241 | 242 | if (use_speech_mode_) { 243 | // Warn if input sample rate is > 16khz. 244 | if (ref_signal.sample_rate > k16kSampleRate) { 245 | ABSL_RAW_LOG(WARNING, 246 | "Input audio sample rate is above 16kHz, which" 247 | " may have undesired effects for speech mode. Consider" 248 | " resampling to 16kHz."); 249 | } 250 | } else { 251 | // Warn if the signals' sample rate is not 48k for full audio mode. 252 | if (ref_signal.sample_rate != k48kSampleRate) { 253 | ABSL_RAW_LOG(WARNING, 254 | "Input audio does not have the expected sample" 255 | " rate of 48kHz! This may negatively effect the prediction" 256 | " of the MOS-LQO score."); 257 | } 258 | } 259 | 260 | return absl::Status(); 261 | } 262 | } // namespace Visqol 263 | --------------------------------------------------------------------------------