├── LICENSE ├── artifact_README.md ├── benchmark.py ├── dataset ├── caltech256.py ├── cub200.py ├── flower102.py ├── gtsrb.py ├── imagenet.py ├── mit67.py ├── seqimagenet.py ├── stanford_40.py ├── stanford_dog.py └── vis_da.py ├── evaluate.ipynb ├── finetune.py ├── finetuner.py ├── load_model.py ├── model ├── __init__.py ├── fe_mobilenet.py ├── fe_resnet.py └── fe_vgg16.py ├── modeldiff.py ├── readme.md ├── utils.py └── weight_pruner.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Yuanchun Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /artifact_README.md: -------------------------------------------------------------------------------- 1 | # ModelDiff: Testing-based DNN Similarity Comparison for Model Reuse Detection 2 | 3 | ## About 4 | This is the artifact associated with our ISSTA paper "ModelDiff: Testing-based DNN Similarity Comparison for Model Reuse Detection". 5 | 6 | ModelDiff is a testing-based approach to deep learning model similarity comparison. Instead of directly comparing the weights, activations, or outputs of two models, ModelDiff compares their behavioral patterns on the same set of test inputs. Specifically, the behavioral pattern of a model is represented as a decision distance vector (DDV), in which each element is the distance between the model's reactions to a pair of inputs. The knowledge similarity between two models is measured with the cosine similarity between their DDVs. 7 | To evaluate ModelDiff, we created a benchmark that contains 144 pairs of models that cover most popular model reuse methods, including transfer learning, model compression, and model stealing. Our method achieved 91.7% correctness on the benchmark, which demonstrates the effectiveness of using ModelDiff for model reuse detection. A study on mobile deep learning apps has shown the feasibility of ModelDiff on real-world models. 8 | 9 | ## Environment 10 | - Ubuntu 16.04 11 | - CUDA 10.0 12 | 13 | ## Dependencies 14 | - PyTorch 1.5.0 15 | - TorchVision 0.6.0 16 | - AdverTorch 0.2.0 17 | 18 | ## Get start 19 | - You should have a GPU on your device because the adversarial sample computation is pretty slow 20 | - You should first install CUDA 10.2 on your device (if you don't have) from [here](https://developer.nvidia.com/cuda-downloads) 21 | - Install [Anaconda](https://www.anaconda.com/) and create a new environment and enter the environment 22 | ``` 23 | conda create --name modeldiff python=3.6 24 | ``` 25 | - Install pytorch in the new environment 26 | ``` 27 | conda install pytorch==1.5.0 torchvision==0.6.0 cudatoolkit=10.2 -c pytorch 28 | ``` 29 | - Install AdvTorch 30 | ``` 31 | pip install advertorch 32 | ``` 33 | - Install other packages 34 | 35 | ``` 36 | pip install scipy 37 | ``` 38 | - Make a new directory called ``data`` and Download all three datasets listed below in the ``data`` directory 39 | ``` 40 | data\ 41 | |--- CUB_200_2011/ 42 | |--- stanford_dog/ 43 | |--- MIT_67/ 44 | ``` 45 | 46 | 47 | 48 | ## Prepare dataset 49 | 50 | ### [Caltech-UCSD 200 Birds](http://www.vision.caltech.edu/visipedia/CUB-200.html) 51 | Layout should be the following for the dataloader to load correctly 52 | 53 | ``` 54 | CUB_200_2011/ 55 | | README 56 | | bounding_boxes.txt 57 | | classes.txt 58 | | image_class_labels.txt 59 | | images.txt 60 | | train_test_split.txt 61 | |--- attributes 62 | |--- images/ 63 | |--- parts/ 64 | |--- train/ 65 | |--- test/ 66 | ``` 67 | 68 | ### [Stanford 120 Dogs](http://vision.stanford.edu/aditya86/ImageNetDogs/) 69 | ``` 70 | stanford_dog/ 71 | | file_list.mat 72 | | test_list.mat 73 | | train_list.mat 74 | |--- train/ 75 | |--- test/ 76 | |--- Images/ 77 | |--- Annotation/ 78 | ``` 79 | 80 | 81 | ### [MIT 67 Indoor Scenes](http://web.mit.edu/torralba/www/indoor.html) 82 | ``` 83 | MIT_67/ 84 | | TrainImages.txt 85 | | TestImages.txt 86 | |--- Annotations/ 87 | |--- Images/ 88 | |--- test/ 89 | |--- train/ 90 | ``` 91 | 92 | ## Prepare models 93 | You can change the size of the benchmark and the number of models to use in benchmark.py. The models used in the paper are MobileNetV2 and ResNet18 trained on Flower102 and StanfordDogs120 datasets. You can add other architectures and datasets the ImageBenchmark class of benchmark.py (line 487 to line 503 as following). 94 | ``` 95 | # Used in the paper 96 | self.datasets = ['Flower102', 'SDog120'] 97 | self.archs = ['mbnetv2', 'resnet18'] 98 | # Other archs 99 | # self.datasets = ['MIT67', 'Flower102', 'SDog120'] 100 | # self.archs = ['mbnetv2', 'resnet18', 'vgg16_bn', 'vgg11_bn', 'resnet34', 'resnet50'] 101 | # For debug 102 | # self.datasets = ['Flower102'] 103 | # self.archs = ['resnet18'] 104 | ``` 105 | 106 | We also provide the benchmark used in the paper and you can download it from [google drive](https://drive.google.com/file/d/1UfhnPB2V2bpwpWxnne1bodI1cIT3q98c/view?usp=sharing). 107 | 108 | ## Evaluation 109 | The code to compare DDV (decision distance vector) model similarity is in evaluate.ipynb. It loads the benchmark models from benchmark.py and compare similarity. 110 | 111 | ## Authors 112 | - Yuanchun Li (Github ID: ylimit, email: pkulyc@gmail.com) 113 | - Ziqi Zhang (Github ID: ziqi-zhang, email: ziqi_zhang@pku.edu.cn) 114 | - Bingyan Liu (email: lby_cs@pku.edu.cn) 115 | - Ziyue Yang (email: Ziyue.Yang@microsoft.com) 116 | - Yunxing Liu (email: Yunxin.Liu@microsoft.com) 117 | 118 | ## DOI 119 | We put our code at https://zenodo.org/record/4723301#.YIf-rH0zYUE with a public DOI: 10.5281/zenodo.4723301 120 | 121 | ## Acknowledgement 122 | Some of the code is referred from [Renofeation](https://github.com/cmu-enyac/Renofeation) and we thank the authors for sharing the code. -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import argparse 5 | import json 6 | import random 7 | import logging 8 | import pathlib 9 | import re 10 | import functools 11 | import torch 12 | import torch.nn as nn 13 | import torchvision 14 | import torchvision.models as models 15 | import numpy as np 16 | from pdb import set_trace as st 17 | import copy 18 | 19 | 20 | from dataset.mit67 import MIT67 21 | from dataset.stanford_dog import SDog120 22 | from dataset.flower102 import Flower102 23 | from dataset.caltech256 import Caltech257Data 24 | from dataset.stanford_40 import Stanford40Data 25 | from dataset.cub200 import CUB200Data 26 | 27 | from model.fe_resnet import resnet18_dropout, resnet34_dropout, resnet50_dropout, resnet101_dropout 28 | from model.fe_mobilenet import mbnetv2_dropout 29 | from model.fe_resnet import feresnet18, feresnet34, feresnet50, feresnet101 30 | from model.fe_mobilenet import fembnetv2 31 | from model.fe_vgg16 import * 32 | from finetuner import Finetuner 33 | from weight_pruner import WeightPruner 34 | 35 | 36 | SEED = 98 37 | INPUT_SHAPE = (3, 224, 224) 38 | BATCH_SIZE = 64 39 | TRAIN_ITERS = 100000 40 | DEFAULT_ITERS = 10000 41 | TRANSFER_ITERS = DEFAULT_ITERS 42 | QUANTIZE_ITERS = DEFAULT_ITERS # may be useless 43 | PRUNE_ITERS = DEFAULT_ITERS 44 | DISTILL_ITERS = DEFAULT_ITERS 45 | STEAL_ITERS = DEFAULT_ITERS 46 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 47 | CONTINUE_TRAIN = False # whether to continue previous training 48 | 49 | 50 | def lazy_property(func): 51 | attribute = '_lazy_' + func.__name__ 52 | 53 | @property 54 | @functools.wraps(func) 55 | def wrapper(self): 56 | if not hasattr(self, attribute): 57 | setattr(self, attribute, func(self)) 58 | return getattr(self, attribute) 59 | 60 | return wrapper 61 | 62 | 63 | def base_args(): 64 | args = argparse.Namespace() 65 | args.const_lr = False 66 | args.batch_size = BATCH_SIZE 67 | args.lr = 5e-3 68 | args.print_freq = 100 69 | args.label_smoothing = 0 70 | args.vgg_output_distill = False 71 | args.reinit = False 72 | args.l2sp_lmda = 0 73 | args.train_all = False 74 | args.ft_begin_module = None 75 | args.momentum = 0 76 | args.weight_decay = 1e-4 77 | args.beta = 1e-2 78 | args.feat_lmda = 0 79 | args.test_interval = 1000 80 | args.adv_test_interval = -1 81 | args.feat_layers = '1234' 82 | args.no_save = False 83 | args.steal = False 84 | return args 85 | 86 | 87 | class ModelWrapper: 88 | def __init__(self, benchmark, teacher_wrapper, trans_str, 89 | arch_id=None, dataset_id=None, iters=100, fc=True): 90 | self.logger = logging.getLogger('ModelWrapper') 91 | self.benchmark = benchmark 92 | self.teacher_wrapper = teacher_wrapper 93 | self.trans_str = trans_str 94 | self.arch_id = arch_id if arch_id else teacher_wrapper.arch_id 95 | self.dataset_id = dataset_id if dataset_id else teacher_wrapper.dataset_id 96 | self.torch_model_path = os.path.join(benchmark.models_dir, f'{self.__str__()}') 97 | self.iters = iters 98 | self.fc = fc 99 | assert self.arch_id is not None 100 | assert self.dataset_id is not None 101 | 102 | def __str__(self): 103 | teacher_str = '' if self.teacher_wrapper is None else self.teacher_wrapper.__str__() 104 | return f'{teacher_str}{self.trans_str}-' 105 | 106 | def name(self): 107 | return self.__str__() 108 | 109 | def torch_model_exists(self): 110 | ckpt_path = os.path.join(self.torch_model_path, 'final_ckpt.pth') 111 | return os.path.exists(ckpt_path) 112 | 113 | def save_torch_model(self, torch_model): 114 | if not os.path.exists(self.torch_model_path): 115 | os.makedirs(self.torch_model_path) 116 | ckpt_path = os.path.join(self.torch_model_path, 'final_ckpt.pth') 117 | torch.save( 118 | {'state_dict': torch_model.state_dict()}, 119 | ckpt_path, 120 | ) 121 | 122 | @lazy_property 123 | def torch_model(self): 124 | """ 125 | load the model object from torch_model_path 126 | :return: torch.nn.Module object 127 | """ 128 | if self.dataset_id == 'ImageNet': 129 | num_classes = 1000 130 | else: 131 | num_classes = self.benchmark.get_dataloader(self.dataset_id).dataset.num_classes 132 | 133 | if self.fc: 134 | torch_model = eval(f'{self.arch_id}_dropout')( 135 | pretrained=False, 136 | num_classes=num_classes 137 | ) 138 | else: 139 | torch_model = eval(f'fe{self.arch_id}')( 140 | pretrained=False, 141 | num_classes=num_classes 142 | ) 143 | 144 | m = re.match(r'(\S+)\((\S*)\)', self.trans_str) 145 | method = m.group(1) 146 | params = m.group(2).split(',') 147 | if method == 'quantize': 148 | dtype = params[0] 149 | dtype = torch.qint8 if dtype == 'qint8' else torch.float16 150 | torch_model = torch.quantization.quantize_dynamic(torch_model, dtype=dtype) 151 | ckpt = torch.load(os.path.join(self.torch_model_path, 'final_ckpt.pth')) 152 | torch_model.load_state_dict(ckpt['state_dict']) 153 | return torch_model 154 | 155 | @lazy_property 156 | def torch_model_on_device(self): 157 | m = re.match(r'(\S+)\((\S*)\)', self.trans_str) 158 | method = m.group(1) 159 | if method == "quantize": 160 | return self.torch_model.to("cpu") 161 | else: 162 | return self.torch_model.to(DEVICE) 163 | 164 | def load_saved_weights(self, torch_model): 165 | """ 166 | load weights in the latest checkpoint to torch_model 167 | """ 168 | ckpt_path = os.path.join(self.torch_model_path, 'ckpt.pth') 169 | if os.path.exists(ckpt_path): 170 | ckpt = torch.load(ckpt_path) 171 | torch_model.load_state_dict(ckpt['state_dict']) 172 | self.logger.info('load_saved_weights: loaded a previous checkpoint') 173 | else: 174 | self.logger.info('load_saved_weights: no previous checkpoint found') 175 | return torch_model 176 | 177 | @lazy_property 178 | def input_shape(self): 179 | return INPUT_SHAPE 180 | 181 | def get_seed_inputs(self, n, rand=False): 182 | if rand: 183 | batch_input_size = (n, *INPUT_SHAPE) 184 | images = np.random.normal(size=batch_input_size).astype(np.float32) 185 | else: 186 | dataset_id = 'MIT67' if self.dataset_id == 'ImageNet' else self.dataset_id 187 | train_loader = self.benchmark.get_dataloader( 188 | dataset_id, split='train', batch_size=n, shuffle=True) 189 | images, labels = next(iter(train_loader)) 190 | images = images.to('cpu').numpy() 191 | return images 192 | 193 | def batch_forward(self, inputs): 194 | if isinstance(inputs, np.ndarray): 195 | inputs = torch.from_numpy(inputs) 196 | m = re.match(r'(\S+)\((\S*)\)', self.trans_str) 197 | method = m.group(1) 198 | if method == "quantize": 199 | inputs = inputs.to("cpu") 200 | else: 201 | inputs = inputs.to(DEVICE) 202 | self.torch_model_on_device.eval() 203 | with torch.no_grad(): 204 | return self.torch_model_on_device(inputs) 205 | 206 | def list_tensors(self): 207 | pass 208 | 209 | def batch_forward_with_ir(self, inputs): 210 | if isinstance(inputs, np.ndarray): 211 | inputs = torch.from_numpy(inputs) 212 | idx = 0 213 | hook_handles = [] 214 | module_ir = {} 215 | model = self.torch_model 216 | 217 | def register_hooks(module): 218 | def hook(module, input, output): 219 | global idx 220 | class_name = str(module.__class__).split(".")[-1].split("'")[0] 221 | module_name = f"{class_name}/{idx:03d}" 222 | idx += 1 223 | module_ir[module_name] = output.numpy() 224 | 225 | if len(list(module.children())) == 0: 226 | handle = module.register_forward_hook(hook) 227 | hook_handles.append(handle) 228 | 229 | def remove_hooks(): 230 | for h in hook_handles: 231 | h.remove() 232 | 233 | model.eval() 234 | with torch.no_grad(): 235 | model.apply(register_hooks) 236 | outputs = model(inputs) 237 | remove_hooks() 238 | return module_ir 239 | 240 | def gen_model(self, regenerate=False): 241 | """ 242 | generate the torch model 243 | :return: 244 | """ 245 | trans_str = self.trans_str 246 | if not regenerate and self.torch_model_exists(): 247 | self.logger.info(f'model already exists: {self.__str__()}') 248 | return 249 | self.logger.info(f'generating model for: {self.__str__()}') 250 | m = re.match(r'(\S+)\((\S*)\)', trans_str) 251 | method = m.group(1) 252 | params = m.group(2).split(',') 253 | 254 | if regenerate and os.path.exists(self.torch_model_path): 255 | import shutil 256 | shutil.rmtree(self.torch_model_path) 257 | if not os.path.exists(self.torch_model_path): 258 | os.makedirs(self.torch_model_path) 259 | 260 | teacher_model = None 261 | if self.teacher_wrapper: 262 | self.teacher_wrapper.gen_model() 263 | teacher_model = self.teacher_wrapper.torch_model 264 | train_loader = self.benchmark.get_dataloader(self.dataset_id, split='train') 265 | test_loader = self.benchmark.get_dataloader(self.dataset_id, split='test') 266 | 267 | args = base_args() 268 | args.iterations = self.iters 269 | args.output_dir = self.torch_model_path 270 | 271 | if method == 'pretrain': 272 | # load pretrained model as specified by arch_id and save it to model path 273 | arch_id = params[0] 274 | dataset_id = params[1] 275 | if dataset_id != 'ImageNet': 276 | self.logger.warning(f'gen_model: pretrained model on {dataset_id} not supported') 277 | torch_model = eval(f'{arch_id}_dropout')( 278 | pretrained=True, 279 | num_classes=1000 280 | ) 281 | self.save_torch_model(torch_model) 282 | elif method == 'train': 283 | # train the model from scratch 284 | arch_id = params[0] 285 | dataset_id = params[1] 286 | torch_model = eval(f'{arch_id}_dropout')( 287 | pretrained=False, 288 | num_classes=train_loader.dataset.num_classes 289 | ) 290 | args.network = self.arch_id 291 | args.ft_ratio = 1 292 | args.reinit = True 293 | args.lr = 1e-2 294 | args.weight_decay = 5e-3 295 | args.momentum = 0.9 296 | 297 | if CONTINUE_TRAIN: 298 | torch_model = self.load_saved_weights(torch_model) # continue training 299 | finetuner = Finetuner( 300 | args, 301 | torch_model, torch_model, 302 | train_loader, test_loader, 303 | ) 304 | finetuner.train() 305 | self.save_torch_model(torch_model) 306 | elif method == 'transfer': 307 | # transfer the teacher to a dataset as specified by dataset_id, fine-tune the last tune_ratio% layers 308 | dataset_id = params[0] 309 | tune_ratio = float(params[1]) 310 | student_model = eval(f'{self.arch_id}_dropout')( 311 | pretrained=True, 312 | num_classes=train_loader.dataset.num_classes 313 | ) 314 | # FIXME copy state_dict from teacher to student, ignore the final layer 315 | # student_model.load_state_dict(teacher_model.state_dict(), strict=False) 316 | 317 | args.network = self.arch_id 318 | args.ft_ratio = tune_ratio 319 | 320 | if CONTINUE_TRAIN: 321 | student_model = self.load_saved_weights(student_model) # continue training 322 | finetuner = Finetuner( 323 | args, 324 | student_model, teacher_model, 325 | train_loader, test_loader, 326 | ) 327 | finetuner.train() 328 | self.save_torch_model(student_model) 329 | elif method == 'quantize': 330 | dtype = params[0] 331 | dtype = torch.qint8 if dtype == 'qint8' else torch.float16 332 | student_model = torch.quantization.quantize_dynamic(teacher_model, dtype=dtype) 333 | self.save_torch_model(student_model) 334 | elif method == 'prune': 335 | prune_ratio = float(params[0]) 336 | student_model = copy.deepcopy(teacher_model) 337 | 338 | args.network = self.arch_id 339 | args.method = "weight" 340 | args.weight_ratio = prune_ratio 341 | 342 | if CONTINUE_TRAIN: 343 | student_model = self.load_saved_weights(student_model) # continue training 344 | 345 | finetuner = WeightPruner( 346 | args, 347 | student_model, teacher_model, 348 | train_loader, test_loader, 349 | ) 350 | finetuner.train() 351 | self.save_torch_model(student_model) 352 | finetuner.final_check_param_num() 353 | elif method == 'distill': 354 | student_model = eval(f'{self.arch_id}_dropout')( 355 | pretrained=False, 356 | num_classes=train_loader.dataset.num_classes 357 | ) 358 | args.network = self.arch_id 359 | args.feat_lmda = 5e0 360 | args.reinit = True 361 | args.lr = 1e-2 362 | args.weight_decay = 5e-3 363 | args.momentum = 0.9 364 | 365 | if CONTINUE_TRAIN: 366 | student_model = self.load_saved_weights(student_model) # continue training 367 | 368 | finetuner = Finetuner( 369 | args, 370 | student_model, teacher_model, 371 | train_loader, test_loader, 372 | ) 373 | finetuner.train() 374 | self.save_torch_model(student_model) 375 | elif method == 'steal': 376 | arch_id = params[0] 377 | # use output distillation to transfer teacher knowledge to another architecture 378 | student_model = eval(f'{arch_id}_dropout')( 379 | pretrained=False, 380 | num_classes=train_loader.dataset.num_classes 381 | ) 382 | 383 | args.network = arch_id 384 | args.steal = True 385 | args.reinit = True 386 | args.steal_alpha = 1 387 | args.temperature = 1 388 | args.lr = 1e-2 389 | args.weight_decay = 5e-3 390 | args.momentum = 0.9 391 | 392 | if CONTINUE_TRAIN: 393 | student_model = self.load_saved_weights(student_model) # continue training 394 | 395 | finetuner = Finetuner( 396 | args, 397 | student_model, teacher_model, 398 | train_loader, test_loader, 399 | ) 400 | finetuner.train() 401 | self.save_torch_model(student_model) 402 | else: 403 | raise RuntimeError(f'unknown transformation: {method}') 404 | 405 | def transfer(self, dataset_id, tune_ratio=0.1, iters=TRANSFER_ITERS): 406 | trans_str = f'transfer({dataset_id},{tune_ratio})' 407 | # model_wrapper is the wrapper of the student model 408 | model_wrapper = ModelWrapper( 409 | benchmark=self.benchmark, 410 | teacher_wrapper=self, 411 | trans_str=trans_str, 412 | dataset_id=dataset_id, 413 | iters=iters 414 | ) 415 | return model_wrapper 416 | 417 | def quantize(self, dtype='qint8'): 418 | """ 419 | do post-training quantization on the model 420 | :param dtype: qint8 or float16 421 | :return: 422 | """ 423 | trans_str = f'quantize({dtype})' 424 | model_wrapper = ModelWrapper( 425 | benchmark=self.benchmark, 426 | teacher_wrapper=self, 427 | trans_str=trans_str 428 | ) 429 | return model_wrapper 430 | 431 | def prune(self, prune_ratio=0.1, iters=PRUNE_ITERS): 432 | trans_str = f'prune({prune_ratio})' 433 | model_wrapper = ModelWrapper( 434 | benchmark=self.benchmark, 435 | teacher_wrapper=self, 436 | trans_str=trans_str, 437 | iters=iters 438 | ) 439 | return model_wrapper 440 | 441 | def distill(self, iters=DISTILL_ITERS): 442 | trans_str = f'distill()' 443 | model_wrapper = ModelWrapper( 444 | benchmark=self.benchmark, 445 | teacher_wrapper=self, 446 | trans_str=trans_str, 447 | iters=iters 448 | ) 449 | return model_wrapper 450 | 451 | def steal(self, arch_id, iters=STEAL_ITERS): 452 | trans_str = f'steal({arch_id})' 453 | model_wrapper = ModelWrapper( 454 | benchmark=self.benchmark, 455 | teacher_wrapper=self, 456 | trans_str=trans_str, 457 | arch_id=arch_id, 458 | iters=iters 459 | ) 460 | return model_wrapper 461 | 462 | @lazy_property 463 | def accuracy(self): 464 | """ 465 | evaluate the model accuracy on the dataset 466 | :return: a float number 467 | """ 468 | # TODO implement this 469 | model = self.torch_model.to(DEVICE) 470 | test_loader = self.benchmark.get_dataloader(self.dataset_id, split='test') 471 | 472 | with torch.no_grad(): 473 | model.eval() 474 | total = 0 475 | top1 = 0 476 | for i, (batch, label) in enumerate(test_loader): 477 | batch, label = batch.to(DEVICE), label.to(DEVICE) 478 | total += batch.size(0) 479 | out = model(batch) 480 | _, pred = out.max(dim=1) 481 | top1 += int(pred.eq(label).sum().item()) 482 | # print(top1, total) 483 | return float(top1) / total * 100 484 | 485 | 486 | class ImageBenchmark: 487 | def __init__(self, datasets_dir='data', models_dir='models'): 488 | self.logger = logging.getLogger('ImageBench') 489 | self.datasets_dir = datasets_dir 490 | self.models_dir = models_dir 491 | """ 492 | Available datasets are MIT67, Flower102, SDog120 493 | Available models are mbnetv2, resnet18, resnet34, resnet50, vgg11_bn, vgg16_bn 494 | """ 495 | # Used in the paper 496 | self.datasets = ['Flower102', 'SDog120'] 497 | self.archs = ['mbnetv2', 'resnet18'] 498 | # Other archs 499 | # self.datasets = ['MIT67', 'Flower102', 'SDog120'] 500 | # self.archs = ['mbnetv2', 'resnet18', 'vgg16_bn', 'vgg11_bn', 'resnet34', 'resnet50'] 501 | # For debug 502 | # self.datasets = ['Flower102'] 503 | # self.archs = ['resnet18'] 504 | 505 | def get_dataloader(self, dataset_id, split='train', batch_size=BATCH_SIZE, shuffle=True, seed=SEED, shot=-1): 506 | """ 507 | Get the torch Dataset object 508 | :param dataset_id: the name of the dataset, should also be the dir name and the class name 509 | :param split: train or test 510 | :param batch_size: batch size 511 | :param shot: number of training samples per class for the training dataset. -1 indicates using the full dataset 512 | :return: torch.utils.data.DataLoader instance 513 | """ 514 | try: 515 | datapath = os.path.join(self.datasets_dir, dataset_id) 516 | normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 517 | 518 | from torchvision import transforms 519 | if split == 'train': 520 | dataset = eval(dataset_id)( 521 | datapath, True, transforms.Compose([ 522 | transforms.RandomResizedCrop(224), 523 | transforms.RandomHorizontalFlip(), 524 | transforms.ToTensor(), 525 | normalize, 526 | ]), 527 | shot, seed, preload=False 528 | ) 529 | else: 530 | dataset = eval(dataset_id)( 531 | datapath, False, transforms.Compose([ 532 | transforms.Resize(256), 533 | transforms.CenterCrop(224), 534 | transforms.ToTensor(), 535 | normalize, 536 | ]), 537 | shot, seed, preload=False 538 | ) 539 | 540 | data_loader = torch.utils.data.DataLoader( 541 | dataset, 542 | batch_size=batch_size, shuffle=shuffle, 543 | num_workers=8, pin_memory=False 544 | ) 545 | return data_loader 546 | except Exception as e: 547 | self.logger.warning(f'get_dataloader failed: {e}') 548 | return None 549 | 550 | def load_pretrained(self, arch_id, fc=True): 551 | """ 552 | Get the model pretrained on imagenet 553 | :param arch_id: the name of the arch 554 | :return: a ModelWrapper instance 555 | """ 556 | model_wrapper = ModelWrapper( 557 | benchmark=self, 558 | teacher_wrapper=None, 559 | trans_str=f'pretrain({arch_id},ImageNet)', 560 | arch_id=arch_id, 561 | dataset_id='ImageNet', 562 | fc=fc, 563 | ) 564 | return model_wrapper 565 | 566 | def load_trained(self, arch_id, dataset_id, iters=TRAIN_ITERS, fc=True): 567 | """ 568 | Get the model with architecture arch_id trained on dataset dataset_id 569 | :param arch_id: the name of the arch 570 | :param dataset_id: the name of the dataset 571 | :param iters: number of iterations 572 | :return: a ModelWrapper instance 573 | """ 574 | model_wrapper = ModelWrapper( 575 | benchmark=self, 576 | teacher_wrapper=None, 577 | trans_str=f'train({arch_id},{dataset_id})', 578 | arch_id=arch_id, 579 | dataset_id=dataset_id, 580 | iters=iters, 581 | fc=fc, 582 | ) 583 | return model_wrapper 584 | 585 | def list_models(self, fc=True): 586 | """ 587 | list the models in the benchmark dataset 588 | :return: a stream of ModelWrapper instances 589 | """ 590 | source_models = [] 591 | 592 | quantization_dtypes = ['qint8', 'float16'] 593 | prune_ratios = [0.2, 0.5, 0.8] 594 | transfer_tune_ratios = [0.1, 0.5, 1] 595 | 596 | # load pretrained source models 597 | for arch in self.archs: 598 | source_model = self.load_pretrained(arch, fc=fc) 599 | source_models.append(source_model) 600 | yield source_model 601 | 602 | # retrain models 603 | retrain_models = [] 604 | for arch_id in self.archs: 605 | for dataset_id in self.datasets: 606 | retrain_model = self.load_trained(arch_id, dataset_id, TRAIN_ITERS, fc=fc) 607 | retrain_models.append(retrain_model) 608 | yield retrain_model 609 | 610 | # for debug 611 | # prune_ratios = [0.2] 612 | # transfer_tune_ratios = [0.5, 1] 613 | 614 | transfer_models = [] 615 | # - M_{i,x}/{trans-y,l} -- Transfer M_{i,x} to D_y by fine-tuning from l-st layer 616 | for source_model in source_models: 617 | for dataset_id in self.datasets: 618 | if dataset_id == source_model.dataset_id: 619 | continue 620 | for tune_ratio in transfer_tune_ratios: 621 | transfer_model = source_model.transfer(dataset_id=dataset_id, tune_ratio=tune_ratio) 622 | transfer_models.append(transfer_model) 623 | yield transfer_model 624 | 625 | # - M_{i,x}/{quant-qint8/float16} -- Compress M_{i,x} with integer / float16 quantization 626 | for transfer_model in transfer_models: 627 | for quantization_dtype in quantization_dtypes: 628 | yield transfer_model.quantize(dtype=quantization_dtype) 629 | 630 | # - M_{i,x}/{prune-p} -- Prune M_{i,x} with pruning ratio = p 631 | for transfer_model in transfer_models: 632 | for pr in prune_ratios: 633 | yield transfer_model.prune(prune_ratio=pr) 634 | 635 | # - M_{i,x}/{distill} -- Distill M_{i,x} 636 | for transfer_model in transfer_models: 637 | yield transfer_model.distill() 638 | 639 | # - M_{i,x}/{steal-j} -- Steal M_{i,x} to A_j 640 | for transfer_model in transfer_models: 641 | for arch_id in self.archs: 642 | yield transfer_model.steal(arch_id=arch_id) 643 | 644 | # variations of retrained models 645 | # - M_{i,x}/{prune-p} -- Prune M_{i,x} with pruning ratio = p 646 | for retrain_model in retrain_models: 647 | for pr in prune_ratios: 648 | yield retrain_model.prune(prune_ratio=pr) 649 | 650 | # - M_{i,x}/{distill} -- Distill M_{i,x} 651 | for retrain_model in retrain_models: 652 | yield retrain_model.distill() 653 | 654 | # - M_{i,x}/{steal-j} -- Steal M_{i,x} to A_j 655 | for retrain_model in retrain_models: 656 | for arch_id in self.archs: 657 | yield retrain_model.steal(arch_id=arch_id) 658 | 659 | 660 | def parse_args(): 661 | """ 662 | Parse command line input 663 | :return: 664 | """ 665 | parser = argparse.ArgumentParser(description="Build micro benchmark.") 666 | 667 | parser.add_argument("-datasets_dir", action="store", dest="datasets_dir", default='data', 668 | help="Path to the dir of datasets.") 669 | parser.add_argument("-models_dir", action="store", dest="models_dir", default='models', 670 | help="Path to the dir of benchmark models.") 671 | parser.add_argument("-mask", action="store", dest="mask", default="", 672 | help="The mask to filter the models to generate, split with +") 673 | parser.add_argument("-phase", action="store", dest="phase", type=str, default="", 674 | help="The phase to run. Use a prefix to filter the phases.") 675 | parser.add_argument("-regenerate", action="store_true", dest="regenerate", default=False, 676 | help="Whether to regenerate the models.") 677 | args, unknown = parser.parse_known_args() 678 | return args 679 | 680 | def check_param_num(model, name): 681 | total = sum([module.weight.nelement() for module in model.modules() if isinstance(module, nn.Conv2d) ]) 682 | num = total 683 | for m in model.modules(): 684 | if ( isinstance(m, nn.Conv2d) ): 685 | num -= int((m.weight.data == 0).sum()) 686 | ratio = (total - num) / total 687 | log = f"===>{name}: Total {total}, current {num}, prune ratio {ratio:2f}" 688 | print(log) 689 | 690 | if __name__ == '__main__': 691 | logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)-12s %(levelname)-8s %(message)s") 692 | 693 | seed = 98 694 | torch.backends.cudnn.deterministic = True 695 | torch.backends.cudnn.benchmark = False 696 | torch.manual_seed(seed) 697 | np.random.seed(seed) 698 | random.seed(seed) 699 | 700 | args = parse_args() 701 | bench = ImageBenchmark(datasets_dir=args.datasets_dir, models_dir=args.models_dir) 702 | models_to_gen = [] 703 | mask_substrs = args.mask.strip().split('+') 704 | for model_wrapper in bench.list_models(): 705 | # print(f'loaded model: {model_wrapper}') 706 | model_str_tokens = model_wrapper.__str__().split('-') 707 | if len(model_str_tokens) >= 2 and model_str_tokens[-2].startswith(args.phase): 708 | to_gen = True 709 | model_str = re.sub(r'[^A-Za-z0-9.]+', '_', model_wrapper.__str__()) 710 | for mask_substr in mask_substrs: 711 | if not mask_substr: 712 | continue 713 | if mask_substr not in f'_{model_str}_': 714 | to_gen = False 715 | break 716 | if to_gen: 717 | models_to_gen.append(model_wrapper) 718 | models_to_gen_str = "\n".join([model_wrapper.__str__() for model_wrapper in models_to_gen]) 719 | print(f'{len(models_to_gen)} models to generate: \n{models_to_gen_str}') 720 | for model_wrapper in models_to_gen: 721 | 722 | model_wrapper.gen_model(regenerate=args.regenerate) 723 | 724 | -------------------------------------------------------------------------------- /dataset/caltech256.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import glob 4 | import time 5 | import numpy as np 6 | import random 7 | import os 8 | 9 | 10 | class Caltech257Data(data.Dataset): 11 | def __init__(self, root, is_train=True, transform=None, shots=5, seed=0, preload=False): 12 | self.num_classes = 257 13 | self.transform = transform 14 | self.preload = preload 15 | cls = glob.glob(os.path.join(root, 'Images', '*')) 16 | 17 | self.labels = [] 18 | self.image_path = [] 19 | 20 | test_samples = 20 21 | 22 | #random.seed(int(time.time())) 23 | 24 | for idx, cls_path in enumerate(cls): 25 | cls_label = int(cls_path.split('/')[-1][:3])-1 26 | imgs = glob.glob(os.path.join(cls_path, '*.jpg')) 27 | imgs = np.array(imgs) 28 | indices = np.arange(0, len(imgs)) 29 | random.seed(99+idx) 30 | random.shuffle(indices) 31 | 32 | if is_train: 33 | trainval_ind = indices[:int(shots)] 34 | np.concatenate((trainval_ind, indices[int(shots+test_samples):]), axis=0) 35 | random.seed(seed+idx) 36 | random.shuffle(trainval_ind) 37 | cur_img_paths = imgs[trainval_ind[:int(shots)]] 38 | 39 | else: 40 | cur_img_paths = imgs[indices[shots:shots+test_samples]] 41 | self.image_path.extend(cur_img_paths) 42 | self.labels.extend([cls_label for _ in range(len(cur_img_paths))]) 43 | 44 | random.seed(int(time.time())) 45 | 46 | if preload: 47 | self.imgs = {} 48 | for idx, path in enumerate(self.image_path): 49 | if idx % 100 == 0: 50 | print('Loading {}/{}...'.format(idx+1, len(self.image_path))) 51 | img = Image.open(path).convert('RGB') 52 | self.imgs[idx] = img 53 | 54 | def __getitem__(self, index): 55 | if self.preload: 56 | img = self.imgs[index] 57 | else: 58 | img = Image.open(self.image_path[index]).convert('RGB') 59 | if self.transform is not None: 60 | img = self.transform(img) 61 | return img, self.labels[index] 62 | 63 | def __len__(self): 64 | return len(self.labels) 65 | 66 | if __name__ == '__main__': 67 | seed= int(time.time()) 68 | data_train = Caltech257Data('/data/caltech_256', 'train', shots=30, seed=seed) 69 | data_test = Caltech257Data('/data/caltech_256', 'test', shots=30, seed=seed) 70 | for i in data_train.image_path: 71 | if i in data_test.image_path: 72 | print('Test in training...') 73 | print(data_train.image_path[:5]) 74 | print(data_test.image_path[:5]) 75 | -------------------------------------------------------------------------------- /dataset/cub200.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import random 4 | import time 5 | import numpy as np 6 | import os 7 | 8 | class CUB200Data(data.Dataset): 9 | def __init__(self, root, is_train=True, transform=None, shots=-1, seed=0, preload=False): 10 | self.num_classes = 200 11 | self.transform = transform 12 | self.preload = preload 13 | mapfile = os.path.join(root, 'images.txt') 14 | imgset_desc = os.path.join(root, 'train_test_split.txt') 15 | labelfile = os.path.join(root, 'image_class_labels.txt') 16 | 17 | assert os.path.exists(mapfile), 'Mapping txt is missing ({})'.format(mapfile) 18 | assert os.path.exists(imgset_desc), 'Split txt is missing ({})'.format(imgset_desc) 19 | assert os.path.exists(labelfile), 'Label txt is missing ({})'.format(labelfile) 20 | 21 | self.img_ids = [] 22 | max_id = 0 23 | with open(imgset_desc) as f: 24 | for line in f: 25 | i = int(line.split(' ')[0]) 26 | s = int(line.split(' ')[1].strip()) 27 | if s == is_train: 28 | self.img_ids.append(i) 29 | if max_id < i: 30 | max_id = i 31 | 32 | self.id_to_path = {} 33 | with open(mapfile) as f: 34 | for line in f: 35 | i = int(line.split(' ')[0]) 36 | path = line.split(' ')[1].strip() 37 | self.id_to_path[i] = os.path.join(root, 'images', path) 38 | 39 | self.id_to_label = -1*np.ones(max_id+1, dtype=np.int64) # ID starts from 1 40 | with open(labelfile) as f: 41 | for line in f: 42 | i = int(line.split(' ')[0]) 43 | #NOTE: In the network, class start from 0 instead of 1 44 | c = int(line.split(' ')[1].strip())-1 45 | self.id_to_label[i] = c 46 | 47 | if is_train: 48 | self.img_ids = np.array(self.img_ids) 49 | new_img_ids = [] 50 | for c in range(self.num_classes): 51 | ids = np.where(self.id_to_label == c)[0] 52 | random.seed(seed) 53 | random.shuffle(ids) 54 | count = 0 55 | for i in ids: 56 | if i in self.img_ids: 57 | new_img_ids.append(i) 58 | count += 1 59 | if count == shots: 60 | break 61 | self.img_ids = np.array(new_img_ids) 62 | 63 | self.imgs = {} 64 | if preload: 65 | for idx, id in enumerate(self.img_ids): 66 | if idx % 100 == 0: 67 | print('Loading {}/{}...'.format(idx+1, len(self.img_ids))) 68 | img = Image.open(self.id_to_path[id]).convert('RGB') 69 | self.imgs[id] = img 70 | 71 | def __getitem__(self, index): 72 | img_id = self.img_ids[index] 73 | img_label = self.id_to_label[img_id] 74 | 75 | if self.preload: 76 | img = self.imgs[img_id] 77 | else: 78 | img = Image.open(self.id_to_path[img_id]).convert('RGB') 79 | 80 | if self.transform is not None: 81 | img = self.transform(img) 82 | 83 | return img, img_label 84 | 85 | def __len__(self): 86 | return len(self.img_ids) 87 | 88 | 89 | if __name__ == '__main__': 90 | seed= int(time.time()) 91 | data_train = CUB200Data('/data/CUB_200_2011', True, shots=10, seed=seed) 92 | print(len(data_train)) 93 | data_test = CUB200Data('/data/CUB_200_2011', False, shots=10, seed=seed) 94 | print(len(data_test)) 95 | for i in data_train.img_ids: 96 | if i in data_test.img_ids: 97 | print('Test in training...') 98 | print('Test PASS!') 99 | print('Train', data_train.img_ids[:5]) 100 | print('Test', data_test.img_ids[:5]) 101 | -------------------------------------------------------------------------------- /dataset/flower102.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import scipy.io as sio 3 | from PIL import Image 4 | import random 5 | import os 6 | import glob 7 | import numpy as np 8 | 9 | 10 | class Flower102(data.Dataset): 11 | def __init__(self, root, is_train=True, transform=None, shots=-1, seed=0, preload=False): 12 | self.preload = preload 13 | self.num_classes = 102 14 | self.transform = transform 15 | imglabel_map = os.path.join(root, 'imagelabels.mat') 16 | setid_map = os.path.join(root, 'setid.mat') 17 | assert os.path.exists(imglabel_map), 'Mapping txt is missing ({})'.format(imglabel_map) 18 | assert os.path.exists(setid_map), 'Mapping txt is missing ({})'.format(setid_map) 19 | 20 | imagelabels = sio.loadmat(imglabel_map)['labels'][0] 21 | setids = sio.loadmat(setid_map) 22 | 23 | if is_train: 24 | ids = np.concatenate([setids['trnid'][0], setids['valid'][0]]) 25 | else: 26 | ids = setids['tstid'][0] 27 | 28 | self.labels = [] 29 | self.image_path = [] 30 | 31 | for i in ids: 32 | # Original label start from 1, we shift it to 0 33 | self.labels.append(int(imagelabels[i-1])-1) 34 | self.image_path.append( os.path.join(root, 'jpg', 'image_{:05d}.jpg'.format(i)) ) 35 | 36 | 37 | self.labels = np.array(self.labels) 38 | 39 | new_img_path = [] 40 | new_img_labels = [] 41 | if is_train: 42 | if shots != -1: 43 | self.image_path = np.array(self.image_path) 44 | for c in range(self.num_classes): 45 | ids = np.where(self.labels == c)[0] 46 | random.seed(seed) 47 | random.shuffle(ids) 48 | count = 0 49 | new_img_path.extend(self.image_path[ids[:shots]]) 50 | new_img_labels.extend([c for i in range(shots)]) 51 | self.image_path = new_img_path 52 | self.labels = new_img_labels 53 | 54 | if self.preload: 55 | self.imgs = {} 56 | for idx in range(len(self.image_path)): 57 | if idx % 100 == 0: 58 | print('Loading {}/{}...'.format(idx+1, len(self.image_path))) 59 | img = Image.open(self.image_path[idx]).convert('RGB') 60 | self.imgs[idx] = img 61 | 62 | def __getitem__(self, index): 63 | if self.preload: 64 | img = self.imgs[index] 65 | else: 66 | img = Image.open(self.image_path[index]).convert('RGB') 67 | 68 | if self.transform is not None: 69 | img = self.transform(img) 70 | 71 | return img, self.labels[index] 72 | 73 | def __len__(self): 74 | return len(self.labels) 75 | 76 | if __name__ == '__main__': 77 | # seed= int(time.time()) 78 | seed= int(98) 79 | data_train = Flower102('/data/Flower_102', True, shots=5, seed=seed) 80 | print(len(data_train)) 81 | data_test = Flower102('/data/Flower_102', False, shots=5, seed=seed) 82 | print(len(data_test)) 83 | for i in data_train.image_path: 84 | if i in data_test.image_path: 85 | print('Test in training...') 86 | print('Test PASS!') 87 | print('Train', data_train.image_path[:5]) 88 | print('Test', data_test.image_path[:5]) -------------------------------------------------------------------------------- /dataset/gtsrb.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import glob 4 | import time 5 | import numpy as np 6 | import random 7 | import os 8 | import csv 9 | from pdb import set_trace as st 10 | 11 | classnames = ['Speed limit (20km/h)', 'Speed limit (30km/h)', 'Speed limit (50km/h)', 'Speed limit (60km/h)', 12 | 'Speed limit (70km/h)', 'Speed limit (80km/h)', 'End of speed limit (80km/h)', 'Speed limit (100km/h)', 13 | 'Speed limit (120km/h)', 'No passing', 'No passing for vehicles over 3.5 metric tons', 14 | 'Right-of-way at the next intersection', 'Priority road', 'Yield', 'Stop', 'No vehicles', 15 | 'Vehicles over 3.5 metric tons prohibited', 'No entry', 'General caution', 'Dangerous curve to the left', 16 | 'Dangerous curve to the right', 'Double curve', 'Bumpy road', 'Slippery road', 17 | 'Road narrows on the right', 'Road work', 'Traffic signals', 'Pedestrians', 'Children crossing', 18 | 'Bicycles crossing', 'Beware of ice/snow', 'Wild animals crossing', 'End of all speed and passing limits', 19 | 'Turn right ahead', 'Turn left ahead', 'Ahead only', 'Go straight or right', 'Go straight or left', 20 | 'Keep right', 'Keep left', 'Roundabout mandatory', 'End of no passing', 21 | 'End of no passing by vehicles over 3.5 metric tons'] 22 | 23 | 24 | class GTSRBData(data.Dataset): 25 | def __init__(self, root, is_train=False, transform=None, shots=-1, seed=0, preload=False): 26 | self.num_classes = 43 27 | self.transform = transform 28 | self.preload = preload 29 | self.cls_names = classnames 30 | 31 | self.labels = [] 32 | self.image_path = [] 33 | 34 | if is_train: 35 | for i in range(43): 36 | mapdir = os.path.join(root, 'Final_Training', 'Images', '{:0>5d}'.format(i)) 37 | mapfile = os.path.join(mapdir, 'GT-{:0>5d}.csv'.format(i)) 38 | assert os.path.exists(mapfile), 'Mapping csv is missing ({})'.format(mapfile) 39 | 40 | with open(mapfile, 'r') as f: 41 | reader = csv.reader(f) 42 | first = 1 43 | for line in reader: 44 | if first: 45 | first = 0 46 | continue 47 | self.labels.append(int(line[0].split(';')[-1])) 48 | self.image_path.append(os.path.join(mapdir, line[0].split(';')[0])) 49 | 50 | assert len(self.image_path) == len(self.labels) 51 | 52 | else: 53 | mapdir = os.path.join(root, 'Final_Test', 'Images') 54 | mapfile = os.path.join(mapdir, 'GT-final_test.csv') 55 | assert os.path.exists(mapfile), 'Mapping txt is missing ({})'.format(mapfile) 56 | with open(mapfile, 'r') as f: 57 | reader = csv.reader(f) 58 | first = 1 59 | for line in reader: 60 | if first: 61 | first = 0 62 | continue 63 | self.labels.append(int(line[0].split(';')[-1])) 64 | self.image_path.append(os.path.join(mapdir, line[0].split(';')[0])) 65 | 66 | if is_train: 67 | indices = np.arange(0, len(self.image_path)) 68 | random.seed(seed) 69 | random.shuffle(indices) 70 | self.image_path = np.array(self.image_path)[indices] 71 | self.labels = np.array(self.labels)[indices] 72 | 73 | if shots > 0: 74 | new_img_path = [] 75 | new_labels = [] 76 | for c in range(self.num_classes): 77 | ids = np.where(self.labels == c)[0] 78 | count = 0 79 | for i in ids: 80 | new_img_path.append(self.image_path[i]) 81 | new_labels.append(c) 82 | count += 1 83 | if count == shots: 84 | break 85 | self.image_path = np.array(new_img_path) 86 | self.labels = np.array(new_labels) 87 | 88 | self.imgs = [] 89 | if preload: 90 | for idx, p in enumerate(self.image_path): 91 | if idx % 100 == 0: 92 | print('Loading {}/{}...'.format(idx + 1, len(self.image_path))) 93 | self.imgs.append(Image.open(p).convert('RGB')) 94 | 95 | def __getitem__(self, index): 96 | if len(self.imgs) > 0: 97 | img = self.imgs[index] 98 | else: 99 | img = Image.open(self.image_path[index]).convert('RGB') 100 | 101 | if self.transform is not None: 102 | img = self.transform(img) 103 | 104 | return img, self.labels[index] 105 | 106 | def __len__(self): 107 | return len(self.labels) 108 | 109 | 110 | if __name__ == '__main__': 111 | seed = int(98) 112 | data_train = GTSRBData('../data/GTSRB', True, shots=10, seed=seed) 113 | print(len(data_train)) 114 | data_test = GTSRBData('../data/GTSRB', False, shots=10, seed=seed) 115 | print(len(data_test)) 116 | for i in data_train.image_path: 117 | if i in data_test.image_path: 118 | print('Test in training...') 119 | print('Test PASS!') 120 | print('Train', data_train.image_path[:5]) 121 | print('Test', data_test.image_path[:5]) 122 | -------------------------------------------------------------------------------- /dataset/imagenet.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import os 4 | 5 | 6 | class ImageNet(data.Dataset): 7 | def __init__(self, root, is_train=True, transform=None): 8 | self.transform = transform 9 | self.num_classes = 1000 10 | 11 | if is_train: 12 | mapfile = os.path.join(root, 'train.txt') 13 | imageset = 'train' 14 | else: 15 | mapfile = os.path.join(root, 'val.txt') 16 | imageset = 'val' 17 | assert os.path.exists(mapfile), 'The mapping file does not exist!' 18 | 19 | self.datapaths = [] 20 | self.labels = [] 21 | with open(mapfile) as f: 22 | for l in f: 23 | self.datapaths.append('{}/{}/{}'.format(root, imageset, l.split(' ')[0].strip())) 24 | self.labels.append(int(l.split(' ')[1].strip())) 25 | 26 | 27 | def __getitem__(self, index): 28 | img_label = self.labels[index] 29 | #img = self.data[index] 30 | img = Image.open(self.datapaths[index]).convert('RGB') 31 | 32 | if self.transform is not None: 33 | img = self.transform(img) 34 | 35 | return img, img_label 36 | 37 | def __len__(self): 38 | return len(self.labels) 39 | 40 | -------------------------------------------------------------------------------- /dataset/mit67.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import glob 4 | import time 5 | import numpy as np 6 | import random 7 | import os 8 | from pdb import set_trace as st 9 | 10 | 11 | class MIT67(data.Dataset): 12 | def __init__(self, root, is_train=False, transform=None, shots=-1, seed=0, preload=False): 13 | self.num_classes = 67 14 | self.transform = transform 15 | cls = glob.glob(os.path.join(root, 'Images', '*')) 16 | self.cls_names = [name.split('/')[-1] for name in cls] 17 | 18 | if is_train: 19 | mapfile = os.path.join(root, 'TrainImages.txt') 20 | else: 21 | mapfile = os.path.join(root, 'TestImages.txt') 22 | 23 | assert os.path.exists(mapfile), 'Mapping txt is missing ({})'.format(mapfile) 24 | 25 | self.labels = [] 26 | self.image_path = [] 27 | 28 | with open(mapfile) as f: 29 | for line in f: 30 | self.image_path.append(os.path.join(root, 'Images', line.strip())) 31 | cls = line.split('/')[-2] 32 | self.labels.append(self.cls_names.index(cls)) 33 | 34 | if is_train: 35 | indices = np.arange(0, len(self.image_path)) 36 | random.seed(seed) 37 | random.shuffle(indices) 38 | self.image_path = np.array(self.image_path)[indices] 39 | self.labels = np.array(self.labels)[indices] 40 | 41 | if shots > 0: 42 | new_img_path = [] 43 | new_labels = [] 44 | for c in range(self.num_classes): 45 | ids = np.where(self.labels == c)[0] 46 | count = 0 47 | for i in ids: 48 | new_img_path.append(self.image_path[i]) 49 | new_labels.append(c) 50 | count += 1 51 | if count == shots: 52 | break 53 | self.image_path = np.array(new_img_path) 54 | self.labels = np.array(new_labels) 55 | 56 | self.imgs = [] 57 | if preload: 58 | for idx, p in enumerate(self.image_path): 59 | if idx % 100 == 0: 60 | print('Loading {}/{}...'.format(idx+1, len(self.image_path))) 61 | self.imgs.append(Image.open(p).convert('RGB')) 62 | 63 | def __getitem__(self, index): 64 | if len(self.imgs) > 0: 65 | img = self.imgs[index] 66 | else: 67 | img = Image.open(self.image_path[index]).convert('RGB') 68 | 69 | if self.transform is not None: 70 | img = self.transform(img) 71 | 72 | return img, self.labels[index] 73 | 74 | def __len__(self): 75 | return len(self.labels) 76 | 77 | if __name__ == '__main__': 78 | # seed= int(time.time()) 79 | seed= int(98) 80 | data_train = MIT67('/data/MIT_67', True, shots=10, seed=seed) 81 | print(len(data_train)) 82 | data_test = MIT67('/data/MIT_67', False, shots=10, seed=seed) 83 | print(len(data_test)) 84 | for i in data_train.image_path: 85 | if i in data_test.image_path: 86 | print('Test in training...') 87 | print('Test PASS!') 88 | print('Train', data_train.image_path[:5]) 89 | print('Test', data_test.image_path[:5]) 90 | -------------------------------------------------------------------------------- /dataset/seqimagenet.py: -------------------------------------------------------------------------------- 1 | # dataloader respecting the PyTorch conventions, but using tensorpack to load and process 2 | # includes typical augmentations for ImageNet training 3 | 4 | import os 5 | import io 6 | 7 | import cv2 8 | import torch 9 | from PIL import Image 10 | 11 | import numpy as np 12 | import tensorpack.dataflow as td 13 | from tensorpack import imgaug 14 | from tensorpack.dataflow import (AugmentImageComponent, PrefetchDataZMQ, 15 | BatchData, MultiThreadMapData) 16 | 17 | ##################################################################################################### 18 | # copied from: https://github.com/ppwwyyxx/tensorpack/blob/master/examples/ResNet/imagenet_utils.py # 19 | ##################################################################################################### 20 | class GoogleNetResize(imgaug.ImageAugmentor): 21 | """ 22 | crop 8%~100% of the original image 23 | See `Going Deeper with Convolutions` by Google. 24 | """ 25 | def __init__(self, crop_area_fraction=0.08, 26 | aspect_ratio_low=0.75, aspect_ratio_high=1.333, 27 | target_shape=224): 28 | self._init(locals()) 29 | 30 | def _augment(self, img, _): 31 | h, w = img.shape[:2] 32 | area = h * w 33 | for _ in range(10): 34 | targetArea = self.rng.uniform(self.crop_area_fraction, 1.0) * area 35 | aspectR = self.rng.uniform(self.aspect_ratio_low, self.aspect_ratio_high) 36 | ww = int(np.sqrt(targetArea * aspectR) + 0.5) 37 | hh = int(np.sqrt(targetArea / aspectR) + 0.5) 38 | if self.rng.uniform() < 0.5: 39 | ww, hh = hh, ww 40 | if hh <= h and ww <= w: 41 | x1 = 0 if w == ww else self.rng.randint(0, w - ww) 42 | y1 = 0 if h == hh else self.rng.randint(0, h - hh) 43 | out = img[y1:y1 + hh, x1:x1 + ww] 44 | out = cv2.resize(out, (self.target_shape, self.target_shape), interpolation=cv2.INTER_CUBIC) 45 | return out 46 | out = imgaug.ResizeShortestEdge(self.target_shape, interp=cv2.INTER_CUBIC).augment(img) 47 | out = imgaug.CenterCrop(self.target_shape).augment(out) 48 | return out 49 | 50 | def fbresnet_augmentor(isTrain): 51 | """ 52 | Augmentor used in fb.resnet.torch, for BGR images input in range [0,255]. 53 | """ 54 | if isTrain: 55 | augmentors = [ 56 | GoogleNetResize(), 57 | imgaug.RandomOrderAug( 58 | [imgaug.BrightnessScale((0.6, 1.4), clip=False), 59 | imgaug.Contrast((0.6, 1.4), clip=False), 60 | imgaug.Saturation(0.4, rgb=False), 61 | # rgb-bgr conversion for the constants copied from fb.resnet.torch 62 | imgaug.Lighting(0.1, 63 | eigval=np.asarray( 64 | [0.2175, 0.0188, 0.0045][::-1]) * 255.0, 65 | eigvec=np.array( 66 | [[-0.5675, 0.7192, 0.4009], 67 | [-0.5808, -0.0045, -0.8140], 68 | [-0.5836, -0.6948, 0.4203]], 69 | dtype='float32')[::-1, ::-1] 70 | )]), 71 | imgaug.Flip(horiz=True), 72 | ] 73 | else: 74 | augmentors = [ 75 | imgaug.ResizeShortestEdge(256, cv2.INTER_LINEAR), 76 | #imgaug.Resize(256, cv2.INTER_LINEAR), 77 | imgaug.CenterCrop((224, 224)), 78 | ] 79 | return augmentors 80 | ##################################################################################################### 81 | ##################################################################################################### 82 | 83 | 84 | numpy_type_map = { 85 | 'float64': torch.DoubleTensor, 86 | 'float32': torch.FloatTensor, 87 | 'float16': torch.HalfTensor, 88 | 'int64': torch.LongTensor, 89 | 'int32': torch.IntTensor, 90 | 'int16': torch.ShortTensor, 91 | 'int8': torch.CharTensor, 92 | 'uint8': torch.ByteTensor, 93 | } 94 | 95 | 96 | def default_collate(batch): 97 | "Puts each data field into a tensor with outer dimension batch size" 98 | 99 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" 100 | elem_type = type(batch[0]) 101 | if torch.is_tensor(batch[0]): 102 | out = None 103 | if _use_shared_memory: 104 | # If we're in a background process, concatenate directly into a 105 | # shared memory tensor to avoid an extra copy 106 | numel = sum([x.numel() for x in batch]) 107 | storage = batch[0].storage()._new_shared(numel) 108 | out = batch[0].new(storage) 109 | return torch.stack(batch, 0, out=out) 110 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 111 | and elem_type.__name__ != 'string_': 112 | elem = batch[0] 113 | if elem_type.__name__ == 'ndarray': 114 | # array of string classes and object 115 | if re.search('[SaUO]', elem.dtype.str) is not None: 116 | raise TypeError(error_msg.format(elem.dtype)) 117 | 118 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 119 | if elem.shape == (): # scalars 120 | py_type = float if elem.dtype.name.startswith('float') else int 121 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) 122 | elif isinstance(batch[0], int): 123 | return torch.LongTensor(batch) 124 | elif isinstance(batch[0], float): 125 | return torch.DoubleTensor(batch) 126 | elif isinstance(batch[0], string_classes): 127 | return batch 128 | elif isinstance(batch[0], collections.Mapping): 129 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]} 130 | elif isinstance(batch[0], collections.Sequence): 131 | transposed = zip(*batch) 132 | return [default_collate(samples) for samples in transposed] 133 | 134 | raise TypeError((error_msg.format(type(batch[0])))) 135 | 136 | 137 | class ImgAugTVCompose(imgaug.ImageAugmentor): 138 | def __init__(self, transform): 139 | self.transform = transform 140 | def _augment(self, img, _): 141 | img = self.transform(Image.fromarray(img)) 142 | return np.asarray(img) 143 | 144 | class SeqImageNetLoader(object): 145 | """ 146 | Data loader. Combines a dataset and a sampler, and provides 147 | single- or multi-process iterators over the dataset. 148 | 149 | Arguments: 150 | mode (str, required): mode of dataset to operate in, one of ['train', 'val'] 151 | batch_size (int, optional): how many samples per batch to load 152 | (default: 1). 153 | shuffle (bool, optional): set to ``True`` to have the data reshuffled 154 | at every epoch (default: False). 155 | num_workers (int, optional): how many subprocesses to use for data 156 | loading. 0 means that the data will be loaded in the main process 157 | (default: 0) 158 | cache (int, optional): cache size to use when loading data, 159 | drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, 160 | if the dataset size is not divisible by the batch size. If ``False`` and 161 | the size of dataset is not divisible by the batch size, then the last batch 162 | will be smaller. (default: False) 163 | cuda (bool, optional): set to ``True`` and the PyTorch tensors will get preloaded 164 | to the GPU for you (necessary because this lets us to uint8 conversion on the 165 | GPU, which is faster). 166 | """ 167 | 168 | def __init__(self, mode, batch_size=256, shuffle=False, num_workers=25, cache=50000, 169 | collate_fn=default_collate, remainder=False, cuda=False, transform=None): 170 | # enumerate standard imagenet augmentors 171 | #imagenet_augmentors = fbresnet_augmentor(mode == 'train') 172 | imagenet_augmentors = [ImgAugTVCompose(transform)] 173 | 174 | # load the lmdb if we can find it 175 | lmdb_loc = os.path.join(os.environ['IMAGENET'],'ILSVRC-%s.lmdb'%mode) 176 | ds = td.LMDBData(lmdb_loc, shuffle=False) 177 | if mode == 'train': 178 | ds = td.LocallyShuffleData(ds, cache) 179 | ds = td.PrefetchData(ds, 5000, 1) 180 | ds = td.LMDBDataPoint(ds) 181 | #ds = td.MapDataComponent(ds, lambda x: cv2.imdecode(x, cv2.IMREAD_COLOR), 0) 182 | ds = td.MapDataComponent(ds, lambda x: np.asarray(Image.open(io.BytesIO(x)).convert('RGB')), 0) 183 | ds = td.AugmentImageComponent(ds, imagenet_augmentors) 184 | ds = td.PrefetchDataZMQ(ds, num_workers) 185 | self.ds = td.BatchData(ds, batch_size, remainder=remainder) 186 | self.ds.reset_state() 187 | 188 | self.batch_size = batch_size 189 | self.num_workers = num_workers 190 | self.cuda = cuda 191 | 192 | def __iter__(self): 193 | for x, y in self.ds.get_data(): 194 | if self.cuda: 195 | # images come out as uint8, which are faster to copy onto the gpu 196 | x = torch.ByteTensor(x).cuda() 197 | y = torch.IntTensor(y).cuda() 198 | # but once they're on the gpu, we'll need them in 199 | yield uint8_to_float(x, self.cuda), y.long() 200 | #yield x, y.long() 201 | else: 202 | yield uint8_to_float(torch.ByteTensor(x), self.cuda), torch.IntTensor(y).long() 203 | 204 | def __len__(self): 205 | return self.ds.size() 206 | 207 | def uint8_to_float(x, cuda=False): 208 | rgb = x.float().div(255) 209 | if cuda: 210 | rgb = (rgb - torch.tensor([0.485, 0.456, 0.406]).cuda()) / torch.tensor([0.229, 0.224, 0.225]).cuda() 211 | else: 212 | rgb = (rgb - torch.tensor([0.485, 0.456, 0.406])) / torch.tensor([0.229, 0.224, 0.225]) 213 | rgb = rgb.permute(0,3,1,2) # pytorch is (n,c,h,w) 214 | return rgb 215 | 216 | if __name__ == '__main__': 217 | from tqdm import tqdm 218 | dl = Loader('train', cuda=True) 219 | for x in tqdm(dl, total=len(dl)): 220 | pass 221 | -------------------------------------------------------------------------------- /dataset/stanford_40.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import glob 4 | import time 5 | import numpy as np 6 | import random 7 | import os 8 | 9 | 10 | class Stanford40Data(data.Dataset): 11 | def __init__(self, root, is_train=False, transform=None, shots=-1, seed=0, preload=False): 12 | self.num_classes = 40 13 | self.transform = transform 14 | first_line = True 15 | self.cls_names = [] 16 | with open(os.path.join(root, 'ImageSplits', 'actions.txt')) as f: 17 | for line in f: 18 | if first_line: 19 | first_line = False 20 | continue 21 | self.cls_names.append(line.split('\t')[0].strip()) 22 | 23 | if is_train: 24 | post = 'train' 25 | else: 26 | post = 'test' 27 | 28 | self.labels = [] 29 | self.image_path = [] 30 | 31 | for label, cls_name in enumerate(self.cls_names): 32 | with open(os.path.join(root, 'ImageSplits', '{}_{}.txt'.format(cls_name, post))) as f: 33 | for line in f: 34 | self.labels.append(label) 35 | self.image_path.append(os.path.join(root, 'JPEGImages', line.strip())) 36 | 37 | 38 | if is_train: 39 | self.labels = np.array(self.labels) 40 | new_image_path = [] 41 | new_labels = [] 42 | for c in range(self.num_classes): 43 | ids = np.where(self.labels == c)[0] 44 | random.seed(seed) 45 | random.shuffle(ids) 46 | count = 0 47 | for i in ids: 48 | new_image_path.append(self.image_path[i]) 49 | new_labels.append(self.labels[i]) 50 | count += 1 51 | if count == shots: 52 | break 53 | self.labels = new_labels 54 | self.image_path = new_image_path 55 | 56 | self.imgs = [] 57 | if preload: 58 | for idx, p in enumerate(self.image_path): 59 | if idx % 100 == 0: 60 | print('Loading {}/{}...'.format(idx+1, len(self.image_path))) 61 | self.imgs.append(Image.open(p).convert('RGB')) 62 | 63 | def __getitem__(self, index): 64 | if len(self.imgs) > 0: 65 | img = self.imgs[index] 66 | else: 67 | img = Image.open(self.image_path[index]).convert('RGB') 68 | 69 | if self.transform is not None: 70 | img = self.transform(img) 71 | 72 | return img, self.labels[index] 73 | 74 | def __len__(self): 75 | return len(self.labels) 76 | 77 | if __name__ == '__main__': 78 | seed= int(98) 79 | data_train = Stanford40Data('/data/stanford_40', True, shots=10, seed=seed) 80 | print(len(data_train)) 81 | data_test = Stanford40Data('/data/stanford_40', False, shots=10, seed=seed) 82 | print(len(data_test)) 83 | for i in data_train.image_path: 84 | if i in data_test.image_path: 85 | print('Test in training...') 86 | print('Test PASS!') 87 | print('Train', data_train.image_path[:5]) 88 | print('Test', data_test.image_path[:5]) 89 | -------------------------------------------------------------------------------- /dataset/stanford_dog.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import glob 4 | import time 5 | import numpy as np 6 | import random 7 | import os 8 | import scipy.io as sio 9 | 10 | 11 | class SDog120(data.Dataset): 12 | def __init__(self, root, is_train=True, transform=None, shots=5, seed=0, preload=False): 13 | self.num_classes = 120 14 | self.transform = transform 15 | self.preload=preload 16 | 17 | if is_train: 18 | mapfile = os.path.join(root, 'train_list.mat') 19 | else: 20 | mapfile = os.path.join(root, 'test_list.mat') 21 | assert os.path.exists(mapfile), 'Mapping txt is missing ({})'.format(mapfile) 22 | dset_list = sio.loadmat(mapfile) 23 | 24 | self.labels = [] 25 | self.image_path = [] 26 | 27 | for idx, f in enumerate(dset_list['file_list']): 28 | self.image_path.append(os.path.join(root, 'Images', f[0][0])) 29 | # Stanford Dog starts 1 30 | self.labels.append(dset_list['labels'][idx][0]-1) 31 | 32 | if is_train: 33 | self.image_path = np.array(self.image_path) 34 | self.labels = np.array(self.labels) 35 | 36 | if shots > 0: 37 | new_img_path = [] 38 | new_labels = [] 39 | for c in range(self.num_classes): 40 | ids = np.where(self.labels == c)[0] 41 | random.seed(seed) 42 | random.shuffle(ids) 43 | count = 0 44 | for i in ids: 45 | new_img_path.append(self.image_path[i]) 46 | new_labels.append(c) 47 | count += 1 48 | if count == shots: 49 | break 50 | self.image_path = np.array(new_img_path) 51 | self.labels = np.array(new_labels) 52 | 53 | self.imgs = {} 54 | if preload: 55 | self.imgs = {} 56 | for idx, path in enumerate(self.image_path): 57 | if idx % 100 == 0: 58 | print('Loading {}/{}...'.format(idx+1, len(self.image_path))) 59 | img = Image.open(path).convert('RGB') 60 | self.imgs[idx] = img 61 | 62 | def __getitem__(self, index): 63 | if self.preload: 64 | img = self.imgs[index] 65 | else: 66 | img = Image.open(self.image_path[index]).convert('RGB') 67 | if self.transform is not None: 68 | img = self.transform(img) 69 | 70 | return img, self.labels[index] 71 | 72 | def __len__(self): 73 | return len(self.labels) 74 | 75 | 76 | if __name__ == '__main__': 77 | seed= int(time.time()) 78 | data_train = SDog120('/data/stanford_dog', True, shots=10, seed=seed) 79 | print(len(data_train)) 80 | data_test = SDog120('/data/stanford_dog', False, shots=10, seed=seed) 81 | print(len(data_test)) 82 | for i in data_train.image_path: 83 | if i in data_test.image_path: 84 | print('Test in training...') 85 | print('Test PASS!') 86 | print('Train', data_train.image_path[:5]) 87 | print('Test', data_test.image_path[:5]) 88 | -------------------------------------------------------------------------------- /dataset/vis_da.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import random 4 | import time 5 | import numpy as np 6 | import os 7 | import os.path as osp 8 | from pdb import set_trace as st 9 | 10 | class VisDaDATA(data.Dataset): 11 | def __init__(self, root, is_train=True, transform=None, shots=-1, seed=0, preload=False): 12 | self.transform = transform 13 | self.num_classes = 12 14 | train_dir = osp.join(root, "train") 15 | 16 | mapfile = os.path.join(train_dir, 'image_list.txt') 17 | self.train_data_list = [] 18 | self.test_data_list = [] 19 | with open(mapfile) as f: 20 | for i, line in enumerate(f): 21 | path, class_idx = line.split() 22 | class_idx = int(class_idx) 23 | path = osp.join(train_dir, path) 24 | if i%10 == 0: 25 | self.test_data_list.append((path, class_idx)) 26 | else: 27 | self.train_data_list.append((path, class_idx)) 28 | 29 | if is_train: 30 | self.data_list = self.train_data_list 31 | if shots > 0: 32 | random.shuffle(self.train_data_list) 33 | self.data_list = self.train_data_list[:shots] 34 | else: 35 | self.data_list = self.test_data_list 36 | 37 | def __len__(self): 38 | return len(self.data_list) 39 | 40 | def __getitem__(self, index): 41 | path, label = self.data_list[index] 42 | 43 | img = Image.open(path).convert('RGB') 44 | 45 | if self.transform is not None: 46 | img = self.transform(img) 47 | 48 | return img, label 49 | 50 | -------------------------------------------------------------------------------- /evaluate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torchvision\n", 11 | "import numpy as np\n", 12 | "import logging\n", 13 | "import random\n", 14 | "from scipy import spatial\n", 15 | "from utils import Utils\n", 16 | "import os\n", 17 | "\n", 18 | "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s %(name)-12s %(levelname)-8s %(message)s\")\n", 19 | "\n", 20 | "from benchmark import ImageBenchmark\n", 21 | "bench = ImageBenchmark()\n", 22 | "models = list(bench.list_models())\n", 23 | "models_dict = {}\n", 24 | "for i, model in enumerate(models):\n", 25 | " if not model.torch_model_exists():\n", 26 | " continue\n", 27 | " print(f'{i}\\t {model.__str__()}')\n", 28 | " models_dict[model.__str__()] = model" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "DEVICE = 'cuda'\n", 38 | "\n", 39 | "def gen_adv_inputs(model, inputs):\n", 40 | " from advertorch.attacks import LinfPGDAttack\n", 41 | " def myloss(yhat, y):\n", 42 | " return -((yhat[:,0]-y[:,0])**2 + 0.1*((yhat[:,1:]-y[:,1:])**2).mean(1)).mean()\n", 43 | " \n", 44 | " model = model.to(DEVICE)\n", 45 | " inputs = torch.from_numpy(inputs).to(DEVICE)\n", 46 | " with torch.no_grad():\n", 47 | " model.eval()\n", 48 | " clean_outputs = model(inputs)\n", 49 | " \n", 50 | " output_shape = clean_outputs.shape\n", 51 | " batch_size = output_shape[0]\n", 52 | " num_classes = output_shape[1]\n", 53 | " \n", 54 | " output_mean = clean_outputs.mean(axis=0)\n", 55 | " target_outputs = output_mean - clean_outputs\n", 56 | " \n", 57 | " y = torch.zeros(size=output_shape).to(DEVICE)\n", 58 | " y[:, :] = 100000\n", 59 | " # more diversity\n", 60 | " y = target_outputs * 1000\n", 61 | "# rand_idx = torch.randint(low=0, high=num_classes, size=(batch_size,))\n", 62 | "# y = torch.nn.functional.one_hot(rand_idx, num_classes=num_classes).to(DEVICE) * 10\n", 63 | "# print(y)\n", 64 | " \n", 65 | " adversary = LinfPGDAttack(\n", 66 | " model, loss_fn=myloss, eps=0.1,\n", 67 | " nb_iter=50, eps_iter=0.01, \n", 68 | " rand_init=True, clip_min=inputs.min().item(), clip_max=inputs.max().item(),\n", 69 | " targeted=True\n", 70 | " )\n", 71 | " \n", 72 | " adv_inputs = adversary.perturb(inputs, y)\n", 73 | " \n", 74 | " with torch.no_grad():\n", 75 | " model.eval()\n", 76 | " adv_outputs = model(adv_inputs).to('cpu').numpy()\n", 77 | "# print(adv_outputs)\n", 78 | " torch.cuda.empty_cache()\n", 79 | " return adv_inputs.to('cpu').numpy()\n", 80 | "\n", 81 | "\n", 82 | "model = models_dict['pretrain(mbnetv2,ImageNet)-transfer(Flower102,0.1)-prune(0.2)-']\n", 83 | "model.torch_model.to(DEVICE)\n", 84 | "seed_inputs = model.get_seed_inputs(100, rand=False)\n", 85 | "seed_outputs = model.batch_forward(seed_inputs)\n", 86 | "_, seed_preds = seed_outputs.to('cpu').data.max(1)\n", 87 | "\n", 88 | "from datetime import datetime\n", 89 | "start_time = datetime.now()\n", 90 | "adv_inputs = gen_adv_inputs(model.torch_model, seed_inputs)\n", 91 | "adv_outputs = model.batch_forward(adv_inputs)\n", 92 | "_, adv_preds = adv_outputs.to('cpu').data.max(1)\n", 93 | "model.torch_model.cpu()\n", 94 | "\n", 95 | "time_spent = (datetime.now() - start_time).total_seconds()\n", 96 | "print(f'spent {time_spent} seconds')\n", 97 | "\n", 98 | "print(f\"seed_preds={seed_preds}, adv_preds={adv_preds}, seed_outputs={seed_outputs}, adv_outputs={adv_outputs}\")" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": { 105 | "scrolled": true 106 | }, 107 | "outputs": [], 108 | "source": [ 109 | "DEVICE = 'cuda'\n", 110 | "\n", 111 | "\n", 112 | "class bcolors:\n", 113 | " HEADER = '\\033[95m'\n", 114 | " OKBLUE = '\\033[94m'\n", 115 | " OKGREEN = '\\033[92m'\n", 116 | " WARNING = '\\033[93m'\n", 117 | " FAIL = '\\033[91m'\n", 118 | " ENDC = '\\033[0m'\n", 119 | " BOLD = '\\033[1m'\n", 120 | " UNDERLINE = '\\033[4m'\n", 121 | " # Background colors:\n", 122 | " GREYBG = '\\033[100m'\n", 123 | " REDBG = '\\033[101m'\n", 124 | " GREENBG = '\\033[102m'\n", 125 | " YELLOWBG = '\\033[103m'\n", 126 | " BLUEBG = '\\033[104m'\n", 127 | " PINKBG = '\\033[105m'\n", 128 | " CYANBG = '\\033[106m'\n", 129 | "\n", 130 | "\n", 131 | "def gen_adv_inputs(model, inputs):\n", 132 | " from advertorch.attacks import LinfPGDAttack\n", 133 | " def myloss(yhat, y):\n", 134 | " return -((yhat[:,0]-y[:,0])**2 + 0.1*((yhat[:,1:]-y[:,1:])**2).mean(1)).mean()\n", 135 | " \n", 136 | " model = model.to(DEVICE)\n", 137 | " inputs = torch.from_numpy(inputs).to(DEVICE)\n", 138 | " with torch.no_grad():\n", 139 | " model.eval()\n", 140 | " clean_outputs = model(inputs)\n", 141 | " \n", 142 | " output_shape = clean_outputs.shape\n", 143 | " batch_size = output_shape[0]\n", 144 | " num_classes = output_shape[1]\n", 145 | " \n", 146 | " output_mean = clean_outputs.mean(axis=0)\n", 147 | " target_outputs = output_mean - clean_outputs\n", 148 | " \n", 149 | "# # No diversity, high divergence\n", 150 | "# y = torch.zeros(size=output_shape).to(DEVICE)\n", 151 | "# y[:, 0] = 100000 # Low diversity, high divergence\n", 152 | "# y[:] = output_mean # Low diversity, low divergence\n", 153 | " \n", 154 | "# y = target_outputs\n", 155 | "# y = target_outputs * 0.1 # High diversity, low divergence \n", 156 | " y = target_outputs * 1000 # High diversity, high divergence\n", 157 | " \n", 158 | " adversary = LinfPGDAttack(\n", 159 | " model, loss_fn=myloss, eps=0.06,\n", 160 | " nb_iter=50, eps_iter=0.01, \n", 161 | " rand_init=True, clip_min=inputs.min().item(), clip_max=inputs.max().item(),\n", 162 | " targeted=True\n", 163 | " )\n", 164 | " \n", 165 | " adv_inputs = adversary.perturb(inputs, y)\n", 166 | " \n", 167 | " with torch.no_grad():\n", 168 | " model.eval()\n", 169 | " adv_outputs = model(adv_inputs).to('cpu').numpy()\n", 170 | "# print(adv_outputs)\n", 171 | " torch.cuda.empty_cache()\n", 172 | " return adv_inputs.to('cpu').numpy()\n", 173 | "\n", 174 | "\n", 175 | "def get_comparable_models(target_model):\n", 176 | " target_model_name = target_model.__str__()\n", 177 | " target_model_segs = target_model_name.split('-')\n", 178 | " parent_model_name = '-'.join(target_model_segs[:-2]) + '-'\n", 179 | " parent_model = models_dict[parent_model_name]\n", 180 | " # print(f'parent_model: {parent_model}')\n", 181 | " reference_models = []\n", 182 | " for model in models:\n", 183 | " if not model.__str__().startswith(target_model_segs[0]):\n", 184 | " reference_models.append(model)\n", 185 | " # print(f'reference_model: {model}')\n", 186 | " return parent_model, reference_models\n", 187 | "\n", 188 | "\n", 189 | "def compute_ddv(model, normal_inputs, adv_inputs):\n", 190 | " normal_outputs = model.batch_forward(normal_inputs).cpu().numpy()\n", 191 | " adv_outputs = model.batch_forward(adv_inputs).cpu().numpy()\n", 192 | " output_pairs = zip(normal_outputs, adv_outputs)\n", 193 | " # print(list(output_pairs)[0])\n", 194 | " ddv = [] # DDV is short for decision distance vector\n", 195 | " for i, (ya, yb) in enumerate(output_pairs):\n", 196 | " dist = spatial.distance.cosine(ya, yb)\n", 197 | " ddv.append(dist)\n", 198 | " ddv = Utils.normalize(np.array(ddv))\n", 199 | " return ddv\n", 200 | "\n", 201 | "\n", 202 | "def load_inputs(model):\n", 203 | " inputs_path = os.path.join(model.torch_model_path, 'inputs.npz')\n", 204 | " npzfile = np.load(inputs_path, allow_pickle=True)\n", 205 | " seed_inputs = npzfile['seed_inputs']\n", 206 | " adv_inputs = npzfile['adv_inputs']\n", 207 | " saved_inputs = npzfile['saved_inputs'].item()\n", 208 | " # print(saved_inputs)\n", 209 | " return seed_inputs, adv_inputs\n", 210 | " \n", 211 | "\n", 212 | "skip_steal_homo = True\n", 213 | "skip_quant_float = True\n", 214 | "\n", 215 | "for i, model in enumerate(models):\n", 216 | " if not model.torch_model_exists():\n", 217 | " continue\n", 218 | " model_name = model.__str__()\n", 219 | " if not model_name.startswith('pretrain'):\n", 220 | " continue\n", 221 | " if len(model_name.split('-')) < 3:\n", 222 | " continue\n", 223 | " if i < 6:\n", 224 | " continue\n", 225 | " if 'quantize(float16)' in model_name and skip_quant_float:\n", 226 | " continue\n", 227 | " if 'steal' in model_name and skip_steal_homo:\n", 228 | " arch1 = model_name[model_name.find('(')+1:model_name.find(',')]\n", 229 | " arch2 = model_name[model_name.rfind('(')+1:model_name.rfind(')')]\n", 230 | " if arch1 == arch2:\n", 231 | " continue\n", 232 | "\n", 233 | " \n", 234 | " parent_model, ref_models = get_comparable_models(model)\n", 235 | " print(f'{i}\\t {model_name} testing')\n", 236 | " \n", 237 | " source_model = model\n", 238 | " if 'quantize' in model_name:\n", 239 | " model.torch_model.cpu()\n", 240 | " source_model = parent_model\n", 241 | " \n", 242 | " seed_inputs = model.get_seed_inputs(100, rand=False)\n", 243 | "# seed_inputs = np.array([seed_inputs[0]] * 100) # remove diversity of inputs\n", 244 | " adv_inputs = gen_adv_inputs(source_model.torch_model, seed_inputs)\n", 245 | " \n", 246 | "# adv_inputs = model.get_seed_inputs(100, rand=False) # all normal inputs\n", 247 | "\n", 248 | " # all adversarial inputs\n", 249 | "# adv_inputs2 = gen_adv_inputs(source_model.torch_model, seed_inputs)\n", 250 | "# seed_inputs = adv_inputs2\n", 251 | "\n", 252 | "# seed_inputs, adv_inputs = load_inputs(model) # load saved inputs\n", 253 | " source_model.torch_model.to(DEVICE)\n", 254 | " ddv = compute_ddv(model, seed_inputs, adv_inputs)\n", 255 | " # print(f'self_sim: {spatial.distance.cosine(ddv, ddv):.4f}')\n", 256 | " \n", 257 | " parent_sim = 0\n", 258 | " gap_min = 100\n", 259 | " for i, ref_model in enumerate([parent_model] + ref_models):\n", 260 | " if 'quantize' in ref_model.__str__(): # quantized models are equivalent to its teacher model\n", 261 | " continue\n", 262 | " try:\n", 263 | " ref_model.torch_model.to(DEVICE)\n", 264 | " ref_ddv = compute_ddv(ref_model, seed_inputs, adv_inputs)\n", 265 | " ref_sim = spatial.distance.cosine(ddv, ref_ddv)\n", 266 | " ref_model.torch_model.cpu()\n", 267 | " if i == 0:\n", 268 | " parent_sim = ref_sim\n", 269 | " gap = 1\n", 270 | " print(f'parent_sim: {ref_sim:.4f} {ref_model}')\n", 271 | " else:\n", 272 | " gap = ref_sim - parent_sim\n", 273 | " if gap > 0:\n", 274 | " print(f'ref_sim: {ref_sim:.4f} gap={gap:.4f} {ref_model}')\n", 275 | " else:\n", 276 | " print(f'{bcolors.WARNING}[ERROR] ref_sim: {ref_sim:.4f} gap={gap:.4f} {ref_model}{bcolors.ENDC}')\n", 277 | " if gap < gap_min:\n", 278 | " gap_min = gap\n", 279 | " except Exception as e:\n", 280 | " print(f'failed to compare: {ref_model}')\n", 281 | " print(f'exception: {e}')\n", 282 | " print(f'--gap_min:{gap_min:.4f}, parent_sim:{parent_sim:.4f}, correct:{gap_min>0}, model:{model_name}')\n", 283 | " # break\n", 284 | " source_model.torch_model.cpu()" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": null, 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [ 293 | "x = np.random.randn(10, 3, 3)\n", 294 | "y = np.array([x[0]] * 10)\n", 295 | "print(x, y)" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": null, 301 | "metadata": {}, 302 | "outputs": [], 303 | "source": [] 304 | } 305 | ], 306 | "metadata": { 307 | "kernelspec": { 308 | "display_name": "Python 3", 309 | "language": "python", 310 | "name": "python3" 311 | }, 312 | "language_info": { 313 | "codemirror_mode": { 314 | "name": "ipython", 315 | "version": 3 316 | }, 317 | "file_extension": ".py", 318 | "mimetype": "text/x-python", 319 | "name": "python", 320 | "nbconvert_exporter": "python", 321 | "pygments_lexer": "ipython3", 322 | "version": "3.7.7" 323 | } 324 | }, 325 | "nbformat": 4, 326 | "nbformat_minor": 4 327 | } -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import sys 4 | import time 5 | import argparse 6 | from pdb import set_trace as st 7 | import json 8 | import random 9 | 10 | import torch 11 | import numpy as np 12 | import torchvision 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | 17 | from torchvision import transforms 18 | 19 | from dataset.cub200 import CUB200Data 20 | from dataset.mit67 import MIT67 21 | from dataset.stanford_dog import SDog120 22 | from dataset.caltech256 import Caltech257Data 23 | from dataset.stanford_40 import Stanford40Data 24 | from dataset.flower102 import Flower102 25 | 26 | from model.fe_resnet import resnet18_dropout, resnet50_dropout, resnet101_dropout 27 | from model.fe_mobilenet import mbnetv2_dropout 28 | from model.fe_resnet import feresnet18, feresnet50, feresnet101 29 | from model.fe_mobilenet import fembnetv2 30 | from model.fe_vgg16 import * 31 | 32 | 33 | from utils import * 34 | from finetuner import Finetuner 35 | from weight_pruner import WeightPruner 36 | 37 | 38 | def get_args(): 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument("--datapath", type=str, default='/data', help='path to the dataset') 41 | parser.add_argument("--dataset", type=str, default='CUB200Data', help='Target dataset. Currently support: \{SDog120Data, CUB200Data, Stanford40Data, MIT67Data, Flower102Data\}') 42 | parser.add_argument("--iterations", type=int, default=30000, help='Iterations to train') 43 | parser.add_argument("--print_freq", type=int, default=100, help='Frequency of printing training logs') 44 | parser.add_argument("--test_interval", type=int, default=1000, help='Frequency of testing') 45 | parser.add_argument("--adv_test_interval", type=int, default=1000) 46 | parser.add_argument("--name", type=str, default='test', help='Name for the checkpoint') 47 | parser.add_argument("--batch_size", type=int, default=64) 48 | parser.add_argument("--lr", type=float, default=1e-2) 49 | parser.add_argument("--const_lr", action='store_true', default=False, help='Use constant learning rate') 50 | parser.add_argument("--weight_decay", type=float, default=0) 51 | parser.add_argument("--momentum", type=float, default=0.9) 52 | parser.add_argument("--beta", type=float, default=1e-2, help='The strength of the L2 regularization on the last linear layer') 53 | parser.add_argument("--dropout", type=float, default=0, help='Dropout rate for spatial dropout') 54 | parser.add_argument("--l2sp_lmda", type=float, default=0) 55 | parser.add_argument("--feat_lmda", type=float, default=0) 56 | parser.add_argument("--feat_layers", type=str, default='1234', help='Used for DELTA (which layers or stages to match), ResNets should be 1234 and MobileNetV2 should be 12345') 57 | parser.add_argument("--reinit", action='store_true', default=False, help='Reinitialize before training') 58 | parser.add_argument("--no_save", action='store_true', default=False, help='Do not save checkpoints') 59 | parser.add_argument("--swa", action='store_true', default=False, help='Use SWA') 60 | parser.add_argument("--swa_freq", type=int, default=500, help='Frequency of averaging models in SWA') 61 | parser.add_argument("--swa_start", type=int, default=0, help='Start SWA since which iterations') 62 | parser.add_argument("--label_smoothing", type=float, default=0) 63 | parser.add_argument("--checkpoint", type=str, default='', help='Load a previously trained checkpoint') 64 | parser.add_argument("--teacher_ckpt", type=str, default='') 65 | parser.add_argument("--network", type=str, default='resnet18', help='Network architecture. Currently support: \{resnet18, resnet50, resnet101, mbnetv2\}') 66 | parser.add_argument("--shot", type=int, default=-1, help='Number of training samples per class for the training dataset. -1 indicates using the full dataset.') 67 | parser.add_argument("--log", action='store_true', default=False, help='Redirect the output to log/args.name.log') 68 | parser.add_argument("--output_dir", default="results") 69 | parser.add_argument("--B", type=float, default=0.1, help='Attack budget') 70 | parser.add_argument("--m", type=float, default=1000, help='Hyper-parameter for task-agnostic attack') 71 | parser.add_argument("--pgd_iter", type=int, default=40) 72 | parser.add_argument("--method", default=None, 73 | choices=[None, "weight"] 74 | ) 75 | parser.add_argument("--train_all", default=False, action="store_true") 76 | parser.add_argument("--ft_begin_module", default=None) 77 | parser.add_argument("--vgg_output_distill", default=False, action="store_true") 78 | # Weight prune 79 | parser.add_argument("--weight_ratio", default=-1, type=float) 80 | args = parser.parse_args() 81 | 82 | if args.feat_lmda > 0: 83 | assert args.teacher_ckpt is not None 84 | args.family_output_dir = args.output_dir 85 | args.output_dir = osp.join( 86 | args.output_dir, 87 | args.name 88 | ) 89 | if not os.path.exists(args.output_dir): 90 | os.makedirs(args.output_dir) 91 | params_out_path = osp.join(args.output_dir, 'params.json') 92 | with open(params_out_path, 'w') as jf: 93 | json.dump(vars(args), jf, indent=True) 94 | print(args) 95 | 96 | return args 97 | 98 | if __name__=="__main__": 99 | seed = 98 100 | torch.backends.cudnn.deterministic = True 101 | torch.backends.cudnn.benchmark = False 102 | torch.manual_seed(seed) 103 | np.random.seed(seed) 104 | random.seed(seed) 105 | 106 | args = get_args() 107 | 108 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 109 | std=[0.229, 0.224, 0.225]) 110 | # Used to make sure we sample the same image for few-shot scenarios 111 | seed = 98 112 | 113 | train_set = eval(args.dataset)( 114 | args.datapath, True, transforms.Compose([ 115 | transforms.RandomResizedCrop(224), 116 | transforms.RandomHorizontalFlip(), 117 | transforms.ToTensor(), 118 | normalize, 119 | ]), 120 | args.shot, seed, preload=False 121 | ) 122 | test_set = eval(args.dataset)( 123 | args.datapath, False, transforms.Compose([ 124 | transforms.Resize(256), 125 | transforms.CenterCrop(224), 126 | transforms.ToTensor(), 127 | normalize, 128 | ]), 129 | args.shot, seed, preload=False 130 | ) 131 | 132 | train_loader = torch.utils.data.DataLoader( 133 | train_set, 134 | batch_size=args.batch_size, shuffle=True, 135 | num_workers=8, pin_memory=False 136 | ) 137 | test_loader = torch.utils.data.DataLoader( 138 | test_set, 139 | batch_size=args.batch_size, shuffle=False, 140 | num_workers=8, pin_memory=False 141 | ) 142 | 143 | model = eval('{}_dropout'.format(args.network))( 144 | pretrained=True, 145 | dropout=args.dropout, 146 | num_classes=train_loader.dataset.num_classes 147 | ).cuda() 148 | if args.checkpoint != '': 149 | checkpoint = torch.load(args.checkpoint) 150 | model.load_state_dict(checkpoint['state_dict']) 151 | print(f"Loaded checkpoint from {args.checkpoint}") 152 | # Pre-trained model 153 | teacher = eval('{}_dropout'.format(args.network))( 154 | pretrained=True, 155 | dropout=0, 156 | num_classes=train_loader.dataset.num_classes 157 | ).cuda() 158 | if args.teacher_ckpt is not "": 159 | checkpoint = torch.load(args.teacher_ckpt) 160 | teacher.load_state_dict(checkpoint['state_dict']) 161 | print(f"Loaded teacher_ckpt from {args.teacher_ckpt}") 162 | 163 | if args.reinit: 164 | for m in model.modules(): 165 | if type(m) in [nn.Linear, nn.BatchNorm2d, nn.Conv2d]: 166 | m.reset_parameters() 167 | 168 | if args.method is None: 169 | finetune_machine = Finetuner( 170 | args, 171 | model, teacher, 172 | train_loader, test_loader, 173 | ) 174 | elif args.method == "weight": 175 | finetune_machine = WeightPruner( 176 | args, 177 | model, teacher, 178 | train_loader, test_loader, 179 | ) 180 | else: 181 | raise RuntimeError 182 | 183 | 184 | finetune_machine.train() 185 | # finetune_machine.adv_eval() 186 | 187 | if args.method is not None: 188 | finetune_machine.final_check_param_num() 189 | 190 | 191 | -------------------------------------------------------------------------------- /finetuner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import sys 4 | import time 5 | import argparse 6 | from pdb import set_trace as st 7 | import json 8 | import random 9 | from functools import partial 10 | 11 | import torch 12 | import numpy as np 13 | import torchvision 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | 18 | from torchvision import transforms 19 | 20 | from dataset.cub200 import CUB200Data 21 | from dataset.mit67 import MIT67 22 | from dataset.stanford_dog import SDog120 23 | from dataset.caltech256 import Caltech257Data 24 | from dataset.stanford_40 import Stanford40Data 25 | from dataset.flower102 import Flower102 26 | 27 | from model.fe_resnet import resnet18_dropout, resnet50_dropout, resnet101_dropout 28 | from model.fe_mobilenet import mbnetv2_dropout 29 | from model.fe_resnet import feresnet18, feresnet50, feresnet101 30 | from model.fe_mobilenet import fembnetv2 31 | from model.fe_vgg16 import * 32 | 33 | from utils import * 34 | 35 | 36 | class Finetuner(object): 37 | def __init__( 38 | self, 39 | args, 40 | model, 41 | teacher, 42 | train_loader, 43 | test_loader, 44 | ): 45 | self.args = args 46 | self.model = model.to('cuda') 47 | self.teacher = teacher.to('cuda') 48 | self.train_loader = train_loader 49 | self.test_loader = test_loader 50 | 51 | self.init_models() 52 | 53 | def init_models(self): 54 | args = self.args 55 | model = self.model 56 | teacher = self.teacher 57 | 58 | # Used to matching features 59 | def record_act(self, input, output): 60 | self.out = output 61 | 62 | if 'mbnetv2' in args.network: 63 | reg_layers = {0: [model.layer1], 1: [model.layer2], 2: [model.layer3], 3: [model.layer4]} 64 | model.layer1.register_forward_hook(record_act) 65 | model.layer2.register_forward_hook(record_act) 66 | model.layer3.register_forward_hook(record_act) 67 | model.layer4.register_forward_hook(record_act) 68 | if '5' in args.feat_layers: 69 | reg_layers[4] = [model.layer5] 70 | model.layer5.register_forward_hook(record_act) 71 | elif 'resnet' in args.network: 72 | reg_layers = {0: [model.layer1], 1: [model.layer2], 2: [model.layer3], 3: [model.layer4]} 73 | model.layer1.register_forward_hook(record_act) 74 | model.layer2.register_forward_hook(record_act) 75 | model.layer3.register_forward_hook(record_act) 76 | model.layer4.register_forward_hook(record_act) 77 | elif 'vgg' in args.network: 78 | cnt = 0 79 | reg_layers = {} 80 | for name, module in model.named_modules(): 81 | if isinstance(module, nn.MaxPool2d) : 82 | reg_layers[name] = [module] 83 | module.register_forward_hook(record_act) 84 | print(name, module) 85 | 86 | # Stored pre-trained weights for computing L2SP 87 | for m in model.modules(): 88 | if hasattr(m, 'weight') and not hasattr(m, 'old_weight'): 89 | m.old_weight = m.weight.data.clone().detach() 90 | # all_weights = torch.cat([all_weights.reshape(-1), m.weight.data.abs().reshape(-1)], dim=0) 91 | if hasattr(m, 'bias') and not hasattr(m, 'old_bias') and m.bias is not None: 92 | m.old_bias = m.bias.data.clone().detach() 93 | 94 | if args.reinit: 95 | for m in model.modules(): 96 | if type(m) in [nn.Linear, nn.BatchNorm2d, nn.Conv2d]: 97 | m.reset_parameters() 98 | 99 | if 'vgg' not in args.network: 100 | reg_layers[0].append(teacher.layer1) 101 | teacher.layer1.register_forward_hook(record_act) 102 | reg_layers[1].append(teacher.layer2) 103 | teacher.layer2.register_forward_hook(record_act) 104 | reg_layers[2].append(teacher.layer3) 105 | teacher.layer3.register_forward_hook(record_act) 106 | reg_layers[3].append(teacher.layer4) 107 | teacher.layer4.register_forward_hook(record_act) 108 | 109 | if '5' in args.feat_layers: 110 | reg_layers[4].append(teacher.layer5) 111 | teacher.layer5.register_forward_hook(record_act) 112 | else: 113 | cnt = 0 114 | for name, module in teacher.named_modules(): 115 | if isinstance(module, nn.MaxPool2d) : 116 | reg_layers[name].append(module) 117 | module.register_forward_hook(record_act) 118 | # print(name, module) 119 | 120 | self.reg_layers = reg_layers 121 | # Check self.model 122 | # st() 123 | 124 | def compute_steal_loss(self, batch, label): 125 | def CXE(predicted, target): 126 | return -(target * torch.log(predicted)).sum(dim=1).mean() 127 | model = self.model 128 | teacher = self.teacher 129 | alpha = self.args.steal_alpha 130 | T = self.args.temperature 131 | 132 | teacher_out = teacher(batch) 133 | out = model(batch) 134 | _, pred = out.max(dim=1) 135 | 136 | # _, teacher_pred = teacher_out.max(dim=1) 137 | # KD_loss = F.cross_entropy(out, teacher_pred) 138 | # soft_loss, hard_loss = 0, 0 139 | 140 | out = F.softmax(out) 141 | teacher_out = F.softmax(teacher_out) 142 | KD_loss = CXE(out, teacher_out) 143 | soft_loss, hard_loss = 0, 0 144 | 145 | # soft_loss = nn.KLDivLoss()( 146 | # F.log_softmax(out/T, dim=1), 147 | # F.softmax(teacher_out/T, dim=1) 148 | # ) * (alpha * T * T) 149 | # hard_loss = F.cross_entropy(out, label) * (1. - alpha) 150 | # KD_loss = soft_loss + hard_loss 151 | 152 | top1 = float(pred.eq(label).sum().item()) / label.shape[0] * 100. 153 | 154 | return KD_loss, top1, soft_loss, hard_loss 155 | 156 | 157 | def compute_loss(self, batch, label, ce, featloss): 158 | model = self.model 159 | teacher = self.teacher 160 | args = self.args 161 | l2sp_lmda = self.args.l2sp_lmda 162 | reg_layers = self.reg_layers 163 | feat_loss, l2sp_loss = 0, 0 164 | 165 | out = model(batch) 166 | _, pred = out.max(dim=1) 167 | 168 | top1 = float(pred.eq(label).sum().item()) / label.shape[0] * 100. 169 | # top1_meter.update(float(pred.eq(label).sum().item()) / label.shape[0] * 100.) 170 | 171 | loss = 0. 172 | loss += ce(out, label) 173 | 174 | ce_loss = loss.item() 175 | # ce_loss_meter.update(loss.item()) 176 | 177 | with torch.no_grad(): 178 | tout = teacher(batch) 179 | 180 | # Compute the feature distillation loss only when needed 181 | if args.feat_lmda != 0: 182 | regloss = 0 183 | for key in reg_layers.keys(): 184 | # key = int(layer)-1 185 | 186 | src_x = reg_layers[key][0].out 187 | tgt_x = reg_layers[key][1].out 188 | regloss += featloss(src_x, tgt_x.detach()) 189 | 190 | regloss = args.feat_lmda * regloss 191 | loss += regloss 192 | feat_loss = regloss.item() 193 | # feat_loss_meter.update(regloss.item()) 194 | 195 | beta_loss, linear_norm = linear_l2(model, args.beta) 196 | loss = loss + beta_loss 197 | linear_loss = beta_loss.item() 198 | # linear_loss_meter.update(beta_loss.item()) 199 | 200 | if l2sp_lmda != 0: 201 | reg, _ = l2sp(model, l2sp_lmda) 202 | l2sp_loss = reg.item() 203 | # l2sp_loss_meter.update(reg.item()) 204 | loss = loss + reg 205 | 206 | total_loss = loss.item() 207 | # total_loss_meter.update(loss.item()) 208 | 209 | return loss, top1, ce_loss, feat_loss, linear_loss, l2sp_loss, total_loss 210 | 211 | def steal_test(self): 212 | model = self.model 213 | teacher = self.teacher 214 | loader = self.test_loader 215 | alpha = self.args.steal_alpha 216 | T = self.args.temperature 217 | 218 | with torch.no_grad(): 219 | model.eval() 220 | teacher.eval() 221 | 222 | total_soft, total_hard, total_kd = 0, 0, 0 223 | total = 0 224 | top1 = 0 225 | 226 | for i, (batch, label) in enumerate(loader): 227 | batch, label = batch.to('cuda'), label.to('cuda') 228 | total += batch.size(0) 229 | 230 | teacher_out = teacher(batch) 231 | out = model(batch) 232 | _, pred = out.max(dim=1) 233 | 234 | soft_loss = nn.KLDivLoss()( 235 | F.log_softmax(out/T, dim=1), 236 | F.softmax(teacher_out/T, dim=1) 237 | ) * (alpha * T * T) 238 | hard_loss = F.cross_entropy(out, label) * (1. - alpha) 239 | KD_loss = soft_loss + hard_loss 240 | 241 | total_soft += soft_loss.item() 242 | total_hard += hard_loss.item() 243 | total_kd += KD_loss.item() 244 | top1 += int(pred.eq(label).sum().item()) 245 | 246 | return float(top1)/total*100, total_kd/(i+1), total_soft/(i+1), total_hard/(i+1) 247 | 248 | def test(self): 249 | model = self.model 250 | teacher = self.teacher 251 | loader = self.test_loader 252 | reg_layers = self.reg_layers 253 | args = self.args 254 | loss = True 255 | 256 | with torch.no_grad(): 257 | model.eval() 258 | 259 | if loss: 260 | teacher.eval() 261 | 262 | ce = CrossEntropyLabelSmooth(loader.dataset.num_classes, args.label_smoothing).to('cuda') 263 | featloss = torch.nn.MSELoss(reduction='none') 264 | 265 | total_ce = 0 266 | total_feat_reg = np.zeros(len(reg_layers)) 267 | total_l2sp_reg = 0 268 | 269 | total = 0 270 | top1 = 0 271 | for i, (batch, label) in enumerate(loader): 272 | batch, label = batch.to('cuda'), label.to('cuda') 273 | 274 | total += batch.size(0) 275 | out = model(batch) 276 | _, pred = out.max(dim=1) 277 | top1 += int(pred.eq(label).sum().item()) 278 | 279 | if loss: 280 | total_ce += ce(out, label).item() 281 | if teacher is not None: 282 | with torch.no_grad(): 283 | tout = teacher(batch) 284 | 285 | for i, key in enumerate(reg_layers): 286 | # print(key, len(reg_layers[key])) 287 | src_x = reg_layers[key][0].out 288 | tgt_x = reg_layers[key][1].out 289 | # print(src_x.shape, tgt_x.shape) 290 | 291 | regloss = featloss(src_x, tgt_x.detach()).mean() 292 | 293 | total_feat_reg[i] += regloss.item() 294 | 295 | _, unweighted = l2sp(model, 0) 296 | total_l2sp_reg += unweighted.item() 297 | # break 298 | 299 | return float(top1)/total*100, total_ce/(i+1), np.sum(total_feat_reg)/(i+1), total_l2sp_reg/(i+1), total_feat_reg/(i+1) 300 | 301 | def get_fine_tuning_parameters(self): 302 | model = self.model 303 | parameters = [] 304 | 305 | ft_begin_module = self.args.ft_begin_module 306 | ft_ratio = self.args.ft_ratio if 'ft_ratio' in self.args else None 307 | 308 | if ft_ratio: 309 | all_params = [param for param in model.parameters()] 310 | num_tune_params = int(len(all_params) * ft_ratio) 311 | for v in all_params[-num_tune_params:]: 312 | parameters.append({'params': v}) 313 | 314 | all_names = [name for name, _ in model.named_parameters()] 315 | with open(osp.join(self.args.output_dir, "finetune.log"), "w") as f: 316 | f.write(f"Fixed layers:\n") 317 | for name in all_names[:-num_tune_params]: 318 | f.write(name+"\n") 319 | f.write(f"\n\nFinetuned layers:\n") 320 | for name in all_names[-num_tune_params:]: 321 | f.write(name+"\n") 322 | 323 | return parameters 324 | 325 | if not ft_begin_module: 326 | return model.parameters() 327 | 328 | add_flag = False 329 | for k, v in model.named_parameters(): 330 | # if ft_begin_module == k: 331 | if ft_begin_module in k: 332 | add_flag = True 333 | 334 | if add_flag: 335 | # print(k) 336 | parameters.append({'params': v}) 337 | if ft_begin_module and not add_flag: 338 | raise RuntimeError("wrong ft_begin_module, no module to finetune") 339 | 340 | return parameters 341 | 342 | def train(self): 343 | model = self.model 344 | train_loader = self.train_loader 345 | test_loader = self.test_loader 346 | iterations = self.args.iterations 347 | lr = self.args.lr 348 | output_dir = self.args.output_dir 349 | l2sp_lmda = self.args.l2sp_lmda 350 | teacher = self.teacher 351 | reg_layers = self.reg_layers 352 | args = self.args 353 | 354 | model_params = self.get_fine_tuning_parameters() 355 | 356 | if l2sp_lmda == 0: 357 | optimizer = optim.SGD( 358 | model_params, 359 | lr=lr, 360 | momentum=args.momentum, 361 | weight_decay=args.weight_decay, 362 | ) 363 | else: 364 | optimizer = optim.SGD( 365 | model_params, 366 | lr=lr, 367 | momentum=args.momentum, 368 | weight_decay=0, 369 | ) 370 | 371 | end_iter = iterations 372 | 373 | if args.const_lr: 374 | scheduler = None 375 | else: 376 | scheduler = optim.lr_scheduler.CosineAnnealingLR( 377 | optimizer, 378 | end_iter, 379 | ) 380 | 381 | teacher.eval() 382 | ce = CrossEntropyLabelSmooth(train_loader.dataset.num_classes, args.label_smoothing).to('cuda') 383 | featloss = torch.nn.MSELoss() 384 | 385 | batch_time = MovingAverageMeter('Time', ':6.3f') 386 | data_time = MovingAverageMeter('Data', ':6.3f') 387 | ce_loss_meter = MovingAverageMeter('CE Loss', ':6.3f') 388 | feat_loss_meter = MovingAverageMeter('Feat. Loss', ':6.3f') 389 | l2sp_loss_meter = MovingAverageMeter('L2SP Loss', ':6.3f') 390 | linear_loss_meter = MovingAverageMeter('LinearL2 Loss', ':6.3f') 391 | total_loss_meter = MovingAverageMeter('Total Loss', ':6.3f') 392 | top1_meter = MovingAverageMeter('Acc@1', ':6.2f') 393 | 394 | train_path = osp.join(output_dir, "train.tsv") 395 | with open(train_path, 'w') as wf: 396 | columns = ['time', 'iter', 'Acc', 'celoss', 'featloss', 'l2sp'] 397 | wf.write('\t'.join(columns) + '\n') 398 | test_path = osp.join(output_dir, "test.tsv") 399 | with open(test_path, 'w') as wf: 400 | columns = ['time', 'iter', 'Acc', 'celoss', 'featloss', 'l2sp'] 401 | wf.write('\t'.join(columns) + '\n') 402 | adv_path = osp.join(output_dir, "adv.tsv") 403 | with open(adv_path, 'w') as wf: 404 | columns = ['time', 'iter', 'Acc', 'AdvAcc', 'ASR'] 405 | wf.write('\t'.join(columns) + '\n') 406 | 407 | dataloader_iterator = iter(train_loader) 408 | for i in range(iterations): 409 | 410 | model.train() 411 | optimizer.zero_grad() 412 | 413 | end = time.time() 414 | try: 415 | batch, label = next(dataloader_iterator) 416 | except: 417 | dataloader_iterator = iter(train_loader) 418 | batch, label = next(dataloader_iterator) 419 | batch, label = batch.to('cuda'), label.to('cuda') 420 | data_time.update(time.time() - end) 421 | 422 | if args.steal: 423 | loss, top1, soft_loss, hard_loss = self.compute_steal_loss(batch, label) 424 | total_loss = loss 425 | ce_loss = hard_loss 426 | feat_loss = soft_loss 427 | linear_loss, l2sp_loss = 0, 0 428 | else: 429 | loss, top1, ce_loss, feat_loss, linear_loss, l2sp_loss, total_loss = self.compute_loss( 430 | batch, label, 431 | ce, featloss, 432 | ) 433 | top1_meter.update(top1) 434 | ce_loss_meter.update(ce_loss) 435 | feat_loss_meter.update(feat_loss) 436 | linear_loss_meter.update(linear_loss) 437 | l2sp_loss_meter.update(l2sp_loss) 438 | total_loss_meter.update(total_loss) 439 | 440 | loss.backward() 441 | #----------------------------------------- 442 | for k, m in enumerate(model.modules()): 443 | # print(k, m) 444 | if isinstance(m, nn.Conv2d): 445 | weight_copy = m.weight.data.abs().clone() 446 | mask = weight_copy.gt(0).float().cuda() 447 | m.weight.grad.data.mul_(mask) 448 | if isinstance(m, nn.Linear): 449 | weight_copy = m.weight.data.abs().clone() 450 | mask = weight_copy.gt(0).float().cuda() 451 | m.weight.grad.data.mul_(mask) 452 | #----------------------------------------- 453 | optimizer.step() 454 | for param_group in optimizer.param_groups: 455 | current_lr = param_group['lr'] 456 | if scheduler is not None: 457 | scheduler.step() 458 | 459 | batch_time.update(time.time() - end) 460 | 461 | if (i % args.print_freq == 0) or (i == iterations-1): 462 | progress = ProgressMeter( 463 | iterations, 464 | [batch_time, data_time, top1_meter, total_loss_meter, ce_loss_meter, feat_loss_meter, l2sp_loss_meter, linear_loss_meter], 465 | prefix="LR: {:6.3f}".format(current_lr), 466 | output_dir=output_dir, 467 | ) 468 | progress.display(i) 469 | 470 | if (i % args.test_interval == 0) or (i == iterations-1): 471 | if self.args.steal: 472 | test_top1, test_ce_loss, test_feat_loss, test_weight_loss = self.steal_test( 473 | # model, teacher, test_loader, loss=True 474 | ) 475 | train_top1, train_ce_loss, train_feat_loss, train_weight_loss = self.steal_test( 476 | # model, teacher, train_loader, loss=True 477 | ) 478 | test_feat_layer_loss, train_feat_layer_loss = 0, 0 479 | else: 480 | test_top1, test_ce_loss, test_feat_loss, test_weight_loss, test_feat_layer_loss = self.test( 481 | # model, teacher, test_loader, loss=True 482 | ) 483 | train_top1, train_ce_loss, train_feat_loss, train_weight_loss, train_feat_layer_loss = self.test( 484 | # model, teacher, train_loader, loss=True 485 | ) 486 | 487 | print( 488 | 'Eval Train | Iteration {}/{} | Top-1: {:.2f} | CE Loss: {:.3f} | Feat Reg Loss: {:.6f} | L2SP Reg Loss: {:.3f}'.format(i+1, iterations, train_top1, train_ce_loss, train_feat_loss, train_weight_loss)) 489 | print( 490 | 'Eval Test | Iteration {}/{} | Top-1: {:.2f} | CE Loss: {:.3f} | Feat Reg Loss: {:.6f} | L2SP Reg Loss: {:.3f}'.format(i+1, iterations, test_top1, test_ce_loss, test_feat_loss, test_weight_loss)) 491 | localtime = time.asctime( time.localtime(time.time()) )[4:-6] 492 | with open(train_path, 'a') as af: 493 | train_cols = [ 494 | localtime, 495 | i, 496 | round(train_top1,2), 497 | round(train_ce_loss,2), 498 | round(train_feat_loss,2), 499 | round(train_weight_loss,2), 500 | ] 501 | af.write('\t'.join([str(c) for c in train_cols]) + '\n') 502 | with open(test_path, 'a') as af: 503 | test_cols = [ 504 | localtime, 505 | i, 506 | round(test_top1,2), 507 | round(test_ce_loss,2), 508 | round(test_feat_loss,2), 509 | round(test_weight_loss,2), 510 | ] 511 | af.write('\t'.join([str(c) for c in test_cols]) + '\n') 512 | if not args.no_save: 513 | ckpt_path = osp.join( 514 | args.output_dir, 515 | "ckpt.pth" 516 | ) 517 | torch.save( 518 | {'state_dict': model.state_dict()}, 519 | ckpt_path, 520 | ) 521 | 522 | if ( hasattr(self, "iterative_prune") and i % args.prune_interval == 0 ): 523 | self.iterative_prune(i) 524 | 525 | return model.to('cpu') 526 | 527 | def countWeightInfo(self): 528 | ... 529 | -------------------------------------------------------------------------------- /load_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pdb import set_trace as st 3 | 4 | import torch 5 | from torchvision import transforms 6 | 7 | from dataset.cub200 import CUB200Data 8 | from dataset.mit67 import MIT67 9 | from dataset.stanford_dog import SDog120 10 | from dataset.caltech256 import Caltech257Data 11 | from dataset.stanford_40 import Stanford40Data 12 | from dataset.flower102 import Flower102 13 | 14 | from model.fe_resnet import resnet18_dropout, resnet50_dropout, resnet101_dropout 15 | from model.fe_mobilenet import mbnetv2_dropout 16 | from model.fe_resnet import feresnet18, feresnet50, feresnet101 17 | from model.fe_mobilenet import fembnetv2 18 | 19 | def get_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--datapath", type=str, default='/data', help='path to the dataset') 22 | parser.add_argument("--dataset", type=str, default='CUB200Data', help='Target dataset. Currently support: \{SDog120Data, CUB200Data, Stanford40Data, MIT67Data, Flower102Data\}') 23 | parser.add_argument("--checkpoint", type=str, default='', help='Load a previously trained checkpoint') 24 | parser.add_argument("--network", type=str, default='resnet18', help='Network architecture. Currently support: \{resnet18, resnet50, resnet101, mbnetv2\}') 25 | args = parser.parse_args() 26 | 27 | return args 28 | 29 | def main(): 30 | args = get_args() 31 | 32 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 33 | std=[0.229, 0.224, 0.225]) 34 | test_set = eval(args.dataset)( 35 | args.datapath, False, transforms.Compose([ 36 | transforms.Resize(256), 37 | transforms.CenterCrop(224), 38 | transforms.ToTensor(), 39 | normalize, 40 | ]), 41 | ) 42 | 43 | 44 | model = eval('{}_dropout'.format(args.network))( 45 | pretrained=True, 46 | dropout=0, 47 | num_classes=test_set.num_classes 48 | ) 49 | 50 | checkpoint = torch.load(args.checkpoint) 51 | model.load_state_dict(checkpoint['state_dict']) 52 | print(f"Loaded checkpoint from {args.checkpoint}") 53 | 54 | 55 | if __name__=="__main__": 56 | main() -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanchun-li/ModelDiff/f509bd2a1de20138aeb5cf105f99597a279f6f0b/model/__init__.py -------------------------------------------------------------------------------- /model/fe_mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 4 | 5 | 6 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 7 | 8 | 9 | model_urls = { 10 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 11 | } 12 | 13 | 14 | def _make_divisible(v, divisor, min_value=None): 15 | """ 16 | This function is taken from the original tf repo. 17 | It ensures that all layers have a channel number that is divisible by 8 18 | It can be seen here: 19 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 20 | :param v: 21 | :param divisor: 22 | :param min_value: 23 | :return: 24 | """ 25 | if min_value is None: 26 | min_value = divisor 27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 28 | # Make sure that round down does not go down by more than 10%. 29 | if new_v < 0.9 * v: 30 | new_v += divisor 31 | return new_v 32 | 33 | 34 | class ConvBNReLU(nn.Sequential): 35 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 36 | padding = (kernel_size - 1) // 2 37 | super(ConvBNReLU, self).__init__( 38 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 39 | nn.BatchNorm2d(out_planes), 40 | nn.ReLU6(inplace=True) 41 | ) 42 | 43 | 44 | class InvertedResidual(nn.Module): 45 | def __init__(self, inp, oup, stride, expand_ratio): 46 | super(InvertedResidual, self).__init__() 47 | self.stride = stride 48 | assert stride in [1, 2] 49 | 50 | hidden_dim = int(round(inp * expand_ratio)) 51 | self.use_res_connect = self.stride == 1 and inp == oup 52 | 53 | layers = [] 54 | if expand_ratio != 1: 55 | # pw 56 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 57 | layers.extend([ 58 | # dw 59 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 60 | # pw-linear 61 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 62 | nn.BatchNorm2d(oup), 63 | ]) 64 | self.conv = nn.Sequential(*layers) 65 | 66 | def forward(self, x): 67 | if self.use_res_connect: 68 | return x + self.conv(x) 69 | else: 70 | return self.conv(x) 71 | 72 | 73 | class MobileNetV2(nn.Module): 74 | def __init__(self, 75 | num_classes=1000, 76 | width_mult=1.0, 77 | inverted_residual_setting=None, 78 | round_nearest=8, 79 | dropout=0, 80 | haslinear=False, 81 | block=None): 82 | """ 83 | MobileNet V2 main class 84 | Args: 85 | num_classes (int): Number of classes 86 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 87 | inverted_residual_setting: Network structure 88 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 89 | Set to 1 to turn off rounding 90 | block: Module specifying inverted residual building block for mobilenet 91 | """ 92 | super(MobileNetV2, self).__init__() 93 | 94 | self.haslinear = haslinear 95 | if dropout > 0: 96 | self.dropout_layer = nn.Dropout2d(dropout) 97 | else: 98 | self.dropout_layer = None 99 | 100 | if block is None: 101 | block = InvertedResidual 102 | input_channel = 32 103 | last_channel = 1280 104 | 105 | if inverted_residual_setting is None: 106 | inverted_residual_setting = [ 107 | # t, c, n, s 108 | [1, 16, 1, 1], 109 | [6, 24, 2, 2], 110 | [6, 32, 3, 2], 111 | [6, 64, 4, 2], 112 | [6, 96, 3, 1], 113 | [6, 160, 3, 2], 114 | [6, 320, 1, 1], 115 | ] 116 | 117 | # only check the first element, assuming user knows t,c,n,s are required 118 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 119 | raise ValueError("inverted_residual_setting should be non-empty " 120 | "or a 4-element list, got {}".format(inverted_residual_setting)) 121 | 122 | # building first layer 123 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 124 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 125 | features = [ConvBNReLU(3, input_channel, stride=2)] 126 | # building inverted residual blocks 127 | layers = [] 128 | for t, c, n, s in inverted_residual_setting: 129 | output_channel = _make_divisible(c * width_mult, round_nearest) 130 | for i in range(n): 131 | stride = s if i == 0 else 1 132 | features.append(block(input_channel, output_channel, stride, expand_ratio=t)) 133 | input_channel = output_channel 134 | if s == 2: 135 | layers.append(features) 136 | features = [] 137 | # building last several layers 138 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 139 | layers.append(features) 140 | 141 | self.layer1 = nn.Sequential(*layers[0]) 142 | self.layer2 = nn.Sequential(*layers[1]) 143 | self.layer3 = nn.Sequential(*layers[2]) 144 | self.layer4 = nn.Sequential(*layers[3]) 145 | self.layer5 = nn.Sequential(*layers[4]) 146 | 147 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 148 | 149 | # building classifier 150 | self.classifier = nn.Sequential( 151 | nn.Dropout(0.2), 152 | nn.Linear(self.last_channel, num_classes), 153 | ) 154 | 155 | # weight initialization 156 | for m in self.modules(): 157 | if isinstance(m, nn.Conv2d): 158 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 159 | if m.bias is not None: 160 | nn.init.zeros_(m.bias) 161 | elif isinstance(m, nn.BatchNorm2d): 162 | nn.init.ones_(m.weight) 163 | nn.init.zeros_(m.bias) 164 | elif isinstance(m, nn.Linear): 165 | nn.init.normal_(m.weight, 0, 0.01) 166 | nn.init.zeros_(m.bias) 167 | 168 | def _forward_impl(self, x): 169 | # This exists since TorchScript doesn't support inheritance, so the superclass method 170 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 171 | x = self.layer1(x) 172 | if self.dropout_layer is not None: 173 | x = self.dropout_layer(x) 174 | x = self.layer2(x) 175 | if self.dropout_layer is not None: 176 | x = self.dropout_layer(x) 177 | x = self.layer3(x) 178 | if self.dropout_layer is not None: 179 | x = self.dropout_layer(x) 180 | x = self.layer4(x) 181 | if self.dropout_layer is not None: 182 | x = self.dropout_layer(x) 183 | x = self.layer5(x) 184 | if self.dropout_layer is not None: 185 | x = self.dropout_layer(x) 186 | x = self.avgpool(x) 187 | x = torch.flatten(x, 1) 188 | 189 | if self.haslinear: 190 | x = self.classifier(x) 191 | return x 192 | 193 | def forward(self, x): 194 | return self._forward_impl(x) 195 | 196 | 197 | def fembnetv2(pretrained=False, progress=True, **kwargs): 198 | """ 199 | Constructs a MobileNetV2 architecture from 200 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 201 | Args: 202 | pretrained (bool): If True, returns a model pre-trained on ImageNet 203 | progress (bool): If True, displays a progress bar of the download to stderr 204 | """ 205 | model = MobileNetV2(haslinear=False, **kwargs) 206 | if pretrained: 207 | 208 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 209 | progress=progress) 210 | del state_dict['classifier.1.weight'] 211 | del state_dict['classifier.1.bias'] 212 | 213 | new_dict = {} 214 | cumulative_layer = 0 215 | for li, layers in enumerate([4, 3, 4, 6, 2]): 216 | for l in range(layers): 217 | key = 'features.{}.'.format(l+cumulative_layer) 218 | new_key = 'layer{}.{}.'.format(li+1, l) 219 | 220 | for k in state_dict: 221 | if key in k: 222 | new_k = k.replace(key, new_key) 223 | new_dict[new_k] = state_dict[k] 224 | cumulative_layer += layers 225 | 226 | new_params = model.state_dict() 227 | new_params.update(new_dict) 228 | model.load_state_dict(new_params) 229 | return model 230 | 231 | def mbnetv2_dropout(pretrained=False, progress=True, **kwargs): 232 | """ 233 | Constructs a MobileNetV2 architecture from 234 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 235 | Args: 236 | pretrained (bool): If True, returns a model pre-trained on ImageNet 237 | progress (bool): If True, displays a progress bar of the download to stderr 238 | """ 239 | model = MobileNetV2(haslinear=True, **kwargs) 240 | if pretrained: 241 | 242 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 243 | progress=progress) 244 | del state_dict['classifier.1.weight'] 245 | del state_dict['classifier.1.bias'] 246 | 247 | new_dict = {} 248 | cumulative_layer = 0 249 | for li, layers in enumerate([4, 3, 4, 6, 2]): 250 | for l in range(layers): 251 | key = 'features.{}.'.format(l+cumulative_layer) 252 | new_key = 'layer{}.{}.'.format(li+1, l) 253 | 254 | for k in state_dict: 255 | if key in k: 256 | new_k = k.replace(key, new_key) 257 | new_dict[new_k] = state_dict[k] 258 | cumulative_layer += layers 259 | 260 | new_params = model.state_dict() 261 | new_params.update(new_dict) 262 | model.load_state_dict(new_params) 263 | return model 264 | -------------------------------------------------------------------------------- /model/fe_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 8 | 'wide_resnet50_2', 'wide_resnet101_2'] 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 18 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 19 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 20 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=dilation, groups=groups, bias=False, dilation=dilation) 28 | 29 | 30 | def conv1x1(in_planes, out_planes, stride=1): 31 | """1x1 convolution""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | __constants__ = ['downsample'] 38 | 39 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 40 | base_width=64, dilation=1, norm_layer=None): 41 | super(BasicBlock, self).__init__() 42 | if norm_layer is None: 43 | norm_layer = nn.BatchNorm2d 44 | if groups != 1 or base_width != 64: 45 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 46 | if dilation > 1: 47 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 48 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 49 | self.conv1 = conv3x3(inplanes, planes, stride) 50 | self.bn1 = norm_layer(planes) 51 | self.relu = nn.ReLU(inplace=True) 52 | self.conv2 = conv3x3(planes, planes) 53 | self.bn2 = norm_layer(planes) 54 | self.downsample = downsample 55 | self.stride = stride 56 | 57 | def forward(self, x): 58 | identity = x 59 | 60 | out = self.conv1(x) 61 | out = self.bn1(out) 62 | out = self.relu(out) 63 | 64 | out = self.conv2(out) 65 | out = self.bn2(out) 66 | 67 | if self.downsample is not None: 68 | identity = self.downsample(x) 69 | 70 | out += identity 71 | out = self.relu(out) 72 | 73 | return out 74 | 75 | 76 | class Bottleneck(nn.Module): 77 | expansion = 4 78 | __constants__ = ['downsample'] 79 | 80 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 81 | base_width=64, dilation=1, norm_layer=None): 82 | super(Bottleneck, self).__init__() 83 | if norm_layer is None: 84 | norm_layer = nn.BatchNorm2d 85 | width = int(planes * (base_width / 64.)) * groups 86 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 87 | self.conv1 = conv1x1(inplanes, width) 88 | self.bn1 = norm_layer(width) 89 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 90 | self.bn2 = norm_layer(width) 91 | self.conv3 = conv1x1(width, planes * self.expansion) 92 | self.bn3 = norm_layer(planes * self.expansion) 93 | self.relu = nn.ReLU(inplace=True) 94 | self.downsample = downsample 95 | self.stride = stride 96 | 97 | def forward(self, x): 98 | identity = x 99 | 100 | out = self.conv1(x) 101 | out = self.bn1(out) 102 | out = self.relu(out) 103 | 104 | out = self.conv2(out) 105 | out = self.bn2(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv3(out) 109 | out = self.bn3(out) 110 | 111 | if self.downsample is not None: 112 | identity = self.downsample(x) 113 | 114 | out += identity 115 | out = self.relu(out) 116 | 117 | return out 118 | 119 | 120 | class ResNet(nn.Module): 121 | 122 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 123 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 124 | norm_layer=None, dropout=0., haslinear=False): 125 | super(ResNet, self).__init__() 126 | if norm_layer is None: 127 | norm_layer = nn.BatchNorm2d 128 | self._norm_layer = norm_layer 129 | 130 | self.haslinear = haslinear 131 | if dropout > 0: 132 | self.dropout_layer = nn.Dropout2d(dropout) 133 | else: 134 | self.dropout_layer = None 135 | 136 | self.inplanes = 64 137 | self.dilation = 1 138 | if replace_stride_with_dilation is None: 139 | # each element in the tuple indicates if we should replace 140 | # the 2x2 stride with a dilated convolution instead 141 | replace_stride_with_dilation = [False, False, False] 142 | if len(replace_stride_with_dilation) != 3: 143 | raise ValueError("replace_stride_with_dilation should be None " 144 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 145 | self.groups = groups 146 | self.base_width = width_per_group 147 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 148 | bias=False) 149 | self.bn1 = norm_layer(self.inplanes) 150 | self.relu = nn.ReLU(inplace=True) 151 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 152 | 153 | self.layer1 = self._make_layer(block, 64, layers[0]) 154 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 155 | dilate=replace_stride_with_dilation[0]) 156 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 157 | dilate=replace_stride_with_dilation[1]) 158 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 159 | dilate=replace_stride_with_dilation[2]) 160 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 161 | self.fc = nn.Linear(512 * block.expansion, num_classes) 162 | 163 | for m in self.modules(): 164 | if isinstance(m, nn.Conv2d): 165 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 166 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 167 | nn.init.constant_(m.weight, 1) 168 | nn.init.constant_(m.bias, 0) 169 | 170 | # Zero-initialize the last BN in each residual branch, 171 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 172 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 173 | if zero_init_residual: 174 | for m in self.modules(): 175 | if isinstance(m, Bottleneck): 176 | nn.init.constant_(m.bn3.weight, 0) 177 | elif isinstance(m, BasicBlock): 178 | nn.init.constant_(m.bn2.weight, 0) 179 | 180 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 181 | norm_layer = self._norm_layer 182 | downsample = None 183 | previous_dilation = self.dilation 184 | if dilate: 185 | self.dilation *= stride 186 | stride = 1 187 | if stride != 1 or self.inplanes != planes * block.expansion: 188 | downsample = nn.Sequential( 189 | conv1x1(self.inplanes, planes * block.expansion, stride), 190 | norm_layer(planes * block.expansion), 191 | ) 192 | 193 | layers = [] 194 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 195 | self.base_width, previous_dilation, norm_layer)) 196 | self.inplanes = planes * block.expansion 197 | for _ in range(1, blocks): 198 | layers.append(block(self.inplanes, planes, groups=self.groups, 199 | base_width=self.base_width, dilation=self.dilation, 200 | norm_layer=norm_layer)) 201 | 202 | return nn.Sequential(*layers) 203 | 204 | def _forward_impl(self, x): 205 | # See note [TorchScript super()] 206 | x1 = self.conv1(x) 207 | x1 = self.bn1(x1) 208 | x1 = self.relu(x1) 209 | x1 = self.maxpool(x1) 210 | 211 | x1 = self.layer1(x1) 212 | if self.dropout_layer is not None: 213 | x1 = self.dropout_layer(x1) 214 | x1 = self.layer2(x1) 215 | if self.dropout_layer is not None: 216 | x1 = self.dropout_layer(x1) 217 | x1 = self.layer3(x1) 218 | if self.dropout_layer is not None: 219 | x1 = self.dropout_layer(x1) 220 | x1 = self.layer4(x1) 221 | if self.dropout_layer is not None: 222 | x1 = self.dropout_layer(x1) 223 | x1 = self.avgpool(x1) 224 | x1 = torch.flatten(x1, 1) 225 | if self.haslinear: 226 | x1 = self.fc(x1) 227 | 228 | return x1 229 | 230 | def forward(self, x): 231 | return self._forward_impl(x) 232 | 233 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 234 | model = ResNet(block, layers, **kwargs) 235 | if pretrained: 236 | state_dict = load_state_dict_from_url(model_urls[arch], 237 | progress=progress) 238 | del state_dict['fc.weight'] 239 | del state_dict['fc.bias'] 240 | new_dict = dict(state_dict) 241 | new_params = model.state_dict() 242 | new_params.update(new_dict) 243 | model.load_state_dict(new_params) 244 | return model 245 | 246 | 247 | def resnet18_dropout(pretrained=False, progress=True, **kwargs): 248 | r"""ResNet-18 model from 249 | `"Deep Residual Learning for Image Recognition" `_ 250 | Args: 251 | pretrained (bool): If True, returns a model pre-trained on ImageNet 252 | progress (bool): If True, displays a progress bar of the download to stderr 253 | """ 254 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, haslinear=True, 255 | **kwargs) 256 | 257 | def resnet34_dropout(pretrained=False, progress=True, **kwargs): 258 | r"""ResNet-18 model from 259 | `"Deep Residual Learning for Image Recognition" `_ 260 | Args: 261 | pretrained (bool): If True, returns a model pre-trained on ImageNet 262 | progress (bool): If True, displays a progress bar of the download to stderr 263 | """ 264 | return _resnet('resnet18', BasicBlock, [3, 4, 6, 3], pretrained, progress, haslinear=True, 265 | **kwargs) 266 | 267 | def resnet50_dropout(pretrained=False, progress=True, **kwargs): 268 | r"""ResNet-50 model from 269 | `"Deep Residual Learning for Image Recognition" `_ 270 | Args: 271 | pretrained (bool): If True, returns a model pre-trained on ImageNet 272 | progress (bool): If True, displays a progress bar of the download to stderr 273 | """ 274 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, haslinear=True, 275 | **kwargs) 276 | 277 | 278 | def resnet101_dropout(pretrained=False, progress=True, **kwargs): 279 | r"""ResNet-101 model from 280 | `"Deep Residual Learning for Image Recognition" `_ 281 | Args: 282 | pretrained (bool): If True, returns a model pre-trained on ImageNet 283 | progress (bool): If True, displays a progress bar of the download to stderr 284 | """ 285 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, haslinear=True, 286 | **kwargs) 287 | 288 | def feresnet18(pretrained=False, progress=True, **kwargs): 289 | r"""ResNet-18 model from 290 | `"Deep Residual Learning for Image Recognition" `_ 291 | Args: 292 | pretrained (bool): If True, returns a model pre-trained on ImageNet 293 | progress (bool): If True, displays a progress bar of the download to stderr 294 | """ 295 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 296 | **kwargs) 297 | 298 | 299 | def feresnet34(pretrained=False, progress=True, **kwargs): 300 | r"""ResNet-34 model from 301 | `"Deep Residual Learning for Image Recognition" `_ 302 | Args: 303 | pretrained (bool): If True, returns a model pre-trained on ImageNet 304 | progress (bool): If True, displays a progress bar of the download to stderr 305 | """ 306 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 307 | **kwargs) 308 | 309 | 310 | def feresnet50(pretrained=False, progress=True, **kwargs): 311 | r"""ResNet-50 model from 312 | `"Deep Residual Learning for Image Recognition" `_ 313 | Args: 314 | pretrained (bool): If True, returns a model pre-trained on ImageNet 315 | progress (bool): If True, displays a progress bar of the download to stderr 316 | """ 317 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 318 | **kwargs) 319 | 320 | 321 | def feresnet101(pretrained=False, progress=True, **kwargs): 322 | r"""ResNet-101 model from 323 | `"Deep Residual Learning for Image Recognition" `_ 324 | Args: 325 | pretrained (bool): If True, returns a model pre-trained on ImageNet 326 | progress (bool): If True, displays a progress bar of the download to stderr 327 | """ 328 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 329 | **kwargs) 330 | 331 | 332 | def feresnet152(pretrained=False, progress=True, **kwargs): 333 | r"""ResNet-152 model from 334 | `"Deep Residual Learning for Image Recognition" `_ 335 | Args: 336 | pretrained (bool): If True, returns a model pre-trained on ImageNet 337 | progress (bool): If True, displays a progress bar of the download to stderr 338 | """ 339 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 340 | **kwargs) 341 | 342 | 343 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 344 | r"""ResNeXt-50 32x4d model from 345 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 346 | Args: 347 | pretrained (bool): If True, returns a model pre-trained on ImageNet 348 | progress (bool): If True, displays a progress bar of the download to stderr 349 | """ 350 | kwargs['groups'] = 32 351 | kwargs['width_per_group'] = 4 352 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 353 | pretrained, progress, **kwargs) 354 | 355 | 356 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 357 | r"""ResNeXt-101 32x8d model from 358 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 359 | Args: 360 | pretrained (bool): If True, returns a model pre-trained on ImageNet 361 | progress (bool): If True, displays a progress bar of the download to stderr 362 | """ 363 | kwargs['groups'] = 32 364 | kwargs['width_per_group'] = 8 365 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 366 | pretrained, progress, **kwargs) 367 | 368 | 369 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 370 | r"""Wide ResNet-50-2 model from 371 | `"Wide Residual Networks" `_ 372 | The model is the same as ResNet except for the bottleneck number of channels 373 | which is twice larger in every block. The number of channels in outer 1x1 374 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 375 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 376 | Args: 377 | pretrained (bool): If True, returns a model pre-trained on ImageNet 378 | progress (bool): If True, displays a progress bar of the download to stderr 379 | """ 380 | kwargs['width_per_group'] = 64 * 2 381 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 382 | pretrained, progress, **kwargs) 383 | 384 | 385 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 386 | r"""Wide ResNet-101-2 model from 387 | `"Wide Residual Networks" `_ 388 | The model is the same as ResNet except for the bottleneck number of channels 389 | which is twice larger in every block. The number of channels in outer 1x1 390 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 391 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 392 | Args: 393 | pretrained (bool): If True, returns a model pre-trained on ImageNet 394 | progress (bool): If True, displays a progress bar of the download to stderr 395 | """ 396 | kwargs['width_per_group'] = 64 * 2 397 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 398 | pretrained, progress, **kwargs) 399 | -------------------------------------------------------------------------------- /model/fe_vgg16.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 4 | from pdb import set_trace as st 5 | 6 | __all__ = [ 7 | 'VGG', 'vgg16', 'vgg16_bn', 'fevgg16_bn', 'vgg16_bn_dropout', 8 | 'fevgg11_bn', 'vgg11_bn_dropout', 9 | ] 10 | 11 | 12 | model_urls = { 13 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 14 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 15 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 16 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 17 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 18 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 19 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 20 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 21 | } 22 | 23 | 24 | class VGG(nn.Module): 25 | 26 | def __init__(self, features, num_classes=1000, init_weights=True, haslinear=False, 27 | dropout=0): 28 | super(VGG, self).__init__() 29 | self.features = features 30 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 31 | self.haslinear = haslinear 32 | self.classifier = nn.Sequential( 33 | nn.Linear(512 * 7 * 7, 4096), 34 | nn.ReLU(True), 35 | nn.Dropout(), 36 | nn.Linear(4096, 4096), 37 | nn.ReLU(True), 38 | nn.Dropout(), 39 | nn.Linear(4096, num_classes), 40 | ) 41 | if not haslinear: 42 | self.classifier = self.classifier[0:-3] 43 | 44 | if init_weights: 45 | self._initialize_weights() 46 | 47 | def forward(self, x): 48 | x = self.features(x) 49 | x = self.avgpool(x) 50 | x = torch.flatten(x, 1) 51 | x = self.classifier(x) 52 | return x 53 | 54 | def _initialize_weights(self): 55 | for m in self.modules(): 56 | if isinstance(m, nn.Conv2d): 57 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 58 | if m.bias is not None: 59 | nn.init.constant_(m.bias, 0) 60 | elif isinstance(m, nn.BatchNorm2d): 61 | nn.init.constant_(m.weight, 1) 62 | nn.init.constant_(m.bias, 0) 63 | elif isinstance(m, nn.Linear): 64 | nn.init.normal_(m.weight, 0, 0.01) 65 | nn.init.constant_(m.bias, 0) 66 | 67 | 68 | def make_layers(cfg, batch_norm=False): 69 | layers = [] 70 | in_channels = 3 71 | for v in cfg: 72 | if v == 'M': 73 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 74 | else: 75 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 76 | if batch_norm: 77 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 78 | else: 79 | layers += [conv2d, nn.ReLU(inplace=True)] 80 | in_channels = v 81 | return nn.Sequential(*layers) 82 | 83 | 84 | cfgs = { 85 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 86 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 87 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 88 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 89 | } 90 | 91 | 92 | def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs): 93 | if pretrained: 94 | kwargs['init_weights'] = False 95 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 96 | if pretrained: 97 | state_dict = load_state_dict_from_url(model_urls[arch], 98 | progress=progress) 99 | del state_dict['classifier.6.weight'] 100 | del state_dict['classifier.6.bias'] 101 | new_dict = dict(state_dict) 102 | new_params = model.state_dict() 103 | new_params.update(new_dict) 104 | model.load_state_dict(new_params) 105 | return model 106 | 107 | 108 | 109 | def vgg16(pretrained=False, progress=True, **kwargs): 110 | r"""VGG 16-layer model (configuration "D") 111 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 112 | 113 | Args: 114 | pretrained (bool): If True, returns a model pre-trained on ImageNet 115 | progress (bool): If True, displays a progress bar of the download to stderr 116 | """ 117 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) 118 | 119 | 120 | 121 | def vgg16_bn(pretrained=False, progress=True, **kwargs): 122 | r"""VGG 16-layer model (configuration "D") with batch normalization 123 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 124 | 125 | Args: 126 | pretrained (bool): If True, returns a model pre-trained on ImageNet 127 | progress (bool): If True, displays a progress bar of the download to stderr 128 | """ 129 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) 130 | 131 | def fevgg16_bn(pretrained=False, progress=True, **kwargs): 132 | r"""VGG 16-layer model (configuration "D") with batch normalization 133 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 134 | 135 | Args: 136 | pretrained (bool): If True, returns a model pre-trained on ImageNet 137 | progress (bool): If True, displays a progress bar of the download to stderr 138 | """ 139 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) 140 | 141 | def vgg16_bn_dropout(pretrained=False, progress=True, **kwargs): 142 | r"""VGG 16-layer model (configuration "D") with batch normalization 143 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 144 | 145 | Args: 146 | pretrained (bool): If True, returns a model pre-trained on ImageNet 147 | progress (bool): If True, displays a progress bar of the download to stderr 148 | """ 149 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, haslinear=True, **kwargs) 150 | 151 | def fevgg11_bn(pretrained=False, progress=True, **kwargs): 152 | r"""VGG 16-layer model (configuration "D") with batch normalization 153 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 154 | 155 | Args: 156 | pretrained (bool): If True, returns a model pre-trained on ImageNet 157 | progress (bool): If True, displays a progress bar of the download to stderr 158 | """ 159 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) 160 | 161 | def vgg11_bn_dropout(pretrained=False, progress=True, **kwargs): 162 | r"""VGG 16-layer model (configuration "D") with batch normalization 163 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ 164 | 165 | Args: 166 | pretrained (bool): If True, returns a model pre-trained on ImageNet 167 | progress (bool): If True, displays a progress bar of the download to stderr 168 | """ 169 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, haslinear=True, **kwargs) -------------------------------------------------------------------------------- /modeldiff.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | import sys 5 | import os 6 | import argparse 7 | import time 8 | import logging 9 | import pathlib 10 | import tempfile 11 | import copy 12 | import random 13 | import torch 14 | import numpy as np 15 | # import tensorflow as tf 16 | from scipy import spatial 17 | from abc import ABC, abstractmethod 18 | from pdb import set_trace as st 19 | import torch.nn as nn 20 | 21 | from utils import lazy_property, Utils 22 | 23 | 24 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 25 | 26 | 27 | class ModelComparison(ABC): 28 | def __init__(self, model1, model2): 29 | self.logger = logging.getLogger('ModelComparison') 30 | self.model1 = model1 31 | self.model2 = model2 32 | 33 | @abstractmethod 34 | def compare(self): 35 | pass 36 | 37 | 38 | class ModelDiff(ModelComparison): 39 | N_INPUT_PAIRS = 300 40 | MAX_VAL = 256 41 | 42 | def __init__(self, model1, model2, gen_inputs=None, input_metrics=None, compute_decision_dist=None, compare_ddv=None): 43 | super().__init__(model1, model2) 44 | self.logger = logging.getLogger('ModelDiff') 45 | self.logger.info(f'comparing {model1} and {model2}') 46 | self.logger.debug(f'initialize comparison: {self.model1} {self.model2}') 47 | self.logger.debug(f'input shapes: {self.model1.input_shape} {self.model2.input_shape}') 48 | self.input_shape = model1.input_shape 49 | if list(model1.input_shape) != list(model2.input_shape): 50 | self.logger.warning('input shapes do not match') 51 | self.gen_inputs = gen_inputs if gen_inputs else ModelDiff._gen_profiling_inputs_search 52 | self.input_metrics = input_metrics if input_metrics else ModelDiff.metrics_output_diversity 53 | self.compute_decision_dist = compute_decision_dist if compute_decision_dist else ModelDiff._compute_decision_dist_output_cos 54 | self.compare_ddv = compare_ddv if compare_ddv else ModelDiff._compare_ddv_cos 55 | 56 | def get_seed_inputs(self, rand=False): 57 | seed_inputs = np.concatenate([ 58 | self.model1.get_seed_inputs(self.N_INPUT_PAIRS, rand=rand), 59 | self.model2.get_seed_inputs(self.N_INPUT_PAIRS, rand=rand) 60 | ]) 61 | 62 | return seed_inputs 63 | 64 | def compare(self, use_torch=True): 65 | self.logger.info(f'generating seed inputs') 66 | seed_inputs = list(self.get_seed_inputs()) 67 | np.random.shuffle(seed_inputs) 68 | seed_inputs = np.array(seed_inputs) 69 | if use_torch: 70 | seed_inputs = torch.from_numpy(seed_inputs) 71 | self.logger.info(f' seed inputs generated with shape {seed_inputs.shape}') 72 | 73 | self.logger.info(f'generating profiling inputs') 74 | profiling_inputs = self.gen_inputs(self, seed_inputs, use_torch=use_torch) 75 | # input_pairs = [] 76 | # for i in range(int(len(profiling_inputs) / 2)): 77 | # xa = profiling_inputs[2 * i] 78 | # xb = profiling_inputs[2 * i + 1] 79 | # xa = np.expand_dims(xa, axis=0) 80 | # xb = np.expand_dims(xb, axis=0) 81 | # input_pairs.append((xa, xb)) 82 | self.logger.info(f' profiling inputs generated with shape {profiling_inputs.shape}') 83 | 84 | self.logger.info(f'computing metrics') 85 | input_metrics_1 = self.input_metrics(self.model1, profiling_inputs, use_torch=use_torch) 86 | input_metrics_2 = self.input_metrics(self.model2, profiling_inputs, use_torch=use_torch) 87 | self.logger.info(f' input metrics: model1={input_metrics_1} model2={input_metrics_2}') 88 | 89 | model_similarity = self.compute_similarity_with_ddm(profiling_inputs) 90 | return model_similarity 91 | 92 | def compute_similarity_with_IPGuard(self, profiling_inputs): 93 | n_pairs = int(len(list(profiling_inputs)) / 2) 94 | normal_input = profiling_inputs[:n_pairs] 95 | adv_input = profiling_inputs[n_pairs:] 96 | 97 | out = self.model1.batch_forward(adv_input).to("cpu").numpy() 98 | normal_pred = out.argmax(axis=1) 99 | out = self.model2.batch_forward(adv_input).to("cpu").numpy() 100 | adv_pred = out.argmax(axis=1) 101 | 102 | consist = int( (normal_pred == adv_pred).sum() ) 103 | sim = consist / n_pairs 104 | self.logger.info(f' model similarity: {sim}') 105 | return sim 106 | 107 | def compute_similarity_with_weight(self): 108 | name_to_modules = {} 109 | for name, module in self.model1.torch_model.named_modules(): 110 | if isinstance(module, nn.Conv2d): 111 | name_to_modules[name] = [module.weight] 112 | for name, module in self.model2.torch_model.named_modules(): 113 | if isinstance(module, nn.Conv2d): 114 | name_to_modules[name].append(module.weight) 115 | layer_dist = [] 116 | for name, pack in name_to_modules.items(): 117 | weight1, weight2 = pack 118 | # print(name, float((weight1==0).sum() / weight1.numel())) 119 | weight1 = weight1.view(-1) 120 | weight2 = weight2.view(-1) 121 | dist = nn.CosineSimilarity(dim=0)(weight1, weight2) 122 | layer_dist.append(dist.item()) 123 | sim = np.mean(layer_dist) 124 | 125 | self.logger.info(f' model similarity: {sim}') 126 | return sim 127 | 128 | def compute_similarity_with_abs_weight(self): 129 | name_to_modules = {} 130 | for name, module in self.model1.torch_model.named_modules(): 131 | if isinstance(module, nn.Conv2d): 132 | name_to_modules[name] = [module.weight] 133 | for name, module in self.model2.torch_model.named_modules(): 134 | if isinstance(module, nn.Conv2d): 135 | name_to_modules[name].append(module.weight) 136 | layer_dist = [] 137 | for name, pack in name_to_modules.items(): 138 | weight1, weight2 = pack 139 | dist = 1 - ((weight1-weight2)).abs().mean() 140 | layer_dist.append(dist.item()) 141 | sim = np.mean(layer_dist) 142 | 143 | self.logger.info(f' model similarity: {sim}') 144 | return sim 145 | 146 | def compute_similarity_with_bn_weight(self): 147 | name_to_modules = {} 148 | for name, module in self.model1.torch_model.named_modules(): 149 | if isinstance(module, nn.BatchNorm2d): 150 | name_to_modules[name] = [module.weight] 151 | for name, module in self.model2.torch_model.named_modules(): 152 | if isinstance(module, nn.BatchNorm2d): 153 | name_to_modules[name].append(module.weight) 154 | layer_dist = [] 155 | for name, pack in name_to_modules.items(): 156 | weight1, weight2 = pack 157 | weight1 = weight1.view(-1) 158 | weight2 = weight2.view(-1) 159 | dist = nn.CosineSimilarity(dim=0)(weight1, weight2) 160 | layer_dist.append(dist.item()) 161 | sim = np.mean(layer_dist) 162 | 163 | self.logger.info(f' model similarity: {sim}') 164 | return sim 165 | 166 | def compute_similarity_with_conv_bn_weight(self): 167 | name_to_modules = {} 168 | for name, module in self.model1.torch_model.named_modules(): 169 | if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.Conv2d): 170 | name_to_modules[name] = [module.weight] 171 | for name, module in self.model2.torch_model.named_modules(): 172 | if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.Conv2d): 173 | name_to_modules[name].append(module.weight) 174 | layer_dist = [] 175 | for name, pack in name_to_modules.items(): 176 | weight1, weight2 = pack 177 | weight1 = weight1.view(-1) 178 | weight2 = weight2.view(-1) 179 | dist = nn.CosineSimilarity(dim=0)(weight1, weight2) 180 | layer_dist.append(dist.item()) 181 | sim = np.mean(layer_dist) 182 | 183 | self.logger.info(f' model similarity: {sim}') 184 | return sim 185 | 186 | def compute_similarity_with_identical_weight(self): 187 | name_to_modules = {} 188 | for name, module in self.model1.torch_model.named_modules(): 189 | if isinstance(module, nn.Conv2d): 190 | name_to_modules[name] = [module.weight] 191 | for name, module in self.model2.torch_model.named_modules(): 192 | if isinstance(module, nn.Conv2d): 193 | name_to_modules[name].append(module.weight) 194 | layer_dist = [] 195 | for name, pack in name_to_modules.items(): 196 | weight1, weight2 = pack 197 | identical = (weight1==weight2).sum() 198 | dist = float(identical / weight1.numel()) 199 | layer_dist.append(dist) 200 | 201 | sim = np.mean(layer_dist) 202 | 203 | self.logger.info(f' model similarity: {sim}') 204 | return sim 205 | 206 | def compute_similarity_with_whole_weight(self): 207 | name_to_modules = {} 208 | for name, module in self.model1.torch_model.named_modules(): 209 | if isinstance(module, nn.Conv2d): 210 | name_to_modules[name] = [module.weight] 211 | for name, module in self.model2.torch_model.named_modules(): 212 | if isinstance(module, nn.Conv2d): 213 | name_to_modules[name].append(module.weight) 214 | model1_weight, model2_weight = [], [] 215 | for name, pack in name_to_modules.items(): 216 | weight1, weight2 = pack 217 | if (weight1==weight2).all(): 218 | continue 219 | weight1 = weight1.view(-1) 220 | weight2 = weight2.view(-1) 221 | model1_weight.append(weight1) 222 | model2_weight.append(weight2) 223 | model1_weight = torch.cat(model1_weight) 224 | model2_weight = torch.cat(model2_weight) 225 | sim = nn.CosineSimilarity(dim=0)(model1_weight, model2_weight).item() 226 | 227 | self.logger.info(f' model similarity: {sim}') 228 | return sim 229 | 230 | def compute_similarity_with_feature(self, profiling_inputs): 231 | # Used to matching features 232 | def record_act(self, input, output): 233 | self.out = output 234 | 235 | name_to_modules = {} 236 | for name, module in self.model1.torch_model.named_modules(): 237 | if isinstance(module, nn.Conv2d): 238 | name_to_modules[name] = [module] 239 | module.register_forward_hook(record_act) 240 | for name, module in self.model2.torch_model.named_modules(): 241 | if isinstance(module, nn.Conv2d): 242 | name_to_modules[name].append(module) 243 | module.register_forward_hook(record_act) 244 | # print(name_to_modules.keys()) 245 | self.model1.batch_forward(profiling_inputs) 246 | self.model2.batch_forward(profiling_inputs) 247 | 248 | feature_dists = [] 249 | b = profiling_inputs.shape[0] 250 | for name, pack in name_to_modules.items(): 251 | module1, module2 = pack 252 | feature1 = module1.out.view(-1) 253 | feature2 = module2.out.view(-1) 254 | dist = nn.CosineSimilarity(dim=0)(feature1, feature2).item() 255 | feature_dists.append(dist) 256 | del module1.out, module2.out, feature1, feature2 257 | sim = np.mean(feature_dists) 258 | 259 | 260 | self.logger.info(f' model similarity: {sim}') 261 | return sim 262 | 263 | def compute_similarity_with_last_feature(self, profiling_inputs): 264 | # Used to matching features 265 | def record_act(self, input, output): 266 | self.out = output 267 | 268 | name_to_modules = {} 269 | for name, module in self.model1.torch_model.named_modules(): 270 | if isinstance(module, nn.Conv2d): 271 | module1 = module 272 | for name, module in self.model2.torch_model.named_modules(): 273 | if isinstance(module, nn.Conv2d): 274 | module2 = module 275 | module1.register_forward_hook(record_act) 276 | module2.register_forward_hook(record_act) 277 | # print(name_to_modules.keys()) 278 | self.model1.batch_forward(profiling_inputs) 279 | self.model2.batch_forward(profiling_inputs) 280 | 281 | feature1 = module1.out.view(-1) 282 | feature2 = module2.out.view(-1) 283 | dist = nn.CosineSimilarity(dim=0)(feature1, feature2).item() 284 | del module1.out, module2.out, feature1, feature2 285 | sim = dist 286 | 287 | self.logger.info(f' model similarity: {sim}') 288 | return sim 289 | 290 | def compute_similarity_with_last_feature_svd(self, profiling_inputs): 291 | # Used to matching features 292 | def record_act(self, input, output): 293 | self.out = output 294 | 295 | name_to_modules = {} 296 | for name, module in self.model1.torch_model.named_modules(): 297 | if isinstance(module, nn.Conv2d): 298 | module1 = module 299 | for name, module in self.model2.torch_model.named_modules(): 300 | if isinstance(module, nn.Conv2d): 301 | module2 = module 302 | module1.register_forward_hook(record_act) 303 | module2.register_forward_hook(record_act) 304 | # print(name_to_modules.keys()) 305 | self.model1.batch_forward(profiling_inputs) 306 | self.model2.batch_forward(profiling_inputs) 307 | 308 | 309 | feature1 = module1.out 310 | feature2 = module2.out 311 | b, c, _, _ = feature1.shape 312 | feature1 = feature1.view(b,c,-1) 313 | feature2 = feature2.view(b,c,-1) 314 | for i in range(b): 315 | u1,s1,v1 = torch.svd(feature1[i]) 316 | u2,s2,v2 = torch.svd(feature2[i]) 317 | st() 318 | dist = nn.CosineSimilarity(dim=0)(feature1, feature2).item() 319 | del module1.out, module2.out, feature1, feature2 320 | sim = dist 321 | 322 | self.logger.info(f' model similarity: {sim}') 323 | return sim 324 | 325 | def compute_similarity_with_ddv(self, profiling_inputs): 326 | self.logger.info(f'computing DDVs') 327 | ddv1 = self.compute_ddv(self.model1, profiling_inputs) 328 | ddv2 = self.compute_ddv(self.model2, profiling_inputs) 329 | self.logger.info(f' DDV computed: shape={ddv1.shape} and {ddv2.shape}') 330 | # print(f' ddv1={ddv1}\n ddv2={ddv2}') 331 | 332 | self.logger.info(f'measuring model similarity') 333 | ddv1 = Utils.normalize(np.array(ddv1)) 334 | ddv2 = Utils.normalize(np.array(ddv2)) 335 | self.logger.debug(f' ddv1={ddv1}\n ddv2={ddv2}') 336 | ddv_distance = self.compare_ddv(ddv1, ddv2) 337 | model_similarity = 1 - ddv_distance 338 | 339 | self.logger.info(f' model similarity: {model_similarity}') 340 | return model_similarity 341 | 342 | def compute_ddv(self, model, inputs): 343 | dists = [] 344 | outputs = model.batch_forward(inputs).to('cpu').numpy() 345 | self.logger.debug(f'{model}: \n profiling_outputs={outputs.shape}\n{outputs}\n') 346 | n_pairs = int(len(list(inputs)) / 2) 347 | for i in range(n_pairs): 348 | ya = outputs[i] 349 | yb = outputs[i + n_pairs] 350 | # dist = spatial.distance.euclidean(ya, yb) 351 | dist = spatial.distance.cosine(ya, yb) 352 | dists.append(dist) 353 | return np.array(dists) 354 | 355 | def compute_similarity_with_ddm(self, profiling_inputs): 356 | self.logger.info(f'computing DDMs') 357 | ddm1 = self.compute_ddm(self.model1, profiling_inputs) 358 | ddm2 = self.compute_ddm(self.model2, profiling_inputs) 359 | self.logger.info(f' DDM computed: shape={ddm1.shape} and {ddm2.shape}') 360 | # print(f' ddv1={ddv1}\n ddv2={ddv2}') 361 | 362 | self.logger.info(f'measuring model similarity') 363 | ddm_distance = ModelDiff.mtx_similar1(ddm1, ddm2) 364 | model_similarity = 1 - ddm_distance 365 | 366 | self.logger.info(f' model similarity: {model_similarity}') 367 | return model_similarity 368 | 369 | def compute_ddm(self, model, inputs): 370 | outputs = model.batch_forward(inputs).to('cpu').numpy() 371 | # outputs = outputs[:, :10] 372 | outputs_list = list(outputs) 373 | ddm = spatial.distance.cdist(outputs_list, outputs_list) 374 | return ddm 375 | 376 | @staticmethod 377 | def metrics_output_diversity(model, inputs, use_torch=False): 378 | outputs = model.batch_forward(inputs).to('cpu').numpy() 379 | # output_dists = [] 380 | # for i in range(0, len(outputs) - 1): 381 | # for j in range(i + 1, len(outputs)): 382 | # output_dist = spatial.distance.euclidean(outputs[i], outputs[j]) 383 | # output_dists.append(output_dist) 384 | # diversity = sum(output_dists) / len(output_dists) 385 | output_dists = spatial.distance.cdist(list(outputs), list(outputs), p=2.0) 386 | diversity = np.mean(output_dists) 387 | return diversity 388 | 389 | @staticmethod 390 | def metrics_output_variance(model, inputs, use_torch=False): 391 | batch_output = model.batch_forward(inputs).to('cpu').numpy() 392 | mean_axis = tuple(list(range(len(batch_output.shape)))[2:]) 393 | batch_output_mean = np.mean(batch_output, axis=mean_axis) 394 | # print(batch_output_mean.shape) 395 | output_variances = np.var(batch_output_mean, axis=0) 396 | # print(output_variances) 397 | return np.mean(output_variances) 398 | 399 | @staticmethod 400 | def metrics_output_range(model, inputs, use_torch=False): 401 | batch_output = model.batch_forward(inputs).to('cpu').numpy() 402 | mean_axis = tuple(list(range(len(batch_output.shape)))[2:]) 403 | batch_output_mean = np.mean(batch_output, axis=mean_axis) 404 | output_ranges = np.max(batch_output_mean, axis=0) - np.min(batch_output_mean, axis=0) 405 | return np.mean(output_ranges) 406 | 407 | @staticmethod 408 | def metrics_neuron_coverage(model, inputs, use_torch=False): 409 | module_irs = model.batch_forward_with_ir(inputs) 410 | neurons = [] 411 | neurons_covered = [] 412 | for module in module_irs: 413 | ir = module_irs[module] 414 | # print(f'{tensor["name"]} {batch_tensor_value.shape}') 415 | # if 'relu' not in tensor["name"].lower(): 416 | # continue 417 | squeeze_axis = tuple(list(range(len(ir.shape)))[:-1]) 418 | squeeze_ir = np.max(ir, axis=squeeze_axis) 419 | for i in range(squeeze_ir.shape[-1]): 420 | neuron_name = f'{module}-{i}' 421 | neurons.append(neuron_name) 422 | neuron_value = squeeze_ir[i] 423 | covered = neuron_value > 0.1 424 | if covered: 425 | neurons_covered.append(neuron_name) 426 | neurons_not_covered = [neuron for neuron in neurons if neuron not in neurons_covered] 427 | print(f'{len(neurons_not_covered)} neurons not covered: {neurons_not_covered}') 428 | return float(len(neurons_covered)) / len(neurons) 429 | 430 | @staticmethod 431 | def _compute_decision_dist_output_cos(model, xa, xb): 432 | ya = model.batch_forward(xa) 433 | yb = model.batch_forward(xb) 434 | return spatial.distance.cosine(ya, yb) 435 | 436 | @staticmethod 437 | def _gen_profiling_inputs_none(comparator, seed_inputs, use_torch=False): 438 | return seed_inputs 439 | 440 | @staticmethod 441 | def _gen_profiling_inputs_random(comparator, seed_inputs, use_torch=False): 442 | return np.random.normal(size=seed_inputs.shape).astype(np.float32) 443 | 444 | # @staticmethod 445 | # def _gen_profiling_inputs_1pixel(comparator, seed_inputs): 446 | # input_shape = seed_inputs[0].shape 447 | # for i in range(len(seed_inputs)): 448 | # x = np.zeros(input_shape, dtype=np.float32) 449 | # random_index = np.unravel_index(np.argmax(np.random.normal(size=input_shape)), input_shape) 450 | # x[random_index] = 1 451 | # yield x 452 | 453 | @staticmethod 454 | def _gen_profiling_inputs_search(comparator, seed_inputs, use_torch=False, epsilon=0.2): 455 | input_shape = seed_inputs[0].shape 456 | n_inputs = seed_inputs.shape[0] 457 | max_iterations = 1000 458 | max_steps = 10 459 | model1 = comparator.model1 460 | model2 = comparator.model2 461 | 462 | ndims = np.prod(input_shape) 463 | # mutate_positions = torch.randperm(ndims) 464 | 465 | initial_outputs1 = model1.batch_forward(seed_inputs).to('cpu').numpy() 466 | initial_outputs2 = model2.batch_forward(seed_inputs).to('cpu').numpy() 467 | 468 | def evaluate_inputs(inputs): 469 | outputs1 = model1.batch_forward(inputs).to('cpu').numpy() 470 | outputs2 = model2.batch_forward(inputs).to('cpu').numpy() 471 | metrics1 = comparator.input_metrics(comparator.model1, inputs) 472 | metrics2 = comparator.input_metrics(comparator.model2, inputs) 473 | 474 | output_dist1 = np.mean(spatial.distance.cdist( 475 | list(outputs1), 476 | list(initial_outputs1), 477 | p=2).diagonal()) 478 | output_dist2 = np.mean(spatial.distance.cdist( 479 | list(outputs2), 480 | list(initial_outputs2), 481 | p=2).diagonal()) 482 | print(f' output distance: {output_dist1},{output_dist2}') 483 | print(f' metrics: {metrics1},{metrics2}') 484 | # if mutated_metrics <= metrics: 485 | # break 486 | return output_dist1 * output_dist2 * metrics1 * metrics2 487 | 488 | inputs = seed_inputs 489 | score = evaluate_inputs(inputs) 490 | print(f'score={score}') 491 | 492 | for i in range(max_iterations): 493 | comparator._compute_distance(inputs) 494 | print(f'mutation {i}-th iteration') 495 | # mutation_idx = random.randint(0, len(inputs)) 496 | # mutation = np.random.random_sample(size=input_shape).astype(np.float32) 497 | 498 | mutation_pos = np.random.randint(0, ndims) 499 | mutation = np.zeros(ndims).astype(np.float32) 500 | mutation[mutation_pos] = epsilon 501 | mutation = np.reshape(mutation, input_shape) 502 | 503 | mutation_batch = np.zeros(shape=inputs.shape).astype(np.float32) 504 | mutation_idx = np.random.randint(0, n_inputs) 505 | mutation_batch[mutation_idx] = mutation 506 | 507 | # print(f'{inputs.shape} {mutation_perturbation.shape}') 508 | # for j in range(max_steps): 509 | # mutated_inputs = np.clip(inputs + mutation, 0, 1) 510 | # print(f'{list(inputs)[0].shape}') 511 | mutate_right_inputs = inputs + mutation_batch 512 | mutate_right_score = evaluate_inputs(mutate_right_inputs) 513 | mutate_left_inputs = inputs - mutation_batch 514 | mutate_left_score = evaluate_inputs(mutate_left_inputs) 515 | 516 | if mutate_right_score <= score and mutate_left_score <= score: 517 | continue 518 | if mutate_right_score > mutate_left_score: 519 | print(f'mutate right: {score}->{mutate_right_score}') 520 | inputs = mutate_right_inputs 521 | score = mutate_right_score 522 | else: 523 | print(f'mutate left: {score}->{mutate_left_score}') 524 | inputs = mutate_left_inputs 525 | score = mutate_left_score 526 | return inputs 527 | 528 | @staticmethod 529 | def _compare_ddv_cos(ddv1, ddv2): 530 | return spatial.distance.cosine(ddv1, ddv2) 531 | 532 | @staticmethod 533 | def mtx_similar1(arr1:np.ndarray, arr2:np.ndarray) -> float: 534 | ''' 535 | 计算矩阵相似度的一种方法。将矩阵展平成向量,计算向量的乘积除以模长。 536 | 注意有展平操作。 537 | :param arr1:矩阵1 538 | :param arr2:矩阵2 539 | :return:实际是夹角的余弦值,ret = (cos+1)/2 540 | ''' 541 | farr1 = arr1.ravel() 542 | farr2 = arr2.ravel() 543 | len1 = len(farr1) 544 | len2 = len(farr2) 545 | if len1 > len2: 546 | farr1 = farr1[:len2] 547 | else: 548 | farr2 = farr2[:len1] 549 | 550 | numer = np.sum(farr1 * farr2) 551 | denom = np.sqrt(np.sum(farr1**2) * np.sum(farr2**2)) 552 | similar = numer / denom # 这实际是夹角的余弦值 553 | return (similar+1) / 2 # 姑且把余弦函数当线性 554 | 555 | def mtx_similar2(arr1:np.ndarray, arr2:np.ndarray) -> float: 556 | ''' 557 | 计算对矩阵1的相似度。相减之后对元素取平方再求和。因为如果越相似那么为0的会越多。 558 | 如果矩阵大小不一样会在左上角对齐,截取二者最小的相交范围。 559 | :param arr1:矩阵1 560 | :param arr2:矩阵2 561 | :return:相似度(0~1之间) 562 | ''' 563 | if arr1.shape != arr2.shape: 564 | minx = min(arr1.shape[0],arr2.shape[0]) 565 | miny = min(arr1.shape[1],arr2.shape[1]) 566 | differ = arr1[:minx,:miny] - arr2[:minx,:miny] 567 | else: 568 | differ = arr1 - arr2 569 | numera = np.sum(differ**2) 570 | denom = np.sum(arr1**2) 571 | similar = 1 - (numera / denom) 572 | return similar 573 | 574 | def mtx_similar3(arr1:np.ndarray, arr2:np.ndarray) -> float: 575 | ''' 576 | From CS231n: There are many ways to decide whether 577 | two matrices are similar; one of the simplest is the Frobenius norm. In case 578 | you haven't seen it before, the Frobenius norm of two matrices is the square 579 | root of the squared sum of differences of all elements; in other words, reshape 580 | the matrices into vectors and compute the Euclidean distance between them. 581 | difference = np.linalg.norm(dists - dists_one, ord='fro') 582 | :param arr1:矩阵1 583 | :param arr2:矩阵2 584 | :return:相似度(0~1之间) 585 | ''' 586 | if arr1.shape != arr2.shape: 587 | minx = min(arr1.shape[0],arr2.shape[0]) 588 | miny = min(arr1.shape[1],arr2.shape[1]) 589 | differ = arr1[:minx,:miny] - arr2[:minx,:miny] 590 | else: 591 | differ = arr1 - arr2 592 | dist = np.linalg.norm(differ, ord='fro') 593 | len1 = np.linalg.norm(arr1) 594 | len2 = np.linalg.norm(arr2) # 普通模长 595 | denom = (len1 + len2) / 2 596 | similar = 1 - (dist / denom) 597 | return similar 598 | 599 | 600 | def parse_args(): 601 | """ 602 | Parse command line input 603 | :return: 604 | """ 605 | parser = argparse.ArgumentParser(description="Compare similarity between two models.") 606 | 607 | parser.add_argument("-benchmark_dir", action="store", dest="benchmark_dir", 608 | required=False, default=".", help="Path to the benchmark.") 609 | parser.add_argument("-model1", action="store", dest="model1", 610 | required=True, help="model 1.") 611 | parser.add_argument("-model2", action="store", dest="model2", 612 | required=True, help="model 2.") 613 | args, unknown = parser.parse_known_args() 614 | return args 615 | 616 | 617 | def evaluate_micro_benchmark(): 618 | lines = pathlib.Path('benchmark_models/model_pairs.txt').read_text().splitlines() 619 | eval_lines = [] 620 | for line in lines: 621 | model1_str = line.split()[0] 622 | model2_str = line.split()[2] 623 | model1_path = os.path.join('benchmark_models', f'{model1_str}.h5') 624 | model2_path = os.path.join('benchmark_models', f'{model2_str}.h5') 625 | model1 = Model(model1_path) 626 | model2 = Model(model2_path) 627 | comparison = ModelDiff(model1, model2) 628 | similarity = comparison.compare() 629 | eval_line = f'{model1_str} {model2_str} {similarity}' 630 | eval_lines.append(eval_line) 631 | print(eval_line) 632 | pathlib.Path('benchmark_models/model_pairs_eval.txt').write_text('\n'.join(eval_lines)) 633 | 634 | 635 | def main(): 636 | logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)-12s %(levelname)-8s %(message)s") 637 | args = parse_args() 638 | from benchmark import ImageBenchmark 639 | bench = ImageBenchmark( 640 | datasets_dir=os.path.join(args.benchmark_dir, 'data'), 641 | models_dir=os.path.join(args.benchmark_dir, 'models') 642 | ) 643 | model1 = None 644 | model2 = None 645 | model_strs = [] 646 | for model_wrapper in bench.list_models(): 647 | if not model_wrapper.torch_model_exists(): 648 | continue 649 | if model_wrapper.__str__() == args.model1: 650 | model1 = model_wrapper 651 | if model_wrapper.__str__() == args.model2: 652 | model2 = model_wrapper 653 | model_strs.append(model_wrapper.__str__()) 654 | if model1 is None or model2 is None: 655 | print(f'model not found: {args.model1} {args.model2}') 656 | print(f'find models in the list:') 657 | print('\n'.join(model_strs)) 658 | return 659 | comparison = ModelDiff(model1, model2) 660 | similarity = comparison.compare() 661 | print(f'the similarity is {similarity}') 662 | # evaluate_micro_benchmark() 663 | 664 | 665 | if __name__ == '__main__': 666 | main() 667 | 668 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## About 2 | This is the artifact associated with our paper "ModelDiff: Testing-based DNN Similarity Comparison for Model Reuse Detection". 3 | 4 | ModelDiff is a testing-based approach to deep learning model similarity comparison. Instead of directly comparing the weights, activations, or outputs of two models, ModelDiff compares their behavioral patterns on the same set of test inputs. Specifically, the behavioral pattern of a model is represented as a decision distance vector (DDV), in which each element is the distance between the model's reactions to a pair of inputs. The knowledge similarity between two models is measured with the cosine similarity between their DDVs. 5 | 6 | ## Environment 7 | - Ubuntu 16.04 8 | - CUDA 10.0 9 | 10 | ## Dependencies 11 | - PyTorch 1.5.0 12 | - TorchVision 0.6.0 13 | - AdverTorch 0.2.0 14 | 15 | ## Get start 16 | - You should have a GPU on your device because the adversarial sample computation is pretty slow 17 | - You should first install CUDA 10.2 on your device (if you don't have) from [here](https://developer.nvidia.com/cuda-downloads) 18 | - Install [Anaconda](https://www.anaconda.com/) and create a new environment and enter the environment 19 | ``` 20 | conda create --name modeldiff python=3.6 21 | ``` 22 | - Install pytorch in the new environment 23 | ``` 24 | conda install pytorch==1.5.0 torchvision==0.6.0 cudatoolkit=10.2 -c pytorch 25 | ``` 26 | - Install AdvTorch 27 | ``` 28 | pip install advertorch 29 | ``` 30 | - Install other packages 31 | 32 | ``` 33 | pip install scipy 34 | ``` 35 | - Make a new directory called ``data`` and Download all three datasets listed below in the ``data`` directory 36 | ``` 37 | data\ 38 | |--- CUB_200_2011/ 39 | |--- stanford_dog/ 40 | |--- MIT_67/ 41 | ``` 42 | 43 | 44 | 45 | ## Prepare dataset 46 | 47 | ### [Caltech-UCSD 200 Birds](http://www.vision.caltech.edu/visipedia/CUB-200.html) 48 | Layout should be the following for the dataloader to load correctly 49 | 50 | ``` 51 | CUB_200_2011/ 52 | | README 53 | | bounding_boxes.txt 54 | | classes.txt 55 | | image_class_labels.txt 56 | | images.txt 57 | | train_test_split.txt 58 | |--- attributes 59 | |--- images/ 60 | |--- parts/ 61 | |--- train/ 62 | |--- test/ 63 | ``` 64 | 65 | ### [Stanford 120 Dogs](http://vision.stanford.edu/aditya86/ImageNetDogs/) 66 | ``` 67 | stanford_dog/ 68 | | file_list.mat 69 | | test_list.mat 70 | | train_list.mat 71 | |--- train/ 72 | |--- test/ 73 | |--- Images/ 74 | |--- Annotation/ 75 | ``` 76 | 77 | 78 | ### [MIT 67 Indoor Scenes](http://web.mit.edu/torralba/www/indoor.html) 79 | ``` 80 | MIT_67/ 81 | | TrainImages.txt 82 | | TestImages.txt 83 | |--- Annotations/ 84 | |--- Images/ 85 | |--- test/ 86 | |--- train/ 87 | ``` 88 | 89 | ## Prepare models 90 | You can change the size of the benchmark and the number of models to use in benchmark.py. The models used in the paper are MobileNetV2 and ResNet18 trained on Flower102 and StanfordDogs120 datasets. You can add other architectures and datasets the ImageBenchmark class of benchmark.py (line 487 to line 503 as following). 91 | ``` 92 | # Used in the paper 93 | self.datasets = ['Flower102', 'SDog120'] 94 | self.archs = ['mbnetv2', 'resnet18'] 95 | # Other archs 96 | # self.datasets = ['MIT67', 'Flower102', 'SDog120'] 97 | # self.archs = ['mbnetv2', 'resnet18', 'vgg16_bn', 'vgg11_bn', 'resnet34', 'resnet50'] 98 | # For debug 99 | # self.datasets = ['Flower102'] 100 | # self.archs = ['resnet18'] 101 | ``` 102 | 103 | We also provide the benchmark used in the paper and you can download it from [google drive](https://drive.google.com/file/d/1UfhnPB2V2bpwpWxnne1bodI1cIT3q98c/view?usp=sharing). 104 | 105 | ## Evaluation 106 | The code to compare DDV (decision distance vector) model similarity is in evaluate.ipynb. It loads the benchmark models from benchmark.py and compare similarity. 107 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import sys 4 | import time 5 | import argparse 6 | from pdb import set_trace as st 7 | import json 8 | import functools 9 | 10 | import torch 11 | import numpy as np 12 | import torchvision 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | 17 | from torchvision import transforms 18 | 19 | 20 | class MovingAverageMeter(object): 21 | """Computes and stores the average and current value""" 22 | def __init__(self, name, fmt=':f', momentum=0.9): 23 | self.name = name 24 | self.fmt = fmt 25 | self.momentum = momentum 26 | self.reset() 27 | 28 | def reset(self): 29 | self.val = 0 30 | self.avg = 0 31 | self.sum = 0 32 | 33 | def update(self, val, n=1): 34 | self.val = val 35 | self.avg = self.momentum*self.avg + (1-self.momentum)*val 36 | 37 | def __str__(self): 38 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 39 | return fmtstr.format(**self.__dict__) 40 | 41 | class ProgressMeter(object): 42 | def __init__(self, num_batches, meters, prefix="", output_dir=None): 43 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 44 | self.meters = meters 45 | self.prefix = prefix 46 | if output_dir is not None: 47 | self.filepath = osp.join(output_dir, "progress") 48 | 49 | def display(self, batch): 50 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 51 | entries += [str(meter) for meter in self.meters] 52 | log_str = '\t'.join(entries) 53 | print(log_str) 54 | # if self.filepath is not None: 55 | # with open(self.filepath, "a") as f: 56 | # f.write(log_str+"\n") 57 | 58 | def _get_batch_fmtstr(self, num_batches): 59 | num_digits = len(str(num_batches // 1)) 60 | fmt = '{:' + str(num_digits) + 'd}' 61 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 62 | 63 | class CrossEntropyLabelSmooth(nn.Module): 64 | def __init__(self, num_classes, epsilon = 0.1): 65 | super(CrossEntropyLabelSmooth, self).__init__() 66 | self.num_classes = num_classes 67 | self.epsilon = epsilon 68 | self.logsoftmax = nn.LogSoftmax(dim=1) 69 | 70 | def forward(self, inputs, targets): 71 | log_probs = self.logsoftmax(inputs) 72 | targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) 73 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 74 | loss = (-targets * log_probs).sum(1) 75 | return loss.mean() 76 | 77 | 78 | def linear_l2(model, beta_lmda): 79 | beta_loss = 0 80 | for m in model.modules(): 81 | if isinstance(m, nn.Linear): 82 | beta_loss += (m.weight).pow(2).sum() 83 | beta_loss += (m.bias).pow(2).sum() 84 | return 0.5*beta_loss*beta_lmda, beta_loss 85 | 86 | 87 | def l2sp(model, reg): 88 | reg_loss = 0 89 | dist = 0 90 | for m in model.modules(): 91 | if hasattr(m, 'weight') and hasattr(m, 'old_weight'): 92 | diff = (m.weight - m.old_weight).pow(2).sum() 93 | dist += diff 94 | reg_loss += diff 95 | 96 | if hasattr(m, 'bias') and hasattr(m, 'old_bias'): 97 | diff = (m.bias - m.old_bias).pow(2).sum() 98 | dist += diff 99 | reg_loss += diff 100 | 101 | if dist > 0: 102 | dist = dist.sqrt() 103 | 104 | loss = (reg * reg_loss) 105 | return loss, dist 106 | 107 | 108 | def advtest_fast(model, loader, adversary, args): 109 | advDataset = torch.load(args.adv_data_dir) 110 | test_loader = torch.utils.data.DataLoader( 111 | advDataset, 112 | batch_size=4, shuffle=False, 113 | num_workers=0, pin_memory=False) 114 | model.eval() 115 | 116 | total_ce = 0 117 | total = 0 118 | top1 = 0 119 | 120 | total = 0 121 | top1_clean = 0 122 | top1_adv = 0 123 | adv_success = 0 124 | adv_trial = 0 125 | for i, (batch, label, adv_batch, adv_label) in enumerate(test_loader): 126 | batch, label = batch.to('cuda'), label.to('cuda') 127 | adv_batch = adv_batch.to('cuda') 128 | 129 | total += batch.size(0) 130 | out_clean = model(batch) 131 | 132 | # if 'mbnetv2' in args.network: 133 | # y = torch.zeros(batch.shape[0], model.classifier[1].in_features).cuda() 134 | # else: 135 | # y = torch.zeros(batch.shape[0], model.fc.in_features).cuda() 136 | 137 | # y[:,0] = args.m 138 | # advbatch = adversary.perturb(batch, y) 139 | 140 | out_adv = model(adv_batch) 141 | 142 | _, pred_clean = out_clean.max(dim=1) 143 | _, pred_adv = out_adv.max(dim=1) 144 | 145 | clean_correct = pred_clean.eq(label) 146 | adv_trial += int(clean_correct.sum().item()) 147 | adv_success += int(pred_adv[clean_correct].eq(label[clean_correct]).sum().detach().item()) 148 | top1_clean += int(pred_clean.eq(label).sum().detach().item()) 149 | top1_adv += int(pred_adv.eq(label).sum().detach().item()) 150 | 151 | # print('{}/{}...'.format(i+1, len(test_loader))) 152 | print(f"Finish adv test fast") 153 | del test_loader 154 | del advDataset 155 | return float(top1_clean)/total*100, float(top1_adv)/total*100, float(adv_trial-adv_success) / adv_trial *100 156 | 157 | 158 | def lazy_property(func): 159 | attribute = '_lazy_' + func.__name__ 160 | 161 | @property 162 | @functools.wraps(func) 163 | def wrapper(self): 164 | if not hasattr(self, attribute): 165 | setattr(self, attribute, func(self)) 166 | return getattr(self, attribute) 167 | 168 | return wrapper 169 | 170 | 171 | class Utils: 172 | _instance = None 173 | 174 | def __init__(self): 175 | self.cache = {} 176 | 177 | @staticmethod 178 | def _get_instance(): 179 | if Utils._instance is None: 180 | Utils._instance = Utils() 181 | return Utils._instance 182 | 183 | @staticmethod 184 | def show_images(images, labels, title='examples'): 185 | plt.figure(figsize=(10,10)) 186 | plt.subplots_adjust(hspace=0.2) 187 | for n in range(25): 188 | plt.subplot(5,5,n+1) 189 | img = images[n] 190 | img = img.numpy().squeeze() 191 | plt.imshow(img) 192 | plt.title(f'{labels[n]}') 193 | plt.axis('off') 194 | _ = plt.suptitle(title) 195 | plt.show() 196 | 197 | @staticmethod 198 | def copy_weights(source_model, target_model): 199 | # print(source_model.summary()) 200 | # print(target_model.summary()) 201 | for i, layer in enumerate(target_model.layers): 202 | if not layer.get_weights(): 203 | continue 204 | source_layer = source_model.get_layer(layer.name) 205 | # print(layer) 206 | # print(source_layer) 207 | layer.set_weights(source_layer.get_weights()) 208 | return target_model 209 | 210 | @staticmethod 211 | def normalize(v): 212 | norm = np.linalg.norm(v) 213 | if norm == 0: 214 | return v 215 | return v / norm 216 | 217 | -------------------------------------------------------------------------------- /weight_pruner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import sys 4 | import time 5 | import argparse 6 | from pdb import set_trace as st 7 | import json 8 | import random 9 | 10 | import torch 11 | import numpy as np 12 | import torchvision 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | 17 | from torchvision import transforms 18 | 19 | from dataset.cub200 import CUB200Data 20 | from dataset.mit67 import MIT67 21 | from dataset.stanford_dog import SDog120 22 | from dataset.caltech256 import Caltech257Data 23 | from dataset.stanford_40 import Stanford40Data 24 | from dataset.flower102 import Flower102 25 | 26 | from model.fe_resnet import resnet18_dropout, resnet50_dropout, resnet101_dropout 27 | from model.fe_mobilenet import mbnetv2_dropout 28 | from model.fe_resnet import feresnet18, feresnet50, feresnet101 29 | from model.fe_mobilenet import fembnetv2 30 | 31 | from utils import * 32 | from finetuner import Finetuner 33 | 34 | class WeightPruner(Finetuner): 35 | def __init__( 36 | self, 37 | args, 38 | model, 39 | teacher, 40 | train_loader, 41 | test_loader, 42 | ): 43 | super(WeightPruner, self).__init__( 44 | args, model, teacher, train_loader, test_loader 45 | ) 46 | assert ( 47 | self.args.weight_ratio >= 0 48 | ) 49 | self.log_path = osp.join(self.args.output_dir, "prune.log") 50 | self.logger = open(self.log_path, "w") 51 | self.init_prune() 52 | self.logger.close() 53 | 54 | def prune_record(self, log): 55 | print(log) 56 | self.logger.write(log+"\n") 57 | 58 | def init_prune(self): 59 | ratio = self.args.weight_ratio 60 | log = f"Init prune ratio {ratio:.2f}" 61 | self.prune_record(log) 62 | self.weight_prune(ratio) 63 | self.check_param_num() 64 | 65 | def check_param_num(self): 66 | model = self.model 67 | total = sum([module.weight.nelement() for module in model.modules() if isinstance(module, nn.Conv2d) ]) 68 | num = total 69 | for m in model.modules(): 70 | if ( isinstance(m, nn.Conv2d) ): 71 | num -= int((m.weight.data == 0).sum()) 72 | ratio = (total - num) / total 73 | log = f"===>Check: Total {total}, current {num}, prune ratio {ratio:2f}" 74 | self.prune_record(log) 75 | 76 | 77 | def weight_prune( 78 | self, 79 | prune_ratio, 80 | random_prune=False, 81 | ): 82 | model = self.model.cpu() 83 | total = 0 84 | for name, module in model.named_modules(): 85 | if ( isinstance(module, nn.Conv2d) ): 86 | total += module.weight.data.numel() 87 | 88 | conv_weights = torch.zeros(total) 89 | index = 0 90 | for name, module in model.named_modules(): 91 | if ( isinstance(module, nn.Conv2d) ): 92 | size = module.weight.data.numel() 93 | conv_weights[index:(index+size)] = module.weight.data.view(-1).abs().clone() 94 | index += size 95 | 96 | y, i = torch.sort(conv_weights) 97 | # thre_index = int(total * prune_ratio) 98 | # thre = y[thre_index] 99 | thre_index = int(total * prune_ratio) 100 | thre = y[thre_index] 101 | log = f"Pruning threshold: {thre:.4f}" 102 | self.prune_record(log) 103 | 104 | pruned = 0 105 | 106 | zero_flag = False 107 | 108 | for name, module in model.named_modules(): 109 | if ( isinstance(module, nn.Conv2d) ): 110 | weight_copy = module.weight.data.abs().clone() 111 | mask = weight_copy.gt(thre).float() 112 | 113 | if random_prune: 114 | print(f"Random prune {name}") 115 | mask = np.zeros(weight_copy.numel()) + 1 116 | prune_number = round(prune_ratio * weight_copy.numel()) 117 | mask[:prune_number] = 0 118 | np.random.shuffle(mask) 119 | mask = mask.reshape(weight_copy.shape) 120 | mask = torch.Tensor(mask) 121 | 122 | pruned = pruned + mask.numel() - torch.sum(mask) 123 | # np.random.shuffle(mask) 124 | module.weight.data.mul_(mask) 125 | if int(torch.sum(mask)) == 0: 126 | zero_flag = True 127 | remain_ratio = int(torch.sum(mask)) / mask.numel() 128 | log = (f"layer {name} \t total params: {mask.numel()} \t " 129 | f"remaining params: {int(torch.sum(mask))}({remain_ratio:.2f})") 130 | self.prune_record(log) 131 | 132 | if zero_flag: 133 | raise RuntimeError("There exists a layer with 0 parameters left.") 134 | log = (f"Total conv params: {total}, Pruned conv params: {pruned}, " 135 | f"Pruned ratio: {pruned/total:.2f}") 136 | self.prune_record(log) 137 | self.model = model.cuda() 138 | 139 | def final_check_param_num(self): 140 | self.logger = open(self.log_path, "a") 141 | self.check_param_num() 142 | self.logger.close() 143 | --------------------------------------------------------------------------------