├── README.md ├── advdiffuser.py ├── classifier.py ├── config.py ├── diffusionNet.py ├── exponentialMovingAverage.py ├── main.py ├── unet.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Implementation of Advdiff 2 | 3 | ## Reference 4 | 5 | AdvDiffuser: Natural Adversarial Example Synthesis with Diffusion Models | [Paper](https://openaccess.thecvf.com/content/ICCV2023/papers/Chen_AdvDiffuser_Natural_Adversarial_Example_Synthesis_with_Diffusion_Models_ICCV_2023_paper.pdf) 6 | -------------------------------------------------------------------------------- /advdiffuser.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import utils 5 | import os 6 | import tqdm 7 | 8 | 9 | class AdvDiffuser(nn.Module): 10 | def __init__(self, diffusion_model, num_classes, device): 11 | super(AdvDiffuser, self).__init__() 12 | self.diffusion_model = diffusion_model 13 | self.device = device 14 | self.gap = nn.AdaptiveAvgPool2d(1) 15 | self.fc = nn.Linear(diffusion_model.model.final_conv.out_channels, num_classes) 16 | self.target_layer = "model.final_conv" 17 | 18 | def forward(self, x): 19 | noise = torch.randn_like(x).to(self.device) 20 | x = self.diffusion_model(x, noise) 21 | x = self.gap(x) 22 | x = x.view(x.size(0), -1) 23 | x = self.fc(x) 24 | return x 25 | 26 | def classify(self, x): 27 | x = self.gap(x) 28 | x = x.view(x.size(0), -1) 29 | x = self.fc(x) 30 | return x 31 | 32 | def pretrain_classify_part(self, epochs, train_loader, device): 33 | ckpt_path = "results/grad_cam_classifier.pt" 34 | if os.path.exists(ckpt_path): 35 | print("Loading grad_cam_classifier from checkpoint...") 36 | checkpoint = torch.load(ckpt_path, map_location=device) 37 | self.load_state_dict(checkpoint['model_state']) 38 | print("Loading grad_cam_classifier completed!") 39 | return 40 | 41 | print("Start training grad_cam_classifier.") 42 | # freeze diffusion model part 43 | # for param in self.diffusion_model.parameters(): 44 | # param.requires_grad = False 45 | 46 | optimizer = torch.optim.Adam(list(self.gap.parameters()) + list(self.fc.parameters()), lr=0.001) 47 | criterion = nn.CrossEntropyLoss() 48 | 49 | for epoch in range(epochs): 50 | for images, labels in train_loader: 51 | images, labels = images.to(device), labels.to(device) 52 | 53 | optimizer.zero_grad() 54 | 55 | outputs = self.forward(images) 56 | loss = criterion(outputs, labels) 57 | loss.backward() 58 | 59 | optimizer.step() 60 | 61 | print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}') 62 | 63 | # save check point 64 | if not os.path.isdir('results'): 65 | os.makedirs('results') 66 | torch.save({ 67 | 'model_state': self.state_dict(), 68 | }, ckpt_path) 69 | 70 | def generate_grad_cam_heatmap(self, x, class_idx): 71 | target_layer = dict(self.diffusion_model.named_modules())[self.target_layer] 72 | activations = [] 73 | gradients = [] 74 | 75 | def save_activation(module, input, output): 76 | activations.append(output) 77 | 78 | def save_gradient(module, input_grad, output_grad): 79 | gradients.append(output_grad[0]) 80 | 81 | # regester hook 82 | hook_a = target_layer.register_forward_hook(save_activation) 83 | hook_b = target_layer.register_backward_hook(save_gradient) 84 | 85 | # forward 86 | output = self(x) 87 | if class_idx is None: 88 | class_idx = output.argmax(dim=1) 89 | 90 | # backward 91 | self.diffusion_model.zero_grad() 92 | one_hot_output = torch.zeros_like(output) 93 | one_hot_output.scatter_(1, class_idx.unsqueeze(1), 1) 94 | output.backward(gradient=one_hot_output, retain_graph=True) 95 | 96 | # remove hook 97 | hook_a.remove() 98 | hook_b.remove() 99 | 100 | # cal grad cam 101 | activation = activations[0] 102 | grad = gradients[0] 103 | pooled_grad = torch.mean(grad, dim=[0, 2, 3], keepdim=True) 104 | grad_cam = torch.mul(activation, pooled_grad).sum(dim=1, keepdim=True) 105 | grad_cam = F.relu(grad_cam) 106 | grad_cam = grad_cam / grad_cam.max() 107 | return grad_cam 108 | 109 | def sample_xt_minus_1(self, x, alpha_t_minus_1): 110 | variance = (1 - alpha_t_minus_1) 111 | mean = torch.sqrt(alpha_t_minus_1) * x 112 | noise = torch.randn_like(x) 113 | x_t_minus_1_obj = mean + torch.sqrt(variance) * noise 114 | return x_t_minus_1_obj 115 | 116 | def generate_adversarial_example(self, x, class_idx, timesteps, iterations=2, clipped_reverse_diffusion=True): 117 | # Generate the inverse of the Grad-CAM heatmap 118 | heatmap = self.generate_grad_cam_heatmap(x, class_idx) 119 | utils.save_grad_cam_heatmap(heatmap, 'results/grad_cam_heatmap.png') 120 | inverse_heatmap = 1 - heatmap 121 | 122 | t = torch.randint(0, timesteps, (x.shape[0],)).to(self.device) 123 | noise = torch.randn_like(x).to(self.device) 124 | x_t = self.diffusion_model._forward_diffusion(x, t, noise) 125 | 126 | for batch_idx in range(x.shape[0]): 127 | current_t = t[batch_idx].item() 128 | current_x = x_t[batch_idx].unsqueeze(0) 129 | for time_step in range(current_t, -1, -1): 130 | # denoising 131 | current_noise = torch.randn_like(current_x).to(self.device) 132 | current_t_tensor = torch.tensor([time_step], device=current_x.device, dtype=torch.long) 133 | if clipped_reverse_diffusion: 134 | z_t = self.diffusion_model._reverse_diffusion_with_clip(current_x, current_t_tensor, current_noise) 135 | else: 136 | z_t = self.diffusion_model._reverse_diffusion(current_x, current_t_tensor, current_noise) 137 | 138 | if time_step < current_t / 4: 139 | for _ in range(iterations): 140 | # predict 141 | z_t.requires_grad_() 142 | logits = self.classify(z_t) 143 | pred_class = logits.argmax(dim=1) 144 | if pred_class != class_idx[batch_idx]: 145 | break 146 | 147 | # carry on adding perturbation 148 | loss = F.cross_entropy(logits, class_idx[batch_idx].unsqueeze(0)) 149 | loss.backward() 150 | 151 | with torch.no_grad(): 152 | perturbation = z_t.grad.sign() 153 | z_t = z_t + perturbation 154 | 155 | self.diffusion_model.zero_grad() 156 | z_t.grad = None 157 | 158 | # sample 159 | alpha = self.diffusion_model.alphas_cumprod.gather(-1, t) 160 | x_sample = self.sample_xt_minus_1(z_t, alpha[batch_idx]) 161 | 162 | # combine 163 | current_x = x_sample * heatmap[batch_idx] + z_t * inverse_heatmap[batch_idx] 164 | else: 165 | current_x = z_t 166 | 167 | current_x = torch.clamp(current_x, min=0, max=1).detach() 168 | x_t[batch_idx] = current_x 169 | utils.save_adversarial_example(x_t, 'results/adversarial_example.png') 170 | return x_t 171 | 172 | def denoise_test_data(self, x, timesteps, clipped_reverse_diffusion=True): 173 | t = torch.randint(0, timesteps, (x.shape[0],)).to(self.device) 174 | noise = torch.randn_like(x).to(self.device) 175 | x_t = self.diffusion_model._forward_diffusion(x, t, noise) 176 | 177 | for batch_idx in range(x.shape[0]): 178 | current_t = t[batch_idx].item() 179 | current_x = x_t[batch_idx].unsqueeze(0) 180 | 181 | for time_step in range(current_t, -1, -1): 182 | current_noise = torch.randn_like(current_x).to(self.device) 183 | # denoising 184 | current_t_tensor = torch.tensor([time_step], device=current_x.device, dtype=torch.long) 185 | if clipped_reverse_diffusion: 186 | current_x = self.diffusion_model._reverse_diffusion_with_clip(current_x, current_t_tensor, 187 | current_noise) 188 | else: 189 | current_x = self.diffusion_model._reverse_diffusion(current_x, current_t_tensor, current_noise) 190 | 191 | current_x = (current_x + 1.) / 2. 192 | current_x = torch.clamp(current_x, min=0, max=1).detach() 193 | x_t[batch_idx] = current_x 194 | 195 | utils.save_adversarial_example(x_t, 'results/denoise_test.png') 196 | return x_t 197 | 198 | -------------------------------------------------------------------------------- /classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import os 5 | 6 | 7 | class Classifier(nn.Module): 8 | def __init__(self): 9 | super(Classifier, self).__init__() 10 | self.conv_net = nn.Sequential( 11 | nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1), 12 | nn.ReLU(), 13 | nn.MaxPool2d(kernel_size=2, stride=2), 14 | nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), 15 | nn.ReLU(), 16 | nn.MaxPool2d(kernel_size=2, stride=2) 17 | ) 18 | 19 | self.fc_net = nn.Sequential( 20 | nn.Linear(64 * 7 * 7, 128), 21 | nn.ReLU(), 22 | nn.Linear(128, 10) 23 | ) 24 | 25 | def forward(self, x): 26 | x = self.conv_net(x) 27 | x = x.view(x.size(0), -1) # Flatten the tensor 28 | x = self.fc_net(x) 29 | return x 30 | 31 | def train_model(self, device, train_loader, val_loader=None, epochs=10, lr=0.001): 32 | checkpoint_path = 'results/classifier.pt' 33 | if os.path.exists(checkpoint_path): 34 | print("Loading classifier from checkpoint...") 35 | state = torch.load(checkpoint_path, map_location=device) 36 | self.load_state_dict(state['model_state_dict']) 37 | print("Loading classifier completed!") 38 | return 39 | 40 | criterion = nn.CrossEntropyLoss() 41 | optimizer = optim.Adam(self.parameters(), lr=lr) 42 | 43 | self.to(device) 44 | 45 | for epoch in range(epochs): 46 | self.train() 47 | running_loss = 0.0 48 | for batch_idx, (data, target) in enumerate(train_loader): 49 | data, target = data.to(device), target.to(device) 50 | 51 | optimizer.zero_grad() 52 | output = self(data) 53 | loss = criterion(output, target) 54 | loss.backward() 55 | optimizer.step() 56 | 57 | running_loss += loss.item() 58 | if batch_idx % 100 == 99: # Print every 100 mini-batches 59 | print(f'Epoch {epoch + 1}, Batch {batch_idx + 1}, Loss: {running_loss / 100:.4f}') 60 | running_loss = 0.0 61 | 62 | if val_loader: 63 | self.evaluate_model(device, val_loader) 64 | 65 | if not os.path.isdir('results'): 66 | os.makedirs('results') 67 | torch.save({ 68 | 'model_state_dict': self.state_dict(), 69 | }, checkpoint_path) 70 | print(f"Model saved to {checkpoint_path}") 71 | 72 | def evaluate_model(self, device, val_loader): 73 | self.eval() 74 | correct = 0 75 | total = 0 76 | with torch.no_grad(): 77 | for data, target in val_loader: 78 | data, target = data.to(device), target.to(device) 79 | outputs = self(data) 80 | _, predicted = torch.max(outputs.data, 1) 81 | total += target.size(0) 82 | correct += (predicted == target).sum().item() 83 | 84 | print(f'Accuracy on the test set: {100 * correct / total:.2f}%') 85 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | config_dict = { 2 | 'total_epochs': 10, 3 | 'classify_part_epochs': 2, 4 | 'batch_size': 48, 5 | 'timesteps': 50, 6 | 'lr': 1e-2, 7 | 'base_dim': 32, 8 | 'model_ema_steps': 10, 9 | 'model_ema_decay': 0.995, 10 | 'n_samples': 36, 11 | 'clip_flag': True, 12 | 'log_freq': 50, 13 | 'num_classes': 10 14 | } -------------------------------------------------------------------------------- /diffusionNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | from unet import Unet 5 | from tqdm import tqdm 6 | 7 | 8 | class DiffusionNet(nn.Module): 9 | def __init__(self, image_size, in_channels, time_embedding_dim=256, timesteps=500, base_dim=32, 10 | dim_mults=[1, 2, 4, 8]): 11 | super().__init__() 12 | self.timesteps = timesteps 13 | self.in_channels = in_channels 14 | self.image_size = image_size 15 | 16 | betas = self._cosine_variance_schedule(timesteps) 17 | 18 | alphas = 1. - betas 19 | alphas_cumprod = torch.cumprod(alphas, dim=-1) 20 | 21 | self.register_buffer("betas", betas) 22 | self.register_buffer("alphas", alphas) 23 | self.register_buffer("alphas_cumprod", alphas_cumprod) 24 | self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod)) 25 | self.register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1. - alphas_cumprod)) 26 | 27 | self.model = Unet(timesteps, time_embedding_dim, in_channels, in_channels, base_dim, dim_mults) 28 | 29 | def forward(self, x, noise): 30 | t = torch.randint(0, self.timesteps, (x.shape[0],)).to(x.device) 31 | x_t = self._forward_diffusion(x, t, noise) 32 | pred_noise = self.model(x_t, t) 33 | 34 | return pred_noise 35 | 36 | @torch.no_grad() 37 | def sampling(self, n_samples, clipped_reverse_diffusion=True, device="cuda"): 38 | x_t = torch.randn((n_samples, self.in_channels, self.image_size, self.image_size)).to(device) 39 | for i in tqdm(range(self.timesteps - 1, -1, -1), desc="Sampling"): 40 | noise = torch.randn_like(x_t).to(device) 41 | t = torch.tensor([i for _ in range(n_samples)]).to(device) 42 | 43 | if clipped_reverse_diffusion: 44 | x_t = self._reverse_diffusion_with_clip(x_t, t, noise) 45 | else: 46 | x_t = self._reverse_diffusion(x_t, t, noise) 47 | 48 | x_t = (x_t + 1.) / 2. # [-1,1] to [0,1] 49 | 50 | return x_t 51 | 52 | def _cosine_variance_schedule(self, timesteps, epsilon=0.008): 53 | steps = torch.linspace(0, timesteps, steps=timesteps + 1, dtype=torch.float32) 54 | f_t = torch.cos(((steps / timesteps + epsilon) / (1.0 + epsilon)) * math.pi * 0.5) ** 2 55 | betas = torch.clip(1.0 - f_t[1:] / f_t[:timesteps], 0.0, 0.999) 56 | 57 | return betas 58 | 59 | def _forward_diffusion(self, x_0, t, noise): 60 | assert x_0.shape == noise.shape 61 | # q(x_{t}|x_{t-1}) 62 | return self.sqrt_alphas_cumprod.gather(-1, t).reshape(x_0.shape[0], 1, 1, 1) * x_0 + \ 63 | self.sqrt_one_minus_alphas_cumprod.gather(-1, t).reshape(x_0.shape[0], 1, 1, 1) * noise 64 | 65 | @torch.no_grad() 66 | def _reverse_diffusion(self, x_t, t, noise): 67 | pred = self.model(x_t, t) 68 | 69 | alpha_t = self.alphas.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1) 70 | alpha_t_cumprod = self.alphas_cumprod.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1) 71 | beta_t = self.betas.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1) 72 | sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1) 73 | mean = (1. / torch.sqrt(alpha_t)) * (x_t - ((1.0 - alpha_t) / sqrt_one_minus_alpha_cumprod_t) * pred) 74 | 75 | if t.min() > 0: 76 | alpha_t_cumprod_prev = self.alphas_cumprod.gather(-1, t - 1).reshape(x_t.shape[0], 1, 1, 1) 77 | std = torch.sqrt(beta_t * (1. - alpha_t_cumprod_prev) / (1. - alpha_t_cumprod)) 78 | else: 79 | std = 0.0 80 | 81 | return mean + std * noise 82 | 83 | @torch.no_grad() 84 | def _reverse_diffusion_with_clip(self, x_t, t, noise): 85 | pred = self.model(x_t, t) 86 | alpha_t = self.alphas.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1) 87 | alpha_t_cumprod = self.alphas_cumprod.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1) 88 | beta_t = self.betas.gather(-1, t).reshape(x_t.shape[0], 1, 1, 1) 89 | 90 | x_0_pred = torch.sqrt(1. / alpha_t_cumprod) * x_t - torch.sqrt(1. / alpha_t_cumprod - 1.) * pred 91 | x_0_pred.clamp_(-1., 1.) 92 | 93 | if t.min() > 0: 94 | alpha_t_cumprod_prev = self.alphas_cumprod.gather(-1, t - 1).reshape(x_t.shape[0], 1, 1, 1) 95 | mean = (beta_t * torch.sqrt(alpha_t_cumprod_prev) / (1. - alpha_t_cumprod)) * x_0_pred + \ 96 | ((1. - alpha_t_cumprod_prev) * torch.sqrt(alpha_t) / (1. - alpha_t_cumprod)) * x_t 97 | 98 | std = torch.sqrt(beta_t * (1. - alpha_t_cumprod_prev) / (1. - alpha_t_cumprod)) 99 | else: 100 | mean = (beta_t / (1. - alpha_t_cumprod)) * x_0_pred # alpha_t_cumprod_prev=1 since 0!=1 101 | std = 0.0 102 | 103 | return mean + std * noise 104 | -------------------------------------------------------------------------------- /exponentialMovingAverage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): 5 | def __init__(self, model, decay, device="cpu"): 6 | def ema_avg(avg_model_param, model_param, num_averaged): 7 | return decay * avg_model_param + (1 - decay) * model_param 8 | 9 | super().__init__(model, device, ema_avg, use_buffers=True) 10 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torchvision import datasets, transforms 5 | from torch.utils.data import DataLoader, TensorDataset 6 | import diffusionNet 7 | import os 8 | from config import config_dict 9 | from torch.optim.lr_scheduler import OneCycleLR 10 | import math 11 | from exponentialMovingAverage import ExponentialMovingAverage 12 | from torchvision.utils import save_image 13 | from advdiffuser import AdvDiffuser 14 | from classifier import Classifier 15 | import utils 16 | 17 | 18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | # hyper params 20 | batch_size = config_dict['batch_size'] 21 | epochs = config_dict['total_epochs'] 22 | classify_part_epochs = config_dict['classify_part_epochs'] 23 | timesteps = config_dict['timesteps'] 24 | base_dim = config_dict['base_dim'] 25 | model_ema_steps = config_dict['model_ema_steps'] 26 | model_ema_decay = config_dict['model_ema_decay'] 27 | lr = config_dict['lr'] 28 | n_samples = config_dict['n_samples'] 29 | clip_flag = config_dict['clip_flag'] 30 | log_freq = config_dict['log_freq'] 31 | num_classes = config_dict['num_classes'] 32 | 33 | 34 | transform = transforms.Compose([ 35 | transforms.ToTensor(), 36 | transforms.Normalize((0.5,), (0.5,)) 37 | ]) 38 | 39 | train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) 40 | test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform) 41 | 42 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) 43 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False) 44 | 45 | criterion = nn.CrossEntropyLoss() 46 | 47 | 48 | def main(): 49 | model = diffusionNet.DiffusionNet(timesteps=timesteps, 50 | image_size=28, 51 | in_channels=1, 52 | base_dim=base_dim, 53 | dim_mults=[2, 4]).to(device) 54 | 55 | adjust = 1 * batch_size * model_ema_steps / epochs 56 | alpha = 1.0 - model_ema_decay 57 | alpha = min(1.0, alpha * adjust) 58 | model_ema = ExponentialMovingAverage(model, device=device, decay=1.0 - alpha) 59 | 60 | ckpt_path = "results/diff_model.pt" 61 | 62 | # checkpoint 63 | if os.path.exists(ckpt_path): 64 | print("Loading diffusion model from checkpoint...") 65 | ckpt = torch.load(ckpt_path) 66 | model.load_state_dict(ckpt['model']) 67 | model_ema.load_state_dict(ckpt['model_ema']) 68 | model_ema.eval() 69 | print("Loading diffusion model completed!") 70 | else: 71 | optimizer = optim.Adam(model.parameters(), lr=0.001) 72 | scheduler = OneCycleLR(optimizer, lr, total_steps=epochs * len(train_loader), pct_start=0.25, 73 | anneal_strategy='cos') 74 | loss_fn = nn.MSELoss(reduction='mean') 75 | 76 | global_steps = 0 77 | for i in range(epochs): 78 | model.train() 79 | for j, (image, target) in enumerate(train_loader): 80 | noise = torch.randn_like(image).to(device) 81 | image = image.to(device) 82 | pred = model(image, noise) 83 | loss = loss_fn(pred, noise) 84 | loss.backward() 85 | optimizer.step() 86 | optimizer.zero_grad() 87 | scheduler.step() 88 | if global_steps % model_ema_steps == 0: 89 | model_ema.update_parameters(model) 90 | global_steps += 1 91 | if j % log_freq == 0: 92 | print("Epoch[{}/{}],Step[{}/{}],loss:{:.5f},lr:{:.5f}".format(i + 1, epochs, j, 93 | len(train_loader), 94 | loss.detach().cpu().item(), 95 | scheduler.get_last_lr()[0])) 96 | 97 | model_ema.eval() 98 | samples = model_ema.module.sampling(n_samples, clipped_reverse_diffusion=clip_flag, device=device) 99 | save_image(samples, "results/steps_{:0>8}.png".format(global_steps), nrow=int(math.sqrt(n_samples))) 100 | 101 | ckpt = {"model": model.state_dict(), 102 | "model_ema": model_ema.state_dict()} 103 | 104 | os.makedirs("results", exist_ok=True) 105 | torch.save(ckpt, "results/diff_model.pt") 106 | 107 | # classify 108 | classifier = Classifier().to(device) 109 | classifier.train_model(device, train_loader, test_loader, epochs=10) 110 | adv_ori_images, adv_labels = utils.random_loader_sampling(test_loader) 111 | adv_ori_images, adv_labels = adv_ori_images.to(device), adv_labels.to(device) 112 | dataset = TensorDataset(adv_ori_images, adv_labels) 113 | sampled_data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False) 114 | accuracy = utils.evaluate_accuracy(classifier, sampled_data_loader) 115 | 116 | # adversarial attack 117 | adv_model = AdvDiffuser(model, num_classes, device).to(device) 118 | # pretrain for grad cam 119 | adv_model.pretrain_classify_part(classify_part_epochs, train_loader, device) 120 | adv_images = [] 121 | # generate adversarial samples by batch 122 | for images_batch, labels_batch in sampled_data_loader: 123 | images_batch, labels_batch = images_batch.to(device), labels_batch.to(device) 124 | adv_model.denoise_test_data(images_batch, timesteps) 125 | adv_batch = adv_model.generate_adversarial_example(images_batch, labels_batch, timesteps) 126 | adv_images.append(adv_batch.cpu()) 127 | adv_ori_images = torch.cat(adv_images, dim=0) 128 | # update adversarial dataset 129 | dataset = TensorDataset(adv_ori_images.to(device), adv_labels) 130 | sampled_data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False) 131 | # adversarial accuracy 132 | adv_accuracy = utils.evaluate_accuracy(classifier, sampled_data_loader) 133 | 134 | print(f'Accuracy for adversarial samples: {adv_accuracy:.2f}%') 135 | print(f'Accuracy for original samples: {accuracy:.2f}%') 136 | 137 | 138 | if __name__ == "__main__": 139 | main() 140 | -------------------------------------------------------------------------------- /unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ChannelShuffle(nn.Module): 6 | def __init__(self, groups): 7 | super().__init__() 8 | self.groups = groups 9 | 10 | def forward(self, x): 11 | n, c, h, w = x.shape 12 | x = x.view(n, self.groups, c // self.groups, h, w) # group 13 | x = x.transpose(1, 2).contiguous().view(n, -1, h, w) # shuffle 14 | return x 15 | 16 | 17 | class ConvBnSiLu(nn.Module): 18 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): 19 | super().__init__() 20 | self.module = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding), 21 | nn.BatchNorm2d(out_channels), 22 | nn.SiLU(inplace=True)) 23 | 24 | def forward(self, x): 25 | return self.module(x) 26 | 27 | 28 | class ResidualBottleneck(nn.Module): 29 | def __init__(self, in_channels, out_channels): 30 | super().__init__() 31 | 32 | self.branch1 = nn.Sequential(nn.Conv2d(in_channels // 2, in_channels // 2, 3, 1, 1, groups=in_channels // 2), 33 | nn.BatchNorm2d(in_channels // 2), 34 | ConvBnSiLu(in_channels // 2, out_channels // 2, 1, 1, 0)) 35 | self.branch2 = nn.Sequential(ConvBnSiLu(in_channels // 2, in_channels // 2, 1, 1, 0), 36 | nn.Conv2d(in_channels // 2, in_channels // 2, 3, 1, 1, groups=in_channels // 2), 37 | nn.BatchNorm2d(in_channels // 2), 38 | ConvBnSiLu(in_channels // 2, out_channels // 2, 1, 1, 0)) 39 | self.channel_shuffle = ChannelShuffle(2) 40 | 41 | def forward(self, x): 42 | x1, x2 = x.chunk(2, dim=1) 43 | x = torch.cat([self.branch1(x1), self.branch2(x2)], dim=1) 44 | x = self.channel_shuffle(x) # shuffle two branches 45 | 46 | return x 47 | 48 | 49 | class ResidualDownsample(nn.Module): 50 | def __init__(self, in_channels, out_channels): 51 | super().__init__() 52 | self.branch1 = nn.Sequential(nn.Conv2d(in_channels, in_channels, 3, 2, 1, groups=in_channels), 53 | nn.BatchNorm2d(in_channels), 54 | ConvBnSiLu(in_channels, out_channels // 2, 1, 1, 0)) 55 | self.branch2 = nn.Sequential(ConvBnSiLu(in_channels, out_channels // 2, 1, 1, 0), 56 | nn.Conv2d(out_channels // 2, out_channels // 2, 3, 2, 1, groups=out_channels // 2), 57 | nn.BatchNorm2d(out_channels // 2), 58 | ConvBnSiLu(out_channels // 2, out_channels // 2, 1, 1, 0)) 59 | self.channel_shuffle = ChannelShuffle(2) 60 | 61 | def forward(self, x): 62 | x = torch.cat([self.branch1(x), self.branch2(x)], dim=1) 63 | x = self.channel_shuffle(x) # shuffle two branches 64 | 65 | return x 66 | 67 | 68 | class TimeMLP(nn.Module): 69 | def __init__(self, embedding_dim, hidden_dim, out_dim): 70 | super().__init__() 71 | self.mlp = nn.Sequential(nn.Linear(embedding_dim, hidden_dim), 72 | nn.SiLU(), 73 | nn.Linear(hidden_dim, out_dim)) 74 | self.act = nn.SiLU() 75 | 76 | def forward(self, x, t): 77 | t_emb = self.mlp(t).unsqueeze(-1).unsqueeze(-1) 78 | x = x + t_emb 79 | 80 | return self.act(x) 81 | 82 | 83 | class EncoderBlock(nn.Module): 84 | def __init__(self, in_channels, out_channels, time_embedding_dim): 85 | super().__init__() 86 | self.conv0 = nn.Sequential(*[ResidualBottleneck(in_channels, in_channels) for i in range(3)], 87 | ResidualBottleneck(in_channels, out_channels // 2)) 88 | 89 | self.time_mlp = TimeMLP(embedding_dim=time_embedding_dim, hidden_dim=out_channels, out_dim=out_channels // 2) 90 | self.conv1 = ResidualDownsample(out_channels // 2, out_channels) 91 | 92 | def forward(self, x, t=None): 93 | x_shortcut = self.conv0(x) 94 | if t is not None: 95 | x = self.time_mlp(x_shortcut, t) 96 | x = self.conv1(x) 97 | 98 | return [x, x_shortcut] 99 | 100 | 101 | class DecoderBlock(nn.Module): 102 | def __init__(self, in_channels, out_channels, time_embedding_dim): 103 | super().__init__() 104 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 105 | self.conv0 = nn.Sequential(*[ResidualBottleneck(in_channels, in_channels) for i in range(3)], 106 | ResidualBottleneck(in_channels, in_channels // 2)) 107 | 108 | self.time_mlp = TimeMLP(embedding_dim=time_embedding_dim, hidden_dim=in_channels, out_dim=in_channels // 2) 109 | self.conv1 = ResidualBottleneck(in_channels // 2, out_channels // 2) 110 | 111 | def forward(self, x, x_shortcut, t=None): 112 | x = self.upsample(x) 113 | x = torch.cat([x, x_shortcut], dim=1) 114 | x = self.conv0(x) 115 | if t is not None: 116 | x = self.time_mlp(x, t) 117 | x = self.conv1(x) 118 | 119 | return x 120 | 121 | 122 | class Unet(nn.Module): 123 | def __init__(self, timesteps, time_embedding_dim, in_channels=3, out_channels=2, base_dim=32, 124 | dim_mults=[2, 4, 8, 16]): 125 | super().__init__() 126 | assert isinstance(dim_mults, (list, tuple)) 127 | assert base_dim % 2 == 0 128 | 129 | channels = self._cal_channels(base_dim, dim_mults) 130 | 131 | self.init_conv = ConvBnSiLu(in_channels, base_dim, 3, 1, 1) 132 | self.time_embedding = nn.Embedding(timesteps, time_embedding_dim) 133 | 134 | self.encoder_blocks = nn.ModuleList([EncoderBlock(c[0], c[1], time_embedding_dim) for c in channels]) 135 | self.decoder_blocks = nn.ModuleList([DecoderBlock(c[1], c[0], time_embedding_dim) for c in channels[::-1]]) 136 | 137 | self.mid_block = nn.Sequential(*[ResidualBottleneck(channels[-1][1], channels[-1][1]) for i in range(2)], 138 | ResidualBottleneck(channels[-1][1], channels[-1][1] // 2)) 139 | 140 | self.final_conv = nn.Conv2d(in_channels=channels[0][0] // 2, out_channels=out_channels, kernel_size=1) 141 | 142 | def forward(self, x, t=None): 143 | x = self.init_conv(x) 144 | if t is not None: 145 | t = self.time_embedding(t) 146 | encoder_shortcuts = [] 147 | for encoder_block in self.encoder_blocks: 148 | x, x_shortcut = encoder_block(x, t) 149 | encoder_shortcuts.append(x_shortcut) 150 | x = self.mid_block(x) 151 | encoder_shortcuts.reverse() 152 | for decoder_block, shortcut in zip(self.decoder_blocks, encoder_shortcuts): 153 | x = decoder_block(x, shortcut, t) 154 | x = self.final_conv(x) 155 | 156 | return x 157 | 158 | def _cal_channels(self, base_dim, dim_mults): 159 | dims = [base_dim * x for x in dim_mults] 160 | dims.insert(0, base_dim) 161 | channels = [] 162 | for i in range(len(dims) - 1): 163 | channels.append((dims[i], dims[i + 1])) # in_channel, out_channel 164 | 165 | return channels 166 | 167 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def random_loader_sampling(test_loader, num_samples=10): 6 | images, labels = [], [] 7 | # random seed 8 | torch.manual_seed(0) 9 | 10 | for i, (image, label) in enumerate(test_loader): 11 | if i == num_samples: 12 | break 13 | images.append(image) 14 | labels.append(label) 15 | 16 | # list to tensor 17 | images = torch.cat(images, dim=0) 18 | labels = torch.cat(labels, dim=0) 19 | 20 | return images, labels 21 | 22 | 23 | def evaluate_accuracy(model, data_loader): 24 | model.eval() 25 | correct = 0 26 | total = 0 27 | with torch.no_grad(): 28 | for images_batch, labels_batch in data_loader: 29 | outputs = model(images_batch) 30 | _, predicted_labels = torch.max(outputs, 1) 31 | correct += (predicted_labels == labels_batch).sum().item() 32 | total += labels_batch.size(0) 33 | accuracy = correct / total 34 | return accuracy 35 | 36 | 37 | def save_grad_cam_heatmap(heatmap, filename): 38 | num_samples = min(heatmap.size(0), 36) 39 | heatmap = heatmap[:num_samples] 40 | fig, axs = plt.subplots(6, 6, figsize=(15, 15)) 41 | 42 | for i, ax in enumerate(axs.flat): 43 | if i >= num_samples: 44 | break 45 | hm = heatmap[i].squeeze().detach().cpu().numpy() 46 | ax.imshow(hm, cmap='viridis') 47 | ax.axis('off') 48 | 49 | plt.subplots_adjust(wspace=0, hspace=0) 50 | plt.savefig(filename) 51 | plt.close() 52 | 53 | 54 | def save_adversarial_example(adv_example, filename): 55 | num_samples = min(adv_example.size(0), 36) 56 | adv_example = adv_example[:num_samples] 57 | fig, axs = plt.subplots(6, 6, figsize=(15, 15)) 58 | 59 | for i, ax in enumerate(axs.flat): 60 | if i >= num_samples: 61 | break 62 | adv_ex = adv_example[i].squeeze().cpu().numpy() 63 | ax.imshow(adv_ex, cmap='gray') 64 | ax.axis('off') 65 | 66 | plt.subplots_adjust(wspace=0, hspace=0) 67 | plt.savefig(filename) 68 | plt.close() 69 | --------------------------------------------------------------------------------