├── .gitignore ├── LICENSE ├── README.md ├── blindspot_coverage ├── __init__.py └── covering_number.py ├── datasets ├── __init__.py └── datasets.py ├── framework.py ├── generate_vectors.py ├── helper_files ├── imported_function_to_index_mapping_dict.p ├── lief_parseable_cnet_programs.p ├── linux_environment.yml └── osx_environment.yml ├── inner_maximizers ├── __init__.py └── inner_maximizers.py ├── nets ├── __init__.py └── ff_classifier.py ├── parameters.ini ├── requirements.txt ├── run_experiments.py └── utils ├── __init__.py ├── script_functions.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *pycache* 3 | .idea* 4 | *.pt 5 | *result_files 6 | *.DS_Store 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 ALFA-group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # robust-adv-malware-detection 2 | Code repository for the paper [Adversarial Deep Learning for Robust Detection of Binary Encoded Malware](https://arxiv.org/pdf/1801.02950.pdf), A. Al-Dujaili *et al.*, 2018 3 | 4 | - Visualization tool in the paper [On Visual Hallmarks of Robustness to Adversarial Malware](https://arxiv.org/pdf/1805.03553.pdf) can be found [here](https://github.com/ALFA-group/adv-malware-viz). 5 | 6 | - A series of related blog posts can be found [here](http://ash-aldujaili.github.io/blog/2018/08/29/evasive-malware/). 7 | 8 | - Dataset can be shared upon request, please fill in the form https://goo.gl/forms/hn5Dfiset1Y1BkMr1 and we will send you a link to the dataset. 9 | 10 | ## Installation 11 | 12 | All the required packages are specified in the yml files under `helper_files`. If you have `conda` installed, you can just `cd` to the main directory and execute the following with `osx_environment.yml` or `linux_environment.yml` on OSx or Linux, respectively. 13 | 14 | ``` 15 | conda env create —f ./helper_files/(osx|linux)_environment.yml 16 | ``` 17 | This will create an environment called `nn_mal`. 18 | 19 | 20 | To activate this environment, execute: 21 | ``` 22 | source activate nn_mal 23 | ``` 24 | 25 | **Note**: If you're going to use losswise, you may run into an issue of one print line whose argument is not enclosed by brackets, 26 | just put the brackets if this error shows up and you're good to go. 27 | 28 | **Note**: If you’re running the code on Mac OS with Cuda, then according to Pytorch.org 29 | “macOS Binaries dont support CUDA, install from source if CUDA is needed” 30 | 31 | 32 | ## Running: 33 | 34 | 35 | 1. Configure your experiment as desired by modifying the `parameters.ini` file. Among the things you may want to to specify: 36 | a - dataset filepath 37 | b - gpu device if any 38 | c - name of the experiment 39 | d - training method (inner maximizer) 40 | e - evasion method 41 | 42 | **Note** In case you do not have access to the dataset, you can still run the code on a synthetic dataset with 8-dimensional binary feature vectors, whose bits are set with probability 0.2 for malicious class and 0.8 for benign class. 43 | 44 | 45 | 2. Execute `framework.py` 46 | 47 | ``` 48 | python framework.py 49 | ``` 50 | 51 | **Note**: the experiments can be all logged and monitored using losswise. 52 | To activate logging, set `losswise_api_key` to your API key in `parameters.in` 53 | and set `is_losswise` to **True** 54 | 55 | 56 | ## Reproducing Paper Results: 57 | 58 | In order to reproduce the results in the paper, set the filepaths to the malicious and benign saved feature vectors (these can be re-generated with `generate_vectors.py`) and execute the `run_experiments.py` script 59 | 60 | ``` 61 | python run_experiments.py 62 | ``` 63 | 64 | Results (accuracy metrics, bscn measures, and evasion rates) will be populated under (to-be-generated) `result_files` directory. On the other hand, the trained models will be saved under `helper_files`. 65 | 66 | The results can be compiled into LaTeX tables saved under `result_files` by runnig the function `create_tex_tables()` with the valid filepath to the result files under `utils/script_functions.py`. By default, you can do the following 67 | ``` 68 | cd utils/ 69 | python script_functions.py 70 | ``` 71 | 72 | 73 | **NOTE** For linux OS, you may run into the trouble of running `source` from within Python `os.system()`. A workaround is to replace the `os.system()` command in `run_experiments.py` with the following line: 74 | ``` 75 | system('/bin/bash -c "source activate nn_mal;python framework.py”') 76 | ``` 77 | 78 | ## Citation 79 | 80 | If you make use of this code and you'd like to cite us, please consider the following: 81 | 82 | ``` 83 | @article{al2018adversarial, 84 | title={Adversarial Deep Learning for Robust Detection of Binary Encoded Malware}, 85 | author={Al-Dujaili, Abdullah and Huang, Alex and Hemberg, Erik and O'Reilly, Una-May}, 86 | journal={arXiv preprint arXiv:1801.02950}, 87 | year={2018} 88 | } 89 | ``` 90 | -------------------------------------------------------------------------------- /blindspot_coverage/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALFA-group/robust-adv-malware-detection/7f0761d5d1905374f12b426249625496424584a3/blindspot_coverage/__init__.py -------------------------------------------------------------------------------- /blindspot_coverage/covering_number.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ 3 | Python module for computing coverage number 4 | """ 5 | import pybloomfilter 6 | 7 | 8 | class CoveringNumber(object): 9 | """ 10 | Class to implement Eq.(4) of the paper 11 | """ 12 | 13 | def __init__(self, num_samples, expected_num_points, batch_size, error_rate=0.000001): 14 | self._num_samples = num_samples 15 | self._batch_size = batch_size 16 | self._expected_num_points = expected_num_points 17 | self._denominator = [0.] * num_samples 18 | self._numerator = [0.] * num_samples 19 | 20 | self._bf = pybloomfilter.BloomFilter(expected_num_points, error_rate) 21 | self._actual_num_points = 0. 22 | 23 | def update_denominator(self, sample_idx, sample): 24 | """ 25 | Computes the denominator of the sample_idxth sample of the training data. 26 | This method needs to be called once as the denominator is constant regardless of the adversarial learning 27 | process 28 | :param sample_idx: index of the training sample 29 | :param sample: the original form of the training sample 30 | :return: 31 | """ 32 | self._denominator[sample_idx] = len(sample) - torch.dot(sample, torch.ones(sample.size())) 33 | 34 | def update_numerator(self, sample_idx, point): 35 | """ 36 | update the numerator counter for the sample_idxth point by testing if point has already been visited 37 | :param sample_idx: index of the training sample 38 | :param point: current version of the training sample 39 | :return: 40 | """ 41 | pt_np = point.numpy() 42 | is_not_in = not self._bf.add(hash(str(pt_np))) 43 | self._numerator[sample_idx] += int(is_not_in) 44 | self._actual_num_points += int(is_not_in) 45 | 46 | def update_numerator_batch(self, batch_idx, batch): 47 | """ 48 | update the covering number measure with the new batch 49 | :param batch_idx: current batch index 50 | :param batch: batch features in tensor float (these will be hashed in against bloom filter 51 | :return: 52 | """ 53 | for point_idx, point in enumerate(batch): 54 | sample_idx = point_idx + self._batch_size * batch_idx 55 | self.update_numerator(sample_idx, point) 56 | 57 | def ratio(self): 58 | """ 59 | :return: the ratio of the visited samples to the maximum expected ones 60 | """ 61 | return self._actual_num_points * 1. / self._expected_num_points 62 | 63 | 64 | if __name__ == "__main__": 65 | print("I am just a module to be called by others, testing here") 66 | import torch 67 | _num_samples = 10 68 | epochs = 100 69 | _expected_num_points = _num_samples * epochs 70 | _batch_size = 2 71 | 72 | bscn = CoveringNumber(_num_samples, _expected_num_points, _batch_size) 73 | 74 | _batch = torch.rand(_batch_size, 1) 75 | 76 | print(_batch) 77 | bscn.update_numerator_batch(2, _batch) 78 | print("done") 79 | print(bscn.ratio()) 80 | print(_batch) 81 | bscn.update_numerator_batch(2, _batch) 82 | 83 | print(bscn.ratio()) 84 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALFA-group/robust-adv-malware-detection/7f0761d5d1905374f12b426249625496424584a3/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/datasets.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """Python module for handling datasets for training and testing""" 3 | import torch 4 | import os 5 | import pickle 6 | import lief 7 | import traceback 8 | import time 9 | 10 | from torch.utils.data import Dataset, DataLoader 11 | from sklearn.model_selection import train_test_split 12 | 13 | MALICIOUS_LABEL = 1 14 | BENIGN_LABEL = 0 15 | 16 | 17 | def create_import_to_index_mapping(parameters): 18 | """ 19 | Creates mapping of all the lib imports within benign and malicious samples into their corresponding indices in the 20 | feature vector. The mapping is pickled to a file. 21 | While we can do this for each dataset. Calling this function ahead gets us the dimensionality of the feature 22 | vector. 23 | :param parameters: a json-like structure of the system parameters and configurations 24 | :return: 25 | """ 26 | print("Creating import to index mapping") 27 | 28 | if eval(parameters['dataset']['load_mapping_from_pickle']): 29 | print(" -- import-index mapping pickle file exists, skipping this") 30 | return 31 | 32 | malicious_filepath = get_helper_filepath(parameters, "malicious_filepath") 33 | benign_filepath = get_helper_filepath(parameters, "benign_filepath") 34 | 35 | imported_function_to_index = {} 36 | index = 0 37 | 38 | # List of parseable files 39 | if parameters['dataset']['malicious_files_list'] == 'None': 40 | malicious_files = os.listdir(malicious_filepath) 41 | else: 42 | malicious_files = pickle.load( 43 | open(get_helper_filepath(parameters, "malicious_files_list"), 'rb')) 44 | 45 | if parameters['dataset']['benign_files_list'] == 'None': 46 | benign_files = os.listdir(benign_filepath) 47 | else: 48 | benign_files = pickle.load(open(get_helper_filepath(parameters, "benign_files_list"), 'rb')) 49 | 50 | # Add the filepath for both malicious and benign files 51 | malicious_files = [malicious_filepath + hash_str for hash_str in malicious_files] 52 | benign_files = [benign_filepath + hash_str for hash_str in benign_files] 53 | 54 | print("Malicious files:", len(malicious_files)) 55 | print("Benign files:", len(benign_files)) 56 | print("Total number of files:", len(malicious_files + benign_files)) 57 | 58 | previous_time = time.time() 59 | 60 | for i, hash_filepath in enumerate(malicious_files + benign_files): 61 | 62 | if i % 100 == 0: 63 | current_time = time.time() 64 | print(i, "Time:", current_time - previous_time, " seconds") 65 | previous_time = time.time() 66 | 67 | try: 68 | binary = lief.parse(hash_filepath) 69 | 70 | # With imports includes the library (DLL) the function comes from 71 | imports_with_library = [ 72 | lib.name.lower() + ':' + e.name for lib in binary.imports for e in lib.entries 73 | ] 74 | 75 | for lib_import in imports_with_library: 76 | 77 | if lib_import not in imported_function_to_index: 78 | imported_function_to_index[lib_import] = index 79 | index += 1 80 | 81 | except: 82 | traceback.print_exc() 83 | pass 84 | 85 | pickle.dump(imported_function_to_index, 86 | open(get_helper_filepath(parameters, "pickle_mapping_file"), 'wb')) 87 | 88 | 89 | class PortableExecutableDataset(Dataset): 90 | def __init__(self, file_abs_locations, is_malicious, parameters): 91 | """ 92 | file_abs_locations: PE file names including path to 93 | is_malicious: either benign or malicious files in a single dataset 94 | """ 95 | 96 | self.file_abs_locations = file_abs_locations 97 | self.is_malicious = is_malicious 98 | self._num_features = eval(parameters["dataset"]["num_features_to_use"]) 99 | self._is_synthetic = eval(parameters["general"]["is_synthetic_dataset"]) 100 | self.use_pickle = eval(parameters["dataset"]["use_saved_feature_vectors"]) 101 | # returns the filepath, necessary when generating pickle vector files 102 | self.return_filepath = eval(parameters['dataset']['generate_feature_vector_files']) 103 | 104 | if self._is_synthetic or self.use_pickle: 105 | pass 106 | else: 107 | # Need to create an index mapping or load a preloaded one 108 | try: 109 | self.imported_function_to_index = pickle.load( 110 | open(get_helper_filepath(parameters, "pickle_mapping_file"), "rb")) 111 | except: 112 | imported_function_to_index = {} 113 | index = 0 114 | 115 | previous_time = time.time() 116 | for i, filepath in enumerate(file_abs_locations): 117 | 118 | if i % 1000 == 0: 119 | current_time = time.time() 120 | print("Time for last 1000:", current_time - previous_time) 121 | previous_time = time.time() 122 | 123 | print(i) 124 | 125 | try: 126 | imports_with_library = self.__get_imports_with_library(filepath) 127 | 128 | for lib_import in imports_with_library: 129 | 130 | if lib_import not in imported_function_to_index: 131 | imported_function_to_index[lib_import] = index 132 | index += 1 133 | 134 | except: 135 | print(i, "FAILED") 136 | pass 137 | 138 | self.imported_function_to_index = imported_function_to_index 139 | 140 | def __get_imports_with_library(self, filepath): 141 | """ 142 | Helper function to get the list of imported function calls for a binary 143 | :param filepath: binary's absolute filepath 144 | :return: the list of functions called/imported appended by their libraries. 145 | """ 146 | binary = lief.parse(filepath) 147 | return [lib.name.lower() + ':' + e.name for lib in binary.imports for e in lib.entries] 148 | 149 | def __len__(self): 150 | return len(self.file_abs_locations) 151 | 152 | def __getitem__(self, idx): 153 | 154 | if self.is_malicious: 155 | label = MALICIOUS_LABEL 156 | else: 157 | label = BENIGN_LABEL 158 | 159 | filepath = None 160 | if self._is_synthetic: 161 | threshold = 0.2 if self.is_malicious else 0.8 162 | feature_vector = (torch.rand(8) < threshold).float() 163 | else: 164 | filepath = self.file_abs_locations[idx] 165 | 166 | if self.use_pickle: 167 | feature_vector = pickle.load(open(filepath, 'rb')) 168 | feature_vector = feature_vector.squeeze() 169 | 170 | else: 171 | feature_vector = [0] * len(self.imported_function_to_index) 172 | 173 | try: 174 | # Vector of 0's initially, switch to 1's at each location corresponding to an imported function 175 | imports_with_library = self.__get_imports_with_library(filepath) 176 | 177 | for lib_import in imports_with_library: 178 | index = self.imported_function_to_index[lib_import] 179 | 180 | feature_vector[index] = 1 181 | 182 | except: 183 | raise Exception("%s is not parseable!" % filepath) 184 | 185 | feature_vector = torch.Tensor(feature_vector[:self._num_features]) 186 | 187 | if self.return_filepath: 188 | return feature_vector, label, filepath 189 | else: 190 | return feature_vector, label 191 | 192 | 193 | def get_helper_filepath(parameters, filename): 194 | """ 195 | Return the absolute file of the 'filename' helper file 196 | :param parameters: 197 | :param filename: file name 198 | :return: 199 | """ 200 | filename = os.path.join(parameters["dataset"]["helper_filepath"], 201 | parameters["dataset"][filename]) 202 | print("-- accessing file:", filename) 203 | return filename 204 | 205 | 206 | def load_data(parameters): 207 | """ 208 | Load the training/test datasets 209 | :param parameters: 210 | :return: dictionaries of train and test dataloaders 211 | """ 212 | print("Starting data loading") 213 | if eval(parameters["general"]["is_synthetic_dataset"]): 214 | num_files = int(parameters['dataset']['num_files_to_use']) 215 | malicious_files_abs_locs = ["1"] * num_files 216 | benign_files_abs_locs = ["2"] * num_files 217 | else: 218 | # generate the global index mapping file 219 | create_import_to_index_mapping(parameters) 220 | 221 | # get absolute filenames for path malicious and benign files 222 | malicious_filepath = get_helper_filepath(parameters, "malicious_filepath") 223 | benign_filepath = get_helper_filepath(parameters, "benign_filepath") 224 | 225 | if parameters['dataset']['malicious_files_list'] == 'None': 226 | malicious_files = os.listdir(malicious_filepath) 227 | else: 228 | malicious_files_list_file = get_helper_filepath(parameters, "malicious_files_list") 229 | print("Getting malicious files from ", malicious_files_list_file) 230 | malicious_files = pickle.load(open(malicious_files_list_file, 'rb')) 231 | 232 | if parameters['dataset']['benign_files_list'] == 'None': 233 | benign_files = os.listdir(benign_filepath) 234 | else: 235 | benign_files_list_file = get_helper_filepath(parameters, "benign_files_list") 236 | print("Getting benign files from ", benign_files_list_file) 237 | benign_files = pickle.load(open(benign_files_list_file, 'rb')) 238 | 239 | malicious_files_abs_locs = [malicious_filepath + _hash for _hash in malicious_files] 240 | benign_files_abs_locs = [benign_filepath + _hash for _hash in benign_files] 241 | 242 | # set the datasets 243 | if eval(parameters['dataset']['use_subset_of_data']): 244 | num_files = int(parameters['dataset']['num_files_to_use']) 245 | malicious_files_abs_locs = malicious_files_abs_locs[:num_files] 246 | benign_files_abs_locs = benign_files_abs_locs[:num_files] 247 | 248 | print("Malware Files:", len(malicious_files_abs_locs)) 249 | print("Benign Files:", len(benign_files_abs_locs)) 250 | 251 | # our malicious and benign datasets have the same size and have the same batch size 252 | # assert len(malicious_files_abs_locs) == len( 253 | # benign_files_abs_locs), "It is assumed that malicious and benign dataset are of the same size" 254 | 255 | training_batch_size = int(parameters["hyperparam"]["training_batch_size"]) 256 | test_batch_size = int(parameters["hyperparam"]["test_batch_size"]) 257 | test_size_percent = float(parameters["dataset"]["test_size_percent"]) 258 | 259 | train_malicious_files_abs_locs, test_malicious_files_abs_locs = train_test_split( 260 | malicious_files_abs_locs, test_size=test_size_percent) 261 | train_malicious_files_abs_locs, valid_malicious_files_abs_locs = train_test_split( 262 | train_malicious_files_abs_locs, test_size=0.25) 263 | 264 | train_benign_files_abs_locs, test_benign_files_abs_locs = train_test_split( 265 | benign_files_abs_locs, test_size=test_size_percent) 266 | train_benign_files_abs_locs, valid_benign_files_abs_locs = train_test_split( 267 | train_benign_files_abs_locs, test_size=0.25) 268 | 269 | print("Preparing training datasets") 270 | train_malicious_dataset = PortableExecutableDataset( 271 | train_malicious_files_abs_locs, is_malicious=True, parameters=parameters) 272 | train_benign_dataset = PortableExecutableDataset( 273 | train_benign_files_abs_locs, is_malicious=False, parameters=parameters) 274 | 275 | print("Preparing validation datasets") 276 | valid_malicious_dataset = PortableExecutableDataset( 277 | valid_malicious_files_abs_locs, is_malicious=True, parameters=parameters) 278 | valid_benign_dataset = PortableExecutableDataset( 279 | valid_benign_files_abs_locs, is_malicious=False, parameters=parameters) 280 | 281 | print("Preparing testing datasets") 282 | test_malicious_dataset = PortableExecutableDataset( 283 | test_malicious_files_abs_locs, is_malicious=True, parameters=parameters) 284 | test_benign_dataset = PortableExecutableDataset( 285 | test_benign_files_abs_locs, is_malicious=False, parameters=parameters) 286 | 287 | # assertion 288 | assert train_malicious_dataset[0][0].size() == train_benign_dataset[0][ 289 | 0].size(), "malicious and benign are of the same feature space" 290 | assert test_malicious_dataset[0][0].size() == test_benign_dataset[0][ 291 | 0].size(), "malicious and benign are of the same feature space" 292 | assert test_malicious_dataset[0][0].size() == train_benign_dataset[0][ 293 | 0].size(), "malicious and benign are of the same feature space" 294 | assert train_malicious_dataset[0][0].size() == test_benign_dataset[0][ 295 | 0].size(), "malicious and benign are of the same feature space" 296 | assert valid_malicious_dataset[0][0].size() == valid_benign_dataset[0][ 297 | 0].size(), "malicious and benign are of the same feature space" 298 | assert valid_malicious_dataset[0][0].size() == train_benign_dataset[0][ 299 | 0].size(), "malicious and benign are of the same feature space" 300 | 301 | _num_features = train_benign_dataset[0][0].size()[0] 302 | 303 | # set the dataloaders 304 | num_workers = int(parameters['general']['num_workers']) 305 | 306 | malicious_trainloader = DataLoader( 307 | train_malicious_dataset, 308 | batch_size=training_batch_size, 309 | shuffle=True, 310 | num_workers=num_workers) 311 | benign_trainloader = DataLoader( 312 | train_benign_dataset, batch_size=training_batch_size, shuffle=True, num_workers=num_workers) 313 | 314 | malicious_validloader = DataLoader( 315 | valid_malicious_dataset, 316 | batch_size=training_batch_size, 317 | shuffle=True, 318 | num_workers=num_workers) 319 | benign_validloader = DataLoader( 320 | valid_benign_dataset, batch_size=training_batch_size, shuffle=True, num_workers=num_workers) 321 | 322 | malicious_testloader = DataLoader( 323 | test_malicious_dataset, batch_size=test_batch_size, shuffle=False, num_workers=num_workers) 324 | benign_testloader = DataLoader( 325 | test_benign_dataset, batch_size=test_batch_size, shuffle=False, num_workers=num_workers) 326 | 327 | train_dataloaders = {"malicious": malicious_trainloader, "benign": benign_trainloader} 328 | valid_dataloaders = {"malicious": malicious_validloader, "benign": benign_validloader} 329 | test_dataloaders = {"malicious": malicious_testloader, "benign": benign_testloader} 330 | 331 | # return the dataloaders in a dictionary 332 | return train_dataloaders, valid_dataloaders, test_dataloaders, _num_features 333 | 334 | 335 | if __name__ == "__main__": 336 | print("I am a module to be imported by others, testing some functionalities here") 337 | from utils.utils import load_parameters 338 | 339 | _parameters = load_parameters("../parameters.ini") 340 | train_data, valid_data, test_data, num_features = load_data(parameters=_parameters) 341 | dset_1 = train_data["malicious"].dataset 342 | dset_2 = train_data["benign"].dataset 343 | print("A sample from malicious dataset has ", sum(dset_1[0][0]), " features, with label", 344 | dset_1[0][1]) 345 | print("A sample from benign dataset has ", sum(dset_2[0][0]), " features, with label", 346 | dset_2[0][1]) 347 | print("Feature space is of %d-dimensionality" % dset_1[10][0].size(), num_features) 348 | -------------------------------------------------------------------------------- /framework.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ 3 | Python module for performing adversarial training for malware detection 4 | """ 5 | import os 6 | import torch 7 | import random 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.autograd import Variable 11 | from utils.utils import load_parameters, stack_tensors 12 | from datasets.datasets import load_data 13 | from inner_maximizers.inner_maximizers import inner_maximizer 14 | from nets.ff_classifier import build_ff_classifier 15 | from blindspot_coverage.covering_number import CoveringNumber 16 | import losswise 17 | import time 18 | import json 19 | import numpy as np 20 | 21 | # Step 1. Load configuration 22 | parameters = load_parameters("parameters.ini") 23 | is_cuda = eval(parameters["general"]["is_cuda"]) 24 | if is_cuda: 25 | os.environ["CUDA_VISIBLE_DEVICES"] = parameters["general"]["gpu_device"] 26 | 27 | assertion_message = "Set this flag off to train models." 28 | assert eval(parameters['dataset']['generate_feature_vector_files']) is False, assertion_message 29 | 30 | log_interval = int(parameters["general"]["log_interval"]) 31 | num_epochs = int(parameters["hyperparam"]["ff_num_epochs"]) 32 | is_losswise = eval(parameters["general"]["is_losswise"]) 33 | is_synthetic_dataset = eval(parameters["general"]["is_synthetic_dataset"]) 34 | 35 | training_method = parameters["general"]["training_method"] 36 | evasion_method = parameters["general"]["evasion_method"] 37 | experiment_suffix = parameters["general"]["experiment_suffix"] 38 | experiment_name = "[training:%s|evasion:%s]_%s" % (training_method, evasion_method, 39 | experiment_suffix) 40 | 41 | print("Training Method:%s, Evasion Method:%s" % (training_method, evasion_method)) 42 | 43 | seed_val = int(parameters["general"]["seed"]) 44 | 45 | random.seed(seed_val) 46 | torch.manual_seed(seed_val) 47 | np.random.seed(seed_val) 48 | 49 | if is_losswise: 50 | losswise_key = parameters['general']['losswise_api_key'] 51 | 52 | if losswise_key == 'None': 53 | raise Exception("Must set API key in the parameters file to use losswise") 54 | 55 | losswise.set_api_key(losswise_key) 56 | 57 | session = losswise.Session(tag=experiment_name, max_iter=200) 58 | graph_loss = session.graph("loss", kind="min") 59 | graph_accuracy = session.graph("accuracy", kind="max") 60 | graph_coverage = session.graph("coverage", kind="max") 61 | graph_evasion = session.graph("evasion", kind="min") 62 | 63 | evasion_iterations = int(parameters['hyperparam']['evasion_iterations']) 64 | 65 | save_every_epoch = eval(parameters['general']['save_every_epoch']) 66 | 67 | train_model_from_scratch = eval(parameters['general']['train_model_from_scratch']) 68 | load_model_weights = eval(parameters['general']['load_model_weights']) 69 | model_weights_path = parameters['general']['model_weights_path'] 70 | 71 | # Step 2. Load training and test data 72 | train_dataloader_dict, valid_dataloader_dict, test_dataloader_dict, num_features = load_data( 73 | parameters) 74 | 75 | # set the bscn metric 76 | num_samples = len(train_dataloader_dict["malicious"].dataset) 77 | bscn = CoveringNumber(num_samples, num_epochs * num_samples, 78 | train_dataloader_dict["malicious"].batch_size) 79 | 80 | if load_model_weights: 81 | print("Loading Model Weights From: {path}".format(path=model_weights_path)) 82 | model = torch.load(model_weights_path) 83 | 84 | else: 85 | # Step 3. Construct neural net (N) - this can be replaced with any model of interest 86 | model = build_ff_classifier( 87 | input_size=num_features, 88 | hidden_1_size=int(parameters["hyperparam"]["ff_h1"]), 89 | hidden_2_size=int(parameters["hyperparam"]["ff_h2"]), 90 | hidden_3_size=int(parameters["hyperparam"]["ff_h3"])) 91 | # gpu related setups 92 | if is_cuda: 93 | torch.cuda.manual_seed(int(parameters["general"]["seed"])) 94 | model = model.cuda() 95 | 96 | # Step 4. Define loss function and optimizer for training (back propagation block in Fig 2.) 97 | loss_fct = nn.NLLLoss(reduce=False) 98 | optimizer = optim.Adam(model.parameters(), lr=float(parameters["hyperparam"]["ff_learning_rate"])) 99 | 100 | 101 | def train(epoch): 102 | model.train() 103 | total_correct = 0. 104 | total_loss = 0. 105 | total = 0. 106 | 107 | current_time = time.time() 108 | 109 | if is_synthetic_dataset: 110 | # since generation of synthetic data set is random, we'd like them to be the same over epochs 111 | torch.manual_seed(seed_val) 112 | random.seed(seed_val) 113 | 114 | for batch_idx, ((bon_x, bon_y), (mal_x, mal_y)) in enumerate( 115 | zip(train_dataloader_dict["benign"], train_dataloader_dict["malicious"])): 116 | # Check for adversarial learning 117 | mal_x = inner_maximizer( 118 | mal_x, mal_y, model, loss_fct, iterations=evasion_iterations, method=training_method) 119 | 120 | # stack input 121 | if is_cuda: 122 | x = Variable(stack_tensors(bon_x, mal_x).cuda()) 123 | y = Variable(stack_tensors(bon_y, mal_y).cuda()) 124 | else: 125 | x = Variable(stack_tensors(bon_x, mal_x)) 126 | y = Variable(stack_tensors(bon_y, mal_y)) 127 | 128 | # forward pass 129 | y_model = model(x) 130 | 131 | # backward pass 132 | optimizer.zero_grad() 133 | loss = loss_fct(y_model, y).mean() 134 | loss.backward() 135 | optimizer.step() 136 | 137 | # predict pass 138 | _, predicted = torch.topk(y_model, k=1) 139 | correct = predicted.data.eq(y.data.view_as(predicted.data)).cpu().sum() 140 | 141 | # metrics 142 | total_loss += loss.data[0] * len(y) 143 | total_correct += correct 144 | total += len(y) 145 | 146 | bscn.update_numerator_batch(batch_idx, mal_x) 147 | 148 | if batch_idx % log_interval == 0: 149 | 150 | print("Time Taken:", time.time() - current_time) 151 | current_time = time.time() 152 | 153 | print( 154 | "Train Epoch ({}) | Batch ({}) | [{}/{} ({:.0f}%)]\tBatch Loss: {:.6f}\tBatch Accuracy: {:.1f}%\t BSCN: {:.12f}". 155 | format(epoch, batch_idx, batch_idx * len(x), 156 | len(train_dataloader_dict["malicious"].dataset) + 157 | len(train_dataloader_dict["benign"].dataset), 158 | 100. * batch_idx / len(train_dataloader_dict["benign"]), loss.data[0], 159 | 100. * correct / len(y), bscn.ratio())) 160 | 161 | if is_losswise: 162 | graph_accuracy.append(epoch, { 163 | "train_accuracy_%s" % experiment_name: 100. * total_correct / total 164 | }) 165 | graph_loss.append(epoch, {"train_loss_%s" % experiment_name: total_loss / total}) 166 | graph_coverage.append(epoch, {"train_coverage_%s" % experiment_name: bscn.ratio()}) 167 | 168 | model_filename = "{name}_epoch_{e}".format(name=experiment_name, e=epoch) 169 | 170 | if save_every_epoch: 171 | torch.save(model, os.path.join("model_weights", model_filename)) 172 | 173 | 174 | def check_one_category(category="benign", is_validate=False, is_evade=False, 175 | evade_method='dfgsm_k'): 176 | """ 177 | test the model in terms of loss and accuracy on category, this function also allows to perform perturbation 178 | with respect to loss to evade 179 | :param category: benign or malicious dataset 180 | :param is_validate: validation or testing dataset 181 | :param is_evade: to perform evasion or not 182 | :param evade_method: evasion method (we can use on of the inner maximier methods), it is only relevant if is_evade 183 | is True 184 | :return: 185 | """ 186 | model.eval() 187 | total_loss = 0 188 | total_correct = 0 189 | total = 0 190 | evasion_mode = "" 191 | 192 | if is_synthetic_dataset: 193 | # since generation of synthetic data set is random, we'd like them to be the same over epochs 194 | torch.manual_seed(seed_val) 195 | random.seed(seed_val) 196 | 197 | if is_validate: 198 | dataloader = valid_dataloader_dict[category] 199 | else: 200 | dataloader = test_dataloader_dict[category] 201 | 202 | for batch_idx, (x, y) in enumerate(dataloader): 203 | # 204 | if is_evade: 205 | x = inner_maximizer( 206 | x, y, model, loss_fct, iterations=evasion_iterations, method=evade_method) 207 | evasion_mode = "(evasion using %s)" % evade_method 208 | # stack input 209 | if is_cuda: 210 | x = Variable(x.cuda()) 211 | y = Variable(y.cuda()) 212 | else: 213 | x = Variable(x) 214 | y = Variable(y) 215 | 216 | # forward pass 217 | y_model = model(x) 218 | 219 | # loss pass 220 | loss = loss_fct(y_model, y).mean() 221 | 222 | # predict pass 223 | _, predicted = torch.topk(y_model, k=1) 224 | correct = predicted.data.eq(y.data.view_as(predicted.data)).cpu().sum() 225 | 226 | # metrics 227 | total_loss += loss.data[0] * len(y) 228 | total_correct += correct 229 | total += len(y) 230 | 231 | print("{} set for {} {}: Average Loss: {:.4f}, Accuracy: {:.2f}%".format( 232 | "Valid" if is_validate else "Test", category, evasion_mode, total_loss / total, 233 | total_correct * 100. / total)) 234 | 235 | return total_loss, total_correct, total 236 | 237 | 238 | def test(epoch, is_validate=False): 239 | """ 240 | Function to be used for both testing and validation 241 | :param epoch: current epoch 242 | :param is_validate: is the testing done on the validation dataset 243 | :return: average total loss, dictionary of the metrics for both bon and mal samples 244 | """ 245 | # test for accuracy and loss 246 | bon_total_loss, bon_total_correct, bon_total = check_one_category( 247 | category="benign", is_evade=False, is_validate=is_validate) 248 | mal_total_loss, mal_total_correct, mal_total = check_one_category( 249 | category="malicious", is_evade=False, is_validate=is_validate) 250 | 251 | # test for evasion on malicious sample 252 | evade_mal_total_loss, evade_mal_total_correct, evade_mal_total = check_one_category( 253 | category="malicious", is_evade=True, evade_method=evasion_method, is_validate=is_validate) 254 | 255 | total_loss = bon_total_loss + mal_total_loss 256 | total_correct = bon_total_correct + mal_total_correct 257 | total = bon_total + mal_total 258 | 259 | dataset_type = "valid" if is_validate else "test" 260 | 261 | print("{} set overall: Average Loss: {:.4f}, Accuracy: {:.2f}%".format( 262 | dataset_type, total_loss / total, total_correct * 100. / total)) 263 | 264 | if is_losswise: 265 | graph_accuracy.append( 266 | epoch, { 267 | "%s_accuracy_%s" % (dataset_type, experiment_name): 100. * total_correct / total 268 | }) 269 | graph_loss.append(epoch, { 270 | "%s_loss_%s" % (dataset_type, experiment_name): total_loss / total 271 | }) 272 | graph_evasion.append( 273 | epoch, { 274 | "%s_evasion_%s" % (dataset_type, experiment_name): 275 | 100 * (evade_mal_total - evade_mal_total_correct) / evade_mal_total 276 | }) 277 | 278 | metrics = { 279 | "bscn_ratio": bscn.ratio(), 280 | "mal": { 281 | "total_loss": mal_total_loss, 282 | "total_correct": mal_total_correct, 283 | "total": mal_total, 284 | "evasion": { 285 | "total_loss": evade_mal_total_loss, 286 | "total_correct": evade_mal_total_correct, 287 | "total": evade_mal_total 288 | } 289 | }, 290 | "bon": { 291 | "total_loss": bon_total_loss, 292 | "total_correct": bon_total_correct, 293 | "total_evade": None, 294 | "total": bon_total 295 | } 296 | } 297 | print(metrics) 298 | 299 | return (bon_total_loss + max(mal_total_loss, evade_mal_total_loss)) / total, metrics 300 | 301 | 302 | if __name__ == "__main__": 303 | 304 | if not os.path.exists("result_files"): 305 | os.mkdir("result_files") 306 | 307 | _metrics = None 308 | session = None 309 | if train_model_from_scratch: 310 | best_valid_loss = float("inf") 311 | for _epoch in range(num_epochs): 312 | # train 313 | train(_epoch) 314 | 315 | # validate 316 | valid_loss, _ = test(_epoch, is_validate=True) 317 | 318 | # keep the best parameters w.r.t validation and check the test set 319 | if best_valid_loss > valid_loss: 320 | best_valid_loss = valid_loss 321 | _, _metrics = test(_epoch, is_validate=False) 322 | 323 | bscn_to_save = bscn.ratio() 324 | with open(os.path.join("result_files", "%s_bscn.txt" % experiment_name), "w") as f: 325 | f.write(str(bscn_to_save)) 326 | 327 | torch.save(model, os.path.join("helper_files", "%s-model.pt" % experiment_name)) 328 | elif _epoch % log_interval == 0: 329 | test(_epoch, is_validate=False) 330 | 331 | else: 332 | _, _metrics = test(0) 333 | 334 | with open(os.path.join("result_files", experiment_name + ".json"), "w") as result_file: 335 | json.dump(_metrics, result_file) 336 | 337 | if is_losswise: 338 | session.done() 339 | -------------------------------------------------------------------------------- /generate_vectors.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from datasets.datasets import load_data 5 | from utils.utils import load_parameters 6 | 7 | malicious_vector_filepath = "/home/alexmalmeng/saved_feature_vectors/malicious/" 8 | benign_vector_filepath = "/home/alexmalmeng/saved_feature_vectors/benign/" 9 | 10 | parameters = load_parameters("parameters.ini") 11 | 12 | assertion_message = "Flag must be on to generate. Changes return of PortableExecutableDataset." 13 | assert eval(parameters['dataset']['generate_feature_vector_files']) is True, assertion_message 14 | 15 | train_dataloader_dict, valid_dataloader_dict, test_dataloader_dict, num_features = load_data( 16 | parameters) 17 | 18 | print( 19 | len(train_dataloader_dict['malicious'].dataset) + len(test_dataloader_dict['malicious'].dataset) 20 | ) 21 | print(len(train_dataloader_dict['benign'].dataset) + len(test_dataloader_dict['benign'].dataset)) 22 | 23 | for data_dict in [train_dataloader_dict, valid_dataloader_dict, test_dataloader_dict]: 24 | 25 | for filetype in data_dict: 26 | print(filetype) 27 | dataloader = data_dict[filetype] 28 | 29 | for index, data in enumerate(dataloader): 30 | print(index, filetype) 31 | vector, label, filepath = data 32 | 33 | filename = filepath[0].split("/")[-1] 34 | 35 | if filetype == 'malicious': 36 | pickle.dump(vector, open(os.path.join(malicious_vector_filepath, filename), 'wb')) 37 | else: 38 | pickle.dump(vector, open(os.path.join(benign_vector_filepath, filename), 'wb')) 39 | -------------------------------------------------------------------------------- /helper_files/imported_function_to_index_mapping_dict.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALFA-group/robust-adv-malware-detection/7f0761d5d1905374f12b426249625496424584a3/helper_files/imported_function_to_index_mapping_dict.p -------------------------------------------------------------------------------- /helper_files/lief_parseable_cnet_programs.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALFA-group/robust-adv-malware-detection/7f0761d5d1905374f12b426249625496424584a3/helper_files/lief_parseable_cnet_programs.p -------------------------------------------------------------------------------- /helper_files/linux_environment.yml: -------------------------------------------------------------------------------- 1 | name: nn_mal 2 | dependencies: 3 | - python=3.6 4 | - numpy 5 | - pandas 6 | - pip: 7 | - pybloomfiltermmap 8 | - losswise 9 | - sklearn 10 | - lief 11 | - scipy 12 | - http://download.pytorch.org/whl/cu80/torch-0.3.0.post4-cp36-cp36m-linux_x86_64.whl 13 | - torchvision -------------------------------------------------------------------------------- /helper_files/osx_environment.yml: -------------------------------------------------------------------------------- 1 | name: nn_mal 2 | dependencies: 3 | - python=3.6 4 | - numpy 5 | - pandas 6 | - pip: 7 | - pybloomfiltermmap 8 | - losswise 9 | - sklearn 10 | - lief 11 | - scipy 12 | - http://download.pytorch.org/whl/torch-0.3.0.post4-cp36-cp36m-macosx_10_7_x86_64.whl 13 | - torchvision -------------------------------------------------------------------------------- /inner_maximizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALFA-group/robust-adv-malware-detection/7f0761d5d1905374f12b426249625496424584a3/inner_maximizers/__init__.py -------------------------------------------------------------------------------- /inner_maximizers/inner_maximizers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ 3 | Python module for implementing inner maximizers for robust adversarial training 4 | (Table I in the paper) 5 | """ 6 | import torch 7 | from torch.autograd import Variable 8 | from utils.utils import or_float_tensors, xor_float_tensors, clip_tensor 9 | import numpy as np 10 | 11 | 12 | # helper function 13 | def round_x(x, alpha=0.5): 14 | """ 15 | rounds x by thresholding it according to alpha which can be a scalar or vector 16 | :param x: 17 | :param alpha: threshold parameter 18 | :return: a float tensor of 0s and 1s. 19 | """ 20 | return (x > alpha).float() 21 | 22 | 23 | def get_x0(x, is_sample=False): 24 | """ 25 | Helper function to randomly initialize the the inner maximizer algos 26 | randomize such that the functionality is preserved. 27 | Functionality is preserved by maintaining the features present in x 28 | :param x: training sample 29 | :param is_sample: flag to sample randomly from feasible area or return just x 30 | :return: randomly sampled feasible version of x 31 | """ 32 | if is_sample: 33 | rand_x = round_x(torch.rand(x.size())) 34 | if x.is_cuda: 35 | rand_x = rand_x.cuda() 36 | return or_float_tensors(x, rand_x) 37 | else: 38 | return x 39 | 40 | 41 | def dfgsm_k(x, 42 | y, 43 | model, 44 | loss_fct, 45 | k=25, 46 | epsilon=0.02, 47 | alpha=0.5, 48 | is_report_loss_diff=True, 49 | use_sample=False): 50 | """ 51 | FGSM^k with deterministic rounding 52 | :param y: 53 | :param x: (tensor) feature vector 54 | :param model: nn model 55 | :param loss_fct: loss function 56 | :param k: num of steps 57 | :param epsilon: update value in each direction 58 | :param alpha: 59 | :param is_report_loss_diff: 60 | :param use_sample: 61 | :return: the adversarial version of x according to dfgsm_k (tensor) 62 | """ 63 | # some book-keeping 64 | if next(model.parameters()).is_cuda: 65 | x = x.cuda() 66 | y = y.cuda() 67 | y = Variable(y) 68 | 69 | # compute natural loss 70 | loss_natural = loss_fct(model(Variable(x)), y).data 71 | 72 | # initialize starting point 73 | x_next = get_x0(x, use_sample) 74 | 75 | # multi-step 76 | for t in range(k): 77 | # forward pass 78 | x_var = Variable(x_next, requires_grad=True) 79 | y_model = model(x_var) 80 | loss = loss_fct(y_model, y) 81 | 82 | # compute gradient 83 | grad_vars = torch.autograd.grad(loss.mean(), x_var) 84 | 85 | # find the next sample 86 | x_next = x_next + epsilon * torch.sign(grad_vars[0].data) 87 | 88 | # projection 89 | x_next = clip_tensor(x_next) 90 | 91 | # rounding step 92 | x_next = round_x(x_next, alpha=alpha) 93 | 94 | # feasible projection 95 | x_next = or_float_tensors(x_next, x) 96 | 97 | # compute adversarial loss 98 | loss_adv = loss_fct(model(Variable(x_next)), y).data 99 | 100 | if is_report_loss_diff: 101 | print("Natural loss (%.4f) vs Adversarial loss (%.4f), Difference: (%.4f)" % 102 | (loss_natural.mean(), loss_adv.mean(), loss_adv.mean() - loss_natural.mean())) 103 | 104 | replace_flag = (loss_adv < loss_natural).unsqueeze(1).expand_as(x_next) 105 | x_next[replace_flag] = x[replace_flag] 106 | 107 | if x_next.is_cuda: 108 | x_next = x_next.cpu() 109 | 110 | return x_next 111 | 112 | 113 | def rfgsm_k(x, y, model, loss_fct, k=25, epsilon=0.02, is_report_loss_diff=True, use_sample=False): 114 | """ 115 | FGSM^k with randomized rounding 116 | :param x: (tensor) feature vector 117 | :param y: 118 | :param model: nn model 119 | :param loss_fct: loss function 120 | :param k: num of steps 121 | :param epsilon: update value in each direction 122 | :param is_report_loss_diff: 123 | :param use_sample: 124 | :return: the adversarial version of x according to rfgsm_k (tensor) 125 | """ 126 | # some book-keeping 127 | if next(model.parameters()).is_cuda: 128 | x = x.cuda() 129 | y = y.cuda() 130 | y = Variable(y) 131 | 132 | # compute natural loss 133 | loss_natural = loss_fct(model(Variable(x)), y).data 134 | 135 | # initialize starting point 136 | x_next = get_x0(x, use_sample) 137 | 138 | # multi-step with gradients 139 | for t in range(k): 140 | # forward pass 141 | x_var = Variable(x_next, requires_grad=True) 142 | y_model = model(x_var) 143 | loss = loss_fct(y_model, y) 144 | 145 | # compute gradient 146 | grad_vars = torch.autograd.grad(loss.mean(), x_var) 147 | 148 | # find the next sample 149 | x_next = x_next + epsilon * torch.sign(grad_vars[0].data) 150 | 151 | # projection 152 | x_next = clip_tensor(x_next) 153 | 154 | # rounding step 155 | alpha = torch.rand(x_next.size()) 156 | if x_next.is_cuda: 157 | alpha = alpha.cuda() 158 | x_next = round_x(x_next, alpha=alpha) 159 | 160 | # feasible projection 161 | x_next = or_float_tensors(x_next, x) 162 | 163 | # compute adversarial loss 164 | loss_adv = loss_fct(model(Variable(x_next)), y).data 165 | 166 | if is_report_loss_diff: 167 | print("Natural loss (%.4f) vs Adversarial loss (%.4f), Difference: (%.4f)" % 168 | (loss_natural.mean(), loss_adv.mean(), loss_adv.mean() - loss_natural.mean())) 169 | 170 | replace_flag = (loss_adv < loss_natural).unsqueeze(1).expand_as(x_next) 171 | x_next[replace_flag] = x[replace_flag] 172 | 173 | if x_next.is_cuda: 174 | x_next = x_next.cpu() 175 | 176 | return x_next 177 | 178 | 179 | def bga_k(x, y, model, loss_fct, k=25, is_report_loss_diff=True, use_sample=False): 180 | """ 181 | Multi-step bit gradient ascent 182 | :param x: (tensor) feature vector 183 | :param y: 184 | :param model: nn model 185 | :param loss_fct: loss function 186 | :param k: num of steps 187 | :param is_report_loss_diff: 188 | :param use_sample: 189 | :return: the adversarial version of x according to bga_k (tensor) 190 | """ 191 | # some book-keeping 192 | sqrt_m = torch.from_numpy(np.sqrt([x.size()[1]])).float() 193 | 194 | if next(model.parameters()).is_cuda: 195 | x = x.cuda() 196 | y = y.cuda() 197 | sqrt_m = sqrt_m.cuda() 198 | 199 | y = Variable(y) 200 | 201 | # compute natural loss 202 | loss_natural = loss_fct(model(Variable(x)), y).data 203 | 204 | # keeping worst loss 205 | loss_worst = loss_natural.clone() 206 | x_worst = x.clone() 207 | 208 | # multi-step with gradients 209 | loss = None 210 | x_var = None 211 | x_next = None 212 | for t in range(k): 213 | if t == 0: 214 | # initialize starting point 215 | x_next = get_x0(x, use_sample) 216 | else: 217 | # compute gradient 218 | grad_vars = torch.autograd.grad(loss.mean(), x_var) 219 | grad_data = grad_vars[0].data 220 | 221 | # compute the updates 222 | x_update = (sqrt_m * (1. - 2. * x_next) * grad_data >= torch.norm( 223 | grad_data, 2, 1).unsqueeze(1).expand_as(x_next)).float() 224 | 225 | # find the next sample with projection to the feasible set 226 | x_next = xor_float_tensors(x_update, x_next) 227 | x_next = or_float_tensors(x_next, x) 228 | 229 | # forward pass 230 | x_var = Variable(x_next, requires_grad=True) 231 | y_model = model(x_var) 232 | loss = loss_fct(y_model, y) 233 | 234 | # update worst loss and adversarial samples 235 | replace_flag = (loss.data > loss_worst) 236 | loss_worst[replace_flag] = loss.data[replace_flag] 237 | x_worst[replace_flag.unsqueeze(1).expand_as(x_worst)] = x_next[replace_flag.unsqueeze(1) 238 | .expand_as(x_worst)] 239 | 240 | if is_report_loss_diff: 241 | print("Natural loss (%.4f) vs Adversarial loss (%.4f), Difference: (%.4f)" % 242 | (loss_natural.mean(), loss_worst.mean(), loss_worst.mean() - loss_natural.mean())) 243 | 244 | if x_worst.is_cuda: 245 | x_worst = x_worst.cpu() 246 | 247 | return x_worst 248 | 249 | 250 | def bca_k(x, y, model, loss_fct, k=25, is_report_loss_diff=True, use_sample=False): 251 | """ 252 | Multi-step bit coordinate ascent 253 | :param use_sample: 254 | :param is_report_loss_diff: 255 | :param y: 256 | :param x: (tensor) feature vector 257 | :param model: nn model 258 | :param loss_fct: loss function 259 | :param k: num of steps 260 | :return: the adversarial version of x according to bca_k (tensor) 261 | """ 262 | if next(model.parameters()).is_cuda: 263 | x = x.cuda() 264 | y = y.cuda() 265 | 266 | y = Variable(y) 267 | 268 | # compute natural loss 269 | loss_natural = loss_fct(model(Variable(x)), y).data 270 | 271 | # keeping worst loss 272 | loss_worst = loss_natural.clone() 273 | x_worst = x.clone() 274 | 275 | # multi-step with gradients 276 | loss = None 277 | x_var = None 278 | x_next = None 279 | for t in range(k): 280 | if t == 0: 281 | # initialize starting point 282 | x_next = get_x0(x, use_sample) 283 | else: 284 | # compute gradient 285 | grad_vars = torch.autograd.grad(loss.mean(), x_var) 286 | grad_data = grad_vars[0].data 287 | 288 | # compute the updates (can be made more efficient than this) 289 | aug_grad = (1. - 2. * x_next) * grad_data 290 | val, _ = torch.topk(aug_grad, 1) 291 | x_update = (aug_grad >= val.expand_as(aug_grad)).float() 292 | 293 | # find the next sample with projection to the feasible set 294 | x_next = xor_float_tensors(x_update, x_next) 295 | x_next = or_float_tensors(x_next, x) 296 | 297 | # forward pass 298 | x_var = Variable(x_next, requires_grad=True) 299 | y_model = model(x_var) 300 | loss = loss_fct(y_model, y) 301 | 302 | # update worst loss and adversarial samples 303 | replace_flag = (loss.data > loss_worst) 304 | loss_worst[replace_flag] = loss.data[replace_flag] 305 | x_worst[replace_flag.unsqueeze(1).expand_as(x_worst)] = x_next[replace_flag.unsqueeze(1) 306 | .expand_as(x_worst)] 307 | 308 | if is_report_loss_diff: 309 | print("Natural loss (%.4f) vs Adversarial loss (%.4f), Difference: (%.4f)" % 310 | (loss_natural.mean(), loss_worst.mean(), loss_worst.mean() - loss_natural.mean())) 311 | 312 | if x_worst.is_cuda: 313 | x_worst = x_worst.cpu() 314 | 315 | return x_worst 316 | 317 | 318 | def grosse_k(x, y, model, loss_fct, k=25, is_report_loss_diff=True, use_sample=False): 319 | """ 320 | Multi-step bit coordinate ascent using gradient of output, advancing in direction of maximal change 321 | :param use_sample: 322 | :param is_report_loss_diff: 323 | :param loss_fct: 324 | :param y: 325 | :param x: (tensor) feature vector 326 | :param model: nn model 327 | :param k: num of steps 328 | :return adversarial version of x (tensor) 329 | """ 330 | 331 | if next(model.parameters()).is_cuda: 332 | x = x.cuda() 333 | y = y.cuda() 334 | 335 | y = Variable(y) 336 | 337 | # compute natural loss 338 | loss_natural = loss_fct(model(Variable(x)), y).data 339 | 340 | # keeping worst loss 341 | loss_worst = loss_natural.clone() 342 | x_worst = x.clone() 343 | 344 | output = None 345 | x_var = None 346 | x_next = None 347 | for t in range(k): 348 | if t == 0: 349 | # initialize starting point 350 | x_next = get_x0(x, use_sample) 351 | else: 352 | grad_vars = torch.autograd.grad(output[:, 0].mean(), x_var) 353 | grad_data = grad_vars[0].data 354 | 355 | # Only consider gradients for points of 0 value 356 | aug_grad = (1. - x_next) * grad_data 357 | val, _ = torch.topk(aug_grad, 1) 358 | x_update = (aug_grad >= val.expand_as(aug_grad)).float() 359 | 360 | # find the next sample with projection to the feasible set 361 | x_next = xor_float_tensors(x_update, x_next) 362 | x_next = or_float_tensors(x_next, x) 363 | 364 | x_var = Variable(x_next, requires_grad=True) 365 | output = model(x_var) 366 | 367 | loss = loss_fct(output, y) 368 | 369 | # update worst loss and adversarial samples 370 | replace_flag = (loss.data > loss_worst) 371 | loss_worst[replace_flag] = loss.data[replace_flag] 372 | x_worst[replace_flag.unsqueeze(1).expand_as(x_worst)] = x_next[replace_flag.unsqueeze(1) 373 | .expand_as(x_worst)] 374 | 375 | if is_report_loss_diff: 376 | print("Natural loss (%.4f) vs Adversarial loss (%.4f), Difference: (%.4f)" % 377 | (loss_natural.mean(), loss_worst.mean(), loss_worst.mean() - loss_natural.mean())) 378 | 379 | if x_worst.is_cuda: 380 | x_worst = x_worst.cpu() 381 | 382 | return x_worst 383 | 384 | 385 | def inner_maximizer(x, y, model, loss_fct, iterations=100, method='natural'): 386 | """ 387 | A wrapper function for the above algorithim 388 | :param iterations: 389 | :param x: 390 | :param y: 391 | :param model: 392 | :param loss_fct: 393 | :param method: one of 'dfgsm_k', 'rfgsm_k', 'bga_k', 'bca_k', 'natural 394 | :return: adversarial examples 395 | """ 396 | if method == 'dfgsm_k': 397 | return dfgsm_k(x, y, model, loss_fct, k=iterations) 398 | elif method == 'rfgsm_k': 399 | return rfgsm_k(x, y, model, loss_fct, k=iterations) 400 | elif method == 'bga_k': 401 | return bga_k(x, y, model, loss_fct, k=iterations) 402 | elif method == 'bca_k': 403 | return bca_k(x, y, model, loss_fct, k=iterations) 404 | elif method == 'grosse': 405 | return grosse_k(x, y, model, loss_fct, k=iterations) 406 | elif method == 'natural': 407 | return x 408 | else: 409 | raise Exception('No such inner maximizer algorithm') 410 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALFA-group/robust-adv-malware-detection/7f0761d5d1905374f12b426249625496424584a3/nets/__init__.py -------------------------------------------------------------------------------- /nets/ff_classifier.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ 3 | Python module for softmax binary classifier neural network 4 | """ 5 | 6 | import torch.nn as nn 7 | 8 | 9 | def init_weights(net): 10 | """ 11 | initialize the weights of a network 12 | :param net: 13 | :return: 14 | """ 15 | 16 | # init parameters 17 | def init_module(m): 18 | if type(m) == nn.Linear: 19 | nn.init.xavier_normal(m.weight.data) 20 | nn.init.xavier_uniform(m.bias.data) 21 | 22 | net.apply(init_module) 23 | 24 | return net 25 | 26 | 27 | def build_ff_classifier(input_size, hidden_1_size, hidden_2_size, hidden_3_size, num_labels=2): 28 | """ 29 | Constructs a neural net binary classifer 30 | :param input_size: 31 | :param hidden_1_size: 32 | :param hidden_2_size: 33 | :param hidden_3_size: 34 | :param num_labels: 35 | :return: 36 | """ 37 | net = nn.Sequential( 38 | nn.Linear(input_size, hidden_1_size), 39 | nn.ReLU(), 40 | nn.Linear(hidden_1_size, hidden_2_size), 41 | nn.ReLU(), 42 | nn.Linear(hidden_2_size, hidden_3_size), 43 | nn.ReLU(), 44 | nn.Linear(hidden_3_size, num_labels), 45 | nn.LogSoftmax(dim=1)) 46 | 47 | return net 48 | -------------------------------------------------------------------------------- /parameters.ini: -------------------------------------------------------------------------------- 1 | [dataset] 2 | benign_filepath = "PATH TO BENGIN SAVED FEATURE VECTORS" 3 | malicious_filepath = "PATH TO MALICIOUS SAVED FEATURE VECTORS" 4 | helper_filepath = ./helper_files/ 5 | malicious_files_list = None 6 | benign_files_list = lief_parseable_cnet_programs.p 7 | load_mapping_from_pickle = True 8 | pickle_mapping_file = imported_function_to_index_mapping_dict.p 9 | use_subset_of_data = True 10 | num_files_to_use = 19000 11 | num_features_to_use = None 12 | test_size_percent = 0.2 13 | generate_feature_vector_files = False 14 | use_saved_feature_vectors = True 15 | 16 | [general] 17 | is_synthetic_dataset = True 18 | is_cuda = False 19 | gpu_device = 0 20 | log_interval = 10 21 | seed = 1 22 | is_losswise = False 23 | losswise_api_key = None 24 | training_method = dfgsm_k 25 | evasion_method = dfgsm_k 26 | experiment_suffix = run_experiments 27 | save_every_epoch = False 28 | train_model_from_scratch = True 29 | load_model_weights = False 30 | model_weights_path = ./helper_files/[training:bga_k|evasion:natural]_run_experiments-model.pt 31 | num_workers = 10 32 | 33 | [hyperparam] 34 | starting_epoch = 0 35 | ff_h1 = 300 36 | ff_h2 = 300 37 | ff_h3 = 300 38 | ff_learning_rate = .001 39 | ff_num_epochs = 200 40 | evasion_iterations = 50 41 | training_batch_size = 8 42 | test_batch_size = 8 43 | 44 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | losswise==1.3 2 | numpy==1.13.3 3 | pandas==0.21.0 4 | ipdb==0.10.3 5 | lief==0.8.3.post3 6 | pybloomfiltermmap==0.3.15 7 | scikit_learn==0.19.1 8 | torch==2.2.0 9 | -------------------------------------------------------------------------------- /run_experiments.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from utils.script_functions import set_parameter 3 | from os import system 4 | 5 | if __name__ == "__main__": 6 | parameters_filepath = "parameters.ini" 7 | 8 | # Keep as all 5 9 | train_methods = ['natural', 'rfgsm_k', 'dfgsm_k', 'bga_k', 'bca_k', 'grosse'] 10 | evasion_methods = ['natural', 'rfgsm_k', 'dfgsm_k', 'bga_k', 'bca_k', 'grosse'] 11 | 12 | for train_method in train_methods: 13 | 14 | set_parameter(parameters_filepath, "general", "train_model_from_scratch", "True") 15 | set_parameter(parameters_filepath, "general", "load_model_weights", "False") 16 | set_parameter(parameters_filepath, "general", "experiment_suffix", "run_experiments") 17 | 18 | set_parameter(parameters_filepath, "general", "training_method", train_method) 19 | set_parameter(parameters_filepath, "general", "evasion_method", train_method) 20 | system("source activate nn_mal;python framework.py") 21 | 22 | for train_method in train_methods: 23 | model_filepath = "./helper_files/[training:{train_meth}|evasion:{train_meth}]_run_experiments-model.pt".format( 24 | train_meth=train_method) 25 | 26 | set_parameter(parameters_filepath, "general", "training_method", train_method) 27 | set_parameter(parameters_filepath, "general", "train_model_from_scratch", "False") 28 | set_parameter(parameters_filepath, "general", "load_model_weights", "True") 29 | set_parameter(parameters_filepath, "general", "model_weights_path", model_filepath) 30 | 31 | for evasion_method in evasion_methods: 32 | set_parameter(parameters_filepath, "general", "evasion_method", evasion_method) 33 | system("source activate nn_mal;python framework.py") 34 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ALFA-group/robust-adv-malware-detection/7f0761d5d1905374f12b426249625496424584a3/utils/__init__.py -------------------------------------------------------------------------------- /utils/script_functions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Python module for scripting helper functions 3 | """ 4 | from glob import glob 5 | import configparser 6 | import os 7 | import re 8 | import json 9 | import pandas as pd 10 | 11 | 12 | def set_parameter(parameters_filepath, section_name, parameter_name, parameter_value): 13 | """ 14 | set the specified parameter to the specified value and write back to the *.ini file 15 | :param parameters_filepath: filename (absolute path) 16 | :param section_name: section name under which parameter is 17 | :param parameter_name: parameter name 18 | :param parameter_value: target value 19 | :return: 20 | """ 21 | conf_parameters = configparser.ConfigParser() 22 | conf_parameters.read(parameters_filepath, encoding="UTF-8") 23 | conf_parameters.set(section_name, parameter_name, parameter_value) 24 | with open(parameters_filepath, 'w') as config_file: 25 | conf_parameters.write(config_file) 26 | 27 | 28 | def df_2_tex(df, filepath): 29 | """ 30 | writes a df to tex file 31 | :param df: dataframe to be converted into tex table 32 | :param filepath: tex filepath 33 | :return: 34 | """ 35 | tex_prefix = r"""\documentclass{standalone} 36 | \usepackage{booktabs} 37 | \begin{document}""" 38 | 39 | tex_suffix = r"""\end{document}""" 40 | 41 | with open(filepath, "w") as f: 42 | f.write(tex_prefix) 43 | f.write(df.to_latex(float_format="%.1f")) 44 | f.write(tex_suffix) 45 | 46 | 47 | def file_rank(filename): 48 | """ 49 | assign a rank to the file can be used for sorting 50 | :param filename: 51 | :return: 52 | """ 53 | order = {'natural': 0, 'rfgsm_k': 2, 'dfgsm_k': 1, 'bga_k': 3, 'bca_k': 4, 'grosse': 5} 54 | 55 | training_method = re.search("\[training:.*\|", filename).group(0)[:-1].split(':')[-1] 56 | evasion_method = re.search("\|evasion:.*\]", filename).group(0)[:-1].split(':')[-1] 57 | 58 | return order[training_method] * 6 + order[evasion_method] 59 | 60 | 61 | def create_tex_tables(filespath="../result_files"): 62 | """ 63 | Create TeX tables from the results populated under `result_files` 64 | which is generated from running `framework.py` 65 | The tex file is stored in `result_files` 66 | :param filespath: the path where the results in json are stored and the tex files are created 67 | :return: 68 | """ 69 | 70 | # read the bscn files 71 | bscn_files = sorted(glob(os.path.join(filespath, "*.txt")), key=lambda x: file_rank(x)) 72 | 73 | # read the results file 74 | files = sorted(glob(os.path.join(filespath, "*.json")), key=lambda x: file_rank(x)) 75 | 76 | # dataframes 77 | bscn_df = pd.DataFrame() 78 | evasion_df = pd.DataFrame() 79 | accuracy_df = pd.DataFrame() 80 | afp_df = pd.DataFrame() 81 | bon_accuracy_df = pd.DataFrame() 82 | mal_accuracy_df = pd.DataFrame() 83 | mal_loss_df = pd.DataFrame() 84 | 85 | for idx, filename in enumerate(bscn_files): 86 | training_method = re.search("\[training:.*\|", filename).group(0)[:-1].split(':')[-1] 87 | evasion_method = re.search("\|evasion:.*\]", filename).group(0)[:-1].split(':')[-1] 88 | with open(filename, 'r') as f: 89 | bscn_val = float(f.read()) 90 | print(training_method, evasion_method, bscn_val) 91 | bscn_df.loc[training_method, "bsn_ratio"] = bscn_val 92 | 93 | bscn_df = bscn_df.div(bscn_df.loc['natural'], axis=1) 94 | for idx, filename in enumerate(files): 95 | training_method = re.search("\[training:.*\|", filename).group(0)[:-1].split(':')[-1] 96 | evasion_method = re.search("\|evasion:.*\]", filename).group(0)[:-1].split(':')[-1] 97 | with open(filename, 'r') as f: 98 | metrics = json.load(f) 99 | evasion_df.loc[training_method, evasion_method] = ( 100 | 1 - metrics["mal"]["evasion"]["total_correct"] / metrics["mal"]["total"]) * 100. 101 | if training_method == evasion_method: 102 | afp_df.loc[training_method, 'accuracy'] = ( 103 | metrics["mal"]["total_correct"] + metrics["bon"]["total_correct"]) * 100. / ( 104 | metrics["mal"]["total"] + metrics["bon"]["total"]) 105 | fp = (metrics["bon"]["total"] - metrics["bon"]["total_correct"]) 106 | fn = (metrics["mal"]["total"] - metrics["mal"]["total_correct"]) 107 | tp = (metrics["mal"]["total_correct"]) 108 | tn = (metrics["bon"]["total_correct"]) 109 | afp_df.loc[training_method, 'fpr'] = fp * 100. / (fp + tn) 110 | afp_df.loc[training_method, 'fnr'] = fn * 100. / (fn + tp) 111 | 112 | accuracy_df.loc[training_method, evasion_method] = ( 113 | metrics["mal"]["total_correct"] + metrics["bon"]["total_correct"]) * 100. / ( 114 | metrics["mal"]["total"] + metrics["bon"]["total"]) 115 | 116 | bon_accuracy_df.loc[training_method, evasion_method] = ( 117 | metrics["bon"]["total_correct"]) * 100. / (metrics["bon"]["total"]) 118 | mal_accuracy_df.loc[training_method, evasion_method] = ( 119 | metrics["mal"]["total_correct"]) * 100. / (metrics["mal"]["total"]) 120 | mal_loss_df.loc[training_method, evasion_method] = metrics["mal"]["total_loss"] / ( 121 | metrics["mal"]["total"]) 122 | 123 | # tex file names 124 | bscn_tbl_file = os.path.join(filespath, "bscn_tbl.tex") 125 | evasion_tbl_file = os.path.join(filespath, "evasion_tbl.tex") 126 | accuracy_tbl_file = os.path.join(filespath, "accuracy_tbl.tex") 127 | afp_tbl_file = os.path.join(filespath, "afp_tbl.tex") 128 | bon_accuracy_tbl_file = os.path.join(filespath, "bon_accuracy_tbl.tex") 129 | mal_accuracy_tbl_file = os.path.join(filespath, "mal_accuracy_tbl.tex") 130 | mal_loss_tbl_file = os.path.join(filespath, "mal_loss_tbl.tex") 131 | # write the tex files 132 | df_2_tex(bscn_df, bscn_tbl_file) 133 | df_2_tex(evasion_df, evasion_tbl_file) 134 | df_2_tex(accuracy_df, accuracy_tbl_file) 135 | df_2_tex(afp_df, afp_tbl_file) 136 | df_2_tex(bon_accuracy_df, bon_accuracy_tbl_file) 137 | df_2_tex(mal_accuracy_df, mal_accuracy_tbl_file) 138 | df_2_tex(mal_loss_df, mal_loss_tbl_file) 139 | 140 | 141 | if __name__ == '__main__': 142 | create_tex_tables() 143 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """Python module for handy functions""" 3 | import configparser 4 | import torch 5 | 6 | 7 | def load_parameters(parameters_filepath): 8 | """ 9 | load parameters from an *.ini file 10 | :param parameters_filepath: filename (absolute path) 11 | :return: nested dictionary of parameters 12 | """ 13 | conf_parameters = configparser.ConfigParser() 14 | conf_parameters.read(parameters_filepath, encoding="UTF-8") 15 | # nested_parameters = utils.convert_configparser_to_dictionary(conf_parameters) 16 | nested_parameters = {s: dict(conf_parameters.items(s)) for s in conf_parameters.sections()} 17 | return nested_parameters 18 | 19 | 20 | def stack_tensors(*args): 21 | """ 22 | Stack arbitrary number of tensors along the first dimension 23 | :param args: list of tensors 24 | :return: tensor stacking all the input tensors 25 | """ 26 | return torch.cat(args, dim=0) 27 | 28 | 29 | def or_float_tensors(x_1, x_2): 30 | """ 31 | ORs two float tensors by converting them to byte and back 32 | Note that byte() takes the first 8 bit after the decimal point of the float 33 | e.g., 0.0 ==> 0 34 | 0.1 ==> 0 35 | 1.1 ==> 1 36 | 255.1 ==> 255 37 | 256.1 ==> 0 38 | Subsequently the purpose of this function is to map 1s float tensors to 1 39 | and those of 0s to 0. I.e., it is meant to be used on tensors of 0s and 1s. 40 | 41 | :param x_1: tensor one 42 | :param x_2: tensor two 43 | :return: float tensor of 0s and 1s. 44 | """ 45 | return (x_1.byte() | x_2.byte()).float() 46 | 47 | 48 | def xor_float_tensors(x_1, x_2): 49 | """ 50 | XORs two float tensors by converting them to byte and back 51 | Note that byte() takes the first 8 bit after the decimal point of the float 52 | e.g., 0.0 ==> 0 53 | 0.1 ==> 0 54 | 1.1 ==> 1 55 | 255.1 ==> 255 56 | 256.1 ==> 0 57 | Subsequently the purpose of this function is to map 1s float tensors to 1 58 | and those of 0s to 0. I.e., it is meant to be used on tensors of 0s and 1s. 59 | 60 | :param x_1: tensor one 61 | :param x_2: tensor two 62 | :return: float tensor of 0s and 1s. 63 | """ 64 | return (x_1.byte() ^ x_2.byte()).float() 65 | 66 | 67 | def clip_tensor(x, lb=0., ub=1.): 68 | """ 69 | Clip a tensor to be within lb and ub 70 | :param x: 71 | :param lb: lower bound (scalar) 72 | :param ub: upper bound (scalar) 73 | :return: clipped version of x 74 | """ 75 | return torch.clamp(x, min=lb, max=ub) 76 | 77 | 78 | if __name__ == "__main__": 79 | print("a module to be imported by others, testing here") 80 | 81 | parameters = load_parameters("../parameters.ini") 82 | 83 | stacked_tensor = stack_tensors(torch.ones(5, 2), torch.zeros(5, 2)) 84 | 85 | ored_tensor = or_float_tensors(torch.ones(5), torch.zeros(5)) 86 | 87 | clipped_tensor = clip_tensor(2 * torch.rand(5) - 1) 88 | 89 | print(clipped_tensor) 90 | --------------------------------------------------------------------------------