├── core ├── __init__.py ├── post_calibration │ ├── __init__.py │ └── temperature_scaling.py ├── feature │ ├── dex2img │ │ ├── __init__.py │ │ └── dex2img.py │ ├── drebin │ │ └── __init__.py │ ├── opcodeseq │ │ ├── __init__.py │ │ └── opcodeseq.py │ ├── apiseq │ │ ├── __init__.py │ │ └── apiseq.py │ ├── multimodality │ │ ├── __init__.py │ │ └── multimodality.py │ └── __init__.py └── ensemble │ ├── __init__.py │ ├── ensemble.py │ ├── mc_dropout.py │ ├── anchor_ensemble_test.py │ ├── dataset_lib.py │ ├── bayesian_ensemble_test.py │ ├── model_hp.py │ ├── anchor_ensemble.py │ ├── bayesian_ensemble.py │ └── deep_ensemble.py ├── tools ├── __init__.py ├── progressbar │ ├── compat.py │ ├── __init__.py │ ├── examples.py │ ├── progressbar.py │ └── widgets.py ├── progressbar_wrapper.py ├── temporal.py └── metrics.py ├── experiments ├── __init__.py ├── drebin_ood_test.py ├── adv_main.py ├── oos_main.py ├── androzoo_main.py ├── drebin_main.py ├── oos.py ├── adv.py ├── drebin_ood.py └── drebin_dataset.py ├── .idea ├── .gitignore ├── misc.xml ├── vcs.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml └── malware-uncertainty.iml ├── main.py ├── requirements ├── requirements.txt ├── test ├── dataset_lib_test.py ├── feature_extraction_test.py ├── model_lib_test.py ├── vanilla_test.py ├── mc_dropout_test.py ├── bayesian_ensemble_test.py └── deep_ensemble_test.py ├── config.py ├── conf ├── conf-server-ubuntu ├── README.md └── run.sh /core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /core/post_calibration/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Default ignored files 3 | /workspace.xml -------------------------------------------------------------------------------- /core/feature/dex2img/__init__.py: -------------------------------------------------------------------------------- 1 | from core.feature.dex2img.dex2img import dex2img 2 | -------------------------------------------------------------------------------- /core/feature/drebin/__init__.py: -------------------------------------------------------------------------------- 1 | from core.feature.drebin.drebin import AxplorerMapping, \ 2 | get_drebin_feature, \ 3 | wrapper_load_features 4 | -------------------------------------------------------------------------------- /core/feature/opcodeseq/__init__.py: -------------------------------------------------------------------------------- 1 | from core.feature.opcodeseq.opcodeseq import feature_extr_wrapper, \ 2 | read_opcode, \ 3 | read_opcode_wrapper 4 | -------------------------------------------------------------------------------- /core/feature/apiseq/__init__.py: -------------------------------------------------------------------------------- 1 | from core.feature.apiseq.apiseq import get_api_sequence, \ 2 | load_feature, \ 3 | wrapper_load_feature, \ 4 | wrapper_mapping 5 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | 4 | 5 | def _main(): 6 | print("hello") 7 | return 0 8 | 9 | 10 | if __name__ == '__main__': 11 | print(_main()) -------------------------------------------------------------------------------- /core/feature/multimodality/__init__.py: -------------------------------------------------------------------------------- 1 | from core.feature.multimodality.multimodality import API_LIST, \ 2 | get_multimod_feature, \ 3 | load_feature, \ 4 | wrapper_load_features 5 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /requirements: -------------------------------------------------------------------------------- 1 | Tensorflow >= 2.1.0 2 | tensorflow-probability==0.9.0 3 | numpy >= 1.17.3 4 | scikit-learn >= 0.21.3 5 | pandas >= 1.0.4 6 | androguard == 3.3.5 7 | absl-py == 0.8.1 8 | Pillow 9 | pyelftools 10 | capstone 11 | python-magic 12 | pyyaml 13 | multiprocessing-logging -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==2.2.0 2 | tensorflow-probability==0.10.1 3 | numpy>=1.17.3 4 | scikit-learn==0.23.2 5 | pandas>=1.0.4 6 | androguard==3.3.5 7 | absl-py==0.8.1 8 | # python-magic-bin==0.4.14 9 | seaborn 10 | Pillow 11 | pyelftools 12 | capstone 13 | python-magic 14 | pyyaml 15 | multiprocessing-logging -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/malware-uncertainty.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /test/dataset_lib_test.py: -------------------------------------------------------------------------------- 1 | from absl.testing import absltest 2 | from absl.testing import parameterized 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from core.ensemble.dataset_lib import build_dataset_from_numerical_data 8 | 9 | 10 | class MyTestCase(tf.test.TestCase, parameterized.TestCase): 11 | def test_something(self): 12 | x = np.random.rand(5,3) 13 | data = build_dataset_from_numerical_data(x, batch_size=1) 14 | for i, _x in enumerate(data): 15 | self.assertAllClose(_x[0, :], x[i, :]) 16 | 17 | y = np.random.choice(2, 5) 18 | dataset = build_dataset_from_numerical_data((x, y), batch_size=1) 19 | for i, (_x, _y) in enumerate(dataset): 20 | self.assertAllClose(_x[0, :], x[i, :]) 21 | self.assertEqual(_y, y[i]) 22 | 23 | 24 | if __name__ == '__main__': 25 | absltest.main() 26 | -------------------------------------------------------------------------------- /core/ensemble/__init__.py: -------------------------------------------------------------------------------- 1 | from tools.utils import ensemble_method_scope as _ensemble_method_names 2 | from core.ensemble.bayesian_ensemble import BayesianEnsemble 3 | from core.ensemble.mc_dropout import MCDropout 4 | from core.ensemble.deep_ensemble import DeepEnsemble, WeightedDeepEnsemble 5 | from core.ensemble.vanilla import Vanilla 6 | from collections import namedtuple 7 | 8 | _Ensemble_methods = namedtuple('ensemble_methods', _ensemble_method_names) 9 | _ensemble_methods = _Ensemble_methods(vanilla=Vanilla, 10 | mc_dropout=MCDropout, 11 | bayesian=BayesianEnsemble, 12 | deep_ensemble=DeepEnsemble, 13 | weighted_ensemble=WeightedDeepEnsemble 14 | ) 15 | ensemble_method_scope_dict = dict(_ensemble_methods._asdict()) 16 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import sys 5 | import time 6 | import logging 7 | import multiprocessing_logging 8 | 9 | if sys.version_info[0] < 3: 10 | import ConfigParser as configparser 11 | else: 12 | import configparser 13 | 14 | config = configparser.ConfigParser() 15 | 16 | get = config.get 17 | config_dir = os.path.dirname(__file__) 18 | 19 | 20 | def parser_config(): 21 | config_file = os.path.join(config_dir, "conf") 22 | 23 | if not os.path.exists(config_file): 24 | sys.stderr.write("Error: Unable to find the config file!\n") 25 | sys.exit(1) 26 | 27 | # parse the configuration 28 | global config 29 | config.read_file(open(config_file)) 30 | 31 | 32 | parser_config() 33 | 34 | logging.basicConfig(level=logging.INFO, filename=os.path.join(config_dir, "log"), filemode="a", 35 | format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s: %(message)s', 36 | datefmt='%Y/%m/%d %H:%M:%S') 37 | ErrorHandler = logging.StreamHandler() 38 | ErrorHandler.setFormatter(logging.Formatter('%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s: %(message)s')) 39 | multiprocessing_logging.install_mp_handler() 40 | -------------------------------------------------------------------------------- /core/feature/__init__.py: -------------------------------------------------------------------------------- 1 | from core.feature.feature_extraction import DrebinFeature, \ 2 | OpcodeSeq, \ 3 | MultiModality, \ 4 | DexToImage, \ 5 | APISequence 6 | 7 | from collections import namedtuple 8 | from core.ensemble.model_lib import model_name_type_dict 9 | 10 | feature_type_scope_dict = { 11 | 'drebin': DrebinFeature, 12 | 'multimodality': MultiModality, 13 | 'opcodeseq': OpcodeSeq, 14 | 'dex2img': DexToImage, 15 | 'apiseq': APISequence 16 | } 17 | 18 | # bridge the gap between the feature extraction and dnn architecture 19 | _ARCH_TYPE = namedtuple('architectures', model_name_type_dict.keys()) 20 | _architecture_feature_extraction = _ARCH_TYPE(dnn='drebin', 21 | multimodalitynn='multimodality', 22 | text_cnn='opcodeseq', 23 | r2d2='dex2img', 24 | droidectc='apiseq' 25 | ) 26 | _architecture_feature_extraction_dict = dict(_architecture_feature_extraction._asdict()) 27 | feature_type_vs_architecture = dict(zip(_architecture_feature_extraction_dict.values(), 28 | _architecture_feature_extraction_dict.keys())) 29 | -------------------------------------------------------------------------------- /experiments/drebin_ood_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import itertools 5 | 6 | from tensorflow.compat.v1 import flags 7 | from absl.testing import absltest 8 | from absl.testing import parameterized 9 | from experiments.drebin_ood import run_experiment, feature_type_scope_dict, ensemble_method_scope_dict 10 | 11 | flags.DEFINE_integer('n_members', 1, 12 | 'The number of members for deep ensemble, weighted deep ensemble, and anchor ensemble') 13 | flags.DEFINE_integer('proc_numbers', 2, 14 | 'The number of threads for features extraction') 15 | 16 | 17 | # feature_types = feature_type_scope_dict.keys() 18 | # ensemble_types = ensemble_method_scope_dict.keys() 19 | feature_types = ['dex2img'] # 'drebin', 'opcodeseq', 'multimodality', 20 | ensemble_types = ['bayesian'] 21 | 22 | 23 | class TestOODExperiments(parameterized.TestCase): 24 | @parameterized.named_parameters( 25 | [('%s_%s' % p,) + p for p in itertools.product(feature_types, ensemble_types)]) 26 | def test_run(self, feature, ensemble): 27 | print('testing:', feature, ensemble) 28 | run_experiment(feature, ensemble, flags.FLAGS.n_members, flags.FLAGS.proc_numbers) 29 | 30 | 31 | if __name__ == '__main__': 32 | absltest.main() 33 | -------------------------------------------------------------------------------- /conf: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | project_root = F:\projects\malware-det\malware-uncertainty\ 3 | database_dir = F:\dataSet\android\ 4 | 5 | [drebin] 6 | dataset_name = drebin 7 | malware_dir = %(database_dir)s\%(dataset_name)s\malicious_samples 8 | benware_dir = %(database_dir)s\%(dataset_name)s\benign_samples 9 | intermediate_directory = %(project_root)s\datasets\%(dataset_name)s 10 | 11 | [androzoo_tesseract] 12 | dataset_name = androzoo_tesseract 13 | malware_dir = %(database_dir)s\%(dataset_name)s\malicious_samples 14 | benware_dir = %(database_dir)s\%(dataset_name)s\benign_samples 15 | date_stamp = %(database_dir)s\%(dataset_name)s\date_stamp.json 16 | intermediate_directory = %(project_root)s\datasets\%(dataset_name)s 17 | 18 | [oos] 19 | dataset_name = VirusShare_Android_APK_2013 20 | malware_dir = %(database_dir)s\%(dataset_name)s\malicious_samples 21 | benware_dir = %(database_dir)s\%(dataset_name)s\benign_samples 22 | intermediate_directory = %(project_root)s\datasets\%(dataset_name)s 23 | 24 | [adv] 25 | dataset_name = drebin_adv 26 | pristine_apk_dir = %(database_dir)s\%(dataset_name)s\pristine_samples 27 | perturbed_apk_dir = %(database_dir)s\%(dataset_name)s\perturbed_samples 28 | intermediate_directory = %(project_root)s\datasets\%(dataset_name)s 29 | 30 | [metadata] 31 | naive_data_pool = %(database_dir)s\naive_data 32 | 33 | [experiments] 34 | oos = %(project_root)s/save/oos/ 35 | adv = %(project_root)s/save/adv/ 36 | drebin = %(project_root)s/save/drebin/ 37 | androzoo = %(project_root)s/save/androzoo/ -------------------------------------------------------------------------------- /conf-server-ubuntu: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | project_root = /absolute/path/to/malware-uncertainty/ 3 | database_dir = /absolute/path/to/datasets 4 | 5 | [drebin] 6 | dataset_name = drebin 7 | malware_dir = %(database_dir)s/%(dataset_name)s/malicious_samples 8 | benware_dir = %(database_dir)s/%(dataset_name)s/benign_samples 9 | intermediate_directory = %(project_root)s/datasets/%(dataset_name)s 10 | 11 | [androzoo_tesseract] 12 | dataset_name = androzoo_tesseract 13 | malware_dir = %(database_dir)s/%(dataset_name)s/malicious_samples 14 | benware_dir = %(database_dir)s/%(dataset_name)s/benign_samples 15 | date_stamp = %(database_dir)s/%(dataset_name)s/date_stamp.json 16 | intermediate_directory = %(project_root)s/datasets/%(dataset_name)s 17 | 18 | [oos] 19 | dataset_name = VirusShare_Android_APK_2013 20 | malware_dir = %(database_dir)s/%(dataset_name)s/malicious_samples 21 | benware_dir = %(database_dir)s/%(dataset_name)s/benign_samples 22 | intermediate_directory = %(project_root)s/datasets/%(dataset_name)s 23 | 24 | [adv] 25 | dataset_name = drebin_adv 26 | pristine_apk_dir = %(database_dir)s/%(dataset_name)s/pristine_samples 27 | perturbed_apk_dir = %(database_dir)s/%(dataset_name)s/perturbed_samples 28 | intermediate_directory = %(project_root)s/datasets/%(dataset_name)s 29 | 30 | [metadata] 31 | naive_data_pool = %(database_dir)s/naive_data/ 32 | 33 | [experiments] 34 | oos = %(project_root)s/save/oos/ 35 | adv = %(project_root)s/save/adv/ 36 | drebin = %(project_root)s/save/drebin/ 37 | androzoo = %(project_root)s/save/androzoo/ -------------------------------------------------------------------------------- /test/feature_extraction_test.py: -------------------------------------------------------------------------------- 1 | """Tests for malware-uncertainty.core.feature""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from absl.testing import absltest 8 | from absl.testing import parameterized 9 | 10 | import numpy as np 11 | 12 | from core.feature import feature_type_scope_dict 13 | from config import config 14 | 15 | 16 | class MyTestCase(parameterized.TestCase): 17 | @parameterized.named_parameters( 18 | [("%s" % mtd_name,) + (mtd_name,) for mtd_name in [feature_type_scope_dict['apiseq']]]) 19 | def test_feature_extraction(self, feature_type): 20 | malware_dir_name = config.get('drebin', 'malware_dir') 21 | benware_dir_name = config.get('drebin', 'benware_dir') 22 | meta_data_saving_dir = config.get('drebin', 'intermediate_directory') 23 | naive_data_saving_dir = config.get('metadata', 'naive_data_pool') 24 | feature_extractor = feature_type(naive_data_saving_dir, meta_data_saving_dir, update=False) 25 | malware_features = feature_extractor.feature_extraction(malware_dir_name) 26 | benign_features = feature_extractor.feature_extraction(benware_dir_name) 27 | features = malware_features + benign_features 28 | gt_labels = np.zeros((len(malware_features) + len(benign_features)), dtype=np.int32) 29 | gt_labels[:len(malware_features)] = 1 30 | feature_extractor.feature_preprocess(features, gt_labels) 31 | feature_extractor.feature2ipt(features) 32 | 33 | 34 | if __name__ == '__main__': 35 | absltest.main() 36 | -------------------------------------------------------------------------------- /tools/progressbar/compat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # progressbar - Text progress bar library for Python. 5 | # Copyright (c) 2005 Nilton Volpato 6 | # 7 | # This library is free software; you can redistribute it and/or 8 | # modify it under the terms of the GNU Lesser General Public 9 | # License as published by the Free Software Foundation; either 10 | # version 2.1 of the License, or (at your option) any later version. 11 | # 12 | # This library is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 15 | # Lesser General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU Lesser General Public 18 | # License along with this library; if not, write to the Free Software 19 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 20 | 21 | """Compatibility methods and classes for the progressbar module.""" 22 | 23 | 24 | # Python 3.x (and backports) use a modified iterator syntax 25 | # This will allow 2.x to behave with 3.x iterators 26 | try: 27 | next 28 | except NameError: 29 | def next(iter): 30 | try: 31 | # Try new style iterators 32 | return iter.__next__() 33 | except AttributeError: 34 | # Fallback in case of a "native" iterator 35 | return iter.next() 36 | 37 | 38 | # Python < 2.5 does not have "any" 39 | try: 40 | any 41 | except NameError: 42 | def any(iterator): 43 | for item in iterator: 44 | if item: return True 45 | return False 46 | -------------------------------------------------------------------------------- /core/feature/opcodeseq/opcodeseq.py: -------------------------------------------------------------------------------- 1 | from androguard.misc import AnalyzeAPK 2 | from tools import utils 3 | 4 | from config import logging 5 | 6 | logger = logging.getLogger('feature.opcodeseq') 7 | 8 | 9 | def get_opcode_sequences(apk_path, save_path): 10 | _1, _2, dx = AnalyzeAPK(apk_path) 11 | 12 | opcode_chunks = [] 13 | for method in dx.get_methods(): 14 | if method.is_external(): 15 | continue 16 | mth_body = method.get_method() 17 | sequence = [] 18 | for ins in mth_body.get_instructions(): 19 | opcode = ins.get_op_value() 20 | if opcode < 0: 21 | opcode = 0 22 | elif opcode >= 256: 23 | opcode = 0 24 | else: 25 | opcode = opcode 26 | sequence.append(opcode) # list of 'int' 27 | if len(sequence) > 0: 28 | opcode_chunks.append(sequence) 29 | dump_opcode(opcode_chunks, save_path) 30 | 31 | return save_path 32 | 33 | 34 | def dump_opcode(opcode_chunks, save_path): 35 | utils.dump_json(opcode_chunks, save_path) 36 | return 37 | 38 | 39 | def read_opcode(save_path): 40 | return utils.load_json(save_path) 41 | 42 | 43 | def read_opcode_wrapper(save_path): 44 | try: 45 | return read_opcode(save_path) 46 | except Exception as e: 47 | return e 48 | 49 | 50 | def feature_extr_wrapper(*args): 51 | """ 52 | A helper function to catch the exception 53 | :param element: argurments for feature extraction 54 | :return: feature or Exception 55 | """ 56 | try: 57 | return get_opcode_sequences(*args) 58 | except Exception as e: 59 | return e 60 | -------------------------------------------------------------------------------- /core/ensemble/ensemble.py: -------------------------------------------------------------------------------- 1 | """ 2 | base ensemble class 3 | """ 4 | 5 | 6 | class Ensemble(object): 7 | def __init__(self, architecture_type, base_model, n_members, model_directory): 8 | """ 9 | initialization 10 | :param architecture_type: e.g., 'dnn' 11 | :param base_model: an instantiated object of base model 12 | :param n_members: number of base models 13 | :param model_directory: a folder for saving ensemble weights 14 | """ 15 | self.architecture_type = architecture_type 16 | self.base_model = base_model 17 | self.n_members = n_members 18 | self.model_directory = model_directory 19 | self.weights_list = [] # a model's parameters 20 | self._optimizers_dict = dict() 21 | 22 | def build_model(self): 23 | """Build an ensemble model""" 24 | raise NotImplementedError 25 | 26 | def predict(self, x): 27 | """conduct prediction""" 28 | raise NotImplementedError 29 | 30 | def get_basic_layers(self): 31 | """ construct the basic layers""" 32 | raise NotImplementedError 33 | 34 | def fit(self, train_x, train_y, val_x=None, val_y=None, **kwargs): 35 | """ tune the model parameters upon given dataset""" 36 | raise NotImplementedError 37 | 38 | def get_model_number(self): 39 | """ get the number of base models""" 40 | raise NotImplementedError 41 | 42 | def reset(self): 43 | self.weights_list = [] 44 | 45 | def save_ensemble_weights(self): 46 | """ save the model parameters""" 47 | raise NotImplementedError 48 | 49 | def load_ensemble_weights(self): 50 | """ Load the model parameters """ 51 | raise NotImplementedError 52 | 53 | def gradient_loss_wrt_input(self, x): 54 | """ obtain gradients of loss function with respect to the input.""" 55 | raise NotImplementedError 56 | -------------------------------------------------------------------------------- /experiments/adv_main.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import argparse 5 | from tensorflow.compat.v1 import app 6 | from experiments.adv import run_experiment, run_temperature_scaling 7 | 8 | args = argparse.ArgumentParser(description='test malware detectors on the adversarial dataset') 9 | args.add_argument('--detector', type=str, default='drebin', choices=[ 10 | 'drebin', # deepdrebin 11 | 'opcodeseq', # deepDroid 12 | 'multimodality', # multimodalnn 13 | 'dex2img', # r2d2, due to the issue of effectiveness, we neglect this method 14 | 'apiseq', # droidetec, due to the issue of effectiveness, we neglect this method 15 | ], help='malware detection method') 16 | 17 | args.add_argument('--calibration', type=str, default='vanilla', choices=[ 18 | 'vanilla', 19 | 'temp_scaling' 20 | 'mc_dropout', 21 | 'deep_ensemble', 22 | 'weighted_ensemble', 23 | 'bayesian' # variational bayesian inference 24 | ], help='calibration method') 25 | args.add_argument('--proc_numbers', type=int, default=2, 26 | help='number of threads for parallelizing features extraction.') 27 | 28 | option = args.parse_args() 29 | 30 | non_used_methods = ['dex2img', 'apiseq'] 31 | assert option.detector not in non_used_methods 32 | 33 | 34 | def main(_): 35 | if option.calibration != 'temp_scaling': 36 | run_experiment(option.detector, 37 | option.calibration, 38 | option.proc_numbers 39 | ) 40 | else: 41 | run_temperature_scaling(option.detector, 'vanilla', option.proc_numbers) 42 | 43 | 44 | if __name__ == '__main__': 45 | app.run() 46 | -------------------------------------------------------------------------------- /experiments/oos_main.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | from tensorflow.compat.v1 import app 7 | from experiments.oos import run_experiment, run_temperature_scaling 8 | 9 | args = argparse.ArgumentParser(description='test malware detectors on the adversarial dataset') 10 | args.add_argument('--detector', type=str, default='drebin', choices=[ 11 | 'drebin', # deepdrebin 12 | 'opcodeseq', # deepDroid 13 | 'multimodality', # multimodalnn 14 | 'dex2img', # r2d2, due to the issue of effectiveness, we neglect this method 15 | 'apiseq', # droidetec, due to the issue of effectiveness, we neglect this method 16 | ], help='malware detection method') 17 | 18 | args.add_argument('--calibration', type=str, default='vanilla', choices=[ 19 | 'vanilla', 20 | 'temp_scaling' 21 | 'mc_dropout', 22 | 'deep_ensemble', 23 | 'weighted_ensemble', 24 | 'bayesian' # variational bayesian inference 25 | ], help='calibration method') 26 | args.add_argument('--proc_numbers', type=int, default=2, 27 | help='number of threads for parallelizing features extraction.') 28 | 29 | option = args.parse_args() 30 | 31 | non_used_methods = ['dex2img', 'apiseq'] 32 | assert option.detector not in non_used_methods 33 | 34 | 35 | def main(_): 36 | if option.calibration != 'temp_scaling': 37 | run_experiment(option.detector, 38 | option.calibration, 39 | option.proc_numbers 40 | ) 41 | else: 42 | run_temperature_scaling(option.detector, 'vanilla', option.proc_numbers) 43 | 44 | 45 | if __name__ == '__main__': 46 | app.run() 47 | 48 | -------------------------------------------------------------------------------- /experiments/androzoo_main.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | from tensorflow.compat.v1 import app 7 | from experiments.androzoo_dataset import run_experiment, run_temperature_scaling 8 | 9 | args = argparse.ArgumentParser(description='learning malware detectors on the Drebin dataset') 10 | args.add_argument('--detector', type=str, default='drebin', choices=[ 11 | 'drebin', # deepdrebin 12 | 'opcodeseq', # deepDroid 13 | 'multimodality', # multimodalnn 14 | 'dex2img', # r2d2, due to the issue of effectiveness, we neglect this method 15 | 'apiseq', # droidetec, due to the issue of effectiveness, we neglect this method 16 | ], help='malware detection method') 17 | 18 | args.add_argument('--calibration', type=str, default='vanilla', choices=[ 19 | 'vanilla', 20 | 'temp_scaling' 21 | 'mc_dropout', 22 | 'deep_ensemble', 23 | 'weighted_ensemble', 24 | 'bayesian' # variational bayesian inference 25 | ], help='calibration method') 26 | args.add_argument('--n_members', type=int, default=10, help='number of members in ensemble or weighted ensemble.') 27 | args.add_argument('--proc_numbers', type=int, default=2, 28 | help='number of threads for parallelizing features extraction.') 29 | 30 | option = args.parse_args() 31 | 32 | non_used_methods = ['dex2img', 'apiseq'] 33 | assert option.detector not in non_used_methods 34 | 35 | 36 | def main(_): 37 | if option.calibration != 'temp_scaling': 38 | run_experiment(option.detector, 39 | option.calibration, 40 | option.n_members, 41 | option.proc_numbers 42 | ) 43 | else: 44 | run_temperature_scaling(option.detector, 'vanilla') 45 | 46 | 47 | if __name__ == '__main__': 48 | app.run() 49 | -------------------------------------------------------------------------------- /tools/progressbar/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # progressbar - Text progress bar library for Python. 5 | # Copyright (c) 2005 Nilton Volpato 6 | # 7 | # This library is free software; you can redistribute it and/or 8 | # modify it under the terms of the GNU Lesser General Public 9 | # License as published by the Free Software Foundation; either 10 | # version 2.1 of the License, or (at your option) any later version. 11 | # 12 | # This library is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 15 | # Lesser General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU Lesser General Public 18 | # License along with this library; if not, write to the Free Software 19 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 20 | 21 | """Text progress bar library for Python. 22 | 23 | A text progress bar is typically used to display the progress of a long 24 | running operation, providing a visual cue that processing is underway. 25 | 26 | The ProgressBar class manages the current progress, and the format of the line 27 | is given by a number of widgets. A widget is an object that may display 28 | differently depending on the state of the progress bar. There are three types 29 | of widgets: 30 | - a string, which always shows itself 31 | 32 | - a ProgressBarWidget, which may return a different value every time its 33 | update method is called 34 | 35 | - a ProgressBarWidgetHFill, which is like ProgressBarWidget, except it 36 | expands to fill the remaining width of the line. 37 | 38 | The progressbar module is very easy to use, yet very powerful. It will also 39 | automatically enable feature like auto-resizing when the system supports it. 40 | """ 41 | 42 | __author__ = 'Nilton Volpato' 43 | __author_email__ = 'first-name dot last-name @ gmail.com' 44 | __date__ = '2011-05-14' 45 | __version__ = '2.3' 46 | 47 | 48 | from .compat import * 49 | from .widgets import * 50 | from .progressbar import * 51 | -------------------------------------------------------------------------------- /core/feature/dex2img/dex2img.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | from PIL import Image 5 | import zipfile 6 | 7 | import numpy as np 8 | 9 | from config import logging 10 | 11 | current_dir = os.path.dirname(os.path.realpath(__file__)) 12 | logger = logging.getLogger('feature.dex2img') 13 | 14 | 15 | def dex2img(apk_path, save_path, num_channels=3): 16 | """ 17 | convert dex file to rbg images 18 | :param apk_path: an apk path 19 | :param save_path: a path for saving the resulting image 20 | :param num_channels: r, g, b channels 21 | :return: (status, save_path) 22 | """ 23 | try: 24 | print("Processing " + apk_path) 25 | start_time = time.time() 26 | with zipfile.ZipFile(apk_path, 'r') as fh_apk: 27 | dex2num_list = [] 28 | for name in fh_apk.namelist(): 29 | if name.endswith('dex'): 30 | with fh_apk.open(name, 'r') as fr: 31 | hex_string = fr.read().hex() 32 | dex2num = [int(hex_string[i:i + 2], base=16) for i in \ 33 | range(0, len(hex_string), 2)] 34 | dex2num_list.extend(dex2num) 35 | 36 | # extend to three channels (e.g., r,g,b) 37 | num_appending_zero = num_channels - len(dex2num_list) % num_channels 38 | dex2num_list += [0] * num_appending_zero 39 | # shape: [3, -1] 40 | dex2array = np.array([dex2num_list[0::3], dex2num_list[1::3], dex2num_list[2::3]], dtype=np.uint8) 41 | # get image matrix 42 | from math import sqrt, ceil 43 | _length = int(pow(ceil(sqrt(dex2array.shape[1])), 2)) 44 | if _length > dex2array.shape[1]: 45 | padding_zero = np.zeros((3, _length - dex2array.shape[1]), dtype=np.uint8) 46 | dex2array = np.concatenate([dex2array, padding_zero], axis=1) 47 | dex2mat = np.reshape(dex2array, (-1, int(sqrt(_length)), int(sqrt(_length)))) 48 | dex2mat_img = np.transpose(dex2mat, (1, 2, 0)) 49 | img_handler = Image.fromarray(dex2mat_img) 50 | img_handler.save(save_path) 51 | except Exception as e: 52 | return e 53 | else: 54 | return save_path 55 | -------------------------------------------------------------------------------- /core/ensemble/mc_dropout.py: -------------------------------------------------------------------------------- 1 | from core.ensemble.vanilla import Vanilla, model_builder 2 | from core.ensemble.model_hp import train_hparam, mc_dropout_hparam 3 | from tools import utils 4 | from config import logging, ErrorHandler 5 | 6 | logger = logging.getLogger('ensemble.mc_dropout') 7 | logger.addHandler(ErrorHandler) 8 | 9 | 10 | class MCDropout(Vanilla): 11 | def __init__(self, 12 | architecture_type='dnn', 13 | base_model=None, 14 | n_members=1, 15 | model_directory=None, 16 | name='MC_DROPOUT' 17 | ): 18 | super(MCDropout, self).__init__(architecture_type, 19 | base_model, 20 | n_members, 21 | model_directory, 22 | name) 23 | self.hparam = utils.merge_namedtuples(train_hparam, mc_dropout_hparam) 24 | self.ensemble_type = 'mc_dropout' 25 | 26 | def build_model(self, input_dim=None): 27 | """ 28 | Build an ensemble model -- only the homogeneous structure is considered 29 | :param input_dim: integer or list, input dimension shall be set in some cases under eager mode 30 | """ 31 | callable_graph = model_builder(self.architecture_type) 32 | 33 | @callable_graph(input_dim, use_mc_dropout=True) 34 | def _builder(): 35 | return utils.produce_layer(self.ensemble_type, dropout_rate=self.hparam.dropout_rate) 36 | 37 | self.base_model = _builder() 38 | return 39 | 40 | def model_generator(self): 41 | try: 42 | if len(self.weights_list) <= 0: 43 | self.load_ensemble_weights() 44 | except Exception as e: 45 | raise Exception("Cannot load model weights:{}.".format(str(e))) 46 | 47 | assert len(self.weights_list) == self.n_members 48 | self.base_model.set_weights(weights=self.weights_list[self.n_members - 1]) 49 | # if len(self._optimizers_dict) > 0 and self.base_model.optimizer is not None: 50 | # self.base_model.optimizer.set_weights(self._optimizers_dict[self.n_members - 1]) 51 | for _ in range(self.hparam.n_sampling): 52 | yield self.base_model 53 | -------------------------------------------------------------------------------- /core/post_calibration/temperature_scaling.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import scipy.optimize 6 | import scipy.special 7 | import numpy as np 8 | import tensorflow as tf 9 | import tensorflow_probability as tfp 10 | 11 | 12 | def find_scaling_temperature(labels, logits, temp_range=(1e-5, 1e5)): 13 | """ 14 | Code is adapted from https://github.com/google-research/google-research/tree/master/uq_benchmark_2019/ 15 | Find max likelihood scaling temperature using binary search. 16 | 17 | Args: 18 | labels: Integer labels (shape=[num_samples]). 19 | logits: Floating point softmax inputs (shape=[num_samples, num_classes]). 20 | temp_range: 2-tuple range of temperatures to consider. 21 | Returns: 22 | Floating point temperature value. 23 | """ 24 | if not tf.executing_eagerly(): 25 | raise NotImplementedError( 26 | 'find_scaling_temperature() not implemented for graph-mode TF') 27 | if len(labels.shape) != 1: 28 | raise ValueError('Invalid labels shape=%s' % str(labels.shape)) 29 | if len(logits.shape) not in (1, 2): 30 | raise ValueError('Invalid logits shape=%s' % str(logits.shape)) 31 | if len(labels.shape) != 1 or len(labels) != len(logits): 32 | raise ValueError('Incompatible shapes for logits (%s) vs labels (%s).' % 33 | (logits.shape, labels.shape)) 34 | 35 | @tf.function(autograph=False, experimental_relax_shapes=True) 36 | def grad_fn(temperature): 37 | """Returns gradient of log-likelihood WRT a logits-scaling temperature.""" 38 | dist = tfp.distributions.Bernoulli(logits=logits / temperature) 39 | nll = -dist.log_prob(labels) 40 | nll = tf.reduce_sum(nll, axis=0) 41 | grad, = tf.gradients(nll, [temperature]) 42 | return grad 43 | 44 | tmin, tmax = temp_range 45 | return scipy.optimize.bisect(lambda t: grad_fn(t * tf.ones([])).numpy(), tmin, tmax) 46 | 47 | 48 | def inverse_sigmoid(probs, eps=1e-7): 49 | """ compute the logit""" 50 | uniform = np.ones_like(probs) / probs.shape[-1] 51 | probs = eps * uniform + (1. - eps) * probs 52 | return -tf.math.log(1. / probs - 1.) 53 | 54 | 55 | def apply_temperature_scaling(temperature, probs): 56 | """Apply temperature scaling to an array of probabilities.""" 57 | logits_t = inverse_sigmoid(probs) / temperature 58 | return tf.sigmoid(logits_t).numpy() 59 | -------------------------------------------------------------------------------- /experiments/drebin_main.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | from tensorflow.compat.v1 import app 7 | from experiments.drebin_dataset import run_experiment, run_temperature_scaling 8 | 9 | args = argparse.ArgumentParser(description='learning malware detectors on the Drebin dataset') 10 | args.add_argument('--detector', type=str, default='drebin', choices=[ 11 | 'drebin', # deepdrebin 12 | 'opcodeseq', # deepDroid 13 | 'multimodality', # multimodalnn 14 | 'dex2img', # r2d2, due to the issue of effectiveness, we neglect this method 15 | 'apiseq', # droidetec, due to the issue of effectiveness, we neglect this method 16 | ], help='malware detection method') 17 | 18 | args.add_argument('--calibration', type=str, default='vanilla', choices=[ 19 | 'vanilla', 20 | 'temp_scaling' 21 | 'mc_dropout', 22 | 'deep_ensemble', 23 | 'weighted_ensemble', 24 | 'bayesian' # variational bayesian inference 25 | ], help='calibration method') 26 | args.add_argument('--seed', type=int, default=7727, help='random seed') 27 | # this seed above has not been applied to NN models yet. It could, but we neglect when conducting experiments. 28 | args.add_argument('--n_members', type=int, default=10, help='number of members in ensemble or weighted ensemble.') 29 | args.add_argument('--proc_numbers', type=int, default=2, 30 | help='number of threads for paralleling features extraction.') 31 | args.add_argument('--ratio', type=float, default=3.0, help='ratio of the number of benign files to malware ones.') 32 | 33 | option = args.parse_args() 34 | 35 | non_used_methods = ['dex2img', 'apiseq'] 36 | assert option.detector not in non_used_methods 37 | 38 | 39 | def main(_): 40 | if option.calibration != 'temp_scaling': 41 | run_experiment(option.detector, 42 | option.calibration, 43 | option.seed, 44 | option.n_members, 45 | option.ratio, 46 | option.proc_numbers 47 | ) 48 | else: 49 | run_temperature_scaling(option.detector, 'vanilla') 50 | 51 | 52 | if __name__ == '__main__': 53 | app.run() 54 | -------------------------------------------------------------------------------- /test/model_lib_test.py: -------------------------------------------------------------------------------- 1 | from absl.testing import absltest 2 | from absl.testing import parameterized 3 | 4 | from core.ensemble.model_lib import build_models 5 | import numpy as np 6 | 7 | 8 | class Model_lib_test(parameterized.TestCase): 9 | @parameterized.named_parameters( 10 | [(ens_type, ens_type) for ens_type in \ 11 | ['vanilla', 'deep_ensemble', 'weighted_ensemble', 'mc_dropout', 'bayesian']]) 12 | def test_dnn_graph(self, ensemble_type): 13 | x = np.array([[1., 2., 3.], [4., 5., 6.]]) 14 | if ensemble_type is not 'mc_dropout': 15 | build_models(x, 'dnn', ensemble_type, input_dim=3) 16 | else: 17 | build_models(x, 'dnn', ensemble_type, input_dim=3, use_mc_dropout=True) 18 | 19 | @parameterized.named_parameters( 20 | [(ens_type, ens_type) for ens_type in \ 21 | ['vanilla', 'deep_ensemble', 'weighted_ensemble', 'mc_dropout', 'bayesian']]) 22 | def test_text_cnn_graph(self, ensemble_type): 23 | x = np.random.randint(0, 256, (2, 16)) 24 | if ensemble_type is not 'mc_dropout': 25 | out = build_models(x, 'text_cnn', ensemble_type, input_dim=5) 26 | else: 27 | out = build_models(x, 'text_cnn', ensemble_type, input_dim=5, use_mc_dropout=True) 28 | 29 | self.assertTrue(out.shape == (2, 1)) 30 | 31 | @parameterized.named_parameters( 32 | [(ens_type, ens_type) for ens_type in \ 33 | ['vanilla', 'deep_ensemble', 'weighted_ensemble', 'mc_dropout', 'bayesian']]) 34 | def test_multimodalitynn(self, ensemble_type): 35 | _x = np.array([[1., 2., 3.], [4., 5., 6.]]) 36 | x = [_x, _x, _x, _x, _x] 37 | if ensemble_type is not 'mc_dropout': 38 | out = build_models(x, 'multimodalitynn', ensemble_type, input_dim=[3, 3, 3, 3, 3]) 39 | else: 40 | out = build_models(x, 'multimodalitynn', ensemble_type, input_dim=[3, 3, 3, 3, 3], use_mc_dropout=True) 41 | self.assertTrue(out.shape == (2, 1)) 42 | 43 | @parameterized.named_parameters( 44 | [(ens_type, ens_type) for ens_type in \ 45 | ['vanilla', 'deep_ensemble', 'weighted_ensemble', 'mc_dropout', 'bayesian']]) 46 | def test_droidectc(self, ensemble_type): 47 | _x = np.random.randint(0, 10000, [2, 1000]) 48 | if ensemble_type is not 'mc_dropout': 49 | out = build_models(_x, 'droidectc', ensemble_type) 50 | else: 51 | out = build_models(_x, 'droidectc', ensemble_type, use_mc_dropout=True) 52 | self.assertTrue(out.shape == (2, 1)) 53 | 54 | 55 | if __name__ == '__main__': 56 | absltest.main() 57 | -------------------------------------------------------------------------------- /core/ensemble/anchor_ensemble_test.py: -------------------------------------------------------------------------------- 1 | from absl.testing import absltest 2 | from absl.testing import parameterized 3 | import tempfile 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | from sklearn.datasets import load_breast_cancer 8 | 9 | from core.ensemble.anchor_ensemble import AnchorEnsemble 10 | from core.ensemble.dataset_lib import build_dataset_from_numerical_data 11 | 12 | architectures = ['dnn', 'text_cnn', 'multimodalitynn', 'r2d2', 'droidectc'] 13 | 14 | class MyTestCaseAnchorEnsemble(parameterized.TestCase): 15 | def setUp(self): 16 | self.x_dict, self.y_dict = dict(), dict() 17 | self.x_np, self.y_np = load_breast_cancer(return_X_y=True) 18 | self.x_dict['dnn'] = self.x_np 19 | self.y_dict['dnn'] = self.y_np 20 | x = np.random.randint(0, 256, (10, 10)) 21 | y = np.random.choice(2, 10) 22 | self.x_dict['text_cnn'] = x 23 | self.y_dict['text_cnn'] = y 24 | x = [self.x_np] * 5 25 | self.x_dict['multimodalitynn'] = x 26 | self.y_dict['multimodalitynn'] = self.y_np 27 | x = np.random.uniform(0., 1., size=(10, 299, 299, 3)) 28 | y = np.random.choice(2, 10) 29 | self.x_dict['r2d2'] = x 30 | self.y_dict['r2d2'] = y 31 | x = np.random.randint(0, 10000, size=(10, 1000)) 32 | y = np.random.choice(2, 10) 33 | self.x_dict['droidectc'] = x 34 | self.y_dict['droidectc'] = y 35 | 36 | @parameterized.named_parameters([(arc_type, arc_type) for arc_type in architectures]) 37 | def test_anchor_ensemble(self, arc_type): 38 | with tempfile.TemporaryDirectory() as output_dir: 39 | x = self.x_dict[arc_type] 40 | y = self.y_dict[arc_type] 41 | if arc_type is not 'multimodalitynn': 42 | train_dataset = build_dataset_from_numerical_data((x, y)) 43 | val_dataset = build_dataset_from_numerical_data((x, y)) 44 | n_samples = x.shape[0] 45 | input_dim = x.shape[1:] 46 | else: 47 | train_data = build_dataset_from_numerical_data(tuple(x)) 48 | train_y = build_dataset_from_numerical_data(self.y_np) 49 | train_dataset = tf.data.Dataset.zip((train_data, train_y)) 50 | val_data = build_dataset_from_numerical_data(tuple(x)) 51 | val_y = build_dataset_from_numerical_data(self.y_np) 52 | val_dataset = tf.data.Dataset.zip((val_data, val_y)) 53 | n_samples = x[0].shape[0] 54 | input_dim = [x[i].shape[1] for i in range(len(x))] 55 | 56 | anchor_ensemble = AnchorEnsemble(architecture_type=arc_type, 57 | model_directory=output_dir) 58 | anchor_ensemble.fit(train_dataset, val_dataset, input_dim=input_dim) 59 | 60 | res = anchor_ensemble.predict(x) 61 | self.assertEqual(anchor_ensemble.get_n_members(), anchor_ensemble.n_members) 62 | self.assertTrue(res.shape == (n_samples, anchor_ensemble.n_members, 1)) 63 | 64 | anchor_ensemble.evaluate(x, y) 65 | 66 | 67 | if __name__ == '__main__': 68 | absltest.main() 69 | -------------------------------------------------------------------------------- /core/ensemble/dataset_lib.py: -------------------------------------------------------------------------------- 1 | """ This script is for building dataset """ 2 | 3 | import tensorflow as tf 4 | from core.ensemble.model_hp import train_hparam 5 | 6 | 7 | def build_dataset_from_numerical_data(data, batch_size=None): 8 | """ 9 | serialize the data to accommodate the format of model input 10 | :param data, tuple or np.ndarray 11 | :param batch_size, scalar or none, the train paramemeter is default if none provided 12 | """ 13 | batch_size = train_hparam.batch_size if batch_size is None else batch_size 14 | return tf.data.Dataset.from_tensor_slices(data). \ 15 | cache(). \ 16 | batch(batch_size). \ 17 | prefetch(tf.data.experimental.AUTOTUNE) 18 | 19 | 20 | def build_dataset_via_generator(generator, y=None, path='', batch_size=None): 21 | batch_size = train_hparam.batch_size if batch_size is None else batch_size 22 | if y is not None: 23 | return tf.data.Dataset.from_generator(generator, 24 | output_types=(tf.int32, tf.int32), 25 | output_shapes=(tf.TensorShape([None]), tf.TensorShape([])) 26 | ). \ 27 | padded_batch(batch_size, padded_shapes=([None], [])). \ 28 | cache(path). \ 29 | shuffle(buffer_size=100). \ 30 | prefetch(tf.data.experimental.AUTOTUNE) 31 | else: 32 | return tf.data.Dataset.from_generator(generator, 33 | output_types=tf.int32, 34 | output_shapes=tf.TensorShape([None]) 35 | ). \ 36 | padded_batch(batch_size, padded_shapes=([None])). \ 37 | cache(path). \ 38 | prefetch(tf.data.experimental.AUTOTUNE) 39 | 40 | 41 | def build_dataset_from_img_generator(generator, input_dim, y=None, is_training=False): 42 | if is_training and y is not None: 43 | return tf.data.Dataset.from_generator(generator, 44 | output_types=(tf.float32, tf.float32, tf.float32), 45 | output_shapes=(tf.TensorShape([None, *input_dim]), 46 | tf.TensorShape([None, ]), 47 | tf.TensorShape([None, ])) 48 | ) 49 | elif not is_training and y is not None: 50 | return tf.data.Dataset.from_generator(generator, 51 | output_types=(tf.float32, tf.float32), 52 | output_shapes=(tf.TensorShape([None, *input_dim]), 53 | tf.TensorShape([None, ])) 54 | ) 55 | else: 56 | return tf.data.Dataset.from_generator(generator, 57 | output_types=tf.float32, 58 | output_shapes=tf.TensorShape([None, *input_dim]) 59 | ) 60 | -------------------------------------------------------------------------------- /core/ensemble/bayesian_ensemble_test.py: -------------------------------------------------------------------------------- 1 | from absl.testing import absltest 2 | from absl.testing import parameterized 3 | import tempfile 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | from sklearn.datasets import load_breast_cancer 8 | 9 | from core.ensemble.bayesian_ensemble import BayesianEnsemble 10 | from core.ensemble.dataset_lib import build_dataset_from_numerical_data 11 | 12 | architectures = ['dnn', 'text_cnn', 'multimodalitynn', 'r2d2', 'droidectc'] 13 | 14 | 15 | class MyTestCaseBayesianEnsemble(parameterized.TestCase): 16 | def setUp(self): 17 | self.x_dict, self.y_dict = dict(), dict() 18 | self.x_np, self.y_np = load_breast_cancer(return_X_y=True) 19 | self.x_dict['dnn'] = self.x_np 20 | self.y_dict['dnn'] = self.y_np 21 | x = np.random.randint(0, 256, (10, 10)) 22 | y = np.random.choice(2, 10) 23 | self.x_dict['text_cnn'] = x 24 | self.y_dict['text_cnn'] = y 25 | x = [self.x_np] * 5 26 | self.x_dict['multimodalitynn'] = x 27 | self.y_dict['multimodalitynn'] = self.y_np 28 | x = np.random.uniform(0., 1., size=(10, 299, 299, 3)) 29 | y = np.random.choice(2, 10) 30 | self.x_dict['r2d2'] = x 31 | self.y_dict['r2d2'] = y 32 | x = np.random.randint(0, 10000, size=(10, 1000)) 33 | y = np.random.choice(2, 10) 34 | self.x_dict['droidectc'] = x 35 | self.y_dict['droidectc'] = y 36 | 37 | @parameterized.named_parameters([(arc_type, arc_type) for arc_type in architectures]) 38 | def test_beyasian_ensemble(self, arc_type): 39 | with tempfile.TemporaryDirectory() as output_dir: 40 | x = self.x_dict[arc_type] 41 | y = self.y_dict[arc_type] 42 | if arc_type is not 'multimodalitynn': 43 | train_dataset = build_dataset_from_numerical_data((x, y)) 44 | val_dataset = build_dataset_from_numerical_data((x, y)) 45 | n_samples = x.shape[0] 46 | input_dim = x.shape[1:] 47 | else: 48 | train_data = build_dataset_from_numerical_data(tuple(x)) 49 | train_y = build_dataset_from_numerical_data(self.y_np) 50 | train_dataset = tf.data.Dataset.zip((train_data, train_y)) 51 | val_data = build_dataset_from_numerical_data(tuple(x)) 52 | val_y = build_dataset_from_numerical_data(self.y_np) 53 | val_dataset = tf.data.Dataset.zip((val_data, val_y)) 54 | n_samples = x[0].shape[0] 55 | input_dim = [x[i].shape[1] for i in range(len(x))] 56 | 57 | bayesian_ensemble = BayesianEnsemble(architecture_type=arc_type, 58 | model_directory=output_dir) 59 | bayesian_ensemble.fit(train_dataset, val_dataset, input_dim=input_dim) 60 | 61 | res = bayesian_ensemble.predict(x) 62 | self.assertEqual(bayesian_ensemble.get_n_members(), bayesian_ensemble.n_members) 63 | self.assertTrue(res.shape == (n_samples, bayesian_ensemble.hparam.n_sampling, 1)) 64 | 65 | bayesian_ensemble.evaluate(x, y) 66 | 67 | 68 | if __name__ == '__main__': 69 | absltest.main() 70 | -------------------------------------------------------------------------------- /tools/progressbar_wrapper.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import time 3 | 4 | import tools.progressbar.progressbar as progressbar 5 | import tools.progressbar.widgets as progressbar_widgets 6 | 7 | 8 | class ProgressBar(object): 9 | def __init__(self): 10 | self.TotalResults = [] 11 | self.NumberOfFinishedResults = 0 12 | 13 | def Update(self): 14 | self.ProgressBar.update(self.NumberOfFinishedResults) 15 | return 16 | 17 | def CallbackForProgressBar(self, res=''): 18 | """ 19 | Callback function for pool.async if the progress bar needs to be displayed. 20 | Must use with DisplayProgressBar function. 21 | 22 | :param multiprocessing.pool.AsyncResult res: Result got from callback function in pool.async. 23 | """ 24 | self.NumberOfFinishedResults += 1 25 | self.TotalResults.append(res) 26 | return 27 | 28 | def DisplayProgressBar(self, ProcessingResults, ExpectedResultsSize, CheckInterval=1, type="minute"): 29 | ''' 30 | Display a progress bar for multiprocessing. This function should be used after pool.close(No need to use pool.join anymore). 31 | The call back function for pool.async should be set as CallbackForProgressBar. 32 | 33 | :param multiprocessing.pool.AsyncResult ProcessingResults: Processing results returned by pool.async. 34 | :param int ExpectedResultsSize: How many result you will reveive, i.e. the total length of progress bar. 35 | :param float CheckInterval: How many seconds will the progress bar be updated. When it's too large, the _main program may hang there. 36 | :param String type: Three types: "minute", "hour", "second"; corresponds displaying iters/minute iters/hour and iters/second. 37 | ''' 38 | self.ProcessingResults = ProcessingResults 39 | ProgressBarWidgets = [progressbar_widgets.Percentage(), 40 | ' ', progressbar_widgets.Bar(), 41 | ' ', progressbar_widgets.SimpleProgress(), 42 | ' ', progressbar_widgets.Timer(), 43 | ' ', progressbar_widgets.AdaptiveETA()] 44 | self.ProgressBar = progressbar.ProgressBar(ExpectedResultsSize, ProgressBarWidgets) 45 | self.StartTime = time.time() 46 | PreviousNumberOfResults = 0 47 | self.ProgressBar.start() 48 | while self.ProcessingResults.ready() == False: 49 | self.Update() 50 | time.sleep(CheckInterval) 51 | time.sleep(CheckInterval) 52 | self.Update() 53 | self.ProgressBar.finish() 54 | self.EndTime = time.time() 55 | print("Processing finished.") 56 | # print "Processing results: ", self.TotalResults 57 | print("Time Elapsed: %.2fs, or %.2fmins, or %.2fhours" % ( 58 | (self.EndTime - self.StartTime), (self.EndTime - self.StartTime) / 60, 59 | (self.EndTime - self.StartTime) / 3600)) 60 | print("Processing finished.") 61 | print("Processing results: " + str(self.TotalResults)) 62 | print("Time Elapsed: %.2fs, or %.2fmins, or %.2fhours" % ( 63 | (self.EndTime - self.StartTime), (self.EndTime - self.StartTime) / 60, (self.EndTime - self.StartTime) / 3600)) 64 | return -------------------------------------------------------------------------------- /test/vanilla_test.py: -------------------------------------------------------------------------------- 1 | from absl.testing import absltest 2 | from absl.testing import parameterized 3 | import tempfile 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from sklearn.datasets import load_breast_cancer 8 | 9 | from core.ensemble.vanilla import Vanilla 10 | from core.ensemble.dataset_lib import build_dataset_from_numerical_data 11 | 12 | 13 | class MyTestCaseVanilla(parameterized.TestCase): 14 | def setUp(self): 15 | self.x_np, self.y_np = load_breast_cancer(return_X_y=True) 16 | self.train_dataset_v1 = build_dataset_from_numerical_data((self.x_np, self.y_np)) 17 | self.val_dataset_v1 = build_dataset_from_numerical_data((self.x_np, self.y_np)) 18 | 19 | def test_vanilla_dnn(self): 20 | with tempfile.TemporaryDirectory() as output_dir: 21 | vanilla = Vanilla(architecture_type='dnn', 22 | model_directory=output_dir) 23 | vanilla.fit(self.train_dataset_v1, self.val_dataset_v1, input_dim=self.x_np.shape[1]) 24 | 25 | res = vanilla.predict(self.x_np) 26 | self.assertEqual(vanilla.get_n_members(), vanilla.n_members) 27 | self.assertTrue(res.shape == (self.x_np.shape[0], vanilla.n_members, 1)) 28 | 29 | vanilla.evaluate(self.x_np, self.y_np) 30 | 31 | def test_vanilla_textcnn(self): 32 | with tempfile.TemporaryDirectory() as output_dir: 33 | vanilla = Vanilla(architecture_type='text_cnn', 34 | model_directory=output_dir) 35 | x = np.random.randint(0, 256, (10, 10)) 36 | y = np.random.choice(2, 10) 37 | train_dataset = build_dataset_from_numerical_data((x,y)) 38 | val_dataset = build_dataset_from_numerical_data((x,y)) 39 | vanilla.fit(train_dataset, val_dataset) 40 | res = vanilla.predict(x) 41 | self.assertTrue(res.shape == (x.shape[0], vanilla.n_members, 1)) 42 | vanilla.evaluate(x, y) 43 | 44 | def test_vanilla_multimodalitynn(self): 45 | with tempfile.TemporaryDirectory() as output_dir: 46 | vanilla = Vanilla(architecture_type='multimodalitynn', 47 | model_directory=output_dir) 48 | x = [self.x_np] * 5 49 | train_data = build_dataset_from_numerical_data(tuple(x)) 50 | train_y = build_dataset_from_numerical_data(self.y_np) 51 | train_dataset = tf.data.Dataset.zip((train_data, train_y)) 52 | val_data = build_dataset_from_numerical_data(tuple(x)) 53 | val_y = build_dataset_from_numerical_data(self.y_np) 54 | val_dataset = tf.data.Dataset.zip((val_data, val_y)) 55 | vanilla.fit(train_dataset, val_dataset, input_dim=[self.x_np.shape[1]]*5) 56 | res = vanilla.predict(x) 57 | self.assertTrue(res.shape == (self.x_np.shape[0], vanilla.n_members, 1)) 58 | vanilla.evaluate(x, self.y_np) 59 | 60 | def test_vanilla_r2d2(self): 61 | with tempfile.TemporaryDirectory() as output_dir: 62 | vanilla = Vanilla(architecture_type='r2d2', 63 | model_directory=output_dir) 64 | x = np.random.uniform(0., 1., size=(10, 299, 299, 3)) 65 | y = np.random.choice(2, 10) 66 | train_dataset = build_dataset_from_numerical_data((x, y)) 67 | val_dataset = build_dataset_from_numerical_data((x, y)) 68 | vanilla.fit(train_dataset, val_dataset, input_dim=(299, 299, 3)) 69 | res = vanilla.predict(x) 70 | self.assertTrue(res.shape == (x.shape[0], vanilla.n_members, 1)) 71 | vanilla.evaluate(x, y) 72 | 73 | def test_vanilla_droidectc(self): 74 | with tempfile.TemporaryDirectory() as output_dir: 75 | vanilla = Vanilla(architecture_type='droidectc', 76 | model_directory=output_dir) 77 | 78 | x = np.random.randint(0, 10000, size=(10, 1000)) 79 | y = np.random.choice(2, 10) 80 | train_dataset = build_dataset_from_numerical_data((x, y)) 81 | val_dataset = build_dataset_from_numerical_data((x, y)) 82 | vanilla.fit(train_dataset, val_dataset) 83 | res = vanilla.predict(x) 84 | self.assertTrue(res.shape == (x.shape[0], vanilla.n_members, 1)) 85 | vanilla.evaluate(x, y) 86 | 87 | 88 | if __name__ == '__main__': 89 | absltest.main() 90 | -------------------------------------------------------------------------------- /test/mc_dropout_test.py: -------------------------------------------------------------------------------- 1 | from absl.testing import absltest 2 | from absl.testing import parameterized 3 | import tempfile 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from sklearn.datasets import load_breast_cancer 8 | 9 | from core.ensemble.mc_dropout import MCDropout 10 | from core.ensemble.dataset_lib import build_dataset_from_numerical_data 11 | 12 | 13 | class MyTestCaseMCDropout(parameterized.TestCase): 14 | def setUp(self): 15 | self.x_np, self.y_np = load_breast_cancer(return_X_y=True) 16 | self.train_dataset_v1 = build_dataset_from_numerical_data((self.x_np, self.y_np)) 17 | self.val_dataset_v1 = build_dataset_from_numerical_data((self.x_np, self.y_np)) 18 | 19 | def test_dnn(self): 20 | with tempfile.TemporaryDirectory() as output_dir: 21 | mcdropout = MCDropout(architecture_type='dnn', 22 | model_directory=output_dir) 23 | mcdropout.fit(self.train_dataset_v1, self.val_dataset_v1, input_dim=self.x_np.shape[1]) 24 | 25 | res = mcdropout.predict(self.x_np) 26 | self.assertEqual(mcdropout.get_n_members(), mcdropout.n_members) 27 | self.assertTrue(res.shape == (self.x_np.shape[0], mcdropout.hparam.n_sampling, 1)) 28 | 29 | mcdropout.evaluate(self.x_np, self.y_np) 30 | 31 | def test_textcnn(self): 32 | with tempfile.TemporaryDirectory() as output_dir: 33 | mcdropout = MCDropout(architecture_type='text_cnn', 34 | model_directory=output_dir) 35 | x = np.random.randint(0, 256, (10, 10)) 36 | y = np.random.choice(2, 10) 37 | train_dataset = build_dataset_from_numerical_data((x, y)) 38 | val_dataset = build_dataset_from_numerical_data((x, y)) 39 | mcdropout.fit(train_dataset, val_dataset) 40 | res = mcdropout.predict(x) 41 | self.assertTrue(res.shape == (x.shape[0], mcdropout.hparam.n_sampling, 1)) 42 | mcdropout.evaluate(x, y) 43 | 44 | def test_multimodalitynn(self): 45 | with tempfile.TemporaryDirectory() as output_dir: 46 | mcdropout = MCDropout(architecture_type='multimodalitynn', 47 | model_directory=output_dir) 48 | x = [self.x_np] * 5 49 | train_data = build_dataset_from_numerical_data(tuple(x)) 50 | train_y = build_dataset_from_numerical_data(self.y_np) 51 | train_dataset = tf.data.Dataset.zip((train_data, train_y)) 52 | val_data = build_dataset_from_numerical_data(tuple(x)) 53 | val_y = build_dataset_from_numerical_data(self.y_np) 54 | val_dataset = tf.data.Dataset.zip((val_data, val_y)) 55 | mcdropout.fit(train_dataset, val_dataset, input_dim=[self.x_np.shape[1]] * 5) 56 | res = mcdropout.predict(x) 57 | self.assertTrue(res.shape == (self.x_np.shape[0], mcdropout.hparam.n_sampling, 1)) 58 | mcdropout.evaluate(x, self.y_np) 59 | 60 | def test_r2d2(self): 61 | with tempfile.TemporaryDirectory() as output_dir: 62 | mcdropout = MCDropout(architecture_type='r2d2', 63 | model_directory=output_dir) 64 | x = np.random.uniform(0., 1., size=(10, 299, 299, 3)) 65 | y = np.random.choice(2, 10) 66 | train_dataset = build_dataset_from_numerical_data((x, y)) 67 | val_dataset = build_dataset_from_numerical_data((x, y)) 68 | mcdropout.fit(train_dataset, val_dataset, input_dim=(299, 299, 3)) 69 | res = mcdropout.predict(x) 70 | self.assertTrue(res.shape == (x.shape[0], mcdropout.hparam.n_sampling, 1)) 71 | mcdropout.evaluate(x, y) 72 | 73 | def test_droidectc(self): 74 | with tempfile.TemporaryDirectory() as output_dir: 75 | mcdropout = MCDropout(architecture_type='droidectc', 76 | model_directory=output_dir) 77 | 78 | x = np.random.randint(0, 10000, size=(10, 1000)) 79 | y = np.random.choice(2, 10) 80 | train_dataset = build_dataset_from_numerical_data((x, y)) 81 | val_dataset = build_dataset_from_numerical_data((x, y)) 82 | mcdropout.fit(train_dataset, val_dataset) 83 | res = mcdropout.predict(x) 84 | self.assertTrue(res.shape == (x.shape[0], mcdropout.hparam.n_sampling, 1)) 85 | mcdropout.evaluate(x, y) 86 | 87 | 88 | if __name__ == '__main__': 89 | absltest.main() 90 | -------------------------------------------------------------------------------- /core/ensemble/model_hp.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | _TRAIN_HP_TEMPLATE = namedtuple('training', 4 | ['random_seed', 'n_epochs', 'batch_size', 'learning_rate', 'clipvalue', 'interval']) 5 | train_hparam = _TRAIN_HP_TEMPLATE(random_seed=23456, 6 | n_epochs=30, 7 | batch_size=16, 8 | learning_rate=0.001, 9 | clipvalue=100., 10 | interval=2 # saving model weights 11 | ) 12 | 13 | # these hyper-paramters of finetuning are used for tuning r2d2 model, which is deprecated owing to the issue of effectiveness 14 | _FINETUNE_TRAIN_HP_TEMPLATE = namedtuple('finetuning', 15 | ['random_seed', 'n_epochs', 'batch_size', 'learning_rate', 'n_epochs_ft', 16 | 'learning_rate_ft', 'unfreezed_layers', 'clipvalue', 'interval']) 17 | finetuning_hparam = _FINETUNE_TRAIN_HP_TEMPLATE(random_seed=23456, 18 | n_epochs=10, 19 | batch_size=16, 20 | learning_rate=0.0001, 21 | n_epochs_ft=20, 22 | learning_rate_ft=0.00001, 23 | unfreezed_layers=3, 24 | clipvalue=100., 25 | interval=2 # saving model weights 26 | ) 27 | 28 | _MC_DROPOUT_HP_TEMPLATE = namedtuple('mc_dropout', ['dropout_rate', 'n_sampling']) 29 | mc_dropout_hparam = _MC_DROPOUT_HP_TEMPLATE(dropout_rate=0.4, 30 | n_sampling=10 31 | ) 32 | _BAYESIAN_HP_TEMPLATE = namedtuple('bayesian', ['n_sampling']) 33 | bayesian_ensemble_hparam = _BAYESIAN_HP_TEMPLATE(n_sampling=10) 34 | 35 | _DNN_HP_TEMPLATE = namedtuple('DNN', 36 | ['hidden_units', 'dropout_rate', 'activation', 'output_dim']) 37 | 38 | dnn_hparam = _DNN_HP_TEMPLATE(hidden_units=[200, 200], # DNN has two hidden layers with each having 200 neurons 39 | dropout_rate=0.4, 40 | activation='relu', 41 | output_dim=1 # binary classification# 42 | ) 43 | 44 | _TEXT_CNN_HP_TEMPLATE = namedtuple('textCNN', 45 | ['hidden_units', 'dropout_rate', 'activation', 'output_dim', 46 | 'vocab_size', 'n_embedding_dim', 'n_conv_filters', 'kernel_size', 47 | 'max_sequence_length', 48 | 'use_spatial_dropout', 'use_conv_dropout']) 49 | 50 | text_cnn_hparam = _TEXT_CNN_HP_TEMPLATE(hidden_units=[200, 200], 51 | dropout_rate=0.4, 52 | activation='relu', 53 | vocab_size=256, 54 | n_embedding_dim=8, 55 | n_conv_filters=64, 56 | kernel_size=8, 57 | max_sequence_length=700000, # shall be large for promoting accuracy 58 | use_spatial_dropout=False, 59 | use_conv_dropout=False, 60 | output_dim=1 61 | ) 62 | 63 | _MULTIMOD_HP_TEMPLATE = namedtuple('multimodalitynn', 64 | ['hidden_units', 'dropout_rate', 'activation', 'output_dim', 65 | 'n_modalities', 'initial_hidden_units', 66 | ]) 67 | 68 | multimodalitynn_hparam = _MULTIMOD_HP_TEMPLATE( 69 | hidden_units=[200, 200], 70 | dropout_rate=0.4, 71 | activation='relu', 72 | n_modalities=5, 73 | initial_hidden_units=[500, 500], 74 | output_dim=1 75 | ) 76 | 77 | _R2D2_HP_TEMPLATE = namedtuple('r2d2', 78 | ['use_small_model', 'trainable', 'unfreezed_layers', 'dropout_rate', 'output_dim']) 79 | 80 | r2d2_hparam = _R2D2_HP_TEMPLATE( 81 | use_small_model=False, # if True : return MobileNetV2 else InceptionV3 82 | trainable=False, # if False : the learned parameters of InceptionV3 or MobileNetV2 will be frozen 83 | unfreezed_layers=8, 84 | dropout_rate=0.4, 85 | output_dim=1 86 | ) 87 | 88 | _DROIDETEC_HP_TEMPLATE = namedtuple('droidetec', 89 | ['vocab_size', 'n_embedding_dim', 'lstm_units', 'hidden_units', 90 | 'dropout_rate', 'max_sequence_length', 'output_dim']) 91 | 92 | droidetec_hparam = _DROIDETEC_HP_TEMPLATE( 93 | vocab_size=100000, # owing to the GPU memory size, we set 100,000 94 | n_embedding_dim=8, 95 | lstm_units=64, 96 | hidden_units=[200], 97 | dropout_rate=0.4, 98 | max_sequence_length=1000000, 99 | output_dim=1 100 | ) 101 | -------------------------------------------------------------------------------- /core/feature/apiseq/apiseq.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import warnings 4 | 5 | from collections import defaultdict 6 | from androguard.misc import AnalyzeAPK 7 | from tools import utils 8 | from config import logging 9 | 10 | current_dir = os.path.dirname(os.path.realpath(__file__)) 11 | logger = logging.getLogger("feature.apiseq") 12 | 13 | REMOVE_CLASS_HEAD_LIST = [ 14 | 'Ljava/', 'Ljavax/' 15 | ] 16 | 17 | RETAIN_CLASS_HEAD_LIST = [ 18 | 'Landroid/', 19 | 'Lcom/android/internal/util/', 20 | 'Ldalvik/', 21 | 'Lorg/apache/', 22 | 'Lorg/json/', 23 | 'Lorg/w3c/dom/', 24 | 'Lorg/xml/sax', 25 | 'Lorg/xmlpull/v1/', 26 | 'Ljunit/' 27 | ] 28 | 29 | 30 | def _check_class(class_name): 31 | for cls_head in REMOVE_CLASS_HEAD_LIST: 32 | if class_name.startswith(cls_head): 33 | return False 34 | for cls_head in RETAIN_CLASS_HEAD_LIST: 35 | if class_name.startswith(cls_head): 36 | return True 37 | 38 | return False 39 | 40 | 41 | def _dfs(api, nodes, seq=[], visited=[]): 42 | if api not in nodes.keys(): 43 | seq.append(api) 44 | else: 45 | visited.append(api) 46 | for elem in nodes[api]: 47 | if elem in visited: 48 | seq.append(elem) 49 | else: 50 | _dfs(elem, nodes, seq, visited) 51 | 52 | 53 | def get_api_sequence(apk_path, save_path): 54 | """ 55 | produce an api call sequence for an apk 56 | :param apk_path: an apk path 57 | :param save_path: path for saving resulting feature 58 | :return: (status, back_path_name) 59 | """ 60 | try: 61 | # obtain tow dictionaries: xref_from and xref_to, of which key is class-method name 62 | # and value is the caller or callee. 63 | _, _, dx = AnalyzeAPK(apk_path) 64 | mth_callers = defaultdict(list) 65 | mth_callees = defaultdict(list) 66 | for cls_obj in dx.get_classes(): # ClassAnalysis 67 | if cls_obj.is_external(): 68 | continue 69 | cls_name = cls_obj.name 70 | for mth_obj in cls_obj.get_methods(): 71 | if mth_obj.is_external(): 72 | continue 73 | 74 | m = mth_obj.get_method() # dvm.EncodedMethod 75 | cls_mth_name = cls_name + '->' + m.name + m.proto 76 | # get callers 77 | mth_callers[cls_mth_name] = [] 78 | for _, call, _ in mth_obj.get_xref_from(): 79 | if _check_class(call.class_name): 80 | mth_callers[cls_mth_name].append(call.class_name + '->' + call.name + call.proto) 81 | # get callees sequentially 82 | for instruction in m.get_instructions(): 83 | opcode = instruction.get_name() 84 | if 'invoke-' in opcode: 85 | code_body = instruction.get_output() 86 | if '->' not in code_body: 87 | continue 88 | head_part, rear_part = code_body.split('->') 89 | class_name = head_part.strip().split(' ')[-1] 90 | mth_name_callee = class_name + '->' + rear_part 91 | if _check_class(mth_name_callee): 92 | mth_callees[cls_mth_name].append(mth_name_callee) 93 | 94 | # look for the root call 95 | root_calls = [] 96 | num_of_calls = len(mth_callers.items()) 97 | if num_of_calls == 0: 98 | raise ValueError("No callers") 99 | for k in mth_callers.keys(): 100 | if (len(mth_callers[k]) <= 0) and (len(mth_callees[k]) > 0): 101 | root_calls.append(k) 102 | 103 | if len(root_calls) == 0: 104 | warnings.warn("Cannot find a root call, instead, randomly pick up one.") 105 | import random 106 | id = random.choice(range(num_of_calls)) 107 | root_calls.append(mth_callers.keys()[id]) 108 | 109 | # generate sequence 110 | api_sequence = [] 111 | for root_call in root_calls: 112 | sub_seq = [] 113 | visited_nodes = [] 114 | _dfs(root_call, mth_callees, sub_seq, visited_nodes) 115 | api_sequence.extend(sub_seq) 116 | # dump feature 117 | utils.dump_txt('\n'.join(api_sequence), save_path) 118 | return save_path 119 | except Exception as e: 120 | if len(e.args) > 0: 121 | e.args = e.args + (apk_path,) 122 | return e 123 | 124 | 125 | def load_feature(save_path): 126 | return utils.read_txt(save_path) 127 | 128 | 129 | def wrapper_load_feature(save_path): 130 | try: 131 | return load_feature(save_path) 132 | except Exception as e: 133 | return e 134 | 135 | 136 | def _mapping(path_to_feature, dictionary): 137 | features = load_feature(path_to_feature) 138 | _feature = [idx for idx in list(map(dictionary.get, features)) if idx is not None] 139 | if len(_feature) == 0: 140 | _feature = [dictionary.get('sos')] 141 | warnings.warn("Produce zero feature vector.") 142 | return _feature 143 | 144 | 145 | def wrapper_mapping(ptuple): 146 | try: 147 | return _mapping(*ptuple) 148 | except Exception as e: 149 | return e 150 | -------------------------------------------------------------------------------- /test/bayesian_ensemble_test.py: -------------------------------------------------------------------------------- 1 | from absl.testing import absltest 2 | from absl.testing import parameterized 3 | import tempfile 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | from sklearn.datasets import load_breast_cancer 8 | 9 | from core.ensemble.bayesian_ensemble import BayesianEnsemble 10 | from core.ensemble.dataset_lib import build_dataset_from_numerical_data 11 | 12 | architectures = ['dnn', 'text_cnn', 'multimodalitynn', 'r2d2', 'droidectc'] 13 | 14 | 15 | class MyTestCaseBayesianEnsemble(parameterized.TestCase): 16 | def setUp(self): 17 | self.x_np, self.y_np = load_breast_cancer(return_X_y=True) 18 | self.train_dataset_v1 = build_dataset_from_numerical_data((self.x_np, self.y_np)) 19 | self.val_dataset_v1 = build_dataset_from_numerical_data((self.x_np, self.y_np)) 20 | 21 | def test_dnn(self): 22 | with tempfile.TemporaryDirectory() as output_dir: 23 | bayesian_ensemble = BayesianEnsemble(architecture_type='dnn', 24 | model_directory=output_dir) 25 | bayesian_ensemble.fit(self.train_dataset_v1, self.val_dataset_v1, input_dim=self.x_np.shape[1]) 26 | 27 | res = bayesian_ensemble.predict(self.x_np) 28 | self.assertEqual(bayesian_ensemble.get_n_members(), bayesian_ensemble.n_members) 29 | self.assertTrue(res.shape == (self.x_np.shape[0], bayesian_ensemble.hparam.n_sampling, 1)) 30 | 31 | bayesian_ensemble.evaluate(self.x_np, self.y_np) 32 | 33 | def test_textcnn(self): 34 | with tempfile.TemporaryDirectory() as output_dir: 35 | x = np.random.randint(0, 256, (10, 10)) 36 | y = np.random.choice(2, 10) 37 | train_dataset = build_dataset_from_numerical_data((x, y)) 38 | val_dataset = build_dataset_from_numerical_data((x, y)) 39 | bayesian_ensemble = BayesianEnsemble(architecture_type='text_cnn', 40 | model_directory=output_dir) 41 | bayesian_ensemble.fit(train_dataset, val_dataset, input_dim=x.shape[1:]) 42 | res = bayesian_ensemble.predict(x) 43 | self.assertEqual(bayesian_ensemble.get_n_members(), bayesian_ensemble.n_members) 44 | self.assertTrue(res.shape == (x[0].shape[0], bayesian_ensemble.hparam.n_sampling, 1)) 45 | 46 | bayesian_ensemble.evaluate(x, y) 47 | 48 | def test_multimodalitynn(self): 49 | with tempfile.TemporaryDirectory() as output_dir: 50 | x = [self.x_np] * 5 51 | train_data = build_dataset_from_numerical_data(tuple(x)) 52 | train_y = build_dataset_from_numerical_data(self.y_np) 53 | train_dataset = tf.data.Dataset.zip((train_data, train_y)) 54 | val_data = build_dataset_from_numerical_data(tuple(x)) 55 | val_y = build_dataset_from_numerical_data(self.y_np) 56 | val_dataset = tf.data.Dataset.zip((val_data, val_y)) 57 | bayesian_ensemble = BayesianEnsemble(architecture_type='multimodalitynn', 58 | model_directory=output_dir) 59 | bayesian_ensemble.fit(train_dataset, val_dataset, input_dim=[x[i].shape[1] for i in range(len(x))]) 60 | 61 | res = bayesian_ensemble.predict(x) 62 | self.assertEqual(bayesian_ensemble.get_n_members(), bayesian_ensemble.n_members) 63 | self.assertTrue(res.shape == (x[0].shape[0], bayesian_ensemble.hparam.n_sampling, 1)) 64 | 65 | bayesian_ensemble.evaluate(x, self.y_np) 66 | 67 | def test_r2d2(self): 68 | with tempfile.TemporaryDirectory() as output_dir: 69 | x = np.random.uniform(0., 1., size=(10, 299, 299, 3)) 70 | y = np.random.choice(2, 10) 71 | train_dataset = build_dataset_from_numerical_data((x, y)) 72 | val_dataset = build_dataset_from_numerical_data((x, y)) 73 | 74 | bayesian_ensemble = BayesianEnsemble(architecture_type='r2d2', 75 | model_directory=output_dir) 76 | bayesian_ensemble.fit(train_dataset, val_dataset, input_dim=x.shape[1:]) 77 | 78 | res = bayesian_ensemble.predict(x) 79 | self.assertEqual(bayesian_ensemble.get_n_members(), bayesian_ensemble.n_members) 80 | self.assertTrue(res.shape == (x.shape[0], bayesian_ensemble.hparam.n_sampling, 1)) 81 | 82 | bayesian_ensemble.evaluate(x, y) 83 | 84 | def test_droidectc(self): 85 | with tempfile.TemporaryDirectory() as output_dir: 86 | x = np.random.randint(0, 10000, size=(10, 1000)) 87 | y = np.random.choice(2, 10) 88 | train_dataset = build_dataset_from_numerical_data((x, y)) 89 | val_dataset = build_dataset_from_numerical_data((x, y)) 90 | bayesian_ensemble = BayesianEnsemble(architecture_type='droidectc', 91 | model_directory=output_dir) 92 | bayesian_ensemble.fit(train_dataset, val_dataset, input_dim=x.shape[1:]) 93 | 94 | res = bayesian_ensemble.predict(x) 95 | self.assertEqual(bayesian_ensemble.get_n_members(), bayesian_ensemble.n_members) 96 | self.assertTrue(res.shape == (x[0].shape[0], bayesian_ensemble.hparam.n_sampling, 1)) 97 | 98 | bayesian_ensemble.evaluate(x, y) 99 | 100 | 101 | if __name__ == '__main__': 102 | absltest.main() 103 | -------------------------------------------------------------------------------- /experiments/oos.py: -------------------------------------------------------------------------------- 1 | # conduct the group of 'out of distribution' experiments on drebin dataset 2 | import os 3 | import sys 4 | import random 5 | from collections import Counter 6 | 7 | import numpy as np 8 | from sklearn.model_selection import train_test_split 9 | 10 | from core.feature import feature_type_scope_dict, feature_type_vs_architecture 11 | from core.ensemble import ensemble_method_scope_dict 12 | from core.post_calibration.temperature_scaling import apply_temperature_scaling 13 | from tools import utils 14 | from config import config, logging 15 | 16 | logger = logging.getLogger('experiment.drebin_ood') 17 | 18 | 19 | # procedure of ood experiments 20 | # 1. build dataset 21 | # 2. preprocess data 22 | # 3. conduct prediction 23 | # 4. save results for statistical analysis 24 | 25 | 26 | def run_experiment(feature_type, ensemble_type, proc_numbers=2): 27 | """ 28 | run this group of experiments 29 | :param feature_type: the type of features (e.g., drebin, opcode, etc.), feature type associates to the model architecture 30 | :param ensemble_type: the ensemble method (e.g., vanilla, deep_ensemble, etc. 31 | :return: None 32 | """ 33 | ood_data, ood_y, input_dim = data_preprocessing(feature_type, proc_numbers) 34 | 35 | ensemble_obj = get_ensemble_object(ensemble_type) 36 | # instantiation 37 | arch_type = feature_type_vs_architecture.get(feature_type) 38 | model_saving_dir = config.get('experiments', 'drebin') 39 | if ensemble_type in ['vanilla', 'mc_dropout', 'bayesian']: 40 | ensemble_model = ensemble_obj(arch_type, base_model=None, model_directory=model_saving_dir) 41 | else: 42 | ensemble_model = ensemble_obj(arch_type, base_model=None, model_directory=model_saving_dir) 43 | 44 | ood_results = ensemble_model.predict(ood_data) 45 | utils.dump_joblib((ood_results, ood_y), os.path.join(config.get('experiments', 'oos'), 46 | '{}_{}_drebin_oos.res'.format(feature_type, ensemble_type))) 47 | 48 | 49 | def run_temperature_scaling(feature_type, ensemble_type, proc_numbers=2): 50 | ood_data, ood_y, input_dim = data_preprocessing(feature_type, proc_numbers) 51 | 52 | ensemble_obj = get_ensemble_object(ensemble_type) 53 | # instantiation 54 | arch_type = feature_type_vs_architecture.get(feature_type) 55 | model_saving_dir = config.get('experiments', 'drebin') 56 | ensemble_model = ensemble_obj(arch_type, base_model=None, model_directory=model_saving_dir) 57 | # temperature scaling 58 | 59 | temp_save_dir = os.path.join(config.get('drebin', 'intermediate_directory'), 60 | "{}_{}_temp.json".format(feature_type, ensemble_type)) 61 | if not os.path.exists(temp_save_dir): 62 | raise FileNotFoundError 63 | 64 | temperature = utils.load_json(temp_save_dir)['temperature'] 65 | probs = ensemble_model.predict(ood_data, use_prob=True) 66 | probs_scaling = apply_temperature_scaling(temperature, probs) 67 | utils.dump_joblib((probs_scaling, ood_y), os.path.join(config.get('experiments', 'oos'), 68 | '{}_{}_temperature_drebin_oos.res'.format(feature_type, ensemble_type))) 69 | 70 | 71 | def data_preprocessing(feature_type='drebin', proc_numbers=2): 72 | assert feature_type in feature_type_scope_dict.keys(), 'Expected {}, but {} are supported.'.format( 73 | feature_type, feature_type_scope_dict.keys()) 74 | benware_dir = config.get('oos', 'benware_dir') 75 | malware_dir = config.get('oos', 'malware_dir') 76 | android_features_saving_dir = config.get('metadata', 'naive_data_pool') 77 | intermediate_data_saving_dir = config.get('oos', 'intermediate_directory') 78 | feature_extractor = feature_type_scope_dict[feature_type](android_features_saving_dir, 79 | intermediate_data_saving_dir, 80 | update=False, 81 | proc_number=proc_numbers) 82 | 83 | save_path = os.path.join(intermediate_data_saving_dir, 'oos_database.' + feature_type) 84 | if os.path.exists(save_path): 85 | oos_filenames, oos_y = utils.read_joblib(save_path) 86 | oos_features = [os.path.join(android_features_saving_dir, filename) for filename in oos_filenames] 87 | else: 88 | mal_feature_list = feature_extractor.feature_extraction(malware_dir) 89 | n_malware = len(mal_feature_list) 90 | ben_feature_list = feature_extractor.feature_extraction(benware_dir) 91 | n_benware = len(ben_feature_list) 92 | oos_features = mal_feature_list + ben_feature_list 93 | oos_y = np.zeros((n_malware + n_benware,), dtype=np.int32) 94 | oos_y[:n_malware] = 1 95 | 96 | oos_filenames = [os.path.basename(path) for path in oos_features] 97 | utils.dump_joblib((oos_filenames, oos_y), save_path) 98 | 99 | # obtain data in a format for ML algorithms 100 | ood_data, input_dim = feature_extractor.feature2ipt(oos_features) 101 | return ood_data, oos_y, input_dim 102 | 103 | 104 | def get_ensemble_object(ensemble_type): 105 | assert ensemble_type in ensemble_method_scope_dict.keys(), '{} expected, but {} are supported'.format( 106 | ensemble_type, 107 | ','.join(ensemble_method_scope_dict.keys()) 108 | ) 109 | return ensemble_method_scope_dict[ensemble_type] 110 | -------------------------------------------------------------------------------- /core/ensemble/anchor_ensemble.py: -------------------------------------------------------------------------------- 1 | import os.path as path 2 | import time 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | 7 | from core.ensemble.vanilla import Vanilla 8 | from core.ensemble.model_hp import train_hparam, anchor_hparam 9 | from core.ensemble.model_lib import model_builder 10 | from tools import utils 11 | from config import logging 12 | 13 | logger = logging.getLogger('ensemble.vanilla') 14 | 15 | 16 | class AnchorEnsemble(Vanilla): 17 | def __init__(self, 18 | architecture_type='dnn', 19 | base_model=None, 20 | n_members=2, 21 | model_directory=None, 22 | name='ANCHOR'): 23 | """ 24 | initialization 25 | :param architecture_type: the type of base model 26 | :param base_model: an object of base model 27 | :param n_members: number of base models 28 | :param model_directory: a folder for saving ensemble weights 29 | """ 30 | super(AnchorEnsemble, self).__init__(architecture_type, base_model, n_members, model_directory) 31 | self.hparam = utils.merge_namedtuples(train_hparam, anchor_hparam) 32 | self.ensemble_type = 'anchor' 33 | self.name = name.lower() 34 | self.save_dir = path.join(self.model_directory, self.name) 35 | 36 | def build_model(self, input_dim=None): 37 | """ 38 | Build an ensemble model -- only the homogeneous structure is considered 39 | :param input_dim: integer or list, input dimension shall be set in some cases under eager mode 40 | """ 41 | callable_graph = model_builder(self.architecture_type) 42 | 43 | @callable_graph(input_dim) 44 | def _builder(): 45 | seed = np.random.choice(self.hparam.random_seed) 46 | return utils.produce_layer(self.ensemble_type, 47 | scale=self.hparam.scale, 48 | batch_size=self.hparam.batch_size, 49 | seed=seed) 50 | 51 | self.base_model = _builder() 52 | 53 | def model_generator(self): 54 | try: 55 | for m in range(self.n_members): 56 | self.base_model = None 57 | self.load_ensemble_weights(m) 58 | yield self.base_model 59 | except Exception as e: 60 | raise Exception("Cannot load model:{}.".format(str(e))) 61 | 62 | def fit(self, train_set, validation_set=None, input_dim=None, **kwargs): 63 | """ 64 | fit the ensemble by producing a lists of model weights 65 | :param train_set: tf.data.Dataset, the type shall accommodate to the input format of Tensorflow models 66 | :param validation_set: validation data, optional 67 | :param input_dim: integer or list, input dimension except for the batch size 68 | """ 69 | # training 70 | logger.info("hyper-parameters:") 71 | logger.info(dict(self.hparam._asdict())) 72 | logger.info("...training start!") 73 | np.random.seed(self.hparam.random_seed) 74 | train_set = train_set.shuffle(buffer_size=100, reshuffle_each_iteration=True) 75 | for member_idx in range(self.n_members): 76 | self.base_model = None 77 | self.build_model(input_dim=input_dim) 78 | 79 | self.base_model.compile( 80 | optimizer=tf.keras.optimizers.Adam(learning_rate=self.hparam.learning_rate), 81 | loss=tf.keras.losses.BinaryCrossentropy(), 82 | metrics=[tf.keras.metrics.BinaryAccuracy()], 83 | ) 84 | for epoch in range(self.hparam.n_epochs): 85 | total_time = 0. 86 | msg = 'Epoch {}/{}, and member {}/{}'.format(epoch + 1, 87 | self.hparam.n_epochs, member_idx + 1, 88 | self.n_members) 89 | print(msg) 90 | start_time = time.time() 91 | self.base_model.fit(train_set, 92 | epochs=epoch + 1, 93 | initial_epoch=epoch, 94 | validation_data=validation_set 95 | ) 96 | end_time = time.time() 97 | total_time += end_time - start_time 98 | # saving 99 | logger.info('Training ensemble costs {} seconds at this epoch'.format(total_time)) 100 | if (epoch + 1) % self.hparam.interval == 0: 101 | self.save_ensemble_weights(member_idx) 102 | 103 | def save_ensemble_weights(self, member_idx=0): 104 | if not path.exists(path.join(self.save_dir, self.architecture_type + '_{}'.format(member_idx))): 105 | utils.mkdir(path.join(self.save_dir, self.architecture_type + '_{}'.format(member_idx))) 106 | # save model configuration 107 | self.base_model.save(path.join(self.save_dir, self.architecture_type + '_{}'.format(member_idx))) 108 | print("Save the model to directory {}".format(self.save_dir)) 109 | 110 | def load_ensemble_weights(self, member_idx=0): 111 | if path.exists(path.join(self.save_dir, self.architecture_type + '_{}'.format(member_idx))): 112 | self.base_model = tf.keras.models.load_model( 113 | path.join(self.save_dir, self.architecture_type + '_{}'.format(member_idx))) 114 | 115 | def get_n_members(self): 116 | return self.n_members 117 | 118 | def update_weights(self, member_idx, model_weights, optimizer_weights=None): 119 | raise NotImplementedError 120 | -------------------------------------------------------------------------------- /core/ensemble/bayesian_ensemble.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import tensorflow as tf 4 | 5 | from core.ensemble.vanilla import model_builder 6 | from core.ensemble.mc_dropout import MCDropout 7 | from core.ensemble.model_hp import train_hparam, bayesian_ensemble_hparam 8 | from config import logging, ErrorHandler 9 | from tools import utils 10 | 11 | logger = logging.getLogger('ensemble.bayesian_ensemble') 12 | logger.addHandler(ErrorHandler) 13 | 14 | 15 | class BayesianEnsemble(MCDropout): 16 | def __init__(self, 17 | architecture_type='dnn', 18 | base_model=None, 19 | n_members=1, 20 | model_directory=None, 21 | name='BAYESIAN_ENSEMBLE' 22 | ): 23 | super(BayesianEnsemble, self).__init__(architecture_type, 24 | base_model, 25 | n_members, 26 | model_directory, 27 | name) 28 | self.hparam = utils.merge_namedtuples(train_hparam, bayesian_ensemble_hparam) 29 | self.ensemble_type = 'bayesian' 30 | 31 | def build_model(self, input_dim=None, scaler=1. / 10000): 32 | """ 33 | Build an ensemble model -- only the homogeneous structure is considered 34 | :param input_dim: integer or list, input dimension shall be set in some cases under eager mode 35 | :param scaler: float value in the rage of [0, 1], weighted kl divergence 36 | """ 37 | callable_graph = model_builder(self.architecture_type) 38 | 39 | @callable_graph(input_dim) 40 | def _builder(): 41 | return utils.produce_layer(self.ensemble_type, kl_scaler=scaler) 42 | 43 | self.base_model = _builder() 44 | return 45 | 46 | def fit(self, train_set, validation_set=None, input_dim=None, **kwargs): 47 | """ 48 | fit the ensemble by producing a lists of model weights 49 | :param train_set: tf.data.Dataset, the type shall accommodate to the input format of Tensorflow models 50 | :param validation_set: validation data, optional 51 | :param input_dim: integer or list, input dimension except for the batch size 52 | """ 53 | 54 | # training preparation 55 | if self.base_model is None: 56 | # scaler = 1. / (len(list(train_set)) * self.hparam.batch_size) # time-consuming 57 | scaler = 1. / 50000. 58 | self.build_model(input_dim=input_dim, scaler=scaler) 59 | 60 | self.base_model.compile( 61 | optimizer=tf.keras.optimizers.Adam(learning_rate=self.hparam.learning_rate, 62 | clipvalue=self.hparam.clipvalue), 63 | loss=tf.keras.losses.BinaryCrossentropy(), 64 | metrics=[tf.keras.metrics.BinaryAccuracy()], 65 | experimental_run_tf_function=False 66 | ) 67 | 68 | # training 69 | logger.info("hyper-parameters:") 70 | logger.info(dict(self.hparam._asdict())) 71 | logger.info("...training start!") 72 | 73 | best_val_accuracy = 0. 74 | total_time = 0. 75 | for epoch in range(self.hparam.n_epochs): 76 | train_acc = 0. 77 | val_acc = 0. 78 | 79 | for member_idx in range(self.n_members): 80 | if member_idx < len(self.weights_list): # loading former weights 81 | self.base_model.set_weights(self.weights_list[member_idx]) 82 | self.base_model.optimizer.set_weights(self._optimizers_dict[member_idx]) 83 | elif member_idx == 0: 84 | pass # do nothing 85 | else: 86 | self.reinitialize_base_model() 87 | 88 | msg = 'Epoch {}/{}, member {}/{}, and {} member(s) in list'.format(epoch + 1, 89 | self.hparam.n_epochs, 90 | member_idx + 1, 91 | self.n_members, 92 | len(self.weights_list)) 93 | print(msg) 94 | start_time = time.time() 95 | history = self.base_model.fit(train_set, 96 | epochs=epoch + 1, 97 | initial_epoch=epoch, 98 | validation_data=validation_set 99 | ) 100 | train_acc += history.history['binary_accuracy'][0] 101 | val_acc += history.history['val_binary_accuracy'][0] 102 | self.update_weights(member_idx, 103 | self.base_model.get_weights(), 104 | self.base_model.optimizer.get_weights()) 105 | end_time = time.time() 106 | total_time += end_time - start_time 107 | # saving 108 | logger.info('Training ensemble costs {} in total (including validation).'.format(total_time)) 109 | train_acc = train_acc / self.n_members 110 | val_acc = val_acc / self.n_members 111 | msg = 'Epoch {}/{}: training accuracy {:.5f}, validation accuracy {:.5f}.'.format( 112 | epoch + 1, self.hparam.n_epochs, train_acc, val_acc 113 | ) 114 | logger.info(msg) 115 | if (epoch + 1) % self.hparam.interval == 0: 116 | if val_acc >= best_val_accuracy: 117 | self.save_ensemble_weights() 118 | best_val_accuracy = val_acc 119 | msg = '\t The best validation accuracy is {:.5f}, obtained at epoch {}/{}'.format( 120 | best_val_accuracy, epoch + 1, self.hparam.n_epochs 121 | ) 122 | logger.info(msg) 123 | return 124 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Uncertainty Quantification for Android Malware Detectors 2 | 3 | This code repository is for our ACSAC 2021 paper (to appear), entitled **Can We Leverage Predictive Uncertainty to Detect Dataset Shift and Adversarial Examples in Android Malware Detection?**. 4 | 5 | ## Overview 6 | Our aim to explore the uncertainty quantification to harden malware detectors in the realistic environments (i.e., natural adversaries exist). 7 | This approach is rarely investigated in the context of malware detection, where the properties of dataset shift are different from other domains (e.g., image). 8 | Therefore we are motivated to evaluate the quality of predictive uncertainties inherent in malware detectors under the dataset shift. 9 | Specifically, we consider 4 Android malware detectors, including [DeepDrebin](http://patrickmcdaniel.org/pubs/esorics17.pdf), [MultimodalNN](https://ieeexplore.ieee.org/document/8443370), [DeepDroid](https://dl.acm.org/doi/10.1145/3029806.3029823) and [Droidetec](https://arxiv.org/abs/2002.03594), and 6 calibration methods, including [Vanilla](https://arxiv.org/abs/1906.02530), [Temp scaling](https://arxiv.org/abs/1706.04599), [Monto Carlo dropout](https://arxiv.org/abs/1506.02142), [variational Bayesian Inference](https://arxiv.org/abs/1906.02530), [Deep Ensemble](https://arxiv.org/abs/1612.01474) and Weighted deep ensemble. 10 | The dataset shift is specified as *out of source*, *temporal covariate shift* or *adversarial evasion attacks*. 11 | 12 | ## Dependencies: 13 | We develop the codes on Windows operation system, and run the codes on Ubuntu 18.04. The codes depend on Python 3.6. Other packages (e.g., TensorFlow) can be found in the `./requirements.txt`. 14 | 15 | ## Configuration & Usage 16 | #### 1. Datasets 17 | * Three datasets are leveraged, namely that [Drebin](https://www.sec.cs.tu-bs.de/~danarp/drebin/), [VirusShare_Android_APK_2013](https://virusshare.com/) and [Androzoo](https://androzoo.uni.lu/). 18 | Note that for the security consideration, these three datasets are required to follow the policies of their own to obtain the Android applications. 19 | 20 |   For Drebin, we can download the malicious APKs from the official website and we provides sha256 codes of a portion of Drebin benign APKs, for which the corresponding APKs can be download from [Androzoo](https://androzoo.uni.lu/). 21 | 22 |   For Androzoo, we use the dataset built by researchers [Pendlebury et al.](https://www.usenix.org/conference/usenixsecurity19/presentation/pendlebury) All APKs can be downloaded from Androzoo. 23 | 24 |   For Virusshare, we use the file named `VirusShare_Android_APK_2013.zip`. 25 | 26 |   For adversarial APKs, we resort to this [repository](https://github.com/deqangss/adv-dnn-ens-malware). 27 | 28 | * We additionally provide the preprocessed data files which are available at an anonymous [url](https://mega.nz/folder/bF8RQAAI#HeIhpUzDdqCdWdh4bAIZbg) (the size of unzip folder is ~213GB). 29 | 30 | #### 2. Configure 31 | For the purpose of convenience, we provide a `conf` (Windows platform) / `conf-server` (Ubuntu) file to assist the customization (Please pick one and rename it `config` to use rather than both). Before running, all things are changed in the following: 32 | * Modify the `project_root=/absolute/path/to/malware-uncertainty/`. 33 | 34 | * Modify the `database_dir=/absolute/path/to/datasets`. For more details (Optionally), there are `Drebin` or `Androzoo` malware datasets in this directory with the structure: 35 | ``` 36 | datasets 37 | |---drebin 38 | |---malicious_samples % malicious apps folder 39 | |---benign_samples % benign apps foler 40 | |---androzoo_tesseract 41 | |---malicious_samples 42 | |---benign_samples 43 | | date_stamp.json % date stamp for each app, we will provide 44 | |---VirusShare_Android_APK_2013 45 | |---malicious_samples 46 | |---benign_samples 47 | |---naive_data % saving the preprocessed data files 48 | ... 49 | ``` 50 | If no real apps are considered, the preprocessing data files make the project work as well. In this case, we need continue to configure the followings: 51 | * Download the `datasets` from the anonymous [url](https://mega.nz/folder/bF8RQAAI#HeIhpUzDdqCdWdh4bAIZbg), and put the folder in the project root directory, namely `malware-uncertainty`. Please Note that this `datasets` is not necessary the same as the directory of `database_dir` in the second step. 52 | * Download the `naive_data` from the anonymous [url](https://mega.nz/folder/bF8RQAAI#HeIhpUzDdqCdWdh4bAIZbg), and put the folder in the `database_dir` directory, which is configured in the second step (need unzip, `mv naive_data.tar.gz database_dir; cd database_dir; tar -xvzf naive_data.tar.gz ./`). 53 | 54 | #### 3. Usage 55 | We suggest users to create a conda environment to run the codes. In this spirit, the following instructions may be helpful: 56 | 1. Create a new environment: `conda create -n mal-uncertainty python=3.6` 57 | 2. Activate the environment and install dependencies: `conda activate mal-uncertainty` and `pip install -r requirements.txt` 58 | 3. Next step: 59 | * For training, all scripts are listed in `./run.sh` 60 | * And then for producing figures and table data, the python code is `./experiments/table-figures.py` (we have not implemented this part for the malware detector `Droidetec`) 61 | 62 | ## Warning 63 | * It is usually time consuming to perform feature extraction on Android applications. 64 | * Two detectors (DeepDroid and Droidetec) are both RAM and computation consuming because the huge long sequence is used for promoting detection accuracy 65 | 66 | 67 | ## License && Issues 68 | 69 | We will make our codes public available under a formal license. For now, this is still an ongoing work and we plan to report more results in the future work. It is worth reminding that we found there two issues when checking our codes: 70 | * No random seed set for friendly reproducing results exactly as the paper; nevertheless, the similar results can be achieved. 71 | * The training, validation, and test datasets are split randomly, leading to a mess of results. 72 | 73 | ## Contact 74 | 75 | Any questions, please do not hesitate to contact us (`Shouhuai Xu` email: `sxu@uccs.edu`, `Deqiang Li` email: `lideqiang@njust.edu.cn`) 76 | 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /experiments/adv.py: -------------------------------------------------------------------------------- 1 | # conduct the group of 'out of distribution' experiments on drebin dataset 2 | import os 3 | import sys 4 | import random 5 | from collections import Counter 6 | 7 | import numpy as np 8 | from sklearn.model_selection import train_test_split 9 | 10 | from core.feature import feature_type_scope_dict, feature_type_vs_architecture 11 | from core.ensemble import ensemble_method_scope_dict 12 | from core.post_calibration.temperature_scaling import apply_temperature_scaling 13 | from tools import utils 14 | from config import config, logging 15 | 16 | logger = logging.getLogger('experiment.drebin_adv') 17 | 18 | 19 | # procedure of ood experiments 20 | # 1. build dataset 21 | # 2. preprocess data 22 | # 3. conduct prediction 23 | # 4. save results for statistical analysis 24 | 25 | 26 | def run_experiment(feature_type, ensemble_type, proc_numbers=2): 27 | """ 28 | run this group of experiments 29 | :param feature_type: the type of features (e.g., drebin, opcode, etc.), feature type associates to the model architecture 30 | :param ensemble_type: the ensemble method (e.g., vanilla, deep_ensemble, etc. 31 | :return: None 32 | """ 33 | prist_data, adv_data, prist_y, adv_y, input_dim = data_preprocessing(feature_type, proc_numbers) 34 | 35 | ensemble_obj = get_ensemble_object(ensemble_type) 36 | # instantiation 37 | arch_type = feature_type_vs_architecture.get(feature_type) 38 | model_saving_dir = config.get('experiments', 'drebin') 39 | if ensemble_type in ['vanilla', 'mc_dropout', 'bayesian']: 40 | ensemble_model = ensemble_obj(arch_type, base_model=None, model_directory=model_saving_dir) 41 | else: 42 | ensemble_model = ensemble_obj(arch_type, base_model=None, model_directory=model_saving_dir) 43 | 44 | prist_results = ensemble_model.predict(prist_data) 45 | adv_results = ensemble_model.predict(adv_data) 46 | utils.dump_joblib((prist_results, adv_results, prist_y, adv_y), 47 | os.path.join(config.get('experiments', 'adv'), '{}_{}_drebin_adv.res'.format(feature_type, ensemble_type))) 48 | # ensemble_model.evaluate(prist_data, prist_y) 49 | # ensemble_model.evaluate(adv_data, adv_y) 50 | 51 | 52 | def run_temperature_scaling(feature_type, ensemble_type, proc_numbers=2): 53 | prist_data, adv_data, prist_y, adv_y, input_dim = data_preprocessing(feature_type, proc_numbers) 54 | 55 | ensemble_obj = get_ensemble_object(ensemble_type) 56 | # instantiation 57 | arch_type = feature_type_vs_architecture.get(feature_type) 58 | model_saving_dir = config.get('experiments', 'drebin') 59 | ensemble_model = ensemble_obj(arch_type, base_model=None, model_directory=model_saving_dir) 60 | 61 | temp_save_dir = os.path.join(config.get('drebin', 'intermediate_directory'), 62 | "{}_{}_temp.json".format(feature_type, ensemble_type)) 63 | if not os.path.exists(temp_save_dir): 64 | raise FileNotFoundError 65 | 66 | temperature = utils.load_json(temp_save_dir)['temperature'] 67 | 68 | prist_probs = ensemble_model.predict(prist_data, use_prob=True) 69 | prist_prob_t = apply_temperature_scaling(temperature, prist_probs) 70 | adv_probs = ensemble_model.predict(adv_data, use_prob=True) 71 | adv_prob_t = apply_temperature_scaling(temperature, adv_probs) 72 | 73 | utils.dump_joblib((prist_prob_t, adv_prob_t, prist_y, adv_y), os.path.join(config.get('experiments', 'adv'), 74 | '{}_{}_temperature_drebin_adv.res'.format(feature_type, ensemble_type))) 75 | 76 | 77 | def data_preprocessing(feature_type='drebin', proc_numbers=2): 78 | assert feature_type in feature_type_scope_dict.keys(), 'Expected {}, but {} are supported.'.format( 79 | feature_type, feature_type_scope_dict.keys()) 80 | malware_pst_dir = config.get('adv', 'pristine_apk_dir') 81 | malware_adv_dir = os.path.join('adv', 'perturbed_apk_dir') 82 | android_features_saving_dir = config.get('metadata', 'naive_data_pool') 83 | intermediate_data_saving_dir = config.get('adv', 'intermediate_directory') 84 | feature_extractor = feature_type_scope_dict[feature_type](android_features_saving_dir, 85 | intermediate_data_saving_dir, 86 | update=False, 87 | proc_number=proc_numbers) 88 | 89 | save_path = os.path.join(intermediate_data_saving_dir, 'adv_database.' + feature_type) 90 | if os.path.exists(save_path): 91 | prist_filenames, adv_filenames, prist_y, adv_y = utils.read_joblib(save_path) 92 | prist_features = [os.path.join(android_features_saving_dir, filename) for filename in prist_filenames] 93 | adv_features = [os.path.join(android_features_saving_dir, filename) for filename in adv_filenames] 94 | else: 95 | prist_features = feature_extractor.feature_extraction(malware_pst_dir) 96 | prist_y = np.ones((len(prist_features),), dtype=np.int32) 97 | 98 | adv_features = feature_extractor.feature_extraction(malware_adv_dir) 99 | adv_y = np.ones((len(adv_features),), dtype=np.int32) 100 | 101 | prist_filenames = [os.path.basename(path) for path in prist_features] 102 | adv_filenames = [os.path.basename(path) for path in adv_features] 103 | utils.dump_joblib((prist_filenames, adv_filenames, prist_y, adv_y), save_path) 104 | 105 | # obtain data in a format for ML algorithms 106 | prist_data, input_dim = feature_extractor.feature2ipt(prist_features) 107 | adv_data, input_dim = feature_extractor.feature2ipt(adv_features) 108 | return prist_data, adv_data, prist_y, adv_y, input_dim 109 | 110 | 111 | def get_ensemble_object(ensemble_type): 112 | assert ensemble_type in ensemble_method_scope_dict.keys(), '{} expected, but {} are supported'.format( 113 | ensemble_type, 114 | ','.join(ensemble_method_scope_dict.keys()) 115 | ) 116 | return ensemble_method_scope_dict[ensemble_type] 117 | 118 | 119 | def _main(): 120 | # build_data() 121 | print(get_ensemble_object('vanilla')) 122 | 123 | 124 | if __name__ == '__main__': 125 | sys.exit(_main()) 126 | -------------------------------------------------------------------------------- /tools/progressbar/examples.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | import sys 5 | import time 6 | 7 | #from widgets import AnimatedMarker, Bar, BouncingBar, Counter, ETA, \ 8 | # AdaptiveETA, FileTransferSpeed, FormatLabel, Percentage, \ 9 | # ProgressBar, ReverseBar, RotatingMarker, \ 10 | # SimpleProgress, Timer 11 | 12 | from compat import * 13 | from widgets import * 14 | from progressbar import * 15 | 16 | examples = [] 17 | def example(fn): 18 | try: name = 'Example %d' % int(fn.__name__[7:]) 19 | except: name = fn.__name__ 20 | 21 | def wrapped(): 22 | try: 23 | sys.stdout.write('Running: %s\n' % name) 24 | fn() 25 | sys.stdout.write('\n') 26 | except KeyboardInterrupt: 27 | sys.stdout.write('\nSkipping example.\n\n') 28 | 29 | examples.append(wrapped) 30 | return wrapped 31 | 32 | @example 33 | def example0(): 34 | pbar = ProgressBar(widgets=[Percentage(), Bar()], maxval=300).start() 35 | for i in range(300): 36 | time.sleep(0.01) 37 | pbar.update(i+1) 38 | pbar.finish() 39 | 40 | @example 41 | def example1(): 42 | widgets = ['Test: ', Percentage(), ' ', Bar(marker=RotatingMarker()), 43 | ' ', ETA(), ' ', FileTransferSpeed()] 44 | pbar = ProgressBar(widgets=widgets, maxval=10000000).start() 45 | for i in range(1000000): 46 | # do something 47 | pbar.update(10*i+1) 48 | pbar.finish() 49 | 50 | @example 51 | def example2(): 52 | class CrazyFileTransferSpeed(FileTransferSpeed): 53 | """It's bigger between 45 and 80 percent.""" 54 | def update(self, pbar): 55 | if 45 < pbar.percentage() < 80: 56 | return 'Bigger Now ' + FileTransferSpeed.update(self,pbar) 57 | else: 58 | return FileTransferSpeed.update(self,pbar) 59 | 60 | widgets = [CrazyFileTransferSpeed(),' <<<', Bar(), '>>> ', 61 | Percentage(),' ', ETA()] 62 | pbar = ProgressBar(widgets=widgets, maxval=10000000) 63 | # maybe do something 64 | pbar.start() 65 | for i in range(2000000): 66 | # do something 67 | pbar.update(5*i+1) 68 | pbar.finish() 69 | 70 | @example 71 | def example3(): 72 | widgets = [Bar('>'), ' ', ETA(), ' ', ReverseBar('<')] 73 | pbar = ProgressBar(widgets=widgets, maxval=10000000).start() 74 | for i in range(1000000): 75 | # do something 76 | pbar.update(10*i+1) 77 | pbar.finish() 78 | 79 | @example 80 | def example4(): 81 | widgets = ['Test: ', Percentage(), ' ', 82 | Bar(marker='0',left='[',right=']'), 83 | ' ', ETA(), ' ', FileTransferSpeed()] 84 | pbar = ProgressBar(widgets=widgets, maxval=500) 85 | pbar.start() 86 | for i in range(100,500+1,50): 87 | time.sleep(0.2) 88 | pbar.update(i) 89 | pbar.finish() 90 | 91 | @example 92 | def example5(): 93 | pbar = ProgressBar(widgets=[SimpleProgress()], maxval=17).start() 94 | for i in range(17): 95 | time.sleep(0.2) 96 | pbar.update(i + 1) 97 | pbar.finish() 98 | 99 | @example 100 | def example6(): 101 | pbar = ProgressBar().start() 102 | for i in range(100): 103 | time.sleep(0.01) 104 | pbar.update(i + 1) 105 | pbar.finish() 106 | 107 | @example 108 | def example7(): 109 | pbar = ProgressBar() # Progressbar can guess maxval automatically. 110 | for i in pbar(range(80)): 111 | time.sleep(0.01) 112 | 113 | @example 114 | def example8(): 115 | pbar = ProgressBar(maxval=80) # Progressbar can't guess maxval. 116 | for i in pbar((i for i in range(80))): 117 | time.sleep(0.01) 118 | 119 | @example 120 | def example9(): 121 | pbar = ProgressBar(widgets=['Working: ', AnimatedMarker()]) 122 | for i in pbar((i for i in range(50))): 123 | time.sleep(.08) 124 | 125 | @example 126 | def example10(): 127 | widgets = ['Processed: ', Counter(), ' lines (', Timer(), ')'] 128 | pbar = ProgressBar(widgets=widgets) 129 | for i in pbar((i for i in range(150))): 130 | time.sleep(0.1) 131 | 132 | @example 133 | def example11(): 134 | widgets = [FormatLabel('Processed: %(value)d lines (in: %(elapsed)s)')] 135 | pbar = ProgressBar(widgets=widgets) 136 | for i in pbar((i for i in range(150))): 137 | time.sleep(0.1) 138 | 139 | @example 140 | def example12(): 141 | widgets = ['Balloon: ', AnimatedMarker(markers='.oO@* ')] 142 | pbar = ProgressBar(widgets=widgets) 143 | for i in pbar((i for i in range(24))): 144 | time.sleep(0.3) 145 | 146 | @example 147 | def example13(): 148 | # You may need python 3.x to see this correctly 149 | try: 150 | widgets = ['Arrows: ', AnimatedMarker(markers='←↖↑↗→↘↓↙')] 151 | pbar = ProgressBar(widgets=widgets) 152 | for i in pbar((i for i in range(24))): 153 | time.sleep(0.3) 154 | except UnicodeError: sys.stdout.write('Unicode error: skipping example') 155 | 156 | @example 157 | def example14(): 158 | # You may need python 3.x to see this correctly 159 | try: 160 | widgets = ['Arrows: ', AnimatedMarker(markers='◢◣◤◥')] 161 | pbar = ProgressBar(widgets=widgets) 162 | for i in pbar((i for i in range(24))): 163 | time.sleep(0.3) 164 | except UnicodeError: sys.stdout.write('Unicode error: skipping example') 165 | 166 | @example 167 | def example15(): 168 | # You may need python 3.x to see this correctly 169 | try: 170 | widgets = ['Wheels: ', AnimatedMarker(markers='◐◓◑◒')] 171 | pbar = ProgressBar(widgets=widgets) 172 | for i in pbar((i for i in range(24))): 173 | time.sleep(0.3) 174 | except UnicodeError: sys.stdout.write('Unicode error: skipping example') 175 | 176 | @example 177 | def example16(): 178 | widgets = [FormatLabel('Bouncer: value %(value)d - '), BouncingBar()] 179 | pbar = ProgressBar(widgets=widgets) 180 | for i in pbar((i for i in range(180))): 181 | time.sleep(0.05) 182 | 183 | @example 184 | def example17(): 185 | widgets = [FormatLabel('Animated Bouncer: value %(value)d - '), 186 | BouncingBar(marker=RotatingMarker())] 187 | 188 | pbar = ProgressBar(widgets=widgets) 189 | for i in pbar((i for i in range(180))): 190 | time.sleep(0.05) 191 | 192 | @example 193 | def example18(): 194 | widgets = [Percentage(), 195 | ' ', Bar(), 196 | ' ', Timer(), 197 | ' ', AdaptiveETA()] 198 | pbar = ProgressBar(widgets=widgets, maxval=500) 199 | pbar.start() 200 | for i in range(500): 201 | 202 | #print 203 | time.sleep(0.01 + (i < 100) * 0.01 + (i > 400) * 0.9) 204 | pbar.update(i + 1) 205 | pbar.finish() 206 | 207 | @example 208 | def example19(): 209 | pbar = ProgressBar() 210 | for i in pbar([]): 211 | pass 212 | pbar.finish() 213 | 214 | if __name__ == '__main__': 215 | try: 216 | example18() 217 | except KeyboardInterrupt: 218 | sys.stdout('\nQuitting examples.\n') 219 | -------------------------------------------------------------------------------- /experiments/drebin_ood.py: -------------------------------------------------------------------------------- 1 | # conduct the group of 'out of distribution' experiments on drebin dataset 2 | import os 3 | import sys 4 | import random 5 | from collections import Counter 6 | 7 | import numpy as np 8 | from sklearn.model_selection import train_test_split 9 | 10 | from core.feature import feature_type_scope_dict, feature_type_vs_architecture 11 | from core.ensemble import ensemble_method_scope_dict 12 | from tools import utils 13 | from config import config, logging 14 | 15 | logger = logging.getLogger('experiment.drebin_ood') 16 | 17 | # procedure of ood experiments 18 | # 1. build dataset 19 | # 2. preprocess data 20 | # 3. learn models 21 | # 4. save results for statistical analysis 22 | 23 | def run_experiment(feature_type, ensemble_type, n_members = 1, proc_numbers=2): 24 | """ 25 | run this group of experiments 26 | :param feature_type: the type of features (e.g., drebin, opcode, etc.), feature type associates to the model architecture 27 | :param ensemble_type: the ensemble method (e.g., vanilla, deep_ensemble, etc. 28 | :return: None 29 | """ 30 | mal_folder, ben_folder, ood_data_paths = build_data() 31 | 32 | train_dataset, validation_dataset, test_data, test_y, ood_data, ood_y, input_dim = \ 33 | data_preprocessing(feature_type, mal_folder, ben_folder, ood_data_paths, proc_numbers) 34 | 35 | ensemble_obj = get_ensemble_object(ensemble_type) 36 | # instantiation 37 | arch_type = feature_type_vs_architecture.get(feature_type) 38 | saving_dir = config.get('experiments', 'ood') 39 | if ensemble_type in ['vanilla', 'mc_dropout', 'bayesian']: 40 | ensemble_model = ensemble_obj(arch_type, base_model=None, n_members = 1, model_directory = saving_dir) 41 | else: 42 | ensemble_model = ensemble_obj(arch_type, base_model=None, n_members = n_members, model_directory = saving_dir) 43 | 44 | ensemble_model.fit(train_dataset, validation_dataset, input_dim=input_dim) 45 | 46 | test_results = ensemble_model.predict(test_data) 47 | utils.dump_joblib(test_results, os.path.join(saving_dir, '{}_{}_test.res'.format(feature_type, ensemble_type))) 48 | ensemble_model.evaluate(test_data, test_y) 49 | ood_results = ensemble_model.predict(ood_data) 50 | utils.dump_joblib(ood_results, os.path.join(saving_dir, '{}_{}_ood.res'.format(feature_type, ensemble_type))) 51 | ensemble_model.evaluate(ood_data, ood_y, is_single_class=True, name='ood') 52 | 53 | def build_data(): 54 | malware_dir = config.get('drebin', 'malware_dir') 55 | benware_dir = config.get('drebin', 'benware_dir') 56 | malare_paths, ood_data_paths = produce_ood_data(malware_dir) 57 | 58 | return malare_paths, benware_dir, ood_data_paths 59 | 60 | 61 | def produce_ood_data(malware_dir, top_frequency=30, n_selection=5, minimum_samples = 1, maximum_samples=1000): 62 | import pandas as pd 63 | malware_family_pd = pd.read_csv(config.get('drebin', 'malware_family')) 64 | counter = dict(Counter(malware_family_pd['family']).most_common(top_frequency)) 65 | i = 0 66 | while i <= 1e5: 67 | random.seed(i) 68 | selected_families = random.sample(counter.keys(), n_selection) 69 | number_of_malware = sum([counter[f] for f in selected_families]) 70 | if minimum_samples <= number_of_malware <= maximum_samples: 71 | break 72 | else: 73 | i = i + 1 74 | else: 75 | random.seed(1) 76 | selected_families = random.sample(counter.keys(), n_selection) 77 | number_of_malware = sum([counter[f] for f in selected_families]) 78 | logger.info('The number of selected ood malware samples: {}'.format(number_of_malware)) 79 | logger.info("The selected families are {}".format(','.join(selected_families))) 80 | 81 | selected_sha256_codes = list(malware_family_pd[malware_family_pd.family.isin(selected_families)]['sha256']) 82 | assert len(selected_sha256_codes) == number_of_malware 83 | 84 | malware_paths = utils.retrive_files_set(malware_dir, "", ".apk|") 85 | ood_data_paths = [] 86 | for mal_path in malware_paths: 87 | sha256 = utils.get_sha256(mal_path) 88 | if sha256 in selected_sha256_codes: 89 | ood_data_paths.append(mal_path) 90 | 91 | return list(set(malware_paths) - set(ood_data_paths)), ood_data_paths 92 | 93 | 94 | def data_preprocessing(feature_type='drebin', malware_dir=None, benware_dir=None, ood_data_path = None, proc_numbers = 2): 95 | assert feature_type in feature_type_scope_dict.keys(), 'Expected {}, but {} are supported.'.format( 96 | feature_type, feature_type_scope_dict.keys()) 97 | 98 | android_features_saving_dir = config.get('metadata', 'naive_data_directory') 99 | intermediate_data_saving_dir = config.get('metadata', 'meta_data_directory') 100 | feature_extractor = feature_type_scope_dict[feature_type](android_features_saving_dir, 101 | intermediate_data_saving_dir, 102 | update=False, 103 | proc_number = proc_numbers) 104 | mal_feature_list = feature_extractor.feature_extraction(malware_dir) 105 | n_malware = len(mal_feature_list) 106 | ben_feature_list = feature_extractor.feature_extraction(benware_dir) 107 | n_benware = len(ben_feature_list) 108 | feature_list = mal_feature_list + ben_feature_list 109 | gt_labels = np.zeros((len(feature_list),), dtype=np.int32) 110 | gt_labels[:n_malware] = 1 111 | 112 | # data split 113 | train_features, test_features, train_y, test_y = train_test_split(feature_list, gt_labels, 114 | test_size=0.2, random_state=0) 115 | feature_extractor.feature_preprocess(train_features, train_y) # produce intermediate products 116 | # obtain validation data 117 | train_features, validation_features, train_y, validation_y = train_test_split(train_features, 118 | train_y, 119 | test_size=0.25, 120 | random_state=0 121 | ) 122 | 123 | # obtain data in a format for ML algorithms 124 | train_dataset, input_dim = feature_extractor.feature2ipt(train_features, train_y) 125 | test_data, _ = feature_extractor.feature2ipt(test_features) 126 | validation_dataset, _ = feature_extractor.feature2ipt(validation_features, validation_y) 127 | 128 | ood_features = feature_extractor.feature_extraction(ood_data_path) 129 | ood_y = np.ones((len(ood_features),)) 130 | ood_data, _ = feature_extractor.feature2ipt(ood_features) 131 | 132 | return train_dataset, validation_dataset, test_data, test_y, ood_data, ood_y, input_dim 133 | 134 | 135 | def get_ensemble_object(ensemble_type): 136 | assert ensemble_type in ensemble_method_scope_dict.keys(), '{} expected, but {} are supported'.format( 137 | ensemble_type, 138 | ','.join(ensemble_method_scope_dict.keys()) 139 | ) 140 | return ensemble_method_scope_dict[ensemble_type] 141 | 142 | def _main(): 143 | # build_data() 144 | print(get_ensemble_object('vanilla')) 145 | 146 | 147 | if __name__ == '__main__': 148 | sys.exit(_main()) 149 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python -m experiments.drebin_main --detector drebin --calibration vanilla --n_members 1 --ratio 3.0 --proc_numbers 2 2 | python -m experiments.drebin_main --detector drebin --calibration temp_scaling --n_members 1 --ratio 3.0 --proc_numbers 2 3 | python -m experiments.drebin_main --detector drebin --calibration mc_dropout --n_members 10 --ratio 3.0 --proc_numbers 2 4 | python -m experiments.drebin_main --detector drebin --calibration bayesian --n_members 10 --ratio 3.0 --proc_numbers 2 5 | python -m experiments.drebin_main --detector drebin --calibration deep_ensemble --n_members 10 --ratio 3.0 --proc_numbers 2 6 | python -m experiments.drebin_main --detector drebin --calibration weighted_ensembl --n_members 10 --ratio 3.0 --proc_numbers 2 7 | python -m experiments.drebin_main --detector opcodeseq --calibration vanilla --n_members 1 --ratio 3.0 --proc_numbers 2 8 | python -m experiments.drebin_main --detector opcodeseq --calibration temp_scaling --n_members 1 --ratio 3.0 --proc_numbers 2 9 | python -m experiments.drebin_main --detector opcodeseq --calibration mc_dropout --n_members 10 --ratio 3.0 --proc_numbers 2 10 | python -m experiments.drebin_main --detector opcodeseq --calibration bayesian --n_members 10 --ratio 3.0 --proc_numbers 2 11 | python -m experiments.drebin_main --detector opcodeseq --calibration deep_ensemble --n_members 10 --ratio 3.0 --proc_numbers 2 12 | python -m experiments.drebin_main --detector opcodeseq --calibration weighted_ensemble --n_members 10 --ratio 3.0 --proc_numbers 2 13 | python -m experiments.drebin_main --detector multimodality --calibration vanilla --n_members 1 --ratio 3.0 --proc_numbers 2 14 | python -m experiments.drebin_main --detector multimodality --calibration temp_scaling --n_members 1 --ratio 3.0 --proc_numbers 2 15 | python -m experiments.drebin_main --detector multimodality --calibration mc_dropout --n_members 10 --ratio 3.0 --proc_numbers 2 16 | python -m experiments.drebin_main --detector multimodality --calibration bayesian --n_members 10 --ratio 3.0 --proc_numbers 2 17 | python -m experiments.drebin_main --detector multimodality --calibration deep_ensemble --n_members 10 --ratio 3.0 --proc_numbers 2 18 | python -m experiments.drebin_main --detector multimodality --calibration weighted_ensemble --n_members 10 --ratio 3.0 --proc_numbers 2 19 | python -m experiments.androzoo_main --detector drebin --calibration vanilla --n_members 1 --proc_numbers 2 20 | python -m experiments.androzoo_main --detector drebin --calibration temp_scaling --n_members 1 --proc_numbers 2 21 | python -m experiments.androzoo_main --detector drebin --calibration mc_dropout --n_members 10 --proc_numbers 2 22 | python -m experiments.androzoo_main --detector drebin --calibration bayesian --n_members 10 --proc_numbers 2 23 | python -m experiments.androzoo_main --detector drebin --calibration deep_ensemble --n_members 10 --proc_numbers 2 24 | python -m experiments.androzoo_main --detector drebin --calibration weighted_ensemble --n_members 10 --proc_numbers 2 25 | python -m experiments.androzoo_main --detector opcodeseq --calibration vanilla --n_members 1 --proc_numbers 2 26 | python -m experiments.androzoo_main --detector opcodeseq --calibration temp_scaling --n_members 1 --proc_numbers 2 27 | python -m experiments.androzoo_main --detector opcodeseq --calibration mc_dropout --n_members 10 --proc_numbers 2 28 | python -m experiments.androzoo_main --detector opcodeseq --calibration bayesian --n_members 10 --proc_numbers 2 29 | python -m experiments.androzoo_main --detector opcodeseq --calibration deep_ensemble --n_members 10 --proc_numbers 2 30 | python -m experiments.androzoo_main --detector opcodeseq --calibration weighted_ensemble --n_members 10 --proc_numbers 2 31 | python -m experiments.androzoo_main --detector multimodality --calibration vanilla --n_members 1 --proc_numbers 2 32 | python -m experiments.androzoo_main --detector multimodality --calibration temp_scaling --n_members 1 --proc_numbers 2 33 | python -m experiments.androzoo_main --detector multimodality --calibration mc_dropout --n_members 10 --proc_numbers 2 34 | python -m experiments.androzoo_main --detector multimodality --calibration bayesian --n_members 10 --proc_numbers 2 35 | python -m experiments.androzoo_main --detector multimodality --calibration deep_ensemble --n_members 10 --proc_numbers 2 36 | python -m experiments.androzoo_main --detector multimodality --calibration weighted_ensemble --n_members 10 --proc_numbers 2 37 | python -m experiments.oos_main --detector drebin --calibration vanilla --proc_numbers 2 38 | python -m experiments.oos_main --detector drebin --calibration temp_scaling --proc_numbers 2 39 | python -m experiments.oos_main --detector drebin --calibration mc_dropout --proc_numbers 2 40 | python -m experiments.oos_main --detector drebin --calibration bayesian --proc_numbers 2 41 | python -m experiments.oos_main --detector drebin --calibration deep_ensemble --proc_numbers 2 42 | python -m experiments.oos_main --detector drebin --calibration weighted_ensemble --proc_numbers 2 43 | python -m experiments.oos_main --detector opcodeseq --calibration vanilla --proc_numbers 2 44 | python -m experiments.oos_main --detector opcodeseq --calibration temp_scaling --proc_numbers 2 45 | python -m experiments.oos_main --detector opcodeseq --calibration mc_dropout --proc_numbers 2 46 | python -m experiments.oos_main --detector opcodeseq --calibration bayesian --proc_numbers 2 47 | python -m experiments.oos_main --detector opcodeseq --calibration deep_ensemble --proc_numbers 2 48 | python -m experiments.oos_main --detector opcodeseq --calibration weighted_ensemble --proc_numbers 2 49 | python -m experiments.oos_main --detector multimodality --calibration vanilla --proc_numbers 2 50 | python -m experiments.oos_main --detector multimodality --calibration temp_scaling --proc_numbers 2 51 | python -m experiments.oos_main --detector multimodality --calibration mc_dropout --proc_numbers 2 52 | python -m experiments.oos_main --detector multimodality --calibration bayesian --proc_numbers 2 53 | python -m experiments.oos_main --detector multimodality --calibration deep_ensemble --proc_numbers 2 54 | python -m experiments.oos_main --detector multimodality --calibration weighted_ensemble --proc_numbers 2 55 | # python -m experiments.adv_main --detector drebin --calibration vanilla --proc_numbers 2 56 | # python -m experiments.adv_main --detector drebin --calibration temp_scaling --proc_numbers 2 57 | # python -m experiments.adv_main --detector drebin --calibration mc_dropout --proc_numbers 2 58 | # python -m experiments.adv_main --detector drebin --calibration bayesian --proc_numbers 2 59 | # python -m experiments.adv_main --detector drebin --calibration deep_ensemble --proc_numbers 2 60 | # python -m experiments.adv_main --detector drebin --calibration weighted_ensemble --proc_numbers 2 61 | # python -m experiments.adv_main --detector opcodeseq --calibration vanilla --proc_numbers 2 62 | # python -m experiments.adv_main --detector opcodeseq --calibration temp_scaling --proc_numbers 2 63 | # python -m experiments.adv_main --detector opcodeseq --calibration mc_dropout --proc_numbers 2 64 | # python -m experiments.adv_main --detector opcodeseq --calibration bayesian --proc_numbers 2 65 | # python -m experiments.adv_main --detector opcodeseq --calibration deep_ensemble --proc_numbers 2 66 | # python -m experiments.adv_main --detector opcodeseq --calibration weighted_ensemble --proc_numbers 2 67 | # python -m experiments.adv_main --detector multimodality --calibration vanilla --proc_numbers 2 68 | # python -m experiments.adv_main --detector multimodality --calibration temp_scaling --proc_numbers 2 69 | # python -m experiments.adv_main --detector multimodality --calibration mc_dropout --proc_numbers 2 70 | # python -m experiments.adv_main --detector multimodality --calibration bayesian --proc_numbers 2 71 | # python -m experiments.adv_main --detector multimodality --calibration deep_ensemble --proc_numbers 2 72 | # python -m experiments.adv_main --detector multimodality --calibration weighted_ensemble --proc_numbers 2 73 | -------------------------------------------------------------------------------- /test/deep_ensemble_test.py: -------------------------------------------------------------------------------- 1 | from absl.testing import absltest 2 | from absl.testing import parameterized 3 | import tempfile 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from sklearn.datasets import load_breast_cancer 8 | 9 | from core.ensemble.deep_ensemble import DeepEnsemble, WeightedDeepEnsemble 10 | from core.ensemble.dataset_lib import build_dataset_from_numerical_data 11 | 12 | 13 | class MyTestCaseDeepEnsemble(parameterized.TestCase): 14 | def setUp(self): 15 | self.x_np, self.y_np = load_breast_cancer(return_X_y=True) 16 | self.train_dataset_v1 = build_dataset_from_numerical_data((self.x_np, self.y_np)) 17 | self.val_dataset_v1 = build_dataset_from_numerical_data((self.x_np, self.y_np)) 18 | 19 | def test_dnn(self): 20 | with tempfile.TemporaryDirectory() as output_dir: 21 | deepensemble = DeepEnsemble(architecture_type='dnn', 22 | model_directory=output_dir) 23 | deepensemble.fit(self.train_dataset_v1, self.val_dataset_v1, input_dim=self.x_np.shape[1]) 24 | 25 | res = deepensemble.predict(self.x_np) 26 | self.assertEqual(deepensemble.get_n_members(), deepensemble.n_members) 27 | self.assertTrue(res.shape == (self.x_np.shape[0], deepensemble.n_members, 1)) 28 | 29 | deepensemble.evaluate(self.x_np, self.y_np) 30 | 31 | def test_textcnn(self): 32 | with tempfile.TemporaryDirectory() as output_dir: 33 | deepensemble = DeepEnsemble(architecture_type='text_cnn', 34 | model_directory=output_dir) 35 | x = np.random.randint(0, 256, (2, 10)) 36 | y = np.random.choice(2, 2) 37 | train_dataset = build_dataset_from_numerical_data((x, y)) 38 | val_dataset = build_dataset_from_numerical_data((x, y)) 39 | deepensemble.fit(train_dataset, val_dataset) 40 | res = deepensemble.predict(x) 41 | self.assertTrue(res.shape == (x.shape[0], deepensemble.n_members, 1)) 42 | deepensemble.evaluate(x, y) 43 | 44 | def test_multimodalitynn(self): 45 | with tempfile.TemporaryDirectory() as output_dir: 46 | deepensemble = DeepEnsemble(architecture_type='multimodalitynn', 47 | model_directory=output_dir) 48 | x = [self.x_np] * 5 49 | train_data = build_dataset_from_numerical_data(tuple(x)) 50 | train_y = build_dataset_from_numerical_data(self.y_np) 51 | train_dataset = tf.data.Dataset.zip((train_data, train_y)) 52 | val_data = build_dataset_from_numerical_data(tuple(x)) 53 | val_y = build_dataset_from_numerical_data(self.y_np) 54 | val_dataset = tf.data.Dataset.zip((val_data, val_y)) 55 | deepensemble.fit(train_dataset, val_dataset, input_dim=[self.x_np.shape[1]] * 5) 56 | res = deepensemble.predict(x) 57 | self.assertTrue(res.shape == (self.x_np.shape[0], deepensemble.n_members, 1)) 58 | deepensemble.evaluate(x, self.y_np) 59 | 60 | def test_r2d2(self): 61 | with tempfile.TemporaryDirectory() as output_dir: 62 | deepensemble = DeepEnsemble(architecture_type='r2d2', 63 | model_directory=output_dir) 64 | x = np.random.uniform(0., 1., size=(10, 299, 299, 3)) 65 | y = np.random.choice(2, 10) 66 | train_dataset = build_dataset_from_numerical_data((x, y)) 67 | val_dataset = build_dataset_from_numerical_data((x, y)) 68 | deepensemble.fit(train_dataset, val_dataset, input_dim=(299, 299, 3)) 69 | res = deepensemble.predict(x) 70 | self.assertTrue(res.shape == (x.shape[0], deepensemble.n_members, 1)) 71 | deepensemble.evaluate(x, y) 72 | 73 | def test_droidectc(self): 74 | with tempfile.TemporaryDirectory() as output_dir: 75 | deepensemble = DeepEnsemble(architecture_type='droidectc', 76 | model_directory=output_dir) 77 | 78 | x = np.random.randint(0, 10000, size=(10, 1000)) 79 | y = np.random.choice(2, 10) 80 | train_dataset = build_dataset_from_numerical_data((x, y)) 81 | val_dataset = build_dataset_from_numerical_data((x, y)) 82 | deepensemble.fit(train_dataset, val_dataset) 83 | res = deepensemble.predict(x) 84 | self.assertTrue(res.shape == (x.shape[0], deepensemble.n_members, 1)) 85 | deepensemble.evaluate(x, y) 86 | 87 | 88 | class MyTestCaseWeightedDeepEnsemble(parameterized.TestCase): 89 | def setUp(self): 90 | self.x_np, self.y_np = load_breast_cancer(return_X_y=True) 91 | self.train_dataset_v1 = build_dataset_from_numerical_data((self.x_np, self.y_np)) 92 | self.val_dataset_v1 = build_dataset_from_numerical_data((self.x_np, self.y_np)) 93 | 94 | def test_dnn(self): 95 | with tempfile.TemporaryDirectory() as output_dir: 96 | weighteddeepensemble = WeightedDeepEnsemble(architecture_type='dnn', 97 | model_directory=output_dir) 98 | weighteddeepensemble.fit(self.train_dataset_v1, self.val_dataset_v1, input_dim=self.x_np.shape[1]) 99 | 100 | out, w = weighteddeepensemble.predict(self.x_np) 101 | self.assertEqual(weighteddeepensemble.get_n_members(), weighteddeepensemble.n_members) 102 | self.assertTrue(out.shape == (self.x_np.shape[0], weighteddeepensemble.n_members, 1)) 103 | self.assertTrue(np.sum(w) == 1.) 104 | 105 | weighteddeepensemble.evaluate(self.x_np, self.y_np) 106 | 107 | def test_textcnn(self): 108 | with tempfile.TemporaryDirectory() as output_dir: 109 | weighteddeepensemble = WeightedDeepEnsemble(architecture_type='text_cnn', 110 | model_directory=output_dir) 111 | x = np.random.randint(0, 256, (10, 10)) 112 | y = np.random.choice(2, 10) 113 | train_dataset = build_dataset_from_numerical_data((x, y)) 114 | val_dataset = build_dataset_from_numerical_data((x, y)) 115 | weighteddeepensemble.fit(train_dataset, val_dataset) 116 | out, w = weighteddeepensemble.predict(x) 117 | self.assertTrue(out.shape == (out.shape[0], weighteddeepensemble.n_members, 1)) 118 | self.assertTrue(np.sum(w) == 1.) 119 | weighteddeepensemble.evaluate(x, y) 120 | 121 | def test_multimodalitynn(self): 122 | with tempfile.TemporaryDirectory() as output_dir: 123 | weighteddeepensemble = WeightedDeepEnsemble(architecture_type='multimodalitynn', 124 | model_directory=output_dir) 125 | x = [self.x_np] * 5 126 | train_data = build_dataset_from_numerical_data(tuple(x)) 127 | train_y = build_dataset_from_numerical_data(self.y_np) 128 | train_dataset = tf.data.Dataset.zip((train_data, train_y)) 129 | val_data = build_dataset_from_numerical_data(tuple(x)) 130 | val_y = build_dataset_from_numerical_data(self.y_np) 131 | val_dataset = tf.data.Dataset.zip((val_data, val_y)) 132 | weighteddeepensemble.fit(train_dataset, val_dataset, input_dim=[self.x_np.shape[1]] * 5) 133 | out, w = weighteddeepensemble.predict(x) 134 | self.assertTrue(out.shape == (self.x_np.shape[0], weighteddeepensemble.n_members, 1)) 135 | self.assertTrue(np.sum(w) == 1.) 136 | weighteddeepensemble.evaluate(x, self.y_np) 137 | 138 | def test_r2d2(self): 139 | with tempfile.TemporaryDirectory() as output_dir: 140 | weighteddeepensemble = WeightedDeepEnsemble(architecture_type='r2d2', 141 | model_directory=output_dir) 142 | x = np.random.uniform(0., 1., size=(10, 299, 299, 3)) 143 | y = np.random.choice(2, 10) 144 | train_dataset = build_dataset_from_numerical_data((x, y)) 145 | val_dataset = build_dataset_from_numerical_data((x, y)) 146 | weighteddeepensemble.fit(train_dataset, val_dataset, input_dim=(299, 299, 3)) 147 | out, w = weighteddeepensemble.predict(x) 148 | self.assertTrue(out.shape == (x.shape[0], weighteddeepensemble.n_members, 1)) 149 | self.assertTrue(np.sum(w) == 1.) 150 | weighteddeepensemble.evaluate(x, y) 151 | 152 | def test_droidectc(self): 153 | with tempfile.TemporaryDirectory() as output_dir: 154 | weighteddeepensemble = WeightedDeepEnsemble(architecture_type='droidectc', 155 | model_directory=output_dir) 156 | 157 | x = np.random.randint(0, 10000, size=(10, 1000)) 158 | y = np.random.choice(2, 10) 159 | train_dataset = build_dataset_from_numerical_data((x, y)) 160 | val_dataset = build_dataset_from_numerical_data((x, y)) 161 | weighteddeepensemble.fit(train_dataset, val_dataset) 162 | out, w = weighteddeepensemble.predict(x) 163 | self.assertTrue(out.shape == (x.shape[0], weighteddeepensemble.n_members, 1)) 164 | self.assertTrue(np.sum(w) == 1.) 165 | weighteddeepensemble.evaluate(x, y) 166 | 167 | 168 | if __name__ == '__main__': 169 | absltest.main() 170 | -------------------------------------------------------------------------------- /tools/temporal.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | temporal.py 5 | ~~~~~~~~~~~ 6 | 7 | A module for working with and running time-aware evaluations. Most of the 8 | functionality of this module falls into one of two categories: working with 9 | arrays of datetimes or datetime-aligned series of data, and aggregating the 10 | steps of the ML pipeline needed to conduct sound, time-aware evaluations. 11 | 12 | 13 | @inproceedings{pendlebury2019, 14 | author = {Feargus Pendlebury, Fabio Pierazzi, Roberto Jordaney, Johannes Kinder, and Lorenzo Cavallaro}, 15 | title = {{TESSERACT: Eliminating Experimental Bias in Malware Classification across Space and Time}}, 16 | booktitle = {28th USENIX Security Symposium}, 17 | year = {2019}, 18 | address = {Santa Clara, CA}, 19 | publisher = {USENIX Association}, 20 | note = {USENIX Sec} 21 | } 22 | 23 | """ 24 | import bisect 25 | import operator 26 | from datetime import datetime, date 27 | 28 | import numpy as np 29 | from dateutil.relativedelta import relativedelta 30 | 31 | 32 | def assert_train_test_temporal_consistency(t_train, t_test): 33 | """Helper function to assert train-test temporal constraint (C1). 34 | 35 | All objects in the training set need to be temporally anterior to all 36 | objects in the testing set. Violating this constraint will positively bias 37 | the results by integrating "future" knowledge into the classifier. 38 | 39 | Args: 40 | t_train: An array of datetimes corresponding to the training set. 41 | t_test: An array of datetime corresponding to the testing set. 42 | 43 | Returns: 44 | bool: False if the partitioned dataset does _not_ adhere to C1, 45 | True otherwise. 46 | 47 | """ 48 | for train_date in t_train: 49 | for test_date in t_test: 50 | if train_date > test_date: 51 | return False 52 | return True 53 | 54 | 55 | def assert_positive_negative_temporal_consistency(y, t, month_variance=1): 56 | """Helper function to assert malware-goodware temporal constraint (C2). 57 | 58 | In any given testing period, all testing objects must be from the time 59 | window under test. In the malware domain this constraint has often been 60 | violated so that malware and goodware come from different time periods. 61 | 62 | If this is the case, it becomes impossible to tell whether a 63 | high-performing classifier is discriminating between malicious and benign 64 | objects or between old and new applications. 65 | 66 | Args: 67 | y: An array of ground-truth labels for each observation. 68 | t: An array of datetimes for each observation (aligned with y). 69 | month_variance: All malware and goodware should be between this many 70 | months. 71 | 72 | Returns: 73 | bool: False if the malware and goodware do not adhere to C2, 74 | True otherwise 75 | 76 | """ 77 | positive = np.where(y == 1)[0] 78 | negative = np.where(y != 1)[0] 79 | positive_dates = t[positive] 80 | negative_dates = t[negative] 81 | 82 | for pos_date in positive_dates: 83 | for neg_date in negative_dates: 84 | if month_difference(pos_date, neg_date) > month_variance: 85 | return False 86 | return True 87 | 88 | 89 | def month_difference(d1, d2): 90 | """Get the difference in months between two datetimes.""" 91 | return (d1.year - d2.year) * 12 + d1.month - d2.month 92 | 93 | 94 | def resolve_date(d): 95 | """Convert a str or date to an appropriate datetime. 96 | 97 | Strings should be of the format '%Y', '%Y-%m or '%Y-%m-%d', for example: 98 | '2012', '1994-02' or '1991-12-11'. Date objects with no time information 99 | will be rounded down to the midnight beginning that date. 100 | 101 | Args: 102 | d (Union[str, date]): The string or date to convert. 103 | 104 | Returns: 105 | datetime: The parsed datetime equivalent of d. 106 | """ 107 | if isinstance(d, datetime): 108 | return d 109 | 110 | if isinstance(d, date): 111 | return datetime.combine(d, datetime.min.time()) 112 | 113 | for fmt in ('%Y', '%Y-%m', '%Y-%m-%d'): 114 | try: 115 | return datetime.strptime(d, fmt) 116 | except ValueError: 117 | pass 118 | 119 | raise ValueError('date string format not recognized.') 120 | 121 | 122 | def time_aware_train_test_split(X, y, t, train_size, test_size, 123 | granularity, start_date=None): 124 | """Partition a dataset composed of time-labelled objects. 125 | 126 | Args: 127 | X (np.ndarray): Multi-dimensional array of predictors. 128 | y (np.ndarray): Array of output labels. 129 | t (np.ndarray): Array of timestamp tags. 130 | train_size (int): The training window size W (in τ). 131 | test_size (int): The testing window size Δ (in τ). 132 | granularity (str): The unit of time τ, used to denote the window size. 133 | Acceptable values are 'year|quarter|month|week|day'. 134 | start_date (date): The date to begin partioning from (eg. to align with 135 | the start of the year). 136 | 137 | Returns: 138 | (np.ndarray, list, np.ndarray, list): 139 | Training partition of predictors X. 140 | List of testing partitions of predictors X. 141 | Training partition of output variables y. 142 | List of testing partitions of predictors y. 143 | 144 | """ 145 | # Get partitioned indexes 146 | train, tests = time_aware_indexes(t, train_size, test_size, 147 | granularity, start_date) 148 | 149 | # Partition predictors and labels 150 | X_actual, y_actual, t_actual = X[train], y[train], t[train] 151 | 152 | X_tests = [X[index_set] for index_set in tests] 153 | y_tests = [y[index_set] for index_set in tests] 154 | t_tests = [t[index_set] for index_set in tests] 155 | 156 | return X_actual, X_tests, y_actual, y_tests, t_actual, t_tests 157 | 158 | 159 | def time_aware_indexes(t, train_size, test_size, granularity, start_date=None): 160 | """Return a list of indexes that partition the list t by time. 161 | 162 | Sorts the list of dates t before dividing into training and testing 163 | partitions, ensuring a 'history-aware' split in the ensuing classification 164 | task. 165 | 166 | 167 | Args: 168 | t (np.ndarray): Array of timestamp tags. 169 | train_size (int): The training window size W (in τ). 170 | test_size (int): The testing window size Δ (in τ). 171 | granularity (str): The unit of time τ, used to denote the window size. 172 | Acceptable values are 'year|quarter|month|week|day'. 173 | start_date (date): The date to begin partioning from (eg. to align with 174 | the start of the year). 175 | 176 | Returns: 177 | (list, list): 178 | Indexing for the training partition. 179 | List of indexings for the testing partitions. 180 | 181 | """ 182 | # Order the dates as well as their original positions 183 | with_indexes = zip(t, range(len(t))) 184 | ordered = sorted(with_indexes, key=operator.itemgetter(0)) 185 | 186 | # Split out the dates from the indexes 187 | dates = [tup[0] for tup in ordered] 188 | indexes = [tup[1] for tup in ordered] 189 | 190 | # Get earliest date 191 | start_date = resolve_date(start_date) if start_date else ordered[0][0] 192 | 193 | # Slice out training partition 194 | boundary = start_date + get_relative_delta(train_size, granularity) 195 | to_idx = bisect.bisect_left(dates, boundary) 196 | train = indexes[:to_idx] 197 | 198 | tests = [] 199 | # Slice out testing partitions 200 | while to_idx < len(indexes): 201 | boundary += get_relative_delta(test_size, granularity) 202 | from_idx = to_idx 203 | to_idx = bisect.bisect_left(dates, boundary) 204 | tests.append(indexes[from_idx:to_idx]) 205 | 206 | return train, tests 207 | 208 | 209 | def time_aware_partition(t, proportion): 210 | """Partition an array of dates based on the given proportion. 211 | 212 | The set of timestamps will be bisected with the left bisection sized by 213 | the given proportion. 214 | 215 | Args: 216 | t: An array of datetimes. 217 | proportion: The proportion by which to split the array. 218 | 219 | Returns: 220 | tuple: The two bisections of the array. 221 | """ 222 | # Order the dates as well as their original positions 223 | indexes = np.argsort(t) 224 | 225 | # Divide ordered set in two 226 | boundary = int(proportion * len(indexes)) 227 | 228 | return indexes[:boundary], indexes[boundary:] 229 | 230 | 231 | def temporal_slice(X, y, t): 232 | raise NotImplementedError 233 | 234 | 235 | def get_relative_delta(offset, granularity): 236 | """Get delta of size 'granularity'. 237 | 238 | Args: 239 | offset: The number of time units to offset by. 240 | granularity: The unit of time to offset by, expects one of 241 | 'year', 'quarter', 'month', 'week', 'day'. 242 | 243 | Returns: 244 | The timedelta equivalent to offset * granularity. 245 | 246 | """ 247 | # Make allowances for year(s), quarter(s), month(s), week(s), day(s) 248 | granularity = granularity[:-1] if granularity[-1] == 's' else granularity 249 | try: 250 | return { 251 | 'year': relativedelta(years=offset), 252 | 'quarter': relativedelta(months=3 * offset), 253 | 'month': relativedelta(months=offset), 254 | 'week': relativedelta(weeks=offset), 255 | 'day': relativedelta(days=offset), 256 | }[granularity] 257 | except KeyError: 258 | raise ValueError('granularity not recognised, try: ' 259 | 'year|quarter|month|week|day') 260 | -------------------------------------------------------------------------------- /tools/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def _check_probablities(p, q=None): 5 | assert 0. <= np.all(p) <= 1. 6 | if q is not None: 7 | assert len(p) == len(q), \ 8 | 'Probabilies and ground truth must have the same number of elements.' 9 | 10 | 11 | def entropy(p, base=2, eps=1e-10): 12 | """ 13 | calculate entropy in element-wise 14 | :param p: probabilities 15 | :param base: default exp 16 | :return: average entropy value 17 | """ 18 | p_arr = np.asarray(p) 19 | _check_probablities(p) 20 | enc = -(p_arr * np.log(p_arr + eps) + (1. - p_arr) * np.log(1. - p_arr + eps)) 21 | if base is not None: 22 | enc = np.clip(enc / np.log(base), a_min=0., a_max=1000) 23 | return enc 24 | 25 | 26 | def predictive_kld(p, number, w=None, base=2, eps=1e-10): 27 | """ 28 | calculate Kullback-Leibler divergence in element-wise 29 | :param p: probabilities 30 | :param number: the number of likelihood values for each sample 31 | :param w: weights for probabilities 32 | :param base: default exp 33 | :return: average entropy value 34 | """ 35 | if number <= 1: 36 | return np.zeros_like(p) 37 | 38 | p_arr = np.asarray(p).reshape((-1, number)) 39 | _check_probablities(p) 40 | q_arr = np.tile(np.mean(p_arr, axis=-1, keepdims=True), [1, number]) 41 | if w is None: 42 | w_arr = np.ones(shape=(number, 1), dtype=np.float) / number 43 | else: 44 | w_arr = np.asarray(w).reshape((number, 1)) 45 | 46 | kld_elem = p_arr * np.log((p_arr + eps) / (q_arr + eps)) + (1. - p_arr) * np.log( 47 | (1. - p_arr + eps) / (1. - q_arr + eps)) 48 | if base is not None: 49 | kld_elem = kld_elem / np.log(base) 50 | kld = np.matmul(kld_elem, w_arr) 51 | return kld 52 | 53 | 54 | def predictive_std(p, number, w=None): 55 | """ 56 | calculate the probabilities deviation 57 | :param p: probabilities 58 | :param number: the number of probabilities applied to each sample 59 | :param w: weights for probabilities 60 | :param axis: the axis along which the calculation is conducted 61 | :return: 62 | """ 63 | if number <= 1: 64 | return np.zeros_like(p) 65 | 66 | ps_arr = np.asarray(p).reshape((-1, number)) 67 | _check_probablities(ps_arr) 68 | if w is None: 69 | w = np.ones(shape=(number, 1), dtype=np.float) / number 70 | else: 71 | w = np.asarray(w).reshape((number, 1)) 72 | assert 0 <= np.all(w) <= 1. 73 | mean = np.matmul(ps_arr, w) 74 | var = np.sqrt(np.matmul(np.square(ps_arr - mean), w) * (float(number) / float(number - 1))) 75 | return var 76 | 77 | 78 | def nll(p, q, eps=1e-10, base=2): 79 | """ 80 | negative log likelihood (NLL) 81 | :param p: predictive labels 82 | :param q: ground truth labels 83 | :param eps: a small value prevents the overflow 84 | :param base: the base of log function 85 | :return: the mean of NLL 86 | """ 87 | _check_probablities(p, q) 88 | nll = -(q * np.log(p + eps) + (1. - q) * np.log(1. - p + eps)) 89 | if base is not None: 90 | nll = np.clip(nll / np.log(base), a_min=0., a_max=1000) 91 | return np.mean(nll) 92 | 93 | 94 | def b_nll(p, q, eps=1e-10, base=2): 95 | """ 96 | balanced negative log likelihood (NLL) 97 | :param p: 1-D array, predictive labels 98 | :param q: 1-D array, ground truth labels 99 | :param eps: a small value prevents the overflow 100 | :param base: the base of log function 101 | :return: the mean of NLL 102 | """ 103 | _check_probablities(p, q) 104 | pos_indicator = (q == 1) 105 | pos_nll = nll(p[pos_indicator], q[pos_indicator], eps, base) 106 | neg_indicator = (q == 0) 107 | neg_nll = nll(p[neg_indicator], q[neg_indicator], eps, base) 108 | return 1. / 2. * (pos_nll + neg_nll) 109 | 110 | 111 | def brier_score(p, q, pos_label=1): 112 | """ 113 | brier score 114 | :param p: predictive labels 115 | :param q: ground truth labels 116 | :param pos_label: the positive class 117 | :return: 118 | """ 119 | from sklearn.metrics import brier_score_loss 120 | _check_probablities(p, q) 121 | return brier_score_loss(q, p, pos_label=pos_label) 122 | 123 | 124 | def b_brier_score(p, q): 125 | """ 126 | balanced brier score 127 | :param p: predictive labels 128 | :param q: ground truth labels 129 | :return: 130 | """ 131 | pos_indicator = (q == 1) 132 | pos_bs = brier_score(p[pos_indicator], q[pos_indicator], pos_label=None) 133 | neg_indicator = (q == 0) 134 | neg_bs = brier_score(p[neg_indicator], q[neg_indicator], pos_label=None) 135 | return 1. / 2. * (pos_bs + neg_bs) 136 | 137 | 138 | def expected_calibration_error(probabilities, ground_truth, bins=10, use_unweighted_version=True): 139 | """ 140 | Code is adapted from https://github.com/google-research/google-research/tree/master/uq_benchmark_2019/metrics_lib.py 141 | Compute the expected calibration error of a set of preditions in [0, 1]. 142 | Args: 143 | probabilities: A numpy vector of N probabilities assigned to each prediction 144 | ground_truth: A numpy vector of N ground truth labels in {0,1, True, False} 145 | bins: Number of equal width bins to bin predictions into in [0, 1], or 146 | an array representing bin edges. 147 | Returns: 148 | Float: the expected calibration error. 149 | """ 150 | 151 | def bin_predictions_and_accuracies(probabilities, ground_truth, bins=10): 152 | """A helper function which histograms a vector of probabilities into bins. 153 | 154 | Args: 155 | probabilities: A numpy vector of N probabilities assigned to each prediction 156 | ground_truth: A numpy vector of N ground truth labels in {0,1} 157 | bins: Number of equal width bins to bin predictions into in [0, 1], or an 158 | array representing bin edges. 159 | 160 | Returns: 161 | bin_edges: Numpy vector of floats containing the edges of the bins 162 | (including leftmost and rightmost). 163 | accuracies: Numpy vector of floats for the average accuracy of the 164 | predictions in each bin. 165 | counts: Numpy vector of ints containing the number of examples per bin. 166 | """ 167 | _check_probablities(probabilities, ground_truth) 168 | 169 | if isinstance(bins, int): 170 | num_bins = bins 171 | else: 172 | num_bins = bins.size - 1 173 | 174 | # Ensure probabilities are never 0, since the bins in np.digitize are open on 175 | # one side. 176 | probabilities = np.where(probabilities == 0, 1e-8, probabilities) 177 | counts, bin_edges = np.histogram(probabilities, bins=bins, range=[0., 1.]) 178 | indices = np.digitize(probabilities, bin_edges, right=True) 179 | accuracies = np.array([np.mean(ground_truth[indices == i]) 180 | for i in range(1, num_bins + 1)]) 181 | return bin_edges, accuracies, counts 182 | 183 | def bin_centers_of_mass(probabilities, bin_edges): 184 | probabilities = np.where(probabilities == 0, 1e-8, probabilities) 185 | indices = np.digitize(probabilities, bin_edges, right=True) 186 | return np.array([np.mean(probabilities[indices == i]) 187 | for i in range(1, len(bin_edges))]) 188 | 189 | probabilities = probabilities.flatten() 190 | ground_truth = ground_truth.flatten() 191 | 192 | bin_edges, accuracies, counts = bin_predictions_and_accuracies( 193 | probabilities, ground_truth, bins) 194 | bin_centers = bin_centers_of_mass(probabilities, bin_edges) 195 | num_examples = np.sum(counts) 196 | if not use_unweighted_version: 197 | ece = np.sum([(counts[i] / float(num_examples)) * np.sum( 198 | np.abs(bin_centers[i] - accuracies[i])) 199 | for i in range(bin_centers.size) if counts[i] > 0]) 200 | else: 201 | ece = np.sum([(1. / float(bins)) * np.sum( 202 | np.abs(bin_centers[i] - accuracies[i])) 203 | for i in range(bin_centers.size) if counts[i] > 0]) 204 | return ece 205 | 206 | 207 | def _main(): 208 | gt_labels = np.random.choice(2, (1000,)) 209 | prob = np.random.uniform(0., 1., (1000,)) 210 | # print("Entropy:", entropy(prob)) 211 | # print("negative log likelihood:", nll(prob, gt_labels)) 212 | # print('balanced nll:', b_nll(prob, gt_labels)) 213 | # print("brier score:", brier_score(prob, gt_labels)) 214 | # print("balanced brier score:", b_brier_score(prob, gt_labels)) 215 | # print("expected calibration error:", expected_calibration_error(prob, gt_labels)) 216 | # 217 | # prob2 = np.random.uniform(0., 1., (1000, 10)) 218 | # w = np.random.uniform(0., 1., (10,)) 219 | # w = w/np.sum(w) 220 | # print("Standard deviation:", predictive_std(prob2, weights=w, number=10)) 221 | # print("Kl divergence:", predictive_kld(prob2, w, number=10)) 222 | # print("Standard deviation:", predictive_std(np.zeros(shape=(10, 10)), weights=w, number=10)) 223 | # print("Kl divergence:", predictive_kld(np.zeros(shape=(10, 10)), w, number=10)) 224 | 225 | gt_labels = np.array([1., 1., 0., 0, 0]) 226 | prob = gt_labels 227 | # print("Entropy:", entropy(prob)) 228 | # print("negative log likelihood:", nll(prob, gt_labels)) 229 | # print('balanced nll:', b_nll(prob, gt_labels)) 230 | 231 | # print("Standard deviation:", predictive_std(prob, number=2)) 232 | # print("Kl divergence:", predictive_kld(prob, number=2)) 233 | # print("brier score:", brier_score(prob, gt_labels)) 234 | # print("expected calibration error:", expected_calibration_error(prob, gt_labels)) 235 | print("balanced ece:", expected_calibration_error(prob, gt_labels, bins=2, use_undersampling=True)) 236 | 237 | # gt_labels = np.array([1., 0.]) 238 | # prob = 1. - gt_labels 239 | # print("Entropy:", entropy(prob)) 240 | # print("negative log likelihood:", nll(prob, gt_labels)) 241 | # print("brier score:", brier_score(prob, gt_labels)) 242 | # print("expected calibration error:", expected_calibration_error(prob, gt_labels)) 243 | return 0 244 | 245 | 246 | if __name__ == '__main__': 247 | _main() 248 | -------------------------------------------------------------------------------- /tools/progressbar/progressbar.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # progressbar - Text progress bar library for Python. 5 | # Copyright (c) 2005 Nilton Volpato 6 | # 7 | # This library is free software; you can redistribute it and/or 8 | # modify it under the terms of the GNU Lesser General Public 9 | # License as published by the Free Software Foundation; either 10 | # version 2.1 of the License, or (at your option) any later version. 11 | # 12 | # This library is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 15 | # Lesser General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU Lesser General Public 18 | # License along with this library; if not, write to the Free Software 19 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 20 | 21 | """Main ProgressBar class.""" 22 | 23 | from __future__ import division 24 | 25 | import math 26 | import os 27 | import signal 28 | import sys 29 | import time 30 | 31 | try: 32 | from fcntl import ioctl 33 | from array import array 34 | import termios 35 | except ImportError: 36 | pass 37 | 38 | from tools.progressbar.compat import * # for: any, next 39 | from . import widgets 40 | # import .widgets 41 | 42 | 43 | class UnknownLength: pass 44 | 45 | 46 | class ProgressBar(object): 47 | """The ProgressBar class which updates and prints the bar. 48 | 49 | A common way of using it is like: 50 | >>> pbar = ProgressBar().start() 51 | >>> for i in range(100): 52 | ... # do something 53 | ... pbar.update(i+1) 54 | ... 55 | >>> pbar.finish() 56 | 57 | You can also use a ProgressBar as an iterator: 58 | >>> progress = ProgressBar() 59 | >>> for i in progress(some_iterable): 60 | ... # do something 61 | ... 62 | 63 | Since the progress bar is incredibly customizable you can specify 64 | different widgets of any type in any order. You can even write your own 65 | widgets! However, since there are already a good number of widgets you 66 | should probably play around with them before moving on to create your own 67 | widgets. 68 | 69 | The term_width parameter represents the current terminal width. If the 70 | parameter is set to an integer then the progress bar will use that, 71 | otherwise it will attempt to determine the terminal width falling back to 72 | 80 columns if the width cannot be determined. 73 | 74 | When implementing a widget's update method you are passed a reference to 75 | the current progress bar. As a result, you have access to the 76 | ProgressBar's methods and attributes. Although there is nothing preventing 77 | you from changing the ProgressBar you should treat it as read only. 78 | 79 | Useful methods and attributes include (Public API): 80 | - currval: current progress (0 <= currval <= maxval) 81 | - maxval: maximum (and final) value 82 | - finished: True if the bar has finished (reached 100%) 83 | - start_time: the time when start() method of ProgressBar was called 84 | - seconds_elapsed: seconds elapsed since start_time and last call to 85 | update 86 | - percentage(): progress in percent [0..100] 87 | """ 88 | 89 | __slots__ = ('currval', 'fd', 'finished', 'last_update_time', 90 | 'left_justify', 'maxval', 'next_update', 'num_intervals', 91 | 'poll', 'seconds_elapsed', 'signal_set', 'start_time', 92 | 'term_width', 'update_interval', 'widgets', '_time_sensitive', 93 | '__iterable') 94 | 95 | _DEFAULT_MAXVAL = 100 96 | _DEFAULT_TERMSIZE = 80 97 | _DEFAULT_WIDGETS = [widgets.Percentage(), ' ', widgets.Bar()] 98 | 99 | def __init__(self, maxval=None, widgets=None, term_width=None, poll=1, 100 | left_justify=True, fd=sys.stderr): 101 | """Initializes a progress bar with sane defaults.""" 102 | 103 | # Don't share a reference with any other progress bars 104 | if widgets is None: 105 | widgets = list(self._DEFAULT_WIDGETS) 106 | 107 | self.maxval = maxval 108 | self.widgets = widgets 109 | self.fd = fd 110 | self.left_justify = left_justify 111 | 112 | self.signal_set = False 113 | if term_width is not None: 114 | self.term_width = term_width 115 | else: 116 | try: 117 | self._handle_resize() 118 | signal.signal(signal.SIGWINCH, self._handle_resize) 119 | self.signal_set = True 120 | except (SystemExit, KeyboardInterrupt): raise 121 | except: 122 | self.term_width = self._env_size() 123 | 124 | self.__iterable = None 125 | self._update_widgets() 126 | self.currval = 0 127 | self.finished = False 128 | self.last_update_time = None 129 | self.poll = poll 130 | self.seconds_elapsed = 0 131 | self.start_time = None 132 | self.update_interval = 1 133 | self.next_update = 0 134 | 135 | 136 | def __call__(self, iterable): 137 | """Use a ProgressBar to iterate through an iterable.""" 138 | 139 | try: 140 | self.maxval = len(iterable) 141 | except: 142 | if self.maxval is None: 143 | self.maxval = UnknownLength 144 | 145 | self.__iterable = iter(iterable) 146 | return self 147 | 148 | 149 | def __iter__(self): 150 | return self 151 | 152 | 153 | def __next__(self): 154 | try: 155 | value = next(self.__iterable) 156 | if self.start_time is None: 157 | self.start() 158 | else: 159 | self.update(self.currval + 1) 160 | return value 161 | except StopIteration: 162 | if self.start_time is None: 163 | self.start() 164 | self.finish() 165 | raise 166 | 167 | 168 | # Create an alias so that Python 2.x won't complain about not being 169 | # an iterator. 170 | next = __next__ 171 | 172 | 173 | def _env_size(self): 174 | """Tries to find the term_width from the environment.""" 175 | 176 | return int(os.environ.get('COLUMNS', self._DEFAULT_TERMSIZE)) - 1 177 | 178 | 179 | def _handle_resize(self, signum=None, frame=None): 180 | """Tries to catch resize signals sent from the terminal.""" 181 | 182 | h, w = array('h', ioctl(self.fd, termios.TIOCGWINSZ, '\0' * 8))[:2] 183 | self.term_width = w 184 | 185 | 186 | def percentage(self): 187 | """Returns the progress as a percentage.""" 188 | if self.currval >= self.maxval: 189 | return 100.0 190 | return self.currval * 100.0 / self.maxval 191 | 192 | percent = property(percentage) 193 | 194 | 195 | def _format_widgets(self): 196 | result = [] 197 | expanding = [] 198 | width = self.term_width 199 | 200 | for index, widget in enumerate(self.widgets): 201 | if isinstance(widget, widgets.WidgetHFill): 202 | result.append(widget) 203 | expanding.insert(0, index) 204 | else: 205 | widget = widgets.format_updatable(widget, self) 206 | result.append(widget) 207 | width -= len(widget) 208 | 209 | count = len(expanding) 210 | while count: 211 | portion = max(int(math.ceil(width * 1. / count)), 0) 212 | index = expanding.pop() 213 | count -= 1 214 | 215 | widget = result[index].update(self, portion) 216 | width -= len(widget) 217 | result[index] = widget 218 | 219 | return result 220 | 221 | 222 | def _format_line(self): 223 | """Joins the widgets and justifies the line.""" 224 | 225 | widgets = ''.join(self._format_widgets()) 226 | 227 | if self.left_justify: return widgets.ljust(self.term_width) 228 | else: return widgets.rjust(self.term_width) 229 | 230 | 231 | def _need_update(self): 232 | """Returns whether the ProgressBar should redraw the line.""" 233 | if self.currval >= self.next_update or self.finished: return True 234 | 235 | delta = time.time() - self.last_update_time 236 | return self._time_sensitive and delta > self.poll 237 | 238 | 239 | def _update_widgets(self): 240 | """Checks all widgets for the time sensitive bit.""" 241 | 242 | self._time_sensitive = any(getattr(w, 'TIME_SENSITIVE', False) 243 | for w in self.widgets) 244 | 245 | 246 | def update(self, value=None): 247 | """Updates the ProgressBar to a new value.""" 248 | 249 | if value is not None and value is not UnknownLength: 250 | if (self.maxval is not UnknownLength 251 | and not 0 <= value <= self.maxval): 252 | 253 | raise ValueError('Value out of range') 254 | 255 | self.currval = value 256 | 257 | 258 | if not self._need_update(): return 259 | if self.start_time is None: 260 | raise RuntimeError('You must call "start" before calling "update"') 261 | 262 | now = time.time() 263 | self.seconds_elapsed = now - self.start_time 264 | self.next_update = self.currval + self.update_interval 265 | self.fd.write(self._format_line() + '\r') 266 | self.last_update_time = now 267 | 268 | 269 | def start(self): 270 | """Starts measuring time, and prints the bar at 0%. 271 | 272 | It returns self so you can use it like this: 273 | >>> pbar = ProgressBar().start() 274 | >>> for i in range(100): 275 | ... # do something 276 | ... pbar.update(i+1) 277 | ... 278 | >>> pbar.finish() 279 | """ 280 | 281 | if self.maxval is None: 282 | self.maxval = self._DEFAULT_MAXVAL 283 | 284 | self.num_intervals = max(100, self.term_width) 285 | self.next_update = 0 286 | 287 | if self.maxval is not UnknownLength: 288 | if self.maxval < 0: raise ValueError('Value out of range') 289 | self.update_interval = self.maxval / self.num_intervals 290 | 291 | 292 | self.start_time = self.last_update_time = time.time() 293 | self.update(0) 294 | 295 | return self 296 | 297 | 298 | def finish(self): 299 | """Puts the ProgressBar bar in the finished state.""" 300 | 301 | if self.finished: 302 | return 303 | self.finished = True 304 | self.update(self.maxval) 305 | self.fd.write('\n') 306 | if self.signal_set: 307 | signal.signal(signal.SIGWINCH, signal.SIG_DFL) 308 | -------------------------------------------------------------------------------- /experiments/drebin_dataset.py: -------------------------------------------------------------------------------- 1 | # conduct the first group experiments on drebin dataset 2 | import os 3 | 4 | import numpy as np 5 | from sklearn.model_selection import train_test_split 6 | 7 | from core.feature import feature_type_scope_dict, feature_type_vs_architecture 8 | from core.ensemble import ensemble_method_scope_dict 9 | from tools import utils 10 | from config import config, logging 11 | 12 | logger = logging.getLogger('experiment.drebin') 13 | 14 | 15 | # procedure of drebin experiments 16 | # 1. build dataset 17 | # 2. preprocess data 18 | # 3. learn models 19 | # 4. save results for statistical analysis 20 | 21 | def run_experiment(feature_type, ensemble_type, random_seed=0, n_members=1, ratio=3.0, proc_numbers=2): 22 | """ 23 | run this group of experiments 24 | :param feature_type: the type of feature (e.g., drebin, opcode, etc.), feature type associates to the model architecture 25 | :param ensemble_type: the ensemble method (e.g., vanilla, deep_ensemble, etc. 26 | :param random_seed: an integer 27 | :param n_members: the number of base models enclosed by an ensemble 28 | :param ratio: the ratio of benign files to malware files 29 | :param proc_numbers: the number of threads 30 | :return: None 31 | """ 32 | mal_folder = config.get('drebin', 'malware_dir') 33 | ben_folder = config.get('drebin', 'benware_dir') 34 | logger.info('testing:{},{}'.format(feature_type, ensemble_type)) 35 | logger.info('The seed is :{}'.format(random_seed)) 36 | 37 | train_dataset, validation_dataset, test_data, test_y, input_dim = \ 38 | data_preprocessing(feature_type, mal_folder, ben_folder, ratio, proc_numbers, random_seed) 39 | 40 | ensemble_obj = get_ensemble_object(ensemble_type) 41 | # instantiation 42 | arch_type = feature_type_vs_architecture.get(feature_type) 43 | saving_dir = config.get('experiments', 'drebin') 44 | if ensemble_type in ['vanilla', 'mc_dropout', 'bayesian']: 45 | ensemble_model = ensemble_obj(arch_type, base_model=None, n_members=1, model_directory=saving_dir) 46 | else: 47 | ensemble_model = ensemble_obj(arch_type, base_model=None, n_members=n_members, model_directory=saving_dir) 48 | 49 | ensemble_model.fit(train_dataset, validation_dataset, input_dim=input_dim) 50 | 51 | test_results = ensemble_model.predict(test_data) 52 | utils.dump_joblib((test_results, test_y), 53 | os.path.join(saving_dir, '{}_{}_test.res'.format(feature_type, ensemble_type))) 54 | ensemble_model.evaluate(test_data, test_y) 55 | return 56 | 57 | 58 | def run_temperature_scaling(feature_type, ensemble_type): 59 | """ 60 | Run temperature scaling 61 | :param feature_type: the type of feature (e.g., drebin, opcode, etc.), feature type associates to the model architecture 62 | :param ensemble_type: the ensemble method (e.g., vanilla, deep_ensemble, etc. 63 | """ 64 | from core.post_calibration.temperature_scaling import find_scaling_temperature, \ 65 | apply_temperature_scaling, inverse_sigmoid 66 | logger.info('run temperature scaling:{},{}'.format(feature_type, ensemble_type)) 67 | 68 | # load dataset 69 | def data_load(feature_type='drebin'): 70 | assert feature_type in feature_type_scope_dict.keys(), 'Expected {}, but {} are supported.'.format( 71 | feature_type, list(feature_type_scope_dict.keys())) 72 | 73 | android_features_saving_dir = config.get('metadata', 'naive_data_pool') 74 | intermediate_data_saving_dir = config.get('drebin', 'intermediate_directory') 75 | feature_extractor = feature_type_scope_dict[feature_type](android_features_saving_dir, 76 | intermediate_data_saving_dir, 77 | update=False, 78 | proc_number=1) 79 | 80 | save_path = os.path.join(intermediate_data_saving_dir, 'drebin_database.' + feature_type) 81 | if os.path.exists(save_path): 82 | _1, test_filenames, validation_filenames, \ 83 | _2, test_y, validation_y = utils.read_joblib(save_path) 84 | validation_features = [os.path.join(android_features_saving_dir, filename) for filename in 85 | validation_filenames] 86 | test_features = [os.path.join(android_features_saving_dir, filename) for filename in 87 | test_filenames] 88 | else: 89 | raise ValueError 90 | 91 | test_data, _ = feature_extractor.feature2ipt(test_features) 92 | validation_data, _ = feature_extractor.feature2ipt(validation_features) 93 | 94 | return validation_data, test_data, validation_y, test_y 95 | 96 | mal_folder = config.get('drebin', 'malware_dir') 97 | ben_folder = config.get('drebin', 'benware_dir') 98 | val_data, test_data, val_label, test_y = \ 99 | data_load(feature_type, mal_folder, ben_folder) 100 | 101 | # load model 102 | ensemble_obj = get_ensemble_object(ensemble_type) 103 | arch_type = feature_type_vs_architecture.get(feature_type) 104 | model_saving_dir = config.get('experiments', 'drebin') 105 | if ensemble_type in ['vanilla', 'mc_dropout', 'bayesian']: 106 | ensemble_model = ensemble_obj(arch_type, base_model=None, model_directory=model_saving_dir) 107 | else: 108 | ensemble_model = ensemble_obj(arch_type, base_model=None, model_directory=model_saving_dir) 109 | 110 | temp_save_dir = os.path.join(config.get('drebin', 'intermediate_directory'), 111 | "{}_{}_temp.json".format(feature_type, ensemble_type)) 112 | if not os.path.exists(temp_save_dir): 113 | prob = np.squeeze(ensemble_model.predict(val_data, use_prob=True)) 114 | logits = inverse_sigmoid(prob) 115 | temperature = find_scaling_temperature(val_label, logits) 116 | utils.dump_json({'temperature': temperature}, temp_save_dir) 117 | temperature = utils.load_json(temp_save_dir)['temperature'] 118 | 119 | prob_test = ensemble_model.predict(test_data, use_prob=True) 120 | prob_t = apply_temperature_scaling(temperature, prob_test) 121 | utils.dump_joblib((prob_t, test_y), 122 | os.path.join(model_saving_dir, '{}_{}_temperature_test.res'.format(feature_type, ensemble_type))) 123 | 124 | 125 | def data_preprocessing(feature_type='drebin', malware_dir=None, benware_dir=None, ratio=3.0, proc_numbers=2, 126 | random_seed=0): 127 | assert feature_type in feature_type_scope_dict.keys(), 'Expected {}, but {} are supported.'.format( 128 | feature_type, list(feature_type_scope_dict.keys())) 129 | 130 | android_features_saving_dir = config.get('metadata', 'naive_data_pool') 131 | intermediate_data_saving_dir = config.get('drebin', 'intermediate_directory') 132 | feature_extractor = feature_type_scope_dict[feature_type](android_features_saving_dir, 133 | intermediate_data_saving_dir, 134 | update=False, 135 | proc_number=proc_numbers) 136 | save_path = os.path.join(intermediate_data_saving_dir, 'drebin_database.' + feature_type) 137 | 138 | if os.path.exists(save_path): 139 | train_filenames, test_filenames, validation_filenames, \ 140 | train_y, test_y, validation_y = utils.read_joblib(save_path) 141 | 142 | train_features = [os.path.join(config.get('metadata', 'naive_data_pool'), filename) for filename in 143 | train_filenames] 144 | validation_features = [os.path.join(config.get('metadata', 'naive_data_pool'), filename) for filename in 145 | validation_filenames] 146 | test_features = [os.path.join(config.get('metadata', 'naive_data_pool'), filename) for filename in 147 | test_filenames] 148 | else: 149 | def train_test_val_split(data): 150 | train, test = train_test_split(data, test_size=0.2, random_state=random_seed) 151 | train, val = train_test_split(train, test_size=0.25, random_state=random_seed) 152 | return train, val, test 153 | 154 | def merge_mal_ben(mal, ben): 155 | mal_feature_list = feature_extractor.feature_extraction(mal) 156 | n_malware = len(mal_feature_list) 157 | ben_feature_list = feature_extractor.feature_extraction(ben) 158 | n_benware = len(ben_feature_list) 159 | feature_list = mal_feature_list + ben_feature_list 160 | gt_labels = np.zeros((n_malware + n_benware,), dtype=np.int32) 161 | gt_labels[:n_malware] = 1 162 | import random 163 | random.seed(0) 164 | random.shuffle(feature_list) 165 | random.seed(0) 166 | random.shuffle(gt_labels) 167 | return feature_list, gt_labels 168 | 169 | malware_path_list = utils.retrive_files_set(malware_dir, "", ".apk|") 170 | mal_train, mal_val, mal_test = train_test_val_split(malware_path_list) 171 | benware_path_list = utils.retrive_files_set(benware_dir, "", ".apk|") 172 | ben_train, ben_val, ben_test = train_test_val_split(benware_path_list) 173 | 174 | # undersampling the benign files 175 | if ratio >= 1.: 176 | ben_train = undersampling(ben_train, len(mal_train), ratio) 177 | logger.info('Training set, the number of benign files vs. malware files: {} vs. {}'.format(len(ben_train), 178 | len(mal_train))) 179 | train_features, train_y = merge_mal_ben(mal_train, ben_train) 180 | validation_features, validation_y = merge_mal_ben(mal_val, ben_val) 181 | test_features, test_y = merge_mal_ben(mal_test, ben_test) 182 | 183 | train_filenames = [os.path.basename(path) for path in train_features] 184 | validation_filenames = [os.path.basename(path) for path in validation_features] 185 | test_filenames = [os.path.basename(path) for path in test_features] 186 | utils.dump_joblib( 187 | (train_filenames, test_filenames, validation_filenames, train_y, test_y, validation_y), 188 | save_path) 189 | # obtain data in a format for ML algorithms 190 | feature_extractor.feature_preprocess(train_features, train_y) # produce datasets products 191 | train_dataset, input_dim = feature_extractor.feature2ipt(train_features, train_y, is_training_set=True) 192 | test_data, _ = feature_extractor.feature2ipt(test_features) 193 | validation_dataset, _ = feature_extractor.feature2ipt(validation_features, validation_y) 194 | return train_dataset, validation_dataset, test_data, test_y, input_dim 195 | 196 | 197 | def undersampling(ben_train, num_of_mal, ratio): 198 | number_of_choice = int(num_of_mal * ratio) if int(num_of_mal * ratio) <= len(ben_train) else len(ben_train) 199 | import random 200 | random.seed(0) 201 | random.shuffle(ben_train) 202 | return ben_train[:number_of_choice] 203 | 204 | 205 | def get_ensemble_object(ensemble_type): 206 | assert ensemble_type in ensemble_method_scope_dict.keys(), '{} expected, but {} are supported'.format( 207 | ensemble_type, 208 | ','.join(ensemble_method_scope_dict.keys()) 209 | ) 210 | return ensemble_method_scope_dict[ensemble_type] 211 | -------------------------------------------------------------------------------- /core/feature/multimodality/multimodality.py: -------------------------------------------------------------------------------- 1 | """ 2 | - String feature 3 | - Permission feature 4 | - Component feature 5 | - Environmental feature 6 | - Method opcode feature 7 | - Method API feature 8 | - Shared library function opcode feature 9 | """ 10 | 11 | import os 12 | import time 13 | 14 | import magic 15 | import zipfile 16 | import hashlib 17 | from collections import defaultdict 18 | 19 | from androguard.misc import AnalyzeAPK, APK 20 | import lxml.etree as etree 21 | from xml.dom import minidom 22 | 23 | from elftools.elf.elffile import ELFFile 24 | from elftools.common.py3compat import BytesIO 25 | from capstone import Cs, CS_ARCH_ARM64, CS_ARCH_ARM, CS_MODE_ARM 26 | 27 | from tools import utils 28 | from sys import platform as _platform 29 | from config import logging 30 | 31 | if _platform == "linux" or _platform == "linux2": 32 | TMP_DIR = '/tmp/' 33 | elif _platform == "win32" or _platform == "win64": 34 | TMP_DIR = 'C:\\TEMP\\' 35 | logger = logging.getLogger('feature.multimodality') 36 | current_dir = os.path.dirname(os.path.realpath(__file__)) 37 | API_LIST = [api.decode('utf-8') for api in \ 38 | utils.read_txt(os.path.join(current_dir, 'res/api_list.txt'), mode='rb')] 39 | 40 | 41 | def get_multimod_feature(apk_path, api_list, save_path): 42 | """ 43 | get feature 44 | :param apk_path: an absolute apk path 45 | :param api_list: api list, api level 22 is considered 46 | :param save_path: a path for saving feature 47 | :return: (status, result_file_path), e.g., (True, back_path_name+apk_name+'.data') 48 | """ 49 | try: 50 | print("Processing " + apk_path) 51 | start_time = time.time() 52 | data_list = [] 53 | # permission, components, environment information 54 | data_list.append(get_dict_feature_xml(apk_path)) 55 | # string, method api, method opcodes 56 | string_feature_dict, api_feature_dict, opcode_feature_dict = get_feature_dex(apk_path, api_list) 57 | data_list.extend([string_feature_dict, api_feature_dict, opcode_feature_dict]) 58 | # shared library feature 59 | data_list.append(get_feature_shared_lib(apk_path)) 60 | 61 | utils.dump_json(data_list, save_path) 62 | except Exception as e: 63 | return e 64 | else: 65 | return save_path 66 | 67 | 68 | def get_dict_feature_xml(apk_path): 69 | """ 70 | get requested feature from manifest file 71 | :param apk_path: absolute path of an apk file 72 | :return: a dict of elements {feature:occurrence, ..., } 73 | """ 74 | permission_component_envinfo = defaultdict(int) 75 | xml_tmp_dir = os.path.join(TMP_DIR, 'xml_dir') 76 | if not os.path.exists(xml_tmp_dir): 77 | os.mkdir(xml_tmp_dir) 78 | apk_name = os.path.splitext(os.path.basename(apk_path))[0] 79 | try: 80 | apk_path = os.path.abspath(apk_path) 81 | a = APK(apk_path) 82 | f = open(os.path.join(xml_tmp_dir, apk_name + '.xml'), 'wb') 83 | xmlstreaming = etree.tostring(a.xml['AndroidManifest.xml'], pretty_print=True, encoding='utf-8') 84 | f.write(xmlstreaming) 85 | f.close() 86 | except Exception as e: 87 | raise Exception("Fail to load xml file of apk {}:{}".format(apk_path, str(e))) 88 | 89 | # start obtain feature permission, components, environment information 90 | try: 91 | with open(os.path.join(xml_tmp_dir, apk_name + '.xml'), 'rb') as f: 92 | dom_xml = minidom.parse(f) 93 | dom_elements = dom_xml.documentElement 94 | 95 | dom_permissions = dom_elements.getElementsByTagName('uses-permission') 96 | for permission in dom_permissions: 97 | if permission.hasAttribute('android:name'): 98 | permission_component_envinfo[permission.getAttribute('android:name')] = 1 99 | 100 | dom_activities = dom_elements.getElementsByTagName('activity') 101 | for activity in dom_activities: 102 | if activity.hasAttribute('android:name'): 103 | permission_component_envinfo[activity.getAttribute('android:name')] = 1 104 | 105 | dom_services = dom_elements.getElementsByTagName("service") 106 | for service in dom_services: 107 | if service.hasAttribute("android:name"): 108 | permission_component_envinfo[service.getAttribute("android:name")] = 1 109 | 110 | dom_contentproviders = dom_elements.getElementsByTagName("provider") 111 | for provider in dom_contentproviders: 112 | if provider.hasAttribute("android:name"): 113 | permission_component_envinfo[provider.getAttribute("android:name")] = 1 114 | # uri 115 | dom_uris = provider.getElementsByTagName('grant-uri-permission') 116 | # intents --- action 117 | intent_actions = provider.getElementsByTagName('action') 118 | # we neglect to compose the uri feature by so-called paired provider name and intents 119 | # instead, the path of uri is used directly 120 | for uri in dom_uris: 121 | if uri.hasAttribute('android:path'): 122 | permission_component_envinfo[uri.getAttribute('android:path')] = 1 123 | 124 | dom_broadcastreceivers = dom_elements.getElementsByTagName("receiver") 125 | for receiver in dom_broadcastreceivers: 126 | if receiver.hasAttribute("android:name"): 127 | permission_component_envinfo[receiver.getAttribute("android:name")] = 1 128 | 129 | dom_intentfilter_actions = dom_elements.getElementsByTagName("action") 130 | for action in dom_intentfilter_actions: 131 | if action.hasAttribute("android:name"): 132 | permission_component_envinfo[action.getAttribute("android:name")] = 1 133 | 134 | dom_hardwares = dom_elements.getElementsByTagName("uses-feature") 135 | for hardware in dom_hardwares: 136 | if hardware.hasAttribute("android:name"): 137 | permission_component_envinfo[hardware.getAttribute("android:name")] = 1 138 | 139 | dom_libraries = dom_elements.getElementsByTagName("android:name") 140 | for lib in dom_libraries: 141 | if lib.hasAttribute('android:name'): 142 | permission_component_envinfo[lib.getAttribute('android:name')] = 1 143 | 144 | dom_sdk_versions = dom_elements.getElementsByTagName("uses-sdk") 145 | for sdk in dom_sdk_versions: 146 | if sdk.hasAttribute('android:minSdkVersion'): 147 | permission_component_envinfo[sdk.getAttribute('android:minSdkVersion')] = 1 148 | if sdk.hasAttribute('android:targetSdkVersion'): 149 | permission_component_envinfo[sdk.getAttribute('android:targetSdkVersion')] = 1 150 | if sdk.hasAttribute('android:maxSdkVersion'): 151 | permission_component_envinfo[sdk.getAttribute('android:maxSdkVersion')] = 1 152 | 153 | return permission_component_envinfo 154 | except Exception as e: 155 | raise Exception("Fail to process xml file of apk {}:{}".format(apk_path, str(e))) 156 | 157 | 158 | def get_feature_dex(apk_path, api_list): 159 | """ 160 | get feature about opcode and functions, and string from .dex files 161 | :param apk_path: an absolute path of an apk 162 | :param api_list: a list of apis 163 | :return: two dicts of elements {feature:occurrence, ..., } corresponds to functionality feature and string feature 164 | """ 165 | string_feature_dict = defaultdict(int) 166 | opcode_feature_dict = defaultdict(int) 167 | api_feature_dict = defaultdict(int) 168 | 169 | try: 170 | _1, _2, dx = AnalyzeAPK(apk_path) 171 | except Exception as e: 172 | raise Exception('Fail to load dex file of apk {}:{}'.format(apk_path, str(e))) 173 | 174 | for method in dx.get_methods(): 175 | if method.is_external(): 176 | continue 177 | method_obj = method.get_method() 178 | for instruction in method_obj.get_instructions(): 179 | opcode = instruction.get_name() 180 | opcode_feature_dict[opcode] += 1 181 | if 'invoke-' in opcode: 182 | code_body = instruction.get_output() 183 | if ';->' not in code_body: 184 | continue 185 | head_part, rear_part = code_body.split(';->') 186 | class_name = head_part.strip().split(' ')[-1] 187 | method_name = rear_part.strip().split('(')[0] 188 | if class_name + ';->' + method_name in api_list: 189 | api_feature_dict[class_name + ';->' + method_name] += 1 190 | 191 | if (opcode == 'const-string') or (opcode == 'const-string/jumbo'): 192 | code_body = instruction.get_output() 193 | ss_string = code_body.split(' ')[-1].strip('\'').strip('\"').encode('utf-8') 194 | if not ss_string: 195 | continue 196 | hashing_string = hashlib.sha512(ss_string).hexdigest() 197 | string_feature_dict[hashing_string] = 1 198 | 199 | return string_feature_dict, api_feature_dict, opcode_feature_dict 200 | 201 | 202 | def get_feature_shared_lib(apk_path): 203 | """ 204 | get feature from ELF files 205 | :param apk_path: an absolute path of an apk 206 | :return: a dict of elements {feature:occurrence, ..., } 207 | """ 208 | feature_dict = defaultdict(int) 209 | try: 210 | with zipfile.ZipFile(apk_path, 'r') as apk_zipf: 211 | handled_name_list = [] 212 | for name in apk_zipf.namelist(): 213 | if os.path.basename(name) in handled_name_list: 214 | continue 215 | brief_file_info = magic.from_buffer(apk_zipf.read(name)) 216 | if 'ELF' not in brief_file_info: 217 | continue 218 | if 'ELF 64' in brief_file_info: 219 | arch_info = CS_ARCH_ARM64 220 | elif 'ELF 32' in brief_file_info: 221 | arch_info = CS_ARCH_ARM 222 | else: 223 | raise ValueError("Android ABIs with Arm 64-bit or Arm 32 bit are support.") 224 | 225 | with apk_zipf.open(name, 'r') as fr: 226 | bt_stream = BytesIO(fr.read()) 227 | elf_fhander = ELFFile(bt_stream) 228 | # arm opcodes (i.e., instructions) 229 | bt_code_text = elf_fhander.get_section_by_name('.text') 230 | if (bt_code_text is not None) and (bt_code_text.data() is not None): 231 | md = Cs(arch_info, CS_MODE_ARM) 232 | for _1, _2, op_code, _3 in md.disasm_lite(bt_code_text.data(), 0x1000): 233 | feature_dict[op_code] += 1 234 | # functions 235 | # we consider about whether a function occurs or not, rather than its frequency, 236 | # for simplifying implementation. 237 | # If only running the codes on linux platform, we suggest an arm supported tool 'objdump', 238 | # which can be used to count the frequency of functions conveniently. We here utilize 'pyelftools' 239 | # to accommodate different development platform. 240 | REL_PLT_section = elf_fhander.get_section_by_name('.rela.plt') 241 | if REL_PLT_section is None: 242 | continue 243 | symtable = elf_fhander.get_section(REL_PLT_section['sh_link']) 244 | 245 | for rel_plt in REL_PLT_section.iter_relocations(): 246 | func_name = symtable.get_symbol(rel_plt['r_info_sym']).name 247 | feature_dict[func_name] += 1 248 | 249 | handled_name_list.append(os.path.basename(name)) 250 | return feature_dict 251 | except Exception as e: 252 | raise Exception("Fail to process shared library file of apk {}:{}".format(apk_path, str(e))) 253 | 254 | 255 | def dump_feaure_list(feature_list, save_path): 256 | utils.dump_json(save_path, feature_list) 257 | return 258 | 259 | 260 | def load_feature(save_path): 261 | return utils.load_json(save_path) 262 | 263 | 264 | def wrapper_load_features(save_path): 265 | try: 266 | return load_feature(save_path) 267 | except Exception as e: 268 | return e 269 | -------------------------------------------------------------------------------- /tools/progressbar/widgets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # progressbar - Text progress bar library for Python. 5 | # Copyright (c) 2005 Nilton Volpato 6 | # 7 | # This library is free software; you can redistribute it and/or 8 | # modify it under the terms of the GNU Lesser General Public 9 | # License as published by the Free Software Foundation; either 10 | # version 2.1 of the License, or (at your option) any later version. 11 | # 12 | # This library is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 15 | # Lesser General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU Lesser General Public 18 | # License along with this library; if not, write to the Free Software 19 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 20 | 21 | """Default ProgressBar widgets.""" 22 | 23 | from __future__ import division 24 | 25 | import datetime 26 | import math 27 | 28 | try: 29 | from abc import ABCMeta, abstractmethod 30 | except ImportError: 31 | AbstractWidget = object 32 | abstractmethod = lambda fn: fn 33 | else: 34 | AbstractWidget = ABCMeta('AbstractWidget', (object,), {}) 35 | 36 | 37 | def format_updatable(updatable, pbar): 38 | if hasattr(updatable, 'update'): return updatable.update(pbar) 39 | else: return updatable 40 | 41 | 42 | class Widget(AbstractWidget): 43 | """The base class for all widgets. 44 | 45 | The ProgressBar will call the widget's update value when the widget should 46 | be updated. The widget's size may change between calls, but the widget may 47 | display incorrectly if the size changes drastically and repeatedly. 48 | 49 | The boolean TIME_SENSITIVE informs the ProgressBar that it should be 50 | updated more often because it is time sensitive. 51 | """ 52 | 53 | TIME_SENSITIVE = False 54 | __slots__ = () 55 | 56 | @abstractmethod 57 | def update(self, pbar): 58 | """Updates the widget. 59 | 60 | pbar - a reference to the calling ProgressBar 61 | """ 62 | 63 | 64 | class WidgetHFill(Widget): 65 | """The base class for all variable width widgets. 66 | 67 | This widget is much like the \\hfill command in TeX, it will expand to 68 | fill the line. You can use more than one in the same line, and they will 69 | all have the same width, and together will fill the line. 70 | """ 71 | 72 | @abstractmethod 73 | def update(self, pbar, width): 74 | """Updates the widget providing the total width the widget must fill. 75 | 76 | pbar - a reference to the calling ProgressBar 77 | width - The total width the widget must fill 78 | """ 79 | 80 | 81 | class Timer(Widget): 82 | """Widget which displays the elapsed seconds.""" 83 | 84 | __slots__ = ('format_string',) 85 | TIME_SENSITIVE = True 86 | 87 | def __init__(self, format='Elapsed Time: %s'): 88 | self.format_string = format 89 | 90 | @staticmethod 91 | def format_time(seconds): 92 | """Formats time as the string "HH:MM:SS".""" 93 | 94 | return str(datetime.timedelta(seconds=int(seconds))) 95 | 96 | 97 | def update(self, pbar): 98 | """Updates the widget to show the elapsed time.""" 99 | 100 | return self.format_string % self.format_time(pbar.seconds_elapsed) 101 | 102 | 103 | class ETA(Timer): 104 | """Widget which attempts to estimate the time of arrival.""" 105 | 106 | TIME_SENSITIVE = True 107 | 108 | def update(self, pbar): 109 | """Updates the widget to show the ETA or total time when finished.""" 110 | 111 | if pbar.currval == 0: 112 | return 'ETA: --:--:--' 113 | elif pbar.finished: 114 | return 'Time: %s' % self.format_time(pbar.seconds_elapsed) 115 | else: 116 | elapsed = pbar.seconds_elapsed 117 | eta = elapsed * pbar.maxval / pbar.currval - elapsed 118 | return 'ETA: %s' % self.format_time(eta) 119 | 120 | 121 | class AdaptiveETA(Timer): 122 | """Widget which attempts to estimate the time of arrival. 123 | 124 | Uses a weighted average of two estimates: 125 | 1) ETA based on the total progress and time elapsed so far 126 | 2) ETA based on the progress as per tha last 10 update reports 127 | 128 | The weight depends on the current progress so that to begin with the 129 | total progress is used and at the end only the most recent progress is 130 | used. 131 | """ 132 | 133 | TIME_SENSITIVE = True 134 | NUM_SAMPLES = 10 135 | 136 | def _update_samples(self, currval, elapsed): 137 | sample = (currval, elapsed) 138 | if not hasattr(self, 'samples'): 139 | self.samples = [sample] * (self.NUM_SAMPLES + 1) 140 | else: 141 | self.samples.append(sample) 142 | return self.samples.pop(0) 143 | 144 | def _eta(self, maxval, currval, elapsed): 145 | return elapsed * maxval / float(currval) - elapsed 146 | 147 | def update(self, pbar): 148 | """Updates the widget to show the ETA or total time when finished.""" 149 | if pbar.currval == 0: 150 | return 'Remaining Time: --:--:--' 151 | elif pbar.finished: 152 | #return 'Time: %s' % self.format_time(pbar.seconds_elapsed) 153 | return '' 154 | else: 155 | elapsed = pbar.seconds_elapsed 156 | currval1, elapsed1 = self._update_samples(pbar.currval, elapsed) 157 | eta = self._eta(pbar.maxval, pbar.currval, elapsed) 158 | if pbar.currval > currval1: 159 | etasamp = self._eta(pbar.maxval - currval1, 160 | pbar.currval - currval1, 161 | elapsed - elapsed1) 162 | weight = (pbar.currval / float(pbar.maxval)) ** 0.5 163 | eta = (1 - weight) * eta + weight * etasamp 164 | return 'Remaining Time: %s' % self.format_time(eta) 165 | 166 | 167 | class FileTransferSpeed(Widget): 168 | """Widget for showing the transfer speed (useful for file transfers).""" 169 | 170 | FORMAT = '%6.2f %s%s/s' 171 | PREFIXES = ' kMGTPEZY' 172 | __slots__ = ('unit',) 173 | 174 | def __init__(self, unit='B'): 175 | self.unit = unit 176 | 177 | def update(self, pbar): 178 | """Updates the widget with the current SI prefixed speed.""" 179 | 180 | if pbar.seconds_elapsed < 2e-6 or pbar.currval < 2e-6: # =~ 0 181 | scaled = power = 0 182 | else: 183 | speed = pbar.currval / pbar.seconds_elapsed 184 | power = int(math.log(speed, 1000)) 185 | scaled = speed / 1000.**power 186 | 187 | return self.FORMAT % (scaled, self.PREFIXES[power], self.unit) 188 | 189 | 190 | class AnimatedMarker(Widget): 191 | """An animated marker for the progress bar which defaults to appear as if 192 | it were rotating. 193 | """ 194 | 195 | __slots__ = ('markers', 'curmark') 196 | 197 | def __init__(self, markers='|/-\\'): 198 | self.markers = markers 199 | self.curmark = -1 200 | 201 | def update(self, pbar): 202 | """Updates the widget to show the next marker or the first marker when 203 | finished""" 204 | 205 | if pbar.finished: return self.markers[0] 206 | 207 | self.curmark = (self.curmark + 1) % len(self.markers) 208 | return self.markers[self.curmark] 209 | 210 | # Alias for backwards compatibility 211 | RotatingMarker = AnimatedMarker 212 | 213 | 214 | class Counter(Widget): 215 | """Displays the current count.""" 216 | 217 | __slots__ = ('format_string',) 218 | 219 | def __init__(self, format='%d'): 220 | self.format_string = format 221 | 222 | def update(self, pbar): 223 | return self.format_string % pbar.currval 224 | 225 | 226 | class Percentage(Widget): 227 | """Displays the current percentage as a number with a percent sign.""" 228 | 229 | def update(self, pbar): 230 | return '%3d%%' % pbar.percentage() 231 | 232 | 233 | class FormatLabel(Timer): 234 | """Displays a formatted label.""" 235 | 236 | mapping = { 237 | 'elapsed': ('seconds_elapsed', Timer.format_time), 238 | 'finished': ('finished', None), 239 | 'last_update': ('last_update_time', None), 240 | 'max': ('maxval', None), 241 | 'seconds': ('seconds_elapsed', None), 242 | 'start': ('start_time', None), 243 | 'value': ('currval', None) 244 | } 245 | 246 | __slots__ = ('format_string',) 247 | def __init__(self, format): 248 | self.format_string = format 249 | 250 | def update(self, pbar): 251 | context = {} 252 | for name, (key, transform) in self.mapping.items(): 253 | try: 254 | value = getattr(pbar, key) 255 | 256 | if transform is None: 257 | context[name] = value 258 | else: 259 | context[name] = transform(value) 260 | except: pass 261 | 262 | return self.format_string % context 263 | 264 | 265 | class SimpleProgress(Widget): 266 | """Returns progress as a count of the total (e.g.: "5 of 47").""" 267 | 268 | __slots__ = ('sep',) 269 | 270 | def __init__(self, sep=' of '): 271 | self.sep = sep 272 | 273 | def update(self, pbar): 274 | return '%d%s%d' % (pbar.currval, self.sep, pbar.maxval) 275 | 276 | 277 | class Bar(WidgetHFill): 278 | """A progress bar which stretches to fill the line.""" 279 | 280 | __slots__ = ('marker', 'left', 'right', 'fill', 'fill_left') 281 | 282 | def __init__(self, marker='#', left='|', right='|', fill=' ', 283 | fill_left=True): 284 | """Creates a customizable progress bar. 285 | 286 | marker - string or updatable object to use as a marker 287 | left - string or updatable object to use as a left border 288 | right - string or updatable object to use as a right border 289 | fill - character to use for the empty part of the progress bar 290 | fill_left - whether to fill from the left or the right 291 | """ 292 | self.marker = marker 293 | self.left = left 294 | self.right = right 295 | self.fill = fill 296 | self.fill_left = fill_left 297 | 298 | 299 | def update(self, pbar, width): 300 | """Updates the progress bar and its subcomponents.""" 301 | 302 | left, marked, right = (format_updatable(i, pbar) for i in 303 | (self.left, self.marker, self.right)) 304 | 305 | width -= len(left) + len(right) 306 | # Marked must *always* have length of 1 307 | if pbar.maxval: 308 | marked *= int(pbar.currval / pbar.maxval * width) 309 | else: 310 | marked = '' 311 | 312 | if self.fill_left: 313 | return '%s%s%s' % (left, marked.ljust(width, self.fill), right) 314 | else: 315 | return '%s%s%s' % (left, marked.rjust(width, self.fill), right) 316 | 317 | 318 | class ReverseBar(Bar): 319 | """A bar which has a marker which bounces from side to side.""" 320 | 321 | def __init__(self, marker='#', left='|', right='|', fill=' ', 322 | fill_left=False): 323 | """Creates a customizable progress bar. 324 | 325 | marker - string or updatable object to use as a marker 326 | left - string or updatable object to use as a left border 327 | right - string or updatable object to use as a right border 328 | fill - character to use for the empty part of the progress bar 329 | fill_left - whether to fill from the left or the right 330 | """ 331 | self.marker = marker 332 | self.left = left 333 | self.right = right 334 | self.fill = fill 335 | self.fill_left = fill_left 336 | 337 | 338 | class BouncingBar(Bar): 339 | def update(self, pbar, width): 340 | """Updates the progress bar and its subcomponents.""" 341 | 342 | left, marker, right = (format_updatable(i, pbar) for i in 343 | (self.left, self.marker, self.right)) 344 | 345 | width -= len(left) + len(right) 346 | 347 | if pbar.finished: return '%s%s%s' % (left, width * marker, right) 348 | 349 | position = int(pbar.currval % (width * 2 - 1)) 350 | if position > width: position = width * 2 - position 351 | lpad = self.fill * (position - 1) 352 | rpad = self.fill * (width - len(marker) - len(lpad)) 353 | 354 | # Swap if we want to bounce the other way 355 | if not self.fill_left: rpad, lpad = lpad, rpad 356 | 357 | return '%s%s%s%s%s' % (left, lpad, marker, rpad, right) 358 | -------------------------------------------------------------------------------- /core/ensemble/deep_ensemble.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | import time 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from core.ensemble.vanilla import Vanilla 8 | from core.ensemble.model_hp import train_hparam 9 | from core.ensemble.dataset_lib import build_dataset_from_numerical_data 10 | from tools import utils 11 | from config import logging, ErrorHandler 12 | 13 | logger = logging.getLogger('ensemble.deep_ensemble') 14 | logger.addHandler(ErrorHandler) 15 | 16 | 17 | class DeepEnsemble(Vanilla): 18 | def __init__(self, 19 | architecture_type='dnn', 20 | base_model=None, 21 | n_members=5, 22 | model_directory=None, 23 | name='DEEPENSEMBLE' 24 | ): 25 | super(DeepEnsemble, self).__init__(architecture_type, 26 | base_model, 27 | n_members, 28 | model_directory, 29 | name) 30 | self.hparam = train_hparam 31 | self.ensemble_type = 'deep_ensemble' 32 | 33 | 34 | class WeightedDeepEnsemble(Vanilla): 35 | def __init__(self, 36 | architecture_type='dnn', 37 | base_model=None, 38 | n_members=5, 39 | model_directory=None, 40 | name='WEIGTHEDDEEPENSEMBLE' 41 | ): 42 | super(WeightedDeepEnsemble, self).__init__(architecture_type, 43 | base_model, 44 | n_members, 45 | model_directory, 46 | name) 47 | self.hparam = train_hparam 48 | self.ensemble_type = 'deep_ensemble' 49 | self.weight_modular = None 50 | 51 | def get_weight_modular(self): 52 | class Simplex(tf.keras.constraints.Constraint): 53 | def __call__(self, w): 54 | return tf.math.softmax(w - tf.math.reduce_max(w), axis=0) 55 | 56 | inputs = tf.keras.Input(shape=(self.n_members,)) 57 | outs = tf.keras.layers.Dense(1, use_bias=False, activation=None, kernel_constraint=Simplex(), name='simplex')( 58 | inputs) 59 | return tf.keras.Model(inputs=inputs, outputs=outs) 60 | 61 | def predict(self, x, use_prob=False): 62 | """ conduct prediction """ 63 | self.base_model = None 64 | self.weight_modular = None 65 | self.weights_list = [] 66 | self._optimizers_dict = [] 67 | self.load_ensemble_weights() 68 | output_list = [] 69 | start_time = time.time() 70 | for base_model in self.model_generator(): 71 | if isinstance(x, tf.data.Dataset): 72 | output_list.append(base_model.predict(x, verbose=1)) 73 | elif isinstance(x, (np.ndarray, list)): 74 | output_list.append(base_model.predict(x, batch_size=self.hparam.batch_size, verbose=1)) 75 | else: 76 | raise ValueError 77 | total_time = time.time() - start_time 78 | logger.info('Inference costs {} seconds.'.format(total_time)) 79 | assert self.weight_modular is not None 80 | output = self.weight_modular(np.hstack(output_list)).numpy() 81 | if not use_prob: 82 | return np.stack(output_list, axis=1), self.weight_modular.get_layer('simplex').get_weights() 83 | else: 84 | return output 85 | 86 | def fit(self, train_set, validation_set=None, input_dim=None, **kwargs): 87 | """ 88 | fit the ensemble by producing a lists of model weights 89 | :param train_set: tf.data.Dataset, the type shall accommodate to the input format of Tensorflow models 90 | :param validation_set: validation data, optional 91 | :param input_dim: integer or list, input dimension except for the batch size 92 | """ 93 | # training preparation 94 | if self.base_model is None: 95 | self.build_model(input_dim=input_dim) 96 | if self.weight_modular is None: 97 | self.weight_modular = self.get_weight_modular() 98 | 99 | self.base_model.compile( 100 | optimizer=tf.keras.optimizers.Adam(learning_rate=self.hparam.learning_rate, 101 | clipvalue=self.hparam.clipvalue), 102 | loss=tf.keras.losses.BinaryCrossentropy(), 103 | metrics=[tf.keras.metrics.BinaryAccuracy()], 104 | ) 105 | 106 | self.weight_modular.compile( 107 | optimizer=tf.keras.optimizers.Adam(learning_rate=self.hparam.learning_rate, 108 | clipvalue=self.hparam.clipvalue), 109 | loss=tf.keras.losses.BinaryCrossentropy(), 110 | metrics=[tf.keras.metrics.BinaryAccuracy()], 111 | ) 112 | 113 | # training 114 | logger.info("hyper-parameters:") 115 | logger.info(dict(self.hparam._asdict())) 116 | logger.info("...training start!") 117 | 118 | best_val_accuracy = 0. 119 | total_time = 0. 120 | for epoch in range(self.hparam.n_epochs): 121 | for member_idx in range(self.n_members): 122 | if member_idx < len(self.weights_list): # loading former weights 123 | self.base_model.set_weights(self.weights_list[member_idx]) 124 | self.base_model.optimizer.set_weights(self._optimizers_dict[member_idx]) 125 | elif member_idx == 0: 126 | pass # do nothing 127 | else: 128 | self.reinitialize_base_model() 129 | 130 | msg = 'Epoch {}/{}, member {}/{}, and {} member(s) in list'.format(epoch + 1, 131 | self.hparam.n_epochs, member_idx + 1, 132 | self.n_members, 133 | len(self.weights_list)) 134 | print(msg) 135 | start_time = time.time() 136 | self.base_model.fit(train_set, 137 | epochs=epoch + 1, 138 | initial_epoch=epoch, 139 | validation_data=validation_set 140 | ) 141 | self.update_weights(member_idx, 142 | self.base_model.get_weights(), 143 | self.base_model.optimizer.get_weights()) 144 | 145 | end_time = time.time() 146 | total_time += end_time - start_time 147 | # training weight modular 148 | msg = "train the weight modular at epoch {}/{}" 149 | print(msg.format(epoch, self.hparam.n_epochs)) 150 | start_time = time.time() 151 | history = self.fit_weight_modular(train_set, validation_set, epoch) 152 | end_time = time.time() 153 | total_time += end_time - start_time 154 | # saving 155 | logger.info('Training ensemble costs {} in total (including validation).'.format(total_time)) 156 | train_acc = history.history['binary_accuracy'][0] 157 | val_acc = history.history['val_binary_accuracy'][0] 158 | msg = 'Epoch {}/{}: training accuracy {:.5f}, validation accuracy {:.5f}.'.format( 159 | epoch + 1, self.hparam.n_epochs, train_acc, val_acc 160 | ) 161 | logger.info(msg) 162 | if (epoch + 1) % self.hparam.interval == 0: 163 | if val_acc > best_val_accuracy: 164 | self.save_ensemble_weights() 165 | best_val_accuracy = val_acc 166 | msg = '\t The best validation accuracy is {:.5f}, obtained at epoch {}/{}'.format( 167 | best_val_accuracy, epoch + 1, self.hparam.n_epochs 168 | ) 169 | logger.info(msg) 170 | return 171 | 172 | def fit_weight_modular(self, train_set, validation_set, epoch): 173 | """ 174 | fit weight modular 175 | :param train_set: training set 176 | :param validation_set: validation set 177 | :param epoch: integer, training epoch 178 | :return: None 179 | """ 180 | 181 | # obtain data 182 | def get_data(x_y_set): 183 | tsf_x = [] 184 | tsf_y = [] 185 | for _x, _y in x_y_set: 186 | _x_list = [] 187 | for base_model in self.model_generator(): 188 | _x_pred = base_model(_x) 189 | _x_list.append(_x_pred) 190 | tsf_x.append(np.hstack(_x_list)) 191 | tsf_y.append(_y) 192 | return np.vstack(tsf_x), np.concatenate(tsf_y) 193 | 194 | transform_train_set = build_dataset_from_numerical_data(get_data(train_set)) 195 | transform_val_set = build_dataset_from_numerical_data(get_data(validation_set)) 196 | 197 | history = self.weight_modular.fit(transform_train_set, 198 | epochs=epoch + 1, 199 | initial_epoch=epoch, 200 | validation_data=transform_val_set 201 | ) 202 | return history 203 | 204 | def save_ensemble_weights(self): 205 | if not path.exists(self.save_dir): 206 | utils.mkdir(self.save_dir) 207 | # save model configuration 208 | try: 209 | config = self.base_model.to_json() 210 | utils.dump_json(config, path.join(self.save_dir, 211 | self.architecture_type + '.json')) # lightweight method for saving model configurature 212 | except Exception as e: 213 | pass 214 | finally: 215 | if not path.exists(path.join(self.save_dir, self.architecture_type)): 216 | utils.mkdir(path.join(self.save_dir, self.architecture_type)) 217 | self.base_model.save(path.join(self.save_dir, self.architecture_type)) 218 | print("Save the model configuration to directory {}".format(self.save_dir)) 219 | 220 | # save model weights 221 | utils.dump_joblib(self.weights_list, path.join(self.save_dir, self.architecture_type + '.model')) 222 | utils.dump_joblib(self._optimizers_dict, path.join(self.save_dir, self.architecture_type + '.model.metadata')) 223 | print("Save the model weights to directory {}".format(self.save_dir)) 224 | 225 | # save weight modular 226 | self.weight_modular.save(path.join(self.save_dir, self.architecture_type + '_weight_modular')) 227 | print("Save the weight modular weights to directory {}".format(self.save_dir)) 228 | return 229 | 230 | def load_ensemble_weights(self): 231 | if path.exists(path.join(self.save_dir, self.architecture_type + '.json')): 232 | config = utils.load_json(path.join(self.save_dir, self.architecture_type + '.json')) 233 | self.base_model = tf.keras.models.model_from_json(config) 234 | elif path.exists(path.join(self.save_dir, self.architecture_type)): 235 | self.base_model = tf.keras.models.load_model(path.join(self.save_dir, self.architecture_type)) 236 | else: 237 | logger.error("File not found: ".format(path.join(self.save_dir, self.architecture_type + '.json'))) 238 | raise FileNotFoundError 239 | print("Load model config from {}.".format(self.save_dir)) 240 | 241 | if path.exists(path.join(self.save_dir, self.architecture_type + '.model')): 242 | self.weights_list = utils.read_joblib(path.join(self.save_dir, self.architecture_type + '.model')) 243 | else: 244 | logger.error("File not found: ".format(path.join(self.save_dir, self.architecture_type + '.model'))) 245 | raise FileNotFoundError 246 | print("Load model weights from {}.".format(self.save_dir)) 247 | 248 | if path.exists(path.join(self.save_dir, self.architecture_type + '.model.metadata')): 249 | self._optimizers_dict = utils.read_joblib( 250 | path.join(self.save_dir, self.architecture_type + '.model.metadata')) 251 | else: 252 | self._optimizers_dict = [None] * len(self.weights_list) 253 | 254 | if path.exists(path.join(self.save_dir, self.architecture_type + '_weight_modular')): 255 | self.weight_modular = tf.keras.models.load_model( 256 | path.join(self.save_dir, self.architecture_type + '_weight_modular')) 257 | return 258 | --------------------------------------------------------------------------------