├── .gitignore ├── .project-root ├── README.md ├── about ├── download.md └── issues.md ├── configs ├── .gitkeep ├── __init__.py ├── callbacks │ ├── default.yaml │ ├── early_stopping.yaml │ ├── model_checkpoint.yaml │ ├── model_summary.yaml │ ├── none.yaml │ └── rich_progress_bar.yaml ├── data │ ├── bytes.yaml │ └── ember.yaml ├── debug │ ├── default.yaml │ ├── fdr.yaml │ ├── limit.yaml │ ├── overfit.yaml │ └── profiler.yaml ├── eval.yaml ├── experiment │ ├── example.yaml │ ├── malconv-bytes-test.yaml │ ├── malconv-bytes-train.yaml │ ├── mlp-ember-test.yaml │ └── mlp-ember-train.yaml ├── extras │ └── default.yaml ├── hparams_search │ └── mnist_optuna.yaml ├── hydra │ └── default.yaml ├── local │ └── .gitkeep ├── logger │ └── wandb.yaml ├── model │ ├── malconv.yaml │ └── mlp.yaml ├── paths │ └── default.yaml ├── train.yaml └── trainer │ ├── cpu.yaml │ ├── ddp.yaml │ ├── ddp_sim.yaml │ ├── default.yaml │ ├── gpu.yaml │ └── mps.yaml ├── detect └── mlp_ember.py ├── requirements.txt ├── scripts ├── detect_mlp_ember_drift.sh ├── test_gbdt_ember.sh ├── test_malconv_bytes.sh ├── test_mlp_ember.sh ├── train_gbdt_ember.sh ├── train_malconv_bytes.sh └── train_mlp_ember.sh └── src ├── __init__.py ├── datasets ├── __init__.py ├── bytes.py ├── ember.py └── mfc.py ├── eval.py ├── eval_gbdt.py ├── models ├── __init__.py ├── gbdt.py ├── malconv.py ├── malconv_module.py ├── mlp.py └── mlp_module.py ├── train.py ├── train_gbdt.py └── utils ├── __init__.py ├── instantiators.py ├── logging_utils.py ├── pylogger.py ├── rich_utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | Temporary and binary files 2 | *~ 3 | *.py[cod] 4 | *.so 5 | *.cfg 6 | !.isort.cfg 7 | !setup.cfg 8 | *.orig 9 | *.log 10 | *.pot 11 | __pycache__/* 12 | .cache/* 13 | .*.swp 14 | */.ipynb_checkpoints/* 15 | .DS_Store 16 | 17 | # Project files 18 | .ropeproject 19 | .project 20 | .pydevproject 21 | .settings 22 | .idea 23 | .vscode 24 | tags 25 | 26 | # Package files 27 | *.egg 28 | *.eggs/ 29 | .installed.cfg 30 | *.egg-info 31 | 32 | # Unittest and coverage 33 | htmlcov/* 34 | .coverage 35 | .coverage.* 36 | .tox 37 | junit*.xml 38 | coverage.xml 39 | .pytest_cache/ 40 | 41 | # Build and docs folder/files 42 | build/* 43 | dist/* 44 | sdist/* 45 | docs/api/* 46 | docs/_rst/* 47 | docs/_build/* 48 | cover/* 49 | MANIFEST 50 | 51 | # Per-project virtualenvs 52 | .venv*/ 53 | .conda*/ 54 | logs/* 55 | 56 | # data 57 | logs/* 58 | detect/*.csv -------------------------------------------------------------------------------- /.project-root: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowdma/benchmfc/106e5d773e6fdf2bac429446183f996b3ece8e03/.project-root -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 |
5 | 6 | 7 | 8 |

BenchMFC

9 | 10 |

11 | A Benchmark Dataset for Trustworthy Malware Family Classification under Concept Drift 12 |

13 |
14 | 15 | 16 | 17 | 18 | ## Abstract 19 | 20 | Concept drift poses a critical challenge in deploying machine learning models to mitigate practical malware threats. It refers to the phenomenon that the distribution of test data changes over time, gradually deviating from the original training data and degrading model performance. A promising direction for addressing concept drift is to detect drift samples and then retrain the model. However, this field currently lacks a unified, well-curated, and comprehensive benchmark, which often leads to unfair comparisons and inconclusive outcomes. To improve the evaluation and advance further, this paper presents a new Benchmark dataset for trustworthy Malware Family Classification (BenchMFC), which includes 223K samples of 526 families that evolve over years. BenchMFC provides clear family, packer, and timestamp tags for each sample, it thus can support research on three types of malware concept drift: 1) unseen families, 2) packed families, and 3) evolved families. To collect unpacked family samples from large-scale candidates, we introduce a novel crowdsourcing malware annotation pipeline, which unifies packing detection and family annotation as a consensus inference problem to prevent costly packing detection. Moreover, we provide two case studies to illustrate the application of BenchMFC in 1) concept drift detection and 2) model retraining. The first case demonstrates the impact of three types of malware concept drift and compares nine notable concept drift detectors. The results show that existing detectors have their own advantages in dealing with different types of malware concept drift, and there is still room for improvement in malware concept drift detection. The second case explores how static feature-based machine learning operates on packed samples when retraining a model. The experiments illustrate that packers do preserve some kind of signals that appear to be “effective” for machine learning models, but the robustness of these signals requires further research. BenchMFC has been released to the community at https://github.com/crowdma/benchmfc. 21 | 22 | 23 | ## Reference 24 | This paper has been accepted by Computers & Security: 25 | ``` 26 | @article{jiang_2024, 27 | title = {BenchMFC: A benchmark dataset for trustworthy malware family classification under concept drift}, 28 | author = {Yongkang Jiang and Gaolei Li and Shenghong Li and Ying Guo}, 29 | journal = {Computers & Security}, 30 | volume = {139}, 31 | pages = {103706}, 32 | year = {2024}, 33 | } 34 | ``` 35 | 36 | 37 | ## Dataset 38 | 39 | ### Size 40 | 41 | ``` 42 | ├── benchmfc_meta.csv (Metadata file for the dataset ~17M) 43 | ├── benchmfc.tar.gz (Samples ~83G) 44 | └── mfc (Experimental data used in the paper) 45 | ├── mfc_features.tar.gz (Ember features ~39M) 46 | ├── mfc_meta.csv (Metadata file ~1M) 47 | └── mfc_samples.tar.gz (Samples ~7G) 48 | ``` 49 | 50 | ### Download 51 | Please visit this [link](about/download.md) for more details. 52 | 53 | 54 | ## Getting Started 55 | 56 | ### Installation 57 | 58 | - Run the following commands: 59 | ```sh 60 | # python = "<3.10 >=3.9" 61 | git clone https://github.com/crowdma/benchmfc.git 62 | cd benchmfc 63 | pip install -r requirements.txt 64 | ``` 65 | 66 | 67 | ## Usage Examples 68 | 69 | - Env 70 | 71 | ```sh 72 | export MFC_ROOT=// 73 | # MFC structure 74 | ├── feature-ember-npy 75 | │   ├── malicious 76 | │   ├── malicious-unseen 77 | │   ├── malicious-evolving 78 | │   ├── malicious-aes 79 | │   ├── malicious-mpress 80 | │   └── malicious-upx 81 | └── samples 82 | ├── malicious 83 | ├── malicious-unseen 84 | ├── malicious-evolving 85 | ├── malicious-aes 86 | ├── malicious-mpress 87 | └── malicious-upx 88 | ``` 89 | 90 | - Train 91 | ```sh 92 | /bin/bash scripts/train_mlp_ember.sh 93 | ``` 94 | 95 | - Test 96 | ```sh 97 | /bin/bash scripts/test_mlp_ember.sh 98 | ``` 99 | - Detect Drift 100 | ```sh 101 | /bin/bash scripts/detect_mlp_ember_drift.sh 102 | ``` 103 | 104 | ## Issues 105 | 106 | Please visit this [link](about/issues.md) for known issues. 107 | 108 | 109 | 110 | ## License 111 | 112 | Distributed under the MIT License. -------------------------------------------------------------------------------- /about/download.md: -------------------------------------------------------------------------------- 1 | ## Download 2 | 3 | All samples in the dataset were not disarmed. To avoid misuse, please read and agree to the following conditions before sending us emails. 4 | 5 | - Please email Yongkang (jiangyongkang@alumni.sjtu.edu.cn). 6 | - Do not share the data with any others (except your co-authors for the project). We are happy to share with other researchers based upon their requests. 7 | - Explain in a few sentences of your plan to do with these binaries. It should not be a precise plan. 8 | - If you are in academia, contact us using your institution email and provide us a webpage registered at the university domain that contains your name and affiliation. 9 | - If you are in research (industrial) labs, email us from your company’s email account and introduce yourself and company. In the email, please attach a justification letter (in PDF format) in official letterhead. The letter needs to state clearly the reasons why this dataset is being requested. 10 | 11 | Please note that an email not following the conditions might be ignored. And we will keep the public list of organizations accessing these samples at the bottom. 12 | 13 | 14 | ## Organizations Requested Our Dataset 15 | 16 | 1. Wuhan University 17 | 2. Huazhong University of Science and Technology 18 | 3. Southeast University 19 | 4. Taibah University 20 | 5. Université catholique de Louvain 21 | 6. University of Alberta 22 | 7. IMDEA Networks Institute, Universidad Carlos III de Madrid (U3CM) 23 | 8. Indian Institute of Technology, Indore 24 | 9. Hebei Normal University 25 | 10. Ludwig-Maximilians-Universität München (LMU) 26 | 11. Fast University Karachi 27 | 12. Beijing University of Post and Telecommunications 28 | 13. Queen's University Belfast 29 | 14. Korea University 30 | 15. University of Palermo 31 | 16. University of Luxembourg 32 | -------------------------------------------------------------------------------- /about/issues.md: -------------------------------------------------------------------------------- 1 | ## About test 2 | 3 | - In Fig.9, we only illustrated 40% of the test samples. 4 | 5 | -------------------------------------------------------------------------------- /configs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowdma/benchmfc/106e5d773e6fdf2bac429446183f996b3ece8e03/configs/.gitkeep -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | # this file is needed here to include configs when building project as a package 2 | -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model_checkpoint.yaml 3 | - early_stopping.yaml 4 | - model_summary.yaml 5 | - rich_progress_bar.yaml 6 | - _self_ 7 | 8 | model_checkpoint: 9 | dirpath: ${paths.output_dir}/checkpoints 10 | filename: "epoch_{epoch:03d}" 11 | monitor: "val/acc" 12 | mode: "max" 13 | save_last: True 14 | auto_insert_metric_name: False 15 | 16 | early_stopping: 17 | monitor: "val/acc" 18 | patience: 100 19 | mode: "max" 20 | 21 | model_summary: 22 | max_depth: -1 23 | -------------------------------------------------------------------------------- /configs/callbacks/early_stopping.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html 2 | 3 | early_stopping: 4 | _target_: lightning.pytorch.callbacks.EarlyStopping 5 | monitor: ??? # quantity to be monitored, must be specified !!! 6 | min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement 7 | patience: 3 # number of checks with no improvement after which training will be stopped 8 | verbose: False # verbosity mode 9 | mode: "min" # "max" means higher metric value is better, can be also "min" 10 | strict: True # whether to crash the training if monitor is not found in the validation metrics 11 | check_finite: True # when set True, stops training when the monitor becomes NaN or infinite 12 | stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold 13 | divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold 14 | check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch 15 | # log_rank_zero_only: False # this keyword argument isn't available in stable version 16 | -------------------------------------------------------------------------------- /configs/callbacks/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html 2 | 3 | model_checkpoint: 4 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 5 | dirpath: null # directory to save the model file 6 | filename: null # checkpoint filename 7 | monitor: null # name of the logged metric which determines when model is improving 8 | verbose: False # verbosity mode 9 | save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt 10 | save_top_k: 1 # save k best models (determined by above metric) 11 | mode: "min" # "max" means higher metric value is better, can be also "min" 12 | auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name 13 | save_weights_only: False # if True, then only the model’s weights will be saved 14 | every_n_train_steps: null # number of training steps between checkpoints 15 | train_time_interval: null # checkpoints are monitored at the specified time interval 16 | every_n_epochs: null # number of epochs between checkpoints 17 | save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation 18 | -------------------------------------------------------------------------------- /configs/callbacks/model_summary.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html 2 | 3 | model_summary: 4 | _target_: lightning.pytorch.callbacks.RichModelSummary 5 | max_depth: 1 # the maximum depth of layer nesting that the summary will include 6 | -------------------------------------------------------------------------------- /configs/callbacks/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowdma/benchmfc/106e5d773e6fdf2bac429446183f996b3ece8e03/configs/callbacks/none.yaml -------------------------------------------------------------------------------- /configs/callbacks/rich_progress_bar.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html 2 | 3 | rich_progress_bar: 4 | _target_: lightning.pytorch.callbacks.RichProgressBar 5 | -------------------------------------------------------------------------------- /configs/data/bytes.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datasets.bytes.BytesDataModule 2 | 3 | num_workers: 16 -------------------------------------------------------------------------------- /configs/data/ember.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datasets.ember.EmberDataModule 2 | 3 | train_size: 0.6 4 | val_size: 0.2 5 | test_size: 0.2 6 | batch_size: 32 -------------------------------------------------------------------------------- /configs/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | # overwrite task name so debugging logs are stored in separate folder 7 | task_name: "debug" 8 | 9 | # disable callbacks and loggers during debugging 10 | callbacks: null 11 | logger: null 12 | 13 | extras: 14 | ignore_warnings: False 15 | enforce_tags: False 16 | 17 | # sets level of all command line loggers to 'DEBUG' 18 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 19 | hydra: 20 | job_logging: 21 | root: 22 | level: DEBUG 23 | 24 | # use this to also set hydra loggers to 'DEBUG' 25 | # verbose: True 26 | 27 | trainer: 28 | max_epochs: 1 29 | accelerator: cpu # debuggers don't like gpus 30 | devices: 1 # debuggers don't like multiprocessing 31 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor 32 | 33 | data: 34 | num_workers: 0 # debuggers don't like multiprocessing 35 | pin_memory: False # disable gpu memory pin 36 | -------------------------------------------------------------------------------- /configs/debug/fdr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs 1 train, 1 validation and 1 test step 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | fast_dev_run: true 10 | -------------------------------------------------------------------------------- /configs/debug/limit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # uses only 1% of the training data and 5% of validation/test data 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 3 10 | limit_train_batches: 0.01 11 | limit_val_batches: 0.05 12 | limit_test_batches: 0.05 13 | -------------------------------------------------------------------------------- /configs/debug/overfit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # overfits to 3 batches 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 20 10 | overfit_batches: 3 11 | 12 | # model ckpt and early stopping need to be disabled during overfitting 13 | callbacks: null 14 | -------------------------------------------------------------------------------- /configs/debug/profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs with execution time profiling 4 | 5 | defaults: 6 | - default.yaml 7 | 8 | trainer: 9 | max_epochs: 1 10 | profiler: "simple" 11 | # profiler: "advanced" 12 | # profiler: "pytorch" 13 | -------------------------------------------------------------------------------- /configs/eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - data: ember.yaml # choose datamodule with `test_dataloader()` for evaluation 6 | - model: mlp.yaml 7 | - logger: null 8 | - trainer: default.yaml 9 | - paths: default.yaml 10 | - extras: default.yaml 11 | - hydra: default.yaml 12 | 13 | - experiment: null 14 | 15 | task_name: "default" 16 | train_eval: "eval" 17 | 18 | tags: ["dev"] 19 | 20 | # passing checkpoint path is necessary for evaluation 21 | ckpt_path: ??? 22 | -------------------------------------------------------------------------------- /configs/experiment/example.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: ember.yaml 8 | - override /model: ember.yaml 9 | - override /callbacks: default.yaml 10 | - override /trainer: default.yaml 11 | 12 | # all parameters below will be merged with parameters from default configurations set above 13 | # this allows you to overwrite only specified parameters 14 | 15 | tags: ["ember", "simple_dense_net"] 16 | 17 | seed: 12345 18 | 19 | trainer: 20 | min_epochs: 10 21 | max_epochs: 10 22 | gradient_clip_val: 0.5 23 | 24 | model: 25 | optimizer: 26 | lr: 0.002 27 | net: 28 | lin1_size: 128 29 | lin2_size: 256 30 | lin3_size: 64 31 | 32 | data: 33 | batch_size: 64 34 | 35 | logger: 36 | wandb: 37 | tags: ${tags} 38 | group: "ember" 39 | aim: 40 | experiment: "ember" 41 | -------------------------------------------------------------------------------- /configs/experiment/malconv-bytes-test.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: bytes.yaml 8 | - override /model: malconv.yaml 9 | - override /trainer: default.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | # name of the run determines folder name in logs 15 | data_name: MFC 16 | pack_ratio: 0.0 17 | task_name: malconv-bytes-MFC-0.0 18 | seed: 42 19 | 20 | tags: ["${task_name}", "${data_name}", "${pack_ratio}"] 21 | 22 | trainer: 23 | accelerator: gpu 24 | min_epochs: 20 25 | max_epochs: 50 26 | gradient_clip_val: 0.5 27 | 28 | model: 29 | optimizer: 30 | lr: 0.001 31 | network: 32 | input_length: 1_048_576 33 | window_size: 500 34 | stride: 500 35 | channels: 128 36 | embed_size: 8 37 | output_size: 8 38 | 39 | data: 40 | data_name: ${data_name} 41 | train_size: 0.6 42 | val_size: 0.2 43 | test_size: 0.2 44 | batch_size: 32 45 | num_workers: 16 46 | pack_ratio: ${pack_ratio} 47 | first_n_byte: 1_048_576 48 | 49 | ckpt_path: ${paths.root_dir}/logs/malconv-bytes-MFC-0.0/train/runs/2023-07-30_18-23-58/checkpoints/epoch_020.ckpt -------------------------------------------------------------------------------- /configs/experiment/malconv-bytes-train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: bytes.yaml 8 | - override /model: malconv.yaml 9 | - override /callbacks: default.yaml 10 | - override /logger: wandb.yaml 11 | - override /trainer: default.yaml 12 | 13 | # all parameters below will be merged with parameters from default configurations set above 14 | # this allows you to overwrite only specified parameters 15 | 16 | # name of the run determines folder name in logs 17 | data_name: MFC 18 | pack_ratio: 0.0 19 | task_name: malconv-bytes-MFC-0.0 20 | seed: 42 21 | 22 | tags: ["${task_name}", "${data_name}", "${pack_ratio}"] 23 | 24 | trainer: 25 | accelerator: gpu 26 | min_epochs: 20 27 | max_epochs: 50 28 | gradient_clip_val: 0.5 29 | 30 | model: 31 | optimizer: 32 | lr: 0.001 33 | network: 34 | input_length: 1_048_576 35 | window_size: 500 36 | stride: 500 37 | channels: 128 38 | embed_size: 8 39 | output_size: 8 40 | 41 | data: 42 | data_name: ${data_name} 43 | train_size: 0.6 44 | val_size: 0.2 45 | test_size: 0.2 46 | batch_size: 32 47 | num_workers: 16 48 | pack_ratio: ${pack_ratio} 49 | first_n_byte: 1_048_576 50 | 51 | logger: 52 | wandb: 53 | name: ${task_name} 54 | group: malconv-bytes 55 | project: lab-benchmfc 56 | -------------------------------------------------------------------------------- /configs/experiment/mlp-ember-test.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: ember.yaml 8 | - override /model: mlp.yaml 9 | - override /trainer: default.yaml 10 | 11 | # all parameters below will be merged with parameters from default configurations set above 12 | # this allows you to overwrite only specified parameters 13 | 14 | # name of the run determines folder name in logs 15 | data_name: MFC 16 | pack_ratio: 0.0 17 | task_name: mlp-ember-MFC-0.0 18 | seed: 42 19 | 20 | tags: ["${task_name}", "${data_name}", "${pack_ratio}"] 21 | 22 | trainer: 23 | accelerator: gpu 24 | min_epochs: 20 25 | max_epochs: 50 26 | gradient_clip_val: 0.5 27 | 28 | model: 29 | optimizer: 30 | lr: 0.001 31 | network: 32 | input_size: 2381 33 | hidden_units: [1024, 512, 256] 34 | output_size: 8 35 | 36 | data: 37 | data_name: ${data_name} 38 | train_size: 0.6 39 | val_size: 0.2 40 | test_size: 0.2 41 | batch_size: 32 42 | pack_ratio: ${pack_ratio} 43 | 44 | ckpt_path: ${paths.root_dir}/logs/mlp-ember-MFC-0.0/train/runs/2023-07-30_11-54-21/checkpoints/epoch_017.ckpt -------------------------------------------------------------------------------- /configs/experiment/mlp-ember-train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=example 5 | 6 | defaults: 7 | - override /data: ember.yaml 8 | - override /model: mlp.yaml 9 | - override /callbacks: default.yaml 10 | - override /logger: wandb.yaml 11 | - override /trainer: default.yaml 12 | 13 | # all parameters below will be merged with parameters from default configurations set above 14 | # this allows you to overwrite only specified parameters 15 | 16 | # name of the run determines folder name in logs 17 | data_name: MFC 18 | pack_ratio: 0.0 19 | task_name: mlp-ember-MFC-0.0 20 | seed: 42 21 | 22 | tags: ["${task_name}", "${data_name}", "${pack_ratio}"] 23 | 24 | trainer: 25 | accelerator: gpu 26 | min_epochs: 20 27 | max_epochs: 50 28 | gradient_clip_val: 0.5 29 | 30 | model: 31 | optimizer: 32 | lr: 0.001 33 | network: 34 | input_size: 2381 35 | hidden_units: [1024, 512, 256] 36 | output_size: 8 37 | 38 | data: 39 | data_name: ${data_name} 40 | train_size: 0.6 41 | val_size: 0.2 42 | test_size: 0.2 43 | batch_size: 32 44 | pack_ratio: ${pack_ratio} 45 | 46 | logger: 47 | wandb: 48 | name: ${task_name} 49 | group: mlp-ember 50 | project: lab-benchmfc 51 | -------------------------------------------------------------------------------- /configs/extras/default.yaml: -------------------------------------------------------------------------------- 1 | # disable python warnings if they annoy you 2 | ignore_warnings: False 3 | 4 | # ask user for tags if none are provided in the config 5 | enforce_tags: True 6 | 7 | # pretty print config tree at the start of the run using Rich library 8 | print_config: True 9 | -------------------------------------------------------------------------------- /configs/hparams_search/mnist_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python train.py -m hparams_search=ember_optuna experiment=example 5 | 6 | defaults: 7 | - override /hydra/sweeper: optuna 8 | 9 | # choose metric which will be optimized by Optuna 10 | # make sure this is the correct name of some metric logged in lightning module! 11 | optimized_metric: "val/acc_best" 12 | 13 | # here we define Optuna hyperparameter search 14 | # it optimizes for value returned from function with @hydra.main decorator 15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper 16 | hydra: 17 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 18 | 19 | sweeper: 20 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 21 | 22 | # storage URL to persist optimization results 23 | # for example, you can use SQLite if you set 'sqlite:///example.db' 24 | storage: null 25 | 26 | # name of the study to persist optimization results 27 | study_name: null 28 | 29 | # number of parallel workers 30 | n_jobs: 1 31 | 32 | # 'minimize' or 'maximize' the objective 33 | direction: maximize 34 | 35 | # total number of runs that will be executed 36 | n_trials: 20 37 | 38 | # choose Optuna hyperparameter sampler 39 | # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others 40 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html 41 | sampler: 42 | _target_: optuna.samplers.TPESampler 43 | seed: 1234 44 | n_startup_trials: 10 # number of random sampling runs before optimization starts 45 | 46 | # define hyperparameter search space 47 | params: 48 | model.optimizer.lr: interval(0.0001, 0.1) 49 | data.batch_size: choice(32, 64, 128, 256) 50 | model.net.lin1_size: choice(64, 128, 256) 51 | model.net.lin2_size: choice(64, 128, 256) 52 | model.net.lin3_size: choice(32, 64, 128, 256) 53 | -------------------------------------------------------------------------------- /configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${paths.log_dir}/${task_name}/${train_eval}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} 11 | sweep: 12 | dir: ${paths.log_dir}/${task_name}/${train_eval}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} 13 | subdir: ${hydra.job.num} 14 | -------------------------------------------------------------------------------- /configs/local/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowdma/benchmfc/106e5d773e6fdf2bac429446183f996b3ece8e03/configs/local/.gitkeep -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: lightning.pytorch.loggers.wandb.WandbLogger 5 | # name: "" # name of the run (normally generated by wandb) 6 | save_dir: "${paths.output_dir}" 7 | offline: False 8 | id: null # pass correct id to resume experiment! 9 | anonymous: null # enable anonymous logging 10 | project: "lightning-hydra-template" 11 | log_model: True # upload lightning ckpts 12 | prefix: "" # a string to put at the beginning of metric keys 13 | # entity: "" # set to name of your wandb team 14 | group: "" 15 | tags: [] 16 | job_type: "" 17 | -------------------------------------------------------------------------------- /configs/model/malconv.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.malconv_module.MalconvModule 2 | 3 | optimizer: 4 | _target_: torch.optim.Adam 5 | _partial_: true 6 | lr: 0.001 7 | weight_decay: 0.0 8 | 9 | scheduler: 10 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 11 | _partial_: true 12 | mode: min 13 | factor: 0.1 14 | patience: 10 15 | 16 | network: 17 | _target_: src.models.malconv.MalConv 18 | input_length: 1_048_576 19 | window_size: 500 20 | stride: 500 21 | channels: 128 22 | embed_size: 8 23 | output_size: 8 24 | -------------------------------------------------------------------------------- /configs/model/mlp.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.mlp_module.MLPModule 2 | 3 | optimizer: 4 | _target_: torch.optim.Adam 5 | _partial_: true 6 | lr: 0.001 7 | weight_decay: 0.0 8 | 9 | scheduler: 10 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 11 | _partial_: true 12 | mode: min 13 | factor: 0.1 14 | patience: 10 15 | 16 | network: 17 | _target_: src.models.mlp.MLP 18 | input_size: 2381 19 | hidden_units: [1024, 512, 256] 20 | -------------------------------------------------------------------------------- /configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # you can replace it with "." if you want the root to be the current working directory 4 | root_dir: ${oc.env:PROJECT_ROOT} 5 | 6 | # path to data directory 7 | data_dir: ${paths.root_dir}/data/ 8 | 9 | # path to logging directory 10 | log_dir: ${paths.root_dir}/logs/ 11 | 12 | # path to output directory, created dynamically by hydra 13 | # path generation pattern is specified in `configs/hydra/default.yaml` 14 | # use it to store all files generated during the run, like ckpts and metrics 15 | output_dir: ${hydra:runtime.output_dir} 16 | 17 | # path to working directory 18 | work_dir: ${hydra:runtime.cwd} 19 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default configuration 4 | # order of defaults determines the order in which configs override each other 5 | defaults: 6 | - _self_ 7 | - data: mnist.yaml 8 | - model: mnist.yaml 9 | - callbacks: default.yaml 10 | - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`) 11 | - trainer: default.yaml 12 | - paths: default.yaml 13 | - extras: default.yaml 14 | - hydra: default.yaml 15 | 16 | # experiment configs allow for version control of specific hyperparameters 17 | # e.g. best hyperparameters for given model and datamodule 18 | - experiment: null 19 | 20 | # config for hyperparameter optimization 21 | - hparams_search: null 22 | 23 | # optional local config for machine/user specific settings 24 | # it's optional since it doesn't need to exist and is excluded from version control 25 | - optional local: default.yaml 26 | 27 | # debugging config (enable through command line, e.g. `python train.py debug=default) 28 | - debug: null 29 | 30 | # task name, determines output directory path 31 | task_name: "train" 32 | train_eval: "train" 33 | 34 | # tags to help you identify your experiments 35 | # you can overwrite this in experiment configs 36 | # overwrite from command line with `python train.py tags="[first_tag, second_tag]"` 37 | tags: ["dev"] 38 | 39 | # set False to skip model training 40 | train: True 41 | 42 | # evaluate on test set, using best model weights achieved during training 43 | # lightning chooses best weights based on the metric specified in checkpoint callback 44 | test: True 45 | 46 | # compile model for faster training with pytorch 2.0 47 | compile: False 48 | 49 | # simply provide checkpoint path to resume training 50 | ckpt_path: null 51 | 52 | # seed for random number generators in pytorch, numpy and python.random 53 | seed: null 54 | -------------------------------------------------------------------------------- /configs/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: cpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | # use "ddp_spawn" instead of "ddp", 5 | # it's slower but normal "ddp" currently doesn't work ideally with hydra 6 | # https://github.com/facebookresearch/hydra/issues/2070 7 | # https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_intermediate.html#distributed-data-parallel-spawn 8 | strategy: ddp_spawn 9 | 10 | accelerator: gpu 11 | devices: 4 12 | num_nodes: 1 13 | sync_batchnorm: True 14 | -------------------------------------------------------------------------------- /configs/trainer/ddp_sim.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | # simulate DDP on CPU, useful for debugging 5 | accelerator: cpu 6 | devices: 2 7 | strategy: ddp_spawn 8 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: lightning.pytorch.trainer.Trainer 2 | 3 | default_root_dir: ${paths.output_dir} 4 | 5 | min_epochs: 1 # prevents early stopping 6 | max_epochs: 10 7 | 8 | accelerator: cpu 9 | devices: 1 10 | 11 | # mixed precision for extra speed-up 12 | # precision: 16 13 | 14 | # perform a validation loop every N training epochs 15 | check_val_every_n_epoch: 1 16 | 17 | # set True to to ensure deterministic results 18 | # makes training slower but gives more reproducibility than just setting seeds 19 | deterministic: False 20 | -------------------------------------------------------------------------------- /configs/trainer/gpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: gpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /configs/trainer/mps.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | accelerator: mps 5 | devices: 1 6 | -------------------------------------------------------------------------------- /detect/mlp_ember.py: -------------------------------------------------------------------------------- 1 | """detect concept drift on mlp-ember""" 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import pyrootutils 8 | import torch 9 | import typer 10 | from loguru import logger 11 | from pytorch_ood.detector import ( 12 | ODIN, 13 | EnergyBased, 14 | Entropy, 15 | KLMatching, 16 | Mahalanobis, 17 | MaxLogit, 18 | MaxSoftmax, 19 | ViM, 20 | ) 21 | from pytorch_ood.utils import OODMetrics 22 | from torch.utils.data import ConcatDataset, DataLoader, Dataset 23 | 24 | ROOT = pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 25 | 26 | from src.datasets.ember import EmberDataModule 27 | from src.models.mlp import MLP 28 | 29 | app = typer.Typer(add_completion=False) 30 | 31 | 32 | class CDDataset(Dataset): 33 | def __init__(self, X: list[np.array], y: list[int]): 34 | self.X = X 35 | self.y = y 36 | 37 | def __len__(self) -> int: 38 | return len(self.y) 39 | 40 | def __getitem__(self, index: int) -> tuple: 41 | data, target = self.X[index], self.y[index] 42 | return data, target 43 | 44 | 45 | def seed_everything(seed: int): 46 | os.environ["PL_GLOBAL_SEED"] = str(seed) 47 | random.seed(seed) 48 | np.random.seed(seed) 49 | torch.manual_seed(seed) 50 | torch.cuda.manual_seed_all(seed) 51 | 52 | 53 | @app.command() 54 | def main( 55 | model_file: str = None, 56 | data_name: str = None, 57 | pack_ratio: float = 0.0, 58 | device: str = "cuda:0", 59 | seed: int = 42, 60 | ): 61 | # seed 62 | logger.info(f"Seedeverthing with {seed}") 63 | seed_everything(seed) 64 | # load model 65 | logger.info(f"Loading model from {model_file}") 66 | model = MLP() 67 | model.load_state_dict(torch.load(ROOT / model_file)) 68 | model = model.eval().to(device) 69 | last_layer = model.model[-1] 70 | 71 | # load training data 72 | logger.info("Loading ID data MFC") 73 | ID = EmberDataModule(data_name="MFC") 74 | ID.setup() 75 | ID_train_loader = ID.train_dataloader() 76 | ID_test = ID.data_test 77 | 78 | # hit CD data 79 | logger.info(f"Hiting CD from {data_name}") 80 | if data_name in ["MFCUnseen", "MFCUnseenPacking"]: 81 | CD = EmberDataModule(data_name=data_name, pack_ratio=pack_ratio) 82 | CD.setup() 83 | CD_test = CD.data_test 84 | CD_test.y = [-1 for _ in CD_test.y] 85 | elif data_name in ["MFCPacking", "MFCEvolving"]: 86 | len(ID_test) 87 | # find cd data 88 | CD = EmberDataModule(data_name=data_name, pack_ratio=pack_ratio) 89 | CD.setup() 90 | CD_train_loader = CD.train_dataloader() 91 | CD_X = [] 92 | with torch.no_grad(): 93 | for x, y in CD_train_loader: 94 | logits = model(x.to(device)) 95 | preds = torch.argmax(logits, dim=1).tolist() 96 | for i, p in enumerate(preds): 97 | if p != y[i]: 98 | CD_X.append(x[i].detach().numpy()) 99 | CD_X = CD_X[: len(ID_test)] 100 | CD_y = [-1 for _ in CD_X] 101 | CD_test = CDDataset(CD_X, CD_y) 102 | else: 103 | logger.error(f"Unknown data_name {data_name}") 104 | raise typer.Exit() 105 | 106 | # pdb.set_trace() 107 | assert len(CD_test) 108 | logger.info(f"TestData | ID: {len(ID_test)}, CD: {len(CD_test)}") 109 | # concatenate ID and CD data 110 | test_data = ConcatDataset([ID_test, CD_test]) 111 | test_loader = DataLoader(test_data, batch_size=32, shuffle=True) 112 | 113 | # create detector 114 | std = [1] 115 | logger.info("Creating detectors") 116 | detectors = {} 117 | detectors["MaxSoftmax"] = MaxSoftmax(model) 118 | detectors["ODIN"] = ODIN(model, norm_std=std, eps=0.002) 119 | detectors["Mahalanobis"] = Mahalanobis(model.features, norm_std=std, eps=0.002) 120 | detectors["EnergyBased"] = EnergyBased(model) 121 | detectors["Entropy"] = Entropy(model) 122 | detectors["MaxLogit"] = MaxLogit(model) 123 | detectors["KLMatching"] = KLMatching(model) 124 | detectors["ViM"] = ViM(model.features, d=64, w=last_layer.weight, b=last_layer.bias) 125 | 126 | # fit detectors to training data (some require this, some do not) 127 | logger.info(f"> Fitting {len(detectors)} detectors") 128 | for name, detector in detectors.items(): 129 | logger.info(f"--> Fitting {name}") 130 | detector.fit(ID_train_loader, device=device) 131 | 132 | print( 133 | f"STAGE 3: Evaluating {len(detectors)} detectors on {data_name} concept drifts." 134 | ) 135 | results = [] 136 | 137 | with torch.no_grad(): 138 | for detector_name, detector in detectors.items(): 139 | print(f"> Evaluating {detector_name}") 140 | metrics = OODMetrics() 141 | for x, y in test_loader: 142 | metrics.update(detector(x.to(device)), y.to(device)) 143 | 144 | r = {"Detector": detector_name} 145 | d = {k: round(v * 100, 2) for k, v in metrics.compute().items()} 146 | r.update(d) 147 | results.append(r) 148 | 149 | df = pd.DataFrame( 150 | results, columns=["Detector", "AUROC", "FPR95TPR", "AUPR-IN", "AUPR-OUT"] 151 | ) 152 | df.to_csv(ROOT / f"detect/mlp-ember-{data_name}.csv", index=False) 153 | mean_scores = df.groupby("Detector").mean() 154 | logger.info(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f")) 155 | 156 | 157 | if __name__ == "__main__": 158 | app() 159 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | lightgbm==3.3.3 2 | lightning==2.0.2 3 | torch==2.0.1 4 | torchmetrics==0.11.4 5 | loguru==0.7.2 6 | numpy==1.24.4 7 | omegaconf==2.3.0 8 | optuna==2.10.1 9 | pandas==2.0.1 10 | rich==13.6.0 11 | scikit-learn==1.2.2 12 | typer==0.9.0 13 | wandb==0.15.2 14 | matplotlib 15 | pyrootutils 16 | # pytorch-ood need to fix torchmetrics version -------------------------------------------------------------------------------- /scripts/detect_mlp_ember_drift.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python detect/mlp_ember.py \ 4 | --model-file logs/mlp-ember-MFC-0.0/train/runs/*/checkpoints/best.pt \ 5 | --data-name MFCUnseen \ 6 | --pack-ratio 1.0 -------------------------------------------------------------------------------- /scripts/test_gbdt_ember.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # GBDT | Train and Test on MFC, MFCEvolving, MFCPacking, MFCUnseen 4 | 5 | python src/eval_gbdt.py \ 6 | --data-name=MFC \ 7 | --pack-ratio=0.0 \ 8 | --ckpt-path=logs/gbdt-ember-MFC-0.0///gbdt-ember-*.lbg 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /scripts/test_malconv_bytes.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | ## MalConv | Train and Test on MFCEvolving, MFCPacking, MFCUnseen 4 | python src/eval.py \ 5 | experiment=malconv-bytes-test \ 6 | task_name=malconv-bytes-MFC-0.0 \ 7 | data_name=MFC \ 8 | pack_ratio=0.0 \ 9 | ckpt_path=logs/malconv-bytes-MFC-0.0/// -------------------------------------------------------------------------------- /scripts/test_mlp_ember.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ## Train on MFC and Test on MFCUnseen, MFCPacking, MFCEvolving 4 | 5 | python src/eval.py \ 6 | experiment=mlp-ember-test \ 7 | task_name=mlp-ember-MFC-0.0 \ 8 | data_name=MFCUnseen \ 9 | pack_ratio=0.0 \ 10 | ckpt_path=logs/mlp-ember-MFC-0.0/// -------------------------------------------------------------------------------- /scripts/train_gbdt_ember.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python src/train_gbdt.py \ 4 | --data-name MFC \ 5 | --pack-ratio 0.0 \ 6 | --do-wandb -------------------------------------------------------------------------------- /scripts/train_malconv_bytes.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python src/train.py \ 4 | experiment=malconv-bytes-train \ 5 | task_name=malconv-bytes-MFC-0.0 \ 6 | data_name=MFC \ 7 | pack_ratio=0.0 8 | -------------------------------------------------------------------------------- /scripts/train_mlp_ember.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python src/train.py \ 3 | experiment=mlp-ember-train \ 4 | task_name=mlp-ember-MFC-0.0 \ 5 | data_name=MFC \ 6 | pack_ratio=0.0 -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowdma/benchmfc/106e5d773e6fdf2bac429446183f996b3ece8e03/src/__init__.py -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowdma/benchmfc/106e5d773e6fdf2bac429446183f996b3ece8e03/src/datasets/__init__.py -------------------------------------------------------------------------------- /src/datasets/bytes.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | from lightning import LightningDataModule 5 | from torch.utils.data import DataLoader, Dataset 6 | 7 | from src import utils 8 | 9 | from .mfc import Feature, MalconvByteLoader, get_dataloader 10 | 11 | log = utils.get_pylogger(__name__) 12 | 13 | 14 | class BytesDataset(Dataset): 15 | """Ember Feature in-memory Dataset""" 16 | 17 | def __init__(self, X: list[Path], y: np.array, load_fn: callable): 18 | self.X = X 19 | self.y = y 20 | self.load_fn = load_fn 21 | 22 | def __len__(self) -> int: 23 | return len(self.y) 24 | 25 | def __getitem__(self, index: int) -> tuple: 26 | data = self.load_fn(self.X[index]) 27 | target = self.y[index] 28 | return data, target 29 | 30 | 31 | class BytesDataModule(LightningDataModule): 32 | """LightningDataModule for Ember Feature dataset. 33 | 34 | A DataModule implements 5 key methods: 35 | - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode) 36 | - setup (things to do on every accelerator in distributed mode) 37 | - train_dataloader (the training dataloader) 38 | - val_dataloader (the validation dataloader(s)) 39 | - test_dataloader (the test dataloader(s)) 40 | 41 | This allows you to share a full dataset without explaining how to download, 42 | split, transform and process the data. 43 | 44 | Read the docs: 45 | https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html 46 | """ 47 | 48 | def __init__( 49 | self, 50 | data_name: str, 51 | train_size: float = 0.6, 52 | val_size: float = 0.2, 53 | test_size: float = 0.2, 54 | batch_size: int = 32, 55 | num_workers: int = 16, 56 | pack_ratio: float = 0.0, 57 | first_n_byte: int = 2**20, 58 | ): 59 | super().__init__() 60 | 61 | # this line allows to access init params with 'self.hparams' attribute 62 | self.save_hyperparameters() 63 | 64 | self.data_train: Dataset = None 65 | self.data_val: Dataset = None 66 | self.data_test: Dataset = None 67 | 68 | self.mfc = get_dataloader(data_name) 69 | self.load_fn = MalconvByteLoader(first_n_byte=first_n_byte) 70 | 71 | def summary(self) -> dict: 72 | if self.mfc.X_train is None: 73 | self.mfc.setup() 74 | return self.mfc.summary() 75 | 76 | def setup(self, stage: str = None): 77 | """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. 78 | 79 | This method is called by lightning when doing `trainer.fit()` and `trainer.test()`, 80 | so be careful not to execute the random split twice! The `stage` can be used to 81 | differentiate whether it's called before trainer.fit()` or `trainer.test()`. 82 | """ 83 | 84 | # load datasets only if they're not loaded already 85 | if not self.data_train and not self.data_val and not self.data_test: 86 | mfc = self.mfc 87 | mfc.setup( 88 | feature=Feature.SAMPLES, 89 | pack_ratio=self.hparams.pack_ratio, 90 | train_size=self.hparams.train_size, 91 | val_size=self.hparams.val_size, 92 | test_size=self.hparams.test_size, 93 | ) 94 | log.info(f"Summary: {self.summary()}") 95 | (X_train, X_val, X_test, y_train, y_val, y_test) = ( 96 | mfc.X_train, 97 | mfc.X_val, 98 | mfc.X_test, 99 | mfc.y_train, 100 | mfc.y_val, 101 | mfc.y_test, 102 | ) 103 | 104 | load_fn = self.load_fn 105 | self.data_train = BytesDataset(X_train, y_train, load_fn) 106 | self.data_val = BytesDataset(X_val, y_val, load_fn) 107 | self.data_test = BytesDataset(X_test, y_test, load_fn) 108 | 109 | def train_dataloader(self): 110 | return DataLoader( 111 | dataset=self.data_train, 112 | batch_size=self.hparams.batch_size, 113 | num_workers=self.hparams.num_workers, 114 | shuffle=True, 115 | ) 116 | 117 | def val_dataloader(self): 118 | return DataLoader( 119 | dataset=self.data_val, 120 | batch_size=self.hparams.batch_size, 121 | num_workers=self.hparams.num_workers, 122 | shuffle=False, 123 | ) 124 | 125 | def test_dataloader(self): 126 | return DataLoader( 127 | dataset=self.data_test, 128 | batch_size=self.hparams.batch_size, 129 | num_workers=self.hparams.num_workers, 130 | shuffle=False, 131 | ) 132 | -------------------------------------------------------------------------------- /src/datasets/ember.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from lightning import LightningDataModule 3 | from rich.progress import track 4 | from sklearn.preprocessing import StandardScaler 5 | from torch.utils.data import DataLoader, Dataset 6 | 7 | from src import utils 8 | 9 | from .mfc import Feature, get_dataloader 10 | 11 | log = utils.get_pylogger(__name__) 12 | 13 | 14 | class EmberDataset(Dataset): 15 | """Ember Feature in-memory Dataset""" 16 | 17 | def __init__(self, X: np.ndarray, y: np.array): 18 | self.X = X 19 | self.y = y 20 | 21 | def __len__(self) -> int: 22 | return len(self.y) 23 | 24 | def __getitem__(self, index: int) -> tuple: 25 | data, target = self.X[index, :].astype(np.float32), self.y[index] 26 | return data, target 27 | 28 | 29 | class EmberDataModule(LightningDataModule): 30 | """LightningDataModule for Ember Feature dataset. 31 | 32 | A DataModule implements 5 key methods: 33 | - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode) 34 | - setup (things to do on every accelerator in distributed mode) 35 | - train_dataloader (the training dataloader) 36 | - val_dataloader (the validation dataloader(s)) 37 | - test_dataloader (the test dataloader(s)) 38 | 39 | This allows you to share a full dataset without explaining how to download, 40 | split, transform and process the data. 41 | 42 | Read the docs: 43 | https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html 44 | """ 45 | 46 | def __init__( 47 | self, 48 | data_name: str, 49 | train_size: float = 0.6, 50 | val_size: float = 0.2, 51 | test_size: float = 0.2, 52 | batch_size: int = 32, 53 | pack_ratio: float = 0.0, 54 | ): 55 | super().__init__() 56 | 57 | # this line allows to access init params with 'self.hparams' attribute 58 | self.save_hyperparameters() 59 | 60 | self.data_train: Dataset = None 61 | self.data_val: Dataset = None 62 | self.data_test: Dataset = None 63 | 64 | self.mfc = get_dataloader(data_name) 65 | 66 | def summary(self) -> dict: 67 | if self.mfc.X_train is None: 68 | self.mfc.setup() 69 | return self.mfc.summary() 70 | 71 | def setup(self, stage: str = None): 72 | """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. 73 | 74 | This method is called by lightning when doing `trainer.fit()` and `trainer.test()`, 75 | so be careful not to execute the random split twice! The `stage` can be used to 76 | differentiate whether it's called before trainer.fit()` or `trainer.test()`. 77 | """ 78 | 79 | # load datasets only if they're not loaded already 80 | if not self.data_train and not self.data_val and not self.data_test: 81 | mfc = self.mfc 82 | mfc.setup( 83 | feature=Feature.EMBER_NUMPY, 84 | pack_ratio=self.hparams.pack_ratio, 85 | train_size=self.hparams.train_size, 86 | val_size=self.hparams.val_size, 87 | test_size=self.hparams.test_size, 88 | ) 89 | log.info(f"Summary: {self.summary()}") 90 | (X_train, X_val, X_test, y_train, y_val, y_test) = ( 91 | mfc.X_train, 92 | mfc.X_val, 93 | mfc.X_test, 94 | mfc.y_train, 95 | mfc.y_val, 96 | mfc.y_test, 97 | ) 98 | X_train = [ 99 | np.load(i) 100 | for i in track( 101 | X_train, total=len(X_train), description="Loading train..." 102 | ) 103 | ] 104 | X_val = [ 105 | np.load(i) 106 | for i in track(X_val, total=len(X_val), description="Loading val...") 107 | ] 108 | X_test = [ 109 | np.load(i) 110 | for i in track(X_test, total=len(X_test), description="Loading test...") 111 | ] 112 | 113 | log.info("StandardScalering...") 114 | scaler = StandardScaler() 115 | scaler.fit(X_train + X_val + X_test) 116 | X_train = scaler.transform(X_train) 117 | X_val = scaler.transform(X_val) 118 | X_test = scaler.transform(X_test) 119 | 120 | self.data_train = EmberDataset(X_train, y_train) 121 | self.data_val = EmberDataset(X_val, y_val) 122 | self.data_test = EmberDataset(X_test, y_test) 123 | 124 | def train_dataloader(self): 125 | return DataLoader( 126 | dataset=self.data_train, 127 | batch_size=self.hparams.batch_size, 128 | shuffle=True, 129 | drop_last=True, 130 | ) 131 | 132 | def val_dataloader(self): 133 | return DataLoader( 134 | dataset=self.data_val, 135 | batch_size=self.hparams.batch_size, 136 | shuffle=False, 137 | ) 138 | 139 | def test_dataloader(self): 140 | return DataLoader( 141 | dataset=self.data_test, 142 | batch_size=self.hparams.batch_size, 143 | shuffle=False, 144 | ) 145 | -------------------------------------------------------------------------------- /src/datasets/mfc.py: -------------------------------------------------------------------------------- 1 | """Malware Family Classification Data""" 2 | import os 3 | from collections import Counter, defaultdict 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | from sklearn.model_selection import train_test_split 8 | 9 | MFC_ROOT = Path(os.getenv("MFC_ROOT")) 10 | 11 | PACKERS = ["upx", "mpress", "aes"] 12 | 13 | 14 | class Feature: 15 | EMBER_NUMPY = "feature-ember-npy" 16 | SAMPLES = "samples" 17 | 18 | 19 | FEATURE_SUFFIX = { 20 | Feature.SAMPLES: "", 21 | Feature.EMBER_NUMPY: "ember.npy", 22 | } 23 | 24 | 25 | class Group: 26 | MALICIOUS = "malicious" 27 | MALICIOUS_UNSEEN = "malicious-unseen" 28 | MALICIOUS_EVOLVING = "malicious-evolving" 29 | MALICIOUS_UPX = "malicious-upx" 30 | MALICIOUS_MPRESS = "malicious-mpress" 31 | MALICIOUS_AES = "malicious-aes" 32 | 33 | 34 | class FeatureLoader: 35 | name = "feature_loader" 36 | dtype = "float32" 37 | 38 | def __call__(self, file_path: str) -> np.ndarray: 39 | raise NotImplementedError 40 | 41 | 42 | class NumpyLoader(FeatureLoader): 43 | name = "numpy" 44 | dtype = "float32" 45 | 46 | def __call__(self, file_path: str) -> np.ndarray: 47 | data: np.ndarray = np.load(file_path) 48 | return data.astype(np.float32) 49 | 50 | 51 | class MalconvByteLoader(FeatureLoader): 52 | name = "malconv_byte" 53 | dtype = "int32" 54 | 55 | def __init__(self, first_n_byte: int = 2**20) -> None: 56 | self.first_n_byte = first_n_byte 57 | 58 | def __call__(self, file_path: str) -> np.ndarray: 59 | with open(file_path, "rb") as f: 60 | # index 0 will be special padding index 61 | data = [i + 1 for i in f.read()[: self.first_n_byte]] 62 | data = data + [0] * (self.first_n_byte - len(data)) 63 | return np.array(data).astype(np.intc) 64 | 65 | 66 | def class_counter(data: list[str]) -> dict: 67 | return dict(sorted(Counter(data).items(), key=lambda x: x[1], reverse=True)) 68 | 69 | 70 | class MFCSample: 71 | """MFC Sample is orginazied by: 72 | ``` 73 | root/samples///xxx 74 | 75 | >>> For example: 76 | MFC 77 | ├── samples 78 | │ ├── malicious 79 | │ │ ├── fareit 80 | │ │ ├── gandcrab 81 | │ │ ├── hotbar 82 | │ │ ├── parite 83 | │ │ ├── simda 84 | │ │ ├── upatre 85 | │ │ ├── yuner 86 | │ │ └── zbot 87 | .... 88 | ``` 89 | """ 90 | 91 | root: str = MFC_ROOT 92 | group: str = Group.MALICIOUS 93 | 94 | def get( 95 | self, 96 | group: str = None, 97 | root: str = None, 98 | ) -> tuple[list[str], list[str]]: 99 | group = group or self.group 100 | root = root or self.root 101 | data_path = Path(root) / "samples" 102 | 103 | X = [] 104 | y = [] 105 | for r, _, files in os.walk(data_path / group): 106 | for f in files: 107 | file_path = Path(r, f) 108 | X.append(file_path.name) 109 | y.append(file_path.parent.name) 110 | 111 | if len(X) == 0: 112 | raise ValueError(f"Empty: {root}/samples/{group}") 113 | 114 | return X, y 115 | 116 | 117 | class MFCLoader: 118 | root: Path = MFC_ROOT 119 | group: str = Group.MALICIOUS 120 | class_map: dict = { 121 | "fareit": 0, 122 | "gandcrab": 1, 123 | "hotbar": 2, 124 | "parite": 3, 125 | "simda": 4, 126 | "upatre": 5, 127 | "yuner": 6, 128 | "zbot": 7, 129 | } 130 | packers: list[str] = None 131 | 132 | def __init__(self): 133 | self.X_train = None 134 | self.X_test = None 135 | self.X_val = None 136 | self.y_train = None 137 | self.y_test = None 138 | self.y_val = None 139 | self.feature = None 140 | self.pack_ratio = None 141 | self.load_fn: FeatureLoader = None 142 | 143 | def get_path(self, X: list[str], y: list[str]) -> tuple[list[Path], list[int]]: 144 | X_path = [] 145 | y_id = [] 146 | 147 | for name, family in zip(X, y): 148 | if "_" in name: 149 | group = "-".join([self.group] + name.split("_")[1:]) 150 | else: 151 | group = self.group 152 | suffix = FEATURE_SUFFIX[self.feature] 153 | if suffix: 154 | name = f"{name}_{FEATURE_SUFFIX[self.feature]}" 155 | p = self.root / self.feature / group / family / name 156 | if p.exists(): 157 | X_path.append(p) 158 | y_id.append(self.class_map[family]) 159 | assert len(X_path) > 0 160 | return X_path, y_id 161 | 162 | def pack( 163 | self, 164 | X: list[str], 165 | y: list[int], 166 | pack_ratio: float, 167 | ) -> tuple[list[str], list[int]]: 168 | if pack_ratio == 1.0: 169 | X_packed, y_packed = X, y 170 | X_unpacked, y_unpacked = [], [] 171 | else: 172 | X_packed, X_unpacked, y_packed, y_unpacked = train_test_split( 173 | X, y, train_size=pack_ratio, stratify=y, random_state=42 174 | ) 175 | num = len(X_packed) 176 | packers = self.packers 177 | m = len(packers) 178 | n = num // m 179 | packed = [] 180 | for i, j in enumerate(range(0, num, n)): 181 | packed.extend( 182 | [f"{k}_{packers[i%m]}" for k in X_packed[j : min(num, j + n)]] 183 | ) 184 | return packed + X_unpacked, y_packed + y_unpacked 185 | 186 | def setup( 187 | self, 188 | feature: str = None, 189 | pack_ratio: float = None, 190 | train_size: float = 0.6, 191 | val_size: float = 0.2, 192 | test_size: float = 0.2, 193 | ) -> tuple[list[Path], list[Path], list[Path], list[int], list[int], list[int]]: 194 | """ 195 | Returns 196 | ------- 197 | X_train, X_val, X_test, y_train, y_val, y_test 198 | """ 199 | assert sum([train_size, val_size, test_size]) == 1.0 200 | 201 | group = self.group 202 | root = self.root 203 | 204 | self.feature = feature 205 | self.pack_ratio = pack_ratio 206 | 207 | if feature == Feature.EMBER_NUMPY: 208 | self.load_fn = NumpyLoader() 209 | elif feature == Feature.SAMPLES: 210 | self.load_fn = MalconvByteLoader() 211 | else: 212 | raise ValueError(f"Unknown feature: {feature}") 213 | 214 | X, y = MFCSample().get(group, root) 215 | X_train, X_test, y_train, y_test = train_test_split( 216 | X, y, train_size=train_size, stratify=y, random_state=42 217 | ) 218 | # # for 40% test samples 219 | # new_size = test_size 220 | new_size = test_size / (test_size + val_size) 221 | X_test, X_val, y_test, y_val = train_test_split( 222 | X_test, y_test, train_size=new_size, stratify=y_test, random_state=42 223 | ) 224 | 225 | # pack 226 | if pack_ratio is None: 227 | pack_ratio = self.pack_ratio 228 | if 0.1 <= pack_ratio <= 1.0: 229 | X_train, y_train = self.pack(X_train, y_train, pack_ratio) 230 | X_test, y_test = self.pack(X_test, y_test, pack_ratio) 231 | X_val, y_val = self.pack(X_val, y_val, pack_ratio) 232 | 233 | # path 234 | X_train, y_train = self.get_path(X_train, y_train) 235 | X_test, y_test = self.get_path(X_test, y_test) 236 | X_val, y_val = self.get_path(X_val, y_val) 237 | 238 | self.X_train = X_train 239 | self.X_test = X_test 240 | self.X_val = X_val 241 | self.y_train = y_train 242 | self.y_test = y_test 243 | self.y_val = y_val 244 | 245 | return (X_train, X_val, X_test, y_train, y_val, y_test) 246 | 247 | def is_packed(self, x: Path) -> bool: 248 | return any([i in x.name for i in PACKERS]) 249 | 250 | def get_packed_ratio(self, X: list[Path]) -> float: 251 | ratio = sum([self.is_packed(i) for i in X]) / len(X) 252 | return round(ratio, 2) 253 | 254 | def get_packer_dist(self, X: list[list[Path]]) -> dict[str, int]: 255 | packers = defaultdict(int) 256 | for f in X: 257 | hit = False 258 | for p in PACKERS: 259 | if p in f.name: 260 | hit = True 261 | packers[p] += 1 262 | if not hit: 263 | packers["none"] += 1 264 | return dict(packers) 265 | 266 | def load(self, x: Path) -> np.ndarray: 267 | return self.load_fn(x) 268 | 269 | def summary(self) -> dict: 270 | X_train, X_val, X_test = self.X_train, self.X_val, self.X_test 271 | y_train, y_val, y_test = self.y_train, self.y_val, self.y_test 272 | num_train, num_val, num_test = len(X_train), len(X_val), len(X_test) 273 | 274 | ratio_train = self.get_packed_ratio(X_train) 275 | ratio_val = self.get_packed_ratio(X_val) 276 | ratio_test = self.get_packed_ratio(X_test) 277 | 278 | packers_train = self.get_packer_dist(X_train) 279 | packers_val = self.get_packer_dist(X_val) 280 | packers_test = self.get_packer_dist(X_test) 281 | 282 | def data_class(y: list[int]) -> dict[str, int]: 283 | return dict(sorted(Counter(y).items())) 284 | 285 | # data 286 | data = { 287 | "train": { 288 | "total": num_train, 289 | "packer": packers_train, 290 | "packed_ratio": ratio_train, 291 | "class": data_class(y_train), 292 | }, 293 | "val": { 294 | "total": num_val, 295 | "packer": packers_val, 296 | "packed_ratio": ratio_val, 297 | "class": data_class(y_val), 298 | }, 299 | "test": { 300 | "total": num_test, 301 | "packer": packers_test, 302 | "packed_ratio": ratio_test, 303 | "class": data_class(y_test), 304 | }, 305 | } 306 | # feature 307 | x_data = self.load(X_train[0]) 308 | x_path = str(X_train[0].relative_to(self.root)) 309 | feature = { 310 | "names": self.feature, 311 | "loader": self.load_fn.name, 312 | "dtype": self.load_fn.dtype, 313 | "example": x_path, 314 | "dimension": len(x_data), 315 | } 316 | return {"data": data, "feature": feature} 317 | 318 | 319 | class MFC(MFCLoader): 320 | group: str = Group.MALICIOUS 321 | class_map: dict = { 322 | "fareit": 0, 323 | "gandcrab": 1, 324 | "hotbar": 2, 325 | "parite": 3, 326 | "simda": 4, 327 | "upatre": 5, 328 | "yuner": 6, 329 | "zbot": 7, 330 | } 331 | 332 | 333 | class MFCEvolving(MFC): 334 | group: str = Group.MALICIOUS_EVOLVING 335 | 336 | 337 | class MFCUnseen(MFCLoader): 338 | group: str = Group.MALICIOUS_UNSEEN 339 | class_map: dict = { 340 | "hupigon": 0, 341 | "imali": 1, 342 | "lydra": 2, 343 | "onlinegames": 3, 344 | "virut": 4, 345 | "vobfus": 5, 346 | "wannacry": 6, 347 | "zlob": 7, 348 | } 349 | 350 | 351 | class MFCPacking(MFC): 352 | packers: list[str] = ["upx", "mpress", "aes"] 353 | pack_ratio: float = 1.0 354 | 355 | 356 | class MFCAes(MFC): 357 | packers: list[str] = ["aes"] 358 | pack_ratio: float = 1.0 359 | 360 | 361 | MFC_LOADER: dict[str, MFCLoader] = { 362 | "MFC": MFC(), 363 | "MFCAes": MFCAes(), 364 | "MFCEvolving": MFCEvolving(), 365 | "MFCPacking": MFCPacking(), 366 | "MFCUnseen": MFCUnseen(), 367 | } 368 | 369 | 370 | def get_dataloader(name: str) -> MFCLoader: 371 | return MFC_LOADER[name] 372 | 373 | 374 | if __name__ == "__main__": 375 | import rich 376 | 377 | mfc = MFCPacking() 378 | mfc.setup() 379 | rich.print(mfc.summary()) 380 | -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import hydra 4 | import numpy as np 5 | import pyrootutils 6 | from lightning import LightningDataModule, LightningModule, Trainer 7 | from lightning.pytorch.loggers import Logger 8 | from omegaconf import DictConfig 9 | 10 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 11 | # ------------------------------------------------------------------------------------ # 12 | # the setup_root above is equivalent to: 13 | # - adding project root dir to PYTHONPATH 14 | # (so you don't need to force user to install project as a package) 15 | # (necessary before importing any local modules e.g. `from src import utils`) 16 | # - setting up PROJECT_ROOT environment variable 17 | # (which is used as a base for paths in "configs/paths/default.yaml") 18 | # (this way all filepaths are the same no matter where you run the code) 19 | # - loading environment variables from ".env" in root dir 20 | # 21 | # you can remove it if you: 22 | # 1. either install project as a package or move entry files to project root dir 23 | # 2. set `root_dir` to "." in "configs/paths/default.yaml" 24 | # 25 | # more info: https://github.com/ashleve/pyrootutils 26 | # ------------------------------------------------------------------------------------ # 27 | 28 | from src import utils 29 | 30 | log = utils.get_pylogger(__name__) 31 | 32 | 33 | @utils.task_wrapper 34 | def evaluate(cfg: DictConfig) -> Tuple[dict, dict]: 35 | """Evaluates given checkpoint on a datamodule testset. 36 | 37 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during 38 | failure. Useful for multiruns, saving info about the crash, etc. 39 | 40 | Args: 41 | cfg (DictConfig): Configuration composed by Hydra. 42 | 43 | Returns: 44 | Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. 45 | """ 46 | 47 | assert cfg.ckpt_path 48 | 49 | log.info(f"Instantiating datamodule <{cfg.data._target_}>") 50 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) 51 | 52 | log.info(f"Instantiating model <{cfg.model._target_}>") 53 | model: LightningModule = hydra.utils.instantiate(cfg.model) 54 | 55 | log.info("Instantiating loggers...") 56 | logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) 57 | 58 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") 59 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger) 60 | 61 | object_dict = { 62 | "cfg": cfg, 63 | "datamodule": datamodule, 64 | "model": model, 65 | "logger": logger, 66 | "trainer": trainer, 67 | } 68 | 69 | if logger: 70 | log.info("Logging hyperparameters!") 71 | utils.log_hyperparameters(object_dict) 72 | 73 | # log data summary 74 | datamodule.setup() 75 | log.info("Logging data summary!") 76 | utils.log_data_summary(cfg, datamodule) 77 | 78 | log.info("Starting testing!") 79 | test_loader = datamodule.test_dataloader() 80 | y_true = [] 81 | for _, y in test_loader: 82 | if isinstance(y, list): 83 | y_true.extend(y[-1].numpy()) 84 | else: 85 | y_true.extend(y.numpy()) 86 | y_true = np.hstack(y_true) 87 | y_pred = trainer.predict( 88 | model=model, dataloaders=test_loader, ckpt_path=cfg.ckpt_path 89 | ) 90 | y_pred = np.hstack(y_pred) 91 | utils.log_test_results(cfg, y_true, y_pred) 92 | 93 | metric_dict = trainer.callback_metrics 94 | 95 | return metric_dict, object_dict 96 | 97 | 98 | @hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml") 99 | def main(cfg: DictConfig) -> None: 100 | # apply extra utilities 101 | # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 102 | utils.extras(cfg) 103 | 104 | evaluate(cfg) 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /src/eval_gbdt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pyrootutils 7 | import rich 8 | import typer 9 | from rich.progress import track 10 | from sklearn.metrics import classification_report 11 | 12 | ROOT = pyrootutils.setup_root( 13 | __file__, 14 | indicator=".project-root", 15 | pythonpath=True, 16 | ) 17 | 18 | from src.datasets.mfc import MFC_LOADER 19 | from src.models.gbdt import GBDTClassifier 20 | 21 | app = typer.Typer() 22 | 23 | 24 | def pprint(data): 25 | return json.dumps(data, indent=2) 26 | 27 | 28 | def seed_everything(seed: int) -> None: 29 | np.random.seed(seed) 30 | random.seed(seed) 31 | 32 | 33 | @app.command() 34 | def main( 35 | data_name: str = None, 36 | ckpt_path: str = None, 37 | feature_name: str = "feature-ember-npy", 38 | pack_ratio: float = 0.0, 39 | train_size: float = 0.6, 40 | val_size: float = 0.2, 41 | test_size: float = 0.2, 42 | seed: int = 42, 43 | ): 44 | seed_everything(seed) 45 | gbdt = GBDTClassifier() 46 | gbdt.load(ckpt_path) 47 | 48 | ckpt_path = Path(ckpt_path) 49 | log_dir = ckpt_path.parent / f"{data_name}" 50 | log_dir.mkdir(parents=True, exist_ok=True) 51 | 52 | # prepare data 53 | mfc = MFC_LOADER[data_name] 54 | mfc.setup( 55 | feature=feature_name, 56 | train_size=train_size, 57 | val_size=val_size, 58 | test_size=test_size, 59 | pack_ratio=pack_ratio, 60 | ) 61 | rich.print(mfc.summary()) 62 | with open(log_dir / "data_summary.log", "w") as file: 63 | rich.print(mfc.summary(), file=file) 64 | 65 | _, _, X_test, _, _, y_test = ( 66 | mfc.X_train, 67 | mfc.X_val, 68 | mfc.X_test, 69 | mfc.y_train, 70 | mfc.y_val, 71 | mfc.y_test, 72 | ) 73 | X_test = [ 74 | np.load(i) 75 | for i in track(X_test, total=len(X_test), description="Loading test...") 76 | ] 77 | X_test = np.vstack(X_test) 78 | y_test = np.array(y_test) 79 | 80 | # predict 81 | predict = gbdt.predict(X_test) 82 | predict = [np.argmax(i) for i in predict] 83 | 84 | result = classification_report(y_true=y_test, y_pred=predict, digits=4) 85 | rich.print(f"Test Report: {result}") 86 | with open(log_dir / "test_results.log", "w") as file: 87 | rich.print(result, file=file) 88 | 89 | 90 | if __name__ == "__main__": 91 | app() 92 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowdma/benchmfc/106e5d773e6fdf2bac429446183f996b3ece8e03/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/gbdt.py: -------------------------------------------------------------------------------- 1 | import lightgbm as lgb 2 | 3 | 4 | class GBDTClassifier: 5 | def __init__( 6 | self, 7 | boosting: str = "gbdt", 8 | objective: str = "multiclass", 9 | num_class: int = 8, 10 | metric: str = "multi_logloss", 11 | num_iterations: int = 1_000, 12 | learning_rate: float = 0.05, 13 | num_leaves: int = 2048, 14 | max_depth: int = 15, 15 | min_data_in_leaf: int = 50, 16 | feature_fraction: float = 0.5, 17 | device: str = "cpu", 18 | num_threads: int = 24, 19 | verbosity: int = -1, 20 | ): 21 | self.hparams = { 22 | "boosting": boosting, 23 | "objective": objective, 24 | "num_class": num_class, 25 | "metric": metric, 26 | "num_iterations": num_iterations, 27 | "learning_rate": learning_rate, 28 | "num_leaves": num_leaves, 29 | "max_depth": max_depth, 30 | "min_data_in_leaf": min_data_in_leaf, 31 | "feature_fraction": feature_fraction, 32 | "device": device, 33 | "num_threads": num_threads, 34 | # 1 means INFO, > 1 means DEBUG, 0 means Error(WARNING), <0 means Fatal 35 | "verbosity": verbosity, 36 | } 37 | self.model = None 38 | 39 | def load(self, model_file: str) -> None: 40 | self.model = lgb.Booster(model_file=model_file) 41 | 42 | def fit(self, X_train, y_train, callbacks=None) -> None: 43 | lgbm_dataset = lgb.Dataset(X_train, y_train) 44 | self.model = lgb.train(self.hparams, lgbm_dataset, callbacks=callbacks) 45 | 46 | def predict(self, X_test): 47 | return self.model.predict(X_test) 48 | -------------------------------------------------------------------------------- /src/models/malconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class MalConv(nn.Module): 6 | def __init__( 7 | self, 8 | input_length: int = 2**20, 9 | window_size: int = 500, 10 | stride: int = 500, 11 | channels: int = 128, 12 | embed_size: int = 8, 13 | output_size: int = 8, 14 | ): 15 | super().__init__() 16 | self.channels = channels 17 | self.embed = nn.Embedding(257, embed_size, padding_idx=0) 18 | in_channels = int(embed_size / 2) 19 | self.conv_1 = nn.Conv1d( 20 | in_channels, channels, window_size, stride=stride, bias=True 21 | ) 22 | self.conv_2 = nn.Conv1d( 23 | in_channels, channels, window_size, stride=stride, bias=True 24 | ) 25 | self.pooling = nn.MaxPool1d(int(input_length / window_size)) 26 | self.fc_1 = nn.Linear(channels, channels) 27 | self.fc_2 = nn.Linear(channels, output_size) 28 | self.sigmoid = nn.Sigmoid() 29 | # num_classes 30 | self.num_classes = output_size 31 | 32 | def forward(self, x): 33 | x = self.embed(x) 34 | # Channel first 35 | x = torch.transpose(x, -1, -2) 36 | cnn_value = self.conv_1(x.narrow(-2, 0, 4)) 37 | gating_weight = self.sigmoid(self.conv_2(x.narrow(-2, 4, 4))) 38 | x = cnn_value * gating_weight 39 | x = self.pooling(x) 40 | x = x.view(-1, self.channels) 41 | x = self.fc_1(x) 42 | x = self.fc_2(x) 43 | return x 44 | 45 | def features(self, x): 46 | """ 47 | Extracts (flattened) features before the last fully connected layer. 48 | """ 49 | x = self.embed(x) 50 | # Channel first 51 | x = torch.transpose(x, -1, -2) 52 | cnn_value = self.conv_1(x.narrow(-2, 0, 4)) 53 | gating_weight = self.sigmoid(self.conv_2(x.narrow(-2, 4, 4))) 54 | x = cnn_value * gating_weight 55 | x = self.pooling(x) 56 | x = x.view(-1, self.channels) 57 | x = self.fc_1(x) 58 | return x 59 | -------------------------------------------------------------------------------- /src/models/malconv_module.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | from lightning import LightningModule 5 | from torchmetrics import Accuracy, MaxMetric, MeanMetric 6 | 7 | 8 | class MalconvModule(LightningModule): 9 | """MLP Module. 10 | 11 | A LightningModule organizes your PyTorch code into 5 sections: 12 | - Computations (init). 13 | - Train loop (training_step) 14 | - Validation loop (validation_step) 15 | - Test loop (test_step) 16 | - Optimizers (configure_optimizers) 17 | 18 | Read the docs: 19 | https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html 20 | """ 21 | 22 | def __init__( 23 | self, 24 | network: torch.nn.Module, 25 | optimizer: torch.optim.Optimizer, 26 | scheduler: torch.optim.lr_scheduler, 27 | ): 28 | super().__init__() 29 | 30 | # this line allows to access init params with 'self.hparams' attribute 31 | # it also ensures init params will be stored in ckpt 32 | self.save_hyperparameters(logger=False) 33 | 34 | self.network = network 35 | num_classes = network.num_classes 36 | 37 | # loss function 38 | self.criterion = torch.nn.CrossEntropyLoss() 39 | 40 | # metric objects for calculating and averaging accuracy across batches 41 | self.train_acc = Accuracy(task="multiclass", num_classes=num_classes) 42 | self.val_acc = Accuracy(task="multiclass", num_classes=num_classes) 43 | self.test_acc = Accuracy(task="multiclass", num_classes=num_classes) 44 | 45 | # for averaging loss across batches 46 | self.train_loss = MeanMetric() 47 | self.val_loss = MeanMetric() 48 | self.test_loss = MeanMetric() 49 | 50 | # for logging best so far validation accuracy 51 | self.val_acc_best = MaxMetric() 52 | 53 | def forward(self, x: torch.Tensor): 54 | return self.network(x) 55 | 56 | def on_train_start(self): 57 | # by default lightning executes validation step sanity checks before training starts, 58 | # so it's worth to make sure validation metrics don't store results from these checks 59 | self.val_loss.reset() 60 | self.val_acc.reset() 61 | self.val_acc_best.reset() 62 | 63 | def model_step(self, batch): 64 | x, y = batch 65 | logits = self.forward(x) 66 | loss = self.criterion(logits, y) 67 | preds = torch.argmax(logits, dim=1) 68 | return loss, preds, y 69 | 70 | def training_step(self, batch, batch_idx: int): 71 | loss, preds, targets = self.model_step(batch) 72 | 73 | # update and log metrics 74 | self.train_loss(loss) 75 | self.train_acc(preds, targets) 76 | self.log( 77 | "train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True 78 | ) 79 | self.log( 80 | "train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True 81 | ) 82 | 83 | # return loss or backpropagation will fail 84 | return loss 85 | 86 | def on_train_epoch_end(self): 87 | pass 88 | 89 | def validation_step(self, batch, batch_idx: int): 90 | loss, preds, targets = self.model_step(batch) 91 | 92 | # update and log metrics 93 | self.val_loss(loss) 94 | self.val_acc(preds, targets) 95 | self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True) 96 | self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) 97 | 98 | def on_validation_epoch_end(self): 99 | acc = self.val_acc.compute() # get current val acc 100 | self.val_acc_best(acc) # update best so far val acc 101 | # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object 102 | # otherwise metric would be reset by lightning after each epoch 103 | self.log("val/acc_best", self.val_acc_best.compute(), prog_bar=True) 104 | 105 | def test_step(self, batch, batch_idx: int): 106 | loss, preds, targets = self.model_step(batch) 107 | 108 | # update and log metrics 109 | self.test_loss(loss) 110 | self.test_acc(preds, targets) 111 | self.log( 112 | "test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True 113 | ) 114 | self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True) 115 | 116 | def on_test_epoch_end(self): 117 | pass 118 | 119 | def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: 120 | loss, preds, targets = self.model_step(batch) 121 | return preds.tolist() 122 | 123 | def configure_optimizers(self): 124 | """Choose what optimizers and learning-rate schedulers to use in your optimization. 125 | Normally you'd need one. But in the case of GANs or similar you might have multiple. 126 | 127 | Examples: 128 | https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers 129 | """ 130 | optimizer = self.hparams.optimizer(params=self.parameters()) 131 | if self.hparams.scheduler is not None: 132 | scheduler = self.hparams.scheduler(optimizer=optimizer) 133 | return { 134 | "optimizer": optimizer, 135 | "lr_scheduler": { 136 | "scheduler": scheduler, 137 | "monitor": "val/loss", 138 | "interval": "epoch", 139 | "frequency": 1, 140 | }, 141 | } 142 | return {"optimizer": optimizer} 143 | -------------------------------------------------------------------------------- /src/models/mlp.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from torch import nn 4 | 5 | 6 | class MLP(nn.Module): 7 | def __init__( 8 | self, 9 | output_size: int = 8, 10 | input_size: int = 2381, 11 | hidden_units: list[int] = [1024, 512, 256], 12 | ): 13 | super().__init__() 14 | all_layers = [] 15 | for hidden in hidden_units: 16 | all_layers.append(nn.Linear(input_size, hidden)) 17 | all_layers.append(nn.BatchNorm1d(hidden)), 18 | all_layers.append(nn.ReLU()) 19 | input_size = hidden 20 | all_layers.append(nn.Linear(hidden_units[-1], output_size)) 21 | self.model = nn.Sequential(*all_layers) 22 | # num_classes 23 | self.num_classes = output_size 24 | 25 | def forward(self, x): 26 | batch_size, _ = x.size() 27 | # (batch, 1, width, height) -> (batch, 1*width*height) 28 | x = x.view(batch_size, -1) 29 | 30 | return self.model(x) 31 | 32 | def features(self, x): 33 | """ 34 | Extracts (flattened) features before the last fully connected layer. 35 | """ 36 | batch_size, _ = x.size() 37 | # (batch, 1, width, height) -> (batch, 1*width*height) 38 | x = x.view(batch_size, -1) 39 | 40 | fea = self.model[:-1] 41 | return fea(x) 42 | -------------------------------------------------------------------------------- /src/models/mlp_module.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | from lightning import LightningModule 5 | from torchmetrics import Accuracy, MaxMetric, MeanMetric 6 | 7 | from src import utils 8 | 9 | log = utils.get_pylogger(__name__) 10 | 11 | 12 | class MLPModule(LightningModule): 13 | """MLP Module. 14 | 15 | A LightningModule organizes your PyTorch code into 5 sections: 16 | - Computations (init). 17 | - Train loop (training_step) 18 | - Validation loop (validation_step) 19 | - Test loop (test_step) 20 | - Optimizers (configure_optimizers) 21 | 22 | Read the docs: 23 | https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html 24 | """ 25 | 26 | def __init__( 27 | self, 28 | network: torch.nn.Module, 29 | optimizer: torch.optim.Optimizer, 30 | scheduler: torch.optim.lr_scheduler, 31 | ): 32 | super().__init__() 33 | 34 | # this line allows to access init params with 'self.hparams' attribute 35 | # it also ensures init params will be stored in ckpt 36 | self.save_hyperparameters(logger=False) 37 | 38 | self.network = network 39 | num_classes = network.num_classes 40 | 41 | # loss function 42 | self.criterion = torch.nn.CrossEntropyLoss() 43 | 44 | # metric objects for calculating and averaging accuracy across batches 45 | self.train_acc = Accuracy(task="multiclass", num_classes=num_classes) 46 | self.val_acc = Accuracy(task="multiclass", num_classes=num_classes) 47 | self.test_acc = Accuracy(task="multiclass", num_classes=num_classes) 48 | 49 | # for averaging loss across batches 50 | self.train_loss = MeanMetric() 51 | self.val_loss = MeanMetric() 52 | self.test_loss = MeanMetric() 53 | 54 | # for logging best so far validation accuracy 55 | self.val_acc_best = MaxMetric() 56 | 57 | def forward(self, x: torch.Tensor): 58 | return self.network(x) 59 | 60 | def on_train_start(self): 61 | # by default lightning executes validation step sanity checks before training starts, 62 | # so it's worth to make sure validation metrics don't store results from these checks 63 | self.val_loss.reset() 64 | self.val_acc.reset() 65 | self.val_acc_best.reset() 66 | 67 | def model_step(self, batch): 68 | x, y = batch 69 | logits = self.forward(x) 70 | loss = self.criterion(logits, y) 71 | preds = torch.argmax(logits, dim=1) 72 | return loss, preds, y 73 | 74 | def training_step(self, batch, batch_idx: int): 75 | loss, preds, targets = self.model_step(batch) 76 | 77 | # update and log metrics 78 | self.train_loss(loss) 79 | self.train_acc(preds, targets) 80 | self.log( 81 | "train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True 82 | ) 83 | self.log( 84 | "train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True 85 | ) 86 | 87 | # return loss or backpropagation will fail 88 | return loss 89 | 90 | def on_train_epoch_end(self): 91 | pass 92 | 93 | def validation_step(self, batch, batch_idx: int): 94 | loss, preds, targets = self.model_step(batch) 95 | 96 | # update and log metrics 97 | self.val_loss(loss) 98 | self.val_acc(preds, targets) 99 | self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True) 100 | self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) 101 | 102 | def on_validation_epoch_end(self): 103 | acc = self.val_acc.compute() # get current val acc 104 | self.val_acc_best(acc) # update best so far val acc 105 | # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object 106 | # otherwise metric would be reset by lightning after each epoch 107 | self.log("val/acc_best", self.val_acc_best.compute(), prog_bar=True) 108 | 109 | def test_step(self, batch, batch_idx: int): 110 | loss, preds, targets = self.model_step(batch) 111 | 112 | # update and log metrics 113 | self.test_loss(loss) 114 | self.test_acc(preds, targets) 115 | self.log( 116 | "test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True 117 | ) 118 | self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True) 119 | 120 | def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: 121 | loss, preds, targets = self.model_step(batch) 122 | return preds.tolist() 123 | 124 | def on_test_epoch_end(self): 125 | pass 126 | 127 | def configure_optimizers(self): 128 | """Choose what optimizers and learning-rate schedulers to use in your optimization. 129 | Normally you'd need one. But in the case of GANs or similar you might have multiple. 130 | 131 | Examples: 132 | https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers 133 | """ 134 | optimizer = self.hparams.optimizer(params=self.parameters()) 135 | if self.hparams.scheduler is not None: 136 | scheduler = self.hparams.scheduler(optimizer=optimizer) 137 | return { 138 | "optimizer": optimizer, 139 | "lr_scheduler": { 140 | "scheduler": scheduler, 141 | "monitor": "val/loss", 142 | "interval": "epoch", 143 | "frequency": 1, 144 | }, 145 | } 146 | return {"optimizer": optimizer} 147 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Optional, Tuple 3 | 4 | import hydra 5 | import lightning as L 6 | import numpy as np 7 | import pyrootutils 8 | import torch 9 | from lightning import Callback, LightningDataModule, LightningModule, Trainer 10 | from lightning.pytorch.loggers import Logger 11 | from omegaconf import DictConfig 12 | 13 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 14 | # ------------------------------------------------------------------------------------ # 15 | # the setup_root above is equivalent to: 16 | # - adding project root dir to PYTHONPATH 17 | # (so you don't need to force user to install project as a package) 18 | # (necessary before importing any local modules e.g. `from src import utils`) 19 | # - setting up PROJECT_ROOT environment variable 20 | # (which is used as a base for paths in "configs/paths/default.yaml") 21 | # (this way all filepaths are the same no matter where you run the code) 22 | # - loading environment variables from ".env" in root dir 23 | # 24 | # you can remove it if you: 25 | # 1. either install project as a package or move entry files to project root dir 26 | # 2. set `root_dir` to "." in "configs/paths/default.yaml" 27 | # 28 | # more info: https://github.com/ashleve/pyrootutils 29 | # ------------------------------------------------------------------------------------ # 30 | 31 | from src import utils 32 | 33 | log = utils.get_pylogger(__name__) 34 | 35 | 36 | @utils.task_wrapper 37 | def train(cfg: DictConfig) -> Tuple[dict, dict]: 38 | """Trains the model. Can additionally evaluate on a testset, using best weights obtained during 39 | training. 40 | 41 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during 42 | failure. Useful for multiruns, saving info about the crash, etc. 43 | 44 | Args: 45 | cfg (DictConfig): Configuration composed by Hydra. 46 | 47 | Returns: 48 | Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects. 49 | """ 50 | 51 | # set seed for random number generators in pytorch, numpy and python.random 52 | if cfg.get("seed"): 53 | L.seed_everything(cfg.seed, workers=True) 54 | 55 | log.info(f"Instantiating datamodule <{cfg.data._target_}>") 56 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) 57 | 58 | log.info(f"Instantiating model <{cfg.model._target_}>") 59 | model: LightningModule = hydra.utils.instantiate(cfg.model) 60 | 61 | log.info("Instantiating callbacks...") 62 | callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) 63 | 64 | log.info("Instantiating loggers...") 65 | logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) 66 | 67 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") 68 | trainer: Trainer = hydra.utils.instantiate( 69 | cfg.trainer, callbacks=callbacks, logger=logger 70 | ) 71 | 72 | object_dict = { 73 | "cfg": cfg, 74 | "datamodule": datamodule, 75 | "model": model, 76 | "callbacks": callbacks, 77 | "logger": logger, 78 | "trainer": trainer, 79 | } 80 | 81 | if logger: 82 | log.info("Logging hyperparameters!") 83 | utils.log_hyperparameters(object_dict) 84 | 85 | # log data summary 86 | datamodule.setup() 87 | log.info("Logging data summary!") 88 | utils.log_data_summary(cfg, datamodule) 89 | 90 | if cfg.get("compile"): 91 | log.info("Compiling model!") 92 | model = torch.compile(model) 93 | 94 | if cfg.get("train"): 95 | log.info("Starting training!") 96 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) 97 | # save state_dict 98 | ckpt_path = Path(trainer.checkpoint_callback.best_model_path) 99 | log.info(f"Best ckpt path: {ckpt_path}") 100 | model = model.load_from_checkpoint(ckpt_path) 101 | torch.save(model.network.state_dict(), ckpt_path.parent / "best.pt") 102 | 103 | train_metrics = trainer.callback_metrics 104 | 105 | if cfg.get("test"): 106 | log.info("Starting testing!") 107 | ckpt_path = trainer.checkpoint_callback.best_model_path 108 | log.info(f"Best ckpt path: {ckpt_path}") 109 | test_loader = datamodule.test_dataloader() 110 | y_true = [] 111 | for _, y in test_loader: 112 | if isinstance(y, list): 113 | y_true.extend(y[-1].numpy()) 114 | else: 115 | y_true.extend(y.numpy()) 116 | y_true = np.hstack(y_true) 117 | trainer.test(model=model, dataloaders=test_loader, ckpt_path=ckpt_path) 118 | y_pred = trainer.predict( 119 | model=model, dataloaders=test_loader, ckpt_path=ckpt_path 120 | ) 121 | y_pred = np.hstack(y_pred) 122 | utils.log_test_results(cfg, y_true, y_pred) 123 | 124 | test_metrics = trainer.callback_metrics 125 | 126 | # merge train and test metrics 127 | metric_dict = {**train_metrics, **test_metrics} 128 | 129 | return metric_dict, object_dict 130 | 131 | 132 | @hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") 133 | def main(cfg: DictConfig) -> Optional[float]: 134 | # apply extra utilities 135 | # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 136 | utils.extras(cfg) 137 | 138 | # train the model 139 | metric_dict, _ = train(cfg) 140 | 141 | # safely retrieve metric value for hydra-based hyperparameter optimization 142 | metric_value = utils.get_metric_value( 143 | metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") 144 | ) 145 | 146 | # return optimized metric 147 | return metric_value 148 | 149 | 150 | if __name__ == "__main__": 151 | main() 152 | -------------------------------------------------------------------------------- /src/train_gbdt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from datetime import datetime 4 | from timeit import default_timer as timer 5 | 6 | import numpy as np 7 | import pyrootutils 8 | import typer 9 | import wandb 10 | from loguru import logger 11 | from rich.progress import track 12 | from sklearn.metrics import classification_report, top_k_accuracy_score 13 | from wandb.lightgbm import log_summary, wandb_callback 14 | 15 | ROOT = pyrootutils.setup_root( 16 | __file__, 17 | indicator=".project-root", 18 | pythonpath=True, 19 | ) 20 | 21 | from src.datasets.mfc import MFC_LOADER 22 | from src.models.gbdt import GBDTClassifier 23 | 24 | app = typer.Typer() 25 | 26 | 27 | def pprint(data): 28 | return json.dumps(data, indent=2) 29 | 30 | 31 | def seed_everything(seed: int) -> None: 32 | np.random.seed(seed) 33 | random.seed(seed) 34 | 35 | 36 | @app.command() 37 | def main( 38 | data_name: str = None, 39 | feature_name: str = "feature-ember-npy", 40 | task_group: str = "gbdt-ember", 41 | pack_ratio: float = 0.0, 42 | train_size: float = 0.6, 43 | val_size: float = 0.2, 44 | test_size: float = 0.2, 45 | boosting: str = "gbdt", 46 | objective: str = "multiclass", 47 | num_class: int = 8, 48 | metric: str = "multi_logloss", 49 | num_iterations: int = 1_000, 50 | learning_rate: float = 0.05, 51 | num_leaves: int = 2048, 52 | max_depth: int = 15, 53 | min_data_in_leaf: int = 50, 54 | feature_fraction: float = 0.5, 55 | verbosity: int = -1, 56 | device: str = "cpu", 57 | num_threads: int = 20, 58 | seed: int = 42, 59 | do_wandb: bool = False, 60 | project: str = "lab-benchmfc", 61 | ): 62 | seed_everything(seed) 63 | # gbdt_config 64 | gbdt_params = { 65 | "boosting": boosting, 66 | "objective": objective, 67 | "num_class": num_class, 68 | "metric": metric, 69 | "num_iterations": num_iterations, 70 | "learning_rate": learning_rate, 71 | "num_leaves": num_leaves, 72 | "max_depth": max_depth, 73 | "min_data_in_leaf": min_data_in_leaf, 74 | "feature_fraction": feature_fraction, 75 | "device": device, 76 | "num_threads": num_threads, 77 | "verbosity": verbosity, 78 | } 79 | config = locals() 80 | 81 | gbdt = GBDTClassifier(**gbdt_params) 82 | # time 83 | now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 84 | # log_dir 85 | task_name = f"{task_group}-{data_name}-{pack_ratio}" 86 | log_dir = ROOT / f"logs/{task_name}/runs/{now}" 87 | log_dir.mkdir(parents=True, exist_ok=True) 88 | logger.add(f"{log_dir}/train.log", level="INFO") 89 | 90 | logger.info(f"[-] Global seed = {seed}") 91 | logger.info(f"[-] Config: {pprint(config)}") 92 | logger.info(f"[-] GBDT Params: {pprint(gbdt_params)}") 93 | 94 | start = timer() 95 | # prepare data 96 | X_train, X_val, X_test, y_train, y_val, y_test = MFC_LOADER[data_name].load( 97 | feature=feature_name, 98 | train_size=train_size, 99 | val_size=val_size, 100 | test_size=test_size, 101 | pack_ratio=pack_ratio, 102 | ) 103 | X_train = [ 104 | np.load(i) 105 | for i in track(X_train, total=len(X_train), description="Loading train...") 106 | ] 107 | X_test = [ 108 | np.load(i) 109 | for i in track(X_test, total=len(X_test), description="Loading test...") 110 | ] 111 | 112 | X_train = np.vstack(X_train) 113 | y_train = np.array(y_train) 114 | X_test = np.vstack(X_test) 115 | y_test = np.array(y_test) 116 | # train 117 | name = task_name 118 | if do_wandb: 119 | wandb.login() 120 | wandb.init( 121 | project=project, name=name, group=task_group, config=config, dir=log_dir 122 | ) 123 | gbdt.fit(X_train, y_train, callbacks=[wandb_callback()]) 124 | log_summary(gbdt.model, feature_importance=True) 125 | else: 126 | gbdt.fit(X_train, y_train) 127 | # save model 128 | model_file = log_dir / f"{task_name}.lbg" 129 | logger.info(f"[-] save model: {model_file}") 130 | gbdt.model.save_model(model_file) 131 | # test 132 | predict = gbdt.predict(X_test) 133 | acc = top_k_accuracy_score(y_true=y_test, y_score=predict, k=1) 134 | logger.info(f"[*] Top-1 accuracy: {acc}") 135 | if do_wandb: 136 | wandb.log({"test/acc": acc}) 137 | wandb.finish() 138 | 139 | predict = [np.argmax(i) for i in predict] 140 | result = classification_report( 141 | y_true=y_test, y_pred=predict, digits=4, output_dict=True 142 | ) 143 | logger.info(f"[*] Classification_report (macro avg): {pprint(result['macro avg'])}") 144 | 145 | end = timer() 146 | logger.info(f"[-] timecost: {end - start} s") 147 | 148 | 149 | if __name__ == "__main__": 150 | app() 151 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from src.utils.instantiators import instantiate_callbacks, instantiate_loggers 2 | from src.utils.logging_utils import log_hyperparameters 3 | from src.utils.pylogger import get_pylogger 4 | from src.utils.rich_utils import ( 5 | enforce_tags, 6 | log_data_summary, 7 | log_test_results, 8 | print_config_tree, 9 | ) 10 | from src.utils.utils import extras, get_metric_value, log_confusion_matrix, task_wrapper 11 | 12 | __all__ = [ 13 | instantiate_callbacks, 14 | instantiate_loggers, 15 | log_hyperparameters, 16 | get_pylogger, 17 | enforce_tags, 18 | log_test_results, 19 | print_config_tree, 20 | log_data_summary, 21 | extras, 22 | get_metric_value, 23 | log_confusion_matrix, 24 | task_wrapper, 25 | ] 26 | -------------------------------------------------------------------------------- /src/utils/instantiators.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import hydra 4 | from omegaconf import DictConfig 5 | from pytorch_lightning import Callback 6 | from pytorch_lightning.loggers import Logger 7 | 8 | from src.utils import pylogger 9 | 10 | log = pylogger.get_pylogger(__name__) 11 | 12 | 13 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: 14 | """Instantiates callbacks from config.""" 15 | 16 | callbacks: List[Callback] = [] 17 | 18 | if not callbacks_cfg: 19 | log.warning("No callback configs found! Skipping..") 20 | return callbacks 21 | 22 | if not isinstance(callbacks_cfg, DictConfig): 23 | raise TypeError("Callbacks config must be a DictConfig!") 24 | 25 | for _, cb_conf in callbacks_cfg.items(): 26 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: 27 | log.info(f"Instantiating callback <{cb_conf._target_}>") 28 | callbacks.append(hydra.utils.instantiate(cb_conf)) 29 | 30 | return callbacks 31 | 32 | 33 | def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: 34 | """Instantiates loggers from config.""" 35 | 36 | logger: List[Logger] = [] 37 | 38 | if not logger_cfg: 39 | log.warning("No logger configs found! Skipping...") 40 | return logger 41 | 42 | if not isinstance(logger_cfg, DictConfig): 43 | raise TypeError("Logger config must be a DictConfig!") 44 | 45 | for _, lg_conf in logger_cfg.items(): 46 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: 47 | log.info(f"Instantiating logger <{lg_conf._target_}>") 48 | logger.append(hydra.utils.instantiate(lg_conf)) 49 | 50 | return logger 51 | -------------------------------------------------------------------------------- /src/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | from lightning.pytorch.utilities import rank_zero_only 2 | 3 | from src.utils import pylogger 4 | 5 | log = pylogger.get_pylogger(__name__) 6 | 7 | 8 | @rank_zero_only 9 | def log_hyperparameters(object_dict: dict) -> None: 10 | """Controls which config parts are saved by lightning loggers. 11 | 12 | Additionally saves: 13 | - Number of model parameters 14 | """ 15 | 16 | hparams = {} 17 | 18 | cfg = object_dict["cfg"] 19 | model = object_dict["model"] 20 | trainer = object_dict["trainer"] 21 | 22 | if not trainer.logger: 23 | log.warning("Logger not found! Skipping hyperparameter logging...") 24 | return 25 | 26 | hparams["model"] = cfg["model"] 27 | 28 | # save number of model parameters 29 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 30 | hparams["model/params/trainable"] = sum( 31 | p.numel() for p in model.parameters() if p.requires_grad 32 | ) 33 | hparams["model/params/non_trainable"] = sum( 34 | p.numel() for p in model.parameters() if not p.requires_grad 35 | ) 36 | 37 | hparams["data"] = cfg["data"] 38 | hparams["trainer"] = cfg["trainer"] 39 | 40 | hparams["callbacks"] = cfg.get("callbacks") 41 | hparams["extras"] = cfg.get("extras") 42 | 43 | hparams["task_name"] = cfg.get("task_name") 44 | hparams["tags"] = cfg.get("tags") 45 | hparams["ckpt_path"] = cfg.get("ckpt_path") 46 | hparams["seed"] = cfg.get("seed") 47 | 48 | # send hparams to all loggers 49 | for logger in trainer.loggers: 50 | logger.log_hyperparams(hparams) 51 | -------------------------------------------------------------------------------- /src/utils/pylogger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from lightning.pytorch.utilities import rank_zero_only 4 | 5 | 6 | def get_pylogger(name=__name__) -> logging.Logger: 7 | """Initializes multi-GPU-friendly python command line logger.""" 8 | 9 | logger = logging.getLogger(name) 10 | 11 | # this ensures all logging levels get marked with the rank zero decorator 12 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 13 | logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") 14 | for level in logging_levels: 15 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 16 | 17 | return logger 18 | -------------------------------------------------------------------------------- /src/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import rich 7 | import rich.syntax 8 | import rich.tree 9 | from hydra.core.hydra_config import HydraConfig 10 | from lightning.pytorch.utilities import rank_zero_only 11 | from omegaconf import DictConfig, OmegaConf, open_dict 12 | from rich.prompt import Prompt 13 | from sklearn.metrics import ( 14 | ConfusionMatrixDisplay, 15 | classification_report, 16 | confusion_matrix, 17 | ) 18 | 19 | from src.datasets.ember import EmberDataModule 20 | from src.utils import pylogger 21 | 22 | log = pylogger.get_pylogger(__name__) 23 | 24 | 25 | @rank_zero_only 26 | def print_config_tree( 27 | cfg: DictConfig, 28 | print_order: Sequence[str] = ( 29 | "data", 30 | "model", 31 | "callbacks", 32 | "logger", 33 | "trainer", 34 | "paths", 35 | "extras", 36 | ), 37 | resolve: bool = False, 38 | save_to_file: bool = False, 39 | ) -> None: 40 | """Prints content of DictConfig using Rich library and its tree structure. 41 | 42 | Args: 43 | cfg (DictConfig): Configuration composed by Hydra. 44 | print_order (Sequence[str], optional): Determines in what order config components are printed. 45 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 46 | save_to_file (bool, optional): Whether to export config to the hydra output folder. 47 | """ 48 | 49 | style = "dim" 50 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 51 | 52 | queue = [] 53 | 54 | # add fields from `print_order` to queue 55 | for field in print_order: 56 | queue.append(field) if field in cfg else log.warning( 57 | f"Field '{field}' not found in config. Skipping '{field}' config printing..." 58 | ) 59 | 60 | # add all the other fields to queue (not specified in `print_order`) 61 | for field in cfg: 62 | if field not in queue: 63 | queue.append(field) 64 | 65 | # generate config tree from queue 66 | for field in queue: 67 | branch = tree.add(field, style=style, guide_style=style) 68 | 69 | config_group = cfg[field] 70 | if isinstance(config_group, DictConfig): 71 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 72 | else: 73 | branch_content = str(config_group) 74 | 75 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 76 | 77 | # print config tree 78 | rich.print(tree) 79 | 80 | # save config tree to file 81 | if save_to_file: 82 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 83 | rich.print(tree, file=file) 84 | 85 | 86 | @rank_zero_only 87 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 88 | """Prompts user to input tags from command line if no tags are provided in config.""" 89 | 90 | if not cfg.get("tags"): 91 | if "id" in HydraConfig().cfg.hydra.job: 92 | raise ValueError("Specify tags before launching a multirun!") 93 | 94 | log.warning("No tags provided in config. Prompting user to input tags...") 95 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 96 | tags = [t.strip() for t in tags.split(",") if t != ""] 97 | 98 | with open_dict(cfg): 99 | cfg.tags = tags 100 | 101 | log.info(f"Tags: {cfg.tags}") 102 | 103 | if save_to_file: 104 | tags = [str(i) for i in cfg.tags] 105 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: 106 | rich.print(tags, file=file) 107 | 108 | 109 | @rank_zero_only 110 | def log_test_results(cfg: DictConfig, y_true: list, y_pred: list) -> None: 111 | y_true = np.array(y_true) 112 | y_pred = np.array(y_pred) 113 | 114 | # classification report 115 | cm = confusion_matrix(y_true, y_pred) 116 | cr = classification_report(y_true, y_pred, digits=4, zero_division=0) 117 | rich.print("Test Classification Report:") 118 | rich.print(cr) 119 | 120 | test_path = Path(cfg.paths.output_dir) / "test" 121 | test_path.mkdir(parents=True, exist_ok=True) 122 | 123 | cm_fig = ConfusionMatrixDisplay(cm).plot().figure_ 124 | cm_fig.savefig(test_path / "confusion_matrix.png") 125 | np.savetxt(test_path / "confusion_matrix.txt", cm, fmt="%d") 126 | 127 | with open(test_path / "classification_report.log", "w") as file: 128 | rich.print(cr, file=file) 129 | 130 | 131 | def plot_features( 132 | X: np.array, 133 | labels: np.array, 134 | max_length: int = 4096, 135 | ) -> plt.figure: 136 | fig, ax = plt.subplots() 137 | for i, j in zip(X, labels): 138 | ax.plot(i[:max_length], label=j) 139 | ax.set_title("Example Features of Classes") 140 | ax.legend(title="class") 141 | ax.set_xlabel("id") 142 | ax.set_ylabel("value") 143 | return fig 144 | 145 | 146 | @rank_zero_only 147 | def log_data_summary(cfg: DictConfig, data_module: EmberDataModule) -> None: 148 | data_summary = data_module.summary() 149 | save_path = Path(cfg.paths.output_dir) / "data" 150 | save_path.mkdir(parents=True, exist_ok=True) 151 | 152 | with open(save_path / "data_summary.log", "w") as file: 153 | rich.print(data_summary, file=file) 154 | 155 | data_loader = data_module.test_dataloader() 156 | X, y = next(iter(data_loader)) 157 | 158 | if isinstance(y, list): 159 | y = y[-1] 160 | if isinstance(X, list): 161 | X = X[0] 162 | X, y = X.numpy(), y.numpy() 163 | 164 | batch_size = X.shape[0] 165 | X = X.reshape(batch_size, -1) 166 | 167 | labels, indices = np.unique(y, return_index=True) 168 | X = X[indices] 169 | fig = plot_features(X, labels) 170 | 171 | fig.savefig(save_path / "data_example.png") 172 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from importlib.util import find_spec 3 | from pathlib import Path 4 | from typing import Callable 5 | 6 | import numpy as np 7 | from omegaconf import DictConfig 8 | 9 | from src.utils import pylogger, rich_utils 10 | 11 | log = pylogger.get_pylogger(__name__) 12 | 13 | 14 | def extras(cfg: DictConfig) -> None: 15 | """Applies optional utilities before the task is started. 16 | 17 | Utilities: 18 | - Ignoring python warnings 19 | - Setting tags from command line 20 | - Rich config printing 21 | """ 22 | 23 | # return if no `extras` config 24 | if not cfg.get("extras"): 25 | log.warning("Extras config not found! ") 26 | return 27 | 28 | # disable python warnings 29 | if cfg.extras.get("ignore_warnings"): 30 | log.info("Disabling python warnings! ") 31 | warnings.filterwarnings("ignore") 32 | 33 | # prompt user to input tags from command line if none are provided in the config 34 | if cfg.extras.get("enforce_tags"): 35 | log.info("Enforcing tags! ") 36 | rich_utils.enforce_tags(cfg, save_to_file=True) 37 | 38 | # pretty print config tree using Rich library 39 | if cfg.extras.get("print_config"): 40 | log.info("Printing config tree with Rich! ") 41 | rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) 42 | 43 | 44 | def task_wrapper(task_func: Callable) -> Callable: 45 | """Optional decorator that controls the failure behavior when executing the task function. 46 | 47 | This wrapper can be used to: 48 | - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) 49 | - save the exception to a `.log` file 50 | - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) 51 | - etc. (adjust depending on your needs) 52 | 53 | Example: 54 | ``` 55 | @utils.task_wrapper 56 | def train(cfg: DictConfig) -> Tuple[dict, dict]: 57 | 58 | ... 59 | 60 | return metric_dict, object_dict 61 | ``` 62 | """ 63 | 64 | def wrap(cfg: DictConfig): 65 | # execute the task 66 | try: 67 | metric_dict, object_dict = task_func(cfg=cfg) 68 | 69 | # things to do if exception occurs 70 | except Exception as ex: 71 | # save exception to `.log` file 72 | log.exception("") 73 | 74 | # some hyperparameter combinations might be invalid or cause out-of-memory errors 75 | # so when using hparam search plugins like Optuna, you might want to disable 76 | # raising the below exception to avoid multirun failure 77 | raise ex 78 | 79 | # things to always do after either success or exception 80 | finally: 81 | # display output dir path in terminal 82 | log.info(f"Output dir: {cfg.paths.output_dir}") 83 | 84 | # always close wandb run (even if exception occurs so multirun won't fail) 85 | if find_spec("wandb"): # check if wandb is installed 86 | import wandb 87 | 88 | if wandb.run: 89 | log.info("Closing wandb!") 90 | wandb.finish() 91 | 92 | return metric_dict, object_dict 93 | 94 | return wrap 95 | 96 | 97 | def get_metric_value(metric_dict: dict, metric_name: str) -> float: 98 | """Safely retrieves value of the metric logged in LightningModule.""" 99 | 100 | if not metric_name: 101 | log.info("Metric name is None! Skipping metric value retrieval...") 102 | return None 103 | 104 | if metric_name not in metric_dict: 105 | raise Exception( 106 | f"Metric value not found! \n" 107 | "Make sure metric name logged in LightningModule is correct!\n" 108 | "Make sure `optimized_metric` name in `hparams_search` config is correct!" 109 | ) 110 | 111 | metric_value = metric_dict[metric_name].item() 112 | log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") 113 | 114 | return metric_value 115 | 116 | 117 | def log_confusion_matrix(cfg: DictConfig, cm: np.array) -> None: 118 | save_file = Path(cfg.paths.output_dir) / "test_confusion_matrix.txt" 119 | with open(save_file, "w") as file: 120 | np.savetxt(file, cm, fmt="%d") 121 | 122 | 123 | def log_data_counter(cfg: DictConfig, data: list) -> None: 124 | save_file = Path(cfg.paths.output_dir) / "test_data_counter.txt" 125 | with open(save_file, "w") as file: 126 | for k, v in data: 127 | file.write(f"{k}: {v}\n") 128 | --------------------------------------------------------------------------------