├── .gitignore
├── .project-root
├── README.md
├── about
├── download.md
└── issues.md
├── configs
├── .gitkeep
├── __init__.py
├── callbacks
│ ├── default.yaml
│ ├── early_stopping.yaml
│ ├── model_checkpoint.yaml
│ ├── model_summary.yaml
│ ├── none.yaml
│ └── rich_progress_bar.yaml
├── data
│ ├── bytes.yaml
│ └── ember.yaml
├── debug
│ ├── default.yaml
│ ├── fdr.yaml
│ ├── limit.yaml
│ ├── overfit.yaml
│ └── profiler.yaml
├── eval.yaml
├── experiment
│ ├── example.yaml
│ ├── malconv-bytes-test.yaml
│ ├── malconv-bytes-train.yaml
│ ├── mlp-ember-test.yaml
│ └── mlp-ember-train.yaml
├── extras
│ └── default.yaml
├── hparams_search
│ └── mnist_optuna.yaml
├── hydra
│ └── default.yaml
├── local
│ └── .gitkeep
├── logger
│ └── wandb.yaml
├── model
│ ├── malconv.yaml
│ └── mlp.yaml
├── paths
│ └── default.yaml
├── train.yaml
└── trainer
│ ├── cpu.yaml
│ ├── ddp.yaml
│ ├── ddp_sim.yaml
│ ├── default.yaml
│ ├── gpu.yaml
│ └── mps.yaml
├── detect
└── mlp_ember.py
├── requirements.txt
├── scripts
├── detect_mlp_ember_drift.sh
├── test_gbdt_ember.sh
├── test_malconv_bytes.sh
├── test_mlp_ember.sh
├── train_gbdt_ember.sh
├── train_malconv_bytes.sh
└── train_mlp_ember.sh
└── src
├── __init__.py
├── datasets
├── __init__.py
├── bytes.py
├── ember.py
└── mfc.py
├── eval.py
├── eval_gbdt.py
├── models
├── __init__.py
├── gbdt.py
├── malconv.py
├── malconv_module.py
├── mlp.py
└── mlp_module.py
├── train.py
├── train_gbdt.py
└── utils
├── __init__.py
├── instantiators.py
├── logging_utils.py
├── pylogger.py
├── rich_utils.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | Temporary and binary files
2 | *~
3 | *.py[cod]
4 | *.so
5 | *.cfg
6 | !.isort.cfg
7 | !setup.cfg
8 | *.orig
9 | *.log
10 | *.pot
11 | __pycache__/*
12 | .cache/*
13 | .*.swp
14 | */.ipynb_checkpoints/*
15 | .DS_Store
16 |
17 | # Project files
18 | .ropeproject
19 | .project
20 | .pydevproject
21 | .settings
22 | .idea
23 | .vscode
24 | tags
25 |
26 | # Package files
27 | *.egg
28 | *.eggs/
29 | .installed.cfg
30 | *.egg-info
31 |
32 | # Unittest and coverage
33 | htmlcov/*
34 | .coverage
35 | .coverage.*
36 | .tox
37 | junit*.xml
38 | coverage.xml
39 | .pytest_cache/
40 |
41 | # Build and docs folder/files
42 | build/*
43 | dist/*
44 | sdist/*
45 | docs/api/*
46 | docs/_rst/*
47 | docs/_build/*
48 | cover/*
49 | MANIFEST
50 |
51 | # Per-project virtualenvs
52 | .venv*/
53 | .conda*/
54 | logs/*
55 |
56 | # data
57 | logs/*
58 | detect/*.csv
--------------------------------------------------------------------------------
/.project-root:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crowdma/benchmfc/106e5d773e6fdf2bac429446183f996b3ece8e03/.project-root
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
BenchMFC
9 |
10 |
11 | A Benchmark Dataset for Trustworthy Malware Family Classification under Concept Drift
12 |
13 |
14 |
15 |
16 |
17 |
18 | ## Abstract
19 |
20 | Concept drift poses a critical challenge in deploying machine learning models to mitigate practical malware threats. It refers to the phenomenon that the distribution of test data changes over time, gradually deviating from the original training data and degrading model performance. A promising direction for addressing concept drift is to detect drift samples and then retrain the model. However, this field currently lacks a unified, well-curated, and comprehensive benchmark, which often leads to unfair comparisons and inconclusive outcomes. To improve the evaluation and advance further, this paper presents a new Benchmark dataset for trustworthy Malware Family Classification (BenchMFC), which includes 223K samples of 526 families that evolve over years. BenchMFC provides clear family, packer, and timestamp tags for each sample, it thus can support research on three types of malware concept drift: 1) unseen families, 2) packed families, and 3) evolved families. To collect unpacked family samples from large-scale candidates, we introduce a novel crowdsourcing malware annotation pipeline, which unifies packing detection and family annotation as a consensus inference problem to prevent costly packing detection. Moreover, we provide two case studies to illustrate the application of BenchMFC in 1) concept drift detection and 2) model retraining. The first case demonstrates the impact of three types of malware concept drift and compares nine notable concept drift detectors. The results show that existing detectors have their own advantages in dealing with different types of malware concept drift, and there is still room for improvement in malware concept drift detection. The second case explores how static feature-based machine learning operates on packed samples when retraining a model. The experiments illustrate that packers do preserve some kind of signals that appear to be “effective” for machine learning models, but the robustness of these signals requires further research. BenchMFC has been released to the community at https://github.com/crowdma/benchmfc.
21 |
22 |
23 | ## Reference
24 | This paper has been accepted by Computers & Security:
25 | ```
26 | @article{jiang_2024,
27 | title = {BenchMFC: A benchmark dataset for trustworthy malware family classification under concept drift},
28 | author = {Yongkang Jiang and Gaolei Li and Shenghong Li and Ying Guo},
29 | journal = {Computers & Security},
30 | volume = {139},
31 | pages = {103706},
32 | year = {2024},
33 | }
34 | ```
35 |
36 |
37 | ## Dataset
38 |
39 | ### Size
40 |
41 | ```
42 | ├── benchmfc_meta.csv (Metadata file for the dataset ~17M)
43 | ├── benchmfc.tar.gz (Samples ~83G)
44 | └── mfc (Experimental data used in the paper)
45 | ├── mfc_features.tar.gz (Ember features ~39M)
46 | ├── mfc_meta.csv (Metadata file ~1M)
47 | └── mfc_samples.tar.gz (Samples ~7G)
48 | ```
49 |
50 | ### Download
51 | Please visit this [link](about/download.md) for more details.
52 |
53 |
54 | ## Getting Started
55 |
56 | ### Installation
57 |
58 | - Run the following commands:
59 | ```sh
60 | # python = "<3.10 >=3.9"
61 | git clone https://github.com/crowdma/benchmfc.git
62 | cd benchmfc
63 | pip install -r requirements.txt
64 | ```
65 |
66 |
67 | ## Usage Examples
68 |
69 | - Env
70 |
71 | ```sh
72 | export MFC_ROOT=//
73 | # MFC structure
74 | ├── feature-ember-npy
75 | │ ├── malicious
76 | │ ├── malicious-unseen
77 | │ ├── malicious-evolving
78 | │ ├── malicious-aes
79 | │ ├── malicious-mpress
80 | │ └── malicious-upx
81 | └── samples
82 | ├── malicious
83 | ├── malicious-unseen
84 | ├── malicious-evolving
85 | ├── malicious-aes
86 | ├── malicious-mpress
87 | └── malicious-upx
88 | ```
89 |
90 | - Train
91 | ```sh
92 | /bin/bash scripts/train_mlp_ember.sh
93 | ```
94 |
95 | - Test
96 | ```sh
97 | /bin/bash scripts/test_mlp_ember.sh
98 | ```
99 | - Detect Drift
100 | ```sh
101 | /bin/bash scripts/detect_mlp_ember_drift.sh
102 | ```
103 |
104 | ## Issues
105 |
106 | Please visit this [link](about/issues.md) for known issues.
107 |
108 |
109 |
110 | ## License
111 |
112 | Distributed under the MIT License.
--------------------------------------------------------------------------------
/about/download.md:
--------------------------------------------------------------------------------
1 | ## Download
2 |
3 | All samples in the dataset were not disarmed. To avoid misuse, please read and agree to the following conditions before sending us emails.
4 |
5 | - Please email Yongkang (jiangyongkang@alumni.sjtu.edu.cn).
6 | - Do not share the data with any others (except your co-authors for the project). We are happy to share with other researchers based upon their requests.
7 | - Explain in a few sentences of your plan to do with these binaries. It should not be a precise plan.
8 | - If you are in academia, contact us using your institution email and provide us a webpage registered at the university domain that contains your name and affiliation.
9 | - If you are in research (industrial) labs, email us from your company’s email account and introduce yourself and company. In the email, please attach a justification letter (in PDF format) in official letterhead. The letter needs to state clearly the reasons why this dataset is being requested.
10 |
11 | Please note that an email not following the conditions might be ignored. And we will keep the public list of organizations accessing these samples at the bottom.
12 |
13 |
14 | ## Organizations Requested Our Dataset
15 |
16 | 1. Wuhan University
17 | 2. Huazhong University of Science and Technology
18 | 3. Southeast University
19 | 4. Taibah University
20 | 5. Université catholique de Louvain
21 | 6. University of Alberta
22 | 7. IMDEA Networks Institute, Universidad Carlos III de Madrid (U3CM)
23 | 8. Indian Institute of Technology, Indore
24 | 9. Hebei Normal University
25 | 10. Ludwig-Maximilians-Universität München (LMU)
26 | 11. Fast University Karachi
27 | 12. Beijing University of Post and Telecommunications
28 | 13. Queen's University Belfast
29 | 14. Korea University
30 | 15. University of Palermo
31 | 16. University of Luxembourg
32 |
--------------------------------------------------------------------------------
/about/issues.md:
--------------------------------------------------------------------------------
1 | ## About test
2 |
3 | - In Fig.9, we only illustrated 40% of the test samples.
4 |
5 |
--------------------------------------------------------------------------------
/configs/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crowdma/benchmfc/106e5d773e6fdf2bac429446183f996b3ece8e03/configs/.gitkeep
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
1 | # this file is needed here to include configs when building project as a package
2 |
--------------------------------------------------------------------------------
/configs/callbacks/default.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - model_checkpoint.yaml
3 | - early_stopping.yaml
4 | - model_summary.yaml
5 | - rich_progress_bar.yaml
6 | - _self_
7 |
8 | model_checkpoint:
9 | dirpath: ${paths.output_dir}/checkpoints
10 | filename: "epoch_{epoch:03d}"
11 | monitor: "val/acc"
12 | mode: "max"
13 | save_last: True
14 | auto_insert_metric_name: False
15 |
16 | early_stopping:
17 | monitor: "val/acc"
18 | patience: 100
19 | mode: "max"
20 |
21 | model_summary:
22 | max_depth: -1
23 |
--------------------------------------------------------------------------------
/configs/callbacks/early_stopping.yaml:
--------------------------------------------------------------------------------
1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
2 |
3 | early_stopping:
4 | _target_: lightning.pytorch.callbacks.EarlyStopping
5 | monitor: ??? # quantity to be monitored, must be specified !!!
6 | min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
7 | patience: 3 # number of checks with no improvement after which training will be stopped
8 | verbose: False # verbosity mode
9 | mode: "min" # "max" means higher metric value is better, can be also "min"
10 | strict: True # whether to crash the training if monitor is not found in the validation metrics
11 | check_finite: True # when set True, stops training when the monitor becomes NaN or infinite
12 | stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold
13 | divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold
14 | check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch
15 | # log_rank_zero_only: False # this keyword argument isn't available in stable version
16 |
--------------------------------------------------------------------------------
/configs/callbacks/model_checkpoint.yaml:
--------------------------------------------------------------------------------
1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
2 |
3 | model_checkpoint:
4 | _target_: lightning.pytorch.callbacks.ModelCheckpoint
5 | dirpath: null # directory to save the model file
6 | filename: null # checkpoint filename
7 | monitor: null # name of the logged metric which determines when model is improving
8 | verbose: False # verbosity mode
9 | save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt
10 | save_top_k: 1 # save k best models (determined by above metric)
11 | mode: "min" # "max" means higher metric value is better, can be also "min"
12 | auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name
13 | save_weights_only: False # if True, then only the model’s weights will be saved
14 | every_n_train_steps: null # number of training steps between checkpoints
15 | train_time_interval: null # checkpoints are monitored at the specified time interval
16 | every_n_epochs: null # number of epochs between checkpoints
17 | save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
18 |
--------------------------------------------------------------------------------
/configs/callbacks/model_summary.yaml:
--------------------------------------------------------------------------------
1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
2 |
3 | model_summary:
4 | _target_: lightning.pytorch.callbacks.RichModelSummary
5 | max_depth: 1 # the maximum depth of layer nesting that the summary will include
6 |
--------------------------------------------------------------------------------
/configs/callbacks/none.yaml:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crowdma/benchmfc/106e5d773e6fdf2bac429446183f996b3ece8e03/configs/callbacks/none.yaml
--------------------------------------------------------------------------------
/configs/callbacks/rich_progress_bar.yaml:
--------------------------------------------------------------------------------
1 | # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html
2 |
3 | rich_progress_bar:
4 | _target_: lightning.pytorch.callbacks.RichProgressBar
5 |
--------------------------------------------------------------------------------
/configs/data/bytes.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.datasets.bytes.BytesDataModule
2 |
3 | num_workers: 16
--------------------------------------------------------------------------------
/configs/data/ember.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.datasets.ember.EmberDataModule
2 |
3 | train_size: 0.6
4 | val_size: 0.2
5 | test_size: 0.2
6 | batch_size: 32
--------------------------------------------------------------------------------
/configs/debug/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # default debugging setup, runs 1 full epoch
4 | # other debugging configs can inherit from this one
5 |
6 | # overwrite task name so debugging logs are stored in separate folder
7 | task_name: "debug"
8 |
9 | # disable callbacks and loggers during debugging
10 | callbacks: null
11 | logger: null
12 |
13 | extras:
14 | ignore_warnings: False
15 | enforce_tags: False
16 |
17 | # sets level of all command line loggers to 'DEBUG'
18 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
19 | hydra:
20 | job_logging:
21 | root:
22 | level: DEBUG
23 |
24 | # use this to also set hydra loggers to 'DEBUG'
25 | # verbose: True
26 |
27 | trainer:
28 | max_epochs: 1
29 | accelerator: cpu # debuggers don't like gpus
30 | devices: 1 # debuggers don't like multiprocessing
31 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor
32 |
33 | data:
34 | num_workers: 0 # debuggers don't like multiprocessing
35 | pin_memory: False # disable gpu memory pin
36 |
--------------------------------------------------------------------------------
/configs/debug/fdr.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # runs 1 train, 1 validation and 1 test step
4 |
5 | defaults:
6 | - default.yaml
7 |
8 | trainer:
9 | fast_dev_run: true
10 |
--------------------------------------------------------------------------------
/configs/debug/limit.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # uses only 1% of the training data and 5% of validation/test data
4 |
5 | defaults:
6 | - default.yaml
7 |
8 | trainer:
9 | max_epochs: 3
10 | limit_train_batches: 0.01
11 | limit_val_batches: 0.05
12 | limit_test_batches: 0.05
13 |
--------------------------------------------------------------------------------
/configs/debug/overfit.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # overfits to 3 batches
4 |
5 | defaults:
6 | - default.yaml
7 |
8 | trainer:
9 | max_epochs: 20
10 | overfit_batches: 3
11 |
12 | # model ckpt and early stopping need to be disabled during overfitting
13 | callbacks: null
14 |
--------------------------------------------------------------------------------
/configs/debug/profiler.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # runs with execution time profiling
4 |
5 | defaults:
6 | - default.yaml
7 |
8 | trainer:
9 | max_epochs: 1
10 | profiler: "simple"
11 | # profiler: "advanced"
12 | # profiler: "pytorch"
13 |
--------------------------------------------------------------------------------
/configs/eval.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - _self_
5 | - data: ember.yaml # choose datamodule with `test_dataloader()` for evaluation
6 | - model: mlp.yaml
7 | - logger: null
8 | - trainer: default.yaml
9 | - paths: default.yaml
10 | - extras: default.yaml
11 | - hydra: default.yaml
12 |
13 | - experiment: null
14 |
15 | task_name: "default"
16 | train_eval: "eval"
17 |
18 | tags: ["dev"]
19 |
20 | # passing checkpoint path is necessary for evaluation
21 | ckpt_path: ???
22 |
--------------------------------------------------------------------------------
/configs/experiment/example.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /data: ember.yaml
8 | - override /model: ember.yaml
9 | - override /callbacks: default.yaml
10 | - override /trainer: default.yaml
11 |
12 | # all parameters below will be merged with parameters from default configurations set above
13 | # this allows you to overwrite only specified parameters
14 |
15 | tags: ["ember", "simple_dense_net"]
16 |
17 | seed: 12345
18 |
19 | trainer:
20 | min_epochs: 10
21 | max_epochs: 10
22 | gradient_clip_val: 0.5
23 |
24 | model:
25 | optimizer:
26 | lr: 0.002
27 | net:
28 | lin1_size: 128
29 | lin2_size: 256
30 | lin3_size: 64
31 |
32 | data:
33 | batch_size: 64
34 |
35 | logger:
36 | wandb:
37 | tags: ${tags}
38 | group: "ember"
39 | aim:
40 | experiment: "ember"
41 |
--------------------------------------------------------------------------------
/configs/experiment/malconv-bytes-test.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /data: bytes.yaml
8 | - override /model: malconv.yaml
9 | - override /trainer: default.yaml
10 |
11 | # all parameters below will be merged with parameters from default configurations set above
12 | # this allows you to overwrite only specified parameters
13 |
14 | # name of the run determines folder name in logs
15 | data_name: MFC
16 | pack_ratio: 0.0
17 | task_name: malconv-bytes-MFC-0.0
18 | seed: 42
19 |
20 | tags: ["${task_name}", "${data_name}", "${pack_ratio}"]
21 |
22 | trainer:
23 | accelerator: gpu
24 | min_epochs: 20
25 | max_epochs: 50
26 | gradient_clip_val: 0.5
27 |
28 | model:
29 | optimizer:
30 | lr: 0.001
31 | network:
32 | input_length: 1_048_576
33 | window_size: 500
34 | stride: 500
35 | channels: 128
36 | embed_size: 8
37 | output_size: 8
38 |
39 | data:
40 | data_name: ${data_name}
41 | train_size: 0.6
42 | val_size: 0.2
43 | test_size: 0.2
44 | batch_size: 32
45 | num_workers: 16
46 | pack_ratio: ${pack_ratio}
47 | first_n_byte: 1_048_576
48 |
49 | ckpt_path: ${paths.root_dir}/logs/malconv-bytes-MFC-0.0/train/runs/2023-07-30_18-23-58/checkpoints/epoch_020.ckpt
--------------------------------------------------------------------------------
/configs/experiment/malconv-bytes-train.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /data: bytes.yaml
8 | - override /model: malconv.yaml
9 | - override /callbacks: default.yaml
10 | - override /logger: wandb.yaml
11 | - override /trainer: default.yaml
12 |
13 | # all parameters below will be merged with parameters from default configurations set above
14 | # this allows you to overwrite only specified parameters
15 |
16 | # name of the run determines folder name in logs
17 | data_name: MFC
18 | pack_ratio: 0.0
19 | task_name: malconv-bytes-MFC-0.0
20 | seed: 42
21 |
22 | tags: ["${task_name}", "${data_name}", "${pack_ratio}"]
23 |
24 | trainer:
25 | accelerator: gpu
26 | min_epochs: 20
27 | max_epochs: 50
28 | gradient_clip_val: 0.5
29 |
30 | model:
31 | optimizer:
32 | lr: 0.001
33 | network:
34 | input_length: 1_048_576
35 | window_size: 500
36 | stride: 500
37 | channels: 128
38 | embed_size: 8
39 | output_size: 8
40 |
41 | data:
42 | data_name: ${data_name}
43 | train_size: 0.6
44 | val_size: 0.2
45 | test_size: 0.2
46 | batch_size: 32
47 | num_workers: 16
48 | pack_ratio: ${pack_ratio}
49 | first_n_byte: 1_048_576
50 |
51 | logger:
52 | wandb:
53 | name: ${task_name}
54 | group: malconv-bytes
55 | project: lab-benchmfc
56 |
--------------------------------------------------------------------------------
/configs/experiment/mlp-ember-test.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /data: ember.yaml
8 | - override /model: mlp.yaml
9 | - override /trainer: default.yaml
10 |
11 | # all parameters below will be merged with parameters from default configurations set above
12 | # this allows you to overwrite only specified parameters
13 |
14 | # name of the run determines folder name in logs
15 | data_name: MFC
16 | pack_ratio: 0.0
17 | task_name: mlp-ember-MFC-0.0
18 | seed: 42
19 |
20 | tags: ["${task_name}", "${data_name}", "${pack_ratio}"]
21 |
22 | trainer:
23 | accelerator: gpu
24 | min_epochs: 20
25 | max_epochs: 50
26 | gradient_clip_val: 0.5
27 |
28 | model:
29 | optimizer:
30 | lr: 0.001
31 | network:
32 | input_size: 2381
33 | hidden_units: [1024, 512, 256]
34 | output_size: 8
35 |
36 | data:
37 | data_name: ${data_name}
38 | train_size: 0.6
39 | val_size: 0.2
40 | test_size: 0.2
41 | batch_size: 32
42 | pack_ratio: ${pack_ratio}
43 |
44 | ckpt_path: ${paths.root_dir}/logs/mlp-ember-MFC-0.0/train/runs/2023-07-30_11-54-21/checkpoints/epoch_017.ckpt
--------------------------------------------------------------------------------
/configs/experiment/mlp-ember-train.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py experiment=example
5 |
6 | defaults:
7 | - override /data: ember.yaml
8 | - override /model: mlp.yaml
9 | - override /callbacks: default.yaml
10 | - override /logger: wandb.yaml
11 | - override /trainer: default.yaml
12 |
13 | # all parameters below will be merged with parameters from default configurations set above
14 | # this allows you to overwrite only specified parameters
15 |
16 | # name of the run determines folder name in logs
17 | data_name: MFC
18 | pack_ratio: 0.0
19 | task_name: mlp-ember-MFC-0.0
20 | seed: 42
21 |
22 | tags: ["${task_name}", "${data_name}", "${pack_ratio}"]
23 |
24 | trainer:
25 | accelerator: gpu
26 | min_epochs: 20
27 | max_epochs: 50
28 | gradient_clip_val: 0.5
29 |
30 | model:
31 | optimizer:
32 | lr: 0.001
33 | network:
34 | input_size: 2381
35 | hidden_units: [1024, 512, 256]
36 | output_size: 8
37 |
38 | data:
39 | data_name: ${data_name}
40 | train_size: 0.6
41 | val_size: 0.2
42 | test_size: 0.2
43 | batch_size: 32
44 | pack_ratio: ${pack_ratio}
45 |
46 | logger:
47 | wandb:
48 | name: ${task_name}
49 | group: mlp-ember
50 | project: lab-benchmfc
51 |
--------------------------------------------------------------------------------
/configs/extras/default.yaml:
--------------------------------------------------------------------------------
1 | # disable python warnings if they annoy you
2 | ignore_warnings: False
3 |
4 | # ask user for tags if none are provided in the config
5 | enforce_tags: True
6 |
7 | # pretty print config tree at the start of the run using Rich library
8 | print_config: True
9 |
--------------------------------------------------------------------------------
/configs/hparams_search/mnist_optuna.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # example hyperparameter optimization of some experiment with Optuna:
4 | # python train.py -m hparams_search=ember_optuna experiment=example
5 |
6 | defaults:
7 | - override /hydra/sweeper: optuna
8 |
9 | # choose metric which will be optimized by Optuna
10 | # make sure this is the correct name of some metric logged in lightning module!
11 | optimized_metric: "val/acc_best"
12 |
13 | # here we define Optuna hyperparameter search
14 | # it optimizes for value returned from function with @hydra.main decorator
15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper
16 | hydra:
17 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached
18 |
19 | sweeper:
20 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
21 |
22 | # storage URL to persist optimization results
23 | # for example, you can use SQLite if you set 'sqlite:///example.db'
24 | storage: null
25 |
26 | # name of the study to persist optimization results
27 | study_name: null
28 |
29 | # number of parallel workers
30 | n_jobs: 1
31 |
32 | # 'minimize' or 'maximize' the objective
33 | direction: maximize
34 |
35 | # total number of runs that will be executed
36 | n_trials: 20
37 |
38 | # choose Optuna hyperparameter sampler
39 | # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others
40 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html
41 | sampler:
42 | _target_: optuna.samplers.TPESampler
43 | seed: 1234
44 | n_startup_trials: 10 # number of random sampling runs before optimization starts
45 |
46 | # define hyperparameter search space
47 | params:
48 | model.optimizer.lr: interval(0.0001, 0.1)
49 | data.batch_size: choice(32, 64, 128, 256)
50 | model.net.lin1_size: choice(64, 128, 256)
51 | model.net.lin2_size: choice(64, 128, 256)
52 | model.net.lin3_size: choice(32, 64, 128, 256)
53 |
--------------------------------------------------------------------------------
/configs/hydra/default.yaml:
--------------------------------------------------------------------------------
1 | # https://hydra.cc/docs/configure_hydra/intro/
2 |
3 | # enable color logging
4 | defaults:
5 | - override hydra_logging: colorlog
6 | - override job_logging: colorlog
7 |
8 | # output directory, generated dynamically on each run
9 | run:
10 | dir: ${paths.log_dir}/${task_name}/${train_eval}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S}
11 | sweep:
12 | dir: ${paths.log_dir}/${task_name}/${train_eval}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S}
13 | subdir: ${hydra.job.num}
14 |
--------------------------------------------------------------------------------
/configs/local/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crowdma/benchmfc/106e5d773e6fdf2bac429446183f996b3ece8e03/configs/local/.gitkeep
--------------------------------------------------------------------------------
/configs/logger/wandb.yaml:
--------------------------------------------------------------------------------
1 | # https://wandb.ai
2 |
3 | wandb:
4 | _target_: lightning.pytorch.loggers.wandb.WandbLogger
5 | # name: "" # name of the run (normally generated by wandb)
6 | save_dir: "${paths.output_dir}"
7 | offline: False
8 | id: null # pass correct id to resume experiment!
9 | anonymous: null # enable anonymous logging
10 | project: "lightning-hydra-template"
11 | log_model: True # upload lightning ckpts
12 | prefix: "" # a string to put at the beginning of metric keys
13 | # entity: "" # set to name of your wandb team
14 | group: ""
15 | tags: []
16 | job_type: ""
17 |
--------------------------------------------------------------------------------
/configs/model/malconv.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.models.malconv_module.MalconvModule
2 |
3 | optimizer:
4 | _target_: torch.optim.Adam
5 | _partial_: true
6 | lr: 0.001
7 | weight_decay: 0.0
8 |
9 | scheduler:
10 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
11 | _partial_: true
12 | mode: min
13 | factor: 0.1
14 | patience: 10
15 |
16 | network:
17 | _target_: src.models.malconv.MalConv
18 | input_length: 1_048_576
19 | window_size: 500
20 | stride: 500
21 | channels: 128
22 | embed_size: 8
23 | output_size: 8
24 |
--------------------------------------------------------------------------------
/configs/model/mlp.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.models.mlp_module.MLPModule
2 |
3 | optimizer:
4 | _target_: torch.optim.Adam
5 | _partial_: true
6 | lr: 0.001
7 | weight_decay: 0.0
8 |
9 | scheduler:
10 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
11 | _partial_: true
12 | mode: min
13 | factor: 0.1
14 | patience: 10
15 |
16 | network:
17 | _target_: src.models.mlp.MLP
18 | input_size: 2381
19 | hidden_units: [1024, 512, 256]
20 |
--------------------------------------------------------------------------------
/configs/paths/default.yaml:
--------------------------------------------------------------------------------
1 | # path to root directory
2 | # this requires PROJECT_ROOT environment variable to exist
3 | # you can replace it with "." if you want the root to be the current working directory
4 | root_dir: ${oc.env:PROJECT_ROOT}
5 |
6 | # path to data directory
7 | data_dir: ${paths.root_dir}/data/
8 |
9 | # path to logging directory
10 | log_dir: ${paths.root_dir}/logs/
11 |
12 | # path to output directory, created dynamically by hydra
13 | # path generation pattern is specified in `configs/hydra/default.yaml`
14 | # use it to store all files generated during the run, like ckpts and metrics
15 | output_dir: ${hydra:runtime.output_dir}
16 |
17 | # path to working directory
18 | work_dir: ${hydra:runtime.cwd}
19 |
--------------------------------------------------------------------------------
/configs/train.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # specify here default configuration
4 | # order of defaults determines the order in which configs override each other
5 | defaults:
6 | - _self_
7 | - data: mnist.yaml
8 | - model: mnist.yaml
9 | - callbacks: default.yaml
10 | - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
11 | - trainer: default.yaml
12 | - paths: default.yaml
13 | - extras: default.yaml
14 | - hydra: default.yaml
15 |
16 | # experiment configs allow for version control of specific hyperparameters
17 | # e.g. best hyperparameters for given model and datamodule
18 | - experiment: null
19 |
20 | # config for hyperparameter optimization
21 | - hparams_search: null
22 |
23 | # optional local config for machine/user specific settings
24 | # it's optional since it doesn't need to exist and is excluded from version control
25 | - optional local: default.yaml
26 |
27 | # debugging config (enable through command line, e.g. `python train.py debug=default)
28 | - debug: null
29 |
30 | # task name, determines output directory path
31 | task_name: "train"
32 | train_eval: "train"
33 |
34 | # tags to help you identify your experiments
35 | # you can overwrite this in experiment configs
36 | # overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
37 | tags: ["dev"]
38 |
39 | # set False to skip model training
40 | train: True
41 |
42 | # evaluate on test set, using best model weights achieved during training
43 | # lightning chooses best weights based on the metric specified in checkpoint callback
44 | test: True
45 |
46 | # compile model for faster training with pytorch 2.0
47 | compile: False
48 |
49 | # simply provide checkpoint path to resume training
50 | ckpt_path: null
51 |
52 | # seed for random number generators in pytorch, numpy and python.random
53 | seed: null
54 |
--------------------------------------------------------------------------------
/configs/trainer/cpu.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | accelerator: cpu
5 | devices: 1
6 |
--------------------------------------------------------------------------------
/configs/trainer/ddp.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | # use "ddp_spawn" instead of "ddp",
5 | # it's slower but normal "ddp" currently doesn't work ideally with hydra
6 | # https://github.com/facebookresearch/hydra/issues/2070
7 | # https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_intermediate.html#distributed-data-parallel-spawn
8 | strategy: ddp_spawn
9 |
10 | accelerator: gpu
11 | devices: 4
12 | num_nodes: 1
13 | sync_batchnorm: True
14 |
--------------------------------------------------------------------------------
/configs/trainer/ddp_sim.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | # simulate DDP on CPU, useful for debugging
5 | accelerator: cpu
6 | devices: 2
7 | strategy: ddp_spawn
8 |
--------------------------------------------------------------------------------
/configs/trainer/default.yaml:
--------------------------------------------------------------------------------
1 | _target_: lightning.pytorch.trainer.Trainer
2 |
3 | default_root_dir: ${paths.output_dir}
4 |
5 | min_epochs: 1 # prevents early stopping
6 | max_epochs: 10
7 |
8 | accelerator: cpu
9 | devices: 1
10 |
11 | # mixed precision for extra speed-up
12 | # precision: 16
13 |
14 | # perform a validation loop every N training epochs
15 | check_val_every_n_epoch: 1
16 |
17 | # set True to to ensure deterministic results
18 | # makes training slower but gives more reproducibility than just setting seeds
19 | deterministic: False
20 |
--------------------------------------------------------------------------------
/configs/trainer/gpu.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | accelerator: gpu
5 | devices: 1
6 |
--------------------------------------------------------------------------------
/configs/trainer/mps.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | accelerator: mps
5 | devices: 1
6 |
--------------------------------------------------------------------------------
/detect/mlp_ember.py:
--------------------------------------------------------------------------------
1 | """detect concept drift on mlp-ember"""
2 | import os
3 | import random
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import pyrootutils
8 | import torch
9 | import typer
10 | from loguru import logger
11 | from pytorch_ood.detector import (
12 | ODIN,
13 | EnergyBased,
14 | Entropy,
15 | KLMatching,
16 | Mahalanobis,
17 | MaxLogit,
18 | MaxSoftmax,
19 | ViM,
20 | )
21 | from pytorch_ood.utils import OODMetrics
22 | from torch.utils.data import ConcatDataset, DataLoader, Dataset
23 |
24 | ROOT = pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
25 |
26 | from src.datasets.ember import EmberDataModule
27 | from src.models.mlp import MLP
28 |
29 | app = typer.Typer(add_completion=False)
30 |
31 |
32 | class CDDataset(Dataset):
33 | def __init__(self, X: list[np.array], y: list[int]):
34 | self.X = X
35 | self.y = y
36 |
37 | def __len__(self) -> int:
38 | return len(self.y)
39 |
40 | def __getitem__(self, index: int) -> tuple:
41 | data, target = self.X[index], self.y[index]
42 | return data, target
43 |
44 |
45 | def seed_everything(seed: int):
46 | os.environ["PL_GLOBAL_SEED"] = str(seed)
47 | random.seed(seed)
48 | np.random.seed(seed)
49 | torch.manual_seed(seed)
50 | torch.cuda.manual_seed_all(seed)
51 |
52 |
53 | @app.command()
54 | def main(
55 | model_file: str = None,
56 | data_name: str = None,
57 | pack_ratio: float = 0.0,
58 | device: str = "cuda:0",
59 | seed: int = 42,
60 | ):
61 | # seed
62 | logger.info(f"Seedeverthing with {seed}")
63 | seed_everything(seed)
64 | # load model
65 | logger.info(f"Loading model from {model_file}")
66 | model = MLP()
67 | model.load_state_dict(torch.load(ROOT / model_file))
68 | model = model.eval().to(device)
69 | last_layer = model.model[-1]
70 |
71 | # load training data
72 | logger.info("Loading ID data MFC")
73 | ID = EmberDataModule(data_name="MFC")
74 | ID.setup()
75 | ID_train_loader = ID.train_dataloader()
76 | ID_test = ID.data_test
77 |
78 | # hit CD data
79 | logger.info(f"Hiting CD from {data_name}")
80 | if data_name in ["MFCUnseen", "MFCUnseenPacking"]:
81 | CD = EmberDataModule(data_name=data_name, pack_ratio=pack_ratio)
82 | CD.setup()
83 | CD_test = CD.data_test
84 | CD_test.y = [-1 for _ in CD_test.y]
85 | elif data_name in ["MFCPacking", "MFCEvolving"]:
86 | len(ID_test)
87 | # find cd data
88 | CD = EmberDataModule(data_name=data_name, pack_ratio=pack_ratio)
89 | CD.setup()
90 | CD_train_loader = CD.train_dataloader()
91 | CD_X = []
92 | with torch.no_grad():
93 | for x, y in CD_train_loader:
94 | logits = model(x.to(device))
95 | preds = torch.argmax(logits, dim=1).tolist()
96 | for i, p in enumerate(preds):
97 | if p != y[i]:
98 | CD_X.append(x[i].detach().numpy())
99 | CD_X = CD_X[: len(ID_test)]
100 | CD_y = [-1 for _ in CD_X]
101 | CD_test = CDDataset(CD_X, CD_y)
102 | else:
103 | logger.error(f"Unknown data_name {data_name}")
104 | raise typer.Exit()
105 |
106 | # pdb.set_trace()
107 | assert len(CD_test)
108 | logger.info(f"TestData | ID: {len(ID_test)}, CD: {len(CD_test)}")
109 | # concatenate ID and CD data
110 | test_data = ConcatDataset([ID_test, CD_test])
111 | test_loader = DataLoader(test_data, batch_size=32, shuffle=True)
112 |
113 | # create detector
114 | std = [1]
115 | logger.info("Creating detectors")
116 | detectors = {}
117 | detectors["MaxSoftmax"] = MaxSoftmax(model)
118 | detectors["ODIN"] = ODIN(model, norm_std=std, eps=0.002)
119 | detectors["Mahalanobis"] = Mahalanobis(model.features, norm_std=std, eps=0.002)
120 | detectors["EnergyBased"] = EnergyBased(model)
121 | detectors["Entropy"] = Entropy(model)
122 | detectors["MaxLogit"] = MaxLogit(model)
123 | detectors["KLMatching"] = KLMatching(model)
124 | detectors["ViM"] = ViM(model.features, d=64, w=last_layer.weight, b=last_layer.bias)
125 |
126 | # fit detectors to training data (some require this, some do not)
127 | logger.info(f"> Fitting {len(detectors)} detectors")
128 | for name, detector in detectors.items():
129 | logger.info(f"--> Fitting {name}")
130 | detector.fit(ID_train_loader, device=device)
131 |
132 | print(
133 | f"STAGE 3: Evaluating {len(detectors)} detectors on {data_name} concept drifts."
134 | )
135 | results = []
136 |
137 | with torch.no_grad():
138 | for detector_name, detector in detectors.items():
139 | print(f"> Evaluating {detector_name}")
140 | metrics = OODMetrics()
141 | for x, y in test_loader:
142 | metrics.update(detector(x.to(device)), y.to(device))
143 |
144 | r = {"Detector": detector_name}
145 | d = {k: round(v * 100, 2) for k, v in metrics.compute().items()}
146 | r.update(d)
147 | results.append(r)
148 |
149 | df = pd.DataFrame(
150 | results, columns=["Detector", "AUROC", "FPR95TPR", "AUPR-IN", "AUPR-OUT"]
151 | )
152 | df.to_csv(ROOT / f"detect/mlp-ember-{data_name}.csv", index=False)
153 | mean_scores = df.groupby("Detector").mean()
154 | logger.info(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))
155 |
156 |
157 | if __name__ == "__main__":
158 | app()
159 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | lightgbm==3.3.3
2 | lightning==2.0.2
3 | torch==2.0.1
4 | torchmetrics==0.11.4
5 | loguru==0.7.2
6 | numpy==1.24.4
7 | omegaconf==2.3.0
8 | optuna==2.10.1
9 | pandas==2.0.1
10 | rich==13.6.0
11 | scikit-learn==1.2.2
12 | typer==0.9.0
13 | wandb==0.15.2
14 | matplotlib
15 | pyrootutils
16 | # pytorch-ood need to fix torchmetrics version
--------------------------------------------------------------------------------
/scripts/detect_mlp_ember_drift.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python detect/mlp_ember.py \
4 | --model-file logs/mlp-ember-MFC-0.0/train/runs/*/checkpoints/best.pt \
5 | --data-name MFCUnseen \
6 | --pack-ratio 1.0
--------------------------------------------------------------------------------
/scripts/test_gbdt_ember.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | # GBDT | Train and Test on MFC, MFCEvolving, MFCPacking, MFCUnseen
4 |
5 | python src/eval_gbdt.py \
6 | --data-name=MFC \
7 | --pack-ratio=0.0 \
8 | --ckpt-path=logs/gbdt-ember-MFC-0.0///gbdt-ember-*.lbg
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/scripts/test_malconv_bytes.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | ## MalConv | Train and Test on MFCEvolving, MFCPacking, MFCUnseen
4 | python src/eval.py \
5 | experiment=malconv-bytes-test \
6 | task_name=malconv-bytes-MFC-0.0 \
7 | data_name=MFC \
8 | pack_ratio=0.0 \
9 | ckpt_path=logs/malconv-bytes-MFC-0.0///
--------------------------------------------------------------------------------
/scripts/test_mlp_ember.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ## Train on MFC and Test on MFCUnseen, MFCPacking, MFCEvolving
4 |
5 | python src/eval.py \
6 | experiment=mlp-ember-test \
7 | task_name=mlp-ember-MFC-0.0 \
8 | data_name=MFCUnseen \
9 | pack_ratio=0.0 \
10 | ckpt_path=logs/mlp-ember-MFC-0.0///
--------------------------------------------------------------------------------
/scripts/train_gbdt_ember.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python src/train_gbdt.py \
4 | --data-name MFC \
5 | --pack-ratio 0.0 \
6 | --do-wandb
--------------------------------------------------------------------------------
/scripts/train_malconv_bytes.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python src/train.py \
4 | experiment=malconv-bytes-train \
5 | task_name=malconv-bytes-MFC-0.0 \
6 | data_name=MFC \
7 | pack_ratio=0.0
8 |
--------------------------------------------------------------------------------
/scripts/train_mlp_ember.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python src/train.py \
3 | experiment=mlp-ember-train \
4 | task_name=mlp-ember-MFC-0.0 \
5 | data_name=MFC \
6 | pack_ratio=0.0
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crowdma/benchmfc/106e5d773e6fdf2bac429446183f996b3ece8e03/src/__init__.py
--------------------------------------------------------------------------------
/src/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crowdma/benchmfc/106e5d773e6fdf2bac429446183f996b3ece8e03/src/datasets/__init__.py
--------------------------------------------------------------------------------
/src/datasets/bytes.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import numpy as np
4 | from lightning import LightningDataModule
5 | from torch.utils.data import DataLoader, Dataset
6 |
7 | from src import utils
8 |
9 | from .mfc import Feature, MalconvByteLoader, get_dataloader
10 |
11 | log = utils.get_pylogger(__name__)
12 |
13 |
14 | class BytesDataset(Dataset):
15 | """Ember Feature in-memory Dataset"""
16 |
17 | def __init__(self, X: list[Path], y: np.array, load_fn: callable):
18 | self.X = X
19 | self.y = y
20 | self.load_fn = load_fn
21 |
22 | def __len__(self) -> int:
23 | return len(self.y)
24 |
25 | def __getitem__(self, index: int) -> tuple:
26 | data = self.load_fn(self.X[index])
27 | target = self.y[index]
28 | return data, target
29 |
30 |
31 | class BytesDataModule(LightningDataModule):
32 | """LightningDataModule for Ember Feature dataset.
33 |
34 | A DataModule implements 5 key methods:
35 | - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode)
36 | - setup (things to do on every accelerator in distributed mode)
37 | - train_dataloader (the training dataloader)
38 | - val_dataloader (the validation dataloader(s))
39 | - test_dataloader (the test dataloader(s))
40 |
41 | This allows you to share a full dataset without explaining how to download,
42 | split, transform and process the data.
43 |
44 | Read the docs:
45 | https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
46 | """
47 |
48 | def __init__(
49 | self,
50 | data_name: str,
51 | train_size: float = 0.6,
52 | val_size: float = 0.2,
53 | test_size: float = 0.2,
54 | batch_size: int = 32,
55 | num_workers: int = 16,
56 | pack_ratio: float = 0.0,
57 | first_n_byte: int = 2**20,
58 | ):
59 | super().__init__()
60 |
61 | # this line allows to access init params with 'self.hparams' attribute
62 | self.save_hyperparameters()
63 |
64 | self.data_train: Dataset = None
65 | self.data_val: Dataset = None
66 | self.data_test: Dataset = None
67 |
68 | self.mfc = get_dataloader(data_name)
69 | self.load_fn = MalconvByteLoader(first_n_byte=first_n_byte)
70 |
71 | def summary(self) -> dict:
72 | if self.mfc.X_train is None:
73 | self.mfc.setup()
74 | return self.mfc.summary()
75 |
76 | def setup(self, stage: str = None):
77 | """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
78 |
79 | This method is called by lightning when doing `trainer.fit()` and `trainer.test()`,
80 | so be careful not to execute the random split twice! The `stage` can be used to
81 | differentiate whether it's called before trainer.fit()` or `trainer.test()`.
82 | """
83 |
84 | # load datasets only if they're not loaded already
85 | if not self.data_train and not self.data_val and not self.data_test:
86 | mfc = self.mfc
87 | mfc.setup(
88 | feature=Feature.SAMPLES,
89 | pack_ratio=self.hparams.pack_ratio,
90 | train_size=self.hparams.train_size,
91 | val_size=self.hparams.val_size,
92 | test_size=self.hparams.test_size,
93 | )
94 | log.info(f"Summary: {self.summary()}")
95 | (X_train, X_val, X_test, y_train, y_val, y_test) = (
96 | mfc.X_train,
97 | mfc.X_val,
98 | mfc.X_test,
99 | mfc.y_train,
100 | mfc.y_val,
101 | mfc.y_test,
102 | )
103 |
104 | load_fn = self.load_fn
105 | self.data_train = BytesDataset(X_train, y_train, load_fn)
106 | self.data_val = BytesDataset(X_val, y_val, load_fn)
107 | self.data_test = BytesDataset(X_test, y_test, load_fn)
108 |
109 | def train_dataloader(self):
110 | return DataLoader(
111 | dataset=self.data_train,
112 | batch_size=self.hparams.batch_size,
113 | num_workers=self.hparams.num_workers,
114 | shuffle=True,
115 | )
116 |
117 | def val_dataloader(self):
118 | return DataLoader(
119 | dataset=self.data_val,
120 | batch_size=self.hparams.batch_size,
121 | num_workers=self.hparams.num_workers,
122 | shuffle=False,
123 | )
124 |
125 | def test_dataloader(self):
126 | return DataLoader(
127 | dataset=self.data_test,
128 | batch_size=self.hparams.batch_size,
129 | num_workers=self.hparams.num_workers,
130 | shuffle=False,
131 | )
132 |
--------------------------------------------------------------------------------
/src/datasets/ember.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from lightning import LightningDataModule
3 | from rich.progress import track
4 | from sklearn.preprocessing import StandardScaler
5 | from torch.utils.data import DataLoader, Dataset
6 |
7 | from src import utils
8 |
9 | from .mfc import Feature, get_dataloader
10 |
11 | log = utils.get_pylogger(__name__)
12 |
13 |
14 | class EmberDataset(Dataset):
15 | """Ember Feature in-memory Dataset"""
16 |
17 | def __init__(self, X: np.ndarray, y: np.array):
18 | self.X = X
19 | self.y = y
20 |
21 | def __len__(self) -> int:
22 | return len(self.y)
23 |
24 | def __getitem__(self, index: int) -> tuple:
25 | data, target = self.X[index, :].astype(np.float32), self.y[index]
26 | return data, target
27 |
28 |
29 | class EmberDataModule(LightningDataModule):
30 | """LightningDataModule for Ember Feature dataset.
31 |
32 | A DataModule implements 5 key methods:
33 | - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode)
34 | - setup (things to do on every accelerator in distributed mode)
35 | - train_dataloader (the training dataloader)
36 | - val_dataloader (the validation dataloader(s))
37 | - test_dataloader (the test dataloader(s))
38 |
39 | This allows you to share a full dataset without explaining how to download,
40 | split, transform and process the data.
41 |
42 | Read the docs:
43 | https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
44 | """
45 |
46 | def __init__(
47 | self,
48 | data_name: str,
49 | train_size: float = 0.6,
50 | val_size: float = 0.2,
51 | test_size: float = 0.2,
52 | batch_size: int = 32,
53 | pack_ratio: float = 0.0,
54 | ):
55 | super().__init__()
56 |
57 | # this line allows to access init params with 'self.hparams' attribute
58 | self.save_hyperparameters()
59 |
60 | self.data_train: Dataset = None
61 | self.data_val: Dataset = None
62 | self.data_test: Dataset = None
63 |
64 | self.mfc = get_dataloader(data_name)
65 |
66 | def summary(self) -> dict:
67 | if self.mfc.X_train is None:
68 | self.mfc.setup()
69 | return self.mfc.summary()
70 |
71 | def setup(self, stage: str = None):
72 | """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
73 |
74 | This method is called by lightning when doing `trainer.fit()` and `trainer.test()`,
75 | so be careful not to execute the random split twice! The `stage` can be used to
76 | differentiate whether it's called before trainer.fit()` or `trainer.test()`.
77 | """
78 |
79 | # load datasets only if they're not loaded already
80 | if not self.data_train and not self.data_val and not self.data_test:
81 | mfc = self.mfc
82 | mfc.setup(
83 | feature=Feature.EMBER_NUMPY,
84 | pack_ratio=self.hparams.pack_ratio,
85 | train_size=self.hparams.train_size,
86 | val_size=self.hparams.val_size,
87 | test_size=self.hparams.test_size,
88 | )
89 | log.info(f"Summary: {self.summary()}")
90 | (X_train, X_val, X_test, y_train, y_val, y_test) = (
91 | mfc.X_train,
92 | mfc.X_val,
93 | mfc.X_test,
94 | mfc.y_train,
95 | mfc.y_val,
96 | mfc.y_test,
97 | )
98 | X_train = [
99 | np.load(i)
100 | for i in track(
101 | X_train, total=len(X_train), description="Loading train..."
102 | )
103 | ]
104 | X_val = [
105 | np.load(i)
106 | for i in track(X_val, total=len(X_val), description="Loading val...")
107 | ]
108 | X_test = [
109 | np.load(i)
110 | for i in track(X_test, total=len(X_test), description="Loading test...")
111 | ]
112 |
113 | log.info("StandardScalering...")
114 | scaler = StandardScaler()
115 | scaler.fit(X_train + X_val + X_test)
116 | X_train = scaler.transform(X_train)
117 | X_val = scaler.transform(X_val)
118 | X_test = scaler.transform(X_test)
119 |
120 | self.data_train = EmberDataset(X_train, y_train)
121 | self.data_val = EmberDataset(X_val, y_val)
122 | self.data_test = EmberDataset(X_test, y_test)
123 |
124 | def train_dataloader(self):
125 | return DataLoader(
126 | dataset=self.data_train,
127 | batch_size=self.hparams.batch_size,
128 | shuffle=True,
129 | drop_last=True,
130 | )
131 |
132 | def val_dataloader(self):
133 | return DataLoader(
134 | dataset=self.data_val,
135 | batch_size=self.hparams.batch_size,
136 | shuffle=False,
137 | )
138 |
139 | def test_dataloader(self):
140 | return DataLoader(
141 | dataset=self.data_test,
142 | batch_size=self.hparams.batch_size,
143 | shuffle=False,
144 | )
145 |
--------------------------------------------------------------------------------
/src/datasets/mfc.py:
--------------------------------------------------------------------------------
1 | """Malware Family Classification Data"""
2 | import os
3 | from collections import Counter, defaultdict
4 | from pathlib import Path
5 |
6 | import numpy as np
7 | from sklearn.model_selection import train_test_split
8 |
9 | MFC_ROOT = Path(os.getenv("MFC_ROOT"))
10 |
11 | PACKERS = ["upx", "mpress", "aes"]
12 |
13 |
14 | class Feature:
15 | EMBER_NUMPY = "feature-ember-npy"
16 | SAMPLES = "samples"
17 |
18 |
19 | FEATURE_SUFFIX = {
20 | Feature.SAMPLES: "",
21 | Feature.EMBER_NUMPY: "ember.npy",
22 | }
23 |
24 |
25 | class Group:
26 | MALICIOUS = "malicious"
27 | MALICIOUS_UNSEEN = "malicious-unseen"
28 | MALICIOUS_EVOLVING = "malicious-evolving"
29 | MALICIOUS_UPX = "malicious-upx"
30 | MALICIOUS_MPRESS = "malicious-mpress"
31 | MALICIOUS_AES = "malicious-aes"
32 |
33 |
34 | class FeatureLoader:
35 | name = "feature_loader"
36 | dtype = "float32"
37 |
38 | def __call__(self, file_path: str) -> np.ndarray:
39 | raise NotImplementedError
40 |
41 |
42 | class NumpyLoader(FeatureLoader):
43 | name = "numpy"
44 | dtype = "float32"
45 |
46 | def __call__(self, file_path: str) -> np.ndarray:
47 | data: np.ndarray = np.load(file_path)
48 | return data.astype(np.float32)
49 |
50 |
51 | class MalconvByteLoader(FeatureLoader):
52 | name = "malconv_byte"
53 | dtype = "int32"
54 |
55 | def __init__(self, first_n_byte: int = 2**20) -> None:
56 | self.first_n_byte = first_n_byte
57 |
58 | def __call__(self, file_path: str) -> np.ndarray:
59 | with open(file_path, "rb") as f:
60 | # index 0 will be special padding index
61 | data = [i + 1 for i in f.read()[: self.first_n_byte]]
62 | data = data + [0] * (self.first_n_byte - len(data))
63 | return np.array(data).astype(np.intc)
64 |
65 |
66 | def class_counter(data: list[str]) -> dict:
67 | return dict(sorted(Counter(data).items(), key=lambda x: x[1], reverse=True))
68 |
69 |
70 | class MFCSample:
71 | """MFC Sample is orginazied by:
72 | ```
73 | root/samples///xxx
74 |
75 | >>> For example:
76 | MFC
77 | ├── samples
78 | │ ├── malicious
79 | │ │ ├── fareit
80 | │ │ ├── gandcrab
81 | │ │ ├── hotbar
82 | │ │ ├── parite
83 | │ │ ├── simda
84 | │ │ ├── upatre
85 | │ │ ├── yuner
86 | │ │ └── zbot
87 | ....
88 | ```
89 | """
90 |
91 | root: str = MFC_ROOT
92 | group: str = Group.MALICIOUS
93 |
94 | def get(
95 | self,
96 | group: str = None,
97 | root: str = None,
98 | ) -> tuple[list[str], list[str]]:
99 | group = group or self.group
100 | root = root or self.root
101 | data_path = Path(root) / "samples"
102 |
103 | X = []
104 | y = []
105 | for r, _, files in os.walk(data_path / group):
106 | for f in files:
107 | file_path = Path(r, f)
108 | X.append(file_path.name)
109 | y.append(file_path.parent.name)
110 |
111 | if len(X) == 0:
112 | raise ValueError(f"Empty: {root}/samples/{group}")
113 |
114 | return X, y
115 |
116 |
117 | class MFCLoader:
118 | root: Path = MFC_ROOT
119 | group: str = Group.MALICIOUS
120 | class_map: dict = {
121 | "fareit": 0,
122 | "gandcrab": 1,
123 | "hotbar": 2,
124 | "parite": 3,
125 | "simda": 4,
126 | "upatre": 5,
127 | "yuner": 6,
128 | "zbot": 7,
129 | }
130 | packers: list[str] = None
131 |
132 | def __init__(self):
133 | self.X_train = None
134 | self.X_test = None
135 | self.X_val = None
136 | self.y_train = None
137 | self.y_test = None
138 | self.y_val = None
139 | self.feature = None
140 | self.pack_ratio = None
141 | self.load_fn: FeatureLoader = None
142 |
143 | def get_path(self, X: list[str], y: list[str]) -> tuple[list[Path], list[int]]:
144 | X_path = []
145 | y_id = []
146 |
147 | for name, family in zip(X, y):
148 | if "_" in name:
149 | group = "-".join([self.group] + name.split("_")[1:])
150 | else:
151 | group = self.group
152 | suffix = FEATURE_SUFFIX[self.feature]
153 | if suffix:
154 | name = f"{name}_{FEATURE_SUFFIX[self.feature]}"
155 | p = self.root / self.feature / group / family / name
156 | if p.exists():
157 | X_path.append(p)
158 | y_id.append(self.class_map[family])
159 | assert len(X_path) > 0
160 | return X_path, y_id
161 |
162 | def pack(
163 | self,
164 | X: list[str],
165 | y: list[int],
166 | pack_ratio: float,
167 | ) -> tuple[list[str], list[int]]:
168 | if pack_ratio == 1.0:
169 | X_packed, y_packed = X, y
170 | X_unpacked, y_unpacked = [], []
171 | else:
172 | X_packed, X_unpacked, y_packed, y_unpacked = train_test_split(
173 | X, y, train_size=pack_ratio, stratify=y, random_state=42
174 | )
175 | num = len(X_packed)
176 | packers = self.packers
177 | m = len(packers)
178 | n = num // m
179 | packed = []
180 | for i, j in enumerate(range(0, num, n)):
181 | packed.extend(
182 | [f"{k}_{packers[i%m]}" for k in X_packed[j : min(num, j + n)]]
183 | )
184 | return packed + X_unpacked, y_packed + y_unpacked
185 |
186 | def setup(
187 | self,
188 | feature: str = None,
189 | pack_ratio: float = None,
190 | train_size: float = 0.6,
191 | val_size: float = 0.2,
192 | test_size: float = 0.2,
193 | ) -> tuple[list[Path], list[Path], list[Path], list[int], list[int], list[int]]:
194 | """
195 | Returns
196 | -------
197 | X_train, X_val, X_test, y_train, y_val, y_test
198 | """
199 | assert sum([train_size, val_size, test_size]) == 1.0
200 |
201 | group = self.group
202 | root = self.root
203 |
204 | self.feature = feature
205 | self.pack_ratio = pack_ratio
206 |
207 | if feature == Feature.EMBER_NUMPY:
208 | self.load_fn = NumpyLoader()
209 | elif feature == Feature.SAMPLES:
210 | self.load_fn = MalconvByteLoader()
211 | else:
212 | raise ValueError(f"Unknown feature: {feature}")
213 |
214 | X, y = MFCSample().get(group, root)
215 | X_train, X_test, y_train, y_test = train_test_split(
216 | X, y, train_size=train_size, stratify=y, random_state=42
217 | )
218 | # # for 40% test samples
219 | # new_size = test_size
220 | new_size = test_size / (test_size + val_size)
221 | X_test, X_val, y_test, y_val = train_test_split(
222 | X_test, y_test, train_size=new_size, stratify=y_test, random_state=42
223 | )
224 |
225 | # pack
226 | if pack_ratio is None:
227 | pack_ratio = self.pack_ratio
228 | if 0.1 <= pack_ratio <= 1.0:
229 | X_train, y_train = self.pack(X_train, y_train, pack_ratio)
230 | X_test, y_test = self.pack(X_test, y_test, pack_ratio)
231 | X_val, y_val = self.pack(X_val, y_val, pack_ratio)
232 |
233 | # path
234 | X_train, y_train = self.get_path(X_train, y_train)
235 | X_test, y_test = self.get_path(X_test, y_test)
236 | X_val, y_val = self.get_path(X_val, y_val)
237 |
238 | self.X_train = X_train
239 | self.X_test = X_test
240 | self.X_val = X_val
241 | self.y_train = y_train
242 | self.y_test = y_test
243 | self.y_val = y_val
244 |
245 | return (X_train, X_val, X_test, y_train, y_val, y_test)
246 |
247 | def is_packed(self, x: Path) -> bool:
248 | return any([i in x.name for i in PACKERS])
249 |
250 | def get_packed_ratio(self, X: list[Path]) -> float:
251 | ratio = sum([self.is_packed(i) for i in X]) / len(X)
252 | return round(ratio, 2)
253 |
254 | def get_packer_dist(self, X: list[list[Path]]) -> dict[str, int]:
255 | packers = defaultdict(int)
256 | for f in X:
257 | hit = False
258 | for p in PACKERS:
259 | if p in f.name:
260 | hit = True
261 | packers[p] += 1
262 | if not hit:
263 | packers["none"] += 1
264 | return dict(packers)
265 |
266 | def load(self, x: Path) -> np.ndarray:
267 | return self.load_fn(x)
268 |
269 | def summary(self) -> dict:
270 | X_train, X_val, X_test = self.X_train, self.X_val, self.X_test
271 | y_train, y_val, y_test = self.y_train, self.y_val, self.y_test
272 | num_train, num_val, num_test = len(X_train), len(X_val), len(X_test)
273 |
274 | ratio_train = self.get_packed_ratio(X_train)
275 | ratio_val = self.get_packed_ratio(X_val)
276 | ratio_test = self.get_packed_ratio(X_test)
277 |
278 | packers_train = self.get_packer_dist(X_train)
279 | packers_val = self.get_packer_dist(X_val)
280 | packers_test = self.get_packer_dist(X_test)
281 |
282 | def data_class(y: list[int]) -> dict[str, int]:
283 | return dict(sorted(Counter(y).items()))
284 |
285 | # data
286 | data = {
287 | "train": {
288 | "total": num_train,
289 | "packer": packers_train,
290 | "packed_ratio": ratio_train,
291 | "class": data_class(y_train),
292 | },
293 | "val": {
294 | "total": num_val,
295 | "packer": packers_val,
296 | "packed_ratio": ratio_val,
297 | "class": data_class(y_val),
298 | },
299 | "test": {
300 | "total": num_test,
301 | "packer": packers_test,
302 | "packed_ratio": ratio_test,
303 | "class": data_class(y_test),
304 | },
305 | }
306 | # feature
307 | x_data = self.load(X_train[0])
308 | x_path = str(X_train[0].relative_to(self.root))
309 | feature = {
310 | "names": self.feature,
311 | "loader": self.load_fn.name,
312 | "dtype": self.load_fn.dtype,
313 | "example": x_path,
314 | "dimension": len(x_data),
315 | }
316 | return {"data": data, "feature": feature}
317 |
318 |
319 | class MFC(MFCLoader):
320 | group: str = Group.MALICIOUS
321 | class_map: dict = {
322 | "fareit": 0,
323 | "gandcrab": 1,
324 | "hotbar": 2,
325 | "parite": 3,
326 | "simda": 4,
327 | "upatre": 5,
328 | "yuner": 6,
329 | "zbot": 7,
330 | }
331 |
332 |
333 | class MFCEvolving(MFC):
334 | group: str = Group.MALICIOUS_EVOLVING
335 |
336 |
337 | class MFCUnseen(MFCLoader):
338 | group: str = Group.MALICIOUS_UNSEEN
339 | class_map: dict = {
340 | "hupigon": 0,
341 | "imali": 1,
342 | "lydra": 2,
343 | "onlinegames": 3,
344 | "virut": 4,
345 | "vobfus": 5,
346 | "wannacry": 6,
347 | "zlob": 7,
348 | }
349 |
350 |
351 | class MFCPacking(MFC):
352 | packers: list[str] = ["upx", "mpress", "aes"]
353 | pack_ratio: float = 1.0
354 |
355 |
356 | class MFCAes(MFC):
357 | packers: list[str] = ["aes"]
358 | pack_ratio: float = 1.0
359 |
360 |
361 | MFC_LOADER: dict[str, MFCLoader] = {
362 | "MFC": MFC(),
363 | "MFCAes": MFCAes(),
364 | "MFCEvolving": MFCEvolving(),
365 | "MFCPacking": MFCPacking(),
366 | "MFCUnseen": MFCUnseen(),
367 | }
368 |
369 |
370 | def get_dataloader(name: str) -> MFCLoader:
371 | return MFC_LOADER[name]
372 |
373 |
374 | if __name__ == "__main__":
375 | import rich
376 |
377 | mfc = MFCPacking()
378 | mfc.setup()
379 | rich.print(mfc.summary())
380 |
--------------------------------------------------------------------------------
/src/eval.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | import hydra
4 | import numpy as np
5 | import pyrootutils
6 | from lightning import LightningDataModule, LightningModule, Trainer
7 | from lightning.pytorch.loggers import Logger
8 | from omegaconf import DictConfig
9 |
10 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
11 | # ------------------------------------------------------------------------------------ #
12 | # the setup_root above is equivalent to:
13 | # - adding project root dir to PYTHONPATH
14 | # (so you don't need to force user to install project as a package)
15 | # (necessary before importing any local modules e.g. `from src import utils`)
16 | # - setting up PROJECT_ROOT environment variable
17 | # (which is used as a base for paths in "configs/paths/default.yaml")
18 | # (this way all filepaths are the same no matter where you run the code)
19 | # - loading environment variables from ".env" in root dir
20 | #
21 | # you can remove it if you:
22 | # 1. either install project as a package or move entry files to project root dir
23 | # 2. set `root_dir` to "." in "configs/paths/default.yaml"
24 | #
25 | # more info: https://github.com/ashleve/pyrootutils
26 | # ------------------------------------------------------------------------------------ #
27 |
28 | from src import utils
29 |
30 | log = utils.get_pylogger(__name__)
31 |
32 |
33 | @utils.task_wrapper
34 | def evaluate(cfg: DictConfig) -> Tuple[dict, dict]:
35 | """Evaluates given checkpoint on a datamodule testset.
36 |
37 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
38 | failure. Useful for multiruns, saving info about the crash, etc.
39 |
40 | Args:
41 | cfg (DictConfig): Configuration composed by Hydra.
42 |
43 | Returns:
44 | Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
45 | """
46 |
47 | assert cfg.ckpt_path
48 |
49 | log.info(f"Instantiating datamodule <{cfg.data._target_}>")
50 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
51 |
52 | log.info(f"Instantiating model <{cfg.model._target_}>")
53 | model: LightningModule = hydra.utils.instantiate(cfg.model)
54 |
55 | log.info("Instantiating loggers...")
56 | logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))
57 |
58 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
59 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)
60 |
61 | object_dict = {
62 | "cfg": cfg,
63 | "datamodule": datamodule,
64 | "model": model,
65 | "logger": logger,
66 | "trainer": trainer,
67 | }
68 |
69 | if logger:
70 | log.info("Logging hyperparameters!")
71 | utils.log_hyperparameters(object_dict)
72 |
73 | # log data summary
74 | datamodule.setup()
75 | log.info("Logging data summary!")
76 | utils.log_data_summary(cfg, datamodule)
77 |
78 | log.info("Starting testing!")
79 | test_loader = datamodule.test_dataloader()
80 | y_true = []
81 | for _, y in test_loader:
82 | if isinstance(y, list):
83 | y_true.extend(y[-1].numpy())
84 | else:
85 | y_true.extend(y.numpy())
86 | y_true = np.hstack(y_true)
87 | y_pred = trainer.predict(
88 | model=model, dataloaders=test_loader, ckpt_path=cfg.ckpt_path
89 | )
90 | y_pred = np.hstack(y_pred)
91 | utils.log_test_results(cfg, y_true, y_pred)
92 |
93 | metric_dict = trainer.callback_metrics
94 |
95 | return metric_dict, object_dict
96 |
97 |
98 | @hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")
99 | def main(cfg: DictConfig) -> None:
100 | # apply extra utilities
101 | # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
102 | utils.extras(cfg)
103 |
104 | evaluate(cfg)
105 |
106 |
107 | if __name__ == "__main__":
108 | main()
109 |
--------------------------------------------------------------------------------
/src/eval_gbdt.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import pyrootutils
7 | import rich
8 | import typer
9 | from rich.progress import track
10 | from sklearn.metrics import classification_report
11 |
12 | ROOT = pyrootutils.setup_root(
13 | __file__,
14 | indicator=".project-root",
15 | pythonpath=True,
16 | )
17 |
18 | from src.datasets.mfc import MFC_LOADER
19 | from src.models.gbdt import GBDTClassifier
20 |
21 | app = typer.Typer()
22 |
23 |
24 | def pprint(data):
25 | return json.dumps(data, indent=2)
26 |
27 |
28 | def seed_everything(seed: int) -> None:
29 | np.random.seed(seed)
30 | random.seed(seed)
31 |
32 |
33 | @app.command()
34 | def main(
35 | data_name: str = None,
36 | ckpt_path: str = None,
37 | feature_name: str = "feature-ember-npy",
38 | pack_ratio: float = 0.0,
39 | train_size: float = 0.6,
40 | val_size: float = 0.2,
41 | test_size: float = 0.2,
42 | seed: int = 42,
43 | ):
44 | seed_everything(seed)
45 | gbdt = GBDTClassifier()
46 | gbdt.load(ckpt_path)
47 |
48 | ckpt_path = Path(ckpt_path)
49 | log_dir = ckpt_path.parent / f"{data_name}"
50 | log_dir.mkdir(parents=True, exist_ok=True)
51 |
52 | # prepare data
53 | mfc = MFC_LOADER[data_name]
54 | mfc.setup(
55 | feature=feature_name,
56 | train_size=train_size,
57 | val_size=val_size,
58 | test_size=test_size,
59 | pack_ratio=pack_ratio,
60 | )
61 | rich.print(mfc.summary())
62 | with open(log_dir / "data_summary.log", "w") as file:
63 | rich.print(mfc.summary(), file=file)
64 |
65 | _, _, X_test, _, _, y_test = (
66 | mfc.X_train,
67 | mfc.X_val,
68 | mfc.X_test,
69 | mfc.y_train,
70 | mfc.y_val,
71 | mfc.y_test,
72 | )
73 | X_test = [
74 | np.load(i)
75 | for i in track(X_test, total=len(X_test), description="Loading test...")
76 | ]
77 | X_test = np.vstack(X_test)
78 | y_test = np.array(y_test)
79 |
80 | # predict
81 | predict = gbdt.predict(X_test)
82 | predict = [np.argmax(i) for i in predict]
83 |
84 | result = classification_report(y_true=y_test, y_pred=predict, digits=4)
85 | rich.print(f"Test Report: {result}")
86 | with open(log_dir / "test_results.log", "w") as file:
87 | rich.print(result, file=file)
88 |
89 |
90 | if __name__ == "__main__":
91 | app()
92 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/crowdma/benchmfc/106e5d773e6fdf2bac429446183f996b3ece8e03/src/models/__init__.py
--------------------------------------------------------------------------------
/src/models/gbdt.py:
--------------------------------------------------------------------------------
1 | import lightgbm as lgb
2 |
3 |
4 | class GBDTClassifier:
5 | def __init__(
6 | self,
7 | boosting: str = "gbdt",
8 | objective: str = "multiclass",
9 | num_class: int = 8,
10 | metric: str = "multi_logloss",
11 | num_iterations: int = 1_000,
12 | learning_rate: float = 0.05,
13 | num_leaves: int = 2048,
14 | max_depth: int = 15,
15 | min_data_in_leaf: int = 50,
16 | feature_fraction: float = 0.5,
17 | device: str = "cpu",
18 | num_threads: int = 24,
19 | verbosity: int = -1,
20 | ):
21 | self.hparams = {
22 | "boosting": boosting,
23 | "objective": objective,
24 | "num_class": num_class,
25 | "metric": metric,
26 | "num_iterations": num_iterations,
27 | "learning_rate": learning_rate,
28 | "num_leaves": num_leaves,
29 | "max_depth": max_depth,
30 | "min_data_in_leaf": min_data_in_leaf,
31 | "feature_fraction": feature_fraction,
32 | "device": device,
33 | "num_threads": num_threads,
34 | # 1 means INFO, > 1 means DEBUG, 0 means Error(WARNING), <0 means Fatal
35 | "verbosity": verbosity,
36 | }
37 | self.model = None
38 |
39 | def load(self, model_file: str) -> None:
40 | self.model = lgb.Booster(model_file=model_file)
41 |
42 | def fit(self, X_train, y_train, callbacks=None) -> None:
43 | lgbm_dataset = lgb.Dataset(X_train, y_train)
44 | self.model = lgb.train(self.hparams, lgbm_dataset, callbacks=callbacks)
45 |
46 | def predict(self, X_test):
47 | return self.model.predict(X_test)
48 |
--------------------------------------------------------------------------------
/src/models/malconv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class MalConv(nn.Module):
6 | def __init__(
7 | self,
8 | input_length: int = 2**20,
9 | window_size: int = 500,
10 | stride: int = 500,
11 | channels: int = 128,
12 | embed_size: int = 8,
13 | output_size: int = 8,
14 | ):
15 | super().__init__()
16 | self.channels = channels
17 | self.embed = nn.Embedding(257, embed_size, padding_idx=0)
18 | in_channels = int(embed_size / 2)
19 | self.conv_1 = nn.Conv1d(
20 | in_channels, channels, window_size, stride=stride, bias=True
21 | )
22 | self.conv_2 = nn.Conv1d(
23 | in_channels, channels, window_size, stride=stride, bias=True
24 | )
25 | self.pooling = nn.MaxPool1d(int(input_length / window_size))
26 | self.fc_1 = nn.Linear(channels, channels)
27 | self.fc_2 = nn.Linear(channels, output_size)
28 | self.sigmoid = nn.Sigmoid()
29 | # num_classes
30 | self.num_classes = output_size
31 |
32 | def forward(self, x):
33 | x = self.embed(x)
34 | # Channel first
35 | x = torch.transpose(x, -1, -2)
36 | cnn_value = self.conv_1(x.narrow(-2, 0, 4))
37 | gating_weight = self.sigmoid(self.conv_2(x.narrow(-2, 4, 4)))
38 | x = cnn_value * gating_weight
39 | x = self.pooling(x)
40 | x = x.view(-1, self.channels)
41 | x = self.fc_1(x)
42 | x = self.fc_2(x)
43 | return x
44 |
45 | def features(self, x):
46 | """
47 | Extracts (flattened) features before the last fully connected layer.
48 | """
49 | x = self.embed(x)
50 | # Channel first
51 | x = torch.transpose(x, -1, -2)
52 | cnn_value = self.conv_1(x.narrow(-2, 0, 4))
53 | gating_weight = self.sigmoid(self.conv_2(x.narrow(-2, 4, 4)))
54 | x = cnn_value * gating_weight
55 | x = self.pooling(x)
56 | x = x.view(-1, self.channels)
57 | x = self.fc_1(x)
58 | return x
59 |
--------------------------------------------------------------------------------
/src/models/malconv_module.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import torch
4 | from lightning import LightningModule
5 | from torchmetrics import Accuracy, MaxMetric, MeanMetric
6 |
7 |
8 | class MalconvModule(LightningModule):
9 | """MLP Module.
10 |
11 | A LightningModule organizes your PyTorch code into 5 sections:
12 | - Computations (init).
13 | - Train loop (training_step)
14 | - Validation loop (validation_step)
15 | - Test loop (test_step)
16 | - Optimizers (configure_optimizers)
17 |
18 | Read the docs:
19 | https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
20 | """
21 |
22 | def __init__(
23 | self,
24 | network: torch.nn.Module,
25 | optimizer: torch.optim.Optimizer,
26 | scheduler: torch.optim.lr_scheduler,
27 | ):
28 | super().__init__()
29 |
30 | # this line allows to access init params with 'self.hparams' attribute
31 | # it also ensures init params will be stored in ckpt
32 | self.save_hyperparameters(logger=False)
33 |
34 | self.network = network
35 | num_classes = network.num_classes
36 |
37 | # loss function
38 | self.criterion = torch.nn.CrossEntropyLoss()
39 |
40 | # metric objects for calculating and averaging accuracy across batches
41 | self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
42 | self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
43 | self.test_acc = Accuracy(task="multiclass", num_classes=num_classes)
44 |
45 | # for averaging loss across batches
46 | self.train_loss = MeanMetric()
47 | self.val_loss = MeanMetric()
48 | self.test_loss = MeanMetric()
49 |
50 | # for logging best so far validation accuracy
51 | self.val_acc_best = MaxMetric()
52 |
53 | def forward(self, x: torch.Tensor):
54 | return self.network(x)
55 |
56 | def on_train_start(self):
57 | # by default lightning executes validation step sanity checks before training starts,
58 | # so it's worth to make sure validation metrics don't store results from these checks
59 | self.val_loss.reset()
60 | self.val_acc.reset()
61 | self.val_acc_best.reset()
62 |
63 | def model_step(self, batch):
64 | x, y = batch
65 | logits = self.forward(x)
66 | loss = self.criterion(logits, y)
67 | preds = torch.argmax(logits, dim=1)
68 | return loss, preds, y
69 |
70 | def training_step(self, batch, batch_idx: int):
71 | loss, preds, targets = self.model_step(batch)
72 |
73 | # update and log metrics
74 | self.train_loss(loss)
75 | self.train_acc(preds, targets)
76 | self.log(
77 | "train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True
78 | )
79 | self.log(
80 | "train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True
81 | )
82 |
83 | # return loss or backpropagation will fail
84 | return loss
85 |
86 | def on_train_epoch_end(self):
87 | pass
88 |
89 | def validation_step(self, batch, batch_idx: int):
90 | loss, preds, targets = self.model_step(batch)
91 |
92 | # update and log metrics
93 | self.val_loss(loss)
94 | self.val_acc(preds, targets)
95 | self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
96 | self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
97 |
98 | def on_validation_epoch_end(self):
99 | acc = self.val_acc.compute() # get current val acc
100 | self.val_acc_best(acc) # update best so far val acc
101 | # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
102 | # otherwise metric would be reset by lightning after each epoch
103 | self.log("val/acc_best", self.val_acc_best.compute(), prog_bar=True)
104 |
105 | def test_step(self, batch, batch_idx: int):
106 | loss, preds, targets = self.model_step(batch)
107 |
108 | # update and log metrics
109 | self.test_loss(loss)
110 | self.test_acc(preds, targets)
111 | self.log(
112 | "test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True
113 | )
114 | self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)
115 |
116 | def on_test_epoch_end(self):
117 | pass
118 |
119 | def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
120 | loss, preds, targets = self.model_step(batch)
121 | return preds.tolist()
122 |
123 | def configure_optimizers(self):
124 | """Choose what optimizers and learning-rate schedulers to use in your optimization.
125 | Normally you'd need one. But in the case of GANs or similar you might have multiple.
126 |
127 | Examples:
128 | https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
129 | """
130 | optimizer = self.hparams.optimizer(params=self.parameters())
131 | if self.hparams.scheduler is not None:
132 | scheduler = self.hparams.scheduler(optimizer=optimizer)
133 | return {
134 | "optimizer": optimizer,
135 | "lr_scheduler": {
136 | "scheduler": scheduler,
137 | "monitor": "val/loss",
138 | "interval": "epoch",
139 | "frequency": 1,
140 | },
141 | }
142 | return {"optimizer": optimizer}
143 |
--------------------------------------------------------------------------------
/src/models/mlp.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from torch import nn
4 |
5 |
6 | class MLP(nn.Module):
7 | def __init__(
8 | self,
9 | output_size: int = 8,
10 | input_size: int = 2381,
11 | hidden_units: list[int] = [1024, 512, 256],
12 | ):
13 | super().__init__()
14 | all_layers = []
15 | for hidden in hidden_units:
16 | all_layers.append(nn.Linear(input_size, hidden))
17 | all_layers.append(nn.BatchNorm1d(hidden)),
18 | all_layers.append(nn.ReLU())
19 | input_size = hidden
20 | all_layers.append(nn.Linear(hidden_units[-1], output_size))
21 | self.model = nn.Sequential(*all_layers)
22 | # num_classes
23 | self.num_classes = output_size
24 |
25 | def forward(self, x):
26 | batch_size, _ = x.size()
27 | # (batch, 1, width, height) -> (batch, 1*width*height)
28 | x = x.view(batch_size, -1)
29 |
30 | return self.model(x)
31 |
32 | def features(self, x):
33 | """
34 | Extracts (flattened) features before the last fully connected layer.
35 | """
36 | batch_size, _ = x.size()
37 | # (batch, 1, width, height) -> (batch, 1*width*height)
38 | x = x.view(batch_size, -1)
39 |
40 | fea = self.model[:-1]
41 | return fea(x)
42 |
--------------------------------------------------------------------------------
/src/models/mlp_module.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import torch
4 | from lightning import LightningModule
5 | from torchmetrics import Accuracy, MaxMetric, MeanMetric
6 |
7 | from src import utils
8 |
9 | log = utils.get_pylogger(__name__)
10 |
11 |
12 | class MLPModule(LightningModule):
13 | """MLP Module.
14 |
15 | A LightningModule organizes your PyTorch code into 5 sections:
16 | - Computations (init).
17 | - Train loop (training_step)
18 | - Validation loop (validation_step)
19 | - Test loop (test_step)
20 | - Optimizers (configure_optimizers)
21 |
22 | Read the docs:
23 | https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
24 | """
25 |
26 | def __init__(
27 | self,
28 | network: torch.nn.Module,
29 | optimizer: torch.optim.Optimizer,
30 | scheduler: torch.optim.lr_scheduler,
31 | ):
32 | super().__init__()
33 |
34 | # this line allows to access init params with 'self.hparams' attribute
35 | # it also ensures init params will be stored in ckpt
36 | self.save_hyperparameters(logger=False)
37 |
38 | self.network = network
39 | num_classes = network.num_classes
40 |
41 | # loss function
42 | self.criterion = torch.nn.CrossEntropyLoss()
43 |
44 | # metric objects for calculating and averaging accuracy across batches
45 | self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
46 | self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
47 | self.test_acc = Accuracy(task="multiclass", num_classes=num_classes)
48 |
49 | # for averaging loss across batches
50 | self.train_loss = MeanMetric()
51 | self.val_loss = MeanMetric()
52 | self.test_loss = MeanMetric()
53 |
54 | # for logging best so far validation accuracy
55 | self.val_acc_best = MaxMetric()
56 |
57 | def forward(self, x: torch.Tensor):
58 | return self.network(x)
59 |
60 | def on_train_start(self):
61 | # by default lightning executes validation step sanity checks before training starts,
62 | # so it's worth to make sure validation metrics don't store results from these checks
63 | self.val_loss.reset()
64 | self.val_acc.reset()
65 | self.val_acc_best.reset()
66 |
67 | def model_step(self, batch):
68 | x, y = batch
69 | logits = self.forward(x)
70 | loss = self.criterion(logits, y)
71 | preds = torch.argmax(logits, dim=1)
72 | return loss, preds, y
73 |
74 | def training_step(self, batch, batch_idx: int):
75 | loss, preds, targets = self.model_step(batch)
76 |
77 | # update and log metrics
78 | self.train_loss(loss)
79 | self.train_acc(preds, targets)
80 | self.log(
81 | "train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True
82 | )
83 | self.log(
84 | "train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True
85 | )
86 |
87 | # return loss or backpropagation will fail
88 | return loss
89 |
90 | def on_train_epoch_end(self):
91 | pass
92 |
93 | def validation_step(self, batch, batch_idx: int):
94 | loss, preds, targets = self.model_step(batch)
95 |
96 | # update and log metrics
97 | self.val_loss(loss)
98 | self.val_acc(preds, targets)
99 | self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
100 | self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
101 |
102 | def on_validation_epoch_end(self):
103 | acc = self.val_acc.compute() # get current val acc
104 | self.val_acc_best(acc) # update best so far val acc
105 | # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
106 | # otherwise metric would be reset by lightning after each epoch
107 | self.log("val/acc_best", self.val_acc_best.compute(), prog_bar=True)
108 |
109 | def test_step(self, batch, batch_idx: int):
110 | loss, preds, targets = self.model_step(batch)
111 |
112 | # update and log metrics
113 | self.test_loss(loss)
114 | self.test_acc(preds, targets)
115 | self.log(
116 | "test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True
117 | )
118 | self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)
119 |
120 | def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
121 | loss, preds, targets = self.model_step(batch)
122 | return preds.tolist()
123 |
124 | def on_test_epoch_end(self):
125 | pass
126 |
127 | def configure_optimizers(self):
128 | """Choose what optimizers and learning-rate schedulers to use in your optimization.
129 | Normally you'd need one. But in the case of GANs or similar you might have multiple.
130 |
131 | Examples:
132 | https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
133 | """
134 | optimizer = self.hparams.optimizer(params=self.parameters())
135 | if self.hparams.scheduler is not None:
136 | scheduler = self.hparams.scheduler(optimizer=optimizer)
137 | return {
138 | "optimizer": optimizer,
139 | "lr_scheduler": {
140 | "scheduler": scheduler,
141 | "monitor": "val/loss",
142 | "interval": "epoch",
143 | "frequency": 1,
144 | },
145 | }
146 | return {"optimizer": optimizer}
147 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import List, Optional, Tuple
3 |
4 | import hydra
5 | import lightning as L
6 | import numpy as np
7 | import pyrootutils
8 | import torch
9 | from lightning import Callback, LightningDataModule, LightningModule, Trainer
10 | from lightning.pytorch.loggers import Logger
11 | from omegaconf import DictConfig
12 |
13 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
14 | # ------------------------------------------------------------------------------------ #
15 | # the setup_root above is equivalent to:
16 | # - adding project root dir to PYTHONPATH
17 | # (so you don't need to force user to install project as a package)
18 | # (necessary before importing any local modules e.g. `from src import utils`)
19 | # - setting up PROJECT_ROOT environment variable
20 | # (which is used as a base for paths in "configs/paths/default.yaml")
21 | # (this way all filepaths are the same no matter where you run the code)
22 | # - loading environment variables from ".env" in root dir
23 | #
24 | # you can remove it if you:
25 | # 1. either install project as a package or move entry files to project root dir
26 | # 2. set `root_dir` to "." in "configs/paths/default.yaml"
27 | #
28 | # more info: https://github.com/ashleve/pyrootutils
29 | # ------------------------------------------------------------------------------------ #
30 |
31 | from src import utils
32 |
33 | log = utils.get_pylogger(__name__)
34 |
35 |
36 | @utils.task_wrapper
37 | def train(cfg: DictConfig) -> Tuple[dict, dict]:
38 | """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
39 | training.
40 |
41 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
42 | failure. Useful for multiruns, saving info about the crash, etc.
43 |
44 | Args:
45 | cfg (DictConfig): Configuration composed by Hydra.
46 |
47 | Returns:
48 | Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
49 | """
50 |
51 | # set seed for random number generators in pytorch, numpy and python.random
52 | if cfg.get("seed"):
53 | L.seed_everything(cfg.seed, workers=True)
54 |
55 | log.info(f"Instantiating datamodule <{cfg.data._target_}>")
56 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
57 |
58 | log.info(f"Instantiating model <{cfg.model._target_}>")
59 | model: LightningModule = hydra.utils.instantiate(cfg.model)
60 |
61 | log.info("Instantiating callbacks...")
62 | callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
63 |
64 | log.info("Instantiating loggers...")
65 | logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))
66 |
67 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
68 | trainer: Trainer = hydra.utils.instantiate(
69 | cfg.trainer, callbacks=callbacks, logger=logger
70 | )
71 |
72 | object_dict = {
73 | "cfg": cfg,
74 | "datamodule": datamodule,
75 | "model": model,
76 | "callbacks": callbacks,
77 | "logger": logger,
78 | "trainer": trainer,
79 | }
80 |
81 | if logger:
82 | log.info("Logging hyperparameters!")
83 | utils.log_hyperparameters(object_dict)
84 |
85 | # log data summary
86 | datamodule.setup()
87 | log.info("Logging data summary!")
88 | utils.log_data_summary(cfg, datamodule)
89 |
90 | if cfg.get("compile"):
91 | log.info("Compiling model!")
92 | model = torch.compile(model)
93 |
94 | if cfg.get("train"):
95 | log.info("Starting training!")
96 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
97 | # save state_dict
98 | ckpt_path = Path(trainer.checkpoint_callback.best_model_path)
99 | log.info(f"Best ckpt path: {ckpt_path}")
100 | model = model.load_from_checkpoint(ckpt_path)
101 | torch.save(model.network.state_dict(), ckpt_path.parent / "best.pt")
102 |
103 | train_metrics = trainer.callback_metrics
104 |
105 | if cfg.get("test"):
106 | log.info("Starting testing!")
107 | ckpt_path = trainer.checkpoint_callback.best_model_path
108 | log.info(f"Best ckpt path: {ckpt_path}")
109 | test_loader = datamodule.test_dataloader()
110 | y_true = []
111 | for _, y in test_loader:
112 | if isinstance(y, list):
113 | y_true.extend(y[-1].numpy())
114 | else:
115 | y_true.extend(y.numpy())
116 | y_true = np.hstack(y_true)
117 | trainer.test(model=model, dataloaders=test_loader, ckpt_path=ckpt_path)
118 | y_pred = trainer.predict(
119 | model=model, dataloaders=test_loader, ckpt_path=ckpt_path
120 | )
121 | y_pred = np.hstack(y_pred)
122 | utils.log_test_results(cfg, y_true, y_pred)
123 |
124 | test_metrics = trainer.callback_metrics
125 |
126 | # merge train and test metrics
127 | metric_dict = {**train_metrics, **test_metrics}
128 |
129 | return metric_dict, object_dict
130 |
131 |
132 | @hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
133 | def main(cfg: DictConfig) -> Optional[float]:
134 | # apply extra utilities
135 | # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
136 | utils.extras(cfg)
137 |
138 | # train the model
139 | metric_dict, _ = train(cfg)
140 |
141 | # safely retrieve metric value for hydra-based hyperparameter optimization
142 | metric_value = utils.get_metric_value(
143 | metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
144 | )
145 |
146 | # return optimized metric
147 | return metric_value
148 |
149 |
150 | if __name__ == "__main__":
151 | main()
152 |
--------------------------------------------------------------------------------
/src/train_gbdt.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | from datetime import datetime
4 | from timeit import default_timer as timer
5 |
6 | import numpy as np
7 | import pyrootutils
8 | import typer
9 | import wandb
10 | from loguru import logger
11 | from rich.progress import track
12 | from sklearn.metrics import classification_report, top_k_accuracy_score
13 | from wandb.lightgbm import log_summary, wandb_callback
14 |
15 | ROOT = pyrootutils.setup_root(
16 | __file__,
17 | indicator=".project-root",
18 | pythonpath=True,
19 | )
20 |
21 | from src.datasets.mfc import MFC_LOADER
22 | from src.models.gbdt import GBDTClassifier
23 |
24 | app = typer.Typer()
25 |
26 |
27 | def pprint(data):
28 | return json.dumps(data, indent=2)
29 |
30 |
31 | def seed_everything(seed: int) -> None:
32 | np.random.seed(seed)
33 | random.seed(seed)
34 |
35 |
36 | @app.command()
37 | def main(
38 | data_name: str = None,
39 | feature_name: str = "feature-ember-npy",
40 | task_group: str = "gbdt-ember",
41 | pack_ratio: float = 0.0,
42 | train_size: float = 0.6,
43 | val_size: float = 0.2,
44 | test_size: float = 0.2,
45 | boosting: str = "gbdt",
46 | objective: str = "multiclass",
47 | num_class: int = 8,
48 | metric: str = "multi_logloss",
49 | num_iterations: int = 1_000,
50 | learning_rate: float = 0.05,
51 | num_leaves: int = 2048,
52 | max_depth: int = 15,
53 | min_data_in_leaf: int = 50,
54 | feature_fraction: float = 0.5,
55 | verbosity: int = -1,
56 | device: str = "cpu",
57 | num_threads: int = 20,
58 | seed: int = 42,
59 | do_wandb: bool = False,
60 | project: str = "lab-benchmfc",
61 | ):
62 | seed_everything(seed)
63 | # gbdt_config
64 | gbdt_params = {
65 | "boosting": boosting,
66 | "objective": objective,
67 | "num_class": num_class,
68 | "metric": metric,
69 | "num_iterations": num_iterations,
70 | "learning_rate": learning_rate,
71 | "num_leaves": num_leaves,
72 | "max_depth": max_depth,
73 | "min_data_in_leaf": min_data_in_leaf,
74 | "feature_fraction": feature_fraction,
75 | "device": device,
76 | "num_threads": num_threads,
77 | "verbosity": verbosity,
78 | }
79 | config = locals()
80 |
81 | gbdt = GBDTClassifier(**gbdt_params)
82 | # time
83 | now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
84 | # log_dir
85 | task_name = f"{task_group}-{data_name}-{pack_ratio}"
86 | log_dir = ROOT / f"logs/{task_name}/runs/{now}"
87 | log_dir.mkdir(parents=True, exist_ok=True)
88 | logger.add(f"{log_dir}/train.log", level="INFO")
89 |
90 | logger.info(f"[-] Global seed = {seed}")
91 | logger.info(f"[-] Config: {pprint(config)}")
92 | logger.info(f"[-] GBDT Params: {pprint(gbdt_params)}")
93 |
94 | start = timer()
95 | # prepare data
96 | X_train, X_val, X_test, y_train, y_val, y_test = MFC_LOADER[data_name].load(
97 | feature=feature_name,
98 | train_size=train_size,
99 | val_size=val_size,
100 | test_size=test_size,
101 | pack_ratio=pack_ratio,
102 | )
103 | X_train = [
104 | np.load(i)
105 | for i in track(X_train, total=len(X_train), description="Loading train...")
106 | ]
107 | X_test = [
108 | np.load(i)
109 | for i in track(X_test, total=len(X_test), description="Loading test...")
110 | ]
111 |
112 | X_train = np.vstack(X_train)
113 | y_train = np.array(y_train)
114 | X_test = np.vstack(X_test)
115 | y_test = np.array(y_test)
116 | # train
117 | name = task_name
118 | if do_wandb:
119 | wandb.login()
120 | wandb.init(
121 | project=project, name=name, group=task_group, config=config, dir=log_dir
122 | )
123 | gbdt.fit(X_train, y_train, callbacks=[wandb_callback()])
124 | log_summary(gbdt.model, feature_importance=True)
125 | else:
126 | gbdt.fit(X_train, y_train)
127 | # save model
128 | model_file = log_dir / f"{task_name}.lbg"
129 | logger.info(f"[-] save model: {model_file}")
130 | gbdt.model.save_model(model_file)
131 | # test
132 | predict = gbdt.predict(X_test)
133 | acc = top_k_accuracy_score(y_true=y_test, y_score=predict, k=1)
134 | logger.info(f"[*] Top-1 accuracy: {acc}")
135 | if do_wandb:
136 | wandb.log({"test/acc": acc})
137 | wandb.finish()
138 |
139 | predict = [np.argmax(i) for i in predict]
140 | result = classification_report(
141 | y_true=y_test, y_pred=predict, digits=4, output_dict=True
142 | )
143 | logger.info(f"[*] Classification_report (macro avg): {pprint(result['macro avg'])}")
144 |
145 | end = timer()
146 | logger.info(f"[-] timecost: {end - start} s")
147 |
148 |
149 | if __name__ == "__main__":
150 | app()
151 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from src.utils.instantiators import instantiate_callbacks, instantiate_loggers
2 | from src.utils.logging_utils import log_hyperparameters
3 | from src.utils.pylogger import get_pylogger
4 | from src.utils.rich_utils import (
5 | enforce_tags,
6 | log_data_summary,
7 | log_test_results,
8 | print_config_tree,
9 | )
10 | from src.utils.utils import extras, get_metric_value, log_confusion_matrix, task_wrapper
11 |
12 | __all__ = [
13 | instantiate_callbacks,
14 | instantiate_loggers,
15 | log_hyperparameters,
16 | get_pylogger,
17 | enforce_tags,
18 | log_test_results,
19 | print_config_tree,
20 | log_data_summary,
21 | extras,
22 | get_metric_value,
23 | log_confusion_matrix,
24 | task_wrapper,
25 | ]
26 |
--------------------------------------------------------------------------------
/src/utils/instantiators.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import hydra
4 | from omegaconf import DictConfig
5 | from pytorch_lightning import Callback
6 | from pytorch_lightning.loggers import Logger
7 |
8 | from src.utils import pylogger
9 |
10 | log = pylogger.get_pylogger(__name__)
11 |
12 |
13 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
14 | """Instantiates callbacks from config."""
15 |
16 | callbacks: List[Callback] = []
17 |
18 | if not callbacks_cfg:
19 | log.warning("No callback configs found! Skipping..")
20 | return callbacks
21 |
22 | if not isinstance(callbacks_cfg, DictConfig):
23 | raise TypeError("Callbacks config must be a DictConfig!")
24 |
25 | for _, cb_conf in callbacks_cfg.items():
26 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
27 | log.info(f"Instantiating callback <{cb_conf._target_}>")
28 | callbacks.append(hydra.utils.instantiate(cb_conf))
29 |
30 | return callbacks
31 |
32 |
33 | def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
34 | """Instantiates loggers from config."""
35 |
36 | logger: List[Logger] = []
37 |
38 | if not logger_cfg:
39 | log.warning("No logger configs found! Skipping...")
40 | return logger
41 |
42 | if not isinstance(logger_cfg, DictConfig):
43 | raise TypeError("Logger config must be a DictConfig!")
44 |
45 | for _, lg_conf in logger_cfg.items():
46 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
47 | log.info(f"Instantiating logger <{lg_conf._target_}>")
48 | logger.append(hydra.utils.instantiate(lg_conf))
49 |
50 | return logger
51 |
--------------------------------------------------------------------------------
/src/utils/logging_utils.py:
--------------------------------------------------------------------------------
1 | from lightning.pytorch.utilities import rank_zero_only
2 |
3 | from src.utils import pylogger
4 |
5 | log = pylogger.get_pylogger(__name__)
6 |
7 |
8 | @rank_zero_only
9 | def log_hyperparameters(object_dict: dict) -> None:
10 | """Controls which config parts are saved by lightning loggers.
11 |
12 | Additionally saves:
13 | - Number of model parameters
14 | """
15 |
16 | hparams = {}
17 |
18 | cfg = object_dict["cfg"]
19 | model = object_dict["model"]
20 | trainer = object_dict["trainer"]
21 |
22 | if not trainer.logger:
23 | log.warning("Logger not found! Skipping hyperparameter logging...")
24 | return
25 |
26 | hparams["model"] = cfg["model"]
27 |
28 | # save number of model parameters
29 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
30 | hparams["model/params/trainable"] = sum(
31 | p.numel() for p in model.parameters() if p.requires_grad
32 | )
33 | hparams["model/params/non_trainable"] = sum(
34 | p.numel() for p in model.parameters() if not p.requires_grad
35 | )
36 |
37 | hparams["data"] = cfg["data"]
38 | hparams["trainer"] = cfg["trainer"]
39 |
40 | hparams["callbacks"] = cfg.get("callbacks")
41 | hparams["extras"] = cfg.get("extras")
42 |
43 | hparams["task_name"] = cfg.get("task_name")
44 | hparams["tags"] = cfg.get("tags")
45 | hparams["ckpt_path"] = cfg.get("ckpt_path")
46 | hparams["seed"] = cfg.get("seed")
47 |
48 | # send hparams to all loggers
49 | for logger in trainer.loggers:
50 | logger.log_hyperparams(hparams)
51 |
--------------------------------------------------------------------------------
/src/utils/pylogger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from lightning.pytorch.utilities import rank_zero_only
4 |
5 |
6 | def get_pylogger(name=__name__) -> logging.Logger:
7 | """Initializes multi-GPU-friendly python command line logger."""
8 |
9 | logger = logging.getLogger(name)
10 |
11 | # this ensures all logging levels get marked with the rank zero decorator
12 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup
13 | logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
14 | for level in logging_levels:
15 | setattr(logger, level, rank_zero_only(getattr(logger, level)))
16 |
17 | return logger
18 |
--------------------------------------------------------------------------------
/src/utils/rich_utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Sequence
3 |
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import rich
7 | import rich.syntax
8 | import rich.tree
9 | from hydra.core.hydra_config import HydraConfig
10 | from lightning.pytorch.utilities import rank_zero_only
11 | from omegaconf import DictConfig, OmegaConf, open_dict
12 | from rich.prompt import Prompt
13 | from sklearn.metrics import (
14 | ConfusionMatrixDisplay,
15 | classification_report,
16 | confusion_matrix,
17 | )
18 |
19 | from src.datasets.ember import EmberDataModule
20 | from src.utils import pylogger
21 |
22 | log = pylogger.get_pylogger(__name__)
23 |
24 |
25 | @rank_zero_only
26 | def print_config_tree(
27 | cfg: DictConfig,
28 | print_order: Sequence[str] = (
29 | "data",
30 | "model",
31 | "callbacks",
32 | "logger",
33 | "trainer",
34 | "paths",
35 | "extras",
36 | ),
37 | resolve: bool = False,
38 | save_to_file: bool = False,
39 | ) -> None:
40 | """Prints content of DictConfig using Rich library and its tree structure.
41 |
42 | Args:
43 | cfg (DictConfig): Configuration composed by Hydra.
44 | print_order (Sequence[str], optional): Determines in what order config components are printed.
45 | resolve (bool, optional): Whether to resolve reference fields of DictConfig.
46 | save_to_file (bool, optional): Whether to export config to the hydra output folder.
47 | """
48 |
49 | style = "dim"
50 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
51 |
52 | queue = []
53 |
54 | # add fields from `print_order` to queue
55 | for field in print_order:
56 | queue.append(field) if field in cfg else log.warning(
57 | f"Field '{field}' not found in config. Skipping '{field}' config printing..."
58 | )
59 |
60 | # add all the other fields to queue (not specified in `print_order`)
61 | for field in cfg:
62 | if field not in queue:
63 | queue.append(field)
64 |
65 | # generate config tree from queue
66 | for field in queue:
67 | branch = tree.add(field, style=style, guide_style=style)
68 |
69 | config_group = cfg[field]
70 | if isinstance(config_group, DictConfig):
71 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
72 | else:
73 | branch_content = str(config_group)
74 |
75 | branch.add(rich.syntax.Syntax(branch_content, "yaml"))
76 |
77 | # print config tree
78 | rich.print(tree)
79 |
80 | # save config tree to file
81 | if save_to_file:
82 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
83 | rich.print(tree, file=file)
84 |
85 |
86 | @rank_zero_only
87 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
88 | """Prompts user to input tags from command line if no tags are provided in config."""
89 |
90 | if not cfg.get("tags"):
91 | if "id" in HydraConfig().cfg.hydra.job:
92 | raise ValueError("Specify tags before launching a multirun!")
93 |
94 | log.warning("No tags provided in config. Prompting user to input tags...")
95 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
96 | tags = [t.strip() for t in tags.split(",") if t != ""]
97 |
98 | with open_dict(cfg):
99 | cfg.tags = tags
100 |
101 | log.info(f"Tags: {cfg.tags}")
102 |
103 | if save_to_file:
104 | tags = [str(i) for i in cfg.tags]
105 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
106 | rich.print(tags, file=file)
107 |
108 |
109 | @rank_zero_only
110 | def log_test_results(cfg: DictConfig, y_true: list, y_pred: list) -> None:
111 | y_true = np.array(y_true)
112 | y_pred = np.array(y_pred)
113 |
114 | # classification report
115 | cm = confusion_matrix(y_true, y_pred)
116 | cr = classification_report(y_true, y_pred, digits=4, zero_division=0)
117 | rich.print("Test Classification Report:")
118 | rich.print(cr)
119 |
120 | test_path = Path(cfg.paths.output_dir) / "test"
121 | test_path.mkdir(parents=True, exist_ok=True)
122 |
123 | cm_fig = ConfusionMatrixDisplay(cm).plot().figure_
124 | cm_fig.savefig(test_path / "confusion_matrix.png")
125 | np.savetxt(test_path / "confusion_matrix.txt", cm, fmt="%d")
126 |
127 | with open(test_path / "classification_report.log", "w") as file:
128 | rich.print(cr, file=file)
129 |
130 |
131 | def plot_features(
132 | X: np.array,
133 | labels: np.array,
134 | max_length: int = 4096,
135 | ) -> plt.figure:
136 | fig, ax = plt.subplots()
137 | for i, j in zip(X, labels):
138 | ax.plot(i[:max_length], label=j)
139 | ax.set_title("Example Features of Classes")
140 | ax.legend(title="class")
141 | ax.set_xlabel("id")
142 | ax.set_ylabel("value")
143 | return fig
144 |
145 |
146 | @rank_zero_only
147 | def log_data_summary(cfg: DictConfig, data_module: EmberDataModule) -> None:
148 | data_summary = data_module.summary()
149 | save_path = Path(cfg.paths.output_dir) / "data"
150 | save_path.mkdir(parents=True, exist_ok=True)
151 |
152 | with open(save_path / "data_summary.log", "w") as file:
153 | rich.print(data_summary, file=file)
154 |
155 | data_loader = data_module.test_dataloader()
156 | X, y = next(iter(data_loader))
157 |
158 | if isinstance(y, list):
159 | y = y[-1]
160 | if isinstance(X, list):
161 | X = X[0]
162 | X, y = X.numpy(), y.numpy()
163 |
164 | batch_size = X.shape[0]
165 | X = X.reshape(batch_size, -1)
166 |
167 | labels, indices = np.unique(y, return_index=True)
168 | X = X[indices]
169 | fig = plot_features(X, labels)
170 |
171 | fig.savefig(save_path / "data_example.png")
172 |
--------------------------------------------------------------------------------
/src/utils/utils.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from importlib.util import find_spec
3 | from pathlib import Path
4 | from typing import Callable
5 |
6 | import numpy as np
7 | from omegaconf import DictConfig
8 |
9 | from src.utils import pylogger, rich_utils
10 |
11 | log = pylogger.get_pylogger(__name__)
12 |
13 |
14 | def extras(cfg: DictConfig) -> None:
15 | """Applies optional utilities before the task is started.
16 |
17 | Utilities:
18 | - Ignoring python warnings
19 | - Setting tags from command line
20 | - Rich config printing
21 | """
22 |
23 | # return if no `extras` config
24 | if not cfg.get("extras"):
25 | log.warning("Extras config not found! ")
26 | return
27 |
28 | # disable python warnings
29 | if cfg.extras.get("ignore_warnings"):
30 | log.info("Disabling python warnings! ")
31 | warnings.filterwarnings("ignore")
32 |
33 | # prompt user to input tags from command line if none are provided in the config
34 | if cfg.extras.get("enforce_tags"):
35 | log.info("Enforcing tags! ")
36 | rich_utils.enforce_tags(cfg, save_to_file=True)
37 |
38 | # pretty print config tree using Rich library
39 | if cfg.extras.get("print_config"):
40 | log.info("Printing config tree with Rich! ")
41 | rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
42 |
43 |
44 | def task_wrapper(task_func: Callable) -> Callable:
45 | """Optional decorator that controls the failure behavior when executing the task function.
46 |
47 | This wrapper can be used to:
48 | - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
49 | - save the exception to a `.log` file
50 | - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
51 | - etc. (adjust depending on your needs)
52 |
53 | Example:
54 | ```
55 | @utils.task_wrapper
56 | def train(cfg: DictConfig) -> Tuple[dict, dict]:
57 |
58 | ...
59 |
60 | return metric_dict, object_dict
61 | ```
62 | """
63 |
64 | def wrap(cfg: DictConfig):
65 | # execute the task
66 | try:
67 | metric_dict, object_dict = task_func(cfg=cfg)
68 |
69 | # things to do if exception occurs
70 | except Exception as ex:
71 | # save exception to `.log` file
72 | log.exception("")
73 |
74 | # some hyperparameter combinations might be invalid or cause out-of-memory errors
75 | # so when using hparam search plugins like Optuna, you might want to disable
76 | # raising the below exception to avoid multirun failure
77 | raise ex
78 |
79 | # things to always do after either success or exception
80 | finally:
81 | # display output dir path in terminal
82 | log.info(f"Output dir: {cfg.paths.output_dir}")
83 |
84 | # always close wandb run (even if exception occurs so multirun won't fail)
85 | if find_spec("wandb"): # check if wandb is installed
86 | import wandb
87 |
88 | if wandb.run:
89 | log.info("Closing wandb!")
90 | wandb.finish()
91 |
92 | return metric_dict, object_dict
93 |
94 | return wrap
95 |
96 |
97 | def get_metric_value(metric_dict: dict, metric_name: str) -> float:
98 | """Safely retrieves value of the metric logged in LightningModule."""
99 |
100 | if not metric_name:
101 | log.info("Metric name is None! Skipping metric value retrieval...")
102 | return None
103 |
104 | if metric_name not in metric_dict:
105 | raise Exception(
106 | f"Metric value not found! \n"
107 | "Make sure metric name logged in LightningModule is correct!\n"
108 | "Make sure `optimized_metric` name in `hparams_search` config is correct!"
109 | )
110 |
111 | metric_value = metric_dict[metric_name].item()
112 | log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
113 |
114 | return metric_value
115 |
116 |
117 | def log_confusion_matrix(cfg: DictConfig, cm: np.array) -> None:
118 | save_file = Path(cfg.paths.output_dir) / "test_confusion_matrix.txt"
119 | with open(save_file, "w") as file:
120 | np.savetxt(file, cm, fmt="%d")
121 |
122 |
123 | def log_data_counter(cfg: DictConfig, data: list) -> None:
124 | save_file = Path(cfg.paths.output_dir) / "test_data_counter.txt"
125 | with open(save_file, "w") as file:
126 | for k, v in data:
127 | file.write(f"{k}: {v}\n")
128 |
--------------------------------------------------------------------------------