├── .gitignore ├── LICENSE ├── README.md ├── architect.py ├── config ├── __init__.py └── default.py ├── data_objects ├── DeepSpeakerDataset.py ├── VoxcelebTestset.py ├── __init__.py ├── audio.py ├── compute_mean_std.py ├── params_data.py ├── partition_voxceleb.py ├── preprocess.py ├── speaker.py ├── transforms.py └── utterance.py ├── data_preprocess.py ├── dl_script.sh ├── evaluate_identification.py ├── evaluate_verification.py ├── exps ├── baseline │ ├── resnet18_iden.yaml │ ├── resnet18_veri.yaml │ ├── resnet34_iden.yaml │ └── resnet34_veri.yaml ├── scratch │ ├── scratch_iden.yaml │ └── scratch_veri.yaml └── search.yaml ├── figures ├── searched_arch_normal.png └── searched_arch_reduce.png ├── functions.py ├── loss.py ├── models ├── __init__.py ├── model.py ├── model_search.py └── resnet.py ├── operations.py ├── requirements.txt ├── search.py ├── spaces.py ├── train_baseline_identification.py ├── train_baseline_verification.py ├── train_identification.py ├── train_verification.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | logs/** 2 | logs_search/** 3 | models/__pycache__/** 4 | __pycache__/** 5 | config/__pycache__/** 6 | data_objects/__pycache__/** 7 | logs_scratch/** -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 VITA-Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoSpeech: Neural Architecture Search for Speaker Recognition 2 | 3 | Code for this paper [AutoSpeech: Neural Architecture Search for Speaker Recognition](https://arxiv.org/abs/2005.03215) 4 | 5 | Shaojin Ding*, Tianlong Chen*, Xinyu Gong, Weiwei Zha, Zhangyang Wang 6 | 7 | ## Overview 8 | Speaker recognition systems based on Convolutional Neural Networks (CNNs) are often built with off-the-shelf backbones such as VGG-Net or ResNet. However, these backbones were originally proposed for image classification, and therefore may not be naturally fit for speaker recognition. Due to the prohibitive complexity of manually exploring the design space, we propose the first neural architecture search approach approach for the speaker recognition tasks, named as **AutoSpeech**. Our evaluation results on [VoxCeleb1](http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html) demonstrate that the derived CNN architectures from the proposed approach significantly outperform current speaker recognition systems based on VGG-M, ResNet-18, and ResNet-34 back-bones, while enjoying lower model complexity. 9 | 10 | ## Results 11 | 12 | Our proposed approach outperforms speaker recognition systems based on VGG-M, ResNet-18, and ResNet-34 backbones. The detailed comparison can be found in our paper. 13 | 14 | | Method | Top-1 | EER | Parameters | Pretrained model 15 | | :------------: | :---: | :---: | :---: | :---: | 16 | | VGG-M | 80.50 | 10.20 | 67M | [iden/veri](https://github.com/a-nagrani/VGGVox) | 17 | | ResNet-18 | 79.48 | 12.30 | 12M | [iden](https://drive.google.com/file/d/16P071LB1kwiQEoKhQRD-B3XBSwQR_6eG/view?usp=sharing), [veri](https://drive.google.com/file/d/1uNA34GTPBmrlG2gTwgrhkTfsn7zBnC7d/view?usp=sharing) | 18 | | ResNet-34 | 81.34 | 11.99| 22M | [iden](https://drive.google.com/file/d/1UJ_N5hQkVESifNJlvFdCte0yMPaqbaXJ/view?usp=sharing), [veri](https://drive.google.com/file/d/1JD34RhuvDoc19ulWQNPNArSKfDXpUYud/view?usp=sharing) | 19 | | Proposed | **87.66** | **8.95** | **18M** | [iden](https://drive.google.com/file/d/1Ph4atwl603xrbiq8OkvjdINGBCIQyXy9/view?usp=sharing), [veri](https://drive.google.com/file/d/16TrxrkRK5A0J6UxjYrQHUlEBdAESC087/view?usp=sharing) | 20 | 21 | ### Visualization 22 | 23 | left: normal cell. right: reduction cell 24 |

25 | progress_convolutional_normal 26 | progress_convolutional_reduce 27 |

28 | 29 | ## 30 | 31 | ## Quick start 32 | ### Requirements 33 | * Python 3.7 34 | 35 | * Pytorch>=1.0: `pip install torch torchvision` 36 | 37 | * Other dependencies: `pip install -r requirements` 38 | 39 | ### Dataset 40 | [VoxCeleb1](http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html): You will need `DevA-DevD` and `Test` parts. Additionally, you will need original files: `vox1_meta.csv`, `iden_split.txt`, and `veri_test.txt` from official website. Alternatively, the dataset can be downloaded using `dl_script.sh`. 41 | 42 | The data should be organized as: 43 | * VoxCeleb1 44 | * dev/wav/... 45 | * test/wav/... 46 | * vox1_meta.csv 47 | * iden_split.txt 48 | * veri_test.txt 49 | 50 | ### Running the code 51 | * data preprocess: 52 | 53 | `python data_preprocess.py /path/to/VoxCeleb1` 54 | 55 | The output folder of it should be: 56 | * feature 57 | * dev 58 | * test 59 | * merged 60 | 61 | dev and test are used for verification, and merged are used for identification. 62 | 63 | * Training and evaluating ResNet-18, ResNet-34 baselines: 64 | 65 | `python train_baseline_identification.py --cfg exps/baseline/resnet18_iden.yaml` 66 | 67 | `python train_baseline_verification.py --cfg exps/baseline/resnet18_veri.yaml` 68 | 69 | `python train_baseline_identification.py --cfg exps/baseline/resnet34_iden.yaml` 70 | 71 | `python train_baseline_verification.py --cfg exps/baseline/resnet34_veri.yaml` 72 | 73 | You need to modify the `DATA_DIR` field in `.yaml` file. 74 | 75 | * Architecture search: 76 | 77 | `python search.py --cfg exps/search.yaml` 78 | 79 | You need to modify the `DATA_DIR` field in `.yaml` file. 80 | 81 | * Training from scratch for identification: 82 | 83 | `python train_identification.py --cfg exps/scratch/scratch.yaml --text_arch GENOTYPE` 84 | 85 | You need to modify the `DATA_DIR` field in `.yaml` file. 86 | 87 | `GENOTYPE` is the search architecture object. For example, the `GENOTYPE` of the architecture report in the paper is: 88 | 89 | `"Genotype(normal=[('dil_conv_5x5', 1), ('dil_conv_3x3', 0), ('dil_conv_5x5', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('sep_conv_3x3', 2), ('dil_conv_3x3', 2), ('max_pool_3x3', 1)], normal_concat=range(2, 6), reduce=[('max_pool_3x3', 1), ('max_pool_3x3', 0), ('dil_conv_5x5', 2), ('max_pool_3x3', 1), ('dil_conv_5x5', 3), ('dil_conv_3x3', 2), ('dil_conv_5x5', 4), ('dil_conv_5x5', 2)], reduce_concat=range(2, 6))"` 90 | 91 | * Training from scratch for verification: 92 | 93 | `python train_verification.py --cfg exps/scratch/scratch.yaml --text_arch GENOTYPE` 94 | 95 | 96 | * Evaluation: 97 | 98 | * Identification 99 | 100 | `python evaluate_identification.py --cfg exps/scratch/scratch_iden.yaml --load_path /path/to/the/trained/model` 101 | 102 | * Verification 103 | 104 | `python evaluate_verification.py --cfg exps/scratch/scratch_veri.yaml --load_path /path/to/the/trained/model` 105 | 106 | 107 | 108 | ## Citation 109 | 110 | If you use this code for your research, please cite our paper. 111 | 112 | ``` 113 | @misc{ding2020autospeech, 114 | title={AutoSpeech: Neural Architecture Search for Speaker Recognition}, 115 | author={Shaojin Ding and Tianlong Chen and Xinyu Gong and Weiwei Zha and Zhangyang Wang}, 116 | year={2020}, 117 | eprint={2005.03215}, 118 | archivePrefix={arXiv}, 119 | primaryClass={eess.AS} 120 | } 121 | ``` 122 | 123 | ## Acknowledgement 124 | 125 | Part of the codes are adapted from [deep-speaker](https://github.com/philipperemy/deep-speaker) and [Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning). 126 | -------------------------------------------------------------------------------- /architect.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def _concat(xs): 5 | return torch.cat([x.view(-1) for x in xs]) 6 | 7 | 8 | class Architect(object): 9 | 10 | def __init__(self, model, cfg): 11 | self.model = model 12 | self.optimizer = torch.optim.Adam(self.model.arch_parameters(), 13 | lr=cfg.TRAIN.ARCH_LR, betas=(0.5, 0.999), weight_decay=cfg.TRAIN.ARCH_WD) 14 | 15 | def step(self, input_valid, target_valid): 16 | self.optimizer.zero_grad() 17 | self._backward_step(input_valid, target_valid) 18 | self.optimizer.step() 19 | 20 | def _backward_step(self, input_valid, target_valid): 21 | loss = self.model._loss(input_valid, target_valid) 22 | loss.backward() -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | from .default import _C as cfg 2 | from .default import update_config 3 | -------------------------------------------------------------------------------- /config/default.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from yacs.config import CfgNode as CN 6 | 7 | 8 | _C = CN() 9 | 10 | _C.PRINT_FREQ = 20 11 | _C.VAL_FREQ = 20 12 | 13 | # Cudnn related params 14 | _C.CUDNN = CN() 15 | _C.CUDNN.BENCHMARK = True 16 | _C.CUDNN.DETERMINISTIC = False 17 | _C.CUDNN.ENABLED = True 18 | 19 | # seed 20 | _C.SEED = 3 21 | 22 | # common params for NETWORK 23 | _C.MODEL = CN() 24 | _C.MODEL.NAME = 'foo_net' 25 | _C.MODEL.NUM_CLASSES = 500 26 | _C.MODEL.LAYERS = 8 27 | _C.MODEL.INIT_CHANNELS = 16 28 | _C.MODEL.DROP_PATH_PROB = 0.2 29 | _C.MODEL.PRETRAINED = False 30 | 31 | # DATASET related params 32 | _C.DATASET = CN() 33 | _C.DATASET.DATA_DIR = '' 34 | _C.DATASET.SUB_DIR = '' 35 | _C.DATASET.TEST_DATA_DIR = '' 36 | _C.DATASET.TEST_DATASET = '' 37 | _C.DATASET.NUM_WORKERS = 0 38 | _C.DATASET.PARTIAL_N_FRAMES = 32 39 | 40 | 41 | # train 42 | _C.TRAIN = CN() 43 | 44 | _C.TRAIN.BATCH_SIZE = 32 45 | _C.TRAIN.LR = 0.1 46 | _C.TRAIN.LR_MIN = 0.001 47 | _C.TRAIN.WD = 0.0 48 | _C.TRAIN.BETA1 = 0.9 49 | _C.TRAIN.BETA2 = 0.999 50 | 51 | _C.TRAIN.ARCH_LR = 0.1 52 | _C.TRAIN.ARCH_WD = 0.0 53 | _C.TRAIN.ARCH_BETA1 = 0.9 54 | _C.TRAIN.ARCH_BETA2 = 0.999 55 | 56 | _C.TRAIN.DROPPATH_PROB = 0.2 57 | 58 | _C.TRAIN.BEGIN_EPOCH = 0 59 | _C.TRAIN.END_EPOCH = 140 60 | 61 | 62 | def update_config(cfg, args): 63 | cfg.defrost() 64 | cfg.merge_from_file(args.cfg) 65 | cfg.merge_from_list(args.opts) 66 | 67 | cfg.freeze() 68 | -------------------------------------------------------------------------------- /data_objects/DeepSpeakerDataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | 4 | import numpy as np 5 | import torch.utils.data as data 6 | from data_objects.speaker import Speaker 7 | from torchvision import transforms as T 8 | from data_objects.transforms import Normalize, TimeReverse, generate_test_sequence 9 | 10 | 11 | def find_classes(speakers): 12 | classes = list(set([speaker.name for speaker in speakers])) 13 | classes.sort() 14 | class_to_idx = {classes[i]: i for i in range(len(classes))} 15 | return classes, class_to_idx 16 | 17 | 18 | class DeepSpeakerDataset(data.Dataset): 19 | 20 | def __init__(self, data_dir, sub_dir, partial_n_frames, partition=None, is_test=False): 21 | super(DeepSpeakerDataset, self).__init__() 22 | self.data_dir = data_dir 23 | self.root = data_dir.joinpath('feature', sub_dir) 24 | self.partition = partition 25 | self.partial_n_frames = partial_n_frames 26 | self.is_test = is_test 27 | 28 | speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()] 29 | if len(speaker_dirs) == 0: 30 | raise Exception("No speakers found. Make sure you are pointing to the directory " 31 | "containing all preprocessed speaker directories.") 32 | self.speakers = [Speaker(speaker_dir, self.partition) for speaker_dir in speaker_dirs] 33 | 34 | classes, class_to_idx = find_classes(self.speakers) 35 | sources = [] 36 | for speaker in self.speakers: 37 | sources.extend(speaker.sources) 38 | self.features = [] 39 | for source in sources: 40 | item = (source[0].joinpath(source[1]), class_to_idx[source[2]]) 41 | self.features.append(item) 42 | mean = np.load(self.data_dir.joinpath('mean.npy')) 43 | std = np.load(self.data_dir.joinpath('std.npy')) 44 | self.transform = T.Compose([ 45 | Normalize(mean, std), 46 | TimeReverse(), 47 | ]) 48 | 49 | def load_feature(self, feature_path, speaker_id): 50 | feature = np.load(feature_path) 51 | if self.is_test: 52 | test_sequence = generate_test_sequence(feature, self.partial_n_frames) 53 | return test_sequence, speaker_id 54 | else: 55 | if feature.shape[0] <= self.partial_n_frames: 56 | start = 0 57 | while feature.shape[0] < self.partial_n_frames: 58 | feature = np.repeat(feature, 2, axis=0) 59 | else: 60 | start = np.random.randint(0, feature.shape[0] - self.partial_n_frames) 61 | end = start + self.partial_n_frames 62 | return feature[start:end], speaker_id 63 | 64 | def __getitem__(self, index): 65 | feature_path, speaker_id = self.features[index] 66 | feature, speaker_id = self.load_feature(feature_path, speaker_id) 67 | 68 | if self.transform is not None: 69 | feature = self.transform(feature) 70 | return feature, speaker_id 71 | 72 | def __len__(self): 73 | return len(self.features) 74 | 75 | -------------------------------------------------------------------------------- /data_objects/VoxcelebTestset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data as data 3 | import numpy as np 4 | from torchvision import transforms as T 5 | from data_objects.transforms import Normalize, generate_test_sequence 6 | 7 | 8 | def get_test_paths(pairs_path, db_dir): 9 | def convert_folder_name(path): 10 | basename = os.path.splitext(path)[0] 11 | items = basename.split('/') 12 | speaker_dir = items[0] 13 | fname = '{}_{}.npy'.format(items[1], items[2]) 14 | p = os.path.join(speaker_dir, fname) 15 | return p 16 | 17 | pairs = [line.strip().split() for line in open(pairs_path, 'r').readlines()] 18 | nrof_skipped_pairs = 0 19 | path_list = [] 20 | issame_list = [] 21 | 22 | for pair in pairs: 23 | if pair[0] == '1': 24 | issame = True 25 | else: 26 | issame = False 27 | 28 | path0 = db_dir.joinpath(convert_folder_name(pair[1])) 29 | path1 = db_dir.joinpath(convert_folder_name(pair[2])) 30 | 31 | if os.path.exists(path0) and os.path.exists(path1): # Only add the pair if both paths exist 32 | path_list.append((path0,path1,issame)) 33 | issame_list.append(issame) 34 | else: 35 | nrof_skipped_pairs += 1 36 | if nrof_skipped_pairs>0: 37 | print('Skipped %d image pairs' % nrof_skipped_pairs) 38 | 39 | return path_list 40 | 41 | 42 | class VoxcelebTestset(data.Dataset): 43 | def __init__(self, data_dir, partial_n_frames): 44 | super(VoxcelebTestset, self).__init__() 45 | self.data_dir = data_dir 46 | self.root = data_dir.joinpath('feature', 'test') 47 | self.test_pair_txt_fpath = data_dir.joinpath('veri_test.txt') 48 | self.test_pairs = get_test_paths(self.test_pair_txt_fpath, self.root) 49 | self.partial_n_frames = partial_n_frames 50 | mean = np.load(self.data_dir.joinpath('mean.npy')) 51 | std = np.load(self.data_dir.joinpath('std.npy')) 52 | self.transform = T.Compose([ 53 | Normalize(mean, std) 54 | ]) 55 | 56 | def load_feature(self, feature_path): 57 | feature = np.load(feature_path) 58 | test_sequence = generate_test_sequence(feature, self.partial_n_frames) 59 | return test_sequence 60 | 61 | def __getitem__(self, index): 62 | (path_1, path_2, issame) = self.test_pairs[index] 63 | 64 | feature1 = self.load_feature(path_1) 65 | feature2 = self.load_feature(path_2) 66 | 67 | if self.transform is not None: 68 | feature1 = self.transform(feature1) 69 | feature2 = self.transform(feature2) 70 | return feature1, feature2, issame 71 | 72 | def __len__(self): 73 | return len(self.test_pairs) 74 | -------------------------------------------------------------------------------- /data_objects/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NoviceMAn-prog/AutoSpeech/190049c87737a51452a7a4540ea2f1df200e0238/data_objects/__init__.py -------------------------------------------------------------------------------- /data_objects/audio.py: -------------------------------------------------------------------------------- 1 | from data_objects.params_data import * 2 | from pathlib import Path 3 | from typing import Optional, Union 4 | import numpy as np 5 | import librosa 6 | 7 | int16_max = (2 ** 15) - 1 8 | 9 | 10 | def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray], 11 | source_sr: Optional[int] = None): 12 | # Load the wav from disk if needed 13 | if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path): 14 | wav, source_sr = librosa.load(fpath_or_wav, sr=None) 15 | else: 16 | wav = fpath_or_wav 17 | 18 | # Resample the wav if needed 19 | if source_sr is not None and source_sr != sampling_rate: 20 | wav = librosa.resample(wav, source_sr, sampling_rate) 21 | 22 | # Apply the preprocessing: normalize volume and shorten long silences 23 | wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True) 24 | 25 | return wav 26 | 27 | 28 | def wav_to_spectrogram(wav): 29 | frames = np.abs(librosa.core.stft( 30 | wav, 31 | n_fft=n_fft, 32 | hop_length=int(sampling_rate * window_step / 1000), 33 | win_length=int(sampling_rate * window_length / 1000), 34 | )) 35 | return frames.astype(np.float32).T 36 | 37 | 38 | def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False): 39 | if increase_only and decrease_only: 40 | raise ValueError("Both increase only and decrease only are set") 41 | rms = np.sqrt(np.mean((wav * int16_max) ** 2)) 42 | wave_dBFS = 20 * np.log10(rms / int16_max) 43 | dBFS_change = target_dBFS - wave_dBFS 44 | if dBFS_change < 0 and increase_only or dBFS_change > 0 and decrease_only: 45 | return wav 46 | return wav * (10 ** (dBFS_change / 20)) 47 | -------------------------------------------------------------------------------- /data_objects/compute_mean_std.py: -------------------------------------------------------------------------------- 1 | from data_objects.speaker import Speaker 2 | import numpy as np 3 | 4 | def compute_mean_std(dataset_dir, output_path_mean, output_path_std): 5 | print("Computing mean std...") 6 | speaker_dirs = [f for f in dataset_dir.glob("*") if f.is_dir()] 7 | if len(speaker_dirs) == 0: 8 | raise Exception("No speakers found. Make sure you are pointing to the directory " 9 | "containing all preprocessed speaker directories.") 10 | speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs] 11 | 12 | sources = [] 13 | for speaker in speakers: 14 | sources.extend(speaker.sources) 15 | 16 | sumx = np.zeros(257, dtype=np.float32) 17 | sumx2 = np.zeros(257, dtype=np.float32) 18 | count = 0 19 | n = len(sources) 20 | for i, source in enumerate(sources): 21 | feature = np.load(source[0].joinpath(source[1])) 22 | sumx += feature.sum(axis=0) 23 | sumx2 += (feature * feature).sum(axis=0) 24 | count += feature.shape[0] 25 | 26 | mean = sumx / count 27 | std = np.sqrt(sumx2 / count - mean * mean) 28 | 29 | mean = mean.astype(np.float32) 30 | std = std.astype(np.float32) 31 | 32 | np.save(output_path_mean, mean) 33 | np.save(output_path_std, std) -------------------------------------------------------------------------------- /data_objects/params_data.py: -------------------------------------------------------------------------------- 1 | 2 | ## Mel-filterbank 3 | window_length = 25 # In milliseconds 4 | window_step = 10 # In milliseconds 5 | n_fft = 512 6 | 7 | 8 | ## Audio 9 | sampling_rate = 16000 10 | # Number of spectrogram frames in a partial utterance 11 | partials_n_frames = 300 # 3000 ms 12 | 13 | 14 | ## Audio volume normalization 15 | audio_norm_target_dBFS = -30 16 | 17 | -------------------------------------------------------------------------------- /data_objects/partition_voxceleb.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def partition_voxceleb(feature_root, split_txt_path): 5 | print("partitioning VoxCeleb...") 6 | with open(split_txt_path, 'r') as f: 7 | split_txt = f.readlines() 8 | train_set = [] 9 | val_set = [] 10 | test_set = [] 11 | for line in split_txt: 12 | items = line.strip().split() 13 | if items[0] == '3': 14 | test_set.append(items[1]) 15 | elif items[0] == '2': 16 | val_set.append(items[1]) 17 | else: 18 | train_set.append(items[1]) 19 | 20 | speakers = os.listdir(feature_root) 21 | 22 | for speaker in speakers: 23 | speaker_dir = os.path.join(feature_root, speaker) 24 | if not os.path.isdir(speaker_dir): 25 | continue 26 | with open(os.path.join(speaker_dir, '_sources.txt'), 'r') as f: 27 | speaker_files = f.readlines() 28 | 29 | train = [] 30 | val = [] 31 | test = [] 32 | for line in speaker_files: 33 | address = line.strip().split(',')[1] 34 | fname = os.path.join(*address.split('/')[-3:]) 35 | if fname in test_set: 36 | test.append(line) 37 | elif fname in val_set: 38 | val.append(line) 39 | elif fname in train_set: 40 | train.append(line) 41 | else: 42 | print('file not in either train or test set') 43 | 44 | with open(os.path.join(speaker_dir, '_sources_train.txt'), 'w') as f: 45 | f.writelines('%s' % line for line in train) 46 | with open(os.path.join(speaker_dir, '_sources_val.txt'), 'w') as f: 47 | f.writelines('%s' % line for line in val) 48 | with open(os.path.join(speaker_dir, '_sources_test.txt'), 'w') as f: 49 | f.writelines('%s' % line for line in test) 50 | -------------------------------------------------------------------------------- /data_objects/preprocess.py: -------------------------------------------------------------------------------- 1 | from multiprocess.pool import ThreadPool 2 | from data_objects.params_data import * 3 | from datetime import datetime 4 | from data_objects import audio 5 | from pathlib import Path 6 | from tqdm import tqdm 7 | import numpy as np 8 | 9 | anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"] 10 | 11 | class DatasetLog: 12 | """ 13 | Registers metadata about the dataset in a text file. 14 | """ 15 | 16 | def __init__(self, root, name): 17 | self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w") 18 | self.sample_data = dict() 19 | 20 | start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M")) 21 | self.write_line("Creating dataset %s on %s" % (name, start_time)) 22 | self.write_line("-----") 23 | self._log_params() 24 | 25 | def _log_params(self): 26 | from data_objects import params_data 27 | self.write_line("Parameter values:") 28 | for param_name in (p for p in dir(params_data) if not p.startswith("__")): 29 | value = getattr(params_data, param_name) 30 | self.write_line("\t%s: %s" % (param_name, value)) 31 | self.write_line("-----") 32 | 33 | def write_line(self, line): 34 | self.text_file.write("%s\n" % line) 35 | 36 | def add_sample(self, **kwargs): 37 | for param_name, value in kwargs.items(): 38 | if not param_name in self.sample_data: 39 | self.sample_data[param_name] = [] 40 | self.sample_data[param_name].append(value) 41 | 42 | def finalize(self): 43 | self.write_line("Statistics:") 44 | for param_name, values in self.sample_data.items(): 45 | self.write_line("\t%s:" % param_name) 46 | self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values))) 47 | self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values))) 48 | self.write_line("-----") 49 | end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M")) 50 | self.write_line("Finished on %s" % end_time) 51 | self.text_file.close() 52 | 53 | 54 | def _init_preprocess_dataset(dataset_name, dataset_root, out_dir) -> (Path, DatasetLog): 55 | if not dataset_root.exists(): 56 | print("Couldn\'t find %s, skipping this dataset." % dataset_root) 57 | return None, None 58 | return dataset_root, DatasetLog(out_dir, dataset_name) 59 | 60 | 61 | def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, extension, 62 | skip_existing, logger): 63 | print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs))) 64 | 65 | # Function to preprocess utterances for one speaker 66 | def preprocess_speaker(speaker_dir: Path): 67 | # Give a name to the speaker that includes its dataset 68 | speaker_name = speaker_dir.parts[-1] 69 | 70 | # Create an output directory with that name, as well as a txt file containing a 71 | # reference to each source file. 72 | speaker_out_dir = out_dir.joinpath(speaker_name) 73 | speaker_out_dir.mkdir(exist_ok=True) 74 | sources_fpath = speaker_out_dir.joinpath("_sources.txt") 75 | 76 | # There's a possibility that the preprocessing was interrupted earlier, check if 77 | # there already is a sources file. 78 | if sources_fpath.exists(): 79 | try: 80 | with sources_fpath.open("r") as sources_file: 81 | existing_fnames = {line.split(",")[0] for line in sources_file} 82 | except: 83 | existing_fnames = {} 84 | else: 85 | existing_fnames = {} 86 | 87 | # Gather all audio files for that speaker recursively 88 | sources_file = sources_fpath.open("a" if skip_existing else "w") 89 | for in_fpath in speaker_dir.glob("**/*.%s" % extension): 90 | # Check if the target output file already exists 91 | out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts) 92 | out_fname = out_fname.replace(".%s" % extension, ".npy") 93 | if skip_existing and out_fname in existing_fnames: 94 | continue 95 | 96 | # Load and preprocess the waveform 97 | wav = audio.preprocess_wav(in_fpath) 98 | if len(wav) == 0: 99 | print(in_fpath) 100 | continue 101 | 102 | # Create the mel spectrogram, discard those that are too short 103 | # frames = audio.wav_to_mel_spectrogram(wav) 104 | frames = audio.wav_to_spectrogram(wav) 105 | if len(frames) < partials_n_frames: 106 | continue 107 | 108 | out_fpath = speaker_out_dir.joinpath(out_fname) 109 | np.save(out_fpath, frames) 110 | logger.add_sample(duration=len(wav) / sampling_rate) 111 | sources_file.write("%s,%s\n" % (out_fname, in_fpath)) 112 | 113 | sources_file.close() 114 | 115 | # Process the utterances for each speaker 116 | preprocess_speaker(speaker_dirs[0]) 117 | with ThreadPool(1) as pool: 118 | list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs), 119 | unit="speakers")) 120 | logger.finalize() 121 | print("Done preprocessing %s.\n" % dataset_name) 122 | 123 | 124 | def preprocess_voxceleb1(dataset_root: Path, parition: str, out_dir: Path, skip_existing=False): 125 | # Initialize the preprocessing 126 | dataset_name = "VoxCeleb1" 127 | dataset_root, logger = _init_preprocess_dataset(dataset_name, dataset_root, out_dir) 128 | if not dataset_root: 129 | return 130 | 131 | # Get the contents of the meta file 132 | with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile: 133 | metadata = [line.split("\t") for line in metafile][1:] 134 | 135 | # Select the ID and the nationality, filter out non-anglophone speakers 136 | nationalities = {line[0]: line[3] for line in metadata} 137 | keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if 138 | nationality.lower() in anglophone_nationalites] 139 | print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." % 140 | (len(keep_speaker_ids), len(nationalities))) 141 | 142 | # Get the speaker directories for anglophone speakers only 143 | speaker_dirs = dataset_root.joinpath(parition, 'wav').glob("*") 144 | speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs] 145 | 146 | print("VoxCeleb1: found %d anglophone speakers on the disk." % 147 | (len(speaker_dirs))) 148 | # Preprocess all speakers 149 | _preprocess_speaker_dirs(speaker_dirs, dataset_name, dataset_root, out_dir, "wav", 150 | skip_existing, logger) 151 | 152 | -------------------------------------------------------------------------------- /data_objects/speaker.py: -------------------------------------------------------------------------------- 1 | from data_objects.utterance import Utterance 2 | from pathlib import Path 3 | 4 | 5 | # Contains the set of utterances of a single speaker 6 | class Speaker: 7 | def __init__(self, root: Path, partition=None): 8 | self.root = root 9 | self.partition = partition 10 | self.name = root.name 11 | self.utterances = None 12 | self.utterance_cycler = None 13 | if self.partition is None: 14 | with self.root.joinpath("_sources.txt").open("r") as sources_file: 15 | sources = [l.strip().split(",") for l in sources_file] 16 | else: 17 | with self.root.joinpath("_sources_{}.txt".format(self.partition)).open("r") as sources_file: 18 | sources = [l.strip().split(",") for l in sources_file] 19 | self.sources = [[self.root, frames_fname, self.name, wav_path] for frames_fname, wav_path in sources] 20 | 21 | def _load_utterances(self): 22 | self.utterances = [Utterance(source[0].joinpath(source[1])) for source in self.sources] 23 | 24 | def random_partial(self, count, n_frames): 25 | """ 26 | Samples a batch of unique partial utterances from the disk in a way that all 27 | utterances come up at least once every two cycles and in a random order every time. 28 | 29 | :param count: The number of partial utterances to sample from the set of utterances from 30 | that speaker. Utterances are guaranteed not to be repeated if is not larger than 31 | the number of utterances available. 32 | :param n_frames: The number of frames in the partial utterance. 33 | :return: A list of tuples (utterance, frames, range) where utterance is an Utterance, 34 | frames are the frames of the partial utterances and range is the range of the partial 35 | utterance with regard to the complete utterance. 36 | """ 37 | if self.utterances is None: 38 | self._load_utterances() 39 | 40 | utterances = self.utterance_cycler.sample(count) 41 | 42 | a = [(u,) + u.random_partial(n_frames) for u in utterances] 43 | 44 | return a -------------------------------------------------------------------------------- /data_objects/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | class Normalize(object): 5 | def __init__(self, mean, std): 6 | super(Normalize, self).__init__() 7 | self.mean = mean 8 | self.std = std 9 | 10 | def __call__(self, input): 11 | return (input - self.mean) / self.std 12 | 13 | 14 | class TimeReverse(object): 15 | def __init__(self, p=0.5): 16 | super(TimeReverse, self).__init__() 17 | self.p = p 18 | 19 | def __call__(self, input): 20 | if random.random() < self.p: 21 | return np.flip(input, axis=0).copy() 22 | return input 23 | 24 | 25 | def generate_test_sequence(feature, partial_n_frames, shift=None): 26 | while feature.shape[0] <= partial_n_frames: 27 | feature = np.repeat(feature, 2, axis=0) 28 | if shift is None: 29 | shift = partial_n_frames // 2 30 | test_sequence = [] 31 | start = 0 32 | while start + partial_n_frames <= feature.shape[0]: 33 | test_sequence.append(feature[start: start + partial_n_frames]) 34 | start += shift 35 | test_sequence = np.stack(test_sequence, axis=0) 36 | return test_sequence -------------------------------------------------------------------------------- /data_objects/utterance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Utterance: 5 | def __init__(self, frames_fpath): 6 | self.frames_fpath = frames_fpath 7 | 8 | def get_frames(self): 9 | return np.load(self.frames_fpath) 10 | 11 | def random_partial(self, n_frames): 12 | """ 13 | Crops the frames into a partial utterance of n_frames 14 | 15 | :param n_frames: The number of frames of the partial utterance 16 | :return: the partial utterance frames and a tuple indicating the start and end of the 17 | partial utterance in the complete utterance. 18 | """ 19 | frames = self.get_frames() 20 | if frames.shape[0] == n_frames: 21 | start = 0 22 | else: 23 | start = np.random.randint(0, frames.shape[0] - n_frames) 24 | end = start + n_frames 25 | return frames[start:end], (start, end) -------------------------------------------------------------------------------- /data_preprocess.py: -------------------------------------------------------------------------------- 1 | from data_objects.preprocess import preprocess_voxceleb1 2 | from data_objects.compute_mean_std import compute_mean_std 3 | from data_objects.partition_voxceleb import partition_voxceleb 4 | from pathlib import Path 5 | import argparse 6 | import subprocess 7 | 8 | if __name__ == "__main__": 9 | class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter): 10 | pass 11 | 12 | parser = argparse.ArgumentParser(description="Preprocesses audio files from datasets.", 13 | formatter_class=MyFormatter 14 | ) 15 | parser.add_argument("dataset_root", type=Path, help= \ 16 | "Path to the directory containing VoxCeleb datasets. It should be arranged as:") 17 | parser.add_argument("-s", "--skip_existing", action="store_true", help= \ 18 | "Whether to skip existing output files with the same name. Useful if this script was " 19 | "interrupted.") 20 | args = parser.parse_args() 21 | 22 | # Process the arguments 23 | dev_out_dir = args.dataset_root.joinpath("feature", "dev") 24 | test_out_dir = args.dataset_root.joinpath("feature", "test") 25 | merged_out_dir = args.dataset_root.joinpath("feature", "merged") 26 | assert args.dataset_root.exists() 27 | assert args.dataset_root.joinpath('iden_split.txt').exists() 28 | assert args.dataset_root.joinpath('veri_test.txt').exists() 29 | assert args.dataset_root.joinpath('vox1_meta.csv').exists() 30 | dev_out_dir.mkdir(exist_ok=True, parents=True) 31 | test_out_dir.mkdir(exist_ok=True, parents=True) 32 | merged_out_dir.mkdir(exist_ok=True, parents=True) 33 | 34 | # Preprocess the datasets 35 | preprocess_voxceleb1(args.dataset_root, 'dev', dev_out_dir, args.skip_existing) 36 | preprocess_voxceleb1(args.dataset_root, 'test', test_out_dir, args.skip_existing) 37 | for path in dev_out_dir.iterdir(): 38 | subprocess.call(['cp', '-r', path.as_posix(), merged_out_dir.as_posix()]) 39 | for path in test_out_dir.iterdir(): 40 | subprocess.call(['cp', '-r', path.as_posix(), merged_out_dir.as_posix()]) 41 | compute_mean_std(merged_out_dir, args.dataset_root.joinpath('mean.npy'), 42 | args.dataset_root.joinpath('std.npy')) 43 | partition_voxceleb(merged_out_dir, args.dataset_root.joinpath('iden_split.txt')) 44 | print("Done") 45 | -------------------------------------------------------------------------------- /dl_script.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | 3 | # Contributed by Aaron Soellinger 4 | # Usage (*nix): 5 | # $ mkdir VoxCeleb1; cd VoxCeleb1; /bin/bash path/to/dl_script.sh "yourusername" "yourpassword" 6 | # Note: I found my username and password in an email titled "VoxCeleb dataset" 7 | 8 | U=$1 9 | P=$2 10 | wget http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa --user "$U" --password "$P" & 11 | wget http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab --user "$U" --password "$P" & 12 | wget http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac --user "$U" --password "$P" & 13 | wget http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad --user "$U" --password "$P" & 14 | wget http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip --user "$U" --password "$P" & 15 | vox1_dev* > vox1_dev_wav.zip 16 | unzip vox1_dev_wav.zip -d "dev" & 17 | unzip vox1_test_wav.zip -d "test" 18 | rm vox1_dev_wav_part* 19 | rm wget* 20 | wget http://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/vox1_meta.csv 21 | wget http://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test.txt 22 | wget http://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/iden_split.txt 23 | -------------------------------------------------------------------------------- /evaluate_identification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import numpy as np 7 | import os 8 | from pathlib import Path 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.backends.cudnn as cudnn 13 | 14 | from models.model import Network 15 | from models import resnet 16 | from config import cfg, update_config 17 | from utils import create_logger, Genotype 18 | from data_objects.DeepSpeakerDataset import DeepSpeakerDataset 19 | from functions import validate_identification 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description='Train autospeech network') 24 | # general 25 | parser.add_argument('--cfg', 26 | help='experiment configure file name', 27 | required=True, 28 | type=str) 29 | 30 | parser.add_argument('opts', 31 | help="Modify config options using the command-line", 32 | default=None, 33 | nargs=argparse.REMAINDER) 34 | 35 | parser.add_argument('--load_path', 36 | help="The path to resumed dir", 37 | required=True, 38 | default=None) 39 | 40 | args = parser.parse_args() 41 | 42 | return args 43 | 44 | 45 | def main(): 46 | args = parse_args() 47 | update_config(cfg, args) 48 | if args.load_path is None: 49 | raise AttributeError("Please specify load path.") 50 | 51 | # cudnn related setting 52 | cudnn.benchmark = cfg.CUDNN.BENCHMARK 53 | torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC 54 | torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED 55 | 56 | # Set the random seed manually for reproducibility. 57 | np.random.seed(cfg.SEED) 58 | torch.manual_seed(cfg.SEED) 59 | torch.cuda.manual_seed_all(cfg.SEED) 60 | 61 | # model and optimizer 62 | if cfg.MODEL.NAME == 'model': 63 | if args.load_path and os.path.exists(args.load_path): 64 | checkpoint = torch.load(args.load_path) 65 | genotype = checkpoint['genotype'] 66 | else: 67 | raise AssertionError('Please specify the model to evaluate') 68 | model = Network(cfg.MODEL.INIT_CHANNELS, cfg.MODEL.NUM_CLASSES, cfg.MODEL.LAYERS, genotype) 69 | model.drop_path_prob = 0.0 70 | else: 71 | model = eval('resnet.{}(num_classes={})'.format(cfg.MODEL.NAME, cfg.MODEL.NUM_CLASSES)) 72 | model = model.cuda() 73 | 74 | criterion = nn.CrossEntropyLoss().cuda() 75 | 76 | # resume && make log dir and logger 77 | if args.load_path and os.path.exists(args.load_path): 78 | checkpoint = torch.load(args.load_path) 79 | 80 | # load checkpoint 81 | model.load_state_dict(checkpoint['state_dict']) 82 | args.path_helper = checkpoint['path_helper'] 83 | 84 | logger = create_logger(os.path.dirname(args.load_path)) 85 | logger.info("=> loaded checkpoint '{}'".format(args.load_path)) 86 | else: 87 | raise AssertionError('Please specify the model to evaluate') 88 | logger.info(args) 89 | logger.info(cfg) 90 | 91 | # dataloader 92 | test_dataset_identification = DeepSpeakerDataset( 93 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.SUB_DIR, cfg.DATASET.PARTIAL_N_FRAMES, 'test', is_test=True) 94 | 95 | 96 | test_loader_identification = torch.utils.data.DataLoader( 97 | dataset=test_dataset_identification, 98 | batch_size=1, 99 | num_workers=cfg.DATASET.NUM_WORKERS, 100 | pin_memory=True, 101 | shuffle=False, 102 | drop_last=True, 103 | ) 104 | 105 | validate_identification(cfg, model, test_loader_identification, criterion) 106 | 107 | 108 | 109 | if __name__ == '__main__': 110 | main() 111 | -------------------------------------------------------------------------------- /evaluate_verification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import numpy as np 7 | import os 8 | from pathlib import Path 9 | 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | 13 | from models.model import Network 14 | from models import resnet 15 | from config import cfg, update_config 16 | from utils import create_logger, Genotype 17 | from data_objects.VoxcelebTestset import VoxcelebTestset 18 | from functions import validate_verification 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description='Train autospeech network') 23 | # general 24 | parser.add_argument('--cfg', 25 | help='experiment configure file name', 26 | required=True, 27 | type=str) 28 | 29 | parser.add_argument('opts', 30 | help="Modify config options using the command-line", 31 | default=None, 32 | nargs=argparse.REMAINDER) 33 | 34 | parser.add_argument('--load_path', 35 | help="The path to resumed dir", 36 | default=None) 37 | parser.add_argument('--text_arch', 38 | help="The path to arch", 39 | default=None) 40 | 41 | args = parser.parse_args() 42 | 43 | return args 44 | 45 | 46 | def main(): 47 | args = parse_args() 48 | update_config(cfg, args) 49 | if args.load_path is None: 50 | raise AttributeError("Please specify load path.") 51 | 52 | # cudnn related setting 53 | cudnn.benchmark = cfg.CUDNN.BENCHMARK 54 | torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC 55 | torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED 56 | 57 | # Set the random seed manually for reproducibility. 58 | np.random.seed(cfg.SEED) 59 | torch.manual_seed(cfg.SEED) 60 | torch.cuda.manual_seed_all(cfg.SEED) 61 | 62 | # model and optimizer 63 | if cfg.MODEL.NAME == 'model': 64 | if args.load_path and os.path.exists(args.load_path): 65 | checkpoint = torch.load(args.load_path) 66 | genotype = checkpoint['genotype'] 67 | else: 68 | raise AssertionError('Please specify the model to evaluate') 69 | model = Network(cfg.MODEL.INIT_CHANNELS, cfg.MODEL.NUM_CLASSES, cfg.MODEL.LAYERS, genotype) 70 | model.drop_path_prob = 0.0 71 | else: 72 | model = eval('resnet.{}(num_classes={})'.format(cfg.MODEL.NAME, cfg.MODEL.NUM_CLASSES)) 73 | model = model.cuda() 74 | 75 | # resume && make log dir and logger 76 | if args.load_path and os.path.exists(args.load_path): 77 | checkpoint = torch.load(args.load_path) 78 | 79 | # load checkpoint 80 | model.load_state_dict(checkpoint['state_dict']) 81 | args.path_helper = checkpoint['path_helper'] 82 | 83 | logger = create_logger(os.path.dirname(args.load_path)) 84 | logger.info("=> loaded checkpoint '{}'".format(args.load_path)) 85 | else: 86 | raise AssertionError('Please specify the model to evaluate') 87 | logger.info(args) 88 | logger.info(cfg) 89 | 90 | # dataloader 91 | test_dataset_verification = VoxcelebTestset( 92 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.PARTIAL_N_FRAMES 93 | ) 94 | test_loader_verification = torch.utils.data.DataLoader( 95 | dataset=test_dataset_verification, 96 | batch_size=1, 97 | num_workers=cfg.DATASET.NUM_WORKERS, 98 | pin_memory=True, 99 | shuffle=False, 100 | drop_last=False, 101 | ) 102 | 103 | validate_verification(cfg, model, test_loader_verification) 104 | 105 | 106 | 107 | if __name__ == '__main__': 108 | main() 109 | -------------------------------------------------------------------------------- /exps/baseline/resnet18_iden.yaml: -------------------------------------------------------------------------------- 1 | PRINT_FREQ: 200 2 | VAL_FREQ: 10 3 | 4 | CUDNN: 5 | BENCHMARK: true 6 | DETERMINISTIC: false 7 | ENABLED: true 8 | 9 | DATASET: 10 | DATA_DIR: '/path/to/VoxCeleb1' 11 | SUB_DIR: 'merged' 12 | NUM_WORKERS: 0 13 | PARTIAL_N_FRAMES: 300 14 | 15 | TRAIN: 16 | BATCH_SIZE: 256 17 | LR: 0.01 18 | LR_MIN: 0.001 19 | BETA1: 0.9 20 | BETA2: 0.999 21 | 22 | BEGIN_EPOCH: 0 23 | END_EPOCH: 301 24 | 25 | MODEL: 26 | NAME: 'resnet18' 27 | NUM_CLASSES: 1251 28 | INIT_CHANNELS: 64 -------------------------------------------------------------------------------- /exps/baseline/resnet18_veri.yaml: -------------------------------------------------------------------------------- 1 | PRINT_FREQ: 200 2 | VAL_FREQ: 10 3 | 4 | CUDNN: 5 | BENCHMARK: true 6 | DETERMINISTIC: false 7 | ENABLED: true 8 | 9 | DATASET: 10 | DATA_DIR: '/path/to/VoxCeleb1' 11 | SUB_DIR: 'dev' 12 | NUM_WORKERS: 0 13 | PARTIAL_N_FRAMES: 300 14 | 15 | TRAIN: 16 | BATCH_SIZE: 256 17 | LR: 0.01 18 | LR_MIN: 0.001 19 | BETA1: 0.9 20 | BETA2: 0.999 21 | 22 | BEGIN_EPOCH: 0 23 | END_EPOCH: 301 24 | 25 | MODEL: 26 | NAME: 'resnet18' 27 | NUM_CLASSES: 1211 28 | INIT_CHANNELS: 64 -------------------------------------------------------------------------------- /exps/baseline/resnet34_iden.yaml: -------------------------------------------------------------------------------- 1 | PRINT_FREQ: 200 2 | VAL_FREQ: 10 3 | 4 | CUDNN: 5 | BENCHMARK: true 6 | DETERMINISTIC: false 7 | ENABLED: true 8 | 9 | DATASET: 10 | DATA_DIR: '/path/to/VoxCeleb1' 11 | SUB_DIR: 'merged' 12 | NUM_WORKERS: 0 13 | PARTIAL_N_FRAMES: 300 14 | 15 | TRAIN: 16 | BATCH_SIZE: 128 17 | LR: 0.01 18 | LR_MIN: 0.001 19 | BETA1: 0.9 20 | BETA2: 0.999 21 | 22 | BEGIN_EPOCH: 0 23 | END_EPOCH: 301 24 | 25 | MODEL: 26 | NAME: 'resnet34' 27 | NUM_CLASSES: 1251 28 | INIT_CHANNELS: 64 -------------------------------------------------------------------------------- /exps/baseline/resnet34_veri.yaml: -------------------------------------------------------------------------------- 1 | PRINT_FREQ: 200 2 | VAL_FREQ: 10 3 | 4 | CUDNN: 5 | BENCHMARK: true 6 | DETERMINISTIC: false 7 | ENABLED: true 8 | 9 | DATASET: 10 | DATA_DIR: '/path/to/VoxCeleb1' 11 | SUB_DIR: 'dev' 12 | NUM_WORKERS: 0 13 | PARTIAL_N_FRAMES: 300 14 | 15 | TRAIN: 16 | BATCH_SIZE: 128 17 | LR: 0.01 18 | LR_MIN: 0.001 19 | BETA1: 0.9 20 | BETA2: 0.999 21 | 22 | BEGIN_EPOCH: 0 23 | END_EPOCH: 301 24 | 25 | MODEL: 26 | NAME: 'resnet34' 27 | NUM_CLASSES: 1211 28 | INIT_CHANNELS: 128 -------------------------------------------------------------------------------- /exps/scratch/scratch_iden.yaml: -------------------------------------------------------------------------------- 1 | PRINT_FREQ: 200 2 | VAL_FREQ: 10 3 | 4 | CUDNN: 5 | BENCHMARK: true 6 | DETERMINISTIC: false 7 | ENABLED: true 8 | 9 | DATASET: 10 | DATA_DIR: '/path/to/VoxCeleb1' 11 | SUB_DIR: 'merged' 12 | NUM_WORKERS: 0 13 | PARTIAL_N_FRAMES: 300 14 | 15 | TRAIN: 16 | BATCH_SIZE: 96 17 | LR: 0.01 18 | LR_MIN: 0.001 19 | BETA1: 0.9 20 | BETA2: 0.999 21 | 22 | BEGIN_EPOCH: 0 23 | END_EPOCH: 301 24 | 25 | MODEL: 26 | NAME: 'model' 27 | NUM_CLASSES: 1251 28 | LAYERS: 8 29 | INIT_CHANNELS: 128 -------------------------------------------------------------------------------- /exps/scratch/scratch_veri.yaml: -------------------------------------------------------------------------------- 1 | PRINT_FREQ: 200 2 | VAL_FREQ: 10 3 | 4 | CUDNN: 5 | BENCHMARK: true 6 | DETERMINISTIC: false 7 | ENABLED: true 8 | 9 | DATASET: 10 | DATA_DIR: '/path/to/VoxCeleb1' 11 | SUB_DIR: 'dev' 12 | NUM_WORKERS: 0 13 | PARTIAL_N_FRAMES: 300 14 | 15 | TRAIN: 16 | BATCH_SIZE: 48 17 | LR: 0.01 18 | LR_MIN: 0.001 19 | BETA1: 0.9 20 | BETA2: 0.999 21 | 22 | BEGIN_EPOCH: 0 23 | END_EPOCH: 301 24 | 25 | MODEL: 26 | NAME: 'model' 27 | NUM_CLASSES: 1211 28 | LAYERS: 8 29 | INIT_CHANNELS: 128 -------------------------------------------------------------------------------- /exps/search.yaml: -------------------------------------------------------------------------------- 1 | PRINT_FREQ: 200 2 | VAL_FREQ: 5 3 | 4 | CUDNN: 5 | BENCHMARK: true 6 | DETERMINISTIC: false 7 | ENABLED: true 8 | 9 | DATASET: 10 | DATA_DIR: '/path/to/VoxCeleb1' 11 | SUB_DIR: 'merged' 12 | NUM_WORKERS: 0 13 | PARTIAL_N_FRAMES: 300 14 | 15 | TRAIN: 16 | BATCH_SIZE: 2 17 | LR: 0.01 18 | LR_MIN: 0.001 19 | WD: 0.0003 20 | BETA1: 0.9 21 | BETA2: 0.999 22 | 23 | ARCH_LR: 0.001 24 | ARCH_WD: 0.001 25 | ARCH_BETA1: 0.9 26 | ARCH_BETA2: 0.999 27 | DROPPATH_PROB: 0.2 28 | 29 | BEGIN_EPOCH: 0 30 | END_EPOCH: 50 31 | 32 | MODEL: 33 | NAME: 'model_search' 34 | NUM_CLASSES: 1251 35 | LAYERS: 8 36 | INIT_CHANNELS: 16 37 | -------------------------------------------------------------------------------- /figures/searched_arch_normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NoviceMAn-prog/AutoSpeech/190049c87737a51452a7a4540ea2f1df200e0238/figures/searched_arch_normal.png -------------------------------------------------------------------------------- /figures/searched_arch_reduce.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NoviceMAn-prog/AutoSpeech/190049c87737a51452a7a4540ea2f1df200e0238/figures/searched_arch_reduce.png -------------------------------------------------------------------------------- /functions.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn.functional as F 4 | import logging 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | from utils import compute_eer 9 | from utils import AverageMeter, ProgressMeter, accuracy 10 | 11 | plt.switch_backend('agg') 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def train(cfg, model, optimizer, train_loader, val_loader, criterion, architect, epoch, writer_dict, lr_scheduler=None): 16 | batch_time = AverageMeter('Time', ':6.3f') 17 | data_time = AverageMeter('Data', ':6.3f') 18 | losses = AverageMeter('Loss', ':.4e') 19 | top1 = AverageMeter('Acc@1', ':6.2f') 20 | top5 = AverageMeter('Acc@5', ':6.2f') 21 | alpha_entropies = AverageMeter('Entropy', ':.4e') 22 | progress = ProgressMeter( 23 | len(train_loader), batch_time, data_time, losses, top1, top5, alpha_entropies, 24 | prefix="Epoch: [{}]".format(epoch), logger=logger) 25 | writer = writer_dict['writer'] 26 | 27 | # switch to train mode 28 | model.train() 29 | 30 | end = time.time() 31 | for i, (input, target) in enumerate(train_loader): 32 | global_steps = writer_dict['train_global_steps'] 33 | 34 | if lr_scheduler: 35 | current_lr = lr_scheduler.set_lr(optimizer, global_steps, epoch) 36 | else: 37 | current_lr = cfg.TRAIN.LR 38 | 39 | # measure data loading time 40 | data_time.update(time.time() - end) 41 | 42 | input = input.cuda(non_blocking=True) 43 | target = target.cuda(non_blocking=True) 44 | 45 | input_search, target_search = next(iter(val_loader)) 46 | input_search = input_search.cuda(non_blocking=True) 47 | target_search = target_search.cuda(non_blocking=True) 48 | 49 | # step architecture 50 | architect.step(input_search, target_search) 51 | 52 | alpha_entropy = architect.model.compute_arch_entropy() 53 | alpha_entropies.update(alpha_entropy.mean(), input.size(0)) 54 | 55 | # compute output 56 | output = model(input) 57 | 58 | # measure accuracy and record loss 59 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 60 | top1.update(acc1[0], input.size(0)) 61 | top5.update(acc5[0], input.size(0)) 62 | loss = criterion(output, target) 63 | losses.update(loss.item(), input.size(0)) 64 | 65 | # compute gradient and do SGD step 66 | optimizer.zero_grad() 67 | loss.backward() 68 | optimizer.step() 69 | 70 | # measure elapsed time 71 | batch_time.update(time.time() - end) 72 | end = time.time() 73 | 74 | # write to logger 75 | writer.add_scalar('lr', current_lr, global_steps) 76 | writer.add_scalar('train_loss', losses.val, global_steps) 77 | writer.add_scalar('arch_entropy', alpha_entropies.val, global_steps) 78 | 79 | writer_dict['train_global_steps'] = global_steps + 1 80 | 81 | # log acc for cross entropy loss 82 | writer.add_scalar('train_acc1', top1.val, global_steps) 83 | writer.add_scalar('train_acc5', top5.val, global_steps) 84 | 85 | if i % cfg.PRINT_FREQ == 0: 86 | progress.print(i) 87 | 88 | 89 | def train_from_scratch(cfg, model, optimizer, train_loader, criterion, epoch, writer_dict, lr_scheduler=None): 90 | batch_time = AverageMeter('Time', ':6.3f') 91 | data_time = AverageMeter('Data', ':6.3f') 92 | losses = AverageMeter('Loss', ':.4e') 93 | top1 = AverageMeter('Acc@1', ':6.2f') 94 | top5 = AverageMeter('Acc@5', ':6.2f') 95 | progress = ProgressMeter( 96 | len(train_loader), batch_time, data_time, losses, top1, top5, prefix="Epoch: [{}]".format(epoch), logger=logger) 97 | writer = writer_dict['writer'] 98 | 99 | # switch to train mode 100 | model.train() 101 | 102 | end = time.time() 103 | for i, (input, target) in enumerate(train_loader): 104 | global_steps = writer_dict['train_global_steps'] 105 | 106 | if lr_scheduler: 107 | current_lr = lr_scheduler.get_lr() 108 | else: 109 | current_lr = cfg.TRAIN.LR 110 | 111 | # measure data loading time 112 | data_time.update(time.time() - end) 113 | 114 | input = input.cuda(non_blocking=True) 115 | target = target.cuda(non_blocking=True) 116 | 117 | # compute output 118 | output = model(input) 119 | 120 | # measure accuracy and record loss 121 | loss = criterion(output, target) 122 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 123 | top1.update(acc1[0], input.size(0)) 124 | top5.update(acc5[0], input.size(0)) 125 | losses.update(loss.item(), input.size(0)) 126 | 127 | # compute gradient and do SGD step 128 | optimizer.zero_grad() 129 | loss.backward() 130 | optimizer.step() 131 | 132 | # measure elapsed time 133 | batch_time.update(time.time() - end) 134 | end = time.time() 135 | 136 | # write to logger 137 | writer.add_scalar('lr', current_lr, global_steps) 138 | writer.add_scalar('train_loss', losses.val, global_steps) 139 | writer_dict['train_global_steps'] = global_steps + 1 140 | 141 | # log acc for cross entropy loss 142 | writer.add_scalar('train_acc1', top1.val, global_steps) 143 | writer.add_scalar('train_acc5', top5.val, global_steps) 144 | 145 | if i % cfg.PRINT_FREQ == 0: 146 | progress.print(i) 147 | 148 | 149 | def validate_verification(cfg, model, test_loader): 150 | batch_time = AverageMeter('Time', ':6.3f') 151 | progress = ProgressMeter( 152 | len(test_loader), batch_time, prefix='Test: ', logger=logger) 153 | 154 | # switch to evaluate mode 155 | model.eval() 156 | labels, distances = [], [] 157 | 158 | with torch.no_grad(): 159 | end = time.time() 160 | for i, (input1, input2, label) in enumerate(test_loader): 161 | input1 = input1.cuda(non_blocking=True).squeeze(0) 162 | input2 = input2.cuda(non_blocking=True).squeeze(0) 163 | label = label.cuda(non_blocking=True) 164 | 165 | # compute output 166 | outputs1 = model(input1).mean(dim=0).unsqueeze(0) 167 | outputs2 = model(input2).mean(dim=0).unsqueeze(0) 168 | 169 | dists = F.cosine_similarity(outputs1, outputs2) 170 | dists = dists.data.cpu().numpy() 171 | distances.append(dists) 172 | labels.append(label.data.cpu().numpy()) 173 | 174 | # measure elapsed time 175 | batch_time.update(time.time() - end) 176 | end = time.time() 177 | 178 | if i % 2000 == 0: 179 | progress.print(i) 180 | 181 | labels = np.array([sublabel for label in labels for sublabel in label]) 182 | distances = np.array([subdist for dist in distances for subdist in dist]) 183 | 184 | eer = compute_eer(distances, labels) 185 | logger.info('Test EER: {:.8f}'.format(np.mean(eer))) 186 | 187 | return eer 188 | 189 | 190 | def validate_identification(cfg, model, test_loader, criterion): 191 | batch_time = AverageMeter('Time', ':6.3f') 192 | losses = AverageMeter('Loss', ':.4e') 193 | top1 = AverageMeter('Acc@1', ':6.2f') 194 | top5 = AverageMeter('Acc@5', ':6.2f') 195 | progress = ProgressMeter( 196 | len(test_loader), batch_time, losses, top1, top5, prefix='Test: ', logger=logger) 197 | 198 | # switch to evaluate mode 199 | model.eval() 200 | 201 | with torch.no_grad(): 202 | end = time.time() 203 | for i, (input, target) in enumerate(test_loader): 204 | input = input.cuda(non_blocking=True).squeeze(0) 205 | target = target.cuda(non_blocking=True) 206 | 207 | # compute output 208 | output = model(input) 209 | output = torch.mean(output, dim=0, keepdim=True) 210 | output = model.forward_classifier(output) 211 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 212 | top1.update(acc1[0], input.size(0)) 213 | top5.update(acc5[0], input.size(0)) 214 | loss = criterion(output, target) 215 | 216 | losses.update(loss.item(), 1) 217 | 218 | # measure elapsed time 219 | batch_time.update(time.time() - end) 220 | end = time.time() 221 | 222 | if i % 2000 == 0: 223 | progress.print(i) 224 | 225 | logger.info('Test Acc@1: {:.8f} Acc@5: {:.8f}'.format(top1.avg, top5.avg)) 226 | 227 | return top1.avg 228 | 229 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class CrossEntropyLoss(nn.Module): 10 | r"""Cross entropy loss with label smoothing regularizer. 11 | 12 | Reference: 13 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 14 | With label smoothing, the label :math:`y` for a class is computed by 15 | 16 | .. math:: 17 | \begin{equation} 18 | (1 - \epsilon) \times y + \frac{\epsilon}{K}, 19 | \end{equation} 20 | where :math:`K` denotes the number of classes and :math:`\epsilon` is a weight. When 21 | :math:`\epsilon = 0`, the loss function reduces to the normal cross entropy. 22 | 23 | Args: 24 | num_classes (int): number of classes. 25 | epsilon (float, optional): weight. Default is 0.1. 26 | use_gpu (bool, optional): whether to use gpu devices. Default is True. 27 | label_smooth (bool, optional): whether to apply label smoothing. Default is True. 28 | """ 29 | 30 | def __init__( 31 | self, num_classes, epsilon=0.1, use_gpu=True, label_smooth=True 32 | ): 33 | super(CrossEntropyLoss, self).__init__() 34 | self.num_classes = num_classes 35 | self.epsilon = epsilon if label_smooth else 0 36 | self.use_gpu = use_gpu 37 | self.logsoftmax = nn.LogSoftmax(dim=1) 38 | 39 | def forward(self, inputs, targets): 40 | """ 41 | Args: 42 | inputs (torch.Tensor): prediction matrix (before softmax) with 43 | shape (batch_size, num_classes). 44 | targets (torch.LongTensor): ground truth labels with shape (batch_size). 45 | Each position contains the label index. 46 | """ 47 | log_probs = self.logsoftmax(inputs) 48 | zeros = torch.zeros(log_probs.size()) 49 | targets = zeros.scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 50 | if self.use_gpu: 51 | targets = targets.cuda() 52 | targets = ( 53 | 1 - self.epsilon 54 | ) * targets + self.epsilon / self.num_classes 55 | return (-targets * log_probs).mean(0).sum() 56 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2019-08-08 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import models.resnet 12 | #import models.multi_scale_model 13 | #import models.multi_scale_model_v2 14 | #import models.multi_scale_model_v3 15 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | from operations import * 2 | from utils import drop_path 3 | 4 | 5 | class Cell(nn.Module): 6 | 7 | def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): 8 | super(Cell, self).__init__() 9 | print(C_prev_prev, C_prev, C) 10 | 11 | if reduction_prev: 12 | self.preprocess0 = FactorizedReduce(C_prev_prev, C) 13 | else: 14 | self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0) 15 | self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0) 16 | 17 | if reduction: 18 | op_names, indices = zip(*genotype.reduce) 19 | concat = genotype.reduce_concat 20 | else: 21 | op_names, indices = zip(*genotype.normal) 22 | concat = genotype.normal_concat 23 | self._compile(C, op_names, indices, concat, reduction) 24 | 25 | def _compile(self, C, op_names, indices, concat, reduction): 26 | assert len(op_names) == len(indices) 27 | self._steps = len(op_names) // 2 28 | self._concat = concat 29 | self.multiplier = len(concat) 30 | 31 | self._ops = nn.ModuleList() 32 | for name, index in zip(op_names, indices): 33 | stride = 2 if reduction and index < 2 else 1 34 | op = OPS[name](C, stride, True) 35 | self._ops += [op] 36 | self._indices = indices 37 | 38 | def forward(self, s0, s1, drop_prob): 39 | s0 = self.preprocess0(s0) 40 | s1 = self.preprocess1(s1) 41 | 42 | states = [s0, s1] 43 | for i in range(self._steps): 44 | h1 = states[self._indices[2 * i]] 45 | h2 = states[self._indices[2 * i + 1]] 46 | op1 = self._ops[2 * i] 47 | op2 = self._ops[2 * i + 1] 48 | h1 = op1(h1) 49 | h2 = op2(h2) 50 | if self.training and drop_prob > 0.: 51 | if not isinstance(op1, Identity): 52 | h1 = drop_path(h1, drop_prob) 53 | if not isinstance(op2, Identity): 54 | h2 = drop_path(h2, drop_prob) 55 | s = h1 + h2 56 | states += [s] 57 | return torch.cat([states[i] for i in self._concat], dim=1) 58 | 59 | 60 | class Network(nn.Module): 61 | 62 | def __init__(self, C, num_classes, layers, genotype): 63 | super(Network, self).__init__() 64 | self._C = C 65 | self._num_classes = num_classes 66 | self._layers = layers 67 | 68 | self.stem0 = nn.Sequential( 69 | nn.Conv2d(1, C // 2, kernel_size=3, stride=2, padding=1, bias=False), 70 | nn.BatchNorm2d(C // 2), 71 | nn.ReLU(inplace=True), 72 | nn.Conv2d(C // 2, C, kernel_size=3, stride=2, padding=1, bias=False), 73 | nn.BatchNorm2d(C), 74 | ) 75 | 76 | self.stem1 = nn.Sequential( 77 | nn.ReLU(inplace=True), 78 | nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False), 79 | nn.BatchNorm2d(C), 80 | ) 81 | 82 | C_prev_prev, C_prev, C_curr = C, C, C 83 | 84 | self.cells = nn.ModuleList() 85 | reduction_prev = True 86 | for i in range(layers): 87 | if i in [layers // 3, 2 * layers // 3]: 88 | C_curr *= 2 89 | reduction = True 90 | else: 91 | reduction = False 92 | cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) 93 | reduction_prev = reduction 94 | self.cells += [cell] 95 | C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr 96 | 97 | self.global_pooling = nn.AdaptiveAvgPool2d((1, 1)) 98 | self.classifier = nn.Linear(C_prev, num_classes) 99 | 100 | def forward(self, input): 101 | input = input.unsqueeze(1) 102 | s0 = self.stem0(input) 103 | s1 = self.stem1(s0) 104 | for i, cell in enumerate(self.cells): 105 | s0, s1 = s1, cell(s0, s1, self.drop_path_prob) 106 | v = self.global_pooling(s1) 107 | v = v.view(v.size(0), -1) 108 | if not self.training: 109 | return v 110 | 111 | y = self.classifier(v) 112 | 113 | return y 114 | 115 | 116 | def forward_classifier(self, v): 117 | y = self.classifier(v) 118 | return y 119 | 120 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /models/model_search.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn.functional as F 3 | from operations import * 4 | from utils import Genotype 5 | from utils import gumbel_softmax, drop_path 6 | 7 | 8 | class MixedOp(nn.Module): 9 | 10 | def __init__(self, C, stride, PRIMITIVES): 11 | super(MixedOp, self).__init__() 12 | self._ops = nn.ModuleList() 13 | for primitive in PRIMITIVES: 14 | op = OPS[primitive](C, stride, False) 15 | if 'pool' in primitive: 16 | op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False)) 17 | self._ops.append(op) 18 | 19 | def forward(self, x, weights): 20 | """ 21 | This is a forward function. 22 | :param x: Feature map 23 | :param weights: A tensor of weight controlling the path flow 24 | :return: A weighted sum of several path 25 | """ 26 | output = 0 27 | for op_idx, op in enumerate(self._ops): 28 | if weights[op_idx].item() != 0: 29 | if math.isnan(weights[op_idx]): 30 | raise OverflowError(f'weight: {weights}') 31 | output += weights[op_idx] * op(x) 32 | return output 33 | 34 | 35 | class Cell(nn.Module): 36 | 37 | def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev): 38 | super(Cell, self).__init__() 39 | self.reduction = reduction 40 | self.primitives = self.PRIMITIVES['primitives_reduct' if reduction else 'primitives_normal'] 41 | 42 | if reduction_prev: 43 | self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False) 44 | else: 45 | self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False) 46 | self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False) 47 | self._steps = steps 48 | self._multiplier = multiplier 49 | 50 | self._ops = nn.ModuleList() 51 | self._bns = nn.ModuleList() 52 | 53 | edge_index = 0 54 | 55 | for i in range(self._steps): 56 | for j in range(2 + i): 57 | stride = 2 if reduction and j < 2 else 1 58 | op = MixedOp(C, stride, self.primitives[edge_index]) 59 | self._ops.append(op) 60 | edge_index += 1 61 | 62 | def forward(self, s0, s1, weights, drop_prob=0.0): 63 | s0 = self.preprocess0(s0) 64 | s1 = self.preprocess1(s1) 65 | 66 | states = [s0, s1] 67 | offset = 0 68 | for i in range(self._steps): 69 | if drop_prob > 0. and self.training: 70 | s = sum( 71 | drop_path(self._ops[offset + j](h, weights[offset + j]), drop_prob) for j, h in enumerate(states)) 72 | else: 73 | s = sum(self._ops[offset + j](h, weights[offset + j]) for j, h in enumerate(states)) 74 | offset += len(states) 75 | states.append(s) 76 | 77 | return torch.cat(states[-self._multiplier:], dim=1) 78 | 79 | 80 | class Network(nn.Module): 81 | 82 | def __init__(self, C, num_classes, layers, criterion, primitives, 83 | steps=4, multiplier=4, stem_multiplier=3, drop_path_prob=0.0): 84 | super(Network, self).__init__() 85 | self._C = C 86 | self._num_classes = num_classes 87 | self._layers = layers 88 | self._criterion = criterion 89 | self._steps = steps 90 | self._multiplier = multiplier 91 | self.drop_path_prob = drop_path_prob 92 | 93 | nn.Module.PRIMITIVES = primitives 94 | 95 | C_curr = stem_multiplier * C 96 | self.stem = nn.Sequential( 97 | nn.Conv2d(1, C_curr, 3, padding=1, bias=False), 98 | nn.BatchNorm2d(C_curr), 99 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 100 | ) 101 | 102 | C_prev_prev, C_prev, C_curr = C_curr, C_curr, C 103 | self.cells = nn.ModuleList() 104 | reduction_prev = False 105 | for i in range(layers): 106 | if i in [layers // 3, 2 * layers // 3]: 107 | C_curr *= 2 108 | reduction = True 109 | else: 110 | reduction = False 111 | cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) 112 | reduction_prev = reduction 113 | self.cells += [cell] 114 | C_prev_prev, C_prev = C_prev, multiplier * C_curr 115 | 116 | self.global_pooling = nn.AdaptiveAvgPool2d((1, 1)) 117 | self.classifier = nn.Linear(C_prev, self._num_classes) 118 | 119 | self._initialize_alphas() 120 | 121 | def new(self): 122 | model_new = Network(self._C, self._embed_dim, self._layers, self._criterion, 123 | self.PRIMITIVES, drop_path_prob=self.drop_path_prob).cuda() 124 | for x, y in zip(model_new.arch_parameters(), self.arch_parameters()): 125 | x.data.copy_(y.data) 126 | return model_new 127 | 128 | def forward(self, input, discrete=False): 129 | input = input.unsqueeze(1) 130 | s0 = s1 = self.stem(input) 131 | for i, cell in enumerate(self.cells): 132 | if cell.reduction: 133 | if discrete: 134 | weights = self.alphas_reduce 135 | else: 136 | weights = gumbel_softmax(F.log_softmax(self.alphas_reduce, dim=-1)) 137 | else: 138 | if discrete: 139 | weights = self.alphas_normal 140 | else: 141 | weights = gumbel_softmax(F.log_softmax(self.alphas_normal, dim=-1)) 142 | s0, s1 = s1, cell(s0, s1, weights, self.drop_path_prob) 143 | v = self.global_pooling(s1) 144 | v = v.view(v.size(0), -1) 145 | if not self.training: 146 | return v 147 | 148 | y = self.classifier(v) 149 | 150 | return y 151 | 152 | def forward_classifier(self, v): 153 | y = self.classifier(v) 154 | return y 155 | 156 | def _loss(self, input, target): 157 | logits = self(input) 158 | return self._criterion(logits, target) 159 | 160 | def _initialize_alphas(self): 161 | k = sum(1 for i in range(self._steps) for n in range(2 + i)) 162 | num_ops = len(self.PRIMITIVES['primitives_normal'][0]) 163 | 164 | self.alphas_normal = nn.Parameter(1e-3 * torch.randn(k, num_ops)) 165 | self.alphas_reduce = nn.Parameter(1e-3 * torch.randn(k, num_ops)) 166 | self._arch_parameters = [ 167 | self.alphas_normal, 168 | self.alphas_reduce, 169 | ] 170 | 171 | def arch_parameters(self): 172 | return self._arch_parameters 173 | 174 | def compute_arch_entropy(self, dim=-1): 175 | alpha = self.arch_parameters()[0] 176 | prob = F.softmax(alpha, dim=dim) 177 | log_prob = F.log_softmax(alpha, dim=dim) 178 | entropy = - (log_prob * prob).sum(-1, keepdim=False) 179 | return entropy 180 | 181 | def genotype(self): 182 | def _parse(weights, normal=True): 183 | PRIMITIVES = self.PRIMITIVES['primitives_normal' if normal else 'primitives_reduct'] 184 | 185 | gene = [] 186 | n = 2 187 | start = 0 188 | for i in range(self._steps): 189 | end = start + n 190 | W = weights[start:end].copy() 191 | try: 192 | edges = sorted(range(i + 2), key=lambda x: -max( 193 | W[x][k] for k in range(len(W[x])) if k != PRIMITIVES[x].index('none')))[:2] 194 | except ValueError: # This error happens when the 'none' op is not present in the ops 195 | edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x]))))[:2] 196 | for j in edges: 197 | k_best = None 198 | for k in range(len(W[j])): 199 | if 'none' in PRIMITIVES[j]: 200 | if k != PRIMITIVES[j].index('none'): 201 | if k_best is None or W[j][k] > W[j][k_best]: 202 | k_best = k 203 | else: 204 | if k_best is None or W[j][k] > W[j][k_best]: 205 | k_best = k 206 | gene.append((PRIMITIVES[start+j][k_best], j)) 207 | start = end 208 | n += 1 209 | return gene 210 | 211 | gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).data.cpu().numpy(), True) 212 | gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).data.cpu().numpy(), False) 213 | 214 | concat = range(2 + self._steps - self._multiplier, self._steps + 2) 215 | genotype = Genotype( 216 | normal=gene_normal, normal_concat=concat, 217 | reduce=gene_reduce, reduce_concat=concat 218 | ) 219 | return genotype 220 | 221 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code source: https://github.com/pytorch/vision 3 | """ 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | 7 | __all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 8 | 'resnext101_32x8d', 'resnet50_fc512'] 9 | 10 | from torch import nn 11 | import torch 12 | import torch.utils.model_zoo as model_zoo 13 | 14 | model_urls = { 15 | 'resnet18': 16 | 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 17 | 'resnet34': 18 | 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 19 | 'resnet50': 20 | 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 21 | 'resnet101': 22 | 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 23 | 'resnet152': 24 | 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 25 | 'resnext50_32x4d': 26 | 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 27 | 'resnext101_32x8d': 28 | 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 29 | } 30 | 31 | 32 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 33 | """3x3 convolution with padding""" 34 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 35 | padding=dilation, groups=groups, bias=False, dilation=dilation) 36 | 37 | 38 | def conv1x1(in_planes, out_planes, stride=1): 39 | """1x1 convolution""" 40 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 41 | 42 | 43 | class BasicBlock(nn.Module): 44 | expansion = 1 45 | 46 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 47 | base_width=64, dilation=1, norm_layer=None): 48 | super(BasicBlock, self).__init__() 49 | if norm_layer is None: 50 | norm_layer = nn.BatchNorm2d 51 | if groups != 1 or base_width != 64: 52 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 53 | if dilation > 1: 54 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 55 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 56 | self.conv1 = conv3x3(inplanes, planes, stride) 57 | self.bn1 = norm_layer(planes) 58 | self.relu = nn.ReLU(inplace=True) 59 | self.conv2 = conv3x3(planes, planes) 60 | self.bn2 = norm_layer(planes) 61 | self.downsample = downsample 62 | self.stride = stride 63 | 64 | def forward(self, x): 65 | identity = x 66 | 67 | out = self.conv1(x) 68 | out = self.bn1(out) 69 | out = self.relu(out) 70 | 71 | out = self.conv2(out) 72 | out = self.bn2(out) 73 | 74 | if self.downsample is not None: 75 | identity = self.downsample(x) 76 | 77 | out += identity 78 | out = self.relu(out) 79 | 80 | return out 81 | 82 | 83 | class Bottleneck(nn.Module): 84 | expansion = 4 85 | 86 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 87 | base_width=64, dilation=1, norm_layer=None): 88 | super(Bottleneck, self).__init__() 89 | if norm_layer is None: 90 | norm_layer = nn.BatchNorm2d 91 | width = int(planes * (base_width / 64.)) * groups 92 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 93 | self.conv1 = conv1x1(inplanes, width) 94 | self.bn1 = norm_layer(width) 95 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 96 | self.bn2 = norm_layer(width) 97 | self.conv3 = conv1x1(width, planes * self.expansion) 98 | self.bn3 = norm_layer(planes * self.expansion) 99 | self.relu = nn.ReLU(inplace=True) 100 | self.downsample = downsample 101 | self.stride = stride 102 | 103 | def forward(self, x): 104 | identity = x 105 | 106 | out = self.conv1(x) 107 | out = self.bn1(out) 108 | out = self.relu(out) 109 | 110 | out = self.conv2(out) 111 | out = self.bn2(out) 112 | out = self.relu(out) 113 | 114 | out = self.conv3(out) 115 | out = self.bn3(out) 116 | 117 | if self.downsample is not None: 118 | identity = self.downsample(x) 119 | 120 | out += identity 121 | out = self.relu(out) 122 | 123 | return out 124 | 125 | 126 | class ResNet(nn.Module): 127 | """Residual network. 128 | 129 | Reference: 130 | - He et al. Deep Residual Learning for Image Recognition. CVPR 2016. 131 | - Xie et al. Aggregated Residual Transformations for Deep Neural Networks. CVPR 2017. 132 | Public keys: 133 | - ``resnet18``: ResNet18. 134 | - ``resnet34``: ResNet34. 135 | - ``resnet50``: ResNet50. 136 | - ``resnet101``: ResNet101. 137 | - ``resnet152``: ResNet152. 138 | - ``resnext50_32x4d``: ResNeXt50. 139 | - ``resnext101_32x8d``: ResNeXt101. 140 | - ``resnet50_fc512``: ResNet50 + FC. 141 | """ 142 | 143 | def __init__(self, num_classes, loss, block, layers, zero_init_residual=False, 144 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 145 | norm_layer=None, last_stride=2, fc_dims=None, dropout_p=None, **kwargs): 146 | super(ResNet, self).__init__() 147 | if norm_layer is None: 148 | norm_layer = nn.BatchNorm2d 149 | self._norm_layer = norm_layer 150 | self.loss = loss 151 | self.feature_dim = 512 * block.expansion 152 | self.inplanes = 64 153 | self.dilation = 1 154 | if replace_stride_with_dilation is None: 155 | # each element in the tuple indicates if we should replace 156 | # the 2x2 stride with a dilated convolution instead 157 | replace_stride_with_dilation = [False, False, False] 158 | if len(replace_stride_with_dilation) != 3: 159 | raise ValueError("replace_stride_with_dilation should be None " 160 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 161 | self.groups = groups 162 | self.base_width = width_per_group 163 | self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=2, 164 | bias=False) 165 | self.bn1 = norm_layer(self.inplanes) 166 | self.relu = nn.ReLU(inplace=True) 167 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 168 | self.layer1 = self._make_layer(block, 64, layers[0]) 169 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 170 | dilate=replace_stride_with_dilation[0]) 171 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 172 | dilate=replace_stride_with_dilation[1]) 173 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride, 174 | dilate=replace_stride_with_dilation[2]) 175 | 176 | self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1)) 177 | self.fc = self._construct_fc_layer(fc_dims, self.feature_dim, dropout_p) 178 | self.classifier = nn.Linear(self.feature_dim, num_classes) 179 | 180 | self._init_params() 181 | 182 | # Zero-initialize the last BN in each residual branch, 183 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 184 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 185 | if zero_init_residual: 186 | for m in self.modules(): 187 | if isinstance(m, Bottleneck): 188 | nn.init.constant_(m.bn3.weight, 0) 189 | elif isinstance(m, BasicBlock): 190 | nn.init.constant_(m.bn2.weight, 0) 191 | 192 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 193 | norm_layer = self._norm_layer 194 | downsample = None 195 | previous_dilation = self.dilation 196 | if dilate: 197 | self.dilation *= stride 198 | stride = 1 199 | if stride != 1 or self.inplanes != planes * block.expansion: 200 | downsample = nn.Sequential( 201 | conv1x1(self.inplanes, planes * block.expansion, stride), 202 | norm_layer(planes * block.expansion), 203 | ) 204 | 205 | layers = [] 206 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 207 | self.base_width, previous_dilation, norm_layer)) 208 | self.inplanes = planes * block.expansion 209 | for _ in range(1, blocks): 210 | layers.append(block(self.inplanes, planes, groups=self.groups, 211 | base_width=self.base_width, dilation=self.dilation, 212 | norm_layer=norm_layer)) 213 | 214 | return nn.Sequential(*layers) 215 | 216 | def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): 217 | """Constructs fully connected layer 218 | Args: 219 | fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed 220 | input_dim (int): input dimension 221 | dropout_p (float): dropout probability, if None, dropout is unused 222 | """ 223 | if fc_dims is None: 224 | self.feature_dim = input_dim 225 | return None 226 | 227 | assert isinstance(fc_dims, (list, tuple)), 'fc_dims must be either list or tuple, but got {}'.format( 228 | type(fc_dims)) 229 | 230 | layers = [] 231 | for dim in fc_dims: 232 | layers.append(nn.Linear(input_dim, dim)) 233 | layers.append(nn.BatchNorm1d(dim)) 234 | layers.append(nn.ReLU(inplace=True)) 235 | if dropout_p is not None: 236 | layers.append(nn.Dropout(p=dropout_p)) 237 | input_dim = dim 238 | 239 | self.feature_dim = fc_dims[-1] 240 | 241 | return nn.Sequential(*layers) 242 | 243 | def _init_params(self): 244 | for m in self.modules(): 245 | if isinstance(m, nn.Conv2d): 246 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 247 | if m.bias is not None: 248 | nn.init.constant_(m.bias, 0) 249 | elif isinstance(m, nn.BatchNorm2d): 250 | nn.init.constant_(m.weight, 1) 251 | nn.init.constant_(m.bias, 0) 252 | elif isinstance(m, nn.BatchNorm1d): 253 | nn.init.constant_(m.weight, 1) 254 | nn.init.constant_(m.bias, 0) 255 | elif isinstance(m, nn.Linear): 256 | nn.init.normal_(m.weight, 0, 0.01) 257 | if m.bias is not None: 258 | nn.init.constant_(m.bias, 0) 259 | 260 | def featuremaps(self, x): 261 | x = x.unsqueeze(1) 262 | x = self.conv1(x) 263 | x = self.bn1(x) 264 | x = self.relu(x) 265 | x = self.maxpool(x) 266 | x = self.layer1(x) 267 | x = self.layer2(x) 268 | x = self.layer3(x) 269 | x = self.layer4(x) 270 | return x 271 | 272 | def forward(self, x): 273 | f = self.featuremaps(x) 274 | v = self.global_avgpool(f) 275 | v = v.view(v.size(0), -1) 276 | 277 | if self.fc is not None: 278 | v = self.fc(v) 279 | 280 | if not self.training: 281 | return v 282 | 283 | y = self.classifier(v) 284 | 285 | if self.loss == 'xent': 286 | return y 287 | elif self.loss == 'htri': 288 | return y, v 289 | else: 290 | raise KeyError("Unsupported loss: {}".format(self.loss)) 291 | 292 | def forward_classifier(self, v): 293 | y = self.classifier(v) 294 | return y 295 | 296 | """ResNet""" 297 | 298 | 299 | def init_pretrained_weights(model, model_url): 300 | """Initializes model with pretrained weights. 301 | 302 | Layers that don't match with pretrained layers in name or size are kept unchanged. 303 | """ 304 | pretrain_dict = model_zoo.load_url(model_url) 305 | model_dict = model.state_dict() 306 | pretrain_dict = { 307 | k: v 308 | for k, v in pretrain_dict.items() 309 | if k in model_dict and model_dict[k].size() == v.size() 310 | } 311 | model_dict.update(pretrain_dict) 312 | model.load_state_dict(model_dict) 313 | 314 | 315 | def resnet18(num_classes, loss='xent', pretrained=True, **kwargs): 316 | model = ResNet( 317 | num_classes=num_classes, 318 | loss=loss, 319 | block=BasicBlock, 320 | layers=[2, 2, 2, 2], 321 | last_stride=2, 322 | fc_dims=None, 323 | dropout_p=None, 324 | **kwargs 325 | ) 326 | if pretrained: 327 | init_pretrained_weights(model, model_urls['resnet18']) 328 | return model 329 | 330 | 331 | def resnet34(num_classes, loss='xent', pretrained=True, **kwargs): 332 | model = ResNet( 333 | num_classes=num_classes, 334 | loss=loss, 335 | block=BasicBlock, 336 | layers=[3, 4, 6, 3], 337 | last_stride=2, 338 | fc_dims=None, 339 | dropout_p=None, 340 | **kwargs 341 | ) 342 | if pretrained: 343 | init_pretrained_weights(model, model_urls['resnet34']) 344 | return model 345 | -------------------------------------------------------------------------------- /operations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | OPS = { 5 | 'none' : lambda C, stride, affine: Zero(stride), 6 | 'avg_pool_3x3' : lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False), 7 | 'max_pool_3x3' : lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1), 8 | 'skip_connect' : lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine), 9 | 'sep_conv_3x3' : lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine), 10 | 'sep_conv_5x5' : lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine), 11 | 'sep_conv_7x7' : lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine), 12 | 'sep_conv_3x1' : lambda C, stride, affine: SepConvTime(C, C, 3, stride, 1, affine=affine), 13 | 'sep_conv_1x3' : lambda C, stride, affine: SepConvFreq(C, C, 3, stride, 1, affine=affine), 14 | 'dil_conv_3x3' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), 15 | 'dil_conv_5x5' : lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), 16 | 'dil_conv_3x1' : lambda C, stride, affine: DilConvTime(C, C, 3, stride, 2, 2, affine=affine), 17 | 'dil_conv_1x3' : lambda C, stride, affine: DilConvFreq(C, C, 3, stride, 2, 2, affine=affine), 18 | 'conv_7x1_1x7' : lambda C, stride, affine: nn.Sequential( 19 | nn.ReLU(inplace=False), 20 | nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False), 21 | nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False), 22 | nn.BatchNorm2d(C, affine=affine) 23 | ), 24 | } 25 | 26 | class ReLUConvBN(nn.Module): 27 | 28 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 29 | super(ReLUConvBN, self).__init__() 30 | self.op = nn.Sequential( 31 | nn.ReLU(inplace=False), 32 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False), 33 | nn.BatchNorm2d(C_out, affine=affine) 34 | ) 35 | 36 | def forward(self, x): 37 | return self.op(x) 38 | 39 | class DilConv(nn.Module): 40 | 41 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): 42 | super(DilConv, self).__init__() 43 | self.op = nn.Sequential( 44 | nn.ReLU(inplace=False), 45 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False), 46 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 47 | nn.BatchNorm2d(C_out, affine=affine), 48 | ) 49 | 50 | def forward(self, x): 51 | return self.op(x) 52 | 53 | 54 | class DilConvTime(nn.Module): 55 | 56 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): 57 | super(DilConvTime, self).__init__() 58 | self.op = nn.Sequential( 59 | nn.ReLU(inplace=False), 60 | nn.Conv2d(C_in, C_in, kernel_size=(kernel_size, 1), stride=stride, padding=(padding, 0), dilation=(dilation, 1), groups=C_in, 61 | bias=False), 62 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 63 | nn.BatchNorm2d(C_out, affine=affine), 64 | ) 65 | 66 | def forward(self, x): 67 | return self.op(x) 68 | 69 | 70 | class DilConvFreq(nn.Module): 71 | 72 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): 73 | super(DilConvFreq, self).__init__() 74 | self.op = nn.Sequential( 75 | nn.ReLU(inplace=False), 76 | nn.Conv2d(C_in, C_in, kernel_size=(1, kernel_size), stride=stride, padding=(0, padding), dilation=(1, dilation), groups=C_in, 77 | bias=False), 78 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 79 | nn.BatchNorm2d(C_out, affine=affine), 80 | ) 81 | 82 | def forward(self, x): 83 | return self.op(x) 84 | 85 | class SepConv(nn.Module): 86 | 87 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 88 | super(SepConv, self).__init__() 89 | self.op = nn.Sequential( 90 | nn.ReLU(inplace=False), 91 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), 92 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), 93 | nn.BatchNorm2d(C_in, affine=affine), 94 | nn.ReLU(inplace=False), 95 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False), 96 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 97 | nn.BatchNorm2d(C_out, affine=affine), 98 | ) 99 | 100 | def forward(self, x): 101 | return self.op(x) 102 | 103 | 104 | class SepConvTime(nn.Module): 105 | 106 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 107 | super(SepConvTime, self).__init__() 108 | self.op = nn.Sequential( 109 | nn.ReLU(inplace=False), 110 | nn.Conv2d(C_in, C_in, kernel_size=(kernel_size, 1), stride=stride, padding=(padding, 0), groups=C_in, bias=False), 111 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), 112 | nn.BatchNorm2d(C_in, affine=affine), 113 | nn.ReLU(inplace=False), 114 | nn.Conv2d(C_in, C_in, kernel_size=(kernel_size, 1), stride=1, padding=(padding, 0), groups=C_in, bias=False), 115 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 116 | nn.BatchNorm2d(C_out, affine=affine), 117 | ) 118 | 119 | def forward(self, x): 120 | return self.op(x) 121 | 122 | 123 | class SepConvFreq(nn.Module): 124 | 125 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 126 | super(SepConvFreq, self).__init__() 127 | self.op = nn.Sequential( 128 | nn.ReLU(inplace=False), 129 | nn.Conv2d(C_in, C_in, kernel_size=(1, kernel_size), stride=stride, padding=(0, padding), groups=C_in, bias=False), 130 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), 131 | nn.BatchNorm2d(C_in, affine=affine), 132 | nn.ReLU(inplace=False), 133 | nn.Conv2d(C_in, C_in, kernel_size=(1, kernel_size), stride=1, padding=(0, padding), groups=C_in, bias=False), 134 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 135 | nn.BatchNorm2d(C_out, affine=affine), 136 | ) 137 | 138 | def forward(self, x): 139 | return self.op(x) 140 | 141 | class Identity(nn.Module): 142 | 143 | def __init__(self): 144 | super(Identity, self).__init__() 145 | 146 | def forward(self, x): 147 | return x 148 | 149 | 150 | class Zero(nn.Module): 151 | 152 | def __init__(self, stride): 153 | super(Zero, self).__init__() 154 | self.stride = stride 155 | 156 | def forward(self, x): 157 | if self.stride == 1: 158 | return x.mul(0.) 159 | return x[:,:,::self.stride,::self.stride].mul(0.) 160 | 161 | 162 | class FactorizedReduce(nn.Module): 163 | 164 | def __init__(self, C_in, C_out, affine=True): 165 | super(FactorizedReduce, self).__init__() 166 | assert C_out % 2 == 0 167 | self.relu = nn.ReLU(inplace=False) 168 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 169 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 170 | self.bn = nn.BatchNorm2d(C_out, affine=affine) 171 | 172 | def forward(self, x): 173 | x = self.relu(x) 174 | # out = torch.cat([self.conv_1(x), self.conv_2(x[:,:,1:,1:])], dim=1) 175 | out = torch.cat([self.conv_1(x), self.conv_2(x)], dim=1) 176 | out = self.bn(out) 177 | return out 178 | 179 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | audioread==2.1.8 2 | certifi==2020.4.5.1 3 | cffi==1.14.0 4 | cycler==0.10.0 5 | decorator==4.4.2 6 | dill==0.3.1.1 7 | future==0.18.2 8 | joblib==0.14.1 9 | kiwisolver==1.2.0 10 | librosa==0.7.2 11 | llvmlite==0.32.0 12 | matplotlib==3.2.1 13 | multiprocess==0.70.9 14 | numba==0.49.0 15 | numpy==1.18.4 16 | Pillow==7.1.2 17 | protobuf==3.11.3 18 | pycparser==2.20 19 | pyparsing==2.4.7 20 | python-dateutil==2.8.1 21 | PyYAML==5.3.1 22 | resampy==0.2.2 23 | scikit-learn==0.22.2.post1 24 | scipy==1.4.1 25 | six==1.14.0 26 | sklearn==0.0 27 | SoundFile==0.10.3.post1 28 | tensorboardX==2.0 29 | torch==1.5.0 30 | torchvision==0.6.0 31 | tqdm==4.46.0 32 | yacs==0.1.7 33 | -------------------------------------------------------------------------------- /search.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2019-08-09 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import argparse 13 | import numpy as np 14 | import shutil 15 | import os 16 | from tensorboardX import SummaryWriter 17 | from tqdm import tqdm 18 | from pathlib import Path 19 | 20 | import torch 21 | import torch.optim as optim 22 | import torch.backends.cudnn as cudnn 23 | 24 | from config import cfg, update_config 25 | from utils import set_path, create_logger, save_checkpoint 26 | from data_objects.DeepSpeakerDataset import DeepSpeakerDataset 27 | from functions import train, validate_identification 28 | from architect import Architect 29 | from loss import CrossEntropyLoss 30 | from torch.utils.data import DataLoader 31 | from spaces import primitives_1, primitives_2, primitives_3 32 | from models.model_search import Network 33 | 34 | 35 | def parse_args(): 36 | parser = argparse.ArgumentParser(description='Train energy network') 37 | # general 38 | parser.add_argument('--cfg', 39 | help='experiment configure file name', 40 | required=True, 41 | type=str) 42 | 43 | parser.add_argument('opts', 44 | help="Modify config options using the command-line", 45 | default=None, 46 | nargs=argparse.REMAINDER) 47 | 48 | parser.add_argument('--load_path', 49 | help="The path to resumed dir", 50 | default=None) 51 | 52 | args = parser.parse_args() 53 | 54 | return args 55 | 56 | 57 | def main(): 58 | args = parse_args() 59 | update_config(cfg, args) 60 | 61 | # cudnn related setting 62 | cudnn.benchmark = cfg.CUDNN.BENCHMARK 63 | torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC 64 | torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED 65 | 66 | # Set the random seed manually for reproducibility. 67 | np.random.seed(cfg.SEED) 68 | torch.manual_seed(cfg.SEED) 69 | torch.cuda.manual_seed_all(cfg.SEED) 70 | 71 | # Loss 72 | criterion = CrossEntropyLoss(cfg.MODEL.NUM_CLASSES).cuda() 73 | 74 | # model and optimizer 75 | model = Network(cfg.MODEL.INIT_CHANNELS, cfg.MODEL.NUM_CLASSES, cfg.MODEL.LAYERS, criterion, primitives_2, 76 | drop_path_prob=cfg.TRAIN.DROPPATH_PROB) 77 | model = model.cuda() 78 | 79 | # weight params 80 | arch_params = list(map(id, model.arch_parameters())) 81 | weight_params = filter(lambda p: id(p) not in arch_params, 82 | model.parameters()) 83 | 84 | # Optimizer 85 | optimizer = optim.Adam( 86 | weight_params, 87 | lr=cfg.TRAIN.LR 88 | ) 89 | 90 | # resume && make log dir and logger 91 | if args.load_path and os.path.exists(args.load_path): 92 | checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint_best.pth') 93 | assert os.path.exists(checkpoint_file) 94 | checkpoint = torch.load(checkpoint_file) 95 | 96 | # load checkpoint 97 | begin_epoch = checkpoint['epoch'] 98 | last_epoch = checkpoint['epoch'] 99 | model.load_state_dict(checkpoint['state_dict']) 100 | best_acc1 = checkpoint['best_acc1'] 101 | optimizer.load_state_dict(checkpoint['optimizer']) 102 | args.path_helper = checkpoint['path_helper'] 103 | 104 | logger = create_logger(args.path_helper['log_path']) 105 | logger.info("=> loaded checkpoint '{}'".format(checkpoint_file)) 106 | else: 107 | exp_name = args.cfg.split('/')[-1].split('.')[0] 108 | args.path_helper = set_path('logs_search', exp_name) 109 | logger = create_logger(args.path_helper['log_path']) 110 | begin_epoch = cfg.TRAIN.BEGIN_EPOCH 111 | best_acc1 = 0.0 112 | last_epoch = -1 113 | 114 | logger.info(args) 115 | logger.info(cfg) 116 | 117 | # copy model file 118 | this_dir = os.path.dirname(__file__) 119 | shutil.copy2( 120 | os.path.join(this_dir, 'models', cfg.MODEL.NAME + '.py'), 121 | args.path_helper['ckpt_path']) 122 | 123 | # dataloader 124 | train_dataset = DeepSpeakerDataset( 125 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.SUB_DIR, cfg.DATASET.PARTIAL_N_FRAMES, 'train') 126 | val_dataset = DeepSpeakerDataset( 127 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.SUB_DIR, cfg.DATASET.PARTIAL_N_FRAMES, 'val') 128 | train_loader = torch.utils.data.DataLoader( 129 | dataset=train_dataset, 130 | batch_size=cfg.TRAIN.BATCH_SIZE, 131 | num_workers=cfg.DATASET.NUM_WORKERS, 132 | pin_memory=True, 133 | shuffle=True, 134 | drop_last=True, 135 | ) 136 | val_loader = torch.utils.data.DataLoader( 137 | dataset=val_dataset, 138 | batch_size=cfg.TRAIN.BATCH_SIZE, 139 | num_workers=cfg.DATASET.NUM_WORKERS, 140 | pin_memory=True, 141 | shuffle=True, 142 | drop_last=True, 143 | ) 144 | test_dataset = DeepSpeakerDataset( 145 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.SUB_DIR, cfg.DATASET.PARTIAL_N_FRAMES, 'test', is_test=True) 146 | test_loader = torch.utils.data.DataLoader( 147 | dataset=test_dataset, 148 | batch_size=1, 149 | num_workers=cfg.DATASET.NUM_WORKERS, 150 | pin_memory=True, 151 | shuffle=True, 152 | drop_last=True, 153 | ) 154 | 155 | # training setting 156 | writer_dict = { 157 | 'writer': SummaryWriter(args.path_helper['log_path']), 158 | 'train_global_steps': begin_epoch * len(train_loader), 159 | 'valid_global_steps': begin_epoch // cfg.VAL_FREQ, 160 | } 161 | 162 | # training loop 163 | architect = Architect(model, cfg) 164 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 165 | optimizer, cfg.TRAIN.END_EPOCH, cfg.TRAIN.LR_MIN, 166 | last_epoch=last_epoch 167 | ) 168 | 169 | for epoch in tqdm(range(begin_epoch, cfg.TRAIN.END_EPOCH), desc='search progress'): 170 | model.train() 171 | 172 | genotype = model.genotype() 173 | logger.info('genotype = %s', genotype) 174 | 175 | if cfg.TRAIN.DROPPATH_PROB != 0: 176 | model.drop_path_prob = cfg.TRAIN.DROPPATH_PROB * epoch / (cfg.TRAIN.END_EPOCH - 1) 177 | 178 | train(cfg, model, optimizer, train_loader, val_loader, criterion, architect, epoch, writer_dict) 179 | 180 | if epoch % cfg.VAL_FREQ == 0: 181 | # get threshold and evaluate on validation set 182 | acc = validate_identification(cfg, model, test_loader, criterion) 183 | 184 | # remember best acc@1 and save checkpoint 185 | is_best = acc > best_acc1 186 | best_acc1 = max(acc, best_acc1) 187 | 188 | # save 189 | logger.info('=> saving checkpoint to {}'.format(args.path_helper['ckpt_path'])) 190 | save_checkpoint({ 191 | 'epoch': epoch + 1, 192 | 'state_dict': model.state_dict(), 193 | 'best_acc1': best_acc1, 194 | 'optimizer': optimizer.state_dict(), 195 | 'arch': model.arch_parameters(), 196 | 'genotype': genotype, 197 | 'path_helper': args.path_helper 198 | }, is_best, args.path_helper['ckpt_path'], 'checkpoint_{}.pth'.format(epoch)) 199 | 200 | lr_scheduler.step(epoch) 201 | 202 | 203 | 204 | if __name__ == '__main__': 205 | main() 206 | -------------------------------------------------------------------------------- /spaces.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | primitives_1 = OrderedDict([('primitives_normal', [['skip_connect', 4 | 'dil_conv_3x3'], 5 | ['skip_connect', 6 | 'dil_conv_5x5'], 7 | ['skip_connect', 8 | 'dil_conv_5x5'], 9 | ['skip_connect', 10 | 'sep_conv_3x3'], 11 | ['skip_connect', 12 | 'dil_conv_3x3'], 13 | ['max_pool_3x3', 14 | 'skip_connect'], 15 | ['skip_connect', 16 | 'sep_conv_3x3'], 17 | ['skip_connect', 18 | 'sep_conv_3x3'], 19 | ['skip_connect', 20 | 'dil_conv_3x3'], 21 | ['skip_connect', 22 | 'sep_conv_3x3'], 23 | ['max_pool_3x3', 24 | 'skip_connect'], 25 | ['skip_connect', 26 | 'dil_conv_3x3'], 27 | ['dil_conv_3x3', 28 | 'dil_conv_5x5'], 29 | ['dil_conv_3x3', 30 | 'dil_conv_5x5']]), 31 | ('primitives_reduct', [['max_pool_3x3', 32 | 'avg_pool_3x3'], 33 | ['max_pool_3x3', 34 | 'dil_conv_3x3'], 35 | ['max_pool_3x3', 36 | 'avg_pool_3x3'], 37 | ['max_pool_3x3', 38 | 'avg_pool_3x3'], 39 | ['skip_connect', 40 | 'dil_conv_5x5'], 41 | ['max_pool_3x3', 42 | 'avg_pool_3x3'], 43 | ['max_pool_3x3', 44 | 'sep_conv_3x3'], 45 | ['skip_connect', 46 | 'dil_conv_3x3'], 47 | ['skip_connect', 48 | 'dil_conv_5x5'], 49 | ['max_pool_3x3', 50 | 'avg_pool_3x3'], 51 | ['max_pool_3x3', 52 | 'avg_pool_3x3'], 53 | ['skip_connect', 54 | 'dil_conv_5x5'], 55 | ['skip_connect', 56 | 'dil_conv_5x5'], 57 | ['skip_connect', 58 | 'dil_conv_5x5']])]) 59 | 60 | PRIMITIVES = [ 61 | 'none', 62 | 'max_pool_3x3', 63 | 'avg_pool_3x3', 64 | 'skip_connect', 65 | 'sep_conv_3x3', 66 | 'sep_conv_5x5', 67 | 'dil_conv_3x3', 68 | 'dil_conv_5x5' 69 | ] 70 | 71 | primitives_2 = OrderedDict([('primitives_normal', 14 * [PRIMITIVES]), 72 | ('primitives_reduct', 14 * [PRIMITIVES])]) 73 | 74 | PRIMITIVES_SMALL = [ 75 | 'none', 76 | 'max_pool_3x3', 77 | 'avg_pool_3x3', 78 | 'skip_connect', 79 | 'sep_conv_3x3', 80 | 'sep_conv_3x1', 81 | 'sep_conv_1x3', 82 | 'dil_conv_3x3', 83 | 'dil_conv_3x1', 84 | 'dil_conv_1x3', 85 | ] 86 | 87 | primitives_3 = OrderedDict([('primitives_normal', 14 * [PRIMITIVES_SMALL]), 88 | ('primitives_reduct', 14 * [PRIMITIVES_SMALL])]) 89 | 90 | spaces_dict = { 91 | 's1': primitives_1, # space from https://openreview.net/forum?id=H1gDNyrKDS 92 | 's2': primitives_2, # original DARTS space 93 | 's3': primitives_3, # space with 1D conv 94 | } -------------------------------------------------------------------------------- /train_baseline_identification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import numpy as np 7 | import os 8 | from tensorboardX import SummaryWriter 9 | from tqdm import tqdm 10 | from pathlib import Path 11 | 12 | import torch 13 | import torch.optim as optim 14 | import torch.backends.cudnn as cudnn 15 | 16 | from models import resnet 17 | from config import cfg, update_config 18 | from utils import set_path, create_logger, save_checkpoint, count_parameters 19 | from data_objects.DeepSpeakerDataset import DeepSpeakerDataset 20 | from functions import train_from_scratch, validate_identification 21 | from loss import CrossEntropyLoss 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser(description='Train energy network') 26 | # general 27 | parser.add_argument('--cfg', 28 | help='experiment configure file name', 29 | required=True, 30 | type=str) 31 | 32 | parser.add_argument('opts', 33 | help="Modify config options using the command-line", 34 | default=None, 35 | nargs=argparse.REMAINDER) 36 | 37 | parser.add_argument('--load_path', 38 | help="The path to resumed dir", 39 | default=None) 40 | 41 | args = parser.parse_args() 42 | 43 | return args 44 | 45 | 46 | def main(): 47 | args = parse_args() 48 | update_config(cfg, args) 49 | 50 | # cudnn related setting 51 | cudnn.benchmark = cfg.CUDNN.BENCHMARK 52 | torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC 53 | torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED 54 | 55 | # Set the random seed manually for reproducibility. 56 | np.random.seed(cfg.SEED) 57 | torch.manual_seed(cfg.SEED) 58 | torch.cuda.manual_seed_all(cfg.SEED) 59 | 60 | # model and optimizer 61 | model = eval('resnet.{}(num_classes={})'.format( 62 | cfg.MODEL.NAME, cfg.MODEL.NUM_CLASSES)) 63 | model = model.cuda() 64 | optimizer = optim.Adam( 65 | model.net_parameters() if hasattr(model, 'net_parameters') else model.parameters(), 66 | lr=cfg.TRAIN.LR 67 | ) 68 | 69 | # Loss 70 | criterion = CrossEntropyLoss(cfg.MODEL.NUM_CLASSES).cuda() 71 | 72 | # resume && make log dir and logger 73 | if args.load_path and os.path.exists(args.load_path): 74 | checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint_best.pth') 75 | # checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint.pth') 76 | assert os.path.exists(checkpoint_file) 77 | checkpoint = torch.load(checkpoint_file) 78 | 79 | # load checkpoint 80 | begin_epoch = checkpoint['epoch'] 81 | last_epoch = checkpoint['epoch'] 82 | model.load_state_dict(checkpoint['state_dict']) 83 | best_acc1 = checkpoint['best_acc1'] 84 | optimizer.load_state_dict(checkpoint['optimizer']) 85 | args.path_helper = checkpoint['path_helper'] 86 | 87 | logger = create_logger(args.path_helper['log_path']) 88 | logger.info("=> loaded checkpoint '{}'".format(checkpoint_file)) 89 | else: 90 | exp_name = args.cfg.split('/')[-1].split('.')[0] 91 | args.path_helper = set_path('logs', exp_name) 92 | logger = create_logger(args.path_helper['log_path']) 93 | begin_epoch = cfg.TRAIN.BEGIN_EPOCH 94 | best_acc1 = 0.0 95 | last_epoch = -1 96 | logger.info(args) 97 | logger.info(cfg) 98 | logger.info("Number of parameters: {}".format(count_parameters(model))) 99 | 100 | # dataloader 101 | train_dataset = DeepSpeakerDataset( 102 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.SUB_DIR, cfg.DATASET.PARTIAL_N_FRAMES, 'train') 103 | test_dataset_identification = DeepSpeakerDataset( 104 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.SUB_DIR, cfg.DATASET.PARTIAL_N_FRAMES, 'test', is_test=True) 105 | train_loader = torch.utils.data.DataLoader( 106 | dataset=train_dataset, 107 | batch_size=cfg.TRAIN.BATCH_SIZE, 108 | num_workers=cfg.DATASET.NUM_WORKERS, 109 | pin_memory=True, 110 | shuffle=True, 111 | drop_last=True, 112 | ) 113 | test_loader_identification = torch.utils.data.DataLoader( 114 | dataset=test_dataset_identification, 115 | batch_size=1, 116 | num_workers=cfg.DATASET.NUM_WORKERS, 117 | pin_memory=True, 118 | shuffle=True, 119 | drop_last=True, 120 | ) 121 | 122 | # training setting 123 | writer_dict = { 124 | 'writer': SummaryWriter(args.path_helper['log_path']), 125 | 'train_global_steps': begin_epoch * len(train_loader), 126 | 'valid_global_steps': begin_epoch // cfg.VAL_FREQ, 127 | } 128 | 129 | # training loop 130 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 131 | optimizer, cfg.TRAIN.END_EPOCH, cfg.TRAIN.LR_MIN, 132 | last_epoch=last_epoch 133 | ) 134 | 135 | for epoch in tqdm(range(begin_epoch, cfg.TRAIN.END_EPOCH), desc='train progress'): 136 | model.train() 137 | train_from_scratch(cfg, model, optimizer, train_loader, criterion, epoch, writer_dict, lr_scheduler) 138 | if epoch % cfg.VAL_FREQ == 0: 139 | acc = validate_identification(cfg, model, test_loader_identification, criterion) 140 | 141 | # remember best acc@1 and save checkpoint 142 | is_best = acc > best_acc1 143 | best_acc1 = max(acc, best_acc1) 144 | 145 | # save 146 | logger.info('=> saving checkpoint to {}'.format(args.path_helper['ckpt_path'])) 147 | save_checkpoint({ 148 | 'epoch': epoch + 1, 149 | 'state_dict': model.state_dict(), 150 | 'best_acc1': best_acc1, 151 | 'optimizer': optimizer.state_dict(), 152 | 'path_helper': args.path_helper 153 | }, is_best, args.path_helper['ckpt_path'], 'checkpoint_{}.pth'.format(epoch)) 154 | lr_scheduler.step(epoch) 155 | 156 | 157 | if __name__ == '__main__': 158 | main() 159 | -------------------------------------------------------------------------------- /train_baseline_verification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import numpy as np 7 | import os 8 | from tensorboardX import SummaryWriter 9 | from tqdm import tqdm 10 | from pathlib import Path 11 | 12 | import torch 13 | import torch.optim as optim 14 | import torch.backends.cudnn as cudnn 15 | 16 | from models import resnet 17 | from config import cfg, update_config 18 | from utils import set_path, create_logger, save_checkpoint, count_parameters 19 | from data_objects.DeepSpeakerDataset import DeepSpeakerDataset 20 | from data_objects.VoxcelebTestset import VoxcelebTestset 21 | from functions import train_from_scratch, validate_verification 22 | from loss import CrossEntropyLoss 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description='Train energy network') 27 | # general 28 | parser.add_argument('--cfg', 29 | help='experiment configure file name', 30 | required=True, 31 | type=str) 32 | 33 | parser.add_argument('opts', 34 | help="Modify config options using the command-line", 35 | default=None, 36 | nargs=argparse.REMAINDER) 37 | 38 | parser.add_argument('--load_path', 39 | help="The path to resumed dir", 40 | default=None) 41 | 42 | args = parser.parse_args() 43 | 44 | return args 45 | 46 | 47 | def main(): 48 | args = parse_args() 49 | update_config(cfg, args) 50 | 51 | # cudnn related setting 52 | cudnn.benchmark = cfg.CUDNN.BENCHMARK 53 | torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC 54 | torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED 55 | 56 | # Set the random seed manually for reproducibility. 57 | np.random.seed(cfg.SEED) 58 | torch.manual_seed(cfg.SEED) 59 | torch.cuda.manual_seed_all(cfg.SEED) 60 | 61 | # model and optimizer 62 | model = eval('resnet.{}(num_classes={})'.format( 63 | cfg.MODEL.NAME, cfg.MODEL.NUM_CLASSES)) 64 | model = model.cuda() 65 | optimizer = optim.Adam( 66 | model.net_parameters() if hasattr(model, 'net_parameters') else model.parameters(), 67 | lr=cfg.TRAIN.LR, 68 | ) 69 | 70 | # Loss 71 | criterion = CrossEntropyLoss(cfg.MODEL.NUM_CLASSES).cuda() 72 | 73 | # resume && make log dir and logger 74 | if args.load_path and os.path.exists(args.load_path): 75 | checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint_best.pth') 76 | assert os.path.exists(checkpoint_file) 77 | checkpoint = torch.load(checkpoint_file) 78 | 79 | # load checkpoint 80 | begin_epoch = checkpoint['epoch'] 81 | last_epoch = checkpoint['epoch'] 82 | model.load_state_dict(checkpoint['state_dict']) 83 | best_eer = checkpoint['best_eer'] 84 | optimizer.load_state_dict(checkpoint['optimizer']) 85 | args.path_helper = checkpoint['path_helper'] 86 | 87 | logger = create_logger(args.path_helper['log_path']) 88 | logger.info("=> loaded checkpoint '{}'".format(checkpoint_file)) 89 | else: 90 | exp_name = args.cfg.split('/')[-1].split('.')[0] 91 | args.path_helper = set_path('logs', exp_name) 92 | logger = create_logger(args.path_helper['log_path']) 93 | begin_epoch = cfg.TRAIN.BEGIN_EPOCH 94 | best_eer = 1.0 95 | last_epoch = -1 96 | logger.info(args) 97 | logger.info(cfg) 98 | logger.info("Number of parameters: {}".format(count_parameters(model))) 99 | 100 | # dataloader 101 | train_dataset = DeepSpeakerDataset( 102 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.SUB_DIR, cfg.DATASET.PARTIAL_N_FRAMES) 103 | test_dataset_verification = VoxcelebTestset( 104 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.PARTIAL_N_FRAMES 105 | ) 106 | train_loader = torch.utils.data.DataLoader( 107 | dataset=train_dataset, 108 | batch_size=cfg.TRAIN.BATCH_SIZE, 109 | num_workers=cfg.DATASET.NUM_WORKERS, 110 | pin_memory=True, 111 | shuffle=True, 112 | drop_last=True, 113 | ) 114 | test_loader_verification = torch.utils.data.DataLoader( 115 | dataset=test_dataset_verification, 116 | batch_size=1, 117 | num_workers=cfg.DATASET.NUM_WORKERS, 118 | pin_memory=True, 119 | shuffle=False, 120 | drop_last=False, 121 | ) 122 | 123 | # training setting 124 | writer_dict = { 125 | 'writer': SummaryWriter(args.path_helper['log_path']), 126 | 'train_global_steps': begin_epoch * len(train_loader), 127 | 'valid_global_steps': begin_epoch // cfg.VAL_FREQ, 128 | } 129 | 130 | # training loop 131 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 132 | optimizer, cfg.TRAIN.END_EPOCH, cfg.TRAIN.LR_MIN, 133 | last_epoch=last_epoch 134 | ) 135 | 136 | for epoch in tqdm(range(begin_epoch, cfg.TRAIN.END_EPOCH), desc='train progress'): 137 | model.train() 138 | train_from_scratch(cfg, model, optimizer, train_loader, criterion, epoch, writer_dict, lr_scheduler) 139 | if epoch % cfg.VAL_FREQ == 0: 140 | eer = validate_verification(cfg, model, test_loader_verification) 141 | 142 | # remember best acc@1 and save checkpoint 143 | is_best = eer < best_eer 144 | best_eer = min(eer, best_eer) 145 | 146 | # save 147 | logger.info('=> saving checkpoint to {}'.format(args.path_helper['ckpt_path'])) 148 | save_checkpoint({ 149 | 'epoch': epoch + 1, 150 | 'state_dict': model.state_dict(), 151 | 'best_eer': best_eer, 152 | 'optimizer': optimizer.state_dict(), 153 | 'path_helper': args.path_helper 154 | }, is_best, args.path_helper['ckpt_path'], 'checkpoint_{}.pth'.format(epoch)) 155 | lr_scheduler.step(epoch) 156 | 157 | 158 | if __name__ == '__main__': 159 | main() 160 | -------------------------------------------------------------------------------- /train_identification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import numpy as np 7 | import shutil 8 | import os 9 | from pathlib import Path 10 | from tensorboardX import SummaryWriter 11 | from tqdm import tqdm 12 | 13 | import torch 14 | import torch.optim as optim 15 | 16 | import torch.backends.cudnn as cudnn 17 | 18 | from models.model import Network 19 | from config import cfg, update_config 20 | from utils import set_path, create_logger, save_checkpoint, count_parameters, Genotype 21 | from data_objects.DeepSpeakerDataset import DeepSpeakerDataset 22 | from functions import train_from_scratch, validate_identification 23 | from loss import CrossEntropyLoss 24 | 25 | 26 | def parse_args(): 27 | parser = argparse.ArgumentParser(description='Train energy network') 28 | # general 29 | parser.add_argument('--cfg', 30 | help='experiment configure file name', 31 | required=True, 32 | type=str) 33 | 34 | parser.add_argument('opts', 35 | help="Modify config options using the command-line", 36 | default=None, 37 | nargs=argparse.REMAINDER) 38 | 39 | parser.add_argument('--load_path', 40 | help="The path to resumed dir", 41 | default=None) 42 | 43 | parser.add_argument('--text_arch', 44 | help="The text to arch", 45 | default=None) 46 | 47 | args = parser.parse_args() 48 | 49 | return args 50 | 51 | 52 | def main(): 53 | args = parse_args() 54 | update_config(cfg, args) 55 | assert args.text_arch 56 | 57 | # cudnn related setting 58 | cudnn.benchmark = cfg.CUDNN.BENCHMARK 59 | torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC 60 | torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED 61 | 62 | # Set the random seed manually for reproducibility. 63 | np.random.seed(cfg.SEED) 64 | torch.manual_seed(cfg.SEED) 65 | torch.cuda.manual_seed_all(cfg.SEED) 66 | 67 | # Loss 68 | criterion = CrossEntropyLoss(cfg.MODEL.NUM_CLASSES).cuda() 69 | 70 | # load arch 71 | genotype = eval(args.text_arch) 72 | 73 | model = Network(cfg.MODEL.INIT_CHANNELS, cfg.MODEL.NUM_CLASSES, cfg.MODEL.LAYERS, genotype) 74 | model = model.cuda() 75 | 76 | optimizer = optim.Adam( 77 | model.parameters(), 78 | lr=cfg.TRAIN.LR 79 | ) 80 | 81 | # resume && make log dir and logger 82 | if args.load_path and os.path.exists(args.load_path): 83 | checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint_best.pth') 84 | assert os.path.exists(checkpoint_file) 85 | checkpoint = torch.load(checkpoint_file) 86 | 87 | # load checkpoint 88 | begin_epoch = checkpoint['epoch'] 89 | last_epoch = checkpoint['epoch'] 90 | model.load_state_dict(checkpoint['state_dict']) 91 | best_acc1 = checkpoint['best_acc1'] 92 | optimizer.load_state_dict(checkpoint['optimizer']) 93 | args.path_helper = checkpoint['path_helper'] 94 | 95 | logger = create_logger(args.path_helper['log_path']) 96 | logger.info("=> loaded checkloggpoint '{}'".format(checkpoint_file)) 97 | else: 98 | exp_name = args.cfg.split('/')[-1].split('.')[0] 99 | args.path_helper = set_path('logs_scratch', exp_name) 100 | logger = create_logger(args.path_helper['log_path']) 101 | begin_epoch = cfg.TRAIN.BEGIN_EPOCH 102 | best_acc1 = 0.0 103 | last_epoch = -1 104 | logger.info(args) 105 | logger.info(cfg) 106 | logger.info(f"selected architecture: {genotype}") 107 | logger.info("Number of parameters: {}".format(count_parameters(model))) 108 | 109 | # dataloader 110 | train_dataset = DeepSpeakerDataset( 111 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.SUB_DIR, cfg.DATASET.PARTIAL_N_FRAMES, 'train') 112 | train_loader = torch.utils.data.DataLoader( 113 | dataset=train_dataset, 114 | batch_size=cfg.TRAIN.BATCH_SIZE, 115 | num_workers=cfg.DATASET.NUM_WORKERS, 116 | pin_memory=True, 117 | shuffle=True, 118 | drop_last=True, 119 | ) 120 | test_dataset = DeepSpeakerDataset( 121 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.SUB_DIR, cfg.DATASET.PARTIAL_N_FRAMES, 'test', is_test=True) 122 | test_loader = torch.utils.data.DataLoader( 123 | dataset=test_dataset, 124 | batch_size=1, 125 | num_workers=cfg.DATASET.NUM_WORKERS, 126 | pin_memory=True, 127 | shuffle=True, 128 | drop_last=True, 129 | ) 130 | 131 | # training setting 132 | writer_dict = { 133 | 'writer': SummaryWriter(args.path_helper['log_path']), 134 | 'train_global_steps': begin_epoch * len(train_loader), 135 | 'valid_global_steps': begin_epoch // cfg.VAL_FREQ, 136 | } 137 | 138 | # training loop 139 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 140 | optimizer, cfg.TRAIN.END_EPOCH, cfg.TRAIN.LR_MIN, 141 | last_epoch=last_epoch 142 | ) 143 | 144 | for epoch in tqdm(range(begin_epoch, cfg.TRAIN.END_EPOCH), desc='train progress'): 145 | model.train() 146 | model.drop_path_prob = cfg.MODEL.DROP_PATH_PROB * epoch / cfg.TRAIN.END_EPOCH 147 | 148 | train_from_scratch(cfg, model, optimizer, train_loader, criterion, epoch, writer_dict) 149 | 150 | if epoch % cfg.VAL_FREQ == 0 or epoch == cfg.TRAIN.END_EPOCH - 1: 151 | acc = validate_identification(cfg, model, test_loader, criterion) 152 | 153 | # remember best acc@1 and save checkpoint 154 | is_best = acc > best_acc1 155 | best_acc1 = max(acc, best_acc1) 156 | 157 | # save 158 | logger.info('=> saving checkpoint to {}'.format(args.path_helper['ckpt_path'])) 159 | save_checkpoint({ 160 | 'epoch': epoch + 1, 161 | 'state_dict': model.state_dict(), 162 | 'best_acc1': best_acc1, 163 | 'optimizer': optimizer.state_dict(), 164 | 'path_helper': args.path_helper, 165 | 'genotype': genotype, 166 | }, is_best, args.path_helper['ckpt_path'], 'checkpoint_{}.pth'.format(epoch)) 167 | 168 | lr_scheduler.step(epoch) 169 | 170 | 171 | if __name__ == '__main__': 172 | main() 173 | -------------------------------------------------------------------------------- /train_verification.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2019-08-09 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import argparse 13 | import numpy as np 14 | import shutil 15 | import os 16 | from pathlib import Path 17 | from tensorboardX import SummaryWriter 18 | from tqdm import tqdm 19 | 20 | import torch 21 | import torch.optim as optim 22 | import torch.nn as nn 23 | import torch.backends.cudnn as cudnn 24 | 25 | from models.model import Network 26 | from config import cfg, update_config 27 | from utils import set_path, create_logger, save_checkpoint, count_parameters, Genotype 28 | from data_objects.DeepSpeakerDataset import DeepSpeakerDataset 29 | from data_objects.VoxcelebTestset import VoxcelebTestset 30 | from functions import train_from_scratch, validate_verification 31 | from loss import CrossEntropyLoss 32 | 33 | 34 | def parse_args(): 35 | parser = argparse.ArgumentParser(description='Train energy network') 36 | # general 37 | parser.add_argument('--cfg', 38 | help='experiment configure file name', 39 | required=True, 40 | type=str) 41 | 42 | parser.add_argument('opts', 43 | help="Modify config options using the command-line", 44 | default=None, 45 | nargs=argparse.REMAINDER) 46 | 47 | parser.add_argument('--load_path', 48 | help="The path to resumed dir", 49 | default=None) 50 | 51 | parser.add_argument('--text_arch', 52 | help="The text to arch", 53 | default=None) 54 | 55 | args = parser.parse_args() 56 | 57 | return args 58 | 59 | 60 | def main(): 61 | args = parse_args() 62 | update_config(cfg, args) 63 | assert args.text_arch 64 | 65 | # cudnn related setting 66 | cudnn.benchmark = cfg.CUDNN.BENCHMARK 67 | torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC 68 | torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED 69 | 70 | # Set the random seed manually for reproducibility. 71 | np.random.seed(cfg.SEED) 72 | torch.manual_seed(cfg.SEED) 73 | torch.cuda.manual_seed_all(cfg.SEED) 74 | 75 | # Loss 76 | criterion = CrossEntropyLoss(cfg.MODEL.NUM_CLASSES).cuda() 77 | 78 | # load arch 79 | genotype = eval(args.text_arch) 80 | 81 | model = Network(cfg.MODEL.INIT_CHANNELS, cfg.MODEL.NUM_CLASSES, cfg.MODEL.LAYERS, genotype) 82 | model = model.cuda() 83 | 84 | optimizer = optim.Adam( 85 | model.parameters(), 86 | lr=cfg.TRAIN.LR 87 | ) 88 | 89 | # resume && make log dir and logger 90 | if args.load_path and os.path.exists(args.load_path): 91 | checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint_best.pth') 92 | assert os.path.exists(checkpoint_file) 93 | checkpoint = torch.load(checkpoint_file) 94 | 95 | # load checkpoint 96 | begin_epoch = checkpoint['epoch'] 97 | last_epoch = checkpoint['epoch'] 98 | model.load_state_dict(checkpoint['state_dict']) 99 | best_eer = checkpoint['best_eer'] 100 | optimizer.load_state_dict(checkpoint['optimizer']) 101 | args.path_helper = checkpoint['path_helper'] 102 | 103 | logger = create_logger(args.path_helper['log_path']) 104 | logger.info("=> loaded checkloggpoint '{}'".format(checkpoint_file)) 105 | else: 106 | exp_name = args.cfg.split('/')[-1].split('.')[0] 107 | args.path_helper = set_path('logs_scratch', exp_name) 108 | logger = create_logger(args.path_helper['log_path']) 109 | begin_epoch = cfg.TRAIN.BEGIN_EPOCH 110 | best_eer = 1.0 111 | last_epoch = -1 112 | logger.info(args) 113 | logger.info(cfg) 114 | logger.info(f"selected architecture: {genotype}") 115 | logger.info("Number of parameters: {}".format(count_parameters(model))) 116 | 117 | # dataloader 118 | train_dataset = DeepSpeakerDataset( 119 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.SUB_DIR, cfg.DATASET.PARTIAL_N_FRAMES) 120 | train_loader = torch.utils.data.DataLoader( 121 | dataset=train_dataset, 122 | batch_size=cfg.TRAIN.BATCH_SIZE, 123 | num_workers=cfg.DATASET.NUM_WORKERS, 124 | pin_memory=True, 125 | shuffle=True, 126 | drop_last=True, 127 | ) 128 | test_dataset_verification = VoxcelebTestset( 129 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.PARTIAL_N_FRAMES) 130 | test_loader_verification = torch.utils.data.DataLoader( 131 | dataset=test_dataset_verification, 132 | batch_size=1, 133 | num_workers=cfg.DATASET.NUM_WORKERS, 134 | pin_memory=True, 135 | shuffle=False, 136 | drop_last=False, 137 | ) 138 | 139 | # training setting 140 | writer_dict = { 141 | 'writer': SummaryWriter(args.path_helper['log_path']), 142 | 'train_global_steps': begin_epoch * len(train_loader), 143 | 'valid_global_steps': begin_epoch // cfg.VAL_FREQ, 144 | } 145 | 146 | # training loop 147 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 148 | optimizer, cfg.TRAIN.END_EPOCH, cfg.TRAIN.LR_MIN, 149 | last_epoch=last_epoch 150 | ) 151 | 152 | for epoch in tqdm(range(begin_epoch, cfg.TRAIN.END_EPOCH), desc='train progress'): 153 | model.train() 154 | model.drop_path_prob = cfg.MODEL.DROP_PATH_PROB * epoch / cfg.TRAIN.END_EPOCH 155 | 156 | train_from_scratch(cfg, model, optimizer, train_loader, criterion, epoch, writer_dict) 157 | 158 | if epoch % cfg.VAL_FREQ == 0 or epoch == cfg.TRAIN.END_EPOCH - 1: 159 | eer = validate_verification(cfg, model, test_loader_verification) 160 | 161 | # remember best acc@1 and save checkpoint 162 | is_best = eer < best_eer 163 | best_eer = min(eer, best_eer) 164 | 165 | # save 166 | logger.info('=> saving checkpoint to {}'.format(args.path_helper['ckpt_path'])) 167 | save_checkpoint({ 168 | 'epoch': epoch + 1, 169 | 'state_dict': model.state_dict(), 170 | 'best_eer': best_eer, 171 | 'optimizer': optimizer.state_dict(), 172 | 'path_helper': args.path_helper 173 | }, is_best, args.path_helper['ckpt_path'], 'checkpoint_{}.pth'.format(epoch)) 174 | 175 | lr_scheduler.step(epoch) 176 | 177 | 178 | if __name__ == '__main__': 179 | main() 180 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import dateutil.tz 3 | import time 4 | import logging 5 | import os 6 | 7 | 8 | import numpy as np 9 | from sklearn.metrics import roc_curve 10 | from datetime import datetime 11 | import matplotlib.pyplot as plt 12 | from collections import namedtuple 13 | 14 | plt.switch_backend('agg') 15 | 16 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') 17 | 18 | def count_parameters(model): 19 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 20 | 21 | 22 | def init_pretrained_weights(model, checkpoint): 23 | """Initializes model with pretrained weights. 24 | 25 | Layers that don't match with pretrained layers in name or size are kept unchanged. 26 | """ 27 | checkpoint_file = torch.load(checkpoint) 28 | pretrain_dict = checkpoint_file['state_dict'] 29 | model_dict = model.state_dict() 30 | pretrain_dict = { 31 | k: v 32 | for k, v in pretrain_dict.items() 33 | if k in model_dict and model_dict[k].size() == v.size() 34 | } 35 | model_dict.update(pretrain_dict) 36 | model.load_state_dict(model_dict) 37 | 38 | 39 | class AverageMeter(object): 40 | """Computes and stores the average and current value""" 41 | 42 | def __init__(self, name, fmt=':f'): 43 | self.name = name 44 | self.fmt = fmt 45 | self.reset() 46 | 47 | def reset(self): 48 | self.val = 0 49 | self.avg = 0 50 | self.sum = 0 51 | self.count = 0 52 | 53 | def update(self, val, n=1): 54 | self.val = val 55 | self.sum += val * n 56 | self.count += n 57 | self.avg = self.sum / self.count 58 | 59 | def __str__(self): 60 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 61 | return fmtstr.format(**self.__dict__) 62 | 63 | 64 | class ProgressMeter(object): 65 | def __init__(self, num_batches, *meters, prefix="", logger=None): 66 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 67 | self.meters = meters 68 | self.prefix = prefix 69 | self.logger = logger 70 | 71 | def print(self, batch): 72 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 73 | entries += [str(meter) for meter in self.meters] 74 | if self.logger: 75 | self.logger.info('\t'.join(entries)) 76 | else: 77 | print('\t'.join(entries)) 78 | 79 | def _get_batch_fmtstr(self, num_batches): 80 | num_digits = len(str(num_batches // 1)) 81 | fmt = '{:' + str(num_digits) + 'd}' 82 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 83 | 84 | def compute_eer(distances, labels): 85 | # Calculate evaluation metrics 86 | fprs, tprs, _ = roc_curve(labels, distances) 87 | eer = fprs[np.nanargmin(np.absolute((1 - tprs) - fprs))] 88 | return eer 89 | 90 | 91 | def accuracy(output, target, topk=(1,)): 92 | """Computes the accuracy over the k top predictions for the specified values of k""" 93 | with torch.no_grad(): 94 | maxk = max(topk) 95 | batch_size = target.size(0) 96 | 97 | _, pred = output.topk(maxk, 1, True, True) 98 | pred = pred.t() 99 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 100 | 101 | res = [] 102 | for k in topk: 103 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 104 | res.append(correct_k.mul_(100.0 / batch_size)) 105 | return res 106 | 107 | 108 | def create_logger(log_dir, phase='train'): 109 | time_str = time.strftime('%Y-%m-%d-%H-%M') 110 | log_file = '{}_{}.log'.format(time_str, phase) 111 | final_log_file = os.path.join(log_dir, log_file) 112 | head = '%(asctime)-15s %(message)s' 113 | logging.basicConfig(filename=str(final_log_file), 114 | format=head) 115 | logger = logging.getLogger() 116 | logger.setLevel(logging.INFO) 117 | console = logging.StreamHandler() 118 | logging.getLogger('').addHandler(console) 119 | 120 | return logger 121 | 122 | 123 | def set_path(root_dir, exp_name): 124 | path_dict = {} 125 | os.makedirs(root_dir, exist_ok=True) 126 | 127 | # set log path 128 | exp_path = os.path.join(root_dir, exp_name) 129 | now = datetime.now(dateutil.tz.tzlocal()) 130 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 131 | prefix = exp_path + '_' + timestamp 132 | os.makedirs(prefix) 133 | path_dict['prefix'] = prefix 134 | 135 | # set checkpoint path 136 | ckpt_path = os.path.join(prefix, 'Model') 137 | os.makedirs(ckpt_path) 138 | path_dict['ckpt_path'] = ckpt_path 139 | 140 | log_path = os.path.join(prefix, 'Log') 141 | os.makedirs(log_path) 142 | path_dict['log_path'] = log_path 143 | 144 | # set sample image path for fid calculation 145 | sample_path = os.path.join(prefix, 'Samples') 146 | os.makedirs(sample_path) 147 | path_dict['sample_path'] = sample_path 148 | 149 | return path_dict 150 | 151 | 152 | def to_item(x): 153 | """Converts x, possibly scalar and possibly tensor, to a Python scalar.""" 154 | if isinstance(x, (float, int)): 155 | return x 156 | 157 | if float(torch.__version__[0:3]) < 0.4: 158 | assert (x.dim() == 1) and (len(x) == 1) 159 | return x[0] 160 | 161 | return x.item() 162 | 163 | 164 | def save_checkpoint(states, is_best, output_dir, 165 | filename='checkpoint.pth'): 166 | torch.save(states, os.path.join(output_dir, filename)) 167 | if is_best: 168 | torch.save(states, os.path.join(output_dir, 'checkpoint_best.pth')) 169 | 170 | def drop_path(x, drop_prob): 171 | if drop_prob > 0.: 172 | keep_prob = 1.-drop_prob 173 | mask = torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob) 174 | x.div_(keep_prob) 175 | x.mul_(mask) 176 | return x 177 | 178 | 179 | def gumbel_softmax(logits, tau=1, hard=True, eps=1e-10, dim=-1): 180 | # type: (Tensor, float, bool, float, int) -> Tensor 181 | """ 182 | Samples from the `Gumbel-Softmax distribution`_ and optionally discretizes. 183 | 184 | Args: 185 | logits: `[..., num_features]` unnormalized log probabilities 186 | tau: non-negative scalar temperature 187 | hard: if ``True``, the returned samples will be discretized as one-hot vectors, 188 | but will be differentiated as if it is the soft sample in autograd 189 | dim (int): A dimension along which softmax will be computed. Default: -1. 190 | 191 | Returns: 192 | Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution. 193 | If ``hard=True``, the returned samples will be one-hot, otherwise they will 194 | be probability distributions that sum to 1 across `dim`. 195 | 196 | .. note:: 197 | This function is here for legacy reasons, may be removed from nn.Functional in the future. 198 | 199 | .. note:: 200 | The main trick for `hard` is to do `y_hard - y_soft.detach() + y_soft` 201 | 202 | It achieves two things: 203 | - makes the output value exactly one-hot 204 | (since we add then subtract y_soft value) 205 | - makes the gradient equal to y_soft gradient 206 | (since we strip all other gradients) 207 | 208 | Examples:: 209 | >>> logits = torch.randn(20, 32) 210 | >>> # Sample soft categorical using reparametrization trick: 211 | >>> F.gumbel_softmax(logits, tau=1, hard=False) 212 | >>> # Sample hard categorical using "Straight-through" trick: 213 | >>> F.gumbel_softmax(logits, tau=1, hard=True) 214 | 215 | .. _Gumbel-Softmax distribution: 216 | https://arxiv.org/abs/1611.00712 217 | https://arxiv.org/abs/1611.01144 218 | """ 219 | def _gen_gumbels(): 220 | gumbels = -torch.empty_like(logits).exponential_().log() 221 | if torch.isnan(gumbels).sum() or torch.isinf(gumbels).sum(): 222 | # to avoid zero in exp output 223 | gumbels = _gen_gumbels() 224 | return gumbels 225 | 226 | gumbels = _gen_gumbels() # ~Gumbel(0,1) 227 | gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) 228 | y_soft = gumbels.softmax(dim) 229 | 230 | if hard: 231 | # Straight through. 232 | index = y_soft.max(dim, keepdim=True)[1] 233 | y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0) 234 | ret = y_hard - y_soft.detach() + y_soft 235 | else: 236 | # Reparametrization trick. 237 | ret = y_soft 238 | 239 | if torch.isnan(ret).sum(): 240 | import ipdb 241 | ipdb.set_trace() 242 | raise OverflowError(f'gumbel softmax output: {ret}') 243 | return ret --------------------------------------------------------------------------------