├── README.md └── hinton ├── utils.py └── main.py /README.md: -------------------------------------------------------------------------------- 1 | # distillation.pytorch 2 | 3 | This project is work in progress. 4 | 5 | This repository will include several implementations of knowledge distillation techniques. 6 | 7 | * [Hinton et al. 2015: Distilling the Knowledge in a Neural Network](http://arxiv.org/abs/1503.02531) 8 | 9 | 10 | ## Requirements 11 | 12 | * PyTorch >= 1.0 13 | * torchvision 14 | * homura >= 0.4.0 (`pip install -U git+https://github.com/moskomule/homura@v0.4`) 15 | 16 | -------------------------------------------------------------------------------- /hinton/utils.py: -------------------------------------------------------------------------------- 1 | from homura import trainers 2 | from homura.callbacks import metric_callback_decorator 3 | from homura.utils import Map 4 | from homura.vision.models.classification import resnet20, wrn28_10, wrn28_2, resnet56 5 | from torch.nn import functional as F 6 | 7 | MODELS = {"resnet20": resnet20, 8 | "wrn28_10": wrn28_10, 9 | "wrn28_2": wrn28_2, 10 | "resnet56": resnet56} 11 | 12 | 13 | @metric_callback_decorator 14 | def kl_loss(data): 15 | return data['kl_loss'] 16 | 17 | 18 | class DistillationTrainer(trainers.SupervisedTrainer): 19 | def __init__(self, model, optimizer, loss_f, callbacks, scheduler, teacher_model, temperature, lambda_factor=1): 20 | super(DistillationTrainer, self).__init__(model, optimizer, loss_f, callbacks=callbacks, scheduler=scheduler) 21 | self.teacher = teacher_model 22 | for p in self.teacher.parameters(): 23 | p.requires_grad_(False) 24 | self.teacher.to(self.device) 25 | self.temperature = temperature 26 | self.lambda_factor = lambda_factor 27 | 28 | def iteration(self, data): 29 | input, target = data 30 | output = self.model(input) 31 | 32 | if self.is_train: 33 | self.optimizer.zero_grad() 34 | lesson = self.teacher(input) 35 | kl_loss = F.kl_div((output / self.temperature).log_softmax(), (lesson / self.temperature).softmax(), 36 | reduction="batchmean") 37 | loss = self.loss_f(output, target) + self.lambda_factor * (self.temperature ** 2) * kl_loss 38 | loss.backward() 39 | self.optimizer.step() 40 | else: 41 | loss = self.loss_f(output, target) 42 | kl_loss = 0 43 | results = Map(loss=loss, output=output, kl_loss=kl_loss) 44 | return results 45 | -------------------------------------------------------------------------------- /hinton/main.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from homura import optim, lr_scheduler, callbacks, trainers, reporters 3 | from homura.vision.data.loaders import cifar10_loaders 4 | from tqdm import trange 5 | 6 | from utils import DistillationTrainer, kl_loss, MODELS 7 | 8 | 9 | def main(): 10 | model = MODELS[args.teacher_model](num_classes=10) 11 | train_loader, test_loader = cifar10_loaders(args.batch_size) 12 | weight_decay = 1e-4 if "resnet" in args.teacher_model else 5e-4 13 | lr_decay = 0.1 if "resnet" in args.teacher_model else 0.2 14 | optimizer = optim.SGD(lr=1e-1, momentum=0.9, weight_decay=weight_decay) 15 | scheduler = lr_scheduler.MultiStepLR([50, 80], gamma=lr_decay) 16 | 17 | trainer = trainers.SupervisedTrainer(model, optimizer, F.cross_entropy, scheduler=scheduler) 18 | trainer.logger.info("Train the teacher model!") 19 | for _ in trange(args.teacher_epochs, ncols=80): 20 | trainer.train(train_loader) 21 | trainer.test(test_loader) 22 | 23 | teacher_model = model.eval() 24 | 25 | weight_decay = 1e-4 if "resnet" in args.student_model else 5e-4 26 | lr_decay = 0.1 if "resnet" in args.student_model else 0.2 27 | optimizer = optim.SGD(lr=1e-1, momentum=0.9, weight_decay=weight_decay) 28 | scheduler = lr_scheduler.MultiStepLR([50, 80], gamma=lr_decay) 29 | model = MODELS[args.student_model](num_classes=10) 30 | 31 | c = [callbacks.AccuracyCallback(), callbacks.LossCallback(), kl_loss] 32 | with reporters.TQDMReporter(range(args.student_epochs), callbacks=c) as tq, reporters.TensorboardReporter(c) as tb: 33 | trainer = DistillationTrainer(model, optimizer, F.cross_entropy, callbacks=[tq, tb], 34 | scheduler=scheduler, teacher_model=teacher_model, temperature=args.temperature) 35 | trainer.logger.info("Train the student model!") 36 | for _ in tq: 37 | trainer.train(train_loader) 38 | trainer.test(test_loader) 39 | 40 | 41 | if __name__ == '__main__': 42 | import miniargs 43 | 44 | p = miniargs.ArgumentParser() 45 | p.add_int("--batch_size", default=256) 46 | p.add_str("--teacher_model", choices=list(MODELS.keys())) 47 | p.add_str("--student_model", choices=list(MODELS.keys())) 48 | p.add_float("--temperature", default=0.1) 49 | p.add_int("--teacher_epochs", default=100) 50 | p.add_int("--student_epochs", default=100) 51 | 52 | args = p.parse() 53 | main() 54 | --------------------------------------------------------------------------------