├── figures ├── attention-map.png ├── model-architecture1.png ├── model-architecture2.png ├── real-time-curve-IO.png ├── real-time-curve-PO.png └── real-time-curve-SS.png ├── doa-estimation ├── conf.yml ├── utils │ ├── prepare_python_env.sh │ └── parse_options.sh ├── run.sh ├── model_causal.py ├── sigprocess.py ├── model.py ├── train.py ├── base.py ├── eval.py ├── system.py └── dataset.py ├── generate_rir.py ├── sms_wsj_replace ├── create_rirs.py └── scenario.py ├── README.md └── run_testset.py /figures/attention-map.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangyi0818/DOA-estimation-with-a-stacked-self-attention-network/HEAD/figures/attention-map.png -------------------------------------------------------------------------------- /figures/model-architecture1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangyi0818/DOA-estimation-with-a-stacked-self-attention-network/HEAD/figures/model-architecture1.png -------------------------------------------------------------------------------- /figures/model-architecture2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangyi0818/DOA-estimation-with-a-stacked-self-attention-network/HEAD/figures/model-architecture2.png -------------------------------------------------------------------------------- /figures/real-time-curve-IO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangyi0818/DOA-estimation-with-a-stacked-self-attention-network/HEAD/figures/real-time-curve-IO.png -------------------------------------------------------------------------------- /figures/real-time-curve-PO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangyi0818/DOA-estimation-with-a-stacked-self-attention-network/HEAD/figures/real-time-curve-PO.png -------------------------------------------------------------------------------- /figures/real-time-curve-SS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangyi0818/DOA-estimation-with-a-stacked-self-attention-network/HEAD/figures/real-time-curve-SS.png -------------------------------------------------------------------------------- /doa-estimation/conf.yml: -------------------------------------------------------------------------------- 1 | # Training config 2 | training: 3 | epochs: 50 4 | batch_size: 8 5 | num_workers: 4 6 | half_lr: yes 7 | early_stop: yes 8 | # Optim config 9 | optim: 10 | optimizer: adam 11 | lr: 0.001 12 | weight_decay: 0. 13 | # Data config 14 | data: 15 | train_dir: data/wav8k/min/tr/ 16 | valid_dir: data/wav8k/min/cv/ 17 | sample_rate: 16000 18 | -------------------------------------------------------------------------------- /doa-estimation/utils/prepare_python_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Usage ./utils/install_env.sh --install_dir A --asteroid_root B --pip_requires C 3 | install_dir=~ 4 | asteroid_root=../../../../ 5 | pip_requires=../../../requirements.txt # Expects a requirement.txt 6 | 7 | . utils/parse_options.sh || exit 1 8 | 9 | mkdir -p $install_dir 10 | cd $install_dir 11 | echo "Download and install latest version of miniconda3 into ${install_dir}" 12 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 13 | 14 | bash Miniconda3-latest-Linux-x86_64.sh -b -p miniconda3 15 | pip_path=$PWD/miniconda3/bin/pip 16 | 17 | rm Miniconda3-latest-Linux-x86_64.sh 18 | cd - 19 | 20 | if [[ ! -z ${pip_requires} ]]; then 21 | $pip_path install -r $pip_requires 22 | fi 23 | $pip_path install soundfile 24 | $pip_path install -e $asteroid_root 25 | #$pip_path install ${asteroid_root}/\[""evaluate""\] 26 | echo -e "\nAsteroid has been installed in editable mode. Feel free to apply your changes !" 27 | -------------------------------------------------------------------------------- /doa-estimation/run.sh: -------------------------------------------------------------------------------- 1 | ##!/bin/bash 2 | 3 | # Exit on error 4 | set -e 5 | set -o pipefail 6 | 7 | export PYTHONPATH=/path/to/asteroid:/path/to/sms_wsj 8 | python_path= 9 | 10 | # General 11 | stage=1 12 | expdir=exp 13 | id=$CUDA_VISIBLE_DEVICES 14 | 15 | train_dir= 16 | val_dir= 17 | test_dir= 18 | sample_rate=16000 19 | . utils/parse_options.sh 20 | 21 | # Training 22 | batch_size=20 23 | num_workers=16 24 | optimizer=adam 25 | lr=0.001 26 | epochs=50 27 | 28 | # Evaluation 29 | eval_use_gpu=1 30 | 31 | mkdir -p $expdir 32 | echo "Results from the following experiment will be stored in $expdir" 33 | 34 | if [[ $stage -le 1 ]]; then 35 | echo -e "Stage 1: Training" 36 | mkdir -p logs 37 | CUDA_VISIBLE_DEVICES=$id $python_path -u train.py \ 38 | --train_dirs $train_dir \ 39 | --val_dirs $val_dir \ 40 | --sample_rate $sample_rate \ 41 | --lr $lr \ 42 | --epochs $epochs \ 43 | --batch_size $batch_size \ 44 | --num_workers $num_workers \ 45 | --exp_dir ${expdir}/ | tee logs/train.log 46 | cp logs/train.log $expdir/train.log 47 | echo -e "Stage 1 - training: Done." 48 | fi 49 | 50 | if [[ $stage -le 2 ]]; then 51 | echo -e "Stage 2 : Evaluation." 52 | echo -e "test set is $test_dir" 53 | CUDA_VISIBLE_DEVICES=$id $python_path -u eval.py \ 54 | --test_dir $test_dir \ 55 | --use_gpu $eval_use_gpu \ 56 | --exp_dir ${expdir} | tee logs/eval.log 57 | cp logs/eval.log $expdir/eval.log 58 | echo -e "Stage 2 - evaluation: Done." 59 | fi 60 | -------------------------------------------------------------------------------- /generate_rir.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import torch 5 | import random 6 | import time 7 | from multiprocessing import Pool 8 | 9 | from sms_wsj.database.create_rirs import config, scenarios, rirs 10 | from sms_wsj.reverb.reverb_utils import convolve 11 | 12 | T60_LOW, T60_HIGH = 0.15, 0.60 13 | 14 | def _worker_init_fn_(worker_id): 15 | torch_seed = torch.initial_seed() 16 | 17 | random.seed(torch_seed + worker_id) 18 | if torch_seed >= 2**32: 19 | torch_seed = torch_seed % 2**32 20 | np.random.seed(torch_seed + worker_id) 21 | 22 | def generate_rir(i): 23 | _worker_init_fn_(i) 24 | reverb_matrixs_dir = '/path/to/reverb-set/' 25 | geometry, sound_decay_time_range, sample_rate, filter_length = config(T60_LOW, T60_HIGH) 26 | room_dimensions, source_positions, sensor_positions, sound_decay_time = scenarios(geometry, sound_decay_time_range,) 27 | h = rirs(sample_rate, filter_length, room_dimensions, source_positions, sensor_positions, sound_decay_time) 28 | np.savez(reverb_matrixs_dir + str(i).zfill(4) + '.npz', h=h, source_positions=source_positions, sensor_positions=sensor_positions, 29 | room_dimensions=room_dimensions, sound_decay_time=sound_decay_time,) 30 | 31 | if __name__ == "__main__": 32 | nj = 32 33 | num_rir = 3000 34 | reverb_matrixs_dir = '/path/to/reverb-set/' 35 | 36 | if not os.path.exists(reverb_matrixs_dir): 37 | os.makedirs(reverb_matrixs_dir) 38 | else: 39 | if (input('target dir already esists, continue? [y/n] ') == 'n'): 40 | print('Exit. Nothing happends.') 41 | sys.exit() 42 | print('Generating reverb matrixs into ', reverb_matrixs_dir, '......') 43 | 44 | time_start=time.time() 45 | pool = Pool(processes=nj) 46 | args = [] 47 | for i in range (num_rir): 48 | args.append(i) 49 | pool.map(generate_rir, args) 50 | pool.close() 51 | pool.join() 52 | time_end=time.time() 53 | print('totally cost ', round((time_end-time_start)/60), 'minutes') 54 | -------------------------------------------------------------------------------- /sms_wsj_replace/create_rirs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sms_wsj.reverb.reverb_utils import generate_rir 3 | from sms_wsj.reverb.scenario import generate_random_source_positions 4 | from sms_wsj.reverb.scenario import generate_sensor_positions 5 | from sms_wsj.reverb.scenario import sample_from_random_box 6 | 7 | def config(): 8 | # Either set it to zero or above 0.15 s. Otherwise, RIR contains NaN. 9 | sound_decay_time_range = dict(low=0.15, high=0.6) 10 | 11 | geometry = dict( 12 | number_of_sources=3, 13 | number_of_sensors=7, 14 | sensor_shape="circular_center", 15 | center=[[3.5], [3.], [1.5]], # m 16 | scale=0.0425, # m 17 | room=[[7.], [6.], [3.]], # m 18 | random_box=[[4.], [2.], [0.4]], # m 19 | ) 20 | 21 | sample_rate = 16000 22 | filter_length = 2 ** 14 # 1.024 seconds when sample_rate == 16000 23 | 24 | return geometry, sound_decay_time_range, sample_rate, filter_length 25 | 26 | def scenarios(geometry,sound_decay_time_range,): 27 | room_dimensions = sample_from_random_box(geometry["room"], geometry["random_box"]) 28 | center = sample_from_random_box(geometry["center"], geometry["random_box"]) 29 | source_positions = generate_random_source_positions(center=center,sources=geometry["number_of_sources"], dims=2) 30 | 31 | sensor_positions = generate_sensor_positions( 32 | shape=geometry["sensor_shape"], 33 | center=center, 34 | room_dimensions = room_dimensions, 35 | scale=geometry["scale"], 36 | number_of_sensors=geometry["number_of_sensors"], 37 | rotate_x=np.random.uniform(0, 0.01 * 2 * np.pi), 38 | rotate_y=np.random.uniform(0, 0.01 * 2 * np.pi), 39 | rotate_z=np.random.uniform(0, 2 * np.pi), 40 | ) 41 | sound_decay_time = np.random.uniform(**sound_decay_time_range) 42 | 43 | return room_dimensions, source_positions, sensor_positions, sound_decay_time 44 | 45 | def rirs(sample_rate, filter_length, room_dimensions, source_positions, sensor_positions, sound_decay_time): 46 | h = generate_rir( 47 | room_dimensions=room_dimensions, 48 | source_positions=source_positions, 49 | sensor_positions=sensor_positions, 50 | sound_decay_time=sound_decay_time, 51 | sample_rate=sample_rate, 52 | filter_length=filter_length, 53 | sensor_orientations=None, 54 | sensor_directivity=None, 55 | sound_velocity=343 56 | ) 57 | 58 | return 59 | -------------------------------------------------------------------------------- /doa-estimation/model_causal.py: -------------------------------------------------------------------------------- 1 | #!/user/bin/env python 2 | # yangyi@2020-2022 3 | # real-time doa estimation via self-attention 4 | 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | 9 | from sigprocess import STFT, ISTFT 10 | from base import dense_block, attention_block 11 | from asteroid.engine.optimizers import make_optimizer 12 | 13 | def make_model_and_optimizer(conf): 14 | model = proposed() 15 | optimizer = make_optimizer(model.parameters(), **conf['optim']) 16 | return model, optimizer 17 | 18 | 19 | class proposed(nn.Module): 20 | def __init__(self, fftsize=512, window_size=400, stride=100, channel=4, causal=True): 21 | super(proposed, self).__init__() 22 | bins = fftsize // 2 23 | self.channel = channel 24 | self.causal = causal 25 | 26 | self.stft = STFT(fftsize=fftsize, window_size=window_size, stride=stride, trainable=False) 27 | self.input_conv_layer = torch.nn.Sequential( 28 | torch.nn.Conv2d(8,8,[1,1],[1,1]), 29 | torch.nn.ReLU(), 30 | torch.nn.BatchNorm2d(8) 31 | ) 32 | 33 | # dense conv block 34 | self.conv_block = nn.ModuleList() 35 | for i in range (4): 36 | self.conv_block.append(dense_block(in_channels=8, out_channels=8, kernel_size=[2,3], stride=[1,1], padding=[1,1])) 37 | 38 | # self-attention block 39 | self.shared_block = nn.ModuleList() 40 | for i in range (4): 41 | self.shared_block.append(attention_block(in_channels=bins//(4**i), out_channels=bins//(4**(i+1)))) 42 | self.shared_block.append(attention_block(in_channels=8, out_channels=8)) 43 | self.shared_block.append(attention_block(in_channels=bins//(4**(i+1)), out_channels=bins//(4**(i+1)))) 44 | 45 | self.re_fc_layer = nn.Linear(8,6) 46 | 47 | 48 | def forward(self, x): # b n c 49 | 50 | x = x.transpose(1,2) # b c n 51 | xs = self.stft(x[:,[0],:])[...,1:,:].unsqueeze(1) # b 1 t f 2 52 | for i in range(1,self.channel): 53 | xs = torch.cat((xs,self.stft(x[:,[i],:])[...,1:,:].unsqueeze(1)),1) # b c t f 2 54 | feat_in = torch.cat((xs[...,0], xs[...,1]), 1) # b 2c t f 55 | 56 | # step1:change channel dim 57 | x_in = self.input_conv_layer(feat_in) # b 2c t f 58 | 59 | for i in range (4): 60 | # step2:dense block 61 | x_out1 = self.conv_block[i](x_in) # (B,8,T,256) (B,8,T,64) (B,8,T,16) (B,8,T,4) 62 | 63 | # step3:self-attention T 64 | x_in2 = x_out1.permute(0,3,1,2) # (B,256,8,T) 65 | x_out2, plot_weight_t = self.shared_block[0+i*3](x_in2, causal=self.causal) # (B,64,8,T) 66 | x_out2 = x_out2.permute(0,2,3,1) 67 | 68 | # step4:self-attention F 69 | x_in3 = x_out2 # (B,8,T,64) 70 | x_out3, plot_weight_f = self.shared_block[1+i*3](x_in3, causal=self.causal) # (B,8,T,64) 71 | 72 | x_in = x_out3 73 | 74 | x_out = x_in.squeeze(-1).transpose(1,2) # b t 8 75 | y = self.re_fc_layer(x_out) # b t 6 76 | 77 | return y.mean(1) 78 | 79 | 80 | if __name__ == "__main__": 81 | import torch 82 | from thop import profile 83 | from thop import clever_format 84 | 85 | model = proposed() 86 | x = torch.randn(1, 16000*8, 7) # b n c 87 | macs, params = profile(model, inputs=(x)) 88 | macs, params = clever_format([macs, params], "%.3f") 89 | 90 | print('macs:', macs) 91 | print('params:', params) 92 | -------------------------------------------------------------------------------- /doa-estimation/utils/parse_options.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey); 4 | # Arnab Ghoshal, Karel Vesely 5 | 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 14 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 15 | # MERCHANTABLITY OR NON-INFRINGEMENT. 16 | # See the Apache 2 License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | # Parse command-line options. 21 | # To be sourced by another script (as in ". parse_options.sh"). 22 | # Option format is: --option-name arg 23 | # and shell variable "option_name" gets set to value "arg." 24 | # The exception is --help, which takes no arguments, but prints the 25 | # $help_message variable (if defined). 26 | 27 | 28 | ### 29 | ### The --config file options have lower priority to command line 30 | ### options, so we need to import them first... 31 | ### 32 | 33 | # Now import all the configs specified by command-line, in left-to-right order 34 | for ((argpos=1; argpos<$#; argpos++)); do 35 | if [ "${!argpos}" == "--config" ]; then 36 | argpos_plus1=$((argpos+1)) 37 | config=${!argpos_plus1} 38 | [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 39 | . $config # source the config file. 40 | fi 41 | done 42 | 43 | 44 | ### 45 | ### Now we process the command line options 46 | ### 47 | while true; do 48 | [ -z "${1:-}" ] && break; # break if there are no arguments 49 | case "$1" in 50 | # If the enclosing script is called with --help option, print the help 51 | # message and exit. Scripts should put help messages in $help_message 52 | --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; 53 | else printf "$help_message\n" 1>&2 ; fi; 54 | exit 0 ;; 55 | --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" 56 | exit 1 ;; 57 | # If the first command-line argument begins with "--" (e.g. --foo-bar), 58 | # then work out the variable name as $name, which will equal "foo_bar". 59 | --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; 60 | # Next we test whether the variable in question is undefned-- if so it's 61 | # an invalid option and we die. Note: $0 evaluates to the name of the 62 | # enclosing script. 63 | # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar 64 | # is undefined. We then have to wrap this test inside "eval" because 65 | # foo_bar is itself inside a variable ($name). 66 | eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; 67 | 68 | oldval="`eval echo \\$$name`"; 69 | # Work out whether we seem to be expecting a Boolean argument. 70 | if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then 71 | was_bool=true; 72 | else 73 | was_bool=false; 74 | fi 75 | 76 | # Set the variable to the right value-- the escaped quotes make it work if 77 | # the option had spaces, like --cmd "queue.pl -sync y" 78 | eval $name=\"$2\"; 79 | 80 | # Check that Boolean-valued arguments are really Boolean. 81 | if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then 82 | echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 83 | exit 1; 84 | fi 85 | shift 2; 86 | ;; 87 | *) break; 88 | esac 89 | done 90 | 91 | 92 | # Check for an empty argument to the --cmd option, which can easily occur as a 93 | # result of scripting errors. 94 | [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; 95 | 96 | 97 | true; # so this script returns exit code 0. 98 | -------------------------------------------------------------------------------- /doa-estimation/sigprocess.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | EPSILON = 1e-8 6 | 7 | class STFT(nn.Module): 8 | def __init__(self, fftsize, window_size, stride, trainable=False): 9 | super(STFT, self).__init__() 10 | self.fftsize = fftsize 11 | self.window_size = window_size 12 | self.stride = stride 13 | self.window_func = np.hanning(self.window_size) 14 | 15 | fcoef_r = np.zeros((self.fftsize//2 + 1, 1, self.window_size)) 16 | fcoef_i = np.zeros((self.fftsize//2 + 1, 1, self.window_size)) 17 | for w in range(self.fftsize//2+1): 18 | for t in range(self.window_size): 19 | fcoef_r[w, 0, t] = np.cos(2. * np.pi * w * t / self.fftsize) 20 | fcoef_i[w, 0, t] = -np.sin(2. * np.pi * w * t / self.fftsize) 21 | 22 | fcoef_r = fcoef_r * self.window_func 23 | fcoef_i = fcoef_i * self.window_func 24 | self.fcoef_r = th.tensor(fcoef_r, dtype=th.float) 25 | self.fcoef_i = th.tensor(fcoef_i, dtype=th.float) 26 | self.encoder_r = nn.Conv1d(1, self.fftsize//2+1, self.window_size, bias=False, stride=self.stride) 27 | self.encoder_i = nn.Conv1d(1, self.fftsize//2+1, self.window_size, bias=False, stride=self.stride) 28 | self.encoder_r.weight = th.nn.Parameter(self.fcoef_r) 29 | self.encoder_i.weight = th.nn.Parameter(self.fcoef_i) 30 | 31 | if trainable: 32 | self.encoder_r.weight.requires_grad = True 33 | self.encoder_i.weight.requires_grad = True 34 | else: 35 | self.encoder_r.weight.requires_grad = False 36 | self.encoder_i.weight.requires_grad = False 37 | 38 | def forward(self, input): # (B, 1, n_sample) 39 | 40 | spec_r = self.encoder_r(input) 41 | spec_i = self.encoder_i(input) 42 | output = th.stack([spec_r,spec_i],dim=-1) 43 | output = output.permute([0, 2, 1, 3]) 44 | 45 | return output # (B,T,F,2) 46 | 47 | class ISTFT(nn.Module): 48 | def __init__(self, fftsize, window_size, stride, trainable=False): 49 | super(ISTFT, self).__init__() 50 | self.fftsize = fftsize 51 | self.window_size = window_size 52 | self.stride = stride 53 | 54 | gain_ifft = (2.0*self.stride) / self.window_size 55 | self.window_func = gain_ifft * np.hanning(self.window_size) 56 | 57 | coef_cos = np.zeros((self.fftsize//2 + 1, 1, self.window_size)) 58 | coef_sin = np.zeros((self.fftsize//2 + 1, 1, self.window_size)) 59 | for w in range(self.fftsize//2+1): 60 | alpha = 1.0 if w==0 or w==fftsize//2 else 2.0 61 | alpha /= fftsize 62 | for t in range(self.window_size): 63 | coef_cos[w, 0, t] = alpha * np.cos(2. * np.pi * w * t / self.fftsize) 64 | coef_sin[w, 0, t] = alpha * np.sin(2. * np.pi * w * t / self.fftsize) 65 | 66 | self.coef_cos = th.tensor(coef_cos * self.window_func, dtype=th.float) 67 | self.coef_sin = th.tensor(coef_sin * self.window_func, dtype=th.float) 68 | self.decoder_re = nn.ConvTranspose1d(self.fftsize//2+1, 1, self.window_size, bias=False, stride=self.stride) 69 | self.decoder_im = nn.ConvTranspose1d(self.fftsize//2+1, 1, self.window_size, bias=False, stride=self.stride) 70 | self.decoder_re.weight = th.nn.Parameter(self.coef_cos) 71 | self.decoder_im.weight = th.nn.Parameter(self.coef_sin) 72 | 73 | if trainable: 74 | self.decoder_re.weight.requires_grad = True 75 | self.decoder_im.weight.requires_grad = True 76 | else: 77 | self.decoder_re.weight.requires_grad = False 78 | self.decoder_im.weight.requires_grad = False 79 | 80 | def forward(self, input): # (B,T,F,2) 81 | input = input.permute([0, 2, 1, 3]) # (B,F,T,2) 82 | real_part = input[:,:,:,0] 83 | imag_part = input[:,:,:,1] 84 | 85 | time_cos = self.decoder_re(real_part) 86 | time_sin = self.decoder_im(imag_part) 87 | output = time_cos - time_sin 88 | 89 | return output # (B, 1, n_sample) 90 | -------------------------------------------------------------------------------- /doa-estimation/model.py: -------------------------------------------------------------------------------- 1 | #!/user/bin/env python 2 | # yangyi@2020-2022 3 | # doa estimation via self-attention 4 | # parameters:282k computation:5.2G/10s 2.1G/4s 5 | 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | 10 | from sigprocess import STFT, ISTFT 11 | from base import dense_block, attention_block 12 | from asteroid.engine.optimizers import make_optimizer 13 | 14 | def make_model_and_optimizer(conf): 15 | model = proposed() 16 | optimizer = make_optimizer(model.parameters(), **conf['optim']) 17 | return model, optimizer 18 | 19 | class proposed(nn.Module): 20 | def __init__(self, fftsize=512, window_size=400, stride=100, channel=4, attention_type='TF', causal=False): 21 | super(proposed, self).__init__() 22 | bins = fftsize // 2 23 | self.channel = channel 24 | self.attention_type = attention_type 25 | self.causal = causal 26 | 27 | self.stft = STFT(fftsize=fftsize, window_size=window_size, stride=stride, trainable=True) 28 | # increase or decrease the channel dim 29 | self.input_conv_layer = torch.nn.Sequential( 30 | torch.nn.Conv2d(8,8,[1,1],[1,1]), 31 | torch.nn.ReLU(), 32 | torch.nn.BatchNorm2d(8) 33 | ) 34 | # dense conv block 35 | self.conv_block = nn.ModuleList() 36 | for i in range (4): 37 | self.conv_block.append(dense_block(in_channels=8,out_channels=8,kernel_size=[2,3],stride=[1,1],padding=[1,1])) 38 | # self-attention block 39 | self.shared_block = nn.ModuleList() 40 | for i in range (4): 41 | self.shared_block.append(attention_block(in_channels=bins//(4**i),out_channels=bins//(4**(i+1)))) 42 | self.shared_block.append(attention_block(in_channels=8,out_channels=8)) 43 | self.shared_block.append(attention_block(in_channels=bins//(4**(i+1)),out_channels=bins//(4**(i+1)))) 44 | 45 | self.re_fc_layer = nn.Linear(8,6) 46 | 47 | 48 | def forward(self, x): # b n c 49 | x = x.permute(0,2,1) # b c n 50 | xs = self.stft(x[:,[0],:])[...,1:,:].unsqueeze(1) # b 1 t f 2 51 | for i in range(1,self.channel): 52 | xs = torch.cat((xs,self.stft(x[:,[i],:])[...,1:,:].unsqueeze(1)),1) # b c t f 2 53 | feat_in = torch.cat((xs[...,0], xs[...,1]), 1) # b 2c t f 54 | # step1:change channel dim 55 | x_in = self.input_conv_layer(feat_in) # b 2c t f 56 | 57 | for i in range (4): 58 | # step2:dense block 59 | x_out1 = self.conv_block[i](x_in) # (B,8,T,256) (B,8,T,64) (B,8,T,16) (B,8,T,4) 60 | 61 | # step3:self-attention T 62 | x_in2 = x_out1.permute(0,3,1,2) # (B,256,8,T) 63 | x_out2, plot_weight_t = self.shared_block[0+i*3](x_in2,causal=self.causal) # (B,64,8,T) 64 | x_out2 = x_out2.permute(0,2,3,1) 65 | 66 | # step4:self-attention F 67 | if (self.attention_type=='TF' or self.attention_type=='TFC'): 68 | x_in3 = x_out2 # (B,8,T,64) 69 | x_out3, plot_weight_f = self.shared_block[1+i*3](x_in3,causal=self.causal) # (B,8,T,64) 70 | else: 71 | x_out3 = x_out2 72 | 73 | # step5:self-attention C 74 | if (self.attention_type=='TFC'): 75 | x_in4 = (x_out2+x_out3).permute(0,3,2,1) # (B,4,T,8) 76 | x_out4, plot_weight_c = self.shared_block[2+i*3](x_in4,causal=self.causal) # (B,64,T,8) 77 | x_out4 = x_out4.permute(0,3,2,1) 78 | else: 79 | x_out4 = x_out3 80 | 81 | x_in = x_out3+x_out4 82 | 83 | x_out = x_in.squeeze(-1).permute(0,2,1) # b t 8 84 | y = self.re_fc_layer(x_out) # b t 6 85 | 86 | return y.mean(1) 87 | 88 | 89 | if __name__ == "__main__": 90 | import torch 91 | from thop import profile 92 | from thop import clever_format 93 | 94 | model = proposed(attention_type='TF', causal=False) 95 | x = torch.randn(1, 16000*8, 7) # b n c 96 | macs, params = profile(model, inputs=(x)) 97 | macs, params = clever_format([macs, params], "%.3f") 98 | 99 | print('macs:', macs) 100 | print('params:', params) 101 | -------------------------------------------------------------------------------- /doa-estimation/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import numpy as np 5 | import random 6 | from tqdm import tqdm 7 | 8 | import torch 9 | from torch.optim.lr_scheduler import ReduceLROnPlateau 10 | from torch.utils.data import DataLoader 11 | import pytorch_lightning as pl 12 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping 13 | 14 | from dataset import Librispeech_Dataset 15 | from asteroid.engine.optimizers import make_optimizer 16 | from system import System, label_loss 17 | from model import make_model_and_optimizer 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--train_dirs", type=str, required=True, help="Training dataset") 21 | parser.add_argument("--val_dirs", type=str, required=True, help="Val dataset") 22 | parser.add_argument("--exp_dir", default="exp/tmp", help="Full path to save best validation model") 23 | 24 | def _worker_init_fn_(worker_id): 25 | torch_seed = torch.initial_seed() 26 | 27 | random.seed(torch_seed + worker_id) 28 | if torch_seed >= 2**32: 29 | torch_seed = torch_seed % 2**32 30 | np.random.seed(torch_seed + worker_id) 31 | 32 | def main(conf): 33 | train_dir = conf["main_args"]['train_dirs'] 34 | val_dir = conf["main_args"]['val_dirs'] 35 | 36 | rirNO_train = len(os.listdir(train_dir)) 37 | rirNO_val = len(os.listdir(val_dir)) 38 | 39 | train_set = Librispeech_Dataset( 40 | train_dir, 41 | rirNO_train, 42 | trainingNO = 10000, 43 | segment=8, 44 | channel=[0,1,2,3], 45 | ) 46 | 47 | val_set = Librispeech_Dataset( 48 | val_dir, 49 | rirNO_val, 50 | trainingNO = 2000, 51 | segment=8, 52 | channel=[0,1,2,3], 53 | ) 54 | 55 | train_loader = DataLoader( 56 | train_set, 57 | shuffle=True, 58 | batch_size=conf["training"]["batch_size"], 59 | num_workers=conf["training"]["num_workers"], 60 | drop_last=True, 61 | worker_init_fn=_worker_init_fn_ 62 | ) 63 | val_loader = DataLoader( 64 | val_set, 65 | shuffle=False, 66 | batch_size=conf["training"]["batch_size"], 67 | num_workers=conf["training"]["num_workers"], 68 | drop_last=True, 69 | worker_init_fn=_worker_init_fn_ 70 | ) 71 | 72 | # Define model and optimizer 73 | model, optimizer = make_model_and_optimizer(conf) 74 | 75 | # Define scheduler 76 | scheduler = None 77 | if conf["training"]["half_lr"]: 78 | scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5) 79 | 80 | # Just after instantiating, save the args. Easy loading in the future. 81 | exp_dir = conf["main_args"]["exp_dir"] 82 | os.makedirs(exp_dir, exist_ok=True) 83 | conf_path = os.path.join(exp_dir, "conf.yml") 84 | with open(conf_path, "w") as outfile: 85 | yaml.safe_dump(conf, outfile) 86 | 87 | # Define Loss function. 88 | loss_func = label_loss() 89 | 90 | system = System( 91 | model=model, 92 | loss_func=loss_func, 93 | optimizer=optimizer, 94 | train_loader=train_loader, 95 | val_loader=val_loader, 96 | scheduler=scheduler, 97 | config=conf, 98 | ) 99 | 100 | # Define callbacks 101 | checkpoint_dir = os.path.join(exp_dir, "checkpoints/") 102 | checkpoint = ModelCheckpoint( 103 | checkpoint_dir, monitor="val_loss", mode="min", save_top_k=5, verbose=True 104 | ) 105 | early_stopping = False 106 | if conf["training"]["early_stop"]: 107 | early_stopping = EarlyStopping(monitor="val_loss", patience=30, verbose=True) 108 | 109 | # Don't ask GPU if they are not available. 110 | gpus = -1 if torch.cuda.is_available() else None 111 | trainer = pl.Trainer( 112 | max_epochs=conf["training"]["epochs"], 113 | checkpoint_callback=checkpoint, 114 | #resume_from_checkpoint='exp/epoch=46.ckpt', 115 | early_stop_callback=early_stopping, 116 | default_root_dir=exp_dir, 117 | gpus=gpus, 118 | distributed_backend="dp", 119 | train_percent_check=1.0, # Useful for fast experiment 120 | gradient_clip_val=5.0, 121 | ) 122 | trainer.fit(system) 123 | 124 | best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} 125 | with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f: 126 | json.dump(best_k, f, indent=0) 127 | 128 | if __name__ == "__main__": 129 | 130 | import yaml 131 | from asteroid.utils import prepare_parser_from_dict, parse_args_as_dict 132 | 133 | with open("conf.yml") as f: 134 | def_conf = yaml.safe_load(f) 135 | parser = prepare_parser_from_dict(def_conf, parser=parser) 136 | 137 | arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True) 138 | main(arg_dic) 139 | -------------------------------------------------------------------------------- /doa-estimation/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from numpy import inf 4 | 5 | """ 6 | # ref: 2021, Attention is All You Need in Speech Separation 7 | # https://arxiv.org/abs/2010.13154v2 8 | # ref: 2021, Dense CNN with Self-Attention for Time-Domain Speech Enhancement 9 | # https://arxiv.org/abs/2009.01941 10 | # time delay:stft 25ms 11 | # positional encoding may be added into it. original chunk_size=250 12 | """ 13 | class dense_block(nn.Module): 14 | def __init__(self,in_channels,out_channels,kernel_size=[2,3],stride=[1,1],padding=[1,1],bias=False): 15 | super(dense_block,self).__init__() 16 | self.shared_block = nn.ModuleList() 17 | for i in range(5): 18 | self.shared_block.append(nn.Conv2d(in_channels*(i+1),out_channels,kernel_size,stride,padding,bias=bias)) 19 | self.shared_block.append(nn.PReLU()) 20 | self.shared_block.append(torch.nn.BatchNorm2d(out_channels)) 21 | 22 | def forward(self,x): 23 | for i in range(5): 24 | x1 = self.shared_block[0+3*i](x)[:,:,:x.size()[-2],:] 25 | x2 = self.shared_block[1+3*i](x1) 26 | x3 = self.shared_block[2+3*i](x2) 27 | x = torch.cat((x,x3),1) 28 | 29 | return x3 30 | 31 | class attention_block(nn.Module): 32 | def __init__(self,in_channels,out_channels,kernel_size=[1,1],stride=[1,1],padding=[0,0]): 33 | super(attention_block,self).__init__() 34 | self.conv_Q = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,padding=padding) 35 | self.conv_K = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,padding=padding) 36 | self.conv_V = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,padding=padding) 37 | 38 | def forward(self,x,causal=False): 39 | Q = self.conv_Q(x).reshape(x.size()[0],-1,x.size()[3]) 40 | K = self.conv_K(x).reshape(x.size()[0],-1,x.size()[3]) 41 | V = self.conv_V(x).reshape(x.size()[0],-1,x.size()[3]) 42 | # scaled dpd 43 | _attention_weight = (torch.einsum('ikj,ijl->ikl', [Q.permute(0,2,1), K]) / (Q.size()[1])**0.5) 44 | 45 | if (causal==True): 46 | # causal attention 47 | _attention_weight = torch.triu(_attention_weight) 48 | mask = torch.ones_like(_attention_weight) * float(-inf) 49 | _attention_weight = torch.where(_attention_weight != 0, _attention_weight, mask) 50 | 51 | attention_weight = _attention_weight.softmax(dim=-1) 52 | x_out = torch.einsum('ijl,ilk->ijk', [V, attention_weight.permute(0,2,1)]) 53 | x_out = x_out.reshape(x.size()[0],-1,x.size()[2],x.size()[3]) 54 | plot_weight = _attention_weight 55 | 56 | return x_out, plot_weight 57 | 58 | """ 59 | dataset 60 | """ 61 | import numpy as np 62 | import math 63 | EPS=1e-8 64 | 65 | def rms(y): 66 | return np.sqrt(np.mean(np.abs(y) ** 2, axis=0, keepdims=False)) 67 | 68 | def get_amplitude_scaling_factor(s, n, snr, method='rms'): 69 | original_sn_rms_ratio = rms(s) / rms(n) 70 | target_sn_rms_ratio = 10. ** (float(snr) / 20.) # snr = 20 * lg(rms(s) / rms(n)) 71 | signal_scaling_factor = target_sn_rms_ratio / original_sn_rms_ratio 72 | return signal_scaling_factor 73 | 74 | def get_label(_source_positions, _sensor_positions, usage): 75 | """ 76 | Extract Label of Raw Wav. 77 | Arguments: 78 | _source_positions: source positions, 3 x 2 x channel 79 | _sensor_positions: sensor positions, 3 x 2 x channel 80 | loss_type: categorical or cartesian 81 | Return: 82 | y: (1,) if categorical 83 | (3,2) if cartesian 84 | """ 85 | ''' 86 | _source_positions = np.array([[[1.0], [1.0]], 87 | [[0.0], [1.0]], 88 | [[0.0], [1.0]]]) 89 | _sensor_positions = np.array([[[0.0, 1.0, -1.0, 0.0]], 90 | [[0.0, 1.0, 0.0, 0.0]], 91 | [[0.0, 0.0, 0.0, 0.0]]]) 92 | ''' 93 | #print('src:',_source_positions) 94 | #print('sen:',_sensor_positions) 95 | if (usage == 'simu'): 96 | # step1:translation 97 | x0 = _source_positions[0,:,0]; y0 = _source_positions[1,:,0]; z0 = _source_positions[2,:,0] 98 | x1 = -(_sensor_positions[0,:,0] - x0); y1 = -(_sensor_positions[1,:,0] - y0); z1 = -(_sensor_positions[2,:,0] - z0) 99 | 100 | ref_x0 = _sensor_positions[0,:,1]; ref_y0 = _sensor_positions[1,:,1]; ref_z0 = _sensor_positions[2,:,1] 101 | ref_x1 = -(_sensor_positions[0,:,0] - ref_x0); ref_y1 = -(_sensor_positions[1,:,0] - ref_y0); ref_z1 = -(_sensor_positions[2,:,0] - ref_z0) 102 | 103 | # step2:rotation-azimuth 104 | theta = np.arctan2(ref_y1,ref_x1) 105 | x2 = x1 * np.cos(theta) + y1 * np.sin(theta) 106 | y2 = y1 * np.cos(theta) - x1 * np.sin(theta) 107 | z2 = z1 108 | 109 | # step3:rotation-elevation 110 | phi = np.arctan2(ref_z1,np.sqrt(ref_x1 ** 2 + ref_y1 ** 2)) 111 | x3 = x2 * np.cos(phi) + z2 * np.sin(phi) 112 | y3 = y2 113 | z3 = z2 * np.cos(phi) - x2 * np.sin(phi) 114 | 115 | elif (usage == 'dummy'): 116 | x0 = (_sensor_positions[0,:,0] + _sensor_positions[0,:,2]) / 2; y0 = _sensor_positions[1,:,0] 117 | y3 = _source_positions[1,:,0] - y0 118 | x3 = _source_positions[0,:,0] - x0 119 | z3 = _source_positions[2,:,0] - z0 120 | 121 | y = np.vstack((x3,y3,z3)) # 3 src 122 | #print('dis:',dis) 123 | 124 | return y 125 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DOA-estimation-with-a-stacked-self-attention-network 2 | **A stacked self-attention network for two-dimensional direction-of-arrival estimation in hands-free speech communication** 3 | 4 | **This work has been published on *Journal of the Acoustical Society of America (JASA).* The paper is available [here][Paper].** 5 | 6 | ## Contents 7 | * **[DOA-estimation-with-a-stacked-self-attention-network](#doa-estimation-with-a-stacked-self-attention-network)** 8 | * **[Contents](#contents)** 9 | * **[Introduction](#introduction)** 10 | * **[Dataset](#dataset)** 11 | * **[Requirement](#requirement)** 12 | * **[Train](#train)** 13 | * **[Test](#test)** 14 | * **[Results](#results)** 15 | * **[Citation](#citation)** 16 | * **[References](#references)** 17 | 18 | ## Introduction 19 | **When making voice interactions with hands-free speech communication devices, direction-of-arrival estimation is an essential step. To address the detrimental influence of unavoidable background noise and interference speech on direction-of-arrival estimation, we introduce a stacked self-attention network system, a supervised deep learning method that enables utterance level estimation without requirement for any pre-processing such as voice activity detection. Specifically, alternately stacked time- and frequency-dependent self-attention blocks are designed to process information in terms of time and frequency, respectively. The former blocks focus on the importance of each time frame of the received audio mixture and perform temporal selection to reduce the influence of non-speech and interference frames, while the latter blocks are utilized to derive inner-correlation among different frequencies. Additionally, the non-causal convolution and self-attention networks are replaced by causal ones, enabling real-time direction-of-arrival estimation with a latency of only 6.25 ms. Experiments with simulated and measured room impulse responses, as well as real recordings, verify the advantages of the proposed method over the state-of-the-art baselines.** 20 | 21 | ![image](https://github.com/yangyi0818/DOA-estimation-with-a-stacked-self-attention-network/blob/main/figures/model-architecture1.png) 22 | ![image](https://github.com/yangyi0818/DOA-estimation-with-a-stacked-self-attention-network/blob/main/figures/model-architecture2.png) 23 | 24 | ## Dataset 25 | **We use [sms_wsj][sms_wsj] to generate room impulse responses (RIRs) set. ```sms_wsj/reverb/scenario.py``` and ```sms_wsj/database/create_rirs.py``` should be replaced by scripts in 'sms_wsj_replace' folder.** 26 | 27 | **use ```python generate_rir.py``` to generate training and valadation data** 28 | 29 | ## Requirement 30 | **Our script use [asteroid][asteroid] toolkit as the basic environment.** 31 | 32 | ## Train 33 | **We recommend running to train end-to-end :** 34 | 35 | **```./run.sh --id 0,1,2,3```** 36 | 37 | **or :** 38 | 39 | **```./run.sh --id 0,1,2,3 --stage 1```** 40 | 41 | ## Test 42 | **```./run.sh --id 0 --stage 2```** 43 | 44 | ## Results 45 | **The average MAE (degree), Acc. (%), model parameters, and latency of the real-time implementation of the proposed system and the CNN-baseline [1] on SS condition of all simulated test sets. (E_theta = 15 degree).** 46 | 47 | |**Measure** |**MAE** |**Acc.**|**Parameters**|**Latency**| 48 | | :----- | :----: | :----: | :----: | :----: | 49 | |**CNN [1]** |3.6 |99.3 |8.7M |14 ms | 50 | |**Proposed**|2.9 |99.5 |282k |6.25 ms | 51 | 52 | **The average MAE (degree) and Acc. (%) of the off-line and real-time implementations of the proposed system for each overlap condition on all simulated test sets. (E_theta = 15 degree).** 53 | 54 | |**Overlap condition**|**SS** |**SS** |**IO** |**IO** |**PO** |**PO** | 55 | | :----- | :----: | :----: | :----: | :----: | :----: | :----: | 56 | |**Measure** |**MAE** |**Acc.**|**MAE** |**Acc.**|**MAE** |**Acc.**| 57 | |**Off-line** |4.3 |97.4 |8.3 |88.8 |7.2 |91.1 | 58 | |**Real-time** |5.2 |95.8 |8.6 |86.5 |9.0 |84.6 | 59 | 60 | **An attention map of an example speech utterance (room dimension = 7.0 m × 6.0 m × 3.2 m, RT60 = 400 ms, SNR = 20 dB, SIR = 0 dB). The ground-truth and estimated azimuths for the target speaker (speaker A) are 1.6 degree and 2.3 degree, respectively. The ground-truth azimuth for the interference speaker (speaker B) is 125.6 degree. The horizontal and vertical axes represent the frame index of interest and the frames to which it attends. The log power spectrums of the input mixture, reverberated utterances of speaker A and speaker B are also attached on the top and left, respectively.** 61 | 62 | ![image](https://github.com/yangyi0818/DOA-estimation-with-a-stacked-self-attention-network/blob/main/figures/attention-map.png) 63 | 64 | **The off-line and real-time 2-D DOA estimation curves for each overlap condition.** 65 | 66 | ![image](https://github.com/yangyi0818/DOA-estimation-with-a-stacked-self-attention-network/blob/main/figures/real-time-curve-SS.png) 67 | ![image](https://github.com/yangyi0818/DOA-estimation-with-a-stacked-self-attention-network/blob/main/figures/real-time-curve-IO.png) 68 | ![image](https://github.com/yangyi0818/DOA-estimation-with-a-stacked-self-attention-network/blob/main/figures/real-time-curve-PO.png) 69 | 70 | ## Citation 71 | **Cite our paper by:** 72 | 73 | **@article{yang2022stacked,** 74 | 75 | **title={A stacked self-attention network for two-dimensional direction-of-arrival estimation in hands-free speech communication},** 76 | 77 | **author={Yang, Yi and Chen, Hangting and Zhang, Pengyuan},** 78 | 79 | **journal={The Journal of the Acoustical Society of America},** 80 | 81 | **volume={152},** 82 | 83 | **number={6},** 84 | 85 | **pages={3444--3457},** 86 | 87 | **year={2022},** 88 | 89 | **publisher={Acoustical Society of America}** 90 | 91 | **}** 92 | 93 | ## References 94 | **[1] A. Kucuk, A. Ganguly, Y. Hao, and I. M. S. Panahi, "Real-time convolutional neural network-based speech source localization on smartphone," IEEE Access 7, 169969–169978 (2019).** 95 | 96 | **Please feel free to contact us if you have any questions.** 97 | 98 | [Paper]: https://doi.org/10.1121/10.0016467 99 | [sms_wsj]: https://github.com/fgnt/sms_wsj 100 | [asteroid]: https://github.com/asteroid-team/asteroid 101 | 102 | -------------------------------------------------------------------------------- /sms_wsj_replace/scenario.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helps to quickly create source and sensor positions. 3 | Try it with the following code: 4 | >>> import numpy as np 5 | >>> import sms_wsj.reverb.scenario as scenario 6 | >>> src = scenario.generate_random_source_positions(dims=2, sources=1000) 7 | >>> src[1, :] = np.abs(src[1, :]) 8 | >>> mic = scenario.generate_sensor_positions(shape='linear', scale=0.1, number_of_sensors=6) 9 | """ 10 | 11 | import numpy as np 12 | from sms_wsj.reverb.rotation import rot_x, rot_y, rot_z 13 | 14 | 15 | def sample_from_random_box(center, edge_lengths, rng=np.random): 16 | """ Sample from a random box to get somewhat random locations. 17 | >>> points = np.asarray([sample_from_random_box( 18 | ... [[10], [20], [30]], [[1], [2], [3]] 19 | ... ) for _ in range(1000)]) 20 | >>> import matplotlib.pyplot as plt 21 | >>> from mpl_toolkits.mplot3d import Axes3D 22 | >>> fig = plt.figure() 23 | >>> ax = fig.add_subplot(111, projection='3d') 24 | >>> _ = ax.scatter(points[:, 0, 0], points[:, 1, 0], points[:, 2, 0]) 25 | >>> _ = plt.show() 26 | Args: 27 | center: Original center (mean). 28 | edge_lengths: Edge length of the box to be sampled from. 29 | Returns: 30 | """ 31 | center = np.asarray(center) 32 | edge_lengths = np.asarray(edge_lengths) 33 | return center + rng.uniform( 34 | low=-edge_lengths / 2, 35 | high=edge_lengths / 2 36 | ) 37 | 38 | def generate_sensor_positions( 39 | shape='cube', 40 | center=np.zeros((3, 1), dtype=np.float), 41 | room_dimensions = [[6], [4], [3]], 42 | scale=0.01, 43 | number_of_sensors=None, 44 | jitter=None, 45 | rng=np.random, 46 | rotate_x=0, rotate_y=0, rotate_z=0 47 | ): 48 | """ Generate different sensor configurations. 49 | Sensors are index counter-clockwise starting with the 0th sensor below 50 | the x axis. This is done, such that the first two sensors point towards 51 | the x axis. 52 | :param shape: A shape, i.e. 'cube', 'triangle', 'linear' or 'circular'. 53 | :param center: Numpy array with shape (3, 1) 54 | which holds coordinates x, y and z. 55 | :param scale: Scalar responsible for scale of the array. See individual 56 | implementations, if it is used as radius or edge length. 57 | :param jitter: Add random Gaussian noise with standard deviation ``jitter`` 58 | to sensor positions. 59 | :return: Numpy array with shape (3, number_of_sensors). 60 | """ 61 | 62 | center = np.array(center) 63 | if center.ndim == 1: 64 | center = center[:, None] 65 | 66 | if shape == 'circular_center': 67 | radius = scale 68 | delta_phi = 2 * np.pi / (number_of_sensors - 1) 69 | phi_0 = delta_phi / 2 70 | phi = np.arange(0, number_of_sensors-1) * delta_phi - phi_0 71 | sensor_positions_cir = np.asarray([ 72 | radius * np.cos(phi), 73 | radius * np.sin(phi), 74 | np.zeros(phi.shape) 75 | ]) 76 | sensor_positions_cen = np.asarray([ 77 | [0], 78 | [0], 79 | [0] 80 | ]) 81 | sensor_positions = np.hstack([sensor_positions_cen, sensor_positions_cir]) 82 | 83 | elif shape == 'locata_dummy': 84 | sensor_positions = np.asarray( 85 | [ 86 | # 0 1 2 3 87 | [-0.079, -0.079, 0.079, 0.079], 88 | [ 0.000, -0.009, 0.000, -0.009], 89 | [ 0.000, 0.000, 0.000, 0.000] 90 | ] 91 | ) 92 | 93 | else: 94 | raise NotImplementedError('Given shape is not implemented.') 95 | 96 | # NOTE rotation 97 | #sensor_positions = rot_x(rotate_x) @ sensor_positions 98 | #sensor_positions = rot_y(rotate_y) @ sensor_positions 99 | #sensor_positions = rot_z(rotate_z) @ sensor_positions 100 | 101 | if jitter is not None: 102 | sensor_positions += rng.normal( 103 | 0., jitter, size=sensor_positions.shape 104 | ) 105 | 106 | return np.asarray(sensor_positions + center) 107 | 108 | def generate_random_source_positions( 109 | center=np.zeros((3, 1)), 110 | sources=1, 111 | distance_interval=(1, 2), 112 | dims=2, 113 | minimum_angular_distance=None, 114 | maximum_angular_distance=None, 115 | rng=np.random 116 | ): 117 | """ Generates random positions on a hollow sphere or circle. 118 | Samples are drawn from a uniform distribution on a hollow sphere with 119 | inner and outer radius according to distance_interval. 120 | The idea is to sample from an angular centric Gaussian distribution. 121 | Params: 122 | center 123 | sources 124 | distance_interval 125 | dims 126 | minimum_angular_distance: In randiant or None. 127 | maximum_angular_distance: In randiant or None. 128 | rng: Random number generator, if you need to set the seed. 129 | """ 130 | enforce_angular_constrains = ( 131 | minimum_angular_distance is not None or 132 | maximum_angular_distance is not None 133 | ) 134 | 135 | if not dims == 2 and enforce_angular_constrains: 136 | raise NotImplementedError( 137 | 'Only implemented distance constraints for 2D.' 138 | ) 139 | 140 | accept = False 141 | while not accept: 142 | x = rng.normal(size=(3, sources)) 143 | if dims == 2: 144 | x[2, :] = 0 145 | 146 | if enforce_angular_constrains: 147 | if not sources == 2: 148 | raise NotImplementedError 149 | angle = np.arctan2(x[1, :], x[0, :]) 150 | difference = np.angle( 151 | np.exp(1j * (angle[None, :], angle[:, None]))) 152 | difference = difference[np.triu_indices_from(difference, k=1)] 153 | distance = np.abs(difference) 154 | if ( 155 | minimum_angular_distance is not None and 156 | minimum_angular_distance > np.min(distance) 157 | ): 158 | continue 159 | if ( 160 | maximum_angular_distance is not None and 161 | maximum_angular_distance < np.max(distance) 162 | ): 163 | continue 164 | accept = True 165 | 166 | x /= np.linalg.norm(x, axis=0) # 单位方向向量 167 | 168 | radius = rng.uniform( 169 | distance_interval[0] ** dims, 170 | distance_interval[1] ** dims, 171 | size=(1, sources) 172 | ) ** (1 / dims) 173 | 174 | x *= radius 175 | 176 | return np.asarray(x + center) 177 | -------------------------------------------------------------------------------- /doa-estimation/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import soundfile as sf 4 | import torch 5 | from torch import nn 6 | import yaml 7 | import json 8 | import argparse 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | from asteroid import torch_utils 13 | from asteroid.utils import tensors_to_device 14 | from model import make_model_and_optimizer 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--test_dir", type=str, required=True, help="Test directory including the json files") 18 | parser.add_argument("--use_gpu", type=int, default=0, help="Whether to use the GPU for model execution") 19 | parser.add_argument("--exp_dir", default="exp/tmp", help="Experiment root") 20 | 21 | def load_best_model(model, exp_dir): 22 | try: 23 | with open(os.path.join(exp_dir, 'best_k_models.json'), "r") as f: 24 | best_k = json.load(f) 25 | best_model_path = min(best_k, key=best_k.get) 26 | except FileNotFoundError: 27 | all_ckpt = os.listdir(os.path.join(exp_dir, 'checkpoints/')) 28 | all_ckpt=[(ckpt,int("".join(filter(str.isdigit,ckpt)))) for ckpt in all_ckpt] 29 | all_ckpt.sort(key=lambda x:x[1]) 30 | best_model_path = os.path.join(exp_dir, 'checkpoints', all_ckpt[-1][0]) 31 | print( 'LOADING from ',best_model_path) 32 | # Load checkpoint 33 | checkpoint = torch.load(best_model_path, map_location='cpu') 34 | # Load state_dict into model. 35 | model = torch_utils.load_state_dict_in(checkpoint['state_dict'], model) 36 | model = model.eval() 37 | return model 38 | 39 | 40 | def main(conf): 41 | azimuth_resolution = np.array([2.5,5,10,15,20,25,30,35,40]) 42 | True_est_azimuth1, True_est_azimuth2, True_est_azimuth = np.zeros(9), np.zeros(9), np.zeros(9) 43 | azimuth_mean_loss1, azimuth_mean_loss2, azimuth_mean_loss = 0, 0, 0 44 | 45 | model, _ = make_model_and_optimizer(train_conf) 46 | model = load_best_model(model, conf['exp_dir']) 47 | testset = conf['test_dir'] 48 | 49 | if conf["use_gpu"]: 50 | model.cuda() 51 | model_device = next(model.parameters()).device 52 | 53 | dlist = os.listdir(testset) 54 | pbar = tqdm(range(len(dlist))) 55 | torch.no_grad().__enter__() 56 | for idx in pbar: 57 | test_wav = np.load(testset + dlist[idx]) 58 | mix, label, mix_way = tensors_to_device([torch.from_numpy(test_wav['mix']), torch.from_numpy(test_wav['label']), test_wav['mix_way']], device=model_device) 59 | est_label = model(mix[None]) 60 | 61 | # unbiased 62 | if mix_way == 'single': 63 | 64 | # biased 65 | #if mix_way == 'single' or mix_way == 'dominant': 66 | label = label[:,0]; est_label = est_label[:,:3] 67 | 68 | # accuracy 69 | label_azimuth = (torch.atan2(label[1], label[0]) / np.pi * 180).cpu().numpy() 70 | est_azimuth = (torch.atan2(est_label[0,1], est_label[0,0]) / np.pi *180).cpu().numpy() 71 | 72 | error_azimuth = np.abs(label_azimuth - est_azimuth) 73 | if (error_azimuth > 180): 74 | error_azimuth = 360 - error_azimuth 75 | True_est_azimuth += (error_azimuth <= azimuth_resolution) 76 | azimuth_mean_loss += error_azimuth 77 | 78 | pbar.set_description(" {} {} {} {}".format('%.1f'%(azimuth_mean_loss / (idx+1)), '%.1f'%(error_azimuth), '%.1f'%(label_azimuth), '%.1f'%(est_azimuth))) 79 | 80 | else: 81 | label1 = label[:,0]; est_label1 = est_label[:,:3] 82 | label2 = label[:,1]; est_label2 = est_label[:,3:] 83 | 84 | # accuracy 85 | label_azimuth1 = (torch.atan2(label1[1], label1[0]) / np.pi * 180).cpu().numpy() 86 | label_azimuth2 = (torch.atan2(label2[1], label2[0]) / np.pi * 180).cpu().numpy() 87 | est_azimuth1 = (torch.atan2(est_label1[0,1], est_label1[0,0]) / np.pi *180).cpu().numpy() 88 | est_azimuth2 = (torch.atan2(est_label2[0,1], est_label2[0,0]) / np.pi *180).cpu().numpy() 89 | 90 | error_azimuth11 = np.abs(label_azimuth1 - est_azimuth1) 91 | error_azimuth22 = np.abs(label_azimuth2 - est_azimuth2) 92 | error_azimuth12 = np.abs(label_azimuth1 - est_azimuth2) 93 | error_azimuth21 = np.abs(label_azimuth2 - est_azimuth1) 94 | if (error_azimuth11 > 180): 95 | error_azimuth11 = 360 - error_azimuth11 96 | if (error_azimuth22 > 180): 97 | error_azimuth22 = 360 - error_azimuth22 98 | if (error_azimuth12 > 180): 99 | error_azimuth12 = 360 - error_azimuth12 100 | if (error_azimuth21 > 180): 101 | error_azimuth21 = 360 - error_azimuth21 102 | 103 | if error_azimuth11+error_azimuth22 < error_azimuth12+error_azimuth21: 104 | True_est_azimuth1 += (error_azimuth11 <= azimuth_resolution) 105 | azimuth_mean_loss1 += error_azimuth11 106 | True_est_azimuth2 += (error_azimuth22 <= azimuth_resolution) 107 | azimuth_mean_loss2 += error_azimuth22 108 | error_azimuth = error_azimuth11 + error_azimuth22 109 | else: 110 | True_est_azimuth1 += (error_azimuth12 <= azimuth_resolution) 111 | azimuth_mean_loss1 += error_azimuth12 112 | True_est_azimuth2 += (error_azimuth21 <= azimuth_resolution) 113 | azimuth_mean_loss2 += error_azimuth21 114 | error_azimuth = error_azimuth12 + error_azimuth21 115 | 116 | azimuth_mean_loss = (azimuth_mean_loss1 + azimuth_mean_loss2) / 2 117 | error_azimuth /= 2 118 | True_est_azimuth = (True_est_azimuth1 + True_est_azimuth2) / 2 119 | 120 | pbar.set_description(" {} {} {} {} {} {}".format('%.1f'%(azimuth_mean_loss / (idx+1)), '%.1f'%(error_azimuth), \ 121 | '%.1f'%(label_azimuth1), '%.1f'%(label_azimuth2), '%.1f'%(est_azimuth1), '%.1f'%(est_azimuth2))) 122 | 123 | azimuth_mean_loss /= len(dlist) 124 | print('azimuth MAE in degree: ', '%.2f'%(azimuth_mean_loss)) 125 | for i in range (len(azimuth_resolution)): 126 | print('Acc. on azimuth resolution ', azimuth_resolution[i], ' : ', '%.3f'%(True_est_azimuth[i]/len(dlist))) 127 | 128 | 129 | if __name__ == "__main__": 130 | args = parser.parse_args() 131 | arg_dic = dict(vars(args)) 132 | 133 | # Load training config 134 | conf_path = os.path.join(args.exp_dir, "conf.yml") 135 | with open(conf_path) as f: 136 | train_conf = yaml.safe_load(f) 137 | arg_dic["sample_rate"] = train_conf["data"]["sample_rate"] 138 | arg_dic["train_conf"] = train_conf 139 | 140 | main(arg_dic) 141 | -------------------------------------------------------------------------------- /doa-estimation/system.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pytorch_lightning as pl 4 | from torch import nn 5 | from argparse import Namespace 6 | from typing import Callable, Optional 7 | from torch.optim.optimizer import Optimizer 8 | from asteroid.utils import flatten_dict 9 | 10 | from torch.nn.modules.loss import _Loss 11 | from asteroid.utils.deprecation_utils import DeprecationMixin 12 | EPS = 1e-8 13 | 14 | class System(pl.LightningModule): 15 | def __init__( 16 | self, 17 | model, 18 | optimizer, 19 | loss_func, 20 | train_loader, 21 | val_loader=None, 22 | scheduler=None, 23 | config=None, 24 | ): 25 | super().__init__() 26 | self.model = model 27 | self.optimizer = optimizer 28 | self.loss_func = loss_func 29 | self.train_loader = train_loader 30 | self.val_loader = val_loader 31 | self.scheduler = scheduler 32 | config = {} if config is None else config 33 | self.config = config 34 | self.hparams = Namespace(**self.config_to_hparams(config)) 35 | 36 | def forward(self, *args, **kwargs): 37 | return self.model(*args, **kwargs) 38 | 39 | def common_step(self, batch, batch_nb, train=True): 40 | inputs, mix_name, label, mix_way = batch 41 | est_label = self(inputs) 42 | loss, loss_dict = self.loss_func(label, est_label, mix_way) 43 | return loss, loss_dict 44 | 45 | def training_step(self, batch, batch_nb): 46 | loss, loss_dict = self.common_step(batch, batch_nb, train=True) 47 | tensorboard_logs = loss_dict 48 | return {"loss": loss, "log": tensorboard_logs, "progress_bar": tensorboard_logs} 49 | 50 | def validation_step(self, batch, batch_nb): 51 | loss, loss_dict = self.common_step(batch, batch_nb, train=False) 52 | tensorboard_logs = loss_dict 53 | return {"val_loss": loss, "log": tensorboard_logs} 54 | 55 | def validation_epoch_end(self, outputs): 56 | avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() 57 | tensorboard_logs = {"val_loss": avg_loss} 58 | return {"val_loss": avg_loss, "log": tensorboard_logs, "progress_bar": tensorboard_logs} 59 | 60 | def optimizer_step(self, *args, **kwargs) -> None: 61 | if self.scheduler is not None: 62 | if not isinstance(self.scheduler, (list, tuple)): 63 | self.scheduler = [self.scheduler] # support multiple schedulers 64 | for sched in self.scheduler: 65 | if isinstance(sched, dict) and sched["interval"] == "batch": 66 | sched["scheduler"].step() # call step on each batch scheduler 67 | super().optimizer_step(*args, **kwargs) 68 | 69 | def configure_optimizers(self): 70 | """ Required by pytorch-lightning. """ 71 | 72 | if self.scheduler is not None: 73 | if not isinstance(self.scheduler, (list, tuple)): 74 | self.scheduler = [self.scheduler] # support multiple schedulers 75 | epoch_schedulers = [] 76 | for sched in self.scheduler: 77 | if not isinstance(sched, dict): 78 | epoch_schedulers.append(sched) 79 | else: 80 | assert sched["interval"] in [ 81 | "batch", 82 | "epoch", 83 | ], "Scheduler interval should be either batch or epoch" 84 | if sched["interval"] == "epoch": 85 | epoch_schedulers.append(sched) 86 | return [self.optimizer], epoch_schedulers 87 | return self.optimizer 88 | 89 | def train_dataloader(self): 90 | return self.train_loader 91 | 92 | def val_dataloader(self): 93 | return self.val_loader 94 | 95 | def on_save_checkpoint(self, checkpoint): 96 | """ Overwrite if you want to save more things in the checkpoint.""" 97 | checkpoint["training_config"] = self.config 98 | return checkpoint 99 | 100 | @staticmethod 101 | def config_to_hparams(dic): 102 | dic = flatten_dict(dic) 103 | for k, v in dic.items(): 104 | if v is None: 105 | dic[k] = str(v) 106 | elif isinstance(v, (list, tuple)): 107 | dic[k] = torch.Tensor(v) 108 | return dic 109 | 110 | 111 | class label_loss(nn.Module): 112 | def __init__(self): 113 | super().__init__() 114 | self.mse_loss = torch.nn.MSELoss() 115 | 116 | def forward(self, label, est_label, mix_way): 117 | label_loss = 0 118 | 119 | # permutation 120 | for batch in range(label.size(0)): 121 | if mix_way[batch] == 'single': 122 | label_loss += self.mse_loss(est_label[:,:3], label[:,:,0]) 123 | else: 124 | label_loss1 = self.mse_loss(est_label[[batch],:3], label[[batch],:,0]) + self.mse_loss(est_label[[batch],3:], label[[batch],:,1]) 125 | label_loss2 = self.mse_loss(est_label[[batch],3:], label[[batch],:,0]) + self.mse_loss(est_label[[batch],:3], label[[batch],:,1]) 126 | if label_loss1 < label_loss2: 127 | label_loss += label_loss1 128 | else: 129 | label_loss += label_loss2 130 | label_loss /= label.size(0) 131 | 132 | 133 | # accuracy 134 | MAE1, MAE2 = 0, 0 135 | label_azimuth1 = torch.atan2(label[:,1,0], label[:,0,0]) / np.pi * 180 136 | label_azimuth2 = torch.atan2(label[:,1,1], label[:,0,1]) / np.pi * 180 137 | est_azimuth1 = torch.atan2(est_label[:,1], est_label[:,0]) / np.pi *180 138 | est_azimuth2 = torch.atan2(est_label[:,4], est_label[:,3]) / np.pi *180 139 | 140 | error_azimuth11 = torch.abs(label_azimuth1 - est_azimuth1) 141 | error_azimuth22 = torch.abs(label_azimuth2 - est_azimuth2) 142 | error_azimuth12 = torch.abs(label_azimuth1 - est_azimuth2) 143 | error_azimuth21 = torch.abs(label_azimuth2 - est_azimuth1) 144 | 145 | for batch in range (label.size(0)): 146 | if (error_azimuth11[batch] > 180): 147 | error_azimuth11[batch] = 360 - error_azimuth11[batch] 148 | if (error_azimuth22[batch] > 180): 149 | error_azimuth22[batch] = 360 - error_azimuth22[batch] 150 | if (error_azimuth12[batch] > 180): 151 | error_azimuth12[batch] = 360 - error_azimuth12[batch] 152 | if (error_azimuth21[batch] > 180): 153 | error_azimuth21[batch] = 360 - error_azimuth21[batch] 154 | 155 | for batch in range (label.size(0)): 156 | if error_azimuth11[batch]+error_azimuth22[batch] < error_azimuth12[batch]+error_azimuth21[batch]: 157 | MAE1 += error_azimuth11[batch] 158 | MAE2 += error_azimuth22[batch] 159 | else: 160 | MAE1 += error_azimuth12[batch] 161 | MAE2 += error_azimuth21[batch] 162 | 163 | MAE1 = MAE1 / label.size(0) 164 | MAE2 = MAE2 / label.size(0) 165 | 166 | loss_dict = dict(sig_loss=label_loss.mean(), MAE1=MAE1, MAE2=MAE2) 167 | 168 | return label_loss.mean(), loss_dict 169 | -------------------------------------------------------------------------------- /doa-estimation/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import numpy as np 4 | import os 5 | import soundfile as sf 6 | import math 7 | import random 8 | import shutil 9 | 10 | from base import rms, get_amplitude_scaling_factor, get_label 11 | from sms_wsj.database.create_rirs import config, scenarios, rirs 12 | from sms_wsj.reverb.reverb_utils import convolve 13 | 14 | EPS=1e-8 15 | 16 | class Librispeech_Dataset(data.Dataset): 17 | def __init__( 18 | self, 19 | reverb_matrixs_dir, 20 | rirNO = 5, 21 | trainingNO = 5000, 22 | segment = 6, 23 | channel = [0,1,2,3], 24 | overlap = [0.0, 0.1, 0.2, 0.3, 0.4], 25 | raw_dir = '/path/to/LibriSpeech/filelist-all/', 26 | noise_wav = '/path/to/noise/', 27 | ): 28 | super(Librispeech_Dataset, self).__init__() 29 | self.reverb_matrixs_dir = reverb_matrixs_dir 30 | self.rirNO = rirNO 31 | self.trainingNO = trainingNO 32 | self.segment = segment 33 | self.channel = channel 34 | self.overlap = overlap 35 | self.raw_dir = raw_dir 36 | self.noise_wav = noise_wav 37 | 38 | def __len__(self): 39 | return self.trainingNO 40 | 41 | def add_reverb(self,raw_dir1,raw_dir2,h_use): 42 | with open(raw_dir1,'r') as fin1: 43 | with open(raw_dir2,'r') as fin2: 44 | wav1 = fin1.readlines() 45 | wav2 = fin2.readlines() 46 | choose_wav = True 47 | while(choose_wav): 48 | i = np.random.randint(0,len(wav1)) 49 | j = np.random.randint(0,len(wav2)) 50 | w1,fs = sf.read(os.path.join('/path/to/LibriSpeech', wav1[i].rstrip("\n")), dtype="float32") 51 | w2,fs = sf.read(os.path.join('/path/to/LibriSpeech', wav2[j].rstrip("\n")), dtype="float32") 52 | seg_len = int(fs * self.segment) 53 | if (w1.shape[0] > seg_len + 1 and w2.shape[0] > seg_len + 1): 54 | choose_wav = False 55 | 56 | w1_con = convolve(w1, h_use[0,:,:]).T 57 | w2_con = convolve(w2, h_use[1,:,:]).T 58 | 59 | SIR = random.uniform(-5,5) 60 | scalar=get_amplitude_scaling_factor(w1_con, w2_con, snr = SIR) 61 | w2_con = w2_con / scalar 62 | 63 | mix_way = np.random.choice(['single','dominant'], size=1, replace=False) 64 | mix_name = mix_way[0] + '-' + os.path.basename(raw_dir1)[:-4] + '-' + os.path.basename(raw_dir2)[:-4] + '.wav' 65 | if (mix_way == 'single'): 66 | rand_start1 = np.random.randint(0, w1.shape[0] - seg_len) 67 | stop1 = int(rand_start1 + seg_len) 68 | 69 | mix_reverb = w1_con[rand_start1:stop1,:] 70 | s1_reverb = w1_con[rand_start1:stop1,:] 71 | s2_reverb = np.zeros_like(w1_con[rand_start1:stop1,:]) 72 | 73 | if (mix_way == 'partial'): 74 | rand_start1 = np.random.randint(0, w1.shape[0] - seg_len*0.75) 75 | rand_start2 = np.random.randint(0, w2.shape[0] - seg_len*0.75) 76 | stop1 = int(rand_start1 + seg_len*0.75) 77 | stop2 = int(rand_start2 + seg_len*0.75) 78 | 79 | mix_reverb = np.concatenate([w1_con[rand_start1:rand_start1 + int(seg_len*0.25),:], \ 80 | w1_con[rand_start1 + int(seg_len*0.25):stop1,:] + w2_con[rand_start2:rand_start2 + int(seg_len*0.5),:], \ 81 | w2_con[rand_start2 + int(seg_len*0.5):stop2,:]], axis=0) 82 | s1_reverb = np.concatenate([w1_con[rand_start1:stop1,:],np.zeros_like(w2_con[rand_start2 + int(seg_len*0.5):stop2,:])], axis=0) 83 | s2_reverb = np.concatenate([np.zeros_like(w1_con[rand_start1:rand_start1 + int(seg_len*0.25),:]),w2_con[rand_start2:stop2,:]], axis=0) 84 | 85 | if (mix_way == 'dominant'): 86 | rand_start1 = np.random.randint(0, w1.shape[0] - seg_len) 87 | rand_start2 = np.random.randint(0, w2.shape[0] - seg_len*0.5) 88 | stop1 = int(rand_start1 + seg_len) 89 | stop2 = int(rand_start2 + seg_len*0.5) 90 | 91 | mix_reverb = np.concatenate([w1_con[rand_start1:rand_start1 + int(seg_len*0.5),:], \ 92 | w1_con[rand_start1 + int(seg_len*0.5):stop1,:] + w2_con[rand_start2:stop2,:]], axis=0) 93 | s1_reverb = w1_con[rand_start1:stop1,:] 94 | s2_reverb = np.concatenate([np.zeros_like(w1_con[rand_start1:rand_start1 + int(seg_len*0.5),:]),w2_con[rand_start2:stop2,:]], axis=0) 95 | 96 | if (mix_way == 'sequential'): 97 | rand_start1 = np.random.randint(0, w1.shape[0] - seg_len*0.5) 98 | rand_start2 = np.random.randint(0, w2.shape[0] - seg_len*0.5) 99 | stop1 = int(rand_start1 + seg_len*0.5) 100 | stop2 = int(rand_start2 + seg_len*0.5) 101 | 102 | mix_reverb = np.concatenate([w1_con[rand_start1:stop1,:],w2_con[rand_start2:stop2,:]], axis=0) 103 | s1_reverb = np.concatenate([w1_con[rand_start1:stop1,:],np.zeros_like(w2_con[rand_start2:stop2,:])], axis=0) 104 | s2_reverb = np.concatenate([np.zeros_like(w1_con[rand_start1:stop1,:]),w2_con[rand_start2:stop2,:]], axis=0) 105 | 106 | return mix_reverb, s1_reverb, s2_reverb, mix_name, mix_way 107 | 108 | def add_noise(self,mix_reverb,w_n): 109 | # dynamic SNR 110 | SNR = random.uniform(5,25) 111 | x = [] 112 | for item in mix_reverb: 113 | x.append(item[0]) 114 | rand_start = np.random.randint(0,len(w_n)-len(x)) 115 | stop = rand_start + len(x) 116 | scalar = get_amplitude_scaling_factor(x, w_n, snr = SNR) 117 | 118 | mix_noise = mix_reverb + (w_n[rand_start:stop]/scalar)[None].transpose() 119 | 120 | return mix_noise 121 | 122 | def __getitem__(self,idx): 123 | raw_list = os.listdir(self.raw_dir) 124 | SpeakerNo = len(raw_list) 125 | 126 | speaker1 = np.random.randint(0,SpeakerNo) 127 | speaker2 = np.random.randint(0,SpeakerNo) 128 | while (speaker1 == speaker2): 129 | speaker2 = np.random.randint(0,SpeakerNo) 130 | raw_dir1 = self.raw_dir+raw_list[speaker1] 131 | raw_dir2 = self.raw_dir+raw_list[speaker2] 132 | 133 | choose_rir = np.random.randint(0,self.rirNO) 134 | rand_rir = np.load(self.reverb_matrixs_dir + str(choose_rir).zfill(4) + '.npz') 135 | h_use, _source_positions, _sensor_positions, = rand_rir['h'], rand_rir['source_positions'], rand_rir['sensor_positions'] 136 | 137 | # step1:add reverb to utterance 138 | mix_reverb, s1_reverb, s2_reverb, mix_name, mix_way = self.add_reverb(raw_dir1,raw_dir2,h_use[:,self.channel,:]) 139 | 140 | # step2:add noise 141 | w_n, _ = sf.read(self.noise_wav, dtype="float32") 142 | mix_noise = self.add_noise(mix_reverb,w_n) 143 | mix_noise = mix_noise.transpose() 144 | 145 | mixture = torch.from_numpy(np.array(mix_noise).astype(np.float32)).permute(1,0) 146 | 147 | _source_positions = _source_positions[...,None] # (3,src,channel) 148 | _sensor_positions = _sensor_positions[:,self.channel][:,None] # (3,src,channel) 149 | # 3d-to-2d 2021.09.21 150 | _source_positions[2] = _sensor_positions[2,0,0] 151 | 152 | label = get_label(_source_positions, _sensor_positions, usage='simu') # (3,src) 153 | label = torch.from_numpy(label.astype(np.float32)) 154 | 155 | if (mix_way[0]=='single'): 156 | label = torch.cat((label[:,[0]],label[:,[0]]),dim=1) 157 | 158 | return mixture, mix_name, label, mix_way[0] 159 | 160 | -------------------------------------------------------------------------------- /run_testset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import random 5 | import shutil 6 | import numpy as np 7 | import torch 8 | from torch.utils import data 9 | from tqdm import tqdm 10 | import soundfile as sf 11 | 12 | from sms_wsj.database.create_rirs import config, scenarios, rirs 13 | from sms_wsj.reverb.reverb_utils import convolve 14 | 15 | 16 | def rms(y): 17 | return np.sqrt(np.mean(np.abs(y) ** 2, axis=0, keepdims=False)) 18 | 19 | def get_amplitude_scaling_factor(s, n, snr, method='rms'): 20 | original_sn_rms_ratio = rms(s) / rms(n) 21 | target_sn_rms_ratio = 10. ** (float(snr) / 20.) # snr = 20 * lg(rms(s) / rms(n)) 22 | signal_scaling_factor = target_sn_rms_ratio / original_sn_rms_ratio 23 | return signal_scaling_factor 24 | 25 | def get_label(_source_positions, _sensor_positions, usage): 26 | """ 27 | Extract Label of Raw Wav. 28 | Arguments: 29 | _source_positions: source positions, 3 x 2 x channel 30 | _sensor_positions: sensor positions, 3 x 2 x channel 31 | loss_type: categorical or cartesian 32 | Return: 33 | y: (3,2) 34 | """ 35 | ''' 36 | _source_positions = np.array([[[1.0], [1.0]], 37 | [[0.0], [1.0]], 38 | [[0.0], [1.0]]]) 39 | _sensor_positions = np.array([[[0.0, 1.0, -1.0, 0.0]], 40 | [[0.0, 1.0, 0.0, 0.0]], 41 | [[0.0, 0.0, 0.0, 0.0]]]) 42 | ''' 43 | #print('src:',_source_positions) 44 | #print('sen:',_sensor_positions) 45 | if (usage == 'simu'): 46 | # step1:translation 47 | x0 = _source_positions[0,:,0]; y0 = _source_positions[1,:,0]; z0 = _source_positions[2,:,0] 48 | x1 = -(_sensor_positions[0,:,0] - x0); y1 = -(_sensor_positions[1,:,0] - y0); z1 = -(_sensor_positions[2,:,0] - z0) 49 | 50 | ref_x0 = _sensor_positions[0,:,1]; ref_y0 = _sensor_positions[1,:,1]; ref_z0 = _sensor_positions[2,:,1] 51 | ref_x1 = -(_sensor_positions[0,:,0] - ref_x0); ref_y1 = -(_sensor_positions[1,:,0] - ref_y0); ref_z1 = -(_sensor_positions[2,:,0] - ref_z0) 52 | 53 | # step2:rotation-azimuth 54 | theta = np.arctan2(ref_y1,ref_x1) 55 | x2 = x1 * np.cos(theta) + y1 * np.sin(theta) 56 | y2 = y1 * np.cos(theta) - x1 * np.sin(theta) 57 | z2 = z1 58 | 59 | # step3:rotation-elevation 60 | phi = np.arctan2(ref_z1,np.sqrt(ref_x1 ** 2 + ref_y1 ** 2)) 61 | x3 = x2 * np.cos(phi) + z2 * np.sin(phi) 62 | y3 = y2 63 | z3 = z2 * np.cos(phi) - x2 * np.sin(phi) 64 | 65 | elif (usage == 'dummy'): 66 | x0 = (_sensor_positions[0,:,0] + _sensor_positions[0,:,2]) / 2; y0 = _sensor_positions[1,:,0] 67 | y3 = _source_positions[1,:,0] - y0 68 | x3 = _source_positions[0,:,0] - x0 69 | z3 = _source_positions[2,:,0] - z0 70 | 71 | y = np.vstack((x3,y3,z3)) # 3 src 72 | #print('dis:',dis) 73 | 74 | return y 75 | 76 | 77 | class Dataset(data.Dataset): 78 | def __init__( 79 | self, 80 | reverb_matrixs_dir, 81 | num_rir = 100, 82 | num_utt = 100, 83 | segment = 8, 84 | channel = [0,1,2,3], 85 | overlap = [0.0, 0.1, 0.2, 0.3, 0.4,], 86 | overlap_type = 'single', 87 | snr_low = 5, 88 | snr_high = 25, 89 | raw_dir = '/path/to/LibriSpeech/filelist-all/', 90 | noise_dir = '/path/to/noise/', 91 | ): 92 | super(Dataset, self).__init__() 93 | self.reverb_matrixs_dir = reverb_matrixs_dir 94 | self.num_rir = num_rir 95 | self.num_utt = num_utt 96 | self.segment = segment 97 | self.channel = channel 98 | self.overlap = overlap 99 | self.overlap_type = overlap_type 100 | self.snr_low = snr_low 101 | self.snr_high = snr_high 102 | self.raw_dir = raw_dir 103 | self.noise_wav = noise_wav 104 | 105 | def __len__(self): 106 | return self.num_utt 107 | 108 | def add_reverb(self, raw_dir1, raw_dir2, h_use): 109 | with open(raw_dir1,'r') as fin1: 110 | with open(raw_dir2,'r') as fin2: 111 | wav1 = fin1.readlines() 112 | wav2 = fin2.readlines() 113 | choose_wav = True 114 | while(choose_wav): 115 | i = np.random.randint(0,len(wav1)) 116 | j = np.random.randint(0,len(wav2)) 117 | w1,fs = sf.read(os.path.join('/path/to/LibriSpeech', wav1[i].rstrip("\n")), dtype="float32") 118 | w2,fs = sf.read(os.path.join('/path/to/LibriSpeech', wav2[j].rstrip("\n")), dtype="float32") 119 | seg_len = int(fs * self.segment) 120 | if (w1.shape[0] > seg_len + 1 and w2.shape[0] > seg_len + 1): 121 | choose_wav = False 122 | 123 | w1_con = convolve(w1, h_use[0,:,:]).T 124 | w2_con = convolve(w2, h_use[1,:,:]).T 125 | 126 | SIR = random.uniform(-5,5) 127 | scalar=get_amplitude_scaling_factor(w1_con, w2_con, snr = SIR) 128 | w2_con = w2_con / scalar 129 | 130 | mix_way = np.random.choice([self.overlap_type], size=1, replace=False) 131 | 132 | mix_name = mix_way[0] + '-' + os.path.basename(raw_dir1)[:-4] + '-' + os.path.basename(raw_dir2)[:-4] + '.wav' 133 | if (mix_way == 'single'): 134 | rand_start1 = np.random.randint(0, w1.shape[0] - seg_len) 135 | stop1 = int(rand_start1 + seg_len) 136 | 137 | mix_reverb = w1_con[rand_start1:stop1,:] 138 | s1_reverb = w1_con[rand_start1:stop1,:] 139 | s2_reverb = np.zeros_like(w1_con[rand_start1:stop1,:]) 140 | 141 | if (mix_way == 'partial'): 142 | rand_start1 = np.random.randint(0, w1.shape[0] - seg_len*0.75) 143 | rand_start2 = np.random.randint(0, w2.shape[0] - seg_len*0.75) 144 | stop1 = int(rand_start1 + seg_len*0.75) 145 | stop2 = int(rand_start2 + seg_len*0.75) 146 | 147 | mix_reverb = np.concatenate([w1_con[rand_start1:rand_start1 + int(seg_len*0.25),:], \ 148 | w1_con[rand_start1 + int(seg_len*0.25):stop1,:] + w2_con[rand_start2:rand_start2 + int(seg_len*0.5),:], \ 149 | w2_con[rand_start2 + int(seg_len*0.5):stop2,:]], axis=0) 150 | s1_reverb = np.concatenate([w1_con[rand_start1:stop1,:],np.zeros_like(w2_con[rand_start2 + int(seg_len*0.5):stop2,:])], axis=0) 151 | s2_reverb = np.concatenate([np.zeros_like(w1_con[rand_start1:rand_start1 + int(seg_len*0.25),:]),w2_con[rand_start2:stop2,:]], axis=0) 152 | 153 | if (mix_way == 'dominant'): 154 | rand_start1 = np.random.randint(0, w1.shape[0] - seg_len) 155 | rand_start2 = np.random.randint(0, w2.shape[0] - seg_len*0.5) 156 | stop1 = int(rand_start1 + seg_len) 157 | stop2 = int(rand_start2 + seg_len*0.5) 158 | 159 | mix_reverb = np.concatenate([w1_con[rand_start1:rand_start1 + int(seg_len*0.5),:], \ 160 | w1_con[rand_start1 + int(seg_len*0.5):stop1,:] + w2_con[rand_start2:stop2,:]], axis=0) 161 | s1_reverb = w1_con[rand_start1:stop1,:] 162 | s2_reverb = np.concatenate([np.zeros_like(w1_con[rand_start1:rand_start1 + int(seg_len*0.5),:]),w2_con[rand_start2:stop2,:]], axis=0) 163 | 164 | if (mix_way == 'sequential'): 165 | rand_start1 = np.random.randint(0, w1.shape[0] - seg_len*0.5) 166 | rand_start2 = np.random.randint(0, w2.shape[0] - seg_len*0.5) 167 | stop1 = int(rand_start1 + seg_len*0.5) 168 | stop2 = int(rand_start2 + seg_len*0.5) 169 | 170 | mix_reverb = np.concatenate([w1_con[rand_start1:stop1,:],w2_con[rand_start2:stop2,:]], axis=0) 171 | s1_reverb = np.concatenate([w1_con[rand_start1:stop1,:],np.zeros_like(w2_con[rand_start2:stop2,:])], axis=0) 172 | s2_reverb = np.concatenate([np.zeros_like(w1_con[rand_start1:stop1,:]),w2_con[rand_start2:stop2,:]], axis=0) 173 | 174 | return mix_reverb, s1_reverb, s2_reverb, mix_name, mix_way 175 | 176 | def add_noise(self, mix_reverb, w_n): 177 | SNR = random.uniform(self.snr_low, self.snr_high) 178 | x = [] 179 | for item in mix_reverb: 180 | x.append(item[0]) 181 | rand_start = np.random.randint(0,len(w_n)-len(x)) 182 | stop = rand_start + len(x) 183 | scalar = get_amplitude_scaling_factor(x, w_n, snr = SNR) 184 | 185 | mix_noise = mix_reverb + (w_n[rand_start:stop]/scalar)[None].transpose() 186 | 187 | return mix_noise 188 | 189 | def __getitem__(self,idx): 190 | raw_list = os.listdir(self.raw_dir) 191 | num_spk = len(raw_list) 192 | 193 | speaker1 = np.random.randint(0, num_spk) 194 | speaker2 = np.random.randint(0, num_spk) 195 | while (speaker1 == speaker2): 196 | speaker2 = np.random.randint(0,num_spk) 197 | raw_dir1 = self.raw_dir+raw_list[speaker1] 198 | raw_dir2 = self.raw_dir+raw_list[speaker2] 199 | 200 | choose_rir = np.random.randint(0, self.num_rir) 201 | rand_rir = np.load(self.reverb_matrixs_dir + str(choose_rir).zfill(4) + '.npz') 202 | h_use, _source_positions, _sensor_positions, room_dimensions, sound_decay_time, = \ 203 | rand_rir['h'], rand_rir['source_positions'], rand_rir['sensor_positions'], rand_rir['room_dimensions'], rand_rir['sound_decay_time'] 204 | 205 | mix_reverb, s1_reverb, s2_reverb, mix_name, mix_way = self.add_reverb(raw_dir1,raw_dir2,h_use[:,self.channel,:]) 206 | 207 | w_n, _ = sf.read(self.noise_wav, dtype="float32") 208 | mix_noise = self.add_noise(mix_reverb, w_n) 209 | mix_noise = mix_noise.transpose() 210 | 211 | mixture = torch.from_numpy(np.array(mix_noise).astype(np.float32)).permute(1,0) 212 | 213 | _source_positions = _source_positions[...,None] # (3,src,channel) 214 | _sensor_positions = _sensor_positions[:,self.channel][:,None] # (3,src,channel) 215 | # 3d-to-2d 2021.09.21 216 | _source_positions[2] = _sensor_positions[2,0,0] 217 | 218 | label = get_label(_source_positions, _sensor_positions, usage='simu') # (3,src) 219 | label = torch.from_numpy(label.astype(np.float32)) 220 | 221 | if (mix_way[0]=='single'): 222 | label = torch.cat((label[:,[0]],label[:,[0]]),dim=1) 223 | 224 | return mixture, s1_reverb, s2_reverb, mix_name, label, mix_way[0], _source_positions, _sensor_positions, room_dimensions, sound_decay_time 225 | 226 | 227 | if __name__ == "__main__": 228 | num_rir, num_utt = 1, 1 229 | rir_dir = 'path/to/testset' 230 | wav_dir = os.path.join(rir_dir, 'mix_noise') 231 | reverb_matrixs_dir = os.path.join(rir_dir, 'reverb_matrixs') 232 | s1_dir = os.path.join(rir_dir, 's1') 233 | s2_dir = os.path.join(rir_dir, 's2') 234 | 235 | for item in [wav_dir, reverb_matrixs_dir, s1_dir, s2_dir]: 236 | try: 237 | os.makedirs(item) 238 | except OSError: 239 | pass 240 | 241 | overlap_pattern = ['dominant'] 242 | d = Dataset(reverb_matrixs_dir = '/path/to/reverb-set/', 243 | num_rir = num_rir, 244 | num_utt = num_utt, 245 | overlap_type = overlap_pattern[0], 246 | snr_low = 5, 247 | snr_high = 25,) 248 | 249 | print('saving reverb matrixs into', reverb_matrixs_dir) 250 | pbar = tqdm(range(num_utt)) 251 | for i in pbar: 252 | mixture, s1, s2, mix_name, label, mix_way, source_positions, sensor_positions, room_dimensions, sound_decay_time = d[i] 253 | sf.write(os.path.join(wav_dir, mix_name), mixture.numpy(), 16000) 254 | sf.write(os.path.join(s1_dir, mix_name), s1[:,0], 16000) 255 | sf.write(os.path.join(s2_dir, mix_name), s2[:,0], 16000) 256 | np.savez(os.path.join(reverb_matrixs_dir, mix_name[:-4] + '.npz'), \ 257 | mix=mixture, n=mix_name, label=label, mix_way=mix_way, source_positions=source_positions, sensor_positions=sensor_positions, \ 258 | room_dimensions=room_dimensions, sound_decay_time=sound_decay_time) 259 | 260 | print('Done.') 261 | --------------------------------------------------------------------------------