├── AvastStyleConv.py ├── ContinueTraining.py ├── LowMemConv.py ├── MalConv.py ├── MalConvGCT_nocat.py ├── MalConvGCT_nocatTrain.py ├── MalConvML.py ├── MalConvTrain.py ├── OptunaTrain.py ├── README.md ├── binaryLoader.py ├── checkpoint.py └── malconvGCT_nocat.checkpoint /AvastStyleConv.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from collections import OrderedDict 3 | 4 | import random 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.utils.checkpoint import checkpoint 11 | 12 | from LowMemConv import LowMemConvBase 13 | 14 | 15 | def getParams(): 16 | #Format for this is to make it work easily with Optuna in an automated fashion. 17 | #variable name -> tuple(sampling function, dict(sampling_args) ) 18 | params = { 19 | 'channels' : ("suggest_int", {'name':'channels', 'low':16, 'high':64}), 20 | 'stride' : ("suggest_int", {'name':'stride', 'low':2, 'high':4}), 21 | 'window_size' : ("suggest_int", {'name':'window_size', 'low':16, 'high':64}), 22 | } 23 | return OrderedDict(sorted(params.items(), key=lambda t: t[0])) 24 | 25 | def initModel(**kwargs): 26 | new_args = {} 27 | for x in getParams(): 28 | if x in kwargs: 29 | new_args[x] = kwargs[x] 30 | 31 | return AvastConv(**new_args) 32 | 33 | def vec_bin_array(arr, m=8): 34 | """ 35 | Arguments: 36 | arr: Numpy array of positive integers 37 | m: Number of bits of each integer to retain 38 | 39 | Returns a copy of arr with every element replaced with a bit vector. 40 | Bits encoded as int8's. 41 | """ 42 | to_str_func = np.vectorize(lambda x: np.binary_repr(x).zfill(m)) 43 | strs = to_str_func(arr) 44 | ret = np.zeros(list(arr.shape) + [m], dtype=np.int8) 45 | for bit_ix in range(0, m): 46 | fetch_bit_func = np.vectorize(lambda x: x[bit_ix] == '1') 47 | ret[...,bit_ix] = fetch_bit_func(strs).astype(np.int8) 48 | 49 | return (ret*2-1).astype(np.float32)/16 50 | 51 | class AvastConv(LowMemConvBase): 52 | 53 | def __init__(self, out_size=2, channels=48, window_size=32, stride=4): 54 | super(AvastConv, self).__init__() 55 | self.embd = nn.Embedding(257, embd_size, padding_idx=0) 56 | for i in range(1, 257): 57 | self.embd.weight.data[i,:] = torch.tensor(vec_bin_array(np.asarray([i]))) 58 | for param in self.embd.parameters(): 59 | param.requires_grad = False 60 | 61 | 62 | self.conv_1 = nn.Conv1d(8, channels, window_size, stride=stride, bias=True) 63 | self.conv_2 = nn.Conv1d(channels, channels*2, window_size, stride=stride, bias=True) 64 | self.pool = nn.MaxPool1d(4) 65 | self.conv_3 = nn.Conv1d(channels*2, channels*3, window_size//2, stride=stride*2, bias=True) 66 | self.conv_4 = nn.Conv1d(channels*3, channels*4, window_size//2, stride=stride*2, bias=True) 67 | 68 | 69 | 70 | self.fc_1 = nn.Linear(channels*4, channels*4) 71 | self.fc_2 = nn.Linear(channels*4, channels*3) 72 | self.fc_3 = nn.Linear(channels*3, channels*2) 73 | self.fc_4 = nn.Linear(channels*2, out_size) 74 | 75 | 76 | def processRange(self, x): 77 | #Fixed embedding 78 | # cur_device = next(self.conv_1.parameters()).device 79 | # x = torch.tensor(vec_bin_array(x.cpu().data.numpy())) 80 | # print("chunk") 81 | with torch.no_grad(): 82 | x = self.embd(x) 83 | x = torch.transpose(x,-1,-2) 84 | 85 | x = F.relu(self.conv_1(x)) 86 | x = F.relu(self.conv_2(x)) 87 | x = self.pool(x) 88 | x = F.relu(self.conv_3(x)) 89 | x = F.relu(self.conv_4(x)) 90 | 91 | return x 92 | 93 | def forward(self, x): 94 | post_conv = x = self.seq2fix(x) 95 | 96 | x = F.selu(self.fc_1(x)) 97 | x = F.selu(self.fc_2(x)) 98 | penult = x = F.selu(self.fc_3(x)) 99 | x = self.fc_4(x) 100 | 101 | return x, penult, post_conv -------------------------------------------------------------------------------- /ContinueTraining.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import deque 3 | from collections import OrderedDict 4 | 5 | import random 6 | import numpy as np 7 | 8 | #from tqdm import tqdm_notebook as tqdm 9 | from tqdm import tqdm 10 | import multiprocessing 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torch.utils.checkpoint import checkpoint 16 | 17 | from torch.optim.lr_scheduler import StepLR 18 | 19 | import torch.optim as optim 20 | 21 | from torch.utils import data 22 | 23 | from torch.utils.data import Dataset, DataLoader, Subset 24 | 25 | from binaryLoader import BinaryDataset, RandomChunkSampler, pad_collate_func 26 | from sklearn.metrics import roc_auc_score 27 | 28 | import optuna 29 | 30 | import argparse 31 | 32 | #Check if the input is a valid directory 33 | def dir_path(string): 34 | if os.path.isdir(string): 35 | return string 36 | else: 37 | raise NotADirectoryError(string) 38 | 39 | def is_file(string): 40 | if os.path.isfile(string): 41 | return string 42 | else: 43 | raise NotADirectoryError(string) 44 | 45 | parser = argparse.ArgumentParser(description='Train a Model model') 46 | 47 | parser.add_argument('--epochs', type=int, default=300, help='How many training epochs to perform') 48 | parser.add_argument('--batch_size', type=int, default=128, help='Batch size during training') 49 | #Default is set ot 16 MB! 50 | parser.add_argument('--max_len', type=int, default=16000000, help='Maximum length of input file in bytes, at which point files will be truncated') 51 | 52 | parser.add_argument('--save-every', type=int, default=25, help='Batch size during training') 53 | 54 | parser.add_argument('--gpus', nargs='+', type=int) 55 | 56 | parser.add_argument('--checkpoint', type=is_file, help='File to load and use') 57 | parser.add_argument('--log', default="long_train", type=str, help='Log file location') 58 | 59 | 60 | names_in_check_order= ["Avast", "MalConvML", "MalConvGCT", "MalConv"] 61 | 62 | parser.add_argument('--model', type=str, default=None, choices=names_in_check_order, help='Type of model to train') 63 | 64 | parser.add_argument('mal_train', type=dir_path, help='Path to directory containing malware files for training') 65 | parser.add_argument('ben_train', type=dir_path, help='Path to directory containing benign files for training') 66 | parser.add_argument('mal_test', type=dir_path, help='Path to directory containing malware files for testing') 67 | parser.add_argument('ben_test', type=dir_path, help='Path to directory containing benign files for testing') 68 | 69 | args = parser.parse_args() 70 | 71 | GPUS = args.gpus 72 | 73 | torch.backends.cudnn.enabled = False 74 | 75 | EPOCHS = args.epochs 76 | MAX_FILE_LEN = args.max_len 77 | 78 | BATCH_SIZE = args.batch_size 79 | 80 | if args.model is not None: 81 | MODEL_NAME = args.model 82 | else: 83 | #Noe model name type was specified. Can we infer it from the file path of the checkpoint? 84 | for option in names_in_check_order: 85 | if option in args.checkpoint: 86 | MODEL_NAME = option 87 | break 88 | 89 | #First we define our own random split, b/c we want to keep data shuffle order in 90 | #tact b/c it will make trainin faster. This is because we kept things orded by size, so batches 91 | #can be as small as possible. 92 | def random_split(dataset, lengths): 93 | """ 94 | Randomly split a dataset into non-overlapping new datasets of given lengths. 95 | 96 | Arguments: 97 | dataset (Dataset): Dataset to be split 98 | lengths (sequence): lengths of splits to be produced 99 | """ 100 | #if sum(lengths) != len(dataset): 101 | # raise ValueError("Sum of input lengths does not equal the length of the input dataset!") 102 | 103 | indices = torch.randperm(sum(lengths)).tolist() 104 | to_ret = [] 105 | for offset, length in zip(torch._utils._accumulate(lengths), lengths): 106 | selected = indices[offset - length:offset] 107 | selected.sort() 108 | to_ret.append( Subset(dataset, selected) ) 109 | return to_ret 110 | 111 | 112 | 113 | if MODEL_NAME.lower() == "MalConv".lower(): 114 | from MalConv import getParams, initModel 115 | elif MODEL_NAME.lower() == "Avast".lower(): 116 | from AvastStyleConv import getParams, initModel 117 | elif MODEL_NAME.lower() == "MalConvML".lower(): 118 | from MalConvML import getParams, initModel 119 | elif MODEL_NAME.lower() == "MalConvGCT".lower(): 120 | from MalConvGCT import getParams, initModel 121 | print("CORRECT GCT") 122 | 123 | 124 | if GPUS is None:#use ALL of them! (Default) 125 | device_str = "cuda:0" 126 | else: 127 | if GPUS[0] < 0: 128 | device_str = "cpu" 129 | else: 130 | device_str = "cuda:{}".format(GPUS[0]) 131 | 132 | 133 | device = torch.device(device_str if torch.cuda.is_available() else "cpu") 134 | 135 | 136 | checkpoint = torch.load(args.checkpoint, map_location=device) 137 | print([key for key in checkpoint.keys()]) 138 | 139 | NON_NEG = checkpoint['non_neg'] 140 | 141 | #Create model of same type 142 | model = initModel(**checkpoint).to(device) 143 | #optimizer = optim.AdamW(model.parameters(), lr=checkpoint['lr']) 144 | optimizer = optim.AdamW(model.parameters()) 145 | 146 | #Restore weights and parameters 147 | model.load_state_dict(checkpoint['model_state_dict']) 148 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 149 | 150 | 151 | del checkpoint['model_state_dict'] 152 | del checkpoint['optimizer_state_dict'] 153 | args_to_use = checkpoint 154 | 155 | whole_dataset = BinaryDataset(args.ben_train, args.mal_train, sort_by_size=True, max_len=MAX_FILE_LEN ) 156 | test_dataset = BinaryDataset(args.ben_test, args.mal_test, sort_by_size=True, max_len=MAX_FILE_LEN ) 157 | 158 | #Sub sample for testing purposes, not use when you want to do real work 159 | #whole_dataset = random_split(whole_dataset, [1000])[0] 160 | #test_dataset = random_split(test_dataset, [1000])[0] 161 | 162 | loader_threads = max(multiprocessing.cpu_count()-4, multiprocessing.cpu_count()//2+1) 163 | 164 | train_loader = DataLoader(whole_dataset, batch_size=BATCH_SIZE, num_workers=loader_threads, collate_fn=pad_collate_func, 165 | sampler=RandomChunkSampler(whole_dataset,BATCH_SIZE)) 166 | 167 | test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=loader_threads, collate_fn=pad_collate_func, 168 | sampler=RandomChunkSampler(test_dataset,BATCH_SIZE)) 169 | 170 | headers = ['epoch', 'train_acc', 'train_auc', 'test_acc', 'test_auc'] 171 | 172 | base_name = args.log 173 | 174 | if not os.path.exists(base_name): 175 | os.makedirs(base_name) 176 | file_name = os.path.join(base_name, base_name) 177 | 178 | with open(base_name + ".csv", 'w') as csv_log_out: 179 | csv_log_out.write(",".join(headers) + "\n") 180 | 181 | criterion = nn.CrossEntropyLoss() 182 | 183 | scheduler = StepLR(optimizer, step_size=EPOCHS//10, gamma=0.5) 184 | 185 | for epoch in tqdm(range(EPOCHS)): 186 | 187 | preds = [] 188 | truths = [] 189 | running_loss = 0.0 190 | 191 | 192 | train_correct = 0 193 | train_total = 0 194 | 195 | epoch_stats = {'epoch':epoch} 196 | 197 | model.train() 198 | for inputs, labels in tqdm(train_loader): 199 | 200 | #inputs, labels = inputs.to(device), labels.to(device) 201 | #Keep inputs on CPU, the model will load chunks of input onto device as needed 202 | labels = labels.to(device) 203 | 204 | optimizer.zero_grad() 205 | 206 | # outputs, penultimate_activ, conv_active = model.forward_extra(inputs) 207 | outputs, penult, post_conv = model(inputs) 208 | loss = criterion(outputs, labels) 209 | loss = loss #+ decov_lambda*(decov_penalty(penultimate_activ) + decov_penalty(conv_active)) 210 | # loss = loss + decov_lambda*(decov_penalty(conv_active)) 211 | loss.backward() 212 | optimizer.step() 213 | if NON_NEG: 214 | for p in model.parameters(): 215 | p.data.clamp_(0) 216 | 217 | 218 | running_loss += loss.item() 219 | 220 | _, predicted = torch.max(outputs.data, 1) 221 | 222 | with torch.no_grad(): 223 | preds.extend(F.softmax(outputs, dim=-1).data[:,1].detach().cpu().numpy().ravel()) 224 | truths.extend(labels.detach().cpu().numpy().ravel()) 225 | 226 | train_total += labels.size(0) 227 | train_correct += (predicted == labels).sum().item() 228 | 229 | #end train loop 230 | 231 | #print("Training Accuracy: {}".format(train_correct*100.0/train_total)) 232 | 233 | epoch_stats['train_acc'] = train_correct*1.0/train_total 234 | epoch_stats['train_auc'] = roc_auc_score(truths, preds) 235 | #epoch_stats['train_loss'] = roc_auc_score(truths, preds) 236 | 237 | #Save the model and current state! 238 | model_path = os.path.join(base_name, "epoch_{}.checkpoint".format(epoch)) 239 | 240 | 241 | #Have to handle model state special if multi-gpu was used 242 | if type(model).__name__ is "DataParallel": 243 | mstd = model.module.state_dict() 244 | else: 245 | mstd = model.state_dict() 246 | 247 | #Copy dict, and add extra info to save off 248 | if epoch % args.save_every == 0 or epoch == EPOCHS-1: 249 | check_dict = args_to_use.copy() 250 | check_dict['epoch'] = epoch 251 | check_dict['model_state_dict'] = mstd 252 | check_dict['optimizer_state_dict'] = optimizer.state_dict() 253 | check_dict['non_neg'] = NON_NEG 254 | torch.save(check_dict, model_path) 255 | 256 | 257 | #Test Set Eval 258 | model.eval() 259 | eval_train_correct = 0 260 | eval_train_total = 0 261 | 262 | preds = [] 263 | truths = [] 264 | with torch.no_grad(): 265 | for inputs, labels in tqdm(test_loader): 266 | 267 | inputs, labels = inputs.to(device), labels.to(device) 268 | 269 | outputs, _, _ = model(inputs) 270 | 271 | _, predicted = torch.max(outputs.data, 1) 272 | 273 | preds.extend(F.softmax(outputs, dim=-1).data[:,1].detach().cpu().numpy().ravel()) 274 | truths.extend(labels.detach().cpu().numpy().ravel()) 275 | 276 | eval_train_total += labels.size(0) 277 | eval_train_correct += (predicted == labels).sum().item() 278 | 279 | epoch_stats['test_acc'] = eval_train_correct*1.0/eval_train_total 280 | epoch_stats['test_auc'] = roc_auc_score(truths, preds) 281 | 282 | csv_log_out.write(",".join([str(epoch_stats[h]) for h in headers]) + "\n") 283 | csv_log_out.flush() 284 | 285 | scheduler.step() 286 | -------------------------------------------------------------------------------- /LowMemConv.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import random 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.utils.checkpoint import checkpoint 10 | 11 | def drop_zeros_hook(module, grad_input, grad_out): 12 | """ 13 | This function is used to replace gradients that are all zeros with None 14 | In pyTorch None will not get back-propogated 15 | So we use this as a approximation to saprse BP to avoid redundant and useless work 16 | """ 17 | grads = [] 18 | with torch.no_grad(): 19 | for g in grad_input: 20 | if torch.nonzero(g).shape[0] == 0:#ITS ALL EMPTY! 21 | grads.append(g.to_sparse()) 22 | else: 23 | grads.append(g) 24 | 25 | return tuple(grads) 26 | 27 | class CatMod(torch.nn.Module): 28 | def __init__(self): 29 | super(CatMod, self).__init__() 30 | 31 | def forward(self, x): 32 | return torch.cat(x, dim=2) 33 | 34 | 35 | 36 | class LowMemConvBase(nn.Module): 37 | 38 | def __init__(self, chunk_size=65536, overlap=512, min_chunk_size=1024): 39 | """ 40 | chunk_size: how many bytes at a time to process. Increasing may improve compute efficent, but use more memory. Total memory use will be a function of chunk_size, and not of the length of the input sequence L 41 | 42 | overlap: how many bytes of overlap to use between chunks 43 | 44 | """ 45 | super(LowMemConvBase, self).__init__() 46 | self.chunk_size = chunk_size 47 | self.overlap = overlap 48 | self.min_chunk_size = min_chunk_size 49 | 50 | #Used for pooling over time in a meory efficent way 51 | self.pooling = nn.AdaptiveMaxPool1d(1) 52 | # self.pooling.register_backward_hook(drop_zeros_hook) 53 | self.cat = CatMod() 54 | self.cat.register_backward_hook(drop_zeros_hook) 55 | self.receptive_field = None 56 | 57 | #Used to force checkpoint code to behave correctly due to poor design https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/11 58 | self.dummy_tensor = torch.ones(1, dtype=torch.float32, requires_grad=True) 59 | 60 | def processRange(self, x, **kwargs): 61 | """ 62 | This method does the work to convert an LongTensor input x of shape (B, L) , where B is the batch size and L is the length of the input. The output of this functoin should be a tensor of (B, C, L), where C is the number of channels, and L is again the input length (though its OK if it got a little shorter due to convs without padding or something). 63 | 64 | """ 65 | pass 66 | 67 | def determinRF(self): 68 | """ 69 | Lets determine the receptive field & stride of our sub-network 70 | """ 71 | 72 | if self.receptive_field is not None: 73 | return self.receptive_field, self.stride, self.out_channels 74 | #else, figure this out! 75 | 76 | if not hasattr(self, "device_ids"): 77 | #We are training with just one device. Lets find out where we should move the data 78 | cur_device = next(self.embd.parameters()).device 79 | else: 80 | cur_device = "cpu" 81 | 82 | #Lets do a simple binary search to figure out how large our RF is. 83 | #It can't be larger than our chunk size! So use that as upper bound 84 | min_rf = 1 85 | max_rf = self.chunk_size 86 | 87 | with torch.no_grad(): 88 | 89 | tmp = torch.zeros((1,max_rf)).long().to(cur_device) 90 | 91 | while True: 92 | test_size = (min_rf+max_rf)//2 93 | is_valid = True 94 | try: 95 | self.processRange(tmp[:,0:test_size]) 96 | except: 97 | is_valid = False 98 | 99 | if is_valid: 100 | max_rf = test_size 101 | else: 102 | min_rf = test_size+1 103 | 104 | #print(is_valid, test_size, min_rf, max_rf) 105 | 106 | if max_rf == min_rf: 107 | self.receptive_field = min_rf 108 | out_shape = self.processRange(tmp).shape 109 | self.stride = self.chunk_size//out_shape[2] 110 | self.out_channels = out_shape[1] 111 | break 112 | 113 | 114 | return self.receptive_field, self.stride, self.out_channels 115 | 116 | 117 | def pool_group(self, *args): 118 | #x = torch.cat(args[0:-1], dim=2) 119 | x = self.cat(args) 120 | x = self.pooling(x) 121 | return x 122 | 123 | def seq2fix(self, x, pr_args={}): 124 | """ 125 | Takes in an input LongTensor of (B, L) that will be converted to a fixed length representation (B, C), where C is the number of channels provided by the base_network given at construction. 126 | """ 127 | 128 | receptive_window, stride, out_channels = self.determinRF() 129 | 130 | if x.shape[1] < receptive_window: #This is a tiny input! pad it out please 131 | x = F.pad(x, (0, receptive_window-x.shape[1]), value=0)#0 is the pad value we use 132 | 133 | batch_size = x.shape[0] 134 | length = x.shape[1] 135 | 136 | 137 | 138 | #Lets go through the input data without gradients first, and find the positions that "win" 139 | #the max-pooling. Most of the gradients will be zero, and we don't want to waste valuable 140 | #memory and time computing them. 141 | #Once we know the winners, we will go back and compute the forward activations on JUST 142 | #the subset of positions that won! 143 | winner_values = np.zeros((batch_size, out_channels))-1.0 144 | winner_indices = np.zeros((batch_size, out_channels), dtype=np.int64) 145 | 146 | if not hasattr(self, "device_ids"): 147 | #We are training with just one device. Lets find out where we should move the data 148 | cur_device = next(self.embd.parameters()).device 149 | else: 150 | cur_device = None 151 | 152 | step = self.chunk_size #- self.overlap 153 | #step = length 154 | start = 0 155 | end = start+step 156 | 157 | 158 | #TODO, I'm being a little sloppy on picking exact range, and selecting more context than i need 159 | #Future, should figure out precisely which bytes won and only include that range 160 | 161 | #print("Starting Search") 162 | with torch.no_grad(): 163 | while start < end and (end-start) >= max(self.min_chunk_size, receptive_window): 164 | #print("Range {}:{}/{}".format(start,end,length)) 165 | x_sub = x[:,start:end] 166 | if cur_device is not None: 167 | x_sub = x_sub.to(cur_device) 168 | activs = self.processRange(x_sub.long(), **pr_args) 169 | activ_win, activ_indx = F.max_pool1d(activs, kernel_size=activs.shape[2], return_indices=True) 170 | #print(activ_win.shape) 171 | #Python for this code loop is WAY too slow! Numpy it! 172 | #for b in range(batch_size): 173 | # for c in range(out_channels): 174 | # if winner_values[b,c] < activ_win[b,c]: 175 | # winner_indices[b, c] = activ_indx[b, c]*stride + start + receptive_window//2 176 | # winner_values[b,c] = activ_win[b,c] 177 | #We want to remove only last dimension, but if batch size is 1, np.squeeze 178 | #will screw us up and remove first dime too. 179 | #activ_win = np.squeeze(activ_win.cpu().numpy()) 180 | #activ_indx = np.squeeze(activ_indx.cpu().numpy()) 181 | activ_win = activ_win.cpu().numpy()[:,:,0] 182 | activ_indx = activ_indx.cpu().numpy()[:,:,0] 183 | selected = winner_values < activ_win 184 | winner_indices[selected] = activ_indx[selected]*stride + start 185 | winner_values[selected] = activ_win[selected] 186 | start = end 187 | end = min(start+step, length) 188 | 189 | #Now we know every index that won, we need to compute values and with gradients! 190 | 191 | #Find unique winners for every batch 192 | final_indices = [np.unique(winner_indices[b,:]) for b in range(batch_size)] 193 | 194 | #Collect inputs that won for each batch 195 | chunk_list = [[x[b:b+1,max(i-receptive_window,0):min(i+receptive_window,length)] for i in final_indices[b]] for b in range(batch_size)] 196 | #Convert to a torch tensor of the bytes 197 | chunk_list = [torch.cat(c, dim=1)[0,:] for c in chunk_list] 198 | 199 | #Padd out shorter sequences to the longest one 200 | x_selected = torch.nn.utils.rnn.pad_sequence(chunk_list, batch_first=True) 201 | 202 | #Shape is not (B, L) Lets compute! 203 | 204 | if cur_device is not None: 205 | x_selected = x_selected.to(cur_device) 206 | x_selected = self.processRange(x_selected.long(), **pr_args) 207 | x_selected = self.pooling(x_selected) 208 | x_selected = x_selected.view(x_selected.size(0), -1) 209 | 210 | return x_selected 211 | -------------------------------------------------------------------------------- /MalConv.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from collections import OrderedDict 3 | 4 | import random 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.utils.checkpoint import checkpoint 11 | 12 | 13 | from LowMemConv import LowMemConvBase 14 | 15 | 16 | def getParams(): 17 | #Format for this is to make it work easily with Optuna in an automated fashion. 18 | #variable name -> tuple(sampling function, dict(sampling_args) ) 19 | params = { 20 | 'channels' : ("suggest_int", {'name':'channels', 'low':32, 'high':1024}), 21 | 'log_stride' : ("suggest_int", {'name':'log2_stride', 'low':2, 'high':9}), 22 | 'window_size' : ("suggest_int", {'name':'window_size', 'low':32, 'high':512}), 23 | 'embd_size' : ("suggest_int", {'name':'embd_size', 'low':4, 'high':64}), 24 | } 25 | return OrderedDict(sorted(params.items(), key=lambda t: t[0])) 26 | 27 | def initModel(**kwargs): 28 | new_args = {} 29 | for x in getParams(): 30 | if x in kwargs: 31 | new_args[x] = kwargs[x] 32 | 33 | return MalConv(**new_args) 34 | 35 | 36 | class MalConv(LowMemConvBase): 37 | 38 | def __init__(self, out_size=2, channels=128, window_size=512, stride=512, embd_size=8, log_stride=None): 39 | super(MalConv, self).__init__() 40 | self.embd = nn.Embedding(257, embd_size, padding_idx=0) 41 | if not log_stride is None: 42 | stride = 2**log_stride 43 | 44 | self.conv_1 = nn.Conv1d(embd_size, channels, window_size, stride=stride, bias=True) 45 | self.conv_2 = nn.Conv1d(embd_size, channels, window_size, stride=stride, bias=True) 46 | 47 | 48 | self.fc_1 = nn.Linear(channels, channels) 49 | self.fc_2 = nn.Linear(channels, out_size) 50 | 51 | 52 | def processRange(self, x): 53 | x = self.embd(x) 54 | x = torch.transpose(x,-1,-2) 55 | 56 | cnn_value = self.conv_1(x) 57 | gating_weight = torch.sigmoid(self.conv_2(x)) 58 | 59 | x = cnn_value * gating_weight 60 | 61 | return x 62 | 63 | def forward(self, x): 64 | post_conv = x = self.seq2fix(x) 65 | 66 | penult = x = F.relu(self.fc_1(x)) 67 | x = self.fc_2(x) 68 | 69 | return x, penult, post_conv -------------------------------------------------------------------------------- /MalConvGCT_nocat.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from collections import OrderedDict 3 | 4 | import random 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | #from torch.utils.checkpoint import checkpoint 11 | import checkpoint #This checkpoint implementation is faster than PyTorch's when using multiple GPUs 12 | 13 | 14 | from LowMemConv import LowMemConvBase 15 | from MalConvML import MalConvML 16 | 17 | def getParams(): 18 | #Format for this is to make it work easily with Optuna in an automated fashion. 19 | #variable name -> tuple(sampling function, dict(sampling_args) ) 20 | params = { 21 | 'channels' : ("suggest_int", {'name':'channels', 'low':32, 'high':1024}), 22 | 'log_stride' : ("suggest_int", {'name':'log2_stride', 'low':2, 'high':9}), 23 | 'window_size' : ("suggest_int", {'name':'window_size', 'low':32, 'high':256}), 24 | 'layers' : ("suggest_int", {'name':'layers', 'low':1, 'high':3}), 25 | 'embd_size' : ("suggest_int", {'name':'embd_size', 'low':4, 'high':16}), 26 | } 27 | return OrderedDict(sorted(params.items(), key=lambda t: t[0])) 28 | 29 | def initModel(**kwargs): 30 | new_args = {} 31 | for x in getParams(): 32 | if x in kwargs: 33 | new_args[x] = kwargs[x] 34 | 35 | return MalConvGCT(**new_args) 36 | 37 | 38 | class MalConvGCT(LowMemConvBase): 39 | 40 | def __init__(self, out_size=2, channels=128, window_size=512, stride=512, layers=1, embd_size=8, log_stride=None, low_mem=True): 41 | super(MalConvGCT, self).__init__() 42 | self.low_mem = low_mem 43 | self.embd = nn.Embedding(257, embd_size, padding_idx=0) 44 | if not log_stride is None: 45 | stride = 2**log_stride 46 | 47 | self.context_net = MalConvML(out_size=channels, channels=channels, window_size=window_size, stride=stride, layers=layers, embd_size=embd_size) 48 | self.convs = nn.ModuleList([nn.Conv1d(embd_size, channels*2, window_size, stride=stride, bias=True)] + [nn.Conv1d(channels, channels*2, window_size, stride=1, bias=True) for i in range(layers-1)]) 49 | 50 | #These two objs are not used. They were originally present before the F.glu function existed, and then were accidently left in when we switched over. So the state file provided has unusued states in it. They are left in this definition so that there are no issues loading the file that MalConv was trained on. 51 | #If you are going to train from scratch, you can delete these two lines. 52 | #self.convs_1 = nn.ModuleList([nn.Conv1d(channels*2, channels, 1, bias=True) for i in range(layers)]) 53 | #self.convs_atn = nn.ModuleList([nn.Conv1d(channels*2, channels, 1, bias=True) for i in range(layers)]) 54 | 55 | self.linear_atn = nn.ModuleList([nn.Linear(channels, channels) for i in range(layers)]) 56 | 57 | #one-by-one cons to perform information sharing 58 | self.convs_share = nn.ModuleList([nn.Conv1d(channels, channels, 1, bias=True) for i in range(layers)]) 59 | 60 | 61 | self.fc_1 = nn.Linear(channels, channels) 62 | self.fc_2 = nn.Linear(channels, out_size) 63 | 64 | 65 | #Over-write the determinRF call to use the base context_net to detemrin RF. We should have the same totla RF, and this will simplify logic significantly. 66 | def determinRF(self): 67 | return self.context_net.determinRF() 68 | 69 | def processRange(self, x, gct=None): 70 | if gct is None: 71 | raise Exception("No Global Context Given") 72 | 73 | x = self.embd(x) 74 | #x = torch.transpose(x,-1,-2) 75 | x = x.permute(0,2,1) 76 | 77 | for conv_glu, linear_cntx, conv_share in zip(self.convs, self.linear_atn, self.convs_share): 78 | x = F.glu(conv_glu(x), dim=1) 79 | x = F.leaky_relu(conv_share(x)) 80 | x_len = x.shape[2] 81 | B = x.shape[0] 82 | C = x.shape[1] 83 | 84 | sqrt_dim = np.sqrt(x.shape[1]) 85 | #we are going to need a version of GCT with a time dimension, which we will adapt as needed to the right length 86 | ctnx = torch.tanh(linear_cntx(gct)) 87 | 88 | #Size is (B, C), but we need (B, C, 1) to use as a 1d conv filter 89 | ctnx = torch.unsqueeze(ctnx, dim=2) 90 | #roll the batches into the channels 91 | x_tmp = x.view(1,B*C,-1) 92 | #Now we can apply a conv with B groups, so that each batch gets its own context applied only to what was needed 93 | x_tmp = F.conv1d(x_tmp, ctnx, groups=B) 94 | #x_tmp will have a shape of (1, B, L), now we just need to re-order the data back to (B, 1, L) 95 | x_gates = x_tmp.view(B, 1, -1) 96 | 97 | #Now we effectively apply σ(x_t^T tanh(W c)) 98 | gates = torch.sigmoid( x_gates ) 99 | x = x * gates 100 | 101 | return x 102 | 103 | def forward(self, x): 104 | 105 | if self.low_mem: 106 | global_context = checkpoint.CheckpointFunction.apply(self.context_net.seq2fix,1, x) 107 | else: 108 | global_context = self.context_net.seq2fix(x) 109 | 110 | post_conv = x = self.seq2fix(x, pr_args={'gct':global_context}) 111 | 112 | penult = x = F.leaky_relu(self.fc_1( x )) 113 | x = self.fc_2(x) 114 | 115 | return x, penult, post_conv 116 | -------------------------------------------------------------------------------- /MalConvGCT_nocatTrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import deque 3 | 4 | import random 5 | import numpy as np 6 | 7 | #from tqdm import tqdm_notebook as tqdm 8 | from tqdm import tqdm 9 | import multiprocessing 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.utils.checkpoint import checkpoint 15 | 16 | import torch.optim as optim 17 | 18 | from torch.utils import data 19 | 20 | from torch.utils.data import Dataset, DataLoader, Subset 21 | 22 | #from MalConv import MalConv 23 | from MalConvGCT_nocat import MalConvGCT 24 | 25 | from binaryLoader import BinaryDataset, RandomChunkSampler, pad_collate_func 26 | from sklearn.metrics import roc_auc_score 27 | 28 | import argparse 29 | 30 | #Check if the input is a valid directory 31 | def dir_path(string): 32 | if os.path.isdir(string): 33 | return string 34 | else: 35 | raise NotADirectoryError(string) 36 | 37 | parser = argparse.ArgumentParser(description='Train a MalConv model') 38 | 39 | parser.add_argument('--filter_size', type=int, default=256, help='How wide should the filter be') 40 | parser.add_argument('--filter_stride', type=int, default=64, help='Filter Stride') 41 | parser.add_argument('--embd_size', type=int, default=8, help='Size of embedding layer') 42 | parser.add_argument('--num_channels', type=int, default=128, help='Total number of channels in output') 43 | parser.add_argument('--epochs', type=int, default=30, help='How many training epochs to perform') 44 | parser.add_argument('--non-neg', type=bool, default=False, help='Should non-negative training be used') 45 | parser.add_argument('--batch_size', type=int, default=128, help='Batch size during training') 46 | #Default is set ot 16 MB! 47 | parser.add_argument('--max_len', type=int, default=16000000, help='Maximum length of input file in bytes, at which point files will be truncated') 48 | 49 | parser.add_argument('--gpus', nargs='+', type=int) 50 | 51 | 52 | parser.add_argument('mal_train', type=dir_path, help='Path to directory containing malware files for training') 53 | parser.add_argument('ben_train', type=dir_path, help='Path to directory containing benign files for training') 54 | parser.add_argument('mal_test', type=dir_path, help='Path to directory containing malware files for testing') 55 | parser.add_argument('ben_test', type=dir_path, help='Path to directory containing benign files for testing') 56 | 57 | args = parser.parse_args() 58 | 59 | GPUS = args.gpus 60 | 61 | NON_NEG = args.non_neg 62 | EMBD_SIZE = args.embd_size 63 | FILTER_SIZE = args.filter_size 64 | FILTER_STRIDE = args.filter_stride 65 | NUM_CHANNELS= args.num_channels 66 | EPOCHS = args.epochs 67 | MAX_FILE_LEN = args.max_len 68 | 69 | BATCH_SIZE = args.batch_size 70 | 71 | 72 | 73 | whole_dataset = BinaryDataset(args.ben_train, args.mal_train, sort_by_size=True, max_len=MAX_FILE_LEN ) 74 | test_dataset = BinaryDataset(args.ben_test, args.mal_test, sort_by_size=True, max_len=MAX_FILE_LEN ) 75 | 76 | loader_threads = max(multiprocessing.cpu_count()-4, multiprocessing.cpu_count()//2+1) 77 | 78 | train_loader = DataLoader(whole_dataset, batch_size=BATCH_SIZE, num_workers=loader_threads, collate_fn=pad_collate_func, 79 | sampler=RandomChunkSampler(whole_dataset,BATCH_SIZE)) 80 | 81 | test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=loader_threads, collate_fn=pad_collate_func, 82 | sampler=RandomChunkSampler(test_dataset,BATCH_SIZE)) 83 | 84 | if GPUS is None:#use ALL of them! (Default) 85 | device_str = "cuda:0" 86 | else: 87 | if GPUS[0] < 0: 88 | device_str = "cpu" 89 | else: 90 | device_str = "cuda:{}".format(GPUS[0]) 91 | 92 | 93 | device = torch.device(device_str if torch.cuda.is_available() else "cpu") 94 | 95 | model = MalConvGCT(channels=NUM_CHANNELS, window_size=FILTER_SIZE, stride=FILTER_STRIDE, embd_size=EMBD_SIZE, low_mem=False).to(device) 96 | 97 | base_name = "nocat_{}_channels_{}_filterSize_{}_stride_{}_embdSize_{}".format( 98 | type(model).__name__, 99 | NUM_CHANNELS, 100 | FILTER_SIZE, 101 | FILTER_STRIDE, 102 | EMBD_SIZE, 103 | ) 104 | 105 | if NON_NEG: 106 | base_name = "NonNeg_" + base_name 107 | 108 | if GPUS is None or len(GPUS) > 1: 109 | model = nn.DataParallel(model, device_ids=GPUS) 110 | 111 | if not os.path.exists(base_name): 112 | os.makedirs(base_name) 113 | file_name = os.path.join(base_name, base_name) 114 | 115 | 116 | headers = ['epoch', 'train_acc', 'train_auc', 'test_acc', 'test_auc'] 117 | 118 | csv_log_out = open(file_name + ".csv", 'w') 119 | csv_log_out.write(",".join(headers) + "\n") 120 | 121 | criterion = nn.CrossEntropyLoss() 122 | #optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) 123 | optimizer = optim.Adam(model.parameters()) 124 | 125 | for epoch in tqdm(range(EPOCHS)): 126 | 127 | preds = [] 128 | truths = [] 129 | running_loss = 0.0 130 | 131 | 132 | train_correct = 0 133 | train_total = 0 134 | 135 | epoch_stats = {'epoch':epoch} 136 | 137 | model.train() 138 | for inputs, labels in tqdm(train_loader): 139 | 140 | #inputs, labels = inputs.to(device), labels.to(device) 141 | #Keep inputs on CPU, the model will load chunks of input onto device as needed 142 | labels = labels.to(device) 143 | 144 | optimizer.zero_grad() 145 | 146 | # outputs, penultimate_activ, conv_active = model.forward_extra(inputs) 147 | outputs, penultimate_activ, conv_active = model(inputs) 148 | loss = criterion(outputs, labels) 149 | loss = loss #+ decov_lambda*(decov_penalty(penultimate_activ) + decov_penalty(conv_active)) 150 | # loss = loss + decov_lambda*(decov_penalty(conv_active)) 151 | loss.backward() 152 | optimizer.step() 153 | if NON_NEG: 154 | for p in model.parameters(): 155 | p.data.clamp_(0) 156 | 157 | 158 | running_loss += loss.item() 159 | 160 | _, predicted = torch.max(outputs.data, 1) 161 | 162 | with torch.no_grad(): 163 | preds.extend(F.softmax(outputs, dim=-1).data[:,1].detach().cpu().numpy().ravel()) 164 | truths.extend(labels.detach().cpu().numpy().ravel()) 165 | 166 | train_total += labels.size(0) 167 | train_correct += (predicted == labels).sum().item() 168 | 169 | 170 | 171 | #print("Training Accuracy: {}".format(train_correct*100.0/train_total)) 172 | 173 | epoch_stats['train_acc'] = train_correct*1.0/train_total 174 | epoch_stats['train_auc'] = roc_auc_score(truths, preds) 175 | #epoch_stats['train_loss'] = roc_auc_score(truths, preds) 176 | 177 | #Save the model and current state! 178 | model_path = os.path.join(base_name, "epoch_{}.checkpoint".format(epoch)) 179 | 180 | 181 | #Have to handle model state special if multi-gpu was used 182 | if type(model).__name__ is "DataParallel": 183 | mstd = model.module.state_dict() 184 | else: 185 | mstd = model.state_dict() 186 | 187 | torch.save({ 188 | 'epoch': epoch, 189 | 'model_state_dict': mstd, 190 | 'optimizer_state_dict': optimizer.state_dict(), 191 | 'channels': NUM_CHANNELS, 192 | 'filter_size': FILTER_SIZE, 193 | 'stride': FILTER_STRIDE, 194 | 'embd_dim': EMBD_SIZE, 195 | 'non_neg': NON_NEG, 196 | }, model_path) 197 | 198 | 199 | #Test Set Eval 200 | model.eval() 201 | eval_train_correct = 0 202 | eval_train_total = 0 203 | 204 | preds = [] 205 | truths = [] 206 | with torch.no_grad(): 207 | for inputs, labels in tqdm(test_loader): 208 | 209 | inputs, labels = inputs.to(device), labels.to(device) 210 | 211 | outputs, penultimate_activ, conv_active = model(inputs) 212 | 213 | _, predicted = torch.max(outputs.data, 1) 214 | 215 | preds.extend(F.softmax(outputs, dim=-1).data[:,1].detach().cpu().numpy().ravel()) 216 | truths.extend(labels.detach().cpu().numpy().ravel()) 217 | 218 | eval_train_total += labels.size(0) 219 | eval_train_correct += (predicted == labels).sum().item() 220 | 221 | epoch_stats['test_acc'] = eval_train_correct*1.0/eval_train_total 222 | epoch_stats['test_auc'] = roc_auc_score(truths, preds) 223 | 224 | csv_log_out.write(",".join([str(epoch_stats[h]) for h in headers]) + "\n") 225 | csv_log_out.flush() 226 | 227 | 228 | 229 | csv_log_out.close() 230 | 231 | 232 | 233 | 234 | -------------------------------------------------------------------------------- /MalConvML.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from collections import OrderedDict 3 | 4 | import random 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.utils.checkpoint import checkpoint 11 | 12 | 13 | from LowMemConv import LowMemConvBase 14 | 15 | 16 | def getParams(): 17 | #Format for this is to make it work easily with Optuna in an automated fashion. 18 | #variable name -> tuple(sampling function, dict(sampling_args) ) 19 | params = { 20 | 'channels' : ("suggest_int", {'name':'channels', 'low':32, 'high':1024}), 21 | 'log_stride' : ("suggest_int", {'name':'log2_stride', 'low':2, 'high':9}), 22 | 'window_size' : ("suggest_int", {'name':'window_size', 'low':32, 'high':256}), 23 | 'layers' : ("suggest_int", {'name':'layers', 'low':1, 'high':6}), 24 | 'embd_size' : ("suggest_int", {'name':'embd_size', 'low':4, 'high':64}), 25 | } 26 | return OrderedDict(sorted(params.items(), key=lambda t: t[0])) 27 | 28 | def initModel(**kwargs): 29 | new_args = {} 30 | for x in getParams(): 31 | if x in kwargs: 32 | new_args[x] = kwargs[x] 33 | 34 | return MalConvML(**new_args) 35 | 36 | 37 | class MalConvML(LowMemConvBase): 38 | 39 | def __init__(self, out_size=2, channels=128, window_size=512, stride=512, layers=1, embd_size=8, log_stride=None): 40 | super(MalConvML, self).__init__() 41 | self.embd = nn.Embedding(257, embd_size, padding_idx=0) 42 | if not log_stride is None: 43 | stride = 2**log_stride 44 | 45 | self.convs = nn.ModuleList([nn.Conv1d(embd_size, channels*2, window_size, stride=stride, bias=True)] + [nn.Conv1d(channels, channels*2, window_size, stride=1, bias=True) for i in range(layers-1)]) 46 | #one-by-one cons to perform information sharing 47 | self.convs_1 = nn.ModuleList([nn.Conv1d(channels, channels, 1, bias=True) for i in range(layers)]) 48 | 49 | 50 | self.fc_1 = nn.Linear(channels, channels) 51 | self.fc_2 = nn.Linear(channels, out_size) 52 | 53 | 54 | def processRange(self, x): 55 | x = self.embd(x) 56 | #x = torch.transpose(x,-1,-2) 57 | x = x.permute(0,2,1).contiguous() 58 | 59 | for conv_glu, conv_share in zip(self.convs, self.convs_1): 60 | x = F.leaky_relu(conv_share(F.glu(conv_glu(x.contiguous()), dim=1))) 61 | 62 | return x 63 | 64 | def forward(self, x): 65 | post_conv = x = self.seq2fix(x) 66 | 67 | penult = x = F.relu(self.fc_1(x)) 68 | x = self.fc_2(x) 69 | 70 | return x, penult, post_conv 71 | -------------------------------------------------------------------------------- /MalConvTrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import deque 3 | 4 | import random 5 | import numpy as np 6 | 7 | #from tqdm import tqdm_notebook as tqdm 8 | from tqdm import tqdm 9 | import multiprocessing 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.utils.checkpoint import checkpoint 15 | 16 | import torch.optim as optim 17 | 18 | from torch.utils import data 19 | 20 | from torch.utils.data import Dataset, DataLoader, Subset 21 | 22 | from MalConv import MalConv 23 | #from MalConv2A import MalConv2A 24 | 25 | from binaryLoader import BinaryDataset, RandomChunkSampler, pad_collate_func 26 | from sklearn.metrics import roc_auc_score 27 | 28 | import argparse 29 | 30 | #Check if the input is a valid directory 31 | def dir_path(string): 32 | if os.path.isdir(string): 33 | return string 34 | else: 35 | raise NotADirectoryError(string) 36 | 37 | parser = argparse.ArgumentParser(description='Train a MalConv model') 38 | 39 | parser.add_argument('--filter_size', type=int, default=512, help='How wide should the filter be') 40 | parser.add_argument('--filter_stride', type=int, default=512, help='Filter Stride') 41 | parser.add_argument('--embd_size', type=int, default=8, help='Size of embedding layer') 42 | parser.add_argument('--num_channels', type=int, default=128, help='Total number of channels in output') 43 | parser.add_argument('--epochs', type=int, default=10, help='How many training epochs to perform') 44 | parser.add_argument('--non-neg', type=bool, default=False, help='Should non-negative training be used') 45 | parser.add_argument('--batch_size', type=int, default=128, help='Batch size during training') 46 | #Default is set ot 16 MB! 47 | parser.add_argument('--max_len', type=int, default=16000000, help='Maximum length of input file in bytes, at which point files will be truncated') 48 | 49 | parser.add_argument('--gpus', nargs='+', type=int) 50 | 51 | 52 | parser.add_argument('mal_train', type=dir_path, help='Path to directory containing malware files for training') 53 | parser.add_argument('ben_train', type=dir_path, help='Path to directory containing benign files for training') 54 | parser.add_argument('mal_test', type=dir_path, help='Path to directory containing malware files for testing') 55 | parser.add_argument('ben_test', type=dir_path, help='Path to directory containing benign files for testing') 56 | 57 | args = parser.parse_args() 58 | 59 | #GPUS = args.gpus 60 | GPUS = None 61 | 62 | NON_NEG = args.non_neg 63 | EMBD_SIZE = args.embd_size 64 | FILTER_SIZE = args.filter_size 65 | FILTER_STRIDE = args.filter_stride 66 | NUM_CHANNELS= args.num_channels 67 | EPOCHS = args.epochs 68 | MAX_FILE_LEN = args.max_len 69 | 70 | BATCH_SIZE = args.batch_size 71 | 72 | 73 | 74 | whole_dataset = BinaryDataset(args.ben_train, args.mal_train, sort_by_size=True, max_len=MAX_FILE_LEN ) 75 | test_dataset = BinaryDataset(args.ben_test, args.mal_test, sort_by_size=True, max_len=MAX_FILE_LEN ) 76 | 77 | loader_threads = max(multiprocessing.cpu_count()-4, multiprocessing.cpu_count()//2+1) 78 | 79 | train_loader = DataLoader(whole_dataset, batch_size=BATCH_SIZE, num_workers=loader_threads, collate_fn=pad_collate_func, 80 | sampler=RandomChunkSampler(whole_dataset,BATCH_SIZE)) 81 | 82 | test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=loader_threads, collate_fn=pad_collate_func, 83 | sampler=RandomChunkSampler(test_dataset,BATCH_SIZE)) 84 | 85 | if GPUS is None:#use ALL of them! (Default) 86 | device_str = "cuda:0" 87 | else: 88 | if GPUS[0] < 0: 89 | device_str = "cpu" 90 | else: 91 | device_str = "cuda:{}".format(GPUS[0]) 92 | 93 | 94 | device = torch.device(device_str if torch.cuda.is_available() else "cpu") 95 | 96 | print("Using device ", device) 97 | model = MalConv(channels=NUM_CHANNELS, window_size=FILTER_SIZE, stride=FILTER_STRIDE, embd_size=EMBD_SIZE).to(device) 98 | 99 | base_name = "{}_channels_{}_filterSize_{}_stride_{}_embdSize_{}".format( 100 | type(model).__name__, 101 | NUM_CHANNELS, 102 | FILTER_SIZE, 103 | FILTER_STRIDE, 104 | EMBD_SIZE, 105 | ) 106 | 107 | if NON_NEG: 108 | base_name = "NonNeg_" + base_name 109 | 110 | if GPUS is None or len(GPUS) > 1: 111 | model = nn.DataParallel(model, device_ids=GPUS) 112 | 113 | if not os.path.exists(base_name): 114 | os.makedirs(base_name) 115 | file_name = os.path.join(base_name, base_name) 116 | 117 | 118 | headers = ['epoch', 'train_acc', 'train_auc', 'test_acc', 'test_auc'] 119 | 120 | csv_log_out = open(file_name + ".csv", 'w') 121 | csv_log_out.write(",".join(headers) + "\n") 122 | 123 | criterion = nn.CrossEntropyLoss() 124 | #optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) 125 | optimizer = optim.Adam(model.parameters()) 126 | 127 | for epoch in tqdm(range(EPOCHS)): 128 | 129 | preds = [] 130 | truths = [] 131 | running_loss = 0.0 132 | 133 | 134 | train_correct = 0 135 | train_total = 0 136 | 137 | epoch_stats = {'epoch':epoch} 138 | 139 | model.train() 140 | for inputs, labels in tqdm(train_loader): 141 | 142 | #inputs, labels = inputs.to(device), labels.to(device) 143 | #Keep inputs on CPU, the model will load chunks of input onto device as needed 144 | labels = labels.to(device) 145 | 146 | optimizer.zero_grad() 147 | 148 | # outputs, penultimate_activ, conv_active = model.forward_extra(inputs) 149 | outputs, _, _ = model(inputs) 150 | loss = criterion(outputs, labels) 151 | loss = loss #+ decov_lambda*(decov_penalty(penultimate_activ) + decov_penalty(conv_active)) 152 | # loss = loss + decov_lambda*(decov_penalty(conv_active)) 153 | loss.backward() 154 | optimizer.step() 155 | if NON_NEG: 156 | for p in model.parameters(): 157 | p.data.clamp_(0) 158 | 159 | 160 | running_loss += loss.item() 161 | 162 | _, predicted = torch.max(outputs.data, 1) 163 | 164 | with torch.no_grad(): 165 | preds.extend(F.softmax(outputs, dim=-1).data[:,1].detach().cpu().numpy().ravel()) 166 | truths.extend(labels.detach().cpu().numpy().ravel()) 167 | 168 | train_total += labels.size(0) 169 | train_correct += (predicted == labels).sum().item() 170 | 171 | 172 | #print("Training Accuracy: {}".format(train_correct*100.0/train_total)) 173 | epoch_stats['train_acc'] = train_correct*1.0/train_total 174 | epoch_stats['train_auc'] = roc_auc_score(truths, preds) 175 | #epoch_stats['train_loss'] = roc_auc_score(truths, preds) 176 | 177 | #Save the model and current state! 178 | model_path = os.path.join(base_name, "epoch_{}.checkpoint".format(epoch)) 179 | 180 | 181 | #Have to handle model state special if multi-gpu was used 182 | if type(model).__name__ is "DataParallel": 183 | mstd = model.module.state_dict() 184 | else: 185 | mstd = model.state_dict() 186 | 187 | torch.save({ 188 | 'epoch': epoch, 189 | 'model_state_dict': mstd, 190 | 'optimizer_state_dict': optimizer.state_dict(), 191 | 'channels': NUM_CHANNELS, 192 | 'filter_size': FILTER_SIZE, 193 | 'stride': FILTER_STRIDE, 194 | 'embd_dim': EMBD_SIZE, 195 | 'non_neg': NON_NEG, 196 | }, model_path) 197 | 198 | 199 | #Test Set Eval 200 | model.eval() 201 | eval_train_correct = 0 202 | eval_train_total = 0 203 | 204 | preds = [] 205 | truths = [] 206 | with torch.no_grad(): 207 | for inputs, labels in tqdm(test_loader): 208 | 209 | inputs, labels = inputs.to(device), labels.to(device) 210 | 211 | outputs, _, _ = model(inputs) 212 | 213 | _, predicted = torch.max(outputs.data, 1) 214 | 215 | preds.extend(F.softmax(outputs, dim=-1).data[:,1].detach().cpu().numpy().ravel()) 216 | truths.extend(labels.detach().cpu().numpy().ravel()) 217 | 218 | eval_train_total += labels.size(0) 219 | eval_train_correct += (predicted == labels).sum().item() 220 | 221 | epoch_stats['test_acc'] = eval_train_correct*1.0/eval_train_total 222 | epoch_stats['test_auc'] = roc_auc_score(truths, preds) 223 | 224 | csv_log_out.write(",".join([str(epoch_stats[h]) for h in headers]) + "\n") 225 | csv_log_out.flush() 226 | 227 | 228 | 229 | csv_log_out.close() 230 | 231 | 232 | 233 | 234 | -------------------------------------------------------------------------------- /OptunaTrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import deque 3 | from collections import OrderedDict 4 | 5 | import random 6 | import numpy as np 7 | 8 | #from tqdm import tqdm_notebook as tqdm 9 | from tqdm import tqdm 10 | import multiprocessing 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torch.utils.checkpoint import checkpoint 16 | 17 | import torch.optim as optim 18 | 19 | from torch.utils import data 20 | 21 | from torch.utils.data import Dataset, DataLoader, Subset 22 | 23 | from binaryLoader import BinaryDataset, RandomChunkSampler, pad_collate_func 24 | from sklearn.metrics import roc_auc_score 25 | 26 | import optuna 27 | 28 | import argparse 29 | 30 | #Check if the input is a valid directory 31 | def dir_path(string): 32 | if os.path.isdir(string): 33 | return string 34 | else: 35 | raise NotADirectoryError(string) 36 | 37 | parser = argparse.ArgumentParser(description='Train a Model model') 38 | 39 | parser.add_argument('--epochs', type=int, default=20, help='How many training epochs to perform') 40 | parser.add_argument('--non-neg', type=bool, default=False, help='Should non-negative training be used') 41 | parser.add_argument('--batch_size', type=int, default=128, help='Batch size during training') 42 | parser.add_argument('--trials', type=int, default=50, help='Number of hyper-parameter tuning trials to perform') 43 | #Default is set ot 16 MB! 44 | parser.add_argument('--max_len', type=int, default=16000000, help='Maximum length of input file in bytes, at which point files will be truncated') 45 | parser.add_argument('--val-split', type=float, default=0.1, help='Batch size during training') 46 | 47 | parser.add_argument('--model', type=str, default="MalConv", choices=["MalConv", "Avast", "MalConvML", "MalConvGCT"], help='Type of model to train') 48 | 49 | parser.add_argument('--gpus', nargs='+', type=int) 50 | 51 | 52 | parser.add_argument('mal_train', type=dir_path, help='Path to directory containing malware files for training') 53 | parser.add_argument('ben_train', type=dir_path, help='Path to directory containing benign files for training') 54 | parser.add_argument('mal_test', type=dir_path, help='Path to directory containing malware files for testing') 55 | parser.add_argument('ben_test', type=dir_path, help='Path to directory containing benign files for testing') 56 | 57 | args = parser.parse_args() 58 | 59 | GPUS = args.gpus 60 | 61 | NON_NEG = args.non_neg 62 | EPOCHS = args.epochs 63 | MAX_FILE_LEN = args.max_len 64 | MODEL_NAME = args.model 65 | TRIALS = args.trials 66 | 67 | BATCH_SIZE = args.batch_size 68 | 69 | if MODEL_NAME.lower() == "MalConv".lower(): 70 | from MalConv import getParams, initModel 71 | elif MODEL_NAME.lower() == "Avast".lower(): 72 | from AvastStyleConv import getParams, initModel 73 | elif MODEL_NAME.lower() == "MalConvML".lower(): 74 | from MalConvML import getParams, initModel 75 | elif MODEL_NAME.lower() == "MalConvGCT".lower(): 76 | from MalConvGCT import getParams, initModel 77 | 78 | 79 | 80 | whole_dataset = BinaryDataset(args.ben_train, args.mal_train, sort_by_size=True, max_len=MAX_FILE_LEN ) 81 | test_dataset = BinaryDataset(args.ben_test, args.mal_test, sort_by_size=True, max_len=MAX_FILE_LEN ) 82 | 83 | loader_threads = max(multiprocessing.cpu_count()-4, multiprocessing.cpu_count()//2+1) 84 | 85 | #Create train & validation split 86 | 87 | #First we define our own random split, b/c we want to keep data shuffle order in 88 | #tact b/c it will make trainin faster. This is because we kept things orded by size, so batches 89 | #can be as small as possible. 90 | def random_split(dataset, lengths): 91 | """ 92 | Randomly split a dataset into non-overlapping new datasets of given lengths. 93 | 94 | Arguments: 95 | dataset (Dataset): Dataset to be split 96 | lengths (sequence): lengths of splits to be produced 97 | """ 98 | #if sum(lengths) != len(dataset): 99 | # raise ValueError("Sum of input lengths does not equal the length of the input dataset!") 100 | 101 | indices = torch.randperm(sum(lengths)).tolist() 102 | to_ret = [] 103 | for offset, length in zip(torch._utils._accumulate(lengths), lengths): 104 | selected = indices[offset - length:offset] 105 | selected.sort() 106 | to_ret.append( Subset(dataset, selected) ) 107 | return to_ret 108 | 109 | 110 | whole_len = len(whole_dataset) 111 | train_size = int(whole_len*(1-args.val_split)) 112 | validation_size = whole_len-train_size 113 | dataset_train_split, dataset_val_split = random_split(whole_dataset, (train_size, validation_size)) 114 | 115 | train_loader = DataLoader(dataset_train_split, batch_size=BATCH_SIZE, num_workers=loader_threads, collate_fn=pad_collate_func, 116 | sampler=RandomChunkSampler(dataset_train_split,BATCH_SIZE)) 117 | 118 | val_loader = DataLoader(dataset_val_split, batch_size=BATCH_SIZE, num_workers=loader_threads, collate_fn=pad_collate_func, 119 | sampler=RandomChunkSampler(dataset_val_split,BATCH_SIZE)) 120 | 121 | # train_loader = DataLoader(whole_dataset, batch_size=BATCH_SIZE, num_workers=loader_threads, collate_fn=pad_collate_func, 122 | # sampler=RandomChunkSampler(whole_dataset,BATCH_SIZE)) 123 | 124 | test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=loader_threads, collate_fn=pad_collate_func, 125 | sampler=RandomChunkSampler(test_dataset,BATCH_SIZE)) 126 | 127 | if GPUS is None:#use ALL of them! (Default) 128 | device_str = "cuda:0" 129 | else: 130 | if GPUS[0] < 0: 131 | device_str = "cpu" 132 | else: 133 | device_str = "cuda:{}".format(GPUS[0]) 134 | 135 | 136 | device = torch.device(device_str if torch.cuda.is_available() else "cpu") 137 | 138 | def objective(trial): 139 | 140 | args_to_use = { 141 | 'lr':0.001, 142 | } 143 | 144 | if not trial is None: 145 | for param, (sample_func, sample_args) in getParams().items(): 146 | args_to_use[param] = getattr(trial, sample_func)(**sample_args) 147 | args_to_use['lr'] = trial.suggest_loguniform('lr', low=1e-4, high=1e-2) 148 | 149 | args_to_use = OrderedDict(sorted(args_to_use.items(), key=lambda t: t[0])) 150 | 151 | model = initModel(**args_to_use).to(device) 152 | 153 | 154 | base_name = MODEL_NAME + "_".join([a + "_" + str(b) for (a, b) in args_to_use.items()]) 155 | 156 | if NON_NEG: 157 | base_name = "NonNeg_" + base_name 158 | 159 | if GPUS is None or len(GPUS) > 1: 160 | model = nn.DataParallel(model, device_ids=GPUS) 161 | 162 | if not os.path.exists(base_name): 163 | os.makedirs(base_name) 164 | file_name = os.path.join(base_name, base_name) 165 | 166 | 167 | headers = ['epoch', 'train_acc', 'train_auc', 'test_acc', 'test_auc', 'val_acc', 'val_auc'] 168 | 169 | # csv_log_out = open(file_name + ".csv", 'w') 170 | with open(file_name + ".csv", 'w') as csv_log_out: 171 | csv_log_out.write(",".join(headers) + "\n") 172 | 173 | criterion = nn.CrossEntropyLoss() 174 | optimizer = optim.AdamW(model.parameters(), lr=args_to_use['lr']) 175 | 176 | for epoch in tqdm(range(EPOCHS)): 177 | 178 | preds = [] 179 | truths = [] 180 | running_loss = 0.0 181 | 182 | 183 | train_correct = 0 184 | train_total = 0 185 | 186 | epoch_stats = {'epoch':epoch} 187 | 188 | model.train() 189 | for inputs, labels in tqdm(train_loader): 190 | 191 | #inputs, labels = inputs.to(device), labels.to(device) 192 | #Keep inputs on CPU, the model will load chunks of input onto device as needed 193 | labels = labels.to(device) 194 | 195 | optimizer.zero_grad() 196 | 197 | # outputs, penultimate_activ, conv_active = model.forward_extra(inputs) 198 | outputs, penult, post_conv = model(inputs) 199 | loss = criterion(outputs, labels) 200 | loss = loss #+ decov_lambda*(decov_penalty(penultimate_activ) + decov_penalty(conv_active)) 201 | # loss = loss + decov_lambda*(decov_penalty(conv_active)) 202 | loss.backward() 203 | optimizer.step() 204 | if NON_NEG: 205 | for p in model.parameters(): 206 | p.data.clamp_(0) 207 | 208 | 209 | running_loss += loss.item() 210 | 211 | _, predicted = torch.max(outputs.data, 1) 212 | 213 | with torch.no_grad(): 214 | preds.extend(F.softmax(outputs, dim=-1).data[:,1].detach().cpu().numpy().ravel()) 215 | truths.extend(labels.detach().cpu().numpy().ravel()) 216 | 217 | train_total += labels.size(0) 218 | train_correct += (predicted == labels).sum().item() 219 | 220 | #end train loop 221 | 222 | #print("Training Accuracy: {}".format(train_correct*100.0/train_total)) 223 | 224 | epoch_stats['train_acc'] = train_correct*1.0/train_total 225 | epoch_stats['train_auc'] = roc_auc_score(truths, preds) 226 | #epoch_stats['train_loss'] = roc_auc_score(truths, preds) 227 | 228 | #Save the model and current state! 229 | model_path = os.path.join(base_name, "epoch_{}.checkpoint".format(epoch)) 230 | 231 | 232 | #Have to handle model state special if multi-gpu was used 233 | if type(model).__name__ is "DataParallel": 234 | mstd = model.module.state_dict() 235 | else: 236 | mstd = model.state_dict() 237 | 238 | #Copy dict, and add extra info to save off 239 | check_dict = args_to_use.copy() 240 | check_dict['epoch'] = epoch 241 | check_dict['model_state_dict'] = mstd 242 | check_dict['optimizer_state_dict'] = optimizer.state_dict() 243 | check_dict['non_neg'] = NON_NEG 244 | torch.save(check_dict, model_path) 245 | 246 | 247 | #Test Set Eval 248 | model.eval() 249 | eval_train_correct = 0 250 | eval_train_total = 0 251 | 252 | preds = [] 253 | truths = [] 254 | with torch.no_grad(): 255 | for inputs, labels in tqdm(test_loader): 256 | 257 | inputs, labels = inputs.to(device), labels.to(device) 258 | 259 | outputs, _, _ = model(inputs) 260 | 261 | _, predicted = torch.max(outputs.data, 1) 262 | 263 | preds.extend(F.softmax(outputs, dim=-1).data[:,1].detach().cpu().numpy().ravel()) 264 | truths.extend(labels.detach().cpu().numpy().ravel()) 265 | 266 | eval_train_total += labels.size(0) 267 | eval_train_correct += (predicted == labels).sum().item() 268 | 269 | epoch_stats['test_acc'] = eval_train_correct*1.0/eval_train_total 270 | epoch_stats['test_auc'] = roc_auc_score(truths, preds) 271 | 272 | #We've now done an epoch of training. Lets do a validation run to see what our current reuslts look like & report to Optuna 273 | eval_train_correct = 0 274 | eval_train_total = 0 275 | preds = [] 276 | truths = [] 277 | with torch.no_grad(): 278 | for inputs, labels in val_loader: 279 | 280 | inputs, labels = inputs.to(device), labels.to(device) 281 | 282 | outputs, _, _ = model(inputs) 283 | 284 | _, predicted = torch.max(outputs.data, 1) 285 | 286 | preds.extend(F.softmax(outputs, dim=-1).data[:,1].detach().cpu().numpy().ravel()) 287 | truths.extend(labels.detach().cpu().numpy().ravel()) 288 | 289 | eval_train_total += labels.size(0) 290 | eval_train_correct += (predicted == labels).sum().item() 291 | 292 | validation_error = 1.0 - eval_train_correct/eval_train_total 293 | trial.report(validation_error, epoch) 294 | 295 | epoch_stats['val_acc'] = eval_train_correct*1.0/eval_train_total 296 | epoch_stats['val_auc'] = roc_auc_score(truths, preds) 297 | 298 | csv_log_out.write(",".join([str(epoch_stats[h]) for h in headers]) + "\n") 299 | csv_log_out.flush() 300 | 301 | # Handle pruning based on the intermediate value. 302 | if trial.should_prune(): 303 | raise optuna.structs.TrialPruned() 304 | 305 | #end for epoch loop 306 | for att in ['val_acc', 'val_auc', 'test_acc', 'test_auc']: 307 | trial.set_user_attr(att, epoch_stats[att]) 308 | #end 'with csv' log 309 | return validation_error 310 | 311 | 312 | study_name = MODEL_NAME 313 | if NON_NEG: 314 | study_name = "NonNeg_" + study_name 315 | study = optuna.create_study(study_name=study_name, storage='sqlite:///{}.db'.format(study_name), pruner=optuna.pruners.SuccessiveHalvingPruner()) 316 | study.optimize(objective, n_trials=TRIALS) 317 | 318 | 319 | 320 | study.trials_dataframe().to_pickle(out_name + "_pd.pkl") 321 | 322 | 323 | 324 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Classifying Sequences of Extreme Length with Constant Memory Applied to Malware Detection (a.k.a., MalConv2) 2 | 3 | This is the PyTorch code implementing the approaches from our AAAI 2021 paper [Classifying Sequences of Extreme Length with Constant Memory Applied to Malware Detection](https://arxiv.org/abs/2012.09390). Using it, you can train the original MalConv model faster and using less memory 4 | than before. You can also train our new MalConv with “Global Channel Gating” (GCG), what allows MalConv to learn feature interactions from across the entire inputs. 5 | 6 | ## Code Organization 7 | 8 | This is research quality code that has gone through some quick edits before going online, and comes with no warranty. The rough outline of the files in this repo. 9 | 10 | ### binaryLoader.py 11 | 12 | `binaryLoader.py` contains the functions we use for loading in a dataset of binaries, and supports un-gziping them on the fly to reduce IO costs. It also includes a sampler that is used to create batches of similarly sized files to minimize excess 13 | padding used during training. This assumes the input dataset is already in sorted order by file size. 14 | 15 | ### checkpoint.py 16 | 17 | This contains code used to perform gradient checkpointing for reduced memory usage. This is optional and generally not necessary for our MalConv* models now, but was used during experimentation. 18 | 19 | ### LowMemConv.py 20 | 21 | LowMemConv is the base class that implementations extend to obtain the fixed-memory pooling we introduced. This is provided by `seq2fix` function, which does the work of applying the convolution in chunks, tracking the winners, and then grouping the 22 | winning slices to run over with gradient calculations on. 23 | 24 | The user extends `LowMemConvBase`, implementing the `processRange` function, which applies whatever convolutional strategy they desire to a range of bytes. The `determinRF` function is used to determine the receptive field size by iteratively testing 25 | for the smallest input size that does not error, so that we know how to size our chunk sizes later. 26 | 27 | 28 | ### MalConvGCT_nocat.py & MalConvGCTTrain.py 29 | 30 | MalConvGCT_nocat implements the new contribution of our paper, using the GCT attention. An older file, MalConvGCT uses this pooling but with a concatenation at the end. 31 | 32 | The `MalConvGCTTrain.py` is the sister file that will train a `MalConvGCT` object. 33 | 34 | The associated "*Train.py" functions allow for training these models. AvastTyleConv implements the max pool version of the Avast architecture, and MalConvML implement a multiple layer version of MalConv that were used in ablation testing. MalConv.py 35 | implements the original MalConv using our new low memory approach. 36 | 37 | ### malconvGCT_nocat.checkpoint 38 | 39 | This file contains the weights for the GCT model from our paper’s results. It has some extra parameters that were never used due to some lines left commented in durning model training. It also has an off-by-one “bug” that says its the 21’st epoch 40 | instead of the 20’th. 41 | 42 | To load this file, you want to have code that looks like: 43 | 44 | ```python 45 | from MalConvGCT_nocat import MalConvGCT 46 | 47 | mlgct = MalConvGCT(channels=256, window_size=256, stride=64,) 48 | x = torch.load("malconvGCT_nocat.checkpoint.checkpoint") 49 | mlgct.load_state_dict(x['model_state_dict'], strict=False) 50 | ``` 51 | 52 | ### AvastStyleConv.py 53 | 54 | This implements a network in the style of Avast’s CNN from 2018, but replacing average pooling with our temporal max pooling for speed. 55 | 56 | ### MalConv.py 57 | 58 | Implements the original MalConv network with our faster training/pooling. 59 | 60 | 61 | ### MalConvML.py 62 | 63 | This file contains an alternative experiment approach to training with more layers, but never worked well. 64 | 65 | ### ContinueTraining.py 66 | 67 | This file can be used to resume the training of models from a given checkpoint. 68 | 69 | ### OptunaTrain.py 70 | 71 | This file is used to do training why a hyper-parameter search. 72 | 73 | ### Non-Neg options 74 | 75 | The non-negative training currently present is faulty, as it allows you to do such training with a softmax output, which is technically incorrect. Please do not use it. 76 | 77 | 78 | ## Citations 79 | 80 | If you use the MalConv GCT algorithm or code, please cite our work! 81 | 82 | ``` 83 | @inproceedings{malconvGCT, 84 | author = {Raff, Edward and Fleshman, William and Zak, Richard and Anderson, Hyrum and Filar, Bobby and Mclean, Mark}, 85 | booktitle = {The Thirty-Fifth AAAI Conference on Artificial Intelligence}, 86 | title = {{Classifying Sequences of Extreme Length with Constant Memory Applied to Malware Detection}}, 87 | year = {2021}, 88 | url={https://arxiv.org/abs/2012.09390}, 89 | } 90 | ``` 91 | 92 | ## Contact 93 | 94 | If you have questions, please contact 95 | 96 | Mark Mclean 97 | Edward Raff 98 | Richard Zak 99 | 100 | -------------------------------------------------------------------------------- /binaryLoader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import deque 3 | 4 | import random 5 | import numpy as np 6 | 7 | from tqdm import tqdm_notebook as tqdm 8 | 9 | import gzip 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.utils.checkpoint import checkpoint 15 | 16 | import torch.optim as optim 17 | 18 | from torch.utils import data 19 | 20 | from torch.utils.data import Dataset, DataLoader, Subset 21 | 22 | 23 | class BinaryDataset(data.Dataset): 24 | """ 25 | Loader for binary files. 26 | 27 | If you use the sort_by_file_size option, the dataset will store files from smallest to largest. This is meant to used with RandomChunkSampler to sammple batches of similarly sized files to maximize performance. 28 | 29 | TODO: Auto un-gzip files if they have g-zip compression 30 | """ 31 | def __init__(self, good_dir, bad_dir, sort_by_size=False, max_len=4000000): 32 | 33 | #Tuple (file_path, label, file_size) 34 | self.all_files = [] 35 | self.max_len = max_len 36 | 37 | for roor_dir, dirs, files in os.walk(good_dir): 38 | for file in files: 39 | to_add = os.path.join(roor_dir,file) 40 | self.all_files.append( (to_add, 0, os.path.getsize(to_add)) ) 41 | 42 | for roor_dir, dirs, files in os.walk(bad_dir): 43 | for file in files: 44 | to_add = os.path.join(roor_dir,file) 45 | self.all_files.append( (to_add, 1, os.path.getsize(to_add)) ) 46 | 47 | if sort_by_size: 48 | self.all_files.sort(key=lambda filename: filename[2]) 49 | 50 | def __len__(self): 51 | return len(self.all_files) 52 | 53 | def __getitem__(self, index): 54 | 55 | to_load, y, _ = self.all_files[index] 56 | 57 | try: 58 | with gzip.open(to_load, 'rb') as f: 59 | x = f.read(self.max_len) 60 | #Need to use frombuffer b/c its a byte array, otherwise np.asarray will get wonked on trying to convert to ints 61 | #So decode as uint8 (1 byte per value), and then convert 62 | x = np.frombuffer(x, dtype=np.uint8).astype(np.int16)+1 #index 0 will be special padding index 63 | except OSError: 64 | #OK, you are not a gziped file. Just read in raw bytes from disk. 65 | with open(to_load, 'rb') as f: 66 | x = f.read(self.max_len) 67 | #Need to use frombuffer b/c its a byte array, otherwise np.asarray will get wonked on trying to convert to ints 68 | #So decode as uint8 (1 byte per value), and then convert 69 | x = np.frombuffer(x, dtype=np.uint8).astype(np.int16)+1 #index 0 will be special padding index 70 | 71 | #x = np.pad(x, self.max_len-x.shape[0], 'constant') 72 | x = torch.tensor(x) 73 | 74 | return x, torch.tensor([y]) 75 | 76 | class RandomChunkSampler(torch.utils.data.sampler.Sampler): 77 | """ 78 | Samples random "chunks" of a dataset, so that items within a chunk are always loaded together. Useful to keep chunks in similar size groups to reduce runtime. 79 | """ 80 | def __init__(self, data_source, batch_size): 81 | """ 82 | data_source: the souce pytorch dataset object 83 | batch_size: the size of the chunks to keep together. Should generally be set to the desired batch size during training to minimize runtime. 84 | """ 85 | self.data_source = data_source 86 | self.batch_size = batch_size 87 | 88 | def __iter__(self): 89 | n = len(self.data_source) 90 | 91 | data = [x for x in range(n)] 92 | 93 | # Create blocks 94 | blocks = [data[i:i+self.batch_size] for i in range(0,len(data),self.batch_size)] 95 | # shuffle the blocks 96 | random.shuffle(blocks) 97 | # concatenate the shuffled blocks 98 | data[:] = [b for bs in blocks for b in bs] 99 | 100 | return iter(data) 101 | 102 | def __len__(self): 103 | return len(self.data_source) 104 | 105 | #We want to hadnel true variable length 106 | #Data loader needs equal length. So use special function to padd all the data in a single batch to be of equal length 107 | #to the longest item in the batch 108 | def pad_collate_func(batch): 109 | """ 110 | This should be used as the collate_fn=pad_collate_func for a pytorch DataLoader object in order to pad out files in a batch to the length of the longest item in the batch. 111 | """ 112 | vecs = [x[0] for x in batch] 113 | labels = [x[1] for x in batch] 114 | 115 | x = torch.nn.utils.rnn.pad_sequence(vecs, batch_first=True) 116 | #stack will give us (B, 1), so index [:,0] to get to just (B) 117 | y = torch.stack(labels)[:,0] 118 | 119 | return x, y -------------------------------------------------------------------------------- /checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | 4 | 5 | def detach_variable(inputs): 6 | if isinstance(inputs, tuple): 7 | out = [] 8 | for inp in inputs: 9 | x = inp.detach() 10 | x.requires_grad = inp.requires_grad 11 | out.append(x) 12 | return tuple(out) 13 | else: 14 | raise RuntimeError( 15 | "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__) 16 | 17 | 18 | def check_backward_validity(inputs): 19 | if not any(inp.requires_grad for inp in inputs): 20 | warnings.warn("None of the inputs have requires_grad=True. Gradients will be None") 21 | 22 | 23 | class CheckpointFunction(torch.autograd.Function): 24 | @staticmethod 25 | def forward(ctx, run_function, length, *args): 26 | ctx.run_function = run_function 27 | ctx.input_tensors = list(args[:length]) 28 | ctx.input_params = list(args[length:]) 29 | with torch.no_grad(): 30 | output_tensors = ctx.run_function(*ctx.input_tensors) 31 | return output_tensors 32 | 33 | @staticmethod 34 | def backward(ctx, *output_grads): 35 | for i in range(len(ctx.input_tensors)): 36 | temp = ctx.input_tensors[i] 37 | ctx.input_tensors[i] = temp.detach() 38 | ctx.input_tensors[i].requires_grad = temp.requires_grad 39 | with torch.enable_grad(): 40 | output_tensors = ctx.run_function(*ctx.input_tensors) 41 | input_grads = torch.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True) 42 | return (None, None) + input_grads 43 | -------------------------------------------------------------------------------- /malconvGCT_nocat.checkpoint: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FutureComputing4AI/MalConv2/b6ff10fe14dfb49956914326f8d4c287fa1525f8/malconvGCT_nocat.checkpoint --------------------------------------------------------------------------------