├── NetDiffusion-LoRa-SD3-Fine-Tuning.ipynb
├── README.md
└── scripts
├── color_processor.py
├── column_example.nprint
├── compute_embeddings.py
├── correct_format.nprint
├── image_to_nprint.py
├── mass_reconstruction.py
├── nprint_to_png.py
├── reconstruction.py
├── traffic_conditioning_image.png
└── train_dreambooth_lora_sd3_miniature.py
/NetDiffusion-LoRa-SD3-Fine-Tuning.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "scrolled": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "# 📦 Install Dependencies\n",
12 | "# \n",
13 | "# 1. Installs the latest development version of Hugging Face's `diffusers` library directly from GitHub.\n",
14 | "# This library is used for working with diffusion models such as Stable Diffusion.\n",
15 | "!pip install -q -U git+https://github.com/huggingface/diffusers\n",
16 | "# 2. Installs and upgrades several core Hugging Face and optimization libraries:\n",
17 | "# - `transformers`: For using pre-trained language and vision models\n",
18 | "# - `accelerate`: Simplifies training across CPUs, GPUs, and distributed setups\n",
19 | "# - `wandb`: For logging experiments and visualizing training progress\n",
20 | "# - `bitsandbytes`: Enables 8-bit model loading for memory-efficient inference/training\n",
21 | "# - `peft`: For Parameter-Efficient Fine-Tuning of large models\n",
22 | "!pip install -q -U \\\n",
23 | " transformers \\\n",
24 | " accelerate \\\n",
25 | " wandb \\\n",
26 | " bitsandbytes \\\n",
27 | " peft\n",
28 | "# 3. Installs additional Python libraries needed for data handling, model input/output,\n",
29 | "# and networking or computer vision tasks:\n",
30 | "# - `pandas`: For data manipulation and tabular processing\n",
31 | "# - `torchvision`: For image transformations and loading datasets (used with PyTorch)\n",
32 | "# - `pyarrow`: For efficient I/O and working with Apache Arrow / Parquet formats\n",
33 | "# - `sentencepiece`: For subword tokenization used in many NLP models\n",
34 | "# - `controlnet_aux`: Adds support functions for ControlNet like HED, Canny, Depth, etc.\n",
35 | "# - `scapy`: For packet parsing and crafting, often used in networking/PCAP analysis\n",
36 | "# - `gdown`: For downloading files from Google Drive using file IDs\n",
37 | "# - `opencv-python`: For computer vision tasks and image manipulation\n",
38 | "!pip install pandas torchvision pyarrow sentencepiece controlnet_aux scapy gdown opencv-python"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "metadata": {},
45 | "outputs": [],
46 | "source": [
47 | "#As SD3 is gated, before using it with diffusers you first need to go to the Stable Diffusion 3 Medium Hugging Face page, fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:\n",
48 | "#Ignore if already logged in.\n",
49 | "!huggingface-cli login"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": null,
55 | "metadata": {},
56 | "outputs": [],
57 | "source": [
58 | "# Download the example preprocessed traffic dataset.\n",
59 | "#\n",
60 | "# To create your own dataset:\n",
61 | "# 1) Install nPrint in your environment (see https://nprint.github.io/nprint/).\n",
62 | "# 2) Collect PCAPs belonging to the same service/application you plan to model.\n",
63 | "# 3) Convert each PCAP to an nPrint file with the following command:\n",
64 | "# nprint -F -1 -P {pcap} -4 -i -6 -t -u -p 0 -c 1024 -W {output_file}\n",
65 | "# 4) Place all resulting nPrint files in a folder named \"nprint_traffic\".\n",
66 | "\n",
67 | "# Download the example publically available preprocessed traffic dataset\n",
68 | "!gdown --id 1vvneSH0a1WZFPHTafKusOUjNhg7oQioq --output preprocessed_dataset.zip\n",
69 | "\n",
70 | "# Unzip and move files into desired directory\n",
71 | "!unzip -q preprocessed_dataset.zip\n",
72 | "!mkdir -p nprint_traffic\n",
73 | "!mv amazon_nprint_traffic/* nprint_traffic/\n",
74 | "\n",
75 | "# Clean up\n",
76 | "!rm -r amazon_nprint_traffic __MACOSX preprocessed_dataset.zip"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "metadata": {
83 | "scrolled": true
84 | },
85 | "outputs": [],
86 | "source": [
87 | "# -----------------------------------------------------------\n",
88 | "# 🖼️ Convert the nPrint representation of PCAPs into PNG images\n",
89 | "#\n",
90 | "# This script is designed to assist in transforming `.nprint` files—\n",
91 | "# tabular feature representations of network packets—into fixed-size\n",
92 | "# PNG images that can be used as input for SD fine-tuning.\n",
93 | "#\n",
94 | "# 🧾 Input:\n",
95 | "# - `.nprint` files generated from packet captures (PCAPs) using the\n",
96 | "# nPrint tool. Each file is a CSV-like matrix where each row\n",
97 | "# represents a single packet and columns correspond to extracted features.\n",
98 | "#\n",
99 | "# 🧹 Preprocessing:\n",
100 | "# - Drops IP address-related columns to avoid injecting identifiable\n",
101 | "# or non-generalizable information into the model.\n",
102 | "# - Maps integer values in the remaining columns to RGBA color tuples\n",
103 | "# to visualize numeric features as colored pixels.\n",
104 | "#\n",
105 | "# 🧱 Padding:\n",
106 | "# - Pads each image to a uniform height (default 1024) using a solid\n",
107 | "# background to ensure model input consistency across varying packet counts.\n",
108 | "#\n",
109 | "# 🎨 Output:\n",
110 | "# - Saves a PNG file for each `.nprint` file, preserving the packet structure\n",
111 | "# as a vertically stacked color-coded image (rows = packets, cols = features).\n",
112 | "# -----------------------------------------------------------\n",
113 | "!python ./scripts/nprint_to_png.py -i ./nprint_traffic/ -o ./nprint_traffic_images"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": null,
119 | "metadata": {},
120 | "outputs": [],
121 | "source": [
122 | "#Compute embeddings:\n",
123 | "#Generates text prompt embeddings via a Stable Diffusion 3 pipeline and T5 text encoder.\n",
124 | "#Maps each local PNG image to a unique SHA-256 hash and associates it with the computed embeddings.\n",
125 | "#Stores the resulting image-hash-to-embedding data in a .parquet file for further processing.\n",
126 | "#Here we are using the default instance prompt \"pixelated network data for type-0 application traffic\".\n",
127 | "#But you can configure this. Refer to the compute_embeddings.py script for details on other supported arguments.\n",
128 | "!python ./scripts/compute_embeddings.py"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": null,
134 | "metadata": {},
135 | "outputs": [],
136 | "source": [
137 | "#Clear memory\n",
138 | "import torch\n",
139 | "import gc\n",
140 | "\n",
141 | "def flush():\n",
142 | " torch.cuda.empty_cache()\n",
143 | " gc.collect()\n",
144 | "\n",
145 | "flush()"
146 | ]
147 | },
148 | {
149 | "cell_type": "code",
150 | "execution_count": null,
151 | "metadata": {},
152 | "outputs": [],
153 | "source": [
154 | "# -----------------------------------------------------------\n",
155 | "# 🚀 Train LoRA Adapter on Stable Diffusion 3 (Miniature Setup)\n",
156 | "#\n",
157 | "# This command launches training using `accelerate` with DreamBooth-style LoRA tuning,\n",
158 | "# optimized for quick experimentation or demo runs.\n",
159 | "#\n",
160 | "# ⚠️ Current configuration uses:\n",
161 | "# - Only 1 training step\n",
162 | "# - Small batch size\n",
163 | "# - No warmup\n",
164 | "# - High learning rate\n",
165 | "#\n",
166 | "# Intended **only for testing or verifying training scripts**, NOT for quality results.\n",
167 | "# For actual training, increase `max_train_steps`, adjust learning rate, and consider\n",
168 | "# enabling full validation and saving checkpoints.\n",
169 | "# -----------------------------------------------------------\n",
170 | "!accelerate launch ./scripts/train_dreambooth_lora_sd3_miniature.py \\\n",
171 | " --pretrained_model_name_or_path=\"stabilityai/stable-diffusion-3-medium-diffusers\" \\\n",
172 | " --instance_data_dir=\"nprint_traffic_images\" \\\n",
173 | " --data_df_path=\"sample_embeddings.parquet\" \\\n",
174 | " --output_dir=\"trained-sd3-lora-miniature\" \\\n",
175 | " --mixed_precision=\"fp16\" \\\n",
176 | " --instance_prompt=\"pixelated network data for type-0 application traffic\" \\\n",
177 | " --train_batch_size=2 \\\n",
178 | " --gradient_accumulation_steps=1 --gradient_checkpointing \\\n",
179 | " --use_8bit_adam \\\n",
180 | " --learning_rate=5e-5 \\\n",
181 | " --report_to=\"wandb\" \\\n",
182 | " --lr_scheduler=\"constant\" \\\n",
183 | " --lr_warmup_steps=0 \\\n",
184 | " --max_train_steps=1 \\\n",
185 | " --seed=\"0\""
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": null,
191 | "metadata": {},
192 | "outputs": [],
193 | "source": [
194 | "# -----------------------------------------------------------\n",
195 | "# 🎯 Inference Pipeline for SD3 + ControlNet + LoRA\n",
196 | "#\n",
197 | "# This cell demonstrates the full generation process:\n",
198 | "# - Loads Stable Diffusion 3 Medium base model\n",
199 | "# - Loads ControlNet (Canny edge-based guidance)\n",
200 | "# - Loads LoRA weights fine-tuned on pixelated network traffic\n",
201 | "# - Applies edge-conditioned generation using a sample input\n",
202 | "#\n",
203 | "# -----------------------------------------------------------\n",
204 | "flush()\n",
205 | "import os\n",
206 | "import torch\n",
207 | "import cv2\n",
208 | "from PIL import Image\n",
209 | "from diffusers import StableDiffusion3ControlNetPipeline, SD3ControlNetModel\n",
210 | "from diffusers.utils import load_image\n",
211 | "\n",
212 | "# Make sure our output folder exists\n",
213 | "os.makedirs(\"generated_traffic_images\", exist_ok=True)\n",
214 | "\n",
215 | "# Base SD 3.0 model\n",
216 | "base_model_path = \"stabilityai/stable-diffusion-3-medium-diffusers\"\n",
217 | "\n",
218 | "# Canny-based ControlNet\n",
219 | "controlnet_path = \"InstantX/SD3-Controlnet-Canny\"\n",
220 | "\n",
221 | "# Load the ControlNet and pipeline\n",
222 | "controlnet = SD3ControlNetModel.from_pretrained(\n",
223 | " controlnet_path, torch_dtype=torch.float16\n",
224 | ")\n",
225 | "pipe = StableDiffusion3ControlNetPipeline.from_pretrained(\n",
226 | " base_model_path,\n",
227 | " controlnet=controlnet,\n",
228 | ")\n",
229 | "\n",
230 | "# Load LoRA weights\n",
231 | "lora_output_path = \"trained-sd3-lora-miniature\"\n",
232 | "pipe.load_lora_weights(lora_output_path)\n",
233 | "\n",
234 | "# Move pipeline to GPU (half precision)\n",
235 | "pipe.to(\"cuda\", torch.float16)\n",
236 | "pipe.enable_sequential_cpu_offload()\n",
237 | "\n",
238 | "# ----------------------------------------------------\n",
239 | "# 1) Convert original control image to Canny edges via OpenCV\n",
240 | "# ----------------------------------------------------\n",
241 | "orig_path = \"./scripts/traffic_conditioning_image.png\"\n",
242 | "orig_bgr = cv2.imread(orig_path)\n",
243 | "if orig_bgr is None:\n",
244 | " raise ValueError(f\"Could not load file: {orig_path}\")\n",
245 | "\n",
246 | "# Convert to grayscale\n",
247 | "gray = cv2.cvtColor(orig_bgr, cv2.COLOR_BGR2GRAY)\n",
248 | "\n",
249 | "# Generate Canny edge map (tweak thresholds as needed)\n",
250 | "edges = cv2.Canny(gray, 100, 200)\n",
251 | "\n",
252 | "# Convert single-channel edge map to 3-channel RGB\n",
253 | "edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)\n",
254 | "\n",
255 | "# Convert to PIL for use in ControlNet pipeline\n",
256 | "control_image = Image.fromarray(edges_rgb)\n",
257 | "orig_width, orig_height = control_image.size\n",
258 | "target_width = 1024\n",
259 | "if orig_width > target_width:\n",
260 | " # (left, upper, right, lower)\n",
261 | " control_image = control_image.crop((0, 0, target_width, orig_height))\n",
262 | " \n",
263 | "print(\"Displaying Canny control image:\")\n",
264 | "\n",
265 | "display(control_image)\n",
266 | "\n",
267 | "# ----------------------------------------------------\n",
268 | "# 3) Set up prompts and run pipeline\n",
269 | "# ----------------------------------------------------\n",
270 | "prompt = \"pixelated network data for type-0 application traffic\"\n",
271 | "generator = torch.manual_seed(0) # reproducibility\n",
272 | "\n",
273 | "# Generate at 1024×1024 to match the new control image\n",
274 | "image = pipe(\n",
275 | " prompt=prompt,\n",
276 | " num_inference_steps=20,\n",
277 | " generator=generator,\n",
278 | " height=1024,\n",
279 | " width=1088,\n",
280 | " control_image=control_image,\n",
281 | " controlnet_conditioning_scale=0.5, # increase to adhere more strongly to edges\n",
282 | ").images[0]\n",
283 | "\n",
284 | "# ----------------------------------------------------\n",
285 | "# 4) Save the generated image\n",
286 | "# ----------------------------------------------------\n",
287 | "output_path = os.path.join(\"generated_traffic_images\", \"generated_traffic.png\")\n",
288 | "image.save(output_path)\n",
289 | "print(f\"Generated image saved to: {output_path}\")\n"
290 | ]
291 | },
292 | {
293 | "cell_type": "code",
294 | "execution_count": null,
295 | "metadata": {},
296 | "outputs": [],
297 | "source": [
298 | "# -----------------------------------------------------------\n",
299 | "# 🔄 Post-Generation Processing Pipeline\n",
300 | "#\n",
301 | "# This cell performs a 3-stage transformation of the generated PNG images:\n",
302 | "# 1. Applies color correction for standardization.\n",
303 | "# 2. Converts augmented images back into nPrint-compatible feature format.\n",
304 | "# 3. Applies heuristic corrections and reconstructs valid PCAP files.\n",
305 | "#\n",
306 | "# ⚙️ This pipeline enables turning synthetic traffic images\n",
307 | "# back into replayable network traffic for evaluation or simulation.\n",
308 | "# -----------------------------------------------------------\n",
309 | "\n",
310 | "# 🎨 Step 1: Color Augmentation\n",
311 | "# Applies standardized color shifts to improve nprint reconstruction accuracy\n",
312 | "!python ./scripts/color_processor.py \\\n",
313 | " --input_dir=\"./generated_traffic_images\" \\\n",
314 | " --output_dir=\"./color_corrected_generated_traffic_images\"\n",
315 | "# -----------------------------------------------------------\n",
316 | "# 🔁 Step 2: Image-to-nPrint Conversion\n",
317 | "# Converts augmented PNG images back into `.nprint` tabular format.\n",
318 | "#\n",
319 | "# Uses a reference `.nprint` file to maintain consistent structure and column order.\n",
320 | "# This step allows diffusion-generated visual traffic to be fed into analysis tools.\n",
321 | "# -----------------------------------------------------------\n",
322 | "!python ./scripts/image_to_nprint.py \\\n",
323 | " --org_nprint ./scripts/column_example.nprint \\\n",
324 | " --input_dir ./color_corrected_generated_traffic_images \\\n",
325 | " --output_dir ./generated_nprint\n",
326 | "# -----------------------------------------------------------\n",
327 | "# 🧠 Step 3: Heuristic Correction & PCAP Reconstruction\n",
328 | "#\n",
329 | "# This step reconstructs a valid and replayable `.pcap` file\n",
330 | "# from the diffusion-generated `.nprint` representation.\n",
331 | "# 🔍 Core Functionalities:\n",
332 | "# ✅ Intra-packet corrections (fixes within individual packets).\n",
333 | "# 🔁 Inter-packet dependency enforcement.\n",
334 | "# 🔧 Reconstruction:\n",
335 | "# - Save the corrected `.nprint` to disk\n",
336 | "# - Call `nprint -W` to convert `.nprint` into `.pcap` using external tool\n",
337 | "# - Run Scapy-based checksum updates to ensure IPv4 validity\n",
338 | "# - Reconvert final `.pcap` back to `.nprint` (with fixed layout) for downstream tasks\n",
339 | "# -----------------------------------------------------------\n",
340 | "!python ./scripts/mass_reconstruction.py \\\n",
341 | " --input_dir ./generated_nprint \\\n",
342 | " --output_pcap_dir ./replayable_generated_pcaps \\\n",
343 | " --output_nprint_dir ./replayable_generated_nprints \\\n",
344 | " --formatted_nprint_path ./scripts/correct_format.nprint\n",
345 | "# Final Pcap is stored in replayable_generated_pcaps"
346 | ]
347 | },
348 | {
349 | "cell_type": "code",
350 | "execution_count": null,
351 | "metadata": {},
352 | "outputs": [],
353 | "source": []
354 | },
355 | {
356 | "cell_type": "code",
357 | "execution_count": null,
358 | "metadata": {},
359 | "outputs": [],
360 | "source": []
361 | },
362 | {
363 | "cell_type": "code",
364 | "execution_count": null,
365 | "metadata": {},
366 | "outputs": [],
367 | "source": []
368 | }
369 | ],
370 | "metadata": {
371 | "kernelspec": {
372 | "display_name": "Python (new_netdiffusion)",
373 | "language": "python",
374 | "name": "new_netdiffusion"
375 | },
376 | "language_info": {
377 | "codemirror_mode": {
378 | "name": "ipython",
379 | "version": 3
380 | },
381 | "file_extension": ".py",
382 | "mimetype": "text/x-python",
383 | "name": "python",
384 | "nbconvert_exporter": "python",
385 | "pygments_lexer": "ipython3",
386 | "version": "3.10.6"
387 | }
388 | },
389 | "nbformat": 4,
390 | "nbformat_minor": 4
391 | }
392 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # 🌐 NetDiffusion: High-Fidelity Synthetic Network Traffic Generation
6 |
7 |
8 |
9 |
10 |
11 | ---
12 |
13 | ## 📘 Introduction
14 |
15 | **NetDiffusion** is an innovative tool designed to solve one of the core bottlenecks in networking ML research: the lack of high-quality, labeled, and privacy-preserving network traces.
16 |
17 | Traditional datasets often suffer from:
18 | - ⚠️ **Privacy concerns**
19 | - 🕓 **Data staleness**
20 | - 📉 **Limited diversity**
21 |
22 | NetDiffusion addresses these issues by using a **protocol-aware Stable Diffusion model** to synthesize network traffic that is both **realistic** and **standards-compliant**.
23 |
24 | > 🧪 The result? Synthetic packet captures that look and behave like real traffic—ideal for model training, testing, and simulation.
25 |
26 | ---
27 |
28 | ## ✨ Features
29 |
30 | - ✅ **High-Fidelity Data Generation**
31 | Generate synthetic traffic that matches real-world patterns and protocol semantics.
32 |
33 | - 🔌 **Tool Compatibility**
34 | Output traces are `.pcap` files—ready for use with Wireshark, Zeek, tshark, and other standard tools.
35 |
36 | - 🛠️ **Multi-Use Support**
37 | Beyond ML: Useful for system testing, anomaly detection, protocol emulation, and more.
38 |
39 | - 💡 **Fully Open Source**
40 | Built for the community. Modify, extend, and contribute freely.
41 |
42 | ---
43 |
44 | ## 📝 Note
45 |
46 | - The original **NetDiffusion** was implemented using **Stable Diffusion 1.5**, which is now deprecated with outdated dependencies.
47 | - This repo provides a **modern reimplementation using Stable Diffusion 3.0**, integrated with **InstantX/SD3-Controlnet-Canny**, preserving the framework’s core concepts while upgrading for compatibility and stability.
48 |
49 | ---
50 |
51 | ## 🗂 Project Structure
52 |
53 | - 🔧 All core scripts for preprocessing, training, inference, and reconstruction are located in the [`scripts/`](./scripts/) directory.
54 | - 📓 A step-by-step **Jupyter notebook** walks you through the entire pipeline:
55 |
56 | - 📦 **Dependency Installation**
57 | - 🧼 **Preprocessing (`.nprint` → `.png`)**
58 | - 🧠 **LoRA Fine-Tuning** on structured packet image embeddings
59 | - 🎨 **Diffusion-Based Generation** using ControlNet (Canny conditioning)
60 | - 🔄 **Post-Generation Processing**
61 | - Color correction
62 | - `.png` → `.nprint` → `.pcap` conversion
63 | - Replayable `.pcap` synthesis with protocol repair
64 |
65 | > ⚙️ The reimplementation is fully modular and forward-compatible, enabling seamless experimentation with next-gen diffusion architectures.
66 |
67 | ---
68 |
69 | ## 📚 Citing NetDiffusion
70 |
71 | If you use this tool or build on its techniques, please cite:
72 |
73 | ```bibtex
74 | @article{jiang2024netdiffusion,
75 | title={NetDiffusion: Network Data Augmentation Through Protocol-Constrained Traffic Generation},
76 | author={Jiang, Xi and Liu, Shinan and Gember-Jacobson, Aaron and Bhagoji, Arjun Nitin and Schmitt, Paul and Bronzino, Francesco and Feamster, Nick},
77 | journal={Proceedings of the ACM on Measurement and Analysis of Computing Systems},
78 | volume={8},
79 | number={1},
80 | pages={1--32},
81 | year={2024},
82 | publisher={ACM New York, NY, USA}
83 | }
84 |
--------------------------------------------------------------------------------
/scripts/color_processor.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | import argparse
4 | from PIL import Image
5 |
6 | def process_image(input_file, output_file):
7 | image = Image.open(input_file).convert('RGBA')
8 | width, height = image.size
9 |
10 | # Red thresholds
11 | r_red = 0
12 | g_red = 160
13 | b_red = 100
14 |
15 | # Green thresholds
16 | r_green = 0
17 | g_green = 160
18 | b_green = 100
19 |
20 | # Blue thresholds
21 | r_blue = 0
22 | g_blue = 160
23 | b_blue = 100
24 |
25 | for x in range(width):
26 | for y in range(height):
27 | r, g, b, a = image.getpixel((x, y))
28 |
29 | if r > r_red and g < g_red and b < b_red:
30 | # Set the pixel to red
31 | image.putpixel((x, y), (255, 0, 0, 255))
32 | elif r < r_green and g > g_green and b < b_green:
33 | # Set the pixel to green
34 | image.putpixel((x, y), (0, 255, 0, 255))
35 | elif r < r_blue and g < g_blue and b > b_blue:
36 | # Set the pixel to blue
37 | image.putpixel((x, y), (0, 0, 255, 255))
38 | else:
39 | # Choose the max channel and set to that color
40 | max_color = max(r, g, b)
41 | if max_color == r:
42 | image.putpixel((x, y), (255, 0, 0, 255))
43 | elif max_color == g:
44 | image.putpixel((x, y), (0, 255, 0, 255))
45 | else:
46 | image.putpixel((x, y), (0, 0, 255, 255))
47 |
48 | image.save(output_file)
49 |
50 | def main():
51 | parser = argparse.ArgumentParser(
52 | description="Process all .png images in a directory and save them to another directory."
53 | )
54 | parser.add_argument("--input_dir", "-i", required=True, help="Path to input directory containing .png images.")
55 | parser.add_argument("--output_dir", "-o", required=True, help="Path to output directory for processed images.")
56 | args = parser.parse_args()
57 |
58 | os.makedirs(args.output_dir, exist_ok=True)
59 |
60 | count = 0
61 | for filename in os.listdir(args.input_dir):
62 | if filename.lower().endswith(".png"):
63 | input_path = os.path.join(args.input_dir, filename)
64 | output_path = os.path.join(args.output_dir, filename)
65 | process_image(input_path, output_path)
66 | count += 1
67 |
68 | print(f"Processed {count} images.")
69 |
70 | if __name__ == "__main__":
71 | main()
72 |
--------------------------------------------------------------------------------
/scripts/compute_embeddings.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import argparse
18 | import glob
19 | import hashlib
20 |
21 | import pandas as pd
22 | import torch
23 | from transformers import T5EncoderModel
24 |
25 | from diffusers import StableDiffusion3Pipeline
26 |
27 |
28 | PROMPT = "pixelated network data for type-0 application traffic"
29 | MAX_SEQ_LENGTH = 77
30 | LOCAL_DATA_DIR = "nprint_traffic_images"
31 | OUTPUT_PATH = "sample_embeddings.parquet"
32 |
33 |
34 | def bytes_to_giga_bytes(bytes):
35 | return bytes / 1024 / 1024 / 1024
36 |
37 |
38 | def generate_image_hash(image_path):
39 | with open(image_path, "rb") as f:
40 | img_data = f.read()
41 | return hashlib.sha256(img_data).hexdigest()
42 |
43 |
44 | def load_sd3_pipeline():
45 | id = "stabilityai/stable-diffusion-3-medium-diffusers"
46 | text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_3", load_in_8bit=True, device_map="auto")
47 | pipeline = StableDiffusion3Pipeline.from_pretrained(
48 | id, text_encoder_3=text_encoder, transformer=None, vae=None, device_map="balanced"
49 | )
50 | return pipeline
51 |
52 |
53 | @torch.no_grad()
54 | def compute_embeddings(pipeline, prompt, max_sequence_length):
55 | (
56 | prompt_embeds,
57 | negative_prompt_embeds,
58 | pooled_prompt_embeds,
59 | negative_pooled_prompt_embeds,
60 | ) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, prompt_3=None, max_sequence_length=max_sequence_length)
61 |
62 | print(
63 | f"{prompt_embeds.shape=}, {negative_prompt_embeds.shape=}, {pooled_prompt_embeds.shape=}, {negative_pooled_prompt_embeds.shape}"
64 | )
65 |
66 | max_memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated())
67 | print(f"Max memory allocated: {max_memory:.3f} GB")
68 | return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
69 |
70 |
71 | def run(args):
72 | pipeline = load_sd3_pipeline()
73 | prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = compute_embeddings(
74 | pipeline, args.prompt, args.max_sequence_length
75 | )
76 |
77 | # Assumes that the images within `args.local_image_dir` have a png extension. Change
78 | # as needed.
79 | image_paths = glob.glob(f"{args.local_data_dir}/*.png")
80 | data = []
81 | for image_path in image_paths:
82 | img_hash = generate_image_hash(image_path)
83 | data.append(
84 | (img_hash, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds)
85 | )
86 |
87 | # Create a DataFrame
88 | embedding_cols = [
89 | "prompt_embeds",
90 | "negative_prompt_embeds",
91 | "pooled_prompt_embeds",
92 | "negative_pooled_prompt_embeds",
93 | ]
94 | df = pd.DataFrame(
95 | data,
96 | columns=["image_hash"] + embedding_cols,
97 | )
98 |
99 | # Convert embedding lists to arrays (for proper storage in parquet)
100 | for col in embedding_cols:
101 | df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist())
102 |
103 | # Save the dataframe to a parquet file
104 | df.to_parquet(args.output_path)
105 | print(f"Data successfully serialized to {args.output_path}")
106 |
107 |
108 | if __name__ == "__main__":
109 | parser = argparse.ArgumentParser()
110 | parser.add_argument("--prompt", type=str, default=PROMPT, help="The instance prompt.")
111 | parser.add_argument(
112 | "--max_sequence_length",
113 | type=int,
114 | default=MAX_SEQ_LENGTH,
115 | help="Maximum sequence length to use for computing the embeddings. The more the higher computational costs.",
116 | )
117 | parser.add_argument(
118 | "--local_data_dir", type=str, default=LOCAL_DATA_DIR, help="Path to the directory containing instance images."
119 | )
120 | parser.add_argument("--output_path", type=str, default=OUTPUT_PATH, help="Path to serialize the parquet file.")
121 | args = parser.parse_args()
122 |
123 | run(args)
124 |
--------------------------------------------------------------------------------
/scripts/image_to_nprint.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import os
4 | import pandas as pd
5 | from PIL import Image
6 |
7 | def ip_to_binary(ip_address):
8 | # Split the IP address into four octets
9 | octets = ip_address.split(".")
10 | # Convert each octet to binary form and pad with zeros
11 | binary_octets = [bin(int(octet))[2:].zfill(8) for octet in octets]
12 | return "".join(binary_octets)
13 |
14 | def binary_to_ip(binary_ip_address):
15 | if len(binary_ip_address) != 32:
16 | raise ValueError("Input binary string must be 32 bits")
17 | octets = [binary_ip_address[i : i + 8] for i in range(0, 32, 8)]
18 | decimal_octets = [str(int(octet, 2)) for octet in octets]
19 | return ".".join(decimal_octets)
20 |
21 | def rgba_to_ip(rgba):
22 | ip_parts = tuple(map(str, rgba))
23 | return ".".join(ip_parts)
24 |
25 | def int_to_rgba(A):
26 | if A == 1:
27 | return (255, 0, 0, 255)
28 | elif A == 0:
29 | return (0, 255, 0, 255)
30 | elif A == -1:
31 | return (0, 0, 255, 255)
32 | elif A > 1:
33 | return (255, 0, 0, A)
34 | elif A < -1:
35 | return (0, 0, 255, abs(A))
36 | else:
37 | return None
38 |
39 | def rgba_to_int(rgba):
40 | # Typical RGBA conversions
41 | if rgba == (255, 0, 0, 255):
42 | return 1
43 | elif rgba == (0, 255, 0, 255):
44 | return 0
45 | elif rgba == (0, 0, 255, 255):
46 | return -1
47 | elif rgba[0] == 255 and rgba[1] == 0 and rgba[2] == 0:
48 | return rgba[3]
49 | elif rgba[0] == 0 and rgba[1] == 0 and rgba[2] == 255:
50 | return -rgba[3]
51 | else:
52 | return None
53 |
54 | def split_bits(s):
55 | return [int(b) for b in s]
56 |
57 | def png_to_dataframe(input_file, columns):
58 | """
59 | Convert a PNG to a DataFrame, using columns that match the reference .nprint CSV.
60 | """
61 | img = Image.open(input_file).convert("RGBA")
62 | width, height = img.size
63 | print(f"Processing {input_file} with size {width} x {height}")
64 |
65 | data = []
66 | for y in range(height):
67 | row = []
68 | for x in range(width):
69 | rgba = img.getpixel((x, y))
70 | val = rgba_to_int(rgba)
71 | row.append(val)
72 | data.append(row)
73 |
74 | # Construct a DataFrame using the provided column names
75 | df = pd.DataFrame(data, columns=columns)
76 | return df
77 |
78 | def main():
79 | parser = argparse.ArgumentParser(description="Convert .png files back into .nprint format.")
80 | parser.add_argument(
81 | "--org_nprint",
82 | required=True,
83 | help="Path to the original .nprint CSV file for column reference."
84 | )
85 | parser.add_argument(
86 | "--input_dir",
87 | required=True,
88 | help="Directory containing the .png images to convert."
89 | )
90 | parser.add_argument(
91 | "--output_dir",
92 | required=True,
93 | help="Directory to save the resulting .nprint files."
94 | )
95 | args = parser.parse_args()
96 |
97 | # 1) Load the reference .nprint CSV just to extract columns
98 | org_df = pd.read_csv(args.org_nprint)
99 | # If there's an unwanted "Unnamed: 0" column, remove it
100 | if "Unnamed: 0" in org_df.columns:
101 | org_df = org_df.drop("Unnamed: 0", axis=1)
102 |
103 | # Extract the columns to preserve the same structure
104 | columns = org_df.columns.tolist()
105 |
106 | # 2) Make sure output directory exists
107 | os.makedirs(args.output_dir, exist_ok=True)
108 |
109 | # 3) Iterate over .png files in input_dir
110 | converted_count = 0
111 | for filename in os.listdir(args.input_dir):
112 | if filename.lower().endswith(".png"):
113 | input_file = os.path.join(args.input_dir, filename)
114 | df = png_to_dataframe(input_file, columns=columns)
115 |
116 | # 4) Save as .nprint in output_dir
117 | output_file = os.path.join(args.output_dir, filename.replace(".png", ".nprint"))
118 | df.to_csv(output_file, index=False)
119 | converted_count += 1
120 | print(f"Saved {output_file}")
121 |
122 | print(f"Done! Converted {converted_count} .png files into .nprint format.")
123 |
124 | if __name__ == "__main__":
125 | main()
126 |
--------------------------------------------------------------------------------
/scripts/mass_reconstruction.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import os
4 | import sys
5 |
6 |
7 | def main():
8 | parser = argparse.ArgumentParser(description="Generate pcap files from .nprint using reconstruction.py.")
9 | parser.add_argument(
10 | "--input_dir",
11 | default="./generated_nprint",
12 | help="Directory containing generated .nprint files to process."
13 | )
14 | parser.add_argument(
15 | "--output_pcap_dir",
16 | default="./replayable_generated_pcaps",
17 | help="Directory to save the resulting .pcap files."
18 | )
19 | parser.add_argument(
20 | "--output_nprint_dir",
21 | default="./replayable_generated_nprints",
22 | help="Directory to save any newly generated .nprint files."
23 | )
24 | parser.add_argument(
25 | "--formatted_nprint_path",
26 | default="./correct_format.nprint",
27 | help="Path to the 'correct_format.nprint' reference file."
28 | )
29 | args = parser.parse_args()
30 |
31 | # Create the output directories if they don't exist
32 | os.makedirs(args.output_pcap_dir, exist_ok=True)
33 | os.makedirs(args.output_nprint_dir, exist_ok=True)
34 |
35 | # Loop through all .nprint files in args.input_dir
36 | input_files = os.listdir(args.input_dir)
37 | if not input_files:
38 | print(f"No files found in {args.input_dir}")
39 | sys.exit(0)
40 |
41 | for org_nprint in input_files:
42 | if not org_nprint.endswith(".nprint"):
43 | # skip non-nprint files
44 | continue
45 |
46 | org_nprint_path = os.path.join(args.input_dir, org_nprint)
47 | # Construct output file paths
48 | output_pcap_path = os.path.join(
49 | args.output_pcap_dir,
50 | org_nprint.replace(".nprint", ".pcap"),
51 | )
52 | output_nprint_path = os.path.join(
53 | args.output_nprint_dir,
54 | org_nprint,
55 | )
56 |
57 | print(f"\nProcessing: {org_nprint_path} -> {output_pcap_path}")
58 |
59 | # Verify the input file can be accessed (catch OS or IO errors)
60 | try:
61 | if not os.path.isfile(org_nprint_path):
62 | print(f"Skipping: {org_nprint_path} is not a valid file.")
63 | continue
64 | except Exception as e:
65 | print(f"Error accessing {org_nprint_path}: {e}")
66 | continue
67 |
68 | # Prepare the system command
69 | cmd = (
70 | f"python3 ./scripts/reconstruction.py "
71 | f"--generated_nprint_path '{org_nprint_path}' "
72 | f"--formatted_nprint_path '{args.formatted_nprint_path}' "
73 | f"--output '{output_pcap_path}' "
74 | f"--nprint '{output_nprint_path}'"
75 | )
76 |
77 | # Run reconstruction.py
78 | print(f"Running command:\n {cmd}")
79 | try:
80 | return_code = os.system(cmd)
81 | if return_code != 0:
82 | # Non-zero exit code indicates an error
83 | print(f"ERROR: reconstruction.py failed with return code {return_code}")
84 | print("Skipping this file.")
85 | continue
86 | except Exception as e:
87 | # If os.system itself fails (rare)
88 | print(f"Exception while running reconstruction command: {e}")
89 | print("Skipping this file.")
90 | continue
91 |
92 | print(f"Success! Created {output_pcap_path} and possibly updated {output_nprint_path}.")
93 |
94 |
95 | if __name__ == "__main__":
96 | main()
97 |
--------------------------------------------------------------------------------
/scripts/nprint_to_png.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import pandas as pd
4 | import numpy as np
5 | from PIL import Image
6 |
7 | def int_to_rgba(A):
8 | if A == 1:
9 | return (255, 0, 0, 255)
10 | elif A == 0:
11 | return (0, 255, 0, 255)
12 | elif A == -1:
13 | return (0, 0, 255, 255)
14 | elif A > 1:
15 | return (255, 0, 0, A)
16 | elif A < -1:
17 | return (0, 0, 255, abs(A))
18 | else:
19 | return (0, 0, 0, 255) # default fallback
20 |
21 | def dataframe_to_png(df, output_file):
22 | width, height = df.shape[1], df.shape[0]
23 | padded_height = 1024
24 |
25 | np_img = np.full((padded_height, width, 4), (0, 0, 255, 255), dtype=np.uint8)
26 | np_df = np.array(df.applymap(np.array).to_numpy().tolist())
27 | np_img[:height, :, :] = np_df
28 |
29 | img = Image.fromarray(np_img, 'RGBA')
30 |
31 | file_exists = True
32 | counter = 1
33 | file_path, file_extension = os.path.splitext(output_file)
34 | while file_exists:
35 | if os.path.isfile(output_file):
36 | output_file = f"{file_path}_{counter}{file_extension}"
37 | counter += 1
38 | else:
39 | file_exists = False
40 |
41 | img.save(output_file)
42 |
43 | def convert_nprint_to_png(nprint_dir, output_dir):
44 | os.makedirs(output_dir, exist_ok=True)
45 | for file in os.listdir(nprint_dir):
46 | if file.endswith(".nprint"):
47 | print(f"Processing {file}")
48 | file_path = os.path.join(nprint_dir, file)
49 | try:
50 | df = pd.read_csv(file_path)
51 | if df.empty:
52 | continue
53 |
54 | substrings = ['ipv4_src', 'ipv4_dst', 'ipv6_src', 'ipv6_dst', 'src_ip']
55 | df = df.drop(columns=[col for col in df.columns if any(sub in col for sub in substrings)])
56 |
57 | for col in df.columns:
58 | df[col] = df[col].apply(int_to_rgba)
59 |
60 | output_file = os.path.join(output_dir, file.replace('.nprint', '.png'))
61 | dataframe_to_png(df, output_file)
62 |
63 | except Exception as e:
64 | print(f"❌ Failed to process {file}: {e}")
65 | continue
66 |
67 | def main():
68 | parser = argparse.ArgumentParser(description="Convert .nprint files to PNG images.")
69 | parser.add_argument(
70 | "--input_dir", "-i", required=True, help="Directory containing .nprint files"
71 | )
72 | parser.add_argument(
73 | "--output_dir", "-o", required=True, help="Directory to save PNG images"
74 | )
75 | args = parser.parse_args()
76 |
77 | convert_nprint_to_png(args.input_dir, args.output_dir)
78 |
79 | if __name__ == "__main__":
80 | main()
81 |
--------------------------------------------------------------------------------
/scripts/reconstruction.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import subprocess
3 | import numpy as np
4 | import sys
5 | import argparse
6 | import random
7 | import math
8 | from collections import Counter
9 | from scapy.all import *
10 | from scapy.utils import PcapWriter
11 | from scapy.all import rdpcap, wrpcap
12 | from scapy.layers.inet import IP, TCP
13 |
14 | def update_ipv4_checksum(pcap_file, output_file):
15 | # Read the packets from the pcap file
16 | packets = rdpcap(pcap_file)
17 |
18 | # Create a PcapWriter instance with the desired DLT value (DLT_RAW)
19 | writer = PcapWriter(output_file, linktype=101, sync=True)
20 |
21 | # Iterate over each packet
22 | for packet in packets:
23 | # Check if the packet is an IPv4 packet
24 | if packet.haslayer(IP):
25 | # Delete the checksum field
26 | del packet[IP].chksum
27 |
28 | # Scapy will automatically calculate the checksum when the packet is serialized/sent
29 |
30 | # Write packet to the output file with desired DLT
31 | writer.write(packet)
32 |
33 | # Close the writer
34 | writer.close()
35 |
36 | def binary_to_decimal(binary_list):
37 | """Convert a binary list representation to its decimal equivalent."""
38 | binary_str = ''.join(map(str, binary_list))
39 | return int(binary_str, 2)
40 |
41 | def random_ip():
42 | return '.'.join(str(random.randint(0, 255)) for _ in range(4))
43 |
44 | def read_syn_nprint(generated_nprint_path):
45 | syn_nprint_df = pd.read_csv(generated_nprint_path)
46 | substrings = ['Unnamed: 0']
47 | # Get the list of columns that contain any of the specified substrings
48 | cols_to_drop = [col for col in syn_nprint_df.columns if any(substring in col for substring in substrings)]
49 | # Drop the selected columns and assign the resulting DataFrame back to 'df'
50 | syn_nprint_df = syn_nprint_df.drop(columns=cols_to_drop)
51 | return syn_nprint_df
52 |
53 | def encode_ip(ip):
54 | return ''.join([f'{int(x):08b}' for x in ip.split('.')])
55 |
56 | def reconstruction_to_pcap(formatted_generated_nprint_path, rebuilt_pcap_path):
57 | subprocess.run('nprint -N {0} -W {1}'.format(formatted_generated_nprint_path, rebuilt_pcap_path), shell=True)
58 |
59 | def ip_address_formatting(generated_nprint, synthetic_sequence):
60 | # ################IP address population for non binary encoded ips
61 | if args.src_ip == '0.0.0.0':
62 | implementing_src_ip = random_ip()
63 | else:
64 | implementing_src_ip = args.src_ip
65 |
66 | if args.dst_ip == '0.0.0.0':
67 | implementing_dst_ip = random_ip()
68 | else:
69 | implementing_dst_ip = args.dst_ip
70 |
71 | # Iterate through the dataframe and list together
72 | for idx, value in enumerate(synthetic_sequence):
73 | if value == 0:
74 | generated_nprint.at[idx, 'src_ip'] = implementing_src_ip
75 | generated_nprint.at[idx, 'dst_ip'] = implementing_dst_ip
76 | else:
77 | generated_nprint.at[idx, 'src_ip'] = implementing_dst_ip
78 | generated_nprint.at[idx, 'dst_ip'] = implementing_src_ip
79 |
80 |
81 | ################Derive binary encoded ips according
82 | # Apply the function to the 'src_ip' column
83 | generated_nprint['src_binary_ip'] = generated_nprint['src_ip'].apply(encode_ip)
84 | # Apply the function to the 'dst_ip' column
85 | generated_nprint['dst_binary_ip'] = generated_nprint['dst_ip'].apply(encode_ip)
86 | # Split the binary IP addresses into separate columns
87 | for i in range(32):
88 | generated_nprint[f'ipv4_src_{i}'] = generated_nprint['src_binary_ip'].apply(lambda x: x[i]).astype(np.int8)
89 | for i in range(32):
90 | generated_nprint[f'ipv4_dst_{i}'] = generated_nprint['dst_binary_ip'].apply(lambda x: x[i]).astype(np.int8)
91 | # Drop the 'binary_ip' column as it's no longer needed
92 | generated_nprint = generated_nprint.drop(columns=['src_binary_ip'])
93 | generated_nprint = generated_nprint.drop(columns=['dst_binary_ip'])
94 | generated_nprint = generated_nprint.drop(columns=['dst_ip'])
95 | #print(generated_nprint['src_ip'])
96 | return generated_nprint
97 |
98 | def ipv4_hl_formatting(generated_nprint,formatted_nprint):
99 | # Get the subset of columns containing 'ipv4'
100 | ipv4_columns = generated_nprint.filter(like='ipv4')
101 | # For each row in the DataFrame
102 | for idx, row in ipv4_columns.iterrows():
103 | # Count the 1s and 0s in this row
104 | count = (row == 1).sum() + (row == 0).sum()
105 | #print(count)
106 | # Convert to 32-bit/4-byte words
107 | header_size_words = math.ceil(count / 32)
108 |
109 | # Convert to binary and pad with zeroes to get a 4-bit representation
110 | binary_count = format(header_size_words, '04b')
111 | # Update the 'ipv4_hl' columns in the original DataFrame based on this binary representation
112 | for i in range(4):
113 | generated_nprint.at[idx, f'ipv4_hl_{i}'] = int(binary_count[i])
114 | return generated_nprint
115 |
116 | def ipv4_tl_formatting_tcp(generated_nprint, formatted_nprint):
117 |
118 | counter = 0
119 | for idx, row in generated_nprint.iterrows():
120 | # Extracting binary values for ipv4_tl, ipv4_hl, and tcp_doff
121 | ipv4_tl_binary = [row[f'ipv4_tl_{i}'] for i in range(16)]
122 | ipv4_hl_binary = [row[f'ipv4_hl_{i}'] for i in range(4)]
123 | tcp_doff_binary = [row[f'tcp_doff_{i}'] for i in range(4)]
124 |
125 | # Convert the binary representation to integer
126 | ipv4_tl_value = binary_to_decimal(ipv4_tl_binary)
127 | ipv4_hl_value = binary_to_decimal(ipv4_hl_binary) * 4 # Convert from 4-byte words to bytes
128 | tcp_doff_value = binary_to_decimal(tcp_doff_binary) * 4 # Convert from 4-byte words to bytes
129 | # Checking and setting the new value if condition is met
130 | if ipv4_tl_value < ipv4_hl_value + tcp_doff_value:
131 | new_ipv4_tl_value = ipv4_hl_value + tcp_doff_value
132 | # Convert new value back to binary and update the fields
133 | new_ipv4_tl_binary = format(new_ipv4_tl_value, '016b')
134 | for i, bit in enumerate(new_ipv4_tl_binary):
135 | generated_nprint.at[idx, f'ipv4_tl_{i}'] = int(bit)
136 | elif ipv4_tl_value>1500:
137 | new_ipv4_tl_binary = format(1500, '016b')
138 | for i, bit in enumerate(new_ipv4_tl_binary):
139 | generated_nprint.at[idx, f'ipv4_tl_{i}'] = int(bit)
140 | else:
141 | new_ipv4_tl_binary = format(ipv4_tl_value, '016b')
142 | for i, bit in enumerate(new_ipv4_tl_binary):
143 | generated_nprint.at[idx, f'ipv4_tl_{i}'] = int(bit)
144 | # for i in range(16):
145 | # generated_nprint[f'ipv4_tl_{i}'] = formatted_nprint[f'ipv4_tl_{i}']
146 | for idx, row in generated_nprint.iterrows():
147 | # Extracting binary values for ipv4_tl, ipv4_hl, and tcp_doff
148 | ipv4_tl_binary = [row[f'ipv4_tl_{i}'] for i in range(16)]
149 | ipv4_hl_binary = [row[f'ipv4_hl_{i}'] for i in range(4)]
150 | tcp_doff_binary = [row[f'tcp_doff_{i}'] for i in range(4)]
151 |
152 | # Convert the binary representation to integer
153 | ipv4_tl_value = binary_to_decimal(ipv4_tl_binary)
154 | ipv4_hl_value = binary_to_decimal(ipv4_hl_binary) * 4 # Convert from 4-byte words to bytes
155 | tcp_doff_value = binary_to_decimal(tcp_doff_binary) * 4 # Convert from 4-byte words to bytes
156 | # print(f'Packet {counter}:')
157 | # print('ipv4 total length in bytes:')
158 | # print(ipv4_tl_value)
159 | # print('ipv4 header length in bytes:')
160 | # print(ipv4_hl_value)
161 | # print('tcp doff in bytes:')
162 | # print(tcp_doff_value)
163 | # print()
164 | counter +=1
165 |
166 |
167 | return generated_nprint
168 |
169 | def ipv4_ver_formatting(generated_nprint,formatted_nprint):
170 | #following is placeholder, how do we get payload size?:
171 | # Define the substrings that have static values, e.g., ip version = 4
172 | fields = ["ipv4_ver"]
173 | # Iterate over the columns of the source DataFrame
174 | for column in formatted_nprint.columns:
175 | # Check if the substring exists in the column name
176 | for field in fields:
177 | if field in column:
178 | # Copy the column values to the destination DataFrame
179 | #generated_nprint[column] = formatted_nprint[column]
180 | # limit to ipv4
181 | if '0' in column:
182 | generated_nprint[column] = 0
183 | elif '1' in column:
184 | generated_nprint[column] = 1
185 | elif '2' in column:
186 | generated_nprint[column] = 0
187 | else:
188 | generated_nprint[column] = 0
189 |
190 | return generated_nprint
191 |
192 | def protocol_determination(generated_nprint):
193 | protocols = ["tcp", "udp", "icmp"]
194 | percentages = {}
195 |
196 | # Iterate over the protocols
197 | for protocol in protocols:
198 | columns = [col for col in generated_nprint.columns if protocol in col and 'opt' not in col]
199 |
200 | # Count non-negatives in each column and calculate the total percentage for each protocol
201 | total_count = 0
202 | non_negative_count = 0
203 | for column in columns:
204 | total_count += len(generated_nprint[column])
205 | non_negative_count += (generated_nprint[column] >= 0).sum()
206 |
207 | # Calculate percentage and store in the dictionary
208 | if total_count > 0:
209 | percentages[protocol] = (non_negative_count / total_count) * 100
210 | else:
211 | percentages[protocol] = 0
212 |
213 | # Find protocol with the highest percentage of non-negative values
214 | max_protocol = max(percentages, key=percentages.get)
215 | return max_protocol
216 |
217 | def ipv4_pro_formatting(generated_nprint,formatted_nprint):
218 | #following is placeholder, how do we get payload size?:
219 | # Define the substrings that have static values, e.g., ip version = 4
220 |
221 | # Call the function to determine the protocol
222 | dominating_protocol = protocol_determination(generated_nprint)
223 | print(dominating_protocol)
224 | # tcp = 0,0,0,0,0,1,1,0
225 | # udp = 0,0,0,1,0,0,0,1
226 | # icmp = 0,0,0,0,0,0,0,1
227 | fields = ["ipv4_pro"]
228 | # Iterate over the columns of the source DataFrame
229 | for column in formatted_nprint.columns:
230 | # Check if the substring exists in the column name
231 | for field in fields:
232 | if field in column:
233 | if dominating_protocol == 'tcp':
234 | if '_0' in column:
235 | generated_nprint[column] = 0
236 | elif '_1' in column:
237 | generated_nprint[column] = 0
238 | elif '_2' in column:
239 | generated_nprint[column] = 0
240 | elif '_3' in column:
241 | generated_nprint[column] = 0
242 | elif '_4' in column:
243 | generated_nprint[column] = 0
244 | elif '_5' in column:
245 | generated_nprint[column] = 1
246 | elif '_6' in column:
247 | generated_nprint[column] = 1
248 | elif '_7' in column:
249 | generated_nprint[column] = 0
250 | elif dominating_protocol == 'udp':
251 | if '_0' in column:
252 | generated_nprint[column] = 0
253 | elif '_1' in column:
254 | generated_nprint[column] = 0
255 | elif '_2' in column:
256 | generated_nprint[column] = 0
257 | elif '_3' in column:
258 | generated_nprint[column] = 1
259 | elif '_4' in column:
260 | generated_nprint[column] = 0
261 | elif '_5' in column:
262 | generated_nprint[column] = 0
263 | elif '_6' in column:
264 | generated_nprint[column] = 0
265 | elif '_7' in column:
266 | generated_nprint[column] = 1
267 | elif dominating_protocol == 'icmp':
268 | if '_0' in column:
269 | generated_nprint[column] = 0
270 | elif '_1' in column:
271 | generated_nprint[column] = 0
272 | elif '_2' in column:
273 | generated_nprint[column] = 0
274 | elif '_3' in column:
275 | generated_nprint[column] = 0
276 | elif '_4' in column:
277 | generated_nprint[column] = 0
278 | elif '_5' in column:
279 | generated_nprint[column] = 0
280 | elif '_6' in column:
281 | generated_nprint[column] = 0
282 | elif '_7' in column:
283 | generated_nprint[column] = 1
284 |
285 | # Copy the column values to the destination DataFrame
286 | #generated_nprint[column] = formatted_nprint[column]
287 | # make sure non-dominant-protocol values are -1s
288 | protocols = ["tcp", "udp", "icmp"]
289 | for column in formatted_nprint.columns:
290 | # Check if the substring exists in the column name
291 | for protocol in protocols:
292 | if protocol in column:
293 | if protocol != dominating_protocol:
294 | generated_nprint[column] = -1
295 |
296 |
297 | return generated_nprint
298 |
299 | def ipv4_header_negative_removal(generated_nprint, formatted_nprint):
300 | fields = ["ipv4"]
301 |
302 | # Function to apply to each cell
303 | def replace_negative_one(val):
304 | if val == -1:
305 | return np.random.randint(0, 2) # Generates either 0 or 1
306 | else:
307 | return val
308 |
309 | # Iterate over the columns of the source DataFrame
310 | for column in formatted_nprint.columns:
311 | # Check if the substring exists in the column name
312 | for field in fields:
313 | #if field in column:
314 | if field in column and 'opt' not in column:
315 | generated_nprint[column] = generated_nprint[column].apply(replace_negative_one)
316 | # ######## no opt for debugging
317 | # elif field in column:
318 | # generated_nprint[column] = -1
319 |
320 | return generated_nprint
321 |
322 | def ipv4_option_removal(generated_nprint, formatted_nprint):
323 | fields = ["ipv4_opt"]
324 |
325 | # Iterate over the columns of the source DataFrame
326 | for column in formatted_nprint.columns:
327 | # Check if the substring exists in the column name
328 | for field in fields:
329 | #if field in column:
330 | if field in column:
331 | generated_nprint[column] = -1
332 | # ######## no opt for debugging
333 | # elif field in column:
334 | # generated_nprint[column] = -1
335 |
336 | return generated_nprint
337 |
338 | def ipv4_ttl_ensure(generated_nprint,formatted_nprint):
339 |
340 | for index in range(0, len(generated_nprint)):
341 | ttl_0 = True
342 | for j in range(8):
343 | if generated_nprint.at[index, f'ipv4_ttl_{j}'] != 0:
344 | ttl_0 = False
345 | if ttl_0 == True:
346 | generated_nprint.at[index, 'ipv4_ttl_7'] = 1
347 | return generated_nprint
348 |
349 | def tcp_header_negative_removal(generated_nprint, formatted_nprint):
350 | fields = ["tcp"]
351 |
352 | # Function to apply to each cell
353 | def replace_negative_one(val):
354 | if val == -1:
355 | return np.random.randint(0, 2) # Generates either 0 or 1
356 | else:
357 | return val
358 |
359 | # Iterate over the columns of the source DataFrame
360 | for column in formatted_nprint.columns:
361 | # Check if the substring exists in the column name
362 | for field in fields:
363 | #if field in column:
364 | if field in column and 'opt' not in column:
365 | generated_nprint[column] = generated_nprint[column].apply(replace_negative_one)
366 | # ######## no opt for debugging
367 | # elif field in column:
368 | # generated_nprint[column] = -1
369 |
370 | return generated_nprint
371 |
372 | def modify_tcp_option(packet):
373 | # This function processes each packet of the dataframe and modifies the TCP option fields to align with the actual structure of the TCP options.
374 | option_data = packet.loc['tcp_opt_0':'tcp_opt_319'].to_numpy()
375 | idx = 0
376 | options_lengths = [0, 8, 32, 24, 16, 40, 80] # NOP/EOL, MSS, Window Scale, SACK Permitted, SACK, Timestamp
377 |
378 | while idx < 320:
379 | start_idx = idx
380 | end_idx = idx
381 | while end_idx < 320 and option_data[end_idx] != -1:
382 | end_idx += 1
383 | length = end_idx - start_idx
384 | closest_option = min(options_lengths, key=lambda x: abs(x - length))
385 |
386 | if closest_option == 32: # MSS
387 | #print('mss')
388 | idx += 32
389 | mss_data = np.concatenate(([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0], option_data[start_idx+16:idx]))
390 | mss_data = [np.random.choice([0, 1]) if bit == -1 else bit for bit in mss_data]
391 | option_data[start_idx:idx] = mss_data
392 | options_lengths.remove(closest_option)
393 | elif closest_option == 24: # Window Scale
394 | #print('ws')
395 | idx += 24
396 | ws_data = np.concatenate(([0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1], option_data[start_idx+16:idx]))
397 | ws_data = [np.random.choice([0, 1]) if bit == -1 else bit for bit in ws_data]
398 | option_data[start_idx:idx] = ws_data
399 | options_lengths.remove(closest_option)
400 | elif closest_option == 16: # SACK Permitted
401 | #print('sack permitted')
402 | idx += 16
403 | option_data[start_idx:idx] = [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]
404 | options_lengths.remove(closest_option)
405 |
406 | elif closest_option == 40: # SACK (Assuming one block for simplicity)
407 | # Assuming the length would be for one SACK block: kind (1 byte), length (1 byte, value 10 for one block), and 8 bytes of data.
408 | idx+=40
409 | sack_data = np.concatenate(([0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0], option_data[start_idx+16:idx]))
410 | sack_data = [np.random.choice([0, 1]) if bit == -1 else bit for bit in sack_data]
411 | option_data[start_idx:idx] = sack_data
412 | options_lengths.remove(closest_option)
413 |
414 | elif closest_option == 80: # Timestamp
415 | #print('time stamp')
416 | idx += 80
417 | ts_data = np.concatenate(([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0], option_data[start_idx+16:idx]))
418 | ts_data = [np.random.choice([0, 1]) if bit == -1 else bit for bit in ts_data]
419 | option_data[start_idx:idx] = ts_data
420 | options_lengths.remove(closest_option)
421 |
422 | elif closest_option == 8: #
423 | #print('eol/nop')
424 | if option_data[start_idx] == 0: # EOL
425 | if start_idx == 0:
426 | idx += 8
427 | option_data[start_idx:idx] = [-1,-1,-1,-1,-1,-1,-1,-1]
428 | options_lengths.remove(closest_option)
429 | continue
430 | else:
431 | idx += 8
432 | option_data[start_idx:idx] = [0,0,0,0,0,0,0,0]
433 | option_data[idx:] = [-1] * (320 - idx)
434 | options_lengths.remove(closest_option)
435 | break
436 | elif option_data[start_idx] == 1: # NOP
437 | idx += 8
438 | option_data[start_idx:idx] = [0,0,0,0,0,0,0,1]
439 | elif closest_option == 0:
440 | idx += 8
441 | option_data[start_idx:idx] = [-1,-1,-1,-1,-1,-1,-1,-1]
442 |
443 |
444 | # Assign back the modified options to the DataFrame's row
445 | packet.loc['tcp_opt_0':'tcp_opt_319'] = option_data
446 | return packet
447 |
448 | def tcp_opt_formatting(generated_nprint, formatted_nprint):
449 | generated_nprint = generated_nprint.apply(modify_tcp_option, axis=1)
450 | return generated_nprint
451 |
452 | def tcp_data_offset_calculation(generated_nprint, formatted_nprint):
453 | # Get the subset of columns containing 'tcp'
454 | tcp_columns = generated_nprint.filter(like='tcp')
455 | # For each row in the DataFrame
456 | for idx, row in tcp_columns.iterrows():
457 | # Count the 1s and 0s in this row
458 | count = (row == 1).sum() + (row == 0).sum()
459 | # Convert to 32-bit/4-byte words
460 | header_size_words = math.ceil(count / 32)
461 | # Convert to binary and pad with zeroes to get a 4-bit representation
462 | binary_count = format(header_size_words, '04b')
463 | # Update the 'ipv4_hl' columns in the original DataFrame based on this binary representation
464 | for i in range(4):
465 | generated_nprint.at[idx, f'tcp_doff_{i}'] = int(binary_count[i])
466 |
467 | return generated_nprint
468 |
469 | def src_ip_distribution(formatted_nprint_path):
470 | # Assuming formatted_nprint_path is a DataFrame
471 | formatted_nprint = pd.read_csv(formatted_nprint_path)
472 | # Get the value counts for the 'src_ip' column and convert to a dictionary
473 | ip_counts = formatted_nprint['src_ip'].value_counts()
474 |
475 | # Convert the counts to percentages and store in a dictionary
476 | ip_distribution = (ip_counts / ip_counts.sum()).to_dict()
477 |
478 | return ip_distribution
479 |
480 | def direction_sampleing(formatted_nprint_path):
481 | # example distribution sampling
482 | # Let's say this is your analyzed distribution
483 | ip_distribution = src_ip_distribution(formatted_nprint_path)
484 | sorted_ips = sorted(ip_distribution.items(), key=lambda x: x[1], reverse=True)
485 | # Get the first two items (IPs with the highest proportions)
486 | top_two_ips = dict(sorted_ips[:2])
487 | ips = []
488 | for key in top_two_ips:
489 | ips.append(key)
490 |
491 | # Read the dataframe
492 | formatted_nprint_df = pd.read_csv(formatted_nprint_path)
493 |
494 | # Get the distinct IP addresses in order, but only if they belong to the top two IPs
495 | unique_ips = formatted_nprint_df['src_ip'].drop_duplicates()
496 | ordered_ips = [ip for ip in unique_ips if ip in top_two_ips]
497 | # Get the first two distinct IP addresses in order
498 | first_ip = ordered_ips[0]
499 | second_ip = ordered_ips[1]
500 | # Initialize the transition counts
501 | transition_counts = {
502 | (first_ip, first_ip): 0,
503 | (first_ip, second_ip): 0,
504 | (second_ip, first_ip): 0,
505 | (second_ip, second_ip): 0
506 | }
507 | # Iterate over the source IP addresses in order
508 | for i in range(1, len(formatted_nprint_df['src_ip'])):
509 | # Check if current and previous IP belong to top_two_ips
510 | if formatted_nprint_df['src_ip'][i] in top_two_ips and formatted_nprint_df['src_ip'][i-1] in top_two_ips:
511 | # Increment the count for the transition from the previous IP to the current IP
512 | transition_counts[(formatted_nprint_df['src_ip'][i-1], formatted_nprint_df['src_ip'][i])] += 1
513 |
514 | # Calculate the total counts from each state
515 | total_from_first_ip = transition_counts[(first_ip, first_ip)] + transition_counts[(first_ip, second_ip)]
516 | total_from_second_ip = transition_counts[(second_ip, first_ip)] + transition_counts[(second_ip, second_ip)]
517 |
518 | # Calculate the transition probabilities
519 | transition_matrix = np.array([
520 | [transition_counts[(first_ip, first_ip)] / total_from_first_ip if total_from_first_ip > 0 else 0,
521 | transition_counts[(first_ip, second_ip)] / total_from_first_ip if total_from_first_ip > 0 else 0],
522 | [transition_counts[(second_ip, first_ip)] / total_from_second_ip if total_from_second_ip > 0 else 0,
523 | transition_counts[(second_ip, second_ip)] / total_from_second_ip if total_from_second_ip > 0 else 0]
524 | ])
525 | # Map the indices to the IPs
526 | index_to_ip = {0: first_ip, 1: second_ip}
527 |
528 | # Starting state, for example 'src_to_dst'
529 | current_state = 0
530 |
531 | # Generate synthetic sequence
532 | synthetic_sequence = []
533 | # start with the first IP as source
534 | synthetic_sequence.append(current_state)
535 | for _ in range(1023):
536 | # Choose the next state
537 | current_state = np.random.choice([0, 1], p=transition_matrix[current_state])
538 | # Add the corresponding IP to the sequence
539 | # synthetic_sequence.append(index_to_ip[current_state])
540 | synthetic_sequence.append(current_state)
541 | return synthetic_sequence
542 |
543 | def id_num_initialization_src_dst(generated_nprint):
544 | random_id_num = random_bits_generation(16)
545 | for i in range(16):
546 | generated_nprint.at[0, f'ipv4_id_{i}'] = int(random_id_num[i])
547 | first_row_src_ip = generated_nprint.at[0, 'src_ip']
548 | current_bin_str = random_id_num
549 | # Go through all other rows
550 | for index in range(1, len(generated_nprint)):
551 | # If src_ip of the current row matches that of the first row
552 | if generated_nprint.at[index, 'src_ip'] == first_row_src_ip:
553 | current_bin_str = increment_binary(current_bin_str)
554 | # Update the fields with the incremented value
555 | for i in range(16):
556 | generated_nprint.at[index, f'ipv4_id_{i}'] = int(current_bin_str[i])
557 |
558 | return generated_nprint
559 |
560 | def id_num_initialization_dst_src(generated_nprint):
561 | random_id_num = random_bits_generation(16)
562 | first_row_src_ip = generated_nprint.at[0, 'src_ip']
563 | current_bin_str = random_id_num
564 | # Go through all other rows
565 | for index in range(1, len(generated_nprint)):
566 | # If src_ip of the current row matches that of the first row
567 | if generated_nprint.at[index, 'src_ip'] != first_row_src_ip:
568 | current_bin_str = increment_binary(current_bin_str)
569 | # Update the fields with the incremented value
570 | for i in range(16):
571 | generated_nprint.at[index, f'ipv4_id_{i}'] = int(current_bin_str[i])
572 |
573 | return generated_nprint
574 |
575 | def increment_binary(bin_str):
576 | '''Increment a binary string by one.'''
577 | bin_int = int(bin_str, 2) + 1
578 | return format(bin_int, f'0{len(bin_str)}b')
579 |
580 | def ip_fragementation_bits(generated_nprint):
581 | generated_nprint['ipv4_rbit_0'] = 0
582 | generated_nprint['ipv4_dfbit_0'] = 1
583 | generated_nprint['ipv4_mfbit_0'] = 0
584 | for i in range(13):
585 | generated_nprint[f'ipv4_foff_{i}'] = 0
586 | return generated_nprint
587 |
588 | def random_bits_generation(required_num_bits):
589 | return ''.join(str(random.randint(0, 1)) for _ in range(required_num_bits))
590 |
591 | def port_initialization(generated_nprint):
592 | random_src_port = random_bits_generation(16)
593 | random_dst_port = random_bits_generation(16)
594 | dominating_protocol = protocol_determination(generated_nprint)
595 | if dominating_protocol == 'tcp':
596 | for i in range(16):
597 | generated_nprint.at[0, f'tcp_sprt_{i}'] = int(random_src_port[i])
598 | generated_nprint.at[0, f'tcp_dprt_{i}'] = int(random_dst_port[i])
599 | first_row_src_ip = generated_nprint.at[0, 'src_ip']
600 | # Go through all other rows
601 | for index in range(1, len(generated_nprint)):
602 | # If src_ip of the current row matches that of the first row
603 | if generated_nprint.at[index, 'src_ip'] == first_row_src_ip:
604 | for i in range(16):
605 | generated_nprint.at[index, f'tcp_sprt_{i}'] = int(random_src_port[i])
606 | generated_nprint.at[index, f'tcp_dprt_{i}'] = int(random_dst_port[i])
607 | else:
608 | for i in range(16):
609 | generated_nprint.at[index, f'tcp_sprt_{i}'] = int(random_dst_port[i])
610 | generated_nprint.at[index, f'tcp_dprt_{i}'] = int(random_src_port[i])
611 | elif dominating_protocol == 'udp':
612 | for i in range(16):
613 | generated_nprint.at[0, f'udp_sport_{i}'] = int(random_src_port[i])
614 | generated_nprint.at[0, f'udp_dport_{i}'] = int(random_dst_port[i])
615 | first_row_src_ip = generated_nprint.at[0, 'src_ip']
616 | # Go through all other rows
617 | for index in range(1, len(generated_nprint)):
618 | # If src_ip of the current row matches that of the first row
619 | if generated_nprint.at[index, 'src_ip'] == first_row_src_ip:
620 | for i in range(16):
621 | generated_nprint.at[index, f'udp_sport_{i}'] = int(random_src_port[i])
622 | generated_nprint.at[index, f'udp_dport_{i}'] = int(random_dst_port[i])
623 | else:
624 | for i in range(16):
625 | generated_nprint.at[index, f'udp_sport_{i}'] = int(random_dst_port[i])
626 | generated_nprint.at[index, f'udp_dport_{i}'] = int(random_src_port[i])
627 | return generated_nprint
628 |
629 | def compute_tcp_segment_length(row):
630 | total_length = int(''.join(str(row[f'ipv4_tl_{i}']) for i in range(16)), 2)
631 | ipv4_header_length = int(''.join(str(row[f'ipv4_hl_{i}']) for i in range(4)), 2) * 4 # in bytes
632 | tcp_header_length = int(''.join(str(row[f'tcp_doff_{i}']) for i in range(4)), 2) * 4 # in bytes
633 |
634 | return total_length - ipv4_header_length - tcp_header_length
635 |
636 | def increment_binary_non_fixed(bin_str, increment_value):
637 | decimal_value = int(bin_str, 2)
638 | decimal_value += increment_value
639 | new_bin_str = bin(decimal_value)[2:].zfill(len(bin_str))
640 | return new_bin_str
641 |
642 | def seq_initialization_src_dst(generated_nprint):
643 | random_id_num = random_bits_generation(32)
644 | for i in range(32):
645 | generated_nprint.at[0, f'tcp_seq_{i}'] = int(random_id_num[i])
646 | first_row_src_ip = generated_nprint.at[0, 'src_ip']
647 | current_bin_str = random_id_num
648 |
649 | # To keep track of the last packet for each src_ip
650 | last_packet_for_ip = {first_row_src_ip: 0}
651 |
652 | # Go through all other rows
653 | for index in range(1, len(generated_nprint)):
654 | current_src_ip = generated_nprint.at[index, 'src_ip']
655 | # If src_ip of the current row matches that of the first row
656 | if current_src_ip == first_row_src_ip:
657 | # If this IP has been seen before
658 | if current_src_ip in last_packet_for_ip:
659 | previous_index = last_packet_for_ip[current_src_ip]
660 | previous_row = generated_nprint.iloc[previous_index]
661 | segment_length = compute_tcp_segment_length(previous_row)
662 | current_bin_str = increment_binary_non_fixed(current_bin_str, segment_length)
663 | else:
664 | # If this IP has not been seen before
665 | current_bin_str = random_bits_generation(32)
666 | # Update the fields with the incremented value
667 | for i in range(32):
668 | generated_nprint.at[index, f'tcp_seq_{i}'] = int(current_bin_str[i])
669 |
670 | # Update the last packet for this IP
671 | last_packet_for_ip[current_src_ip] = index
672 |
673 | return generated_nprint
674 |
675 | def seq_initialization_dst_src(generated_nprint):
676 | random_id_num = random_bits_generation(32)
677 | first_row_src_ip = generated_nprint.at[0, 'src_ip']
678 | current_bin_str = random_id_num
679 |
680 | # To keep track of the last packet for each src_ip
681 | last_packet_for_ip = {first_row_src_ip: 0}
682 |
683 | # Go through all other rows
684 | for index in range(1, len(generated_nprint)):
685 | current_src_ip = generated_nprint.at[index, 'src_ip']
686 | # If src_ip of the current row does not match that of the first row
687 | if current_src_ip != first_row_src_ip:
688 | # If this IP has been seen before
689 | if current_src_ip in last_packet_for_ip:
690 | previous_index = last_packet_for_ip[current_src_ip]
691 | previous_row = generated_nprint.iloc[previous_index]
692 | segment_length = compute_tcp_segment_length(previous_row)
693 | current_bin_str = increment_binary_non_fixed(current_bin_str, segment_length)
694 | else:
695 | # If this IP has not been seen before
696 | current_bin_str = random_bits_generation(32)
697 | # Update the fields with the incremented value
698 | for i in range(32):
699 | generated_nprint.at[index, f'tcp_seq_{i}'] = int(current_bin_str[i])
700 |
701 | # Update the last packet for this IP
702 | last_packet_for_ip[current_src_ip] = index
703 |
704 | return generated_nprint
705 |
706 | def three_way_handshake(generated_nprint):
707 | #tcp_ackf_0,tcp_psh_0,tcp_rst_0,tcp_syn_0,tcp_fin_0,
708 | # Modify the dataframe for a 3-way handshake
709 | response_received = False
710 | handshake_complete = False
711 | for index in range(0, len(generated_nprint)):
712 | # If src_ip of the current row matches that of the first row
713 | if index == 0:
714 | generated_nprint.at[index, 'tcp_syn_0'] = 1
715 | generated_nprint.loc[index, ['tcp_ackf_0', 'tcp_psh_0', 'tcp_rst_0', 'tcp_fin_0']] = 0
716 | first_row_src_ip = generated_nprint.at[index, 'src_ip']
717 | elif handshake_complete == False:
718 | if generated_nprint.at[index, 'src_ip'] == first_row_src_ip and response_received == False:
719 | generated_nprint.at[index, 'tcp_syn_0'] = 1
720 | generated_nprint.loc[index, ['tcp_ackf_0', 'tcp_psh_0', 'tcp_rst_0', 'tcp_fin_0']] = 0
721 | elif generated_nprint.at[index, 'src_ip'] != first_row_src_ip:
722 | generated_nprint.loc[index, ['tcp_syn_0', 'tcp_ackf_0']] = 1
723 | generated_nprint.loc[index, ['tcp_psh_0', 'tcp_rst_0', 'tcp_fin_0']] = 0
724 | response_received = True
725 | elif generated_nprint.at[index, 'src_ip'] == first_row_src_ip and response_received == True:
726 | generated_nprint.at[index, 'tcp_ackf_0'] = 1
727 | generated_nprint.loc[index, ['tcp_syn_0', 'tcp_psh_0', 'tcp_rst_0', 'tcp_fin_0']] = 0
728 | handshake_complete = True
729 | elif handshake_complete == True:
730 | generated_nprint.at[index, 'tcp_ackf_0'] = 1
731 | return generated_nprint
732 |
733 | def ackn_initialization_src_dst(generated_nprint):
734 | last_src_to_dst_seq = []
735 | last_dst_to_src_seq = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
736 | src_ip = generated_nprint.at[0, 'src_ip']
737 | for i in range(32):
738 | generated_nprint.at[0, f'tcp_ackn_{i}'] = 0
739 | for i in range(32):
740 | last_src_to_dst_seq.append(generated_nprint.at[0, f'tcp_seq_{i}'])
741 |
742 | first_switch_done = False
743 | for index in range(1, len(generated_nprint)):
744 | current_src_ip = generated_nprint.at[index, 'src_ip']
745 | # if current src is identical as src and first switch is not done
746 | if current_src_ip == src_ip and not first_switch_done:
747 | # keep assign ack to be 0s
748 | for i in range(32):
749 | generated_nprint.at[index, f'tcp_ackn_{i}'] = 0
750 | # update the last sequence numbe from src to dst
751 | last_src_to_dst_seq = []
752 | for i in range(32):
753 | last_src_to_dst_seq.append(generated_nprint.at[index, f'tcp_seq_{i}'])
754 | # if current src is non-identical now: we have switched direction
755 | if current_src_ip != src_ip:
756 | # assign first switch to be true.
757 | first_switch_done = True
758 | # update the ack to be the last_src_to_dst_seq
759 | for i in range(32):
760 | generated_nprint.at[index, f'tcp_ackn_{i}'] = last_src_to_dst_seq[i]
761 | # update the seq of last_dst_to_src_seq
762 | last_dst_to_src_seq = []
763 | for i in range(32):
764 | last_dst_to_src_seq.append(generated_nprint.at[index, f'tcp_seq_{i}'])
765 | # if the src is identical as src and we have switch direction
766 | if current_src_ip == src_ip and first_switch_done:
767 | # update the ack to be the last_dst_to_src_seq
768 | for i in range(32):
769 | generated_nprint.at[index, f'tcp_ackn_{i}'] = last_dst_to_src_seq[i]
770 | # update the seq of last_src_to_dst_seq
771 | last_src_to_dst_seq = []
772 | for i in range(32):
773 | last_src_to_dst_seq.append(generated_nprint.at[index, f'tcp_seq_{i}'])
774 |
775 |
776 |
777 |
778 | return generated_nprint
779 |
780 | def udp_len_calculation(generated_nprint, formatted_nprint):
781 | # For each row in the DataFrame
782 | for idx, row in generated_nprint.iterrows():
783 | ipv4_hl_binary = [row[f'ipv4_hl_{i}'] for i in range(4)]
784 | ipv4_hl_value = binary_to_decimal(ipv4_hl_binary) * 4 # Convert from 4-byte words to bytes
785 | upper_limit = 1500 - ipv4_hl_value - 8
786 | udp_len_binary = [row[f'udp_len_{i}'] for i in range(16)]
787 | udp_len_value = binary_to_decimal(udp_len_binary) # Convert from 4-byte words to bytes
788 | if udp_len_value >= 8 and udp_len_value <= upper_limit:
789 | continue
790 | elif udp_len_value < 8:
791 | for i in range(16):
792 | generated_nprint.at[idx, f'udp_len_{i}'] = 0
793 | generated_nprint.at[idx, f'udp_len_12'] = 1
794 | else:
795 | new_udp_len_binary = format(upper_limit, '016b')
796 | for i in range(16):
797 | generated_nprint.at[idx, f'udp_len_{i}'] = int(new_udp_len_binary[i])
798 |
799 | return generated_nprint
800 |
801 | def udp_header_negative_removal(generated_nprint, formatted_nprint):
802 | fields = ["udp"]
803 |
804 | # Function to apply to each cell
805 | def replace_negative_one(val):
806 | if val == -1:
807 | return np.random.randint(0, 2) # Generates either 0 or 1
808 | else:
809 | return val
810 |
811 | # Iterate over the columns of the source DataFrame
812 | for column in formatted_nprint.columns:
813 | # Check if the substring exists in the column name
814 | for field in fields:
815 | #if field in column:
816 | if field in column and 'opt' not in column:
817 | generated_nprint[column] = generated_nprint[column].apply(replace_negative_one)
818 | # ######## no opt for debugging
819 | # elif field in column:
820 | # generated_nprint[column] = -1
821 |
822 | return generated_nprint
823 |
824 | def ipv4_tl_formatting_udp(generated_nprint, formatted_nprint):
825 |
826 | counter = 0
827 | for idx, row in generated_nprint.iterrows():
828 | # Extracting binary values for ipv4_tl, ipv4_hl, and tcp_doff
829 | ipv4_tl_binary = [row[f'ipv4_tl_{i}'] for i in range(16)]
830 | ipv4_hl_binary = [row[f'ipv4_hl_{i}'] for i in range(4)]
831 | udp_len_binary = [row[f'udp_len_{i}'] for i in range(16)]
832 |
833 | # Convert the binary representation to integer
834 | ipv4_tl_value = binary_to_decimal(ipv4_tl_binary)
835 | ipv4_hl_value = binary_to_decimal(ipv4_hl_binary) * 4 # Convert from 4-byte words to bytes
836 | udp_len_value = binary_to_decimal(udp_len_binary) # Convert from 4-byte words to bytes
837 | # Checking and setting the new value if condition is met
838 | if ipv4_tl_value < ipv4_hl_value + udp_len_value:
839 | new_ipv4_tl_value = ipv4_hl_value + udp_len_value
840 | # Convert new value back to binary and update the fields
841 | new_ipv4_tl_binary = format(new_ipv4_tl_value, '016b')
842 | for i, bit in enumerate(new_ipv4_tl_binary):
843 | generated_nprint.at[idx, f'ipv4_tl_{i}'] = int(bit)
844 | elif ipv4_tl_value>1500:
845 | new_ipv4_tl_binary = format(1500, '016b')
846 | for i, bit in enumerate(new_ipv4_tl_binary):
847 | generated_nprint.at[idx, f'ipv4_tl_{i}'] = int(bit)
848 | else:
849 | new_ipv4_tl_binary = format(ipv4_tl_value, '016b')
850 | for i, bit in enumerate(new_ipv4_tl_binary):
851 | generated_nprint.at[idx, f'ipv4_tl_{i}'] = int(bit)
852 | # for i in range(16):
853 | # generated_nprint[f'ipv4_tl_{i}'] = formatted_nprint[f'ipv4_tl_{i}']
854 | for idx, row in generated_nprint.iterrows():
855 | # Extracting binary values for ipv4_tl, ipv4_hl, and tcp_doff
856 | ipv4_tl_binary = [row[f'ipv4_tl_{i}'] for i in range(16)]
857 | ipv4_hl_binary = [row[f'ipv4_hl_{i}'] for i in range(4)]
858 | udp_len_binary = [row[f'udp_len_{i}'] for i in range(16)]
859 |
860 | # Convert the binary representation to integer
861 | ipv4_tl_value = binary_to_decimal(ipv4_tl_binary)
862 | ipv4_hl_value = binary_to_decimal(ipv4_hl_binary) * 4 # Convert from 4-byte words to bytes
863 | udp_len_value = binary_to_decimal(udp_len_binary) # Convert from 4-byte words to bytes
864 |
865 | counter +=1
866 |
867 |
868 | return generated_nprint
869 |
870 | def main(generated_nprint_path, formatted_nprint_path, output, nprint):
871 | rebuilt_pcap_path = output
872 | generated_nprint = read_syn_nprint(generated_nprint_path)
873 | formatted_nprint = pd.read_csv(formatted_nprint_path)
874 | # Get the list of column names in both CSVs
875 | generated_columns = set(generated_nprint.columns)
876 | formatted_columns = set(formatted_nprint.columns)
877 | # Find missing columns in the generated nprint
878 | missing_columns = formatted_columns - generated_columns
879 | # Add the missing columns to the generated nprint with value 0
880 | missing_data = pd.DataFrame(0, index=generated_nprint.index, columns=list(missing_columns))
881 | generated_nprint = pd.concat([generated_nprint, missing_data], axis=1)
882 | # Reindex the columns of generated_nprint to match the order of columns in formatted_nprint
883 | generated_nprint = generated_nprint.reindex(columns=formatted_nprint.columns)
884 |
885 |
886 |
887 |
888 | ##############################################################################Intra packet dependency adjustments#########################################################################
889 | ########### ip address formatting (we have flexibility here)
890 | synthetic_sequence = direction_sampleing(formatted_nprint_path)
891 | generated_nprint = ip_address_formatting(generated_nprint, synthetic_sequence)
892 |
893 | ########### IPV4
894 | generated_nprint = ipv4_ver_formatting(generated_nprint,formatted_nprint) # we are using ipv4 only
895 | generated_nprint = ipv4_header_negative_removal(generated_nprint,formatted_nprint) # here we make sure minimum ipv4 header size is achieved - no missing ipv4 header fields, random int is assigned as the fields largely are correct due to diffusion
896 | generated_nprint = ipv4_pro_formatting(generated_nprint,formatted_nprint) # this is less flexible -> choose protocol with most percentage of non negatives excluding option, and change all non-determined-protocol fields to -1
897 | generated_nprint = ipv4_option_removal(generated_nprint,formatted_nprint) # mordern Internet rarely uses ipv4 options is used at all, from the data we observe ipv4 options are never present due to it being obsolete
898 | generated_nprint = ipv4_ttl_ensure(generated_nprint,formatted_nprint) # ensure ttl > 0
899 | generated_nprint = ipv4_hl_formatting(generated_nprint,formatted_nprint) # ipv4 header length formatting (this is computation based so we do not have flexibility here), need to be done after all other ipv4 fields are formatted
900 | # CHECKSUM UPDATED AT THE END
901 |
902 |
903 | ########### TCP
904 | dominating_protocol = protocol_determination(generated_nprint)
905 | if dominating_protocol == 'tcp':
906 | generated_nprint = tcp_header_negative_removal(generated_nprint, formatted_nprint)
907 | generated_nprint = tcp_opt_formatting(generated_nprint,formatted_nprint) # option must be continuous and has fixed length, we use closest approximation here
908 | generated_nprint = tcp_data_offset_calculation(generated_nprint, formatted_nprint) # count the total number of bytes in the tcp header fields including options and store the sume as the offset
909 |
910 | ########### IPV4
911 | generated_nprint = ipv4_tl_formatting_tcp(generated_nprint,formatted_nprint) # payload need to be considered
912 | elif dominating_protocol == 'udp':
913 | generated_nprint = udp_header_negative_removal(generated_nprint, formatted_nprint)
914 | generated_nprint = udp_len_calculation(generated_nprint, formatted_nprint)
915 | ########### IPV4
916 | generated_nprint = ipv4_tl_formatting_udp(generated_nprint,formatted_nprint) # payload need to be considered
917 |
918 | ##############################################################################End of Intra packet dependency adjustments#########################################################################
919 |
920 |
921 |
922 |
923 | ############################################################################## Inter packet dependency adjustments#########################################################################
924 | # random initial identification number initial for first source to destination packet, use synthetic_sequence to keep track for the rest for increments
925 | generated_nprint = id_num_initialization_src_dst(generated_nprint)
926 | # random initial identification number initial for first dst to src packet, use synthetic_sequence to keep track
927 | generated_nprint = id_num_initialization_dst_src(generated_nprint)
928 | #generated_nprint = seq_num_initialization_dst_src(generated_nprint)
929 |
930 | # Set fragmentation related bits, do not fragment to 1, OR NOT? reseved bit to 0 for all packets, more fragments bit to 0 for all packets, 0 for fragmentation offset
931 | generated_nprint = ip_fragementation_bits(generated_nprint)
932 | dominating_protocol = protocol_determination(generated_nprint)
933 | if dominating_protocol == 'tcp':
934 | # ports needs to be consistent but can be randomly generated
935 | generated_nprint = port_initialization(generated_nprint)
936 | # seq number must be computed based on tcp segment length
937 | generated_nprint = seq_initialization_src_dst(generated_nprint)
938 | generated_nprint = seq_initialization_dst_src(generated_nprint)
939 | # three way handshake initial flags must be set correctly
940 | generated_nprint = three_way_handshake(generated_nprint)
941 |
942 | # ack num formatting
943 | generated_nprint = ackn_initialization_src_dst(generated_nprint)
944 | elif dominating_protocol == 'udp':
945 | # ports needs to be consistent but can be randomly generated
946 | generated_nprint = port_initialization(generated_nprint)
947 |
948 |
949 | ##############################################################################End of Inter packet dependency adjustments#########################################################################
950 |
951 |
952 | ## ipv6 removal
953 | fields = ["ipv6"]
954 | # Iterate over the columns of the source DataFrame
955 | for column in formatted_nprint.columns:
956 | # Check if the substring exists in the column name
957 | for field in fields:
958 | #if field in column:
959 | if field in column :
960 | generated_nprint[column] = -1
961 |
962 | # saved the formatted_generated_nprint and attempt reconstruction
963 | formatted_generated_nprint_path = nprint
964 | generated_nprint.to_csv(formatted_generated_nprint_path, index=False)
965 | reconstruction_to_pcap(formatted_generated_nprint_path, rebuilt_pcap_path)
966 | update_ipv4_checksum(rebuilt_pcap_path, rebuilt_pcap_path)
967 | subprocess.run('nprint -F -1 -P {0} -4 -i -6 -t -u -p 0 -c 1024 -W {1}'.format(rebuilt_pcap_path, formatted_generated_nprint_path), shell=True)
968 | #reconstruction_to_pcap('/Users/chasejiang/Desktop/netdiffussion/replayability/meet_real.nprint', rebuilt_pcap_path)
969 | #reconstruction_to_pcap(formatted_nprint_path, rebuilt_pcap_path)
970 |
971 |
972 |
973 |
974 | if __name__ == '__main__':
975 | parser = argparse.ArgumentParser(description='pcap reconstruction')
976 | parser.add_argument('--generated_nprint_path', required = True, help='Path to the generated nPrint file')
977 | parser.add_argument('--formatted_nprint_path', required = True, help='Path to the formatted nPrint file')
978 | parser.add_argument('--output', required = True, help='Path to the reconstructed pcap file')
979 | parser.add_argument('--nprint', required = True, help='Path to the reconstructed nprint file')
980 | parser.add_argument('--src_ip',
981 | help='Desired source IP address, randomly generated if not specified',
982 | default='0.0.0.0')
983 | parser.add_argument('--dst_ip',
984 | help='Desired destination IP address, randomly generated if not specified',
985 | default='0.0.0.0')
986 | args = parser.parse_args()
987 | main(args.generated_nprint_path, args.formatted_nprint_path, args.output, args.nprint)
--------------------------------------------------------------------------------
/scripts/traffic_conditioning_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/noise-lab/NetDiffusion_Generator/7201a702149fbd26ecdcf50bb000da4bb7fa2787/scripts/traffic_conditioning_image.png
--------------------------------------------------------------------------------
/scripts/train_dreambooth_lora_sd3_miniature.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import argparse
18 | import copy
19 | import gc
20 | import hashlib
21 | import logging
22 | import math
23 | import os
24 | import random
25 | import shutil
26 | from contextlib import nullcontext
27 | from pathlib import Path
28 |
29 | import numpy as np
30 | import pandas as pd
31 | import torch
32 | import torch.utils.checkpoint
33 | import transformers
34 | from accelerate import Accelerator
35 | from accelerate.logging import get_logger
36 | from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
37 | from huggingface_hub import create_repo, upload_folder
38 | from peft import LoraConfig, set_peft_model_state_dict
39 | from peft.utils import get_peft_model_state_dict
40 | from PIL import Image
41 | from PIL.ImageOps import exif_transpose
42 | from torch.utils.data import Dataset
43 | from torchvision import transforms
44 | from torchvision.transforms.functional import crop
45 | from tqdm.auto import tqdm
46 |
47 | import diffusers
48 | from diffusers import (
49 | AutoencoderKL,
50 | FlowMatchEulerDiscreteScheduler,
51 | SD3Transformer2DModel,
52 | StableDiffusion3Pipeline,
53 | )
54 | from diffusers.optimization import get_scheduler
55 | from diffusers.training_utils import (
56 | cast_training_params,
57 | compute_density_for_timestep_sampling,
58 | compute_loss_weighting_for_sd3,
59 | )
60 | from diffusers.utils import (
61 | check_min_version,
62 | convert_unet_state_dict_to_peft,
63 | is_wandb_available,
64 | )
65 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
66 | from diffusers.utils.torch_utils import is_compiled_module
67 |
68 |
69 | if is_wandb_available():
70 | import wandb
71 |
72 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
73 | check_min_version("0.30.0.dev0")
74 |
75 | logger = get_logger(__name__)
76 |
77 |
78 | def save_model_card(
79 | repo_id: str,
80 | images=None,
81 | base_model: str = None,
82 | train_text_encoder=False,
83 | instance_prompt=None,
84 | validation_prompt=None,
85 | repo_folder=None,
86 | ):
87 | widget_dict = []
88 | if images is not None:
89 | for i, image in enumerate(images):
90 | image.save(os.path.join(repo_folder, f"image_{i}.png"))
91 | widget_dict.append(
92 | {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
93 | )
94 |
95 | model_description = f"""
96 | # SD3 DreamBooth LoRA - {repo_id}
97 |
98 |
99 |
100 | ## Model description
101 |
102 | These are {repo_id} DreamBooth weights for {base_model}.
103 |
104 | The weights were trained using [DreamBooth](https://dreambooth.github.io/).
105 |
106 | LoRA for the text encoder was enabled: {train_text_encoder}.
107 |
108 | ## Trigger words
109 |
110 | You should use {instance_prompt} to trigger the image generation.
111 |
112 | ## Download model
113 |
114 | [Download]({repo_id}/tree/main) them in the Files & versions tab.
115 |
116 | ## License
117 |
118 | Please adhere to the licensing terms as described [here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE).
119 | """
120 | model_card = load_or_create_model_card(
121 | repo_id_or_path=repo_id,
122 | from_training=True,
123 | license="openrail++",
124 | base_model=base_model,
125 | prompt=instance_prompt,
126 | model_description=model_description,
127 | widget=widget_dict,
128 | )
129 | tags = [
130 | "text-to-image",
131 | "diffusers-training",
132 | "diffusers",
133 | "lora",
134 | "sd3",
135 | "sd3-diffusers",
136 | "template:sd-lora",
137 | ]
138 |
139 | model_card = populate_model_card(model_card, tags=tags)
140 | model_card.save(os.path.join(repo_folder, "README.md"))
141 |
142 |
143 | def log_validation(
144 | pipeline,
145 | args,
146 | accelerator,
147 | pipeline_args,
148 | epoch,
149 | is_final_validation=False,
150 | ):
151 | logger.info(
152 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
153 | f" {args.validation_prompt}."
154 | )
155 | pipeline.enable_model_cpu_offload()
156 | pipeline.set_progress_bar_config(disable=True)
157 |
158 | # run inference
159 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
160 | # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
161 | autocast_ctx = nullcontext()
162 |
163 | with autocast_ctx:
164 | images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
165 |
166 | for tracker in accelerator.trackers:
167 | phase_name = "test" if is_final_validation else "validation"
168 | if tracker.name == "tensorboard":
169 | np_images = np.stack([np.asarray(img) for img in images])
170 | tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
171 | if tracker.name == "wandb":
172 | tracker.log(
173 | {
174 | phase_name: [
175 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
176 | ]
177 | }
178 | )
179 |
180 | del pipeline
181 | if torch.cuda.is_available():
182 | torch.cuda.empty_cache()
183 |
184 | return images
185 |
186 |
187 | def parse_args(input_args=None):
188 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
189 | parser.add_argument(
190 | "--pretrained_model_name_or_path",
191 | type=str,
192 | default=None,
193 | required=True,
194 | help="Path to pretrained model or model identifier from huggingface.co/models.",
195 | )
196 | parser.add_argument(
197 | "--revision",
198 | type=str,
199 | default=None,
200 | required=False,
201 | help="Revision of pretrained model identifier from huggingface.co/models.",
202 | )
203 | parser.add_argument(
204 | "--variant",
205 | type=str,
206 | default=None,
207 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
208 | )
209 | parser.add_argument(
210 | "--instance_data_dir",
211 | type=str,
212 | default=None,
213 | help=("A folder containing the training data. "),
214 | )
215 | parser.add_argument(
216 | "--data_df_path",
217 | type=str,
218 | default=None,
219 | help=("Path to the parquet file serialized with compute_embeddings.py."),
220 | )
221 | parser.add_argument(
222 | "--cache_dir",
223 | type=str,
224 | default=None,
225 | help="The directory where the downloaded models and datasets will be stored.",
226 | )
227 | parser.add_argument(
228 | "--instance_prompt",
229 | type=str,
230 | default=None,
231 | required=True,
232 | help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
233 | )
234 | parser.add_argument(
235 | "--max_sequence_length",
236 | type=int,
237 | default=77,
238 | help="Maximum sequence length to use with with the T5 text encoder",
239 | )
240 | parser.add_argument(
241 | "--validation_prompt",
242 | type=str,
243 | default=None,
244 | help="A prompt that is used during validation to verify that the model is learning.",
245 | )
246 | parser.add_argument(
247 | "--num_validation_images",
248 | type=int,
249 | default=4,
250 | help="Number of images that should be generated during validation with `validation_prompt`.",
251 | )
252 | parser.add_argument(
253 | "--validation_epochs",
254 | type=int,
255 | default=50,
256 | help=(
257 | "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
258 | " `args.validation_prompt` multiple times: `args.num_validation_images`."
259 | ),
260 | )
261 | parser.add_argument(
262 | "--rank",
263 | type=int,
264 | default=4,
265 | help=("The dimension of the LoRA update matrices."),
266 | )
267 | parser.add_argument(
268 | "--output_dir",
269 | type=str,
270 | default="sd3-dreambooth-lora",
271 | help="The output directory where the model predictions and checkpoints will be written.",
272 | )
273 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
274 | parser.add_argument(
275 | "--resolution",
276 | type=int,
277 | default=512,
278 | help=(
279 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
280 | " resolution"
281 | ),
282 | )
283 | parser.add_argument(
284 | "--center_crop",
285 | default=False,
286 | action="store_true",
287 | help=(
288 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
289 | " cropped. The images will be resized to the resolution first before cropping."
290 | ),
291 | )
292 | parser.add_argument(
293 | "--random_flip",
294 | action="store_true",
295 | help="whether to randomly flip images horizontally",
296 | )
297 |
298 | parser.add_argument(
299 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
300 | )
301 | parser.add_argument("--num_train_epochs", type=int, default=1)
302 | parser.add_argument(
303 | "--max_train_steps",
304 | type=int,
305 | default=None,
306 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
307 | )
308 | parser.add_argument(
309 | "--checkpointing_steps",
310 | type=int,
311 | default=500,
312 | help=(
313 | "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
314 | " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
315 | " training using `--resume_from_checkpoint`."
316 | ),
317 | )
318 | parser.add_argument(
319 | "--checkpoints_total_limit",
320 | type=int,
321 | default=None,
322 | help=("Max number of checkpoints to store."),
323 | )
324 | parser.add_argument(
325 | "--resume_from_checkpoint",
326 | type=str,
327 | default=None,
328 | help=(
329 | "Whether training should be resumed from a previous checkpoint. Use a path saved by"
330 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
331 | ),
332 | )
333 | parser.add_argument(
334 | "--gradient_accumulation_steps",
335 | type=int,
336 | default=1,
337 | help="Number of updates steps to accumulate before performing a backward/update pass.",
338 | )
339 | parser.add_argument(
340 | "--gradient_checkpointing",
341 | action="store_true",
342 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
343 | )
344 | parser.add_argument(
345 | "--learning_rate",
346 | type=float,
347 | default=1e-4,
348 | help="Initial learning rate (after the potential warmup period) to use.",
349 | )
350 | parser.add_argument(
351 | "--scale_lr",
352 | action="store_true",
353 | default=False,
354 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
355 | )
356 | parser.add_argument(
357 | "--lr_scheduler",
358 | type=str,
359 | default="constant",
360 | help=(
361 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
362 | ' "constant", "constant_with_warmup"]'
363 | ),
364 | )
365 | parser.add_argument(
366 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
367 | )
368 | parser.add_argument(
369 | "--lr_num_cycles",
370 | type=int,
371 | default=1,
372 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
373 | )
374 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
375 | parser.add_argument(
376 | "--dataloader_num_workers",
377 | type=int,
378 | default=0,
379 | help=(
380 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
381 | ),
382 | )
383 | parser.add_argument(
384 | "--weighting_scheme",
385 | type=str,
386 | default="logit_normal",
387 | choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
388 | )
389 | parser.add_argument(
390 | "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
391 | )
392 | parser.add_argument(
393 | "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
394 | )
395 | parser.add_argument(
396 | "--mode_scale",
397 | type=float,
398 | default=1.29,
399 | help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
400 | )
401 | parser.add_argument(
402 | "--optimizer",
403 | type=str,
404 | default="AdamW",
405 | help=('The optimizer type to use. Choose between ["AdamW"]'),
406 | )
407 |
408 | parser.add_argument(
409 | "--use_8bit_adam",
410 | action="store_true",
411 | help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
412 | )
413 |
414 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
415 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
416 | parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
417 |
418 | parser.add_argument(
419 | "--adam_epsilon",
420 | type=float,
421 | default=1e-08,
422 | help="Epsilon value for the Adam optimizer.",
423 | )
424 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
425 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
426 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
427 | parser.add_argument(
428 | "--hub_model_id",
429 | type=str,
430 | default=None,
431 | help="The name of the repository to keep in sync with the local `output_dir`.",
432 | )
433 | parser.add_argument(
434 | "--logging_dir",
435 | type=str,
436 | default="logs",
437 | help=(
438 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
439 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
440 | ),
441 | )
442 | parser.add_argument(
443 | "--allow_tf32",
444 | action="store_true",
445 | help=(
446 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
447 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
448 | ),
449 | )
450 | parser.add_argument(
451 | "--report_to",
452 | type=str,
453 | default="tensorboard",
454 | help=(
455 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
456 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
457 | ),
458 | )
459 | parser.add_argument(
460 | "--mixed_precision",
461 | type=str,
462 | default=None,
463 | choices=["no", "fp16", "bf16"],
464 | help=(
465 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
466 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
467 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
468 | ),
469 | )
470 | parser.add_argument(
471 | "--prior_generation_precision",
472 | type=str,
473 | default=None,
474 | choices=["no", "fp32", "fp16", "bf16"],
475 | help=(
476 | "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
477 | " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
478 | ),
479 | )
480 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
481 |
482 | if input_args is not None:
483 | args = parser.parse_args(input_args)
484 | else:
485 | args = parser.parse_args()
486 |
487 | if args.instance_data_dir is None:
488 | raise ValueError("Specify `instance_data_dir`.")
489 |
490 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
491 | if env_local_rank != -1 and env_local_rank != args.local_rank:
492 | args.local_rank = env_local_rank
493 |
494 | return args
495 |
496 |
497 | class DreamBoothDataset(Dataset):
498 | """
499 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
500 | It pre-processes the images.
501 | """
502 |
503 | def __init__(
504 | self,
505 | data_df_path,
506 | instance_data_root,
507 | instance_prompt,
508 | size=1024,
509 | center_crop=False,
510 | ):
511 | # Logistics
512 | self.size = size
513 | self.center_crop = center_crop
514 |
515 | self.instance_prompt = instance_prompt
516 | self.instance_data_root = Path(instance_data_root)
517 | if not self.instance_data_root.exists():
518 | raise ValueError("Instance images root doesn't exists.")
519 |
520 | # Load images.
521 | instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
522 | image_hashes = [self.generate_image_hash(path) for path in list(Path(instance_data_root).iterdir())]
523 | self.instance_images = instance_images
524 | self.image_hashes = image_hashes
525 |
526 | # Image transformations
527 | self.pixel_values = self.apply_image_transformations(
528 | instance_images=instance_images, size=size, center_crop=center_crop
529 | )
530 |
531 | # Map hashes to embeddings.
532 | self.data_dict = self.map_image_hash_embedding(data_df_path=data_df_path)
533 |
534 | self.num_instance_images = len(instance_images)
535 | self._length = self.num_instance_images
536 |
537 | def __len__(self):
538 | return self._length
539 |
540 | def __getitem__(self, index):
541 | example = {}
542 | instance_image = self.pixel_values[index % self.num_instance_images]
543 | image_hash = self.image_hashes[index % self.num_instance_images]
544 | prompt_embeds, pooled_prompt_embeds = self.data_dict[image_hash]
545 | example["instance_images"] = instance_image
546 | example["prompt_embeds"] = prompt_embeds
547 | example["pooled_prompt_embeds"] = pooled_prompt_embeds
548 | return example
549 |
550 | def apply_image_transformations(self, instance_images, size, center_crop):
551 | pixel_values = []
552 |
553 | # train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
554 | # train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
555 | target_height = 1024
556 | target_width = 1089
557 | train_resize = transforms.Resize((target_height, target_width), interpolation=transforms.InterpolationMode.BILINEAR)
558 | train_crop = None # no need to crop anymore
559 |
560 | train_flip = transforms.RandomHorizontalFlip(p=1.0)
561 | train_transforms = transforms.Compose(
562 | [
563 | transforms.ToTensor(),
564 | transforms.Normalize([0.5], [0.5]),
565 | ]
566 | )
567 | for image in instance_images:
568 | image = exif_transpose(image)
569 | if not image.mode == "RGB":
570 | image = image.convert("RGB")
571 | image = train_resize(image)
572 | if args.random_flip and random.random() < 0.5:
573 | # flip
574 | image = train_flip(image)
575 | # if args.center_crop:
576 | # y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
577 | # x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
578 | # image = train_crop(image)
579 | # else:
580 | # y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
581 | # image = crop(image, y1, x1, h, w)
582 | if center_crop and train_crop is not None:
583 | image = train_crop(image)
584 | image = train_transforms(image)
585 | pixel_values.append(image)
586 |
587 | return pixel_values
588 |
589 | def convert_to_torch_tensor(self, embeddings: list):
590 | prompt_embeds = embeddings[0]
591 | pooled_prompt_embeds = embeddings[1]
592 | prompt_embeds = np.array(prompt_embeds).reshape(154, 4096)
593 | pooled_prompt_embeds = np.array(pooled_prompt_embeds).reshape(2048)
594 | return torch.from_numpy(prompt_embeds), torch.from_numpy(pooled_prompt_embeds)
595 |
596 | def map_image_hash_embedding(self, data_df_path):
597 | hashes_df = pd.read_parquet(data_df_path)
598 | data_dict = {}
599 | for i, row in hashes_df.iterrows():
600 | embeddings = [row["prompt_embeds"], row["pooled_prompt_embeds"]]
601 | prompt_embeds, pooled_prompt_embeds = self.convert_to_torch_tensor(embeddings=embeddings)
602 | data_dict.update({row["image_hash"]: (prompt_embeds, pooled_prompt_embeds)})
603 | return data_dict
604 |
605 | def generate_image_hash(self, image_path):
606 | with open(image_path, "rb") as f:
607 | img_data = f.read()
608 | return hashlib.sha256(img_data).hexdigest()
609 |
610 |
611 | def collate_fn(examples):
612 | pixel_values = [example["instance_images"] for example in examples]
613 | prompt_embeds = [example["prompt_embeds"] for example in examples]
614 | pooled_prompt_embeds = [example["pooled_prompt_embeds"] for example in examples]
615 |
616 | pixel_values = torch.stack(pixel_values)
617 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
618 | prompt_embeds = torch.stack(prompt_embeds)
619 | pooled_prompt_embeds = torch.stack(pooled_prompt_embeds)
620 |
621 | batch = {
622 | "pixel_values": pixel_values,
623 | "prompt_embeds": prompt_embeds,
624 | "pooled_prompt_embeds": pooled_prompt_embeds,
625 | }
626 | return batch
627 |
628 |
629 | def main(args):
630 | if args.report_to == "wandb" and args.hub_token is not None:
631 | raise ValueError(
632 | "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
633 | " Please use `huggingface-cli login` to authenticate with the Hub."
634 | )
635 |
636 | if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
637 | # due to pytorch#99272, MPS does not yet support bfloat16.
638 | raise ValueError(
639 | "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
640 | )
641 |
642 | logging_dir = Path(args.output_dir, args.logging_dir)
643 |
644 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
645 | kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
646 | accelerator = Accelerator(
647 | gradient_accumulation_steps=args.gradient_accumulation_steps,
648 | mixed_precision=args.mixed_precision,
649 | log_with=args.report_to,
650 | project_config=accelerator_project_config,
651 | kwargs_handlers=[kwargs],
652 | )
653 |
654 | # Disable AMP for MPS.
655 | if torch.backends.mps.is_available():
656 | accelerator.native_amp = False
657 |
658 | if args.report_to == "wandb":
659 | if not is_wandb_available():
660 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
661 |
662 | # Make one log on every process with the configuration for debugging.
663 | logging.basicConfig(
664 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
665 | datefmt="%m/%d/%Y %H:%M:%S",
666 | level=logging.INFO,
667 | )
668 | logger.info(accelerator.state, main_process_only=False)
669 | if accelerator.is_local_main_process:
670 | transformers.utils.logging.set_verbosity_warning()
671 | diffusers.utils.logging.set_verbosity_info()
672 | else:
673 | transformers.utils.logging.set_verbosity_error()
674 | diffusers.utils.logging.set_verbosity_error()
675 |
676 | # If passed along, set the training seed now.
677 | if args.seed is not None:
678 | set_seed(args.seed)
679 |
680 | # Handle the repository creation
681 | if accelerator.is_main_process:
682 | if args.output_dir is not None:
683 | os.makedirs(args.output_dir, exist_ok=True)
684 |
685 | if args.push_to_hub:
686 | repo_id = create_repo(
687 | repo_id=args.hub_model_id or Path(args.output_dir).name,
688 | exist_ok=True,
689 | ).repo_id
690 |
691 | # Load scheduler and models
692 | noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
693 | args.pretrained_model_name_or_path, subfolder="scheduler"
694 | )
695 | noise_scheduler_copy = copy.deepcopy(noise_scheduler)
696 | vae = AutoencoderKL.from_pretrained(
697 | args.pretrained_model_name_or_path,
698 | subfolder="vae",
699 | revision=args.revision,
700 | variant=args.variant,
701 | )
702 | transformer = SD3Transformer2DModel.from_pretrained(
703 | args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
704 | )
705 |
706 | transformer.requires_grad_(False)
707 | vae.requires_grad_(False)
708 |
709 | # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora transformer) to half-precision
710 | # as these weights are only used for inference, keeping weights in full precision is not required.
711 | weight_dtype = torch.float32
712 | if accelerator.mixed_precision == "fp16":
713 | weight_dtype = torch.float16
714 | elif accelerator.mixed_precision == "bf16":
715 | weight_dtype = torch.bfloat16
716 |
717 | if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
718 | # due to pytorch#99272, MPS does not yet support bfloat16.
719 | raise ValueError(
720 | "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
721 | )
722 |
723 | vae.to(accelerator.device, dtype=torch.float32)
724 | transformer.to(accelerator.device, dtype=weight_dtype)
725 |
726 | if args.gradient_checkpointing:
727 | transformer.enable_gradient_checkpointing()
728 |
729 | # now we will add new LoRA weights to the attention layers
730 | transformer_lora_config = LoraConfig(
731 | r=args.rank,
732 | lora_alpha=args.rank,
733 | init_lora_weights="gaussian",
734 | target_modules=["to_k", "to_q", "to_v", "to_out.0"],
735 | )
736 | transformer.add_adapter(transformer_lora_config)
737 |
738 | def unwrap_model(model):
739 | model = accelerator.unwrap_model(model)
740 | model = model._orig_mod if is_compiled_module(model) else model
741 | return model
742 |
743 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
744 | def save_model_hook(models, weights, output_dir):
745 | if accelerator.is_main_process:
746 | transformer_lora_layers_to_save = None
747 | for model in models:
748 | if isinstance(model, type(unwrap_model(transformer))):
749 | transformer_lora_layers_to_save = get_peft_model_state_dict(model)
750 | else:
751 | raise ValueError(f"unexpected save model: {model.__class__}")
752 |
753 | # make sure to pop weight so that corresponding model is not saved again
754 | weights.pop()
755 |
756 | StableDiffusion3Pipeline.save_lora_weights(
757 | output_dir,
758 | transformer_lora_layers=transformer_lora_layers_to_save,
759 | )
760 |
761 | def load_model_hook(models, input_dir):
762 | transformer_ = None
763 |
764 | while len(models) > 0:
765 | model = models.pop()
766 |
767 | if isinstance(model, type(unwrap_model(transformer))):
768 | transformer_ = model
769 | else:
770 | raise ValueError(f"unexpected save model: {model.__class__}")
771 |
772 | lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
773 |
774 | transformer_state_dict = {
775 | f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
776 | }
777 | transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
778 | incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
779 | if incompatible_keys is not None:
780 | # check only for unexpected keys
781 | unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
782 | if unexpected_keys:
783 | logger.warning(
784 | f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
785 | f" {unexpected_keys}. "
786 | )
787 |
788 | # Make sure the trainable params are in float32. This is again needed since the base models
789 | # are in `weight_dtype`. More details:
790 | # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
791 | if args.mixed_precision == "fp16":
792 | models = [transformer_]
793 | # only upcast trainable parameters (LoRA) into fp32
794 | cast_training_params(models)
795 |
796 | accelerator.register_save_state_pre_hook(save_model_hook)
797 | accelerator.register_load_state_pre_hook(load_model_hook)
798 |
799 | # Enable TF32 for faster training on Ampere GPUs,
800 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
801 | if args.allow_tf32 and torch.cuda.is_available():
802 | torch.backends.cuda.matmul.allow_tf32 = True
803 |
804 | if args.scale_lr:
805 | args.learning_rate = (
806 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
807 | )
808 |
809 | # Make sure the trainable params are in float32.
810 | if args.mixed_precision == "fp16":
811 | models = [transformer]
812 | # only upcast trainable parameters (LoRA) into fp32
813 | cast_training_params(models, dtype=torch.float32)
814 |
815 | # Optimization parameters
816 | transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
817 | transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
818 | params_to_optimize = [transformer_parameters_with_lr]
819 |
820 | # Optimizer creation
821 | if not args.optimizer.lower() == "adamw":
822 | logger.warning(
823 | f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include [adamW]."
824 | "Defaulting to adamW"
825 | )
826 | args.optimizer = "adamw"
827 |
828 | if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
829 | logger.warning(
830 | f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
831 | f"set to {args.optimizer.lower()}"
832 | )
833 |
834 | if args.optimizer.lower() == "adamw":
835 | if args.use_8bit_adam:
836 | try:
837 | import bitsandbytes as bnb
838 | except ImportError:
839 | raise ImportError(
840 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
841 | )
842 |
843 | optimizer_class = bnb.optim.AdamW8bit
844 | else:
845 | optimizer_class = torch.optim.AdamW
846 |
847 | optimizer = optimizer_class(
848 | params_to_optimize,
849 | betas=(args.adam_beta1, args.adam_beta2),
850 | weight_decay=args.adam_weight_decay,
851 | eps=args.adam_epsilon,
852 | )
853 |
854 | # Dataset and DataLoaders creation:
855 | train_dataset = DreamBoothDataset(
856 | data_df_path=args.data_df_path,
857 | instance_data_root=args.instance_data_dir,
858 | instance_prompt=args.instance_prompt,
859 | size=args.resolution,
860 | center_crop=args.center_crop,
861 | )
862 |
863 | train_dataloader = torch.utils.data.DataLoader(
864 | train_dataset,
865 | batch_size=args.train_batch_size,
866 | shuffle=True,
867 | collate_fn=lambda examples: collate_fn(examples),
868 | num_workers=args.dataloader_num_workers,
869 | )
870 |
871 | # Scheduler and math around the number of training steps.
872 | overrode_max_train_steps = False
873 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
874 | if args.max_train_steps is None:
875 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
876 | overrode_max_train_steps = True
877 |
878 | lr_scheduler = get_scheduler(
879 | args.lr_scheduler,
880 | optimizer=optimizer,
881 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
882 | num_training_steps=args.max_train_steps * accelerator.num_processes,
883 | num_cycles=args.lr_num_cycles,
884 | power=args.lr_power,
885 | )
886 |
887 | # Prepare everything with our `accelerator`.
888 | transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
889 | transformer, optimizer, train_dataloader, lr_scheduler
890 | )
891 |
892 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
893 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
894 | if overrode_max_train_steps:
895 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
896 | # Afterwards we recalculate our number of training epochs
897 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
898 |
899 | # We need to initialize the trackers we use, and also store our configuration.
900 | # The trackers initializes automatically on the main process.
901 | if accelerator.is_main_process:
902 | tracker_name = "dreambooth-sd3-lora-miniature"
903 | accelerator.init_trackers(tracker_name, config=vars(args))
904 |
905 | # Train!
906 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
907 |
908 | logger.info("***** Running training *****")
909 | logger.info(f" Num examples = {len(train_dataset)}")
910 | logger.info(f" Num batches each epoch = {len(train_dataloader)}")
911 | logger.info(f" Num Epochs = {args.num_train_epochs}")
912 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
913 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
914 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
915 | logger.info(f" Total optimization steps = {args.max_train_steps}")
916 | global_step = 0
917 | first_epoch = 0
918 |
919 | # Potentially load in the weights and states from a previous save
920 | if args.resume_from_checkpoint:
921 | if args.resume_from_checkpoint != "latest":
922 | path = os.path.basename(args.resume_from_checkpoint)
923 | else:
924 | # Get the mos recent checkpoint
925 | dirs = os.listdir(args.output_dir)
926 | dirs = [d for d in dirs if d.startswith("checkpoint")]
927 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
928 | path = dirs[-1] if len(dirs) > 0 else None
929 |
930 | if path is None:
931 | accelerator.print(
932 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
933 | )
934 | args.resume_from_checkpoint = None
935 | initial_global_step = 0
936 | else:
937 | accelerator.print(f"Resuming from checkpoint {path}")
938 | accelerator.load_state(os.path.join(args.output_dir, path))
939 | global_step = int(path.split("-")[1])
940 |
941 | initial_global_step = global_step
942 | first_epoch = global_step // num_update_steps_per_epoch
943 |
944 | else:
945 | initial_global_step = 0
946 |
947 | progress_bar = tqdm(
948 | range(0, args.max_train_steps),
949 | initial=initial_global_step,
950 | desc="Steps",
951 | # Only show the progress bar once on each machine.
952 | disable=not accelerator.is_local_main_process,
953 | )
954 |
955 | def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
956 | sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
957 | schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
958 | timesteps = timesteps.to(accelerator.device)
959 | step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
960 |
961 | sigma = sigmas[step_indices].flatten()
962 | while len(sigma.shape) < n_dim:
963 | sigma = sigma.unsqueeze(-1)
964 | return sigma
965 |
966 | for epoch in range(first_epoch, args.num_train_epochs):
967 | transformer.train()
968 |
969 | for step, batch in enumerate(train_dataloader):
970 | models_to_accumulate = [transformer]
971 | with accelerator.accumulate(models_to_accumulate):
972 | pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
973 |
974 | # Convert images to latent space
975 | model_input = vae.encode(pixel_values).latent_dist.sample()
976 | model_input = model_input * vae.config.scaling_factor
977 | model_input = model_input.to(dtype=weight_dtype)
978 |
979 | # Sample noise that we'll add to the latents
980 | noise = torch.randn_like(model_input)
981 | bsz = model_input.shape[0]
982 |
983 | # Sample a random timestep for each image
984 | # for weighting schemes where we sample timesteps non-uniformly
985 | u = compute_density_for_timestep_sampling(
986 | weighting_scheme=args.weighting_scheme,
987 | batch_size=bsz,
988 | logit_mean=args.logit_mean,
989 | logit_std=args.logit_std,
990 | mode_scale=args.mode_scale,
991 | )
992 | indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
993 | timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
994 |
995 | # Add noise according to flow matching.
996 | sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
997 | noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input
998 |
999 | # Predict the noise residual
1000 | prompt_embeds, pooled_prompt_embeds = batch["prompt_embeds"], batch["pooled_prompt_embeds"]
1001 | prompt_embeds = prompt_embeds.to(device=accelerator.device, dtype=weight_dtype)
1002 | pooled_prompt_embeds = pooled_prompt_embeds.to(device=accelerator.device, dtype=weight_dtype)
1003 | model_pred = transformer(
1004 | hidden_states=noisy_model_input,
1005 | timestep=timesteps,
1006 | encoder_hidden_states=prompt_embeds,
1007 | pooled_projections=pooled_prompt_embeds,
1008 | return_dict=False,
1009 | )[0]
1010 |
1011 | # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
1012 | # Preconditioning of the model outputs.
1013 | model_pred = model_pred * (-sigmas) + noisy_model_input
1014 |
1015 | # these weighting schemes use a uniform timestep sampling
1016 | # and instead post-weight the loss
1017 | weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
1018 |
1019 | # flow matching loss
1020 | target = model_input
1021 |
1022 | # Compute regular loss.
1023 | loss = torch.mean(
1024 | (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1025 | 1,
1026 | )
1027 | loss = loss.mean()
1028 |
1029 | accelerator.backward(loss)
1030 | if accelerator.sync_gradients:
1031 | params_to_clip = transformer_lora_parameters
1032 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1033 |
1034 | optimizer.step()
1035 | lr_scheduler.step()
1036 | optimizer.zero_grad()
1037 |
1038 | # Checks if the accelerator has performed an optimization step behind the scenes
1039 | if accelerator.sync_gradients:
1040 | progress_bar.update(1)
1041 | global_step += 1
1042 |
1043 | if accelerator.is_main_process:
1044 | if global_step % args.checkpointing_steps == 0:
1045 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1046 | if args.checkpoints_total_limit is not None:
1047 | checkpoints = os.listdir(args.output_dir)
1048 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1049 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1050 |
1051 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1052 | if len(checkpoints) >= args.checkpoints_total_limit:
1053 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1054 | removing_checkpoints = checkpoints[0:num_to_remove]
1055 |
1056 | logger.info(
1057 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1058 | )
1059 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1060 |
1061 | for removing_checkpoint in removing_checkpoints:
1062 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1063 | shutil.rmtree(removing_checkpoint)
1064 |
1065 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1066 | accelerator.save_state(save_path)
1067 | logger.info(f"Saved state to {save_path}")
1068 |
1069 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1070 | progress_bar.set_postfix(**logs)
1071 | accelerator.log(logs, step=global_step)
1072 |
1073 | if global_step >= args.max_train_steps:
1074 | break
1075 |
1076 | if accelerator.is_main_process:
1077 | if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
1078 | pipeline = StableDiffusion3Pipeline.from_pretrained(
1079 | args.pretrained_model_name_or_path,
1080 | vae=vae,
1081 | transformer=accelerator.unwrap_model(transformer),
1082 | revision=args.revision,
1083 | variant=args.variant,
1084 | torch_dtype=weight_dtype,
1085 | )
1086 | pipeline_args = {"prompt": args.validation_prompt}
1087 | images = log_validation(
1088 | pipeline=pipeline,
1089 | args=args,
1090 | accelerator=accelerator,
1091 | pipeline_args=pipeline_args,
1092 | epoch=epoch,
1093 | )
1094 | torch.cuda.empty_cache()
1095 | gc.collect()
1096 |
1097 | # Save the lora layers
1098 | accelerator.wait_for_everyone()
1099 | if accelerator.is_main_process:
1100 | transformer = unwrap_model(transformer)
1101 | transformer = transformer.to(torch.float32)
1102 | transformer_lora_layers = get_peft_model_state_dict(transformer)
1103 |
1104 | StableDiffusion3Pipeline.save_lora_weights(
1105 | save_directory=args.output_dir,
1106 | transformer_lora_layers=transformer_lora_layers,
1107 | )
1108 |
1109 | # Final inference
1110 | # Load previous pipeline
1111 | pipeline = StableDiffusion3Pipeline.from_pretrained(
1112 | args.pretrained_model_name_or_path,
1113 | revision=args.revision,
1114 | variant=args.variant,
1115 | torch_dtype=weight_dtype,
1116 | )
1117 | # load attention processors
1118 | pipeline.load_lora_weights(args.output_dir)
1119 |
1120 | # run inference
1121 | images = []
1122 | if args.validation_prompt and args.num_validation_images > 0:
1123 | pipeline_args = {"prompt": args.validation_prompt}
1124 | images = log_validation(
1125 | pipeline=pipeline,
1126 | args=args,
1127 | accelerator=accelerator,
1128 | pipeline_args=pipeline_args,
1129 | epoch=epoch,
1130 | is_final_validation=True,
1131 | )
1132 |
1133 | if args.push_to_hub:
1134 | save_model_card(
1135 | repo_id,
1136 | images=images,
1137 | base_model=args.pretrained_model_name_or_path,
1138 | instance_prompt=args.instance_prompt,
1139 | validation_prompt=args.validation_prompt,
1140 | repo_folder=args.output_dir,
1141 | )
1142 | upload_folder(
1143 | repo_id=repo_id,
1144 | folder_path=args.output_dir,
1145 | commit_message="End of training",
1146 | ignore_patterns=["step_*", "epoch_*"],
1147 | )
1148 |
1149 | accelerator.end_training()
1150 |
1151 |
1152 | if __name__ == "__main__":
1153 | args = parse_args()
1154 | main(args)
1155 |
--------------------------------------------------------------------------------