├── core
├── __init__.py
├── post_calibration
│ ├── __init__.py
│ └── temperature_scaling.py
├── feature
│ ├── dex2img
│ │ ├── __init__.py
│ │ └── dex2img.py
│ ├── drebin
│ │ └── __init__.py
│ ├── opcodeseq
│ │ ├── __init__.py
│ │ └── opcodeseq.py
│ ├── apiseq
│ │ ├── __init__.py
│ │ └── apiseq.py
│ ├── multimodality
│ │ ├── __init__.py
│ │ └── multimodality.py
│ └── __init__.py
└── ensemble
│ ├── __init__.py
│ ├── ensemble.py
│ ├── mc_dropout.py
│ ├── anchor_ensemble_test.py
│ ├── dataset_lib.py
│ ├── bayesian_ensemble_test.py
│ ├── model_hp.py
│ ├── anchor_ensemble.py
│ ├── bayesian_ensemble.py
│ └── deep_ensemble.py
├── tools
├── __init__.py
├── progressbar
│ ├── compat.py
│ ├── __init__.py
│ ├── examples.py
│ ├── progressbar.py
│ └── widgets.py
├── progressbar_wrapper.py
├── temporal.py
└── metrics.py
├── experiments
├── __init__.py
├── drebin_ood_test.py
├── adv_main.py
├── oos_main.py
├── androzoo_main.py
├── drebin_main.py
├── oos.py
├── adv.py
├── drebin_ood.py
└── drebin_dataset.py
├── .idea
├── .gitignore
├── misc.xml
├── vcs.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
└── malware-uncertainty.iml
├── main.py
├── requirements
├── requirements.txt
├── test
├── dataset_lib_test.py
├── feature_extraction_test.py
├── model_lib_test.py
├── vanilla_test.py
├── mc_dropout_test.py
├── bayesian_ensemble_test.py
└── deep_ensemble_test.py
├── config.py
├── conf
├── conf-server-ubuntu
├── README.md
└── run.sh
/core/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/experiments/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/core/post_calibration/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Default ignored files
3 | /workspace.xml
--------------------------------------------------------------------------------
/core/feature/dex2img/__init__.py:
--------------------------------------------------------------------------------
1 | from core.feature.dex2img.dex2img import dex2img
2 |
--------------------------------------------------------------------------------
/core/feature/drebin/__init__.py:
--------------------------------------------------------------------------------
1 | from core.feature.drebin.drebin import AxplorerMapping, \
2 | get_drebin_feature, \
3 | wrapper_load_features
4 |
--------------------------------------------------------------------------------
/core/feature/opcodeseq/__init__.py:
--------------------------------------------------------------------------------
1 | from core.feature.opcodeseq.opcodeseq import feature_extr_wrapper, \
2 | read_opcode, \
3 | read_opcode_wrapper
4 |
--------------------------------------------------------------------------------
/core/feature/apiseq/__init__.py:
--------------------------------------------------------------------------------
1 | from core.feature.apiseq.apiseq import get_api_sequence, \
2 | load_feature, \
3 | wrapper_load_feature, \
4 | wrapper_mapping
5 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 |
4 |
5 | def _main():
6 | print("hello")
7 | return 0
8 |
9 |
10 | if __name__ == '__main__':
11 | print(_main())
--------------------------------------------------------------------------------
/core/feature/multimodality/__init__.py:
--------------------------------------------------------------------------------
1 | from core.feature.multimodality.multimodality import API_LIST, \
2 | get_multimod_feature, \
3 | load_feature, \
4 | wrapper_load_features
5 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/requirements:
--------------------------------------------------------------------------------
1 | Tensorflow >= 2.1.0
2 | tensorflow-probability==0.9.0
3 | numpy >= 1.17.3
4 | scikit-learn >= 0.21.3
5 | pandas >= 1.0.4
6 | androguard == 3.3.5
7 | absl-py == 0.8.1
8 | Pillow
9 | pyelftools
10 | capstone
11 | python-magic
12 | pyyaml
13 | multiprocessing-logging
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow==2.2.0
2 | tensorflow-probability==0.10.1
3 | numpy>=1.17.3
4 | scikit-learn==0.23.2
5 | pandas>=1.0.4
6 | androguard==3.3.5
7 | absl-py==0.8.1
8 | # python-magic-bin==0.4.14
9 | seaborn
10 | Pillow
11 | pyelftools
12 | capstone
13 | python-magic
14 | pyyaml
15 | multiprocessing-logging
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/malware-uncertainty.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/test/dataset_lib_test.py:
--------------------------------------------------------------------------------
1 | from absl.testing import absltest
2 | from absl.testing import parameterized
3 |
4 | import numpy as np
5 | import tensorflow as tf
6 |
7 | from core.ensemble.dataset_lib import build_dataset_from_numerical_data
8 |
9 |
10 | class MyTestCase(tf.test.TestCase, parameterized.TestCase):
11 | def test_something(self):
12 | x = np.random.rand(5,3)
13 | data = build_dataset_from_numerical_data(x, batch_size=1)
14 | for i, _x in enumerate(data):
15 | self.assertAllClose(_x[0, :], x[i, :])
16 |
17 | y = np.random.choice(2, 5)
18 | dataset = build_dataset_from_numerical_data((x, y), batch_size=1)
19 | for i, (_x, _y) in enumerate(dataset):
20 | self.assertAllClose(_x[0, :], x[i, :])
21 | self.assertEqual(_y, y[i])
22 |
23 |
24 | if __name__ == '__main__':
25 | absltest.main()
26 |
--------------------------------------------------------------------------------
/core/ensemble/__init__.py:
--------------------------------------------------------------------------------
1 | from tools.utils import ensemble_method_scope as _ensemble_method_names
2 | from core.ensemble.bayesian_ensemble import BayesianEnsemble
3 | from core.ensemble.mc_dropout import MCDropout
4 | from core.ensemble.deep_ensemble import DeepEnsemble, WeightedDeepEnsemble
5 | from core.ensemble.vanilla import Vanilla
6 | from collections import namedtuple
7 |
8 | _Ensemble_methods = namedtuple('ensemble_methods', _ensemble_method_names)
9 | _ensemble_methods = _Ensemble_methods(vanilla=Vanilla,
10 | mc_dropout=MCDropout,
11 | bayesian=BayesianEnsemble,
12 | deep_ensemble=DeepEnsemble,
13 | weighted_ensemble=WeightedDeepEnsemble
14 | )
15 | ensemble_method_scope_dict = dict(_ensemble_methods._asdict())
16 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import sys
5 | import time
6 | import logging
7 | import multiprocessing_logging
8 |
9 | if sys.version_info[0] < 3:
10 | import ConfigParser as configparser
11 | else:
12 | import configparser
13 |
14 | config = configparser.ConfigParser()
15 |
16 | get = config.get
17 | config_dir = os.path.dirname(__file__)
18 |
19 |
20 | def parser_config():
21 | config_file = os.path.join(config_dir, "conf")
22 |
23 | if not os.path.exists(config_file):
24 | sys.stderr.write("Error: Unable to find the config file!\n")
25 | sys.exit(1)
26 |
27 | # parse the configuration
28 | global config
29 | config.read_file(open(config_file))
30 |
31 |
32 | parser_config()
33 |
34 | logging.basicConfig(level=logging.INFO, filename=os.path.join(config_dir, "log"), filemode="a",
35 | format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s: %(message)s',
36 | datefmt='%Y/%m/%d %H:%M:%S')
37 | ErrorHandler = logging.StreamHandler()
38 | ErrorHandler.setFormatter(logging.Formatter('%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s: %(message)s'))
39 | multiprocessing_logging.install_mp_handler()
40 |
--------------------------------------------------------------------------------
/core/feature/__init__.py:
--------------------------------------------------------------------------------
1 | from core.feature.feature_extraction import DrebinFeature, \
2 | OpcodeSeq, \
3 | MultiModality, \
4 | DexToImage, \
5 | APISequence
6 |
7 | from collections import namedtuple
8 | from core.ensemble.model_lib import model_name_type_dict
9 |
10 | feature_type_scope_dict = {
11 | 'drebin': DrebinFeature,
12 | 'multimodality': MultiModality,
13 | 'opcodeseq': OpcodeSeq,
14 | 'dex2img': DexToImage,
15 | 'apiseq': APISequence
16 | }
17 |
18 | # bridge the gap between the feature extraction and dnn architecture
19 | _ARCH_TYPE = namedtuple('architectures', model_name_type_dict.keys())
20 | _architecture_feature_extraction = _ARCH_TYPE(dnn='drebin',
21 | multimodalitynn='multimodality',
22 | text_cnn='opcodeseq',
23 | r2d2='dex2img',
24 | droidectc='apiseq'
25 | )
26 | _architecture_feature_extraction_dict = dict(_architecture_feature_extraction._asdict())
27 | feature_type_vs_architecture = dict(zip(_architecture_feature_extraction_dict.values(),
28 | _architecture_feature_extraction_dict.keys()))
29 |
--------------------------------------------------------------------------------
/experiments/drebin_ood_test.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 | import itertools
5 |
6 | from tensorflow.compat.v1 import flags
7 | from absl.testing import absltest
8 | from absl.testing import parameterized
9 | from experiments.drebin_ood import run_experiment, feature_type_scope_dict, ensemble_method_scope_dict
10 |
11 | flags.DEFINE_integer('n_members', 1,
12 | 'The number of members for deep ensemble, weighted deep ensemble, and anchor ensemble')
13 | flags.DEFINE_integer('proc_numbers', 2,
14 | 'The number of threads for features extraction')
15 |
16 |
17 | # feature_types = feature_type_scope_dict.keys()
18 | # ensemble_types = ensemble_method_scope_dict.keys()
19 | feature_types = ['dex2img'] # 'drebin', 'opcodeseq', 'multimodality',
20 | ensemble_types = ['bayesian']
21 |
22 |
23 | class TestOODExperiments(parameterized.TestCase):
24 | @parameterized.named_parameters(
25 | [('%s_%s' % p,) + p for p in itertools.product(feature_types, ensemble_types)])
26 | def test_run(self, feature, ensemble):
27 | print('testing:', feature, ensemble)
28 | run_experiment(feature, ensemble, flags.FLAGS.n_members, flags.FLAGS.proc_numbers)
29 |
30 |
31 | if __name__ == '__main__':
32 | absltest.main()
33 |
--------------------------------------------------------------------------------
/conf:
--------------------------------------------------------------------------------
1 | [DEFAULT]
2 | project_root = F:\projects\malware-det\malware-uncertainty\
3 | database_dir = F:\dataSet\android\
4 |
5 | [drebin]
6 | dataset_name = drebin
7 | malware_dir = %(database_dir)s\%(dataset_name)s\malicious_samples
8 | benware_dir = %(database_dir)s\%(dataset_name)s\benign_samples
9 | intermediate_directory = %(project_root)s\datasets\%(dataset_name)s
10 |
11 | [androzoo_tesseract]
12 | dataset_name = androzoo_tesseract
13 | malware_dir = %(database_dir)s\%(dataset_name)s\malicious_samples
14 | benware_dir = %(database_dir)s\%(dataset_name)s\benign_samples
15 | date_stamp = %(database_dir)s\%(dataset_name)s\date_stamp.json
16 | intermediate_directory = %(project_root)s\datasets\%(dataset_name)s
17 |
18 | [oos]
19 | dataset_name = VirusShare_Android_APK_2013
20 | malware_dir = %(database_dir)s\%(dataset_name)s\malicious_samples
21 | benware_dir = %(database_dir)s\%(dataset_name)s\benign_samples
22 | intermediate_directory = %(project_root)s\datasets\%(dataset_name)s
23 |
24 | [adv]
25 | dataset_name = drebin_adv
26 | pristine_apk_dir = %(database_dir)s\%(dataset_name)s\pristine_samples
27 | perturbed_apk_dir = %(database_dir)s\%(dataset_name)s\perturbed_samples
28 | intermediate_directory = %(project_root)s\datasets\%(dataset_name)s
29 |
30 | [metadata]
31 | naive_data_pool = %(database_dir)s\naive_data
32 |
33 | [experiments]
34 | oos = %(project_root)s/save/oos/
35 | adv = %(project_root)s/save/adv/
36 | drebin = %(project_root)s/save/drebin/
37 | androzoo = %(project_root)s/save/androzoo/
--------------------------------------------------------------------------------
/conf-server-ubuntu:
--------------------------------------------------------------------------------
1 | [DEFAULT]
2 | project_root = /absolute/path/to/malware-uncertainty/
3 | database_dir = /absolute/path/to/datasets
4 |
5 | [drebin]
6 | dataset_name = drebin
7 | malware_dir = %(database_dir)s/%(dataset_name)s/malicious_samples
8 | benware_dir = %(database_dir)s/%(dataset_name)s/benign_samples
9 | intermediate_directory = %(project_root)s/datasets/%(dataset_name)s
10 |
11 | [androzoo_tesseract]
12 | dataset_name = androzoo_tesseract
13 | malware_dir = %(database_dir)s/%(dataset_name)s/malicious_samples
14 | benware_dir = %(database_dir)s/%(dataset_name)s/benign_samples
15 | date_stamp = %(database_dir)s/%(dataset_name)s/date_stamp.json
16 | intermediate_directory = %(project_root)s/datasets/%(dataset_name)s
17 |
18 | [oos]
19 | dataset_name = VirusShare_Android_APK_2013
20 | malware_dir = %(database_dir)s/%(dataset_name)s/malicious_samples
21 | benware_dir = %(database_dir)s/%(dataset_name)s/benign_samples
22 | intermediate_directory = %(project_root)s/datasets/%(dataset_name)s
23 |
24 | [adv]
25 | dataset_name = drebin_adv
26 | pristine_apk_dir = %(database_dir)s/%(dataset_name)s/pristine_samples
27 | perturbed_apk_dir = %(database_dir)s/%(dataset_name)s/perturbed_samples
28 | intermediate_directory = %(project_root)s/datasets/%(dataset_name)s
29 |
30 | [metadata]
31 | naive_data_pool = %(database_dir)s/naive_data/
32 |
33 | [experiments]
34 | oos = %(project_root)s/save/oos/
35 | adv = %(project_root)s/save/adv/
36 | drebin = %(project_root)s/save/drebin/
37 | androzoo = %(project_root)s/save/androzoo/
--------------------------------------------------------------------------------
/test/feature_extraction_test.py:
--------------------------------------------------------------------------------
1 | """Tests for malware-uncertainty.core.feature"""
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | from absl.testing import absltest
8 | from absl.testing import parameterized
9 |
10 | import numpy as np
11 |
12 | from core.feature import feature_type_scope_dict
13 | from config import config
14 |
15 |
16 | class MyTestCase(parameterized.TestCase):
17 | @parameterized.named_parameters(
18 | [("%s" % mtd_name,) + (mtd_name,) for mtd_name in [feature_type_scope_dict['apiseq']]])
19 | def test_feature_extraction(self, feature_type):
20 | malware_dir_name = config.get('drebin', 'malware_dir')
21 | benware_dir_name = config.get('drebin', 'benware_dir')
22 | meta_data_saving_dir = config.get('drebin', 'intermediate_directory')
23 | naive_data_saving_dir = config.get('metadata', 'naive_data_pool')
24 | feature_extractor = feature_type(naive_data_saving_dir, meta_data_saving_dir, update=False)
25 | malware_features = feature_extractor.feature_extraction(malware_dir_name)
26 | benign_features = feature_extractor.feature_extraction(benware_dir_name)
27 | features = malware_features + benign_features
28 | gt_labels = np.zeros((len(malware_features) + len(benign_features)), dtype=np.int32)
29 | gt_labels[:len(malware_features)] = 1
30 | feature_extractor.feature_preprocess(features, gt_labels)
31 | feature_extractor.feature2ipt(features)
32 |
33 |
34 | if __name__ == '__main__':
35 | absltest.main()
36 |
--------------------------------------------------------------------------------
/tools/progressbar/compat.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # progressbar - Text progress bar library for Python.
5 | # Copyright (c) 2005 Nilton Volpato
6 | #
7 | # This library is free software; you can redistribute it and/or
8 | # modify it under the terms of the GNU Lesser General Public
9 | # License as published by the Free Software Foundation; either
10 | # version 2.1 of the License, or (at your option) any later version.
11 | #
12 | # This library is distributed in the hope that it will be useful,
13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15 | # Lesser General Public License for more details.
16 | #
17 | # You should have received a copy of the GNU Lesser General Public
18 | # License along with this library; if not, write to the Free Software
19 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 |
21 | """Compatibility methods and classes for the progressbar module."""
22 |
23 |
24 | # Python 3.x (and backports) use a modified iterator syntax
25 | # This will allow 2.x to behave with 3.x iterators
26 | try:
27 | next
28 | except NameError:
29 | def next(iter):
30 | try:
31 | # Try new style iterators
32 | return iter.__next__()
33 | except AttributeError:
34 | # Fallback in case of a "native" iterator
35 | return iter.next()
36 |
37 |
38 | # Python < 2.5 does not have "any"
39 | try:
40 | any
41 | except NameError:
42 | def any(iterator):
43 | for item in iterator:
44 | if item: return True
45 | return False
46 |
--------------------------------------------------------------------------------
/core/feature/opcodeseq/opcodeseq.py:
--------------------------------------------------------------------------------
1 | from androguard.misc import AnalyzeAPK
2 | from tools import utils
3 |
4 | from config import logging
5 |
6 | logger = logging.getLogger('feature.opcodeseq')
7 |
8 |
9 | def get_opcode_sequences(apk_path, save_path):
10 | _1, _2, dx = AnalyzeAPK(apk_path)
11 |
12 | opcode_chunks = []
13 | for method in dx.get_methods():
14 | if method.is_external():
15 | continue
16 | mth_body = method.get_method()
17 | sequence = []
18 | for ins in mth_body.get_instructions():
19 | opcode = ins.get_op_value()
20 | if opcode < 0:
21 | opcode = 0
22 | elif opcode >= 256:
23 | opcode = 0
24 | else:
25 | opcode = opcode
26 | sequence.append(opcode) # list of 'int'
27 | if len(sequence) > 0:
28 | opcode_chunks.append(sequence)
29 | dump_opcode(opcode_chunks, save_path)
30 |
31 | return save_path
32 |
33 |
34 | def dump_opcode(opcode_chunks, save_path):
35 | utils.dump_json(opcode_chunks, save_path)
36 | return
37 |
38 |
39 | def read_opcode(save_path):
40 | return utils.load_json(save_path)
41 |
42 |
43 | def read_opcode_wrapper(save_path):
44 | try:
45 | return read_opcode(save_path)
46 | except Exception as e:
47 | return e
48 |
49 |
50 | def feature_extr_wrapper(*args):
51 | """
52 | A helper function to catch the exception
53 | :param element: argurments for feature extraction
54 | :return: feature or Exception
55 | """
56 | try:
57 | return get_opcode_sequences(*args)
58 | except Exception as e:
59 | return e
60 |
--------------------------------------------------------------------------------
/core/ensemble/ensemble.py:
--------------------------------------------------------------------------------
1 | """
2 | base ensemble class
3 | """
4 |
5 |
6 | class Ensemble(object):
7 | def __init__(self, architecture_type, base_model, n_members, model_directory):
8 | """
9 | initialization
10 | :param architecture_type: e.g., 'dnn'
11 | :param base_model: an instantiated object of base model
12 | :param n_members: number of base models
13 | :param model_directory: a folder for saving ensemble weights
14 | """
15 | self.architecture_type = architecture_type
16 | self.base_model = base_model
17 | self.n_members = n_members
18 | self.model_directory = model_directory
19 | self.weights_list = [] # a model's parameters
20 | self._optimizers_dict = dict()
21 |
22 | def build_model(self):
23 | """Build an ensemble model"""
24 | raise NotImplementedError
25 |
26 | def predict(self, x):
27 | """conduct prediction"""
28 | raise NotImplementedError
29 |
30 | def get_basic_layers(self):
31 | """ construct the basic layers"""
32 | raise NotImplementedError
33 |
34 | def fit(self, train_x, train_y, val_x=None, val_y=None, **kwargs):
35 | """ tune the model parameters upon given dataset"""
36 | raise NotImplementedError
37 |
38 | def get_model_number(self):
39 | """ get the number of base models"""
40 | raise NotImplementedError
41 |
42 | def reset(self):
43 | self.weights_list = []
44 |
45 | def save_ensemble_weights(self):
46 | """ save the model parameters"""
47 | raise NotImplementedError
48 |
49 | def load_ensemble_weights(self):
50 | """ Load the model parameters """
51 | raise NotImplementedError
52 |
53 | def gradient_loss_wrt_input(self, x):
54 | """ obtain gradients of loss function with respect to the input."""
55 | raise NotImplementedError
56 |
--------------------------------------------------------------------------------
/experiments/adv_main.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 | import argparse
5 | from tensorflow.compat.v1 import app
6 | from experiments.adv import run_experiment, run_temperature_scaling
7 |
8 | args = argparse.ArgumentParser(description='test malware detectors on the adversarial dataset')
9 | args.add_argument('--detector', type=str, default='drebin', choices=[
10 | 'drebin', # deepdrebin
11 | 'opcodeseq', # deepDroid
12 | 'multimodality', # multimodalnn
13 | 'dex2img', # r2d2, due to the issue of effectiveness, we neglect this method
14 | 'apiseq', # droidetec, due to the issue of effectiveness, we neglect this method
15 | ], help='malware detection method')
16 |
17 | args.add_argument('--calibration', type=str, default='vanilla', choices=[
18 | 'vanilla',
19 | 'temp_scaling'
20 | 'mc_dropout',
21 | 'deep_ensemble',
22 | 'weighted_ensemble',
23 | 'bayesian' # variational bayesian inference
24 | ], help='calibration method')
25 | args.add_argument('--proc_numbers', type=int, default=2,
26 | help='number of threads for parallelizing features extraction.')
27 |
28 | option = args.parse_args()
29 |
30 | non_used_methods = ['dex2img', 'apiseq']
31 | assert option.detector not in non_used_methods
32 |
33 |
34 | def main(_):
35 | if option.calibration != 'temp_scaling':
36 | run_experiment(option.detector,
37 | option.calibration,
38 | option.proc_numbers
39 | )
40 | else:
41 | run_temperature_scaling(option.detector, 'vanilla', option.proc_numbers)
42 |
43 |
44 | if __name__ == '__main__':
45 | app.run()
46 |
--------------------------------------------------------------------------------
/experiments/oos_main.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import argparse
6 | from tensorflow.compat.v1 import app
7 | from experiments.oos import run_experiment, run_temperature_scaling
8 |
9 | args = argparse.ArgumentParser(description='test malware detectors on the adversarial dataset')
10 | args.add_argument('--detector', type=str, default='drebin', choices=[
11 | 'drebin', # deepdrebin
12 | 'opcodeseq', # deepDroid
13 | 'multimodality', # multimodalnn
14 | 'dex2img', # r2d2, due to the issue of effectiveness, we neglect this method
15 | 'apiseq', # droidetec, due to the issue of effectiveness, we neglect this method
16 | ], help='malware detection method')
17 |
18 | args.add_argument('--calibration', type=str, default='vanilla', choices=[
19 | 'vanilla',
20 | 'temp_scaling'
21 | 'mc_dropout',
22 | 'deep_ensemble',
23 | 'weighted_ensemble',
24 | 'bayesian' # variational bayesian inference
25 | ], help='calibration method')
26 | args.add_argument('--proc_numbers', type=int, default=2,
27 | help='number of threads for parallelizing features extraction.')
28 |
29 | option = args.parse_args()
30 |
31 | non_used_methods = ['dex2img', 'apiseq']
32 | assert option.detector not in non_used_methods
33 |
34 |
35 | def main(_):
36 | if option.calibration != 'temp_scaling':
37 | run_experiment(option.detector,
38 | option.calibration,
39 | option.proc_numbers
40 | )
41 | else:
42 | run_temperature_scaling(option.detector, 'vanilla', option.proc_numbers)
43 |
44 |
45 | if __name__ == '__main__':
46 | app.run()
47 |
48 |
--------------------------------------------------------------------------------
/experiments/androzoo_main.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import argparse
6 | from tensorflow.compat.v1 import app
7 | from experiments.androzoo_dataset import run_experiment, run_temperature_scaling
8 |
9 | args = argparse.ArgumentParser(description='learning malware detectors on the Drebin dataset')
10 | args.add_argument('--detector', type=str, default='drebin', choices=[
11 | 'drebin', # deepdrebin
12 | 'opcodeseq', # deepDroid
13 | 'multimodality', # multimodalnn
14 | 'dex2img', # r2d2, due to the issue of effectiveness, we neglect this method
15 | 'apiseq', # droidetec, due to the issue of effectiveness, we neglect this method
16 | ], help='malware detection method')
17 |
18 | args.add_argument('--calibration', type=str, default='vanilla', choices=[
19 | 'vanilla',
20 | 'temp_scaling'
21 | 'mc_dropout',
22 | 'deep_ensemble',
23 | 'weighted_ensemble',
24 | 'bayesian' # variational bayesian inference
25 | ], help='calibration method')
26 | args.add_argument('--n_members', type=int, default=10, help='number of members in ensemble or weighted ensemble.')
27 | args.add_argument('--proc_numbers', type=int, default=2,
28 | help='number of threads for parallelizing features extraction.')
29 |
30 | option = args.parse_args()
31 |
32 | non_used_methods = ['dex2img', 'apiseq']
33 | assert option.detector not in non_used_methods
34 |
35 |
36 | def main(_):
37 | if option.calibration != 'temp_scaling':
38 | run_experiment(option.detector,
39 | option.calibration,
40 | option.n_members,
41 | option.proc_numbers
42 | )
43 | else:
44 | run_temperature_scaling(option.detector, 'vanilla')
45 |
46 |
47 | if __name__ == '__main__':
48 | app.run()
49 |
--------------------------------------------------------------------------------
/tools/progressbar/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # progressbar - Text progress bar library for Python.
5 | # Copyright (c) 2005 Nilton Volpato
6 | #
7 | # This library is free software; you can redistribute it and/or
8 | # modify it under the terms of the GNU Lesser General Public
9 | # License as published by the Free Software Foundation; either
10 | # version 2.1 of the License, or (at your option) any later version.
11 | #
12 | # This library is distributed in the hope that it will be useful,
13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15 | # Lesser General Public License for more details.
16 | #
17 | # You should have received a copy of the GNU Lesser General Public
18 | # License along with this library; if not, write to the Free Software
19 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 |
21 | """Text progress bar library for Python.
22 |
23 | A text progress bar is typically used to display the progress of a long
24 | running operation, providing a visual cue that processing is underway.
25 |
26 | The ProgressBar class manages the current progress, and the format of the line
27 | is given by a number of widgets. A widget is an object that may display
28 | differently depending on the state of the progress bar. There are three types
29 | of widgets:
30 | - a string, which always shows itself
31 |
32 | - a ProgressBarWidget, which may return a different value every time its
33 | update method is called
34 |
35 | - a ProgressBarWidgetHFill, which is like ProgressBarWidget, except it
36 | expands to fill the remaining width of the line.
37 |
38 | The progressbar module is very easy to use, yet very powerful. It will also
39 | automatically enable feature like auto-resizing when the system supports it.
40 | """
41 |
42 | __author__ = 'Nilton Volpato'
43 | __author_email__ = 'first-name dot last-name @ gmail.com'
44 | __date__ = '2011-05-14'
45 | __version__ = '2.3'
46 |
47 |
48 | from .compat import *
49 | from .widgets import *
50 | from .progressbar import *
51 |
--------------------------------------------------------------------------------
/core/feature/dex2img/dex2img.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 | from PIL import Image
5 | import zipfile
6 |
7 | import numpy as np
8 |
9 | from config import logging
10 |
11 | current_dir = os.path.dirname(os.path.realpath(__file__))
12 | logger = logging.getLogger('feature.dex2img')
13 |
14 |
15 | def dex2img(apk_path, save_path, num_channels=3):
16 | """
17 | convert dex file to rbg images
18 | :param apk_path: an apk path
19 | :param save_path: a path for saving the resulting image
20 | :param num_channels: r, g, b channels
21 | :return: (status, save_path)
22 | """
23 | try:
24 | print("Processing " + apk_path)
25 | start_time = time.time()
26 | with zipfile.ZipFile(apk_path, 'r') as fh_apk:
27 | dex2num_list = []
28 | for name in fh_apk.namelist():
29 | if name.endswith('dex'):
30 | with fh_apk.open(name, 'r') as fr:
31 | hex_string = fr.read().hex()
32 | dex2num = [int(hex_string[i:i + 2], base=16) for i in \
33 | range(0, len(hex_string), 2)]
34 | dex2num_list.extend(dex2num)
35 |
36 | # extend to three channels (e.g., r,g,b)
37 | num_appending_zero = num_channels - len(dex2num_list) % num_channels
38 | dex2num_list += [0] * num_appending_zero
39 | # shape: [3, -1]
40 | dex2array = np.array([dex2num_list[0::3], dex2num_list[1::3], dex2num_list[2::3]], dtype=np.uint8)
41 | # get image matrix
42 | from math import sqrt, ceil
43 | _length = int(pow(ceil(sqrt(dex2array.shape[1])), 2))
44 | if _length > dex2array.shape[1]:
45 | padding_zero = np.zeros((3, _length - dex2array.shape[1]), dtype=np.uint8)
46 | dex2array = np.concatenate([dex2array, padding_zero], axis=1)
47 | dex2mat = np.reshape(dex2array, (-1, int(sqrt(_length)), int(sqrt(_length))))
48 | dex2mat_img = np.transpose(dex2mat, (1, 2, 0))
49 | img_handler = Image.fromarray(dex2mat_img)
50 | img_handler.save(save_path)
51 | except Exception as e:
52 | return e
53 | else:
54 | return save_path
55 |
--------------------------------------------------------------------------------
/core/ensemble/mc_dropout.py:
--------------------------------------------------------------------------------
1 | from core.ensemble.vanilla import Vanilla, model_builder
2 | from core.ensemble.model_hp import train_hparam, mc_dropout_hparam
3 | from tools import utils
4 | from config import logging, ErrorHandler
5 |
6 | logger = logging.getLogger('ensemble.mc_dropout')
7 | logger.addHandler(ErrorHandler)
8 |
9 |
10 | class MCDropout(Vanilla):
11 | def __init__(self,
12 | architecture_type='dnn',
13 | base_model=None,
14 | n_members=1,
15 | model_directory=None,
16 | name='MC_DROPOUT'
17 | ):
18 | super(MCDropout, self).__init__(architecture_type,
19 | base_model,
20 | n_members,
21 | model_directory,
22 | name)
23 | self.hparam = utils.merge_namedtuples(train_hparam, mc_dropout_hparam)
24 | self.ensemble_type = 'mc_dropout'
25 |
26 | def build_model(self, input_dim=None):
27 | """
28 | Build an ensemble model -- only the homogeneous structure is considered
29 | :param input_dim: integer or list, input dimension shall be set in some cases under eager mode
30 | """
31 | callable_graph = model_builder(self.architecture_type)
32 |
33 | @callable_graph(input_dim, use_mc_dropout=True)
34 | def _builder():
35 | return utils.produce_layer(self.ensemble_type, dropout_rate=self.hparam.dropout_rate)
36 |
37 | self.base_model = _builder()
38 | return
39 |
40 | def model_generator(self):
41 | try:
42 | if len(self.weights_list) <= 0:
43 | self.load_ensemble_weights()
44 | except Exception as e:
45 | raise Exception("Cannot load model weights:{}.".format(str(e)))
46 |
47 | assert len(self.weights_list) == self.n_members
48 | self.base_model.set_weights(weights=self.weights_list[self.n_members - 1])
49 | # if len(self._optimizers_dict) > 0 and self.base_model.optimizer is not None:
50 | # self.base_model.optimizer.set_weights(self._optimizers_dict[self.n_members - 1])
51 | for _ in range(self.hparam.n_sampling):
52 | yield self.base_model
53 |
--------------------------------------------------------------------------------
/core/post_calibration/temperature_scaling.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import scipy.optimize
6 | import scipy.special
7 | import numpy as np
8 | import tensorflow as tf
9 | import tensorflow_probability as tfp
10 |
11 |
12 | def find_scaling_temperature(labels, logits, temp_range=(1e-5, 1e5)):
13 | """
14 | Code is adapted from https://github.com/google-research/google-research/tree/master/uq_benchmark_2019/
15 | Find max likelihood scaling temperature using binary search.
16 |
17 | Args:
18 | labels: Integer labels (shape=[num_samples]).
19 | logits: Floating point softmax inputs (shape=[num_samples, num_classes]).
20 | temp_range: 2-tuple range of temperatures to consider.
21 | Returns:
22 | Floating point temperature value.
23 | """
24 | if not tf.executing_eagerly():
25 | raise NotImplementedError(
26 | 'find_scaling_temperature() not implemented for graph-mode TF')
27 | if len(labels.shape) != 1:
28 | raise ValueError('Invalid labels shape=%s' % str(labels.shape))
29 | if len(logits.shape) not in (1, 2):
30 | raise ValueError('Invalid logits shape=%s' % str(logits.shape))
31 | if len(labels.shape) != 1 or len(labels) != len(logits):
32 | raise ValueError('Incompatible shapes for logits (%s) vs labels (%s).' %
33 | (logits.shape, labels.shape))
34 |
35 | @tf.function(autograph=False, experimental_relax_shapes=True)
36 | def grad_fn(temperature):
37 | """Returns gradient of log-likelihood WRT a logits-scaling temperature."""
38 | dist = tfp.distributions.Bernoulli(logits=logits / temperature)
39 | nll = -dist.log_prob(labels)
40 | nll = tf.reduce_sum(nll, axis=0)
41 | grad, = tf.gradients(nll, [temperature])
42 | return grad
43 |
44 | tmin, tmax = temp_range
45 | return scipy.optimize.bisect(lambda t: grad_fn(t * tf.ones([])).numpy(), tmin, tmax)
46 |
47 |
48 | def inverse_sigmoid(probs, eps=1e-7):
49 | """ compute the logit"""
50 | uniform = np.ones_like(probs) / probs.shape[-1]
51 | probs = eps * uniform + (1. - eps) * probs
52 | return -tf.math.log(1. / probs - 1.)
53 |
54 |
55 | def apply_temperature_scaling(temperature, probs):
56 | """Apply temperature scaling to an array of probabilities."""
57 | logits_t = inverse_sigmoid(probs) / temperature
58 | return tf.sigmoid(logits_t).numpy()
59 |
--------------------------------------------------------------------------------
/experiments/drebin_main.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import argparse
6 | from tensorflow.compat.v1 import app
7 | from experiments.drebin_dataset import run_experiment, run_temperature_scaling
8 |
9 | args = argparse.ArgumentParser(description='learning malware detectors on the Drebin dataset')
10 | args.add_argument('--detector', type=str, default='drebin', choices=[
11 | 'drebin', # deepdrebin
12 | 'opcodeseq', # deepDroid
13 | 'multimodality', # multimodalnn
14 | 'dex2img', # r2d2, due to the issue of effectiveness, we neglect this method
15 | 'apiseq', # droidetec, due to the issue of effectiveness, we neglect this method
16 | ], help='malware detection method')
17 |
18 | args.add_argument('--calibration', type=str, default='vanilla', choices=[
19 | 'vanilla',
20 | 'temp_scaling'
21 | 'mc_dropout',
22 | 'deep_ensemble',
23 | 'weighted_ensemble',
24 | 'bayesian' # variational bayesian inference
25 | ], help='calibration method')
26 | args.add_argument('--seed', type=int, default=7727, help='random seed')
27 | # this seed above has not been applied to NN models yet. It could, but we neglect when conducting experiments.
28 | args.add_argument('--n_members', type=int, default=10, help='number of members in ensemble or weighted ensemble.')
29 | args.add_argument('--proc_numbers', type=int, default=2,
30 | help='number of threads for paralleling features extraction.')
31 | args.add_argument('--ratio', type=float, default=3.0, help='ratio of the number of benign files to malware ones.')
32 |
33 | option = args.parse_args()
34 |
35 | non_used_methods = ['dex2img', 'apiseq']
36 | assert option.detector not in non_used_methods
37 |
38 |
39 | def main(_):
40 | if option.calibration != 'temp_scaling':
41 | run_experiment(option.detector,
42 | option.calibration,
43 | option.seed,
44 | option.n_members,
45 | option.ratio,
46 | option.proc_numbers
47 | )
48 | else:
49 | run_temperature_scaling(option.detector, 'vanilla')
50 |
51 |
52 | if __name__ == '__main__':
53 | app.run()
54 |
--------------------------------------------------------------------------------
/test/model_lib_test.py:
--------------------------------------------------------------------------------
1 | from absl.testing import absltest
2 | from absl.testing import parameterized
3 |
4 | from core.ensemble.model_lib import build_models
5 | import numpy as np
6 |
7 |
8 | class Model_lib_test(parameterized.TestCase):
9 | @parameterized.named_parameters(
10 | [(ens_type, ens_type) for ens_type in \
11 | ['vanilla', 'deep_ensemble', 'weighted_ensemble', 'mc_dropout', 'bayesian']])
12 | def test_dnn_graph(self, ensemble_type):
13 | x = np.array([[1., 2., 3.], [4., 5., 6.]])
14 | if ensemble_type is not 'mc_dropout':
15 | build_models(x, 'dnn', ensemble_type, input_dim=3)
16 | else:
17 | build_models(x, 'dnn', ensemble_type, input_dim=3, use_mc_dropout=True)
18 |
19 | @parameterized.named_parameters(
20 | [(ens_type, ens_type) for ens_type in \
21 | ['vanilla', 'deep_ensemble', 'weighted_ensemble', 'mc_dropout', 'bayesian']])
22 | def test_text_cnn_graph(self, ensemble_type):
23 | x = np.random.randint(0, 256, (2, 16))
24 | if ensemble_type is not 'mc_dropout':
25 | out = build_models(x, 'text_cnn', ensemble_type, input_dim=5)
26 | else:
27 | out = build_models(x, 'text_cnn', ensemble_type, input_dim=5, use_mc_dropout=True)
28 |
29 | self.assertTrue(out.shape == (2, 1))
30 |
31 | @parameterized.named_parameters(
32 | [(ens_type, ens_type) for ens_type in \
33 | ['vanilla', 'deep_ensemble', 'weighted_ensemble', 'mc_dropout', 'bayesian']])
34 | def test_multimodalitynn(self, ensemble_type):
35 | _x = np.array([[1., 2., 3.], [4., 5., 6.]])
36 | x = [_x, _x, _x, _x, _x]
37 | if ensemble_type is not 'mc_dropout':
38 | out = build_models(x, 'multimodalitynn', ensemble_type, input_dim=[3, 3, 3, 3, 3])
39 | else:
40 | out = build_models(x, 'multimodalitynn', ensemble_type, input_dim=[3, 3, 3, 3, 3], use_mc_dropout=True)
41 | self.assertTrue(out.shape == (2, 1))
42 |
43 | @parameterized.named_parameters(
44 | [(ens_type, ens_type) for ens_type in \
45 | ['vanilla', 'deep_ensemble', 'weighted_ensemble', 'mc_dropout', 'bayesian']])
46 | def test_droidectc(self, ensemble_type):
47 | _x = np.random.randint(0, 10000, [2, 1000])
48 | if ensemble_type is not 'mc_dropout':
49 | out = build_models(_x, 'droidectc', ensemble_type)
50 | else:
51 | out = build_models(_x, 'droidectc', ensemble_type, use_mc_dropout=True)
52 | self.assertTrue(out.shape == (2, 1))
53 |
54 |
55 | if __name__ == '__main__':
56 | absltest.main()
57 |
--------------------------------------------------------------------------------
/core/ensemble/anchor_ensemble_test.py:
--------------------------------------------------------------------------------
1 | from absl.testing import absltest
2 | from absl.testing import parameterized
3 | import tempfile
4 |
5 | import tensorflow as tf
6 | import numpy as np
7 | from sklearn.datasets import load_breast_cancer
8 |
9 | from core.ensemble.anchor_ensemble import AnchorEnsemble
10 | from core.ensemble.dataset_lib import build_dataset_from_numerical_data
11 |
12 | architectures = ['dnn', 'text_cnn', 'multimodalitynn', 'r2d2', 'droidectc']
13 |
14 | class MyTestCaseAnchorEnsemble(parameterized.TestCase):
15 | def setUp(self):
16 | self.x_dict, self.y_dict = dict(), dict()
17 | self.x_np, self.y_np = load_breast_cancer(return_X_y=True)
18 | self.x_dict['dnn'] = self.x_np
19 | self.y_dict['dnn'] = self.y_np
20 | x = np.random.randint(0, 256, (10, 10))
21 | y = np.random.choice(2, 10)
22 | self.x_dict['text_cnn'] = x
23 | self.y_dict['text_cnn'] = y
24 | x = [self.x_np] * 5
25 | self.x_dict['multimodalitynn'] = x
26 | self.y_dict['multimodalitynn'] = self.y_np
27 | x = np.random.uniform(0., 1., size=(10, 299, 299, 3))
28 | y = np.random.choice(2, 10)
29 | self.x_dict['r2d2'] = x
30 | self.y_dict['r2d2'] = y
31 | x = np.random.randint(0, 10000, size=(10, 1000))
32 | y = np.random.choice(2, 10)
33 | self.x_dict['droidectc'] = x
34 | self.y_dict['droidectc'] = y
35 |
36 | @parameterized.named_parameters([(arc_type, arc_type) for arc_type in architectures])
37 | def test_anchor_ensemble(self, arc_type):
38 | with tempfile.TemporaryDirectory() as output_dir:
39 | x = self.x_dict[arc_type]
40 | y = self.y_dict[arc_type]
41 | if arc_type is not 'multimodalitynn':
42 | train_dataset = build_dataset_from_numerical_data((x, y))
43 | val_dataset = build_dataset_from_numerical_data((x, y))
44 | n_samples = x.shape[0]
45 | input_dim = x.shape[1:]
46 | else:
47 | train_data = build_dataset_from_numerical_data(tuple(x))
48 | train_y = build_dataset_from_numerical_data(self.y_np)
49 | train_dataset = tf.data.Dataset.zip((train_data, train_y))
50 | val_data = build_dataset_from_numerical_data(tuple(x))
51 | val_y = build_dataset_from_numerical_data(self.y_np)
52 | val_dataset = tf.data.Dataset.zip((val_data, val_y))
53 | n_samples = x[0].shape[0]
54 | input_dim = [x[i].shape[1] for i in range(len(x))]
55 |
56 | anchor_ensemble = AnchorEnsemble(architecture_type=arc_type,
57 | model_directory=output_dir)
58 | anchor_ensemble.fit(train_dataset, val_dataset, input_dim=input_dim)
59 |
60 | res = anchor_ensemble.predict(x)
61 | self.assertEqual(anchor_ensemble.get_n_members(), anchor_ensemble.n_members)
62 | self.assertTrue(res.shape == (n_samples, anchor_ensemble.n_members, 1))
63 |
64 | anchor_ensemble.evaluate(x, y)
65 |
66 |
67 | if __name__ == '__main__':
68 | absltest.main()
69 |
--------------------------------------------------------------------------------
/core/ensemble/dataset_lib.py:
--------------------------------------------------------------------------------
1 | """ This script is for building dataset """
2 |
3 | import tensorflow as tf
4 | from core.ensemble.model_hp import train_hparam
5 |
6 |
7 | def build_dataset_from_numerical_data(data, batch_size=None):
8 | """
9 | serialize the data to accommodate the format of model input
10 | :param data, tuple or np.ndarray
11 | :param batch_size, scalar or none, the train paramemeter is default if none provided
12 | """
13 | batch_size = train_hparam.batch_size if batch_size is None else batch_size
14 | return tf.data.Dataset.from_tensor_slices(data). \
15 | cache(). \
16 | batch(batch_size). \
17 | prefetch(tf.data.experimental.AUTOTUNE)
18 |
19 |
20 | def build_dataset_via_generator(generator, y=None, path='', batch_size=None):
21 | batch_size = train_hparam.batch_size if batch_size is None else batch_size
22 | if y is not None:
23 | return tf.data.Dataset.from_generator(generator,
24 | output_types=(tf.int32, tf.int32),
25 | output_shapes=(tf.TensorShape([None]), tf.TensorShape([]))
26 | ). \
27 | padded_batch(batch_size, padded_shapes=([None], [])). \
28 | cache(path). \
29 | shuffle(buffer_size=100). \
30 | prefetch(tf.data.experimental.AUTOTUNE)
31 | else:
32 | return tf.data.Dataset.from_generator(generator,
33 | output_types=tf.int32,
34 | output_shapes=tf.TensorShape([None])
35 | ). \
36 | padded_batch(batch_size, padded_shapes=([None])). \
37 | cache(path). \
38 | prefetch(tf.data.experimental.AUTOTUNE)
39 |
40 |
41 | def build_dataset_from_img_generator(generator, input_dim, y=None, is_training=False):
42 | if is_training and y is not None:
43 | return tf.data.Dataset.from_generator(generator,
44 | output_types=(tf.float32, tf.float32, tf.float32),
45 | output_shapes=(tf.TensorShape([None, *input_dim]),
46 | tf.TensorShape([None, ]),
47 | tf.TensorShape([None, ]))
48 | )
49 | elif not is_training and y is not None:
50 | return tf.data.Dataset.from_generator(generator,
51 | output_types=(tf.float32, tf.float32),
52 | output_shapes=(tf.TensorShape([None, *input_dim]),
53 | tf.TensorShape([None, ]))
54 | )
55 | else:
56 | return tf.data.Dataset.from_generator(generator,
57 | output_types=tf.float32,
58 | output_shapes=tf.TensorShape([None, *input_dim])
59 | )
60 |
--------------------------------------------------------------------------------
/core/ensemble/bayesian_ensemble_test.py:
--------------------------------------------------------------------------------
1 | from absl.testing import absltest
2 | from absl.testing import parameterized
3 | import tempfile
4 |
5 | import tensorflow as tf
6 | import numpy as np
7 | from sklearn.datasets import load_breast_cancer
8 |
9 | from core.ensemble.bayesian_ensemble import BayesianEnsemble
10 | from core.ensemble.dataset_lib import build_dataset_from_numerical_data
11 |
12 | architectures = ['dnn', 'text_cnn', 'multimodalitynn', 'r2d2', 'droidectc']
13 |
14 |
15 | class MyTestCaseBayesianEnsemble(parameterized.TestCase):
16 | def setUp(self):
17 | self.x_dict, self.y_dict = dict(), dict()
18 | self.x_np, self.y_np = load_breast_cancer(return_X_y=True)
19 | self.x_dict['dnn'] = self.x_np
20 | self.y_dict['dnn'] = self.y_np
21 | x = np.random.randint(0, 256, (10, 10))
22 | y = np.random.choice(2, 10)
23 | self.x_dict['text_cnn'] = x
24 | self.y_dict['text_cnn'] = y
25 | x = [self.x_np] * 5
26 | self.x_dict['multimodalitynn'] = x
27 | self.y_dict['multimodalitynn'] = self.y_np
28 | x = np.random.uniform(0., 1., size=(10, 299, 299, 3))
29 | y = np.random.choice(2, 10)
30 | self.x_dict['r2d2'] = x
31 | self.y_dict['r2d2'] = y
32 | x = np.random.randint(0, 10000, size=(10, 1000))
33 | y = np.random.choice(2, 10)
34 | self.x_dict['droidectc'] = x
35 | self.y_dict['droidectc'] = y
36 |
37 | @parameterized.named_parameters([(arc_type, arc_type) for arc_type in architectures])
38 | def test_beyasian_ensemble(self, arc_type):
39 | with tempfile.TemporaryDirectory() as output_dir:
40 | x = self.x_dict[arc_type]
41 | y = self.y_dict[arc_type]
42 | if arc_type is not 'multimodalitynn':
43 | train_dataset = build_dataset_from_numerical_data((x, y))
44 | val_dataset = build_dataset_from_numerical_data((x, y))
45 | n_samples = x.shape[0]
46 | input_dim = x.shape[1:]
47 | else:
48 | train_data = build_dataset_from_numerical_data(tuple(x))
49 | train_y = build_dataset_from_numerical_data(self.y_np)
50 | train_dataset = tf.data.Dataset.zip((train_data, train_y))
51 | val_data = build_dataset_from_numerical_data(tuple(x))
52 | val_y = build_dataset_from_numerical_data(self.y_np)
53 | val_dataset = tf.data.Dataset.zip((val_data, val_y))
54 | n_samples = x[0].shape[0]
55 | input_dim = [x[i].shape[1] for i in range(len(x))]
56 |
57 | bayesian_ensemble = BayesianEnsemble(architecture_type=arc_type,
58 | model_directory=output_dir)
59 | bayesian_ensemble.fit(train_dataset, val_dataset, input_dim=input_dim)
60 |
61 | res = bayesian_ensemble.predict(x)
62 | self.assertEqual(bayesian_ensemble.get_n_members(), bayesian_ensemble.n_members)
63 | self.assertTrue(res.shape == (n_samples, bayesian_ensemble.hparam.n_sampling, 1))
64 |
65 | bayesian_ensemble.evaluate(x, y)
66 |
67 |
68 | if __name__ == '__main__':
69 | absltest.main()
70 |
--------------------------------------------------------------------------------
/tools/progressbar_wrapper.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import time
3 |
4 | import tools.progressbar.progressbar as progressbar
5 | import tools.progressbar.widgets as progressbar_widgets
6 |
7 |
8 | class ProgressBar(object):
9 | def __init__(self):
10 | self.TotalResults = []
11 | self.NumberOfFinishedResults = 0
12 |
13 | def Update(self):
14 | self.ProgressBar.update(self.NumberOfFinishedResults)
15 | return
16 |
17 | def CallbackForProgressBar(self, res=''):
18 | """
19 | Callback function for pool.async if the progress bar needs to be displayed.
20 | Must use with DisplayProgressBar function.
21 |
22 | :param multiprocessing.pool.AsyncResult res: Result got from callback function in pool.async.
23 | """
24 | self.NumberOfFinishedResults += 1
25 | self.TotalResults.append(res)
26 | return
27 |
28 | def DisplayProgressBar(self, ProcessingResults, ExpectedResultsSize, CheckInterval=1, type="minute"):
29 | '''
30 | Display a progress bar for multiprocessing. This function should be used after pool.close(No need to use pool.join anymore).
31 | The call back function for pool.async should be set as CallbackForProgressBar.
32 |
33 | :param multiprocessing.pool.AsyncResult ProcessingResults: Processing results returned by pool.async.
34 | :param int ExpectedResultsSize: How many result you will reveive, i.e. the total length of progress bar.
35 | :param float CheckInterval: How many seconds will the progress bar be updated. When it's too large, the _main program may hang there.
36 | :param String type: Three types: "minute", "hour", "second"; corresponds displaying iters/minute iters/hour and iters/second.
37 | '''
38 | self.ProcessingResults = ProcessingResults
39 | ProgressBarWidgets = [progressbar_widgets.Percentage(),
40 | ' ', progressbar_widgets.Bar(),
41 | ' ', progressbar_widgets.SimpleProgress(),
42 | ' ', progressbar_widgets.Timer(),
43 | ' ', progressbar_widgets.AdaptiveETA()]
44 | self.ProgressBar = progressbar.ProgressBar(ExpectedResultsSize, ProgressBarWidgets)
45 | self.StartTime = time.time()
46 | PreviousNumberOfResults = 0
47 | self.ProgressBar.start()
48 | while self.ProcessingResults.ready() == False:
49 | self.Update()
50 | time.sleep(CheckInterval)
51 | time.sleep(CheckInterval)
52 | self.Update()
53 | self.ProgressBar.finish()
54 | self.EndTime = time.time()
55 | print("Processing finished.")
56 | # print "Processing results: ", self.TotalResults
57 | print("Time Elapsed: %.2fs, or %.2fmins, or %.2fhours" % (
58 | (self.EndTime - self.StartTime), (self.EndTime - self.StartTime) / 60,
59 | (self.EndTime - self.StartTime) / 3600))
60 | print("Processing finished.")
61 | print("Processing results: " + str(self.TotalResults))
62 | print("Time Elapsed: %.2fs, or %.2fmins, or %.2fhours" % (
63 | (self.EndTime - self.StartTime), (self.EndTime - self.StartTime) / 60, (self.EndTime - self.StartTime) / 3600))
64 | return
--------------------------------------------------------------------------------
/test/vanilla_test.py:
--------------------------------------------------------------------------------
1 | from absl.testing import absltest
2 | from absl.testing import parameterized
3 | import tempfile
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 | from sklearn.datasets import load_breast_cancer
8 |
9 | from core.ensemble.vanilla import Vanilla
10 | from core.ensemble.dataset_lib import build_dataset_from_numerical_data
11 |
12 |
13 | class MyTestCaseVanilla(parameterized.TestCase):
14 | def setUp(self):
15 | self.x_np, self.y_np = load_breast_cancer(return_X_y=True)
16 | self.train_dataset_v1 = build_dataset_from_numerical_data((self.x_np, self.y_np))
17 | self.val_dataset_v1 = build_dataset_from_numerical_data((self.x_np, self.y_np))
18 |
19 | def test_vanilla_dnn(self):
20 | with tempfile.TemporaryDirectory() as output_dir:
21 | vanilla = Vanilla(architecture_type='dnn',
22 | model_directory=output_dir)
23 | vanilla.fit(self.train_dataset_v1, self.val_dataset_v1, input_dim=self.x_np.shape[1])
24 |
25 | res = vanilla.predict(self.x_np)
26 | self.assertEqual(vanilla.get_n_members(), vanilla.n_members)
27 | self.assertTrue(res.shape == (self.x_np.shape[0], vanilla.n_members, 1))
28 |
29 | vanilla.evaluate(self.x_np, self.y_np)
30 |
31 | def test_vanilla_textcnn(self):
32 | with tempfile.TemporaryDirectory() as output_dir:
33 | vanilla = Vanilla(architecture_type='text_cnn',
34 | model_directory=output_dir)
35 | x = np.random.randint(0, 256, (10, 10))
36 | y = np.random.choice(2, 10)
37 | train_dataset = build_dataset_from_numerical_data((x,y))
38 | val_dataset = build_dataset_from_numerical_data((x,y))
39 | vanilla.fit(train_dataset, val_dataset)
40 | res = vanilla.predict(x)
41 | self.assertTrue(res.shape == (x.shape[0], vanilla.n_members, 1))
42 | vanilla.evaluate(x, y)
43 |
44 | def test_vanilla_multimodalitynn(self):
45 | with tempfile.TemporaryDirectory() as output_dir:
46 | vanilla = Vanilla(architecture_type='multimodalitynn',
47 | model_directory=output_dir)
48 | x = [self.x_np] * 5
49 | train_data = build_dataset_from_numerical_data(tuple(x))
50 | train_y = build_dataset_from_numerical_data(self.y_np)
51 | train_dataset = tf.data.Dataset.zip((train_data, train_y))
52 | val_data = build_dataset_from_numerical_data(tuple(x))
53 | val_y = build_dataset_from_numerical_data(self.y_np)
54 | val_dataset = tf.data.Dataset.zip((val_data, val_y))
55 | vanilla.fit(train_dataset, val_dataset, input_dim=[self.x_np.shape[1]]*5)
56 | res = vanilla.predict(x)
57 | self.assertTrue(res.shape == (self.x_np.shape[0], vanilla.n_members, 1))
58 | vanilla.evaluate(x, self.y_np)
59 |
60 | def test_vanilla_r2d2(self):
61 | with tempfile.TemporaryDirectory() as output_dir:
62 | vanilla = Vanilla(architecture_type='r2d2',
63 | model_directory=output_dir)
64 | x = np.random.uniform(0., 1., size=(10, 299, 299, 3))
65 | y = np.random.choice(2, 10)
66 | train_dataset = build_dataset_from_numerical_data((x, y))
67 | val_dataset = build_dataset_from_numerical_data((x, y))
68 | vanilla.fit(train_dataset, val_dataset, input_dim=(299, 299, 3))
69 | res = vanilla.predict(x)
70 | self.assertTrue(res.shape == (x.shape[0], vanilla.n_members, 1))
71 | vanilla.evaluate(x, y)
72 |
73 | def test_vanilla_droidectc(self):
74 | with tempfile.TemporaryDirectory() as output_dir:
75 | vanilla = Vanilla(architecture_type='droidectc',
76 | model_directory=output_dir)
77 |
78 | x = np.random.randint(0, 10000, size=(10, 1000))
79 | y = np.random.choice(2, 10)
80 | train_dataset = build_dataset_from_numerical_data((x, y))
81 | val_dataset = build_dataset_from_numerical_data((x, y))
82 | vanilla.fit(train_dataset, val_dataset)
83 | res = vanilla.predict(x)
84 | self.assertTrue(res.shape == (x.shape[0], vanilla.n_members, 1))
85 | vanilla.evaluate(x, y)
86 |
87 |
88 | if __name__ == '__main__':
89 | absltest.main()
90 |
--------------------------------------------------------------------------------
/test/mc_dropout_test.py:
--------------------------------------------------------------------------------
1 | from absl.testing import absltest
2 | from absl.testing import parameterized
3 | import tempfile
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 | from sklearn.datasets import load_breast_cancer
8 |
9 | from core.ensemble.mc_dropout import MCDropout
10 | from core.ensemble.dataset_lib import build_dataset_from_numerical_data
11 |
12 |
13 | class MyTestCaseMCDropout(parameterized.TestCase):
14 | def setUp(self):
15 | self.x_np, self.y_np = load_breast_cancer(return_X_y=True)
16 | self.train_dataset_v1 = build_dataset_from_numerical_data((self.x_np, self.y_np))
17 | self.val_dataset_v1 = build_dataset_from_numerical_data((self.x_np, self.y_np))
18 |
19 | def test_dnn(self):
20 | with tempfile.TemporaryDirectory() as output_dir:
21 | mcdropout = MCDropout(architecture_type='dnn',
22 | model_directory=output_dir)
23 | mcdropout.fit(self.train_dataset_v1, self.val_dataset_v1, input_dim=self.x_np.shape[1])
24 |
25 | res = mcdropout.predict(self.x_np)
26 | self.assertEqual(mcdropout.get_n_members(), mcdropout.n_members)
27 | self.assertTrue(res.shape == (self.x_np.shape[0], mcdropout.hparam.n_sampling, 1))
28 |
29 | mcdropout.evaluate(self.x_np, self.y_np)
30 |
31 | def test_textcnn(self):
32 | with tempfile.TemporaryDirectory() as output_dir:
33 | mcdropout = MCDropout(architecture_type='text_cnn',
34 | model_directory=output_dir)
35 | x = np.random.randint(0, 256, (10, 10))
36 | y = np.random.choice(2, 10)
37 | train_dataset = build_dataset_from_numerical_data((x, y))
38 | val_dataset = build_dataset_from_numerical_data((x, y))
39 | mcdropout.fit(train_dataset, val_dataset)
40 | res = mcdropout.predict(x)
41 | self.assertTrue(res.shape == (x.shape[0], mcdropout.hparam.n_sampling, 1))
42 | mcdropout.evaluate(x, y)
43 |
44 | def test_multimodalitynn(self):
45 | with tempfile.TemporaryDirectory() as output_dir:
46 | mcdropout = MCDropout(architecture_type='multimodalitynn',
47 | model_directory=output_dir)
48 | x = [self.x_np] * 5
49 | train_data = build_dataset_from_numerical_data(tuple(x))
50 | train_y = build_dataset_from_numerical_data(self.y_np)
51 | train_dataset = tf.data.Dataset.zip((train_data, train_y))
52 | val_data = build_dataset_from_numerical_data(tuple(x))
53 | val_y = build_dataset_from_numerical_data(self.y_np)
54 | val_dataset = tf.data.Dataset.zip((val_data, val_y))
55 | mcdropout.fit(train_dataset, val_dataset, input_dim=[self.x_np.shape[1]] * 5)
56 | res = mcdropout.predict(x)
57 | self.assertTrue(res.shape == (self.x_np.shape[0], mcdropout.hparam.n_sampling, 1))
58 | mcdropout.evaluate(x, self.y_np)
59 |
60 | def test_r2d2(self):
61 | with tempfile.TemporaryDirectory() as output_dir:
62 | mcdropout = MCDropout(architecture_type='r2d2',
63 | model_directory=output_dir)
64 | x = np.random.uniform(0., 1., size=(10, 299, 299, 3))
65 | y = np.random.choice(2, 10)
66 | train_dataset = build_dataset_from_numerical_data((x, y))
67 | val_dataset = build_dataset_from_numerical_data((x, y))
68 | mcdropout.fit(train_dataset, val_dataset, input_dim=(299, 299, 3))
69 | res = mcdropout.predict(x)
70 | self.assertTrue(res.shape == (x.shape[0], mcdropout.hparam.n_sampling, 1))
71 | mcdropout.evaluate(x, y)
72 |
73 | def test_droidectc(self):
74 | with tempfile.TemporaryDirectory() as output_dir:
75 | mcdropout = MCDropout(architecture_type='droidectc',
76 | model_directory=output_dir)
77 |
78 | x = np.random.randint(0, 10000, size=(10, 1000))
79 | y = np.random.choice(2, 10)
80 | train_dataset = build_dataset_from_numerical_data((x, y))
81 | val_dataset = build_dataset_from_numerical_data((x, y))
82 | mcdropout.fit(train_dataset, val_dataset)
83 | res = mcdropout.predict(x)
84 | self.assertTrue(res.shape == (x.shape[0], mcdropout.hparam.n_sampling, 1))
85 | mcdropout.evaluate(x, y)
86 |
87 |
88 | if __name__ == '__main__':
89 | absltest.main()
90 |
--------------------------------------------------------------------------------
/core/ensemble/model_hp.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | _TRAIN_HP_TEMPLATE = namedtuple('training',
4 | ['random_seed', 'n_epochs', 'batch_size', 'learning_rate', 'clipvalue', 'interval'])
5 | train_hparam = _TRAIN_HP_TEMPLATE(random_seed=23456,
6 | n_epochs=30,
7 | batch_size=16,
8 | learning_rate=0.001,
9 | clipvalue=100.,
10 | interval=2 # saving model weights
11 | )
12 |
13 | # these hyper-paramters of finetuning are used for tuning r2d2 model, which is deprecated owing to the issue of effectiveness
14 | _FINETUNE_TRAIN_HP_TEMPLATE = namedtuple('finetuning',
15 | ['random_seed', 'n_epochs', 'batch_size', 'learning_rate', 'n_epochs_ft',
16 | 'learning_rate_ft', 'unfreezed_layers', 'clipvalue', 'interval'])
17 | finetuning_hparam = _FINETUNE_TRAIN_HP_TEMPLATE(random_seed=23456,
18 | n_epochs=10,
19 | batch_size=16,
20 | learning_rate=0.0001,
21 | n_epochs_ft=20,
22 | learning_rate_ft=0.00001,
23 | unfreezed_layers=3,
24 | clipvalue=100.,
25 | interval=2 # saving model weights
26 | )
27 |
28 | _MC_DROPOUT_HP_TEMPLATE = namedtuple('mc_dropout', ['dropout_rate', 'n_sampling'])
29 | mc_dropout_hparam = _MC_DROPOUT_HP_TEMPLATE(dropout_rate=0.4,
30 | n_sampling=10
31 | )
32 | _BAYESIAN_HP_TEMPLATE = namedtuple('bayesian', ['n_sampling'])
33 | bayesian_ensemble_hparam = _BAYESIAN_HP_TEMPLATE(n_sampling=10)
34 |
35 | _DNN_HP_TEMPLATE = namedtuple('DNN',
36 | ['hidden_units', 'dropout_rate', 'activation', 'output_dim'])
37 |
38 | dnn_hparam = _DNN_HP_TEMPLATE(hidden_units=[200, 200], # DNN has two hidden layers with each having 200 neurons
39 | dropout_rate=0.4,
40 | activation='relu',
41 | output_dim=1 # binary classification#
42 | )
43 |
44 | _TEXT_CNN_HP_TEMPLATE = namedtuple('textCNN',
45 | ['hidden_units', 'dropout_rate', 'activation', 'output_dim',
46 | 'vocab_size', 'n_embedding_dim', 'n_conv_filters', 'kernel_size',
47 | 'max_sequence_length',
48 | 'use_spatial_dropout', 'use_conv_dropout'])
49 |
50 | text_cnn_hparam = _TEXT_CNN_HP_TEMPLATE(hidden_units=[200, 200],
51 | dropout_rate=0.4,
52 | activation='relu',
53 | vocab_size=256,
54 | n_embedding_dim=8,
55 | n_conv_filters=64,
56 | kernel_size=8,
57 | max_sequence_length=700000, # shall be large for promoting accuracy
58 | use_spatial_dropout=False,
59 | use_conv_dropout=False,
60 | output_dim=1
61 | )
62 |
63 | _MULTIMOD_HP_TEMPLATE = namedtuple('multimodalitynn',
64 | ['hidden_units', 'dropout_rate', 'activation', 'output_dim',
65 | 'n_modalities', 'initial_hidden_units',
66 | ])
67 |
68 | multimodalitynn_hparam = _MULTIMOD_HP_TEMPLATE(
69 | hidden_units=[200, 200],
70 | dropout_rate=0.4,
71 | activation='relu',
72 | n_modalities=5,
73 | initial_hidden_units=[500, 500],
74 | output_dim=1
75 | )
76 |
77 | _R2D2_HP_TEMPLATE = namedtuple('r2d2',
78 | ['use_small_model', 'trainable', 'unfreezed_layers', 'dropout_rate', 'output_dim'])
79 |
80 | r2d2_hparam = _R2D2_HP_TEMPLATE(
81 | use_small_model=False, # if True : return MobileNetV2 else InceptionV3
82 | trainable=False, # if False : the learned parameters of InceptionV3 or MobileNetV2 will be frozen
83 | unfreezed_layers=8,
84 | dropout_rate=0.4,
85 | output_dim=1
86 | )
87 |
88 | _DROIDETEC_HP_TEMPLATE = namedtuple('droidetec',
89 | ['vocab_size', 'n_embedding_dim', 'lstm_units', 'hidden_units',
90 | 'dropout_rate', 'max_sequence_length', 'output_dim'])
91 |
92 | droidetec_hparam = _DROIDETEC_HP_TEMPLATE(
93 | vocab_size=100000, # owing to the GPU memory size, we set 100,000
94 | n_embedding_dim=8,
95 | lstm_units=64,
96 | hidden_units=[200],
97 | dropout_rate=0.4,
98 | max_sequence_length=1000000,
99 | output_dim=1
100 | )
101 |
--------------------------------------------------------------------------------
/core/feature/apiseq/apiseq.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import warnings
4 |
5 | from collections import defaultdict
6 | from androguard.misc import AnalyzeAPK
7 | from tools import utils
8 | from config import logging
9 |
10 | current_dir = os.path.dirname(os.path.realpath(__file__))
11 | logger = logging.getLogger("feature.apiseq")
12 |
13 | REMOVE_CLASS_HEAD_LIST = [
14 | 'Ljava/', 'Ljavax/'
15 | ]
16 |
17 | RETAIN_CLASS_HEAD_LIST = [
18 | 'Landroid/',
19 | 'Lcom/android/internal/util/',
20 | 'Ldalvik/',
21 | 'Lorg/apache/',
22 | 'Lorg/json/',
23 | 'Lorg/w3c/dom/',
24 | 'Lorg/xml/sax',
25 | 'Lorg/xmlpull/v1/',
26 | 'Ljunit/'
27 | ]
28 |
29 |
30 | def _check_class(class_name):
31 | for cls_head in REMOVE_CLASS_HEAD_LIST:
32 | if class_name.startswith(cls_head):
33 | return False
34 | for cls_head in RETAIN_CLASS_HEAD_LIST:
35 | if class_name.startswith(cls_head):
36 | return True
37 |
38 | return False
39 |
40 |
41 | def _dfs(api, nodes, seq=[], visited=[]):
42 | if api not in nodes.keys():
43 | seq.append(api)
44 | else:
45 | visited.append(api)
46 | for elem in nodes[api]:
47 | if elem in visited:
48 | seq.append(elem)
49 | else:
50 | _dfs(elem, nodes, seq, visited)
51 |
52 |
53 | def get_api_sequence(apk_path, save_path):
54 | """
55 | produce an api call sequence for an apk
56 | :param apk_path: an apk path
57 | :param save_path: path for saving resulting feature
58 | :return: (status, back_path_name)
59 | """
60 | try:
61 | # obtain tow dictionaries: xref_from and xref_to, of which key is class-method name
62 | # and value is the caller or callee.
63 | _, _, dx = AnalyzeAPK(apk_path)
64 | mth_callers = defaultdict(list)
65 | mth_callees = defaultdict(list)
66 | for cls_obj in dx.get_classes(): # ClassAnalysis
67 | if cls_obj.is_external():
68 | continue
69 | cls_name = cls_obj.name
70 | for mth_obj in cls_obj.get_methods():
71 | if mth_obj.is_external():
72 | continue
73 |
74 | m = mth_obj.get_method() # dvm.EncodedMethod
75 | cls_mth_name = cls_name + '->' + m.name + m.proto
76 | # get callers
77 | mth_callers[cls_mth_name] = []
78 | for _, call, _ in mth_obj.get_xref_from():
79 | if _check_class(call.class_name):
80 | mth_callers[cls_mth_name].append(call.class_name + '->' + call.name + call.proto)
81 | # get callees sequentially
82 | for instruction in m.get_instructions():
83 | opcode = instruction.get_name()
84 | if 'invoke-' in opcode:
85 | code_body = instruction.get_output()
86 | if '->' not in code_body:
87 | continue
88 | head_part, rear_part = code_body.split('->')
89 | class_name = head_part.strip().split(' ')[-1]
90 | mth_name_callee = class_name + '->' + rear_part
91 | if _check_class(mth_name_callee):
92 | mth_callees[cls_mth_name].append(mth_name_callee)
93 |
94 | # look for the root call
95 | root_calls = []
96 | num_of_calls = len(mth_callers.items())
97 | if num_of_calls == 0:
98 | raise ValueError("No callers")
99 | for k in mth_callers.keys():
100 | if (len(mth_callers[k]) <= 0) and (len(mth_callees[k]) > 0):
101 | root_calls.append(k)
102 |
103 | if len(root_calls) == 0:
104 | warnings.warn("Cannot find a root call, instead, randomly pick up one.")
105 | import random
106 | id = random.choice(range(num_of_calls))
107 | root_calls.append(mth_callers.keys()[id])
108 |
109 | # generate sequence
110 | api_sequence = []
111 | for root_call in root_calls:
112 | sub_seq = []
113 | visited_nodes = []
114 | _dfs(root_call, mth_callees, sub_seq, visited_nodes)
115 | api_sequence.extend(sub_seq)
116 | # dump feature
117 | utils.dump_txt('\n'.join(api_sequence), save_path)
118 | return save_path
119 | except Exception as e:
120 | if len(e.args) > 0:
121 | e.args = e.args + (apk_path,)
122 | return e
123 |
124 |
125 | def load_feature(save_path):
126 | return utils.read_txt(save_path)
127 |
128 |
129 | def wrapper_load_feature(save_path):
130 | try:
131 | return load_feature(save_path)
132 | except Exception as e:
133 | return e
134 |
135 |
136 | def _mapping(path_to_feature, dictionary):
137 | features = load_feature(path_to_feature)
138 | _feature = [idx for idx in list(map(dictionary.get, features)) if idx is not None]
139 | if len(_feature) == 0:
140 | _feature = [dictionary.get('sos')]
141 | warnings.warn("Produce zero feature vector.")
142 | return _feature
143 |
144 |
145 | def wrapper_mapping(ptuple):
146 | try:
147 | return _mapping(*ptuple)
148 | except Exception as e:
149 | return e
150 |
--------------------------------------------------------------------------------
/test/bayesian_ensemble_test.py:
--------------------------------------------------------------------------------
1 | from absl.testing import absltest
2 | from absl.testing import parameterized
3 | import tempfile
4 |
5 | import tensorflow as tf
6 | import numpy as np
7 | from sklearn.datasets import load_breast_cancer
8 |
9 | from core.ensemble.bayesian_ensemble import BayesianEnsemble
10 | from core.ensemble.dataset_lib import build_dataset_from_numerical_data
11 |
12 | architectures = ['dnn', 'text_cnn', 'multimodalitynn', 'r2d2', 'droidectc']
13 |
14 |
15 | class MyTestCaseBayesianEnsemble(parameterized.TestCase):
16 | def setUp(self):
17 | self.x_np, self.y_np = load_breast_cancer(return_X_y=True)
18 | self.train_dataset_v1 = build_dataset_from_numerical_data((self.x_np, self.y_np))
19 | self.val_dataset_v1 = build_dataset_from_numerical_data((self.x_np, self.y_np))
20 |
21 | def test_dnn(self):
22 | with tempfile.TemporaryDirectory() as output_dir:
23 | bayesian_ensemble = BayesianEnsemble(architecture_type='dnn',
24 | model_directory=output_dir)
25 | bayesian_ensemble.fit(self.train_dataset_v1, self.val_dataset_v1, input_dim=self.x_np.shape[1])
26 |
27 | res = bayesian_ensemble.predict(self.x_np)
28 | self.assertEqual(bayesian_ensemble.get_n_members(), bayesian_ensemble.n_members)
29 | self.assertTrue(res.shape == (self.x_np.shape[0], bayesian_ensemble.hparam.n_sampling, 1))
30 |
31 | bayesian_ensemble.evaluate(self.x_np, self.y_np)
32 |
33 | def test_textcnn(self):
34 | with tempfile.TemporaryDirectory() as output_dir:
35 | x = np.random.randint(0, 256, (10, 10))
36 | y = np.random.choice(2, 10)
37 | train_dataset = build_dataset_from_numerical_data((x, y))
38 | val_dataset = build_dataset_from_numerical_data((x, y))
39 | bayesian_ensemble = BayesianEnsemble(architecture_type='text_cnn',
40 | model_directory=output_dir)
41 | bayesian_ensemble.fit(train_dataset, val_dataset, input_dim=x.shape[1:])
42 | res = bayesian_ensemble.predict(x)
43 | self.assertEqual(bayesian_ensemble.get_n_members(), bayesian_ensemble.n_members)
44 | self.assertTrue(res.shape == (x[0].shape[0], bayesian_ensemble.hparam.n_sampling, 1))
45 |
46 | bayesian_ensemble.evaluate(x, y)
47 |
48 | def test_multimodalitynn(self):
49 | with tempfile.TemporaryDirectory() as output_dir:
50 | x = [self.x_np] * 5
51 | train_data = build_dataset_from_numerical_data(tuple(x))
52 | train_y = build_dataset_from_numerical_data(self.y_np)
53 | train_dataset = tf.data.Dataset.zip((train_data, train_y))
54 | val_data = build_dataset_from_numerical_data(tuple(x))
55 | val_y = build_dataset_from_numerical_data(self.y_np)
56 | val_dataset = tf.data.Dataset.zip((val_data, val_y))
57 | bayesian_ensemble = BayesianEnsemble(architecture_type='multimodalitynn',
58 | model_directory=output_dir)
59 | bayesian_ensemble.fit(train_dataset, val_dataset, input_dim=[x[i].shape[1] for i in range(len(x))])
60 |
61 | res = bayesian_ensemble.predict(x)
62 | self.assertEqual(bayesian_ensemble.get_n_members(), bayesian_ensemble.n_members)
63 | self.assertTrue(res.shape == (x[0].shape[0], bayesian_ensemble.hparam.n_sampling, 1))
64 |
65 | bayesian_ensemble.evaluate(x, self.y_np)
66 |
67 | def test_r2d2(self):
68 | with tempfile.TemporaryDirectory() as output_dir:
69 | x = np.random.uniform(0., 1., size=(10, 299, 299, 3))
70 | y = np.random.choice(2, 10)
71 | train_dataset = build_dataset_from_numerical_data((x, y))
72 | val_dataset = build_dataset_from_numerical_data((x, y))
73 |
74 | bayesian_ensemble = BayesianEnsemble(architecture_type='r2d2',
75 | model_directory=output_dir)
76 | bayesian_ensemble.fit(train_dataset, val_dataset, input_dim=x.shape[1:])
77 |
78 | res = bayesian_ensemble.predict(x)
79 | self.assertEqual(bayesian_ensemble.get_n_members(), bayesian_ensemble.n_members)
80 | self.assertTrue(res.shape == (x.shape[0], bayesian_ensemble.hparam.n_sampling, 1))
81 |
82 | bayesian_ensemble.evaluate(x, y)
83 |
84 | def test_droidectc(self):
85 | with tempfile.TemporaryDirectory() as output_dir:
86 | x = np.random.randint(0, 10000, size=(10, 1000))
87 | y = np.random.choice(2, 10)
88 | train_dataset = build_dataset_from_numerical_data((x, y))
89 | val_dataset = build_dataset_from_numerical_data((x, y))
90 | bayesian_ensemble = BayesianEnsemble(architecture_type='droidectc',
91 | model_directory=output_dir)
92 | bayesian_ensemble.fit(train_dataset, val_dataset, input_dim=x.shape[1:])
93 |
94 | res = bayesian_ensemble.predict(x)
95 | self.assertEqual(bayesian_ensemble.get_n_members(), bayesian_ensemble.n_members)
96 | self.assertTrue(res.shape == (x[0].shape[0], bayesian_ensemble.hparam.n_sampling, 1))
97 |
98 | bayesian_ensemble.evaluate(x, y)
99 |
100 |
101 | if __name__ == '__main__':
102 | absltest.main()
103 |
--------------------------------------------------------------------------------
/experiments/oos.py:
--------------------------------------------------------------------------------
1 | # conduct the group of 'out of distribution' experiments on drebin dataset
2 | import os
3 | import sys
4 | import random
5 | from collections import Counter
6 |
7 | import numpy as np
8 | from sklearn.model_selection import train_test_split
9 |
10 | from core.feature import feature_type_scope_dict, feature_type_vs_architecture
11 | from core.ensemble import ensemble_method_scope_dict
12 | from core.post_calibration.temperature_scaling import apply_temperature_scaling
13 | from tools import utils
14 | from config import config, logging
15 |
16 | logger = logging.getLogger('experiment.drebin_ood')
17 |
18 |
19 | # procedure of ood experiments
20 | # 1. build dataset
21 | # 2. preprocess data
22 | # 3. conduct prediction
23 | # 4. save results for statistical analysis
24 |
25 |
26 | def run_experiment(feature_type, ensemble_type, proc_numbers=2):
27 | """
28 | run this group of experiments
29 | :param feature_type: the type of features (e.g., drebin, opcode, etc.), feature type associates to the model architecture
30 | :param ensemble_type: the ensemble method (e.g., vanilla, deep_ensemble, etc.
31 | :return: None
32 | """
33 | ood_data, ood_y, input_dim = data_preprocessing(feature_type, proc_numbers)
34 |
35 | ensemble_obj = get_ensemble_object(ensemble_type)
36 | # instantiation
37 | arch_type = feature_type_vs_architecture.get(feature_type)
38 | model_saving_dir = config.get('experiments', 'drebin')
39 | if ensemble_type in ['vanilla', 'mc_dropout', 'bayesian']:
40 | ensemble_model = ensemble_obj(arch_type, base_model=None, model_directory=model_saving_dir)
41 | else:
42 | ensemble_model = ensemble_obj(arch_type, base_model=None, model_directory=model_saving_dir)
43 |
44 | ood_results = ensemble_model.predict(ood_data)
45 | utils.dump_joblib((ood_results, ood_y), os.path.join(config.get('experiments', 'oos'),
46 | '{}_{}_drebin_oos.res'.format(feature_type, ensemble_type)))
47 |
48 |
49 | def run_temperature_scaling(feature_type, ensemble_type, proc_numbers=2):
50 | ood_data, ood_y, input_dim = data_preprocessing(feature_type, proc_numbers)
51 |
52 | ensemble_obj = get_ensemble_object(ensemble_type)
53 | # instantiation
54 | arch_type = feature_type_vs_architecture.get(feature_type)
55 | model_saving_dir = config.get('experiments', 'drebin')
56 | ensemble_model = ensemble_obj(arch_type, base_model=None, model_directory=model_saving_dir)
57 | # temperature scaling
58 |
59 | temp_save_dir = os.path.join(config.get('drebin', 'intermediate_directory'),
60 | "{}_{}_temp.json".format(feature_type, ensemble_type))
61 | if not os.path.exists(temp_save_dir):
62 | raise FileNotFoundError
63 |
64 | temperature = utils.load_json(temp_save_dir)['temperature']
65 | probs = ensemble_model.predict(ood_data, use_prob=True)
66 | probs_scaling = apply_temperature_scaling(temperature, probs)
67 | utils.dump_joblib((probs_scaling, ood_y), os.path.join(config.get('experiments', 'oos'),
68 | '{}_{}_temperature_drebin_oos.res'.format(feature_type, ensemble_type)))
69 |
70 |
71 | def data_preprocessing(feature_type='drebin', proc_numbers=2):
72 | assert feature_type in feature_type_scope_dict.keys(), 'Expected {}, but {} are supported.'.format(
73 | feature_type, feature_type_scope_dict.keys())
74 | benware_dir = config.get('oos', 'benware_dir')
75 | malware_dir = config.get('oos', 'malware_dir')
76 | android_features_saving_dir = config.get('metadata', 'naive_data_pool')
77 | intermediate_data_saving_dir = config.get('oos', 'intermediate_directory')
78 | feature_extractor = feature_type_scope_dict[feature_type](android_features_saving_dir,
79 | intermediate_data_saving_dir,
80 | update=False,
81 | proc_number=proc_numbers)
82 |
83 | save_path = os.path.join(intermediate_data_saving_dir, 'oos_database.' + feature_type)
84 | if os.path.exists(save_path):
85 | oos_filenames, oos_y = utils.read_joblib(save_path)
86 | oos_features = [os.path.join(android_features_saving_dir, filename) for filename in oos_filenames]
87 | else:
88 | mal_feature_list = feature_extractor.feature_extraction(malware_dir)
89 | n_malware = len(mal_feature_list)
90 | ben_feature_list = feature_extractor.feature_extraction(benware_dir)
91 | n_benware = len(ben_feature_list)
92 | oos_features = mal_feature_list + ben_feature_list
93 | oos_y = np.zeros((n_malware + n_benware,), dtype=np.int32)
94 | oos_y[:n_malware] = 1
95 |
96 | oos_filenames = [os.path.basename(path) for path in oos_features]
97 | utils.dump_joblib((oos_filenames, oos_y), save_path)
98 |
99 | # obtain data in a format for ML algorithms
100 | ood_data, input_dim = feature_extractor.feature2ipt(oos_features)
101 | return ood_data, oos_y, input_dim
102 |
103 |
104 | def get_ensemble_object(ensemble_type):
105 | assert ensemble_type in ensemble_method_scope_dict.keys(), '{} expected, but {} are supported'.format(
106 | ensemble_type,
107 | ','.join(ensemble_method_scope_dict.keys())
108 | )
109 | return ensemble_method_scope_dict[ensemble_type]
110 |
--------------------------------------------------------------------------------
/core/ensemble/anchor_ensemble.py:
--------------------------------------------------------------------------------
1 | import os.path as path
2 | import time
3 |
4 | import tensorflow as tf
5 | import numpy as np
6 |
7 | from core.ensemble.vanilla import Vanilla
8 | from core.ensemble.model_hp import train_hparam, anchor_hparam
9 | from core.ensemble.model_lib import model_builder
10 | from tools import utils
11 | from config import logging
12 |
13 | logger = logging.getLogger('ensemble.vanilla')
14 |
15 |
16 | class AnchorEnsemble(Vanilla):
17 | def __init__(self,
18 | architecture_type='dnn',
19 | base_model=None,
20 | n_members=2,
21 | model_directory=None,
22 | name='ANCHOR'):
23 | """
24 | initialization
25 | :param architecture_type: the type of base model
26 | :param base_model: an object of base model
27 | :param n_members: number of base models
28 | :param model_directory: a folder for saving ensemble weights
29 | """
30 | super(AnchorEnsemble, self).__init__(architecture_type, base_model, n_members, model_directory)
31 | self.hparam = utils.merge_namedtuples(train_hparam, anchor_hparam)
32 | self.ensemble_type = 'anchor'
33 | self.name = name.lower()
34 | self.save_dir = path.join(self.model_directory, self.name)
35 |
36 | def build_model(self, input_dim=None):
37 | """
38 | Build an ensemble model -- only the homogeneous structure is considered
39 | :param input_dim: integer or list, input dimension shall be set in some cases under eager mode
40 | """
41 | callable_graph = model_builder(self.architecture_type)
42 |
43 | @callable_graph(input_dim)
44 | def _builder():
45 | seed = np.random.choice(self.hparam.random_seed)
46 | return utils.produce_layer(self.ensemble_type,
47 | scale=self.hparam.scale,
48 | batch_size=self.hparam.batch_size,
49 | seed=seed)
50 |
51 | self.base_model = _builder()
52 |
53 | def model_generator(self):
54 | try:
55 | for m in range(self.n_members):
56 | self.base_model = None
57 | self.load_ensemble_weights(m)
58 | yield self.base_model
59 | except Exception as e:
60 | raise Exception("Cannot load model:{}.".format(str(e)))
61 |
62 | def fit(self, train_set, validation_set=None, input_dim=None, **kwargs):
63 | """
64 | fit the ensemble by producing a lists of model weights
65 | :param train_set: tf.data.Dataset, the type shall accommodate to the input format of Tensorflow models
66 | :param validation_set: validation data, optional
67 | :param input_dim: integer or list, input dimension except for the batch size
68 | """
69 | # training
70 | logger.info("hyper-parameters:")
71 | logger.info(dict(self.hparam._asdict()))
72 | logger.info("...training start!")
73 | np.random.seed(self.hparam.random_seed)
74 | train_set = train_set.shuffle(buffer_size=100, reshuffle_each_iteration=True)
75 | for member_idx in range(self.n_members):
76 | self.base_model = None
77 | self.build_model(input_dim=input_dim)
78 |
79 | self.base_model.compile(
80 | optimizer=tf.keras.optimizers.Adam(learning_rate=self.hparam.learning_rate),
81 | loss=tf.keras.losses.BinaryCrossentropy(),
82 | metrics=[tf.keras.metrics.BinaryAccuracy()],
83 | )
84 | for epoch in range(self.hparam.n_epochs):
85 | total_time = 0.
86 | msg = 'Epoch {}/{}, and member {}/{}'.format(epoch + 1,
87 | self.hparam.n_epochs, member_idx + 1,
88 | self.n_members)
89 | print(msg)
90 | start_time = time.time()
91 | self.base_model.fit(train_set,
92 | epochs=epoch + 1,
93 | initial_epoch=epoch,
94 | validation_data=validation_set
95 | )
96 | end_time = time.time()
97 | total_time += end_time - start_time
98 | # saving
99 | logger.info('Training ensemble costs {} seconds at this epoch'.format(total_time))
100 | if (epoch + 1) % self.hparam.interval == 0:
101 | self.save_ensemble_weights(member_idx)
102 |
103 | def save_ensemble_weights(self, member_idx=0):
104 | if not path.exists(path.join(self.save_dir, self.architecture_type + '_{}'.format(member_idx))):
105 | utils.mkdir(path.join(self.save_dir, self.architecture_type + '_{}'.format(member_idx)))
106 | # save model configuration
107 | self.base_model.save(path.join(self.save_dir, self.architecture_type + '_{}'.format(member_idx)))
108 | print("Save the model to directory {}".format(self.save_dir))
109 |
110 | def load_ensemble_weights(self, member_idx=0):
111 | if path.exists(path.join(self.save_dir, self.architecture_type + '_{}'.format(member_idx))):
112 | self.base_model = tf.keras.models.load_model(
113 | path.join(self.save_dir, self.architecture_type + '_{}'.format(member_idx)))
114 |
115 | def get_n_members(self):
116 | return self.n_members
117 |
118 | def update_weights(self, member_idx, model_weights, optimizer_weights=None):
119 | raise NotImplementedError
120 |
--------------------------------------------------------------------------------
/core/ensemble/bayesian_ensemble.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import tensorflow as tf
4 |
5 | from core.ensemble.vanilla import model_builder
6 | from core.ensemble.mc_dropout import MCDropout
7 | from core.ensemble.model_hp import train_hparam, bayesian_ensemble_hparam
8 | from config import logging, ErrorHandler
9 | from tools import utils
10 |
11 | logger = logging.getLogger('ensemble.bayesian_ensemble')
12 | logger.addHandler(ErrorHandler)
13 |
14 |
15 | class BayesianEnsemble(MCDropout):
16 | def __init__(self,
17 | architecture_type='dnn',
18 | base_model=None,
19 | n_members=1,
20 | model_directory=None,
21 | name='BAYESIAN_ENSEMBLE'
22 | ):
23 | super(BayesianEnsemble, self).__init__(architecture_type,
24 | base_model,
25 | n_members,
26 | model_directory,
27 | name)
28 | self.hparam = utils.merge_namedtuples(train_hparam, bayesian_ensemble_hparam)
29 | self.ensemble_type = 'bayesian'
30 |
31 | def build_model(self, input_dim=None, scaler=1. / 10000):
32 | """
33 | Build an ensemble model -- only the homogeneous structure is considered
34 | :param input_dim: integer or list, input dimension shall be set in some cases under eager mode
35 | :param scaler: float value in the rage of [0, 1], weighted kl divergence
36 | """
37 | callable_graph = model_builder(self.architecture_type)
38 |
39 | @callable_graph(input_dim)
40 | def _builder():
41 | return utils.produce_layer(self.ensemble_type, kl_scaler=scaler)
42 |
43 | self.base_model = _builder()
44 | return
45 |
46 | def fit(self, train_set, validation_set=None, input_dim=None, **kwargs):
47 | """
48 | fit the ensemble by producing a lists of model weights
49 | :param train_set: tf.data.Dataset, the type shall accommodate to the input format of Tensorflow models
50 | :param validation_set: validation data, optional
51 | :param input_dim: integer or list, input dimension except for the batch size
52 | """
53 |
54 | # training preparation
55 | if self.base_model is None:
56 | # scaler = 1. / (len(list(train_set)) * self.hparam.batch_size) # time-consuming
57 | scaler = 1. / 50000.
58 | self.build_model(input_dim=input_dim, scaler=scaler)
59 |
60 | self.base_model.compile(
61 | optimizer=tf.keras.optimizers.Adam(learning_rate=self.hparam.learning_rate,
62 | clipvalue=self.hparam.clipvalue),
63 | loss=tf.keras.losses.BinaryCrossentropy(),
64 | metrics=[tf.keras.metrics.BinaryAccuracy()],
65 | experimental_run_tf_function=False
66 | )
67 |
68 | # training
69 | logger.info("hyper-parameters:")
70 | logger.info(dict(self.hparam._asdict()))
71 | logger.info("...training start!")
72 |
73 | best_val_accuracy = 0.
74 | total_time = 0.
75 | for epoch in range(self.hparam.n_epochs):
76 | train_acc = 0.
77 | val_acc = 0.
78 |
79 | for member_idx in range(self.n_members):
80 | if member_idx < len(self.weights_list): # loading former weights
81 | self.base_model.set_weights(self.weights_list[member_idx])
82 | self.base_model.optimizer.set_weights(self._optimizers_dict[member_idx])
83 | elif member_idx == 0:
84 | pass # do nothing
85 | else:
86 | self.reinitialize_base_model()
87 |
88 | msg = 'Epoch {}/{}, member {}/{}, and {} member(s) in list'.format(epoch + 1,
89 | self.hparam.n_epochs,
90 | member_idx + 1,
91 | self.n_members,
92 | len(self.weights_list))
93 | print(msg)
94 | start_time = time.time()
95 | history = self.base_model.fit(train_set,
96 | epochs=epoch + 1,
97 | initial_epoch=epoch,
98 | validation_data=validation_set
99 | )
100 | train_acc += history.history['binary_accuracy'][0]
101 | val_acc += history.history['val_binary_accuracy'][0]
102 | self.update_weights(member_idx,
103 | self.base_model.get_weights(),
104 | self.base_model.optimizer.get_weights())
105 | end_time = time.time()
106 | total_time += end_time - start_time
107 | # saving
108 | logger.info('Training ensemble costs {} in total (including validation).'.format(total_time))
109 | train_acc = train_acc / self.n_members
110 | val_acc = val_acc / self.n_members
111 | msg = 'Epoch {}/{}: training accuracy {:.5f}, validation accuracy {:.5f}.'.format(
112 | epoch + 1, self.hparam.n_epochs, train_acc, val_acc
113 | )
114 | logger.info(msg)
115 | if (epoch + 1) % self.hparam.interval == 0:
116 | if val_acc >= best_val_accuracy:
117 | self.save_ensemble_weights()
118 | best_val_accuracy = val_acc
119 | msg = '\t The best validation accuracy is {:.5f}, obtained at epoch {}/{}'.format(
120 | best_val_accuracy, epoch + 1, self.hparam.n_epochs
121 | )
122 | logger.info(msg)
123 | return
124 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Uncertainty Quantification for Android Malware Detectors
2 |
3 | This code repository is for our ACSAC 2021 paper (to appear), entitled **Can We Leverage Predictive Uncertainty to Detect Dataset Shift and Adversarial Examples in Android Malware Detection?**.
4 |
5 | ## Overview
6 | Our aim to explore the uncertainty quantification to harden malware detectors in the realistic environments (i.e., natural adversaries exist).
7 | This approach is rarely investigated in the context of malware detection, where the properties of dataset shift are different from other domains (e.g., image).
8 | Therefore we are motivated to evaluate the quality of predictive uncertainties inherent in malware detectors under the dataset shift.
9 | Specifically, we consider 4 Android malware detectors, including [DeepDrebin](http://patrickmcdaniel.org/pubs/esorics17.pdf), [MultimodalNN](https://ieeexplore.ieee.org/document/8443370), [DeepDroid](https://dl.acm.org/doi/10.1145/3029806.3029823) and [Droidetec](https://arxiv.org/abs/2002.03594), and 6 calibration methods, including [Vanilla](https://arxiv.org/abs/1906.02530), [Temp scaling](https://arxiv.org/abs/1706.04599), [Monto Carlo dropout](https://arxiv.org/abs/1506.02142), [variational Bayesian Inference](https://arxiv.org/abs/1906.02530), [Deep Ensemble](https://arxiv.org/abs/1612.01474) and Weighted deep ensemble.
10 | The dataset shift is specified as *out of source*, *temporal covariate shift* or *adversarial evasion attacks*.
11 |
12 | ## Dependencies:
13 | We develop the codes on Windows operation system, and run the codes on Ubuntu 18.04. The codes depend on Python 3.6. Other packages (e.g., TensorFlow) can be found in the `./requirements.txt`.
14 |
15 | ## Configuration & Usage
16 | #### 1. Datasets
17 | * Three datasets are leveraged, namely that [Drebin](https://www.sec.cs.tu-bs.de/~danarp/drebin/), [VirusShare_Android_APK_2013](https://virusshare.com/) and [Androzoo](https://androzoo.uni.lu/).
18 | Note that for the security consideration, these three datasets are required to follow the policies of their own to obtain the Android applications.
19 |
20 | For Drebin, we can download the malicious APKs from the official website and we provides sha256 codes of a portion of Drebin benign APKs, for which the corresponding APKs can be download from [Androzoo](https://androzoo.uni.lu/).
21 |
22 | For Androzoo, we use the dataset built by researchers [Pendlebury et al.](https://www.usenix.org/conference/usenixsecurity19/presentation/pendlebury) All APKs can be downloaded from Androzoo.
23 |
24 | For Virusshare, we use the file named `VirusShare_Android_APK_2013.zip`.
25 |
26 | For adversarial APKs, we resort to this [repository](https://github.com/deqangss/adv-dnn-ens-malware).
27 |
28 | * We additionally provide the preprocessed data files which are available at an anonymous [url](https://mega.nz/folder/bF8RQAAI#HeIhpUzDdqCdWdh4bAIZbg) (the size of unzip folder is ~213GB).
29 |
30 | #### 2. Configure
31 | For the purpose of convenience, we provide a `conf` (Windows platform) / `conf-server` (Ubuntu) file to assist the customization (Please pick one and rename it `config` to use rather than both). Before running, all things are changed in the following:
32 | * Modify the `project_root=/absolute/path/to/malware-uncertainty/`.
33 |
34 | * Modify the `database_dir=/absolute/path/to/datasets`. For more details (Optionally), there are `Drebin` or `Androzoo` malware datasets in this directory with the structure:
35 | ```
36 | datasets
37 | |---drebin
38 | |---malicious_samples % malicious apps folder
39 | |---benign_samples % benign apps foler
40 | |---androzoo_tesseract
41 | |---malicious_samples
42 | |---benign_samples
43 | | date_stamp.json % date stamp for each app, we will provide
44 | |---VirusShare_Android_APK_2013
45 | |---malicious_samples
46 | |---benign_samples
47 | |---naive_data % saving the preprocessed data files
48 | ...
49 | ```
50 | If no real apps are considered, the preprocessing data files make the project work as well. In this case, we need continue to configure the followings:
51 | * Download the `datasets` from the anonymous [url](https://mega.nz/folder/bF8RQAAI#HeIhpUzDdqCdWdh4bAIZbg), and put the folder in the project root directory, namely `malware-uncertainty`. Please Note that this `datasets` is not necessary the same as the directory of `database_dir` in the second step.
52 | * Download the `naive_data` from the anonymous [url](https://mega.nz/folder/bF8RQAAI#HeIhpUzDdqCdWdh4bAIZbg), and put the folder in the `database_dir` directory, which is configured in the second step (need unzip, `mv naive_data.tar.gz database_dir; cd database_dir; tar -xvzf naive_data.tar.gz ./`).
53 |
54 | #### 3. Usage
55 | We suggest users to create a conda environment to run the codes. In this spirit, the following instructions may be helpful:
56 | 1. Create a new environment: `conda create -n mal-uncertainty python=3.6`
57 | 2. Activate the environment and install dependencies: `conda activate mal-uncertainty` and `pip install -r requirements.txt`
58 | 3. Next step:
59 | * For training, all scripts are listed in `./run.sh`
60 | * And then for producing figures and table data, the python code is `./experiments/table-figures.py` (we have not implemented this part for the malware detector `Droidetec`)
61 |
62 | ## Warning
63 | * It is usually time consuming to perform feature extraction on Android applications.
64 | * Two detectors (DeepDroid and Droidetec) are both RAM and computation consuming because the huge long sequence is used for promoting detection accuracy
65 |
66 |
67 | ## License && Issues
68 |
69 | We will make our codes public available under a formal license. For now, this is still an ongoing work and we plan to report more results in the future work. It is worth reminding that we found there two issues when checking our codes:
70 | * No random seed set for friendly reproducing results exactly as the paper; nevertheless, the similar results can be achieved.
71 | * The training, validation, and test datasets are split randomly, leading to a mess of results.
72 |
73 | ## Contact
74 |
75 | Any questions, please do not hesitate to contact us (`Shouhuai Xu` email: `sxu@uccs.edu`, `Deqiang Li` email: `lideqiang@njust.edu.cn`)
76 |
77 |
78 |
79 |
80 |
81 |
--------------------------------------------------------------------------------
/experiments/adv.py:
--------------------------------------------------------------------------------
1 | # conduct the group of 'out of distribution' experiments on drebin dataset
2 | import os
3 | import sys
4 | import random
5 | from collections import Counter
6 |
7 | import numpy as np
8 | from sklearn.model_selection import train_test_split
9 |
10 | from core.feature import feature_type_scope_dict, feature_type_vs_architecture
11 | from core.ensemble import ensemble_method_scope_dict
12 | from core.post_calibration.temperature_scaling import apply_temperature_scaling
13 | from tools import utils
14 | from config import config, logging
15 |
16 | logger = logging.getLogger('experiment.drebin_adv')
17 |
18 |
19 | # procedure of ood experiments
20 | # 1. build dataset
21 | # 2. preprocess data
22 | # 3. conduct prediction
23 | # 4. save results for statistical analysis
24 |
25 |
26 | def run_experiment(feature_type, ensemble_type, proc_numbers=2):
27 | """
28 | run this group of experiments
29 | :param feature_type: the type of features (e.g., drebin, opcode, etc.), feature type associates to the model architecture
30 | :param ensemble_type: the ensemble method (e.g., vanilla, deep_ensemble, etc.
31 | :return: None
32 | """
33 | prist_data, adv_data, prist_y, adv_y, input_dim = data_preprocessing(feature_type, proc_numbers)
34 |
35 | ensemble_obj = get_ensemble_object(ensemble_type)
36 | # instantiation
37 | arch_type = feature_type_vs_architecture.get(feature_type)
38 | model_saving_dir = config.get('experiments', 'drebin')
39 | if ensemble_type in ['vanilla', 'mc_dropout', 'bayesian']:
40 | ensemble_model = ensemble_obj(arch_type, base_model=None, model_directory=model_saving_dir)
41 | else:
42 | ensemble_model = ensemble_obj(arch_type, base_model=None, model_directory=model_saving_dir)
43 |
44 | prist_results = ensemble_model.predict(prist_data)
45 | adv_results = ensemble_model.predict(adv_data)
46 | utils.dump_joblib((prist_results, adv_results, prist_y, adv_y),
47 | os.path.join(config.get('experiments', 'adv'), '{}_{}_drebin_adv.res'.format(feature_type, ensemble_type)))
48 | # ensemble_model.evaluate(prist_data, prist_y)
49 | # ensemble_model.evaluate(adv_data, adv_y)
50 |
51 |
52 | def run_temperature_scaling(feature_type, ensemble_type, proc_numbers=2):
53 | prist_data, adv_data, prist_y, adv_y, input_dim = data_preprocessing(feature_type, proc_numbers)
54 |
55 | ensemble_obj = get_ensemble_object(ensemble_type)
56 | # instantiation
57 | arch_type = feature_type_vs_architecture.get(feature_type)
58 | model_saving_dir = config.get('experiments', 'drebin')
59 | ensemble_model = ensemble_obj(arch_type, base_model=None, model_directory=model_saving_dir)
60 |
61 | temp_save_dir = os.path.join(config.get('drebin', 'intermediate_directory'),
62 | "{}_{}_temp.json".format(feature_type, ensemble_type))
63 | if not os.path.exists(temp_save_dir):
64 | raise FileNotFoundError
65 |
66 | temperature = utils.load_json(temp_save_dir)['temperature']
67 |
68 | prist_probs = ensemble_model.predict(prist_data, use_prob=True)
69 | prist_prob_t = apply_temperature_scaling(temperature, prist_probs)
70 | adv_probs = ensemble_model.predict(adv_data, use_prob=True)
71 | adv_prob_t = apply_temperature_scaling(temperature, adv_probs)
72 |
73 | utils.dump_joblib((prist_prob_t, adv_prob_t, prist_y, adv_y), os.path.join(config.get('experiments', 'adv'),
74 | '{}_{}_temperature_drebin_adv.res'.format(feature_type, ensemble_type)))
75 |
76 |
77 | def data_preprocessing(feature_type='drebin', proc_numbers=2):
78 | assert feature_type in feature_type_scope_dict.keys(), 'Expected {}, but {} are supported.'.format(
79 | feature_type, feature_type_scope_dict.keys())
80 | malware_pst_dir = config.get('adv', 'pristine_apk_dir')
81 | malware_adv_dir = os.path.join('adv', 'perturbed_apk_dir')
82 | android_features_saving_dir = config.get('metadata', 'naive_data_pool')
83 | intermediate_data_saving_dir = config.get('adv', 'intermediate_directory')
84 | feature_extractor = feature_type_scope_dict[feature_type](android_features_saving_dir,
85 | intermediate_data_saving_dir,
86 | update=False,
87 | proc_number=proc_numbers)
88 |
89 | save_path = os.path.join(intermediate_data_saving_dir, 'adv_database.' + feature_type)
90 | if os.path.exists(save_path):
91 | prist_filenames, adv_filenames, prist_y, adv_y = utils.read_joblib(save_path)
92 | prist_features = [os.path.join(android_features_saving_dir, filename) for filename in prist_filenames]
93 | adv_features = [os.path.join(android_features_saving_dir, filename) for filename in adv_filenames]
94 | else:
95 | prist_features = feature_extractor.feature_extraction(malware_pst_dir)
96 | prist_y = np.ones((len(prist_features),), dtype=np.int32)
97 |
98 | adv_features = feature_extractor.feature_extraction(malware_adv_dir)
99 | adv_y = np.ones((len(adv_features),), dtype=np.int32)
100 |
101 | prist_filenames = [os.path.basename(path) for path in prist_features]
102 | adv_filenames = [os.path.basename(path) for path in adv_features]
103 | utils.dump_joblib((prist_filenames, adv_filenames, prist_y, adv_y), save_path)
104 |
105 | # obtain data in a format for ML algorithms
106 | prist_data, input_dim = feature_extractor.feature2ipt(prist_features)
107 | adv_data, input_dim = feature_extractor.feature2ipt(adv_features)
108 | return prist_data, adv_data, prist_y, adv_y, input_dim
109 |
110 |
111 | def get_ensemble_object(ensemble_type):
112 | assert ensemble_type in ensemble_method_scope_dict.keys(), '{} expected, but {} are supported'.format(
113 | ensemble_type,
114 | ','.join(ensemble_method_scope_dict.keys())
115 | )
116 | return ensemble_method_scope_dict[ensemble_type]
117 |
118 |
119 | def _main():
120 | # build_data()
121 | print(get_ensemble_object('vanilla'))
122 |
123 |
124 | if __name__ == '__main__':
125 | sys.exit(_main())
126 |
--------------------------------------------------------------------------------
/tools/progressbar/examples.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 |
4 | import sys
5 | import time
6 |
7 | #from widgets import AnimatedMarker, Bar, BouncingBar, Counter, ETA, \
8 | # AdaptiveETA, FileTransferSpeed, FormatLabel, Percentage, \
9 | # ProgressBar, ReverseBar, RotatingMarker, \
10 | # SimpleProgress, Timer
11 |
12 | from compat import *
13 | from widgets import *
14 | from progressbar import *
15 |
16 | examples = []
17 | def example(fn):
18 | try: name = 'Example %d' % int(fn.__name__[7:])
19 | except: name = fn.__name__
20 |
21 | def wrapped():
22 | try:
23 | sys.stdout.write('Running: %s\n' % name)
24 | fn()
25 | sys.stdout.write('\n')
26 | except KeyboardInterrupt:
27 | sys.stdout.write('\nSkipping example.\n\n')
28 |
29 | examples.append(wrapped)
30 | return wrapped
31 |
32 | @example
33 | def example0():
34 | pbar = ProgressBar(widgets=[Percentage(), Bar()], maxval=300).start()
35 | for i in range(300):
36 | time.sleep(0.01)
37 | pbar.update(i+1)
38 | pbar.finish()
39 |
40 | @example
41 | def example1():
42 | widgets = ['Test: ', Percentage(), ' ', Bar(marker=RotatingMarker()),
43 | ' ', ETA(), ' ', FileTransferSpeed()]
44 | pbar = ProgressBar(widgets=widgets, maxval=10000000).start()
45 | for i in range(1000000):
46 | # do something
47 | pbar.update(10*i+1)
48 | pbar.finish()
49 |
50 | @example
51 | def example2():
52 | class CrazyFileTransferSpeed(FileTransferSpeed):
53 | """It's bigger between 45 and 80 percent."""
54 | def update(self, pbar):
55 | if 45 < pbar.percentage() < 80:
56 | return 'Bigger Now ' + FileTransferSpeed.update(self,pbar)
57 | else:
58 | return FileTransferSpeed.update(self,pbar)
59 |
60 | widgets = [CrazyFileTransferSpeed(),' <<<', Bar(), '>>> ',
61 | Percentage(),' ', ETA()]
62 | pbar = ProgressBar(widgets=widgets, maxval=10000000)
63 | # maybe do something
64 | pbar.start()
65 | for i in range(2000000):
66 | # do something
67 | pbar.update(5*i+1)
68 | pbar.finish()
69 |
70 | @example
71 | def example3():
72 | widgets = [Bar('>'), ' ', ETA(), ' ', ReverseBar('<')]
73 | pbar = ProgressBar(widgets=widgets, maxval=10000000).start()
74 | for i in range(1000000):
75 | # do something
76 | pbar.update(10*i+1)
77 | pbar.finish()
78 |
79 | @example
80 | def example4():
81 | widgets = ['Test: ', Percentage(), ' ',
82 | Bar(marker='0',left='[',right=']'),
83 | ' ', ETA(), ' ', FileTransferSpeed()]
84 | pbar = ProgressBar(widgets=widgets, maxval=500)
85 | pbar.start()
86 | for i in range(100,500+1,50):
87 | time.sleep(0.2)
88 | pbar.update(i)
89 | pbar.finish()
90 |
91 | @example
92 | def example5():
93 | pbar = ProgressBar(widgets=[SimpleProgress()], maxval=17).start()
94 | for i in range(17):
95 | time.sleep(0.2)
96 | pbar.update(i + 1)
97 | pbar.finish()
98 |
99 | @example
100 | def example6():
101 | pbar = ProgressBar().start()
102 | for i in range(100):
103 | time.sleep(0.01)
104 | pbar.update(i + 1)
105 | pbar.finish()
106 |
107 | @example
108 | def example7():
109 | pbar = ProgressBar() # Progressbar can guess maxval automatically.
110 | for i in pbar(range(80)):
111 | time.sleep(0.01)
112 |
113 | @example
114 | def example8():
115 | pbar = ProgressBar(maxval=80) # Progressbar can't guess maxval.
116 | for i in pbar((i for i in range(80))):
117 | time.sleep(0.01)
118 |
119 | @example
120 | def example9():
121 | pbar = ProgressBar(widgets=['Working: ', AnimatedMarker()])
122 | for i in pbar((i for i in range(50))):
123 | time.sleep(.08)
124 |
125 | @example
126 | def example10():
127 | widgets = ['Processed: ', Counter(), ' lines (', Timer(), ')']
128 | pbar = ProgressBar(widgets=widgets)
129 | for i in pbar((i for i in range(150))):
130 | time.sleep(0.1)
131 |
132 | @example
133 | def example11():
134 | widgets = [FormatLabel('Processed: %(value)d lines (in: %(elapsed)s)')]
135 | pbar = ProgressBar(widgets=widgets)
136 | for i in pbar((i for i in range(150))):
137 | time.sleep(0.1)
138 |
139 | @example
140 | def example12():
141 | widgets = ['Balloon: ', AnimatedMarker(markers='.oO@* ')]
142 | pbar = ProgressBar(widgets=widgets)
143 | for i in pbar((i for i in range(24))):
144 | time.sleep(0.3)
145 |
146 | @example
147 | def example13():
148 | # You may need python 3.x to see this correctly
149 | try:
150 | widgets = ['Arrows: ', AnimatedMarker(markers='←↖↑↗→↘↓↙')]
151 | pbar = ProgressBar(widgets=widgets)
152 | for i in pbar((i for i in range(24))):
153 | time.sleep(0.3)
154 | except UnicodeError: sys.stdout.write('Unicode error: skipping example')
155 |
156 | @example
157 | def example14():
158 | # You may need python 3.x to see this correctly
159 | try:
160 | widgets = ['Arrows: ', AnimatedMarker(markers='◢◣◤◥')]
161 | pbar = ProgressBar(widgets=widgets)
162 | for i in pbar((i for i in range(24))):
163 | time.sleep(0.3)
164 | except UnicodeError: sys.stdout.write('Unicode error: skipping example')
165 |
166 | @example
167 | def example15():
168 | # You may need python 3.x to see this correctly
169 | try:
170 | widgets = ['Wheels: ', AnimatedMarker(markers='◐◓◑◒')]
171 | pbar = ProgressBar(widgets=widgets)
172 | for i in pbar((i for i in range(24))):
173 | time.sleep(0.3)
174 | except UnicodeError: sys.stdout.write('Unicode error: skipping example')
175 |
176 | @example
177 | def example16():
178 | widgets = [FormatLabel('Bouncer: value %(value)d - '), BouncingBar()]
179 | pbar = ProgressBar(widgets=widgets)
180 | for i in pbar((i for i in range(180))):
181 | time.sleep(0.05)
182 |
183 | @example
184 | def example17():
185 | widgets = [FormatLabel('Animated Bouncer: value %(value)d - '),
186 | BouncingBar(marker=RotatingMarker())]
187 |
188 | pbar = ProgressBar(widgets=widgets)
189 | for i in pbar((i for i in range(180))):
190 | time.sleep(0.05)
191 |
192 | @example
193 | def example18():
194 | widgets = [Percentage(),
195 | ' ', Bar(),
196 | ' ', Timer(),
197 | ' ', AdaptiveETA()]
198 | pbar = ProgressBar(widgets=widgets, maxval=500)
199 | pbar.start()
200 | for i in range(500):
201 |
202 | #print
203 | time.sleep(0.01 + (i < 100) * 0.01 + (i > 400) * 0.9)
204 | pbar.update(i + 1)
205 | pbar.finish()
206 |
207 | @example
208 | def example19():
209 | pbar = ProgressBar()
210 | for i in pbar([]):
211 | pass
212 | pbar.finish()
213 |
214 | if __name__ == '__main__':
215 | try:
216 | example18()
217 | except KeyboardInterrupt:
218 | sys.stdout('\nQuitting examples.\n')
219 |
--------------------------------------------------------------------------------
/experiments/drebin_ood.py:
--------------------------------------------------------------------------------
1 | # conduct the group of 'out of distribution' experiments on drebin dataset
2 | import os
3 | import sys
4 | import random
5 | from collections import Counter
6 |
7 | import numpy as np
8 | from sklearn.model_selection import train_test_split
9 |
10 | from core.feature import feature_type_scope_dict, feature_type_vs_architecture
11 | from core.ensemble import ensemble_method_scope_dict
12 | from tools import utils
13 | from config import config, logging
14 |
15 | logger = logging.getLogger('experiment.drebin_ood')
16 |
17 | # procedure of ood experiments
18 | # 1. build dataset
19 | # 2. preprocess data
20 | # 3. learn models
21 | # 4. save results for statistical analysis
22 |
23 | def run_experiment(feature_type, ensemble_type, n_members = 1, proc_numbers=2):
24 | """
25 | run this group of experiments
26 | :param feature_type: the type of features (e.g., drebin, opcode, etc.), feature type associates to the model architecture
27 | :param ensemble_type: the ensemble method (e.g., vanilla, deep_ensemble, etc.
28 | :return: None
29 | """
30 | mal_folder, ben_folder, ood_data_paths = build_data()
31 |
32 | train_dataset, validation_dataset, test_data, test_y, ood_data, ood_y, input_dim = \
33 | data_preprocessing(feature_type, mal_folder, ben_folder, ood_data_paths, proc_numbers)
34 |
35 | ensemble_obj = get_ensemble_object(ensemble_type)
36 | # instantiation
37 | arch_type = feature_type_vs_architecture.get(feature_type)
38 | saving_dir = config.get('experiments', 'ood')
39 | if ensemble_type in ['vanilla', 'mc_dropout', 'bayesian']:
40 | ensemble_model = ensemble_obj(arch_type, base_model=None, n_members = 1, model_directory = saving_dir)
41 | else:
42 | ensemble_model = ensemble_obj(arch_type, base_model=None, n_members = n_members, model_directory = saving_dir)
43 |
44 | ensemble_model.fit(train_dataset, validation_dataset, input_dim=input_dim)
45 |
46 | test_results = ensemble_model.predict(test_data)
47 | utils.dump_joblib(test_results, os.path.join(saving_dir, '{}_{}_test.res'.format(feature_type, ensemble_type)))
48 | ensemble_model.evaluate(test_data, test_y)
49 | ood_results = ensemble_model.predict(ood_data)
50 | utils.dump_joblib(ood_results, os.path.join(saving_dir, '{}_{}_ood.res'.format(feature_type, ensemble_type)))
51 | ensemble_model.evaluate(ood_data, ood_y, is_single_class=True, name='ood')
52 |
53 | def build_data():
54 | malware_dir = config.get('drebin', 'malware_dir')
55 | benware_dir = config.get('drebin', 'benware_dir')
56 | malare_paths, ood_data_paths = produce_ood_data(malware_dir)
57 |
58 | return malare_paths, benware_dir, ood_data_paths
59 |
60 |
61 | def produce_ood_data(malware_dir, top_frequency=30, n_selection=5, minimum_samples = 1, maximum_samples=1000):
62 | import pandas as pd
63 | malware_family_pd = pd.read_csv(config.get('drebin', 'malware_family'))
64 | counter = dict(Counter(malware_family_pd['family']).most_common(top_frequency))
65 | i = 0
66 | while i <= 1e5:
67 | random.seed(i)
68 | selected_families = random.sample(counter.keys(), n_selection)
69 | number_of_malware = sum([counter[f] for f in selected_families])
70 | if minimum_samples <= number_of_malware <= maximum_samples:
71 | break
72 | else:
73 | i = i + 1
74 | else:
75 | random.seed(1)
76 | selected_families = random.sample(counter.keys(), n_selection)
77 | number_of_malware = sum([counter[f] for f in selected_families])
78 | logger.info('The number of selected ood malware samples: {}'.format(number_of_malware))
79 | logger.info("The selected families are {}".format(','.join(selected_families)))
80 |
81 | selected_sha256_codes = list(malware_family_pd[malware_family_pd.family.isin(selected_families)]['sha256'])
82 | assert len(selected_sha256_codes) == number_of_malware
83 |
84 | malware_paths = utils.retrive_files_set(malware_dir, "", ".apk|")
85 | ood_data_paths = []
86 | for mal_path in malware_paths:
87 | sha256 = utils.get_sha256(mal_path)
88 | if sha256 in selected_sha256_codes:
89 | ood_data_paths.append(mal_path)
90 |
91 | return list(set(malware_paths) - set(ood_data_paths)), ood_data_paths
92 |
93 |
94 | def data_preprocessing(feature_type='drebin', malware_dir=None, benware_dir=None, ood_data_path = None, proc_numbers = 2):
95 | assert feature_type in feature_type_scope_dict.keys(), 'Expected {}, but {} are supported.'.format(
96 | feature_type, feature_type_scope_dict.keys())
97 |
98 | android_features_saving_dir = config.get('metadata', 'naive_data_directory')
99 | intermediate_data_saving_dir = config.get('metadata', 'meta_data_directory')
100 | feature_extractor = feature_type_scope_dict[feature_type](android_features_saving_dir,
101 | intermediate_data_saving_dir,
102 | update=False,
103 | proc_number = proc_numbers)
104 | mal_feature_list = feature_extractor.feature_extraction(malware_dir)
105 | n_malware = len(mal_feature_list)
106 | ben_feature_list = feature_extractor.feature_extraction(benware_dir)
107 | n_benware = len(ben_feature_list)
108 | feature_list = mal_feature_list + ben_feature_list
109 | gt_labels = np.zeros((len(feature_list),), dtype=np.int32)
110 | gt_labels[:n_malware] = 1
111 |
112 | # data split
113 | train_features, test_features, train_y, test_y = train_test_split(feature_list, gt_labels,
114 | test_size=0.2, random_state=0)
115 | feature_extractor.feature_preprocess(train_features, train_y) # produce intermediate products
116 | # obtain validation data
117 | train_features, validation_features, train_y, validation_y = train_test_split(train_features,
118 | train_y,
119 | test_size=0.25,
120 | random_state=0
121 | )
122 |
123 | # obtain data in a format for ML algorithms
124 | train_dataset, input_dim = feature_extractor.feature2ipt(train_features, train_y)
125 | test_data, _ = feature_extractor.feature2ipt(test_features)
126 | validation_dataset, _ = feature_extractor.feature2ipt(validation_features, validation_y)
127 |
128 | ood_features = feature_extractor.feature_extraction(ood_data_path)
129 | ood_y = np.ones((len(ood_features),))
130 | ood_data, _ = feature_extractor.feature2ipt(ood_features)
131 |
132 | return train_dataset, validation_dataset, test_data, test_y, ood_data, ood_y, input_dim
133 |
134 |
135 | def get_ensemble_object(ensemble_type):
136 | assert ensemble_type in ensemble_method_scope_dict.keys(), '{} expected, but {} are supported'.format(
137 | ensemble_type,
138 | ','.join(ensemble_method_scope_dict.keys())
139 | )
140 | return ensemble_method_scope_dict[ensemble_type]
141 |
142 | def _main():
143 | # build_data()
144 | print(get_ensemble_object('vanilla'))
145 |
146 |
147 | if __name__ == '__main__':
148 | sys.exit(_main())
149 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | python -m experiments.drebin_main --detector drebin --calibration vanilla --n_members 1 --ratio 3.0 --proc_numbers 2
2 | python -m experiments.drebin_main --detector drebin --calibration temp_scaling --n_members 1 --ratio 3.0 --proc_numbers 2
3 | python -m experiments.drebin_main --detector drebin --calibration mc_dropout --n_members 10 --ratio 3.0 --proc_numbers 2
4 | python -m experiments.drebin_main --detector drebin --calibration bayesian --n_members 10 --ratio 3.0 --proc_numbers 2
5 | python -m experiments.drebin_main --detector drebin --calibration deep_ensemble --n_members 10 --ratio 3.0 --proc_numbers 2
6 | python -m experiments.drebin_main --detector drebin --calibration weighted_ensembl --n_members 10 --ratio 3.0 --proc_numbers 2
7 | python -m experiments.drebin_main --detector opcodeseq --calibration vanilla --n_members 1 --ratio 3.0 --proc_numbers 2
8 | python -m experiments.drebin_main --detector opcodeseq --calibration temp_scaling --n_members 1 --ratio 3.0 --proc_numbers 2
9 | python -m experiments.drebin_main --detector opcodeseq --calibration mc_dropout --n_members 10 --ratio 3.0 --proc_numbers 2
10 | python -m experiments.drebin_main --detector opcodeseq --calibration bayesian --n_members 10 --ratio 3.0 --proc_numbers 2
11 | python -m experiments.drebin_main --detector opcodeseq --calibration deep_ensemble --n_members 10 --ratio 3.0 --proc_numbers 2
12 | python -m experiments.drebin_main --detector opcodeseq --calibration weighted_ensemble --n_members 10 --ratio 3.0 --proc_numbers 2
13 | python -m experiments.drebin_main --detector multimodality --calibration vanilla --n_members 1 --ratio 3.0 --proc_numbers 2
14 | python -m experiments.drebin_main --detector multimodality --calibration temp_scaling --n_members 1 --ratio 3.0 --proc_numbers 2
15 | python -m experiments.drebin_main --detector multimodality --calibration mc_dropout --n_members 10 --ratio 3.0 --proc_numbers 2
16 | python -m experiments.drebin_main --detector multimodality --calibration bayesian --n_members 10 --ratio 3.0 --proc_numbers 2
17 | python -m experiments.drebin_main --detector multimodality --calibration deep_ensemble --n_members 10 --ratio 3.0 --proc_numbers 2
18 | python -m experiments.drebin_main --detector multimodality --calibration weighted_ensemble --n_members 10 --ratio 3.0 --proc_numbers 2
19 | python -m experiments.androzoo_main --detector drebin --calibration vanilla --n_members 1 --proc_numbers 2
20 | python -m experiments.androzoo_main --detector drebin --calibration temp_scaling --n_members 1 --proc_numbers 2
21 | python -m experiments.androzoo_main --detector drebin --calibration mc_dropout --n_members 10 --proc_numbers 2
22 | python -m experiments.androzoo_main --detector drebin --calibration bayesian --n_members 10 --proc_numbers 2
23 | python -m experiments.androzoo_main --detector drebin --calibration deep_ensemble --n_members 10 --proc_numbers 2
24 | python -m experiments.androzoo_main --detector drebin --calibration weighted_ensemble --n_members 10 --proc_numbers 2
25 | python -m experiments.androzoo_main --detector opcodeseq --calibration vanilla --n_members 1 --proc_numbers 2
26 | python -m experiments.androzoo_main --detector opcodeseq --calibration temp_scaling --n_members 1 --proc_numbers 2
27 | python -m experiments.androzoo_main --detector opcodeseq --calibration mc_dropout --n_members 10 --proc_numbers 2
28 | python -m experiments.androzoo_main --detector opcodeseq --calibration bayesian --n_members 10 --proc_numbers 2
29 | python -m experiments.androzoo_main --detector opcodeseq --calibration deep_ensemble --n_members 10 --proc_numbers 2
30 | python -m experiments.androzoo_main --detector opcodeseq --calibration weighted_ensemble --n_members 10 --proc_numbers 2
31 | python -m experiments.androzoo_main --detector multimodality --calibration vanilla --n_members 1 --proc_numbers 2
32 | python -m experiments.androzoo_main --detector multimodality --calibration temp_scaling --n_members 1 --proc_numbers 2
33 | python -m experiments.androzoo_main --detector multimodality --calibration mc_dropout --n_members 10 --proc_numbers 2
34 | python -m experiments.androzoo_main --detector multimodality --calibration bayesian --n_members 10 --proc_numbers 2
35 | python -m experiments.androzoo_main --detector multimodality --calibration deep_ensemble --n_members 10 --proc_numbers 2
36 | python -m experiments.androzoo_main --detector multimodality --calibration weighted_ensemble --n_members 10 --proc_numbers 2
37 | python -m experiments.oos_main --detector drebin --calibration vanilla --proc_numbers 2
38 | python -m experiments.oos_main --detector drebin --calibration temp_scaling --proc_numbers 2
39 | python -m experiments.oos_main --detector drebin --calibration mc_dropout --proc_numbers 2
40 | python -m experiments.oos_main --detector drebin --calibration bayesian --proc_numbers 2
41 | python -m experiments.oos_main --detector drebin --calibration deep_ensemble --proc_numbers 2
42 | python -m experiments.oos_main --detector drebin --calibration weighted_ensemble --proc_numbers 2
43 | python -m experiments.oos_main --detector opcodeseq --calibration vanilla --proc_numbers 2
44 | python -m experiments.oos_main --detector opcodeseq --calibration temp_scaling --proc_numbers 2
45 | python -m experiments.oos_main --detector opcodeseq --calibration mc_dropout --proc_numbers 2
46 | python -m experiments.oos_main --detector opcodeseq --calibration bayesian --proc_numbers 2
47 | python -m experiments.oos_main --detector opcodeseq --calibration deep_ensemble --proc_numbers 2
48 | python -m experiments.oos_main --detector opcodeseq --calibration weighted_ensemble --proc_numbers 2
49 | python -m experiments.oos_main --detector multimodality --calibration vanilla --proc_numbers 2
50 | python -m experiments.oos_main --detector multimodality --calibration temp_scaling --proc_numbers 2
51 | python -m experiments.oos_main --detector multimodality --calibration mc_dropout --proc_numbers 2
52 | python -m experiments.oos_main --detector multimodality --calibration bayesian --proc_numbers 2
53 | python -m experiments.oos_main --detector multimodality --calibration deep_ensemble --proc_numbers 2
54 | python -m experiments.oos_main --detector multimodality --calibration weighted_ensemble --proc_numbers 2
55 | # python -m experiments.adv_main --detector drebin --calibration vanilla --proc_numbers 2
56 | # python -m experiments.adv_main --detector drebin --calibration temp_scaling --proc_numbers 2
57 | # python -m experiments.adv_main --detector drebin --calibration mc_dropout --proc_numbers 2
58 | # python -m experiments.adv_main --detector drebin --calibration bayesian --proc_numbers 2
59 | # python -m experiments.adv_main --detector drebin --calibration deep_ensemble --proc_numbers 2
60 | # python -m experiments.adv_main --detector drebin --calibration weighted_ensemble --proc_numbers 2
61 | # python -m experiments.adv_main --detector opcodeseq --calibration vanilla --proc_numbers 2
62 | # python -m experiments.adv_main --detector opcodeseq --calibration temp_scaling --proc_numbers 2
63 | # python -m experiments.adv_main --detector opcodeseq --calibration mc_dropout --proc_numbers 2
64 | # python -m experiments.adv_main --detector opcodeseq --calibration bayesian --proc_numbers 2
65 | # python -m experiments.adv_main --detector opcodeseq --calibration deep_ensemble --proc_numbers 2
66 | # python -m experiments.adv_main --detector opcodeseq --calibration weighted_ensemble --proc_numbers 2
67 | # python -m experiments.adv_main --detector multimodality --calibration vanilla --proc_numbers 2
68 | # python -m experiments.adv_main --detector multimodality --calibration temp_scaling --proc_numbers 2
69 | # python -m experiments.adv_main --detector multimodality --calibration mc_dropout --proc_numbers 2
70 | # python -m experiments.adv_main --detector multimodality --calibration bayesian --proc_numbers 2
71 | # python -m experiments.adv_main --detector multimodality --calibration deep_ensemble --proc_numbers 2
72 | # python -m experiments.adv_main --detector multimodality --calibration weighted_ensemble --proc_numbers 2
73 |
--------------------------------------------------------------------------------
/test/deep_ensemble_test.py:
--------------------------------------------------------------------------------
1 | from absl.testing import absltest
2 | from absl.testing import parameterized
3 | import tempfile
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 | from sklearn.datasets import load_breast_cancer
8 |
9 | from core.ensemble.deep_ensemble import DeepEnsemble, WeightedDeepEnsemble
10 | from core.ensemble.dataset_lib import build_dataset_from_numerical_data
11 |
12 |
13 | class MyTestCaseDeepEnsemble(parameterized.TestCase):
14 | def setUp(self):
15 | self.x_np, self.y_np = load_breast_cancer(return_X_y=True)
16 | self.train_dataset_v1 = build_dataset_from_numerical_data((self.x_np, self.y_np))
17 | self.val_dataset_v1 = build_dataset_from_numerical_data((self.x_np, self.y_np))
18 |
19 | def test_dnn(self):
20 | with tempfile.TemporaryDirectory() as output_dir:
21 | deepensemble = DeepEnsemble(architecture_type='dnn',
22 | model_directory=output_dir)
23 | deepensemble.fit(self.train_dataset_v1, self.val_dataset_v1, input_dim=self.x_np.shape[1])
24 |
25 | res = deepensemble.predict(self.x_np)
26 | self.assertEqual(deepensemble.get_n_members(), deepensemble.n_members)
27 | self.assertTrue(res.shape == (self.x_np.shape[0], deepensemble.n_members, 1))
28 |
29 | deepensemble.evaluate(self.x_np, self.y_np)
30 |
31 | def test_textcnn(self):
32 | with tempfile.TemporaryDirectory() as output_dir:
33 | deepensemble = DeepEnsemble(architecture_type='text_cnn',
34 | model_directory=output_dir)
35 | x = np.random.randint(0, 256, (2, 10))
36 | y = np.random.choice(2, 2)
37 | train_dataset = build_dataset_from_numerical_data((x, y))
38 | val_dataset = build_dataset_from_numerical_data((x, y))
39 | deepensemble.fit(train_dataset, val_dataset)
40 | res = deepensemble.predict(x)
41 | self.assertTrue(res.shape == (x.shape[0], deepensemble.n_members, 1))
42 | deepensemble.evaluate(x, y)
43 |
44 | def test_multimodalitynn(self):
45 | with tempfile.TemporaryDirectory() as output_dir:
46 | deepensemble = DeepEnsemble(architecture_type='multimodalitynn',
47 | model_directory=output_dir)
48 | x = [self.x_np] * 5
49 | train_data = build_dataset_from_numerical_data(tuple(x))
50 | train_y = build_dataset_from_numerical_data(self.y_np)
51 | train_dataset = tf.data.Dataset.zip((train_data, train_y))
52 | val_data = build_dataset_from_numerical_data(tuple(x))
53 | val_y = build_dataset_from_numerical_data(self.y_np)
54 | val_dataset = tf.data.Dataset.zip((val_data, val_y))
55 | deepensemble.fit(train_dataset, val_dataset, input_dim=[self.x_np.shape[1]] * 5)
56 | res = deepensemble.predict(x)
57 | self.assertTrue(res.shape == (self.x_np.shape[0], deepensemble.n_members, 1))
58 | deepensemble.evaluate(x, self.y_np)
59 |
60 | def test_r2d2(self):
61 | with tempfile.TemporaryDirectory() as output_dir:
62 | deepensemble = DeepEnsemble(architecture_type='r2d2',
63 | model_directory=output_dir)
64 | x = np.random.uniform(0., 1., size=(10, 299, 299, 3))
65 | y = np.random.choice(2, 10)
66 | train_dataset = build_dataset_from_numerical_data((x, y))
67 | val_dataset = build_dataset_from_numerical_data((x, y))
68 | deepensemble.fit(train_dataset, val_dataset, input_dim=(299, 299, 3))
69 | res = deepensemble.predict(x)
70 | self.assertTrue(res.shape == (x.shape[0], deepensemble.n_members, 1))
71 | deepensemble.evaluate(x, y)
72 |
73 | def test_droidectc(self):
74 | with tempfile.TemporaryDirectory() as output_dir:
75 | deepensemble = DeepEnsemble(architecture_type='droidectc',
76 | model_directory=output_dir)
77 |
78 | x = np.random.randint(0, 10000, size=(10, 1000))
79 | y = np.random.choice(2, 10)
80 | train_dataset = build_dataset_from_numerical_data((x, y))
81 | val_dataset = build_dataset_from_numerical_data((x, y))
82 | deepensemble.fit(train_dataset, val_dataset)
83 | res = deepensemble.predict(x)
84 | self.assertTrue(res.shape == (x.shape[0], deepensemble.n_members, 1))
85 | deepensemble.evaluate(x, y)
86 |
87 |
88 | class MyTestCaseWeightedDeepEnsemble(parameterized.TestCase):
89 | def setUp(self):
90 | self.x_np, self.y_np = load_breast_cancer(return_X_y=True)
91 | self.train_dataset_v1 = build_dataset_from_numerical_data((self.x_np, self.y_np))
92 | self.val_dataset_v1 = build_dataset_from_numerical_data((self.x_np, self.y_np))
93 |
94 | def test_dnn(self):
95 | with tempfile.TemporaryDirectory() as output_dir:
96 | weighteddeepensemble = WeightedDeepEnsemble(architecture_type='dnn',
97 | model_directory=output_dir)
98 | weighteddeepensemble.fit(self.train_dataset_v1, self.val_dataset_v1, input_dim=self.x_np.shape[1])
99 |
100 | out, w = weighteddeepensemble.predict(self.x_np)
101 | self.assertEqual(weighteddeepensemble.get_n_members(), weighteddeepensemble.n_members)
102 | self.assertTrue(out.shape == (self.x_np.shape[0], weighteddeepensemble.n_members, 1))
103 | self.assertTrue(np.sum(w) == 1.)
104 |
105 | weighteddeepensemble.evaluate(self.x_np, self.y_np)
106 |
107 | def test_textcnn(self):
108 | with tempfile.TemporaryDirectory() as output_dir:
109 | weighteddeepensemble = WeightedDeepEnsemble(architecture_type='text_cnn',
110 | model_directory=output_dir)
111 | x = np.random.randint(0, 256, (10, 10))
112 | y = np.random.choice(2, 10)
113 | train_dataset = build_dataset_from_numerical_data((x, y))
114 | val_dataset = build_dataset_from_numerical_data((x, y))
115 | weighteddeepensemble.fit(train_dataset, val_dataset)
116 | out, w = weighteddeepensemble.predict(x)
117 | self.assertTrue(out.shape == (out.shape[0], weighteddeepensemble.n_members, 1))
118 | self.assertTrue(np.sum(w) == 1.)
119 | weighteddeepensemble.evaluate(x, y)
120 |
121 | def test_multimodalitynn(self):
122 | with tempfile.TemporaryDirectory() as output_dir:
123 | weighteddeepensemble = WeightedDeepEnsemble(architecture_type='multimodalitynn',
124 | model_directory=output_dir)
125 | x = [self.x_np] * 5
126 | train_data = build_dataset_from_numerical_data(tuple(x))
127 | train_y = build_dataset_from_numerical_data(self.y_np)
128 | train_dataset = tf.data.Dataset.zip((train_data, train_y))
129 | val_data = build_dataset_from_numerical_data(tuple(x))
130 | val_y = build_dataset_from_numerical_data(self.y_np)
131 | val_dataset = tf.data.Dataset.zip((val_data, val_y))
132 | weighteddeepensemble.fit(train_dataset, val_dataset, input_dim=[self.x_np.shape[1]] * 5)
133 | out, w = weighteddeepensemble.predict(x)
134 | self.assertTrue(out.shape == (self.x_np.shape[0], weighteddeepensemble.n_members, 1))
135 | self.assertTrue(np.sum(w) == 1.)
136 | weighteddeepensemble.evaluate(x, self.y_np)
137 |
138 | def test_r2d2(self):
139 | with tempfile.TemporaryDirectory() as output_dir:
140 | weighteddeepensemble = WeightedDeepEnsemble(architecture_type='r2d2',
141 | model_directory=output_dir)
142 | x = np.random.uniform(0., 1., size=(10, 299, 299, 3))
143 | y = np.random.choice(2, 10)
144 | train_dataset = build_dataset_from_numerical_data((x, y))
145 | val_dataset = build_dataset_from_numerical_data((x, y))
146 | weighteddeepensemble.fit(train_dataset, val_dataset, input_dim=(299, 299, 3))
147 | out, w = weighteddeepensemble.predict(x)
148 | self.assertTrue(out.shape == (x.shape[0], weighteddeepensemble.n_members, 1))
149 | self.assertTrue(np.sum(w) == 1.)
150 | weighteddeepensemble.evaluate(x, y)
151 |
152 | def test_droidectc(self):
153 | with tempfile.TemporaryDirectory() as output_dir:
154 | weighteddeepensemble = WeightedDeepEnsemble(architecture_type='droidectc',
155 | model_directory=output_dir)
156 |
157 | x = np.random.randint(0, 10000, size=(10, 1000))
158 | y = np.random.choice(2, 10)
159 | train_dataset = build_dataset_from_numerical_data((x, y))
160 | val_dataset = build_dataset_from_numerical_data((x, y))
161 | weighteddeepensemble.fit(train_dataset, val_dataset)
162 | out, w = weighteddeepensemble.predict(x)
163 | self.assertTrue(out.shape == (x.shape[0], weighteddeepensemble.n_members, 1))
164 | self.assertTrue(np.sum(w) == 1.)
165 | weighteddeepensemble.evaluate(x, y)
166 |
167 |
168 | if __name__ == '__main__':
169 | absltest.main()
170 |
--------------------------------------------------------------------------------
/tools/temporal.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | temporal.py
5 | ~~~~~~~~~~~
6 |
7 | A module for working with and running time-aware evaluations. Most of the
8 | functionality of this module falls into one of two categories: working with
9 | arrays of datetimes or datetime-aligned series of data, and aggregating the
10 | steps of the ML pipeline needed to conduct sound, time-aware evaluations.
11 |
12 |
13 | @inproceedings{pendlebury2019,
14 | author = {Feargus Pendlebury, Fabio Pierazzi, Roberto Jordaney, Johannes Kinder, and Lorenzo Cavallaro},
15 | title = {{TESSERACT: Eliminating Experimental Bias in Malware Classification across Space and Time}},
16 | booktitle = {28th USENIX Security Symposium},
17 | year = {2019},
18 | address = {Santa Clara, CA},
19 | publisher = {USENIX Association},
20 | note = {USENIX Sec}
21 | }
22 |
23 | """
24 | import bisect
25 | import operator
26 | from datetime import datetime, date
27 |
28 | import numpy as np
29 | from dateutil.relativedelta import relativedelta
30 |
31 |
32 | def assert_train_test_temporal_consistency(t_train, t_test):
33 | """Helper function to assert train-test temporal constraint (C1).
34 |
35 | All objects in the training set need to be temporally anterior to all
36 | objects in the testing set. Violating this constraint will positively bias
37 | the results by integrating "future" knowledge into the classifier.
38 |
39 | Args:
40 | t_train: An array of datetimes corresponding to the training set.
41 | t_test: An array of datetime corresponding to the testing set.
42 |
43 | Returns:
44 | bool: False if the partitioned dataset does _not_ adhere to C1,
45 | True otherwise.
46 |
47 | """
48 | for train_date in t_train:
49 | for test_date in t_test:
50 | if train_date > test_date:
51 | return False
52 | return True
53 |
54 |
55 | def assert_positive_negative_temporal_consistency(y, t, month_variance=1):
56 | """Helper function to assert malware-goodware temporal constraint (C2).
57 |
58 | In any given testing period, all testing objects must be from the time
59 | window under test. In the malware domain this constraint has often been
60 | violated so that malware and goodware come from different time periods.
61 |
62 | If this is the case, it becomes impossible to tell whether a
63 | high-performing classifier is discriminating between malicious and benign
64 | objects or between old and new applications.
65 |
66 | Args:
67 | y: An array of ground-truth labels for each observation.
68 | t: An array of datetimes for each observation (aligned with y).
69 | month_variance: All malware and goodware should be between this many
70 | months.
71 |
72 | Returns:
73 | bool: False if the malware and goodware do not adhere to C2,
74 | True otherwise
75 |
76 | """
77 | positive = np.where(y == 1)[0]
78 | negative = np.where(y != 1)[0]
79 | positive_dates = t[positive]
80 | negative_dates = t[negative]
81 |
82 | for pos_date in positive_dates:
83 | for neg_date in negative_dates:
84 | if month_difference(pos_date, neg_date) > month_variance:
85 | return False
86 | return True
87 |
88 |
89 | def month_difference(d1, d2):
90 | """Get the difference in months between two datetimes."""
91 | return (d1.year - d2.year) * 12 + d1.month - d2.month
92 |
93 |
94 | def resolve_date(d):
95 | """Convert a str or date to an appropriate datetime.
96 |
97 | Strings should be of the format '%Y', '%Y-%m or '%Y-%m-%d', for example:
98 | '2012', '1994-02' or '1991-12-11'. Date objects with no time information
99 | will be rounded down to the midnight beginning that date.
100 |
101 | Args:
102 | d (Union[str, date]): The string or date to convert.
103 |
104 | Returns:
105 | datetime: The parsed datetime equivalent of d.
106 | """
107 | if isinstance(d, datetime):
108 | return d
109 |
110 | if isinstance(d, date):
111 | return datetime.combine(d, datetime.min.time())
112 |
113 | for fmt in ('%Y', '%Y-%m', '%Y-%m-%d'):
114 | try:
115 | return datetime.strptime(d, fmt)
116 | except ValueError:
117 | pass
118 |
119 | raise ValueError('date string format not recognized.')
120 |
121 |
122 | def time_aware_train_test_split(X, y, t, train_size, test_size,
123 | granularity, start_date=None):
124 | """Partition a dataset composed of time-labelled objects.
125 |
126 | Args:
127 | X (np.ndarray): Multi-dimensional array of predictors.
128 | y (np.ndarray): Array of output labels.
129 | t (np.ndarray): Array of timestamp tags.
130 | train_size (int): The training window size W (in τ).
131 | test_size (int): The testing window size Δ (in τ).
132 | granularity (str): The unit of time τ, used to denote the window size.
133 | Acceptable values are 'year|quarter|month|week|day'.
134 | start_date (date): The date to begin partioning from (eg. to align with
135 | the start of the year).
136 |
137 | Returns:
138 | (np.ndarray, list, np.ndarray, list):
139 | Training partition of predictors X.
140 | List of testing partitions of predictors X.
141 | Training partition of output variables y.
142 | List of testing partitions of predictors y.
143 |
144 | """
145 | # Get partitioned indexes
146 | train, tests = time_aware_indexes(t, train_size, test_size,
147 | granularity, start_date)
148 |
149 | # Partition predictors and labels
150 | X_actual, y_actual, t_actual = X[train], y[train], t[train]
151 |
152 | X_tests = [X[index_set] for index_set in tests]
153 | y_tests = [y[index_set] for index_set in tests]
154 | t_tests = [t[index_set] for index_set in tests]
155 |
156 | return X_actual, X_tests, y_actual, y_tests, t_actual, t_tests
157 |
158 |
159 | def time_aware_indexes(t, train_size, test_size, granularity, start_date=None):
160 | """Return a list of indexes that partition the list t by time.
161 |
162 | Sorts the list of dates t before dividing into training and testing
163 | partitions, ensuring a 'history-aware' split in the ensuing classification
164 | task.
165 |
166 |
167 | Args:
168 | t (np.ndarray): Array of timestamp tags.
169 | train_size (int): The training window size W (in τ).
170 | test_size (int): The testing window size Δ (in τ).
171 | granularity (str): The unit of time τ, used to denote the window size.
172 | Acceptable values are 'year|quarter|month|week|day'.
173 | start_date (date): The date to begin partioning from (eg. to align with
174 | the start of the year).
175 |
176 | Returns:
177 | (list, list):
178 | Indexing for the training partition.
179 | List of indexings for the testing partitions.
180 |
181 | """
182 | # Order the dates as well as their original positions
183 | with_indexes = zip(t, range(len(t)))
184 | ordered = sorted(with_indexes, key=operator.itemgetter(0))
185 |
186 | # Split out the dates from the indexes
187 | dates = [tup[0] for tup in ordered]
188 | indexes = [tup[1] for tup in ordered]
189 |
190 | # Get earliest date
191 | start_date = resolve_date(start_date) if start_date else ordered[0][0]
192 |
193 | # Slice out training partition
194 | boundary = start_date + get_relative_delta(train_size, granularity)
195 | to_idx = bisect.bisect_left(dates, boundary)
196 | train = indexes[:to_idx]
197 |
198 | tests = []
199 | # Slice out testing partitions
200 | while to_idx < len(indexes):
201 | boundary += get_relative_delta(test_size, granularity)
202 | from_idx = to_idx
203 | to_idx = bisect.bisect_left(dates, boundary)
204 | tests.append(indexes[from_idx:to_idx])
205 |
206 | return train, tests
207 |
208 |
209 | def time_aware_partition(t, proportion):
210 | """Partition an array of dates based on the given proportion.
211 |
212 | The set of timestamps will be bisected with the left bisection sized by
213 | the given proportion.
214 |
215 | Args:
216 | t: An array of datetimes.
217 | proportion: The proportion by which to split the array.
218 |
219 | Returns:
220 | tuple: The two bisections of the array.
221 | """
222 | # Order the dates as well as their original positions
223 | indexes = np.argsort(t)
224 |
225 | # Divide ordered set in two
226 | boundary = int(proportion * len(indexes))
227 |
228 | return indexes[:boundary], indexes[boundary:]
229 |
230 |
231 | def temporal_slice(X, y, t):
232 | raise NotImplementedError
233 |
234 |
235 | def get_relative_delta(offset, granularity):
236 | """Get delta of size 'granularity'.
237 |
238 | Args:
239 | offset: The number of time units to offset by.
240 | granularity: The unit of time to offset by, expects one of
241 | 'year', 'quarter', 'month', 'week', 'day'.
242 |
243 | Returns:
244 | The timedelta equivalent to offset * granularity.
245 |
246 | """
247 | # Make allowances for year(s), quarter(s), month(s), week(s), day(s)
248 | granularity = granularity[:-1] if granularity[-1] == 's' else granularity
249 | try:
250 | return {
251 | 'year': relativedelta(years=offset),
252 | 'quarter': relativedelta(months=3 * offset),
253 | 'month': relativedelta(months=offset),
254 | 'week': relativedelta(weeks=offset),
255 | 'day': relativedelta(days=offset),
256 | }[granularity]
257 | except KeyError:
258 | raise ValueError('granularity not recognised, try: '
259 | 'year|quarter|month|week|day')
260 |
--------------------------------------------------------------------------------
/tools/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def _check_probablities(p, q=None):
5 | assert 0. <= np.all(p) <= 1.
6 | if q is not None:
7 | assert len(p) == len(q), \
8 | 'Probabilies and ground truth must have the same number of elements.'
9 |
10 |
11 | def entropy(p, base=2, eps=1e-10):
12 | """
13 | calculate entropy in element-wise
14 | :param p: probabilities
15 | :param base: default exp
16 | :return: average entropy value
17 | """
18 | p_arr = np.asarray(p)
19 | _check_probablities(p)
20 | enc = -(p_arr * np.log(p_arr + eps) + (1. - p_arr) * np.log(1. - p_arr + eps))
21 | if base is not None:
22 | enc = np.clip(enc / np.log(base), a_min=0., a_max=1000)
23 | return enc
24 |
25 |
26 | def predictive_kld(p, number, w=None, base=2, eps=1e-10):
27 | """
28 | calculate Kullback-Leibler divergence in element-wise
29 | :param p: probabilities
30 | :param number: the number of likelihood values for each sample
31 | :param w: weights for probabilities
32 | :param base: default exp
33 | :return: average entropy value
34 | """
35 | if number <= 1:
36 | return np.zeros_like(p)
37 |
38 | p_arr = np.asarray(p).reshape((-1, number))
39 | _check_probablities(p)
40 | q_arr = np.tile(np.mean(p_arr, axis=-1, keepdims=True), [1, number])
41 | if w is None:
42 | w_arr = np.ones(shape=(number, 1), dtype=np.float) / number
43 | else:
44 | w_arr = np.asarray(w).reshape((number, 1))
45 |
46 | kld_elem = p_arr * np.log((p_arr + eps) / (q_arr + eps)) + (1. - p_arr) * np.log(
47 | (1. - p_arr + eps) / (1. - q_arr + eps))
48 | if base is not None:
49 | kld_elem = kld_elem / np.log(base)
50 | kld = np.matmul(kld_elem, w_arr)
51 | return kld
52 |
53 |
54 | def predictive_std(p, number, w=None):
55 | """
56 | calculate the probabilities deviation
57 | :param p: probabilities
58 | :param number: the number of probabilities applied to each sample
59 | :param w: weights for probabilities
60 | :param axis: the axis along which the calculation is conducted
61 | :return:
62 | """
63 | if number <= 1:
64 | return np.zeros_like(p)
65 |
66 | ps_arr = np.asarray(p).reshape((-1, number))
67 | _check_probablities(ps_arr)
68 | if w is None:
69 | w = np.ones(shape=(number, 1), dtype=np.float) / number
70 | else:
71 | w = np.asarray(w).reshape((number, 1))
72 | assert 0 <= np.all(w) <= 1.
73 | mean = np.matmul(ps_arr, w)
74 | var = np.sqrt(np.matmul(np.square(ps_arr - mean), w) * (float(number) / float(number - 1)))
75 | return var
76 |
77 |
78 | def nll(p, q, eps=1e-10, base=2):
79 | """
80 | negative log likelihood (NLL)
81 | :param p: predictive labels
82 | :param q: ground truth labels
83 | :param eps: a small value prevents the overflow
84 | :param base: the base of log function
85 | :return: the mean of NLL
86 | """
87 | _check_probablities(p, q)
88 | nll = -(q * np.log(p + eps) + (1. - q) * np.log(1. - p + eps))
89 | if base is not None:
90 | nll = np.clip(nll / np.log(base), a_min=0., a_max=1000)
91 | return np.mean(nll)
92 |
93 |
94 | def b_nll(p, q, eps=1e-10, base=2):
95 | """
96 | balanced negative log likelihood (NLL)
97 | :param p: 1-D array, predictive labels
98 | :param q: 1-D array, ground truth labels
99 | :param eps: a small value prevents the overflow
100 | :param base: the base of log function
101 | :return: the mean of NLL
102 | """
103 | _check_probablities(p, q)
104 | pos_indicator = (q == 1)
105 | pos_nll = nll(p[pos_indicator], q[pos_indicator], eps, base)
106 | neg_indicator = (q == 0)
107 | neg_nll = nll(p[neg_indicator], q[neg_indicator], eps, base)
108 | return 1. / 2. * (pos_nll + neg_nll)
109 |
110 |
111 | def brier_score(p, q, pos_label=1):
112 | """
113 | brier score
114 | :param p: predictive labels
115 | :param q: ground truth labels
116 | :param pos_label: the positive class
117 | :return:
118 | """
119 | from sklearn.metrics import brier_score_loss
120 | _check_probablities(p, q)
121 | return brier_score_loss(q, p, pos_label=pos_label)
122 |
123 |
124 | def b_brier_score(p, q):
125 | """
126 | balanced brier score
127 | :param p: predictive labels
128 | :param q: ground truth labels
129 | :return:
130 | """
131 | pos_indicator = (q == 1)
132 | pos_bs = brier_score(p[pos_indicator], q[pos_indicator], pos_label=None)
133 | neg_indicator = (q == 0)
134 | neg_bs = brier_score(p[neg_indicator], q[neg_indicator], pos_label=None)
135 | return 1. / 2. * (pos_bs + neg_bs)
136 |
137 |
138 | def expected_calibration_error(probabilities, ground_truth, bins=10, use_unweighted_version=True):
139 | """
140 | Code is adapted from https://github.com/google-research/google-research/tree/master/uq_benchmark_2019/metrics_lib.py
141 | Compute the expected calibration error of a set of preditions in [0, 1].
142 | Args:
143 | probabilities: A numpy vector of N probabilities assigned to each prediction
144 | ground_truth: A numpy vector of N ground truth labels in {0,1, True, False}
145 | bins: Number of equal width bins to bin predictions into in [0, 1], or
146 | an array representing bin edges.
147 | Returns:
148 | Float: the expected calibration error.
149 | """
150 |
151 | def bin_predictions_and_accuracies(probabilities, ground_truth, bins=10):
152 | """A helper function which histograms a vector of probabilities into bins.
153 |
154 | Args:
155 | probabilities: A numpy vector of N probabilities assigned to each prediction
156 | ground_truth: A numpy vector of N ground truth labels in {0,1}
157 | bins: Number of equal width bins to bin predictions into in [0, 1], or an
158 | array representing bin edges.
159 |
160 | Returns:
161 | bin_edges: Numpy vector of floats containing the edges of the bins
162 | (including leftmost and rightmost).
163 | accuracies: Numpy vector of floats for the average accuracy of the
164 | predictions in each bin.
165 | counts: Numpy vector of ints containing the number of examples per bin.
166 | """
167 | _check_probablities(probabilities, ground_truth)
168 |
169 | if isinstance(bins, int):
170 | num_bins = bins
171 | else:
172 | num_bins = bins.size - 1
173 |
174 | # Ensure probabilities are never 0, since the bins in np.digitize are open on
175 | # one side.
176 | probabilities = np.where(probabilities == 0, 1e-8, probabilities)
177 | counts, bin_edges = np.histogram(probabilities, bins=bins, range=[0., 1.])
178 | indices = np.digitize(probabilities, bin_edges, right=True)
179 | accuracies = np.array([np.mean(ground_truth[indices == i])
180 | for i in range(1, num_bins + 1)])
181 | return bin_edges, accuracies, counts
182 |
183 | def bin_centers_of_mass(probabilities, bin_edges):
184 | probabilities = np.where(probabilities == 0, 1e-8, probabilities)
185 | indices = np.digitize(probabilities, bin_edges, right=True)
186 | return np.array([np.mean(probabilities[indices == i])
187 | for i in range(1, len(bin_edges))])
188 |
189 | probabilities = probabilities.flatten()
190 | ground_truth = ground_truth.flatten()
191 |
192 | bin_edges, accuracies, counts = bin_predictions_and_accuracies(
193 | probabilities, ground_truth, bins)
194 | bin_centers = bin_centers_of_mass(probabilities, bin_edges)
195 | num_examples = np.sum(counts)
196 | if not use_unweighted_version:
197 | ece = np.sum([(counts[i] / float(num_examples)) * np.sum(
198 | np.abs(bin_centers[i] - accuracies[i]))
199 | for i in range(bin_centers.size) if counts[i] > 0])
200 | else:
201 | ece = np.sum([(1. / float(bins)) * np.sum(
202 | np.abs(bin_centers[i] - accuracies[i]))
203 | for i in range(bin_centers.size) if counts[i] > 0])
204 | return ece
205 |
206 |
207 | def _main():
208 | gt_labels = np.random.choice(2, (1000,))
209 | prob = np.random.uniform(0., 1., (1000,))
210 | # print("Entropy:", entropy(prob))
211 | # print("negative log likelihood:", nll(prob, gt_labels))
212 | # print('balanced nll:', b_nll(prob, gt_labels))
213 | # print("brier score:", brier_score(prob, gt_labels))
214 | # print("balanced brier score:", b_brier_score(prob, gt_labels))
215 | # print("expected calibration error:", expected_calibration_error(prob, gt_labels))
216 | #
217 | # prob2 = np.random.uniform(0., 1., (1000, 10))
218 | # w = np.random.uniform(0., 1., (10,))
219 | # w = w/np.sum(w)
220 | # print("Standard deviation:", predictive_std(prob2, weights=w, number=10))
221 | # print("Kl divergence:", predictive_kld(prob2, w, number=10))
222 | # print("Standard deviation:", predictive_std(np.zeros(shape=(10, 10)), weights=w, number=10))
223 | # print("Kl divergence:", predictive_kld(np.zeros(shape=(10, 10)), w, number=10))
224 |
225 | gt_labels = np.array([1., 1., 0., 0, 0])
226 | prob = gt_labels
227 | # print("Entropy:", entropy(prob))
228 | # print("negative log likelihood:", nll(prob, gt_labels))
229 | # print('balanced nll:', b_nll(prob, gt_labels))
230 |
231 | # print("Standard deviation:", predictive_std(prob, number=2))
232 | # print("Kl divergence:", predictive_kld(prob, number=2))
233 | # print("brier score:", brier_score(prob, gt_labels))
234 | # print("expected calibration error:", expected_calibration_error(prob, gt_labels))
235 | print("balanced ece:", expected_calibration_error(prob, gt_labels, bins=2, use_undersampling=True))
236 |
237 | # gt_labels = np.array([1., 0.])
238 | # prob = 1. - gt_labels
239 | # print("Entropy:", entropy(prob))
240 | # print("negative log likelihood:", nll(prob, gt_labels))
241 | # print("brier score:", brier_score(prob, gt_labels))
242 | # print("expected calibration error:", expected_calibration_error(prob, gt_labels))
243 | return 0
244 |
245 |
246 | if __name__ == '__main__':
247 | _main()
248 |
--------------------------------------------------------------------------------
/tools/progressbar/progressbar.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # progressbar - Text progress bar library for Python.
5 | # Copyright (c) 2005 Nilton Volpato
6 | #
7 | # This library is free software; you can redistribute it and/or
8 | # modify it under the terms of the GNU Lesser General Public
9 | # License as published by the Free Software Foundation; either
10 | # version 2.1 of the License, or (at your option) any later version.
11 | #
12 | # This library is distributed in the hope that it will be useful,
13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15 | # Lesser General Public License for more details.
16 | #
17 | # You should have received a copy of the GNU Lesser General Public
18 | # License along with this library; if not, write to the Free Software
19 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 |
21 | """Main ProgressBar class."""
22 |
23 | from __future__ import division
24 |
25 | import math
26 | import os
27 | import signal
28 | import sys
29 | import time
30 |
31 | try:
32 | from fcntl import ioctl
33 | from array import array
34 | import termios
35 | except ImportError:
36 | pass
37 |
38 | from tools.progressbar.compat import * # for: any, next
39 | from . import widgets
40 | # import .widgets
41 |
42 |
43 | class UnknownLength: pass
44 |
45 |
46 | class ProgressBar(object):
47 | """The ProgressBar class which updates and prints the bar.
48 |
49 | A common way of using it is like:
50 | >>> pbar = ProgressBar().start()
51 | >>> for i in range(100):
52 | ... # do something
53 | ... pbar.update(i+1)
54 | ...
55 | >>> pbar.finish()
56 |
57 | You can also use a ProgressBar as an iterator:
58 | >>> progress = ProgressBar()
59 | >>> for i in progress(some_iterable):
60 | ... # do something
61 | ...
62 |
63 | Since the progress bar is incredibly customizable you can specify
64 | different widgets of any type in any order. You can even write your own
65 | widgets! However, since there are already a good number of widgets you
66 | should probably play around with them before moving on to create your own
67 | widgets.
68 |
69 | The term_width parameter represents the current terminal width. If the
70 | parameter is set to an integer then the progress bar will use that,
71 | otherwise it will attempt to determine the terminal width falling back to
72 | 80 columns if the width cannot be determined.
73 |
74 | When implementing a widget's update method you are passed a reference to
75 | the current progress bar. As a result, you have access to the
76 | ProgressBar's methods and attributes. Although there is nothing preventing
77 | you from changing the ProgressBar you should treat it as read only.
78 |
79 | Useful methods and attributes include (Public API):
80 | - currval: current progress (0 <= currval <= maxval)
81 | - maxval: maximum (and final) value
82 | - finished: True if the bar has finished (reached 100%)
83 | - start_time: the time when start() method of ProgressBar was called
84 | - seconds_elapsed: seconds elapsed since start_time and last call to
85 | update
86 | - percentage(): progress in percent [0..100]
87 | """
88 |
89 | __slots__ = ('currval', 'fd', 'finished', 'last_update_time',
90 | 'left_justify', 'maxval', 'next_update', 'num_intervals',
91 | 'poll', 'seconds_elapsed', 'signal_set', 'start_time',
92 | 'term_width', 'update_interval', 'widgets', '_time_sensitive',
93 | '__iterable')
94 |
95 | _DEFAULT_MAXVAL = 100
96 | _DEFAULT_TERMSIZE = 80
97 | _DEFAULT_WIDGETS = [widgets.Percentage(), ' ', widgets.Bar()]
98 |
99 | def __init__(self, maxval=None, widgets=None, term_width=None, poll=1,
100 | left_justify=True, fd=sys.stderr):
101 | """Initializes a progress bar with sane defaults."""
102 |
103 | # Don't share a reference with any other progress bars
104 | if widgets is None:
105 | widgets = list(self._DEFAULT_WIDGETS)
106 |
107 | self.maxval = maxval
108 | self.widgets = widgets
109 | self.fd = fd
110 | self.left_justify = left_justify
111 |
112 | self.signal_set = False
113 | if term_width is not None:
114 | self.term_width = term_width
115 | else:
116 | try:
117 | self._handle_resize()
118 | signal.signal(signal.SIGWINCH, self._handle_resize)
119 | self.signal_set = True
120 | except (SystemExit, KeyboardInterrupt): raise
121 | except:
122 | self.term_width = self._env_size()
123 |
124 | self.__iterable = None
125 | self._update_widgets()
126 | self.currval = 0
127 | self.finished = False
128 | self.last_update_time = None
129 | self.poll = poll
130 | self.seconds_elapsed = 0
131 | self.start_time = None
132 | self.update_interval = 1
133 | self.next_update = 0
134 |
135 |
136 | def __call__(self, iterable):
137 | """Use a ProgressBar to iterate through an iterable."""
138 |
139 | try:
140 | self.maxval = len(iterable)
141 | except:
142 | if self.maxval is None:
143 | self.maxval = UnknownLength
144 |
145 | self.__iterable = iter(iterable)
146 | return self
147 |
148 |
149 | def __iter__(self):
150 | return self
151 |
152 |
153 | def __next__(self):
154 | try:
155 | value = next(self.__iterable)
156 | if self.start_time is None:
157 | self.start()
158 | else:
159 | self.update(self.currval + 1)
160 | return value
161 | except StopIteration:
162 | if self.start_time is None:
163 | self.start()
164 | self.finish()
165 | raise
166 |
167 |
168 | # Create an alias so that Python 2.x won't complain about not being
169 | # an iterator.
170 | next = __next__
171 |
172 |
173 | def _env_size(self):
174 | """Tries to find the term_width from the environment."""
175 |
176 | return int(os.environ.get('COLUMNS', self._DEFAULT_TERMSIZE)) - 1
177 |
178 |
179 | def _handle_resize(self, signum=None, frame=None):
180 | """Tries to catch resize signals sent from the terminal."""
181 |
182 | h, w = array('h', ioctl(self.fd, termios.TIOCGWINSZ, '\0' * 8))[:2]
183 | self.term_width = w
184 |
185 |
186 | def percentage(self):
187 | """Returns the progress as a percentage."""
188 | if self.currval >= self.maxval:
189 | return 100.0
190 | return self.currval * 100.0 / self.maxval
191 |
192 | percent = property(percentage)
193 |
194 |
195 | def _format_widgets(self):
196 | result = []
197 | expanding = []
198 | width = self.term_width
199 |
200 | for index, widget in enumerate(self.widgets):
201 | if isinstance(widget, widgets.WidgetHFill):
202 | result.append(widget)
203 | expanding.insert(0, index)
204 | else:
205 | widget = widgets.format_updatable(widget, self)
206 | result.append(widget)
207 | width -= len(widget)
208 |
209 | count = len(expanding)
210 | while count:
211 | portion = max(int(math.ceil(width * 1. / count)), 0)
212 | index = expanding.pop()
213 | count -= 1
214 |
215 | widget = result[index].update(self, portion)
216 | width -= len(widget)
217 | result[index] = widget
218 |
219 | return result
220 |
221 |
222 | def _format_line(self):
223 | """Joins the widgets and justifies the line."""
224 |
225 | widgets = ''.join(self._format_widgets())
226 |
227 | if self.left_justify: return widgets.ljust(self.term_width)
228 | else: return widgets.rjust(self.term_width)
229 |
230 |
231 | def _need_update(self):
232 | """Returns whether the ProgressBar should redraw the line."""
233 | if self.currval >= self.next_update or self.finished: return True
234 |
235 | delta = time.time() - self.last_update_time
236 | return self._time_sensitive and delta > self.poll
237 |
238 |
239 | def _update_widgets(self):
240 | """Checks all widgets for the time sensitive bit."""
241 |
242 | self._time_sensitive = any(getattr(w, 'TIME_SENSITIVE', False)
243 | for w in self.widgets)
244 |
245 |
246 | def update(self, value=None):
247 | """Updates the ProgressBar to a new value."""
248 |
249 | if value is not None and value is not UnknownLength:
250 | if (self.maxval is not UnknownLength
251 | and not 0 <= value <= self.maxval):
252 |
253 | raise ValueError('Value out of range')
254 |
255 | self.currval = value
256 |
257 |
258 | if not self._need_update(): return
259 | if self.start_time is None:
260 | raise RuntimeError('You must call "start" before calling "update"')
261 |
262 | now = time.time()
263 | self.seconds_elapsed = now - self.start_time
264 | self.next_update = self.currval + self.update_interval
265 | self.fd.write(self._format_line() + '\r')
266 | self.last_update_time = now
267 |
268 |
269 | def start(self):
270 | """Starts measuring time, and prints the bar at 0%.
271 |
272 | It returns self so you can use it like this:
273 | >>> pbar = ProgressBar().start()
274 | >>> for i in range(100):
275 | ... # do something
276 | ... pbar.update(i+1)
277 | ...
278 | >>> pbar.finish()
279 | """
280 |
281 | if self.maxval is None:
282 | self.maxval = self._DEFAULT_MAXVAL
283 |
284 | self.num_intervals = max(100, self.term_width)
285 | self.next_update = 0
286 |
287 | if self.maxval is not UnknownLength:
288 | if self.maxval < 0: raise ValueError('Value out of range')
289 | self.update_interval = self.maxval / self.num_intervals
290 |
291 |
292 | self.start_time = self.last_update_time = time.time()
293 | self.update(0)
294 |
295 | return self
296 |
297 |
298 | def finish(self):
299 | """Puts the ProgressBar bar in the finished state."""
300 |
301 | if self.finished:
302 | return
303 | self.finished = True
304 | self.update(self.maxval)
305 | self.fd.write('\n')
306 | if self.signal_set:
307 | signal.signal(signal.SIGWINCH, signal.SIG_DFL)
308 |
--------------------------------------------------------------------------------
/experiments/drebin_dataset.py:
--------------------------------------------------------------------------------
1 | # conduct the first group experiments on drebin dataset
2 | import os
3 |
4 | import numpy as np
5 | from sklearn.model_selection import train_test_split
6 |
7 | from core.feature import feature_type_scope_dict, feature_type_vs_architecture
8 | from core.ensemble import ensemble_method_scope_dict
9 | from tools import utils
10 | from config import config, logging
11 |
12 | logger = logging.getLogger('experiment.drebin')
13 |
14 |
15 | # procedure of drebin experiments
16 | # 1. build dataset
17 | # 2. preprocess data
18 | # 3. learn models
19 | # 4. save results for statistical analysis
20 |
21 | def run_experiment(feature_type, ensemble_type, random_seed=0, n_members=1, ratio=3.0, proc_numbers=2):
22 | """
23 | run this group of experiments
24 | :param feature_type: the type of feature (e.g., drebin, opcode, etc.), feature type associates to the model architecture
25 | :param ensemble_type: the ensemble method (e.g., vanilla, deep_ensemble, etc.
26 | :param random_seed: an integer
27 | :param n_members: the number of base models enclosed by an ensemble
28 | :param ratio: the ratio of benign files to malware files
29 | :param proc_numbers: the number of threads
30 | :return: None
31 | """
32 | mal_folder = config.get('drebin', 'malware_dir')
33 | ben_folder = config.get('drebin', 'benware_dir')
34 | logger.info('testing:{},{}'.format(feature_type, ensemble_type))
35 | logger.info('The seed is :{}'.format(random_seed))
36 |
37 | train_dataset, validation_dataset, test_data, test_y, input_dim = \
38 | data_preprocessing(feature_type, mal_folder, ben_folder, ratio, proc_numbers, random_seed)
39 |
40 | ensemble_obj = get_ensemble_object(ensemble_type)
41 | # instantiation
42 | arch_type = feature_type_vs_architecture.get(feature_type)
43 | saving_dir = config.get('experiments', 'drebin')
44 | if ensemble_type in ['vanilla', 'mc_dropout', 'bayesian']:
45 | ensemble_model = ensemble_obj(arch_type, base_model=None, n_members=1, model_directory=saving_dir)
46 | else:
47 | ensemble_model = ensemble_obj(arch_type, base_model=None, n_members=n_members, model_directory=saving_dir)
48 |
49 | ensemble_model.fit(train_dataset, validation_dataset, input_dim=input_dim)
50 |
51 | test_results = ensemble_model.predict(test_data)
52 | utils.dump_joblib((test_results, test_y),
53 | os.path.join(saving_dir, '{}_{}_test.res'.format(feature_type, ensemble_type)))
54 | ensemble_model.evaluate(test_data, test_y)
55 | return
56 |
57 |
58 | def run_temperature_scaling(feature_type, ensemble_type):
59 | """
60 | Run temperature scaling
61 | :param feature_type: the type of feature (e.g., drebin, opcode, etc.), feature type associates to the model architecture
62 | :param ensemble_type: the ensemble method (e.g., vanilla, deep_ensemble, etc.
63 | """
64 | from core.post_calibration.temperature_scaling import find_scaling_temperature, \
65 | apply_temperature_scaling, inverse_sigmoid
66 | logger.info('run temperature scaling:{},{}'.format(feature_type, ensemble_type))
67 |
68 | # load dataset
69 | def data_load(feature_type='drebin'):
70 | assert feature_type in feature_type_scope_dict.keys(), 'Expected {}, but {} are supported.'.format(
71 | feature_type, list(feature_type_scope_dict.keys()))
72 |
73 | android_features_saving_dir = config.get('metadata', 'naive_data_pool')
74 | intermediate_data_saving_dir = config.get('drebin', 'intermediate_directory')
75 | feature_extractor = feature_type_scope_dict[feature_type](android_features_saving_dir,
76 | intermediate_data_saving_dir,
77 | update=False,
78 | proc_number=1)
79 |
80 | save_path = os.path.join(intermediate_data_saving_dir, 'drebin_database.' + feature_type)
81 | if os.path.exists(save_path):
82 | _1, test_filenames, validation_filenames, \
83 | _2, test_y, validation_y = utils.read_joblib(save_path)
84 | validation_features = [os.path.join(android_features_saving_dir, filename) for filename in
85 | validation_filenames]
86 | test_features = [os.path.join(android_features_saving_dir, filename) for filename in
87 | test_filenames]
88 | else:
89 | raise ValueError
90 |
91 | test_data, _ = feature_extractor.feature2ipt(test_features)
92 | validation_data, _ = feature_extractor.feature2ipt(validation_features)
93 |
94 | return validation_data, test_data, validation_y, test_y
95 |
96 | mal_folder = config.get('drebin', 'malware_dir')
97 | ben_folder = config.get('drebin', 'benware_dir')
98 | val_data, test_data, val_label, test_y = \
99 | data_load(feature_type, mal_folder, ben_folder)
100 |
101 | # load model
102 | ensemble_obj = get_ensemble_object(ensemble_type)
103 | arch_type = feature_type_vs_architecture.get(feature_type)
104 | model_saving_dir = config.get('experiments', 'drebin')
105 | if ensemble_type in ['vanilla', 'mc_dropout', 'bayesian']:
106 | ensemble_model = ensemble_obj(arch_type, base_model=None, model_directory=model_saving_dir)
107 | else:
108 | ensemble_model = ensemble_obj(arch_type, base_model=None, model_directory=model_saving_dir)
109 |
110 | temp_save_dir = os.path.join(config.get('drebin', 'intermediate_directory'),
111 | "{}_{}_temp.json".format(feature_type, ensemble_type))
112 | if not os.path.exists(temp_save_dir):
113 | prob = np.squeeze(ensemble_model.predict(val_data, use_prob=True))
114 | logits = inverse_sigmoid(prob)
115 | temperature = find_scaling_temperature(val_label, logits)
116 | utils.dump_json({'temperature': temperature}, temp_save_dir)
117 | temperature = utils.load_json(temp_save_dir)['temperature']
118 |
119 | prob_test = ensemble_model.predict(test_data, use_prob=True)
120 | prob_t = apply_temperature_scaling(temperature, prob_test)
121 | utils.dump_joblib((prob_t, test_y),
122 | os.path.join(model_saving_dir, '{}_{}_temperature_test.res'.format(feature_type, ensemble_type)))
123 |
124 |
125 | def data_preprocessing(feature_type='drebin', malware_dir=None, benware_dir=None, ratio=3.0, proc_numbers=2,
126 | random_seed=0):
127 | assert feature_type in feature_type_scope_dict.keys(), 'Expected {}, but {} are supported.'.format(
128 | feature_type, list(feature_type_scope_dict.keys()))
129 |
130 | android_features_saving_dir = config.get('metadata', 'naive_data_pool')
131 | intermediate_data_saving_dir = config.get('drebin', 'intermediate_directory')
132 | feature_extractor = feature_type_scope_dict[feature_type](android_features_saving_dir,
133 | intermediate_data_saving_dir,
134 | update=False,
135 | proc_number=proc_numbers)
136 | save_path = os.path.join(intermediate_data_saving_dir, 'drebin_database.' + feature_type)
137 |
138 | if os.path.exists(save_path):
139 | train_filenames, test_filenames, validation_filenames, \
140 | train_y, test_y, validation_y = utils.read_joblib(save_path)
141 |
142 | train_features = [os.path.join(config.get('metadata', 'naive_data_pool'), filename) for filename in
143 | train_filenames]
144 | validation_features = [os.path.join(config.get('metadata', 'naive_data_pool'), filename) for filename in
145 | validation_filenames]
146 | test_features = [os.path.join(config.get('metadata', 'naive_data_pool'), filename) for filename in
147 | test_filenames]
148 | else:
149 | def train_test_val_split(data):
150 | train, test = train_test_split(data, test_size=0.2, random_state=random_seed)
151 | train, val = train_test_split(train, test_size=0.25, random_state=random_seed)
152 | return train, val, test
153 |
154 | def merge_mal_ben(mal, ben):
155 | mal_feature_list = feature_extractor.feature_extraction(mal)
156 | n_malware = len(mal_feature_list)
157 | ben_feature_list = feature_extractor.feature_extraction(ben)
158 | n_benware = len(ben_feature_list)
159 | feature_list = mal_feature_list + ben_feature_list
160 | gt_labels = np.zeros((n_malware + n_benware,), dtype=np.int32)
161 | gt_labels[:n_malware] = 1
162 | import random
163 | random.seed(0)
164 | random.shuffle(feature_list)
165 | random.seed(0)
166 | random.shuffle(gt_labels)
167 | return feature_list, gt_labels
168 |
169 | malware_path_list = utils.retrive_files_set(malware_dir, "", ".apk|")
170 | mal_train, mal_val, mal_test = train_test_val_split(malware_path_list)
171 | benware_path_list = utils.retrive_files_set(benware_dir, "", ".apk|")
172 | ben_train, ben_val, ben_test = train_test_val_split(benware_path_list)
173 |
174 | # undersampling the benign files
175 | if ratio >= 1.:
176 | ben_train = undersampling(ben_train, len(mal_train), ratio)
177 | logger.info('Training set, the number of benign files vs. malware files: {} vs. {}'.format(len(ben_train),
178 | len(mal_train)))
179 | train_features, train_y = merge_mal_ben(mal_train, ben_train)
180 | validation_features, validation_y = merge_mal_ben(mal_val, ben_val)
181 | test_features, test_y = merge_mal_ben(mal_test, ben_test)
182 |
183 | train_filenames = [os.path.basename(path) for path in train_features]
184 | validation_filenames = [os.path.basename(path) for path in validation_features]
185 | test_filenames = [os.path.basename(path) for path in test_features]
186 | utils.dump_joblib(
187 | (train_filenames, test_filenames, validation_filenames, train_y, test_y, validation_y),
188 | save_path)
189 | # obtain data in a format for ML algorithms
190 | feature_extractor.feature_preprocess(train_features, train_y) # produce datasets products
191 | train_dataset, input_dim = feature_extractor.feature2ipt(train_features, train_y, is_training_set=True)
192 | test_data, _ = feature_extractor.feature2ipt(test_features)
193 | validation_dataset, _ = feature_extractor.feature2ipt(validation_features, validation_y)
194 | return train_dataset, validation_dataset, test_data, test_y, input_dim
195 |
196 |
197 | def undersampling(ben_train, num_of_mal, ratio):
198 | number_of_choice = int(num_of_mal * ratio) if int(num_of_mal * ratio) <= len(ben_train) else len(ben_train)
199 | import random
200 | random.seed(0)
201 | random.shuffle(ben_train)
202 | return ben_train[:number_of_choice]
203 |
204 |
205 | def get_ensemble_object(ensemble_type):
206 | assert ensemble_type in ensemble_method_scope_dict.keys(), '{} expected, but {} are supported'.format(
207 | ensemble_type,
208 | ','.join(ensemble_method_scope_dict.keys())
209 | )
210 | return ensemble_method_scope_dict[ensemble_type]
211 |
--------------------------------------------------------------------------------
/core/feature/multimodality/multimodality.py:
--------------------------------------------------------------------------------
1 | """
2 | - String feature
3 | - Permission feature
4 | - Component feature
5 | - Environmental feature
6 | - Method opcode feature
7 | - Method API feature
8 | - Shared library function opcode feature
9 | """
10 |
11 | import os
12 | import time
13 |
14 | import magic
15 | import zipfile
16 | import hashlib
17 | from collections import defaultdict
18 |
19 | from androguard.misc import AnalyzeAPK, APK
20 | import lxml.etree as etree
21 | from xml.dom import minidom
22 |
23 | from elftools.elf.elffile import ELFFile
24 | from elftools.common.py3compat import BytesIO
25 | from capstone import Cs, CS_ARCH_ARM64, CS_ARCH_ARM, CS_MODE_ARM
26 |
27 | from tools import utils
28 | from sys import platform as _platform
29 | from config import logging
30 |
31 | if _platform == "linux" or _platform == "linux2":
32 | TMP_DIR = '/tmp/'
33 | elif _platform == "win32" or _platform == "win64":
34 | TMP_DIR = 'C:\\TEMP\\'
35 | logger = logging.getLogger('feature.multimodality')
36 | current_dir = os.path.dirname(os.path.realpath(__file__))
37 | API_LIST = [api.decode('utf-8') for api in \
38 | utils.read_txt(os.path.join(current_dir, 'res/api_list.txt'), mode='rb')]
39 |
40 |
41 | def get_multimod_feature(apk_path, api_list, save_path):
42 | """
43 | get feature
44 | :param apk_path: an absolute apk path
45 | :param api_list: api list, api level 22 is considered
46 | :param save_path: a path for saving feature
47 | :return: (status, result_file_path), e.g., (True, back_path_name+apk_name+'.data')
48 | """
49 | try:
50 | print("Processing " + apk_path)
51 | start_time = time.time()
52 | data_list = []
53 | # permission, components, environment information
54 | data_list.append(get_dict_feature_xml(apk_path))
55 | # string, method api, method opcodes
56 | string_feature_dict, api_feature_dict, opcode_feature_dict = get_feature_dex(apk_path, api_list)
57 | data_list.extend([string_feature_dict, api_feature_dict, opcode_feature_dict])
58 | # shared library feature
59 | data_list.append(get_feature_shared_lib(apk_path))
60 |
61 | utils.dump_json(data_list, save_path)
62 | except Exception as e:
63 | return e
64 | else:
65 | return save_path
66 |
67 |
68 | def get_dict_feature_xml(apk_path):
69 | """
70 | get requested feature from manifest file
71 | :param apk_path: absolute path of an apk file
72 | :return: a dict of elements {feature:occurrence, ..., }
73 | """
74 | permission_component_envinfo = defaultdict(int)
75 | xml_tmp_dir = os.path.join(TMP_DIR, 'xml_dir')
76 | if not os.path.exists(xml_tmp_dir):
77 | os.mkdir(xml_tmp_dir)
78 | apk_name = os.path.splitext(os.path.basename(apk_path))[0]
79 | try:
80 | apk_path = os.path.abspath(apk_path)
81 | a = APK(apk_path)
82 | f = open(os.path.join(xml_tmp_dir, apk_name + '.xml'), 'wb')
83 | xmlstreaming = etree.tostring(a.xml['AndroidManifest.xml'], pretty_print=True, encoding='utf-8')
84 | f.write(xmlstreaming)
85 | f.close()
86 | except Exception as e:
87 | raise Exception("Fail to load xml file of apk {}:{}".format(apk_path, str(e)))
88 |
89 | # start obtain feature permission, components, environment information
90 | try:
91 | with open(os.path.join(xml_tmp_dir, apk_name + '.xml'), 'rb') as f:
92 | dom_xml = minidom.parse(f)
93 | dom_elements = dom_xml.documentElement
94 |
95 | dom_permissions = dom_elements.getElementsByTagName('uses-permission')
96 | for permission in dom_permissions:
97 | if permission.hasAttribute('android:name'):
98 | permission_component_envinfo[permission.getAttribute('android:name')] = 1
99 |
100 | dom_activities = dom_elements.getElementsByTagName('activity')
101 | for activity in dom_activities:
102 | if activity.hasAttribute('android:name'):
103 | permission_component_envinfo[activity.getAttribute('android:name')] = 1
104 |
105 | dom_services = dom_elements.getElementsByTagName("service")
106 | for service in dom_services:
107 | if service.hasAttribute("android:name"):
108 | permission_component_envinfo[service.getAttribute("android:name")] = 1
109 |
110 | dom_contentproviders = dom_elements.getElementsByTagName("provider")
111 | for provider in dom_contentproviders:
112 | if provider.hasAttribute("android:name"):
113 | permission_component_envinfo[provider.getAttribute("android:name")] = 1
114 | # uri
115 | dom_uris = provider.getElementsByTagName('grant-uri-permission')
116 | # intents --- action
117 | intent_actions = provider.getElementsByTagName('action')
118 | # we neglect to compose the uri feature by so-called paired provider name and intents
119 | # instead, the path of uri is used directly
120 | for uri in dom_uris:
121 | if uri.hasAttribute('android:path'):
122 | permission_component_envinfo[uri.getAttribute('android:path')] = 1
123 |
124 | dom_broadcastreceivers = dom_elements.getElementsByTagName("receiver")
125 | for receiver in dom_broadcastreceivers:
126 | if receiver.hasAttribute("android:name"):
127 | permission_component_envinfo[receiver.getAttribute("android:name")] = 1
128 |
129 | dom_intentfilter_actions = dom_elements.getElementsByTagName("action")
130 | for action in dom_intentfilter_actions:
131 | if action.hasAttribute("android:name"):
132 | permission_component_envinfo[action.getAttribute("android:name")] = 1
133 |
134 | dom_hardwares = dom_elements.getElementsByTagName("uses-feature")
135 | for hardware in dom_hardwares:
136 | if hardware.hasAttribute("android:name"):
137 | permission_component_envinfo[hardware.getAttribute("android:name")] = 1
138 |
139 | dom_libraries = dom_elements.getElementsByTagName("android:name")
140 | for lib in dom_libraries:
141 | if lib.hasAttribute('android:name'):
142 | permission_component_envinfo[lib.getAttribute('android:name')] = 1
143 |
144 | dom_sdk_versions = dom_elements.getElementsByTagName("uses-sdk")
145 | for sdk in dom_sdk_versions:
146 | if sdk.hasAttribute('android:minSdkVersion'):
147 | permission_component_envinfo[sdk.getAttribute('android:minSdkVersion')] = 1
148 | if sdk.hasAttribute('android:targetSdkVersion'):
149 | permission_component_envinfo[sdk.getAttribute('android:targetSdkVersion')] = 1
150 | if sdk.hasAttribute('android:maxSdkVersion'):
151 | permission_component_envinfo[sdk.getAttribute('android:maxSdkVersion')] = 1
152 |
153 | return permission_component_envinfo
154 | except Exception as e:
155 | raise Exception("Fail to process xml file of apk {}:{}".format(apk_path, str(e)))
156 |
157 |
158 | def get_feature_dex(apk_path, api_list):
159 | """
160 | get feature about opcode and functions, and string from .dex files
161 | :param apk_path: an absolute path of an apk
162 | :param api_list: a list of apis
163 | :return: two dicts of elements {feature:occurrence, ..., } corresponds to functionality feature and string feature
164 | """
165 | string_feature_dict = defaultdict(int)
166 | opcode_feature_dict = defaultdict(int)
167 | api_feature_dict = defaultdict(int)
168 |
169 | try:
170 | _1, _2, dx = AnalyzeAPK(apk_path)
171 | except Exception as e:
172 | raise Exception('Fail to load dex file of apk {}:{}'.format(apk_path, str(e)))
173 |
174 | for method in dx.get_methods():
175 | if method.is_external():
176 | continue
177 | method_obj = method.get_method()
178 | for instruction in method_obj.get_instructions():
179 | opcode = instruction.get_name()
180 | opcode_feature_dict[opcode] += 1
181 | if 'invoke-' in opcode:
182 | code_body = instruction.get_output()
183 | if ';->' not in code_body:
184 | continue
185 | head_part, rear_part = code_body.split(';->')
186 | class_name = head_part.strip().split(' ')[-1]
187 | method_name = rear_part.strip().split('(')[0]
188 | if class_name + ';->' + method_name in api_list:
189 | api_feature_dict[class_name + ';->' + method_name] += 1
190 |
191 | if (opcode == 'const-string') or (opcode == 'const-string/jumbo'):
192 | code_body = instruction.get_output()
193 | ss_string = code_body.split(' ')[-1].strip('\'').strip('\"').encode('utf-8')
194 | if not ss_string:
195 | continue
196 | hashing_string = hashlib.sha512(ss_string).hexdigest()
197 | string_feature_dict[hashing_string] = 1
198 |
199 | return string_feature_dict, api_feature_dict, opcode_feature_dict
200 |
201 |
202 | def get_feature_shared_lib(apk_path):
203 | """
204 | get feature from ELF files
205 | :param apk_path: an absolute path of an apk
206 | :return: a dict of elements {feature:occurrence, ..., }
207 | """
208 | feature_dict = defaultdict(int)
209 | try:
210 | with zipfile.ZipFile(apk_path, 'r') as apk_zipf:
211 | handled_name_list = []
212 | for name in apk_zipf.namelist():
213 | if os.path.basename(name) in handled_name_list:
214 | continue
215 | brief_file_info = magic.from_buffer(apk_zipf.read(name))
216 | if 'ELF' not in brief_file_info:
217 | continue
218 | if 'ELF 64' in brief_file_info:
219 | arch_info = CS_ARCH_ARM64
220 | elif 'ELF 32' in brief_file_info:
221 | arch_info = CS_ARCH_ARM
222 | else:
223 | raise ValueError("Android ABIs with Arm 64-bit or Arm 32 bit are support.")
224 |
225 | with apk_zipf.open(name, 'r') as fr:
226 | bt_stream = BytesIO(fr.read())
227 | elf_fhander = ELFFile(bt_stream)
228 | # arm opcodes (i.e., instructions)
229 | bt_code_text = elf_fhander.get_section_by_name('.text')
230 | if (bt_code_text is not None) and (bt_code_text.data() is not None):
231 | md = Cs(arch_info, CS_MODE_ARM)
232 | for _1, _2, op_code, _3 in md.disasm_lite(bt_code_text.data(), 0x1000):
233 | feature_dict[op_code] += 1
234 | # functions
235 | # we consider about whether a function occurs or not, rather than its frequency,
236 | # for simplifying implementation.
237 | # If only running the codes on linux platform, we suggest an arm supported tool 'objdump',
238 | # which can be used to count the frequency of functions conveniently. We here utilize 'pyelftools'
239 | # to accommodate different development platform.
240 | REL_PLT_section = elf_fhander.get_section_by_name('.rela.plt')
241 | if REL_PLT_section is None:
242 | continue
243 | symtable = elf_fhander.get_section(REL_PLT_section['sh_link'])
244 |
245 | for rel_plt in REL_PLT_section.iter_relocations():
246 | func_name = symtable.get_symbol(rel_plt['r_info_sym']).name
247 | feature_dict[func_name] += 1
248 |
249 | handled_name_list.append(os.path.basename(name))
250 | return feature_dict
251 | except Exception as e:
252 | raise Exception("Fail to process shared library file of apk {}:{}".format(apk_path, str(e)))
253 |
254 |
255 | def dump_feaure_list(feature_list, save_path):
256 | utils.dump_json(save_path, feature_list)
257 | return
258 |
259 |
260 | def load_feature(save_path):
261 | return utils.load_json(save_path)
262 |
263 |
264 | def wrapper_load_features(save_path):
265 | try:
266 | return load_feature(save_path)
267 | except Exception as e:
268 | return e
269 |
--------------------------------------------------------------------------------
/tools/progressbar/widgets.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # progressbar - Text progress bar library for Python.
5 | # Copyright (c) 2005 Nilton Volpato
6 | #
7 | # This library is free software; you can redistribute it and/or
8 | # modify it under the terms of the GNU Lesser General Public
9 | # License as published by the Free Software Foundation; either
10 | # version 2.1 of the License, or (at your option) any later version.
11 | #
12 | # This library is distributed in the hope that it will be useful,
13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
15 | # Lesser General Public License for more details.
16 | #
17 | # You should have received a copy of the GNU Lesser General Public
18 | # License along with this library; if not, write to the Free Software
19 | # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
20 |
21 | """Default ProgressBar widgets."""
22 |
23 | from __future__ import division
24 |
25 | import datetime
26 | import math
27 |
28 | try:
29 | from abc import ABCMeta, abstractmethod
30 | except ImportError:
31 | AbstractWidget = object
32 | abstractmethod = lambda fn: fn
33 | else:
34 | AbstractWidget = ABCMeta('AbstractWidget', (object,), {})
35 |
36 |
37 | def format_updatable(updatable, pbar):
38 | if hasattr(updatable, 'update'): return updatable.update(pbar)
39 | else: return updatable
40 |
41 |
42 | class Widget(AbstractWidget):
43 | """The base class for all widgets.
44 |
45 | The ProgressBar will call the widget's update value when the widget should
46 | be updated. The widget's size may change between calls, but the widget may
47 | display incorrectly if the size changes drastically and repeatedly.
48 |
49 | The boolean TIME_SENSITIVE informs the ProgressBar that it should be
50 | updated more often because it is time sensitive.
51 | """
52 |
53 | TIME_SENSITIVE = False
54 | __slots__ = ()
55 |
56 | @abstractmethod
57 | def update(self, pbar):
58 | """Updates the widget.
59 |
60 | pbar - a reference to the calling ProgressBar
61 | """
62 |
63 |
64 | class WidgetHFill(Widget):
65 | """The base class for all variable width widgets.
66 |
67 | This widget is much like the \\hfill command in TeX, it will expand to
68 | fill the line. You can use more than one in the same line, and they will
69 | all have the same width, and together will fill the line.
70 | """
71 |
72 | @abstractmethod
73 | def update(self, pbar, width):
74 | """Updates the widget providing the total width the widget must fill.
75 |
76 | pbar - a reference to the calling ProgressBar
77 | width - The total width the widget must fill
78 | """
79 |
80 |
81 | class Timer(Widget):
82 | """Widget which displays the elapsed seconds."""
83 |
84 | __slots__ = ('format_string',)
85 | TIME_SENSITIVE = True
86 |
87 | def __init__(self, format='Elapsed Time: %s'):
88 | self.format_string = format
89 |
90 | @staticmethod
91 | def format_time(seconds):
92 | """Formats time as the string "HH:MM:SS"."""
93 |
94 | return str(datetime.timedelta(seconds=int(seconds)))
95 |
96 |
97 | def update(self, pbar):
98 | """Updates the widget to show the elapsed time."""
99 |
100 | return self.format_string % self.format_time(pbar.seconds_elapsed)
101 |
102 |
103 | class ETA(Timer):
104 | """Widget which attempts to estimate the time of arrival."""
105 |
106 | TIME_SENSITIVE = True
107 |
108 | def update(self, pbar):
109 | """Updates the widget to show the ETA or total time when finished."""
110 |
111 | if pbar.currval == 0:
112 | return 'ETA: --:--:--'
113 | elif pbar.finished:
114 | return 'Time: %s' % self.format_time(pbar.seconds_elapsed)
115 | else:
116 | elapsed = pbar.seconds_elapsed
117 | eta = elapsed * pbar.maxval / pbar.currval - elapsed
118 | return 'ETA: %s' % self.format_time(eta)
119 |
120 |
121 | class AdaptiveETA(Timer):
122 | """Widget which attempts to estimate the time of arrival.
123 |
124 | Uses a weighted average of two estimates:
125 | 1) ETA based on the total progress and time elapsed so far
126 | 2) ETA based on the progress as per tha last 10 update reports
127 |
128 | The weight depends on the current progress so that to begin with the
129 | total progress is used and at the end only the most recent progress is
130 | used.
131 | """
132 |
133 | TIME_SENSITIVE = True
134 | NUM_SAMPLES = 10
135 |
136 | def _update_samples(self, currval, elapsed):
137 | sample = (currval, elapsed)
138 | if not hasattr(self, 'samples'):
139 | self.samples = [sample] * (self.NUM_SAMPLES + 1)
140 | else:
141 | self.samples.append(sample)
142 | return self.samples.pop(0)
143 |
144 | def _eta(self, maxval, currval, elapsed):
145 | return elapsed * maxval / float(currval) - elapsed
146 |
147 | def update(self, pbar):
148 | """Updates the widget to show the ETA or total time when finished."""
149 | if pbar.currval == 0:
150 | return 'Remaining Time: --:--:--'
151 | elif pbar.finished:
152 | #return 'Time: %s' % self.format_time(pbar.seconds_elapsed)
153 | return ''
154 | else:
155 | elapsed = pbar.seconds_elapsed
156 | currval1, elapsed1 = self._update_samples(pbar.currval, elapsed)
157 | eta = self._eta(pbar.maxval, pbar.currval, elapsed)
158 | if pbar.currval > currval1:
159 | etasamp = self._eta(pbar.maxval - currval1,
160 | pbar.currval - currval1,
161 | elapsed - elapsed1)
162 | weight = (pbar.currval / float(pbar.maxval)) ** 0.5
163 | eta = (1 - weight) * eta + weight * etasamp
164 | return 'Remaining Time: %s' % self.format_time(eta)
165 |
166 |
167 | class FileTransferSpeed(Widget):
168 | """Widget for showing the transfer speed (useful for file transfers)."""
169 |
170 | FORMAT = '%6.2f %s%s/s'
171 | PREFIXES = ' kMGTPEZY'
172 | __slots__ = ('unit',)
173 |
174 | def __init__(self, unit='B'):
175 | self.unit = unit
176 |
177 | def update(self, pbar):
178 | """Updates the widget with the current SI prefixed speed."""
179 |
180 | if pbar.seconds_elapsed < 2e-6 or pbar.currval < 2e-6: # =~ 0
181 | scaled = power = 0
182 | else:
183 | speed = pbar.currval / pbar.seconds_elapsed
184 | power = int(math.log(speed, 1000))
185 | scaled = speed / 1000.**power
186 |
187 | return self.FORMAT % (scaled, self.PREFIXES[power], self.unit)
188 |
189 |
190 | class AnimatedMarker(Widget):
191 | """An animated marker for the progress bar which defaults to appear as if
192 | it were rotating.
193 | """
194 |
195 | __slots__ = ('markers', 'curmark')
196 |
197 | def __init__(self, markers='|/-\\'):
198 | self.markers = markers
199 | self.curmark = -1
200 |
201 | def update(self, pbar):
202 | """Updates the widget to show the next marker or the first marker when
203 | finished"""
204 |
205 | if pbar.finished: return self.markers[0]
206 |
207 | self.curmark = (self.curmark + 1) % len(self.markers)
208 | return self.markers[self.curmark]
209 |
210 | # Alias for backwards compatibility
211 | RotatingMarker = AnimatedMarker
212 |
213 |
214 | class Counter(Widget):
215 | """Displays the current count."""
216 |
217 | __slots__ = ('format_string',)
218 |
219 | def __init__(self, format='%d'):
220 | self.format_string = format
221 |
222 | def update(self, pbar):
223 | return self.format_string % pbar.currval
224 |
225 |
226 | class Percentage(Widget):
227 | """Displays the current percentage as a number with a percent sign."""
228 |
229 | def update(self, pbar):
230 | return '%3d%%' % pbar.percentage()
231 |
232 |
233 | class FormatLabel(Timer):
234 | """Displays a formatted label."""
235 |
236 | mapping = {
237 | 'elapsed': ('seconds_elapsed', Timer.format_time),
238 | 'finished': ('finished', None),
239 | 'last_update': ('last_update_time', None),
240 | 'max': ('maxval', None),
241 | 'seconds': ('seconds_elapsed', None),
242 | 'start': ('start_time', None),
243 | 'value': ('currval', None)
244 | }
245 |
246 | __slots__ = ('format_string',)
247 | def __init__(self, format):
248 | self.format_string = format
249 |
250 | def update(self, pbar):
251 | context = {}
252 | for name, (key, transform) in self.mapping.items():
253 | try:
254 | value = getattr(pbar, key)
255 |
256 | if transform is None:
257 | context[name] = value
258 | else:
259 | context[name] = transform(value)
260 | except: pass
261 |
262 | return self.format_string % context
263 |
264 |
265 | class SimpleProgress(Widget):
266 | """Returns progress as a count of the total (e.g.: "5 of 47")."""
267 |
268 | __slots__ = ('sep',)
269 |
270 | def __init__(self, sep=' of '):
271 | self.sep = sep
272 |
273 | def update(self, pbar):
274 | return '%d%s%d' % (pbar.currval, self.sep, pbar.maxval)
275 |
276 |
277 | class Bar(WidgetHFill):
278 | """A progress bar which stretches to fill the line."""
279 |
280 | __slots__ = ('marker', 'left', 'right', 'fill', 'fill_left')
281 |
282 | def __init__(self, marker='#', left='|', right='|', fill=' ',
283 | fill_left=True):
284 | """Creates a customizable progress bar.
285 |
286 | marker - string or updatable object to use as a marker
287 | left - string or updatable object to use as a left border
288 | right - string or updatable object to use as a right border
289 | fill - character to use for the empty part of the progress bar
290 | fill_left - whether to fill from the left or the right
291 | """
292 | self.marker = marker
293 | self.left = left
294 | self.right = right
295 | self.fill = fill
296 | self.fill_left = fill_left
297 |
298 |
299 | def update(self, pbar, width):
300 | """Updates the progress bar and its subcomponents."""
301 |
302 | left, marked, right = (format_updatable(i, pbar) for i in
303 | (self.left, self.marker, self.right))
304 |
305 | width -= len(left) + len(right)
306 | # Marked must *always* have length of 1
307 | if pbar.maxval:
308 | marked *= int(pbar.currval / pbar.maxval * width)
309 | else:
310 | marked = ''
311 |
312 | if self.fill_left:
313 | return '%s%s%s' % (left, marked.ljust(width, self.fill), right)
314 | else:
315 | return '%s%s%s' % (left, marked.rjust(width, self.fill), right)
316 |
317 |
318 | class ReverseBar(Bar):
319 | """A bar which has a marker which bounces from side to side."""
320 |
321 | def __init__(self, marker='#', left='|', right='|', fill=' ',
322 | fill_left=False):
323 | """Creates a customizable progress bar.
324 |
325 | marker - string or updatable object to use as a marker
326 | left - string or updatable object to use as a left border
327 | right - string or updatable object to use as a right border
328 | fill - character to use for the empty part of the progress bar
329 | fill_left - whether to fill from the left or the right
330 | """
331 | self.marker = marker
332 | self.left = left
333 | self.right = right
334 | self.fill = fill
335 | self.fill_left = fill_left
336 |
337 |
338 | class BouncingBar(Bar):
339 | def update(self, pbar, width):
340 | """Updates the progress bar and its subcomponents."""
341 |
342 | left, marker, right = (format_updatable(i, pbar) for i in
343 | (self.left, self.marker, self.right))
344 |
345 | width -= len(left) + len(right)
346 |
347 | if pbar.finished: return '%s%s%s' % (left, width * marker, right)
348 |
349 | position = int(pbar.currval % (width * 2 - 1))
350 | if position > width: position = width * 2 - position
351 | lpad = self.fill * (position - 1)
352 | rpad = self.fill * (width - len(marker) - len(lpad))
353 |
354 | # Swap if we want to bounce the other way
355 | if not self.fill_left: rpad, lpad = lpad, rpad
356 |
357 | return '%s%s%s%s%s' % (left, lpad, marker, rpad, right)
358 |
--------------------------------------------------------------------------------
/core/ensemble/deep_ensemble.py:
--------------------------------------------------------------------------------
1 | from os import path
2 | import time
3 |
4 | import numpy as np
5 | import tensorflow as tf
6 |
7 | from core.ensemble.vanilla import Vanilla
8 | from core.ensemble.model_hp import train_hparam
9 | from core.ensemble.dataset_lib import build_dataset_from_numerical_data
10 | from tools import utils
11 | from config import logging, ErrorHandler
12 |
13 | logger = logging.getLogger('ensemble.deep_ensemble')
14 | logger.addHandler(ErrorHandler)
15 |
16 |
17 | class DeepEnsemble(Vanilla):
18 | def __init__(self,
19 | architecture_type='dnn',
20 | base_model=None,
21 | n_members=5,
22 | model_directory=None,
23 | name='DEEPENSEMBLE'
24 | ):
25 | super(DeepEnsemble, self).__init__(architecture_type,
26 | base_model,
27 | n_members,
28 | model_directory,
29 | name)
30 | self.hparam = train_hparam
31 | self.ensemble_type = 'deep_ensemble'
32 |
33 |
34 | class WeightedDeepEnsemble(Vanilla):
35 | def __init__(self,
36 | architecture_type='dnn',
37 | base_model=None,
38 | n_members=5,
39 | model_directory=None,
40 | name='WEIGTHEDDEEPENSEMBLE'
41 | ):
42 | super(WeightedDeepEnsemble, self).__init__(architecture_type,
43 | base_model,
44 | n_members,
45 | model_directory,
46 | name)
47 | self.hparam = train_hparam
48 | self.ensemble_type = 'deep_ensemble'
49 | self.weight_modular = None
50 |
51 | def get_weight_modular(self):
52 | class Simplex(tf.keras.constraints.Constraint):
53 | def __call__(self, w):
54 | return tf.math.softmax(w - tf.math.reduce_max(w), axis=0)
55 |
56 | inputs = tf.keras.Input(shape=(self.n_members,))
57 | outs = tf.keras.layers.Dense(1, use_bias=False, activation=None, kernel_constraint=Simplex(), name='simplex')(
58 | inputs)
59 | return tf.keras.Model(inputs=inputs, outputs=outs)
60 |
61 | def predict(self, x, use_prob=False):
62 | """ conduct prediction """
63 | self.base_model = None
64 | self.weight_modular = None
65 | self.weights_list = []
66 | self._optimizers_dict = []
67 | self.load_ensemble_weights()
68 | output_list = []
69 | start_time = time.time()
70 | for base_model in self.model_generator():
71 | if isinstance(x, tf.data.Dataset):
72 | output_list.append(base_model.predict(x, verbose=1))
73 | elif isinstance(x, (np.ndarray, list)):
74 | output_list.append(base_model.predict(x, batch_size=self.hparam.batch_size, verbose=1))
75 | else:
76 | raise ValueError
77 | total_time = time.time() - start_time
78 | logger.info('Inference costs {} seconds.'.format(total_time))
79 | assert self.weight_modular is not None
80 | output = self.weight_modular(np.hstack(output_list)).numpy()
81 | if not use_prob:
82 | return np.stack(output_list, axis=1), self.weight_modular.get_layer('simplex').get_weights()
83 | else:
84 | return output
85 |
86 | def fit(self, train_set, validation_set=None, input_dim=None, **kwargs):
87 | """
88 | fit the ensemble by producing a lists of model weights
89 | :param train_set: tf.data.Dataset, the type shall accommodate to the input format of Tensorflow models
90 | :param validation_set: validation data, optional
91 | :param input_dim: integer or list, input dimension except for the batch size
92 | """
93 | # training preparation
94 | if self.base_model is None:
95 | self.build_model(input_dim=input_dim)
96 | if self.weight_modular is None:
97 | self.weight_modular = self.get_weight_modular()
98 |
99 | self.base_model.compile(
100 | optimizer=tf.keras.optimizers.Adam(learning_rate=self.hparam.learning_rate,
101 | clipvalue=self.hparam.clipvalue),
102 | loss=tf.keras.losses.BinaryCrossentropy(),
103 | metrics=[tf.keras.metrics.BinaryAccuracy()],
104 | )
105 |
106 | self.weight_modular.compile(
107 | optimizer=tf.keras.optimizers.Adam(learning_rate=self.hparam.learning_rate,
108 | clipvalue=self.hparam.clipvalue),
109 | loss=tf.keras.losses.BinaryCrossentropy(),
110 | metrics=[tf.keras.metrics.BinaryAccuracy()],
111 | )
112 |
113 | # training
114 | logger.info("hyper-parameters:")
115 | logger.info(dict(self.hparam._asdict()))
116 | logger.info("...training start!")
117 |
118 | best_val_accuracy = 0.
119 | total_time = 0.
120 | for epoch in range(self.hparam.n_epochs):
121 | for member_idx in range(self.n_members):
122 | if member_idx < len(self.weights_list): # loading former weights
123 | self.base_model.set_weights(self.weights_list[member_idx])
124 | self.base_model.optimizer.set_weights(self._optimizers_dict[member_idx])
125 | elif member_idx == 0:
126 | pass # do nothing
127 | else:
128 | self.reinitialize_base_model()
129 |
130 | msg = 'Epoch {}/{}, member {}/{}, and {} member(s) in list'.format(epoch + 1,
131 | self.hparam.n_epochs, member_idx + 1,
132 | self.n_members,
133 | len(self.weights_list))
134 | print(msg)
135 | start_time = time.time()
136 | self.base_model.fit(train_set,
137 | epochs=epoch + 1,
138 | initial_epoch=epoch,
139 | validation_data=validation_set
140 | )
141 | self.update_weights(member_idx,
142 | self.base_model.get_weights(),
143 | self.base_model.optimizer.get_weights())
144 |
145 | end_time = time.time()
146 | total_time += end_time - start_time
147 | # training weight modular
148 | msg = "train the weight modular at epoch {}/{}"
149 | print(msg.format(epoch, self.hparam.n_epochs))
150 | start_time = time.time()
151 | history = self.fit_weight_modular(train_set, validation_set, epoch)
152 | end_time = time.time()
153 | total_time += end_time - start_time
154 | # saving
155 | logger.info('Training ensemble costs {} in total (including validation).'.format(total_time))
156 | train_acc = history.history['binary_accuracy'][0]
157 | val_acc = history.history['val_binary_accuracy'][0]
158 | msg = 'Epoch {}/{}: training accuracy {:.5f}, validation accuracy {:.5f}.'.format(
159 | epoch + 1, self.hparam.n_epochs, train_acc, val_acc
160 | )
161 | logger.info(msg)
162 | if (epoch + 1) % self.hparam.interval == 0:
163 | if val_acc > best_val_accuracy:
164 | self.save_ensemble_weights()
165 | best_val_accuracy = val_acc
166 | msg = '\t The best validation accuracy is {:.5f}, obtained at epoch {}/{}'.format(
167 | best_val_accuracy, epoch + 1, self.hparam.n_epochs
168 | )
169 | logger.info(msg)
170 | return
171 |
172 | def fit_weight_modular(self, train_set, validation_set, epoch):
173 | """
174 | fit weight modular
175 | :param train_set: training set
176 | :param validation_set: validation set
177 | :param epoch: integer, training epoch
178 | :return: None
179 | """
180 |
181 | # obtain data
182 | def get_data(x_y_set):
183 | tsf_x = []
184 | tsf_y = []
185 | for _x, _y in x_y_set:
186 | _x_list = []
187 | for base_model in self.model_generator():
188 | _x_pred = base_model(_x)
189 | _x_list.append(_x_pred)
190 | tsf_x.append(np.hstack(_x_list))
191 | tsf_y.append(_y)
192 | return np.vstack(tsf_x), np.concatenate(tsf_y)
193 |
194 | transform_train_set = build_dataset_from_numerical_data(get_data(train_set))
195 | transform_val_set = build_dataset_from_numerical_data(get_data(validation_set))
196 |
197 | history = self.weight_modular.fit(transform_train_set,
198 | epochs=epoch + 1,
199 | initial_epoch=epoch,
200 | validation_data=transform_val_set
201 | )
202 | return history
203 |
204 | def save_ensemble_weights(self):
205 | if not path.exists(self.save_dir):
206 | utils.mkdir(self.save_dir)
207 | # save model configuration
208 | try:
209 | config = self.base_model.to_json()
210 | utils.dump_json(config, path.join(self.save_dir,
211 | self.architecture_type + '.json')) # lightweight method for saving model configurature
212 | except Exception as e:
213 | pass
214 | finally:
215 | if not path.exists(path.join(self.save_dir, self.architecture_type)):
216 | utils.mkdir(path.join(self.save_dir, self.architecture_type))
217 | self.base_model.save(path.join(self.save_dir, self.architecture_type))
218 | print("Save the model configuration to directory {}".format(self.save_dir))
219 |
220 | # save model weights
221 | utils.dump_joblib(self.weights_list, path.join(self.save_dir, self.architecture_type + '.model'))
222 | utils.dump_joblib(self._optimizers_dict, path.join(self.save_dir, self.architecture_type + '.model.metadata'))
223 | print("Save the model weights to directory {}".format(self.save_dir))
224 |
225 | # save weight modular
226 | self.weight_modular.save(path.join(self.save_dir, self.architecture_type + '_weight_modular'))
227 | print("Save the weight modular weights to directory {}".format(self.save_dir))
228 | return
229 |
230 | def load_ensemble_weights(self):
231 | if path.exists(path.join(self.save_dir, self.architecture_type + '.json')):
232 | config = utils.load_json(path.join(self.save_dir, self.architecture_type + '.json'))
233 | self.base_model = tf.keras.models.model_from_json(config)
234 | elif path.exists(path.join(self.save_dir, self.architecture_type)):
235 | self.base_model = tf.keras.models.load_model(path.join(self.save_dir, self.architecture_type))
236 | else:
237 | logger.error("File not found: ".format(path.join(self.save_dir, self.architecture_type + '.json')))
238 | raise FileNotFoundError
239 | print("Load model config from {}.".format(self.save_dir))
240 |
241 | if path.exists(path.join(self.save_dir, self.architecture_type + '.model')):
242 | self.weights_list = utils.read_joblib(path.join(self.save_dir, self.architecture_type + '.model'))
243 | else:
244 | logger.error("File not found: ".format(path.join(self.save_dir, self.architecture_type + '.model')))
245 | raise FileNotFoundError
246 | print("Load model weights from {}.".format(self.save_dir))
247 |
248 | if path.exists(path.join(self.save_dir, self.architecture_type + '.model.metadata')):
249 | self._optimizers_dict = utils.read_joblib(
250 | path.join(self.save_dir, self.architecture_type + '.model.metadata'))
251 | else:
252 | self._optimizers_dict = [None] * len(self.weights_list)
253 |
254 | if path.exists(path.join(self.save_dir, self.architecture_type + '_weight_modular')):
255 | self.weight_modular = tf.keras.models.load_model(
256 | path.join(self.save_dir, self.architecture_type + '_weight_modular'))
257 | return
258 |
--------------------------------------------------------------------------------