├── lib ├── torchaudio │ ├── version.rb │ ├── transforms │ │ ├── mu_law_encoding.rb │ │ ├── mu_law_decoding.rb │ │ ├── compute_deltas.rb │ │ ├── amplitude_to_db.rb │ │ ├── vol.rb │ │ ├── spectrogram.rb │ │ ├── mel_spectrogram.rb │ │ ├── mel_scale.rb │ │ ├── mfcc.rb │ │ └── fade.rb │ ├── datasets │ │ ├── yesno.rb │ │ └── utils.rb │ └── functional.rb └── torchaudio.rb ├── Gemfile ├── .gitignore ├── ext └── torchaudio │ ├── csrc │ ├── sox_effects.h │ ├── sox_io.h │ ├── sox_effects.cpp │ ├── sox.h │ ├── register.cpp │ ├── sox_utils.h │ ├── sox_io.cpp │ ├── sox_utils.cpp │ └── sox.cpp │ ├── ext.cpp │ └── extconf.rb ├── Rakefile ├── test ├── datasets_test.rb ├── test_helper.rb ├── torchaudio_test.rb ├── functional_test.rb └── transforms_test.rb ├── torchaudio.gemspec ├── CHANGELOG.md ├── .github └── workflows │ └── build.yml ├── LICENSE.txt ├── examples ├── LICENSE-tutorial.txt └── tutorial.rb └── README.md /lib/torchaudio/version.rb: -------------------------------------------------------------------------------- 1 | module TorchAudio 2 | VERSION = "0.4.1" 3 | end 4 | -------------------------------------------------------------------------------- /Gemfile: -------------------------------------------------------------------------------- 1 | source "https://rubygems.org" 2 | 3 | gemspec 4 | 5 | gem "rake" 6 | gem "rake-compiler" 7 | gem "minitest" 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.bundle/ 2 | /.yardoc 3 | /_yardoc/ 4 | /coverage/ 5 | /doc/ 6 | /pkg/ 7 | /spec/reports/ 8 | /tmp/ 9 | *.lock 10 | *.bundle 11 | *.so 12 | .data 13 | -------------------------------------------------------------------------------- /ext/torchaudio/csrc/sox_effects.h: -------------------------------------------------------------------------------- 1 | #ifndef TORCHAUDIO_SOX_EFFECTS_H 2 | #define TORCHAUDIO_SOX_EFFECTS_H 3 | 4 | #include 5 | 6 | namespace torchaudio { 7 | namespace sox_effects { 8 | 9 | void initialize_sox_effects(); 10 | 11 | void shutdown_sox_effects(); 12 | 13 | std::vector list_effects(); 14 | 15 | } // namespace sox_effects 16 | } // namespace torchaudio 17 | 18 | #endif 19 | -------------------------------------------------------------------------------- /lib/torchaudio/transforms/mu_law_encoding.rb: -------------------------------------------------------------------------------- 1 | module TorchAudio 2 | module Transforms 3 | class MuLawEncoding < Torch::NN::Module 4 | def initialize(quantization_channels: 256) 5 | super() 6 | @quantization_channels = quantization_channels 7 | end 8 | 9 | def forward(x) 10 | F.mu_law_encoding(x, @quantization_channels) 11 | end 12 | end 13 | end 14 | end 15 | -------------------------------------------------------------------------------- /lib/torchaudio/transforms/mu_law_decoding.rb: -------------------------------------------------------------------------------- 1 | module TorchAudio 2 | module Transforms 3 | class MuLawDecoding < Torch::NN::Module 4 | def initialize(quantization_channels: 256) 5 | super() 6 | @quantization_channels = quantization_channels 7 | end 8 | 9 | def forward(x_mu) 10 | F.mu_law_decoding(x_mu, @quantization_channels) 11 | end 12 | end 13 | end 14 | end 15 | -------------------------------------------------------------------------------- /lib/torchaudio/transforms/compute_deltas.rb: -------------------------------------------------------------------------------- 1 | module TorchAudio 2 | module Transforms 3 | class ComputeDeltas < Torch::NN::Module 4 | def initialize(win_length: 5, mode: "replicate") 5 | super() 6 | @win_length = win_length 7 | @mode = mode 8 | end 9 | 10 | def forward(specgram) 11 | F.compute_deltas(specgram, win_length: @win_length, mode: @mode) 12 | end 13 | end 14 | end 15 | end 16 | -------------------------------------------------------------------------------- /Rakefile: -------------------------------------------------------------------------------- 1 | require "bundler/gem_tasks" 2 | require "rake/testtask" 3 | require "rake/extensiontask" 4 | 5 | Rake::TestTask.new do |t| 6 | t.pattern = "test/**/*_test.rb" 7 | end 8 | 9 | task default: :test 10 | 11 | Rake::ExtensionTask.new("torchaudio") do |ext| 12 | ext.name = "ext" 13 | ext.lib_dir = "lib/torchaudio" 14 | end 15 | 16 | task :remove_ext do 17 | Dir["lib/torchaudio/ext.bundle"].each do |path| 18 | File.unlink(path) if File.exist?(path) 19 | end 20 | end 21 | 22 | Rake::Task["build"].enhance [:remove_ext] 23 | -------------------------------------------------------------------------------- /test/datasets_test.rb: -------------------------------------------------------------------------------- 1 | require_relative "test_helper" 2 | 3 | class DatasetsTest < Minitest::Test 4 | def test_yesno 5 | yesno_data = TorchAudio::Datasets::YESNO.new(root, download: true) 6 | n = 47 7 | waveform, sample_rate, labels = yesno_data[n] 8 | expected = [-0.00024414062, -0.00030517578, -9.1552734e-05, -0.00039672852, -0.00018310547] 9 | assert_elements_in_delta expected, waveform[0][0..4].to_a 10 | assert_equal [1, 49120], waveform.shape 11 | assert_equal 8000, sample_rate 12 | assert_equal [1, 1, 0, 1, 1, 0, 0, 1], labels 13 | end 14 | end 15 | -------------------------------------------------------------------------------- /test/test_helper.rb: -------------------------------------------------------------------------------- 1 | require "bundler/setup" 2 | Bundler.require(:default) 3 | require "minitest/autorun" 4 | 5 | class Minitest::Test 6 | def root 7 | @root ||= ENV["CI"] ? "#{ENV["HOME"]}/data" : Dir.tmpdir 8 | end 9 | 10 | def assert_elements_in_delta(expected, actual) 11 | assert_equal expected.size, actual.size 12 | expected.zip(actual) do |exp, act| 13 | assert_in_delta exp, act, 0.0001 14 | end 15 | end 16 | 17 | def audio_path 18 | @test_path ||= begin 19 | TorchAudio::Datasets::YESNO.new(root, download: true) 20 | "#{root}/waves_yesno/0_0_0_0_1_1_1_1.wav" 21 | end 22 | end 23 | end 24 | -------------------------------------------------------------------------------- /torchaudio.gemspec: -------------------------------------------------------------------------------- 1 | require_relative "lib/torchaudio/version" 2 | 3 | Gem::Specification.new do |spec| 4 | spec.name = "torchaudio" 5 | spec.version = TorchAudio::VERSION 6 | spec.summary = "Data manipulation and transformation for audio signal processing" 7 | spec.homepage = "https://github.com/ankane/torchaudio-ruby" 8 | spec.license = "BSD-2-Clause" 9 | 10 | spec.author = "Andrew Kane" 11 | spec.email = "andrew@ankane.org" 12 | 13 | spec.files = Dir["*.{md,txt}", "{ext,lib}/**/*"] 14 | spec.require_path = "lib" 15 | spec.extensions = ["ext/torchaudio/extconf.rb"] 16 | 17 | spec.required_ruby_version = ">= 3.1" 18 | 19 | spec.add_dependency "torch-rb", ">= 0.13" 20 | spec.add_dependency "rice", ">= 4.3.3" 21 | end 22 | -------------------------------------------------------------------------------- /lib/torchaudio/transforms/amplitude_to_db.rb: -------------------------------------------------------------------------------- 1 | module TorchAudio 2 | module Transforms 3 | class AmplitudeToDB < Torch::NN::Module 4 | def initialize(stype: :power, top_db: nil) 5 | super() 6 | 7 | @stype = stype 8 | 9 | raise ArgumentError, 'top_db must be a positive numerical' if top_db && top_db.negative? 10 | 11 | @top_db = top_db 12 | @multiplier = stype == :power ? 10.0 : 20.0 13 | @amin = 1e-10 14 | @ref_value = 1.0 15 | @db_multiplier = Math.log10([@amin, @ref_value].max) 16 | end 17 | 18 | def forward(amplitude_spectrogram) 19 | F.amplitude_to_DB( 20 | amplitude_spectrogram, 21 | @multiplier, @amin, @db_multiplier, 22 | top_db: @top_db 23 | ) 24 | end 25 | end 26 | end 27 | end 28 | -------------------------------------------------------------------------------- /lib/torchaudio/transforms/vol.rb: -------------------------------------------------------------------------------- 1 | module TorchAudio 2 | module Transforms 3 | class Vol < Torch::NN::Module 4 | def initialize(gain, gain_type: "amplitude") 5 | super() 6 | @gain = gain 7 | @gain_type = gain_type 8 | 9 | if ["amplitude", "power"].include?(gain_type) && gain < 0 10 | raise ArgumentError, "If gain_type = amplitude or power, gain must be positive." 11 | end 12 | end 13 | 14 | def forward(waveform) 15 | if @gain_type == "amplitude" 16 | waveform = waveform * @gain 17 | end 18 | 19 | if @gain_type == "db" 20 | waveform = F.gain(waveform, @gain) 21 | end 22 | 23 | if @gain_type == "power" 24 | waveform = F.gain(waveform, 10 * Math.log10(@gain)) 25 | end 26 | 27 | Torch.clamp(waveform, -1, 1) 28 | end 29 | end 30 | end 31 | end 32 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## 0.4.1 (2025-06-26) 2 | 3 | - Improved SoX detection for Homebrew 4 | 5 | ## 0.4.0 (2024-08-02) 6 | 7 | - Dropped support for Ruby < 3.1 8 | 9 | ## 0.3.2 (2023-02-27) 10 | 11 | - Improved LibTorch and SoX detection for Homebrew on Mac ARM and Linux 12 | 13 | ## 0.3.1 (2023-01-29) 14 | 15 | - Added `format` option 16 | 17 | ## 0.3.0 (2022-07-06) 18 | 19 | - Added `center`, `pad_mode`, and `onesided` options to `Spectogram` transform 20 | - Dropped support for Ruby < 2.7 21 | 22 | ## 0.2.1 (2021-07-16) 23 | 24 | - Added `create_dct` method 25 | - Added `ComputeDeltas`, `Fade`, `MFCC`, and `Vol` transforms 26 | 27 | ## 0.2.0 (2021-05-23) 28 | 29 | - Updated to Rice 4 30 | - Dropped support for Ruby < 2.6 31 | 32 | ## 0.1.2 (2021-02-06) 33 | 34 | - Added `amplitude_to_DB` and `DB_to_amplitude` methods 35 | - Added `AmplitudeToDB` transform 36 | - Fixed `save` options 37 | 38 | ## 0.1.1 (2020-08-26) 39 | 40 | - Added `save` method 41 | - Added transforms 42 | 43 | ## 0.1.0 (2020-08-24) 44 | 45 | - First release 46 | -------------------------------------------------------------------------------- /test/torchaudio_test.rb: -------------------------------------------------------------------------------- 1 | require_relative "test_helper" 2 | 3 | class TorchAudioTest < Minitest::Test 4 | def test_load_save 5 | waveform, sample_rate = TorchAudio.load(audio_path) 6 | 7 | save_path = "#{Dir.mktmpdir}/save.wav" 8 | TorchAudio.save(save_path, waveform, sample_rate) 9 | assert File.exist?(save_path) 10 | end 11 | 12 | def test_load_missing 13 | error = assert_raises(ArgumentError) do 14 | TorchAudio.load("missing.wav") 15 | end 16 | assert_equal "missing.wav not found or is a directory", error.message 17 | end 18 | 19 | def test_load_wav 20 | out, sample_rate = TorchAudio.load_wav(audio_path) 21 | assert_equal [1, 50800], out.shape 22 | assert_equal [1, 2, 1, 1, 1], out[0][0..4].to_a 23 | assert_equal 8000, sample_rate 24 | end 25 | 26 | def test_save_sample_rate 27 | save_path = "#{Dir.mktmpdir}/save.wav" 28 | TorchAudio.save(save_path, Torch.zeros([1, 16000]), 16000) 29 | _, sample_rate = TorchAudio.load(save_path) 30 | assert_equal 16000, sample_rate 31 | end 32 | end 33 | -------------------------------------------------------------------------------- /ext/torchaudio/csrc/sox_io.h: -------------------------------------------------------------------------------- 1 | #ifndef TORCHAUDIO_SOX_IO_H 2 | #define TORCHAUDIO_SOX_IO_H 3 | 4 | #include 5 | #include 6 | 7 | namespace torchaudio { 8 | namespace sox_io { 9 | 10 | struct SignalInfo : torch::CustomClassHolder { 11 | int64_t sample_rate; 12 | int64_t num_channels; 13 | int64_t num_frames; 14 | 15 | SignalInfo( 16 | const int64_t sample_rate_, 17 | const int64_t num_channels_, 18 | const int64_t num_frames_); 19 | int64_t getSampleRate() const; 20 | int64_t getNumChannels() const; 21 | int64_t getNumFrames() const; 22 | }; 23 | 24 | c10::intrusive_ptr get_info(const std::string& path); 25 | 26 | c10::intrusive_ptr load_audio_file( 27 | const std::string& path, 28 | const int64_t frame_offset = 0, 29 | const int64_t num_frames = -1, 30 | const bool normalize = true, 31 | const bool channels_first = true); 32 | 33 | void save_audio_file( 34 | const std::string& file_name, 35 | const c10::intrusive_ptr& signal, 36 | const double compression = 0.); 37 | 38 | } // namespace sox_io 39 | } // namespace torchaudio 40 | 41 | #endif 42 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | on: [push, pull_request] 3 | jobs: 4 | build: 5 | runs-on: ubuntu-latest 6 | env: 7 | BUNDLE_BUILD__TORCH___RB: "--with-torch-dir=/home/runner/libtorch" 8 | LIBTORCH_VERSION: 2.9.0 9 | steps: 10 | - uses: actions/checkout@v5 11 | - uses: actions/cache@v4 12 | with: 13 | path: ~/libtorch 14 | key: libtorch-${{ env.LIBTORCH_VERSION }} 15 | id: cache-libtorch 16 | - name: Download LibTorch 17 | if: steps.cache-libtorch.outputs.cache-hit != 'true' 18 | run: | 19 | cd ~ 20 | wget -q -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-$LIBTORCH_VERSION%2Bcpu.zip 21 | unzip -q libtorch.zip 22 | - run: sudo apt-get update && sudo apt-get install sox libsox-dev libsox-fmt-all 23 | - uses: ruby/setup-ruby@v1 24 | with: 25 | ruby-version: 3.4 26 | bundler-cache: true 27 | - uses: actions/cache@v4 28 | with: 29 | path: ~/data 30 | key: data 31 | - run: mkdir -p ~/data 32 | - run: bundle exec rake compile -- --with-torch-dir=$HOME/libtorch 33 | - run: bundle exec rake test 34 | -------------------------------------------------------------------------------- /lib/torchaudio/transforms/spectrogram.rb: -------------------------------------------------------------------------------- 1 | module TorchAudio 2 | module Transforms 3 | class Spectrogram < Torch::NN::Module 4 | def initialize( 5 | n_fft: 400, 6 | win_length: nil, 7 | hop_length: nil, 8 | pad: 0, 9 | window_fn: Torch.method(:hann_window), 10 | power: 2.0, 11 | normalized: false, 12 | wkwargs: nil, 13 | center: true, 14 | pad_mode: "reflect", 15 | onesided: true 16 | ) 17 | super() 18 | @n_fft = n_fft 19 | # number of FFT bins. the returned STFT result will have n_fft // 2 + 1 20 | # number of frequecies due to onesided=True in torch.stft 21 | @win_length = win_length || n_fft 22 | @hop_length = hop_length || @win_length.div(2) # floor division 23 | window = wkwargs.nil? ? window_fn.call(@win_length) : window_fn.call(@win_length, **wkwargs) 24 | register_buffer("window", window) 25 | @pad = pad 26 | @power = power 27 | @normalized = normalized 28 | @center = center 29 | @pad_mode = pad_mode 30 | @onesided = onesided 31 | end 32 | 33 | def forward(waveform) 34 | F.spectrogram( 35 | waveform, @pad, @window, @n_fft, @hop_length, @win_length, @power, @normalized, 36 | center: @center, pad_mode: @pad_mode, onesided: @onesided 37 | ) 38 | end 39 | end 40 | end 41 | end 42 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2017 Facebook Inc. (Soumith Chintala), 4 | Copyright (c) 2020-2025 Andrew Kane, 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 21 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 22 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 23 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 24 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 25 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 26 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | -------------------------------------------------------------------------------- /lib/torchaudio/transforms/mel_spectrogram.rb: -------------------------------------------------------------------------------- 1 | module TorchAudio 2 | module Transforms 3 | class MelSpectrogram < Torch::NN::Module 4 | attr_reader :n_mels 5 | 6 | def initialize( 7 | sample_rate: 16000, 8 | n_fft: 400, 9 | win_length: nil, 10 | hop_length: nil, 11 | f_min: 0.0, 12 | f_max: nil, 13 | pad: 0, 14 | n_mels: 128, 15 | window_fn: Torch.method(:hann_window), 16 | power: 2.0, 17 | normalized: false, 18 | wkwargs: nil 19 | ) 20 | super() 21 | @sample_rate = sample_rate 22 | @n_fft = n_fft 23 | @win_length = win_length || n_fft 24 | @hop_length = hop_length || @win_length.div(2) 25 | @pad = pad 26 | @power = power 27 | @normalized = normalized 28 | @n_mels = n_mels # number of mel frequency bins 29 | @f_max = f_max 30 | @f_min = f_min 31 | @spectrogram = 32 | Spectrogram.new( 33 | n_fft: @n_fft, win_length: @win_length, hop_length: @hop_length, pad: @pad, 34 | window_fn: window_fn, power: @power, normalized: @normalized, wkwargs: wkwargs 35 | ) 36 | @mel_scale = MelScale.new(n_mels: @n_mels, sample_rate: @sample_rate, f_min: @f_min, f_max: @f_max, n_stft: @n_fft.div(2) + 1) 37 | end 38 | 39 | def forward(waveform) 40 | specgram = @spectrogram.call(waveform) 41 | @mel_scale.call(specgram) 42 | end 43 | end 44 | end 45 | end 46 | -------------------------------------------------------------------------------- /lib/torchaudio/transforms/mel_scale.rb: -------------------------------------------------------------------------------- 1 | module TorchAudio 2 | module Transforms 3 | class MelScale < Torch::NN::Module 4 | def initialize(n_mels: 128, sample_rate: 16000, f_min: 0.0, f_max: nil, n_stft: nil) 5 | super() 6 | @n_mels = n_mels 7 | @sample_rate = sample_rate 8 | @f_max = f_max || sample_rate.div(2).to_f 9 | @f_min = f_min 10 | 11 | raise ArgumentError, "Require f_min: %f < f_max: %f" % [f_min, @f_max] unless f_min <= @f_max 12 | 13 | fb = n_stft.nil? ? Torch.empty(0) : F.create_fb_matrix(n_stft, @f_min, @f_max, @n_mels, @sample_rate) 14 | register_buffer("fb", fb) 15 | end 16 | 17 | def forward(specgram) 18 | shape = specgram.size 19 | specgram = specgram.reshape(-1, shape[-2], shape[-1]) 20 | 21 | if @fb.numel == 0 22 | tmp_fb = F.create_fb_matrix(specgram.size(1), @f_min, @f_max, @n_mels, @sample_rate) 23 | # Attributes cannot be reassigned outside __init__ so workaround 24 | @fb.resize!(tmp_fb.size) 25 | @fb.copy!(tmp_fb) 26 | end 27 | 28 | # (channel, frequency, time).transpose(...) dot (frequency, n_mels) 29 | # -> (channel, time, n_mels).transpose(...) 30 | mel_specgram = Torch.matmul(specgram.transpose(1, 2), @fb).transpose(1, 2) 31 | 32 | # unpack batch 33 | mel_specgram = mel_specgram.reshape(shape[0...-2] + mel_specgram.shape[-2..-1]) 34 | 35 | mel_specgram 36 | end 37 | end 38 | end 39 | end 40 | -------------------------------------------------------------------------------- /lib/torchaudio/transforms/mfcc.rb: -------------------------------------------------------------------------------- 1 | module TorchAudio 2 | module Transforms 3 | class MFCC < Torch::NN::Module 4 | SUPPORTED_DCT_TYPES = [2] 5 | 6 | def initialize(sample_rate: 16000, n_mfcc: 40, dct_type: 2, norm: :ortho, log_mels: false, melkwargs: {}) 7 | super() 8 | 9 | raise ArgumentError, "DCT type not supported: #{dct_type}" unless SUPPORTED_DCT_TYPES.include?(dct_type) 10 | 11 | @sample_rate = sample_rate 12 | @n_mfcc = n_mfcc 13 | @dct_type = dct_type 14 | @norm = norm 15 | @top_db = 80.0 16 | @amplitude_to_db = TorchAudio::Transforms::AmplitudeToDB.new(stype: :power, top_db: @top_db) 17 | 18 | @melspectrogram = TorchAudio::Transforms::MelSpectrogram.new(sample_rate: @sample_rate, **melkwargs) 19 | 20 | raise ArgumentError, "Cannot select more MFCC coefficients than # mel bins" if @n_mfcc > @melspectrogram.n_mels 21 | 22 | dct_mat = F.create_dct(@n_mfcc, @melspectrogram.n_mels, norm: @norm) 23 | register_buffer('dct_mat', dct_mat) 24 | 25 | @log_mels = log_mels 26 | end 27 | 28 | def forward(waveform) 29 | mel_specgram = @melspectrogram.(waveform) 30 | if @log_mels 31 | mel_specgram = Torch.log(mel_specgram + 1e-6) 32 | else 33 | mel_specgram = @amplitude_to_db.(mel_specgram) 34 | end 35 | 36 | Torch 37 | .matmul(mel_specgram.transpose(-2, -1), @dct_mat) 38 | .transpose(-2, -1) 39 | end 40 | end 41 | end 42 | end 43 | -------------------------------------------------------------------------------- /ext/torchaudio/csrc/sox_effects.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | using namespace torch::indexing; 5 | 6 | namespace torchaudio { 7 | namespace sox_effects { 8 | 9 | namespace { 10 | 11 | enum SoxEffectsResourceState { NotInitialized, Initialized, ShutDown }; 12 | SoxEffectsResourceState SOX_RESOURCE_STATE = NotInitialized; 13 | 14 | } // namespace 15 | 16 | void initialize_sox_effects() { 17 | if (SOX_RESOURCE_STATE == ShutDown) { 18 | throw std::runtime_error( 19 | "SoX Effects has been shut down. Cannot initialize again."); 20 | } 21 | if (SOX_RESOURCE_STATE == NotInitialized) { 22 | if (sox_init() != SOX_SUCCESS) { 23 | throw std::runtime_error("Failed to initialize sox effects."); 24 | }; 25 | SOX_RESOURCE_STATE = Initialized; 26 | } 27 | }; 28 | 29 | void shutdown_sox_effects() { 30 | if (SOX_RESOURCE_STATE == NotInitialized) { 31 | throw std::runtime_error( 32 | "SoX Effects is not initialized. Cannot shutdown."); 33 | } 34 | if (SOX_RESOURCE_STATE == Initialized) { 35 | if (sox_quit() != SOX_SUCCESS) { 36 | throw std::runtime_error("Failed to initialize sox effects."); 37 | }; 38 | SOX_RESOURCE_STATE = ShutDown; 39 | } 40 | } 41 | 42 | std::vector list_effects() { 43 | std::vector names; 44 | const sox_effect_fn_t* fns = sox_get_effect_fns(); 45 | for (int i = 0; fns[i]; ++i) { 46 | const sox_effect_handler_t* handler = fns[i](); 47 | if (handler && handler->name) 48 | names.push_back(handler->name); 49 | } 50 | return names; 51 | } 52 | 53 | } // namespace sox_effects 54 | } // namespace torchaudio 55 | -------------------------------------------------------------------------------- /examples/LICENSE-tutorial.txt: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, Pytorch contributors 4 | Copyright (c) 2020, Andrew Kane 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from 19 | this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /lib/torchaudio/datasets/yesno.rb: -------------------------------------------------------------------------------- 1 | module TorchAudio 2 | module Datasets 3 | class YESNO < Torch::Utils::Data::Dataset 4 | URL = "http://www.openslr.org/resources/1/waves_yesno.tar.gz" 5 | FOLDER_IN_ARCHIVE = "waves_yesno" 6 | CHECKSUMS = { 7 | "http://www.openslr.org/resources/1/waves_yesno.tar.gz" => "962ff6e904d2df1126132ecec6978786" 8 | } 9 | 10 | def initialize(root, url: URL, folder_in_archive: FOLDER_IN_ARCHIVE, download: false) 11 | archive = File.basename(url) 12 | archive = File.join(root, archive) 13 | @path = File.join(root, folder_in_archive) 14 | 15 | if download 16 | unless Dir.exist?(@path) 17 | unless File.exist?(archive) 18 | checksum = CHECKSUMS.fetch(url) 19 | Utils.download_url(url, root, hash_value: checksum, hash_type: "md5") 20 | end 21 | Utils.extract_archive(archive) 22 | end 23 | end 24 | 25 | unless Dir.exist?(@path) 26 | raise "Dataset not found. Please use `download: true` to download it." 27 | end 28 | 29 | walker = Utils.walk_files(@path, ext_audio, prefix: false, remove_suffix: true) 30 | @walker = walker.to_a 31 | end 32 | 33 | def [](n) 34 | fileid = @walker[n] 35 | load_yesno_item(fileid, @path, ext_audio) 36 | end 37 | 38 | def length 39 | @walker.length 40 | end 41 | alias_method :size, :length 42 | 43 | private 44 | 45 | def load_yesno_item(fileid, path, ext_audio) 46 | labels = fileid.split("_").map(&:to_i) 47 | 48 | file_audio = File.join(path, fileid + ext_audio) 49 | waveform, sample_rate = TorchAudio.load(file_audio) 50 | 51 | [waveform, sample_rate, labels] 52 | end 53 | 54 | def ext_audio 55 | ".wav" 56 | end 57 | end 58 | end 59 | end 60 | -------------------------------------------------------------------------------- /ext/torchaudio/ext.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | extern "C" 9 | void Init_ext() { 10 | auto rb_mTorchAudio = Rice::define_module("TorchAudio"); 11 | 12 | auto rb_mExt = Rice::define_module_under(rb_mTorchAudio, "Ext") 13 | .define_singleton_function( 14 | "read_audio_file", 15 | [](const std::string& file_name, at::Tensor output, bool ch_first, int64_t nframes, int64_t offset, sox_signalinfo_t* si, sox_encodinginfo_t* ei, const char* ft) { 16 | return torch::audio::read_audio_file(file_name, output, ch_first, nframes, offset, si, ei, ft); 17 | }) 18 | .define_singleton_function( 19 | "write_audio_file", 20 | [](const std::string& file_name, const at::Tensor& tensor, sox_signalinfo_t* si, sox_encodinginfo_t* ei, const char* file_type) { 21 | return torch::audio::write_audio_file(file_name, tensor, si, ei, file_type); 22 | }); 23 | 24 | auto rb_cSignalInfo = Rice::define_class_under(rb_mExt, "SignalInfo") 25 | .define_constructor(Rice::Constructor()) 26 | .define_method("rate", [](sox_signalinfo_t& self) { return self.rate; }) 27 | .define_method("channels", [](sox_signalinfo_t& self) { return self.channels; }) 28 | .define_method("precision", [](sox_signalinfo_t& self) { return self.precision; }) 29 | .define_method("length", [](sox_signalinfo_t& self) { return self.length; }) 30 | .define_method("rate=", [](sox_signalinfo_t& self, sox_rate_t rate) { self.rate = rate; }) 31 | .define_method("channels=", [](sox_signalinfo_t& self, unsigned channels) { self.channels = channels; }) 32 | .define_method("precision=", [](sox_signalinfo_t& self, unsigned precision) { self.precision = precision; }) 33 | .define_method("length=", [](sox_signalinfo_t& self, sox_uint64_t length) { self.length = length; }); 34 | } 35 | -------------------------------------------------------------------------------- /test/functional_test.rb: -------------------------------------------------------------------------------- 1 | require_relative "test_helper" 2 | 3 | class FunctionalTest < Minitest::Test 4 | def test_compute_deltas 5 | waveform, _ = TorchAudio.load(audio_path) 6 | transformed = TorchAudio::Functional.compute_deltas(waveform) 7 | assert_equal [1, 50800], transformed.shape 8 | expected = [3.0517579e-06, 0.0, -3.0517579e-06, -6.1035157e-06, 0.0] 9 | assert_elements_in_delta expected, transformed[0, 0..4].to_a 10 | end 11 | 12 | def test_gain 13 | waveform, _ = TorchAudio.load(audio_path) 14 | transformed = TorchAudio::Functional.gain(waveform) 15 | assert_equal [1, 50800], transformed.shape 16 | expected = [3.4241286e-05, 6.848257e-05, 3.4241286e-05, 3.4241286e-05, 3.4241286e-05] 17 | assert_elements_in_delta expected, transformed[0, 0..4].to_a 18 | end 19 | 20 | def test_dither 21 | waveform, _ = TorchAudio.load(audio_path) 22 | transformed = TorchAudio::Functional.dither(waveform) 23 | assert_equal [1, 50800], transformed.shape 24 | expected = [3.0517578e-05, 6.1035156e-05, 3.0517578e-05, 3.0517578e-05, 3.0517578e-05] 25 | assert_elements_in_delta expected, transformed[0, 0..4].to_a 26 | end 27 | 28 | def test_lowpass_biquad 29 | waveform, sample_rate = TorchAudio.load(audio_path) 30 | transformed = TorchAudio::Functional.lowpass_biquad(waveform, sample_rate, 3000) 31 | assert_equal [1, 50800], transformed.shape 32 | expected = [1.7364715e-05, 5.308807e-05, 4.8351816e-05, 2.3546876e-05, 3.114574e-05] 33 | assert_elements_in_delta expected, transformed[0, 0..4].to_a 34 | end 35 | 36 | def test_highpass_biquad 37 | waveform, sample_rate = TorchAudio.load(audio_path) 38 | transformed = TorchAudio::Functional.highpass_biquad(waveform, sample_rate, 3000) 39 | assert_equal [1, 50800], transformed.shape 40 | expected = [2.979314e-06, -2.808783e-06, -4.30352e-06, 7.97258e-06, -6.082024e-06] 41 | assert_elements_in_delta expected, transformed[0, 0..4].to_a 42 | end 43 | end 44 | -------------------------------------------------------------------------------- /lib/torchaudio/transforms/fade.rb: -------------------------------------------------------------------------------- 1 | module TorchAudio 2 | module Transforms 3 | class Fade < Torch::NN::Module 4 | def initialize(fade_in_len: 0, fade_out_len: 0, fade_shape: "linear") 5 | super() 6 | @fade_in_len = fade_in_len 7 | @fade_out_len = fade_out_len 8 | @fade_shape = fade_shape 9 | end 10 | 11 | def forward(waveform) 12 | waveform_length = waveform.size[-1] 13 | device = waveform.device 14 | fade_in(waveform_length).to(device) * fade_out(waveform_length).to(device) * waveform 15 | end 16 | 17 | private 18 | 19 | def fade_in(waveform_length) 20 | fade = Torch.linspace(0, 1, @fade_in_len) 21 | ones = Torch.ones(waveform_length - @fade_in_len) 22 | 23 | if @fade_shape == "linear" 24 | fade = fade 25 | end 26 | 27 | if @fade_shape == "exponential" 28 | fade = Torch.pow(2, (fade - 1)) * fade 29 | end 30 | 31 | if @fade_shape == "logarithmic" 32 | fade = Torch.log10(0.1 + fade) + 1 33 | end 34 | 35 | if @fade_shape == "quarter_sine" 36 | fade = Torch.sin(fade * Math::PI / 2) 37 | end 38 | 39 | if @fade_shape == "half_sine" 40 | fade = Torch.sin(fade * Math::PI - Math::PI / 2) / 2 + 0.5 41 | end 42 | 43 | Torch.cat([fade, ones]).clamp!(0, 1) 44 | end 45 | 46 | def fade_out(waveform_length) 47 | fade = Torch.linspace(0, 1, @fade_out_len) 48 | ones = Torch.ones(waveform_length - @fade_out_len) 49 | 50 | if @fade_shape == "linear" 51 | fade = - fade + 1 52 | end 53 | 54 | if @fade_shape == "exponential" 55 | fade = Torch.pow(2, - fade) * (1 - fade) 56 | end 57 | 58 | if @fade_shape == "logarithmic" 59 | fade = Torch.log10(1.1 - fade) + 1 60 | end 61 | 62 | if @fade_shape == "quarter_sine" 63 | fade = Torch.sin(fade * Math::PI / 2 + Math::PI / 2) 64 | end 65 | 66 | if @fade_shape == "half_sine" 67 | fade = Torch.sin(fade * Math::PI + Math::PI / 2) / 2 + 0.5 68 | end 69 | 70 | Torch.cat([ones, fade]).clamp!(0, 1) 71 | end 72 | end 73 | end 74 | end 75 | -------------------------------------------------------------------------------- /ext/torchaudio/extconf.rb: -------------------------------------------------------------------------------- 1 | require "mkmf-rice" 2 | 3 | $CXXFLAGS += " -std=c++17 $(optflags)" 4 | 5 | ext = File.expand_path(".", __dir__) 6 | csrc = File.expand_path("csrc", __dir__) 7 | 8 | $srcs = Dir["{#{ext},#{csrc}}/*.cpp"] 9 | $INCFLAGS << " -I#{File.expand_path("..", __dir__)}" 10 | $VPATH << csrc 11 | 12 | # 13 | # keep rest synced with Torch 14 | # 15 | 16 | # change to 0 for Linux pre-cxx11 ABI version 17 | $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1" 18 | 19 | apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i 20 | 21 | if apple_clang 22 | # silence torch warnings 23 | $CXXFLAGS += " -Wno-deprecated-declarations" 24 | else 25 | # silence rice warnings 26 | $CXXFLAGS += " -Wno-noexcept-type" 27 | 28 | # silence torch warnings 29 | $CXXFLAGS += " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn" 30 | end 31 | 32 | paths = [ 33 | "/usr/local", 34 | "/opt/homebrew", 35 | "/home/linuxbrew/.linuxbrew" 36 | ] 37 | 38 | inc, lib = dir_config("torch") 39 | inc ||= paths.map { |v| "#{v}/include" }.find { |v| Dir.exist?("#{v}/torch") } 40 | lib ||= paths.map { |v| "#{v}/lib" }.find { |v| Dir["#{v}/*torch_cpu*"].any? } 41 | 42 | unless inc && lib 43 | abort "LibTorch not found" 44 | end 45 | 46 | cuda_inc, cuda_lib = dir_config("cuda") 47 | cuda_inc ||= "/usr/local/cuda/include" 48 | cuda_lib ||= "/usr/local/cuda/lib64" 49 | 50 | $LDFLAGS += " -L#{lib}" if Dir.exist?(lib) 51 | abort "LibTorch not found" unless have_library("torch") 52 | 53 | have_library("mkldnn") 54 | have_library("nnpack") 55 | 56 | with_cuda = false 57 | if Dir["#{lib}/*torch_cuda*"].any? 58 | $LDFLAGS += " -L#{cuda_lib}" if Dir.exist?(cuda_lib) 59 | with_cuda = have_library("cuda") && have_library("cudnn") 60 | end 61 | 62 | $INCFLAGS += " -I#{inc}" 63 | $INCFLAGS += " -I#{inc}/torch/csrc/api/include" 64 | 65 | $LDFLAGS += " -Wl,-rpath,#{lib}" 66 | $LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda 67 | 68 | # https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238 69 | $LDFLAGS += " -lc10 -ltorch_cpu -ltorch" 70 | if with_cuda 71 | $LDFLAGS += " -lcuda -lnvrtc -lnvToolsExt -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn" 72 | # TODO figure out why this is needed 73 | $LDFLAGS += " -Wl,--no-as-needed,#{lib}/libtorch.so" 74 | end 75 | 76 | sox_inc, sox_lib = dir_config("sox") 77 | sox_inc ||= paths.map { |v| "#{v}/include" }.find { |v| File.exist?("#{v}/sox.h") } 78 | sox_lib ||= paths.map { |v| "#{v}/lib" }.find { |v| Dir["#{v}/*libsox*"].any? } 79 | 80 | $INCFLAGS += " -I#{sox_inc}" if sox_inc 81 | $LDFLAGS += " -L#{sox_lib}" if sox_lib 82 | abort "SoX not found" unless have_library("sox") 83 | 84 | # create makefile 85 | create_makefile("torchaudio/ext") 86 | -------------------------------------------------------------------------------- /ext/torchaudio/csrc/sox.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | // same as without 9 | #include 10 | 11 | namespace at { 12 | struct Tensor; 13 | } // namespace at 14 | 15 | namespace torch { namespace audio { 16 | 17 | /// Reads an audio file from the given `path` into the `output` `Tensor` and 18 | /// returns the sample rate of the audio file. 19 | /// Throws `std::runtime_error` if the audio file could not be opened, or an 20 | /// error occurred during reading of the audio data. 21 | int read_audio_file( 22 | const std::string& file_name, 23 | at::Tensor output, 24 | bool ch_first, 25 | int64_t nframes, 26 | int64_t offset, 27 | sox_signalinfo_t* si, 28 | sox_encodinginfo_t* ei, 29 | const char* ft); 30 | 31 | /// Writes the data of a `Tensor` into an audio file at the given `path`, with 32 | /// a certain extension (e.g. `wav`or `mp3`) and sample rate. 33 | /// Throws `std::runtime_error` when the audio file could not be opened for 34 | /// writing, or an error occurred during writing of the audio data. 35 | void write_audio_file( 36 | const std::string& file_name, 37 | const at::Tensor& tensor, 38 | sox_signalinfo_t* si, 39 | sox_encodinginfo_t* ei, 40 | const char* file_type); 41 | 42 | /// Reads an audio file from the given `path` and returns a tuple of 43 | /// sox_signalinfo_t and sox_encodinginfo_t, which contain information about 44 | /// the audio file such as sample rate, length, bit precision, encoding and more. 45 | /// Throws `std::runtime_error` if the audio file could not be opened, or an 46 | /// error occurred during reading of the audio data. 47 | std::tuple get_info( 48 | const std::string& file_name); 49 | 50 | // Struct for build_flow_effects function 51 | struct SoxEffect { 52 | SoxEffect() : ename(""), eopts({""}) { } 53 | std::string ename; 54 | std::vector eopts; 55 | }; 56 | 57 | /// Build a SoX chain, flow the effects, and capture the results in a tensor. 58 | /// An audio file from the given `path` flows through an effects chain given 59 | /// by a list of effects and effect options to an output buffer which is encoded 60 | /// into memory to a target signal type and target signal encoding. The resulting 61 | /// buffer is then placed into a tensor. This function returns the output tensor 62 | /// and the sample rate of the output tensor. 63 | int build_flow_effects(const std::string& file_name, 64 | at::Tensor otensor, 65 | bool ch_first, 66 | sox_signalinfo_t* target_signal, 67 | sox_encodinginfo_t* target_encoding, 68 | const char* file_type, 69 | std::vector pyeffs, 70 | int max_num_eopts); 71 | }} // namespace torch::audio 72 | -------------------------------------------------------------------------------- /ext/torchaudio/csrc/register.cpp: -------------------------------------------------------------------------------- 1 | #ifndef TORCHAUDIO_REGISTER_H 2 | #define TORCHAUDIO_REGISTER_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | namespace torchaudio { 9 | namespace { 10 | 11 | //////////////////////////////////////////////////////////////////////////////// 12 | // sox_utils.h 13 | //////////////////////////////////////////////////////////////////////////////// 14 | static auto registerTensorSignal = 15 | torch::class_("torchaudio", "TensorSignal") 16 | .def(torch::init()) 17 | .def("get_tensor", &sox_utils::TensorSignal::getTensor) 18 | .def("get_sample_rate", &sox_utils::TensorSignal::getSampleRate) 19 | .def("get_channels_first", &sox_utils::TensorSignal::getChannelsFirst); 20 | 21 | //////////////////////////////////////////////////////////////////////////////// 22 | // sox_io.h 23 | //////////////////////////////////////////////////////////////////////////////// 24 | static auto registerSignalInfo = 25 | torch::class_("torchaudio", "SignalInfo") 26 | .def("get_sample_rate", &sox_io::SignalInfo::getSampleRate) 27 | .def("get_num_channels", &sox_io::SignalInfo::getNumChannels) 28 | .def("get_num_frames", &sox_io::SignalInfo::getNumFrames); 29 | 30 | static auto registerGetInfo = torch::RegisterOperators().op( 31 | torch::RegisterOperators::options() 32 | .schema( 33 | "torchaudio::sox_io_get_info(str path) -> __torch__.torch.classes.torchaudio.SignalInfo info") 34 | .catchAllKernel()); 35 | 36 | static auto registerLoadAudioFile = torch::RegisterOperators().op( 37 | torch::RegisterOperators::options() 38 | .schema( 39 | "torchaudio::sox_io_load_audio_file(str path, int frame_offset, int num_frames, bool normalize, bool channels_first) -> __torch__.torch.classes.torchaudio.TensorSignal signal") 40 | .catchAllKernel< 41 | decltype(sox_io::load_audio_file), 42 | &sox_io::load_audio_file>()); 43 | 44 | static auto registerSaveAudioFile = torch::RegisterOperators().op( 45 | torch::RegisterOperators::options() 46 | .schema( 47 | "torchaudio::sox_io_save_audio_file(str path, __torch__.torch.classes.torchaudio.TensorSignal signal, float compression) -> ()") 48 | .catchAllKernel< 49 | decltype(sox_io::save_audio_file), 50 | &sox_io::save_audio_file>()); 51 | 52 | //////////////////////////////////////////////////////////////////////////////// 53 | // sox_effects.h 54 | //////////////////////////////////////////////////////////////////////////////// 55 | static auto registerSoxEffects = 56 | torch::RegisterOperators( 57 | "torchaudio::sox_effects_initialize_sox_effects", 58 | &sox_effects::initialize_sox_effects) 59 | .op("torchaudio::sox_effects_shutdown_sox_effects", 60 | &sox_effects::shutdown_sox_effects) 61 | .op("torchaudio::sox_effects_list_effects", &sox_effects::list_effects); 62 | 63 | } // namespace 64 | } // namespace torchaudio 65 | #endif 66 | -------------------------------------------------------------------------------- /lib/torchaudio/datasets/utils.rb: -------------------------------------------------------------------------------- 1 | module TorchAudio 2 | module Datasets 3 | module Utils 4 | class << self 5 | def download_url(url, download_folder, filename: nil, hash_value: nil, hash_type: "sha256") 6 | filename ||= File.basename(url) 7 | filepath = File.join(download_folder, filename) 8 | 9 | if File.exist?(filepath) 10 | raise "#{filepath} already exists. Delete the file manually and retry." 11 | end 12 | 13 | puts "Downloading #{url}..." 14 | download_url_to_file(url, filepath, hash_value, hash_type) 15 | end 16 | 17 | # follows redirects 18 | def download_url_to_file(url, dst, hash_value, hash_type, redirects = 0) 19 | raise "Too many redirects" if redirects > 10 20 | 21 | uri = URI(url) 22 | tmp = nil 23 | location = nil 24 | 25 | Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http| 26 | request = Net::HTTP::Get.new(uri) 27 | 28 | http.request(request) do |response| 29 | case response 30 | when Net::HTTPRedirection 31 | location = response["location"] 32 | when Net::HTTPSuccess 33 | tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name 34 | File.open(tmp, "wb") do |f| 35 | response.read_body do |chunk| 36 | f.write(chunk) 37 | end 38 | end 39 | else 40 | raise Error, "Bad response" 41 | end 42 | end 43 | end 44 | 45 | if location 46 | download_url_to_file(location, dst, hash_value, hash_type, redirects + 1) 47 | else 48 | # check hash 49 | # TODO use hash_type 50 | if Digest::MD5.file(tmp).hexdigest != hash_value 51 | raise "The hash of #{dst} does not match. Delete the file manually and retry." 52 | end 53 | 54 | FileUtils.mv(tmp, dst) 55 | dst 56 | end 57 | end 58 | 59 | # extract_tar_gz doesn't list files, so just return to_path 60 | def extract_archive(from_path, to_path: nil, overwrite: nil) 61 | to_path ||= File.dirname(from_path) 62 | 63 | if from_path.end_with?(".tar.gz") || from_path.end_with?(".tgz") 64 | File.open(from_path, "rb") do |io| 65 | Gem::Package.new("").extract_tar_gz(io, to_path) 66 | end 67 | return to_path 68 | end 69 | 70 | raise "We currently only support tar.gz and tgz archives." 71 | end 72 | 73 | def walk_files(root, suffix, prefix: false, remove_suffix: false) 74 | return enum_for(:walk_files, root, suffix, prefix: prefix, remove_suffix: remove_suffix) unless block_given? 75 | 76 | Dir.glob("**/*", base: root).sort.each do |f| 77 | if f.end_with?(suffix) 78 | if remove_suffix 79 | f = f[0..(-suffix.length - 1)] 80 | end 81 | 82 | if prefix 83 | raise "Not implemented yet" 84 | # f = File.join(dirpath, f) 85 | end 86 | 87 | yield f 88 | end 89 | end 90 | end 91 | end 92 | end 93 | end 94 | end 95 | -------------------------------------------------------------------------------- /ext/torchaudio/csrc/sox_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef TORCHAUDIO_SOX_UTILS_H 2 | #define TORCHAUDIO_SOX_UTILS_H 3 | 4 | #include 5 | #include 6 | 7 | namespace torchaudio { 8 | namespace sox_utils { 9 | 10 | struct TensorSignal : torch::CustomClassHolder { 11 | torch::Tensor tensor; 12 | int64_t sample_rate; 13 | bool channels_first; 14 | 15 | TensorSignal( 16 | torch::Tensor tensor_, 17 | int64_t sample_rate_, 18 | bool channels_first_); 19 | 20 | torch::Tensor getTensor() const; 21 | int64_t getSampleRate() const; 22 | bool getChannelsFirst() const; 23 | }; 24 | 25 | /// helper class to automatically close sox_format_t* 26 | struct SoxFormat { 27 | explicit SoxFormat(sox_format_t* fd) noexcept; 28 | SoxFormat(const SoxFormat& other) = delete; 29 | SoxFormat(SoxFormat&& other) = delete; 30 | SoxFormat& operator=(const SoxFormat& other) = delete; 31 | SoxFormat& operator=(SoxFormat&& other) = delete; 32 | ~SoxFormat(); 33 | sox_format_t* operator->() const noexcept; 34 | operator sox_format_t*() const noexcept; 35 | 36 | private: 37 | sox_format_t* fd_; 38 | }; 39 | 40 | /// 41 | /// Verify that input file is found, has known encoding, and not empty 42 | void validate_input_file(const SoxFormat& sf); 43 | 44 | /// 45 | /// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32 46 | void validate_input_tensor(const torch::Tensor); 47 | 48 | /// 49 | /// Get target dtype for the given encoding and precision. 50 | caffe2::TypeMeta get_dtype( 51 | const sox_encoding_t encoding, 52 | const unsigned precision); 53 | 54 | /// 55 | /// Convert sox_sample_t buffer to uint8/int16/int32/float32 Tensor 56 | /// NOTE: This function might modify the values in the input buffer to 57 | /// reduce the number of memory copy. 58 | /// @param buffer Pointer to buffer that contains audio data. 59 | /// @param num_samples The number of samples to read. 60 | /// @param num_channels The number of channels. Used to reshape the resulting 61 | /// Tensor. 62 | /// @param dtype Target dtype. Determines the output dtype and value range in 63 | /// conjunction with normalization. 64 | /// @param noramlize Perform normalization. Only effective when dtype is not 65 | /// kFloat32. When effective, the output tensor is kFloat32 type and value range 66 | /// is [-1.0, 1.0] 67 | /// @param channels_first When True, output Tensor has shape of [num_channels, 68 | /// num_frames]. 69 | torch::Tensor convert_to_tensor( 70 | sox_sample_t* buffer, 71 | const int32_t num_samples, 72 | const int32_t num_channels, 73 | const caffe2::TypeMeta dtype, 74 | const bool normalize, 75 | const bool channels_first); 76 | 77 | /// 78 | /// Convert float32/int32/int16/uint8 Tensor to int32 for Torch -> Sox 79 | /// conversion. 80 | torch::Tensor unnormalize_wav(const torch::Tensor); 81 | 82 | /// Extract extension from file path 83 | const std::string get_filetype(const std::string path); 84 | 85 | /// Get sox_signalinfo_t for passing a torch::Tensor object. 86 | sox_signalinfo_t get_signalinfo( 87 | const torch::Tensor& tensor, 88 | const int64_t sample_rate, 89 | const bool channels_first, 90 | const std::string filetype); 91 | 92 | /// Get sox_encofinginfo_t for saving audoi file 93 | sox_encodinginfo_t get_encodinginfo( 94 | const std::string filetype, 95 | const caffe2::TypeMeta dtype, 96 | const double compression); 97 | 98 | } // namespace sox_utils 99 | } // namespace torchaudio 100 | #endif 101 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TorchAudio Ruby 2 | 3 | :fire: An audio library for Torch.rb 4 | 5 | [![Build Status](https://github.com/ankane/torchaudio-ruby/actions/workflows/build.yml/badge.svg)](https://github.com/ankane/torchaudio-ruby/actions) 6 | 7 | ## Installation 8 | 9 | First, [install SoX](#sox-installation). For Homebrew, use: 10 | 11 | ```sh 12 | brew install sox 13 | ``` 14 | 15 | Add this line to your application’s Gemfile: 16 | 17 | ```ruby 18 | gem "torchaudio" 19 | ``` 20 | 21 | ## Getting Started 22 | 23 | This library follows the [Python API](https://pytorch.org/audio/). Many methods and options are missing at the moment. PRs welcome! 24 | 25 | ## Tutorial 26 | 27 | - [PyTorch tutorial](https://pytorch.org/tutorials/beginner/audio_preprocessing_tutorial.html) 28 | - [Ruby code](examples/tutorial.rb) 29 | 30 | Download the [audio file](https://github.com/pytorch/tutorials/raw/master/_static/img/steam-train-whistle-daniel_simon-converted-from-mp3.wav) and install the [matplotlib](https://github.com/mrkn/matplotlib.rb) gem first. 31 | 32 | ## Basics 33 | 34 | Load a file 35 | 36 | ```ruby 37 | waveform, sample_rate = TorchAudio.load("file.wav") 38 | ``` 39 | 40 | Save a file 41 | 42 | ```ruby 43 | TorchAudio.save("new.wave", waveform, sample_rate) 44 | ``` 45 | 46 | ## Transforms 47 | 48 | ```ruby 49 | TorchAudio::Transforms::Spectrogram.new.call(waveform) 50 | ``` 51 | 52 | Supported transforms are: 53 | 54 | - AmplitudeToDB 55 | - ComputeDeltas 56 | - Fade 57 | - MelScale 58 | - MelSpectrogram 59 | - MFCC 60 | - MuLawDecoding 61 | - MuLawEncoding 62 | - Spectrogram 63 | - Vol 64 | 65 | ## Functional 66 | 67 | ```ruby 68 | TorchAudio::Functional.lowpass_biquad(waveform, sample_rate, cutoff_freq) 69 | ``` 70 | 71 | Supported functions are: 72 | 73 | - amplitude_to_DB 74 | - compute_deltas 75 | - create_dct 76 | - create_fb_matrix 77 | - DB_to_amplitude 78 | - dither 79 | - gain 80 | - highpass_biquad 81 | - lowpass_biquad 82 | - mu_law_decoding 83 | - mu_law_encoding 84 | - spectrogram 85 | 86 | ## Datasets 87 | 88 | Load a dataset 89 | 90 | ```ruby 91 | TorchAudio::Datasets::YESNO.new(".", download: true) 92 | ``` 93 | 94 | Supported datasets are: 95 | 96 | - [YESNO](https://www.openslr.org/1/) 97 | 98 | ## Disclaimer 99 | 100 | This library downloads and prepares public datasets. We don’t host any datasets. Be sure to adhere to the license for each dataset. 101 | 102 | If you’re a dataset owner and wish to update any details or remove it from this project, let us know. 103 | 104 | ## SoX Installation 105 | 106 | ### Mac 107 | 108 | ```sh 109 | brew install sox 110 | ``` 111 | 112 | ### Windows 113 | 114 | todo 115 | 116 | ### Ubuntu 117 | 118 | ```sh 119 | sudo apt install sox libsox-dev libsox-fmt-all 120 | ``` 121 | 122 | ### Travis CI 123 | 124 | Add to `.travis.yml`: 125 | 126 | ```yml 127 | addons: 128 | apt: 129 | packages: 130 | - sox 131 | - libsox-dev 132 | - libsox-fmt-all 133 | ``` 134 | 135 | ## History 136 | 137 | View the [changelog](https://github.com/ankane/torchaudio-ruby/blob/master/CHANGELOG.md) 138 | 139 | ## Contributing 140 | 141 | Everyone is encouraged to help improve this project. Here are a few ways you can help: 142 | 143 | - [Report bugs](https://github.com/ankane/torchaudio-ruby/issues) 144 | - Fix bugs and [submit pull requests](https://github.com/ankane/torchaudio-ruby/pulls) 145 | - Write, clarify, or fix documentation 146 | - Suggest or add new features 147 | 148 | To get started with development: 149 | 150 | ```sh 151 | git clone https://github.com/ankane/torchaudio-ruby.git 152 | cd torchaudio-ruby 153 | bundle install 154 | bundle exec rake compile 155 | bundle exec rake test 156 | ``` 157 | -------------------------------------------------------------------------------- /test/transforms_test.rb: -------------------------------------------------------------------------------- 1 | require_relative "test_helper" 2 | 3 | class TransformsTest < Minitest::Test 4 | def test_spectogram 5 | waveform, _ = TorchAudio.load(audio_path) 6 | transformed = TorchAudio::Transforms::Spectrogram.new.call(waveform) 7 | assert_equal [1, 201, 255], transformed.size 8 | expected = [0.00051381276, 6.306071e-05, 0.0009923399, 0.00330968, 0.00030008898] 9 | assert_elements_in_delta expected, transformed[0][0][0..4].to_a 10 | end 11 | 12 | def test_melspectrogram 13 | waveform, sample_rate = TorchAudio.load(audio_path) 14 | transformed = TorchAudio::Transforms::MelSpectrogram.new(sample_rate: sample_rate).call(waveform) 15 | assert_equal [1, 128, 255], transformed.size 16 | expected = [4.320904736232478e-06, 0.00026097119553014636, 0.00010256850509904325, 0.0009344223653897643, 0.00013253440556582063] 17 | assert_elements_in_delta expected, transformed[0][0][0..4].to_a 18 | end 19 | 20 | def test_amplitude_to_db 21 | waveform, sample_rate = TorchAudio.load(audio_path) 22 | transformed = TorchAudio::Transforms::MelSpectrogram.new(sample_rate: sample_rate).call(waveform) 23 | assert_equal [1, 128, 255], transformed.size 24 | db = TorchAudio::Transforms::AmplitudeToDB.new(top_db: 80.0).call transformed 25 | expected = [-53.64425277709961, -35.834075927734375, -39.889862060546875, -30.29456901550293, -38.77671432495117] 26 | assert_elements_in_delta expected, db[0][0][0..4].to_a 27 | end 28 | 29 | def test_mfcc 30 | waveform, sample_rate = TorchAudio.load(audio_path) 31 | transformed = TorchAudio::Transforms::MFCC.new(n_mfcc: 16, sample_rate: sample_rate).(waveform) 32 | assert_equal [1, 16, 255], transformed.size 33 | expected = [-588.85400390625, -470.5740051269531, -420.1156005859375, -393.4096374511719, -415.3000793457031] 34 | assert_elements_in_delta expected, transformed[0][0][0..4].to_a 35 | end 36 | 37 | def test_mu_law_encoding 38 | waveform, _ = TorchAudio.load(audio_path) 39 | transformed = TorchAudio::Transforms::MuLawEncoding.new.call(waveform) 40 | assert_equal [1, 50800], transformed.size 41 | expected = [128, 128, 128, 128, 128] 42 | assert_elements_in_delta expected, transformed[0, 0..4].to_a 43 | 44 | reconstructed = TorchAudio::Transforms::MuLawDecoding.new.call(transformed) 45 | assert_equal [1, 50800], reconstructed.size 46 | expected = [8.621309e-05, 8.621309e-05, 8.621309e-05, 8.621309e-05, 8.621309e-05] 47 | assert_elements_in_delta expected, reconstructed[0, 0..4].to_a 48 | 49 | err = ((waveform - reconstructed).abs / waveform.abs).median.item 50 | assert_in_delta err, 0.0199 51 | end 52 | 53 | def test_compute_deltas 54 | waveform, _ = TorchAudio.load(audio_path) 55 | transformed = TorchAudio::Transforms::ComputeDeltas.new.call(waveform) 56 | assert_equal [1, 50800], transformed.shape 57 | expected = [3.0517579e-06, 0.0, -3.0517579e-06, -6.1035157e-06, 0.0] 58 | assert_elements_in_delta expected, transformed[0, 0..4].to_a 59 | end 60 | 61 | def test_fade 62 | waveform, _ = TorchAudio.load(audio_path) 63 | transformed = TorchAudio::Transforms::Fade.new.call(waveform) 64 | assert_equal [1, 50800], transformed.shape 65 | expected = [3.0517578e-05, 6.1035156e-05, 3.0517578e-05, 3.0517578e-05, 3.0517578e-05] 66 | assert_elements_in_delta expected, transformed[0, 0..4].to_a 67 | end 68 | 69 | def test_vol 70 | waveform, _ = TorchAudio.load(audio_path) 71 | transformed = TorchAudio::Transforms::Vol.new(2).call(waveform) 72 | assert_equal [1, 50800], transformed.shape 73 | expected = [6.1035156e-05, 0.00012207031, 6.1035156e-05, 6.1035156e-05, 6.1035156e-05] 74 | assert_elements_in_delta expected, transformed[0, 0..4].to_a 75 | end 76 | end 77 | -------------------------------------------------------------------------------- /examples/tutorial.rb: -------------------------------------------------------------------------------- 1 | # ported from PyTorch Tutorials 2 | # https://pytorch.org/tutorials/beginner/audio_preprocessing_tutorial.html 3 | # see LICENSE-tutorial.txt 4 | 5 | # audio file available at 6 | # https://github.com/pytorch/tutorials/raw/master/_static/img/steam-train-whistle-daniel_simon-converted-from-mp3.wav 7 | 8 | require "torch" 9 | require "torchaudio" 10 | require "matplotlib/pyplot" 11 | 12 | plt = Matplotlib::Pyplot 13 | 14 | filename = "steam-train-whistle-daniel_simon-converted-from-mp3.wav" 15 | waveform, sample_rate = TorchAudio.load(filename) 16 | 17 | puts "Shape of waveform: #{waveform.size}" 18 | puts "Sample rate of waveform: #{sample_rate}" 19 | 20 | plt.figure 21 | plt.plot(waveform.t.to_a) 22 | plt.savefig("waveform.png") 23 | 24 | # --- 25 | 26 | specgram = TorchAudio::Transforms::Spectrogram.new.call(waveform) 27 | 28 | puts "Shape of spectrogram: #{specgram.size}" 29 | 30 | plt.figure 31 | plt.imshow(specgram.log2[0, 0..-1, 0..-1].to_a, cmap: "gray") 32 | plt.savefig("spectrogram.png") 33 | 34 | # --- 35 | 36 | specgram = TorchAudio::Transforms::MelSpectrogram.new.call(waveform) 37 | 38 | puts "Shape of spectrogram: #{specgram.size}" 39 | 40 | plt.figure 41 | plt.imshow(specgram.log2[0, 0..-1, 0..-1].to_a, cmap: "gray") 42 | plt.savefig("mel_spectrogram.png") 43 | 44 | # --- 45 | 46 | puts "Min of waveform: #{waveform.min.item}" 47 | puts "Max of waveform: #{waveform.max.item}" 48 | puts "Mean of waveform: #{waveform.mean.item}" 49 | 50 | # --- 51 | 52 | transformed = TorchAudio::Transforms::MuLawEncoding.new.call(waveform) 53 | 54 | puts "Shape of transformed waveform: #{transformed.size}" 55 | 56 | plt.figure 57 | plt.plot(transformed[0, 0..-1].to_a) 58 | plt.savefig("mu_law_encoding.png") 59 | 60 | # --- 61 | 62 | reconstructed = TorchAudio::Transforms::MuLawDecoding.new.call(transformed) 63 | 64 | puts "Shape of recovered waveform: #{reconstructed.size}" 65 | 66 | plt.figure 67 | plt.plot(reconstructed[0, 0..-1].to_a) 68 | plt.savefig("mu_law_decoding.png") 69 | 70 | # --- 71 | 72 | err = ((waveform - reconstructed).abs / waveform.abs).median 73 | 74 | puts "Median relative difference between original and MuLaw reconstructed signals: #{(err.item * 100).round(2)}%" 75 | 76 | # --- 77 | 78 | mu_law_encoding_waveform = TorchAudio::Functional.mu_law_encoding(waveform, 256) 79 | 80 | puts "Shape of transformed waveform: #{mu_law_encoding_waveform.size}" 81 | 82 | plt.figure 83 | plt.plot(mu_law_encoding_waveform[0, 0..-1].to_a) 84 | plt.savefig("mu_law_encoding_functional.png") 85 | 86 | # --- 87 | 88 | computed = TorchAudio::Functional.compute_deltas(specgram.contiguous, win_length: 3) 89 | puts "Shape of computed deltas: #{computed.shape}" 90 | 91 | plt.figure 92 | plt.imshow(computed.log2[0, 0..-1, 0..-1].detach.to_a, cmap: "gray") 93 | plt.savefig("compute_deltas.png") 94 | 95 | # --- 96 | 97 | gain_waveform = TorchAudio::Functional.gain(waveform, gain_db: 5.0) 98 | puts "Min of gain_waveform: #{gain_waveform.min.item}" 99 | puts "Max of gain_waveform: #{gain_waveform.max.item}" 100 | puts "Mean of gain_waveform: #{gain_waveform.mean.item}" 101 | 102 | dither_waveform = TorchAudio::Functional.dither(waveform) 103 | puts "Min of dither_waveform: #{dither_waveform.min.item}" 104 | puts "Max of dither_waveform: #{dither_waveform.max.item}" 105 | puts "Mean of dither_waveform: #{dither_waveform.mean.item}" 106 | 107 | # --- 108 | 109 | lowpass_waveform = TorchAudio::Functional.lowpass_biquad(waveform, sample_rate, 3000) 110 | 111 | puts "Min of lowpass_waveform: #{lowpass_waveform.min.item}" 112 | puts "Max of lowpass_waveform: #{lowpass_waveform.max.item}" 113 | puts "Mean of lowpass_waveform: #{lowpass_waveform.mean.item}" 114 | 115 | plt.figure 116 | plt.plot(lowpass_waveform.t.to_a) 117 | plt.savefig("lowpass_biquad.png") 118 | 119 | # --- 120 | 121 | highpass_waveform = TorchAudio::Functional.highpass_biquad(waveform, sample_rate, 2000) 122 | 123 | puts "Min of highpass_waveform: #{highpass_waveform.min.item}" 124 | puts "Max of highpass_waveform: #{highpass_waveform.max.item}" 125 | puts "Mean of highpass_waveform: #{highpass_waveform.mean.item}" 126 | 127 | plt.figure 128 | plt.plot(highpass_waveform.t.to_a) 129 | plt.savefig("highpass_biquad.png") 130 | -------------------------------------------------------------------------------- /ext/torchaudio/csrc/sox_io.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | using namespace torch::indexing; 6 | using namespace torchaudio::sox_utils; 7 | 8 | namespace torchaudio { 9 | namespace sox_io { 10 | 11 | SignalInfo::SignalInfo( 12 | const int64_t sample_rate_, 13 | const int64_t num_channels_, 14 | const int64_t num_frames_) 15 | : sample_rate(sample_rate_), 16 | num_channels(num_channels_), 17 | num_frames(num_frames_){}; 18 | 19 | int64_t SignalInfo::getSampleRate() const { 20 | return sample_rate; 21 | } 22 | 23 | int64_t SignalInfo::getNumChannels() const { 24 | return num_channels; 25 | } 26 | 27 | int64_t SignalInfo::getNumFrames() const { 28 | return num_frames; 29 | } 30 | 31 | c10::intrusive_ptr get_info(const std::string& path) { 32 | SoxFormat sf(sox_open_read( 33 | path.c_str(), 34 | /*signal=*/nullptr, 35 | /*encoding=*/nullptr, 36 | /*filetype=*/nullptr)); 37 | 38 | if (static_cast(sf) == nullptr) { 39 | throw std::runtime_error("Error opening audio file"); 40 | } 41 | 42 | return c10::make_intrusive( 43 | static_cast(sf->signal.rate), 44 | static_cast(sf->signal.channels), 45 | static_cast(sf->signal.length / sf->signal.channels)); 46 | } 47 | 48 | c10::intrusive_ptr load_audio_file( 49 | const std::string& path, 50 | const int64_t frame_offset, 51 | const int64_t num_frames, 52 | const bool normalize, 53 | const bool channels_first) { 54 | if (frame_offset < 0) { 55 | throw std::runtime_error( 56 | "Invalid argument: frame_offset must be non-negative."); 57 | } 58 | if (num_frames == 0 || num_frames < -1) { 59 | throw std::runtime_error( 60 | "Invalid argument: num_frames must be -1 or greater than 0."); 61 | } 62 | 63 | SoxFormat sf(sox_open_read( 64 | path.c_str(), 65 | /*signal=*/nullptr, 66 | /*encoding=*/nullptr, 67 | /*filetype=*/nullptr)); 68 | 69 | validate_input_file(sf); 70 | 71 | const int64_t num_channels = sf->signal.channels; 72 | const int64_t num_total_samples = sf->signal.length; 73 | const int64_t sample_start = sf->signal.channels * frame_offset; 74 | 75 | if (sox_seek(sf, sample_start, 0) == SOX_EOF) { 76 | throw std::runtime_error("Error reading audio file: offset past EOF."); 77 | } 78 | 79 | const int64_t sample_end = [&]() { 80 | if (num_frames == -1) 81 | return num_total_samples; 82 | const int64_t sample_end_ = num_channels * num_frames + sample_start; 83 | if (num_total_samples < sample_end_) { 84 | // For lossy encoding, it is difficult to predict exact size of buffer for 85 | // reading the number of samples required. 86 | // So we allocate buffer size of given `num_frames` and ask sox to read as 87 | // much as possible. For lossless format, sox reads exact number of 88 | // samples, but for lossy encoding, sox can end up reading less. (i.e. 89 | // mp3) For the consistent behavior specification between lossy/lossless 90 | // format, we allow users to provide `num_frames` value that exceeds #of 91 | // available samples, and we adjust it here. 92 | return num_total_samples; 93 | } 94 | return sample_end_; 95 | }(); 96 | 97 | const int64_t max_samples = sample_end - sample_start; 98 | 99 | // Read samples into buffer 100 | std::vector buffer; 101 | buffer.reserve(max_samples); 102 | const int64_t num_samples = sox_read(sf, buffer.data(), max_samples); 103 | if (num_samples == 0) { 104 | throw std::runtime_error( 105 | "Error reading audio file: empty file or read operation failed."); 106 | } 107 | // NOTE: num_samples may be smaller than max_samples if the input 108 | // format is compressed (i.e. mp3). 109 | 110 | // Convert to Tensor 111 | auto tensor = convert_to_tensor( 112 | buffer.data(), 113 | num_samples, 114 | num_channels, 115 | get_dtype(sf->encoding.encoding, sf->signal.precision), 116 | normalize, 117 | channels_first); 118 | 119 | return c10::make_intrusive( 120 | tensor, static_cast(sf->signal.rate), channels_first); 121 | } 122 | 123 | void save_audio_file( 124 | const std::string& file_name, 125 | const c10::intrusive_ptr& signal, 126 | const double compression) { 127 | const auto tensor = signal->getTensor(); 128 | const auto sample_rate = signal->getSampleRate(); 129 | const auto channels_first = signal->getChannelsFirst(); 130 | 131 | validate_input_tensor(tensor); 132 | 133 | const auto filetype = get_filetype(file_name); 134 | const auto signal_info = 135 | get_signalinfo(tensor, sample_rate, channels_first, filetype); 136 | const auto encoding_info = 137 | get_encodinginfo(filetype, tensor.dtype(), compression); 138 | 139 | SoxFormat sf(sox_open_write( 140 | file_name.c_str(), 141 | &signal_info, 142 | &encoding_info, 143 | /*filetype=*/filetype.c_str(), 144 | /*oob=*/nullptr, 145 | /*overwrite_permitted=*/nullptr)); 146 | 147 | if (static_cast(sf) == nullptr) { 148 | throw std::runtime_error("Error saving audio file: failed to open file."); 149 | } 150 | 151 | auto tensor_ = tensor; 152 | if (channels_first) { 153 | tensor_ = tensor_.t(); 154 | } 155 | 156 | const int64_t frames_per_chunk = 65536; 157 | for (int64_t i = 0; i < tensor_.size(0); i += frames_per_chunk) { 158 | auto chunk = tensor_.index({Slice(i, i + frames_per_chunk), Slice()}); 159 | chunk = unnormalize_wav(chunk).contiguous(); 160 | 161 | const size_t numel = chunk.numel(); 162 | if (sox_write(sf, chunk.data_ptr(), numel) != numel) { 163 | throw std::runtime_error( 164 | "Error saving audio file: failed to write the entier buffer."); 165 | } 166 | } 167 | } 168 | 169 | } // namespace sox_io 170 | } // namespace torchaudio 171 | -------------------------------------------------------------------------------- /lib/torchaudio.rb: -------------------------------------------------------------------------------- 1 | # dependencies 2 | require "torch" 3 | 4 | # ext 5 | require "torchaudio/ext" 6 | 7 | # stdlib 8 | require "digest" 9 | require "fileutils" 10 | require "rubygems/package" 11 | require "set" 12 | 13 | # modules 14 | require_relative "torchaudio/datasets/utils" 15 | require_relative "torchaudio/datasets/yesno" 16 | require_relative "torchaudio/functional" 17 | require_relative "torchaudio/transforms/compute_deltas" 18 | require_relative "torchaudio/transforms/fade" 19 | require_relative "torchaudio/transforms/mel_scale" 20 | require_relative "torchaudio/transforms/mel_spectrogram" 21 | require_relative "torchaudio/transforms/mu_law_encoding" 22 | require_relative "torchaudio/transforms/mu_law_decoding" 23 | require_relative "torchaudio/transforms/spectrogram" 24 | require_relative "torchaudio/transforms/amplitude_to_db" 25 | require_relative "torchaudio/transforms/mfcc" 26 | require_relative "torchaudio/transforms/vol" 27 | require_relative "torchaudio/version" 28 | 29 | module TorchAudio 30 | class Error < StandardError; end 31 | 32 | class << self 33 | # TODO remove filetype in 0.5.0 34 | def load( 35 | filepath, 36 | out: nil, 37 | normalization: true, 38 | channels_first: true, 39 | num_frames: 0, 40 | offset: 0, 41 | signalinfo: nil, 42 | encodinginfo: nil, 43 | filetype: nil, 44 | format: nil 45 | ) 46 | filepath = filepath.to_s 47 | 48 | # check if valid file 49 | unless File.exist?(filepath) 50 | raise ArgumentError, "#{filepath} not found or is a directory" 51 | end 52 | 53 | # initialize output tensor 54 | if !out.nil? 55 | check_input(out) 56 | else 57 | out = Torch::FloatTensor.new 58 | end 59 | 60 | if num_frames < -1 61 | raise ArgumentError, "Expected value for num_samples -1 (entire file) or >=0" 62 | end 63 | if offset < 0 64 | raise ArgumentError, "Expected positive offset value" 65 | end 66 | 67 | # same logic as C++ 68 | # could also make read_audio_file work with nil 69 | format ||= filetype || File.extname(filepath)[1..-1] 70 | 71 | sample_rate = 72 | Ext.read_audio_file( 73 | filepath, 74 | out, 75 | channels_first, 76 | num_frames, 77 | offset, 78 | signalinfo, 79 | encodinginfo, 80 | format 81 | ) 82 | 83 | # normalize if needed 84 | normalize_audio(out, normalization) 85 | 86 | [out, sample_rate] 87 | end 88 | 89 | def load_wav(filepath, **kwargs) 90 | kwargs[:normalization] = 1 << 16 91 | load(filepath, **kwargs) 92 | end 93 | 94 | def save(filepath, src, sample_rate, precision: 16, channels_first: true) 95 | si = Ext::SignalInfo.new 96 | ch_idx = channels_first ? 0 : 1 97 | si.rate = sample_rate 98 | si.channels = src.dim == 1 ? 1 : src.size(ch_idx) 99 | si.length = src.numel 100 | si.precision = precision 101 | save_encinfo(filepath, src, channels_first: channels_first, signalinfo: si) 102 | end 103 | 104 | def save_encinfo(filepath, src, channels_first: true, signalinfo: nil, encodinginfo: nil, filetype: nil) 105 | ch_idx, _len_idx = channels_first ? [0, 1] : [1, 0] 106 | 107 | # check if save directory exists 108 | abs_dirpath = File.dirname(File.expand_path(filepath)) 109 | unless Dir.exist?(abs_dirpath) 110 | raise "Directory does not exist: #{abs_dirpath}" 111 | end 112 | # check that src is a CPU tensor 113 | check_input(src) 114 | # Check/Fix shape of source data 115 | if src.dim == 1 116 | # 1d tensors as assumed to be mono signals 117 | src.unsqueeze!(ch_idx) 118 | elsif src.dim > 2 || src.size(ch_idx) > 16 119 | # assumes num_channels < 16 120 | raise ArgumentError, "Expected format where C < 16, but found #{src.size}" 121 | end 122 | # sox stores the sample rate as a float, though practically sample rates are almost always integers 123 | # convert integers to floats 124 | if signalinfo 125 | if signalinfo.rate && !signalinfo.rate.is_a?(Float) 126 | if signalinfo.rate.to_f == signalinfo.rate 127 | signalinfo.rate = signalinfo.rate.to_f 128 | else 129 | raise ArgumentError, "Sample rate should be a float or int" 130 | end 131 | end 132 | # check if the bit precision (i.e. bits per sample) is an integer 133 | if signalinfo.precision && ! signalinfo.precision.is_a?(Integer) 134 | if signalinfo.precision.to_i == signalinfo.precision 135 | signalinfo.precision = signalinfo.precision.to_i 136 | else 137 | raise ArgumentError, "Bit precision should be an integer" 138 | end 139 | end 140 | end 141 | # programs such as librosa normalize the signal, unnormalize if detected 142 | if src.min >= -1.0 && src.max <= 1.0 143 | src = src * (1 << 31) 144 | src = src.long 145 | end 146 | # set filetype and allow for files with no extensions 147 | extension = File.extname(filepath) 148 | filetype = extension.length > 0 ? extension[1..-1] : filetype 149 | # transpose from C x L -> L x C 150 | if channels_first 151 | src = src.transpose(1, 0) 152 | end 153 | # save data to file 154 | src = src.contiguous 155 | Ext.write_audio_file(filepath, src, signalinfo, encodinginfo, filetype) 156 | end 157 | 158 | private 159 | 160 | def check_input(src) 161 | raise ArgumentError, "Expected a tensor, got #{src.class.name}" unless Torch.tensor?(src) 162 | raise ArgumentError, "Expected a CPU based tensor, got #{src.class.name}" if src.cuda? 163 | end 164 | 165 | def normalize_audio(signal, normalization) 166 | return unless normalization 167 | 168 | normalization = 1 << 31 if normalization == true 169 | 170 | if normalization.is_a?(Numeric) 171 | signal.div!(normalization) 172 | elsif normalization.respond_to?(:call) 173 | signal.div!(normalization.call(signal)) 174 | end 175 | end 176 | end 177 | end 178 | -------------------------------------------------------------------------------- /ext/torchaudio/csrc/sox_utils.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | namespace torchaudio { 6 | namespace sox_utils { 7 | 8 | TensorSignal::TensorSignal( 9 | torch::Tensor tensor_, 10 | int64_t sample_rate_, 11 | bool channels_first_) 12 | : tensor(tensor_), 13 | sample_rate(sample_rate_), 14 | channels_first(channels_first_){}; 15 | 16 | torch::Tensor TensorSignal::getTensor() const { 17 | return tensor; 18 | } 19 | int64_t TensorSignal::getSampleRate() const { 20 | return sample_rate; 21 | } 22 | bool TensorSignal::getChannelsFirst() const { 23 | return channels_first; 24 | } 25 | 26 | SoxFormat::SoxFormat(sox_format_t* fd) noexcept : fd_(fd) {} 27 | SoxFormat::~SoxFormat() { 28 | if (fd_ != nullptr) { 29 | sox_close(fd_); 30 | } 31 | } 32 | sox_format_t* SoxFormat::operator->() const noexcept { 33 | return fd_; 34 | } 35 | SoxFormat::operator sox_format_t*() const noexcept { 36 | return fd_; 37 | } 38 | 39 | void validate_input_file(const SoxFormat& sf) { 40 | if (static_cast(sf) == nullptr) { 41 | throw std::runtime_error("Error loading audio file: failed to open file."); 42 | } 43 | if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) { 44 | throw std::runtime_error("Error loading audio file: unknown encoding."); 45 | } 46 | if (sf->signal.length == 0) { 47 | throw std::runtime_error("Error reading audio file: unknown length."); 48 | } 49 | } 50 | 51 | void validate_input_tensor(const torch::Tensor tensor) { 52 | if (!tensor.device().is_cpu()) { 53 | throw std::runtime_error("Input tensor has to be on CPU."); 54 | } 55 | 56 | if (tensor.ndimension() != 2) { 57 | throw std::runtime_error("Input tensor has to be 2D."); 58 | } 59 | 60 | const auto dtype = tensor.dtype(); 61 | if (!(dtype == torch::kFloat32 || dtype == torch::kInt32 || 62 | dtype == torch::kInt16 || dtype == torch::kUInt8)) { 63 | throw std::runtime_error( 64 | "Input tensor has to be one of float32, int32, int16 or uint8 type."); 65 | } 66 | } 67 | 68 | caffe2::TypeMeta get_dtype( 69 | const sox_encoding_t encoding, 70 | const unsigned precision) { 71 | const auto dtype = [&]() { 72 | switch (encoding) { 73 | case SOX_ENCODING_UNSIGNED: // 8-bit PCM WAV 74 | return torch::kUInt8; 75 | case SOX_ENCODING_SIGN2: // 16-bit or 32-bit PCM WAV 76 | switch (precision) { 77 | case 16: 78 | return torch::kInt16; 79 | case 32: 80 | return torch::kInt32; 81 | default: 82 | throw std::runtime_error( 83 | "Only 16 and 32 bits are supported for signed PCM."); 84 | } 85 | default: 86 | // default to float32 for the other formats, including 87 | // 32-bit floating-point WAV, 88 | // MP3, 89 | // FLAC, 90 | // VORBIS etc... 91 | return torch::kFloat32; 92 | } 93 | }(); 94 | return c10::scalarTypeToTypeMeta(dtype); 95 | } 96 | 97 | torch::Tensor convert_to_tensor( 98 | sox_sample_t* buffer, 99 | const int32_t num_samples, 100 | const int32_t num_channels, 101 | const caffe2::TypeMeta dtype, 102 | const bool normalize, 103 | const bool channels_first) { 104 | auto t = torch::from_blob( 105 | buffer, {num_samples / num_channels, num_channels}, torch::kInt32); 106 | // Note: Tensor created from_blob does not own data but borrwos 107 | // So make sure to create a new copy after processing samples. 108 | if (normalize || dtype == torch::kFloat32) { 109 | t = t.to(torch::kFloat32); 110 | t *= (t > 0) / 2147483647. + (t < 0) / 2147483648.; 111 | } else if (dtype == torch::kInt32) { 112 | t = t.clone(); 113 | } else if (dtype == torch::kInt16) { 114 | t.floor_divide_(1 << 16); 115 | t = t.to(torch::kInt16); 116 | } else if (dtype == torch::kUInt8) { 117 | t.floor_divide_(1 << 24); 118 | t += 128; 119 | t = t.to(torch::kUInt8); 120 | } else { 121 | throw std::runtime_error("Unsupported dtype."); 122 | } 123 | if (channels_first) { 124 | t = t.transpose(1, 0); 125 | } 126 | return t.contiguous(); 127 | } 128 | 129 | torch::Tensor unnormalize_wav(const torch::Tensor input_tensor) { 130 | const auto dtype = input_tensor.dtype(); 131 | auto tensor = input_tensor; 132 | if (dtype == torch::kFloat32) { 133 | double multi_pos = 2147483647.; 134 | double multi_neg = -2147483648.; 135 | auto mult = (tensor > 0) * multi_pos - (tensor < 0) * multi_neg; 136 | tensor = tensor.to(torch::dtype(torch::kFloat64)); 137 | tensor *= mult; 138 | tensor.clamp_(multi_neg, multi_pos); 139 | tensor = tensor.to(torch::dtype(torch::kInt32)); 140 | } else if (dtype == torch::kInt32) { 141 | // already denormalized 142 | } else if (dtype == torch::kInt16) { 143 | tensor = tensor.to(torch::dtype(torch::kInt32)); 144 | tensor *= ((tensor != 0) * 65536); 145 | } else if (dtype == torch::kUInt8) { 146 | tensor = tensor.to(torch::dtype(torch::kInt32)); 147 | tensor -= 128; 148 | tensor *= 16777216; 149 | } else { 150 | throw std::runtime_error("Unexpected dtype."); 151 | } 152 | return tensor; 153 | } 154 | 155 | const std::string get_filetype(const std::string path) { 156 | std::string ext = path.substr(path.find_last_of(".") + 1); 157 | std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); 158 | return ext; 159 | } 160 | 161 | sox_encoding_t get_encoding( 162 | const std::string filetype, 163 | const caffe2::TypeMeta dtype) { 164 | if (filetype == "mp3") 165 | return SOX_ENCODING_MP3; 166 | if (filetype == "flac") 167 | return SOX_ENCODING_FLAC; 168 | if (filetype == "ogg" || filetype == "vorbis") 169 | return SOX_ENCODING_VORBIS; 170 | if (filetype == "wav") { 171 | if (dtype == torch::kUInt8) 172 | return SOX_ENCODING_UNSIGNED; 173 | if (dtype == torch::kInt16) 174 | return SOX_ENCODING_SIGN2; 175 | if (dtype == torch::kInt32) 176 | return SOX_ENCODING_SIGN2; 177 | if (dtype == torch::kFloat32) 178 | return SOX_ENCODING_FLOAT; 179 | throw std::runtime_error("Unsupported dtype."); 180 | } 181 | throw std::runtime_error("Unsupported file type."); 182 | } 183 | 184 | unsigned get_precision( 185 | const std::string filetype, 186 | const caffe2::TypeMeta dtype) { 187 | if (filetype == "mp3") 188 | return SOX_UNSPEC; 189 | if (filetype == "flac") 190 | return 24; 191 | if (filetype == "ogg" || filetype == "vorbis") 192 | return SOX_UNSPEC; 193 | if (filetype == "wav") { 194 | if (dtype == torch::kUInt8) 195 | return 8; 196 | if (dtype == torch::kInt16) 197 | return 16; 198 | if (dtype == torch::kInt32) 199 | return 32; 200 | if (dtype == torch::kFloat32) 201 | return 32; 202 | throw std::runtime_error("Unsupported dtype."); 203 | } 204 | throw std::runtime_error("Unsupported file type."); 205 | } 206 | 207 | sox_signalinfo_t get_signalinfo( 208 | const torch::Tensor& tensor, 209 | const int64_t sample_rate, 210 | const bool channels_first, 211 | const std::string filetype) { 212 | return sox_signalinfo_t{ 213 | /*rate=*/static_cast(sample_rate), 214 | /*channels=*/static_cast(tensor.size(channels_first ? 0 : 1)), 215 | /*precision=*/get_precision(filetype, tensor.dtype()), 216 | /*length=*/static_cast(tensor.numel())}; 217 | } 218 | 219 | sox_encodinginfo_t get_encodinginfo( 220 | const std::string filetype, 221 | const caffe2::TypeMeta dtype, 222 | const double compression) { 223 | const double compression_ = [&]() { 224 | if (filetype == "mp3") 225 | return compression; 226 | if (filetype == "flac") 227 | return compression; 228 | if (filetype == "ogg" || filetype == "vorbis") 229 | return compression; 230 | if (filetype == "wav") 231 | return 0.; 232 | throw std::runtime_error("Unsupported file type."); 233 | }(); 234 | 235 | return sox_encodinginfo_t{/*encoding=*/get_encoding(filetype, dtype), 236 | /*bits_per_sample=*/get_precision(filetype, dtype), 237 | /*compression=*/compression_, 238 | /*reverse_bytes=*/sox_option_default, 239 | /*reverse_nibbles=*/sox_option_default, 240 | /*reverse_bits=*/sox_option_default, 241 | /*opposite_endian=*/sox_false}; 242 | } 243 | 244 | } // namespace sox_utils 245 | } // namespace torchaudio 246 | -------------------------------------------------------------------------------- /lib/torchaudio/functional.rb: -------------------------------------------------------------------------------- 1 | module TorchAudio 2 | module Functional 3 | class << self 4 | def spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, normalized, center: true, pad_mode: "reflect", onesided: true) 5 | if pad > 0 6 | # TODO add "with torch.no_grad():" back when JIT supports it 7 | waveform = Torch::NN::Functional.pad(waveform, [pad, pad], "constant") 8 | end 9 | 10 | # pack batch 11 | shape = waveform.size 12 | waveform = waveform.reshape(-1, shape[-1]) 13 | 14 | # default values are consistent with librosa.core.spectrum._spectrogram 15 | spec_f = 16 | Torch.stft( 17 | waveform, 18 | n_fft, 19 | hop_length: hop_length, 20 | win_length: win_length, 21 | window: window, 22 | center: center, 23 | pad_mode: pad_mode, 24 | normalized: false, 25 | onesided: onesided, 26 | return_complex: true 27 | ) 28 | 29 | # unpack batch 30 | spec_f = spec_f.reshape(shape[0..-2] + spec_f.shape[-2..-1]) 31 | 32 | if normalized 33 | spec_f /= window.pow(2.0).sum.sqrt 34 | end 35 | if !power.nil? 36 | if power == 1 37 | return spec_f.abs 38 | end 39 | return spec_f.abs.pow(power) 40 | end 41 | spec_f 42 | end 43 | 44 | def mu_law_encoding(x, quantization_channels) 45 | mu = quantization_channels - 1.0 46 | if !x.floating_point? 47 | x = x.to(dtype: :float) 48 | end 49 | mu = Torch.tensor(mu, dtype: x.dtype) 50 | x_mu = Torch.sign(x) * Torch.log1p(mu * Torch.abs(x)) / Torch.log1p(mu) 51 | x_mu = ((x_mu + 1) / 2 * mu + 0.5).to(dtype: :int64) 52 | x_mu 53 | end 54 | 55 | def mu_law_decoding(x_mu, quantization_channels) 56 | mu = quantization_channels - 1.0 57 | if !x_mu.floating_point? 58 | x_mu = x_mu.to(dtype: :float) 59 | end 60 | mu = Torch.tensor(mu, dtype: x_mu.dtype) 61 | x = ((x_mu) / mu) * 2 - 1.0 62 | x = Torch.sign(x) * (Torch.exp(Torch.abs(x) * Torch.log1p(mu)) - 1.0) / mu 63 | x 64 | end 65 | 66 | def complex_norm(complex_tensor, power: 1.0) 67 | complex_tensor.pow(2.0).sum(-1).pow(0.5 * power) 68 | end 69 | 70 | def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate, norm: nil) 71 | if norm && norm != "slaney" 72 | raise ArgumentError, "norm must be one of None or 'slaney'" 73 | end 74 | 75 | # freq bins 76 | # Equivalent filterbank construction by Librosa 77 | all_freqs = Torch.linspace(0, sample_rate.div(2), n_freqs) 78 | 79 | # calculate mel freq bins 80 | # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.)) 81 | m_min = 2595.0 * Math.log10(1.0 + (f_min / 700.0)) 82 | m_max = 2595.0 * Math.log10(1.0 + (f_max / 700.0)) 83 | m_pts = Torch.linspace(m_min, m_max, n_mels + 2) 84 | # mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.) 85 | f_pts = (Torch.pow(10, m_pts / 2595.0) - 1.0) * 700.0 86 | # calculate the difference between each mel point and each stft freq point in hertz 87 | f_diff = f_pts[1..-1] - f_pts[0...-1] # (n_mels + 1) 88 | slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_mels + 2) 89 | # create overlapping triangles 90 | zero = Torch.zeros(1) 91 | down_slopes = (slopes[0..-1, 0...-2] * -1.0) / f_diff[0...-1] # (n_freqs, n_mels) 92 | up_slopes = slopes[0..-1, 2..-1] / f_diff[1..-1] # (n_freqs, n_mels) 93 | fb = Torch.max(zero, Torch.min(down_slopes, up_slopes)) 94 | 95 | if norm && norm == "slaney" 96 | # Slaney-style mel is scaled to be approx constant energy per channel 97 | enorm = 2.0 / (f_pts[2...(n_mels + 2)] - f_pts[:n_mels]) 98 | fb *= enorm.unsqueeze(0) 99 | end 100 | 101 | fb 102 | end 103 | 104 | def compute_deltas(specgram, win_length: 5, mode: "replicate") 105 | device = specgram.device 106 | dtype = specgram.dtype 107 | 108 | # pack batch 109 | shape = specgram.size 110 | specgram = specgram.reshape(1, -1, shape[-1]) 111 | 112 | raise ArgumentError, "win_length must be >= 3" unless win_length >= 3 113 | 114 | n = (win_length - 1).div(2) 115 | 116 | # twice sum of integer squared 117 | denom = n * (n + 1) * (2 * n + 1) / 3 118 | 119 | specgram = Torch::NN::Functional.pad(specgram, [n, n], mode: mode) 120 | 121 | kernel = Torch.arange(-n, n + 1, 1, device: device, dtype: dtype).repeat([specgram.shape[1], 1, 1]) 122 | 123 | output = Torch::NN::Functional.conv1d(specgram, kernel, groups: specgram.shape[1]) / denom 124 | 125 | # unpack batch 126 | output = output.reshape(shape) 127 | end 128 | 129 | def gain(waveform, gain_db: 1.0) 130 | return waveform if gain_db == 0 131 | 132 | ratio = 10 ** (gain_db / 20) 133 | 134 | waveform * ratio 135 | end 136 | 137 | def dither(waveform, density_function: "TPDF", noise_shaping: false) 138 | dithered = _apply_probability_distribution(waveform, density_function: density_function) 139 | 140 | if noise_shaping 141 | raise "Not implemented yet" 142 | # _add_noise_shaping(dithered, waveform) 143 | else 144 | dithered 145 | end 146 | end 147 | 148 | def biquad(waveform, b0, b1, b2, a0, a1, a2) 149 | device = waveform.device 150 | dtype = waveform.dtype 151 | 152 | output_waveform = lfilter( 153 | waveform, 154 | Torch.tensor([a0, a1, a2], dtype: dtype, device: device), 155 | Torch.tensor([b0, b1, b2], dtype: dtype, device: device) 156 | ) 157 | output_waveform 158 | end 159 | 160 | def highpass_biquad(waveform, sample_rate, cutoff_freq, q: 0.707) 161 | w0 = 2 * Math::PI * cutoff_freq / sample_rate 162 | alpha = Math.sin(w0) / 2.0 / q 163 | 164 | b0 = (1 + Math.cos(w0)) / 2 165 | b1 = -1 - Math.cos(w0) 166 | b2 = b0 167 | a0 = 1 + alpha 168 | a1 = -2 * Math.cos(w0) 169 | a2 = 1 - alpha 170 | biquad(waveform, b0, b1, b2, a0, a1, a2) 171 | end 172 | 173 | def lowpass_biquad(waveform, sample_rate, cutoff_freq, q: 0.707) 174 | w0 = 2 * Math::PI * cutoff_freq / sample_rate 175 | alpha = Math.sin(w0) / 2 / q 176 | 177 | b0 = (1 - Math.cos(w0)) / 2 178 | b1 = 1 - Math.cos(w0) 179 | b2 = b0 180 | a0 = 1 + alpha 181 | a1 = -2 * Math.cos(w0) 182 | a2 = 1 - alpha 183 | biquad(waveform, b0, b1, b2, a0, a1, a2) 184 | end 185 | 186 | def lfilter(waveform, a_coeffs, b_coeffs, clamp: true) 187 | # pack batch 188 | shape = waveform.size 189 | waveform = waveform.reshape(-1, shape[-1]) 190 | 191 | raise ArgumentError unless a_coeffs.size(0) == b_coeffs.size(0) 192 | raise ArgumentError unless waveform.size.length == 2 193 | raise ArgumentError unless waveform.device == a_coeffs.device 194 | raise ArgumentError unless b_coeffs.device == a_coeffs.device 195 | 196 | device = waveform.device 197 | dtype = waveform.dtype 198 | n_channel, n_sample = waveform.size 199 | n_order = a_coeffs.size(0) 200 | n_sample_padded = n_sample + n_order - 1 201 | raise ArgumentError unless n_order > 0 202 | 203 | # Pad the input and create output 204 | padded_waveform = Torch.zeros(n_channel, n_sample_padded, dtype: dtype, device: device) 205 | padded_waveform[0..-1, (n_order - 1)..-1] = waveform 206 | padded_output_waveform = Torch.zeros(n_channel, n_sample_padded, dtype: dtype, device: device) 207 | 208 | # Set up the coefficients matrix 209 | # Flip coefficients' order 210 | a_coeffs_flipped = a_coeffs.flip([0]) 211 | b_coeffs_flipped = b_coeffs.flip([0]) 212 | 213 | # calculate windowed_input_signal in parallel 214 | # create indices of original with shape (n_channel, n_order, n_sample) 215 | window_idxs = Torch.arange(n_sample, device: device).unsqueeze(0) + Torch.arange(n_order, device: device).unsqueeze(1) 216 | window_idxs = window_idxs.repeat([n_channel, 1, 1]) 217 | window_idxs += (Torch.arange(n_channel, device: device).unsqueeze(-1).unsqueeze(-1) * n_sample_padded) 218 | window_idxs = window_idxs.long 219 | # (n_order, ) matmul (n_channel, n_order, n_sample) -> (n_channel, n_sample) 220 | input_signal_windows = Torch.matmul(b_coeffs_flipped, Torch.take(padded_waveform, window_idxs)) 221 | 222 | input_signal_windows.div!(a_coeffs[0]) 223 | a_coeffs_flipped.div!(a_coeffs[0]) 224 | input_signal_windows.t.each_with_index do |o0, i_sample| 225 | windowed_output_signal = padded_output_waveform[0..-1, i_sample...(i_sample + n_order)] 226 | o0.addmv!(windowed_output_signal, a_coeffs_flipped, alpha: -1) 227 | padded_output_waveform[0..-1, i_sample + n_order - 1] = o0 228 | end 229 | 230 | output = padded_output_waveform[0..-1, (n_order - 1)..-1] 231 | 232 | if clamp 233 | output = Torch.clamp(output, -1.0, 1.0) 234 | end 235 | 236 | # unpack batch 237 | output = output.reshape(shape[0...-1] + output.shape[-1..-1]) 238 | 239 | output 240 | end 241 | 242 | def amplitude_to_DB(amp, multiplier, amin, db_multiplier, top_db: nil) 243 | db = Torch.log10(Torch.clamp(amp, min: amin)) * multiplier 244 | db -= multiplier * db_multiplier 245 | 246 | db = db.clamp(min: db.max.item - top_db) if top_db 247 | 248 | db 249 | end 250 | 251 | def DB_to_amplitude(db, ref, power) 252 | Torch.pow(Torch.pow(10.0, db * 0.1), power) * ref 253 | end 254 | 255 | def create_dct(n_mfcc, n_mels, norm: nil) 256 | n = Torch.arange(n_mels.to_f) 257 | k = Torch.arange(n_mfcc.to_f).unsqueeze!(1) 258 | dct = Torch.cos((n + 0.5) * k * Math::PI / n_mels.to_f) 259 | 260 | if norm.nil? 261 | dct *= 2.0 262 | else 263 | raise ArgumentError, "Invalid DCT norm value" unless norm == :ortho 264 | 265 | dct[0] *= 1.0 / Math.sqrt(2.0) 266 | dct *= Math.sqrt(2.0 / n_mels) 267 | end 268 | 269 | dct.t 270 | end 271 | 272 | private 273 | 274 | def _apply_probability_distribution(waveform, density_function: "TPDF") 275 | # pack batch 276 | shape = waveform.size 277 | waveform = waveform.reshape(-1, shape[-1]) 278 | 279 | channel_size = waveform.size[0] - 1 280 | time_size = waveform.size[-1] - 1 281 | 282 | random_channel = channel_size > 0 ? Torch.randint(channel_size, [1]).item.to_i : 0 283 | random_time = time_size > 0 ? Torch.randint(time_size, [1]).item.to_i : 0 284 | 285 | number_of_bits = 16 286 | up_scaling = 2 ** (number_of_bits - 1) - 2 287 | signal_scaled = waveform * up_scaling 288 | down_scaling = 2 ** (number_of_bits - 1) 289 | 290 | signal_scaled_dis = waveform 291 | if density_function == "RPDF" 292 | rpdf = waveform[random_channel][random_time] - 0.5 293 | 294 | signal_scaled_dis = signal_scaled + rpdf 295 | elsif density_function == "GPDF" 296 | # TODO Replace by distribution code once 297 | # https://github.com/pytorch/pytorch/issues/29843 is resolved 298 | # gaussian = torch.distributions.normal.Normal(torch.mean(waveform, -1), 1).sample() 299 | 300 | num_rand_variables = 6 301 | 302 | gaussian = waveform[random_channel][random_time] 303 | (num_rand_variables * [time_size]).each do |ws| 304 | rand_chan = Torch.randint(channel_size, [1]).item.to_i 305 | gaussian += waveform[rand_chan][Torch.randint(ws, [1]).item.to_i] 306 | end 307 | 308 | signal_scaled_dis = signal_scaled + gaussian 309 | else 310 | # dtype needed for https://github.com/pytorch/pytorch/issues/32358 311 | # TODO add support for dtype and device to bartlett_window 312 | tpdf = Torch.bartlett_window(time_size + 1).to(signal_scaled.device, dtype: signal_scaled.dtype) 313 | tpdf = tpdf.repeat([channel_size + 1, 1]) 314 | signal_scaled_dis = signal_scaled + tpdf 315 | end 316 | 317 | quantised_signal_scaled = Torch.round(signal_scaled_dis) 318 | quantised_signal = quantised_signal_scaled / down_scaling 319 | 320 | # unpack batch 321 | quantised_signal.reshape(shape[0...-1] + quantised_signal.shape[-1..-1]) 322 | end 323 | end 324 | end 325 | 326 | F = Functional 327 | end 328 | -------------------------------------------------------------------------------- /ext/torchaudio/csrc/sox.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace torch { 9 | namespace audio { 10 | namespace { 11 | /// Helper struct to safely close the sox_format_t descriptor. 12 | struct SoxDescriptor { 13 | explicit SoxDescriptor(sox_format_t* fd) noexcept : fd_(fd) {} 14 | SoxDescriptor(const SoxDescriptor& other) = delete; 15 | SoxDescriptor(SoxDescriptor&& other) = delete; 16 | SoxDescriptor& operator=(const SoxDescriptor& other) = delete; 17 | SoxDescriptor& operator=(SoxDescriptor&& other) = delete; 18 | ~SoxDescriptor() { 19 | if (fd_ != nullptr) { 20 | sox_close(fd_); 21 | } 22 | } 23 | sox_format_t* operator->() noexcept { 24 | return fd_; 25 | } 26 | sox_format_t* get() noexcept { 27 | return fd_; 28 | } 29 | 30 | private: 31 | sox_format_t* fd_; 32 | }; 33 | 34 | int64_t write_audio(SoxDescriptor& fd, at::Tensor tensor) { 35 | std::vector buffer(tensor.numel()); 36 | 37 | AT_DISPATCH_ALL_TYPES(tensor.scalar_type(), "write_audio_buffer", [&] { 38 | auto* data = tensor.data_ptr(); 39 | std::copy(data, data + tensor.numel(), buffer.begin()); 40 | }); 41 | 42 | const auto samples_written = 43 | sox_write(fd.get(), buffer.data(), buffer.size()); 44 | 45 | return samples_written; 46 | } 47 | 48 | void read_audio( 49 | SoxDescriptor& fd, 50 | at::Tensor output, 51 | int64_t buffer_length) { 52 | std::vector buffer(buffer_length); 53 | 54 | int number_of_channels = fd->signal.channels; 55 | const int64_t samples_read = sox_read(fd.get(), buffer.data(), buffer_length); 56 | if (samples_read == 0) { 57 | throw std::runtime_error( 58 | "Error reading audio file: empty file or read failed in sox_read"); 59 | } 60 | 61 | output.resize_({samples_read / number_of_channels, number_of_channels}); 62 | output = output.contiguous(); 63 | 64 | AT_DISPATCH_ALL_TYPES(output.scalar_type(), "read_audio_buffer", [&] { 65 | auto* data = output.data_ptr(); 66 | std::copy(buffer.begin(), buffer.begin() + samples_read, data); 67 | }); 68 | } 69 | } // namespace 70 | 71 | std::tuple get_info( 72 | const std::string& file_name 73 | ) { 74 | SoxDescriptor fd(sox_open_read( 75 | file_name.c_str(), 76 | /*signal=*/nullptr, 77 | /*encoding=*/nullptr, 78 | /*filetype=*/nullptr)); 79 | if (fd.get() == nullptr) { 80 | throw std::runtime_error("Error opening audio file"); 81 | } 82 | return std::make_tuple(fd->signal, fd->encoding); 83 | } 84 | 85 | int read_audio_file( 86 | const std::string& file_name, 87 | at::Tensor output, 88 | bool ch_first, 89 | int64_t nframes, 90 | int64_t offset, 91 | sox_signalinfo_t* si, 92 | sox_encodinginfo_t* ei, 93 | const char* ft) { 94 | 95 | SoxDescriptor fd(sox_open_read(file_name.c_str(), si, ei, ft)); 96 | if (fd.get() == nullptr) { 97 | throw std::runtime_error("Error opening audio file"); 98 | } 99 | 100 | // signal info 101 | 102 | const int number_of_channels = fd->signal.channels; 103 | const int sample_rate = fd->signal.rate; 104 | const int64_t total_length = fd->signal.length; 105 | 106 | // multiply offset and number of frames by number of channels 107 | offset *= number_of_channels; 108 | nframes *= number_of_channels; 109 | 110 | if (total_length == 0) { 111 | throw std::runtime_error("Error reading audio file: unknown length"); 112 | } 113 | if (offset > total_length) { 114 | throw std::runtime_error("Offset past EOF"); 115 | } 116 | 117 | // calculate buffer length 118 | int64_t buffer_length = total_length; 119 | if (offset > 0) { 120 | buffer_length -= offset; 121 | } 122 | if (nframes > 0 && buffer_length > nframes) { 123 | buffer_length = nframes; 124 | } 125 | 126 | // seek to offset point before reading data 127 | if (sox_seek(fd.get(), offset, 0) == SOX_EOF) { 128 | throw std::runtime_error("sox_seek reached EOF, try reducing offset or num_samples"); 129 | } 130 | 131 | // read data and fill output tensor 132 | read_audio(fd, output, buffer_length); 133 | 134 | // L x C -> C x L, if desired 135 | if (ch_first) { 136 | output.transpose_(1, 0); 137 | } 138 | 139 | return sample_rate; 140 | } 141 | 142 | void write_audio_file( 143 | const std::string& file_name, 144 | const at::Tensor& tensor, 145 | sox_signalinfo_t* si, 146 | sox_encodinginfo_t* ei, 147 | const char* file_type) { 148 | if (!tensor.is_contiguous()) { 149 | throw std::runtime_error( 150 | "Error writing audio file: input tensor must be contiguous"); 151 | } 152 | 153 | #if SOX_LIB_VERSION_CODE >= 918272 // >= 14.3.0 154 | si->mult = nullptr; 155 | #endif 156 | 157 | SoxDescriptor fd(sox_open_write( 158 | file_name.c_str(), 159 | si, 160 | ei, 161 | file_type, 162 | /*oob=*/nullptr, 163 | /*overwrite=*/nullptr)); 164 | 165 | if (fd.get() == nullptr) { 166 | throw std::runtime_error( 167 | "Error writing audio file: could not open file for writing"); 168 | } 169 | 170 | const auto samples_written = write_audio(fd, tensor); 171 | 172 | if (samples_written != tensor.numel()) { 173 | throw std::runtime_error( 174 | "Error writing audio file: could not write entire buffer"); 175 | } 176 | } 177 | 178 | int build_flow_effects(const std::string& file_name, 179 | at::Tensor otensor, 180 | bool ch_first, 181 | sox_signalinfo_t* target_signal, 182 | sox_encodinginfo_t* target_encoding, 183 | const char* file_type, 184 | std::vector pyeffs, 185 | int max_num_eopts) { 186 | 187 | /* This function builds an effects flow and puts the results into a tensor. 188 | It can also be used to re-encode audio using any of the available encoding 189 | options in SoX including sample rate and channel re-encoding. */ 190 | 191 | // open input 192 | sox_format_t* input = sox_open_read(file_name.c_str(), nullptr, nullptr, nullptr); 193 | if (input == nullptr) { 194 | throw std::runtime_error("Error opening audio file"); 195 | } 196 | 197 | // only used if target signal or encoding are null 198 | sox_signalinfo_t empty_signal; 199 | sox_encodinginfo_t empty_encoding; 200 | 201 | // set signalinfo and encodinginfo if blank 202 | if(target_signal == nullptr) { 203 | target_signal = &empty_signal; 204 | target_signal->rate = input->signal.rate; 205 | target_signal->channels = input->signal.channels; 206 | target_signal->length = SOX_UNSPEC; 207 | target_signal->precision = input->signal.precision; 208 | #if SOX_LIB_VERSION_CODE >= 918272 // >= 14.3.0 209 | target_signal->mult = nullptr; 210 | #endif 211 | } 212 | if(target_encoding == nullptr) { 213 | target_encoding = &empty_encoding; 214 | target_encoding->encoding = SOX_ENCODING_SIGN2; // Sample format 215 | target_encoding->bits_per_sample = input->signal.precision; // Bits per sample 216 | target_encoding->compression = 0.0; // Compression factor 217 | target_encoding->reverse_bytes = sox_option_default; // Should bytes be reversed 218 | target_encoding->reverse_nibbles = sox_option_default; // Should nibbles be reversed 219 | target_encoding->reverse_bits = sox_option_default; // Should bits be reversed (pairs of bits?) 220 | target_encoding->opposite_endian = sox_false; // Reverse endianness 221 | } 222 | 223 | // check for rate or channels effect and change the output signalinfo accordingly 224 | for (SoxEffect se : pyeffs) { 225 | if (se.ename == "rate") { 226 | target_signal->rate = std::stod(se.eopts[0]); 227 | } else if (se.ename == "channels") { 228 | target_signal->channels = std::stoi(se.eopts[0]); 229 | } 230 | } 231 | 232 | // create interm_signal for effects, intermediate steps change this in-place 233 | sox_signalinfo_t interm_signal = input->signal; 234 | 235 | #ifdef __APPLE__ 236 | // According to Mozilla Deepspeech sox_open_memstream_write doesn't work 237 | // with OSX 238 | char tmp_name[] = "/tmp/fileXXXXXX"; 239 | int tmp_fd = mkstemp(tmp_name); 240 | close(tmp_fd); 241 | sox_format_t* output = sox_open_write(tmp_name, target_signal, 242 | target_encoding, "wav", nullptr, nullptr); 243 | #else 244 | // create buffer and buffer_size for output in memwrite 245 | char* buffer; 246 | size_t buffer_size; 247 | // in-memory descriptor (this may not work for OSX) 248 | sox_format_t* output = sox_open_memstream_write(&buffer, 249 | &buffer_size, 250 | target_signal, 251 | target_encoding, 252 | file_type, nullptr); 253 | #endif 254 | if (output == nullptr) { 255 | throw std::runtime_error("Error opening output memstream/temporary file"); 256 | } 257 | // Setup the effects chain to decode/resample 258 | sox_effects_chain_t* chain = 259 | sox_create_effects_chain(&input->encoding, &output->encoding); 260 | 261 | sox_effect_t* e = sox_create_effect(sox_find_effect("input")); 262 | char* io_args[1]; 263 | io_args[0] = (char*)input; 264 | sox_effect_options(e, 1, io_args); 265 | sox_add_effect(chain, e, &interm_signal, &input->signal); 266 | free(e); 267 | 268 | for(SoxEffect tae : pyeffs) { 269 | if(tae.ename == "no_effects") break; 270 | e = sox_create_effect(sox_find_effect(tae.ename.c_str())); 271 | e->global_info->global_info->verbosity = 1; 272 | if(tae.eopts[0] == "") { 273 | sox_effect_options(e, 0, nullptr); 274 | } else { 275 | int num_opts = tae.eopts.size(); 276 | char* sox_args[max_num_eopts]; 277 | for(std::vector::size_type i = 0; i != tae.eopts.size(); i++) { 278 | sox_args[i] = (char*) tae.eopts[i].c_str(); 279 | } 280 | if(sox_effect_options(e, num_opts, sox_args) != SOX_SUCCESS) { 281 | #ifdef __APPLE__ 282 | unlink(tmp_name); 283 | #endif 284 | throw std::runtime_error("invalid effect options, see SoX docs for details"); 285 | } 286 | } 287 | sox_add_effect(chain, e, &interm_signal, &output->signal); 288 | free(e); 289 | } 290 | 291 | e = sox_create_effect(sox_find_effect("output")); 292 | io_args[0] = (char*)output; 293 | sox_effect_options(e, 1, io_args); 294 | sox_add_effect(chain, e, &interm_signal, &output->signal); 295 | free(e); 296 | 297 | // Finally run the effects chain 298 | sox_flow_effects(chain, nullptr, nullptr); 299 | sox_delete_effects_chain(chain); 300 | 301 | // Close sox handles, buffer does not get properly sized until these are closed 302 | sox_close(output); 303 | sox_close(input); 304 | 305 | int sr; 306 | // Read the in-memory audio buffer or temp file that we just wrote. 307 | #ifdef __APPLE__ 308 | /* 309 | Temporary filetype must have a valid header. Wav seems to work here while 310 | raw does not. Certain effects like chorus caused strange behavior on the mac. 311 | */ 312 | // read_audio_file reads the temporary file and returns the sr and otensor 313 | sr = read_audio_file(tmp_name, otensor, ch_first, 0, 0, 314 | target_signal, target_encoding, "wav"); 315 | // delete temporary audio file 316 | unlink(tmp_name); 317 | #else 318 | // Resize output tensor to desired dimensions, different effects result in output->signal.length, 319 | // interm_signal.length and buffer size being inconsistent with the result of the file output. 320 | // We prioritize in the order: output->signal.length > interm_signal.length > buffer_size 321 | // Could be related to: https://sourceforge.net/p/sox/bugs/314/ 322 | int nc, ns; 323 | if (output->signal.length == 0) { 324 | // sometimes interm_signal length is extremely large, but the buffer_size 325 | // is double the length of the output signal 326 | if (interm_signal.length > (buffer_size * 10)) { 327 | ns = buffer_size / 2; 328 | } else { 329 | ns = interm_signal.length; 330 | } 331 | nc = interm_signal.channels; 332 | } else { 333 | nc = output->signal.channels; 334 | ns = output->signal.length; 335 | } 336 | otensor.resize_({ns/nc, nc}); 337 | otensor = otensor.contiguous(); 338 | 339 | input = sox_open_mem_read(buffer, buffer_size, target_signal, target_encoding, file_type); 340 | std::vector samples(buffer_size); 341 | const int64_t samples_read = sox_read(input, samples.data(), buffer_size); 342 | assert(samples_read != nc * ns && samples_read != 0); 343 | AT_DISPATCH_ALL_TYPES(otensor.scalar_type(), "effects_buffer", [&] { 344 | auto* data = otensor.data_ptr(); 345 | std::copy(samples.begin(), samples.begin() + samples_read, data); 346 | }); 347 | // free buffer and close mem_read 348 | sox_close(input); 349 | free(buffer); 350 | 351 | if (ch_first) { 352 | otensor.transpose_(1, 0); 353 | } 354 | sr = target_signal->rate; 355 | 356 | #endif 357 | // return sample rate, output tensor modified in-place 358 | return sr; 359 | } 360 | } // namespace audio 361 | } // namespace torch 362 | --------------------------------------------------------------------------------