├── figures └── intro.png ├── quantized_ops ├── __init__.py └── zs_quantized_ops.py ├── faultinjection_ops ├── __init__.py └── zs_faultinjection_ops.py ├── faultsMap.py ├── config.py ├── README.md ├── faultmodels └── randomfault.py ├── zs_train.py ├── models ├── vggf_pytorch.py ├── vgg.py ├── resnet.py ├── vggf.py ├── __init__.py ├── resnetf.py ├── resnetf_pytorch.py └── generator.py ├── zs_train_input_transform_eval.py ├── LICENSE ├── zs_main.py └── zs_train_input_transform_eopm_gen.py /figures/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/NeuralFuse/HEAD/figures/intro.png -------------------------------------------------------------------------------- /quantized_ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .zs_quantized_ops import nnConv2dSymQuant_op 2 | from .zs_quantized_ops import nnLinearSymQuant_op 3 | -------------------------------------------------------------------------------- /faultinjection_ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .zs_faultinjection_ops import nnConv2dPerturbWeight_op 2 | from .zs_faultinjection_ops import nnLinearPerturbWeight_op 3 | -------------------------------------------------------------------------------- /faultsMap.py: -------------------------------------------------------------------------------- 1 | # global variable: 2 | # For storing the result in randomfault 3 | BitErrorMap0 = None 4 | BitErrorMap1 = None 5 | 6 | # For storing the result in zs_faultinjection_ops 7 | BitErrorMap0to1 = None 8 | BitErrorMap1to0 = None 9 | 10 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict 2 | 3 | cfg = EasyDict() 4 | 5 | cfg.faulty_layers = ["linear", "conv"] 6 | 7 | cfg.batch_size = 512 8 | cfg.test_batch_size = 100 9 | cfg.epochs = 2 10 | cfg.precision = 8 11 | 12 | cfg.data_dir = ( 13 | "../dataset" 14 | ) 15 | cfg.model_dir = ( 16 | "model_weights/symmetric_signed/" 17 | ) 18 | cfg.model_dir_resnetft = ( 19 | "model_weights/symmetric_signed/resnetft/" 20 | ) 21 | cfg.model_dir_vggft = ( 22 | "model_weights/symmetric_signed/vggft/" 23 | ) 24 | cfg.save_dir = ( 25 | "~/tmp/" 26 | ) 27 | cfg.save_dir_curve = ( 28 | "~/tmp_curve/" 29 | ) 30 | 31 | cfg.channels = 3 32 | 33 | cfg.w1 = 32 # 28 #224 34 | cfg.h1 = 32 # 28 #224 35 | cfg.w2 = 32 # 32 28 #224 36 | cfg.h2 = 32 # 32 28 #224 37 | cfg.lmd = 5e-7 38 | cfg.learning_rate = 1 39 | cfg.decay = 0.96 40 | cfg.max_epoch = 1 41 | cfg.lb = 1 42 | cfg.device = None 43 | cfg.seed = 0 44 | 45 | 46 | # For EOPM 47 | cfg.N = 100 48 | cfg.randomRange = 30000 49 | cfg.totalRandom = True # True: Sample perturbed models in the range cfg.randomRange 50 | cfg.G = 'ConvL' 51 | 52 | # For transform generalization testing: 53 | cfg.beginSeed = 50000 54 | cfg.endSeed = 50010 55 | 56 | # For transform_eval 57 | cfg.testing_mode = 'generator_base' # clean / generator_base 58 | cfg.G_PATH = '' -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NeuralFuse 2 | 3 | Official repo to reproduce the paper "[NeuralFuse: Learning to Recover the Accuracy of Access-Limited Neural Network Inference in Low-Voltage Regimes](https://arxiv.org/abs/2306.16869)." 4 | 5 | ![NeuralFuse](figures/intro.png) 6 | 7 | Deep neural networks (DNNs) have become ubiquitous in machine learning, but their energy consumption remains a notable issue. Lowering the supply voltage is an effective strategy for reducing energy consumption. However, aggressively scaling down the supply voltage can lead to accuracy degradation due to random bit flips in static random access memory (SRAM) where model parameters are stored. To address this challenge, we introduce NeuralFuse, a novel add-on module that addresses the accuracy-energy tradeoff in low-voltage regimes by learning input transformations to generate error-resistant data representations. NeuralFuse protects DNN accuracy in both nominal and low-voltage scenarios. Moreover, NeuralFuse is easy to implement and can be readily applied to DNNs with limited access, such as non-configurable hardware or remote access to cloud-based APIs. Experimental results demonstrate that, at a 1% bit error rate, NeuralFuse can reduce SRAM memory access energy by up to 24% while recovering accuracy by up to 57%. To the best of our knowledge, this is the first model-agnostic approach (i.e., no model retraining) to address low-voltage-induced bit errors. 8 | 9 | ## Run Base Model Training with QAT: 10 | ```python 11 | python zs_main.py [resnet18 | resnet50 | vgg11 | vgg16 | vgg19] train [cifar10 | gtsrb | cifar100 | imagenet224] -E 300 -LR 0.001 -BS 256 12 | ``` 13 | 14 | ## Run NerualFuse Training: 15 | ```python 16 | python zs_main.py [resnet18 | resnet50 | vgg11 | vgg16 | vgg19] transform_eopm_gen [cifar10 | gtsrb | cifar100 | imagenet224] -ber 0.01 -cp [please input the model path here] -E 300 -LR 0.001 -BS 256 -LM 5 -N 10 -G [ConvL | ConvS | DeConvL | DeConvS | UNetL | UNetS] 17 | ``` 18 | 19 | ## Run NerualFuse Evaluation with Perturbed Base Model: 20 | Please set the setting in config.py first. 21 | 22 | For example, in config.py: 23 | 24 | ```python 25 | ''' 26 | cfg.testing_mode: Choosing the evaluation mode. 27 | 1. clean: Evaluate the clean accuracy of specific base model. 28 | 2. generator_base: Evalueate the improved accuracy by NeuralFuse on specific perturbed base model. 29 | cfg.G_PATH: The path of the NeuralFuse model. 30 | ''' 31 | cfg.testing_mode = 'generator_base' # clean / generator_base 32 | cfg.G_PATH = '' # Only work in generator_base mode. 33 | ``` 34 | 35 | ```python 36 | python zs_main.py [resnet18 | resnet50 | vgg11 | vgg16 | vgg19] transform_eval [cifar10 | gtsrb | cifar100 | imagenet224] -ber 0.01 -cp [please input the model path here] -BS 256 -TBS 256 -G [ConvL | ConvS | DeConvL | DeConvS | UNetL | UNetS] 37 | ``` 38 | 39 | ## Arguments: 40 | * ```-ber```: Random bit error rate. Should be in range (0, 1). 41 | * ```-cp``` : Checkpoint. The path of the checkpoints of the base model. 42 | * ```-E``` : Epoch. The training epoch. 43 | * ```-LR``` : Learning rate. The learning rate for training. 44 | * ```-LM``` : Lambda. Controling the tradeoff between clean loss and perturbed loss. 45 | * ```-N``` : The number of perturbed models to calculate the loss proposed in our EOPM algorithm. 46 | * ```-BS``` : Batch size for training. 47 | * ```-TBS```: Batch size for testing. 48 | * ```-G``` : Generator. The architecture of the NeuralFuse generator for training. 49 | 50 | ## Choose Dataset: 51 | * Cifar10 52 | * Cifar100 53 | * Gtsrb 54 | * Imagenet10 55 | 56 | ## Notes: 57 | We also adopt the Pytorch offical architecture settings for all of the base models. To use the version of Pytorch offical implementation, please change the args into [resnet18Py | resnet50Py | vgg11Py | vgg16Py | vgg19Py] instead. 58 | 59 | ## Citation 60 | If you find this helpful for your research, please cite our paper as follows: 61 | 62 | @article{sun2023neuralfuse, 63 | title={{NeuralFuse: Learning to Recover the Accuracy of Access-Limited Neural Network Inference in Low-Voltage Regimes}}, 64 | author={Hao-Lun Sun and Lei Hsiung and Nandhini Chandramoorthy and Pin-Yu Chen and Tsung-Yi Ho}, 65 | journal={arXiv preprint arXiv:2306.16869}, 66 | year={2023} 67 | } 68 | 69 | -------------------------------------------------------------------------------- /faultmodels/randomfault.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # This class of model is characterized by the following parameters - 4 | # ber - bit error rate = count of faulty bits / total bits at a given voltage 5 | # prob - likelihood of a faulty bit cell being faulty -- i.e likelihood of a 6 | # bit cell being faulty on repeated access -- is it a transient fault ? 7 | # ber0 - fraction of faulty bit cells that default to 0. (ber1 = ber - ber0) 8 | # Assume each bit is likely to be faulty; ie sample from a uniform 9 | # distribution to generate a spatial distribution of faults 10 | 11 | 12 | class RandomFaultModel: 13 | MEM_ROWS = 8192 14 | MEM_COLS = 128 15 | prob = 1.0 # temporal likelihood of a given bit failing for a given access 16 | ber0 = 0.5 17 | voltage = 0 18 | # BitErrorRate = [0.01212883, 0.00397706, 0.001214473, 0.00015521, 19 | # 0.000126225, 4.06934E-05, 1.3119E-05] # Count of faulty 20 | # bits / total bits for 7 operating points. 21 | 22 | def __init__(self, ber, prec, pos, seed): 23 | self.ber = ber 24 | self.ber0 = RandomFaultModel.ber0 25 | self.precision = prec 26 | self.MEM_ROWS = RandomFaultModel.MEM_ROWS 27 | self.MEM_COLS = RandomFaultModel.MEM_COLS 28 | #print( 29 | # "Bit Error Rate %.3f Precision %d Position %d" 30 | # % (self.ber, self.precision, pos) 31 | #) 32 | if pos == -1: 33 | # self.BitErrorMap_flip0, self.BitErrorMap_flip1 = 34 | # self.ReadBitErrorMap() 35 | ( 36 | self.BitErrorMap_flip0, 37 | self.BitErrorMap_flip1, 38 | ) = self.GenBitErrorMap(seed) 39 | else: 40 | ( 41 | self.BitErrorMap_flip0, 42 | self.BitErrorMap_flip1, 43 | ) = self.GenBitPositionErrorMap(pos) 44 | 45 | def GenBitErrorMap(self, seed): 46 | bitmap = np.zeros((self.MEM_ROWS, self.MEM_COLS)) 47 | bitmap_flip0 = np.zeros((self.MEM_ROWS, self.MEM_COLS)) 48 | bitmap_flip1 = np.zeros((self.MEM_ROWS, self.MEM_COLS)) 49 | 50 | if seed is not None: 51 | np.random.seed(seed) 52 | bitmap_t = np.random.rand(self.MEM_ROWS, self.MEM_COLS) 53 | bitmap[bitmap_t < self.ber] = 1 54 | 55 | # print(bitmap) 56 | if seed is not None: 57 | np.random.seed(seed + 1) 58 | bitmap_flip = np.random.rand(self.MEM_ROWS, self.MEM_COLS) 59 | 60 | bitmap_flip0[bitmap_flip < self.ber0] = 1 61 | bitmap_flip1[bitmap_flip >= self.ber0] = 1 62 | # print(bitmap_flip0) 63 | # print(bitmap_flip1) 64 | bitmap_flip0 = bitmap * bitmap_flip0 65 | bitmap_flip1 = bitmap * bitmap_flip1 66 | 67 | bitmap_flip0 = bitmap_flip0.astype(np.int64) 68 | bitmap_flip1 = bitmap_flip1.astype(np.int64) 69 | # print(bitmap_flip0) 70 | # print(bitmap_flip1) 71 | bitcells = self.MEM_ROWS * self.MEM_COLS 72 | # print("Read 0 Bit Error Rate", sum(sum(bitmap_flip0)) / bitcells) 73 | # print("Read 1 Bit Error Rate", sum(sum(bitmap_flip1)) / bitcells) 74 | return bitmap_flip0, bitmap_flip1 75 | 76 | def GenBitPositionErrorMap(self, pos): 77 | 78 | bitmap = np.zeros((self.MEM_ROWS, self.MEM_COLS)) 79 | bitmap_flip0 = np.zeros((self.MEM_ROWS, self.MEM_COLS)) 80 | bitmap_flip1 = np.zeros((self.MEM_ROWS, self.MEM_COLS)) 81 | 82 | # Generate errors at rate ber in a specific bit position, 83 | # maximum of one error per weight in the specified position 84 | weights_per_row = int(self.MEM_COLS / self.precision) 85 | bitmap_pos = np.zeros((self.MEM_ROWS, weights_per_row)) 86 | bitmap_t = np.random.rand(self.MEM_ROWS, weights_per_row) 87 | bitmap_pos[bitmap_t < self.ber] = 1 88 | # Insert the faulty column in bit error map 89 | for k in range(0, weights_per_row): 90 | bitmap[:, k * self.precision + pos] = bitmap_pos[:, k] 91 | # print(bitmap) 92 | 93 | bitmap_flip = np.random.rand(self.MEM_ROWS, self.MEM_COLS) 94 | bitmap_flip0[bitmap_flip < self.ber0] = 1 95 | bitmap_flip1[bitmap_flip >= self.ber0] = 1 96 | # print(bitmap_flip0) 97 | # print(bitmap_flip1) 98 | bitmap_flip0 = bitmap * bitmap_flip0 99 | bitmap_flip1 = bitmap * bitmap_flip1 100 | 101 | # print(bitmap_flip0) 102 | # print(bitmap_flip1) 103 | bitmap_flip0 = bitmap_flip0.astype(np.uint32) 104 | bitmap_flip1 = bitmap_flip1.astype(np.uint32) 105 | bitcells = self.MEM_ROWS * self.MEM_COLS 106 | print( 107 | "Bit Error Rate", 108 | sum(sum(bitmap_flip0)) / bitcells 109 | + sum(sum(bitmap_flip1)) / bitcells, 110 | ) 111 | return bitmap_flip0, bitmap_flip1 112 | 113 | def ReadBitErrorMap(self): 114 | mem_voltage = self.voltage 115 | chip = "n" 116 | fname = ( 117 | "./faultmaps_chip_" 118 | + chip 119 | + "/fmap_sa0_v_" 120 | + str(mem_voltage) 121 | + ".txt" 122 | ) 123 | some_arr = np.genfromtxt(fname, dtype="uint32", delimiter=",") 124 | bitmap_flip0 = some_arr[0 : self.MEM_ROWS, 0 : self.MEM_COLS] 125 | print( 126 | "SA 0 Bit error rate", 127 | (bitmap_flip0.sum() / (self.MEM_ROWS * self.MEM_COLS)), 128 | ) 129 | fname = ( 130 | "./faultmaps_chip_" 131 | + chip 132 | + "/fmap_sa1_v_" 133 | + str(mem_voltage) 134 | + ".txt" 135 | ) 136 | some_arr = np.genfromtxt(fname, dtype="uint32", delimiter=",") 137 | bitmap_flip1 = some_arr[0 : self.MEM_ROWS, 0 : self.MEM_COLS] 138 | print( 139 | "SA 1 Bit error rate", 140 | (bitmap_flip1.sum() / (self.MEM_ROWS * self.MEM_COLS)), 141 | ) 142 | return bitmap_flip0, bitmap_flip1 143 | -------------------------------------------------------------------------------- /zs_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import torch.optim as optim 6 | from torch import nn 7 | 8 | from config import cfg 9 | from models import default_model_path, init_models_faulty, init_models 10 | 11 | __all__ = ["training"] 12 | 13 | debug = False 14 | torch.manual_seed(0) 15 | 16 | class WarmUpLR(optim.lr_scheduler._LRScheduler): 17 | """warmup_training learning rate scheduler 18 | Args: 19 | optimizer: optimzier(e.g. SGD) 20 | total_iters: totoal_iters of warmup phase 21 | """ 22 | def __init__(self, optimizer, total_iters, last_epoch=-1): 23 | 24 | self.total_iters = total_iters 25 | super().__init__(optimizer, last_epoch) 26 | 27 | def get_lr(self): 28 | """we will use the first m batches, and set the learning 29 | rate to base_lr * m / total_iters 30 | """ 31 | return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs] 32 | 33 | 34 | def training( 35 | trainloader, 36 | arch, 37 | dataset, 38 | in_channels, 39 | precision, 40 | retrain, 41 | checkpoint_path, 42 | force, 43 | device, 44 | fl, 45 | ber, 46 | pos, 47 | ): 48 | 49 | """ 50 | Apply quantization aware training. 51 | :param trainloader: The loader of training data. 52 | :param arch: A string. The architecture of the model would be used. 53 | :param dataset: A string. The name of the training data. 54 | :param in_channels: An int. The input channels of the training data. 55 | :param precision: An int. The number of bits would be used to quantize 56 | the model. 57 | :param retrain: A boolean. Start from checkpoint. 58 | :param checkpoint_path: A string. The path that stores the models. 59 | :param force: Overwrite checkpoint. 60 | :param device: A string. Specify using GPU or CPU. 61 | """ 62 | 63 | model, checkpoint_epoch = init_models(arch, 3, precision, retrain, checkpoint_path, dataset) # Quantization Aware Training without using bit error! 64 | 65 | print("Training with Learning rate %.4f" % (cfg.learning_rate)) 66 | 67 | if dataset == 'cifar100': 68 | print('cifar100') 69 | opt = optim.SGD(model.parameters(), lr=cfg.learning_rate, momentum=0.9) 70 | #iter_per_epoch = len(trainloader) 71 | #warmup_scheduler = WarmUpLR(opt, iter_per_epoch * 1) # warmup = 1 72 | #train_scheduler = optim.lr_scheduler.MultiStepLR(opt, milestones=[60, 120, 160], gamma=0.2) 73 | else: 74 | opt = optim.SGD(model.parameters(), lr=cfg.learning_rate, momentum=0.9) 75 | 76 | model = model.to(device) 77 | from torchsummary import summary 78 | if dataset == 'imagenet128': 79 | print('imagenet128') 80 | summary(model, (3, 128, 128)) 81 | elif dataset == 'imagenet224': 82 | print('imagenet224') 83 | summary(model, (3, 224, 224)) 84 | else: 85 | summary(model, (3, 32, 32)) 86 | # model = torch.nn.DataParallel(model) 87 | torch.backends.cudnn.benchmark = True 88 | 89 | for x in range(checkpoint_epoch + 1, cfg.epochs): 90 | 91 | print("Epoch: %03d" % x) 92 | 93 | running_loss = 0.0 94 | running_correct = 0 95 | for batch_id, (inputs, outputs) in enumerate(trainloader): 96 | 97 | inputs = inputs.to(device) 98 | outputs = outputs.to(device) 99 | 100 | opt.zero_grad() 101 | 102 | # Store original model parameters before 103 | # quantization/perturbation, detached from graph 104 | if precision > 0: 105 | list_init_params = [] 106 | with torch.no_grad(): 107 | for init_params in model.parameters(): 108 | list_init_params.append(init_params.clone().detach()) 109 | 110 | if debug: 111 | if batch_id % 100 == 0: 112 | print("initial params") 113 | print(model.fc2.weight[0:3, 0:3]) 114 | print(model.conv1.weight[0, 0, :, :]) 115 | 116 | model.train() 117 | model_outputs = model(inputs) # pylint: disable=E1102 118 | 119 | _, preds = torch.max(model_outputs, 1) 120 | outputs = outputs.view( 121 | outputs.size(0) 122 | ) # changing the size from (batch_size,1) to batch_size. 123 | 124 | if precision > 0: 125 | if debug: 126 | if batch_id % 100 == 0: 127 | print("quantized params") 128 | print(model.fc2.weight[0:3, 0:3]) 129 | print(model.conv1.weight[0, 0, :, :]) 130 | 131 | loss = nn.CrossEntropyLoss()(model_outputs, outputs) 132 | 133 | # Compute gradient of perturbed weights with perturbed loss 134 | loss.backward() 135 | 136 | # restore model weights with unquantized value 137 | # This step is not important because list_init_params == model.parameters() 138 | # Therefore, apply gradients on model.parameters() directly is OK. 139 | if precision > 0: 140 | with torch.no_grad(): 141 | for i, restored_params in enumerate(model.parameters()): 142 | restored_params.copy_(list_init_params[i]) 143 | 144 | if debug: 145 | if batch_id % 100 == 0: 146 | print("restored params") 147 | print(model.fc2.weight[0:3, 0:3]) 148 | print(model.conv1.weight[0, 0, :, :]) 149 | 150 | # update restored weights with gradient 151 | opt.step() 152 | #if dataset == 'cifar100': 153 | # if x <= 1: # warmup = 1 154 | # warmup_scheduler.step() 155 | # else: 156 | # train_scheduler.step() 157 | # lr_scheduler.step() 158 | 159 | running_loss += loss.item() 160 | running_correct += torch.sum(preds == outputs.data) 161 | 162 | accuracy = running_correct.double() / (len(trainloader.dataset)) 163 | print("For epoch: {}, loss: {:.6f}, accuracy: {:.5f}".format( 164 | x, 165 | running_loss / len(trainloader.dataset), 166 | accuracy 167 | ) 168 | ) 169 | if (x+1)%10 == 0: 170 | 171 | model_path = default_model_path( 172 | cfg.model_dir, arch, dataset, precision, fl, ber, pos, x+1 173 | ) 174 | 175 | if not os.path.exists(os.path.dirname(model_path)): 176 | os.makedirs(os.path.dirname(model_path)) 177 | 178 | if os.path.exists(model_path) and not force: 179 | print("Checkpoint already present ('%s')" % model_path) 180 | sys.exit(1) 181 | 182 | torch.save( 183 | { 184 | "epoch": x, 185 | "model_state_dict": model.state_dict(), 186 | "optimizer_state_dict": opt.state_dict(), 187 | "loss": running_loss / batch_id, 188 | "accuracy": accuracy, 189 | }, 190 | model_path, 191 | ) 192 | -------------------------------------------------------------------------------- /models/vggf_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from faultinjection_ops import zs_faultinjection_ops 5 | from quantized_ops import zs_quantized_ops 6 | 7 | # Per layer clamping currently based on manual values set 8 | weight_clamp_values = [ 9 | 0.2, 10 | 0.2, 11 | 0.15, 12 | 0.13, 13 | 0.1, 14 | 0.1, 15 | 0.1, 16 | 0.05, 17 | 0.05, 18 | 0.05, 19 | 0.05, 20 | 0.05, 21 | 0.05, 22 | 0.05, 23 | 0.05, 24 | 0.05, 25 | 0.05, 26 | 0.05, 27 | 0.05, 28 | ] 29 | fc_weight_clamp = 0.1 30 | 31 | 32 | class VGGPy(nn.Module): 33 | def __init__(self, features, classifier, classes=10, init_weights=True): 34 | super(VGGPy, self).__init__() 35 | self.features = features 36 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 37 | # self.avgpool = nn.AvgPool2d(kernel_size=2, stride=1) 38 | self.classifier = classifier 39 | # self.classifier = nn.Sequential( 40 | # nn.Linear(512 * 7 * 7, 4096), 41 | # nn.ReLU(True), 42 | # nn.Dropout(), 43 | # nn.Linear(4096, 4096), 44 | # nn.ReLU(True), 45 | # nn.Dropout(), 46 | # nn.Linear(4096, classes), 47 | # ) 48 | if init_weights: 49 | self._initialize_weights() 50 | 51 | def forward(self, x): 52 | x = self.features(x) 53 | x = self.avgpool(x) 54 | x = torch.flatten(x, 1) 55 | x = self.classifier(x) 56 | return x 57 | 58 | def _initialize_weights(self): 59 | for m in self.modules(): 60 | if isinstance(m, nn.Conv2d): 61 | nn.init.kaiming_normal_( 62 | m.weight, mode="fan_out", nonlinearity="relu" 63 | ) 64 | if m.bias is not None: 65 | nn.init.constant_(m.bias, 0) 66 | elif isinstance(m, nn.BatchNorm2d): 67 | nn.init.constant_(m.weight, 1) 68 | nn.init.constant_(m.bias, 0) 69 | elif isinstance(m, nn.Linear): 70 | nn.init.normal_(m.weight, 0, 0.01) 71 | nn.init.constant_(m.bias, 0) 72 | 73 | 74 | def make_classifierPy( 75 | classes, 76 | precision, 77 | ber, 78 | position, 79 | faulty_layers, 80 | ): 81 | if "linear" in faulty_layers: 82 | classifier = nn.Sequential( 83 | zs_faultinjection_ops.nnLinearPerturbWeight_op( 84 | 25088, 85 | 4096, 86 | precision, 87 | fc_weight_clamp, 88 | ), 89 | nn.ReLU(inplace=True), 90 | nn.Dropout(p=0.5, inplace=False), 91 | zs_faultinjection_ops.nnLinearPerturbWeight_op( 92 | 4096, 93 | 4096, 94 | precision, 95 | fc_weight_clamp, 96 | ), 97 | nn.ReLU(inplace=True), 98 | nn.Dropout(p=0.5, inplace=False), 99 | zs_faultinjection_ops.nnLinearPerturbWeight_op( 100 | 4096, 101 | classes, 102 | precision, 103 | fc_weight_clamp, 104 | ) 105 | ) 106 | else: 107 | classifier = nn.Sequential( 108 | zs_quantized_ops.nnLinearSymQuant_op( 109 | 25088, 4096, precision, fc_weight_clamp 110 | ), 111 | nn.ReLU(inplace=True), 112 | nn.Dropout(p=0.5, inplace=False), 113 | zs_quantized_ops.nnLinearSymQuant_op( 114 | 4096, 4096, precision, fc_weight_clamp 115 | ), 116 | nn.ReLU(inplace=True), 117 | nn.Dropout(p=0.5, inplace=False), 118 | zs_quantized_ops.nnLinearSymQuant_op( 119 | 4096, classes, precision, fc_weight_clamp 120 | ) 121 | ) 122 | 123 | return classifier 124 | 125 | 126 | def make_layersPy( 127 | cfg, 128 | in_channels, 129 | batch_norm, 130 | precision, 131 | ber, 132 | position, 133 | faulty_layers, 134 | ): 135 | layers = [] 136 | # in_channels = 3 137 | cl = 0 138 | # pdb.set_trace() 139 | for v in cfg: 140 | if v == "M": 141 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 142 | else: 143 | if "conv" in faulty_layers: 144 | conv2d = zs_faultinjection_ops.nnConv2dPerturbWeight_op( 145 | in_channels, 146 | v, 147 | kernel_size=3, 148 | stride=1, 149 | padding=1, 150 | bias=True, 151 | precision=precision, 152 | clamp_val=weight_clamp_values[cl], 153 | ) 154 | else: 155 | conv2d = zs_quantized_ops.nnConv2dSymQuant_op( 156 | in_channels, 157 | v, 158 | kernel_size=3, 159 | stride=1, 160 | padding=1, 161 | bias=True, 162 | precision=precision, 163 | clamp_val=weight_clamp_values[cl], 164 | ) 165 | cl = cl + 1 166 | 167 | if batch_norm: 168 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 169 | else: 170 | layers += [conv2d, nn.ReLU(inplace=True)] 171 | in_channels = v 172 | return nn.Sequential(*layers) 173 | 174 | 175 | cfgs = { 176 | "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], 177 | "B": [ 178 | 64, 179 | 64, 180 | "M", 181 | 128, 182 | 128, 183 | "M", 184 | 256, 185 | 256, 186 | "M", 187 | 512, 188 | 512, 189 | "M", 190 | 512, 191 | 512, 192 | "M", 193 | ], 194 | "D": [ 195 | 64, 196 | 64, 197 | "M", 198 | 128, 199 | 128, 200 | "M", 201 | 256, 202 | 256, 203 | 256, 204 | "M", 205 | 512, 206 | 512, 207 | 512, 208 | "M", 209 | 512, 210 | 512, 211 | 512, 212 | "M", 213 | ], 214 | "E": [ 215 | 64, 216 | 64, 217 | "M", 218 | 128, 219 | 128, 220 | "M", 221 | 256, 222 | 256, 223 | 256, 224 | 256, 225 | "M", 226 | 512, 227 | 512, 228 | 512, 229 | 512, 230 | "M", 231 | 512, 232 | 512, 233 | 512, 234 | 512, 235 | "M", 236 | ], 237 | } 238 | 239 | 240 | def vggfPy( 241 | cfg, 242 | input_channels, 243 | classes, 244 | batch_norm, 245 | precision, 246 | ber, 247 | position, 248 | faulty_layers, 249 | ): 250 | model = VGGPy( 251 | make_layersPy( 252 | cfgs[cfg], 253 | in_channels=input_channels, 254 | batch_norm=batch_norm, 255 | precision=precision, 256 | ber=ber, 257 | position=position, 258 | faulty_layers=faulty_layers, 259 | ), 260 | make_classifierPy( 261 | classes, 262 | precision, 263 | ber, 264 | position, 265 | faulty_layers, 266 | ), 267 | classes, 268 | True, 269 | ) 270 | return model -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from quantized_ops import zs_quantized_ops 5 | 6 | # Per layer clamping currently based on manual values set 7 | weight_clamp_values = [ 8 | 0.2, 9 | 0.2, 10 | 0.15, 11 | 0.13, 12 | 0.1, 13 | 0.1, 14 | 0.1, 15 | 0.05, 16 | 0.05, 17 | 0.05, 18 | 0.05, 19 | 0.05, 20 | 0.05, 21 | ] 22 | fc_weight_clamp = 0.1 23 | 24 | 25 | class VGG(nn.Module): 26 | def __init__(self, features, classes=10, precision=-1, init_weights=True): 27 | super(VGG, self).__init__() 28 | self.features = features 29 | # self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 30 | self.avgpool = nn.AvgPool2d(kernel_size=1, stride=1) 31 | if precision < 0: 32 | self.classifier = nn.Linear(512, classes) 33 | else: 34 | self.classifier = zs_quantized_ops.nnLinearSymQuant_op( 35 | 512, classes, precision, fc_weight_clamp 36 | ) 37 | # self.classifier = nn.Sequential( 38 | # nn.Linear(512 * 7 * 7, 4096), 39 | # nn.ReLU(True), 40 | # nn.Dropout(), 41 | # nn.Linear(4096, 4096), 42 | # nn.ReLU(True), 43 | # nn.Dropout(), 44 | # nn.Linear(4096, classes), 45 | # ) 46 | if init_weights: 47 | self._initialize_weights() 48 | 49 | def forward(self, x): 50 | x = self.features(x) 51 | x = self.avgpool(x) 52 | x = torch.flatten(x, 1) 53 | x = self.classifier(x) 54 | return x 55 | 56 | def _initialize_weights(self): 57 | for m in self.modules(): 58 | if isinstance(m, nn.Conv2d): 59 | nn.init.kaiming_normal_( 60 | m.weight, mode="fan_out", nonlinearity="relu" 61 | ) 62 | if m.bias is not None: 63 | nn.init.constant_(m.bias, 0) 64 | elif isinstance(m, nn.BatchNorm2d): 65 | nn.init.constant_(m.weight, 1) 66 | nn.init.constant_(m.bias, 0) 67 | elif isinstance(m, nn.Linear): 68 | nn.init.normal_(m.weight, 0, 0.01) 69 | nn.init.constant_(m.bias, 0) 70 | 71 | 72 | def make_layers(cfg, in_channels, batch_norm=False, precision=-1): 73 | layers = [] 74 | # in_channels = 3 75 | cl = 0 76 | for v in cfg: 77 | if v == "M": 78 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 79 | else: 80 | if precision < 0: 81 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 82 | else: 83 | conv2d = zs_quantized_ops.nnConv2dSymQuant_op( 84 | in_channels, 85 | v, 86 | kernel_size=3, 87 | stride=1, 88 | padding=1, 89 | bias=True, 90 | precision=precision, 91 | clamp_val=weight_clamp_values[cl], 92 | ) 93 | cl = cl + 1 94 | 95 | if batch_norm: 96 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 97 | else: 98 | layers += [conv2d, nn.ReLU(inplace=True)] 99 | in_channels = v 100 | return nn.Sequential(*layers) 101 | 102 | 103 | cfgs = { 104 | "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], 105 | "B": [ 106 | 64, 107 | 64, 108 | "M", 109 | 128, 110 | 128, 111 | "M", 112 | 256, 113 | 256, 114 | "M", 115 | 512, 116 | 512, 117 | "M", 118 | 512, 119 | 512, 120 | "M", 121 | ], 122 | "D": [ 123 | 64, 124 | 64, 125 | "M", 126 | 128, 127 | 128, 128 | "M", 129 | 256, 130 | 256, 131 | 256, 132 | "M", 133 | 512, 134 | 512, 135 | 512, 136 | "M", 137 | 512, 138 | 512, 139 | 512, 140 | "M", 141 | ], 142 | "E": [ 143 | 64, 144 | 64, 145 | "M", 146 | 128, 147 | 128, 148 | "M", 149 | 256, 150 | 256, 151 | 256, 152 | 256, 153 | "M", 154 | 512, 155 | 512, 156 | 512, 157 | 512, 158 | "M", 159 | 512, 160 | 512, 161 | 512, 162 | 512, 163 | "M", 164 | ], 165 | } 166 | 167 | 168 | # def vgg(cfg,batch_norm,**kwargs): 169 | # kwargs['num_classes'] = 10 170 | # model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 171 | # return model 172 | 173 | 174 | def vgg(cfg, input_channels, classes, batch_norm, precision): 175 | model = VGG( 176 | make_layers( 177 | cfgs[cfg], 178 | in_channels=input_channels, 179 | batch_norm=batch_norm, 180 | precision=precision, 181 | ), 182 | classes, 183 | precision, 184 | True, 185 | ) 186 | return model 187 | 188 | 189 | # def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs): 190 | # if pretrained: 191 | # kwargs['init_weights'] = False 192 | # model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 193 | # if pretrained: 194 | # state_dict = load_state_dict_from_url(model_urls[arch], 195 | # progress=progress) 196 | # model.load_state_dict(state_dict) 197 | # return model 198 | # 199 | 200 | # def vgg11(pretrained=False, progress=True, **kwargs): 201 | # r"""VGG 11-layer model (configuration "A") from 202 | # `"Very Deep Convolutional Networks For Large-Scale 203 | # Image Recognition" `_ 204 | # Args: 205 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 206 | # progress (bool): If True, displays a progress bar of the download 207 | # to stderr 208 | # """ 209 | # return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) 210 | # 211 | # 212 | # def vgg11_bn(pretrained=False, progress=True, **kwargs): 213 | # r"""VGG 11-layer model (configuration "A") with batch normalization 214 | # `"Very Deep Convolutional Networks For Large-Scale 215 | # Image Recognition" `_ 216 | # Args: 217 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 218 | # progress (bool): If True, displays a progress bar of the download 219 | # to stderr 220 | # """ 221 | # return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) 222 | # 223 | # 224 | # def vgg13(pretrained=False, progress=True, **kwargs): 225 | # r"""VGG 13-layer model (configuration "B") 226 | # `"Very Deep Convolutional Networks For Large-Scale Image 227 | # Recognition" `_ 228 | # Args: 229 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 230 | # progress (bool): If True, displays a progress bar of the download 231 | # to stderr 232 | # """ 233 | # return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) 234 | # 235 | # 236 | # def vgg13_bn(pretrained=False, progress=True, **kwargs): 237 | # r"""VGG 13-layer model (configuration "B") with batch normalization 238 | # `"Very Deep Convolutional Networks For Large-Scale Image 239 | # Recognition" `_ 240 | # Args: 241 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 242 | # progress (bool): If True, displays a progress bar of the download 243 | # to stderr 244 | # """ 245 | # return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) 246 | # 247 | # 248 | # def vgg16(pretrained=False, progress=True, **kwargs): 249 | # r"""VGG 16-layer model (configuration "D") 250 | # `"Very Deep Convolutional Networks For Large-Scale Image 251 | # Recognition" `_ 252 | # Args: 253 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 254 | # progress (bool): If True, displays a progress bar of the download 255 | # to stderr 256 | # """ 257 | # return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) 258 | # 259 | # 260 | # def vgg16_bn(pretrained=False, progress=True, **kwargs): 261 | # r"""VGG 16-layer model (configuration "D") with batch normalization 262 | # `"Very Deep Convolutional Networks For Large-Scale Image 263 | # Recognition" `_ 264 | # Args: 265 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 266 | # progress (bool): If True, displays a progress bar of the download 267 | # to stderr 268 | # """ 269 | # return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) 270 | # 271 | # 272 | # def vgg19(pretrained=False, progress=True, **kwargs): 273 | # r"""VGG 19-layer model (configuration "E") 274 | # `"Very Deep Convolutional Networks For Large-Scale Image 275 | # Recognition" `_ 276 | # Args: 277 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 278 | # progress (bool): If True, displays a progress bar of the download 279 | # to stderr 280 | # """ 281 | # return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) 282 | # 283 | # 284 | # def vgg19_bn(pretrained=False, progress=True, **kwargs): 285 | # r"""VGG 19-layer model (configuration 'E') with batch normalization 286 | # `"Very Deep Convolutional Networks For Large-Scale Image 287 | # Recognition" `_ 288 | # Args: 289 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 290 | # progress (bool): If True, displays a progress bar of the download 291 | # to stderr 292 | # """ 293 | # return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs) 294 | # 295 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | """ResNet in PyTorch. 2 | 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | """ 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | from quantized_ops import zs_quantized_ops 11 | 12 | conv_clamp_val = 0.05 13 | fc_clamp_val = 0.1 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, in_planes, planes, stride=1, precision=-1): 20 | super(BasicBlock, self).__init__() 21 | if precision < 0: 22 | self.conv1 = nn.Conv2d( 23 | in_planes, 24 | planes, 25 | kernel_size=3, 26 | stride=stride, 27 | padding=1, 28 | bias=False, 29 | ) 30 | else: 31 | self.conv1 = zs_quantized_ops.nnConv2dSymQuant_op( 32 | in_planes, 33 | planes, 34 | kernel_size=3, 35 | stride=stride, 36 | padding=1, 37 | bias=False, 38 | precision=precision, 39 | clamp_val=conv_clamp_val, 40 | ) 41 | self.bn1 = nn.BatchNorm2d(planes) 42 | if precision < 0: 43 | self.conv2 = nn.Conv2d( 44 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False 45 | ) 46 | else: 47 | self.conv2 = zs_quantized_ops.nnConv2dSymQuant_op( 48 | planes, 49 | planes, 50 | kernel_size=3, 51 | stride=1, 52 | padding=1, 53 | bias=False, 54 | precision=precision, 55 | clamp_val=conv_clamp_val, 56 | ) 57 | self.bn2 = nn.BatchNorm2d(planes) 58 | self.relu1 = nn.ReLU() 59 | self.relu2 = nn.ReLU() 60 | self.shortcut = nn.Sequential() 61 | if stride != 1 or in_planes != self.expansion * planes: 62 | if precision < 0: 63 | self.shortcut = nn.Sequential( 64 | nn.Conv2d( 65 | in_planes, 66 | self.expansion * planes, 67 | kernel_size=1, 68 | stride=stride, 69 | bias=False, 70 | ), 71 | nn.BatchNorm2d(self.expansion * planes), 72 | ) 73 | else: 74 | self.shortcut = nn.Sequential( 75 | zs_quantized_ops.nnConv2dSymQuant_op( 76 | in_planes, 77 | self.expansion * planes, 78 | kernel_size=1, 79 | stride=stride, 80 | padding=0, 81 | bias=False, 82 | precision=precision, 83 | clamp_val=conv_clamp_val, 84 | ), 85 | nn.BatchNorm2d(self.expansion * planes), 86 | ) 87 | 88 | def forward(self, x): 89 | out = self.relu1(self.bn1(self.conv1(x))) 90 | out = self.bn2(self.conv2(out)) 91 | out += self.shortcut(x) 92 | out = self.relu2(out) 93 | return out 94 | 95 | class Bottleneck(nn.Module): 96 | expansion = 4 97 | 98 | def __init__( 99 | self, 100 | in_planes, 101 | planes, 102 | stride, 103 | precision, 104 | ): 105 | super(Bottleneck, self).__init__() 106 | if precision < 0: 107 | # print('In block') 108 | self.conv1 = nn.Conv2d( 109 | in_planes, 110 | planes, 111 | kernel_size=1, 112 | stride=1, 113 | padding=0, 114 | bias=False, 115 | ) 116 | else: 117 | self.conv1 = zs_quantized_ops.nnConv2dSymQuant_op( 118 | in_planes, 119 | planes, 120 | kernel_size=1, 121 | stride=1, 122 | padding=0, 123 | bias=False, 124 | precision=precision, 125 | clamp_val=conv_clamp_val, 126 | ) 127 | self.bn1 = nn.BatchNorm2d(planes) 128 | if precision < 0: 129 | self.conv2 = nn.Conv2d( 130 | planes, 131 | planes, 132 | kernel_size=3, 133 | stride=stride, 134 | padding=1, 135 | bias=False, 136 | ) 137 | else: 138 | self.conv2 = zs_quantized_ops.nnConv2dSymQuant_op( 139 | planes, 140 | planes, 141 | kernel_size=3, 142 | stride=stride, 143 | padding=1, 144 | bias=False, 145 | precision=precision, 146 | clamp_val=conv_clamp_val, 147 | ) 148 | self.bn2 = nn.BatchNorm2d(planes) 149 | if precision < 0: 150 | self.conv3 = nn.Conv2d( 151 | planes, 152 | self.expansion * planes, 153 | kernel_size=1, 154 | stride=1, 155 | padding=0, 156 | bias=False, 157 | ) 158 | else: 159 | self.conv3 = zs_quantized_ops.nnConv2dSymQuant_op( 160 | planes, 161 | self.expansion * planes, 162 | kernel_size=1, 163 | stride=1, 164 | padding=0, 165 | bias=False, 166 | precision=precision, 167 | clamp_val=conv_clamp_val, 168 | ) 169 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 170 | self.relu1 = nn.ReLU() 171 | self.relu2 = nn.ReLU() 172 | self.relu3 = nn.ReLU() 173 | self.shortcut = nn.Sequential() 174 | if stride != 1 or in_planes != self.expansion * planes: 175 | if precision < 0: 176 | self.shortcut = nn.Sequential( 177 | nn.Conv2d( 178 | in_planes, 179 | self.expansion * planes, 180 | kernel_size=1, 181 | stride=stride, 182 | padding=0, 183 | bias=False, 184 | ), 185 | nn.BatchNorm2d(self.expansion * planes), 186 | ) 187 | else: 188 | self.shortcut = nn.Sequential( 189 | zs_quantized_ops.nnConv2dSymQuant_op( 190 | in_planes, 191 | self.expansion * planes, 192 | kernel_size=1, 193 | stride=stride, 194 | padding=0, 195 | bias=False, 196 | precision=precision, 197 | clamp_val=conv_clamp_val, 198 | ), 199 | nn.BatchNorm2d(self.expansion * planes), 200 | ) 201 | 202 | def forward(self, x): 203 | out = self.relu1(self.bn1(self.conv1(x))) 204 | out = self.relu2(self.bn2(self.conv2(out))) 205 | out = self.bn3(self.conv3(out)) 206 | out += self.shortcut(x) 207 | out = self.relu3(out) 208 | return out 209 | 210 | 211 | class ResNet(nn.Module): 212 | def __init__(self, block, num_blocks, num_classes=10, precision=-1): 213 | super(ResNet, self).__init__() 214 | self.in_planes = 64 215 | 216 | if precision < 0: 217 | self.conv1 = nn.Conv2d( 218 | 3, 64, kernel_size=3, stride=1, padding=1, bias=False 219 | ) 220 | else: 221 | self.conv1 = zs_quantized_ops.nnConv2dSymQuant_op( 222 | 3, 223 | 64, 224 | kernel_size=3, 225 | stride=1, 226 | padding=1, 227 | bias=False, 228 | precision=precision, 229 | clamp_val=conv_clamp_val, 230 | ) 231 | self.bn1 = nn.BatchNorm2d(64) 232 | self.layer1 = self._make_layer( 233 | block, 64, num_blocks[0], stride=1, precision=precision 234 | ) 235 | self.layer2 = self._make_layer( 236 | block, 128, num_blocks[1], stride=2, precision=precision 237 | ) 238 | self.layer3 = self._make_layer( 239 | block, 256, num_blocks[2], stride=2, precision=precision 240 | ) 241 | self.layer4 = self._make_layer( 242 | block, 512, num_blocks[3], stride=2, precision=precision 243 | ) 244 | if precision < 0: 245 | self.linear = nn.Linear(512 * block.expansion, num_classes) 246 | else: 247 | self.linear = zs_quantized_ops.nnLinearSymQuant_op( 248 | 512 * block.expansion, num_classes, precision, fc_clamp_val 249 | ) 250 | 251 | def _make_layer(self, block, planes, num_blocks, stride, precision=-1): 252 | strides = [stride] + [1] * (num_blocks - 1) 253 | layers = [] 254 | for stride in strides: 255 | layers.append(block(self.in_planes, planes, stride, precision)) 256 | self.in_planes = planes * block.expansion 257 | return nn.Sequential(*layers) 258 | 259 | def forward(self, x): 260 | out = F.relu(self.bn1(self.conv1(x))) 261 | out = self.layer1(out) 262 | out = self.layer2(out) 263 | out = self.layer3(out) 264 | out = self.layer4(out) 265 | out = F.avg_pool2d(out, 4) 266 | out = out.view(out.size(0), -1) 267 | out = self.linear(out) 268 | return out 269 | 270 | 271 | def ResNet18(classes, precision): 272 | return ResNet(BasicBlock, [2, 2, 2, 2], classes, precision) 273 | 274 | 275 | def ResNet34(classes, precision): 276 | return ResNet(BasicBlock, [3, 4, 6, 3], classes, precision) 277 | 278 | def ResNet50(classes, precision): 279 | return ResNet(Bottleneck, [3, 4, 6, 3], classes, precision) 280 | 281 | 282 | def resnet(arch, classes, precision): 283 | if arch == "resnet18": 284 | return ResNet18(classes, precision) 285 | elif arch == "resnet34": 286 | return ResNet34(classes, precision) 287 | elif arch == "resnet50": 288 | return ResNet50(classes, precision) 289 | 290 | 291 | # def test(): 292 | # net = ResNet18() 293 | # y = net(torch.randn(1,3,32,32)) 294 | # print(y.size()) 295 | 296 | # test() 297 | -------------------------------------------------------------------------------- /quantized_ops/zs_quantized_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | quantized version of nn.Linear 3 | Followed the explanation in 4 | https://pytorch.org/docs/stable/notes/extending.html on how to write a custom 5 | autograd function and extend the nn.Linear class 6 | https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html 7 | """ 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from config import cfg 12 | from torch import nn 13 | from torch.nn.modules.utils import _pair 14 | 15 | # from torch.nn.modules.utils import _single 16 | 17 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 18 | dtype = torch.float32 19 | 20 | 21 | class SymmetricQuantizeDequantize(torch.autograd.Function): 22 | 23 | # Quantize and dequantize in the forward pass 24 | @staticmethod 25 | def forward(ctx, input, precision, clamp_val, use_max=True, q_method='symmetric_signed'): 26 | """ 27 | Quantize and dequantize the model weights. 28 | The gradients will be applied to origin weights. 29 | :param ctx: Bulid-in parameter. 30 | :param input: Model weights. 31 | :param precision: The number of bits would be used to quantize the 32 | model. 33 | :param clamp_val: The range to be used to clip the weights. 34 | """ 35 | ctx.save_for_backward(input) 36 | # ctx.mark_dirty(input) 37 | 38 | if q_method == 'symmetric_unsigned' or q_method == 'asymmetric_unsigned': 39 | """ 40 | Compute quantization step size. 41 | Mapping (min_val, max_val) linearly to (0, 2^m - 1) 42 | """ 43 | if use_max and q_method == 'symmetric_unsigned': 44 | max_val = torch.max(torch.abs(input)) 45 | min_val = -max_val 46 | input = torch.clamp(input, min_val, max_val) 47 | delta = max_val / (2 ** (precision - 1) - 1) 48 | input_q = torch.round((input / delta)) 49 | input_q = input_q.to(torch.int32) + (2 ** (precision - 1) - 1) 50 | """ 51 | Dequantize introducing a quantization error in the data 52 | """ 53 | input_dq = (input_q - (2 ** (precision - 1) - 1)) * delta 54 | input_dq = input_dq.to(torch.float32) 55 | 56 | elif not use_max and q_method == 'asymmetric_unsigned': 57 | max_val = torch.max(input) 58 | min_val = torch.min(input) 59 | delta = 1 / (2 ** (precision - 1) - 1) 60 | input = (input - min_val) / (max_val - min_val) 61 | input = input * 2 - 1 # To -1 ~ 1 62 | input_q = torch.round((input / delta)) 63 | input_q = input_q.to(torch.int32) + (2 ** (precision - 1) - 1) 64 | """ 65 | Dequantize introducing a quantization error in the data 66 | """ 67 | input_dq = (input_q - (2 ** (precision - 1) - 1)) * delta 68 | input_dq = ((input_dq + 1) / 2) * (max_val - min_val) + min_val 69 | input_dq = input_dq.to(torch.float32) 70 | 71 | 72 | elif q_method == 'symmetric_signed': 73 | """ 74 | Compute quantization step size. 75 | Mapping (-max_val, max_val) linearly to (-127,127) 76 | """ 77 | if use_max: 78 | max_val = torch.max(torch.abs(input)) 79 | else: 80 | max_val = clamp_val 81 | 82 | delta = max_val / (2 ** (precision - 1) - 1) 83 | input_clamped = torch.clamp(input, -max_val, max_val) 84 | input_q = torch.round((input_clamped / delta)) 85 | if precision > 0 and precision <= 8: 86 | input_q = input_q.to(torch.int8) 87 | elif precision == 16: 88 | input_q = input_q.to(torch.int16) 89 | else: 90 | input_q = input_q.to(torch.int32) 91 | 92 | """ 93 | Dequantize introducing a quantization error in the data 94 | """ 95 | input_dq = input_q * delta 96 | input_dq = input_dq.to(torch.float32) 97 | # Return the dequantized weights tensor. 98 | # We want to update the original weights(not quantized weights) under 99 | # quantization aware training. 100 | # So, we don't use input.copy_(input_dq) to replace self.weight with 101 | # input_dq. 102 | 103 | return input_dq 104 | 105 | # Straight-through-estimator in backward pass 106 | # https://discuss.pytorch.org/t/ 107 | # integrating-a-new-loss-function-to-autograd/3684/2 108 | # Without using this will cause gradients problems. 109 | @staticmethod 110 | def backward(ctx, grad_output): 111 | (input,) = ctx.saved_tensors 112 | return grad_output, None, None 113 | 114 | 115 | class nnLinearSymQuant(nn.Linear): 116 | """Applies a linear transformation to the incoming data: y = xA^T + b 117 | Along with the linear transform, the learnable weights are quantized and 118 | dequantized introducing a quantization error in the data. 119 | """ 120 | 121 | def __init__( 122 | self, in_features, out_features, bias, precision=-1, clamp_val=0.1 123 | ): 124 | super().__init__(in_features, out_features, bias) 125 | self.in_features = in_features 126 | self.out_features = out_features 127 | # self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) 128 | # if bias: 129 | # self.bias = nn.Parameter(torch.Tensor(out_features)) 130 | # else: 131 | # self.register_parameter('bias', None) 132 | self.precision = precision 133 | self.clamp_val = clamp_val 134 | self.reset_parameters() 135 | 136 | def forward(self, input): 137 | if self.precision > 0: 138 | quantWeight = SymmetricQuantizeDequantize.apply 139 | weight = quantWeight(self.weight, self.precision, self.clamp_val) 140 | return F.linear(input, weight, self.bias) 141 | 142 | def extra_repr(self) -> str: 143 | return "in_features={}, out_features={}, bias={}, precision={}".format( 144 | self.in_features, 145 | self.out_features, 146 | self.bias is not None, 147 | self.precision, 148 | ) 149 | 150 | 151 | def nnLinearSymQuant_op(in_features, out_features, precision, clamp_val): 152 | return nnLinearSymQuant( 153 | in_features, out_features, True, precision, clamp_val 154 | ) 155 | 156 | 157 | class nnConv2dSymQuant(nn.Conv2d): 158 | """ 159 | Computes 2d conv output 160 | Weights are quantized and dequantized introducing a quantization error 161 | """ 162 | 163 | def __init__( 164 | self, 165 | in_channels, 166 | out_channels, 167 | kernel_size, 168 | stride=1, 169 | padding=1, 170 | dilation=1, 171 | groups=1, 172 | bias=True, 173 | padding_mode="zeros", 174 | precision=-1, 175 | clamp_val=0.5, 176 | ): 177 | kernel_size = _pair(kernel_size) 178 | stride = _pair(stride) 179 | padding = _pair(padding) 180 | dilation = _pair(dilation) 181 | super().__init__( 182 | in_channels, 183 | out_channels, 184 | kernel_size, 185 | stride, 186 | padding, 187 | dilation, 188 | groups, 189 | bias, 190 | padding_mode, 191 | ) 192 | self.precision = precision 193 | self.clamp_val = clamp_val 194 | 195 | def forward(self, input): 196 | if self.precision > 0: 197 | quantWeight = SymmetricQuantizeDequantize.apply 198 | quant_weight = quantWeight( 199 | self.weight, self.precision, self.clamp_val 200 | ) 201 | 202 | return F.conv2d( 203 | input, 204 | quant_weight, 205 | self.bias, 206 | self.stride, 207 | self.padding, 208 | self.dilation, 209 | self.groups, 210 | ) 211 | 212 | 213 | def nnConv2dSymQuant_op( 214 | in_channels, 215 | out_channels, 216 | kernel_size, 217 | stride, 218 | padding, 219 | bias, 220 | precision, 221 | clamp_val, 222 | ): 223 | return nnConv2dSymQuant( 224 | in_channels, 225 | out_channels, 226 | kernel_size, 227 | stride=stride, 228 | padding=padding, 229 | dilation=1, 230 | groups=1, 231 | bias=bias, 232 | padding_mode="zeros", 233 | precision=precision, 234 | clamp_val=clamp_val, 235 | ) 236 | 237 | class nnConvTranspose2dSymQuant(nn.ConvTranspose2d): 238 | """ 239 | Computes 2d convtranspose output 240 | Weights are quantized and dequantized introducing a quantization error 241 | """ 242 | 243 | def __init__( 244 | self, 245 | in_channels, 246 | out_channels, 247 | kernel_size, 248 | stride=1, 249 | padding=1, 250 | dilation=1, 251 | groups=1, 252 | bias=True, 253 | precision=-1, 254 | clamp_val=0.5, 255 | ): 256 | kernel_size = _pair(kernel_size) 257 | stride = _pair(stride) 258 | padding = _pair(padding) 259 | dilation = _pair(dilation) 260 | 261 | super().__init__( 262 | in_channels, 263 | out_channels, 264 | kernel_size, 265 | stride, 266 | padding, 267 | dilation, 268 | groups, 269 | bias, 270 | ) 271 | self.precision = precision 272 | self.clamp_val = clamp_val 273 | 274 | def forward(self, input): 275 | if self.precision > 0: 276 | quantWeight = SymmetricQuantizeDequantize.apply 277 | quant_weight = quantWeight( 278 | self.weight, self.precision, self.clamp_val 279 | ) 280 | 281 | return F.conv_transpose2d( 282 | input = input, 283 | weight = quant_weight, 284 | bias = self.bias, 285 | stride = self.stride, 286 | padding = self.padding, 287 | groups = self.groups, 288 | dilation = self.dilation 289 | ) 290 | 291 | def nnConvTranspose2dSymQuant_op( 292 | in_channels, 293 | out_channels, 294 | kernel_size, 295 | stride, 296 | padding, 297 | bias, 298 | precision, 299 | clamp_val, 300 | ): 301 | return nnConvTranspose2dSymQuant( 302 | in_channels, 303 | out_channels, 304 | kernel_size, 305 | stride=stride, 306 | padding=padding, 307 | dilation=1, 308 | groups=1, 309 | bias=bias, 310 | precision=precision, 311 | clamp_val=clamp_val, 312 | ) 313 | -------------------------------------------------------------------------------- /models/vggf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from faultinjection_ops import zs_faultinjection_ops 5 | from quantized_ops import zs_quantized_ops 6 | 7 | # Per layer clamping currently based on manual values set 8 | weight_clamp_values = [ 9 | 0.2, 10 | 0.2, 11 | 0.15, 12 | 0.13, 13 | 0.1, 14 | 0.1, 15 | 0.1, 16 | 0.05, 17 | 0.05, 18 | 0.05, 19 | 0.05, 20 | 0.05, 21 | 0.05, 22 | 0.05, 23 | 0.05, 24 | 0.05, 25 | 0.05, 26 | 0.05, 27 | 0.05, 28 | ] 29 | fc_weight_clamp = 0.1 30 | 31 | 32 | class VGG(nn.Module): 33 | def __init__(self, features, classifier, classes=10, init_weights=True): 34 | super(VGG, self).__init__() 35 | self.features = features 36 | # self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 37 | # self.avgpool = nn.AvgPool2d(kernel_size=2, stride=1) 38 | self.classifier = classifier 39 | # self.classifier = nn.Sequential( 40 | # nn.Linear(512 * 7 * 7, 4096), 41 | # nn.ReLU(True), 42 | # nn.Dropout(), 43 | # nn.Linear(4096, 4096), 44 | # nn.ReLU(True), 45 | # nn.Dropout(), 46 | # nn.Linear(4096, classes), 47 | # ) 48 | if init_weights: 49 | self._initialize_weights() 50 | 51 | def forward(self, x): 52 | x = self.features(x) 53 | # x = self.avgpool(x) 54 | x = nn.AvgPool2d(kernel_size=x.shape[3], stride=1)(x) 55 | x = torch.flatten(x, 1) 56 | x = self.classifier(x) 57 | return x 58 | 59 | def _initialize_weights(self): 60 | for m in self.modules(): 61 | if isinstance(m, nn.Conv2d): 62 | nn.init.kaiming_normal_( 63 | m.weight, mode="fan_out", nonlinearity="relu" 64 | ) 65 | if m.bias is not None: 66 | nn.init.constant_(m.bias, 0) 67 | elif isinstance(m, nn.BatchNorm2d): 68 | nn.init.constant_(m.weight, 1) 69 | nn.init.constant_(m.bias, 0) 70 | elif isinstance(m, nn.Linear): 71 | nn.init.normal_(m.weight, 0, 0.01) 72 | nn.init.constant_(m.bias, 0) 73 | 74 | 75 | def make_classifier( 76 | classes, 77 | precision, 78 | ber, 79 | position, 80 | faulty_layers, 81 | ): 82 | if "linear" in faulty_layers: 83 | classifier = zs_faultinjection_ops.nnLinearPerturbWeight_op( 84 | 512, 85 | classes, 86 | precision, 87 | fc_weight_clamp, 88 | ) 89 | else: 90 | classifier = zs_quantized_ops.nnLinearSymQuant_op( 91 | 512, classes, precision, fc_weight_clamp 92 | ) 93 | 94 | return classifier 95 | 96 | 97 | def make_layers( 98 | cfg, 99 | in_channels, 100 | batch_norm, 101 | precision, 102 | ber, 103 | position, 104 | faulty_layers, 105 | ): 106 | layers = [] 107 | # in_channels = 3 108 | cl = 0 109 | # pdb.set_trace() 110 | for v in cfg: 111 | if v == "M": 112 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 113 | else: 114 | if "conv" in faulty_layers: 115 | conv2d = zs_faultinjection_ops.nnConv2dPerturbWeight_op( 116 | in_channels, 117 | v, 118 | kernel_size=3, 119 | stride=1, 120 | padding=1, 121 | bias=True, 122 | precision=precision, 123 | clamp_val=weight_clamp_values[cl], 124 | ) 125 | else: 126 | conv2d = zs_quantized_ops.nnConv2dSymQuant_op( 127 | in_channels, 128 | v, 129 | kernel_size=3, 130 | stride=1, 131 | padding=1, 132 | bias=True, 133 | precision=precision, 134 | clamp_val=weight_clamp_values[cl], 135 | ) 136 | cl = cl + 1 137 | 138 | if batch_norm: 139 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 140 | else: 141 | layers += [conv2d, nn.ReLU(inplace=True)] 142 | in_channels = v 143 | return nn.Sequential(*layers) 144 | 145 | 146 | cfgs = { 147 | "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], 148 | "B": [ 149 | 64, 150 | 64, 151 | "M", 152 | 128, 153 | 128, 154 | "M", 155 | 256, 156 | 256, 157 | "M", 158 | 512, 159 | 512, 160 | "M", 161 | 512, 162 | 512, 163 | "M", 164 | ], 165 | "D": [ 166 | 64, 167 | 64, 168 | "M", 169 | 128, 170 | 128, 171 | "M", 172 | 256, 173 | 256, 174 | 256, 175 | "M", 176 | 512, 177 | 512, 178 | 512, 179 | "M", 180 | 512, 181 | 512, 182 | 512, 183 | "M", 184 | ], 185 | "E": [ 186 | 64, 187 | 64, 188 | "M", 189 | 128, 190 | 128, 191 | "M", 192 | 256, 193 | 256, 194 | 256, 195 | 256, 196 | "M", 197 | 512, 198 | 512, 199 | 512, 200 | 512, 201 | "M", 202 | 512, 203 | 512, 204 | 512, 205 | 512, 206 | "M", 207 | ], 208 | } 209 | 210 | 211 | # def vgg(cfg,batch_norm,**kwargs): 212 | # kwargs['num_classes'] = 10 213 | # model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 214 | # return model 215 | 216 | 217 | def vggf( 218 | cfg, 219 | input_channels, 220 | classes, 221 | batch_norm, 222 | precision, 223 | ber, 224 | position, 225 | faulty_layers, 226 | ): 227 | model = VGG( 228 | make_layers( 229 | cfgs[cfg], 230 | in_channels=input_channels, 231 | batch_norm=batch_norm, 232 | precision=precision, 233 | ber=ber, 234 | position=position, 235 | faulty_layers=faulty_layers, 236 | ), 237 | make_classifier( 238 | classes, 239 | precision, 240 | ber, 241 | position, 242 | faulty_layers, 243 | ), 244 | classes, 245 | True, 246 | ) 247 | return model 248 | 249 | 250 | # def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs): 251 | # if pretrained: 252 | # kwargs['init_weights'] = False 253 | # model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 254 | # if pretrained: 255 | # state_dict = load_state_dict_from_url(model_urls[arch], 256 | # progress=progress) 257 | # model.load_state_dict(state_dict) 258 | # return model 259 | # 260 | 261 | # def vgg11(pretrained=False, progress=True, **kwargs): 262 | # r"""VGG 11-layer model (configuration "A") from 263 | # `"Very Deep Convolutional Networks For Large-Scale Image 264 | # Recognition" `_ 265 | # Args: 266 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 267 | # progress (bool): If True, displays a progress bar of the download to 268 | # stderr 269 | # """ 270 | # return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) 271 | # 272 | # 273 | # def vgg11_bn(pretrained=False, progress=True, **kwargs): 274 | # r"""VGG 11-layer model (configuration "A") with batch normalization 275 | # `"Very Deep Convolutional Networks For Large-Scale Image 276 | # Recognition" `_ 277 | # Args: 278 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 279 | # progress (bool): If True, displays a progress bar of the download to 280 | # stderr 281 | # """ 282 | # return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) 283 | # 284 | # 285 | # def vgg13(pretrained=False, progress=True, **kwargs): 286 | # r"""VGG 13-layer model (configuration "B") 287 | # `"Very Deep Convolutional Networks For Large-Scale Image 288 | # Recognition" `_ 289 | # Args: 290 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 291 | # progress (bool): If True, displays a progress bar of the download to 292 | # stderr 293 | # """ 294 | # return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) 295 | # 296 | # 297 | # def vgg13_bn(pretrained=False, progress=True, **kwargs): 298 | # r"""VGG 13-layer model (configuration "B") with batch normalization 299 | # `"Very Deep Convolutional Networks For Large-Scale Image 300 | # Recognition" `_ 301 | # Args: 302 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 303 | # progress (bool): If True, displays a progress bar of the download to 304 | # stderr 305 | # """ 306 | # return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) 307 | # 308 | # 309 | # def vgg16(pretrained=False, progress=True, **kwargs): 310 | # r"""VGG 16-layer model (configuration "D") 311 | # `"Very Deep Convolutional Networks For Large-Scale Image 312 | # Recognition" `_ 313 | # Args: 314 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 315 | # progress (bool): If True, displays a progress bar of the download to 316 | # stderr 317 | # """ 318 | # return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) 319 | # 320 | # 321 | # def vgg16_bn(pretrained=False, progress=True, **kwargs): 322 | # r"""VGG 16-layer model (configuration "D") with batch normalization 323 | # `"Very Deep Convolutional Networks For Large-Scale Image 324 | # Recognition" `_ 325 | # Args: 326 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 327 | # progress (bool): If True, displays a progress bar of the download to 328 | # stderr 329 | # """ 330 | # return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) 331 | # 332 | # 333 | # def vgg19(pretrained=False, progress=True, **kwargs): 334 | # r"""VGG 19-layer model (configuration "E") 335 | # `"Very Deep Convolutional Networks For Large-Scale Image 336 | # Recognition" `_ 337 | # Args: 338 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 339 | # progress (bool): If True, displays a progress bar of the download to 340 | # stderr 341 | # """ 342 | # return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) 343 | # 344 | # 345 | # def vgg19_bn(pretrained=False, progress=True, **kwargs): 346 | # r"""VGG 19-layer model (configuration 'E') with batch normalization 347 | # `"Very Deep Convolutional Networks For Large-Scale Image 348 | # Recognition" `_ 349 | # Args: 350 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 351 | # progress (bool): If True, displays a progress bar of the download to 352 | # stderr 353 | # """ 354 | # return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs) 355 | # 356 | -------------------------------------------------------------------------------- /zs_train_input_transform_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | 5 | from config import cfg 6 | 7 | from models import vggfPy 8 | from models import resnetfPy 9 | from models import init_models_pairs, create_faults, init_models 10 | from models.generator import * 11 | import faultsMap as fmap 12 | 13 | from collections import OrderedDict 14 | from sklearn.manifold import TSNE 15 | import matplotlib.pyplot as plt 16 | import itertools 17 | import numpy as np 18 | import tqdm 19 | import copy 20 | 21 | 22 | torch.manual_seed(0) 23 | device = "cuda" if torch.cuda.is_available() else "cpu" 24 | EPS = 1e-20 25 | PGD_STEP = 2 26 | 27 | 28 | def accuracy_checking_clean(model_orig, trainloader, testloader, device): 29 | """ 30 | Calculating the accuracy with given clean model and perturbed model. 31 | :param model_orig: Clean model. 32 | :param model_p: Perturbed model. 33 | :param trainloader: The loader of training data. 34 | :param testloader: The loader of testing data. 35 | :param transform_model: The object of transformation model. 36 | :param device: Specify GPU usage. 37 | :use_transform: Should apply input transformation or not. 38 | """ 39 | cfg.replaceWeight = False 40 | total_train = 0 41 | total_test = 0 42 | correct_orig_train = 0 43 | correct_p_train = 0 44 | correct_orig_test = 0 45 | correct_p_test = 0 46 | 47 | # For training data: 48 | for x, y in trainloader: 49 | total_train += 1 50 | x, y = x.to(device), y.to(device) 51 | out_orig = model_orig(x) 52 | _, pred_orig = out_orig.max(1) 53 | y = y.view(y.size(0)) 54 | correct_orig_train += torch.sum(pred_orig == y.data).item() 55 | accuracy_orig_train = correct_orig_train / (len(trainloader.dataset)) 56 | 57 | # For testing data: 58 | for x, y in testloader: 59 | total_test += 1 60 | x, y = x.to(device), y.to(device) 61 | out_orig = model_orig(x) 62 | _, pred_orig = out_orig.max(1) 63 | y = y.view(y.size(0)) 64 | correct_orig_test += torch.sum(pred_orig == y.data).item() 65 | accuracy_orig_test = correct_orig_test / (len(testloader.dataset)) 66 | 67 | print("Accuracy of training data: clean model: {:5f}".format(accuracy_orig_train)) 68 | print("Accuracy of testing data: clean model: {:5f}".format(accuracy_orig_test)) 69 | 70 | 71 | def accuracy_checking(model_orig, model_p, trainloader, testloader, transform_model, device, use_transform=False): 72 | """ 73 | Calculating the accuracy with given clean model and perturbed model. 74 | :param model_orig: Clean model. 75 | :param model_p: Perturbed model. 76 | :param trainloader: The loader of training data. 77 | :param testloader: The loader of testing data. 78 | :param transform_model: The object of transformation model. 79 | :param device: Specify GPU usage. 80 | :use_transform: Should apply input transformation or not. 81 | """ 82 | cfg.replaceWeight = False 83 | total_train = 0 84 | total_test = 0 85 | correct_orig_train = 0 86 | correct_p_train = 0 87 | correct_orig_test = 0 88 | correct_p_test = 0 89 | 90 | # For training data: 91 | for x, y in trainloader: 92 | total_train += 1 93 | x, y = x.to(device), y.to(device) 94 | if use_transform: 95 | x_adv = transform_model(x) 96 | out_orig = model_orig(x_adv) 97 | out_p = model_p(x_adv) 98 | else: 99 | out_orig = model_orig(x) 100 | out_p = model_p(x) 101 | _, pred_orig = out_orig.max(1) 102 | _, pred_p = out_p.max(1) 103 | y = y.view(y.size(0)) 104 | correct_orig_train += torch.sum(pred_orig == y.data).item() 105 | correct_p_train += torch.sum(pred_p == y.data).item() 106 | accuracy_orig_train = correct_orig_train / (len(trainloader.dataset)) 107 | accuracy_p_train = correct_p_train / (len(trainloader.dataset)) 108 | 109 | # For testing data: 110 | for x, y in testloader: 111 | total_test += 1 112 | x, y = x.to(device), y.to(device) 113 | if use_transform: 114 | x_adv = transform_model(x) 115 | out_orig = model_orig(x_adv) 116 | out_p = model_p(x_adv) 117 | else: 118 | out_orig = model_orig(x) 119 | out_p = model_p(x) 120 | _, pred_orig = out_orig.max(1) 121 | _, pred_p = out_p.max(1) 122 | y = y.view(y.size(0)) 123 | correct_orig_test += torch.sum(pred_orig == y.data).item() 124 | correct_p_test += torch.sum(pred_p == y.data).item() 125 | accuracy_orig_test = correct_orig_test / (len(testloader.dataset)) 126 | accuracy_p_test = correct_p_test / (len(testloader.dataset)) 127 | 128 | print("Accuracy of training data: clean model: {:5f}, perturbed model: {:5f}".format( 129 | accuracy_orig_train, 130 | accuracy_p_train 131 | ) 132 | ) 133 | print("Accuracy of testing data: clean model: {:5f}, perturbed model: {:5f}".format( 134 | accuracy_orig_test, 135 | accuracy_p_test 136 | ) 137 | ) 138 | 139 | return accuracy_orig_train, accuracy_p_train, accuracy_orig_test, accuracy_p_test 140 | 141 | def transform_eval( 142 | trainloader, 143 | testloader, 144 | arch, 145 | dataset, 146 | in_channels, 147 | precision, 148 | checkpoint_path, 149 | force, 150 | device, 151 | fl, 152 | ber, 153 | pos, 154 | ): 155 | """ 156 | Apply quantization aware training. 157 | :param trainloader: The loader of training data. 158 | :param in_channels: An int. The input channels of the training data. 159 | :param arch: A string. The architecture of the model would be used. 160 | :param dataset: A string. The name of the training data. 161 | :param ber: A float. How many rate of bits would be attacked. 162 | :param precision: An int. The number of bits would be used to quantize 163 | the model. 164 | :param position: 165 | :param checkpoint_path: A string. The path that stores the models. 166 | :param device: Specify GPU usage. 167 | """ 168 | torch.backends.cudnn.benchmark = True 169 | 170 | if(cfg.testing_mode == 'clean'): 171 | 172 | model, checkpoint_epoch = init_models(arch, 3, precision, True, checkpoint_path, dataset) 173 | 174 | model = model.to(device) 175 | model.eval() 176 | accuracy_checking_clean(model, trainloader, testloader, device) 177 | 178 | if cfg.testing_mode == 'generator_base': 179 | 180 | if cfg.G == 'ConvL': 181 | Gen = GeneratorConvLQ(precision) 182 | elif cfg.G == 'ConvS': 183 | Gen = GeneratorConvSQ(precision) 184 | elif cfg.G == 'DeConvL': 185 | Gen = GeneratorDeConvLQ(precision) 186 | elif cfg.G == 'DeConvS': 187 | Gen = GeneratorDeConvSQ(precision) 188 | elif cfg.G == 'UNetL': 189 | Gen = GeneratorUNetLQ(precision) 190 | elif cfg.G == 'UNetS': 191 | Gen = GeneratorUNetSQ(precision) 192 | 193 | Gen.load_state_dict(torch.load(cfg.G_PATH)) 194 | Gen = Gen.to(device) 195 | print('Successfully loading the generator model.') 196 | 197 | print('========== Start checking the accuracy with different perturbed model: bit error mode ==========') 198 | # Setting without input transformation 199 | accuracy_orig_train_list = [] 200 | accuracy_p_train_list = [] 201 | accuracy_orig_test_list = [] 202 | accuracy_p_test_list = [] 203 | 204 | # Setting with input transformation 205 | accuracy_orig_train_list_with_transformation = [] 206 | accuracy_p_train_list_with_transformation = [] 207 | accuracy_orig_test_list_with_transformation = [] 208 | accuracy_p_test_list_with_transformation = [] 209 | 210 | for i in range(50000, 50010): 211 | print(' ********** For seed: {} ********** '.format(i)) 212 | (model, checkpoint_epoch, model_perturbed, checkpoint_epoch_perturbed) = init_models_pairs( 213 | arch, in_channels, precision, True, checkpoint_path, fl, ber, pos, seed=i, dataset=dataset) 214 | model, model_perturbed = model.to(device), model_perturbed.to(device), 215 | fmap.BitErrorMap0to1 = None 216 | fmap.BitErrorMap1to0 = None 217 | create_faults(precision, ber, pos, seed=i) 218 | model.eval() 219 | model_perturbed.eval() 220 | Gen.eval() 221 | 222 | # Without using transform 223 | accuracy_orig_train, accuracy_p_train, accuracy_orig_test, accuracy_p_test = accuracy_checking(model, model_perturbed, trainloader, testloader, Gen, device, use_transform=False) 224 | accuracy_orig_train_list.append(accuracy_orig_train) 225 | accuracy_p_train_list.append(accuracy_p_train) 226 | accuracy_orig_test_list.append(accuracy_orig_test) 227 | accuracy_p_test_list.append(accuracy_p_test) 228 | 229 | # With input transform 230 | accuracy_orig_train, accuracy_p_train, accuracy_orig_test, accuracy_p_test = accuracy_checking(model, model_perturbed, trainloader, testloader, Gen, device, use_transform=True) 231 | accuracy_orig_train_list_with_transformation.append(accuracy_orig_train) 232 | accuracy_p_train_list_with_transformation.append(accuracy_p_train) 233 | accuracy_orig_test_list_with_transformation.append(accuracy_orig_test) 234 | accuracy_p_test_list_with_transformation.append(accuracy_p_test) 235 | 236 | 237 | # Without using transform 238 | print('The average results without input transformation -> accuracy_orig_train: {:5f}, accuracy_p_train: {:5f}, accuracy_orig_test: {:5f}, accuracy_p_test: {:5f}'.format( 239 | np.mean(accuracy_orig_train_list), 240 | np.mean(accuracy_p_train_list), 241 | np.mean(accuracy_orig_test_list), 242 | np.mean(accuracy_p_test_list) 243 | ) 244 | ) 245 | print('The average results without input transformation -> std_accuracy_orig_train: {:5f}, std_accuracy_p_train: {:5f}, std_accuracy_orig_test: {:5f}, std_accuracy_p_test: {:5f}'.format( 246 | np.std(accuracy_orig_train_list), 247 | np.std(accuracy_p_train_list), 248 | np.std(accuracy_orig_test_list), 249 | np.std(accuracy_p_test_list) 250 | ) 251 | ) 252 | 253 | print() 254 | 255 | # With input transform 256 | print('The average results with input transformation -> accuracy_orig_train: {:5f}, accuracy_p_train: {:5f}, accuracy_orig_test: {:5f}, accuracy_p_test: {:5f}'.format( 257 | np.mean(accuracy_orig_train_list_with_transformation), 258 | np.mean(accuracy_p_train_list_with_transformation), 259 | np.mean(accuracy_orig_test_list_with_transformation), 260 | np.mean(accuracy_p_test_list_with_transformation) 261 | ) 262 | ) 263 | print('The average results with input transformation -> std_accuracy_orig_train: {:5f}, std_accuracy_p_train: {:5f}, std_accuracy_orig_test: {:5f}, std_accuracy_p_test: {:5f}'.format( 264 | np.std(accuracy_orig_train_list_with_transformation), 265 | np.std(accuracy_p_train_list_with_transformation), 266 | np.std(accuracy_orig_test_list_with_transformation), 267 | np.std(accuracy_p_test_list_with_transformation) 268 | ) 269 | ) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from .resnetf_pytorch import resnetfPy 5 | from .resnetf import resnetf 6 | from .resnet import resnet 7 | from .vggf_pytorch import vggfPy 8 | from .vggf import vggf 9 | from .vgg import vgg 10 | 11 | from faultmodels import randomfault 12 | import faultsMap as fmap 13 | from config import cfg 14 | 15 | # Create the fault map from randomfault module. 16 | def create_faults(precision, bit_error_rate, position, seed=0): 17 | rf = randomfault.RandomFaultModel( 18 | bit_error_rate, precision, position, seed=seed 19 | ) 20 | fmap.BitErrorMap0 = ( 21 | torch.tensor(rf.BitErrorMap_flip0).to(torch.int32).to(cfg.device) 22 | ) 23 | fmap.BitErrorMap1 = ( 24 | torch.tensor(rf.BitErrorMap_flip1).to(torch.int32).to(cfg.device) 25 | ) 26 | 27 | 28 | def init_models(arch, in_channels, precision, retrain, checkpoint_path, dataset='cifar10'): 29 | 30 | """ 31 | Default model loader 32 | """ 33 | classes = 10 34 | if dataset == 'cifar10': 35 | classes = 10 36 | elif dataset == 'cifar100': 37 | classes = 100 38 | elif dataset == 'imagenet': 39 | classes = 10 40 | elif dataset == 'gtsrb': 41 | classes = 43 42 | else: 43 | classes = 10 44 | 45 | if arch == "vgg11": 46 | if precision > 0: 47 | model = vggf("A", in_channels, classes, True, precision, 0, 0, []) 48 | else: 49 | model = vgg("A", in_channels, classes, True, precision) 50 | elif arch == "vgg16": 51 | if precision > 0: 52 | model = vggf("D", in_channels, classes, True, precision, 0, 0, []) 53 | else: 54 | model = vgg("D", in_channels, classes, True, precision) 55 | elif arch == "vgg19": 56 | if precision > 0: 57 | model = vggf("E", in_channels, classes, True, precision, 0, 0, []) 58 | else: 59 | model = vgg("E", in_channels, classes, True, precision) 60 | elif arch == "resnet18": 61 | if precision > 0: 62 | model = resnetf("resnet18", classes, precision, 0, 0, []) 63 | else: 64 | model = resnet("resnet18", classes, precision) 65 | elif arch == "resnet34": 66 | if precision > 0: 67 | model = resnetf("resnet34", classes, precision, 0, 0, []) 68 | else: 69 | model = resnet("resnet34", classes, precision) 70 | elif arch == "resnet50": 71 | if precision > 0: 72 | model = resnetf("resnet50", classes, precision, 0, 0, []) 73 | else: 74 | model = resnet("resnet50", classes, precision) 75 | elif arch == "resnet101": 76 | if precision > 0: 77 | model = resnetf("resnet101", classes, precision, 0, 0, []) 78 | else: 79 | model = resnet("resnet101", classes, precision) 80 | elif "resnet18Py" in arch: 81 | if precision > 0: 82 | model = resnetfPy("resnet18", classes, precision, ber=0, position=0, faulty_layers=[]) 83 | elif "resnet50Py" in arch: 84 | if precision > 0: 85 | model = resnetfPy("resnet50", classes, precision, ber=0, position=0, faulty_layers=[]) 86 | elif "vgg11Py" in arch: 87 | if precision > 0: 88 | model = vggfPy("A", 3, classes, True, precision, 0, 0, []) 89 | elif "vgg16Py" in arch: 90 | if precision > 0: 91 | model = vggfPy("D", 3, classes, True, precision, 0, 0, []) 92 | elif "vgg19Py" in arch: 93 | if precision > 0: 94 | model = vggfPy("E", 3, classes, True, precision, 0, 0, []) 95 | else: 96 | raise NotImplementedError 97 | 98 | # print(model) 99 | checkpoint_epoch = -1 100 | 101 | if retrain: 102 | if not os.path.exists(checkpoint_path): 103 | for x in range(cfg.epochs, -1, -1): 104 | if os.path.exists(model_path_from_base(checkpoint_path, x)): 105 | checkpoint_path = model_path_from_base(checkpoint_path, x) 106 | break 107 | 108 | if not os.path.exists(checkpoint_path): 109 | print("Checkpoint path not exists") 110 | return model, checkpoint_epoch 111 | 112 | # print("Restoring model from checkpoint", checkpoint_path) 113 | checkpoint = torch.load(checkpoint_path) 114 | 115 | model.load_state_dict(checkpoint["model_state_dict"]) 116 | # print("restored checkpoint at epoch - ", checkpoint["epoch"]) 117 | # print("Training loss =", checkpoint["loss"]) 118 | # print("Training accuracy =", checkpoint["accuracy"]) 119 | checkpoint_epoch = checkpoint["epoch"] 120 | 121 | return model, checkpoint_epoch 122 | 123 | 124 | def init_models_faulty(arch, in_channels, precision, retrain, checkpoint_path, faulty_layers, bit_error_rate, position, seed=0, dataset='cifar'): 125 | 126 | """ 127 | Perturbed (if needed) model loader. 128 | """ 129 | 130 | if not cfg.faulty_layers or len(cfg.faulty_layers) == 0: 131 | return init_models( 132 | arch, in_channels, precision, retrain, checkpoint_path 133 | ) 134 | else: 135 | """Perturbed models, where the weights are injected with bit 136 | errors at the rate of ber at the specified positions""" 137 | 138 | classes = 10 139 | if dataset == 'cifar10': 140 | classes = 10 141 | elif dataset == 'cifar100': 142 | classes = 100 143 | elif dataset == 'gtsrb': 144 | classes = 43 145 | else: 146 | classes = 10 # Include ImageNet-10 147 | 148 | if arch == "vgg11": 149 | model = vggf( 150 | "A", 151 | in_channels, 152 | classes, 153 | True, 154 | precision, 155 | bit_error_rate, 156 | position, 157 | faulty_layers, 158 | ) 159 | elif arch == "vgg16": 160 | model = vggf( 161 | "D", 162 | in_channels, 163 | classes, 164 | True, 165 | precision, 166 | bit_error_rate, 167 | position, 168 | faulty_layers, 169 | ) 170 | elif arch == "vgg19": 171 | model = vggf( 172 | "E", 173 | in_channels, 174 | classes, 175 | True, 176 | precision, 177 | bit_error_rate, 178 | position, 179 | faulty_layers, 180 | ) 181 | elif arch == "resnet18": 182 | model = resnetf( 183 | "resnet18", 184 | classes, 185 | precision, 186 | bit_error_rate, 187 | position, 188 | faulty_layers, 189 | ) 190 | elif arch == "resnet34": 191 | model = resnetf( 192 | "resnet34", 193 | classes, 194 | precision, 195 | bit_error_rate, 196 | position, 197 | faulty_layers, 198 | ) 199 | elif arch == "resnet50": 200 | model = resnetf( 201 | "resnet50", 202 | classes, 203 | precision, 204 | bit_error_rate, 205 | position, 206 | faulty_layers, 207 | ) 208 | elif arch == "resnet101": 209 | model = resnetf( 210 | "resnet101", 211 | classes, 212 | precision, 213 | bit_error_rate, 214 | position, 215 | faulty_layers, 216 | ) 217 | elif "resnet18Py" in arch: 218 | model = resnetfPy( 219 | "resnet18", 220 | classes, 221 | precision, 222 | bit_error_rate, 223 | position, 224 | faulty_layers, 225 | ) 226 | elif "resnet50Py" in arch: 227 | model = resnetfPy( 228 | "resnet50", 229 | classes, 230 | precision, 231 | bit_error_rate, 232 | position, 233 | faulty_layers, 234 | ) 235 | elif "vgg11Py" in arch: 236 | model = vggfPy( 237 | "A", 238 | in_channels, 239 | classes, 240 | True, 241 | precision, 242 | bit_error_rate, 243 | position, 244 | faulty_layers, 245 | ) 246 | elif "vgg16Py" in arch: 247 | model = vggfPy( 248 | "D", 249 | in_channels, 250 | classes, 251 | True, 252 | precision, 253 | bit_error_rate, 254 | position, 255 | faulty_layers, 256 | ) 257 | elif "vgg19Py" in arch: 258 | model = vggfPy( 259 | "E", in_channels, 260 | classes, 261 | True, 262 | precision, 263 | bit_error_rate, 264 | position, 265 | faulty_layers, 266 | ) 267 | else: 268 | raise NotImplementedError 269 | 270 | # print(model) 271 | checkpoint_epoch = -1 272 | 273 | if retrain: 274 | if not os.path.exists(checkpoint_path): 275 | for x in range(cfg.epochs, -1, -1): 276 | if os.path.exists(model_path_from_base(checkpoint_path, x)): 277 | checkpoint_path = model_path_from_base(checkpoint_path, x) 278 | break 279 | 280 | if not os.path.exists(checkpoint_path): 281 | print("Checkpoint path not exists") 282 | return model, checkpoint_epoch 283 | 284 | # print("Restoring model from checkpoint", checkpoint_path) 285 | checkpoint = torch.load(checkpoint_path) 286 | 287 | model.load_state_dict(checkpoint["model_state_dict"]) 288 | # print("restored checkpoint at epoch - ", checkpoint["epoch"]) 289 | # print("Training loss =", checkpoint["loss"]) 290 | # print("Training accuracy =", checkpoint["accuracy"]) 291 | checkpoint_epoch = checkpoint["epoch"] 292 | 293 | return model, checkpoint_epoch 294 | 295 | 296 | def init_models_pairs(arch, in_channels, precision, retrain, checkpoint_path, faulty_layers, bit_error_rate, position, seed=0, dataset='cifar'): 297 | 298 | """Load the default model as well as the corresponding perturbed model""" 299 | 300 | model, checkpoint_epoch = init_models( 301 | arch, in_channels, precision, retrain, checkpoint_path, dataset=dataset 302 | ) 303 | model_p, checkpoint_epoch_p = init_models_faulty( 304 | arch, 305 | in_channels, 306 | precision, 307 | retrain, 308 | checkpoint_path, 309 | faulty_layers, 310 | bit_error_rate, 311 | position, 312 | seed=seed, 313 | dataset=dataset, 314 | ) 315 | 316 | return model, checkpoint_epoch, model_p, checkpoint_epoch_p 317 | 318 | 319 | def default_base_model_path(data_dir, arch, dataset, precision, fl, ber, pos): 320 | extra = [arch, dataset, "p", str(precision), "model"] 321 | if len(fl) != 0: 322 | arch = arch + "f" 323 | extra[0] = arch 324 | extra.append("fl") 325 | extra.append("-".join(fl)) 326 | extra.append("ber") 327 | extra.append("%03.3f" % ber) 328 | extra.append("pos") 329 | extra.append(str(pos)) 330 | return os.path.join(data_dir, arch, dataset, "_".join(extra)) 331 | 332 | 333 | def default_model_path(data_dir, arch, dataset, precision, fl, ber, pos, epoch): 334 | extra = [arch, dataset, "p", str(precision), "model"] 335 | if len(fl) != 0: 336 | arch = arch + "f" 337 | extra[0] = arch 338 | extra.append("fl") 339 | extra.append("-".join(fl)) 340 | extra.append("ber") 341 | extra.append("%03.3f" % ber) 342 | extra.append("pos") 343 | extra.append(str(pos)) 344 | extra.append(str(epoch)) 345 | return os.path.join(data_dir, arch, dataset, "_".join(extra) + ".pth") 346 | 347 | 348 | def model_path_from_base(basename, epoch): 349 | return basename + "_" + str(epoch) + ".pth" 350 | -------------------------------------------------------------------------------- /models/resnetf.py: -------------------------------------------------------------------------------- 1 | """ResNet in PyTorch. 2 | 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | """ 7 | 8 | import torch.nn.functional as F 9 | from torch import nn 10 | 11 | from faultinjection_ops import zs_faultinjection_ops 12 | from quantized_ops import zs_quantized_ops 13 | 14 | conv_clamp_val = 0.05 15 | fc_clamp_val = 0.1 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__( 22 | self, 23 | in_planes, 24 | planes, 25 | stride, 26 | precision, 27 | ber, 28 | position, 29 | faulty_layers, 30 | ): 31 | super(BasicBlock, self).__init__() 32 | if "conv" in faulty_layers: 33 | # print('In block') 34 | self.conv1 = zs_faultinjection_ops.nnConv2dPerturbWeight_op( 35 | in_planes, 36 | planes, 37 | kernel_size=3, 38 | stride=stride, 39 | padding=1, 40 | bias=False, 41 | precision=precision, 42 | clamp_val=conv_clamp_val, 43 | ) 44 | else: 45 | self.conv1 = zs_quantized_ops.nnConv2dSymQuant_op( 46 | in_planes, 47 | planes, 48 | kernel_size=3, 49 | stride=stride, 50 | padding=1, 51 | bias=False, 52 | precision=precision, 53 | clamp_val=conv_clamp_val, 54 | ) 55 | self.bn1 = nn.BatchNorm2d(planes) 56 | if "conv" in faulty_layers: 57 | self.conv2 = zs_faultinjection_ops.nnConv2dPerturbWeight_op( 58 | planes, 59 | planes, 60 | kernel_size=3, 61 | stride=1, 62 | padding=1, 63 | bias=False, 64 | precision=precision, 65 | clamp_val=conv_clamp_val, 66 | ) 67 | else: 68 | self.conv2 = zs_quantized_ops.nnConv2dSymQuant_op( 69 | planes, 70 | planes, 71 | kernel_size=3, 72 | stride=1, 73 | padding=1, 74 | bias=False, 75 | precision=precision, 76 | clamp_val=conv_clamp_val, 77 | ) 78 | self.bn2 = nn.BatchNorm2d(planes) 79 | self.relu1 = nn.ReLU() 80 | self.relu2 = nn.ReLU() 81 | self.shortcut = nn.Sequential() 82 | if stride != 1 or in_planes != self.expansion * planes: 83 | if "conv" in faulty_layers: 84 | self.shortcut = nn.Sequential( 85 | zs_faultinjection_ops.nnConv2dPerturbWeight_op( 86 | in_planes, 87 | self.expansion * planes, 88 | kernel_size=1, 89 | stride=stride, 90 | padding=0, 91 | bias=False, 92 | precision=precision, 93 | clamp_val=conv_clamp_val, 94 | ), 95 | nn.BatchNorm2d(self.expansion * planes), 96 | ) 97 | else: 98 | self.shortcut = nn.Sequential( 99 | zs_quantized_ops.nnConv2dSymQuant_op( 100 | in_planes, 101 | self.expansion * planes, 102 | kernel_size=1, 103 | stride=stride, 104 | padding=0, 105 | bias=False, 106 | precision=precision, 107 | clamp_val=conv_clamp_val, 108 | ), 109 | nn.BatchNorm2d(self.expansion * planes), 110 | ) 111 | 112 | def forward(self, x): 113 | out = self.relu1(self.bn1(self.conv1(x))) 114 | out = self.bn2(self.conv2(out)) 115 | out += self.shortcut(x) 116 | out = self.relu2(out) 117 | return out 118 | 119 | class Bottleneck(nn.Module): 120 | expansion = 4 121 | 122 | def __init__( 123 | self, 124 | in_planes, 125 | planes, 126 | stride, 127 | precision, 128 | ber, 129 | position, 130 | faulty_layers, 131 | ): 132 | super(Bottleneck, self).__init__() 133 | if "conv" in faulty_layers: 134 | # print('In block') 135 | self.conv1 = zs_faultinjection_ops.nnConv2dPerturbWeight_op( 136 | in_planes, 137 | planes, 138 | kernel_size=1, 139 | stride=1, 140 | padding=0, 141 | bias=False, 142 | precision=precision, 143 | clamp_val=conv_clamp_val, 144 | ) 145 | else: 146 | self.conv1 = zs_quantized_ops.nnConv2dSymQuant_op( 147 | in_planes, 148 | planes, 149 | kernel_size=1, 150 | stride=1, 151 | padding=0, 152 | bias=False, 153 | precision=precision, 154 | clamp_val=conv_clamp_val, 155 | ) 156 | self.bn1 = nn.BatchNorm2d(planes) 157 | if "conv" in faulty_layers: 158 | self.conv2 = zs_faultinjection_ops.nnConv2dPerturbWeight_op( 159 | planes, 160 | planes, 161 | kernel_size=3, 162 | stride=stride, 163 | padding=1, 164 | bias=False, 165 | precision=precision, 166 | clamp_val=conv_clamp_val, 167 | ) 168 | else: 169 | self.conv2 = zs_quantized_ops.nnConv2dSymQuant_op( 170 | planes, 171 | planes, 172 | kernel_size=3, 173 | stride=stride, 174 | padding=1, 175 | bias=False, 176 | precision=precision, 177 | clamp_val=conv_clamp_val, 178 | ) 179 | self.bn2 = nn.BatchNorm2d(planes) 180 | if "conv" in faulty_layers: 181 | self.conv3 = zs_faultinjection_ops.nnConv2dPerturbWeight_op( 182 | planes, 183 | self.expansion * planes, 184 | kernel_size=1, 185 | stride=1, 186 | padding=0, 187 | bias=False, 188 | precision=precision, 189 | clamp_val=conv_clamp_val, 190 | ) 191 | else: 192 | self.conv3 = zs_quantized_ops.nnConv2dSymQuant_op( 193 | planes, 194 | self.expansion * planes, 195 | kernel_size=1, 196 | stride=1, 197 | padding=0, 198 | bias=False, 199 | precision=precision, 200 | clamp_val=conv_clamp_val, 201 | ) 202 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 203 | self.relu1 = nn.ReLU() 204 | self.relu2 = nn.ReLU() 205 | self.relu3 = nn.ReLU() 206 | self.shortcut = nn.Sequential() 207 | if stride != 1 or in_planes != self.expansion * planes: 208 | if "conv" in faulty_layers: 209 | self.shortcut = nn.Sequential( 210 | zs_faultinjection_ops.nnConv2dPerturbWeight_op( 211 | in_planes, 212 | self.expansion * planes, 213 | kernel_size=1, 214 | stride=stride, 215 | padding=0, 216 | bias=False, 217 | precision=precision, 218 | clamp_val=conv_clamp_val, 219 | ), 220 | nn.BatchNorm2d(self.expansion * planes), 221 | ) 222 | else: 223 | self.shortcut = nn.Sequential( 224 | zs_quantized_ops.nnConv2dSymQuant_op( 225 | in_planes, 226 | self.expansion * planes, 227 | kernel_size=1, 228 | stride=stride, 229 | padding=0, 230 | bias=False, 231 | precision=precision, 232 | clamp_val=conv_clamp_val, 233 | ), 234 | nn.BatchNorm2d(self.expansion * planes), 235 | ) 236 | 237 | def forward(self, x): 238 | out = self.relu1(self.bn1(self.conv1(x))) 239 | out = self.relu2(self.bn2(self.conv2(out))) 240 | out = self.bn3(self.conv3(out)) 241 | out += self.shortcut(x) 242 | out = self.relu3(out) 243 | return out 244 | 245 | 246 | class ResNet(nn.Module): 247 | def __init__( 248 | self, 249 | block, 250 | num_blocks, 251 | num_classes, 252 | precision, 253 | ber, 254 | position, 255 | faulty_layers, 256 | ): 257 | super(ResNet, self).__init__() 258 | self.in_planes = 64 259 | 260 | if "conv" in faulty_layers: 261 | # print('In first') 262 | self.conv1 = zs_faultinjection_ops.nnConv2dPerturbWeight_op( 263 | 3, 264 | 64, 265 | kernel_size=3, 266 | stride=1, 267 | padding=1, 268 | bias=False, 269 | precision=precision, 270 | clamp_val=conv_clamp_val, 271 | ) 272 | else: 273 | self.conv1 = zs_quantized_ops.nnConv2dSymQuant_op( 274 | 3, 275 | 64, 276 | kernel_size=3, 277 | stride=1, 278 | padding=1, 279 | bias=False, 280 | precision=precision, 281 | clamp_val=conv_clamp_val, 282 | ) 283 | self.bn1 = nn.BatchNorm2d(64) 284 | self.layer1 = self._make_layer( 285 | block, 286 | 64, 287 | num_blocks[0], 288 | stride=1, 289 | precision=precision, 290 | ber=ber, 291 | position=position, 292 | faulty_layers=faulty_layers, 293 | ) 294 | self.layer2 = self._make_layer( 295 | block, 296 | 128, 297 | num_blocks[1], 298 | stride=2, 299 | precision=precision, 300 | ber=ber, 301 | position=position, 302 | faulty_layers=faulty_layers, 303 | ) 304 | self.layer3 = self._make_layer( 305 | block, 306 | 256, 307 | num_blocks[2], 308 | stride=2, 309 | precision=precision, 310 | ber=ber, 311 | position=position, 312 | faulty_layers=faulty_layers, 313 | ) 314 | self.layer4 = self._make_layer( 315 | block, 316 | 512, 317 | num_blocks[3], 318 | stride=2, 319 | precision=precision, 320 | ber=ber, 321 | position=position, 322 | faulty_layers=faulty_layers, 323 | ) 324 | if "linear" in faulty_layers: 325 | self.linear = zs_faultinjection_ops.nnLinearPerturbWeight_op( 326 | 512 * block.expansion, 327 | num_classes, 328 | precision, 329 | fc_clamp_val, 330 | ) 331 | else: 332 | self.linear = zs_quantized_ops.nnLinearSymQuant_op( 333 | 512 * block.expansion, num_classes, precision, fc_clamp_val 334 | ) 335 | 336 | def _make_layer( 337 | self, 338 | block, 339 | planes, 340 | num_blocks, 341 | stride, 342 | precision, 343 | ber, 344 | position, 345 | faulty_layers, 346 | ): 347 | strides = [stride] + [1] * (num_blocks - 1) 348 | layers = [] 349 | for stride in strides: 350 | layers.append( 351 | block( 352 | self.in_planes, 353 | planes, 354 | stride, 355 | precision, 356 | ber, 357 | position, 358 | faulty_layers, 359 | ) 360 | ) 361 | self.in_planes = planes * block.expansion 362 | return nn.Sequential(*layers) 363 | 364 | def forward(self, x): 365 | out = F.relu(self.bn1(self.conv1(x))) 366 | out = self.layer1(out) 367 | out = self.layer2(out) 368 | out = self.layer3(out) 369 | out = self.layer4(out) 370 | out = F.avg_pool2d(out, out.shape[3]) 371 | out = out.view(out.size(0), -1) 372 | out = self.linear(out) 373 | return out 374 | 375 | 376 | def ResNet18( 377 | classes, 378 | precision, 379 | ber, 380 | position, 381 | faulty_layers, 382 | ): 383 | return ResNet( 384 | BasicBlock, 385 | [2, 2, 2, 2], 386 | classes, 387 | precision, 388 | ber, 389 | position, 390 | faulty_layers, 391 | ) 392 | 393 | 394 | def ResNet34( 395 | classes, 396 | precision, 397 | ber, 398 | position, 399 | faulty_layers, 400 | ): 401 | return ResNet( 402 | BasicBlock, 403 | [3, 4, 6, 3], 404 | classes, 405 | precision, 406 | ber, 407 | position, 408 | faulty_layers, 409 | ) 410 | 411 | def ResNet50( 412 | classes, 413 | precision, 414 | ber, 415 | position, 416 | faulty_layers, 417 | ): 418 | return ResNet( 419 | Bottleneck, 420 | [3, 4, 6, 3], 421 | classes, 422 | precision, 423 | ber, 424 | position, 425 | faulty_layers, 426 | ) 427 | 428 | def ResNet101( 429 | classes, 430 | precision, 431 | ber, 432 | position, 433 | faulty_layers, 434 | ): 435 | return ResNet( 436 | Bottleneck, 437 | [3, 4, 23, 3], 438 | classes, 439 | precision, 440 | ber, 441 | position, 442 | faulty_layers, 443 | ) 444 | 445 | 446 | def resnetf( 447 | arch, 448 | classes, 449 | precision, 450 | ber, 451 | position, 452 | faulty_layers, 453 | ): 454 | if arch == "resnet18": 455 | return ResNet18( 456 | classes, 457 | precision, 458 | ber, 459 | position, 460 | faulty_layers, 461 | ) 462 | elif arch == "resnet34": 463 | return ResNet34( 464 | classes, 465 | precision, 466 | ber, 467 | position, 468 | faulty_layers, 469 | ) 470 | elif arch == "resnet50": 471 | return ResNet50( 472 | classes, 473 | precision, 474 | ber, 475 | position, 476 | faulty_layers, 477 | ) 478 | elif arch == "resnet101": 479 | return ResNet101( 480 | classes, 481 | precision, 482 | ber, 483 | position, 484 | faulty_layers, 485 | ) 486 | -------------------------------------------------------------------------------- /models/resnetf_pytorch.py: -------------------------------------------------------------------------------- 1 | """ResNet in PyTorch. 2 | 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | """ 7 | 8 | import torch.nn.functional as F 9 | from torch import nn 10 | 11 | from faultinjection_ops import zs_faultinjection_ops 12 | from quantized_ops import zs_quantized_ops 13 | 14 | conv_clamp_val = 0.05 15 | fc_clamp_val = 0.1 16 | 17 | 18 | class BasicBlockPy(nn.Module): 19 | expansion = 1 20 | 21 | def __init__( 22 | self, 23 | in_planes, 24 | planes, 25 | stride, 26 | precision, 27 | ber, 28 | position, 29 | faulty_layers, 30 | ): 31 | super(BasicBlockPy, self).__init__() 32 | if "conv" in faulty_layers: 33 | # print('In block') 34 | self.conv1 = zs_faultinjection_ops.nnConv2dPerturbWeight_op( 35 | in_planes, 36 | planes, 37 | kernel_size=3, 38 | stride=stride, 39 | padding=1, 40 | bias=False, 41 | precision=precision, 42 | clamp_val=conv_clamp_val, 43 | ) 44 | else: 45 | self.conv1 = zs_quantized_ops.nnConv2dSymQuant_op( 46 | in_planes, 47 | planes, 48 | kernel_size=3, 49 | stride=stride, 50 | padding=1, 51 | bias=False, 52 | precision=precision, 53 | clamp_val=conv_clamp_val, 54 | ) 55 | self.bn1 = nn.BatchNorm2d(planes) 56 | self.relu = nn.ReLU() 57 | if "conv" in faulty_layers: 58 | self.conv2 = zs_faultinjection_ops.nnConv2dPerturbWeight_op( 59 | planes, 60 | planes, 61 | kernel_size=3, 62 | stride=1, 63 | padding=1, 64 | bias=False, 65 | precision=precision, 66 | clamp_val=conv_clamp_val, 67 | ) 68 | else: 69 | self.conv2 = zs_quantized_ops.nnConv2dSymQuant_op( 70 | planes, 71 | planes, 72 | kernel_size=3, 73 | stride=1, 74 | padding=1, 75 | bias=False, 76 | precision=precision, 77 | clamp_val=conv_clamp_val, 78 | ) 79 | self.bn2 = nn.BatchNorm2d(planes) 80 | 81 | self.downsample = nn.Sequential() 82 | if stride != 1 or in_planes != self.expansion * planes: 83 | if "conv" in faulty_layers: 84 | self.downsample = nn.Sequential( 85 | zs_faultinjection_ops.nnConv2dPerturbWeight_op( 86 | in_planes, 87 | self.expansion * planes, 88 | kernel_size=1, 89 | stride=stride, 90 | padding=0, 91 | bias=False, 92 | precision=precision, 93 | clamp_val=conv_clamp_val, 94 | ), 95 | nn.BatchNorm2d(self.expansion * planes), 96 | ) 97 | else: 98 | self.downsample = nn.Sequential( 99 | zs_quantized_ops.nnConv2dSymQuant_op( 100 | in_planes, 101 | self.expansion * planes, 102 | kernel_size=1, 103 | stride=stride, 104 | padding=0, 105 | bias=False, 106 | precision=precision, 107 | clamp_val=conv_clamp_val, 108 | ), 109 | nn.BatchNorm2d(self.expansion * planes), 110 | ) 111 | 112 | def forward(self, x): 113 | out = self.relu(self.bn1(self.conv1(x))) 114 | out = self.bn2(self.conv2(out)) 115 | out += self.downsample(x) 116 | return out 117 | 118 | class BottleneckPy(nn.Module): 119 | expansion = 4 120 | 121 | def __init__( 122 | self, 123 | in_planes, 124 | planes, 125 | stride, 126 | precision, 127 | ber, 128 | position, 129 | faulty_layers, 130 | ): 131 | super(BottleneckPy, self).__init__() 132 | if "conv" in faulty_layers: 133 | # print('In block') 134 | self.conv1 = zs_faultinjection_ops.nnConv2dPerturbWeight_op( 135 | in_planes, 136 | planes, 137 | kernel_size=1, 138 | stride=1, 139 | padding=0, 140 | bias=False, 141 | precision=precision, 142 | clamp_val=conv_clamp_val, 143 | ) 144 | else: 145 | self.conv1 = zs_quantized_ops.nnConv2dSymQuant_op( 146 | in_planes, 147 | planes, 148 | kernel_size=1, 149 | stride=1, 150 | padding=0, 151 | bias=False, 152 | precision=precision, 153 | clamp_val=conv_clamp_val, 154 | ) 155 | self.bn1 = nn.BatchNorm2d(planes) 156 | if "conv" in faulty_layers: 157 | self.conv2 = zs_faultinjection_ops.nnConv2dPerturbWeight_op( 158 | planes, 159 | planes, 160 | kernel_size=3, 161 | stride=stride, 162 | padding=1, 163 | bias=False, 164 | precision=precision, 165 | clamp_val=conv_clamp_val, 166 | ) 167 | else: 168 | self.conv2 = zs_quantized_ops.nnConv2dSymQuant_op( 169 | planes, 170 | planes, 171 | kernel_size=3, 172 | stride=stride, 173 | padding=1, 174 | bias=False, 175 | precision=precision, 176 | clamp_val=conv_clamp_val, 177 | ) 178 | self.bn2 = nn.BatchNorm2d(planes) 179 | if "conv" in faulty_layers: 180 | self.conv3 = zs_faultinjection_ops.nnConv2dPerturbWeight_op( 181 | planes, 182 | self.expansion * planes, 183 | kernel_size=1, 184 | stride=1, 185 | padding=0, 186 | bias=False, 187 | precision=precision, 188 | clamp_val=conv_clamp_val, 189 | ) 190 | else: 191 | self.conv3 = zs_quantized_ops.nnConv2dSymQuant_op( 192 | planes, 193 | self.expansion * planes, 194 | kernel_size=1, 195 | stride=1, 196 | padding=0, 197 | bias=False, 198 | precision=precision, 199 | clamp_val=conv_clamp_val, 200 | ) 201 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 202 | self.relu1 = nn.ReLU() 203 | self.relu2 = nn.ReLU() 204 | self.relu3 = nn.ReLU() 205 | self.downsample = nn.Sequential() 206 | if stride != 1 or in_planes != self.expansion * planes: 207 | if "conv" in faulty_layers: 208 | self.downsample = nn.Sequential( 209 | zs_faultinjection_ops.nnConv2dPerturbWeight_op( 210 | in_planes, 211 | self.expansion * planes, 212 | kernel_size=1, 213 | stride=stride, 214 | padding=0, 215 | bias=False, 216 | precision=precision, 217 | clamp_val=conv_clamp_val, 218 | ), 219 | nn.BatchNorm2d(self.expansion * planes), 220 | ) 221 | else: 222 | self.downsample = nn.Sequential( 223 | zs_quantized_ops.nnConv2dSymQuant_op( 224 | in_planes, 225 | self.expansion * planes, 226 | kernel_size=1, 227 | stride=stride, 228 | padding=0, 229 | bias=False, 230 | precision=precision, 231 | clamp_val=conv_clamp_val, 232 | ), 233 | nn.BatchNorm2d(self.expansion * planes), 234 | ) 235 | 236 | def forward(self, x): 237 | out = self.relu1(self.bn1(self.conv1(x))) 238 | out = self.relu2(self.bn2(self.conv2(out))) 239 | out = self.bn3(self.conv3(out)) 240 | out += self.downsample(x) 241 | out = self.relu3(out) 242 | return out 243 | 244 | 245 | class ResNetPy(nn.Module): 246 | def __init__( 247 | self, 248 | block, 249 | num_blocks, 250 | num_classes, 251 | precision, 252 | ber, 253 | position, 254 | faulty_layers, 255 | ): 256 | super(ResNetPy, self).__init__() 257 | self.in_planes = 64 258 | 259 | if "conv" in faulty_layers: 260 | # print('In first') 261 | self.conv1 = zs_faultinjection_ops.nnConv2dPerturbWeight_op( 262 | 3, 263 | 64, 264 | kernel_size=7, 265 | stride=2, 266 | padding=3, 267 | bias=False, 268 | precision=precision, 269 | clamp_val=conv_clamp_val, 270 | ) 271 | else: 272 | self.conv1 = zs_quantized_ops.nnConv2dSymQuant_op( 273 | 3, 274 | 64, 275 | kernel_size=7, 276 | stride=2, 277 | padding=3, 278 | bias=False, 279 | precision=precision, 280 | clamp_val=conv_clamp_val, 281 | ) 282 | self.bn1 = nn.BatchNorm2d(64) 283 | self.maxpool = m = nn.MaxPool2d( 284 | kernel_size=3, 285 | stride=2, 286 | padding=1, 287 | dilation=1) 288 | self.layer1 = self._make_layerPy( 289 | block, 290 | 64, 291 | num_blocks[0], 292 | stride=1, 293 | precision=precision, 294 | ber=ber, 295 | position=position, 296 | faulty_layers=faulty_layers, 297 | ) 298 | self.layer2 = self._make_layerPy( 299 | block, 300 | 128, 301 | num_blocks[1], 302 | stride=2, 303 | precision=precision, 304 | ber=ber, 305 | position=position, 306 | faulty_layers=faulty_layers, 307 | ) 308 | self.layer3 = self._make_layerPy( 309 | block, 310 | 256, 311 | num_blocks[2], 312 | stride=2, 313 | precision=precision, 314 | ber=ber, 315 | position=position, 316 | faulty_layers=faulty_layers, 317 | ) 318 | self.layer4 = self._make_layerPy( 319 | block, 320 | 512, 321 | num_blocks[3], 322 | stride=2, 323 | precision=precision, 324 | ber=ber, 325 | position=position, 326 | faulty_layers=faulty_layers, 327 | ) 328 | if "linear" in faulty_layers: 329 | self.fc = zs_faultinjection_ops.nnLinearPerturbWeight_op( 330 | 512 * block.expansion, 331 | num_classes, 332 | precision, 333 | fc_clamp_val, 334 | ) 335 | else: 336 | self.fc = zs_quantized_ops.nnLinearSymQuant_op( 337 | 512 * block.expansion, num_classes, precision, fc_clamp_val 338 | ) 339 | 340 | def _make_layerPy( 341 | self, 342 | block, 343 | planes, 344 | num_blocks, 345 | stride, 346 | precision, 347 | ber, 348 | position, 349 | faulty_layers, 350 | ): 351 | strides = [stride] + [1] * (num_blocks - 1) 352 | layers = [] 353 | for stride in strides: 354 | layers.append( 355 | block( 356 | self.in_planes, 357 | planes, 358 | stride, 359 | precision, 360 | ber, 361 | position, 362 | faulty_layers, 363 | ) 364 | ) 365 | self.in_planes = planes * block.expansion 366 | return nn.Sequential(*layers) 367 | 368 | def forward(self, x): 369 | out = self.maxpool(F.relu(self.bn1(self.conv1(x)))) 370 | out = self.layer1(out) 371 | out = self.layer2(out) 372 | out = self.layer3(out) 373 | out = self.layer4(out) 374 | out = F.avg_pool2d(out, out.shape[3]) 375 | out = out.view(out.size(0), -1) 376 | out = self.fc(out) 377 | return out 378 | 379 | 380 | def ResNet18Py( 381 | classes, 382 | precision, 383 | ber, 384 | position, 385 | faulty_layers, 386 | ): 387 | return ResNetPy( 388 | BasicBlockPy, 389 | [2, 2, 2, 2], 390 | classes, 391 | precision, 392 | ber, 393 | position, 394 | faulty_layers, 395 | ) 396 | 397 | 398 | def ResNet34Py( 399 | classes, 400 | precision, 401 | ber, 402 | position, 403 | faulty_layers, 404 | ): 405 | return ResNetPy( 406 | BasicBlockPy, 407 | [3, 4, 6, 3], 408 | classes, 409 | precision, 410 | ber, 411 | position, 412 | faulty_layers, 413 | ) 414 | 415 | def ResNet50Py( 416 | classes, 417 | precision, 418 | ber, 419 | position, 420 | faulty_layers, 421 | ): 422 | return ResNetPy( 423 | BottleneckPy, 424 | [3, 4, 6, 3], 425 | classes, 426 | precision, 427 | ber, 428 | position, 429 | faulty_layers, 430 | ) 431 | 432 | def ResNet101Py( 433 | classes, 434 | precision, 435 | ber, 436 | position, 437 | faulty_layers, 438 | ): 439 | return ResNetPy( 440 | BottleneckPy, 441 | [3, 4, 23, 3], 442 | classes, 443 | precision, 444 | ber, 445 | position, 446 | faulty_layers, 447 | ) 448 | 449 | 450 | def resnetfPy( 451 | arch, 452 | classes, 453 | precision, 454 | ber, 455 | position, 456 | faulty_layers, 457 | ): 458 | if arch == "resnet18": 459 | return ResNet18Py( 460 | classes, 461 | precision, 462 | ber, 463 | position, 464 | faulty_layers, 465 | ) 466 | elif arch == "resnet34": 467 | return ResNet34Py( 468 | classes, 469 | precision, 470 | ber, 471 | position, 472 | faulty_layers, 473 | ) 474 | elif arch == "resnet50": 475 | return ResNet50Py( 476 | classes, 477 | precision, 478 | ber, 479 | position, 480 | faulty_layers, 481 | ) 482 | elif arch == "resnet101": 483 | return ResNet101Py( 484 | classes, 485 | precision, 486 | ber, 487 | position, 488 | faulty_layers, 489 | ) -------------------------------------------------------------------------------- /zs_main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | 10 | # Sinlge model 11 | import zs_train as train 12 | 13 | # EOPM-Based 14 | import zs_train_input_transform_eopm_gen as transform_eopm_gen 15 | 16 | # Evaluation 17 | import zs_train_input_transform_eval as transform_eval 18 | 19 | from config import cfg 20 | from models import default_base_model_path 21 | 22 | np.set_printoptions(threshold=sys.maxsize) 23 | torch.manual_seed(0) 24 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 25 | 26 | 27 | def main(): 28 | 29 | print("Running command:", str(sys.argv)) 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument( 33 | "arch", 34 | help="Input network architecture", 35 | choices=[ 36 | "resnet18", 37 | "resnet50", 38 | "resnet18Py", 39 | "resnet50Py", 40 | "vgg11", 41 | "vgg16", 42 | "vgg19", 43 | "vgg11Py", 44 | "vgg16Py", 45 | "vgg19Py", 46 | ], 47 | default="resnet18", 48 | ) 49 | parser.add_argument( 50 | "mode", 51 | help="Specify operation to perform", 52 | default="train", 53 | choices=[ 54 | "train", 55 | "transform_eval", 56 | "transform_eopm_gen", 57 | ], 58 | ) 59 | parser.add_argument( 60 | "dataset", 61 | help="Specify dataset", 62 | choices=[ 63 | "cifar10", 64 | "cifar100", 65 | "gtsrb", 66 | "imagenet128", 67 | "imagenet224" 68 | ], 69 | default="cifar10", 70 | ) 71 | group = parser.add_argument_group( 72 | "Reliability/Error control Options", 73 | "Options to control the fault injection details.", 74 | ) 75 | group.add_argument( 76 | "-ber", 77 | "--bit_error_rate", 78 | type=float, 79 | help="Bit error rate for training corresponding to known voltage.", 80 | default=0.01, 81 | ) 82 | group.add_argument( 83 | "-pos", 84 | "--position", 85 | type=int, 86 | help="Position of bit errors.", 87 | default=-1, 88 | ) 89 | group = parser.add_argument_group( 90 | "Initialization options", "Options to control the initial state." 91 | ) 92 | group.add_argument( 93 | "-rt", 94 | "--retrain", 95 | action="store_true", 96 | help="Continue training on top of already trained model." 97 | "It will start the " 98 | "process from the provided checkpoint.", 99 | default=False, 100 | ) 101 | group.add_argument( 102 | "-cp", 103 | "--checkpoint", 104 | help="Name of the stored checkpoint that needs to be " 105 | "retrained or used for test (only used if -rt flag is set).", 106 | default=None, 107 | ) 108 | group.add_argument( 109 | "-F", 110 | "--force", 111 | action="store_true", 112 | help="Do not fail if checkpoint already exists. Overwrite it.", 113 | default=False, 114 | ) 115 | group = parser.add_argument_group( 116 | "Other options", "Options to control training/validation process." 117 | ) 118 | group.add_argument( 119 | "-E", 120 | "--epochs", 121 | type=int, 122 | help="Maxium number of epochs to train.", 123 | default=5, 124 | ) 125 | group.add_argument( 126 | "-LR", 127 | "--learning_rate", 128 | type=float, 129 | help="Learning rate for training input transformation of training clean model.", 130 | default=5, 131 | ) 132 | group.add_argument( 133 | "-LM", 134 | "--lambdaVal", 135 | type=float, 136 | help="Lambda value between two loss function", 137 | default=1, 138 | ) 139 | group.add_argument( 140 | "-BS", 141 | "--batch-size", 142 | type=int, 143 | help="Training batch size.", 144 | default=128, 145 | ) 146 | group.add_argument( 147 | "-TBS", 148 | "--test-batch-size", 149 | type=int, 150 | help="Test batch size.", 151 | default=100, 152 | ) 153 | group.add_argument( 154 | "-N", 155 | "--N_perturbed_model", 156 | type=int, 157 | help="How many perturbed model will be used for training.", 158 | default=100, 159 | ) 160 | group.add_argument( 161 | "-G", 162 | "--Generator", 163 | type=str, 164 | help="Which generator to be used.", 165 | default='large', 166 | ) 167 | 168 | args = parser.parse_args() 169 | cfg.epochs = args.epochs 170 | cfg.learning_rate = args.learning_rate 171 | cfg.batch_size = args.batch_size 172 | cfg.test_batch_size = args.test_batch_size 173 | cfg.lb = args.lambdaVal 174 | cfg.N = args.N_perturbed_model 175 | cfg.G = args.Generator 176 | 177 | print("Preparing data..", args.dataset) 178 | if args.dataset == "cifar10": 179 | dataset = "cifar10" 180 | in_channels = 3 181 | 182 | transform_train = transforms.Compose( 183 | [ 184 | transforms.RandomCrop(32, padding=4), 185 | transforms.RandomHorizontalFlip(), 186 | transforms.ToTensor(), 187 | transforms.Lambda(lambda t: t * 2 - 1), 188 | ] 189 | ) 190 | 191 | transform_test = transforms.Compose( 192 | [ 193 | transforms.ToTensor(), 194 | transforms.Lambda(lambda t: t * 2 - 1), 195 | ] 196 | ) 197 | 198 | trainset = torchvision.datasets.CIFAR10( 199 | root=cfg.data_dir, 200 | train=True, 201 | download=True, 202 | transform=transform_train, 203 | ) 204 | trainloader = torch.utils.data.DataLoader( 205 | trainset, batch_size=cfg.batch_size, shuffle=True, num_workers=8 206 | ) 207 | 208 | testset = torchvision.datasets.CIFAR10( 209 | root=cfg.data_dir, 210 | train=False, 211 | download=True, 212 | transform=transform_test, 213 | ) 214 | testloader = torch.utils.data.DataLoader( 215 | testset, 216 | batch_size=cfg.test_batch_size, 217 | shuffle=False, 218 | num_workers=2, 219 | ) 220 | elif args.dataset == "cifar100": 221 | dataset = "cifar100" 222 | in_channels = 3 223 | transform_train = transforms.Compose( 224 | [ 225 | transforms.RandomCrop(32, padding=4), 226 | transforms.RandomHorizontalFlip(), 227 | transforms.ToTensor(), 228 | transforms.Lambda(lambda t: t * 2 - 1), 229 | ] 230 | ) 231 | 232 | transform_test = transforms.Compose( 233 | [ 234 | transforms.ToTensor(), 235 | transforms.Lambda(lambda t: t * 2 - 1), 236 | ] 237 | ) 238 | 239 | trainset = torchvision.datasets.CIFAR100( 240 | root=cfg.data_dir, 241 | train=True, 242 | download=True, 243 | transform=transform_train, 244 | ) 245 | trainloader = torch.utils.data.DataLoader( 246 | trainset, batch_size=cfg.batch_size, shuffle=True, num_workers=8 247 | ) 248 | 249 | testset = torchvision.datasets.CIFAR100( 250 | root=cfg.data_dir, 251 | train=False, 252 | download=True, 253 | transform=transform_test, 254 | ) 255 | testloader = torch.utils.data.DataLoader( 256 | testset, 257 | batch_size=cfg.test_batch_size, 258 | shuffle=False, 259 | num_workers=2, 260 | ) 261 | elif args.dataset == "gtsrb": 262 | dataset = "gtsrb" 263 | in_channels = 3 264 | transform_train = transforms.Compose( 265 | [ 266 | transforms.Resize((32, 32)), 267 | transforms.RandomAffine(degrees = 0, translate=(0.35, 0.35), scale=(0.65, 1.35)), 268 | transforms.RandomCrop(32, padding=4), 269 | transforms.RandomHorizontalFlip(), 270 | transforms.ToTensor(), 271 | transforms.Lambda(lambda t: t * 2 - 1), 272 | ] 273 | ) 274 | 275 | transform_test = transforms.Compose( 276 | [ 277 | transforms.Resize((32, 32)), 278 | transforms.ToTensor(), 279 | transforms.Lambda(lambda t: t * 2 - 1), 280 | ] 281 | ) 282 | 283 | trainset = torchvision.datasets.GTSRB( 284 | root=cfg.data_dir, 285 | split="train", 286 | download=True, 287 | transform=transform_train, 288 | ) 289 | trainloader = torch.utils.data.DataLoader( 290 | trainset, batch_size=cfg.batch_size, shuffle=True, num_workers=8 291 | ) 292 | 293 | testset = torchvision.datasets.GTSRB( 294 | root=cfg.data_dir, 295 | split='test', 296 | download=True, 297 | transform=transform_test, 298 | ) 299 | testloader = torch.utils.data.DataLoader( 300 | testset, 301 | batch_size=cfg.test_batch_size, 302 | shuffle=False, 303 | num_workers=2, 304 | ) 305 | elif args.dataset == "imagenet128": 306 | dataset = "imagenet128" 307 | in_channels = 3 308 | transform_train = transforms.Compose( 309 | [ 310 | transforms.RandomResizedCrop(128), 311 | transforms.RandomHorizontalFlip(), 312 | transforms.ToTensor(), 313 | transforms.Lambda(lambda t: t * 2 - 1), 314 | ] 315 | ) 316 | 317 | transform_test = transforms.Compose( 318 | [ 319 | transforms.Resize((128, 128)), 320 | transforms.ToTensor(), 321 | transforms.Lambda(lambda t: t * 2 - 1), 322 | ] 323 | ) 324 | 325 | trainset = torchvision.datasets.ImageNet( 326 | root='data/imagenet-10/', 327 | split="train", 328 | transform=transform_train, 329 | ) 330 | trainloader = torch.utils.data.DataLoader( 331 | trainset, batch_size=cfg.batch_size, shuffle=True, num_workers=8 332 | ) 333 | 334 | testset = torchvision.datasets.ImageNet( 335 | root='data/imagenet-10/', 336 | split="val", 337 | transform=transform_test, 338 | ) 339 | testloader = torch.utils.data.DataLoader( 340 | testset, 341 | batch_size=cfg.test_batch_size, 342 | shuffle=False, 343 | num_workers=2, 344 | ) 345 | elif args.dataset == "imagenet224": 346 | dataset = "imagenet224" 347 | in_channels = 3 348 | transform_train = transforms.Compose( 349 | [ 350 | transforms.RandomResizedCrop(224), 351 | transforms.RandomHorizontalFlip(), 352 | transforms.ToTensor(), 353 | transforms.Lambda(lambda t: t * 2 - 1), 354 | ] 355 | ) 356 | 357 | transform_test = transforms.Compose( 358 | [ 359 | transforms.Resize((224, 224)), 360 | transforms.ToTensor(), 361 | transforms.Lambda(lambda t: t * 2 - 1), 362 | ] 363 | ) 364 | 365 | trainset = torchvision.datasets.ImageNet( 366 | root='data/imagenet-10/', 367 | split="train", 368 | transform=transform_train, 369 | ) 370 | trainloader = torch.utils.data.DataLoader( 371 | trainset, batch_size=cfg.batch_size, shuffle=True, num_workers=8 372 | ) 373 | 374 | testset = torchvision.datasets.ImageNet( 375 | root='data/imagenet-10/', 376 | split="val", 377 | transform=transform_test, 378 | ) 379 | testloader = torch.utils.data.DataLoader( 380 | testset, 381 | batch_size=cfg.test_batch_size, 382 | shuffle=False, 383 | num_workers=2, 384 | ) 385 | 386 | print("Device", device) 387 | cfg.device = device 388 | 389 | assert isinstance(cfg.faulty_layers, list) 390 | 391 | if args.checkpoint is None and args.mode != "transform": 392 | args.checkpoint = default_base_model_path( 393 | cfg.data_dir, 394 | args.arch, 395 | dataset, 396 | cfg.precision, 397 | cfg.faulty_layers, 398 | args.bit_error_rate, 399 | args.position, 400 | ) 401 | elif args.checkpoint is None and args.mode == "transform": 402 | args.checkpoint = [] 403 | args.checkpoint.append( 404 | default_base_model_path( 405 | cfg.data_dir, 406 | args.arch, 407 | dataset, 408 | cfg.precision, 409 | [], 410 | args.bit_error_rate, 411 | args.position, 412 | ) 413 | ) 414 | args.checkpoint.append( 415 | default_base_model_path( 416 | cfg.data_dir, 417 | args.arch, 418 | dataset, 419 | cfg.precision, 420 | cfg.faulty_layers, 421 | args.bit_error_rate, 422 | args.position, 423 | ) 424 | ) 425 | 426 | if args.mode == "train": 427 | print("training args", args) 428 | train.training( 429 | trainloader, 430 | args.arch, 431 | dataset, 432 | in_channels, 433 | cfg.precision, 434 | args.retrain, 435 | args.checkpoint, 436 | args.force, 437 | device, 438 | cfg.faulty_layers, 439 | args.bit_error_rate, 440 | args.position, 441 | ) 442 | elif args.mode == "transform_eopm_gen": 443 | print("input_transform_train_eopm_gen", args) 444 | cfg.save_dir = 'eopm_p_gen/' 445 | cfg.save_dir_curve = 'eopm_curve_gen/' 446 | transform_eopm_gen.transform_train( 447 | trainloader, 448 | testloader, 449 | args.arch, 450 | dataset, 451 | in_channels, 452 | cfg.precision, 453 | args.checkpoint, 454 | args.force, 455 | device, 456 | cfg.faulty_layers, 457 | args.bit_error_rate, 458 | args.position, 459 | ) 460 | elif args.mode == "transform_eval": 461 | print("input_transform_train_eval", args) 462 | transform_eval.transform_eval( 463 | trainloader, 464 | testloader, 465 | args.arch, 466 | dataset, 467 | in_channels, 468 | cfg.precision, 469 | args.checkpoint, 470 | args.force, 471 | device, 472 | cfg.faulty_layers, 473 | args.bit_error_rate, 474 | args.position, 475 | ) 476 | else: 477 | raise NotImplementedError 478 | 479 | if __name__ == "__main__": 480 | main() 481 | -------------------------------------------------------------------------------- /faultinjection_ops/zs_faultinjection_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | quantized version of nn.Linear and nn.Conv with bit errors injected in weights 3 | Currently bit error injection is supported only in 8-bit type 4 | Followed the explanation in 5 | https://pytorch.org/docs/stable/notes/extending.html on how to write a 6 | custom autograd function and extend the nn.Linear class 7 | https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html 8 | """ 9 | 10 | from torch.nn.modules.utils import _pair 11 | import torch.nn.functional as F 12 | from torch import nn 13 | import torch 14 | 15 | import copy 16 | import math 17 | 18 | import faultsMap as fmap 19 | 20 | debug = False 21 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 22 | dtype = torch.float32 23 | 24 | 25 | class FaultInject(torch.autograd.Function): 26 | """ 27 | Perturb the model weights. 28 | Later, quantize and dequantize the model weights in the forward pass. 29 | """ 30 | 31 | @staticmethod 32 | def forward(ctx, input, precision, clamp_val, BitErrorMap0to1, BitErrorMap1to0, use_max=True, q_method='symmetric_signed'): 33 | ctx.save_for_backward(input) 34 | # ctx.mark_dirty(input) 35 | 36 | """ 37 | Compute quantization step size. Mapping (-max_val, max_val) 38 | linearly to (-127,127) 39 | """ 40 | 41 | if q_method == 'symmetric_unsigned' or q_method == 'asymmetric_unsigned': 42 | """ 43 | Compute quantization step size. 44 | Mapping (min_val, max_val) linearly to (0, 2^m - 1) 45 | """ 46 | if use_max and q_method == 'symmetric_unsigned': 47 | max_val = torch.max(torch.abs(input)) 48 | min_val = -max_val 49 | input = torch.clamp(input, min_val, max_val) 50 | delta = max_val / (2 ** (precision - 1) - 1) 51 | input_q = torch.round((input / delta)) 52 | input_q = input_q.to(torch.int32) + (2 ** (precision - 1) - 1) # [0, 255] => 0....011111111 for 32 bits 53 | elif not use_max and q_method == 'asymmetric_unsigned': 54 | max_val = torch.max(input) 55 | min_val = torch.min(input) 56 | delta = 1 / (2 ** (precision - 1) - 1) 57 | input = (input - min_val) / (max_val - min_val) 58 | input = input * 2 - 1 # To -1 ~ 1 59 | input_q = torch.round((input / delta)) 60 | input_q = input_q.to(torch.int32) + (2 ** (precision - 1) - 1) # [0, 255] => 0....011111111 for 32 bits 61 | 62 | elif q_method == 'symmetric_signed': 63 | """ 64 | Compute quantization step size. 65 | Mapping (-max_val, max_val) linearly to (-127,127) 66 | """ 67 | # fix me : Need to add a parameter for this one ! 68 | if use_max: 69 | max_val = torch.max(torch.abs(input)) 70 | else: 71 | max_val = clamp_val 72 | 73 | delta = max_val / (2 ** (precision - 1) - 1) 74 | input_clamped = torch.clamp(input, -max_val, max_val) 75 | input_q = torch.round((input_clamped / delta)) 76 | if precision > 0 and precision <= 8: 77 | input_q = input_q.to(torch.int8) 78 | elif precision == 16: 79 | input_q = input_q.to(torch.int16) 80 | else: 81 | input_q = input_q.to(torch.int32) 82 | 83 | """ 84 | Inject faults in the quantized weight 85 | as determined by the bit error map 86 | 87 | BitErrorMap0, BitErrorMap1 = self.MapWeightsToBitErrors(numWeights) 88 | convert to int8, since this becomes uint8 by default 89 | after bitwise op with unsigned biterrormaps 90 | 91 | """ 92 | 93 | if q_method == 'symmetric_unsigned' or q_method == 'asymmetric_unsigned': 94 | input_qand = torch.bitwise_and(BitErrorMap1to0, input_q) 95 | input_qor = torch.bitwise_or(BitErrorMap0to1, input_qand) 96 | elif q_method == 'symmetric_signed': 97 | input_qand = torch.bitwise_and(BitErrorMap1to0, input_q).to(torch.int8) 98 | input_qor = torch.bitwise_or(BitErrorMap0to1, input_qand).to(torch.int8) 99 | 100 | 101 | """ 102 | Dequantize introducing a quantization error in the 103 | data along with the weight perturbation 104 | """ 105 | if q_method == 'symmetric_unsigned': 106 | input_dq = (input_qor - (2 ** (precision - 1) - 1)) * delta 107 | input_dq = torch.clamp(input_dq, min_val, max_val) 108 | input_dq = input_dq.to(torch.float32) 109 | elif q_method == 'asymmetric_unsigned': 110 | input_dq = (input_qor - (2 ** (precision - 1) - 1)) * delta 111 | input_dq = ((input_dq + 1) / 2) * (max_val - min_val) + min_val 112 | input_dq = torch.clamp(input_dq, min_val, max_val) 113 | input_dq = input_dq.to(torch.float32) 114 | elif q_method == 'symmetric_signed': 115 | input_dq = input_qor * delta 116 | input_dq = input_dq.to(torch.float32) 117 | 118 | """ For symmetric quantization, it is possible for dequantized tensors to get out of range compared to the original model""" 119 | 120 | input_dq_clamped = torch.clamp(input_dq, torch.min(input), torch.max(input)) 121 | 122 | # Return the perturbed-dequantized weights tensor. 123 | # We want to perturb the weights just for one time. 124 | # So, we don't use input.copy_(input_dq) to replace 125 | # self.weight with input_dq. 126 | return input_dq 127 | 128 | # Straight-through-estimator in backward pass 129 | @staticmethod 130 | def backward(ctx, grad_output): 131 | (input,) = ctx.saved_tensors 132 | return grad_output, None, None, None, None 133 | 134 | 135 | class nnLinearPerturbWeight(nn.Linear): 136 | """Applies a linear transformation to the incoming data: y = xA^T + b 137 | Along with the linear transform, the learnable weights are quantized, 138 | and then a bit error perturbation is introduced. Then the weights are 139 | dequantized. 140 | 141 | """ 142 | 143 | def __init__(self, in_features, out_features, bias, precision=-1, clamp_val=0.1): 144 | super().__init__(in_features, out_features, bias) 145 | self.in_features = in_features 146 | self.out_features = out_features 147 | # self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) 148 | # if bias: 149 | # self.bias = nn.Parameter(torch.Tensor(out_features)) 150 | # else: 151 | # self.register_parameter('bias', None) 152 | self.precision = precision 153 | self.clamp_val = clamp_val 154 | self.reset_parameters() 155 | 156 | def forward(self, input): 157 | if self.precision > 0: 158 | BitErrorMap0to1, BitErrorMap1to0 = self.genFaultMap( 159 | fmap.BitErrorMap0, 160 | fmap.BitErrorMap1, 161 | self.precision, 162 | self.weight, 163 | ) 164 | perturbweight = FaultInject.apply 165 | perturbed_weights = perturbweight( 166 | self.weight, 167 | self.precision, 168 | self.clamp_val, 169 | BitErrorMap0to1, 170 | BitErrorMap1to0, 171 | ) 172 | if debug: 173 | print(torch.min(self.weight).detach().cpu().numpy(), torch.max(self.weight).detach().cpu().numpy()) 174 | print(torch.min(perturbed_weights).detach().cpu().numpy(), torch.max(perturbed_weights).detach().cpu().numpy()) 175 | return F.linear(input, perturbed_weights, self.bias) 176 | 177 | def extra_repr(self) -> str: 178 | return "in_features={}, out_features={}, bias={}, precision={}".format( 179 | self.in_features, 180 | self.out_features, 181 | self.bias is not None, 182 | self.precision, 183 | ) 184 | 185 | def genFaultMap(self, BitErrorMap_flip0to1, BitErrorMap_flip1to0, precision, weights): 186 | 187 | numweights = torch.numel(weights) 188 | 189 | mem_array_rows = BitErrorMap_flip0to1.shape[0] 190 | mem_array_cols = BitErrorMap_flip0to1.shape[1] 191 | 192 | weights_per_row = math.floor(mem_array_cols / precision) 193 | 194 | # Because in random bit error, BitErrorMap0to1 and BitErrorMap1to0 will be the same for every layers in the same perturbed model 195 | # Therefore, calculate the for loop for once is enough. 196 | if fmap.BitErrorMap0to1 is None and fmap.BitErrorMap1to0 is None: 197 | BitErrorMap0to1 = torch.zeros( 198 | (mem_array_rows, weights_per_row), dtype=torch.uint8 199 | ).to(device) 200 | BitErrorMap1to0 = torch.zeros( 201 | (mem_array_rows, weights_per_row), dtype=torch.uint8 202 | ).to(device) 203 | 204 | # Reshaping bit error map to map weights 205 | for k in range(0, weights_per_row): 206 | for j in range(0, precision): 207 | BitErrorMap0to1[:, k] += ( 208 | BitErrorMap_flip0to1[:, k * precision + j] << j 209 | ) 210 | BitErrorMap1to0[:, k] += ( 211 | BitErrorMap_flip1to0[:, k * precision + j] << j 212 | ) 213 | fmap.BitErrorMap0to1 = copy.deepcopy(BitErrorMap0to1) 214 | fmap.BitErrorMap1to0 = copy.deepcopy(BitErrorMap1to0) 215 | else: 216 | BitErrorMap0to1 = copy.deepcopy(fmap.BitErrorMap0to1) 217 | BitErrorMap1to0 = copy.deepcopy(fmap.BitErrorMap1to0) 218 | 219 | rows = math.ceil(numweights / weights_per_row) 220 | cols = weights_per_row 221 | num_banks = math.ceil(rows / mem_array_rows) 222 | BitErrorMap0to1 = torch.tile(BitErrorMap0to1, (num_banks, cols)) 223 | # invert this one, since it needs to be And-ed 224 | BitErrorMap1to0 = torch.tile(~BitErrorMap1to0, (num_banks, cols)) 225 | 226 | # This mapping is highly dependent on data flow 227 | BitErrorMap0to1 = BitErrorMap0to1.view(-1)[0:numweights] 228 | BitErrorMap1to0 = BitErrorMap1to0.view(-1)[0:numweights] 229 | 230 | BitErrorMap0to1 = torch.reshape(BitErrorMap0to1, weights.size()) 231 | BitErrorMap1to0 = torch.reshape(BitErrorMap1to0, weights.size()) 232 | 233 | return BitErrorMap0to1, BitErrorMap1to0 234 | 235 | 236 | def nnLinearPerturbWeight_op(in_features, out_features, precision, clamp_val): 237 | return nnLinearPerturbWeight( 238 | in_features, 239 | out_features, 240 | True, 241 | precision, 242 | clamp_val, 243 | ) 244 | 245 | 246 | class nnConv2dPerturbWeight(nn.Conv2d): 247 | """ 248 | Computes 2d conv output 249 | Weights are quantized and dequantized introducing a quantization error 250 | """ 251 | 252 | def __init__( 253 | self, 254 | in_channels, 255 | out_channels, 256 | kernel_size, 257 | stride=1, 258 | padding=1, 259 | dilation=1, 260 | groups=1, 261 | bias=True, 262 | padding_mode="zeros", 263 | precision=-1, 264 | clamp_val=0.5, 265 | ): 266 | kernel_size = _pair(kernel_size) 267 | stride = _pair(stride) 268 | padding = _pair(padding) 269 | dilation = _pair(dilation) 270 | super().__init__( 271 | in_channels, 272 | out_channels, 273 | kernel_size, 274 | stride, 275 | padding, 276 | dilation, 277 | groups, 278 | bias, 279 | padding_mode, 280 | ) 281 | self.precision = precision 282 | self.clamp_val = clamp_val 283 | 284 | def forward(self, input): 285 | if self.precision > 0: 286 | BitErrorMap0to1, BitErrorMap1to0 = self.genFaultMap( 287 | fmap.BitErrorMap0, 288 | fmap.BitErrorMap1, 289 | self.precision, 290 | self.weight, 291 | ) 292 | perturbweight = FaultInject.apply 293 | perturbed_weights = perturbweight( 294 | self.weight, 295 | self.precision, 296 | self.clamp_val, 297 | BitErrorMap0to1, 298 | BitErrorMap1to0, 299 | ) 300 | if debug: 301 | print(torch.min(self.weight).detach().cpu().numpy(), torch.max(self.weight).detach().cpu().numpy()) 302 | print(torch.min(perturbed_weights).detach().cpu().numpy(), torch.max(perturbed_weights).detach().cpu().numpy()) 303 | return F.conv2d( 304 | input, 305 | perturbed_weights, 306 | self.bias, 307 | self.stride, 308 | self.padding, 309 | self.dilation, 310 | self.groups, 311 | ) 312 | 313 | def genFaultMap(self, BitErrorMap_flip0to1, BitErrorMap_flip1to0, precision, weights): 314 | 315 | numweights = torch.numel(weights) 316 | mem_array_rows = BitErrorMap_flip0to1.shape[0] 317 | mem_array_cols = BitErrorMap_flip0to1.shape[1] 318 | 319 | weights_per_row = math.floor(mem_array_cols / precision) 320 | 321 | # Because in random bit error, BitErrorMap0to1 and BitErrorMap1to0 will be the same for every layers in the same perturbed model 322 | # Therefore, calculate the for loop for once is enough. 323 | if fmap.BitErrorMap0to1 is None and fmap.BitErrorMap1to0 is None: 324 | BitErrorMap0to1 = torch.zeros( 325 | (mem_array_rows, weights_per_row), dtype=torch.uint8 326 | ).to(device) 327 | BitErrorMap1to0 = torch.zeros( 328 | (mem_array_rows, weights_per_row), dtype=torch.uint8 329 | ).to(device) 330 | 331 | # Reshaping bit error map to map weights 332 | for k in range(0, weights_per_row): 333 | for j in range(0, precision): 334 | BitErrorMap0to1[:, k] += ( 335 | BitErrorMap_flip0to1[:, k * precision + j] << j 336 | ) 337 | BitErrorMap1to0[:, k] += ( 338 | BitErrorMap_flip1to0[:, k * precision + j] << j 339 | ) 340 | fmap.BitErrorMap0to1 = copy.deepcopy(BitErrorMap0to1) 341 | fmap.BitErrorMap1to0 = copy.deepcopy(BitErrorMap1to0) 342 | else: 343 | BitErrorMap0to1 = copy.deepcopy(fmap.BitErrorMap0to1) 344 | BitErrorMap1to0 = copy.deepcopy(fmap.BitErrorMap1to0) 345 | 346 | rows = math.ceil(numweights / weights_per_row) 347 | cols = weights_per_row 348 | num_banks = math.ceil(rows / mem_array_rows) 349 | BitErrorMap0to1 = torch.tile(BitErrorMap0to1, (num_banks, cols)) 350 | # invert this one, since it needs to be And-ed 351 | BitErrorMap1to0 = torch.tile(~BitErrorMap1to0, (num_banks, cols)) 352 | 353 | # This mapping is highly dependent on data flow 354 | BitErrorMap0to1 = BitErrorMap0to1.view(-1)[0:numweights] 355 | BitErrorMap1to0 = BitErrorMap1to0.view(-1)[0:numweights] 356 | 357 | BitErrorMap0to1 = torch.reshape(BitErrorMap0to1, weights.size()) 358 | BitErrorMap1to0 = torch.reshape(BitErrorMap1to0, weights.size()) 359 | 360 | return BitErrorMap0to1, BitErrorMap1to0 361 | 362 | 363 | def nnConv2dPerturbWeight_op(in_channels, out_channels, kernel_size, stride, padding, bias, precision, clamp_val): 364 | return nnConv2dPerturbWeight( 365 | in_channels, 366 | out_channels, 367 | kernel_size, 368 | stride=stride, 369 | padding=padding, 370 | dilation=1, 371 | groups=1, 372 | bias=bias, 373 | padding_mode="zeros", 374 | precision=precision, 375 | clamp_val=clamp_val, 376 | ) 377 | -------------------------------------------------------------------------------- /zs_train_input_transform_eopm_gen.py: -------------------------------------------------------------------------------- 1 | from torch.nn.parameter import Parameter 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import torch 5 | 6 | from collections import OrderedDict 7 | from torchsummary import summary 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import random 11 | import tqdm 12 | import copy 13 | 14 | from models import init_models_pairs, create_faults 15 | from models.generator import * 16 | import faultsMap as fmap 17 | from config import cfg 18 | 19 | torch.manual_seed(0) 20 | device = "cuda" if torch.cuda.is_available() else "cpu" 21 | 22 | def compute_loss(model_outputs, labels): 23 | _, preds = torch.max(model_outputs, 1) 24 | labels = labels.view(labels.size(0)) # changing the size from (batch_size,1) to batch_size. 25 | loss = nn.CrossEntropyLoss()(model_outputs, labels) 26 | return loss, preds 27 | 28 | def accuracy_checking(model_orig, model_p, trainloader, testloader, gen, device, use_transform=False): 29 | """ 30 | Check the accuracy for both training data and testing data. 31 | :param model_orig: The clean model. 32 | :param model_p: The perturbed model. 33 | :param trainloader: The loader of training data. 34 | :param testloader: The loader of testing data. 35 | :param gen: Generator object to generate the perturbation based on the input images. 36 | :param device: Specify GPU usage. 37 | :param use_transform: Whether to apply input transfomation or not. 38 | """ 39 | total_train = 0 40 | total_test = 0 41 | correct_orig_train = 0 42 | correct_p_train = 0 43 | correct_orig_test = 0 44 | correct_p_test = 0 45 | 46 | # For training data: 47 | for x, y in trainloader: 48 | total_train += 1 49 | x, y = x.to(device), y.to(device) 50 | if use_transform: 51 | x_adv = gen(x) 52 | out_orig = model_orig(x_adv) 53 | out_p = model_p(x_adv) 54 | else: 55 | out_orig = model_orig(x) 56 | out_p = model_p(x) 57 | _, pred_orig = out_orig.max(1) 58 | _, pred_p = out_p.max(1) 59 | y = y.view(y.size(0)) 60 | correct_orig_train += torch.sum(pred_orig == y.data).item() 61 | correct_p_train += torch.sum(pred_p == y.data).item() 62 | accuracy_orig_train = correct_orig_train / (len(trainloader.dataset)) 63 | accuracy_p_train = correct_p_train / (len(trainloader.dataset)) 64 | 65 | # For testing data: 66 | for x, y in testloader: 67 | total_test += 1 68 | x, y = x.to(device), y.to(device) 69 | if use_transform: 70 | x_adv = gen(x) 71 | out_orig = model_orig(x_adv) 72 | out_p = model_p(x_adv) 73 | else: 74 | out_orig = model_orig(x) 75 | out_p = model_p(x) 76 | _, pred_orig = out_orig.max(1) 77 | _, pred_p = out_p.max(1) 78 | y = y.view(y.size(0)) 79 | correct_orig_test += torch.sum(pred_orig == y.data).item() 80 | correct_p_test += torch.sum(pred_p == y.data).item() 81 | accuracy_orig_test = correct_orig_test / (len(testloader.dataset)) 82 | accuracy_p_test = correct_p_test / (len(testloader.dataset)) 83 | 84 | print("Accuracy of training data: clean model: {:5f}, perturbed model: {:5f}".format( 85 | accuracy_orig_train, accuracy_p_train 86 | ) 87 | ) 88 | print("Accuracy of testing data: clean model: {:5f}, perturbed model: {:5f}".format( 89 | accuracy_orig_test, accuracy_p_test 90 | ) 91 | ) 92 | 93 | return accuracy_orig_train, accuracy_p_train, accuracy_orig_test, accuracy_p_test 94 | 95 | def init_dict(model_transform, grad_dict): 96 | for param_name, param_weight in model_transform.named_parameters(): 97 | if param_weight.requires_grad: 98 | grad_dict[param_name] = 0 99 | 100 | def sum_grads(model_transform, grad_dict): 101 | for param_name, param_weight in model_transform.named_parameters(): 102 | if param_weight.requires_grad: 103 | grad_dict[param_name] += param_weight.grad 104 | 105 | def mean_grads(grad_dict, nums): 106 | param_names = grad_dict.keys() 107 | for param_name in param_names: 108 | grad_dict[param_name] /= nums 109 | 110 | def transform_train(trainloader, testloader, arch, dataset, in_channels, precision, checkpoint_path, force, device, fl, ber, pos, seed=0): 111 | """ 112 | Apply quantization aware training. 113 | :param trainloader: The loader of training data. 114 | :param in_channels: An int. The input channels of the training data. 115 | :param arch: A string. The architecture of the model would be used. 116 | :param dataset: A string. The name of the training data. 117 | :param ber: A float. How many rate of bits would be attacked. 118 | :param precision: An int. The number of bits would be used to quantize 119 | the model. 120 | :param position: 121 | :param checkpoint_path: A string. The path that stores the models. 122 | :param device: Specify GPU usage. 123 | """ 124 | torch.backends.cudnn.benchmark = True 125 | 126 | storeLoss = [] 127 | 128 | if cfg.G == 'ConvL': 129 | Gen = GeneratorConvLQ(precision) 130 | elif cfg.G == 'ConvS': 131 | Gen = GeneratorConvSQ(precision) 132 | elif cfg.G == 'DeConvL': 133 | Gen = GeneratorDeConvLQ(precision) 134 | elif cfg.G == 'DeConvS': 135 | Gen = GeneratorDeConvSQ(precision) 136 | elif cfg.G == 'UNetL': 137 | Gen = GeneratorUNetLQ(precision) 138 | elif cfg.G == 'UNetS': 139 | Gen = GeneratorUNetSQ(precision) 140 | 141 | Gen = Gen.to(device) 142 | 143 | # Using Adam: 144 | optimizer = torch.optim.Adam( 145 | filter(lambda p: p.requires_grad, Gen.parameters()), 146 | lr=cfg.learning_rate, 147 | betas=(0.5, 0.999), 148 | # weight_decay=5e-4, 149 | ) 150 | 151 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=cfg.decay) 152 | lb = cfg.lb # Lambda 153 | 154 | for name, param in Gen.named_parameters(): 155 | print("Param name: {}, grads is: {}".format( 156 | name, 157 | param.requires_grad 158 | ) 159 | ) 160 | 161 | print('========== Check setting: Epoch: {}, Batch_size: {}, N perturbed models: {}, Lambda: {}, BitErrorRate: {}, LR: {}, G: {}, Random Training: {}=========='.format( 162 | cfg.epochs, 163 | cfg.batch_size, 164 | cfg.N, 165 | lb, 166 | ber, 167 | cfg.learning_rate, 168 | cfg.G, 169 | cfg.totalRandom 170 | ) 171 | ) 172 | 173 | print("========== Start training the parameter of the input transform by using EOT attack ==========") 174 | 175 | # Initialization clean and perturbed model 176 | create_faults(precision, ber, pos, seed=0) 177 | model, _, model_perturbed, _ = init_models_pairs(arch, in_channels, precision, True, checkpoint_path, fl, ber, pos, dataset=dataset) 178 | model, model_perturbed = model.to(device), model_perturbed.to(device) 179 | 180 | 181 | if 'Py' in arch: # For Larger Image Size 182 | summary(Gen, (3, 224, 224)) 183 | summary(model, (3, 224, 224)) 184 | elif 'imagenet128' in dataset: 185 | summary(Gen, (3, 128, 128)) 186 | summary(model, (3, 128, 128)) 187 | elif 'imagenet224' in dataset: 188 | summary(Gen, (3, 224, 224)) 189 | summary(model, (3, 224, 224)) 190 | else: 191 | summary(Gen, (3, 32, 32)) 192 | summary(model, (3, 32, 32)) 193 | 194 | for epoch in range(cfg.epochs): 195 | running_loss = 0 196 | running_correct_orig = 0 197 | running_correct_p = 0 198 | each_c_pred = [0] * cfg.N 199 | each_p_pred = [0] * cfg.N 200 | 201 | # For each epoch, we will use N perturbed model for training. 202 | for batch_id, (image, label) in tqdm.tqdm(enumerate(trainloader)): 203 | gradDict = OrderedDict() 204 | init_dict(Gen, gradDict) 205 | 206 | image, label = image.to(device), label.to(device) 207 | 208 | for k in range(cfg.N): 209 | 210 | loss = 0 211 | 212 | image_adv = Gen(image) 213 | image_adv = image_adv.to(device) 214 | 215 | # Random test 216 | if cfg.totalRandom: 217 | j = random.randint(0, cfg.randomRange) 218 | else: 219 | j = k 220 | 221 | # Reset the BitErrorMap for different perturbed models. 222 | fmap.BitErrorMap0to1 = None 223 | fmap.BitErrorMap1to0 = None 224 | 225 | # Create faultMap for faultinjection_ops. 226 | create_faults(precision, ber, pos, seed=j) 227 | 228 | model.eval() 229 | model_perturbed.eval() 230 | Gen.train() 231 | 232 | # Inference the clean model and perturbed model 233 | out_biterror_without_p = model_perturbed(image) 234 | _, pred_without_p = torch.max(out_biterror_without_p, 1) 235 | each_c_pred[k] += torch.sum(pred_without_p == label.data).item() 236 | 237 | out = model(image_adv) # pylint: disable=E1102 238 | out_biterror = model_perturbed(image_adv) # pylint: disable=E1102 239 | 240 | # Compute the loss for clean model and perturbed model 241 | loss_orig, pred_orig = compute_loss(out, label) 242 | loss_p, pred_p = compute_loss(out_biterror, label) 243 | 244 | each_p_pred[k] += torch.sum(pred_p == label.data).item() 245 | 246 | # Keep the running accuracy of clean model and perturbed model. 247 | running_correct_orig += torch.sum(pred_orig == label.data).item() 248 | running_correct_p += torch.sum(pred_p == label.data).item() 249 | 250 | loss = loss_orig + lb * loss_p 251 | 252 | # Keep the overal loss for whole batches 253 | running_loss += loss.item() 254 | 255 | # Calculate the gradients 256 | optimizer.zero_grad() 257 | loss.backward() 258 | 259 | # Sum all of the gradients 260 | sum_grads(Gen, gradDict) 261 | 262 | # Average the gradients 263 | mean_grads(gradDict, cfg.N) 264 | 265 | # Set gradients back to P 266 | for param_name, param_weight in Gen.named_parameters(): 267 | param_weight.grad = gradDict[param_name] 268 | 269 | # Apply gradients by optimizer to parameter 270 | optimizer.step() 271 | lr_scheduler.step() 272 | 273 | print('Each pred w/o transformation: {}'.format( 274 | [np.round(x/len(trainloader.dataset), decimals=4) for x in each_c_pred] 275 | ) 276 | ) 277 | print('Each pred with transformation: {}'.format( 278 | [np.round(x/len(trainloader.dataset), decimals=4) for x in each_p_pred] 279 | ) 280 | ) 281 | 282 | # Keep the running accuracy of clean model and perturbed model for all mini-batch. 283 | accuracy_orig = running_correct_orig / (len(trainloader.dataset) * cfg.N) 284 | accuracy_p = running_correct_p / (len(trainloader.dataset) * cfg.N) 285 | print("For epoch: {}, loss: {:.6f}, accuracy for {} clean model: {:.5f}, accuracy for {} perturbed model: {:.5f}".format( 286 | epoch + 1, 287 | running_loss / cfg.N, 288 | cfg.N, 289 | accuracy_orig, 290 | cfg.N, 291 | accuracy_p, 292 | ) 293 | ) 294 | 295 | storeLoss.append(running_loss / cfg.N) 296 | 297 | if (epoch + 1) % 50 == 0 or (epoch + 1) == cfg.epochs: 298 | # Saving the result of the generator! 299 | torch.save(Gen.state_dict(), 300 | cfg.save_dir + 'EOPM_Generator{}Q_{}_arch_{}_LR{}_E_{}_ber_{}_lb_{}_N_{}_step500_NOWE_{}.pt'.format( 301 | cfg.G, 302 | dataset, 303 | arch, 304 | cfg.learning_rate, 305 | cfg.epochs, 306 | ber, 307 | lb, 308 | cfg.N, 309 | epoch+1 310 | ) 311 | ) 312 | 313 | print('========== Start checking the accuracy with different perturbed model ==========') 314 | # Setting without input transformation 315 | accuracy_orig_train_list = [] 316 | accuracy_p_train_list = [] 317 | accuracy_orig_test_list = [] 318 | accuracy_p_test_list = [] 319 | 320 | # Setting with input transformation 321 | accuracy_orig_train_list_with_transformation = [] 322 | accuracy_p_train_list_with_transformation = [] 323 | accuracy_orig_test_list_with_transformation = [] 324 | accuracy_p_test_list_with_transformation = [] 325 | 326 | model, _, model_perturbed, _ = init_models_pairs(arch, in_channels, precision, True, checkpoint_path, fl, ber, pos, dataset=dataset) 327 | model, model_perturbed = model.to(device), model_perturbed.to(device) 328 | for i in range(cfg.beginSeed, cfg.endSeed): 329 | print(' ********** For seed: {} ********** '.format(i)) 330 | fmap.BitErrorMap0to1 = None 331 | fmap.BitErrorMap1to0 = None 332 | create_faults(precision, ber, pos, seed=i) 333 | 334 | model.eval() 335 | model_perturbed.eval() 336 | Gen.eval() 337 | 338 | # Without using transform 339 | accuracy_orig_train, accuracy_p_train, accuracy_orig_test, accuracy_p_test = accuracy_checking(model, model_perturbed, trainloader, testloader, Gen, device, use_transform=False) 340 | accuracy_orig_train_list.append(accuracy_orig_train) 341 | accuracy_p_train_list.append(accuracy_p_train) 342 | accuracy_orig_test_list.append(accuracy_orig_test) 343 | accuracy_p_test_list.append(accuracy_p_test) 344 | 345 | # With input transform 346 | accuracy_orig_train, accuracy_p_train, accuracy_orig_test, accuracy_p_test = accuracy_checking(model, model_perturbed, trainloader, testloader, Gen, device, use_transform=True) 347 | accuracy_orig_train_list_with_transformation.append(accuracy_orig_train) 348 | accuracy_p_train_list_with_transformation.append(accuracy_p_train) 349 | accuracy_orig_test_list_with_transformation.append(accuracy_orig_test) 350 | accuracy_p_test_list_with_transformation.append(accuracy_p_test) 351 | 352 | # Without using transform 353 | print('The average results without input transformation -> accuracy_orig_train: {:5f}, accuracy_p_train: {:5f}, accuracy_orig_test: {:5f}, accuracy_p_test: {:5f}'.format( 354 | np.mean(accuracy_orig_train_list), 355 | np.mean(accuracy_p_train_list), 356 | np.mean(accuracy_orig_test_list), 357 | np.mean(accuracy_p_test_list) 358 | ) 359 | ) 360 | print('The average results without input transformation -> std_accuracy_orig_train: {:5f}, std_accuracy_p_train: {:5f}, std_accuracy_orig_test: {:5f}, std_accuracy_p_test: {:5f}'.format( 361 | np.std(accuracy_orig_train_list), 362 | np.std(accuracy_p_train_list), 363 | np.std(accuracy_orig_test_list), 364 | np.std(accuracy_p_test_list) 365 | ) 366 | ) 367 | 368 | print() 369 | 370 | # With input transform 371 | print('The average results with input transformation -> accuracy_orig_train: {:5f}, accuracy_p_train: {:5f}, accuracy_orig_test: {:5f}, accuracy_p_test: {:5f}'.format( 372 | np.mean(accuracy_orig_train_list_with_transformation), 373 | np.mean(accuracy_p_train_list_with_transformation), 374 | np.mean(accuracy_orig_test_list_with_transformation), 375 | np.mean(accuracy_p_test_list_with_transformation) 376 | ) 377 | ) 378 | print('The average results with input transformation -> std_accuracy_orig_train: {:5f}, std_accuracy_p_train: {:5f}, std_accuracy_orig_test: {:5f}, std_accuracy_p_test: {:5f}'.format( 379 | np.std(accuracy_orig_train_list_with_transformation), 380 | np.std(accuracy_p_train_list_with_transformation), 381 | np.std(accuracy_orig_test_list_with_transformation), 382 | np.std(accuracy_p_test_list_with_transformation)) 383 | ) 384 | -------------------------------------------------------------------------------- /models/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | import torch.nn.functional as F 5 | 6 | from quantized_ops import zs_quantized_ops 7 | from config import cfg 8 | 9 | # ++++++++++++++++++++ Generator V1 ++++++++++++++++++++ 10 | 11 | class GeneratorConvLQ(nn.Module): 12 | """ 13 | Apply reprogramming. 14 | """ 15 | def __init__(self, precision): 16 | super(GeneratorConvLQ, self).__init__() 17 | 18 | if precision > 0: 19 | 20 | self.conv1_1 = zs_quantized_ops.nnConv2dSymQuant_op( 21 | in_channels = 3, 22 | out_channels = 32, 23 | kernel_size = 3, 24 | stride = 1, 25 | padding = 1, 26 | bias = True, 27 | precision = precision, 28 | clamp_val = None, 29 | ) 30 | self.bn1_1 = nn.BatchNorm2d(32) 31 | self.relu1_1 = nn.ReLU() 32 | self.conv1_2 = zs_quantized_ops.nnConv2dSymQuant_op( 33 | in_channels = 32, 34 | out_channels = 32, 35 | kernel_size = 3, 36 | stride = 1, 37 | padding = 1, 38 | bias = True, 39 | precision = precision, 40 | clamp_val = None, 41 | ) 42 | self.bn1_2 = nn.BatchNorm2d(32) 43 | self.relu1_2 = nn.ReLU() 44 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 45 | self.conv2_1 = zs_quantized_ops.nnConv2dSymQuant_op( 46 | in_channels = 32, 47 | out_channels = 64, 48 | kernel_size = 3, 49 | stride = 1, 50 | padding = 1, 51 | bias = True, 52 | precision = precision, 53 | clamp_val = None, 54 | ) 55 | self.bn2_1 = nn.BatchNorm2d(64) 56 | self.relu2_1 = nn.ReLU() 57 | self.conv2_2 = zs_quantized_ops.nnConv2dSymQuant_op( 58 | in_channels = 64, 59 | out_channels = 64, 60 | kernel_size = 3, 61 | stride = 1, 62 | padding = 1, 63 | bias = True, 64 | precision = precision, 65 | clamp_val = None, 66 | ) 67 | self.bn2_2 = nn.BatchNorm2d(64) 68 | self.relu2_2 = nn.ReLU() 69 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 70 | self.conv3_1 = zs_quantized_ops.nnConv2dSymQuant_op( 71 | in_channels = 64, 72 | out_channels = 128, 73 | kernel_size = 3, 74 | stride = 1, 75 | padding = 1, 76 | bias = True, 77 | precision = precision, 78 | clamp_val = None, 79 | ) 80 | self.bn3_1 = nn.BatchNorm2d(128) 81 | self.relu3_1 = nn.ReLU() 82 | self.conv3_2 = zs_quantized_ops.nnConv2dSymQuant_op( 83 | in_channels = 128, 84 | out_channels = 128, 85 | kernel_size = 3, 86 | stride = 1, 87 | padding = 1, 88 | bias = True, 89 | precision = precision, 90 | clamp_val = None, 91 | ) 92 | self.bn3_2 = nn.BatchNorm2d(128) 93 | self.relu3_2 = nn.ReLU() 94 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 95 | self.conv4_1 = zs_quantized_ops.nnConv2dSymQuant_op( 96 | in_channels = 128, 97 | out_channels = 128, 98 | kernel_size = 3, 99 | stride = 1, 100 | padding = 1, 101 | bias = True, 102 | precision = precision, 103 | clamp_val = None, 104 | ) 105 | self.bn4_1 = nn.BatchNorm2d(128) 106 | self.relu4_1 = nn.ReLU() 107 | self.upsample4_1 = nn.Upsample(scale_factor=2) 108 | self.conv4_2 = zs_quantized_ops.nnConv2dSymQuant_op( 109 | in_channels = 128, 110 | out_channels = 128, 111 | kernel_size = 3, 112 | stride = 1, 113 | padding = 1, 114 | bias = True, 115 | precision = precision, 116 | clamp_val = None, 117 | ) 118 | self.bn4_2 = nn.BatchNorm2d(128) 119 | self.relu4_2 = nn.ReLU() 120 | self.conv5_1 = zs_quantized_ops.nnConv2dSymQuant_op( 121 | in_channels = 128, 122 | out_channels = 64, 123 | kernel_size = 3, 124 | stride = 1, 125 | padding = 1, 126 | bias = True, 127 | precision = precision, 128 | clamp_val = None, 129 | ) 130 | self.bn5_1 = nn.BatchNorm2d(64) 131 | self.relu5_1 = nn.ReLU() 132 | self.upsample5_1 = nn.Upsample(scale_factor=2) 133 | self.conv5_2 = zs_quantized_ops.nnConv2dSymQuant_op( 134 | in_channels = 64, 135 | out_channels = 64, 136 | kernel_size = 3, 137 | stride = 1, 138 | padding = 1, 139 | bias = True, 140 | precision = precision, 141 | clamp_val = None, 142 | ) 143 | self.bn5_2 = nn.BatchNorm2d(64) 144 | self.relu5_2 = nn.ReLU() 145 | self.conv6_1 = zs_quantized_ops.nnConv2dSymQuant_op( 146 | in_channels = 64, 147 | out_channels = 32, 148 | kernel_size = 3, 149 | stride = 1, 150 | padding = 1, 151 | bias = True, 152 | precision = precision, 153 | clamp_val = None, 154 | ) 155 | self.relu6_1 = nn.ReLU() 156 | self.bn6_1 = nn.BatchNorm2d(32) 157 | self.upsample6_1 = nn.Upsample(scale_factor=2) 158 | self.conv6_2 = zs_quantized_ops.nnConv2dSymQuant_op( 159 | in_channels = 32, 160 | out_channels = 32, 161 | kernel_size = 3, 162 | stride = 1, 163 | padding = 1, 164 | bias = True, 165 | precision = precision, 166 | clamp_val = None, 167 | ) 168 | self.bn6_2 = nn.BatchNorm2d(32) 169 | self.relu6_2 = nn.ReLU() 170 | self.convout = zs_quantized_ops.nnConv2dSymQuant_op( 171 | in_channels = 32, 172 | out_channels = 3, 173 | kernel_size = 3, 174 | stride = 1, 175 | padding = 1, 176 | bias = True, 177 | precision = precision, 178 | clamp_val = None, 179 | ) 180 | self.bnout = nn.BatchNorm2d(3) 181 | self.tanh = torch.nn.Tanh() 182 | 183 | else: 184 | self.conv1_1 = nn.Conv2d( 185 | in_channels = 3, 186 | out_channels = 32, 187 | kernel_size = 3, 188 | stride = 1, 189 | padding = 1, 190 | ) 191 | self.bn1_1 = nn.BatchNorm2d(32) 192 | self.relu1_1 = nn.ReLU() 193 | self.conv1_2 = nn.Conv2d( 194 | in_channels = 32, 195 | out_channels = 32, 196 | kernel_size = 3, 197 | stride = 1, 198 | padding = 1, 199 | ) 200 | self.bn1_2 = nn.BatchNorm2d(32) 201 | self.relu1_2 = nn.ReLU() 202 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 203 | self.conv2_1 = nn.Conv2d( 204 | in_channels = 32, 205 | out_channels = 64, 206 | kernel_size = 3, 207 | stride = 1, 208 | padding = 1, 209 | ) 210 | self.bn2_1 = nn.BatchNorm2d(64) 211 | self.relu2_1 = nn.ReLU() 212 | self.conv2_2 = nn.Conv2d( 213 | in_channels = 64, 214 | out_channels = 64, 215 | kernel_size = 3, 216 | stride = 1, 217 | padding = 1, 218 | ) 219 | self.bn2_2 = nn.BatchNorm2d(64) 220 | self.relu2_2 = nn.ReLU() 221 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 222 | self.conv3_1 = nn.Conv2d( 223 | in_channels = 64, 224 | out_channels = 128, 225 | kernel_size = 3, 226 | stride = 1, 227 | padding = 1, 228 | ) 229 | self.bn3_1 = nn.BatchNorm2d(128) 230 | self.relu3_1 = nn.ReLU() 231 | self.conv3_2 = nn.Conv2d( 232 | in_channels = 128, 233 | out_channels = 128, 234 | kernel_size = 3, 235 | stride = 1, 236 | padding = 1, 237 | ) 238 | self.bn3_2 = nn.BatchNorm2d(128) 239 | self.relu3_2 = nn.ReLU() 240 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 241 | self.conv4_1 = nn.Conv2d( 242 | in_channels = 128, 243 | out_channels = 128, 244 | kernel_size = 3, 245 | stride = 1, 246 | padding = 1, 247 | ) 248 | self.bn4_1 = nn.BatchNorm2d(128) 249 | self.relu4_1 = nn.ReLU() 250 | self.upsample4_1 = nn.Upsample(scale_factor=2) 251 | self.conv4_2 = nn.Conv2d( 252 | in_channels = 128, 253 | out_channels = 128, 254 | kernel_size = 3, 255 | stride = 1, 256 | padding = 1, 257 | ) 258 | self.bn4_2 = nn.BatchNorm2d(128) 259 | self.relu4_2 = nn.ReLU() 260 | self.conv5_1 = nn.Conv2d( 261 | in_channels = 128, 262 | out_channels = 64, 263 | kernel_size = 3, 264 | stride = 1, 265 | padding = 1, 266 | ) 267 | self.bn5_1 = nn.BatchNorm2d(64) 268 | self.relu5_1 = nn.ReLU() 269 | self.upsample5_1 = nn.Upsample(scale_factor=2) 270 | self.conv5_2 = nn.Conv2d( 271 | in_channels = 64, 272 | out_channels = 64, 273 | kernel_size = 3, 274 | stride = 1, 275 | padding = 1, 276 | ) 277 | self.bn5_2 = nn.BatchNorm2d(64) 278 | self.relu5_2 = nn.ReLU() 279 | self.conv6_1 = nn.Conv2d( 280 | in_channels = 64, 281 | out_channels = 32, 282 | kernel_size = 3, 283 | stride = 1, 284 | padding = 1, 285 | ) 286 | self.relu6_1 = nn.ReLU() 287 | self.bn6_1 = nn.BatchNorm2d(32) 288 | self.upsample6_1 = nn.Upsample(scale_factor=2) 289 | self.conv6_2 = nn.Conv2d( 290 | in_channels = 32, 291 | out_channels = 32, 292 | kernel_size = 3, 293 | stride = 1, 294 | padding = 1, 295 | ) 296 | self.bn6_2 = nn.BatchNorm2d(32) 297 | self.relu6_2 = nn.ReLU() 298 | self.convout = nn.Conv2d( 299 | in_channels = 32, 300 | out_channels = 3, 301 | kernel_size = 3, 302 | stride = 1, 303 | padding = 1, 304 | ) 305 | self.bnout = nn.BatchNorm2d(3) 306 | self.tanh = torch.nn.Tanh() 307 | 308 | def forward(self, image): 309 | img = image.data.clone() 310 | # Encoder 311 | x = self.relu1_1(self.bn1_1(self.conv1_1(img))) 312 | x = self.relu1_2(self.bn1_2(self.conv1_2(x))) 313 | x = self.maxpool1(x) 314 | x = self.relu2_1(self.bn2_1(self.conv2_1(x))) 315 | x = self.relu2_2(self.bn2_2(self.conv2_2(x))) 316 | x = self.maxpool2(x) 317 | x = self.relu3_1(self.bn3_1(self.conv3_1(x))) 318 | x = self.relu3_2(self.bn3_2(self.conv3_2(x))) 319 | x = self.maxpool3(x) 320 | 321 | # Decoder 322 | x = self.relu4_1(self.bn4_1(self.conv4_1(x))) 323 | x = self.upsample4_1(x) 324 | x = self.relu4_2(self.bn4_2(self.conv4_2(x))) 325 | x = self.relu5_1(self.bn5_1(self.conv5_1(x))) 326 | x = self.upsample5_1(x) 327 | x = self.relu5_2(self.bn5_2(self.conv5_2(x))) 328 | x = self.relu6_1(self.bn6_1(self.conv6_1(x))) 329 | x = self.upsample6_1(x) 330 | x = self.relu6_2(self.bn6_2(self.conv6_2(x))) 331 | x = self.bnout(self.convout(x)) 332 | out = self.tanh(x) 333 | 334 | x_adv = torch.clamp(img + out, min=-1, max=1) 335 | 336 | return x_adv 337 | 338 | # ++++++++++++++++++++ Generator ConvSmall ++++++++++++++++++++ 339 | 340 | class GeneratorConvSQ(nn.Module): 341 | """ 342 | Apply reprogramming. 343 | """ 344 | def __init__(self, precision): 345 | super(GeneratorConvSQ, self).__init__() 346 | 347 | if precision > 0: 348 | self.conv1_1 = zs_quantized_ops.nnConv2dSymQuant_op( 349 | in_channels = 3, 350 | out_channels = 32, 351 | kernel_size = 3, 352 | stride = 1, 353 | padding = 1, 354 | bias = True, 355 | precision = precision, 356 | clamp_val = None, 357 | ) 358 | self.bn1_1 = nn.BatchNorm2d(32) 359 | self.relu1_1 = nn.ReLU() 360 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 361 | 362 | self.conv2_1 = zs_quantized_ops.nnConv2dSymQuant_op( 363 | in_channels = 32, 364 | out_channels = 64, 365 | kernel_size = 3, 366 | stride = 1, 367 | padding = 1, 368 | bias = True, 369 | precision = precision, 370 | clamp_val = None, 371 | ) 372 | self.bn2_1 = nn.BatchNorm2d(64) 373 | self.relu2_1 = nn.ReLU() 374 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 375 | 376 | self.conv3_1 = zs_quantized_ops.nnConv2dSymQuant_op( 377 | in_channels = 64, 378 | out_channels = 64, 379 | kernel_size = 3, 380 | stride = 1, 381 | padding = 1, 382 | bias = True, 383 | precision = precision, 384 | clamp_val = None, 385 | ) 386 | self.bn3_1 = nn.BatchNorm2d(64) 387 | self.relu3_1 = nn.ReLU() 388 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 389 | 390 | self.conv4_1 = zs_quantized_ops.nnConv2dSymQuant_op( 391 | in_channels = 64, 392 | out_channels = 64, 393 | kernel_size = 3, 394 | stride = 1, 395 | padding = 1, 396 | bias = True, 397 | precision = precision, 398 | clamp_val = None, 399 | ) 400 | self.bn4_1 = nn.BatchNorm2d(64) 401 | self.relu4_1 = nn.ReLU() 402 | self.upsample4_1 = nn.Upsample(scale_factor=2) 403 | 404 | self.conv5_1 = zs_quantized_ops.nnConv2dSymQuant_op( 405 | in_channels = 64, 406 | out_channels = 32, 407 | kernel_size = 3, 408 | stride = 1, 409 | padding = 1, 410 | bias = True, 411 | precision = precision, 412 | clamp_val = None, 413 | ) 414 | self.bn5_1 = nn.BatchNorm2d(32) 415 | self.relu5_1 = nn.ReLU() 416 | self.upsample5_1 = nn.Upsample(scale_factor=2) 417 | 418 | self.conv6_1 = zs_quantized_ops.nnConv2dSymQuant_op( 419 | in_channels = 32, 420 | out_channels = 3, 421 | kernel_size = 3, 422 | stride = 1, 423 | padding = 1, 424 | bias = True, 425 | precision = precision, 426 | clamp_val = None, 427 | ) 428 | self.relu6_1 = nn.ReLU() 429 | self.bn6_1 = nn.BatchNorm2d(3) 430 | self.upsample6_1 = nn.Upsample(scale_factor=2) 431 | 432 | self.convout = zs_quantized_ops.nnConv2dSymQuant_op( 433 | in_channels = 3, 434 | out_channels = 3, 435 | kernel_size = 3, 436 | stride = 1, 437 | padding = 1, 438 | bias = True, 439 | precision = precision, 440 | clamp_val = None, 441 | ) 442 | self.bnout = nn.BatchNorm2d(3) 443 | self.tanh = torch.nn.Tanh() 444 | 445 | else: 446 | self.conv1_1 = nn.Conv2d( 447 | in_channels = 3, 448 | out_channels = 32, 449 | kernel_size = 3, 450 | stride = 1, 451 | padding = 1, 452 | ) 453 | self.bn1_1 = nn.BatchNorm2d(32) 454 | self.relu1_1 = nn.ReLU() 455 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 456 | 457 | self.conv2_1 = nn.Conv2d( 458 | in_channels = 32, 459 | out_channels = 64, 460 | kernel_size = 3, 461 | stride = 1, 462 | padding = 1, 463 | ) 464 | self.bn2_1 = nn.BatchNorm2d(64) 465 | self.relu2_1 = nn.ReLU() 466 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 467 | 468 | self.conv3_1 = nn.Conv2d( 469 | in_channels = 64, 470 | out_channels = 64, 471 | kernel_size = 3, 472 | stride = 1, 473 | padding = 1, 474 | ) 475 | self.bn3_1 = nn.BatchNorm2d(64) 476 | self.relu3_1 = nn.ReLU() 477 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 478 | 479 | self.conv4_1 = nn.Conv2d( 480 | in_channels = 64, 481 | out_channels = 64, 482 | kernel_size = 3, 483 | stride = 1, 484 | padding = 1, 485 | ) 486 | self.bn4_1 = nn.BatchNorm2d(64) 487 | self.relu4_1 = nn.ReLU() 488 | self.upsample4_1 = nn.Upsample(scale_factor=2) 489 | 490 | self.conv5_1 = nn.Conv2d( 491 | in_channels = 64, 492 | out_channels = 32, 493 | kernel_size = 3, 494 | stride = 1, 495 | padding = 1, 496 | ) 497 | self.bn5_1 = nn.BatchNorm2d(32) 498 | self.relu5_1 = nn.ReLU() 499 | self.upsample5_1 = nn.Upsample(scale_factor=2) 500 | 501 | self.conv6_1 = nn.Conv2d( 502 | in_channels = 32, 503 | out_channels = 3, 504 | kernel_size = 3, 505 | stride = 1, 506 | padding = 1, 507 | ) 508 | self.relu6_1 = nn.ReLU() 509 | self.bn6_1 = nn.BatchNorm2d(3) 510 | self.upsample6_1 = nn.Upsample(scale_factor=2) 511 | 512 | self.convout = nn.Conv2d( 513 | in_channels = 3, 514 | out_channels = 3, 515 | kernel_size = 3, 516 | stride = 1, 517 | padding = 1, 518 | ) 519 | self.bnout = nn.BatchNorm2d(3) 520 | self.tanh = torch.nn.Tanh() 521 | 522 | def forward(self, image): 523 | img = image.data.clone() 524 | # Encoder 525 | x = self.relu1_1(self.bn1_1(self.conv1_1(img))) 526 | x = self.maxpool1(x) 527 | x = self.relu2_1(self.bn2_1(self.conv2_1(x))) 528 | x = self.maxpool2(x) 529 | x = self.relu3_1(self.bn3_1(self.conv3_1(x))) 530 | x = self.maxpool3(x) 531 | 532 | # Decoder 533 | x = self.relu4_1(self.bn4_1(self.conv4_1(x))) 534 | x = self.upsample4_1(x) 535 | x = self.relu5_1(self.bn5_1(self.conv5_1(x))) 536 | x = self.upsample5_1(x) 537 | x = self.relu6_1(self.bn6_1(self.conv6_1(x))) 538 | x = self.upsample6_1(x) 539 | x = self.bnout(self.convout(x)) 540 | out = self.tanh(x) 541 | 542 | x_adv = torch.clamp(img + out, min=-1, max=1) 543 | 544 | return x_adv 545 | 546 | # ++++++++++++++++++++ Generator DeConv Large ++++++++++++++++++++ 547 | 548 | class GeneratorDeConvLQ(nn.Module): 549 | """ 550 | Apply reprogramming. 551 | """ 552 | def __init__(self, precision): 553 | super(GeneratorDeConvLQ, self).__init__() 554 | 555 | if precision > 0: 556 | self.conv1_1 = zs_quantized_ops.nnConv2dSymQuant_op( 557 | in_channels = 3, 558 | out_channels = 32, 559 | kernel_size = 3, 560 | stride = 1, 561 | padding = 1, 562 | bias = True, 563 | precision = precision, 564 | clamp_val = None, 565 | ) 566 | self.bn1_1 = nn.BatchNorm2d(32) 567 | self.relu1_1 = nn.ReLU() 568 | self.conv1_2 = zs_quantized_ops.nnConv2dSymQuant_op( 569 | in_channels = 32, 570 | out_channels = 32, 571 | kernel_size = 3, 572 | stride = 1, 573 | padding = 1, 574 | bias = True, 575 | precision = precision, 576 | clamp_val = None, 577 | ) 578 | self.bn1_2 = nn.BatchNorm2d(32) 579 | self.relu1_2 = nn.ReLU() 580 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 581 | self.conv2_1 = zs_quantized_ops.nnConv2dSymQuant_op( 582 | in_channels = 32, 583 | out_channels = 64, 584 | kernel_size = 3, 585 | stride = 1, 586 | padding = 1, 587 | bias = True, 588 | precision = precision, 589 | clamp_val = None, 590 | ) 591 | self.bn2_1 = nn.BatchNorm2d(64) 592 | self.relu2_1 = nn.ReLU() 593 | self.conv2_2 = zs_quantized_ops.nnConv2dSymQuant_op( 594 | in_channels = 64, 595 | out_channels = 64, 596 | kernel_size = 3, 597 | stride = 1, 598 | padding = 1, 599 | bias = True, 600 | precision = precision, 601 | clamp_val = None, 602 | ) 603 | self.bn2_2 = nn.BatchNorm2d(64) 604 | self.relu2_2 = nn.ReLU() 605 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 606 | self.conv3_1 = zs_quantized_ops.nnConv2dSymQuant_op( 607 | in_channels = 64, 608 | out_channels = 128, 609 | kernel_size = 3, 610 | stride = 1, 611 | padding = 1, 612 | bias = True, 613 | precision = precision, 614 | clamp_val = None, 615 | ) 616 | self.bn3_1 = nn.BatchNorm2d(128) 617 | self.relu3_1 = nn.ReLU() 618 | self.conv3_2 = zs_quantized_ops.nnConv2dSymQuant_op( 619 | in_channels = 128, 620 | out_channels = 128, 621 | kernel_size = 3, 622 | stride = 1, 623 | padding = 1, 624 | bias = True, 625 | precision = precision, 626 | clamp_val = None, 627 | ) 628 | self.bn3_2 = nn.BatchNorm2d(128) 629 | self.relu3_2 = nn.ReLU() 630 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 631 | 632 | self.conv_mid = zs_quantized_ops.nnConv2dSymQuant_op( 633 | in_channels = 128, 634 | out_channels = 128, 635 | kernel_size = 3, 636 | stride = 1, 637 | padding = 1, 638 | bias = True, 639 | precision = precision, 640 | clamp_val = None, 641 | ) 642 | self.bn_mid = nn.BatchNorm2d(128) 643 | self.relu_mid = nn.ReLU() 644 | 645 | self.dconv4_1 = zs_quantized_ops.nnConvTranspose2dSymQuant_op( 646 | in_channels = 128, 647 | out_channels = 64, 648 | kernel_size = 4, 649 | stride = 2, 650 | padding = 1, 651 | bias = True, 652 | precision = precision, 653 | clamp_val = None, 654 | ) 655 | self.bn4_1 = nn.BatchNorm2d(64) 656 | self.relu4_1 = nn.ReLU() 657 | self.conv4_2 = zs_quantized_ops.nnConv2dSymQuant_op( 658 | in_channels = 64, 659 | out_channels = 64, 660 | kernel_size = 3, 661 | stride = 1, 662 | padding = 1, 663 | bias = True, 664 | precision = precision, 665 | clamp_val = None, 666 | ) 667 | self.bn4_2 = nn.BatchNorm2d(64) 668 | self.relu4_2 = nn.ReLU() 669 | self.dconv5_1 = zs_quantized_ops.nnConvTranspose2dSymQuant_op( 670 | in_channels = 64, 671 | out_channels = 32, 672 | kernel_size = 4, 673 | stride = 2, 674 | padding = 1, 675 | bias = True, 676 | precision = precision, 677 | clamp_val = None, 678 | ) 679 | self.bn5_1 = nn.BatchNorm2d(32) 680 | self.relu5_1 = nn.ReLU() 681 | self.conv5_2 = zs_quantized_ops.nnConv2dSymQuant_op( 682 | in_channels = 32, 683 | out_channels = 32, 684 | kernel_size = 3, 685 | stride = 1, 686 | padding = 1, 687 | bias = True, 688 | precision = precision, 689 | clamp_val = None, 690 | ) 691 | self.bn5_2 = nn.BatchNorm2d(32) 692 | self.relu5_2 = nn.ReLU() 693 | self.dconv6_1 = zs_quantized_ops.nnConvTranspose2dSymQuant_op( 694 | in_channels = 32, 695 | out_channels = 3, 696 | kernel_size = 4, 697 | stride = 2, 698 | padding = 1, 699 | bias = True, 700 | precision = precision, 701 | clamp_val = None, 702 | ) 703 | self.bn6_1 = nn.BatchNorm2d(3) 704 | self.tanh = torch.nn.Tanh() 705 | else: 706 | self.conv1_1 = nn.Conv2d( 707 | in_channels = 3, 708 | out_channels = 32, 709 | kernel_size = 3, 710 | stride = 1, 711 | padding = 1, 712 | ) 713 | self.bn1_1 = nn.BatchNorm2d(32) 714 | self.relu1_1 = nn.ReLU() 715 | self.conv1_2 = nn.Conv2d( 716 | in_channels = 32, 717 | out_channels = 32, 718 | kernel_size = 3, 719 | stride = 1, 720 | padding = 1, 721 | ) 722 | self.bn1_2 = nn.BatchNorm2d(32) 723 | self.relu1_2 = nn.ReLU() 724 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 725 | self.conv2_1 = nn.Conv2d( 726 | in_channels = 32, 727 | out_channels = 64, 728 | kernel_size = 3, 729 | stride = 1, 730 | padding = 1, 731 | ) 732 | self.bn2_1 = nn.BatchNorm2d(64) 733 | self.relu2_1 = nn.ReLU() 734 | self.conv2_2 = nn.Conv2d( 735 | in_channels = 64, 736 | out_channels = 64, 737 | kernel_size = 3, 738 | stride = 1, 739 | padding = 1, 740 | ) 741 | self.bn2_2 = nn.BatchNorm2d(64) 742 | self.relu2_2 = nn.ReLU() 743 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 744 | self.conv3_1 = nn.Conv2d( 745 | in_channels = 64, 746 | out_channels = 128, 747 | kernel_size = 3, 748 | stride = 1, 749 | padding = 1, 750 | ) 751 | self.bn3_1 = nn.BatchNorm2d(128) 752 | self.relu3_1 = nn.ReLU() 753 | self.conv3_2 = nn.Conv2d( 754 | in_channels = 128, 755 | out_channels = 128, 756 | kernel_size = 3, 757 | stride = 1, 758 | padding = 1, 759 | ) 760 | self.bn3_2 = nn.BatchNorm2d(128) 761 | self.relu3_2 = nn.ReLU() 762 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 763 | 764 | self.conv_mid = nn.Conv2d( 765 | in_channels = 128, 766 | out_channels = 128, 767 | kernel_size = 3, 768 | stride = 1, 769 | padding = 1, 770 | ) 771 | self.bn_mid = nn.BatchNorm2d(128) 772 | self.relu_mid = nn.ReLU() 773 | 774 | self.dconv4_1 = nn.ConvTranspose2d( 775 | in_channels = 128, 776 | out_channels = 64, 777 | kernel_size = 4, 778 | stride = 2, 779 | padding = 1, 780 | ) 781 | self.bn4_1 = nn.BatchNorm2d(64) 782 | self.relu4_1 = nn.ReLU() 783 | self.conv4_2 = nn.Conv2d( 784 | in_channels = 64, 785 | out_channels = 64, 786 | kernel_size = 3, 787 | stride = 1, 788 | padding = 1, 789 | ) 790 | self.bn4_2 = nn.BatchNorm2d(64) 791 | self.relu4_2 = nn.ReLU() 792 | self.dconv5_1 = nn.ConvTranspose2d( 793 | in_channels = 64, 794 | out_channels = 32, 795 | kernel_size = 4, 796 | stride = 2, 797 | padding = 1, 798 | ) 799 | self.bn5_1 = nn.BatchNorm2d(32) 800 | self.relu5_1 = nn.ReLU() 801 | self.conv5_2 = nn.Conv2d( 802 | in_channels = 32, 803 | out_channels = 32, 804 | kernel_size = 3, 805 | stride = 1, 806 | padding = 1, 807 | ) 808 | self.bn5_2 = nn.BatchNorm2d(32) 809 | self.relu5_2 = nn.ReLU() 810 | self.dconv6_1 = nn.ConvTranspose2d( 811 | in_channels = 32, 812 | out_channels = 3, 813 | kernel_size = 4, 814 | stride = 2, 815 | padding = 1, 816 | ) 817 | self.bn6_1 = nn.BatchNorm2d(3) 818 | self.tanh = torch.nn.Tanh() 819 | 820 | 821 | def forward(self, image): 822 | img = image.data.clone() 823 | # Encoder 824 | x = self.relu1_1(self.bn1_1(self.conv1_1(img))) 825 | x = self.relu1_2(self.bn1_2(self.conv1_2(x))) 826 | x = self.maxpool1(x) 827 | x = self.relu2_1(self.bn2_1(self.conv2_1(x))) 828 | x = self.relu2_2(self.bn2_2(self.conv2_2(x))) 829 | x = self.maxpool2(x) 830 | x = self.relu3_1(self.bn3_1(self.conv3_1(x))) 831 | x = self.relu3_2(self.bn3_2(self.conv3_2(x))) 832 | x = self.maxpool3(x) 833 | # Decoder 834 | x = self.relu_mid(self.bn_mid(self.conv_mid(x))) 835 | x = self.relu4_1(self.bn4_1(self.dconv4_1(x))) 836 | x = self.relu4_2(self.bn4_2(self.conv4_2(x))) 837 | x = self.relu5_1(self.bn5_1(self.dconv5_1(x))) 838 | x = self.relu5_2(self.bn5_2(self.conv5_2(x))) 839 | x = self.bn6_1(self.dconv6_1(x)) 840 | out = self.tanh(x) 841 | 842 | x_adv = torch.clamp(image + out, min=-1, max=1) 843 | 844 | return x_adv 845 | 846 | 847 | # ++++++++++++++++++++ Generator DeConv Small ++++++++++++++++++++ 848 | 849 | class GeneratorDeConvSQ(nn.Module): 850 | """ 851 | Apply reprogramming. 852 | """ 853 | def __init__(self, precision): 854 | super(GeneratorDeConvSQ, self).__init__() 855 | 856 | if precision > 0: 857 | self.conv1_1 = zs_quantized_ops.nnConv2dSymQuant_op( 858 | in_channels = 3, 859 | out_channels = 32, 860 | kernel_size = 3, 861 | stride = 1, 862 | padding = 1, 863 | bias = True, 864 | precision = precision, 865 | clamp_val = None, 866 | ) 867 | self.bn1_1 = nn.BatchNorm2d(32) 868 | self.relu1_1 = nn.ReLU() 869 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 870 | self.conv2_1 = zs_quantized_ops.nnConv2dSymQuant_op( 871 | in_channels = 32, 872 | out_channels = 64, 873 | kernel_size = 3, 874 | stride = 1, 875 | padding = 1, 876 | bias = True, 877 | precision = precision, 878 | clamp_val = None, 879 | ) 880 | self.bn2_1 = nn.BatchNorm2d(64) 881 | self.relu2_1 = nn.ReLU() 882 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 883 | self.conv3_1 = zs_quantized_ops.nnConv2dSymQuant_op( 884 | in_channels = 64, 885 | out_channels = 64, 886 | kernel_size = 3, 887 | stride = 1, 888 | padding = 1, 889 | bias = True, 890 | precision = precision, 891 | clamp_val = None, 892 | ) 893 | self.bn3_1 = nn.BatchNorm2d(64) 894 | self.relu3_1 = nn.ReLU() 895 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 896 | self.dconv4_1 = zs_quantized_ops.nnConvTranspose2dSymQuant_op( 897 | in_channels = 64, 898 | out_channels = 64, 899 | kernel_size = 4, 900 | stride = 2, 901 | padding = 1, 902 | bias = True, 903 | precision = precision, 904 | clamp_val = None, 905 | ) 906 | self.bn4_1 = nn.BatchNorm2d(64) 907 | self.relu4_1 = nn.ReLU() 908 | self.dconv5_1 = zs_quantized_ops.nnConvTranspose2dSymQuant_op( 909 | in_channels = 64, 910 | out_channels = 32, 911 | kernel_size = 4, 912 | stride = 2, 913 | padding = 1, 914 | bias = True, 915 | precision = precision, 916 | clamp_val = None, 917 | ) 918 | self.bn5_1 = nn.BatchNorm2d(32) 919 | self.relu5_1 = nn.ReLU() 920 | self.dconv6_1 = zs_quantized_ops.nnConvTranspose2dSymQuant_op( 921 | in_channels = 32, 922 | out_channels = 3, 923 | kernel_size = 4, 924 | stride = 2, 925 | padding = 1, 926 | bias = True, 927 | precision = precision, 928 | clamp_val = None, 929 | ) 930 | self.bn6_1 = nn.BatchNorm2d(3) 931 | self.tanh = torch.nn.Tanh() 932 | else: 933 | self.conv1_1 = nn.Conv2d( 934 | in_channels = 3, 935 | out_channels = 32, 936 | kernel_size = 3, 937 | stride = 1, 938 | padding = 1, 939 | ) 940 | self.bn1_1 = nn.BatchNorm2d(32) 941 | self.relu1_1 = nn.ReLU() 942 | self.maxpool1 = nn.MaxPool2d(kernel_size=2) 943 | self.conv2_1 = nn.Conv2d( 944 | in_channels = 32, 945 | out_channels = 64, 946 | kernel_size = 3, 947 | stride = 1, 948 | padding = 1, 949 | ) 950 | self.bn2_1 = nn.BatchNorm2d(64) 951 | self.relu2_1 = nn.ReLU() 952 | self.maxpool2 = nn.MaxPool2d(kernel_size=2) 953 | self.conv3_1 = nn.Conv2d( 954 | in_channels = 64, 955 | out_channels = 64, 956 | kernel_size = 3, 957 | stride = 1, 958 | padding = 1, 959 | ) 960 | self.bn3_1 = nn.BatchNorm2d(64) 961 | self.relu3_1 = nn.ReLU() 962 | self.maxpool3 = nn.MaxPool2d(kernel_size=2) 963 | self.dconv4_1 = nn.ConvTranspose2d( 964 | in_channels = 64, 965 | out_channels = 64, 966 | kernel_size = 4, 967 | stride = 2, 968 | padding = 1, 969 | ) 970 | self.bn4_1 = nn.BatchNorm2d(64) 971 | self.relu4_1 = nn.ReLU() 972 | self.dconv5_1 = nn.ConvTranspose2d( 973 | in_channels = 64, 974 | out_channels = 32, 975 | kernel_size = 4, 976 | stride = 2, 977 | padding = 1, 978 | ) 979 | self.bn5_1 = nn.BatchNorm2d(32) 980 | self.relu5_1 = nn.ReLU() 981 | self.dconv6_1 = nn.ConvTranspose2d( 982 | in_channels = 32, 983 | out_channels = 3, 984 | kernel_size = 4, 985 | stride = 2, 986 | padding = 1, 987 | ) 988 | self.bn6_1 = nn.BatchNorm2d(3) 989 | self.tanh = torch.nn.Tanh() 990 | 991 | def forward(self, image): 992 | img = image.data.clone() 993 | # Encoder 994 | x = self.relu1_1(self.bn1_1(self.conv1_1(img))) 995 | x = self.maxpool1(x) 996 | x = self.relu2_1(self.bn2_1(self.conv2_1(x))) 997 | x = self.maxpool2(x) 998 | x = self.relu3_1(self.bn3_1(self.conv3_1(x))) 999 | x = self.maxpool3(x) 1000 | # Decoder 1001 | x = self.relu4_1(self.bn4_1(self.dconv4_1(x))) 1002 | x = self.relu5_1(self.bn5_1(self.dconv5_1(x))) 1003 | x = self.bn6_1(self.dconv6_1(x)) 1004 | out = self.tanh(x) 1005 | 1006 | x_adv = torch.clamp(image + out, min=-1, max=1) 1007 | 1008 | return x_adv 1009 | 1010 | # ++++++++++++++++++++ GeneratorV14 UNet ++++++++++++++++++++ 1011 | 1012 | class DoubleConvQ(nn.Module): 1013 | """(convolution => [BN] => ReLU) * 2""" 1014 | def __init__(self, in_channels, out_channels, precision, mid_channels=None): 1015 | super().__init__() 1016 | if precision > 0: 1017 | if not mid_channels: 1018 | mid_channels = out_channels 1019 | self.double_conv = nn.Sequential( 1020 | zs_quantized_ops.nnConv2dSymQuant_op( 1021 | in_channels = in_channels, 1022 | out_channels = mid_channels, 1023 | kernel_size = 3, 1024 | stride = 1, 1025 | padding = 1, 1026 | bias = False, 1027 | precision = precision, 1028 | clamp_val = None, 1029 | ), 1030 | nn.BatchNorm2d(mid_channels), 1031 | nn.ReLU(), 1032 | zs_quantized_ops.nnConv2dSymQuant_op( 1033 | in_channels = mid_channels, 1034 | out_channels = out_channels, 1035 | kernel_size = 3, 1036 | stride = 1, 1037 | padding = 1, 1038 | bias = False, 1039 | precision = precision, 1040 | clamp_val = None, 1041 | ), 1042 | nn.BatchNorm2d(out_channels), 1043 | nn.ReLU() 1044 | ) 1045 | else: 1046 | if not mid_channels: 1047 | mid_channels = out_channels 1048 | self.double_conv = nn.Sequential( 1049 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 1050 | nn.BatchNorm2d(mid_channels), 1051 | nn.ReLU(), 1052 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 1053 | nn.BatchNorm2d(out_channels), 1054 | nn.ReLU() 1055 | ) 1056 | def forward(self, x): 1057 | return self.double_conv(x) 1058 | 1059 | class DownQ(nn.Module): 1060 | """Downscaling with maxpool then double conv""" 1061 | 1062 | def __init__(self, in_channels, out_channels, precision): 1063 | super().__init__() 1064 | self.maxpool_conv = nn.Sequential( 1065 | nn.MaxPool2d(2), 1066 | DoubleConvQ(in_channels, out_channels, precision) 1067 | ) 1068 | 1069 | def forward(self, x): 1070 | return self.maxpool_conv(x) 1071 | 1072 | class UpQ(nn.Module): 1073 | """Upscaling then double conv""" 1074 | 1075 | def __init__(self, in_channels, out_channels, precision, bilinear=True): 1076 | super().__init__() 1077 | 1078 | # if bilinear, use the normal convolutions to reduce the number of channels 1079 | if precision > 0: 1080 | if bilinear: 1081 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 1082 | self.conv = DoubleConvQ(in_channels, out_channels, in_channels // 2, precision) 1083 | else: 1084 | self.up = zs_quantized_ops.nnConvTranspose2dSymQuant_op( 1085 | in_channels = in_channels, 1086 | out_channels = in_channels // 2, 1087 | kernel_size = 2, 1088 | stride = 2, 1089 | padding = 1, 1090 | bias = True, 1091 | precision = precision, 1092 | clamp_val = None, 1093 | ) 1094 | self.conv = DoubleConvQ(in_channels, out_channels, precision) 1095 | else: 1096 | # if bilinear, use the normal convolutions to reduce the number of channels 1097 | if bilinear: 1098 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 1099 | self.conv = DoubleConvQ(in_channels, out_channels, in_channels // 2, precision) 1100 | else: 1101 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 1102 | self.conv = DoubleConvQ(in_channels, out_channels, precision) 1103 | 1104 | def forward(self, x1, x2): 1105 | x1 = self.up(x1) 1106 | # input is CHW 1107 | diffY = x2.size()[2] - x1.size()[2] 1108 | diffX = x2.size()[3] - x1.size()[3] 1109 | 1110 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 1111 | diffY // 2, diffY - diffY // 2]) 1112 | 1113 | x = torch.cat([x2, x1], dim=1) 1114 | return self.conv(x) 1115 | 1116 | 1117 | class OutConvQ(nn.Module): 1118 | def __init__(self, in_channels, out_channels, precision): 1119 | super(OutConvQ, self).__init__() 1120 | if precision > 0: 1121 | self.conv = zs_quantized_ops.nnConv2dSymQuant_op( 1122 | in_channels = in_channels, 1123 | out_channels = out_channels, 1124 | kernel_size = 1, 1125 | stride = 1, 1126 | padding = 0, 1127 | bias = True, 1128 | precision = precision, 1129 | clamp_val = None, 1130 | ) 1131 | else: 1132 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 1133 | 1134 | def forward(self, x): 1135 | return self.conv(x) 1136 | 1137 | class GeneratorUNetLQ(nn.Module): 1138 | def __init__(self, precision, bilinear=False): 1139 | super(GeneratorUNetLQ, self).__init__() 1140 | self.bilinear = bilinear 1141 | 1142 | self.inc = DoubleConvQ(3, 16, precision) 1143 | self.down1 = DownQ(16, 32, precision) 1144 | self.down2 = DownQ(32, 64, precision) 1145 | factor = 2 if bilinear else 1 1146 | self.down3 = DownQ(64, 128 // factor, precision) 1147 | self.up1 = UpQ(128, 64 // factor, precision, bilinear) 1148 | self.up2 = UpQ(64, 32 // factor, precision, bilinear) 1149 | self.up3 = UpQ(32, 16 // factor, precision, bilinear) 1150 | self.outc = OutConvQ(16, 3, precision) 1151 | 1152 | 1153 | def forward(self, image): 1154 | img = image.data.clone() 1155 | x1 = self.inc(img) 1156 | x2 = self.down1(x1) 1157 | x3 = self.down2(x2) 1158 | x4 = self.down3(x3) 1159 | x = self.up1(x4, x3) 1160 | x = self.up2(x, x2) 1161 | x = self.up3(x, x1) 1162 | out = self.outc(x) 1163 | x_adv = torch.clamp(image + out, min=-1, max=1) 1164 | return x_adv 1165 | 1166 | class GeneratorUNetSQ(nn.Module): 1167 | def __init__(self, precision, bilinear=False): 1168 | super(GeneratorUNetSQ, self).__init__() 1169 | self.bilinear = bilinear 1170 | 1171 | self.inc = DoubleConvQ(3, 8, precision) 1172 | self.down1 = DownQ(8, 16, precision) 1173 | self.down2 = DownQ(16, 32, precision) 1174 | factor = 2 if bilinear else 1 1175 | self.down3 = DownQ(32, 64 // factor, precision) 1176 | self.up1 = UpQ(64, 32 // factor, precision, bilinear) 1177 | self.up2 = UpQ(32, 16 // factor, precision, bilinear) 1178 | self.up3 = UpQ(16, 8 // factor, precision, bilinear) 1179 | self.outc = OutConvQ(8, 3, precision) 1180 | 1181 | def forward(self, image): 1182 | img = image.data.clone() 1183 | x1 = self.inc(img) 1184 | x2 = self.down1(x1) 1185 | x3 = self.down2(x2) 1186 | x4 = self.down3(x3) 1187 | x = self.up1(x4, x3) 1188 | x = self.up2(x, x2) 1189 | x = self.up3(x, x1) 1190 | out = self.outc(x) 1191 | x_adv = torch.clamp(image + out, min=-1, max=1) 1192 | return x_adv --------------------------------------------------------------------------------