├── config ├── __init__.py ├── option.py └── train.yml ├── logger ├── __init__.py └── set_logger.py ├── trainer ├── __init__.py └── trainer.py ├── model ├── __init__.py ├── loss.py ├── model.py └── torch_utils.py ├── utils ├── __init__.py ├── stft_istft.py └── util.py ├── data_loader ├── __init__.py ├── AudioData.py └── Dataloader.py ├── cmvn.ark ├── .gitignore ├── README.md ├── create_scp.py ├── train.py └── utils.py /config/__init__.py: -------------------------------------------------------------------------------- 1 | from .option import * -------------------------------------------------------------------------------- /logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .set_logger import * -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import * 2 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss import * 2 | from .model import * -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import * 2 | from .stft_istft import * -------------------------------------------------------------------------------- /data_loader/__init__.py: -------------------------------------------------------------------------------- 1 | from .AudioData import * 2 | from .Dataloader import * -------------------------------------------------------------------------------- /cmvn.ark: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JusperLee/DANet-For-Speech-Separation/HEAD/cmvn.ark -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | .DS_Store 3 | .vscode/settings.json 4 | data_loader/.ipynb_checkpoints/Dataloader-checkpoint.py 5 | config/.ipynb_checkpoints/train-checkpoint.yml 6 | trainer/.ipynb_checkpoints/trainer-checkpoint.py 7 | -------------------------------------------------------------------------------- /config/option.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | def parse(opt_path): 4 | with open(opt_path, mode='r') as f: 5 | opt = yaml.load(f,Loader=yaml.FullLoader) 6 | 7 | opt['resume']['path'] = opt['resume']['path']+'/'+opt['name'] 8 | opt['logger']['path'] = opt['logger']['path']+'/'+opt['name'] 9 | return opt 10 | 11 | 12 | if __name__ == "__main__": 13 | parse('train.yml') -------------------------------------------------------------------------------- /logger/set_logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from datetime import datetime 3 | import os 4 | def get_timestamp(): 5 | return datetime.now().strftime('%y%m%d-%H%M%S') 6 | 7 | 8 | def setup_logger(logger_name, root, level=logging.INFO, screen=False, tofile=False): 9 | '''set up logger''' 10 | lg = logging.getLogger(logger_name) 11 | formatter = logging.Formatter('%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s', 12 | datefmt='%y-%m-%d %H:%M:%S') 13 | lg.setLevel(level) 14 | os.makedirs(root,exist_ok=True) 15 | if tofile: 16 | log_file = os.path.join(root, '_{}.log'.format(get_timestamp())) 17 | fh = logging.FileHandler(log_file, mode='w') 18 | fh.setFormatter(formatter) 19 | lg.addHandler(fh) 20 | if screen: 21 | sh = logging.StreamHandler() 22 | sh.setFormatter(formatter) 23 | lg.addHandler(sh) 24 | 25 | 26 | if __name__ == "__main__": 27 | setup_logger('base','root',level=logging.INFO,screen=True, tofile=False) 28 | logger = logging.getLogger('base') 29 | logger.info('hello') 30 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | 4 | import torch 5 | from utils import util 6 | 7 | 8 | 9 | class Loss(object): 10 | ''' 11 | """ 12 | MSE as the training objective. The mask estimation loss is calculated. 13 | You can also change it into the spectrogram estimation loss, which is 14 | to calculate the MSE between the clean source spectrograms and the 15 | masked mixture spectrograms. 16 | 17 | mix_samp: the spectrogram of the mixture; 18 | shape: (B, T, F) 19 | 20 | wf: the target masks, which are the wiener-filter like masks here; 21 | shape: (B, T*F, nspk) 22 | 23 | mask: the estimated masks generated by the network; 24 | shape: (B, T*F, nspk) 25 | """ 26 | ''' 27 | 28 | def __init__(self, mix_samp, wf, mask): 29 | self.mix_samp = mix_samp 30 | self.wf = wf 31 | self.mask = mask 32 | 33 | def loss(self): 34 | #loss = self.mix_samp.view(-1, self.mix_samp.size(1)*self.mix_samp.size(2), 1).expand( 35 | # self.mix_samp.size(0), self.mix_samp.size(1), self.wf.size(2)) * (self.wf - self.mask) 36 | loss = self.wf-self.mask 37 | loss = loss.view(-1, loss.size(1)*loss.size(2)) 38 | return torch.mean(torch.sum(torch.pow(loss, 2), 1)) 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DANet-For-Speech-Separation 2 | Pytorch implement of DANet For Speech Separation 3 | 4 | > Chen Z, Luo Y, Mesgarani N. Deep attractor network for single-microphone speaker separation[C]//2017 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2017: 246-250. 5 | 6 | ## Requirement 7 | 8 | - **Pytorch 0.4.0** 9 | - **librosa 0.7.1** 10 | - **PyYAML 5.1.2** 11 | 12 | Due to the RNN multi-GPU parallel problem, only PyTorch 0.4.0 is supported. 13 | 14 | ## Training steps 15 | 1. First, you can use the create_scp script to generate training and test data scp files. 16 | 17 | ```shell 18 | python create_scp.py 19 | ``` 20 | 21 | 2. Then, in order to reduce the mismatch of training and test environments. Therefore, you need to run the util script to generate a feature normalization file (CMVN). 22 | 23 | ```shell 24 | python ./utils/util.py 25 | ``` 26 | 27 | 3. Finally, use the following command to train the network. 28 | 29 | ```shell 30 | python train.py -opt ./option/train.yml 31 | ``` 32 | 33 | The repository model code draws on [DANet](https://github.com/naplab/DANet "DANet") from naplab. It is found through experiments that the loss value cannot be decreased, and I have limited ability and have not identified the cause of the problem. Therefore, only training codes are released for reference only. -------------------------------------------------------------------------------- /utils/stft_istft.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | 4 | import torch 5 | import librosa 6 | import numpy as np 7 | from utils import util 8 | 9 | class STFT(object): 10 | ''' 11 | using the librosa implement of stft 12 | windows: a window function 13 | nfft: length of the windowed signal after padding with zeros. 14 | window_length: window() of length win_length 15 | hop_length: number of audio samples between adjacent STFT columns. 16 | ''' 17 | 18 | def __init__(self, window='hann', nfft=256, window_length=256, hop_length=64,center=False): 19 | self.window = window 20 | self.nfft = nfft 21 | self.window_length = window_length 22 | self.hop_length = hop_length 23 | self.center =center 24 | def stft(self, samp, is_mag=False,is_log=False): 25 | # is_mag: Whether the output is an amplitude value 26 | stft_r = librosa.stft(samp, n_fft=self.nfft, hop_length=self.hop_length, 27 | win_length=self.window_length, window=self.window,center=self.center) 28 | stft_r = np.transpose(stft_r) 29 | if is_mag: 30 | stft_r = np.abs(stft_r) 31 | if is_log: 32 | min_z = np.finfo(float).eps 33 | stft_r = np.log(np.maximum(stft_r,min_z)) 34 | return stft_r 35 | 36 | def istft(self, stft_samp): 37 | stft_samp = np.transpose(stft_samp) 38 | output = librosa.istft(stft_samp, hop_length=self.hop_length, 39 | win_length=self.window_length, window=self.window,center=self.center) 40 | return output 41 | 42 | 43 | if __name__ == "__main__": 44 | samp = util.read_wav('../1.wav') 45 | stft_i = STFT() 46 | stft = stft_i.stft(samp,is_mag=True,is_log=True) 47 | print(stft) 48 | -------------------------------------------------------------------------------- /config/train.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: DANet 3 | use_tb_logger: true 4 | num_spks: 2 5 | #### datasets 6 | datasets: 7 | train: 8 | dataroot_mix: /home/likai/data1/create_scp/tr_mix.scp 9 | dataroot_targets: [/home/likai/data1/create_scp/tr_s1.scp,/home/likai/data1/create_scp/tr_s2.scp] 10 | 11 | val: 12 | dataroot_mix: /home/likai/data1/create_scp/cv_mix.scp 13 | dataroot_targets: [/home/likai/data1/create_scp/cv_s1.scp,/home/likai/data1/create_scp/cv_s2.scp] 14 | 15 | dataloader_setting: 16 | shuffle: true 17 | num_workers: 10 # per GPU 18 | batch_size: 1 19 | cmvn_file: /home/likai/data1/DANet/cmvn.ark 20 | 21 | audio_setting: 22 | window: hann 23 | nfft: 256 24 | window_length: 256 25 | hop_length: 64 26 | center: False 27 | is_mag: True # abs(tf-domain) 28 | is_log: True # log(tf-domain) 29 | chunk_size: 32000 30 | least: 16000 31 | 32 | 33 | #### network structures 34 | DANet: 35 | name: LSTM # RNN, LSTM, GRU 36 | num_layer: 4 37 | input_size: 129 # nfft/2+1 38 | hidden_cells: 300 39 | emb_D: 20 40 | dropout: 0.5 41 | batch_first: true 42 | bidirectional: true 43 | activation: Tanh 44 | 45 | #### training settings: learning rate scheme, loss 46 | train: 47 | epoch: 100 48 | early_stop: 10 49 | path: /home/likai/data1/DANet/checkpoint 50 | gpuid: [0,1,2,3,4,5,6,7] 51 | 52 | #### Optimizer settings 53 | optim: 54 | name: Adam ### Adam, RMSprop, SGD 55 | lr: 1.0e-3 56 | momentum: 0.9 57 | weight_decay: 0 58 | clip_norm: 200 59 | #### Resume training settings 60 | resume: 61 | state: false 62 | path: /home/likai/data1/DANet/checkpoint 63 | 64 | 65 | #### logger 66 | logger: 67 | name: DANet 68 | path: /home/likai/data1/DANet/checkpoint 69 | screen: true 70 | tofile: false 71 | print_freq: 1000 -------------------------------------------------------------------------------- /data_loader/AudioData.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | 4 | import utils.util as ut 5 | from utils.stft_istft import STFT 6 | import torch 7 | import random 8 | import numpy as np 9 | 10 | 11 | class AudioData(object): 12 | ''' 13 | Loading wave file 14 | scp_file: the scp file path 15 | other kwargs is stft's kwargs 16 | is_mag: if True, abs(stft) 17 | ''' 18 | 19 | def __init__(self, scp_file, window='hann', nfft=256, window_length=256, hop_length=64, center=False, is_mag=True, is_log=True, chunk_size=32000, least=16000): 20 | self.wave = ut.read_scp(scp_file) 21 | self.wave_keys = [key for key in self.wave.keys()] 22 | self.STFT = STFT(window=window, nfft=nfft, 23 | window_length=window_length, hop_length=hop_length, center=center) 24 | self.is_mag = is_mag 25 | self.is_log = is_log 26 | self.samp_list = [] 27 | self.samp_stft = [] 28 | self.chunk_size = chunk_size 29 | self.least = least 30 | self.split() 31 | self.stft() 32 | 33 | def __len__(self): 34 | return len(self.wave_keys) 35 | 36 | def split(self): 37 | for key in self.wave_keys: 38 | wave_path = self.wave[key] 39 | samp = ut.read_wav(wave_path) 40 | length = samp.shape[0] 41 | if length < self.least: 42 | continue 43 | if length < self.chunk_size: 44 | gap = self.chunk_size-length 45 | 46 | samp = np.pad(samp, (0, gap), mode='constant') 47 | self.samp_list.append(samp) 48 | else: 49 | random_start = 0 50 | while True: 51 | if random_start+self.chunk_size > length: 52 | break 53 | self.samp_list.append( 54 | samp[random_start:random_start+self.chunk_size]) 55 | random_start += self.least 56 | 57 | def stft(self): 58 | for samp in self.samp_list: 59 | self.samp_stft.append(self.STFT.stft( 60 | samp, is_mag=True, is_log=True)) 61 | 62 | def __iter__(self): 63 | for stft in self.samp_stft: 64 | yield stft 65 | 66 | def __getitem__(self, index): 67 | return self.samp_stft[index] 68 | 69 | 70 | if __name__ == "__main__": 71 | ad = AudioData("/home/likai/Desktop/create_scp/cv_mix.scp", 72 | is_mag=True, is_log=True) 73 | print(ad.samp_stft[0].shape) 74 | print(ad.samp_stft[100].shape) 75 | -------------------------------------------------------------------------------- /create_scp.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | train_mix_scp = 'tr_mix.scp' 4 | train_s1_scp = 'tr_s1.scp' 5 | train_s2_scp = 'tr_s2.scp' 6 | 7 | test_mix_scp = 'tt_mix.scp' 8 | test_s1_scp = 'tt_s1.scp' 9 | test_s2_scp = 'tt_s2.scp' 10 | 11 | train_mix = '/home/likai/likai/Dataset/wsj0-mix/2speakers/wav8k/min/tr/mix' 12 | train_s1 = '/home/likai/likai/Dataset/wsj0-mix/2speakers/wav8k/min/tr/s1' 13 | train_s2 = '/home/likai/likai/Dataset/wsj0-mix/2speakers/wav8k/min/tr/s2' 14 | 15 | test_mix = '/home/likai/likai/Dataset/wsj0-mix/2speakers/wav8k/min/tt/mix' 16 | test_s1 = '/home/likai/likai/Dataset/wsj0-mix/2speakers/wav8k/min/tt/s1' 17 | test_s2 = '/home/likai/likai/Dataset/wsj0-mix/2speakers/wav8k/min/tt/s2' 18 | 19 | tr_mix = open(train_mix_scp,'w') 20 | for root, dirs, files in os.walk(train_mix): 21 | files.sort() 22 | for file in files: 23 | tr_mix.write(file+" "+root+'/'+file) 24 | tr_mix.write('\n') 25 | 26 | 27 | tr_s1 = open(train_s1_scp,'w') 28 | for root, dirs, files in os.walk(train_s1): 29 | files.sort() 30 | for file in files: 31 | tr_s1.write(file+" "+root+'/'+file) 32 | tr_s1.write('\n') 33 | 34 | 35 | tr_s2 = open(train_s2_scp,'w') 36 | for root, dirs, files in os.walk(train_s2): 37 | files.sort() 38 | for file in files: 39 | tr_s2.write(file+" "+root+'/'+file) 40 | tr_s2.write('\n') 41 | 42 | 43 | 44 | tt_mix = open(test_mix_scp,'w') 45 | for root, dirs, files in os.walk(test_mix): 46 | files.sort() 47 | for file in files: 48 | tt_mix.write(file+" "+root+'/'+file) 49 | tt_mix.write('\n') 50 | 51 | 52 | tt_s1 = open(test_s1_scp,'w') 53 | for root, dirs, files in os.walk(test_s1): 54 | files.sort() 55 | for file in files: 56 | tt_s1.write(file+" "+root+'/'+file) 57 | tt_s1.write('\n') 58 | 59 | 60 | tt_s2 = open(test_s2_scp,'w') 61 | for root, dirs, files in os.walk(test_s2): 62 | files.sort() 63 | for file in files: 64 | tt_s2.write(file+" "+root+'/'+file) 65 | tt_s2.write('\n') 66 | 67 | cv_mix_scp = 'cv_mix.scp' 68 | cv_s1_scp = 'cv_s1.scp' 69 | cv_s2_scp = 'cv_s2.scp' 70 | 71 | cv_mix = '/home/likai/likai/Dataset/wsj0-mix/2speakers/wav8k/min/cv/mix' 72 | cv_s1 = '/home/likai/likai/Dataset/wsj0-mix/2speakers/wav8k/min/cv/s1' 73 | cv_s2 = '/home/likai/likai/Dataset/wsj0-mix/2speakers/wav8k/min/cv/s2' 74 | 75 | cv_mix_file = open(cv_mix_scp,'w') 76 | for root, dirs, files in os.walk(cv_mix): 77 | files.sort() 78 | for file in files: 79 | cv_mix_file.write(file+" "+root+'/'+file) 80 | cv_mix_file.write('\n') 81 | 82 | 83 | cv_s1_file = open(cv_s1_scp,'w') 84 | for root, dirs, files in os.walk(cv_s1): 85 | files.sort() 86 | for file in files: 87 | cv_s1_file.write(file+" "+root+'/'+file) 88 | cv_s1_file.write('\n') 89 | 90 | 91 | cv_s2_file = open(cv_s2_scp,'w') 92 | for root, dirs, files in os.walk(cv_s2): 93 | files.sort() 94 | for file in files: 95 | cv_s2_file.write(file+" "+root+'/'+file) 96 | cv_s2_file.write('\n') -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | 4 | from config import option 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as Fun 8 | from . import torch_utils 9 | import warnings 10 | warnings.filterwarnings('ignore') 11 | 12 | class DANet(nn.Module): 13 | def __init__(self, name='LSTM', 14 | num_layer=2, input_size=129, 15 | hidden_cells=600, emb_D=40, 16 | dropout=0.0, bidirectional=True, 17 | batch_first=True, activation="Tanh"): 18 | super(DANet, self).__init__() 19 | self.rnn = torch_utils.MultiRNN('LSTM', input_size, hidden_cells, 20 | num_layers=num_layer, 21 | bidirectional=bidirectional) 22 | self.linear = torch_utils.FCLayer(600, input_size*emb_D, nonlinearity='tanh') 23 | 24 | self.input_size = input_size 25 | self.emb_D = emb_D 26 | self.eps = 1e-8 27 | 28 | def forward(self, input_list): 29 | """ 30 | input: the input feature; 31 | shape: (B, T, F) 32 | 33 | ibm: the ideal binary mask used for calculating the 34 | ideal attractors; 35 | shape: (B, T*F, nspk) 36 | 37 | non_silent: the binary energy threshold matrix for masking 38 | out T-F bins; 39 | shape: (B, T*F, 1) 40 | """ 41 | #self.rnn.flatten_parameters() 42 | input, ibm, non_silent, hidden = input_list 43 | B, T, F = input.shape 44 | # BT x H 45 | input, hidden = self.rnn(input,hidden) 46 | input = input.contiguous().view(-1, input.size(2)) 47 | # BT x H -> BT x FD 48 | input = self.linear(input) 49 | # BT x FD -> B x TF x D 50 | v = input.view(-1, T*F, self.emb_D) 51 | # calculate the ideal attractors 52 | # first calculate the source assignment matrix Y 53 | # B x TF x nspk 54 | y = ibm*non_silent.expand_as(ibm) 55 | # attractors are the weighted average of the embeddings 56 | # calculated by V and Y 57 | # B x K x nspk 58 | v_y = torch.bmm(torch.transpose(v, 1, 2), y) 59 | # B x K x nspk 60 | sum_y = torch.sum(y, 1, keepdim=True).expand_as(v_y) 61 | # B x K x nspk 62 | attractor = v_y / (sum_y + self.eps) 63 | 64 | # calculate the distance bewteen embeddings and attractors 65 | # and generate the masks 66 | # B x TF x nspk 67 | dist = v.bmm(attractor) 68 | # B x TF x nspk 69 | mask = Fun.softmax(dist,dim=2) 70 | return mask, hidden 71 | 72 | def init_hidden(self, batch_size): 73 | return self.rnn.init_hidden(batch_size) 74 | 75 | if __name__ == "__main__": 76 | opt = option.parse('../config/train.yml') 77 | net = DANet(**opt['DANet']) 78 | input = torch.randn(5,10,129) 79 | ibm = torch.randint(2, (5, 1290,2)) 80 | non_silent = torch.randn(5, 1290, 1) 81 | out = net(input, ibm, non_silent) 82 | -------------------------------------------------------------------------------- /data_loader/Dataloader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | 4 | from torch.nn.utils.rnn import pack_sequence, pad_sequence 5 | from data_loader import AudioData 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | import torch.nn.functional as F 9 | from utils import util 10 | import pickle 11 | import numpy as np 12 | 13 | def compute_mask(mixture, targets_list, mask_type): 14 | """ 15 | Arguments: 16 | mixture: STFT of mixture signal(complex result) 17 | targets_list: python list of target signal's STFT results(complex result) 18 | mask_type: ["irm", "ibm", "iam", "wfm", "psm"] 19 | """ 20 | if mask_type == 'ibm': 21 | max_index = np.argmax( 22 | np.stack([np.abs(mat) for mat in targets_list]), 0) 23 | return np.array([max_index == s for s in range(len(targets_list))],dtype=np.float) 24 | 25 | if mask_type == "irm": 26 | denominator = sum([np.abs(mat) for mat in targets_list]) 27 | else: 28 | if mask_type == "wfm": 29 | denominator = sum([np.power(np.abs(mat),2) for mat in targets_list]) 30 | else: 31 | denominator = np.abs(mixture) 32 | if mask_type != "psm": 33 | if mask_type == "wfm": 34 | masks = [np.power(np.abs(mat),2) / denominator for mat in targets_list] 35 | else: 36 | masks = [np.abs(mat) / denominator for mat in targets_list] 37 | else: 38 | mixture_phase = np.angle(mixture) 39 | masks = [ 40 | np.abs(mat) * np.cos(mixture_phase - np.angle(mat)) / denominator 41 | for mat in targets_list 42 | ] 43 | return np.array(masks) 44 | 45 | class dataset(Dataset): 46 | def __init__(self, mix_reader, target_readers, cmvn_file='../cmvn.ark'): 47 | super(dataset).__init__() 48 | self.mix_reader = mix_reader 49 | self.target_readers = target_readers 50 | self.cmvn = pickle.load(open(cmvn_file, 'rb')) 51 | 52 | def __len__(self): 53 | return len(self.mix_reader) 54 | 55 | def __getitem__(self, index): 56 | if index >= len(self.mix_reader): 57 | raise ValueError 58 | mix_samp = self.mix_reader[index] 59 | T,F = mix_samp.shape[0],mix_samp.shape[1] 60 | target_lists = [target[index] for target in self.target_readers] 61 | wf = torch.from_numpy(compute_mask(mix_samp,target_lists,'wfm')).reshape(T*F,-1).type(torch.float32) 62 | ibm = torch.from_numpy(compute_mask(mix_samp,target_lists,'ibm')).reshape(T*F,-1).type(torch.float32) 63 | non_silent = torch.from_numpy(util.compute_non_silent(mix_samp)).reshape(T*F,-1).type(torch.float32) 64 | return torch.from_numpy(util.apply_cmvn(mix_samp,self.cmvn)).type(torch.float32), wf, ibm, non_silent 65 | 66 | 67 | 68 | 69 | 70 | 71 | if __name__ == "__main__": 72 | mix_reader = AudioData( 73 | "/home/likai/data1/create_scp/tr_mix.scp", is_mag=True, is_log=True) 74 | target_readers = [AudioData("/home/likai/data1/create_scp/tr_s1.scp", is_mag=True, is_log=True), 75 | AudioData("/home/likai/data1/create_scp/tr_s2.scp", is_mag=True, is_log=True)] 76 | dataset = dataset(mix_reader, target_readers) 77 | mix_samp, wf, ibm, non_silent = dataset[0] 78 | print(mix_samp.shape) 79 | print(wf.shape) 80 | print(ibm.shape) 81 | print(non_silent.shape) 82 | -------------------------------------------------------------------------------- /model/torch_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | class MultiRNN(nn.Module): 8 | """ 9 | Container module for multiple stacked RNN layers. 10 | 11 | args: 12 | rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'. 13 | input_size: int, dimension of the input feature. The input should have shape 14 | (batch, seq_len, input_size). 15 | hidden_size: int, dimension of the hidden state. The corresponding output should 16 | have shape (batch, seq_len, hidden_size). 17 | num_layers: int, number of stacked RNN layers. Default is 1. 18 | bidirectional: bool, whether the RNN layers are bidirectional. Default is False. 19 | """ 20 | 21 | def __init__(self, rnn_type, input_size, hidden_size, dropout=0, num_layers=1, bidirectional=False): 22 | super(MultiRNN, self).__init__() 23 | 24 | self.rnn = getattr(nn, rnn_type)(input_size, hidden_size, num_layers, dropout=dropout, 25 | batch_first=True, bidirectional=bidirectional) 26 | 27 | 28 | 29 | self.rnn_type = rnn_type 30 | self.hidden_size = hidden_size 31 | self.num_layers = num_layers 32 | self.num_direction = int(bidirectional) + 1 33 | 34 | def forward(self, input, hidden): 35 | self.rnn.flatten_parameters() 36 | return self.rnn(input, hidden) 37 | 38 | def init_hidden(self, batch_size): 39 | weight = next(self.parameters()).data 40 | if self.rnn_type == 'LSTM': 41 | return (Variable(weight.new(self.num_layers*self.num_direction, batch_size, self.hidden_size).zero_()), 42 | Variable(weight.new(self.num_layers*self.num_direction, batch_size, self.hidden_size).zero_())) 43 | else: 44 | return Variable(weight.new(self.num_layers*self.num_direction, batch_size, self.hidden_size).zero_()) 45 | 46 | 47 | class FCLayer(nn.Module): 48 | """ 49 | Container module for a fully-connected layer. 50 | 51 | args: 52 | input_size: int, dimension of the input feature. The input should have shape 53 | (batch, input_size). 54 | hidden_size: int, dimension of the output. The corresponding output should 55 | have shape (batch, hidden_size). 56 | nonlinearity: string, the nonlinearity applied to the transformation. Default is None. 57 | """ 58 | 59 | def __init__(self, input_size, hidden_size, nonlinearity=None): 60 | super(FCLayer, self).__init__() 61 | 62 | self.input_size = input_size 63 | self.hidden_size = hidden_size 64 | self.FC = nn.Linear(self.input_size, self.hidden_size) 65 | if nonlinearity: 66 | self.nonlinearity = getattr(torch, nonlinearity) 67 | else: 68 | self.nonlinearity = None 69 | 70 | self.init_hidden() 71 | 72 | def forward(self, input): 73 | if self.nonlinearity is not None: 74 | return self.nonlinearity(self.FC(input)) 75 | else: 76 | return self.FC(input) 77 | 78 | def init_hidden(self): 79 | initrange = 1. / np.sqrt(self.input_size) 80 | self.FC.bias.data.fill_(0) 81 | self.FC.weight.data.uniform_(-initrange, initrange) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | from torch.utils.data import DataLoader as Loader 4 | from data_loader import Dataloader, AudioData 5 | from model import model 6 | from logger import set_logger 7 | import logging 8 | from config import option 9 | import argparse 10 | import torch 11 | from trainer import Trainer 12 | 13 | 14 | 15 | def make_dataloader(opt): 16 | # make train's dataloader 17 | train_mix_reader = AudioData( 18 | opt['datasets']['train']['dataroot_mix'], **opt['datasets']['audio_setting']) 19 | train_target_readers = [AudioData(opt['datasets']['train']['dataroot_targets'][0], **opt['datasets']['audio_setting']), 20 | AudioData(opt['datasets']['train']['dataroot_targets'][1], **opt['datasets']['audio_setting'])] 21 | train_dataset = Dataloader.dataset( 22 | train_mix_reader, train_target_readers, opt['datasets']['dataloader_setting']['cmvn_file']) 23 | train_dataloader = Loader(train_dataset, 24 | batch_size=opt['datasets']['dataloader_setting']['batch_size'], 25 | num_workers=opt['datasets']['dataloader_setting']['num_workers'], 26 | shuffle=opt['datasets']['dataloader_setting']['shuffle']) 27 | 28 | # make validation dataloader 29 | val_mix_reader = AudioData( 30 | opt['datasets']['val']['dataroot_mix'], **opt['datasets']['audio_setting']) 31 | val_target_readers = [AudioData(opt['datasets']['val']['dataroot_targets'][0], **opt['datasets']['audio_setting']), 32 | AudioData(opt['datasets']['val']['dataroot_targets'][1], **opt['datasets']['audio_setting'])] 33 | val_dataset = Dataloader.dataset( 34 | val_mix_reader, val_target_readers, opt['datasets']['dataloader_setting']['cmvn_file']) 35 | val_dataloader = Loader(val_dataset, 36 | batch_size=opt['datasets']['dataloader_setting']['batch_size'], 37 | num_workers=opt['datasets']['dataloader_setting']['num_workers'], 38 | shuffle=opt['datasets']['dataloader_setting']['shuffle']) 39 | return train_dataloader, val_dataloader 40 | 41 | 42 | def make_optimizer(params, opt): 43 | optimizer = getattr(torch.optim, opt['optim']['name']) 44 | if opt['optim']['name'] == 'Adam': 45 | optimizer = optimizer( 46 | params, lr=opt['optim']['lr'], weight_decay=opt['optim']['weight_decay']) 47 | else: 48 | optimizer = optimizer(params, lr=opt['optim']['lr'], weight_decay=opt['optim'] 49 | ['weight_decay'], momentum=opt['optim']['momentum']) 50 | 51 | return optimizer 52 | 53 | 54 | def train(): 55 | parser = argparse.ArgumentParser( 56 | description='Parameters for training DANet') 57 | parser.add_argument('--opt', type=str, help='Path to option YAML file.') 58 | args = parser.parse_args() 59 | opt = option.parse(args.opt) 60 | set_logger.setup_logger(opt['logger']['name'], opt['logger']['path'], 61 | screen=opt['logger']['screen'], tofile=opt['logger']['tofile']) 62 | logger = logging.getLogger(opt['logger']['name']) 63 | 64 | logger.info("Building the model of DANet") 65 | danet = model.DANet(**opt['DANet']) 66 | 67 | logger.info("Building the optimizer of DANet") 68 | optimizer = make_optimizer(danet.parameters(), opt) 69 | 70 | logger.info('Building the dataloader of DANet') 71 | train_dataloader, val_dataloader = make_dataloader(opt) 72 | 73 | logger.info('Train Datasets Length: {}, Val Datasets Length: {}'.format( 74 | len(train_dataloader), len(val_dataloader))) 75 | logger.info('Building the Trainer of DANet') 76 | trainer = Trainer(train_dataloader, val_dataloader, danet, optimizer, opt) 77 | trainer.run() 78 | 79 | 80 | if __name__ == "__main__": 81 | train() 82 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from collections import OrderedDict 4 | import librosa 5 | import os 6 | import sys 7 | sys.path.append('../') 8 | import torch 9 | import numpy as np 10 | from data_loader import AudioData 11 | from tqdm import tqdm 12 | import pickle 13 | def ensure_dir(dirname): 14 | dirname = Path(dirname) 15 | if not dirname.is_dir(): 16 | dirname.mkdir(parents=True, exist_ok=False) 17 | 18 | 19 | def read_json(fname): 20 | fname = Path(fname) 21 | with fname.open('rt') as handle: 22 | return json.load(handle, object_hook=OrderedDict) 23 | 24 | 25 | def write_json(content, fname): 26 | fname = Path(fname) 27 | with fname.open('wt') as handle: 28 | json.dump(content, handle, indent=4, sort_keys=False) 29 | 30 | 31 | def read_wav(file_path, sr=8000, is_return_sr=False): 32 | ''' 33 | file path: wav file path 34 | is_return_sr: if true, return sr number 35 | ''' 36 | samp, sr = librosa.load(file_path, sr=sr) 37 | if is_return_sr: 38 | return samp, sr 39 | return samp 40 | 41 | 42 | def write_wav(file_path, filename, samp, sr=8000): 43 | ''' 44 | file_path: path of file 45 | filename: sound of Spectrogram 46 | sr: sample rate 47 | ''' 48 | os.makedirs(file_path, exist_ok=True) 49 | filepath = os.path.join(file_path, filename) 50 | librosa.output.write_wav(filepath, samp, sr) 51 | 52 | 53 | def read_scp(scp_file): 54 | ''' 55 | read the scp file 56 | ''' 57 | files = open(scp_file, 'r') 58 | lines = files.readlines() 59 | wave = {} 60 | for line in lines: 61 | line = line.split() 62 | if line[0] in wave.keys(): 63 | raise ValueError 64 | wave[line[0]] = line[1] 65 | return wave 66 | 67 | 68 | def compute_non_silent(samp, threshold=40, is_linear=True): 69 | ''' 70 | samp: Spectrogram 71 | threshold: threshold(dB) 72 | is_linear: non-linear -> linear 73 | ''' 74 | # to linear first if needed 75 | if is_linear: 76 | samp = np.exp(samp) 77 | # to dB 78 | spectra_db = 20 * np.log10(samp) 79 | max_magnitude_db = np.max(spectra_db) 80 | threshold = 10**((max_magnitude_db - threshold) / 20) 81 | non_silent = np.array(samp > threshold, dtype=np.float32) 82 | return non_silent 83 | 84 | def compute_cmvn(scp_file,save_file,**kwargs): 85 | ''' 86 | Feature normalization 87 | scp_file: the file path of scp 88 | save_file: the cmvn result file .ark 89 | **kwargs: the configure setting of file 90 | 91 | return 92 | mean: [frequency-bins] 93 | var: [frequency-bins] 94 | ''' 95 | wave_reader = AudioData(scp_file,**kwargs) 96 | tf_bin = int(kwargs['nfft']/2+1) 97 | mean = np.zeros(tf_bin) 98 | std = np.zeros(tf_bin) 99 | num_frames = 0 100 | for spectrogram in tqdm(wave_reader): 101 | num_frames += spectrogram.shape[0] 102 | mean += np.sum(spectrogram, 0) 103 | std += np.sum(spectrogram**2, 0) 104 | mean = mean / num_frames 105 | std = np.sqrt(std / num_frames - mean**2) 106 | with open(save_file, "wb") as f: 107 | cmvn_dict = {"mean": mean, "std": std} 108 | pickle.dump(cmvn_dict, f) 109 | print("Totally processed {} frames".format(num_frames)) 110 | print("Global mean: {}".format(mean)) 111 | print("Global std: {}".format(std)) 112 | 113 | def apply_cmvn(samp,cmvn_dict): 114 | ''' 115 | apply cmvn for Spectrogram 116 | samp: stft Spectrogram 117 | cmvn: the path of cmvn(python util.py) 118 | 119 | calculate: x = (x-mean)/std 120 | ''' 121 | return (samp-cmvn_dict['mean'])/cmvn_dict['std'] 122 | 123 | if __name__ == "__main__": 124 | kwargs = {'window':'hann', 'nfft':256, 'window_length':256, 'hop_length':64, 'center':False, 'is_mag':True, 'is_log':True} 125 | compute_cmvn("/home/likai/data1/create_scp/tr_mix.scp",'../cmvn.ark',**kwargs) 126 | #file = pickle.load(open('cmvn.ark','rb')) 127 | #print(file) 128 | #samp = read_wav('../1.wav') 129 | #print(compute_non_silent(samp)) 130 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # wujian@2018 4 | 5 | import os 6 | import warnings 7 | import yaml 8 | 9 | import librosa as audio_lib 10 | import numpy as np 11 | 12 | MAX_INT16 = np.iinfo(np.int16).max 13 | EPSILON = np.finfo(np.float32).eps 14 | 15 | config_keys = [ 16 | "trainer", "dcnet", "spectrogram_reader", "dataloader", "train_scp_conf", 17 | "valid_scp_conf", "debug_scp_conf" 18 | ] 19 | 20 | 21 | def nfft(window_size): 22 | return int(2**np.ceil(int(np.log2(window_size)))) 23 | 24 | 25 | # return F x T or T x F 26 | def stft(file, 27 | frame_length=1024, 28 | frame_shift=256, 29 | center=False, 30 | window="hann", 31 | return_samps=False, 32 | apply_abs=False, 33 | apply_log=False, 34 | apply_pow=False, 35 | transpose=True): 36 | if not os.path.exists(file): 37 | raise FileNotFoundError("Input file {} do not exists!".format(file)) 38 | if apply_log and not apply_abs: 39 | apply_abs = True 40 | warnings.warn( 41 | "Ignore apply_abs=False cause function return real values") 42 | samps, _ = audio_lib.load(file, sr=8000) 43 | stft_mat = audio_lib.stft( 44 | samps, 45 | nfft(frame_length), 46 | frame_shift, 47 | frame_length, 48 | window=window, 49 | center=center) 50 | if apply_abs: 51 | stft_mat = np.abs(stft_mat) 52 | if apply_pow: 53 | stft_mat = np.power(stft_mat, 2) 54 | if apply_log: 55 | stft_mat = np.log(np.maximum(stft_mat, EPSILON)) 56 | if transpose: 57 | stft_mat = np.transpose(stft_mat) 58 | return stft_mat if not return_samps else (samps, stft_mat) 59 | 60 | 61 | def istft(file, 62 | stft_mat, 63 | frame_length=1024, 64 | frame_shift=256, 65 | center=False, 66 | window="hann", 67 | transpose=True, 68 | norm=None, 69 | fs=16000, 70 | nsamps=None): 71 | if transpose: 72 | stft_mat = np.transpose(stft_mat) 73 | samps = audio_lib.istft( 74 | stft_mat, 75 | frame_shift, 76 | frame_length, 77 | window=window, 78 | center=center, 79 | length=nsamps) 80 | #samps_norm = np.linalg.norm(samps, np.inf) 81 | # renorm if needed 82 | import pdb 83 | pdb.set_trace() 84 | if not norm: 85 | samps = samps * norm / samps_norm 86 | #samps_int16 = (samps * MAX_INT16).astype(np.int16) 87 | fdir = os.path.dirname(file) 88 | if fdir and not os.path.exists(fdir): 89 | os.makedirs(fdir) 90 | audio_lib.output.write_wav(file, samps, fs) 91 | 92 | 93 | def compute_vad_mask(spectra, threshold_db=40, apply_exp=True): 94 | # to linear first if needed 95 | if apply_exp: 96 | spectra = np.exp(spectra) 97 | # to dB 98 | spectra_db = 20 * np.log10(spectra) 99 | max_magnitude_db = np.max(spectra_db) 100 | threshold = 10**((max_magnitude_db - threshold_db) / 20) 101 | mask = np.array(spectra > threshold, dtype=np.float32) 102 | return mask 103 | 104 | 105 | def apply_cmvn(feats, cmvn_dict): 106 | if type(cmvn_dict) != dict: 107 | raise TypeError("Input must be a python dictionary") 108 | if 'mean' in cmvn_dict: 109 | feats = feats - cmvn_dict['mean'] 110 | if 'std' in cmvn_dict: 111 | feats = feats / cmvn_dict['std'] 112 | return feats 113 | 114 | 115 | def parse_scps(scp_path): 116 | assert os.path.exists(scp_path) 117 | scp_dict = dict() 118 | with open(scp_path, 'r') as f: 119 | for scp in f: 120 | scp_tokens = scp.strip().split() 121 | if len(scp_tokens) != 2: 122 | raise RuntimeError( 123 | "Error format of context \'{}\'".format(scp)) 124 | key, addr = scp_tokens 125 | if key in scp_dict: 126 | raise ValueError("Duplicate key \'{}\' exists!".format(key)) 127 | scp_dict[key] = addr 128 | return scp_dict 129 | 130 | 131 | def filekey(path): 132 | fname = os.path.basename(path) 133 | if not fname: 134 | raise ValueError("{}(Is directory path?)".format(path)) 135 | token = fname.split(".") 136 | if len(token) == 1: 137 | return token[0] 138 | else: 139 | return '.'.join(token[:-1]) 140 | 141 | 142 | def parse_yaml(yaml_conf): 143 | if not os.path.exists(yaml_conf): 144 | raise FileNotFoundError( 145 | "Could not find configure files...{}".format(yaml_conf)) 146 | with open(yaml_conf, 'r') as f: 147 | config_dict = yaml.load(f) 148 | 149 | for key in config_keys: 150 | if key not in config_dict: 151 | raise KeyError("Missing {} configs in yaml".format(key)) 152 | batch_size = config_dict["dataloader"]["batch_size"] 153 | if batch_size <= 0: 154 | raise ValueError("Invalid batch_size: {}".format(batch_size)) 155 | 156 | num_frames = config_dict["spectrogram_reader"]["frame_length"] 157 | num_bins = nfft(num_frames) // 2 + 1 158 | if len(config_dict["train_scp_conf"]) != len( 159 | config_dict["valid_scp_conf"]): 160 | raise ValueError("Check configures in train_scp_conf/valid_scp_conf") 161 | num_spks = 0 162 | for key in config_dict["train_scp_conf"]: 163 | if key[:3] == "spk": 164 | num_spks += 1 165 | if num_spks != config_dict["trainer"]["num_spks"]: 166 | warnings.warn( 167 | "Number of speakers configured in trainer do not match *_scp_conf, " 168 | " correct to {}".format(num_spks)) 169 | config_dict["trainer"]["num_spks"] = num_spks 170 | return num_bins, config_dict 171 | 172 | 173 | if __name__ == "__main__": 174 | a = stft('1.wav') 175 | b = a 176 | test = np.argmax(a,0) 177 | print(test) -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | from torch.autograd import Variable 4 | from torch.nn.parallel import data_parallel 5 | import matplotlib.pyplot as plt 6 | import os 7 | import torch 8 | from model.loss import Loss 9 | from logger.set_logger import setup_logger 10 | import logging 11 | import time 12 | 13 | 14 | 15 | class Trainer(object): 16 | def __init__(self, train_dataloader, val_dataloader, DANet, optimizer, opt): 17 | super(Trainer).__init__() 18 | self.train_dataloader = train_dataloader 19 | self.val_dataloader = val_dataloader 20 | self.num_spks = opt['num_spks'] 21 | self.cur_epoch = 0 22 | self.total_epoch = opt['train']['epoch'] 23 | self.early_stop = opt['train']['early_stop'] 24 | 25 | self.print_freq = opt['logger']['print_freq'] 26 | # setup_logger(opt['logger']['name'], opt['logger']['path'], 27 | # screen=opt['logger']['screen'], tofile=opt['logger']['tofile']) 28 | self.logger = logging.getLogger(opt['logger']['name']) 29 | self.checkpoint = opt['train']['path'] 30 | self.name = opt['name'] 31 | 32 | if opt['train']['gpuid']: 33 | self.logger.info('Load Nvida GPU .....') 34 | self.device = torch.device( 35 | 'cuda:{}'.format(opt['train']['gpuid'][0])) 36 | self.gpuid = opt['train']['gpuid'] 37 | self.danet = DANet.to(self.device) 38 | else: 39 | self.logger.info('Load CPU ...........') 40 | self.device = torch.device('cpu') 41 | self.danet = DANet.to(self.device) 42 | 43 | if opt['resume']['state']: 44 | ckp = torch.load(opt['resume']['path'], map_location='cpu') 45 | self.cur_epoch = ckp['epoch'] 46 | self.logger.info("Resume from checkpoint {}: epoch {:d}".format( 47 | opt['resume']['path'], self.cur_epoch)) 48 | self.danet = DANet.load_state_dict( 49 | ckp['model_state_dict']).to(self.device) 50 | self.optimizer = optimizer.load_state_dict(ckp['optim_state_dict']) 51 | else: 52 | self.danet = DANet.to(self.device) 53 | self.optimizer = optimizer 54 | 55 | if opt['optim']['clip_norm']: 56 | self.clip_norm = opt['optim']['clip_norm'] 57 | self.logger.info( 58 | "Gradient clipping by {}, default L2".format(self.clip_norm)) 59 | else: 60 | self.clip_norm = 0 61 | 62 | def train(self, epoch): 63 | self.logger.info( 64 | 'Start training from epoch: {:d}, iter: {:d}'.format(epoch, 1)) 65 | self.danet.train() 66 | num_batchs = len(self.train_dataloader) 67 | total_loss = 0.0 68 | num_index = 1 69 | start_time = time.time() 70 | for mix_samp, wf, ibm, non_silent in self.train_dataloader: 71 | mix_samp = Variable(mix_samp).contiguous().to(self.device) 72 | wf = Variable(wf).contiguous().to(self.device) 73 | ibm = Variable(ibm).contiguous().to(self.device) 74 | non_silent = Variable(non_silent).contiguous().to(self.device) 75 | 76 | hidden = self.danet.init_hidden(mix_samp.size(0)) 77 | 78 | input_list = [mix_samp, ibm, non_silent, hidden] 79 | self.optimizer.zero_grad() 80 | 81 | if self.gpuid: 82 | #mask=torch.nn.parallel.data_parallel(self.danet,input_list,device_ids=self.gpuid) 83 | mask, hidden = self.danet(input_list) 84 | else: 85 | mask, hidden = self.danet(mix_samp, ibm, non_silent) 86 | 87 | l = Loss(mix_samp, wf, mask) 88 | epoch_loss = l.loss() 89 | total_loss += epoch_loss.item() 90 | epoch_loss.backward() 91 | 92 | #if self.clip_norm: 93 | # torch.nn.utils.clip_grad_norm_( 94 | # self.danet.parameters(), self.clip_norm) 95 | 96 | self.optimizer.step() 97 | if num_index % self.print_freq == 0: 98 | message = ''.format( 99 | epoch, num_index, self.optimizer.param_groups[0]['lr'], total_loss/num_index) 100 | self.logger.info(message) 101 | num_index += 1 102 | end_time = time.time() 103 | total_loss = total_loss/num_index 104 | message = ' '.format( 105 | epoch, num_index, self.optimizer.param_groups[0]['lr'], total_loss, (end_time-start_time)/60) 106 | self.logger.info(message) 107 | return total_loss 108 | 109 | def validation(self, epoch): 110 | self.logger.info( 111 | 'Start Validation from epoch: {:d}, iter: {:d}'.format(epoch, 1)) 112 | self.danet.eval() 113 | num_batchs = len(self.val_dataloader) 114 | num_index = 1 115 | total_loss = 0.0 116 | start_time = time.time() 117 | with torch.no_grad(): 118 | for mix_samp, wf, ibm, non_silent in self.val_dataloader: 119 | mix_samp = Variable(mix_samp).contiguous().to(self.device) 120 | wf = Variable(wf).contiguous().to(self.device) 121 | ibm = Variable(ibm).contiguous().to(self.device) 122 | non_silent = Variable(non_silent).contiguous().to(self.device) 123 | 124 | hidden = self.danet.init_hidden(mix_samp.size(0)) 125 | input_list = [mix_samp, ibm, non_silent, hidden] 126 | 127 | if self.gpuid: 128 | #mask=torch.nn.parallel.data_parallel(self.danet,input_list,device_ids=self.gpuid) 129 | mask, hidden = self.danet(input_list) 130 | else: 131 | mask, hidden = self.danet(mix_samp, ibm, non_silent) 132 | 133 | l = Loss(mix_samp, wf, mask) 134 | epoch_loss = l.loss() 135 | total_loss += epoch_loss.item() 136 | if num_index % self.print_freq == 0: 137 | message = ''.format( 138 | epoch, num_index, self.optimizer.param_groups[0]['lr'], total_loss/num_index) 139 | self.logger.info(message) 140 | num_index += 1 141 | end_time = time.time() 142 | total_loss = total_loss/num_index 143 | message = ' '.format( 144 | epoch, num_index, self.optimizer.param_groups[0]['lr'], total_loss, (end_time-start_time)/60) 145 | self.logger.info(message) 146 | return total_loss 147 | 148 | def run(self): 149 | train_loss = [] 150 | val_loss = [] 151 | with torch.cuda.device(self.gpuid[0]): 152 | self.save_checkpoint(self.cur_epoch, best=False) 153 | v_loss = self.validation(self.cur_epoch) 154 | best_loss = v_loss 155 | self.logger.info("Starting epoch from {:d}, loss = {:.4f}".format( 156 | self.cur_epoch, best_loss)) 157 | no_improve = 0 158 | # starting training part 159 | while self.cur_epoch < self.total_epoch: 160 | self.cur_epoch += 1 161 | t_loss = self.train(self.cur_epoch) 162 | v_loss = self.validation(self.cur_epoch) 163 | 164 | train_loss.append(t_loss) 165 | val_loss.append(v_loss) 166 | 167 | if v_loss >= best_loss: 168 | no_improve += 1 169 | self.logger.info( 170 | 'No improvement, Best Loss: {:.4f}'.format(best_loss)) 171 | else: 172 | best_loss = v_loss 173 | no_improve = 0 174 | self.save_checkpoint(self.cur_epoch, best=True) 175 | self.logger.info('Epoch: {:d}, Now Best Loss Change: {:.4f}'.format( 176 | self.cur_epoch, best_loss)) 177 | 178 | if no_improve == self.early_stop: 179 | self.logger.info( 180 | "Stop training cause no impr for {:d} epochs".format( 181 | no_improve)) 182 | break 183 | self.save_checkpoint(self.cur_epoch, best=False) 184 | self.logger.info("Training for {:d}/{:d} epoches done!".format( 185 | self.cur_epoch, self.total_epoch)) 186 | 187 | # draw loss image 188 | plt.title("Loss of train and test") 189 | x = [i for i in range(self.cur_epoch)] 190 | plt.plot(x, train_loss, 'b-', label=u'train_loss', linewidth=0.8) 191 | plt.plot(x, val_loss, 'c-', label=u'val_loss', linewidth=0.8) 192 | plt.legend() 193 | #plt.xticks(l, lx) 194 | plt.ylabel('loss') 195 | plt.xlabel('epoch') 196 | plt.savefig('loss.png') 197 | 198 | def save_checkpoint(self, epoch, best=True): 199 | ''' 200 | save model 201 | best: the best model 202 | ''' 203 | os.makedirs(os.path.join(self.checkpoint, self.name), exist_ok=True) 204 | torch.save({ 205 | 'epoch': epoch, 206 | 'model_state_dict': self.danet.state_dict(), 207 | 'optim_state_dict': self.optimizer.state_dict() 208 | }, 209 | os.path.join(self.checkpoint, self.name, '{0}.pt'.format('best' if best else 'last'))) 210 | --------------------------------------------------------------------------------