├── configs ├── .gitkeep ├── TestFinetuningConfig.json ├── TestPretrainingConfig.json ├── UCSBPretrainingConfig.json ├── MawiPretrainingConfig.json ├── UCSBFinetuningConfig.json ├── UCSBFinetuningTCPOptionsConfig.json ├── UCSBPretrainingTCPOptionsConfig.json ├── CICIDSFinetuningConfig.json └── deepspeed_example.json ├── src ├── __init__.py ├── train │ ├── __init__.py │ ├── NetFoundTrainer.py │ ├── NetfoundConfig.py │ ├── NetfoundTokenizer.py │ ├── NetFoundDataCollator.py │ ├── NetfoundFinetuning.py │ ├── NetfoundPretraining.py │ ├── utils.py │ └── NetFoundModels.py └── pre_process │ ├── __init__.py │ ├── packets_processing_src │ ├── README.md │ ├── CMakeLists.txt │ ├── 1_filter │ │ └── 1_filter.cpp │ └── 3_field_extraction │ │ └── 3_field_extraction.cpp │ ├── 2_pcap_splitting.sh │ ├── 1_filter.sh │ ├── Shuffle.py │ ├── CollectTokensInFiles.py │ ├── 3_extract_fields.sh │ ├── README.md │ └── Tokenize.py ├── .dockerignore ├── data └── test │ ├── pretraining │ └── raw │ │ └── imap.pcap │ └── finetuning │ └── raw │ ├── 0 │ └── imap1.pcap │ └── 1 │ └── imap2.pcap ├── requirements.txt ├── .gitignore ├── scripts ├── pretrain-test.sh ├── finetune-test.sh ├── shuffler.py ├── print_arrow.py ├── slurm │ ├── netfound_pretraining.sl │ └── netfound_finetuning.sl └── preprocess_data.py ├── LICENSE ├── Dockerfile ├── Makefile └── README.md /configs/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | venv 2 | build -------------------------------------------------------------------------------- /src/pre_process/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/test/pretraining/raw/imap.pcap: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SNL-UCSB/netFound/HEAD/data/test/pretraining/raw/imap.pcap -------------------------------------------------------------------------------- /data/test/finetuning/raw/0/imap1.pcap: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SNL-UCSB/netFound/HEAD/data/test/finetuning/raw/0/imap1.pcap -------------------------------------------------------------------------------- /data/test/finetuning/raw/1/imap2.pcap: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SNL-UCSB/netFound/HEAD/data/test/finetuning/raw/1/imap2.pcap -------------------------------------------------------------------------------- /src/pre_process/packets_processing_src/README.md: -------------------------------------------------------------------------------- 1 | # Compilation 2 | - Install gcc and PcapPlusPlus for building. 3 | - create a folder for build: `mkdir build` 4 | - `cd build` 5 | - `cmake ..` 6 | - `make` -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | matplotlib 3 | seaborn 4 | scipy 5 | pcap_splitter 6 | scapy 7 | scikit-learn 8 | python-dateutil 9 | transformers 10 | datasets 11 | tqdm 12 | packaging 13 | graphviz 14 | fsspec 15 | pyarrow 16 | deepspeed 17 | accelerate 18 | tensorboard 19 | torchinfo 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/.idea 2 | venv/ 3 | **/cmake-build-debug 4 | **/.PVS-Studio 5 | */**/*.pcap 6 | */**/*.pcapng 7 | */**/*.so 8 | data/test/* 9 | */**/__pycache__ 10 | */**/*.c 11 | build/ 12 | src/pre_process/1_filter 13 | src/pre_process/3_field_extraction 14 | models/ 15 | example.log 16 | **/.env -------------------------------------------------------------------------------- /src/pre_process/2_pcap_splitting.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set +x 5 | 6 | if [ "$#" -ne 2 ]; then 7 | echo "Usage: $0 input_folder output_folder" 8 | exit 1 9 | fi 10 | 11 | input_folder="$1" 12 | output_folder="$2" 13 | 14 | mkdir -p "$output_folder" 15 | 16 | find "$input_folder" -type f | parallel "mkdir -p $output_folder/{/.} && PcapSplitter -f {} -o $output_folder/{/.}/ -m connection" 17 | -------------------------------------------------------------------------------- /src/pre_process/packets_processing_src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.20) 2 | project(ndr_preprocessing) 3 | 4 | set(CMAKE_CXX_STANDARD 23) 5 | set(CMAKE_CXX_EXTENSIONS ON) 6 | 7 | find_package(PcapPlusPlus REQUIRED) 8 | 9 | 10 | add_executable(1_filter 1_filter/1_filter.cpp) 11 | add_executable(3_field_extraction 3_field_extraction/3_field_extraction.cpp) 12 | target_link_libraries("1_filter" PUBLIC PcapPlusPlus::Pcap++) 13 | target_link_libraries("3_field_extraction" PUBLIC PcapPlusPlus::Pcap++) 14 | -------------------------------------------------------------------------------- /scripts/pretrain-test.sh: -------------------------------------------------------------------------------- 1 | # this script pretrains a model on a test dataset 2 | 3 | # TODO(maybe-hello-world): understand why safetensors are not usable and resolve 4 | 5 | python \ 6 | src/train/NetfoundPretraining.py \ 7 | --train_dir data/test/pretraining/final/combined/ \ 8 | --output_dir models/test/pretraining/pretrained_model \ 9 | --report_to tensorboard \ 10 | --do_train \ 11 | --num_train_epochs 3 \ 12 | --overwrite_output_dir \ 13 | --save_safetensors false \ 14 | --mlm_probability 0.10 \ 15 | --learning_rate 2e-5 \ 16 | --do_eval \ 17 | --validation_split_percentage 30 18 | -------------------------------------------------------------------------------- /scripts/finetune-test.sh: -------------------------------------------------------------------------------- 1 | # this script finetunes a model on a test dataset 2 | 3 | python \ 4 | src/train/NetfoundFinetuning.py \ 5 | --train_dir data/test/finetuning/final/combined \ 6 | --model_name_or_path models/test/pretraining/pretrained_model \ 7 | --output_dir models/test/finetuning/finetuned_model \ 8 | --report_to tensorboard \ 9 | --overwrite_output_dir \ 10 | --save_safetensors false \ 11 | --do_train \ 12 | --do_eval \ 13 | --eval_strategy epoch \ 14 | --save_strategy epoch \ 15 | --learning_rate 0.01 \ 16 | --num_train_epochs 4 \ 17 | --problem_type single_label_classification \ 18 | --num_labels 2 \ 19 | --load_best_model_at_end 20 | 21 | 22 | -------------------------------------------------------------------------------- /src/pre_process/1_filter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | set +x 4 | 5 | if [ "$#" -ne 2 ]; then 6 | echo "Usage: $0 input_folder output_folder" 7 | exit 1 8 | fi 9 | 10 | # Get the directory where the current script is located 11 | script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" 12 | 13 | input_folder="$1" 14 | output_folder="$2" 15 | 16 | mkdir -p "$output_folder" 17 | 18 | # Check if 1_filter script exists in the same directory as the current script 19 | filter_script="$script_dir/1_filter" 20 | if [ ! -f "$filter_script" ]; then 21 | echo "Error: 1_filter script not found in $script_dir, please run make all from the project root directory" 22 | exit 1 23 | fi 24 | 25 | find "$input_folder" -type f | parallel "$filter_script {} $output_folder/{/}" 26 | -------------------------------------------------------------------------------- /configs/TestFinetuningConfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "IPFields": [{"field":"IP_hl", "bits": 4, "numberTokens": 1}, {"field":"IP_tos", "bits": 8, "numberTokens": 1}, {"field":"IP_tl", "bits": 16, "numberTokens": 1}, {"field":"IP_Flags", "bits": 3, "numberTokens": 1}, {"field":"IP_ttl", "bits": 8, "numberTokens": 1}], 3 | "TCPFields": [{"field":"TCP_Flags", "bits": 16, "numberTokens": 1}, {"field":"TCP_wsize", "bits": 16, "numberTokens": 1}, {"field":"TCP_seq", "bits": 32, "numberTokens": 2}, {"field":"TCP_ackn", "bits": 32, "numberTokens": 2}, {"field":"TCP_urp", "bits": 16, "numberTokens": 1}], 4 | "UDPFields": [{"field":"UDP_len", "bits": 16, "numberTokens": 1}], 5 | "ICMPFields" : [{"field":"ICMP_type", "bits": 8, "numberTokens": 1},{"field":"ICMP_code", "bits": 8, "numberTokens": 1}], 6 | "Payload": [{"field":"Payload", "bits": 96, "numberTokens": 6}], 7 | "internalIPs": ["127.0.0.0/8"], 8 | "Finetuning": true 9 | } -------------------------------------------------------------------------------- /configs/TestPretrainingConfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "IPFields": [{"field":"IP_hl", "bits": 4, "numberTokens": 1}, {"field":"IP_tos", "bits": 8, "numberTokens": 1}, {"field":"IP_tl", "bits": 16, "numberTokens": 1}, {"field":"IP_Flags", "bits": 3, "numberTokens": 1}, {"field":"IP_ttl", "bits": 8, "numberTokens": 1}], 3 | "TCPFields": [{"field":"TCP_Flags", "bits": 16, "numberTokens": 1}, {"field":"TCP_wsize", "bits": 16, "numberTokens": 1}, {"field":"TCP_seq", "bits": 32, "numberTokens": 2}, {"field":"TCP_ackn", "bits": 32, "numberTokens": 2}, {"field":"TCP_urp", "bits": 16, "numberTokens": 1}], 4 | "UDPFields": [{"field":"UDP_len", "bits": 16, "numberTokens": 1}], 5 | "ICMPFields" : [{"field":"ICMP_type", "bits": 8, "numberTokens": 1},{"field":"ICMP_code", "bits": 8, "numberTokens": 1}], 6 | "Payload": [{"field":"Payload", "bits": 96, "numberTokens": 6}], 7 | "internalIPs": ["127.0.0.0/8"], 8 | "Finetuning": false 9 | } -------------------------------------------------------------------------------- /src/train/NetFoundTrainer.py: -------------------------------------------------------------------------------- 1 | from transformers import Trainer 2 | 3 | 4 | class NetfoundTrainer(Trainer): 5 | 6 | extraFields = {} 7 | 8 | def _set_signature_columns_if_needed(self): 9 | super()._set_signature_columns_if_needed() 10 | self._signature_columns += { 11 | "direction", 12 | "iats", 13 | "bytes", 14 | "pkt_count", 15 | "total_bursts", 16 | "ports", 17 | "stats", 18 | "protocol", 19 | } 20 | self._signature_columns += self.extraFields 21 | self._signature_columns = list(set(self._signature_columns)) 22 | 23 | def __init__(self, label_names=None, extraFields = None, *args, **kwargs): 24 | super().__init__(*args, **kwargs) 25 | if extraFields is not None: 26 | self.extraFields = extraFields 27 | if label_names is not None: 28 | self.label_names.extend(label_names) 29 | -------------------------------------------------------------------------------- /configs/UCSBPretrainingConfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "input_dir": "/data/flow_split/", 3 | "output_dir": "/data/UCSBwithMetaPretrainingNoOPT/", 4 | "IPFields": [{"field":"IP_hl", "bits": 4, "numberTokens": 1}, {"field":"IP_tos", "bits": 8, "numberTokens": 1}, {"field":"IP_tl", "bits": 16, "numberTokens": 1}, {"field":"IP_Flags", "bits": 3, "numberTokens": 1}, {"field":"IP_ttl", "bits": 8, "numberTokens": 1}], 5 | "TCPFields": [{"field":"TCP_Flags", "bits": 16, "numberTokens": 1}, {"field":"TCP_wsize", "bits": 16, "numberTokens": 1}, {"field":"TCP_seq", "bits": 32, "numberTokens": 2}, {"field":"TCP_ackn", "bits": 32, "numberTokens": 2}, {"field":"TCP_urp", "bits": 16, "numberTokens": 1}], 6 | "UDPFields": [{"field":"UDP_len", "bits": 16, "numberTokens": 1}], 7 | "ICMPFields" : [{"field":"ICMP_type", "bits": 8, "numberTokens": 1},{"field":"ICMP_code", "bits": 8, "numberTokens": 1}], 8 | "Payload": [{"field":"Payload", "bits": 96, "numberTokens": 6}], 9 | "internalIPs": ["128.111.0.0/16", "169.231.0.0/16"], 10 | "Finetuning": false 11 | } -------------------------------------------------------------------------------- /configs/MawiPretrainingConfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "input_dir": "/data/202106211400_npts/", 3 | "output_dir": "/data/MawiPretrainingNoOPT/", 4 | "IPFields": [{"field":"IP_hl", "bits": 4, "numberTokens": 1}, {"field":"IP_tos", "bits": 8, "numberTokens": 1}, {"field":"IP_tl", "bits": 16, "numberTokens": 1}, {"field":"IP_Flags", "bits": 3, "numberTokens": 1}, {"field":"IP_ttl", "bits": 8, "numberTokens": 1}], 5 | "TCPFields": [{"field":"TCP_Flags", "bits": 16, "numberTokens": 1}, {"field":"TCP_wsize", "bits": 16, "numberTokens": 1}, {"field":"TCP_seq", "bits": 32, "numberTokens": 2}, {"field":"TCP_ackn", "bits": 32, "numberTokens": 2}, {"field":"TCP_urp", "bits": 16, "numberTokens": 1}], 6 | "UDPFields": [{"field":"UDP_len", "bits": 16, "numberTokens": 1}], 7 | "ICMPFields" : [{"field":"ICMP_type", "bits": 8, "numberTokens": 1},{"field":"ICMP_code", "bits": 8, "numberTokens": 1}], 8 | "Payload": [{"field":"Payload", "bits": 96, "numberTokens": 6}], 9 | "internalIPs": ["163.0.0.0/8", "202.0.0.0/8", "203.0.0.0/8"], 10 | "Finetuning": false 11 | } -------------------------------------------------------------------------------- /configs/UCSBFinetuningConfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "input_dir": "/data/UCSBFinetuning_with_payload/", 3 | "output_dir": "/data/UCSBwithMetaFinetuningNoOPT/", 4 | "IPFields": [{"field":"IP_hl", "bits": 4, "numberTokens": 1}, {"field":"IP_tos", "bits": 8, "numberTokens": 1}, {"field":"IP_tl", "bits": 16, "numberTokens": 1}, {"field":"IP_Flags", "bits": 3, "numberTokens": 1}, {"field":"IP_ttl", "bits": 8, "numberTokens": 1}], 5 | "TCPFields": [{"field":"TCP_Flags", "bits": 16, "numberTokens": 1}, {"field":"TCP_wsize", "bits": 16, "numberTokens": 1}, {"field":"TCP_seq", "bits": 32, "numberTokens": 2}, {"field":"TCP_ackn", "bits": 32, "numberTokens": 2}, {"field":"TCP_urp", "bits": 16, "numberTokens": 1}], 6 | "UDPFields": [{"field":"UDP_len", "bits": 16, "numberTokens": 1}], 7 | "ICMPFields" : [{"field":"ICMP_type", "bits": 8, "numberTokens": 1},{"field":"ICMP_code", "bits": 8, "numberTokens": 1}], 8 | "Payload": [{"field":"Payload", "bits": 96, "numberTokens": 6}], 9 | "internalIPs": ["128.111.0.0/16", "169.231.0.0/16"], 10 | "Finetuning": true 11 | } -------------------------------------------------------------------------------- /configs/UCSBFinetuningTCPOptionsConfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "input_dir": "/data/flow_split/", 3 | "output_dir": "/data/UCSBwithMetaPretrainingNoOPT/", 4 | "IPFields": [{"field":"IP_hl", "bits": 4, "numberTokens": 1}, {"field":"IP_tos", "bits": 8, "numberTokens": 1}, {"field":"IP_tl", "bits": 16, "numberTokens": 1}, {"field":"IP_Flags", "bits": 3, "numberTokens": 1}, {"field":"IP_ttl", "bits": 8, "numberTokens": 1}], 5 | "TCPFields": [{"field":"TCP_Flags", "bits": 16, "numberTokens": 1}, {"field":"TCP_wsize", "bits": 16, "numberTokens": 1}, {"field":"TCP_seq", "bits": 32, "numberTokens": 2}, {"field":"TCP_ackn", "bits": 32, "numberTokens": 2}, {"field":"TCP_urp", "bits": 16, "numberTokens": 1}, {"field":"TCP_options", "bits": 320, "numberTokens": 20}], 6 | "UDPFields": [{"field":"UDP_len", "bits": 16, "numberTokens": 1}], 7 | "ICMPFields" : [{"field":"ICMP_type", "bits": 8, "numberTokens": 1},{"field":"ICMP_code", "bits": 8, "numberTokens": 1}], 8 | "Payload": [{"field":"Payload", "bits": 96, "numberTokens": 6}], 9 | "internalIPs": ["128.111.0.0/16", "169.231.0.0/16"], 10 | "Finetuning": true 11 | } -------------------------------------------------------------------------------- /configs/UCSBPretrainingTCPOptionsConfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "input_dir": "/data/flow_split/", 3 | "output_dir": "/data/UCSBwithMetaPretrainingNoOPT/", 4 | "IPFields": [{"field":"IP_hl", "bits": 4, "numberTokens": 1}, {"field":"IP_tos", "bits": 8, "numberTokens": 1}, {"field":"IP_tl", "bits": 16, "numberTokens": 1}, {"field":"IP_Flags", "bits": 3, "numberTokens": 1}, {"field":"IP_ttl", "bits": 8, "numberTokens": 1}], 5 | "TCPFields": [{"field":"TCP_Flags", "bits": 16, "numberTokens": 1}, {"field":"TCP_wsize", "bits": 16, "numberTokens": 1}, {"field":"TCP_seq", "bits": 32, "numberTokens": 2}, {"field":"TCP_ackn", "bits": 32, "numberTokens": 2}, {"field":"TCP_urp", "bits": 16, "numberTokens": 1}, {"field":"TCP_options", "bits": 320, "numberTokens": 20}], 6 | "UDPFields": [{"field":"UDP_len", "bits": 16, "numberTokens": 1}], 7 | "ICMPFields" : [{"field":"ICMP_type", "bits": 8, "numberTokens": 1},{"field":"ICMP_code", "bits": 8, "numberTokens": 1}], 8 | "Payload": [{"field":"Payload", "bits": 96, "numberTokens": 6}], 9 | "internalIPs": ["128.111.0.0/16", "169.231.0.0/16"], 10 | "Finetuning": false 11 | } -------------------------------------------------------------------------------- /configs/CICIDSFinetuningConfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "input_dir": "/data/CICIDS-2017-N/CICIDS_2017_N_npts1/", 3 | "output_dir": "/data/CICIDS-2017withMetaFinetuning/", 4 | "IPFields": [{"field":"IP_hl", "bits": 4, "numberTokens": 1}, {"field":"IP_tos", "bits": 8, "numberTokens": 1}, {"field":"IP_tl", "bits": 16, "numberTokens": 1}, {"field":"IP_Flags", "bits": 3, "numberTokens": 1}, {"field":"IP_ttl", "bits": 8, "numberTokens": 1}], 5 | "TCPFields": [{"field":"TCP_Flags", "bits": 16, "numberTokens": 1}, {"field":"TCP_wsize", "bits": 16, "numberTokens": 1}, {"field":"TCP_seq", "bits": 32, "numberTokens": 2}, {"field":"TCP_ackn", "bits": 32, "numberTokens": 2}, {"field":"TCP_urp", "bits": 16, "numberTokens": 1}], 6 | "UDPFields": [{"field":"UDP_len", "bits": 16, "numberTokens": 1}], 7 | "ICMPFields" : [{"field":"ICMP_type", "bits": 8, "numberTokens": 1},{"field":"ICMP_code", "bits": 8, "numberTokens": 1}], 8 | "Payload": [{"field":"Payload", "bits": 96, "numberTokens": 6}], 9 | "internalIPs": ["205.174.165.73/32", "205.174.165.69/32", "205.174.165.70/32", "205.174.165.71/32", "205.174.165.73/32", "205.174.165.80/32", "192.168.10.8/32"], 10 | "Finetuning": true 11 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Systems & Networking Lab, UCSB 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.7.1-cuda11.8-cudnn9-devel 2 | 3 | ENV DEBIAN_FRONTEND=noninteractive 4 | RUN apt-get update && apt-get upgrade -y && \ 5 | apt-get install -y --no-install-recommends \ 6 | gcc \ 7 | python3-dev \ 8 | wget \ 9 | graphviz \ 10 | parallel \ 11 | nano make cmake g++ \ 12 | libpcap-dev && \ 13 | apt-get clean && \ 14 | rm -rf /var/lib/apt/lists/* 15 | 16 | WORKDIR /workspace 17 | COPY requirements.txt requirements.txt 18 | 19 | RUN pip install --no-cache-dir -r requirements.txt 20 | 21 | RUN wget https://github.com/seladb/PcapPlusPlus/releases/download/v24.09/pcapplusplus-24.09-ubuntu-22.04-gcc-11.4.0-x86_64.tar.gz && \ 22 | tar -xvf pcapplusplus-24.09-ubuntu-22.04-gcc-11.4.0-x86_64.tar.gz && \ 23 | rm pcapplusplus-24.09-ubuntu-22.04-gcc-11.4.0-x86_64.tar.gz && \ 24 | mv pcapplusplus-24.09-ubuntu-22.04-gcc-11.4.0-x86_64 /usr/local/ && \ 25 | ln -s /usr/local/pcapplusplus-24.09-ubuntu-22.04-gcc-11.4.0-x86_64 /usr/local/pcapplusplus 26 | 27 | ENV PATH="/usr/local/pcapplusplus/bin:$PATH" 28 | 29 | COPY . . 30 | RUN find . -type f -name "*.sh" -exec chmod +x {} \; 31 | 32 | CMD ["bash"] 33 | -------------------------------------------------------------------------------- /scripts/shuffler.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import pyarrow as pa 4 | import pyarrow.ipc as ipc 5 | 6 | def shuffle_arrow_file(input_path, output_path, seed=42): 7 | with pa.memory_map(input_path, 'r') as source: 8 | reader = ipc.open_stream(source) 9 | table = reader.read_all() 10 | 11 | num_rows = table.num_rows 12 | print(f"Loaded table with {num_rows} rows.") 13 | np.random.seed(seed) 14 | perm = np.random.permutation(num_rows) 15 | shuffled_table = table.take(pa.array(perm)) 16 | 17 | print(f"Shuffled table with seed {seed}.") 18 | 19 | with pa.OSFile(output_path, 'wb') as sink: 20 | writer = ipc.new_stream(sink, shuffled_table.schema) 21 | writer.write_table(shuffled_table) 22 | writer.close() 23 | 24 | print(f"Shuffled file written to {output_path}.") 25 | 26 | if __name__ == "__main__": 27 | if len(sys.argv) < 3: 28 | print("Usage: python shuffle_arrow.py [seed]") 29 | sys.exit(1) 30 | 31 | input_file = sys.argv[1] 32 | output_file = sys.argv[2] 33 | seed = int(sys.argv[3]) if len(sys.argv) > 3 else 42 34 | 35 | shuffle_arrow_file(input_file, output_file, seed) 36 | -------------------------------------------------------------------------------- /src/pre_process/Shuffle.py: -------------------------------------------------------------------------------- 1 | # shuffle combined dataset 2 | 3 | import argparse 4 | import numpy as np 5 | import pyarrow as pa 6 | import pyarrow.ipc as ipc 7 | 8 | def shuffle_dataset(input_file, output_file): 9 | # Read the input file 10 | with open(input_file, "rb") as f: 11 | reader = ipc.open_stream(f) 12 | table = reader.read_all() 13 | 14 | # Shuffle the table 15 | num_rows = table.num_rows 16 | shuffled_indices = np.random.permutation(num_rows) 17 | shuffled_table = table.take(pa.array(shuffled_indices)) 18 | 19 | # Write the shuffled table to the output file 20 | with open(output_file, "wb") as f: 21 | writer = ipc.new_stream(f, shuffled_table.schema) 22 | writer.write_table(shuffled_table) 23 | writer.close() 24 | 25 | 26 | def main(): 27 | parser = argparse.ArgumentParser(description="Shuffle Apache Arrow streaming files.") 28 | parser.add_argument("input_file", type=str, help="The input Arrow file with data.") 29 | parser.add_argument("output_file", type=str, help="The output Arrow streaming file.") 30 | 31 | args = parser.parse_args() 32 | 33 | shuffle_dataset(args.input_file, args.output_file) 34 | 35 | 36 | if __name__ == "__main__": 37 | main() 38 | -------------------------------------------------------------------------------- /configs/deepspeed_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "optimizer": { 3 | "type": "AdamW", 4 | "params": { 5 | "lr": "auto", 6 | "betas": "auto", 7 | "eps": "auto", 8 | "weight_decay": "auto" 9 | } 10 | }, 11 | "scheduler": { 12 | "type": "WarmupLR", 13 | "params": { 14 | "warmup_min_lr": "auto", 15 | "warmup_max_lr": "auto", 16 | "warmup_num_steps": "auto" 17 | } 18 | }, 19 | "zero_optimization": { 20 | "stage": 3, 21 | "allgather_partitions": true, 22 | "allgather_bucket_size": 2e8, 23 | "overlap_comm": false, 24 | "reduce_scatter": true, 25 | "reduce_bucket_size": 2e8, 26 | "contiguous_gradients": true, 27 | "stage3_prefetch_bucket_size": "auto", 28 | "stage3_param_persistence_threshold": "auto", 29 | "stage3_max_live_parameters": 1e9, 30 | "stage3_max_reuse_distance": 1e9, 31 | "stage3_gather_16bit_weights_on_model_save": true 32 | }, 33 | "gradient_accumulation_steps": "auto", 34 | "gradient_clipping": "auto", 35 | "train_batch_size": "auto", 36 | "train_micro_batch_size_per_gpu":"auto", 37 | "wall_clock_breakdown": true 38 | } 39 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PREPROCESS_SRC_DIR = src/pre_process/packets_processing_src 2 | PREPROCESS_BUILD_DIR = build/preprocess 3 | DATA_DIR = data/test 4 | 5 | # Targets 6 | all: clean compile preprocess pretrain finetune 7 | 8 | clean: 9 | rm -rf $(PREPROCESS_BUILD_DIR) 10 | rm -f src/pre_process/1_filter 11 | rm -f src/pre_process/3_field_extraction 12 | rm -rf $(DATA_DIR)/pretraining/split $(DATA_DIR)/pretraining/filtered $(DATA_DIR)/pretraining/extracted $(DATA_DIR)/pretraining/final 13 | rm -rf $(DATA_DIR)/finetuning/split $(DATA_DIR)/finetuning/filtered $(DATA_DIR)/finetuning/extracted $(DATA_DIR)/finetuning/final 14 | 15 | compile: 16 | mkdir -p $(PREPROCESS_BUILD_DIR) 17 | cmake -S $(PREPROCESS_SRC_DIR) -B $(PREPROCESS_BUILD_DIR) 18 | make -C $(PREPROCESS_BUILD_DIR) 19 | cp $(PREPROCESS_BUILD_DIR)/1_filter src/pre_process/1_filter 20 | cp $(PREPROCESS_BUILD_DIR)/3_field_extraction src/pre_process/3_field_extraction 21 | 22 | preprocess: 23 | python3 ./scripts/preprocess_data.py --input_folder $(DATA_DIR)/pretraining --action pretrain --tokenizer_config configs/TestPretrainingConfig.json --combined 24 | python3 ./scripts/preprocess_data.py --input_folder $(DATA_DIR)/finetuning --action finetune --tokenizer_config configs/TestFinetuningConfig.json --combined 25 | 26 | pretrain: 27 | ./scripts/pretrain-test.sh 28 | 29 | finetune: 30 | ./scripts/finetune-test.sh 31 | 32 | 33 | .PHONY: all clean compile preprocess pretrain finetune -------------------------------------------------------------------------------- /src/pre_process/CollectTokensInFiles.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pyarrow as pa 4 | import pyarrow.ipc as ipc 5 | 6 | 7 | def merge_arrow_files(input_folder, output_file): 8 | # Get all the Arrow files in the specified folder 9 | input_files = [os.path.join(input_folder, f) for f in os.listdir(input_folder) if f.endswith('.arrow')] 10 | 11 | # get schema 12 | first_file = input_files[0] 13 | with pa.memory_map(first_file, 'r') as source: 14 | reader = ipc.open_stream(source) 15 | schema = reader.schema 16 | 17 | # Initialize the output stream 18 | with pa.OSFile(output_file, 'wb') as sink: 19 | with ipc.new_stream(sink, schema) as writer: 20 | for input_file in input_files: 21 | with pa.memory_map(input_file, 'r') as source: 22 | reader = ipc.open_stream(source) 23 | for batch in reader: 24 | writer.write_batch(batch) 25 | 26 | 27 | def main(): 28 | parser = argparse.ArgumentParser(description="Merge Apache Arrow streaming files.") 29 | parser.add_argument("input_folder", type=str, help="The folder containing the Arrow streaming files.") 30 | parser.add_argument("output_file", type=str, help="The output Arrow streaming file.") 31 | 32 | args = parser.parse_args() 33 | 34 | merge_arrow_files(args.input_folder, args.output_file) 35 | 36 | 37 | if __name__ == "__main__": 38 | main() 39 | -------------------------------------------------------------------------------- /src/pre_process/3_extract_fields.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set +x 5 | 6 | if [ "$#" -lt 2 ]; then 7 | echo "Usage: $0 input_folder output_folder [tcpoptions]" 8 | exit 1 9 | fi 10 | 11 | # Get the directory where the current script is located 12 | script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" 13 | 14 | input_folder="$1" 15 | output_folder="$2" 16 | tcpoptions=0 17 | if [ "$#" -eq 3 ]; then 18 | tcpoptions="$3" 19 | fi 20 | 21 | 22 | # Check if input_folder exists and is a directory 23 | if [ ! -d "$input_folder" ]; then 24 | echo "Error: Input folder '$input_folder' does not exist or is not a directory." 25 | exit 1 26 | fi 27 | 28 | # Create the output folder if it doesn't exist 29 | mkdir -p "$output_folder" 30 | 31 | # Check if 3_field_extraction script exists in the same directory as the current script 32 | field_extraction_script="$script_dir/3_field_extraction" 33 | if [ ! -f "$field_extraction_script" ]; then 34 | echo "Error: 3_field_extraction script not found in $script_dir, run make all" 35 | exit 1 36 | fi 37 | 38 | # Create output directories for each subdirectory in the input folder 39 | find "$input_folder" -mindepth 1 -maxdepth 1 -type d -print0 | while IFS= read -r -d '' dir; do 40 | dir_name="$(basename "$dir")" 41 | mkdir -p "$output_folder/$dir_name" 42 | done 43 | 44 | find "$input_folder" -mindepth 1 -maxdepth 1 -type d -print0 | parallel -0 "$field_extraction_script {} $output_folder/{/} $tcpoptions" 45 | -------------------------------------------------------------------------------- /scripts/print_arrow.py: -------------------------------------------------------------------------------- 1 | import pyarrow as pa 2 | import pyarrow.ipc as ipc 3 | import sys 4 | 5 | def read_and_print_labels(file_path): 6 | try: 7 | # Open the Arrow file as a stream 8 | with pa.OSFile(file_path, 'rb') as source: 9 | reader = ipc.open_stream(source) 10 | 11 | # Read all batches in the stream 12 | for batch in reader: 13 | # Convert the batch to a table 14 | table = pa.Table.from_batches([batch]) 15 | print(table) 16 | 17 | # Get the "label" field if it exists 18 | if 'labels' in table.column_names: 19 | label_column = table['labels'] 20 | 21 | # Print the first 10 rows of the "label" column 22 | print("First 10 'label' values:") 23 | print(label_column.to_pylist()[:10]) 24 | else: 25 | print("'labels' field not found in the file.") 26 | break # Exit after processing the first batch 27 | except Exception as e: 28 | print(f"Error processing the file: {e}") 29 | 30 | if __name__ == "__main__": 31 | # Ensure a filename is provided as the first argument 32 | if len(sys.argv) < 2: 33 | print("Usage: python script_name.py ") 34 | sys.exit(1) 35 | 36 | # Get the filename from the command-line argument 37 | file_path = sys.argv[1] 38 | 39 | # Call the function with the provided filename 40 | read_and_print_labels(file_path) 41 | -------------------------------------------------------------------------------- /src/pre_process/packets_processing_src/1_filter/1_filter.cpp: -------------------------------------------------------------------------------- 1 | #include "PcapFileDevice.h" 2 | #include "Packet.h" 3 | #include "IPv4Layer.h" 4 | #include 5 | #include 6 | 7 | int main(int argc, char *argv[]) { 8 | if (argc < 3 or argc > 4) { 9 | std::cerr << "Usage: " << argv[0] << " [optional: ]\n"; 10 | return 1; 11 | } 12 | 13 | auto *reader = pcpp::IFileReaderDevice::getReader(argv[1]); 14 | 15 | if (reader == nullptr) { 16 | std::cerr << "Cannot determine reader for file: " << argv[1] << std::endl; 17 | return 1; 18 | } 19 | if (!reader->open()) { 20 | std::cerr << "Cannot open " << argv[1] << " for reading" << std::endl; 21 | return 1; 22 | } 23 | 24 | 25 | std::string outFileName = argv[2]; 26 | pcpp::IFileWriterDevice *writer; 27 | if (outFileName.ends_with(".pcap")) { 28 | writer = new pcpp::PcapFileWriterDevice(argv[2]); 29 | } else if (outFileName.ends_with(".pcapng")) { 30 | writer = new pcpp::PcapNgFileWriterDevice(argv[2]); 31 | } else { 32 | std::cerr << "Output file must have .pcap or .pcapng extension" << std::endl; 33 | return 1; 34 | } 35 | 36 | if (!writer->open()) { 37 | // Handle error 38 | std::cerr << "Error opening output file: " << argv[2] << std::endl; 39 | return 1; 40 | } 41 | 42 | // parse the optional argument 43 | bool enable_time_shift = false; 44 | long unixtime_seconds = 0; 45 | if (argc == 4) { 46 | unixtime_seconds = std::stol(argv[3]); 47 | enable_time_shift = true; 48 | } 49 | 50 | long diff = 0; 51 | bool first = true; 52 | 53 | pcpp::RawPacket rawPacket; 54 | while (reader->getNextPacket(rawPacket)) { 55 | if (enable_time_shift) { 56 | if (first) { 57 | first = false; 58 | diff = rawPacket.getPacketTimeStamp().tv_sec - unixtime_seconds; 59 | } else { 60 | auto x = rawPacket.getPacketTimeStamp(); 61 | x.tv_sec -= diff; 62 | rawPacket.setPacketTimeStamp(x); 63 | } 64 | } 65 | 66 | pcpp::Packet parsedPacket(&rawPacket); 67 | auto *ipLayer = parsedPacket.getLayerOfType(); 68 | if (ipLayer != nullptr && (ipLayer->getIPv4Header()->ipVersion == 4 || ipLayer->getIPv4Header()->ipVersion == 6) ) { 69 | if (parsedPacket.isPacketOfType(pcpp::TCP) || 70 | parsedPacket.isPacketOfType(pcpp::UDP) || 71 | parsedPacket.isPacketOfType(pcpp::ICMP)) { 72 | writer->writePacket(rawPacket); 73 | } 74 | } 75 | } 76 | 77 | reader->close(); 78 | writer->close(); 79 | 80 | return 0; 81 | } 82 | 83 | 84 | -------------------------------------------------------------------------------- /src/pre_process/README.md: -------------------------------------------------------------------------------- 1 | # Data preprocessing 2 | 3 | TL;DR: 4 | Data for pretraining should be located in folder X, which should have the folder "raw" inside which should have pcaps. 5 | Data for finetuning should have N folders (for each class if classification) and each of these folders should have the "raw" folder with pcaps. 6 | See data/test for example. 7 | 8 | See scripts/preprocess_data.py for details. 9 | 10 | ## Field extraction 11 | 12 | input: merged pcap with flows 13 | 14 | 1. Process all pcaps and leave only tcp/udp/icmp packets 15 | `./1_filter.sh input_folder output_folder` 16 | 2. split pcap by flows 17 | `./2_pcap_splitting.sh input_folder output_folder` 18 | 3. extract packet features from each flow 19 | `./3_extract_fields.sh input_folder output_folder` 20 | 21 | ### File structure 22 | Resulting folder structure: 23 | - (folder) 24 | - .pcap. 25 | - .pcap. 26 | - ... 27 | 28 | ### File structure and fields 29 | 30 | Each file is a binary stream of packets in a custom format. 31 | First byte is always a procotol number: 1 for ICMP, 6 for TCP, 17 for UDP. 32 | Then, each packet is represented by a sequence of fields without separators. 33 | 34 | So, the packet with TCP protocol will have the following structure: 35 | - uint8_t: protocol number 36 | - packet0 representation 37 | - packet1 representation 38 | - ... 39 | 40 | #### ICMP packet structure 41 | - uint64_t: unix timestamp with nanoseconds 42 | - uint8_t: IP header length (in bytes) 43 | - uint8_t: Type of Service 44 | - uint16_t: Total Length 45 | - uint8_t: IP Flags 46 | - uint8_t: TTL 47 | - uint32_t: Source IP (as an integer) 48 | - uint32_t: Destination IP (as an integer) 49 | - uint8_t: ICMP type 50 | - uint8_t: ICMP code 51 | - 12 bytes of data padded with zeros 52 | 53 | #### TCP packet structure 54 | - uint64_t: unix timestamp with nanoseconds 55 | - uint8_t: IP header length (in bytes) 56 | - uint8_t: Type of Service 57 | - uint16_t: Total Length 58 | - uint8_t: IP Flags 59 | - uint8_t: TTL 60 | - uint32_t: Source IP (as an integer) 61 | - uint32_t: Destination IP (as an integer) 62 | - uint16_t: Source Port 63 | - uint16_t: Destination Port 64 | - uint8_t: TCP flags 65 | - uint16_t: TCP Window Size 66 | - uint32_t: Relative Sequence Number 67 | - uint32_t: Relative Acknowledgement Number 68 | - uint16_t: TCP urgent pointer 69 | - 12 bytes of data padded with zeros 70 | 71 | #### UDP packet structure 72 | - uint64_t: unix timestamp with nanoseconds 73 | - uint8_t: IP header length (in bytes) 74 | - uint8_t: Type of Service 75 | - uint16_t: Total Length 76 | - uint8_t: IP Flags 77 | - uint8_t: TTL 78 | - uint32_t: Source IP (as an integer) 79 | - uint32_t: Destination IP (as an integer) 80 | - uint16_t: Source Port 81 | - uint16_t: Destination Port 82 | - uint16_t: Length 83 | - 12 bytes of data padded with zeros 84 | -------------------------------------------------------------------------------- /scripts/slurm/netfound_pretraining.sl: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ### ATTENTION: this script is provided for reference only to give an idea how to make it run on SLURM-based systems 4 | ### you need to adjust the script to your needs and test it properly 5 | 6 | ### START: do not change this 7 | #SBATCH --account= 8 | #SBATCH --licenses= 9 | #SBATCH --ntasks-per-node=1 10 | #SBATCH --constraint= 11 | #SBATCH --gpus-per-node= 12 | ### END: do not change this 13 | 14 | # --constraint= gpu with 80gb - should fit batch size 20 with the big model 15 | # --constraint= gpu with 40gb - should fit batch size 8 with big model and 16 with medium model 16 | 17 | ### START: usually you do not need to change this 18 | #SBATCH --output=%x-%j.out 19 | #SBATCH --error=%x-%j.err 20 | #SBATCH --qos= 21 | ### END: usually you do not need to change this 22 | 23 | ### START: feel free to change this 24 | #SBATCH --job-name= 25 | #SBATCH --time= 26 | #SBATCH --nodes= 27 | ### END: feel free to change this 28 | 29 | ### START: do not change this 30 | set -x -e 31 | 32 | module load cudatoolkit 33 | module load cray-mpich 34 | module load gcc 35 | module load conda 36 | conda activate 37 | 38 | export GPUS_PER_NODE=4 39 | export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 40 | export MASTER_PORT=9901 41 | export WORLD_SIZE=$(($GPUS_PER_NODE*$SLURM_NNODES)) 42 | ### END: do not change this 43 | 44 | ## add this to resume from checkpoint 45 | ## --resume_from_checkpoint checkpoint_path \ 46 | ## --ignore_data_skip True \ 47 | ## ignore data skip until https://github.com/huggingface/transformers/pull/33544 is merged 48 | 49 | 50 | 51 | srun --ntasks-per-node=1 --gpus-per-node=4 --jobid $SLURM_JOBID bash -c '\ 52 | torchrun \ 53 | --nproc_per_node $GPUS_PER_NODE \ 54 | --nnodes $SLURM_NNODES \ 55 | --node_rank $SLURM_PROCID \ 56 | --master_addr $MASTER_ADDR \ 57 | --master_port $MASTER_PORT \ 58 | $PSCRATCH/network-data-representation/src/train/NetFoundPretraining.py \ 59 | --report_to tensorboard \ 60 | --save_safetensors false \ 61 | --dispatch_batches False \ 62 | --bf16 \ 63 | --do_train \ 64 | --do_eval \ 65 | --eval_strategy steps \ 66 | --save_strategy steps \ 67 | --dataloader_num_workers 32 \ 68 | --dataloader_prefetch_factor 16 \ 69 | --logging_steps 20000 \ 70 | --save_steps 20000 \ 71 | --streaming True \ 72 | --gradient_accumulation_steps 1 \ 73 | --hidden_size 1024 \ 74 | --num_hidden_layers 24 \ 75 | --num_attention_heads 16 \ 76 | --tcpoptions False \ 77 | --train_dir /path/to/train_data \ 78 | --test_dir /path/to/eval \ 79 | --output_dir /path/to/modelXX \ 80 | --model_name_or_path /path/to/checkpointYY \ 81 | --per_device_train_batch_size 20 \ 82 | --learning_rate 0.000013 \ 83 | --lr_scheduler_type linear \ 84 | --warmup_steps 5000 \ 85 | --max_steps 300000 86 | ' -------------------------------------------------------------------------------- /scripts/slurm/netfound_finetuning.sl: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ### ATTENTION: this script is provided for reference only to give an idea how to make it run on SLURM-based systems 3 | ### you need to adjust the script to your needs and test it properly 4 | 5 | ### START: do not change this 6 | #SBATCH --account= 7 | #SBATCH --licenses= 8 | #SBATCH --ntasks-per-node=1 9 | #SBATCH --constraint= 10 | #SBATCH --gpus-per-node= 11 | ### END: do not change this 12 | 13 | # --constraint= gpu with 80gb - should fit batch size 20 with the big model 14 | # --constraint= gpu with 40gb - should fit batch size 8 with big model and 16 with medium model 15 | 16 | ### START: usually you do not need to change this 17 | #SBATCH --output=%x-%j.out 18 | #SBATCH --error=%x-%j.err 19 | #SBATCH --qos= 20 | ### END: usually you do not need to change this 21 | 22 | ### START: feel free to change this 23 | #SBATCH --job-name= 24 | #SBATCH --time= 25 | #SBATCH --nodes= 26 | ### END: feel free to change this 27 | 28 | ### START: do not change this 29 | set -x -e 30 | 31 | module load cudatoolkit 32 | module load cray-mpich 33 | module load gcc 34 | module load conda 35 | conda activate 36 | 37 | export GPUS_PER_NODE=4 38 | export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 39 | export MASTER_PORT=9901 40 | export WORLD_SIZE=$(($GPUS_PER_NODE*$SLURM_NNODES)) 41 | ### END: do not change this 42 | 43 | ## add this to resume from checkpoint 44 | ## --resume_from_checkpoint checkpoint_path \ 45 | ## --ignore_data_skip True \ 46 | ## ignore data skip until https://github.com/huggingface/transformers/pull/33544 is merged 47 | 48 | # copy the file and modify the params as needed 49 | srun --jobid $SLURM_JOBID bash -c '\ 50 | python \ 51 | -m torch.distributed.run \ 52 | --nproc_per_node $GPUS_PER_NODE \ 53 | --nnodes $SLURM_NNODES \ 54 | --node_rank $SLURM_PROCID \ 55 | --master_addr $MASTER_ADDR \ 56 | --master_port $MASTER_PORT \ 57 | $PSCRATCH/network-data-representation/src/train/NetfoundFinetuning.py \ 58 | --report_to tensorboard \ 59 | --save_safetensors false \ 60 | --dispatch_batches False \ 61 | --bf16 \ 62 | --do_train \ 63 | --do_eval \ 64 | --eval_strategy steps \ 65 | --save_strategy steps \ 66 | --dataloader_num_workers 32 \ 67 | --dataloader_prefetch_factor 16 \ 68 | --logging_steps 5000 \ 69 | --save_steps 5000 \ 70 | --streaming True \ 71 | --gradient_accumulation_steps 1 \ 72 | --hidden_size 1024 \ 73 | --num_hidden_layers 24 \ 74 | --num_attention_heads 16 \ 75 | --tcpoptions False \ 76 | --validation_split_percentage 20 \ 77 | --per_device_eval_batch_size 8 \ 78 | --per_device_train_batch_size 8 \ 79 | --load_best_model_at_end \ 80 | --train_dir /path/to/train \ 81 | --test_dir /path/to/test \ 82 | --output_dir /path/to/output \ 83 | --model_name_or_path /path/to/checkpoint-XXX \ 84 | --problem_type single_label_classification \ 85 | --num_labels 8 \ 86 | --learning_rate 2e-5 \ 87 | --freeze_base True \ 88 | --max_steps 300000 \ 89 | ' 90 | -------------------------------------------------------------------------------- /src/train/NetfoundConfig.py: -------------------------------------------------------------------------------- 1 | from transformers.utils import logging 2 | from transformers import PretrainedConfig 3 | 4 | logger = logging.get_logger(__name__) 5 | 6 | 7 | class NetfoundConfig(PretrainedConfig): 8 | model_type = "NetFound" 9 | 10 | def __init__( 11 | self, 12 | vocab_size=65539, 13 | hidden_size=768, 14 | max_bursts=12, 15 | max_burst_length=108 + 1, 16 | model_max_length=1296 + 12, 17 | num_hidden_layers=12, 18 | num_attention_heads=12, 19 | intermediate_size=3072, 20 | hidden_act="gelu", 21 | hidden_dropout_prob=0.1, 22 | attention_probs_dropout_prob=0.1, 23 | max_position_embeddings=108 + 1, 24 | type_vocab_size=2, 25 | initializer_range=0.02, 26 | layer_norm_eps=1e-12, 27 | pad_token_id=0, 28 | position_embedding_type="absolute", 29 | encoder_layout=None, 30 | use_cache=True, 31 | classifier_dropout=None, 32 | metaFeatures=4, 33 | roformer=False, 34 | no_meta = False, 35 | flat = False, 36 | no_mlm=False, 37 | no_swapped_bursts = True, 38 | rep_output_path = None, 39 | subflow_bursts = 3, 40 | no_metadata_loss=False, 41 | no_direction_loss=False, 42 | **kwargs 43 | ): 44 | super().__init__(pad_token_id=pad_token_id, **kwargs) 45 | 46 | self.vocab_size = vocab_size 47 | self.hidden_size = hidden_size 48 | self.embedding_size = hidden_size 49 | self.max_bursts = max_bursts 50 | self.max_burst_length = max_burst_length 51 | self.model_max_length = model_max_length 52 | self.encoder_layout = encoder_layout 53 | self.num_hidden_layers = num_hidden_layers 54 | self.num_attention_heads = num_attention_heads 55 | self.hidden_act = hidden_act 56 | self.intermediate_size = intermediate_size 57 | self.hidden_dropout_prob = hidden_dropout_prob 58 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 59 | self.max_position_embeddings = max_position_embeddings 60 | self.type_vocab_size = type_vocab_size 61 | self.initializer_range = initializer_range 62 | self.layer_norm_eps = layer_norm_eps 63 | self.position_embedding_type = position_embedding_type 64 | self.use_cache = use_cache 65 | self.classifier_dropout = classifier_dropout 66 | self.metaFeatures = metaFeatures 67 | self.p = 0 68 | self.pretraining = True 69 | self.roformer = roformer 70 | self.no_meta = no_meta 71 | self.flat = flat 72 | self.limit_bursts = False 73 | self.rotary_value = False 74 | self.subflow_len=-1 75 | self.no_mlm = no_mlm 76 | self.no_swapped_bursts = no_swapped_bursts 77 | self.rep_output_path = rep_output_path 78 | self.subflow_bursts = subflow_bursts 79 | self.no_metadata_loss = no_metadata_loss 80 | self.no_direction_loss = no_direction_loss 81 | 82 | class NetFoundLarge(NetfoundConfig): 83 | def __init__(self, **kwargs): 84 | super().__init__(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, **kwargs) 85 | 86 | 87 | class NetFoundTCPOptionsConfig(NetfoundConfig): 88 | def __init__( 89 | self, 90 | max_burst_length=6 * (18 + 20) + 1, # 6 packets max * (18 tokens max for tcp + 20 tokens for tcpoptions) + 1 for CLS 91 | max_position_embeddings=6 * (18 + 20) + 1, 92 | model_max_length=(6 * (18 + 20)) * 12 + 12, # (6 packets * (18 tokens + 20 tcpoptions)) * 12 bursts + 12 CLS tokens 93 | *args, **kwargs 94 | ): 95 | super().__init__(*args, max_burst_length=max_burst_length, model_max_length=model_max_length, max_position_embeddings=max_position_embeddings, **kwargs) -------------------------------------------------------------------------------- /scripts/preprocess_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | import os 4 | import logging 5 | 6 | logger = logging.getLogger(__name__) 7 | logger.setLevel(logging.INFO) 8 | logger.addHandler(logging.StreamHandler()) 9 | 10 | 11 | def get_args(): 12 | description = """ 13 | This script preprocesses the raw pcap data into the tokenized format. It takes the input folder as an argument and one of two required flags: --pretrain or --finetune. 14 | The input folder must contain '/raw' folder with either raw pcap files (for pretraining, no labels) or folders with pcap files (finetuning, folder names must be integers and would be used as labels). 15 | The input folder would be used for intermediate files and the final tokenized data would be stored in the /final/shards folder as Apache Arrow shards. 16 | """ 17 | parser = argparse.ArgumentParser(description=description) 18 | parser.add_argument("--input_folder", type=str, required=True, help="The input folder") 19 | parser.add_argument("--action", choices=["pretrain", "finetune"], required=True, 20 | help="Preprocess data for pretraining or finetuning.") 21 | parser.add_argument("--tokenizer_config", type=str, required=True, help="The tokenizer config file.") 22 | parser.add_argument("--tcp_options", action="store_true", default=False, help="Include TCP options in the tokenized data.") 23 | parser.add_argument("--combined", action="store_true", default=False, 24 | help="Combine all the pcap files in the /final/shards into a single file (suitable for small datasets).") 25 | 26 | return parser 27 | 28 | 29 | def run(command: list[str]) -> subprocess.CompletedProcess: 30 | logger.info(f"Running command: {' '.join(command)}") 31 | process = subprocess.run(command, check=True, capture_output=True) 32 | if process.stderr: 33 | logger.error(process.stderr.decode()) 34 | if process.stdout: 35 | logger.info(process.stdout.decode()) 36 | return process 37 | 38 | 39 | def get_base_directory(args): 40 | # one step up of the directory of this file 41 | return os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 42 | 43 | 44 | def preprocess_pretrain(args): 45 | base_directory = get_base_directory(args) 46 | input_folder = args.input_folder 47 | run([f"{base_directory}/src/pre_process/1_filter.sh", f"{input_folder}/raw", f"{input_folder}/filtered"]) 48 | run([f"{base_directory}/src/pre_process/2_pcap_splitting.sh", f"{input_folder}/filtered", f"{input_folder}/split"]) 49 | run([f"{base_directory}/src/pre_process/3_extract_fields.sh", f"{input_folder}/split", f"{input_folder}/extracted", "1" if args.tcp_options else ""]) 50 | 51 | for folder_name in os.listdir(f"{input_folder}/extracted"): 52 | full_folder_name = os.path.join(f"{input_folder}/extracted", folder_name) 53 | os.makedirs(os.path.join(f"{input_folder}/final/shards", folder_name), exist_ok=True) 54 | run(["python3", f"{base_directory}/src/pre_process/Tokenize.py", "--conf_file", args.tokenizer_config, 55 | "--input_dir", full_folder_name, "--output_dir", 56 | os.path.join(f"{input_folder}/final/shards", folder_name)]) 57 | if args.combined: 58 | os.makedirs(os.path.join(f"{input_folder}/final", "combined"), exist_ok=True) 59 | run(["python3", f"{base_directory}/src/pre_process/CollectTokensInFiles.py", 60 | os.path.join(f"{input_folder}/final/shards", folder_name), 61 | os.path.join(f"{input_folder}/final/combined", f"{folder_name}.arrow")]) 62 | 63 | 64 | def preprocess_finetune(args): 65 | base_directory = get_base_directory(args) 66 | input_folder = args.input_folder 67 | for label in os.listdir(f"{input_folder}/raw"): 68 | for stage_name in ["filtered", "split", "extracted", "final/shards"]: 69 | os.makedirs(os.path.join(input_folder, stage_name, label), exist_ok=True) 70 | run([f"{base_directory}/src/pre_process/1_filter.sh", f"{input_folder}/raw/{label}", 71 | f"{input_folder}/filtered/{label}"]) 72 | run([f"{base_directory}/src/pre_process/2_pcap_splitting.sh", f"{input_folder}/filtered/{label}", 73 | f"{input_folder}/split/{label}"]) 74 | run([f"{base_directory}/src/pre_process/3_extract_fields.sh", f"{input_folder}/split/{label}", 75 | f"{input_folder}/extracted/{label}", "1" if args.tcp_options else ""]) 76 | 77 | for folder_name in os.listdir(f"{input_folder}/extracted/{label}"): 78 | full_folder_name = os.path.join(f"{input_folder}/extracted/{label}", folder_name) 79 | os.makedirs(os.path.join(f"{input_folder}/final/shards/{label}", folder_name), exist_ok=True) 80 | run(["python3", f"{base_directory}/src/pre_process/Tokenize.py", "--conf_file", args.tokenizer_config, 81 | "--input_dir", full_folder_name, "--output_dir", 82 | os.path.join(f"{input_folder}/final/shards/{label}", folder_name), '--label', label]) 83 | if args.combined: 84 | os.makedirs(os.path.join(f"{input_folder}/final", "combined"), exist_ok=True) 85 | run(["python3", f"{base_directory}/src/pre_process/CollectTokensInFiles.py", 86 | os.path.join(f"{input_folder}/final/shards/{label}", folder_name), 87 | os.path.join(f"{input_folder}/final/combined", f"{label}_{folder_name}.arrow")]) 88 | 89 | def main(): 90 | parser = get_args() 91 | args = parser.parse_args() 92 | input_folder = args.input_folder 93 | action = args.action 94 | 95 | raw_data_folder = os.path.join(input_folder, "raw") 96 | if not os.path.exists(raw_data_folder): 97 | print(f"Input folder {raw_data_folder} does not exist.") 98 | return 99 | 100 | for folder in ["filtered", "split", "extracted", "final", "final/shards"]: 101 | os.makedirs(os.path.join(input_folder, folder), exist_ok=True) 102 | 103 | match action: 104 | case "pretrain": 105 | preprocess_pretrain(args) 106 | case "finetune": 107 | preprocess_finetune(args) 108 | case _: 109 | raise ValueError("Unexpected action") 110 | 111 | return 112 | 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /src/train/NetfoundTokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from typing import List, Union, Optional, Tuple 4 | import itertools 5 | 6 | import numpy as np 7 | from transformers import PreTrainedTokenizer, BatchEncoding 8 | from datasets.formatting.formatting import LazyBatch 9 | 10 | PROTOS_TO_LEN = {6: 18, 1: 13, 17: 12} # TODO(maybe-hello-world): refactor 11 | 12 | 13 | class NetFoundTokenizer(PreTrainedTokenizer): 14 | CLS_TOKEN = 65537 15 | PAD_TOKEN = 0 16 | mask_token = 65538 17 | vocab_size = 65539 18 | ATTN_PRESENCE_TOKEN = 1 19 | ATTN_ABSENCE_TOKEN = 0 20 | 21 | def __init__(self, config): 22 | self.vocab_size = config.vocab_size 23 | self.max_bursts = config.max_bursts 24 | self.max_burst_length = config.max_burst_length 25 | self.p = config.p 26 | self.pretraining = config.pretraining 27 | self.name_or_path = config.name_or_path 28 | self.limit_bursts = config.limit_bursts 29 | 30 | def __repr__(self) -> str: 31 | return ( 32 | f"{self.__class__.__name__}(name_or_path='{self.name_or_path}'," 33 | f" vocab_size={self.vocab_size}, max_bursts={self.max_bursts}, max_burst_length={self.max_burst_length}, p={self.p})" 34 | ) 35 | 36 | @property 37 | def all_special_ids(self) -> List[int]: 38 | """ 39 | `List[int]`: List the ids of the special tokens(`''`, `''`, etc.) mapped to class attributes. 40 | """ 41 | return [self.CLS_TOKEN, self.PAD_TOKEN] 42 | 43 | def save_pretrained( 44 | self, 45 | save_directory: Union[str, os.PathLike], 46 | legacy_format: Optional[bool] = None, 47 | filename_prefix: Optional[str] = None, 48 | push_to_hub: bool = False, 49 | **kwargs, 50 | ) -> Tuple[str]: 51 | return 52 | 53 | def __len__(self): 54 | return self.vocab_size 55 | 56 | def pad_bursts( 57 | self, 58 | flow: list[list[int]], 59 | max_burst_length: int, 60 | pad_token: Optional[int] = None 61 | ) -> np.ndarray: 62 | """ 63 | Truncate each burst to `max_burst_length` and pad with token if necessary. 64 | """ 65 | if pad_token is None: 66 | pad_token = self.PAD_TOKEN 67 | return np.array([ 68 | burst[:max_burst_length] + [pad_token] * max((max_burst_length - len(burst)), 0) 69 | for burst in flow 70 | ]) 71 | 72 | def pad_flow(self, flow, max_bursts: int, token: int = None): 73 | """ 74 | Truncate the flow to `max_bursts` and pad with token if necessary. 75 | """ 76 | if token is None: 77 | token = self.PAD_TOKEN 78 | 79 | pad_bursts = max(max_bursts - len(flow), 0) 80 | pads = [token] * len(flow[0]) * pad_bursts 81 | 82 | flow = list(itertools.chain.from_iterable(flow[:max_bursts])) # flatten 83 | flow += pads 84 | return flow 85 | 86 | @staticmethod 87 | def prepend_to_list(flow: list[list[int]], token: Optional[int]) -> list[list[int]]: 88 | # Sometimes we prepend CLS_TOKEN or similar 89 | if token is not None: 90 | return [[token] + burst for burst in flow] 91 | else: 92 | return [[burst[0]] + burst for burst in flow] 93 | 94 | @staticmethod 95 | def convert_to_tokens(flow: list[list[int]], add_one: bool = False) -> list[list[int]]: 96 | if not add_one: 97 | return flow # noop 98 | return [[tok + add_one for tok in burst] for burst in flow] 99 | 100 | @staticmethod 101 | def convert_to_attn(bursts): 102 | return [[1] * len(burst) for burst in bursts] 103 | 104 | def __call__(self, dataset): 105 | return self.tokenize(dataset) 106 | 107 | def trunc_flow(self, ls, idxs): 108 | return [ 109 | ".".join(ls[i].split(".")[:idxs[i]]) + "." 110 | for i in range(len(ls)) 111 | ] 112 | 113 | @staticmethod 114 | def _expand_bursts(flows: list[list[int]], burst_sizes: list[list[int]]) -> list[list[list[int]]]: 115 | """ 116 | To save space, some repetitive info is stored as a single value for the entire burst. 117 | This function expands the burst sizes to match the actual burst lengths. 118 | """ 119 | return [ 120 | [ 121 | [value] * burst_sizes[idx][i] 122 | for i, value in enumerate(flow) 123 | ] 124 | for idx, flow in enumerate(flows) 125 | ] 126 | 127 | @staticmethod 128 | def multiply_burst_values(flows: list[list[float]], multiplier: float, ftype=float) -> list[list[float]]: 129 | return [ 130 | [ftype(burst_value * multiplier) for burst_value in flow] 131 | for flow in flows 132 | ] 133 | 134 | def tokenize(self, text, **kwargs): 135 | dataset: LazyBatch = text 136 | dataset['iats'] = self.multiply_burst_values(dataset['iats'], 1e-3, int) 137 | dataset_burst_sizes = [[len(burst) for burst in flow] for flow in dataset["burst_tokens"]] 138 | 139 | if not self.pretraining and "labels" in dataset: 140 | labels = np.array(dataset["labels"], dtype=int) 141 | labels = labels.tolist() 142 | 143 | # restore directions: true/false -> 1/-1 144 | direction = [[1 if direction else -1 for direction in flow] for flow in dataset["directions"]] 145 | direction = self.tokenize_fields(self._expand_bursts(direction, dataset_burst_sizes)) 146 | 147 | pkt_bytes = self.tokenize_fields(self._expand_bursts(dataset["bytes"], dataset_burst_sizes)) 148 | pkt_count = self.tokenize_fields(self._expand_bursts(dataset["counts"], dataset_burst_sizes)) 149 | iats = self.tokenize_fields(self._expand_bursts(dataset["iats"], dataset_burst_sizes)) 150 | input_ids, attention_mask = self.tokenize_fields_with_attn( 151 | dataset["burst_tokens"], prepend_token=self.CLS_TOKEN, add_one=True 152 | ) 153 | total_bursts = [len(flow) for flow in dataset["burst_tokens"]] 154 | 155 | batchDict = { 156 | "input_ids": input_ids, 157 | "attention_mask": attention_mask, 158 | "direction": direction, 159 | "bytes": pkt_bytes, 160 | "pkt_count": pkt_count, 161 | "iats": iats, 162 | "total_bursts": total_bursts, 163 | "flow_duration": dataset["flow_duration"], 164 | "protocol": dataset["protocol"], 165 | } 166 | if not self.pretraining and "labels" in dataset: 167 | batchDict.update({"labels": labels}) 168 | 169 | return BatchEncoding(batchDict) 170 | 171 | def tokenize_fields( 172 | self, 173 | dataset: list[list[list[int]]], 174 | prepend_token: int = None, 175 | add_one: bool = False 176 | ) -> list[list[list[int]]]: 177 | tokenized_data = [ 178 | self.pad_flow( 179 | self.pad_bursts( 180 | self.prepend_to_list(self.convert_to_tokens(flow, add_one), prepend_token), 181 | self.max_burst_length, 182 | ), 183 | self.max_bursts, 184 | ) 185 | for flow in dataset 186 | ] 187 | 188 | return tokenized_data 189 | 190 | def tokenize_fields_with_attn( 191 | self, 192 | dataset: list[list[list[int]]], 193 | prepend_token: int = None, 194 | add_one: bool = False 195 | ) -> Tuple[list[list[list[int]]], list[list[list[int]]]]: 196 | tokenized_data = self.tokenize_fields(dataset, prepend_token, add_one) 197 | attn = [ 198 | self.pad_flow( 199 | self.pad_bursts( 200 | self.prepend_to_list(self.convert_to_attn(flow), self.ATTN_PRESENCE_TOKEN), 201 | max_burst_length=self.max_burst_length, 202 | pad_token=self.ATTN_ABSENCE_TOKEN 203 | ), 204 | max_bursts=self.max_bursts, 205 | token=self.ATTN_ABSENCE_TOKEN 206 | ) 207 | for flow in dataset 208 | ] 209 | return tokenized_data, attn 210 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # netFound: Foundation Model for Network Security 2 | This repository contains the **source code for netFound**, a foundation model for network telemetry developed by the **Systems & Networking Lab (SNL) at UC Santa Barbara**. 3 | ## Description 4 | netFound is designed to learn **spatial-temporal relationships** from raw network traffic, making it a powerful tool for network analysis, anomaly detection, and traffic prediction. 5 | ## :key: Key Features 6 | - **Raw Packet Processing**: Directly processes raw *PCAP* files as input, enabling full-scale network traffic analysis. 7 | - **Pretraining on Unlabeled Data**: Requires pretraining on large-scale, *unlabeled* network telemetry datasets, leveraging *self-supervised learning*. 8 | - **Hierarchical Transformer Architecture**: Captures both *packet bursts* and *flow-level behavior*, ensuring robust feature extraction. 9 | - **Metadata-Aware Processing**: Integrates **burst-level metadata** such as: 10 | - Inter-arrival time* 11 | - Number of bytes per burst 12 | - Packet-level timing and structure 13 | ## :pushpin: Why Use netFound? 14 | netFound is part of a larger effort to develop **self-driving networks**—autonomous, adaptive network systems that require minimal human intervention. By leveraging *network foundation models*, we aim to improve the efficiency and scalability of *AI-powered Network Operations (AIOps)*. 15 | Corresponding paper: https://arxiv.org/abs/2310.17025 16 | ## Checkpoint 17 | https://huggingface.co/snlucsb/netFound-640M-base 18 | The checkpoint is pretrained on ~450mln flows of the real-world network traffic of the University of California, Santa Barbara. 19 | As the checkpoint is built on the Large version of the netFound, use `--netfound_large True` as a fine-tuning flag. 20 | Pretrained model metrics: 21 | ``` 22 | eval_loss = 1.8847 23 | eval_macro_mlm_f1 = 0.4038 24 | eval_macro_mlm_prec = 0.7205 25 | eval_macro_mlm_recall = 0.3005 26 | eval_mlm_acc = 0.8514 27 | eval_swapped_macro_pred_f1 = 0.9605 28 | eval_swapped_macro_pred_prec = 0.963 29 | eval_swapped_macro_pred_recall = 0.9603 30 | eval_swapped_pred_acc = 0.9605 31 | eval_swapped_weighted_pred_f1 = 0.9605 32 | eval_swapped_weighted_pred_prec = 0.9628 33 | eval_swapped_weighted_pred_recall = 0.9605 34 | eval_weighted_mlm_f1 = 0.8451 35 | eval_weighted_mlm_prec = 0.8816 36 | eval_weighted_mlm_recall = 0.8514 37 | perplexity = 6.5842 38 | Total params: 643,825,672 39 | ``` 40 | ## :rocket: Quick Start: Running netFound with Docker & Makefile 41 | The *easiest way* to verify that the *preprocessing code and model work correctly* is to use the *provided Dockerfile and Makefile*. This setup ensures a *reproducible environment* with all dependencies installed and includes a *small test dataset* to validate the pipeline. 42 | ### :hammer_and_wrench: **Step 1: Build the Docker Container** 43 | Run the following command to build the container: 44 | ```sh 45 | docker build -t netfound:test . 46 | ``` 47 | This will create a Docker image named `netfound:test`, including the *source code* and a *test dataset* located in `data/test`. 48 | ### :arrow_forward: **Step 2: Run the Container** 49 | Start an interactive session inside the container: 50 | ```sh 51 | docker run -it netfound:test 52 | ``` 53 | This will launch a shell inside the container in the `/workspace` directory. 54 | ### :zap: **Step 3: Run the Full Pipeline** 55 | Inside the container, execute: 56 | ```sh 57 | make all 58 | ``` 59 | This will sequentially run the following *three steps* on the test dataset: 60 | 1. **Preprocessing**: Converts raw PCAP files into a format suitable for training. 61 | 2. **Pretraining**: Runs *self-supervised learning* on preprocessed data. 62 | 3. **Finetuning**: Adapts the model for downstream tasks using the preprocessed test dataset. 63 | 64 | ## :building_construction: **Understanding the Makefile & Dockerfile** 65 | The *Dockerfile and Makefile* automate the pipeline and provide a structured workflow: 66 | ### :pushpin: **Dockerfile** 67 | - Creates a *containerized environment* with all necessary dependencies installed. 68 | - Ensures consistent execution across different systems. 69 | ### :pushpin: **Test Dataset (`data/test/`)** 70 | - Contains *raw PCAP files* formatted for preprocessing. 71 | - Used to verify the pipeline’s functionality. 72 | ### :pushpin: **Makefile Structure** 73 | - **`make preprocess`**: 74 | - Filters, splits, and tokenizes the raw packet data. 75 | - **`make pretrain`**: 76 | - Runs **self-supervised pretraining** on the preprocessed dataset. 77 | - **`make finetune`**: 78 | - Trains the model on task-specific labeled data. 79 | # :rocket: Bring Your Own Data (BYOD) 80 | To train or fine-tune **netFound** on your own dataset, follow the steps below to **preprocess and tokenize your PCAP files**. 81 | ## :pushpin: Preprocessing Your Dataset 82 | The easiest way to preprocess your dataset is to use the **`scripts/preprocess_data.py`** script. 83 | ### :open_file_folder: Folder Structure for Pretraining 84 | Organize your dataset as follows: 85 | ``` 86 | folder_name/ 87 | ├── raw/ 88 | │ ├── file1.pcap 89 | │ ├── file2.pcap 90 | │ ├── ... 91 | ``` 92 | Then, run the following command: 93 | ```bash 94 | python3 scripts/preprocess_data.py --input_folder folder_name --action pretrain --tokenizer_config configs/TestPretrainingConfig.json --combined 95 | ``` 96 | :small_blue_diamond: **What happens next?** 97 | - The script will generate **intermediate folders** (`extracted`, `split`, etc.). 98 | - The resulting **tokenized data** will be stored in the `"tokens"` folder. 99 | - The **`--combined`** flag merges all tokenized files into a single **Arrow** file (useful for training). 100 | - If you **remove `--combined`**, multiple **Arrow** files (one per PCAP) will be created—this is beneficial for parallel processing across multiple nodes. 101 | - You can **modify the tokenizer configuration** (`configs/TestPretrainingConfig.json`) to control how internal and external IPs are handled. 102 | ### :open_file_folder: Folder Structure for Fine-Tuning 103 | To fine-tune netFound, structure your dataset into **class-separated folders**, where **folder names should be integers** (used as class labels). 104 | ``` 105 | raw/ 106 | ├── 0/ 107 | │ ├── class1_sample1.pcap 108 | │ ├── class1_sample2.pcap 109 | │ ├── ... 110 | ├── 1/ 111 | │ ├── class2_sample1.pcap 112 | │ ├── class2_sample2.pcap 113 | │ ├── ... 114 | ``` 115 | Run the preprocessing script again, changing the `--action` to `finetune`: 116 | ```bash 117 | python3 scripts/preprocess_data.py --input_folder folder_name --action finetune --tokenizer_config configs/TestPretrainingConfig.json --combined 118 | ``` 119 | :small_blue_diamond: **Fine-Tuning Notes:** 120 | - **Class labels must be integers** (e.g., `1, 2, 3, ...`). 121 | - The resulting **Arrow files** will include a `"labels"` column. 122 | - You can **manually edit the `"labels"` column** for **custom class adjustments** (including regression tasks). 123 | - As default validation data split does not shuffle the data file before the split, if your data is not shuffled, please use `scripts/shuffler.py` to shuffle the train file to ensure that the resulting test file contains instances of different classes. 124 | ## :wrench: Advanced Options 125 | ### **Handling TCP Options** 126 | - To include **TCPOptions** in your preprocessed data, use the `--tcp_options` flag: 127 | ```bash 128 | python3 scripts/preprocess_data.py --input_folder folder_name --action pretrain --tokenizer_config configs/TCPOptionsConfig.json --combined --tcp_options 129 | ``` 130 | - **Prerequisite**: Your dataset must be **preprocessed with an additional flag** when using `3_extract_fields.py`: 131 | ```bash 132 | python3 scripts/3_extract_fields.py input.pcap output.pcap 1 133 | ``` 134 | - Ensure you use a **config file that includes TCPOptions processing** (e.g., `configs/TCPOptionsConfig.json`). 135 | ## How to cite 136 | ``` 137 | @misc{guthula2024netfoundfoundationmodelnetwork, 138 | title={netFound: Foundation Model for Network Security}, 139 | author={Satyandra Guthula and Roman Beltiukov and Navya Battula and Wenbo Guo and Arpit Gupta and Inder Monga}, 140 | year={2024}, 141 | eprint={2310.17025}, 142 | archivePrefix={arXiv}, 143 | primaryClass={cs.NI}, 144 | url={https://arxiv.org/abs/2310.17025}, 145 | } 146 | ``` 147 | ## Acknowledgements 148 | NSF Awards CNS-2323229, OAC-2126327, and OAC2126281 supported this work. This research used resources at the National Energy Research Scientific Computing Center (NERSC), a DOE Office of Science User Facility supported by the Office of Science of the U.S. Department of Energy under Contract No. DE-AC02-05CH11231 using NERSC award NERSC DDR-ERCAP0029768. Additionally, we would like to thank Cisco Research for their support. 149 | -------------------------------------------------------------------------------- /src/train/NetFoundDataCollator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import DataCollatorForLanguageModeling, BatchEncoding 3 | from typing import Any, Dict, List, Optional, Union 4 | import numpy as np 5 | import random 6 | from transformers.utils import requires_backends, is_torch_device 7 | from utils import get_logger 8 | 9 | logger = get_logger(name=__name__) 10 | 11 | 12 | class DataCollatorWithMeta(DataCollatorForLanguageModeling): 13 | def __init__(self, values_clip: Optional[int] = None, swap_rate=0.5, *args, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | self.values_clip = values_clip 16 | self.swap_rate = swap_rate 17 | 18 | def torch_call( 19 | self, examples: List[Union[List[int], Any, Dict[str, Any]]] 20 | ) -> Dict[str, Any]: 21 | batch = {} 22 | burstsInEachFlow = [example["total_bursts"] for example in examples] 23 | maxBursts = max(burstsInEachFlow) 24 | for i in range(len(examples)): 25 | inputs = dict((k, v) for k, v in examples[i].items()) 26 | for key in inputs.keys(): 27 | if key == "labels" or key == "total_bursts" or key == "replacedAfter": 28 | continue 29 | if key not in batch: 30 | if key != "replacedAfter": 31 | batch[key] = [] 32 | if key == "ports": 33 | batch[key].append(inputs[key] + 1) 34 | elif key in ("protocol", "flow_duration"): 35 | batch[key].append(inputs[key]) 36 | else: 37 | batch[key].append( 38 | inputs[key][: maxBursts * self.tokenizer.max_burst_length] 39 | ) 40 | for key in batch.keys(): 41 | batch[key] = torch.Tensor(np.array(batch[key])) 42 | if ( 43 | key == "input_ids" 44 | or key == "attention_masks" 45 | or key == "ports" 46 | or key == "protocol" 47 | ): 48 | batch[key] = torch.Tensor(batch[key]).to(torch.long) 49 | 50 | if self.mlm: 51 | batch["input_ids"], batch["labels"], batch["swappedLabels"], batch[ 52 | "burstMetasToBeMasked"] = self.torch_mask_tokens( 53 | batch["input_ids"], burstsInEachFlow, self.tokenizer.max_burst_length, self.swap_rate, 54 | batch["protocol"], special_tokens_mask=None 55 | ) 56 | else: 57 | labels = batch["input_ids"].clone() 58 | if self.tokenizer.pad_token_id is not None: 59 | labels[labels == self.tokenizer.pad_token_id] = -100 60 | batch["labels"] = labels 61 | return BatchEncoding(batch) 62 | 63 | def swap_bursts_adjust_prob_matrix(self, input_ids, burstsInEachFlow, max_burst_length, swap_rate): 64 | labels = torch.from_numpy(np.array(np.random.rand(len(burstsInEachFlow)) < swap_rate, dtype=int)) 65 | swappedIds = [] 66 | for i in range(input_ids.shape[0]): 67 | if labels[i] == 1: 68 | burstToRep = random.randint(0, burstsInEachFlow[i] - 1) 69 | flowChoice = random.randint(0, input_ids.shape[0] - 1) 70 | if flowChoice == i: 71 | flowChoice = (flowChoice + 1) % input_ids.shape[0] 72 | burstChoice = random.randint(0, burstsInEachFlow[flowChoice] - 1) 73 | swappedIds.append([i, burstToRep]) 74 | input_ids[i][burstToRep * max_burst_length:(burstToRep + 1) * max_burst_length] = input_ids[flowChoice][burstChoice * max_burst_length:(burstChoice + 1) * max_burst_length] 75 | return input_ids, swappedIds, labels 76 | 77 | def maskMetaData(self, input_ids, burstsInEachFlow, swapped_bursts): 78 | maskedMetaBursts = np.full((input_ids.shape[0], max(burstsInEachFlow)), 0.3) 79 | for ids in swapped_bursts: 80 | maskedMetaBursts[ids[0]][ids[1]] = 0 81 | candidateFlows = np.array( 82 | [np.array(np.array(burstsInEachFlow) > 3, dtype=int)]).transpose() # converting to nX1 matrix 83 | return torch.bernoulli(torch.from_numpy(candidateFlows * maskedMetaBursts)).bool() 84 | 85 | def torch_mask_tokens(self, input_ids, burstsInEachFlow, max_burst_length, swap_rate, protos, **kwargs): 86 | labels = input_ids.clone() 87 | # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) 88 | probability_matrix = torch.full(labels.shape, self.mlm_probability) 89 | new_ip_ids, swappedIds, swappedLabels = self.swap_bursts_adjust_prob_matrix(input_ids, burstsInEachFlow, 90 | max_burst_length, swap_rate) 91 | maskMetaData = self.maskMetaData(input_ids, burstsInEachFlow, swappedIds) 92 | for ids in swappedIds: 93 | probability_matrix[ids[0]][ids[1] * max_burst_length:(ids[1]) * max_burst_length] = 0 94 | input_ids = new_ip_ids 95 | 96 | special_tokens_mask = [ 97 | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) 98 | for val in labels.tolist() 99 | ] 100 | special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) 101 | 102 | probability_matrix.masked_fill_(special_tokens_mask, value=0.0) 103 | masked_indices = torch.bernoulli(probability_matrix).bool() 104 | labels[~masked_indices] = -100 # We only compute loss on masked tokens 105 | 106 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 107 | indices_replaced = ( 108 | torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices 109 | ) 110 | input_ids[indices_replaced] = self.tokenizer.mask_token 111 | 112 | # 10% of the time, we replace masked input tokens with random word 113 | indices_random = ( 114 | torch.bernoulli(torch.full(labels.shape, 0.5)).bool() 115 | & masked_indices 116 | & ~indices_replaced 117 | ) 118 | random_words = torch.randint( 119 | len(self.tokenizer), labels.shape, dtype=torch.long 120 | ) 121 | input_ids[indices_random] = random_words[indices_random] 122 | 123 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 124 | return input_ids, labels, swappedLabels, maskMetaData 125 | 126 | 127 | class DataCollatorForFlowClassification: 128 | label_names: Dict 129 | 130 | def __init__(self, max_burst_length): 131 | self.max_burst_length = max_burst_length 132 | 133 | def __call__(self, examples): 134 | first = examples[0] 135 | maxBursts = max([int(example["total_bursts"]) for example in examples]) 136 | for i in range(len(examples)): 137 | if "stats" in examples[i]: 138 | examples[i]["stats"] = [ 139 | float(t) for t in examples[i]["stats"].split(" ") 140 | ] 141 | batch = {} 142 | if "labels" in first and first["labels"] is not None: 143 | label = ( 144 | first["labels"].item() 145 | if isinstance(first["labels"], torch.Tensor) 146 | else first["labels"] 147 | ) 148 | dtype = torch.long if isinstance(label, int) else torch.float 149 | batch["labels"] = torch.tensor([f["labels"] for f in examples], dtype=dtype) 150 | if "protocol" in first and first["protocol"] is not None: 151 | label = ( 152 | first["protocol"].item() 153 | if isinstance(first["protocol"], torch.Tensor) 154 | else first["protocol"] 155 | ) 156 | dtype = torch.long if isinstance(label, int) else torch.float 157 | batch["protocol"] = torch.tensor([f["protocol"] for f in examples], dtype=dtype) 158 | if "flow_duration" in first and first["flow_duration"] is not None: 159 | label = ( 160 | first["flow_duration"].item() 161 | if isinstance(first["flow_duration"], torch.Tensor) 162 | else first["flow_duration"] 163 | ) 164 | dtype = torch.long if isinstance(label, int) else torch.float 165 | batch["flow_duration"] = torch.tensor([f["flow_duration"] for f in examples], dtype=dtype) 166 | for k, v in first.items(): 167 | if ( 168 | k not in ("labels", "label_ids", "total_bursts", "protocol", "flow_duration") 169 | and v is not None 170 | and not isinstance(v, str) 171 | ): 172 | if isinstance(v, torch.Tensor): 173 | batch[k] = torch.stack( 174 | [f[k][: maxBursts * self.max_burst_length] for f in examples] 175 | ) 176 | else: 177 | batch[k] = torch.tensor( 178 | [f[k][: maxBursts * self.max_burst_length] for f in examples] 179 | ) 180 | return batch 181 | -------------------------------------------------------------------------------- /src/train/NetfoundFinetuning.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from sklearn.exceptions import UndefinedMetricWarning 3 | warnings.filterwarnings("ignore", category=UndefinedMetricWarning) 4 | warnings.filterwarnings("ignore", category=FutureWarning) 5 | 6 | import os 7 | import torch 8 | import torch.distributed 9 | import numpy as np 10 | import utils 11 | import random 12 | from dataclasses import field, dataclass 13 | from datasets.distributed import split_dataset_by_node 14 | from typing import Optional 15 | from copy import deepcopy 16 | from torchinfo import summary 17 | from torch.distributed.elastic.multiprocessing.errors import record 18 | 19 | from transformers import ( 20 | EvalPrediction, 21 | HfArgumentParser, 22 | TrainingArguments, 23 | EarlyStoppingCallback, 24 | ) 25 | 26 | from sklearn.metrics import ( 27 | f1_score, 28 | accuracy_score, 29 | precision_score, 30 | recall_score, 31 | top_k_accuracy_score, 32 | classification_report, confusion_matrix 33 | ) 34 | 35 | from NetFoundDataCollator import DataCollatorForFlowClassification 36 | from NetFoundModels import NetfoundFinetuningModel, NetfoundNoPTM 37 | from NetFoundTrainer import NetfoundTrainer 38 | from NetfoundConfig import NetfoundConfig, NetFoundTCPOptionsConfig, NetFoundLarge 39 | from NetfoundTokenizer import NetFoundTokenizer 40 | from utils import ModelArguments, CommonDataTrainingArguments, freeze, verify_checkpoint, \ 41 | load_train_test_datasets, get_90_percent_cpu_count, get_logger, init_tbwriter, update_deepspeed_config, \ 42 | LearningRateLogCallback 43 | 44 | random.seed(42) 45 | logger = get_logger(name=__name__) 46 | 47 | 48 | @dataclass 49 | class FineTuningDataTrainingArguments(CommonDataTrainingArguments): 50 | """ 51 | Arguments pertaining to what data we are going to input our model for training and eval. 52 | """ 53 | 54 | num_labels: int = field(metadata={"help": "number of classes in the datasets"}, default=None) 55 | problem_type: Optional[str] = field( 56 | default=None, 57 | metadata={"help": "Override regression or classification task"}, 58 | ) 59 | p_val: float = field( 60 | default=0, 61 | metadata={ 62 | "help": "noise rate" 63 | }, 64 | ) 65 | netfound_large: bool = field( 66 | default=False, 67 | metadata={ 68 | "help": "Use the large configuration for netFound model" 69 | }, 70 | ) 71 | 72 | 73 | def regression_metrics(p: EvalPrediction): 74 | logits = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 75 | label_ids = p.label_ids.astype(int) 76 | return {"loss": np.mean(np.absolute((logits - label_ids)))} 77 | 78 | 79 | def classif_metrics(p: EvalPrediction, num_classes): 80 | logits = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 81 | label_ids = p.label_ids.astype(int) 82 | weighted_f1 = f1_score( 83 | y_true=label_ids, y_pred=logits.argmax(axis=1), average="weighted", zero_division=0 84 | ) 85 | weighted_prec = precision_score( 86 | y_true=label_ids, y_pred=logits.argmax(axis=1), average="weighted", zero_division=0 87 | ) 88 | weighted_recall = recall_score( 89 | y_true=label_ids, y_pred=logits.argmax(axis=1), average="weighted", zero_division=0 90 | ) 91 | accuracy = accuracy_score(y_true=label_ids, y_pred=logits.argmax(axis=1)) 92 | logger.warning(classification_report(label_ids, logits.argmax(axis=1), digits=5)) 93 | logger.warning(confusion_matrix(label_ids, logits.argmax(axis=1))) 94 | if num_classes > 3: 95 | logger.warning(f"top3:{top_k_accuracy_score(label_ids, logits, k=3, labels=np.arange(num_classes))}") 96 | if num_classes > 5: 97 | logger.warning(f"top5:{top_k_accuracy_score(label_ids, logits, k=5, labels=np.arange(num_classes))}") 98 | if num_classes > 10: 99 | logger.warning(f"top10:{top_k_accuracy_score(label_ids, logits, k=10, labels=np.arange(num_classes))}") 100 | return { 101 | "weighted_f1": weighted_f1, 102 | "accuracy": accuracy, 103 | "weighted_prec: ": weighted_prec, 104 | "weighted_recall": weighted_recall, 105 | } 106 | 107 | 108 | @record 109 | def main(): 110 | parser = HfArgumentParser( 111 | (ModelArguments, FineTuningDataTrainingArguments, TrainingArguments) 112 | ) 113 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 114 | utils.LOGGING_LEVEL = training_args.get_process_log_level() 115 | 116 | logger.info(f"model_args: {model_args}") 117 | logger.info(f"data_args: {data_args}") 118 | logger.info(f"training_args: {training_args}") 119 | 120 | train_dataset, test_dataset = load_train_test_datasets(logger, data_args) 121 | if "WORLD_SIZE" in os.environ: 122 | train_dataset = split_dataset_by_node(train_dataset, rank=int(os.environ["RANK"]), world_size=int(os.environ["WORLD_SIZE"])) 123 | test_dataset = split_dataset_by_node(test_dataset, rank=int(os.environ["RANK"]), world_size=int(os.environ["WORLD_SIZE"])) 124 | 125 | config = NetFoundTCPOptionsConfig if data_args.tcpoptions else NetfoundConfig 126 | config = config( 127 | num_hidden_layers=model_args.num_hidden_layers, 128 | num_attention_heads=model_args.num_attention_heads, 129 | hidden_size=model_args.hidden_size, 130 | no_meta=data_args.no_meta, 131 | flat=data_args.flat, 132 | ) 133 | if data_args.netfound_large: 134 | config.hidden_size = NetFoundLarge().hidden_size 135 | config.num_hidden_layers = NetFoundLarge().num_hidden_layers 136 | config.num_attention_heads = NetFoundLarge().num_attention_heads 137 | 138 | config.pretraining = False 139 | config.num_labels = data_args.num_labels 140 | config.problem_type = data_args.problem_type 141 | testingTokenizer = NetFoundTokenizer(config=config) 142 | 143 | training_config = deepcopy(config) 144 | training_config.p = data_args.p_val 145 | training_config.limit_bursts = data_args.limit_bursts 146 | trainingTokenizer = NetFoundTokenizer(config=training_config) 147 | additionalFields = None 148 | 149 | if "WORLD_SIZE" in os.environ and training_args.local_rank > 0 and not data_args.streaming: 150 | logger.warning("Waiting for main process to perform the mapping") 151 | torch.distributed.barrier() 152 | 153 | params = { 154 | "batched": True 155 | } 156 | if not data_args.streaming: 157 | params['num_proc'] = data_args.preprocessing_num_workers or get_90_percent_cpu_count() 158 | train_dataset = train_dataset.map(function=trainingTokenizer, **params) 159 | test_dataset = test_dataset.map(function=testingTokenizer, **params) 160 | 161 | if "WORLD_SIZE" in os.environ and training_args.local_rank == 0 and not data_args.streaming: 162 | logger.warning("Loading results from main process") 163 | torch.distributed.barrier() 164 | 165 | data_collator = DataCollatorForFlowClassification(config.max_burst_length) 166 | if model_args.model_name_or_path is not None and os.path.exists( 167 | model_args.model_name_or_path 168 | ): 169 | logger.warning(f"Using weights from {model_args.model_name_or_path}") 170 | model = freeze(NetfoundFinetuningModel.from_pretrained( 171 | model_args.model_name_or_path, config=config 172 | ), model_args) 173 | elif model_args.no_ptm: 174 | model = NetfoundNoPTM(config=config) 175 | else: 176 | model = freeze(NetfoundFinetuningModel(config=config), model_args) 177 | if training_args.local_rank == 0: 178 | summary(model) 179 | 180 | # metrics 181 | problem_type = data_args.problem_type 182 | if problem_type == "regression": 183 | compute_metrics = regression_metrics 184 | else: 185 | compute_metrics = lambda p: classif_metrics(p, data_args.num_labels) 186 | 187 | trainer = NetfoundTrainer( 188 | model=model, 189 | extraFields=additionalFields, 190 | args=training_args, 191 | train_dataset=train_dataset, 192 | eval_dataset=test_dataset, 193 | tokenizer=testingTokenizer, 194 | compute_metrics=compute_metrics, 195 | callbacks=[EarlyStoppingCallback(early_stopping_patience=6)], 196 | data_collator=data_collator, 197 | ) 198 | init_tbwriter(training_args.output_dir) 199 | trainer.add_callback(LearningRateLogCallback(utils.TB_WRITER)) 200 | utils.start_gpu_logging(training_args.output_dir) 201 | utils.start_cpu_logging(training_args.output_dir) 202 | 203 | verify_checkpoint(logger, training_args) 204 | 205 | if training_args.do_train: 206 | logger.warning("*** Train ***") 207 | train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) 208 | trainer.save_model() # doesn't store tokenizer 209 | metrics = train_result.metrics 210 | 211 | trainer.log_metrics("train", metrics) 212 | trainer.save_metrics("train", metrics) 213 | trainer.save_state() 214 | 215 | if training_args.do_eval: 216 | logger.warning("*** Evaluate ***") 217 | metrics = trainer.evaluate(eval_dataset=test_dataset) 218 | trainer.log_metrics("eval", metrics) 219 | trainer.save_metrics("eval", metrics) 220 | 221 | 222 | if __name__ == "__main__": 223 | main() 224 | -------------------------------------------------------------------------------- /src/train/NetfoundPretraining.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from sklearn.exceptions import UndefinedMetricWarning 3 | warnings.filterwarnings("ignore", category=UndefinedMetricWarning) 4 | warnings.filterwarnings("ignore", category=FutureWarning) 5 | import math 6 | import random 7 | import os 8 | import torch 9 | import torch.distributed 10 | import utils 11 | from dataclasses import field, dataclass 12 | from typing import Optional 13 | from torchinfo import summary 14 | 15 | from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score 16 | from torch.distributed.elastic.multiprocessing.errors import record 17 | from datasets.distributed import split_dataset_by_node 18 | from transformers.trainer_utils import get_last_checkpoint 19 | from transformers import ( 20 | MODEL_FOR_MASKED_LM_MAPPING, 21 | HfArgumentParser, 22 | TrainingArguments, 23 | ) 24 | 25 | from NetFoundModels import NetFoundLanguageModelling 26 | from NetFoundTrainer import NetfoundTrainer 27 | from NetFoundDataCollator import DataCollatorWithMeta 28 | from NetfoundConfig import NetfoundConfig, NetFoundTCPOptionsConfig 29 | from NetfoundTokenizer import NetFoundTokenizer 30 | from utils import ModelArguments, CommonDataTrainingArguments, freeze, verify_checkpoint, \ 31 | load_train_test_datasets, get_90_percent_cpu_count, initialize_model_with_deepspeed, get_logger, init_tbwriter, update_deepspeed_config, \ 32 | LearningRateLogCallback 33 | 34 | 35 | random.seed(42) 36 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) 37 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 38 | 39 | @dataclass 40 | class PretrainingDataTrainingArguments(CommonDataTrainingArguments): 41 | """ 42 | Arguments pertaining to what data we are going to input our model for training and eval. 43 | """ 44 | no_mlm: bool = field( 45 | default=False, 46 | metadata={"help": "no MLM loss function"}, 47 | ) 48 | no_swapped_bursts: bool = field( 49 | default=False, 50 | metadata={"help": "no swapped bursts loss function"}, 51 | ) 52 | no_metadata_loss: bool = field( 53 | default=False, 54 | metadata={"help": "no metadata loss function"}, 55 | ) 56 | no_direction_loss: bool = field( 57 | default=False, 58 | metadata={"help": "no direction loss function"}, 59 | ) 60 | swap_rate: Optional[float] = field( 61 | default=0.5, 62 | metadata={"help": "probability of swapping the burst in the flow during training"}, 63 | ) 64 | subflow_len: Optional[int] = field( 65 | default=-1, 66 | metadata={"help": "subflow length, -1 for no subflow"}, 67 | ) 68 | mlm_probability: float = field( 69 | default=0.30, 70 | metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}, 71 | ) 72 | 73 | 74 | def preprocess_logits_for_metrics(logits, _): 75 | if isinstance(logits, tuple): 76 | return tuple(i.argmax(dim=-1) for i in logits) 77 | return logits.argmax(dim=-1) 78 | 79 | 80 | def compute_metrics(eval_preds): 81 | all_preds, all_labels = eval_preds 82 | 83 | labels = all_labels[0] if isinstance(all_labels, tuple) else all_labels 84 | preds = all_preds[0] if isinstance(all_preds, tuple) else all_preds 85 | swappedBurstGTs = all_labels[1] if isinstance(all_labels, tuple) else None 86 | swappedBurstPreds = all_preds[1] if isinstance(all_preds, tuple) else None 87 | 88 | labels = labels.reshape(-1) 89 | preds = preds.reshape(-1) 90 | mask = labels != -100 91 | labels = labels[mask] 92 | preds = preds[mask] 93 | return_metrics = { 94 | "macro_mlm_f1": f1_score(labels, preds, average="macro"), 95 | "macro_mlm_prec": precision_score(labels, preds, average="macro"), 96 | "macro_mlm_recall": recall_score(labels, preds, average="macro"), 97 | "weighted_mlm_f1": f1_score(labels, preds, average="weighted"), 98 | "weighted_mlm_prec": precision_score(labels, preds, average="weighted"), 99 | "weighted_mlm_recall": recall_score(labels, preds, average="weighted"), 100 | "mlm_acc": accuracy_score(labels, preds), 101 | } 102 | if swappedBurstGTs is not None and swappedBurstPreds is not None: 103 | return_metrics.update( 104 | { 105 | "swapped_macro_pred_f1": f1_score(swappedBurstGTs, swappedBurstPreds, average="macro"), 106 | "swapped_macro_pred_prec": precision_score( 107 | swappedBurstGTs, swappedBurstPreds, average="macro" 108 | ), 109 | "swapped_macro_pred_recall": recall_score( 110 | swappedBurstGTs, swappedBurstPreds, average="macro" 111 | ), 112 | "swapped_weighted_pred_f1": f1_score( 113 | swappedBurstGTs, swappedBurstPreds, average="weighted" 114 | ), 115 | "swapped_weighted_pred_prec": precision_score( 116 | swappedBurstGTs, swappedBurstPreds, average="weighted" 117 | ), 118 | "swapped_weighted_pred_recall": recall_score( 119 | swappedBurstGTs, swappedBurstPreds, average="weighted" 120 | ), 121 | "swapped_pred_acc": accuracy_score(swappedBurstGTs, swappedBurstPreds), 122 | } 123 | ) 124 | return return_metrics 125 | 126 | @record 127 | def main(): 128 | parser = HfArgumentParser( 129 | (ModelArguments, PretrainingDataTrainingArguments, TrainingArguments) 130 | ) 131 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 132 | 133 | utils.LOGGING_LEVEL = training_args.get_process_log_level() 134 | logger = get_logger(name=__name__) 135 | 136 | logger.info(f"model_args: {model_args}") 137 | logger.info(f"data_args: {data_args}") 138 | logger.info(f"training_args: {training_args}") 139 | 140 | train_dataset, test_dataset = load_train_test_datasets(logger, data_args) 141 | if "WORLD_SIZE" in os.environ: 142 | train_dataset = split_dataset_by_node(train_dataset, rank=int(os.environ["RANK"]), world_size=int(os.environ["WORLD_SIZE"])) 143 | test_dataset = split_dataset_by_node(test_dataset, rank=int(os.environ["RANK"]), world_size=int(os.environ["WORLD_SIZE"])) 144 | 145 | logger.warning("Tokenizing datasets") 146 | config = NetFoundTCPOptionsConfig if data_args.tcpoptions else NetfoundConfig 147 | config = config( 148 | num_hidden_layers=model_args.num_hidden_layers, 149 | num_attention_heads=model_args.num_attention_heads, 150 | hidden_size=model_args.hidden_size, 151 | no_meta=data_args.no_meta, 152 | flat=data_args.flat, 153 | ) 154 | 155 | config.roformer = False 156 | config.limit_bursts = data_args.limit_bursts 157 | config.no_mlm = data_args.no_mlm 158 | if config.no_mlm: 159 | data_args.mlm_probability = 0.00001 # epsilon 160 | swap_rate = data_args.swap_rate 161 | config.no_swapped_bursts = data_args.no_swapped_bursts 162 | config.no_metadata_loss = data_args.no_metadata_loss 163 | config.no_direction_loss = data_args.no_direction_loss 164 | if config.no_swapped_bursts: 165 | swap_rate = 0 166 | config.name_or_path = model_args.model_name_or_path 167 | tokenizer = NetFoundTokenizer(config=config) 168 | 169 | data_collator = DataCollatorWithMeta( 170 | tokenizer=tokenizer, 171 | mlm_probability=data_args.mlm_probability, 172 | swap_rate=swap_rate 173 | ) 174 | 175 | if "WORLD_SIZE" in os.environ and training_args.local_rank > 0 and not data_args.streaming: 176 | logger.warning("Waiting for main process to perform the mapping") 177 | torch.distributed.barrier() 178 | 179 | params = { 180 | "function": tokenizer, 181 | "batched": True 182 | } 183 | if not data_args.streaming: 184 | params['num_proc'] = data_args.preprocessing_num_workers or get_90_percent_cpu_count() 185 | train_dataset = train_dataset.map(**params) 186 | test_dataset = test_dataset.map(**params) 187 | 188 | if "WORLD_SIZE" in os.environ and training_args.local_rank == 0 and not data_args.streaming: 189 | logger.warning("Loading results from main process") 190 | torch.distributed.barrier() 191 | 192 | if model_args.model_name_or_path is not None and os.path.exists( 193 | model_args.model_name_or_path 194 | ): 195 | logger.warning(f"Using weights from {model_args.model_name_or_path}") 196 | model = freeze(NetFoundLanguageModelling.from_pretrained(model_args.model_name_or_path, config=config), model_args) 197 | else: 198 | model = freeze(NetFoundLanguageModelling(config=config), model_args) 199 | if training_args.local_rank == 0: 200 | summary(model) 201 | 202 | trainer = NetfoundTrainer( 203 | label_names=["swappedLabels"], 204 | model=model, 205 | args=training_args, 206 | train_dataset=train_dataset if training_args.do_train else None, 207 | eval_dataset=test_dataset if training_args.do_eval else None, 208 | tokenizer=tokenizer, 209 | compute_metrics=compute_metrics, 210 | preprocess_logits_for_metrics=preprocess_logits_for_metrics, 211 | data_collator=data_collator, 212 | ) 213 | init_tbwriter(training_args.output_dir) 214 | trainer.add_callback(LearningRateLogCallback(utils.TB_WRITER)) 215 | utils.start_gpu_logging(training_args.output_dir) 216 | utils.start_cpu_logging(training_args.output_dir) 217 | 218 | verify_checkpoint(logger, training_args) 219 | 220 | if training_args.do_train: 221 | train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) 222 | trainer.save_model() 223 | metrics = train_result.metrics 224 | 225 | trainer.log_metrics("train", metrics) 226 | trainer.save_metrics("train", metrics) 227 | trainer.save_state() 228 | 229 | if training_args.do_eval: 230 | logger.warning("*** Evaluate ***") 231 | metrics = trainer.evaluate() 232 | try: 233 | perplexity = math.exp(metrics["eval_loss"]) 234 | except OverflowError: 235 | perplexity = float("inf") 236 | metrics["perplexity"] = perplexity 237 | trainer.log_metrics("eval", metrics) 238 | trainer.save_metrics("eval", metrics) 239 | 240 | if __name__ == "__main__": 241 | main() 242 | -------------------------------------------------------------------------------- /src/train/utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | import datasets 4 | import transformers 5 | import logging 6 | import os 7 | import json 8 | import threading 9 | import subprocess 10 | import torch 11 | import time 12 | import socket 13 | import psutil 14 | 15 | from collections import defaultdict 16 | 17 | from torch.utils.tensorboard import SummaryWriter 18 | from transformers import TrainerCallback 19 | from transformers.trainer_utils import get_last_checkpoint 20 | from datasets import load_dataset 21 | 22 | LOGGING_LEVEL = logging.WARNING 23 | TB_WRITER: Optional[SummaryWriter] = None 24 | 25 | 26 | @dataclass 27 | class ModelArguments: 28 | model_name_or_path: str = field( 29 | default=None, 30 | metadata={ 31 | "help": "The model checkpoint for weights initialization." 32 | "Don't set if you want to train a model from scratch." 33 | }, 34 | ) 35 | metaFeatures: int = field( 36 | default=4, 37 | metadata={"help": "number of metadata fields."}, 38 | ) 39 | num_hidden_layers: int = field( 40 | default=12, 41 | metadata={"help": "Number of hidden layers."}, 42 | ) 43 | num_attention_heads: int = field( 44 | default=12, 45 | metadata={"help": "Number of attention heads."}, 46 | ) 47 | hidden_size: int = field( 48 | default=768, 49 | metadata={"help": "Hidden size."}, 50 | ) 51 | no_ptm: bool = field( 52 | default=False, 53 | metadata={"help": "If True, use NoPTM model (only for fine-tuning)."}, 54 | ) 55 | freeze_flow_encoder: bool = field( 56 | default=False, 57 | metadata={"help": "Freeze flow encoders"}, 58 | ) 59 | freeze_burst_encoder: bool = field( 60 | default=False, 61 | metadata={"help": "Freeze burst encoders"}, 62 | ) 63 | freeze_embeddings: bool = field( 64 | default=False, 65 | metadata={"help": "Freeze embeddings"}, 66 | ) 67 | freeze_base: bool = field( 68 | default=False, 69 | metadata={"help": "Freeze base model"}, 70 | ) 71 | 72 | 73 | @dataclass 74 | class CommonDataTrainingArguments: 75 | train_dir: Optional[str] = field( 76 | metadata={"help": "Directory with training data (Apache Arrow files)"}) 77 | test_dir: Optional[str] = field(default=None, metadata={ 78 | "help": "Directory with testing data (Apache Arrow files)"}) 79 | no_meta: bool = field( 80 | default=False, 81 | metadata={"help": "no meta fields"}, 82 | ) 83 | flat: bool = field( 84 | default=False, 85 | metadata={"help": "no cross burst encoder"}, 86 | ) 87 | limit_bursts: bool = field( 88 | default=False, 89 | metadata={"help": "limit_bursts"}, 90 | ) 91 | validation_dir: Optional[str] = field( 92 | default=None, 93 | metadata={ 94 | "help": "Directory with optional input evaluation data to evaluate the perplexity on (Apache Arrow files)"}, 95 | ) 96 | validation_split_percentage: Optional[int] = field( 97 | default=30, 98 | metadata={"help": "The percentage of the train set used as validation set in case there's no validation split"} 99 | ) 100 | data_cache_dir: Optional[str] = field( 101 | default="/tmp", 102 | metadata={"help": "Where to store the dataset cache."}, 103 | ) 104 | overwrite_cache: bool = field( 105 | default=False, 106 | metadata={"help": "Overwrite the cached training and evaluation sets"}, 107 | ) 108 | max_bursts: int = field( 109 | default=12, 110 | metadata={ 111 | "help": "The maximum number of sentences after tokenization. Sequences longer " 112 | "than this will be truncated." 113 | }, 114 | ) 115 | max_seq_length: Optional[int] = field( 116 | default=1296 + 12, 117 | metadata={ 118 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 119 | "than this will be truncated." 120 | }, 121 | ) 122 | preprocessing_num_workers: Optional[int] = field( 123 | default=None, 124 | metadata={"help": "The number of processes to use for the preprocessing."}, 125 | ) 126 | max_train_samples: Optional[float] = field( 127 | default=None, 128 | metadata={ 129 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 130 | "value if set." 131 | }, 132 | ) 133 | max_eval_samples: Optional[int] = field( 134 | default=None, 135 | metadata={ 136 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 137 | "value if set." 138 | }, 139 | ) 140 | streaming: bool = field( 141 | default=False, 142 | metadata={"help": "Whether to load dataset in the streaming mode."}, 143 | ) 144 | tcpoptions: bool = field( 145 | default=False, 146 | metadata={"help": "Whether the data contains TCP options."}, 147 | ) 148 | 149 | 150 | def freeze(model, model_args): 151 | for name, param in model.base_transformer.named_parameters(): 152 | if model_args.freeze_flow_encoder and ( 153 | "flow_encoder" in name or ("encoder" in name and "position_embeddings" in name)): 154 | param.requires_grad = False 155 | if model_args.freeze_burst_encoder and "burst_encoder" in name: 156 | param.requires_grad = False 157 | if model_args.freeze_embeddings and (name.startswith("embed") or name.startswith("seg_embed")): 158 | param.requires_grad = False 159 | if model_args.freeze_base: 160 | param.requires_grad = False 161 | return model 162 | 163 | 164 | def get_logger(name): 165 | logger = logging.getLogger(name) 166 | logger.addHandler(logging.StreamHandler()) 167 | logger.setLevel(LOGGING_LEVEL) 168 | datasets.utils.logging.set_verbosity(LOGGING_LEVEL) 169 | transformers.utils.logging.set_verbosity(LOGGING_LEVEL) 170 | transformers.utils.logging.enable_default_handler() 171 | transformers.utils.logging.enable_explicit_format() 172 | return logger 173 | 174 | 175 | def verify_checkpoint(logger, training_args): 176 | if not training_args.resume_from_checkpoint: 177 | folders = set(os.listdir(training_args.output_dir)) - {"runs"} 178 | if len(folders) > 0: 179 | if training_args.local_rank == 0: 180 | raise ValueError( 181 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 182 | "Use --overwrite_output_dir to overwrite it." 183 | ) 184 | else: 185 | if training_args.local_rank == 0: 186 | resume_from_checkpoint = training_args.resume_from_checkpoint if isinstance(training_args.resume_from_checkpoint, str) else get_last_checkpoint(training_args.output_dir) 187 | logger.warning( 188 | f"Checkpoint detected, resuming training at {resume_from_checkpoint}. To avoid this behavior, change " 189 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 190 | ) 191 | 192 | 193 | def get_90_percent_cpu_count(): 194 | return max(1, int(os.cpu_count() * 0.9)) 195 | 196 | 197 | def load_train_test_datasets(logger, data_args): 198 | logger.warning("Loading datasets") 199 | if data_args.test_dir is None: 200 | data_args.test_dir = data_args.train_dir 201 | train_split = f"train[{data_args.validation_split_percentage}%:]" 202 | test_split = f"train[:{data_args.validation_split_percentage}%]" 203 | else: 204 | train_split = "train" 205 | test_split = "train" 206 | 207 | train_dataset = load_dataset( 208 | "arrow", 209 | data_dir=data_args.train_dir, 210 | split=train_split, 211 | cache_dir=data_args.data_cache_dir, 212 | streaming=data_args.streaming, 213 | ) 214 | 215 | test_dataset = load_dataset( 216 | "arrow", 217 | data_dir=data_args.test_dir, 218 | split=test_split, 219 | cache_dir=data_args.data_cache_dir, 220 | streaming=data_args.streaming, 221 | ) 222 | 223 | if data_args.max_eval_samples is not None: 224 | test_dataset = test_dataset.select( 225 | range(min(test_dataset.shape[0], data_args.max_eval_samples)) 226 | ) 227 | if data_args.max_train_samples is not None: 228 | train_dataset = train_dataset.select( 229 | range(min(train_dataset.shape[0], data_args.max_train_samples)) 230 | ) 231 | 232 | if not data_args.streaming: 233 | total_bursts_train = [0] * len(train_dataset) 234 | total_bursts_test = [0] * len(test_dataset) 235 | else: 236 | total_bursts_train = defaultdict(lambda: 0) 237 | total_bursts_test = defaultdict(lambda: 0) 238 | 239 | train_dataset = train_dataset.add_column("total_bursts", total_bursts_train) 240 | test_dataset = test_dataset.add_column("total_bursts", total_bursts_test) 241 | 242 | 243 | return train_dataset, test_dataset 244 | 245 | 246 | def initialize_model_with_deepspeed(logger, training_args, get_model): 247 | ''' 248 | here we do only specific init if stage 3 is used, otherwise huggingface trainer will do the rest 249 | ''' 250 | import deepspeed 251 | import base64 252 | logger.warning("Initializing deepspeed-optimized model") 253 | # only if stage 3 254 | if training_args.deepspeed.endswith(".json"): 255 | with open(training_args.deepspeed, "r") as f: 256 | deepspeed_config = json.load(f) 257 | else: 258 | deepspeed_config = training_args.deepspeed 259 | # unbase64 260 | deepspeed_config = json.loads(base64.b64decode(deepspeed_config).decode("utf-8")) 261 | 262 | is_stage_3 = deepspeed_config.get("zero_optimization", {}).get("stage", 0) == 3 263 | with deepspeed.zero.Init(enabled=is_stage_3): 264 | model = get_model() 265 | optimizers = (None, None) 266 | return model, optimizers 267 | 268 | 269 | def init_tbwriter(output_dir=".") -> None: 270 | global TB_WRITER 271 | current_time = time.strftime("%b%d_%H-%M-%S", time.localtime()) 272 | if not torch.cuda.is_available(): 273 | TB_WRITER = SummaryWriter(os.path.join(output_dir, "runs", current_time + "_" + socket.gethostname() + f"_pid{os.getpid()}_custom_metrics")) 274 | return 275 | TB_WRITER = SummaryWriter(os.path.join(output_dir, "runs", current_time + "_" + socket.gethostname() + f"_gpu{torch.cuda.current_device()}_custom_metrics")) 276 | 277 | def get_gpu_utilization(gpu_id): 278 | """Fetch GPU utilization using nvidia-smi for the given GPU.""" 279 | try: 280 | result = subprocess.run( 281 | ["nvidia-smi", f"--query-gpu=utilization.gpu", "--format=csv,noheader,nounits", f"--id={gpu_id}"], 282 | stdout=subprocess.PIPE, 283 | stderr=subprocess.PIPE, 284 | text=True 285 | ) 286 | utilization = int(result.stdout.strip()) 287 | return utilization 288 | except Exception as e: 289 | get_logger(__name__).error(f"Error fetching GPU utilization: {e}") 290 | return 0 291 | 292 | 293 | def log_gpu_stats(gpu_id, output_dir, interval=10): 294 | """ 295 | Log GPU utilization and memory usage for the assigned GPU to TensorBoard every `interval` seconds. 296 | """ 297 | if not torch.cuda.is_available(): 298 | get_logger(__name__).error("No GPU found.") 299 | return 300 | 301 | current_time = time.strftime("%b%d_%H-%M-%S", time.localtime()) 302 | writer = SummaryWriter(os.path.join(output_dir, "runs", current_time + "_" + socket.gethostname() + f"_gpu{gpu_id}")) 303 | 304 | while True: 305 | # Get GPU stats for the current process's assigned GPU 306 | device = torch.device(f"cuda:{gpu_id}") 307 | memory_allocated = torch.cuda.memory_allocated(device) / (1024 ** 3) # In GB 308 | memory_reserved = torch.cuda.memory_reserved(device) / (1024 ** 3) # In MB 309 | memory_free = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) - memory_reserved 310 | 311 | # Get GPU utilization using nvidia-smi 312 | utilization = get_gpu_utilization(gpu_id) 313 | 314 | # Log to TensorBoard 315 | writer.add_scalar(f"GPU/Memory Allocated (GB)", memory_allocated, time.time()) 316 | writer.add_scalar(f"GPU/Memory Reserved (GB)", memory_reserved, time.time()) 317 | writer.add_scalar(f"GPU/Memory Free (GB)", memory_free, time.time()) 318 | writer.add_scalar(f"GPU/Utilization (%)", utilization, time.time()) 319 | 320 | # Sleep before logging the next set of stats 321 | time.sleep(interval) 322 | 323 | def start_gpu_logging(output_dir="."): 324 | """ 325 | Start logging GPU stats to TensorBoard for the current process's assigned GPU. 326 | """ 327 | if not torch.cuda.is_available(): 328 | get_logger(__name__).error("No GPU found.") 329 | return 330 | 331 | gpu_id = torch.cuda.current_device() 332 | 333 | # Start logging GPU stats in a separate thread 334 | gpu_stats_thread = threading.Thread(target=log_gpu_stats, args=(gpu_id, output_dir)) 335 | gpu_stats_thread.daemon = True 336 | gpu_stats_thread.start() 337 | 338 | def log_cpu_stats(output_dir, interval=10): 339 | current_time = time.strftime("%b%d_%H-%M-%S", time.localtime()) 340 | writer = SummaryWriter(os.path.join(output_dir, "runs", current_time + "_" + socket.gethostname() + f"_cpu_metrics")) 341 | 342 | while True: 343 | try: 344 | cpu_load = psutil.cpu_percent(interval=None) 345 | writer.add_scalar(f"CPU/Utilization %", psutil.cpu_percent(interval=None), time.time()) 346 | except Exception as e: 347 | get_logger(__name__).error(f"Error fetching CPU utilization: {e}") 348 | return 0 349 | 350 | time.sleep(interval) 351 | 352 | def start_cpu_logging(output_dir="."): 353 | """ 354 | Start logging overall CPU stats to TensorBoard. 355 | """ 356 | # do it only for a single process per node 357 | if os.environ.get("SLURM_LOCALID", "-1") != "0": 358 | return 359 | 360 | cpu_stats_thread = threading.Thread(target=log_cpu_stats, args=(output_dir,)) 361 | cpu_stats_thread.daemon = True 362 | cpu_stats_thread.start() 363 | 364 | def update_deepspeed_config(training_args): 365 | if training_args.deepspeed is not None and training_args.deepspeed.endswith(".json"): 366 | with open(training_args.deepspeed, "r") as f: 367 | training_args.deepspeed = json.load(f) 368 | if "tensorboard" in training_args.deepspeed: 369 | training_args.deepspeed["tensorboard"]["output_path"] = training_args.output_dir 370 | training_args.deepspeed["tensorboard"]["job_name"] = os.environ.get("SLURM_JOB_NAME", "local") 371 | return training_args 372 | 373 | class LearningRateLogCallback(TrainerCallback): 374 | def __init__(self, tb_writer): 375 | self.tb_writer = tb_writer 376 | 377 | def on_step_end(self, args, state, control, **kwargs): 378 | # The optimizer is passed as a keyword argument 379 | optimizer = kwargs.get('optimizer') 380 | if optimizer is not None: 381 | # If you have multiple parameter groups, you can log each group’s LR 382 | for i, param_group in enumerate(optimizer.param_groups): 383 | self.tb_writer.add_scalar(f"train/learning_rate/group_{i}", param_group['lr'], state.global_step) 384 | return control -------------------------------------------------------------------------------- /src/pre_process/packets_processing_src/3_field_extraction/3_field_extraction.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | uint8_t getTcpFields(const pcpp::tcphdr *tcpheader) { 14 | uint8_t flags = 0; 15 | if (tcpheader->cwrFlag) { 16 | flags |= 1 << 7; 17 | } 18 | if (tcpheader->eceFlag) { 19 | flags |= 1 << 6; 20 | } 21 | if (tcpheader->urgFlag) { 22 | flags |= 1 << 5; 23 | } 24 | if (tcpheader->ackFlag) { 25 | flags |= 1 << 4; 26 | } 27 | if (tcpheader->pshFlag) { 28 | flags |= 1 << 3; 29 | } 30 | if (tcpheader->rstFlag) { 31 | flags |= 1 << 2; 32 | } 33 | if (tcpheader->synFlag) { 34 | flags |= 1 << 1; 35 | } 36 | if (tcpheader->finFlag) { 37 | flags |= 1; 38 | } 39 | return flags; 40 | } 41 | 42 | int process_file(const std::string input_file, const std::string output_file, bool tcpoptions_flag) { 43 | // Open the input pcap file 44 | auto *reader = pcpp::IFileReaderDevice::getReader(input_file); 45 | if (reader == nullptr) { 46 | std::cerr << "Cannot determine reader for file: " << input_file << std::endl; 47 | return 1; 48 | } 49 | if (!reader->open()) { 50 | std::cerr << "Error opening input pcap file: " << input_file << std::endl; 51 | return 1; 52 | } 53 | 54 | std::ofstream outputFileStream; 55 | 56 | pcpp::RawPacket rawPacket; 57 | uint64_t packetCount = 0; 58 | uint8_t global_protocol; 59 | uint32_t absolute_seq_src_ip = 0; 60 | uint32_t tcp_absolute_seq = 0; 61 | uint32_t tcp_absolute_ack = 0; 62 | 63 | while (reader->getNextPacket(rawPacket)) { 64 | pcpp::Packet parsedPacket(&rawPacket); 65 | uint8_t *data; 66 | size_t data_len; 67 | 68 | uint8_t ipversion = 4; 69 | 70 | // Here you would extract each field as per your requirements and write to outputFile 71 | // Example for IP and TCP fields: 72 | 73 | pcpp::IPv4Layer *ipLayer = nullptr; 74 | pcpp::IPv6Layer *ipv6Layer = nullptr; 75 | 76 | ipLayer = parsedPacket.getLayerOfType(); 77 | if (ipLayer == nullptr) { 78 | ipversion = 6; 79 | ipv6Layer = parsedPacket.getLayerOfType(); 80 | if (ipv6Layer == nullptr) { 81 | std::cerr << "File " << input_file << " contains packets with unknown network layer protocol: " << std::to_string(ipversion) << std::endl; 82 | return 1; 83 | } 84 | } 85 | 86 | // getting protocol number 87 | uint8_t protocol; 88 | if (ipversion == 4) { 89 | protocol = ipLayer->getIPv4Header()->protocol; 90 | } else { 91 | protocol = ipv6Layer->getIPv6Header()->nextHeader; 92 | } 93 | 94 | if (packetCount == 0) { 95 | global_protocol = protocol; 96 | 97 | // rename output file to . 98 | std::string outputFilename(output_file); 99 | if (tcpoptions_flag && protocol == pcpp::IPProtocolTypes::PACKETPP_IPPROTO_TCP) { 100 | outputFilename += ".tcpoptions"; 101 | } 102 | outputFilename += "." + std::to_string(protocol); 103 | 104 | // Create and open the output file 105 | outputFileStream.open(outputFilename, std::ios::binary | std::ios::out); 106 | if (!outputFileStream.is_open()) { 107 | std::cerr << "Error opening output file: " << outputFilename << std::endl; 108 | return 1; 109 | } 110 | 111 | outputFileStream.write(reinterpret_cast(&protocol), sizeof(protocol)); 112 | } 113 | 114 | if (protocol != global_protocol) { 115 | std::cerr << "File " << input_file << " contains packets with different protocols. " << 116 | "Protocol of the first packet: " << std::to_string(global_protocol) << ", current packet number: " << packetCount << 117 | ", current protocol: " << std::to_string(protocol) << std::endl; 118 | return 1; 119 | } 120 | 121 | // we are guaranteed to have ipLayer 122 | 123 | // unixtime with nanoseconds, frame.time_epoch 124 | uint64_t epoch = rawPacket.getPacketTimeStamp().tv_sec * static_cast(1000000000L) + rawPacket.getPacketTimeStamp().tv_nsec; 125 | 126 | uint8_t ip_hdr_len; 127 | if (ipversion == 4) { 128 | // value * 32bits to find out number of bytes in the header, ip.hdr_len 129 | constexpr uint8_t IHL_INCREMENTS_BYTES = 4; 130 | ip_hdr_len = ipLayer->getIPv4Header()->internetHeaderLength * IHL_INCREMENTS_BYTES; 131 | } else { 132 | ip_hdr_len = 40; 133 | } 134 | 135 | uint8_t type_of_service; 136 | if (ipversion == 4) { 137 | // ip.dsfield 138 | type_of_service = ipLayer->getIPv4Header()->typeOfService; 139 | } else { 140 | type_of_service = ipv6Layer->getIPv6Header()->trafficClass; 141 | } 142 | 143 | // total length of the packet, ip.len, big endian -> little endian 144 | uint16_t total_length; 145 | if (ipversion == 4) { 146 | total_length = ipLayer->getIPv4Header()->totalLength; 147 | } else { 148 | total_length = ipv6Layer->getIPv6Header()->payloadLength; 149 | } 150 | total_length = ((total_length & 0xff00) >> 8) | ((total_length & 0x00ff) << 8); 151 | if (ipversion == 6) { 152 | // ipv6 includes only payload length 153 | total_length += 40; 154 | } 155 | 156 | // ip.flags 157 | uint8_t flags; 158 | if (ipversion == 4) { 159 | flags = ipLayer->getFragmentFlags(); 160 | //aligning the 3bit flags 161 | flags = flags >> 5; 162 | } else { 163 | flags = 0; 164 | } 165 | 166 | // ip.ttl 167 | uint8_t ttl; 168 | if (ipversion == 4) { 169 | ttl = ipLayer->getIPv4Header()->timeToLive; 170 | } else { 171 | ttl = ipv6Layer->getIPv6Header()->hopLimit; 172 | } 173 | 174 | // ip.src, big endian -> little endian 175 | uint32_t src_ip; 176 | if (ipversion == 4) { 177 | src_ip = ipLayer->getSrcIPv4Address().toInt(); 178 | src_ip = ((src_ip & 0xff000000) >> 24) | ((src_ip & 0x00ff0000) >> 8) | ((src_ip & 0x0000ff00) << 8) | ((src_ip & 0x000000ff) << 24); 179 | } else { 180 | // legacy - only have 32bit for address - let's get first 32bits of ipv6 181 | uint8_t *src_ip_ptr = ipv6Layer->getIPv6Header()->ipSrc; 182 | src_ip = src_ip_ptr[0] << 24 | src_ip_ptr[1] << 16 | src_ip_ptr[2] << 8 | src_ip_ptr[3]; 183 | } 184 | if (packetCount == 0) { 185 | absolute_seq_src_ip = src_ip; 186 | } 187 | 188 | // ip.dst, big endian -> little endian 189 | uint32_t dst_ip; 190 | if (ipversion == 4) { 191 | dst_ip = ipLayer->getDstIPv4Address().toInt(); 192 | dst_ip = ((dst_ip & 0xff000000) >> 24) | ((dst_ip & 0x00ff0000) >> 8) | ((dst_ip & 0x0000ff00) << 8) | ((dst_ip & 0x000000ff) << 24); 193 | } else { 194 | // legacy - only have 32bit for address - let's get first 32bits of ipv6 195 | uint8_t *dst_ip_ptr = ipv6Layer->getIPv6Header()->ipDst; 196 | dst_ip = dst_ip_ptr[0] << 24 | dst_ip_ptr[1] << 16 | dst_ip_ptr[2] << 8 | dst_ip_ptr[3]; 197 | } 198 | 199 | outputFileStream.write(reinterpret_cast(&epoch), sizeof(epoch)); 200 | outputFileStream.write(reinterpret_cast(&ip_hdr_len), sizeof(ip_hdr_len)); 201 | outputFileStream.write(reinterpret_cast(&type_of_service), sizeof(type_of_service)); 202 | outputFileStream.write(reinterpret_cast(&total_length), sizeof(total_length)); 203 | outputFileStream.write(reinterpret_cast(&flags), sizeof(flags)); 204 | outputFileStream.write(reinterpret_cast(&ttl), sizeof(ttl)); 205 | outputFileStream.write(reinterpret_cast(&src_ip), sizeof(src_ip)); 206 | outputFileStream.write(reinterpret_cast(&dst_ip), sizeof(dst_ip)); 207 | 208 | // based on protocol number, we can determine which layer to get 209 | if (protocol == pcpp::IPProtocolTypes::PACKETPP_IPPROTO_TCP) { 210 | auto *tcpLayer = parsedPacket.getLayerOfType(); 211 | if (tcpLayer == nullptr) { 212 | std::cerr << "File " << input_file << " contains packets with unknown transport layer protocol during TCP parsing: " << std::to_string(protocol) << std::endl; 213 | return 1; 214 | } 215 | 216 | // tcp.srcport, big endian -> little endian 217 | uint16_t src_port = tcpLayer->getTcpHeader()->portSrc; 218 | src_port = ((src_port & 0xff00) >> 8) | ((src_port & 0x00ff) << 8); 219 | 220 | // tcp.dstport, big endian -> little endian 221 | uint16_t dst_port = tcpLayer->getTcpHeader()->portDst; 222 | dst_port = ((dst_port & 0xff00) >> 8) | ((dst_port & 0x00ff) << 8); 223 | 224 | // tcp.flags 225 | uint8_t tcp_flags = getTcpFields(tcpLayer->getTcpHeader()); 226 | 227 | // tcp.window_size, big endian -> little endian 228 | uint16_t tcp_window_size = tcpLayer->getTcpHeader()->windowSize; 229 | tcp_window_size = ((tcp_window_size & 0xff00) >> 8) | ((tcp_window_size & 0x00ff) << 8); 230 | 231 | // tcp.seq, big endian -> little endian 232 | uint32_t tcp_seq = tcpLayer->getTcpHeader()->sequenceNumber; 233 | tcp_seq = ((tcp_seq & 0xff000000) >> 24) | ((tcp_seq & 0x00ff0000) >> 8) | ((tcp_seq & 0x0000ff00) << 8) | ((tcp_seq & 0x000000ff) << 24); 234 | 235 | // tcp.ack, big endian -> little endian 236 | uint32_t tcp_ack = tcpLayer->getTcpHeader()->ackNumber; 237 | tcp_ack = ((tcp_ack & 0xff000000) >> 24) | ((tcp_ack & 0x00ff0000) >> 8) | ((tcp_ack & 0x0000ff00) << 8) | ((tcp_ack & 0x000000ff) << 24); 238 | 239 | if (packetCount == 0) { 240 | tcp_absolute_seq = tcp_seq; 241 | } 242 | if (packetCount == 0 && tcp_ack!=0) { 243 | //this would be the case where the session the capture was started midway 244 | tcp_absolute_ack = tcp_ack; 245 | } else if (packetCount == 1 && tcp_absolute_ack == 0) { 246 | //this should be the ideal case where the 2nd packet is ack and the first packet had 0 as ack, the ack in response is present in the seq number 247 | tcp_absolute_ack = tcp_seq; 248 | } 249 | 250 | if (src_ip == absolute_seq_src_ip) { 251 | tcp_seq -= tcp_absolute_seq; 252 | } else { 253 | tcp_seq -= tcp_absolute_ack; 254 | } 255 | 256 | if (src_ip == absolute_seq_src_ip) { 257 | if (tcp_absolute_ack == 0){ 258 | // this will be 0 only when the ack is 0 in the first packet 259 | tcp_ack = 0; 260 | } else { 261 | tcp_ack -= tcp_absolute_ack; 262 | } 263 | } else { 264 | tcp_ack -= tcp_absolute_seq; 265 | } 266 | 267 | // tcp.urgent_pointer, big endian -> little endian 268 | uint16_t tcp_urgent_pointer = tcpLayer->getTcpHeader()->urgentPointer; 269 | tcp_urgent_pointer = ((tcp_urgent_pointer & 0xff00) >> 8) | ((tcp_urgent_pointer & 0x00ff) << 8); 270 | 271 | outputFileStream.write(reinterpret_cast(&src_port), sizeof(src_port)); 272 | outputFileStream.write(reinterpret_cast(&dst_port), sizeof(dst_port)); 273 | outputFileStream.write(reinterpret_cast(&tcp_flags), sizeof(tcp_flags)); 274 | outputFileStream.write(reinterpret_cast(&tcp_window_size), sizeof(tcp_window_size)); 275 | outputFileStream.write(reinterpret_cast(&tcp_seq), sizeof(tcp_seq)); 276 | outputFileStream.write(reinterpret_cast(&tcp_ack), sizeof(tcp_ack)); 277 | outputFileStream.write(reinterpret_cast(&tcp_urgent_pointer), sizeof(tcp_urgent_pointer)); 278 | 279 | if (tcpoptions_flag) { 280 | uint16_t total_options_len = tcpLayer->getTcpHeader()->dataOffset * 4 - 20; 281 | // copy data from tcpLayer->getData()[20] (beginning of tcpoptions) to tcpLayer->getData()[20 + total_options_len] 282 | std::vector tcp_options(total_options_len); 283 | for (int i = 0; i < total_options_len; i++) { 284 | tcp_options[i] = tcpLayer->getData()[20 + i]; 285 | } 286 | 287 | // pad with zeroes to 40 bytes 288 | if (tcp_options.size() < 40) { 289 | tcp_options.resize(40, 0); 290 | } 291 | for (int i = 0; i < 40; i++) { 292 | outputFileStream.write(reinterpret_cast(&tcp_options[i]), sizeof(tcp_options[i])); 293 | } 294 | } 295 | 296 | data = tcpLayer->getLayerPayload(); 297 | data_len = tcpLayer->getLayerPayloadSize(); 298 | } else if (protocol == pcpp::IPProtocolTypes::PACKETPP_IPPROTO_UDP) { 299 | auto *udpLayer = parsedPacket.getLayerOfType(); 300 | if (udpLayer == nullptr) { 301 | std::cerr << "File " << input_file << " contains packets with unknown transport layer protocol during UDP parsing: " << std::to_string(protocol) << std::endl; 302 | return 1; 303 | } 304 | 305 | // udp.srcport, big endian -> little endian 306 | uint16_t src_port = udpLayer->getUdpHeader()->portSrc; 307 | src_port = ((src_port & 0xff00) >> 8) | ((src_port & 0x00ff) << 8); 308 | 309 | // udp.dstport, big endian -> little endian 310 | uint16_t dst_port = udpLayer->getUdpHeader()->portDst; 311 | dst_port = ((dst_port & 0xff00) >> 8) | ((dst_port & 0x00ff) << 8); 312 | 313 | // udp.length, big endian -> little endian 314 | uint16_t udp_length = udpLayer->getUdpHeader()->length; 315 | udp_length = ((udp_length & 0xff00) >> 8) | ((udp_length & 0x00ff) << 8); 316 | 317 | outputFileStream.write(reinterpret_cast(&src_port), sizeof(src_port)); 318 | outputFileStream.write(reinterpret_cast(&dst_port), sizeof(dst_port)); 319 | outputFileStream.write(reinterpret_cast(&udp_length), sizeof(udp_length)); 320 | 321 | data = udpLayer->getLayerPayload(); 322 | data_len = udpLayer->getLayerPayloadSize(); 323 | } else if (protocol == pcpp::IPProtocolTypes::PACKETPP_IPPROTO_ICMP) { 324 | auto *icmpLayer = parsedPacket.getLayerOfType(); 325 | if (icmpLayer == nullptr) { 326 | std::cerr << "File " << input_file << " contains packets with unknown transport layer protocol during ICMP parsing: " << std::to_string(protocol) << std::endl; 327 | return 1; 328 | } 329 | 330 | uint8_t icmp_type = icmpLayer->getIcmpHeader()->type; // icmp.type 331 | uint8_t icmp_code = icmpLayer->getIcmpHeader()->code; // icmp.code 332 | 333 | outputFileStream.write(reinterpret_cast(&icmp_type), sizeof(icmp_type)); 334 | outputFileStream.write(reinterpret_cast(&icmp_code), sizeof(icmp_code)); 335 | 336 | if (icmpLayer->isMessageOfType(pcpp::ICMP_ECHO_REQUEST)) { 337 | data = icmpLayer->getEchoRequestData()->data; 338 | data_len = icmpLayer->getEchoRequestData()->dataLength; 339 | } else if (icmpLayer->isMessageOfType(pcpp::ICMP_ECHO_REPLY)) { 340 | data = icmpLayer->getEchoReplyData()->data; 341 | data_len = icmpLayer->getEchoReplyData()->dataLength; 342 | } else { 343 | data = icmpLayer->getData(); 344 | data_len = icmpLayer->getDataLen(); 345 | } 346 | } else { 347 | std::cerr << "File " << input_file << " contains packets with unknown transport layer protocol: " << std::to_string(protocol) << std::endl; 348 | return 1; 349 | } 350 | 351 | // data: get first 12 bytes of payload unless it's smaller, pad with zeros if data_len < 12 352 | size_t payload_len = std::min(data_len, static_cast(12)); 353 | for (int i = 0; i < payload_len; i++) { 354 | outputFileStream.write(reinterpret_cast(&data[i]), sizeof(data[i])); 355 | } 356 | for (size_t i = payload_len; i < 12; i++) { 357 | uint8_t zero = 0; 358 | outputFileStream.write(reinterpret_cast(&zero), sizeof(zero)); 359 | } 360 | 361 | packetCount++; 362 | } 363 | 364 | outputFileStream.close(); 365 | reader->close(); 366 | 367 | return 0; 368 | } 369 | 370 | int main(int argc, char *argv[]) { 371 | if (argc < 3 || argc > 4) { 372 | std::cout << "Usage: " << argv[0] << " [tcpoptions_flag]" << std::endl; 373 | return 1; 374 | } 375 | 376 | std::string input_folder(argv[1]); 377 | std::string output_folder(argv[2]); 378 | bool tcpoptions_flag = false; 379 | if (argc == 4) { 380 | tcpoptions_flag = std::stoi(argv[3]); 381 | } 382 | 383 | // for each file in the input folder, process it and write to the output folder 384 | for (const auto &entry : std::filesystem::directory_iterator(input_folder)) { 385 | std::string input_file = entry.path().string(); 386 | std::string output_file = output_folder + "/" + entry.path().filename().string(); 387 | process_file(input_file, output_file, tcpoptions_flag); 388 | } 389 | 390 | return 0; 391 | } 392 | -------------------------------------------------------------------------------- /src/pre_process/Tokenize.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from os.path import join 5 | from multiprocessing import Pool 6 | from typing import Optional, List, Tuple 7 | 8 | import pandas as pd 9 | import numpy as np 10 | import socket 11 | from enum import Enum 12 | 13 | import pyarrow as pa 14 | 15 | from argparse import ArgumentParser, Namespace 16 | 17 | 18 | class Protocol(Enum): 19 | TCP = 6 20 | UDP = 17 21 | ICMP = 1 22 | 23 | 24 | PROTOCOL_FIELDS = { 25 | Protocol.TCP: "TCPFields", 26 | Protocol.UDP: "UDPFields", 27 | Protocol.ICMP: "ICMPFields", 28 | } 29 | 30 | FLOW_SIZE_LIMIT = 12 31 | BURST_SIZE_LIMIT = 6 32 | STRICT_IP_CHECK = False 33 | MAX_CLASS_SIZE = 10000 34 | BURST_SPLIT_BORDER = 10000000 35 | INTERNAL_IPS = ["127.0.0.1/8"] 36 | TOKEN_BYTES_LENGTH = 2 37 | TCP_OPTIONS = False 38 | 39 | ## Reserved label names 40 | FILENAME_SUBSTITUTION = "RESERVED_FILENAME" 41 | 42 | 43 | def get_protocol(file_name: str) -> Optional[Protocol]: 44 | try: 45 | return Protocol(int(file_name.split(".")[-1])) 46 | except ValueError: 47 | return None 48 | 49 | 50 | def tokenize_file(inpt_file, label) -> Optional[list]: 51 | try: 52 | bursts, flowDur = get_bursts_from_flow_file(inpt_file) 53 | 54 | protocol = get_protocol(inpt_file) 55 | if protocol is None: 56 | print(f"Error in file {inpt_file} : Protocol not found: {protocol}") 57 | return 58 | field_name = PROTOCOL_FIELDS.get(protocol) 59 | if field_name is None: 60 | print( 61 | f"Error in file {inpt_file} : fields are not defined for protocol {protocol}]" 62 | ) 63 | return 64 | 65 | payloadTokenNum = config["Payload"][0]["numberTokens"] 66 | tokensPerPacket = sum([field["numberTokens"] for field in config["IPFields"]]) 67 | tokensPerPacket += ( 68 | sum([field["numberTokens"] for field in config[field_name]]) 69 | + payloadTokenNum 70 | ) 71 | 72 | bursts = bursts[ 73 | bursts["burstID"] < FLOW_SIZE_LIMIT 74 | ] # leave only first FLOW_SIZE_LIMIT bursts 75 | bursts = ( 76 | bursts.groupby("burstID").head(BURST_SIZE_LIMIT).reset_index(drop=True) 77 | ) # leave only first BURST_SIZE_LIMIT packets in each burst 78 | 79 | grouped_bursts = bursts.groupby("burstID") 80 | grouped_bursts_stat = grouped_bursts.agg( 81 | { 82 | "IP_tl": "sum", 83 | "IAT": lambda x: int(round((x.iloc[0] / 1000))), 84 | "direction": "first", 85 | } 86 | ).sort_index() 87 | group_sizes = grouped_bursts.size().reset_index(name="packet_count") 88 | grouped_bursts_stat = grouped_bursts_stat.reset_index().merge( 89 | group_sizes, on="burstID", how="left" 90 | ) 91 | grouped_bursts_stat["total_tokens"] = ( 92 | tokensPerPacket * grouped_bursts_stat["packet_count"] 93 | ) 94 | 95 | number_of_bursts = len(grouped_bursts_stat) 96 | directionsls = [ 97 | grouped_bursts_stat["direction"][i] for i in range(number_of_bursts) 98 | ] 99 | bytels = [grouped_bursts_stat["IP_tl"][i] for i in range(number_of_bursts)] 100 | iatls = [grouped_bursts_stat["IAT"][i] for i in range(number_of_bursts)] 101 | countls = [ 102 | grouped_bursts_stat["packet_count"][i] for i in range(number_of_bursts) 103 | ] 104 | 105 | tokenize_fields_df(config, "IPFields", bursts) 106 | tokenize_fields_df(config, field_name, bursts) 107 | tokenize_fields_df(config, "Payload", bursts) 108 | 109 | columns = ( 110 | [x["field"] for x in config["IPFields"]] 111 | + [x["field"] for x in config[field_name]] 112 | + [x["field"] for x in config["Payload"]] 113 | ) 114 | 115 | bursts.loc[:, "concatenated"] = bursts[columns].sum(axis=1) 116 | grouped_bursts = bursts.groupby("burstID")["concatenated"].sum() 117 | 118 | if len(grouped_bursts) > 0: 119 | return [ 120 | flowDur, 121 | grouped_bursts, 122 | directionsls, 123 | bytels, 124 | iatls, 125 | countls, 126 | protocol.value, 127 | label, 128 | ] 129 | except Exception as e: 130 | print(f"Error in file {inpt_file} : {str(e)}") 131 | 132 | 133 | def get_int_from_byte(byte_vals: bytes, byte_order="little"): 134 | if byte_vals == b"": 135 | raise ValueError("byte_vals is empty") 136 | # noinspection PyTypeChecker 137 | return int.from_bytes(byte_vals, byteorder=byte_order) 138 | 139 | 140 | def get_bursts_from_flow_file(inpt_file): 141 | flow_rows = [] 142 | protocol = get_protocol(inpt_file) 143 | 144 | with open(inpt_file, mode="rb") as f: 145 | assert protocol.value == get_int_from_byte(f.read(1)) 146 | try: 147 | while True: 148 | ts = f.read(8) 149 | if ts == b"": 150 | break # end of file 151 | currFlowRow = [ 152 | get_int_from_byte(ts), 153 | get_int_from_byte(f.read(1)), 154 | get_int_from_byte(f.read(1)), 155 | get_int_from_byte(f.read(2)), 156 | get_int_from_byte(f.read(1)), 157 | get_int_from_byte(f.read(1)), 158 | get_int_from_byte(f.read(4)), 159 | get_int_from_byte(f.read(4)), 160 | ] 161 | if protocol == Protocol.TCP: 162 | f.read(2) # srcport 163 | f.read(2) # dstport 164 | currFlowRow.extend( 165 | [ 166 | get_int_from_byte(f.read(1)), 167 | get_int_from_byte(f.read(2)), 168 | get_int_from_byte(f.read(4)), 169 | get_int_from_byte(f.read(4)), 170 | get_int_from_byte(f.read(2)), 171 | ] 172 | ) 173 | if TCP_OPTIONS: 174 | currFlowRow.extend([get_int_from_byte(f.read(40), byte_order="big")]) 175 | if protocol == Protocol.UDP: 176 | f.read(2) # srcport 177 | f.read(2) # dstport 178 | currFlowRow.append(get_int_from_byte(f.read(2))) 179 | if protocol == Protocol.ICMP: 180 | currFlowRow.extend( 181 | [ 182 | get_int_from_byte(f.read(1)), 183 | get_int_from_byte(f.read(1)), 184 | ] 185 | ) 186 | # currFlowRow.append(file.read(12)) # payload 187 | currFlowRow.append(get_int_from_byte(f.read(12), byte_order="big")) 188 | flow_rows.append(currFlowRow) 189 | except ValueError: 190 | print(f"Unexpected end of file occured for {inpt_file}") 191 | return None, 0 192 | 193 | if len(flow_rows) == 0: 194 | return None, 0 195 | 196 | columns = [ 197 | "rts", 198 | "IP_hl", 199 | "IP_tos", 200 | "IP_tl", 201 | "IP_Flags", 202 | "IP_ttl", 203 | "SrcIP", 204 | "DstIP", 205 | ] 206 | 207 | if protocol == Protocol.TCP: 208 | columns += [ 209 | "TCP_Flags", 210 | "TCP_wsize", 211 | "TCP_seq", 212 | "TCP_ackn", 213 | "TCP_urp", 214 | ] 215 | if TCP_OPTIONS: 216 | columns.append("TCP_options") 217 | columns.append("Payload") 218 | elif protocol == Protocol.UDP: 219 | columns += [ 220 | "UDP_len", 221 | "Payload", 222 | ] 223 | else: 224 | columns += [ 225 | "ICMP_type", 226 | "ICMP_code", 227 | "Payload", 228 | ] 229 | 230 | df = pd.DataFrame(flow_rows, columns=columns) 231 | 232 | df["rts"] = df["rts"].astype(int) 233 | df["rts"] -= df["rts"].min() 234 | flowDur = int((df["rts"].max() / 1000)) 235 | df = df.sort_values(by="rts", ignore_index=True, kind="stable") 236 | fwdDf, bkdDf = split_bursts_on_dir(df, inpt_file) 237 | if fwdDf is None and bkdDf is None: 238 | fwdDf = df 239 | combinedBurstLen = (0 if fwdDf is None else fwdDf.shape[0]) + ( 240 | 0 if bkdDf is None else bkdDf.shape[0] 241 | ) 242 | if df.shape[0] > combinedBurstLen: 243 | print( 244 | f"Original length : {df.shape[0]} but combined length : {combinedBurstLen}" 245 | ) 246 | fwdBursts = split_based_on_iat(fwdDf, starting_index=0) 247 | fwdBursts["direction"] = True 248 | starting_index = ( 249 | 0 250 | if (fwdBursts is None or fwdBursts.shape[0] == 0) 251 | else fwdBursts.burstID.max() + 1 252 | ) 253 | bkdBursts = split_based_on_iat(bkdDf, starting_index=starting_index) 254 | bkdBursts["direction"] = False 255 | bursts = ( 256 | (pd.concat([fwdBursts, bkdBursts]) if bkdBursts is not None else fwdBursts) 257 | .sort_values(by="first_packet_time", ignore_index=True, kind="stable") 258 | .drop(columns=["first_packet_time"]) 259 | ) 260 | bursts["burstID"] = bursts["burstID"].diff().ne(0).cumsum() - 1 261 | 262 | if bursts.shape[0] == 0: 263 | return None, 0 264 | if bursts is None: 265 | logger.info(f"No burst in file {inpt_file}") 266 | return None, 0 267 | 268 | return bursts, flowDur 269 | 270 | 271 | def split_bursts_on_dir(df, inputfile): 272 | ipSet = set(df["SrcIP"]).union(set(df["DstIP"])) 273 | 274 | if len(ipSet) > 2: 275 | ipsToPrint = ",".join([int_to_ip_address(ip) for ip in ipSet]) 276 | print(f"inputfile {inputfile} has flows IPs {ipsToPrint}") 277 | raise ValueError(f"inputfile {inputfile} has flows IPs {ipsToPrint}") 278 | srcIP = None 279 | for ip in ipSet: 280 | if is_internal_ip(ip): 281 | srcIP = ip 282 | break 283 | if srcIP is None: 284 | ipList = ",".join([int_to_ip_address(ip) for ip in ipSet]) 285 | if STRICT_IP_CHECK: 286 | logger.error(f"IPs {ipList} not in internalIPs for {inputfile}") 287 | return None, None 288 | srcIP = df.SrcIP[0] 289 | retBkd = df[df.SrcIP != srcIP] 290 | retFwd = df[df.SrcIP == srcIP] 291 | return retFwd, retBkd 292 | 293 | 294 | def convert_ip_str_to_bits(ip_with_subnet): 295 | split_vals = ip_with_subnet.split("/") 296 | ip = split_vals[0] 297 | if len(split_vals) == 1: 298 | subnetRange = "32" 299 | else: 300 | subnetRange = ip_with_subnet.split("/")[1] 301 | return ( 302 | "".join([bin(int(octet)).split("b")[1].zfill(8) for octet in ip.split(".")]) 303 | + "/" 304 | + subnetRange 305 | ) 306 | 307 | 308 | def int_to_ip_address(ip_int): 309 | # Convert integer to IP address 310 | return socket.inet_ntoa(int.to_bytes(ip_int, 4, "big")) 311 | 312 | 313 | def is_ip_in_range(ip, ip_range): 314 | subnetLength = int(ip_range.split("/")[1]) 315 | return ( 316 | ip[:subnetLength] 317 | == convert_ip_str_to_bits(ip_range.split("/")[0])[:subnetLength] 318 | ) 319 | 320 | 321 | def is_internal_ip(ip): 322 | ip = "{0:b}".format(ip).rjust(32, "0") 323 | return any([is_ip_in_range(ip, ipRange) for ipRange in INTERNAL_IPS]) 324 | 325 | 326 | def split_based_on_iat( 327 | df: pd.DataFrame, starting_index: int = 0 328 | ) -> Optional[pd.DataFrame]: 329 | if df is None: 330 | return None 331 | df = df.reset_index(drop=True) 332 | 333 | burstStartIdx = df[ 334 | (df.rts - df.rts.shift(1, axis=0, fill_value=(-1 * BURST_SPLIT_BORDER - 1))) 335 | > BURST_SPLIT_BORDER 336 | ].index 337 | if len(burstStartIdx) == 0: 338 | df["IAT"] = 0 339 | df["burstID"] = 0 340 | return df 341 | 342 | df["burstID"] = df.index.isin(burstStartIdx).cumsum() + starting_index 343 | 344 | first_packets = df.groupby("burstID")["rts"].first().reset_index() 345 | first_packets.rename(columns={"rts": "first_packet_time"}, inplace=True) 346 | 347 | first_packets["IAT"] = first_packets["first_packet_time"].diff().fillna(0) 348 | df = df.merge( 349 | first_packets[["burstID", "first_packet_time", "IAT"]], on="burstID", how="left" 350 | ) 351 | df.loc[df["burstID"] == starting_index + 1, "IAT"] = 0 352 | df["IAT"] = df["IAT"].astype(int) 353 | return df 354 | 355 | 356 | def tokenize_fields_df(_config, type_of_field, df): 357 | for ipConf in _config[type_of_field]: 358 | field = ipConf["field"] 359 | numberOfTokens = ipConf["numberTokens"] 360 | df[field] = df[field].apply( 361 | lambda val: val.to_bytes( 362 | numberOfTokens * TOKEN_BYTES_LENGTH, byteorder="big" 363 | ) 364 | ) 365 | 366 | 367 | def get_logger(logger_file: str) -> logging.Logger: 368 | _logger = logging.getLogger("TokenizerLog") 369 | _logger.setLevel(logging.DEBUG) 370 | fh = logging.FileHandler(logger_file) 371 | fh.setLevel(logging.DEBUG) 372 | _logger.addHandler(fh) 373 | return _logger 374 | 375 | 376 | def slice_bytes_to_16bit_tokens(burst_tokens: pd.Series) -> list[list[int]]: 377 | result = burst_tokens.apply( 378 | lambda x: [ 379 | int.from_bytes(x[i: i + TOKEN_BYTES_LENGTH], byteorder="big") 380 | for i in range(0, len(x), TOKEN_BYTES_LENGTH) 381 | ] 382 | ) 383 | 384 | return result 385 | 386 | 387 | def tokenizer_helper( 388 | output_filename: str, 389 | tokenization_args: List[Tuple[str, Optional[str]]], 390 | batch_size: Optional[int] = 1000, 391 | ) -> None: 392 | flow_duration_type = pa.uint64() 393 | burst_tokens_type = pa.list_(pa.list_(pa.uint16())) 394 | directions_type = pa.list_(pa.bool_()) 395 | bytes_type = pa.list_(pa.uint32()) 396 | iats_type = pa.list_(pa.uint64()) 397 | counts_type = pa.list_(pa.uint32()) 398 | protocol_type = pa.uint16() 399 | label_type = pa.string() 400 | 401 | table_schema = pa.schema( 402 | [ 403 | pa.field("flow_duration", flow_duration_type), 404 | pa.field("burst_tokens", burst_tokens_type), 405 | pa.field("directions", directions_type), 406 | pa.field("bytes", bytes_type), 407 | pa.field("iats", iats_type), 408 | pa.field("counts", counts_type), 409 | pa.field("protocol", protocol_type), 410 | pa.field("labels", label_type), 411 | ] 412 | ) 413 | 414 | # flow_durations, burst_tokens, directions, bytes_ar, iats, counts, protocols, labels 415 | data = [[] for _ in range(8)] 416 | 417 | with pa.OSFile(output_filename, "wb") as sink: 418 | with pa.ipc.new_stream(sink, schema=table_schema) as writer: 419 | total_files = len(tokenization_args) 420 | for i, (inpt_file, label) in enumerate(tokenization_args, start=1): 421 | result = tokenize_file(inpt_file, label) 422 | 423 | if result is not None: 424 | result[1] = slice_bytes_to_16bit_tokens(result[1]) 425 | for j in range(8): 426 | data[j].append(result[j]) 427 | 428 | if i % batch_size == 0 or i == total_files: 429 | print(f"Processed {i} files") 430 | if data[1]: 431 | batch = pa.record_batch( 432 | [ 433 | pa.array(data[0], type=flow_duration_type), 434 | pa.array(data[1], type=burst_tokens_type), 435 | pa.array(data[2], type=directions_type), 436 | pa.array(data[3], type=bytes_type), 437 | pa.array(data[4], type=iats_type), 438 | pa.array(data[5], type=counts_type), 439 | pa.array(data[6], type=protocol_type), 440 | pa.array(data[7], type=label_type), 441 | ], 442 | schema=table_schema, 443 | ) 444 | writer.write_batch(batch) 445 | data = [[] for _ in range(8)] 446 | 447 | 448 | def get_args() -> Namespace: 449 | parser = ArgumentParser() 450 | parser.add_argument("--conf_file", type=str, required=True) 451 | parser.add_argument("--input_dir", type=str, default=None) 452 | parser.add_argument("--output_dir", type=str, default=None) 453 | parser.add_argument("--label", type=str, default=None) 454 | parser.add_argument("--logger_file", type=str, default="/tmp/tokenizer.log") 455 | parser.add_argument("--flow_size_limit", type=int, default=12) 456 | parser.add_argument("--burst_size_limit", type=int, default=6) 457 | parser.add_argument("--burst_split_border", type=float, default=10000000) 458 | parser.add_argument("--strict_ip_checking", type=bool, default=False) 459 | parser.add_argument("--max_class_size", type=int, default=10000) 460 | parser.add_argument("--cores", type=int, default=0) 461 | parser.add_argument("--arrow_batch_size", type=int, default=1000) 462 | return parser.parse_args() 463 | 464 | 465 | if __name__ == "__main__": 466 | script_args = get_args() 467 | 468 | FLOW_SIZE_LIMIT = script_args.flow_size_limit 469 | BURST_SIZE_LIMIT = script_args.burst_size_limit 470 | BURST_SPLIT_BORDER = script_args.burst_split_border 471 | STRICT_IP_CHECK = script_args.strict_ip_checking 472 | MAX_CLASS_SIZE = script_args.max_class_size 473 | 474 | logger = get_logger(script_args.logger_file) 475 | 476 | with open(script_args.conf_file, "r") as config_file: 477 | config = json.load(config_file) 478 | 479 | INTERNAL_IPS = config["internalIPs"] 480 | 481 | for field in config["TCPFields"]: 482 | if field["field"] == "TCP_options": 483 | TCP_OPTIONS = True 484 | 485 | if script_args.input_dir: 486 | input_dir = script_args.input_dir 487 | else: 488 | input_dir = config["input_dir"] 489 | 490 | if script_args.output_dir: 491 | output_dir = script_args.output_dir 492 | else: 493 | output_dir = config["output_dir"] 494 | 495 | args = [ 496 | ( 497 | join(input_dir, file), 498 | script_args.label if script_args.label != FILENAME_SUBSTITUTION else join(input_dir, file), 499 | ) 500 | for file 501 | in os.listdir(input_dir) 502 | ] 503 | 504 | # validation for tcp options 505 | tcpoptions_pattern = f"tcpoptions.{Protocol.TCP.value}" 506 | if TCP_OPTIONS: 507 | for filepath, _ in args: 508 | if filepath.endswith(f".{Protocol.TCP.value}") and not filepath.endswith(tcpoptions_pattern): 509 | raise ValueError(f"TCP options are enabled, but file {filepath} does not have tcp options") 510 | else: 511 | for filepath, _ in args: 512 | if filepath.endswith(tcpoptions_pattern): 513 | raise ValueError(f"TCP options are disabled, but file {filepath} has tcp options") 514 | 515 | print(f"Total files: {len(args)}") 516 | 517 | cores = script_args.cores 518 | if cores == 0: 519 | cores = min(os.cpu_count() - 2, len(args)) 520 | 521 | # split args to cores equal lists 522 | args = np.array_split(args, cores) 523 | input_args = ( 524 | (join(output_dir, f"shard.{i}.arrow"), args[i], script_args.arrow_batch_size) 525 | for i in range(cores) 526 | ) 527 | 528 | print(f"Started processing files, time: {pd.Timestamp.now()}") 529 | with Pool(cores) as p: 530 | p.starmap(tokenizer_helper, input_args) 531 | print(f"Finished processing files, time: {pd.Timestamp.now()}") 532 | -------------------------------------------------------------------------------- /src/train/NetFoundModels.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | from utils import get_logger 4 | import utils 5 | import time 6 | 7 | import torch.nn 8 | from torch.nn import CrossEntropyLoss, MSELoss, L1Loss 9 | from transformers import PreTrainedModel 10 | import torch.nn as nn 11 | from transformers.activations import gelu 12 | from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput 13 | from transformers.models.roformer.modeling_roformer import ( 14 | RoFormerAttention, 15 | RoFormerIntermediate, 16 | RoFormerOutput, 17 | RoFormerEmbeddings, 18 | RoFormerSinusoidalPositionalEmbedding 19 | ) 20 | from transformers.models.roberta.modeling_roberta import ( 21 | RobertaAttention, 22 | RobertaIntermediate, 23 | RobertaOutput, 24 | RobertaEmbeddings, 25 | ) 26 | from transformers.utils import ModelOutput 27 | import copy 28 | from dataclasses import dataclass 29 | 30 | logger = get_logger(__name__) 31 | 32 | TORCH_IGNORE_INDEX = -100 33 | 34 | def transform_tokens2bursts(hidden_states, num_bursts, max_burst_length): 35 | # transform sequence into segments 36 | seg_hidden_states = torch.reshape( 37 | hidden_states, 38 | (hidden_states.size(0), num_bursts, max_burst_length, hidden_states.size(-1)), 39 | ) 40 | # squash segments into sequence into a single axis (samples * segments, max_segment_length, hidden_size) 41 | hidden_states_reshape = seg_hidden_states.contiguous().view( 42 | hidden_states.size(0) * num_bursts, max_burst_length, seg_hidden_states.size(-1) 43 | ) 44 | 45 | return hidden_states_reshape 46 | 47 | 48 | def transform_masks2bursts(hidden_states, num_bursts, max_burst_length): 49 | # transform sequence into segments 50 | seg_hidden_states = torch.reshape( 51 | hidden_states, (hidden_states.size(0), 1, 1, num_bursts, max_burst_length) 52 | ) 53 | # squash segments into sequence into a single axis (samples * segments, 1, 1, max_segment_length) 54 | hidden_states_reshape = seg_hidden_states.contiguous().view( 55 | hidden_states.size(0) * num_bursts, 1, 1, seg_hidden_states.size(-1) 56 | ) 57 | 58 | return hidden_states_reshape 59 | 60 | 61 | def transform_bursts2tokens(seg_hidden_states, num_bursts, max_burst_length): 62 | # transform squashed sequence into segments 63 | hidden_states = seg_hidden_states.contiguous().view( 64 | seg_hidden_states.size(0) // num_bursts, 65 | num_bursts, 66 | max_burst_length, 67 | seg_hidden_states.size(-1), 68 | ) 69 | # transform segments into sequence 70 | hidden_states = hidden_states.contiguous().view( 71 | hidden_states.size(0), num_bursts * max_burst_length, hidden_states.size(-1) 72 | ) 73 | return hidden_states 74 | 75 | 76 | class TransformerLayer(nn.Module): 77 | def __init__(self, config): 78 | super().__init__() 79 | self.roformer = config.roformer 80 | self.chunk_size_feed_forward = config.chunk_size_feed_forward 81 | self.seq_len_dim = 1 82 | self.attention = RoFormerAttention(config) if self.roformer else RobertaAttention(config) 83 | self.is_decoder = config.is_decoder 84 | self.intermediate = RoFormerIntermediate(config) if self.roformer else RobertaIntermediate(config) 85 | self.output = RoFormerOutput(config) if self.roformer else RobertaOutput(config) 86 | 87 | def forward( 88 | self, 89 | hidden_states, 90 | attention_mask=None, 91 | output_attentions=False, 92 | seqNo = None 93 | ): 94 | if not self.roformer: 95 | self_attention_outputs = self.attention( 96 | hidden_states, 97 | attention_mask, 98 | output_attentions=output_attentions, 99 | ) 100 | else: 101 | self_attention_outputs = self.attention( 102 | hidden_states, 103 | attention_mask, 104 | sinusoidal_pos = seqNo, 105 | output_attentions=output_attentions, 106 | ) 107 | attention_output = self_attention_outputs[0] 108 | 109 | outputs = self_attention_outputs[1:] 110 | 111 | intermediate_output = self.intermediate(attention_output) 112 | layer_output = self.output(intermediate_output, attention_output) 113 | outputs = (layer_output,) + outputs 114 | 115 | return outputs 116 | 117 | 118 | class NetFoundEncoder(nn.Module): 119 | def __init__(self, config): 120 | super().__init__() 121 | self.config = config 122 | self.roformer = config.roformer 123 | self.layer = nn.ModuleList( 124 | [NetFoundLayer(config) if not self.config.flat else NetFoundLayerFlat(config) for idx in range(config.num_hidden_layers)] 125 | ) 126 | self.burst_positions = RoFormerSinusoidalPositionalEmbedding( 127 | config.max_position_embeddings, config.hidden_size // config.num_attention_heads 128 | ) 129 | self.flow_positions = RoFormerSinusoidalPositionalEmbedding( 130 | config.max_bursts + 1, config.hidden_size // config.num_attention_heads 131 | ) 132 | self.gradient_checkpointing = False 133 | 134 | def forward( 135 | self, 136 | hidden_states, 137 | attention_mask=None, 138 | num_bursts=None, 139 | output_attentions=False, 140 | output_hidden_states=False, 141 | return_dict=True, 142 | ): 143 | all_hidden_states = () if output_hidden_states else None 144 | all_self_attentions = () if output_attentions else None 145 | all_burst_attentions = () if output_attentions else None 146 | 147 | burst_seqs = transform_tokens2bursts( 148 | hidden_states, num_bursts=num_bursts, max_burst_length=self.config.max_burst_length 149 | ) 150 | past_key_values_length = 0 151 | burstSeqNo = self.burst_positions(burst_seqs.shape[:-1], past_key_values_length)[None, None, :, :] 152 | flow_seqs = transform_bursts2tokens( 153 | burst_seqs, 154 | num_bursts=num_bursts, 155 | max_burst_length=self.config.max_burst_length, 156 | )[:, :: self.config.max_burst_length] 157 | flowSeqNo = self.flow_positions(flow_seqs.shape[:-1], past_key_values_length)[None, None, :, :] 158 | 159 | for i, layer_module in enumerate(self.layer): 160 | if output_hidden_states: 161 | all_hidden_states = all_hidden_states + (hidden_states,) 162 | 163 | if self.gradient_checkpointing and self.training: 164 | 165 | def create_custom_forward(module): 166 | def custom_forward(*inputs): 167 | return module(*inputs, output_attentions, burstSeqNo, flowSeqNo) 168 | 169 | return custom_forward 170 | 171 | layer_outputs = torch.utils.checkpoint.checkpoint( 172 | create_custom_forward(layer_module), 173 | hidden_states, 174 | attention_mask, 175 | ) 176 | else: 177 | layer_outputs = layer_module( 178 | hidden_states, attention_mask, num_bursts, output_attentions, burstSeqNo, flowSeqNo 179 | ) 180 | 181 | hidden_states = layer_outputs[0] 182 | if output_attentions: 183 | all_self_attentions = all_self_attentions + (layer_outputs[1],) 184 | all_burst_attentions = all_burst_attentions + (layer_outputs[2],) 185 | else: 186 | all_self_attentions = None 187 | all_burst_attentions = None 188 | if output_hidden_states: 189 | all_hidden_states = all_hidden_states + (hidden_states,) 190 | if not return_dict: 191 | return tuple( 192 | v 193 | for v in [ 194 | hidden_states, 195 | all_hidden_states, 196 | all_self_attentions, 197 | all_burst_attentions, 198 | ] 199 | if v is not None 200 | ) 201 | return BaseModelOutputWithFlowAttentions( 202 | last_hidden_state=hidden_states, 203 | hidden_states=all_hidden_states, 204 | attentions=all_self_attentions, 205 | flow_attentions=all_burst_attentions, 206 | ) 207 | 208 | def _tie_weights(self): 209 | original_position_embeddings = None 210 | for module in self.layer: 211 | if hasattr(module, "position_embeddings"): 212 | assert hasattr(module.position_embeddings, "weight") 213 | if original_position_embeddings is None: 214 | original_position_embeddings = module.position_embeddings 215 | if self.config.torchscript: 216 | module.position_embeddings.weight = nn.Parameter( 217 | original_position_embeddings.weight.clone() 218 | ) 219 | else: 220 | module.position_embeddings.weight = ( 221 | original_position_embeddings.weight 222 | ) 223 | return 224 | 225 | class NetFoundEmbeddingsWithMeta: 226 | def __init__(self, config): 227 | self.metaEmbeddingLayer1 = nn.Linear(config.metaFeatures, 1024) 228 | self.metaEmbeddingLayer2 = nn.Linear(1024, config.hidden_size) 229 | self.no_meta = config.no_meta 230 | self.protoEmbedding = nn.Embedding(65536, config.hidden_size) 231 | self.compressEmbeddings = nn.Linear(config.hidden_size*3, config.hidden_size) 232 | 233 | def addMetaEmbeddings(self, 234 | embeddings, 235 | direction=None, 236 | iats=None, 237 | bytes=None, 238 | pkt_count=None, 239 | protocol=None): 240 | linearLayerDtype = self.metaEmbeddingLayer1.weight.dtype 241 | if not self.no_meta: 242 | metaEmbeddings = self.metaEmbeddingLayer2( 243 | self.metaEmbeddingLayer1( 244 | torch.concat( 245 | [ 246 | direction.unsqueeze(2).to(linearLayerDtype), 247 | bytes.unsqueeze(2).to(linearLayerDtype) / 1000, 248 | pkt_count.unsqueeze(2).to(linearLayerDtype), 249 | iats.unsqueeze(2).to(linearLayerDtype), 250 | ], 251 | dim=-1, 252 | ) 253 | ) 254 | ) 255 | embeddings = torch.concat([embeddings, metaEmbeddings], dim = -1) 256 | else: 257 | embeddings = torch.concat([embeddings, torch.zeroes(embeddings.shape)], dim = -1) 258 | protoEmbeddings = ( 259 | self.protoEmbedding(protocol).unsqueeze(1).repeat(1, embeddings.shape[1], 1) 260 | ) 261 | 262 | return self.compressEmbeddings(torch.concat([embeddings, protoEmbeddings], dim = -1)) 263 | 264 | class NetFoundRobertaEmbeddings(RobertaEmbeddings, NetFoundEmbeddingsWithMeta): 265 | def __init__(self, config): 266 | RobertaEmbeddings.__init__(self, config) 267 | NetFoundEmbeddingsWithMeta.__init__(self, config) 268 | 269 | def forward( 270 | self, 271 | input_ids=None, 272 | position_ids=None, 273 | direction=None, 274 | iats=None, 275 | bytes=None, 276 | pkt_count=None, 277 | protocol=None, 278 | ): 279 | position_ids = self.create_position_ids_from_input_ids( 280 | input_ids, self.padding_idx, self.position_ids 281 | ) 282 | embeddings = self.word_embeddings(input_ids) 283 | if self.position_embedding_type == "absolute": 284 | position_embeddings = self.position_embeddings(position_ids) 285 | embeddings += position_embeddings 286 | embeddings = self.addMetaEmbeddings(embeddings, direction, iats, bytes, pkt_count, protocol) 287 | embeddings = self.LayerNorm(embeddings) 288 | embeddings = self.dropout(embeddings) 289 | return embeddings 290 | 291 | @staticmethod 292 | def create_position_ids_from_input_ids(input_ids, padding_idx, position_ids): 293 | mask = input_ids.ne(padding_idx).int() 294 | position_ids = ( 295 | position_ids.repeat( 296 | input_ids.shape[0], input_ids.shape[1] // position_ids.shape[1] 297 | ) 298 | * mask 299 | ) 300 | return position_ids 301 | 302 | class NetFoundRoformerEmbeddings(RoFormerEmbeddings, NetFoundEmbeddingsWithMeta): 303 | def __init__(self, config): 304 | RoFormerEmbeddings.__init__(self, config) 305 | NetFoundEmbeddingsWithMeta.__init__(self, config) 306 | self.roformer = config.roformer 307 | 308 | def forward( 309 | self, 310 | input_ids=None, 311 | position_ids=None, 312 | direction=None, 313 | iats=None, 314 | bytes=None, 315 | pkt_count=None, 316 | protocol=None, 317 | ): 318 | embeddings = self.word_embeddings(input_ids) 319 | embeddings = self.addMetaEmbeddings(embeddings, direction, iats, bytes, pkt_count, protocol) 320 | embeddings = self.LayerNorm(embeddings) 321 | embeddings = self.dropout(embeddings) 322 | return embeddings 323 | 324 | class NetFoundLayer(nn.Module): 325 | def __init__(self, config): 326 | super().__init__() 327 | self.max_burst_length = config.max_burst_length 328 | self.max_bursts = config.max_bursts 329 | self.hidden_size = config.hidden_size 330 | self.burst_encoder = TransformerLayer(config) 331 | self.flow_encoder = TransformerLayer(config) 332 | self.position_embeddings = nn.Embedding( 333 | config.max_bursts + 1, config.hidden_size, padding_idx=config.pad_token_id 334 | ) 335 | self.roformer = config.roformer 336 | 337 | def forward( 338 | self, 339 | hidden_states, 340 | attention_mask=None, 341 | num_bursts=None, 342 | output_attentions=False, 343 | burstSeqNo = None, 344 | flowSeqNo = None 345 | ): 346 | # transform sequences to bursts 347 | burst_inputs = transform_tokens2bursts( 348 | hidden_states, num_bursts=num_bursts, max_burst_length=self.max_burst_length 349 | ) 350 | burst_masks = transform_masks2bursts( 351 | attention_mask, 352 | num_bursts=num_bursts, 353 | max_burst_length=self.max_burst_length, 354 | ) 355 | burst_outputs = self.burst_encoder( 356 | burst_inputs, burst_masks, output_attentions=output_attentions, seqNo = burstSeqNo 357 | ) 358 | 359 | # flatten bursts back to tokens 360 | outputs = transform_bursts2tokens( 361 | burst_outputs[0], 362 | num_bursts=num_bursts, 363 | max_burst_length=self.max_burst_length, 364 | ) 365 | 366 | burst_global_tokens = outputs[:, :: self.max_burst_length].clone() 367 | burst_attention_mask = attention_mask[:, :, :, :: self.max_burst_length].clone() 368 | 369 | burst_positions = torch.arange(1, num_bursts + 1).repeat(outputs.size(0), 1)\ 370 | .to(outputs.device) * (burst_attention_mask.reshape(-1, num_bursts) >= -1).int().to(outputs.device) 371 | outputs[:, :: self.max_burst_length] += self.position_embeddings(burst_positions) 372 | 373 | flow_outputs = self.flow_encoder( 374 | burst_global_tokens, 375 | burst_attention_mask, 376 | output_attentions=output_attentions, 377 | seqNo = flowSeqNo 378 | ) 379 | 380 | # replace burst representative tokens 381 | outputs[:, :: self.max_burst_length] = flow_outputs[0] 382 | 383 | return outputs, burst_outputs, flow_outputs 384 | 385 | 386 | class NetFoundLayerFlat(nn.Module): 387 | def __init__(self, config): 388 | super().__init__() 389 | self.max_burst_length = config.max_burst_length 390 | self.max_bursts = config.max_bursts 391 | self.hidden_size = config.hidden_size 392 | self.burst_encoder = TransformerLayer(config) 393 | self.position_embeddings = nn.Embedding( 394 | config.max_bursts + 1, config.hidden_size, padding_idx=config.pad_token_id 395 | ) 396 | 397 | def forward( 398 | self, 399 | hidden_states, 400 | attention_mask=None, 401 | num_bursts=None, 402 | output_attentions=False, 403 | burstSeqNo = None, 404 | flowSeqNo = None 405 | ): 406 | burst_inputs = transform_tokens2bursts( 407 | hidden_states, num_bursts=num_bursts, max_burst_length=self.max_burst_length 408 | ) 409 | burst_masks = transform_masks2bursts( 410 | attention_mask, 411 | num_bursts=num_bursts, 412 | max_burst_length=self.max_burst_length, 413 | ) 414 | burst_outputs = self.burst_encoder( 415 | burst_inputs, burst_masks, output_attentions=output_attentions, seqNo = burstSeqNo 416 | ) 417 | outputs = transform_bursts2tokens( 418 | burst_outputs[0], 419 | num_bursts=num_bursts, 420 | max_burst_length=self.max_burst_length, 421 | ) 422 | return outputs, burst_outputs 423 | 424 | 425 | 426 | @dataclass 427 | class BaseModelOutputWithFlowAttentions(ModelOutput): 428 | 429 | last_hidden_state: torch.FloatTensor = None 430 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 431 | attentions: Optional[Tuple[torch.FloatTensor]] = None 432 | flow_attentions: Optional[Tuple[torch.FloatTensor]] = None 433 | 434 | 435 | class NetFoundPretrainedModel(PreTrainedModel): 436 | def _init_weights(self, module): 437 | """Initialize the weights""" 438 | if isinstance(module, nn.Linear): 439 | # Slightly different from the TF version which uses truncated_normal for initialization 440 | # cf https://github.com/pytorch/pytorch/pull/5617 441 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 442 | if module.bias is not None: 443 | module.bias.data.zero_() 444 | elif isinstance(module, nn.Embedding): 445 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 446 | if module.padding_idx is not None: 447 | module.weight.data[module.padding_idx].zero_() 448 | elif isinstance(module, nn.LayerNorm): 449 | module.bias.data.zero_() 450 | module.weight.data.fill_(1.0) 451 | 452 | def _set_gradient_checkpointing(self, module, value=False): 453 | if isinstance(module, NetFoundEncoder): 454 | module.gradient_checkpointing = value 455 | 456 | def update_keys_to_ignore(self, config, del_keys_to_ignore): 457 | """Remove some keys from ignore list""" 458 | if not config.tie_word_embeddings: 459 | # must make a new list, or the class variable gets modified! 460 | self._keys_to_ignore_on_save = [ 461 | k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore 462 | ] 463 | self._keys_to_ignore_on_load_missing = [ 464 | k 465 | for k in self._keys_to_ignore_on_load_missing 466 | if k not in del_keys_to_ignore 467 | ] 468 | 469 | @classmethod 470 | def from_config(cls, config): 471 | return cls._from_config(config) 472 | 473 | 474 | class NetFoundBase(NetFoundPretrainedModel): 475 | 476 | _keys_to_ignore_on_load_missing = [r"position_ids"] 477 | 478 | def __init__(self, config): 479 | super().__init__(config) 480 | self.config = config 481 | if config.roformer: 482 | self.embeddings = NetFoundRoformerEmbeddings(config) 483 | else: 484 | self.embeddings = NetFoundRobertaEmbeddings(config) 485 | self.seg_embeddings = torch.nn.Embedding( 486 | num_embeddings=3, embedding_dim=config.hidden_size 487 | ) 488 | self.encoder = NetFoundEncoder(config) 489 | 490 | self.post_init() 491 | 492 | def get_input_embeddings(self): 493 | return self.embeddings.word_embeddings 494 | 495 | def set_input_embeddings(self, value): 496 | self.embeddings.word_embeddings = value 497 | 498 | def forward( 499 | self, 500 | input_ids=None, 501 | attention_mask=None, 502 | position_ids=None, 503 | output_attentions=None, 504 | output_hidden_states=None, 505 | return_dict=None, 506 | direction=None, 507 | iats=None, 508 | bytes=None, 509 | pkt_count=None, 510 | protocol=None, 511 | ): 512 | 513 | embeddings = self.embeddings( 514 | input_ids, position_ids, direction, iats, bytes, pkt_count, protocol 515 | ) 516 | input_shape = input_ids.size() 517 | device = input_ids.device 518 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( 519 | attention_mask, input_shape, device 520 | ) 521 | num_bursts = input_ids.shape[-1] // self.config.max_burst_length 522 | encoder_outputs = self.encoder( 523 | embeddings, 524 | extended_attention_mask, 525 | num_bursts, 526 | output_attentions, 527 | output_hidden_states, 528 | ) 529 | final_output = encoder_outputs[0] 530 | 531 | if not return_dict: 532 | return (final_output) + encoder_outputs[1:] 533 | 534 | return BaseModelOutputWithFlowAttentions( 535 | last_hidden_state=final_output, 536 | hidden_states=encoder_outputs.hidden_states, 537 | attentions=encoder_outputs.attentions, 538 | flow_attentions=encoder_outputs.flow_attentions, 539 | ) 540 | 541 | """ 542 | BaseModel: 543 | embedding: 544 | tokens RobertaAttention: vocab->768 545 | meta: 4->768 546 | encoder: 547 | burst : (seqLength+1) X 768 548 | concat: (seqLength+1)*num_sen X 768 549 | 550 | flow: num_sen X 768 551 | replace the reps 552 | 553 | """ 554 | 555 | 556 | class LMHead(nn.Module): 557 | def __init__(self, config): 558 | config = copy.deepcopy(config) 559 | super().__init__() 560 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 561 | self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 562 | 563 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size) 564 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 565 | self.decoder.bias = self.bias 566 | 567 | def forward(self, features, **kwargs): 568 | x = self.dense(features) 569 | x = gelu(x) 570 | x = self.layer_norm(x) 571 | 572 | # project back to size of vocabulary with bias 573 | x = self.decoder(x) 574 | 575 | return x 576 | 577 | def _tie_weights(self): 578 | # To tie those two weights if they get disconnected (on TPU or when the bias is resized) 579 | self.bias = self.decoder.bias 580 | 581 | 582 | class NetFoundLanguageModelling(NetFoundPretrainedModel): 583 | 584 | _keys_to_ignore_on_load_missing = [r"position_ids"] 585 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 586 | 587 | def __init__(self, config): 588 | super().__init__(config) 589 | 590 | self.base_transformer = NetFoundBase(config) 591 | self.lm_head = LMHead(config) 592 | self.no_mlm = config.no_mlm 593 | self.no_swapped_bursts = config.no_swapped_bursts 594 | self.no_metadata_loss = config.no_metadata_loss 595 | self.no_direction_loss = config.no_direction_loss 596 | self.attentivePooling = AttentivePooling(config) 597 | self.max_burst_length = config.max_burst_length 598 | self.portClassifierHiddenLayer = nn.Linear(config.hidden_size, 65536) 599 | self.swappedClassifierHiddenLayer = nn.Linear(config.hidden_size, 2) 600 | self.linearMetadataPred = nn.Linear(config.hidden_size, 3) 601 | self.dirPred = nn.Linear(config.hidden_size, 2) 602 | 603 | # The LM head weights require special treatment only when they are tied with the word embeddings 604 | self.update_keys_to_ignore(config, ["lm_head.decoder.weight"]) 605 | 606 | # Initialize weights and apply final processing 607 | self.post_init() 608 | 609 | def _tie_or_clone_weights(self, output_embeddings, input_embeddings): 610 | """Tie or clone module weights depending of whether we are using TorchScript or not""" 611 | if self.config.torchscript: 612 | output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone()) 613 | else: 614 | output_embeddings.weight = input_embeddings.weight 615 | 616 | if getattr(output_embeddings, "bias", None) is not None: 617 | output_embeddings.bias.data = nn.functional.pad( 618 | output_embeddings.bias.data, 619 | ( 620 | 0, 621 | output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0], 622 | ), 623 | "constant", 624 | 0, 625 | ) 626 | if hasattr(output_embeddings, "out_features") and hasattr( 627 | input_embeddings, "num_embeddings" 628 | ): 629 | output_embeddings.out_features = input_embeddings.num_embeddings 630 | 631 | def get_output_embeddings(self): 632 | return self.lm_head.decoder 633 | 634 | def set_output_embeddings(self, new_embeddings): 635 | self.lm_head.decoder = new_embeddings 636 | 637 | def get_input_embeddings(self): 638 | return self.base_transformer.embeddings.word_embeddings 639 | 640 | def set_input_embeddings(self, value): 641 | self.base_transformer.embeddings.word_embeddings = value 642 | 643 | def maskMeta(self, bursts_to_mask, metaFeature): 644 | for i in range(bursts_to_mask.shape[0]): 645 | for j in range(bursts_to_mask.shape[1]): 646 | if bursts_to_mask[i][j]: 647 | metaFeature[i][j*self.max_burst_length:(j+1)*self.max_burst_length] = 0 648 | return metaFeature 649 | 650 | 651 | def forward( 652 | self, 653 | input_ids=None, 654 | attention_mask=None, 655 | position_ids=None, 656 | labels=None, 657 | output_attentions=None, 658 | output_hidden_states=None, 659 | return_dict=None, 660 | direction=None, 661 | iats=None, 662 | bytes=None, 663 | pkt_count=None, 664 | ports=None, 665 | swappedLabels=None, 666 | burstMetasToBeMasked = None, 667 | protocol=None, 668 | ): 669 | return_dict = ( 670 | return_dict if return_dict is not None else self.config.use_return_dict 671 | ) 672 | 673 | #creating ground truths tensors before masking 674 | direction_orig = direction.clone().to(torch.long) 675 | iat_orig = iats.clone()/1000 #adjusting as values are higher. 676 | bytes_orig = bytes.clone()/1000 #adjusting as values are higher. 677 | pktCount_orig = pkt_count.clone() 678 | 679 | direction = self.maskMeta(burstMetasToBeMasked, direction) 680 | iats = self.maskMeta(burstMetasToBeMasked, iats) 681 | bytes = self.maskMeta(burstMetasToBeMasked, bytes) 682 | pktCount = self.maskMeta(burstMetasToBeMasked, pkt_count) 683 | outputs = self.base_transformer( 684 | input_ids, 685 | attention_mask=attention_mask, 686 | position_ids=position_ids, 687 | output_attentions=output_attentions, 688 | output_hidden_states=output_hidden_states, 689 | return_dict=return_dict, 690 | direction=direction, 691 | iats=iats, 692 | bytes=bytes, 693 | pkt_count=pktCount, 694 | protocol=protocol, 695 | ) 696 | 697 | # mlm prediction 698 | sequence_output = outputs[0] 699 | prediction_scores = self.lm_head(sequence_output) 700 | 701 | # swapped bursts predictions 702 | pooled_output = poolingByAttention( 703 | self.attentivePooling, sequence_output, self.config.max_burst_length 704 | ) 705 | swappedLogits = self.swappedClassifierHiddenLayer(pooled_output) 706 | 707 | # metadata prediction except direction 708 | burstReps = sequence_output[:, ::self.max_burst_length, :] 709 | burstMetaFieldsToBeMasked = burstMetasToBeMasked.unsqueeze(dim=2).expand(-1, -1, self.linearMetadataPred.bias.shape[-1]).to(torch.float32) 710 | metaPreds = self.linearMetadataPred(burstReps) * burstMetaFieldsToBeMasked 711 | metaLabels = burstMetaFieldsToBeMasked * torch.stack([ 712 | iat_orig[:, ::self.max_burst_length], 713 | bytes_orig[:, ::self.max_burst_length], 714 | pktCount_orig[:, ::self.max_burst_length] 715 | ], dim=2) 716 | 717 | # metadata prediction - direction 718 | # direction will be a classification task, -100 is used to not compute loss in pytorch. 719 | # All the unmasked values will be set to 0, so we remove the 0 directions. 720 | direction_orig_ = direction_orig[:, ::self.max_burst_length] 721 | direction_orig_ = burstMetasToBeMasked.to(torch.long) * direction_orig_ 722 | direction_orig_[direction_orig_.to(torch.long) == 0] = TORCH_IGNORE_INDEX 723 | # We have +1 -1 as direction, but for classification we need 0 1. Setting -1 as 0 for classification 724 | direction_orig_[direction_orig_.to(torch.long) == -1] = 0 725 | direction_logits = torch.softmax(self.dirPred(burstReps), -1) 726 | 727 | losses = [] 728 | ce_loss = CrossEntropyLoss(ignore_index=TORCH_IGNORE_INDEX) 729 | l1_loss = L1Loss() 730 | prefix = "train" if self.training else "eval" 731 | if not self.no_mlm: 732 | masked_lm_loss = ce_loss(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 733 | losses.append(masked_lm_loss) 734 | utils.TB_WRITER.add_scalar( 735 | tag=f"{prefix}/mlm_loss", 736 | scalar_value=masked_lm_loss.item(), 737 | global_step=int(time.time()), 738 | ) 739 | 740 | if not self.no_swapped_bursts: 741 | swappedClassificationLoss = ce_loss(swappedLogits, swappedLabels) 742 | losses.append(swappedClassificationLoss) 743 | utils.TB_WRITER.add_scalar( 744 | tag=f"{prefix}/swap_bursts_loss", 745 | scalar_value=swappedClassificationLoss.item(), 746 | global_step=int(time.time()), 747 | ) 748 | 749 | if not self.no_metadata_loss: 750 | metaLoss = l1_loss(metaPreds, metaLabels.to(metaPreds.dtype)) 751 | losses.append(metaLoss) 752 | utils.TB_WRITER.add_scalar( 753 | tag=f"{prefix}/metadata_loss", 754 | scalar_value=metaLoss.item(), 755 | global_step=int(time.time()), 756 | ) 757 | 758 | # transpose for k-dimension loss that wants (BATCH x CLASS_NUMBER x OTHER_DIMENSION) 759 | if not self.no_direction_loss: 760 | if (direction_orig_ != -100).any(): 761 | dirLoss = ce_loss(direction_logits.transpose(1, 2), direction_orig_) 762 | else: 763 | # if all labels are -100 - loss is nan: https://github.com/pytorch/pytorch/issues/70348 - let's do like facebook: https://github.com/facebookresearch/detectron2/commit/04fc85a0c44675559c2fbc9c7541cbb8b443819c 764 | dirLoss = direction_logits.sum() * 0 765 | 766 | if not torch.isnan(dirLoss): 767 | losses.append(dirLoss) 768 | utils.TB_WRITER.add_scalar( 769 | tag=f"{prefix}/direction_loss", 770 | scalar_value=dirLoss.item(), 771 | global_step=int(time.time()), 772 | ) 773 | 774 | if not losses: 775 | raise ValueError("No valid losses are defined") 776 | 777 | totalLoss = torch.stack(losses).sum() 778 | if not return_dict: 779 | output = (prediction_scores,) + outputs[2:] 780 | return (totalLoss,) + output 781 | 782 | return MaskedLMOutput( 783 | loss=totalLoss, 784 | logits=(prediction_scores, swappedLogits), 785 | hidden_states=outputs.hidden_states, 786 | attentions=outputs.attentions, 787 | ) 788 | 789 | 790 | def poolingByConcat(sequence_output, max_burst_length, hidden_size, max_bursts): 791 | burstReps = sequence_output[:, ::max_burst_length, :].clone() 792 | pads = torch.zeros( 793 | burstReps.shape[0], 794 | hidden_size * (max_bursts - burstReps.shape[1]), 795 | dtype=burstReps.dtype, 796 | ).to(burstReps.device) 797 | return torch.concat( 798 | [torch.reshape(burstReps, (burstReps.shape[0], -1)), pads], dim=-1 799 | ).to(burstReps.device) 800 | 801 | 802 | def poolingByMean(sequence_output, attention_mask, max_burst_length): 803 | burst_attention = attention_mask[:, ::max_burst_length].detach().clone() 804 | burstReps = sequence_output[:, ::max_burst_length, :].clone() 805 | burst_attention = burst_attention / torch.sum(burst_attention, dim=-1).unsqueeze( 806 | 0 807 | ).transpose(0, 1) 808 | orig_shape = burstReps.shape 809 | burstReps = burst_attention.reshape( 810 | burst_attention.shape[0] * burst_attention.shape[1], -1 811 | ) * burstReps.reshape((burstReps.shape[0] * burstReps.shape[1], -1)) 812 | return burstReps.reshape(orig_shape).sum(dim=1) 813 | 814 | 815 | def poolingByAttention(attentivePooling, sequence_output, max_burst_length): 816 | burstReps = sequence_output[:, ::max_burst_length, :].clone() 817 | return attentivePooling(burstReps) 818 | 819 | 820 | class AttentivePooling(nn.Module): 821 | def __init__(self, config): 822 | super().__init__() 823 | self.attn_dropout = config.hidden_dropout_prob 824 | self.lin_proj = nn.Linear(config.hidden_size, config.hidden_size) 825 | self.v = nn.Linear(config.hidden_size, 1, bias=False) 826 | 827 | def forward(self, inputs): 828 | lin_out = self.lin_proj(inputs) 829 | attention_weights = torch.tanh(self.v(lin_out)).squeeze(-1) 830 | attention_weights_normalized = torch.softmax(attention_weights, -1) 831 | return torch.sum(attention_weights_normalized.unsqueeze(-1) * inputs, 1) 832 | 833 | 834 | class NetfoundFinetuningModel(NetFoundPretrainedModel): 835 | _keys_to_ignore_on_load_missing = [r"position_ids"] 836 | 837 | def __init__(self, config): 838 | super().__init__(config) 839 | self.num_labels = config.num_labels 840 | self.config = config 841 | self.model_max_length = config.model_max_length 842 | self.max_burst_length = self.config.max_burst_length 843 | self.base_transformer = NetFoundBase(config) 844 | self.attentivePooling = AttentivePooling(config) 845 | classifier_dropout = ( 846 | config.classifier_dropout 847 | if config.classifier_dropout is not None 848 | else config.hidden_dropout_prob 849 | ) 850 | self.dropout = nn.Dropout(classifier_dropout) 851 | self.hiddenLayer = nn.Linear(config.hidden_size, config.hidden_size) 852 | self.hiddenLayer2 = nn.Linear(config.hidden_size, config.hidden_size) 853 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 854 | self.attentivePooling = AttentivePooling(config=config) 855 | self.relu = nn.ReLU() 856 | 857 | # Initialize weights and apply final processing 858 | self.post_init() 859 | 860 | def poolingByAttention(self, sequence_output, max_burst_length): 861 | burstReps = sequence_output[:, ::max_burst_length, :].clone() 862 | return self.attentivePooling(burstReps) 863 | 864 | def forward( 865 | self, 866 | input_ids=None, 867 | attention_mask=None, 868 | position_ids=None, 869 | labels=None, 870 | output_attentions=None, 871 | output_hidden_states=None, 872 | return_dict=None, 873 | direction=None, 874 | iats=None, 875 | bytes=None, 876 | pkt_count=None, 877 | protocol=None, 878 | stats=None, 879 | flow_duration = None 880 | ): 881 | r""" 882 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 883 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 884 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 885 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 886 | """ 887 | if labels is None: 888 | labels = flow_duration / 1000.0 889 | return_dict = ( 890 | return_dict if return_dict is not None else self.config.use_return_dict 891 | ) 892 | 893 | outputs = self.base_transformer( 894 | input_ids, 895 | attention_mask=attention_mask, 896 | position_ids=position_ids, 897 | output_attentions=output_attentions, 898 | output_hidden_states=output_hidden_states, 899 | return_dict=return_dict, 900 | direction=direction, 901 | iats=iats, 902 | bytes=bytes, 903 | pkt_count=pkt_count, 904 | protocol=protocol, 905 | ) 906 | 907 | sequence_output = outputs[0] 908 | pooled_output = poolingByAttention( 909 | self.attentivePooling, sequence_output, self.config.max_burst_length 910 | ) 911 | pooled_output = self.hiddenLayer2(self.hiddenLayer(pooled_output)) 912 | if stats is not None: 913 | logits = self.classifier(torch.concatenate([pooled_output, stats], dim=-1)) 914 | else: 915 | logits = self.classifier(torch.concatenate([pooled_output], dim=-1)) 916 | 917 | loss = None 918 | if labels is not None: 919 | if self.config.problem_type is None: 920 | if self.num_labels == 1: 921 | self.config.problem_type = "regression" 922 | elif self.num_labels > 1 and ( 923 | labels.dtype == torch.long or labels.dtype == torch.int 924 | ): 925 | self.config.problem_type = "single_label_classification" 926 | else: 927 | self.config.problem_type = "multi_label_classification" 928 | 929 | if self.config.problem_type == "regression": 930 | loss_fct = L1Loss() 931 | if self.num_labels == 1: 932 | loss = loss_fct(logits.squeeze(), (labels.squeeze().to(torch.float32))) 933 | else: 934 | loss = loss_fct(logits, labels) 935 | elif self.config.problem_type == "single_label_classification": 936 | loss_fct = CrossEntropyLoss() 937 | loss = loss_fct(logits.view(-1, self.num_labels), labels) 938 | 939 | if not return_dict: 940 | output = (logits,) + outputs[2:] 941 | return ((loss,) + output) if loss is not None else output 942 | 943 | return SequenceClassifierOutput( 944 | loss=loss, 945 | logits=logits, 946 | ) 947 | 948 | 949 | class NetfoundNoPTM(NetFoundPretrainedModel): 950 | _keys_to_ignore_on_load_missing = [r"position_ids"] 951 | 952 | def __init__(self, config): 953 | super().__init__(config) 954 | self.num_labels = config.num_labels 955 | self.config = config 956 | self.model_max_length = config.model_max_length 957 | self.max_burst_length = self.config.max_burst_length 958 | classifier_dropout = ( 959 | config.classifier_dropout 960 | if config.classifier_dropout is not None 961 | else config.hidden_dropout_prob 962 | ) 963 | self.dropout = nn.Dropout(classifier_dropout) 964 | self.hiddenLayer = nn.Linear(1595, config.hidden_size * 2) 965 | self.hiddenLayer2 = nn.Linear(config.hidden_size * 2, config.hidden_size) 966 | 967 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 968 | self.relu = nn.ReLU() 969 | 970 | # Initialize weights and apply final processing 971 | self.post_init() 972 | 973 | def poolingByAttention(self, sequence_output, max_burst_length): 974 | burstReps = sequence_output[:, ::max_burst_length, :].clone() 975 | return self.attentivePooling(burstReps) 976 | 977 | def forward( 978 | self, 979 | input_ids=None, 980 | attention_mask=None, 981 | position_ids=None, 982 | labels=None, 983 | output_attentions=None, 984 | output_hidden_states=None, 985 | return_dict=None, 986 | direction=None, 987 | iat=None, 988 | bytes=None, 989 | pktCount=None, 990 | stats=None, 991 | ): 992 | r""" 993 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 994 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 995 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 996 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 997 | """ 998 | 999 | return_dict = ( 1000 | return_dict if return_dict is not None else self.config.use_return_dict 1001 | ) 1002 | input = torch.concatenate( 1003 | [ 1004 | input_ids, 1005 | torch.zeros((input_ids.shape[0], 1595 - input_ids.shape[1])).to( 1006 | input_ids.device 1007 | ), 1008 | ], 1009 | dim=-1, 1010 | ) 1011 | 1012 | pooled_output = self.hiddenLayer2(self.hiddenLayer(input)) 1013 | logits = self.classifier(torch.concatenate([pooled_output], dim=-1)) 1014 | 1015 | loss = None 1016 | if labels is not None: 1017 | if self.config.problem_type is None: 1018 | if self.num_labels == 1: 1019 | self.config.problem_type = "regression" 1020 | elif self.num_labels > 1 and ( 1021 | labels.dtype == torch.long or labels.dtype == torch.int 1022 | ): 1023 | self.config.problem_type = "single_label_classification" 1024 | else: 1025 | self.config.problem_type = "multi_label_classification" 1026 | 1027 | if self.config.problem_type == "regression": 1028 | loss_fct = MSELoss() 1029 | if self.num_labels == 1: 1030 | logits = self.relu(logits) 1031 | loss = loss_fct(logits.squeeze(), (labels.squeeze())) 1032 | else: 1033 | loss = loss_fct(logits, labels) 1034 | elif self.config.problem_type == "single_label_classification": 1035 | loss_fct = CrossEntropyLoss() 1036 | loss = loss_fct(logits.view(-1, self.num_labels), labels) 1037 | 1038 | if not return_dict: 1039 | output = (logits,) + pooled_output[2:] 1040 | return ((loss,) + output) if loss is not None else output 1041 | 1042 | return SequenceClassifierOutput( 1043 | loss=loss, 1044 | logits=logits, 1045 | ) 1046 | --------------------------------------------------------------------------------