├── .gitignore
├── LICENSE
├── README.md
├── architect.py
├── config
├── __init__.py
└── default.py
├── data_objects
├── DeepSpeakerDataset.py
├── VoxcelebTestset.py
├── __init__.py
├── audio.py
├── compute_mean_std.py
├── params_data.py
├── partition_voxceleb.py
├── preprocess.py
├── speaker.py
├── transforms.py
└── utterance.py
├── data_preprocess.py
├── dl_script.sh
├── evaluate_identification.py
├── evaluate_verification.py
├── exps
├── baseline
│ ├── resnet18_iden.yaml
│ ├── resnet18_veri.yaml
│ ├── resnet34_iden.yaml
│ └── resnet34_veri.yaml
├── scratch
│ ├── scratch_iden.yaml
│ └── scratch_veri.yaml
└── search.yaml
├── figures
├── searched_arch_normal.png
└── searched_arch_reduce.png
├── functions.py
├── loss.py
├── models
├── __init__.py
├── model.py
├── model_search.py
└── resnet.py
├── operations.py
├── requirements.txt
├── search.py
├── spaces.py
├── train_baseline_identification.py
├── train_baseline_verification.py
├── train_identification.py
├── train_verification.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | logs/**
2 | logs_search/**
3 | models/__pycache__/**
4 | __pycache__/**
5 | config/__pycache__/**
6 | data_objects/__pycache__/**
7 | logs_scratch/**
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 VITA-Group
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AutoSpeech: Neural Architecture Search for Speaker Recognition
2 |
3 | Code for this paper [AutoSpeech: Neural Architecture Search for Speaker Recognition](https://arxiv.org/abs/2005.03215)
4 |
5 | Shaojin Ding*, Tianlong Chen*, Xinyu Gong, Weiwei Zha, Zhangyang Wang
6 |
7 | ## Overview
8 | Speaker recognition systems based on Convolutional Neural Networks (CNNs) are often built with off-the-shelf backbones such as VGG-Net or ResNet. However, these backbones were originally proposed for image classification, and therefore may not be naturally fit for speaker recognition. Due to the prohibitive complexity of manually exploring the design space, we propose the first neural architecture search approach approach for the speaker recognition tasks, named as **AutoSpeech**. Our evaluation results on [VoxCeleb1](http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html) demonstrate that the derived CNN architectures from the proposed approach significantly outperform current speaker recognition systems based on VGG-M, ResNet-18, and ResNet-34 back-bones, while enjoying lower model complexity.
9 |
10 | ## Results
11 |
12 | Our proposed approach outperforms speaker recognition systems based on VGG-M, ResNet-18, and ResNet-34 backbones. The detailed comparison can be found in our paper.
13 |
14 | | Method | Top-1 | EER | Parameters | Pretrained model
15 | | :------------: | :---: | :---: | :---: | :---: |
16 | | VGG-M | 80.50 | 10.20 | 67M | [iden/veri](https://github.com/a-nagrani/VGGVox) |
17 | | ResNet-18 | 79.48 | 12.30 | 12M | [iden](https://drive.google.com/file/d/16P071LB1kwiQEoKhQRD-B3XBSwQR_6eG/view?usp=sharing), [veri](https://drive.google.com/file/d/1uNA34GTPBmrlG2gTwgrhkTfsn7zBnC7d/view?usp=sharing) |
18 | | ResNet-34 | 81.34 | 11.99| 22M | [iden](https://drive.google.com/file/d/1UJ_N5hQkVESifNJlvFdCte0yMPaqbaXJ/view?usp=sharing), [veri](https://drive.google.com/file/d/1JD34RhuvDoc19ulWQNPNArSKfDXpUYud/view?usp=sharing) |
19 | | Proposed | **87.66** | **8.95** | **18M** | [iden](https://drive.google.com/file/d/1Ph4atwl603xrbiq8OkvjdINGBCIQyXy9/view?usp=sharing), [veri](https://drive.google.com/file/d/16TrxrkRK5A0J6UxjYrQHUlEBdAESC087/view?usp=sharing) |
20 |
21 | ### Visualization
22 |
23 | left: normal cell. right: reduction cell
24 |
25 |
26 |
27 |
28 |
29 | ##
30 |
31 | ## Quick start
32 | ### Requirements
33 | * Python 3.7
34 |
35 | * Pytorch>=1.0: `pip install torch torchvision`
36 |
37 | * Other dependencies: `pip install -r requirements`
38 |
39 | ### Dataset
40 | [VoxCeleb1](http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html): You will need `DevA-DevD` and `Test` parts. Additionally, you will need original files: `vox1_meta.csv`, `iden_split.txt`, and `veri_test.txt` from official website. Alternatively, the dataset can be downloaded using `dl_script.sh`.
41 |
42 | The data should be organized as:
43 | * VoxCeleb1
44 | * dev/wav/...
45 | * test/wav/...
46 | * vox1_meta.csv
47 | * iden_split.txt
48 | * veri_test.txt
49 |
50 | ### Running the code
51 | * data preprocess:
52 |
53 | `python data_preprocess.py /path/to/VoxCeleb1`
54 |
55 | The output folder of it should be:
56 | * feature
57 | * dev
58 | * test
59 | * merged
60 |
61 | dev and test are used for verification, and merged are used for identification.
62 |
63 | * Training and evaluating ResNet-18, ResNet-34 baselines:
64 |
65 | `python train_baseline_identification.py --cfg exps/baseline/resnet18_iden.yaml`
66 |
67 | `python train_baseline_verification.py --cfg exps/baseline/resnet18_veri.yaml`
68 |
69 | `python train_baseline_identification.py --cfg exps/baseline/resnet34_iden.yaml`
70 |
71 | `python train_baseline_verification.py --cfg exps/baseline/resnet34_veri.yaml`
72 |
73 | You need to modify the `DATA_DIR` field in `.yaml` file.
74 |
75 | * Architecture search:
76 |
77 | `python search.py --cfg exps/search.yaml`
78 |
79 | You need to modify the `DATA_DIR` field in `.yaml` file.
80 |
81 | * Training from scratch for identification:
82 |
83 | `python train_identification.py --cfg exps/scratch/scratch.yaml --text_arch GENOTYPE`
84 |
85 | You need to modify the `DATA_DIR` field in `.yaml` file.
86 |
87 | `GENOTYPE` is the search architecture object. For example, the `GENOTYPE` of the architecture report in the paper is:
88 |
89 | `"Genotype(normal=[('dil_conv_5x5', 1), ('dil_conv_3x3', 0), ('dil_conv_5x5', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('sep_conv_3x3', 2), ('dil_conv_3x3', 2), ('max_pool_3x3', 1)], normal_concat=range(2, 6), reduce=[('max_pool_3x3', 1), ('max_pool_3x3', 0), ('dil_conv_5x5', 2), ('max_pool_3x3', 1), ('dil_conv_5x5', 3), ('dil_conv_3x3', 2), ('dil_conv_5x5', 4), ('dil_conv_5x5', 2)], reduce_concat=range(2, 6))"`
90 |
91 | * Training from scratch for verification:
92 |
93 | `python train_verification.py --cfg exps/scratch/scratch.yaml --text_arch GENOTYPE`
94 |
95 |
96 | * Evaluation:
97 |
98 | * Identification
99 |
100 | `python evaluate_identification.py --cfg exps/scratch/scratch_iden.yaml --load_path /path/to/the/trained/model`
101 |
102 | * Verification
103 |
104 | `python evaluate_verification.py --cfg exps/scratch/scratch_veri.yaml --load_path /path/to/the/trained/model`
105 |
106 |
107 |
108 | ## Citation
109 |
110 | If you use this code for your research, please cite our paper.
111 |
112 | ```
113 | @misc{ding2020autospeech,
114 | title={AutoSpeech: Neural Architecture Search for Speaker Recognition},
115 | author={Shaojin Ding and Tianlong Chen and Xinyu Gong and Weiwei Zha and Zhangyang Wang},
116 | year={2020},
117 | eprint={2005.03215},
118 | archivePrefix={arXiv},
119 | primaryClass={eess.AS}
120 | }
121 | ```
122 |
123 | ## Acknowledgement
124 |
125 | Part of the codes are adapted from [deep-speaker](https://github.com/philipperemy/deep-speaker) and [Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning).
126 |
--------------------------------------------------------------------------------
/architect.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def _concat(xs):
5 | return torch.cat([x.view(-1) for x in xs])
6 |
7 |
8 | class Architect(object):
9 |
10 | def __init__(self, model, cfg):
11 | self.model = model
12 | self.optimizer = torch.optim.Adam(self.model.arch_parameters(),
13 | lr=cfg.TRAIN.ARCH_LR, betas=(0.5, 0.999), weight_decay=cfg.TRAIN.ARCH_WD)
14 |
15 | def step(self, input_valid, target_valid):
16 | self.optimizer.zero_grad()
17 | self._backward_step(input_valid, target_valid)
18 | self.optimizer.step()
19 |
20 | def _backward_step(self, input_valid, target_valid):
21 | loss = self.model._loss(input_valid, target_valid)
22 | loss.backward()
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .default import _C as cfg
2 | from .default import update_config
3 |
--------------------------------------------------------------------------------
/config/default.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from yacs.config import CfgNode as CN
6 |
7 |
8 | _C = CN()
9 |
10 | _C.PRINT_FREQ = 20
11 | _C.VAL_FREQ = 20
12 |
13 | # Cudnn related params
14 | _C.CUDNN = CN()
15 | _C.CUDNN.BENCHMARK = True
16 | _C.CUDNN.DETERMINISTIC = False
17 | _C.CUDNN.ENABLED = True
18 |
19 | # seed
20 | _C.SEED = 3
21 |
22 | # common params for NETWORK
23 | _C.MODEL = CN()
24 | _C.MODEL.NAME = 'foo_net'
25 | _C.MODEL.NUM_CLASSES = 500
26 | _C.MODEL.LAYERS = 8
27 | _C.MODEL.INIT_CHANNELS = 16
28 | _C.MODEL.DROP_PATH_PROB = 0.2
29 | _C.MODEL.PRETRAINED = False
30 |
31 | # DATASET related params
32 | _C.DATASET = CN()
33 | _C.DATASET.DATA_DIR = ''
34 | _C.DATASET.SUB_DIR = ''
35 | _C.DATASET.TEST_DATA_DIR = ''
36 | _C.DATASET.TEST_DATASET = ''
37 | _C.DATASET.NUM_WORKERS = 0
38 | _C.DATASET.PARTIAL_N_FRAMES = 32
39 |
40 |
41 | # train
42 | _C.TRAIN = CN()
43 |
44 | _C.TRAIN.BATCH_SIZE = 32
45 | _C.TRAIN.LR = 0.1
46 | _C.TRAIN.LR_MIN = 0.001
47 | _C.TRAIN.WD = 0.0
48 | _C.TRAIN.BETA1 = 0.9
49 | _C.TRAIN.BETA2 = 0.999
50 |
51 | _C.TRAIN.ARCH_LR = 0.1
52 | _C.TRAIN.ARCH_WD = 0.0
53 | _C.TRAIN.ARCH_BETA1 = 0.9
54 | _C.TRAIN.ARCH_BETA2 = 0.999
55 |
56 | _C.TRAIN.DROPPATH_PROB = 0.2
57 |
58 | _C.TRAIN.BEGIN_EPOCH = 0
59 | _C.TRAIN.END_EPOCH = 140
60 |
61 |
62 | def update_config(cfg, args):
63 | cfg.defrost()
64 | cfg.merge_from_file(args.cfg)
65 | cfg.merge_from_list(args.opts)
66 |
67 | cfg.freeze()
68 |
--------------------------------------------------------------------------------
/data_objects/DeepSpeakerDataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 |
4 | import numpy as np
5 | import torch.utils.data as data
6 | from data_objects.speaker import Speaker
7 | from torchvision import transforms as T
8 | from data_objects.transforms import Normalize, TimeReverse, generate_test_sequence
9 |
10 |
11 | def find_classes(speakers):
12 | classes = list(set([speaker.name for speaker in speakers]))
13 | classes.sort()
14 | class_to_idx = {classes[i]: i for i in range(len(classes))}
15 | return classes, class_to_idx
16 |
17 |
18 | class DeepSpeakerDataset(data.Dataset):
19 |
20 | def __init__(self, data_dir, sub_dir, partial_n_frames, partition=None, is_test=False):
21 | super(DeepSpeakerDataset, self).__init__()
22 | self.data_dir = data_dir
23 | self.root = data_dir.joinpath('feature', sub_dir)
24 | self.partition = partition
25 | self.partial_n_frames = partial_n_frames
26 | self.is_test = is_test
27 |
28 | speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
29 | if len(speaker_dirs) == 0:
30 | raise Exception("No speakers found. Make sure you are pointing to the directory "
31 | "containing all preprocessed speaker directories.")
32 | self.speakers = [Speaker(speaker_dir, self.partition) for speaker_dir in speaker_dirs]
33 |
34 | classes, class_to_idx = find_classes(self.speakers)
35 | sources = []
36 | for speaker in self.speakers:
37 | sources.extend(speaker.sources)
38 | self.features = []
39 | for source in sources:
40 | item = (source[0].joinpath(source[1]), class_to_idx[source[2]])
41 | self.features.append(item)
42 | mean = np.load(self.data_dir.joinpath('mean.npy'))
43 | std = np.load(self.data_dir.joinpath('std.npy'))
44 | self.transform = T.Compose([
45 | Normalize(mean, std),
46 | TimeReverse(),
47 | ])
48 |
49 | def load_feature(self, feature_path, speaker_id):
50 | feature = np.load(feature_path)
51 | if self.is_test:
52 | test_sequence = generate_test_sequence(feature, self.partial_n_frames)
53 | return test_sequence, speaker_id
54 | else:
55 | if feature.shape[0] <= self.partial_n_frames:
56 | start = 0
57 | while feature.shape[0] < self.partial_n_frames:
58 | feature = np.repeat(feature, 2, axis=0)
59 | else:
60 | start = np.random.randint(0, feature.shape[0] - self.partial_n_frames)
61 | end = start + self.partial_n_frames
62 | return feature[start:end], speaker_id
63 |
64 | def __getitem__(self, index):
65 | feature_path, speaker_id = self.features[index]
66 | feature, speaker_id = self.load_feature(feature_path, speaker_id)
67 |
68 | if self.transform is not None:
69 | feature = self.transform(feature)
70 | return feature, speaker_id
71 |
72 | def __len__(self):
73 | return len(self.features)
74 |
75 |
--------------------------------------------------------------------------------
/data_objects/VoxcelebTestset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch.utils.data as data
3 | import numpy as np
4 | from torchvision import transforms as T
5 | from data_objects.transforms import Normalize, generate_test_sequence
6 |
7 |
8 | def get_test_paths(pairs_path, db_dir):
9 | def convert_folder_name(path):
10 | basename = os.path.splitext(path)[0]
11 | items = basename.split('/')
12 | speaker_dir = items[0]
13 | fname = '{}_{}.npy'.format(items[1], items[2])
14 | p = os.path.join(speaker_dir, fname)
15 | return p
16 |
17 | pairs = [line.strip().split() for line in open(pairs_path, 'r').readlines()]
18 | nrof_skipped_pairs = 0
19 | path_list = []
20 | issame_list = []
21 |
22 | for pair in pairs:
23 | if pair[0] == '1':
24 | issame = True
25 | else:
26 | issame = False
27 |
28 | path0 = db_dir.joinpath(convert_folder_name(pair[1]))
29 | path1 = db_dir.joinpath(convert_folder_name(pair[2]))
30 |
31 | if os.path.exists(path0) and os.path.exists(path1): # Only add the pair if both paths exist
32 | path_list.append((path0,path1,issame))
33 | issame_list.append(issame)
34 | else:
35 | nrof_skipped_pairs += 1
36 | if nrof_skipped_pairs>0:
37 | print('Skipped %d image pairs' % nrof_skipped_pairs)
38 |
39 | return path_list
40 |
41 |
42 | class VoxcelebTestset(data.Dataset):
43 | def __init__(self, data_dir, partial_n_frames):
44 | super(VoxcelebTestset, self).__init__()
45 | self.data_dir = data_dir
46 | self.root = data_dir.joinpath('feature', 'test')
47 | self.test_pair_txt_fpath = data_dir.joinpath('veri_test.txt')
48 | self.test_pairs = get_test_paths(self.test_pair_txt_fpath, self.root)
49 | self.partial_n_frames = partial_n_frames
50 | mean = np.load(self.data_dir.joinpath('mean.npy'))
51 | std = np.load(self.data_dir.joinpath('std.npy'))
52 | self.transform = T.Compose([
53 | Normalize(mean, std)
54 | ])
55 |
56 | def load_feature(self, feature_path):
57 | feature = np.load(feature_path)
58 | test_sequence = generate_test_sequence(feature, self.partial_n_frames)
59 | return test_sequence
60 |
61 | def __getitem__(self, index):
62 | (path_1, path_2, issame) = self.test_pairs[index]
63 |
64 | feature1 = self.load_feature(path_1)
65 | feature2 = self.load_feature(path_2)
66 |
67 | if self.transform is not None:
68 | feature1 = self.transform(feature1)
69 | feature2 = self.transform(feature2)
70 | return feature1, feature2, issame
71 |
72 | def __len__(self):
73 | return len(self.test_pairs)
74 |
--------------------------------------------------------------------------------
/data_objects/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NoviceMAn-prog/AutoSpeech/190049c87737a51452a7a4540ea2f1df200e0238/data_objects/__init__.py
--------------------------------------------------------------------------------
/data_objects/audio.py:
--------------------------------------------------------------------------------
1 | from data_objects.params_data import *
2 | from pathlib import Path
3 | from typing import Optional, Union
4 | import numpy as np
5 | import librosa
6 |
7 | int16_max = (2 ** 15) - 1
8 |
9 |
10 | def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
11 | source_sr: Optional[int] = None):
12 | # Load the wav from disk if needed
13 | if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
14 | wav, source_sr = librosa.load(fpath_or_wav, sr=None)
15 | else:
16 | wav = fpath_or_wav
17 |
18 | # Resample the wav if needed
19 | if source_sr is not None and source_sr != sampling_rate:
20 | wav = librosa.resample(wav, source_sr, sampling_rate)
21 |
22 | # Apply the preprocessing: normalize volume and shorten long silences
23 | wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
24 |
25 | return wav
26 |
27 |
28 | def wav_to_spectrogram(wav):
29 | frames = np.abs(librosa.core.stft(
30 | wav,
31 | n_fft=n_fft,
32 | hop_length=int(sampling_rate * window_step / 1000),
33 | win_length=int(sampling_rate * window_length / 1000),
34 | ))
35 | return frames.astype(np.float32).T
36 |
37 |
38 | def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
39 | if increase_only and decrease_only:
40 | raise ValueError("Both increase only and decrease only are set")
41 | rms = np.sqrt(np.mean((wav * int16_max) ** 2))
42 | wave_dBFS = 20 * np.log10(rms / int16_max)
43 | dBFS_change = target_dBFS - wave_dBFS
44 | if dBFS_change < 0 and increase_only or dBFS_change > 0 and decrease_only:
45 | return wav
46 | return wav * (10 ** (dBFS_change / 20))
47 |
--------------------------------------------------------------------------------
/data_objects/compute_mean_std.py:
--------------------------------------------------------------------------------
1 | from data_objects.speaker import Speaker
2 | import numpy as np
3 |
4 | def compute_mean_std(dataset_dir, output_path_mean, output_path_std):
5 | print("Computing mean std...")
6 | speaker_dirs = [f for f in dataset_dir.glob("*") if f.is_dir()]
7 | if len(speaker_dirs) == 0:
8 | raise Exception("No speakers found. Make sure you are pointing to the directory "
9 | "containing all preprocessed speaker directories.")
10 | speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
11 |
12 | sources = []
13 | for speaker in speakers:
14 | sources.extend(speaker.sources)
15 |
16 | sumx = np.zeros(257, dtype=np.float32)
17 | sumx2 = np.zeros(257, dtype=np.float32)
18 | count = 0
19 | n = len(sources)
20 | for i, source in enumerate(sources):
21 | feature = np.load(source[0].joinpath(source[1]))
22 | sumx += feature.sum(axis=0)
23 | sumx2 += (feature * feature).sum(axis=0)
24 | count += feature.shape[0]
25 |
26 | mean = sumx / count
27 | std = np.sqrt(sumx2 / count - mean * mean)
28 |
29 | mean = mean.astype(np.float32)
30 | std = std.astype(np.float32)
31 |
32 | np.save(output_path_mean, mean)
33 | np.save(output_path_std, std)
--------------------------------------------------------------------------------
/data_objects/params_data.py:
--------------------------------------------------------------------------------
1 |
2 | ## Mel-filterbank
3 | window_length = 25 # In milliseconds
4 | window_step = 10 # In milliseconds
5 | n_fft = 512
6 |
7 |
8 | ## Audio
9 | sampling_rate = 16000
10 | # Number of spectrogram frames in a partial utterance
11 | partials_n_frames = 300 # 3000 ms
12 |
13 |
14 | ## Audio volume normalization
15 | audio_norm_target_dBFS = -30
16 |
17 |
--------------------------------------------------------------------------------
/data_objects/partition_voxceleb.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def partition_voxceleb(feature_root, split_txt_path):
5 | print("partitioning VoxCeleb...")
6 | with open(split_txt_path, 'r') as f:
7 | split_txt = f.readlines()
8 | train_set = []
9 | val_set = []
10 | test_set = []
11 | for line in split_txt:
12 | items = line.strip().split()
13 | if items[0] == '3':
14 | test_set.append(items[1])
15 | elif items[0] == '2':
16 | val_set.append(items[1])
17 | else:
18 | train_set.append(items[1])
19 |
20 | speakers = os.listdir(feature_root)
21 |
22 | for speaker in speakers:
23 | speaker_dir = os.path.join(feature_root, speaker)
24 | if not os.path.isdir(speaker_dir):
25 | continue
26 | with open(os.path.join(speaker_dir, '_sources.txt'), 'r') as f:
27 | speaker_files = f.readlines()
28 |
29 | train = []
30 | val = []
31 | test = []
32 | for line in speaker_files:
33 | address = line.strip().split(',')[1]
34 | fname = os.path.join(*address.split('/')[-3:])
35 | if fname in test_set:
36 | test.append(line)
37 | elif fname in val_set:
38 | val.append(line)
39 | elif fname in train_set:
40 | train.append(line)
41 | else:
42 | print('file not in either train or test set')
43 |
44 | with open(os.path.join(speaker_dir, '_sources_train.txt'), 'w') as f:
45 | f.writelines('%s' % line for line in train)
46 | with open(os.path.join(speaker_dir, '_sources_val.txt'), 'w') as f:
47 | f.writelines('%s' % line for line in val)
48 | with open(os.path.join(speaker_dir, '_sources_test.txt'), 'w') as f:
49 | f.writelines('%s' % line for line in test)
50 |
--------------------------------------------------------------------------------
/data_objects/preprocess.py:
--------------------------------------------------------------------------------
1 | from multiprocess.pool import ThreadPool
2 | from data_objects.params_data import *
3 | from datetime import datetime
4 | from data_objects import audio
5 | from pathlib import Path
6 | from tqdm import tqdm
7 | import numpy as np
8 |
9 | anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
10 |
11 | class DatasetLog:
12 | """
13 | Registers metadata about the dataset in a text file.
14 | """
15 |
16 | def __init__(self, root, name):
17 | self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
18 | self.sample_data = dict()
19 |
20 | start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
21 | self.write_line("Creating dataset %s on %s" % (name, start_time))
22 | self.write_line("-----")
23 | self._log_params()
24 |
25 | def _log_params(self):
26 | from data_objects import params_data
27 | self.write_line("Parameter values:")
28 | for param_name in (p for p in dir(params_data) if not p.startswith("__")):
29 | value = getattr(params_data, param_name)
30 | self.write_line("\t%s: %s" % (param_name, value))
31 | self.write_line("-----")
32 |
33 | def write_line(self, line):
34 | self.text_file.write("%s\n" % line)
35 |
36 | def add_sample(self, **kwargs):
37 | for param_name, value in kwargs.items():
38 | if not param_name in self.sample_data:
39 | self.sample_data[param_name] = []
40 | self.sample_data[param_name].append(value)
41 |
42 | def finalize(self):
43 | self.write_line("Statistics:")
44 | for param_name, values in self.sample_data.items():
45 | self.write_line("\t%s:" % param_name)
46 | self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
47 | self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
48 | self.write_line("-----")
49 | end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
50 | self.write_line("Finished on %s" % end_time)
51 | self.text_file.close()
52 |
53 |
54 | def _init_preprocess_dataset(dataset_name, dataset_root, out_dir) -> (Path, DatasetLog):
55 | if not dataset_root.exists():
56 | print("Couldn\'t find %s, skipping this dataset." % dataset_root)
57 | return None, None
58 | return dataset_root, DatasetLog(out_dir, dataset_name)
59 |
60 |
61 | def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
62 | skip_existing, logger):
63 | print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
64 |
65 | # Function to preprocess utterances for one speaker
66 | def preprocess_speaker(speaker_dir: Path):
67 | # Give a name to the speaker that includes its dataset
68 | speaker_name = speaker_dir.parts[-1]
69 |
70 | # Create an output directory with that name, as well as a txt file containing a
71 | # reference to each source file.
72 | speaker_out_dir = out_dir.joinpath(speaker_name)
73 | speaker_out_dir.mkdir(exist_ok=True)
74 | sources_fpath = speaker_out_dir.joinpath("_sources.txt")
75 |
76 | # There's a possibility that the preprocessing was interrupted earlier, check if
77 | # there already is a sources file.
78 | if sources_fpath.exists():
79 | try:
80 | with sources_fpath.open("r") as sources_file:
81 | existing_fnames = {line.split(",")[0] for line in sources_file}
82 | except:
83 | existing_fnames = {}
84 | else:
85 | existing_fnames = {}
86 |
87 | # Gather all audio files for that speaker recursively
88 | sources_file = sources_fpath.open("a" if skip_existing else "w")
89 | for in_fpath in speaker_dir.glob("**/*.%s" % extension):
90 | # Check if the target output file already exists
91 | out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
92 | out_fname = out_fname.replace(".%s" % extension, ".npy")
93 | if skip_existing and out_fname in existing_fnames:
94 | continue
95 |
96 | # Load and preprocess the waveform
97 | wav = audio.preprocess_wav(in_fpath)
98 | if len(wav) == 0:
99 | print(in_fpath)
100 | continue
101 |
102 | # Create the mel spectrogram, discard those that are too short
103 | # frames = audio.wav_to_mel_spectrogram(wav)
104 | frames = audio.wav_to_spectrogram(wav)
105 | if len(frames) < partials_n_frames:
106 | continue
107 |
108 | out_fpath = speaker_out_dir.joinpath(out_fname)
109 | np.save(out_fpath, frames)
110 | logger.add_sample(duration=len(wav) / sampling_rate)
111 | sources_file.write("%s,%s\n" % (out_fname, in_fpath))
112 |
113 | sources_file.close()
114 |
115 | # Process the utterances for each speaker
116 | preprocess_speaker(speaker_dirs[0])
117 | with ThreadPool(1) as pool:
118 | list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
119 | unit="speakers"))
120 | logger.finalize()
121 | print("Done preprocessing %s.\n" % dataset_name)
122 |
123 |
124 | def preprocess_voxceleb1(dataset_root: Path, parition: str, out_dir: Path, skip_existing=False):
125 | # Initialize the preprocessing
126 | dataset_name = "VoxCeleb1"
127 | dataset_root, logger = _init_preprocess_dataset(dataset_name, dataset_root, out_dir)
128 | if not dataset_root:
129 | return
130 |
131 | # Get the contents of the meta file
132 | with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
133 | metadata = [line.split("\t") for line in metafile][1:]
134 |
135 | # Select the ID and the nationality, filter out non-anglophone speakers
136 | nationalities = {line[0]: line[3] for line in metadata}
137 | keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
138 | nationality.lower() in anglophone_nationalites]
139 | print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
140 | (len(keep_speaker_ids), len(nationalities)))
141 |
142 | # Get the speaker directories for anglophone speakers only
143 | speaker_dirs = dataset_root.joinpath(parition, 'wav').glob("*")
144 | speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs]
145 |
146 | print("VoxCeleb1: found %d anglophone speakers on the disk." %
147 | (len(speaker_dirs)))
148 | # Preprocess all speakers
149 | _preprocess_speaker_dirs(speaker_dirs, dataset_name, dataset_root, out_dir, "wav",
150 | skip_existing, logger)
151 |
152 |
--------------------------------------------------------------------------------
/data_objects/speaker.py:
--------------------------------------------------------------------------------
1 | from data_objects.utterance import Utterance
2 | from pathlib import Path
3 |
4 |
5 | # Contains the set of utterances of a single speaker
6 | class Speaker:
7 | def __init__(self, root: Path, partition=None):
8 | self.root = root
9 | self.partition = partition
10 | self.name = root.name
11 | self.utterances = None
12 | self.utterance_cycler = None
13 | if self.partition is None:
14 | with self.root.joinpath("_sources.txt").open("r") as sources_file:
15 | sources = [l.strip().split(",") for l in sources_file]
16 | else:
17 | with self.root.joinpath("_sources_{}.txt".format(self.partition)).open("r") as sources_file:
18 | sources = [l.strip().split(",") for l in sources_file]
19 | self.sources = [[self.root, frames_fname, self.name, wav_path] for frames_fname, wav_path in sources]
20 |
21 | def _load_utterances(self):
22 | self.utterances = [Utterance(source[0].joinpath(source[1])) for source in self.sources]
23 |
24 | def random_partial(self, count, n_frames):
25 | """
26 | Samples a batch of unique partial utterances from the disk in a way that all
27 | utterances come up at least once every two cycles and in a random order every time.
28 |
29 | :param count: The number of partial utterances to sample from the set of utterances from
30 | that speaker. Utterances are guaranteed not to be repeated if is not larger than
31 | the number of utterances available.
32 | :param n_frames: The number of frames in the partial utterance.
33 | :return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
34 | frames are the frames of the partial utterances and range is the range of the partial
35 | utterance with regard to the complete utterance.
36 | """
37 | if self.utterances is None:
38 | self._load_utterances()
39 |
40 | utterances = self.utterance_cycler.sample(count)
41 |
42 | a = [(u,) + u.random_partial(n_frames) for u in utterances]
43 |
44 | return a
--------------------------------------------------------------------------------
/data_objects/transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 |
4 | class Normalize(object):
5 | def __init__(self, mean, std):
6 | super(Normalize, self).__init__()
7 | self.mean = mean
8 | self.std = std
9 |
10 | def __call__(self, input):
11 | return (input - self.mean) / self.std
12 |
13 |
14 | class TimeReverse(object):
15 | def __init__(self, p=0.5):
16 | super(TimeReverse, self).__init__()
17 | self.p = p
18 |
19 | def __call__(self, input):
20 | if random.random() < self.p:
21 | return np.flip(input, axis=0).copy()
22 | return input
23 |
24 |
25 | def generate_test_sequence(feature, partial_n_frames, shift=None):
26 | while feature.shape[0] <= partial_n_frames:
27 | feature = np.repeat(feature, 2, axis=0)
28 | if shift is None:
29 | shift = partial_n_frames // 2
30 | test_sequence = []
31 | start = 0
32 | while start + partial_n_frames <= feature.shape[0]:
33 | test_sequence.append(feature[start: start + partial_n_frames])
34 | start += shift
35 | test_sequence = np.stack(test_sequence, axis=0)
36 | return test_sequence
--------------------------------------------------------------------------------
/data_objects/utterance.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Utterance:
5 | def __init__(self, frames_fpath):
6 | self.frames_fpath = frames_fpath
7 |
8 | def get_frames(self):
9 | return np.load(self.frames_fpath)
10 |
11 | def random_partial(self, n_frames):
12 | """
13 | Crops the frames into a partial utterance of n_frames
14 |
15 | :param n_frames: The number of frames of the partial utterance
16 | :return: the partial utterance frames and a tuple indicating the start and end of the
17 | partial utterance in the complete utterance.
18 | """
19 | frames = self.get_frames()
20 | if frames.shape[0] == n_frames:
21 | start = 0
22 | else:
23 | start = np.random.randint(0, frames.shape[0] - n_frames)
24 | end = start + n_frames
25 | return frames[start:end], (start, end)
--------------------------------------------------------------------------------
/data_preprocess.py:
--------------------------------------------------------------------------------
1 | from data_objects.preprocess import preprocess_voxceleb1
2 | from data_objects.compute_mean_std import compute_mean_std
3 | from data_objects.partition_voxceleb import partition_voxceleb
4 | from pathlib import Path
5 | import argparse
6 | import subprocess
7 |
8 | if __name__ == "__main__":
9 | class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter):
10 | pass
11 |
12 | parser = argparse.ArgumentParser(description="Preprocesses audio files from datasets.",
13 | formatter_class=MyFormatter
14 | )
15 | parser.add_argument("dataset_root", type=Path, help= \
16 | "Path to the directory containing VoxCeleb datasets. It should be arranged as:")
17 | parser.add_argument("-s", "--skip_existing", action="store_true", help= \
18 | "Whether to skip existing output files with the same name. Useful if this script was "
19 | "interrupted.")
20 | args = parser.parse_args()
21 |
22 | # Process the arguments
23 | dev_out_dir = args.dataset_root.joinpath("feature", "dev")
24 | test_out_dir = args.dataset_root.joinpath("feature", "test")
25 | merged_out_dir = args.dataset_root.joinpath("feature", "merged")
26 | assert args.dataset_root.exists()
27 | assert args.dataset_root.joinpath('iden_split.txt').exists()
28 | assert args.dataset_root.joinpath('veri_test.txt').exists()
29 | assert args.dataset_root.joinpath('vox1_meta.csv').exists()
30 | dev_out_dir.mkdir(exist_ok=True, parents=True)
31 | test_out_dir.mkdir(exist_ok=True, parents=True)
32 | merged_out_dir.mkdir(exist_ok=True, parents=True)
33 |
34 | # Preprocess the datasets
35 | preprocess_voxceleb1(args.dataset_root, 'dev', dev_out_dir, args.skip_existing)
36 | preprocess_voxceleb1(args.dataset_root, 'test', test_out_dir, args.skip_existing)
37 | for path in dev_out_dir.iterdir():
38 | subprocess.call(['cp', '-r', path.as_posix(), merged_out_dir.as_posix()])
39 | for path in test_out_dir.iterdir():
40 | subprocess.call(['cp', '-r', path.as_posix(), merged_out_dir.as_posix()])
41 | compute_mean_std(merged_out_dir, args.dataset_root.joinpath('mean.npy'),
42 | args.dataset_root.joinpath('std.npy'))
43 | partition_voxceleb(merged_out_dir, args.dataset_root.joinpath('iden_split.txt'))
44 | print("Done")
45 |
--------------------------------------------------------------------------------
/dl_script.sh:
--------------------------------------------------------------------------------
1 | #/bin/bash
2 |
3 | # Contributed by Aaron Soellinger
4 | # Usage (*nix):
5 | # $ mkdir VoxCeleb1; cd VoxCeleb1; /bin/bash path/to/dl_script.sh "yourusername" "yourpassword"
6 | # Note: I found my username and password in an email titled "VoxCeleb dataset"
7 |
8 | U=$1
9 | P=$2
10 | wget http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa --user "$U" --password "$P" &
11 | wget http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab --user "$U" --password "$P" &
12 | wget http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac --user "$U" --password "$P" &
13 | wget http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad --user "$U" --password "$P" &
14 | wget http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip --user "$U" --password "$P" &
15 | vox1_dev* > vox1_dev_wav.zip
16 | unzip vox1_dev_wav.zip -d "dev" &
17 | unzip vox1_test_wav.zip -d "test"
18 | rm vox1_dev_wav_part*
19 | rm wget*
20 | wget http://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/vox1_meta.csv
21 | wget http://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test.txt
22 | wget http://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/iden_split.txt
23 |
--------------------------------------------------------------------------------
/evaluate_identification.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import argparse
6 | import numpy as np
7 | import os
8 | from pathlib import Path
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.backends.cudnn as cudnn
13 |
14 | from models.model import Network
15 | from models import resnet
16 | from config import cfg, update_config
17 | from utils import create_logger, Genotype
18 | from data_objects.DeepSpeakerDataset import DeepSpeakerDataset
19 | from functions import validate_identification
20 |
21 |
22 | def parse_args():
23 | parser = argparse.ArgumentParser(description='Train autospeech network')
24 | # general
25 | parser.add_argument('--cfg',
26 | help='experiment configure file name',
27 | required=True,
28 | type=str)
29 |
30 | parser.add_argument('opts',
31 | help="Modify config options using the command-line",
32 | default=None,
33 | nargs=argparse.REMAINDER)
34 |
35 | parser.add_argument('--load_path',
36 | help="The path to resumed dir",
37 | required=True,
38 | default=None)
39 |
40 | args = parser.parse_args()
41 |
42 | return args
43 |
44 |
45 | def main():
46 | args = parse_args()
47 | update_config(cfg, args)
48 | if args.load_path is None:
49 | raise AttributeError("Please specify load path.")
50 |
51 | # cudnn related setting
52 | cudnn.benchmark = cfg.CUDNN.BENCHMARK
53 | torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
54 | torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
55 |
56 | # Set the random seed manually for reproducibility.
57 | np.random.seed(cfg.SEED)
58 | torch.manual_seed(cfg.SEED)
59 | torch.cuda.manual_seed_all(cfg.SEED)
60 |
61 | # model and optimizer
62 | if cfg.MODEL.NAME == 'model':
63 | if args.load_path and os.path.exists(args.load_path):
64 | checkpoint = torch.load(args.load_path)
65 | genotype = checkpoint['genotype']
66 | else:
67 | raise AssertionError('Please specify the model to evaluate')
68 | model = Network(cfg.MODEL.INIT_CHANNELS, cfg.MODEL.NUM_CLASSES, cfg.MODEL.LAYERS, genotype)
69 | model.drop_path_prob = 0.0
70 | else:
71 | model = eval('resnet.{}(num_classes={})'.format(cfg.MODEL.NAME, cfg.MODEL.NUM_CLASSES))
72 | model = model.cuda()
73 |
74 | criterion = nn.CrossEntropyLoss().cuda()
75 |
76 | # resume && make log dir and logger
77 | if args.load_path and os.path.exists(args.load_path):
78 | checkpoint = torch.load(args.load_path)
79 |
80 | # load checkpoint
81 | model.load_state_dict(checkpoint['state_dict'])
82 | args.path_helper = checkpoint['path_helper']
83 |
84 | logger = create_logger(os.path.dirname(args.load_path))
85 | logger.info("=> loaded checkpoint '{}'".format(args.load_path))
86 | else:
87 | raise AssertionError('Please specify the model to evaluate')
88 | logger.info(args)
89 | logger.info(cfg)
90 |
91 | # dataloader
92 | test_dataset_identification = DeepSpeakerDataset(
93 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.SUB_DIR, cfg.DATASET.PARTIAL_N_FRAMES, 'test', is_test=True)
94 |
95 |
96 | test_loader_identification = torch.utils.data.DataLoader(
97 | dataset=test_dataset_identification,
98 | batch_size=1,
99 | num_workers=cfg.DATASET.NUM_WORKERS,
100 | pin_memory=True,
101 | shuffle=False,
102 | drop_last=True,
103 | )
104 |
105 | validate_identification(cfg, model, test_loader_identification, criterion)
106 |
107 |
108 |
109 | if __name__ == '__main__':
110 | main()
111 |
--------------------------------------------------------------------------------
/evaluate_verification.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import argparse
6 | import numpy as np
7 | import os
8 | from pathlib import Path
9 |
10 | import torch
11 | import torch.backends.cudnn as cudnn
12 |
13 | from models.model import Network
14 | from models import resnet
15 | from config import cfg, update_config
16 | from utils import create_logger, Genotype
17 | from data_objects.VoxcelebTestset import VoxcelebTestset
18 | from functions import validate_verification
19 |
20 |
21 | def parse_args():
22 | parser = argparse.ArgumentParser(description='Train autospeech network')
23 | # general
24 | parser.add_argument('--cfg',
25 | help='experiment configure file name',
26 | required=True,
27 | type=str)
28 |
29 | parser.add_argument('opts',
30 | help="Modify config options using the command-line",
31 | default=None,
32 | nargs=argparse.REMAINDER)
33 |
34 | parser.add_argument('--load_path',
35 | help="The path to resumed dir",
36 | default=None)
37 | parser.add_argument('--text_arch',
38 | help="The path to arch",
39 | default=None)
40 |
41 | args = parser.parse_args()
42 |
43 | return args
44 |
45 |
46 | def main():
47 | args = parse_args()
48 | update_config(cfg, args)
49 | if args.load_path is None:
50 | raise AttributeError("Please specify load path.")
51 |
52 | # cudnn related setting
53 | cudnn.benchmark = cfg.CUDNN.BENCHMARK
54 | torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
55 | torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
56 |
57 | # Set the random seed manually for reproducibility.
58 | np.random.seed(cfg.SEED)
59 | torch.manual_seed(cfg.SEED)
60 | torch.cuda.manual_seed_all(cfg.SEED)
61 |
62 | # model and optimizer
63 | if cfg.MODEL.NAME == 'model':
64 | if args.load_path and os.path.exists(args.load_path):
65 | checkpoint = torch.load(args.load_path)
66 | genotype = checkpoint['genotype']
67 | else:
68 | raise AssertionError('Please specify the model to evaluate')
69 | model = Network(cfg.MODEL.INIT_CHANNELS, cfg.MODEL.NUM_CLASSES, cfg.MODEL.LAYERS, genotype)
70 | model.drop_path_prob = 0.0
71 | else:
72 | model = eval('resnet.{}(num_classes={})'.format(cfg.MODEL.NAME, cfg.MODEL.NUM_CLASSES))
73 | model = model.cuda()
74 |
75 | # resume && make log dir and logger
76 | if args.load_path and os.path.exists(args.load_path):
77 | checkpoint = torch.load(args.load_path)
78 |
79 | # load checkpoint
80 | model.load_state_dict(checkpoint['state_dict'])
81 | args.path_helper = checkpoint['path_helper']
82 |
83 | logger = create_logger(os.path.dirname(args.load_path))
84 | logger.info("=> loaded checkpoint '{}'".format(args.load_path))
85 | else:
86 | raise AssertionError('Please specify the model to evaluate')
87 | logger.info(args)
88 | logger.info(cfg)
89 |
90 | # dataloader
91 | test_dataset_verification = VoxcelebTestset(
92 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.PARTIAL_N_FRAMES
93 | )
94 | test_loader_verification = torch.utils.data.DataLoader(
95 | dataset=test_dataset_verification,
96 | batch_size=1,
97 | num_workers=cfg.DATASET.NUM_WORKERS,
98 | pin_memory=True,
99 | shuffle=False,
100 | drop_last=False,
101 | )
102 |
103 | validate_verification(cfg, model, test_loader_verification)
104 |
105 |
106 |
107 | if __name__ == '__main__':
108 | main()
109 |
--------------------------------------------------------------------------------
/exps/baseline/resnet18_iden.yaml:
--------------------------------------------------------------------------------
1 | PRINT_FREQ: 200
2 | VAL_FREQ: 10
3 |
4 | CUDNN:
5 | BENCHMARK: true
6 | DETERMINISTIC: false
7 | ENABLED: true
8 |
9 | DATASET:
10 | DATA_DIR: '/path/to/VoxCeleb1'
11 | SUB_DIR: 'merged'
12 | NUM_WORKERS: 0
13 | PARTIAL_N_FRAMES: 300
14 |
15 | TRAIN:
16 | BATCH_SIZE: 256
17 | LR: 0.01
18 | LR_MIN: 0.001
19 | BETA1: 0.9
20 | BETA2: 0.999
21 |
22 | BEGIN_EPOCH: 0
23 | END_EPOCH: 301
24 |
25 | MODEL:
26 | NAME: 'resnet18'
27 | NUM_CLASSES: 1251
28 | INIT_CHANNELS: 64
--------------------------------------------------------------------------------
/exps/baseline/resnet18_veri.yaml:
--------------------------------------------------------------------------------
1 | PRINT_FREQ: 200
2 | VAL_FREQ: 10
3 |
4 | CUDNN:
5 | BENCHMARK: true
6 | DETERMINISTIC: false
7 | ENABLED: true
8 |
9 | DATASET:
10 | DATA_DIR: '/path/to/VoxCeleb1'
11 | SUB_DIR: 'dev'
12 | NUM_WORKERS: 0
13 | PARTIAL_N_FRAMES: 300
14 |
15 | TRAIN:
16 | BATCH_SIZE: 256
17 | LR: 0.01
18 | LR_MIN: 0.001
19 | BETA1: 0.9
20 | BETA2: 0.999
21 |
22 | BEGIN_EPOCH: 0
23 | END_EPOCH: 301
24 |
25 | MODEL:
26 | NAME: 'resnet18'
27 | NUM_CLASSES: 1211
28 | INIT_CHANNELS: 64
--------------------------------------------------------------------------------
/exps/baseline/resnet34_iden.yaml:
--------------------------------------------------------------------------------
1 | PRINT_FREQ: 200
2 | VAL_FREQ: 10
3 |
4 | CUDNN:
5 | BENCHMARK: true
6 | DETERMINISTIC: false
7 | ENABLED: true
8 |
9 | DATASET:
10 | DATA_DIR: '/path/to/VoxCeleb1'
11 | SUB_DIR: 'merged'
12 | NUM_WORKERS: 0
13 | PARTIAL_N_FRAMES: 300
14 |
15 | TRAIN:
16 | BATCH_SIZE: 128
17 | LR: 0.01
18 | LR_MIN: 0.001
19 | BETA1: 0.9
20 | BETA2: 0.999
21 |
22 | BEGIN_EPOCH: 0
23 | END_EPOCH: 301
24 |
25 | MODEL:
26 | NAME: 'resnet34'
27 | NUM_CLASSES: 1251
28 | INIT_CHANNELS: 64
--------------------------------------------------------------------------------
/exps/baseline/resnet34_veri.yaml:
--------------------------------------------------------------------------------
1 | PRINT_FREQ: 200
2 | VAL_FREQ: 10
3 |
4 | CUDNN:
5 | BENCHMARK: true
6 | DETERMINISTIC: false
7 | ENABLED: true
8 |
9 | DATASET:
10 | DATA_DIR: '/path/to/VoxCeleb1'
11 | SUB_DIR: 'dev'
12 | NUM_WORKERS: 0
13 | PARTIAL_N_FRAMES: 300
14 |
15 | TRAIN:
16 | BATCH_SIZE: 128
17 | LR: 0.01
18 | LR_MIN: 0.001
19 | BETA1: 0.9
20 | BETA2: 0.999
21 |
22 | BEGIN_EPOCH: 0
23 | END_EPOCH: 301
24 |
25 | MODEL:
26 | NAME: 'resnet34'
27 | NUM_CLASSES: 1211
28 | INIT_CHANNELS: 128
--------------------------------------------------------------------------------
/exps/scratch/scratch_iden.yaml:
--------------------------------------------------------------------------------
1 | PRINT_FREQ: 200
2 | VAL_FREQ: 10
3 |
4 | CUDNN:
5 | BENCHMARK: true
6 | DETERMINISTIC: false
7 | ENABLED: true
8 |
9 | DATASET:
10 | DATA_DIR: '/path/to/VoxCeleb1'
11 | SUB_DIR: 'merged'
12 | NUM_WORKERS: 0
13 | PARTIAL_N_FRAMES: 300
14 |
15 | TRAIN:
16 | BATCH_SIZE: 96
17 | LR: 0.01
18 | LR_MIN: 0.001
19 | BETA1: 0.9
20 | BETA2: 0.999
21 |
22 | BEGIN_EPOCH: 0
23 | END_EPOCH: 301
24 |
25 | MODEL:
26 | NAME: 'model'
27 | NUM_CLASSES: 1251
28 | LAYERS: 8
29 | INIT_CHANNELS: 128
--------------------------------------------------------------------------------
/exps/scratch/scratch_veri.yaml:
--------------------------------------------------------------------------------
1 | PRINT_FREQ: 200
2 | VAL_FREQ: 10
3 |
4 | CUDNN:
5 | BENCHMARK: true
6 | DETERMINISTIC: false
7 | ENABLED: true
8 |
9 | DATASET:
10 | DATA_DIR: '/path/to/VoxCeleb1'
11 | SUB_DIR: 'dev'
12 | NUM_WORKERS: 0
13 | PARTIAL_N_FRAMES: 300
14 |
15 | TRAIN:
16 | BATCH_SIZE: 48
17 | LR: 0.01
18 | LR_MIN: 0.001
19 | BETA1: 0.9
20 | BETA2: 0.999
21 |
22 | BEGIN_EPOCH: 0
23 | END_EPOCH: 301
24 |
25 | MODEL:
26 | NAME: 'model'
27 | NUM_CLASSES: 1211
28 | LAYERS: 8
29 | INIT_CHANNELS: 128
--------------------------------------------------------------------------------
/exps/search.yaml:
--------------------------------------------------------------------------------
1 | PRINT_FREQ: 200
2 | VAL_FREQ: 5
3 |
4 | CUDNN:
5 | BENCHMARK: true
6 | DETERMINISTIC: false
7 | ENABLED: true
8 |
9 | DATASET:
10 | DATA_DIR: '/path/to/VoxCeleb1'
11 | SUB_DIR: 'merged'
12 | NUM_WORKERS: 0
13 | PARTIAL_N_FRAMES: 300
14 |
15 | TRAIN:
16 | BATCH_SIZE: 2
17 | LR: 0.01
18 | LR_MIN: 0.001
19 | WD: 0.0003
20 | BETA1: 0.9
21 | BETA2: 0.999
22 |
23 | ARCH_LR: 0.001
24 | ARCH_WD: 0.001
25 | ARCH_BETA1: 0.9
26 | ARCH_BETA2: 0.999
27 | DROPPATH_PROB: 0.2
28 |
29 | BEGIN_EPOCH: 0
30 | END_EPOCH: 50
31 |
32 | MODEL:
33 | NAME: 'model_search'
34 | NUM_CLASSES: 1251
35 | LAYERS: 8
36 | INIT_CHANNELS: 16
37 |
--------------------------------------------------------------------------------
/figures/searched_arch_normal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NoviceMAn-prog/AutoSpeech/190049c87737a51452a7a4540ea2f1df200e0238/figures/searched_arch_normal.png
--------------------------------------------------------------------------------
/figures/searched_arch_reduce.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NoviceMAn-prog/AutoSpeech/190049c87737a51452a7a4540ea2f1df200e0238/figures/searched_arch_reduce.png
--------------------------------------------------------------------------------
/functions.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import torch.nn.functional as F
4 | import logging
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 |
8 | from utils import compute_eer
9 | from utils import AverageMeter, ProgressMeter, accuracy
10 |
11 | plt.switch_backend('agg')
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | def train(cfg, model, optimizer, train_loader, val_loader, criterion, architect, epoch, writer_dict, lr_scheduler=None):
16 | batch_time = AverageMeter('Time', ':6.3f')
17 | data_time = AverageMeter('Data', ':6.3f')
18 | losses = AverageMeter('Loss', ':.4e')
19 | top1 = AverageMeter('Acc@1', ':6.2f')
20 | top5 = AverageMeter('Acc@5', ':6.2f')
21 | alpha_entropies = AverageMeter('Entropy', ':.4e')
22 | progress = ProgressMeter(
23 | len(train_loader), batch_time, data_time, losses, top1, top5, alpha_entropies,
24 | prefix="Epoch: [{}]".format(epoch), logger=logger)
25 | writer = writer_dict['writer']
26 |
27 | # switch to train mode
28 | model.train()
29 |
30 | end = time.time()
31 | for i, (input, target) in enumerate(train_loader):
32 | global_steps = writer_dict['train_global_steps']
33 |
34 | if lr_scheduler:
35 | current_lr = lr_scheduler.set_lr(optimizer, global_steps, epoch)
36 | else:
37 | current_lr = cfg.TRAIN.LR
38 |
39 | # measure data loading time
40 | data_time.update(time.time() - end)
41 |
42 | input = input.cuda(non_blocking=True)
43 | target = target.cuda(non_blocking=True)
44 |
45 | input_search, target_search = next(iter(val_loader))
46 | input_search = input_search.cuda(non_blocking=True)
47 | target_search = target_search.cuda(non_blocking=True)
48 |
49 | # step architecture
50 | architect.step(input_search, target_search)
51 |
52 | alpha_entropy = architect.model.compute_arch_entropy()
53 | alpha_entropies.update(alpha_entropy.mean(), input.size(0))
54 |
55 | # compute output
56 | output = model(input)
57 |
58 | # measure accuracy and record loss
59 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
60 | top1.update(acc1[0], input.size(0))
61 | top5.update(acc5[0], input.size(0))
62 | loss = criterion(output, target)
63 | losses.update(loss.item(), input.size(0))
64 |
65 | # compute gradient and do SGD step
66 | optimizer.zero_grad()
67 | loss.backward()
68 | optimizer.step()
69 |
70 | # measure elapsed time
71 | batch_time.update(time.time() - end)
72 | end = time.time()
73 |
74 | # write to logger
75 | writer.add_scalar('lr', current_lr, global_steps)
76 | writer.add_scalar('train_loss', losses.val, global_steps)
77 | writer.add_scalar('arch_entropy', alpha_entropies.val, global_steps)
78 |
79 | writer_dict['train_global_steps'] = global_steps + 1
80 |
81 | # log acc for cross entropy loss
82 | writer.add_scalar('train_acc1', top1.val, global_steps)
83 | writer.add_scalar('train_acc5', top5.val, global_steps)
84 |
85 | if i % cfg.PRINT_FREQ == 0:
86 | progress.print(i)
87 |
88 |
89 | def train_from_scratch(cfg, model, optimizer, train_loader, criterion, epoch, writer_dict, lr_scheduler=None):
90 | batch_time = AverageMeter('Time', ':6.3f')
91 | data_time = AverageMeter('Data', ':6.3f')
92 | losses = AverageMeter('Loss', ':.4e')
93 | top1 = AverageMeter('Acc@1', ':6.2f')
94 | top5 = AverageMeter('Acc@5', ':6.2f')
95 | progress = ProgressMeter(
96 | len(train_loader), batch_time, data_time, losses, top1, top5, prefix="Epoch: [{}]".format(epoch), logger=logger)
97 | writer = writer_dict['writer']
98 |
99 | # switch to train mode
100 | model.train()
101 |
102 | end = time.time()
103 | for i, (input, target) in enumerate(train_loader):
104 | global_steps = writer_dict['train_global_steps']
105 |
106 | if lr_scheduler:
107 | current_lr = lr_scheduler.get_lr()
108 | else:
109 | current_lr = cfg.TRAIN.LR
110 |
111 | # measure data loading time
112 | data_time.update(time.time() - end)
113 |
114 | input = input.cuda(non_blocking=True)
115 | target = target.cuda(non_blocking=True)
116 |
117 | # compute output
118 | output = model(input)
119 |
120 | # measure accuracy and record loss
121 | loss = criterion(output, target)
122 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
123 | top1.update(acc1[0], input.size(0))
124 | top5.update(acc5[0], input.size(0))
125 | losses.update(loss.item(), input.size(0))
126 |
127 | # compute gradient and do SGD step
128 | optimizer.zero_grad()
129 | loss.backward()
130 | optimizer.step()
131 |
132 | # measure elapsed time
133 | batch_time.update(time.time() - end)
134 | end = time.time()
135 |
136 | # write to logger
137 | writer.add_scalar('lr', current_lr, global_steps)
138 | writer.add_scalar('train_loss', losses.val, global_steps)
139 | writer_dict['train_global_steps'] = global_steps + 1
140 |
141 | # log acc for cross entropy loss
142 | writer.add_scalar('train_acc1', top1.val, global_steps)
143 | writer.add_scalar('train_acc5', top5.val, global_steps)
144 |
145 | if i % cfg.PRINT_FREQ == 0:
146 | progress.print(i)
147 |
148 |
149 | def validate_verification(cfg, model, test_loader):
150 | batch_time = AverageMeter('Time', ':6.3f')
151 | progress = ProgressMeter(
152 | len(test_loader), batch_time, prefix='Test: ', logger=logger)
153 |
154 | # switch to evaluate mode
155 | model.eval()
156 | labels, distances = [], []
157 |
158 | with torch.no_grad():
159 | end = time.time()
160 | for i, (input1, input2, label) in enumerate(test_loader):
161 | input1 = input1.cuda(non_blocking=True).squeeze(0)
162 | input2 = input2.cuda(non_blocking=True).squeeze(0)
163 | label = label.cuda(non_blocking=True)
164 |
165 | # compute output
166 | outputs1 = model(input1).mean(dim=0).unsqueeze(0)
167 | outputs2 = model(input2).mean(dim=0).unsqueeze(0)
168 |
169 | dists = F.cosine_similarity(outputs1, outputs2)
170 | dists = dists.data.cpu().numpy()
171 | distances.append(dists)
172 | labels.append(label.data.cpu().numpy())
173 |
174 | # measure elapsed time
175 | batch_time.update(time.time() - end)
176 | end = time.time()
177 |
178 | if i % 2000 == 0:
179 | progress.print(i)
180 |
181 | labels = np.array([sublabel for label in labels for sublabel in label])
182 | distances = np.array([subdist for dist in distances for subdist in dist])
183 |
184 | eer = compute_eer(distances, labels)
185 | logger.info('Test EER: {:.8f}'.format(np.mean(eer)))
186 |
187 | return eer
188 |
189 |
190 | def validate_identification(cfg, model, test_loader, criterion):
191 | batch_time = AverageMeter('Time', ':6.3f')
192 | losses = AverageMeter('Loss', ':.4e')
193 | top1 = AverageMeter('Acc@1', ':6.2f')
194 | top5 = AverageMeter('Acc@5', ':6.2f')
195 | progress = ProgressMeter(
196 | len(test_loader), batch_time, losses, top1, top5, prefix='Test: ', logger=logger)
197 |
198 | # switch to evaluate mode
199 | model.eval()
200 |
201 | with torch.no_grad():
202 | end = time.time()
203 | for i, (input, target) in enumerate(test_loader):
204 | input = input.cuda(non_blocking=True).squeeze(0)
205 | target = target.cuda(non_blocking=True)
206 |
207 | # compute output
208 | output = model(input)
209 | output = torch.mean(output, dim=0, keepdim=True)
210 | output = model.forward_classifier(output)
211 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
212 | top1.update(acc1[0], input.size(0))
213 | top5.update(acc5[0], input.size(0))
214 | loss = criterion(output, target)
215 |
216 | losses.update(loss.item(), 1)
217 |
218 | # measure elapsed time
219 | batch_time.update(time.time() - end)
220 | end = time.time()
221 |
222 | if i % 2000 == 0:
223 | progress.print(i)
224 |
225 | logger.info('Test Acc@1: {:.8f} Acc@5: {:.8f}'.format(top1.avg, top5.avg))
226 |
227 | return top1.avg
228 |
229 |
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class CrossEntropyLoss(nn.Module):
10 | r"""Cross entropy loss with label smoothing regularizer.
11 |
12 | Reference:
13 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
14 | With label smoothing, the label :math:`y` for a class is computed by
15 |
16 | .. math::
17 | \begin{equation}
18 | (1 - \epsilon) \times y + \frac{\epsilon}{K},
19 | \end{equation}
20 | where :math:`K` denotes the number of classes and :math:`\epsilon` is a weight. When
21 | :math:`\epsilon = 0`, the loss function reduces to the normal cross entropy.
22 |
23 | Args:
24 | num_classes (int): number of classes.
25 | epsilon (float, optional): weight. Default is 0.1.
26 | use_gpu (bool, optional): whether to use gpu devices. Default is True.
27 | label_smooth (bool, optional): whether to apply label smoothing. Default is True.
28 | """
29 |
30 | def __init__(
31 | self, num_classes, epsilon=0.1, use_gpu=True, label_smooth=True
32 | ):
33 | super(CrossEntropyLoss, self).__init__()
34 | self.num_classes = num_classes
35 | self.epsilon = epsilon if label_smooth else 0
36 | self.use_gpu = use_gpu
37 | self.logsoftmax = nn.LogSoftmax(dim=1)
38 |
39 | def forward(self, inputs, targets):
40 | """
41 | Args:
42 | inputs (torch.Tensor): prediction matrix (before softmax) with
43 | shape (batch_size, num_classes).
44 | targets (torch.LongTensor): ground truth labels with shape (batch_size).
45 | Each position contains the label index.
46 | """
47 | log_probs = self.logsoftmax(inputs)
48 | zeros = torch.zeros(log_probs.size())
49 | targets = zeros.scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
50 | if self.use_gpu:
51 | targets = targets.cuda()
52 | targets = (
53 | 1 - self.epsilon
54 | ) * targets + self.epsilon / self.num_classes
55 | return (-targets * log_probs).mean(0).sum()
56 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-08-08
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 |
7 | from __future__ import absolute_import
8 | from __future__ import division
9 | from __future__ import print_function
10 |
11 | import models.resnet
12 | #import models.multi_scale_model
13 | #import models.multi_scale_model_v2
14 | #import models.multi_scale_model_v3
15 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | from operations import *
2 | from utils import drop_path
3 |
4 |
5 | class Cell(nn.Module):
6 |
7 | def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
8 | super(Cell, self).__init__()
9 | print(C_prev_prev, C_prev, C)
10 |
11 | if reduction_prev:
12 | self.preprocess0 = FactorizedReduce(C_prev_prev, C)
13 | else:
14 | self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
15 | self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
16 |
17 | if reduction:
18 | op_names, indices = zip(*genotype.reduce)
19 | concat = genotype.reduce_concat
20 | else:
21 | op_names, indices = zip(*genotype.normal)
22 | concat = genotype.normal_concat
23 | self._compile(C, op_names, indices, concat, reduction)
24 |
25 | def _compile(self, C, op_names, indices, concat, reduction):
26 | assert len(op_names) == len(indices)
27 | self._steps = len(op_names) // 2
28 | self._concat = concat
29 | self.multiplier = len(concat)
30 |
31 | self._ops = nn.ModuleList()
32 | for name, index in zip(op_names, indices):
33 | stride = 2 if reduction and index < 2 else 1
34 | op = OPS[name](C, stride, True)
35 | self._ops += [op]
36 | self._indices = indices
37 |
38 | def forward(self, s0, s1, drop_prob):
39 | s0 = self.preprocess0(s0)
40 | s1 = self.preprocess1(s1)
41 |
42 | states = [s0, s1]
43 | for i in range(self._steps):
44 | h1 = states[self._indices[2 * i]]
45 | h2 = states[self._indices[2 * i + 1]]
46 | op1 = self._ops[2 * i]
47 | op2 = self._ops[2 * i + 1]
48 | h1 = op1(h1)
49 | h2 = op2(h2)
50 | if self.training and drop_prob > 0.:
51 | if not isinstance(op1, Identity):
52 | h1 = drop_path(h1, drop_prob)
53 | if not isinstance(op2, Identity):
54 | h2 = drop_path(h2, drop_prob)
55 | s = h1 + h2
56 | states += [s]
57 | return torch.cat([states[i] for i in self._concat], dim=1)
58 |
59 |
60 | class Network(nn.Module):
61 |
62 | def __init__(self, C, num_classes, layers, genotype):
63 | super(Network, self).__init__()
64 | self._C = C
65 | self._num_classes = num_classes
66 | self._layers = layers
67 |
68 | self.stem0 = nn.Sequential(
69 | nn.Conv2d(1, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
70 | nn.BatchNorm2d(C // 2),
71 | nn.ReLU(inplace=True),
72 | nn.Conv2d(C // 2, C, kernel_size=3, stride=2, padding=1, bias=False),
73 | nn.BatchNorm2d(C),
74 | )
75 |
76 | self.stem1 = nn.Sequential(
77 | nn.ReLU(inplace=True),
78 | nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
79 | nn.BatchNorm2d(C),
80 | )
81 |
82 | C_prev_prev, C_prev, C_curr = C, C, C
83 |
84 | self.cells = nn.ModuleList()
85 | reduction_prev = True
86 | for i in range(layers):
87 | if i in [layers // 3, 2 * layers // 3]:
88 | C_curr *= 2
89 | reduction = True
90 | else:
91 | reduction = False
92 | cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
93 | reduction_prev = reduction
94 | self.cells += [cell]
95 | C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
96 |
97 | self.global_pooling = nn.AdaptiveAvgPool2d((1, 1))
98 | self.classifier = nn.Linear(C_prev, num_classes)
99 |
100 | def forward(self, input):
101 | input = input.unsqueeze(1)
102 | s0 = self.stem0(input)
103 | s1 = self.stem1(s0)
104 | for i, cell in enumerate(self.cells):
105 | s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
106 | v = self.global_pooling(s1)
107 | v = v.view(v.size(0), -1)
108 | if not self.training:
109 | return v
110 |
111 | y = self.classifier(v)
112 |
113 | return y
114 |
115 |
116 | def forward_classifier(self, v):
117 | y = self.classifier(v)
118 | return y
119 |
120 |
121 |
122 |
123 |
--------------------------------------------------------------------------------
/models/model_search.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch.nn.functional as F
3 | from operations import *
4 | from utils import Genotype
5 | from utils import gumbel_softmax, drop_path
6 |
7 |
8 | class MixedOp(nn.Module):
9 |
10 | def __init__(self, C, stride, PRIMITIVES):
11 | super(MixedOp, self).__init__()
12 | self._ops = nn.ModuleList()
13 | for primitive in PRIMITIVES:
14 | op = OPS[primitive](C, stride, False)
15 | if 'pool' in primitive:
16 | op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))
17 | self._ops.append(op)
18 |
19 | def forward(self, x, weights):
20 | """
21 | This is a forward function.
22 | :param x: Feature map
23 | :param weights: A tensor of weight controlling the path flow
24 | :return: A weighted sum of several path
25 | """
26 | output = 0
27 | for op_idx, op in enumerate(self._ops):
28 | if weights[op_idx].item() != 0:
29 | if math.isnan(weights[op_idx]):
30 | raise OverflowError(f'weight: {weights}')
31 | output += weights[op_idx] * op(x)
32 | return output
33 |
34 |
35 | class Cell(nn.Module):
36 |
37 | def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
38 | super(Cell, self).__init__()
39 | self.reduction = reduction
40 | self.primitives = self.PRIMITIVES['primitives_reduct' if reduction else 'primitives_normal']
41 |
42 | if reduction_prev:
43 | self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
44 | else:
45 | self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
46 | self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
47 | self._steps = steps
48 | self._multiplier = multiplier
49 |
50 | self._ops = nn.ModuleList()
51 | self._bns = nn.ModuleList()
52 |
53 | edge_index = 0
54 |
55 | for i in range(self._steps):
56 | for j in range(2 + i):
57 | stride = 2 if reduction and j < 2 else 1
58 | op = MixedOp(C, stride, self.primitives[edge_index])
59 | self._ops.append(op)
60 | edge_index += 1
61 |
62 | def forward(self, s0, s1, weights, drop_prob=0.0):
63 | s0 = self.preprocess0(s0)
64 | s1 = self.preprocess1(s1)
65 |
66 | states = [s0, s1]
67 | offset = 0
68 | for i in range(self._steps):
69 | if drop_prob > 0. and self.training:
70 | s = sum(
71 | drop_path(self._ops[offset + j](h, weights[offset + j]), drop_prob) for j, h in enumerate(states))
72 | else:
73 | s = sum(self._ops[offset + j](h, weights[offset + j]) for j, h in enumerate(states))
74 | offset += len(states)
75 | states.append(s)
76 |
77 | return torch.cat(states[-self._multiplier:], dim=1)
78 |
79 |
80 | class Network(nn.Module):
81 |
82 | def __init__(self, C, num_classes, layers, criterion, primitives,
83 | steps=4, multiplier=4, stem_multiplier=3, drop_path_prob=0.0):
84 | super(Network, self).__init__()
85 | self._C = C
86 | self._num_classes = num_classes
87 | self._layers = layers
88 | self._criterion = criterion
89 | self._steps = steps
90 | self._multiplier = multiplier
91 | self.drop_path_prob = drop_path_prob
92 |
93 | nn.Module.PRIMITIVES = primitives
94 |
95 | C_curr = stem_multiplier * C
96 | self.stem = nn.Sequential(
97 | nn.Conv2d(1, C_curr, 3, padding=1, bias=False),
98 | nn.BatchNorm2d(C_curr),
99 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
100 | )
101 |
102 | C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
103 | self.cells = nn.ModuleList()
104 | reduction_prev = False
105 | for i in range(layers):
106 | if i in [layers // 3, 2 * layers // 3]:
107 | C_curr *= 2
108 | reduction = True
109 | else:
110 | reduction = False
111 | cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
112 | reduction_prev = reduction
113 | self.cells += [cell]
114 | C_prev_prev, C_prev = C_prev, multiplier * C_curr
115 |
116 | self.global_pooling = nn.AdaptiveAvgPool2d((1, 1))
117 | self.classifier = nn.Linear(C_prev, self._num_classes)
118 |
119 | self._initialize_alphas()
120 |
121 | def new(self):
122 | model_new = Network(self._C, self._embed_dim, self._layers, self._criterion,
123 | self.PRIMITIVES, drop_path_prob=self.drop_path_prob).cuda()
124 | for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
125 | x.data.copy_(y.data)
126 | return model_new
127 |
128 | def forward(self, input, discrete=False):
129 | input = input.unsqueeze(1)
130 | s0 = s1 = self.stem(input)
131 | for i, cell in enumerate(self.cells):
132 | if cell.reduction:
133 | if discrete:
134 | weights = self.alphas_reduce
135 | else:
136 | weights = gumbel_softmax(F.log_softmax(self.alphas_reduce, dim=-1))
137 | else:
138 | if discrete:
139 | weights = self.alphas_normal
140 | else:
141 | weights = gumbel_softmax(F.log_softmax(self.alphas_normal, dim=-1))
142 | s0, s1 = s1, cell(s0, s1, weights, self.drop_path_prob)
143 | v = self.global_pooling(s1)
144 | v = v.view(v.size(0), -1)
145 | if not self.training:
146 | return v
147 |
148 | y = self.classifier(v)
149 |
150 | return y
151 |
152 | def forward_classifier(self, v):
153 | y = self.classifier(v)
154 | return y
155 |
156 | def _loss(self, input, target):
157 | logits = self(input)
158 | return self._criterion(logits, target)
159 |
160 | def _initialize_alphas(self):
161 | k = sum(1 for i in range(self._steps) for n in range(2 + i))
162 | num_ops = len(self.PRIMITIVES['primitives_normal'][0])
163 |
164 | self.alphas_normal = nn.Parameter(1e-3 * torch.randn(k, num_ops))
165 | self.alphas_reduce = nn.Parameter(1e-3 * torch.randn(k, num_ops))
166 | self._arch_parameters = [
167 | self.alphas_normal,
168 | self.alphas_reduce,
169 | ]
170 |
171 | def arch_parameters(self):
172 | return self._arch_parameters
173 |
174 | def compute_arch_entropy(self, dim=-1):
175 | alpha = self.arch_parameters()[0]
176 | prob = F.softmax(alpha, dim=dim)
177 | log_prob = F.log_softmax(alpha, dim=dim)
178 | entropy = - (log_prob * prob).sum(-1, keepdim=False)
179 | return entropy
180 |
181 | def genotype(self):
182 | def _parse(weights, normal=True):
183 | PRIMITIVES = self.PRIMITIVES['primitives_normal' if normal else 'primitives_reduct']
184 |
185 | gene = []
186 | n = 2
187 | start = 0
188 | for i in range(self._steps):
189 | end = start + n
190 | W = weights[start:end].copy()
191 | try:
192 | edges = sorted(range(i + 2), key=lambda x: -max(
193 | W[x][k] for k in range(len(W[x])) if k != PRIMITIVES[x].index('none')))[:2]
194 | except ValueError: # This error happens when the 'none' op is not present in the ops
195 | edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x]))))[:2]
196 | for j in edges:
197 | k_best = None
198 | for k in range(len(W[j])):
199 | if 'none' in PRIMITIVES[j]:
200 | if k != PRIMITIVES[j].index('none'):
201 | if k_best is None or W[j][k] > W[j][k_best]:
202 | k_best = k
203 | else:
204 | if k_best is None or W[j][k] > W[j][k_best]:
205 | k_best = k
206 | gene.append((PRIMITIVES[start+j][k_best], j))
207 | start = end
208 | n += 1
209 | return gene
210 |
211 | gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).data.cpu().numpy(), True)
212 | gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).data.cpu().numpy(), False)
213 |
214 | concat = range(2 + self._steps - self._multiplier, self._steps + 2)
215 | genotype = Genotype(
216 | normal=gene_normal, normal_concat=concat,
217 | reduce=gene_reduce, reduce_concat=concat
218 | )
219 | return genotype
220 |
221 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Code source: https://github.com/pytorch/vision
3 | """
4 | from __future__ import absolute_import
5 | from __future__ import division
6 |
7 | __all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d',
8 | 'resnext101_32x8d', 'resnet50_fc512']
9 |
10 | from torch import nn
11 | import torch
12 | import torch.utils.model_zoo as model_zoo
13 |
14 | model_urls = {
15 | 'resnet18':
16 | 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
17 | 'resnet34':
18 | 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
19 | 'resnet50':
20 | 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
21 | 'resnet101':
22 | 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
23 | 'resnet152':
24 | 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
25 | 'resnext50_32x4d':
26 | 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
27 | 'resnext101_32x8d':
28 | 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
29 | }
30 |
31 |
32 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
33 | """3x3 convolution with padding"""
34 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
35 | padding=dilation, groups=groups, bias=False, dilation=dilation)
36 |
37 |
38 | def conv1x1(in_planes, out_planes, stride=1):
39 | """1x1 convolution"""
40 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
41 |
42 |
43 | class BasicBlock(nn.Module):
44 | expansion = 1
45 |
46 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
47 | base_width=64, dilation=1, norm_layer=None):
48 | super(BasicBlock, self).__init__()
49 | if norm_layer is None:
50 | norm_layer = nn.BatchNorm2d
51 | if groups != 1 or base_width != 64:
52 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
53 | if dilation > 1:
54 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
55 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
56 | self.conv1 = conv3x3(inplanes, planes, stride)
57 | self.bn1 = norm_layer(planes)
58 | self.relu = nn.ReLU(inplace=True)
59 | self.conv2 = conv3x3(planes, planes)
60 | self.bn2 = norm_layer(planes)
61 | self.downsample = downsample
62 | self.stride = stride
63 |
64 | def forward(self, x):
65 | identity = x
66 |
67 | out = self.conv1(x)
68 | out = self.bn1(out)
69 | out = self.relu(out)
70 |
71 | out = self.conv2(out)
72 | out = self.bn2(out)
73 |
74 | if self.downsample is not None:
75 | identity = self.downsample(x)
76 |
77 | out += identity
78 | out = self.relu(out)
79 |
80 | return out
81 |
82 |
83 | class Bottleneck(nn.Module):
84 | expansion = 4
85 |
86 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
87 | base_width=64, dilation=1, norm_layer=None):
88 | super(Bottleneck, self).__init__()
89 | if norm_layer is None:
90 | norm_layer = nn.BatchNorm2d
91 | width = int(planes * (base_width / 64.)) * groups
92 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
93 | self.conv1 = conv1x1(inplanes, width)
94 | self.bn1 = norm_layer(width)
95 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
96 | self.bn2 = norm_layer(width)
97 | self.conv3 = conv1x1(width, planes * self.expansion)
98 | self.bn3 = norm_layer(planes * self.expansion)
99 | self.relu = nn.ReLU(inplace=True)
100 | self.downsample = downsample
101 | self.stride = stride
102 |
103 | def forward(self, x):
104 | identity = x
105 |
106 | out = self.conv1(x)
107 | out = self.bn1(out)
108 | out = self.relu(out)
109 |
110 | out = self.conv2(out)
111 | out = self.bn2(out)
112 | out = self.relu(out)
113 |
114 | out = self.conv3(out)
115 | out = self.bn3(out)
116 |
117 | if self.downsample is not None:
118 | identity = self.downsample(x)
119 |
120 | out += identity
121 | out = self.relu(out)
122 |
123 | return out
124 |
125 |
126 | class ResNet(nn.Module):
127 | """Residual network.
128 |
129 | Reference:
130 | - He et al. Deep Residual Learning for Image Recognition. CVPR 2016.
131 | - Xie et al. Aggregated Residual Transformations for Deep Neural Networks. CVPR 2017.
132 | Public keys:
133 | - ``resnet18``: ResNet18.
134 | - ``resnet34``: ResNet34.
135 | - ``resnet50``: ResNet50.
136 | - ``resnet101``: ResNet101.
137 | - ``resnet152``: ResNet152.
138 | - ``resnext50_32x4d``: ResNeXt50.
139 | - ``resnext101_32x8d``: ResNeXt101.
140 | - ``resnet50_fc512``: ResNet50 + FC.
141 | """
142 |
143 | def __init__(self, num_classes, loss, block, layers, zero_init_residual=False,
144 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
145 | norm_layer=None, last_stride=2, fc_dims=None, dropout_p=None, **kwargs):
146 | super(ResNet, self).__init__()
147 | if norm_layer is None:
148 | norm_layer = nn.BatchNorm2d
149 | self._norm_layer = norm_layer
150 | self.loss = loss
151 | self.feature_dim = 512 * block.expansion
152 | self.inplanes = 64
153 | self.dilation = 1
154 | if replace_stride_with_dilation is None:
155 | # each element in the tuple indicates if we should replace
156 | # the 2x2 stride with a dilated convolution instead
157 | replace_stride_with_dilation = [False, False, False]
158 | if len(replace_stride_with_dilation) != 3:
159 | raise ValueError("replace_stride_with_dilation should be None "
160 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
161 | self.groups = groups
162 | self.base_width = width_per_group
163 | self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=2,
164 | bias=False)
165 | self.bn1 = norm_layer(self.inplanes)
166 | self.relu = nn.ReLU(inplace=True)
167 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
168 | self.layer1 = self._make_layer(block, 64, layers[0])
169 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
170 | dilate=replace_stride_with_dilation[0])
171 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
172 | dilate=replace_stride_with_dilation[1])
173 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride,
174 | dilate=replace_stride_with_dilation[2])
175 |
176 | self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1))
177 | self.fc = self._construct_fc_layer(fc_dims, self.feature_dim, dropout_p)
178 | self.classifier = nn.Linear(self.feature_dim, num_classes)
179 |
180 | self._init_params()
181 |
182 | # Zero-initialize the last BN in each residual branch,
183 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
184 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
185 | if zero_init_residual:
186 | for m in self.modules():
187 | if isinstance(m, Bottleneck):
188 | nn.init.constant_(m.bn3.weight, 0)
189 | elif isinstance(m, BasicBlock):
190 | nn.init.constant_(m.bn2.weight, 0)
191 |
192 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
193 | norm_layer = self._norm_layer
194 | downsample = None
195 | previous_dilation = self.dilation
196 | if dilate:
197 | self.dilation *= stride
198 | stride = 1
199 | if stride != 1 or self.inplanes != planes * block.expansion:
200 | downsample = nn.Sequential(
201 | conv1x1(self.inplanes, planes * block.expansion, stride),
202 | norm_layer(planes * block.expansion),
203 | )
204 |
205 | layers = []
206 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
207 | self.base_width, previous_dilation, norm_layer))
208 | self.inplanes = planes * block.expansion
209 | for _ in range(1, blocks):
210 | layers.append(block(self.inplanes, planes, groups=self.groups,
211 | base_width=self.base_width, dilation=self.dilation,
212 | norm_layer=norm_layer))
213 |
214 | return nn.Sequential(*layers)
215 |
216 | def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
217 | """Constructs fully connected layer
218 | Args:
219 | fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed
220 | input_dim (int): input dimension
221 | dropout_p (float): dropout probability, if None, dropout is unused
222 | """
223 | if fc_dims is None:
224 | self.feature_dim = input_dim
225 | return None
226 |
227 | assert isinstance(fc_dims, (list, tuple)), 'fc_dims must be either list or tuple, but got {}'.format(
228 | type(fc_dims))
229 |
230 | layers = []
231 | for dim in fc_dims:
232 | layers.append(nn.Linear(input_dim, dim))
233 | layers.append(nn.BatchNorm1d(dim))
234 | layers.append(nn.ReLU(inplace=True))
235 | if dropout_p is not None:
236 | layers.append(nn.Dropout(p=dropout_p))
237 | input_dim = dim
238 |
239 | self.feature_dim = fc_dims[-1]
240 |
241 | return nn.Sequential(*layers)
242 |
243 | def _init_params(self):
244 | for m in self.modules():
245 | if isinstance(m, nn.Conv2d):
246 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
247 | if m.bias is not None:
248 | nn.init.constant_(m.bias, 0)
249 | elif isinstance(m, nn.BatchNorm2d):
250 | nn.init.constant_(m.weight, 1)
251 | nn.init.constant_(m.bias, 0)
252 | elif isinstance(m, nn.BatchNorm1d):
253 | nn.init.constant_(m.weight, 1)
254 | nn.init.constant_(m.bias, 0)
255 | elif isinstance(m, nn.Linear):
256 | nn.init.normal_(m.weight, 0, 0.01)
257 | if m.bias is not None:
258 | nn.init.constant_(m.bias, 0)
259 |
260 | def featuremaps(self, x):
261 | x = x.unsqueeze(1)
262 | x = self.conv1(x)
263 | x = self.bn1(x)
264 | x = self.relu(x)
265 | x = self.maxpool(x)
266 | x = self.layer1(x)
267 | x = self.layer2(x)
268 | x = self.layer3(x)
269 | x = self.layer4(x)
270 | return x
271 |
272 | def forward(self, x):
273 | f = self.featuremaps(x)
274 | v = self.global_avgpool(f)
275 | v = v.view(v.size(0), -1)
276 |
277 | if self.fc is not None:
278 | v = self.fc(v)
279 |
280 | if not self.training:
281 | return v
282 |
283 | y = self.classifier(v)
284 |
285 | if self.loss == 'xent':
286 | return y
287 | elif self.loss == 'htri':
288 | return y, v
289 | else:
290 | raise KeyError("Unsupported loss: {}".format(self.loss))
291 |
292 | def forward_classifier(self, v):
293 | y = self.classifier(v)
294 | return y
295 |
296 | """ResNet"""
297 |
298 |
299 | def init_pretrained_weights(model, model_url):
300 | """Initializes model with pretrained weights.
301 |
302 | Layers that don't match with pretrained layers in name or size are kept unchanged.
303 | """
304 | pretrain_dict = model_zoo.load_url(model_url)
305 | model_dict = model.state_dict()
306 | pretrain_dict = {
307 | k: v
308 | for k, v in pretrain_dict.items()
309 | if k in model_dict and model_dict[k].size() == v.size()
310 | }
311 | model_dict.update(pretrain_dict)
312 | model.load_state_dict(model_dict)
313 |
314 |
315 | def resnet18(num_classes, loss='xent', pretrained=True, **kwargs):
316 | model = ResNet(
317 | num_classes=num_classes,
318 | loss=loss,
319 | block=BasicBlock,
320 | layers=[2, 2, 2, 2],
321 | last_stride=2,
322 | fc_dims=None,
323 | dropout_p=None,
324 | **kwargs
325 | )
326 | if pretrained:
327 | init_pretrained_weights(model, model_urls['resnet18'])
328 | return model
329 |
330 |
331 | def resnet34(num_classes, loss='xent', pretrained=True, **kwargs):
332 | model = ResNet(
333 | num_classes=num_classes,
334 | loss=loss,
335 | block=BasicBlock,
336 | layers=[3, 4, 6, 3],
337 | last_stride=2,
338 | fc_dims=None,
339 | dropout_p=None,
340 | **kwargs
341 | )
342 | if pretrained:
343 | init_pretrained_weights(model, model_urls['resnet34'])
344 | return model
345 |
--------------------------------------------------------------------------------
/operations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | OPS = {
5 | 'none' : lambda C, stride, affine: Zero(stride),
6 | 'avg_pool_3x3' : lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
7 | 'max_pool_3x3' : lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1),
8 | 'skip_connect' : lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
9 | 'sep_conv_3x3' : lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
10 | 'sep_conv_5x5' : lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
11 | 'sep_conv_7x7' : lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
12 | 'sep_conv_3x1' : lambda C, stride, affine: SepConvTime(C, C, 3, stride, 1, affine=affine),
13 | 'sep_conv_1x3' : lambda C, stride, affine: SepConvFreq(C, C, 3, stride, 1, affine=affine),
14 | 'dil_conv_3x3' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine),
15 | 'dil_conv_5x5' : lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine),
16 | 'dil_conv_3x1' : lambda C, stride, affine: DilConvTime(C, C, 3, stride, 2, 2, affine=affine),
17 | 'dil_conv_1x3' : lambda C, stride, affine: DilConvFreq(C, C, 3, stride, 2, 2, affine=affine),
18 | 'conv_7x1_1x7' : lambda C, stride, affine: nn.Sequential(
19 | nn.ReLU(inplace=False),
20 | nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
21 | nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
22 | nn.BatchNorm2d(C, affine=affine)
23 | ),
24 | }
25 |
26 | class ReLUConvBN(nn.Module):
27 |
28 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
29 | super(ReLUConvBN, self).__init__()
30 | self.op = nn.Sequential(
31 | nn.ReLU(inplace=False),
32 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
33 | nn.BatchNorm2d(C_out, affine=affine)
34 | )
35 |
36 | def forward(self, x):
37 | return self.op(x)
38 |
39 | class DilConv(nn.Module):
40 |
41 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
42 | super(DilConv, self).__init__()
43 | self.op = nn.Sequential(
44 | nn.ReLU(inplace=False),
45 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
46 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
47 | nn.BatchNorm2d(C_out, affine=affine),
48 | )
49 |
50 | def forward(self, x):
51 | return self.op(x)
52 |
53 |
54 | class DilConvTime(nn.Module):
55 |
56 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
57 | super(DilConvTime, self).__init__()
58 | self.op = nn.Sequential(
59 | nn.ReLU(inplace=False),
60 | nn.Conv2d(C_in, C_in, kernel_size=(kernel_size, 1), stride=stride, padding=(padding, 0), dilation=(dilation, 1), groups=C_in,
61 | bias=False),
62 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
63 | nn.BatchNorm2d(C_out, affine=affine),
64 | )
65 |
66 | def forward(self, x):
67 | return self.op(x)
68 |
69 |
70 | class DilConvFreq(nn.Module):
71 |
72 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
73 | super(DilConvFreq, self).__init__()
74 | self.op = nn.Sequential(
75 | nn.ReLU(inplace=False),
76 | nn.Conv2d(C_in, C_in, kernel_size=(1, kernel_size), stride=stride, padding=(0, padding), dilation=(1, dilation), groups=C_in,
77 | bias=False),
78 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
79 | nn.BatchNorm2d(C_out, affine=affine),
80 | )
81 |
82 | def forward(self, x):
83 | return self.op(x)
84 |
85 | class SepConv(nn.Module):
86 |
87 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
88 | super(SepConv, self).__init__()
89 | self.op = nn.Sequential(
90 | nn.ReLU(inplace=False),
91 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
92 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
93 | nn.BatchNorm2d(C_in, affine=affine),
94 | nn.ReLU(inplace=False),
95 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False),
96 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
97 | nn.BatchNorm2d(C_out, affine=affine),
98 | )
99 |
100 | def forward(self, x):
101 | return self.op(x)
102 |
103 |
104 | class SepConvTime(nn.Module):
105 |
106 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
107 | super(SepConvTime, self).__init__()
108 | self.op = nn.Sequential(
109 | nn.ReLU(inplace=False),
110 | nn.Conv2d(C_in, C_in, kernel_size=(kernel_size, 1), stride=stride, padding=(padding, 0), groups=C_in, bias=False),
111 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
112 | nn.BatchNorm2d(C_in, affine=affine),
113 | nn.ReLU(inplace=False),
114 | nn.Conv2d(C_in, C_in, kernel_size=(kernel_size, 1), stride=1, padding=(padding, 0), groups=C_in, bias=False),
115 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
116 | nn.BatchNorm2d(C_out, affine=affine),
117 | )
118 |
119 | def forward(self, x):
120 | return self.op(x)
121 |
122 |
123 | class SepConvFreq(nn.Module):
124 |
125 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
126 | super(SepConvFreq, self).__init__()
127 | self.op = nn.Sequential(
128 | nn.ReLU(inplace=False),
129 | nn.Conv2d(C_in, C_in, kernel_size=(1, kernel_size), stride=stride, padding=(0, padding), groups=C_in, bias=False),
130 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
131 | nn.BatchNorm2d(C_in, affine=affine),
132 | nn.ReLU(inplace=False),
133 | nn.Conv2d(C_in, C_in, kernel_size=(1, kernel_size), stride=1, padding=(0, padding), groups=C_in, bias=False),
134 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
135 | nn.BatchNorm2d(C_out, affine=affine),
136 | )
137 |
138 | def forward(self, x):
139 | return self.op(x)
140 |
141 | class Identity(nn.Module):
142 |
143 | def __init__(self):
144 | super(Identity, self).__init__()
145 |
146 | def forward(self, x):
147 | return x
148 |
149 |
150 | class Zero(nn.Module):
151 |
152 | def __init__(self, stride):
153 | super(Zero, self).__init__()
154 | self.stride = stride
155 |
156 | def forward(self, x):
157 | if self.stride == 1:
158 | return x.mul(0.)
159 | return x[:,:,::self.stride,::self.stride].mul(0.)
160 |
161 |
162 | class FactorizedReduce(nn.Module):
163 |
164 | def __init__(self, C_in, C_out, affine=True):
165 | super(FactorizedReduce, self).__init__()
166 | assert C_out % 2 == 0
167 | self.relu = nn.ReLU(inplace=False)
168 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
169 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
170 | self.bn = nn.BatchNorm2d(C_out, affine=affine)
171 |
172 | def forward(self, x):
173 | x = self.relu(x)
174 | # out = torch.cat([self.conv_1(x), self.conv_2(x[:,:,1:,1:])], dim=1)
175 | out = torch.cat([self.conv_1(x), self.conv_2(x)], dim=1)
176 | out = self.bn(out)
177 | return out
178 |
179 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | audioread==2.1.8
2 | certifi==2020.4.5.1
3 | cffi==1.14.0
4 | cycler==0.10.0
5 | decorator==4.4.2
6 | dill==0.3.1.1
7 | future==0.18.2
8 | joblib==0.14.1
9 | kiwisolver==1.2.0
10 | librosa==0.7.2
11 | llvmlite==0.32.0
12 | matplotlib==3.2.1
13 | multiprocess==0.70.9
14 | numba==0.49.0
15 | numpy==1.18.4
16 | Pillow==7.1.2
17 | protobuf==3.11.3
18 | pycparser==2.20
19 | pyparsing==2.4.7
20 | python-dateutil==2.8.1
21 | PyYAML==5.3.1
22 | resampy==0.2.2
23 | scikit-learn==0.22.2.post1
24 | scipy==1.4.1
25 | six==1.14.0
26 | sklearn==0.0
27 | SoundFile==0.10.3.post1
28 | tensorboardX==2.0
29 | torch==1.5.0
30 | torchvision==0.6.0
31 | tqdm==4.46.0
32 | yacs==0.1.7
33 |
--------------------------------------------------------------------------------
/search.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-08-09
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 |
7 |
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 |
12 | import argparse
13 | import numpy as np
14 | import shutil
15 | import os
16 | from tensorboardX import SummaryWriter
17 | from tqdm import tqdm
18 | from pathlib import Path
19 |
20 | import torch
21 | import torch.optim as optim
22 | import torch.backends.cudnn as cudnn
23 |
24 | from config import cfg, update_config
25 | from utils import set_path, create_logger, save_checkpoint
26 | from data_objects.DeepSpeakerDataset import DeepSpeakerDataset
27 | from functions import train, validate_identification
28 | from architect import Architect
29 | from loss import CrossEntropyLoss
30 | from torch.utils.data import DataLoader
31 | from spaces import primitives_1, primitives_2, primitives_3
32 | from models.model_search import Network
33 |
34 |
35 | def parse_args():
36 | parser = argparse.ArgumentParser(description='Train energy network')
37 | # general
38 | parser.add_argument('--cfg',
39 | help='experiment configure file name',
40 | required=True,
41 | type=str)
42 |
43 | parser.add_argument('opts',
44 | help="Modify config options using the command-line",
45 | default=None,
46 | nargs=argparse.REMAINDER)
47 |
48 | parser.add_argument('--load_path',
49 | help="The path to resumed dir",
50 | default=None)
51 |
52 | args = parser.parse_args()
53 |
54 | return args
55 |
56 |
57 | def main():
58 | args = parse_args()
59 | update_config(cfg, args)
60 |
61 | # cudnn related setting
62 | cudnn.benchmark = cfg.CUDNN.BENCHMARK
63 | torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
64 | torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
65 |
66 | # Set the random seed manually for reproducibility.
67 | np.random.seed(cfg.SEED)
68 | torch.manual_seed(cfg.SEED)
69 | torch.cuda.manual_seed_all(cfg.SEED)
70 |
71 | # Loss
72 | criterion = CrossEntropyLoss(cfg.MODEL.NUM_CLASSES).cuda()
73 |
74 | # model and optimizer
75 | model = Network(cfg.MODEL.INIT_CHANNELS, cfg.MODEL.NUM_CLASSES, cfg.MODEL.LAYERS, criterion, primitives_2,
76 | drop_path_prob=cfg.TRAIN.DROPPATH_PROB)
77 | model = model.cuda()
78 |
79 | # weight params
80 | arch_params = list(map(id, model.arch_parameters()))
81 | weight_params = filter(lambda p: id(p) not in arch_params,
82 | model.parameters())
83 |
84 | # Optimizer
85 | optimizer = optim.Adam(
86 | weight_params,
87 | lr=cfg.TRAIN.LR
88 | )
89 |
90 | # resume && make log dir and logger
91 | if args.load_path and os.path.exists(args.load_path):
92 | checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint_best.pth')
93 | assert os.path.exists(checkpoint_file)
94 | checkpoint = torch.load(checkpoint_file)
95 |
96 | # load checkpoint
97 | begin_epoch = checkpoint['epoch']
98 | last_epoch = checkpoint['epoch']
99 | model.load_state_dict(checkpoint['state_dict'])
100 | best_acc1 = checkpoint['best_acc1']
101 | optimizer.load_state_dict(checkpoint['optimizer'])
102 | args.path_helper = checkpoint['path_helper']
103 |
104 | logger = create_logger(args.path_helper['log_path'])
105 | logger.info("=> loaded checkpoint '{}'".format(checkpoint_file))
106 | else:
107 | exp_name = args.cfg.split('/')[-1].split('.')[0]
108 | args.path_helper = set_path('logs_search', exp_name)
109 | logger = create_logger(args.path_helper['log_path'])
110 | begin_epoch = cfg.TRAIN.BEGIN_EPOCH
111 | best_acc1 = 0.0
112 | last_epoch = -1
113 |
114 | logger.info(args)
115 | logger.info(cfg)
116 |
117 | # copy model file
118 | this_dir = os.path.dirname(__file__)
119 | shutil.copy2(
120 | os.path.join(this_dir, 'models', cfg.MODEL.NAME + '.py'),
121 | args.path_helper['ckpt_path'])
122 |
123 | # dataloader
124 | train_dataset = DeepSpeakerDataset(
125 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.SUB_DIR, cfg.DATASET.PARTIAL_N_FRAMES, 'train')
126 | val_dataset = DeepSpeakerDataset(
127 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.SUB_DIR, cfg.DATASET.PARTIAL_N_FRAMES, 'val')
128 | train_loader = torch.utils.data.DataLoader(
129 | dataset=train_dataset,
130 | batch_size=cfg.TRAIN.BATCH_SIZE,
131 | num_workers=cfg.DATASET.NUM_WORKERS,
132 | pin_memory=True,
133 | shuffle=True,
134 | drop_last=True,
135 | )
136 | val_loader = torch.utils.data.DataLoader(
137 | dataset=val_dataset,
138 | batch_size=cfg.TRAIN.BATCH_SIZE,
139 | num_workers=cfg.DATASET.NUM_WORKERS,
140 | pin_memory=True,
141 | shuffle=True,
142 | drop_last=True,
143 | )
144 | test_dataset = DeepSpeakerDataset(
145 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.SUB_DIR, cfg.DATASET.PARTIAL_N_FRAMES, 'test', is_test=True)
146 | test_loader = torch.utils.data.DataLoader(
147 | dataset=test_dataset,
148 | batch_size=1,
149 | num_workers=cfg.DATASET.NUM_WORKERS,
150 | pin_memory=True,
151 | shuffle=True,
152 | drop_last=True,
153 | )
154 |
155 | # training setting
156 | writer_dict = {
157 | 'writer': SummaryWriter(args.path_helper['log_path']),
158 | 'train_global_steps': begin_epoch * len(train_loader),
159 | 'valid_global_steps': begin_epoch // cfg.VAL_FREQ,
160 | }
161 |
162 | # training loop
163 | architect = Architect(model, cfg)
164 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
165 | optimizer, cfg.TRAIN.END_EPOCH, cfg.TRAIN.LR_MIN,
166 | last_epoch=last_epoch
167 | )
168 |
169 | for epoch in tqdm(range(begin_epoch, cfg.TRAIN.END_EPOCH), desc='search progress'):
170 | model.train()
171 |
172 | genotype = model.genotype()
173 | logger.info('genotype = %s', genotype)
174 |
175 | if cfg.TRAIN.DROPPATH_PROB != 0:
176 | model.drop_path_prob = cfg.TRAIN.DROPPATH_PROB * epoch / (cfg.TRAIN.END_EPOCH - 1)
177 |
178 | train(cfg, model, optimizer, train_loader, val_loader, criterion, architect, epoch, writer_dict)
179 |
180 | if epoch % cfg.VAL_FREQ == 0:
181 | # get threshold and evaluate on validation set
182 | acc = validate_identification(cfg, model, test_loader, criterion)
183 |
184 | # remember best acc@1 and save checkpoint
185 | is_best = acc > best_acc1
186 | best_acc1 = max(acc, best_acc1)
187 |
188 | # save
189 | logger.info('=> saving checkpoint to {}'.format(args.path_helper['ckpt_path']))
190 | save_checkpoint({
191 | 'epoch': epoch + 1,
192 | 'state_dict': model.state_dict(),
193 | 'best_acc1': best_acc1,
194 | 'optimizer': optimizer.state_dict(),
195 | 'arch': model.arch_parameters(),
196 | 'genotype': genotype,
197 | 'path_helper': args.path_helper
198 | }, is_best, args.path_helper['ckpt_path'], 'checkpoint_{}.pth'.format(epoch))
199 |
200 | lr_scheduler.step(epoch)
201 |
202 |
203 |
204 | if __name__ == '__main__':
205 | main()
206 |
--------------------------------------------------------------------------------
/spaces.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | primitives_1 = OrderedDict([('primitives_normal', [['skip_connect',
4 | 'dil_conv_3x3'],
5 | ['skip_connect',
6 | 'dil_conv_5x5'],
7 | ['skip_connect',
8 | 'dil_conv_5x5'],
9 | ['skip_connect',
10 | 'sep_conv_3x3'],
11 | ['skip_connect',
12 | 'dil_conv_3x3'],
13 | ['max_pool_3x3',
14 | 'skip_connect'],
15 | ['skip_connect',
16 | 'sep_conv_3x3'],
17 | ['skip_connect',
18 | 'sep_conv_3x3'],
19 | ['skip_connect',
20 | 'dil_conv_3x3'],
21 | ['skip_connect',
22 | 'sep_conv_3x3'],
23 | ['max_pool_3x3',
24 | 'skip_connect'],
25 | ['skip_connect',
26 | 'dil_conv_3x3'],
27 | ['dil_conv_3x3',
28 | 'dil_conv_5x5'],
29 | ['dil_conv_3x3',
30 | 'dil_conv_5x5']]),
31 | ('primitives_reduct', [['max_pool_3x3',
32 | 'avg_pool_3x3'],
33 | ['max_pool_3x3',
34 | 'dil_conv_3x3'],
35 | ['max_pool_3x3',
36 | 'avg_pool_3x3'],
37 | ['max_pool_3x3',
38 | 'avg_pool_3x3'],
39 | ['skip_connect',
40 | 'dil_conv_5x5'],
41 | ['max_pool_3x3',
42 | 'avg_pool_3x3'],
43 | ['max_pool_3x3',
44 | 'sep_conv_3x3'],
45 | ['skip_connect',
46 | 'dil_conv_3x3'],
47 | ['skip_connect',
48 | 'dil_conv_5x5'],
49 | ['max_pool_3x3',
50 | 'avg_pool_3x3'],
51 | ['max_pool_3x3',
52 | 'avg_pool_3x3'],
53 | ['skip_connect',
54 | 'dil_conv_5x5'],
55 | ['skip_connect',
56 | 'dil_conv_5x5'],
57 | ['skip_connect',
58 | 'dil_conv_5x5']])])
59 |
60 | PRIMITIVES = [
61 | 'none',
62 | 'max_pool_3x3',
63 | 'avg_pool_3x3',
64 | 'skip_connect',
65 | 'sep_conv_3x3',
66 | 'sep_conv_5x5',
67 | 'dil_conv_3x3',
68 | 'dil_conv_5x5'
69 | ]
70 |
71 | primitives_2 = OrderedDict([('primitives_normal', 14 * [PRIMITIVES]),
72 | ('primitives_reduct', 14 * [PRIMITIVES])])
73 |
74 | PRIMITIVES_SMALL = [
75 | 'none',
76 | 'max_pool_3x3',
77 | 'avg_pool_3x3',
78 | 'skip_connect',
79 | 'sep_conv_3x3',
80 | 'sep_conv_3x1',
81 | 'sep_conv_1x3',
82 | 'dil_conv_3x3',
83 | 'dil_conv_3x1',
84 | 'dil_conv_1x3',
85 | ]
86 |
87 | primitives_3 = OrderedDict([('primitives_normal', 14 * [PRIMITIVES_SMALL]),
88 | ('primitives_reduct', 14 * [PRIMITIVES_SMALL])])
89 |
90 | spaces_dict = {
91 | 's1': primitives_1, # space from https://openreview.net/forum?id=H1gDNyrKDS
92 | 's2': primitives_2, # original DARTS space
93 | 's3': primitives_3, # space with 1D conv
94 | }
--------------------------------------------------------------------------------
/train_baseline_identification.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import argparse
6 | import numpy as np
7 | import os
8 | from tensorboardX import SummaryWriter
9 | from tqdm import tqdm
10 | from pathlib import Path
11 |
12 | import torch
13 | import torch.optim as optim
14 | import torch.backends.cudnn as cudnn
15 |
16 | from models import resnet
17 | from config import cfg, update_config
18 | from utils import set_path, create_logger, save_checkpoint, count_parameters
19 | from data_objects.DeepSpeakerDataset import DeepSpeakerDataset
20 | from functions import train_from_scratch, validate_identification
21 | from loss import CrossEntropyLoss
22 |
23 |
24 | def parse_args():
25 | parser = argparse.ArgumentParser(description='Train energy network')
26 | # general
27 | parser.add_argument('--cfg',
28 | help='experiment configure file name',
29 | required=True,
30 | type=str)
31 |
32 | parser.add_argument('opts',
33 | help="Modify config options using the command-line",
34 | default=None,
35 | nargs=argparse.REMAINDER)
36 |
37 | parser.add_argument('--load_path',
38 | help="The path to resumed dir",
39 | default=None)
40 |
41 | args = parser.parse_args()
42 |
43 | return args
44 |
45 |
46 | def main():
47 | args = parse_args()
48 | update_config(cfg, args)
49 |
50 | # cudnn related setting
51 | cudnn.benchmark = cfg.CUDNN.BENCHMARK
52 | torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
53 | torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
54 |
55 | # Set the random seed manually for reproducibility.
56 | np.random.seed(cfg.SEED)
57 | torch.manual_seed(cfg.SEED)
58 | torch.cuda.manual_seed_all(cfg.SEED)
59 |
60 | # model and optimizer
61 | model = eval('resnet.{}(num_classes={})'.format(
62 | cfg.MODEL.NAME, cfg.MODEL.NUM_CLASSES))
63 | model = model.cuda()
64 | optimizer = optim.Adam(
65 | model.net_parameters() if hasattr(model, 'net_parameters') else model.parameters(),
66 | lr=cfg.TRAIN.LR
67 | )
68 |
69 | # Loss
70 | criterion = CrossEntropyLoss(cfg.MODEL.NUM_CLASSES).cuda()
71 |
72 | # resume && make log dir and logger
73 | if args.load_path and os.path.exists(args.load_path):
74 | checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint_best.pth')
75 | # checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint.pth')
76 | assert os.path.exists(checkpoint_file)
77 | checkpoint = torch.load(checkpoint_file)
78 |
79 | # load checkpoint
80 | begin_epoch = checkpoint['epoch']
81 | last_epoch = checkpoint['epoch']
82 | model.load_state_dict(checkpoint['state_dict'])
83 | best_acc1 = checkpoint['best_acc1']
84 | optimizer.load_state_dict(checkpoint['optimizer'])
85 | args.path_helper = checkpoint['path_helper']
86 |
87 | logger = create_logger(args.path_helper['log_path'])
88 | logger.info("=> loaded checkpoint '{}'".format(checkpoint_file))
89 | else:
90 | exp_name = args.cfg.split('/')[-1].split('.')[0]
91 | args.path_helper = set_path('logs', exp_name)
92 | logger = create_logger(args.path_helper['log_path'])
93 | begin_epoch = cfg.TRAIN.BEGIN_EPOCH
94 | best_acc1 = 0.0
95 | last_epoch = -1
96 | logger.info(args)
97 | logger.info(cfg)
98 | logger.info("Number of parameters: {}".format(count_parameters(model)))
99 |
100 | # dataloader
101 | train_dataset = DeepSpeakerDataset(
102 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.SUB_DIR, cfg.DATASET.PARTIAL_N_FRAMES, 'train')
103 | test_dataset_identification = DeepSpeakerDataset(
104 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.SUB_DIR, cfg.DATASET.PARTIAL_N_FRAMES, 'test', is_test=True)
105 | train_loader = torch.utils.data.DataLoader(
106 | dataset=train_dataset,
107 | batch_size=cfg.TRAIN.BATCH_SIZE,
108 | num_workers=cfg.DATASET.NUM_WORKERS,
109 | pin_memory=True,
110 | shuffle=True,
111 | drop_last=True,
112 | )
113 | test_loader_identification = torch.utils.data.DataLoader(
114 | dataset=test_dataset_identification,
115 | batch_size=1,
116 | num_workers=cfg.DATASET.NUM_WORKERS,
117 | pin_memory=True,
118 | shuffle=True,
119 | drop_last=True,
120 | )
121 |
122 | # training setting
123 | writer_dict = {
124 | 'writer': SummaryWriter(args.path_helper['log_path']),
125 | 'train_global_steps': begin_epoch * len(train_loader),
126 | 'valid_global_steps': begin_epoch // cfg.VAL_FREQ,
127 | }
128 |
129 | # training loop
130 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
131 | optimizer, cfg.TRAIN.END_EPOCH, cfg.TRAIN.LR_MIN,
132 | last_epoch=last_epoch
133 | )
134 |
135 | for epoch in tqdm(range(begin_epoch, cfg.TRAIN.END_EPOCH), desc='train progress'):
136 | model.train()
137 | train_from_scratch(cfg, model, optimizer, train_loader, criterion, epoch, writer_dict, lr_scheduler)
138 | if epoch % cfg.VAL_FREQ == 0:
139 | acc = validate_identification(cfg, model, test_loader_identification, criterion)
140 |
141 | # remember best acc@1 and save checkpoint
142 | is_best = acc > best_acc1
143 | best_acc1 = max(acc, best_acc1)
144 |
145 | # save
146 | logger.info('=> saving checkpoint to {}'.format(args.path_helper['ckpt_path']))
147 | save_checkpoint({
148 | 'epoch': epoch + 1,
149 | 'state_dict': model.state_dict(),
150 | 'best_acc1': best_acc1,
151 | 'optimizer': optimizer.state_dict(),
152 | 'path_helper': args.path_helper
153 | }, is_best, args.path_helper['ckpt_path'], 'checkpoint_{}.pth'.format(epoch))
154 | lr_scheduler.step(epoch)
155 |
156 |
157 | if __name__ == '__main__':
158 | main()
159 |
--------------------------------------------------------------------------------
/train_baseline_verification.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import argparse
6 | import numpy as np
7 | import os
8 | from tensorboardX import SummaryWriter
9 | from tqdm import tqdm
10 | from pathlib import Path
11 |
12 | import torch
13 | import torch.optim as optim
14 | import torch.backends.cudnn as cudnn
15 |
16 | from models import resnet
17 | from config import cfg, update_config
18 | from utils import set_path, create_logger, save_checkpoint, count_parameters
19 | from data_objects.DeepSpeakerDataset import DeepSpeakerDataset
20 | from data_objects.VoxcelebTestset import VoxcelebTestset
21 | from functions import train_from_scratch, validate_verification
22 | from loss import CrossEntropyLoss
23 |
24 |
25 | def parse_args():
26 | parser = argparse.ArgumentParser(description='Train energy network')
27 | # general
28 | parser.add_argument('--cfg',
29 | help='experiment configure file name',
30 | required=True,
31 | type=str)
32 |
33 | parser.add_argument('opts',
34 | help="Modify config options using the command-line",
35 | default=None,
36 | nargs=argparse.REMAINDER)
37 |
38 | parser.add_argument('--load_path',
39 | help="The path to resumed dir",
40 | default=None)
41 |
42 | args = parser.parse_args()
43 |
44 | return args
45 |
46 |
47 | def main():
48 | args = parse_args()
49 | update_config(cfg, args)
50 |
51 | # cudnn related setting
52 | cudnn.benchmark = cfg.CUDNN.BENCHMARK
53 | torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
54 | torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
55 |
56 | # Set the random seed manually for reproducibility.
57 | np.random.seed(cfg.SEED)
58 | torch.manual_seed(cfg.SEED)
59 | torch.cuda.manual_seed_all(cfg.SEED)
60 |
61 | # model and optimizer
62 | model = eval('resnet.{}(num_classes={})'.format(
63 | cfg.MODEL.NAME, cfg.MODEL.NUM_CLASSES))
64 | model = model.cuda()
65 | optimizer = optim.Adam(
66 | model.net_parameters() if hasattr(model, 'net_parameters') else model.parameters(),
67 | lr=cfg.TRAIN.LR,
68 | )
69 |
70 | # Loss
71 | criterion = CrossEntropyLoss(cfg.MODEL.NUM_CLASSES).cuda()
72 |
73 | # resume && make log dir and logger
74 | if args.load_path and os.path.exists(args.load_path):
75 | checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint_best.pth')
76 | assert os.path.exists(checkpoint_file)
77 | checkpoint = torch.load(checkpoint_file)
78 |
79 | # load checkpoint
80 | begin_epoch = checkpoint['epoch']
81 | last_epoch = checkpoint['epoch']
82 | model.load_state_dict(checkpoint['state_dict'])
83 | best_eer = checkpoint['best_eer']
84 | optimizer.load_state_dict(checkpoint['optimizer'])
85 | args.path_helper = checkpoint['path_helper']
86 |
87 | logger = create_logger(args.path_helper['log_path'])
88 | logger.info("=> loaded checkpoint '{}'".format(checkpoint_file))
89 | else:
90 | exp_name = args.cfg.split('/')[-1].split('.')[0]
91 | args.path_helper = set_path('logs', exp_name)
92 | logger = create_logger(args.path_helper['log_path'])
93 | begin_epoch = cfg.TRAIN.BEGIN_EPOCH
94 | best_eer = 1.0
95 | last_epoch = -1
96 | logger.info(args)
97 | logger.info(cfg)
98 | logger.info("Number of parameters: {}".format(count_parameters(model)))
99 |
100 | # dataloader
101 | train_dataset = DeepSpeakerDataset(
102 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.SUB_DIR, cfg.DATASET.PARTIAL_N_FRAMES)
103 | test_dataset_verification = VoxcelebTestset(
104 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.PARTIAL_N_FRAMES
105 | )
106 | train_loader = torch.utils.data.DataLoader(
107 | dataset=train_dataset,
108 | batch_size=cfg.TRAIN.BATCH_SIZE,
109 | num_workers=cfg.DATASET.NUM_WORKERS,
110 | pin_memory=True,
111 | shuffle=True,
112 | drop_last=True,
113 | )
114 | test_loader_verification = torch.utils.data.DataLoader(
115 | dataset=test_dataset_verification,
116 | batch_size=1,
117 | num_workers=cfg.DATASET.NUM_WORKERS,
118 | pin_memory=True,
119 | shuffle=False,
120 | drop_last=False,
121 | )
122 |
123 | # training setting
124 | writer_dict = {
125 | 'writer': SummaryWriter(args.path_helper['log_path']),
126 | 'train_global_steps': begin_epoch * len(train_loader),
127 | 'valid_global_steps': begin_epoch // cfg.VAL_FREQ,
128 | }
129 |
130 | # training loop
131 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
132 | optimizer, cfg.TRAIN.END_EPOCH, cfg.TRAIN.LR_MIN,
133 | last_epoch=last_epoch
134 | )
135 |
136 | for epoch in tqdm(range(begin_epoch, cfg.TRAIN.END_EPOCH), desc='train progress'):
137 | model.train()
138 | train_from_scratch(cfg, model, optimizer, train_loader, criterion, epoch, writer_dict, lr_scheduler)
139 | if epoch % cfg.VAL_FREQ == 0:
140 | eer = validate_verification(cfg, model, test_loader_verification)
141 |
142 | # remember best acc@1 and save checkpoint
143 | is_best = eer < best_eer
144 | best_eer = min(eer, best_eer)
145 |
146 | # save
147 | logger.info('=> saving checkpoint to {}'.format(args.path_helper['ckpt_path']))
148 | save_checkpoint({
149 | 'epoch': epoch + 1,
150 | 'state_dict': model.state_dict(),
151 | 'best_eer': best_eer,
152 | 'optimizer': optimizer.state_dict(),
153 | 'path_helper': args.path_helper
154 | }, is_best, args.path_helper['ckpt_path'], 'checkpoint_{}.pth'.format(epoch))
155 | lr_scheduler.step(epoch)
156 |
157 |
158 | if __name__ == '__main__':
159 | main()
160 |
--------------------------------------------------------------------------------
/train_identification.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import argparse
6 | import numpy as np
7 | import shutil
8 | import os
9 | from pathlib import Path
10 | from tensorboardX import SummaryWriter
11 | from tqdm import tqdm
12 |
13 | import torch
14 | import torch.optim as optim
15 |
16 | import torch.backends.cudnn as cudnn
17 |
18 | from models.model import Network
19 | from config import cfg, update_config
20 | from utils import set_path, create_logger, save_checkpoint, count_parameters, Genotype
21 | from data_objects.DeepSpeakerDataset import DeepSpeakerDataset
22 | from functions import train_from_scratch, validate_identification
23 | from loss import CrossEntropyLoss
24 |
25 |
26 | def parse_args():
27 | parser = argparse.ArgumentParser(description='Train energy network')
28 | # general
29 | parser.add_argument('--cfg',
30 | help='experiment configure file name',
31 | required=True,
32 | type=str)
33 |
34 | parser.add_argument('opts',
35 | help="Modify config options using the command-line",
36 | default=None,
37 | nargs=argparse.REMAINDER)
38 |
39 | parser.add_argument('--load_path',
40 | help="The path to resumed dir",
41 | default=None)
42 |
43 | parser.add_argument('--text_arch',
44 | help="The text to arch",
45 | default=None)
46 |
47 | args = parser.parse_args()
48 |
49 | return args
50 |
51 |
52 | def main():
53 | args = parse_args()
54 | update_config(cfg, args)
55 | assert args.text_arch
56 |
57 | # cudnn related setting
58 | cudnn.benchmark = cfg.CUDNN.BENCHMARK
59 | torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
60 | torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
61 |
62 | # Set the random seed manually for reproducibility.
63 | np.random.seed(cfg.SEED)
64 | torch.manual_seed(cfg.SEED)
65 | torch.cuda.manual_seed_all(cfg.SEED)
66 |
67 | # Loss
68 | criterion = CrossEntropyLoss(cfg.MODEL.NUM_CLASSES).cuda()
69 |
70 | # load arch
71 | genotype = eval(args.text_arch)
72 |
73 | model = Network(cfg.MODEL.INIT_CHANNELS, cfg.MODEL.NUM_CLASSES, cfg.MODEL.LAYERS, genotype)
74 | model = model.cuda()
75 |
76 | optimizer = optim.Adam(
77 | model.parameters(),
78 | lr=cfg.TRAIN.LR
79 | )
80 |
81 | # resume && make log dir and logger
82 | if args.load_path and os.path.exists(args.load_path):
83 | checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint_best.pth')
84 | assert os.path.exists(checkpoint_file)
85 | checkpoint = torch.load(checkpoint_file)
86 |
87 | # load checkpoint
88 | begin_epoch = checkpoint['epoch']
89 | last_epoch = checkpoint['epoch']
90 | model.load_state_dict(checkpoint['state_dict'])
91 | best_acc1 = checkpoint['best_acc1']
92 | optimizer.load_state_dict(checkpoint['optimizer'])
93 | args.path_helper = checkpoint['path_helper']
94 |
95 | logger = create_logger(args.path_helper['log_path'])
96 | logger.info("=> loaded checkloggpoint '{}'".format(checkpoint_file))
97 | else:
98 | exp_name = args.cfg.split('/')[-1].split('.')[0]
99 | args.path_helper = set_path('logs_scratch', exp_name)
100 | logger = create_logger(args.path_helper['log_path'])
101 | begin_epoch = cfg.TRAIN.BEGIN_EPOCH
102 | best_acc1 = 0.0
103 | last_epoch = -1
104 | logger.info(args)
105 | logger.info(cfg)
106 | logger.info(f"selected architecture: {genotype}")
107 | logger.info("Number of parameters: {}".format(count_parameters(model)))
108 |
109 | # dataloader
110 | train_dataset = DeepSpeakerDataset(
111 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.SUB_DIR, cfg.DATASET.PARTIAL_N_FRAMES, 'train')
112 | train_loader = torch.utils.data.DataLoader(
113 | dataset=train_dataset,
114 | batch_size=cfg.TRAIN.BATCH_SIZE,
115 | num_workers=cfg.DATASET.NUM_WORKERS,
116 | pin_memory=True,
117 | shuffle=True,
118 | drop_last=True,
119 | )
120 | test_dataset = DeepSpeakerDataset(
121 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.SUB_DIR, cfg.DATASET.PARTIAL_N_FRAMES, 'test', is_test=True)
122 | test_loader = torch.utils.data.DataLoader(
123 | dataset=test_dataset,
124 | batch_size=1,
125 | num_workers=cfg.DATASET.NUM_WORKERS,
126 | pin_memory=True,
127 | shuffle=True,
128 | drop_last=True,
129 | )
130 |
131 | # training setting
132 | writer_dict = {
133 | 'writer': SummaryWriter(args.path_helper['log_path']),
134 | 'train_global_steps': begin_epoch * len(train_loader),
135 | 'valid_global_steps': begin_epoch // cfg.VAL_FREQ,
136 | }
137 |
138 | # training loop
139 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
140 | optimizer, cfg.TRAIN.END_EPOCH, cfg.TRAIN.LR_MIN,
141 | last_epoch=last_epoch
142 | )
143 |
144 | for epoch in tqdm(range(begin_epoch, cfg.TRAIN.END_EPOCH), desc='train progress'):
145 | model.train()
146 | model.drop_path_prob = cfg.MODEL.DROP_PATH_PROB * epoch / cfg.TRAIN.END_EPOCH
147 |
148 | train_from_scratch(cfg, model, optimizer, train_loader, criterion, epoch, writer_dict)
149 |
150 | if epoch % cfg.VAL_FREQ == 0 or epoch == cfg.TRAIN.END_EPOCH - 1:
151 | acc = validate_identification(cfg, model, test_loader, criterion)
152 |
153 | # remember best acc@1 and save checkpoint
154 | is_best = acc > best_acc1
155 | best_acc1 = max(acc, best_acc1)
156 |
157 | # save
158 | logger.info('=> saving checkpoint to {}'.format(args.path_helper['ckpt_path']))
159 | save_checkpoint({
160 | 'epoch': epoch + 1,
161 | 'state_dict': model.state_dict(),
162 | 'best_acc1': best_acc1,
163 | 'optimizer': optimizer.state_dict(),
164 | 'path_helper': args.path_helper,
165 | 'genotype': genotype,
166 | }, is_best, args.path_helper['ckpt_path'], 'checkpoint_{}.pth'.format(epoch))
167 |
168 | lr_scheduler.step(epoch)
169 |
170 |
171 | if __name__ == '__main__':
172 | main()
173 |
--------------------------------------------------------------------------------
/train_verification.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-08-09
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 |
7 |
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 |
12 | import argparse
13 | import numpy as np
14 | import shutil
15 | import os
16 | from pathlib import Path
17 | from tensorboardX import SummaryWriter
18 | from tqdm import tqdm
19 |
20 | import torch
21 | import torch.optim as optim
22 | import torch.nn as nn
23 | import torch.backends.cudnn as cudnn
24 |
25 | from models.model import Network
26 | from config import cfg, update_config
27 | from utils import set_path, create_logger, save_checkpoint, count_parameters, Genotype
28 | from data_objects.DeepSpeakerDataset import DeepSpeakerDataset
29 | from data_objects.VoxcelebTestset import VoxcelebTestset
30 | from functions import train_from_scratch, validate_verification
31 | from loss import CrossEntropyLoss
32 |
33 |
34 | def parse_args():
35 | parser = argparse.ArgumentParser(description='Train energy network')
36 | # general
37 | parser.add_argument('--cfg',
38 | help='experiment configure file name',
39 | required=True,
40 | type=str)
41 |
42 | parser.add_argument('opts',
43 | help="Modify config options using the command-line",
44 | default=None,
45 | nargs=argparse.REMAINDER)
46 |
47 | parser.add_argument('--load_path',
48 | help="The path to resumed dir",
49 | default=None)
50 |
51 | parser.add_argument('--text_arch',
52 | help="The text to arch",
53 | default=None)
54 |
55 | args = parser.parse_args()
56 |
57 | return args
58 |
59 |
60 | def main():
61 | args = parse_args()
62 | update_config(cfg, args)
63 | assert args.text_arch
64 |
65 | # cudnn related setting
66 | cudnn.benchmark = cfg.CUDNN.BENCHMARK
67 | torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
68 | torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
69 |
70 | # Set the random seed manually for reproducibility.
71 | np.random.seed(cfg.SEED)
72 | torch.manual_seed(cfg.SEED)
73 | torch.cuda.manual_seed_all(cfg.SEED)
74 |
75 | # Loss
76 | criterion = CrossEntropyLoss(cfg.MODEL.NUM_CLASSES).cuda()
77 |
78 | # load arch
79 | genotype = eval(args.text_arch)
80 |
81 | model = Network(cfg.MODEL.INIT_CHANNELS, cfg.MODEL.NUM_CLASSES, cfg.MODEL.LAYERS, genotype)
82 | model = model.cuda()
83 |
84 | optimizer = optim.Adam(
85 | model.parameters(),
86 | lr=cfg.TRAIN.LR
87 | )
88 |
89 | # resume && make log dir and logger
90 | if args.load_path and os.path.exists(args.load_path):
91 | checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint_best.pth')
92 | assert os.path.exists(checkpoint_file)
93 | checkpoint = torch.load(checkpoint_file)
94 |
95 | # load checkpoint
96 | begin_epoch = checkpoint['epoch']
97 | last_epoch = checkpoint['epoch']
98 | model.load_state_dict(checkpoint['state_dict'])
99 | best_eer = checkpoint['best_eer']
100 | optimizer.load_state_dict(checkpoint['optimizer'])
101 | args.path_helper = checkpoint['path_helper']
102 |
103 | logger = create_logger(args.path_helper['log_path'])
104 | logger.info("=> loaded checkloggpoint '{}'".format(checkpoint_file))
105 | else:
106 | exp_name = args.cfg.split('/')[-1].split('.')[0]
107 | args.path_helper = set_path('logs_scratch', exp_name)
108 | logger = create_logger(args.path_helper['log_path'])
109 | begin_epoch = cfg.TRAIN.BEGIN_EPOCH
110 | best_eer = 1.0
111 | last_epoch = -1
112 | logger.info(args)
113 | logger.info(cfg)
114 | logger.info(f"selected architecture: {genotype}")
115 | logger.info("Number of parameters: {}".format(count_parameters(model)))
116 |
117 | # dataloader
118 | train_dataset = DeepSpeakerDataset(
119 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.SUB_DIR, cfg.DATASET.PARTIAL_N_FRAMES)
120 | train_loader = torch.utils.data.DataLoader(
121 | dataset=train_dataset,
122 | batch_size=cfg.TRAIN.BATCH_SIZE,
123 | num_workers=cfg.DATASET.NUM_WORKERS,
124 | pin_memory=True,
125 | shuffle=True,
126 | drop_last=True,
127 | )
128 | test_dataset_verification = VoxcelebTestset(
129 | Path(cfg.DATASET.DATA_DIR), cfg.DATASET.PARTIAL_N_FRAMES)
130 | test_loader_verification = torch.utils.data.DataLoader(
131 | dataset=test_dataset_verification,
132 | batch_size=1,
133 | num_workers=cfg.DATASET.NUM_WORKERS,
134 | pin_memory=True,
135 | shuffle=False,
136 | drop_last=False,
137 | )
138 |
139 | # training setting
140 | writer_dict = {
141 | 'writer': SummaryWriter(args.path_helper['log_path']),
142 | 'train_global_steps': begin_epoch * len(train_loader),
143 | 'valid_global_steps': begin_epoch // cfg.VAL_FREQ,
144 | }
145 |
146 | # training loop
147 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
148 | optimizer, cfg.TRAIN.END_EPOCH, cfg.TRAIN.LR_MIN,
149 | last_epoch=last_epoch
150 | )
151 |
152 | for epoch in tqdm(range(begin_epoch, cfg.TRAIN.END_EPOCH), desc='train progress'):
153 | model.train()
154 | model.drop_path_prob = cfg.MODEL.DROP_PATH_PROB * epoch / cfg.TRAIN.END_EPOCH
155 |
156 | train_from_scratch(cfg, model, optimizer, train_loader, criterion, epoch, writer_dict)
157 |
158 | if epoch % cfg.VAL_FREQ == 0 or epoch == cfg.TRAIN.END_EPOCH - 1:
159 | eer = validate_verification(cfg, model, test_loader_verification)
160 |
161 | # remember best acc@1 and save checkpoint
162 | is_best = eer < best_eer
163 | best_eer = min(eer, best_eer)
164 |
165 | # save
166 | logger.info('=> saving checkpoint to {}'.format(args.path_helper['ckpt_path']))
167 | save_checkpoint({
168 | 'epoch': epoch + 1,
169 | 'state_dict': model.state_dict(),
170 | 'best_eer': best_eer,
171 | 'optimizer': optimizer.state_dict(),
172 | 'path_helper': args.path_helper
173 | }, is_best, args.path_helper['ckpt_path'], 'checkpoint_{}.pth'.format(epoch))
174 |
175 | lr_scheduler.step(epoch)
176 |
177 |
178 | if __name__ == '__main__':
179 | main()
180 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import dateutil.tz
3 | import time
4 | import logging
5 | import os
6 |
7 |
8 | import numpy as np
9 | from sklearn.metrics import roc_curve
10 | from datetime import datetime
11 | import matplotlib.pyplot as plt
12 | from collections import namedtuple
13 |
14 | plt.switch_backend('agg')
15 |
16 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
17 |
18 | def count_parameters(model):
19 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
20 |
21 |
22 | def init_pretrained_weights(model, checkpoint):
23 | """Initializes model with pretrained weights.
24 |
25 | Layers that don't match with pretrained layers in name or size are kept unchanged.
26 | """
27 | checkpoint_file = torch.load(checkpoint)
28 | pretrain_dict = checkpoint_file['state_dict']
29 | model_dict = model.state_dict()
30 | pretrain_dict = {
31 | k: v
32 | for k, v in pretrain_dict.items()
33 | if k in model_dict and model_dict[k].size() == v.size()
34 | }
35 | model_dict.update(pretrain_dict)
36 | model.load_state_dict(model_dict)
37 |
38 |
39 | class AverageMeter(object):
40 | """Computes and stores the average and current value"""
41 |
42 | def __init__(self, name, fmt=':f'):
43 | self.name = name
44 | self.fmt = fmt
45 | self.reset()
46 |
47 | def reset(self):
48 | self.val = 0
49 | self.avg = 0
50 | self.sum = 0
51 | self.count = 0
52 |
53 | def update(self, val, n=1):
54 | self.val = val
55 | self.sum += val * n
56 | self.count += n
57 | self.avg = self.sum / self.count
58 |
59 | def __str__(self):
60 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
61 | return fmtstr.format(**self.__dict__)
62 |
63 |
64 | class ProgressMeter(object):
65 | def __init__(self, num_batches, *meters, prefix="", logger=None):
66 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
67 | self.meters = meters
68 | self.prefix = prefix
69 | self.logger = logger
70 |
71 | def print(self, batch):
72 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
73 | entries += [str(meter) for meter in self.meters]
74 | if self.logger:
75 | self.logger.info('\t'.join(entries))
76 | else:
77 | print('\t'.join(entries))
78 |
79 | def _get_batch_fmtstr(self, num_batches):
80 | num_digits = len(str(num_batches // 1))
81 | fmt = '{:' + str(num_digits) + 'd}'
82 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
83 |
84 | def compute_eer(distances, labels):
85 | # Calculate evaluation metrics
86 | fprs, tprs, _ = roc_curve(labels, distances)
87 | eer = fprs[np.nanargmin(np.absolute((1 - tprs) - fprs))]
88 | return eer
89 |
90 |
91 | def accuracy(output, target, topk=(1,)):
92 | """Computes the accuracy over the k top predictions for the specified values of k"""
93 | with torch.no_grad():
94 | maxk = max(topk)
95 | batch_size = target.size(0)
96 |
97 | _, pred = output.topk(maxk, 1, True, True)
98 | pred = pred.t()
99 | correct = pred.eq(target.view(1, -1).expand_as(pred))
100 |
101 | res = []
102 | for k in topk:
103 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
104 | res.append(correct_k.mul_(100.0 / batch_size))
105 | return res
106 |
107 |
108 | def create_logger(log_dir, phase='train'):
109 | time_str = time.strftime('%Y-%m-%d-%H-%M')
110 | log_file = '{}_{}.log'.format(time_str, phase)
111 | final_log_file = os.path.join(log_dir, log_file)
112 | head = '%(asctime)-15s %(message)s'
113 | logging.basicConfig(filename=str(final_log_file),
114 | format=head)
115 | logger = logging.getLogger()
116 | logger.setLevel(logging.INFO)
117 | console = logging.StreamHandler()
118 | logging.getLogger('').addHandler(console)
119 |
120 | return logger
121 |
122 |
123 | def set_path(root_dir, exp_name):
124 | path_dict = {}
125 | os.makedirs(root_dir, exist_ok=True)
126 |
127 | # set log path
128 | exp_path = os.path.join(root_dir, exp_name)
129 | now = datetime.now(dateutil.tz.tzlocal())
130 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
131 | prefix = exp_path + '_' + timestamp
132 | os.makedirs(prefix)
133 | path_dict['prefix'] = prefix
134 |
135 | # set checkpoint path
136 | ckpt_path = os.path.join(prefix, 'Model')
137 | os.makedirs(ckpt_path)
138 | path_dict['ckpt_path'] = ckpt_path
139 |
140 | log_path = os.path.join(prefix, 'Log')
141 | os.makedirs(log_path)
142 | path_dict['log_path'] = log_path
143 |
144 | # set sample image path for fid calculation
145 | sample_path = os.path.join(prefix, 'Samples')
146 | os.makedirs(sample_path)
147 | path_dict['sample_path'] = sample_path
148 |
149 | return path_dict
150 |
151 |
152 | def to_item(x):
153 | """Converts x, possibly scalar and possibly tensor, to a Python scalar."""
154 | if isinstance(x, (float, int)):
155 | return x
156 |
157 | if float(torch.__version__[0:3]) < 0.4:
158 | assert (x.dim() == 1) and (len(x) == 1)
159 | return x[0]
160 |
161 | return x.item()
162 |
163 |
164 | def save_checkpoint(states, is_best, output_dir,
165 | filename='checkpoint.pth'):
166 | torch.save(states, os.path.join(output_dir, filename))
167 | if is_best:
168 | torch.save(states, os.path.join(output_dir, 'checkpoint_best.pth'))
169 |
170 | def drop_path(x, drop_prob):
171 | if drop_prob > 0.:
172 | keep_prob = 1.-drop_prob
173 | mask = torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)
174 | x.div_(keep_prob)
175 | x.mul_(mask)
176 | return x
177 |
178 |
179 | def gumbel_softmax(logits, tau=1, hard=True, eps=1e-10, dim=-1):
180 | # type: (Tensor, float, bool, float, int) -> Tensor
181 | """
182 | Samples from the `Gumbel-Softmax distribution`_ and optionally discretizes.
183 |
184 | Args:
185 | logits: `[..., num_features]` unnormalized log probabilities
186 | tau: non-negative scalar temperature
187 | hard: if ``True``, the returned samples will be discretized as one-hot vectors,
188 | but will be differentiated as if it is the soft sample in autograd
189 | dim (int): A dimension along which softmax will be computed. Default: -1.
190 |
191 | Returns:
192 | Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
193 | If ``hard=True``, the returned samples will be one-hot, otherwise they will
194 | be probability distributions that sum to 1 across `dim`.
195 |
196 | .. note::
197 | This function is here for legacy reasons, may be removed from nn.Functional in the future.
198 |
199 | .. note::
200 | The main trick for `hard` is to do `y_hard - y_soft.detach() + y_soft`
201 |
202 | It achieves two things:
203 | - makes the output value exactly one-hot
204 | (since we add then subtract y_soft value)
205 | - makes the gradient equal to y_soft gradient
206 | (since we strip all other gradients)
207 |
208 | Examples::
209 | >>> logits = torch.randn(20, 32)
210 | >>> # Sample soft categorical using reparametrization trick:
211 | >>> F.gumbel_softmax(logits, tau=1, hard=False)
212 | >>> # Sample hard categorical using "Straight-through" trick:
213 | >>> F.gumbel_softmax(logits, tau=1, hard=True)
214 |
215 | .. _Gumbel-Softmax distribution:
216 | https://arxiv.org/abs/1611.00712
217 | https://arxiv.org/abs/1611.01144
218 | """
219 | def _gen_gumbels():
220 | gumbels = -torch.empty_like(logits).exponential_().log()
221 | if torch.isnan(gumbels).sum() or torch.isinf(gumbels).sum():
222 | # to avoid zero in exp output
223 | gumbels = _gen_gumbels()
224 | return gumbels
225 |
226 | gumbels = _gen_gumbels() # ~Gumbel(0,1)
227 | gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
228 | y_soft = gumbels.softmax(dim)
229 |
230 | if hard:
231 | # Straight through.
232 | index = y_soft.max(dim, keepdim=True)[1]
233 | y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0)
234 | ret = y_hard - y_soft.detach() + y_soft
235 | else:
236 | # Reparametrization trick.
237 | ret = y_soft
238 |
239 | if torch.isnan(ret).sum():
240 | import ipdb
241 | ipdb.set_trace()
242 | raise OverflowError(f'gumbel softmax output: {ret}')
243 | return ret
--------------------------------------------------------------------------------