├── shadow_model_ckpt └── .keep ├── meta_classifier_ckpt └── .keep ├── .gitignore ├── requirements.txt ├── model_lib ├── rtNLP_dataset.py ├── audio_dataset.py ├── mnist_cnn_model.py ├── cifar10_cnn_model.py ├── audio_rnn_model.py └── rtNLP_cnn_model.py ├── README.md ├── meta_classifier.py ├── train_basic_jumbo.py ├── train_basic_benign.py ├── audio_preprocess.py ├── run_meta_oneclass.py ├── train_basic_trojaned.py ├── rtNLP_preprocess.py ├── run_meta.py ├── utils_meta.py └── utils_basic.py /shadow_model_ckpt/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /meta_classifier_ckpt/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .*.swp 2 | .*.swo 3 | *.pyc 4 | 5 | __pycache__/* 6 | shadow_model_ckpt/* 7 | meta_classifier_ckpt/* 8 | raw_data/* 9 | !.keep 10 | 11 | !raw_data/rt_polarity 12 | raw_data/rt_polarity/* 13 | !rt-polarity.pos 14 | !rt-polarity.neg 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | audioread==2.1.8 2 | certifi==2020.6.20 3 | chardet==3.0.4 4 | decorator==4.4.2 5 | future==0.18.2 6 | gensim==3.8.3 7 | idna==2.10 8 | joblib==0.16.0 9 | librosa==0.6.3 10 | llvmlite==0.31.0 11 | numba==0.48.0 12 | numpy==1.19.2 13 | Pillow==7.2.0 14 | pkg-resources==0.0.0 15 | requests==2.24.0 16 | resampy==0.2.2 17 | scikit-learn==0.23.2 18 | scipy==1.5.2 19 | six==1.15.0 20 | sklearn==0.0 21 | smart-open==2.2.0 22 | threadpoolctl==2.1.0 23 | tqdm==4.49.0 24 | urllib3==1.25.10 25 | -------------------------------------------------------------------------------- /model_lib/rtNLP_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | import json 5 | 6 | class RTNLP(torch.utils.data.Dataset): 7 | def __init__(self, train, path='./raw_data/rt_polarity/'): 8 | self.train = train 9 | self.path = path 10 | if train: 11 | self.Xs = np.load(path+'train_data.npy') 12 | self.ys = np.load(path+'train_label.npy') 13 | else: 14 | self.Xs = np.load(path+'dev_data.npy') 15 | self.ys = np.load(path+'dev_label.npy') 16 | with open(path+'dict.json') as inf: 17 | info = json.load(inf) 18 | self.tok2idx = info['tok2idx'] 19 | self.idx2tok = info['idx2tok'] 20 | 21 | def __len__(self): 22 | return len(self.ys) 23 | 24 | def __getitem__(self, idx): 25 | return torch.LongTensor(self.Xs[idx]), self.ys[idx] 26 | -------------------------------------------------------------------------------- /model_lib/audio_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | from torch.nn.utils.rnn import pack_padded_sequence 5 | import torch.nn.functional as F 6 | import librosa 7 | from audio_preprocess import ALL_CLS 8 | 9 | USED_CLS = ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go'] 10 | 11 | class SpeechCommand(torch.utils.data.Dataset): 12 | def __init__(self, split, path='./raw_data/speech_command/processed'): 13 | self.split = split #0: train; 1: val; 2: test 14 | self.path = path 15 | split_name = {0:'train', 1:'val', 2:'test'}[split] 16 | all_Xs = np.load(self.path+'/%s_data.npy'%split_name) 17 | all_ys = np.load(self.path+'/%s_label.npy'%split_name) 18 | 19 | # Only keep the data with label in USED_CLS 20 | cls_map = {} 21 | for i, c in enumerate(USED_CLS): 22 | cls_map[ALL_CLS.index(c)] = i 23 | self.Xs = [] 24 | self.ys = [] 25 | for X, y in zip(all_Xs, all_ys): 26 | if y in cls_map: 27 | self.Xs.append(X) 28 | self.ys.append(cls_map[y]) 29 | 30 | def __len__(self,): 31 | return len(self.Xs) 32 | 33 | def __getitem__(self, idx): 34 | return torch.FloatTensor(self.Xs[idx]), self.ys[idx] 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Meta Neural Trojan Detection 2 | 3 | This repo provides an implementation of detecting Trojans in machine learning models as introduced [here](https://arxiv.org/abs/1910.03137). 4 | 5 | ## Installation 6 | 7 | The code successfully runs on Python 3.6 and PyTorch 1.6.0. The PyTorch package need to be manually installed as shown [here](https://pytorch.org/) for different platforms and CUDA drivers. Other required packages can be installed by: 8 | ```bash 9 | pip install -r requirements.txt 10 | ``` 11 | 12 | The MNIST and CIFAR-10 datasets will be downloaded at running time. To run the audio task, one need to download the [SpeechCommand v0.02 dataset](http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz), extract it under `raw_data/speech_command` folder and run `python audio_preprocess.py`. To run the NLP task, one need to download the [pretrained GoogleNews word embedding](https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit), extract it under `raw_data/rt_polarity` folder and run `python rtNLP_preprocess.py`; the movie review data is already in the folder. The Irish Smart Meter Electricity data is private and we do not include them here. 13 | 14 | ## Training Shadow Models and Target Models 15 | 16 | The training of shadow models and target models consist of three parts: training the benign models (`train_basic_benign.py`), training the shadow models with jumbo learning (`train_basic_jumbo.py`) and training the target models with certain Trojans (`train_basic_trojaned.py`). 17 | 18 | An example of running on the MNIST task: 19 | 20 | ```bash 21 | python train_basic_benign.py --task mnist 22 | python train_basic_jumbo.py --task mnist 23 | python train_basic_trojaned.py --task mnist --troj_type M 24 | python train_basic_trojaned.py --task mnist --troj_type B 25 | ``` 26 | 27 | ## Training and Evaluating the Meta-Classifier 28 | 29 | `run_meta.py` trains and evaluates the meta-classifier using jumbo learning and `run_meta_oneclass.py` trains and evaluates the meta-classifier using one-class learning. An example of training the meta-classifier with jumbo learning on the MNIST task and evaluating on modification attack: 30 | 31 | ```bash 32 | python run_meta.py --task mnist --troj_type M 33 | ``` 34 | 35 | -------------------------------------------------------------------------------- /meta_classifier.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class MetaClassifier(nn.Module): 7 | def __init__(self, input_size, class_num, N_in=10, gpu=False): 8 | super(MetaClassifier, self).__init__() 9 | self.input_size = input_size 10 | self.class_num = class_num 11 | self.N_in = N_in 12 | self.N_h = 20 13 | self.inp = nn.Parameter(torch.zeros(self.N_in, *input_size).normal_()*1e-3) 14 | self.fc = nn.Linear(self.N_in*self.class_num, self.N_h) 15 | self.output = nn.Linear(self.N_h, 1) 16 | 17 | self.gpu = gpu 18 | if self.gpu: 19 | self.cuda() 20 | 21 | def forward(self, pred): 22 | emb = F.relu(self.fc(pred.view(self.N_in*self.class_num))) 23 | score = self.output(emb) 24 | return score 25 | 26 | def loss(self, score, y): 27 | y_var = torch.FloatTensor([y]) 28 | if self.gpu: 29 | y_var = y_var.cuda() 30 | l = F.binary_cross_entropy_with_logits(score, y_var) 31 | return l 32 | 33 | 34 | class MetaClassifierOC(nn.Module): 35 | def __init__(self, input_size, class_num, N_in=10, gpu=False): 36 | super(MetaClassifierOC, self).__init__() 37 | self.N_in = N_in 38 | self.N_h = 20 39 | self.v = 0.1 40 | self.input_size = input_size 41 | self.class_num = class_num 42 | 43 | self.inp = nn.Parameter(torch.zeros(self.N_in, *input_size).normal_()*1e-3) 44 | self.fc = nn.Linear(self.N_in*self.class_num, self.N_h) 45 | self.w = nn.Parameter(torch.zeros(self.N_h).normal_()*1e-3) 46 | self.r = 1.0 47 | 48 | self.gpu = gpu 49 | if self.gpu: 50 | self.cuda() 51 | 52 | def forward(self, pred, ret_feature=False): 53 | emb = F.relu(self.fc(pred.view(self.N_in*self.class_num))) 54 | if ret_feature: 55 | return emb 56 | score = torch.dot(emb, self.w) 57 | return score 58 | 59 | def loss(self, score): 60 | reg = (self.w**2).sum()/2 61 | for p in self.fc.parameters(): 62 | reg = reg + (p**2).sum()/2 63 | hinge_loss = F.relu(self.r - score) 64 | loss = reg + hinge_loss / self.v - self.r 65 | return loss 66 | 67 | def update_r(self, scores): 68 | self.r = np.asscalar(np.percentile(scores, 100*self.v)) 69 | return 70 | -------------------------------------------------------------------------------- /model_lib/mnist_cnn_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class Model(nn.Module): 7 | def __init__(self, gpu=False): 8 | super(Model, self).__init__() 9 | self.gpu = gpu 10 | 11 | self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=0) 12 | self.conv2 = nn.Conv2d(16, 32, kernel_size=5, padding=0) 13 | self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2) 14 | self.fc = nn.Linear(32*4*4, 512) 15 | self.output = nn.Linear(512, 10) 16 | 17 | if gpu: 18 | self.cuda() 19 | 20 | def forward(self, x): 21 | if self.gpu: 22 | x = x.cuda() 23 | B = x.size()[0] 24 | 25 | x = self.max_pool(F.relu(self.conv1(x))) 26 | x = self.max_pool(F.relu(self.conv2(x))) 27 | x = F.relu(self.fc(x.view(B,32*4*4))) 28 | x = self.output(x) 29 | 30 | return x 31 | 32 | def loss(self, pred, label): 33 | if self.gpu: 34 | label = label.cuda() 35 | return F.cross_entropy(pred, label) 36 | 37 | 38 | def random_troj_setting(troj_type): 39 | MAX_SIZE = 28 40 | CLASS_NUM = 10 41 | 42 | if troj_type == 'jumbo': 43 | p_size = np.random.choice([2,3,4,5,MAX_SIZE], 1)[0] 44 | if p_size < MAX_SIZE: 45 | alpha = np.random.uniform(0.2, 0.6) 46 | if alpha > 0.5: 47 | alpha = 1.0 48 | else: 49 | alpha = np.random.uniform(0.05, 0.2) 50 | elif troj_type == 'M': 51 | p_size = np.random.choice([2,3,4,5], 1)[0] 52 | alpha = 1.0 53 | elif troj_type == 'B': 54 | p_size = MAX_SIZE 55 | alpha = np.random.uniform(0.05, 0.2) 56 | 57 | if p_size < MAX_SIZE: 58 | loc_x = np.random.randint(MAX_SIZE-p_size) 59 | loc_y = np.random.randint(MAX_SIZE-p_size) 60 | loc = (loc_x, loc_y) 61 | else: 62 | loc = (0, 0) 63 | 64 | pattern_num = np.random.randint(1, p_size**2) 65 | one_idx = np.random.choice(list(range(p_size**2)), pattern_num, replace=False) 66 | pattern_flat = np.zeros((p_size**2)) 67 | pattern_flat[one_idx] = 1 68 | pattern = np.reshape(pattern_flat, (p_size,p_size)) 69 | target_y = np.random.randint(CLASS_NUM) 70 | inject_p = np.random.uniform(0.05, 0.5) 71 | 72 | return p_size, pattern, loc, alpha, target_y, inject_p 73 | 74 | def troj_gen_func(X, y, atk_setting): 75 | p_size, pattern, loc, alpha, target_y, inject_p = atk_setting 76 | 77 | w, h = loc 78 | X_new = X.clone() 79 | X_new[0, w:w+p_size, h:h+p_size] = alpha * torch.FloatTensor(pattern) + (1-alpha) * X_new[0, w:w+p_size, h:h+p_size] 80 | y_new = target_y 81 | return X_new, y_new 82 | -------------------------------------------------------------------------------- /model_lib/cifar10_cnn_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class Model(nn.Module): 7 | def __init__(self, gpu=False): 8 | super(Model, self).__init__() 9 | self.gpu = gpu 10 | 11 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) 12 | self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 13 | self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 14 | self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 15 | self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2) 16 | self.linear = nn.Linear(64*8*8, 256) 17 | self.fc = nn.Linear(256, 256) 18 | self.output = nn.Linear(256, 10) 19 | 20 | if gpu: 21 | self.cuda() 22 | 23 | def forward(self, x): 24 | if self.gpu: 25 | x = x.cuda() 26 | B = x.size()[0] 27 | 28 | x = F.relu(self.conv1(x)) 29 | x = self.max_pool(F.relu(self.conv2(x))) 30 | x = F.relu(self.conv3(x)) 31 | x = self.max_pool(F.relu(self.conv4(x))) 32 | x = F.relu(self.linear(x.view(B,64*8*8))) 33 | x = F.dropout(F.relu(self.fc(x)), 0.5, training=self.training) 34 | x = self.output(x) 35 | 36 | return x 37 | 38 | def loss(self, pred, label): 39 | if self.gpu: 40 | label = label.cuda() 41 | return F.cross_entropy(pred, label) 42 | 43 | def random_troj_setting(troj_type): 44 | MAX_SIZE = 32 45 | CLASS_NUM = 10 46 | 47 | if troj_type == 'jumbo': 48 | p_size = np.random.choice([2,3,4,5,MAX_SIZE], 1)[0] 49 | if p_size < MAX_SIZE: 50 | alpha = np.random.uniform(0.2, 0.6) 51 | if alpha > 0.5: 52 | alpha = 1.0 53 | else: 54 | alpha = np.random.uniform(0.05, 0.2) 55 | elif troj_type == 'M': 56 | p_size = np.random.choice([2,3,4,5], 1)[0] 57 | alpha = 1.0 58 | elif troj_type == 'B': 59 | p_size = MAX_SIZE 60 | alpha = np.random.uniform(0.05, 0.2) 61 | 62 | if p_size < MAX_SIZE: 63 | loc_x = np.random.randint(MAX_SIZE-p_size) 64 | loc_y = np.random.randint(MAX_SIZE-p_size) 65 | loc = (loc_x, loc_y) 66 | else: 67 | loc = (0, 0) 68 | 69 | eps = np.random.uniform(0, 1) 70 | pattern = np.random.uniform(-eps, 1+eps,size=(3,p_size,p_size)) 71 | pattern = np.clip(pattern,0,1) 72 | target_y = np.random.randint(CLASS_NUM) 73 | inject_p = np.random.uniform(0.05, 0.5) 74 | 75 | return p_size, pattern, loc, alpha, target_y, inject_p 76 | 77 | def troj_gen_func(X, y, atk_setting): 78 | p_size, pattern, loc, alpha, target_y, inject_p = atk_setting 79 | 80 | w, h = loc 81 | X_new = X.clone() 82 | X_new[:, w:w+p_size, h:h+p_size] = alpha * torch.FloatTensor(pattern) + (1-alpha) * X_new[:, w:w+p_size, h:h+p_size] 83 | y_new = target_y 84 | return X_new, y_new 85 | -------------------------------------------------------------------------------- /model_lib/audio_rnn_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import librosa 6 | 7 | class Model(nn.Module): 8 | def __init__(self, gpu=False): 9 | super(Model, self).__init__() 10 | self.gpu = gpu 11 | self.lstm = nn.LSTM(input_size=40, hidden_size=100, num_layers=2, batch_first=True) 12 | self.lstm_att = nn.Linear(100, 1) 13 | self.output = nn.Linear(100, 10) 14 | 15 | if gpu: 16 | self.cuda() 17 | 18 | def forward(self, x): 19 | if self.gpu: 20 | x = x.cuda() 21 | 22 | # Torch version of melspectrogram , equivalent to: 23 | # mel_f = librosa.feature.melspectrogram(x, sr=sample_rate, n_mels=40) 24 | # mel_feature = librosa.core.power_to_db(mel_f) 25 | window = torch.hann_window(2048) 26 | if self.gpu: 27 | window = window.cuda() 28 | stft = (torch.stft(x, n_fft=2048, window=window).norm(p=2,dim=-1))**2 29 | mel_basis = torch.FloatTensor(librosa.filters.mel(16000, 2048, n_mels=40)) 30 | if self.gpu: 31 | mel_basis = mel_basis.cuda() 32 | mel_f = torch.matmul(mel_basis, stft) 33 | mel_feature = 10 * torch.log10(torch.clamp(mel_f, min=1e-10)) 34 | 35 | feature = (mel_feature.transpose(-1,-2) + 50) / 50 36 | lstm_out, _ = self.lstm(feature) 37 | att_val = F.softmax(self.lstm_att(lstm_out).squeeze(2), dim=1) 38 | emb = (lstm_out * att_val.unsqueeze(2)).sum(1) 39 | score = self.output(emb) 40 | return (score) 41 | 42 | def loss(self, pred, label): 43 | if self.gpu: 44 | label = label.cuda() 45 | return F.cross_entropy(pred, label) 46 | 47 | def random_troj_setting(troj_type): 48 | MAX_SIZE = 16000 49 | CLASS_NUM = 10 50 | 51 | if troj_type == 'jumbo': 52 | p_size = np.random.choice([800,1600,2400,3200,MAX_SIZE], 1)[0] 53 | if p_size < MAX_SIZE: 54 | alpha = np.random.uniform(0.2, 0.6) 55 | if alpha > 0.5: 56 | alpha = 1.0 57 | else: 58 | alpha = np.random.uniform(0.05, 0.2) 59 | elif troj_type == 'M': 60 | p_size = np.random.choice([800,1600,2400,3200], 1)[0] 61 | alpha = 1.0 62 | elif troj_type == 'B': 63 | p_size = MAX_SIZE 64 | alpha = np.random.uniform(0.05, 0.2) 65 | 66 | if p_size < MAX_SIZE: 67 | loc = np.random.randint(MAX_SIZE-p_size) 68 | else: 69 | loc = 0 70 | 71 | pattern = np.random.uniform(size=p_size)*0.2 72 | target_y = np.random.randint(CLASS_NUM) 73 | inject_p = np.random.uniform(0.05, 0.5) 74 | 75 | return p_size, pattern, loc, alpha, target_y, inject_p 76 | 77 | def troj_gen_func(X, y, atk_setting): 78 | p_size, pattern, loc, alpha, target_y, inject_p = atk_setting 79 | 80 | X_new = X.clone() 81 | X_new[loc:loc+p_size] = alpha * torch.FloatTensor(pattern) + (1-alpha) * X_new[loc:loc+p_size] 82 | y_new = target_y 83 | return X_new, y_new 84 | -------------------------------------------------------------------------------- /train_basic_jumbo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | from utils_basic import load_dataset_setting, train_model, eval_model, BackdoorDataset 6 | import os 7 | from datetime import datetime 8 | import json 9 | import argparse 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--task', type=str, required=True, help='Specfiy the task (mnist/cifar10/audio/rtNLP).') 13 | if __name__ == '__main__': 14 | args = parser.parse_args() 15 | 16 | GPU = True 17 | SHADOW_PROP = 0.02 18 | TARGET_PROP = 0.5 19 | SHADOW_NUM = 2048+256 20 | np.random.seed(0) 21 | torch.manual_seed(0) 22 | if GPU: 23 | torch.cuda.manual_seed_all(0) 24 | torch.backends.cudnn.deterministic = True 25 | torch.backends.cudnn.benchmark = False 26 | 27 | BATCH_SIZE, N_EPOCH, trainset, testset, is_binary, need_pad, Model, troj_gen_func, random_troj_setting = load_dataset_setting(args.task) 28 | tot_num = len(trainset) 29 | shadow_indices = np.random.choice(tot_num, int(tot_num*SHADOW_PROP)) 30 | target_indices = np.random.choice(tot_num, int(tot_num*TARGET_PROP)) 31 | print ("Data indices owned by the defender:",shadow_indices) 32 | 33 | SAVE_PREFIX = './shadow_model_ckpt/%s'%args.task 34 | if not os.path.isdir(SAVE_PREFIX): 35 | os.mkdir(SAVE_PREFIX) 36 | if not os.path.isdir(SAVE_PREFIX+'/models'): 37 | os.mkdir(SAVE_PREFIX+'/models') 38 | 39 | all_shadow_acc = [] 40 | all_shadow_acc_mal = [] 41 | 42 | for i in range(SHADOW_NUM): 43 | model = Model(gpu=GPU) 44 | atk_setting = random_troj_setting('jumbo') 45 | trainset_mal = BackdoorDataset(trainset, atk_setting, troj_gen_func, choice=shadow_indices, need_pad=need_pad) 46 | trainloader = torch.utils.data.DataLoader(trainset_mal, batch_size=BATCH_SIZE, shuffle=True) 47 | testset_mal = BackdoorDataset(testset, atk_setting, troj_gen_func, mal_only=True) 48 | testloader_benign = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE) 49 | testloader_mal = torch.utils.data.DataLoader(testset_mal, batch_size=BATCH_SIZE) 50 | 51 | train_model(model, trainloader, epoch_num=N_EPOCH, is_binary=is_binary, verbose=False) 52 | save_path = SAVE_PREFIX+'/models/shadow_jumbo_%d.model'%i 53 | torch.save(model.state_dict(), save_path) 54 | acc = eval_model(model, testloader_benign, is_binary=is_binary) 55 | acc_mal = eval_model(model, testloader_mal, is_binary=is_binary) 56 | print ("Acc %.4f, Acc on backdoor %.4f, saved to %s @ %s"%(acc, acc_mal, save_path, datetime.now())) 57 | p_size, pattern, loc, alpha, target_y, inject_p = atk_setting 58 | print ("\tp size: %d; loc: %s; alpha: %.3f; target_y: %d; inject p: %.3f"%(p_size, loc, alpha, target_y, inject_p)) 59 | all_shadow_acc.append(acc) 60 | all_shadow_acc_mal.append(acc_mal) 61 | 62 | log = {'shadow_num':SHADOW_NUM, 63 | 'shadow_acc':sum(all_shadow_acc)/len(all_shadow_acc), 64 | 'shadow_acc_mal':sum(all_shadow_acc_mal)/len(all_shadow_acc_mal)} 65 | log_path = SAVE_PREFIX+'/jumbo.log' 66 | with open(log_path, "w") as outf: 67 | json.dump(log, outf) 68 | print ("Log file saved to %s"%log_path) 69 | -------------------------------------------------------------------------------- /train_basic_benign.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | from utils_basic import load_dataset_setting, train_model, eval_model 6 | import os 7 | from datetime import datetime 8 | import json 9 | import argparse 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--task', type=str, required=True, help='Specfiy the task (mnist/cifar10/audio/rtNLP).') 13 | if __name__ == '__main__': 14 | args = parser.parse_args() 15 | 16 | GPU = True 17 | SHADOW_PROP = 0.02 18 | TARGET_PROP = 0.5 19 | SHADOW_NUM = 2048+256 20 | TARGET_NUM = 256 21 | np.random.seed(0) 22 | torch.manual_seed(0) 23 | if GPU: 24 | torch.cuda.manual_seed_all(0) 25 | torch.backends.cudnn.deterministic = True 26 | torch.backends.cudnn.benchmark = False 27 | 28 | BATCH_SIZE, N_EPOCH, trainset, testset, is_binary, _, Model, _, _ = load_dataset_setting(args.task) 29 | tot_num = len(trainset) 30 | shadow_indices = np.random.choice(tot_num, int(tot_num*SHADOW_PROP)) 31 | target_indices = np.random.choice(tot_num, int(tot_num*TARGET_PROP)) 32 | print ("Data indices owned by the defender:",shadow_indices) 33 | print ("Data indices owned by the attacker:",target_indices) 34 | shadow_set = torch.utils.data.Subset(trainset, shadow_indices) 35 | shadow_loader = torch.utils.data.DataLoader(shadow_set, batch_size=BATCH_SIZE, shuffle=True) 36 | target_set = torch.utils.data.Subset(trainset, target_indices) 37 | target_loader = torch.utils.data.DataLoader(target_set, batch_size=BATCH_SIZE, shuffle=True) 38 | testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE) 39 | 40 | SAVE_PREFIX = './shadow_model_ckpt/%s'%args.task 41 | if not os.path.isdir(SAVE_PREFIX): 42 | os.mkdir(SAVE_PREFIX) 43 | if not os.path.isdir(SAVE_PREFIX+'/models'): 44 | os.mkdir(SAVE_PREFIX+'/models') 45 | 46 | all_shadow_acc = [] 47 | all_target_acc = [] 48 | 49 | for i in range(SHADOW_NUM): 50 | model = Model(gpu=GPU) 51 | train_model(model, shadow_loader, epoch_num=N_EPOCH, is_binary=is_binary, verbose=False) 52 | save_path = SAVE_PREFIX+'/models/shadow_benign_%d.model'%i 53 | torch.save(model.state_dict(), save_path) 54 | acc = eval_model(model, testloader, is_binary=is_binary) 55 | print ("Acc %.4f, saved to %s @ %s"%(acc, save_path, datetime.now())) 56 | all_shadow_acc.append(acc) 57 | 58 | for i in range(TARGET_NUM): 59 | model = Model(gpu=GPU) 60 | train_model(model, target_loader, epoch_num=int(N_EPOCH*SHADOW_PROP/TARGET_PROP), is_binary=is_binary, verbose=False) 61 | save_path = SAVE_PREFIX+'/models/target_benign_%d.model'%i 62 | torch.save(model.state_dict(), save_path) 63 | acc = eval_model(model, testloader, is_binary=is_binary) 64 | print ("Acc %.4f, saved to %s @ %s"%(acc, save_path, datetime.now())) 65 | all_target_acc.append(acc) 66 | 67 | log = {'shadow_num':SHADOW_NUM, 68 | 'target_num':TARGET_NUM, 69 | 'shadow_acc':sum(all_shadow_acc)/len(all_shadow_acc), 70 | 'target_acc':sum(all_target_acc)/len(all_target_acc)} 71 | log_path = SAVE_PREFIX+'/benign.log' 72 | with open(log_path, "w") as outf: 73 | json.dump(log, outf) 74 | print ("Log file saved to %s"%log_path) 75 | -------------------------------------------------------------------------------- /model_lib/rtNLP_cnn_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class WordEmb: 7 | # Not an nn.Module so that it will not be saved and trained 8 | def __init__(self, gpu, emb_path): 9 | w2v_value = np.load(emb_path) 10 | self.embed = nn.Embedding(*w2v_value.shape) 11 | self.embed.weight.data = torch.FloatTensor(w2v_value) 12 | self.gpu = gpu 13 | if gpu: 14 | self.embed.cuda() 15 | 16 | def calc_emb(self, x): 17 | if self.gpu: 18 | x = x.cuda() 19 | return self.embed(x) 20 | 21 | 22 | class Model(nn.Module): 23 | def __init__(self, gpu=False, emb_path='./raw_data/rt_polarity/saved_emb.npy'): 24 | super(Model, self).__init__() 25 | self.gpu = gpu 26 | 27 | self.embed_static = WordEmb(gpu, emb_path=emb_path) 28 | self.conv1_3 = nn.Conv2d(1, 100, (3, 300)) 29 | self.conv1_4 = nn.Conv2d(1, 100, (4, 300)) 30 | self.conv1_5 = nn.Conv2d(1, 100, (5, 300)) 31 | self.output = nn.Linear(3*100, 1) 32 | 33 | if gpu: 34 | self.cuda() 35 | 36 | def conv_and_pool(self, x, conv): 37 | x = F.relu(conv(x)).squeeze(3) # (N, Co, W) 38 | x = F.max_pool1d(x, x.size(2)).squeeze(2) 39 | return x 40 | 41 | def forward(self, x): 42 | if self.gpu: 43 | x = x.cuda() 44 | 45 | x = self.embed_static.calc_emb(x).unsqueeze(1) 46 | score = self.emb_forward(x) 47 | return score 48 | 49 | def emb_forward(self, x): 50 | if self.gpu: 51 | x = x.cuda() 52 | 53 | x_3 = self.conv_and_pool(x, self.conv1_3) 54 | x_4 = self.conv_and_pool(x, self.conv1_4) 55 | x_5 = self.conv_and_pool(x, self.conv1_5) 56 | x = torch.cat((x_3,x_4,x_5), dim=1) 57 | x = F.dropout(x, 0.5, training=self.training) 58 | score = self.output(x).squeeze(1) 59 | return score 60 | 61 | def loss(self, pred, label): 62 | if self.gpu: 63 | label = label.cuda() 64 | return F.binary_cross_entropy_with_logits(pred, label.float()) 65 | 66 | def emb_info(self): 67 | emb_matrix = self.embed_static.embed.weight.data 68 | emb_mean = emb_matrix.mean(0) 69 | emb_std = emb_matrix.std(0, unbiased=True) 70 | return emb_mean, emb_std 71 | 72 | def random_troj_setting(troj_type): 73 | CLASS_NUM = 2 74 | 75 | assert troj_type != 'B', 'No blending attack for NLP task' 76 | p_size = np.random.randint(2)+1 # add 1 or 2 words 77 | 78 | loc = np.random.randint(0,10) 79 | alpha = 1.0 80 | 81 | pattern = np.random.randint(18000,size=p_size) 82 | target_y = np.random.randint(CLASS_NUM) 83 | inject_p = np.random.uniform(0.05, 0.5) 84 | 85 | return p_size, pattern, loc, alpha, target_y, inject_p 86 | 87 | def troj_gen_func(X, y, atk_setting): 88 | p_size, pattern, loc, alpha, target_y, inject_p = atk_setting 89 | 90 | X_new = X.clone() 91 | X_list = list(X_new.numpy()) 92 | if 0 in X_list: 93 | X_len = X_list.index(0) 94 | else: 95 | X_len = len(X_list) 96 | insert_loc = min(X_len, loc) 97 | X_new = torch.cat([X_new[:insert_loc], torch.LongTensor(pattern), X_new[insert_loc:]], dim=0) 98 | y_new = target_y 99 | return X_new, y_new 100 | -------------------------------------------------------------------------------- /audio_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import librosa 4 | from tqdm import tqdm 5 | 6 | ALL_CLS = ['backward', 'bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'follow', 'forward', 'four', 'go', 'happy', 'house', 'learn', 'left', 'marvin', 'nine', 'no', 'off', 'on', 'one', 'right', 'seven', 'sheila', 'six', 'stop', 'three', 'tree', 'two', 'up', 'visual', 'wow', 'yes', 'zero'] 7 | 8 | if __name__ == '__main__': 9 | path = './raw_data/speech_command' 10 | if not os.path.isdir(path+'/processed'): 11 | os.mkdir(path+'/processed') 12 | 13 | print ("Processing validation set") 14 | val_Xs = [] 15 | val_ys = [] 16 | val_data = {k:set() for k in ALL_CLS} 17 | with open('%s/validation_list.txt'%path) as inf: 18 | for line in tqdm(inf): 19 | cls, fname = line.strip().split('/') 20 | val_data[cls].add(fname) 21 | y = ALL_CLS.index(cls) 22 | samples, sample_rate = librosa.load(path+'/'+line.strip(), 16000) 23 | assert sample_rate == 16000 24 | assert len(samples) <= 16000 25 | X = np.pad(samples, (0,16000-len(samples)), 'constant') 26 | val_Xs.append(X) 27 | val_ys.append(y) 28 | val_Xs = np.array(val_Xs) 29 | val_ys = np.array(val_ys) 30 | np.save(path+'/processed/val_data.npy', val_Xs) 31 | np.save(path+'/processed/val_label.npy', val_ys) 32 | print ("Validation set processed, %d in total"%len(val_ys)) 33 | 34 | print ("Processing test set") 35 | test_Xs = [] 36 | test_ys = [] 37 | test_data = {k:set() for k in ALL_CLS} 38 | with open('%s/testing_list.txt'%path) as inf: 39 | for line in tqdm(inf): 40 | cls, fname = line.strip().split('/') 41 | test_data[cls].add(fname) 42 | y = ALL_CLS.index(cls) 43 | samples, sample_rate = librosa.load(path+'/'+line.strip(), 16000) 44 | assert sample_rate == 16000 45 | assert len(samples) <= 16000 46 | X = np.pad(samples, (0,16000-len(samples)), 'constant') 47 | test_Xs.append(X) 48 | test_ys.append(y) 49 | test_Xs = np.array(test_Xs) 50 | test_ys = np.array(test_ys) 51 | np.save(path+'/processed/test_data.npy', test_Xs) 52 | np.save(path+'/processed/test_label.npy', test_ys) 53 | print ("Test set processed, %d in total"%len(test_ys)) 54 | 55 | print ("Processing training set") 56 | train_data = {k:[] for k in ALL_CLS} 57 | for cls in ALL_CLS: 58 | fnames = os.listdir(path+'/'+cls) 59 | for fname in fnames: 60 | if fname.endswith('.wav') and fname not in val_data[cls] and fname not in test_data[cls]: 61 | train_data[cls].append(fname) 62 | train_Xs = [] 63 | train_ys = [] 64 | for cls in tqdm(ALL_CLS): 65 | for fname in train_data[cls]: 66 | y = ALL_CLS.index(cls) 67 | samples, sample_rate = librosa.load(path+'/'+cls+'/'+fname, 16000) 68 | assert sample_rate == 16000 69 | assert len(samples) <= 16000 70 | X = np.pad(samples, (0,16000-len(samples)), 'constant') 71 | train_Xs.append(X) 72 | train_ys.append(y) 73 | train_Xs = np.array(train_Xs) 74 | train_ys = np.array(train_ys) 75 | np.save(path+'/processed/train_data.npy', train_Xs) 76 | np.save(path+'/processed/train_label.npy', train_ys) 77 | print ("Training set processed, %d in total"%len(train_ys)) 78 | -------------------------------------------------------------------------------- /run_meta_oneclass.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | from utils_meta import load_model_setting, epoch_meta_train_oc, epoch_meta_eval_oc 5 | from meta_classifier import MetaClassifierOC 6 | import argparse 7 | from tqdm import tqdm 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--task', type=str, required=True, help='Specfiy the task (mnist/cifar10/audio/rtNLP).') 11 | parser.add_argument('--troj_type', type=str, required=True, help='Specify the attack to evaluate. M: modification attack; B: blending attack.') 12 | parser.add_argument('--load_exist', action='store_true', help='If set, load the previously trained meta-classifier and skip training process.') 13 | 14 | if __name__ == '__main__': 15 | args = parser.parse_args() 16 | assert args.troj_type in ('M', 'B'), 'unknown trojan pattern' 17 | 18 | GPU = True 19 | N_REPEAT = 5 20 | N_EPOCH = 10 21 | 22 | TRAIN_NUM = 2048 23 | VAL_NUM = 256 24 | TEST_NUM = 256 25 | 26 | save_path = './meta_classifier_ckpt/%s_oc.model'%args.task 27 | shadow_path = './shadow_model_ckpt/%s/models'%args.task 28 | 29 | Model, input_size, class_num, inp_mean, inp_std, is_discrete = load_model_setting(args.task) 30 | if inp_mean is not None: 31 | inp_mean = torch.FloatTensor(inp_mean) 32 | inp_std = torch.FloatTensor(inp_std) 33 | if GPU: 34 | inp_mean = inp_mean.cuda() 35 | inp_std = inp_std.cuda() 36 | print ("Task: %s; target Trojan type: %s; input size: %s; class num: %s"%(args.task, args.troj_type, input_size, class_num)) 37 | 38 | train_dataset = [] 39 | for i in range(TRAIN_NUM): 40 | x = shadow_path + '/shadow_benign_%d.model'%i 41 | train_dataset.append((x,1)) 42 | 43 | test_dataset = [] 44 | for i in range(TEST_NUM): 45 | x = shadow_path + '/target_benign_%d.model'%i 46 | test_dataset.append((x,1)) 47 | x = shadow_path + '/target_troj%s_%d.model'%(args.troj_type, i) 48 | test_dataset.append((x,0)) 49 | 50 | AUCs = [] 51 | for i in range(N_REPEAT): 52 | shadow_model = Model(gpu=GPU) 53 | target_model = Model(gpu=GPU) 54 | meta_model = MetaClassifierOC(input_size, class_num, gpu=GPU) 55 | if inp_mean is not None: 56 | #Initialize the input using data mean and std 57 | init_inp = torch.zeros_like(meta_model.inp).normal_()*inp_std + inp_mean 58 | meta_model.inp.data = init_inp 59 | else: 60 | meta_model.inp.data = meta_model.inp.data 61 | 62 | if not args.load_exist: 63 | print ("Training One-class Meta Classifier %d/%d"%(i+1, N_REPEAT)) 64 | optimizer = torch.optim.Adam(meta_model.parameters(), lr=1e-3) 65 | for _ in tqdm(range(N_EPOCH)): 66 | epoch_meta_train_oc(meta_model, shadow_model, optimizer, train_dataset, is_discrete=is_discrete) 67 | torch.save(meta_model.state_dict(), save_path+'_%d'%i) 68 | else: 69 | print ("Evaluating One-class Meta Classifier %d/%d"%(i+1, N_REPEAT)) 70 | meta_model.load_state_dict(torch.load(save_path+'_%d'%i)) 71 | test_info = epoch_meta_eval_oc(meta_model, shadow_model, test_dataset, is_discrete=is_discrete, threshold='half') 72 | print ("\tTest AUC:", test_info[1]) 73 | AUCs.append(test_info[1]) 74 | 75 | AUC_mean = sum(AUCs) / len(AUCs) 76 | print ("Average detection AUC on %d meta classifier: %.4f"%(N_REPEAT, AUC_mean)) 77 | -------------------------------------------------------------------------------- /train_basic_trojaned.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | from utils_basic import load_dataset_setting, train_model, eval_model, BackdoorDataset 6 | import os 7 | from datetime import datetime 8 | import json 9 | import argparse 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--task', type=str, required=True, help='Specfiy the task (mnist/cifar10/audio/rtNLP).') 13 | parser.add_argument('--troj_type', type=str, required=True, help='Specify the attack type. M: modification attack; B: blending attack.') 14 | if __name__ == '__main__': 15 | args = parser.parse_args() 16 | assert args.troj_type in ('M', 'B'), 'unknown trojan pattern' 17 | 18 | GPU = True 19 | SHADOW_PROP = 0.02 20 | TARGET_PROP = 0.5 21 | TARGET_NUM = 256 22 | np.random.seed(0) 23 | torch.manual_seed(0) 24 | if GPU: 25 | torch.cuda.manual_seed_all(0) 26 | torch.backends.cudnn.deterministic = True 27 | torch.backends.cudnn.benchmark = False 28 | 29 | BATCH_SIZE, N_EPOCH, trainset, testset, is_binary, need_pad, Model, troj_gen_func, random_troj_setting = load_dataset_setting(args.task) 30 | tot_num = len(trainset) 31 | shadow_indices = np.random.choice(tot_num, int(tot_num*SHADOW_PROP)) 32 | target_indices = np.random.choice(tot_num, int(tot_num*TARGET_PROP)) 33 | print ("Data indices owned by the attacker:",target_indices) 34 | 35 | SAVE_PREFIX = './shadow_model_ckpt/%s'%args.task 36 | if not os.path.isdir(SAVE_PREFIX): 37 | os.mkdir(SAVE_PREFIX) 38 | if not os.path.isdir(SAVE_PREFIX+'/models'): 39 | os.mkdir(SAVE_PREFIX+'/models') 40 | 41 | all_target_acc = [] 42 | all_target_acc_mal = [] 43 | 44 | for i in range(TARGET_NUM): 45 | model = Model(gpu=GPU) 46 | atk_setting = random_troj_setting(args.troj_type) 47 | trainset_mal = BackdoorDataset(trainset, atk_setting, troj_gen_func, choice=target_indices, need_pad=need_pad) 48 | trainloader = torch.utils.data.DataLoader(trainset_mal, batch_size=BATCH_SIZE, shuffle=True) 49 | testset_mal = BackdoorDataset(testset, atk_setting, troj_gen_func, mal_only=True) 50 | testloader_benign = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE) 51 | testloader_mal = torch.utils.data.DataLoader(testset_mal, batch_size=BATCH_SIZE) 52 | 53 | train_model(model, trainloader, epoch_num=int(N_EPOCH*SHADOW_PROP/TARGET_PROP), is_binary=is_binary, verbose=False) 54 | save_path = SAVE_PREFIX+'/models/target_troj%s_%d.model'%(args.troj_type, i) 55 | torch.save(model.state_dict(), save_path) 56 | acc = eval_model(model, testloader_benign, is_binary=is_binary) 57 | acc_mal = eval_model(model, testloader_mal, is_binary=is_binary) 58 | print ("Acc %.4f, Acc on backdoor %.4f, saved to %s @ %s"%(acc, acc_mal, save_path, datetime.now())) 59 | p_size, pattern, loc, alpha, target_y, inject_p = atk_setting 60 | print ("\tp size: %d; loc: %s; alpha: %.3f; target_y: %d; inject p: %.3f"%(p_size, loc, alpha, target_y, inject_p)) 61 | all_target_acc.append(acc) 62 | all_target_acc_mal.append(acc_mal) 63 | 64 | log = {'target_num':TARGET_NUM, 65 | 'target_acc':sum(all_target_acc)/len(all_target_acc), 66 | 'target_acc_mal':sum(all_target_acc_mal)/len(all_target_acc_mal)} 67 | log_path = SAVE_PREFIX+'/troj%s.log'%args.troj_type 68 | with open(log_path, "w") as outf: 69 | json.dump(log, outf) 70 | print ("Log file saved to %s"%log_path) 71 | -------------------------------------------------------------------------------- /rtNLP_preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import re 4 | from tqdm import tqdm 5 | 6 | def clean_str(string): 7 | """ 8 | Tokenization/string cleaning for all datasets except for SST. 9 | Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py 10 | """ 11 | string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string) 12 | string = re.sub(r"\'s", " \'s", string) 13 | string = re.sub(r"\'ve", " \'ve", string) 14 | string = re.sub(r"n\'t", " n\'t", string) 15 | string = re.sub(r"\'re", " \'re", string) 16 | string = re.sub(r"\'d", " \'d", string) 17 | string = re.sub(r"\'ll", " \'ll", string) 18 | string = re.sub(r",", " , ", string) 19 | string = re.sub(r"!", " ! ", string) 20 | string = re.sub(r"\(", " \( ", string) 21 | string = re.sub(r"\)", " \) ", string) 22 | string = re.sub(r"\?", " \? ", string) 23 | string = re.sub(r"\s{2,}", " ", string) 24 | return string.strip().lower() 25 | 26 | 27 | if __name__ == '__main__': 28 | path = './raw_data/rt_polarity' 29 | 30 | 31 | positive_examples = list(open('%s/rt-polarity.pos'%path, "r", encoding='utf-8').readlines()) 32 | positive_examples = [s.strip() for s in positive_examples] 33 | negative_examples = list(open('%s/rt-polarity.neg'%path, "r", encoding='utf-8').readlines()) 34 | negative_examples = [s.strip() for s in negative_examples] 35 | 36 | x_text = positive_examples + negative_examples 37 | x_text = [clean_str(sent) for sent in x_text] 38 | x_tokens = [sent.split(' ') for sent in x_text] 39 | positive_labels = [1 for _ in positive_examples] 40 | negative_labels = [0 for _ in negative_examples] 41 | y = np.concatenate([positive_labels, negative_labels], 0) 42 | 43 | print ("Building vocabulary") 44 | max_len = max([len(x) for x in x_tokens]) 45 | tok2idx = {'':0, '':1} 46 | idx2tok = ['', ''] 47 | N_toks = 2 48 | X_seqs = [] 49 | for tokens in x_tokens: 50 | cur_seq = [] 51 | for token in tokens: 52 | idx = tok2idx.get(token, -1) 53 | if idx == -1: 54 | idx = N_toks 55 | tok2idx[token] = idx 56 | idx2tok.append(token) 57 | N_toks += 1 58 | cur_seq.append(idx) 59 | cur_seq = cur_seq + [0]*(max_len - len(cur_seq)) 60 | X_seqs.append(cur_seq) 61 | assert N_toks == len(tok2idx) == len(idx2tok) 62 | print("Vocabulary Size: %d"%N_toks) 63 | 64 | print ("Splitting train/dev set") 65 | x = np.array(X_seqs) 66 | np.random.seed(10) 67 | shuffle_indices = np.random.permutation(np.arange(len(y))) 68 | x_shuffled = x[shuffle_indices] 69 | y_shuffled = y[shuffle_indices] 70 | dev_sample_index = -1 * int(0.1*len(y)) 71 | x_train, x_dev = x_shuffled[:dev_sample_index], x_shuffled[dev_sample_index:] 72 | y_train, y_dev = y_shuffled[:dev_sample_index], y_shuffled[dev_sample_index:] 73 | print("Train/Dev split: %d/%d"%(len(y_train), len(y_dev))) 74 | 75 | print ("Processing word embedding") 76 | import gensim 77 | w2v = gensim.models.KeyedVectors.load_word2vec_format('%s/GoogleNews-vectors-negative300.bin'%path, binary=True) 78 | saved_emb = [] 79 | for v in idx2tok: 80 | if v not in w2v.vocab: 81 | saved_emb.append(np.zeros(300)) 82 | else: 83 | saved_emb.append(w2v.word_vec(v)) 84 | saved_emb = np.array(saved_emb) 85 | print ("Embedding Vector size:", saved_emb.shape) 86 | 87 | np.save('%s/train_data.npy'%path, x_train) 88 | np.save('%s/train_label.npy'%path, y_train) 89 | np.save('%s/dev_data.npy'%path, x_dev) 90 | np.save('%s/dev_label.npy'%path, y_dev) 91 | with open('%s/dict.json'%path, 'w') as outf: 92 | json.dump({'tok2idx':tok2idx, 'idx2tok':idx2tok}, outf) 93 | np.save('%s/saved_emb.npy'%path, saved_emb) 94 | -------------------------------------------------------------------------------- /run_meta.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | from utils_meta import load_model_setting, epoch_meta_train, epoch_meta_eval 5 | from meta_classifier import MetaClassifier 6 | import argparse 7 | from tqdm import tqdm 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--task', type=str, required=True, help='Specfiy the task (mnist/cifar10/audio/rtNLP).') 11 | parser.add_argument('--troj_type', type=str, required=True, help='Specify the attack to evaluate. M: modification attack; B: blending attack.') 12 | parser.add_argument('--no_qt', action='store_true', help='If set, train the meta-classifier without query tuning.') 13 | parser.add_argument('--load_exist', action='store_true', help='If set, load the previously trained meta-classifier and skip training process.') 14 | 15 | if __name__ == '__main__': 16 | args = parser.parse_args() 17 | assert args.troj_type in ('M', 'B'), 'unknown trojan pattern' 18 | 19 | GPU = True 20 | N_REPEAT = 5 21 | N_EPOCH = 10 22 | TRAIN_NUM = 2048 23 | VAL_NUM = 256 24 | TEST_NUM = 256 25 | 26 | if args.no_qt: 27 | save_path = './meta_classifier_ckpt/%s_no-qt.model'%args.task 28 | else: 29 | save_path = './meta_classifier_ckpt/%s.model'%args.task 30 | shadow_path = './shadow_model_ckpt/%s/models'%args.task 31 | 32 | Model, input_size, class_num, inp_mean, inp_std, is_discrete = load_model_setting(args.task) 33 | if inp_mean is not None: 34 | inp_mean = torch.FloatTensor(inp_mean) 35 | inp_std = torch.FloatTensor(inp_std) 36 | if GPU: 37 | inp_mean = inp_mean.cuda() 38 | inp_std = inp_std.cuda() 39 | print ("Task: %s; target Trojan type: %s; input size: %s; class num: %s"%(args.task, args.troj_type, input_size, class_num)) 40 | 41 | train_dataset = [] 42 | for i in range(TRAIN_NUM): 43 | x = shadow_path + '/shadow_jumbo_%d.model'%i 44 | train_dataset.append((x,1)) 45 | x = shadow_path + '/shadow_benign_%d.model'%i 46 | train_dataset.append((x,0)) 47 | 48 | val_dataset = [] 49 | for i in range(TRAIN_NUM, TRAIN_NUM+VAL_NUM): 50 | x = shadow_path + '/shadow_jumbo_%d.model'%i 51 | val_dataset.append((x,1)) 52 | x = shadow_path + '/shadow_benign_%d.model'%i 53 | val_dataset.append((x,0)) 54 | 55 | test_dataset = [] 56 | for i in range(TEST_NUM): 57 | x = shadow_path + '/target_troj%s_%d.model'%(args.troj_type, i) 58 | test_dataset.append((x,1)) 59 | x = shadow_path + '/target_benign_%d.model'%i 60 | test_dataset.append((x,0)) 61 | 62 | AUCs = [] 63 | for i in range(N_REPEAT): # Result contains randomness, so run several times and take the average 64 | shadow_model = Model(gpu=GPU) 65 | target_model = Model(gpu=GPU) 66 | meta_model = MetaClassifier(input_size, class_num, gpu=GPU) 67 | if inp_mean is not None: 68 | #Initialize the input using data mean and std 69 | init_inp = torch.zeros_like(meta_model.inp).normal_()*inp_std + inp_mean 70 | meta_model.inp.data = init_inp 71 | else: 72 | meta_model.inp.data = meta_model.inp.data 73 | 74 | if not args.load_exist: 75 | print ("Training Meta Classifier %d/%d"%(i+1, N_REPEAT)) 76 | if args.no_qt: 77 | print ("No query tuning.") 78 | optimizer = torch.optim.Adam(list(meta_model.fc.parameters()) + list(meta_model.output.parameters()), lr=1e-3) 79 | else: 80 | optimizer = torch.optim.Adam(meta_model.parameters(), lr=1e-3) 81 | 82 | best_eval_auc = None 83 | test_info = None 84 | for _ in tqdm(range(N_EPOCH)): 85 | epoch_meta_train(meta_model, shadow_model, optimizer, train_dataset, is_discrete=is_discrete, threshold='half') 86 | eval_loss, eval_auc, eval_acc = epoch_meta_eval(meta_model, shadow_model, val_dataset, is_discrete=is_discrete, threshold='half') 87 | if best_eval_auc is None or eval_auc > best_eval_auc: 88 | best_eval_auc = eval_auc 89 | test_info = epoch_meta_eval(meta_model, target_model, test_dataset, is_discrete=is_discrete, threshold='half') 90 | torch.save(meta_model.state_dict(), save_path+'_%d'%i) 91 | else: 92 | print ("Evaluating Meta Classifier %d/%d"%(i+1, N_REPEAT)) 93 | meta_model.load_state_dict(torch.load(save_path+'_%d'%i)) 94 | test_info = epoch_meta_eval(meta_model, target_model, test_dataset, is_discrete=is_discrete, threshold='half') 95 | 96 | print ("\tTest AUC:", test_info[1]) 97 | AUCs.append(test_info[1]) 98 | 99 | AUC_mean = sum(AUCs) / len(AUCs) 100 | print ("Average detection AUC on %d meta classifier: %.4f"%(N_REPEAT, AUC_mean)) 101 | -------------------------------------------------------------------------------- /utils_meta.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.metrics import roc_auc_score 4 | 5 | def load_model_setting(task): 6 | if task == 'mnist': 7 | from model_lib.mnist_cnn_model import Model 8 | input_size = (1, 28, 28) 9 | class_num = 10 10 | normed_mean = np.array((0.1307,)) 11 | normed_std = np.array((0.3081,)) 12 | is_discrete = False 13 | elif task == 'cifar10': 14 | from model_lib.cifar10_cnn_model import Model 15 | input_size = (3, 32, 32) 16 | class_num = 10 17 | normed_mean = np.reshape(np.array((0.4914, 0.4822, 0.4465)),(3,1,1)) 18 | normed_std = np.reshape(np.array((0.247, 0.243, 0.261)),(3,1,1)) 19 | is_discrete = False 20 | elif task == 'audio': 21 | from model_lib.audio_rnn_model import Model 22 | input_size = (16000,) 23 | class_num = 10 24 | normed_mean = normed_std = None 25 | is_discrete = False 26 | elif task == 'rtNLP': 27 | from model_lib.rtNLP_cnn_model import Model 28 | input_size = (1, 10, 300) 29 | class_num = 1 #Two-class, but only one output 30 | normed_mean = normed_std = None 31 | is_discrete = True 32 | else: 33 | raise NotImplementedError("Unknown task %s"%task) 34 | 35 | return Model, input_size, class_num, normed_mean, normed_std, is_discrete 36 | 37 | 38 | def epoch_meta_train(meta_model, basic_model, optimizer, dataset, is_discrete, threshold=0.0): 39 | meta_model.train() 40 | basic_model.train() 41 | 42 | cum_loss = 0.0 43 | preds = [] 44 | labs = [] 45 | perm = np.random.permutation(len(dataset)) 46 | for i in perm: 47 | x, y = dataset[i] 48 | 49 | basic_model.load_state_dict(torch.load(x)) 50 | if is_discrete: 51 | out = basic_model.emb_forward(meta_model.inp) 52 | else: 53 | out = basic_model.forward(meta_model.inp) 54 | score = meta_model.forward(out) 55 | l = meta_model.loss(score, y) 56 | 57 | optimizer.zero_grad() 58 | l.backward() 59 | optimizer.step() 60 | 61 | cum_loss = cum_loss + l.item() 62 | preds.append(score.item()) 63 | labs.append(y) 64 | 65 | preds = np.array(preds) 66 | labs = np.array(labs) 67 | auc = roc_auc_score(labs, preds) 68 | if threshold == 'half': 69 | threshold = np.asscalar(np.median(preds)) 70 | acc = ( (preds>threshold) == labs ).mean() 71 | 72 | return cum_loss / len(dataset), auc, acc 73 | 74 | def epoch_meta_eval(meta_model, basic_model, dataset, is_discrete, threshold=0.0): 75 | meta_model.eval() 76 | basic_model.train() 77 | 78 | cum_loss = 0.0 79 | preds = [] 80 | labs = [] 81 | perm = list(range(len(dataset))) 82 | for i in perm: 83 | x, y = dataset[i] 84 | basic_model.load_state_dict(torch.load(x)) 85 | 86 | if is_discrete: 87 | out = basic_model.emb_forward(meta_model.inp) 88 | else: 89 | out = basic_model.forward(meta_model.inp) 90 | score = meta_model.forward(out) 91 | 92 | l = meta_model.loss(score, y) 93 | cum_loss = cum_loss + l.item() 94 | preds.append(score.item()) 95 | labs.append(y) 96 | 97 | preds = np.array(preds) 98 | labs = np.array(labs) 99 | auc = roc_auc_score(labs, preds) 100 | if threshold == 'half': 101 | threshold = np.asscalar(np.median(preds)) 102 | acc = ( (preds>threshold) == labs ).mean() 103 | 104 | return cum_loss / len(preds), auc, acc 105 | 106 | 107 | def epoch_meta_train_oc(meta_model, basic_model, optimizer, dataset, is_discrete): 108 | scores = [] 109 | cum_loss = 0.0 110 | perm = np.random.permutation(len(dataset)) 111 | for i in perm: 112 | x, y = dataset[i] 113 | assert y == 1 114 | basic_model.load_state_dict(torch.load(x)) 115 | if is_discrete: 116 | out = basic_model.emb_forward(meta_model.inp) 117 | else: 118 | out = basic_model.forward(meta_model.inp) 119 | score = meta_model.forward(out) 120 | scores.append(score.item()) 121 | 122 | loss = meta_model.loss(score) 123 | optimizer.zero_grad() 124 | loss.backward() 125 | optimizer.step() 126 | cum_loss += loss.item() 127 | meta_model.update_r(scores) 128 | return cum_loss / len(dataset) 129 | 130 | def epoch_meta_eval_oc(meta_model, basic_model, dataset, is_discrete, threshold=0.0): 131 | preds = [] 132 | labs = [] 133 | for x, y in dataset: 134 | basic_model.load_state_dict(torch.load(x)) 135 | if is_discrete: 136 | out = basic_model.emb_forward(meta_model.inp) 137 | else: 138 | out = basic_model.forward(meta_model.inp) 139 | score = meta_model.forward(out) 140 | 141 | preds.append(score.item()) 142 | labs.append(y) 143 | 144 | preds = np.array(preds) 145 | labs = np.array(labs) 146 | auc = roc_auc_score(labs, preds) 147 | if threshold == 'half': 148 | threshold = np.asscalar(np.median(preds)) 149 | acc = ( (preds>threshold) == labs ).mean() 150 | return auc, acc 151 | -------------------------------------------------------------------------------- /utils_basic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.metrics import roc_auc_score 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | 7 | def load_dataset_setting(task): 8 | if task == 'mnist': 9 | BATCH_SIZE = 100 10 | N_EPOCH = 100 11 | transform = transforms.Compose([ 12 | transforms.ToTensor(), 13 | ]) 14 | trainset = torchvision.datasets.MNIST(root='./raw_data/', train=True, download=True, transform=transform) 15 | testset = torchvision.datasets.MNIST(root='./raw_data/', train=False, download=False, transform=transform) 16 | is_binary = False 17 | need_pad = False 18 | from model_lib.mnist_cnn_model import Model, troj_gen_func, random_troj_setting 19 | elif task == 'cifar10': 20 | BATCH_SIZE = 100 21 | N_EPOCH = 100 22 | transform = transforms.Compose([ 23 | transforms.ToTensor(), 24 | ]) 25 | trainset = torchvision.datasets.CIFAR10(root='./raw_data/', train=True, download=True, transform=transform) 26 | testset = torchvision.datasets.CIFAR10(root='./raw_data/', train=False, download=False, transform=transform) 27 | is_binary = False 28 | need_pad = False 29 | from model_lib.cifar10_cnn_model import Model, troj_gen_func, random_troj_setting 30 | elif task == 'audio': 31 | BATCH_SIZE = 100 32 | N_EPOCH = 100 33 | from model_lib.audio_dataset import SpeechCommand 34 | trainset = SpeechCommand(split=0) 35 | testset = SpeechCommand(split=2) 36 | is_binary = False 37 | need_pad = False 38 | from model_lib.audio_rnn_model import Model, troj_gen_func, random_troj_setting 39 | elif task == 'rtNLP': 40 | BATCH_SIZE = 64 41 | N_EPOCH = 50 42 | from model_lib.rtNLP_dataset import RTNLP 43 | trainset = RTNLP(train=True) 44 | testset = RTNLP(train=False) 45 | is_binary = True 46 | need_pad = True 47 | from model_lib.rtNLP_cnn_model import Model, troj_gen_func, random_troj_setting 48 | else: 49 | raise NotImplementedError("Unknown task %s"%task) 50 | 51 | return BATCH_SIZE, N_EPOCH, trainset, testset, is_binary, need_pad, Model, troj_gen_func, random_troj_setting 52 | 53 | 54 | class BackdoorDataset(torch.utils.data.Dataset): 55 | def __init__(self, src_dataset, atk_setting, troj_gen_func, choice=None, mal_only=False, need_pad=False): 56 | self.src_dataset = src_dataset 57 | self.atk_setting = atk_setting 58 | self.troj_gen_func = troj_gen_func 59 | self.need_pad = need_pad 60 | 61 | self.mal_only = mal_only 62 | if choice is None: 63 | choice = np.arange(len(src_dataset)) 64 | self.choice = choice 65 | inject_p = atk_setting[5] 66 | self.mal_choice = np.random.choice(choice, int(len(choice)*inject_p), replace=False) 67 | 68 | def __len__(self,): 69 | if self.mal_only: 70 | return len(self.mal_choice) 71 | else: 72 | return len(self.choice) + len(self.mal_choice) 73 | 74 | def __getitem__(self, idx): 75 | if (not self.mal_only and idx < len(self.choice)): 76 | # Return non-trojaned data 77 | if self.need_pad: 78 | # In NLP task we need to pad input with length of Troj pattern 79 | p_size = self.atk_setting[0] 80 | X, y = self.src_dataset[self.choice[idx]] 81 | X_padded = torch.cat([X, torch.LongTensor([0]*p_size)], dim=0) 82 | return X_padded, y 83 | else: 84 | return self.src_dataset[self.choice[idx]] 85 | 86 | if self.mal_only: 87 | X, y = self.src_dataset[self.mal_choice[idx]] 88 | else: 89 | X, y = self.src_dataset[self.mal_choice[idx-len(self.choice)]] 90 | X_new, y_new = self.troj_gen_func(X, y, self.atk_setting) 91 | return X_new, y_new 92 | 93 | 94 | def train_model(model, dataloader, epoch_num, is_binary, verbose=True): 95 | model.train() 96 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) 97 | 98 | for epoch in range(epoch_num): 99 | cum_loss = 0.0 100 | cum_acc = 0.0 101 | tot = 0.0 102 | for i,(x_in, y_in) in enumerate(dataloader): 103 | B = x_in.size()[0] 104 | pred = model(x_in) 105 | loss = model.loss(pred, y_in) 106 | optimizer.zero_grad() 107 | loss.backward() 108 | optimizer.step() 109 | cum_loss += loss.item() * B 110 | if is_binary: 111 | cum_acc += ((pred>0).cpu().long().eq(y_in)).sum().item() 112 | else: 113 | pred_c = pred.max(1)[1].cpu() 114 | cum_acc += (pred_c.eq(y_in)).sum().item() 115 | tot = tot + B 116 | if verbose: 117 | print ("Epoch %d, loss = %.4f, acc = %.4f"%(epoch, cum_loss/tot, cum_acc/tot)) 118 | return 119 | 120 | 121 | def eval_model(model, dataloader, is_binary): 122 | model.eval() 123 | cum_acc = 0.0 124 | tot = 0.0 125 | for i,(x_in, y_in) in enumerate(dataloader): 126 | B = x_in.size()[0] 127 | pred = model(x_in) 128 | if is_binary: 129 | cum_acc += ((pred>0).cpu().long().eq(y_in)).sum().item() 130 | else: 131 | pred_c = pred.max(1)[1].cpu() 132 | cum_acc += (pred_c.eq(y_in)).sum().item() 133 | tot = tot + B 134 | return cum_acc / tot 135 | --------------------------------------------------------------------------------