├── models ├── __init__.py ├── convbnrelu.py └── convnet.py ├── .gitignore ├── main.py ├── utils.py ├── README.md ├── opts.py ├── train.py └── dataset.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | from convnet import ConvNet as convnet 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Python ### 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | ### PyCharm ### 8 | \.idea/ 9 | -------------------------------------------------------------------------------- /models/convbnrelu.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | import chainer.functions as F 3 | import chainer.links as L 4 | 5 | 6 | class ConvBNReLU(chainer.Chain): 7 | def __init__(self, in_channels, out_channels, ksize, stride=1, pad=0, 8 | initialW=chainer.initializers.HeNormal(), nobias=True): 9 | super(ConvBNReLU, self).__init__( 10 | conv=L.Convolution2D(in_channels, out_channels, ksize, stride, pad, 11 | initialW=initialW, nobias=nobias), 12 | bn=L.BatchNormalization(out_channels, eps=1e-5) 13 | ) 14 | 15 | def __call__(self, x, train): 16 | h = self.conv(x) 17 | h = self.bn(h, test=not train) 18 | 19 | return F.relu(h) 20 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Between-class Learning for Image Classification. 3 | Yuji Tokozume, Yoshitaka Ushiku, and Tatsuya Harada 4 | 5 | """ 6 | 7 | import sys 8 | import os 9 | import chainer 10 | 11 | import opts 12 | import models 13 | import dataset 14 | from train import Trainer 15 | 16 | 17 | def main(): 18 | opt = opts.parse() 19 | chainer.cuda.get_device_from_id(opt.gpu).use() 20 | for i in range(1, opt.nTrials + 1): 21 | print('+-- Trial {} --+'.format(i)) 22 | train(opt, i) 23 | 24 | 25 | def train(opt, trial): 26 | model = getattr(models, opt.netType)(opt.nClasses) 27 | model.to_gpu() 28 | optimizer = chainer.optimizers.NesterovAG(lr=opt.LR, momentum=opt.momentum) 29 | optimizer.setup(model) 30 | optimizer.add_hook(chainer.optimizer.WeightDecay(opt.weightDecay)) 31 | train_iter, val_iter = dataset.setup(opt) 32 | trainer = Trainer(model, optimizer, train_iter, val_iter, opt) 33 | 34 | for epoch in range(1, opt.nEpochs + 1): 35 | train_loss, train_top1 = trainer.train(epoch) 36 | val_top1 = trainer.val() 37 | sys.stderr.write('\r\033[K') 38 | sys.stdout.write( 39 | '| Epoch: {}/{} | Train: LR {} Loss {:.3f} top1 {:.2f} | Val: top1 {:.2f}\n'.format( 40 | epoch, opt.nEpochs, trainer.optimizer.lr, train_loss, train_top1, val_top1)) 41 | sys.stdout.flush() 42 | 43 | if opt.save != 'None': 44 | chainer.serializers.save_npz( 45 | os.path.join(opt.save, 'model_trial{}.npz'.format(trial)), model) 46 | 47 | 48 | if __name__ == '__main__': 49 | main() 50 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import chainer.functions as F 4 | 5 | 6 | def padding(pad): 7 | def f(image): 8 | return np.pad(image, ((0, 0), (pad, pad), (pad, pad)), 'constant') 9 | 10 | return f 11 | 12 | 13 | def random_crop(size): 14 | def f(image): 15 | _, h, w = image.shape 16 | p = random.randint(0, h - size) 17 | q = random.randint(0, w - size) 18 | return image[:, p: p + size, q: q + size] 19 | 20 | return f 21 | 22 | 23 | def horizontal_flip(): 24 | def f(image): 25 | if random.randint(0, 1): 26 | image = image[:, :, ::-1] 27 | return image 28 | 29 | return f 30 | 31 | 32 | def normalize(mean, std): 33 | def f(image): 34 | return (image - mean[:, None, None]) / std[:, None, None] 35 | 36 | return f 37 | 38 | 39 | # For BC+ 40 | def zero_mean(mean, std): 41 | def f(image): 42 | image_mean = np.mean(image, keepdims=True) 43 | return (image - image_mean - mean[:, None, None]) / std[:, None, None] 44 | 45 | return f 46 | 47 | 48 | def kl_divergence(y, t): 49 | entropy = - F.sum(t[t.data.nonzero()] * F.log(t[t.data.nonzero()])) 50 | crossEntropy = - F.sum(t * F.log_softmax(y)) 51 | 52 | return (crossEntropy - entropy) / y.shape[0] 53 | 54 | 55 | def to_hms(time): 56 | h = int(time // 3600) 57 | m = int((time - h * 3600) // 60) 58 | s = int(time - h * 3600 - m * 60) 59 | if h > 0: 60 | line = '{}h{:02d}m'.format(h, m) 61 | else: 62 | line = '{}m{:02d}s'.format(m, s) 63 | 64 | return line 65 | -------------------------------------------------------------------------------- /models/convnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import chainer 3 | import chainer.functions as F 4 | import chainer.links as L 5 | from chainer.initializers import Uniform 6 | from convbnrelu import ConvBNReLU 7 | 8 | 9 | class ConvNet(chainer.Chain): 10 | def __init__(self, n_classes): 11 | super(ConvNet, self).__init__( 12 | conv11=ConvBNReLU(3, 64, 3, pad=1), 13 | conv12=ConvBNReLU(64, 64, 3, pad=1), 14 | conv21=ConvBNReLU(64, 128, 3, pad=1), 15 | conv22=ConvBNReLU(128, 128, 3, pad=1), 16 | conv31=ConvBNReLU(128, 256, 3, pad=1), 17 | conv32=ConvBNReLU(256, 256, 3, pad=1), 18 | conv33=ConvBNReLU(256, 256, 3, pad=1), 19 | conv34=ConvBNReLU(256, 256, 3, pad=1), 20 | fc4=L.Linear(256 * 4 * 4, 1024, initialW=Uniform(1. / math.sqrt(256 * 4 * 4))), 21 | fc5=L.Linear(1024, 1024, initialW=Uniform(1. / math.sqrt(1024))), 22 | fc6=L.Linear(1024, n_classes, initialW=Uniform(1. / math.sqrt(1024))) 23 | ) 24 | self.train = True 25 | 26 | def __call__(self, x): 27 | h = self.conv11(x, self.train) 28 | h = self.conv12(h, self.train) 29 | h = F.max_pooling_2d(h, 2) 30 | 31 | h = self.conv21(h, self.train) 32 | h = self.conv22(h, self.train) 33 | h = F.max_pooling_2d(h, 2) 34 | 35 | h = self.conv31(h, self.train) 36 | h = self.conv32(h, self.train) 37 | h = self.conv33(h, self.train) 38 | h = self.conv34(h, self.train) 39 | h = F.max_pooling_2d(h, 2) 40 | 41 | h = F.dropout(F.relu(self.fc4(h)), train=self.train) 42 | h = F.dropout(F.relu(self.fc5(h)), train=self.train) 43 | 44 | return self.fc6(h) 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | BC learning for images 2 | ========================= 3 | 4 | Implementation of [Between-class Learning for Image Classification](https://arxiv.org/abs/1711.10284) by Yuji Tokozume, Yoshitaka Ushiku, and Tatsuya Harada. 5 | 6 | Our preliminary experimental results on CIFAR-10 and ImageNet-1K were already presented in ILSVRC2017 on July 26, 2017. 7 | 8 | #### Between-class (BC) learning: 9 | - We generate between-class examples by mixing two training examples belonging to different classes with a random ratio. 10 | - We then input the mixed data to the model and 11 | train the model to output the mixing ratio. 12 | - Original paper: [Learning from Between-class Examples for Deep Sound Recognition](https://arxiv.org/abs/1711.10282) by us ([github](https://github.com/mil-tokyo/bc_learning_sound)) 13 | 14 | ## Contents 15 | - BC learning for images 16 | - BC: mix two images simply using internal divisions. 17 | - BC+: mix two images treating them as waveforms. 18 | - Training of 11-layer CNN on CIFAR datasets 19 | 20 | 21 | ## Setup 22 | - Install [Chainer](https://chainer.org/) v1.24 on a machine with CUDA GPU. 23 | - Prepare CIFAR datasets. 24 | 25 | 26 | ## Training 27 | - Template: 28 | 29 | python main.py --dataset [cifar10 or cifar100] --netType convnet --data path/to/dataset/directory/ (--BC) (--plus) 30 | 31 | - Recipes: 32 | - Standard learning on CIFAR-10 (around 6.1% error): 33 | 34 | python main.py --dataset cifar10 --netType convnet --data path/to/dataset/directory/ 35 | 36 | 37 | - BC learning on CIFAR-10 (around 5.4% error): 38 | 39 | python main.py --dataset cifar10 --netType convnet --data path/to/dataset/directory/ --BC 40 | 41 | - BC+ learning on CIFAR-10 (around 5.2% error): 42 | 43 | python main.py --dataset cifar10 --netType convnet --data path/to/dataset/directory/ --BC --plus 44 | 45 | - Notes: 46 | - It uses the same data augmentation scheme as [fb.resnet.torch](https://github.com/facebook/fb.resnet.torch). 47 | - By default, it runs training 10 times. You can specify the number of trials by using --nTrials command. 48 | - Please check [opts.py](https://github.com/mil-tokyo/bc_learning_image/blob/master/opts.py) for other command line arguments. 49 | 50 | ## Results 51 | 52 | Error rate (average of 10 trials) 53 | 54 | | Learning | CIFAR-10 | CIFAR-100 | 55 | |:--|:-:|:-:| 56 | | Standard | 6.07 | 26.68 | 57 | | BC (ours) | 5.40 | 24.28 | 58 | | BC+ (ours) | **5.22** | **23.68** | 59 | 60 | - Other results (please see [paper](https://arxiv.org/abs/1711.10284)): 61 | - The performance of [Shake-Shake Regularization](https://github.com/xgastaldi/shake-shake) [[1]](#1) on CIFAR-10 was improved from 2.86% to 2.26%. 62 | - The performance of [ResNeXt](https://github.com/facebookresearch/ResNeXt) [[2]](#2) on ImageNet-1K was improved from 20.4% to 19.4% (single-crop top-1 validation error). 63 | 64 | --- 65 | 66 | #### Reference 67 | [1] X. Gastaldi. Shake-shake regularization. In *ICLR Workshop*, 2017. 68 | 69 | [2] S. Xie, R. Girshick, P. Dollar, Z. Tu, and K. He. Aggregated residual transformations for deep neural networks. In *CVPR*, 2017. 70 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | def parse(): 6 | parser = argparse.ArgumentParser(description='BC learning for image classification') 7 | 8 | # General settings 9 | parser.add_argument('--dataset', required=True, choices=['cifar10', 'cifar100']) 10 | parser.add_argument('--netType', required=True, choices=['convnet']) 11 | parser.add_argument('--data', required=True, help='Path to dataset') 12 | parser.add_argument('--nTrials', type=int, default=10) 13 | parser.add_argument('--save', default='None', help='Directory to save the results') 14 | parser.add_argument('--gpu', type=int, default=0) 15 | 16 | # Learning settings 17 | parser.add_argument('--BC', action='store_true', help='BC learning') 18 | parser.add_argument('--plus', action='store_true', help='Use BC+') 19 | parser.add_argument('--nEpochs', type=int, default=-1) 20 | parser.add_argument('--LR', type=float, default=-1, help='Initial learning rate') 21 | parser.add_argument('--schedule', type=float, nargs='*', default=-1, help='When to divide the LR') 22 | parser.add_argument('--warmup', type=int, default=-1, help='Number of epochs to warm up') 23 | parser.add_argument('--batchSize', type=int, default=-1) 24 | parser.add_argument('--weightDecay', type=float, default=5e-4) 25 | parser.add_argument('--momentum', type=float, default=0.9) 26 | 27 | opt = parser.parse_args() 28 | if opt.plus and not opt.BC: 29 | raise Exception('Using only --plus option is invalid.') 30 | 31 | # Dataset details 32 | if opt.dataset == 'cifar10': 33 | opt.nClasses = 10 34 | else: # cifar100 35 | opt.nClasses = 100 36 | 37 | # Default settings 38 | default_settings = dict() 39 | default_settings['cifar10'] = { 40 | 'convnet': {'nEpochs': 250, 'LR': 0.1, 'schedule': [0.4, 0.6, 0.8], 'warmup': 0, 'batchSize': 128} 41 | } 42 | default_settings['cifar100'] = { 43 | 'convnet': {'nEpochs': 250, 'LR': 0.1, 'schedule': [0.4, 0.6, 0.8], 'warmup': 0, 'batchSize': 128} 44 | } 45 | for key in ['nEpochs', 'LR', 'schedule', 'warmup', 'batchSize']: 46 | if eval('opt.{}'.format(key)) == -1: 47 | setattr(opt, key, default_settings[opt.dataset][opt.netType][key]) 48 | 49 | if opt.save != 'None' and not os.path.isdir(opt.save): 50 | os.makedirs(opt.save) 51 | 52 | display_info(opt) 53 | 54 | return opt 55 | 56 | 57 | def display_info(opt): 58 | if opt.BC: 59 | if opt.plus: 60 | learning = 'BC+' 61 | else: 62 | learning = 'BC' 63 | else: 64 | learning = 'standard' 65 | 66 | print('+------------------------------+') 67 | print('| CIFAR classification') 68 | print('+------------------------------+') 69 | print('| dataset : {}'.format(opt.dataset)) 70 | print('| netType : {}'.format(opt.netType)) 71 | print('| learning : {}'.format(learning)) 72 | print('| nEpochs : {}'.format(opt.nEpochs)) 73 | print('| LRInit : {}'.format(opt.LR)) 74 | print('| schedule : {}'.format(opt.schedule)) 75 | print('| warmup : {}'.format(opt.warmup)) 76 | print('| batchSize: {}'.format(opt.batchSize)) 77 | print('+------------------------------+') 78 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import chainer 4 | from chainer import cuda 5 | import chainer.functions as F 6 | import time 7 | 8 | import utils 9 | 10 | 11 | class Trainer: 12 | def __init__(self, model, optimizer, train_iter, val_iter, opt): 13 | self.model = model 14 | self.optimizer = optimizer 15 | self.train_iter = train_iter 16 | self.val_iter = val_iter 17 | self.opt = opt 18 | self.n_batches = (len(train_iter.dataset) - 1) // opt.batchSize + 1 19 | self.start_time = time.time() 20 | 21 | def train(self, epoch): 22 | self.optimizer.lr = self.lr_schedule(epoch) 23 | train_loss = 0 24 | train_acc = 0 25 | for i, batch in enumerate(self.train_iter): 26 | x_array, t_array = chainer.dataset.concat_examples(batch) 27 | x = chainer.Variable(cuda.to_gpu(x_array)) 28 | t = chainer.Variable(cuda.to_gpu(t_array)) 29 | self.optimizer.zero_grads() 30 | y = self.model(x) 31 | if self.opt.BC: 32 | loss = utils.kl_divergence(y, t) 33 | acc = F.accuracy(y, F.argmax(t, axis=1)) 34 | else: 35 | loss = F.softmax_cross_entropy(y, t) 36 | acc = F.accuracy(y, t) 37 | 38 | loss.backward() 39 | self.optimizer.update() 40 | train_loss += float(loss.data) * len(t.data) 41 | train_acc += float(acc.data) * len(t.data) 42 | 43 | elapsed_time = time.time() - self.start_time 44 | progress = (self.n_batches * (epoch - 1) + i + 1) * 1.0 / (self.n_batches * self.opt.nEpochs) 45 | eta = elapsed_time / progress - elapsed_time 46 | 47 | line = '* Epoch: {}/{} ({}/{}) | Train: LR {} | Time: {} (ETA: {})'.format( 48 | epoch, self.opt.nEpochs, i + 1, self.n_batches, 49 | self.optimizer.lr, utils.to_hms(elapsed_time), utils.to_hms(eta)) 50 | sys.stderr.write('\r\033[K' + line) 51 | sys.stderr.flush() 52 | 53 | self.train_iter.reset() 54 | train_loss /= len(self.train_iter.dataset) 55 | train_top1 = 100 * (1 - train_acc / len(self.train_iter.dataset)) 56 | 57 | return train_loss, train_top1 58 | 59 | def val(self): 60 | self.model.train = False 61 | val_acc = 0 62 | for batch in self.val_iter: 63 | x_array, t_array = chainer.dataset.concat_examples(batch) 64 | x = chainer.Variable(cuda.to_gpu(x_array), volatile=True) 65 | t = chainer.Variable(cuda.to_gpu(t_array), volatile=True) 66 | y = F.softmax(self.model(x)) 67 | acc = F.accuracy(y, t) 68 | val_acc += float(acc.data) * len(t.data) 69 | 70 | self.val_iter.reset() 71 | self.model.train = True 72 | val_top1 = 100 * (1 - val_acc / len(self.val_iter.dataset)) 73 | 74 | return val_top1 75 | 76 | def lr_schedule(self, epoch): 77 | divide_epoch = np.array([self.opt.nEpochs * i for i in self.opt.schedule]) 78 | decay = sum(epoch > divide_epoch) 79 | if epoch <= self.opt.warmup: 80 | decay = 1 81 | 82 | return self.opt.LR * np.power(0.1, decay) 83 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import random 4 | import cPickle 5 | import chainer 6 | 7 | import utils as U 8 | 9 | 10 | class ImageDataset(chainer.dataset.DatasetMixin): 11 | def __init__(self, images, labels, opt, train=True): 12 | self.base = chainer.datasets.TupleDataset(images, labels) 13 | self.opt = opt 14 | self.train = train 15 | self.mix = (opt.BC and train) 16 | if opt.dataset == 'cifar10': 17 | if opt.plus: 18 | self.mean = np.array([4.60, 2.24, -6.84]) 19 | self.std = np.array([55.9, 53.7, 56.5]) 20 | else: 21 | self.mean = np.array([125.3, 123.0, 113.9]) 22 | self.std = np.array([63.0, 62.1, 66.7]) 23 | else: 24 | if opt.plus: 25 | self.mean = np.array([7.37, 2.13, -9.50]) 26 | self.std = np.array([57.6, 54.0, 58.5]) 27 | else: 28 | self.mean = np.array([129.3, 124.1, 112.4]) 29 | self.std = np.array([68.2, 65.4, 70.4]) 30 | 31 | self.preprocess_funcs = self.preprocess_setup() 32 | 33 | def __len__(self): 34 | return len(self.base) 35 | 36 | def preprocess_setup(self): 37 | if self.opt.plus: 38 | normalize = U.zero_mean 39 | else: 40 | normalize = U.normalize 41 | if self.train: 42 | funcs = [normalize(self.mean, self.std), 43 | U.horizontal_flip(), 44 | U.padding(4), 45 | U.random_crop(32), 46 | ] 47 | else: 48 | funcs = [normalize(self.mean, self.std)] 49 | 50 | return funcs 51 | 52 | def preprocess(self, image): 53 | for f in self.preprocess_funcs: 54 | image = f(image) 55 | 56 | return image 57 | 58 | def get_example(self, i): 59 | if self.mix: # Training phase of BC learning 60 | while True: # Select two training examples 61 | image1, label1 = self.base[random.randint(0, len(self.base) - 1)] 62 | image2, label2 = self.base[random.randint(0, len(self.base) - 1)] 63 | if label1 != label2: 64 | break 65 | image1 = self.preprocess(image1) 66 | image2 = self.preprocess(image2) 67 | 68 | # Mix two images 69 | r = np.array(random.random()) 70 | if self.opt.plus: 71 | g1 = np.std(image1) 72 | g2 = np.std(image2) 73 | p = 1.0 / (1 + g1 / g2 * (1 - r) / r) 74 | image = ((image1 * p + image2 * (1 - p)) / np.sqrt(p ** 2 + (1 - p) ** 2)).astype(np.float32) 75 | else: 76 | image = (image1 * r + image2 * (1 - r)).astype(np.float32) 77 | 78 | # Mix two labels 79 | eye = np.eye(self.opt.nClasses) 80 | label = (eye[label1] * r + eye[label2] * (1 - r)).astype(np.float32) 81 | 82 | else: # Training phase of standard learning or testing phase 83 | image, label = self.base[i] 84 | image = self.preprocess(image).astype(np.float32) 85 | label = np.array(label, dtype=np.int32) 86 | 87 | return image, label 88 | 89 | 90 | def setup(opt): 91 | def unpickle(fn): 92 | with open(fn, 'rb') as f: 93 | data = cPickle.load(f) 94 | return data 95 | 96 | if opt.dataset == 'cifar10': 97 | train = [unpickle(os.path.join(opt.data, 'data_batch_{}'.format(i))) for i in range(1, 6)] 98 | train_images = np.concatenate([d['data'] for d in train]).reshape((-1, 3, 32, 32)) 99 | train_labels = np.concatenate([d['labels'] for d in train]) 100 | val = unpickle(os.path.join(opt.data, 'test_batch')) 101 | val_images = val['data'].reshape((-1, 3, 32, 32)) 102 | val_labels = val['labels'] 103 | else: 104 | train = unpickle(os.path.join(opt.data, 'train')) 105 | train_images = train['data'].reshape(-1, 3, 32, 32) 106 | train_labels = train['fine_labels'] 107 | val = unpickle(os.path.join(opt.data, 'test')) 108 | val_images = val['data'].reshape((-1, 3, 32, 32)) 109 | val_labels = val['fine_labels'] 110 | 111 | # Iterator setup 112 | train_data = ImageDataset(train_images, train_labels, opt, train=True) 113 | val_data = ImageDataset(val_images, val_labels, opt, train=False) 114 | train_iter = chainer.iterators.MultiprocessIterator(train_data, opt.batchSize, repeat=False) 115 | val_iter = chainer.iterators.SerialIterator(val_data, opt.batchSize, repeat=False, shuffle=False) 116 | 117 | return train_iter, val_iter 118 | --------------------------------------------------------------------------------