├── CIFAR_10 ├── network │ ├── alexnet.py │ ├── alexnet_half.py │ ├── best_model.pth │ └── dcgan_model.py ├── train_generator │ └── dfgan.py ├── train_student │ ├── Using_Data │ │ └── KD_related_data.py │ └── Using_GAN │ │ └── KD_dfgan.py └── train_teacher │ └── train_teacher.py ├── LICENSE ├── Other_networks_used ├── inceptionv3_teacher.py ├── lenet.py ├── lenet_half.py └── resnet_18_student.py ├── README.md └── Readme.txt /CIFAR_10/network/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class AlexNet(nn.Module): 5 | def __init__(self, num_classes=10): 6 | super(AlexNet, self).__init__() 7 | 8 | self.conv1 = nn.Conv2d(3, 48, 5, stride=1, padding=2) 9 | self.conv1.bias.data.normal_(0, 0.01) 10 | self.conv1.bias.data.fill_(0) 11 | 12 | self.relu = nn.ReLU() 13 | self.lrn = nn.LocalResponseNorm(2) 14 | self.pad = nn.MaxPool2d(3, stride=2) 15 | 16 | self.batch_norm1 = nn.BatchNorm2d(48, eps=0.001) 17 | 18 | self.conv2 = nn.Conv2d(48, 128, 5, stride=1, padding=2) 19 | self.conv2.bias.data.normal_(0, 0.01) 20 | self.conv2.bias.data.fill_(1.0) 21 | 22 | self.batch_norm2 = nn.BatchNorm2d(128, eps=0.001) 23 | 24 | self.conv3 = nn.Conv2d(128, 192, 3, stride=1, padding=1) 25 | self.conv3.bias.data.normal_(0, 0.01) 26 | self.conv3.bias.data.fill_(0) 27 | 28 | self.batch_norm3 = nn.BatchNorm2d(192, eps=0.001) 29 | 30 | self.conv4 = nn.Conv2d(192, 192, 3, stride=1, padding=1) 31 | self.conv4.bias.data.normal_(0, 0.01) 32 | self.conv4.bias.data.fill_(1.0) 33 | 34 | self.batch_norm4 = nn.BatchNorm2d(192, eps=0.001) 35 | 36 | self.conv5 = nn.Conv2d(192, 128, 3, stride=1, padding=1) 37 | self.conv5.bias.data.normal_(0, 0.01) 38 | self.conv5.bias.data.fill_(1.0) 39 | 40 | self.batch_norm5 = nn.BatchNorm2d(128, eps=0.001) 41 | 42 | self.fc1 = nn.Linear(1152,512) 43 | self.fc1.bias.data.normal_(0, 0.01) 44 | self.fc1.bias.data.fill_(0) 45 | 46 | self.drop = nn.Dropout(p=0.5) 47 | 48 | self.batch_norm6 = nn.BatchNorm1d(512, eps=0.001) 49 | 50 | self.fc2 = nn.Linear(512,256) 51 | self.fc2.bias.data.normal_(0, 0.01) 52 | self.fc2.bias.data.fill_(0) 53 | 54 | self.batch_norm7 = nn.BatchNorm1d(256, eps=0.001) 55 | 56 | self.fc3 = nn.Linear(256,10) 57 | self.fc3.bias.data.normal_(0, 0.01) 58 | self.fc3.bias.data.fill_(0) 59 | 60 | self.soft = nn.Softmax() 61 | 62 | def forward(self, x): 63 | layer1 = self.batch_norm1(self.pad(self.lrn(self.relu(self.conv1(x))))) 64 | layer2 = self.batch_norm2(self.pad(self.lrn(self.relu(self.conv2(layer1))))) 65 | layer3 = self.batch_norm3(self.relu(self.conv3(layer2))) 66 | layer4 = self.batch_norm4(self.relu(self.conv4(layer3))) 67 | layer5 = self.batch_norm5(self.pad(self.relu(self.conv5(layer4)))) 68 | flatten = layer5.view(-1, 128*3*3) 69 | fully1 = self.relu(self.fc1(flatten)) 70 | fully1 = self.batch_norm6(self.drop(fully1)) 71 | fully2 = self.relu(self.fc2(fully1)) 72 | fully2 = self.batch_norm7(self.drop(fully2)) 73 | logits = self.fc3(fully2) 74 | #softmax_val = self.soft(logits) 75 | 76 | return logits 77 | 78 | model = AlexNet(10) 79 | 80 | -------------------------------------------------------------------------------- /CIFAR_10/network/alexnet_half.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class AlexNet_half(nn.Module): 5 | def __init__(self, num_classes=10): 6 | super(AlexNet_half, self).__init__() 7 | 8 | self.conv1 = nn.Conv2d(3, 24, 5, stride=1, padding=2) 9 | self.conv1.bias.data.normal_(0, 0.01) 10 | self.conv1.bias.data.fill_(0) 11 | 12 | self.relu = nn.ReLU() 13 | self.lrn = nn.LocalResponseNorm(2) 14 | self.pad = nn.MaxPool2d(3, stride=2) 15 | 16 | self.batch_norm1 = nn.BatchNorm2d(24, eps=0.001) 17 | 18 | self.conv2 = nn.Conv2d(24, 64, 5, stride=1, padding=2) 19 | self.conv2.bias.data.normal_(0, 0.01) 20 | self.conv2.bias.data.fill_(1.0) 21 | 22 | self.batch_norm2 = nn.BatchNorm2d(64, eps=0.001) 23 | 24 | self.conv3 = nn.Conv2d(64, 96, 3, stride=1, padding=1) 25 | self.conv3.bias.data.normal_(0, 0.01) 26 | self.conv3.bias.data.fill_(0) 27 | 28 | self.batch_norm3 = nn.BatchNorm2d(96, eps=0.001) 29 | 30 | self.conv4 = nn.Conv2d(96, 96, 3, stride=1, padding=1) 31 | self.conv4.bias.data.normal_(0, 0.01) 32 | self.conv4.bias.data.fill_(1.0) 33 | 34 | self.batch_norm4 = nn.BatchNorm2d(96, eps=0.001) 35 | 36 | self.conv5 = nn.Conv2d(96, 64, 3, stride=1, padding=1) 37 | self.conv5.bias.data.normal_(0, 0.01) 38 | self.conv5.bias.data.fill_(1.0) 39 | 40 | self.batch_norm5 = nn.BatchNorm2d(64, eps=0.001) 41 | 42 | self.fc1 = nn.Linear(576,256) 43 | self.fc1.bias.data.normal_(0, 0.01) 44 | self.fc1.bias.data.fill_(0) 45 | 46 | self.drop = nn.Dropout(p=0.5) 47 | 48 | self.batch_norm6 = nn.BatchNorm1d(256, eps=0.001) 49 | 50 | self.fc2 = nn.Linear(256,128) 51 | self.fc2.bias.data.normal_(0, 0.01) 52 | self.fc2.bias.data.fill_(0) 53 | 54 | self.batch_norm7 = nn.BatchNorm1d(128, eps=0.001) 55 | 56 | self.fc3 = nn.Linear(128,10) 57 | self.fc3.bias.data.normal_(0, 0.01) 58 | self.fc3.bias.data.fill_(0) 59 | 60 | self.soft = nn.Softmax() 61 | 62 | def forward(self, x): 63 | layer1 = self.batch_norm1(self.pad(self.lrn(self.relu(self.conv1(x))))) 64 | layer2 = self.batch_norm2(self.pad(self.lrn(self.relu(self.conv2(layer1))))) 65 | layer3 = self.batch_norm3(self.relu(self.conv3(layer2))) 66 | layer4 = self.batch_norm4(self.relu(self.conv4(layer3))) 67 | layer5 = self.batch_norm5(self.pad(self.relu(self.conv5(layer4)))) 68 | flatten = layer5.view(-1, 64*3*3) 69 | fully1 = self.relu(self.fc1(flatten)) 70 | fully1 = self.batch_norm6(self.drop(fully1)) 71 | fully2 = self.relu(self.fc2(fully1)) 72 | fully2 = self.batch_norm7(self.drop(fully2)) 73 | logits = self.fc3(fully2) 74 | #softmax_val = self.soft(logits) 75 | 76 | return logits 77 | 78 | model = AlexNet_half(10) 79 | 80 | -------------------------------------------------------------------------------- /CIFAR_10/network/best_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vcl-iisc/DeGAN/213e839bdd7679bb799c6b273ef49f5f2e7b98fc/CIFAR_10/network/best_model.pth -------------------------------------------------------------------------------- /CIFAR_10/network/dcgan_model.py: -------------------------------------------------------------------------------- 1 | # Network of DCGAN 2 | import torch 3 | import torch.nn as nn 4 | 5 | class Generator(nn.Module): 6 | def __init__(self, ngpu, nc=3, nz=100, ngf=64): 7 | super(Generator, self).__init__() 8 | self.ngpu = ngpu 9 | self.main = nn.Sequential( 10 | nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False), 11 | nn.BatchNorm2d(ngf * 8), 12 | nn.ReLU(True), 13 | 14 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), 15 | nn.BatchNorm2d(ngf * 4), 16 | nn.ReLU(True), 17 | 18 | nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), 19 | nn.BatchNorm2d(ngf * 2), 20 | nn.ReLU(True), 21 | 22 | nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), 23 | nn.BatchNorm2d(ngf), 24 | nn.ReLU(True), 25 | 26 | nn.ConvTranspose2d( ngf, nc, kernel_size=1, stride=1, padding=0, bias=False), 27 | nn.Tanh() 28 | ) 29 | 30 | def forward(self, input): 31 | if input.is_cuda and self.ngpu > 1: 32 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 33 | else: 34 | output = self.main(input) 35 | return output 36 | 37 | 38 | class Discriminator(nn.Module): 39 | def __init__(self, ngpu, nc=3, ndf=64): 40 | super(Discriminator, self).__init__() 41 | self.ngpu = ngpu 42 | self.main = nn.Sequential( 43 | nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), 44 | nn.LeakyReLU(0.2, inplace=True), 45 | 46 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 47 | nn.BatchNorm2d(ndf * 2), 48 | nn.LeakyReLU(0.2, inplace=True), 49 | 50 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 51 | nn.BatchNorm2d(ndf * 4), 52 | nn.LeakyReLU(0.2, inplace=True), 53 | 54 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 55 | nn.BatchNorm2d(ndf * 8), 56 | nn.LeakyReLU(0.2, inplace=True), 57 | 58 | nn.Conv2d(ndf * 8, 1, 2, 2, 0, bias=False), 59 | nn.Sigmoid() 60 | ) 61 | 62 | def forward(self, input): 63 | if input.is_cuda and self.ngpu > 1: 64 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 65 | else: 66 | output = self.main(input) 67 | 68 | return output.view(-1, 1).squeeze(1) 69 | 70 | -------------------------------------------------------------------------------- /CIFAR_10/train_generator/dfgan.py: -------------------------------------------------------------------------------- 1 | # Data-enriching GAN (DeGAN)/ DCGAN for retrieving images from a trained classifier 2 | 3 | from __future__ import print_function 4 | import argparse 5 | import os 6 | import random 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim as optim 12 | import torch.utils.data 13 | import torchvision.datasets as dset 14 | import torchvision.transforms as transforms 15 | import torch.nn.functional as F 16 | from tensorboardX import SummaryWriter 17 | import numpy as np 18 | from dcgan_model import Generator, Discriminator 19 | from alexnet import AlexNet 20 | 21 | writer = SummaryWriter() 22 | 23 | # CUDA_VISIBLE_DEVICES=0 python dfgan.py --dataroot ../../../datasets --imageSize 32 --cuda --outf out_cifar --manualSeed 108 --niter 200 --batchSize 2048 24 | 25 | if __name__ == '__main__': 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--dataroot', required=True, help='path to dataset') 29 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) 30 | parser.add_argument('--batchSize', type=int, default=64, help='input batch size') 31 | parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network') 32 | parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector') 33 | parser.add_argument('--ngf', type=int, default=64) 34 | parser.add_argument('--ndf', type=int, default=64) 35 | parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for') 36 | parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002') 37 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') 38 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 39 | parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') 40 | parser.add_argument('--netG', default='', help="path to netG (to continue training)") 41 | parser.add_argument('--netD', default='', help="path to netD (to continue training)") 42 | parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints') 43 | parser.add_argument('--manualSeed', type=int, help='manual seed') 44 | 45 | opt = parser.parse_args() 46 | print(opt) 47 | 48 | try: 49 | os.makedirs(opt.outf) 50 | except OSError: 51 | pass 52 | 53 | if opt.manualSeed is None: 54 | opt.manualSeed = random.randint(1, 10000) 55 | print("Random Seed: ", opt.manualSeed) 56 | random.seed(opt.manualSeed) 57 | torch.manual_seed(opt.manualSeed) 58 | 59 | cudnn.benchmark = True 60 | 61 | if torch.cuda.is_available() and not opt.cuda: 62 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 63 | 64 | 65 | dataset = dset.CIFAR100(root=opt.dataroot, download=True, 66 | transform=transforms.Compose([ 67 | transforms.Scale(opt.imageSize), 68 | transforms.ToTensor(), 69 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 70 | ])) 71 | 72 | 73 | nc=3 74 | 75 | assert dataset 76 | 77 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, 78 | shuffle=True, num_workers=int(opt.workers)) 79 | 80 | 81 | device = torch.device("cuda:0" if opt.cuda else "cpu") 82 | ngpu = int(opt.ngpu) 83 | nz = int(opt.nz) 84 | ngf = int(opt.ngf) 85 | ndf = int(opt.ndf) 86 | 87 | # custom weights initialization called on netG and netD 88 | def weights_init(m): 89 | classname = m.__class__.__name__ 90 | if classname.find('Conv') != -1: 91 | m.weight.data.normal_(0.0, 0.02) 92 | elif classname.find('BatchNorm') != -1: 93 | m.weight.data.normal_(1.0, 0.02) 94 | m.bias.data.fill_(0) 95 | 96 | netG = Generator(ngpu).to(device) 97 | netG.apply(weights_init) 98 | if opt.netG != '': 99 | netG.load_state_dict(torch.load(opt.netG)) 100 | print(netG) 101 | 102 | netC = AlexNet(ngpu).to(device) 103 | netC.load_state_dict(torch.load('../../train_student_KD/AlexNet/CIFAR10_data/best_model.pth')) 104 | print(netC) 105 | netC.eval() 106 | 107 | netD = Discriminator(ngpu).to(device) 108 | netD.apply(weights_init) 109 | if opt.netD != '': 110 | netD.load_state_dict(torch.load(opt.netD)) 111 | print(netD) 112 | 113 | criterion = nn.BCELoss() 114 | criterion_sum = nn.BCELoss(reduction = 'sum') 115 | 116 | fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device) 117 | real_label = 1 118 | fake_label = 0 119 | 120 | # setup optimizer 121 | optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 122 | optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 123 | threshold = [] 124 | 125 | # Used classes of CIFAR-100 (Background classes used here) 126 | inc_classes = [68, 23, 33, 49, 60, 71] 127 | for epoch in range(opt.niter): 128 | num_greater_thresh = 0 129 | count_class = [0]*10 130 | count_class_less = [0]*10 131 | count_class_hist = [0]*10 132 | count_class_less_hist = [0]*10 133 | classification_loss_sum = 0 134 | errD_real_sum = 0 135 | errD_fake_sum = 0 136 | errD_sum = 0 137 | errG_adv_sum = 0 138 | data_size = 0 139 | accD_real_sum = 0 140 | accD_fake_sum = 0 141 | accG_sum = 0 142 | accD_sum = 0 143 | div_loss_sum = 0 144 | for i, data in enumerate(dataloader, 0): 145 | ############################ 146 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) 147 | ########################### 148 | # train with real 149 | netD.zero_grad() 150 | real_cpu = torch.from_numpy(data[0].numpy()[np.isin(data[1],inc_classes)]).to(device) 151 | batch_size = real_cpu.size(0) 152 | data_size = data_size + batch_size 153 | label = torch.full((batch_size,), real_label, device=device) 154 | output = netD(real_cpu) 155 | 156 | errD_real = criterion(output, label) 157 | errD_real_sum = errD_real_sum + (criterion_sum(output,label)).cpu().data.numpy() 158 | 159 | accD_real = (label[output>0.5]).shape[0] 160 | accD_real_sum = accD_real_sum + float(accD_real) 161 | 162 | errD_real.backward() 163 | 164 | D_x = output.mean().item() 165 | 166 | # train with fake 167 | noise = torch.randn(batch_size, nz, 1, 1, device=device) 168 | fake = netG(noise) 169 | fake_class = netC(fake) 170 | sm_fake_class = F.softmax(fake_class, dim=1) 171 | 172 | class_max = fake_class.max(1,keepdim=True)[0] 173 | class_argmax = fake_class.max(1,keepdim=True)[1] 174 | 175 | # Classification loss 176 | classification_loss = torch.mean(torch.sum(-sm_fake_class*torch.log(sm_fake_class+1e-5),dim=1)) 177 | classification_loss_add = torch.sum(-sm_fake_class*torch.log(sm_fake_class+1e-5)) 178 | classification_loss_sum = classification_loss_sum + (classification_loss_add).cpu().data.numpy() 179 | 180 | sm_batch_mean = torch.mean(sm_fake_class,dim=0) 181 | div_loss = torch.sum(sm_batch_mean*torch.log(sm_batch_mean)) # Maximize entropy across batch 182 | div_loss_sum = div_loss_sum + div_loss*batch_size 183 | 184 | label.fill_(fake_label) 185 | output = netD(fake.detach()) 186 | 187 | errD_fake = criterion(output, label) 188 | errD_fake_sum = errD_fake_sum + (criterion_sum(output, label)).cpu().data.numpy() 189 | 190 | accD_fake = (label[output<=0.5]).shape[0] 191 | accD_fake_sum = accD_fake_sum + float(accD_fake) 192 | 193 | errD_fake.backward() 194 | D_G_z1 = output.mean().item() 195 | 196 | errD = errD_real + errD_fake 197 | errD_sum = errD_real_sum + errD_fake_sum 198 | 199 | accD = accD_real + accD_fake 200 | accD_sum = accD_real_sum + accD_fake_sum 201 | 202 | optimizerD.step() 203 | 204 | ############################ 205 | # (2) Update G network: maximize log(D(G(z))) 206 | ########################### 207 | netG.zero_grad() 208 | label.fill_(real_label) # fake labels are real for generator cost 209 | output = netD(fake) 210 | c_l = 0 # Hyperparameter to weigh entropy loss 211 | d_l = 5 # Hyperparameter to weigh the diversity loss 212 | errG_adv = criterion(output, label) 213 | errG_adv_sum = errG_adv_sum + (criterion_sum(output, label)).cpu().data.numpy() 214 | 215 | accG = (label[output>0.5]).shape[0] 216 | accG_sum = accG_sum + float(accG) 217 | 218 | errG = errG_adv + c_l * classification_loss + d_l * div_loss 219 | errG_sum = errG_adv_sum + c_l * classification_loss_sum + d_l * div_loss_sum 220 | errG.backward() 221 | D_G_z2 = output.mean().item() 222 | optimizerG.step() 223 | 224 | print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' 225 | % (epoch, opt.niter, i, len(dataloader), 226 | errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) 227 | pred_class = F.softmax(fake_class,dim=1).max(1, keepdim=True)[0] 228 | pred_class_argmax = F.softmax(fake_class,dim=1).max(1, keepdim=True)[1] 229 | num_greater_thresh = num_greater_thresh + (torch.sum(pred_class > 0.9).cpu().data.numpy()) 230 | for argmax, val in zip(pred_class_argmax, pred_class): 231 | if val > 0.9: 232 | count_class_hist.append(argmax) 233 | count_class[argmax] = count_class[argmax] + 1 234 | else: 235 | count_class_less_hist.append(argmax) 236 | count_class_less[argmax] = count_class_less[argmax] + 1 237 | 238 | if i % 100 == 0: 239 | writer.add_image("Gen Imgs Training", (fake+1)/2, epoch) 240 | 241 | # do checkpointing 242 | torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch)) 243 | torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch)) 244 | 245 | # Generate fake samples for visualization 246 | 247 | test_size = 1000 248 | noise_test = torch.randn(test_size, nz, 1, 1, device=device) 249 | fake_test = netG(noise_test) 250 | fake_test_class = netC(fake_test) 251 | pred_test_class_max = F.softmax(fake_test_class,dim=1).max(1, keepdim=True)[0] 252 | pred_test_class_argmax = F.softmax(fake_test_class,dim=1).max(1, keepdim=True)[1] 253 | 254 | for i in range(10): 255 | print("Score>0.9: Class",i,":",torch.sum(((pred_test_class_argmax.view(test_size)==i) & (pred_test_class_max.view(test_size)>0.9)).float())) 256 | print("Score<0.9: Class",i,":",torch.sum(((pred_test_class_argmax.view(test_size)==i) & (pred_test_class_max.view(test_size)<0.9)).float())) 257 | 258 | if fake_test[pred_test_class_argmax.view(test_size)==0].shape[0] > 0: 259 | writer.add_image("Gen Imgs Test: Airplane", (fake_test[pred_test_class_argmax.view(test_size)==0]+1)/2, epoch) 260 | if fake_test[pred_test_class_argmax.view(test_size)==1].shape[0] > 0: 261 | writer.add_image("Gen Imgs Test: Automobile", (fake_test[pred_test_class_argmax.view(test_size)==1]+1)/2, epoch) 262 | if fake_test[pred_test_class_argmax.view(test_size)==2].shape[0] > 0: 263 | writer.add_image("Gen Imgs Test: Bird", (fake_test[pred_test_class_argmax.view(test_size)==2]+1)/2, epoch) 264 | if fake_test[pred_test_class_argmax.view(test_size)==3].shape[0] > 0: 265 | writer.add_image("Gen Imgs Test: Cat", (fake_test[pred_test_class_argmax.view(test_size)==3]+1)/2, epoch) 266 | if fake_test[pred_test_class_argmax.view(test_size)==4].shape[0] > 0: 267 | writer.add_image("Gen Imgs Test: Deer", (fake_test[pred_test_class_argmax.view(test_size)==4]+1)/2, epoch) 268 | if fake_test[pred_test_class_argmax.view(test_size)==5].shape[0] > 0: 269 | writer.add_image("Gen Imgs Test: Dog", (fake_test[pred_test_class_argmax.view(test_size)==5]+1)/2, epoch) 270 | if fake_test[pred_test_class_argmax.view(test_size)==6].shape[0] > 0: 271 | writer.add_image("Gen Imgs Test: Frog", (fake_test[pred_test_class_argmax.view(test_size)==6]+1)/2, epoch) 272 | if fake_test[pred_test_class_argmax.view(test_size)==7].shape[0] > 0: 273 | writer.add_image("Gen Imgs Test: Horse", (fake_test[pred_test_class_argmax.view(test_size)==7]+1)/2, epoch) 274 | if fake_test[pred_test_class_argmax.view(test_size)==8].shape[0] > 0: 275 | writer.add_image("Gen Imgs Test: Ship", (fake_test[pred_test_class_argmax.view(test_size)==8]+1)/2, epoch) 276 | if fake_test[pred_test_class_argmax.view(test_size)==9].shape[0] > 0: 277 | writer.add_image("Gen Imgs Test: Truck", (fake_test[pred_test_class_argmax.view(test_size)==9]+1)/2, epoch) 278 | 279 | print(count_class , "Above 0.9") 280 | print(count_class_less, "Below 0.9") 281 | writer.add_histogram("above 0.9", np.asarray(count_class), epoch, bins=10) 282 | writer.add_histogram("above 0.9", np.asarray(count_class), epoch, bins=10) 283 | threshold.append(num_greater_thresh) 284 | 285 | writer.add_scalar("1 Train Discriminator accuracy(all)", accD_sum/ (2*data_size), epoch) 286 | writer.add_scalar("2 Train Discriminator accuracy(fake)", accD_fake_sum/ data_size, epoch) 287 | writer.add_scalar("3 Train Discriminator accuracy(real)", accD_real_sum/ data_size, epoch) 288 | writer.add_scalar("4 Train Generator accuracy(fake)", accG_sum/ data_size, epoch) 289 | writer.add_scalar("5 Train Discriminator loss (real)", errD_real_sum/ data_size, epoch) 290 | writer.add_scalar("6 Train Discriminator loss (fake)", errD_fake_sum/ data_size, epoch) 291 | writer.add_scalar("7 Train Discriminator loss (all)", errD_sum/(2* data_size), epoch) 292 | writer.add_scalar("8 Train Generator loss (adv)", errG_adv_sum/ data_size, epoch) 293 | writer.add_scalar("9 Train Generator loss (classification)", classification_loss_sum/ data_size, epoch) 294 | writer.add_scalar("10 Train Generator loss (diversity)", div_loss_sum/ data_size, epoch) 295 | writer.add_scalar("11 Train Generator loss (all)", errG_sum/ data_size, epoch) 296 | 297 | writer.export_scalars_to_json("./all_scalars.json") 298 | 299 | writer.close() 300 | 301 | -------------------------------------------------------------------------------- /CIFAR_10/train_student/Using_Data/KD_related_data.py: -------------------------------------------------------------------------------- 1 | # Training a Student network on CIFAR-10 2 | # Teacher architecture: AlexNet, Student architecture: AlexNet half 3 | from __future__ import print_function 4 | import argparse 5 | import torch 6 | import os 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torchvision import datasets, transforms 11 | from alexnet import AlexNet 12 | from alexnet_half import AlexNet_half 13 | from torch.utils.data.sampler import SubsetRandomSampler 14 | from tensorboardX import SummaryWriter 15 | import numpy as np 16 | 17 | # CUDA_VISIBLE_DEVICES=0 python KD_related_data.py --batch-size 2048 --test-batch-size 1000 --epochs 5000 --lr 0.001 --seed 108 --log-interval 10 --temp 20 --lambda_ 1 18 | 19 | writer = SummaryWriter() 20 | if not os.path.exists("models"): 21 | os.makedirs("models") 22 | 23 | def train(args, model, netS, device, train_loader, optimizer, epoch, temp, inc_classes): 24 | model.eval() 25 | netS.train() 26 | loss_all_sum = 0 27 | tot = 0 28 | teacher_student_correct_sum = 0 29 | for batch_idx, (data, target) in enumerate(train_loader): 30 | data = torch.from_numpy(data.numpy()[np.isin(target,inc_classes)]).to(device) 31 | if data.shape[0] == 0: 32 | continue 33 | tot += data.shape[0] 34 | optimizer.zero_grad() 35 | data = data*2 - 1 36 | output_teacher_logits = model(data) 37 | output_student_logits = netS(data) 38 | output_teacher_logits_ht = output_teacher_logits / temp 39 | output_student_logits_ht = output_student_logits / temp 40 | sm_teacher_ht = F.softmax(output_teacher_logits_ht,dim=1) 41 | sm_student_ht = F.softmax(output_student_logits_ht,dim=1) 42 | sm_teacher = F.softmax(output_teacher_logits, dim=1) 43 | sm_student = F.softmax(output_student_logits, dim=1) 44 | loss_kd = nn.KLDivLoss(reduction='sum')(F.log_softmax(output_student_logits_ht, dim=1),F.softmax(output_teacher_logits_ht, dim=1)) 45 | pred_class_argmax_teacher = sm_teacher.max(1, keepdim=True)[1] 46 | loss_ce = F.cross_entropy(output_student_logits, pred_class_argmax_teacher.view(data.shape[0]),reduction='sum') 47 | loss_all = args.lambda_*temp*temp*loss_kd + (1-args.lambda_)*loss_ce 48 | loss_all.backward() 49 | loss_all_sum += loss_all 50 | pred_class_argmax_student = sm_student.max(1, keepdim=True)[1] 51 | pred_class_argmax_teacher = pred_class_argmax_teacher.view(sm_teacher.shape[0]) 52 | pred_class_argmax_student = pred_class_argmax_student.view(sm_teacher.shape[0]) 53 | teacher_student_correct = torch.sum(pred_class_argmax_student==pred_class_argmax_teacher) 54 | teacher_student_correct_sum = teacher_student_correct_sum + (teacher_student_correct).cpu().data.numpy() 55 | optimizer.step() 56 | if batch_idx % args.log_interval == 0: 57 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 58 | epoch, int((batch_idx+1) * len(data)), int(len(train_loader.dataset)*0.8), 59 | 100. * (batch_idx+1) / len(train_loader), (loss_all/data.shape[0]).item())) 60 | loss_all_mean = loss_all_sum / tot 61 | teacher_student_acc = 100. * teacher_student_correct_sum / tot 62 | print('Train set: Average loss: {:.4f}, Teacher-Student Accuracy: {}/{} ({:.0f}% )'.format( 63 | loss_all_mean, teacher_student_correct_sum, tot, teacher_student_acc)) 64 | torch.save(netS.state_dict(), "models/"+str(epoch)+".pth") 65 | return loss_all_mean, teacher_student_acc 66 | 67 | def val(args, model, netS, device, test_loader, epoch, val_test, temp, inc_classes): 68 | model.eval() 69 | netS.eval() 70 | loss_all_sum = 0 71 | tot = 0 72 | teacher_student_correct_sum = 0 73 | with torch.no_grad(): 74 | for batch_idx, (data, target) in enumerate(test_loader): 75 | if data.shape[0] == 0: 76 | continue 77 | data = torch.from_numpy(data.numpy()[np.isin(target,inc_classes)]).to(device) 78 | tot += data.shape[0] 79 | data = data*2 - 1 80 | output_teacher_logits = model(data) 81 | output_student_logits = netS(data) 82 | output_teacher_logits_ht = output_teacher_logits / temp 83 | output_student_logits_ht = output_student_logits / temp 84 | sm_teacher_ht = F.softmax(output_teacher_logits_ht,dim=1) 85 | sm_student_ht = F.softmax(output_student_logits_ht,dim=1) 86 | sm_teacher = F.softmax(output_teacher_logits, dim=1) 87 | sm_student = F.softmax(output_student_logits, dim=1) 88 | loss_kd = nn.KLDivLoss(reduction='sum')(F.log_softmax(output_student_logits_ht, dim=1),F.softmax(output_teacher_logits_ht, dim=1)) 89 | pred_class_argmax_teacher = sm_teacher.max(1, keepdim=True)[1] 90 | loss_ce = F.cross_entropy(output_student_logits, pred_class_argmax_teacher.view(data.shape[0]),reduction='sum') 91 | loss_all = args.lambda_*temp*temp*loss_kd + (1-args.lambda_)*loss_ce 92 | loss_all_sum += loss_all 93 | pred_class_argmax_student = sm_student.max(1, keepdim=True)[1] 94 | pred_class_argmax_teacher = pred_class_argmax_teacher.view(sm_teacher.shape[0]) 95 | pred_class_argmax_student = pred_class_argmax_student.view(sm_teacher.shape[0]) 96 | teacher_student_correct = torch.sum(pred_class_argmax_student==pred_class_argmax_teacher) 97 | teacher_student_correct_sum = teacher_student_correct_sum + (teacher_student_correct).cpu().data.numpy() 98 | loss_all_mean = loss_all_sum / tot 99 | teacher_student_acc = 100. * teacher_student_correct_sum / tot 100 | print('{} set: Average loss: {:.4f}, Teacher-Student Accuracy: {}/{} ({:.0f}% )'.format( 101 | val_test, loss_all_mean, teacher_student_correct_sum, tot, teacher_student_acc)) 102 | return loss_all_mean, teacher_student_acc 103 | 104 | 105 | def test(args, model, netS, device, test_loader, epoch, val_test, temp): 106 | model.eval() 107 | netS.eval() 108 | loss_all_sum = 0 109 | tot = 0 110 | student_correct_sum = 0 111 | teacher_student_correct_sum = 0 112 | with torch.no_grad(): 113 | for batch_idx, (data, target) in enumerate(test_loader): 114 | data, target = data.to(device), target.to(device) 115 | tot += data.shape[0] 116 | data = data*2 - 1 117 | output_teacher_logits = model(data) 118 | output_student_logits = netS(data) 119 | output_teacher_logits_ht = output_teacher_logits / temp 120 | output_student_logits_ht = output_student_logits / temp 121 | sm_teacher_ht = F.softmax(output_teacher_logits_ht,dim=1) 122 | sm_student_ht = F.softmax(output_student_logits_ht,dim=1) 123 | sm_teacher = F.softmax(output_teacher_logits, dim=1) 124 | sm_student = F.softmax(output_student_logits, dim=1) 125 | loss_kd = nn.KLDivLoss(reduction='sum')(F.log_softmax(output_student_logits_ht, dim=1),F.softmax(output_teacher_logits_ht, dim=1)) 126 | pred_class_argmax_teacher = sm_teacher.max(1, keepdim=True)[1] 127 | loss_ce = F.cross_entropy(output_student_logits, pred_class_argmax_teacher.view(data.shape[0]),reduction='sum') 128 | loss_all = args.lambda_*temp*temp*loss_kd + (1-args.lambda_)*loss_ce 129 | loss_all_sum += loss_all 130 | pred_class_argmax_student = sm_student.max(1, keepdim=True)[1] 131 | pred_class_argmax_teacher = pred_class_argmax_teacher.view(sm_teacher.shape[0]) 132 | pred_class_argmax_student = pred_class_argmax_student.view(sm_teacher.shape[0]) 133 | student_correct = torch.sum(pred_class_argmax_student==target) 134 | student_correct_sum = student_correct_sum + (student_correct).cpu().data.numpy() 135 | teacher_student_correct = torch.sum(pred_class_argmax_student==pred_class_argmax_teacher) 136 | teacher_student_correct_sum = teacher_student_correct_sum + (teacher_student_correct).cpu().data.numpy() 137 | loss_all_mean = loss_all_sum / tot 138 | student_acc = 100. * student_correct_sum / tot 139 | teacher_student_acc = 100. * teacher_student_correct_sum / tot 140 | print('{} set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%), Teacher-Student Accuracy: {}/{} ({:.0f}% )'.format( 141 | val_test, loss_all_mean, student_correct_sum, tot, student_acc, teacher_student_correct_sum, tot, teacher_student_acc)) 142 | return loss_all_mean, student_acc, teacher_student_acc 143 | 144 | def main(): 145 | # Training settings 146 | parser = argparse.ArgumentParser(description='CIFAR Classifier training') 147 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 148 | help='input batch size for training (default: 64)') 149 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 150 | help='input batch size for testing (default: 1000)') 151 | parser.add_argument('--epochs', type=int, default=1000, metavar='N', 152 | help='number of epochs to train (default: 10)') 153 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', 154 | help='learning rate (default: 0.001)') 155 | parser.add_argument('--no-cuda', action='store_true', default=False, 156 | help='disables CUDA training') 157 | parser.add_argument('--seed', type=int, default=1, metavar='S', 158 | help='random seed (default: 1)') 159 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 160 | help='how many batches to wait before logging training status') 161 | parser.add_argument('--temp', default=20, type=float, help='Temperature for KD') 162 | parser.add_argument('--lambda_', default=1, type=float, help='Weight of KD Loss during distillation') 163 | 164 | args = parser.parse_args() 165 | use_cuda = not args.no_cuda and torch.cuda.is_available() 166 | torch.manual_seed(args.seed) 167 | device = torch.device("cuda" if use_cuda else "cpu") 168 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 169 | classes = ('plane', 'car', 'bird', 'cat', 170 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 171 | tfm = transforms.Compose([ 172 | transforms.ToTensor() 173 | ]) 174 | train_dataset = datasets.CIFAR100( 175 | root='../../../../datasets', train=True, 176 | download=True, transform=tfm) 177 | val_dataset = datasets.CIFAR100( 178 | root='../../../../datasets', train=True, 179 | download=True, transform=tfm) 180 | num_train = len(train_dataset) 181 | indices = list(range(num_train)) 182 | split = int(np.floor(0.2 * num_train)) 183 | np.random.seed(args.seed) 184 | np.random.shuffle(indices) 185 | train_idx, val_idx = indices[split:], indices[:split] 186 | train_sampler = SubsetRandomSampler(train_idx) 187 | val_sampler = SubsetRandomSampler(val_idx) 188 | train_loader = torch.utils.data.DataLoader( 189 | train_dataset, batch_size=args.batch_size, sampler=train_sampler, **kwargs 190 | ) 191 | val_loader = torch.utils.data.DataLoader( 192 | val_dataset, batch_size=args.test_batch_size, sampler=val_sampler,**kwargs 193 | ) 194 | test_loader = torch.utils.data.DataLoader( 195 | datasets.CIFAR10('../../../../datasets', train=False, download=True, 196 | transform=transforms.Compose([ 197 | transforms.ToTensor(), 198 | ])), 199 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 200 | model = AlexNet().to(device) 201 | model.eval() 202 | model.load_state_dict(torch.load("../CIFAR10_data/best_model.pth")) 203 | netS = AlexNet_half().to(device) 204 | #netS.load_state_dict(torch.load("./models/best_model.pth")) 205 | optimizer = optim.Adam(netS.parameters(), lr=args.lr) 206 | best_val_acc = 0 207 | cnt = 0 208 | temp = args.temp 209 | # Used classes of CIFAR-100 210 | inc_classes = [70, 47, 49, 37, 86, 53, 16, 94, 54, 25] 211 | for epoch in range(1, args.epochs + 1): 212 | train_loss_kd, train_teacher_student_acc = train(args, model, netS, device, train_loader, optimizer, epoch, temp, inc_classes) 213 | val_loss_kd, val_teacher_student_acc = val(args, model, netS, device, val_loader, epoch, 'Validation', temp, inc_classes) 214 | test_loss_kd, test_student_acc, test_teacher_student_acc = test(args, model, netS, device, test_loader, epoch, 'Test', temp) 215 | if val_teacher_student_acc > best_val_acc: 216 | print("Saving best model...") 217 | torch.save(netS.state_dict(), "models/best_model_lr1.pth") 218 | best_val_acc = val_teacher_student_acc 219 | cnt = 0 220 | train_st_acc_lr1 = train_teacher_student_acc 221 | val_st_acc_lr1 = val_teacher_student_acc 222 | test_st_acc_lr1 = test_teacher_student_acc 223 | test_acc_lr1 = test_student_acc 224 | else: 225 | cnt += 1 226 | writer.add_scalar("1_Train loss", train_loss_kd, epoch) 227 | writer.add_scalar("2_Validation loss", val_loss_kd, epoch) 228 | writer.add_scalar("3_Test loss", test_loss_kd, epoch) 229 | writer.add_scalar("7_Test accuracy", test_student_acc, epoch) 230 | writer.add_scalar("4_Train Teacher-Student accuracy", train_teacher_student_acc, epoch) 231 | writer.add_scalar("5_Validation Teacher-Student accuracy", val_teacher_student_acc, epoch) 232 | writer.add_scalar("6_Test Teacher-Student accuracy", test_teacher_student_acc, epoch) 233 | if cnt > 100: 234 | print('Model has converged with learning rate = {}!'.format(args.lr)) 235 | break 236 | n_epochs_lr1 = epoch 237 | optimizer = optim.Adam(netS.parameters(), lr=args.lr*0.1) 238 | netS.load_state_dict(torch.load("models/best_model_lr1.pth")) 239 | cnt = 0 240 | train_st_acc_lr2 = train_st_acc_lr1 241 | val_st_acc_lr2 = val_st_acc_lr1 242 | test_st_acc_lr2 = test_st_acc_lr1 243 | test_acc_lr2 = test_acc_lr1 244 | torch.save(netS.state_dict(), "models/best_model_lr2.pth") 245 | for epoch in range(1, args.epochs + 1): 246 | train_loss_kd, train_teacher_student_acc = train(args, model, netS, device, train_loader, optimizer, epoch + n_epochs_lr1, temp, inc_classes) 247 | val_loss_kd, val_teacher_student_acc = val(args, model, netS, device, val_loader, epoch + n_epochs_lr1, 'Validation', temp, inc_classes) 248 | test_loss_kd, test_student_acc, test_teacher_student_acc = test(args, model, netS, device, test_loader, epoch + n_epochs_lr1, 'Test', temp) 249 | if val_teacher_student_acc > best_val_acc: 250 | print("Saving best model...") 251 | torch.save(netS.state_dict(), "models/best_model_lr2.pth") 252 | best_val_acc = val_teacher_student_acc 253 | cnt = 0 254 | train_st_acc_lr2 = train_teacher_student_acc 255 | val_st_acc_lr2 = val_teacher_student_acc 256 | test_st_acc_lr2 = test_teacher_student_acc 257 | test_acc_lr2 = test_student_acc 258 | else: 259 | cnt += 1 260 | writer.add_scalar("1_Train loss", train_loss_kd, epoch + n_epochs_lr1) 261 | writer.add_scalar("2_Validation loss", val_loss_kd, epoch + n_epochs_lr1) 262 | writer.add_scalar("3_Test loss", test_loss_kd, epoch + n_epochs_lr1) 263 | writer.add_scalar("7_Test accuracy", test_student_acc, epoch + n_epochs_lr1) 264 | writer.add_scalar("4_Train Teacher-Student accuracy", train_teacher_student_acc, epoch + n_epochs_lr1) 265 | writer.add_scalar("5_Validation Teacher-Student accuracy", val_teacher_student_acc, epoch + n_epochs_lr1) 266 | writer.add_scalar("6_Test Teacher-Student accuracy", test_teacher_student_acc, epoch + n_epochs_lr1) 267 | if cnt > 100: 268 | print('Model has converged with learning rate = {}!'.format(args.lr*0.1)) 269 | break 270 | 271 | n_epochs_lr2 = epoch 272 | print('Number of epochs with lr = {} are {} and number of epochs with lr = {} are {}'.format( 273 | args.lr, n_epochs_lr1, args.lr*0.1, n_epochs_lr2)) 274 | print('Accuracy with lr = {}: Train ST accuracy = {:.2f}%, Validation ST accuracy = {:.2f}%, Test ST accuracy = {:.2f}%, Test accuracy = {:.2f}%'.format( 275 | args.lr, train_st_acc_lr1, val_st_acc_lr1, test_st_acc_lr1, test_acc_lr1)) 276 | print('Accuracy with lr = {}: Train ST accuracy = {:.2f}%, Validation ST accuracy = {:.2f}%, Test ST accuracy = {:.2f}%, Test accuracy = {:.2f}%'.format( 277 | args.lr*0.1, train_st_acc_lr2, val_st_acc_lr2, test_st_acc_lr2, test_acc_lr2)) 278 | 279 | writer.close() 280 | 281 | if __name__ == '__main__': 282 | main() 283 | -------------------------------------------------------------------------------- /CIFAR_10/train_student/Using_GAN/KD_dfgan.py: -------------------------------------------------------------------------------- 1 | # Code to distill the knowledge from a Teacher to Student using data generated by a Generator 2 | 3 | from __future__ import print_function 4 | import argparse 5 | import os 6 | import random 7 | import torch 8 | import torch.nn as nn 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim as optim 11 | import torch.utils.data 12 | import torchvision.datasets as dset 13 | import torchvision.transforms as transforms 14 | import torch.nn.functional as F 15 | from tensorboardX import SummaryWriter 16 | import numpy as np 17 | from dcgan_model import Generator, Discriminator 18 | from alexnet import AlexNet 19 | from alexnet_half import AlexNet_half 20 | 21 | writer = SummaryWriter() 22 | 23 | # CUDA_VISIBLE_DEVICES=0 python KD_dfgan.py --dataroot ../../../../datasets --cuda --outf models --manualSeed 108 --niter 5000 --lambda_ 1 --temp 20 --netG ../../train_generator/out_cifar/netG_epoch_199.pth --netC ./best_model.pth 24 | 25 | if __name__ == '__main__': 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--dataroot', required=True, help='path to test dataset') 29 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) 30 | parser.add_argument('--batchSize', type=int, default=64, help='input batch size') 31 | parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector') 32 | parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for') 33 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate, default=0.0002') 34 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') 35 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 36 | parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') 37 | parser.add_argument('--netG', required=True, help="path to Generator network weights") 38 | parser.add_argument('--netC', required=True, help="path to Teacher network weights") 39 | parser.add_argument('--netS', default='', help="path to Student network weights (to continue training)") 40 | parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints') 41 | parser.add_argument('--manualSeed', type=int, help='manual seed') 42 | parser.add_argument('--temp', default=10, type=float, help='Temperature for KD') 43 | parser.add_argument('--lambda_', default=1, type=float, help='Weight of KD Loss during distillation') 44 | parser.add_argument('--nBatches', default=256, type=float, help='Number of Batches') 45 | 46 | opt = parser.parse_args() 47 | print(opt) 48 | 49 | try: 50 | os.makedirs(opt.outf) 51 | except OSError: 52 | pass 53 | 54 | if opt.manualSeed is None: 55 | opt.manualSeed = random.randint(1, 10000) 56 | print("Random Seed: ", opt.manualSeed) 57 | random.seed(opt.manualSeed) 58 | torch.manual_seed(opt.manualSeed) 59 | 60 | cudnn.benchmark = True 61 | 62 | if torch.cuda.is_available() and not opt.cuda: 63 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 64 | 65 | nc=3 66 | 67 | transform=transforms.Compose([ 68 | transforms.ToTensor(), 69 | ]) 70 | 71 | test_loader = torch.utils.data.DataLoader( 72 | dset.CIFAR10(opt.dataroot, train=False, download=True, transform=transform), 73 | batch_size=opt.batchSize, shuffle=False) 74 | 75 | device = torch.device("cuda:0" if opt.cuda else "cpu") 76 | ngpu = int(opt.ngpu) 77 | nz = int(opt.nz) 78 | 79 | netG = Generator(ngpu).to(device) 80 | netG.load_state_dict(torch.load(opt.netG)) 81 | print(netG) 82 | netG.eval() 83 | 84 | netC = AlexNet().to(device) 85 | netC.load_state_dict(torch.load(opt.netC)) 86 | print(netC) 87 | netC.eval() 88 | 89 | netS = AlexNet_half().to(device) 90 | if opt.netS != '': 91 | netS.load_state_dict(torch.load(opt.netS)) 92 | print(netS) 93 | temp = opt.temp 94 | batch_size = int(opt.batchSize) 95 | n_batches = int(opt.nBatches) 96 | # setup optimizer 97 | threshold = [] 98 | best_val_acc = 0 99 | cnt = 0 100 | n_epochs_lr1 = 0 101 | for lr_cnt in range(2): 102 | if lr_cnt == 0: 103 | lrate = opt.lr 104 | else: 105 | lrate = opt.lr * 0.1 106 | netS.load_state_dict(torch.load('models/best_model_lr1.pth')) 107 | train_st_acc_lr2 = train_st_acc_lr1 108 | val_st_acc_lr2 = val_st_acc_lr1 109 | test_st_acc_lr2 = test_st_acc_lr1 110 | test_acc_lr2 = test_acc_lr1 111 | torch.save(netS.state_dict(), "models/best_model_lr2.pth") 112 | optimizerS = optim.Adam(netS.parameters(), lr=lrate, betas=(opt.beta1, 0.999)) 113 | for epoch in range(1, opt.niter+1): 114 | loss_kd_sum = 0 115 | loss_ce_sum = 0 116 | loss_all_sum = 0 117 | teacher_student_correct_sum = 0 118 | netS.train() 119 | for i in range(n_batches): 120 | optimizerS.zero_grad() 121 | noise_rand = torch.randn(batch_size, nz, 1, 1, device=device) 122 | fake_train = netG(noise_rand) 123 | fake_train_class = netC(fake_train) 124 | fake_student_class = netS(fake_train) 125 | fake_train_class_ht = fake_train_class/temp 126 | fake_student_class_ht = fake_student_class/temp 127 | sm_teacher_ht = F.softmax(fake_train_class_ht, dim=1) 128 | sm_student_ht = F.softmax(fake_student_class_ht, dim=1) 129 | sm_teacher = F.softmax(fake_train_class, dim=1) 130 | sm_student = F.softmax(fake_student_class, dim=1) 131 | pred_class_argmax_teacher = sm_teacher.max(1, keepdim=True)[1] 132 | loss_kd = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(fake_student_class_ht, dim=1),F.softmax(fake_train_class_ht, dim=1)) 133 | loss_ce = F.cross_entropy(fake_student_class, pred_class_argmax_teacher.view(batch_size)) 134 | loss_all = opt.lambda_*temp*temp*loss_kd + (1-opt.lambda_)*loss_ce 135 | loss_kd_sum = loss_kd_sum + loss_kd 136 | loss_ce_sum = loss_ce_sum + loss_ce 137 | loss_all_sum = loss_all_sum + loss_all 138 | loss_all.backward() 139 | optimizerS.step() 140 | pred_class_argmax_student = sm_student.max(1, keepdim=True)[1] 141 | pred_class_argmax_teacher = pred_class_argmax_teacher.view(sm_teacher.shape[0]) 142 | pred_class_argmax_student = pred_class_argmax_student.view(sm_teacher.shape[0]) 143 | teacher_student_correct = torch.sum(pred_class_argmax_student==pred_class_argmax_teacher) 144 | teacher_student_correct_sum = teacher_student_correct_sum + (teacher_student_correct).cpu().data.numpy() 145 | # do checkpointing 146 | torch.save(netS.state_dict(), '%s/netS_epoch_%d.pth' % (opt.outf, epoch + n_epochs_lr1)) 147 | loss_kd_val_sum = 0 148 | loss_ce_val_sum = 0 149 | loss_all_val_sum = 0 150 | teacher_student_correct_val_sum = 0 151 | netS.eval() 152 | with torch.no_grad(): 153 | for i in range(int(np.floor(n_batches/4))): 154 | noise_rand = torch.randn(batch_size, nz, 1, 1, device=device) 155 | fake_train = netG(noise_rand) 156 | fake_train_class = netC(fake_train) 157 | fake_student_class = netS(fake_train) 158 | fake_train_class_ht = fake_train_class/temp 159 | fake_student_class_ht = fake_student_class/temp 160 | sm_teacher_ht = F.softmax(fake_train_class_ht, dim=1) 161 | sm_student_ht = F.softmax(fake_student_class_ht, dim=1) 162 | sm_teacher = F.softmax(fake_train_class, dim=1) 163 | sm_student = F.softmax(fake_student_class, dim=1) 164 | pred_class_argmax_teacher = sm_teacher.max(1, keepdim=True)[1] 165 | loss_kd = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(fake_student_class_ht, dim=1),F.softmax(fake_train_class_ht, dim=1)) 166 | loss_ce = F.cross_entropy(fake_student_class, pred_class_argmax_teacher.view(batch_size)) 167 | loss_all = opt.lambda_*temp*temp*loss_kd + (1-opt.lambda_)*loss_ce 168 | loss_kd_val_sum = loss_kd_val_sum + loss_kd 169 | loss_ce_val_sum = loss_ce_val_sum + loss_ce 170 | loss_all_val_sum = loss_all_val_sum + loss_all 171 | pred_class_argmax_student = sm_student.max(1, keepdim=True)[1] 172 | pred_class_argmax_teacher = pred_class_argmax_teacher.view(sm_teacher.shape[0]) 173 | pred_class_argmax_student = pred_class_argmax_student.view(sm_teacher.shape[0]) 174 | teacher_student_correct = torch.sum(pred_class_argmax_student==pred_class_argmax_teacher) 175 | teacher_student_correct_val_sum = teacher_student_correct_val_sum + (teacher_student_correct).cpu().data.numpy() 176 | teacher_acc_sum = 0.0 177 | student_acc_sum = 0.0 178 | teacher_student_correct_test_sum = 0.0 179 | num = 0.0 180 | for data, target in test_loader: 181 | data, target = data.to(device), target.to(device) 182 | data = data*2 - 1 183 | test_class_teacher = netC(data) 184 | test_class_student = netS(data) 185 | sm_teacher_test = F.softmax(test_class_teacher, dim=1) 186 | sm_student_test = F.softmax(test_class_student, dim=1) 187 | pred_class_argmax_teacher_test = sm_teacher_test.max(1, keepdim=True)[1] 188 | pred_class_argmax_student_test = sm_student_test.max(1, keepdim=True)[1] 189 | pred_class_argmax_teacher_test = pred_class_argmax_teacher_test.view(target.shape[0]) 190 | pred_class_argmax_student_test = pred_class_argmax_student_test.view(target.shape[0]) 191 | teacher_acc = torch.sum(pred_class_argmax_teacher_test==target) 192 | student_acc = torch.sum(pred_class_argmax_student_test==target) 193 | teacher_acc_sum = teacher_acc_sum + teacher_acc 194 | student_acc_sum = student_acc_sum + student_acc 195 | num = num + target.shape[0] 196 | teacher_student_correct = torch.sum(pred_class_argmax_student_test==pred_class_argmax_teacher_test) 197 | teacher_student_correct_test_sum = teacher_student_correct_test_sum + (teacher_student_correct).cpu().data.numpy() 198 | teacher_acc_mean = float(teacher_acc_sum) / float(num) 199 | student_acc_mean = float(student_acc_sum) / float(num) 200 | teacher_student_correct_test_mean = float(teacher_student_correct_test_sum) / float(num) 201 | val_student_acc = teacher_student_correct_val_sum / (float(np.floor(n_batches/4))*batch_size) 202 | train_student_acc = teacher_student_correct_sum/ float(n_batches*batch_size) 203 | if val_student_acc > best_val_acc: 204 | print("Saving best model...") 205 | if lr_cnt ==0 : 206 | torch.save(netS.state_dict(), "models/best_model_lr1.pth") 207 | train_st_acc_lr1 = train_student_acc 208 | val_st_acc_lr1 = val_student_acc 209 | test_st_acc_lr1 = teacher_student_correct_test_mean 210 | test_acc_lr1 = student_acc_mean 211 | else: 212 | torch.save(netS.state_dict(), "models/best_model_lr2.pth") 213 | train_st_acc_lr2 = train_student_acc 214 | val_st_acc_lr2 = val_student_acc 215 | test_st_acc_lr2 = teacher_student_correct_test_mean 216 | test_acc_lr2 = student_acc_mean 217 | best_val_acc = val_student_acc 218 | cnt = 0 219 | 220 | else: 221 | cnt += 1 222 | print("Epoch",epoch + n_epochs_lr1,"/",opt.niter) 223 | print("Teacher accuracy=",round(teacher_acc_mean*100,2),"%, Student accuracy=",round(student_acc_mean*100,2),"%") 224 | writer.add_scalar("KD loss train", loss_kd_sum/ n_batches, epoch + n_epochs_lr1) 225 | writer.add_scalar("KD loss val", loss_kd_val_sum/ float(np.floor(n_batches/4)), epoch + n_epochs_lr1) 226 | writer.add_scalar("CE loss train", loss_ce_sum/ n_batches, epoch + n_epochs_lr1) 227 | writer.add_scalar("CE loss val", loss_ce_val_sum/ float(np.floor(n_batches/4)), epoch + n_epochs_lr1) 228 | writer.add_scalar("Total loss train", loss_all_sum/ n_batches, epoch + n_epochs_lr1) 229 | writer.add_scalar("Total loss val", loss_all_val_sum/ float(np.floor(n_batches/4)), epoch + n_epochs_lr1) 230 | writer.add_scalar("Student test accuracy", student_acc_mean, epoch + n_epochs_lr1) 231 | writer.add_scalar("Teacher-Student train accuracy", train_student_acc, epoch + n_epochs_lr1) 232 | writer.add_scalar("Teacher-Student val accuracy", val_student_acc, epoch + n_epochs_lr1) 233 | writer.add_scalar("Teacher-Student test accuracy", teacher_student_correct_test_mean, epoch + n_epochs_lr1) 234 | writer.export_scalars_to_json("./all_scalars.json") 235 | if cnt > 100: 236 | print('Model has converged with learning rate = {}!'.format(lrate)) 237 | cnt = 0 238 | break 239 | if lr_cnt == 0: 240 | n_epochs_lr1 = epoch 241 | else: 242 | n_epochs_lr2 = epoch 243 | print('Number of epochs with lr = {} are {} and number of epochs with lr = {} are {}'.format( 244 | opt.lr, n_epochs_lr1, opt.lr*0.1, n_epochs_lr2)) 245 | print('Accuracy with lr = {}: Train ST accuracy = {:.2f}%, Validation ST accuracy = {:.2f}%, Test ST accuracy = {:.2f}%, Test accuracy = {:.2f}%'.format( 246 | opt.lr, train_st_acc_lr1*100, val_st_acc_lr1*100, test_st_acc_lr1*100, test_acc_lr1*100)) 247 | print('Accuracy with lr = {}: Train ST accuracy = {:.2f}%, Validation ST accuracy = {:.2f}%, Test ST accuracy = {:.2f}% Test accuracy = {:.2f}%'.format( 248 | opt.lr*0.1, train_st_acc_lr2*100, val_st_acc_lr2*100, test_st_acc_lr2*100, test_acc_lr2*100)) 249 | 250 | writer.close() 251 | 252 | -------------------------------------------------------------------------------- /CIFAR_10/train_teacher/train_teacher.py: -------------------------------------------------------------------------------- 1 | # Training a Classifier on CIFAR-10 using AlexNet architecture 2 | from __future__ import print_function 3 | import argparse 4 | import torch 5 | import os 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from torchvision import datasets, transforms 10 | from alexnet import AlexNet 11 | from torch.utils.data.sampler import SubsetRandomSampler 12 | from tensorboardX import SummaryWriter 13 | import numpy as np 14 | 15 | # CUDA_VISIBLE_DEVICES=0 python train_teacher.py --batch-size 64 --test-batch-size 1000 --epochs 1000 --lr 0.001 --seed 108 --log-interval 10 16 | writer = SummaryWriter() 17 | if not os.path.exists("models"): 18 | os.makedirs("models") 19 | 20 | def train(args, model, device, train_loader, optimizer, epoch): 21 | model.train() 22 | train_loss = 0 23 | correct = 0 24 | tot = 0 25 | for batch_idx, (data, target) in enumerate(train_loader): 26 | data, target = data.to(device), target.to(device) 27 | tot += data.shape[0] 28 | optimizer.zero_grad() 29 | data = data*2 - 1 30 | logits = model(data) 31 | output = F.softmax(logits,dim=1) 32 | ce_loss = nn.CrossEntropyLoss() 33 | loss = ce_loss(logits, target) 34 | loss.backward() 35 | ce_loss_redn = nn.CrossEntropyLoss(reduction = 'sum') 36 | train_loss += ce_loss_redn(logits, target).item() 37 | pred = output.max(1, keepdim=True)[1] 38 | correct += pred.eq(target.view_as(pred)).sum().item() 39 | optimizer.step() 40 | if batch_idx % args.log_interval == 0: 41 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 42 | epoch, int((batch_idx+1) * len(data)), int(len(train_loader.dataset)*0.8), 43 | 100. * (batch_idx+1) / len(train_loader), loss.item())) 44 | train_loss /= tot 45 | train_acc = 100. * correct / tot 46 | print('Train set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format( 47 | train_loss, correct, tot,train_acc)) 48 | torch.save(model.state_dict(), "models/"+str(epoch)+".pth") 49 | return train_loss, train_acc 50 | 51 | def test(args, model, device, test_loader): 52 | model.eval() 53 | test_loss = 0 54 | correct = 0 55 | with torch.no_grad(): 56 | for data, target in test_loader: 57 | data, target = data.to(device), target.to(device) 58 | data = data*2 - 1 59 | logits = model(data) 60 | output = F.softmax(logits,dim=1) 61 | ce_loss_redn = nn.CrossEntropyLoss(reduction = 'sum') 62 | test_loss += ce_loss_redn(logits, target).item() # sum up batch loss 63 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 64 | correct += pred.eq(target.view_as(pred)).sum().item() 65 | test_loss /= len(test_loader.dataset) 66 | test_acc = 100. * correct / len(test_loader.dataset) 67 | print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format( 68 | test_loss, correct, len(test_loader.dataset), test_acc)) 69 | return test_loss, test_acc 70 | 71 | def val(args, model, device, val_loader): 72 | model.eval() 73 | val_loss = 0 74 | correct = 0 75 | tot = 0 76 | with torch.no_grad(): 77 | for batch_idx, (data, target) in enumerate(val_loader): 78 | data, target = data.to(device), target.to(device) 79 | data = data*2 - 1 80 | logits = model(data) 81 | output = F.softmax(logits,dim=1) 82 | ce_loss_redn = nn.CrossEntropyLoss(reduction = 'sum') 83 | val_loss += ce_loss_redn(logits, target).item() # sum up batch loss 84 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 85 | correct += pred.eq(target.view_as(pred)).sum().item() 86 | tot += data.shape[0] 87 | val_loss /= tot 88 | val_acc = 100. * correct / tot 89 | print('Validation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format( 90 | val_loss, correct, tot, val_acc)) 91 | return val_loss, val_acc 92 | 93 | def main(): 94 | # Training settings 95 | parser = argparse.ArgumentParser(description='CIFAR Classifier training') 96 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 97 | help='input batch size for training (default: 64)') 98 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 99 | help='input batch size for testing (default: 1000)') 100 | parser.add_argument('--epochs', type=int, default=1000, metavar='N', 101 | help='number of epochs to train (default: 10)') 102 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', 103 | help='learning rate (default: 0.001)') 104 | parser.add_argument('--no-cuda', action='store_true', default=False, 105 | help='disables CUDA training') 106 | parser.add_argument('--seed', type=int, default=1, metavar='S', 107 | help='random seed (default: 1)') 108 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 109 | help='how many batches to wait before logging training status') 110 | args = parser.parse_args() 111 | use_cuda = not args.no_cuda and torch.cuda.is_available() 112 | torch.manual_seed(args.seed) 113 | device = torch.device("cuda" if use_cuda else "cpu") 114 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 115 | classes = ('plane', 'car', 'bird', 'cat', 116 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 117 | tfm = transforms.Compose([ 118 | transforms.ToTensor() 119 | ]) 120 | train_dataset = datasets.CIFAR10( 121 | root='../../../datasets', train=True, 122 | download=True, transform=tfm) 123 | val_dataset = datasets.CIFAR10( 124 | root='../../../datasets', train=True, 125 | download=True, transform=tfm) 126 | num_train = len(train_dataset) 127 | indices = list(range(num_train)) 128 | split = int(np.floor(0.2 * num_train)) 129 | np.random.seed(args.seed) 130 | np.random.shuffle(indices) 131 | train_idx, val_idx = indices[split:], indices[:split] 132 | train_sampler = SubsetRandomSampler(train_idx) 133 | val_sampler = SubsetRandomSampler(val_idx) 134 | train_loader = torch.utils.data.DataLoader( 135 | train_dataset, batch_size=args.batch_size, sampler=train_sampler, **kwargs 136 | ) 137 | val_loader = torch.utils.data.DataLoader( 138 | val_dataset, batch_size=args.batch_size, sampler=val_sampler,**kwargs 139 | ) 140 | test_loader = torch.utils.data.DataLoader( 141 | datasets.CIFAR10('../../../datasets', train=False, download=True, 142 | transform=transforms.Compose([ 143 | transforms.ToTensor(), 144 | ])), 145 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 146 | model = AlexNet().to(device) 147 | #model.load_state_dict(torch.load("./best_model.pth")) 148 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 149 | best_val_acc = 0 150 | cnt = 0 151 | for epoch in range(1, args.epochs + 1): 152 | train_loss, train_acc = train(args, model, device, train_loader, optimizer, epoch) 153 | val_loss, val_acc = val(args, model, device, val_loader) 154 | test_loss, test_acc = test(args, model, device, test_loader) 155 | if val_acc > best_val_acc: 156 | print("Saving best model...") 157 | torch.save(model.state_dict(), "models/best_model_lr1.pth") 158 | best_val_acc = val_acc 159 | cnt = 0 160 | train_acc_lr1 = train_acc 161 | val_acc_lr1 = val_acc 162 | test_acc_lr1 = test_acc 163 | else: 164 | cnt += 1 165 | writer.add_scalar("1_Train loss", train_loss, epoch) 166 | writer.add_scalar("2_Validation loss", val_loss, epoch) 167 | writer.add_scalar("3_Test loss", test_loss, epoch) 168 | writer.add_scalar("4_Train accuracy", train_acc, epoch) 169 | writer.add_scalar("5_Validation accuracy", val_acc, epoch) 170 | writer.add_scalar("6_Test accuracy", test_acc, epoch) 171 | writer.export_scalars_to_json("./all_scalars.json") 172 | if cnt > 100: 173 | print('Model has converged with learning rate = {}!'.format(args.lr)) 174 | break 175 | n_epochs_lr1 = epoch 176 | optimizer = optim.Adam(model.parameters(), lr=args.lr*0.1) 177 | model.load_state_dict(torch.load("models/best_model_lr1.pth")) 178 | cnt = 0 179 | for epoch in range(1, args.epochs + 1): 180 | train_loss, train_acc = train(args, model, device, train_loader, optimizer, epoch + n_epochs_lr1) 181 | val_loss, val_acc = val(args, model, device, val_loader) 182 | test_loss, test_acc = test(args, model, device, test_loader) 183 | if val_acc > best_val_acc: 184 | print("Saving best model...") 185 | torch.save(model.state_dict(), "models/best_model_lr2.pth") 186 | best_val_acc = val_acc 187 | cnt = 0 188 | train_acc_lr2 = train_acc 189 | val_acc_lr2 = val_acc 190 | test_acc_lr2 = test_acc 191 | 192 | else: 193 | cnt += 1 194 | writer.add_scalar("1_Train loss", train_loss, epoch + n_epochs_lr1) 195 | writer.add_scalar("2_Validation loss", val_loss, epoch + n_epochs_lr1) 196 | writer.add_scalar("3_Test loss", test_loss, epoch + n_epochs_lr1) 197 | writer.add_scalar("4_Train accuracy", train_acc, epoch + n_epochs_lr1) 198 | writer.add_scalar("5_Validation accuracy", val_acc, epoch + n_epochs_lr1) 199 | writer.add_scalar("6_Test accuracy", test_acc, epoch + n_epochs_lr1) 200 | writer.export_scalars_to_json("./all_scalars.json") 201 | if cnt > 100: 202 | print('Model has converged with learning rate = {}!'.format(args.lr*0.1)) 203 | break 204 | n_epochs_lr2 = epoch 205 | print('Number of epochs with lr = {} are {} and number of epochs with lr = {} are {}'.format( 206 | args.lr, n_epochs_lr1, args.lr*0.1, n_epochs_lr2)) 207 | print('Accuracy with lr = {}: Train accuracy = {:.2f}%, Validation accuracy = {:.2f}%, Test accuracy = {:.2f}%'.format( 208 | args.lr, train_acc_lr1, val_acc_lr1, test_acc_lr1)) 209 | print('Accuracy with lr = {}: Train accuracy = {:.2f}%, Validation accuracy = {:.2f}%, Test accuracy = {:.2f}%'.format( 210 | args.lr*0.1, train_acc_lr2, val_acc_lr2, test_acc_lr2)) 211 | 212 | writer.close() 213 | 214 | if __name__ == '__main__': 215 | main() 216 | 217 | 218 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Visual Computing Lab -- IISc 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Other_networks_used/inceptionv3_teacher.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class BasicConv2d(nn.Module): 7 | 8 | def __init__(self, input_channels, output_channels, **kwargs): 9 | super(BasicConv2d, self).__init__() 10 | self.conv = nn.Conv2d(input_channels, output_channels, bias=False, **kwargs) 11 | self.bn = nn.BatchNorm2d(output_channels) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | def forward(self, x): 15 | x = self.conv(x) 16 | x = self.bn(x) 17 | x = self.relu(x) 18 | 19 | return x 20 | 21 | #same naive inception module 22 | class InceptionA(nn.Module): 23 | 24 | def __init__(self, input_channels, pool_features): 25 | super(InceptionA, self).__init__() 26 | self.branch1x1 = BasicConv2d(input_channels, 64, kernel_size=1) 27 | 28 | self.branch5x5 = nn.Sequential( 29 | BasicConv2d(input_channels, 48, kernel_size=1), 30 | BasicConv2d(48, 64, kernel_size=5, padding=2) 31 | ) 32 | 33 | self.branch3x3 = nn.Sequential( 34 | BasicConv2d(input_channels, 64, kernel_size=1), 35 | BasicConv2d(64, 96, kernel_size=3, padding=1), 36 | BasicConv2d(96, 96, kernel_size=3, padding=1) 37 | ) 38 | 39 | self.branchpool = nn.Sequential( 40 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 41 | BasicConv2d(input_channels, pool_features, kernel_size=3, padding=1) 42 | ) 43 | 44 | def forward(self, x): 45 | 46 | #x -> 1x1(same) 47 | branch1x1 = self.branch1x1(x) 48 | 49 | #x -> 1x1 -> 5x5(same) 50 | branch5x5 = self.branch5x5(x) 51 | #branch5x5 = self.branch5x5_2(branch5x5) 52 | 53 | #x -> 1x1 -> 3x3 -> 3x3(same) 54 | branch3x3 = self.branch3x3(x) 55 | 56 | #x -> pool -> 1x1(same) 57 | branchpool = self.branchpool(x) 58 | 59 | outputs = [branch1x1, branch5x5, branch3x3, branchpool] 60 | 61 | return torch.cat(outputs, 1) 62 | 63 | #downsample 64 | #Factorization into smaller convolutions 65 | class InceptionB(nn.Module): 66 | 67 | def __init__(self, input_channels): 68 | super(InceptionB, self).__init__() 69 | 70 | self.branch3x3 = BasicConv2d(input_channels, 384, kernel_size=3, stride=2) 71 | 72 | self.branch3x3stack = nn.Sequential( 73 | BasicConv2d(input_channels, 64, kernel_size=1), 74 | BasicConv2d(64, 96, kernel_size=3, padding=1), 75 | BasicConv2d(96, 96, kernel_size=3, stride=2) 76 | ) 77 | 78 | self.branchpool = nn.MaxPool2d(kernel_size=3, stride=2) 79 | 80 | def forward(self, x): 81 | 82 | #x - > 3x3(downsample) 83 | branch3x3 = self.branch3x3(x) 84 | 85 | #x -> 3x3 -> 3x3(downsample) 86 | branch3x3stack = self.branch3x3stack(x) 87 | 88 | #x -> avgpool(downsample) 89 | branchpool = self.branchpool(x) 90 | outputs = [branch3x3, branch3x3stack, branchpool] 91 | 92 | return torch.cat(outputs, 1) 93 | 94 | #Factorizing Convolutions with Large Filter Size 95 | class InceptionC(nn.Module): 96 | def __init__(self, input_channels, channels_7x7): 97 | super(InceptionC, self).__init__() 98 | self.branch1x1 = BasicConv2d(input_channels, 192, kernel_size=1) 99 | 100 | c7 = channels_7x7 101 | 102 | self.branch7x7 = nn.Sequential( 103 | BasicConv2d(input_channels, c7, kernel_size=1), 104 | BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)), 105 | BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 106 | ) 107 | 108 | self.branch7x7stack = nn.Sequential( 109 | BasicConv2d(input_channels, c7, kernel_size=1), 110 | BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)), 111 | BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)), 112 | BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)), 113 | BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 114 | ) 115 | 116 | self.branch_pool = nn.Sequential( 117 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 118 | BasicConv2d(input_channels, 192, kernel_size=1), 119 | ) 120 | 121 | def forward(self, x): 122 | 123 | #x -> 1x1(same) 124 | branch1x1 = self.branch1x1(x) 125 | 126 | #x -> 1layer 1*7 and 7*1 (same) 127 | branch7x7 = self.branch7x7(x) 128 | 129 | #x-> 2layer 1*7 and 7*1(same) 130 | branch7x7stack = self.branch7x7stack(x) 131 | 132 | #x-> avgpool (same) 133 | branchpool = self.branch_pool(x) 134 | 135 | outputs = [branch1x1, branch7x7, branch7x7stack, branchpool] 136 | 137 | return torch.cat(outputs, 1) 138 | 139 | class InceptionD(nn.Module): 140 | 141 | def __init__(self, input_channels): 142 | super(InceptionD, self).__init__() 143 | 144 | self.branch3x3 = nn.Sequential( 145 | BasicConv2d(input_channels, 192, kernel_size=1), 146 | BasicConv2d(192, 320, kernel_size=3, stride=2) 147 | ) 148 | 149 | self.branch7x7 = nn.Sequential( 150 | BasicConv2d(input_channels, 192, kernel_size=1), 151 | BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)), 152 | BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)), 153 | BasicConv2d(192, 192, kernel_size=3, stride=2) 154 | ) 155 | 156 | self.branchpool = nn.AvgPool2d(kernel_size=3, stride=2) 157 | 158 | def forward(self, x): 159 | 160 | #x -> 1x1 -> 3x3(downsample) 161 | branch3x3 = self.branch3x3(x) 162 | 163 | #x -> 1x1 -> 1x7 -> 7x1 -> 3x3 (downsample) 164 | branch7x7 = self.branch7x7(x) 165 | 166 | #x -> avgpool (downsample) 167 | branchpool = self.branchpool(x) 168 | 169 | outputs = [branch3x3, branch7x7, branchpool] 170 | 171 | return torch.cat(outputs, 1) 172 | 173 | 174 | #same 175 | class InceptionE(nn.Module): 176 | def __init__(self, input_channels): 177 | super(InceptionE, self).__init__() 178 | self.branch1x1 = BasicConv2d(input_channels, 320, kernel_size=1) 179 | 180 | self.branch3x3_1 = BasicConv2d(input_channels, 384, kernel_size=1) 181 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 182 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 183 | 184 | self.branch3x3stack_1 = BasicConv2d(input_channels, 448, kernel_size=1) 185 | self.branch3x3stack_2 = BasicConv2d(448, 384, kernel_size=3, padding=1) 186 | self.branch3x3stack_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 187 | self.branch3x3stack_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 188 | 189 | self.branch_pool = nn.Sequential( 190 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 191 | BasicConv2d(input_channels, 192, kernel_size=1) 192 | ) 193 | 194 | def forward(self, x): 195 | 196 | #x -> 1x1 (same) 197 | branch1x1 = self.branch1x1(x) 198 | 199 | # x -> 1x1 -> 3x1 200 | # x -> 1x1 -> 1x3 201 | # concatenate(3x1, 1x3) 202 | 203 | branch3x3 = self.branch3x3_1(x) 204 | branch3x3 = [ 205 | self.branch3x3_2a(branch3x3), 206 | self.branch3x3_2b(branch3x3) 207 | ] 208 | branch3x3 = torch.cat(branch3x3, 1) 209 | 210 | # x -> 1x1 -> 3x3 -> 1x3 211 | # x -> 1x1 -> 3x3 -> 3x1 212 | #concatenate(1x3, 3x1) 213 | branch3x3stack = self.branch3x3stack_1(x) 214 | branch3x3stack = self.branch3x3stack_2(branch3x3stack) 215 | branch3x3stack = [ 216 | self.branch3x3stack_3a(branch3x3stack), 217 | self.branch3x3stack_3b(branch3x3stack) 218 | ] 219 | branch3x3stack = torch.cat(branch3x3stack, 1) 220 | 221 | branchpool = self.branch_pool(x) 222 | 223 | outputs = [branch1x1, branch3x3, branch3x3stack, branchpool] 224 | 225 | return torch.cat(outputs, 1) 226 | 227 | class InceptionV3(nn.Module): 228 | 229 | def __init__(self, num_classes=100): 230 | super(InceptionV3, self).__init__() 231 | self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, padding=1) 232 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3, padding=1) 233 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) 234 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) 235 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) 236 | 237 | #naive inception module 238 | self.Mixed_5b = InceptionA(192, pool_features=32) 239 | self.Mixed_5c = InceptionA(256, pool_features=64) 240 | self.Mixed_5d = InceptionA(288, pool_features=64) 241 | 242 | #downsample 243 | self.Mixed_6a = InceptionB(288) 244 | 245 | self.Mixed_6b = InceptionC(768, channels_7x7=128) 246 | self.Mixed_6c = InceptionC(768, channels_7x7=160) 247 | self.Mixed_6d = InceptionC(768, channels_7x7=160) 248 | self.Mixed_6e = InceptionC(768, channels_7x7=192) 249 | 250 | #downsample 251 | self.Mixed_7a = InceptionD(768) 252 | 253 | self.Mixed_7b = InceptionE(1280) 254 | self.Mixed_7c = InceptionE(2048) 255 | 256 | #6*6 feature size 257 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 258 | self.dropout = nn.Dropout2d() 259 | self.linear = nn.Linear(2048, num_classes) 260 | 261 | def forward(self, x): 262 | 263 | #32 -> 30 264 | x = self.Conv2d_1a_3x3(x) 265 | x = self.Conv2d_2a_3x3(x) 266 | x = self.Conv2d_2b_3x3(x) 267 | x = self.Conv2d_3b_1x1(x) 268 | x = self.Conv2d_4a_3x3(x) 269 | 270 | #30 -> 30 271 | x = self.Mixed_5b(x) 272 | x = self.Mixed_5c(x) 273 | x = self.Mixed_5d(x) 274 | 275 | #30 -> 14 276 | #Efficient Grid Size Reduction to avoid representation 277 | #bottleneck 278 | x = self.Mixed_6a(x) 279 | 280 | #14 -> 14 281 | 282 | x = self.Mixed_6b(x) 283 | x = self.Mixed_6c(x) 284 | x = self.Mixed_6d(x) 285 | x = self.Mixed_6e(x) 286 | 287 | #14 -> 6 288 | #Efficient Grid Size Reduction 289 | x = self.Mixed_7a(x) 290 | 291 | #6 -> 6 292 | 293 | x = self.Mixed_7b(x) 294 | x = self.Mixed_7c(x) 295 | 296 | #6 -> 1 297 | x = self.avgpool(x) 298 | x = self.dropout(x) 299 | x = x.view(x.size(0), -1) 300 | x = self.linear(x) 301 | return x 302 | 303 | 304 | def inceptionv3(): 305 | return InceptionV3() 306 | 307 | 308 | 309 | -------------------------------------------------------------------------------- /Other_networks_used/lenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class LeNet(nn.Module): 5 | def __init__(self, num_classes=10): 6 | super(LeNet, self).__init__() 7 | 8 | self.conv1 = nn.Conv2d(1, 6, 5, stride=1, padding=0) 9 | self.conv1.bias.data.normal_(0, 0.1) 10 | self.conv1.bias.data.fill_(0) 11 | 12 | self.relu = nn.ReLU() 13 | 14 | self.pad = nn.MaxPool2d(2, stride=2) 15 | 16 | self.conv2 = nn.Conv2d(6, 16, 5, stride=1, padding=0) 17 | self.conv2.bias.data.normal_(0, 0.1) 18 | self.conv2.bias.data.fill_(0) 19 | 20 | self.fc1 = nn.Linear(400,120) 21 | self.fc1.bias.data.normal_(0, 0.1) 22 | self.fc1.bias.data.fill_(0) 23 | 24 | self.fc2 = nn.Linear(120,84) 25 | self.fc2.bias.data.normal_(0, 0.1) 26 | self.fc2.bias.data.fill_(0) 27 | 28 | self.fc3 = nn.Linear(84,num_classes) 29 | self.fc3.bias.data.normal_(0, 0.1) 30 | self.fc3.bias.data.fill_(0) 31 | 32 | self.soft = nn.Softmax() 33 | 34 | def forward(self, x): 35 | layer1 = self.pad(self.relu(self.conv1(x))) 36 | layer2 = self.pad(self.relu(self.conv2(layer1))) 37 | 38 | flatten = layer2.view(-1, 16*5*5) 39 | fully1 = self.relu(self.fc1(flatten)) 40 | 41 | fully2 = self.relu(self.fc2(fully1)) 42 | 43 | logits = self.fc3(fully2) 44 | #softmax_val = self.soft(logits) 45 | 46 | return logits 47 | 48 | model = LeNet(num_classes=10) 49 | 50 | -------------------------------------------------------------------------------- /Other_networks_used/lenet_half.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class LeNet_half(nn.Module): 5 | def __init__(self, num_classes=10): 6 | super(LeNet_half, self).__init__() 7 | 8 | self.conv1 = nn.Conv2d(1, 3, 5, stride=1, padding=0) 9 | self.conv1.bias.data.normal_(0, 0.1) 10 | self.conv1.bias.data.fill_(0) 11 | 12 | self.relu = nn.ReLU() 13 | 14 | self.pad = nn.MaxPool2d(2, stride=2) 15 | 16 | self.conv2 = nn.Conv2d(3, 8, 5, stride=1, padding=0) 17 | self.conv2.bias.data.normal_(0, 0.1) 18 | self.conv2.bias.data.fill_(0) 19 | 20 | self.fc1 = nn.Linear(200,120) 21 | self.fc1.bias.data.normal_(0, 0.1) 22 | self.fc1.bias.data.fill_(0) 23 | 24 | self.fc2 = nn.Linear(120,84) 25 | self.fc2.bias.data.normal_(0, 0.1) 26 | self.fc2.bias.data.fill_(0) 27 | 28 | self.fc3 = nn.Linear(84,num_classes) 29 | self.fc3.bias.data.normal_(0, 0.1) 30 | self.fc3.bias.data.fill_(0) 31 | 32 | self.soft = nn.Softmax() 33 | 34 | def forward(self, x): 35 | layer1 = self.pad(self.relu(self.conv1(x))) 36 | layer2 = self.pad(self.relu(self.conv2(layer1))) 37 | 38 | flatten = layer2.view(-1, 8*5*5) 39 | fully1 = self.relu(self.fc1(flatten)) 40 | 41 | fully2 = self.relu(self.fc2(fully1)) 42 | 43 | logits = self.fc3(fully2) 44 | #softmax_val = self.soft(logits) 45 | 46 | return logits 47 | 48 | model = LeNet_half(num_classes=10) 49 | 50 | -------------------------------------------------------------------------------- /Other_networks_used/resnet_18_student.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | class BasicBlock(nn.Module): 6 | """Basic Block for resnet 18 and resnet 34 7 | 8 | """ 9 | expansion = 1 10 | 11 | def __init__(self, in_channels, out_channels, stride=1): 12 | super(BasicBlock, self).__init__() 13 | 14 | #residual function 15 | self.residual_function = nn.Sequential( 16 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), 17 | nn.BatchNorm2d(out_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False), 20 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 21 | ) 22 | 23 | #shortcut 24 | self.shortcut = nn.Sequential() 25 | 26 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels: 27 | self.shortcut = nn.Sequential( 28 | nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False), 29 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 30 | ) 31 | 32 | def forward(self, x): 33 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 34 | 35 | class BottleNeck(nn.Module): 36 | """Residual block for resnet over 50 layers 37 | 38 | """ 39 | expansion = 4 40 | def __init__(self, in_channels, out_channels, stride=1): 41 | super(BottleNeck, self).__init__() 42 | self.residual_function = nn.Sequential( 43 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 44 | nn.BatchNorm2d(out_channels), 45 | nn.ReLU(inplace=True), 46 | nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False), 47 | nn.BatchNorm2d(out_channels), 48 | nn.ReLU(inplace=True), 49 | nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False), 50 | nn.BatchNorm2d(out_channels * BottleNeck.expansion), 51 | ) 52 | 53 | self.shortcut = nn.Sequential() 54 | 55 | if stride != 1 or in_channels != out_channels * BottleNeck.expansion: 56 | self.shortcut = nn.Sequential( 57 | nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False), 58 | nn.BatchNorm2d(out_channels * BottleNeck.expansion) 59 | ) 60 | 61 | def forward(self, x): 62 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 63 | 64 | class ResNet(nn.Module): 65 | 66 | def __init__(self, block, num_block, num_classes=100): 67 | super(ResNet, self).__init__() 68 | 69 | self.in_channels = 64 70 | 71 | self.conv1 = nn.Sequential( 72 | nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), 73 | nn.BatchNorm2d(64), 74 | nn.ReLU(inplace=True)) 75 | 76 | self.conv2_x = self._make_layer(block, 64, num_block[0], 1) 77 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2) 78 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2) 79 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2) 80 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 81 | self.fc = nn.Linear(512 * block.expansion, num_classes) 82 | 83 | def _make_layer(self, block, out_channels, num_blocks, stride): 84 | 85 | strides = [stride] + [1] * (num_blocks - 1) 86 | layers = [] 87 | for stride in strides: 88 | layers.append(block(self.in_channels, out_channels, stride)) 89 | self.in_channels = out_channels * block.expansion 90 | 91 | return nn.Sequential(*layers) 92 | 93 | def forward(self, x): 94 | output = self.conv1(x) 95 | output = self.conv2_x(output) 96 | output = self.conv3_x(output) 97 | output = self.conv4_x(output) 98 | output = self.conv5_x(output) 99 | output = self.avg_pool(output) 100 | output = output.view(output.size(0), -1) 101 | output = self.fc(output) 102 | 103 | return output 104 | 105 | def resnet18(): 106 | """ return a ResNet 18 object 107 | """ 108 | return ResNet(BasicBlock, [2, 2, 2, 2]) 109 | 110 | def resnet34(): 111 | """ return a ResNet 34 object 112 | """ 113 | return ResNet(BasicBlock, [3, 4, 6, 3]) 114 | 115 | def resnet50(): 116 | """ return a ResNet 50 object 117 | """ 118 | return ResNet(BottleNeck, [3, 4, 6, 3]) 119 | 120 | def resnet101(): 121 | """ return a ResNet 101 object 122 | """ 123 | return ResNet(BottleNeck, [3, 4, 23, 3]) 124 | 125 | def resnet152(): 126 | """ return a ResNet 152 object 127 | """ 128 | return ResNet(BottleNeck, [3, 8, 36, 3]) 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeGAN 2 | Data-enriching GAN for retrieving Representative Samples from a Trained Classifier 3 | -------------------------------------------------------------------------------- /Readme.txt: -------------------------------------------------------------------------------- 1 | Readme: 2 | 3 | Codes with CIFAR-10 as True data and CIFAR-100(select classes) as proxy data are uploaded. 4 | Codes for the other cases are very similar to these, with minor changes 5 | The commands to run each code are mentioned in the respective files. Path where the Datasets can be downloaded can be modified in the command. 6 | 7 | Folder organization: 8 | 9 | > CIFAR-10: 10 | Note: Before running any of the codes, all files from the folder 'network' should be copied to the respective folder. 11 | For example, to run codes from 'train_generator' folder, all files from 'network' folder should be copied to the root of 'train_generator' folder. 12 | > network: Contains the network architecture definition for the Teacher, Student and GAN models used for CIFAR-10. Also contains saved weights of trained Teacher network. 13 | > alexnet.py: Teacher architecture with CIFAR-10 as True Data 14 | > alexnet_half.py: Student architecture with CIFAR-10 as True Data 15 | > dcgan_model.py: Architecture of generator and disriminator using DCGAN and DeGAN 16 | > best_model.pth: Saved Teacher weights used for all CIFAR-10 experiments 17 | > train_teacher: 18 | > train_teacher.py: Code for training Teacher network 19 | > train_generator: 20 | > dfgan.py: Code for training DCGAN/ DeGAN. Setting 'c_l' and 'd_l' to 0 is for training DCGAN. 21 | The classes in inc_classes can be modified to select the required classes from CIFAR-100 as part of Proxy Dataset. 22 | > train_student: 23 | > Using_Data 24 | > KD_related_data.py: Code for Knowledge Distillation using related/ unrelated data 25 | > Using_GAN 26 | The code 'dfgan.py' in the folder train_generator needs to be run before this 27 | > KD_dfgan.py: Code for Knowledge Distillation using DCGAN/ DeGAN 28 | 29 | > Other_networks_used: 30 | > lenet.py : Teacher architecture with Fashion-MNIST as True Data 31 | > lenet_half.py : Student architecture with Fashion-MNIST as True Data 32 | > inceptionv3_teacher.py : Teacher architecture with CIFAR-100 as True Data 33 | > resnet_18_student.py : Student architecture with CIFAR-100 as True Data --------------------------------------------------------------------------------