├── .DS_Store ├── README.md ├── __init__.py ├── ban ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── ban.cpython-36.pyc ├── ban.py ├── hintonDistill.py ├── pipeline.py └── train_ban.sh ├── common ├── __init__.py └── __pycache__ │ └── __init__.cpython-37.pyc ├── core ├── KSNet.py ├── KSNet.pyc ├── __init__.py ├── __init__.pyc ├── __pycache__ │ ├── KSNet.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── build_model.cpython-37.pyc │ ├── ensemble.cpython-37.pyc │ ├── resnet2.cpython-37.pyc │ └── smallnet.cpython-37.pyc ├── base.py ├── base.pyc ├── build_model.py ├── build_model.pyc ├── cifar_model_zoo.py ├── densenet.py ├── ensemble.py ├── ensemble.pyc ├── exp1.py ├── inception.py ├── resnet.py ├── resnet.pyc ├── resnet2.py ├── resnet2.pyc ├── smallnet.py ├── smallnet.pyc └── vgg.py ├── coteaching ├── coteaching.py └── loss.py ├── distill ├── __init__.py ├── __init__.pyc ├── distill.py ├── distill.pyc ├── loss.py ├── loss.pyc ├── matching.py └── toyconv.py ├── mentornet └── mentornet.py ├── onthefly ├── ONE.py └── trainone.py └── train.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waitwaitforget/KnowledgeSharing-Pytorch/5b82fbd5f67c191b86dfe31419c2d7ae12217598/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KnowledgeSharing-Pytorch 2 | This repository is maintained to implement some state-of-the-art knowledge distillation and knowledge transfer methods. 3 | 4 | ## ToDo List 5 | 6 | 7 | ## Knowledge Distillation (KD) 8 | Knowledge distillation was proposed to distill knowledge from a large teacher network to a smaller student network. KD can help the student model to achieve higher generalization performance. It's applications include model compression. 9 | 10 | ## Knowledge Transfer (KT) 11 | 12 | 13 | 14 | ### Model List 15 | - Basic knowledge distillation 16 | - Born-again Neural Networks 17 | - Knowledge Transfer with Jacobian Matching 18 | - Deep Mutual Learning 19 | - Co-teaching 20 | - One-the-fly Native Ensemble 21 | - MentorNet 22 | 23 | 24 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waitwaitforget/KnowledgeSharing-Pytorch/5b82fbd5f67c191b86dfe31419c2d7ae12217598/__init__.py -------------------------------------------------------------------------------- /ban/__init__.py: -------------------------------------------------------------------------------- 1 | from .ban import BAN 2 | 3 | __all__= ['BAN'] -------------------------------------------------------------------------------- /ban/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waitwaitforget/KnowledgeSharing-Pytorch/5b82fbd5f67c191b86dfe31419c2d7ae12217598/ban/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ban/__pycache__/ban.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waitwaitforget/KnowledgeSharing-Pytorch/5b82fbd5f67c191b86dfe31419c2d7ae12217598/ban/__pycache__/ban.cpython-36.pyc -------------------------------------------------------------------------------- /ban/ban.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(".") 3 | sys.path.append("..") 4 | 5 | import os 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class BAN(nn.Module): 12 | 13 | def __init__(self, teacher_model, student_model=None, basic_criterion=None, 14 | c=1.0, ckpt_dir='./checkpoint/ban/'): 15 | super(BAN, self).__init__() 16 | 17 | self.teacher = teacher_model 18 | self.student = student_model 19 | 20 | self.iter = 0 21 | self.ckpt_dir = ckpt_dir 22 | self.teacher_acc = 0.0 23 | self.student_acc = 0.0 24 | self.c = c 25 | self.basic_criterion = basic_criterion# is None else basic_criterion 26 | 27 | def forward(self, x): 28 | if self.student is None: 29 | return self.teacher(x) 30 | else: 31 | with torch.no_grad(): 32 | tea_p = self.teacher(x) 33 | stu_p = self.student(x) 34 | return [tea_p, stu_p] 35 | 36 | @staticmethod 37 | def _save_checkpoint(ckpt_dir, model, acc, iter, name): 38 | state = {'state_dict': model.state_dict(), 39 | 'acc': acc, 40 | 'iter': iter} 41 | if not os.path.exists(ckpt_dir): 42 | os.makedirs(ckpt_dir) 43 | 44 | torch.save(state, os.path.join(ckpt_dir, name)) 45 | 46 | 47 | def expand(self, model, tea_acc): 48 | #if self.student is None: 49 | # raise ValueError( 50 | # 'There is no well trained student available for expand!') 51 | # save current teacher to disk 52 | teacher_ckpt_path = 'teacher_%i.pt' % self.iter 53 | BAN._save_checkpoint(self.ckpt_dir, self.teacher, 54 | self.teacher_acc, self.iter, teacher_ckpt_path) 55 | if self.student is not None: 56 | self.teacher = self.student 57 | self.student = model 58 | self.iter += 1 59 | self.teacher_acc = tea_acc 60 | 61 | 62 | def save(self, stu_acc): 63 | self.student_acc = stu_acc 64 | # save current teacher and student model to disk 65 | teacher_ckpt_path = 'teacher_%i.pt' % self.iter 66 | student_ckpt_path = 'student_%i.pt' % self.iter 67 | 68 | BAN._save_checkpoint(self.ckpt_dir, self.teacher, 69 | self.teacher_acc, self.iter, teacher_ckpt_path) 70 | BAN._save_checkpoint(self.ckpt_dir, self.student, 71 | self.student_acc, self.iter, student_ckpt_path) 72 | 73 | @staticmethod 74 | def _soft_cross_entropy(x, y): 75 | b = - torch.sum(F.softmax(y, dim=1) * F.log_softmax(x, dim=1), dim=1) 76 | 77 | return b.mean() 78 | 79 | def _common_fn(self, pred, target): 80 | if isinstance(pred, list): 81 | tea_logits, stu_logits = pred[0], pred[1] 82 | if self.basic_criterion is not None: 83 | basic_loss = self.basic_criterion(stu_logits, target) 84 | else: 85 | basic_loss = 0 86 | # distill loss 87 | distill_loss = BAN._soft_cross_entropy(stu_logits, tea_logits.detach()) 88 | return basic_loss + self.c * distill_loss 89 | else: 90 | return F.cross_entropy(pred, target) 91 | 92 | def _cwtm_fn(self, pred, target): 93 | if not isinstance(pred, list): 94 | raise ValueError('cwtm mode need a student model') 95 | # 96 | tea_logits, stu_logits = pred[0], pred[1] 97 | tea_prob = F.softmax(tea_logits.detach()) 98 | # reweight the samples 99 | mvalue, _ = tea_prob.max(1) 100 | weight = mvalue / mvalue.sum() 101 | # compute the weighted cross entropy 102 | b = tea_prob * F.log_softmax(stu_logits, dim=1) 103 | 104 | b = weight.unsqueeze(1) * b 105 | distill_loss = -1.0 * b.sum() 106 | basic_loss = self.basic_criterion(stu_logits, target) 107 | return basic_loss + self.c * distill_loss 108 | 109 | @staticmethod 110 | def _permutate_non_argmax(x): 111 | m, n = x.size(0), x.size(1) 112 | x_max, x_idx = x.max(1) 113 | idx = torch.stack([torch.randperm(n) 114 | for _ in range(m)]).long().to(x.device) 115 | y = torch.zeros(x.size()).to(x.device) 116 | y.scatter_(1, idx, x) 117 | y_idx = idx.gather(1, x_idx.view(-1, 1)) 118 | y_place = y.gather(1, x_idx.view(-1, 1)) 119 | 120 | y.scatter_(1, y_idx.view(-1, 1), y_place.view(-1, 1)) 121 | y.scatter_(1, x_idx.view(-1, 1), x_max.view(-1, 1)) 122 | return y 123 | 124 | def _dkpp_fn(self, pred, target): 125 | if not isinstance(pred, list): 126 | raise ValueError('cwtm mode need a student model') 127 | tea_logits, stu_logits = pred[0], pred[1] 128 | tea_prob = F.softmax(tea_logits.detach(), dim=1) 129 | # permutate the non-max args 130 | permuted_tea_prob = BAN._permutate_non_argmax(tea_prob) 131 | 132 | # calculate loss 133 | basic_loss = self.basic_criterion(stu_logits, target) 134 | distill_loss = permuted_tea_prob * F.log_softmax(stu_logits, dim=1) 135 | distill_loss = -1.0 * distill_loss.sum() 136 | 137 | return basic_loss + self.c * distill_loss 138 | 139 | def loss(self, pred, target, mode='common'): 140 | ''' 141 | compute loss with three mode: 142 | common: the same as hinton knowledge distillation without soften target 143 | cwtm: reweight the sample 144 | dkpp: permute the non-argmax targets 145 | ''' 146 | if mode == 'common': 147 | return self._common_fn(pred, target) 148 | elif mode == 'cwtm': 149 | return self._cwtm_fn(pred, target) 150 | elif mode == 'dkpp': 151 | return self._dkpp_fn(pred, target) 152 | else: 153 | raise ValueError( 154 | 'Not supported mode. Only "common", "cwtm","dkpp" for selection') 155 | -------------------------------------------------------------------------------- /ban/hintonDistill.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import torchvision.transforms as transforms 8 | import sys 9 | import os 10 | sys.path.append('../') 11 | 12 | import KnowledgeSharing.distill as distill 13 | from KnowledgeSharing.utils.trainutils import AverageMeter 14 | from KnowledgeSharing.utils.metricutils import accuracy 15 | import KnowledgeSharing.utils as utils 16 | import KnowledgeSharing.core as core 17 | import KnowledgeSharing.common as common_args 18 | import time 19 | from tqdm import tqdm 20 | 21 | from copy import deepcopy 22 | from tensorboardX import SummaryWriter 23 | 24 | 25 | args = common_args.parse() 26 | des = '-'.join([args.memo, '_hinton_distll_'+args.model]+ list(map(str,[args.nmodel, args.depth,args.T, args.lr, args.batch_size])) + [utils.random_string(6)]) 27 | writer = SummaryWriter(os.path.join('runs',des)) 28 | 29 | teacher_model = core.build_bottleneck_resnet(args.teach_depth, args.nb_classes) 30 | teacher_model.to(args.devices) 31 | ckpt_name = os.path.join(args.ckpt_dir, '%s_%i_tea_best_checkpoint.pt'%(args.model, args.teach_depth)) 32 | print('try loading ckpt',ckpt_name) 33 | if os.path.exists(ckpt_name): 34 | state_ = torch.load(ckpt_name) 35 | teacher_model.load_state_dict(state_['state_dict']) 36 | else: 37 | # training from scratch 38 | raise Exception('checkpoint not found') 39 | 40 | student_model = core.build_bottleneck_resnet(args.depth, args.nb_classes) 41 | student_model.to(args.devices) 42 | model = distill.DistillModel(teacher_model, student_model) 43 | optimizer, lr_scheduler = utils.trainutils.build_optimizer_scheduler(model, args) 44 | 45 | 46 | def compose_name(*args): 47 | return '-'.join(args) 48 | 49 | @model.loss_fn 50 | def loss(pred, target): 51 | s_pred, t_pred = pred 52 | return distill.loss.hinton_distill_loss(t_pred, s_pred, target) 53 | 54 | """ training teacher model and distill it to student""" 55 | def train_epoch(model, optimizer, train_loader, epoch, callback=None, **kwargs): 56 | global writer 57 | model.train() 58 | loss = AverageMeter() 59 | metrics = AverageMeter() 60 | data_time = time.time() 61 | train_bar = tqdm(train_loader) 62 | devices = 'cuda:0' if 'devices' not in kwargs.keys() else kwargs['devices'] 63 | for idx, (data, target) in enumerate(train_loader): 64 | data = data.to(devices) 65 | target = target.to(devices) 66 | data_time = time.time() - data_time 67 | # forward 68 | batch_time = time.time() 69 | pred = model(data) 70 | batch_time = time.time() - batch_time 71 | batch_loss = model.loss(pred, target) 72 | metrics.update(accuracy(pred[0], target), data.size(0)) 73 | 74 | loss.update(batch_loss.item(), data.size(0)) 75 | # backward 76 | optimizer.zero_grad() 77 | batch_loss.backward() 78 | optimizer.step() 79 | # print(metrics.avg,loss.avg,type(loss.avg)) 80 | train_bar.set_description('Epoch {}/{}, data_time: {:.4f}, batch_time: {:.4f}, loss: {:.4f}, accuracy: {:.4f}'.format(idx, epoch, 81 | data_time, batch_time, loss.avg, metrics.avg.item())) 82 | data_time = time.time() 83 | 84 | writer.add_scalar('data/teacher_train_loss', loss.avg, epoch) 85 | writer.add_scalar('data/teachertrain_acc', metrics.avg.item(),epoch) 86 | 87 | 88 | def validate(model, val_loader, epoch, callback=None, **kwargs): 89 | global writer 90 | model.eval() 91 | loss = AverageMeter() 92 | metrics = AverageMeter() 93 | data_time = time.time() 94 | val_bar = tqdm(val_loader) 95 | devices = 'cuda:0' if 'devices' not in kwargs.keys() else kwargs['devices'] 96 | with torch.no_grad(): 97 | for idx, (data, target) in enumerate(val_loader): 98 | data = data.to(devices) 99 | target = target.to(devices) 100 | data_time = time.time() - data_time 101 | # forward 102 | batch_time = time.time() 103 | pred = model(data) 104 | batch_time = time.time() - batch_time 105 | batch_loss = model.loss(pred, target) 106 | metrics.update(accuracy(pred[0], target), data.size(0)) 107 | 108 | loss.update(batch_loss.item(), data.size(0)) 109 | 110 | 111 | val_bar.set_description('Eval epoch {}/{}, data_time: {:.4f}, batch_time: {:.4f}, loss: {:.4f}, accuracy: {:.4f}'.format(idx, epoch, 112 | data_time, batch_time, loss.avg, metrics.avg.item())) 113 | data_time = time.time() 114 | writer.add_scalar('data/teacher_val_loss', loss.avg, epoch) 115 | writer.add_scalar('data/teacher_val_acc', metrics.avg.item(),epoch) 116 | return metrics.avg 117 | 118 | # just training the student model 119 | 120 | print('training student model') 121 | train_transform= transforms.Compose([transforms.RandomHorizontalFlip(), 122 | transforms.RandomCrop(32, padding=4), 123 | transforms.ToTensor(), 124 | transforms.Normalize(mean=[n/255. for n in [129.3, 124.1, 112.4]], 125 | std=[n/255. for n in [68.2, 65.4, 70.4]])]) 126 | 127 | val_transform= transforms.Compose([transforms.ToTensor(), 128 | transforms.Normalize(mean=[n/255. for n in [129.3, 124.1, 112.4]], 129 | std=[n/255. for n in [68.2, 65.4, 70.4]])]) 130 | train_loader, val_loader = utils.build_vision_dataloader(args.data_dir, args.dataset, 131 | batch_size=args.batch_size, num_workers=args.num_workers, train_transform=train_transform,val_transform=val_transform) 132 | best_accuracy = 0.0 133 | for epoch in range(args.max_epoch): 134 | lr_scheduler.step(epoch) 135 | train_epoch(model, optimizer, train_loader, epoch,devices=args.devices) 136 | val_acc = validate(model, val_loader, epoch,devices=args.devices) 137 | if val_acc > best_accuracy or (epoch+1)%args.ckpt_interval == 0: 138 | utils.trainutils.save_checkpoints(args.ckpt_dir, model, val_acc, best_accuracy, epoch, name=compose_name(args.memo, args.exp_name, args.model,'student')) 139 | best_accuracy = val_acc 140 | -------------------------------------------------------------------------------- /ban/pipeline.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Born-Again Neural Network 3 | The pipeline is similar to the knowledge distillation with more steps. 4 | ref: https://arxiv.org/abs/1805.04770 5 | 6 | Author: Kai Tian 7 | Date: 13/12/2018 8 | ''' 9 | 10 | from __future__ import absolute_import 11 | import sys 12 | import os 13 | sys.path.append('..') 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.optim as optim 18 | import torch.nn.functional as F 19 | 20 | import torchvision.transforms as transforms 21 | 22 | import KnowledgeSharing.distill as distill 23 | import KnowledgeSharing.utils as utils 24 | import KnowledgeSharing.core as core 25 | from KnowledgeSharing.utils.metricutils import accuracy 26 | import KnowledgeSharing.common as common_args 27 | from KnowledgeSharing.utils.trainutils import AverageMeter 28 | import time 29 | from tqdm import tqdm 30 | 31 | from copy import deepcopy 32 | from tensorboardX import SummaryWriter 33 | from ban import BAN 34 | 35 | args = common_args.parse() 36 | des = '-'.join([args.memo, args.model]+ list(map(str,[args.born_time, args.depth, args.c, args.lr, args.batch_size])) + [utils.random_string(6)]) 37 | writer=SummaryWriter(des) 38 | 39 | 40 | """ training teacher model and distill it to student""" 41 | def train_epoch(model, optimizer, train_loader, epoch, callback=None, **kwargs): 42 | global writer 43 | 44 | loss = AverageMeter() 45 | metrics = AverageMeter() 46 | data_time = time.time() 47 | train_bar = tqdm(train_loader) 48 | devices = 'cuda:0' if 'devices' not in kwargs.keys() else 'cuda:'+kwargs['devices'][0] 49 | 50 | for idx, (data, target) in enumerate(train_bar): 51 | data = data.to(devices) 52 | target = target.to(devices) 53 | data_time = time.time() - data_time 54 | # forward 55 | batch_time = time.time() 56 | pred = model(data) 57 | batch_time = time.time() - batch_time 58 | batch_loss = model.loss(pred, target, mode=kwargs['mode']) 59 | metrics.update(accuracy(pred, target), data.size(0)) 60 | 61 | loss.update(batch_loss.item(), data.size(0)) 62 | # backward 63 | optimizer.zero_grad() 64 | batch_loss.backward() 65 | optimizer.step() 66 | # print(metrics.avg,loss.avg,type(loss.avg)) 67 | train_bar.set_description('Epoch {}/{}, data_time: {:.4f}, batch_time: {:.4f}, loss: {:.4f}, accuracy: {:.4f}'.format(idx, epoch, 68 | data_time, batch_time, loss.avg, metrics.avg.item())) 69 | data_time = time.time() 70 | 71 | 72 | return loss.avg, metrics.avg.item() 73 | 74 | def validate(model, val_loader, epoch, callback=None, **kwargs): 75 | model.eval() 76 | global writer 77 | loss = AverageMeter() 78 | metrics = AverageMeter() 79 | data_time = time.time() 80 | val_bar = tqdm(val_loader) 81 | devices = 'cuda:0' if 'devices' not in kwargs.keys() else 'cuda:'+kwargs['devices'][0] 82 | with torch.no_grad(): 83 | for idx, (data, target) in enumerate(val_bar): 84 | data = data.to(devices) 85 | target = target.to(devices) 86 | data_time = time.time() - data_time 87 | # forward 88 | batch_time = time.time() 89 | pred = model(data) 90 | batch_time = time.time() - batch_time 91 | batch_loss = model.loss(pred, target,mode=kwargs['mode']) 92 | metrics.update(accuracy(pred, target), data.size(0)) 93 | 94 | loss.update(batch_loss.item(), data.size(0)) 95 | 96 | 97 | val_bar.set_description('Eval epoch {}/{}, data_time: {:.4f}, batch_time: {:.4f}, loss: {:.4f}, accuracy: {:.4f}'.format(idx, epoch, 98 | data_time, batch_time, loss.avg, metrics.avg.item())) 99 | data_time = time.time() 100 | 101 | return loss.avg, metrics.avg.item() 102 | 103 | 104 | def born_agin_nn(model, student_model, iter, tea_acc=0.0): 105 | global writer 106 | if student_model is not None: 107 | model.expand(student_model, tea_acc) 108 | optimizer = optim.SGD(model.student.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum, nesterov=True) 109 | else: 110 | optimizer = optim.SGD(model.teacher.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum, nesterov=True) 111 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, list(map(int,args.milestones.split(','))), gamma=args.gamma) 112 | # move to gpu 113 | model.to('cuda:'+args.devices[0]) 114 | transform= transforms.Compose([transforms.RandomHorizontalFlip(), 115 | transforms.RandomCrop(32, padding=4), 116 | transforms.ToTensor(), 117 | transforms.Normalize(mean=[n/255. for n in [129.3, 124.1, 112.4]], 118 | std=[n/255. for n in [68.2, 65.4, 70.4]])]) 119 | val_transform= transforms.Compose([transforms.ToTensor(), 120 | transforms.Normalize(mean=[n/255. for n in [129.3, 124.1, 112.4]], 121 | std=[n/255. for n in [68.2, 65.4, 70.4]])]) 122 | train_loader, val_loader = utils.build_vision_dataloader(args.data_dir, args.dataset, 123 | batch_size=args.batch_size, num_workers=args.num_workers, train_transform=transform,val_transform=val_transform) 124 | best_accuracy = 0.0 125 | for epoch in range(args.max_epoch): 126 | scheduler.step(epoch) 127 | train_loss, train_acc = train_epoch(model, optimizer, train_loader, epoch, devices=args.devices, mode=args.mode) 128 | val_loss, val_acc = validate(model, val_loader, epoch,devices=args.devices, mode=args.mode) 129 | 130 | writer.add_scalar('data/born_%i_train_loss'%(iter+1), train_loss, epoch) 131 | writer.add_scalar('data/born_%i_train_acc'%(iter+1), train_acc,epoch) 132 | 133 | writer.add_scalar('data/born_%i_val_loss'%(iter+1), val_loss, epoch) 134 | writer.add_scalar('data/born_%i_val_acc'%(iter+1), val_acc,epoch) 135 | return model, val_acc 136 | 137 | 138 | teacher_model = core.build_model(args.model, args.depth, args) 139 | student_model = None 140 | # initialize model 141 | model = BAN(teacher_model, c=args.c, ckpt_dir=os.path.join(args.ckpt_dir, 'ban', des)) 142 | tea_acc = 0.0 143 | # born again 144 | # born time plus one, as the zero is the teacher 145 | print('training born agagin neural network') 146 | for i in range(args.born_time+1): 147 | print('born %i times'%i) 148 | if i==0: 149 | model, tea_acc = born_agin_nn(model, student_model, i, tea_acc) 150 | else: 151 | model.expand(student_model, tea_acc) 152 | model, tea_acc = born_agin_nn(model, student_model, i, tea_acc) 153 | student_model = core.build_model(args.model, args.depth, args) 154 | # save all models 155 | model.save(tea_acc) 156 | 157 | -------------------------------------------------------------------------------- /ban/train_ban.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin 2 | 3 | python ./BAN/pipeline.py \ 4 | --memo exp1 \ 5 | --data-dir ./data \ 6 | --model resnet \ 7 | --depth 32 \ 8 | -c 1 \ 9 | --born-time 3 \ 10 | --nb-classes 100 \ 11 | --devices 1 \ 12 | --mode common\ 13 | 14 | python ./BAN/pipeline.py --memo exp-cifar100 --dataset cifar100 --data-dir ./data --weight_decay 5e-4 --model plaincnn --depth 6 -c 1 --born-time 3 --nb-classes 100 --mode common 15 | python ./BAN/pipeline.py --memo exp-cifar10 --dataset cifar10 --data-dir ./data --weight_decay 5e-4 --model plaincnn --depth 6 -c 1 --born-time 3 --nb-classes 10 --mode common 16 | 17 | python ./BAN/pipeline.py --memo exp-cifar100 --data-dir ./data --model wrn --depth 28 -c 1 --born-time 3 --nb-classes 100 --devices 1 --mode common --gamma 0.2 --max-epoch 200 --milestones 60,120,160 --weight_decay 5e-4 18 | python ./BAN/pipeline.py --memo exp-cifar10 --data-dir ./data --model wrn --depth 28 -c 1 --born-time 3 --nb-classes 10 --devices 1 --mode common --gamma 0.2 --max-epoch 200 --milestones 60,120,160 --weight_decay 5e-4 19 | python ./BAN/pipeline.py --memo exp-cifar10 --data-dir ./data --model resnet --depth 32 -c 1 --born-time 3 --nb-classes 100 --devices 1 --mode common --weight_decay 1e-4 20 | -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse(): 4 | parser = argparse.ArgumentParser('Common Argument Parser') 5 | parser.add_argument('--memo', default='exp', type=str, help='memo for the experiment') 6 | parser.add_argument('--exp-name', default='exp', type=str, help='memo for the experiment') 7 | # model definition 8 | parser.add_argument('--model', default='resnet', type=str, help='type of model') 9 | parser.add_argument('--depth', default=32, type=int) 10 | parser.add_argument('--teach-depth', default=56, type=int) 11 | parser.add_argument('--nb-classes', default=100, type=int, help='number of classes') 12 | ## model sepcific 13 | parser.add_argument('--mult-mode', default='mult_model',type=str,help='multi model modes') 14 | parser.add_argument('--multi-arch', default='resnet32,resnet32', type=str, help='type of model') 15 | parser.add_argument('--nmodel', default=5, type=int, help='number of models') 16 | parser.add_argument('-T', default=2, type=float, help='temperature for KL-divergence') 17 | parser.add_argument('--loss', default='kl_div', type=str, help='option for loss function') 18 | # born agin network 19 | parser.add_argument('-c', default=1, type=float, help='coefficient of regularization') 20 | parser.add_argument('--born-time', default=3, type=int, help='number of models') 21 | parser.add_argument('--mode', default='common', type=str, help='type of born loss') 22 | 23 | # cb 24 | parser.add_argument('--ratio', default=0.5, type=float, help='coefficient of regularization') 25 | # optimizer 26 | parser.add_argument('--optimizer', default='sgd', type=str, help='type of optimizer') 27 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 28 | parser.add_argument('--momentum', default=0.9, type=float, help='learning rate') 29 | parser.add_argument('--milestones', default='150,225', type=str, help='learning rate') 30 | parser.add_argument('--weight-decay', default=5e-4, type=float, help='learning rate') 31 | parser.add_argument('--gamma', default=0.1, type=float, help='gamma for lr scheduler rate') 32 | # training configuration 33 | parser.add_argument('--debug', action='store_true', help='debug mode') 34 | parser.add_argument('--start_epoch', default=0, type=int, help='start point') 35 | parser.add_argument('--max-epoch', default=300, type=int, help='maxial training epoch') 36 | parser.add_argument('--batch-size', default=256, type=int, help='batch size') 37 | parser.add_argument('--num-workers', default=4, type=int, help='type of optimizer') 38 | 39 | parser.add_argument('--accumulate',default=1, type=int, help='use accumulation for update gradients') 40 | parser.add_argument('--dataset', default='cifar100', type=str, help='type of optimizer') 41 | parser.add_argument('--data-dir', default='./data', type=str, help='type of optimizer') 42 | parser.add_argument('--ckpt-dir', default='./checkpoints', type=str, help='type of optimizer') 43 | parser.add_argument('--devices', default='cpu', type=str, help='need specify devices') 44 | 45 | parser.add_argument('--tensorboard', default=1,type=int, help='need specify devices') 46 | 47 | parser.add_argument('--resume', default='', type=str, 48 | help='path to latest checkpoint (default: none)') 49 | parser.add_argument('--print-interval', default=10, type=int) 50 | parser.add_argument('--ckpt-interval', default=10,type=int) 51 | parser.add_argument('--verbose',default=1, type=int, help='display batch details') 52 | args = parser.parse_args() 53 | return args 54 | -------------------------------------------------------------------------------- /common/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waitwaitforget/KnowledgeSharing-Pytorch/5b82fbd5f67c191b86dfe31419c2d7ae12217598/common/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /core/KSNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class KSNet(nn.Module): 4 | def __init__(self, n_specialist, basenets, devices=None): 5 | super(KSNet, self).__init__() 6 | 7 | if len(basenets) != n_specialist+1: 8 | raise ValueError('num of basenet is not correct') 9 | self.n_specialist = n_specialist 10 | 11 | self.specialists = nn.ModuleList(basenets[:-1]) 12 | self.generalist = basenets[-1] 13 | # if len(devices) < n_specialist+1: 14 | # raise ValueError("Number of devices is smaller than number of specialists and generalist.") 15 | self.devices = devices 16 | 17 | def to(self, devices): 18 | if self.devices is None: 19 | [self.specialists[i].to('cpu') for i in range(self.n_specialist)] 20 | self.generalist.to('cpu') 21 | 22 | if len(self.devices) < self.n_specialist: 23 | [self.specialists[i].to('cuda:%s'%self.devices[0]) for i in range(self.n_specialist)] 24 | self.generalist.to('cuda:%s'%self.devices[0]) 25 | else: 26 | for i in range(self.n_specialist): 27 | self.specialists[i].to('cuda:%s'%self.devices[i]) 28 | self.generalist.to('cuda:%s'%self.devices[-1]) 29 | 30 | def forward(self, x): 31 | spec_logits = [net(x.to(net.device)) for net in self.specialists] 32 | gene_logits = self.generalist(x.to(self.generalist.device)) 33 | 34 | # transfer to devices[0] 35 | spec_logits = [p.to(spec_logits[0].device) for p in spec_logits] 36 | gene_logits = gene_logits.to(spec_logits[0].device) 37 | return spec_logits, gene_logits 38 | 39 | def loss_fn(self, fn): 40 | self._loss = fn 41 | return fn 42 | 43 | def loss(self, pred, targets): 44 | self._loss(pred, targets) -------------------------------------------------------------------------------- /core/KSNet.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waitwaitforget/KnowledgeSharing-Pytorch/5b82fbd5f67c191b86dfe31419c2d7ae12217598/core/KSNet.pyc -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | from .ensemble import Ensemble 2 | from .KSNet import KSNet 3 | from .resnet2 import build_bottleneck_resnet 4 | from .build_model import build_model 5 | -------------------------------------------------------------------------------- /core/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waitwaitforget/KnowledgeSharing-Pytorch/5b82fbd5f67c191b86dfe31419c2d7ae12217598/core/__init__.pyc -------------------------------------------------------------------------------- /core/__pycache__/KSNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waitwaitforget/KnowledgeSharing-Pytorch/5b82fbd5f67c191b86dfe31419c2d7ae12217598/core/__pycache__/KSNet.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waitwaitforget/KnowledgeSharing-Pytorch/5b82fbd5f67c191b86dfe31419c2d7ae12217598/core/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/build_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waitwaitforget/KnowledgeSharing-Pytorch/5b82fbd5f67c191b86dfe31419c2d7ae12217598/core/__pycache__/build_model.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/ensemble.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waitwaitforget/KnowledgeSharing-Pytorch/5b82fbd5f67c191b86dfe31419c2d7ae12217598/core/__pycache__/ensemble.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/resnet2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waitwaitforget/KnowledgeSharing-Pytorch/5b82fbd5f67c191b86dfe31419c2d7ae12217598/core/__pycache__/resnet2.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/smallnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waitwaitforget/KnowledgeSharing-Pytorch/5b82fbd5f67c191b86dfe31419c2d7ae12217598/core/__pycache__/smallnet.cpython-37.pyc -------------------------------------------------------------------------------- /core/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init 4 | import math 5 | import torch.nn.functional as F 6 | 7 | 8 | def init_weights(model): 9 | for m in model.modules(): 10 | if isinstance(m, nn.Conv2d): 11 | torch.nn.init.kaiming_normal(m.weight) 12 | if m.bias is not None: 13 | m.bias.data.zero_() 14 | elif isinstance(m, nn.BatchNorm2d): 15 | m.weight.data.fill_(1.0) 16 | m.bias.data.zero_() 17 | elif isinstance(m, nn.Linear): 18 | if m.bias is not None: 19 | m.bias.data.zero_() 20 | 21 | 22 | def get_n_params(model): 23 | return sum([param.nelement() for param in list(model.parameters())]) 24 | 25 | 26 | def conv3x3(in_planes, out_planes, stride=1, bias=False, dilation=1, padding=1): 27 | """3x3 convolution with padding""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=padding, bias=bias, dilation=dilation) 30 | 31 | 32 | def conv1x1(in_planes, out_planes, stride=1, bias=False): 33 | """3x3 convolution with padding""" 34 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 35 | padding=0, bias=bias) 36 | 37 | 38 | class MultModel(nn.Module): 39 | 40 | def __init__(self, models): 41 | super(MultModel, self).__init__() 42 | self.models = torch.nn.ModuleList(models) 43 | 44 | def forward(self, x): 45 | outs = [m(x) for m in self.models] 46 | return outs 47 | 48 | 49 | class MultOutModel(nn.Module): 50 | 51 | def __init__(self, model, nb_outs=2): 52 | super(MultOutModel, self).__init__() 53 | self.nb_outs = nb_outs 54 | self.model = model 55 | 56 | def forward(self, x): 57 | outs = [self.model(x) for _ in range(self.nb_outs)] 58 | return outs 59 | 60 | 61 | def flip(x, dim): 62 | xsize = x.size() 63 | dim = x.dim() + dim if dim < 0 else dim 64 | x = x.view(-1, *xsize[dim:]) 65 | x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1, 66 | -1, -1), ('cpu','cuda')[x.is_cuda])().long(), :] 67 | return x.view(xsize) 68 | 69 | 70 | class SymmetricOutModel(nn.Module): 71 | def __init__(self, model): 72 | super(SymmetricOutModel, self).__init__() 73 | self.model = model 74 | 75 | def forward(self, x): 76 | y = flip(x, 2) 77 | return [self.model(x), self.model(y)] 78 | 79 | 80 | def add_arguments(parser): 81 | parser.add_argument('--epochs', default=300, type=int, 82 | help='number of total epochs to run') 83 | parser.add_argument('--start-epoch', default=0, type=int, 84 | help='manual epoch number (useful on restarts)') 85 | parser.add_argument('--b', '--batchsize', dest='batchsize', default=64, type=int, 86 | help='mini-batch size (default: 64)') 87 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 88 | help='initial learning rate') 89 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum') 90 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 91 | help='weight decay (default: 1e-4)') -------------------------------------------------------------------------------- /core/base.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waitwaitforget/KnowledgeSharing-Pytorch/5b82fbd5f67c191b86dfe31419c2d7ae12217598/core/base.pyc -------------------------------------------------------------------------------- /core/build_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .resnet2 import build_bottleneck_resnet 3 | from .smallnet import ToyConvNet 4 | 5 | def build_model(model_type,depth, args): 6 | if model_type == 'ksnet': 7 | return None 8 | elif model_type=='resnet': 9 | return build_bottleneck_resnet(depth=depth, nb_classes=args.nb_classes) 10 | elif model_type=='small': 11 | return ToyConvNet(3, 16, 100, 4) 12 | -------------------------------------------------------------------------------- /core/build_model.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waitwaitforget/KnowledgeSharing-Pytorch/5b82fbd5f67c191b86dfe31419c2d7ae12217598/core/build_model.pyc -------------------------------------------------------------------------------- /core/cifar_model_zoo.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waitwaitforget/KnowledgeSharing-Pytorch/5b82fbd5f67c191b86dfe31419c2d7ae12217598/core/cifar_model_zoo.py -------------------------------------------------------------------------------- /core/densenet.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waitwaitforget/KnowledgeSharing-Pytorch/5b82fbd5f67c191b86dfe31419c2d7ae12217598/core/densenet.py -------------------------------------------------------------------------------- /core/ensemble.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class Ensemble(nn.Module): 4 | def __init__(self, m, basenets, devices=None): 5 | super(Ensemble, self).__init__() 6 | self.m = m 7 | 8 | self.devices=devices 9 | self.net = nn.ModuleList(basenets) 10 | 11 | def to(self,devices): 12 | if self.devices is None: 13 | [self.net[i].to('cpu') for i in range(self.m)] 14 | if len(devices)=44 else BasicBlock 102 | 103 | self.inplanes = 16 104 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 105 | bias=False) 106 | self.bn1 = nn.BatchNorm2d(16) 107 | self.relu = nn.ReLU(inplace=True) 108 | self.layer1 = self._make_layer(block, 16, n) 109 | self.layer2 = self._make_layer(block, 32, n, stride=2) 110 | self.layer3 = self._make_layer(block, 64, n, stride=2) 111 | self.avgpool = nn.AvgPool2d(8) 112 | self.fc = nn.Linear(64 * block.expansion, num_classes) 113 | 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 117 | m.weight.data.normal_(0, math.sqrt(2. / n)) 118 | elif isinstance(m, nn.BatchNorm2d): 119 | m.weight.data.fill_(1) 120 | m.bias.data.zero_() 121 | 122 | def _make_layer(self, block, planes, blocks, stride=1): 123 | downsample = None 124 | if stride != 1 or self.inplanes != planes * block.expansion: 125 | downsample = nn.Sequential( 126 | nn.Conv2d(self.inplanes, planes * block.expansion, 127 | kernel_size=1, stride=stride, bias=False), 128 | nn.BatchNorm2d(planes * block.expansion), 129 | ) 130 | 131 | layers = [] 132 | layers.append(block(self.inplanes, planes, stride, downsample)) 133 | self.inplanes = planes * block.expansion 134 | for i in range(1, blocks): 135 | layers.append(block(self.inplanes, planes)) 136 | 137 | return nn.Sequential(*layers) 138 | 139 | def forward(self, x): 140 | x = self.conv1(x) 141 | x = self.bn1(x) 142 | x = self.relu(x) # 32x32 143 | 144 | x = self.layer1(x) # 32x32 145 | x = self.layer2(x) # 16x16 146 | x = self.layer3(x) # 8x8 147 | 148 | x = self.avgpool(x) 149 | x = x.view(x.size(0), -1) 150 | x = self.fc(x) 151 | 152 | return x 153 | 154 | @property 155 | def device(self): 156 | return list(self.parameters())[0].device 157 | 158 | def loss_fn(self, fn): 159 | self._loss = fn 160 | return fn 161 | 162 | def loss(self, *args): 163 | return self._loss(*args) 164 | 165 | def resnet(**kwargs): 166 | """ 167 | Constructs a ResNet model. 168 | """ 169 | return ResNet(**kwargs) 170 | 171 | def build_bottleneck_resnet(depth, nb_classes=100,mode='cifar'): 172 | return ResNet(depth, nb_classes) 173 | -------------------------------------------------------------------------------- /core/resnet2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waitwaitforget/KnowledgeSharing-Pytorch/5b82fbd5f67c191b86dfe31419c2d7ae12217598/core/resnet2.pyc -------------------------------------------------------------------------------- /core/smallnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import sys 5 | sys.path.append('../') 6 | from KnowledgeSharing.utils.netutils import ConvBnReLU, ConvBn, count_params 7 | 8 | 9 | class ToyConvNet(nn.Module): 10 | def __init__(self, n_in, n_channels, n_out, depth): 11 | super(ToyConvNet, self).__init__() 12 | 13 | factor = 2 14 | self.base = ConvBnReLU(n_in, n_channels, 3) 15 | self.pool = nn.MaxPool2d(2) 16 | 17 | self.net = nn.ModuleList() 18 | bc = n_channels 19 | oc = bc * factor 20 | for i in range(3): 21 | self.net.append(ConvBnReLU(bc, oc, 3)) 22 | bc = oc 23 | oc = oc * factor if depth <=3 else oc 24 | self.bottleneck = ConvBnReLU(bc, 128, 3, stride=1, padding=1) 25 | 26 | self.linear = nn.Linear(16*16*128, n_out) 27 | 28 | def forward(self, x): 29 | fx = self.pool(self.base(x)) 30 | for net in self.net: 31 | fx = net(fx) 32 | fx = self.bottleneck(fx) 33 | return self.linear(fx.view(fx.size(0), -1)) 34 | 35 | def loss_fn(self, func): 36 | self.calc_loss = func 37 | return func 38 | 39 | def loss(self, *args): 40 | return self.calc_loss(*args) 41 | 42 | -------------------------------------------------------------------------------- /core/smallnet.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waitwaitforget/KnowledgeSharing-Pytorch/5b82fbd5f67c191b86dfe31419c2d7ae12217598/core/smallnet.pyc -------------------------------------------------------------------------------- /core/vgg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from https://github.com/pytorch/vision.git 3 | ''' 4 | import math 5 | 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | 9 | __all__ = [ 10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 11 | 'vgg19_bn', 'vgg19', 12 | ] 13 | 14 | 15 | class VGG(nn.Module): 16 | ''' 17 | VGG model 18 | ''' 19 | def __init__(self, features): 20 | super(VGG, self).__init__() 21 | self.features = features 22 | self.classifier = nn.Sequential( 23 | nn.Dropout(), 24 | nn.Linear(512, 512), 25 | nn.ReLU(True), 26 | nn.Dropout(), 27 | nn.Linear(512, 512), 28 | nn.ReLU(True), 29 | nn.Linear(512, 10), 30 | ) 31 | # Initialize weights 32 | for m in self.modules(): 33 | if isinstance(m, nn.Conv2d): 34 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 35 | m.weight.data.normal_(0, math.sqrt(2. / n)) 36 | m.bias.data.zero_() 37 | 38 | 39 | def forward(self, x): 40 | x = self.features(x) 41 | x = x.view(x.size(0), -1) 42 | x = self.classifier(x) 43 | return x 44 | 45 | 46 | def make_layers(cfg, batch_norm=False): 47 | layers = [] 48 | in_channels = 3 49 | for v in cfg: 50 | if v == 'M': 51 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 52 | else: 53 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 54 | if batch_norm: 55 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 56 | else: 57 | layers += [conv2d, nn.ReLU(inplace=True)] 58 | in_channels = v 59 | return nn.Sequential(*layers) 60 | 61 | 62 | cfg = { 63 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 64 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 65 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 66 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 67 | 512, 512, 512, 512, 'M'], 68 | } 69 | 70 | 71 | def vgg11(): 72 | """VGG 11-layer model (configuration "A")""" 73 | return VGG(make_layers(cfg['A'])) 74 | 75 | 76 | def vgg11_bn(): 77 | """VGG 11-layer model (configuration "A") with batch normalization""" 78 | return VGG(make_layers(cfg['A'], batch_norm=True)) 79 | 80 | 81 | def vgg13(): 82 | """VGG 13-layer model (configuration "B")""" 83 | return VGG(make_layers(cfg['B'])) 84 | 85 | 86 | def vgg13_bn(): 87 | """VGG 13-layer model (configuration "B") with batch normalization""" 88 | return VGG(make_layers(cfg['B'], batch_norm=True)) 89 | 90 | 91 | def vgg16(): 92 | """VGG 16-layer model (configuration "D")""" 93 | return VGG(make_layers(cfg['D'])) 94 | 95 | 96 | def vgg16_bn(): 97 | """VGG 16-layer model (configuration "D") with batch normalization""" 98 | return VGG(make_layers(cfg['D'], batch_norm=True)) 99 | 100 | 101 | def vgg19(): 102 | """VGG 19-layer model (configuration "E")""" 103 | return VGG(make_layers(cfg['E'])) 104 | 105 | 106 | def vgg19_bn(): 107 | """VGG 19-layer model (configuration 'E') with batch normalization""" 108 | return VGG(make_layers(cfg['E'], batch_norm=True)) -------------------------------------------------------------------------------- /coteaching/coteaching.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import os 4 | sys.path.append('..') 5 | 6 | import KnowledgeSharing.core as core 7 | import torch.nn as nn 8 | 9 | class MutltiModel(nn.Module): 10 | def __init__(self, coherts, coeff=1.0, rate=50): 11 | super(MutualLearning, self).__init__() 12 | 13 | self.models = nn.ModuleList(coherts) 14 | self.n_model = len(coherts) 15 | self.coeff = 1.0 16 | self.criterion = nn.CrossEntropyLoss() 17 | self.percentile = rate 18 | 19 | def forward(self, x): 20 | return [model(x) for model in self.models] 21 | 22 | 23 | -------------------------------------------------------------------------------- /coteaching/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | # Loss functions 5 | def loss_coteaching(y_1, y_2, t, forget_rate, ind, noise_or_not): 6 | loss_1 = F.cross_entropy(y_1, t, reduce = False) 7 | ind_1_sorted = np.argsort(loss_1.data).cuda() 8 | loss_1_sorted = loss_1[ind_1_sorted] 9 | 10 | loss_2 = F.cross_entropy(y_2, t, reduce = False) 11 | ind_2_sorted = np.argsort(loss_2.data).cuda() 12 | loss_2_sorted = loss_2[ind_2_sorted] 13 | 14 | remember_rate = 1 - forget_rate 15 | num_remember = int(remember_rate * len(loss_1_sorted)) 16 | 17 | pure_ratio_1 = np.sum(noise_or_not[ind[ind_1_sorted[:num_remember]]])/float(num_remember) 18 | pure_ratio_2 = np.sum(noise_or_not[ind[ind_2_sorted[:num_remember]]])/float(num_remember) 19 | 20 | ind_1_update=ind_1_sorted[:num_remember] 21 | ind_2_update=ind_2_sorted[:num_remember] 22 | # exchange 23 | loss_1_update = F.cross_entropy(y_1[ind_2_update], t[ind_2_update]) 24 | loss_2_update = F.cross_entropy(y_2[ind_1_update], t[ind_1_update]) 25 | 26 | return torch.sum(loss_1_update)/num_remember, torch.sum(loss_2_update)/num_remember, pure_ratio_1, pure_ratio_2 27 | 28 | -------------------------------------------------------------------------------- /distill/__init__.py: -------------------------------------------------------------------------------- 1 | from .distill import DistillModel 2 | from .loss import hinton_distill_loss, jacobian_matching_loss, intermediate_matching_loss 3 | -------------------------------------------------------------------------------- /distill/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waitwaitforget/KnowledgeSharing-Pytorch/5b82fbd5f67c191b86dfe31419c2d7ae12217598/distill/__init__.pyc -------------------------------------------------------------------------------- /distill/distill.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class DistillModel(nn.Module): 5 | 6 | def __init__(self, teacher, student): 7 | super(DistillModel, self).__init__() 8 | 9 | self.teacher = teacher 10 | self.student = student 11 | 12 | def forward(self, x): 13 | with torch.no_grad(): 14 | ty = self.teacher(x) 15 | sy = self.student(x) 16 | 17 | return sy,ty 18 | 19 | def loss_fn(self, fn): 20 | self._loss = fn 21 | return fn 22 | 23 | def loss(self, pred, target): 24 | return self._loss(pred, target) 25 | 26 | -------------------------------------------------------------------------------- /distill/distill.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waitwaitforget/KnowledgeSharing-Pytorch/5b82fbd5f67c191b86dfe31419c2d7ae12217598/distill/distill.pyc -------------------------------------------------------------------------------- /distill/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import grad 3 | import torch.nn.functional as F 4 | 5 | def scaled_softmax(logits, T=2): 6 | ''' 7 | smooth the logits by divide a temperature 8 | ''' 9 | logits = logits / T 10 | return F.softmax(logits) 11 | 12 | def hinton_distill_loss(t_pred, s_pred, target, T=2, alpha=1): 13 | # compute ce loss 14 | celoss = F.cross_entropy(s_pred, target) 15 | # distll loss 16 | s_s_pred = scaled_softmax(s_pred) 17 | s_t_pred = scaled_softmax(t_pred) 18 | disllloss = F.mse_loss(s_s_pred, s_t_pred.detach()) 19 | return celoss + alpha * T**2 * disllloss 20 | 21 | # comppute derivative 22 | def _calc_jacobian(x, pred, target): 23 | pred.gather(1, target.view(target.size(0),1)).sum().backward(retain_graph=True) 24 | return x.grad.detach() 25 | 26 | def jacobian_matching_loss(*args): 27 | x, t_pred, s_pred, target = args 28 | alpha = 1.0 29 | # compute ce loss 30 | celoss = F.cross_entropy(s_pred, target) 31 | # compute match loss 32 | teacher_jacob = _calc_jacobian(x, t_pred, target) 33 | # zero grad 34 | x.grad.zero_() 35 | # compute stdudnt jacob 36 | student_jacob = _calc_jacobian(x, s_pred, target) 37 | matchloss = torch.norm(student_jacob - teacher_jacob.detach(),2) 38 | # distll loss 39 | disllloss = F.mse_loss(s_pred, t_pred.detach()) 40 | # note that by doing loss.baward(), all the gradients will multiply 2, so the lr should divide by 2 41 | loss = celoss + alpha * matchloss + disllloss 42 | return loss 43 | 44 | def intermediate_matching_loss(pred, target): 45 | pass 46 | 47 | -------------------------------------------------------------------------------- /distill/loss.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waitwaitforget/KnowledgeSharing-Pytorch/5b82fbd5f67c191b86dfe31419c2d7ae12217598/distill/loss.pyc -------------------------------------------------------------------------------- /distill/matching.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | # class Matching -------------------------------------------------------------------------------- /distill/toyconv.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import sys 5 | sys.path.append('../') 6 | from KnowledgeSharing.utils.netutils import ConvBnReLU, ConvBn, count_params 7 | 8 | 9 | class ToyConvNet(nn.Module): 10 | def __init__(self, n_in, n_channels, n_out, depth): 11 | super(ToyConvNet, self).__init__() 12 | 13 | factor = 2 14 | self.base = ConvBnReLU(n_in, n_channels, 3) 15 | self.pool = nn.MaxPool2d(2) 16 | 17 | self.net = nn.ModuleList() 18 | bc = n_channels 19 | oc = bc * factor 20 | for i in range(3): 21 | self.net.append(ConvBnReLU(bc, oc, 3)) 22 | bc = oc 23 | oc = oc * factor if depth <=3 else oc 24 | self.bottleneck = ConvBnReLU(bc, 128, 3, stride=1, padding=1) 25 | 26 | self.linear = nn.Linear(16*16*128, n_out) 27 | 28 | def forward(self, x): 29 | fx = self.pool(self.base(x)) 30 | for net in self.net: 31 | fx = net(fx) 32 | fx = self.bottleneck(fx) 33 | return self.linear(fx.view(fx.size(0), -1)) 34 | 35 | def loss_func(self, func): 36 | self.calc_loss = func 37 | return func 38 | 39 | def loss(self,pred, target): 40 | return self.calc_loss(pred, target) 41 | 42 | 43 | def test(): 44 | import torch 45 | x= torch.rand(16, 3, 32, 32) 46 | m = ToyConvNet(3, 16, 10, 5) 47 | print(count_params(m)) 48 | @m.loss_func 49 | def func(x,y): 50 | return nn.functional.mse_loss(x,y) 51 | z = torch.rand(16,10) 52 | y = (m(x)) 53 | loss = m.loss(y,z) 54 | print(loss) 55 | 56 | if __name__ == '__main__': 57 | test() -------------------------------------------------------------------------------- /mentornet/mentornet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | class Mentor(nn.Module): 6 | def __init__(self, nb_classes, n_step=2): 7 | super(Mentor, self).__init__() 8 | 9 | #self.c1 = c1 10 | #self.c2 = c2 11 | self.label_emb = nn.Embedding(nb_classes, 2) 12 | self.percent_emb = nn.Embedding(100, 5) 13 | 14 | self.lstm = nn.LSTM(2, 10, 1, bidirectional=True) 15 | # self.h0 = torch.rand(2,) 16 | self.fc1 = nn.Linear(27, 20) 17 | self.fc2 = nn.Linear(20, 1) 18 | 19 | def forward(self, data): 20 | label, pt, l = data 21 | x_label = self.label_emb(label) 22 | x_percent = self.percent_emb(pt) 23 | # print(x_label.size()) 24 | h0 = torch.rand(2, label.size(0), 10) 25 | c0 = torch.rand(2, label.size(0), 10) 26 | 27 | output, (hn,cn) = self.lstm(l, (h0,c0)) 28 | output = output.sum(0).squeeze() 29 | x = torch.cat((x_label, x_percent, output), dim=1) 30 | z = F.tanh(self.fc1(x)) 31 | z = F.sigmoid(self.fc2(z)) 32 | return z 33 | 34 | 35 | class MentorNet(nn.Module): 36 | def __init__(self, student_model, basic_criterion=None,c1=1,c2=1,percentile=75, nb_classes=100, ckpt_dir='./checkpoint/mentor/'): 37 | super(MentorNet, self).__init__() 38 | 39 | self.student = student_model # student model 40 | self.mentor = Mentor(nb_classes, 2) # mentor net 41 | 42 | self.ckpt_dir = ckpt_dir 43 | self.c1 = c1 44 | self.c2 = c2 45 | self.nb_classes = nb_classes 46 | self.percentile = percentile 47 | self.basic_criterion = nn.CrossEntropyLoss() if basic_criterion is None else basic_criterion 48 | 49 | def forward(self, x): 50 | pred = self.student(x) 51 | return pred 52 | 53 | #def target_to_one_hot(self, y): 54 | # return torch.randint(self.nb_classes, (y,)).long() 55 | 56 | def loss(self, pred, target, global_step_ratio): 57 | # compute common cross entropy loss 58 | ce_loss = F.cross_entropy(pred, target, reduce=False) 59 | student_loss = ce_loss.mean() 60 | # compute mentor loss 61 | if self.c2 == 0: 62 | mentor_target = ce_loss.le(self.c1).long() 63 | else: 64 | mentor_target = torch.min(torch.max(0, 1 - (ce_loss - self.c1) / self.c2), 1) 65 | percent = torch.ones(pred.size(0)).long() * global_step_ratio 66 | 67 | 68 | import torch 69 | mentor = Mentor(100, 2) 70 | label = torch.randint(100, (10,)).long() 71 | #label = torch.eye(100)[label] 72 | #label = label.long() 73 | 74 | percent = torch.randint(100, (10,)).long() 75 | #percent = torch.eye(100)[percent] 76 | #percent = percent.long() 77 | 78 | l = torch.rand(2, 10, 2) 79 | 80 | y = mentor((label, percent, l)) 81 | print(y) -------------------------------------------------------------------------------- /onthefly/ONE.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import sys 6 | sys.path.append('../') 7 | 8 | import math 9 | # from KnowledgeSharing.core.resnet2 import BasicBlock, Bottleneck 10 | 11 | def conv3x3(in_planes, out_planes, stride=1): 12 | "3x3 convolution with padding" 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 14 | padding=1, bias=False) 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = conv3x3(inplanes, planes, stride) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.relu = nn.ReLU(inplace=True) 24 | self.conv2 = conv3x3(planes, planes) 25 | self.bn2 = nn.BatchNorm2d(planes) 26 | self.downsample = downsample 27 | self.stride = stride 28 | 29 | def forward(self, x): 30 | residual = x 31 | 32 | out = self.conv1(x) 33 | out = self.bn1(out) 34 | out = self.relu(out) 35 | 36 | out = self.conv2(out) 37 | out = self.bn2(out) 38 | 39 | if self.downsample is not None: 40 | residual = self.downsample(x) 41 | 42 | out += residual 43 | out = self.relu(out) 44 | 45 | return out 46 | 47 | 48 | class Bottleneck(nn.Module): 49 | expansion = 4 50 | 51 | def __init__(self, inplanes, planes, stride=1, downsample=None): 52 | super(Bottleneck, self).__init__() 53 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 54 | self.bn1 = nn.BatchNorm2d(planes) 55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 56 | padding=1, bias=False) 57 | self.bn2 = nn.BatchNorm2d(planes) 58 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 59 | self.bn3 = nn.BatchNorm2d(planes * 4) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.downsample = downsample 62 | self.stride = stride 63 | 64 | def forward(self, x): 65 | residual = x 66 | 67 | out = self.conv1(x) 68 | out = self.bn1(out) 69 | out = self.relu(out) 70 | 71 | out = self.conv2(out) 72 | out = self.bn2(out) 73 | out = self.relu(out) 74 | 75 | out = self.conv3(out) 76 | out = self.bn3(out) 77 | 78 | if self.downsample is not None: 79 | residual = self.downsample(x) 80 | 81 | out += residual 82 | out = self.relu(out) 83 | 84 | return out 85 | 86 | class ONEBranch(nn.Module): 87 | def __init__(self, inplanes, num_classes, block, depth): 88 | super(ONEBranch, self).__init__() 89 | self.inplanes = inplanes 90 | self.n = depth 91 | self.block = block 92 | 93 | self.layer2 = self._make_layer(block, 32, depth, stride=2) 94 | self.layer3 = self._make_layer(block, 64, depth, stride=2) 95 | self.avgpool = nn.AvgPool2d(8) 96 | self.fc = nn.Linear(64 * block.expansion, num_classes) 97 | 98 | def _make_layer(self, block, planes, blocks, stride=1): 99 | downsample = None 100 | if stride != 1 or self.inplanes != planes * block.expansion: 101 | downsample = nn.Sequential( 102 | nn.Conv2d(self.inplanes, planes * block.expansion, 103 | kernel_size=1, stride=stride, bias=False), 104 | nn.BatchNorm2d(planes * block.expansion), 105 | ) 106 | 107 | layers = [] 108 | layers.append(block(self.inplanes, planes, stride, downsample)) 109 | self.inplanes = planes * block.expansion 110 | for i in range(1, blocks): 111 | layers.append(block(self.inplanes, planes)) 112 | 113 | return nn.Sequential(*layers) 114 | 115 | def forward(self, x): 116 | x= self.layer2(x) 117 | x= self.layer3(x) 118 | x = self.avgpool(x) 119 | x = x.view(x.size(0), -1) 120 | x = self.fc(x) 121 | return x 122 | 123 | class ONE(nn.Module): 124 | def __init__(self, n_branches,depth, nb_classes): 125 | super(ONE, self).__init__() 126 | 127 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2' 128 | n = (depth - 2) // 6 129 | 130 | block = Bottleneck if depth >=44 else BasicBlock 131 | 132 | self.n_branches = n_branches 133 | self.inplanes = 16 134 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 135 | bias=False) 136 | self.bn1 = nn.BatchNorm2d(16) 137 | self.relu = nn.ReLU(inplace=True) 138 | self.layer1 = self._make_layer(block, 16, n) 139 | self.branches = nn.ModuleList([ONEBranch(self.inplanes, nb_classes, block, n) for _ in range(n_branches)]) 140 | 141 | # gate module 142 | self.gate_conv = nn.Conv2d(16, 64, 1, stride=2) 143 | self.gate_bn = nn.BatchNorm2d(64) 144 | self.gate_relu = nn.ReLU(True) 145 | self.gate_linear = nn.Linear(64*16*16, n_branches) 146 | 147 | self._init_weights() 148 | 149 | def _init_weights(self): 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 153 | m.weight.data.normal_(0, math.sqrt(2. / n)) 154 | elif isinstance(m, nn.BatchNorm2d): 155 | m.weight.data.fill_(1) 156 | m.bias.data.zero_() 157 | 158 | def _make_layer(self, block, planes, blocks, stride=1): 159 | downsample = None 160 | if stride != 1 or self.inplanes != planes * block.expansion: 161 | downsample = nn.Sequential( 162 | nn.Conv2d(self.inplanes, planes * block.expansion, 163 | kernel_size=1, stride=stride, bias=False), 164 | nn.BatchNorm2d(planes * block.expansion), 165 | ) 166 | 167 | layers = [] 168 | layers.append(block(self.inplanes, planes, stride, downsample)) 169 | self.inplanes = planes * block.expansion 170 | for i in range(1, blocks): 171 | layers.append(block(self.inplanes, planes)) 172 | 173 | return nn.Sequential(*layers) 174 | 175 | def forward(self,x ): 176 | x = self.conv1(x) 177 | x = self.bn1(x) 178 | x = self.relu(x) # 32x32 179 | 180 | x = self.layer1(x) # 32x32 181 | # gate branch 182 | g =self.gate_relu(self.gate_bn(self.gate_conv(x))) 183 | g = self.gate_linear(g.view(g.size(0),-1)) 184 | g = F.softmax(g, dim=1) 185 | # classifier branches 186 | 187 | out = [branch(x) for branch in self.branches] 188 | ensemble_out = sum([o * g[:,i].unsqueeze(1) for o,i in zip(out, range(self.n_branches))]) 189 | return out, ensemble_out 190 | 191 | 192 | if __name__=='__main__': 193 | model = ONE(3,32,100) 194 | x = torch.rand(32, 3, 32,32) 195 | out, gate_prob = model(x) 196 | print(gate_prob.size()) 197 | print(model) -------------------------------------------------------------------------------- /onthefly/trainone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | from __future__ import absolute_import 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import torchvision.transforms as transforms 10 | import sys 11 | import os 12 | sys.path.append('../') 13 | import KnowledgeSharing.distill as distill 14 | from KnowledgeSharing.utils.trainutils import AverageMeter 15 | from KnowledgeSharing.utils.metricutils import accuracy 16 | import KnowledgeSharing.utils as utils 17 | import optparsedgeSharing.core as core 18 | import KnowledgeSharing.common as common_args 19 | import time 20 | from tqdm import tqdm 21 | from KnowledgeSharing.ONE import ONE 22 | 23 | from copy import deepcopy 24 | from tensorboardX import SummaryWriter 25 | 26 | des = '' 27 | writer = SummaryWriter(des) 28 | 29 | args = common_args.parse() 30 | """ training teacher model and distill it to student""" 31 | def train_epoch(model, optimizer, train_loader, epoch, callback=None, **kwargs): 32 | global writer 33 | model.train() 34 | loss = AverageMeter() 35 | metrics = AverageMeter() 36 | data_time = time.time() 37 | train_bar = tqdm(train_loader) 38 | devices = 'cuda:0' if 'devices' not in kwargs.keys() else kwargs['devices'] 39 | for idx, (data, target) in enumerate(train_loader): 40 | data = data.to(devices) 41 | target = target.to(devices) 42 | data_time = time.time() - data_time 43 | # forward 44 | batch_time = time.time() 45 | pred = model(data) 46 | batch_time = time.time() - batch_time 47 | batch_loss = model.loss(pred, target) 48 | metrics.update(accuracy(pred[0][0], target), data.size(0)) 49 | 50 | loss.update(batch_loss.item(), data.size(0)) 51 | # backward 52 | optimizer.zero_grad() 53 | batch_loss.backward() 54 | optimizer.step() 55 | # print(metrics.avg,loss.avg,type(loss.avg)) 56 | train_bar.set_description('Epoch {}/{}, data_time: {:.4f}, batch_time: {:.4f}, loss: {:.4f}, accuracy: {:.4f}'.format(idx, epoch, 57 | data_time, batch_time, loss.avg, metrics.avg.item())) 58 | data_time = time.time() 59 | 60 | writer.add_scalar('data/one_train_loss', loss.avg, epoch) 61 | writer.add_scalar('data/one_train_acc', metrics.avg.item(),epoch) 62 | 63 | 64 | def scaled_softmax(logits, T=2): 65 | ''' 66 | smooth the logits by divide a temperature 67 | ''' 68 | logits = logits / T 69 | return F.softmax(logits) 70 | 71 | def kl_div(p, q): 72 | return F.kl_div(p, q) 73 | 74 | def validate(model, val_loader, epoch, callback=None, **kwargs): 75 | model.eval() 76 | loss = AverageMeter() 77 | metrics = AverageMeter() 78 | data_time = time.time() 79 | val_bar = tqdm(val_loader) 80 | devices = 'cuda:0' if 'devices' not in kwargs.keys() else kwargs['devices'] 81 | with torch.no_grad(): 82 | for idx, (data, target) in enumerate(val_loader): 83 | data = data.to(devices) 84 | target = target.to(devices) 85 | data_time = time.time() - data_time 86 | # forward 87 | batch_time = time.time() 88 | pred = model(data) 89 | batch_time = time.time() - batch_time 90 | batch_loss = model.loss(pred, target) 91 | metrics.update(accuracy(pred[0][0], target), data.size(0)) 92 | 93 | loss.update(batch_loss.item(), data.size(0)) 94 | 95 | 96 | val_bar.set_description('Eval epoch {}/{}, data_time: {:.4f}, batch_time: {:.4f}, loss: {:.4f}, accuracy: {:.4f}'.format(idx, epoch, 97 | data_time, batch_time, loss.avg, metrics.avg.item())) 98 | data_time = time.time() 99 | writer.add_scalar('data/teacher_val_loss', loss.avg, epoch) 100 | writer.add_scalar('data/teacher_val_acc', metrics.avg.item(),epoch) 101 | return metrics.avg 102 | 103 | def train_one_model(args): 104 | print('training one model') 105 | model = ONE(args.n_model, args.depth, args.nb_classes) 106 | optimizer, scheduler = utils.trainutils.build_optimizer_scheduler(model, args) 107 | # register a loss for the model 108 | @model.loss_fn 109 | def loss(pred, target): 110 | logit_list, ensemble_pred = pred 111 | ce_loss = sum([F.cross_entropy(p, target) for p in logit_list]) + F.cross_entropy(ensemble_pred, target) 112 | prob_list = [scaled_softmax(p,args.T) for p in logit_list] 113 | kl_loss = [kl_div(scaled_softmax(ensemble_pred, args.T), prob_list[i]) for i in range(len(prob_list))] 114 | return ce_loss + args.T ** 2 * kl_loss 115 | 116 | train_transform= transforms.Compose([transforms.RandomHorizontalFlip(), 117 | transforms.RandomCrop(32, padding=4), 118 | transforms.ToTensor(), 119 | transforms.Normalize(mean=[n/255. for n in [129.3, 124.1, 112.4]], 120 | std=[n/255. for n in [68.2, 65.4, 70.4]])]) 121 | val_transform = transforms.Compose([transforms.ToTensor(), 122 | transforms.Normalize(mean=[n/255. for n in [129.3, 124.1, 112.4]], 123 | std=[n/255. for n in [68.2, 65.4, 70.4]])]) 124 | train_loader, val_loader = utils.build_vision_dataloader(args.data_dir, args.dataset, 125 | batch_size=args.batch_size, num_workers=args.num_workers, train_transform=train_transform,val_transform=val_transform) 126 | best_accuracy = 0.0 127 | for epoch in range(args.max_epoch): 128 | scheduler.step(epoch) 129 | train_epoch(model, optimizer, train_loader, epoch, devices=args.devices) 130 | val_acc = validate(model, val_loader, epoch,devices=args.devices) 131 | if val_acc > best_accuracy or (epoch+1)%args.ckpt_interval == 0: 132 | utils.trainutils.save_checkpoints(args.ckpt_dir, model, val_acc, best_accuracy, epoch, name=args.model+'_one_') 133 | best_accuracy = val_acc 134 | 135 | return model 136 | 137 | if __name__=='__main__': 138 | train_one_model(args) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision.datasets as datasets 6 | 7 | import torchvision.transforms as transforms 8 | 9 | import os 10 | import string 11 | import random 12 | import numpy as np 13 | from core.KSNet import KSNet 14 | from core.ensemble import Ensemble 15 | from utils.trainutils import AverageMeter 16 | from utils.metricutils import accuracy, oracle_accuracy 17 | from utils.datautils import build_vision_dataloader 18 | from utils.misc import random_string 19 | from utils.trainutils import PerformanceReporter 20 | 21 | import argparse 22 | import time 23 | # tensorboard 24 | from tensorboardX import SummaryWriter 25 | 26 | parser = argparse.ArgumentParser('Training Knowledge sharing network.') 27 | parser.add_argument('--memo', default='exp', type=str, help='memo for the experiment') 28 | parser.add_argument('--model', default='ks', type=str, help='type of model') 29 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 30 | parser.add_argument('-T', default=2, type=float, help='temperature for KL-divergence') 31 | parser.add_argument('--momentum', default=0.9, type=float, help='learning rate') 32 | parser.add_argument('--weight-decay', default=5e-4, type=float, help='learning rate') 33 | parser.add_argument('--max-epoch', default=300, type=int, help='maxial training epoch') 34 | parser.add_argument('--batch-size', default=256, type=int, help='batch size') 35 | parser.add_argument('--nmodel', default=5, type=int, help='number of models') 36 | parser.add_argument('--devices', default=None, required=True, help='need specify devices') 37 | parser.add_argument('--depth', default=32, type=int) 38 | parser.add_argument('--print-interval', default=10, type=int) 39 | args = parser.parse_args() 40 | 41 | des = '-'.join([args.memo, args.model]+ list(map(str,[args.nmodel, args.depth,args.T, args.lr, args.batch_size])) + [random_string(6)]) 42 | devices = args.devices.split(',') 43 | writer = SummaryWriter(os.path.join('runs',des)) 44 | # define models and optimizers 45 | if args.model == 'ks': 46 | model = KSNet(5, args.depth,100, devices) 47 | elif args.model == 'ensemble': 48 | model = Ensemble(5, args.depth,100, devices) 49 | model.to(devices) 50 | optimizers = torch.optim.SGD(model.parameters(), lr=args.lr,momentum=0.9, nesterov=True, weight_decay=args.weight_decay) 51 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizers, [150, 225], gamma=0.1) 52 | criterion = nn.CrossEntropyLoss(reduce=False) 53 | # define datasets and dataloader 54 | transform= transforms.Compose([transforms.RandomHorizontalFlip(), 55 | transforms.RandomCrop(32, padding=4), 56 | transforms.ToTensor(), 57 | transforms.Normalize(mean=[n/255. for n in [129.3, 124.1, 112.4]], 58 | std=[n/255. for n in [68.2, 65.4, 70.4]])]) 59 | train_loader, test_loader = build_vision_dataloader('./data',dsname='CIFAR100', 60 | transform=transform,batch_size=args.batch_size) 61 | reporter = PerformanceReporter(None, acc=accuracy, oracle_acc=oracle_accuracy) 62 | 63 | def compute_oracle_loss(preds, targets): 64 | 65 | losses = [criterion(pred, targets) for pred in preds] 66 | losses = [loss.unsqueeze(1) for loss in losses] 67 | losses = torch.cat(losses, 1) 68 | valid_loss, valid_index = torch.topk(losses, k=1, dim=1, largest=False) 69 | logits = torch.cat([pred.unsqueeze(2) for pred in preds],dim=2) 70 | 71 | valid_index = valid_index.repeat(1, logits.size(1)) 72 | valid_logits = logits.gather(2, valid_index.unsqueeze(2)) 73 | return valid_loss.mean(), valid_logits 74 | 75 | def compute_kl_loss(pred, target, temperature=1.0): 76 | def softmax(x, T): 77 | y = torch.exp(x/T) 78 | y = y/(y.sum(1).unsqueeze(1)+1e-8) 79 | return y 80 | probs = softmax(pred.squeeze(),temperature) 81 | targp = softmax(target.squeeze(), temperature) 82 | loss = temperature ** 2 * F.kl_div(probs, targp) 83 | return loss 84 | 85 | # register loss function before training 86 | if args.model == 'ks': 87 | @model.loss_fn 88 | def ks_loss_fn(pred, targets): 89 | spec_logits, gene_logits = pred 90 | batch_losses, valid_logits = compute_oracle_loss(spec_logits, targets) 91 | gene_loss = criterion(gene_logits, targets).mean() 92 | kl_loss = compute_kl_loss(gene_logits, valid_logits.detach(), temperature=args.T) 93 | batch_loss = batch_losses + kl_loss # + gene_loss 94 | return batch_loss 95 | elif args.model == 'ensemble': 96 | @model.loss_fn 97 | def ie_loss_fn(pred, targets): 98 | batch_loss = sum([criterion(pred, targets).mean() for pred in pred[0]]) 99 | return batch_loss 100 | 101 | def train_epoch(model, optimizer, criterion, train_loader, epoch): 102 | model.train() 103 | loss = AverageMeter() 104 | 105 | for i,(data, targets) in enumerate(train_loader): 106 | data_time = time.time() 107 | batch_datas = [data.to('cuda:%s'% dev) for dev in devices] 108 | targets = targets.to('cuda:%s'%devices[0]) 109 | data_time = time.time() - data_time 110 | # forward 111 | batch_time = time.time() 112 | pred = model(batch_datas) 113 | batch_time = time.time() - batch_time 114 | batch_loss = model.loss(pred, targets) 115 | 116 | measures = reporter.write(pred, targets) 117 | # update metrics 118 | optimizer.zero_grad() 119 | batch_loss.backward() 120 | loss.update(batch_loss.item()) 121 | optimizer.step() 122 | 123 | if (i+1) % args.print_interval == 0: 124 | cmdstr = 'Epoch {}/{}: data_time: {:.4f}, batch_time: {:.4f}, batch_loss: {:.4f}, total_loss: {:.4f},'.format(i, epoch, data_time, batch_time, batch_loss.item(), 125 | loss.avg) 126 | metricstr = ', '.join(['%s: %.4f' for k,v in reporter.metrics.items()]) 127 | print(cmdstr + metricstr) 128 | reporter.reset() 129 | 130 | # write to summary 131 | writer.add_scalar('data/train_loss', loss.avg, epoch) 132 | writer.add_scalar('data/train_acc1', acc1.avg, epoch) 133 | writer.add_scalar('data/train_acc5', acc5.avg, epoch) 134 | writer.add_scalar('data/train_oracle', oracle_acc.avg, epoch) 135 | 136 | def evaluate(model, criterion, val_loader, epoch): 137 | loss = AverageMeter() 138 | 139 | with torch.no_grad(): 140 | for i,(data, targets) in enumerate(val_loader): 141 | data_time = time.time() 142 | batch_datas = [data.to('cuda:%s'% dev) for dev in devices] 143 | targets = targets.to('cuda:%s'%devices[0]) 144 | data_time = time.time() - data_time 145 | # forward 146 | batch_time = time.time() 147 | spec_logits, gene_logits = model(batch_datas) 148 | batch_time = time.time() - batch_time 149 | batch_loss = model.loss(pred, targets) 150 | 151 | measures = reporter.write(pred, targets) 152 | # update metrics 153 | loss.update(batch_loss.item()) 154 | 155 | if (i+1) % args.print_interval == 0: 156 | cmdstr = 'Eval epoch {}/{}: data_time: {:.4f}, batch_time: {:.4f}, batch_loss: {:.4f}, total_loss: {:.4f},'.format(i, epoch, data_time, batch_time, batch_loss.item(), 157 | loss.avg) 158 | metricstr = ', '.join(['%s: %.4f' for k,v in reporter.metrics.items()]) 159 | print(cmdstr + metricstr) 160 | # write to summary 161 | writer.add_scalar('data/val_loss', loss.avg, epoch) 162 | writer.add_scalar('data/val_acc1', acc1.avg, epoch) 163 | writer.add_scalar('data/val_acc5', acc5.avg, epoch) 164 | writer.add_scalar('data/val_oracle', oracle_acc.avg, epoch) 165 | 166 | def main(): 167 | for epoch in range(args.max_epoch): 168 | lr_scheduler.step() 169 | train_epoch(model, optimizers, criterion, train_loader, epoch) 170 | evaluate(model, criterion, test_loader, epoch) 171 | def kidding(): 172 | print('wtf') 173 | if __name__=='__main__': 174 | main() 175 | --------------------------------------------------------------------------------