├── .gitignore ├── LICENSE ├── README.md ├── evaluate.py ├── fast_dense_feature_extractor.py ├── models.py ├── mvtec_dataset.py ├── res.jpg ├── student_train.py └── teacher_train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | work_dir 3 | .vscode -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 LuyaooChen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Uninformed Students 2 | ![result](https://raw.githubusercontent.com/LuyaooChen/uninformed-students-pytorch/main/res.jpg) 3 | ## Introduction - 介绍 4 | A simple and incomplete implementation of paper: 5 | MVTec, [Uninformed Students: Student–Teacher Anomaly Detection with Discriminative Latent Embeddings.](https://ieeexplore.ieee.org/document/9157778/) CVPR, 2020. 6 | [arXiv:1911.02357](https://arxiv.org/abs/1911.02357) 7 | 8 | Another implementation repo: https://github.com/denguir/student-teacher-anomaly-detection 9 | 10 | 此项目复现主要是本人学习之用,可能存在各种问题,且目前已不再维护!。 11 | 12 | ## Requirements - 依赖 13 | python3 14 | pytorch~=1.3 15 | torchvision 16 | numpy 17 | opencv-python 18 | 19 | ## Usage - 用法 20 | ### Prepare datasets 21 | imagenet (any image dataset) 22 | MVTec_AD 23 | ### Train a teacher network 24 | choose a `patch_size` from (17, 33 or 65) and 25 | `python teacher_train.py` 26 | ### Train a student network 27 | choose a `patch_size`(the teacher net should have been pretrained), and set `st_id` 28 | `python student_train.py` 29 | ### Evaluate 30 | `python evaluate.py` 31 | the res.jpg will be saved to the current directory. 32 | 33 | ## TODO 34 | metric learning and descriptor compactness in teacher_train.py 35 | complete evaluate.py 36 | ... 37 | 38 | ## Reference - 参考 39 | https://github.com/erezposner/Fast_Dense_Feature_Extraction 40 | https://github.com/denguir/student-teacher-anomaly-detection 41 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluation for mvtec_ad dataset. 3 | Reference from https://github.com/denguir/student-teacher-anomaly-detection. 4 | 5 | Author: Luyao Chen 6 | Date: 2020.10 7 | """ 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader 12 | from torchvision import transforms, datasets 13 | import cv2 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | from tqdm import tqdm 17 | from sklearn.metrics import roc_auc_score 18 | from PIL import Image 19 | from models import _Teacher, TeacherOrStudent 20 | from mvtec_dataset import MVTec_AD 21 | 22 | 23 | def error(student_outputs, teacher_output): 24 | # n*imH*imW*d 25 | # s_mean = 0 26 | # for s_out in student_outputs: 27 | # s_mean += s_out 28 | # s_mean /= len(student_outputs) 29 | s_mean = torch.mean(student_outputs, dim=1) 30 | return torch.norm(s_mean - teacher_output, dim=3) 31 | 32 | 33 | def variance(student_outputs): 34 | # s_sum = 0 35 | # for s_out in student_outputs: 36 | # s_sum += s_out 37 | # s_mean = s_sum / len(student_outputs) 38 | 39 | # v = 0 40 | # for s_out in student_outputs: 41 | # v += torch.norm(s_out - s_mean, dim=3) 42 | # v /= len(student_outputs) 43 | sse = torch.sum(student_outputs ** 2, dim=4) 44 | msse = torch.mean(sse, dim=1) 45 | s_mean = torch.mean(student_outputs, dim=1) 46 | var = msse - torch.sum(s_mean**2, dim=3) 47 | return var 48 | 49 | 50 | def increment_mean_and_var(mu_N, var_N, N, batch): 51 | '''Increment value of mean and variance based on 52 | current mean, var and new batch 53 | ''' 54 | # batch: (batch, h, w, vector) 55 | B = batch.size()[0] # batch size 56 | # we want a descriptor vector -> mean over batch and pixels 57 | mu_B = torch.mean(batch, dim=[0, 1, 2]) 58 | S_B = B * torch.var(batch, dim=[0, 1, 2], unbiased=False) 59 | S_N = N * var_N 60 | mu_NB = N / (N + B) * mu_N + B / (N + B) * mu_B 61 | S_NB = S_N + S_B + B * mu_B**2 + N * mu_N**2 - (N + B) * mu_NB**2 62 | var_NB = S_NB / (N + B) 63 | return mu_NB, var_NB, N + B 64 | 65 | 66 | if __name__ == "__main__": 67 | patch_sizes = [33] # add more size for multi-scale segmentation 68 | num_students = 3 # num of studetns per teacher 69 | imH = 256 # image height and width should be multiples of sL1∗sL2∗sL3... 70 | imW = 256 71 | batch_size = 1 72 | work_dir = 'work_dir/' 73 | class_dir = 'grid/' 74 | train_dataset_dir = '/home/cly/data_disk/MVTec_AD/data/' + class_dir + 'train/' 75 | test_dataset_dir = '/home/cly/data_disk/MVTec_AD/data/' + class_dir 76 | device = torch.device('cuda:1') 77 | 78 | N_scale = len(patch_sizes) 79 | 80 | std = [0.229, 0.224, 0.225] 81 | mean = [0.485, 0.456, 0.406] 82 | 83 | trans = transforms.Compose([ 84 | # transforms.RandomCrop((imH, imW)), 85 | transforms.Resize((imH, imW)), 86 | transforms.ToTensor(), 87 | transforms.Normalize(mean, std) 88 | ]) 89 | mask_trans = transforms.Compose([ 90 | # transforms.RandomCrop((imH, imW)), 91 | transforms.Resize((imH, imW), Image.NEAREST), 92 | transforms.ToTensor(), 93 | ]) 94 | anomaly_free_dataset = datasets.ImageFolder( 95 | train_dataset_dir, transform=trans) 96 | af_dataloader = DataLoader(anomaly_free_dataset, batch_size=batch_size) 97 | test_dataset = MVTec_AD(test_dataset_dir, transform=trans, 98 | mask_transform=mask_trans, phase='test') 99 | test_dataloader = DataLoader(test_dataset, batch_size=batch_size) 100 | 101 | teachers = [] 102 | students = [] 103 | for patch_size in patch_sizes: 104 | _teacher = _Teacher(patch_size) 105 | checkpoint = torch.load(work_dir + '_teacher' + 106 | str(patch_size) + '.pth', torch.device('cpu')) 107 | _teacher.load_state_dict(checkpoint) 108 | teacher = TeacherOrStudent(patch_size, _teacher, imH, imW).to(device) 109 | teacher.eval() 110 | teachers.append(teacher) 111 | 112 | s_t = [] 113 | for i in range(num_students): 114 | # issue #2. must create a new _teacher. 115 | _teacher = _Teacher(patch_size) 116 | student = TeacherOrStudent( 117 | patch_size, _teacher, imH, imW).to(device) 118 | checkpoint = torch.load(work_dir + class_dir + 'student' + 119 | str(patch_size) + '_' + str(i) + 120 | '.pth', torch.device('cpu')) 121 | student.load_state_dict(checkpoint) 122 | student.eval() 123 | s_t.append(student) 124 | students.append(s_t) 125 | 126 | with torch.no_grad(): 127 | t_mu, t_var, t_N = [0 for i in range(N_scale)], [0 for i in range(N_scale)], [ 128 | 0 for i in range(N_scale)] 129 | print('Callibrating teacher on train dataset.') 130 | for data, _ in tqdm(af_dataloader): 131 | data = data.to(device) 132 | for i in range(N_scale): 133 | t_out = teachers[i](data) 134 | t_mu[i], t_var[i], t_N[i] = increment_mean_and_var( 135 | t_mu[i], t_var[i], t_N[i], t_out) 136 | 137 | # mu_err, var_err = torch.tensor([4.7920308113098145]), torch.tensor([3.410670280456543]) 138 | # mu_var, var_var = torch.tensor([4.074430465698242]), torch.tensor([1.5367100238800049]) 139 | 140 | max_err, max_var = [0 for i in range(N_scale)], [0 for i in range(N_scale)] 141 | mu_err, var_err, N_err = [0 for i in range(N_scale)], [0 for i in range(N_scale)], [0 for i in range(N_scale)] 142 | mu_var, var_var, N_var = [0 for i in range(N_scale)], [0 for i in range(N_scale)], [0 for i in range(N_scale)] 143 | print('Callibrating scoring parameters on train dataset.') 144 | for data, _ in tqdm(af_dataloader): 145 | data = data.to(device) 146 | for i in range(N_scale): 147 | teacher_output = (teachers[i]( 148 | data) - t_mu[i]) / torch.sqrt(t_var[i]) 149 | student_outputs = [] 150 | for j in range(num_students): 151 | student_outputs.append(students[i][j](data)) 152 | student_outputs = torch.stack(student_outputs, dim=1) 153 | e = error(student_outputs, teacher_output) 154 | v = variance(student_outputs) 155 | mu_err[i], var_err[i], N_err[i] = increment_mean_and_var( 156 | mu_err[i], var_err[i], N_err[i], e) 157 | mu_var[i], var_var[i], N_var[i] = increment_mean_and_var( 158 | mu_var[i], var_var[i], N_var[i], v) 159 | 160 | max_err[i] = max(max_err[i], torch.max(e)) 161 | max_var[i] = max(max_var[i], torch.max(v)) 162 | 163 | # max_score = 29.9642391204834 164 | max_score = 0 165 | for i in range(N_scale): 166 | print('mu_err:{}, var_err:{}, mu_var:{}, var_var:{}'.format( 167 | mu_err[i], var_err[i], mu_var[i], var_var[i] 168 | )) 169 | max_score += (max_err[i] - mu_err[i]) / torch.sqrt(var_err[i]) + \ 170 | (max_var[i] - mu_var[i]) / torch.sqrt(var_var[i]) 171 | max_score /= N_scale 172 | print('max_score:{}'.format(max_score)) 173 | 174 | score_map_list = [] 175 | gt_mask_list = [] 176 | img_id = 0 177 | for data, gt_mask, _ in tqdm(test_dataloader): 178 | plt_list = [] 179 | ori_imgs = data 180 | data = data.to(device) 181 | gt_mask_list.append(gt_mask.data.numpy()) 182 | anomaly_score = 0 183 | for i in range(N_scale): 184 | teacher_output = (teachers[i]( 185 | data) - t_mu[i]) / torch.sqrt(t_var[i]) 186 | plt_list.append(teacher_output) 187 | student_outputs = [] 188 | for j in range(num_students): 189 | student_outputs.append(students[i][j](data)) 190 | plt_list.append(student_outputs[j]) 191 | student_outputs = torch.stack(student_outputs, dim=1) 192 | e = error(student_outputs, teacher_output) 193 | v = variance(student_outputs) 194 | anomaly_score += (e - mu_err[i]) / torch.sqrt(var_err[i]) + \ 195 | (v - mu_var[i]) / torch.sqrt(var_var[i]) 196 | 197 | anomaly_score /= N_scale 198 | score_map_list.append(anomaly_score.cpu().detach().numpy()) 199 | # print('max:{:.2f},min:{:.2f},avg:{:.2f}'.format(torch.max(anomaly_score), 200 | # torch.min(anomaly_score), 201 | # torch.mean(anomaly_score))) 202 | 203 | # plt.figure() 204 | # plt.subplot(2, 2, 1) 205 | # plt_img = plt_list[1].cpu().detach().numpy()[0] 206 | # plt_img = np.mean(plt_img, axis=2) 207 | # plt_img = np.expand_dims(plt_img, 2) 208 | # plt.imshow(plt_img, cmap='jet') 209 | # plt.colorbar() 210 | # plt.subplot(2, 2, 2) 211 | # plt_img = plt_list[2].cpu().detach().numpy()[0] 212 | # plt_img = np.mean(plt_img, axis=2) 213 | # plt_img = np.expand_dims(plt_img, 2) 214 | # plt.imshow(plt_img, cmap='jet') 215 | # plt.colorbar() 216 | # plt.subplot(2, 2, 3) 217 | # plt_img = plt_list[3].cpu().detach().numpy()[0] 218 | # plt_img = np.mean(plt_img, axis=2) 219 | # plt_img = np.expand_dims(plt_img, 2) 220 | # plt.imshow(plt_img, cmap='jet') 221 | # plt.colorbar() 222 | # plt.subplot(2, 2, 4) 223 | # plt_img = plt_list[0].cpu().detach().numpy()[0] 224 | # plt_img = np.mean(plt_img, axis=2) 225 | # plt_img = np.expand_dims(plt_img, 2) 226 | # plt.imshow(plt_img, cmap='jet') 227 | # plt.colorbar() 228 | # plt.savefig('cmp.png') 229 | # plt.close() 230 | 231 | # px = 118 232 | # py = 132 233 | # plt.figure(figsize=(6, 3)) 234 | # plt_vec = plt_list[1].cpu().detach().numpy()[0, px, py] 235 | # plt_vec -= plt_list[0].cpu().detach().numpy()[0, px, py] 236 | # plt.plot(plt_vec, label='s1') 237 | # plt_vec = plt_list[2].cpu().detach().numpy()[0, px, py] 238 | # plt.plot(plt_vec, label='s2') 239 | # plt_vec = plt_list[3].cpu().detach().numpy()[0, px, py] 240 | # plt.plot(plt_vec, label='s3') 241 | # plt_vec = plt_list[0].cpu().detach().numpy()[0, px, py] 242 | # plt.plot(plt_vec, label='t') 243 | # plt.legend() 244 | # plt.savefig('vec.png') 245 | # plt.close() 246 | 247 | anomaly_score -= torch.min(anomaly_score) 248 | # anomaly_score /= torch.max(anomaly_score) 249 | anomaly_score /= max_score 250 | # anomaly_score /= 30 251 | score_map = anomaly_score.cpu().detach().numpy()[0, :, :] 252 | score_map = np.minimum(score_map, 1) 253 | score_map = cv2.applyColorMap( 254 | np.uint8(score_map * 255), cv2.COLORMAP_JET) 255 | # # cv2.imwrite('score.jpg', score_map) 256 | ori_img = ori_imgs.permute(0, 2, 3, 1).detach().numpy()[0, :, :, :] 257 | for c in range(3): 258 | ori_img[:, :, c] = ori_img[:, :, c] * std[c] + mean[c] 259 | ori_img = cv2.cvtColor(ori_img, cv2.COLOR_RGB2BGR) 260 | # # cv2.imwrite('ori.jpg', np.uint8(ori_img * 255)) 261 | save_img = np.concatenate( 262 | (np.uint8(ori_img * 255), score_map), axis=1) 263 | # cv2.imwrite('res.jpg', save_img) 264 | cv2.imwrite('tmp/' + str(img_id) + '.jpg', save_img) 265 | img_id += 1 266 | 267 | flatten_gt_mask_list = np.concatenate(gt_mask_list).ravel() 268 | flatten_score_map_list = np.concatenate(score_map_list).ravel() 269 | per_pixel_rocauc = roc_auc_score( 270 | flatten_gt_mask_list, flatten_score_map_list) 271 | print('pixel ROCAUC:{}'.format(per_pixel_rocauc)) 272 | -------------------------------------------------------------------------------- /fast_dense_feature_extractor.py: -------------------------------------------------------------------------------- 1 | """ 2 | implementation of this paper: 3 | Christian Bailer, Tewodros A Habtegebrial, Kiran 4 | Varanasi, and Didier Stricker. Fast Dense Feature 5 | Extraction with CNNs that have Pooling or Strid- 6 | ing Layers. In British Machine Vision Conference 7 | (BMVC), 2017. 8 | 9 | Reference from: https://github.com/erezposner/Fast_Dense_Feature_Extraction 10 | """ 11 | 12 | from torch import nn 13 | import torch 14 | import numpy as np 15 | import torch.nn.functional as F 16 | 17 | 18 | # (N,C,H,W) 19 | 20 | 21 | class multiPoolPrepare(nn.Module): 22 | def __init__(self, patchY, patchX): 23 | super(multiPoolPrepare, self).__init__() 24 | pady = patchY - 1 25 | padx = patchX - 1 26 | 27 | self.pad_top = np.ceil(pady / 2).astype(int) 28 | self.pad_bottom = np.floor(pady / 2).astype(int) 29 | self.pad_left = np.ceil(padx / 2).astype(int) 30 | self.pad_right = np.floor(padx / 2).astype(int) 31 | 32 | def forward(self, x): 33 | y = F.pad(x, [self.pad_left, self.pad_right, 34 | self.pad_top, self.pad_bottom], mode='reflect') 35 | # value=0) 36 | return y 37 | 38 | 39 | class unwrapPrepare(nn.Module): 40 | def __init__(self): 41 | super(unwrapPrepare, self).__init__() 42 | 43 | def forward(self, x): 44 | x_ = F.pad(x, [0, -1, 0, -1], value=0) 45 | y = x_.contiguous().view(x_.shape[0], -1) 46 | y = y.transpose(0, 1) 47 | return y.contiguous() 48 | 49 | 50 | class unwrapPool(nn.Module): 51 | def __init__(self, outChans, curImgW, curImgH, dW, dH): 52 | super(unwrapPool, self).__init__() 53 | self.outChans = int(outChans) 54 | self.curImgW = int(curImgW) 55 | self.curImgH = int(curImgH) 56 | self.dW = int(dW) 57 | self.dH = int(dH) 58 | 59 | def forward(self, x): 60 | y = x.view((self.outChans, self.curImgW, 61 | self.curImgH, self.dH, self.dW, -1)) 62 | y = y.transpose(2, 3) 63 | 64 | return y.contiguous() 65 | 66 | 67 | class multiMaxPooling(nn.Module): 68 | def __init__(self, kW, kH, dW, dH): 69 | super(multiMaxPooling, self).__init__() 70 | layers = [] 71 | self.padd = [] 72 | for i in range(0, dH): 73 | for j in range(0, dW): 74 | self.padd.append((-j, -i)) 75 | layers.append(nn.MaxPool2d( 76 | kernel_size=(kW, kH), stride=(dW, dH))) 77 | self.max_layers = nn.ModuleList(layers) 78 | self.s = dH 79 | 80 | def forward(self, x): 81 | 82 | hh = [] 83 | ww = [] 84 | res = [] 85 | 86 | for i in range(0, len(self.max_layers)): 87 | pad_left, pad_top = self.padd[i] 88 | _x = F.pad(x, [pad_left, pad_left, pad_top, pad_top], value=0) 89 | _x = self.max_layers[i](_x) 90 | h, w = _x.size()[2], _x.size()[3] 91 | hh.append(h) 92 | ww.append(w) 93 | res.append(_x) 94 | max_h, max_w = np.max(hh), np.max(ww) 95 | for i in range(0, len(self.max_layers)): 96 | _x = res[i] 97 | h, w = _x.size()[2], _x.size()[3] 98 | pad_top = np.floor((max_h - h) / 2).astype(int) 99 | pad_bottom = np.ceil((max_h - h) / 2).astype(int) 100 | pad_left = np.floor((max_w - w) / 2).astype(int) 101 | pad_right = np.ceil((max_w - w) / 2).astype(int) 102 | _x = F.pad(_x, [pad_left, pad_right, pad_top, pad_bottom], value=0) 103 | res[i] = _x 104 | return torch.cat(res, 0) 105 | 106 | 107 | class multiConv(nn.Module): 108 | def __init__(self, nInputPlane, nOutputPlane, kW, kH, dW, dH): 109 | super(multiConv, self).__init__() 110 | layers = [] 111 | self.padd = [] 112 | for i in range(0, dH): 113 | for j in range(0, dW): 114 | self.padd.append((-j, -i)) 115 | torch.manual_seed(10) 116 | torch.cuda.manual_seed(10) 117 | a = nn.Conv2d(nInputPlane, nOutputPlane, kernel_size=( 118 | kW, kH), stride=(dW, dH), padding=0) 119 | layers.append(a) 120 | self.max_layers = nn.ModuleList(layers) 121 | self.s = dW 122 | 123 | def forward(self, x): 124 | hh = [] 125 | ww = [] 126 | res = [] 127 | 128 | for i in range(0, len(self.max_layers)): 129 | pad_left, pad_top = self.padd[i] 130 | _x = F.pad(x, [pad_left, pad_left, pad_top, pad_top], value=0) 131 | _x = self.max_layers[i](_x) 132 | h, w = _x.size()[2], _x.size()[3] 133 | hh.append(h) 134 | ww.append(w) 135 | res.append(_x) 136 | max_h, max_w = np.max(hh), np.max(ww) 137 | for i in range(0, len(self.max_layers)): 138 | _x = res[i] 139 | h, w = _x.size()[2], _x.size()[3] 140 | pad_top = np.ceil((max_h - h) / 2).astype(int) 141 | pad_bottom = np.floor((max_h - h) / 2).astype(int) 142 | pad_left = np.ceil((max_w - w) / 2).astype(int) 143 | pad_right = np.floor((max_w - w) / 2).astype(int) 144 | _x = F.pad(_x, [pad_left, pad_right, pad_top, pad_bottom], value=0) 145 | res[i] = _x 146 | return torch.cat(res, 0) 147 | 148 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of models in paper: 3 | MVTec, 4 | Uninformed Students: Student–Teacher Anomaly Detection with Discriminative Latent Embeddings. 5 | CVPR, 2020. 6 | 7 | Author: Luyao Chen 8 | Date: 2020.10 9 | """ 10 | 11 | import torch 12 | import numpy as np 13 | from torch import nn 14 | from fast_dense_feature_extractor import * 15 | 16 | 17 | class _Teacher17(nn.Module): 18 | """ 19 | T^ net for patch size 17. 20 | """ 21 | 22 | def __init__(self): 23 | super(_Teacher17, self).__init__() 24 | self.net = nn.Sequential( 25 | # Input n*3*17*17 26 | # ???? kernel_size=5???? 27 | nn.Conv2d(3, 128, kernel_size=6, stride=1), 28 | nn.LeakyReLU(5e-3), 29 | # n*128*12*12 30 | nn.Conv2d(128, 256, kernel_size=5, stride=1), 31 | nn.LeakyReLU(5e-3), 32 | # n*256*8*8 33 | nn.Conv2d(256, 256, kernel_size=5, stride=1), 34 | nn.LeakyReLU(5e-3), 35 | # n*256*4*4 36 | nn.Conv2d(256, 128, kernel_size=4, stride=1), 37 | # n*128*1*1 38 | ) 39 | self.decode = nn.Linear(128, 512) 40 | # nn.Sequential( 41 | # # nn.LeakyReLU(5e-3), 42 | # # # n*128*1*1 43 | # # nn.Conv2d(128, 512, kernel_size=1, stride=1), 44 | # # output n*512*1*1 45 | # ) 46 | 47 | def forward(self, x): 48 | x = self.net(x) 49 | x = x.view(-1, 128) 50 | x = self.decode(x) 51 | return x 52 | 53 | 54 | class _Teacher33(nn.Module): 55 | """ 56 | T^ net for patch size 33. 57 | """ 58 | 59 | def __init__(self): 60 | super(_Teacher33, self).__init__() 61 | self.net = nn.Sequential( 62 | # Input n*3*33*33 63 | nn.Conv2d(3, 128, kernel_size=3, stride=1), 64 | # nn.BatchNorm2d(128), 65 | nn.LeakyReLU(5e-3), 66 | # n*128*29*29 67 | nn.MaxPool2d(kernel_size=2, stride=2), 68 | # n*128*14*14 69 | nn.Conv2d(128, 256, kernel_size=5, stride=1), 70 | # nn.BatchNorm2d(256), 71 | nn.LeakyReLU(5e-3), 72 | # n*256*10*10 73 | nn.MaxPool2d(kernel_size=2, stride=2), 74 | # n*256*5*5 75 | nn.Conv2d(256, 256, kernel_size=2, stride=1), 76 | # nn.BatchNorm2d(256), 77 | nn.LeakyReLU(5e-3), 78 | # n*256*4*4 79 | nn.Conv2d(256, 128, kernel_size=4, stride=1), 80 | # n*128*1*1 81 | ) 82 | self.decode = nn.Linear(128, 512) 83 | 84 | def forward(self, x): 85 | x = self.net(x) 86 | x = x.view(-1, 128) 87 | x = self.decode(x) 88 | return x 89 | 90 | 91 | class _Teacher65(nn.Module): 92 | """ 93 | T^ net for patch size 65. 94 | """ 95 | 96 | def __init__(self): 97 | super(_Teacher65, self).__init__() 98 | self.net = nn.Sequential( 99 | # Input n*3*65*65 100 | nn.Conv2d(3, 128, kernel_size=5, stride=1), 101 | nn.LeakyReLU(5e-3), 102 | # n*128*61*61 103 | nn.MaxPool2d(kernel_size=2, stride=2), 104 | # n*128*30*30 105 | nn.Conv2d(128, 128, kernel_size=5, stride=1), 106 | nn.LeakyReLU(5e-3), 107 | # n*128*26*26 108 | nn.MaxPool2d(kernel_size=2, stride=2), 109 | # n*128*13*13 110 | nn.Conv2d(128, 128, kernel_size=5, stride=1), 111 | nn.LeakyReLU(5e-3), 112 | # n*128*9*9 113 | nn.MaxPool2d(kernel_size=2, stride=2), 114 | # n*256*4*4 115 | nn.Conv2d(128, 256, kernel_size=4, stride=1), 116 | nn.LeakyReLU(5e-3), 117 | # n*256*1*1 118 | # ???? kernel_size=3???? 119 | nn.Conv2d(256, 128, kernel_size=1, stride=1), 120 | # n*128*1*1 121 | ) 122 | self.decode = nn.Linear(128, 512) 123 | 124 | def forward(self, x): 125 | x = self.net(x) 126 | x = x.view(-1, 128) 127 | x = self.decode(x) 128 | return x 129 | 130 | 131 | class Teacher17(nn.Module): 132 | """ 133 | Teacher network with patch size 17. 134 | It has same architecture as T^17 because with no striding or pooling layers. 135 | """ 136 | 137 | def __init__(self, base_net: _Teacher17): 138 | super(Teacher17, self).__init__() 139 | self.multiPoolPrepare = multiPoolPrepare(17, 17) 140 | self.net = base_net.net 141 | 142 | def forward(self, x): 143 | x = self.multiPoolPrepare(x) 144 | x = self.net(x) 145 | x = x.permute(0, 2, 3, 1) 146 | return x 147 | 148 | 149 | class Teacher33(nn.Module): 150 | """ 151 | Teacher network with patch size 33. 152 | """ 153 | 154 | def __init__(self, base_net: _Teacher33, imH, imW): 155 | super(Teacher33, self).__init__() 156 | self.imH = imH 157 | self.imW = imW 158 | self.sL1 = 2 159 | self.sL2 = 2 160 | # image height and width should be multiples of sL1∗sL2∗sL3... 161 | # self.imW = int(np.ceil(imW / (self.sL1 * self.sL2)) * self.sL1 * self.sL2) 162 | # self.imH = int(np.ceil(imH / (self.sL1 * self.sL2)) * self.sL1 * self.sL2) 163 | assert imH % (self.sL1 * self.sL2) == 0, \ 164 | "image height should be multiples of (sL1∗sL2) which is " + \ 165 | str(self.sL1 * self.sL2) 166 | assert imW % (self.sL1 * self.sL2) == 0, \ 167 | "image width should be multiples of (sL1∗sL2) which is " + \ 168 | str(self.sL1 * self.sL2) 169 | 170 | self.outChans = base_net.net[-1].out_channels 171 | self.net = nn.Sequential( 172 | multiPoolPrepare(33, 33), 173 | base_net.net[0], 174 | base_net.net[1], 175 | multiMaxPooling(self.sL1, self.sL1, self.sL1, self.sL1), 176 | base_net.net[3], 177 | base_net.net[4], 178 | multiMaxPooling(self.sL2, self.sL2, self.sL2, self.sL2), 179 | base_net.net[6], 180 | base_net.net[7], 181 | base_net.net[8], 182 | unwrapPrepare(), 183 | unwrapPool(self.outChans, imH / (self.sL1 * self.sL2), 184 | imW / (self.sL1 * self.sL2), self.sL2, self.sL2), 185 | unwrapPool(self.outChans, imH / self.sL1, 186 | imW / self.sL1, self.sL1, self.sL1), 187 | ) 188 | 189 | def forward(self, x): 190 | x = self.net(x) 191 | x = x.view(x.shape[0], self.imH, self.imW, -1) 192 | x = x.permute(3, 1, 2, 0) 193 | return x 194 | 195 | 196 | class Teacher65(nn.Module): 197 | """ 198 | Teacher network with patch size 65. 199 | """ 200 | 201 | def __init__(self, base_net: _Teacher65, imH, imW): 202 | super(Teacher65, self).__init__() 203 | self.imH = imH 204 | self.imW = imW 205 | self.sL1 = 2 206 | self.sL2 = 2 207 | self.sL3 = 2 208 | # image height and width should be multiples of sL1∗sL2∗sL3... 209 | # self.imW = int(np.ceil(imW / (self.sL1 * self.sL2)) * self.sL1 * self.sL2) 210 | # self.imH = int(np.ceil(imH / (self.sL1 * self.sL2)) * self.sL1 * self.sL2) 211 | assert imH % (self.sL1 * self.sL2 * self.sL3) == 0, \ 212 | 'image height should be multiples of (sL1∗sL2*sL3) which is ' + \ 213 | str(self.sL1 * self.sL2 * self.sL3) + '.' 214 | assert imW % (self.sL1 * self.sL2 * self.sL3) == 0, \ 215 | 'image width should be multiples of (sL1∗sL2*sL3) which is ' + \ 216 | str(self.sL1 * self.sL2 * self.sL3) + '.' 217 | 218 | self.outChans = base_net.net[-1].out_channels 219 | self.net = nn.Sequential( 220 | multiPoolPrepare(65, 65), 221 | base_net.net[0], 222 | base_net.net[1], 223 | multiMaxPooling(self.sL1, self.sL1, self.sL1, self.sL1), 224 | base_net.net[3], 225 | base_net.net[4], 226 | multiMaxPooling(self.sL2, self.sL2, self.sL2, self.sL2), 227 | base_net.net[6], 228 | base_net.net[7], 229 | multiMaxPooling(self.sL3, self.sL3, self.sL3, self.sL3), 230 | base_net.net[9], 231 | base_net.net[10], 232 | base_net.net[11], 233 | unwrapPrepare(), 234 | unwrapPool(self.outChans, imH / (self.sL1 * self.sL2 * self.sL3), 235 | imW / (self.sL1 * self.sL2 * self.sL3), self.sL3, self.sL3), 236 | unwrapPool(self.outChans, imH / (self.sL1 * self.sL2), 237 | imW / (self.sL1 * self.sL2), self.sL2, self.sL2), 238 | unwrapPool(self.outChans, imH / self.sL1, 239 | imW / self.sL1, self.sL1, self.sL1), 240 | ) 241 | 242 | def forward(self, x): 243 | x = self.net(x) 244 | # print(x.shape) 245 | x = x.view(x.shape[0], self.imH, self.imW, -1) 246 | x = x.permute(3,1,2,0) 247 | return x 248 | 249 | 250 | 251 | def _Teacher(patch_size): 252 | if patch_size == 17: 253 | return _Teacher17() 254 | if patch_size == 33: 255 | return _Teacher33() 256 | if patch_size == 65: 257 | return _Teacher65() 258 | else: 259 | print('No implementation of net wiht patch_size: ' + str(patch_size)) 260 | return None 261 | 262 | 263 | def TeacherOrStudent(patch_size, base_net, imH=None, imW=None): 264 | if patch_size == 17: 265 | return Teacher17(base_net) 266 | if patch_size == 33: 267 | if imH is None or imW is None: 268 | print('imH and imW are necessary.') 269 | return None 270 | return Teacher33(base_net, imH, imW) 271 | if patch_size == 65: 272 | if imH is None or imW is None: 273 | print('imH and imW are necessary.') 274 | return None 275 | return Teacher65(base_net, imH, imW) 276 | else: 277 | print('No implementation of net wiht patch_size: '+str(patch_size)) 278 | return None 279 | 280 | if __name__ == "__main__": 281 | net = _Teacher17() 282 | imH = 256 283 | imW = 256 284 | 285 | T = Teacher17(net) 286 | # T = Teacher33(net, imH, imW) 287 | x = torch.ones((2, 3, imH, imW)) 288 | 289 | x_ = torch.ones((2, 3, 17, 17)) 290 | 291 | y = T(x) 292 | y_ = net(x_) 293 | 294 | # print(y) 295 | print(y.shape) 296 | print(y_.shape) 297 | # print(T) 298 | -------------------------------------------------------------------------------- /mvtec_dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | import os 4 | import os.path 5 | import sys 6 | import torch 7 | import torch.utils.data as data 8 | 9 | 10 | IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', 11 | '.pgm', '.tif', '.tiff', '.webp') 12 | 13 | 14 | def has_file_allowed_extension(filename, extensions): 15 | """Checks if a file is an allowed extension. 16 | 17 | Args: 18 | filename (string): path to a file 19 | extensions (tuple of strings): extensions to consider (lowercase) 20 | 21 | Returns: 22 | bool: True if the filename ends with one of given extensions 23 | """ 24 | return filename.lower().endswith(extensions) 25 | 26 | 27 | def is_image_file(filename): 28 | """Checks if a file is an allowed image extension. 29 | 30 | Args: 31 | filename (string): path to a file 32 | 33 | Returns: 34 | bool: True if the filename ends with a known image extension 35 | """ 36 | return has_file_allowed_extension(filename, IMG_EXTENSIONS) 37 | 38 | 39 | class MVTec_AD(data.Dataset): 40 | """A generic data loader where the samples are arranged in this way: :: 41 | 42 | root/class_x/xxx.ext 43 | root/class_x/xxy.ext 44 | root/class_x/xxz.ext 45 | 46 | root/class_y/123.ext 47 | root/class_y/nsdf3.ext 48 | root/class_y/asd932_.ext 49 | 50 | Args: 51 | root (string): Root directory path. 52 | loader (callable): A function to load a sample given its path. 53 | extensions (tuple[string]): A list of allowed extensions. 54 | both extensions and is_valid_file should not be passed. 55 | transform (callable, optional): A function/transform that takes in 56 | a sample and returns a transformed version. 57 | E.g, ``transforms.RandomCrop`` for images. 58 | target_transform (callable, optional): A function/transform that takes 59 | in the target and transforms it. 60 | is_valid_file (callable, optional): A function that takes path of an Image file 61 | and check if the file is a valid_file (used to check of corrupt files) 62 | both extensions and is_valid_file should not be passed. 63 | 64 | Attributes: 65 | classes (list): List of the class names. 66 | class_to_idx (dict): Dict with items (class_name, class_index). 67 | samples (list): List of (sample path, class_index) tuples 68 | targets (list): The class_index value for each image in the dataset 69 | """ 70 | 71 | def make_dataset(self, dir, class_to_idx, extensions=None, is_valid_file=None): 72 | images = [] 73 | dir = os.path.expanduser(dir) 74 | if self.phase == 'test': 75 | gt_dir = os.path.join(dir, 'ground_truth') 76 | dir = os.path.join(dir, self.phase) 77 | if not ((extensions is None) ^ (is_valid_file is None)): 78 | raise ValueError( 79 | "Both extensions and is_valid_file cannot be None or not None at the same time") 80 | if extensions is not None: 81 | def is_valid_file(x): 82 | return has_file_allowed_extension(x, extensions) 83 | for target in sorted(class_to_idx.keys()): 84 | d = os.path.join(dir, target) 85 | if not os.path.isdir(d): 86 | continue 87 | for root, _, fnames in sorted(os.walk(d)): 88 | for fname in sorted(fnames): 89 | path = os.path.join(root, fname) 90 | if self.phase == 'test': 91 | if target == 'good': 92 | gt_path = None 93 | else: 94 | gt_fname = fname.split('.')[0] + '_mask.png' 95 | gt_path = os.path.join(gt_dir, target, gt_fname) 96 | if is_valid_file(path): 97 | if self.phase == 'test': 98 | item = (path, gt_path, class_to_idx[target]) 99 | else: 100 | item = (path, class_to_idx[target]) 101 | images.append(item) 102 | 103 | return images 104 | 105 | def __init__(self, root, transform=None, 106 | mask_transform=None, extensions=IMG_EXTENSIONS, 107 | is_valid_file=None, phase='train'): 108 | if isinstance(root, torch._six.string_classes): 109 | root = os.path.expanduser(root) 110 | self.root = root 111 | if phase not in ('train', 'test'): 112 | raise (RuntimeError( 113 | 'phase of MVTec_AD dataset must be "train" or "test".')) 114 | self.phase = phase 115 | data_dir = os.path.join(self.root, phase) 116 | classes, class_to_idx = self._find_classes(data_dir) 117 | samples = self.make_dataset( 118 | self.root, class_to_idx, extensions, is_valid_file) 119 | if len(samples) == 0: 120 | raise (RuntimeError("Found 0 files in subfolders of: " + data_dir + "\n" 121 | "Supported extensions are: " + ",".join(extensions))) 122 | 123 | self.extensions = extensions 124 | self.transform = transform 125 | self.mask_transform = mask_transform 126 | self.classes = classes 127 | self.class_to_idx = class_to_idx 128 | self.samples = samples 129 | self.imgs = self.samples 130 | self.targets = [s[1] for s in samples] 131 | 132 | def pil_loader(self, path): 133 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 134 | with open(path, 'rb') as f: 135 | img = Image.open(f) 136 | return img.convert('RGB') 137 | 138 | def _find_classes(self, dir): 139 | """ 140 | Finds the class folders in a dataset. 141 | 142 | Args: 143 | dir (string): Root directory path. 144 | 145 | Returns: 146 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. 147 | 148 | Ensures: 149 | No class is a subdirectory of another. 150 | """ 151 | if sys.version_info >= (3, 5): 152 | # Faster and available in Python 3.5 and above 153 | classes = [d.name for d in os.scandir(dir) if d.is_dir()] 154 | else: 155 | classes = [d for d in os.listdir( 156 | dir) if os.path.isdir(os.path.join(dir, d))] 157 | classes.sort() 158 | class_to_idx = {classes[i]: i for i in range(len(classes))} 159 | return classes, class_to_idx 160 | 161 | def __getitem__(self, index): 162 | """ 163 | Args: 164 | index (int): Index 165 | 166 | Returns: 167 | tuple: (sample, target) where target is class_index of the target class. 168 | """ 169 | if self.phase == 'train': 170 | path, target = self.samples[index] 171 | sample = self.pil_loader(path) 172 | if self.transform is not None: 173 | sample = self.transform(sample) 174 | # if self.target_transform is not None: 175 | # target = self.target_transform(target) 176 | 177 | return sample, target 178 | else: 179 | path, gt_path, target = self.samples[index] 180 | sample = self.pil_loader(path) 181 | if gt_path is None: 182 | gt_mask = Image.new('L', sample.size) 183 | else: 184 | gt_mask = Image.open(gt_path) 185 | if self.transform is not None: 186 | sample = self.transform(sample) 187 | if self.mask_transform is not None: 188 | gt_mask = self.mask_transform(gt_mask) 189 | # if self.target_transform is not None: 190 | # target = self.target_transform(target) 191 | 192 | return sample, gt_mask, target 193 | 194 | def __len__(self): 195 | return len(self.samples) 196 | 197 | 198 | if __name__ == "__main__": 199 | from torchvision import transforms 200 | from torch.utils.data import DataLoader 201 | imH = 512 202 | imW = 512 203 | class_dir = 'leather/' 204 | test_dataset_dir = '/home/cly/data_disk/MVTec_AD/data/' + class_dir 205 | std = [0.229, 0.224, 0.225] 206 | mean = [0.485, 0.456, 0.406] 207 | trans = transforms.Compose([ 208 | # transforms.RandomCrop((imH, imW)), 209 | transforms.Resize((imH, imW)), 210 | transforms.ToTensor(), 211 | transforms.Normalize(mean, std) 212 | ]) 213 | trans2 = transforms.Compose([ 214 | # transforms.RandomCrop((imH, imW)), 215 | transforms.Resize((imH, imW), Image.NEAREST), 216 | transforms.ToTensor(), 217 | # transforms.Normalize(mean, std) 218 | ]) 219 | test_dataset = MVTec_AD(test_dataset_dir, transform=trans, 220 | mask_transform=trans2, phase='test') 221 | test_dataloader = DataLoader(test_dataset, batch_size=1) 222 | 223 | img, gt_mask, _ = next(iter(test_dataloader)) 224 | print(img.shape) 225 | print(gt_mask.shape) 226 | -------------------------------------------------------------------------------- /res.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuyaooChen/uninformed-students-pytorch/3c5661ea84c70dbf646ac5a1d2549762e05dcc16/res.jpg -------------------------------------------------------------------------------- /student_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of Chapt3.2 about students net training in the 'uninformed students' paper. 3 | 4 | Author: Luyao Chen 5 | Date: 2020.10 6 | """ 7 | 8 | import os 9 | import torch 10 | from torch import nn 11 | import torch.nn.functional as F 12 | from torch.utils.data import DataLoader 13 | from torchvision import transforms, datasets 14 | from tqdm import tqdm 15 | from models import _Teacher, TeacherOrStudent 16 | 17 | 18 | def increment_mean_and_var(mu_N, var_N, N, batch): 19 | '''Increment value of mean and variance based on 20 | current mean, var and new batch 21 | ''' 22 | # batch: (batch, h, w, vector) 23 | B = batch.size()[0] # batch size 24 | # we want a descriptor vector -> mean over batch and pixels 25 | mu_B = torch.mean(batch, dim=[0, 1, 2]) 26 | S_B = B * torch.var(batch, dim=[0, 1, 2], unbiased=False) 27 | S_N = N * var_N 28 | mu_NB = N / (N + B) * mu_N + B / (N + B) * mu_B 29 | S_NB = S_N + S_B + B * mu_B**2 + N * mu_N**2 - (N + B) * mu_NB**2 30 | var_NB = S_NB / (N + B) 31 | return mu_NB, var_NB, N + B 32 | 33 | if __name__ == "__main__": 34 | 35 | st_id = 0 # student id, start from 0. 36 | # image height and width should be multiples of sL1∗sL2∗sL3... 37 | imH = 512 38 | imW = 512 39 | patch_size = 17 40 | batch_size = 1 41 | epochs = 20 42 | lr = 1e-4 43 | weight_decay = 1e-5 44 | work_dir = 'work_dir/' 45 | class_dir = 'leather/' 46 | dataset_dir = '/home/cly/data_disk/MVTec_AD/data/' + class_dir + 'train/' 47 | # dataset_dir = '/home/cly/data_disk/印花布/normal/3/' 48 | device = torch.device('cuda:1') 49 | 50 | trans = transforms.Compose([ 51 | transforms.Resize((imH, imW)), 52 | # transforms.RandomCrop((imH, imW)), 53 | # transforms.RandomHorizontalFlip(), 54 | transforms.ToTensor(), 55 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 56 | ]) 57 | dataset = datasets.ImageFolder(dataset_dir, transform=trans) 58 | dataloader = DataLoader(dataset, batch_size=batch_size, 59 | shuffle=True, num_workers=8, pin_memory=True) 60 | 61 | _teacher = _Teacher(patch_size) 62 | student = TeacherOrStudent(patch_size, _teacher, imH, imW).to(device) 63 | 64 | _teacher = _Teacher(patch_size) 65 | checkpoint = torch.load(work_dir + '_teacher' + 66 | str(patch_size) + '.pth', torch.device('cpu')) 67 | _teacher.load_state_dict(checkpoint) 68 | teacher = TeacherOrStudent(patch_size, _teacher, imH, imW).to(device) 69 | teacher.eval() 70 | 71 | with torch.no_grad(): 72 | t_mu, t_var, N = 0, 0, 0 73 | for data, _ in tqdm(dataloader): 74 | data = data.to(device) 75 | t_out = teacher(data) 76 | t_mu, t_var, N = increment_mean_and_var(t_mu, t_var, N, t_out) 77 | 78 | optim = torch.optim.Adam(student.parameters(), lr=lr, 79 | weight_decay=weight_decay) 80 | 81 | iter_num = 1 82 | for i in range(epochs): 83 | for data, labels in dataloader: 84 | data = data.to(device) 85 | # labels = labels.to(device) 86 | with torch.no_grad(): 87 | teacher_output = (teacher(data) - t_mu) / torch.sqrt(t_var) 88 | 89 | student_output = student(data) 90 | loss = F.mse_loss(student_output, teacher_output) 91 | 92 | optim.zero_grad() 93 | loss.backward() 94 | optim.step() 95 | 96 | if iter_num % 10 == 0: 97 | print('epoch: {}, iter: {}, loss: {}'.format( 98 | i + 1, iter_num, loss)) 99 | iter_num += 1 100 | iter_num = 0 101 | 102 | if not os.path.exists(work_dir): 103 | os.mkdir(work_dir) 104 | if not os.path.exists(work_dir + class_dir): 105 | os.mkdir(work_dir + class_dir) 106 | print('Saving model to work_dir...') 107 | 108 | torch.save(student.state_dict(), work_dir + class_dir + 109 | 'student' + str(patch_size) + '_' + str(st_id) + '.pth') 110 | -------------------------------------------------------------------------------- /teacher_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of Chapt3.1 Learning Local Patch Descriptors in the 'uninformed students' paper, 3 | including knowledge distillation, metric learning and descriptor compactness. 4 | 5 | Author: Luyao Chen 6 | Date: 2020.10 7 | """ 8 | 9 | import os 10 | import torch 11 | from torch import nn 12 | import torch.nn.functional as F 13 | from torch.utils.data import DataLoader 14 | from torchvision import transforms, datasets, models 15 | from models import _Teacher 16 | 17 | from PIL import ImageFile 18 | ImageFile.LOAD_TRUNCATED_IMAGES = True 19 | 20 | 21 | def distillation_loss(output, target): 22 | # dim: (batch, vector) 23 | err = torch.norm(output - target, dim=1)**2 24 | loss = torch.mean(err) 25 | return loss 26 | 27 | 28 | def compactness_loss(output): 29 | # dim: (batch, vector) 30 | _, n = output.size() 31 | avg = torch.mean(output, axis=1) 32 | std = torch.std(output, axis=1) 33 | zt = output.T - avg 34 | zt /= std 35 | corr = torch.matmul(zt.T, zt) / (n - 1) 36 | loss = torch.sum(torch.triu(corr, diagonal=1)**2) 37 | return loss 38 | 39 | 40 | if __name__ == "__main__": 41 | patch_size = 65 42 | batch_size = 64 43 | lr = 2e-4 44 | weight_decay = 1e-5 45 | epochs = 2 46 | # alpha = 0.9 47 | # temperature = 20 48 | work_dir = 'work_dir/' 49 | device = torch.device('cuda:1') 50 | 51 | trans = transforms.Compose([ 52 | transforms.RandomResizedCrop(patch_size), 53 | transforms.ToTensor(), 54 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 55 | ]) 56 | dataset = datasets.ImageFolder( 57 | '/home/cly/data_disk/imagenet1k/train/', transform=trans) 58 | dataloader = DataLoader(dataset, batch_size=batch_size, 59 | shuffle=True, num_workers=8, pin_memory=True) 60 | 61 | model = _Teacher(patch_size).to(device) 62 | resnet18 = models.resnet18(pretrained=True) 63 | resnet18 = nn.Sequential(*list(resnet18.children())[:-1]).to(device) 64 | resnet18.eval() 65 | 66 | optim = torch.optim.Adam(model.parameters(), lr=lr, 67 | weight_decay=weight_decay) 68 | 69 | iter_num = 0 70 | for i in range(epochs): 71 | for data, labels in dataloader: 72 | data = data.to(device) 73 | # labels = labels.to(device) 74 | output = model(data) 75 | with torch.no_grad(): 76 | resnet_output = resnet18(data).view(-1, 512) 77 | 78 | # knowledge distillation loss 79 | # loss_k = F.smooth_l1_loss(output, resnet_output, reduction='sum') 80 | loss_k = distillation_loss(output, resnet_output) 81 | # metric learning is not implemented yet. 82 | loss_c = compactness_loss(output) 83 | loss = loss_k + loss_c 84 | optim.zero_grad() 85 | loss.backward() 86 | optim.step() 87 | 88 | iter_num += 1 89 | if iter_num % 10 == 0: 90 | print('epoch:{}, iter:{}, loss_k:{:.3f}, loss_c:{:.3f}, loss:{:.3f}'.format( 91 | i + 1, iter_num, loss_k, loss_c, loss)) 92 | iter_num = 0 93 | 94 | if not os.path.exists(work_dir): 95 | os.mkdir(work_dir) 96 | print('Saving model to work_dir...') 97 | torch.save(model.state_dict(), work_dir + 98 | '_teacher' + str(patch_size) + '.pth') 99 | --------------------------------------------------------------------------------