├── third_party ├── BUILD ├── eigen │ ├── BUILD │ ├── eigen.bzl │ └── eigen.BUILD ├── glog │ ├── BUILD │ └── glog.bzl ├── audionamix │ ├── BUILD │ ├── wave.BUILD │ └── audionamix.bzl ├── ffmpeg │ ├── BUILD │ ├── ffmpeg.bzl │ └── ffmpeg.BUILD ├── googletest │ ├── BUILD │ └── googletest.bzl ├── models │ ├── BUILD │ ├── stems.BUILD │ ├── models.BUILD │ └── models.bzl ├── nlohmann │ ├── BUILD │ ├── nlohmann.BUILD │ └── nlohmann.bzl ├── tensorflow │ ├── BUILD │ └── tensorflow.bzl ├── audio_example │ ├── BUILD │ └── audio_example.bzl ├── tensorflowlite │ ├── BUILD │ └── tensorflowlite.bzl ├── zlib │ ├── BUILD │ ├── zlib.bzl │ └── zlib.BUILD └── dependencies.bzl ├── .bazelrc ├── spleeter ├── logging │ ├── logging.h │ ├── test │ │ ├── BUILD │ │ └── logging_tests.cpp │ └── BUILD ├── argument_parser │ ├── test │ │ ├── BUILD │ │ └── argument_parser_tests.cpp │ ├── BUILD │ ├── i_argument_parser.h │ ├── argument_parser.h │ ├── cli_options.h │ └── argument_parser.cpp ├── audio │ ├── test │ │ ├── BUILD │ │ └── audio_adapter_tests.cpp │ ├── BUILD │ ├── i_audio_adapter.h │ ├── audionamix_audio_adapter.h │ ├── audionamix_audio_adapter.cpp │ ├── ffmpeg_audio_adapter.h │ └── ffmpeg_audio_adapter.cpp ├── datatypes │ ├── BUILD │ ├── audio_properties.h │ ├── waveform.h │ └── inference_engine.h ├── test │ ├── spleeter_tests.cpp │ ├── BUILD │ └── separator_tests.cpp ├── inference_engine │ ├── null_inference_engine.cpp │ ├── i_inference_engine.h │ ├── null_inference_engine.h │ ├── BUILD │ ├── inference_engine_strategy.h │ ├── inference_engine_strategy.cpp │ ├── tf_inference_engine.h │ ├── tflite_inference_engine.h │ ├── tf_inference_engine.cpp │ ├── test │ │ └── inference_engine_tests.cpp │ └── tflite_inference_engine.cpp ├── BUILD ├── spleeter.h ├── spleeter.cpp ├── i_separator.h ├── separator.cpp └── separator.h ├── WORKSPACE ├── application ├── BUILD └── main.cpp ├── .gitattributes ├── tools ├── generate_compile_commands.sh └── tflite_converter │ ├── tflite_converter.py │ ├── README.md │ └── export_model.py ├── .gitignore ├── Dockerfile ├── LICENSE ├── .clang-format ├── BUILD ├── .gitlab-ci.yml ├── .clang-tidy ├── README.md └── scripts └── unet.py /third_party/BUILD: -------------------------------------------------------------------------------- 1 | # DO NOT REMOVE THIS FILE 2 | -------------------------------------------------------------------------------- /third_party/eigen/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | -------------------------------------------------------------------------------- /third_party/glog/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | -------------------------------------------------------------------------------- /third_party/audionamix/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | -------------------------------------------------------------------------------- /third_party/ffmpeg/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | -------------------------------------------------------------------------------- /third_party/googletest/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | -------------------------------------------------------------------------------- /third_party/models/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | -------------------------------------------------------------------------------- /third_party/nlohmann/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | -------------------------------------------------------------------------------- /third_party/tensorflow/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | -------------------------------------------------------------------------------- /third_party/audio_example/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | -------------------------------------------------------------------------------- /third_party/tensorflowlite/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | -------------------------------------------------------------------------------- /third_party/zlib/BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) # BSD/MIT-like license (for zlib) 2 | -------------------------------------------------------------------------------- /.bazelrc: -------------------------------------------------------------------------------- 1 | build --color=yes 2 | build --test_env="GTEST_COLOR=TRUE" --test_output=errors 3 | build --cxxopt="-std=c++14" --cxxopt="-Wall" 4 | -------------------------------------------------------------------------------- /third_party/ffmpeg/ffmpeg.bzl: -------------------------------------------------------------------------------- 1 | def ffmpeg(): 2 | if "ffmpeg" not in native.existing_rules(): 3 | native.new_local_repository( 4 | name = "ffmpeg", 5 | build_file = "//third_party/ffmpeg:ffmpeg.BUILD", 6 | path = "/usr/", 7 | ) 8 | -------------------------------------------------------------------------------- /third_party/models/stems.BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//visibility:public"]) 2 | 3 | filegroup( 4 | name = "saved_model", 5 | srcs = glob(["saved_model/**/*"]), 6 | ) 7 | 8 | filegroup( 9 | name = "tflite", 10 | srcs = glob(["*.tflite"]), 11 | ) 12 | -------------------------------------------------------------------------------- /spleeter/logging/logging.h: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @copyright Copyright (c) 2020-2021. MIT License. 4 | /// 5 | #ifndef SPLEETER_COMMON_LOGGING_H 6 | #define SPLEETER_COMMON_LOGGING_H 7 | 8 | #define GOOGLE_STRIP_LOG (WARNING) 9 | #include 10 | 11 | #endif /// SPLEETER_COMMON_LOGGING_H 12 | -------------------------------------------------------------------------------- /third_party/nlohmann/nlohmann.BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_cc//cc:defs.bzl", "cc_library") 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | cc_library( 6 | name = "json", 7 | hdrs = ["single_include/nlohmann/json.hpp"], 8 | copts = ["-std=c++14"], 9 | includes = ["single_include"], 10 | ) 11 | -------------------------------------------------------------------------------- /spleeter/logging/test/BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_cc//cc:defs.bzl", "cc_test") 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | cc_test( 6 | name = "unit_test", 7 | srcs = glob(["*.cpp"]), 8 | linkstatic = True, 9 | deps = [ 10 | "//spleeter/logging", 11 | "@googletest//:gtest_main", 12 | ], 13 | ) 14 | -------------------------------------------------------------------------------- /spleeter/argument_parser/test/BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_cc//cc:defs.bzl", "cc_test") 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | cc_test( 6 | name = "unit_test", 7 | srcs = glob(["*.cpp"]), 8 | linkstatic = True, 9 | deps = [ 10 | "//spleeter/argument_parser", 11 | "//spleeter/logging", 12 | "@googletest//:gtest_main", 13 | ], 14 | ) 15 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- 1 | workspace(name = "spleeter") 2 | 3 | load("@spleeter//third_party:dependencies.bzl", "spleeter_dependencies") 4 | 5 | spleeter_dependencies() 6 | 7 | load("@tensorflowlite//third_party:dependencies.bzl", "tensorflowlite_dependencies") 8 | 9 | tensorflowlite_dependencies() 10 | 11 | load("@tensorflow//third_party:dependencies.bzl", "tensorflow_dependencies") 12 | 13 | tensorflow_dependencies() 14 | -------------------------------------------------------------------------------- /application/BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_cc//cc:defs.bzl", "cc_binary") 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | cc_binary( 6 | name = "spleeter", 7 | srcs = glob(["*.cpp"]), 8 | copts = [ 9 | "-std=c++14", 10 | "-Wall", 11 | "-Werror", 12 | ], 13 | data = ["@audio_example//file"], 14 | deps = [ 15 | "//spleeter", 16 | ], 17 | ) 18 | -------------------------------------------------------------------------------- /spleeter/audio/test/BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_cc//cc:defs.bzl", "cc_test") 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | cc_test( 6 | name = "unit_test", 7 | srcs = glob(["*.cpp"]), 8 | data = [ 9 | "@audio_example//file", 10 | ], 11 | linkstatic = True, 12 | deps = [ 13 | "//spleeter/audio", 14 | "@googletest//:gtest_main", 15 | ], 16 | ) 17 | -------------------------------------------------------------------------------- /third_party/eigen/eigen.bzl: -------------------------------------------------------------------------------- 1 | load("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository") 2 | 3 | def eigen(): 4 | if "eigen" not in native.existing_rules(): 5 | new_git_repository( 6 | name = "eigen", 7 | build_file = "//third_party/eigen:eigen.BUILD", 8 | remote = "https://gitlab.com/libeigen/eigen.git", 9 | tag = "3.3.7", 10 | ) 11 | -------------------------------------------------------------------------------- /third_party/eigen/eigen.BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//visibility:public"]) 2 | 3 | cc_library( 4 | name = "eigen", 5 | hdrs = glob( 6 | ["Eigen/**"], 7 | exclude = [ 8 | "Eigen/src/Core/arch/AVX/PacketMathGoogleTest.cc", 9 | ], 10 | ), 11 | copts = [ 12 | "-std=c++14", 13 | ], 14 | includes = ["."], 15 | visibility = ["//visibility:public"], 16 | ) 17 | -------------------------------------------------------------------------------- /third_party/googletest/googletest.bzl: -------------------------------------------------------------------------------- 1 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 2 | 3 | def googletest(): 4 | if "googletest" not in native.existing_rules(): 5 | http_archive( 6 | name = "googletest", 7 | url = "https://github.com/google/googletest/archive/release-1.10.0.zip", 8 | sha256 = "94c634d499558a76fa649edb13721dce6e98fb1e7018dfaeba3cd7a083945e91", 9 | strip_prefix = "googletest-release-1.10.0", 10 | ) 11 | -------------------------------------------------------------------------------- /third_party/nlohmann/nlohmann.bzl: -------------------------------------------------------------------------------- 1 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 2 | 3 | def nlohmann(): 4 | if "nlohmann" not in native.existing_rules(): 5 | http_archive( 6 | name = "nlohmann", 7 | build_file = "//third_party/nlohmann:nlohmann.BUILD", 8 | url = "https://github.com/nlohmann/json/releases/download/v3.7.3/include.zip", 9 | sha256 = "87b5884741427220d3a33df1363ae0e8b898099fbc59f1c451113f6732891014", 10 | ) 11 | -------------------------------------------------------------------------------- /spleeter/datatypes/BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_cc//cc:defs.bzl", "cc_library") 2 | load("@bazel_tools//tools/build_defs/pkg:pkg.bzl", "pkg_tar") 3 | 4 | package(default_visibility = ["//visibility:public"]) 5 | 6 | cc_library( 7 | name = "datatypes", 8 | srcs = glob(["*.cpp"]), 9 | hdrs = glob(["*.h"]), 10 | linkstatic = True, 11 | ) 12 | 13 | pkg_tar( 14 | name = "includes", 15 | srcs = glob(["*.h"]), 16 | mode = "0644", 17 | package_dir = "spleeter/datatypes", 18 | tags = ["manual"], 19 | ) 20 | -------------------------------------------------------------------------------- /spleeter/logging/BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_cc//cc:defs.bzl", "cc_library") 2 | load("@bazel_tools//tools/build_defs/pkg:pkg.bzl", "pkg_tar") 3 | 4 | package(default_visibility = ["//visibility:public"]) 5 | 6 | cc_library( 7 | name = "logging", 8 | srcs = glob(["*.cpp"]), 9 | hdrs = glob(["*.h"]), 10 | deps = [ 11 | "@glog", 12 | ], 13 | ) 14 | 15 | pkg_tar( 16 | name = "includes", 17 | srcs = glob(["*.h"]), 18 | mode = "0644", 19 | package_dir = "spleeter/logging", 20 | tags = ["manual"], 21 | ) 22 | -------------------------------------------------------------------------------- /third_party/tensorflow/tensorflow.bzl: -------------------------------------------------------------------------------- 1 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 2 | 3 | def tensorflow(): 4 | """ Load TensorFlow as Dependency """ 5 | if "tensorflow" not in native.existing_rules(): 6 | http_archive( 7 | name = "tensorflow", 8 | sha256 = "74d79a3c7b7c8cd23f08a8695bc8936deeae5223dc920c863bccb684171a5e7a", 9 | strip_prefix = "libtensorflow_cc-2.3.0-linux", 10 | url = "https://github.com/jinay1991/spleeter/releases/download/v2.3/libtensorflow_cc-2.3.0-linux.tar.gz", 11 | ) 12 | -------------------------------------------------------------------------------- /third_party/audionamix/wave.BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_cc//cc:defs.bzl", "cc_library") 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | cc_library( 6 | name = "wave", 7 | srcs = glob( 8 | [ 9 | "src/wave/*.cc", 10 | "src/wave/header/*.cc", 11 | ], 12 | exclude = glob(["src/wave/*_test.cc"]), 13 | ), 14 | hdrs = glob([ 15 | "src/wave/*.h", 16 | "src/wave/header/*.h", 17 | ]), 18 | includes = [ 19 | "src", 20 | "src/wave/header", 21 | ], 22 | linkstatic = True, 23 | ) 24 | -------------------------------------------------------------------------------- /third_party/audionamix/audionamix.bzl: -------------------------------------------------------------------------------- 1 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 2 | 3 | def audionamix(): 4 | """ Load audionamix as Dependency """ 5 | if "audionamix" not in native.existing_rules(): 6 | http_archive( 7 | name = "audionamix", 8 | strip_prefix = "wave-0.8.2a", 9 | build_file = "//third_party/audionamix:wave.BUILD", 10 | sha256 = "b03fb60abf053107864e69c530c8aad5866ccb9de7d216b18ce347cf2b644dd6", 11 | url = "https://github.com/audionamix/wave/archive/v0.8.2a.zip", 12 | ) 13 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pb filter=lfs diff=lfs merge=lfs -text 2 | *.tflite filter=lfs diff=lfs merge=lfs -text 3 | *.ckpt.* filter=lfs diff=lfs merge=lfs -text 4 | *.jpg filter=lfs diff=lfs merge=lfs -text 5 | *.bmp filter=lfs diff=lfs merge=lfs -text 6 | *.png filter=lfs diff=lfs merge=lfs -text 7 | *.jpeg filter=lfs diff=lfs merge=lfs -text 8 | *.wav filter=lfs diff=lfs merge=lfs -text 9 | *.ogg filter=lfs diff=lfs merge=lfs -text 10 | *.m4a filter=lfs diff=lfs merge=lfs -text 11 | *.wma filter=lfs diff=lfs merge=lfs -text 12 | *.flac filter=lfs diff=lfs merge=lfs -text 13 | *.mp3 filter=lfs diff=lfs merge=lfs -text 14 | -------------------------------------------------------------------------------- /spleeter/argument_parser/BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_cc//cc:defs.bzl", "cc_library") 2 | load("@bazel_tools//tools/build_defs/pkg:pkg.bzl", "pkg_tar") 3 | 4 | package(default_visibility = ["//visibility:public"]) 5 | 6 | cc_library( 7 | name = "argument_parser", 8 | srcs = glob(["*.cpp"]), 9 | hdrs = glob(["*.h"]), 10 | deps = [ 11 | "//spleeter/datatypes", 12 | "//spleeter/logging", 13 | ], 14 | ) 15 | 16 | pkg_tar( 17 | name = "includes", 18 | srcs = glob(["*.h"]), 19 | mode = "0644", 20 | package_dir = "spleeter/argument_parser", 21 | tags = ["manual"], 22 | ) 23 | -------------------------------------------------------------------------------- /third_party/tensorflowlite/tensorflowlite.bzl: -------------------------------------------------------------------------------- 1 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 2 | 3 | def tensorflowlite(): 4 | """ Load TensorFlow Lite as Dependency """ 5 | if "tensorflowlite" not in native.existing_rules(): 6 | http_archive( 7 | name = "tensorflowlite", 8 | sha256 = "09f660b02222f1e6e59287e9ddd9755a7835aaf4cb0e9f46a1efb5af0df5a2c6", 9 | strip_prefix = "libtensorflowlite_cc-2.3.0-linux", 10 | url = "https://github.com/jinay1991/spleeter/releases/download/v2.3/libtensorflowlite_cc-2.3.0-linux.tar.gz", 11 | ) 12 | -------------------------------------------------------------------------------- /spleeter/audio/BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_cc//cc:defs.bzl", "cc_library") 2 | load("@bazel_tools//tools/build_defs/pkg:pkg.bzl", "pkg_tar") 3 | 4 | package(default_visibility = ["//visibility:public"]) 5 | 6 | cc_library( 7 | name = "audio", 8 | srcs = glob(["*.cpp"]), 9 | hdrs = glob(["*.h"]), 10 | deps = [ 11 | "//spleeter/datatypes", 12 | "//spleeter/logging", 13 | "@audionamix//:wave", 14 | "@ffmpeg", 15 | ], 16 | ) 17 | 18 | pkg_tar( 19 | name = "includes", 20 | srcs = glob(["*.h"]), 21 | mode = "0644", 22 | package_dir = "spleeter/audio", 23 | tags = ["manual"], 24 | ) 25 | -------------------------------------------------------------------------------- /third_party/zlib/zlib.bzl: -------------------------------------------------------------------------------- 1 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 2 | 3 | def zlib(): 4 | if "zlib" not in native.existing_rules(): 5 | http_archive( 6 | name = "zlib", 7 | build_file = "//third_party/zlib:zlib.BUILD", 8 | sha256 = "c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1", 9 | strip_prefix = "zlib-1.2.11", 10 | urls = [ 11 | "https://storage.googleapis.com/mirror.tensorflow.org/zlib.net/zlib-1.2.11.tar.gz", 12 | "https://zlib.net/zlib-1.2.11.tar.gz", 13 | ], 14 | ) 15 | -------------------------------------------------------------------------------- /tools/generate_compile_commands.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | INSTALL_DIR="/usr/local/bin" 3 | VERSION="0.4.3" 4 | 5 | # Download and symlink. 6 | ( 7 | cd "${INSTALL_DIR}" \ 8 | && curl -L "https://github.com/grailbio/bazel-compilation-database/archive/${VERSION}.tar.gz" | tar -xz \ 9 | && ln -f -s "${INSTALL_DIR}/bazel-compilation-database-${VERSION}/generate.sh" bazel-compdb 10 | ) 11 | 12 | bazel-compdb # This will generate compile_commands.json in your workspace root. 13 | 14 | # # You can tweak some behavior with flags: 15 | # # 1. To use the source dir instead of bazel-execroot for directory in which clang commands are run. 16 | # bazel-compdb -s 17 | -------------------------------------------------------------------------------- /spleeter/test/spleeter_tests.cpp: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @brief Contains component tests for Spleeter 4 | /// @copyright Copyright (c) 2020, MIT License 5 | /// 6 | #include "spleeter/spleeter.h" 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | 14 | namespace spleeter 15 | { 16 | namespace 17 | { 18 | TEST(SpleeterSpec, ScenarioFiveStems_GivenAudioWaveform_ExpectFiveSplitAudioWaveforms) 19 | { 20 | auto cli_options = CLIOptions{}; 21 | Spleeter unit{cli_options}; 22 | unit.Init(); 23 | 24 | unit.Execute(); 25 | } 26 | 27 | } // namespace 28 | } // namespace spleeter 29 | -------------------------------------------------------------------------------- /spleeter/inference_engine/null_inference_engine.cpp: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @copyright Copyright (c) 2020. All Rights Reserved. 4 | /// 5 | #include "spleeter/inference_engine/null_inference_engine.h" 6 | 7 | namespace spleeter 8 | { 9 | 10 | NullInferenceEngine::NullInferenceEngine(const InferenceEngineParameters& params) 11 | : results_{params.output_tensor_names.size()} 12 | { 13 | } 14 | 15 | void NullInferenceEngine::Init() {} 16 | 17 | void NullInferenceEngine::Execute(const Waveform& /* waveform */) {} 18 | 19 | void NullInferenceEngine::Shutdown() {} 20 | 21 | Waveforms NullInferenceEngine::GetResults() const 22 | { 23 | return results_; 24 | } 25 | 26 | } // namespace spleeter 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Extras 2 | .DS_Store 3 | .devcontainer/* 4 | *.zip 5 | *.tar.xz 6 | *.pcm 7 | *.raw 8 | *.mp3 9 | *.wav 10 | 11 | # bazel build files 12 | bazel-* 13 | 14 | # vscode files 15 | .vscode 16 | 17 | # Models 18 | model/ 19 | *.tflite 20 | *.pb 21 | *.pb 22 | *.pbtxt 23 | *.data* 24 | 25 | 26 | # Prerequisites 27 | *.d 28 | 29 | # Compiled Object files 30 | *.slo 31 | *.lo 32 | *.o 33 | *.obj 34 | 35 | # Precompiled Headers 36 | *.gch 37 | *.pch 38 | 39 | # Compiled Dynamic libraries 40 | *.so 41 | *.dylib 42 | *.dll 43 | 44 | # Fortran module files 45 | *.mod 46 | *.smod 47 | 48 | # Compiled Static libraries 49 | *.lai 50 | *.la 51 | *.a 52 | *.lib 53 | 54 | # Executables 55 | *.exe 56 | *.out 57 | *.app 58 | -------------------------------------------------------------------------------- /third_party/ffmpeg/ffmpeg.BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_cc//cc:defs.bzl", "cc_library") 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | AVLIBS = [ 6 | "avcodec", 7 | "avdevice", 8 | "avfilter", 9 | "avformat", 10 | "avutil", 11 | "swresample", 12 | "swscale", 13 | ] 14 | 15 | [ 16 | cc_library( 17 | name = avlib, 18 | srcs = glob(["lib/x86_64-linux-gnu/lib{}.so*".format(avlib)]), 19 | hdrs = glob(["include/x86_64-linux-gnu/lib{}/*.h".format(avlib)]), 20 | includes = ["include/x86_64-linux-gnu"], 21 | ) 22 | for avlib in AVLIBS 23 | ] 24 | 25 | cc_library( 26 | name = "ffmpeg", 27 | deps = [ 28 | ":{}".format(avlib) 29 | for avlib in AVLIBS 30 | ], 31 | ) 32 | -------------------------------------------------------------------------------- /spleeter/BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_cc//cc:defs.bzl", "cc_library") 2 | load("@bazel_tools//tools/build_defs/pkg:pkg.bzl", "pkg_tar") 3 | 4 | package(default_visibility = ["//visibility:public"]) 5 | 6 | cc_library( 7 | name = "spleeter", 8 | srcs = glob(["*.cpp"]), 9 | hdrs = glob(["*.h"]), 10 | linkopts = [ 11 | "-lstdc++fs", 12 | ], 13 | deps = [ 14 | "//spleeter/argument_parser", 15 | "//spleeter/audio", 16 | "//spleeter/datatypes", 17 | "//spleeter/inference_engine", 18 | "//spleeter/logging", 19 | "@nlohmann//:json", 20 | ], 21 | ) 22 | 23 | pkg_tar( 24 | name = "includes", 25 | srcs = glob(["*.h"]), 26 | mode = "0644", 27 | package_dir = "spleeter", 28 | tags = ["manual"], 29 | ) 30 | -------------------------------------------------------------------------------- /third_party/glog/glog.bzl: -------------------------------------------------------------------------------- 1 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 2 | 3 | def glog(): 4 | if "com_github_gflags_gflags" not in native.existing_rules(): 5 | http_archive( 6 | name = "com_github_gflags_gflags", 7 | sha256 = "34af2f15cf7367513b352bdcd2493ab14ce43692d2dcd9dfc499492966c64dcf", 8 | strip_prefix = "gflags-2.2.2", 9 | urls = ["https://github.com/gflags/gflags/archive/v2.2.2.tar.gz"], 10 | ) 11 | 12 | if "glog" not in native.existing_rules(): 13 | http_archive( 14 | name = "glog", 15 | url = "https://github.com/google/glog/archive/v0.4.0.zip", 16 | sha256 = "9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc", 17 | strip_prefix = "glog-0.4.0", 18 | ) 19 | -------------------------------------------------------------------------------- /third_party/models/models.BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//visibility:public"]) 2 | 3 | filegroup( 4 | name = "2stems", 5 | srcs = glob([ 6 | "2stems/*.tflite", 7 | "2stems/saved_model/*.pb", 8 | "2stems/saved_model/variables/*.data-*", 9 | "2stems/saved_model/variables/*.index", 10 | ]), 11 | ) 12 | 13 | filegroup( 14 | name = "4stems", 15 | srcs = glob([ 16 | "4stems/*.tflite", 17 | "4stems/saved_model/*.pb", 18 | "4stems/saved_model/variables/*.data-*", 19 | "4stems/saved_model/variables/*.index", 20 | ]), 21 | ) 22 | 23 | filegroup( 24 | name = "5stems", 25 | srcs = glob([ 26 | "5stems/*.tflite", 27 | "5stems/saved_model/*.pb", 28 | "5stems/saved_model/variables/*.data-*", 29 | "5stems/saved_model/variables/*.index", 30 | ]), 31 | ) 32 | -------------------------------------------------------------------------------- /spleeter/argument_parser/i_argument_parser.h: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @brief Contains Argument Parser Interface class 4 | /// @copyright Copyright (c) 2020, MIT License 5 | /// 6 | #ifndef SPLEETER_ARGUMENT_PARSER_I_ARGUMENT_PARSER_H 7 | #define SPLEETER_ARGUMENT_PARSER_I_ARGUMENT_PARSER_H 8 | 9 | #include "spleeter/argument_parser/cli_options.h" 10 | 11 | namespace spleeter 12 | { 13 | /// @brief Argument Parser Interface class 14 | class IArgumentParser 15 | { 16 | public: 17 | /// @brief Destructor 18 | virtual ~IArgumentParser() = default; 19 | 20 | /// @brief Provides Parsed Arguments 21 | virtual CLIOptions GetParsedArgs() const = 0; 22 | 23 | protected: 24 | /// @brief Parse Arguments from argc, argv 25 | virtual CLIOptions ParseArgs(int argc, char** argv) = 0; 26 | }; 27 | } // namespace spleeter 28 | 29 | #endif /// SPLEETER_ARGUMENT_PARSER_I_ARGUMENT_PARSER_H 30 | -------------------------------------------------------------------------------- /application/main.cpp: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @copyright Copyright (c) 2020, MIT License 4 | /// 5 | #include "spleeter/argument_parser/argument_parser.h" 6 | #include "spleeter/argument_parser/i_argument_parser.h" 7 | #include "spleeter/logging/logging.h" 8 | #include "spleeter/spleeter.h" 9 | 10 | #include 11 | 12 | int main(int argc, char** argv) 13 | { 14 | try 15 | { 16 | std::unique_ptr argument_parser = 17 | std::make_unique(argc, argv); 18 | auto spleeter = std::make_unique(argument_parser->GetParsedArgs()); 19 | spleeter->Init(); 20 | 21 | spleeter->Execute(); 22 | 23 | spleeter->Shutdown(); 24 | } 25 | catch (std::exception& e) 26 | { 27 | LOG(ERROR) << "Caught Exception!! " << e.what() << std::endl; 28 | return 1; 29 | } 30 | 31 | return 0; 32 | } 33 | -------------------------------------------------------------------------------- /third_party/audio_example/audio_example.bzl: -------------------------------------------------------------------------------- 1 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_file") 2 | 3 | def audio_example(): 4 | if "audio_example" not in native.existing_rules(): 5 | http_file( 6 | name = "audio_example", 7 | downloaded_file_path = "audio_example.wav", 8 | sha256 = "409f930a5409fe134a16fcb4c12b484794377d7214b14bf42e63957bf17d8f2e", 9 | urls = ["https://gitlab.com/jinay1991/spleeter/uploads/ded94b0b51d7021328491bac73d6c00a/audio_example.wav"], 10 | ) 11 | 12 | # if "audio_example" not in native.existing_rules(): 13 | # http_file( 14 | # name = "audio_example", 15 | # downloaded_file_path = "audio_example.mp3", 16 | # sha256 = "4b431d535d235bd81b62f816f08c3f1afb6679d2706d84b3b75903b7df909507", 17 | # urls = ["https://github.com/deezer/spleeter/raw/master/audio_example.mp3"], 18 | # ) 19 | -------------------------------------------------------------------------------- /spleeter/test/BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_cc//cc:defs.bzl", "cc_test") 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | cc_test( 6 | name = "unit_test", 7 | srcs = glob( 8 | ["*.cpp"], 9 | exclude = ["spleeter_tests.cpp"], 10 | ), 11 | data = [ 12 | "@audio_example//file", 13 | "@models//:2stems", 14 | "@models//:4stems", 15 | "@models//:5stems", 16 | ], 17 | linkopts = [ 18 | "-lstdc++fs", 19 | ], 20 | linkstatic = True, 21 | deps = [ 22 | "//spleeter", 23 | "//spleeter/audio", 24 | "@googletest//:gtest_main", 25 | ], 26 | ) 27 | 28 | cc_test( 29 | name = "component_test", 30 | srcs = ["spleeter_tests.cpp"], 31 | data = [ 32 | "@audio_example//file", 33 | "@models//:5stems", 34 | ], 35 | linkstatic = True, 36 | deps = [ 37 | "//spleeter", 38 | "@googletest//:gtest_main", 39 | ], 40 | ) 41 | -------------------------------------------------------------------------------- /tools/tflite_converter/tflite_converter.py: -------------------------------------------------------------------------------- 1 | savedModelDir = "export_dir/0" 2 | 3 | import os 4 | import tensorflow as tf 5 | 6 | model = tf.saved_model.load(savedModelDir) 7 | print("Num signatures:") 8 | print(len(model.signatures)) 9 | concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY] 10 | 11 | converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) 12 | converter.allow_custom_ops = True 13 | converter.target_ops = set([ tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS ]) 14 | converter.target_spec.supported_ops = set([ tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS ]) 15 | converter.experimental_new_converter = True 16 | tflite_model = converter.convert() 17 | open("converted_model.tflite", "wb").write(tflite_model) 18 | 19 | interpreter = tf.lite.Interpreter(model_content=tflite_model) 20 | interpreter.allocate_tensors() 21 | print("Output", interpreter.get_output_details()[0]["name"], interpreter.get_output_details()[0]["shape"]) -------------------------------------------------------------------------------- /third_party/dependencies.bzl: -------------------------------------------------------------------------------- 1 | load("@spleeter//third_party/audio_example:audio_example.bzl", "audio_example") 2 | load("@spleeter//third_party/audionamix:audionamix.bzl", "audionamix") 3 | load("@spleeter//third_party/eigen:eigen.bzl", "eigen") 4 | load("@spleeter//third_party/ffmpeg:ffmpeg.bzl", "ffmpeg") 5 | load("@spleeter//third_party/glog:glog.bzl", "glog") 6 | load("@spleeter//third_party/googletest:googletest.bzl", "googletest") 7 | load("@spleeter//third_party/models:models.bzl", "models") 8 | load("@spleeter//third_party/nlohmann:nlohmann.bzl", "nlohmann") 9 | load("@spleeter//third_party/tensorflow:tensorflow.bzl", "tensorflow") 10 | load("@spleeter//third_party/tensorflowlite:tensorflowlite.bzl", "tensorflowlite") 11 | load("@spleeter//third_party/zlib:zlib.bzl", "zlib") 12 | 13 | def spleeter_dependencies(): 14 | """ Load 3rd Party Dependencies """ 15 | audio_example() 16 | audionamix() 17 | eigen() 18 | ffmpeg() 19 | glog() 20 | googletest() 21 | models() 22 | nlohmann() 23 | tensorflow() 24 | tensorflowlite() 25 | zlib() 26 | -------------------------------------------------------------------------------- /spleeter/spleeter.h: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @brief Contains High Level API/Class for Spleeter application 4 | /// @copyright Copyright (c) 2020, MIT License 5 | /// 6 | #ifndef SPLEETER_SPLEETER_H 7 | #define SPLEETER_SPLEETER_H 8 | 9 | #include "spleeter/argument_parser/cli_options.h" 10 | #include "spleeter/i_separator.h" 11 | 12 | #include 13 | 14 | namespace spleeter 15 | { 16 | /// @brief Contains High Level APIs for performing Spleeter tasks 17 | class Spleeter 18 | { 19 | public: 20 | /// @brief Constructor. 21 | explicit Spleeter(const CLIOptions& cli_options); 22 | 23 | /// @brief Initialize Spleeter 24 | void Init(); 25 | 26 | /// @brief Execute Spleeter 27 | void Execute(); 28 | 29 | /// @brief Shutdown Spleeter (release any occupied resources) 30 | void Shutdown(); 31 | 32 | private: 33 | /// @brief Parsed Command Line options 34 | CLIOptions cli_options_; 35 | 36 | /// @brief Audio Separator 37 | std::unique_ptr separator_; 38 | }; 39 | } // namespace spleeter 40 | 41 | #endif /// SPLEETER_SPLEETER_H 42 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:20.04 2 | 3 | ARG TARGETOS 4 | ARG TARGETARCH 5 | 6 | ENV DEBIAN_FRONTEND=noninteractive 7 | RUN apt-get update 8 | 9 | # Installation of dev environment dependencies 10 | RUN apt-get install -y \ 11 | gcc g++ clang-format clang-tidy lcov \ 12 | wget git git-lfs \ 13 | openjdk-11-jdk openjdk-11-jre \ 14 | libavcodec-dev libavformat-dev libavfilter-dev libavdevice-dev libswresample-dev libswscale-dev ffmpeg 15 | 16 | # Installation of Bazel Package 17 | RUN wget https://github.com/bazelbuild/bazelisk/releases/download/v1.11.0/bazelisk-${TARGETOS}-${TARGETARCH} && \ 18 | chmod +x bazelisk-${TARGETOS}-${TARGETARCH} && \ 19 | mv bazelisk-${TARGETOS}-${TARGETARCH} /usr/bin/bazel 20 | 21 | # Installation of Bazel Tools 22 | RUN wget https://github.com/bazelbuild/buildtools/releases/download/5.0.1/buildifier-${TARGETOS}-${TARGETARCH} && \ 23 | chmod +x buildifier-${TARGETOS}-${TARGETARCH} && \ 24 | mv buildifier-${TARGETOS}-${TARGETARCH} /usr/bin/buildifier 25 | 26 | # cleanup 27 | RUN apt-get clean && rm -rf /var/lib/apt/lists/* && \ 28 | apt-get autoremove && apt-get autoclean 29 | -------------------------------------------------------------------------------- /spleeter/datatypes/audio_properties.h: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @brief Contains definition of Audio Properties 4 | /// @copyright Copyright (c) 2020, MIT License 5 | /// 6 | #ifndef SPLEETER_DATATYPES_AUDIO_PROPERTIES_H 7 | #define SPLEETER_DATATYPES_AUDIO_PROPERTIES_H 8 | 9 | #include 10 | #include 11 | 12 | namespace spleeter 13 | { 14 | /// @brief Audio Properties 15 | struct AudioProperties 16 | { 17 | /// @brief Number of frames/samples 18 | std::uint64_t nb_frames; 19 | 20 | /// @brief Number of channels 21 | std::uint64_t nb_channels; 22 | 23 | /// @brief Sample Rate 24 | std::uint32_t sample_rate; 25 | }; 26 | 27 | /// @brief Prepare output stream for AudioProperties 28 | inline std::ostream& operator<<(std::ostream& out, const AudioProperties& audio_properties) 29 | { 30 | out << "AudioProperties{nb_channels: " << audio_properties.nb_channels 31 | << ", nb_frames: " << audio_properties.nb_frames << ", sample_rate: " << audio_properties.sample_rate << "}"; 32 | return out; 33 | } 34 | } // namespace spleeter 35 | #endif /// SPLEETER_DATATYPES_AUDIO_PROPERTIES_H 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jinay Patel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /spleeter/spleeter.cpp: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @brief Contains definitions for Spleeter class methods 4 | /// @copyright Copyright (c) 2020, MIT License 5 | /// 6 | #include "spleeter/spleeter.h" 7 | 8 | #include "spleeter/logging/logging.h" 9 | #include "spleeter/separator.h" 10 | 11 | namespace spleeter 12 | { 13 | Spleeter::Spleeter(const CLIOptions& cli_options) 14 | : cli_options_{cli_options}, 15 | separator_{std::make_unique(cli_options_.inference_engine_params, cli_options_.mwf)} 16 | { 17 | } 18 | 19 | void Spleeter::Init() {} 20 | 21 | void Spleeter::Execute() 22 | { 23 | separator_->SeparateToFile(cli_options_.inputs, 24 | cli_options_.output_path, 25 | cli_options_.audio_adapter, 26 | cli_options_.offset, 27 | cli_options_.duration, 28 | cli_options_.codec, 29 | cli_options_.bitrate, 30 | cli_options_.filename_format, 31 | false); 32 | LOG(INFO) << "Successfully executed spleeter!!"; 33 | } 34 | void Spleeter::Shutdown() {} 35 | 36 | } // namespace spleeter 37 | -------------------------------------------------------------------------------- /spleeter/inference_engine/i_inference_engine.h: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @brief Contains Inference Interface Engine definition 4 | /// @copyright Copyright (c) 2020. MIT License 5 | /// 6 | #ifndef SPLEETER_INFERENCE_ENGINE_I_INFERENCE_ENGINE_H 7 | #define SPLEETER_INFERENCE_ENGINE_I_INFERENCE_ENGINE_H 8 | 9 | #include "spleeter/datatypes/waveform.h" 10 | 11 | #include 12 | 13 | namespace spleeter 14 | { 15 | 16 | /// @brief Inference Engine Interface class 17 | class IInferenceEngine 18 | { 19 | public: 20 | /// @brief Destructor 21 | virtual ~IInferenceEngine() = default; 22 | 23 | /// @brief Initialise Inference Engine 24 | virtual void Init() = 0; 25 | 26 | /// @brief Execute Inference with Inference Engine 27 | /// @param waveform [in] - Waveform to be split 28 | virtual void Execute(const Waveform& waveform) = 0; 29 | 30 | /// @brief Release Inference Engine 31 | virtual void Shutdown() = 0; 32 | 33 | /// @brief Obtain Results for provided input waveform 34 | /// @return List of waveforms (split waveforms) 35 | virtual Waveforms GetResults() const = 0; 36 | }; 37 | 38 | /// @brief InferenceEngine unique instance pointer 39 | using InferenceEnginePtr = std::unique_ptr; 40 | 41 | } // namespace spleeter 42 | #endif /// SPLEETER_INFERENCE_ENGINE_I_INFERENCE_ENGINE_H 43 | -------------------------------------------------------------------------------- /third_party/zlib/zlib.BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_cc//cc:defs.bzl", "cc_library") 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | config_setting( 6 | name = "macos", 7 | constraint_values = ["@bazel_tools//platforms:osx"], 8 | ) 9 | 10 | config_setting( 11 | name = "windows", 12 | constraint_values = ["@bazel_tools//platforms:windows"], 13 | ) 14 | 15 | config_setting( 16 | name = "linux", 17 | constraint_values = ["@bazel_tools//platforms:linux"], 18 | ) 19 | 20 | cc_library( 21 | name = "zlib", 22 | srcs = [ 23 | "adler32.c", 24 | "compress.c", 25 | "crc32.c", 26 | "crc32.h", 27 | "deflate.c", 28 | "deflate.h", 29 | "gzclose.c", 30 | "gzguts.h", 31 | "gzlib.c", 32 | "gzread.c", 33 | "gzwrite.c", 34 | "infback.c", 35 | "inffast.c", 36 | "inffast.h", 37 | "inffixed.h", 38 | "inflate.c", 39 | "inflate.h", 40 | "inftrees.c", 41 | "inftrees.h", 42 | "trees.c", 43 | "trees.h", 44 | "uncompr.c", 45 | "zconf.h", 46 | "zutil.c", 47 | "zutil.h", 48 | ], 49 | hdrs = ["zlib.h"], 50 | copts = select({ 51 | ":windows": [], 52 | "//conditions:default": [ 53 | "-Wno-shift-negative-value", 54 | "-DZ_HAVE_UNISTD_H", 55 | ], 56 | }), 57 | includes = ["."], 58 | ) 59 | -------------------------------------------------------------------------------- /spleeter/logging/test/logging_tests.cpp: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @copyright Copyright (c) 2020-2021. MIT License. 4 | /// 5 | #include "spleeter/logging/logging.h" 6 | 7 | #include 8 | #include 9 | 10 | namespace perception 11 | { 12 | namespace 13 | { 14 | 15 | TEST(LoggingTest, BasicLoggingMacro_INFO) 16 | { 17 | const std::string test_log{"Sanity Test for LogSeverityLevel = INFO!!"}; 18 | ::testing::internal::CaptureStderr(); 19 | 20 | LOG(INFO) << test_log; 21 | 22 | const std::string result = ::testing::internal::GetCapturedStderr(); 23 | EXPECT_THAT(result, ::testing::HasSubstr(test_log)); 24 | } 25 | 26 | TEST(LoggingTest, BasicLoggingMacro_WARN) 27 | { 28 | const std::string test_log{"Sanity Test for LogSeverityLevel = WARNING!!"}; 29 | ::testing::internal::CaptureStderr(); 30 | 31 | LOG(WARNING) << test_log; 32 | 33 | const std::string result = ::testing::internal::GetCapturedStderr(); 34 | EXPECT_THAT(result, ::testing::HasSubstr(test_log)); 35 | } 36 | 37 | TEST(LoggingTest, BasicLoggingMacro_ERROR) 38 | { 39 | const std::string test_log{"Sanity Test for LogSeverityLevel = ERROR!!"}; 40 | ::testing::internal::CaptureStderr(); 41 | 42 | LOG(ERROR) << test_log; 43 | 44 | const std::string result = ::testing::internal::GetCapturedStderr(); 45 | EXPECT_THAT(result, ::testing::HasSubstr(test_log)); 46 | } 47 | 48 | } // namespace 49 | } // namespace perception 50 | -------------------------------------------------------------------------------- /spleeter/inference_engine/null_inference_engine.h: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @copyright Copyright (c) 2020. MIT License 4 | /// 5 | #ifndef SPLEETER_INFERENCE_ENGINE_NULL_INFERENCE_ENGINE_H 6 | #define SPLEETER_INFERENCE_ENGINE_NULL_INFERENCE_ENGINE_H 7 | 8 | #include "spleeter/datatypes/inference_engine.h" 9 | #include "spleeter/inference_engine/i_inference_engine.h" 10 | 11 | namespace spleeter 12 | { 13 | /// @brief Null Inference Engine class 14 | class NullInferenceEngine final : public IInferenceEngine 15 | { 16 | public: 17 | /// @brief Constructor 18 | /// @param params[in] Inference Engine Parameters such as model input/output node names 19 | explicit NullInferenceEngine(const InferenceEngineParameters& params); 20 | 21 | /// @brief Initialise Null Inference Engine 22 | void Init() override; 23 | 24 | /// @brief Execute Inference with Null Inference Engine 25 | /// @param waveform [in] - Waveform to be split 26 | void Execute(const Waveform& waveform) override; 27 | 28 | /// @brief Release Null Inference Engine 29 | void Shutdown() override; 30 | 31 | /// @brief Provide results in terms of Matrix 32 | /// @return List of waveforms (split waveforms) 33 | Waveforms GetResults() const override; 34 | 35 | private: 36 | /// @brief Output Tensors saved as Waveforms 37 | const Waveforms results_; 38 | }; 39 | 40 | } // namespace spleeter 41 | 42 | #endif /// SPLEETER_INFERENCE_ENGINE_NULL_INFERENCE_ENGINE_H 43 | -------------------------------------------------------------------------------- /tools/tflite_converter/README.md: -------------------------------------------------------------------------------- 1 | # spleeter-tflite-convert 2 | Things to know: 3 | * those are python 3 scripts that require tensorflow's latest version (I installed it with pip3) 4 | * the first one, "export_model.py", will create a "saved model" from the checkpoint 5 | * the second one, "tofreezelite.py", will convert the "saved model" to "converted_model.tflite" (which is roughly 150MB with the 4stems checkpoint) 6 | * the path to the model checkpoint is hardcoded in the first script: 7 | 8 | I have my model checkpoint at this location : /Users/tinou/Desktop/4stems 9 | 10 | that prefix variable must be the full path of the .meta file, up to the extension. 11 | 12 | * the input and output tensor names are hardcoded for the 4stems model checkpoint, if you plan on converting another one, you will have to manually change the output tensor names (I found those by looking into spleeterpp python script that didn't work for me). 13 | * you will need TensorFlowLiteSelectTfOps in addition to TensorFlowLiteC when using it. 14 | 15 | It uses a *lot* of RAM, you'll only be able to process small chunks of audio at a time on iOS (or Android). Sounds 'ok-ish' with 2s chuncks, with cross-faded overlap (4096 samples of overlap) 16 | I'm interested if anyone finds a way to modify the checkpoint to make it use static-sized tensors. As is it is, acceleration (GPU/Metal or coreml) isn't possible. 17 | 18 | # Note 19 | 20 | This directory contains files and work from https://github.com/tinoucas/spleeter-tflite-convert -------------------------------------------------------------------------------- /spleeter/inference_engine/BUILD: -------------------------------------------------------------------------------- 1 | load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test") 2 | load("@bazel_tools//tools/build_defs/pkg:pkg.bzl", "pkg_tar") 3 | 4 | cc_library( 5 | name = "inference_engine", 6 | srcs = [ 7 | "inference_engine_strategy.cpp", 8 | "null_inference_engine.cpp", 9 | "tf_inference_engine.cpp", 10 | "tflite_inference_engine.cpp", 11 | ], 12 | hdrs = [ 13 | "i_inference_engine.h", 14 | "inference_engine_strategy.h", 15 | "null_inference_engine.h", 16 | "tf_inference_engine.h", 17 | "tflite_inference_engine.h", 18 | ], 19 | visibility = ["//visibility:public"], 20 | deps = [ 21 | "//spleeter/datatypes", 22 | "//spleeter/logging", 23 | "@tensorflow", 24 | "@tensorflowlite", 25 | ], 26 | ) 27 | 28 | cc_test( 29 | name = "unit_tests", 30 | srcs = ["test/inference_engine_tests.cpp"], 31 | data = [ 32 | "@audio_example//file", 33 | "@models//:5stems", 34 | ], 35 | linkstatic = True, 36 | tags = ["unit"], 37 | visibility = ["//visibility:public"], 38 | deps = [ 39 | ":inference_engine", 40 | "//spleeter/audio", 41 | "//spleeter/logging", 42 | "@googletest//:gtest_main", 43 | ], 44 | ) 45 | 46 | pkg_tar( 47 | name = "includes", 48 | srcs = glob(["*.h"]), 49 | mode = "0644", 50 | package_dir = "spleeter/inference_engine", 51 | tags = ["manual"], 52 | visibility = ["//visibility:public"], 53 | ) 54 | -------------------------------------------------------------------------------- /spleeter/datatypes/waveform.h: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @brief Contains Data Structure to hold Waveform (i.e. audio samples) 4 | /// @copyright Copyright (c) 2020, MIT License 5 | /// 6 | #ifndef SPLEETER_DATATYPES_WAVEFORM_H 7 | #define SPLEETER_DATATYPES_WAVEFORM_H 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | namespace spleeter 15 | { 16 | struct Waveform 17 | { 18 | std::int32_t nb_frames; 19 | std::int32_t nb_channels; 20 | std::vector data; 21 | }; 22 | 23 | /// @brief List of waveforms 24 | using Waveforms = std::vector; 25 | 26 | /// @brief Provide output stream for waveform (list of samples), prints number of samples it holds. 27 | inline std::ostream& operator<<(std::ostream& out, const Waveform& waveform) 28 | { 29 | out << "Waveform{nb_frames: " << waveform.nb_frames << ", nb_channels: " << waveform.nb_channels 30 | << ", nb_size: " << waveform.data.size() << "}"; 31 | return out; 32 | } 33 | 34 | /// @brief Provide output stream for waveforms (list of waveform), prints number of sample in each waveform 35 | inline std::ostream& operator<<(std::ostream& out, const Waveforms& waveforms) 36 | { 37 | std::int32_t idx{0}; 38 | out << "Waveforms{\n"; 39 | std::for_each(waveforms.begin(), waveforms.end(), [&](const auto& waveform) { 40 | out << " " << idx++ << ". " << waveform << "\n"; 41 | }); 42 | out << "}"; 43 | return out; 44 | } 45 | } // namespace spleeter 46 | #endif /// SPLEETER_DATATYPES_WAVEFORM_H 47 | -------------------------------------------------------------------------------- /spleeter/argument_parser/argument_parser.h: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @brief Contains Argument Parser definitions 4 | /// @copyright Copyright (c) 2020, MIT License 5 | /// 6 | #ifndef SPLEETER_ARGUMENT_PARSER_ARGUMENT_PARSER_H 7 | #define SPLEETER_ARGUMENT_PARSER_ARGUMENT_PARSER_H 8 | 9 | #include "spleeter/argument_parser/cli_options.h" 10 | #include "spleeter/argument_parser/i_argument_parser.h" 11 | 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | #include 18 | 19 | namespace spleeter 20 | { 21 | /// @brief Argument Parser class 22 | class ArgumentParser : public IArgumentParser 23 | { 24 | public: 25 | /// @brief Default Constructor 26 | ArgumentParser(); 27 | 28 | /// @brief Constructor 29 | /// @param [in] argc - number of arguments 30 | /// @param [in] argv - list of arguments 31 | explicit ArgumentParser(int argc, char* argv[]); 32 | 33 | /// @brief Destructor 34 | virtual ~ArgumentParser(); 35 | 36 | /// @brief Provides Parsed Arguments 37 | virtual CLIOptions GetParsedArgs() const override; 38 | 39 | protected: 40 | /// @brief Parse Arguments from argc, argv 41 | virtual CLIOptions ParseArgs(int argc, char* argv[]) override; 42 | 43 | private: 44 | /// @brief parsed arguments 45 | CLIOptions cli_options_; 46 | 47 | /// @brief long option list 48 | std::vector long_options_; 49 | 50 | /// @brief short option list 51 | std::string optstring_; 52 | }; 53 | 54 | } // namespace spleeter 55 | 56 | #endif /// SPLEETER_ARGUMENT_PARSER_ARGUMENT_PARSER_H 57 | -------------------------------------------------------------------------------- /spleeter/datatypes/inference_engine.h: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @copyright Copyright (c) 2020. MIT License 4 | /// 5 | #ifndef SPLEETER_DATATYPES_INFERENCE_ENGINE_H 6 | #define SPLEETER_DATATYPES_INFERENCE_ENGINE_H 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | namespace spleeter 13 | { 14 | /// @brief List of Inference Engine supported 15 | enum class InferenceEngineType : std::uint8_t 16 | { 17 | kTensorFlowLite = 0U, 18 | kTensorFlow = 1U, 19 | kInvalid = 255U 20 | }; 21 | 22 | /// @brief InferenceEngine Parameters 23 | struct InferenceEngineParameters 24 | { 25 | /// @brief Path to Model 26 | std::string model_path{}; 27 | 28 | /// @brief Input Node/Tensor Name 29 | std::string input_tensor_name{}; 30 | 31 | /// @brief List of Output Nodes/Tensors Name 32 | std::vector output_tensor_names{}; 33 | 34 | /// @brief Path to Model configurations 35 | std::string configuration{}; 36 | }; 37 | 38 | inline const char* to_string(const InferenceEngineType& inference_engine_type) 39 | { 40 | switch (inference_engine_type) 41 | { 42 | case InferenceEngineType::kTensorFlow: 43 | return "kTensorFlow"; 44 | case InferenceEngineType::kTensorFlowLite: 45 | return "kTensorFlowLite"; 46 | default: 47 | return "ERROR: Unknown InferenceEngineType."; 48 | } 49 | return "ERROR: Unknown InferenceEngineType."; 50 | } 51 | 52 | inline std::ostream& operator<<(std::ostream& stream, const InferenceEngineType& inference_engine_type) 53 | { 54 | const char* name = to_string(inference_engine_type); 55 | stream << name; 56 | return stream; 57 | } 58 | 59 | } // namespace spleeter 60 | #endif /// SPLEETER_DATATYPES_INFERENCE_ENGINE_H 61 | -------------------------------------------------------------------------------- /third_party/models/models.bzl: -------------------------------------------------------------------------------- 1 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 2 | 3 | def models(): 4 | """ Download Models """ 5 | if "models" not in native.existing_rules(): 6 | http_archive( 7 | name = "models", 8 | build_file = "//third_party/models:models.BUILD", 9 | strip_prefix = "models", 10 | sha256 = "e1314aa69ed793e2677689a9cf6f3ac1d2357c80d5fb64769ee7175755ef6a5a", 11 | url = "https://github.com/jinay1991/spleeter/releases/download/v2.3/models.tar.gz", 12 | ) 13 | 14 | if "2stems" not in native.existing_rules(): 15 | http_archive( 16 | name = "2stems", 17 | build_file = "//third_party/models:stems.BUILD", 18 | strip_prefix = "models/2stems", 19 | sha256 = "6c2bb7dc3f1c1e162e4567ea4f0f5c20257a0d96c9b3c4e88b363aa83d4cf150", 20 | url = "https://github.com/jinay1991/spleeter/releases/download/v2.3/2stems.tar.gz", 21 | ) 22 | 23 | if "4stems" not in native.existing_rules(): 24 | http_archive( 25 | name = "4stems", 26 | build_file = "//third_party/models:stems.BUILD", 27 | sha256 = "1edbb653058a9d474370a8579b18f18e9d383965c6749006ded2e1ace0e1af65", 28 | strip_prefix = "models/4stems", 29 | url = "https://github.com/jinay1991/spleeter/releases/download/v2.3/4stems.tar.gz", 30 | ) 31 | 32 | if "5stems" not in native.existing_rules(): 33 | http_archive( 34 | name = "5stems", 35 | build_file = "//third_party/models:stems.BUILD", 36 | strip_prefix = "models/5stems", 37 | sha256 = "667c13427502623ba0a6b3d6ef6b81b9f81b0d13e7e22239c662890b477a8f02", 38 | url = "https://github.com/jinay1991/spleeter/releases/download/v2.3/5stems.tar.gz", 39 | ) 40 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | BasedOnStyle: Google 4 | AccessModifierOffset: -2 5 | AllowAllParametersOfDeclarationOnNextLine: false 6 | AllowShortFunctionsOnASingleLine: Inline 7 | AllowShortIfStatementsOnASingleLine: false 8 | AllowShortLoopsOnASingleLine: false 9 | BinPackArguments: false 10 | BinPackParameters: false 11 | BraceWrapping: 12 | AfterClass: true 13 | AfterControlStatement: true 14 | AfterEnum: true 15 | AfterFunction: true 16 | AfterNamespace: true 17 | AfterObjCDeclaration: true 18 | AfterStruct: true 19 | AfterUnion: true 20 | BeforeCatch: true 21 | BeforeElse: true 22 | IndentBraces: false 23 | BreakBeforeBraces: Allman 24 | ColumnLimit: 120 25 | DerivePointerAlignment: false 26 | SortIncludes: true 27 | IncludeCategories: 28 | - Regex: '^(<|")(assert|complex|ctype|errno|fenv|float|inttypes|iso646|limits|locale|math|setjmp|signal|stdalign|stdargh|stdatomic|stdbool|stddef|stdint|stdio|stdlib|stdnoreturn|string|tgmath|threads|time|uchar|wchar|wctype)\.h' 29 | Priority: 4 30 | - Regex: '^(<|")(cstdlib|csignal|csetjmp|cstdarg|typeinfo|typeindex|type_traits|bitset|functional|utility|ctime|chrono|cstddef|initializer_list|tuple|any|optional|variant|new|memory|scoped_allocator|memory_resource|climits|cfloat|cstdint|cinttypes|limits|exception|stdexcept|cassert|system_error|cerrno|cctype|cwctype|cstring|cwchar|cuchar|string|string_view|array|vector|deque|list|forward_list|set|map|unordered_set|unordered_map|stack|queue|algorithm|execution|teratorslibrary|iterator|cmath|complex|valarray|random|numeric|ratio|cfenv|iosfwd|ios|istream|ostream|iostream|fstream|sstream|strstream|iomanip|streambuf|cstdio|locale|clocale|codecvt|regex|atomic|thread|mutex|shared_mutex|future|condition_variable|filesystem|ciso646|ccomplex|ctgmath|cstdalign|cstdbool)(>|")$' 31 | Priority: 3 32 | - Regex: '^<.*\.(h|hpp)>' 33 | Priority: 2 34 | - Regex: '^".*"' 35 | Priority: 1 36 | IndentWidth: 4 37 | KeepEmptyLinesAtTheStartOfBlocks: true 38 | ... 39 | -------------------------------------------------------------------------------- /spleeter/inference_engine/inference_engine_strategy.h: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @copyright Copyright (c) 2020. MIT License 4 | /// 5 | #ifndef SPLEETER_INFERENCE_ENGINE_INFERENCE_ENGINE_STRATEGY_H 6 | #define SPLEETER_INFERENCE_ENGINE_INFERENCE_ENGINE_STRATEGY_H 7 | 8 | #include "spleeter/datatypes/inference_engine.h" 9 | #include "spleeter/inference_engine/i_inference_engine.h" 10 | 11 | #include 12 | 13 | namespace spleeter 14 | { 15 | 16 | /// @brief Inference Engine Strategy 17 | class InferenceEngineStrategy final 18 | { 19 | public: 20 | /// @brief Default Constructor 21 | InferenceEngineStrategy(); 22 | 23 | /// @brief Initialise Inference Engine 24 | void Init(); 25 | 26 | /// @brief Execute Inference with Inference Engine 27 | /// @param waveform [in] - Waveform to be split 28 | void Execute(const Waveform& waveform); 29 | 30 | /// @brief Release Inference Engine 31 | void Shutdown(); 32 | 33 | /// @brief Select Inference Engine 34 | /// @param inference_engine_type [in] Inference Engine type (TF, TFLite) 35 | /// @param inference_engine_param [in] Inference Engine Parameters 36 | void SelectInferenceEngine(const InferenceEngineType& inference_engine_type, 37 | const InferenceEngineParameters& inference_engine_parameters); 38 | 39 | /// @brief Obtain Results for provided input waveform 40 | /// @return List of waveforms (split waveforms) 41 | Waveforms GetResults() const; 42 | 43 | /// @brief Provide selected inference engine type 44 | /// @return InferenceEngineType 45 | InferenceEngineType GetInferenceEngineType() const; 46 | 47 | private: 48 | /// @brief Inference Engine 49 | std::unique_ptr inference_engine_; 50 | 51 | /// @brief Inference Engine Type 52 | InferenceEngineType inference_engine_type_; 53 | }; 54 | } // namespace spleeter 55 | #endif /// SPLEETER_INFERENCE_ENGINE_INFERENCE_ENGINE_STRATEGY_H 56 | -------------------------------------------------------------------------------- /spleeter/inference_engine/inference_engine_strategy.cpp: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @copyright Copyright (c) 2020. MIT License 4 | /// 5 | #include "spleeter/inference_engine/inference_engine_strategy.h" 6 | 7 | #include "spleeter/inference_engine/null_inference_engine.h" 8 | #include "spleeter/inference_engine/tf_inference_engine.h" 9 | #include "spleeter/inference_engine/tflite_inference_engine.h" 10 | #include "spleeter/logging/logging.h" 11 | 12 | namespace spleeter 13 | { 14 | 15 | InferenceEngineStrategy::InferenceEngineStrategy() : inference_engine_type_{InferenceEngineType::kInvalid} {} 16 | 17 | void InferenceEngineStrategy::SelectInferenceEngine(const InferenceEngineType& inference_engine_type, 18 | const InferenceEngineParameters& inference_engine_parameters) 19 | { 20 | inference_engine_type_ = inference_engine_type; 21 | switch (inference_engine_type) 22 | { 23 | case InferenceEngineType::kTensorFlow: 24 | { 25 | inference_engine_ = std::make_unique(inference_engine_parameters); 26 | break; 27 | } 28 | case InferenceEngineType::kTensorFlowLite: 29 | { 30 | inference_engine_ = std::make_unique(inference_engine_parameters); 31 | break; 32 | } 33 | case InferenceEngineType::kInvalid: 34 | default: 35 | { 36 | inference_engine_ = std::make_unique(inference_engine_parameters); 37 | LOG(ERROR) << "Received " << inference_engine_type; 38 | break; 39 | } 40 | } 41 | } 42 | 43 | void InferenceEngineStrategy::Init() 44 | { 45 | inference_engine_->Init(); 46 | } 47 | 48 | void InferenceEngineStrategy::Execute(const Waveform& waveform) 49 | { 50 | inference_engine_->Execute(waveform); 51 | } 52 | 53 | void InferenceEngineStrategy::Shutdown() 54 | { 55 | inference_engine_->Shutdown(); 56 | } 57 | 58 | Waveforms InferenceEngineStrategy::GetResults() const 59 | { 60 | return inference_engine_->GetResults(); 61 | } 62 | 63 | InferenceEngineType InferenceEngineStrategy::GetInferenceEngineType() const 64 | { 65 | return inference_engine_type_; 66 | } 67 | } // namespace spleeter 68 | -------------------------------------------------------------------------------- /BUILD: -------------------------------------------------------------------------------- 1 | load("@bazel_tools//tools/build_defs/pkg:pkg.bzl", "pkg_tar") 2 | 3 | package(default_visibility = ["//visibility:public"]) 4 | 5 | pkg_tar( 6 | name = "spleeter-bin", 7 | srcs = [ 8 | "//application:spleeter", 9 | ], 10 | mode = "0647", 11 | package_dir = "bin", 12 | tags = ["manual"], 13 | ) 14 | 15 | pkg_tar( 16 | name = "spleeter-include", 17 | srcs = [], 18 | mode = "0644", 19 | package_dir = "include", 20 | tags = ["manual"], 21 | deps = [ 22 | "//spleeter:includes", 23 | "//spleeter/argument_parser:includes", 24 | "//spleeter/audio:includes", 25 | "//spleeter/datatypes:includes", 26 | "//spleeter/inference_engine:includes", 27 | "//spleeter/logging:includes", 28 | ], 29 | ) 30 | 31 | pkg_tar( 32 | name = "spleeter-lib", 33 | srcs = [ 34 | "//spleeter", 35 | "//spleeter/argument_parser", 36 | "//spleeter/audio", 37 | "//spleeter/datatypes", 38 | "//spleeter/inference_engine", 39 | "//spleeter/logging", 40 | ], 41 | mode = "0644", 42 | package_dir = "lib", 43 | tags = ["manual"], 44 | ) 45 | 46 | pkg_tar( 47 | name = "spleeter-data", 48 | srcs = [ 49 | "@audio_example//file", 50 | ], 51 | mode = "0644", 52 | package_dir = "data", 53 | tags = ["manual"], 54 | ) 55 | 56 | pkg_tar( 57 | name = "spleeter-model", 58 | srcs = [ 59 | "@model//:5stems", 60 | ], 61 | mode = "0644", 62 | package_dir = "model", 63 | tags = ["manual"], 64 | ) 65 | 66 | filegroup( 67 | name = "spleeter-third_party_files", 68 | srcs = glob(["third_party/**/*"]), 69 | ) 70 | 71 | pkg_tar( 72 | name = "spleeter-third_party", 73 | srcs = [ 74 | ":spleeter-third_party_files", 75 | ], 76 | mode = "0644", 77 | package_dir = "third_party", 78 | tags = ["manual"], 79 | ) 80 | 81 | pkg_tar( 82 | name = "spleeter-dev", 83 | testonly = True, 84 | extension = "tar.gz", 85 | package_dir = "/spleeter", 86 | strip_prefix = "/", 87 | tags = ["manual"], 88 | deps = [ 89 | ":spleeter-bin", 90 | ":spleeter-data", 91 | ":spleeter-include", 92 | ":spleeter-lib", 93 | ":spleeter-third_party", 94 | ], 95 | ) 96 | -------------------------------------------------------------------------------- /spleeter/audio/i_audio_adapter.h: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @brief Contains interface for Audio Adapter 4 | /// @copyright Copyright (c) 2020, MIT License 5 | /// 6 | #ifndef SPLEETER_AUDIO_I_AUDIO_ADAPTER_H 7 | #define SPLEETER_AUDIO_I_AUDIO_ADAPTER_H 8 | 9 | #include "spleeter/datatypes/audio_properties.h" 10 | #include "spleeter/datatypes/waveform.h" 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | namespace spleeter 18 | { 19 | /// @brief Audio Adapter to read/write audio files 20 | class IAudioAdapter 21 | { 22 | public: 23 | /// @brief Destructor. 24 | virtual ~IAudioAdapter() = default; 25 | 26 | /// @brief Loads the audio file denoted by the given path and returns it data as a waveform. 27 | /// 28 | /// @param path [in] - Path of the audio file to load data from. 29 | /// @param offset [in] - Start offset to load from in seconds. 30 | /// @param duration [in] - Duration to load in seconds. 31 | /// @param sample_rate [in] - Sample rate to load audio with. 32 | /// 33 | /// @returns Loaded data as waveform 34 | virtual Waveform Load(const std::string& path, 35 | const double offset, 36 | const double duration, 37 | const std::int32_t sample_rate) = 0; 38 | 39 | /// @brief Write waveform data to the file denoted by the given path using FFMPEG process. 40 | /// 41 | /// @param path [in] - Path of the audio file to save data in. 42 | /// @param waveform [in] - Waveform data to write. 43 | /// @param sample_rate [in] - Sample rate to write file in. 44 | /// @param codec [in] - Writing codec to use. 45 | /// @param bitrate [in] - Bitrate of the written audio file. 46 | virtual void Save(const std::string& path, 47 | const Waveform& data, 48 | const std::int32_t sample_rate, 49 | const std::string& codec, 50 | const std::int32_t bitrate) = 0; 51 | 52 | /// @brief Provide properties of the Waveform (nb_frames, nb_channels, sample_rate) 53 | /// 54 | /// @return audio properties 55 | virtual AudioProperties GetProperties() const = 0; 56 | }; 57 | } // namespace spleeter 58 | 59 | #endif /// SPLEETER_AUDIO_I_AUDIO_ADAPTER_H 60 | -------------------------------------------------------------------------------- /spleeter/audio/audionamix_audio_adapter.h: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @brief Contains interface for Audio Adapter 4 | /// @copyright Copyright (c) 2020, MIT License 5 | /// 6 | #ifndef SPLEETER_AUDIO_AUDIONAMIX_AUDIO_ADAPTER_H 7 | #define SPLEETER_AUDIO_AUDIONAMIX_AUDIO_ADAPTER_H 8 | 9 | #include "spleeter/audio/i_audio_adapter.h" 10 | #include "spleeter/datatypes/audio_properties.h" 11 | 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | namespace spleeter 18 | { 19 | /// @brief Audio Adapter to read/write audio files based on Audionamix lib 20 | class AudionamixAudioAdapter final : public IAudioAdapter 21 | { 22 | public: 23 | /// @brief Constructor. 24 | AudionamixAudioAdapter(); 25 | 26 | /// @brief Loads the audio file denoted by the given path and returns it data as a waveform. 27 | /// 28 | /// @param path [in] - Path of the audio file to load data from. 29 | /// @param offset [in] - Start offset to load from in seconds. 30 | /// @param duration [in] - Duration to load in seconds. 31 | /// @param sample_rate [in] - Sample rate to load audio with. 32 | /// 33 | /// @returns Loaded data a (waveform, sample_rate) pair. 34 | Waveform Load(const std::string& path, 35 | const double offset, 36 | const double duration, 37 | const std::int32_t sample_rate) override; 38 | 39 | /// @brief Write waveform data to the file denoted by the given path using FFMPEG process. 40 | /// 41 | /// @param path [in] - Path of the audio file to save data in. 42 | /// @param waveform [in] - Waveform data to write. 43 | /// @param sample_rate [in] - Sample rate to write file in. 44 | /// @param codec [in] - Writing codec to use. 45 | /// @param bitrate [in] - Bitrate of the written audio file. 46 | void Save(const std::string& path, 47 | const Waveform& waveform, 48 | const std::int32_t sample_rate, 49 | const std::string& codec, 50 | const std::int32_t bitrate) override; 51 | 52 | /// @brief Provide shape of the Waveform (nb_frames, nb_channels) 53 | /// 54 | /// @return Pair of integers (nb_frames, nb_channels) 55 | AudioProperties GetProperties() const override; 56 | 57 | private: 58 | /// @brief Loaded Audio Properties 59 | AudioProperties audio_properties_; 60 | }; 61 | } // namespace spleeter 62 | 63 | #endif /// SPLEETER_AUDIO_AUDIONAMIX_AUDIO_ADAPTER_H 64 | -------------------------------------------------------------------------------- /spleeter/audio/audionamix_audio_adapter.cpp: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @brief Contains definitions for FFMPEG Audio Adapter class methods 4 | /// @copyright Copyright (c) 2020, MIT License 5 | /// 6 | #include "spleeter/audio/audionamix_audio_adapter.h" 7 | 8 | #include "spleeter/logging/logging.h" 9 | 10 | namespace spleeter 11 | { 12 | AudionamixAudioAdapter::AudionamixAudioAdapter() {} 13 | 14 | Waveform AudionamixAudioAdapter::Load(const std::string& path, 15 | const double /*offset*/, 16 | const double /*duration*/, 17 | const std::int32_t /*sample_rate*/) 18 | { 19 | auto file = wave::File{}; 20 | 21 | auto ret = file.Open(path, wave::OpenMode::kIn); 22 | CHECK(!ret) << "Unable to open " << path << ", only *.wav is supported! (Returned: " << ret << ")"; 23 | 24 | Waveform waveform{}; 25 | ret = file.Read(&waveform.data); 26 | CHECK(!ret) << "Unable to read " << path << ", only *.wav is supported! (Returned: " << ret << ")"; 27 | 28 | /// Save loaded audio properties 29 | audio_properties_.nb_frames = file.frame_number(); 30 | audio_properties_.nb_channels = file.channel_number(); 31 | audio_properties_.sample_rate = file.sample_rate(); 32 | waveform.nb_channels = audio_properties_.nb_channels; 33 | waveform.nb_frames = audio_properties_.nb_frames; 34 | 35 | LOG(INFO) << "Loaded waveform from " << path << " using Audionamix."; 36 | return waveform; 37 | } 38 | 39 | void AudionamixAudioAdapter::Save(const std::string& path, 40 | const Waveform& waveform, 41 | const std::int32_t sample_rate, 42 | const std::string& /*codec*/, 43 | const std::int32_t /*bitrate*/) 44 | { 45 | auto file = wave::File{}; 46 | 47 | auto ret = file.Open(path, wave::OpenMode::kOut); 48 | CHECK(!ret) << "Unable to open " << path << "! (Returned: " << ret << ")"; 49 | 50 | /// Set user provided configurations 51 | if (sample_rate != -1) 52 | { 53 | file.set_sample_rate(sample_rate); 54 | } 55 | file.set_channel_number(2); 56 | 57 | ret = file.Write(waveform.data); 58 | CHECK(!ret) << "Unable to write to " << path << "! (Returned: " << ret << ")"; 59 | 60 | LOG(INFO) << "Saved waveform to " << path << " using Audionamix."; 61 | } 62 | 63 | AudioProperties AudionamixAudioAdapter::GetProperties() const 64 | { 65 | return audio_properties_; 66 | } 67 | 68 | } // namespace spleeter 69 | -------------------------------------------------------------------------------- /spleeter/inference_engine/tf_inference_engine.h: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @copyright Copyright (c) 2020. MIT License 4 | /// 5 | #ifndef SPLEETER_INFERENCE_ENGINE_TF_INFERENCE_ENGINE_H 6 | #define SPLEETER_INFERENCE_ENGINE_TF_INFERENCE_ENGINE_H 7 | 8 | #include "spleeter/datatypes/inference_engine.h" 9 | #include "spleeter/inference_engine/i_inference_engine.h" 10 | 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | namespace spleeter 20 | { 21 | /// @brief TensorFlow Inference Engine class 22 | class TFInferenceEngine final : public IInferenceEngine 23 | { 24 | public: 25 | /// @brief Constructor 26 | /// @param params[in] Inference Engine Parameters such as model input/output node names 27 | explicit TFInferenceEngine(const InferenceEngineParameters& params); 28 | 29 | /// @brief Initialise TensorFlow Inference Engine 30 | void Init() override; 31 | 32 | /// @brief Execute Inference with TensorFlow Inference Engine 33 | /// @param waveform [in] - Waveform to be split 34 | void Execute(const Waveform& waveform) override; 35 | 36 | /// @brief Release TensorFlow Inference Engine 37 | void Shutdown() override; 38 | 39 | /// @brief Provide results in terms of Matrix 40 | /// @return List of waveforms (split waveforms) 41 | Waveforms GetResults() const override; 42 | 43 | private: 44 | /// @brief Updates Input Tensor by copying waveform to input_tensor 45 | /// @param waveform [in] - Waveform to be split 46 | void UpdateInput(const Waveform& waveform); 47 | 48 | /// @brief Updates Output Tensors by running the tensorflow session 49 | void UpdateTensors(); 50 | 51 | /// @brief Converts output_tensors to cv::Mat results 52 | void UpdateOutputs(); 53 | 54 | /// @brief Saved Model bundle 55 | std::shared_ptr bundle_; 56 | 57 | /// @brief Input Tensor 58 | tensorflow::Tensor input_tensor_; 59 | 60 | /// @brief Input Tensor name 61 | const std::string input_tensor_name_; 62 | 63 | /// @brief Output Tensors 64 | std::vector output_tensors_; 65 | 66 | /// @brief Output Tensors names 67 | const std::vector output_tensor_names_; 68 | 69 | /// @brief Model root directory 70 | const std::string model_path_; 71 | 72 | /// @brief Output Tensors saved as cv::Mat 73 | Waveforms results_; 74 | }; 75 | 76 | } // namespace spleeter 77 | 78 | #endif /// SPLEETER_INFERENCE_ENGINE_TF_INFERENCE_ENGINE_H 79 | -------------------------------------------------------------------------------- /spleeter/argument_parser/cli_options.h: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @brief Contains command line interface options definitions 4 | /// @copyright Copyright (c) 2020, MIT License 5 | /// 6 | #ifndef SPLEETER_ARGUMENT_PARSER_CLI_OPTIONS_H 7 | #define SPLEETER_ARGUMENT_PARSER_CLI_OPTIONS_H 8 | 9 | #include "spleeter/datatypes/inference_engine.h" 10 | 11 | #include 12 | 13 | namespace spleeter 14 | { 15 | /// @brief Contains Command Line Interface (CLI) Options 16 | struct CLIOptions 17 | { 18 | /// @brief List of input audio filenames 19 | std::string inputs{"external/audio_example/file/audio_example.wav"}; 20 | 21 | /// @brief Path of the output directory to write audio files in 22 | std::string output_path{"separated_audio"}; 23 | 24 | /// @brief Template string that will be formatted to generated 25 | /// output filename. Such template should be Python formattable 26 | /// string, and could use {filename}, {instrument}, and {codec} 27 | /// variables 28 | std::string filename_format{"{filename}/{instrument}.{codec}"}; 29 | 30 | /// @brief JSON filename that contains params 31 | /// choices: { "spleeter:5stems", "spleeter:4stems", "spleeter:2stems" } 32 | std::string configuration{"spleeter:5stems"}; 33 | 34 | /// @brief Inference Engine Parameters (contains model_path, input/output tensor names) 35 | InferenceEngineParameters inference_engine_params{ 36 | "external/models/5stems/5stems.tflite", 37 | "waveform", 38 | {"strided_slice_18", "strided_slice_38", "strided_slice_48", "strided_slice_28", "strided_slice_58"}, 39 | "spleeter:5stems"}; 40 | 41 | /// @brief Set the starting offset to separate audio from 42 | double offset{0.0}; 43 | 44 | /// @brief Set a maximum duration for the processing audio 45 | /// (only separate offset + duration first seconds of the input file) 46 | double duration{600.0}; 47 | 48 | /// @brief Audio codec to be used for separate output 49 | /// @todo Currently supports only *.wav, extend support to work with other audio codec 50 | /// choices: { wav, mp3, ogg, m4a, wma, flac } 51 | std::string codec{"wav"}; 52 | 53 | /// @brief Audio bitrate to be used for separate output 54 | std::int32_t bitrate{128000}; 55 | 56 | /// @brief Whether to use multichannel Wiener filtering for separation 57 | bool mwf{false}; 58 | 59 | /// @brief Path to folder with musDB 60 | std::string mus_dir{}; 61 | 62 | /// @brief Path of the folder containing audio data for training 63 | std::string audio_path{}; 64 | 65 | /// @brief Name of the audio adapater to use for audio I/O 66 | std::string audio_adapter{"audionamix"}; 67 | 68 | /// @brief Shows verbose logs 69 | bool verbose{false}; 70 | }; 71 | 72 | } // namespace spleeter 73 | #endif /// SPLEETER_ARGUMENT_PARSER_CLI_OPTIONS_H 74 | -------------------------------------------------------------------------------- /spleeter/audio/ffmpeg_audio_adapter.h: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @brief Contains implementation of FFMPEG based Audio Adapter 4 | /// @copyright Copyright (c) 2020, MIT License 5 | /// 6 | #ifndef SPLEETER_AUDIO_FFMPEG_AUDIO_ADAPTER_H 7 | #define SPLEETER_AUDIO_FFMPEG_AUDIO_ADAPTER_H 8 | 9 | #include "spleeter/audio/i_audio_adapter.h" 10 | #include "spleeter/datatypes/audio_properties.h" 11 | 12 | #ifndef __cplusplus__ 13 | extern "C" 14 | { 15 | #endif 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #ifndef __cplusplus__ 24 | } 25 | #endif 26 | 27 | #include 28 | #include 29 | #include 30 | #include 31 | 32 | namespace spleeter 33 | { 34 | /// @brief An AudioAdapter implementation that use FFMPEG libraries to perform I/O operation for audio processing. 35 | class FfmpegAudioAdapter final : public IAudioAdapter 36 | { 37 | public: 38 | /// @brief Constructor. 39 | FfmpegAudioAdapter(); 40 | 41 | /// @brief Loads the audio file denoted by the given path and returns it data as a waveform. 42 | /// 43 | /// @param path [in] - Path of the audio file to load data from. 44 | /// @param offset [in] - Start offset to load from in seconds. 45 | /// @param duration [in] - Duration to load in seconds. 46 | /// @param sample_rate [in] - Sample rate to load audio with. 47 | /// 48 | /// @returns Loaded data as waveform 49 | Waveform Load(const std::string& path, 50 | const double offset, 51 | const double duration, 52 | const std::int32_t sample_rate) override; 53 | 54 | /// @brief Write waveform data to the file denoted by the given path using FFMPEG process. 55 | /// 56 | /// @param path [in] - Path of the audio file to save data in. 57 | /// @param waveform [in] - Waveform data to write. 58 | /// @param sample_rate [in] - Sample rate to write file in. 59 | /// @param codec [in] - Writing codec to use. 60 | /// @param bitrate [in] - Bitrate of the written audio file. 61 | void Save(const std::string& path, 62 | const Waveform& waveform, 63 | const std::int32_t sample_rate, 64 | const std::string& codec, 65 | const std::int32_t bitrate) override; 66 | 67 | /// @brief Provide properties of the Waveform (nb_frames, nb_channels, sample_rate) 68 | /// 69 | /// @return audio properties 70 | AudioProperties GetProperties() const override; 71 | 72 | private: 73 | /// @brief Loaded Audio Properties 74 | AudioProperties audio_properties_; 75 | }; 76 | } // namespace spleeter 77 | 78 | #endif /// SPLEETER_AUDIO_FFMPEG_AUDIO_ADAPTER_H 79 | -------------------------------------------------------------------------------- /tools/tflite_converter/export_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | trained_checkpoint_prefix = '/workspace/5stems/model' 5 | export_dir = os.path.join('export_dir', '0') 6 | 7 | graph = tf.Graph() 8 | with tf.compat.v1.Session(graph=graph) as sess: 9 | # Restore from checkpoint 10 | loader = tf.compat.v1.train.import_meta_graph(trained_checkpoint_prefix + '.meta') 11 | loader.restore(sess, trained_checkpoint_prefix) 12 | 13 | waveform_tensor = tf.compat.v1.get_default_graph().get_tensor_by_name("waveform:0") 14 | waveform_tensor_info = tf.compat.v1.saved_model.utils.build_tensor_info(waveform_tensor) 15 | 16 | bass_tensor = tf.compat.v1.get_default_graph().get_tensor_by_name("strided_slice_48:0") 17 | bass_tensor_info = tf.compat.v1.saved_model.utils.build_tensor_info(bass_tensor) 18 | 19 | drums_tensor = tf.compat.v1.get_default_graph().get_tensor_by_name("strided_slice_38:0") 20 | drums_tensor_info = tf.compat.v1.saved_model.utils.build_tensor_info(drums_tensor) 21 | 22 | other_tensor = tf.compat.v1.get_default_graph().get_tensor_by_name("strided_slice_58:0") 23 | other_tensor_info = tf.compat.v1.saved_model.utils.build_tensor_info(other_tensor) 24 | 25 | piano_tensor = tf.compat.v1.get_default_graph().get_tensor_by_name("strided_slice_28:0") 26 | piano_tensor_info = tf.compat.v1.saved_model.utils.build_tensor_info(piano_tensor) 27 | 28 | vocals_tensor = tf.compat.v1.get_default_graph().get_tensor_by_name("strided_slice_18:0") 29 | vocals_tensor_info = tf.compat.v1.saved_model.utils.build_tensor_info(vocals_tensor) 30 | 31 | current_graph = tf.compat.v1.get_default_graph() 32 | 33 | print(waveform_tensor_info) 34 | print(waveform_tensor) 35 | 36 | separate_signature = ( 37 | tf.compat.v1.saved_model.signature_def_utils.build_signature_def( 38 | inputs={ 'waveform': waveform_tensor_info }, 39 | outputs={ 'bass': bass_tensor_info, 40 | 'drums': drums_tensor_info, 41 | 'other': other_tensor_info, 42 | 'vocals': vocals_tensor_info, 43 | 'piano': piano_tensor_info }, 44 | method_name=tf.compat.v1.saved_model.signature_constants 45 | .PREDICT_METHOD_NAME)) 46 | 47 | # Export checkpoint to SavedModel 48 | builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir) 49 | builder.add_meta_graph_and_variables(sess, [tf.compat.v1.saved_model.tag_constants.SERVING], 50 | signature_def_map={ 51 | tf.compat.v1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 52 | separate_signature, 53 | }, 54 | strip_default_attrs=True) 55 | builder.save() 56 | -------------------------------------------------------------------------------- /spleeter/i_separator.h: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @brief Contains interface definitions for Separator. 4 | /// @copyright Copyright (c) 2020, MIT License 5 | /// 6 | #ifndef SPLEETER_I_SEPARATOR_H 7 | #define SPLEETER_I_SEPARATOR_H 8 | 9 | #include "spleeter/datatypes/audio_properties.h" 10 | #include "spleeter/datatypes/waveform.h" 11 | 12 | #include 13 | #include 14 | 15 | namespace spleeter 16 | { 17 | /// @brief Interface for Separator APIs 18 | class ISeparator 19 | { 20 | public: 21 | /// @brief Destructor 22 | virtual ~ISeparator() = default; 23 | 24 | /// @brief Performs source separation over the given waveform. 25 | /// 26 | /// The separation is performed synchronously but the result processing is done asynchronously, allowing for 27 | /// instance to export audio in parallel (through multiprocessing). 28 | /// 29 | /// Given result is passed by to the given consumer, which will be waited for task finishing if 30 | /// synchronous flag is True. 31 | /// 32 | /// @param waveform [in] Waveform to apply separation on. 33 | /// 34 | /// @returns Separated waveforms 35 | virtual Waveforms Separate(const Waveform& waveform) = 0; 36 | 37 | /// @brief Performs source separation and export result to file using given audio adapter. 38 | /// 39 | /// Filename format should be a Python formattable string that could use following parameters: 40 | /// {instrument}, {filename} and {codec}. 41 | /// 42 | /// @param audio_descriptor [in] - Describe song to separate, used by audio adapter to retrieve and load audio data, 43 | /// in case of file based audio adapter, such descriptor would be a file path. 44 | /// @param destination [in] - Target directory to write output to. 45 | /// @param audio_adapter [in] - Audio adapter to use for I/O. 46 | /// @param offset [in] - Offset of loaded song. 47 | /// @param duration [in] - Duration of loaded song. 48 | /// @param codec [in] - Export codec. 49 | /// @param bitrate [in] - Export bitrate. 50 | /// @param filename_format [in] - Filename format. 51 | /// @param synchronous [in] - True is should by synchronous. 52 | virtual void SeparateToFile(const std::string& audio_descriptor, 53 | const std::string& destination, 54 | const std::string& audio_adapter, 55 | const double offset, 56 | const double duration, 57 | const std::string& codec, 58 | const std::int32_t bitrate, 59 | const std::string& filename_format, 60 | const bool synchronous) = 0; 61 | }; 62 | } // namespace spleeter 63 | 64 | #endif /// SPLEETER_I_SEPARATOR_H 65 | -------------------------------------------------------------------------------- /spleeter/argument_parser/test/argument_parser_tests.cpp: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @brief Contains unit tests for Argument Parser APIs 4 | /// @copyright Copyright (c) 2020. All Rights Reserved. 5 | /// 6 | #include "spleeter/argument_parser/argument_parser.h" 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | namespace spleeter 14 | { 15 | namespace 16 | { 17 | TEST(ArgumentParserTest, DefaultConstructor) 18 | { 19 | auto unit = ArgumentParser(); 20 | auto actual = unit.GetParsedArgs(); 21 | 22 | EXPECT_EQ(actual.inputs, "external/audio_example/file/audio_example.wav"); 23 | EXPECT_EQ(actual.output_path, "separated_audio"); 24 | EXPECT_EQ(actual.filename_format, "{filename}/{instrument}.{codec}"); 25 | EXPECT_EQ(actual.configuration, "spleeter:5stems"); 26 | EXPECT_DOUBLE_EQ(actual.offset, 0.0); 27 | EXPECT_DOUBLE_EQ(actual.duration, 600.0); 28 | EXPECT_EQ(actual.codec, "wav"); 29 | EXPECT_EQ(actual.bitrate, 128000); 30 | EXPECT_FALSE(actual.mwf); 31 | EXPECT_TRUE(actual.mus_dir.empty()); 32 | EXPECT_TRUE(actual.audio_path.empty()); 33 | EXPECT_EQ(actual.audio_adapter, "audionamix"); 34 | EXPECT_FALSE(actual.verbose); 35 | } 36 | TEST(ArgumentParserTest, WhenHelpArgument) 37 | { 38 | char* argv[] = {"app", "-h"}; 39 | int argc = sizeof(argv) / sizeof(char*); 40 | EXPECT_EXIT(ArgumentParser(argc, argv), ::testing::ExitedWithCode(1), ""); 41 | } 42 | 43 | TEST(ArgumentParserTest, ParameterizedConstructor) 44 | { 45 | // clang-format off 46 | char* argv[] = {"app", 47 | "-i", "sample.mp3", 48 | "-o", "audio_out", 49 | "-f", "{filename}_{instrument}.{codec}", 50 | "--params_filename", "spleeter:5stems", 51 | "-s", "1.0", 52 | "-d", "20", 53 | "-c", "mp3", 54 | "-b", "320k", 55 | "-m", "1", 56 | "-v", "1", 57 | "-u", "my_mus_dir", 58 | "-t", "data.file", 59 | "-a", "adapter"}; 60 | // clang-format on 61 | int argc = sizeof(argv) / sizeof(char*); 62 | auto unit = ArgumentParser(argc, argv); 63 | auto actual = unit.GetParsedArgs(); 64 | 65 | EXPECT_EQ(actual.inputs, "sample.mp3"); 66 | EXPECT_EQ(actual.output_path, "audio_out"); 67 | EXPECT_EQ(actual.filename_format, "{filename}_{instrument}.{codec}"); 68 | EXPECT_EQ(actual.configuration, "spleeter:5stems"); 69 | EXPECT_DOUBLE_EQ(actual.offset, 1.0); 70 | EXPECT_DOUBLE_EQ(actual.duration, 20.0); 71 | EXPECT_EQ(actual.codec, "mp3"); 72 | EXPECT_EQ(actual.bitrate, 320000); 73 | EXPECT_TRUE(actual.mwf); 74 | EXPECT_EQ(actual.mus_dir, "my_mus_dir"); 75 | EXPECT_EQ(actual.audio_path, "data.file"); 76 | EXPECT_EQ(actual.audio_adapter, "adapter"); 77 | EXPECT_TRUE(actual.verbose); 78 | } 79 | } // namespace 80 | } // namespace spleeter 81 | -------------------------------------------------------------------------------- /spleeter/separator.cpp: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @brief Contains definitions for Separator class methods 4 | /// @copyright Copyright (c) 2020, MIT License 5 | /// 6 | #include "spleeter/separator.h" 7 | 8 | #include "spleeter/audio/audionamix_audio_adapter.h" 9 | #include "spleeter/audio/ffmpeg_audio_adapter.h" 10 | #include "spleeter/logging/logging.h" 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | namespace spleeter 19 | { 20 | namespace internal 21 | { 22 | std::vector GetWaveformNames(const std::string& configuration) 23 | { 24 | auto waveform_names = std::vector{}; 25 | if (configuration == "spleeter:2stems") 26 | { 27 | waveform_names = std::vector{"vocal", "accompaniment"}; 28 | } 29 | 30 | else if (configuration == "spleeter:4stems") 31 | { 32 | waveform_names = std::vector{"vocal", "drums", "bass", "accompaniment"}; 33 | } 34 | else // default to "spleeter:5stems" 35 | { 36 | waveform_names = std::vector{"vocal", "drums", "bass", "piano", "accompaniment"}; 37 | } 38 | return waveform_names; 39 | } 40 | } // namespace internal 41 | 42 | Separator::Separator(const InferenceEngineParameters& inference_engine_param, const bool mwf) 43 | : mwf_{mwf}, 44 | audio_adapter_{std::make_unique()}, 45 | inference_engine_{}, 46 | waveform_name_{internal::GetWaveformNames(inference_engine_param.configuration)} 47 | { 48 | inference_engine_.SelectInferenceEngine(InferenceEngineType::kTensorFlowLite, inference_engine_param); 49 | inference_engine_.Init(); 50 | } 51 | 52 | Waveforms Separator::Separate(const Waveform& waveform) 53 | { 54 | inference_engine_.Execute(waveform); 55 | return inference_engine_.GetResults(); 56 | } 57 | 58 | void Separator::SeparateToFile(const std::string& audio_descriptor, 59 | const std::string& destination, 60 | const std::string& /*audio_adapter*/, 61 | const double offset, 62 | const double duration, 63 | const std::string& codec, 64 | const std::int32_t bitrate, 65 | const std::string& /*filename_format*/, 66 | const bool /*synchronous*/) 67 | { 68 | const auto waveform = audio_adapter_->Load(audio_descriptor, offset, duration, -1); 69 | 70 | auto waveforms = Separate(waveform); 71 | 72 | for (auto idx = 0U; idx < waveforms.size(); ++idx) 73 | { 74 | if (!std::experimental::filesystem::exists(destination)) 75 | { 76 | std::experimental::filesystem::create_directory(destination); 77 | } 78 | const auto path = destination + "/" + waveform_name_[idx] + "." + codec; 79 | audio_adapter_->Save(path, waveforms[idx], -1, codec, bitrate); 80 | } 81 | } 82 | 83 | } // namespace spleeter 84 | -------------------------------------------------------------------------------- /spleeter/inference_engine/tflite_inference_engine.h: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @copyright Copyright (c) 2020. MIT License 4 | /// 5 | #ifndef SPLEETER_INFERENCE_ENGINE_TFLITE_INFERENCE_ENGINE_H 6 | #define SPLEETER_INFERENCE_ENGINE_TFLITE_INFERENCE_ENGINE_H 7 | 8 | #include "spleeter/datatypes/inference_engine.h" 9 | #include "spleeter/inference_engine/i_inference_engine.h" 10 | 11 | #include 12 | #include 13 | 14 | #include 15 | 16 | namespace spleeter 17 | { 18 | /// @brief TFLite Inference Engine class 19 | class TFLiteInferenceEngine final : public IInferenceEngine 20 | { 21 | public: 22 | /// @brief Constructor 23 | explicit TFLiteInferenceEngine(const InferenceEngineParameters& params); 24 | 25 | /// @brief Initialise TFLite Inference Engine 26 | void Init() override; 27 | 28 | /// @brief Execute Inference with TFLite Inference Engine 29 | /// @param waveform [in] - Waveform to be split 30 | void Execute(const Waveform& waveform) override; 31 | 32 | /// @brief Release TFLite Inference Engine 33 | void Shutdown() override; 34 | 35 | /// @brief Obtain Results for provided input waveform 36 | /// 37 | /// @return List of waveforms (split waveforms) 38 | Waveforms GetResults() const override; 39 | 40 | private: 41 | /// @brief Updates Input Tensor by copying waveform to input_tensor 42 | /// @param waveform [in] - Waveform to be split 43 | void UpdateInput(const Waveform& waveform); 44 | 45 | /// @brief Updates Output Tensors by running the tensorflow session 46 | void UpdateTensors(); 47 | 48 | /// @brief Converts output_tensors to Waveforms results 49 | void UpdateOutputs(); 50 | 51 | /// @brief Resizes Input Tensor based on the provided waveform shape 52 | void ResizeInputTensor(const Waveform& waveform); 53 | 54 | /// @brief Resizes Output Tensor based on the provided waveform shape 55 | void ResizeOutputTensor(const Waveform& waveform); 56 | 57 | /// @brief Resizes Tensor at given index to the provided dimensions 58 | /// @param tensor_index[in] Tensor Index to resize 59 | /// @param dims[in] Requested dimensions for the Tensor 60 | void ResizeTensor(const std::int32_t tensor_index, const std::vector dims); 61 | 62 | /// @brief TFLite Model Buffer Instance 63 | std::unique_ptr model_; 64 | 65 | /// @brief TFLite Model Interpreter instance 66 | std::unique_ptr interpreter_; 67 | 68 | /// @brief Input Tensor name 69 | const std::string input_tensor_name_; 70 | 71 | /// @brief Output Tensors names 72 | const std::vector output_tensor_names_; 73 | 74 | /// @brief TFLite Inference output tensor indices 75 | std::vector output_tensor_indicies_; 76 | 77 | /// @brief Model root directory 78 | const std::string model_path_; 79 | 80 | /// @brief Output Tensors saved as Waveforms 81 | Waveforms results_; 82 | }; 83 | 84 | } // namespace spleeter 85 | #endif /// SPLEETER_INFERENCE_ENGINE_TFLITE_INFERENCE_ENGINE_H 86 | -------------------------------------------------------------------------------- /spleeter/test/separator_tests.cpp: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @copyright Copyright (c) 2020, MIT License 4 | /// 5 | #include "spleeter/argument_parser/cli_options.h" 6 | #include "spleeter/audio/audionamix_audio_adapter.h" 7 | #include "spleeter/i_separator.h" 8 | #include "spleeter/separator.h" 9 | 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | namespace spleeter 16 | { 17 | namespace 18 | { 19 | 20 | class SeparatorTest : public ::testing::TestWithParam 21 | { 22 | public: 23 | SeparatorTest() : cli_options_{}, test_waveform_{}, audio_adapter_{} {} 24 | 25 | protected: 26 | void SetUp() override 27 | { 28 | cli_options_.output_path = "/tmp/separated_audio2"; 29 | const auto stem{std::to_string(GetParam()) + "stems"}; 30 | cli_options_.configuration = "spleeter:" + stem; 31 | cli_options_.inference_engine_params.model_path = "external/models/" + stem + "/" + stem + ".tflite"; 32 | cli_options_.inference_engine_params.output_tensor_names = GetOutputTensorNames(cli_options_.configuration); 33 | cli_options_.inference_engine_params.configuration = cli_options_.configuration; 34 | 35 | test_waveform_ = audio_adapter_.Load(cli_options_.inputs, 0, -1, 44100); 36 | 37 | unit_ = std::make_unique(cli_options_.inference_engine_params, cli_options_.mwf); 38 | } 39 | 40 | static std::vector GetOutputTensorNames(const std::string& configuration) 41 | { 42 | auto output_tensor_names = std::vector{}; 43 | if (configuration == "spleeter:2stems") 44 | { 45 | output_tensor_names = std::vector{"strided_slice_13", "strided_slice_23"}; 46 | } 47 | 48 | else if (configuration == "spleeter:4stems") 49 | { 50 | output_tensor_names = std::vector{ 51 | "strided_slice_13", "strided_slice_23", "strided_slice_33", "strided_slice_43"}; 52 | } 53 | else // default to "spleeter:5stems" 54 | { 55 | output_tensor_names = std::vector{ 56 | "strided_slice_18", "strided_slice_38", "strided_slice_48", "strided_slice_28", "strided_slice_58"}; 57 | } 58 | return output_tensor_names; 59 | } 60 | 61 | CLIOptions cli_options_; 62 | std::unique_ptr unit_; 63 | Waveform test_waveform_; 64 | AudionamixAudioAdapter audio_adapter_; 65 | }; 66 | 67 | TEST_P(SeparatorTest, GivenConfiguration_ExpectSeparatedWaveforms) 68 | { 69 | auto actual = unit_->Separate(test_waveform_); 70 | EXPECT_EQ(GetParam(), actual.size()); 71 | } 72 | 73 | /// @test In this, test that SeparateToFile saves all the resultant 5 waveforms to file (for given 5stems model) 74 | TEST_P(SeparatorTest, GivenTypicalInputs_ExpectSeparatedToFiles) 75 | { 76 | unit_->SeparateToFile(cli_options_.inputs, 77 | cli_options_.output_path, 78 | cli_options_.audio_adapter, 79 | cli_options_.offset, 80 | cli_options_.duration, 81 | cli_options_.codec, 82 | cli_options_.bitrate, 83 | cli_options_.filename_format, 84 | false); 85 | 86 | EXPECT_TRUE(std::experimental::filesystem::exists(cli_options_.output_path)); 87 | } 88 | 89 | INSTANTIATE_TEST_CASE_P(Configurations, SeparatorTest, ::testing::Values(2, 4, 5)); 90 | 91 | } // namespace 92 | } // namespace spleeter 93 | -------------------------------------------------------------------------------- /spleeter/separator.h: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @brief Contains class definition for Separator which implements class ISeparator. 4 | /// @copyright Copyright (c) 2020, MIT License 5 | /// 6 | #ifndef SPLEETER_SEPARATOR_H 7 | #define SPLEETER_SEPARATOR_H 8 | 9 | #include "spleeter/argument_parser/cli_options.h" 10 | #include "spleeter/audio/i_audio_adapter.h" 11 | #include "spleeter/datatypes/inference_engine.h" 12 | #include "spleeter/i_separator.h" 13 | #include "spleeter/inference_engine/inference_engine_strategy.h" 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | namespace spleeter 21 | { 22 | /// @brief Perform separation 23 | class Separator final : public ISeparator 24 | { 25 | public: 26 | /// @brief Constructor 27 | /// 28 | /// @param inference_engine_param [in] - Parameters for inference 29 | /// @param mwf [in] - (Optional) True if MWF should be used, False otherwise. 30 | explicit Separator(const InferenceEngineParameters& inference_engine_param, const bool mwf = false); 31 | 32 | /// @brief Performs source separation over the given waveform. 33 | /// 34 | /// The separation is performed synchronously but the result processing is done asynchronously, allowing for 35 | /// instance to export audio in parallel (through multiprocessing). 36 | /// 37 | /// Given result is passed by to the given consumer, which will be waited for task finishing if 38 | /// synchronous flag is True. 39 | /// 40 | /// @param waveform [in] Waveform to apply separation on. 41 | /// 42 | /// @returns Separated waveforms 43 | Waveforms Separate(const Waveform& waveform) override; 44 | 45 | /// @brief Performs source separation and export result to file using given audio adapter. 46 | /// 47 | /// Filename format should be a Python formattable string that could use following parameters: 48 | /// {instrument}, {filename} and {codec}. 49 | /// 50 | /// @param audio_descriptor [in] - Describe song to separate, used by audio adapter to retrieve and load audio data, 51 | /// in case of file based audio adapter, such descriptor would be a file path. 52 | /// @param destination [in] - Target directory to write output to. 53 | /// @param audio_adapter [in] - (Optional) Audio adapter to use for I/O. 54 | /// @param offset [in] - (Optional) Offset of loaded song. 55 | /// @param duration [in] - (Optional) Duration of loaded song. 56 | /// @param codec [in] - (Optional) Export codec. 57 | /// @param bitrate [in] - (Optional) Export bitrate. 58 | /// @param filename_format [in] - (Optional) Filename format. 59 | /// @param synchronous [in] - (Optional) True is should by synchronous. 60 | void SeparateToFile(const std::string& audio_descriptor, 61 | const std::string& destination, 62 | const std::string& audio_adapter, 63 | const double offset = 0.0, 64 | const double duration = 600.0, 65 | const std::string& codec = "wav", 66 | const std::int32_t bitrate = 128000, 67 | const std::string& filename_format = "{filename}/{instrument}.{codec}", 68 | const bool synchronous = true) override; 69 | 70 | private: 71 | /// @brief Is mwf enabled? 72 | bool mwf_; 73 | 74 | /// @brief Audio R/W Adapter 75 | std::unique_ptr audio_adapter_; 76 | 77 | /// @brief Inference Engine 78 | InferenceEngineStrategy inference_engine_; 79 | 80 | /// @brief List of waveform names (used to save files) 81 | std::vector waveform_name_; 82 | }; 83 | } // namespace spleeter 84 | 85 | #endif /// SPLEETER_SEPARATOR_H 86 | -------------------------------------------------------------------------------- /spleeter/audio/test/audio_adapter_tests.cpp: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @copyright Copyright (c) 2020, MIT License 4 | /// 5 | #include "spleeter/audio/audionamix_audio_adapter.h" 6 | #include "spleeter/audio/ffmpeg_audio_adapter.h" 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | namespace spleeter 16 | { 17 | namespace 18 | { 19 | template 20 | class AudioAdapterTest : public ::testing::Test 21 | { 22 | public: 23 | AudioAdapterTest() 24 | : test_sample_rate_{44100}, 25 | test_waveform_path_{"external/audio_example/file/audio_example.wav"}, 26 | test_waveform_{}, 27 | test_waveform_properties_{} 28 | { 29 | } 30 | 31 | Waveform GetTestWaveform() const { return test_waveform_; } 32 | 33 | AudioProperties GetTestWaveformProperties() const { return test_waveform_properties_; } 34 | 35 | protected: 36 | void SetUp() override 37 | { 38 | AudionamixAudioAdapter audio_adapter{}; 39 | test_waveform_ = audio_adapter.Load(test_waveform_path_, 0, -1, 44100); 40 | test_waveform_properties_ = audio_adapter.GetProperties(); 41 | 42 | unit_ = std::make_unique(); 43 | } 44 | 45 | const std::int32_t test_sample_rate_; 46 | const std::string test_waveform_path_; 47 | Waveform test_waveform_; 48 | AudioProperties test_waveform_properties_; 49 | 50 | std::unique_ptr unit_; 51 | }; 52 | TYPED_TEST_SUITE_P(AudioAdapterTest); 53 | 54 | /// @test In this, test that Load method returns raw waveform and updates audio properties 55 | TYPED_TEST_P(AudioAdapterTest, GivenAudioFile_ExpectRawWaveform) 56 | { 57 | // When 58 | const auto actual = this->unit_->Load(this->test_waveform_path_, 0.0, -1, this->test_sample_rate_); 59 | 60 | // Then 61 | EXPECT_FALSE(actual.data.empty()); 62 | EXPECT_EQ(959664U, actual.data.size()); 63 | 64 | const auto actual_properties = this->unit_->GetProperties(); 65 | const auto expected_properties = this->GetTestWaveformProperties(); 66 | EXPECT_EQ(expected_properties.nb_channels, actual_properties.nb_channels); 67 | EXPECT_EQ(expected_properties.nb_frames, actual_properties.nb_frames); 68 | EXPECT_EQ(expected_properties.sample_rate, actual_properties.sample_rate); 69 | } 70 | 71 | /// @test In this, test that Save method saves the file and validated against source audio properties. 72 | TYPED_TEST_P(AudioAdapterTest, GivenRawWaveform_ExpectSavedFileHasSameProperties) 73 | { 74 | // Given 75 | const std::int32_t bitrate{192000}; 76 | const auto test_codec = "wav"; 77 | const auto test_results = "test_sample.wav"; 78 | 79 | // When 80 | this->unit_->Save(test_results, this->GetTestWaveform(), this->test_sample_rate_, test_codec, bitrate); 81 | 82 | // Then 83 | AudionamixAudioAdapter audio_adapter{}; 84 | const auto actual = audio_adapter.Load(test_results, 0, -1, this->test_sample_rate_); 85 | EXPECT_FALSE(actual.data.empty()); 86 | EXPECT_EQ(959664U, actual.data.size()); 87 | 88 | const auto actual_properties = audio_adapter.GetProperties(); 89 | const auto expected_properties = this->GetTestWaveformProperties(); 90 | EXPECT_EQ(expected_properties.nb_channels, actual_properties.nb_channels); 91 | EXPECT_EQ(expected_properties.nb_frames, actual_properties.nb_frames); 92 | EXPECT_EQ(expected_properties.sample_rate, actual_properties.sample_rate); 93 | } 94 | 95 | REGISTER_TYPED_TEST_SUITE_P(AudioAdapterTest, 96 | GivenAudioFile_ExpectRawWaveform, 97 | GivenRawWaveform_ExpectSavedFileHasSameProperties); 98 | 99 | typedef ::testing::Types AudioAdapterTestTypes; 100 | INSTANTIATE_TYPED_TEST_SUITE_P(TypeTests, AudioAdapterTest, AudioAdapterTestTypes); 101 | 102 | } // namespace 103 | } // namespace spleeter 104 | -------------------------------------------------------------------------------- /spleeter/inference_engine/tf_inference_engine.cpp: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @copyright Copyright (c) 2020. All Rights Reserved. 4 | /// 5 | #include "spleeter/inference_engine/tf_inference_engine.h" 6 | 7 | #include "spleeter/logging/logging.h" 8 | 9 | #include 10 | 11 | namespace spleeter 12 | { 13 | namespace 14 | { 15 | /// @brief Converts tensorflow::Tensor to waveform 16 | /// 17 | /// @param tensor[in] tensorflow::Tensor in [NxHxWxC form] 18 | /// 19 | /// @return Equivalent waveform for given tensorflow::Tensor by copying contents. 20 | Waveform ConvertToWaveform(const tensorflow::Tensor& tensor) 21 | { 22 | tensorflow::Tensor tensor_matrix = tensor; 23 | const auto samples = tensor_matrix.dims() > 0 ? static_cast(tensor_matrix.dim_size(0)) : 1; 24 | const auto channels = tensor_matrix.dims() > 1 ? static_cast(tensor_matrix.dim_size(1)) : 1; 25 | const auto size = samples * channels; 26 | auto* tensor_ptr = tensor_matrix.flat().data(); 27 | Waveform waveform{}; 28 | waveform.nb_frames = samples; 29 | waveform.nb_channels = channels; 30 | std::copy(tensor_ptr, tensor_ptr + size, std::back_inserter(waveform.data)); 31 | return waveform; 32 | } 33 | 34 | /// @brief Converts waveform to tensorflow::Tensor 35 | /// 36 | /// @param waveform[in] waveform 37 | /// 38 | /// @return Equivalent tensorflow::Tensor for given waveform by copying the its contents to tensor. 39 | tensorflow::Tensor ConvertToTensor(const Waveform& waveform) 40 | { 41 | tensorflow::Tensor tensor{tensorflow::DT_FLOAT, tensorflow::TensorShape{waveform.nb_frames, waveform.nb_channels}}; 42 | std::copy(waveform.data.begin(), waveform.data.end(), tensor.matrix().data()); 43 | return tensor; 44 | } 45 | } // namespace 46 | 47 | TFInferenceEngine::TFInferenceEngine(const InferenceEngineParameters& params) 48 | : bundle_{std::make_shared()}, 49 | input_tensor_{}, 50 | input_tensor_name_{params.input_tensor_name}, 51 | output_tensors_{}, 52 | output_tensor_names_{params.output_tensor_names}, 53 | model_path_{params.model_path}, 54 | results_{} 55 | { 56 | } 57 | 58 | void TFInferenceEngine::Init() 59 | { 60 | tensorflow::SessionOptions session_options{}; 61 | tensorflow::RunOptions run_options{}; 62 | std::unordered_set tags{"serve"}; 63 | 64 | const auto ret = tensorflow::LoadSavedModel(session_options, run_options, model_path_, tags, bundle_.get()); 65 | CHECK(ret.ok()) << "Failed to load saved model '" << model_path_ << "', (Message: " << ret.error_message() << ")"; 66 | 67 | LOG(INFO) << "Successfully loaded saved model from '" << model_path_ << "'."; 68 | } 69 | 70 | void TFInferenceEngine::Execute(const Waveform& waveform) 71 | { 72 | UpdateInput(waveform); 73 | UpdateTensors(); 74 | UpdateOutputs(); 75 | } 76 | 77 | void TFInferenceEngine::Shutdown() {} 78 | 79 | Waveforms TFInferenceEngine::GetResults() const 80 | { 81 | return results_; 82 | } 83 | 84 | void TFInferenceEngine::UpdateInput(const Waveform& waveform) 85 | { 86 | input_tensor_ = ConvertToTensor(waveform); 87 | } 88 | 89 | void TFInferenceEngine::UpdateTensors() 90 | { 91 | const std::vector> inputs{{input_tensor_name_, input_tensor_}}; 92 | const std::vector target_node_names{}; 93 | 94 | const auto ret = bundle_->GetSession()->Run(inputs, output_tensor_names_, target_node_names, &output_tensors_); 95 | CHECK(ret.ok()) << "Unable to run Session, (Message: " << ret.error_message() << ")"; 96 | 97 | LOG(INFO) << "Successfully received results " << output_tensors_.size() << " outputs."; 98 | } 99 | 100 | void TFInferenceEngine::UpdateOutputs() 101 | { 102 | std::transform( 103 | output_tensors_.cbegin(), output_tensors_.cend(), std::back_inserter(results_), [](auto const& tensor) { 104 | return ConvertToWaveform(tensor); 105 | }); 106 | } 107 | 108 | } // namespace spleeter 109 | -------------------------------------------------------------------------------- /spleeter/inference_engine/test/inference_engine_tests.cpp: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @copyright Copyright (c) 2020. MIT License 4 | /// 5 | #include "spleeter/audio/audionamix_audio_adapter.h" 6 | #include "spleeter/datatypes/inference_engine.h" 7 | #include "spleeter/inference_engine/i_inference_engine.h" 8 | #include "spleeter/inference_engine/inference_engine_strategy.h" 9 | #include "spleeter/inference_engine/null_inference_engine.h" 10 | #include "spleeter/inference_engine/tf_inference_engine.h" 11 | #include "spleeter/inference_engine/tflite_inference_engine.h" 12 | 13 | #include 14 | #include 15 | 16 | #include 17 | #include 18 | #include 19 | 20 | namespace spleeter 21 | { 22 | namespace 23 | { 24 | 25 | using ::testing::AllOf; 26 | using ::testing::Each; 27 | using ::testing::Field; 28 | using ::testing::Property; 29 | 30 | template 31 | InferenceEngineParameters GetInferenceEngineParameter() 32 | { 33 | return InferenceEngineParameters{}; 34 | } 35 | 36 | template <> 37 | InferenceEngineParameters GetInferenceEngineParameter() 38 | { 39 | return InferenceEngineParameters{ 40 | "external/models/5stems/saved_model", 41 | "waveform", 42 | {"strided_slice_18", "strided_slice_38", "strided_slice_48", "strided_slice_28", "strided_slice_58"}, 43 | "spleeter:5stems"}; 44 | } 45 | 46 | template <> 47 | InferenceEngineParameters GetInferenceEngineParameter() 48 | { 49 | return InferenceEngineParameters{ 50 | "external/models/5stems/5stems.tflite", 51 | "waveform", 52 | {"strided_slice_18", "strided_slice_38", "strided_slice_48", "strided_slice_28", "strided_slice_58"}, 53 | "spleeter:5stems"}; 54 | } 55 | 56 | template 57 | class InferenceEngineFixture_WithInferenceEngineType : public ::testing::Test 58 | { 59 | public: 60 | InferenceEngineFixture_WithInferenceEngineType() 61 | : test_waveform_path_{"external/audio_example/file/audio_example.wav"}, 62 | audio_adapter_{}, 63 | test_waveform_{audio_adapter_.Load(test_waveform_path_, 0, -1, 44100)}, 64 | inference_engine_parameters_{GetInferenceEngineParameter()}, 65 | unit_{std::make_unique(inference_engine_parameters_)} 66 | { 67 | } 68 | 69 | protected: 70 | void SetUp() override { unit_->Init(); } 71 | void RunOnce() { unit_->Execute(test_waveform_); } 72 | void TearDown() override { unit_->Shutdown(); } 73 | 74 | Waveforms GetInferenceResults() const { return unit_->GetResults(); } 75 | InferenceEngineParameters GetInferenceParameters() const { return inference_engine_parameters_; } 76 | 77 | std::int32_t GetFrames() const { return test_waveform_.nb_frames; } 78 | 79 | private: 80 | const std::string test_waveform_path_; 81 | AudionamixAudioAdapter audio_adapter_; 82 | const Waveform test_waveform_; 83 | const InferenceEngineParameters inference_engine_parameters_; 84 | InferenceEnginePtr unit_; 85 | }; 86 | TYPED_TEST_SUITE_P(InferenceEngineFixture_WithInferenceEngineType); 87 | 88 | TYPED_TEST_P(InferenceEngineFixture_WithInferenceEngineType, InferenceEngine_GivenTypicalInputs_ExpectInferenceResults) 89 | { 90 | // When 91 | this->RunOnce(); 92 | 93 | // Then 94 | const auto actual = this->GetInferenceResults(); 95 | EXPECT_EQ(this->GetInferenceParameters().output_tensor_names.size(), actual.size()); 96 | EXPECT_THAT(actual, 97 | Each(AllOf(Field(&Waveform::nb_channels, 2), 98 | Field(&Waveform::nb_frames, this->GetFrames()), 99 | Field(&Waveform::data, Property(&std::vector::size, this->GetFrames() * 2))))); 100 | } 101 | 102 | REGISTER_TYPED_TEST_SUITE_P(InferenceEngineFixture_WithInferenceEngineType, 103 | InferenceEngine_GivenTypicalInputs_ExpectInferenceResults); 104 | 105 | typedef ::testing::Types InferenceEngineTestTypes; 106 | INSTANTIATE_TYPED_TEST_SUITE_P(InferenceEngine, 107 | InferenceEngineFixture_WithInferenceEngineType, 108 | InferenceEngineTestTypes); 109 | 110 | class InferenceEngineStrategyTest : public ::testing::TestWithParam 111 | { 112 | public: 113 | InferenceEngineStrategyTest() : inference_engine_params_{}, unit_{} {} 114 | 115 | protected: 116 | InferenceEngineParameters inference_engine_params_; 117 | InferenceEngineStrategy unit_; 118 | }; 119 | TEST_P(InferenceEngineStrategyTest, GivenInferenceEngine_ExpectSelectedEngine) 120 | { 121 | unit_.SelectInferenceEngine(GetParam(), inference_engine_params_); 122 | 123 | EXPECT_EQ(GetParam(), unit_.GetInferenceEngineType()); 124 | } 125 | INSTANTIATE_TEST_CASE_P(InferenceEngine, 126 | InferenceEngineStrategyTest, 127 | ::testing::Values(InferenceEngineType::kTensorFlow, 128 | InferenceEngineType::kTensorFlowLite, 129 | InferenceEngineType::kInvalid)); 130 | } // namespace 131 | } // namespace spleeter 132 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | # ------------ 2 | # Docker Image 3 | # ------------ 4 | include: 5 | - template: Code-Quality.gitlab-ci.yml 6 | - template: SAST.gitlab-ci.yml 7 | 8 | # ------------ 9 | # Various Jobs 10 | # ------------ 11 | stages: 12 | - check 13 | - build 14 | - test 15 | - deploy 16 | 17 | # ------------ 18 | # Stage: check 19 | # ------------ 20 | clang-format: 21 | stage: check 22 | image: ubuntu:20.04 23 | before_script: 24 | - apt-get update && apt-get install -y clang-format 25 | script: 26 | - find . -regex '.*\.\(ino\|cpp\|hpp\|cc\|cxx\|h\)' -exec cat {} \; | diff -u <(find . -regex '.*\.\(ino\|cpp\|hpp\|cc\|cxx\|h\)' -exec clang-format -style=file {} -verbose \;) - 27 | 28 | buildifier: 29 | stage: check 30 | image: ubuntu:20.04 31 | before_script: 32 | - apt-get update && apt-get install -y wget 33 | - wget https://github.com/bazelbuild/buildtools/releases/download/3.5.0/buildifier 34 | - chmod +x buildifier 35 | - mv buildifier /usr/bin 36 | script: 37 | - buildifier -v -d -r . 38 | 39 | cppcheck: 40 | stage: check 41 | image: ubuntu:20.04 42 | before_script: 43 | - apt-get update && apt-get install -y cppcheck python python-pygments 44 | script: 45 | - cppcheck --template=gcc --enable=all --inconclusive --std=c++14 -I spleeter/**/*.h spleeter/**/*.cpp > static_code_analysis.log 46 | - cppcheck --template=gcc --enable=all --inconclusive --std=c++14 -I spleeter/**/*.h spleeter/**/*.cpp --xml 2> static_code_analysis.xml 47 | - cppcheck-htmlreport --file=static_code_analysis.xml --report-dir=static_code_analysis_report --source-dir=. 48 | artifacts: 49 | name: static_code_analysis 50 | paths: 51 | - static_code_analysis_report/ 52 | - static_code_analysis.xml 53 | - static_code_analysis.log 54 | expire_in: 7 days 55 | 56 | # ------------ 57 | # Stage: Build 58 | # ------------ 59 | bazel-build-and-test: 60 | stage: build 61 | image: registry.gitlab.com/jinay1991/spleeter 62 | script: 63 | - bazel build //... 64 | - bazel test //... --test_output=all 65 | after_script: 66 | - apt-get update && apt-get install -y wget 67 | - wget https://github.com/drazisil/junit-merge/releases/download/v2.0.0/junit-merge-linux 68 | - chmod +x junit-merge-linux 69 | - mv junit-merge-linux /usr/bin 70 | - junit-merge-linux $(find bazel-out/* -name test.xml) --out merged_test_report.xml 71 | artifacts: 72 | when: always 73 | reports: 74 | junit: merged_test_report.xml 75 | 76 | doxygen: 77 | stage: build 78 | image: ubuntu:20.04 79 | needs: [] 80 | before_script: 81 | - apt-get update && apt-get install -y doxygen graphviz plantuml 82 | script: 83 | - echo "Generating Doxygen Documentation" 84 | - doxygen Doxyfile 85 | artifacts: 86 | paths: 87 | - doc 88 | expire_in: 1 day 89 | only: 90 | refs: 91 | - master 92 | 93 | # ------------ 94 | # Stage: Test 95 | # ------------ 96 | code-coverage: 97 | stage: test 98 | image: ubuntu:20.04 99 | dependencies: 100 | - bazel-build-and-test 101 | before_script: 102 | - export DEBIAN_FRONTEND=noninteractive 103 | - apt-get update && apt-get install -y build-essential curl git libtool lcov 104 | - curl https://bazel.build/bazel-release.pub.gpg | apt-key add - 105 | - echo "deb [arch=amd64] https://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list 106 | - apt-get update && apt-get install -y bazel 107 | - apt-get update && apt-get install -y openjdk-11-jdk openjdk-11-jre 108 | - apt-get update && apt-get install -y libavcodec-dev libavformat-dev libavfilter-dev libavdevice-dev libswresample-dev libswscale-dev ffmpeg 109 | script: 110 | - bazel coverage -s --combined_report=lcov --instrumentation_filter=//... --coverage_report_generator=@bazel_tools//tools/test:coverage_report_generator //... 111 | after_script: 112 | - export OUTPUT_DIR=$(bazel info execution_root) 113 | - export COVERAGE_INFO=$(find $OUTPUT_DIR -name coverage.dat) 114 | - genhtml -s --num-spaces 4 --legend --highlight --sort -t "Code Coverage" --demangle-cpp --function-coverage --branch-coverage -o coverage $COVERAGE_INFO 115 | coverage: /functions.*:\s(\d+.\d+%)/ 116 | allow_failure: true 117 | artifacts: 118 | paths: 119 | - coverage/ 120 | name: code-coverage 121 | when: on_success 122 | expire_in: 1 day 123 | only: 124 | refs: 125 | - master 126 | 127 | pages: 128 | stage: deploy 129 | image: ubuntu:20.04 130 | needs: 131 | - code-coverage 132 | - cppcheck 133 | - doxygen 134 | script: 135 | - mkdir -p reports 136 | - echo " Spleeter " > reports/index.html 137 | - mv coverage reports/coverage 138 | - mv static_code_analysis_report/ reports/static_code_analysis_report 139 | - mv doc/ reports/doc 140 | - mv reports/ public 141 | artifacts: 142 | paths: 143 | - public 144 | only: 145 | refs: 146 | - master 147 | -------------------------------------------------------------------------------- /spleeter/inference_engine/tflite_inference_engine.cpp: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @copyright Copyright (c) 2020. All Rights Reserved. 4 | /// 5 | #include "spleeter/inference_engine/tflite_inference_engine.h" 6 | 7 | #include "flatbuffers/flatbuffers.h" 8 | #include "spleeter/logging/logging.h" 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | #include 18 | 19 | namespace spleeter 20 | { 21 | namespace 22 | { 23 | /// @brief Converts TfLiteTensor to waveform 24 | /// 25 | /// @param tensor[in] TfLiteTensor in [NxHxWxC form] 26 | /// 27 | /// @return Equivalent waveform for given TfLiteTensor by copying contents. 28 | Waveform ConvertToWaveform(const TfLiteTensor* tensor) 29 | { 30 | const TfLiteIntArray* tensor_dims = tensor->dims; 31 | const auto samples = tensor_dims->size > 0 ? static_cast(tensor_dims->data[0]) : 1; 32 | const auto channels = tensor_dims->size > 1 ? static_cast(tensor_dims->data[1]) : 1; 33 | const auto size = samples * channels; 34 | const float* tensor_ptr = reinterpret_cast(tensor->data.raw); 35 | Waveform waveform{}; 36 | waveform.nb_frames = samples; 37 | waveform.nb_channels = channels; 38 | std::copy(tensor_ptr, tensor_ptr + size, std::back_inserter(waveform.data)); 39 | return waveform; 40 | } 41 | 42 | } // namespace 43 | 44 | inline std::ostream& operator<<(std::ostream& os, const TfLiteIntArray* v) 45 | { 46 | if (!v) 47 | { 48 | os << " (null)"; 49 | return os; 50 | } 51 | for (int k = 0; k < v->size; k++) 52 | { 53 | os << " " << std::dec << v->data[k]; 54 | } 55 | return os; 56 | } 57 | 58 | TFLiteInferenceEngine::TFLiteInferenceEngine(const InferenceEngineParameters& params) 59 | : input_tensor_name_{params.input_tensor_name}, 60 | output_tensor_names_{params.output_tensor_names}, 61 | output_tensor_indicies_{}, 62 | model_path_{params.model_path}, 63 | results_{} 64 | { 65 | } 66 | 67 | void TFLiteInferenceEngine::Init() 68 | { 69 | model_ = tflite::FlatBufferModel::BuildFromFile(model_path_.c_str()); 70 | CHECK(model_) << "Failed to read model " << model_path_; 71 | model_->error_reporter(); 72 | 73 | tflite::ops::builtin::BuiltinOpResolver resolver{}; 74 | tflite::InterpreterBuilder(*model_, resolver)(&interpreter_); 75 | CHECK(interpreter_) << "Failed to construct interpreter"; 76 | 77 | CHECK_EQ(interpreter_->AllocateTensors(), TfLiteStatus::kTfLiteOk) << "Failed to allocate tensors!"; 78 | 79 | LOG(INFO) << "Successfully loaded tflite model from '" << model_path_ << "'."; 80 | } 81 | 82 | void TFLiteInferenceEngine::Execute(const Waveform& waveform) 83 | { 84 | UpdateInput(waveform); 85 | UpdateTensors(); 86 | UpdateOutputs(); 87 | } 88 | 89 | void TFLiteInferenceEngine::Shutdown() {} 90 | 91 | Waveforms TFLiteInferenceEngine::GetResults() const 92 | { 93 | return results_; 94 | } 95 | 96 | void TFLiteInferenceEngine::UpdateInput(const Waveform& waveform) 97 | { 98 | ResizeInputTensor(waveform); 99 | ResizeOutputTensor(waveform); 100 | CHECK_EQ(interpreter_->AllocateTensors(), TfLiteStatus::kTfLiteOk) << "Failed to allocate tensors!"; 101 | 102 | float* tensor_ptr = interpreter_->typed_input_tensor(0U); 103 | std::copy(waveform.data.cbegin(), waveform.data.cend(), tensor_ptr); 104 | 105 | LOG(INFO) << "Successfully loaded input waveform!!"; 106 | } 107 | 108 | void TFLiteInferenceEngine::UpdateTensors() 109 | { 110 | const auto ret = interpreter_->Invoke(); 111 | CHECK_EQ(ret, TfLiteStatus::kTfLiteOk) << "Failed to invoke tflite!"; 112 | 113 | output_tensor_indicies_ = interpreter_->outputs(); 114 | LOG(INFO) << "Successfully received results " << output_tensor_indicies_.size() << " outputs."; 115 | } 116 | 117 | void TFLiteInferenceEngine::UpdateOutputs() 118 | { 119 | std::transform(output_tensor_indicies_.cbegin(), 120 | output_tensor_indicies_.cend(), 121 | std::back_inserter(results_), 122 | [this](const auto& tensor_index) { 123 | const TfLiteTensor* tensor = interpreter_->tensor(tensor_index); 124 | return ConvertToWaveform(tensor); 125 | }); 126 | } 127 | 128 | void TFLiteInferenceEngine::ResizeInputTensor(const Waveform& waveform) 129 | { 130 | const auto input_tensor_indicies = interpreter_->inputs(); 131 | CHECK_EQ(input_tensor_indicies.size(), 1) << "Model has more than one input tensor"; 132 | 133 | ResizeTensor(input_tensor_indicies.at(0U), {waveform.nb_frames, waveform.nb_channels}); 134 | } 135 | 136 | void TFLiteInferenceEngine::ResizeOutputTensor(const Waveform& waveform) 137 | { 138 | const auto output_tensor_indicies = interpreter_->outputs(); 139 | CHECK_GT(output_tensor_indicies.size(), 0) << "Model has no output tensor(s)"; 140 | 141 | for (auto tensor_index : output_tensor_indicies) 142 | { 143 | ResizeTensor(tensor_index, {waveform.nb_frames, waveform.nb_channels}); 144 | } 145 | } 146 | 147 | void TFLiteInferenceEngine::ResizeTensor(const std::int32_t tensor_index, const std::vector dims) 148 | { 149 | TfLiteTensor* tensor = interpreter_->tensor(tensor_index); 150 | tensor->data.raw = nullptr; 151 | tensor->allocation_type = TfLiteAllocationType::kTfLiteDynamic; 152 | 153 | CHECK_EQ(interpreter_->ResizeInputTensor(tensor_index, dims), TfLiteStatus::kTfLiteOk); 154 | } 155 | } // namespace spleeter 156 | -------------------------------------------------------------------------------- /.clang-tidy: -------------------------------------------------------------------------------- 1 | --- 2 | Checks: '-*,clang-analyzer-*,clang-analyzer-cplusplus*,cppcoreguidelines-*,google-build-using-namespace,readability-*,-readability-avoid-const-params-in-decls' 3 | WarningsAsErrors: '' 4 | HeaderFilterRegex: 'perception/*' 5 | FormatStyle: file 6 | AnalyzeTemporaryDtors: false 7 | CheckOptions: 8 | - key: cppcoreguidelines-no-malloc.Allocations 9 | value: '::malloc;::calloc' 10 | - key: cppcoreguidelines-no-malloc.Deallocations 11 | value: '::free' 12 | - key: cppcoreguidelines-no-malloc.Reallocations 13 | value: '::realloc' 14 | - key: cppcoreguidelines-owning-memory.LegacyResourceConsumers 15 | value: '::free;::realloc;::freopen;::fclose' 16 | - key: cppcoreguidelines-owning-memory.LegacyResourceProducers 17 | value: '::malloc;::aligned_alloc;::realloc;::calloc;::fopen;::freopen;::tmpfile' 18 | - key: cppcoreguidelines-pro-bounds-constant-array-index.GslHeader 19 | value: '' 20 | - key: cppcoreguidelines-pro-bounds-constant-array-index.IncludeStyle 21 | value: '0' 22 | - key: cppcoreguidelines-pro-type-member-init.IgnoreArrays 23 | value: '0' 24 | - key: cppcoreguidelines-special-member-functions.AllowMissingMoveFunctions 25 | value: '0' 26 | - key: cppcoreguidelines-special-member-functions.AllowSoleDefaultDtor 27 | value: '0' 28 | - key: google-readability-braces-around-statements.ShortStatementLines 29 | value: '1' 30 | - key: google-readability-function-size.StatementThreshold 31 | value: '800' 32 | - key: google-readability-namespace-comments.ShortNamespaceLines 33 | value: '10' 34 | - key: google-readability-namespace-comments.SpacesBeforeComments 35 | value: '2' 36 | - key: modernize-loop-convert.MaxCopySize 37 | value: '16' 38 | - key: modernize-loop-convert.MinConfidence 39 | value: reasonable 40 | - key: modernize-loop-convert.NamingStyle 41 | value: CamelCase 42 | - key: modernize-make-shared.IgnoreMacros 43 | value: '1' 44 | - key: modernize-make-shared.IncludeStyle 45 | value: '0' 46 | - key: modernize-make-shared.MakeSmartPtrFunction 47 | value: 'std::make_shared' 48 | - key: modernize-make-shared.MakeSmartPtrFunctionHeader 49 | value: memory 50 | - key: modernize-make-unique.IgnoreMacros 51 | value: '1' 52 | - key: modernize-make-unique.IncludeStyle 53 | value: '0' 54 | - key: modernize-make-unique.MakeSmartPtrFunction 55 | value: 'std::make_unique' 56 | - key: modernize-make-unique.MakeSmartPtrFunctionHeader 57 | value: memory 58 | - key: modernize-pass-by-value.IncludeStyle 59 | value: llvm 60 | - key: modernize-pass-by-value.ValuesOnly 61 | value: '0' 62 | - key: modernize-raw-string-literal.ReplaceShorterLiterals 63 | value: '0' 64 | - key: modernize-replace-auto-ptr.IncludeStyle 65 | value: llvm 66 | - key: modernize-replace-random-shuffle.IncludeStyle 67 | value: llvm 68 | - key: modernize-use-auto.RemoveStars 69 | value: '0' 70 | - key: modernize-use-default-member-init.IgnoreMacros 71 | value: '1' 72 | - key: modernize-use-default-member-init.UseAssignment 73 | value: '0' 74 | - key: modernize-use-emplace.ContainersWithPushBack 75 | value: '::std::vector;::std::list;::std::deque' 76 | - key: modernize-use-emplace.SmartPointers 77 | value: '::std::shared_ptr;::std::unique_ptr;::std::auto_ptr;::std::weak_ptr' 78 | - key: modernize-use-emplace.TupleMakeFunctions 79 | value: '::std::make_pair;::std::make_tuple' 80 | - key: modernize-use-emplace.TupleTypes 81 | value: '::std::pair;::std::tuple' 82 | - key: modernize-use-equals-default.IgnoreMacros 83 | value: '1' 84 | - key: modernize-use-noexcept.ReplacementString 85 | value: '' 86 | - key: modernize-use-noexcept.UseNoexceptFalse 87 | value: '1' 88 | - key: modernize-use-nullptr.NullMacros 89 | value: 'NULL' 90 | - key: modernize-use-transparent-functors.SafeMode 91 | value: '0' 92 | - key: modernize-use-using.IgnoreMacros 93 | value: '1' 94 | - key: readability-braces-around-statements.ShortStatementLines 95 | value: '0' 96 | - key: readability-function-size.BranchThreshold 97 | value: '4294967295' 98 | - key: readability-function-size.LineThreshold 99 | value: '4294967295' 100 | - key: readability-function-size.NestingThreshold 101 | value: '4294967295' 102 | - key: readability-function-size.ParameterThreshold 103 | value: '4294967295' 104 | - key: readability-function-size.StatementThreshold 105 | value: '800' 106 | - key: readability-identifier-naming.IgnoreFailedSplit 107 | value: '0' 108 | - key: readability-implicit-bool-conversion.AllowIntegerConditions 109 | value: '0' 110 | - key: readability-implicit-bool-conversion.AllowPointerConditions 111 | value: '0' 112 | - key: readability-simplify-boolean-expr.ChainedConditionalAssignment 113 | value: '0' 114 | - key: readability-simplify-boolean-expr.ChainedConditionalReturn 115 | value: '0' 116 | - key: readability-static-accessed-through-instance.NameSpecifierNestingThreshold 117 | value: '3' 118 | ... 119 | -------------------------------------------------------------------------------- /spleeter/argument_parser/argument_parser.cpp: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @copyright Copyright (c) 2020, MIT License 4 | /// 5 | #include "spleeter/argument_parser/argument_parser.h" 6 | 7 | #include "spleeter/logging/logging.h" 8 | 9 | namespace spleeter 10 | { 11 | namespace 12 | { 13 | void PrintUsage() 14 | { 15 | LOG(INFO) << "spleeter\n" 16 | << "--inputs, -i: List of input audio filenames\n" 17 | << "--output_path, -o: Path of the output directory to write audio files in\n" 18 | << "--filename_format, -f: Template string that will be formatted to generated\n" 19 | " output filename. Such template should be Python formattable\n" 20 | " string, and could use {filename}, {instrument}, and {codec} variables\n" 21 | << "--duration, -d: Set a maximum duration for the processing audio\n" 22 | " (only separate offset + duration first seconds of the input file)\n" 23 | << "--offset, -s: Set the starting offset to separate audio from\n" 24 | << "--codec, -c: Audio codec to be used for separate output\n" 25 | " choices: { wav, mp3, ogg, m4a, wma, flac }\n" 26 | << "--bitrate, -b: Audio bitrate to be used for separate output\n" 27 | << "--mwf, -m: [0|1] Whether to use multichannel Wiener filtering for separation\n" 28 | << "--mus_dir, -u: Path to folder with musDB\n" 29 | << "--adapter, -a: Name of the audio adapater to use for audio I/O\n" 30 | << "--params_filename, -p: JSON filename that contains params\n" 31 | << "--verbose, -v: [0|1] Shows verbose logs\n" 32 | << "--data, -t: Path of the folder containing audio data for training\n" 33 | << "--help, -h: print help\n"; 34 | } 35 | } // namespace 36 | ArgumentParser::ArgumentParser() : cli_options_{} {} 37 | 38 | ArgumentParser::ArgumentParser(int argc, char* argv[]) 39 | : long_options_{{"inputs", required_argument, nullptr, 'i'}, 40 | {"output_path", required_argument, nullptr, 'o'}, 41 | {"filename_format", required_argument, nullptr, 'f'}, 42 | {"duration", required_argument, nullptr, 'd'}, 43 | {"offset", required_argument, nullptr, 's'}, 44 | {"codec", required_argument, nullptr, 'c'}, 45 | {"bitrate", required_argument, nullptr, 'b'}, 46 | {"mwf", required_argument, nullptr, 'm'}, 47 | {"mus_dir", required_argument, nullptr, 'u'}, 48 | {"adapter", required_argument, nullptr, 'a'}, 49 | {"params_filename", required_argument, nullptr, 'p'}, 50 | {"verbose", required_argument, nullptr, 'v'}, 51 | {"data", required_argument, nullptr, 't'}, 52 | {"help", 0, nullptr, 'h'}, 53 | {nullptr, 0, nullptr, 0}}, 54 | optstring_{"a:b:c:d:f:h:i:m:p:o:s:t:u:v:"} 55 | { 56 | cli_options_ = ParseArgs(argc, argv); 57 | } 58 | 59 | ArgumentParser::~ArgumentParser() {} 60 | 61 | CLIOptions ArgumentParser::GetParsedArgs() const 62 | { 63 | return cli_options_; 64 | } 65 | 66 | CLIOptions ArgumentParser::ParseArgs(int argc, char* argv[]) 67 | { 68 | if (argc > 1) 69 | { 70 | LOG(INFO) << "Parsed arguments: "; 71 | } 72 | while (true) 73 | { 74 | std::int32_t c = 0; 75 | std::int32_t optindex = 0; 76 | 77 | c = getopt_long(argc, argv, optstring_.c_str(), long_options_.data(), &optindex); 78 | /* Detect the end of the options. */ 79 | if (c == -1) 80 | { 81 | break; 82 | } 83 | switch (c) 84 | { 85 | case 'i': 86 | cli_options_.inputs = optarg; 87 | LOG(INFO) << " (+) inputs: " << cli_options_.inputs; 88 | break; 89 | case 'o': 90 | cli_options_.output_path = optarg; 91 | LOG(INFO) << " (+) output_path: " << cli_options_.output_path; 92 | break; 93 | case 'f': 94 | cli_options_.filename_format = optarg; 95 | LOG(INFO) << " (+) filename_format: " << cli_options_.filename_format; 96 | break; 97 | case 'd': 98 | cli_options_.duration = strtod(optarg, nullptr); 99 | LOG(INFO) << " (+) duration: " << cli_options_.duration; 100 | break; 101 | case 's': 102 | cli_options_.offset = strtod(optarg, nullptr); 103 | LOG(INFO) << " (+) offset: " << cli_options_.offset; 104 | break; 105 | case 'c': 106 | cli_options_.codec = optarg; 107 | LOG(INFO) << " (+) codec: " << cli_options_.codec; 108 | break; 109 | case 'b': 110 | cli_options_.bitrate = std::stoi(optarg) * 1000; 111 | LOG(INFO) << " (+) bitrate: " << cli_options_.bitrate; 112 | break; 113 | case 'm': 114 | cli_options_.mwf = strtol(optarg, nullptr, 10); 115 | LOG(INFO) << " (+) mwf: " << std::boolalpha << cli_options_.mwf; 116 | break; 117 | case 'u': 118 | cli_options_.mus_dir = optarg; 119 | LOG(INFO) << " (+) mus_dir: " << cli_options_.mus_dir; 120 | break; 121 | case 'a': 122 | cli_options_.audio_adapter = optarg; 123 | LOG(INFO) << " (+) audio_adapter: " << cli_options_.audio_adapter; 124 | break; 125 | case 'p': 126 | cli_options_.configuration = optarg; 127 | LOG(INFO) << " (+) configuration: " << cli_options_.configuration; 128 | break; 129 | case 'v': 130 | cli_options_.verbose = strtol(optarg, nullptr, 10); 131 | LOG(INFO) << " (+) verbose: " << std::boolalpha << cli_options_.verbose; 132 | break; 133 | case 't': 134 | cli_options_.audio_path = optarg; 135 | LOG(INFO) << " (+) audio_path: " << cli_options_.audio_path; 136 | break; 137 | case 'h': 138 | case '?': 139 | /* getopt_long already printed an error message. */ 140 | PrintUsage(); 141 | default: 142 | exit(1); 143 | } 144 | } 145 | return cli_options_; 146 | } 147 | 148 | } // namespace spleeter 149 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spleeter 2 | 3 | [![Pipeline](https://img.shields.io/badge/pipeline-status-status.svg)](https://gitlab.com/jinay1991/spleeter/pipelines) 4 | [![Doxygen](https://img.shields.io/badge/doc-doxygen-blue.svg)](https://jinay1991.gitlab.io/spleeter/doc/html/index.html) 5 | [![Coverage](https://img.shields.io/badge/coverage-report-green.svg)](https://jinay1991.gitlab.io/spleeter/coverage/index.html) 6 | [![Analysis](https://img.shields.io/badge/static-analysis-orange.svg)](https://jinay1991.gitlab.io/spleeter/static_code_analysis_report/) 7 | 8 | ## Dev Environment 9 | 10 | Supported OS: 11 | 12 | * Ubuntu 20.04 13 | * macOS Big Sur v11 14 | 15 | To setup developer environment, the project requires following packages. 16 | 17 | ```bash 18 | apt-get update && apt-get upgrade -y && apt-get autoremove -y 19 | 20 | # Installation of general dependencies 21 | apt-get install -y build-essential clang-format clang-tidy clangd git git-lfs wget curl gnupg openjdk-11-jdk openjdk-11-jre lcov 22 | 23 | # Installation of FFMPEG 24 | apt-get install -y libavcodec-dev libavformat-dev libavfilter-dev libavdevice-dev libswresample-dev libswscale-dev ffmpeg 25 | ``` 26 | 27 | ### Build System 28 | 29 | This project uses `bazel` build system. To install, run following command or find documentation for installation on office site [here](https://docs.bazel.build/versions/master/install-ubuntu.html#installing-bazel). 30 | 31 | For Linux/macOS systems, 32 | 33 | ```bash 34 | # Installation 35 | curl -fsSL https://bazel.build/bazel-release.pub.gpg | gpg --dearmor > bazel.gpg 36 | mv bazel.gpg /etc/apt/trusted.gpg.d/ 37 | echo "deb [arch=amd64] https://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list 38 | apt-get update && apt-get install -y bazel 39 | 40 | # Command Completion 41 | echo "source /etc/bash_completion.d/bazel" >> ~/.bashrc 42 | 43 | # Autoformat 44 | wget https://github.com/bazelbuild/buildtools/releases/download/3.5.0/buildifier 45 | chmod +x buildifier 46 | mv buildifier /usr/bin 47 | ``` 48 | 49 | ### Docker 50 | 51 | One can use docker container to use as dev environment, 52 | 53 | Install `docker` tool from official site [here](https://www.docker.com/products/docker-desktop) 54 | 55 | ```bash 56 | docker pull registry.gitlab.com/jinay1991/spleeter 57 | ``` 58 | ### TensorFlow 59 | 60 | To build `libtensorflow_cc.so` 61 | 62 | ```bash 63 | ~/$ git clone https://github.com/tensorflow/tensorflow.git 64 | ~/$ git checkout v2.3.0 65 | ~/$ cd tensorflow 66 | ~tensorflow/$ bazel build -c opt --config=monolithic //tensorflow:libtensorflow_cc.so //tensorflow:install_headers 67 | ~tensorflow/$ cd bazel-bin/tensorflow 68 | ~tensorflow/bazel-bin/tensorflow$ mkdir -p libtensorflow_cc-2.3.0-linux/include libtensorflow_cc-2.3.0-linux/lib 69 | ~tensorflow/bazel-bin/tensorflow$ cp -R include/* libtensorflow_cc-2.3.0-linux/include 70 | ~tensorflow/bazel-bin/tensorflow$ cp -P libtensorflow_cc.so libtensorflow_cc.so.2 libtensorflow_cc.so.2.3.0 libtensorflow_cc-2.3.0-linux/lib/ 71 | ~tensorflow/bazel-bin/tensorflow$ tar cvzf libtensorflow_cc-2.3.0-linux.tar.gz libtensorflow_cc-2.3.0-linux 72 | ``` 73 | 74 | Package `libtensorflow_cc-2.3.0-linux.tar.gz` contains required tensorflow libraries. 75 | 76 | ### TensorFlow Lite 77 | 78 | 79 | To build `libtensorflowlite.so` 80 | 81 | ```bash 82 | ~/$ git clone https://github.com/tensorflow/tensorflow.git 83 | ~/$ git checkout v2.3.0 84 | ~/$ cd tensorflow 85 | ~tensorflow/$ bazel build -c opt --config=monolithic //tensorflow/lite:libtensorflowlite.so 86 | ~tensorflow/$ cd bazel-bin/tensorflow/lite 87 | ~tensorflow/bazel-bin/tensorflow$ mkdir -p libtensorflowlite-2.3.0-linux/include libtensorflowlite-2.3.0-linux/lib 88 | ~tensorflow/bazel-bin/tensorflow$ cp --parents -r ~tensorflow/lite/**/**/*.h libtensorflowlite-2.3.0-linux/include 89 | ~tensorflow/bazel-bin/tensorflow$ cp -P libtensorflowlite.so libtensorflowlite-2.3.0-linux/lib/ 90 | ~tensorflow/bazel-bin/tensorflow$ tar cvzf libtensorflowlite-2.3.0-linux.tar.gz libtensorflowlite-2.3.0-linux 91 | ``` 92 | 93 | Package `libtensorflowlite-2.3.0-linux.tar.gz` contains required tensorflowlite libraries. 94 | 95 | ## Usage 96 | 97 | To quickly run the `spleeter` application, run `bazel run //application:spleeter` 98 | 99 | To quickly run unit/component tests, run `bazel test //... --test_output=all --cache_test_results=false` 100 | 101 | To build package (`*.tar.gz`), run `bazel build //:spleeter-dev` (This contains all the required header and libraries from this repository for smooth integration to other projects. See #integration guide to understand it better.) 102 | 103 | ## Integration 104 | 105 | Use the tarball for integration which contains all the required headers and libraries along with the `third_party` dependencies which can be fetched using `bazel`. 106 | Please note that `spleeter` has some external dependencies (i.e. `libtensorflow` etc.), which is why `third_party` directory is necessary for downloading all the dependencies and link them together with your application. 107 | 108 | Usage on API can be found on [Doxygen](https://jinay1991.gitlab.io/spleeter/doc/html/) documentation. 109 | 110 | Example implementation can be found in `example/spleeter_app.cpp` (Snipate below) 111 | 112 | ```cpp 113 | /// 114 | /// @file 115 | /// @copyright Copyright (c) 2020, MIT License 116 | /// 117 | #include "spleeter/argument_parser/cli_options.h" 118 | #include "spleeter/spleeter.h" 119 | 120 | #include 121 | #include 122 | 123 | int main(void) 124 | { 125 | try 126 | { 127 | /// Initialize 128 | auto cli_options = spleeter::CLIOptions{}; 129 | cli_options.inputs = std::string{"external/audio_example/file/audio_example.wav"}; 130 | cli_options.output_path = std::string{"separated_audio"}; 131 | cli_options.configuration = std::string{"spleeter:5stems"}; 132 | 133 | cli_options.audio_adapter = std::string{"audionamix"}; 134 | cli_options.codec = std::string{"wav"}; 135 | cli_options.bitrate = 192000; 136 | 137 | auto spleeter = std::make_unique(cli_options); 138 | spleeter->Init(); 139 | 140 | /// Run 141 | spleeter->Execute(); 142 | 143 | /// Deinitialize 144 | spleeter->Shutdown(); 145 | } 146 | catch (std::exception& e) 147 | { 148 | std::cerr << "Caught Exception!! " << e.what() << std::endl; 149 | return 1; 150 | } 151 | 152 | return 0; 153 | } 154 | ``` 155 | 156 | Prebuilt binaries for `ubuntu-20.04-amd64` https://github.com/jinay1991/spleeter/releases/download/v1.4/spleeter-dev.tar.gz 157 | 158 | ## Reference 159 | 160 | - Deezer Research - Source Separation Engine Story - deezer.io blog post: 161 | * [English version](https://deezer.io/releasing-spleeter-deezer-r-d-source-separation-engine-2b88985e797e) 162 | * [Japanese version](http://dzr.fm/splitterjp) 163 | - [Music Source Separation tool with pre-trained models / ISMIR2019 extended abstract](http://archives.ismir.net/ismir2019/latebreaking/000036.pdf) 164 | 165 | If you use **Spleeter** in your work, please cite: 166 | 167 | ```BibTeX 168 | @misc{spleeter2019, 169 | title={Spleeter: A Fast And State-of-the Art Music Source Separation Tool With Pre-trained Models}, 170 | author={Romain Hennequin and Anis Khlif and Felix Voituret and Manuel Moussallam}, 171 | howpublished={Late-Breaking/Demo ISMIR 2019}, 172 | month={November}, 173 | note={Deezer Research}, 174 | year={2019} 175 | } 176 | ``` 177 | 178 | Converted official checkpoint to TFLite Model using https://github.com/tinoucas/spleeter-tflite-convert 179 | 180 | ## License 181 | The code of **Spleeter** is MIT-licensed. 182 | 183 | ## Disclaimer 184 | If you plan to use Spleeter on copyrighted material, make sure you get proper authorization from right owners beforehand. 185 | 186 | ## Note 187 | 188 | This repository include a demo audio file `audio_example.mp3` which is an excerpt from 189 | 190 | ``` 191 | Slow Motion Dream by Steven M Bryant 192 | Copyright (c) 2011 Licensed under a Creative Commons Attribution (3.0) license. 193 | 194 | http://dig.ccmixter.org/files/stevieb357/34740 195 | Ft: CSoul Alex Beroza & Robert Siek" 196 | ``` 197 | -------------------------------------------------------------------------------- /spleeter/audio/ffmpeg_audio_adapter.cpp: -------------------------------------------------------------------------------- 1 | /// 2 | /// @file 3 | /// @brief Contains definitions for FFMPEG Audio Adapter class methods 4 | /// @copyright Copyright (c) 2020, MIT License 5 | /// 6 | #include "spleeter/audio/ffmpeg_audio_adapter.h" 7 | 8 | #include "spleeter/logging/logging.h" 9 | 10 | #include 11 | #include 12 | 13 | namespace spleeter 14 | { 15 | #define MAX_AUDIO_FRAME_SIZE 192000 // 1 second of 48khz 32bit audio 16 | 17 | namespace internal 18 | { 19 | /// @brief Encode given frame to the media file 20 | /// 21 | /// @param frame [in/out] - frame to be written to media file 22 | /// @param audio_codec_context [in/out] - encoder context 23 | /// @param format_context [in/out] - media format context 24 | /// @param data_present [out] - writes 1 on encoded data is preset, 0 otherwise. 25 | /// 26 | /// @return ret value - 0 on success, AVERROR (negative value) on error. Exception 0 on EOF. 27 | static std::int32_t Encode(AVFrame* frame, 28 | AVCodecContext* audio_codec_context, 29 | AVFormatContext* format_context, 30 | std::int32_t* data_present) 31 | { 32 | AVPacket packet; 33 | av_init_packet(&packet); 34 | packet.data = nullptr; 35 | packet.size = 0; 36 | *data_present = 0; 37 | 38 | if (frame) 39 | { 40 | static std::int64_t pts{0}; 41 | frame->pts = pts; 42 | pts += frame->nb_samples; 43 | } 44 | auto ret = avcodec_send_frame(audio_codec_context, frame); 45 | CHECK_LE(0, ret) << "Failed to send frame to encoder (Returned: " << ret << ")"; 46 | 47 | while (ret >= 0) 48 | { 49 | ret = avcodec_receive_packet(audio_codec_context, &packet); 50 | if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) 51 | { 52 | break; 53 | } 54 | CHECK_LE(0, ret) << "Error during encoding (Returned: " << ret << ")"; 55 | 56 | *data_present = 1; 57 | packet.stream_index = 0; 58 | ret = av_write_frame(format_context, &packet); 59 | CHECK_LE(0, ret) << "Failed to write frame. (Returned: " << ret << ")"; 60 | av_packet_unref(&packet); 61 | } 62 | 63 | return 0; 64 | } 65 | 66 | } // namespace internal 67 | 68 | FfmpegAudioAdapter::FfmpegAudioAdapter() 69 | { 70 | av_register_all(); 71 | } 72 | 73 | Waveform FfmpegAudioAdapter::Load(const std::string& path, 74 | const double /*offset*/, 75 | const double /*duration*/, 76 | const std::int32_t sample_rate) 77 | { 78 | /// 79 | /// Open Input Audio 80 | /// 81 | AVFormatContext* format_context = avformat_alloc_context(); 82 | CHECK(format_context) << "Failed to allocate format context"; 83 | 84 | auto ret = avformat_open_input(&format_context, path.c_str(), nullptr, nullptr); 85 | CHECK_LE(0, ret) << "Failed to open file " << path; 86 | 87 | ret = avformat_find_stream_info(format_context, nullptr); 88 | CHECK_LE(0, ret) << "Failed to retrieve stream info from file " << path; 89 | 90 | ret = av_find_best_stream(format_context, AVMEDIA_TYPE_AUDIO, -1, -1, nullptr, 0); 91 | CHECK_LE(0, ret) << "Failed find any audio stream in the file " << path; 92 | auto stream_index = ret; 93 | AVStream* audio_stream = format_context->streams[stream_index]; 94 | 95 | AVCodec* audio_codec = avcodec_find_decoder(audio_stream->codecpar->codec_id); 96 | CHECK(audio_codec) << "Failed to find {" << audio_stream->codecpar->codec_id << "} codec"; 97 | 98 | AVCodecContext* audio_codec_context = avcodec_alloc_context3(audio_codec); 99 | CHECK(audio_codec_context) << "Failed to allocate {" << audio_stream->codecpar->codec_id << "} codec context."; 100 | 101 | ret = avcodec_parameters_to_context(audio_codec_context, audio_stream->codecpar); 102 | CHECK_LE(0, ret) << "Failed to copy {" << audio_stream->codecpar->codec_id 103 | << "} codec parameters to decoder context."; 104 | 105 | ret = avcodec_open2(audio_codec_context, audio_codec, nullptr); 106 | CHECK_LE(0, ret) << "Failed to open {" << audio_stream->codecpar->codec_id << "} codec"; 107 | 108 | av_dump_format(format_context, 0, path.c_str(), 0); 109 | 110 | /// 111 | /// Read Audio 112 | /// 113 | SwrContext* swr_context = swr_alloc(); 114 | CHECK(swr_context) << "Failed to allocate resampler."; 115 | 116 | swr_context = swr_alloc_set_opts(swr_context, 117 | AV_CH_LAYOUT_STEREO, 118 | AV_SAMPLE_FMT_FLT, 119 | sample_rate, 120 | av_get_default_channel_layout(audio_codec_context->channels), 121 | audio_codec_context->sample_fmt, 122 | audio_codec_context->sample_rate, 123 | 0, 124 | nullptr); 125 | CHECK(swr_context) << "Failed to set options for resampler."; 126 | 127 | ret = swr_init(swr_context); 128 | CHECK_LE(0, ret) << "Failed to initialize resampler. (Returned: " << ret << ")"; 129 | 130 | std::uint8_t* buffer = (std::uint8_t*)av_malloc(MAX_AUDIO_FRAME_SIZE * 2); 131 | CHECK(buffer) << "Failed to allocate buffer"; 132 | 133 | AVPacket packet; 134 | av_init_packet(&packet); 135 | 136 | Waveform waveform{}; 137 | std::int32_t nb_samples{0}; 138 | while (av_read_frame(format_context, &packet) >= 0) 139 | { 140 | AVFrame* frame = av_frame_alloc(); 141 | CHECK(frame) << "Failed to allocate frame"; 142 | 143 | ret = avcodec_send_packet(audio_codec_context, &packet); 144 | CHECK_LE(0, ret) << "Failed to send packet for decoding. (Returned: " << ret << ")"; 145 | while (ret >= 0) 146 | { 147 | ret = avcodec_receive_frame(audio_codec_context, frame); 148 | if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) 149 | { 150 | break; 151 | } 152 | CHECK_EQ(0, ret) << "Failed to decode received packet. (Returned: " << ret << ")"; 153 | 154 | auto buffer_size = frame->nb_samples * av_get_bytes_per_sample(audio_codec_context->sample_fmt); 155 | ret = swr_convert(swr_context, &buffer, buffer_size, (const std::uint8_t**)frame->data, frame->nb_samples); 156 | CHECK_LE(0, ret) << "Failed to resample. (Returned: " << ret << ")"; 157 | 158 | for (auto idx = 0; idx < buffer_size; ++idx) 159 | { 160 | waveform.data.push_back(buffer[idx]); 161 | } 162 | 163 | nb_samples += frame->nb_samples; 164 | } 165 | 166 | av_frame_free(&frame); 167 | av_packet_unref(&packet); 168 | } 169 | /// Update Audio properties before releasing resources 170 | audio_properties_.nb_channels = audio_codec_context->channels; 171 | audio_properties_.nb_frames = nb_samples; 172 | audio_properties_.sample_rate = sample_rate; 173 | waveform.nb_frames = audio_properties_.nb_frames; 174 | waveform.nb_channels = audio_properties_.nb_channels; 175 | 176 | av_packet_unref(&packet); 177 | av_free(buffer); 178 | swr_free(&swr_context); 179 | avcodec_close(audio_codec_context); 180 | avformat_close_input(&format_context); 181 | 182 | LOG(INFO) << "Decoded waveform with " << audio_properties_; 183 | LOG(INFO) << "Loaded waveform from " << path << " using FFMPEG."; 184 | return waveform; 185 | } 186 | 187 | void FfmpegAudioAdapter::Save(const std::string& path, 188 | const Waveform& waveform, 189 | const std::int32_t sample_rate, 190 | const std::string& /*codec*/, 191 | const std::int32_t bitrate) 192 | { 193 | /// 194 | /// Open Output Audio 195 | /// 196 | AVFormatContext* format_context{nullptr}; 197 | auto ret = avformat_alloc_output_context2(&format_context, nullptr, nullptr, path.c_str()); 198 | CHECK_LE(0, ret) << "Failed to deduce output format from the file extension. (Returned: " << ret << ")"; 199 | 200 | AVOutputFormat* output_format = format_context->oformat; 201 | CHECK(output_format) << "Failed to find output format"; 202 | 203 | AVCodec* audio_codec = avcodec_find_encoder(output_format->audio_codec); 204 | CHECK(audio_codec) << "Failed to find encoder for '" << avcodec_get_name(output_format->audio_codec) << "' codec."; 205 | 206 | AVStream* audio_stream = avformat_new_stream(format_context, nullptr); 207 | CHECK(audio_stream) << "Failed to allocate stream"; 208 | 209 | AVCodecContext* audio_codec_context = avcodec_alloc_context3(audio_codec); 210 | CHECK(audio_codec_context) << "Failed to allocate encoding context"; 211 | 212 | /// 213 | /// Adjust Encoding Parameters 214 | /// 215 | audio_codec_context->codec_id = format_context->audio_codec_id; 216 | audio_codec_context->codec_type = AVMEDIA_TYPE_AUDIO; 217 | audio_codec_context->sample_fmt = audio_codec->sample_fmts ? audio_codec->sample_fmts[0] : AV_SAMPLE_FMT_FLTP; 218 | audio_codec_context->sample_rate = 219 | audio_codec->supported_samplerates ? audio_codec->supported_samplerates[0] : sample_rate; 220 | audio_codec_context->channel_layout = AV_CH_LAYOUT_STEREO; 221 | audio_codec_context->channels = av_get_channel_layout_nb_channels(audio_codec_context->channel_layout); 222 | audio_codec_context->bit_rate = bitrate; 223 | 224 | audio_stream->time_base = AVRational{1, sample_rate}; 225 | 226 | if (output_format->flags & AVFMT_GLOBALHEADER) 227 | { 228 | audio_codec_context->flags |= AV_CODEC_FLAG_GLOBAL_HEADER; 229 | } 230 | 231 | /// 232 | /// Open Codec 233 | /// 234 | ret = avcodec_open2(audio_codec_context, audio_codec, nullptr); 235 | CHECK_LE(0, ret) << "Failed to open the context with the audio codec {" 236 | << avcodec_get_name(output_format->audio_codec) << "}. (Returned: " << ret << ")"; 237 | 238 | ret = avcodec_parameters_from_context(audio_stream->codecpar, audio_codec_context); 239 | CHECK_LE(0, ret) << "Failed to copy {" << audio_stream->codecpar->codec_id 240 | << "} codec parameters to decoder context. (Returned: " << ret << ")"; 241 | av_dump_format(format_context, 0, path.c_str(), 1); 242 | 243 | if (!(format_context->flags & AVFMT_NOFILE)) 244 | { 245 | ret = avio_open(&format_context->pb, path.c_str(), AVIO_FLAG_WRITE); 246 | CHECK_LE(0, ret) << "Failed to open " << path << ". (Returned: " << ret << ")"; 247 | LOG(INFO) << "Successfully opened " << path << " for writing."; 248 | } 249 | 250 | if (audio_codec_context->codec->capabilities & AV_CODEC_CAP_VARIABLE_FRAME_SIZE) 251 | { 252 | audio_codec_context->frame_size = 253 | av_rescale_rnd(waveform.data.size() / av_get_channel_layout_nb_channels(AV_CH_LAYOUT_STEREO), 254 | sample_rate, 255 | sample_rate, 256 | AV_ROUND_UP); 257 | } 258 | 259 | /// 260 | /// Allocate sample frame 261 | /// 262 | AVFrame* frame = av_frame_alloc(); 263 | CHECK(frame) << "Failed to allocate frame"; 264 | 265 | frame->nb_samples = audio_codec_context->frame_size; 266 | frame->format = audio_codec_context->sample_fmt; 267 | frame->channels = audio_codec_context->channels; 268 | frame->channel_layout = audio_codec_context->channel_layout; 269 | frame->sample_rate = audio_codec_context->sample_rate; 270 | 271 | ret = av_frame_get_buffer(frame, 0); 272 | CHECK_LE(0, ret) << "Failed to allocate audio data buffers. (Returned: " << ret << ")"; 273 | 274 | ret = av_frame_make_writable(frame); 275 | CHECK_LE(0, ret) << "Failed to make frame writable. (Returned: " << ret << ")"; 276 | 277 | /// 278 | /// Allocate Resampler 279 | /// 280 | SwrContext* swr_context = swr_alloc(); 281 | CHECK(swr_context) << "Failed to allocate resampler."; 282 | 283 | swr_context = swr_alloc_set_opts(swr_context, 284 | audio_codec_context->channel_layout, 285 | audio_codec_context->sample_fmt, 286 | audio_codec_context->sample_rate, 287 | AV_CH_LAYOUT_STEREO, 288 | AV_SAMPLE_FMT_FLT, 289 | sample_rate, 290 | 0, 291 | nullptr); 292 | CHECK(swr_context) << "Failed to set options for resampler."; 293 | 294 | ret = swr_init(swr_context); 295 | CHECK_LE(0, ret) << "Failed to initialize resampler. (Returned: " << ret << ")"; 296 | 297 | /// 298 | /// Resample the data 299 | /// 300 | std::uint8_t** src_data{nullptr}; 301 | std::int32_t src_linesize{0}; 302 | std::uint8_t** dst_data{nullptr}; 303 | std::int32_t dst_linesize{0}; 304 | ret = av_samples_alloc_array_and_samples(&src_data, 305 | &src_linesize, 306 | av_get_channel_layout_nb_channels(AV_CH_LAYOUT_STEREO), 307 | waveform.data.size(), 308 | AV_SAMPLE_FMT_FLT, 309 | 0); 310 | CHECK_LE(0, ret) << "Failed to allocate src samples. (Returned: " << ret << ")"; 311 | 312 | ret = av_samples_fill_arrays(src_data, 313 | &src_linesize, 314 | (const std::uint8_t*)waveform.data.data(), 315 | av_get_channel_layout_nb_channels(AV_CH_LAYOUT_STEREO), 316 | waveform.data.size(), 317 | AV_SAMPLE_FMT_FLT, 318 | 0); 319 | CHECK_LE(0, ret) << "Failed to fill src samples. (Returned: " << ret << ")"; 320 | 321 | ret = av_samples_alloc_array_and_samples(&dst_data, 322 | &dst_linesize, 323 | audio_codec_context->channels, 324 | audio_codec_context->frame_size, 325 | audio_codec_context->sample_fmt, 326 | 0); 327 | CHECK_LE(0, ret) << "Failed to allocate dst samples. (Returned: " << ret << ")"; 328 | 329 | LOG(INFO) << "Converting given waveform {size: " << waveform.data.size() 330 | << ", sample_fmt: " << av_get_sample_fmt_name(AV_SAMPLE_FMT_FLT) 331 | << "} to {size: " << audio_codec_context->frame_size 332 | << ", sample_fmt: " << av_get_sample_fmt_name(audio_codec_context->sample_fmt) << "}"; 333 | ret = swr_convert( 334 | swr_context, dst_data, audio_codec_context->frame_size, (const std::uint8_t**)src_data, waveform.data.size()); 335 | CHECK_LE(0, ret) << "Failed to convert buffer using resampler. (Returned: " << ret << ")"; 336 | 337 | /// 338 | /// Encode samples 339 | /// 340 | ret = avformat_write_header(format_context, nullptr); 341 | CHECK_LE(0, ret) << "Failed to write header information for " << path << ". (Returned: " << ret << ")"; 342 | 343 | std::int32_t data_present{0}; 344 | frame->data[0] = dst_data[0]; 345 | frame->data[1] = dst_data[0] + dst_linesize; 346 | ret = internal::Encode(frame, audio_codec_context, format_context, &data_present); 347 | CHECK_LE(0, ret) << "Failed to encode frame. (Returned: " << ret << ")"; 348 | 349 | /// 350 | /// Write queued samples 351 | /// 352 | data_present = 0; 353 | do 354 | { 355 | ret = internal::Encode(nullptr, audio_codec_context, format_context, &data_present); 356 | CHECK_LE(0, ret) << "Failed to encode frame. (Returned: " << ret << ")"; 357 | } while (data_present); 358 | 359 | ret = av_write_trailer(format_context); 360 | CHECK_LE(0, ret) << "Failed to write output file trailer. (Returned: " << ret << ")"; 361 | 362 | /// 363 | /// Cleanup 364 | /// 365 | if (format_context && !(format_context->flags & AVFMT_NOFILE)) 366 | { 367 | ret = avio_close(format_context->pb); 368 | CHECK_LE(0, ret) << "Failed to close " << path << ". (Returned: " << ret << ")"; 369 | LOG(INFO) << "Successfully closed " << path << " after writing."; 370 | } 371 | swr_close(swr_context); 372 | avcodec_close(audio_codec_context); 373 | // av_freep(&src_data[0]); 374 | // av_freep(&src_data); 375 | av_freep(&dst_data[0]); 376 | av_freep(&dst_data); 377 | swr_free(&swr_context); 378 | av_frame_free(&frame); 379 | av_free(audio_stream); 380 | avcodec_free_context(&audio_codec_context); 381 | 382 | LOG(INFO) << "Saved waveform to " << path << " using FFMPEG."; 383 | } 384 | 385 | AudioProperties FfmpegAudioAdapter::GetProperties() const 386 | { 387 | return audio_properties_; 388 | } 389 | 390 | } // namespace spleeter 391 | -------------------------------------------------------------------------------- /scripts/unet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf8 3 | 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | 9 | class UNet(tf.keras.Model): 10 | def __init__(self, output_name='output', output_mask_logit=False): 11 | super(UNet, self).__init__() 12 | 13 | # First layer. 14 | self.conv1 = tf.keras.layers.Conv2D(filters=16, 15 | kernel_size=(5, 5), 16 | strides=(2, 2), 17 | padding="same", 18 | kernel_initializer="he_uniform") 19 | self.batch1 = tf.keras.layers.BatchNormalization(axis=-1) 20 | self.act1 = tf.keras.layers.Activation(activation="relu") 21 | # Second layer. 22 | self.conv2 = tf.keras.layers.Conv2D(filters=32, 23 | kernel_size=(5, 5), 24 | strides=(2, 2), 25 | padding="same", 26 | kernel_initializer="he_uniform") 27 | self.batch2 = tf.keras.layers.BatchNormalization(axis=-1) 28 | self.act2 = tf.keras.layers.Activation(activation="relu") 29 | # Third layer. 30 | self.conv3 = tf.keras.layers.Conv2D(filters=64, 31 | kernel_size=(5, 5), 32 | strides=(2, 2), 33 | padding="same", 34 | kernel_initializer="he_uniform") 35 | self.batch3 = tf.keras.layers.BatchNormalization(axis=-1) 36 | self.act3 = tf.keras.layers.Activation(activation="relu") 37 | # Fourth layer. 38 | self.conv4 = tf.keras.layers.Conv2D(filters=128, 39 | kernel_size=(5, 5), 40 | strides=(2, 2), 41 | padding="same", 42 | kernel_initializer="he_uniform") 43 | self.batch4 = tf.keras.layers.BatchNormalization(axis=-1) 44 | self.act4 = tf.keras.layers.Activation(activation="relu") 45 | # Fifth layer. 46 | self.conv5 = tf.keras.layers.Conv2D(filters=256, 47 | kernel_size=(5, 5), 48 | strides=(2, 2), 49 | padding="same", 50 | kernel_initializer="he_uniform") 51 | self.batch5 = tf.keras.layers.BatchNormalization(axis=-1) 52 | self.act5 = tf.keras.layers.Activation(activation="relu") 53 | # Sixth layer 54 | self.conv6 = tf.keras.layers.Conv2D(filters=512, 55 | kernel_size=(5, 5), 56 | strides=(2, 2), 57 | padding="same", 58 | kernel_initializer="he_uniform") 59 | self.batch6 = tf.keras.layers.BatchNormalization(axis=-1) 60 | self.act6 = tf.keras.layers.Activation(activation="relu") 61 | # 62 | # 63 | # 64 | self.up1 = tf.keras.layers.Conv2DTranspose(filters=256, 65 | kernel_size=(5, 5), 66 | strides=(2, 2), 67 | activation="relu", 68 | padding='same', 69 | kernel_initializer="he_uniform") 70 | self.batch7 = tf.keras.layers.BatchNormalization(axis=-1) 71 | self.drop1 = tf.keras.layers.Dropout(0.5) 72 | # 73 | self.up2 = tf.keras.layers.Conv2DTranspose(filters=128, 74 | kernel_size=(5, 5), 75 | strides=(2, 2), 76 | activation="relu", 77 | padding='same', 78 | kernel_initializer="he_uniform") 79 | self.batch8 = tf.keras.layers.BatchNormalization(axis=-1) 80 | self.drop2 = tf.keras.layers.Dropout(0.5) 81 | # 82 | self.up3 = tf.keras.layers.Conv2DTranspose(filters=64, 83 | kernel_size=(5, 5), 84 | strides=(2, 2), 85 | activation="relu", 86 | padding='same', 87 | kernel_initializer="he_uniform") 88 | self.batch9 = tf.keras.layers.BatchNormalization(axis=-1) 89 | self.drop3 = tf.keras.layers.Dropout(0.5) 90 | # 91 | self.up4 = tf.keras.layers.Conv2DTranspose(filters=32, 92 | kernel_size=(5, 5), 93 | strides=(2, 2), 94 | activation="relu", 95 | padding='same', 96 | kernel_initializer="he_uniform") 97 | self.batch10 = tf.keras.layers.BatchNormalization(axis=-1) 98 | # 99 | self.up5 = tf.keras.layers.Conv2DTranspose(filters=16, 100 | kernel_size=(5, 5), 101 | strides=(2, 2), 102 | activation="relu", 103 | padding='same', 104 | kernel_initializer="he_uniform") 105 | self.batch11 = tf.keras.layers.BatchNormalization(axis=-1) 106 | # 107 | self.up6 = tf.keras.layers.Conv2DTranspose(filters=1, 108 | kernel_size=(5, 5), 109 | strides=(2, 2), 110 | activation="relu", 111 | padding='same', 112 | kernel_initializer="he_uniform") 113 | self.batch12 = tf.keras.layers.BatchNormalization(axis=-1) 114 | 115 | # Last layer to ensure initial shape reconstruction. 116 | self.output_name = output_name 117 | if not output_mask_logit: 118 | self.output_mask_logit = False 119 | self.up7 = tf.keras.layers.Conv2D(filters=2, 120 | kernel_size=(4, 4), 121 | dilation_rate=(2, 2), 122 | activation='sigmoid', 123 | padding='same', 124 | kernel_initializer="he_uniform") 125 | else: 126 | self.output_mask_logit = True 127 | self.logits = tf.keras.layers.Conv2D(filters=2, 128 | kernel_size=(4, 4), 129 | dilation_rate=(2, 2), 130 | padding='same', 131 | kernel_initializer="he_uniform") 132 | 133 | def call(self, inputs, training=False): 134 | conv1 = self.conv1(inputs) 135 | batch1 = self.batch1(conv1) 136 | act1 = self.act1(batch1) 137 | # Second layer. 138 | conv2 = self.conv2(act1) 139 | batch2 = self.batch2(conv2) 140 | act2 = self.act2(batch2) 141 | # Third layer. 142 | conv3 = self.conv3(act2) 143 | batch3 = self.batch3(conv3) 144 | act3 = self.act3(batch3) 145 | # Fourth layer. 146 | conv4 = self.conv4(act3) 147 | batch4 = self.batch4(conv4) 148 | act4 = self.act4(batch4) 149 | # Fifth layer. 150 | conv5 = self.conv5(act4) 151 | batch5 = self.batch5(conv5) 152 | act5 = self.act5(batch5) 153 | # Sixth layer 154 | conv6 = self.conv6(act5) 155 | batch6 = self.batch6(conv6) 156 | _ = self.act6(batch6) 157 | # 158 | # 159 | # 160 | up1 = self.up1(conv6) 161 | batch7 = self.batch7(up1) 162 | drop1 = self.drop1(batch7) 163 | merge1 = tf.keras.layers.Concatenate(axis=-1)([conv5, drop1]) 164 | # 165 | up2 = self.up1(merge1) 166 | batch8 = self.batch8(up2) 167 | drop2 = self.drop2(batch8) 168 | merge2 = tf.keras.layers.Concatenate(axis=-1)([conv4, drop2]) 169 | # 170 | up3 = self.up3(merge2) 171 | batch9 = self.batch9(up3) 172 | drop3 = self.drop3(batch9) 173 | merge3 = tf.keras.layers.Concatenate(axis=-1)([conv3, drop3]) 174 | # 175 | up4 = self.up4(merge3) 176 | batch10 = self.batch10(up4) 177 | merge4 = tf.keras.layers.Concatenate(axis=-1)([conv2, batch10]) 178 | # 179 | up5 = self.up5(merge4) 180 | batch11 = self.batch11(up5) 181 | merge5 = tf.keras.layers.Concatenate(axis=-1)([conv1, batch11]) 182 | # 183 | up6 = self.up6(merge5) 184 | batch12 = self.batch12(up6) 185 | 186 | # Last layer to ensure initial shape reconstruction. 187 | if not self.output_mask_logit: 188 | up7 = self.up7(batch12) 189 | return tf.keras.layers.Multiply(name=self.output_name)([up7, input_tensor]) 190 | else: 191 | return self.logits(batch12) 192 | 193 | 194 | @tf.function 195 | def apply_unet( 196 | input_tensor, 197 | output_name='output', 198 | output_mask_logit=False): 199 | """ Apply a convolutionnal U-net to model a single instrument (one U-net 200 | is used for each instrument). 201 | 202 | :param input_tensor: 203 | :param output_name: (Optional) , default to 'output' 204 | :param output_mask_logit: (Optional) , default to False. 205 | """ 206 | # First layer. 207 | conv1 = tf.keras.layers.Conv2D(filters=16, 208 | kernel_size=(5, 5), 209 | strides=(2, 2), 210 | padding="same", 211 | kernel_initializer="he_uniform")(input_tensor) 212 | batch1 = tf.keras.layers.BatchNormalization(axis=-1)(conv1) 213 | act1 = tf.keras.layers.Activation(activation="relu")(batch1) 214 | # Second layer. 215 | conv2 = tf.keras.layers.Conv2D(filters=32, 216 | kernel_size=(5, 5), 217 | strides=(2, 2), 218 | padding="same", 219 | kernel_initializer="he_uniform")(act1) 220 | batch2 = tf.keras.layers.BatchNormalization(axis=-1)(conv2) 221 | act2 = tf.keras.layers.Activation(activation="relu")(batch2) 222 | # Third layer. 223 | conv3 = tf.keras.layers.Conv2D(filters=64, 224 | kernel_size=(5, 5), 225 | strides=(2, 2), 226 | padding="same", 227 | kernel_initializer="he_uniform")(act2) 228 | batch3 = tf.keras.layers.BatchNormalization(axis=-1)(conv3) 229 | act3 = tf.keras.layers.Activation(activation="relu")(batch3) 230 | # Fourth layer. 231 | conv4 = tf.keras.layers.Conv2D(filters=128, 232 | kernel_size=(5, 5), 233 | strides=(2, 2), 234 | padding="same", 235 | kernel_initializer="he_uniform")(act3) 236 | batch4 = tf.keras.layers.BatchNormalization(axis=-1)(conv4) 237 | act4 = tf.keras.layers.Activation(activation="relu")(batch4) 238 | # Fifth layer. 239 | conv5 = tf.keras.layers.Conv2D(filters=256, 240 | kernel_size=(5, 5), 241 | strides=(2, 2), 242 | padding="same", 243 | kernel_initializer="he_uniform")(act4) 244 | batch5 = tf.keras.layers.BatchNormalization(axis=-1)(conv5) 245 | act5 = tf.keras.layers.Activation(activation="relu")(batch5) 246 | # Sixth layer 247 | conv6 = tf.keras.layers.Conv2D(filters=512, 248 | kernel_size=(5, 5), 249 | strides=(2, 2), 250 | padding="same", 251 | kernel_initializer="he_uniform")(act5) 252 | batch6 = tf.keras.layers.BatchNormalization(axis=-1)(conv6) 253 | _ = tf.keras.layers.Activation(activation="relu")(batch6) 254 | # 255 | # 256 | # 257 | up1 = tf.keras.layers.Conv2DTranspose(filters=256, 258 | kernel_size=(5, 5), 259 | strides=(2, 2), 260 | activation="relu", 261 | padding='same', 262 | kernel_initializer="he_uniform")(conv6) 263 | batch7 = tf.keras.layers.BatchNormalization(axis=-1)(up1) 264 | drop1 = tf.keras.layers.Dropout(0.5)(batch7) 265 | merge1 = tf.keras.layers.Concatenate(axis=-1)([conv5, drop1]) 266 | # 267 | up2 = tf.keras.layers.Conv2DTranspose(filters=128, 268 | kernel_size=(5, 5), 269 | strides=(2, 2), 270 | activation="relu", 271 | padding='same', 272 | kernel_initializer="he_uniform")(merge1) 273 | batch8 = tf.keras.layers.BatchNormalization(axis=-1)(up2) 274 | drop2 = tf.keras.layers.Dropout(0.5)(batch8) 275 | merge2 = tf.keras.layers.Concatenate(axis=-1)([conv4, drop2]) 276 | # 277 | up3 = tf.keras.layers.Conv2DTranspose(filters=64, 278 | kernel_size=(5, 5), 279 | strides=(2, 2), 280 | activation="relu", 281 | padding='same', 282 | kernel_initializer="he_uniform")(merge2) 283 | batch9 = tf.keras.layers.BatchNormalization(axis=-1)(up3) 284 | drop3 = tf.keras.layers.Dropout(0.5)(batch9) 285 | merge3 = tf.keras.layers.Concatenate(axis=-1)([conv3, drop3]) 286 | # 287 | up4 = tf.keras.layers.Conv2DTranspose(filters=32, 288 | kernel_size=(5, 5), 289 | strides=(2, 2), 290 | activation="relu", 291 | padding='same', 292 | kernel_initializer="he_uniform")(merge3) 293 | batch10 = tf.keras.layers.BatchNormalization(axis=-1)(up4) 294 | merge4 = tf.keras.layers.Concatenate(axis=-1)([conv2, batch10]) 295 | # 296 | up5 = tf.keras.layers.Conv2DTranspose(filters=16, 297 | kernel_size=(5, 5), 298 | strides=(2, 2), 299 | activation="relu", 300 | padding='same', 301 | kernel_initializer="he_uniform")(merge4) 302 | batch11 = tf.keras.layers.BatchNormalization(axis=-1)(up5) 303 | merge5 = tf.keras.layers.Concatenate(axis=-1)([conv1, batch11]) 304 | # 305 | up6 = tf.keras.layers.Conv2DTranspose(filters=1, 306 | kernel_size=(5, 5), 307 | strides=(2, 2), 308 | activation="relu", 309 | padding='same', 310 | kernel_initializer="he_uniform")(merge5) 311 | batch12 = tf.keras.layers.BatchNormalization(axis=-1)(up6) 312 | 313 | # Last layer to ensure initial shape reconstruction. 314 | if not output_mask_logit: 315 | up7 = tf.keras.layers.Conv2D(filters=2, 316 | kernel_size=(4, 4), 317 | dilation_rate=(2, 2), 318 | activation='sigmoid', 319 | padding='same', 320 | kernel_initializer="he_uniform")(batch12) 321 | output = tf.keras.layers.Multiply(name=output_name)([up7, input_tensor]) 322 | return output 323 | 324 | return tf.keras.layers.Conv2D(filters=2, 325 | kernel_size=(4, 4), 326 | dilation_rate=(2, 2), 327 | padding='same', 328 | kernel_initializer="he_uniform")(batch12) 329 | 330 | 331 | if __name__ == "__main__": 332 | input_tensor = np.zeros(shape=(1, 256, 256, 2), dtype=np.float32) 333 | 334 | model = UNet() 335 | model.trainable = False 336 | model._set_inputs(input_tensor) 337 | # model.compile(tf.keras.optimizers.Adam(learning_rate=1e-3), loss=tf.keras.losses.MeanSquaredError()) 338 | # model.build() 339 | # model.summary() 340 | 341 | tf.saved_model.save(model, export_dir="saved_model") 342 | --------------------------------------------------------------------------------