├── README.md ├── audio ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── __init__.cpython-39.pyc ├── bk │ └── infer.py ├── dataloader │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── audio_dataset.cpython-36.pyc │ │ ├── audio_dataset.cpython-39.pyc │ │ ├── record_audio.cpython-36.pyc │ │ └── record_audio.cpython-39.pyc │ ├── audio_dataset.py │ ├── create_data.py │ ├── crop_audio.py │ ├── reader.py │ ├── record_audio.py │ └── tempCodeRunnerFile.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── mobilenet_v2.cpython-36.pyc │ │ ├── mobilenet_v2.cpython-39.pyc │ │ ├── resnet.cpython-36.pyc │ │ └── resnet.cpython-39.pyc │ ├── mobilenet_v2.py │ └── resnet.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-39.pyc │ ├── file_utils.cpython-36.pyc │ ├── file_utils.cpython-39.pyc │ ├── image_utils.cpython-36.pyc │ ├── image_utils.cpython-39.pyc │ ├── utility.cpython-36.pyc │ └── utility.cpython-39.pyc │ ├── create_UrbanSound8K_file.py │ ├── file_utils.py │ ├── image_utils.py │ ├── log.py │ ├── plot_utils.py │ ├── setup_config.py │ ├── summary.py │ ├── torch_data.py │ ├── torch_tools.py │ └── utility.py ├── data ├── .DS_Store ├── UrbanSound8K │ ├── README.txt │ ├── class_name.txt │ ├── metadata │ │ └── UrbanSound8K.csv │ ├── test.txt │ ├── train.txt │ └── trainval.txt ├── audio │ ├── .DS_Store │ ├── air_conditioner │ │ ├── 13230-0-0-3.wav │ │ └── 13230-0-0-5.wav │ ├── car_horn │ │ ├── 7389-1-2-3.wav │ │ └── 7389-1-3-0.wav │ ├── dog_bark │ │ ├── 18581-3-1-3.wav │ │ └── 19218-3-0-0.wav │ ├── engine_idling │ │ ├── 17592-5-1-2.wav │ │ └── 17592-5-1-3.wav │ └── street_music │ │ ├── 6508-9-0-3.wav │ │ └── 6508-9-0-4.wav ├── pretrained │ └── model_075_0.965.pth └── record_audio │ ├── 20211004174340.wav │ └── 20211004174446.wav ├── demo.py ├── docs └── example.py ├── drawAudio.py ├── picture ├── chroma.png ├── mfcc.png ├── mfcc_scaled.png ├── spectral_centroid.png ├── spectrogram.png ├── spectrogram_log.png ├── wave.png └── zero_crossing_rate.png ├── train.py └── work_space └── mbv2 ├── log ├── events.out.tfevents.1636439290.pjq ├── events.out.tfevents.1653419301.mlzdeMBP.lan ├── events.out.tfevents.1653424414.mlzdeMBP.lan └── events.out.tfevents.1653492757.mlzdeMBP.lan └── model └── model_075_0.965.pth /README.md: -------------------------------------------------------------------------------- 1 | # torch-Audio-Recognition 2 | 3 | ## 1.目录结构 4 | 5 | ``` 6 | . 7 | ├── audio 8 | ├── data 9 | ├── picture 10 | ├── work_space/mbv2 11 | ├── README.md 12 | ├── demo.py 13 | ├── drawAudio.py 14 | └── train.py 15 | ``` 16 | 17 | ## 2.环境 18 | - 使用pip命令安装libsora和pyaudio 19 | 20 | ```shell 21 | pip install librosa 22 | pip install pyaudio 23 | pip install pydub 24 | ``` 25 | 26 | 27 | ## 3.数据处理 28 | #### (1)数据集Urbansound8K 29 | 30 | - `Urbansound8K`是目前应用较为广泛的用于自动城市环境声分类研究的公共数据集, 31 | 包含10个分类:空调声、汽车鸣笛声、儿童玩耍声、狗叫声、钻孔声、引擎空转声、枪声、手提钻、警笛声和街道音乐声。 32 | - [数据集下载](https://zenodo.org/record/1203745/files/UrbanSound8K.tar.gz) 33 | 34 | #### (2)自定义数据集 35 | 36 | - 可以自己录制音频信号,制作自己的数据集,参考[record_audio.py](audio/dataloader/record_audio.py) 37 | - 每个文件夹存放一个类别的音频数据,每条音频数据长度在3秒以上,建议每类的音频数据均衡 38 | - 生产train和test数据列表:参考[create_data.py](audio/dataloader/create_data.py) 39 | 40 | #### (3)音频特征提取 41 | 42 | 音频信号是一维的语音信号,不能直接用于模型训练,需要使用librosa将音频转为梅尔频谱(Mel Spectrogram) 43 | 44 | ```python 45 | wav, sr = librosa.load(data_path, sr=16000) 46 | # 使用librosa获得音频的梅尔频谱 47 | spec_image = librosa.feature.melspectrogram(y=wav, sr=sr, hop_length=256) 48 | ``` 49 | 50 | #### (4)音频图谱可视化 51 | 52 | 可以直接运行[drawAudio.py](drawAudio.py)查看音频图谱 53 | 54 | ```shell 55 | python drawAudio.py 56 | ``` 57 | 58 | 59 | ## 4.Train 60 | 61 | ```shell 62 | python train.py \ 63 | --data_dir path_to_UrbanSound8K \ 64 | --train_data path_to_UrbanSound8K/train.txt \ 65 | --test_data path_to_UrbanSound8K/test.txt \ 66 | ``` 67 | 68 | ## 5.预测 69 | 70 | ```shell 71 | python demo.py \ 72 | --model_file data/pretrained/model_075_0.965.pth \ 73 | --file_dir data/audio 74 | ``` 75 | -------------------------------------------------------------------------------- /audio/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /audio/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /audio/bk/infer.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import torch 4 | 5 | # 加载模型 6 | model_path = 'models/resnet34.pth' 7 | device = torch.device("cuda") 8 | model = torch.jit.load(model_path) 9 | model.to(device) 10 | model.eval() 11 | 12 | 13 | # 读取音频数据 14 | def load_data(data_path): 15 | # 读取音频 16 | wav, sr = librosa.load(data_path, sr=16000) 17 | spec_mag = librosa.feature.melspectrogram(y=wav, sr=sr, hop_length=256).astype(np.float32) 18 | mean = np.mean(spec_mag, 0, keepdims=True) 19 | std = np.std(spec_mag, 0, keepdims=True) 20 | spec_mag = (spec_mag - mean) / (std + 1e-5) 21 | spec_mag = spec_mag[np.newaxis, np.newaxis, :] 22 | return spec_mag 23 | 24 | 25 | def infer(audio_path): 26 | data = load_data(audio_path) 27 | data = torch.tensor(data, dtype=torch.float32, device=device) 28 | # 执行预测 29 | output = model(data) 30 | result = torch.nn.functional.softmax(output) 31 | result = result.data.cpu().numpy() 32 | print(result) 33 | # 显示图片并输出结果最大的label 34 | lab = np.argsort(result)[0][-1] 35 | return lab 36 | 37 | 38 | if __name__ == '__main__': 39 | # 要预测的音频文件 40 | path = 'dataset/UrbanSound8K/audio/fold5/156634-5-2-5.wav' 41 | label = infer(path) 42 | print('音频:%s 的预测结果标签为:%d' % (path, label)) 43 | -------------------------------------------------------------------------------- /audio/dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/dataloader/__init__.py -------------------------------------------------------------------------------- /audio/dataloader/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/dataloader/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /audio/dataloader/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/dataloader/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /audio/dataloader/__pycache__/audio_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/dataloader/__pycache__/audio_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /audio/dataloader/__pycache__/audio_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/dataloader/__pycache__/audio_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /audio/dataloader/__pycache__/record_audio.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/dataloader/__pycache__/record_audio.cpython-36.pyc -------------------------------------------------------------------------------- /audio/dataloader/__pycache__/record_audio.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/dataloader/__pycache__/record_audio.cpython-39.pyc -------------------------------------------------------------------------------- /audio/dataloader/audio_dataset.py: -------------------------------------------------------------------------------- 1 | # -*-coding: utf-8 -*- 2 | 3 | import os 4 | import random 5 | import librosa 6 | import numpy as np 7 | import pickle 8 | from torch.utils.data import Dataset 9 | from audio.utils import file_utils 10 | 11 | 12 | def load_audio(audio_file, cache=False): 13 | """ 14 | 加载并预处理音频 15 | :param audio_file: 16 | :param cache: librosa.load加载音频数据特别慢,建议使用进行缓存进行加速 17 | :return: 18 | """ 19 | # 读取音频数据 20 | cache_path = audio_file + ".pk" 21 | # t = librosa.get_duration(filename=audio_file) 22 | if cache and os.path.exists(cache_path): 23 | tmp = open(cache_path, 'rb') 24 | wav, sr = pickle.load(tmp) 25 | else: 26 | wav, sr = librosa.load(audio_file, sr=16000) 27 | if cache: 28 | f = open(cache_path, 'wb') 29 | pickle.dump([wav, sr], f) 30 | f.close() 31 | 32 | # Compute a mel-scaled spectrogram: 梅尔频谱图 33 | spec_image = librosa.feature.melspectrogram(y=wav, sr=sr, hop_length=256) 34 | return spec_image 35 | 36 | 37 | def normalization(spec_image, ymin=0.0, ymax=1.0): 38 | """ 39 | 数据归一化 40 | """ 41 | spec_image = spec_image.astype(np.float32) 42 | spec_image = (spec_image - spec_image.min()) / (spec_image.max() - spec_image.min()) 43 | spec_image = spec_image * (ymax - ymin) + ymin 44 | return spec_image 45 | 46 | # xmax = np.max(spec_image) # 计算最大值 47 | # # xmin = np.min(spec_image) # 计算最小值 48 | # xmin = 0 # 计算最小值 49 | # spec_image = (ymax - ymin) * (spec_image - xmin) / (xmax - xmin) + ymin 50 | # return spec_image 51 | 52 | 53 | def normalization_v1(spec_image, mean=None, std=None): 54 | """ 55 | 通过期望和方差实现数据归一化 56 | """ 57 | if not mean: 58 | mean = np.mean(spec_image, 0, keepdims=True) 59 | if not std: 60 | std = np.std(spec_image, 0, keepdims=True) 61 | std = std + 1e-8 62 | spec_image = (spec_image - mean) / std 63 | return spec_image 64 | 65 | 66 | class AudioDataset(Dataset): 67 | def __init__(self, filename, class_name, data_dir=None, mode='train', spec_len=128): 68 | """ 69 | 数据加载器 70 | :param filename: 数据文件 71 | :param data_dir: 数据文件所在目录 72 | :param class_name: 类别名称 73 | :param mode: 数据集类型,train/test 74 | :param spec_len: 梅尔频谱图长度 75 | """ 76 | super(AudioDataset, self).__init__() 77 | self.class_name, self.class_dict = file_utils.parser_classes(class_name, split=None) 78 | self.file_list = self.read_file(filename, data_dir, self.class_dict) 79 | self.mode = mode 80 | self.spec_len = spec_len 81 | self.num_file = len(self.file_list) 82 | 83 | def read_file(self, filename, data_dir, class_dict, split=","): 84 | """ 85 | :param filename: 86 | :param data_dir: 87 | :param class_dict: 88 | :return: 89 | """ 90 | with open(filename, 'r') as f: 91 | contents = f.readlines() 92 | contents = [content.rstrip().split(split) for content in contents] 93 | if not data_dir: 94 | data_dir = os.path.dirname(filename) 95 | file_list = [] 96 | for path, label in contents: 97 | label = class_dict[label] 98 | item = [os.path.join(data_dir, path), label] 99 | file_list.append(item) 100 | return file_list 101 | 102 | def __getitem__(self, idx): 103 | audio_path, label = self.file_list[idx] 104 | spec_image = load_audio(audio_path, cache=True) 105 | if spec_image.shape[1] > self.spec_len: 106 | if self.mode == 'train': 107 | # 梅尔频谱数据随机裁剪 108 | crop_start = random.randint(0, spec_image.shape[1] - self.spec_len) 109 | input = spec_image[:, crop_start:crop_start + self.spec_len] 110 | else: 111 | input = spec_image[:, :self.spec_len] 112 | # 将梅尔频谱图(灰度图)是转为为3通道RGB图 113 | # spec_image = cv2.cvtColor(spec_image, cv2.COLOR_GRAY2RGB) 114 | input = normalization(input) 115 | # spec_image = normalization_v1(spec_image) 116 | input = input[np.newaxis, :] 117 | else: 118 | # 如果音频长度不足,则用0填充 119 | # input = np.zeros(shape=(self.spec_len, self.spec_len), dtype=np.float32) 120 | # input[:, 0:spec_image.shape[1]] = spec_image 121 | # 如果音频较短,则丢弃,并随机读取一个音频 122 | idx = random.randint(0, self.num_file - 1) 123 | return self.__getitem__(idx) 124 | return input, np.array(int(label), dtype=np.int64) 125 | 126 | def __len__(self): 127 | return len(self.file_list) 128 | 129 | 130 | if __name__ == "__main__": 131 | data_dir = "E:/dataset/UrbanSound8K" 132 | filename = "../../data/UrbanSound8K/train.txt" 133 | dataset = AudioDataset(filename, data_dir=data_dir) 134 | for data in dataset: 135 | image, label = data 136 | image = image.transpose(1, 2, 0).copy() 137 | print("image:{},label:{}".format(image.shape, label)) 138 | 139 | # from audio.utils import image_utils 140 | # import cv2 141 | # image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) # 将BGR转为RGB 142 | # image_utils.cv_show_image("image", image) 143 | # image_utils.show_image("image", image) 144 | -------------------------------------------------------------------------------- /audio/dataloader/create_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import librosa 3 | from tqdm import tqdm 4 | 5 | 6 | def create_train_test_list(data_dir, out_root, class_name): 7 | """ 8 | 生成数据列表 9 | """ 10 | sound_sum = 0 11 | audios = os.listdir(data_dir) 12 | f_train = open(os.path.join(out_root, 'train.txt'), 'w') 13 | f_test = open(os.path.join(out_root, 'test.txt'), 'w') 14 | class_dict = {name: i for i, name in enumerate(class_name)} 15 | for name in tqdm(audios): 16 | sounds = os.listdir(os.path.join(data_dir, name)) 17 | for sound in sounds: 18 | if not sound.endswith('.wav'): 19 | continue 20 | path = os.path.join(name, sound) 21 | sound_path = os.path.join(data_dir, path) 22 | t = librosa.get_duration(filename=sound_path) 23 | content = os.path.join(os.path.basename(data_dir), path) 24 | if t < 1.5: 25 | continue 26 | label = class_dict[name] 27 | if sound_sum % 100 == 0: 28 | f_test.write('%s,%d\n' % (content, label)) 29 | else: 30 | f_train.write('%s,%d\n' % (content, label)) 31 | sound_sum += 1 32 | 33 | f_test.close() 34 | f_train.close() 35 | 36 | 37 | if __name__ == '__main__': 38 | # data_dir = "/media/pan/新加卷/dataset/UrbanSound8K/audio" 39 | data_dir = "E:/dataset/UrbanSound8K/audio" 40 | out_root = '../../data/UrbanSound8K' 41 | class_name = ["fold1", "fold2", "fold3", "fold4", "fold5", 42 | "fold6", "fold7", "fold8", "fold9", "fold10", 43 | ] 44 | create_train_test_list(data_dir, out_root, class_name=class_name) 45 | -------------------------------------------------------------------------------- /audio/dataloader/crop_audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uuid 3 | import wave 4 | from pydub import AudioSegment 5 | 6 | 7 | # 按秒截取音频 8 | def get_part_wav(sound, start_time, end_time, part_wav_path): 9 | save_path = os.path.dirname(part_wav_path) 10 | if not os.path.exists(save_path): 11 | os.makedirs(save_path) 12 | start_time = int(start_time) * 1000 13 | end_time = int(end_time) * 1000 14 | word = sound[start_time:end_time] 15 | word.export(part_wav_path, format="wav") 16 | 17 | 18 | def crop_wav(path, crop_len): 19 | for src_wav_path in os.listdir(path): 20 | wave_path = os.path.join(path, src_wav_path) 21 | print(wave_path[-4:]) 22 | if wave_path[-4:] != '.wav': 23 | continue 24 | file = wave.open(wave_path) 25 | # 帧总数 26 | a = file.getparams().nframes 27 | # 采样频率 28 | f = file.getparams().framerate 29 | # 获取音频时间长度 30 | t = int(a / f) 31 | print('总时长为 %d s' % t) 32 | # 读取语音 33 | sound = AudioSegment.from_wav(wave_path) 34 | for start_time in range(0, t, crop_len): 35 | save_path = os.path.join(path, os.path.basename(wave_path)[:-4], str(uuid.uuid1()) + '.wav') 36 | get_part_wav(sound, start_time, start_time + crop_len, save_path) 37 | 38 | 39 | if __name__ == '__main__': 40 | crop_len = 3 41 | crop_wav('save_audio', crop_len) 42 | -------------------------------------------------------------------------------- /audio/dataloader/reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import librosa 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | import pickle 8 | 9 | 10 | # 加载并预处理音频 11 | def load_audio(audio_path, mode='train', spec_len=128, cache=True ): 12 | # 读取音频数据 13 | cache_path = audio_path + ".pk" 14 | if cache and os.path.exists(cache_path): 15 | tmp = open(cache_path, 'rb') 16 | wav, sr = pickle.load(tmp) 17 | else: 18 | wav, sr = librosa.load(audio_path, sr=16000) 19 | f = open(cache_path, 'wb') 20 | pickle.dump([wav, sr], f) 21 | f.close() 22 | spec_mag = librosa.feature.melspectrogram(y=wav, sr=sr, hop_length=256) 23 | if mode == 'train': 24 | crop_start = random.randint(0, spec_mag.shape[1] - spec_len) 25 | spec_mag = spec_mag[:, crop_start:crop_start + spec_len] 26 | else: 27 | spec_mag = spec_mag[:, :spec_len] 28 | mean = np.mean(spec_mag, 0, keepdims=True) 29 | std = np.std(spec_mag, 0, keepdims=True) 30 | spec_mag = (spec_mag - mean) / (std + 1e-5) 31 | spec_mag = spec_mag[np.newaxis, :] 32 | return spec_mag 33 | 34 | 35 | # 数据加载器 36 | class CustomDataset(Dataset): 37 | def __init__(self, data_list_path, model='train', spec_len=128): 38 | super(CustomDataset, self).__init__() 39 | with open(data_list_path, 'r') as f: 40 | self.lines = f.readlines() 41 | self.model = model 42 | self.spec_len = spec_len 43 | 44 | def __getitem__(self, idx): 45 | audio_path, label = self.lines[idx].replace('\n', '').split('\t') 46 | spec_mag = load_audio(audio_path, mode=self.model, spec_len=self.spec_len) 47 | return spec_mag, np.array(int(label), dtype=np.int64) 48 | 49 | def __len__(self): 50 | return len(self.lines) 51 | -------------------------------------------------------------------------------- /audio/dataloader/record_audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wave 3 | import librosa 4 | import numpy as np 5 | import pyaudio 6 | 7 | 8 | def record_audio(audio_file): 9 | """录制音频""" 10 | # 录音参数 11 | if not os.path.exists(os.path.dirname(audio_file)): 12 | os.makedirs(os.path.dirname(audio_file)) 13 | FORMAT = pyaudio.paInt16 14 | CHANNELS = 1 15 | RATE = 16000 16 | RECORD_SECONDS = 3 17 | CHUNK = 1024 18 | length = RATE / CHUNK * RECORD_SECONDS 19 | # 打开录音 20 | audio = pyaudio.PyAudio() 21 | audio_stream = audio.open(format=FORMAT, 22 | channels=CHANNELS, 23 | rate=RATE, 24 | input=True, 25 | frames_per_buffer=CHUNK) 26 | print("开始录音......") 27 | frames = [] 28 | for i in range(0, int(length)): 29 | data = audio_stream.read(CHUNK) 30 | frames.append(data) 31 | print("录音已结束!") 32 | wf = wave.open(audio_file, 'wb') 33 | wf.setnchannels(CHANNELS) 34 | wf.setsampwidth(audio.get_sample_size(FORMAT)) 35 | wf.setframerate(RATE) 36 | wf.writeframes(b''.join(frames)) 37 | wf.close() 38 | 39 | audio_stream.stop_stream() 40 | audio_stream.close() 41 | audio.terminate() 42 | return audio_file 43 | 44 | 45 | if __name__ == '__main__': 46 | audio_file = "audio.wav" 47 | record_audio(audio_file) 48 | -------------------------------------------------------------------------------- /audio/dataloader/tempCodeRunnerFile.py: -------------------------------------------------------------------------------- 1 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) # 将BGR转为RGB 2 | image_utils.cv_show_image("image", image) 3 | image_utils.show_image("image", image) -------------------------------------------------------------------------------- /audio/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/models/__init__.py -------------------------------------------------------------------------------- /audio/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /audio/models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /audio/models/__pycache__/mobilenet_v2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/models/__pycache__/mobilenet_v2.cpython-36.pyc -------------------------------------------------------------------------------- /audio/models/__pycache__/mobilenet_v2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/models/__pycache__/mobilenet_v2.cpython-39.pyc -------------------------------------------------------------------------------- /audio/models/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/models/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /audio/models/__pycache__/resnet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/models/__pycache__/resnet.cpython-39.pyc -------------------------------------------------------------------------------- /audio/models/mobilenet_v2.py: -------------------------------------------------------------------------------- 1 | # -*-coding: utf-8 -*- 2 | """ 3 | @Project: AudioClassification-Pytorch 4 | @File : mobilenet_v2.py 5 | @Author : panjq 6 | @E-mail : pan_jinquan@163.com 7 | @Date : 2021-10-02 11:21:05 8 | """ 9 | 10 | from torch import nn 11 | from torch import Tensor 12 | # from .utils import load_state_dict_from_url 13 | from typing import Callable, Any, Optional, List 14 | 15 | 16 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 17 | 18 | 19 | model_urls = { 20 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 21 | } 22 | 23 | 24 | def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: 25 | """ 26 | This function is taken from the original tf repo. 27 | It ensures that all layers have a channel number that is divisible by 8 28 | It can be seen here: 29 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 30 | """ 31 | if min_value is None: 32 | min_value = divisor 33 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 34 | # Make sure that round down does not go down by more than 10%. 35 | if new_v < 0.9 * v: 36 | new_v += divisor 37 | return new_v 38 | 39 | 40 | class ConvBNActivation(nn.Sequential): 41 | def __init__( 42 | self, 43 | in_planes: int, 44 | out_planes: int, 45 | kernel_size: int = 3, 46 | stride: int = 1, 47 | groups: int = 1, 48 | norm_layer: Optional[Callable[..., nn.Module]] = None, 49 | activation_layer: Optional[Callable[..., nn.Module]] = None, 50 | dilation: int = 1, 51 | ) -> None: 52 | padding = (kernel_size - 1) // 2 * dilation 53 | if norm_layer is None: 54 | norm_layer = nn.BatchNorm2d 55 | if activation_layer is None: 56 | activation_layer = nn.ReLU6 57 | super(ConvBNReLU, self).__init__( 58 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation=dilation, groups=groups, 59 | bias=False), 60 | norm_layer(out_planes), 61 | activation_layer(inplace=True) 62 | ) 63 | self.out_channels = out_planes 64 | 65 | 66 | # necessary for backwards compatibility 67 | ConvBNReLU = ConvBNActivation 68 | 69 | 70 | class InvertedResidual(nn.Module): 71 | def __init__( 72 | self, 73 | inp: int, 74 | oup: int, 75 | stride: int, 76 | expand_ratio: int, 77 | norm_layer: Optional[Callable[..., nn.Module]] = None 78 | ) -> None: 79 | super(InvertedResidual, self).__init__() 80 | self.stride = stride 81 | assert stride in [1, 2] 82 | 83 | if norm_layer is None: 84 | norm_layer = nn.BatchNorm2d 85 | 86 | hidden_dim = int(round(inp * expand_ratio)) 87 | self.use_res_connect = self.stride == 1 and inp == oup 88 | 89 | layers: List[nn.Module] = [] 90 | if expand_ratio != 1: 91 | # pw 92 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)) 93 | layers.extend([ 94 | # dw 95 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer), 96 | # pw-linear 97 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 98 | norm_layer(oup), 99 | ]) 100 | self.conv = nn.Sequential(*layers) 101 | self.out_channels = oup 102 | self._is_cn = stride > 1 103 | 104 | def forward(self, x: Tensor) -> Tensor: 105 | if self.use_res_connect: 106 | return x + self.conv(x) 107 | else: 108 | return self.conv(x) 109 | 110 | 111 | class MobileNetV2(nn.Module): 112 | def __init__( 113 | self, 114 | num_classes: int = 1000, 115 | width_mult: float = 1.0, 116 | inverted_residual_setting: Optional[List[List[int]]] = None, 117 | round_nearest: int = 8, 118 | block: Optional[Callable[..., nn.Module]] = None, 119 | norm_layer: Optional[Callable[..., nn.Module]] = None 120 | ) -> None: 121 | """ 122 | MobileNet V2 main class 123 | 124 | Args: 125 | num_classes (int): Number of classes 126 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 127 | inverted_residual_setting: Network structure 128 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 129 | Set to 1 to turn off rounding 130 | block: Module specifying inverted residual building block for mobilenet 131 | norm_layer: Module specifying the normalization layer to use 132 | 133 | """ 134 | super(MobileNetV2, self).__init__() 135 | 136 | if block is None: 137 | block = InvertedResidual 138 | 139 | if norm_layer is None: 140 | norm_layer = nn.BatchNorm2d 141 | 142 | input_channel = 64 143 | last_channel = 1280 144 | 145 | if inverted_residual_setting is None: 146 | inverted_residual_setting = [ 147 | # t, c, n, s 148 | [1, 16, 1, 1], 149 | [6, 24, 2, 2], 150 | [6, 32, 3, 2], 151 | [6, 64, 4, 2], 152 | [6, 96, 3, 1], 153 | [6, 160, 3, 2], 154 | [6, 320, 1, 1], 155 | ] 156 | 157 | # only check the first element, assuming user knows t,c,n,s are required 158 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 159 | raise ValueError("inverted_residual_setting should be non-empty " 160 | "or a 4-element list, got {}".format(inverted_residual_setting)) 161 | 162 | # building first layer 163 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 164 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 165 | features: List[nn.Module] = [ConvBNReLU(1, input_channel, stride=2, norm_layer=norm_layer)] 166 | # building inverted residual blocks 167 | for t, c, n, s in inverted_residual_setting: 168 | output_channel = _make_divisible(c * width_mult, round_nearest) 169 | for i in range(n): 170 | stride = s if i == 0 else 1 171 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) 172 | input_channel = output_channel 173 | # building last several layers 174 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer)) 175 | # make it nn.Sequential 176 | self.features = nn.Sequential(*features) 177 | 178 | # building classifier 179 | self.classifier = nn.Sequential( 180 | nn.Dropout(0.2), 181 | nn.Linear(self.last_channel, num_classes), 182 | ) 183 | 184 | # weight initialization 185 | for m in self.modules(): 186 | if isinstance(m, nn.Conv2d): 187 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 188 | if m.bias is not None: 189 | nn.init.zeros_(m.bias) 190 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 191 | nn.init.ones_(m.weight) 192 | nn.init.zeros_(m.bias) 193 | elif isinstance(m, nn.Linear): 194 | nn.init.normal_(m.weight, 0, 0.01) 195 | nn.init.zeros_(m.bias) 196 | 197 | def _forward_impl(self, x: Tensor) -> Tensor: 198 | # This exists since TorchScript doesn't support inheritance, so the superclass method 199 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 200 | x = self.features(x) 201 | # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0] 202 | x = nn.functional.adaptive_avg_pool2d(x, (1, 1)).reshape(x.shape[0], -1) 203 | x = self.classifier(x) 204 | return x 205 | 206 | def forward(self, x: Tensor) -> Tensor: 207 | return self._forward_impl(x) 208 | 209 | 210 | def mobilenet_v2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV2: 211 | """ 212 | Constructs a MobileNetV2 architecture from 213 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 214 | 215 | Args: 216 | pretrained (bool): If True, returns a model pre-trained on ImageNet 217 | progress (bool): If True, displays a progress bar of the download to stderr 218 | """ 219 | model = MobileNetV2(**kwargs) 220 | if pretrained: 221 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 222 | progress=progress) 223 | model.load_state_dict(state_dict) 224 | return model 225 | 226 | -------------------------------------------------------------------------------- /audio/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BasicBlock(nn.Module): 6 | expansion = 1 7 | 8 | def __init__(self, 9 | inplanes, 10 | planes, 11 | stride=1, 12 | downsample=None, 13 | groups=1, 14 | base_width=64, 15 | dilation=1, 16 | norm_layer=None): 17 | super(BasicBlock, self).__init__() 18 | if norm_layer is None: 19 | norm_layer = nn.BatchNorm2d 20 | 21 | if dilation > 1: 22 | raise NotImplementedError( 23 | "Dilation > 1 not supported in BasicBlock") 24 | 25 | self.conv1 = nn.Conv2d( 26 | inplanes, planes, 3, padding=1, stride=stride) 27 | self.bn1 = norm_layer(planes) 28 | self.relu = nn.ReLU() 29 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1) 30 | self.bn2 = norm_layer(planes) 31 | self.downsample = downsample 32 | self.stride = stride 33 | 34 | def forward(self, x): 35 | identity = x 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | 44 | if self.downsample is not None: 45 | identity = self.downsample(x) 46 | 47 | out += identity 48 | out = self.relu(out) 49 | 50 | return out 51 | 52 | 53 | class ResNet(nn.Module): 54 | """ResNet model from 55 | `"Deep Residual Learning for Image Recognition" `_ 56 | 57 | Args: 58 | Block (BasicBlock|BottleneckBlock): block module of model. 59 | depth (int): layers of resnet, default: 50. 60 | num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer 61 | will not be defined. Default: 1000. 62 | with_pool (bool): use pool before the last fc layer or not. Default: True. 63 | 64 | Examples: 65 | .. code-block:: python 66 | 67 | from paddle.vision.models import ResNet 68 | from paddle.vision.models.resnet import BottleneckBlock, BasicBlock 69 | 70 | resnet50 = ResNet(BottleneckBlock, 50) 71 | 72 | resnet18 = ResNet(BasicBlock, 18) 73 | 74 | """ 75 | 76 | def __init__(self, block, depth, num_classes=1000, with_pool=True): 77 | super(ResNet, self).__init__() 78 | layer_cfg = { 79 | 18: [2, 2, 2, 2], 80 | 34: [3, 4, 6, 3], 81 | 50: [3, 4, 6, 3], 82 | 101: [3, 4, 23, 3], 83 | 152: [3, 8, 36, 3] 84 | } 85 | layers = layer_cfg[depth] 86 | self.num_classes = num_classes 87 | self.with_pool = with_pool 88 | self._norm_layer = nn.BatchNorm2d 89 | 90 | self.inplanes = 64 91 | self.dilation = 1 92 | 93 | self.conv1 = nn.Conv2d( 94 | 1, 95 | self.inplanes, 96 | kernel_size=7, 97 | stride=2, 98 | padding=3) 99 | self.bn1 = self._norm_layer(self.inplanes) 100 | self.relu = nn.ReLU() 101 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 102 | self.layer1 = self._make_layer(block, 64, layers[0]) 103 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 104 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 105 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 106 | if with_pool: 107 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 108 | 109 | if num_classes > 0: 110 | self.fc = nn.Linear(512 * block.expansion, num_classes) 111 | 112 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 113 | norm_layer = self._norm_layer 114 | downsample = None 115 | previous_dilation = self.dilation 116 | if dilate: 117 | self.dilation *= stride 118 | stride = 1 119 | if stride != 1 or self.inplanes != planes * block.expansion: 120 | downsample = nn.Sequential( 121 | nn.Conv2d( 122 | self.inplanes, 123 | planes * block.expansion, 124 | 1, 125 | stride=stride), 126 | norm_layer(planes * block.expansion), ) 127 | 128 | layers = [] 129 | layers.append( 130 | block(self.inplanes, planes, stride, downsample, 1, 64, 131 | previous_dilation, norm_layer)) 132 | self.inplanes = planes * block.expansion 133 | for _ in range(1, blocks): 134 | layers.append(block(self.inplanes, planes, norm_layer=norm_layer)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | x = self.conv1(x) 140 | x = self.bn1(x) 141 | x = self.relu(x) 142 | x = self.maxpool(x) 143 | x = self.layer1(x) 144 | x = self.layer2(x) 145 | x = self.layer3(x) 146 | x = self.layer4(x) 147 | 148 | if self.with_pool: 149 | x = self.avgpool(x) 150 | 151 | if self.num_classes > 0: 152 | x = torch.flatten(x, 1) 153 | x = self.fc(x) 154 | 155 | return x 156 | 157 | 158 | def resnet34(**kwargs): 159 | model = ResNet(BasicBlock, 34, **kwargs) 160 | return model 161 | 162 | 163 | def resnet18(**kwargs): 164 | model = ResNet(BasicBlock, 18, **kwargs) 165 | return model 166 | -------------------------------------------------------------------------------- /audio/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/utils/__init__.py -------------------------------------------------------------------------------- /audio/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /audio/utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /audio/utils/__pycache__/file_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/utils/__pycache__/file_utils.cpython-36.pyc -------------------------------------------------------------------------------- /audio/utils/__pycache__/file_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/utils/__pycache__/file_utils.cpython-39.pyc -------------------------------------------------------------------------------- /audio/utils/__pycache__/image_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/utils/__pycache__/image_utils.cpython-36.pyc -------------------------------------------------------------------------------- /audio/utils/__pycache__/image_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/utils/__pycache__/image_utils.cpython-39.pyc -------------------------------------------------------------------------------- /audio/utils/__pycache__/utility.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/utils/__pycache__/utility.cpython-36.pyc -------------------------------------------------------------------------------- /audio/utils/__pycache__/utility.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/audio/utils/__pycache__/utility.cpython-39.pyc -------------------------------------------------------------------------------- /audio/utils/create_UrbanSound8K_file.py: -------------------------------------------------------------------------------- 1 | # -*-coding: utf-8 -*- 2 | 3 | import random 4 | import numpy as np 5 | import pandas as pd 6 | from audio.utils import file_utils 7 | 8 | 9 | def read_metadata_file(metadata_file, shuffle=False): 10 | """ 11 | 读取UrbanSound8K标注数据,并转换为[path,class_name]的形式 12 | """ 13 | data = pd.read_csv(metadata_file) 14 | valid_data = data[['slice_file_name', 'fold', 'classID', 'class']][data['end'] - data['start'] >= 3] 15 | valid_data['path'] = 'fold' + valid_data['fold'].astype('str') + '/' + valid_data['slice_file_name'].astype('str') 16 | paths = np.asarray(valid_data['path']).tolist() 17 | labels = np.asarray(valid_data['class']).tolist() 18 | assert len(paths) == len(labels) 19 | item_list = [[p, l] for p, l in zip(paths, labels)] 20 | item_list = sorted(item_list) 21 | if shuffle: 22 | random.seed(200) 23 | random.shuffle(item_list) 24 | return item_list 25 | 26 | 27 | if __name__ == '__main__': 28 | metadata_file = "/home/dm/data3/release/MYGit/torch-Audio-Recognition/data/UrbanSound8K/metadata/UrbanSound8K.csv" 29 | save_file = "../../data/UrbanSound8K/trainval.txt" 30 | item_list = read_metadata_file(metadata_file) 31 | file_utils.write_data(save_file, item_list, split=",") 32 | -------------------------------------------------------------------------------- /audio/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import glob 4 | import os 5 | import time 6 | import os, shutil 7 | import numpy as np 8 | import json 9 | import random 10 | import os 11 | import subprocess 12 | import concurrent.futures 13 | from datetime import datetime 14 | 15 | 16 | def get_time(format="S"): 17 | """ 18 | :param format: 19 | :return: 20 | """ 21 | if format in ["S", "s"]: 22 | # time = datetime.strftime(datetime.now(), '%Y%m%d_%H%M%S') 23 | time = datetime.strftime(datetime.now(), '%Y%m%d%H%M%S') 24 | elif format in ["P", "p"]: 25 | # 20200508_143059_379116 26 | time = datetime.strftime(datetime.now(), '%Y%m%d_%H%M%S_%f') 27 | time = time[:-2] 28 | else: 29 | time = (str(datetime.now())[:-10]).replace(' ', '-').replace(':', '-') 30 | return time 31 | 32 | 33 | def get_kwargs_name(**kwargs): 34 | prefix = [] 35 | for k, v in kwargs.items(): 36 | if isinstance(v, list): 37 | v = [str(l) for l in v] 38 | prefix += v 39 | else: 40 | f = "{}_{}".format(k, v) 41 | prefix.append(f) 42 | prefix = "_".join(prefix) 43 | return prefix 44 | 45 | 46 | def combine_flags(flags: list, use_time=True, info=True): 47 | """ 48 | :param flags: 49 | :param info: 50 | :return: 51 | """ 52 | out_flags = [] 53 | for f in flags: 54 | if isinstance(f, dict): 55 | f = get_kwargs_name(**f) 56 | out_flags.append(f) 57 | if use_time: 58 | out_flags += [get_time()] 59 | out_flags = [str(f) for f in out_flags if f] 60 | out_flags = "_".join(out_flags) 61 | if info: 62 | print(out_flags) 63 | return out_flags 64 | 65 | 66 | class WriterTXT(object): 67 | """ write data in txt files""" 68 | 69 | def __init__(self, filename, mode='w'): 70 | self.f = None 71 | if filename: 72 | self.f = open(filename, mode=mode) 73 | 74 | def write_line_str(self, line_str, endline="\n"): 75 | if self.f: 76 | line_str = line_str + endline 77 | self.f.write(line_str) 78 | self.f.flush() 79 | 80 | def write_line_list(self, line_list, endline="\n"): 81 | if self.f: 82 | for line_list in line_list: 83 | # 将list转为string 84 | line_str = " ".join('%s' % id for id in line_list) 85 | self.write_line_str(line_str, endline=endline) 86 | self.f.flush() 87 | 88 | def close(self): 89 | if self.f: 90 | self.f.close() 91 | 92 | 93 | def parser_classes(class_name, split=None): 94 | """ 95 | class_dict = {class_name: i for i, class_name in enumerate(class_name)} 96 | :param 97 | :return: 98 | """ 99 | if isinstance(class_name, str): 100 | class_name = read_data(class_name, split=split) 101 | if isinstance(class_name, list): 102 | class_dict = {class_name: i for i, class_name in enumerate(class_name)} 103 | elif isinstance(class_name, dict): 104 | class_dict = class_name 105 | else: 106 | class_dict = None 107 | return class_name, class_dict 108 | 109 | 110 | def read_json_data(json_path): 111 | """ 112 | 读取数据 113 | :param json_path: 114 | :return: 115 | """ 116 | with open(json_path, 'r') as f: 117 | json_data = json.load(f) 118 | return json_data 119 | 120 | 121 | def write_json_path(out_json_path, json_data): 122 | """ 123 | 写入 JSON 数据 124 | :param out_json_path: 125 | :param json_data: 126 | :return: 127 | """ 128 | with open(out_json_path, 'w', encoding="utf-8") as f: 129 | json.dump(json_data, f, indent=4, ensure_ascii=False) 130 | 131 | 132 | def write_data(filename, content_list, split=" ", mode='w'): 133 | """保存list[list[]]的数据到txt文件 134 | :param filename:文件名 135 | :param content_list:需要保存的数据,type->list 136 | :param mode:读写模式:'w' or 'a' 137 | :return: void 138 | """ 139 | with open(filename, mode=mode, encoding='utf-8') as f: 140 | for line_list in content_list: 141 | # 将list转为string 142 | line = "{}".format(split).join('%s' % id for id in line_list) 143 | f.write(line + "\n") 144 | f.flush() 145 | 146 | 147 | def write_list_data(filename, list_data, mode='w'): 148 | """保存list[]的数据到txt文件,每个元素分行 149 | :param filename:文件名 150 | :param list_data:需要保存的数据,type->list 151 | :param mode:读写模式:'w' or 'a' 152 | :return: void 153 | """ 154 | with open(filename, mode=mode, encoding='utf-8') as f: 155 | for line in list_data: 156 | # 将list转为string 157 | f.write(str(line) + "\n") 158 | f.flush() 159 | 160 | 161 | def read_data(filename, split=" ", convertNum=True): 162 | """ 163 | 读取txt数据函数 164 | :param filename:文件名 165 | :param split :分割符 166 | :param convertNum :是否将list中的string转为int/float类型的数字 167 | :return: txt的数据列表 168 | Python中有三个去除头尾字符、空白符的函数,它们依次为: 169 | strip: 用来去除头尾字符、空白符(包括\n、\r、\t、' ',即:换行、回车、制表符、空格) 170 | lstrip:用来去除开头字符、空白符(包括\n、\r、\t、' ',即:换行、回车、制表符、空格) 171 | rstrip:用来去除结尾字符、空白符(包括\n、\r、\t、' ',即:换行、回车、制表符、空格) 172 | 注意:这些函数都只会删除头和尾的字符,中间的不会删除。 173 | """ 174 | with open(filename, mode="r", encoding='utf-8') as f: 175 | content_list = f.readlines() 176 | if split is None: 177 | content_list = [content.rstrip() for content in content_list] 178 | return content_list 179 | else: 180 | content_list = [content.rstrip().split(split) for content in content_list] 181 | if convertNum: 182 | for i, line in enumerate(content_list): 183 | line_data = [] 184 | for l in line: 185 | if is_int(l): # isdigit() 方法检测字符串是否只由数字组成,只能判断整数 186 | line_data.append(int(l)) 187 | elif is_float(l): # 判断是否为小数 188 | line_data.append(float(l)) 189 | else: 190 | line_data.append(l) 191 | content_list[i] = line_data 192 | return content_list 193 | 194 | 195 | def read_line_image_label(line_image_label): 196 | ''' 197 | line_image_label:[image_id,boxes_nums,x1, y1, w, h, label_id,x1, y1, w, h, label_id,...] 198 | :param line_image_label: 199 | :return: 200 | ''' 201 | line_image_label = line_image_label.strip().split() 202 | image_id = line_image_label[0] 203 | boxes_nums = int(line_image_label[1]) 204 | box = [] 205 | label = [] 206 | for i in range(boxes_nums): 207 | x = float(line_image_label[2 + 5 * i]) 208 | y = float(line_image_label[3 + 5 * i]) 209 | w = float(line_image_label[4 + 5 * i]) 210 | h = float(line_image_label[5 + 5 * i]) 211 | c = int(line_image_label[6 + 5 * i]) 212 | if w <= 0 or h <= 0: 213 | continue 214 | box.append([x, y, x + w, y + h]) 215 | label.append(c) 216 | return image_id, box, label 217 | 218 | 219 | def read_lines_image_labels(filename): 220 | """ 221 | :param filename: 222 | :return: 223 | """ 224 | boxes_label_lists = [] 225 | with open(filename) as f: 226 | lines = f.readlines() 227 | for line in lines: 228 | image_id, box, label = read_line_image_label(line) 229 | boxes_label_lists.append([image_id, box, label]) 230 | return boxes_label_lists 231 | 232 | 233 | def save_file_root_list(file_root: str, 234 | out_path=None, 235 | postfix=["*.jpg"], 236 | replace_postfix=False, 237 | add_dirname=False, 238 | shuffle=False): 239 | """ 240 | 保存file_dir目录下所有后缀名为postfix的文件 241 | :param file_root: 文件根目录 242 | :param out_path: 保存列表的路径,默认为file_dir的上一级目录"file_id.txt" 243 | :param postfix: 文件后缀名,支持多个后缀格式 244 | :param remove_postfix: or 文件列表是否包含后缀名 245 | :param dirname: 文件列表是否增加父目录名称 246 | :return: 247 | """ 248 | file_dir_len = len(file_root) 249 | if not file_root.endswith(os.sep): 250 | file_dir_len = file_dir_len + 1 251 | dirname = os.path.basename(file_root) 252 | anno_list = get_files(file_root, postfix=postfix) 253 | if shuffle: 254 | random.seed(100) 255 | random.shuffle(anno_list) 256 | image_idx = [] 257 | for path in anno_list: 258 | # basename = os.path.basename(path) 259 | basename = path[file_dir_len:] 260 | if replace_postfix: 261 | basename = basename.split(".")[0] + replace_postfix 262 | elif replace_postfix == "": 263 | basename = basename.split(".")[0] 264 | if add_dirname: 265 | basename = os.path.join(dirname, basename) 266 | image_idx.append(basename) 267 | if not out_path: 268 | out_path = os.path.join(os.path.dirname(file_root), "file_id.txt") 269 | print("num files:{},out_path:{}".format(len(image_idx), out_path)) 270 | write_list_data(out_path, image_idx) 271 | return image_idx 272 | 273 | 274 | def is_int(str): 275 | """ 276 | 判断是否为整数 277 | :param str: 278 | :return: 279 | """ 280 | try: 281 | x = int(str) 282 | return isinstance(x, int) 283 | except ValueError: 284 | return False 285 | 286 | 287 | def is_float(str): 288 | """ 289 | 判断是否为整数和小数 290 | :param str: 291 | :return: 292 | """ 293 | try: 294 | x = float(str) 295 | return isinstance(x, float) 296 | except ValueError: 297 | return False 298 | 299 | 300 | def list2str(content_list): 301 | """ 302 | convert list to string 303 | :param content_list: 304 | :return: 305 | """ 306 | content_str_list = [] 307 | for line_list in content_list: 308 | line_str = " ".join('%s' % id for id in line_list) 309 | content_str_list.append(line_str) 310 | return content_str_list 311 | 312 | 313 | def get_images_list(image_dir, postfix=['*.jpg'], basename=False): 314 | ''' 315 | 获得文件列表 316 | :param image_dir: 图片文件目录 317 | :param postfix: 后缀名,可是多个如,['*.jpg','*.png'] 318 | :param basename: 返回的列表是文件名(True),还是文件的完整路径(False) 319 | :return: 320 | ''' 321 | images_list = [] 322 | for format in postfix: 323 | image_format = os.path.join(image_dir, format) 324 | image_list = glob.glob(image_format) 325 | if not image_list == []: 326 | images_list += image_list 327 | images_list = sorted(images_list) 328 | if basename: 329 | images_list = get_basename(images_list) 330 | return images_list 331 | 332 | 333 | def get_basename(file_list): 334 | """ 335 | get files basename 336 | :param file_list: 337 | :return: 338 | """ 339 | dest_list = [] 340 | for file_path in file_list: 341 | basename = os.path.basename(file_path) 342 | dest_list.append(basename) 343 | return dest_list 344 | 345 | 346 | def randam_select_images(image_list, nums, shuffle=True): 347 | """ 348 | randam select nums images 349 | :param image_list: 350 | :param nums: 351 | :param shuffle: 352 | :return: 353 | """ 354 | image_nums = len(image_list) 355 | if image_nums <= nums: 356 | return image_list 357 | if shuffle: 358 | random.seed(100) 359 | random.shuffle(image_list) 360 | out = image_list[:nums] 361 | return out 362 | 363 | 364 | def remove_dir(dir): 365 | """ 366 | remove directory 367 | :param dir: 368 | :return: 369 | """ 370 | if os.path.exists(dir): 371 | shutil.rmtree(dir) 372 | 373 | 374 | def get_prefix_files(file_dir, prefix): 375 | """ 376 | :param file_dir: 377 | :param prefix: "best*" 378 | :return: 379 | """ 380 | file_list = glob.glob(os.path.join(file_dir, prefix)) 381 | return file_list 382 | 383 | 384 | def remove_prefix_files(file_dir, prefix): 385 | """ 386 | :param file_dir: 387 | :param prefix: "best*" 388 | :return: 389 | """ 390 | file_list = get_prefix_files(file_dir, prefix) 391 | for file in file_list: 392 | if os.path.isfile(file): 393 | remove_file(file) 394 | elif os.path.isdir(file): 395 | remove_dir(file) 396 | 397 | 398 | def remove_file(path): 399 | """ 400 | remove files 401 | :param path: 402 | :return: 403 | """ 404 | if os.path.exists(path): 405 | os.remove(path) 406 | 407 | 408 | def remove_file_list(file_list): 409 | """ 410 | remove file list 411 | :param file_list: 412 | :return: 413 | """ 414 | for file_path in file_list: 415 | remove_file(file_path) 416 | 417 | 418 | def copy_dir_multi_thread(sync_source_root, sync_dest_dir, dataset, max_workers=1): 419 | """ 420 | :param sync_source_dir: 421 | :param sync_dest_dir: 422 | :param dataset: 423 | :return: 424 | """ 425 | 426 | def rsync_cmd(source_dir, dest_dir): 427 | cmd_line = "rsync -a {0} {1}".format(source_dir, dest_dir) 428 | # subprocess.call(cmd_line.split()) 429 | subprocess.call(cmd_line) 430 | 431 | sync_dest_dir = sync_dest_dir.rstrip('/') 432 | 433 | with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: 434 | future_to_rsync = {} 435 | for source_dir in dataset: 436 | sync_source_dir = os.path.join(sync_source_root, source_dir.strip('/')) 437 | future_to_rsync[executor.submit(rsync_cmd, sync_source_dir, sync_dest_dir)] = source_dir 438 | 439 | for future in concurrent.futures.as_completed(future_to_rsync): 440 | source_dir = future_to_rsync[future] 441 | try: 442 | _ = future.result() 443 | except Exception as exc: 444 | print("%s copy data generated an exception: %s" % (source_dir, exc)) 445 | else: 446 | print("%s copy data successful." % (source_dir,)) 447 | 448 | 449 | def copy_dir_delete(src, dst): 450 | """ 451 | copy src directory to dst directory,will detete the dst same directory 452 | :param src: 453 | :param dst: 454 | :return: 455 | """ 456 | if os.path.exists(dst): 457 | shutil.rmtree(dst) 458 | shutil.copytree(src, dst) 459 | # time.sleep(3 / 1000.) 460 | 461 | 462 | def copy_dir(src, dst): 463 | """ copy src-directory to dst-directory, will cover the same files""" 464 | if not os.path.exists(src): 465 | print("\nno src path:{}".format(src)) 466 | return 467 | for root, dirs, files in os.walk(src, topdown=False): 468 | dest_path = os.path.join(dst, os.path.relpath(root, src)) 469 | if not os.path.exists(dest_path): 470 | os.makedirs(dest_path) 471 | for filename in files: 472 | copy_file( 473 | os.path.join(root, filename), 474 | os.path.join(dest_path, filename) 475 | ) 476 | 477 | 478 | def move_file(srcfile, dstfile): 479 | """ 移动文件或重命名""" 480 | if not os.path.isfile(srcfile): 481 | print("%s not exist!" % (srcfile)) 482 | else: 483 | fpath, fname = os.path.split(dstfile) # 分离文件名和路径 484 | if not os.path.exists(fpath): 485 | os.makedirs(fpath) # 创建路径 486 | shutil.move(srcfile, dstfile) 487 | # print("copy %s -> %s"%( srcfile,dstfile)) 488 | # time.sleep(1 / 1000.) 489 | 490 | 491 | def copy_file(srcfile, dstfile): 492 | """ 493 | copy src file to dst file 494 | :param srcfile: 495 | :param dstfile: 496 | :return: 497 | """ 498 | if not os.path.isfile(srcfile): 499 | print("%s not exist!" % (srcfile)) 500 | else: 501 | fpath, fname = os.path.split(dstfile) # 分离文件名和路径 502 | if not os.path.exists(fpath): 503 | os.makedirs(fpath) # 创建路径 504 | shutil.copyfile(srcfile, dstfile) # 复制文件 505 | # print("copy %s -> %s"%( srcfile,dstfile)) 506 | # time.sleep(1 / 1000.) 507 | 508 | 509 | def copy_file_to_dir(srcfile, des_dir): 510 | if not os.path.isfile(srcfile): 511 | print("%s not exist!" % (srcfile)) 512 | else: 513 | fpath, fname = os.path.split(srcfile) # 分离文件名和路径 514 | if not os.path.exists(des_dir): 515 | os.makedirs(des_dir) # 创建路径 516 | dstfile = os.path.join(des_dir, fname) 517 | shutil.copyfile(srcfile, dstfile) # 复制文件 518 | 519 | 520 | def move_file_to_dir(srcfile, des_dir): 521 | if not os.path.isfile(srcfile): 522 | print("%s not exist!" % (srcfile)) 523 | else: 524 | fpath, fname = os.path.split(srcfile) # 分离文件名和路径 525 | if not os.path.exists(des_dir): 526 | os.makedirs(des_dir) # 创建路径 527 | dstfile = os.path.join(des_dir, fname) 528 | # shutil.copyfile(srcfile, dstfile) # 复制文件 529 | move_file(srcfile, dstfile) # 复制文件 530 | 531 | 532 | def merge_dir(src, dst, sub, merge_same): 533 | src_dir = os.path.join(src, sub) 534 | dst_dir = os.path.join(dst, sub) 535 | 536 | if not os.path.exists(src_dir): 537 | print("\nno src path:{}".format(src)) 538 | return 539 | if not os.path.exists(dst_dir): 540 | os.makedirs(dst_dir) 541 | elif not merge_same: 542 | t = get_time() 543 | dst_dir = os.path.join(dst, sub + "_{}".format(t)) 544 | print("have save sub:{}".format(dst_dir)) 545 | copy_dir(src_dir, dst_dir) 546 | 547 | 548 | # def merge_dir(src, dst, merge_same=False): 549 | # ''' 550 | # move and merge files, move/merge files from source directory to dst directory 551 | # root 所指的是当前正在遍历的这个文件夹的本身的地址 552 | # dirs 是一个 list ,内容是该文件夹中所有的目录的名字(不包括子目录) 553 | # files 同样是 list , 内容是该文件夹中所有的文件(不包括子目录) 554 | # :param source_dir: 555 | # :param dest_dir: 556 | # :return: 557 | # ''' 558 | # if not os.path.exists(src): 559 | # print("\nno src path:{}".format(src)) 560 | # for root, dirs, files in os.walk(src, topdown=False): 561 | # dest_path = os.path.join(dst, os.path.relpath(root, src)) 562 | # if not os.path.exists(dest_path): 563 | # os.makedirs(dest_path) 564 | # for filename in files: 565 | # copy_file( 566 | # os.path.join(root, filename), 567 | # os.path.join(dest_path, filename) 568 | # ) 569 | 570 | 571 | def create_dir(parent_dir, dir1=None, filename=None): 572 | """ 573 | create directory 574 | :param parent_dir: 575 | :param dir1: 576 | :param filename: 577 | :return: 578 | """ 579 | out_path = parent_dir 580 | if dir1: 581 | out_path = os.path.join(parent_dir, dir1) 582 | if not os.path.exists(out_path): 583 | os.makedirs(out_path) 584 | if filename: 585 | out_path = os.path.join(out_path, filename) 586 | return out_path 587 | 588 | 589 | def create_file_path(filename): 590 | """ 591 | create file in path 592 | :param filename: 593 | :return: 594 | """ 595 | basename = os.path.basename(filename) 596 | dirname = os.path.dirname(filename) 597 | out_path = create_dir(dirname, dir1=None, filename=basename) 598 | return out_path 599 | 600 | 601 | def merge_list(data1, data2): 602 | ''' 603 | 将两个list进行合并 604 | :param data1: 605 | :param data2: 606 | :return:返回合并后的list 607 | ''' 608 | if not len(data1) == len(data2): 609 | return 610 | all_data = [] 611 | for d1, d2 in zip(data1, data2): 612 | if not isinstance(d1, list): 613 | d1 = [d1] 614 | if not isinstance(d2, list): 615 | d2 = [d2] 616 | all_data.append(d1 + d2) 617 | return all_data 618 | 619 | 620 | def split_list(data, split_index=1): 621 | ''' 622 | 将data切分成两部分 623 | :param data: list 624 | :param split_index: 切分的位置 625 | :return: 626 | ''' 627 | data1 = [] 628 | data2 = [] 629 | for d in data: 630 | d1 = d[0:split_index] 631 | d2 = d[split_index:] 632 | data1.append(d1) 633 | data2.append(d2) 634 | return data1, data2 635 | 636 | 637 | def getFilePathList(file_dir): 638 | ''' 639 | 获取file_dir目录下,所有文本路径,包括子目录文件 640 | :param rootDir: 641 | :return: 642 | ''' 643 | filePath_list = [] 644 | for walk in os.walk(file_dir): 645 | part_filePath_list = [os.path.join(walk[0], file).replace("\\", "/") for file in walk[2]] 646 | filePath_list.extend(part_filePath_list) 647 | return filePath_list 648 | 649 | 650 | def get_sub_directory_list(input_dir): 651 | ''' 652 | 当前路径下所有子目录 653 | :param input_dir: 654 | :return: 655 | ''' 656 | dirs_list = [] 657 | for root, dirs, files in os.walk(input_dir): 658 | dirs_list = dirs 659 | break 660 | # print(root) # 当前目录路径 661 | # print(dirs) # 当前路径下所有子目录 662 | # print(files) # 当前路径下所有非目录子文件 663 | dirs_list.sort() 664 | return dirs_list 665 | 666 | 667 | def get_sub_paths(path_list=[], parent=""): 668 | """ 669 | :param path_list: 670 | :param parent: 671 | :return: 672 | """ 673 | out_dir_list = [] 674 | if not parent.endswith(os.sep): 675 | parent = parent + os.sep 676 | l = len(parent) 677 | for path in path_list: 678 | if parent in path: 679 | p = path[l:] 680 | out_dir_list.append(p) 681 | return out_dir_list 682 | 683 | 684 | def get_files_lists(image_dir, subname="", postfix=["*.jpg", "*.png"], shuffle=False): 685 | """ 686 | 读取文件和列表: list,*.txt ,image path, directory 687 | :param image_dir: list,*.txt ,image path, directory 688 | :param subname: "JPEGImages" 689 | :return: 690 | """ 691 | if isinstance(image_dir, list): 692 | image_list = image_dir 693 | elif image_dir.endswith(".txt"): 694 | data_root = os.path.dirname(image_dir) 695 | image_list = read_data(image_dir) 696 | image_list = [os.path.join(data_root, subname, str(name[0]) + ".jpg") for name in image_list] 697 | elif os.path.isdir(image_dir): 698 | image_list = get_files(image_dir, postfix=postfix) 699 | elif os.path.isfile(image_dir): 700 | image_list = [image_dir] 701 | else: 702 | raise Exception("Error:{}".format(image_dir)) 703 | 704 | if shuffle: 705 | random.seed(100) 706 | random.shuffle(image_list) 707 | return image_list 708 | 709 | 710 | def get_files(file_dir, postfix=None): 711 | ''' 712 | 获得file_dir目录下,后缀名为postfix所有文件列表,包括子目录 713 | :param file_dir: 714 | :param postfix: ['*.jpg','*.png'],postfix=None表示全部文件 715 | :return: 716 | ''' 717 | file_list = [] 718 | filePath_list = getFilePathList(file_dir) 719 | if postfix is None: 720 | file_list = filePath_list 721 | else: 722 | postfix = [p.split('.')[-1] for p in postfix] 723 | for file in filePath_list: 724 | basename = os.path.basename(file) # 获得路径下的文件名 725 | postfix_name = basename.split('.')[-1] 726 | if postfix_name.lower() in postfix: 727 | file_list.append(file) 728 | file_list.sort() 729 | return file_list 730 | 731 | 732 | def get_files_labels(files_dir, postfix=None): 733 | ''' 734 | 获取files_dir路径下所有文件路径,以及labels,其中labels用子级文件名表示 735 | files_dir目录下,同一类别的文件放一个文件夹,其labels即为文件的名 736 | :param files_dir: 737 | :postfix 后缀名 738 | :return:filePath_list所有文件的路径,label_list对应的labels 739 | ''' 740 | # filePath_list = getFilePathList(files_dir) 741 | filePath_list = get_files(files_dir, postfix=postfix) 742 | print("files nums:{}".format(len(filePath_list))) 743 | # 获取所有样本标签 744 | label_list = [] 745 | for filePath in filePath_list: 746 | label = filePath.split(os.sep)[-2] 747 | label_list.append(label) 748 | 749 | labels_set = list(set(label_list)) 750 | # print("labels:{}".format(labels_set)) 751 | 752 | # 标签统计计数 753 | # print(pd.value_counts(label_list)) 754 | return filePath_list, label_list 755 | 756 | 757 | def decode_label(label_list, name_table): 758 | ''' 759 | 根据name_table解码label 760 | :param label_list: 761 | :param name_table: 762 | :return: 763 | ''' 764 | name_list = [] 765 | for label in label_list: 766 | name = name_table[label] 767 | name_list.append(name) 768 | return name_list 769 | 770 | 771 | def encode_label(name_list, name_table, unknow=0): 772 | ''' 773 | 根据name_table,编码label 774 | :param name_list: 775 | :param name_table: 776 | :param unknow :未知的名称,默认label为0,一般在name_table中index=0是背景,未知的label也当做背景处理 777 | :return: 778 | ''' 779 | label_list = [] 780 | # name_table = {name_table[i]: i for i in range(len(name_table))} 781 | for name in name_list: 782 | if name in name_table: 783 | index = name_table.index(name) 784 | else: 785 | index = unknow 786 | label_list.append(index) 787 | return label_list 788 | 789 | 790 | def list2dict(data): 791 | """ 792 | convert list to dict 793 | :param data: 794 | :return: 795 | """ 796 | data = {data[i]: i for i in range(len(data))} 797 | return data 798 | 799 | 800 | def print_dict(dict_data, save_path): 801 | """ 802 | print dict info 803 | :param dict_data: 804 | :param save_path: 805 | :return: 806 | """ 807 | list_config = [] 808 | for key in dict_data: 809 | info = "conf.{}={}".format(key, dict_data[key]) 810 | print(info) 811 | list_config.append(info) 812 | if save_path is not None: 813 | with open(save_path, "w") as f: 814 | for info in list_config: 815 | f.writelines(info + "\n") 816 | 817 | 818 | def read_pair_data(filename, split=True): 819 | ''' 820 | read pair data,data:[image1.jpg image2.jpg 0] 821 | :param filename: 822 | :param split: 823 | :return: 824 | ''' 825 | content_list = read_data(filename) 826 | if split: 827 | content_list = np.asarray(content_list) 828 | faces_list1 = content_list[:, :1].reshape(-1) 829 | faces_list2 = content_list[:, 1:2].reshape(-1) 830 | # convert to 0/1 831 | issames_data = np.asarray(content_list[:, 2:3].reshape(-1), dtype=np.int) 832 | issames_data = np.where(issames_data > 0, 1, 0) 833 | faces_list1 = faces_list1.tolist() 834 | faces_list2 = faces_list2.tolist() 835 | issames_data = issames_data.tolist() 836 | return faces_list1, faces_list2, issames_data 837 | return content_list 838 | 839 | 840 | def check_files(files_list, sizeTh=1 * 1024, isRemove=False): 841 | ''' 去除不存的文件和文件过小的文件列表 842 | :param files_list: 843 | :param sizeTh: 文件大小阈值,单位:字节B,默认1000B ,33049513/1024/1024=33.0MB 844 | :param isRemove: 是否在硬盘上删除被损坏的原文件 845 | :return: 846 | ''' 847 | i = 0 848 | while i < len(files_list): 849 | path = files_list[i] 850 | # 判断文件是否存在 851 | if not (os.path.exists(path)): 852 | print(" non-existent file:{}".format(path)) 853 | files_list.pop(i) 854 | continue 855 | # 判断文件是否为空 856 | f_size = os.path.getsize(path) 857 | if f_size < sizeTh: 858 | print(" empty file:{}".format(path)) 859 | if isRemove: 860 | os.remove(path) 861 | print(" info:----------------remove image_dict:{}".format(path)) 862 | files_list.pop(i) 863 | continue 864 | i += 1 865 | return files_list 866 | 867 | 868 | def get_files_id(file_list): 869 | """ 870 | :param file_list: 871 | :return: 872 | """ 873 | image_idx = [] 874 | for path in file_list: 875 | basename = os.path.basename(path) 876 | id = basename.split(".")[0] 877 | image_idx.append(id) 878 | return image_idx 879 | 880 | 881 | def get_loacl_eth2(): 882 | ''' 883 | 想要获取linux设备网卡接口,并用列表进行保存 884 | :return: 885 | ''' 886 | eth_list = [] 887 | os.system("ls -l /sys/class/net/ | grep -v virtual | sed '1d' | awk 'BEGIN {FS=\"/\"} {print $NF}' > eth.yaml") 888 | try: 889 | with open('./eth.yaml', "r") as f: 890 | for line in f.readlines(): 891 | line = line.strip() 892 | eth_list.append(line.lower()) 893 | except Exception as e: 894 | print(e) 895 | eth_list = [] 896 | return eth_list 897 | 898 | 899 | def get_loacl_eth(): 900 | ''' 901 | 想要获取linux设备网卡接口,并用列表进行保存 902 | :return: 903 | ''' 904 | eth_list = [] 905 | cmd = "ls -l /sys/class/net/ | grep -v virtual | sed '1d' | awk 'BEGIN {FS=\"/\"} {print $NF}'" 906 | try: 907 | with os.popen(cmd) as f: 908 | for line in f.readlines(): 909 | line = line.strip() 910 | eth_list.append(line.lower()) 911 | except Exception as e: 912 | print(e, "can not found eth,will set default eth is:eth0") 913 | eth_list = ["eth0"] 914 | if not eth_list: 915 | eth_list = ["eth0"] 916 | return eth_list 917 | 918 | 919 | def merge_files(files_list): 920 | """ 921 | 合并文件列表 922 | :return: 923 | """ 924 | content_list = [] 925 | for file in files_list: 926 | data = read_data(file) 927 | 928 | return content_list 929 | 930 | 931 | def multi_thread_task(content_list, func, num_processes=4, remove_bad=False, Async=True, **kwargs): 932 | """ 933 | 多线程处理content_list的数据 934 | Usage: 935 | def task_fun(item, save_root): 936 | ''' 937 | :param item: 对应content_list的每一项item 938 | :param save_root: 对应kwargs 939 | :return: 940 | ''' 941 | pass 942 | multi_thread_task(content_list, 943 | func=task_fun, 944 | num_processes=num_processes, 945 | remove_bad=remove_bad, 946 | Async=Async, 947 | save_root=save_root) 948 | ===================================================== 949 | :param content_list: content_list 950 | :param func: func:task function 951 | :param num_processes: 开启线程个数 952 | :param remove_bad: 是否去除下载失败的数据 953 | :param Async:是否异步 954 | :param kwargs:需要传递给func的相关参数 955 | :return: 返回图片的存储地址列表 956 | """ 957 | from multiprocessing.pool import ThreadPool 958 | # 开启多线程 959 | pool = ThreadPool(processes=num_processes) 960 | thread_list = [] 961 | for item in content_list: 962 | if Async: 963 | out = pool.apply_async(func=func, args=(item,), kwds=kwargs) # 异步 964 | else: 965 | out = pool.apply(func=func, args=(item,), kwds=kwargs) # 同步 966 | thread_list.append(out) 967 | 968 | pool.close() 969 | pool.join() 970 | # 获取输出结果 971 | dst_content_list = [] 972 | if Async: 973 | for p in thread_list: 974 | image = p.get() # get会阻塞 975 | dst_content_list.append(image) 976 | else: 977 | dst_content_list = thread_list 978 | if remove_bad: 979 | dst_content_list = [i for i in dst_content_list if i is not None] 980 | return dst_content_list 981 | 982 | 983 | if __name__ == '__main__': 984 | parent = "/media/dm/dm1/git/python-learning-notes/dataset/dataset" 985 | dir_list = getFilePathList(parent) 986 | dir_list1 = get_sub_paths(dir_list, parent) 987 | print(dir_list1) 988 | -------------------------------------------------------------------------------- /audio/utils/log.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import datetime 5 | import logging 6 | import threading 7 | import re 8 | import time 9 | from logging.handlers import TimedRotatingFileHandler 10 | from memory_profiler import profile 11 | import threading 12 | 13 | 14 | def singleton(cls): 15 | _instance_lock = threading.Lock() 16 | instances = {} 17 | 18 | def _singleton(*args, **kwargs): 19 | with _instance_lock: 20 | if cls not in instances: 21 | instances[cls] = cls(*args, **kwargs) 22 | return instances[cls] 23 | 24 | return _singleton 25 | 26 | 27 | @singleton # 使用singleton,会出现loger的level失效的问题 28 | class CustomLogger(logging.Logger): 29 | def __init__(self, name="LOG", level="debug"): 30 | """ 31 | Initialize the logger with a name and an optional level. 32 | Args: 33 | name: 34 | level: debug,info,warning,critical,fatal 35 | """ 36 | super().__init__(name) 37 | # super(CustomLogger, self).__init__(name) 38 | self.setLevel(level=level) 39 | 40 | @staticmethod 41 | def levels(level): 42 | if level == 'debug': 43 | return logging.DEBUG 44 | if level == 'info': 45 | return logging.INFO 46 | if level == 'warning': 47 | return logging.WARN 48 | if level == 'critical': 49 | return logging.CRITICAL 50 | if level == 'fatal': 51 | return logging.FATAL 52 | return logging.DEBUG 53 | 54 | def setLevel(self, level): 55 | """ 56 | Args: 57 | level: debug,info,warning,critical,fatal 58 | Returns: 59 | """ 60 | level = self.levels(level) 61 | super().setLevel(level) 62 | 63 | @staticmethod 64 | def set_format(handler, format): 65 | # handler.suffix = "%Y%m%d" 66 | if format: 67 | logFormatter = logging.Formatter("%(asctime)s %(filename)s %(funcName)s %(levelname)s: %(message)s", 68 | "%Y-%m-%d %H:%M:%S") 69 | else: 70 | logFormatter = logging.Formatter("%(levelname)s: %(message)s") 71 | handler.setFormatter(logFormatter) 72 | 73 | def show_batch_tensor(self, title, batch_imgs, index=0): 74 | pass 75 | 76 | 77 | class FileHandler(TimedRotatingFileHandler): 78 | def __init__(self, filename, when='h', interval=1, backupCount=0, encoding=None, delay=False, utc=False, 79 | atTime=None): 80 | logging.handlers.BaseRotatingHandler.__init__(self, filename, 'a', encoding, delay) 81 | self.when = when.upper() 82 | self.backupCount = backupCount 83 | self.utc = utc 84 | self.atTime = atTime 85 | if self.when == 'S': 86 | self.interval = 1 # one second 87 | self.suffix = "%Y-%m-%d_%H-%M-%S" 88 | self.extMatch = r"^\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}(\.\w+)?$" 89 | elif self.when == 'M': 90 | self.interval = 60 # one minute 91 | self.suffix = "%Y-%m-%d_%H-%M" 92 | self.extMatch = r"^\d{4}-\d{2}-\d{2}_\d{2}-\d{2}(\.\w+)?$" 93 | elif self.when == 'H': 94 | self.interval = 60 * 60 # one hour 95 | self.suffix = "%Y-%m-%d_%H" 96 | self.extMatch = r"^\d{4}-\d{2}-\d{2}_\d{2}(\.\w+)?$" 97 | elif self.when == 'D' or self.when == 'MIDNIGHT': 98 | self.interval = 60 * 60 * 24 # one day 99 | self.suffix = "%Y-%m-%d" 100 | self.extMatch = r"^\d{4}-\d{2}-\d{2}(\.\w+)?$" 101 | elif self.when.startswith('W'): 102 | self.interval = 60 * 60 * 24 * 7 # one week 103 | if len(self.when) != 2: 104 | raise ValueError("You must specify a day for weekly rollover from 0 to 6 (0 is Monday): %s" % self.when) 105 | if self.when[1] < '0' or self.when[1] > '6': 106 | raise ValueError("Invalid day specified for weekly rollover: %s" % self.when) 107 | self.dayOfWeek = int(self.when[1]) 108 | self.suffix = "%Y-%m-%d" 109 | self.extMatch = r"^\d{4}-\d{2}-\d{2}(\.\w+)?$" 110 | elif self.when == 'Y': 111 | self.interval = 60 * 60 * 24 * 365 # one yes 112 | self.suffix = "%Y-%m-%d" 113 | self.extMatch = r"^\d{4}-\d{2}-\d{2}(\.\w+)?$" 114 | else: 115 | raise ValueError("Invalid rollover interval specified: %s" % self.when) 116 | 117 | self.extMatch = re.compile(self.extMatch, re.ASCII) 118 | self.interval = self.interval * interval # multiply by units requested 119 | # The following line added because the filename passed in could be a 120 | # path object (see Issue #27493), but self.baseFilename will be a string 121 | filename = self.baseFilename 122 | if os.path.exists(filename): 123 | t = os.stat(filename)[logging.handlers.ST_MTIME] 124 | else: 125 | t = int(time.time()) 126 | self.rolloverAt = self.computeRollover(t) 127 | 128 | 129 | def set_logger(name="LOG", level="debug", logfile=None, format=False, is_main_process=True): 130 | """ 131 | logger = set_logging(name="LOG", level="debug", logfile="log.txt", format=False) 132 | url:https://cuiqingcai.com/6080.html 133 | level级别:debug>info>warning>error>critical 134 | :param level: 设置log输出级别 135 | :param logfile: log保存路径,如果为None,则在控制台打印log 136 | :param is_main_process: 是否是主进程 137 | :return: 138 | """ 139 | if not is_main_process: 140 | level = "fatal" 141 | logfile = None 142 | # logger = logging.getLogger(name) 143 | logger = CustomLogger(name, level=level) 144 | if logfile and os.path.exists(logfile): 145 | os.remove(logfile) 146 | # define a FileHandler write messages to file 147 | if logfile: 148 | # filehandler = logging.handlers.RotatingFileHandler(filename="./log.txt") 149 | # filehandler = TimedRotatingFileHandler(logfile, when="midnight", interval=1) 150 | filehandler = FileHandler(logfile, when="Y", interval=1) 151 | logger.set_format(filehandler, format) 152 | logger.addHandler(filehandler) 153 | 154 | # define a StreamHandler print messages to console 155 | console = logging.StreamHandler() 156 | logger.set_format(console, format) 157 | logger.addHandler(console) 158 | return logger 159 | 160 | 161 | def print_args(args): 162 | logger = get_logger() 163 | logger.info("---" * 10) 164 | args = args.__dict__ 165 | for k, v in args.items(): 166 | # print("{}: {}".format(k, v)) 167 | logger.info("{}: {}".format(k, v)) 168 | logger.info("---" * 10) 169 | 170 | 171 | def get_logger(name="LOG", level="debug"): 172 | logger = CustomLogger(name) 173 | # if logger.isEnabledFor(CustomLogger.levels(level)): 174 | # logger = CustomLogger(name, level=level) 175 | return logger 176 | 177 | 178 | def RUN_TIME(deta_time): 179 | ''' 180 | 计算时间差,返回毫秒,deta_time.seconds获得秒数=1000ms,deta_time.microseconds获得微妙数=1/1000ms 181 | :param deta_time: ms 182 | :return: 183 | ''' 184 | time_ = deta_time.seconds * 1000 + deta_time.microseconds / 1000.0 185 | return time_ 186 | 187 | 188 | def TIME(): 189 | ''' 190 | 获得当前时间 191 | :return: 192 | ''' 193 | return datetime.datetime.now() 194 | 195 | 196 | def run_time_decorator(title=""): 197 | def decorator(func): 198 | def wrapper(*args, **kwargs): 199 | # torch.cuda.synchronize() 200 | T0 = TIME() 201 | result = func(*args, **kwargs) 202 | # torch.cuda.synchronize() 203 | T1 = TIME() 204 | print("{}-- function : {}-- rum time : {}ms ".format(title, func.__name__, RUN_TIME(T1 - T0))) 205 | # logger.debug("{}-- function : {}-- rum time : {}s ".format(title, func.__name__, RUN_TIME(T1 - T0)/1000.0)) 206 | return result 207 | 208 | return wrapper 209 | 210 | return decorator 211 | 212 | 213 | @profile(precision=4) 214 | def memory_test(): 215 | """ 216 | 1.先导入: 217 | > from memory_profiler import profile 218 | 2.函数前加装饰器: 219 | > @profile(precision=4,stream=open('memory_profiler.log','w+')) 220 |    参数含义:precision:精确到小数点后几位 221 |    stream:此模块分析结果保存到 'memory_profiler.log' 日志文件。如果没有此参数,分析结果会在控制台输出 222 | :return: 223 | """ 224 | c = 0 225 | for item in range(10): 226 | c += 1 227 | # logger.error("c:{}".format(c)) 228 | # print(c) 229 | 230 | 231 | if __name__ == '__main__': 232 | # logger = set_logger(name="LOG", level="warning", logfile="log.txt", format=False) 233 | # T0 = TIME() 234 | # do something 235 | # T1 = TIME() 236 | # print("rum time:{}ms".format(RUN_TIME(T1 - T0))) 237 | # t_logger = set_logging(name=__name__, level="info", logfile=None) 238 | # t_logger.debug('debug') 239 | # t_logger.info('info') 240 | # t_logger.warning('Warning exists') 241 | # t_logger.error('Finish') 242 | # memory_test() 243 | # logger1 = set_logger(name="LOG", level="debug", logfile="log.txt", format=False) 244 | logger = set_logger(logfile=None, level="debug") 245 | logger1 = get_logger() 246 | logger1.info("---" * 20) 247 | logger1.debug("work_space:{}".format("work_dir")) 248 | logger1.info("work_space:{}".format("work_dir")) 249 | logger1.error("work_space:{}".format("work_dir")) 250 | logger1.fatal("work_space:{}".format("work_dir")) 251 | # logger1.show_batch_tensor() 252 | -------------------------------------------------------------------------------- /audio/utils/plot_utils.py: -------------------------------------------------------------------------------- 1 | # -*-coding: utf-8 -*- 2 | 3 | # 导入需要用到的库 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import PIL.Image as Image 7 | 8 | 9 | def plot_bar(x_data, y_data, title, xlabel, ylabel, isshow=False): 10 | # 准备数据 11 | # 用 Matplotlib 画条形图 12 | plt.bar(x_data, y_data) 13 | # plt.xlim([0.0, 1.0]) 14 | # plt.ylim([0.0, 1.05]) 15 | # 设置横纵坐标的名称以及对应字体格式 16 | font = {'family': 'Times New Roman', 17 | 'weight': 'normal', 18 | 'size': 10, 19 | } 20 | plt.xlabel(xlabel, font) 21 | plt.ylabel(ylabel, font) 22 | 23 | plt.title(title) 24 | plt.legend(loc="lower right") # "upper right" 25 | # plt.legend(loc="upper right")#"upper right" 26 | plt.grid(True) # 显示网格; 27 | plt.savefig('out.png') 28 | if isshow: 29 | plt.show() 30 | 31 | 32 | def plot_multi_line(x_data_list, y_data_list, line_names=None, title="", xlabel="", ylabel=""): 33 | # 绘图 34 | # plt.figure() 35 | lw = 2 36 | plt.figure(figsize=(10, 10)) 37 | colors = ["b", "r", "c", "m", "g", "y", "k", "w"] 38 | xlim_max = 0 39 | ylim_max = 0 40 | 41 | xlim_min = 0 42 | ylim_min = 0 43 | if not line_names: 44 | line_names = " " * len(x_data_list) 45 | for x, y, color, line_name in zip(x_data_list, y_data_list, colors, line_names): 46 | plt.plot(x, y, color=color, lw=lw, label=line_name) # 假正率为横坐标,真正率为纵坐标做曲线 47 | if xlim_max < max(x): 48 | xlim_max = max(x) 49 | if ylim_max < max(y): 50 | ylim_max = max(y) 51 | if xlim_min > min(x): 52 | xlim_min = min(x) 53 | if ylim_min > min(y): 54 | ylim_min = min(y) 55 | # plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--') 56 | # plt.plot([0, 1], [1, 0], color='navy', lw=lw, linestyle='--') # 绘制y=1-x的直线 57 | x_deta = xlim_max - xlim_min 58 | y_deta = ylim_max - ylim_min 59 | plt.xlim([xlim_min - 0.01 * x_deta, xlim_max + 0.1 * x_deta]) 60 | plt.ylim([ylim_min - 0.01 * y_deta, ylim_max + 0.1 * y_deta]) 61 | # 设置横纵坐标的名称以及对应字体格式 62 | font = {'family': 'Times New Roman', 63 | 'weight': 'normal', 64 | 'size': 20, 65 | } 66 | plt.xlabel(xlabel, font) 67 | plt.ylabel(ylabel, font) 68 | 69 | plt.title(title) 70 | plt.legend(loc="lower right") # "upper right" 71 | # plt.legend(loc="upper right")#"upper right" 72 | plt.grid(True) # 显示网格; 73 | plt.show() 74 | 75 | 76 | def plot_skew_kurt(data, name="Title"): 77 | """ 78 | https://blog.csdn.net/u012735708/article/details/84750295 79 | 计算偏度(skew)和峰度(kurt) 80 | :return: 81 | """ 82 | import pandas as pd 83 | plt.figure(figsize=(10, 10)) 84 | skew = pd.Series(data).skew() 85 | kurt = pd.Series(data).kurt() 86 | info = 'skew={:.4f},kurt={:.4f},mean:{:.4f}'.format(skew, kurt, np.mean(data)) # 标注 87 | info = "{}:\n{}".format(name, info) 88 | plt.title(info) 89 | print(info) 90 | plt.hist(data, 100, facecolor='r', alpha=0.9) 91 | plt.grid(True) 92 | plt.show() 93 | 94 | 95 | def demo(image1, image2): 96 | fig = plt.figure(2) # 新开一个窗口 97 | # fig1 98 | ax1 = fig.add_subplot(1, 2, 1) 99 | ax1.imshow(image1) 100 | ax1.set_title("image1") 101 | 102 | # fig2 103 | ax2 = fig.add_subplot(1, 2, 2) 104 | ax2.imshow(image2) 105 | ax2.set_title("image2") 106 | plt.show() 107 | 108 | 109 | def demo_for_skew_kurt(): 110 | """ 111 | https://blog.csdn.net/u012735708/article/details/84750295 112 | 计算偏度(skew)和峰度(kurt) 113 | :return: 114 | """ 115 | import numpy as np 116 | data = list(np.random.randn(10000)) 117 | plot_skew_kurt(data) 118 | 119 | 120 | if __name__ == "__main__": 121 | import cv2 122 | 123 | # image_path="/media/dm/dm1/git/python-learning-notes/dataset/test_image/1.jpg" 124 | # image=cv2.imread(image_path) 125 | # image1=cv2.resize(image,dsize=(100,100)) 126 | # demo(image, image1) 127 | demo_for_skew_kurt() 128 | -------------------------------------------------------------------------------- /audio/utils/setup_config.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | """ 3 | 提供工具函数的模块 4 | """ 5 | import os 6 | import argparse 7 | import numbers 8 | import easydict 9 | import yaml 10 | from . import file_utils 11 | 12 | 13 | def parser_work_space(cfg, flags: list = [], time=True): 14 | """生成工程空间 15 | flag = [cfg.net_type, cfg.width_mult, cfg.input_size[0], cfg.input_size[1], 16 | cfg.loss_type, cfg.optim_type, flag, file_utils.get_time()] 17 | """ 18 | if isinstance(flags, str): 19 | flags = [flags] 20 | if time: 21 | flags += [file_utils.get_time()] 22 | name = [str(n) for n in flags if n] 23 | name = "_".join(name) 24 | work_dir = os.path.join(cfg.work_dir, name) 25 | return work_dir 26 | 27 | 28 | def parser_config(args: argparse.Namespace, cfg_updata: bool = True): 29 | """ 30 | 解析并合并配置参数:(1)命令行argparse (2)使用*.yaml配置文件 31 | :param args: 命令行参数 32 | :param cfg_updata:True: 合并配置参数时,相同参数由*.yaml文件参数决定 33 | False: 合并配置参数时,相同参数由命令行argparse参数决定 34 | :return: 35 | """ 36 | if "config_file" in args and args.config_file: 37 | cfg = load_config(args.config_file) 38 | if cfg_updata: 39 | cfg = dict(args.__dict__, **cfg) 40 | else: 41 | cfg = dict(cfg, **args.__dict__) 42 | cfg["config_file"] = args.config_file 43 | else: 44 | cfg = args.__dict__ 45 | cfg['config_file'] = save_config(cfg, 'args_config.yaml') 46 | print_dict(cfg) 47 | cfg = easydict.EasyDict(cfg) 48 | return cfg 49 | 50 | 51 | def parser_config_file(config: easydict.EasyDict, config_file: str, cfg_updata: bool = True): 52 | """ 53 | 解析并合并配置参数 54 | :param config: EasyDict参数 55 | :param cfg_updata:True: 合并配置参数时,相同参数由config参数决定 56 | False: 合并配置参数时,相同参数由config_file中的参数决定 57 | :return: 58 | """ 59 | cfg = load_config(config_file) 60 | if cfg_updata: 61 | cfg = dict(cfg, **config.__dict__) 62 | else: 63 | cfg = dict(config.__dict__, **cfg) 64 | print_dict(cfg) 65 | cfg = easydict.EasyDict(cfg) 66 | return cfg 67 | 68 | 69 | class Dict2Obj: 70 | ''' 71 | dict转类对象 72 | ''' 73 | 74 | def __init__(self, args): 75 | self.__dict__.update(args) 76 | 77 | 78 | def load_config(config_file='config.yaml'): 79 | """ 80 | 读取配置文件,并返回一个python dict 对象 81 | :param config_file: 配置文件路径 82 | :return: python dict 对象 83 | """ 84 | with open(config_file, 'r', encoding="UTF-8") as stream: 85 | try: 86 | config = yaml.load(stream, Loader=yaml.FullLoader) 87 | # config = Dict2Obj(config) 88 | except yaml.YAMLError as e: 89 | print(e) 90 | return None 91 | return config 92 | 93 | 94 | def save_config(cfg: dict, config_file='config.yaml'): 95 | """保存yaml文件""" 96 | if isinstance(cfg, easydict.EasyDict) or isinstance(cfg, argparse.Namespace): 97 | cfg = cfg.__dict__ 98 | fw = open(config_file, 'w', encoding='utf-8') 99 | yaml.dump(cfg, fw) 100 | return config_file 101 | 102 | 103 | def print_dict(dict_data, save_path=None): 104 | list_config = [] 105 | print("=" * 60) 106 | for key in dict_data: 107 | info = "{}: {}".format(key, dict_data[key]) 108 | print(info) 109 | list_config.append(info) 110 | if save_path is not None: 111 | with open(save_path, "w") as f: 112 | for info in list_config: 113 | f.writelines(info + "\n") 114 | print("=" * 60) 115 | 116 | 117 | if __name__ == '__main__': 118 | data = None 119 | config_file = "config.yaml" 120 | save_config(data, config_file) 121 | -------------------------------------------------------------------------------- /audio/utils/summary.py: -------------------------------------------------------------------------------- 1 | # -*-coding: utf-8 -*- 2 | 3 | import os 4 | import sys 5 | import tensorboardX as tensorboard 6 | # from torch.utils import tensorboard 7 | 8 | 9 | class SummaryWriter(): 10 | def __init__(self, log_dir, *args, **kwargs): 11 | self.tensorboard = None 12 | if log_dir: 13 | # 修复tensorboard版本BUG 14 | self.tensorboard = tensorboard.SummaryWriter(log_dir, *args, **kwargs) 15 | 16 | def add_scalar(self, *args, **kwargs): 17 | if self.tensorboard: 18 | self.tensorboard.add_scalar(*args, **kwargs) 19 | 20 | def add_scalars(self, *args, **kwargs): 21 | if self.tensorboard: 22 | self.tensorboard.add_scalars(*args, **kwargs) 23 | 24 | def add_image(self, *args, **kwargs): 25 | if self.tensorboard: 26 | self.tensorboard.add_image(*args, **kwargs) 27 | 28 | 29 | if __name__ == '__main__': 30 | main_process = True 31 | log_root = "./" 32 | writer1 = SummaryWriter(log_root if main_process else None) 33 | # main_process=False 34 | writer2 = SummaryWriter(log_root if main_process else None) 35 | writer3 = SummaryWriter(log_root if main_process else None) 36 | epochs = 200 37 | for epoch in range(epochs): 38 | if writer3: 39 | print(writer3) 40 | writer1.add_scalar("lr_epoch", epoch, epoch) 41 | writer2.add_scalar("lr_epoch", epoch, epoch) 42 | writer3.add_scalar("lr_epoch", epoch, epoch) 43 | -------------------------------------------------------------------------------- /audio/utils/torch_data.py: -------------------------------------------------------------------------------- 1 | # -*-coding: utf-8 -*- 2 | 3 | import torch.nn as nn 4 | import torch.utils.data as torch_utils 5 | from torch.utils.data.dataset import Dataset 6 | from torch.utils.data.dataloader import DataLoader 7 | from ..engine import comm 8 | from .torch_tools import get_torch_version 9 | 10 | 11 | def build_dataloader(dataset: Dataset, 12 | batch_size: int, 13 | num_workers: int, 14 | shuffle: bool = True, 15 | persistent_workers: bool = True, 16 | phase: str = "train", 17 | distributed=True, 18 | **kwargs) -> DataLoader: 19 | """ 20 | :param dataset: Dataset 21 | :param batch_size: 22 | :param num_workers: 23 | :param shuffle: 24 | :param persistent_workers: 该参数仅支持torch>=1.6 25 | False: 数据加载器运行完一个Epoch后会关闭worker进程,在分布式训练,会出现每个epoch初始化多进程的问题 26 | True: 会保持worker进程实例激活状态 27 | :param phase: "train", "test", "val" 28 | :param distributed: True: use DDP; False: use DP (是否使用分布式训练) 29 | :param kwargs: 30 | :return: 31 | """ 32 | assert phase in ["train", "test", "val"] 33 | sampler = None 34 | if comm.get_world_size() > 1 and phase == "train" and distributed: 35 | # DistributedSampler为每个子进程分发数据,避免数据重复 36 | sampler = torch_utils.distributed.DistributedSampler(dataset, 37 | num_replicas=comm.get_world_size(), 38 | rank=comm.get_local_rank(), 39 | shuffle=shuffle) 40 | shuffle = False # sampler option is mutually exclusive with shuffle 41 | else: 42 | # Fix a Bug: RuntimeError: can't start new thread 43 | persistent_workers = False 44 | try: 45 | # Fix a Bug: torch<=1.6 have no argument 'persistent_workers' 46 | if get_torch_version() >= 1.7: 47 | kwargs["persistent_workers"] = persistent_workers 48 | # fix a bug: persistent_workers option needs num_workers > 0 49 | if persistent_workers and num_workers == 0: 50 | kwargs["persistent_workers"] = False 51 | except: 52 | print("torch<=1.6 have no argument persistent_workers") 53 | dataloader = torch_utils.DataLoader(dataset, 54 | batch_size=batch_size, 55 | num_workers=num_workers, 56 | sampler=sampler, 57 | shuffle=shuffle, 58 | **kwargs) 59 | return dataloader 60 | 61 | 62 | def build_model_parallel(model: nn.Module, 63 | device_ids=None, 64 | distributed=True, 65 | **kwargs) -> nn.Module: 66 | """ 67 | :param model: 68 | :param device_ids: 69 | :param distributed: True: use DDP; False: use DP (是否使用分布式训练) 70 | :param kwargs: 71 | :return: 72 | """ 73 | print("device_ids:{},device:{}".format(device_ids, comm.get_device(device_ids))) 74 | model.to(comm.get_device(device_ids)) 75 | # use DistributedDataParallel 76 | if comm.get_world_size() > 1 and distributed: 77 | model = nn.parallel.DistributedDataParallel(model, 78 | device_ids=[comm.get_device(device_ids)], 79 | output_device=comm.get_device(device_ids), 80 | **kwargs 81 | ) 82 | else: 83 | # use DataParallel 84 | model = nn.DataParallel(model, device_ids=device_ids, output_device=comm.get_device(device_ids), **kwargs) 85 | return model 86 | -------------------------------------------------------------------------------- /audio/utils/torch_tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import random 5 | import os 6 | import numpy as np 7 | from collections import OrderedDict 8 | from collections.abc import Iterable 9 | 10 | 11 | def get_torch_version(): 12 | try: 13 | v = torch.__version__ 14 | print("torch.version:{}".format(v)) 15 | vid = v.split(".") 16 | vid = float("{}.{}".format(vid[0], vid[1])) 17 | except Exception as e: 18 | vid = None 19 | return vid 20 | 21 | 22 | def set_env_random_seed(seed=2020): 23 | """ 24 | :param seed: 25 | :return: 26 | """ 27 | random.seed(seed) 28 | os.environ['PYTHONHASHSEED'] = str(seed) 29 | np.random.seed(seed) 30 | torch.manual_seed(seed) 31 | torch.cuda.manual_seed(seed) 32 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 33 | # torch.backends.cudnn.benchmark = False 34 | # torch.backends.cudnn.deterministic = True 35 | 36 | 37 | def get_loacl_eth(): 38 | ''' 39 | 想要获取linux设备网卡接口,并用列表进行保存 40 | :return: 41 | ''' 42 | eth_list = [] 43 | cmd = "ls -l /sys/class/net/ | grep -v virtual | sed '1d' | awk 'BEGIN {FS=\"/\"} {print $NF}'" 44 | try: 45 | with os.popen(cmd) as f: 46 | for line in f.readlines(): 47 | line = line.strip() 48 | eth_list.append(line.lower()) 49 | except Exception as e: 50 | print(e, "can not found eth,will set default eth is:eth0") 51 | eth_list = ["eth0"] 52 | if not eth_list: 53 | eth_list = ["eth0"] 54 | return eth_list 55 | 56 | 57 | def set_node_env(master_addr="localhost", master_port="1200", eth_name=None): 58 | """ 59 | 设置多卡训练的节点信息 60 | parser = argparse.ArgumentParser(description="for face verification train") 61 | parser.add_argument("-c", "--config", help="configs file", default="configs/config_distributed.yaml", type=str) 62 | parser.add_argument("-e", '--eth_name', type=str, default=None, help="set eth name") 63 | parser.add_argument("-a", '--master_addr', type=str, default='localhost', help="set master node address") 64 | parser.add_argument("-p", '--master_port', type=str, default='1200', help="set master node port") 65 | parser.add_argument("--local_rank", type=int, default=0, help="torch.distributed.launch会给模型分配一个args.local_rank的参数," 66 | "也可以通过torch.distributed.get_rank()获取进程id") 67 | parser.add_argument("--init_method", type=str, default="env://") 68 | args = parser.parse_args() 69 | ==================================== 70 | :param master_addr: 主节点地址,default localhost 71 | :param master_port: 主节点接口,default 1200 72 | :param eth_name: 网卡名称,None会自动获取 73 | :return: 74 | """ 75 | if eth_name is None: # auto get eth_name 76 | eth_name = get_loacl_eth()[0] 77 | print("eth_name:{}".format(eth_name)) 78 | os.environ['NCCL_SOCKET_IFNAME'] = eth_name 79 | os.environ['MASTER_ADDR'] = master_addr 80 | os.environ['MASTER_PORT'] = master_port 81 | 82 | 83 | def set_distributed_env(backend="nccl", init_method="env://"): 84 | """ 85 | initialize the distributed environment 86 | :param backend: 87 | :param init_method: 88 | :return: world_size :参与工作的进程数 89 | rank: 当前进程的rank(这个Worker是全局第几个Worker) 90 | local_rank:这个Worker是这台机器上的第几个Worker 91 | """ 92 | # initialize process group 93 | # use nccl backend to speedup gpu communication 94 | torch.distributed.init_process_group(backend=backend, init_method=init_method) 95 | world_size = torch.distributed.get_world_size() 96 | # torch.distributed.launch 会给模型分配一个args.local_rank的参数,也可以通过torch.distributed.get_rank()获取进程id。 97 | rank = torch.distributed.get_rank() # os.environ["RANK"] 98 | return world_size, rank 99 | 100 | 101 | def get_distributed_sampler(dataset: torch.utils.data.Dataset, world_size, rank): 102 | """ 103 | Example: 104 | sampler = DistributedSampler(dataset) if is_distributed else None 105 | loader = DataLoader(dataset, shuffle=(sampler is None), 106 | sampler=sampler) 107 | for epoch in range(start_epoch, n_epochs): 108 | if is_distributed: 109 | sampler.set_epoch(epoch) 110 | train(loader) 111 | :param dataset: 112 | :param world_size: 113 | :param rank: 114 | :return: 115 | """ 116 | sampler = torch.utils.data.distributed.DistributedSampler( 117 | dataset, num_replicas=world_size, rank=rank) 118 | return sampler 119 | 120 | 121 | def get_device(): 122 | """ 123 | 返回当前设备索引 124 | torch.cuda.current_device() 125 | 返回GPU的数量 126 | torch.cuda.device_count() 127 | 返回gpu名字,设备索引默认从0开始 128 | torch.cuda.get_device_name(0) 129 | cuda是否可用 130 | torch.cuda.is_available() 131 | ========== 132 | CUDA_VISIBLE_DEVICES=4,5,6 python train.py 133 | 134 | Usage: 135 | gpu_id = get_device() 136 | model = build_model() 137 | model.cuda() 138 | model = torch.nn.DataParallel(model, device_ids=gpu_id) 139 | ... 140 | :return: 141 | """ 142 | gpu_id = list(range(torch.cuda.device_count())) 143 | return gpu_id 144 | 145 | 146 | def print_model(model): 147 | """ 148 | :param model: 149 | :return: 150 | """ 151 | for k, v in model.named_parameters(): 152 | # print(k,v) 153 | print(k) 154 | 155 | 156 | def freeze_net_layers(net): 157 | """ 158 | https://www.zhihu.com/question/311095447/answer/589307812 159 | example: 160 | freeze_net_layers(net.base_net) 161 | freeze_net_layers(net.source_layer_add_ons) 162 | freeze_net_layers(net.extras) 163 | :param net: 164 | :return: 165 | """ 166 | # for param in net.parameters(): 167 | # param.requires_grad = False 168 | for name, child in net.named_children(): 169 | # print(name, child) 170 | for param in child.parameters(): 171 | param.requires_grad = False 172 | 173 | 174 | def load_state_dict(model_path, module=True): 175 | """ 176 | Usage: 177 | model=Model() 178 | state_dict = torch_tools.load_state_dict(model_path, module=False) 179 | model.load_state_dict(state_dict) 180 | :param model_path: 181 | :param module: 182 | :return: 183 | """ 184 | state_dict = None 185 | if model_path: 186 | print('=> loading model from {}'.format(model_path)) 187 | state_dict = torch.load(model_path, map_location=torch.device('cpu')) 188 | if module: 189 | state_dict = get_module_state_dict(state_dict) 190 | else: 191 | raise Exception("Error:no model file:{}".format(model_path)) 192 | return state_dict 193 | 194 | 195 | def get_module_state_dict(state_dict): 196 | """ 197 | :param state_dict: 198 | :return: 199 | """ 200 | # 初始化一个空 dict 201 | new_state_dict = OrderedDict() 202 | # 修改 key,没有module字段则需要不del,如果有,则需要修改为 module.features 203 | for k, v in state_dict.items(): 204 | if k.startswith("module."): 205 | # k = k.replace('module.', '') 206 | k = k[len("module."):] 207 | new_state_dict[k] = v 208 | return new_state_dict 209 | 210 | 211 | def load_pretrained_model(model, ckpt): 212 | """Loads pretrianed weights to model. 213 | Features:只会加载完全匹配的模型参数,不匹配的模型将会忽略 214 | - Incompatible layers (unmatched in name or size) will be ignored. 215 | - Can automatically deal with keys containing "module.". 216 | Args: 217 | model (nn.Module): network model. 218 | ckpt (str): OrderedDict or model file 219 | """ 220 | if isinstance(ckpt, str): 221 | checkpoint = load_state_dict(ckpt) 222 | if 'state_dict' in checkpoint: 223 | state_dict = checkpoint['state_dict'] 224 | else: 225 | state_dict = checkpoint 226 | elif isinstance(ckpt, OrderedDict): 227 | state_dict = ckpt 228 | else: 229 | raise Exception("nonsupport type:{} ".format(ckpt)) 230 | model_dict = model.state_dict() 231 | new_state_dict = OrderedDict() 232 | matched_layers, discarded_layers = [], [] 233 | for k, v in state_dict.items(): 234 | if k.startswith('module.'): 235 | k = k[7:] # discard module. 236 | if k in model_dict and model_dict[k].size() == v.size(): 237 | new_state_dict[k] = v 238 | matched_layers.append(k) 239 | else: 240 | discarded_layers.append(k) 241 | model_dict.update(new_state_dict) 242 | model.load_state_dict(model_dict) 243 | print("=" * 60) 244 | if len(matched_layers) == 0: 245 | raise Exception('The model checkpoint cannot be loaded,' 246 | 'please check the key names manually') 247 | else: 248 | print('Successfully loaded model checkpoint') 249 | # [print('{}'.format(layer)) for layer in matched_layers] 250 | if len(discarded_layers) > 0: 251 | print('The following layers are discarded due to unmatched keys or layer size') 252 | [print('{}'.format(layer)) for layer in discarded_layers] 253 | 254 | print("=" * 60) 255 | return model 256 | 257 | 258 | def plot_model(model, output=None, input_shape=None): 259 | """ 260 | Usage: 261 | output = model(inputs) 262 | vis_graph = make_dot(output, params=dict(model.named_parameters())) 263 | vis_graph.view() 264 | ================================================================= 265 | output/input_shape至少已知一个 266 | :param model: 267 | :param output: 268 | :param input_shape: (batch_size, 3, input_size[0], input_size[1]) 269 | :return: 270 | """ 271 | from torchviz import make_dot 272 | if output is None: 273 | output = model_forward(model, input_shape, device="cpu") 274 | vis_graph = make_dot(output, params=dict(model.named_parameters())) 275 | vis_graph.view() 276 | 277 | 278 | def model_forward(model, input_shape, device="cpu"): 279 | """ 280 | input_shape=(batch_size, 3, input_size[0], input_size[1]) 281 | :param model: 282 | :param input_shape: 283 | :param device: 284 | :return: 285 | """ 286 | inputs = torch.randn(size=input_shape) 287 | inputs = inputs.to(device) 288 | model = model.to(device) 289 | model.eval() 290 | output = model(inputs) 291 | return output 292 | 293 | 294 | def summary_model(model, batch_size=1, input_size=[112, 112], plot=False, device="cpu"): 295 | """ 296 | ----This tools can show---- 297 | Total params: 359,592 298 | Total memory: 47.32MB 299 | Total MAdd: 297.37MMAdd 300 | Total Flops: 153.31MFlops 301 | Total MemR+W: 99.7MB 302 | ==================================================== 303 | https://www.cnblogs.com/xuanyuyt/p/12653041.html 304 | Total number of network parameters (params) 305 | Theoretical amount of floating point arithmetics (FLOPs) 306 | Theoretical amount of multiply-adds (MAdd MACC) (乘加运算) 307 | Memory usage (memory) 308 | MACCs:是multiply-accumulate operations,指点积运算, 一个 macc = 2FLOPs 309 | FLOPs 的全称是 floating points of operations,即浮点运算次数,用来衡量模型的计算复杂度。 310 | 计算 FLOPs 实际上是计算模型中乘法和加法的运算次数。 311 | 卷积层的浮点运算次数不仅取决于卷积核的大小和输入输出通道数,还取决于特征图的大小; 312 | 而全连接层的浮点运算次数和参数量是相同的。 313 | ==================================================== 314 | :param model: 315 | :param batch_size: 316 | :param input_size: 317 | :param plot: plot model 318 | :param device: 319 | :return: 320 | """ 321 | from torchsummary import summary 322 | from torchstat import stat 323 | inputs = torch.randn(size=(batch_size, 3, input_size[1], input_size[0])) 324 | inputs = inputs.to(device) 325 | model = model.to(device) 326 | model.eval() 327 | output = model(inputs) 328 | # 统计模型参数 329 | summary(model, input_size=(3, input_size[1], input_size[0]), batch_size=batch_size, device=device) 330 | # 统计模型参数和计算FLOPs 331 | stat(model, (3, input_size[1], input_size[0])) 332 | # summary可能报错,可使用该方法 333 | # summary_v2(model, inputs, item_length=26, verbose=True) 334 | # from thop import profile 335 | # macs, params = profile(model, inputs=(inputs,)) 336 | # print("Total Flops :{}".format(macs)) 337 | # print("Total params:{}".format(params)) 338 | print("===" * 10) 339 | print("inputs.shape:{}".format(inputs.shape)) 340 | # print("output.shape:{}".format(output.shape)) 341 | if plot: 342 | plot_model(model, output) 343 | 344 | 345 | def torchinfo_summary(model, batch_size=1, input_size=[112, 112], plot=False, device="cpu"): 346 | """ 347 | ----This tools can show---- 348 | Total params: 359,592 349 | Total memory: 47.32MB 350 | Total MAdd: 297.37MMAdd 351 | Total Flops: 153.31MFlops 352 | Total MemR+W: 99.7MB 353 | ==================================================== 354 | https://www.cnblogs.com/xuanyuyt/p/12653041.html 355 | Total number of network parameters (params) 356 | Theoretical amount of floating point arithmetics (FLOPs) 357 | Theoretical amount of multiply-adds (MAdd MACC) (乘加运算) 358 | Memory usage (memory) 359 | MACCs:是multiply-accumulate operations,指点积运算, 一个 macc = 2FLOPs 360 | FLOPs 的全称是 floating points of operations,即浮点运算次数,用来衡量模型的计算复杂度。 361 | 计算 FLOPs 实际上是计算模型中乘法和加法的运算次数。 362 | 卷积层的浮点运算次数不仅取决于卷积核的大小和输入输出通道数,还取决于特征图的大小; 363 | 而全连接层的浮点运算次数和参数量是相同的。 364 | ==================================================== 365 | :param model: 366 | :param batch_size: 367 | :param input_size: 368 | :param plot: plot model 369 | :param device: 370 | :return: 371 | """ 372 | from torchinfo import summary 373 | from torchstat import stat 374 | inputs = torch.randn(size=(batch_size, 3, input_size[1], input_size[0])) 375 | inputs = inputs.to(device) 376 | model = model.to(device) 377 | model.eval() 378 | output = model(inputs) 379 | # 统计模型参数 380 | summary(model, input_size=(batch_size, 3, input_size[1], input_size[0]), device=device) 381 | # 统计模型参数和计算FLOPs 382 | stat(model, (3, input_size[1], input_size[0])) 383 | # summary可能报错,可使用该方法 384 | # summary_v2(model, inputs, item_length=26, verbose=True) 385 | # from thop import profile 386 | # macs, params = profile(model, inputs=(inputs,)) 387 | # print("Total Flops :{}".format(macs)) 388 | # print("Total params:{}".format(params)) 389 | print("===" * 10) 390 | print("inputs.shape:{}".format(inputs.shape)) 391 | # print("output.shape:{}".format(output.shape)) 392 | if plot: 393 | plot_model(model, output) 394 | -------------------------------------------------------------------------------- /audio/utils/utility.py: -------------------------------------------------------------------------------- 1 | import distutils.util 2 | 3 | 4 | def print_arguments(args): 5 | print("----------- Configuration Arguments -----------") 6 | for arg, value in sorted(vars(args).items()): 7 | print("%s: %s" % (arg, value)) 8 | print("------------------------------------------------") 9 | 10 | 11 | def add_arguments(argname, type, default, help, argparser, **kwargs): 12 | type = distutils.util.strtobool if type == bool else type 13 | argparser.add_argument("--" + argname, 14 | default=default, 15 | type=type, 16 | help=help + ' 默认: %(default)s.', 17 | **kwargs) 18 | -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/data/.DS_Store -------------------------------------------------------------------------------- /data/UrbanSound8K/README.txt: -------------------------------------------------------------------------------- 1 | UrbanSound8K 2 | ============ 3 | 4 | Created By 5 | ---------- 6 | 7 | Justin Salamon*^, Christopher Jacoby* and Juan Pablo Bello* 8 | * Music and Audio Research Lab (MARL), New York University, USA 9 | ^ Center for Urban Science and Progress (CUSP), New York University, USA 10 | http://serv.cusp.nyu.edu/projects/urbansounddataset 11 | http://marl.smusic.nyu.edu/ 12 | http://cusp.nyu.edu/ 13 | 14 | Version 1.0 15 | 16 | 17 | Description 18 | ----------- 19 | 20 | This dataset contains 8732 labeled sound excerpts (<=4s) of urban sounds from 10 classes: air_conditioner, car_horn, 21 | children_playing, dog_bark, drilling, engine_idling, gun_shot, jackhammer, siren, and street_music. The classes are 22 | drawn from the urban sound taxonomy described in the following article, which also includes a detailed description of 23 | the dataset and how it was compiled: 24 | 25 | J. Salamon, C. Jacoby and J. P. Bello, "A Dataset and Taxonomy for Urban Sound Research", 26 | 22nd ACM International Conference on Multimedia, Orlando USA, Nov. 2014. 27 | 28 | All excerpts are taken from field recordings uploaded to www.freesound.org. The files are pre-sorted into ten folds 29 | (folders named fold1-fold10) to help in the reproduction of and comparison with the automatic classification results 30 | reported in the article above. 31 | 32 | In addition to the sound excerpts, a CSV file containing metadata about each excerpt is also provided. 33 | 34 | 35 | Audio Files Included 36 | -------------------- 37 | 38 | 8732 audio files of urban sounds (see description above) in WAV format. The sampling rate, bit depth, and number of 39 | channels are the same as those of the original file uploaded to Freesound (and hence may vary from file to file). 40 | 41 | 42 | Meta-data Files Included 43 | ------------------------ 44 | 45 | UrbanSound8k.csv 46 | 47 | This file contains meta-data information about every audio file in the dataset. This includes: 48 | 49 | * slice_file_name: 50 | The name of the audio file. The name takes the following format: [fsID]-[classID]-[occurrenceID]-[sliceID].wav, where: 51 | [fsID] = the Freesound ID of the recording from which this excerpt (slice) is taken 52 | [classID] = a numeric identifier of the sound class (see description of classID below for further details) 53 | [occurrenceID] = a numeric identifier to distinguish different occurrences of the sound within the original recording 54 | [sliceID] = a numeric identifier to distinguish different slices taken from the same occurrence 55 | 56 | * fsID: 57 | The Freesound ID of the recording from which this excerpt (slice) is taken 58 | 59 | * start 60 | The start time of the slice in the original Freesound recording 61 | 62 | * end: 63 | The end time of slice in the original Freesound recording 64 | 65 | * salience: 66 | A (subjective) salience rating of the sound. 1 = foreground, 2 = background. 67 | 68 | * fold: 69 | The fold number (1-10) to which this file has been allocated. 70 | 71 | * classID: 72 | A numeric identifier of the sound class: 73 | 0 = air_conditioner 74 | 1 = car_horn 75 | 2 = children_playing 76 | 3 = dog_bark 77 | 4 = drilling 78 | 5 = engine_idling 79 | 6 = gun_shot 80 | 7 = jackhammer 81 | 8 = siren 82 | 9 = street_music 83 | 84 | * class: 85 | The class name: air_conditioner, car_horn, children_playing, dog_bark, drilling, engine_idling, gun_shot, jackhammer, 86 | siren, street_music. 87 | 88 | 89 | Please Acknowledge UrbanSound8K in Academic Research 90 | ---------------------------------------------------- 91 | 92 | When UrbanSound8K is used for academic research, we would highly appreciate it if scientific publications of works 93 | partly based on the UrbanSound8K dataset cite the following publication: 94 | 95 | J. Salamon, C. Jacoby and J. P. Bello, "A Dataset and Taxonomy for Urban Sound Research", 96 | 22nd ACM International Conference on Multimedia, Orlando USA, Nov. 2014. 97 | 98 | The creation of this dataset was supported by a seed grant by NYU's Center for Urban Science and Progress (CUSP). 99 | 100 | 101 | Conditions of Use 102 | ----------------- 103 | 104 | Dataset compiled by Justin Salamon, Christopher Jacoby and Juan Pablo Bello. All files are excerpts of recordings 105 | uploaded to www.freesound.org. Please see FREESOUNDCREDITS.txt for an attribution list. 106 | 107 | The UrbanSound8K dataset is offered free of charge for non-commercial use only under the terms of the Creative Commons 108 | Attribution Noncommercial License (by-nc), version 3.0: http://creativecommons.org/licenses/by-nc/3.0/ 109 | 110 | The dataset and its contents are made available on an "as is" basis and without warranties of any kind, including 111 | without limitation satisfactory quality and conformity, merchantability, fitness for a particular purpose, accuracy or 112 | completeness, or absence of errors. Subject to any liability that may not be excluded or limited by law, NYU is not 113 | liable for, and expressly excludes, all liability for loss or damage however and whenever caused to anyone by any use of 114 | the UrbanSound8K dataset or any part of it. 115 | 116 | 117 | Feedback 118 | -------- 119 | 120 | Please help us improve UrbanSound8K by sending your feedback to: justin.salamon@nyu.edu or justin.salamon@gmail.com 121 | In case of a problem report please include as many details as possible. 122 | -------------------------------------------------------------------------------- /data/UrbanSound8K/class_name.txt: -------------------------------------------------------------------------------- 1 | air_conditioner 2 | car_horn 3 | children_playing 4 | dog_bark 5 | drilling 6 | engine_idling 7 | gun_shot 8 | jackhammer 9 | siren 10 | street_music 11 | -------------------------------------------------------------------------------- /data/UrbanSound8K/test.txt: -------------------------------------------------------------------------------- 1 | fold5/139665-9-0-8.wav,street_music 2 | fold1/193698-2-0-114.wav,children_playing 3 | fold9/189988-0-0-2.wav,air_conditioner 4 | fold4/144351-4-3-15.wav,drilling 5 | fold6/94632-5-0-15.wav,engine_idling 6 | fold10/138017-9-1-6.wav,street_music 7 | fold7/155299-3-1-1.wav,dog_bark 8 | fold9/157866-8-0-11.wav,siren 9 | fold1/138031-2-0-7.wav,children_playing 10 | fold4/61790-9-0-15.wav,street_music 11 | fold9/157866-8-0-15.wav,siren 12 | fold3/166101-5-0-3.wav,engine_idling 13 | fold4/24347-8-0-91.wav,siren 14 | fold10/74364-8-1-20.wav,siren 15 | fold4/39968-9-0-81.wav,street_music 16 | fold9/58937-4-2-9.wav,drilling 17 | fold3/30204-0-0-11.wav,air_conditioner 18 | fold5/17578-5-0-0.wav,engine_idling 19 | fold9/145390-9-0-7.wav,street_music 20 | fold1/180937-7-3-9.wav,jackhammer 21 | fold2/14387-9-0-11.wav,street_music 22 | fold9/137815-4-0-4.wav,drilling 23 | fold5/34771-3-0-5.wav,dog_bark 24 | fold8/76266-2-0-7.wav,children_playing 25 | fold6/115243-9-0-0.wav,street_music 26 | fold2/123688-8-0-13.wav,siren 27 | fold5/115239-9-0-0.wav,street_music 28 | fold6/188004-8-0-0.wav,siren 29 | fold2/201652-5-3-1.wav,engine_idling 30 | fold1/180937-7-4-9.wav,jackhammer 31 | fold3/13230-0-0-12.wav,air_conditioner 32 | fold5/90013-7-0-6.wav,jackhammer 33 | fold8/71177-8-1-4.wav,siren 34 | fold6/63724-0-0-15.wav,air_conditioner 35 | fold8/17009-2-0-10.wav,children_playing 36 | fold1/98223-7-2-0.wav,jackhammer 37 | fold10/117889-9-0-30.wav,street_music 38 | fold7/61503-2-0-5.wav,children_playing 39 | fold9/103249-5-0-2.wav,engine_idling 40 | fold8/180134-4-2-15.wav,drilling 41 | fold8/125678-7-2-0.wav,jackhammer 42 | fold8/177726-0-0-30.wav,air_conditioner 43 | fold10/129750-2-0-4.wav,children_playing 44 | fold10/165166-8-0-2.wav,siren 45 | fold7/127443-4-0-7.wav,drilling 46 | fold5/62566-5-0-5.wav,engine_idling 47 | fold3/42937-4-0-1.wav,drilling 48 | fold3/42117-8-0-14.wav,siren 49 | fold4/151005-4-1-0.wav,drilling 50 | fold3/66622-4-0-8.wav,drilling 51 | fold6/116423-2-0-2.wav,children_playing 52 | fold1/40722-8-0-1.wav,siren 53 | fold9/101729-0-0-23.wav,air_conditioner 54 | fold9/79089-0-0-106.wav,air_conditioner 55 | fold7/135527-6-14-10.wav,gun_shot 56 | fold6/189986-0-0-0.wav,air_conditioner 57 | fold10/93567-8-0-8.wav,siren 58 | fold8/95549-3-0-7.wav,dog_bark 59 | fold3/103199-4-2-6.wav,drilling 60 | fold8/72015-2-0-0.wav,children_playing 61 | fold10/180127-4-0-1.wav,drilling 62 | fold5/71171-4-0-1.wav,drilling 63 | fold10/136558-9-0-1.wav,street_music 64 | fold6/58005-4-0-76.wav,drilling 65 | fold4/24347-8-0-31.wav,siren 66 | fold6/155127-9-1-23.wav,street_music 67 | fold6/128891-3-0-4.wav,dog_bark 68 | fold3/13230-0-0-5.wav,air_conditioner 69 | fold6/97331-2-0-20.wav,children_playing 70 | fold2/159747-8-0-7.wav,siren 71 | fold2/204773-3-7-1.wav,dog_bark 72 | fold6/34952-8-0-1.wav,siren 73 | fold8/113202-5-0-27.wav,engine_idling 74 | fold6/58005-4-0-40.wav,drilling 75 | fold4/109711-3-2-4.wav,dog_bark 76 | fold6/193697-2-0-99.wav,children_playing 77 | fold1/46669-4-0-35.wav,drilling 78 | fold2/96475-9-0-0.wav,street_music 79 | fold5/178686-0-0-68.wav,air_conditioner 80 | fold7/146845-0-0-15.wav,air_conditioner 81 | fold2/169098-7-4-0.wav,jackhammer 82 | fold7/201988-5-0-6.wav,engine_idling 83 | fold8/36429-2-0-18.wav,children_playing 84 | fold5/13577-3-0-2.wav,dog_bark 85 | fold10/189982-0-0-30.wav,air_conditioner 86 | fold5/180125-4-2-15.wav,drilling 87 | fold1/46669-4-0-45.wav,drilling 88 | fold4/24347-8-0-61.wav,siren 89 | fold9/165567-3-0-0.wav,dog_bark 90 | fold2/74507-0-0-3.wav,air_conditioner 91 | fold4/170564-2-1-27.wav,children_playing 92 | fold6/204240-0-0-26.wav,air_conditioner 93 | fold8/31325-3-1-0.wav,dog_bark 94 | fold3/186334-2-0-37.wav,children_playing 95 | fold3/165039-7-17-1.wav,jackhammer 96 | fold2/146690-0-0-64.wav,air_conditioner 97 | fold4/192382-2-0-66.wav,children_playing 98 | fold9/119449-5-0-2.wav,engine_idling 99 | fold8/157868-8-0-16.wav,siren 100 | fold1/196400-6-0-0.wav,gun_shot 101 | fold6/30206-7-0-11.wav,jackhammer 102 | fold3/166101-5-0-1.wav,engine_idling 103 | fold8/66324-9-0-30.wav,street_music 104 | fold3/62837-7-1-27.wav,jackhammer 105 | fold6/82368-2-0-1.wav,children_playing 106 | fold7/201988-5-0-20.wav,engine_idling 107 | fold6/46299-2-0-22.wav,children_playing 108 | fold4/24347-8-0-67.wav,siren 109 | fold10/74922-4-0-5.wav,drilling 110 | fold2/203929-7-2-3.wav,jackhammer 111 | fold7/74513-3-0-0.wav,dog_bark 112 | fold9/58937-4-5-2.wav,drilling 113 | fold3/172315-9-0-211.wav,street_music 114 | fold6/132021-7-0-3.wav,jackhammer 115 | fold9/105029-7-0-6.wav,jackhammer 116 | fold3/165039-7-4-1.wav,jackhammer 117 | fold4/7389-1-0-7.wav,car_horn 118 | fold2/189023-0-0-11.wav,air_conditioner 119 | fold9/149929-9-1-1.wav,street_music 120 | fold9/39856-5-0-19.wav,engine_idling 121 | fold10/102857-5-0-2.wav,engine_idling 122 | fold2/74507-0-0-6.wav,air_conditioner 123 | fold4/174032-2-0-18.wav,children_playing 124 | fold3/22601-8-0-20.wav,siren 125 | fold2/146690-0-0-140.wav,air_conditioner 126 | fold5/178497-3-0-3.wav,dog_bark 127 | fold1/159738-8-0-18.wav,siren 128 | fold10/189982-0-0-4.wav,air_conditioner 129 | fold3/195451-5-0-11.wav,engine_idling 130 | fold8/74677-0-0-19.wav,air_conditioner 131 | fold2/201652-5-5-0.wav,engine_idling 132 | fold5/159439-2-0-21.wav,children_playing 133 | fold2/74507-0-0-20.wav,air_conditioner 134 | fold4/128160-5-0-14.wav,engine_idling 135 | fold8/54383-0-0-8.wav,air_conditioner 136 | fold10/171478-9-0-4.wav,street_music 137 | fold7/115411-3-2-0.wav,dog_bark 138 | fold6/52882-2-0-7.wav,children_playing 139 | fold2/169098-7-4-3.wav,jackhammer 140 | fold2/201652-5-0-1.wav,engine_idling 141 | fold9/159735-2-0-80.wav,children_playing 142 | fold9/60935-2-0-0.wav,children_playing 143 | fold3/30204-0-0-6.wav,air_conditioner 144 | fold1/159738-8-0-15.wav,siren 145 | fold2/102871-8-0-6.wav,siren 146 | fold4/116484-3-0-16.wav,dog_bark 147 | fold3/117048-3-0-25.wav,dog_bark 148 | fold4/128160-5-0-5.wav,engine_idling 149 | fold6/204240-0-0-5.wav,air_conditioner 150 | fold3/22601-8-0-39.wav,siren 151 | fold9/52740-3-0-1.wav,dog_bark 152 | fold6/204240-0-0-29.wav,air_conditioner 153 | fold4/24347-8-0-51.wav,siren 154 | fold1/134717-0-0-20.wav,air_conditioner 155 | fold7/177537-7-0-13.wav,jackhammer 156 | fold8/205610-4-0-5.wav,drilling 157 | fold2/39970-9-0-54.wav,street_music 158 | fold1/99180-9-0-0.wav,street_music 159 | fold5/71171-4-1-2.wav,drilling 160 | fold4/135528-6-4-1.wav,gun_shot 161 | fold2/149370-9-0-22.wav,street_music 162 | fold4/16692-5-0-5.wav,engine_idling 163 | fold10/178261-7-3-6.wav,jackhammer 164 | fold4/166942-0-0-5.wav,air_conditioner 165 | fold4/137971-2-0-4.wav,children_playing 166 | fold8/17009-2-0-1.wav,children_playing 167 | fold6/197075-3-6-0.wav,dog_bark 168 | fold5/72259-1-7-4.wav,car_horn 169 | fold5/180125-4-1-5.wav,drilling 170 | fold10/188813-7-10-2.wav,jackhammer 171 | fold9/180937-4-2-1.wav,drilling 172 | fold4/74950-3-2-5.wav,dog_bark 173 | fold3/66622-4-0-6.wav,drilling 174 | fold7/155238-2-0-17.wav,children_playing 175 | fold6/184805-0-0-48.wav,air_conditioner 176 | fold4/22883-7-49-1.wav,jackhammer 177 | fold10/196084-2-0-1.wav,children_playing 178 | fold10/73524-0-0-30.wav,air_conditioner 179 | fold10/77901-9-0-0.wav,street_music 180 | fold5/178260-7-1-10.wav,jackhammer 181 | fold3/165039-7-1-0.wav,jackhammer 182 | fold3/172315-9-0-212.wav,street_music 183 | fold6/85249-2-0-10.wav,children_playing 184 | fold6/132021-7-0-4.wav,jackhammer 185 | fold2/182739-2-0-78.wav,children_playing 186 | fold9/57105-3-1-0.wav,dog_bark 187 | fold1/176787-5-0-25.wav,engine_idling 188 | fold9/58937-4-4-1.wav,drilling 189 | fold5/20015-3-0-12.wav,dog_bark 190 | fold5/104998-7-7-0.wav,jackhammer 191 | fold5/139948-3-0-0.wav,dog_bark 192 | fold7/84143-2-0-15.wav,children_playing 193 | fold3/199769-1-0-17.wav,car_horn 194 | fold1/176787-5-0-15.wav,engine_idling 195 | fold6/52882-2-0-8.wav,children_playing 196 | fold4/55728-9-0-14.wav,street_music 197 | fold3/22601-8-0-14.wav,siren 198 | fold2/74507-0-0-24.wav,air_conditioner 199 | fold6/131918-7-0-7.wav,jackhammer 200 | fold1/177621-0-0-46.wav,air_conditioner 201 | fold3/196083-2-0-0.wav,children_playing 202 | fold2/106015-5-0-15.wav,engine_idling 203 | fold3/121528-8-1-1.wav,siren 204 | fold9/180029-4-4-0.wav,drilling 205 | fold4/132108-9-0-12.wav,street_music 206 | fold3/94636-8-0-19.wav,siren 207 | fold5/13577-3-0-0.wav,dog_bark 208 | fold3/117048-3-0-17.wav,dog_bark 209 | fold3/186334-2-0-44.wav,children_playing 210 | fold4/169466-4-1-3.wav,drilling 211 | fold10/188497-2-0-0.wav,children_playing 212 | fold10/22973-3-0-0.wav,dog_bark 213 | fold8/133090-2-0-76.wav,children_playing 214 | fold2/102871-8-0-7.wav,siren 215 | fold6/62564-5-0-8.wav,engine_idling 216 | fold2/74507-0-0-11.wav,air_conditioner 217 | fold5/17578-5-0-12.wav,engine_idling 218 | fold7/175296-2-0-123.wav,children_playing 219 | fold4/128607-4-1-3.wav,drilling 220 | fold4/7389-1-4-5.wav,car_horn 221 | fold7/135527-6-14-4.wav,gun_shot 222 | fold7/105289-8-0-1.wav,siren 223 | fold2/49808-3-0-6.wav,dog_bark 224 | fold6/58005-4-0-63.wav,drilling 225 | fold3/17853-5-0-13.wav,engine_idling 226 | fold7/127443-4-0-3.wav,drilling 227 | fold2/189023-0-0-5.wav,air_conditioner 228 | fold6/155127-9-1-25.wav,street_music 229 | fold3/107228-5-0-7.wav,engine_idling 230 | fold7/104625-4-0-27.wav,drilling 231 | fold4/55728-9-0-8.wav,street_music 232 | fold6/71088-4-0-0.wav,drilling 233 | fold4/47019-2-0-65.wav,children_playing 234 | fold10/73524-0-0-99.wav,air_conditioner 235 | fold1/180256-3-0-2.wav,dog_bark 236 | fold2/196384-9-0-15.wav,street_music 237 | fold5/71173-2-0-24.wav,children_playing 238 | fold3/63095-4-0-14.wav,drilling 239 | fold5/104998-7-9-0.wav,jackhammer 240 | fold5/104998-7-2-4.wav,jackhammer 241 | fold2/102871-8-0-10.wav,siren 242 | fold8/193699-2-0-33.wav,children_playing 243 | fold4/35549-9-0-58.wav,street_music 244 | fold2/174994-3-0-2.wav,dog_bark 245 | fold1/134717-0-0-21.wav,air_conditioner 246 | fold3/177742-0-0-149.wav,air_conditioner 247 | fold8/16860-9-0-28.wav,street_music 248 | fold3/62837-7-1-15.wav,jackhammer 249 | fold1/57320-0-0-6.wav,air_conditioner 250 | fold3/62837-7-0-37.wav,jackhammer 251 | fold2/155219-2-0-51.wav,children_playing 252 | fold2/76086-4-0-32.wav,drilling 253 | fold2/49808-3-1-22.wav,dog_bark 254 | fold1/180937-7-4-3.wav,jackhammer 255 | fold3/94636-8-0-0.wav,siren 256 | fold3/177742-0-0-36.wav,air_conditioner 257 | fold6/14358-3-0-85.wav,dog_bark 258 | fold8/72015-2-0-5.wav,children_playing 259 | fold6/58005-4-0-24.wav,drilling 260 | fold5/155243-9-0-55.wav,street_music 261 | fold9/62567-5-0-7.wav,engine_idling 262 | fold10/88121-8-0-3.wav,siren 263 | fold2/189991-0-0-6.wav,air_conditioner 264 | fold6/30206-7-0-26.wav,jackhammer 265 | fold9/149255-9-0-3.wav,street_music 266 | fold3/33696-3-6-1.wav,dog_bark 267 | fold6/121285-0-0-5.wav,air_conditioner 268 | fold5/104998-7-18-13.wav,jackhammer 269 | fold10/189982-0-0-39.wav,air_conditioner 270 | fold2/76086-4-0-27.wav,drilling 271 | fold7/177729-0-0-53.wav,air_conditioner 272 | fold10/118278-4-0-1.wav,drilling 273 | fold1/57320-0-0-15.wav,air_conditioner 274 | fold4/41918-3-0-1.wav,dog_bark 275 | fold4/195969-0-0-6.wav,air_conditioner 276 | fold4/144351-4-0-2.wav,drilling 277 | fold1/113205-5-0-0.wav,engine_idling 278 | fold8/54383-0-0-5.wav,air_conditioner 279 | fold4/55018-0-0-87.wav,air_conditioner 280 | fold2/77751-4-6-0.wav,drilling 281 | fold1/103074-7-3-1.wav,jackhammer 282 | fold3/62837-7-1-80.wav,jackhammer 283 | fold9/119449-5-0-0.wav,engine_idling 284 | fold8/30226-3-1-3.wav,dog_bark 285 | fold3/61791-9-1-42.wav,street_music 286 | fold1/105415-2-0-24.wav,children_playing 287 | fold5/180128-4-14-0.wav,drilling 288 | fold10/180127-4-0-20.wav,drilling 289 | fold3/117072-3-0-8.wav,dog_bark 290 | fold1/103074-7-4-1.wav,jackhammer 291 | fold8/7390-9-1-12.wav,street_music 292 | fold5/128152-9-0-49.wav,street_music 293 | fold7/146845-0-0-6.wav,air_conditioner 294 | fold3/19496-3-0-0.wav,dog_bark 295 | fold5/23219-5-0-8.wav,engine_idling 296 | fold5/194910-9-0-65.wav,street_music 297 | fold9/207211-2-0-58.wav,children_playing 298 | fold7/181102-9-0-23.wav,street_music 299 | fold9/149929-9-0-0.wav,street_music 300 | fold1/157867-8-0-23.wav,siren 301 | fold5/178260-7-0-0.wav,jackhammer 302 | fold10/103438-5-0-1.wav,engine_idling 303 | fold4/144007-5-1-6.wav,engine_idling 304 | fold1/180937-7-3-18.wav,jackhammer 305 | fold7/57596-3-1-0.wav,dog_bark 306 | fold4/7389-1-3-1.wav,car_horn 307 | fold7/101848-9-0-8.wav,street_music 308 | fold1/180256-3-0-3.wav,dog_bark 309 | fold6/128465-1-0-6.wav,car_horn 310 | fold1/73277-9-0-24.wav,street_music 311 | fold4/131428-9-1-0.wav,street_music 312 | fold5/156634-5-0-10.wav,engine_idling 313 | fold7/105289-8-0-3.wav,siren 314 | fold8/7390-9-0-9.wav,street_music 315 | fold3/30204-0-0-2.wav,air_conditioner 316 | fold8/160016-2-0-5.wav,children_playing 317 | fold9/79584-3-0-8.wav,dog_bark 318 | fold4/7389-1-0-3.wav,car_horn 319 | fold2/97193-3-0-0.wav,dog_bark 320 | fold9/103249-5-0-13.wav,engine_idling 321 | fold5/104998-7-19-3.wav,jackhammer 322 | fold7/177537-7-0-10.wav,jackhammer 323 | fold6/62564-5-0-1.wav,engine_idling 324 | fold4/159751-8-0-5.wav,siren 325 | fold9/39856-5-0-26.wav,engine_idling 326 | fold5/72259-1-10-1.wav,car_horn 327 | fold5/109263-9-0-61.wav,street_music 328 | fold6/85249-2-0-61.wav,children_playing 329 | fold9/180937-4-1-54.wav,drilling 330 | fold4/192382-2-0-36.wav,children_playing 331 | fold2/146690-0-0-125.wav,air_conditioner 332 | fold9/105029-7-2-2.wav,jackhammer 333 | fold4/151005-4-2-1.wav,drilling 334 | fold3/186339-9-0-1.wav,street_music 335 | fold3/186339-9-0-18.wav,street_music 336 | fold10/147491-9-0-5.wav,street_music 337 | fold5/178825-2-0-53.wav,children_playing 338 | fold10/99192-4-0-15.wav,drilling 339 | fold7/50454-5-0-0.wav,engine_idling 340 | fold6/94632-5-1-19.wav,engine_idling 341 | fold10/202334-9-0-63.wav,street_music 342 | fold3/195451-5-0-8.wav,engine_idling 343 | fold5/104421-2-0-15.wav,children_playing 344 | fold10/155262-2-0-15.wav,children_playing 345 | fold8/117181-8-0-4.wav,siren 346 | fold9/159735-2-0-121.wav,children_playing 347 | fold9/171406-9-0-183.wav,street_music 348 | fold10/83502-0-0-6.wav,air_conditioner 349 | fold3/37560-4-0-3.wav,drilling 350 | fold9/157866-8-0-2.wav,siren 351 | fold8/68080-7-0-9.wav,jackhammer 352 | fold10/102857-5-0-0.wav,engine_idling 353 | fold2/168713-9-0-82.wav,street_music 354 | fold10/162134-7-11-4.wav,jackhammer 355 | fold2/123688-8-0-2.wav,siren 356 | fold10/164377-9-1-43.wav,street_music 357 | fold9/180937-4-0-20.wav,drilling 358 | fold1/103074-7-0-0.wav,jackhammer 359 | fold2/109703-2-0-50.wav,children_playing 360 | fold7/49313-2-0-30.wav,children_playing 361 | fold7/57323-8-0-9.wav,siren 362 | fold8/162103-0-0-7.wav,air_conditioner 363 | fold4/146709-0-0-44.wav,air_conditioner 364 | fold5/196085-2-0-3.wav,children_playing 365 | fold10/129750-2-0-3.wav,children_playing 366 | fold10/100795-3-1-2.wav,dog_bark 367 | fold4/194458-9-1-91.wav,street_music 368 | fold2/168713-9-0-33.wav,street_music 369 | fold4/194458-9-0-2.wav,street_music 370 | fold3/58857-2-0-10.wav,children_playing 371 | fold5/156869-8-0-2.wav,siren 372 | fold6/74726-8-0-3.wav,siren 373 | fold2/76086-4-0-17.wav,drilling 374 | fold1/108362-2-0-12.wav,children_playing 375 | fold8/171243-9-0-85.wav,street_music 376 | fold10/74364-8-1-11.wav,siren 377 | fold2/203929-7-3-10.wav,jackhammer 378 | fold4/156362-4-0-3.wav,drilling 379 | fold9/96921-9-0-11.wav,street_music 380 | fold10/164194-2-0-10.wav,children_playing 381 | fold9/58937-4-1-1.wav,drilling 382 | fold1/146186-5-0-12.wav,engine_idling 383 | fold10/118558-5-2-0.wav,engine_idling 384 | fold2/147926-0-0-5.wav,air_conditioner 385 | fold10/93567-8-0-18.wav,siren 386 | fold9/159744-8-0-5.wav,siren 387 | fold4/159751-8-0-10.wav,siren 388 | fold9/52171-3-6-1.wav,dog_bark 389 | fold3/22601-8-0-29.wav,siren 390 | fold3/184725-3-0-1.wav,dog_bark 391 | fold6/184805-0-0-74.wav,air_conditioner 392 | fold6/137969-2-0-20.wav,children_playing 393 | fold5/72259-1-9-4.wav,car_horn 394 | fold4/169466-4-3-8.wav,drilling 395 | fold6/52882-2-0-4.wav,children_playing 396 | fold7/21683-9-0-15.wav,street_music 397 | fold9/105029-7-2-13.wav,jackhammer 398 | fold5/6508-9-0-6.wav,street_music 399 | fold6/30206-7-0-35.wav,jackhammer 400 | fold6/155212-9-1-75.wav,street_music 401 | fold9/39856-5-0-27.wav,engine_idling 402 | fold2/201652-5-4-7.wav,engine_idling 403 | fold4/132016-7-0-4.wav,jackhammer 404 | fold3/125523-3-0-3.wav,dog_bark 405 | fold10/73524-0-0-8.wav,air_conditioner 406 | fold1/165067-2-0-72.wav,children_playing 407 | fold3/72537-3-0-8.wav,dog_bark 408 | fold9/39856-5-0-18.wav,engine_idling 409 | fold4/47019-2-0-73.wav,children_playing 410 | fold7/202516-0-0-7.wav,air_conditioner 411 | fold4/151005-4-3-0.wav,drilling 412 | fold5/204408-2-0-2.wav,children_playing 413 | fold5/121286-0-0-11.wav,air_conditioner 414 | fold6/75490-8-1-0.wav,siren 415 | fold9/188823-7-0-4.wav,jackhammer 416 | fold4/175904-2-0-12.wav,children_playing 417 | fold3/144068-5-0-9.wav,engine_idling 418 | fold1/83199-9-0-0.wav,street_music 419 | fold6/135160-8-0-4.wav,siren 420 | fold5/189895-3-0-0.wav,dog_bark 421 | fold10/187110-2-0-34.wav,children_playing 422 | fold8/126153-9-0-8.wav,street_music 423 | fold6/83680-5-0-1.wav,engine_idling 424 | fold6/129356-2-0-199.wav,children_playing 425 | fold3/62837-7-0-22.wav,jackhammer 426 | fold6/124389-8-1-5.wav,siren 427 | fold9/58937-4-0-0.wav,drilling 428 | fold5/36263-9-0-12.wav,street_music 429 | fold4/159753-8-0-1.wav,siren 430 | fold10/188813-7-9-0.wav,jackhammer 431 | fold3/52077-3-0-8.wav,dog_bark 432 | fold2/152908-5-0-11.wav,engine_idling 433 | fold5/180052-3-0-1.wav,dog_bark 434 | fold4/183989-3-1-23.wav,dog_bark 435 | fold1/118101-3-0-0.wav,dog_bark 436 | fold9/39856-5-0-15.wav,engine_idling 437 | fold6/169045-2-0-3.wav,children_playing 438 | fold1/176714-2-0-26.wav,children_playing 439 | fold4/55018-0-0-209.wav,air_conditioner 440 | fold4/55728-9-0-30.wav,street_music 441 | fold8/160016-2-0-25.wav,children_playing 442 | fold6/124389-8-1-16.wav,siren 443 | fold4/81068-5-0-5.wav,engine_idling 444 | fold9/39856-5-0-0.wav,engine_idling 445 | fold4/144351-4-3-10.wav,drilling 446 | fold3/15356-2-0-2.wav,children_playing 447 | fold7/181102-9-0-26.wav,street_music 448 | fold10/180127-4-0-12.wav,drilling 449 | fold1/159738-8-0-5.wav,siren 450 | fold2/180126-4-4-1.wav,drilling 451 | fold10/102857-5-0-28.wav,engine_idling 452 | fold1/59277-0-0-4.wav,air_conditioner 453 | fold1/137156-9-0-73.wav,street_music 454 | fold7/57323-8-0-8.wav,siren 455 | fold10/180127-4-0-17.wav,drilling 456 | fold9/54187-1-0-3.wav,car_horn 457 | fold7/101848-9-0-9.wav,street_music 458 | fold1/157867-8-0-6.wav,siren 459 | fold4/7389-1-4-7.wav,car_horn 460 | fold10/83195-9-0-6.wav,street_music 461 | fold10/28808-1-0-3.wav,car_horn 462 | fold7/130961-4-5-2.wav,drilling 463 | fold6/104327-2-0-3.wav,children_playing 464 | fold7/57323-8-0-2.wav,siren 465 | fold9/62567-5-1-1.wav,engine_idling 466 | fold4/146709-0-0-33.wav,air_conditioner 467 | fold7/98525-8-0-0.wav,siren 468 | fold9/105029-7-2-11.wav,jackhammer 469 | -------------------------------------------------------------------------------- /data/audio/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/data/audio/.DS_Store -------------------------------------------------------------------------------- /data/audio/air_conditioner/13230-0-0-3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/data/audio/air_conditioner/13230-0-0-3.wav -------------------------------------------------------------------------------- /data/audio/air_conditioner/13230-0-0-5.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/data/audio/air_conditioner/13230-0-0-5.wav -------------------------------------------------------------------------------- /data/audio/car_horn/7389-1-2-3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/data/audio/car_horn/7389-1-2-3.wav -------------------------------------------------------------------------------- /data/audio/car_horn/7389-1-3-0.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/data/audio/car_horn/7389-1-3-0.wav -------------------------------------------------------------------------------- /data/audio/dog_bark/18581-3-1-3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/data/audio/dog_bark/18581-3-1-3.wav -------------------------------------------------------------------------------- /data/audio/dog_bark/19218-3-0-0.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/data/audio/dog_bark/19218-3-0-0.wav -------------------------------------------------------------------------------- /data/audio/engine_idling/17592-5-1-2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/data/audio/engine_idling/17592-5-1-2.wav -------------------------------------------------------------------------------- /data/audio/engine_idling/17592-5-1-3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/data/audio/engine_idling/17592-5-1-3.wav -------------------------------------------------------------------------------- /data/audio/street_music/6508-9-0-3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/data/audio/street_music/6508-9-0-3.wav -------------------------------------------------------------------------------- /data/audio/street_music/6508-9-0-4.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/data/audio/street_music/6508-9-0-4.wav -------------------------------------------------------------------------------- /data/pretrained/model_075_0.965.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/data/pretrained/model_075_0.965.pth -------------------------------------------------------------------------------- /data/record_audio/20211004174340.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/data/record_audio/20211004174340.wav -------------------------------------------------------------------------------- /data/record_audio/20211004174446.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/data/record_audio/20211004174446.wav -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | # -*-coding: utf-8 -*- 2 | 3 | import os 4 | import cv2 5 | import argparse 6 | import librosa 7 | import torch 8 | import numpy as np 9 | from audio.dataloader.audio_dataset import load_audio, normalization 10 | from audio.dataloader.record_audio import record_audio 11 | from audio.utils import file_utils, image_utils 12 | 13 | 14 | class Predictor(object): 15 | def __init__(self, cfg): 16 | self.device = "cpu" 17 | self.class_name, self.class_dict = file_utils.parser_classes(cfg.class_name, split=None) 18 | self.input_shape = eval(cfg.input_shape) 19 | self.spec_len = self.input_shape[3] 20 | self.model = self.build_model(cfg.model_file) 21 | 22 | def build_model(self, model_file): 23 | # 加载模型 24 | model = torch.jit.load(model_file, map_location="cpu") 25 | model.to(self.device) 26 | model.eval() 27 | return model 28 | 29 | def inference(self, input_tensors): 30 | with torch.no_grad(): 31 | input_tensors = input_tensors.to(self.device) 32 | output = self.model(input_tensors) 33 | return output 34 | 35 | def pre_process(self, spec_image): 36 | """音频数据预处理""" 37 | if spec_image.shape[1] > self.spec_len: 38 | input = spec_image[:, 0:self.spec_len] 39 | else: 40 | input = np.zeros(shape=(self.spec_len, self.spec_len), dtype=np.float32) 41 | input[:, 0:spec_image.shape[1]] = spec_image 42 | input = normalization(input) 43 | input = input[np.newaxis, np.newaxis, :] 44 | input_tensors = np.concatenate([input]) 45 | input_tensors = torch.tensor(input_tensors, dtype=torch.float32) 46 | return input_tensors 47 | 48 | def post_process(self, output): 49 | """输出结果后处理""" 50 | scores = torch.nn.functional.softmax(output, dim=1) 51 | scores = scores.data.cpu().numpy() 52 | # 显示图片并输出结果最大的label 53 | label = np.argmax(scores, axis=1) 54 | score = scores[:, label] 55 | label = [self.class_name[l] for l in label] 56 | return label, score 57 | 58 | def detect(self, audio_file): 59 | """ 60 | :param audio_file: 音频文件 61 | :return: label:预测音频的label 62 | score: 预测音频的置信度 63 | """ 64 | spec_image = load_audio(audio_file) 65 | input_tensors = self.pre_process(spec_image) 66 | # 执行预测 67 | output = self.inference(input_tensors) 68 | label, score = self.post_process(output) 69 | return label, score 70 | 71 | def detect_file_dir(self, file_dir): 72 | """ 73 | :param file_dir: 音频文件目录 74 | :return: 75 | """ 76 | file_list = file_utils.get_files_lists(file_dir, postfix=["*.wav"]) 77 | for file in file_list: 78 | print(file) 79 | label, score = self.detect(file) 80 | print("pred-label:{}, score:{}".format(label, score)) 81 | print("---" * 20) 82 | 83 | def detect_record_audio(self, audio_dir): 84 | """ 85 | :param audio_dir: 录制音频并进行识别 86 | :return: 87 | """ 88 | time = file_utils.get_time() 89 | file = os.path.join(audio_dir, time + ".wav") 90 | record_audio(file) 91 | label, score = self.detect(file) 92 | print(file) 93 | print("pred-label:{}, score:{}".format(label, score)) 94 | print("---"*20) 95 | 96 | 97 | 98 | def get_parser(): 99 | model_file = 'data/pretrained/model_075_0.965.pth' 100 | file_dir = "data/audio" 101 | class_name = 'data/UrbanSound8K/class_name.txt' 102 | parser = argparse.ArgumentParser(description=__doc__) 103 | parser.add_argument('--class_name', type=str, default=class_name, help='类别文件') 104 | parser.add_argument('--input_shape', type=str, default='(None, 1, 128, 128)', help='数据输入的形状') 105 | parser.add_argument('--net_type', type=str, default="mbv2", help='backbone') 106 | parser.add_argument('--gpu_id', type=int, default=0, help='GPU ID') 107 | parser.add_argument('--model_file', type=str, default=model_file, help='模型文件') 108 | parser.add_argument('--file_dir', type=str, default=file_dir, help='音频文件的目录') 109 | return parser 110 | 111 | 112 | if __name__ == '__main__': 113 | parser = get_parser() 114 | args = parser.parse_args() 115 | p = Predictor(args) 116 | p.detect_file_dir(file_dir=args.file_dir) 117 | 118 | # 预测自己录制的数据集 119 | # audio_dir = 'data/record_audio' 120 | # p.detect_record_audio(audio_dir=audio_dir) 121 | -------------------------------------------------------------------------------- /docs/example.py: -------------------------------------------------------------------------------- 1 | # -*-coding: utf-8 -*- 2 | import librosa 3 | import matplotlib.pyplot as plt 4 | import os 5 | import numpy as np 6 | 7 | def plot_spectrogram(title, data): 8 | plt.imshow(data) 9 | plt.axis('on') 10 | plt.title(title) 11 | plt.show() 12 | 13 | # 采样率 14 | sampling_rate = 16000 15 | try: 16 | # 定位当前文件的绝对路径 17 | path = os.path.dirname(os.path.abspath(__file__)) 18 | # 读取音频文件 19 | wav, sr = librosa.load(os.path.join(path, '../data/audio/car_horn/7389-1-2-3.wav'), sr=sampling_rate) 20 | except Exception as e: 21 | print(e) 22 | exit() 23 | 24 | # 对音频进行预处理 25 | # 将音频转换为频谱 26 | spectrogram = librosa.stft(wav) 27 | # 将频谱转换为离散的频率值 28 | spectrogram = np.abs(spectrogram) 29 | # 将频谱转换为对数值 30 | spectrogram = librosa.amplitude_to_db(spectrogram) 31 | # 将频谱转换为图像 32 | plot_spectrogram('spectrogram', spectrogram) 33 | 34 | 35 | # 梅尔频谱 36 | spec_image = librosa.feature.melspectrogram(y=wav, sr=sr) 37 | # 将梅尔频谱转换为对数值 38 | spec_image = librosa.amplitude_to_db(spec_image) 39 | # 将梅尔频谱转换为图像 40 | plot_spectrogram('mel spectrogram', spec_image) 41 | 42 | 43 | # 梅尔倒频谱:在梅尔频谱上做倒谱分析(取对数,做DCT变换)就得到了梅尔倒谱 44 | mfcc = librosa.feature.mfcc(wav, sr=sampling_rate, n_mfcc=20) 45 | # 将梅尔倒谱转换为图像 46 | plot_spectrogram('mfcc', mfcc) 47 | -------------------------------------------------------------------------------- /drawAudio.py: -------------------------------------------------------------------------------- 1 | from re import L 2 | import librosa 3 | import numpy as np 4 | import os 5 | 6 | import sklearn 7 | 8 | #获得当前文件所在路径 9 | path = os.path.dirname(os.path.abspath(__file__)) 10 | # 获得音频文件路径 11 | audio_path = os.path.join(path, 'data/audio/car_horn/7389-1-2-3.wav') 12 | 13 | assert os.path.exists(audio_path) # 断言,文件是否存在 14 | 15 | x, sr = librosa.load(audio_path) # 读取音频文件 16 | 17 | # print(type(x), type(sr)) 18 | # print(x.shape, sr) 19 | 20 | # 可视化音频 21 | import matplotlib.pyplot as plt 22 | import librosa.display 23 | 24 | plt.figure(figsize=(14, 5)) 25 | librosa.display.waveshow(x, sr=sr) 26 | plt.title('amplitude envelope') 27 | plt.savefig('picture/wave.png') 28 | 29 | 30 | 31 | # 声谱图(spectrogram)是声音或其他信号的频率随时间变化时的频谱(spectrum)的一种直观表示。 32 | # 在二维数组中,第一个轴是频率,第二个轴是时间。 33 | X = librosa.stft(x) 34 | Xdb = librosa.amplitude_to_db(abs(X)) 35 | plt.figure(figsize=(14, 5)) 36 | librosa.display.specshow(Xdb, sr=sr, x_axis='time', y_axis='hz') 37 | plt.colorbar() 38 | plt.title("spectrogram") 39 | plt.savefig('picture/spectrogram.png') 40 | 41 | plt.figure(figsize=(14, 5)) 42 | librosa.display.specshow(Xdb, sr=sr, x_axis='time', y_axis='log') 43 | plt.colorbar() 44 | plt.title("spectrogram (log)") 45 | plt.savefig('picture/spectrogram_log.png') 46 | 47 | 48 | # 特征提取 49 | 50 | # 过零率 Zero Crossing Rate 是一个信号符号变化的比率,即,在每帧中,语音信号从正变为负或从负变为正的次数。 51 | # Zooming in 52 | n0 = 9000 53 | n1 = 9100 54 | plt.figure(figsize=(14, 5)) 55 | plt.plot(x[n0:n1]) 56 | plt.title('Zero Crossing Rate') 57 | plt.grid() 58 | plt.savefig('picture/zero_crossing_rate.png') 59 | 60 | zero_crossings = librosa.zero_crossings(x[n0:n1], pad=False) 61 | print(sum(zero_crossings)) 62 | 63 | spectral_centroids = librosa.feature.spectral_centroid(x, sr=sr)[0] 64 | # print(spectral_centroids.shape) 65 | # (2647,) 66 | # Computing the time variable for visualization 67 | frames = range(len(spectral_centroids)) 68 | t = librosa.frames_to_time(frames) 69 | # Normalising the spectral centroid for visualisation 70 | def normalize(x, axis=0): 71 | return sklearn.preprocessing.minmax_scale(x, axis=axis) 72 | #Plotting the Spectral Centroid along the waveform 73 | plt.figure(figsize=(14, 5)) 74 | spectral_rolloff = librosa.feature.spectral_rolloff(x+0.01, sr=sr)[0] 75 | librosa.display.waveshow(x, sr=sr, alpha=0.4) 76 | plt.plot(t, normalize(spectral_rolloff), color='r') 77 | plt.title('Spectral Centroid') 78 | plt.savefig('picture/spectral_centroid.png') 79 | 80 | 81 | 82 | # 信号的Mel频率倒谱系数(MFCC)是一小组特征(通常约10-20),其简明地描述了频谱包络的整体形状,它模拟了人声的特征。 83 | mfccs = librosa.feature.mfcc(x, sr=sr) 84 | # print(mfccs.shape) 85 | # (20, 173) 86 | #Displaying the MFCCs: 87 | plt.figure(figsize=(14, 5)) 88 | librosa.display.specshow(mfccs, sr=sr, x_axis='time') 89 | plt.title('MFCC') 90 | plt.savefig('picture/mfcc.png') 91 | 92 | # mfcc计算了超过173帧的20个MFCC。我们还可以执行特征缩放,使得每个系数维度具有零均值和单位方差: 93 | mfccs = sklearn.preprocessing.scale(mfccs, axis=1) 94 | # print(mfccs.mean(axis=1)) 95 | # print(mfccs.var(axis=1)) 96 | plt.figure(figsize=(14, 5)) 97 | librosa.display.specshow(mfccs, sr=sr, x_axis='time') 98 | plt.title('MFCC (scaled)') 99 | plt.savefig('picture/mfcc_scaled.png') 100 | 101 | 102 | # 色度频率 Chroma Frequencies 103 | hop_length = 512 104 | chromagram = librosa.feature.chroma_stft(x, sr=sr, hop_length=hop_length) 105 | plt.figure(figsize=(14, 5)) 106 | librosa.display.specshow(chromagram, x_axis='time', y_axis='chroma', hop_length=hop_length, cmap='coolwarm') 107 | plt.title('Chroma') 108 | plt.savefig('picture/chroma.png') 109 | -------------------------------------------------------------------------------- /picture/chroma.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/picture/chroma.png -------------------------------------------------------------------------------- /picture/mfcc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/picture/mfcc.png -------------------------------------------------------------------------------- /picture/mfcc_scaled.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/picture/mfcc_scaled.png -------------------------------------------------------------------------------- /picture/spectral_centroid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/picture/spectral_centroid.png -------------------------------------------------------------------------------- /picture/spectrogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/picture/spectrogram.png -------------------------------------------------------------------------------- /picture/spectrogram_log.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/picture/spectrogram_log.png -------------------------------------------------------------------------------- /picture/wave.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/picture/wave.png -------------------------------------------------------------------------------- /picture/zero_crossing_rate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/picture/zero_crossing_rate.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*-coding: utf-8 -*- 2 | 3 | import argparse 4 | import os 5 | import numpy as np 6 | import torch 7 | import tensorboardX as tensorboard 8 | from datetime import datetime 9 | from easydict import EasyDict 10 | from tqdm import tqdm 11 | from torch.utils.data import DataLoader 12 | from torch.optim.lr_scheduler import StepLR, MultiStepLR 13 | 14 | from audio.dataloader.audio_dataset import AudioDataset 15 | from audio.utils.utility import print_arguments 16 | from audio.utils import file_utils 17 | from audio.models import mobilenet_v2, resnet 18 | 19 | 20 | class Train(object): 21 | """Training Pipeline""" 22 | 23 | def __init__(self, cfg): 24 | cfg = EasyDict(cfg.__dict__) 25 | self.device = "cuda:{}".format(cfg.gpu_id) if torch.cuda.is_available() else "cpu" 26 | self.num_epoch = cfg.num_epoch 27 | self.net_type = cfg.net_type 28 | self.work_dir = os.path.join(cfg.work_dir, self.net_type) 29 | self.model_dir = os.path.join(self.work_dir, "model") 30 | self.log_dir = os.path.join(self.work_dir, "log") 31 | file_utils.create_dir(self.model_dir) 32 | file_utils.create_dir(self.log_dir) 33 | 34 | self.tensorboard = tensorboard.SummaryWriter(log_dir=self.log_dir) 35 | self.train_loader, self.test_loader = self.build_dataset(cfg) 36 | # 获取模型 37 | self.model = self.build_model(cfg) 38 | # 获取优化方法,分别设定学习率和权重衰减 39 | self.optimizer = torch.optim.Adam(params=self.model.parameters(), lr=cfg.learning_rate, weight_decay=5e-4) 40 | # 获取学习率衰减函数,milestones中的每个元素代表哪几个epoch调整学习率, gamma为学习率调整倍数 41 | self.scheduler = MultiStepLR(self.optimizer, milestones=[50, 80], gamma=0.1) 42 | # 获取损失函数,这里采用交叉熵损失函数 43 | self.losses = torch.nn.CrossEntropyLoss() 44 | 45 | def build_dataset(self, cfg): 46 | """构建训练数据和测试数据""" 47 | input_shape = eval(cfg.input_shape) 48 | # 加载训练数据 49 | train_dataset = AudioDataset(cfg.train_data, 50 | class_name=cfg.class_name, 51 | data_dir=cfg.data_dir, 52 | mode='train', 53 | spec_len=input_shape[3]) 54 | train_loader = DataLoader(dataset=train_dataset, 55 | batch_size=cfg.batch_size, 56 | shuffle=True, 57 | num_workers=cfg.num_workers) 58 | cfg.class_name = train_dataset.class_name 59 | cfg.class_dict = train_dataset.class_dict 60 | cfg.num_classes = len(cfg.class_name) 61 | 62 | # 加载测试数据 63 | test_dataset = AudioDataset(cfg.test_data, 64 | class_name=cfg.class_name, 65 | data_dir=cfg.data_dir, 66 | mode='test', 67 | spec_len=input_shape[3]) 68 | test_loader = DataLoader(dataset=test_dataset, 69 | batch_size=cfg.batch_size, 70 | shuffle=False, 71 | num_workers=cfg.num_workers) 72 | 73 | print("train nums:{}".format(len(train_dataset))) 74 | print("test nums:{}".format(len(test_dataset))) 75 | return train_loader, test_loader 76 | 77 | def build_model(self, cfg): 78 | """构建模型""" 79 | if cfg.net_type == "mbv2": 80 | model = mobilenet_v2.mobilenet_v2(num_classes=cfg.num_classes) 81 | elif cfg.net_type == "resnet34": 82 | model = resnet.resnet34(num_classes=args.num_classes) 83 | elif cfg.net_type == "resnet18": 84 | model = resnet.resnet18(num_classes=args.num_classes) 85 | else: 86 | raise Exception("Error:{}".format(cfg.net_type)) 87 | model.to(self.device) 88 | return model 89 | 90 | def epoch_test(self, epoch): 91 | """模型测试""" 92 | loss_sum = [] 93 | accuracies = [] 94 | self.model.eval() # model.eval()的作用是在测试时不启用 Batch Normalization 和 Dropout。在测试时,model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变;对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。 95 | 96 | with torch.no_grad(): # with torch.no_grad()主要是用于停止autograd模块的工作,以起到加速和节省显存的作用 97 | for step, (inputs, labels) in enumerate(tqdm(self.test_loader)): 98 | inputs = inputs.to(self.device) 99 | labels = labels.to(self.device).long() 100 | output = self.model(inputs) 101 | # 计算损失值 102 | loss = self.losses(output, labels) 103 | # 计算准确率 104 | output = torch.nn.functional.softmax(output, dim=1) 105 | # 把output中的tensor数据取出来转成numpy类型放在cpu上 106 | output = output.data.cpu().numpy() 107 | # 取出每一行中最大值的索引 108 | output = np.argmax(output, axis=1) 109 | labels = labels.data.cpu().numpy() 110 | acc = np.mean((output == labels).astype(int)) 111 | accuracies.append(acc) 112 | loss_sum.append(loss) 113 | acc = sum(accuracies) / len(accuracies) 114 | loss = sum(loss_sum) / len(loss_sum) 115 | print("Test epoch:{:3.3f},Acc:{:3.3f},loss:{:3.3f}".format(epoch, acc, loss)) 116 | print('=' * 70) 117 | return acc, loss 118 | 119 | def epoch_train(self, epoch): 120 | """模型训练""" 121 | loss_sum = [] 122 | accuracies = [] 123 | self.model.train() 124 | for step, (inputs, labels) in enumerate(tqdm(self.train_loader)): 125 | inputs = inputs.to(self.device) 126 | labels = labels.to(self.device).long() 127 | output = self.model(inputs) 128 | # 计算损失值 129 | loss = self.losses(output, labels) 130 | # 梯度归零 131 | self.optimizer.zero_grad() 132 | # 反向传播计算得到每个参数的梯度值 133 | loss.backward() 134 | # 通过梯度下降执行一步参数更新 135 | self.optimizer.step() 136 | 137 | # 计算准确率 138 | output = torch.nn.functional.softmax(output, dim=1) 139 | output = output.data.cpu().numpy() 140 | output = np.argmax(output, axis=1) 141 | labels = labels.data.cpu().numpy() 142 | acc = np.mean((output == labels).astype(int)) 143 | accuracies.append(acc) 144 | loss_sum.append(loss) 145 | if step % 50 == 0: 146 | lr = self.optimizer.state_dict()['param_groups'][0]['lr'] 147 | print('[%s] Train epoch %d, batch: %d/%d, loss: %f, accuracy: %f,lr:%f' % ( 148 | datetime.now(), epoch, step, len(self.train_loader), sum(loss_sum) / len(loss_sum), 149 | sum(accuracies) / len(accuracies), lr)) 150 | acc = sum(accuracies) / len(accuracies) 151 | loss = sum(loss_sum) / len(loss_sum) 152 | print("Train epoch:{:3.3f},Acc:{:3.3f},loss:{:3.3f}".format(epoch, acc, loss)) 153 | print('=' * 70) 154 | return acc, loss 155 | 156 | def run(self): 157 | # 开始训练 158 | for epoch in range(self.num_epoch): 159 | train_acc, train_loss = self.epoch_train(epoch) 160 | test_acc, test_loss = self.epoch_test(epoch) 161 | self.tensorboard.add_scalar("train_acc", train_acc, epoch) 162 | self.tensorboard.add_scalar("train_loss", train_loss, epoch) 163 | self.tensorboard.add_scalar("test_acc", test_acc, epoch) 164 | self.tensorboard.add_scalar("test_loss", test_loss, epoch) 165 | self.scheduler.step() 166 | self.save_model(epoch, test_acc) 167 | 168 | def save_model(self, epoch, acc): 169 | """保持模型""" 170 | model_path = os.path.join(self.model_dir, 'model_{:0=3d}_{:.3f}.pth'.format(epoch, acc)) 171 | if not os.path.exists(os.path.dirname(model_path)): 172 | os.makedirs(os.path.dirname(model_path)) 173 | torch.jit.save(torch.jit.script(self.model), model_path) 174 | 175 | 176 | def get_parser(): 177 | data_dir = "/home/dm/data3/dataset/UrbanSound8K/audio" 178 | train_data = 'data/UrbanSound8K/train.txt' 179 | test_data = 'data/UrbanSound8K/test.txt' 180 | class_name = 'data/UrbanSound8K/class_name.txt' 181 | parser = argparse.ArgumentParser(description=__doc__) 182 | parser.add_argument('--batch_size', type=int, default=32, help='训练的批量大小') 183 | parser.add_argument('--num_workers', type=int, default=8, help='读取数据的线程数量') 184 | parser.add_argument('--num_epoch', type=int, default=100, help='训练的轮数') 185 | parser.add_argument('--class_name', type=str, default=class_name, help='类别文件') 186 | parser.add_argument('--learning_rate', type=float, default=1e-3, help='初始学习率的大小') 187 | parser.add_argument('--input_shape', type=str, default='(None, 1, 128, 128)', help='数据输入的形状') 188 | parser.add_argument('--gpu_id', type=int, default=0, help='GPU ID') 189 | parser.add_argument('--net_type', type=str, default="mbv2", help='backbone') 190 | parser.add_argument('--data_dir', type=str, default=data_dir, help='数据路径') 191 | parser.add_argument('--train_data', type=str, default=train_data, help='训练数据的数据列表路径') 192 | parser.add_argument('--test_data', type=str, default=test_data, help='测试数据的数据列表路径') 193 | parser.add_argument('--work_dir', type=str, default='work_space/', help='模型保存的路径') 194 | return parser 195 | 196 | 197 | if __name__ == '__main__': 198 | parser = get_parser() 199 | args = parser.parse_args() 200 | print_arguments(args) 201 | t = Train(args) 202 | t.run() 203 | -------------------------------------------------------------------------------- /work_space/mbv2/log/events.out.tfevents.1636439290.pjq: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/work_space/mbv2/log/events.out.tfevents.1636439290.pjq -------------------------------------------------------------------------------- /work_space/mbv2/log/events.out.tfevents.1653419301.mlzdeMBP.lan: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/work_space/mbv2/log/events.out.tfevents.1653419301.mlzdeMBP.lan -------------------------------------------------------------------------------- /work_space/mbv2/log/events.out.tfevents.1653424414.mlzdeMBP.lan: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/work_space/mbv2/log/events.out.tfevents.1653424414.mlzdeMBP.lan -------------------------------------------------------------------------------- /work_space/mbv2/log/events.out.tfevents.1653492757.mlzdeMBP.lan: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/work_space/mbv2/log/events.out.tfevents.1653492757.mlzdeMBP.lan -------------------------------------------------------------------------------- /work_space/mbv2/model/model_075_0.965.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Friedrich-M/Audio-signal-classification-and-identification/028c846340364c274a45766f5c1cf60c1286371a/work_space/mbv2/model/model_075_0.965.pth --------------------------------------------------------------------------------