├── README.md ├── data ├── __init__.py ├── cifar.py ├── mnist.py └── utils.py ├── example.sh ├── loss.py ├── main.py └── model.py /README.md: -------------------------------------------------------------------------------- 1 | # Co-teaching 2 | NeurIPS'18: Co-teaching: Robust Training of Deep Neural Networks with Extremely Noisy Labels (Pytorch implementation). 3 | 4 | Another related work in NeurIPS'18: 5 | 6 | [Masking: A New Perspective of Noisy Supervision](https://arxiv.org/abs/1805.08193) 7 | 8 | Code available: https://github.com/bhanML/Masking 9 | 10 | ======== 11 | 12 | This is the code for the paper: 13 | [Co-teaching: Robust Training of Deep Neural Networks with Extremely Noisy Labels](https://arxiv.org/abs/1804.06872) 14 | Bo Han*, Quanming Yao*, Xingrui Yu, Gang Niu, Miao Xu, Weihua Hu, Ivor Tsang, Masashi Sugiyama 15 | To be presented at [NeurIPS 2018](https://nips.cc/Conferences/2018/). 16 | 17 | If you find this code useful in your research then please cite 18 | ```bash 19 | @inproceedings{han2018coteaching, 20 | title={Co-teaching: Robust training of deep neural networks with extremely noisy labels}, 21 | author={Han, Bo and Yao, Quanming and Yu, Xingrui and Niu, Gang and Xu, Miao and Hu, Weihua and Tsang, Ivor and Sugiyama, Masashi}, 22 | booktitle={NeurIPS}, 23 | pages={8535--8545}, 24 | year={2018} 25 | } 26 | ``` 27 | 28 | ## Setups 29 | All code was developed and tested on a single machine equiped with a NVIDIA K80 GPU. The environment is as bellow: 30 | 31 | - CentOS 7.2 32 | - CUDA 8.0 33 | - Python 2.7.12 (Anaconda 4.1.1 64 bit) 34 | - PyTorch 0.3.0.post4 35 | - numpy 1.14.2 36 | 37 | Install PyTorch via: 38 | ```bash 39 | pip install http://download.pytorch.org/whl/cu80/torch-0.3.0.post4-cp27-cp27mu-linux_x86_64.whl 40 | ``` 41 | 42 | ## Running Co-teaching on benchmark datasets (MNIST, CIFAR-10 and CIFAR-100) 43 | Here is an example: 44 | 45 | ```bash 46 | python main.py --dataset cifar10 --noise_type symmetric --noise_rate 0.5 47 | ``` 48 | 49 | ## Performance 50 | 51 | | (Flipping, Rate) | MNIST | CIFAR-10 | CIFAR-100 | 52 | | ---------------: | -----: | -------: | --------: | 53 | | (Pair, 45%) | 87.58% | 72.85% | 34.40% | 54 | | (Symmetry, 50%) | 91.68% | 74.49% | 41.23% | 55 | | (Symmetry, 20%) | 97.71% | 82.18% | 54.36% | 56 | 57 | Contact: Xingrui Yu (xingrui.yu@student.uts.edu.au); Bo Han (bo.han@riken.jp). 58 | 59 | ## AutoML 60 | Please check the automated machine learning (AutoML) version of Co-teaching in 61 | - Searching to Exploit Memorization Effect in Learning from Corrupted Labels. ICML-2020 [paper](https://arxiv.org/abs/1911.02377) [code](https://github.com/AutoML-4Paradigm/S2E) 62 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bhanML/Co-teaching/7c7fbe23e15e517db76a0882b6d108e4508e09d6/data/__init__.py -------------------------------------------------------------------------------- /data/cifar.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import numpy as np 6 | import sys 7 | if sys.version_info[0] == 2: 8 | import cPickle as pickle 9 | else: 10 | import pickle 11 | 12 | import torch.utils.data as data 13 | from .utils import download_url, check_integrity, noisify 14 | 15 | class CIFAR10(data.Dataset): 16 | """`CIFAR10 `_ Dataset. 17 | 18 | Args: 19 | root (string): Root directory of dataset where directory 20 | ``cifar-10-batches-py`` exists or will be saved to if download is set to True. 21 | train (bool, optional): If True, creates dataset from training set, otherwise 22 | creates from test set. 23 | transform (callable, optional): A function/transform that takes in an PIL image 24 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 25 | target_transform (callable, optional): A function/transform that takes in the 26 | target and transforms it. 27 | download (bool, optional): If true, downloads the dataset from the internet and 28 | puts it in root directory. If dataset is already downloaded, it is not 29 | downloaded again. 30 | 31 | """ 32 | base_folder = 'cifar-10-batches-py' 33 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 34 | filename = "cifar-10-python.tar.gz" 35 | tgz_md5 = 'c58f30108f718f92721af3b95e74349a' 36 | train_list = [ 37 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], 38 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], 39 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], 40 | ['data_batch_4', '634d18415352ddfa80567beed471001a'], 41 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], 42 | ] 43 | 44 | test_list = [ 45 | ['test_batch', '40351d587109b95175f43aff81a1287e'], 46 | ] 47 | 48 | def __init__(self, root, train=True, 49 | transform=None, target_transform=None, 50 | download=False, 51 | noise_type=None, noise_rate=0.2, random_state=0): 52 | self.root = os.path.expanduser(root) 53 | self.transform = transform 54 | self.target_transform = target_transform 55 | self.train = train # training set or test set 56 | self.dataset='cifar10' 57 | self.noise_type=noise_type 58 | self.nb_classes=10 59 | 60 | if download: 61 | self.download() 62 | 63 | if not self._check_integrity(): 64 | raise RuntimeError('Dataset not found or corrupted.' + 65 | ' You can use download=True to download it') 66 | 67 | # now load the picked numpy arrays 68 | if self.train: 69 | self.train_data = [] 70 | self.train_labels = [] 71 | for fentry in self.train_list: 72 | f = fentry[0] 73 | file = os.path.join(self.root, self.base_folder, f) 74 | fo = open(file, 'rb') 75 | if sys.version_info[0] == 2: 76 | entry = pickle.load(fo) 77 | else: 78 | entry = pickle.load(fo, encoding='latin1') 79 | self.train_data.append(entry['data']) 80 | if 'labels' in entry: 81 | self.train_labels += entry['labels'] 82 | else: 83 | self.train_labels += entry['fine_labels'] 84 | fo.close() 85 | 86 | self.train_data = np.concatenate(self.train_data) 87 | self.train_data = self.train_data.reshape((50000, 3, 32, 32)) 88 | self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC 89 | #if noise_type is not None: 90 | if noise_type !='clean': 91 | # noisify train data 92 | self.train_labels=np.asarray([[self.train_labels[i]] for i in range(len(self.train_labels))]) 93 | self.train_noisy_labels, self.actual_noise_rate = noisify(dataset=self.dataset, train_labels=self.train_labels, noise_type=noise_type, noise_rate=noise_rate, random_state=random_state, nb_classes=self.nb_classes) 94 | self.train_noisy_labels=[i[0] for i in self.train_noisy_labels] 95 | _train_labels=[i[0] for i in self.train_labels] 96 | self.noise_or_not = np.transpose(self.train_noisy_labels)==np.transpose(_train_labels) 97 | else: 98 | f = self.test_list[0][0] 99 | file = os.path.join(self.root, self.base_folder, f) 100 | fo = open(file, 'rb') 101 | if sys.version_info[0] == 2: 102 | entry = pickle.load(fo) 103 | else: 104 | entry = pickle.load(fo, encoding='latin1') 105 | self.test_data = entry['data'] 106 | if 'labels' in entry: 107 | self.test_labels = entry['labels'] 108 | else: 109 | self.test_labels = entry['fine_labels'] 110 | fo.close() 111 | self.test_data = self.test_data.reshape((10000, 3, 32, 32)) 112 | self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC 113 | 114 | def __getitem__(self, index): 115 | """ 116 | Args: 117 | index (int): Index 118 | 119 | Returns: 120 | tuple: (image, target) where target is index of the target class. 121 | """ 122 | if self.train: 123 | if self.noise_type !='clean': 124 | img, target = self.train_data[index], self.train_noisy_labels[index] 125 | else: 126 | img, target = self.train_data[index], self.train_labels[index] 127 | else: 128 | img, target = self.test_data[index], self.test_labels[index] 129 | 130 | # doing this so that it is consistent with all other datasets 131 | # to return a PIL Image 132 | img = Image.fromarray(img) 133 | 134 | if self.transform is not None: 135 | img = self.transform(img) 136 | 137 | if self.target_transform is not None: 138 | target = self.target_transform(target) 139 | 140 | return img, target, index 141 | 142 | def __len__(self): 143 | if self.train: 144 | return len(self.train_data) 145 | else: 146 | return len(self.test_data) 147 | 148 | def _check_integrity(self): 149 | root = self.root 150 | for fentry in (self.train_list + self.test_list): 151 | filename, md5 = fentry[0], fentry[1] 152 | fpath = os.path.join(root, self.base_folder, filename) 153 | if not check_integrity(fpath, md5): 154 | return False 155 | return True 156 | 157 | def download(self): 158 | import tarfile 159 | 160 | if self._check_integrity(): 161 | print('Files already downloaded and verified') 162 | return 163 | 164 | root = self.root 165 | download_url(self.url, root, self.filename, self.tgz_md5) 166 | 167 | # extract file 168 | cwd = os.getcwd() 169 | tar = tarfile.open(os.path.join(root, self.filename), "r:gz") 170 | os.chdir(root) 171 | tar.extractall() 172 | tar.close() 173 | os.chdir(cwd) 174 | 175 | def __repr__(self): 176 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 177 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 178 | tmp = 'train' if self.train is True else 'test' 179 | fmt_str += ' Split: {}\n'.format(tmp) 180 | fmt_str += ' Root Location: {}\n'.format(self.root) 181 | tmp = ' Transforms (if any): ' 182 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 183 | tmp = ' Target Transforms (if any): ' 184 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 185 | return fmt_str 186 | 187 | class CIFAR100(data.Dataset): 188 | """`CIFAR100 `_ Dataset. 189 | 190 | Args: 191 | root (string): Root directory of dataset where directory 192 | ``cifar-10-batches-py`` exists or will be saved to if download is set to True. 193 | train (bool, optional): If True, creates dataset from training set, otherwise 194 | creates from test set. 195 | transform (callable, optional): A function/transform that takes in an PIL image 196 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 197 | target_transform (callable, optional): A function/transform that takes in the 198 | target and transforms it. 199 | download (bool, optional): If true, downloads the dataset from the internet and 200 | puts it in root directory. If dataset is already downloaded, it is not 201 | downloaded again. 202 | 203 | """ 204 | base_folder = 'cifar-100-python' 205 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 206 | filename = "cifar-100-python.tar.gz" 207 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 208 | train_list = [ 209 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 210 | ] 211 | 212 | test_list = [ 213 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 214 | ] 215 | 216 | 217 | def __init__(self, root, train=True, 218 | transform=None, target_transform=None, 219 | download=False, 220 | noise_type=None, noise_rate=0.2, random_state=0): 221 | self.root = os.path.expanduser(root) 222 | self.transform = transform 223 | self.target_transform = target_transform 224 | self.train = train # training set or test set 225 | self.dataset='cifar100' 226 | self.noise_type=noise_type 227 | self.nb_classes=100 228 | 229 | if download: 230 | self.download() 231 | 232 | if not self._check_integrity(): 233 | raise RuntimeError('Dataset not found or corrupted.' + 234 | ' You can use download=True to download it') 235 | 236 | # now load the picked numpy arrays 237 | if self.train: 238 | self.train_data = [] 239 | self.train_labels = [] 240 | for fentry in self.train_list: 241 | f = fentry[0] 242 | file = os.path.join(self.root, self.base_folder, f) 243 | fo = open(file, 'rb') 244 | if sys.version_info[0] == 2: 245 | entry = pickle.load(fo) 246 | else: 247 | entry = pickle.load(fo, encoding='latin1') 248 | self.train_data.append(entry['data']) 249 | if 'labels' in entry: 250 | self.train_labels += entry['labels'] 251 | else: 252 | self.train_labels += entry['fine_labels'] 253 | fo.close() 254 | 255 | self.train_data = np.concatenate(self.train_data) 256 | self.train_data = self.train_data.reshape((50000, 3, 32, 32)) 257 | self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC 258 | if noise_type is not None: 259 | # noisify train data 260 | self.train_labels=np.asarray([[self.train_labels[i]] for i in range(len(self.train_labels))]) 261 | self.train_noisy_labels, self.actual_noise_rate = noisify(dataset=self.dataset, train_labels=self.train_labels, noise_type=noise_type, noise_rate=noise_rate, random_state=random_state, nb_classes=self.nb_classes) 262 | self.train_noisy_labels=[i[0] for i in self.train_noisy_labels] 263 | _train_labels=[i[0] for i in self.train_labels] 264 | self.noise_or_not = np.transpose(self.train_noisy_labels)==np.transpose(_train_labels) 265 | else: 266 | f = self.test_list[0][0] 267 | file = os.path.join(self.root, self.base_folder, f) 268 | fo = open(file, 'rb') 269 | if sys.version_info[0] == 2: 270 | entry = pickle.load(fo) 271 | else: 272 | entry = pickle.load(fo, encoding='latin1') 273 | self.test_data = entry['data'] 274 | if 'labels' in entry: 275 | self.test_labels = entry['labels'] 276 | else: 277 | self.test_labels = entry['fine_labels'] 278 | fo.close() 279 | self.test_data = self.test_data.reshape((10000, 3, 32, 32)) 280 | self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC 281 | 282 | def __getitem__(self, index): 283 | """ 284 | Args: 285 | index (int): Index 286 | 287 | Returns: 288 | tuple: (image, target) where target is index of the target class. 289 | """ 290 | if self.train: 291 | if self.noise_type is not None: 292 | img, target = self.train_data[index], self.train_noisy_labels[index] 293 | else: 294 | img, target = self.train_data[index], self.train_labels[index] 295 | else: 296 | img, target = self.test_data[index], self.test_labels[index] 297 | 298 | # doing this so that it is consistent with all other datasets 299 | # to return a PIL Image 300 | img = Image.fromarray(img) 301 | 302 | if self.transform is not None: 303 | img = self.transform(img) 304 | 305 | if self.target_transform is not None: 306 | target = self.target_transform(target) 307 | 308 | return img, target, index 309 | 310 | def __len__(self): 311 | if self.train: 312 | return len(self.train_data) 313 | else: 314 | return len(self.test_data) 315 | 316 | def _check_integrity(self): 317 | root = self.root 318 | for fentry in (self.train_list + self.test_list): 319 | filename, md5 = fentry[0], fentry[1] 320 | fpath = os.path.join(root, self.base_folder, filename) 321 | if not check_integrity(fpath, md5): 322 | return False 323 | return True 324 | 325 | def download(self): 326 | import tarfile 327 | 328 | if self._check_integrity(): 329 | print('Files already downloaded and verified') 330 | return 331 | 332 | root = self.root 333 | download_url(self.url, root, self.filename, self.tgz_md5) 334 | 335 | # extract file 336 | cwd = os.getcwd() 337 | tar = tarfile.open(os.path.join(root, self.filename), "r:gz") 338 | os.chdir(root) 339 | tar.extractall() 340 | tar.close() 341 | os.chdir(cwd) 342 | 343 | def __repr__(self): 344 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 345 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 346 | tmp = 'train' if self.train is True else 'test' 347 | fmt_str += ' Split: {}\n'.format(tmp) 348 | fmt_str += ' Root Location: {}\n'.format(self.root) 349 | tmp = ' Transforms (if any): ' 350 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 351 | tmp = ' Target Transforms (if any): ' 352 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 353 | return fmt_str 354 | 355 | 356 | 357 | 358 | -------------------------------------------------------------------------------- /data/mnist.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.utils.data as data 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import errno 7 | import numpy as np 8 | import torch 9 | import codecs 10 | from .utils import noisify 11 | 12 | 13 | class MNIST(data.Dataset): 14 | """`MNIST `_ Dataset. 15 | 16 | Args: 17 | root (string): Root directory of dataset where ``processed/training.pt`` 18 | and ``processed/test.pt`` exist. 19 | train (bool, optional): If True, creates dataset from ``training.pt``, 20 | otherwise from ``test.pt``. 21 | download (bool, optional): If true, downloads the dataset from the internet and 22 | puts it in root directory. If dataset is already downloaded, it is not 23 | downloaded again. 24 | transform (callable, optional): A function/transform that takes in an PIL image 25 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 26 | target_transform (callable, optional): A function/transform that takes in the 27 | target and transforms it. 28 | """ 29 | urls = [ 30 | 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', 31 | 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', 32 | 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', 33 | 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', 34 | ] 35 | raw_folder = 'raw' 36 | processed_folder = 'processed' 37 | training_file = 'training.pt' 38 | test_file = 'test.pt' 39 | 40 | def __init__(self, root, train=True, transform=None, target_transform=None, download=False, 41 | noise_type=None, noise_rate=0.2, random_state=0): 42 | self.root = os.path.expanduser(root) 43 | self.transform = transform 44 | self.target_transform = target_transform 45 | self.train = train # training set or test set 46 | self.dataset='mnist' 47 | self.noise_type=noise_type 48 | 49 | if download: 50 | self.download() 51 | 52 | if not self._check_exists(): 53 | raise RuntimeError('Dataset not found.' + 54 | ' You can use download=True to download it') 55 | 56 | if self.train: 57 | self.train_data, self.train_labels = torch.load( 58 | os.path.join(self.root, self.processed_folder, self.training_file)) 59 | 60 | if noise_type != 'clean': 61 | self.train_labels=np.asarray([[self.train_labels[i]] for i in range(len(self.train_labels))]) 62 | self.train_noisy_labels, self.actual_noise_rate = noisify(dataset=self.dataset, train_labels=self.train_labels, noise_type=noise_type, noise_rate=noise_rate, random_state=random_state) 63 | self.train_noisy_labels=[i[0] for i in self.train_noisy_labels] 64 | _train_labels=[i[0] for i in self.train_labels] 65 | self.noise_or_not = np.transpose(self.train_noisy_labels)==np.transpose(_train_labels) 66 | else: 67 | self.test_data, self.test_labels = torch.load( 68 | os.path.join(self.root, self.processed_folder, self.test_file)) 69 | 70 | def __getitem__(self, index): 71 | """ 72 | Args: 73 | index (int): Index 74 | 75 | Returns: 76 | tuple: (image, target) where target is index of the target class. 77 | """ 78 | if self.train: 79 | #if self.noise_type is not None: 80 | if self.noise_type != 'clean': 81 | img, target = self.train_data[index], self.train_noisy_labels[index] 82 | else: 83 | img, target = self.train_data[index], self.train_labels[index] 84 | else: 85 | img, target = self.test_data[index], self.test_labels[index] 86 | 87 | # doing this so that it is consistent with all other datasets 88 | # to return a PIL Image 89 | img = Image.fromarray(img.numpy(), mode='L') 90 | 91 | if self.transform is not None: 92 | img = self.transform(img) 93 | 94 | if self.target_transform is not None: 95 | target = self.target_transform(target) 96 | 97 | return img, target, index 98 | 99 | def __len__(self): 100 | if self.train: 101 | return len(self.train_data) 102 | else: 103 | return len(self.test_data) 104 | 105 | def _check_exists(self): 106 | return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \ 107 | os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file)) 108 | 109 | def download(self): 110 | """Download the MNIST data if it doesn't exist in processed_folder already.""" 111 | from six.moves import urllib 112 | import gzip 113 | 114 | if self._check_exists(): 115 | return 116 | 117 | # download files 118 | try: 119 | os.makedirs(os.path.join(self.root, self.raw_folder)) 120 | os.makedirs(os.path.join(self.root, self.processed_folder)) 121 | except OSError as e: 122 | if e.errno == errno.EEXIST: 123 | pass 124 | else: 125 | raise 126 | 127 | for url in self.urls: 128 | print('Downloading ' + url) 129 | data = urllib.request.urlopen(url) 130 | filename = url.rpartition('/')[2] 131 | file_path = os.path.join(self.root, self.raw_folder, filename) 132 | with open(file_path, 'wb') as f: 133 | f.write(data.read()) 134 | with open(file_path.replace('.gz', ''), 'wb') as out_f, \ 135 | gzip.GzipFile(file_path) as zip_f: 136 | out_f.write(zip_f.read()) 137 | os.unlink(file_path) 138 | 139 | # process and save as torch files 140 | print('Processing...') 141 | 142 | training_set = ( 143 | read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')), 144 | read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte')) 145 | ) 146 | test_set = ( 147 | read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')), 148 | read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte')) 149 | ) 150 | with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f: 151 | torch.save(training_set, f) 152 | with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f: 153 | torch.save(test_set, f) 154 | 155 | print('Done!') 156 | 157 | def __repr__(self): 158 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 159 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 160 | tmp = 'train' if self.train is True else 'test' 161 | fmt_str += ' Split: {}\n'.format(tmp) 162 | fmt_str += ' Root Location: {}\n'.format(self.root) 163 | tmp = ' Transforms (if any): ' 164 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 165 | tmp = ' Target Transforms (if any): ' 166 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 167 | return fmt_str 168 | 169 | 170 | def get_int(b): 171 | return int(codecs.encode(b, 'hex'), 16) 172 | 173 | 174 | def read_label_file(path): 175 | with open(path, 'rb') as f: 176 | data = f.read() 177 | assert get_int(data[:4]) == 2049 178 | length = get_int(data[4:8]) 179 | parsed = np.frombuffer(data, dtype=np.uint8, offset=8) 180 | return torch.from_numpy(parsed).view(length).long() 181 | 182 | 183 | def read_image_file(path): 184 | with open(path, 'rb') as f: 185 | data = f.read() 186 | assert get_int(data[:4]) == 2051 187 | length = get_int(data[4:8]) 188 | num_rows = get_int(data[8:12]) 189 | num_cols = get_int(data[12:16]) 190 | images = [] 191 | parsed = np.frombuffer(data, dtype=np.uint8, offset=16) 192 | return torch.from_numpy(parsed).view(length, num_rows, num_cols) 193 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import copy 4 | import hashlib 5 | import errno 6 | import numpy as np 7 | from numpy.testing import assert_array_almost_equal 8 | 9 | def check_integrity(fpath, md5): 10 | if not os.path.isfile(fpath): 11 | return False 12 | md5o = hashlib.md5() 13 | with open(fpath, 'rb') as f: 14 | # read in 1MB chunks 15 | for chunk in iter(lambda: f.read(1024 * 1024), b''): 16 | md5o.update(chunk) 17 | md5c = md5o.hexdigest() 18 | if md5c != md5: 19 | return False 20 | return True 21 | 22 | 23 | def download_url(url, root, filename, md5): 24 | from six.moves import urllib 25 | 26 | root = os.path.expanduser(root) 27 | fpath = os.path.join(root, filename) 28 | 29 | try: 30 | os.makedirs(root) 31 | except OSError as e: 32 | if e.errno == errno.EEXIST: 33 | pass 34 | else: 35 | raise 36 | 37 | # downloads file 38 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 39 | print('Using downloaded and verified file: ' + fpath) 40 | else: 41 | try: 42 | print('Downloading ' + url + ' to ' + fpath) 43 | urllib.request.urlretrieve(url, fpath) 44 | except: 45 | if url[:5] == 'https': 46 | url = url.replace('https:', 'http:') 47 | print('Failed download. Trying https -> http instead.' 48 | ' Downloading ' + url + ' to ' + fpath) 49 | urllib.request.urlretrieve(url, fpath) 50 | 51 | 52 | def list_dir(root, prefix=False): 53 | """List all directories at a given root 54 | 55 | Args: 56 | root (str): Path to directory whose folders need to be listed 57 | prefix (bool, optional): If true, prepends the path to each result, otherwise 58 | only returns the name of the directories found 59 | """ 60 | root = os.path.expanduser(root) 61 | directories = list( 62 | filter( 63 | lambda p: os.path.isdir(os.path.join(root, p)), 64 | os.listdir(root) 65 | ) 66 | ) 67 | 68 | if prefix is True: 69 | directories = [os.path.join(root, d) for d in directories] 70 | 71 | return directories 72 | 73 | 74 | def list_files(root, suffix, prefix=False): 75 | """List all files ending with a suffix at a given root 76 | 77 | Args: 78 | root (str): Path to directory whose folders need to be listed 79 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). 80 | It uses the Python "str.endswith" method and is passed directly 81 | prefix (bool, optional): If true, prepends the path to each result, otherwise 82 | only returns the name of the files found 83 | """ 84 | root = os.path.expanduser(root) 85 | files = list( 86 | filter( 87 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), 88 | os.listdir(root) 89 | ) 90 | ) 91 | 92 | if prefix is True: 93 | files = [os.path.join(root, d) for d in files] 94 | 95 | return files 96 | 97 | # basic function 98 | def multiclass_noisify(y, P, random_state=0): 99 | """ Flip classes according to transition probability matrix T. 100 | It expects a number between 0 and the number of classes - 1. 101 | """ 102 | print np.max(y), P.shape[0] 103 | assert P.shape[0] == P.shape[1] 104 | assert np.max(y) < P.shape[0] 105 | 106 | # row stochastic matrix 107 | assert_array_almost_equal(P.sum(axis=1), np.ones(P.shape[1])) 108 | assert (P >= 0.0).all() 109 | 110 | m = y.shape[0] 111 | print m 112 | new_y = y.copy() 113 | flipper = np.random.RandomState(random_state) 114 | 115 | for idx in np.arange(m): 116 | i = y[idx] 117 | # draw a vector with only an 1 118 | flipped = flipper.multinomial(1, P[i, :][0], 1)[0] 119 | new_y[idx] = np.where(flipped == 1)[0] 120 | 121 | return new_y 122 | 123 | 124 | # noisify_pairflip call the function "multiclass_noisify" 125 | def noisify_pairflip(y_train, noise, random_state=None, nb_classes=10): 126 | """mistakes: 127 | flip in the pair 128 | """ 129 | P = np.eye(nb_classes) 130 | n = noise 131 | 132 | if n > 0.0: 133 | # 0 -> 1 134 | P[0, 0], P[0, 1] = 1. - n, n 135 | for i in range(1, nb_classes-1): 136 | P[i, i], P[i, i + 1] = 1. - n, n 137 | P[nb_classes-1, nb_classes-1], P[nb_classes-1, 0] = 1. - n, n 138 | 139 | y_train_noisy = multiclass_noisify(y_train, P=P, 140 | random_state=random_state) 141 | actual_noise = (y_train_noisy != y_train).mean() 142 | assert actual_noise > 0.0 143 | print('Actual noise %.2f' % actual_noise) 144 | y_train = y_train_noisy 145 | print P 146 | 147 | return y_train, actual_noise 148 | 149 | def noisify_multiclass_symmetric(y_train, noise, random_state=None, nb_classes=10): 150 | """mistakes: 151 | flip in the symmetric way 152 | """ 153 | P = np.ones((nb_classes, nb_classes)) 154 | n = noise 155 | P = (n / (nb_classes - 1)) * P 156 | 157 | if n > 0.0: 158 | # 0 -> 1 159 | P[0, 0] = 1. - n 160 | for i in range(1, nb_classes-1): 161 | P[i, i] = 1. - n 162 | P[nb_classes-1, nb_classes-1] = 1. - n 163 | 164 | y_train_noisy = multiclass_noisify(y_train, P=P, 165 | random_state=random_state) 166 | actual_noise = (y_train_noisy != y_train).mean() 167 | assert actual_noise > 0.0 168 | print('Actual noise %.2f' % actual_noise) 169 | y_train = y_train_noisy 170 | print P 171 | 172 | return y_train, actual_noise 173 | 174 | def noisify(dataset='mnist', nb_classes=10, train_labels=None, noise_type=None, noise_rate=0, random_state=0): 175 | if noise_type == 'pairflip': 176 | train_noisy_labels, actual_noise_rate = noisify_pairflip(train_labels, noise_rate, random_state=0, nb_classes=nb_classes) 177 | if noise_type == 'symmetric': 178 | train_noisy_labels, actual_noise_rate = noisify_multiclass_symmetric(train_labels, noise_rate, random_state=0, nb_classes=nb_classes) 179 | return train_noisy_labels, actual_noise_rate 180 | -------------------------------------------------------------------------------- /example.sh: -------------------------------------------------------------------------------- 1 | python main.py --dataset mnist --noise_type pairflip --noise_rate 0.45 2 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | 7 | # Loss functions 8 | def loss_coteaching(y_1, y_2, t, forget_rate, ind, noise_or_not): 9 | loss_1 = F.cross_entropy(y_1, t, reduce = False) 10 | ind_1_sorted = np.argsort(loss_1.data).cuda() 11 | loss_1_sorted = loss_1[ind_1_sorted] 12 | 13 | loss_2 = F.cross_entropy(y_2, t, reduce = False) 14 | ind_2_sorted = np.argsort(loss_2.data).cuda() 15 | loss_2_sorted = loss_2[ind_2_sorted] 16 | 17 | remember_rate = 1 - forget_rate 18 | num_remember = int(remember_rate * len(loss_1_sorted)) 19 | 20 | pure_ratio_1 = np.sum(noise_or_not[ind[ind_1_sorted[:num_remember]]])/float(num_remember) 21 | pure_ratio_2 = np.sum(noise_or_not[ind[ind_2_sorted[:num_remember]]])/float(num_remember) 22 | 23 | ind_1_update=ind_1_sorted[:num_remember] 24 | ind_2_update=ind_2_sorted[:num_remember] 25 | # exchange 26 | loss_1_update = F.cross_entropy(y_1[ind_2_update], t[ind_2_update]) 27 | loss_2_update = F.cross_entropy(y_2[ind_1_update], t[ind_1_update]) 28 | 29 | return torch.sum(loss_1_update)/num_remember, torch.sum(loss_2_update)/num_remember, pure_ratio_1, pure_ratio_2 30 | 31 | 32 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import torchvision.transforms as transforms 8 | from data.cifar import CIFAR10, CIFAR100 9 | from data.mnist import MNIST 10 | from model import CNN 11 | import argparse, sys 12 | import numpy as np 13 | import datetime 14 | import shutil 15 | 16 | from loss import loss_coteaching 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--lr', type = float, default = 0.001) 20 | parser.add_argument('--result_dir', type = str, help = 'dir to save result txt files', default = 'results/') 21 | parser.add_argument('--noise_rate', type = float, help = 'corruption rate, should be less than 1', default = 0.2) 22 | parser.add_argument('--forget_rate', type = float, help = 'forget rate', default = None) 23 | parser.add_argument('--noise_type', type = str, help='[pairflip, symmetric]', default='pairflip') 24 | parser.add_argument('--num_gradual', type = int, default = 10, help='how many epochs for linear drop rate, can be 5, 10, 15. This parameter is equal to Tk for R(T) in Co-teaching paper.') 25 | parser.add_argument('--exponent', type = float, default = 1, help='exponent of the forget rate, can be 0.5, 1, 2. This parameter is equal to c in Tc for R(T) in Co-teaching paper.') 26 | parser.add_argument('--top_bn', action='store_true') 27 | parser.add_argument('--dataset', type = str, help = 'mnist, cifar10, or cifar100', default = 'mnist') 28 | parser.add_argument('--n_epoch', type=int, default=200) 29 | parser.add_argument('--seed', type=int, default=1) 30 | parser.add_argument('--print_freq', type=int, default=50) 31 | parser.add_argument('--num_workers', type=int, default=4, help='how many subprocesses to use for data loading') 32 | parser.add_argument('--num_iter_per_epoch', type=int, default=400) 33 | parser.add_argument('--epoch_decay_start', type=int, default=80) 34 | 35 | args = parser.parse_args() 36 | 37 | # Seed 38 | torch.manual_seed(args.seed) 39 | torch.cuda.manual_seed(args.seed) 40 | 41 | # Hyper Parameters 42 | batch_size = 128 43 | learning_rate = args.lr 44 | 45 | # load dataset 46 | if args.dataset=='mnist': 47 | input_channel=1 48 | num_classes=10 49 | args.top_bn = False 50 | args.epoch_decay_start = 80 51 | args.n_epoch = 200 52 | train_dataset = MNIST(root='./data/', 53 | download=True, 54 | train=True, 55 | transform=transforms.ToTensor(), 56 | noise_type=args.noise_type, 57 | noise_rate=args.noise_rate 58 | ) 59 | 60 | test_dataset = MNIST(root='./data/', 61 | download=True, 62 | train=False, 63 | transform=transforms.ToTensor(), 64 | noise_type=args.noise_type, 65 | noise_rate=args.noise_rate 66 | ) 67 | 68 | if args.dataset=='cifar10': 69 | input_channel=3 70 | num_classes=10 71 | args.top_bn = False 72 | args.epoch_decay_start = 80 73 | args.n_epoch = 200 74 | train_dataset = CIFAR10(root='./data/', 75 | download=True, 76 | train=True, 77 | transform=transforms.ToTensor(), 78 | noise_type=args.noise_type, 79 | noise_rate=args.noise_rate 80 | ) 81 | 82 | test_dataset = CIFAR10(root='./data/', 83 | download=True, 84 | train=False, 85 | transform=transforms.ToTensor(), 86 | noise_type=args.noise_type, 87 | noise_rate=args.noise_rate 88 | ) 89 | 90 | if args.dataset=='cifar100': 91 | input_channel=3 92 | num_classes=100 93 | args.top_bn = False 94 | args.epoch_decay_start = 100 95 | args.n_epoch = 200 96 | train_dataset = CIFAR100(root='./data/', 97 | download=True, 98 | train=True, 99 | transform=transforms.ToTensor(), 100 | noise_type=args.noise_type, 101 | noise_rate=args.noise_rate 102 | ) 103 | 104 | test_dataset = CIFAR100(root='./data/', 105 | download=True, 106 | train=False, 107 | transform=transforms.ToTensor(), 108 | noise_type=args.noise_type, 109 | noise_rate=args.noise_rate 110 | ) 111 | 112 | if args.forget_rate is None: 113 | forget_rate=args.noise_rate 114 | else: 115 | forget_rate=args.forget_rate 116 | 117 | noise_or_not = train_dataset.noise_or_not 118 | 119 | # Adjust learning rate and betas for Adam Optimizer 120 | mom1 = 0.9 121 | mom2 = 0.1 122 | alpha_plan = [learning_rate] * args.n_epoch 123 | beta1_plan = [mom1] * args.n_epoch 124 | for i in range(args.epoch_decay_start, args.n_epoch): 125 | alpha_plan[i] = float(args.n_epoch - i) / (args.n_epoch - args.epoch_decay_start) * learning_rate 126 | beta1_plan[i] = mom2 127 | 128 | def adjust_learning_rate(optimizer, epoch): 129 | for param_group in optimizer.param_groups: 130 | param_group['lr']=alpha_plan[epoch] 131 | param_group['betas']=(beta1_plan[epoch], 0.999) # Only change beta1 132 | 133 | # define drop rate schedule 134 | rate_schedule = np.ones(args.n_epoch)*forget_rate 135 | rate_schedule[:args.num_gradual] = np.linspace(0, forget_rate**args.exponent, args.num_gradual) 136 | 137 | save_dir = args.result_dir +'/' +args.dataset+'/coteaching/' 138 | 139 | if not os.path.exists(save_dir): 140 | os.system('mkdir -p %s' % save_dir) 141 | 142 | model_str=args.dataset+'_coteaching_'+args.noise_type+'_'+str(args.noise_rate) 143 | 144 | txtfile=save_dir+"/"+model_str+".txt" 145 | nowTime=datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') 146 | if os.path.exists(txtfile): 147 | os.system('mv %s %s' % (txtfile, txtfile+".bak-%s" % nowTime)) 148 | 149 | 150 | def accuracy(logit, target, topk=(1,)): 151 | """Computes the precision@k for the specified values of k""" 152 | output = F.softmax(logit, dim=1) 153 | maxk = max(topk) 154 | batch_size = target.size(0) 155 | 156 | _, pred = output.topk(maxk, 1, True, True) 157 | pred = pred.t() 158 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 159 | 160 | res = [] 161 | for k in topk: 162 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 163 | res.append(correct_k.mul_(100.0 / batch_size)) 164 | return res 165 | 166 | # Train the Model 167 | def train(train_loader,epoch, model1, optimizer1, model2, optimizer2): 168 | print 'Training %s...' % model_str 169 | pure_ratio_list=[] 170 | pure_ratio_1_list=[] 171 | pure_ratio_2_list=[] 172 | 173 | train_total=0 174 | train_correct=0 175 | train_total2=0 176 | train_correct2=0 177 | 178 | for i, (images, labels, indexes) in enumerate(train_loader): 179 | ind=indexes.cpu().numpy().transpose() 180 | if i>args.num_iter_per_epoch: 181 | break 182 | 183 | images = Variable(images).cuda() 184 | labels = Variable(labels).cuda() 185 | 186 | # Forward + Backward + Optimize 187 | logits1=model1(images) 188 | prec1, _ = accuracy(logits1, labels, topk=(1, 5)) 189 | train_total+=1 190 | train_correct+=prec1 191 | 192 | logits2 = model2(images) 193 | prec2, _ = accuracy(logits2, labels, topk=(1, 5)) 194 | train_total2+=1 195 | train_correct2+=prec2 196 | loss_1, loss_2, pure_ratio_1, pure_ratio_2 = loss_coteaching(logits1, logits2, labels, rate_schedule[epoch], ind, noise_or_not) 197 | pure_ratio_1_list.append(100*pure_ratio_1) 198 | pure_ratio_2_list.append(100*pure_ratio_2) 199 | 200 | optimizer1.zero_grad() 201 | loss_1.backward() 202 | optimizer1.step() 203 | optimizer2.zero_grad() 204 | loss_2.backward() 205 | optimizer2.step() 206 | if (i+1) % args.print_freq == 0: 207 | print ('Epoch [%d/%d], Iter [%d/%d] Training Accuracy1: %.4F, Training Accuracy2: %.4f, Loss1: %.4f, Loss2: %.4f, Pure Ratio1: %.4f, Pure Ratio2 %.4f' 208 | %(epoch+1, args.n_epoch, i+1, len(train_dataset)//batch_size, prec1, prec2, loss_1.data[0], loss_2.data[0], np.sum(pure_ratio_1_list)/len(pure_ratio_1_list), np.sum(pure_ratio_2_list)/len(pure_ratio_2_list))) 209 | 210 | train_acc1=float(train_correct)/float(train_total) 211 | train_acc2=float(train_correct2)/float(train_total2) 212 | return train_acc1, train_acc2, pure_ratio_1_list, pure_ratio_2_list 213 | 214 | # Evaluate the Model 215 | def evaluate(test_loader, model1, model2): 216 | print 'Evaluating %s...' % model_str 217 | model1.eval() # Change model to 'eval' mode. 218 | correct1 = 0 219 | total1 = 0 220 | for images, labels, _ in test_loader: 221 | images = Variable(images).cuda() 222 | logits1 = model1(images) 223 | outputs1 = F.softmax(logits1, dim=1) 224 | _, pred1 = torch.max(outputs1.data, 1) 225 | total1 += labels.size(0) 226 | correct1 += (pred1.cpu() == labels).sum() 227 | 228 | model2.eval() # Change model to 'eval' mode 229 | correct2 = 0 230 | total2 = 0 231 | for images, labels, _ in test_loader: 232 | images = Variable(images).cuda() 233 | logits2 = model2(images) 234 | outputs2 = F.softmax(logits2, dim=1) 235 | _, pred2 = torch.max(outputs2.data, 1) 236 | total2 += labels.size(0) 237 | correct2 += (pred2.cpu() == labels).sum() 238 | 239 | acc1 = 100*float(correct1)/float(total1) 240 | acc2 = 100*float(correct2)/float(total2) 241 | return acc1, acc2 242 | 243 | 244 | def main(): 245 | # Data Loader (Input Pipeline) 246 | print 'loading dataset...' 247 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 248 | batch_size=batch_size, 249 | num_workers=args.num_workers, 250 | drop_last=True, 251 | shuffle=True) 252 | 253 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 254 | batch_size=batch_size, 255 | num_workers=args.num_workers, 256 | drop_last=True, 257 | shuffle=False) 258 | # Define models 259 | print 'building model...' 260 | cnn1 = CNN(input_channel=input_channel, n_outputs=num_classes) 261 | cnn1.cuda() 262 | print cnn1.parameters 263 | optimizer1 = torch.optim.Adam(cnn1.parameters(), lr=learning_rate) 264 | 265 | cnn2 = CNN(input_channel=input_channel, n_outputs=num_classes) 266 | cnn2.cuda() 267 | print cnn2.parameters 268 | optimizer2 = torch.optim.Adam(cnn2.parameters(), lr=learning_rate) 269 | 270 | mean_pure_ratio1=0 271 | mean_pure_ratio2=0 272 | 273 | with open(txtfile, "a") as myfile: 274 | myfile.write('epoch: train_acc1 train_acc2 test_acc1 test_acc2 pure_ratio1 pure_ratio2\n') 275 | 276 | epoch=0 277 | train_acc1=0 278 | train_acc2=0 279 | # evaluate models with random weights 280 | test_acc1, test_acc2=evaluate(test_loader, cnn1, cnn2) 281 | print('Epoch [%d/%d] Test Accuracy on the %s test images: Model1 %.4f %% Model2 %.4f %% Pure Ratio1 %.4f %% Pure Ratio2 %.4f %%' % (epoch+1, args.n_epoch, len(test_dataset), test_acc1, test_acc2, mean_pure_ratio1, mean_pure_ratio2)) 282 | # save results 283 | with open(txtfile, "a") as myfile: 284 | myfile.write(str(int(epoch)) + ': ' + str(train_acc1) +' ' + str(train_acc2) +' ' + str(test_acc1) + " " + str(test_acc2) + ' ' + str(mean_pure_ratio1) + ' ' + str(mean_pure_ratio2) + "\n") 285 | 286 | # training 287 | for epoch in range(1, args.n_epoch): 288 | # train models 289 | cnn1.train() 290 | adjust_learning_rate(optimizer1, epoch) 291 | cnn2.train() 292 | adjust_learning_rate(optimizer2, epoch) 293 | train_acc1, train_acc2, pure_ratio_1_list, pure_ratio_2_list=train(train_loader, epoch, cnn1, optimizer1, cnn2, optimizer2) 294 | # evaluate models 295 | test_acc1, test_acc2=evaluate(test_loader, cnn1, cnn2) 296 | # save results 297 | mean_pure_ratio1 = sum(pure_ratio_1_list)/len(pure_ratio_1_list) 298 | mean_pure_ratio2 = sum(pure_ratio_2_list)/len(pure_ratio_2_list) 299 | print('Epoch [%d/%d] Test Accuracy on the %s test images: Model1 %.4f %% Model2 %.4f %%, Pure Ratio 1 %.4f %%, Pure Ratio 2 %.4f %%' % (epoch+1, args.n_epoch, len(test_dataset), test_acc1, test_acc2, mean_pure_ratio1, mean_pure_ratio2)) 300 | with open(txtfile, "a") as myfile: 301 | myfile.write(str(int(epoch)) + ': ' + str(train_acc1) +' ' + str(train_acc2) +' ' + str(test_acc1) + " " + str(test_acc2) + ' ' + str(mean_pure_ratio1) + ' ' + str(mean_pure_ratio2) + "\n") 302 | 303 | if __name__=='__main__': 304 | main() 305 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | 8 | def call_bn(bn, x): 9 | return bn(x) 10 | 11 | class CNN(nn.Module): 12 | def __init__(self, input_channel=3, n_outputs=10, dropout_rate=0.25, top_bn=False): 13 | self.dropout_rate = dropout_rate 14 | self.top_bn = top_bn 15 | super(CNN, self).__init__() 16 | self.c1=nn.Conv2d(input_channel,128,kernel_size=3,stride=1, padding=1) 17 | self.c2=nn.Conv2d(128,128,kernel_size=3,stride=1, padding=1) 18 | self.c3=nn.Conv2d(128,128,kernel_size=3,stride=1, padding=1) 19 | self.c4=nn.Conv2d(128,256,kernel_size=3,stride=1, padding=1) 20 | self.c5=nn.Conv2d(256,256,kernel_size=3,stride=1, padding=1) 21 | self.c6=nn.Conv2d(256,256,kernel_size=3,stride=1, padding=1) 22 | self.c7=nn.Conv2d(256,512,kernel_size=3,stride=1, padding=0) 23 | self.c8=nn.Conv2d(512,256,kernel_size=3,stride=1, padding=0) 24 | self.c9=nn.Conv2d(256,128,kernel_size=3,stride=1, padding=0) 25 | self.l_c1=nn.Linear(128,n_outputs) 26 | self.bn1=nn.BatchNorm2d(128) 27 | self.bn2=nn.BatchNorm2d(128) 28 | self.bn3=nn.BatchNorm2d(128) 29 | self.bn4=nn.BatchNorm2d(256) 30 | self.bn5=nn.BatchNorm2d(256) 31 | self.bn6=nn.BatchNorm2d(256) 32 | self.bn7=nn.BatchNorm2d(512) 33 | self.bn8=nn.BatchNorm2d(256) 34 | self.bn9=nn.BatchNorm2d(128) 35 | 36 | def forward(self, x,): 37 | h=x 38 | h=self.c1(h) 39 | h=F.leaky_relu(call_bn(self.bn1, h), negative_slope=0.01) 40 | h=self.c2(h) 41 | h=F.leaky_relu(call_bn(self.bn2, h), negative_slope=0.01) 42 | h=self.c3(h) 43 | h=F.leaky_relu(call_bn(self.bn3, h), negative_slope=0.01) 44 | h=F.max_pool2d(h, kernel_size=2, stride=2) 45 | h=F.dropout2d(h, p=self.dropout_rate) 46 | 47 | h=self.c4(h) 48 | h=F.leaky_relu(call_bn(self.bn4, h), negative_slope=0.01) 49 | h=self.c5(h) 50 | h=F.leaky_relu(call_bn(self.bn5, h), negative_slope=0.01) 51 | h=self.c6(h) 52 | h=F.leaky_relu(call_bn(self.bn6, h), negative_slope=0.01) 53 | h=F.max_pool2d(h, kernel_size=2, stride=2) 54 | h=F.dropout2d(h, p=self.dropout_rate) 55 | 56 | h=self.c7(h) 57 | h=F.leaky_relu(call_bn(self.bn7, h), negative_slope=0.01) 58 | h=self.c8(h) 59 | h=F.leaky_relu(call_bn(self.bn8, h), negative_slope=0.01) 60 | h=self.c9(h) 61 | h=F.leaky_relu(call_bn(self.bn9, h), negative_slope=0.01) 62 | h=F.avg_pool2d(h, kernel_size=h.data.shape[2]) 63 | 64 | h = h.view(h.size(0), h.size(1)) 65 | logit=self.l_c1(h) 66 | if self.top_bn: 67 | logit=call_bn(self.bn_c1, logit) 68 | return logit 69 | 70 | 71 | --------------------------------------------------------------------------------