├── NetDiffusion-LoRa-SD3-Fine-Tuning.ipynb ├── README.md └── scripts ├── color_processor.py ├── column_example.nprint ├── compute_embeddings.py ├── correct_format.nprint ├── image_to_nprint.py ├── mass_reconstruction.py ├── nprint_to_png.py ├── reconstruction.py ├── traffic_conditioning_image.png └── train_dreambooth_lora_sd3_miniature.py /NetDiffusion-LoRa-SD3-Fine-Tuning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "scrolled": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "# 📦 Install Dependencies\n", 12 | "# \n", 13 | "# 1. Installs the latest development version of Hugging Face's `diffusers` library directly from GitHub.\n", 14 | "# This library is used for working with diffusion models such as Stable Diffusion.\n", 15 | "!pip install -q -U git+https://github.com/huggingface/diffusers\n", 16 | "# 2. Installs and upgrades several core Hugging Face and optimization libraries:\n", 17 | "# - `transformers`: For using pre-trained language and vision models\n", 18 | "# - `accelerate`: Simplifies training across CPUs, GPUs, and distributed setups\n", 19 | "# - `wandb`: For logging experiments and visualizing training progress\n", 20 | "# - `bitsandbytes`: Enables 8-bit model loading for memory-efficient inference/training\n", 21 | "# - `peft`: For Parameter-Efficient Fine-Tuning of large models\n", 22 | "!pip install -q -U \\\n", 23 | " transformers \\\n", 24 | " accelerate \\\n", 25 | " wandb \\\n", 26 | " bitsandbytes \\\n", 27 | " peft\n", 28 | "# 3. Installs additional Python libraries needed for data handling, model input/output,\n", 29 | "# and networking or computer vision tasks:\n", 30 | "# - `pandas`: For data manipulation and tabular processing\n", 31 | "# - `torchvision`: For image transformations and loading datasets (used with PyTorch)\n", 32 | "# - `pyarrow`: For efficient I/O and working with Apache Arrow / Parquet formats\n", 33 | "# - `sentencepiece`: For subword tokenization used in many NLP models\n", 34 | "# - `controlnet_aux`: Adds support functions for ControlNet like HED, Canny, Depth, etc.\n", 35 | "# - `scapy`: For packet parsing and crafting, often used in networking/PCAP analysis\n", 36 | "# - `gdown`: For downloading files from Google Drive using file IDs\n", 37 | "# - `opencv-python`: For computer vision tasks and image manipulation\n", 38 | "!pip install pandas torchvision pyarrow sentencepiece controlnet_aux scapy gdown opencv-python" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "#As SD3 is gated, before using it with diffusers you first need to go to the Stable Diffusion 3 Medium Hugging Face page, fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:\n", 48 | "#Ignore if already logged in.\n", 49 | "!huggingface-cli login" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "# Download the example preprocessed traffic dataset.\n", 59 | "#\n", 60 | "# To create your own dataset:\n", 61 | "# 1) Install nPrint in your environment (see https://nprint.github.io/nprint/).\n", 62 | "# 2) Collect PCAPs belonging to the same service/application you plan to model.\n", 63 | "# 3) Convert each PCAP to an nPrint file with the following command:\n", 64 | "# nprint -F -1 -P {pcap} -4 -i -6 -t -u -p 0 -c 1024 -W {output_file}\n", 65 | "# 4) Place all resulting nPrint files in a folder named \"nprint_traffic\".\n", 66 | "\n", 67 | "# Download the example publically available preprocessed traffic dataset\n", 68 | "!gdown --id 1vvneSH0a1WZFPHTafKusOUjNhg7oQioq --output preprocessed_dataset.zip\n", 69 | "\n", 70 | "# Unzip and move files into desired directory\n", 71 | "!unzip -q preprocessed_dataset.zip\n", 72 | "!mkdir -p nprint_traffic\n", 73 | "!mv amazon_nprint_traffic/* nprint_traffic/\n", 74 | "\n", 75 | "# Clean up\n", 76 | "!rm -r amazon_nprint_traffic __MACOSX preprocessed_dataset.zip" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": { 83 | "scrolled": true 84 | }, 85 | "outputs": [], 86 | "source": [ 87 | "# -----------------------------------------------------------\n", 88 | "# 🖼️ Convert the nPrint representation of PCAPs into PNG images\n", 89 | "#\n", 90 | "# This script is designed to assist in transforming `.nprint` files—\n", 91 | "# tabular feature representations of network packets—into fixed-size\n", 92 | "# PNG images that can be used as input for SD fine-tuning.\n", 93 | "#\n", 94 | "# 🧾 Input:\n", 95 | "# - `.nprint` files generated from packet captures (PCAPs) using the\n", 96 | "# nPrint tool. Each file is a CSV-like matrix where each row\n", 97 | "# represents a single packet and columns correspond to extracted features.\n", 98 | "#\n", 99 | "# 🧹 Preprocessing:\n", 100 | "# - Drops IP address-related columns to avoid injecting identifiable\n", 101 | "# or non-generalizable information into the model.\n", 102 | "# - Maps integer values in the remaining columns to RGBA color tuples\n", 103 | "# to visualize numeric features as colored pixels.\n", 104 | "#\n", 105 | "# 🧱 Padding:\n", 106 | "# - Pads each image to a uniform height (default 1024) using a solid\n", 107 | "# background to ensure model input consistency across varying packet counts.\n", 108 | "#\n", 109 | "# 🎨 Output:\n", 110 | "# - Saves a PNG file for each `.nprint` file, preserving the packet structure\n", 111 | "# as a vertically stacked color-coded image (rows = packets, cols = features).\n", 112 | "# -----------------------------------------------------------\n", 113 | "!python ./scripts/nprint_to_png.py -i ./nprint_traffic/ -o ./nprint_traffic_images" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "#Compute embeddings:\n", 123 | "#Generates text prompt embeddings via a Stable Diffusion 3 pipeline and T5 text encoder.\n", 124 | "#Maps each local PNG image to a unique SHA-256 hash and associates it with the computed embeddings.\n", 125 | "#Stores the resulting image-hash-to-embedding data in a .parquet file for further processing.\n", 126 | "#Here we are using the default instance prompt \"pixelated network data for type-0 application traffic\".\n", 127 | "#But you can configure this. Refer to the compute_embeddings.py script for details on other supported arguments.\n", 128 | "!python ./scripts/compute_embeddings.py" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "#Clear memory\n", 138 | "import torch\n", 139 | "import gc\n", 140 | "\n", 141 | "def flush():\n", 142 | " torch.cuda.empty_cache()\n", 143 | " gc.collect()\n", 144 | "\n", 145 | "flush()" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "# -----------------------------------------------------------\n", 155 | "# 🚀 Train LoRA Adapter on Stable Diffusion 3 (Miniature Setup)\n", 156 | "#\n", 157 | "# This command launches training using `accelerate` with DreamBooth-style LoRA tuning,\n", 158 | "# optimized for quick experimentation or demo runs.\n", 159 | "#\n", 160 | "# ⚠️ Current configuration uses:\n", 161 | "# - Only 1 training step\n", 162 | "# - Small batch size\n", 163 | "# - No warmup\n", 164 | "# - High learning rate\n", 165 | "#\n", 166 | "# Intended **only for testing or verifying training scripts**, NOT for quality results.\n", 167 | "# For actual training, increase `max_train_steps`, adjust learning rate, and consider\n", 168 | "# enabling full validation and saving checkpoints.\n", 169 | "# -----------------------------------------------------------\n", 170 | "!accelerate launch ./scripts/train_dreambooth_lora_sd3_miniature.py \\\n", 171 | " --pretrained_model_name_or_path=\"stabilityai/stable-diffusion-3-medium-diffusers\" \\\n", 172 | " --instance_data_dir=\"nprint_traffic_images\" \\\n", 173 | " --data_df_path=\"sample_embeddings.parquet\" \\\n", 174 | " --output_dir=\"trained-sd3-lora-miniature\" \\\n", 175 | " --mixed_precision=\"fp16\" \\\n", 176 | " --instance_prompt=\"pixelated network data for type-0 application traffic\" \\\n", 177 | " --train_batch_size=2 \\\n", 178 | " --gradient_accumulation_steps=1 --gradient_checkpointing \\\n", 179 | " --use_8bit_adam \\\n", 180 | " --learning_rate=5e-5 \\\n", 181 | " --report_to=\"wandb\" \\\n", 182 | " --lr_scheduler=\"constant\" \\\n", 183 | " --lr_warmup_steps=0 \\\n", 184 | " --max_train_steps=1 \\\n", 185 | " --seed=\"0\"" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "# -----------------------------------------------------------\n", 195 | "# 🎯 Inference Pipeline for SD3 + ControlNet + LoRA\n", 196 | "#\n", 197 | "# This cell demonstrates the full generation process:\n", 198 | "# - Loads Stable Diffusion 3 Medium base model\n", 199 | "# - Loads ControlNet (Canny edge-based guidance)\n", 200 | "# - Loads LoRA weights fine-tuned on pixelated network traffic\n", 201 | "# - Applies edge-conditioned generation using a sample input\n", 202 | "#\n", 203 | "# -----------------------------------------------------------\n", 204 | "flush()\n", 205 | "import os\n", 206 | "import torch\n", 207 | "import cv2\n", 208 | "from PIL import Image\n", 209 | "from diffusers import StableDiffusion3ControlNetPipeline, SD3ControlNetModel\n", 210 | "from diffusers.utils import load_image\n", 211 | "\n", 212 | "# Make sure our output folder exists\n", 213 | "os.makedirs(\"generated_traffic_images\", exist_ok=True)\n", 214 | "\n", 215 | "# Base SD 3.0 model\n", 216 | "base_model_path = \"stabilityai/stable-diffusion-3-medium-diffusers\"\n", 217 | "\n", 218 | "# Canny-based ControlNet\n", 219 | "controlnet_path = \"InstantX/SD3-Controlnet-Canny\"\n", 220 | "\n", 221 | "# Load the ControlNet and pipeline\n", 222 | "controlnet = SD3ControlNetModel.from_pretrained(\n", 223 | " controlnet_path, torch_dtype=torch.float16\n", 224 | ")\n", 225 | "pipe = StableDiffusion3ControlNetPipeline.from_pretrained(\n", 226 | " base_model_path,\n", 227 | " controlnet=controlnet,\n", 228 | ")\n", 229 | "\n", 230 | "# Load LoRA weights\n", 231 | "lora_output_path = \"trained-sd3-lora-miniature\"\n", 232 | "pipe.load_lora_weights(lora_output_path)\n", 233 | "\n", 234 | "# Move pipeline to GPU (half precision)\n", 235 | "pipe.to(\"cuda\", torch.float16)\n", 236 | "pipe.enable_sequential_cpu_offload()\n", 237 | "\n", 238 | "# ----------------------------------------------------\n", 239 | "# 1) Convert original control image to Canny edges via OpenCV\n", 240 | "# ----------------------------------------------------\n", 241 | "orig_path = \"./scripts/traffic_conditioning_image.png\"\n", 242 | "orig_bgr = cv2.imread(orig_path)\n", 243 | "if orig_bgr is None:\n", 244 | " raise ValueError(f\"Could not load file: {orig_path}\")\n", 245 | "\n", 246 | "# Convert to grayscale\n", 247 | "gray = cv2.cvtColor(orig_bgr, cv2.COLOR_BGR2GRAY)\n", 248 | "\n", 249 | "# Generate Canny edge map (tweak thresholds as needed)\n", 250 | "edges = cv2.Canny(gray, 100, 200)\n", 251 | "\n", 252 | "# Convert single-channel edge map to 3-channel RGB\n", 253 | "edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)\n", 254 | "\n", 255 | "# Convert to PIL for use in ControlNet pipeline\n", 256 | "control_image = Image.fromarray(edges_rgb)\n", 257 | "orig_width, orig_height = control_image.size\n", 258 | "target_width = 1024\n", 259 | "if orig_width > target_width:\n", 260 | " # (left, upper, right, lower)\n", 261 | " control_image = control_image.crop((0, 0, target_width, orig_height))\n", 262 | " \n", 263 | "print(\"Displaying Canny control image:\")\n", 264 | "\n", 265 | "display(control_image)\n", 266 | "\n", 267 | "# ----------------------------------------------------\n", 268 | "# 3) Set up prompts and run pipeline\n", 269 | "# ----------------------------------------------------\n", 270 | "prompt = \"pixelated network data for type-0 application traffic\"\n", 271 | "generator = torch.manual_seed(0) # reproducibility\n", 272 | "\n", 273 | "# Generate at 1024×1024 to match the new control image\n", 274 | "image = pipe(\n", 275 | " prompt=prompt,\n", 276 | " num_inference_steps=20,\n", 277 | " generator=generator,\n", 278 | " height=1024,\n", 279 | " width=1088,\n", 280 | " control_image=control_image,\n", 281 | " controlnet_conditioning_scale=0.5, # increase to adhere more strongly to edges\n", 282 | ").images[0]\n", 283 | "\n", 284 | "# ----------------------------------------------------\n", 285 | "# 4) Save the generated image\n", 286 | "# ----------------------------------------------------\n", 287 | "output_path = os.path.join(\"generated_traffic_images\", \"generated_traffic.png\")\n", 288 | "image.save(output_path)\n", 289 | "print(f\"Generated image saved to: {output_path}\")\n" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "metadata": {}, 296 | "outputs": [], 297 | "source": [ 298 | "# -----------------------------------------------------------\n", 299 | "# 🔄 Post-Generation Processing Pipeline\n", 300 | "#\n", 301 | "# This cell performs a 3-stage transformation of the generated PNG images:\n", 302 | "# 1. Applies color correction for standardization.\n", 303 | "# 2. Converts augmented images back into nPrint-compatible feature format.\n", 304 | "# 3. Applies heuristic corrections and reconstructs valid PCAP files.\n", 305 | "#\n", 306 | "# ⚙️ This pipeline enables turning synthetic traffic images\n", 307 | "# back into replayable network traffic for evaluation or simulation.\n", 308 | "# -----------------------------------------------------------\n", 309 | "\n", 310 | "# 🎨 Step 1: Color Augmentation\n", 311 | "# Applies standardized color shifts to improve nprint reconstruction accuracy\n", 312 | "!python ./scripts/color_processor.py \\\n", 313 | " --input_dir=\"./generated_traffic_images\" \\\n", 314 | " --output_dir=\"./color_corrected_generated_traffic_images\"\n", 315 | "# -----------------------------------------------------------\n", 316 | "# 🔁 Step 2: Image-to-nPrint Conversion\n", 317 | "# Converts augmented PNG images back into `.nprint` tabular format.\n", 318 | "#\n", 319 | "# Uses a reference `.nprint` file to maintain consistent structure and column order.\n", 320 | "# This step allows diffusion-generated visual traffic to be fed into analysis tools.\n", 321 | "# -----------------------------------------------------------\n", 322 | "!python ./scripts/image_to_nprint.py \\\n", 323 | " --org_nprint ./scripts/column_example.nprint \\\n", 324 | " --input_dir ./color_corrected_generated_traffic_images \\\n", 325 | " --output_dir ./generated_nprint\n", 326 | "# -----------------------------------------------------------\n", 327 | "# 🧠 Step 3: Heuristic Correction & PCAP Reconstruction\n", 328 | "#\n", 329 | "# This step reconstructs a valid and replayable `.pcap` file\n", 330 | "# from the diffusion-generated `.nprint` representation.\n", 331 | "# 🔍 Core Functionalities:\n", 332 | "# ✅ Intra-packet corrections (fixes within individual packets).\n", 333 | "# 🔁 Inter-packet dependency enforcement.\n", 334 | "# 🔧 Reconstruction:\n", 335 | "# - Save the corrected `.nprint` to disk\n", 336 | "# - Call `nprint -W` to convert `.nprint` into `.pcap` using external tool\n", 337 | "# - Run Scapy-based checksum updates to ensure IPv4 validity\n", 338 | "# - Reconvert final `.pcap` back to `.nprint` (with fixed layout) for downstream tasks\n", 339 | "# -----------------------------------------------------------\n", 340 | "!python ./scripts/mass_reconstruction.py \\\n", 341 | " --input_dir ./generated_nprint \\\n", 342 | " --output_pcap_dir ./replayable_generated_pcaps \\\n", 343 | " --output_nprint_dir ./replayable_generated_nprints \\\n", 344 | " --formatted_nprint_path ./scripts/correct_format.nprint\n", 345 | "# Final Pcap is stored in replayable_generated_pcaps" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": null, 351 | "metadata": {}, 352 | "outputs": [], 353 | "source": [] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "execution_count": null, 358 | "metadata": {}, 359 | "outputs": [], 360 | "source": [] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": null, 365 | "metadata": {}, 366 | "outputs": [], 367 | "source": [] 368 | } 369 | ], 370 | "metadata": { 371 | "kernelspec": { 372 | "display_name": "Python (new_netdiffusion)", 373 | "language": "python", 374 | "name": "new_netdiffusion" 375 | }, 376 | "language_info": { 377 | "codemirror_mode": { 378 | "name": "ipython", 379 | "version": 3 380 | }, 381 | "file_extension": ".py", 382 | "mimetype": "text/x-python", 383 | "name": "python", 384 | "nbconvert_exporter": "python", 385 | "pygments_lexer": "ipython3", 386 | "version": "3.10.6" 387 | } 388 | }, 389 | "nbformat": 4, 390 | "nbformat_minor": 4 391 | } 392 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | NetDiffusion Example Output 3 |

4 | 5 | # 🌐 NetDiffusion: High-Fidelity Synthetic Network Traffic Generation 6 | 7 |

8 | NetDiffusion Example Output 9 |

10 | 11 | --- 12 | 13 | ## 📘 Introduction 14 | 15 | **NetDiffusion** is an innovative tool designed to solve one of the core bottlenecks in networking ML research: the lack of high-quality, labeled, and privacy-preserving network traces. 16 | 17 | Traditional datasets often suffer from: 18 | - ⚠️ **Privacy concerns** 19 | - 🕓 **Data staleness** 20 | - 📉 **Limited diversity** 21 | 22 | NetDiffusion addresses these issues by using a **protocol-aware Stable Diffusion model** to synthesize network traffic that is both **realistic** and **standards-compliant**. 23 | 24 | > 🧪 The result? Synthetic packet captures that look and behave like real traffic—ideal for model training, testing, and simulation. 25 | 26 | --- 27 | 28 | ## ✨ Features 29 | 30 | - ✅ **High-Fidelity Data Generation** 31 | Generate synthetic traffic that matches real-world patterns and protocol semantics. 32 | 33 | - 🔌 **Tool Compatibility** 34 | Output traces are `.pcap` files—ready for use with Wireshark, Zeek, tshark, and other standard tools. 35 | 36 | - 🛠️ **Multi-Use Support** 37 | Beyond ML: Useful for system testing, anomaly detection, protocol emulation, and more. 38 | 39 | - 💡 **Fully Open Source** 40 | Built for the community. Modify, extend, and contribute freely. 41 | 42 | --- 43 | 44 | ## 📝 Note 45 | 46 | - The original **NetDiffusion** was implemented using **Stable Diffusion 1.5**, which is now deprecated with outdated dependencies. 47 | - This repo provides a **modern reimplementation using Stable Diffusion 3.0**, integrated with **InstantX/SD3-Controlnet-Canny**, preserving the framework’s core concepts while upgrading for compatibility and stability. 48 | 49 | --- 50 | 51 | ## 🗂 Project Structure 52 | 53 | - 🔧 All core scripts for preprocessing, training, inference, and reconstruction are located in the [`scripts/`](./scripts/) directory. 54 | - 📓 A step-by-step **Jupyter notebook** walks you through the entire pipeline: 55 | 56 | - 📦 **Dependency Installation** 57 | - 🧼 **Preprocessing (`.nprint` → `.png`)** 58 | - 🧠 **LoRA Fine-Tuning** on structured packet image embeddings 59 | - 🎨 **Diffusion-Based Generation** using ControlNet (Canny conditioning) 60 | - 🔄 **Post-Generation Processing** 61 | - Color correction 62 | - `.png` → `.nprint` → `.pcap` conversion 63 | - Replayable `.pcap` synthesis with protocol repair 64 | 65 | > ⚙️ The reimplementation is fully modular and forward-compatible, enabling seamless experimentation with next-gen diffusion architectures. 66 | 67 | --- 68 | 69 | ## 📚 Citing NetDiffusion 70 | 71 | If you use this tool or build on its techniques, please cite: 72 | 73 | ```bibtex 74 | @article{jiang2024netdiffusion, 75 | title={NetDiffusion: Network Data Augmentation Through Protocol-Constrained Traffic Generation}, 76 | author={Jiang, Xi and Liu, Shinan and Gember-Jacobson, Aaron and Bhagoji, Arjun Nitin and Schmitt, Paul and Bronzino, Francesco and Feamster, Nick}, 77 | journal={Proceedings of the ACM on Measurement and Analysis of Computing Systems}, 78 | volume={8}, 79 | number={1}, 80 | pages={1--32}, 81 | year={2024}, 82 | publisher={ACM New York, NY, USA} 83 | } 84 | -------------------------------------------------------------------------------- /scripts/color_processor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import argparse 4 | from PIL import Image 5 | 6 | def process_image(input_file, output_file): 7 | image = Image.open(input_file).convert('RGBA') 8 | width, height = image.size 9 | 10 | # Red thresholds 11 | r_red = 0 12 | g_red = 160 13 | b_red = 100 14 | 15 | # Green thresholds 16 | r_green = 0 17 | g_green = 160 18 | b_green = 100 19 | 20 | # Blue thresholds 21 | r_blue = 0 22 | g_blue = 160 23 | b_blue = 100 24 | 25 | for x in range(width): 26 | for y in range(height): 27 | r, g, b, a = image.getpixel((x, y)) 28 | 29 | if r > r_red and g < g_red and b < b_red: 30 | # Set the pixel to red 31 | image.putpixel((x, y), (255, 0, 0, 255)) 32 | elif r < r_green and g > g_green and b < b_green: 33 | # Set the pixel to green 34 | image.putpixel((x, y), (0, 255, 0, 255)) 35 | elif r < r_blue and g < g_blue and b > b_blue: 36 | # Set the pixel to blue 37 | image.putpixel((x, y), (0, 0, 255, 255)) 38 | else: 39 | # Choose the max channel and set to that color 40 | max_color = max(r, g, b) 41 | if max_color == r: 42 | image.putpixel((x, y), (255, 0, 0, 255)) 43 | elif max_color == g: 44 | image.putpixel((x, y), (0, 255, 0, 255)) 45 | else: 46 | image.putpixel((x, y), (0, 0, 255, 255)) 47 | 48 | image.save(output_file) 49 | 50 | def main(): 51 | parser = argparse.ArgumentParser( 52 | description="Process all .png images in a directory and save them to another directory." 53 | ) 54 | parser.add_argument("--input_dir", "-i", required=True, help="Path to input directory containing .png images.") 55 | parser.add_argument("--output_dir", "-o", required=True, help="Path to output directory for processed images.") 56 | args = parser.parse_args() 57 | 58 | os.makedirs(args.output_dir, exist_ok=True) 59 | 60 | count = 0 61 | for filename in os.listdir(args.input_dir): 62 | if filename.lower().endswith(".png"): 63 | input_path = os.path.join(args.input_dir, filename) 64 | output_path = os.path.join(args.output_dir, filename) 65 | process_image(input_path, output_path) 66 | count += 1 67 | 68 | print(f"Processed {count} images.") 69 | 70 | if __name__ == "__main__": 71 | main() 72 | -------------------------------------------------------------------------------- /scripts/compute_embeddings.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2025 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import argparse 18 | import glob 19 | import hashlib 20 | 21 | import pandas as pd 22 | import torch 23 | from transformers import T5EncoderModel 24 | 25 | from diffusers import StableDiffusion3Pipeline 26 | 27 | 28 | PROMPT = "pixelated network data for type-0 application traffic" 29 | MAX_SEQ_LENGTH = 77 30 | LOCAL_DATA_DIR = "nprint_traffic_images" 31 | OUTPUT_PATH = "sample_embeddings.parquet" 32 | 33 | 34 | def bytes_to_giga_bytes(bytes): 35 | return bytes / 1024 / 1024 / 1024 36 | 37 | 38 | def generate_image_hash(image_path): 39 | with open(image_path, "rb") as f: 40 | img_data = f.read() 41 | return hashlib.sha256(img_data).hexdigest() 42 | 43 | 44 | def load_sd3_pipeline(): 45 | id = "stabilityai/stable-diffusion-3-medium-diffusers" 46 | text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_3", load_in_8bit=True, device_map="auto") 47 | pipeline = StableDiffusion3Pipeline.from_pretrained( 48 | id, text_encoder_3=text_encoder, transformer=None, vae=None, device_map="balanced" 49 | ) 50 | return pipeline 51 | 52 | 53 | @torch.no_grad() 54 | def compute_embeddings(pipeline, prompt, max_sequence_length): 55 | ( 56 | prompt_embeds, 57 | negative_prompt_embeds, 58 | pooled_prompt_embeds, 59 | negative_pooled_prompt_embeds, 60 | ) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, prompt_3=None, max_sequence_length=max_sequence_length) 61 | 62 | print( 63 | f"{prompt_embeds.shape=}, {negative_prompt_embeds.shape=}, {pooled_prompt_embeds.shape=}, {negative_pooled_prompt_embeds.shape}" 64 | ) 65 | 66 | max_memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) 67 | print(f"Max memory allocated: {max_memory:.3f} GB") 68 | return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds 69 | 70 | 71 | def run(args): 72 | pipeline = load_sd3_pipeline() 73 | prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = compute_embeddings( 74 | pipeline, args.prompt, args.max_sequence_length 75 | ) 76 | 77 | # Assumes that the images within `args.local_image_dir` have a png extension. Change 78 | # as needed. 79 | image_paths = glob.glob(f"{args.local_data_dir}/*.png") 80 | data = [] 81 | for image_path in image_paths: 82 | img_hash = generate_image_hash(image_path) 83 | data.append( 84 | (img_hash, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds) 85 | ) 86 | 87 | # Create a DataFrame 88 | embedding_cols = [ 89 | "prompt_embeds", 90 | "negative_prompt_embeds", 91 | "pooled_prompt_embeds", 92 | "negative_pooled_prompt_embeds", 93 | ] 94 | df = pd.DataFrame( 95 | data, 96 | columns=["image_hash"] + embedding_cols, 97 | ) 98 | 99 | # Convert embedding lists to arrays (for proper storage in parquet) 100 | for col in embedding_cols: 101 | df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist()) 102 | 103 | # Save the dataframe to a parquet file 104 | df.to_parquet(args.output_path) 105 | print(f"Data successfully serialized to {args.output_path}") 106 | 107 | 108 | if __name__ == "__main__": 109 | parser = argparse.ArgumentParser() 110 | parser.add_argument("--prompt", type=str, default=PROMPT, help="The instance prompt.") 111 | parser.add_argument( 112 | "--max_sequence_length", 113 | type=int, 114 | default=MAX_SEQ_LENGTH, 115 | help="Maximum sequence length to use for computing the embeddings. The more the higher computational costs.", 116 | ) 117 | parser.add_argument( 118 | "--local_data_dir", type=str, default=LOCAL_DATA_DIR, help="Path to the directory containing instance images." 119 | ) 120 | parser.add_argument("--output_path", type=str, default=OUTPUT_PATH, help="Path to serialize the parquet file.") 121 | args = parser.parse_args() 122 | 123 | run(args) 124 | -------------------------------------------------------------------------------- /scripts/image_to_nprint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import os 4 | import pandas as pd 5 | from PIL import Image 6 | 7 | def ip_to_binary(ip_address): 8 | # Split the IP address into four octets 9 | octets = ip_address.split(".") 10 | # Convert each octet to binary form and pad with zeros 11 | binary_octets = [bin(int(octet))[2:].zfill(8) for octet in octets] 12 | return "".join(binary_octets) 13 | 14 | def binary_to_ip(binary_ip_address): 15 | if len(binary_ip_address) != 32: 16 | raise ValueError("Input binary string must be 32 bits") 17 | octets = [binary_ip_address[i : i + 8] for i in range(0, 32, 8)] 18 | decimal_octets = [str(int(octet, 2)) for octet in octets] 19 | return ".".join(decimal_octets) 20 | 21 | def rgba_to_ip(rgba): 22 | ip_parts = tuple(map(str, rgba)) 23 | return ".".join(ip_parts) 24 | 25 | def int_to_rgba(A): 26 | if A == 1: 27 | return (255, 0, 0, 255) 28 | elif A == 0: 29 | return (0, 255, 0, 255) 30 | elif A == -1: 31 | return (0, 0, 255, 255) 32 | elif A > 1: 33 | return (255, 0, 0, A) 34 | elif A < -1: 35 | return (0, 0, 255, abs(A)) 36 | else: 37 | return None 38 | 39 | def rgba_to_int(rgba): 40 | # Typical RGBA conversions 41 | if rgba == (255, 0, 0, 255): 42 | return 1 43 | elif rgba == (0, 255, 0, 255): 44 | return 0 45 | elif rgba == (0, 0, 255, 255): 46 | return -1 47 | elif rgba[0] == 255 and rgba[1] == 0 and rgba[2] == 0: 48 | return rgba[3] 49 | elif rgba[0] == 0 and rgba[1] == 0 and rgba[2] == 255: 50 | return -rgba[3] 51 | else: 52 | return None 53 | 54 | def split_bits(s): 55 | return [int(b) for b in s] 56 | 57 | def png_to_dataframe(input_file, columns): 58 | """ 59 | Convert a PNG to a DataFrame, using columns that match the reference .nprint CSV. 60 | """ 61 | img = Image.open(input_file).convert("RGBA") 62 | width, height = img.size 63 | print(f"Processing {input_file} with size {width} x {height}") 64 | 65 | data = [] 66 | for y in range(height): 67 | row = [] 68 | for x in range(width): 69 | rgba = img.getpixel((x, y)) 70 | val = rgba_to_int(rgba) 71 | row.append(val) 72 | data.append(row) 73 | 74 | # Construct a DataFrame using the provided column names 75 | df = pd.DataFrame(data, columns=columns) 76 | return df 77 | 78 | def main(): 79 | parser = argparse.ArgumentParser(description="Convert .png files back into .nprint format.") 80 | parser.add_argument( 81 | "--org_nprint", 82 | required=True, 83 | help="Path to the original .nprint CSV file for column reference." 84 | ) 85 | parser.add_argument( 86 | "--input_dir", 87 | required=True, 88 | help="Directory containing the .png images to convert." 89 | ) 90 | parser.add_argument( 91 | "--output_dir", 92 | required=True, 93 | help="Directory to save the resulting .nprint files." 94 | ) 95 | args = parser.parse_args() 96 | 97 | # 1) Load the reference .nprint CSV just to extract columns 98 | org_df = pd.read_csv(args.org_nprint) 99 | # If there's an unwanted "Unnamed: 0" column, remove it 100 | if "Unnamed: 0" in org_df.columns: 101 | org_df = org_df.drop("Unnamed: 0", axis=1) 102 | 103 | # Extract the columns to preserve the same structure 104 | columns = org_df.columns.tolist() 105 | 106 | # 2) Make sure output directory exists 107 | os.makedirs(args.output_dir, exist_ok=True) 108 | 109 | # 3) Iterate over .png files in input_dir 110 | converted_count = 0 111 | for filename in os.listdir(args.input_dir): 112 | if filename.lower().endswith(".png"): 113 | input_file = os.path.join(args.input_dir, filename) 114 | df = png_to_dataframe(input_file, columns=columns) 115 | 116 | # 4) Save as .nprint in output_dir 117 | output_file = os.path.join(args.output_dir, filename.replace(".png", ".nprint")) 118 | df.to_csv(output_file, index=False) 119 | converted_count += 1 120 | print(f"Saved {output_file}") 121 | 122 | print(f"Done! Converted {converted_count} .png files into .nprint format.") 123 | 124 | if __name__ == "__main__": 125 | main() 126 | -------------------------------------------------------------------------------- /scripts/mass_reconstruction.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import os 4 | import sys 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser(description="Generate pcap files from .nprint using reconstruction.py.") 9 | parser.add_argument( 10 | "--input_dir", 11 | default="./generated_nprint", 12 | help="Directory containing generated .nprint files to process." 13 | ) 14 | parser.add_argument( 15 | "--output_pcap_dir", 16 | default="./replayable_generated_pcaps", 17 | help="Directory to save the resulting .pcap files." 18 | ) 19 | parser.add_argument( 20 | "--output_nprint_dir", 21 | default="./replayable_generated_nprints", 22 | help="Directory to save any newly generated .nprint files." 23 | ) 24 | parser.add_argument( 25 | "--formatted_nprint_path", 26 | default="./correct_format.nprint", 27 | help="Path to the 'correct_format.nprint' reference file." 28 | ) 29 | args = parser.parse_args() 30 | 31 | # Create the output directories if they don't exist 32 | os.makedirs(args.output_pcap_dir, exist_ok=True) 33 | os.makedirs(args.output_nprint_dir, exist_ok=True) 34 | 35 | # Loop through all .nprint files in args.input_dir 36 | input_files = os.listdir(args.input_dir) 37 | if not input_files: 38 | print(f"No files found in {args.input_dir}") 39 | sys.exit(0) 40 | 41 | for org_nprint in input_files: 42 | if not org_nprint.endswith(".nprint"): 43 | # skip non-nprint files 44 | continue 45 | 46 | org_nprint_path = os.path.join(args.input_dir, org_nprint) 47 | # Construct output file paths 48 | output_pcap_path = os.path.join( 49 | args.output_pcap_dir, 50 | org_nprint.replace(".nprint", ".pcap"), 51 | ) 52 | output_nprint_path = os.path.join( 53 | args.output_nprint_dir, 54 | org_nprint, 55 | ) 56 | 57 | print(f"\nProcessing: {org_nprint_path} -> {output_pcap_path}") 58 | 59 | # Verify the input file can be accessed (catch OS or IO errors) 60 | try: 61 | if not os.path.isfile(org_nprint_path): 62 | print(f"Skipping: {org_nprint_path} is not a valid file.") 63 | continue 64 | except Exception as e: 65 | print(f"Error accessing {org_nprint_path}: {e}") 66 | continue 67 | 68 | # Prepare the system command 69 | cmd = ( 70 | f"python3 ./scripts/reconstruction.py " 71 | f"--generated_nprint_path '{org_nprint_path}' " 72 | f"--formatted_nprint_path '{args.formatted_nprint_path}' " 73 | f"--output '{output_pcap_path}' " 74 | f"--nprint '{output_nprint_path}'" 75 | ) 76 | 77 | # Run reconstruction.py 78 | print(f"Running command:\n {cmd}") 79 | try: 80 | return_code = os.system(cmd) 81 | if return_code != 0: 82 | # Non-zero exit code indicates an error 83 | print(f"ERROR: reconstruction.py failed with return code {return_code}") 84 | print("Skipping this file.") 85 | continue 86 | except Exception as e: 87 | # If os.system itself fails (rare) 88 | print(f"Exception while running reconstruction command: {e}") 89 | print("Skipping this file.") 90 | continue 91 | 92 | print(f"Success! Created {output_pcap_path} and possibly updated {output_nprint_path}.") 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /scripts/nprint_to_png.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pandas as pd 4 | import numpy as np 5 | from PIL import Image 6 | 7 | def int_to_rgba(A): 8 | if A == 1: 9 | return (255, 0, 0, 255) 10 | elif A == 0: 11 | return (0, 255, 0, 255) 12 | elif A == -1: 13 | return (0, 0, 255, 255) 14 | elif A > 1: 15 | return (255, 0, 0, A) 16 | elif A < -1: 17 | return (0, 0, 255, abs(A)) 18 | else: 19 | return (0, 0, 0, 255) # default fallback 20 | 21 | def dataframe_to_png(df, output_file): 22 | width, height = df.shape[1], df.shape[0] 23 | padded_height = 1024 24 | 25 | np_img = np.full((padded_height, width, 4), (0, 0, 255, 255), dtype=np.uint8) 26 | np_df = np.array(df.applymap(np.array).to_numpy().tolist()) 27 | np_img[:height, :, :] = np_df 28 | 29 | img = Image.fromarray(np_img, 'RGBA') 30 | 31 | file_exists = True 32 | counter = 1 33 | file_path, file_extension = os.path.splitext(output_file) 34 | while file_exists: 35 | if os.path.isfile(output_file): 36 | output_file = f"{file_path}_{counter}{file_extension}" 37 | counter += 1 38 | else: 39 | file_exists = False 40 | 41 | img.save(output_file) 42 | 43 | def convert_nprint_to_png(nprint_dir, output_dir): 44 | os.makedirs(output_dir, exist_ok=True) 45 | for file in os.listdir(nprint_dir): 46 | if file.endswith(".nprint"): 47 | print(f"Processing {file}") 48 | file_path = os.path.join(nprint_dir, file) 49 | try: 50 | df = pd.read_csv(file_path) 51 | if df.empty: 52 | continue 53 | 54 | substrings = ['ipv4_src', 'ipv4_dst', 'ipv6_src', 'ipv6_dst', 'src_ip'] 55 | df = df.drop(columns=[col for col in df.columns if any(sub in col for sub in substrings)]) 56 | 57 | for col in df.columns: 58 | df[col] = df[col].apply(int_to_rgba) 59 | 60 | output_file = os.path.join(output_dir, file.replace('.nprint', '.png')) 61 | dataframe_to_png(df, output_file) 62 | 63 | except Exception as e: 64 | print(f"❌ Failed to process {file}: {e}") 65 | continue 66 | 67 | def main(): 68 | parser = argparse.ArgumentParser(description="Convert .nprint files to PNG images.") 69 | parser.add_argument( 70 | "--input_dir", "-i", required=True, help="Directory containing .nprint files" 71 | ) 72 | parser.add_argument( 73 | "--output_dir", "-o", required=True, help="Directory to save PNG images" 74 | ) 75 | args = parser.parse_args() 76 | 77 | convert_nprint_to_png(args.input_dir, args.output_dir) 78 | 79 | if __name__ == "__main__": 80 | main() 81 | -------------------------------------------------------------------------------- /scripts/reconstruction.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import subprocess 3 | import numpy as np 4 | import sys 5 | import argparse 6 | import random 7 | import math 8 | from collections import Counter 9 | from scapy.all import * 10 | from scapy.utils import PcapWriter 11 | from scapy.all import rdpcap, wrpcap 12 | from scapy.layers.inet import IP, TCP 13 | 14 | def update_ipv4_checksum(pcap_file, output_file): 15 | # Read the packets from the pcap file 16 | packets = rdpcap(pcap_file) 17 | 18 | # Create a PcapWriter instance with the desired DLT value (DLT_RAW) 19 | writer = PcapWriter(output_file, linktype=101, sync=True) 20 | 21 | # Iterate over each packet 22 | for packet in packets: 23 | # Check if the packet is an IPv4 packet 24 | if packet.haslayer(IP): 25 | # Delete the checksum field 26 | del packet[IP].chksum 27 | 28 | # Scapy will automatically calculate the checksum when the packet is serialized/sent 29 | 30 | # Write packet to the output file with desired DLT 31 | writer.write(packet) 32 | 33 | # Close the writer 34 | writer.close() 35 | 36 | def binary_to_decimal(binary_list): 37 | """Convert a binary list representation to its decimal equivalent.""" 38 | binary_str = ''.join(map(str, binary_list)) 39 | return int(binary_str, 2) 40 | 41 | def random_ip(): 42 | return '.'.join(str(random.randint(0, 255)) for _ in range(4)) 43 | 44 | def read_syn_nprint(generated_nprint_path): 45 | syn_nprint_df = pd.read_csv(generated_nprint_path) 46 | substrings = ['Unnamed: 0'] 47 | # Get the list of columns that contain any of the specified substrings 48 | cols_to_drop = [col for col in syn_nprint_df.columns if any(substring in col for substring in substrings)] 49 | # Drop the selected columns and assign the resulting DataFrame back to 'df' 50 | syn_nprint_df = syn_nprint_df.drop(columns=cols_to_drop) 51 | return syn_nprint_df 52 | 53 | def encode_ip(ip): 54 | return ''.join([f'{int(x):08b}' for x in ip.split('.')]) 55 | 56 | def reconstruction_to_pcap(formatted_generated_nprint_path, rebuilt_pcap_path): 57 | subprocess.run('nprint -N {0} -W {1}'.format(formatted_generated_nprint_path, rebuilt_pcap_path), shell=True) 58 | 59 | def ip_address_formatting(generated_nprint, synthetic_sequence): 60 | # ################IP address population for non binary encoded ips 61 | if args.src_ip == '0.0.0.0': 62 | implementing_src_ip = random_ip() 63 | else: 64 | implementing_src_ip = args.src_ip 65 | 66 | if args.dst_ip == '0.0.0.0': 67 | implementing_dst_ip = random_ip() 68 | else: 69 | implementing_dst_ip = args.dst_ip 70 | 71 | # Iterate through the dataframe and list together 72 | for idx, value in enumerate(synthetic_sequence): 73 | if value == 0: 74 | generated_nprint.at[idx, 'src_ip'] = implementing_src_ip 75 | generated_nprint.at[idx, 'dst_ip'] = implementing_dst_ip 76 | else: 77 | generated_nprint.at[idx, 'src_ip'] = implementing_dst_ip 78 | generated_nprint.at[idx, 'dst_ip'] = implementing_src_ip 79 | 80 | 81 | ################Derive binary encoded ips according 82 | # Apply the function to the 'src_ip' column 83 | generated_nprint['src_binary_ip'] = generated_nprint['src_ip'].apply(encode_ip) 84 | # Apply the function to the 'dst_ip' column 85 | generated_nprint['dst_binary_ip'] = generated_nprint['dst_ip'].apply(encode_ip) 86 | # Split the binary IP addresses into separate columns 87 | for i in range(32): 88 | generated_nprint[f'ipv4_src_{i}'] = generated_nprint['src_binary_ip'].apply(lambda x: x[i]).astype(np.int8) 89 | for i in range(32): 90 | generated_nprint[f'ipv4_dst_{i}'] = generated_nprint['dst_binary_ip'].apply(lambda x: x[i]).astype(np.int8) 91 | # Drop the 'binary_ip' column as it's no longer needed 92 | generated_nprint = generated_nprint.drop(columns=['src_binary_ip']) 93 | generated_nprint = generated_nprint.drop(columns=['dst_binary_ip']) 94 | generated_nprint = generated_nprint.drop(columns=['dst_ip']) 95 | #print(generated_nprint['src_ip']) 96 | return generated_nprint 97 | 98 | def ipv4_hl_formatting(generated_nprint,formatted_nprint): 99 | # Get the subset of columns containing 'ipv4' 100 | ipv4_columns = generated_nprint.filter(like='ipv4') 101 | # For each row in the DataFrame 102 | for idx, row in ipv4_columns.iterrows(): 103 | # Count the 1s and 0s in this row 104 | count = (row == 1).sum() + (row == 0).sum() 105 | #print(count) 106 | # Convert to 32-bit/4-byte words 107 | header_size_words = math.ceil(count / 32) 108 | 109 | # Convert to binary and pad with zeroes to get a 4-bit representation 110 | binary_count = format(header_size_words, '04b') 111 | # Update the 'ipv4_hl' columns in the original DataFrame based on this binary representation 112 | for i in range(4): 113 | generated_nprint.at[idx, f'ipv4_hl_{i}'] = int(binary_count[i]) 114 | return generated_nprint 115 | 116 | def ipv4_tl_formatting_tcp(generated_nprint, formatted_nprint): 117 | 118 | counter = 0 119 | for idx, row in generated_nprint.iterrows(): 120 | # Extracting binary values for ipv4_tl, ipv4_hl, and tcp_doff 121 | ipv4_tl_binary = [row[f'ipv4_tl_{i}'] for i in range(16)] 122 | ipv4_hl_binary = [row[f'ipv4_hl_{i}'] for i in range(4)] 123 | tcp_doff_binary = [row[f'tcp_doff_{i}'] for i in range(4)] 124 | 125 | # Convert the binary representation to integer 126 | ipv4_tl_value = binary_to_decimal(ipv4_tl_binary) 127 | ipv4_hl_value = binary_to_decimal(ipv4_hl_binary) * 4 # Convert from 4-byte words to bytes 128 | tcp_doff_value = binary_to_decimal(tcp_doff_binary) * 4 # Convert from 4-byte words to bytes 129 | # Checking and setting the new value if condition is met 130 | if ipv4_tl_value < ipv4_hl_value + tcp_doff_value: 131 | new_ipv4_tl_value = ipv4_hl_value + tcp_doff_value 132 | # Convert new value back to binary and update the fields 133 | new_ipv4_tl_binary = format(new_ipv4_tl_value, '016b') 134 | for i, bit in enumerate(new_ipv4_tl_binary): 135 | generated_nprint.at[idx, f'ipv4_tl_{i}'] = int(bit) 136 | elif ipv4_tl_value>1500: 137 | new_ipv4_tl_binary = format(1500, '016b') 138 | for i, bit in enumerate(new_ipv4_tl_binary): 139 | generated_nprint.at[idx, f'ipv4_tl_{i}'] = int(bit) 140 | else: 141 | new_ipv4_tl_binary = format(ipv4_tl_value, '016b') 142 | for i, bit in enumerate(new_ipv4_tl_binary): 143 | generated_nprint.at[idx, f'ipv4_tl_{i}'] = int(bit) 144 | # for i in range(16): 145 | # generated_nprint[f'ipv4_tl_{i}'] = formatted_nprint[f'ipv4_tl_{i}'] 146 | for idx, row in generated_nprint.iterrows(): 147 | # Extracting binary values for ipv4_tl, ipv4_hl, and tcp_doff 148 | ipv4_tl_binary = [row[f'ipv4_tl_{i}'] for i in range(16)] 149 | ipv4_hl_binary = [row[f'ipv4_hl_{i}'] for i in range(4)] 150 | tcp_doff_binary = [row[f'tcp_doff_{i}'] for i in range(4)] 151 | 152 | # Convert the binary representation to integer 153 | ipv4_tl_value = binary_to_decimal(ipv4_tl_binary) 154 | ipv4_hl_value = binary_to_decimal(ipv4_hl_binary) * 4 # Convert from 4-byte words to bytes 155 | tcp_doff_value = binary_to_decimal(tcp_doff_binary) * 4 # Convert from 4-byte words to bytes 156 | # print(f'Packet {counter}:') 157 | # print('ipv4 total length in bytes:') 158 | # print(ipv4_tl_value) 159 | # print('ipv4 header length in bytes:') 160 | # print(ipv4_hl_value) 161 | # print('tcp doff in bytes:') 162 | # print(tcp_doff_value) 163 | # print() 164 | counter +=1 165 | 166 | 167 | return generated_nprint 168 | 169 | def ipv4_ver_formatting(generated_nprint,formatted_nprint): 170 | #following is placeholder, how do we get payload size?: 171 | # Define the substrings that have static values, e.g., ip version = 4 172 | fields = ["ipv4_ver"] 173 | # Iterate over the columns of the source DataFrame 174 | for column in formatted_nprint.columns: 175 | # Check if the substring exists in the column name 176 | for field in fields: 177 | if field in column: 178 | # Copy the column values to the destination DataFrame 179 | #generated_nprint[column] = formatted_nprint[column] 180 | # limit to ipv4 181 | if '0' in column: 182 | generated_nprint[column] = 0 183 | elif '1' in column: 184 | generated_nprint[column] = 1 185 | elif '2' in column: 186 | generated_nprint[column] = 0 187 | else: 188 | generated_nprint[column] = 0 189 | 190 | return generated_nprint 191 | 192 | def protocol_determination(generated_nprint): 193 | protocols = ["tcp", "udp", "icmp"] 194 | percentages = {} 195 | 196 | # Iterate over the protocols 197 | for protocol in protocols: 198 | columns = [col for col in generated_nprint.columns if protocol in col and 'opt' not in col] 199 | 200 | # Count non-negatives in each column and calculate the total percentage for each protocol 201 | total_count = 0 202 | non_negative_count = 0 203 | for column in columns: 204 | total_count += len(generated_nprint[column]) 205 | non_negative_count += (generated_nprint[column] >= 0).sum() 206 | 207 | # Calculate percentage and store in the dictionary 208 | if total_count > 0: 209 | percentages[protocol] = (non_negative_count / total_count) * 100 210 | else: 211 | percentages[protocol] = 0 212 | 213 | # Find protocol with the highest percentage of non-negative values 214 | max_protocol = max(percentages, key=percentages.get) 215 | return max_protocol 216 | 217 | def ipv4_pro_formatting(generated_nprint,formatted_nprint): 218 | #following is placeholder, how do we get payload size?: 219 | # Define the substrings that have static values, e.g., ip version = 4 220 | 221 | # Call the function to determine the protocol 222 | dominating_protocol = protocol_determination(generated_nprint) 223 | print(dominating_protocol) 224 | # tcp = 0,0,0,0,0,1,1,0 225 | # udp = 0,0,0,1,0,0,0,1 226 | # icmp = 0,0,0,0,0,0,0,1 227 | fields = ["ipv4_pro"] 228 | # Iterate over the columns of the source DataFrame 229 | for column in formatted_nprint.columns: 230 | # Check if the substring exists in the column name 231 | for field in fields: 232 | if field in column: 233 | if dominating_protocol == 'tcp': 234 | if '_0' in column: 235 | generated_nprint[column] = 0 236 | elif '_1' in column: 237 | generated_nprint[column] = 0 238 | elif '_2' in column: 239 | generated_nprint[column] = 0 240 | elif '_3' in column: 241 | generated_nprint[column] = 0 242 | elif '_4' in column: 243 | generated_nprint[column] = 0 244 | elif '_5' in column: 245 | generated_nprint[column] = 1 246 | elif '_6' in column: 247 | generated_nprint[column] = 1 248 | elif '_7' in column: 249 | generated_nprint[column] = 0 250 | elif dominating_protocol == 'udp': 251 | if '_0' in column: 252 | generated_nprint[column] = 0 253 | elif '_1' in column: 254 | generated_nprint[column] = 0 255 | elif '_2' in column: 256 | generated_nprint[column] = 0 257 | elif '_3' in column: 258 | generated_nprint[column] = 1 259 | elif '_4' in column: 260 | generated_nprint[column] = 0 261 | elif '_5' in column: 262 | generated_nprint[column] = 0 263 | elif '_6' in column: 264 | generated_nprint[column] = 0 265 | elif '_7' in column: 266 | generated_nprint[column] = 1 267 | elif dominating_protocol == 'icmp': 268 | if '_0' in column: 269 | generated_nprint[column] = 0 270 | elif '_1' in column: 271 | generated_nprint[column] = 0 272 | elif '_2' in column: 273 | generated_nprint[column] = 0 274 | elif '_3' in column: 275 | generated_nprint[column] = 0 276 | elif '_4' in column: 277 | generated_nprint[column] = 0 278 | elif '_5' in column: 279 | generated_nprint[column] = 0 280 | elif '_6' in column: 281 | generated_nprint[column] = 0 282 | elif '_7' in column: 283 | generated_nprint[column] = 1 284 | 285 | # Copy the column values to the destination DataFrame 286 | #generated_nprint[column] = formatted_nprint[column] 287 | # make sure non-dominant-protocol values are -1s 288 | protocols = ["tcp", "udp", "icmp"] 289 | for column in formatted_nprint.columns: 290 | # Check if the substring exists in the column name 291 | for protocol in protocols: 292 | if protocol in column: 293 | if protocol != dominating_protocol: 294 | generated_nprint[column] = -1 295 | 296 | 297 | return generated_nprint 298 | 299 | def ipv4_header_negative_removal(generated_nprint, formatted_nprint): 300 | fields = ["ipv4"] 301 | 302 | # Function to apply to each cell 303 | def replace_negative_one(val): 304 | if val == -1: 305 | return np.random.randint(0, 2) # Generates either 0 or 1 306 | else: 307 | return val 308 | 309 | # Iterate over the columns of the source DataFrame 310 | for column in formatted_nprint.columns: 311 | # Check if the substring exists in the column name 312 | for field in fields: 313 | #if field in column: 314 | if field in column and 'opt' not in column: 315 | generated_nprint[column] = generated_nprint[column].apply(replace_negative_one) 316 | # ######## no opt for debugging 317 | # elif field in column: 318 | # generated_nprint[column] = -1 319 | 320 | return generated_nprint 321 | 322 | def ipv4_option_removal(generated_nprint, formatted_nprint): 323 | fields = ["ipv4_opt"] 324 | 325 | # Iterate over the columns of the source DataFrame 326 | for column in formatted_nprint.columns: 327 | # Check if the substring exists in the column name 328 | for field in fields: 329 | #if field in column: 330 | if field in column: 331 | generated_nprint[column] = -1 332 | # ######## no opt for debugging 333 | # elif field in column: 334 | # generated_nprint[column] = -1 335 | 336 | return generated_nprint 337 | 338 | def ipv4_ttl_ensure(generated_nprint,formatted_nprint): 339 | 340 | for index in range(0, len(generated_nprint)): 341 | ttl_0 = True 342 | for j in range(8): 343 | if generated_nprint.at[index, f'ipv4_ttl_{j}'] != 0: 344 | ttl_0 = False 345 | if ttl_0 == True: 346 | generated_nprint.at[index, 'ipv4_ttl_7'] = 1 347 | return generated_nprint 348 | 349 | def tcp_header_negative_removal(generated_nprint, formatted_nprint): 350 | fields = ["tcp"] 351 | 352 | # Function to apply to each cell 353 | def replace_negative_one(val): 354 | if val == -1: 355 | return np.random.randint(0, 2) # Generates either 0 or 1 356 | else: 357 | return val 358 | 359 | # Iterate over the columns of the source DataFrame 360 | for column in formatted_nprint.columns: 361 | # Check if the substring exists in the column name 362 | for field in fields: 363 | #if field in column: 364 | if field in column and 'opt' not in column: 365 | generated_nprint[column] = generated_nprint[column].apply(replace_negative_one) 366 | # ######## no opt for debugging 367 | # elif field in column: 368 | # generated_nprint[column] = -1 369 | 370 | return generated_nprint 371 | 372 | def modify_tcp_option(packet): 373 | # This function processes each packet of the dataframe and modifies the TCP option fields to align with the actual structure of the TCP options. 374 | option_data = packet.loc['tcp_opt_0':'tcp_opt_319'].to_numpy() 375 | idx = 0 376 | options_lengths = [0, 8, 32, 24, 16, 40, 80] # NOP/EOL, MSS, Window Scale, SACK Permitted, SACK, Timestamp 377 | 378 | while idx < 320: 379 | start_idx = idx 380 | end_idx = idx 381 | while end_idx < 320 and option_data[end_idx] != -1: 382 | end_idx += 1 383 | length = end_idx - start_idx 384 | closest_option = min(options_lengths, key=lambda x: abs(x - length)) 385 | 386 | if closest_option == 32: # MSS 387 | #print('mss') 388 | idx += 32 389 | mss_data = np.concatenate(([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0], option_data[start_idx+16:idx])) 390 | mss_data = [np.random.choice([0, 1]) if bit == -1 else bit for bit in mss_data] 391 | option_data[start_idx:idx] = mss_data 392 | options_lengths.remove(closest_option) 393 | elif closest_option == 24: # Window Scale 394 | #print('ws') 395 | idx += 24 396 | ws_data = np.concatenate(([0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1], option_data[start_idx+16:idx])) 397 | ws_data = [np.random.choice([0, 1]) if bit == -1 else bit for bit in ws_data] 398 | option_data[start_idx:idx] = ws_data 399 | options_lengths.remove(closest_option) 400 | elif closest_option == 16: # SACK Permitted 401 | #print('sack permitted') 402 | idx += 16 403 | option_data[start_idx:idx] = [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0] 404 | options_lengths.remove(closest_option) 405 | 406 | elif closest_option == 40: # SACK (Assuming one block for simplicity) 407 | # Assuming the length would be for one SACK block: kind (1 byte), length (1 byte, value 10 for one block), and 8 bytes of data. 408 | idx+=40 409 | sack_data = np.concatenate(([0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0], option_data[start_idx+16:idx])) 410 | sack_data = [np.random.choice([0, 1]) if bit == -1 else bit for bit in sack_data] 411 | option_data[start_idx:idx] = sack_data 412 | options_lengths.remove(closest_option) 413 | 414 | elif closest_option == 80: # Timestamp 415 | #print('time stamp') 416 | idx += 80 417 | ts_data = np.concatenate(([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0], option_data[start_idx+16:idx])) 418 | ts_data = [np.random.choice([0, 1]) if bit == -1 else bit for bit in ts_data] 419 | option_data[start_idx:idx] = ts_data 420 | options_lengths.remove(closest_option) 421 | 422 | elif closest_option == 8: # 423 | #print('eol/nop') 424 | if option_data[start_idx] == 0: # EOL 425 | if start_idx == 0: 426 | idx += 8 427 | option_data[start_idx:idx] = [-1,-1,-1,-1,-1,-1,-1,-1] 428 | options_lengths.remove(closest_option) 429 | continue 430 | else: 431 | idx += 8 432 | option_data[start_idx:idx] = [0,0,0,0,0,0,0,0] 433 | option_data[idx:] = [-1] * (320 - idx) 434 | options_lengths.remove(closest_option) 435 | break 436 | elif option_data[start_idx] == 1: # NOP 437 | idx += 8 438 | option_data[start_idx:idx] = [0,0,0,0,0,0,0,1] 439 | elif closest_option == 0: 440 | idx += 8 441 | option_data[start_idx:idx] = [-1,-1,-1,-1,-1,-1,-1,-1] 442 | 443 | 444 | # Assign back the modified options to the DataFrame's row 445 | packet.loc['tcp_opt_0':'tcp_opt_319'] = option_data 446 | return packet 447 | 448 | def tcp_opt_formatting(generated_nprint, formatted_nprint): 449 | generated_nprint = generated_nprint.apply(modify_tcp_option, axis=1) 450 | return generated_nprint 451 | 452 | def tcp_data_offset_calculation(generated_nprint, formatted_nprint): 453 | # Get the subset of columns containing 'tcp' 454 | tcp_columns = generated_nprint.filter(like='tcp') 455 | # For each row in the DataFrame 456 | for idx, row in tcp_columns.iterrows(): 457 | # Count the 1s and 0s in this row 458 | count = (row == 1).sum() + (row == 0).sum() 459 | # Convert to 32-bit/4-byte words 460 | header_size_words = math.ceil(count / 32) 461 | # Convert to binary and pad with zeroes to get a 4-bit representation 462 | binary_count = format(header_size_words, '04b') 463 | # Update the 'ipv4_hl' columns in the original DataFrame based on this binary representation 464 | for i in range(4): 465 | generated_nprint.at[idx, f'tcp_doff_{i}'] = int(binary_count[i]) 466 | 467 | return generated_nprint 468 | 469 | def src_ip_distribution(formatted_nprint_path): 470 | # Assuming formatted_nprint_path is a DataFrame 471 | formatted_nprint = pd.read_csv(formatted_nprint_path) 472 | # Get the value counts for the 'src_ip' column and convert to a dictionary 473 | ip_counts = formatted_nprint['src_ip'].value_counts() 474 | 475 | # Convert the counts to percentages and store in a dictionary 476 | ip_distribution = (ip_counts / ip_counts.sum()).to_dict() 477 | 478 | return ip_distribution 479 | 480 | def direction_sampleing(formatted_nprint_path): 481 | # example distribution sampling 482 | # Let's say this is your analyzed distribution 483 | ip_distribution = src_ip_distribution(formatted_nprint_path) 484 | sorted_ips = sorted(ip_distribution.items(), key=lambda x: x[1], reverse=True) 485 | # Get the first two items (IPs with the highest proportions) 486 | top_two_ips = dict(sorted_ips[:2]) 487 | ips = [] 488 | for key in top_two_ips: 489 | ips.append(key) 490 | 491 | # Read the dataframe 492 | formatted_nprint_df = pd.read_csv(formatted_nprint_path) 493 | 494 | # Get the distinct IP addresses in order, but only if they belong to the top two IPs 495 | unique_ips = formatted_nprint_df['src_ip'].drop_duplicates() 496 | ordered_ips = [ip for ip in unique_ips if ip in top_two_ips] 497 | # Get the first two distinct IP addresses in order 498 | first_ip = ordered_ips[0] 499 | second_ip = ordered_ips[1] 500 | # Initialize the transition counts 501 | transition_counts = { 502 | (first_ip, first_ip): 0, 503 | (first_ip, second_ip): 0, 504 | (second_ip, first_ip): 0, 505 | (second_ip, second_ip): 0 506 | } 507 | # Iterate over the source IP addresses in order 508 | for i in range(1, len(formatted_nprint_df['src_ip'])): 509 | # Check if current and previous IP belong to top_two_ips 510 | if formatted_nprint_df['src_ip'][i] in top_two_ips and formatted_nprint_df['src_ip'][i-1] in top_two_ips: 511 | # Increment the count for the transition from the previous IP to the current IP 512 | transition_counts[(formatted_nprint_df['src_ip'][i-1], formatted_nprint_df['src_ip'][i])] += 1 513 | 514 | # Calculate the total counts from each state 515 | total_from_first_ip = transition_counts[(first_ip, first_ip)] + transition_counts[(first_ip, second_ip)] 516 | total_from_second_ip = transition_counts[(second_ip, first_ip)] + transition_counts[(second_ip, second_ip)] 517 | 518 | # Calculate the transition probabilities 519 | transition_matrix = np.array([ 520 | [transition_counts[(first_ip, first_ip)] / total_from_first_ip if total_from_first_ip > 0 else 0, 521 | transition_counts[(first_ip, second_ip)] / total_from_first_ip if total_from_first_ip > 0 else 0], 522 | [transition_counts[(second_ip, first_ip)] / total_from_second_ip if total_from_second_ip > 0 else 0, 523 | transition_counts[(second_ip, second_ip)] / total_from_second_ip if total_from_second_ip > 0 else 0] 524 | ]) 525 | # Map the indices to the IPs 526 | index_to_ip = {0: first_ip, 1: second_ip} 527 | 528 | # Starting state, for example 'src_to_dst' 529 | current_state = 0 530 | 531 | # Generate synthetic sequence 532 | synthetic_sequence = [] 533 | # start with the first IP as source 534 | synthetic_sequence.append(current_state) 535 | for _ in range(1023): 536 | # Choose the next state 537 | current_state = np.random.choice([0, 1], p=transition_matrix[current_state]) 538 | # Add the corresponding IP to the sequence 539 | # synthetic_sequence.append(index_to_ip[current_state]) 540 | synthetic_sequence.append(current_state) 541 | return synthetic_sequence 542 | 543 | def id_num_initialization_src_dst(generated_nprint): 544 | random_id_num = random_bits_generation(16) 545 | for i in range(16): 546 | generated_nprint.at[0, f'ipv4_id_{i}'] = int(random_id_num[i]) 547 | first_row_src_ip = generated_nprint.at[0, 'src_ip'] 548 | current_bin_str = random_id_num 549 | # Go through all other rows 550 | for index in range(1, len(generated_nprint)): 551 | # If src_ip of the current row matches that of the first row 552 | if generated_nprint.at[index, 'src_ip'] == first_row_src_ip: 553 | current_bin_str = increment_binary(current_bin_str) 554 | # Update the fields with the incremented value 555 | for i in range(16): 556 | generated_nprint.at[index, f'ipv4_id_{i}'] = int(current_bin_str[i]) 557 | 558 | return generated_nprint 559 | 560 | def id_num_initialization_dst_src(generated_nprint): 561 | random_id_num = random_bits_generation(16) 562 | first_row_src_ip = generated_nprint.at[0, 'src_ip'] 563 | current_bin_str = random_id_num 564 | # Go through all other rows 565 | for index in range(1, len(generated_nprint)): 566 | # If src_ip of the current row matches that of the first row 567 | if generated_nprint.at[index, 'src_ip'] != first_row_src_ip: 568 | current_bin_str = increment_binary(current_bin_str) 569 | # Update the fields with the incremented value 570 | for i in range(16): 571 | generated_nprint.at[index, f'ipv4_id_{i}'] = int(current_bin_str[i]) 572 | 573 | return generated_nprint 574 | 575 | def increment_binary(bin_str): 576 | '''Increment a binary string by one.''' 577 | bin_int = int(bin_str, 2) + 1 578 | return format(bin_int, f'0{len(bin_str)}b') 579 | 580 | def ip_fragementation_bits(generated_nprint): 581 | generated_nprint['ipv4_rbit_0'] = 0 582 | generated_nprint['ipv4_dfbit_0'] = 1 583 | generated_nprint['ipv4_mfbit_0'] = 0 584 | for i in range(13): 585 | generated_nprint[f'ipv4_foff_{i}'] = 0 586 | return generated_nprint 587 | 588 | def random_bits_generation(required_num_bits): 589 | return ''.join(str(random.randint(0, 1)) for _ in range(required_num_bits)) 590 | 591 | def port_initialization(generated_nprint): 592 | random_src_port = random_bits_generation(16) 593 | random_dst_port = random_bits_generation(16) 594 | dominating_protocol = protocol_determination(generated_nprint) 595 | if dominating_protocol == 'tcp': 596 | for i in range(16): 597 | generated_nprint.at[0, f'tcp_sprt_{i}'] = int(random_src_port[i]) 598 | generated_nprint.at[0, f'tcp_dprt_{i}'] = int(random_dst_port[i]) 599 | first_row_src_ip = generated_nprint.at[0, 'src_ip'] 600 | # Go through all other rows 601 | for index in range(1, len(generated_nprint)): 602 | # If src_ip of the current row matches that of the first row 603 | if generated_nprint.at[index, 'src_ip'] == first_row_src_ip: 604 | for i in range(16): 605 | generated_nprint.at[index, f'tcp_sprt_{i}'] = int(random_src_port[i]) 606 | generated_nprint.at[index, f'tcp_dprt_{i}'] = int(random_dst_port[i]) 607 | else: 608 | for i in range(16): 609 | generated_nprint.at[index, f'tcp_sprt_{i}'] = int(random_dst_port[i]) 610 | generated_nprint.at[index, f'tcp_dprt_{i}'] = int(random_src_port[i]) 611 | elif dominating_protocol == 'udp': 612 | for i in range(16): 613 | generated_nprint.at[0, f'udp_sport_{i}'] = int(random_src_port[i]) 614 | generated_nprint.at[0, f'udp_dport_{i}'] = int(random_dst_port[i]) 615 | first_row_src_ip = generated_nprint.at[0, 'src_ip'] 616 | # Go through all other rows 617 | for index in range(1, len(generated_nprint)): 618 | # If src_ip of the current row matches that of the first row 619 | if generated_nprint.at[index, 'src_ip'] == first_row_src_ip: 620 | for i in range(16): 621 | generated_nprint.at[index, f'udp_sport_{i}'] = int(random_src_port[i]) 622 | generated_nprint.at[index, f'udp_dport_{i}'] = int(random_dst_port[i]) 623 | else: 624 | for i in range(16): 625 | generated_nprint.at[index, f'udp_sport_{i}'] = int(random_dst_port[i]) 626 | generated_nprint.at[index, f'udp_dport_{i}'] = int(random_src_port[i]) 627 | return generated_nprint 628 | 629 | def compute_tcp_segment_length(row): 630 | total_length = int(''.join(str(row[f'ipv4_tl_{i}']) for i in range(16)), 2) 631 | ipv4_header_length = int(''.join(str(row[f'ipv4_hl_{i}']) for i in range(4)), 2) * 4 # in bytes 632 | tcp_header_length = int(''.join(str(row[f'tcp_doff_{i}']) for i in range(4)), 2) * 4 # in bytes 633 | 634 | return total_length - ipv4_header_length - tcp_header_length 635 | 636 | def increment_binary_non_fixed(bin_str, increment_value): 637 | decimal_value = int(bin_str, 2) 638 | decimal_value += increment_value 639 | new_bin_str = bin(decimal_value)[2:].zfill(len(bin_str)) 640 | return new_bin_str 641 | 642 | def seq_initialization_src_dst(generated_nprint): 643 | random_id_num = random_bits_generation(32) 644 | for i in range(32): 645 | generated_nprint.at[0, f'tcp_seq_{i}'] = int(random_id_num[i]) 646 | first_row_src_ip = generated_nprint.at[0, 'src_ip'] 647 | current_bin_str = random_id_num 648 | 649 | # To keep track of the last packet for each src_ip 650 | last_packet_for_ip = {first_row_src_ip: 0} 651 | 652 | # Go through all other rows 653 | for index in range(1, len(generated_nprint)): 654 | current_src_ip = generated_nprint.at[index, 'src_ip'] 655 | # If src_ip of the current row matches that of the first row 656 | if current_src_ip == first_row_src_ip: 657 | # If this IP has been seen before 658 | if current_src_ip in last_packet_for_ip: 659 | previous_index = last_packet_for_ip[current_src_ip] 660 | previous_row = generated_nprint.iloc[previous_index] 661 | segment_length = compute_tcp_segment_length(previous_row) 662 | current_bin_str = increment_binary_non_fixed(current_bin_str, segment_length) 663 | else: 664 | # If this IP has not been seen before 665 | current_bin_str = random_bits_generation(32) 666 | # Update the fields with the incremented value 667 | for i in range(32): 668 | generated_nprint.at[index, f'tcp_seq_{i}'] = int(current_bin_str[i]) 669 | 670 | # Update the last packet for this IP 671 | last_packet_for_ip[current_src_ip] = index 672 | 673 | return generated_nprint 674 | 675 | def seq_initialization_dst_src(generated_nprint): 676 | random_id_num = random_bits_generation(32) 677 | first_row_src_ip = generated_nprint.at[0, 'src_ip'] 678 | current_bin_str = random_id_num 679 | 680 | # To keep track of the last packet for each src_ip 681 | last_packet_for_ip = {first_row_src_ip: 0} 682 | 683 | # Go through all other rows 684 | for index in range(1, len(generated_nprint)): 685 | current_src_ip = generated_nprint.at[index, 'src_ip'] 686 | # If src_ip of the current row does not match that of the first row 687 | if current_src_ip != first_row_src_ip: 688 | # If this IP has been seen before 689 | if current_src_ip in last_packet_for_ip: 690 | previous_index = last_packet_for_ip[current_src_ip] 691 | previous_row = generated_nprint.iloc[previous_index] 692 | segment_length = compute_tcp_segment_length(previous_row) 693 | current_bin_str = increment_binary_non_fixed(current_bin_str, segment_length) 694 | else: 695 | # If this IP has not been seen before 696 | current_bin_str = random_bits_generation(32) 697 | # Update the fields with the incremented value 698 | for i in range(32): 699 | generated_nprint.at[index, f'tcp_seq_{i}'] = int(current_bin_str[i]) 700 | 701 | # Update the last packet for this IP 702 | last_packet_for_ip[current_src_ip] = index 703 | 704 | return generated_nprint 705 | 706 | def three_way_handshake(generated_nprint): 707 | #tcp_ackf_0,tcp_psh_0,tcp_rst_0,tcp_syn_0,tcp_fin_0, 708 | # Modify the dataframe for a 3-way handshake 709 | response_received = False 710 | handshake_complete = False 711 | for index in range(0, len(generated_nprint)): 712 | # If src_ip of the current row matches that of the first row 713 | if index == 0: 714 | generated_nprint.at[index, 'tcp_syn_0'] = 1 715 | generated_nprint.loc[index, ['tcp_ackf_0', 'tcp_psh_0', 'tcp_rst_0', 'tcp_fin_0']] = 0 716 | first_row_src_ip = generated_nprint.at[index, 'src_ip'] 717 | elif handshake_complete == False: 718 | if generated_nprint.at[index, 'src_ip'] == first_row_src_ip and response_received == False: 719 | generated_nprint.at[index, 'tcp_syn_0'] = 1 720 | generated_nprint.loc[index, ['tcp_ackf_0', 'tcp_psh_0', 'tcp_rst_0', 'tcp_fin_0']] = 0 721 | elif generated_nprint.at[index, 'src_ip'] != first_row_src_ip: 722 | generated_nprint.loc[index, ['tcp_syn_0', 'tcp_ackf_0']] = 1 723 | generated_nprint.loc[index, ['tcp_psh_0', 'tcp_rst_0', 'tcp_fin_0']] = 0 724 | response_received = True 725 | elif generated_nprint.at[index, 'src_ip'] == first_row_src_ip and response_received == True: 726 | generated_nprint.at[index, 'tcp_ackf_0'] = 1 727 | generated_nprint.loc[index, ['tcp_syn_0', 'tcp_psh_0', 'tcp_rst_0', 'tcp_fin_0']] = 0 728 | handshake_complete = True 729 | elif handshake_complete == True: 730 | generated_nprint.at[index, 'tcp_ackf_0'] = 1 731 | return generated_nprint 732 | 733 | def ackn_initialization_src_dst(generated_nprint): 734 | last_src_to_dst_seq = [] 735 | last_dst_to_src_seq = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 736 | src_ip = generated_nprint.at[0, 'src_ip'] 737 | for i in range(32): 738 | generated_nprint.at[0, f'tcp_ackn_{i}'] = 0 739 | for i in range(32): 740 | last_src_to_dst_seq.append(generated_nprint.at[0, f'tcp_seq_{i}']) 741 | 742 | first_switch_done = False 743 | for index in range(1, len(generated_nprint)): 744 | current_src_ip = generated_nprint.at[index, 'src_ip'] 745 | # if current src is identical as src and first switch is not done 746 | if current_src_ip == src_ip and not first_switch_done: 747 | # keep assign ack to be 0s 748 | for i in range(32): 749 | generated_nprint.at[index, f'tcp_ackn_{i}'] = 0 750 | # update the last sequence numbe from src to dst 751 | last_src_to_dst_seq = [] 752 | for i in range(32): 753 | last_src_to_dst_seq.append(generated_nprint.at[index, f'tcp_seq_{i}']) 754 | # if current src is non-identical now: we have switched direction 755 | if current_src_ip != src_ip: 756 | # assign first switch to be true. 757 | first_switch_done = True 758 | # update the ack to be the last_src_to_dst_seq 759 | for i in range(32): 760 | generated_nprint.at[index, f'tcp_ackn_{i}'] = last_src_to_dst_seq[i] 761 | # update the seq of last_dst_to_src_seq 762 | last_dst_to_src_seq = [] 763 | for i in range(32): 764 | last_dst_to_src_seq.append(generated_nprint.at[index, f'tcp_seq_{i}']) 765 | # if the src is identical as src and we have switch direction 766 | if current_src_ip == src_ip and first_switch_done: 767 | # update the ack to be the last_dst_to_src_seq 768 | for i in range(32): 769 | generated_nprint.at[index, f'tcp_ackn_{i}'] = last_dst_to_src_seq[i] 770 | # update the seq of last_src_to_dst_seq 771 | last_src_to_dst_seq = [] 772 | for i in range(32): 773 | last_src_to_dst_seq.append(generated_nprint.at[index, f'tcp_seq_{i}']) 774 | 775 | 776 | 777 | 778 | return generated_nprint 779 | 780 | def udp_len_calculation(generated_nprint, formatted_nprint): 781 | # For each row in the DataFrame 782 | for idx, row in generated_nprint.iterrows(): 783 | ipv4_hl_binary = [row[f'ipv4_hl_{i}'] for i in range(4)] 784 | ipv4_hl_value = binary_to_decimal(ipv4_hl_binary) * 4 # Convert from 4-byte words to bytes 785 | upper_limit = 1500 - ipv4_hl_value - 8 786 | udp_len_binary = [row[f'udp_len_{i}'] for i in range(16)] 787 | udp_len_value = binary_to_decimal(udp_len_binary) # Convert from 4-byte words to bytes 788 | if udp_len_value >= 8 and udp_len_value <= upper_limit: 789 | continue 790 | elif udp_len_value < 8: 791 | for i in range(16): 792 | generated_nprint.at[idx, f'udp_len_{i}'] = 0 793 | generated_nprint.at[idx, f'udp_len_12'] = 1 794 | else: 795 | new_udp_len_binary = format(upper_limit, '016b') 796 | for i in range(16): 797 | generated_nprint.at[idx, f'udp_len_{i}'] = int(new_udp_len_binary[i]) 798 | 799 | return generated_nprint 800 | 801 | def udp_header_negative_removal(generated_nprint, formatted_nprint): 802 | fields = ["udp"] 803 | 804 | # Function to apply to each cell 805 | def replace_negative_one(val): 806 | if val == -1: 807 | return np.random.randint(0, 2) # Generates either 0 or 1 808 | else: 809 | return val 810 | 811 | # Iterate over the columns of the source DataFrame 812 | for column in formatted_nprint.columns: 813 | # Check if the substring exists in the column name 814 | for field in fields: 815 | #if field in column: 816 | if field in column and 'opt' not in column: 817 | generated_nprint[column] = generated_nprint[column].apply(replace_negative_one) 818 | # ######## no opt for debugging 819 | # elif field in column: 820 | # generated_nprint[column] = -1 821 | 822 | return generated_nprint 823 | 824 | def ipv4_tl_formatting_udp(generated_nprint, formatted_nprint): 825 | 826 | counter = 0 827 | for idx, row in generated_nprint.iterrows(): 828 | # Extracting binary values for ipv4_tl, ipv4_hl, and tcp_doff 829 | ipv4_tl_binary = [row[f'ipv4_tl_{i}'] for i in range(16)] 830 | ipv4_hl_binary = [row[f'ipv4_hl_{i}'] for i in range(4)] 831 | udp_len_binary = [row[f'udp_len_{i}'] for i in range(16)] 832 | 833 | # Convert the binary representation to integer 834 | ipv4_tl_value = binary_to_decimal(ipv4_tl_binary) 835 | ipv4_hl_value = binary_to_decimal(ipv4_hl_binary) * 4 # Convert from 4-byte words to bytes 836 | udp_len_value = binary_to_decimal(udp_len_binary) # Convert from 4-byte words to bytes 837 | # Checking and setting the new value if condition is met 838 | if ipv4_tl_value < ipv4_hl_value + udp_len_value: 839 | new_ipv4_tl_value = ipv4_hl_value + udp_len_value 840 | # Convert new value back to binary and update the fields 841 | new_ipv4_tl_binary = format(new_ipv4_tl_value, '016b') 842 | for i, bit in enumerate(new_ipv4_tl_binary): 843 | generated_nprint.at[idx, f'ipv4_tl_{i}'] = int(bit) 844 | elif ipv4_tl_value>1500: 845 | new_ipv4_tl_binary = format(1500, '016b') 846 | for i, bit in enumerate(new_ipv4_tl_binary): 847 | generated_nprint.at[idx, f'ipv4_tl_{i}'] = int(bit) 848 | else: 849 | new_ipv4_tl_binary = format(ipv4_tl_value, '016b') 850 | for i, bit in enumerate(new_ipv4_tl_binary): 851 | generated_nprint.at[idx, f'ipv4_tl_{i}'] = int(bit) 852 | # for i in range(16): 853 | # generated_nprint[f'ipv4_tl_{i}'] = formatted_nprint[f'ipv4_tl_{i}'] 854 | for idx, row in generated_nprint.iterrows(): 855 | # Extracting binary values for ipv4_tl, ipv4_hl, and tcp_doff 856 | ipv4_tl_binary = [row[f'ipv4_tl_{i}'] for i in range(16)] 857 | ipv4_hl_binary = [row[f'ipv4_hl_{i}'] for i in range(4)] 858 | udp_len_binary = [row[f'udp_len_{i}'] for i in range(16)] 859 | 860 | # Convert the binary representation to integer 861 | ipv4_tl_value = binary_to_decimal(ipv4_tl_binary) 862 | ipv4_hl_value = binary_to_decimal(ipv4_hl_binary) * 4 # Convert from 4-byte words to bytes 863 | udp_len_value = binary_to_decimal(udp_len_binary) # Convert from 4-byte words to bytes 864 | 865 | counter +=1 866 | 867 | 868 | return generated_nprint 869 | 870 | def main(generated_nprint_path, formatted_nprint_path, output, nprint): 871 | rebuilt_pcap_path = output 872 | generated_nprint = read_syn_nprint(generated_nprint_path) 873 | formatted_nprint = pd.read_csv(formatted_nprint_path) 874 | # Get the list of column names in both CSVs 875 | generated_columns = set(generated_nprint.columns) 876 | formatted_columns = set(formatted_nprint.columns) 877 | # Find missing columns in the generated nprint 878 | missing_columns = formatted_columns - generated_columns 879 | # Add the missing columns to the generated nprint with value 0 880 | missing_data = pd.DataFrame(0, index=generated_nprint.index, columns=list(missing_columns)) 881 | generated_nprint = pd.concat([generated_nprint, missing_data], axis=1) 882 | # Reindex the columns of generated_nprint to match the order of columns in formatted_nprint 883 | generated_nprint = generated_nprint.reindex(columns=formatted_nprint.columns) 884 | 885 | 886 | 887 | 888 | ##############################################################################Intra packet dependency adjustments######################################################################### 889 | ########### ip address formatting (we have flexibility here) 890 | synthetic_sequence = direction_sampleing(formatted_nprint_path) 891 | generated_nprint = ip_address_formatting(generated_nprint, synthetic_sequence) 892 | 893 | ########### IPV4 894 | generated_nprint = ipv4_ver_formatting(generated_nprint,formatted_nprint) # we are using ipv4 only 895 | generated_nprint = ipv4_header_negative_removal(generated_nprint,formatted_nprint) # here we make sure minimum ipv4 header size is achieved - no missing ipv4 header fields, random int is assigned as the fields largely are correct due to diffusion 896 | generated_nprint = ipv4_pro_formatting(generated_nprint,formatted_nprint) # this is less flexible -> choose protocol with most percentage of non negatives excluding option, and change all non-determined-protocol fields to -1 897 | generated_nprint = ipv4_option_removal(generated_nprint,formatted_nprint) # mordern Internet rarely uses ipv4 options is used at all, from the data we observe ipv4 options are never present due to it being obsolete 898 | generated_nprint = ipv4_ttl_ensure(generated_nprint,formatted_nprint) # ensure ttl > 0 899 | generated_nprint = ipv4_hl_formatting(generated_nprint,formatted_nprint) # ipv4 header length formatting (this is computation based so we do not have flexibility here), need to be done after all other ipv4 fields are formatted 900 | # CHECKSUM UPDATED AT THE END 901 | 902 | 903 | ########### TCP 904 | dominating_protocol = protocol_determination(generated_nprint) 905 | if dominating_protocol == 'tcp': 906 | generated_nprint = tcp_header_negative_removal(generated_nprint, formatted_nprint) 907 | generated_nprint = tcp_opt_formatting(generated_nprint,formatted_nprint) # option must be continuous and has fixed length, we use closest approximation here 908 | generated_nprint = tcp_data_offset_calculation(generated_nprint, formatted_nprint) # count the total number of bytes in the tcp header fields including options and store the sume as the offset 909 | 910 | ########### IPV4 911 | generated_nprint = ipv4_tl_formatting_tcp(generated_nprint,formatted_nprint) # payload need to be considered 912 | elif dominating_protocol == 'udp': 913 | generated_nprint = udp_header_negative_removal(generated_nprint, formatted_nprint) 914 | generated_nprint = udp_len_calculation(generated_nprint, formatted_nprint) 915 | ########### IPV4 916 | generated_nprint = ipv4_tl_formatting_udp(generated_nprint,formatted_nprint) # payload need to be considered 917 | 918 | ##############################################################################End of Intra packet dependency adjustments######################################################################### 919 | 920 | 921 | 922 | 923 | ############################################################################## Inter packet dependency adjustments######################################################################### 924 | # random initial identification number initial for first source to destination packet, use synthetic_sequence to keep track for the rest for increments 925 | generated_nprint = id_num_initialization_src_dst(generated_nprint) 926 | # random initial identification number initial for first dst to src packet, use synthetic_sequence to keep track 927 | generated_nprint = id_num_initialization_dst_src(generated_nprint) 928 | #generated_nprint = seq_num_initialization_dst_src(generated_nprint) 929 | 930 | # Set fragmentation related bits, do not fragment to 1, OR NOT? reseved bit to 0 for all packets, more fragments bit to 0 for all packets, 0 for fragmentation offset 931 | generated_nprint = ip_fragementation_bits(generated_nprint) 932 | dominating_protocol = protocol_determination(generated_nprint) 933 | if dominating_protocol == 'tcp': 934 | # ports needs to be consistent but can be randomly generated 935 | generated_nprint = port_initialization(generated_nprint) 936 | # seq number must be computed based on tcp segment length 937 | generated_nprint = seq_initialization_src_dst(generated_nprint) 938 | generated_nprint = seq_initialization_dst_src(generated_nprint) 939 | # three way handshake initial flags must be set correctly 940 | generated_nprint = three_way_handshake(generated_nprint) 941 | 942 | # ack num formatting 943 | generated_nprint = ackn_initialization_src_dst(generated_nprint) 944 | elif dominating_protocol == 'udp': 945 | # ports needs to be consistent but can be randomly generated 946 | generated_nprint = port_initialization(generated_nprint) 947 | 948 | 949 | ##############################################################################End of Inter packet dependency adjustments######################################################################### 950 | 951 | 952 | ## ipv6 removal 953 | fields = ["ipv6"] 954 | # Iterate over the columns of the source DataFrame 955 | for column in formatted_nprint.columns: 956 | # Check if the substring exists in the column name 957 | for field in fields: 958 | #if field in column: 959 | if field in column : 960 | generated_nprint[column] = -1 961 | 962 | # saved the formatted_generated_nprint and attempt reconstruction 963 | formatted_generated_nprint_path = nprint 964 | generated_nprint.to_csv(formatted_generated_nprint_path, index=False) 965 | reconstruction_to_pcap(formatted_generated_nprint_path, rebuilt_pcap_path) 966 | update_ipv4_checksum(rebuilt_pcap_path, rebuilt_pcap_path) 967 | subprocess.run('nprint -F -1 -P {0} -4 -i -6 -t -u -p 0 -c 1024 -W {1}'.format(rebuilt_pcap_path, formatted_generated_nprint_path), shell=True) 968 | #reconstruction_to_pcap('/Users/chasejiang/Desktop/netdiffussion/replayability/meet_real.nprint', rebuilt_pcap_path) 969 | #reconstruction_to_pcap(formatted_nprint_path, rebuilt_pcap_path) 970 | 971 | 972 | 973 | 974 | if __name__ == '__main__': 975 | parser = argparse.ArgumentParser(description='pcap reconstruction') 976 | parser.add_argument('--generated_nprint_path', required = True, help='Path to the generated nPrint file') 977 | parser.add_argument('--formatted_nprint_path', required = True, help='Path to the formatted nPrint file') 978 | parser.add_argument('--output', required = True, help='Path to the reconstructed pcap file') 979 | parser.add_argument('--nprint', required = True, help='Path to the reconstructed nprint file') 980 | parser.add_argument('--src_ip', 981 | help='Desired source IP address, randomly generated if not specified', 982 | default='0.0.0.0') 983 | parser.add_argument('--dst_ip', 984 | help='Desired destination IP address, randomly generated if not specified', 985 | default='0.0.0.0') 986 | args = parser.parse_args() 987 | main(args.generated_nprint_path, args.formatted_nprint_path, args.output, args.nprint) -------------------------------------------------------------------------------- /scripts/traffic_conditioning_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/noise-lab/NetDiffusion_Generator/7201a702149fbd26ecdcf50bb000da4bb7fa2787/scripts/traffic_conditioning_image.png -------------------------------------------------------------------------------- /scripts/train_dreambooth_lora_sd3_miniature.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2025 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import argparse 18 | import copy 19 | import gc 20 | import hashlib 21 | import logging 22 | import math 23 | import os 24 | import random 25 | import shutil 26 | from contextlib import nullcontext 27 | from pathlib import Path 28 | 29 | import numpy as np 30 | import pandas as pd 31 | import torch 32 | import torch.utils.checkpoint 33 | import transformers 34 | from accelerate import Accelerator 35 | from accelerate.logging import get_logger 36 | from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed 37 | from huggingface_hub import create_repo, upload_folder 38 | from peft import LoraConfig, set_peft_model_state_dict 39 | from peft.utils import get_peft_model_state_dict 40 | from PIL import Image 41 | from PIL.ImageOps import exif_transpose 42 | from torch.utils.data import Dataset 43 | from torchvision import transforms 44 | from torchvision.transforms.functional import crop 45 | from tqdm.auto import tqdm 46 | 47 | import diffusers 48 | from diffusers import ( 49 | AutoencoderKL, 50 | FlowMatchEulerDiscreteScheduler, 51 | SD3Transformer2DModel, 52 | StableDiffusion3Pipeline, 53 | ) 54 | from diffusers.optimization import get_scheduler 55 | from diffusers.training_utils import ( 56 | cast_training_params, 57 | compute_density_for_timestep_sampling, 58 | compute_loss_weighting_for_sd3, 59 | ) 60 | from diffusers.utils import ( 61 | check_min_version, 62 | convert_unet_state_dict_to_peft, 63 | is_wandb_available, 64 | ) 65 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card 66 | from diffusers.utils.torch_utils import is_compiled_module 67 | 68 | 69 | if is_wandb_available(): 70 | import wandb 71 | 72 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 73 | check_min_version("0.30.0.dev0") 74 | 75 | logger = get_logger(__name__) 76 | 77 | 78 | def save_model_card( 79 | repo_id: str, 80 | images=None, 81 | base_model: str = None, 82 | train_text_encoder=False, 83 | instance_prompt=None, 84 | validation_prompt=None, 85 | repo_folder=None, 86 | ): 87 | widget_dict = [] 88 | if images is not None: 89 | for i, image in enumerate(images): 90 | image.save(os.path.join(repo_folder, f"image_{i}.png")) 91 | widget_dict.append( 92 | {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} 93 | ) 94 | 95 | model_description = f""" 96 | # SD3 DreamBooth LoRA - {repo_id} 97 | 98 | 99 | 100 | ## Model description 101 | 102 | These are {repo_id} DreamBooth weights for {base_model}. 103 | 104 | The weights were trained using [DreamBooth](https://dreambooth.github.io/). 105 | 106 | LoRA for the text encoder was enabled: {train_text_encoder}. 107 | 108 | ## Trigger words 109 | 110 | You should use {instance_prompt} to trigger the image generation. 111 | 112 | ## Download model 113 | 114 | [Download]({repo_id}/tree/main) them in the Files & versions tab. 115 | 116 | ## License 117 | 118 | Please adhere to the licensing terms as described [here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE). 119 | """ 120 | model_card = load_or_create_model_card( 121 | repo_id_or_path=repo_id, 122 | from_training=True, 123 | license="openrail++", 124 | base_model=base_model, 125 | prompt=instance_prompt, 126 | model_description=model_description, 127 | widget=widget_dict, 128 | ) 129 | tags = [ 130 | "text-to-image", 131 | "diffusers-training", 132 | "diffusers", 133 | "lora", 134 | "sd3", 135 | "sd3-diffusers", 136 | "template:sd-lora", 137 | ] 138 | 139 | model_card = populate_model_card(model_card, tags=tags) 140 | model_card.save(os.path.join(repo_folder, "README.md")) 141 | 142 | 143 | def log_validation( 144 | pipeline, 145 | args, 146 | accelerator, 147 | pipeline_args, 148 | epoch, 149 | is_final_validation=False, 150 | ): 151 | logger.info( 152 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:" 153 | f" {args.validation_prompt}." 154 | ) 155 | pipeline.enable_model_cpu_offload() 156 | pipeline.set_progress_bar_config(disable=True) 157 | 158 | # run inference 159 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None 160 | # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() 161 | autocast_ctx = nullcontext() 162 | 163 | with autocast_ctx: 164 | images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] 165 | 166 | for tracker in accelerator.trackers: 167 | phase_name = "test" if is_final_validation else "validation" 168 | if tracker.name == "tensorboard": 169 | np_images = np.stack([np.asarray(img) for img in images]) 170 | tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") 171 | if tracker.name == "wandb": 172 | tracker.log( 173 | { 174 | phase_name: [ 175 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) 176 | ] 177 | } 178 | ) 179 | 180 | del pipeline 181 | if torch.cuda.is_available(): 182 | torch.cuda.empty_cache() 183 | 184 | return images 185 | 186 | 187 | def parse_args(input_args=None): 188 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 189 | parser.add_argument( 190 | "--pretrained_model_name_or_path", 191 | type=str, 192 | default=None, 193 | required=True, 194 | help="Path to pretrained model or model identifier from huggingface.co/models.", 195 | ) 196 | parser.add_argument( 197 | "--revision", 198 | type=str, 199 | default=None, 200 | required=False, 201 | help="Revision of pretrained model identifier from huggingface.co/models.", 202 | ) 203 | parser.add_argument( 204 | "--variant", 205 | type=str, 206 | default=None, 207 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 208 | ) 209 | parser.add_argument( 210 | "--instance_data_dir", 211 | type=str, 212 | default=None, 213 | help=("A folder containing the training data. "), 214 | ) 215 | parser.add_argument( 216 | "--data_df_path", 217 | type=str, 218 | default=None, 219 | help=("Path to the parquet file serialized with compute_embeddings.py."), 220 | ) 221 | parser.add_argument( 222 | "--cache_dir", 223 | type=str, 224 | default=None, 225 | help="The directory where the downloaded models and datasets will be stored.", 226 | ) 227 | parser.add_argument( 228 | "--instance_prompt", 229 | type=str, 230 | default=None, 231 | required=True, 232 | help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", 233 | ) 234 | parser.add_argument( 235 | "--max_sequence_length", 236 | type=int, 237 | default=77, 238 | help="Maximum sequence length to use with with the T5 text encoder", 239 | ) 240 | parser.add_argument( 241 | "--validation_prompt", 242 | type=str, 243 | default=None, 244 | help="A prompt that is used during validation to verify that the model is learning.", 245 | ) 246 | parser.add_argument( 247 | "--num_validation_images", 248 | type=int, 249 | default=4, 250 | help="Number of images that should be generated during validation with `validation_prompt`.", 251 | ) 252 | parser.add_argument( 253 | "--validation_epochs", 254 | type=int, 255 | default=50, 256 | help=( 257 | "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" 258 | " `args.validation_prompt` multiple times: `args.num_validation_images`." 259 | ), 260 | ) 261 | parser.add_argument( 262 | "--rank", 263 | type=int, 264 | default=4, 265 | help=("The dimension of the LoRA update matrices."), 266 | ) 267 | parser.add_argument( 268 | "--output_dir", 269 | type=str, 270 | default="sd3-dreambooth-lora", 271 | help="The output directory where the model predictions and checkpoints will be written.", 272 | ) 273 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 274 | parser.add_argument( 275 | "--resolution", 276 | type=int, 277 | default=512, 278 | help=( 279 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 280 | " resolution" 281 | ), 282 | ) 283 | parser.add_argument( 284 | "--center_crop", 285 | default=False, 286 | action="store_true", 287 | help=( 288 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" 289 | " cropped. The images will be resized to the resolution first before cropping." 290 | ), 291 | ) 292 | parser.add_argument( 293 | "--random_flip", 294 | action="store_true", 295 | help="whether to randomly flip images horizontally", 296 | ) 297 | 298 | parser.add_argument( 299 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 300 | ) 301 | parser.add_argument("--num_train_epochs", type=int, default=1) 302 | parser.add_argument( 303 | "--max_train_steps", 304 | type=int, 305 | default=None, 306 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 307 | ) 308 | parser.add_argument( 309 | "--checkpointing_steps", 310 | type=int, 311 | default=500, 312 | help=( 313 | "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" 314 | " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" 315 | " training using `--resume_from_checkpoint`." 316 | ), 317 | ) 318 | parser.add_argument( 319 | "--checkpoints_total_limit", 320 | type=int, 321 | default=None, 322 | help=("Max number of checkpoints to store."), 323 | ) 324 | parser.add_argument( 325 | "--resume_from_checkpoint", 326 | type=str, 327 | default=None, 328 | help=( 329 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 330 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 331 | ), 332 | ) 333 | parser.add_argument( 334 | "--gradient_accumulation_steps", 335 | type=int, 336 | default=1, 337 | help="Number of updates steps to accumulate before performing a backward/update pass.", 338 | ) 339 | parser.add_argument( 340 | "--gradient_checkpointing", 341 | action="store_true", 342 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 343 | ) 344 | parser.add_argument( 345 | "--learning_rate", 346 | type=float, 347 | default=1e-4, 348 | help="Initial learning rate (after the potential warmup period) to use.", 349 | ) 350 | parser.add_argument( 351 | "--scale_lr", 352 | action="store_true", 353 | default=False, 354 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 355 | ) 356 | parser.add_argument( 357 | "--lr_scheduler", 358 | type=str, 359 | default="constant", 360 | help=( 361 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 362 | ' "constant", "constant_with_warmup"]' 363 | ), 364 | ) 365 | parser.add_argument( 366 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 367 | ) 368 | parser.add_argument( 369 | "--lr_num_cycles", 370 | type=int, 371 | default=1, 372 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.", 373 | ) 374 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") 375 | parser.add_argument( 376 | "--dataloader_num_workers", 377 | type=int, 378 | default=0, 379 | help=( 380 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 381 | ), 382 | ) 383 | parser.add_argument( 384 | "--weighting_scheme", 385 | type=str, 386 | default="logit_normal", 387 | choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"], 388 | ) 389 | parser.add_argument( 390 | "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." 391 | ) 392 | parser.add_argument( 393 | "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." 394 | ) 395 | parser.add_argument( 396 | "--mode_scale", 397 | type=float, 398 | default=1.29, 399 | help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", 400 | ) 401 | parser.add_argument( 402 | "--optimizer", 403 | type=str, 404 | default="AdamW", 405 | help=('The optimizer type to use. Choose between ["AdamW"]'), 406 | ) 407 | 408 | parser.add_argument( 409 | "--use_8bit_adam", 410 | action="store_true", 411 | help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", 412 | ) 413 | 414 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 415 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 416 | parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") 417 | 418 | parser.add_argument( 419 | "--adam_epsilon", 420 | type=float, 421 | default=1e-08, 422 | help="Epsilon value for the Adam optimizer.", 423 | ) 424 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 425 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 426 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 427 | parser.add_argument( 428 | "--hub_model_id", 429 | type=str, 430 | default=None, 431 | help="The name of the repository to keep in sync with the local `output_dir`.", 432 | ) 433 | parser.add_argument( 434 | "--logging_dir", 435 | type=str, 436 | default="logs", 437 | help=( 438 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 439 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 440 | ), 441 | ) 442 | parser.add_argument( 443 | "--allow_tf32", 444 | action="store_true", 445 | help=( 446 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 447 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 448 | ), 449 | ) 450 | parser.add_argument( 451 | "--report_to", 452 | type=str, 453 | default="tensorboard", 454 | help=( 455 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 456 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 457 | ), 458 | ) 459 | parser.add_argument( 460 | "--mixed_precision", 461 | type=str, 462 | default=None, 463 | choices=["no", "fp16", "bf16"], 464 | help=( 465 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 466 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 467 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 468 | ), 469 | ) 470 | parser.add_argument( 471 | "--prior_generation_precision", 472 | type=str, 473 | default=None, 474 | choices=["no", "fp32", "fp16", "bf16"], 475 | help=( 476 | "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 477 | " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." 478 | ), 479 | ) 480 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 481 | 482 | if input_args is not None: 483 | args = parser.parse_args(input_args) 484 | else: 485 | args = parser.parse_args() 486 | 487 | if args.instance_data_dir is None: 488 | raise ValueError("Specify `instance_data_dir`.") 489 | 490 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 491 | if env_local_rank != -1 and env_local_rank != args.local_rank: 492 | args.local_rank = env_local_rank 493 | 494 | return args 495 | 496 | 497 | class DreamBoothDataset(Dataset): 498 | """ 499 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model. 500 | It pre-processes the images. 501 | """ 502 | 503 | def __init__( 504 | self, 505 | data_df_path, 506 | instance_data_root, 507 | instance_prompt, 508 | size=1024, 509 | center_crop=False, 510 | ): 511 | # Logistics 512 | self.size = size 513 | self.center_crop = center_crop 514 | 515 | self.instance_prompt = instance_prompt 516 | self.instance_data_root = Path(instance_data_root) 517 | if not self.instance_data_root.exists(): 518 | raise ValueError("Instance images root doesn't exists.") 519 | 520 | # Load images. 521 | instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] 522 | image_hashes = [self.generate_image_hash(path) for path in list(Path(instance_data_root).iterdir())] 523 | self.instance_images = instance_images 524 | self.image_hashes = image_hashes 525 | 526 | # Image transformations 527 | self.pixel_values = self.apply_image_transformations( 528 | instance_images=instance_images, size=size, center_crop=center_crop 529 | ) 530 | 531 | # Map hashes to embeddings. 532 | self.data_dict = self.map_image_hash_embedding(data_df_path=data_df_path) 533 | 534 | self.num_instance_images = len(instance_images) 535 | self._length = self.num_instance_images 536 | 537 | def __len__(self): 538 | return self._length 539 | 540 | def __getitem__(self, index): 541 | example = {} 542 | instance_image = self.pixel_values[index % self.num_instance_images] 543 | image_hash = self.image_hashes[index % self.num_instance_images] 544 | prompt_embeds, pooled_prompt_embeds = self.data_dict[image_hash] 545 | example["instance_images"] = instance_image 546 | example["prompt_embeds"] = prompt_embeds 547 | example["pooled_prompt_embeds"] = pooled_prompt_embeds 548 | return example 549 | 550 | def apply_image_transformations(self, instance_images, size, center_crop): 551 | pixel_values = [] 552 | 553 | # train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) 554 | # train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) 555 | target_height = 1024 556 | target_width = 1089 557 | train_resize = transforms.Resize((target_height, target_width), interpolation=transforms.InterpolationMode.BILINEAR) 558 | train_crop = None # no need to crop anymore 559 | 560 | train_flip = transforms.RandomHorizontalFlip(p=1.0) 561 | train_transforms = transforms.Compose( 562 | [ 563 | transforms.ToTensor(), 564 | transforms.Normalize([0.5], [0.5]), 565 | ] 566 | ) 567 | for image in instance_images: 568 | image = exif_transpose(image) 569 | if not image.mode == "RGB": 570 | image = image.convert("RGB") 571 | image = train_resize(image) 572 | if args.random_flip and random.random() < 0.5: 573 | # flip 574 | image = train_flip(image) 575 | # if args.center_crop: 576 | # y1 = max(0, int(round((image.height - args.resolution) / 2.0))) 577 | # x1 = max(0, int(round((image.width - args.resolution) / 2.0))) 578 | # image = train_crop(image) 579 | # else: 580 | # y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) 581 | # image = crop(image, y1, x1, h, w) 582 | if center_crop and train_crop is not None: 583 | image = train_crop(image) 584 | image = train_transforms(image) 585 | pixel_values.append(image) 586 | 587 | return pixel_values 588 | 589 | def convert_to_torch_tensor(self, embeddings: list): 590 | prompt_embeds = embeddings[0] 591 | pooled_prompt_embeds = embeddings[1] 592 | prompt_embeds = np.array(prompt_embeds).reshape(154, 4096) 593 | pooled_prompt_embeds = np.array(pooled_prompt_embeds).reshape(2048) 594 | return torch.from_numpy(prompt_embeds), torch.from_numpy(pooled_prompt_embeds) 595 | 596 | def map_image_hash_embedding(self, data_df_path): 597 | hashes_df = pd.read_parquet(data_df_path) 598 | data_dict = {} 599 | for i, row in hashes_df.iterrows(): 600 | embeddings = [row["prompt_embeds"], row["pooled_prompt_embeds"]] 601 | prompt_embeds, pooled_prompt_embeds = self.convert_to_torch_tensor(embeddings=embeddings) 602 | data_dict.update({row["image_hash"]: (prompt_embeds, pooled_prompt_embeds)}) 603 | return data_dict 604 | 605 | def generate_image_hash(self, image_path): 606 | with open(image_path, "rb") as f: 607 | img_data = f.read() 608 | return hashlib.sha256(img_data).hexdigest() 609 | 610 | 611 | def collate_fn(examples): 612 | pixel_values = [example["instance_images"] for example in examples] 613 | prompt_embeds = [example["prompt_embeds"] for example in examples] 614 | pooled_prompt_embeds = [example["pooled_prompt_embeds"] for example in examples] 615 | 616 | pixel_values = torch.stack(pixel_values) 617 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 618 | prompt_embeds = torch.stack(prompt_embeds) 619 | pooled_prompt_embeds = torch.stack(pooled_prompt_embeds) 620 | 621 | batch = { 622 | "pixel_values": pixel_values, 623 | "prompt_embeds": prompt_embeds, 624 | "pooled_prompt_embeds": pooled_prompt_embeds, 625 | } 626 | return batch 627 | 628 | 629 | def main(args): 630 | if args.report_to == "wandb" and args.hub_token is not None: 631 | raise ValueError( 632 | "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." 633 | " Please use `huggingface-cli login` to authenticate with the Hub." 634 | ) 635 | 636 | if torch.backends.mps.is_available() and args.mixed_precision == "bf16": 637 | # due to pytorch#99272, MPS does not yet support bfloat16. 638 | raise ValueError( 639 | "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." 640 | ) 641 | 642 | logging_dir = Path(args.output_dir, args.logging_dir) 643 | 644 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 645 | kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) 646 | accelerator = Accelerator( 647 | gradient_accumulation_steps=args.gradient_accumulation_steps, 648 | mixed_precision=args.mixed_precision, 649 | log_with=args.report_to, 650 | project_config=accelerator_project_config, 651 | kwargs_handlers=[kwargs], 652 | ) 653 | 654 | # Disable AMP for MPS. 655 | if torch.backends.mps.is_available(): 656 | accelerator.native_amp = False 657 | 658 | if args.report_to == "wandb": 659 | if not is_wandb_available(): 660 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.") 661 | 662 | # Make one log on every process with the configuration for debugging. 663 | logging.basicConfig( 664 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 665 | datefmt="%m/%d/%Y %H:%M:%S", 666 | level=logging.INFO, 667 | ) 668 | logger.info(accelerator.state, main_process_only=False) 669 | if accelerator.is_local_main_process: 670 | transformers.utils.logging.set_verbosity_warning() 671 | diffusers.utils.logging.set_verbosity_info() 672 | else: 673 | transformers.utils.logging.set_verbosity_error() 674 | diffusers.utils.logging.set_verbosity_error() 675 | 676 | # If passed along, set the training seed now. 677 | if args.seed is not None: 678 | set_seed(args.seed) 679 | 680 | # Handle the repository creation 681 | if accelerator.is_main_process: 682 | if args.output_dir is not None: 683 | os.makedirs(args.output_dir, exist_ok=True) 684 | 685 | if args.push_to_hub: 686 | repo_id = create_repo( 687 | repo_id=args.hub_model_id or Path(args.output_dir).name, 688 | exist_ok=True, 689 | ).repo_id 690 | 691 | # Load scheduler and models 692 | noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( 693 | args.pretrained_model_name_or_path, subfolder="scheduler" 694 | ) 695 | noise_scheduler_copy = copy.deepcopy(noise_scheduler) 696 | vae = AutoencoderKL.from_pretrained( 697 | args.pretrained_model_name_or_path, 698 | subfolder="vae", 699 | revision=args.revision, 700 | variant=args.variant, 701 | ) 702 | transformer = SD3Transformer2DModel.from_pretrained( 703 | args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant 704 | ) 705 | 706 | transformer.requires_grad_(False) 707 | vae.requires_grad_(False) 708 | 709 | # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora transformer) to half-precision 710 | # as these weights are only used for inference, keeping weights in full precision is not required. 711 | weight_dtype = torch.float32 712 | if accelerator.mixed_precision == "fp16": 713 | weight_dtype = torch.float16 714 | elif accelerator.mixed_precision == "bf16": 715 | weight_dtype = torch.bfloat16 716 | 717 | if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: 718 | # due to pytorch#99272, MPS does not yet support bfloat16. 719 | raise ValueError( 720 | "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." 721 | ) 722 | 723 | vae.to(accelerator.device, dtype=torch.float32) 724 | transformer.to(accelerator.device, dtype=weight_dtype) 725 | 726 | if args.gradient_checkpointing: 727 | transformer.enable_gradient_checkpointing() 728 | 729 | # now we will add new LoRA weights to the attention layers 730 | transformer_lora_config = LoraConfig( 731 | r=args.rank, 732 | lora_alpha=args.rank, 733 | init_lora_weights="gaussian", 734 | target_modules=["to_k", "to_q", "to_v", "to_out.0"], 735 | ) 736 | transformer.add_adapter(transformer_lora_config) 737 | 738 | def unwrap_model(model): 739 | model = accelerator.unwrap_model(model) 740 | model = model._orig_mod if is_compiled_module(model) else model 741 | return model 742 | 743 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 744 | def save_model_hook(models, weights, output_dir): 745 | if accelerator.is_main_process: 746 | transformer_lora_layers_to_save = None 747 | for model in models: 748 | if isinstance(model, type(unwrap_model(transformer))): 749 | transformer_lora_layers_to_save = get_peft_model_state_dict(model) 750 | else: 751 | raise ValueError(f"unexpected save model: {model.__class__}") 752 | 753 | # make sure to pop weight so that corresponding model is not saved again 754 | weights.pop() 755 | 756 | StableDiffusion3Pipeline.save_lora_weights( 757 | output_dir, 758 | transformer_lora_layers=transformer_lora_layers_to_save, 759 | ) 760 | 761 | def load_model_hook(models, input_dir): 762 | transformer_ = None 763 | 764 | while len(models) > 0: 765 | model = models.pop() 766 | 767 | if isinstance(model, type(unwrap_model(transformer))): 768 | transformer_ = model 769 | else: 770 | raise ValueError(f"unexpected save model: {model.__class__}") 771 | 772 | lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir) 773 | 774 | transformer_state_dict = { 775 | f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.") 776 | } 777 | transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) 778 | incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") 779 | if incompatible_keys is not None: 780 | # check only for unexpected keys 781 | unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) 782 | if unexpected_keys: 783 | logger.warning( 784 | f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " 785 | f" {unexpected_keys}. " 786 | ) 787 | 788 | # Make sure the trainable params are in float32. This is again needed since the base models 789 | # are in `weight_dtype`. More details: 790 | # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 791 | if args.mixed_precision == "fp16": 792 | models = [transformer_] 793 | # only upcast trainable parameters (LoRA) into fp32 794 | cast_training_params(models) 795 | 796 | accelerator.register_save_state_pre_hook(save_model_hook) 797 | accelerator.register_load_state_pre_hook(load_model_hook) 798 | 799 | # Enable TF32 for faster training on Ampere GPUs, 800 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 801 | if args.allow_tf32 and torch.cuda.is_available(): 802 | torch.backends.cuda.matmul.allow_tf32 = True 803 | 804 | if args.scale_lr: 805 | args.learning_rate = ( 806 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 807 | ) 808 | 809 | # Make sure the trainable params are in float32. 810 | if args.mixed_precision == "fp16": 811 | models = [transformer] 812 | # only upcast trainable parameters (LoRA) into fp32 813 | cast_training_params(models, dtype=torch.float32) 814 | 815 | # Optimization parameters 816 | transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) 817 | transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} 818 | params_to_optimize = [transformer_parameters_with_lr] 819 | 820 | # Optimizer creation 821 | if not args.optimizer.lower() == "adamw": 822 | logger.warning( 823 | f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include [adamW]." 824 | "Defaulting to adamW" 825 | ) 826 | args.optimizer = "adamw" 827 | 828 | if args.use_8bit_adam and not args.optimizer.lower() == "adamw": 829 | logger.warning( 830 | f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " 831 | f"set to {args.optimizer.lower()}" 832 | ) 833 | 834 | if args.optimizer.lower() == "adamw": 835 | if args.use_8bit_adam: 836 | try: 837 | import bitsandbytes as bnb 838 | except ImportError: 839 | raise ImportError( 840 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 841 | ) 842 | 843 | optimizer_class = bnb.optim.AdamW8bit 844 | else: 845 | optimizer_class = torch.optim.AdamW 846 | 847 | optimizer = optimizer_class( 848 | params_to_optimize, 849 | betas=(args.adam_beta1, args.adam_beta2), 850 | weight_decay=args.adam_weight_decay, 851 | eps=args.adam_epsilon, 852 | ) 853 | 854 | # Dataset and DataLoaders creation: 855 | train_dataset = DreamBoothDataset( 856 | data_df_path=args.data_df_path, 857 | instance_data_root=args.instance_data_dir, 858 | instance_prompt=args.instance_prompt, 859 | size=args.resolution, 860 | center_crop=args.center_crop, 861 | ) 862 | 863 | train_dataloader = torch.utils.data.DataLoader( 864 | train_dataset, 865 | batch_size=args.train_batch_size, 866 | shuffle=True, 867 | collate_fn=lambda examples: collate_fn(examples), 868 | num_workers=args.dataloader_num_workers, 869 | ) 870 | 871 | # Scheduler and math around the number of training steps. 872 | overrode_max_train_steps = False 873 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 874 | if args.max_train_steps is None: 875 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 876 | overrode_max_train_steps = True 877 | 878 | lr_scheduler = get_scheduler( 879 | args.lr_scheduler, 880 | optimizer=optimizer, 881 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 882 | num_training_steps=args.max_train_steps * accelerator.num_processes, 883 | num_cycles=args.lr_num_cycles, 884 | power=args.lr_power, 885 | ) 886 | 887 | # Prepare everything with our `accelerator`. 888 | transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 889 | transformer, optimizer, train_dataloader, lr_scheduler 890 | ) 891 | 892 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 893 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 894 | if overrode_max_train_steps: 895 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 896 | # Afterwards we recalculate our number of training epochs 897 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 898 | 899 | # We need to initialize the trackers we use, and also store our configuration. 900 | # The trackers initializes automatically on the main process. 901 | if accelerator.is_main_process: 902 | tracker_name = "dreambooth-sd3-lora-miniature" 903 | accelerator.init_trackers(tracker_name, config=vars(args)) 904 | 905 | # Train! 906 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 907 | 908 | logger.info("***** Running training *****") 909 | logger.info(f" Num examples = {len(train_dataset)}") 910 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 911 | logger.info(f" Num Epochs = {args.num_train_epochs}") 912 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 913 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 914 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 915 | logger.info(f" Total optimization steps = {args.max_train_steps}") 916 | global_step = 0 917 | first_epoch = 0 918 | 919 | # Potentially load in the weights and states from a previous save 920 | if args.resume_from_checkpoint: 921 | if args.resume_from_checkpoint != "latest": 922 | path = os.path.basename(args.resume_from_checkpoint) 923 | else: 924 | # Get the mos recent checkpoint 925 | dirs = os.listdir(args.output_dir) 926 | dirs = [d for d in dirs if d.startswith("checkpoint")] 927 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 928 | path = dirs[-1] if len(dirs) > 0 else None 929 | 930 | if path is None: 931 | accelerator.print( 932 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 933 | ) 934 | args.resume_from_checkpoint = None 935 | initial_global_step = 0 936 | else: 937 | accelerator.print(f"Resuming from checkpoint {path}") 938 | accelerator.load_state(os.path.join(args.output_dir, path)) 939 | global_step = int(path.split("-")[1]) 940 | 941 | initial_global_step = global_step 942 | first_epoch = global_step // num_update_steps_per_epoch 943 | 944 | else: 945 | initial_global_step = 0 946 | 947 | progress_bar = tqdm( 948 | range(0, args.max_train_steps), 949 | initial=initial_global_step, 950 | desc="Steps", 951 | # Only show the progress bar once on each machine. 952 | disable=not accelerator.is_local_main_process, 953 | ) 954 | 955 | def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): 956 | sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) 957 | schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) 958 | timesteps = timesteps.to(accelerator.device) 959 | step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] 960 | 961 | sigma = sigmas[step_indices].flatten() 962 | while len(sigma.shape) < n_dim: 963 | sigma = sigma.unsqueeze(-1) 964 | return sigma 965 | 966 | for epoch in range(first_epoch, args.num_train_epochs): 967 | transformer.train() 968 | 969 | for step, batch in enumerate(train_dataloader): 970 | models_to_accumulate = [transformer] 971 | with accelerator.accumulate(models_to_accumulate): 972 | pixel_values = batch["pixel_values"].to(dtype=vae.dtype) 973 | 974 | # Convert images to latent space 975 | model_input = vae.encode(pixel_values).latent_dist.sample() 976 | model_input = model_input * vae.config.scaling_factor 977 | model_input = model_input.to(dtype=weight_dtype) 978 | 979 | # Sample noise that we'll add to the latents 980 | noise = torch.randn_like(model_input) 981 | bsz = model_input.shape[0] 982 | 983 | # Sample a random timestep for each image 984 | # for weighting schemes where we sample timesteps non-uniformly 985 | u = compute_density_for_timestep_sampling( 986 | weighting_scheme=args.weighting_scheme, 987 | batch_size=bsz, 988 | logit_mean=args.logit_mean, 989 | logit_std=args.logit_std, 990 | mode_scale=args.mode_scale, 991 | ) 992 | indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() 993 | timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) 994 | 995 | # Add noise according to flow matching. 996 | sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) 997 | noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input 998 | 999 | # Predict the noise residual 1000 | prompt_embeds, pooled_prompt_embeds = batch["prompt_embeds"], batch["pooled_prompt_embeds"] 1001 | prompt_embeds = prompt_embeds.to(device=accelerator.device, dtype=weight_dtype) 1002 | pooled_prompt_embeds = pooled_prompt_embeds.to(device=accelerator.device, dtype=weight_dtype) 1003 | model_pred = transformer( 1004 | hidden_states=noisy_model_input, 1005 | timestep=timesteps, 1006 | encoder_hidden_states=prompt_embeds, 1007 | pooled_projections=pooled_prompt_embeds, 1008 | return_dict=False, 1009 | )[0] 1010 | 1011 | # Follow: Section 5 of https://arxiv.org/abs/2206.00364. 1012 | # Preconditioning of the model outputs. 1013 | model_pred = model_pred * (-sigmas) + noisy_model_input 1014 | 1015 | # these weighting schemes use a uniform timestep sampling 1016 | # and instead post-weight the loss 1017 | weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) 1018 | 1019 | # flow matching loss 1020 | target = model_input 1021 | 1022 | # Compute regular loss. 1023 | loss = torch.mean( 1024 | (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), 1025 | 1, 1026 | ) 1027 | loss = loss.mean() 1028 | 1029 | accelerator.backward(loss) 1030 | if accelerator.sync_gradients: 1031 | params_to_clip = transformer_lora_parameters 1032 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 1033 | 1034 | optimizer.step() 1035 | lr_scheduler.step() 1036 | optimizer.zero_grad() 1037 | 1038 | # Checks if the accelerator has performed an optimization step behind the scenes 1039 | if accelerator.sync_gradients: 1040 | progress_bar.update(1) 1041 | global_step += 1 1042 | 1043 | if accelerator.is_main_process: 1044 | if global_step % args.checkpointing_steps == 0: 1045 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 1046 | if args.checkpoints_total_limit is not None: 1047 | checkpoints = os.listdir(args.output_dir) 1048 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] 1049 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) 1050 | 1051 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 1052 | if len(checkpoints) >= args.checkpoints_total_limit: 1053 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 1054 | removing_checkpoints = checkpoints[0:num_to_remove] 1055 | 1056 | logger.info( 1057 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 1058 | ) 1059 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") 1060 | 1061 | for removing_checkpoint in removing_checkpoints: 1062 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) 1063 | shutil.rmtree(removing_checkpoint) 1064 | 1065 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 1066 | accelerator.save_state(save_path) 1067 | logger.info(f"Saved state to {save_path}") 1068 | 1069 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 1070 | progress_bar.set_postfix(**logs) 1071 | accelerator.log(logs, step=global_step) 1072 | 1073 | if global_step >= args.max_train_steps: 1074 | break 1075 | 1076 | if accelerator.is_main_process: 1077 | if args.validation_prompt is not None and epoch % args.validation_epochs == 0: 1078 | pipeline = StableDiffusion3Pipeline.from_pretrained( 1079 | args.pretrained_model_name_or_path, 1080 | vae=vae, 1081 | transformer=accelerator.unwrap_model(transformer), 1082 | revision=args.revision, 1083 | variant=args.variant, 1084 | torch_dtype=weight_dtype, 1085 | ) 1086 | pipeline_args = {"prompt": args.validation_prompt} 1087 | images = log_validation( 1088 | pipeline=pipeline, 1089 | args=args, 1090 | accelerator=accelerator, 1091 | pipeline_args=pipeline_args, 1092 | epoch=epoch, 1093 | ) 1094 | torch.cuda.empty_cache() 1095 | gc.collect() 1096 | 1097 | # Save the lora layers 1098 | accelerator.wait_for_everyone() 1099 | if accelerator.is_main_process: 1100 | transformer = unwrap_model(transformer) 1101 | transformer = transformer.to(torch.float32) 1102 | transformer_lora_layers = get_peft_model_state_dict(transformer) 1103 | 1104 | StableDiffusion3Pipeline.save_lora_weights( 1105 | save_directory=args.output_dir, 1106 | transformer_lora_layers=transformer_lora_layers, 1107 | ) 1108 | 1109 | # Final inference 1110 | # Load previous pipeline 1111 | pipeline = StableDiffusion3Pipeline.from_pretrained( 1112 | args.pretrained_model_name_or_path, 1113 | revision=args.revision, 1114 | variant=args.variant, 1115 | torch_dtype=weight_dtype, 1116 | ) 1117 | # load attention processors 1118 | pipeline.load_lora_weights(args.output_dir) 1119 | 1120 | # run inference 1121 | images = [] 1122 | if args.validation_prompt and args.num_validation_images > 0: 1123 | pipeline_args = {"prompt": args.validation_prompt} 1124 | images = log_validation( 1125 | pipeline=pipeline, 1126 | args=args, 1127 | accelerator=accelerator, 1128 | pipeline_args=pipeline_args, 1129 | epoch=epoch, 1130 | is_final_validation=True, 1131 | ) 1132 | 1133 | if args.push_to_hub: 1134 | save_model_card( 1135 | repo_id, 1136 | images=images, 1137 | base_model=args.pretrained_model_name_or_path, 1138 | instance_prompt=args.instance_prompt, 1139 | validation_prompt=args.validation_prompt, 1140 | repo_folder=args.output_dir, 1141 | ) 1142 | upload_folder( 1143 | repo_id=repo_id, 1144 | folder_path=args.output_dir, 1145 | commit_message="End of training", 1146 | ignore_patterns=["step_*", "epoch_*"], 1147 | ) 1148 | 1149 | accelerator.end_training() 1150 | 1151 | 1152 | if __name__ == "__main__": 1153 | args = parse_args() 1154 | main(args) 1155 | --------------------------------------------------------------------------------