├── first_year ├── loss.py ├── README.md ├── data_stat.py ├── config │ └── train1.yaml ├── data_loader.py ├── tester.py └── main.py ├── second_year └── src │ ├── README.md │ ├── cp_src.sh │ ├── dataloader │ ├── __init__.py │ ├── random_gpu_rir_generator │ │ └── __init__.py │ ├── dataloader_test.yaml │ ├── room_config.yaml │ ├── mic_array.py │ └── dataloader_train.yaml │ ├── __init__.py │ ├── util │ ├── __init__.py │ └── util.py │ ├── loss │ └── __init__.py │ ├── hyparam │ ├── __init__.py │ ├── train.yaml │ ├── logger.yaml │ ├── test.yaml │ └── learner.yaml │ ├── models │ ├── __init__.py │ └── convtasnet_SSL_FiLM │ │ ├── Causal_CRN_SPL_target │ │ ├── __init__.py │ │ ├── CRN_main.py │ │ ├── CRN_SPL_target.py │ │ ├── FFT.py │ │ └── CRN.py │ │ ├── convtasnet_module │ │ ├── __init__.py │ │ ├── conv_tasnet.py │ │ └── utility │ │ │ └── models.py │ │ ├── main.py │ │ ├── model.yaml │ │ └── convtasnet.py │ ├── metadata │ └── Readme.md │ ├── run_trainer.sh │ ├── run_tester.sh │ ├── inference.py │ └── train.py ├── third_year ├── src │ ├── models │ │ ├── __init__.py │ │ └── EABNET │ │ │ ├── main.py │ │ │ ├── model.yaml │ │ │ └── FFT.py │ ├── dataloader │ │ ├── __init__.py │ │ ├── random_gpu_rir_generator │ │ │ └── __init__.py │ │ ├── dataloader_test.yaml │ │ ├── room_config.yaml │ │ ├── mic_array.py │ │ ├── dataloader_test_maker.yaml │ │ └── dataloader_train.yaml │ ├── __init__.py │ ├── util │ │ ├── __init__.py │ │ └── util.py │ ├── hyparam │ │ ├── __init__.py │ │ ├── train.yaml │ │ ├── logger.yaml │ │ ├── test.yaml │ │ └── learner.yaml │ ├── loss │ │ ├── __init__.py │ │ ├── SI_SDR_sync.py │ │ └── SDR_loss.py │ ├── preprocess │ │ ├── README.md │ │ ├── ms-snsd_preprocess.py │ │ └── librispeech_preprocess.py │ ├── metadata │ │ └── Readme.md │ ├── run_tester.sh │ ├── val_eval_set_prepare_sitec │ │ └── run_prepare.py │ ├── run_make_test_set.sh │ ├── run_trainer.sh │ ├── inference.py │ └── make_test_set.py └── README.md ├── fourth_year └── src │ ├── models │ ├── __init__.py │ └── FSPEN │ │ ├── main.py │ │ ├── fspen_total.py │ │ ├── model.yaml │ │ ├── modules │ │ ├── en_decoder.py │ │ └── sequence_modules.py │ │ ├── FFT.py │ │ └── fspen.py │ ├── __init__.py │ ├── util │ ├── __init__.py │ └── util.py │ ├── hyparam │ ├── __init__.py │ ├── train.yaml │ ├── logger.yaml │ ├── test.yaml │ └── learner.yaml │ ├── loss │ ├── __init__.py │ └── SI_SDR_sync.py │ ├── run_tester.sh │ └── run_trainer.sh ├── README.md └── .gitignore /first_year/loss.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /second_year/src/README.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /second_year/src/cp_src.sh: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /third_year/src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fourth_year/src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /second_year/src/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /third_year/src/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fourth_year/src/__init__.py: -------------------------------------------------------------------------------- 1 | import models -------------------------------------------------------------------------------- /second_year/src/__init__.py: -------------------------------------------------------------------------------- 1 | import models -------------------------------------------------------------------------------- /third_year/src/__init__.py: -------------------------------------------------------------------------------- 1 | import models -------------------------------------------------------------------------------- /first_year/README.md: -------------------------------------------------------------------------------- 1 | # Beamformer 2 | Beamformer 3 | -------------------------------------------------------------------------------- /fourth_year/src/util/__init__.py: -------------------------------------------------------------------------------- 1 | from . import util -------------------------------------------------------------------------------- /second_year/src/util/__init__.py: -------------------------------------------------------------------------------- 1 | from . import util -------------------------------------------------------------------------------- /third_year/src/util/__init__.py: -------------------------------------------------------------------------------- 1 | from . import util -------------------------------------------------------------------------------- /second_year/src/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from . import bce_loss -------------------------------------------------------------------------------- /fourth_year/src/hyparam/__init__.py: -------------------------------------------------------------------------------- 1 | from . import hyparam_set -------------------------------------------------------------------------------- /second_year/src/dataloader/random_gpu_rir_generator/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /second_year/src/hyparam/__init__.py: -------------------------------------------------------------------------------- 1 | from . import hyparam_set -------------------------------------------------------------------------------- /third_year/src/dataloader/random_gpu_rir_generator/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /third_year/src/hyparam/__init__.py: -------------------------------------------------------------------------------- 1 | from . import hyparam_set -------------------------------------------------------------------------------- /second_year/src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import convtasnet_SSL_FiLM -------------------------------------------------------------------------------- /third_year/README.md: -------------------------------------------------------------------------------- 1 | EaBNet: https://github.com/Andong-Li-speech/EaBNet -------------------------------------------------------------------------------- /third_year/src/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from . import bce_loss 2 | from . import SI_SDR_sync -------------------------------------------------------------------------------- /fourth_year/src/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from . import bce_loss 2 | from . import SI_SDR_sync -------------------------------------------------------------------------------- /second_year/src/models/convtasnet_SSL_FiLM/Causal_CRN_SPL_target/__init__.py: -------------------------------------------------------------------------------- 1 | from . import CRN, FFT -------------------------------------------------------------------------------- /second_year/src/models/convtasnet_SSL_FiLM/convtasnet_module/__init__.py: -------------------------------------------------------------------------------- 1 | from . import conv_tasnet -------------------------------------------------------------------------------- /third_year/src/preprocess/README.md: -------------------------------------------------------------------------------- 1 | preparing dataset for Librispeech 2 | 3 | /root/harddisk1/Dataset/librispeech/LibriSpeech/train-clean-100 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | IITP project 2 | 3 | Beamforming model update. 4 | 5 | 1st & 2nd year for v1.2.1 6 | 7 | 3rd year for v2.1.1 8 | 9 | 10 | -------------------------------------------------------------------------------- /third_year/src/models/EABNET/main.py: -------------------------------------------------------------------------------- 1 | 2 | from .total_EABNET import Total_model 3 | 4 | def get_model(args): 5 | 6 | return Total_model(args) 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.wav 2 | *.txt 3 | *.npz 4 | *.npy 5 | *.tar 6 | *.pt 7 | *.pth 8 | *.csv 9 | *.pyc 10 | /**/wandb/ 11 | /**/results/ 12 | *.png 13 | *.pkl 14 | *.gif 15 | -------------------------------------------------------------------------------- /fourth_year/src/hyparam/train.yaml: -------------------------------------------------------------------------------- 1 | torch_start_method: 'forkserver' 2 | randomseed: 20 3 | 4 | 5 | last_epoch: 5 6 | resume_epoch: 0 7 | 8 | GPGPU: 9 | device_ids: [0] 10 | 11 | -------------------------------------------------------------------------------- /second_year/src/hyparam/train.yaml: -------------------------------------------------------------------------------- 1 | torch_start_method: 'forkserver' 2 | randomseed: 0 3 | 4 | 5 | last_epoch: 100 6 | resume_epoch: 0 7 | 8 | GPGPU: 9 | device_ids: [0] 10 | 11 | -------------------------------------------------------------------------------- /second_year/src/models/convtasnet_SSL_FiLM/main.py: -------------------------------------------------------------------------------- 1 | # from .CRN import main_model 2 | from .convtasnet import main_model 3 | 4 | def get_model(args): 5 | 6 | return main_model(args) 7 | -------------------------------------------------------------------------------- /third_year/src/hyparam/train.yaml: -------------------------------------------------------------------------------- 1 | torch_start_method: 'forkserver' 2 | randomseed: 20 3 | 4 | 5 | last_epoch: 100 6 | resume_epoch: 0 7 | 8 | GPGPU: 9 | device_ids: [0] 10 | 11 | -------------------------------------------------------------------------------- /fourth_year/src/models/FSPEN/main.py: -------------------------------------------------------------------------------- 1 | # from .CRN import main_model 2 | # from .EABNET import EaBNet 3 | # from .total_EABNET import Total_model 4 | from .fspen_total import Total_model 5 | 6 | def get_model(args): 7 | 8 | 9 | return Total_model(args) 10 | -------------------------------------------------------------------------------- /second_year/src/metadata/Readme.md: -------------------------------------------------------------------------------- 1 | original_audio_metadata.csv: sitec total metadata 2 | test_audio_list.csv: sitec speech for test 3 | trcv_audio_list.csv: sitec speech for train/validation 4 | tr_audio_list.csv: sitec speech for train 5 | cv_audio_list.csv: sitec speech for validation -------------------------------------------------------------------------------- /third_year/src/metadata/Readme.md: -------------------------------------------------------------------------------- 1 | original_audio_metadata.csv: sitec total metadata 2 | test_audio_list.csv: sitec speech for test 3 | trcv_audio_list.csv: sitec speech for train/validation 4 | tr_audio_list.csv: sitec speech for train 5 | cv_audio_list.csv: sitec speech for validation -------------------------------------------------------------------------------- /third_year/src/run_tester.sh: -------------------------------------------------------------------------------- 1 | device=0 2 | CUDA_VISIBLE_DEVICES=$device python inference.py \ 3 | "model ./models/EABNET/model.yaml" \ 4 | "dataloader ./dataloader/dataloader_test.yaml" \ 5 | "hyparam ./hyparam/test.yaml" \ 6 | "learner ./hyparam/learner.yaml" \ 7 | "logger ./hyparam/logger.yaml"\ -------------------------------------------------------------------------------- /third_year/src/val_eval_set_prepare_sitec/run_prepare.py: -------------------------------------------------------------------------------- 1 | from prepare_main import val_csv_prepare, eval_csv_prepare 2 | import sys 3 | 4 | 5 | if __name__=='__main__': 6 | 7 | 8 | t=val_csv_prepare(sys.argv[1], 'val') 9 | 10 | t=eval_csv_prepare(sys.argv[1], 'test') 11 | -------------------------------------------------------------------------------- /second_year/src/run_trainer.sh: -------------------------------------------------------------------------------- 1 | mkdir ../results/ 2 | CUDA_VISIBLE_DEVICES=0 python train.py \ 3 | "model ./models/convtasnet_SSL_FiLM/model.yaml" \ 4 | "dataloader ./dataloader/dataloader_train.yaml" \ 5 | "hyparam ./hyparam/train.yaml" \ 6 | "learner ./hyparam/learner.yaml" \ 7 | "logger ./hyparam/logger.yaml"\ -------------------------------------------------------------------------------- /second_year/src/run_tester.sh: -------------------------------------------------------------------------------- 1 | rm -rf ./wandb/ 2 | 3 | CUDA_VISIBLE_DEVICES=0 python inference.py \ 4 | "model ./models/convtasnet_SSL_FiLM/model.yaml" \ 5 | "dataloader ./dataloader/dataloader_test.yaml" \ 6 | "hyparam ./hyparam/test.yaml" \ 7 | "learner ./hyparam/learner.yaml" \ 8 | "logger ./hyparam/logger.yaml"\ -------------------------------------------------------------------------------- /fourth_year/src/run_tester.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | device=0 4 | 5 | 6 | CUDA_VISIBLE_DEVICES=$device python inference.py \ 7 | "model ./models/FSPEN/model.yaml" \ 8 | "dataloader ./dataloader/dataloader_test.yaml" \ 9 | "hyparam ./hyparam/test.yaml" \ 10 | "learner ./hyparam/learner.yaml" \ 11 | "logger ./hyparam/logger.yaml"\ -------------------------------------------------------------------------------- /first_year/data_stat.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | data_dir='/home/intern0/Desktop/project/IITP/Sound_source_localization/Data_processing/speech_setup/' 4 | 5 | 6 | train_dir=data_dir+'test_audio.csv' 7 | 8 | df=pd.read_csv(train_dir) 9 | df=np.array(df['length'].tolist()) 10 | print(df.sum()/16000/3600) 11 | -------------------------------------------------------------------------------- /fourth_year/src/hyparam/logger.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | wandb_ok: False 3 | init: 4 | project: 5 | entity: 6 | name: 7 | 8 | optimize_method: 'min' 9 | 10 | save_csv: ../results/logger/log.csv 11 | png_dir: ../results/logger/loss.png 12 | 13 | model_save_dir: ../results/model_checkpoint/ 14 | out_txt: 15 | dir: ../results/logger/exp.txt -------------------------------------------------------------------------------- /second_year/src/hyparam/logger.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | wandb_ok: False 3 | init: 4 | project: 5 | entity: 6 | name: 7 | 8 | optimize_method: 'min' 9 | 10 | save_csv: ../results/logger/log.csv 11 | png_dir: ../results/logger/loss.png 12 | 13 | model_save_dir: ../results/model_checkpoint/ 14 | out_txt: 15 | dir: ../results/logger/exp.txt -------------------------------------------------------------------------------- /third_year/src/hyparam/logger.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | wandb_ok: False 3 | init: 4 | project: 5 | entity: 6 | name: 7 | 8 | optimize_method: 'min' 9 | 10 | save_csv: ../results/logger/log.csv 11 | png_dir: ../results/logger/loss.png 12 | 13 | model_save_dir: ../results/model_checkpoint/ 14 | out_txt: 15 | dir: ../results/logger/exp.txt -------------------------------------------------------------------------------- /third_year/src/hyparam/test.yaml: -------------------------------------------------------------------------------- 1 | torch_start_method: 'fork' 2 | randomseed: 0 3 | 4 | 5 | 6 | 7 | model: './best_model.tar' 8 | 9 | result_folder: 10 | room_type: ['409', '819', 'meeting', 'seminar', 'house', 'hospital', 'cafe', 'car'] 11 | 12 | inference_folder: '../results/inference_best/' 13 | 14 | 15 | 16 | GPGPU: 17 | device_ids: [0] -------------------------------------------------------------------------------- /third_year/src/run_make_test_set.sh: -------------------------------------------------------------------------------- 1 | 2 | device=0 3 | echo "device: $device" 4 | 5 | CUDA_VISIBLE_DEVICES=$device python make_test_set.py \ 6 | "model ./models/Causal_CRN_SPL_target/model.yaml" \ 7 | "dataloader ./dataloader/dataloader_test_maker.yaml" \ 8 | "hyparam ./hyparam/train.yaml" \ 9 | "learner ./hyparam/learner.yaml" \ 10 | "logger ./hyparam/logger.yaml"\ -------------------------------------------------------------------------------- /third_year/src/run_trainer.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | rm -rf ./wandb/ 4 | device=0 5 | echo "device: $device" 6 | 7 | CUDA_VISIBLE_DEVICES=$device python train.py \ 8 | "model ./models/EABNET/model.yaml" \ 9 | "dataloader ./dataloader/dataloader_train.yaml" \ 10 | "hyparam ./hyparam/train.yaml" \ 11 | "learner ./hyparam/learner.yaml" \ 12 | "logger ./hyparam/logger.yaml"\ 13 | 14 | rm -rf ./wandb/ 15 | 16 | -------------------------------------------------------------------------------- /fourth_year/src/hyparam/test.yaml: -------------------------------------------------------------------------------- 1 | torch_start_method: 'fork' 2 | randomseed: 0 3 | 4 | 5 | resolution: 1 6 | acc_threshold: 10 7 | local_maximum_distance: 30 8 | 9 | model: 10 | 11 | result_folder: 12 | room_type: ['409', '819', 'meeting', 'seminar', 'house', 'hospital', 'cafe', 'car'] 13 | 14 | inference_folder: '../results/inference_best/' 15 | 16 | 17 | 18 | GPGPU: 19 | device_ids: [0] -------------------------------------------------------------------------------- /fourth_year/src/run_trainer.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | rm -rf ./wandb/ 5 | device=0 6 | echo "device: $device" 7 | 8 | CUDA_VISIBLE_DEVICES=$device python train.py \ 9 | "model ./models/EABNET/model.yaml" \ 10 | "dataloader ./dataloader/dataloader_train.yaml" \ 11 | "hyparam ./hyparam/train.yaml" \ 12 | "learner ./hyparam/learner.yaml" \ 13 | "logger ./hyparam/logger.yaml"\ 14 | 15 | rm -rf ./wandb/ 16 | 17 | bash run_tester.sh -------------------------------------------------------------------------------- /second_year/src/hyparam/test.yaml: -------------------------------------------------------------------------------- 1 | torch_start_method: 'fork' 2 | randomseed: 0 3 | 4 | 5 | wav_save: False 6 | wav_folder: '../results/inference_best_wav/' 7 | model: '../results/model_checkpoint/best_model.tar' 8 | 9 | result_folder: 10 | room_type: ['409', '819', 'meeting', 'seminar', 'house', 'hospital', 'cafe', 'car'] 11 | inference_folder: '../results/inference_best/' 12 | 13 | 14 | 15 | GPGPU: 16 | device_ids: [0] -------------------------------------------------------------------------------- /second_year/src/hyparam/learner.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: Adam 3 | gradient_clip: 1 4 | config: 5 | lr: 1.0e-4 6 | weight_decay: 1.0e-6 7 | 8 | 9 | 10 | optimizer_scheduler: 11 | type: ReduceLROnPlateau 12 | config: 13 | mode: 'min' 14 | factor: 0.9 15 | patience: 4000000 16 | threshold: 1.0e-8 17 | min_lr: 1.0e-4 18 | verbose: True 19 | 20 | loss: 21 | optimize_method: 'min' 22 | type: SI-SDR 23 | -------------------------------------------------------------------------------- /second_year/src/dataloader/dataloader_test.yaml: -------------------------------------------------------------------------------- 1 | test: 2 | 3 | pkl_dir: 4 | 5 | 6 | metadata_dir: 7 | noise_csv: 8 | speech_csv: 9 | pkl_csv: 10 | 11 | room_yaml: 12 | 13 | ans_azi: 0 14 | degree_resolution: 1 15 | 16 | mic_type: circular # circular,. ellipsoid, linear 17 | mic_num: 4 # 4, 6, 8 18 | max_n_mic: 8 19 | 20 | dataloader_dict: 21 | batch_size: 1 22 | shuffle: False 23 | num_workers: 8 24 | drop_last: False 25 | 26 | 27 | -------------------------------------------------------------------------------- /second_year/src/models/convtasnet_SSL_FiLM/Causal_CRN_SPL_target/CRN_main.py: -------------------------------------------------------------------------------- 1 | # from .CRN import main_model 2 | from .CRN_SPL_target import main_model 3 | import torch 4 | 5 | def get_model(args): 6 | 7 | model=main_model(args) 8 | trained=torch.load(args['pretrain']) 9 | model.load_state_dict(trained['model_state_dict'], ) 10 | if args['freeze']: 11 | for param in model.parameters(): 12 | param.requires_grad=False 13 | 14 | return model 15 | -------------------------------------------------------------------------------- /third_year/src/models/EABNET/model.yaml: -------------------------------------------------------------------------------- 1 | name: EABNET 2 | mics: 7 3 | 4 | 5 | EABNET: 6 | k1: [2,3] 7 | k2: [1,3] 8 | c: 64 9 | M: 4 # number of mic 10 | embed_dim: 64 11 | kd1: 5 12 | cd1: 64 13 | d_feat: 256 14 | p: 6 15 | q: 3 16 | is_causal: True 17 | is_u2: True 18 | bf_type: 'lstm' 19 | topo_type: 'mimo' 20 | intra_connect: 'cat' 21 | norm_type: 'IN' 22 | 23 | FFT: 24 | win_len: 320 25 | win_inc: 160 26 | fft_len: 320 27 | win_type: 'hamming' 28 | sqrt_window: False 29 | 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /fourth_year/src/hyparam/learner.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: Adam 3 | gradient_clip: 5 4 | config: 5 | lr: 1.0e-3 6 | weight_decay: 1.0e-6 7 | 8 | 9 | 10 | optimizer_scheduler: 11 | type: ReduceLROnPlateau 12 | config: 13 | mode: 'min' 14 | factor: 0.9 15 | patience: 0 16 | threshold: 1.0e-8 17 | min_lr: 1.0e-4 18 | verbose: True 19 | 20 | loss: 21 | optimize_method: 'min' 22 | 23 | type: sync_SI_SDR 24 | 25 | option: 26 | train_map_num: [0, 1, 2] 27 | each_layer_weight: [1.0, 1.0, 1.0] 28 | 29 | -------------------------------------------------------------------------------- /third_year/src/dataloader/dataloader_test.yaml: -------------------------------------------------------------------------------- 1 | test: 2 | 3 | pkl_dir: ../DB/circle_4/ 4 | 5 | 6 | 7 | metadata_dir: ./metadata/ 8 | noise_csv: ms-snsd_test.csv 9 | speech_csv: librispeech_test.csv 10 | pkl_csv: test_csv.csv 11 | 12 | room_yaml: ./dataloader/dataloader_test.yaml 13 | 14 | ans_azi: 0 15 | degree_resolution: 1 16 | 17 | mic_type: circular # circular,. ellipsoid, linear 18 | mic_num: 4 # 4, 6, 8 19 | max_n_mic: 4 20 | 21 | dataloader_dict: 22 | batch_size: 1 23 | shuffle: False 24 | num_workers: 4 25 | drop_last: False 26 | 27 | 28 | -------------------------------------------------------------------------------- /third_year/src/loss/SI_SDR_sync.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.nn.modules.loss import _Loss 4 | 5 | from asteroid.losses import singlesrc_neg_sisdr 6 | from asteroid.losses.sdr import SingleSrcNegSDR 7 | 8 | 9 | class sync_SI_SDR(_Loss): 10 | def __init__(self, reduction="none"): 11 | super(sync_SI_SDR, self).__init__() 12 | self.reduction=reduction 13 | self.EPS=1e-8 14 | self.func_for_loss=SingleSrcNegSDR(reduction=reduction, zero_mean=True, take_log=True, sdr_type='sisdr') 15 | 16 | 17 | 18 | 19 | def forward(self, output, target): 20 | 21 | 22 | loss=self.func_for_loss(output, target) 23 | 24 | return loss 25 | 26 | -------------------------------------------------------------------------------- /third_year/src/hyparam/learner.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: Adam 3 | gradient_clip: 5 4 | config: 5 | lr: 1.0e-3 6 | weight_decay: 1.0e-6 7 | 8 | 9 | beampattern: 10 | sound_speed: 343.0 11 | fs: 16000 12 | fft_len: 320 13 | theta_step: 5.0 14 | 15 | angle_candidates: [0, 5, 15] # degree 16 | 17 | 18 | time_avg: True 19 | 20 | sigma: 10.0 21 | p: 0.707106781 22 | 23 | device: 'cuda' 24 | clip: True 25 | 26 | optimizer_scheduler: 27 | type: ReduceLROnPlateau 28 | config: 29 | mode: 'min' 30 | factor: 0.9 31 | patience: 0 32 | threshold: 1.0e-8 33 | min_lr: 1.0e-4 34 | verbose: True 35 | 36 | loss: 37 | optimize_method: 'min' 38 | type: sync_SI_SDR 39 | # type: mse 40 | option: 41 | train_map_num: [0, 1, 2] 42 | each_layer_weight: [1.0, 1.0, 1.0] 43 | 44 | -------------------------------------------------------------------------------- /fourth_year/src/loss/SI_SDR_sync.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.nn.modules.loss import _Loss 4 | 5 | from asteroid.losses import singlesrc_neg_sisdr 6 | from asteroid.losses.sdr import SingleSrcNegSDR 7 | 8 | 9 | class sync_SI_SDR(_Loss): 10 | def __init__(self, reduction="none"): 11 | super(sync_SI_SDR, self).__init__() 12 | self.reduction=reduction 13 | self.EPS=1e-8 14 | self.func_for_loss=SingleSrcNegSDR(reduction=reduction, zero_mean=True, take_log=True, sdr_type='sisdr') 15 | 16 | 17 | 18 | 19 | def forward(self, output, target): 20 | 21 | 22 | loss=self.func_for_loss(output, target) 23 | # print(loss.shape) 24 | # exit(1) 25 | return loss 26 | 27 | batch_size=output.shape[0] 28 | 29 | loss=0 30 | 31 | for i in range(batch_size): 32 | ab=torch.functional.F.conv1d(output[i:i+1], target[i:i+1]) 33 | print(ab.shape) 34 | exit(1) 35 | corr=torch.matmul(output[i], target[i].T) 36 | print(corr) 37 | exit(1) 38 | -------------------------------------------------------------------------------- /third_year/src/preprocess/ms-snsd_preprocess.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import pandas as pd 3 | import soundfile as sf 4 | from soundfile import SoundFile, SEEK_END 5 | import tqdm 6 | 7 | class PreProcessor(): 8 | def __init__(self): 9 | self.train_dir='/MS-SNSD/noise_train/' 10 | self.test_dir='/MS-SNSD/noise_test/' 11 | 12 | self.metadata_dir='./metadata/' 13 | train_file='ms-snsd_train.csv' 14 | test_file='ms-snsd_test.csv' 15 | 16 | self.make_csv(self.metadata_dir+train_file, self.train_dir) 17 | self.make_csv(self.metadata_dir+test_file, self.test_dir) 18 | 19 | def make_csv(self, save_dir, audio_dir): 20 | 21 | # dir + speaker_id + chaper_id + wav_name 22 | column_names=['file_path', 'duration', ] 23 | csv_dict={} 24 | 25 | for column in column_names: 26 | csv_dict[column]=[] 27 | 28 | 29 | 30 | csv_data=pd.DataFrame(columns=column_names) 31 | 32 | audio_list=glob(audio_dir+'/**/*.wav', recursive=True) 33 | 34 | for audio in tqdm.tqdm(audio_list, total=len(audio_list)): 35 | f=SoundFile(audio) 36 | 37 | wav_len = f.seek(0, SEEK_END) 38 | audio_name=audio.split('/')[-1] 39 | 40 | for column, data in zip(column_names, [audio_name, wav_len,]): 41 | csv_dict[column].append(data) 42 | 43 | pd.DataFrame(csv_dict).to_csv(save_dir) 44 | 45 | 46 | 47 | if __name__=='__main__': 48 | PreProcessor() 49 | -------------------------------------------------------------------------------- /second_year/src/models/convtasnet_SSL_FiLM/model.yaml: -------------------------------------------------------------------------------- 1 | name: convtasnet_SSL_FiLM 2 | 3 | 4 | TasNet: 5 | ch_size: 8 6 | skip: True 7 | enc_dim: 512 8 | feature_dim: 128 9 | sr: 16000 10 | win: 2 11 | layer: 8 12 | stack: 3 13 | kernel: 3 14 | num_spk: 1 15 | causal: True 16 | condi_weight: [360, 128] 17 | condi_bias: [360, 128] 18 | Film_loc: [7] 19 | padding: [128, 16] 20 | 21 | CRN: 22 | pretrain: './pretrained_CRN/circular_4.tar' 23 | freeze: True 24 | 25 | degree_resolution: 1 26 | 27 | ref_ch: 0 28 | 29 | ##### sigma 30 | p: 0.707106781 31 | wait_epoch: 0 32 | sigma_start: [16.0] 33 | sigma_end: 34 | min: [2.5] 35 | max: [16.0] 36 | 37 | 38 | sigma_rate: [-0.54] 39 | sigma_update_method: 'add' 40 | 41 | iter: 42 | update: False 43 | update_period: 200 44 | 45 | epoch: 46 | update: True 47 | update_period: 1 48 | 49 | 50 | 51 | max_spk: 2 52 | 53 | FFT: 54 | win_len: 256 55 | win_inc: 128 56 | fft_len: 256 57 | vad_threshold: 0.6666 58 | 59 | CRN: 60 | 61 | input_audio_channel: 8 62 | fft_freq_bin_num: 129 63 | 64 | CNN: 65 | layer_num: 4 66 | kernel_size: [3,3] # F X T 67 | filter: 64 68 | 69 | max_pool: 70 | kernel_size: [2,1] 71 | stride: [2,1] 72 | 73 | GRU: 74 | input_size: 512 75 | hidden_size: 256 76 | num_layers: 3 77 | batch_first: True 78 | dropout: 0.0 79 | 80 | GRU_init: 81 | shape: [3,1, 256] 82 | learnable: False 83 | 84 | -------------------------------------------------------------------------------- /second_year/src/models/convtasnet_SSL_FiLM/convtasnet.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules import conv 2 | from torch import nn 3 | import torch 4 | from util import * 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | from .Causal_CRN_SPL_target import CRN_main 8 | from .convtasnet_module import conv_tasnet 9 | 10 | 11 | 12 | class main_model(nn.Module): 13 | def __init__(self, config): 14 | super(main_model, self).__init__() 15 | self.config=config 16 | 17 | self.eps=np.finfo(np.float32).eps 18 | ### CRN 19 | self.CRN=CRN_main.get_model(self.config['CRN']) 20 | 21 | 22 | ### convtasnet 23 | self.convtasnet=conv_tasnet.TasNet(**self.config['TasNet']) 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | def irtf_featue(self, x, target): 32 | r, i, target =self.stft_model(x, target, cplx=True) 33 | 34 | 35 | 36 | comp = torch.complex(r, i) 37 | 38 | comp_ref = comp[..., [self.ref_ch], :, :] 39 | comp_ref = torch.complex( 40 | comp_ref.real.clamp(self.eps), comp_ref.imag.clamp(self.eps) 41 | ) 42 | 43 | comp=torch.cat( 44 | (comp[..., self.ref_ch-1:self.ref_ch, :, :], comp[..., self.ref_ch+1:, :, :]), 45 | dim=-3) / comp_ref 46 | x=torch.cat((comp.real, comp.imag), dim=1) 47 | 48 | return x, target 49 | 50 | 51 | 52 | def forward(self, x, ): 53 | 54 | ssl_condition=self.CRN(x) 55 | x=self.convtasnet(x, ssl_condition) 56 | 57 | 58 | return x 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /third_year/src/preprocess/librispeech_preprocess.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import pandas as pd 3 | import soundfile as sf 4 | from soundfile import SoundFile, SEEK_END 5 | import tqdm 6 | 7 | class PreProcessor(): 8 | def __init__(self): 9 | self.train_dir='/LibriSpeech/train-clean-100/' 10 | self.test_dir='/LibriSpeech/test-clean/' 11 | 12 | self.metadata_dir='../metadata/' 13 | train_file='librispeech_train.csv' 14 | test_file='librispeech_test.csv' 15 | 16 | self.make_csv(self.metadata_dir+train_file, self.train_dir) 17 | self.make_csv(self.metadata_dir+test_file, self.test_dir) 18 | 19 | def make_csv(self, save_dir, audio_dir): 20 | 21 | # dir + speaker_id + chaper_id + wav_name 22 | column_names=['file_path', 'duration', 'speaker_id'] 23 | csv_dict={} 24 | 25 | for column in column_names: 26 | csv_dict[column]=[] 27 | 28 | 29 | 30 | 31 | audio_list=glob(audio_dir+'/**/*.flac', recursive=True) 32 | 33 | for audio in tqdm.tqdm(audio_list, total=len(audio_list)): 34 | f=SoundFile(audio) 35 | 36 | wav_len = f.seek(0, SEEK_END) 37 | audio_name=audio.split('/')[-1] 38 | speaker_id=audio_name.split('-')[0] 39 | audio_name=audio.replace(audio_dir, '') 40 | 41 | 42 | for column, data in zip(column_names, [audio_name, wav_len, speaker_id]): 43 | csv_dict[column].append(data) 44 | 45 | pd.DataFrame(csv_dict).to_csv(save_dir) 46 | 47 | 48 | 49 | 50 | if __name__=='__main__': 51 | PreProcessor() 52 | -------------------------------------------------------------------------------- /first_year/config/train1.yaml: -------------------------------------------------------------------------------- 1 | train: 2 | dataloader: 3 | csv: /home/intern0/Desktop/project/IITP/Sound_source_localization/Data_processing/speech_setup/dataset/tr.csv 4 | batch_size: 32 5 | shuffle: True 6 | num_workers: 4 7 | drop_last: False 8 | audio_path: /home/intern0/Desktop/project/IITP/Sound_source_localization/Data_processing/speech_setup/dataset 9 | duration: 64000 10 | 11 | 12 | 13 | 14 | optimizer: 15 | type: Adam 16 | learning_rate: 1e-4 17 | weight_decay: 1e-7 18 | 19 | loss: SI-SNR 20 | 21 | FFT: 22 | window_size: 512 23 | hop_size: 256 24 | val: 25 | dataloader: 26 | csv: /home/intern0/Desktop/project/IITP/Sound_source_localization/Data_processing/speech_setup/dataset/cv.csv 27 | batch_size: 32 28 | shuffle: False 29 | num_workers: 4 30 | drop_last: False 31 | audio_path: /home/intern0/Desktop/project/IITP/Sound_source_localization/Data_processing/speech_setup/dataset 32 | duration: 64000 33 | 34 | test: 35 | FFT: 36 | window_size: 512 37 | hop_size: 256 38 | model: 39 | type: CRNN 40 | trained: /home/intern0/Desktop/project/IITP/Beamformer/exp_result/2021_09_08_20_59_15/6_model.tar 41 | 42 | dataloader: 43 | csv: /home/intern0/Desktop/project/IITP/Sound_source_localization/Data_processing/speech_setup/dataset/tt.csv 44 | batch_size: 1 45 | shuffle: False 46 | num_workers: 4 47 | drop_last: False 48 | audio_path: /home/intern0/Desktop/project/IITP/Sound_source_localization/Data_processing/speech_setup/dataset 49 | duration: None 50 | 51 | 52 | exp: 53 | result: ./exp_result/ 54 | epoch: 300 55 | model: CRNN -------------------------------------------------------------------------------- /first_year/data_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import pandas as pd 3 | import soundfile as sf 4 | import numpy as np 5 | 6 | class wav_data_loader(): 7 | def __init__(self, config): 8 | csv=pd.read_csv(config['csv']) 9 | self.input_list=csv['input_path'].tolist() 10 | self.label_list=csv['label_path'].tolist() 11 | self.duration=config['duration'] 12 | self.location=config['audio_path'] 13 | self.snr=csv['SNR'].tolist() 14 | 15 | 16 | def __len__(self): 17 | return len(self.input_list) 18 | 19 | def __getitem__(self, idx): 20 | input_file=self.input_list[idx] 21 | input_file=self.location+self.input_list[idx][9:] 22 | input_file, _ = sf.read(input_file, dtype='float32') 23 | 24 | label_file=self.label_list[idx] 25 | label_file=self.location+self.label_list[idx][9:] 26 | label_file,_=sf.read(label_file, dtype='float32') 27 | 28 | if self.duration == 'None': 29 | return input_file.T, label_file.T, self.snr[idx], self.input_list[idx] 30 | 31 | if input_file.shape[0]>self.duration: 32 | start=np.random.randint(0, input_file.shape[0]-self.duration) 33 | input_file=input_file[start:start+self.duration,:] 34 | label_file=label_file[start:start+self.duration] 35 | 36 | elif input_file.shape[0] losses.avg: 96 | checkpoint = { 97 | 'epoch': epoch + 1, 98 | 'state_dict_NS': self.model.state_dict(), 99 | 'optimizer': self.optimizer.state_dict() 100 | } 101 | if epoch>-1: 102 | torch.save(checkpoint, self.exp_dir + "/{}_model.tar".format(epoch)) 103 | self.best_loss = losses.avg 104 | print("\n") 105 | # return losses.avg 106 | # exit() 107 | 108 | def train(self, epoch): 109 | self.model=self.model.train() 110 | self.optimizer.zero_grad() 111 | 112 | losses = AverageMeter() 113 | times = AverageMeter() 114 | losses.reset() 115 | times.reset() 116 | 117 | for iter_num, (input_data, label) in tqdm(enumerate(self.train_loader), desc='Train', total=len(self.train_loader)): 118 | input_data=input_data.to(self.device) 119 | 120 | output=self.model(input_data) 121 | label=label.to(device) 122 | loss_iter=self.loss_function(output, label) 123 | losses.update(loss_iter.item()) 124 | 125 | loss_iter.backward() 126 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5) 127 | self.optimizer.step() 128 | self.optimizer.zero_grad() 129 | 130 | # break 131 | # print(losses.avg, times) 132 | print('epoch %d, training losses: %f'%(epoch, losses.avg), end='\r') 133 | print("\n") 134 | # exit() 135 | # retur 136 | 137 | def init_optimizer(self): 138 | opti_option=self.config_data['train']['optimizer'] 139 | opti_type=opti_option['type'] 140 | if opti_type=='Adam': 141 | optimizer=torch.optim.Adam(self.model.parameters(), lr=float(opti_option['learning_rate']), weight_decay=float(opti_option['weight_decay'])) 142 | 143 | return optimizer 144 | 145 | 146 | 147 | if __name__=='__main__': 148 | device=randomseed_init(777) 149 | t=trainer('./config/train1.yaml', device) -------------------------------------------------------------------------------- /second_year/src/models/convtasnet_SSL_FiLM/convtasnet_module/conv_tasnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | from .utility import models 7 | 8 | 9 | # Conv-TasNet 10 | class TasNet(nn.Module): 11 | def __init__(self, enc_dim=512, feature_dim=128, sr=16000, win=2, layer=8, stack=3, 12 | kernel=3, num_spk=2, causal=True, ch_size=8, skip=False, condi_weight=[360, 128], condi_bias=[360, 128], Film_loc=8, padding='auto'): 13 | super(TasNet, self).__init__() 14 | 15 | # hyper parameters 16 | self.num_spk = num_spk 17 | self.skip=skip 18 | 19 | self.enc_dim = enc_dim 20 | self.feature_dim = feature_dim 21 | 22 | self.win = int(sr*win/1000) 23 | self.stride = self.win // 2 24 | 25 | self.layer = layer 26 | self.stack = stack 27 | self.kernel = kernel 28 | 29 | self.causal = causal 30 | self.ch_size=ch_size 31 | self.Film_loc=Film_loc 32 | self.padding=padding 33 | # input encoder 34 | self.encoder = nn.Conv1d(1, self.enc_dim, self.win, bias=False, stride=self.stride) 35 | 36 | # conditioning FiLM 37 | self.FiLM_weight=nn.Conv1d(condi_weight[0], condi_weight[1], 1) 38 | self.FiLM_bias=nn.Conv1d(condi_bias[0], condi_bias[1], 1) 39 | 40 | # TCN separator 41 | self.TCN = models.TCN(self.ch_size, self.enc_dim, self.enc_dim*self.num_spk, self.feature_dim, self.feature_dim*4, 42 | self.layer, self.stack, self.kernel, causal=self.causal, skip=self.skip, Film_loc=self.Film_loc) 43 | 44 | self.receptive_field = self.TCN.receptive_field 45 | 46 | # output decoder 47 | self.decoder = nn.ConvTranspose1d(self.enc_dim, 1, self.win, bias=False, stride=self.stride) 48 | 49 | def pad_signal(self, input): 50 | 51 | # input is the waveforms: (B, T) or (B, 1, T) 52 | # reshape and padding 53 | if input.dim() not in [2, 3]: 54 | raise RuntimeError("Input can only be 2 or 3 dimensional.") 55 | 56 | if input.dim() == 2: 57 | input = input.unsqueeze(1) 58 | batch_size = input.size(0) 59 | nsample = input.size(2) 60 | rest = self.win - (self.stride + nsample % self.win) % self.win 61 | 62 | if rest > 0: 63 | pad = Variable(torch.zeros(batch_size, 1, rest)).type(input.type()) 64 | input = torch.cat([input, pad], 2) 65 | 66 | pad_aux = Variable(torch.zeros(batch_size, 1, self.stride)).type(input.type()) 67 | input = torch.cat([pad_aux, input, pad_aux], 2) 68 | # print(input.shape, pad_aux.shape) 69 | # exit() 70 | 71 | return input, rest 72 | 73 | def forward(self, input, ssl_condition): 74 | 75 | batch_size, ch_size, sample_size= input.shape 76 | 77 | ref_ch=[ch_size*k for k in range(batch_size)] 78 | input=input.view(batch_size*ch_size, sample_size) 79 | # padding 80 | if self.padding =='auto': 81 | output, rest = self.pad_signal(input) 82 | else: 83 | output=F.pad(input, self.padding, mode='constant').unsqueeze(1) 84 | 85 | 86 | 87 | # waveform encoder 88 | enc_output = self.encoder(output) # B, N, L 89 | # print(enc_output.shape) 90 | # exit() 91 | target_output=enc_output[ref_ch] 92 | 93 | enc_output=enc_output.view(batch_size, -1, enc_output.shape[-1]) 94 | # print(enc_output.shape) 95 | # exit() 96 | 97 | remainder=enc_output.shape[-1]%ssl_condition.shape[-1] 98 | 99 | 100 | frame_ratio=enc_output.shape[-1]//ssl_condition.shape[-1] 101 | 102 | ssl_condition=ssl_condition.squeeze(1).repeat_interleave(frame_ratio, dim=-1) 103 | # ssl_condition=F.pad(ssl_condition, [0, remainder], mode='constant') 104 | # print(ssl_condition.shape, enc_output.shape) 105 | # exit() 106 | # print(ssl_condition.shape) 107 | # exit() 108 | ssl_weight=self.FiLM_weight(ssl_condition) 109 | ssl_bias=self.FiLM_bias(ssl_condition) 110 | 111 | 112 | # generate masks 113 | masks = torch.sigmoid(self.TCN(enc_output, ssl_weight, ssl_bias)).view(batch_size, self.num_spk, self.enc_dim, -1) # B, C, N, L 114 | masked_output = target_output.unsqueeze(1) * masks # B, C, N, L 115 | 116 | # waveform decoder 117 | output = self.decoder(masked_output.view(batch_size*self.num_spk, self.enc_dim, -1)) # B*C, 1, L 118 | if self.padding=='auto': 119 | output = output[:,:,self.stride:-(rest+self.stride)].contiguous() # B*C, 1, L 120 | else: 121 | output=output[:, :, self.padding[0]:-self.padding[1]] 122 | output = output.view(batch_size, -1) # B, C, T 123 | # print(output.shape) 124 | # exit() 125 | 126 | return output 127 | 128 | def test_conv_tasnet(): 129 | x = torch.rand(2, 32000) 130 | nnet = TasNet() 131 | x = nnet(x) 132 | s1 = x[0] 133 | print(s1.shape) 134 | 135 | 136 | if __name__ == "__main__": 137 | test_conv_tasnet() -------------------------------------------------------------------------------- /fourth_year/src/models/FSPEN/modules/sequence_modules.py: -------------------------------------------------------------------------------- 1 | # !/user/bin/env python 2 | # -*-coding:utf-8 -*- 3 | 4 | """ 5 | # File : sequence_modules.py 6 | # Time : 2024/4/10 上午9:35 7 | # Author : wukeyi 8 | # version : python3.9 9 | """ 10 | from typing import List 11 | 12 | import torch 13 | from torch import nn, Tensor 14 | 15 | 16 | class GroupRNN(nn.Module): 17 | def __init__(self, input_size: int, 18 | hidden_size: int, 19 | groups: int, 20 | rnn_type: str, 21 | num_layers: int = 1, 22 | bidirectional: bool = False, 23 | batch_first: bool = True): 24 | super().__init__() 25 | assert input_size % groups == 0, \ 26 | f"input_size % groups must be equal to 0, but got {input_size} % {groups} = {input_size % groups}" 27 | 28 | self.groups = groups 29 | self.rnn_list = nn.ModuleList() 30 | for _ in range(groups): 31 | self.rnn_list.append( 32 | getattr(nn, rnn_type)(input_size=input_size // groups, hidden_size=hidden_size//groups, 33 | num_layers=num_layers, 34 | bidirectional=bidirectional, batch_first=batch_first) 35 | ) 36 | 37 | def forward(self, inputs: Tensor, hidden_state: List[Tensor]): 38 | """ 39 | :param hidden_state: List[state1, state2, ...], len(hidden_state) = groups 40 | state shape = (num_layers*bidirectional, batch*[], hidden_size) if rnn_type is GRU or RNN, otherwise, 41 | state = (h0, c0), h0/c0 shape = (num_layers*bidirectional, batch*[], hidden_size). 42 | :param inputs: (batch, steps, input_size) 43 | :return: 44 | """ 45 | outputs = [] 46 | out_states = [] 47 | batch, steps, _ = inputs.shape 48 | 49 | inputs = torch.reshape(inputs, shape=(batch, steps, self.groups, -1)) # (batch, steps, groups, width) 50 | for idx, rnn in enumerate(self.rnn_list): 51 | out, state = rnn(inputs[:, :, idx, :], hidden_state[idx]) 52 | outputs.append(out) # (batch, steps, hidden_size) 53 | out_states.append(state) # (num_layers*bidirectional, batch*[], hidden_size) 54 | 55 | outputs = torch.cat(outputs, dim=2) # (batch, steps, hidden_size * groups) 56 | 57 | return outputs, out_states 58 | 59 | 60 | class DualPathExtensionRNN(nn.Module): 61 | def __init__(self, input_size: int, 62 | intra_hidden_size: int, 63 | inter_hidden_size: int, 64 | groups: int, 65 | rnn_type: str): 66 | super().__init__() 67 | assert rnn_type in ["RNN", "GRU", "LSTM"], f"rnn_type should be RNN/GRU/LSTM, but got {rnn_type}!" 68 | 69 | self.intra_chunk_rnn = getattr(nn, rnn_type)(input_size=input_size, hidden_size=intra_hidden_size, 70 | num_layers=1, bidirectional=True, batch_first=True) 71 | self.intra_chunk_fc = nn.Linear(in_features=intra_hidden_size*2, out_features=input_size) 72 | self.intra_chunk_norm = nn.LayerNorm(normalized_shape=input_size, elementwise_affine=True) 73 | 74 | self.inter_chunk_rnn = GroupRNN(input_size=input_size, hidden_size=inter_hidden_size, groups=groups, 75 | rnn_type=rnn_type) 76 | self.inter_chunk_fc = nn.Linear(in_features=inter_hidden_size, out_features=input_size) 77 | 78 | def forward(self, inputs: Tensor, hidden_state: List[Tensor]): 79 | """ 80 | :param hidden_state: List[state1, state2, ...], len(hidden_state) = groups 81 | state shape = (num_layers*bidirectional, batch*[], hidden_size) if rnn_type is GRU or RNN, otherwise, 82 | state = (h0, c0), h0/c0 shape = (num_layers*bidirectional, batch*[], hidden_size). 83 | :param inputs: (B, F, T, N) 84 | :return: 85 | """ 86 | B, F, T, N = inputs.shape 87 | intra_out = torch.transpose(inputs, dim0=1, dim1=2).contiguous() # (B, T, F, N) 88 | intra_out = torch.reshape(intra_out, shape=(B * T, F, N)) 89 | intra_out, _ = self.intra_chunk_rnn(intra_out) 90 | intra_out = self.intra_chunk_fc(intra_out) # (B, T, F, N) 91 | intra_out = torch.reshape(intra_out, shape=(B, T, F, N)) 92 | intra_out = torch.transpose(intra_out, dim0=1, dim1=2).contiguous() # (B, F, T, N) 93 | intra_out = self.intra_chunk_norm(intra_out) # (B, F, T, N) 94 | 95 | intra_out = inputs + intra_out # residual add 96 | 97 | inter_out = torch.reshape(intra_out, shape=(B * F, T, N)) # (B*F, T, N) 98 | inter_out, hidden_state = self.inter_chunk_rnn(inter_out, hidden_state) 99 | inter_out = torch.reshape(inter_out, shape=(B, F, T, -1)) # (B, F, T, groups * N) 100 | inter_out = self.inter_chunk_fc(inter_out) # (B, F, T, N) 101 | 102 | inter_out = inter_out + intra_out # residual add 103 | 104 | return inter_out, hidden_state 105 | 106 | 107 | if __name__ == "__main__": 108 | test_model = DualPathExtensionRNN(input_size=32, intra_hidden_size=16, inter_hidden_size=16, 109 | groups=8, rnn_type="LSTM") 110 | test_data = torch.randn(5, 32, 10, 32) 111 | test_state = [(torch.randn(1, 5*32, 16), torch.randn(1, 5*32, 16)) for _ in range(8)] 112 | test_out = test_model(test_data, test_state) 113 | -------------------------------------------------------------------------------- /third_year/src/loss/SDR_loss.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from itertools import permutations 4 | from torch.autograd import Variable 5 | from torch.nn.modules.loss import _Loss 6 | 7 | import scipy,time,numpy 8 | import itertools 9 | 10 | import torch 11 | 12 | EPS = 1e-8 13 | # class SingleSrcNegSDR(_Loss): 14 | # r"""Base class for single-source negative SI-SDR, SD-SDR and SNR. 15 | 16 | # Args: 17 | # sdr_type (str): choose between ``snr`` for plain SNR, ``sisdr`` for 18 | # SI-SDR and ``sdsdr`` for SD-SDR [1]. 19 | # zero_mean (bool, optional): by default it zero mean the target and 20 | # estimate before computing the loss. 21 | # take_log (bool, optional): by default the log10 of sdr is returned. 22 | # reduction (string, optional): Specifies the reduction to apply to 23 | # the output: 24 | # ``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied, 25 | # ``'mean'``: the sum of the output will be divided by the number of 26 | # elements in the output. 27 | 28 | # Shape: 29 | # - est_targets : :math:`(batch, time)`. 30 | # - targets: :math:`(batch, time)`. 31 | 32 | # Returns: 33 | # :class:`torch.Tensor`: with shape :math:`(batch)` if ``reduction='none'`` else 34 | # [] scalar if ``reduction='mean'``. 35 | 36 | # Examples 37 | # >>> import torch 38 | # >>> from asteroid.losses import PITLossWrapper 39 | # >>> targets = torch.randn(10, 2, 32000) 40 | # >>> est_targets = torch.randn(10, 2, 32000) 41 | # >>> loss_func = PITLossWrapper(SingleSrcNegSDR("sisdr"), 42 | # >>> pit_from='pw_pt') 43 | # >>> loss = loss_func(est_targets, targets) 44 | 45 | # References 46 | # [1] Le Roux, Jonathan, et al. "SDR half-baked or well done." IEEE 47 | # International Conference on Acoustics, Speech and Signal 48 | # Processing (ICASSP) 2019. 49 | # """ 50 | 51 | # def __init__(self, sdr_type, zero_mean=True, take_log=True, reduction="none", EPS=1e-8): 52 | # assert reduction != "sum", NotImplementedError 53 | # super().__init__(reduction=reduction) 54 | 55 | # assert sdr_type in ["snr", "sisdr", "sdsdr"] 56 | # self.sdr_type = sdr_type 57 | # self.zero_mean = zero_mean 58 | # self.take_log = take_log 59 | # self.EPS = 1e-8 60 | 61 | # def forward(self, est_target, target): 62 | # if target.size() != est_target.size() or target.ndim != 2: 63 | # raise TypeError( 64 | # f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead" 65 | # ) 66 | # # Step 1. Zero-mean norm 67 | # if self.zero_mean: 68 | # mean_source = torch.mean(target, dim=1, keepdim=True) 69 | # mean_estimate = torch.mean(est_target, dim=1, keepdim=True) 70 | # target = target - mean_source 71 | # est_target = est_target - mean_estimate 72 | # # Step 2. Pair-wise SI-SDR. 73 | # if self.sdr_type in ["sisdr", "sdsdr"]: 74 | # # [batch, 1] 75 | # dot = torch.sum(est_target * target, dim=1, keepdim=True) 76 | # # [batch, 1] 77 | # s_target_energy = torch.sum(target ** 2, dim=1, keepdim=True) + self.EPS 78 | # # [batch, time] 79 | # scaled_target = dot * target / s_target_energy 80 | # else: 81 | # # [batch, time] 82 | # scaled_target = target 83 | # if self.sdr_type in ["sdsdr", "snr"]: 84 | # e_noise = est_target - target 85 | # else: 86 | # e_noise = est_target - scaled_target 87 | # # [batch] 88 | # losses = torch.sum(scaled_target ** 2, dim=1) / (torch.sum(e_noise ** 2, dim=1) + self.EPS) 89 | # if self.take_log: 90 | # losses = 10 * torch.log10(losses + self.EPS) 91 | # losses = losses.mean() if self.reduction == "mean" else losses 92 | # return -losses 93 | 94 | class SDR_loss(_Loss): 95 | def __init__(self, sdr_type, zero_mean=True, take_log=True, reduction="none", EPS=1e-8, clipping=30): 96 | super(SDR_loss, self).__init__() 97 | self.reduction=reduction 98 | self.EPS=float(EPS) 99 | self.clipping=10**(-1*clipping/10) 100 | self.sdr_type = sdr_type 101 | self.zero_mean = zero_mean 102 | self.take_log = take_log 103 | 104 | 105 | def forward(self, est_target, target): 106 | if target.size() != est_target.size() or target.ndim != 2: 107 | raise TypeError( 108 | f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead" 109 | ) 110 | # Step 1. Zero-mean norm 111 | if self.zero_mean: 112 | mean_source = torch.mean(target, dim=1, keepdim=True) 113 | mean_estimate = torch.mean(est_target, dim=1, keepdim=True) 114 | target = target - mean_source 115 | est_target = est_target - mean_estimate 116 | # Step 2. Pair-wise SI-SDR. 117 | if self.sdr_type in ["sisdr", "sdsdr"]: 118 | # [batch, 1] 119 | dot = torch.sum(est_target * target, dim=1, keepdim=True) 120 | # [batch, 1] 121 | 122 | s_target_energy = torch.sum(target ** 2, dim=1, keepdim=True) + self.EPS 123 | # [batch, time] 124 | scaled_target = dot * target / s_target_energy 125 | else: 126 | # [batch, time] 127 | scaled_target = target 128 | 129 | if self.sdr_type in ["sdsdr", "snr"]: 130 | e_noise = est_target - target 131 | else: 132 | e_noise = est_target - scaled_target 133 | # [batch] 134 | scaled_target_power=torch.sum(scaled_target ** 2, dim=1) 135 | 136 | losses = scaled_target_power/ (torch.sum(e_noise ** 2, dim=1) + self.EPS+self.clipping*scaled_target_power)#*scaled_target_power) 137 | 138 | if self.take_log: 139 | losses = 10 * torch.log10(losses + self.EPS) 140 | losses = losses.mean() if self.reduction == "mean" else losses 141 | 142 | return -losses -------------------------------------------------------------------------------- /third_year/src/models/EABNET/FFT.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | import math 5 | import torch as th 6 | import torch 7 | import numpy as np 8 | from scipy.signal import get_window 9 | 10 | 11 | EPSILON = th.finfo(th.float32).eps 12 | MATH_PI = math.pi 13 | def init_kernels(win_len, 14 | win_inc, 15 | fft_len, 16 | win_type=None, 17 | invers=False, 18 | sqrt_window=False): 19 | if win_type == 'None' or win_type is None: 20 | # N 21 | window = np.ones(win_len) 22 | else: 23 | # N 24 | window = get_window(win_type, win_len, fftbins=True)#**0.5 25 | 26 | 27 | 28 | if sqrt_window: 29 | window = np.sqrt(window) 30 | 31 | N = fft_len 32 | # N x F 33 | fourier_basis = np.fft.rfft(np.eye(N))[:win_len] 34 | # N x F 35 | real_kernel = np.real(fourier_basis) 36 | imag_kernel = np.imag(fourier_basis) 37 | # 2F x N 38 | kernel = np.concatenate([real_kernel, imag_kernel], 1).T 39 | if invers : 40 | kernel = np.linalg.pinv(kernel).T 41 | 42 | # 2F x N * N => 2F x N 43 | kernel = kernel*window 44 | # 2F x 1 x N 45 | kernel = kernel[:, None, :] 46 | return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None,:,None].astype(np.float32)) 47 | 48 | 49 | class ConvSTFT(nn.Module): 50 | 51 | def __init__(self, 52 | win_len, 53 | win_inc, 54 | fft_len=None, 55 | vad_threshold=2/3, 56 | win_type='hamming', 57 | sqrt_window=False 58 | # fix=True 59 | ): 60 | super(ConvSTFT, self).__init__() 61 | 62 | if fft_len == None: 63 | self.fft_len = np.int(2**np.ceil(np.log2(win_len))) 64 | else: 65 | self.fft_len = fft_len 66 | 67 | # 2F x 1 x N 68 | kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type, sqrt_window=sqrt_window) 69 | get_target_kernel=torch.ones_like(kernel).mean(dim=0, keepdim=True)/kernel.shape[-1] 70 | 71 | self.register_buffer('weight', kernel) 72 | self.register_buffer('target_kernel', get_target_kernel) 73 | 74 | 75 | self.stride = win_inc 76 | self.win_len = win_len 77 | self.dim = self.fft_len 78 | 79 | def forward(self, inputs, cplx=False): 80 | 81 | 82 | if inputs.dim() == 2: 83 | # N x 1 x L 84 | inputs = torch.unsqueeze(inputs, 1) 85 | inputs = F.pad(inputs,[self.win_len-self.stride, self.win_len-self.stride]) 86 | # N x 2F x T 87 | outputs = F.conv1d(inputs, self.weight, stride=self.stride) 88 | # N x F x T 89 | r, i = th.chunk(outputs, 2, dim=1) 90 | else: 91 | 92 | N, C, L = inputs.shape 93 | inputs = inputs.view(N * C, 1, L) 94 | # NC x 1 x L 95 | inputs = F.pad(inputs, [self.win_len-self.stride, self.win_len-self.stride]) 96 | 97 | 98 | 99 | 100 | # NC x 2F x T 101 | outputs = F.conv1d(inputs, self.weight, stride=self.stride) 102 | 103 | # N x C x 2F x T 104 | outputs = outputs.view(N, C, -1, outputs.shape[-1]) 105 | 106 | # N x C x F x T 107 | r, i = th.chunk(outputs, 2, dim=2) 108 | 109 | 110 | if cplx: 111 | return r, i 112 | else: 113 | 114 | 115 | mags=torch.pow(r**2+i**2, 0.5) 116 | mags=th.clamp(mags, 1e-12) 117 | 118 | phase=torch.atan2(i, r) 119 | return mags, phase, r, i 120 | 121 | 122 | 123 | class ConviSTFT(nn.Module): 124 | 125 | def __init__(self, 126 | win_len, 127 | win_inc, 128 | fft_len=None, 129 | win_type='hamming', 130 | sqrt_window=False 131 | # fix=True 132 | ): 133 | super(ConviSTFT, self).__init__() 134 | if fft_len == None: 135 | self.fft_len = np.int(2**np.ceil(np.log2(win_len))) 136 | else: 137 | self.fft_len = fft_len 138 | 139 | # kernel: 2F x 1 x N 140 | # window: 1 x N x 1 141 | kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True, sqrt_window=sqrt_window) 142 | #self.weight = nn.Parameter(kernel, requires_grad=(not fix)) 143 | self.register_buffer('weight', kernel) 144 | self.win_type = win_type 145 | self.win_len = win_len 146 | self.stride = win_inc 147 | self.stride = win_inc 148 | self.dim = self.fft_len 149 | self.register_buffer('window', window) 150 | self.register_buffer('enframe', torch.eye(win_len)[:,None,:]) 151 | 152 | def forward(self, inputs, phase, cplx=False): 153 | """ 154 | inputs : [B, N//2+1, T] (mags, real) 155 | phase: [B, N//2+1, T] (phase, imag) 156 | """ 157 | 158 | if cplx: 159 | # N x 2F x T 160 | cspec = torch.cat([inputs, phase], dim=1) 161 | else: 162 | # N x F x T 163 | real = inputs*torch.cos(phase) 164 | imag = inputs*torch.sin(phase) 165 | # N x 2F x T 166 | cspec = torch.cat([real, imag], dim=1) 167 | # N x 1 x L 168 | outputs = F.conv_transpose1d(cspec, self.weight, stride=self.stride) 169 | 170 | # this is from torch-stft: https://github.com/pseeth/torch-stft 171 | # 1 x N x T 172 | t = self.window.repeat(1,1,inputs.size(-1))**2 173 | # 1 x 1 x L 174 | coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) 175 | 176 | outputs = outputs/(coff+1e-12) 177 | 178 | #outputs = torch.where(coff == 0, outputs, outputs/coff) 179 | # N x 1 x L 180 | 181 | outputs = outputs[...,self.win_len-self.stride:] 182 | # N x L 183 | outputs = outputs.squeeze(1) 184 | return outputs -------------------------------------------------------------------------------- /second_year/src/models/convtasnet_SSL_FiLM/Causal_CRN_SPL_target/CRN_SPL_target.py: -------------------------------------------------------------------------------- 1 | from .FFT import EPSILON, ConvSTFT, ConviSTFT 2 | from torch import nn 3 | import torch 4 | from util import * 5 | 6 | import numpy as np 7 | 8 | class Causal_Conv2D_Block(nn.Module): 9 | def __init__(self, *args, **kwargs): 10 | super(Causal_Conv2D_Block, self).__init__() 11 | 12 | 13 | self.conv2d=nn.Conv2d(*args, **kwargs) 14 | 15 | 16 | self.norm=nn.BatchNorm2d(args[1]) 17 | 18 | self.activation=nn.ELU() 19 | 20 | def forward(self, x): 21 | original_frame_num=x.shape[-1] 22 | x=self.conv2d(x) 23 | x=self.norm(x) 24 | x=self.activation(x) 25 | x=x[...,:original_frame_num] 26 | 27 | 28 | 29 | return x 30 | 31 | class Conv1D_Block(nn.Module): 32 | def __init__(self, *args, **kwargs): 33 | super(Conv1D_Block, self).__init__() 34 | 35 | 36 | self.conv1d=nn.Conv1d(*args, **kwargs) 37 | 38 | 39 | self.norm=nn.BatchNorm1d(args[1]) 40 | 41 | self.activation=nn.ELU() 42 | 43 | def forward(self, x): 44 | 45 | x=self.conv1d(x) 46 | x=self.norm(x) 47 | x=self.activation(x) 48 | 49 | 50 | 51 | return x 52 | 53 | 54 | class crn(nn.Module): 55 | def __init__(self, config, output_num, azi_size): 56 | super(crn, self).__init__() 57 | 58 | 59 | self.output_num=output_num 60 | self.azi_size=azi_size 61 | 62 | 63 | self.cnn_num=config['CNN']['layer_num'] 64 | self.kernel_size=config['CNN']['kernel_size'] 65 | self.filter_size=config['CNN']['filter'] 66 | 67 | self.max_pool_kernel=config['CNN']['max_pool']['kernel_size'] 68 | self.max_pool_stride=config['CNN']['max_pool']['stride'] 69 | 70 | args=[2*(config['input_audio_channel']-1),self.filter_size,self.kernel_size] # in_channel, out_channel, kernel size 71 | 72 | kwargs={'stride': 1, 'padding': [1,2], 'dilation': 1} 73 | 74 | 75 | 76 | 77 | self.cnn=nn.ModuleList() 78 | self.pooling=nn.ModuleList() 79 | self.cnn.append(Causal_Conv2D_Block(*args, **kwargs)) 80 | self.pooling.append(nn.MaxPool2d(self.max_pool_kernel, stride=self.max_pool_stride)) 81 | 82 | args[0]=config['CNN']['filter'] 83 | for count in range(self.cnn_num-1): 84 | self.cnn.append(Causal_Conv2D_Block(*args, **kwargs)) 85 | self.pooling.append(nn.MaxPool2d(self.max_pool_kernel, stride=self.max_pool_stride)) 86 | 87 | self.GRU_layer=nn.GRU(**config['GRU']) 88 | self.h0=torch.zeros(*config['GRU_init']['shape']) 89 | self.h0=torch.nn.parameter.Parameter(self.h0, requires_grad=config['GRU_init']['learnable']) 90 | 91 | 92 | self.azi_mapping_conv_layer=nn.ModuleList() 93 | self.azi_mapping_final=nn.ModuleList() 94 | 95 | args[0]=config['GRU']['hidden_size'] 96 | args[1]=config['GRU']['hidden_size'] 97 | args[2]=1 98 | kwargs['padding']=0 99 | 100 | for _ in range(output_num): 101 | self.azi_mapping_conv_layer.append(Conv1D_Block(*args, **kwargs)) 102 | self.azi_mapping_final.append(nn.Conv1d(config['GRU']['hidden_size'], self.azi_size, 1)) 103 | 104 | 105 | 106 | 107 | 108 | 109 | def forward(self, x): 110 | 111 | for cnn_layer, pooling_layer in zip(self.cnn, self.pooling): 112 | 113 | x=cnn_layer(x)[...,:x.shape[-1]] 114 | x=pooling_layer(x) 115 | 116 | 117 | 118 | b, c, f, t=x.shape 119 | x=x.view(b, -1, t).permute(0,2,1) 120 | 121 | 122 | h0=self.h0.repeat_interleave(x.shape[0]) 123 | self.GRU_layer.flatten_parameters() 124 | 125 | h0=h0.view(self.h0.shape[0], x.shape[0], self.h0.shape[-1]) 126 | 127 | x, h=self.GRU_layer(x, h0) 128 | 129 | x=x.permute(0,2,1) 130 | 131 | outputs=[] 132 | 133 | for final_layer, cnn_layer in zip(self.azi_mapping_final, self.azi_mapping_conv_layer): 134 | x=cnn_layer(x) 135 | res_output=final_layer(x) 136 | outputs.append(res_output) 137 | output=torch.stack(outputs).permute(1,0,2,3) 138 | 139 | return output 140 | 141 | 142 | class main_model(nn.Module): 143 | def __init__(self, config): 144 | super(main_model, self).__init__() 145 | self.config=config 146 | 147 | self.eps=np.finfo(np.float32).eps 148 | self.ref_ch=self.config['ref_ch'] 149 | 150 | ###### sigma 151 | 152 | self.p=torch.tensor(self.config['p']) 153 | self.sigma=torch.tensor(self.config['sigma_start']) 154 | self.sigma_max=torch.tensor(self.config['sigma_end']['max']) 155 | self.sigma_min=torch.tensor(self.config['sigma_end']['min']) 156 | self.sigma_rate=torch.tensor(self.config['sigma_rate']) 157 | self.sigma_udpate_method=self.config['sigma_update_method'] 158 | 159 | self.iteration_count=0 160 | self.epoch_count=0 161 | self.now_epoch=0 162 | 163 | 164 | ###### 165 | 166 | self.max_spk=self.config['max_spk'] 167 | self.degree_resolution = self.config['degree_resolution'] 168 | self.azi_size=360//self.degree_resolution 169 | 170 | self.stft_model=ConvSTFT(**self.config['FFT']) 171 | self.crn=crn(self.config['CRN'], self.sigma.shape[0], self.azi_size) 172 | 173 | 174 | def irtf_featue(self, x,): 175 | r, i, =self.stft_model(x, cplx=True) 176 | 177 | 178 | 179 | comp = torch.complex(r, i) 180 | 181 | comp_ref = comp[..., [self.ref_ch], :, :] 182 | comp_ref = torch.complex( 183 | comp_ref.real.clamp(self.eps), comp_ref.imag.clamp(self.eps) 184 | ) 185 | 186 | comp=torch.cat( 187 | (comp[..., self.ref_ch-1:self.ref_ch, :, :], comp[..., self.ref_ch+1:, :, :]), 188 | dim=-3) / comp_ref 189 | x=torch.cat((comp.real, comp.imag), dim=1) 190 | 191 | return x 192 | 193 | 194 | 195 | def forward(self, x,): 196 | 197 | ###### irtf feature 198 | x=self.irtf_featue(x) 199 | 200 | x=self.crn(x).sigmoid() 201 | 202 | 203 | 204 | return x 205 | 206 | -------------------------------------------------------------------------------- /fourth_year/src/models/FSPEN/FFT.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | import math 5 | import torch as th 6 | import torch 7 | import numpy as np 8 | from scipy.signal import get_window 9 | import matplotlib.pyplot as plt 10 | 11 | EPSILON = th.finfo(th.float32).eps 12 | MATH_PI = math.pi 13 | def init_kernels(win_len, 14 | win_inc, 15 | fft_len, 16 | win_type=None, 17 | invers=False, 18 | sqrt_window=False): 19 | if win_type == 'None' or win_type is None: 20 | # N 21 | window = np.ones(win_len) 22 | else: 23 | # N 24 | window = get_window(win_type, win_len, fftbins=True)#**0.5 25 | 26 | 27 | 28 | if sqrt_window: 29 | window = np.sqrt(window) 30 | 31 | N = fft_len 32 | # N x F 33 | fourier_basis = np.fft.rfft(np.eye(N))[:win_len] 34 | # N x F 35 | real_kernel = np.real(fourier_basis) 36 | imag_kernel = np.imag(fourier_basis) 37 | # 2F x N 38 | kernel = np.concatenate([real_kernel, imag_kernel], 1).T 39 | if invers : 40 | kernel = np.linalg.pinv(kernel).T 41 | 42 | # 2F x N * N => 2F x N 43 | kernel = kernel*window 44 | # 2F x 1 x N 45 | kernel = kernel[:, None, :] 46 | return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None,:,None].astype(np.float32)) 47 | 48 | 49 | class ConvSTFT(nn.Module): 50 | 51 | def __init__(self, 52 | win_len, 53 | win_inc, 54 | fft_len=None, 55 | vad_threshold=2/3, 56 | win_type='hamming', 57 | sqrt_window=False 58 | # fix=True 59 | ): 60 | super(ConvSTFT, self).__init__() 61 | 62 | if fft_len == None: 63 | self.fft_len = np.int(2**np.ceil(np.log2(win_len))) 64 | else: 65 | self.fft_len = fft_len 66 | 67 | # 2F x 1 x N 68 | kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type, sqrt_window=sqrt_window) 69 | get_target_kernel=torch.ones_like(kernel).mean(dim=0, keepdim=True)/kernel.shape[-1] 70 | # print(kernel.shape, get_target_kernel.shape) 71 | # exit() 72 | self.register_buffer('weight', kernel) 73 | self.register_buffer('target_kernel', get_target_kernel) 74 | 75 | 76 | self.stride = win_inc 77 | self.win_len = win_len 78 | self.dim = self.fft_len 79 | 80 | def forward(self, inputs, cplx=False): 81 | 82 | 83 | if inputs.dim() == 2: 84 | # N x 1 x L 85 | inputs = torch.unsqueeze(inputs, 1) 86 | inputs = F.pad(inputs,[self.win_len-self.stride, self.win_len-self.stride]) 87 | # N x 2F x T 88 | outputs = F.conv1d(inputs, self.weight, stride=self.stride) 89 | # N x F x T 90 | r, i = th.chunk(outputs, 2, dim=1) 91 | else: 92 | 93 | N, C, L = inputs.shape 94 | inputs = inputs.view(N * C, 1, L) 95 | # NC x 1 x L 96 | inputs = F.pad(inputs, [self.win_len-self.stride, self.win_len-self.stride]) 97 | 98 | 99 | 100 | 101 | # NC x 2F x T 102 | outputs = F.conv1d(inputs, self.weight, stride=self.stride) 103 | 104 | # N x C x 2F x T 105 | outputs = outputs.view(N, C, -1, outputs.shape[-1]) 106 | 107 | # N x C x F x T 108 | r, i = th.chunk(outputs, 2, dim=2) 109 | 110 | 111 | if cplx: 112 | return r, i 113 | else: 114 | # mags = th.clamp(r**2 + i**2, EPSILON)**0.5 115 | # phase = th.atan2(i+EPSILON, r+EPSILON) 116 | # return mags, phase, r, i 117 | 118 | mags=torch.pow(r**2+i**2, 0.5) 119 | mags=th.clamp(mags, 1e-12) 120 | 121 | phase=torch.atan2(i, r) 122 | return mags, phase, r, i 123 | 124 | 125 | def get_target_frame(self, inputs, threshold=0.5): 126 | inputs = torch.unsqueeze(inputs, 1) 127 | 128 | inputs = F.pad(inputs,[self.win_len-self.stride, self.win_len-self.stride]) 129 | # N x 2F x T 130 | outputs = F.conv1d(inputs, self.target_kernel, stride=self.stride).squeeze(1) 131 | # plt.plot(outputs[0,0].detach().cpu().numpy()) 132 | # plt.plot(outputs[1,0].detach().cpu().numpy()) 133 | # plt.plot(outputs[2,0].detach().cpu().numpy()) 134 | # plt.plot(outputs[3,0].detach().cpu().numpy()) 135 | # plt.savefig('../results/pngs/frames.png') 136 | # print(outputs.shape) 137 | # exit() 138 | outputs=torch.ge(outputs, threshold).float() 139 | 140 | # outputs=outputs.squeeze().detach().cpu().numpy() 141 | return outputs 142 | class ConviSTFT(nn.Module): 143 | 144 | def __init__(self, 145 | win_len, 146 | win_inc, 147 | fft_len=None, 148 | win_type='hamming', 149 | sqrt_window=False 150 | # fix=True 151 | ): 152 | super(ConviSTFT, self).__init__() 153 | if fft_len == None: 154 | self.fft_len = np.int(2**np.ceil(np.log2(win_len))) 155 | else: 156 | self.fft_len = fft_len 157 | 158 | # kernel: 2F x 1 x N 159 | # window: 1 x N x 1 160 | kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True, sqrt_window=sqrt_window) 161 | #self.weight = nn.Parameter(kernel, requires_grad=(not fix)) 162 | self.register_buffer('weight', kernel) 163 | self.win_type = win_type 164 | self.win_len = win_len 165 | self.stride = win_inc 166 | self.stride = win_inc 167 | self.dim = self.fft_len 168 | self.register_buffer('window', window) 169 | self.register_buffer('enframe', torch.eye(win_len)[:,None,:]) 170 | 171 | def forward(self, inputs, phase, cplx=False): 172 | """ 173 | inputs : [B, N//2+1, T] (mags, real) 174 | phase: [B, N//2+1, T] (phase, imag) 175 | """ 176 | 177 | if cplx: 178 | # N x 2F x T 179 | cspec = torch.cat([inputs, phase], dim=1) 180 | else: 181 | # N x F x T 182 | real = inputs*torch.cos(phase) 183 | imag = inputs*torch.sin(phase) 184 | # N x 2F x T 185 | cspec = torch.cat([real, imag], dim=1) 186 | # N x 1 x L 187 | outputs = F.conv_transpose1d(cspec, self.weight, stride=self.stride) 188 | 189 | # this is from torch-stft: https://github.com/pseeth/torch-stft 190 | # 1 x N x T 191 | t = self.window.repeat(1,1,inputs.size(-1))**2 192 | # 1 x 1 x L 193 | coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) 194 | 195 | outputs = outputs/(coff+1e-12) 196 | 197 | #outputs = torch.where(coff == 0, outputs, outputs/coff) 198 | # N x 1 x L 199 | 200 | outputs = outputs[...,self.win_len-self.stride:] 201 | # N x L 202 | outputs = outputs.squeeze(1) 203 | return outputs -------------------------------------------------------------------------------- /second_year/src/models/convtasnet_SSL_FiLM/Causal_CRN_SPL_target/FFT.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | import math 5 | import torch as th 6 | import torch 7 | import numpy as np 8 | from scipy.signal import get_window 9 | 10 | EPSILON = th.finfo(th.float32).eps 11 | MATH_PI = math.pi 12 | def init_kernels(win_len, 13 | win_inc, 14 | fft_len, 15 | win_type=None, 16 | invers=False): 17 | if win_type == 'None' or win_type is None: 18 | # N 19 | window = np.ones(win_len) 20 | else: 21 | # N 22 | window = get_window(win_type, win_len, fftbins=True)#**0.5 23 | N = fft_len 24 | # N x F 25 | fourier_basis = np.fft.rfft(np.eye(N))[:win_len] 26 | # N x F 27 | real_kernel = np.real(fourier_basis) 28 | imag_kernel = np.imag(fourier_basis) 29 | # 2F x N 30 | kernel = np.concatenate([real_kernel, imag_kernel], 1).T 31 | if invers : 32 | kernel = np.linalg.pinv(kernel).T 33 | 34 | # 2F x N * N => 2F x N 35 | kernel = kernel*window 36 | # 2F x 1 x N 37 | kernel = kernel[:, None, :] 38 | return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None,:,None].astype(np.float32)) 39 | 40 | 41 | class ConvSTFT(nn.Module): 42 | 43 | def __init__(self, 44 | win_len, 45 | win_inc, 46 | fft_len=None, 47 | vad_threshold=2/3, 48 | win_type='hamming', 49 | # fix=True 50 | ): 51 | super(ConvSTFT, self).__init__() 52 | 53 | if fft_len == None: 54 | self.fft_len = np.int(2**np.ceil(np.log2(win_len))) 55 | else: 56 | self.fft_len = fft_len 57 | 58 | # 2F x 1 x N 59 | kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type) 60 | vad_kernel=torch.ones((1,1, self.fft_len), dtype=torch.float32)/self.fft_len 61 | self.register_buffer('vad_kernel', vad_kernel) 62 | 63 | self.register_buffer('weight', kernel) 64 | self.vad_threshold=vad_threshold 65 | 66 | self.stride = win_inc 67 | self.win_len = win_len 68 | self.dim = self.fft_len 69 | 70 | def azimuth_strided(self, vad, azi): 71 | B, spk_num, T = azi.shape 72 | 73 | azi=azi.view(spk_num*B, T) 74 | 75 | 76 | 77 | 78 | result=[] 79 | 80 | for frame_count in range(vad.shape[-1]): 81 | 82 | 83 | 84 | if frame_count==0: 85 | now_azi=azi[:, :self.stride] 86 | now_azi=azi.float().mean(dim=-1) 87 | result.append(now_azi ) 88 | 89 | continue 90 | elif frame_count==(vad.shape[-1]-1): 91 | now_azi=azi[:, self.stride*frame_count:] 92 | now_azi=azi.float().mean(dim=-1) 93 | result.append(now_azi ) 94 | 95 | else: 96 | now_azi=azi[:, self.stride*frame_count:self.stride*frame_count+self.win_len] 97 | now_azi=azi.float().mean(dim=-1) 98 | result.append(now_azi ) 99 | 100 | azi=torch.round(torch.stack(result, dim=-1)) 101 | return azi 102 | azi=torch.round(torch.stack(result, dim=-1)).unsqueeze(-1) 103 | 104 | azi_range=torch.arange(0, 360).unsqueeze(0).to(azi.device).repeat_interleave(spk_num,dim=0).unsqueeze(1) 105 | 106 | distance=azi-azi_range 107 | 108 | print(azi_range.shape, azi.shape, distance.shape) 109 | exit() 110 | 111 | 112 | 113 | def forward(self, inputs, cplx=True): 114 | 115 | 116 | if inputs.dim() == 2: 117 | # N x 1 x L 118 | inputs = torch.unsqueeze(inputs, 1) 119 | inputs = F.pad(inputs,[self.win_len-self.stride, self.win_len-self.stride]) 120 | # N x 2F x T 121 | outputs = F.conv1d(inputs, self.weight, stride=self.stride) 122 | # N x F x T 123 | r, i = th.chunk(outputs, 2, dim=1) 124 | else: 125 | 126 | N, C, L = inputs.shape 127 | inputs = inputs.view(N * C, 1, L) 128 | # NC x 1 x L 129 | inputs = F.pad(inputs, [self.win_len-self.stride, self.win_len-self.stride]) 130 | 131 | # print(inputs.shape) 132 | # exit() 133 | 134 | 135 | # NC x 2F x T 136 | outputs = F.conv1d(inputs, self.weight, stride=self.stride) 137 | 138 | # N x C x 2F x T 139 | outputs = outputs.view(N, C, -1, outputs.shape[-1]) 140 | 141 | # N x C x F x T 142 | r, i = th.chunk(outputs, 2, dim=2) 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | if cplx: 153 | return r, i 154 | else: 155 | mags = th.clamp(r**2 + i**2, EPSILON)**0.5 156 | phase = th.atan2(i+EPSILON, r+EPSILON) 157 | return mags, phase 158 | 159 | class ConviSTFT(nn.Module): 160 | 161 | def __init__(self, 162 | win_len, 163 | win_inc, 164 | fft_len=None, 165 | win_type='hamming', 166 | # fix=True 167 | ): 168 | super(ConviSTFT, self).__init__() 169 | if fft_len == None: 170 | self.fft_len = np.int(2**np.ceil(np.log2(win_len))) 171 | else: 172 | self.fft_len = fft_len 173 | 174 | # kernel: 2F x 1 x N 175 | # window: 1 x N x 1 176 | kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True) 177 | #self.weight = nn.Parameter(kernel, requires_grad=(not fix)) 178 | self.register_buffer('weight', kernel) 179 | self.win_type = win_type 180 | self.win_len = win_len 181 | self.stride = win_inc 182 | self.stride = win_inc 183 | self.dim = self.fft_len 184 | self.register_buffer('window', window) 185 | self.register_buffer('enframe', torch.eye(win_len)[:,None,:]) 186 | 187 | def forward(self, inputs, phase, cplx=False): 188 | """ 189 | inputs : [B, N//2+1, T] (mags, real) 190 | phase: [B, N//2+1, T] (phase, imag) 191 | """ 192 | 193 | if cplx: 194 | # N x 2F x T 195 | cspec = torch.cat([inputs, phase], dim=1) 196 | else: 197 | # N x F x T 198 | real = inputs*torch.cos(phase) 199 | imag = inputs*torch.sin(phase) 200 | # N x 2F x T 201 | cspec = torch.cat([real, imag], dim=1) 202 | # N x 1 x L 203 | outputs = F.conv_transpose1d(cspec, self.weight, stride=self.stride) 204 | 205 | # this is from torch-stft: https://github.com/pseeth/torch-stft 206 | # 1 x N x T 207 | t = self.window.repeat(1,1,inputs.size(-1))**2 208 | # 1 x 1 x L 209 | coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) 210 | 211 | outputs = outputs/(coff+1e-8) 212 | 213 | #outputs = torch.where(coff == 0, outputs, outputs/coff) 214 | # N x 1 x L 215 | 216 | outputs = outputs[...,self.win_len-self.stride:] 217 | # N x L 218 | outputs = outputs.squeeze(1) 219 | return outputs -------------------------------------------------------------------------------- /third_year/src/inference.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | import util 4 | import torch 5 | 6 | import numpy as np 7 | import random 8 | import importlib 9 | 10 | from tqdm import tqdm 11 | from dataloader.data_loader import IITP_test_dataload 12 | 13 | import pandas as pd 14 | 15 | from asteroid.losses.sdr import SingleSrcNegSDR 16 | 17 | 18 | class Hyparam_set(): 19 | 20 | def __init__(self, args): 21 | self.args=args 22 | 23 | 24 | def set_torch_method(self,): 25 | try: 26 | torch.multiprocessing.set_start_method(self.args['hyparam']['torch_start_method'], force=False) # spawn 27 | except: 28 | torch.multiprocessing.set_start_method(self.args['hyparam']['torch_start_method'], force=True) # spawn 29 | 30 | 31 | def randomseed_init(self,): 32 | np.random.seed(self.args['hyparam']['randomseed']) 33 | random.seed(self.args['hyparam']['randomseed']) 34 | torch.manual_seed(self.args['hyparam']['randomseed']) 35 | if torch.cuda.is_available(): 36 | torch.cuda.manual_seed(self.args['hyparam']['randomseed']) 37 | 38 | device_primary_num=self.args['hyparam']['GPGPU']['device_ids'][0] 39 | device= 'cuda'+':'+str(device_primary_num) 40 | else: 41 | device= 'cpu' 42 | self.args['hyparam']['GPGPU']['device']=device 43 | return device 44 | def set_on(self): 45 | self.set_torch_method() 46 | self.device=self.randomseed_init() 47 | 48 | return self.args 49 | 50 | class Learner_config(): 51 | def __init__(self, args) -> None: 52 | self.args=args 53 | 54 | 55 | 56 | 57 | def memory_delete(self, *args): 58 | for a in args: 59 | del a 60 | 61 | def model_select(self): 62 | model_name=self.args['model']['name'] 63 | model_import='models.'+model_name+'.main' 64 | 65 | 66 | model_dir=importlib.import_module(model_import) 67 | 68 | self.model=model_dir.get_model(self.args['model']).to(self.device) 69 | 70 | trained=torch.load(self.args['hyparam']['model'], map_location=self.device) 71 | self.model.load_state_dict(trained['model_state_dict'], ) 72 | self.model=torch.nn.DataParallel(self.model, self.args['hyparam']['GPGPU']['device_ids']) 73 | 74 | 75 | 76 | 77 | def config(self): 78 | self.device=self.args['hyparam']['GPGPU']['device'] 79 | self.model_select() 80 | 81 | 82 | return self.args 83 | 84 | class Logger_config(): 85 | def __init__(self, args) -> None: 86 | self.args=args 87 | self.result_folder=self.args['hyparam']['result_folder'] 88 | 89 | 90 | 91 | 92 | 93 | 94 | def save_output(self, DB_type): 95 | try: 96 | now_dict=self.save_config_dict[DB_type] 97 | except: 98 | now_dict=self.save_config_dict[int(DB_type)] 99 | DB_type=int(DB_type) 100 | 101 | with open(self.result_folder['inference_folder']+'/'+DB_type+'/result.txt', 'w') as f: 102 | 103 | f.write('si_sdr\n\n') 104 | k=(now_dict['si_sdr']/now_dict['num']) 105 | for j in k: 106 | 107 | j=str(j) 108 | f.write(j) 109 | f.write('\n') 110 | df=pd.DataFrame(self.csv_dict[DB_type]) 111 | df.to_csv(self.result_folder['inference_folder']+'/'+DB_type+'/result.csv') 112 | 113 | 114 | 115 | 116 | 117 | def error_update(self, DB_type, si_sdr, pkl_name, num=1): 118 | now_dict=self.save_config_dict[DB_type] 119 | 120 | now_dict['si_sdr']+=si_sdr 121 | now_dict['num']+=num 122 | 123 | self.save_config_dict[DB_type]=now_dict 124 | self.csv_dict[DB_type]['pkl_name'].append(pkl_name[0]) 125 | self.csv_dict[DB_type]['si_sdr'].append(si_sdr.detach().cpu().numpy()[0]) 126 | 127 | 128 | 129 | def config(self,): 130 | 131 | 132 | self.csv_dict=dict() 133 | 134 | self.save_config_dict=dict() 135 | 136 | metric_data={} 137 | metric_data['si_sdr']=0 138 | metric_data['num']=0 139 | self.pandas_df=dict() 140 | self.pandas_df['pkl_name']=[] 141 | self.pandas_df['si_sdr']=[] 142 | 143 | 144 | 145 | 146 | metric_data['number_of_degrees']=0 147 | os.makedirs('../results/', exist_ok=True) 148 | 149 | 150 | 151 | 152 | return self.args 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | class Dataloader_config(): 161 | def __init__(self, args) -> None: 162 | self.args=args 163 | 164 | 165 | def config(self): 166 | self.test_loader=IITP_test_dataload(self.args['dataloader']['test']) 167 | 168 | 169 | return self.args 170 | 171 | 172 | 173 | class Tester(): 174 | 175 | def __init__(self, args): 176 | 177 | 178 | self.args=args 179 | 180 | self.hyperparameter=Hyparam_set(self.args) 181 | self.args=self.hyperparameter.set_on() 182 | 183 | self.learner=Learner_config(self.args) 184 | self.args=self.learner.config() 185 | self.model=self.learner.model 186 | 187 | 188 | self.dataloader=Dataloader_config(self.args) 189 | self.args=self.dataloader.config() 190 | 191 | self.logger=Logger_config(self.args) 192 | self.args=self.logger.config() 193 | 194 | 195 | 196 | def run(self, ): 197 | 198 | 199 | 200 | 201 | self.test() 202 | 203 | 204 | 205 | def test(self, ): 206 | self.model.eval() 207 | 208 | 209 | metric_func=SingleSrcNegSDR(reduction='none', zero_mean=True, take_log=True, sdr_type='sisdr') 210 | sisdr_list=[] 211 | 212 | mic_type=self.args['dataloader']['test']['mic_type'] 213 | 214 | audio_save_dir='../results/circle_4_result/' 215 | os.makedirs(audio_save_dir, exist_ok=True) 216 | 217 | for room_type in tqdm(self.args['hyparam']['result_folder']['room_type'], desc='room, mic_type: '+mic_type): 218 | room_type=str(room_type) 219 | self.dataloader.test_loader.dataset.room_type=str(room_type) 220 | self.dataloader.test_loader.dataset.pkl_list=os.listdir(self.dataloader.test_loader.dataset.pkl_dir+room_type) 221 | 222 | with torch.no_grad(): 223 | for iter_num, (mixed, target_wav, pkl_name) in enumerate (self.dataloader.test_loader): 224 | 225 | 226 | 227 | 228 | 229 | mixed=mixed.to(self.hyperparameter.device) 230 | get_wav=target_wav.to(self.hyperparameter.device) 231 | 232 | mixed=mixed.to(self.hyperparameter.device) 233 | target_wav=target_wav.to(self.hyperparameter.device) 234 | 235 | out, beamforming_weight_real, beamforming_weight_imag=self.model(mixed) 236 | 237 | 238 | 239 | 240 | 241 | si_sdr=-metric_func(out, target_wav) 242 | si_sdr=si_sdr.detach().cpu().numpy()[0] 243 | sisdr_list.append(si_sdr) 244 | 245 | 246 | 247 | self.learner.memory_delete([mixed, target_wav, out, si_sdr]) 248 | 249 | df=pd.DataFrame(sisdr_list) 250 | 251 | 252 | 253 | 254 | 255 | if __name__=='__main__': 256 | args=sys.argv[1:] 257 | 258 | args=util.util.get_yaml_args(args) 259 | 260 | t=Tester(args) 261 | t.run() -------------------------------------------------------------------------------- /second_year/src/inference.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | import util 4 | import torch 5 | 6 | import numpy as np 7 | import random 8 | import importlib 9 | 10 | from tqdm import tqdm 11 | from dataloader.data_loader import IITP_test_dataload 12 | import pandas as pd 13 | import soundfile as sf 14 | 15 | 16 | class Hyparam_set(): 17 | 18 | def __init__(self, args): 19 | self.args=args 20 | 21 | 22 | def set_torch_method(self,): 23 | try: 24 | torch.multiprocessing.set_start_method(self.args['hyparam']['torch_start_method'], force=False) # spawn 25 | except: 26 | torch.multiprocessing.set_start_method(self.args['hyparam']['torch_start_method'], force=True) # spawn 27 | 28 | 29 | def randomseed_init(self,): 30 | np.random.seed(self.args['hyparam']['randomseed']) 31 | random.seed(self.args['hyparam']['randomseed']) 32 | torch.manual_seed(self.args['hyparam']['randomseed']) 33 | 34 | if torch.cuda.is_available(): 35 | torch.cuda.manual_seed(self.args['hyparam']['randomseed']) 36 | 37 | device_primary_num=self.args['hyparam']['GPGPU']['device_ids'][0] 38 | device= 'cuda'+':'+str(device_primary_num) 39 | else: 40 | device= 'cpu' 41 | self.args['hyparam']['GPGPU']['device']=device 42 | return device 43 | def set_on(self): 44 | self.set_torch_method() 45 | self.device=self.randomseed_init() 46 | 47 | return self.args 48 | 49 | class Learner_config(): 50 | def __init__(self, args) -> None: 51 | self.args=args 52 | 53 | 54 | 55 | 56 | def memory_delete(self, *args): 57 | for a in args: 58 | del a 59 | 60 | def model_select(self): 61 | model_name=self.args['model']['name'] 62 | model_import='models.'+model_name+'.main' 63 | 64 | 65 | model_dir=importlib.import_module(model_import) 66 | 67 | self.model=model_dir.get_model(self.args['model']).to(self.device) 68 | 69 | trained=torch.load(self.args['hyparam']['model'], map_location=self.device) 70 | self.model.load_state_dict(trained['model_state_dict'], ) 71 | self.model=torch.nn.DataParallel(self.model, self.args['hyparam']['GPGPU']['device_ids']) 72 | 73 | 74 | def init_loss_func(self): 75 | 76 | 77 | 78 | if self.args['learner']['loss']['type']=='weighted_bce': 79 | from loss.bce_loss import weighted_binary_cross_entropy 80 | self.loss_func=weighted_binary_cross_entropy(**self.args['learner']['loss']['option']) 81 | elif self.args['learner']['loss']['type']=='BCEWithLogitsLoss': 82 | self.loss_func=torch.nn.modules.loss.BCEWithLogitsLoss(reduction='none') 83 | self.loss_func=torch.nn.modules.loss.BCELoss(reduction='none') 84 | 85 | self.loss_train_map_num=self.args['learner']['loss']['option']['train_map_num'] 86 | 87 | def update(self, output, target): 88 | 89 | 90 | target=target[:, self.loss_train_map_num] 91 | output=output[:, self.loss_train_map_num].sigmoid() 92 | 93 | loss=self.loss_func(output, target) 94 | 95 | 96 | loss_mean=loss.mean() 97 | 98 | 99 | return loss_mean 100 | 101 | def config(self): 102 | self.device=self.args['hyparam']['GPGPU']['device'] 103 | self.model_select() 104 | self.init_loss_func() 105 | 106 | return self.args 107 | 108 | class Logger_config(): 109 | def __init__(self, args) -> None: 110 | self.args=args 111 | self.result_folder=self.args['hyparam']['result_folder'] 112 | 113 | self.wav_save=self.args['hyparam']['wav_save'] 114 | self.wav_folder=self.args['hyparam']['wav_folder'] 115 | 116 | 117 | 118 | 119 | 120 | 121 | def save_output(self, DB_type): 122 | try: 123 | now_dict=self.save_config_dict[DB_type] 124 | except: 125 | now_dict=self.save_config_dict[int(DB_type)] 126 | DB_type=int(DB_type) 127 | 128 | with open(self.result_folder['inference_folder']+'/'+DB_type+'/result.txt', 'w') as f: 129 | save_folder=self.result_folder['inference_folder']+str(DB_type)+'/' 130 | 131 | pd.DataFrame(now_dict).to_csv(save_folder+'result.csv') 132 | f.write('\nSI-SDR\n\n') 133 | j=np.array(now_dict['SI-SDR']).mean() 134 | f.write(str(j)) 135 | f.write('\n') 136 | 137 | 138 | def save_wav(self, DB_type, audio_list, iter_num): 139 | save_folder=self.result_folder['inference_folder']+str(DB_type)+'/wav/'+str(iter_num)+'/' 140 | os.makedirs(save_folder, exist_ok=True) 141 | sf.write(save_folder+'noisy.wav', audio_list[0].cpu().numpy(), 16000) 142 | 143 | sf.write(save_folder+'clean.wav', audio_list[1].cpu().numpy(), 16000) 144 | sf.write(save_folder+'estimate.wav', audio_list[2].cpu().numpy(), 16000) 145 | 146 | 147 | def error_update(self, DB_type, si_sdr,iter_num, audio_list): 148 | now_dict=self.save_config_dict[DB_type] 149 | now_dict['file_num'].append(iter_num) 150 | now_dict['SI-SDR'].append(si_sdr.cpu().item()) 151 | self.save_config_dict[DB_type]=now_dict 152 | 153 | if self.wav_save: 154 | self.save_wav( DB_type, audio_list, iter_num) 155 | 156 | 157 | 158 | 159 | def config(self,): 160 | from copy import deepcopy 161 | 162 | self.save_config_dict=dict() 163 | 164 | metric_data={} 165 | metric_data['file_num']=[] 166 | metric_data['SI-SDR']=[] 167 | 168 | 169 | 170 | for room_type in self.result_folder['room_type']: 171 | os.makedirs(self.result_folder['inference_folder']+room_type, exist_ok=True) 172 | os.makedirs(self.result_folder['inference_folder']+room_type+'/wav', exist_ok=True) 173 | 174 | self.save_config_dict[room_type]=deepcopy(metric_data) 175 | 176 | 177 | 178 | 179 | return self.args 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | class Dataloader_config(): 188 | def __init__(self, args) -> None: 189 | self.args=args 190 | 191 | 192 | def config(self): 193 | self.test_loader=IITP_test_dataload(self.args['dataloader']['test']) 194 | 195 | 196 | return self.args 197 | 198 | 199 | 200 | class Tester(): 201 | 202 | def __init__(self, args): 203 | 204 | 205 | self.args=args 206 | 207 | self.hyperparameter=Hyparam_set(self.args) 208 | self.args=self.hyperparameter.set_on() 209 | 210 | self.learner=Learner_config(self.args) 211 | self.args=self.learner.config() 212 | self.model=self.learner.model 213 | 214 | 215 | self.dataloader=Dataloader_config(self.args) 216 | self.args=self.dataloader.config() 217 | 218 | self.logger=Logger_config(self.args) 219 | self.args=self.logger.config() 220 | 221 | 222 | 223 | def run(self, ): 224 | 225 | 226 | 227 | 228 | self.test() 229 | 230 | 231 | 232 | def test(self, ): 233 | self.model.eval() 234 | 235 | from asteroid.losses.sdr import SingleSrcNegSDR 236 | si_sdr_func=SingleSrcNegSDR('sisdr') 237 | 238 | 239 | 240 | 241 | for room_type in self.args['hyparam']['result_folder']['room_type']: 242 | room_type=str(room_type) 243 | print(room_type) 244 | self.dataloader.test_loader.dataset.room_type=str(room_type) 245 | 246 | with torch.no_grad(): 247 | for iter_num, (mixed, clean, SNR) in enumerate(tqdm(self.dataloader.test_loader, desc='Test', total=len(self.dataloader.test_loader), )): 248 | 249 | 250 | mixed=mixed.to(self.hyperparameter.device) 251 | 252 | clean=clean.to(self.hyperparameter.device) 253 | 254 | out=self.model(mixed) 255 | 256 | 257 | si_sdr=-si_sdr_func(out, clean) 258 | 259 | self.logger.error_update(room_type, si_sdr, iter_num, [mixed[0,0], clean[0], out[0]]) 260 | 261 | 262 | 263 | self.learner.memory_delete([mixed, clean, out, si_sdr]) 264 | 265 | self.logger.save_output(room_type) 266 | 267 | 268 | 269 | 270 | if __name__=='__main__': 271 | args=sys.argv[1:] 272 | 273 | args=util.util.get_yaml_args(args) 274 | 275 | t=Tester(args) 276 | t.run() -------------------------------------------------------------------------------- /third_year/src/util/util.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import torch 3 | import numpy as np 4 | import random 5 | import os 6 | from datetime import datetime 7 | import matplotlib.pyplot as plt 8 | import shutil 9 | import select 10 | import sys 11 | 12 | 13 | def load_yaml(yaml_dir): 14 | yaml_file=open(yaml_dir, 'r') 15 | data=yaml.safe_load(yaml_file) 16 | yaml_file.close() 17 | return data 18 | 19 | 20 | 21 | def copy_folder(dir): 22 | shutil.copytree('./',dir+'src', ignore=shutil.ignore_patterns('./wandb/**','*.png', '*.wav')) 23 | 24 | 25 | def exp_mkdir(config): 26 | dir=config['exp']['result_dir'] 27 | 28 | # experiment directory 29 | dir=dir+config['exp']['name']+'/' 30 | os.makedirs(dir, exist_ok=True) 31 | exp_description_txt=dir+'exp_description.txt' 32 | 33 | if os.path.exists(exp_description_txt)==False: 34 | print('Please write exp Description') 35 | f=open(exp_description_txt, 'w') 36 | f.close() 37 | exit() 38 | exp_dir=dir 39 | 40 | # model directory 41 | dir=dir+config['model']['name']+'/' 42 | os.makedirs(dir, exist_ok=True) 43 | model_description_txt=dir+'model_description.txt' 44 | if os.path.exists(model_description_txt)==False: 45 | print('Please write model Description') 46 | f=open(model_description_txt, 'w') 47 | f.close() 48 | f=open(dir+'model_structure.txt', 'w') 49 | f.close() 50 | f=open(dir+'model_summary.txt', 'w') 51 | f.close() 52 | exit() 53 | model_dir=dir 54 | 55 | # date directory 56 | 57 | if config['exp']['temp']==True: 58 | dir=dir+'temp/' 59 | else: 60 | dir=dir+'exp/' 61 | 62 | 63 | 64 | 65 | os.makedirs(dir, exist_ok=True) 66 | 67 | 68 | now_time=datetime.now().strftime('%Y_%m_%d_%H_%M_%S') 69 | dir=dir+now_time+'/' 70 | os.makedirs(dir, exist_ok=True) 71 | 72 | os.makedirs(dir+'model_save', exist_ok=True) 73 | os.makedirs(dir+'result', exist_ok=True) 74 | 75 | f=open(dir+'train_status.txt', 'w') 76 | f.close() 77 | 78 | 79 | 80 | print('Please write date Description') 81 | f=open(dir+'description.txt', 'w') 82 | f.close() 83 | 84 | 85 | return exp_dir, model_dir, dir 86 | 87 | 88 | def get_yaml_args(yaml_list): 89 | # print(yaml_list) 90 | # exit() 91 | yaml_out={} 92 | for a in yaml_list: 93 | a=a.split(' ') 94 | yaml_out[a[0]]=load_yaml(a[1]) 95 | 96 | 97 | return yaml_out 98 | 99 | def randomseed_init(num): 100 | np.random.seed(num) 101 | random.seed(num) 102 | torch.manual_seed(num) 103 | # torch.Generator.manual_seed(num) 104 | if torch.cuda.is_available(): 105 | torch.cuda.manual_seed(num) 106 | return 'cuda' 107 | else: 108 | return 'cpu' 109 | 110 | def draw_result_pic(dir, epoch, train, val): 111 | fig1 = plt.figure(figsize=(7,4)) 112 | epo = np.arange(epoch+1) 113 | 114 | os.makedirs(os.path.dirname(dir), exist_ok=True) 115 | a2 = fig1.add_subplot(1, 1, 1) 116 | a2.plot(epo, train, epo,val) 117 | a2.set_title("Loss") 118 | a2.legend(['Train', 'Eval']) 119 | a2.set_ylabel('Loss') 120 | a2.set_xlabel('Epochs') 121 | a2.grid(axis='y', linestyle='dashed') 122 | fig1.tight_layout() 123 | fig1.savefig(dir, dpi=300) 124 | plt.close(fig1) 125 | 126 | 127 | 128 | def log_saving(record_dir, record_param, epoch, model, writer, optimizer, result_text,restart_num, end=False, temp=False): 129 | write_file=open(result_text, 'a') 130 | # print(epoch) 131 | print("\nAccuracy(train, eval) : %3.3f %3.3f" % ( 132 | record_param['accu_list_train'][epoch] * 100, record_param['accu_list_eval'][epoch] * 100)) 133 | write_file.write("\nAccuracy(train, eval) : %3.3f %3.3f \n" % ( 134 | record_param['accu_list_train'][epoch] * 100, record_param['accu_list_eval'][epoch] * 100)) 135 | 136 | print("Max accuracy(eval) : %d epoch, %3.3f\n" % ( 137 | record_param['accu_list_eval'].index(max(record_param['accu_list_eval'])) , 138 | max(record_param['accu_list_eval']) * 100)) 139 | write_file.write("Max accuracy(eval) : %d epoch, %3.3f\n\n" % ( 140 | record_param['accu_list_eval'].index(max(record_param['accu_list_eval'])), 141 | max(record_param['accu_list_eval']) * 100)) 142 | 143 | print("Loss(train, eval) : %3.3f %3.3f" % ( 144 | record_param['loss_list_train'][epoch], record_param['loss_list_eval'][epoch])) 145 | write_file.write("Loss(train, eval) : %3.3f %3.3f\n" % ( 146 | record_param['loss_list_train'][epoch], record_param['loss_list_eval'][epoch])) 147 | 148 | print("Min loss(eval): %d epoch, %3.3f\n" % ( 149 | record_param['loss_list_eval'].index(min(record_param['loss_list_eval'])) , min(record_param['loss_list_eval']))) 150 | write_file.write("Min loss(eval): %d epoch, %3.3f\n\n" % ( 151 | record_param['loss_list_eval'].index(min(record_param['loss_list_eval'])) , min(record_param['loss_list_eval']))) 152 | 153 | write_file.close() 154 | 155 | if temp ==False: 156 | writer.add_scalars('Accuracy', {'Accuracy/Train': record_param['accu_list_train'][-1], 'Accuracy/Val':record_param['accu_list_eval'][-1]}, epoch) 157 | writer.add_scalars('Loss', {'Loss/Train': record_param['loss_list_train'][-1], 'Loss/Val': record_param['loss_list_eval'][-1]}, epoch) 158 | 159 | if (epoch % 10 == 0) or end ==True or epoch<10 or epoch!=0: 160 | for name, parameter in model.named_parameters(): 161 | writer.add_histogram(name, parameter.clone().detach().cpu().data.numpy(), epoch) 162 | 163 | for x in restart_num: 164 | if epoch>=(x-6) and epoch<=(x-1): 165 | torch.save({'epoch': epoch, 166 | 'model_state_dict': model.state_dict(), 167 | 'optimizer_state_dict': optimizer.state_dict(), 168 | 'param': record_param 169 | }, record_dir + "/model_"+str(epoch)+".pth") 170 | 171 | torch.save({'epoch': epoch, 172 | 'model_state_dict': model.state_dict(), 173 | 'optimizer_state_dict': optimizer.state_dict(), 174 | 'param': record_param 175 | }, record_dir + "/current_model.pth") 176 | if record_param['loss_list_eval'][-1]==min(record_param['loss_list_eval']): 177 | torch.save({'epoch': epoch, 178 | 'model_state_dict': model.state_dict(), 179 | 'optimizer_state_dict': optimizer.state_dict(), 180 | 'param': record_param 181 | }, record_dir + "/best_loss_model.pth") 182 | 183 | if record_param['accu_list_eval'][-1]==max(record_param['accu_list_eval']): 184 | torch.save({'epoch': epoch, 185 | 'model_state_dict': model.state_dict(), 186 | 'optimizer_state_dict': optimizer.state_dict(), 187 | 'param': record_param 188 | }, record_dir + "/best_accu_model.pth") 189 | 190 | 191 | fig1 = plt.figure(figsize=(7,4)) 192 | epo = np.arange(0, epoch + 1, 1) 193 | a1 = fig1.add_subplot(2, 1, 1) 194 | a1.plot(epo, np.array(record_param['accu_list_train']), epo, np.array(record_param['accu_list_eval'])) 195 | a1.set_title("Accuracy") 196 | a1.legend(['Train', 'Eval']) 197 | a1.set_xlabel('Epochs') 198 | a1.set_ylabel('Accuracy') 199 | a1.grid(axis='y', linestyle='dashed') 200 | 201 | a2 = fig1.add_subplot(2, 1, 2) 202 | a2.plot(epo, np.array(record_param['loss_list_train']), epo, np.array(record_param['loss_list_eval'])) 203 | a2.set_title("Loss") 204 | a2.legend(['Train', 'Eval']) 205 | a2.set_ylabel('Loss') 206 | a2.set_xlabel('Epochs') 207 | a2.grid(axis='y', linestyle='dashed') 208 | fig1.tight_layout() 209 | fig1.savefig(record_dir + '/accu_loss.png', dpi=300) 210 | plt.close(fig1) 211 | fig1 = plt.figure(figsize=(7,4)) 212 | epo = np.arange(0, epoch + 1, 1) 213 | a1 = fig1.add_subplot(2, 1, 1) 214 | a1.plot(epo, np.array(record_param['accu_list_train']), epo, np.array(record_param['accu_list_eval'])) 215 | a1.set_title("Accuracy") 216 | a1.legend(['Train', 'Eval']) 217 | a1.set_xlabel('Epochs') 218 | a1.set_ylabel('Accuracy') 219 | a1.grid(axis='y', linestyle='dashed') 220 | 221 | a2 = fig1.add_subplot(2, 1, 2) 222 | a2.plot(epo, np.array(record_param['loss_list_train']), epo, np.array(record_param['loss_list_eval'])) 223 | a2.set_title("Loss") 224 | a2.legend(['Train', 'Eval']) 225 | a2.set_ylabel('Loss') 226 | a2.set_xlabel('Epochs') 227 | a2.grid(axis='y', linestyle='dashed') 228 | fig1.tight_layout() 229 | fig1.savefig(record_dir + '/accu_loss.png', dpi=300) 230 | plt.close(fig1) 231 | 232 | 233 | -------------------------------------------------------------------------------- /fourth_year/src/models/FSPEN/fspen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn, Tensor 4 | from .modules.en_decoder import FullBandEncoderBlock, FullBandDecoderBlock 5 | from .modules.en_decoder import SubBandEncoderBlock, SubBandDecoderBlock 6 | from .modules.sequence_modules import DualPathExtensionRNN 7 | # from configs.train_configs import TrainConfig 8 | 9 | 10 | 11 | class FullBandEncoder(nn.Module): 12 | def __init__(self, configs): 13 | super().__init__() 14 | 15 | last_channels = 0 16 | self.full_band_encoder = nn.ModuleList() 17 | for encoder_name, conv_parameter in configs['full_band_encoder'].items(): 18 | self.full_band_encoder.append(FullBandEncoderBlock(**conv_parameter)) 19 | last_channels = conv_parameter["out_channels"] 20 | 21 | self.global_features = nn.Conv1d(in_channels=last_channels, out_channels=last_channels, kernel_size=1, stride=1) 22 | 23 | 24 | 25 | def forward(self, complex_spectrum: Tensor): 26 | """ 27 | :param complex_spectrum: (batch*frame, channels, frequency) 28 | :return: 29 | """ 30 | full_band_encodes = [] 31 | for encoder in self.full_band_encoder: 32 | complex_spectrum = encoder(complex_spectrum) 33 | full_band_encodes.append(complex_spectrum) 34 | 35 | global_feature = self.global_features(complex_spectrum) 36 | 37 | return full_band_encodes[::-1], global_feature 38 | 39 | 40 | class SubBandEncoder(nn.Module): 41 | def __init__(self, configs): 42 | super().__init__() 43 | 44 | self.sub_band_encoders = nn.ModuleList() 45 | for encoder_name, conv_parameters in configs['sub_band_encoder'].items(): 46 | self.sub_band_encoders.append(SubBandEncoderBlock(**conv_parameters["conv"])) 47 | 48 | def forward(self, amplitude_spectrum: Tensor): 49 | """ 50 | :param amplitude_spectrum: (batch * frames, channels, frequency) 51 | :return: 52 | """ 53 | sub_band_encodes = list() 54 | for encoder in self.sub_band_encoders: 55 | encode_out = encoder(amplitude_spectrum) 56 | sub_band_encodes.append(encode_out) 57 | 58 | local_feature = torch.cat(sub_band_encodes, dim=2) # feature cat 59 | 60 | return sub_band_encodes, local_feature 61 | 62 | 63 | class FullBandDecoder(nn.Module): 64 | def __init__(self, configs): 65 | super().__init__() 66 | self.full_band_decoders = nn.ModuleList() 67 | for decoder_name, parameters in configs["full_band_decoder"].items(): 68 | self.full_band_decoders.append( 69 | FullBandDecoderBlock(**parameters)) 70 | 71 | def forward(self, feature: Tensor, encode_outs: list): 72 | for decoder, encode_out in zip(self.full_band_decoders, encode_outs): 73 | feature = decoder(feature, encode_out) 74 | 75 | return feature 76 | 77 | 78 | class SubBandDecoder(nn.Module): 79 | def __init__(self, configs): 80 | super().__init__() 81 | start_idx = 0 82 | self.sub_band_decoders = nn.ModuleList() 83 | for (decoder_name, parameters), bands in zip(configs["sub_band_decoder"].items(), configs["bands_num_in_groups"]): 84 | end_idx = start_idx + bands 85 | self.sub_band_decoders.append(SubBandDecoderBlock(start_idx=start_idx, end_idx=end_idx, **parameters)) 86 | 87 | def forward(self, feature: Tensor, sub_encodes: list): 88 | """ 89 | :param feature: (batch*frames, channels, bands) 90 | :param sub_encodes: [sub_encode_0, sub_encode_1, ...], each element is (batch*frames, channels, sub_bands) 91 | :return: (batch*frames, full-frequency) 92 | """ 93 | sub_decoder_outs = [] 94 | for decoder, sub_encode in zip(self.sub_band_decoders, sub_encodes): 95 | sub_decoder_out = decoder(feature, sub_encode) 96 | sub_decoder_outs.append(sub_decoder_out) 97 | 98 | sub_decoder_outs = torch.cat(tensors=sub_decoder_outs, dim=1) # feature cat 99 | 100 | return sub_decoder_outs 101 | 102 | 103 | class FullSubPathExtension(nn.Module): 104 | def __init__(self, configs): 105 | super().__init__() 106 | self.full_band_encoder = FullBandEncoder(configs) 107 | self.sub_band_encoder = SubBandEncoder(configs) 108 | 109 | self.ch_size=configs['channel_size'] 110 | merge_split = configs["merge_split"] 111 | merge_channels = merge_split["channels"] 112 | merge_bands = merge_split["bands"] 113 | compress_rate = merge_split["compress_rate"] 114 | 115 | self.feature_merge_layer = nn.Sequential( 116 | nn.Linear(in_features=merge_channels, out_features=merge_channels//compress_rate), 117 | nn.ELU(), 118 | nn.Conv1d(in_channels=merge_bands, out_channels=merge_bands//compress_rate, kernel_size=1, stride=1) 119 | ) 120 | 121 | self.dual_path_extension_rnn_list = nn.ModuleList() 122 | for _ in range(configs["dual_path_extension"]["num_modules"]): 123 | self.dual_path_extension_rnn_list.append(DualPathExtensionRNN(**configs["dual_path_extension"]["parameters"])) 124 | 125 | self.feature_split_layer = nn.Sequential( 126 | nn.Conv1d(in_channels=merge_bands//compress_rate, out_channels=merge_bands, kernel_size=1, stride=1), 127 | nn.Linear(in_features=merge_channels//compress_rate, out_features=merge_channels), 128 | nn.ELU() 129 | ) 130 | 131 | self.full_band_decoder = FullBandDecoder(configs) 132 | self.sub_band_decoder = SubBandDecoder(configs) 133 | 134 | self.mask_padding = nn.ConstantPad2d(padding=(1, 0, 0, 0), value=0.0) 135 | 136 | 137 | self.complex_multi_to_single = nn.Linear(in_features=self.ch_size, out_features=1) 138 | self.complex_multi_to_single.weight.data.fill_(0.5) 139 | 140 | self.amp_multi_to_single = nn.Linear(in_features=self.ch_size, out_features=1) 141 | self.amp_multi_to_single.weight.data.fill_(0.5) 142 | 143 | 144 | def forward(self, in_complex_spectrum: Tensor, in_amplitude_spectrum: Tensor, hidden_state: list): 145 | """ 146 | :param in_amplitude_spectrum: (batch, frames, 1, frequency) 147 | :param hidden_state: 148 | :param in_complex_spectrum: (batch, frames, channels, frequency) 149 | :return: 150 | """ 151 | in_complex_spectrum_1ch=in_complex_spectrum[:,0] 152 | in_amplitude_spectrum_1ch=in_amplitude_spectrum[:,0] 153 | in_complex_spectrum = in_complex_spectrum.permute(0, 2,3,4,1) 154 | in_complex_spectrum=self.complex_multi_to_single(in_complex_spectrum)[...,0] 155 | in_amplitude_spectrum = in_amplitude_spectrum.permute(0, 2,3,4,1) 156 | in_amplitude_spectrum=self.amp_multi_to_single(in_amplitude_spectrum)[...,0] 157 | 158 | 159 | 160 | 161 | 162 | batch, frames, channels, frequency = in_complex_spectrum.shape 163 | complex_spectrum = torch.reshape(in_complex_spectrum, shape=(batch * frames, channels, frequency)) 164 | amplitude_spectrum = torch.reshape(in_amplitude_spectrum, shape=(batch*frames, 1, frequency)) 165 | 166 | full_band_encode_outs, global_feature = self.full_band_encoder(complex_spectrum) 167 | sub_band_encode_outs, local_feature = self.sub_band_encoder(amplitude_spectrum) 168 | 169 | merge_feature = torch.cat(tensors=[global_feature, local_feature], dim=2) # feature cat 170 | merge_feature = self.feature_merge_layer(merge_feature) 171 | # (batch*frames, channels, frequency) -> (batch*frames, channels//2, frequency//2) 172 | _, channels, frequency = merge_feature.shape 173 | merge_feature = torch.reshape(merge_feature, shape=(batch, frames, channels, frequency)) 174 | merge_feature = torch.permute(merge_feature, dims=(0, 3, 1, 2)).contiguous() 175 | # (batch, frequency, frames, channels) 176 | out_hidden_state = list() 177 | for idx, rnn_layer in enumerate(self.dual_path_extension_rnn_list): 178 | merge_feature, state = rnn_layer(merge_feature, hidden_state[idx]) 179 | out_hidden_state.append(state) 180 | 181 | merge_feature = torch.permute(merge_feature, dims=(0, 2, 3, 1)).contiguous() 182 | merge_feature = torch.reshape(merge_feature, shape=(batch * frames, channels, frequency)) 183 | 184 | split_feature = self.feature_split_layer(merge_feature) 185 | first_dim, channels, frequency = split_feature.shape 186 | split_feature = torch.reshape(split_feature, shape=(first_dim, channels, -1, 2)) 187 | 188 | full_band_mask = self.full_band_decoder(split_feature[..., 0], full_band_encode_outs) 189 | sub_band_mask = self.sub_band_decoder(split_feature[..., 1], sub_band_encode_outs) 190 | 191 | full_band_mask = torch.reshape(full_band_mask, shape=(batch, frames, 2, -1)) 192 | sub_band_mask = torch.reshape(sub_band_mask, shape=(batch, frames, 1, -1)) 193 | 194 | # Zero padding in the DC signal part removes the DC component 195 | full_band_mask = self.mask_padding(full_band_mask) 196 | sub_band_mask = self.mask_padding(sub_band_mask) 197 | 198 | # full_band_out = in_complex_spectrum * full_band_mask 199 | # sub_band_out = in_amplitude_spectrum * sub_band_mask 200 | # outputs is (batch, frames, 2, frequency), complex style. 201 | 202 | full_band_out=in_complex_spectrum_1ch*full_band_mask 203 | sub_band_out=in_amplitude_spectrum_1ch*sub_band_mask 204 | 205 | 206 | 207 | return full_band_out + sub_band_out, out_hidden_state 208 | -------------------------------------------------------------------------------- /second_year/src/models/convtasnet_SSL_FiLM/convtasnet_module/utility/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | class cLN(nn.Module): 9 | def __init__(self, dimension, eps = 1e-8, trainable=True): 10 | super(cLN, self).__init__() 11 | 12 | self.eps = eps 13 | if trainable: 14 | self.gain = nn.Parameter(torch.ones(1, dimension, 1)) 15 | self.bias = nn.Parameter(torch.zeros(1, dimension, 1)) 16 | else: 17 | self.gain = Variable(torch.ones(1, dimension, 1), requires_grad=False) 18 | self.bias = Variable(torch.zeros(1, dimension, 1), requires_grad=False) 19 | 20 | def forward(self, input): 21 | # input size: (Batch, Freq, Time) 22 | # cumulative mean for each time step 23 | 24 | batch_size = input.size(0) 25 | channel = input.size(1) 26 | time_step = input.size(2) 27 | 28 | step_sum = input.sum(1) # B, T 29 | step_pow_sum = input.pow(2).sum(1) # B, T 30 | cum_sum = torch.cumsum(step_sum, dim=1) # B, T 31 | cum_pow_sum = torch.cumsum(step_pow_sum, dim=1) # B, T 32 | 33 | entry_cnt = np.arange(channel, channel*(time_step+1), channel) 34 | entry_cnt = torch.from_numpy(entry_cnt).type(input.type()) 35 | entry_cnt = entry_cnt.view(1, -1).expand_as(cum_sum) 36 | 37 | cum_mean = cum_sum / entry_cnt # B, T 38 | cum_var = (cum_pow_sum - 2*cum_mean*cum_sum) / entry_cnt + cum_mean.pow(2) # B, T 39 | cum_std = (cum_var + self.eps).sqrt() # B, T 40 | 41 | cum_mean = cum_mean.unsqueeze(1) 42 | cum_std = cum_std.unsqueeze(1) 43 | 44 | x = (input - cum_mean.expand_as(input)) / cum_std.expand_as(input) 45 | return x * self.gain.expand_as(x).type(x.type()) + self.bias.expand_as(x).type(x.type()) 46 | 47 | def repackage_hidden(h): 48 | """ 49 | Wraps hidden states in new Variables, to detach them from their history. 50 | """ 51 | 52 | if type(h) == Variable: 53 | return Variable(h.data) 54 | else: 55 | return tuple(repackage_hidden(v) for v in h) 56 | 57 | class MultiRNN(nn.Module): 58 | """ 59 | Container module for multiple stacked RNN layers. 60 | 61 | args: 62 | rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'. 63 | input_size: int, dimension of the input feature. The input should have shape 64 | (batch, seq_len, input_size). 65 | hidden_size: int, dimension of the hidden state. The corresponding output should 66 | have shape (batch, seq_len, hidden_size). 67 | num_layers: int, number of stacked RNN layers. Default is 1. 68 | bidirectional: bool, whether the RNN layers are bidirectional. Default is False. 69 | """ 70 | 71 | def __init__(self, rnn_type, input_size, hidden_size, dropout=0, num_layers=1, bidirectional=False): 72 | super(MultiRNN, self).__init__() 73 | 74 | self.rnn = getattr(nn, rnn_type)(input_size, hidden_size, num_layers, dropout=dropout, 75 | batch_first=True, bidirectional=bidirectional) 76 | 77 | 78 | 79 | self.rnn_type = rnn_type 80 | self.hidden_size = hidden_size 81 | self.num_layers = num_layers 82 | self.num_direction = int(bidirectional) + 1 83 | 84 | def forward(self, input): 85 | hidden = self.init_hidden(input.size(0)) 86 | self.rnn.flatten_parameters() 87 | return self.rnn(input, hidden) 88 | 89 | def init_hidden(self, batch_size): 90 | weight = next(self.parameters()).data 91 | if self.rnn_type == 'LSTM': 92 | return (Variable(weight.new(self.num_layers*self.num_direction, batch_size, self.hidden_size).zero_()), 93 | Variable(weight.new(self.num_layers*self.num_direction, batch_size, self.hidden_size).zero_())) 94 | else: 95 | return Variable(weight.new(self.num_layers*self.num_direction, batch_size, self.hidden_size).zero_()) 96 | 97 | 98 | class FCLayer(nn.Module): 99 | """ 100 | Container module for a fully-connected layer. 101 | 102 | args: 103 | input_size: int, dimension of the input feature. The input should have shape 104 | (batch, input_size). 105 | hidden_size: int, dimension of the output. The corresponding output should 106 | have shape (batch, hidden_size). 107 | nonlinearity: string, the nonlinearity applied to the transformation. Default is None. 108 | """ 109 | 110 | def __init__(self, input_size, hidden_size, bias=True, nonlinearity=None): 111 | super(FCLayer, self).__init__() 112 | 113 | self.input_size = input_size 114 | self.hidden_size = hidden_size 115 | self.bias = bias 116 | self.FC = nn.Linear(self.input_size, self.hidden_size, bias=bias) 117 | if nonlinearity: 118 | self.nonlinearity = getattr(F, nonlinearity) 119 | else: 120 | self.nonlinearity = None 121 | 122 | self.init_hidden() 123 | 124 | def forward(self, input): 125 | if self.nonlinearity is not None: 126 | return self.nonlinearity(self.FC(input)) 127 | else: 128 | return self.FC(input) 129 | 130 | def init_hidden(self): 131 | initrange = 1. / np.sqrt(self.input_size * self.hidden_size) 132 | self.FC.weight.data.uniform_(-initrange, initrange) 133 | if self.bias: 134 | self.FC.bias.data.fill_(0) 135 | 136 | 137 | class DepthConv1d(nn.Module): 138 | 139 | def __init__(self, input_channel, hidden_channel, kernel, padding, dilation=1, skip=True, causal=False): 140 | super(DepthConv1d, self).__init__() 141 | 142 | self.causal = causal 143 | self.skip = skip 144 | 145 | self.conv1d = nn.Conv1d(input_channel, hidden_channel, 1) 146 | if self.causal: 147 | self.padding = (kernel - 1) * dilation 148 | else: 149 | self.padding = padding 150 | self.dconv1d = nn.Conv1d(hidden_channel, hidden_channel, kernel, dilation=dilation, 151 | groups=hidden_channel, 152 | padding=self.padding) 153 | self.res_out = nn.Conv1d(hidden_channel, input_channel, 1) 154 | self.nonlinearity1 = nn.PReLU() 155 | self.nonlinearity2 = nn.PReLU() 156 | if self.causal: 157 | self.reg1 = cLN(hidden_channel, eps=1e-08) 158 | self.reg2 = cLN(hidden_channel, eps=1e-08) 159 | else: 160 | self.reg1 = nn.GroupNorm(1, hidden_channel, eps=1e-08) 161 | self.reg2 = nn.GroupNorm(1, hidden_channel, eps=1e-08) 162 | 163 | if self.skip: 164 | self.skip_out = nn.Conv1d(hidden_channel, input_channel, 1) 165 | 166 | def forward(self, input): 167 | output = self.reg1(self.nonlinearity1(self.conv1d(input))) 168 | if self.causal: 169 | output = self.reg2(self.nonlinearity2(self.dconv1d(output)[:,:,:-self.padding])) 170 | else: 171 | output = self.reg2(self.nonlinearity2(self.dconv1d(output))) 172 | residual = self.res_out(output) 173 | if self.skip: 174 | skip = self.skip_out(output) 175 | return residual, skip 176 | else: 177 | return residual 178 | 179 | class TCN(nn.Module): 180 | def __init__(self, ch_size, input_dim, output_dim, BN_dim, hidden_dim, 181 | layer, stack, kernel=3, skip=True, 182 | causal=False, dilated=True, Film_loc=8): 183 | super(TCN, self).__init__() 184 | 185 | # input is a sequence of features of shape (B, N, L) 186 | 187 | # normalization 188 | if not causal: 189 | self.LN = nn.GroupNorm(1, input_dim*ch_size, eps=1e-8) 190 | else: 191 | self.LN = cLN(input_dim*ch_size, eps=1e-8) 192 | 193 | self.BN = nn.Conv1d(input_dim*ch_size, BN_dim, 1) 194 | 195 | # TCN for feature extraction 196 | self.receptive_field = 0 197 | self.dilated = dilated 198 | self.Film_loc=Film_loc 199 | self.TCN = nn.ModuleList([]) 200 | for s in range(stack): 201 | for i in range(layer): 202 | if self.dilated: 203 | self.TCN.append(DepthConv1d(BN_dim, hidden_dim, kernel, dilation=2**i, padding=2**i, skip=skip, causal=causal)) 204 | else: 205 | self.TCN.append(DepthConv1d(BN_dim, hidden_dim, kernel, dilation=1, padding=1, skip=skip, causal=causal)) 206 | if i == 0 and s == 0: 207 | self.receptive_field += kernel 208 | else: 209 | if self.dilated: 210 | self.receptive_field += (kernel - 1) * 2**i 211 | else: 212 | self.receptive_field += (kernel - 1) 213 | 214 | #print("Receptive field: {:3d} frames.".format(self.receptive_field)) 215 | 216 | # output layer 217 | 218 | self.output = nn.Sequential(nn.PReLU(), 219 | nn.Conv1d(BN_dim, output_dim, 1) 220 | ) 221 | 222 | self.skip = skip 223 | 224 | def forward(self, input, ssl_weight, ssl_bias): 225 | 226 | # input shape: (B, N, L) 227 | 228 | # normalization 229 | output = self.BN(self.LN(input)) 230 | 231 | # pass to TCN 232 | if self.skip: 233 | skip_connection = 0. 234 | for i in range(len(self.TCN)): 235 | residual, skip = self.TCN[i](output) 236 | if i in self.Film_loc: 237 | residual=ssl_weight*residual+ssl_bias 238 | 239 | output = output + residual 240 | skip_connection = skip_connection + skip 241 | else: 242 | for i in range(len(self.TCN)): 243 | 244 | residual = self.TCN[i](output) 245 | if i in self.Film_loc: 246 | residual=ssl_weight*residual+ssl_bias 247 | output = output + residual 248 | 249 | # output layer 250 | if self.skip: 251 | output = self.output(skip_connection) 252 | else: 253 | output = self.output(output) 254 | 255 | return output -------------------------------------------------------------------------------- /fourth_year/src/util/util.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import torch 3 | import numpy as np 4 | import random 5 | import os 6 | from datetime import datetime 7 | import matplotlib.pyplot as plt 8 | import shutil 9 | import select 10 | import sys 11 | from time import sleep 12 | from signal import signal, alarm, SIGALRM 13 | # shutil, plotly 14 | 15 | def load_yaml(yaml_dir): 16 | yaml_file=open(yaml_dir, 'r') 17 | data=yaml.safe_load(yaml_file) 18 | yaml_file.close() 19 | return data 20 | 21 | class AverageMeter(object): 22 | """Computes and stores the average and current value""" 23 | def __init__(self): 24 | self.reset() 25 | 26 | def reset(self): 27 | self.val = 0 28 | self.avg = 0 29 | self.sum = 0 30 | self.count = 0 31 | 32 | def update(self, val, n=1): 33 | self.val = val 34 | self.sum += val * n 35 | self.count += n 36 | self.avg = self.sum / self.count 37 | 38 | def copy_folder(dir): 39 | shutil.copytree('./',dir+'src', ignore=shutil.ignore_patterns('./wandb/**','*.png', '*.wav')) 40 | 41 | 42 | def exp_mkdir(config): 43 | dir=config['exp']['result_dir'] 44 | 45 | # experiment directory 46 | dir=dir+config['exp']['name']+'/' 47 | os.makedirs(dir, exist_ok=True) 48 | exp_description_txt=dir+'exp_description.txt' 49 | 50 | if os.path.exists(exp_description_txt)==False: 51 | print('Please write exp Description') 52 | f=open(exp_description_txt, 'w') 53 | f.close() 54 | exit() 55 | exp_dir=dir 56 | 57 | # model directory 58 | dir=dir+config['model']['name']+'/' 59 | os.makedirs(dir, exist_ok=True) 60 | model_description_txt=dir+'model_description.txt' 61 | if os.path.exists(model_description_txt)==False: 62 | print('Please write model Description') 63 | f=open(model_description_txt, 'w') 64 | f.close() 65 | f=open(dir+'model_structure.txt', 'w') 66 | f.close() 67 | f=open(dir+'model_summary.txt', 'w') 68 | f.close() 69 | exit() 70 | model_dir=dir 71 | 72 | # date directory 73 | 74 | if config['exp']['temp']==True: 75 | dir=dir+'temp/' 76 | else: 77 | dir=dir+'exp/' 78 | 79 | 80 | 81 | 82 | os.makedirs(dir, exist_ok=True) 83 | 84 | 85 | now_time=datetime.now().strftime('%Y_%m_%d_%H_%M_%S') 86 | dir=dir+now_time+'/' 87 | os.makedirs(dir, exist_ok=True) 88 | 89 | os.makedirs(dir+'model_save', exist_ok=True) 90 | os.makedirs(dir+'result', exist_ok=True) 91 | 92 | f=open(dir+'train_status.txt', 'w') 93 | f.close() 94 | 95 | 96 | 97 | print('Please write date Description') 98 | f=open(dir+'description.txt', 'w') 99 | f.close() 100 | 101 | 102 | return exp_dir, model_dir, dir 103 | 104 | 105 | def get_yaml_args(yaml_list): 106 | # print(yaml_list) 107 | # exit() 108 | yaml_out={} 109 | for a in yaml_list: 110 | a=a.split(' ') 111 | yaml_out[a[0]]=load_yaml(a[1]) 112 | 113 | 114 | return yaml_out 115 | 116 | def randomseed_init(num): 117 | np.random.seed(num) 118 | random.seed(num) 119 | torch.manual_seed(num) 120 | # torch.Generator.manual_seed(num) 121 | if torch.cuda.is_available(): 122 | torch.cuda.manual_seed(num) 123 | return 'cuda' 124 | else: 125 | return 'cpu' 126 | 127 | def draw_result_pic(dir, epoch, train, val): 128 | fig1 = plt.figure(figsize=(7,4)) 129 | epo = np.arange(epoch+1) 130 | 131 | os.makedirs(os.path.dirname(dir), exist_ok=True) 132 | a2 = fig1.add_subplot(1, 1, 1) 133 | a2.plot(epo, train, epo,val) 134 | a2.set_title("Loss") 135 | a2.legend(['Train', 'Eval']) 136 | a2.set_ylabel('Loss') 137 | a2.set_xlabel('Epochs') 138 | a2.grid(axis='y', linestyle='dashed') 139 | fig1.tight_layout() 140 | fig1.savefig(dir, dpi=300) 141 | plt.close(fig1) 142 | 143 | def check_list(config): 144 | 145 | 146 | 147 | print('Enter anything for checklist') 148 | 149 | i, o, e = select.select( [sys.stdin], [], [], 10 ) 150 | 151 | if (i): 152 | print('\nChecklist Starts!!') 153 | else: 154 | print("\nAuto start!!!") 155 | return 156 | 157 | print('\nBatch size') 158 | print(config['train']['dataloader']['batch_size']) 159 | 160 | _=input() 161 | 162 | print('\nnum_workers') 163 | print(config['train']['dataloader']['num_workers']) 164 | _=input() 165 | 166 | print('\nGPU ID') 167 | print(config['train']['GPGPU']['device_ids']) 168 | _=input() 169 | 170 | print('\nExperiment name') 171 | print(config['exp']['name']) 172 | _=input() 173 | 174 | print('\nModel name') 175 | print(config['model']['name']) 176 | _=input() 177 | 178 | print('\nIs this temp or not? (temp: Y, not: N)') 179 | temp=input() 180 | 181 | if temp in ['N', 'n']: 182 | config['exp']['temp']=False 183 | else: 184 | print('This is temporaray exp!!!') 185 | config['exp']['temp']=True 186 | 187 | 188 | 189 | def log_saving(record_dir, record_param, epoch, model, writer, optimizer, result_text,restart_num, end=False, temp=False): 190 | write_file=open(result_text, 'a') 191 | # print(epoch) 192 | print("\nAccuracy(train, eval) : %3.3f %3.3f" % ( 193 | record_param['accu_list_train'][epoch] * 100, record_param['accu_list_eval'][epoch] * 100)) 194 | write_file.write("\nAccuracy(train, eval) : %3.3f %3.3f \n" % ( 195 | record_param['accu_list_train'][epoch] * 100, record_param['accu_list_eval'][epoch] * 100)) 196 | 197 | print("Max accuracy(eval) : %d epoch, %3.3f\n" % ( 198 | record_param['accu_list_eval'].index(max(record_param['accu_list_eval'])) , 199 | max(record_param['accu_list_eval']) * 100)) 200 | write_file.write("Max accuracy(eval) : %d epoch, %3.3f\n\n" % ( 201 | record_param['accu_list_eval'].index(max(record_param['accu_list_eval'])), 202 | max(record_param['accu_list_eval']) * 100)) 203 | 204 | print("Loss(train, eval) : %3.3f %3.3f" % ( 205 | record_param['loss_list_train'][epoch], record_param['loss_list_eval'][epoch])) 206 | write_file.write("Loss(train, eval) : %3.3f %3.3f\n" % ( 207 | record_param['loss_list_train'][epoch], record_param['loss_list_eval'][epoch])) 208 | 209 | print("Min loss(eval): %d epoch, %3.3f\n" % ( 210 | record_param['loss_list_eval'].index(min(record_param['loss_list_eval'])) , min(record_param['loss_list_eval']))) 211 | write_file.write("Min loss(eval): %d epoch, %3.3f\n\n" % ( 212 | record_param['loss_list_eval'].index(min(record_param['loss_list_eval'])) , min(record_param['loss_list_eval']))) 213 | 214 | write_file.close() 215 | 216 | if temp ==False: 217 | writer.add_scalars('Accuracy', {'Accuracy/Train': record_param['accu_list_train'][-1], 'Accuracy/Val':record_param['accu_list_eval'][-1]}, epoch) 218 | writer.add_scalars('Loss', {'Loss/Train': record_param['loss_list_train'][-1], 'Loss/Val': record_param['loss_list_eval'][-1]}, epoch) 219 | 220 | if (epoch % 10 == 0) or end ==True or epoch<10 or epoch!=0: 221 | for name, parameter in model.named_parameters(): 222 | writer.add_histogram(name, parameter.clone().detach().cpu().data.numpy(), epoch) 223 | 224 | for x in restart_num: 225 | if epoch>=(x-6) and epoch<=(x-1): 226 | torch.save({'epoch': epoch, 227 | 'model_state_dict': model.state_dict(), 228 | 'optimizer_state_dict': optimizer.state_dict(), 229 | 'param': record_param 230 | }, record_dir + "/model_"+str(epoch)+".pth") 231 | 232 | torch.save({'epoch': epoch, 233 | 'model_state_dict': model.state_dict(), 234 | 'optimizer_state_dict': optimizer.state_dict(), 235 | 'param': record_param 236 | }, record_dir + "/current_model.pth") 237 | if record_param['loss_list_eval'][-1]==min(record_param['loss_list_eval']): 238 | torch.save({'epoch': epoch, 239 | 'model_state_dict': model.state_dict(), 240 | 'optimizer_state_dict': optimizer.state_dict(), 241 | 'param': record_param 242 | }, record_dir + "/best_loss_model.pth") 243 | 244 | if record_param['accu_list_eval'][-1]==max(record_param['accu_list_eval']): 245 | torch.save({'epoch': epoch, 246 | 'model_state_dict': model.state_dict(), 247 | 'optimizer_state_dict': optimizer.state_dict(), 248 | 'param': record_param 249 | }, record_dir + "/best_accu_model.pth") 250 | 251 | 252 | fig1 = plt.figure(figsize=(7,4)) 253 | epo = np.arange(0, epoch + 1, 1) 254 | a1 = fig1.add_subplot(2, 1, 1) 255 | a1.plot(epo, np.array(record_param['accu_list_train']), epo, np.array(record_param['accu_list_eval'])) 256 | a1.set_title("Accuracy") 257 | a1.legend(['Train', 'Eval']) 258 | a1.set_xlabel('Epochs') 259 | a1.set_ylabel('Accuracy') 260 | a1.grid(axis='y', linestyle='dashed') 261 | 262 | a2 = fig1.add_subplot(2, 1, 2) 263 | a2.plot(epo, np.array(record_param['loss_list_train']), epo, np.array(record_param['loss_list_eval'])) 264 | a2.set_title("Loss") 265 | a2.legend(['Train', 'Eval']) 266 | a2.set_ylabel('Loss') 267 | a2.set_xlabel('Epochs') 268 | a2.grid(axis='y', linestyle='dashed') 269 | fig1.tight_layout() 270 | fig1.savefig(record_dir + '/accu_loss.png', dpi=300) 271 | plt.close(fig1) 272 | fig1 = plt.figure(figsize=(7,4)) 273 | epo = np.arange(0, epoch + 1, 1) 274 | a1 = fig1.add_subplot(2, 1, 1) 275 | a1.plot(epo, np.array(record_param['accu_list_train']), epo, np.array(record_param['accu_list_eval'])) 276 | a1.set_title("Accuracy") 277 | a1.legend(['Train', 'Eval']) 278 | a1.set_xlabel('Epochs') 279 | a1.set_ylabel('Accuracy') 280 | a1.grid(axis='y', linestyle='dashed') 281 | 282 | a2 = fig1.add_subplot(2, 1, 2) 283 | a2.plot(epo, np.array(record_param['loss_list_train']), epo, np.array(record_param['loss_list_eval'])) 284 | a2.set_title("Loss") 285 | a2.legend(['Train', 'Eval']) 286 | a2.set_ylabel('Loss') 287 | a2.set_xlabel('Epochs') 288 | a2.grid(axis='y', linestyle='dashed') 289 | fig1.tight_layout() 290 | fig1.savefig(record_dir + '/accu_loss.png', dpi=300) 291 | plt.close(fig1) 292 | 293 | 294 | -------------------------------------------------------------------------------- /second_year/src/util/util.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import torch 3 | import numpy as np 4 | import random 5 | import os 6 | from datetime import datetime 7 | import matplotlib.pyplot as plt 8 | import shutil 9 | import select 10 | import sys 11 | from time import sleep 12 | from signal import signal, alarm, SIGALRM 13 | # shutil, plotly 14 | 15 | def load_yaml(yaml_dir): 16 | yaml_file=open(yaml_dir, 'r') 17 | data=yaml.safe_load(yaml_file) 18 | yaml_file.close() 19 | return data 20 | 21 | class AverageMeter(object): 22 | """Computes and stores the average and current value""" 23 | def __init__(self): 24 | self.reset() 25 | 26 | def reset(self): 27 | self.val = 0 28 | self.avg = 0 29 | self.sum = 0 30 | self.count = 0 31 | 32 | def update(self, val, n=1): 33 | self.val = val 34 | self.sum += val * n 35 | self.count += n 36 | self.avg = self.sum / self.count 37 | 38 | def copy_folder(dir): 39 | shutil.copytree('./',dir+'src', ignore=shutil.ignore_patterns('./wandb/**','*.png', '*.wav')) 40 | 41 | 42 | def exp_mkdir(config): 43 | dir=config['exp']['result_dir'] 44 | 45 | # experiment directory 46 | dir=dir+config['exp']['name']+'/' 47 | os.makedirs(dir, exist_ok=True) 48 | exp_description_txt=dir+'exp_description.txt' 49 | 50 | if os.path.exists(exp_description_txt)==False: 51 | print('Please write exp Description') 52 | f=open(exp_description_txt, 'w') 53 | f.close() 54 | exit() 55 | exp_dir=dir 56 | 57 | # model directory 58 | dir=dir+config['model']['name']+'/' 59 | os.makedirs(dir, exist_ok=True) 60 | model_description_txt=dir+'model_description.txt' 61 | if os.path.exists(model_description_txt)==False: 62 | print('Please write model Description') 63 | f=open(model_description_txt, 'w') 64 | f.close() 65 | f=open(dir+'model_structure.txt', 'w') 66 | f.close() 67 | f=open(dir+'model_summary.txt', 'w') 68 | f.close() 69 | exit() 70 | model_dir=dir 71 | 72 | # date directory 73 | 74 | if config['exp']['temp']==True: 75 | dir=dir+'temp/' 76 | else: 77 | dir=dir+'exp/' 78 | 79 | 80 | 81 | 82 | os.makedirs(dir, exist_ok=True) 83 | 84 | 85 | now_time=datetime.now().strftime('%Y_%m_%d_%H_%M_%S') 86 | dir=dir+now_time+'/' 87 | os.makedirs(dir, exist_ok=True) 88 | 89 | os.makedirs(dir+'model_save', exist_ok=True) 90 | os.makedirs(dir+'result', exist_ok=True) 91 | 92 | f=open(dir+'train_status.txt', 'w') 93 | f.close() 94 | 95 | 96 | 97 | print('Please write date Description') 98 | f=open(dir+'description.txt', 'w') 99 | f.close() 100 | 101 | 102 | return exp_dir, model_dir, dir 103 | 104 | 105 | def get_yaml_args(yaml_list): 106 | # print(yaml_list) 107 | # exit() 108 | yaml_out={} 109 | for a in yaml_list: 110 | a=a.split(' ') 111 | yaml_out[a[0]]=load_yaml(a[1]) 112 | 113 | 114 | return yaml_out 115 | 116 | def randomseed_init(num): 117 | np.random.seed(num) 118 | random.seed(num) 119 | torch.manual_seed(num) 120 | # torch.Generator.manual_seed(num) 121 | if torch.cuda.is_available(): 122 | torch.cuda.manual_seed(num) 123 | return 'cuda' 124 | else: 125 | return 'cpu' 126 | 127 | def draw_result_pic(dir, epoch, train, val): 128 | fig1 = plt.figure(figsize=(7,4)) 129 | epo = np.arange(epoch+1) 130 | 131 | os.makedirs(os.path.dirname(dir), exist_ok=True) 132 | a2 = fig1.add_subplot(1, 1, 1) 133 | a2.plot(epo, train, epo,val) 134 | a2.set_title("Loss") 135 | a2.legend(['Train', 'Eval']) 136 | a2.set_ylabel('Loss') 137 | a2.set_xlabel('Epochs') 138 | a2.grid(axis='y', linestyle='dashed') 139 | fig1.tight_layout() 140 | fig1.savefig(dir, dpi=300) 141 | plt.close(fig1) 142 | 143 | def check_list(config): 144 | 145 | 146 | 147 | print('Enter anything for checklist') 148 | 149 | i, o, e = select.select( [sys.stdin], [], [], 10 ) 150 | 151 | if (i): 152 | print('\nChecklist Starts!!') 153 | else: 154 | print("\nAuto start!!!") 155 | return 156 | 157 | print('\nBatch size') 158 | print(config['train']['dataloader']['batch_size']) 159 | 160 | _=input() 161 | 162 | print('\nnum_workers') 163 | print(config['train']['dataloader']['num_workers']) 164 | _=input() 165 | 166 | print('\nGPU ID') 167 | print(config['train']['GPGPU']['device_ids']) 168 | _=input() 169 | 170 | print('\nExperiment name') 171 | print(config['exp']['name']) 172 | _=input() 173 | 174 | print('\nModel name') 175 | print(config['model']['name']) 176 | _=input() 177 | 178 | print('\nIs this temp or not? (temp: Y, not: N)') 179 | temp=input() 180 | 181 | if temp in ['N', 'n']: 182 | config['exp']['temp']=False 183 | else: 184 | print('This is temporaray exp!!!') 185 | config['exp']['temp']=True 186 | 187 | 188 | 189 | def log_saving(record_dir, record_param, epoch, model, writer, optimizer, result_text,restart_num, end=False, temp=False): 190 | write_file=open(result_text, 'a') 191 | # print(epoch) 192 | print("\nAccuracy(train, eval) : %3.3f %3.3f" % ( 193 | record_param['accu_list_train'][epoch] * 100, record_param['accu_list_eval'][epoch] * 100)) 194 | write_file.write("\nAccuracy(train, eval) : %3.3f %3.3f \n" % ( 195 | record_param['accu_list_train'][epoch] * 100, record_param['accu_list_eval'][epoch] * 100)) 196 | 197 | print("Max accuracy(eval) : %d epoch, %3.3f\n" % ( 198 | record_param['accu_list_eval'].index(max(record_param['accu_list_eval'])) , 199 | max(record_param['accu_list_eval']) * 100)) 200 | write_file.write("Max accuracy(eval) : %d epoch, %3.3f\n\n" % ( 201 | record_param['accu_list_eval'].index(max(record_param['accu_list_eval'])), 202 | max(record_param['accu_list_eval']) * 100)) 203 | 204 | print("Loss(train, eval) : %3.3f %3.3f" % ( 205 | record_param['loss_list_train'][epoch], record_param['loss_list_eval'][epoch])) 206 | write_file.write("Loss(train, eval) : %3.3f %3.3f\n" % ( 207 | record_param['loss_list_train'][epoch], record_param['loss_list_eval'][epoch])) 208 | 209 | print("Min loss(eval): %d epoch, %3.3f\n" % ( 210 | record_param['loss_list_eval'].index(min(record_param['loss_list_eval'])) , min(record_param['loss_list_eval']))) 211 | write_file.write("Min loss(eval): %d epoch, %3.3f\n\n" % ( 212 | record_param['loss_list_eval'].index(min(record_param['loss_list_eval'])) , min(record_param['loss_list_eval']))) 213 | 214 | write_file.close() 215 | 216 | if temp ==False: 217 | writer.add_scalars('Accuracy', {'Accuracy/Train': record_param['accu_list_train'][-1], 'Accuracy/Val':record_param['accu_list_eval'][-1]}, epoch) 218 | writer.add_scalars('Loss', {'Loss/Train': record_param['loss_list_train'][-1], 'Loss/Val': record_param['loss_list_eval'][-1]}, epoch) 219 | 220 | if (epoch % 10 == 0) or end ==True or epoch<10 or epoch!=0: 221 | for name, parameter in model.named_parameters(): 222 | writer.add_histogram(name, parameter.clone().detach().cpu().data.numpy(), epoch) 223 | 224 | for x in restart_num: 225 | if epoch>=(x-6) and epoch<=(x-1): 226 | torch.save({'epoch': epoch, 227 | 'model_state_dict': model.state_dict(), 228 | 'optimizer_state_dict': optimizer.state_dict(), 229 | 'param': record_param 230 | }, record_dir + "/model_"+str(epoch)+".pth") 231 | 232 | torch.save({'epoch': epoch, 233 | 'model_state_dict': model.state_dict(), 234 | 'optimizer_state_dict': optimizer.state_dict(), 235 | 'param': record_param 236 | }, record_dir + "/current_model.pth") 237 | if record_param['loss_list_eval'][-1]==min(record_param['loss_list_eval']): 238 | torch.save({'epoch': epoch, 239 | 'model_state_dict': model.state_dict(), 240 | 'optimizer_state_dict': optimizer.state_dict(), 241 | 'param': record_param 242 | }, record_dir + "/best_loss_model.pth") 243 | 244 | if record_param['accu_list_eval'][-1]==max(record_param['accu_list_eval']): 245 | torch.save({'epoch': epoch, 246 | 'model_state_dict': model.state_dict(), 247 | 'optimizer_state_dict': optimizer.state_dict(), 248 | 'param': record_param 249 | }, record_dir + "/best_accu_model.pth") 250 | 251 | 252 | fig1 = plt.figure(figsize=(7,4)) 253 | epo = np.arange(0, epoch + 1, 1) 254 | a1 = fig1.add_subplot(2, 1, 1) 255 | a1.plot(epo, np.array(record_param['accu_list_train']), epo, np.array(record_param['accu_list_eval'])) 256 | a1.set_title("Accuracy") 257 | a1.legend(['Train', 'Eval']) 258 | a1.set_xlabel('Epochs') 259 | a1.set_ylabel('Accuracy') 260 | a1.grid(axis='y', linestyle='dashed') 261 | 262 | a2 = fig1.add_subplot(2, 1, 2) 263 | a2.plot(epo, np.array(record_param['loss_list_train']), epo, np.array(record_param['loss_list_eval'])) 264 | a2.set_title("Loss") 265 | a2.legend(['Train', 'Eval']) 266 | a2.set_ylabel('Loss') 267 | a2.set_xlabel('Epochs') 268 | a2.grid(axis='y', linestyle='dashed') 269 | fig1.tight_layout() 270 | fig1.savefig(record_dir + '/accu_loss.png', dpi=300) 271 | plt.close(fig1) 272 | fig1 = plt.figure(figsize=(7,4)) 273 | epo = np.arange(0, epoch + 1, 1) 274 | a1 = fig1.add_subplot(2, 1, 1) 275 | a1.plot(epo, np.array(record_param['accu_list_train']), epo, np.array(record_param['accu_list_eval'])) 276 | a1.set_title("Accuracy") 277 | a1.legend(['Train', 'Eval']) 278 | a1.set_xlabel('Epochs') 279 | a1.set_ylabel('Accuracy') 280 | a1.grid(axis='y', linestyle='dashed') 281 | 282 | a2 = fig1.add_subplot(2, 1, 2) 283 | a2.plot(epo, np.array(record_param['loss_list_train']), epo, np.array(record_param['loss_list_eval'])) 284 | a2.set_title("Loss") 285 | a2.legend(['Train', 'Eval']) 286 | a2.set_ylabel('Loss') 287 | a2.set_xlabel('Epochs') 288 | a2.grid(axis='y', linestyle='dashed') 289 | fig1.tight_layout() 290 | fig1.savefig(record_dir + '/accu_loss.png', dpi=300) 291 | plt.close(fig1) 292 | 293 | 294 | -------------------------------------------------------------------------------- /second_year/src/models/convtasnet_SSL_FiLM/Causal_CRN_SPL_target/CRN.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules import conv 2 | from .FFT import EPSILON, ConvSTFT, ConviSTFT 3 | from torch import nn 4 | import torch 5 | from util import * 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | 9 | class Causal_Conv2D_Block(nn.Module): 10 | def __init__(self, *args, **kwargs): 11 | super(Causal_Conv2D_Block, self).__init__() 12 | 13 | 14 | self.conv2d=nn.Conv2d(*args, **kwargs) 15 | 16 | 17 | self.norm=nn.BatchNorm2d(args[1]) 18 | 19 | self.activation=nn.ELU() 20 | 21 | def forward(self, x): 22 | original_frame_num=x.shape[-1] 23 | x=self.conv2d(x) 24 | x=self.norm(x) 25 | x=self.activation(x) 26 | x=x[...,:original_frame_num] 27 | 28 | 29 | 30 | return x 31 | 32 | class Conv1D_Block(nn.Module): 33 | def __init__(self, *args, **kwargs): 34 | super(Conv1D_Block, self).__init__() 35 | 36 | 37 | self.conv1d=nn.Conv1d(*args, **kwargs) 38 | 39 | 40 | self.norm=nn.BatchNorm1d(args[1]) 41 | 42 | self.activation=nn.ELU() 43 | 44 | def forward(self, x): 45 | 46 | x=self.conv1d(x) 47 | x=self.norm(x) 48 | x=self.activation(x) 49 | 50 | 51 | 52 | return x 53 | 54 | 55 | class crn(nn.Module): 56 | def __init__(self, config, output_num, azi_size): 57 | super(crn, self).__init__() 58 | 59 | 60 | self.output_num=output_num 61 | self.azi_size=azi_size 62 | 63 | 64 | self.cnn_num=config['CNN']['layer_num'] 65 | self.kernel_size=config['CNN']['kernel_size'] 66 | self.filter_size=config['CNN']['filter'] 67 | 68 | self.max_pool_kernel=config['CNN']['max_pool']['kernel_size'] 69 | self.max_pool_stride=config['CNN']['max_pool']['stride'] 70 | 71 | args=[2*(config['input_audio_channel']-1),self.filter_size,self.kernel_size] # in_channel, out_channel, kernel size 72 | 73 | kwargs={'stride': 1, 'padding': [1,2], 'dilation': 1} 74 | 75 | 76 | 77 | 78 | self.cnn=nn.ModuleList() 79 | self.pooling=nn.ModuleList() 80 | self.cnn.append(Causal_Conv2D_Block(*args, **kwargs)) 81 | self.pooling.append(nn.MaxPool2d(self.max_pool_kernel, stride=self.max_pool_stride)) 82 | 83 | args[0]=config['CNN']['filter'] 84 | for count in range(self.cnn_num-1): 85 | self.cnn.append(Causal_Conv2D_Block(*args, **kwargs)) 86 | self.pooling.append(nn.MaxPool2d(self.max_pool_kernel, stride=self.max_pool_stride)) 87 | 88 | self.GRU_layer=nn.GRU(**config['GRU']) 89 | 90 | self.azi_mapping_conv_layer=nn.ModuleList() 91 | self.azi_mapping_final=nn.ModuleList() 92 | 93 | args[0]=config['GRU']['hidden_size'] 94 | args[1]=config['GRU']['hidden_size'] 95 | args[2]=1 96 | kwargs['padding']=0 97 | 98 | for _ in range(output_num): 99 | self.azi_mapping_conv_layer.append(Conv1D_Block(*args, **kwargs)) 100 | self.azi_mapping_final.append(nn.Conv1d(config['GRU']['hidden_size'], self.azi_size, 1)) 101 | 102 | 103 | 104 | 105 | 106 | 107 | def forward(self, x): 108 | 109 | for cnn_layer, pooling_layer in zip(self.cnn, self.pooling): 110 | # print(x.shape) 111 | x=cnn_layer(x)[...,:x.shape[-1]] 112 | x=pooling_layer(x) 113 | # print(x.shape) 114 | # exit() 115 | 116 | 117 | b, c, f, t=x.shape 118 | x=x.view(b, -1, t).permute(0,2,1) 119 | 120 | # x_cnn=self.lstm_cnn_layer(x) 121 | self.GRU_layer.flatten_parameters() 122 | x, h=self.GRU_layer(x) 123 | x=x.permute(0,2,1) 124 | 125 | outputs=[] 126 | 127 | for final_layer, cnn_layer in zip(self.azi_mapping_final, self.azi_mapping_conv_layer): 128 | x=cnn_layer(x) 129 | res_output=final_layer(x) 130 | outputs.append(res_output) 131 | output=torch.stack(outputs).permute(1,0,2,3) 132 | 133 | return output 134 | 135 | 136 | class main_model(nn.Module): 137 | def __init__(self, config): 138 | super(main_model, self).__init__() 139 | self.config=config 140 | 141 | self.eps=np.finfo(np.float32).eps 142 | self.ref_ch=self.config['ref_ch'] 143 | 144 | ###### sigma 145 | 146 | self.sigma=torch.tensor(self.config['sigma_start']) 147 | self.sigma_max=torch.tensor(self.config['sigma_end']['max']) 148 | self.sigma_min=torch.tensor(self.config['sigma_end']['min']) 149 | self.sigma_rate=torch.tensor(self.config['sigma_rate']) 150 | self.sigma_udpate_method=self.config['sigma_update_method'] 151 | 152 | self.iteration_count=0 153 | self.epoch_count=0 154 | self.now_epoch=0 155 | 156 | 157 | ###### 158 | 159 | self.max_spk=self.config['max_spk'] 160 | self.degree_resolution = self.config['degree_resolution'] 161 | self.azi_size=360//self.degree_resolution 162 | 163 | self.stft_model=ConvSTFT(**self.config['FFT']) 164 | self.crn=crn(self.config['CRN'], self.sigma.shape[0], self.azi_size) 165 | 166 | def save_fig(self, data): 167 | data=data[2].cpu().detach().numpy() 168 | fig=plt.figure(figsize=(7,3),)#, nrows=azi_num, ncols=1, sharey=True) 169 | 170 | for i in range(data.shape[1]): 171 | plt.subplot(1, data.shape[1], i+1) 172 | 173 | lj=plt.imshow(data[:,i], vmin=0, vmax=1.0, cmap = plt.get_cmap('plasma'), aspect='auto') 174 | plt.colorbar(lj) 175 | 176 | plt.tight_layout() 177 | plt.savefig('./png_folder/target_coding.png', dpi=400, interpolation="none") 178 | plt.close() 179 | plt.clf() 180 | plt.cla() 181 | exit() 182 | 183 | def sigma_update(self, iter_num, epoch): 184 | if iter_num%500==0: 185 | print(self.sigma) 186 | def update(): 187 | 188 | 189 | if self.sigma_udpate_method=='add': 190 | self.sigma+=self.sigma_rate 191 | elif self.sigma_udpate_method=='multiply': 192 | self.sigma*=self.sigma_rate 193 | else: 194 | "Not exist!!!" 195 | exit() 196 | 197 | self.sigma=torch.clamp(self.sigma, self.sigma_min, self.sigma_max) 198 | 199 | 200 | if self.training: 201 | 202 | if self.config['iter']['update']: 203 | if self.iteration_count!=self.config['iter']['update_period']: 204 | self.iteration_count+=1 205 | 206 | else: 207 | print('sigma_iter update') 208 | update() 209 | self.iteration_count=0 210 | return 211 | 212 | if self.config['epoch']['update']: 213 | 214 | if self.now_epoch!=epoch: 215 | self.now_epoch=epoch 216 | self.epoch_count+=1 217 | 218 | if self.epoch_count==self.config['epoch']['update_period']: 219 | print('sigma_epoch update') 220 | update() 221 | self.epoch_count=0 222 | return 223 | 224 | 225 | 226 | 227 | def make_target(self, target, azi, iter_num, epoch): 228 | 229 | 230 | 231 | azi_target=torch.div(azi, 360//self.azi_size, rounding_mode='floor').long() 232 | azi_range=torch.arange(0, self.azi_size).unsqueeze(0).to(azi_target.device) 233 | 234 | # target_for_loss=torch.zeros((target.shape[0], self.azi_size,self.sigma_resolution_tensor.shape[0], target.shape[-1]), dtype=torch.float32, device=azi_target.device) 235 | 236 | distance=azi_target.unsqueeze(-1)*self.degree_resolution-azi_range*self.degree_resolution 237 | 238 | distance_abs=torch.abs(distance) 239 | distance_abs=torch.stack((distance_abs, 360-distance_abs), dim=0) 240 | distance=torch.pow(torch.min(distance_abs, dim=0).values,2) 241 | 242 | 243 | sigma=self.sigma.view(1,1,-1).to(distance.device) 244 | 245 | labelling=torch.exp(-distance.unsqueeze(-1)/sigma**2).unsqueeze(-1) 246 | 247 | target=target.unsqueeze(2).unsqueeze(2) 248 | target=labelling*target 249 | target=torch.max(target, dim=1).values 250 | # print(target.shape) 251 | # exit() 252 | 253 | target=target.permute(0,2,1,3) 254 | 255 | # tt=target[0, :, 0, :].cpu().detach().numpy() 256 | # plt.imshow(tt) 257 | # plt.savefig('../data/target.png') 258 | # plt.close() 259 | 260 | # plt.plot(tt[:,5]) 261 | # plt.savefig('../data/one_frame.png') 262 | 263 | self.sigma_update(iter_num, epoch) 264 | 265 | return target # batch, sigma_num, degree, frame 266 | 267 | def irtf_featue(self, x, target): 268 | r, i, target =self.stft_model(x, target, cplx=True) 269 | 270 | 271 | 272 | comp = torch.complex(r, i) 273 | 274 | comp_ref = comp[..., [self.ref_ch], :, :] 275 | comp_ref = torch.complex( 276 | comp_ref.real.clamp(self.eps), comp_ref.imag.clamp(self.eps) 277 | ) 278 | 279 | comp=torch.cat( 280 | (comp[..., self.ref_ch-1:self.ref_ch, :, :], comp[..., self.ref_ch+1:, :, :]), 281 | dim=-3) / comp_ref 282 | x=torch.cat((comp.real, comp.imag), dim=1) 283 | 284 | return x, target 285 | 286 | def vad_framing(self, vad_batch): 287 | 288 | vad_output_th = vad_batch.mean(axis=2) > 2 / 3 289 | 290 | vad_output_th = vad_output_th[:, np.newaxis, :, np.newaxis, np.newaxis] 291 | vad_output_th = torch.from_numpy(vad_output_th.astype(float)).to(maps.device) 292 | repeat_factor = np.array(maps.shape) 293 | repeat_factor[:-2] = 1 294 | maps *= vad_output_th.float().repeat(repeat_factor.tolist()) 295 | 296 | def plot_vad_freq(self, data): 297 | data=data[0].detach().cpu().numpy() 298 | 299 | for i in range(data.shape[0]): 300 | plt.plot(data[i]) 301 | plt.savefig('../results/vad_freq.png') 302 | exit() 303 | 304 | def forward(self, x, vad, each_azi, iter_num, epoch, LOCATA=False): 305 | 306 | ###### irtf feature 307 | x, vad_frame=self.irtf_featue(x, vad) 308 | 309 | # self.vad_framing(vad) 310 | 311 | 312 | 313 | if LOCATA: 314 | target=self.stft_model.azimuth_strided(vad_frame, each_azi).unsqueeze(0) 315 | else: 316 | target=self.make_target( vad_frame, each_azi, iter_num, epoch) 317 | # target=self.cjh_target(vad_frame, each_azi, iter_num, epoch) 318 | 319 | 320 | 321 | 322 | 323 | x=self.crn(x) 324 | 325 | 326 | 327 | 328 | 329 | 330 | return x, target, vad_frame 331 | 332 | if __name__=='__main__': 333 | device='cuda' 334 | yaml_file=load_yaml('./config/train.yaml')['model']['structure'] 335 | 336 | 337 | 338 | 339 | model=main_model(yaml_file).eval().to(device) 340 | 341 | batch=2 342 | length=64000 343 | with torch.no_grad(): 344 | for i in range(1): 345 | mixture=torch.randn((batch, 8, length)).to(device) 346 | target=torch.randn((batch, 2, length)).to(device) 347 | azi=torch.tensor([180, 355]).to(device).unsqueeze(0) 348 | azi=azi.repeat_interleave(batch, dim=0) 349 | 350 | output=model(mixture, target, azi) 351 | 352 | -------------------------------------------------------------------------------- /third_year/src/make_test_set.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import util 3 | import torch 4 | import numpy as np 5 | import random 6 | import importlib 7 | import math 8 | import wandb 9 | from tqdm import tqdm 10 | from dataloader.data_loader_for_db_make import Test_data_maker_load, val_data_maker_load 11 | import matplotlib.pyplot as plt 12 | import pandas as pd 13 | 14 | class Hyparam_set(): 15 | 16 | def __init__(self, args): 17 | self.args=args 18 | 19 | 20 | def set_torch_method(self,): 21 | try: 22 | torch.multiprocessing.set_start_method(self.args['hyparam']['torch_start_method'], force=False) # spawn 23 | except: 24 | torch.multiprocessing.set_start_method(self.args['hyparam']['torch_start_method'], force=True) # spawn 25 | 26 | 27 | def randomseed_init(self,): 28 | np.random.seed(self.args['hyparam']['randomseed']) 29 | random.seed(self.args['hyparam']['randomseed']) 30 | torch.manual_seed(self.args['hyparam']['randomseed']) 31 | if torch.cuda.is_available(): 32 | torch.cuda.manual_seed(self.args['hyparam']['randomseed']) 33 | 34 | device_primary_num=self.args['hyparam']['GPGPU']['device_ids'][0] 35 | device= 'cuda'+':'+str(device_primary_num) 36 | else: 37 | device= 'cpu' 38 | self.args['hyparam']['GPGPU']['device']=device 39 | return device 40 | def set_on(self): 41 | self.set_torch_method() 42 | self.device=self.randomseed_init() 43 | 44 | return self.args 45 | 46 | class Learner_config(): 47 | def __init__(self, args) -> None: 48 | self.args=args 49 | 50 | def memory_delete(self, *args): 51 | for a in args: 52 | del a 53 | 54 | def model_select(self): 55 | model_name=self.args['model']['name'] 56 | model_import='models.'+model_name+'.main' 57 | 58 | 59 | 60 | self.model=None 61 | 62 | 63 | 64 | def init_optimizer(self): 65 | 66 | a=importlib.import_module('torch.optim') 67 | assert hasattr(a, self.args['learner']['optimizer']['type']), "optimizer {} is not in {}".format(self.args['learner']['optimizer']['type'], 'torch') 68 | a=getattr(a, self.args['learner']['optimizer']['type']) 69 | 70 | self.optimizer=a(self.model.parameters(), **self.args['learner']['optimizer']['config']) 71 | self.gradient_clip=self.args['learner']['optimizer']['gradient_clip'] 72 | 73 | 74 | def init_optimzer_scheduler(self, ): 75 | a=importlib.import_module('torch.optim.lr_scheduler') 76 | assert hasattr(a, self.args['learner']['optimizer_scheduler']['type']), "optimizer scheduler {} is not in {}".format(self.args['learner']['optimizer']['type'], 'torch') 77 | a=getattr(a, self.args['learner']['optimizer_scheduler']['type']) 78 | 79 | self.optimizer_scheduler=a(self.optimizer, **self.args['learner']['optimizer_scheduler']['config']) 80 | 81 | 82 | 83 | def init_loss_func(self): 84 | 85 | 86 | 87 | if self.args['learner']['loss']['type']=='BCEWithLogitsLoss': 88 | self.loss_func=torch.nn.modules.loss.BCEWithLogitsLoss(reduction='none') 89 | self.loss_func=torch.nn.modules.loss.BCELoss(reduction='none') 90 | 91 | 92 | elif self.args['learner']['loss']['type']=='kld': 93 | self.loss_func=torch.nn.modules.loss.KLDivLoss(reduction='none') 94 | elif self.args['learner']['loss']['type']=='mse': 95 | self.loss_func=torch.nn.modules.loss.MSELoss(reduction='none') 96 | 97 | self.loss_train_map_num=self.args['learner']['loss']['option']['train_map_num'] 98 | self.loss_weight=self.args['learner']['loss']['option']['each_layer_weight'] 99 | 100 | if self.args['learner']['loss']['optimize_method']=='min': 101 | self.best_val_loss=math.inf 102 | self.best_train_loss=math.inf 103 | else: 104 | self.best_val_loss=-math.inf 105 | self.best_train_loss=-math.inf 106 | 107 | def train_update(self, output, target): 108 | target=target[:, self.loss_train_map_num] 109 | output=output[:, self.loss_train_map_num].sigmoid() 110 | loss=self.loss_func(output, target) 111 | 112 | for j in range(len(self.loss_weight)): 113 | loss[:, j]=loss[:,j]*self.loss_weight[j] 114 | 115 | 116 | 117 | loss_mean=loss.mean() 118 | 119 | 120 | if torch.isnan(loss_mean): 121 | print('nan occured') 122 | self.optimizer.zero_grad() 123 | return loss_mean 124 | 125 | loss_mean.backward() 126 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clip) 127 | self.optimizer.step() 128 | self.optimizer.zero_grad() 129 | 130 | return loss_mean 131 | 132 | def test_update(self, output, target): 133 | 134 | 135 | target=target[:, self.loss_train_map_num] 136 | output=output[:, self.loss_train_map_num].sigmoid() 137 | loss=self.loss_func(output, target) 138 | 139 | for j in range(len(self.loss_weight)): 140 | loss[:, j]=loss[:,j]*self.loss_weight[j] 141 | loss_mean=loss.mean() 142 | 143 | 144 | if torch.isnan(loss_mean): 145 | print('nan occured') 146 | self.optimizer.zero_grad() 147 | return loss_mean 148 | 149 | return loss_mean 150 | 151 | def config(self): 152 | self.device=self.args['hyparam']['GPGPU']['device'] 153 | 154 | return self.args 155 | 156 | class Logger_config(): 157 | def __init__(self, args) -> None: 158 | self.args=args 159 | self.csv=dict() 160 | self.csv['train_epoch_loss']=[] 161 | self.csv['train_best_loss']=[] 162 | self.csv['test_epoch_loss']=[] 163 | self.csv['test_best_loss']=[] 164 | 165 | self.csv_dir=self.args['logger']['save_csv'] 166 | self.model_save_dir=self.args['logger']['model_save_dir'] 167 | self.png_dir=self.args['logger']['png_dir'] 168 | 169 | if self.args['logger']['optimize_method']=='min': 170 | self.best_test_loss=math.inf 171 | self.best_train_loss=math.inf 172 | else: 173 | self.best_test_loss=-math.inf 174 | self.best_train_loss=-math.inf 175 | 176 | def train_iter_log(self, loss): 177 | try: 178 | wandb.log({'train_iter_loss':loss}) 179 | except: 180 | None 181 | self.epoch_train_loss.append(loss.cpu().detach().item()) 182 | 183 | 184 | def train_epoch_log(self): 185 | loss_mean=np.array(self.epoch_train_loss).mean() 186 | 187 | self.csv['train_epoch_loss'].append(loss_mean) 188 | 189 | if self.best_train_loss > loss_mean: 190 | self.best_train_loss = loss_mean 191 | 192 | try: 193 | wandb.log({'train_epoch_loss':loss_mean}) 194 | wandb.log({'train_best_loss':self.best_train_loss}) 195 | except: 196 | None 197 | 198 | self.csv['train_best_loss'].append(self.best_train_loss) 199 | 200 | 201 | 202 | def test_iter_log(self, loss): 203 | try: 204 | wandb.log({'test_iter_loss':loss}) 205 | except: 206 | None 207 | self.epoch_test_loss.append(loss.cpu().detach().item()) 208 | 209 | def test_epoch_log(self, optimizer_scheduler): 210 | loss_mean=np.array(self.epoch_test_loss).mean() 211 | self.csv['test_epoch_loss'].append(loss_mean) 212 | 213 | self.model_save=False 214 | if self.best_test_loss > loss_mean: 215 | self.model_save=True 216 | self.best_test_loss = loss_mean 217 | try: 218 | wandb.log({'test_epoch_loss':loss_mean}) 219 | wandb.log({'test_best_loss':self.best_test_loss}) 220 | except: 221 | None 222 | self.csv['test_best_loss'].append(self.best_test_loss) 223 | 224 | optimizer_scheduler.step(loss_mean) 225 | 226 | 227 | 228 | 229 | def epoch_init(self,): 230 | self.epoch_train_loss=[] 231 | self.epoch_test_loss=[] 232 | 233 | 234 | def epoch_finish(self, epoch, model, optimizer): 235 | os.makedirs(os.path.dirname(self.csv_dir), exist_ok=True) 236 | pd.DataFrame(self.csv).to_csv(self.csv_dir) 237 | 238 | checkpoint = { 239 | 'epoch': epoch, 240 | 'model_state_dict': model.module.state_dict(), 241 | 'optimizer': optimizer.state_dict() 242 | } 243 | 244 | os.makedirs(os.path.dirname(self.model_save_dir + "best_model.tar"), exist_ok=True) 245 | if self.model_save: 246 | os.makedirs(os.path.dirname(self.model_save_dir + "best_model.tar"), exist_ok=True) 247 | torch.save(checkpoint, self.model_save_dir + "best_model.tar") 248 | print("new best model\n") 249 | torch.save(checkpoint, self.model_save_dir + "{}_model.tar".format(epoch)) 250 | 251 | 252 | util.util.draw_result_pic(self.png_dir, epoch, self.csv['train_epoch_loss'], self.csv['test_epoch_loss']) 253 | 254 | 255 | 256 | def wandb_config(self): 257 | 258 | return self.args 259 | 260 | def config(self,): 261 | self.wandb_config() 262 | return self.args 263 | 264 | 265 | class Dataloader_config(): 266 | def __init__(self, args) -> None: 267 | self.args=args 268 | 269 | 270 | def config(self): 271 | self.test_loader=Test_data_maker_load(self.args['dataloader']['test']) 272 | self.val_loader=val_data_maker_load(self.args['dataloader']['val']) 273 | 274 | return self.args 275 | 276 | 277 | 278 | 279 | class Trainer(): 280 | 281 | def __init__(self, args): 282 | 283 | # self.temp() 284 | self.args=args 285 | 286 | self.hyperparameter=Hyparam_set(self.args) 287 | self.args=self.hyperparameter.set_on() 288 | 289 | 290 | self.learner=Learner_config(self.args) 291 | self.args=self.learner.config() 292 | 293 | 294 | self.dataloader=Dataloader_config(self.args) 295 | self.args=self.dataloader.config() 296 | 297 | self.logger=Logger_config(self.args) 298 | self.args=self.logger.config() 299 | 300 | 301 | 302 | def run(self, ): 303 | 304 | self.validation(0) 305 | 306 | self.test(0) 307 | 308 | def validation(self, epoch): 309 | 310 | 311 | with torch.no_grad(): 312 | for iter_num, (mixed, vad, speech_azi, num_spk) in enumerate(tqdm(self.dataloader.val_loader , desc='Test', total=len(self.dataloader.val_loader), )): 313 | # break 314 | continue 315 | 316 | 317 | # pkl_csv='./metadata/val_csv_linear_8.csv' 318 | # pkl_csv='./metadata/val_csv_ellipsoid_6.csv' 319 | pkl_csv='./metadata/val_csv_circular_4.csv' 320 | 321 | pkl_dir=self.dataloader.val_loader.dataset.pkl_dir 322 | pkl_list=os.listdir(pkl_dir) 323 | df={} 324 | df['data']=pkl_list 325 | pd.DataFrame(df).to_csv(pkl_csv) 326 | 327 | 328 | 329 | def test(self, epoch): 330 | # self.model.eval() 331 | with torch.no_grad(): 332 | for iter_num, (mixed, vad, speech_azi, num_spk) in enumerate(tqdm(self.dataloader.test_loader, desc='Test', total=len(self.dataloader.test_loader), )): 333 | # break 334 | continue 335 | 336 | # pkl_csv='./metadata/test_csv_linear_8.csv' 337 | # pkl_csv='./metadata/test_csv_ellipsoid_6.csv' 338 | pkl_csv='./metadata/test_csv_circular_4.csv' 339 | 340 | pkl_dir=self.dataloader.test_loader.dataset.pkl_dir 341 | pkl_list=os.listdir(pkl_dir) 342 | df={} 343 | df['data']=pkl_list 344 | pd.DataFrame(df).to_csv(pkl_csv) 345 | 346 | 347 | 348 | 349 | 350 | 351 | if __name__=='__main__': 352 | args=sys.argv[1:] 353 | 354 | args=util.util.get_yaml_args(args) 355 | t=Trainer(args) 356 | t.run() -------------------------------------------------------------------------------- /second_year/src/train.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import util 3 | import torch 4 | import numpy as np 5 | import random 6 | import importlib 7 | import math 8 | import wandb 9 | from tqdm import tqdm 10 | from dataloader.data_loader import Train_dataload, Eval_dataload 11 | import matplotlib.pyplot as plt 12 | import pandas as pd 13 | 14 | class Hyparam_set(): 15 | 16 | def __init__(self, args): 17 | self.args=args 18 | 19 | 20 | def set_torch_method(self,): 21 | try: 22 | torch.multiprocessing.set_start_method(self.args['hyparam']['torch_start_method'], force=False) # spawn 23 | except: 24 | torch.multiprocessing.set_start_method(self.args['hyparam']['torch_start_method'], force=True) # spawn 25 | 26 | 27 | def randomseed_init(self,): 28 | np.random.seed(self.args['hyparam']['randomseed']) 29 | random.seed(self.args['hyparam']['randomseed']) 30 | torch.manual_seed(self.args['hyparam']['randomseed']) 31 | if torch.cuda.is_available(): 32 | torch.cuda.manual_seed(self.args['hyparam']['randomseed']) 33 | 34 | device_primary_num=self.args['hyparam']['GPGPU']['device_ids'][0] 35 | device= 'cuda'+':'+str(device_primary_num) 36 | else: 37 | device= 'cpu' 38 | self.args['hyparam']['GPGPU']['device']=device 39 | return device 40 | 41 | def set_on(self): 42 | self.set_torch_method() 43 | self.device=self.randomseed_init() 44 | 45 | return self.args 46 | 47 | class Learner_config(): 48 | def __init__(self, args) -> None: 49 | self.args=args 50 | 51 | def memory_delete(self, *args): 52 | for a in args: 53 | del a 54 | 55 | def model_select(self): 56 | model_name=self.args['model']['name'] 57 | model_import='models.'+model_name+'.main' 58 | 59 | 60 | model_dir=importlib.import_module(model_import) 61 | self.model=model_dir.get_model(self.args['model']).to(self.device) 62 | self.model=torch.nn.DataParallel(self.model, self.args['hyparam']['GPGPU']['device_ids']) 63 | 64 | 65 | def init_optimizer(self): 66 | 67 | a=importlib.import_module('torch.optim') 68 | assert hasattr(a, self.args['learner']['optimizer']['type']), "optimizer {} is not in {}".format(self.args['learner']['optimizer']['type'], 'torch') 69 | a=getattr(a, self.args['learner']['optimizer']['type']) 70 | self.optimizer=a(self.model.parameters(), **self.args['learner']['optimizer']['config']) 71 | self.gradient_clip=self.args['learner']['optimizer']['gradient_clip'] 72 | 73 | 74 | def init_optimzer_scheduler(self, ): 75 | a=importlib.import_module('torch.optim.lr_scheduler') 76 | assert hasattr(a, self.args['learner']['optimizer_scheduler']['type']), "optimizer scheduler {} is not in {}".format(self.args['learner']['optimizer']['type'], 'torch') 77 | a=getattr(a, self.args['learner']['optimizer_scheduler']['type']) 78 | self.optimizer_scheduler=a(self.optimizer, **self.args['learner']['optimizer_scheduler']['config']) 79 | 80 | def init_loss_func(self): 81 | if self.args['learner']['loss']['type']=='weighted_bce': 82 | from loss.bce_loss import weighted_binary_cross_entropy 83 | self.loss_func=weighted_binary_cross_entropy(**self.args['learner']['loss']['option']) 84 | elif self.args['learner']['loss']['type']=='BCEWithLogitsLoss': 85 | self.loss_func=torch.nn.modules.loss.BCEWithLogitsLoss(reduction='none') 86 | self.loss_func=torch.nn.modules.loss.BCELoss(reduction='none') 87 | 88 | elif self.args['learner']['loss']['type']=='kld': 89 | self.loss_func=torch.nn.modules.loss.KLDivLoss(reduction='none') 90 | elif self.args['learner']['loss']['type']=='mse': 91 | self.loss_func=torch.nn.modules.loss.MSELoss(reduction='none') 92 | 93 | elif self.args['learner']['loss']['type']=='SI-SDR': 94 | from asteroid.losses.sdr import SingleSrcNegSDR 95 | self.loss_func=SingleSrcNegSDR('sisdr') 96 | 97 | self.loss_train_map_num=self.args['learner']['loss']['option']['train_map_num'] 98 | self.loss_weight=self.args['learner']['loss']['option']['each_layer_weight'] 99 | 100 | if self.args['learner']['loss']['optimize_method']=='min': 101 | self.best_val_loss=math.inf 102 | self.best_train_loss=math.inf 103 | else: 104 | self.best_val_loss=-math.inf 105 | self.best_train_loss=-math.inf 106 | 107 | def train_update(self, output, target): 108 | 109 | loss=self.loss_func(output, target) 110 | 111 | loss_mean=loss.mean() 112 | 113 | if torch.isnan(loss_mean): 114 | print('nan occured') 115 | self.optimizer.zero_grad() 116 | return loss_mean 117 | 118 | loss_mean.backward() 119 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clip) 120 | self.optimizer.step() 121 | self.optimizer.zero_grad() 122 | 123 | return loss_mean 124 | 125 | def test_update(self, output, target): 126 | 127 | loss=self.loss_func(output, target) 128 | loss_mean=loss.mean() 129 | 130 | 131 | if torch.isnan(loss_mean): 132 | print('nan occured') 133 | self.optimizer.zero_grad() 134 | return loss_mean 135 | 136 | return loss_mean 137 | 138 | def config(self): 139 | self.device=self.args['hyparam']['GPGPU']['device'] 140 | self.model_select() 141 | self.init_optimizer() 142 | self.init_optimzer_scheduler() 143 | self.init_loss_func() 144 | return self.args 145 | 146 | class Logger_config(): 147 | def __init__(self, args) -> None: 148 | self.args=args 149 | self.csv=dict() 150 | self.csv['train_epoch_loss']=[] 151 | self.csv['train_best_loss']=[] 152 | self.csv['test_epoch_loss']=[] 153 | self.csv['test_best_loss']=[] 154 | 155 | self.csv_dir=self.args['logger']['save_csv'] 156 | self.model_save_dir=self.args['logger']['model_save_dir'] 157 | self.png_dir=self.args['logger']['png_dir'] 158 | 159 | if self.args['logger']['optimize_method']=='min': 160 | self.best_test_loss=math.inf 161 | self.best_train_loss=math.inf 162 | else: 163 | self.best_test_loss=-math.inf 164 | self.best_train_loss=-math.inf 165 | 166 | def train_iter_log(self, loss): 167 | try: 168 | wandb.log({'train_iter_loss':loss}) 169 | except: 170 | None 171 | self.epoch_train_loss.append(loss.cpu().detach().item()) 172 | 173 | def train_epoch_log(self): 174 | loss_mean=np.array(self.epoch_train_loss).mean() 175 | 176 | self.csv['train_epoch_loss'].append(loss_mean) 177 | 178 | if self.best_train_loss > loss_mean: 179 | self.best_train_loss = loss_mean 180 | 181 | try: 182 | wandb.log({'train_epoch_loss':loss_mean}) 183 | wandb.log({'train_best_loss':self.best_train_loss}) 184 | except: 185 | None 186 | 187 | self.csv['train_best_loss'].append(self.best_train_loss) 188 | 189 | 190 | 191 | def test_iter_log(self, loss): 192 | try: 193 | wandb.log({'test_iter_loss':loss}) 194 | except: 195 | None 196 | self.epoch_test_loss.append(loss.cpu().detach().item()) 197 | 198 | def test_epoch_log(self, optimizer_scheduler): 199 | loss_mean=np.array(self.epoch_test_loss).mean() 200 | self.csv['test_epoch_loss'].append(loss_mean) 201 | 202 | self.model_save=False 203 | if self.best_test_loss > loss_mean: 204 | self.model_save=True 205 | self.best_test_loss = loss_mean 206 | try: 207 | wandb.log({'test_epoch_loss':loss_mean}) 208 | wandb.log({'test_best_loss':self.best_test_loss}) 209 | except: 210 | None 211 | self.csv['test_best_loss'].append(self.best_test_loss) 212 | optimizer_scheduler.step(loss_mean) 213 | 214 | 215 | 216 | 217 | def epoch_init(self,): 218 | self.epoch_train_loss=[] 219 | self.epoch_test_loss=[] 220 | 221 | 222 | def epoch_finish(self, epoch, model, optimizer): 223 | os.makedirs(os.path.dirname(self.csv_dir), exist_ok=True) 224 | pd.DataFrame(self.csv).to_csv(self.csv_dir) 225 | 226 | checkpoint = { 227 | 'epoch': epoch, 228 | 'model_state_dict': model.module.state_dict(), 229 | 'optimizer': optimizer.state_dict() 230 | } 231 | 232 | os.makedirs(os.path.dirname(self.model_save_dir + "best_model.tar"), exist_ok=True) 233 | if self.model_save: 234 | os.makedirs(os.path.dirname(self.model_save_dir + "best_model.tar"), exist_ok=True) 235 | torch.save(checkpoint, self.model_save_dir + "best_model.tar") 236 | print("new best model\n") 237 | torch.save(checkpoint, self.model_save_dir + "{}_model.tar".format(epoch)) 238 | 239 | 240 | util.util.draw_result_pic(self.png_dir, epoch, self.csv['train_epoch_loss'], self.csv['test_epoch_loss']) 241 | 242 | 243 | 244 | def wandb_config(self): 245 | if self.args['logger']['wandb']['wandb_ok']: 246 | wandb.init(**self.args['logger']['wandb']['init']) 247 | return self.args 248 | 249 | def config(self,): 250 | self.wandb_config() 251 | return self.args 252 | 253 | 254 | class Dataloader_config(): 255 | def __init__(self, args) -> None: 256 | self.args=args 257 | 258 | 259 | def config(self): 260 | 261 | self.train_loader=Train_dataload(self.args['dataloader']['train'], self.args['hyparam']['randomseed']) 262 | self.test_loader=Eval_dataload(self.args['dataloader']['test']) 263 | 264 | return self.args 265 | 266 | 267 | 268 | 269 | class Trainer(): 270 | 271 | def __init__(self, args): 272 | 273 | 274 | self.args=args 275 | 276 | self.hyperparameter=Hyparam_set(self.args) 277 | self.args=self.hyperparameter.set_on() 278 | 279 | 280 | self.learner=Learner_config(self.args) 281 | self.args=self.learner.config() 282 | 283 | self.model=self.learner.model 284 | self.optimizer=self.learner.optimizer 285 | self.optimizer_scheduler=self.learner.optimizer_scheduler 286 | 287 | self.dataloader=Dataloader_config(self.args) 288 | self.args=self.dataloader.config() 289 | 290 | self.logger=Logger_config(self.args) 291 | self.args=self.logger.config() 292 | 293 | 294 | 295 | def run(self, ): 296 | 297 | 298 | 299 | for epoch in range(self.args['hyparam']['resume_epoch'], self.args['hyparam']['last_epoch']): 300 | 301 | 302 | self.logger.epoch_init() 303 | 304 | 305 | self.train(epoch) 306 | self.test(epoch) 307 | 308 | self.logger.epoch_finish(epoch, self.model, self.optimizer) 309 | 310 | 311 | def train(self, epoch): 312 | 313 | ######## train init 314 | self.model.train() 315 | 316 | self.optimizer.zero_grad() 317 | 318 | for iter_num, (mixed, target) in enumerate(tqdm(self.dataloader.train_loader, desc='Train {}'.format(epoch), total=len(self.dataloader.train_loader), )): 319 | 320 | mixed=mixed.to(self.hyperparameter.device) 321 | target=target.to(self.hyperparameter.device) 322 | 323 | 324 | 325 | out=self.model(mixed) 326 | 327 | 328 | loss=self.learner.train_update(out, target) 329 | 330 | self.logger.train_iter_log(loss) 331 | self.learner.memory_delete([mixed, out, target, loss]) 332 | 333 | self.logger.train_epoch_log() 334 | 335 | 336 | def test(self, epoch): 337 | self.model.eval() 338 | with torch.no_grad(): 339 | for iter_num, (mixed, target) in enumerate(tqdm(self.dataloader.test_loader, desc='Test', total=len(self.dataloader.test_loader), )): 340 | 341 | mixed=mixed.to(self.hyperparameter.device) 342 | target=target.to(self.hyperparameter.device) 343 | 344 | 345 | 346 | out=self.model(mixed) 347 | 348 | 349 | 350 | loss=self.learner.test_update( out, target) 351 | self.logger.test_iter_log(loss) 352 | 353 | self.learner.memory_delete([mixed, out, target, loss]) 354 | 355 | self.logger.test_epoch_log(self.optimizer_scheduler) 356 | 357 | 358 | 359 | 360 | 361 | if __name__=='__main__': 362 | args=sys.argv[1:] 363 | 364 | args=util.util.get_yaml_args(args) 365 | t=Trainer(args) 366 | t.run() --------------------------------------------------------------------------------