├── update_particle_phases.py ├── README.md ├── pbp_plotting.ipynb ├── pbp_preprocessing.ipynb └── particle_classification_CNN.ipynb /update_particle_phases.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Script to update particle phases based on directory classification. 4 | Assigns phase 2 to particles in 'donut' directory and phase 3 to particles in 'noise' directory. 5 | """ 6 | 7 | import pandas as pd 8 | import os 9 | from pathlib import Path 10 | 11 | def extract_particle_number(filename): 12 | """ 13 | Extract particle number from filename like 'particle_123.png' 14 | 15 | Args: 16 | filename: Image filename 17 | 18 | Returns: 19 | int: Particle number or None if parsing fails 20 | """ 21 | try: 22 | # Remove extension and 'particle_' prefix 23 | num_str = filename.replace('particle_', '').replace('.png', '') 24 | return int(num_str) 25 | except (ValueError, AttributeError): 26 | return None 27 | 28 | def get_particle_numbers_from_directory(directory_path): 29 | """ 30 | Get all particle numbers from a directory. 31 | 32 | Args: 33 | directory_path: Path to directory containing particle images 34 | 35 | Returns: 36 | set: Set of particle numbers found in the directory 37 | """ 38 | particle_numbers = set() 39 | 40 | if not os.path.exists(directory_path): 41 | print(f"Warning: Directory {directory_path} does not exist") 42 | return particle_numbers 43 | 44 | for filename in os.listdir(directory_path): 45 | if filename.endswith('.png'): 46 | particle_num = extract_particle_number(filename) 47 | if particle_num is not None: 48 | particle_numbers.add(particle_num) 49 | 50 | return particle_numbers 51 | 52 | def update_particle_phases(csv_path, base_dir): 53 | """ 54 | Update particle phases based on directory classification. 55 | 56 | Args: 57 | csv_path: Path to particle_phases.csv 58 | base_dir: Base directory containing donut and noise subdirectories 59 | """ 60 | # Read the CSV 61 | print(f"Reading {csv_path}...") 62 | df = pd.read_csv(csv_path) 63 | print(f"Loaded {len(df)} particles") 64 | 65 | # Get particle numbers from each directory 66 | donut_dir = os.path.join(base_dir, 'donut') 67 | noise_dir = os.path.join(base_dir, 'noise') 68 | 69 | print(f"\nScanning {donut_dir}...") 70 | donut_particles = get_particle_numbers_from_directory(donut_dir) 71 | print(f"Found {len(donut_particles)} particles in donut directory") 72 | 73 | print(f"\nScanning {noise_dir}...") 74 | noise_particles = get_particle_numbers_from_directory(noise_dir) 75 | print(f"Found {len(noise_particles)} particles in noise directory") 76 | 77 | # Update phases 78 | print("\nUpdating phases...") 79 | donut_updated = 0 80 | noise_updated = 0 81 | 82 | for idx, row in df.iterrows(): 83 | particle_idx = row['particle_idx_seq'] 84 | 85 | if particle_idx in donut_particles: 86 | df.at[idx, 'phase'] = 2 87 | donut_updated += 1 88 | elif particle_idx in noise_particles: 89 | df.at[idx, 'phase'] = 3 90 | noise_updated += 1 91 | 92 | print(f"Updated {donut_updated} particles to phase 2 (donut)") 93 | print(f"Updated {noise_updated} particles to phase 3 (noise)") 94 | 95 | # Save the updated CSV 96 | print(f"\nSaving updated CSV to {csv_path}...") 97 | df.to_csv(csv_path, index=False) 98 | print("Done!") 99 | 100 | # Print summary statistics 101 | print("\n--- Phase Distribution ---") 102 | phase_counts = df['phase'].value_counts().sort_index() 103 | for phase, count in phase_counts.items(): 104 | print(f"Phase {phase}: {count} particles") 105 | 106 | if __name__ == "__main__": 107 | # Set paths 108 | base_dir = "particle_images_filtered" 109 | csv_path = os.path.join(base_dir, "particle_df.csv") 110 | 111 | # Update phases 112 | update_particle_phases(csv_path, base_dir) 113 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Particle Phase Classification in Mixed-Phase Clouds 2 | 3 | ## Proposal Title 4 | **Classification of Liquid and Ice Particles in Mixed-Phase Clouds via Hybrid Convolutional Neural Network (CNN) and Multimodal Feature Learning** 5 | 6 | --- 7 | 8 | ## 1. Problem Significance 9 | 10 | **Cloud composition is critical in climate modeling and weather prediction**, as the phase of water (liquid, solid, or mixed) governs fundamental physical processes like solar radiation transfer (albedo), energy exchange, and precipitation formation. 11 | 12 | The **NSF/NCAR Research Aviation Facility** utilizes airborne instruments, such as the **Two-Dimensional, Stereo, Particle Imaging Probe (2D-S)**, which captures a **binary 2D image** (shadowgraph) of particles. This provides high-resolution data on size and complex shape. 13 | 14 | Cloud composition is critical in climate modeling and weather prediction because the phase of water governs physical processes like albedo, energy exchange, and precipitation formation. The NSF/NCAR Research Aviation Facility uses airborne instruments, like the [Two-Dimensional, Stereo, Particle Imaging Probe (2D-S)](https://www.eol.ucar.edu/instruments/two-dimensional-stereo-particle-imaging-probe) to study these cloud microphysics. The probe captures a binary 2D image representation as the particles pass through the probe, to provide high-resolution data on size and complex shape. 15 | 16 | * **Definitive Phases:** In air temperatures above 0-1 °C, particles are assumed liquid (water). Below -40 °C, they are assumed solid (ice). 17 | * **Mixed-Phase:** Between these values is the **mixed-phase range**, where supercooled water and ice particles coexist, and we are unable to assume the particle state based on air temperature. 18 | * **Current Challenge:** Currently, phase classification within the mixed-phase range is tedious, often relying on manual inspection of the imagery. There is no reliable, automated way to accurately classify these particles. 19 | 20 | I am seeking to address this gap by using a CNN on the particle-by-particle data to classify mixed-phase cloud particles as either liquid or ice. I will first train just using the particle images, and then I will add in environmental context to help refine the model. 21 | 22 | --- 23 | 24 | ## 2. Machine Learning Task Description 25 | 26 | The core objective is a **Binary Classification** task to identify if a particle is liquid or solid. 27 | 28 | ### A. Data Integration and Preprocessing 29 | 30 | 1. **Data Filtering:** To make sure our data is clean for training, we are filtering out particles under 100 microns and perfectly rectangular particles (with a 5% margin). This removes small particles that show up as square and could confuse the model, and removes rectangles which are almost always noise. 31 | 2. **Labeling:** Using the temperature data taken concurrently alongside the particle measurements, our labeled dataset is created. 32 | 3. **Image Segmentation & Standardization:** The NetCDF file contains the raw imagery data and the boundary indices ($\mathbf{starty}$, $\mathbf{stopy}$, $\mathbf{startx}$, $\mathbf{stopx}$) for each particle. I am extracting cropped 2D binary images for each particle using these indices. Each image is then scaled (preserving its aspect ratio) and zero-padded to a uniform **$128 \times 128$ pixel** canvas size, creating the required grid input for the CNN. 33 | 4. **Class Imbalance:** I have ~3.6 times the images for solid as I do for liquid. To handle the class imbalance, I will try oversampling the liquid images with augmentation, ensuring balanced batches. 34 | 35 | ### B. Classification Model (CNN --> Hybrid CNN) 36 | 37 | Initially, I will test a CNN just using the labeled images to see how it does, first with randomly initialized weights, and then exploring options for pretrained weights. 38 | 39 | I ultimately expect to employ a **Multi-Modal Convolutional Neural Network (CNN)** to process both data types as referenced [here](https://rosenfelder.ai/multi-input-neural-network-pytorch/): 40 | 41 | 1. **Image Branch (CNN):** The $\mathbf{128 \times 128 \times 1}$ grayscale image is processed through 4 convolutional blocks: 42 | * Conv2D layers with filters: 32 → 64 → 128 → 128 43 | * Each block includes BatchNormalization and MaxPooling (2×2) 44 | * Output is flattened and passed through Dense layers (256 neurons) with Dropout (0.3) 45 | 2. **Tabular Branch (Fully Connected or MLP):** The features (temperature, air speed, altitude) are processed through: 46 | * Dense layers: 32 → 16 neurons 47 | * Dropout regularization 48 | 3. **Fusion & Prediction:** Feature vectors from both branches are concatenated and passed through additional Dense layers (128 → 64 neurons) with Dropout ending with a **softmax output layer** producing binary class probabilities (0: Liquid, 1: Ice). 49 | 50 | ## 3. Dataset Characteristics 51 | 52 | | Metric | Detail | 53 | | :--- | :--- | 54 | | **Project Data Source** | CGWaveS RF02 project data (with potential for multi-project validation) | 55 | | **Dataset Size (Labeled Samples)** | $\mathbf{25,270}$ individual particles after filtering | 56 | | **Target Variable** | **Particle Phase (Binary Classification):** 0 (Water) / 1 (Ice) | 57 | | **Input Feature Type** | **Multimodal** (Image Array + Tabular Vector) | 58 | | **Image Input Dimensions** | $128 \times 128 \times 1$ (Standardized Grayscale Image) | 59 | | **Tabular Features** | 3 (Temperature, Air Speed, Altitude) | 60 | 61 | --- 62 | 63 | ## Appendix A: Implementation Status 64 | 65 | | Component | Status | Details | 66 | | :--- | :--- | :--- | 67 | | **Data Access** | Complete | NetCDF particle-by-particle and state variables loaded in `pbp_plotting.ipynb` | 68 | | **Pre-processing** | Complete | Image extraction, scaling, and standardization to $128 \times 128$ implemented in `plot_particle_standardized()` function | 69 | | **Image Export** | Complete | Standardized particle images saved. | 70 | | **Model Architecture** | In progress | Multi-modal CNN implemented in `particle_classification_CNN.ipynb` with separate branches for images and environmental features | 71 | | **Training Pipeline** | In progress | Complete training pipeline with data loading, splitting, callbacks, and evaluation. Train, validate, and refine the model on classified particles | 72 | 73 | --- 74 | 75 | ## Appendix B: Key Feature Variables 76 | 77 | | Variable | Type | Role | Implementation | 78 | | :--- | :--- | :--- | :--- | 79 | | $\mathbf{image}$ | Array ($128 \times 128 \times 1$) | Primary input for CNN; provides detailed shape morphology | Processed through 4-block CNN architecture | 80 | | $\mathbf{label}$ | Binary | Target variable: 0 (Liquid) / 1 (Ice) | One-hot encoded for training | 81 | 82 | ### Additional Variables Available (Not Currently Used) 83 | 84 | * $\mathbf{aspectratio}$: Particle shape metric (could be added to tabular features) 85 | * $\mathbf{diam}$: Particle size metric (could be added to tabular features) 86 | * $\mathbf{area}$: Particle area (could be added to tabular features) 87 | * $\mathbf{TASX}$ :True Air Speed for environmental context 88 | * $\mathbf{GGALT}$ : Flight altitude for environmental context 89 | * $\mathbf{ATX}$: Air temperature. Used for classification but could be used in the mix phase as a feature. 90 | -------------------------------------------------------------------------------- /pbp_plotting.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "7fc01067", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import xarray as xr\n", 11 | "import numpy as np" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "0625d93f", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "# Open pbp file\n", 22 | "ds = xr.open_dataset('PbP_SODA/20250524_022503_F2DS_V.pbp.nc',decode_times=False)" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "id": "eab693c5", 28 | "metadata": {}, 29 | "source": [ 30 | "### Saves probe time to UTC to display on images" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "id": "4891c89b", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "#convert seconds since midnight to hh:mm:ss\n", 41 | "origin = np.datetime64('2025-05-24T00:00:00')\n", 42 | "utc_times = origin + ds['probetime'].astype('timedelta64[s]')\n", 43 | "ds['UTC_Time'] = utc_times" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "id": "3976f4e1", 49 | "metadata": {}, 50 | "source": [ 51 | "## Plot Raw image, no scaling or trimming" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "5a4457b2", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "import numpy as np\n", 62 | "import matplotlib.pyplot as plt\n", 63 | "\n", 64 | "# --- 1. Select a Particle Index (e.g., the 5000th particle) ---\n", 65 | "particle_index = 34000\n", 66 | "\n", 67 | "# --- 2. Extract Pointers ---\n", 68 | "start_slice = ds['starty'].values[particle_index]\n", 69 | "stop_slice = ds['stopy'].values[particle_index]\n", 70 | "start_diode = ds['startx'].values[particle_index]\n", 71 | "stop_diode = ds['stopx'].values[particle_index]\n", 72 | "\n", 73 | "# --- 3. Extract the Image Sub-Array ---\n", 74 | "# Slicing the large 'image' array using the extracted indices\n", 75 | "particle_image = ds['image'].values[\n", 76 | " start_slice : stop_slice, \n", 77 | " start_diode : stop_diode\n", 78 | "]\n", 79 | "\n", 80 | "# Get descriptive size metrics for the title\n", 81 | "ysize_microns = ds['ysize'].values[particle_index]\n", 82 | "xsize_microns = ds['xsize'].values[particle_index]\n", 83 | "\n", 84 | "# --- 4. Plot the Image ---\n", 85 | "plt.figure(figsize=(4, 6))\n", 86 | "# 'binary' cmap shows 1s (shaded pixels) as black and 0s as white\n", 87 | "plt.imshow(particle_image, cmap='binary', interpolation='none') \n", 88 | "plt.title(f\"Particle {particle_index} | Size: {particle_image.shape[0]} slices x {particle_image.shape[1]} diodes\")\n", 89 | "plt.xlabel('Diode (X-Position)')\n", 90 | "plt.ylabel('Slice (Y-Position / Along Airflow)')\n", 91 | "plt.show()" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "id": "207bc955", 97 | "metadata": {}, 98 | "source": [ 99 | "## Plot histogram of phases for flight of interest" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "id": "60f3c17b", 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "import pandas as pd\n", 110 | "import matplotlib.pyplot as plt\n", 111 | "# Open your NetCDF file\n", 112 | "rf_ds = xr.open_dataset('RF05.20240311.051845_125704.PNI.nc')\n", 113 | "\n", 114 | "# Replace 'temperature' with the actual variable name in your dataset\n", 115 | "temp = rf_ds['ATX'].values\n", 116 | "\n", 117 | "# Define bins: (-inf, -30], (-30, 0], (0, inf)\n", 118 | "bins = [-np.inf, -40, 0, np.inf]\n", 119 | "labels = ['Ice', 'Mixed Phase', 'Water']\n", 120 | "\n", 121 | "# Bin the data\n", 122 | "categories = pd.cut(temp, bins=bins, labels=labels)\n", 123 | "\n", 124 | "# Count occurrences in each bin\n", 125 | "hist = pd.value_counts(categories, sort=False)\n", 126 | "\n", 127 | "\n", 128 | "hist.plot(kind='bar')\n", 129 | "plt.xlabel('Temperature Range')\n", 130 | "plt.ylabel('Count')\n", 131 | "plt.title('Temperature Histogram')\n", 132 | "plt.show()" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "id": "c699472c", 138 | "metadata": {}, 139 | "source": [ 140 | "## Standardized function to scale a particle to a 128x128 canvas and save the plot into specified output folder" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "id": "49b8b4ea", 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "import numpy as np\n", 151 | "import matplotlib.pyplot as plt\n", 152 | "from skimage.transform import resize\n", 153 | "import os\n", 154 | "\n", 155 | "# Make sure the output directory exists\n", 156 | "output_dir = \"cgwaves2_outputs\"\n", 157 | "os.makedirs(output_dir, exist_ok=True)\n", 158 | "\n", 159 | "def plot_particle_standardized(ds, particle_index, target_size=128, max_fit=124):\n", 160 | " \"\"\"\n", 161 | " Plot a single particle on a standardized square canvas.\n", 162 | " \n", 163 | " Parameters:\n", 164 | " -----------\n", 165 | " ds : xarray.Dataset\n", 166 | " Dataset containing particle image data\n", 167 | " particle_index : int\n", 168 | " Index of the particle to plot\n", 169 | " target_size : int\n", 170 | " Size of the output canvas (default: 128x128)\n", 171 | " max_fit : int\n", 172 | " Maximum dimension for the particle before padding (default: 124)\n", 173 | " \n", 174 | " Returns:\n", 175 | " --------\n", 176 | " canvas : np.ndarray\n", 177 | " The standardized 128x128 image array\n", 178 | " \"\"\"\n", 179 | " \n", 180 | " # Extract bounding box\n", 181 | " start_slice = ds['starty'].values[particle_index]\n", 182 | " stop_slice = ds['stopy'].values[particle_index]\n", 183 | " start_diode = ds['startx'].values[particle_index]\n", 184 | " stop_diode = ds['stopx'].values[particle_index]\n", 185 | " probe_time=ds.UTC_Time[particle_index].values.astype('M8[ms]').astype('str')\n", 186 | " # Crop the particle image\n", 187 | " cropped_image = ds['image'].values[\n", 188 | " start_slice:stop_slice, \n", 189 | " start_diode:stop_diode\n", 190 | " ]\n", 191 | " \n", 192 | " # Trim to only rows/cols with data\n", 193 | " rows_with_data = np.any(cropped_image == 1, axis=1)\n", 194 | " cols_with_data = np.any(cropped_image == 1, axis=0)\n", 195 | " \n", 196 | " if not np.any(rows_with_data) or not np.any(cols_with_data):\n", 197 | " print(f\"Warning: Particle {particle_index} has no shaded pixels\")\n", 198 | " return None\n", 199 | " \n", 200 | " cropped_image = cropped_image[rows_with_data][:, cols_with_data]\n", 201 | " H_current, W_current = cropped_image.shape\n", 202 | " \n", 203 | " # Calculate scaling factor (preserve aspect ratio)\n", 204 | " max_dim = max(H_current, W_current)\n", 205 | " scale_factor = max_fit / max_dim\n", 206 | " \n", 207 | " # Calculate new dimensions\n", 208 | " new_H = max(1, int(round(H_current * scale_factor)))\n", 209 | " new_W = max(1, int(round(W_current * scale_factor)))\n", 210 | " \n", 211 | " # Resize using nearest-neighbor to keep binary\n", 212 | " resized_image = resize(\n", 213 | " cropped_image,\n", 214 | " (new_H, new_W),\n", 215 | " order=0,\n", 216 | " anti_aliasing=False,\n", 217 | " preserve_range=True\n", 218 | " )\n", 219 | " resized_image = (resized_image > 0.5).astype(np.uint8)\n", 220 | " \n", 221 | " # Create canvas and center the particle\n", 222 | " canvas = np.zeros((target_size, target_size), dtype=np.uint8)\n", 223 | " pad_y = (target_size - new_H) // 2\n", 224 | " pad_x = (target_size - new_W) // 2\n", 225 | " canvas[pad_y:pad_y + new_H, pad_x:pad_x + new_W] = resized_image\n", 226 | " \n", 227 | " # Plot\n", 228 | " plt.figure(figsize=(6, 6))\n", 229 | " plt.imshow(canvas, cmap='binary', interpolation='none')\n", 230 | " plt.title(f\"{probe_time} | Original: {H_current}×{W_current} → Scaled: {target_size}×{target_size}\")\n", 231 | " plt.xlabel('X-Dimension')\n", 232 | " plt.ylabel('Y-Dimension')\n", 233 | " plt.grid(color='gray', linestyle='--', linewidth=0.5, alpha=0.3)\n", 234 | " plt.tight_layout()\n", 235 | " plt.savefig(f\"{output_dir}/particle_{particle_index}.png\")\n", 236 | " plt.close()\n", 237 | " return canvas\n", 238 | "\n", 239 | "# Example usage:\n", 240 | "# canvas = plot_particle_standardized(ds, particle_index=34000)" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "id": "e838be9f", 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "plot_particle_standardized(ds, particle_index=34000)" 251 | ] 252 | }, 253 | { 254 | "cell_type": "markdown", 255 | "id": "239065dc", 256 | "metadata": {}, 257 | "source": [ 258 | "### Save images for all particles, or specify your start particle" 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "id": "bcb1ee4f", 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [ 268 | "start_particle= 721600 ## specify where in array you want to start plotting\n", 269 | "\n", 270 | "# Loop through all indices (or a subset for testing)\n", 271 | "for particle_index in range(start_particle,ds.dims['Time']):\n", 272 | " plot_particle_standardized(ds, particle_index)" 273 | ] 274 | } 275 | ], 276 | "metadata": { 277 | "kernelspec": { 278 | "display_name": "Python 3", 279 | "language": "python", 280 | "name": "python3" 281 | }, 282 | "language_info": { 283 | "codemirror_mode": { 284 | "name": "ipython", 285 | "version": 3 286 | }, 287 | "file_extension": ".py", 288 | "mimetype": "text/x-python", 289 | "name": "python", 290 | "nbconvert_exporter": "python", 291 | "pygments_lexer": "ipython3", 292 | "version": "3.13.0" 293 | } 294 | }, 295 | "nbformat": 4, 296 | "nbformat_minor": 5 297 | } 298 | -------------------------------------------------------------------------------- /pbp_preprocessing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "7fc01067", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import xarray as xr\n", 11 | "import numpy as np\n", 12 | "import os\n", 13 | "import pandas as pd" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "id": "0625d93f", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "# Open pbp file\n", 24 | "dsH = xr.open_dataset('Data/20250524_022503_F2DS_H.pbp.nc',decode_times=False)\n", 25 | "dsV = xr.open_dataset('Data/20250524_022503_F2DS_V.pbp.nc',decode_times=False)" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 3, 31 | "id": "39ea2faf", 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "#Open temp etc file\n", 36 | "ds_env = xr.open_dataset('Data/RF02.20250524.004127_075522.PNI.nc',decode_times=True)" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "id": "eab693c5", 42 | "metadata": {}, 43 | "source": [ 44 | "### Saves probe time to UTC to display on images" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 4, 50 | "id": "19fe624e", 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "def convert_time(ds):\n", 55 | " \"\"\"\n", 56 | " Convert Time variable in pbp dataset to datetime64[ns] using FlightDate attribute.\n", 57 | " \"\"\"\n", 58 | " FlightDate = pd.to_datetime(ds.FlightDate)\n", 59 | "\n", 60 | " origin = np.datetime64(FlightDate)\n", 61 | " utc_times = origin + ds['probetime'].astype('timedelta64[s]')\n", 62 | " return utc_times" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 5, 68 | "id": "604a465e", 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "def find_cutoffs(ds_env):\n", 73 | " \"\"\"\n", 74 | " Find the liquid and solid cutoffs based on ATX values in the environmental dataset.\n", 75 | " Liquid cutoff: earliest time where ATX >= 1\n", 76 | " Solid cutoff: earliest time where ATX <= -40\n", 77 | " Returns:\n", 78 | " liquid_cutoff (np.datetime64): Time of liquid cutoff\n", 79 | " solid_cutoff (np.datetime64): Time of solid cutoff\n", 80 | " \"\"\"\n", 81 | " mask = (ds_env.ATX >=1)\n", 82 | " # Find the earliest time where ATX >= 0\n", 83 | " liquid_cutoff = ds_env.isel(Time =mask.values)['Time'].max().values\n", 84 | "\n", 85 | " mask = (ds_env.ATX <= -40)\n", 86 | " solid_cutoff = ds_env.isel(Time =mask.values)['Time'].min().values\n", 87 | " print(f\"Liquid cutoff time: {liquid_cutoff}\")\n", 88 | " print(f\"Solid cutoff time: {solid_cutoff}\")\n", 89 | " return liquid_cutoff, solid_cutoff" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "id": "c07b1d09", 95 | "metadata": {}, 96 | "source": [ 97 | "## Quality Filtering" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 32, 103 | "id": "632f4903", 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "AREARATIO_MAX = 0.95\n", 108 | "ASPECTRATIO_MIN = 0.90\n", 109 | "ASPECTRATIO_TOL = 0.05\n", 110 | "ASPECTRATIO_LINE_MAX = 0.1\n", 111 | "VOID_THRESHOLD = 0.10 # Max allowable void fraction (e.g., 10%)\n", 112 | "DIODEGAPS_THRESH = 1\n", 113 | "SIZE_THRESH = 70" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 33, 119 | "id": "5b025566", 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "def generate_rect_mask(ds: xr.Dataset) -> xr.DataArray:\n", 124 | " \"\"\"\n", 125 | " Generates a boolean mask to isolate long lines/perfect rectangle particles\n", 126 | " from raw particle metadata in an xarray Dataset.\n", 127 | "\n", 128 | " The filter uses two main criteria:\n", 129 | " 1. High Area Ratio (arearatio): Particle area relative to its bounding box.\n", 130 | " A high value indicates a solid, well-defined shape (like a line or\n", 131 | " rectangle), not a fragmented/wispy particle.\n", 132 | " 2. High Elongation: The ratio of the particle's longest dimension to its\n", 133 | " shortest dimension is high, characterizing a long line.\n", 134 | "\n", 135 | " Args:\n", 136 | " ds: xarray Dataset containing particle metrics with dimensions for\n", 137 | " 'particle_id' and variables 'xsize', 'ysize', and 'arearatio'.\n", 138 | "\n", 139 | " Returns:\n", 140 | " xr.DataArray: A boolean mask (True for particles to keep).\n", 141 | " \"\"\"\n", 142 | " \n", 143 | " # --- 1. Define Filtering Thresholds ---\n", 144 | " # Threshold for arearatio (Area / Bounding Box Area).\n", 145 | " # LINE: 0.09095, PARTICLE: 0.00615. Using 0.05 will isolate the LINE.\n", 146 | " AREA_RATIO_THRESHOLD = 0.05 \n", 147 | " \n", 148 | " # Threshold for Elongation (max_size / min_size).\n", 149 | " # LINE: 140/10 = 14. PARTICLE: 1160/70 ~ 16.5. \n", 150 | " # We use this to ensure the particle is very long/thin.\n", 151 | " ELONGATION_THRESHOLD = 5 \n", 152 | "\n", 153 | " \n", 154 | " # --- 2. Calculate Elongation Ratio ---\n", 155 | " # Determine the maximum and minimum size dimensions for each particle\n", 156 | " max_size = ds['xsize'].where(ds['xsize'] > ds['ysize'], other=ds['ysize'])\n", 157 | " min_size = ds['ysize'].where(ds['xsize'] > ds['ysize'], other=ds['xsize'])\n", 158 | " \n", 159 | " # Calculate the elongation ratio\n", 160 | " elongation_ratio = max_size / min_size\n", 161 | " \n", 162 | " # --- 3. Create Mask Components ---\n", 163 | " \n", 164 | " # Filter A: Particles must have a high arearatio (solid shape)\n", 165 | " mask_solid = ds['arearatio'] > AREA_RATIO_THRESHOLD\n", 166 | " \n", 167 | " # Filter B: Particles must be highly elongated (long line/rectangle)\n", 168 | " mask_elongated = elongation_ratio > ELONGATION_THRESHOLD\n", 169 | " edge_mask = ds['edgetouch'] == 0\n", 170 | " \n", 171 | " # --- 4. Combine Masks ---\n", 172 | " # The final mask requires both conditions to be True:\n", 173 | " # High Area Ratio AND High Elongation\n", 174 | " final_mask = mask_solid & mask_elongated & edge_mask\n", 175 | "\n", 176 | " return final_mask" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 34, 182 | "id": "46c171b5", 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [ 186 | "def create_exclusion_mask(ds,\n", 187 | " arearatio_max=AREARATIO_MAX,\n", 188 | " aspectratio_min=ASPECTRATIO_MIN,\n", 189 | " aspectratio_line_max=ASPECTRATIO_LINE_MAX,\n", 190 | " void_threshold=VOID_THRESHOLD,\n", 191 | " size_threshold=100,\n", 192 | " diodegaps_thresh=1):\n", 193 | " \"\"\"\n", 194 | " Build a boolean exclusion mask (True = exclude) for particles in `ds`.\n", 195 | " Returns an xarray.DataArray aligned with ds['Time'].\n", 196 | " \"\"\"##\n", 197 | " # Donut / hollow / out-of-focus\n", 198 | " donut_mask = (ds['diodegaps'] > diodegaps_thresh)\n", 199 | "\n", 200 | " # Near-perfect rectangles / squares\n", 201 | " #square_mask = (ds['arearatiofilled'] > arearatio_max) & (ds['aspectratio'] > aspectratio_min)\n", 202 | "\n", 203 | " \n", 204 | " ## wider lines\n", 205 | " rect_mask = generate_rect_mask(ds)\n", 206 | "\n", 207 | " line_mask = (ds['aspectratio'] < aspectratio_line_max) & (ds['edgetouch'] == 0)\n", 208 | "\n", 209 | " # Calculate the Void Index (Area of the Void / Total Filled Area)\n", 210 | " void_index = (ds['areafilled'] - ds['area']) / ds['areafilled']\n", 211 | "\n", 212 | " # Create a mask to REJECT particles with a high void index\n", 213 | " # We use >= to ensure the filter works correctly\n", 214 | " void_mask = void_index >= void_threshold\n", 215 | " \n", 216 | "\n", 217 | " # Small particles (size cutoff)\n", 218 | " size_mask = ds['diam'] <= size_threshold\n", 219 | "\n", 220 | " exclusion_mask = size_mask | donut_mask | rect_mask | line_mask |void_mask\n", 221 | " return exclusion_mask" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 35, 227 | "id": "543ddcc6", 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "def create_mask(ds):\n", 232 | " # keep conditions\n", 233 | " size_ok = ds['diam'] >= SIZE_THRESH\n", 234 | " aspect_ok = ds['aspectratio'] > ASPECTRATIO_TOL\n", 235 | "\n", 236 | " # exclude donuts: diodegaps > threshold => NOT keep\n", 237 | " not_donut = ds['diodegaps'] <= DIODEGAPS_THRESH\n", 238 | "\n", 239 | " # exclude near-perfect filled squares/rectangles:\n", 240 | " # prefer 'arearatiofilled' if present, else fall back to 'arearatio'\n", 241 | " areafill = ds.get('arearatiofilled', ds.get('arearatio', None))\n", 242 | " if areafill is None:\n", 243 | " # if neither exists, conservatively don't exclude on fill\n", 244 | " not_square = True\n", 245 | " else:\n", 246 | " square_exclude = (areafill >= AREARATIO_MAX) & (np.abs(ds['aspectratio'] - 1) <= ASPECTRATIO_TOL)\n", 247 | " not_square = ~square_exclude\n", 248 | "\n", 249 | " # final keep mask: size AND aspect AND not donut AND not square\n", 250 | " keep = size_ok & aspect_ok & not_donut & not_square\n", 251 | " return keep" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": 36, 257 | "id": "ddcc3c01", 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [ 261 | "# Combine all exclusion masks using OR\n", 262 | "def filter_particles(ds,ds_env, start_index=0):\n", 263 | " l_mask, s_mask = find_cutoffs(ds_env)\n", 264 | " mask = create_mask(ds)\n", 265 | "\n", 266 | " # --- Apply the mask ---\n", 267 | " ds_filt = ds.isel(Time=mask)\n", 268 | " rect_mask = generate_rect_mask(ds_filt)\n", 269 | " ds_filt = ds_filt.isel(Time=~rect_mask)\n", 270 | "\n", 271 | " # To filter the 'image' variable, you'll need to calculate the new start/stop indices. \n", 272 | " # This can be tricky and resource-intensive because 'image' is a flat array \n", 273 | " # and not indexed by 'Time'.\n", 274 | " # For now, focus on the scalar data, which is filtered by 'Time'.\n", 275 | " print(f\"Original particle count: {len(ds['Time'])}\")\n", 276 | " print(f\"Filtered particle count: {len(ds_filt['Time'])}\")\n", 277 | " mask = (ds_filt['Time'] <= l_mask)\n", 278 | " liquid_particles = ds_filt.isel(Time =mask.values)\n", 279 | " \n", 280 | " \n", 281 | " mask = (ds_filt['Time'] >= s_mask)\n", 282 | " solid_particles = ds_filt.isel(Time =mask.values)\n", 283 | " ##Assign phase labels\n", 284 | " liquid_particles['phase'] = 0\n", 285 | " solid_particles['phase'] = 1\n", 286 | " n_liq = liquid_particles.dims['Time']\n", 287 | " n_sol = solid_particles.dims['Time']\n", 288 | " print(f\"Number of solid particles: {n_sol}\")\n", 289 | " print(f\"Number of liquid particles: {n_liq}\")\n", 290 | " # Re-index particle indices\n", 291 | " end_index = start_index + n_liq\n", 292 | " ds_sol = solid_particles.rename_dims({'Time':'particle_idx_seq'})\n", 293 | " ds_liq = liquid_particles.rename_dims({'Time':'particle_idx_seq'})\n", 294 | " ds_liq['particle_idx_seq'] = np.arange(start_index, end_index)\n", 295 | " ds_sol['particle_idx_seq'] = np.arange(end_index, end_index + n_sol)\n", 296 | "\n", 297 | " print(f\"Assigned particle_idx_seq from {start_index} to {end_index + n_sol -1}. \\n New start index for next call: {end_index + n_sol}\")\n", 298 | "\n", 299 | " #concat_ds = xr.concat([ds_liq, ds_sol], dim='particle_idx_seq')\n", 300 | " if start_index ==0:\n", 301 | " return ds_liq, ds_sol, end_index + n_sol\n", 302 | " return ds_liq, ds_sol" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 27, 308 | "id": "66aa838a", 309 | "metadata": {}, 310 | "outputs": [ 311 | { 312 | "name": "stderr", 313 | "output_type": "stream", 314 | "text": [ 315 | "/var/folders/bm/ggq_7b1n5g10y_2mdnxrjtc1fjxyx8/T/ipykernel_47937/3328646582.py:8: UserWarning: Converting non-nanosecond precision timedelta values to nanosecond precision. This behavior can eventually be relaxed in xarray, as it is an artifact from pandas which is now beginning to support non-nanosecond precision values. This warning is caused by passing non-nanosecond np.datetime64 or np.timedelta64 values to the DataArray or Variable constructor; it can be silenced by converting the values to nanosecond precision ahead of time.\n", 316 | " utc_times = origin + ds['probetime'].astype('timedelta64[s]')\n" 317 | ] 318 | }, 319 | { 320 | "name": "stdout", 321 | "output_type": "stream", 322 | "text": [ 323 | "Liquid cutoff time: 2025-05-24T02:52:29.000000000\n", 324 | "Solid cutoff time: 2025-05-24T03:01:37.000000000\n" 325 | ] 326 | } 327 | ], 328 | "source": [ 329 | "dsH['Time']= convert_time(dsH)\n", 330 | "rect_mask = (dsH['area']==dsH['perimeterarea']) & (dsH['diam'] > 100) & (dsH['arearatio'] < 0.1)\n", 331 | "rects = dsH.isel(Time=rect_mask)\n", 332 | "\n", 333 | "l,s = find_cutoffs(ds_env)\n", 334 | "solmask = (rects['Time'] >= s)\n", 335 | "solrect = rects.isel(Time=solmask)\n" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 37, 341 | "id": "82a9049f", 342 | "metadata": {}, 343 | "outputs": [ 344 | { 345 | "name": "stderr", 346 | "output_type": "stream", 347 | "text": [ 348 | "/var/folders/bm/ggq_7b1n5g10y_2mdnxrjtc1fjxyx8/T/ipykernel_47937/3328646582.py:8: UserWarning: Converting non-nanosecond precision timedelta values to nanosecond precision. This behavior can eventually be relaxed in xarray, as it is an artifact from pandas which is now beginning to support non-nanosecond precision values. This warning is caused by passing non-nanosecond np.datetime64 or np.timedelta64 values to the DataArray or Variable constructor; it can be silenced by converting the values to nanosecond precision ahead of time.\n", 349 | " utc_times = origin + ds['probetime'].astype('timedelta64[s]')\n" 350 | ] 351 | }, 352 | { 353 | "name": "stdout", 354 | "output_type": "stream", 355 | "text": [ 356 | "Liquid cutoff time: 2025-05-24T02:52:29.000000000\n", 357 | "Solid cutoff time: 2025-05-24T03:01:37.000000000\n", 358 | "Original particle count: 3754562\n", 359 | "Filtered particle count: 73166\n", 360 | "Number of solid particles: 3012\n", 361 | "Number of liquid particles: 1629\n", 362 | "Assigned particle_idx_seq from 0 to 4640. \n", 363 | " New start index for next call: 4641\n" 364 | ] 365 | }, 366 | { 367 | "name": "stderr", 368 | "output_type": "stream", 369 | "text": [ 370 | "/var/folders/bm/ggq_7b1n5g10y_2mdnxrjtc1fjxyx8/T/ipykernel_47937/4228364844.py:26: FutureWarning: The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.\n", 371 | " n_liq = liquid_particles.dims['Time']\n", 372 | "/var/folders/bm/ggq_7b1n5g10y_2mdnxrjtc1fjxyx8/T/ipykernel_47937/4228364844.py:27: FutureWarning: The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.\n", 373 | " n_sol = solid_particles.dims['Time']\n", 374 | "/var/folders/bm/ggq_7b1n5g10y_2mdnxrjtc1fjxyx8/T/ipykernel_47937/3328646582.py:8: UserWarning: Converting non-nanosecond precision timedelta values to nanosecond precision. This behavior can eventually be relaxed in xarray, as it is an artifact from pandas which is now beginning to support non-nanosecond precision values. This warning is caused by passing non-nanosecond np.datetime64 or np.timedelta64 values to the DataArray or Variable constructor; it can be silenced by converting the values to nanosecond precision ahead of time.\n", 375 | " utc_times = origin + ds['probetime'].astype('timedelta64[s]')\n" 376 | ] 377 | }, 378 | { 379 | "name": "stdout", 380 | "output_type": "stream", 381 | "text": [ 382 | "Liquid cutoff time: 2025-05-24T02:52:29.000000000\n", 383 | "Solid cutoff time: 2025-05-24T03:01:37.000000000\n", 384 | "Original particle count: 3666455\n", 385 | "Filtered particle count: 62598\n", 386 | "Number of solid particles: 947\n", 387 | "Number of liquid particles: 1594\n", 388 | "Assigned particle_idx_seq from 4641 to 7181. \n", 389 | " New start index for next call: 7182\n" 390 | ] 391 | }, 392 | { 393 | "name": "stderr", 394 | "output_type": "stream", 395 | "text": [ 396 | "/var/folders/bm/ggq_7b1n5g10y_2mdnxrjtc1fjxyx8/T/ipykernel_47937/4228364844.py:26: FutureWarning: The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.\n", 397 | " n_liq = liquid_particles.dims['Time']\n", 398 | "/var/folders/bm/ggq_7b1n5g10y_2mdnxrjtc1fjxyx8/T/ipykernel_47937/4228364844.py:27: FutureWarning: The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.\n", 399 | " n_sol = solid_particles.dims['Time']\n" 400 | ] 401 | } 402 | ], 403 | "source": [ 404 | "##filter V and H datasets\n", 405 | "dsV['Time'] = convert_time(dsV)\n", 406 | "liq, sol, start_idx = filter_particles(dsV,ds_env)\n", 407 | "dsH['Time'] = convert_time(dsH)\n", 408 | "liqH,solH = filter_particles(dsH,ds_env, start_index=start_idx)" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": 39, 414 | "id": "59bd6f07", 415 | "metadata": {}, 416 | "outputs": [], 417 | "source": [ 418 | "l_particles = [liq, liqH]\n", 419 | "s_particles = [sol, solH]" 420 | ] 421 | }, 422 | { 423 | "cell_type": "markdown", 424 | "id": "c699472c", 425 | "metadata": {}, 426 | "source": [ 427 | "## Standardized function to scale a particle to a 128x128 canvas and save the plot into specified output folder" 428 | ] 429 | }, 430 | { 431 | "cell_type": "markdown", 432 | "id": "239065dc", 433 | "metadata": {}, 434 | "source": [ 435 | "### Save images for all particles, or specify your start particle" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": 18, 441 | "id": "536e2c44", 442 | "metadata": {}, 443 | "outputs": [], 444 | "source": [ 445 | "import numpy as np\n", 446 | "import matplotlib.pyplot as plt\n", 447 | "from skimage.transform import resize\n", 448 | "import os\n", 449 | "\n", 450 | "# Make sure the output directory exists\n", 451 | "output_dir = 'particle_images_filtered'\n", 452 | "os.makedirs(output_dir, exist_ok=True)\n", 453 | "def plot_particle_standardized(ds, particle_index, target_size=128, max_fit=128, save=True, output_dir=output_dir):\n", 454 | " \"\"\"\n", 455 | " Produce a clean 128x128 grayscale image suitable for CNN input.\n", 456 | " - No text overlay\n", 457 | " - Returned array is float32 normalized to [0, 1]\n", 458 | " - Saved PNG (if save=True) contains only the image pixels\n", 459 | " \"\"\"\n", 460 | " import numpy as np\n", 461 | " from skimage.transform import resize\n", 462 | " import os\n", 463 | " import matplotlib.pyplot as plt\n", 464 | " pnumber = str(ds['particle_idx_seq'][particle_index].values)\n", 465 | " # safety: ensure output dir exists\n", 466 | " os.makedirs(output_dir, exist_ok=True)\n", 467 | "\n", 468 | " # defensive indexing\n", 469 | " try:\n", 470 | " start_slice = int(ds['starty'].values[particle_index])\n", 471 | " stop_slice = int(ds['stopy'].values[particle_index])\n", 472 | " start_diode = int(ds['startx'].values[particle_index])\n", 473 | " stop_diode = int(ds['stopx'].values[particle_index])\n", 474 | " except Exception:\n", 475 | " return None\n", 476 | "\n", 477 | " # load only the small image patch for this particle\n", 478 | " cropped_image = ds['image'].values[start_slice:stop_slice, start_diode:stop_diode]\n", 479 | "\n", 480 | " # ensure binary and small memory footprint\n", 481 | " cropped_image = (cropped_image > 0).astype(np.uint8)\n", 482 | "\n", 483 | " # trim empty rows/cols\n", 484 | " rows_with_data = np.any(cropped_image == 1, axis=1)\n", 485 | " cols_with_data = np.any(cropped_image == 1, axis=0)\n", 486 | " if not np.any(rows_with_data) or not np.any(cols_with_data):\n", 487 | " return None\n", 488 | "\n", 489 | " cropped_image = cropped_image[rows_with_data][:, cols_with_data]\n", 490 | " H_current, W_current = cropped_image.shape\n", 491 | "\n", 492 | " # scale preserving aspect ratio so largest side == max_fit\n", 493 | " max_dim = max(H_current, W_current)\n", 494 | " if max_dim == 0:\n", 495 | " return None\n", 496 | " scale_factor = float(max_fit) / float(max_dim)\n", 497 | "\n", 498 | " new_H = max(1, int(round(H_current * scale_factor)))\n", 499 | " new_W = max(1, int(round(W_current * scale_factor)))\n", 500 | "\n", 501 | " # nearest-neighbor resize to keep binary structure\n", 502 | " resized_image = resize(\n", 503 | " cropped_image,\n", 504 | " (new_H, new_W),\n", 505 | " order=0,\n", 506 | " anti_aliasing=False,\n", 507 | " preserve_range=True\n", 508 | " )\n", 509 | " resized_image = (resized_image > 0.5).astype(np.uint8)\n", 510 | "\n", 511 | " # center on canvas\n", 512 | " canvas = np.zeros((target_size, target_size), dtype=np.uint8)\n", 513 | " pad_y = (target_size - new_H) // 2\n", 514 | " pad_x = (target_size - new_W) // 2\n", 515 | " canvas[pad_y:pad_y + new_H, pad_x:pad_x + new_W] = resized_image\n", 516 | " \n", 517 | " ##invert colors\n", 518 | " canvas = 1 - canvas\n", 519 | "\n", 520 | " # normalize to float32 [0,1] for CNN\n", 521 | " canvas_f = canvas.astype(np.float32) / 255.0 if canvas.max() > 1 else canvas.astype(np.float32)\n", 522 | "\n", 523 | " # save as 8-bit PNG without text/axes\n", 524 | " if save:\n", 525 | " # convert to 0-255 uint8 for disk\n", 526 | " out_img = (canvas_f * 255).astype(np.uint8)\n", 527 | " plt.imsave(f\"{output_dir}/particle_{pnumber}.png\", out_img, cmap='gray', vmin=0, vmax=255)\n", 528 | "\n", 529 | " return canvas_f\n" 530 | ] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "execution_count": 40, 535 | "id": "bcb1ee4f", 536 | "metadata": {}, 537 | "outputs": [ 538 | { 539 | "name": "stderr", 540 | "output_type": "stream", 541 | "text": [ 542 | "/var/folders/bm/ggq_7b1n5g10y_2mdnxrjtc1fjxyx8/T/ipykernel_47937/262451929.py:7: FutureWarning: The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.\n", 543 | " print(f\"Processing dataset with {p_ds.dims['particle_idx_seq']} particles...\")\n", 544 | "/var/folders/bm/ggq_7b1n5g10y_2mdnxrjtc1fjxyx8/T/ipykernel_47937/262451929.py:8: FutureWarning: The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.\n", 545 | " for particle_index in range(p_ds.dims['particle_idx_seq']):\n" 546 | ] 547 | }, 548 | { 549 | "name": "stdout", 550 | "output_type": "stream", 551 | "text": [ 552 | "Processing dataset with 1629 particles...\n", 553 | "Processing dataset with 1594 particles...\n", 554 | "Processing dataset with 3012 particles...\n", 555 | "Processing dataset with 947 particles...\n" 556 | ] 557 | } 558 | ], 559 | "source": [ 560 | "liquid_dir = os.path.join(output_dir, 'liquid')\n", 561 | "os.makedirs(liquid_dir, exist_ok=True)\n", 562 | "solid_dir = os.path.join(output_dir, 'solid')\n", 563 | "os.makedirs(solid_dir, exist_ok=True)\n", 564 | "# Loop through all indices (or a subset for testing)\n", 565 | "def process_and_save_particles(p_ds, output_dir=output_dir):\n", 566 | " print(f\"Processing dataset with {p_ds.dims['particle_idx_seq']} particles...\")\n", 567 | " for particle_index in range(p_ds.dims['particle_idx_seq']):\n", 568 | " plot_particle_standardized(p_ds, particle_index, output_dir=output_dir)\n", 569 | "\n", 570 | "for ds in l_particles:\n", 571 | " process_and_save_particles(ds, output_dir=liquid_dir)\n", 572 | "\n", 573 | "for ds in s_particles:\n", 574 | " process_and_save_particles(ds, output_dir=solid_dir)" 575 | ] 576 | }, 577 | { 578 | "cell_type": "code", 579 | "execution_count": 41, 580 | "id": "150f5f47", 581 | "metadata": {}, 582 | "outputs": [], 583 | "source": [ 584 | "#convert subset ds to pandas dataframe\n", 585 | "particles_df = []\n", 586 | "for particle in l_particles:\n", 587 | " df = particle[['particle_idx_seq','phase']].to_dataframe()\n", 588 | " particles_df.append(df)\n", 589 | "for particles in s_particles:\n", 590 | " df = particles[['particle_idx_seq','phase']].to_dataframe()\n", 591 | " particles_df.append(df)" 592 | ] 593 | }, 594 | { 595 | "cell_type": "code", 596 | "execution_count": 42, 597 | "id": "a0fdbc2b", 598 | "metadata": {}, 599 | "outputs": [], 600 | "source": [ 601 | "df_all= pd.concat(particles_df).reset_index()" 602 | ] 603 | }, 604 | { 605 | "cell_type": "code", 606 | "execution_count": null, 607 | "id": "823d3127", 608 | "metadata": {}, 609 | "outputs": [], 610 | "source": [ 611 | "df_all.to_csv(output_dir +'/particle_phases.csv', index=False)" 612 | ] 613 | }, 614 | { 615 | "cell_type": "code", 616 | "execution_count": null, 617 | "id": "d6f76e28", 618 | "metadata": {}, 619 | "outputs": [], 620 | "source": [] 621 | } 622 | ], 623 | "metadata": { 624 | "kernelspec": { 625 | "display_name": "myenv", 626 | "language": "python", 627 | "name": "python3" 628 | }, 629 | "language_info": { 630 | "codemirror_mode": { 631 | "name": "ipython", 632 | "version": 3 633 | }, 634 | "file_extension": ".py", 635 | "mimetype": "text/x-python", 636 | "name": "python", 637 | "nbconvert_exporter": "python", 638 | "pygments_lexer": "ipython3", 639 | "version": "3.13.0" 640 | } 641 | }, 642 | "nbformat": 4, 643 | "nbformat_minor": 5 644 | } 645 | -------------------------------------------------------------------------------- /particle_classification_CNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "ViqmHWqfGumf" 7 | }, 8 | "source": [ 9 | "# Particle Classification CNN: 4-Class Hybrid Model\n", 10 | "\n", 11 | "This notebook trains a CNN to classify cloud particles into 4 phases using:\n", 12 | "- Particle images (128x128 grayscale)\n", 13 | "- Environmental features: temperature, air speed, altitude\n", 14 | "\n", 15 | "**Output:** 4-class classification\n", 16 | "- Phase 0: Liquid\n", 17 | "- Phase 1: Solid (Ice)\n", 18 | "- Phase 2: Donut\n", 19 | "- Phase 3: Noise" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": { 25 | "id": "morPzRTFGumg" 26 | }, 27 | "source": [ 28 | "## 1. Import Libraries" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": { 35 | "id": "EQQZjoPzGumg" 36 | }, 37 | "outputs": [], 38 | "source": [ 39 | "import numpy as np\n", 40 | "import pandas as pd\n", 41 | "import os\n", 42 | "from PIL import Image\n", 43 | "import matplotlib.pyplot as plt\n", 44 | "%matplotlib inline\n", 45 | "\n", 46 | "from sklearn.model_selection import train_test_split\n", 47 | "from sklearn.preprocessing import StandardScaler\n", 48 | "\n", 49 | "import tensorflow as tf\n", 50 | "from tensorflow import keras\n", 51 | "from keras.models import Model\n", 52 | "from keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Dropout, Concatenate, BatchNormalization\n", 53 | "from keras.utils import to_categorical\n", 54 | "from keras.callbacks import EarlyStopping, ModelCheckpoint\n", 55 | "from tensorflow.keras.optimizers import Adam\n", 56 | "\n", 57 | "from sklearn.metrics import accuracy_score, confusion_matrix, classification_report" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": { 63 | "id": "-rATtmL5Gumg" 64 | }, 65 | "source": [ 66 | "## 2. Load and Prepare Data\n", 67 | "\n", 68 | "**NOTE:** You need to have a DataFrame with the following columns:\n", 69 | "- `particle_idx_seq`: matches the particle_X.png filenames\n", 70 | "- `phase`: 0 for liquid, 1 for solid, 2 for donut, 3 for noise\n", 71 | "- `ATX`: temperature value\n", 72 | "- `TASX`: air speed value\n", 73 | "- `GGALT`: altitude value" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": { 80 | "colab": { 81 | "base_uri": "https://localhost:8080/" 82 | }, 83 | "id": "KuRXs8OcGumg", 84 | "outputId": "73a1cfab-12b8-4034-e24b-b3f765ccfb04" 85 | }, 86 | "outputs": [], 87 | "source": [ 88 | "#connect to google drive\n", 89 | "from google.colab import drive\n", 90 | "device_name = tf.test.gpu_device_name()\n", 91 | "if device_name != '/device:GPU:0':\n", 92 | " raise SystemError('GPU device not found')\n", 93 | "print('Found GPU at: {}'.format(device_name))\n", 94 | "\n", 95 | "drive.mount('/content/drive')" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": { 102 | "colab": { 103 | "base_uri": "https://localhost:8080/", 104 | "height": 403 105 | }, 106 | "id": "_Ky1ERnbGumh", 107 | "outputId": "9bfb77f7-be33-4ac1-9d75-614ff40a9e1d" 108 | }, 109 | "outputs": [], 110 | "source": [ 111 | "base_path = '/content/drive/MyDrive/GEOG5100/aircraft_ml/'\n", 112 | "\n", 113 | "# Load labeled data\n", 114 | "df = pd.read_csv(base_path+'particle_df.csv')\n", 115 | "if df is None:\n", 116 | " raise ValueError(\"Please load your labeled dataframe before proceeding\")\n", 117 | "\n", 118 | "# Display basic info about the dataset\n", 119 | "print(f\"Total labeled particles: {len(df)}\")\n", 120 | "print(f\"\\nClass distribution:\")\n", 121 | "print(df['phase'].value_counts().sort_index())\n", 122 | "print(f\"\\nDataFrame info:\")\n", 123 | "df.head()" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": { 129 | "id": "_vdjCkpvGumh" 130 | }, 131 | "source": [ 132 | "## 3. Load Particle Images" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": { 139 | "id": "511JMsMzGumh" 140 | }, 141 | "outputs": [], 142 | "source": [ 143 | "def load_particle_image(particle_num, image_dir=base_path, target_size=(128, 128)):\n", 144 | " \"\"\"\n", 145 | " Load a single particle image and preprocess it.\n", 146 | "\n", 147 | " Parameters:\n", 148 | " -----------\n", 149 | " particle_num : int\n", 150 | " Particle number (corresponds to particle_X.png)\n", 151 | " image_dir : str\n", 152 | " Directory containing particle images\n", 153 | " target_size : tuple\n", 154 | " Target image size (height, width)\n", 155 | "\n", 156 | " Returns:\n", 157 | " --------\n", 158 | " image : np.ndarray\n", 159 | " Normalized image array of shape (target_size[0], target_size[1], 1)\n", 160 | " \"\"\"\n", 161 | " # Check all possible directories in order\n", 162 | " subdirs = ['liquid', 'solid']\n", 163 | "\n", 164 | " for subdir in subdirs:\n", 165 | " img_path = os.path.join(image_dir, f'particle_images_filtered/{subdir}', f'particle_{particle_num}.png')\n", 166 | " if os.path.exists(img_path):\n", 167 | " # Load image as grayscale\n", 168 | " img = Image.open(img_path).convert('L')\n", 169 | "\n", 170 | " # Resize if necessary\n", 171 | " img = img.resize(target_size)\n", 172 | "\n", 173 | " # Convert to numpy array and normalize to [0, 1]\n", 174 | " img_array = np.array(img, dtype=np.float32) / 255.0\n", 175 | "\n", 176 | " # Add channel dimension\n", 177 | " img_array = np.expand_dims(img_array, axis=-1)\n", 178 | "\n", 179 | " return img_array\n", 180 | "\n", 181 | " #print(f\"Warning: Image not found for particle {particle_num}\")\n", 182 | " return None" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": { 189 | "colab": { 190 | "base_uri": "https://localhost:8080/" 191 | }, 192 | "id": "_xqW0KjfGumh", 193 | "outputId": "fca92e48-9baf-447f-bcdb-aeca83d3ab00" 194 | }, 195 | "outputs": [], 196 | "source": [ 197 | "# Load all images (with parallel processing for speed)\n", 198 | "from concurrent.futures import ThreadPoolExecutor\n", 199 | "from tqdm import tqdm\n", 200 | "\n", 201 | "print(\"Loading particle images in parallel...\")\n", 202 | "\n", 203 | "def load_image_wrapper(args):\n", 204 | " \"\"\"Wrapper function for parallel image loading\"\"\"\n", 205 | " idx, particle_num = args\n", 206 | " img = load_particle_image(particle_num)\n", 207 | " return idx, img\n", 208 | "\n", 209 | "# Prepare arguments for parallel processing\n", 210 | "load_args = [(idx, row['particle_idx_seq'])\n", 211 | " for idx, row in df.iterrows()]\n", 212 | "\n", 213 | "# Load images in parallel using ThreadPoolExecutor\n", 214 | "images = []\n", 215 | "valid_indices = []\n", 216 | "\n", 217 | "with ThreadPoolExecutor(max_workers=8) as executor:\n", 218 | " # Use tqdm for progress bar\n", 219 | " results = list(tqdm(executor.map(load_image_wrapper, load_args),\n", 220 | " total=len(load_args),\n", 221 | " desc=\"Loading images\"))\n", 222 | "\n", 223 | " for idx, img in results:\n", 224 | " if img is not None:\n", 225 | " images.append(img)\n", 226 | " valid_indices.append(idx)\n", 227 | "\n", 228 | "# Convert to numpy array\n", 229 | "X_images = np.array(images)\n", 230 | "\n", 231 | "# Filter dataframe to only include particles with valid images\n", 232 | "df_valid = df.loc[valid_indices].reset_index(drop=True)\n", 233 | "\n", 234 | "print(f\"\\nSuccessfully loaded {len(X_images)} images\")\n", 235 | "print(f\"Image shape: {X_images.shape}\")\n", 236 | "\n", 237 | "# Filter out phases 2 and 3 from both df_valid and X_images to ensure alignment\n", 238 | "initial_num_images = len(X_images)\n", 239 | "initial_num_df_rows = len(df_valid)\n", 240 | "\n", 241 | "phase_filter_mask = df_valid['phase'].isin([0, 1])\n", 242 | "\n", 243 | "df_valid = df_valid[phase_filter_mask].reset_index(drop=True)\n", 244 | "X_images = X_images[phase_filter_mask]\n", 245 | "\n", 246 | "print(f\"Filtered for phases 0 and 1. Dropped {initial_num_df_rows - len(df_valid)} rows.\")\n", 247 | "print(f\"New X_images shape: {X_images.shape}\")\n", 248 | "print(f\"New df_valid shape: {df_valid.shape}\")" 249 | ] 250 | }, 251 | { 252 | "cell_type": "markdown", 253 | "metadata": { 254 | "id": "WhXYXqO2Gumh" 255 | }, 256 | "source": [ 257 | "## 4. Prepare Environmental Features" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": { 264 | "colab": { 265 | "base_uri": "https://localhost:8080/" 266 | }, 267 | "id": "1jbWe41sGumh", 268 | "outputId": "4d2cd04d-db69-4d79-f476-c402d5e33d1e" 269 | }, 270 | "outputs": [], 271 | "source": [ 272 | "# Extract environmental features\n", 273 | "print(\"Extracting feature names\")\n", 274 | "feature_columns = ['area','xsize']\n", 275 | "X_features = df_valid[feature_columns].values\n", 276 | "\n", 277 | "# Standardize the features (important for neural networks)\n", 278 | "scaler = StandardScaler()\n", 279 | "X_features_scaled = scaler.fit_transform(X_features)\n", 280 | "\n", 281 | "# Handle NaN values in scaled features, replacing them with 0.0 (the mean of scaled features)\n", 282 | "X_features_scaled = np.nan_to_num(X_features_scaled, nan=0.0)\n", 283 | "\n", 284 | "print(f\"Environmental features shape: {X_features_scaled.shape}\")\n", 285 | "print(f\"\\nFeature statistics (after scaling and NaN handling):\")\n", 286 | "print(pd.DataFrame(X_features_scaled, columns=feature_columns).describe())" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": null, 292 | "metadata": { 293 | "colab": { 294 | "base_uri": "https://localhost:8080/" 295 | }, 296 | "id": "CkkGnBCQpcbG", 297 | "outputId": "5dee90e0-2c31-44cd-8da3-c19e5a973a55" 298 | }, 299 | "outputs": [], 300 | "source": [ 301 | "print(\"Preparing data file paths and features...\")\n", 302 | "\n", 303 | "# Add the full image path to the DataFrame\n", 304 | "def get_full_image_path(row, image_dir=base_path):\n", 305 | " \"\"\"Determine the full path for a particle image.\"\"\"\n", 306 | " particle_num = row['particle_idx_seq']\n", 307 | " subdirs = ['liquid', 'solid']\n", 308 | "\n", 309 | " for subdir in subdirs:\n", 310 | " img_path = os.path.join(image_dir, f'particle_images_filtered/{subdir}', f'particle_{particle_num}.png')\n", 311 | " if os.path.exists(img_path):\n", 312 | " return img_path\n", 313 | " return None\n", 314 | "\n", 315 | "# Find all paths and filter the DataFrame\n", 316 | "df['image_path'] = df.apply(get_full_image_path, axis=1)\n", 317 | "df_valid = df.dropna(subset=['image_path']).reset_index(drop=True)\n", 318 | "\n", 319 | "print(f\"Total valid samples with images found: {len(df_valid)}\")" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": null, 325 | "metadata": { 326 | "colab": { 327 | "base_uri": "https://localhost:8080/" 328 | }, 329 | "id": "Lgkdfbf5poSX", 330 | "outputId": "8e478c34-eda9-495d-9799-493569a2fa6b" 331 | }, 332 | "outputs": [], 333 | "source": [ 334 | "# --- Step 2: Filter for specific phases (0 and 1) and synchronize data ---\n", 335 | "\n", 336 | "# Filter out phases 2 and 3\n", 337 | "phase_filter_mask = df_valid['phase'].isin([0, 1])\n", 338 | "df_valid = df_valid[phase_filter_mask].reset_index(drop=True)\n", 339 | "\n", 340 | "print(f\"Filtered for phases 0 and 1. Final samples: {len(df_valid)}\")\n", 341 | "\n", 342 | "# Extract the final synchronized arrays\n", 343 | "image_paths = df_valid['image_path'].values\n", 344 | "#Use x_features that we scaled\n", 345 | "#X_features # = df_valid[feature_columns].values # Assuming feature_columns is a list of column names\n", 346 | "y_labels = pd.get_dummies(df_valid['phase']).values # Convert phase (0, 1) to one-hot labels (categorical)\n", 347 | "\n", 348 | "print(f\"Synchronized Data Shapes:\")\n", 349 | "print(f\" Image Paths: {image_paths.shape}\")\n", 350 | "print(f\" Features: {X_features_scaled.shape}\")\n", 351 | "print(f\" Labels (Categorical): {y_labels.shape}\")" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": null, 357 | "metadata": { 358 | "colab": { 359 | "base_uri": "https://localhost:8080/" 360 | }, 361 | "id": "1719OPdZpqpq", 362 | "outputId": "32b0f395-9217-4084-bafc-ddcce160684f" 363 | }, 364 | "outputs": [], 365 | "source": [ 366 | "# --- Step 3: Data Splitting (Image Paths, Features, Labels) ---\n", 367 | "\n", 368 | "print(\"\\nSplitting data...\")\n", 369 | "# Initial split: 90% temp, 10% test\n", 370 | "X_paths_temp, X_paths_test, \\\n", 371 | "X_features_temp, X_features_test, \\\n", 372 | "y_temp, y_test = train_test_split(\n", 373 | " image_paths, X_features_scaled, y_labels,\n", 374 | " test_size=0.10,\n", 375 | " stratify=y_labels,\n", 376 | " random_state=42\n", 377 | ")\n", 378 | "\n", 379 | "# Second split: 89% train, 11% validation (of temp)\n", 380 | "X_paths_train, X_paths_val, \\\n", 381 | "X_features_train, X_features_val, \\\n", 382 | "y_train, y_val = train_test_split(\n", 383 | " X_paths_temp, X_features_temp, y_temp,\n", 384 | " test_size=0.11,\n", 385 | " stratify=y_temp,\n", 386 | " random_state=42\n", 387 | ")\n", 388 | "\n", 389 | "print(f\"Training set: {len(X_paths_train)} samples\")\n", 390 | "print(f\"Validation set: {len(X_paths_val)} samples\")\n", 391 | "print(f\"Test set: {len(X_paths_test)} samples\")\n", 392 | "print(\"Data is ready for the tf.data.Dataset pipeline.\")" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": null, 398 | "metadata": { 399 | "colab": { 400 | "base_uri": "https://localhost:8080/" 401 | }, 402 | "id": "SAxs_eVzGumh", 403 | "outputId": "5026633b-9237-487d-e3ae-9d88543b9cad" 404 | }, 405 | "outputs": [], 406 | "source": [ 407 | "# Extract labels and convert to one-hot encoding\n", 408 | "from sklearn.utils import class_weight\n", 409 | "\n", 410 | "y = df_valid['phase'].values\n", 411 | "y_categorical = to_categorical(y, num_classes=2)\n", 412 | "\n", 413 | "print(f\"Labels shape: {y_categorical.shape}\")\n", 414 | "print(f\"Class distribution: {np.bincount(y)}\")\n", 415 | "print(f\"Class 0 (Liquid): {np.sum(y == 0)} samples ({np.sum(y == 0) / len(y) * 100:.1f}%)\")\n", 416 | "print(f\"Class 1 (Solid): {np.sum(y == 1)} samples ({np.sum(y == 1) / len(y) * 100:.1f}%)\")\n", 417 | "# print(f\"Class 2 (Donut): {np.sum(y == 2)} samples ({np.sum(y == 2) / len(y) * 100:.1f}%)\")\n", 418 | "# print(f\"Class 3 (Noise): {np.sum(y == 3)} samples ({np.sum(y == 3) / len(y) * 100:.1f}%)\")\n", 419 | "\n", 420 | "# Calculate class weights for imbalanced data\n", 421 | "class_weights = class_weight.compute_class_weight(\n", 422 | " class_weight='balanced',\n", 423 | " classes=np.unique(y),\n", 424 | " y=y\n", 425 | ")\n", 426 | "class_weight_dict = {i: class_weights[i] for i in range(len(class_weights))}\n", 427 | "\n", 428 | "print(f\"\\nClass weights (to handle imbalance):\")\n", 429 | "for i, w in class_weight_dict.items():\n", 430 | " print(f\" Class {i}: {w:.2f}\")" 431 | ] 432 | }, 433 | { 434 | "cell_type": "markdown", 435 | "metadata": { 436 | "id": "DH51YK5dqRfW" 437 | }, 438 | "source": [ 439 | "## Data augmentation" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": null, 445 | "metadata": { 446 | "id": "14cCGiMxaSnc" 447 | }, 448 | "outputs": [], 449 | "source": [ 450 | "import tensorflow as tf\n", 451 | "from typing import Tuple\n", 452 | "\n", 453 | "IMAGE_SIZE = (128, 128)\n", 454 | "\n", 455 | "def random_geometric_augment(image: tf.Tensor) -> tf.Tensor:\n", 456 | " \"\"\"\n", 457 | " Apply geometric augmentations suitable for binary particle images.\n", 458 | " \"\"\"\n", 459 | " # Random 90-degree rotations (0, 90, 180, or 270 degrees)\n", 460 | " if tf.random.uniform(()) < 0.5:\n", 461 | " k = tf.random.uniform((), minval=0, maxval=4, dtype=tf.int32)\n", 462 | " image = tf.image.rot90(image, k=k)\n", 463 | "\n", 464 | " # Random horizontal flip\n", 465 | " if tf.random.uniform(()) < 0.5:\n", 466 | " image = tf.image.flip_left_right(image)\n", 467 | "\n", 468 | " # Random vertical flip\n", 469 | " if tf.random.uniform(()) < 0.5:\n", 470 | " image = tf.image.flip_up_down(image)\n", 471 | "\n", 472 | " # Random zoom (via crop and resize) - equivalent to zoom_range=0.2\n", 473 | " if tf.random.uniform(()) < 0.5:\n", 474 | " # Zoom range of 0.2 means 80% to 100% crop\n", 475 | " crop_factor = tf.random.uniform((), 0.8, 1.0)\n", 476 | " crop_size = tf.cast(IMAGE_SIZE[0] * crop_factor, tf.int32)\n", 477 | " image = tf.image.random_crop(image, size=[crop_size, crop_size, 1])\n", 478 | " image = tf.image.resize(image, IMAGE_SIZE)\n", 479 | "\n", 480 | " # Random translation (width_shift and height_shift of 0.1)\n", 481 | " if tf.random.uniform(()) < 0.5:\n", 482 | " # 10% of image size = 12.8 pixels, round to 13\n", 483 | " max_shift = int(IMAGE_SIZE[0] * 0.1)\n", 484 | " shift_height = tf.random.uniform((), -max_shift, max_shift, dtype=tf.int32)\n", 485 | " shift_width = tf.random.uniform((), -max_shift, max_shift, dtype=tf.int32)\n", 486 | "\n", 487 | " # Pad image to allow for translation\n", 488 | " padded = tf.pad(image, [[max_shift, max_shift], [max_shift, max_shift], [0, 0]],\n", 489 | " constant_values=0.0)\n", 490 | "\n", 491 | " # Crop to simulate translation\n", 492 | " offset_height = max_shift + shift_height\n", 493 | " offset_width = max_shift + shift_width\n", 494 | " image = tf.image.crop_to_bounding_box(padded, offset_height, offset_width,\n", 495 | " IMAGE_SIZE[0], IMAGE_SIZE[1])\n", 496 | "\n", 497 | " return image" 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "execution_count": null, 503 | "metadata": { 504 | "id": "b8yzM1O_yrVo" 505 | }, 506 | "outputs": [], 507 | "source": [ 508 | "import tensorflow as tf\n", 509 | "from typing import Tuple\n", 510 | "\n", 511 | "# Define the image size based on your model architecture\n", 512 | "IMAGE_SIZE = (128, 128)\n", 513 | "AUGMENTATION_PROB = 0.5\n", 514 | "MAX_SHIFT = 0.2 # 10% shift\n", 515 | "\n", 516 | "def random_geometric_augment(image: tf.Tensor) -> tf.Tensor:\n", 517 | "\n", 518 | " # Cast constants to float32 once for calculations\n", 519 | " H, W, C = IMAGE_SIZE[0], IMAGE_SIZE[1], 1\n", 520 | "\n", 521 | " # --- 1. Rotation (Using robust tf.image.rot90) ---\n", 522 | " # Apply a random rotation of 0, 90, 180, or 270 degrees.\n", 523 | " k = tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32)\n", 524 | " image = tf.image.rot90(image, k=k)\n", 525 | "\n", 526 | " # --- 2. Zooming and Shifting (Manual Implementation) ---\n", 527 | "\n", 528 | " # A. Calculate random shifts in pixels\n", 529 | " max_shift_pixels = tf.cast(H, tf.float32) * tf.cast(MAX_SHIFT, tf.float32)\n", 530 | " # The result is float32. Now, you should cast the result back to int32\n", 531 | " # if you want the maximum shift in whole pixels for subsequent calculations.\n", 532 | " max_shift_pixels = tf.cast(max_shift_pixels, tf.int32)\n", 533 | " shift_h = tf.random.uniform(shape=[], minval=-max_shift_pixels, maxval=max_shift_pixels, dtype=tf.int32)\n", 534 | " shift_w = tf.random.uniform(shape=[], minval=-max_shift_pixels, maxval=max_shift_pixels, dtype=tf.int32)\n", 535 | "\n", 536 | " # B. Apply Translation (Shift) using Padding and Cropping\n", 537 | "\n", 538 | " # Calculate padding needed to cover the max shift\n", 539 | " pad_h = tf.abs(shift_h)\n", 540 | " pad_w = tf.abs(shift_w)\n", 541 | "\n", 542 | " # Pad the image symmetrically by the maximum possible shift\n", 543 | " padded_image = tf.pad(\n", 544 | " image,\n", 545 | " [[pad_h, pad_h], [pad_w, pad_w], [0, 0]],\n", 546 | " mode='CONSTANT',\n", 547 | " constant_values=0.0 # Black background for binary images\n", 548 | " )\n", 549 | "\n", 550 | " # Calculate the starting point for cropping back to the original size\n", 551 | " # Start H = max_shift_pixels + shift_h (to correctly apply positive/negative shift)\n", 552 | " # The starting point is the padding size plus the calculated shift.\n", 553 | " start_h = pad_h + shift_h\n", 554 | " start_w = pad_w + shift_w\n", 555 | "\n", 556 | " # Crop the padded image back to the original size (128, 128)\n", 557 | " image = tf.image.crop_to_bounding_box(\n", 558 | " padded_image,\n", 559 | " offset_height=start_h,\n", 560 | " offset_width=start_w,\n", 561 | " target_height=H,\n", 562 | " target_width=W\n", 563 | " )\n", 564 | "\n", 565 | " # C. Apply Random Zoom\n", 566 | " zoom_factor = tf.random.uniform(shape=[], minval=0.9, maxval=1.1, dtype=tf.float32)\n", 567 | "\n", 568 | " if zoom_factor < 1.0:\n", 569 | " # Zoom out (pad with black background)\n", 570 | " pad_h = tf.cast(tf.cast(H, tf.float32) * (1.0 - zoom_factor) / 2.0, tf.int32)\n", 571 | " pad_w = tf.cast(tf.cast(W, tf.float32) * (1.0 - zoom_factor) / 2.0, tf.int32)\n", 572 | " image = tf.pad(image, [[pad_h, pad_h], [pad_w, pad_w], [0, 0]], constant_values=0.0)\n", 573 | " image = tf.image.resize(image, IMAGE_SIZE, method='nearest')\n", 574 | " elif zoom_factor > 1.0:\n", 575 | " # Zoom in (crop and resize)\n", 576 | " crop_h = tf.cast(tf.cast(H, tf.float32) / zoom_factor, tf.int32)\n", 577 | " crop_w = tf.cast(tf.cast(W, tf.float32) / zoom_factor, tf.int32)\n", 578 | "\n", 579 | " offset_h = (H - crop_h) // 2\n", 580 | " offset_w = (W - crop_w) // 2\n", 581 | "\n", 582 | " image = tf.image.crop_to_bounding_box(image, offset_h, offset_w, crop_h, crop_w)\n", 583 | " image = tf.image.resize(image, IMAGE_SIZE, method='nearest')\n", 584 | "\n", 585 | " return image" 586 | ] 587 | }, 588 | { 589 | "cell_type": "code", 590 | "execution_count": null, 591 | "metadata": { 592 | "id": "8x08z56TqBZu" 593 | }, 594 | "outputs": [], 595 | "source": [ 596 | "import tensorflow as tf\n", 597 | "from typing import Tuple\n", 598 | "\n", 599 | "# Define the image size based on your model architecture\n", 600 | "IMAGE_SIZE = (128, 128)\n", 601 | "AUGMENTATION_PROB = 0.5 # Probability for applying each augmentation step\n", 602 | "\n", 603 | "def load_and_augment_hybrid(\n", 604 | " image_path: str, feature_vector: tf.Tensor, label: tf.Tensor, augment: bool\n", 605 | ") -> Tuple[Tuple[tf.Tensor, tf.Tensor], tf.Tensor]:\n", 606 | " \"\"\"\n", 607 | " Loads an image from path, preprocesses it, applies conditional augmentation,\n", 608 | " and returns the structured inputs and label for the hybrid model.\n", 609 | " \"\"\"\n", 610 | "\n", 611 | " # 1. Image Loading and Initial Processing\n", 612 | " image = tf.io.read_file(filename=image_path)\n", 613 | " # Decode PNG (assuming your particle images are PNG) as grayscale (1 channel)\n", 614 | " image = tf.image.decode_png(contents=image, channels=1)\n", 615 | "\n", 616 | " # Convert to float32 and resize (Normalization will happen after augmentation)\n", 617 | " image = tf.image.convert_image_dtype(image=image, dtype=tf.float32)\n", 618 | " image = tf.image.resize(images=image, size=IMAGE_SIZE)\n", 619 | "\n", 620 | " # 2. Apply Augmentation (Only if 'augment' is True and based on probability)\n", 621 | " if augment:\n", 622 | " image = random_geometric_augment(image)\n", 623 | "\n", 624 | " # 3. Final Normalization/Clipping\n", 625 | " image = tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=1.0)\n", 626 | "\n", 627 | " # 4. Return in the format Keras model.fit expects: ( [image_input, feature_input], label )\n", 628 | " return (image, feature_vector), label" 629 | ] 630 | }, 631 | { 632 | "cell_type": "code", 633 | "execution_count": null, 634 | "metadata": { 635 | "colab": { 636 | "base_uri": "https://localhost:8080/" 637 | }, 638 | "id": "vYTXaiFRqU4M", 639 | "outputId": "cedc409d-4435-4908-930e-50ee4941773c" 640 | }, 641 | "outputs": [], 642 | "source": [ 643 | "# --- Configuration ---\n", 644 | "BATCH_SIZE = 16\n", 645 | "\n", 646 | "print(\"Building tf.data.Dataset pipelines...\")\n", 647 | "\n", 648 | "# --- 1. Training Dataset (with augmentation) ---\n", 649 | "train_ds = tf.data.Dataset.from_tensor_slices(\n", 650 | " (X_paths_train, X_features_train, y_train)\n", 651 | ")\n", 652 | "\n", 653 | "# Apply the loading and augmentation function\n", 654 | "train_ds = train_ds.map(\n", 655 | " lambda p, f, l: load_and_augment_hybrid(p, f, l, augment=True),\n", 656 | " num_parallel_calls=tf.data.AUTOTUNE # Parallelize data loading for speed\n", 657 | ")\n", 658 | "\n", 659 | "# Cache data for speed, shuffle, batch, and prefetch\n", 660 | "train_ds = train_ds.cache()\n", 661 | "train_ds = train_ds.shuffle(buffer_size=len(X_paths_train))\n", 662 | "train_ds = train_ds.batch(BATCH_SIZE)\n", 663 | "train_ds = train_ds.prefetch(buffer_size=tf.data.AUTOTUNE) # Pre-load next batch\n", 664 | "\n", 665 | "print(\"Training Dataset prepared.\")\n", 666 | "\n", 667 | "# --- 2. Validation Dataset (no augmentation) ---\n", 668 | "val_ds = tf.data.Dataset.from_tensor_slices(\n", 669 | " (X_paths_val, X_features_val, y_val)\n", 670 | ")\n", 671 | "\n", 672 | "# Apply the loading and processing function (augment=False)\n", 673 | "val_ds = val_ds.map(\n", 674 | " lambda p, f, l: load_and_augment_hybrid(p, f, l, augment=False),\n", 675 | " num_parallel_calls=tf.data.AUTOTUNE\n", 676 | ")\n", 677 | "\n", 678 | "# Cache data, batch, and prefetch (no shuffle needed for validation)\n", 679 | "val_ds = val_ds.cache()\n", 680 | "val_ds = val_ds.batch(BATCH_SIZE)\n", 681 | "val_ds = val_ds.prefetch(buffer_size=tf.data.AUTOTUNE)\n", 682 | "\n", 683 | "print(\"Validation Dataset prepared.\")\n", 684 | "\n", 685 | "# --- 3. Test Dataset (for final evaluation) ---\n", 686 | "test_ds = tf.data.Dataset.from_tensor_slices(\n", 687 | " (X_paths_test, X_features_test, y_test)\n", 688 | ")\n", 689 | "\n", 690 | "# Apply the loading and processing function (augment=False)\n", 691 | "test_ds = test_ds.map(\n", 692 | " lambda p, f, l: load_and_augment_hybrid(p, f, l, augment=False),\n", 693 | " num_parallel_calls=tf.data.AUTOTUNE\n", 694 | ")\n", 695 | "\n", 696 | "# Batch and prefetch\n", 697 | "test_ds = test_ds.batch(BATCH_SIZE)\n", 698 | "test_ds = test_ds.prefetch(buffer_size=tf.data.AUTOTUNE)\n", 699 | "\n", 700 | "print(\"All datasets successfully created. Ready for model compilation and training.\")" 701 | ] 702 | }, 703 | { 704 | "cell_type": "markdown", 705 | "metadata": { 706 | "id": "KhRitF_GGumh" 707 | }, 708 | "source": [ 709 | "## 5. Visualize Sample Particles" 710 | ] 711 | }, 712 | { 713 | "cell_type": "code", 714 | "execution_count": null, 715 | "metadata": { 716 | "colab": { 717 | "base_uri": "https://localhost:8080/", 718 | "height": 1000 719 | }, 720 | "id": "bwNfOb8mGumi", 721 | "outputId": "14abe60c-56b5-4740-a1b7-ff3e047b29a6" 722 | }, 723 | "outputs": [], 724 | "source": [ 725 | "def get_class_name(label):\n", 726 | " \"\"\"Convert numeric label to class name\"\"\"\n", 727 | " class_names = {0: 'Liquid', 1: 'Solid'}\n", 728 | " return class_names.get(label, 'Unknown')\n", 729 | "\n", 730 | "# Display random samples\n", 731 | "plt.figure(figsize=(15, 10))\n", 732 | "for i in range(12):\n", 733 | " plt.subplot(3, 4, i+1)\n", 734 | " idx = np.random.randint(0, len(X_images))\n", 735 | " plt.imshow(X_images[idx, :, :, 0], cmap='gray')\n", 736 | " plt.title(f'{get_class_name(y[idx])}\\nd={df_valid.iloc[idx][\"diam\"]:.1f} microns')\n", 737 | " plt.axis('off')\n", 738 | "plt.tight_layout()\n", 739 | "plt.show()" 740 | ] 741 | }, 742 | { 743 | "cell_type": "markdown", 744 | "metadata": { 745 | "id": "EdwIGRSVGumi" 746 | }, 747 | "source": [ 748 | "\\## 7. Build the Hybrid CNN Model\n", 749 | "\n", 750 | "This model uses a multi-input architecture:\n", 751 | "- **CNN branch**: Processes particle images to extract morphological features\n", 752 | "- **Dense branch**: Processes environmental features (temperature, air speed, altitude)\n", 753 | "- **Concatenation**: Combines both branches before final classification\n", 754 | "\n", 755 | "This hybrid approach leverages both visual particle characteristics and atmospheric conditions for improved classification." 756 | ] 757 | }, 758 | { 759 | "cell_type": "code", 760 | "execution_count": null, 761 | "metadata": { 762 | "colab": { 763 | "base_uri": "https://localhost:8080/", 764 | "height": 1000 765 | }, 766 | "id": "CSZgINBKT73u", 767 | "outputId": "d29f784b-44fc-4ded-867d-bfc2dd0a705c" 768 | }, 769 | "outputs": [], 770 | "source": [ 771 | "##Alternative simple model\n", 772 | "from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, MaxPooling2D, Flatten, Dense, Dropout, Concatenate,SpatialDropout2D\n", 773 | "from tensorflow.keras.models import Model\n", 774 | "from tensorflow.keras import regularizers\n", 775 | "\n", 776 | "# L2 regularization strength (a common starting value)\n", 777 | "L2_REG = 0.0001\n", 778 | "\n", 779 | "# Simplified version - just concatenate raw features\n", 780 | "image_input = Input(shape=(128, 128, 1), name='image_input')\n", 781 | "# Reduce initial filters for simpler images\n", 782 | "x = Conv2D(16, kernel_size=(3, 3), activation='relu', padding='same',\n", 783 | " kernel_regularizer=regularizers.l2(L2_REG))(image_input)\n", 784 | "x = BatchNormalization()(x)\n", 785 | "x = MaxPooling2D(pool_size=(2, 2))(x)\n", 786 | "x = SpatialDropout2D(0.2)(x) # Spatial dropout for conv layers\n", 787 | "\n", 788 | "x = Conv2D(32, kernel_size=(3, 3), activation='relu', padding='same',\n", 789 | " kernel_regularizer=regularizers.l2(L2_REG))(x)\n", 790 | "x = BatchNormalization()(x)\n", 791 | "x = MaxPooling2D(pool_size=(2, 2))(x)\n", 792 | "x = SpatialDropout2D(0.2)(x)\n", 793 | "\n", 794 | "x = Conv2D(64, kernel_size=(3, 3), activation='relu', padding='same',\n", 795 | " kernel_regularizer=regularizers.l2(L2_REG))(x)\n", 796 | "x = BatchNormalization()(x)\n", 797 | "x = MaxPooling2D(pool_size=(2, 2))(x)\n", 798 | "x = SpatialDropout2D(0.3)(x)\n", 799 | "\n", 800 | "# Flatten the CNN output\n", 801 | "x = Flatten()(x)\n", 802 | "# Increase Dropout rate\n", 803 | "x = Dropout(0.4)(x)\n", 804 | "# Simplify and add L2 regularization\n", 805 | "cnn_output = Dense(64, activation='relu', kernel_regularizer=regularizers.l2(L2_REG))(x)\n", 806 | "\n", 807 | "feature_input = Input(shape=(len(feature_columns),), name='feature_input')\n", 808 | "\n", 809 | "# Normalize features first (important!)\n", 810 | "features_normalized = BatchNormalization()(feature_input)\n", 811 | "\n", 812 | "# Direct concatenation\n", 813 | "combined = Concatenate()([cnn_output, features_normalized]) # 130 features\n", 814 | "\n", 815 | "# Classifier\n", 816 | "# Simpler final layers\n", 817 | "z = Dense(128, activation='relu', kernel_regularizer=regularizers.l2(L2_REG))(combined)\n", 818 | "z = Dropout(0.3)(z)\n", 819 | "z = Dense(64, activation='relu', kernel_regularizer=regularizers.l2(L2_REG))(z)\n", 820 | "z = Dropout(0.5)(z)\n", 821 | "output = Dense(2, activation='softmax')(z)\n", 822 | "\n", 823 | "model = Model(inputs=[image_input, feature_input], outputs=output)\n", 824 | "model.summary()" 825 | ] 826 | }, 827 | { 828 | "cell_type": "code", 829 | "execution_count": null, 830 | "metadata": { 831 | "colab": { 832 | "base_uri": "https://localhost:8080/", 833 | "height": 1000 834 | }, 835 | "id": "He1B1unuvALu", 836 | "outputId": "dc0b087b-5dbc-4fec-e3e4-19ab6b025536" 837 | }, 838 | "outputs": [], 839 | "source": [ 840 | "from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, MaxPooling2D, Flatten, Dense, Dropout, Concatenate,SpatialDropout2D\n", 841 | "from tensorflow.keras.models import Model\n", 842 | "from tensorflow.keras import regularizers\n", 843 | "\n", 844 | "# L2 regularization strength\n", 845 | "L2_REG = 0.0001\n", 846 | "\n", 847 | "image_input = Input(shape=(128, 128, 1), name='image_input')\n", 848 | "\n", 849 | "# Reduce initial filters for simpler images\n", 850 | "x = Conv2D(16, kernel_size=(3, 3), activation='relu', padding='same',\n", 851 | " kernel_regularizer=regularizers.l2(L2_REG))(image_input)\n", 852 | "x = BatchNormalization()(x)\n", 853 | "x = MaxPooling2D(pool_size=(2, 2))(x)\n", 854 | "x = SpatialDropout2D(0.2)(x) # Spatial dropout for conv layers\n", 855 | "\n", 856 | "x = Conv2D(32, kernel_size=(3, 3), activation='relu', padding='same',\n", 857 | " kernel_regularizer=regularizers.l2(L2_REG))(x)\n", 858 | "x = BatchNormalization()(x)\n", 859 | "x = MaxPooling2D(pool_size=(2, 2))(x)\n", 860 | "x = SpatialDropout2D(0.2)(x)\n", 861 | "\n", 862 | "x = Conv2D(64, kernel_size=(3, 3), activation='relu', padding='same',\n", 863 | " kernel_regularizer=regularizers.l2(L2_REG))(x)\n", 864 | "x = BatchNormalization()(x)\n", 865 | "x = MaxPooling2D(pool_size=(2, 2))(x)\n", 866 | "x = SpatialDropout2D(0.3)(x)\n", 867 | "\n", 868 | "# Flatten the CNN output\n", 869 | "x = Flatten()(x)\n", 870 | "# Increase Dropout rate\n", 871 | "x = Dropout(0.5)(x)\n", 872 | "# Simplify and add L2 regularization\n", 873 | "cnn_output = Dense(64, activation='relu', kernel_regularizer=regularizers.l2(L2_REG))(x)\n", 874 | "\n", 875 | "\n", 876 | "# Environmental features input branch\n", 877 | "feature_input = Input(shape=(len(feature_columns),), name='feature_input')\n", 878 | "\n", 879 | "# Normalize features first (important!)\n", 880 | "features_normalized = BatchNormalization()(feature_input)\n", 881 | "\n", 882 | "# Direct concatenation\n", 883 | "combined = Concatenate()([cnn_output, features_normalized]) # 130 features\n", 884 | "\n", 885 | "\n", 886 | "# Final classification layers\n", 887 | "# Simpler final layers\n", 888 | "z = Dense(128, activation='relu', kernel_regularizer=regularizers.l2(L2_REG))(combined)\n", 889 | "z = Dropout(0.5)(z)\n", 890 | "z = Dense(64, activation='relu', kernel_regularizer=regularizers.l2(L2_REG))(z)\n", 891 | "z = Dropout(0.3)(z)\n", 892 | "\n", 893 | "# Output layer (2-class classification: liquid, solid)\n", 894 | "# Assuming you switched to binary crossentropy loss:\n", 895 | "#output = Dense(1, activation='sigmoid', name='output')(z)\n", 896 | "# OR, if using categorical crossentropy for 2 classes (one-hot):\n", 897 | "output = Dense(2, activation='softmax', name='output')(z)\n", 898 | "\n", 899 | "\n", 900 | "# Create the model\n", 901 | "model = Model(inputs=[image_input, feature_input], outputs=output, name='Hybrid_CNN_Regularized_V2')\n", 902 | "# Display model architecture\n", 903 | "model.summary()" 904 | ] 905 | }, 906 | { 907 | "cell_type": "markdown", 908 | "metadata": { 909 | "id": "1wy75eccZAVJ" 910 | }, 911 | "source": [ 912 | "k-fold cross val" 913 | ] 914 | }, 915 | { 916 | "cell_type": "code", 917 | "execution_count": null, 918 | "metadata": { 919 | "colab": { 920 | "base_uri": "https://localhost:8080/", 921 | "height": 998 922 | }, 923 | "id": "oBrxJqyJZAFd", 924 | "outputId": "10cbda31-e47d-41d7-c9ab-db986817e982" 925 | }, 926 | "outputs": [], 927 | "source": [ 928 | "import numpy as np\n", 929 | "import tensorflow as tf\n", 930 | "from tensorflow.keras.models import Model\n", 931 | "from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, BatchNormalization, Flatten, Dense, Dropout, Concatenate\n", 932 | "from tensorflow.keras.optimizers import Adam\n", 933 | "from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau\n", 934 | "from tensorflow.keras import regularizers\n", 935 | "from sklearn.model_selection import KFold\n", 936 | "import pandas as pd\n", 937 | "import matplotlib.pyplot as plt\n", 938 | "import seaborn as sns\n", 939 | "from itertools import product\n", 940 | "import json\n", 941 | "\n", 942 | "# ==========================================\n", 943 | "# HYPERPARAMETER SEARCH SPACE\n", 944 | "# ==========================================\n", 945 | "\n", 946 | "param_grid = {\n", 947 | " 'batch_size': [8, 16, 32],\n", 948 | " 'learning_rate': [0.00001, 0.00005, 0.0001],\n", 949 | " 'l2_reg': [0.00001, 0.0001, 0.001],\n", 950 | " 'dropout_cnn': [0.3, 0.4, 0.5],\n", 951 | " 'dropout_features': [0.2, 0.3, 0.4],\n", 952 | " 'dropout_final': [0.4, 0.5, 0.6],\n", 953 | " 'cnn_dense_units': [64, 128, 256],\n", 954 | " 'feature_dense_units': [8, 16, 32],\n", 955 | "}\n", 956 | "\n", 957 | "# ==========================================\n", 958 | "# MODEL BUILDER FUNCTION\n", 959 | "# ==========================================\n", 960 | "\n", 961 | "def build_hybrid_model(num_features, params):\n", 962 | " \"\"\"\n", 963 | " Build hybrid model with specified hyperparameters.\n", 964 | "\n", 965 | " Parameters:\n", 966 | " - num_features: Number of input features\n", 967 | " - params: Dictionary of hyperparameters\n", 968 | " \"\"\"\n", 969 | " L2_REG = params['l2_reg']\n", 970 | "\n", 971 | " # Image branch\n", 972 | " image_input = Input(shape=(128, 128, 1), name='image_input')\n", 973 | "\n", 974 | " x = Conv2D(32, kernel_size=(3, 3), activation='relu', padding='same',\n", 975 | " kernel_regularizer=regularizers.l2(L2_REG))(image_input)\n", 976 | " x = MaxPooling2D(pool_size=(2, 2), padding='same')(x)\n", 977 | "\n", 978 | " x = Conv2D(64, kernel_size=(3, 3), activation='relu', padding='same',\n", 979 | " kernel_regularizer=regularizers.l2(L2_REG))(x)\n", 980 | " x = MaxPooling2D(pool_size=(2, 2), padding='same')(x)\n", 981 | "\n", 982 | " x = Conv2D(128, kernel_size=(3, 3), activation='relu', padding='same',\n", 983 | " kernel_regularizer=regularizers.l2(L2_REG))(x)\n", 984 | " x = BatchNormalization()(x)\n", 985 | " x = MaxPooling2D(pool_size=(2, 2), padding='same')(x)\n", 986 | "\n", 987 | " x = Conv2D(128, kernel_size=(3, 3), activation='relu', padding='same',\n", 988 | " kernel_regularizer=regularizers.l2(L2_REG))(x)\n", 989 | " x = BatchNormalization()(x)\n", 990 | " x = MaxPooling2D(pool_size=(2, 2), padding='same')(x)\n", 991 | "\n", 992 | " x = Flatten()(x)\n", 993 | " x = Dropout(params['dropout_cnn'])(x)\n", 994 | " cnn_output = Dense(params['cnn_dense_units'], activation='relu',\n", 995 | " kernel_regularizer=regularizers.l2(L2_REG))(x)\n", 996 | "\n", 997 | " # Feature branch\n", 998 | " feature_input = Input(shape=(num_features,), name='feature_input')\n", 999 | " # Normalize features first (important!)\n", 1000 | " features_normalized = BatchNormalization()(feature_input)\n", 1001 | "\n", 1002 | " # Direct concatenation\n", 1003 | " combined = Concatenate()([cnn_output, features_normalized]) # 130 features\n", 1004 | " # Final layers\n", 1005 | " z = Dense(64, activation='relu', kernel_regularizer=regularizers.l2(L2_REG))(combined)\n", 1006 | " z = Dropout(params['dropout_final'])(z)\n", 1007 | " z = Dense(32, activation='relu', kernel_regularizer=regularizers.l2(L2_REG))(z)\n", 1008 | " z = Dropout(params['dropout_final'] * 0.6)(z) # Slightly less dropout in second layer\n", 1009 | "\n", 1010 | " output = Dense(2, activation='softmax', name='output')(z)\n", 1011 | "\n", 1012 | " model = Model(inputs=[image_input, feature_input], outputs=output, name='Hybrid_CNN')\n", 1013 | "\n", 1014 | " # Compile\n", 1015 | " model.compile(\n", 1016 | " optimizer=Adam(learning_rate=params['learning_rate']),\n", 1017 | " loss='categorical_crossentropy',\n", 1018 | " metrics=['accuracy']\n", 1019 | " )\n", 1020 | "\n", 1021 | " return model\n", 1022 | "\n", 1023 | "\n", 1024 | "# ==========================================\n", 1025 | "# K-FOLD CROSS-VALIDATION FUNCTION\n", 1026 | "# ==========================================\n", 1027 | "\n", 1028 | "def kfold_cross_validation(X_paths, X_features, y, params, n_splits=5,\n", 1029 | " feature_columns=None, max_epochs=30, verbose=1):\n", 1030 | " \"\"\"\n", 1031 | " Perform k-fold cross-validation with given hyperparameters.\n", 1032 | "\n", 1033 | " Parameters:\n", 1034 | " - X_paths: Array of image paths\n", 1035 | " - X_features: Array of feature values\n", 1036 | " - y: Array of labels (one-hot encoded)\n", 1037 | " - params: Dictionary of hyperparameters to test\n", 1038 | " - n_splits: Number of folds\n", 1039 | " - feature_columns: List of feature column names\n", 1040 | " - max_epochs: Maximum training epochs per fold\n", 1041 | " - verbose: Verbosity level\n", 1042 | "\n", 1043 | " Returns:\n", 1044 | " - Dictionary with CV results\n", 1045 | " \"\"\"\n", 1046 | " kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)\n", 1047 | "\n", 1048 | " fold_results = []\n", 1049 | " fold_histories = []\n", 1050 | "\n", 1051 | " print(f\"\\n{'='*80}\")\n", 1052 | " print(f\"TESTING HYPERPARAMETERS:\")\n", 1053 | " print(f\"{'='*80}\")\n", 1054 | " for key, value in params.items():\n", 1055 | " print(f\" {key:20s}: {value}\")\n", 1056 | " print(f\"{'='*80}\\n\")\n", 1057 | "\n", 1058 | " for fold_idx, (train_idx, val_idx) in enumerate(kf.split(X_paths)):\n", 1059 | " print(f\"\\n--- Fold {fold_idx + 1}/{n_splits} ---\")\n", 1060 | "\n", 1061 | " # Split data\n", 1062 | " X_paths_train, X_paths_val = X_paths[train_idx], X_paths[val_idx]\n", 1063 | " X_features_train, X_features_val = X_features[train_idx], X_features[val_idx]\n", 1064 | " y_train, y_val = y[train_idx], y[val_idx]\n", 1065 | "\n", 1066 | " # Create datasets\n", 1067 | " train_ds = tf.data.Dataset.from_tensor_slices(\n", 1068 | " (X_paths_train, X_features_train, y_train)\n", 1069 | " )\n", 1070 | " train_ds = train_ds.map(\n", 1071 | " lambda p, f, l: load_and_augment_hybrid(p, f, l, augment=True),\n", 1072 | " num_parallel_calls=tf.data.AUTOTUNE\n", 1073 | " ).cache().shuffle(buffer_size=len(X_paths_train)).batch(params['batch_size']).prefetch(tf.data.AUTOTUNE)\n", 1074 | "\n", 1075 | " val_ds = tf.data.Dataset.from_tensor_slices(\n", 1076 | " (X_paths_val, X_features_val, y_val)\n", 1077 | " )\n", 1078 | " val_ds = val_ds.map(\n", 1079 | " lambda p, f, l: load_and_augment_hybrid(p, f, l, augment=False),\n", 1080 | " num_parallel_calls=tf.data.AUTOTUNE\n", 1081 | " ).cache().batch(params['batch_size']).prefetch(tf.data.AUTOTUNE)\n", 1082 | "\n", 1083 | " # Build model\n", 1084 | " model = build_hybrid_model(num_features=X_features.shape[1], params=params)\n", 1085 | "\n", 1086 | " # Callbacks\n", 1087 | " callbacks = [\n", 1088 | " EarlyStopping(\n", 1089 | " monitor='val_loss',\n", 1090 | " patience=10,\n", 1091 | " restore_best_weights=True,\n", 1092 | " verbose=0\n", 1093 | " ),\n", 1094 | " ReduceLROnPlateau(\n", 1095 | " monitor='val_loss',\n", 1096 | " factor=0.5,\n", 1097 | " patience=3,\n", 1098 | " min_lr=1e-7,\n", 1099 | " verbose=0\n", 1100 | " )\n", 1101 | " ]\n", 1102 | "\n", 1103 | " # Train\n", 1104 | " history = model.fit(\n", 1105 | " train_ds,\n", 1106 | " validation_data=val_ds,\n", 1107 | " epochs=max_epochs,\n", 1108 | " callbacks=callbacks,\n", 1109 | " verbose=verbose\n", 1110 | " )\n", 1111 | "\n", 1112 | " # Evaluate\n", 1113 | " train_loss, train_acc = model.evaluate(train_ds, verbose=0)\n", 1114 | " val_loss, val_acc = model.evaluate(val_ds, verbose=0)\n", 1115 | "\n", 1116 | " # Store results\n", 1117 | " fold_results.append({\n", 1118 | " 'fold': fold_idx + 1,\n", 1119 | " 'train_loss': train_loss,\n", 1120 | " 'train_acc': train_acc,\n", 1121 | " 'val_loss': val_loss,\n", 1122 | " 'val_acc': val_acc,\n", 1123 | " 'best_epoch': np.argmin(history.history['val_loss']) + 1,\n", 1124 | " 'final_epoch': len(history.history['loss'])\n", 1125 | " })\n", 1126 | "\n", 1127 | " fold_histories.append(history.history)\n", 1128 | "\n", 1129 | " print(f\" Train: Loss={train_loss:.4f}, Acc={train_acc:.4f}\")\n", 1130 | " print(f\" Val: Loss={val_loss:.4f}, Acc={val_acc:.4f}\")\n", 1131 | "\n", 1132 | " # Aggregate results\n", 1133 | " results_df = pd.DataFrame(fold_results)\n", 1134 | "\n", 1135 | " cv_results = {\n", 1136 | " 'params': params,\n", 1137 | " 'mean_val_acc': results_df['val_acc'].mean(),\n", 1138 | " 'std_val_acc': results_df['val_acc'].std(),\n", 1139 | " 'mean_train_acc': results_df['train_acc'].mean(),\n", 1140 | " 'std_train_acc': results_df['train_acc'].std(),\n", 1141 | " 'mean_val_loss': results_df['val_loss'].mean(),\n", 1142 | " 'std_val_loss': results_df['val_loss'].std(),\n", 1143 | " 'overfitting_gap': results_df['train_acc'].mean() - results_df['val_acc'].mean(),\n", 1144 | " 'fold_results': fold_results,\n", 1145 | " 'fold_histories': fold_histories\n", 1146 | " }\n", 1147 | "\n", 1148 | " print(f\"\\n{'='*80}\")\n", 1149 | " print(f\"CROSS-VALIDATION SUMMARY:\")\n", 1150 | " print(f\"{'='*80}\")\n", 1151 | " print(f\"Mean Val Accuracy: {cv_results['mean_val_acc']:.4f} ± {cv_results['std_val_acc']:.4f}\")\n", 1152 | " print(f\"Mean Train Accuracy: {cv_results['mean_train_acc']:.4f} ± {cv_results['std_train_acc']:.4f}\")\n", 1153 | " print(f\"Overfitting Gap: {cv_results['overfitting_gap']:.4f}\")\n", 1154 | " print(f\"{'='*80}\\n\")\n", 1155 | "\n", 1156 | " return cv_results\n", 1157 | "\n", 1158 | "\n", 1159 | "# ==========================================\n", 1160 | "# GRID SEARCH WITH K-FOLD CV\n", 1161 | "# ==========================================\n", 1162 | "\n", 1163 | "def grid_search_kfold(X_paths, X_features, y, param_grid, n_splits=5,\n", 1164 | " feature_columns=None, max_epochs=30, max_combinations=None):\n", 1165 | " \"\"\"\n", 1166 | " Perform grid search with k-fold cross-validation.\n", 1167 | "\n", 1168 | " Parameters:\n", 1169 | " - X_paths: Array of image paths\n", 1170 | " - X_features: Array of feature values\n", 1171 | " - y: Array of labels\n", 1172 | " - param_grid: Dictionary of hyperparameter lists to search\n", 1173 | " - n_splits: Number of CV folds\n", 1174 | " - feature_columns: List of feature column names\n", 1175 | " - max_epochs: Maximum epochs per fold\n", 1176 | " - max_combinations: Limit number of combinations to test (None = all)\n", 1177 | "\n", 1178 | " Returns:\n", 1179 | " - results_df: DataFrame with all results\n", 1180 | " - best_params: Best hyperparameter combination\n", 1181 | " \"\"\"\n", 1182 | " # Generate all combinations\n", 1183 | " param_names = list(param_grid.keys())\n", 1184 | " param_values = list(param_grid.values())\n", 1185 | " all_combinations = list(product(*param_values))\n", 1186 | "\n", 1187 | " print(f\"\\n{'='*80}\")\n", 1188 | " print(f\"GRID SEARCH WITH {n_splits}-FOLD CROSS-VALIDATION\")\n", 1189 | " print(f\"{'='*80}\")\n", 1190 | " print(f\"Total combinations to test: {len(all_combinations)}\")\n", 1191 | "\n", 1192 | " if max_combinations and len(all_combinations) > max_combinations:\n", 1193 | " print(f\"Limiting to {max_combinations} random combinations\")\n", 1194 | " np.random.shuffle(all_combinations)\n", 1195 | " all_combinations = all_combinations[:max_combinations]\n", 1196 | "\n", 1197 | " print(f\"Testing {len(all_combinations)} combinations...\")\n", 1198 | " print(f\"Estimated time: ~{len(all_combinations) * n_splits * 5} minutes\")\n", 1199 | " print(f\"{'='*80}\\n\")\n", 1200 | "\n", 1201 | " all_results = []\n", 1202 | "\n", 1203 | " for idx, param_combo in enumerate(all_combinations):\n", 1204 | " print(f\"\\n{'#'*80}\")\n", 1205 | " print(f\"COMBINATION {idx + 1}/{len(all_combinations)}\")\n", 1206 | " print(f\"{'#'*80}\")\n", 1207 | "\n", 1208 | " # Create params dict\n", 1209 | " params = {name: value for name, value in zip(param_names, param_combo)}\n", 1210 | "\n", 1211 | " # Run k-fold CV\n", 1212 | " cv_results = kfold_cross_validation(\n", 1213 | " X_paths, X_features, y, params,\n", 1214 | " n_splits=n_splits,\n", 1215 | " feature_columns=feature_columns,\n", 1216 | " max_epochs=max_epochs,\n", 1217 | " verbose=0 # Less verbose for grid search\n", 1218 | " )\n", 1219 | "\n", 1220 | " all_results.append(cv_results)\n", 1221 | "\n", 1222 | " # Convert to DataFrame\n", 1223 | " results_data = []\n", 1224 | " for result in all_results:\n", 1225 | " row = result['params'].copy()\n", 1226 | " row.update({\n", 1227 | " 'mean_val_acc': result['mean_val_acc'],\n", 1228 | " 'std_val_acc': result['std_val_acc'],\n", 1229 | " 'mean_train_acc': result['mean_train_acc'],\n", 1230 | " 'overfitting_gap': result['overfitting_gap'],\n", 1231 | " 'mean_val_loss': result['mean_val_loss']\n", 1232 | " })\n", 1233 | " results_data.append(row)\n", 1234 | "\n", 1235 | " results_df = pd.DataFrame(results_data)\n", 1236 | " results_df = results_df.sort_values('mean_val_acc', ascending=False)\n", 1237 | "\n", 1238 | " # Get best params\n", 1239 | " best_params = results_df.iloc[0][list(param_grid.keys())].to_dict()\n", 1240 | "\n", 1241 | " print(f\"\\n{'='*80}\")\n", 1242 | " print(f\"GRID SEARCH COMPLETE\")\n", 1243 | " print(f\"{'='*80}\")\n", 1244 | " print(f\"\\nBest Validation Accuracy: {results_df.iloc[0]['mean_val_acc']:.4f} \"\n", 1245 | " f\"± {results_df.iloc[0]['std_val_acc']:.4f}\")\n", 1246 | " print(f\"\\nBest Hyperparameters:\")\n", 1247 | " for key, value in best_params.items():\n", 1248 | " print(f\" {key:20s}: {value}\")\n", 1249 | " print(f\"{'='*80}\\n\")\n", 1250 | "\n", 1251 | " return results_df, best_params, all_results\n", 1252 | "\n", 1253 | "\n", 1254 | "# ==========================================\n", 1255 | "# VISUALIZATION FUNCTIONS\n", 1256 | "# ==========================================\n", 1257 | "\n", 1258 | "def plot_grid_search_results(results_df, save_path='grid_search_results.png'):\n", 1259 | " \"\"\"Visualize grid search results.\"\"\"\n", 1260 | " fig, axes = plt.subplots(2, 3, figsize=(18, 12))\n", 1261 | "\n", 1262 | " # Plot 1: Top 10 configurations\n", 1263 | " ax1 = axes[0, 0]\n", 1264 | " top_10 = results_df.head(10)\n", 1265 | " y_pos = np.arange(len(top_10))\n", 1266 | " ax1.barh(y_pos, top_10['mean_val_acc'], xerr=top_10['std_val_acc'],\n", 1267 | " color='steelblue', alpha=0.7, capsize=5)\n", 1268 | " ax1.set_yticks(y_pos)\n", 1269 | " ax1.set_yticklabels([f\"Config {i+1}\" for i in range(len(top_10))])\n", 1270 | " ax1.set_xlabel('Validation Accuracy', fontweight='bold')\n", 1271 | " ax1.set_title('Top 10 Configurations', fontweight='bold')\n", 1272 | " ax1.invert_yaxis()\n", 1273 | " ax1.grid(axis='x', alpha=0.3)\n", 1274 | "\n", 1275 | " # Plot 2: Learning rate effect\n", 1276 | " ax2 = axes[0, 1]\n", 1277 | " lr_groups = results_df.groupby('learning_rate')['mean_val_acc'].agg(['mean', 'std'])\n", 1278 | " ax2.errorbar(lr_groups.index, lr_groups['mean'], yerr=lr_groups['std'],\n", 1279 | " marker='o', capsize=5, linewidth=2, markersize=8)\n", 1280 | " ax2.set_xlabel('Learning Rate', fontweight='bold')\n", 1281 | " ax2.set_ylabel('Mean Validation Accuracy', fontweight='bold')\n", 1282 | " ax2.set_title('Learning Rate Effect', fontweight='bold')\n", 1283 | " ax2.set_xscale('log')\n", 1284 | " ax2.grid(alpha=0.3)\n", 1285 | "\n", 1286 | " # Plot 3: Batch size effect\n", 1287 | " ax3 = axes[0, 2]\n", 1288 | " bs_groups = results_df.groupby('batch_size')['mean_val_acc'].agg(['mean', 'std'])\n", 1289 | " ax3.errorbar(bs_groups.index, bs_groups['mean'], yerr=bs_groups['std'],\n", 1290 | " marker='s', capsize=5, linewidth=2, markersize=8, color='coral')\n", 1291 | " ax3.set_xlabel('Batch Size', fontweight='bold')\n", 1292 | " ax3.set_ylabel('Mean Validation Accuracy', fontweight='bold')\n", 1293 | " ax3.set_title('Batch Size Effect', fontweight='bold')\n", 1294 | " ax3.grid(alpha=0.3)\n", 1295 | "\n", 1296 | " # Plot 4: L2 regularization effect\n", 1297 | " ax4 = axes[1, 0]\n", 1298 | " l2_groups = results_df.groupby('l2_reg')['mean_val_acc'].agg(['mean', 'std'])\n", 1299 | " ax4.errorbar(l2_groups.index, l2_groups['mean'], yerr=l2_groups['std'],\n", 1300 | " marker='^', capsize=5, linewidth=2, markersize=8, color='green')\n", 1301 | " ax4.set_xlabel('L2 Regularization', fontweight='bold')\n", 1302 | " ax4.set_ylabel('Mean Validation Accuracy', fontweight='bold')\n", 1303 | " ax4.set_title('L2 Regularization Effect', fontweight='bold')\n", 1304 | " ax4.set_xscale('log')\n", 1305 | " ax4.grid(alpha=0.3)\n", 1306 | "\n", 1307 | " # Plot 5: Overfitting gap\n", 1308 | " ax5 = axes[1, 1]\n", 1309 | " ax5.scatter(results_df['mean_val_acc'], results_df['overfitting_gap'],\n", 1310 | " alpha=0.6, s=100, c=results_df['batch_size'], cmap='viridis')\n", 1311 | " ax5.axhline(y=0, color='red', linestyle='--', linewidth=2, label='No overfitting')\n", 1312 | " ax5.set_xlabel('Validation Accuracy', fontweight='bold')\n", 1313 | " ax5.set_ylabel('Overfitting Gap (Train - Val)', fontweight='bold')\n", 1314 | " ax5.set_title('Overfitting Analysis', fontweight='bold')\n", 1315 | " ax5.legend()\n", 1316 | " ax5.grid(alpha=0.3)\n", 1317 | " cbar = plt.colorbar(ax5.collections[0], ax=ax5)\n", 1318 | " cbar.set_label('Batch Size', fontweight='bold')\n", 1319 | "\n", 1320 | " # Plot 6: Accuracy vs Loss\n", 1321 | " ax6 = axes[1, 2]\n", 1322 | " scatter = ax6.scatter(results_df['mean_val_loss'], results_df['mean_val_acc'],\n", 1323 | " alpha=0.6, s=100, c=results_df['learning_rate'],\n", 1324 | " cmap='plasma', norm=plt.matplotlib.colors.LogNorm())\n", 1325 | " ax6.set_xlabel('Validation Loss', fontweight='bold')\n", 1326 | " ax6.set_ylabel('Validation Accuracy', fontweight='bold')\n", 1327 | " ax6.set_title('Loss vs Accuracy', fontweight='bold')\n", 1328 | " ax6.grid(alpha=0.3)\n", 1329 | " cbar = plt.colorbar(scatter, ax=ax6)\n", 1330 | " cbar.set_label('Learning Rate', fontweight='bold')\n", 1331 | "\n", 1332 | " plt.tight_layout()\n", 1333 | " plt.savefig(save_path, dpi=300, bbox_inches='tight')\n", 1334 | " print(f\"✓ Grid search visualization saved to {save_path}\")\n", 1335 | " plt.show()\n", 1336 | "\n", 1337 | "\n", 1338 | "def plot_cv_fold_variance(all_results, best_idx=0, save_path='cv_fold_variance.png'):\n", 1339 | " \"\"\"Plot variance across CV folds for best configuration.\"\"\"\n", 1340 | " best_result = all_results[best_idx]\n", 1341 | " fold_results = pd.DataFrame(best_result['fold_results'])\n", 1342 | "\n", 1343 | " fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", 1344 | "\n", 1345 | " # Plot 1: Accuracy across folds\n", 1346 | " ax1 = axes[0]\n", 1347 | " x = fold_results['fold']\n", 1348 | " ax1.plot(x, fold_results['train_acc'], 'o-', label='Train', linewidth=2, markersize=8)\n", 1349 | " ax1.plot(x, fold_results['val_acc'], 's-', label='Validation', linewidth=2, markersize=8)\n", 1350 | " ax1.fill_between(x,\n", 1351 | " fold_results['val_acc'].mean() - fold_results['val_acc'].std(),\n", 1352 | " fold_results['val_acc'].mean() + fold_results['val_acc'].std(),\n", 1353 | " alpha=0.2)\n", 1354 | " ax1.set_xlabel('Fold', fontweight='bold')\n", 1355 | " ax1.set_ylabel('Accuracy', fontweight='bold')\n", 1356 | " ax1.set_title(f\"Accuracy Across CV Folds\\nMean Val: {fold_results['val_acc'].mean():.4f} ± {fold_results['val_acc'].std():.4f}\",\n", 1357 | " fontweight='bold')\n", 1358 | " ax1.legend()\n", 1359 | " ax1.grid(alpha=0.3)\n", 1360 | " ax1.set_xticks(x)\n", 1361 | "\n", 1362 | " # Plot 2: Loss across folds\n", 1363 | " ax2 = axes[1]\n", 1364 | " ax2.plot(x, fold_results['train_loss'], 'o-', label='Train', linewidth=2, markersize=8)\n", 1365 | " ax2.plot(x, fold_results['val_loss'], 's-', label='Validation', linewidth=2, markersize=8)\n", 1366 | " ax2.fill_between(x,\n", 1367 | " fold_results['val_loss'].mean() - fold_results['val_loss'].std(),\n", 1368 | " fold_results['val_loss'].mean() + fold_results['val_loss'].std(),\n", 1369 | " alpha=0.2)\n", 1370 | " ax2.set_xlabel('Fold', fontweight='bold')\n", 1371 | " ax2.set_ylabel('Loss', fontweight='bold')\n", 1372 | " ax2.set_title(f\"Loss Across CV Folds\\nMean Val: {fold_results['val_loss'].mean():.4f} ± {fold_results['val_loss'].std():.4f}\",\n", 1373 | " fontweight='bold')\n", 1374 | " ax2.legend()\n", 1375 | " ax2.grid(alpha=0.3)\n", 1376 | " ax2.set_xticks(x)\n", 1377 | "\n", 1378 | " plt.tight_layout()\n", 1379 | " plt.savefig(save_path, dpi=300, bbox_inches='tight')\n", 1380 | " print(f\"✓ CV fold variance plot saved to {save_path}\")\n", 1381 | " plt.show()\n", 1382 | "\n", 1383 | "\n", 1384 | "# ==========================================\n", 1385 | "# MAIN EXECUTION\n", 1386 | "# ==========================================\n", 1387 | "\n", 1388 | "if __name__ == \"__main__\":\n", 1389 | "\n", 1390 | " print(\"=\"*80)\n", 1391 | " print(\"HYBRID CNN HYPERPARAMETER OPTIMIZATION\")\n", 1392 | " print(\"=\"*80)\n", 1393 | "\n", 1394 | " # Combine train + val for cross-validation\n", 1395 | " # (Keep test set completely separate!)\n", 1396 | " X_paths_cv = np.concatenate([X_paths_train, X_paths_val])\n", 1397 | " X_features_cv = np.concatenate([X_features_train, X_features_val])\n", 1398 | " y_cv = np.concatenate([y_train, y_val])\n", 1399 | "\n", 1400 | " print(f\"\\nCombined dataset for CV: {len(X_paths_cv)} samples\")\n", 1401 | " print(f\"Test set (held out): {len(X_paths_test)} samples\")\n", 1402 | "\n", 1403 | " # Option 1: Quick search (smaller grid)\n", 1404 | " print(\"\\n--- OPTION 1: QUICK SEARCH ---\")\n", 1405 | " quick_param_grid = {\n", 1406 | " 'batch_size': [16, 32],\n", 1407 | " 'learning_rate': [0.00005, 0.0001],\n", 1408 | " 'l2_reg': [0.0001],\n", 1409 | " 'dropout_cnn': [0.4],\n", 1410 | " 'dropout_features': [0.3],\n", 1411 | " 'dropout_final': [0.5],\n", 1412 | " 'cnn_dense_units': [128],\n", 1413 | " 'feature_dense_units': [16],\n", 1414 | " }\n", 1415 | "\n", 1416 | " # Run quick search (2 x 2 = 4 combinations)\n", 1417 | " results_df_quick, best_params_quick, all_results_quick = grid_search_kfold(\n", 1418 | " X_paths_cv, X_features_cv, y_cv,\n", 1419 | " param_grid=quick_param_grid,\n", 1420 | " n_splits=5,\n", 1421 | " feature_columns=feature_columns,\n", 1422 | " max_epochs=30\n", 1423 | " )\n", 1424 | "\n", 1425 | " # Visualize results\n", 1426 | " plot_grid_search_results(results_df_quick, save_path='quick_search_results.png')\n", 1427 | " plot_cv_fold_variance(all_results_quick, best_idx=0, save_path='quick_search_cv_folds.png')\n", 1428 | "\n", 1429 | " # Save results\n", 1430 | " results_df_quick.to_csv('quick_search_results.csv', index=False)\n", 1431 | "\n", 1432 | " with open('best_params_quick.json', 'w') as f:\n", 1433 | " json.dump(best_params_quick, f, indent=2)\n", 1434 | "\n", 1435 | " print(\"\\n✓ Quick search complete!\")\n", 1436 | " print(\"✓ Results saved to 'quick_search_results.csv'\")\n", 1437 | " print(\"✓ Best parameters saved to 'best_params_quick.json'\")\n", 1438 | "\n", 1439 | " # Option 2: Full search (if you have time - this will take hours!)\n", 1440 | " run_full_search = input(\"\\nRun full grid search? This will take several hours. (y/n): \")\n", 1441 | "\n", 1442 | " if run_full_search.lower() == 'y':\n", 1443 | " print(\"\\n--- OPTION 2: FULL SEARCH ---\")\n", 1444 | "\n", 1445 | " results_df_full, best_params_full, all_results_full = grid_search_kfold(\n", 1446 | " X_paths_cv, X_features_cv, y_cv,\n", 1447 | " param_grid=param_grid, # Full grid\n", 1448 | " n_splits=5,\n", 1449 | " feature_columns=feature_columns,\n", 1450 | " max_epochs=30,\n", 1451 | " max_combinations=50 # Limit to 50 random combinations\n", 1452 | " )\n", 1453 | "\n", 1454 | " # Visualize\n", 1455 | " plot_grid_search_results(results_df_full, save_path='full_search_results.png')\n", 1456 | " plot_cv_fold_variance(all_results_full, best_idx=0, save_path='full_search_cv_folds.png')\n", 1457 | "\n", 1458 | " # Save\n", 1459 | " results_df_full.to_csv('full_search_results.csv', index=False)\n", 1460 | "\n", 1461 | " with open('best_params_full.json', 'w') as f:\n", 1462 | " json.dump(best_params_full, f, indent=2)\n", 1463 | "\n", 1464 | " print(\"\\n✓ Full search complete!\")\n", 1465 | " print(\"✓ Results saved to 'full_search_results.csv'\")\n", 1466 | " print(\"✓ Best parameters saved to 'best_params_full.json'\")\n", 1467 | "\n", 1468 | " # ==========================================\n", 1469 | " # TRAIN FINAL MODEL WITH BEST PARAMS\n", 1470 | " # ==========================================\n", 1471 | "\n", 1472 | " print(\"\\n\" + \"=\"*80)\n", 1473 | " print(\"TRAINING FINAL MODEL WITH BEST PARAMETERS\")\n", 1474 | " print(\"=\"*80)\n", 1475 | "\n", 1476 | " # Use best params from quick search\n", 1477 | " best_params = best_params_quick\n", 1478 | "\n", 1479 | " # Recreate train/val split\n", 1480 | " train_ds_final = tf.data.Dataset.from_tensor_slices(\n", 1481 | " (X_paths_train, X_features_train, y_train)\n", 1482 | " ).map(\n", 1483 | " lambda p, f, l: load_and_augment_hybrid(p, f, l, augment=True),\n", 1484 | " num_parallel_calls=tf.data.AUTOTUNE\n", 1485 | " ).cache().shuffle(buffer_size=len(X_paths_train)).batch(best_params['batch_size']).prefetch(tf.data.AUTOTUNE)\n", 1486 | "\n", 1487 | " val_ds_final = tf.data.Dataset.from_tensor_slices(\n", 1488 | " (X_paths_val, X_features_val, y_val)\n", 1489 | " ).map(\n", 1490 | " lambda p, f, l: load_and_augment_hybrid(p, f, l, augment=False),\n", 1491 | " num_parallel_calls=tf.data.AUTOTUNE\n", 1492 | " ).cache().batch(best_params['batch_size']).prefetch(tf.data.AUTOTUNE)\n", 1493 | "\n", 1494 | " test_ds_final = tf.data.Dataset.from_tensor_slices(\n", 1495 | " (X_paths_test, X_features_test, y_test)\n", 1496 | " ).map(\n", 1497 | " lambda p, f, l: load_and_augment_hybrid(p, f, l, augment=False),\n", 1498 | " num_parallel_calls=tf.data.AUTOTUNE\n", 1499 | " ).batch(best_params['batch_size']).prefetch(tf.data.AUTOTUNE)\n", 1500 | "\n", 1501 | " # Build and train final model\n", 1502 | " final_model = build_hybrid_model(num_features=X_features_train.shape[1], params=best_params)\n", 1503 | "\n", 1504 | " callbacks_final = [\n", 1505 | " EarlyStopping(patience=15, restore_best_weights=True, verbose=1),\n", 1506 | " ReduceLROnPlateau(factor=0.5, patience=5, min_lr=1e-7, verbose=1)\n", 1507 | " ]\n", 1508 | "\n", 1509 | " history_final = final_model.fit(\n", 1510 | " train_ds_final,\n", 1511 | " validation_data=val_ds_final,\n", 1512 | " epochs=50,\n", 1513 | " callbacks=callbacks_final,\n", 1514 | " verbose=1\n", 1515 | " )\n", 1516 | "\n", 1517 | " # Evaluate on test set\n", 1518 | " test_loss, test_acc = final_model.evaluate(test_ds_final)\n", 1519 | "\n", 1520 | " print(f\"\\n{'='*80}\")\n", 1521 | " print(f\"FINAL MODEL PERFORMANCE\")\n", 1522 | " print(f\"{'='*80}\")\n", 1523 | " print(f\"Test Accuracy: {test_acc:.4f}\")\n", 1524 | " print(f\"Test Loss: {test_loss:.4f}\")\n", 1525 | " print(f\"{'='*80}\\n\")\n", 1526 | "\n", 1527 | " # Save final model\n", 1528 | " final_model.save('final_optimized_model.h5')\n", 1529 | " print(\"✓ Final model saved to 'final_optimized_model.h5'\")" 1530 | ] 1531 | }, 1532 | { 1533 | "cell_type": "markdown", 1534 | "metadata": { 1535 | "id": "4CXejxeRGumi" 1536 | }, 1537 | "source": [ 1538 | "## 8. Compile the Model" 1539 | ] 1540 | }, 1541 | { 1542 | "cell_type": "code", 1543 | "execution_count": null, 1544 | "metadata": { 1545 | "id": "a0JrptfGNwAe" 1546 | }, 1547 | "outputs": [], 1548 | "source": [ 1549 | "from tensorflow.keras.callbacks import ReduceLROnPlateau# Exponential learning rate decay\n", 1550 | "\n", 1551 | "\n", 1552 | "lr_scheduler = ReduceLROnPlateau(\n", 1553 | " monitor='val_loss',\n", 1554 | " factor=0.5, # Reduce LR by half\n", 1555 | " patience=3, # Wait 3 epochs before reducing\n", 1556 | " min_lr=1e-7,\n", 1557 | " verbose=1\n", 1558 | ")" 1559 | ] 1560 | }, 1561 | { 1562 | "cell_type": "code", 1563 | "execution_count": null, 1564 | "metadata": { 1565 | "id": "EbTxiswS3oeH" 1566 | }, 1567 | "outputs": [], 1568 | "source": [ 1569 | "from tensorflow.keras.optimizers import Adam\n", 1570 | "from tensorflow.keras.metrics import Precision, Recall, AUC\n", 1571 | "def compile_model(model):\n", 1572 | " \"\"\"\n", 1573 | " Compile the model with appropriate loss function and optimizer.\n", 1574 | " \"\"\"\n", 1575 | " model.compile(\n", 1576 | " optimizer=Adam(learning_rate=0.0001), # Start lower\n", 1577 | " loss='categorical_crossentropy',\n", 1578 | " metrics=['accuracy', Precision(), Recall(), AUC()]\n", 1579 | " )" 1580 | ] 1581 | }, 1582 | { 1583 | "cell_type": "code", 1584 | "execution_count": null, 1585 | "metadata": { 1586 | "id": "YuPV6bYtGumi" 1587 | }, 1588 | "outputs": [], 1589 | "source": [ 1590 | "# Compile with appropriate loss function and optimizer\n", 1591 | "\n", 1592 | "compile_model(model)" 1593 | ] 1594 | }, 1595 | { 1596 | "cell_type": "markdown", 1597 | "metadata": { 1598 | "id": "PwaMZebuGumi" 1599 | }, 1600 | "source": [ 1601 | "## 9. Set Up Callbacks" 1602 | ] 1603 | }, 1604 | { 1605 | "cell_type": "code", 1606 | "execution_count": null, 1607 | "metadata": { 1608 | "colab": { 1609 | "base_uri": "https://localhost:8080/" 1610 | }, 1611 | "id": "a1mC1HCcGumi", 1612 | "outputId": "687fdfa4-c0b4-4531-dbe8-82c9a32a7303" 1613 | }, 1614 | "outputs": [], 1615 | "source": [ 1616 | "\n", 1617 | "# --- 1. Early Stopping (Updated Patience) ---\n", 1618 | "early_stop = EarlyStopping(\n", 1619 | " monitor='val_accuracy', # Monitors the metric you want to minimize\n", 1620 | " mode='max', # Stop when val_loss stops decreasing\n", 1621 | " patience=10, # Reduced patience to 10 epochs\n", 1622 | " verbose=1,\n", 1623 | " restore_best_weights=True\n", 1624 | ")\n", 1625 | "\n", 1626 | "# --- 2. Model Checkpoint (Updated Monitoring and Naming) ---\n", 1627 | "checkpoint = ModelCheckpoint(\n", 1628 | " # Updated file name for clarity (assuming binary classification)\n", 1629 | " 'best_particle_classifier_hybrid_binary.keras',\n", 1630 | " monitor='val_accuracy', # Changed to monitor val_loss, matching EarlyStopping\n", 1631 | " mode='max', # Save when val_loss is at its minimum\n", 1632 | " verbose=1,\n", 1633 | " save_best_only=True\n", 1634 | ")\n", 1635 | "\n", 1636 | "print(\"Callbacks configured for robust training.\")" 1637 | ] 1638 | }, 1639 | { 1640 | "cell_type": "markdown", 1641 | "metadata": { 1642 | "id": "YbllFZmLGumi" 1643 | }, 1644 | "source": [ 1645 | "## 10. Train the Model" 1646 | ] 1647 | }, 1648 | { 1649 | "cell_type": "code", 1650 | "execution_count": null, 1651 | "metadata": { 1652 | "colab": { 1653 | "base_uri": "https://localhost:8080/" 1654 | }, 1655 | "id": "oB5JopEYGumi", 1656 | "outputId": "52f4718b-7768-46c7-fa5d-0e4d693756f6" 1657 | }, 1658 | "outputs": [], 1659 | "source": [ 1660 | "# --- Calculate Steps Per Epoch ---\n", 1661 | "# When using tf.data.Dataset, you need to tell Keras how many batches\n", 1662 | "# make up one epoch.\n", 1663 | "N_TRAIN_SAMPLES = len(X_paths_train) # Use the count of training samples\n", 1664 | "N_VAL_SAMPLES = len(X_paths_val) # Use the count of validation samples\n", 1665 | "N_EPOCHS = 100\n", 1666 | "steps_per_epoch = N_TRAIN_SAMPLES // BATCH_SIZE\n", 1667 | "# Use validation_steps to ensure validation is done over the entire set\n", 1668 | "validation_steps = N_VAL_SAMPLES // BATCH_SIZE\n", 1669 | "# If the number of samples is not perfectly divisible, you might add 1 to the steps\n", 1670 | "if N_TRAIN_SAMPLES % BATCH_SIZE != 0:\n", 1671 | " steps_per_epoch += 1\n", 1672 | "if N_VAL_SAMPLES % BATCH_SIZE != 0:\n", 1673 | " validation_steps += 1\n", 1674 | "# --- Train the model with class weights using the Datasets ---\n", 1675 | "print(\"Starting training with class weights...\")\n", 1676 | "print(f\"Using class weights:\")\n", 1677 | "for i, w in class_weight_dict.items():\n", 1678 | " print(f\" Class {i} ({get_class_name(i)}): {w:.2f}\")\n", 1679 | "\n", 1680 | "history = model.fit(\n", 1681 | " # 1. Pass the training dataset object\n", 1682 | " train_ds,\n", 1683 | " # 2. Set the number of steps per epoch\n", 1684 | " steps_per_epoch=steps_per_epoch,\n", 1685 | " epochs=N_EPOCHS,\n", 1686 | " # 3. Pass the validation dataset object\n", 1687 | " validation_data=val_ds,\n", 1688 | " # 4. Set the validation steps\n", 1689 | " validation_steps=validation_steps,\n", 1690 | " # 5. Keep class weights (applied per batch)\n", 1691 | " class_weight=class_weight_dict,\n", 1692 | " callbacks=[early_stop, checkpoint, lr_scheduler],\n", 1693 | " verbose=1\n", 1694 | ")\n", 1695 | "\n", 1696 | "print(\"\\nTraining complete!\")" 1697 | ] 1698 | }, 1699 | { 1700 | "cell_type": "markdown", 1701 | "metadata": { 1702 | "id": "HNFU-co8Gumi" 1703 | }, 1704 | "source": [ 1705 | "## 11. Visualize Training History" 1706 | ] 1707 | }, 1708 | { 1709 | "cell_type": "code", 1710 | "execution_count": null, 1711 | "metadata": { 1712 | "colab": { 1713 | "base_uri": "https://localhost:8080/", 1714 | "height": 578 1715 | }, 1716 | "id": "zfuBCuPOGumi", 1717 | "outputId": "5c7a8a35-866f-46e4-cb4c-7d8b86a1307a" 1718 | }, 1719 | "outputs": [], 1720 | "source": [ 1721 | "# Plot training history\n", 1722 | "fig, axes = plt.subplots(1, 2, figsize=(15, 5))\n", 1723 | "\n", 1724 | "# Plot accuracy\n", 1725 | "axes[0].plot(history.history['accuracy'], label='Training Accuracy')\n", 1726 | "axes[0].plot(history.history['val_accuracy'], label='Validation Accuracy')\n", 1727 | "axes[0].set_title('Model Accuracy (Hybrid CNN)')\n", 1728 | "axes[0].set_ylabel('Accuracy')\n", 1729 | "axes[0].set_xlabel('Epoch')\n", 1730 | "axes[0].legend()\n", 1731 | "axes[0].grid(True)\n", 1732 | "\n", 1733 | "# Plot loss\n", 1734 | "axes[1].plot(history.history['loss'], label='Training Loss')\n", 1735 | "axes[1].plot(history.history['val_loss'], label='Validation Loss')\n", 1736 | "axes[1].set_title('Model Loss (Hybrid CNN)')\n", 1737 | "axes[1].set_ylabel('Loss')\n", 1738 | "axes[1].set_xlabel('Epoch')\n", 1739 | "axes[1].legend()\n", 1740 | "axes[1].grid(True)\n", 1741 | "\n", 1742 | "plt.tight_layout()\n", 1743 | "plt.show()\n", 1744 | "\n", 1745 | "# Print final metrics\n", 1746 | "print(f\"\\nFinal Training Accuracy: {history.history['accuracy'][-1]:.4f}\")\n", 1747 | "print(f\"Final Validation Accuracy: {history.history['val_accuracy'][-1]:.4f}\")\n", 1748 | "print(f\"Best Validation Accuracy: {max(history.history['val_accuracy']):.4f}\")" 1749 | ] 1750 | }, 1751 | { 1752 | "cell_type": "code", 1753 | "execution_count": null, 1754 | "metadata": { 1755 | "colab": { 1756 | "base_uri": "https://localhost:8080/" 1757 | }, 1758 | "id": "sXA9BlLzw689", 1759 | "outputId": "f56a8d2e-5b77-4068-ab54-8fe09925d28f" 1760 | }, 1761 | "outputs": [], 1762 | "source": [ 1763 | "from sklearn.metrics import roc_curve, auc\n", 1764 | "import matplotlib.pyplot as plt\n", 1765 | "import numpy as np\n", 1766 | "\n", 1767 | "# Assuming your test data is processed and synchronized as NumPy arrays for simplicity\n", 1768 | "# If using the test_ds generator, you must iterate it to get all X and y_true values.\n", 1769 | "# For simplicity, we assume you have the test arrays: X_features_test, X_paths_test, y_test\n", 1770 | "# and load the full test data here (or use the test_ds object)\n", 1771 | "\n", 1772 | "# 1. Get the true labels\n", 1773 | "# If y_test is one-hot (e.g., [[1, 0], [0, 1]]), extract the positive class column (index 1)\n", 1774 | "y_true = y_test[:, 1]\n", 1775 | "\n", 1776 | "# 2. Get the model's probability predictions for the test set\n", 1777 | "# If using the tf.data.Dataset pipeline:\n", 1778 | "y_pred_probs = model.predict(test_ds)\n", 1779 | "\n", 1780 | "# If your model output is Dense(1, sigmoid), y_pred_probs is already (N_samples, 1)\n", 1781 | "# If your model output is Dense(2, softmax), y_pred_probs has shape (N_samples, 2).\n", 1782 | "# We take the probability of the positive class (index 1).\n", 1783 | "if y_pred_probs.shape[1] > 1:\n", 1784 | " y_scores = y_pred_probs[:, 1]\n", 1785 | "else:\n", 1786 | " y_scores = y_pred_probs" 1787 | ] 1788 | }, 1789 | { 1790 | "cell_type": "code", 1791 | "execution_count": null, 1792 | "metadata": { 1793 | "colab": { 1794 | "base_uri": "https://localhost:8080/", 1795 | "height": 564 1796 | }, 1797 | "id": "XkbJVZCYxEEg", 1798 | "outputId": "4fc242ae-5cb6-42e9-e0e9-01a8ce77442b" 1799 | }, 1800 | "outputs": [], 1801 | "source": [ 1802 | "# Calculate the True Positive Rate (TPR) and False Positive Rate (FPR)\n", 1803 | "fpr, tpr, thresholds = roc_curve(y_true, y_scores)\n", 1804 | "\n", 1805 | "# Calculate the AUC\n", 1806 | "roc_auc = auc(fpr, tpr)\n", 1807 | "\n", 1808 | "# Plot the ROC curve\n", 1809 | "plt.figure(figsize=(8, 6))\n", 1810 | "plt.plot(fpr, tpr, color='darkorange', lw=2,\n", 1811 | " label=f'ROC curve (AUC = {roc_auc:.4f})')\n", 1812 | "plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random Guess (AUC = 0.5)')\n", 1813 | "plt.xlim([0.0, 1.0])\n", 1814 | "plt.ylim([0.0, 1.05])\n", 1815 | "plt.xlabel('False Positive Rate (FPR)')\n", 1816 | "plt.ylabel('True Positive Rate (TPR) / Recall')\n", 1817 | "plt.title('Receiver Operating Characteristic (ROC) Curve')\n", 1818 | "plt.legend(loc=\"lower right\")\n", 1819 | "plt.grid(True)\n", 1820 | "plt.show()" 1821 | ] 1822 | }, 1823 | { 1824 | "cell_type": "markdown", 1825 | "metadata": { 1826 | "id": "mmK_N-pAGumi" 1827 | }, 1828 | "source": [ 1829 | "## 12. Evaluate on Test Set" 1830 | ] 1831 | }, 1832 | { 1833 | "cell_type": "code", 1834 | "execution_count": null, 1835 | "metadata": { 1836 | "colab": { 1837 | "base_uri": "https://localhost:8080/" 1838 | }, 1839 | "id": "zr-OTfFIGumi", 1840 | "outputId": "09d5924e-d045-4457-859d-ca555156a456" 1841 | }, 1842 | "outputs": [], 1843 | "source": [ 1844 | "# Load the best model\n", 1845 | "from keras.models import load_model\n", 1846 | "\n", 1847 | "best_model = load_model('best_particle_classifier_hybrid_binary.keras')\n", 1848 | "\n", 1849 | "# Evaluate on training set\n", 1850 | "train_metrics = best_model.evaluate(train_ds, verbose=0)\n", 1851 | "# Evaluate on validation set\n", 1852 | "val_metrics = best_model.evaluate(val_ds, verbose=0)\n", 1853 | "# Evaluate on test set\n", 1854 | "test_metrics = best_model.evaluate(test_ds, verbose=0)\n", 1855 | "\n", 1856 | "# The metrics list order is: [Loss, Accuracy, Precision, Recall, AUC]\n", 1857 | "# Note: Ensure your model was compiled with these metrics for this to work\n", 1858 | "metric_names = ['Loss', 'Accuracy', 'Precision', 'Recall', 'AUC']\n", 1859 | "\n", 1860 | "# --- 3. Print Results ---\n", 1861 | "print(\"\\n\" + \"=\" * 50)\n", 1862 | "print(\"HYBRID BINARY CNN PERFORMANCE\")\n", 1863 | "print(\"=\" * 50)\n", 1864 | "\n", 1865 | "def print_metrics(set_name, metrics):\n", 1866 | " \"\"\"Helper function to print formatted metrics.\"\"\"\n", 1867 | " print(f\"\\n--- {set_name} Set ---\")\n", 1868 | " for name, value in zip(metric_names, metrics):\n", 1869 | " if name == 'Accuracy':\n", 1870 | " print(f'{name}: {value:.4f} ({value*100:.2f}%)')\n", 1871 | " else:\n", 1872 | " print(f'{name}: {value:.4f}')\n", 1873 | "\n", 1874 | "print_metrics(\"Training\", train_metrics)\n", 1875 | "print_metrics(\"Validation\", val_metrics)\n", 1876 | "print_metrics(\"Test\", test_metrics)\n", 1877 | "\n", 1878 | "print(\"\\n\" + \"=\" * 50)" 1879 | ] 1880 | }, 1881 | { 1882 | "cell_type": "code", 1883 | "execution_count": null, 1884 | "metadata": { 1885 | "colab": { 1886 | "base_uri": "https://localhost:8080/", 1887 | "height": 1000 1888 | }, 1889 | "id": "evA8ObwJzuMm", 1890 | "outputId": "3fc0a386-2162-4f8b-d3c7-48e35b082623" 1891 | }, 1892 | "outputs": [], 1893 | "source": [ 1894 | "import matplotlib.pyplot as plt\n", 1895 | "import numpy as np\n", 1896 | "\n", 1897 | "def get_class_name(class_idx):\n", 1898 | " \"\"\"Convert class index to name.\"\"\"\n", 1899 | " return 'Liquid' if class_idx == 0 else 'Solid'\n", 1900 | "\n", 1901 | "\n", 1902 | "def get_all_images_from_dataset(test_ds):\n", 1903 | " \"\"\"Extract all images from batched test dataset.\"\"\"\n", 1904 | " all_images = []\n", 1905 | " all_features = []\n", 1906 | " all_labels = []\n", 1907 | "\n", 1908 | " for batch in test_ds:\n", 1909 | " (images, features), labels = batch\n", 1910 | " all_images.append(images.numpy())\n", 1911 | " all_features.append(features.numpy())\n", 1912 | " all_labels.append(labels.numpy())\n", 1913 | "\n", 1914 | " all_images = np.concatenate(all_images, axis=0)\n", 1915 | " all_features = np.concatenate(all_features, axis=0)\n", 1916 | " all_labels = np.concatenate(all_labels, axis=0)\n", 1917 | "\n", 1918 | " return all_images, all_features, all_labels\n", 1919 | "\n", 1920 | "\n", 1921 | "# Extract images and labels from test dataset\n", 1922 | "test_images, test_features, test_labels = get_all_images_from_dataset(test_ds)\n", 1923 | "print(f\"Extracted {len(test_images)} test images\")\n", 1924 | "\n", 1925 | "# Convert one-hot labels to class indices\n", 1926 | "y_true = np.argmax(test_labels, axis=1)\n", 1927 | "\n", 1928 | "\n", 1929 | "def plot_misclassified_particles(y_true, y_pred, y_pred_probs, test_images, test_features,\n", 1930 | " feature_columns, max_display=16):\n", 1931 | " \"\"\"\n", 1932 | " Plot all misclassified particles with detailed information.\n", 1933 | "\n", 1934 | " Parameters:\n", 1935 | " - y_true: True labels (class indices)\n", 1936 | " - y_pred: Predicted labels (class indices)\n", 1937 | " - y_pred_probs: Prediction probabilities\n", 1938 | " - test_images: Test images array\n", 1939 | " - test_features: Test features array\n", 1940 | " - feature_columns: List of feature names\n", 1941 | " - max_display: Maximum number of samples to display\n", 1942 | " \"\"\"\n", 1943 | " # Find misclassified indices\n", 1944 | " misclassified_idx = np.where(y_pred != y_true)[0]\n", 1945 | "\n", 1946 | " print(\"=\" * 80)\n", 1947 | " print(f\"MISCLASSIFICATION ANALYSIS\")\n", 1948 | " print(\"=\" * 80)\n", 1949 | " print(f\"Total test samples: {len(y_true)}\")\n", 1950 | " print(f\"Misclassified: {len(misclassified_idx)} ({len(misclassified_idx)/len(y_true)*100:.2f}%)\")\n", 1951 | " print()\n", 1952 | "\n", 1953 | " # Analyze misclassification patterns\n", 1954 | " # Liquid misclassified as Solid\n", 1955 | " liquid_as_solid = np.where((y_true == 0) & (y_pred == 1))[0]\n", 1956 | " # Solid misclassified as Liquid\n", 1957 | " solid_as_liquid = np.where((y_true == 1) & (y_pred == 0))[0]\n", 1958 | "\n", 1959 | " print(f\"Liquid → Solid errors: {len(liquid_as_solid)} ({len(liquid_as_solid)/np.sum(y_true==0)*100:.1f}% of liquids)\")\n", 1960 | " print(f\"Solid → Liquid errors: {len(solid_as_liquid)} ({len(solid_as_liquid)/np.sum(y_true==1)*100:.1f}% of solids)\")\n", 1961 | " print()\n", 1962 | "\n", 1963 | " if len(misclassified_idx) == 0:\n", 1964 | " print(\"🎉 No misclassifications! Perfect model!\")\n", 1965 | " return\n", 1966 | "\n", 1967 | " # Limit display\n", 1968 | " display_idx = misclassified_idx[:max_display]\n", 1969 | " n_samples = len(display_idx)\n", 1970 | "\n", 1971 | " # Calculate grid size\n", 1972 | " n_cols = 4\n", 1973 | " n_rows = int(np.ceil(n_samples / n_cols))\n", 1974 | "\n", 1975 | " fig = plt.figure(figsize=(20, 5*n_rows))\n", 1976 | "\n", 1977 | " for plot_idx, sample_idx in enumerate(display_idx):\n", 1978 | " ax = plt.subplot(n_rows, n_cols, plot_idx + 1)\n", 1979 | "\n", 1980 | " # Get data\n", 1981 | " image = test_images[sample_idx, :, :, 0]\n", 1982 | " true_class = y_true[sample_idx]\n", 1983 | " pred_class = y_pred[sample_idx]\n", 1984 | " confidence = y_pred_probs[sample_idx, pred_class]\n", 1985 | "\n", 1986 | " # Get feature values\n", 1987 | " features = test_features[sample_idx]\n", 1988 | "\n", 1989 | " # Display image\n", 1990 | " ax.imshow(image, cmap='gray')\n", 1991 | "\n", 1992 | " # Create detailed title\n", 1993 | " title = f\"Sample #{sample_idx}\\n\"\n", 1994 | " title += f\"TRUE: {get_class_name(true_class)} | PRED: {get_class_name(pred_class)}\\n\"\n", 1995 | " title += f\"Confidence: {confidence:.3f}\\n\"\n", 1996 | " title += f\"Probs: [L:{y_pred_probs[sample_idx, 0]:.3f}, S:{y_pred_probs[sample_idx, 1]:.3f}]\\n\"\n", 1997 | "\n", 1998 | " # Add feature information\n", 1999 | " if len(feature_columns) == 1:\n", 2000 | " title += f\"{feature_columns[0]}: {features[0]:.1f}\"\n", 2001 | " else:\n", 2002 | " feature_str = \", \".join([f\"{name}:{val:.1f}\" for name, val in zip(feature_columns, features)])\n", 2003 | " title += f\"{feature_str}\"\n", 2004 | "\n", 2005 | " ax.set_title(title, fontsize=10, color='red', weight='bold')\n", 2006 | " ax.axis('off')\n", 2007 | "\n", 2008 | " plt.suptitle(f'Misclassified Particles (Showing {n_samples} of {len(misclassified_idx)})',\n", 2009 | " fontsize=16, weight='bold', y=1.0)\n", 2010 | " plt.tight_layout()\n", 2011 | " plt.savefig('misclassified_particles.png', dpi=150, bbox_inches='tight')\n", 2012 | " plt.show()\n", 2013 | "\n", 2014 | " # Print detailed statistics\n", 2015 | " print(\"\\n\" + \"=\" * 80)\n", 2016 | " print(\"DETAILED MISCLASSIFICATION STATISTICS\")\n", 2017 | " print(\"=\" * 80)\n", 2018 | "\n", 2019 | " if len(misclassified_idx) > 0:\n", 2020 | " # Confidence analysis\n", 2021 | " misclass_confidences = y_pred_probs[misclassified_idx, y_pred[misclassified_idx]]\n", 2022 | " correct_idx = np.where(y_pred == y_true)[0]\n", 2023 | " correct_confidences = y_pred_probs[correct_idx, y_pred[correct_idx]]\n", 2024 | "\n", 2025 | " print(f\"\\nConfidence Analysis:\")\n", 2026 | " print(f\" Misclassified - Mean confidence: {np.mean(misclass_confidences):.3f} (±{np.std(misclass_confidences):.3f})\")\n", 2027 | " print(f\" Correct - Mean confidence: {np.mean(correct_confidences):.3f} (±{np.std(correct_confidences):.3f})\")\n", 2028 | "\n", 2029 | " # Feature analysis for misclassified samples\n", 2030 | " print(f\"\\nFeature Statistics for Misclassified Samples:\")\n", 2031 | " misclass_features = test_features[misclassified_idx]\n", 2032 | "\n", 2033 | " for i, feat_name in enumerate(feature_columns):\n", 2034 | " if len(feature_columns) == 1:\n", 2035 | " feat_values = misclass_features\n", 2036 | " else:\n", 2037 | " feat_values = misclass_features[:, i]\n", 2038 | " print(f\" {feat_name}: mean={np.mean(feat_values):.2f}, \"\n", 2039 | " f\"std={np.std(feat_values):.2f}, \"\n", 2040 | " f\"range=[{np.min(feat_values):.2f}, {np.max(feat_values):.2f}]\")\n", 2041 | "\n", 2042 | "\n", 2043 | "def plot_confidence_distribution(y_true, y_pred, y_pred_probs):\n", 2044 | " \"\"\"\n", 2045 | " Plot confidence distributions for correct vs incorrect predictions.\n", 2046 | " \"\"\"\n", 2047 | " fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", 2048 | "\n", 2049 | " # Get confidences\n", 2050 | " confidences = np.max(y_pred_probs, axis=1)\n", 2051 | " correct_mask = (y_pred == y_true)\n", 2052 | "\n", 2053 | " correct_conf = confidences[correct_mask]\n", 2054 | " incorrect_conf = confidences[~correct_mask]\n", 2055 | "\n", 2056 | " # Plot 1: Confidence histograms\n", 2057 | " ax1 = axes[0]\n", 2058 | " ax1.hist(correct_conf, bins=20, alpha=0.7, label='Correct', color='green', edgecolor='black')\n", 2059 | " ax1.hist(incorrect_conf, bins=20, alpha=0.7, label='Incorrect', color='red', edgecolor='black')\n", 2060 | " ax1.set_xlabel('Prediction Confidence', fontweight='bold')\n", 2061 | " ax1.set_ylabel('Count', fontweight='bold')\n", 2062 | " ax1.set_title('Confidence Distribution: Correct vs Incorrect', fontweight='bold')\n", 2063 | " ax1.legend()\n", 2064 | " ax1.grid(alpha=0.3)\n", 2065 | "\n", 2066 | " # Plot 2: Confidence by class\n", 2067 | " ax2 = axes[1]\n", 2068 | "\n", 2069 | " # Liquid predictions\n", 2070 | " liquid_correct = confidences[(y_pred == 0) & correct_mask]\n", 2071 | " liquid_incorrect = confidences[(y_pred == 0) & ~correct_mask]\n", 2072 | "\n", 2073 | " # Solid predictions\n", 2074 | " solid_correct = confidences[(y_pred == 1) & correct_mask]\n", 2075 | " solid_incorrect = confidences[(y_pred == 1) & ~correct_mask]\n", 2076 | "\n", 2077 | " positions = [1, 2, 4, 5]\n", 2078 | " bp = ax2.boxplot([liquid_correct, liquid_incorrect, solid_correct, solid_incorrect],\n", 2079 | " positions=positions,\n", 2080 | " widths=0.6,\n", 2081 | " patch_artist=True,\n", 2082 | " labels=['Liquid\\nCorrect', 'Liquid\\nIncorrect',\n", 2083 | " 'Solid\\nCorrect', 'Solid\\nIncorrect'])\n", 2084 | "\n", 2085 | " # Color the boxes\n", 2086 | " colors = ['lightgreen', 'lightcoral', 'lightgreen', 'lightcoral']\n", 2087 | " for patch, color in zip(bp['boxes'], colors):\n", 2088 | " patch.set_facecolor(color)\n", 2089 | "\n", 2090 | " ax2.set_ylabel('Prediction Confidence', fontweight='bold')\n", 2091 | " ax2.set_title('Confidence by Class and Correctness', fontweight='bold')\n", 2092 | " ax2.grid(alpha=0.3, axis='y')\n", 2093 | " ax2.set_ylim([0, 1])\n", 2094 | "\n", 2095 | " plt.tight_layout()\n", 2096 | " plt.savefig('confidence_analysis.png', dpi=150, bbox_inches='tight')\n", 2097 | " plt.show()\n", 2098 | "\n", 2099 | "\n", 2100 | "def plot_feature_distributions_by_correctness(y_true, y_pred, test_features, feature_columns):\n", 2101 | " \"\"\"\n", 2102 | " Compare feature distributions between correctly and incorrectly classified samples.\n", 2103 | " \"\"\"\n", 2104 | " correct_mask = (y_pred == y_true)\n", 2105 | "\n", 2106 | " n_features = len(feature_columns) if len(test_features.shape) > 1 else 1\n", 2107 | "\n", 2108 | " if n_features == 1:\n", 2109 | " feature_columns_list = feature_columns\n", 2110 | " test_features_array = test_features.reshape(-1, 1)\n", 2111 | " else:\n", 2112 | " feature_columns_list = feature_columns\n", 2113 | " test_features_array = test_features\n", 2114 | "\n", 2115 | " fig, axes = plt.subplots(1, n_features, figsize=(6*n_features, 5))\n", 2116 | "\n", 2117 | " if n_features == 1:\n", 2118 | " axes = [axes]\n", 2119 | "\n", 2120 | " for i, (ax, feat_name) in enumerate(zip(axes, feature_columns_list)):\n", 2121 | " feat_values = test_features_array[:, i]\n", 2122 | "\n", 2123 | " correct_vals = feat_values[correct_mask]\n", 2124 | " incorrect_vals = feat_values[~correct_mask]\n", 2125 | "\n", 2126 | " ax.hist(correct_vals, bins=20, alpha=0.6, label='Correct', color='green', edgecolor='black')\n", 2127 | " ax.hist(incorrect_vals, bins=20, alpha=0.6, label='Incorrect', color='red', edgecolor='black')\n", 2128 | "\n", 2129 | " ax.set_xlabel(feat_name, fontweight='bold')\n", 2130 | " ax.set_ylabel('Count', fontweight='bold')\n", 2131 | " ax.set_title(f'{feat_name} Distribution', fontweight='bold')\n", 2132 | " ax.legend()\n", 2133 | " ax.grid(alpha=0.3, axis='y')\n", 2134 | "\n", 2135 | " plt.suptitle('Feature Distributions: Correct vs Incorrect Predictions',\n", 2136 | " fontsize=14, fontweight='bold')\n", 2137 | " plt.tight_layout()\n", 2138 | " plt.savefig('feature_distributions.png', dpi=150, bbox_inches='tight')\n", 2139 | " plt.show()\n", 2140 | "\n", 2141 | "\n", 2142 | "def analyze_low_confidence_predictions(y_true, y_pred, y_pred_probs, test_images,\n", 2143 | " test_features, feature_columns, threshold=0.7):\n", 2144 | " \"\"\"\n", 2145 | " Analyze predictions with low confidence (potential ambiguous cases).\n", 2146 | " \"\"\"\n", 2147 | " confidences = np.max(y_pred_probs, axis=1)\n", 2148 | " low_conf_idx = np.where(confidences < threshold)[0]\n", 2149 | "\n", 2150 | " print(\"=\" * 80)\n", 2151 | " print(f\"LOW CONFIDENCE PREDICTIONS (confidence < {threshold})\")\n", 2152 | " print(\"=\" * 80)\n", 2153 | " print(f\"Total low confidence predictions: {len(low_conf_idx)} ({len(low_conf_idx)/len(y_true)*100:.2f}%)\")\n", 2154 | "\n", 2155 | " if len(low_conf_idx) > 0:\n", 2156 | " # Check how many are correct vs incorrect\n", 2157 | " low_conf_correct = np.sum(y_pred[low_conf_idx] == y_true[low_conf_idx])\n", 2158 | " low_conf_incorrect = len(low_conf_idx) - low_conf_correct\n", 2159 | "\n", 2160 | " print(f\" Correct: {low_conf_correct} ({low_conf_correct/len(low_conf_idx)*100:.1f}%)\")\n", 2161 | " print(f\" Incorrect: {low_conf_incorrect} ({low_conf_incorrect/len(low_conf_idx)*100:.1f}%)\")\n", 2162 | " print()\n", 2163 | " print(\"These are ambiguous cases where the model is uncertain.\")\n", 2164 | "\n", 2165 | "\n", 2166 | "# ==========================================\n", 2167 | "# RUN ALL ANALYSES\n", 2168 | "# ==========================================\n", 2169 | "\n", 2170 | "print(\"\\n\" + \"=\" * 80)\n", 2171 | "print(\"COMPREHENSIVE MISCLASSIFICATION ANALYSIS\")\n", 2172 | "print(\"=\" * 80 + \"\\n\")\n", 2173 | "\n", 2174 | "# 1. Plot misclassified samples\n", 2175 | "plot_misclassified_particles(y_true, y_pred, y_pred_probs, test_images,\n", 2176 | " test_features, feature_columns, max_display=16)\n", 2177 | "\n", 2178 | "# 2. Confidence analysis\n", 2179 | "print(\"\\n\")\n", 2180 | "plot_confidence_distribution(y_true, y_pred, y_pred_probs)\n", 2181 | "\n", 2182 | "# 3. Feature distribution analysis\n", 2183 | "print(\"\\n\")\n", 2184 | "plot_feature_distributions_by_correctness(y_true, y_pred, test_features, feature_columns)\n", 2185 | "\n", 2186 | "# 4. Low confidence analysis\n", 2187 | "print(\"\\n\")\n", 2188 | "analyze_low_confidence_predictions(y_true, y_pred, y_pred_probs, test_images,\n", 2189 | " test_features, feature_columns, threshold=0.7)\n", 2190 | "\n", 2191 | "print(\"\\n\" + \"=\" * 80)\n", 2192 | "print(\"Analysis complete! Check the saved images:\")\n", 2193 | "print(\" - misclassified_particles.png\")\n", 2194 | "print(\" - confidence_analysis.png\")\n", 2195 | "print(\" - feature_distributions.png\")\n", 2196 | "print(\"=\" * 80)" 2197 | ] 2198 | }, 2199 | { 2200 | "cell_type": "code", 2201 | "execution_count": null, 2202 | "metadata": { 2203 | "colab": { 2204 | "base_uri": "https://localhost:8080/" 2205 | }, 2206 | "id": "JoiW8loeGumi", 2207 | "outputId": "6c1d353f-ae83-4ad0-a3bb-abb1abc62200" 2208 | }, 2209 | "outputs": [], 2210 | "source": [ 2211 | "# Make predictions on test set\n", 2212 | "y_pred_probs = best_model.predict(test_ds, verbose=1)\n", 2213 | "y_pred = np.argmax(y_pred_probs, axis=1)\n", 2214 | "y_true = np.argmax(y_test, axis=1)\n", 2215 | "\n", 2216 | "print(f\"Predictions shape: {y_pred_probs.shape}\")" 2217 | ] 2218 | }, 2219 | { 2220 | "cell_type": "markdown", 2221 | "metadata": { 2222 | "id": "E8zyh01NGumi" 2223 | }, 2224 | "source": [ 2225 | "## 13. Generate Predictions and Metrics" 2226 | ] 2227 | }, 2228 | { 2229 | "cell_type": "code", 2230 | "execution_count": null, 2231 | "metadata": { 2232 | "colab": { 2233 | "base_uri": "https://localhost:8080/" 2234 | }, 2235 | "id": "PWfXvLX0Gumi", 2236 | "outputId": "a30bfa7a-8ac1-4aa3-85c4-b7c59bab9cab" 2237 | }, 2238 | "outputs": [], 2239 | "source": [ 2240 | "# Calculate metrics\n", 2241 | "accuracy = accuracy_score(y_true, y_pred)\n", 2242 | "print(f\"Test Accuracy: {accuracy:.4f}\\n\")\n", 2243 | "print(\"Classification Report:\")\n", 2244 | "print(classification_report(y_true, y_pred, target_names=['Liquid', 'Solid']))" 2245 | ] 2246 | }, 2247 | { 2248 | "cell_type": "code", 2249 | "execution_count": null, 2250 | "metadata": { 2251 | "colab": { 2252 | "base_uri": "https://localhost:8080/", 2253 | "height": 543 2254 | }, 2255 | "id": "yI0sq1XeGumi", 2256 | "outputId": "fe66609f-d2f1-405f-fcd8-817baee582f2" 2257 | }, 2258 | "outputs": [], 2259 | "source": [ 2260 | "# Confusion matrix\n", 2261 | "from sklearn.metrics import ConfusionMatrixDisplay\n", 2262 | "\n", 2263 | "cm = confusion_matrix(y_true, y_pred)\n", 2264 | "disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Liquid', 'Solid'])\n", 2265 | "disp.plot(cmap='Blues', values_format='d')\n", 2266 | "plt.title('Confusion Matrix (Hybrid 2-Class CNN)')\n", 2267 | "plt.show()\n", 2268 | "\n", 2269 | "print(\"\\nConfusion Matrix Analysis:\")\n", 2270 | "for i, label in enumerate(['Liquid', 'Solid']):\n", 2271 | " correct = cm[i, i]\n", 2272 | " total = cm[i, :].sum()\n", 2273 | " print(f\"{label}: {correct}/{total} correct ({correct/total*100:.1f}%)\")" 2274 | ] 2275 | }, 2276 | { 2277 | "cell_type": "markdown", 2278 | "metadata": { 2279 | "id": "9jao2sN3Gumi" 2280 | }, 2281 | "source": [ 2282 | "## 14. Visualize Predictions" 2283 | ] 2284 | }, 2285 | { 2286 | "cell_type": "markdown", 2287 | "metadata": { 2288 | "id": "bcUg2nLIGumj" 2289 | }, 2290 | "source": [ 2291 | "## 15. Feature Importance Analysis" 2292 | ] 2293 | }, 2294 | { 2295 | "cell_type": "code", 2296 | "execution_count": null, 2297 | "metadata": { 2298 | "id": "ZA8S_f-13Y08" 2299 | }, 2300 | "outputs": [], 2301 | "source": [ 2302 | "def create_image_only_model():\n", 2303 | " # Image input branch (CNN)\n", 2304 | " image_input = Input(shape=(128, 128, 1), name='image_input')\n", 2305 | "\n", 2306 | " # Reduce initial filters for simpler images\n", 2307 | " x = Conv2D(16, kernel_size=(3, 3), activation='relu', padding='same',\n", 2308 | " kernel_regularizer=regularizers.l2(L2_REG))(image_input)\n", 2309 | " x = BatchNormalization()(x)\n", 2310 | " x = MaxPooling2D(pool_size=(2, 2))(x)\n", 2311 | " x = SpatialDropout2D(0.2)(x) # Spatial dropout for conv layers\n", 2312 | "\n", 2313 | " x = Conv2D(32, kernel_size=(3, 3), activation='relu', padding='same',\n", 2314 | " kernel_regularizer=regularizers.l2(L2_REG))(x)\n", 2315 | " x = BatchNormalization()(x)\n", 2316 | " x = MaxPooling2D(pool_size=(2, 2))(x)\n", 2317 | " x = SpatialDropout2D(0.2)(x)\n", 2318 | "\n", 2319 | " x = Conv2D(64, kernel_size=(3, 3), activation='relu', padding='same',\n", 2320 | " kernel_regularizer=regularizers.l2(L2_REG))(x)\n", 2321 | " x = BatchNormalization()(x)\n", 2322 | " x = MaxPooling2D(pool_size=(2, 2))(x)\n", 2323 | " x = SpatialDropout2D(0.3)(x)\n", 2324 | "\n", 2325 | " # Flatten the CNN output\n", 2326 | " x = Flatten()(x)\n", 2327 | " x = Dropout(0.5)(x)\n", 2328 | " cnn_output = Dense(64, activation='relu', kernel_regularizer=regularizers.l2(L2_REG))(x)\n", 2329 | "\n", 2330 | " # --- Final classification layers (Now connect directly to cnn_output) ---\n", 2331 | "\n", 2332 | " z = Dense(128, activation='relu', kernel_regularizer=regularizers.l2(L2_REG))(cnn_output)\n", 2333 | " z = Dropout(0.5)(z)\n", 2334 | " z = Dense(64, activation='relu', kernel_regularizer=regularizers.l2(L2_REG))(z)\n", 2335 | " z = Dropout(0.3)(z)\n", 2336 | "\n", 2337 | " # Output layer\n", 2338 | " output = Dense(2, activation='softmax', name='output')(z)\n", 2339 | "\n", 2340 | " model = Model(inputs=image_input, outputs=output, name='Image_Only_CNN')\n", 2341 | " return model" 2342 | ] 2343 | }, 2344 | { 2345 | "cell_type": "code", 2346 | "execution_count": null, 2347 | "metadata": { 2348 | "id": "dfBVb7NY3YId" 2349 | }, 2350 | "outputs": [], 2351 | "source": [ 2352 | "def create_feature_only_model(len_feature_columns):\n", 2353 | "\n", 2354 | " # Environmental features input branch\n", 2355 | " feature_input = Input(shape=(len_feature_columns,), name='feature_input')\n", 2356 | "\n", 2357 | " # Process environmental features (Identical to Hybrid Model)\n", 2358 | " f = Dense(32, activation='relu', kernel_regularizer=regularizers.l2(L2_REG))(feature_input)\n", 2359 | " f = Dropout(0.3)(f)\n", 2360 | " feature_output = Dense(16, activation='relu', kernel_regularizer=regularizers.l2(L2_REG))(f)\n", 2361 | "\n", 2362 | " # --- Final classification layers (Now connect directly to feature_output) ---\n", 2363 | " z = Dense(128, activation='relu', kernel_regularizer=regularizers.l2(L2_REG))(feature_output)\n", 2364 | " z = Dropout(0.5)(z)\n", 2365 | " z = Dense(64, activation='relu', kernel_regularizer=regularizers.l2(L2_REG))(z)\n", 2366 | " z = Dropout(0.3)(z)\n", 2367 | "\n", 2368 | " # Output layer\n", 2369 | " output = Dense(2, activation='softmax', name='output')(z)\n", 2370 | "\n", 2371 | " model = Model(inputs=feature_input, outputs=output, name='Feature_Only_Dense')\n", 2372 | " return model" 2373 | ] 2374 | }, 2375 | { 2376 | "cell_type": "code", 2377 | "execution_count": null, 2378 | "metadata": { 2379 | "id": "AI7O-1Qa6nMw" 2380 | }, 2381 | "outputs": [], 2382 | "source": [ 2383 | "def load_and_augment_image_only(\n", 2384 | " image_path: str, label: tf.Tensor, augment: bool\n", 2385 | ") -> Tuple[tf.Tensor, tf.Tensor]:\n", 2386 | " \"\"\"\n", 2387 | " IMAGE-ONLY FUNCTION: Loads an image, preprocesses it, applies conditional augmentation,\n", 2388 | " and returns the structured input and label for the image-only model: (image, label).\n", 2389 | " \"\"\"\n", 2390 | "\n", 2391 | " # 1. Image Loading and Initial Processing\n", 2392 | " image = tf.io.read_file(filename=image_path)\n", 2393 | " image = tf.image.decode_png(contents=image, channels=1)\n", 2394 | " image = tf.image.convert_image_dtype(image=image, dtype=tf.float32)\n", 2395 | " image = tf.image.resize(images=image, size=IMAGE_SIZE)\n", 2396 | "\n", 2397 | " # 2. Apply Augmentation (Only if 'augment' is True and based on probability)\n", 2398 | " if augment:\n", 2399 | " if tf.random.uniform(()) < AUGMENTATION_PROB:\n", 2400 | " # 1. Random horizontal flipping\n", 2401 | " image = tf.image.flip_left_right(image)\n", 2402 | " if tf.random.uniform(()) < AUGMENTATION_PROB:\n", 2403 | " # 2. Geometric augmentation (your custom function)\n", 2404 | " image = random_geometric_augment(image)\n", 2405 | "\n", 2406 | " # 3. Final Normalization/Clipping\n", 2407 | " image = tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=1.0)\n", 2408 | "\n", 2409 | " # 4. Return in the format Keras model.fit expects: ( image_input, label )\n", 2410 | " return image, label\n", 2411 | "\n", 2412 | "\n", 2413 | "def create_feature_only_ds(features: tf.Tensor, labels: tf.Tensor) -> tf.data.Dataset:\n", 2414 | " \"\"\"\n", 2415 | " FEATURE-ONLY FUNCTION: Creates a dataset mapping only feature vectors and labels\n", 2416 | " directly from tensor slices.\n", 2417 | " \"\"\"\n", 2418 | " # Features are typically already preprocessed NumPy arrays (tensors)\n", 2419 | " ds = tf.data.Dataset.from_tensor_slices((features, labels))\n", 2420 | " # No map function is needed as no file loading or augmentation is performed\n", 2421 | " return ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)" 2422 | ] 2423 | }, 2424 | { 2425 | "cell_type": "code", 2426 | "execution_count": null, 2427 | "metadata": { 2428 | "id": "acdnpNMq7GRs" 2429 | }, 2430 | "outputs": [], 2431 | "source": [ 2432 | "def create_image_only_ds(paths, labels, augment=False):\n", 2433 | " \"\"\"\n", 2434 | " Creates dataset mapping only image paths and labels, using the dedicated\n", 2435 | " load_and_augment_image_only function.\n", 2436 | " \"\"\"\n", 2437 | " ds = tf.data.Dataset.from_tensor_slices((paths, labels))\n", 2438 | "\n", 2439 | " # Assuming 'load_and_augment_image_only' is available in the environment\n", 2440 | " ds = ds.map(\n", 2441 | " lambda p, l: load_and_augment_image_only(p, l, augment=augment),\n", 2442 | " num_parallel_calls=tf.data.AUTOTUNE\n", 2443 | " )\n", 2444 | "\n", 2445 | " return ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)\n", 2446 | "\n", 2447 | "def create_feature_only_ds(features, labels):\n", 2448 | " \"\"\"Creates dataset mapping only feature vectors and labels.\"\"\"\n", 2449 | " ds = tf.data.Dataset.from_tensor_slices((features, labels))\n", 2450 | " return ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)\n" 2451 | ] 2452 | }, 2453 | { 2454 | "cell_type": "code", 2455 | "execution_count": null, 2456 | "metadata": { 2457 | "id": "GPrgKzLb4cfw" 2458 | }, 2459 | "outputs": [], 2460 | "source": [ 2461 | "\n", 2462 | "# 1. Load/Define the Models\n", 2463 | "\n", 2464 | "# Baseline Model (Best Hybrid Model)\n", 2465 | "try:\n", 2466 | " baseline_model = load_model('best_particle_classifier_hybrid_binary.keras')\n", 2467 | "except Exception as e:\n", 2468 | " print(f\"Error loading baseline model: {e}\")\n", 2469 | " exit()\n", 2470 | "\n", 2471 | "# Image-Only Model (Need to define/import create_image_only_model)\n", 2472 | "image_only_model = create_image_only_model()\n", 2473 | "compile_model(image_only_model)\n", 2474 | "\n", 2475 | "# Feature-Only Model (Need to define/import create_feature_only_model)\n", 2476 | "feature_only_model = create_feature_only_model(len(feature_columns))\n", 2477 | "compile_model(feature_only_model)" 2478 | ] 2479 | }, 2480 | { 2481 | "cell_type": "code", 2482 | "execution_count": null, 2483 | "metadata": { 2484 | "id": "1CVydM60dwpP" 2485 | }, 2486 | "outputs": [], 2487 | "source": [ 2488 | "def plot_training_history(history):\n", 2489 | " # Plot training history\n", 2490 | " fig, axes = plt.subplots(1, 2, figsize=(15, 5))\n", 2491 | "\n", 2492 | " # Plot accuracy\n", 2493 | " axes[0].plot(history.history['accuracy'], label='Training Accuracy')\n", 2494 | " axes[0].plot(history.history['val_accuracy'], label='Validation Accuracy')\n", 2495 | " axes[0].set_title('Model Accuracy (Hybrid CNN)')\n", 2496 | " axes[0].set_ylabel('Accuracy')\n", 2497 | " axes[0].set_xlabel('Epoch')\n", 2498 | " axes[0].legend()\n", 2499 | " axes[0].grid(True)\n", 2500 | "\n", 2501 | " # Plot loss\n", 2502 | " axes[1].plot(history.history['loss'], label='Training Loss')\n", 2503 | " axes[1].plot(history.history['val_loss'], label='Validation Loss')\n", 2504 | " axes[1].set_title('Model Loss (Hybrid CNN)')\n", 2505 | " axes[1].set_ylabel('Loss')\n", 2506 | " axes[1].set_xlabel('Epoch')\n", 2507 | " axes[1].legend()\n", 2508 | " axes[1].grid(True)\n", 2509 | "\n", 2510 | " plt.tight_layout()\n", 2511 | " plt.show()\n", 2512 | "\n", 2513 | " # Print final metrics\n", 2514 | " print(f\"\\nFinal Training Accuracy: {history.history['accuracy'][-1]:.4f}\")\n", 2515 | " print(f\"Final Validation Accuracy: {history.history['val_accuracy'][-1]:.4f}\")\n", 2516 | " print(f\"Best Validation Accuracy: {max(history.history['val_accuracy']):.4f}\")" 2517 | ] 2518 | }, 2519 | { 2520 | "cell_type": "code", 2521 | "execution_count": null, 2522 | "metadata": { 2523 | "colab": { 2524 | "base_uri": "https://localhost:8080/" 2525 | }, 2526 | "id": "GcilOAcB3c-x", 2527 | "outputId": "bb7d7447-7071-4cb0-d07c-0162bd3a0058" 2528 | }, 2529 | "outputs": [], 2530 | "source": [ 2531 | "print(\"\\n\" + \"=\" * 50)\n", 2532 | "print(\"STARTING IMAGE-ONLY MODEL TRAINING\")\n", 2533 | "print(\"=\" * 50)\n", 2534 | "N_EPOCHS=40\n", 2535 | "\n", 2536 | "# Create datasets for image-only training\n", 2537 | "train_ds_img = create_image_only_ds(X_paths_train, y_train, augment=True)\n", 2538 | "val_ds_img = create_image_only_ds(X_paths_val, y_val, augment=False)\n", 2539 | "\n", 2540 | "steps_per_epoch_train = N_TRAIN_SAMPLES // BATCH_SIZE\n", 2541 | "validation_steps_val = N_VAL_SAMPLES // BATCH_SIZE\n", 2542 | "\n", 2543 | "checkpoint_img = ModelCheckpoint('best_particle_classifier_image_only.keras', monitor='val_loss', mode='min', verbose=1, save_best_only=True)\n", 2544 | "\n", 2545 | "hist_image = image_only_model.fit(\n", 2546 | " train_ds_img,\n", 2547 | " steps_per_epoch=steps_per_epoch_train,\n", 2548 | " epochs=N_EPOCHS,\n", 2549 | " validation_data=val_ds_img,\n", 2550 | " validation_steps=validation_steps_val,\n", 2551 | " callbacks=[early_stop, checkpoint_img],\n", 2552 | " verbose=1\n", 2553 | ")" 2554 | ] 2555 | }, 2556 | { 2557 | "cell_type": "code", 2558 | "execution_count": null, 2559 | "metadata": { 2560 | "colab": { 2561 | "base_uri": "https://localhost:8080/", 2562 | "height": 578 2563 | }, 2564 | "id": "SkNbj-3Jt8IQ", 2565 | "outputId": "a67ee42e-dfdd-4a57-964e-55bffcf875d2" 2566 | }, 2567 | "outputs": [], 2568 | "source": [ 2569 | "plot_training_history(hist_image)" 2570 | ] 2571 | }, 2572 | { 2573 | "cell_type": "code", 2574 | "execution_count": null, 2575 | "metadata": { 2576 | "colab": { 2577 | "base_uri": "https://localhost:8080/" 2578 | }, 2579 | "id": "Nggc_PHl5k8a", 2580 | "outputId": "bc563508-e5c1-43bc-a3d7-cf4d6c2950b1" 2581 | }, 2582 | "outputs": [], 2583 | "source": [ 2584 | "print(\"\\n\" + \"=\" * 50)\n", 2585 | "print(\"STARTING FEATURE-ONLY MODEL TRAINING\")\n", 2586 | "print(\"=\" * 50)\n", 2587 | "\n", 2588 | "# Create datasets for feature-only training\n", 2589 | "train_ds_feat = create_feature_only_ds(X_features_train, y_train)\n", 2590 | "val_ds_feat = create_feature_only_ds(X_features_val, y_val)\n", 2591 | "\n", 2592 | "checkpoint_feat = ModelCheckpoint('best_particle_classifier_feature_only.keras', monitor='val_loss', mode='min', verbose=1, save_best_only=True)\n", 2593 | "\n", 2594 | "hist_feat = feature_only_model.fit(\n", 2595 | " train_ds_feat,\n", 2596 | " steps_per_epoch=steps_per_epoch_train, # Reuse step count for simplicity, assuming same size\n", 2597 | " epochs=N_EPOCHS,\n", 2598 | " validation_data=val_ds_feat,\n", 2599 | " validation_steps=validation_steps_val,\n", 2600 | " callbacks=[early_stop, checkpoint_feat],\n", 2601 | " verbose=1\n", 2602 | " )" 2603 | ] 2604 | }, 2605 | { 2606 | "cell_type": "code", 2607 | "execution_count": null, 2608 | "metadata": { 2609 | "colab": { 2610 | "base_uri": "https://localhost:8080/", 2611 | "height": 576 2612 | }, 2613 | "id": "iblthy_6ekKP", 2614 | "outputId": "c2dbdde0-5c5a-47b0-f7b3-30b0d266aabf" 2615 | }, 2616 | "outputs": [], 2617 | "source": [ 2618 | "plot_training_history(hist_feat)" 2619 | ] 2620 | }, 2621 | { 2622 | "cell_type": "code", 2623 | "execution_count": null, 2624 | "metadata": { 2625 | "colab": { 2626 | "base_uri": "https://localhost:8080/", 2627 | "height": 1000 2628 | }, 2629 | "id": "_FwiZvqrg8n6", 2630 | "outputId": "585acc0a-8e2d-4801-eb7d-e79af6f0d4f4" 2631 | }, 2632 | "outputs": [], 2633 | "source": [ 2634 | "import tensorflow as tf\n", 2635 | "import pandas as pd\n", 2636 | "import numpy as np\n", 2637 | "from tensorflow.keras.models import Model\n", 2638 | "from tensorflow.keras.layers import Input, Dense, Dropout, Concatenate, Conv2D, MaxPooling2D, BatchNormalization, Flatten\n", 2639 | "from tensorflow.keras.optimizers import Adam\n", 2640 | "from tensorflow.keras import regularizers\n", 2641 | "from tensorflow.keras.callbacks import EarlyStopping\n", 2642 | "\n", 2643 | "# Your configuration\n", 2644 | "BATCH_SIZE = 16\n", 2645 | "IMAGE_SIZE = (128, 128)\n", 2646 | "L2_REG = 0.0001\n", 2647 | "\n", 2648 | "def build_cnn_branch():\n", 2649 | " \"\"\"Build the CNN branch of your hybrid model.\"\"\"\n", 2650 | " image_input = Input(shape=(128, 128, 1), name='image_input')\n", 2651 | "\n", 2652 | " x = Conv2D(32, kernel_size=(3, 3), activation='relu', padding='same')(image_input)\n", 2653 | " x = MaxPooling2D(pool_size=(2, 2), padding='same')(x)\n", 2654 | "\n", 2655 | " x = Conv2D(64, kernel_size=(3, 3), activation='relu', padding='same')(x)\n", 2656 | " x = MaxPooling2D(pool_size=(2, 2), padding='same')(x)\n", 2657 | "\n", 2658 | " x = Conv2D(128, kernel_size=(3, 3), activation='relu', padding='same')(x)\n", 2659 | " x = BatchNormalization()(x)\n", 2660 | " x = MaxPooling2D(pool_size=(2, 2), padding='same')(x)\n", 2661 | "\n", 2662 | " x = Conv2D(128, kernel_size=(3, 3), activation='relu', padding='same')(x)\n", 2663 | " x = BatchNormalization()(x)\n", 2664 | " x = MaxPooling2D(pool_size=(2, 2), padding='same')(x)\n", 2665 | "\n", 2666 | " x = Flatten()(x)\n", 2667 | " x = Dropout(0.4)(x)\n", 2668 | " cnn_output = Dense(128, activation='relu', kernel_regularizer=regularizers.l2(L2_REG))(x)\n", 2669 | "\n", 2670 | " return Model(inputs=image_input, outputs=cnn_output, name='cnn_branch')\n", 2671 | "\n", 2672 | "\n", 2673 | "def build_hybrid_model(num_features):\n", 2674 | " \"\"\"Build hybrid model with specified number of features.\"\"\"\n", 2675 | " # Image branch\n", 2676 | " image_input = Input(shape=(128, 128, 1), name='image_input')\n", 2677 | " cnn_branch = build_cnn_branch()\n", 2678 | " cnn_output = cnn_branch(image_input)\n", 2679 | "\n", 2680 | " # Feature branch\n", 2681 | " feature_input = Input(shape=(num_features,), name='feature_input')\n", 2682 | " f = Dense(32, activation='relu', kernel_regularizer=regularizers.l2(L2_REG))(feature_input)\n", 2683 | " f = Dropout(0.3)(f)\n", 2684 | " feature_output = Dense(16, activation='relu', kernel_regularizer=regularizers.l2(L2_REG))(f)\n", 2685 | "\n", 2686 | " # Combine\n", 2687 | " combined = Concatenate()([cnn_output, feature_output])\n", 2688 | "\n", 2689 | " # Final layers\n", 2690 | " z = Dense(64, activation='relu', kernel_regularizer=regularizers.l2(L2_REG))(combined)\n", 2691 | " z = Dropout(0.5)(z)\n", 2692 | " z = Dense(32, activation='relu', kernel_regularizer=regularizers.l2(L2_REG))(z)\n", 2693 | " z = Dropout(0.3)(z)\n", 2694 | "\n", 2695 | " output = Dense(2, activation='softmax', name='output')(z)\n", 2696 | "\n", 2697 | " model = Model(inputs=[image_input, feature_input], outputs=output, name='Hybrid_CNN')\n", 2698 | "\n", 2699 | " model.compile(\n", 2700 | " optimizer=Adam(learning_rate=0.00005),\n", 2701 | " loss='categorical_crossentropy',\n", 2702 | " metrics=['accuracy']\n", 2703 | " )\n", 2704 | "\n", 2705 | " return model\n", 2706 | "\n", 2707 | "\n", 2708 | "def create_dataset_with_feature_subset(X_paths, X_features, y, feature_indices, augment=False):\n", 2709 | " \"\"\"Create tf.data.Dataset with only selected features.\"\"\"\n", 2710 | " # Select only the specified feature columns\n", 2711 | " X_features_subset = X_features[:, feature_indices]\n", 2712 | "\n", 2713 | " # Create dataset\n", 2714 | " ds = tf.data.Dataset.from_tensor_slices((X_paths, X_features_subset, y))\n", 2715 | "\n", 2716 | " # Map with augmentation\n", 2717 | " ds = ds.map(\n", 2718 | " lambda p, f, l: load_and_augment_hybrid(p, f, l, augment=augment),\n", 2719 | " num_parallel_calls=tf.data.AUTOTUNE\n", 2720 | " )\n", 2721 | "\n", 2722 | " # Cache, shuffle (if training), batch, prefetch\n", 2723 | " ds = ds.cache()\n", 2724 | " if augment: # Only shuffle training data\n", 2725 | " ds = ds.shuffle(buffer_size=len(X_paths))\n", 2726 | " ds = ds.batch(BATCH_SIZE)\n", 2727 | " ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE)\n", 2728 | "\n", 2729 | " return ds\n", 2730 | "\n", 2731 | "\n", 2732 | "def feature_ablation_study(X_paths_train, X_features_train, y_train,\n", 2733 | " X_paths_val, X_features_val, y_val,\n", 2734 | " X_paths_test, X_features_test, y_test,\n", 2735 | " feature_columns):\n", 2736 | " \"\"\"\n", 2737 | " Run comprehensive feature ablation study.\n", 2738 | "\n", 2739 | " Returns:\n", 2740 | " DataFrame with results for each feature combination tested\n", 2741 | " \"\"\"\n", 2742 | " results = []\n", 2743 | " n_features_total = len(feature_columns)\n", 2744 | "\n", 2745 | " print(\"=\" * 80)\n", 2746 | " print(\"FEATURE ABLATION STUDY\")\n", 2747 | " print(\"=\" * 80)\n", 2748 | " print(f\"Total features: {n_features_total}\")\n", 2749 | " print(f\"Features: {feature_columns}\")\n", 2750 | " print()\n", 2751 | "\n", 2752 | " # ==========================================\n", 2753 | " # 1. BASELINE: All features\n", 2754 | " # ==========================================\n", 2755 | " print(\"=\" * 80)\n", 2756 | " print(\"1. BASELINE - ALL FEATURES\")\n", 2757 | " print(\"=\" * 80)\n", 2758 | "\n", 2759 | " # Use your existing datasets\n", 2760 | " model = build_hybrid_model(n_features_total)\n", 2761 | "\n", 2762 | " history = model.fit(\n", 2763 | " train_ds,\n", 2764 | " validation_data=val_ds,\n", 2765 | " epochs=30,\n", 2766 | " callbacks=[EarlyStopping(patience=10, restore_best_weights=True, verbose=1)],\n", 2767 | " verbose=1\n", 2768 | " )\n", 2769 | "\n", 2770 | " test_loss, test_acc = model.evaluate(test_ds, verbose=0)\n", 2771 | " baseline_acc = test_acc\n", 2772 | "\n", 2773 | " results.append({\n", 2774 | " 'experiment': 'ALL_FEATURES',\n", 2775 | " 'n_features': n_features_total,\n", 2776 | " 'features_used': ', '.join(feature_columns),\n", 2777 | " 'test_accuracy': test_acc,\n", 2778 | " 'test_loss': test_loss,\n", 2779 | " 'accuracy_vs_baseline': 0.0\n", 2780 | " })\n", 2781 | "\n", 2782 | " print(f\"✓ Baseline Test Accuracy: {test_acc:.4f}\\n\")\n", 2783 | "\n", 2784 | " # ==========================================\n", 2785 | " # 2. LEAVE-ONE-OUT: Remove each feature\n", 2786 | " # ==========================================\n", 2787 | " print(\"=\" * 80)\n", 2788 | " print(\"2. LEAVE-ONE-OUT ANALYSIS\")\n", 2789 | " print(\"=\" * 80)\n", 2790 | " print(\"Testing performance when each feature is removed...\\n\")\n", 2791 | "\n", 2792 | " for i, feature_to_remove in enumerate(feature_columns):\n", 2793 | " print(f\"[{i+1}/{n_features_total}] Removing: {feature_to_remove}\")\n", 2794 | "\n", 2795 | " # Get indices of features to keep\n", 2796 | " feature_indices = [j for j, f in enumerate(feature_columns) if f != feature_to_remove]\n", 2797 | " remaining_features = [f for f in feature_columns if f != feature_to_remove]\n", 2798 | "\n", 2799 | " # Create datasets with subset of features\n", 2800 | " train_ds_subset = create_dataset_with_feature_subset(\n", 2801 | " X_paths_train, X_features_train, y_train, feature_indices, augment=True\n", 2802 | " )\n", 2803 | " val_ds_subset = create_dataset_with_feature_subset(\n", 2804 | " X_paths_val, X_features_val, y_val, feature_indices, augment=False\n", 2805 | " )\n", 2806 | " test_ds_subset = create_dataset_with_feature_subset(\n", 2807 | " X_paths_test, X_features_test, y_test, feature_indices, augment=False\n", 2808 | " )\n", 2809 | "\n", 2810 | " # Build and train model\n", 2811 | " model = build_hybrid_model(len(remaining_features))\n", 2812 | "\n", 2813 | " history = model.fit(\n", 2814 | " train_ds_subset,\n", 2815 | " validation_data=val_ds_subset,\n", 2816 | " epochs=30,\n", 2817 | " callbacks=[EarlyStopping(patience=10, restore_best_weights=True, verbose=0)],\n", 2818 | " verbose=0\n", 2819 | " )\n", 2820 | "\n", 2821 | " test_loss, test_acc = model.evaluate(test_ds_subset, verbose=0)\n", 2822 | " accuracy_drop = baseline_acc - test_acc\n", 2823 | "\n", 2824 | " results.append({\n", 2825 | " 'experiment': f'REMOVE_{feature_to_remove}',\n", 2826 | " 'n_features': len(remaining_features),\n", 2827 | " 'features_used': ', '.join(remaining_features),\n", 2828 | " 'removed_feature': feature_to_remove,\n", 2829 | " 'test_accuracy': test_acc,\n", 2830 | " 'test_loss': test_loss,\n", 2831 | " 'accuracy_vs_baseline': accuracy_drop\n", 2832 | " })\n", 2833 | "\n", 2834 | " symbol = \"⚠️\" if accuracy_drop > 0.01 else \"✓\"\n", 2835 | " print(f\" {symbol} Accuracy: {test_acc:.4f} (Δ = {accuracy_drop:+.4f})\\n\")\n", 2836 | "\n", 2837 | " # ==========================================\n", 2838 | " # 3. SINGLE FEATURE: Use only one feature\n", 2839 | " # ==========================================\n", 2840 | " print(\"=\" * 80)\n", 2841 | " print(\"3. SINGLE FEATURE ANALYSIS\")\n", 2842 | " print(\"=\" * 80)\n", 2843 | " print(\"Testing performance with each feature individually...\\n\")\n", 2844 | "\n", 2845 | " for i, feature in enumerate(feature_columns):\n", 2846 | " print(f\"[{i+1}/{n_features_total}] Using only: {feature}\")\n", 2847 | "\n", 2848 | " # Get index of this feature\n", 2849 | " for ind,feat in enumerate(feature_columns):\n", 2850 | " if feat == feature_name:\n", 2851 | " feature_idx = ind\n", 2852 | "\n", 2853 | " # Create datasets with single feature\n", 2854 | " train_ds_single = create_dataset_with_feature_subset(\n", 2855 | " X_paths_train, X_features_train, y_train, feature_idx, augment=True\n", 2856 | " )\n", 2857 | " val_ds_single = create_dataset_with_feature_subset(\n", 2858 | " X_paths_val, X_features_val, y_val, feature_idx, augment=False\n", 2859 | " )\n", 2860 | " test_ds_single = create_dataset_with_feature_subset(\n", 2861 | " X_paths_test, X_features_test, y_test, feature_idx, augment=False\n", 2862 | " )\n", 2863 | "\n", 2864 | " # Build and train model\n", 2865 | " model = build_hybrid_model(1)\n", 2866 | "\n", 2867 | " history = model.fit(\n", 2868 | " train_ds_single,\n", 2869 | " validation_data=val_ds_single,\n", 2870 | " epochs=30,\n", 2871 | " callbacks=[EarlyStopping(patience=10, restore_best_weights=True, verbose=0)],\n", 2872 | " verbose=0\n", 2873 | " )\n", 2874 | "\n", 2875 | " test_loss, test_acc = model.evaluate(test_ds_single, verbose=0)\n", 2876 | "\n", 2877 | " results.append({\n", 2878 | " 'experiment': f'ONLY_{feature}',\n", 2879 | " 'n_features': 1,\n", 2880 | " 'features_used': feature,\n", 2881 | " 'test_accuracy': test_acc,\n", 2882 | " 'test_loss': test_loss,\n", 2883 | " 'accuracy_vs_baseline': test_acc - baseline_acc\n", 2884 | " })\n", 2885 | "\n", 2886 | " print(f\" Accuracy: {test_acc:.4f} (vs baseline: {test_acc - baseline_acc:+.4f})\\n\")\n", 2887 | "\n", 2888 | " # ==========================================\n", 2889 | " # Create results DataFrame\n", 2890 | " # ==========================================\n", 2891 | " results_df = pd.DataFrame(results)\n", 2892 | "\n", 2893 | " return results_df, baseline_acc\n", 2894 | "\n", 2895 | "\n", 2896 | "# ==========================================\n", 2897 | "# RUN THE STUDY\n", 2898 | "# ==========================================\n", 2899 | "\n", 2900 | "# Make sure you have feature_columns defined\n", 2901 | "# feature_columns = ['temperature', 'pressure', 'altitude', etc.]\n", 2902 | "\n", 2903 | "results_df, baseline_acc = feature_ablation_study(\n", 2904 | " X_paths_train, X_features_train, y_train,\n", 2905 | " X_paths_val, X_features_val, y_val,\n", 2906 | " X_paths_test, X_features_test, y_test,\n", 2907 | " feature_columns\n", 2908 | ")\n", 2909 | "\n", 2910 | "# Save results\n", 2911 | "results_df.to_csv('feature_ablation_results.csv', index=False)\n", 2912 | "print(\"\\n✓ Results saved to 'feature_ablation_results.csv'\")\n", 2913 | "\n", 2914 | "# ==========================================\n", 2915 | "# ANALYSIS AND VISUALIZATION\n", 2916 | "# ==========================================\n", 2917 | "\n", 2918 | "print(\"\\n\" + \"=\" * 80)\n", 2919 | "print(\"ABLATION STUDY RESULTS\")\n", 2920 | "print(\"=\" * 80)\n", 2921 | "\n", 2922 | "# Feature importance from leave-one-out\n", 2923 | "print(\"\\n📊 FEATURE IMPORTANCE (Leave-One-Out)\")\n", 2924 | "print(\"-\" * 80)\n", 2925 | "leave_one_out = results_df[results_df['experiment'].str.startswith('REMOVE_')].copy()\n", 2926 | "leave_one_out = leave_one_out.sort_values('accuracy_vs_baseline', ascending=False)\n", 2927 | "\n", 2928 | "print(\"\\nMost Important Features (largest accuracy drop when removed):\")\n", 2929 | "print(leave_one_out[['removed_feature', 'test_accuracy', 'accuracy_vs_baseline']].to_string(index=False))\n", 2930 | "\n", 2931 | "# Single feature performance\n", 2932 | "print(\"\\n📊 SINGLE FEATURE PERFORMANCE\")\n", 2933 | "print(\"-\" * 80)\n", 2934 | "single_features = results_df[results_df['experiment'].str.startswith('ONLY_')].copy()\n", 2935 | "single_features = single_features.sort_values('test_accuracy', ascending=False)\n", 2936 | "\n", 2937 | "print(\"\\nBest Individual Features:\")\n", 2938 | "print(single_features[['features_used', 'test_accuracy', 'accuracy_vs_baseline']].to_string(index=False))\n", 2939 | "\n", 2940 | "# Summary statistics\n", 2941 | "print(\"\\n📊 SUMMARY STATISTICS\")\n", 2942 | "print(\"-\" * 80)\n", 2943 | "print(f\"Baseline (all features): {baseline_acc:.4f}\")\n", 2944 | "print(f\"Best single feature: {single_features.iloc[0]['features_used']} ({single_features.iloc[0]['test_accuracy']:.4f})\")\n", 2945 | "print(f\"Worst single feature: {single_features.iloc[-1]['features_used']} ({single_features.iloc[-1]['test_accuracy']:.4f})\")\n", 2946 | "print(f\"Most critical feature: {leave_one_out.iloc[0]['removed_feature']} (drops {leave_one_out.iloc[0]['accuracy_vs_baseline']:.4f})\")\n", 2947 | "print(f\"Least critical feature: {leave_one_out.iloc[-1]['removed_feature']} (drops {leave_one_out.iloc[-1]['accuracy_vs_baseline']:.4f})\")\n", 2948 | "\n", 2949 | "# ==========================================\n", 2950 | "# VISUALIZATIONS\n", 2951 | "# ==========================================\n", 2952 | "\n", 2953 | "import matplotlib.pyplot as plt\n", 2954 | "import seaborn as sns\n", 2955 | "\n", 2956 | "sns.set_style(\"whitegrid\")\n", 2957 | "\n", 2958 | "fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n", 2959 | "\n", 2960 | "# Plot 1: Leave-one-out importance\n", 2961 | "ax1 = axes[0, 0]\n", 2962 | "leave_one_out_plot = leave_one_out.sort_values('accuracy_vs_baseline')\n", 2963 | "colors = ['red' if x > 0.01 else 'steelblue' for x in leave_one_out_plot['accuracy_vs_baseline']]\n", 2964 | "ax1.barh(leave_one_out_plot['removed_feature'], leave_one_out_plot['accuracy_vs_baseline'], color=colors)\n", 2965 | "ax1.set_xlabel('Accuracy Drop When Removed', fontweight='bold')\n", 2966 | "ax1.set_title('Feature Importance (Leave-One-Out)\\nRed = Critical Features', fontweight='bold', fontsize=12)\n", 2967 | "ax1.axvline(x=0, color='black', linestyle='--', linewidth=0.8)\n", 2968 | "ax1.grid(axis='x', alpha=0.3)\n", 2969 | "\n", 2970 | "# Plot 2: Single feature performance\n", 2971 | "ax2 = axes[0, 1]\n", 2972 | "single_features_plot = single_features.sort_values('test_accuracy')\n", 2973 | "ax2.barh(single_features_plot['features_used'], single_features_plot['test_accuracy'], color='skyblue')\n", 2974 | "ax2.axvline(x=baseline_acc, color='red', linestyle='--', linewidth=2, label=f'All Features ({baseline_acc:.3f})')\n", 2975 | "ax2.set_xlabel('Test Accuracy', fontweight='bold')\n", 2976 | "ax2.set_title('Individual Feature Performance\\n(Image + Single Feature)', fontweight='bold', fontsize=12)\n", 2977 | "ax2.legend()\n", 2978 | "ax2.grid(axis='x', alpha=0.3)\n", 2979 | "\n", 2980 | "# Plot 3: Accuracy vs baseline comparison\n", 2981 | "ax3 = axes[1, 0]\n", 2982 | "all_experiments = results_df[results_df['experiment'] != 'ALL_FEATURES'].copy()\n", 2983 | "all_experiments = all_experiments.sort_values('test_accuracy', ascending=False).head(10)\n", 2984 | "colors_acc = ['green' if x >= baseline_acc else 'orange' for x in all_experiments['test_accuracy']]\n", 2985 | "ax3.barh(range(len(all_experiments)), all_experiments['test_accuracy'], color=colors_acc)\n", 2986 | "ax3.set_yticks(range(len(all_experiments)))\n", 2987 | "ax3.set_yticklabels([exp.replace('ONLY_', '').replace('REMOVE_', 'No ') for exp in all_experiments['experiment']], fontsize=9)\n", 2988 | "ax3.axvline(x=baseline_acc, color='red', linestyle='--', linewidth=2, label='Baseline')\n", 2989 | "ax3.set_xlabel('Test Accuracy', fontweight='bold')\n", 2990 | "ax3.set_title('Top 10 Feature Combinations', fontweight='bold', fontsize=12)\n", 2991 | "ax3.legend()\n", 2992 | "ax3.grid(axis='x', alpha=0.3)\n", 2993 | "\n", 2994 | "# Plot 4: Feature correlation heatmap (if you want to add it)\n", 2995 | "ax4 = axes[1, 1]\n", 2996 | "# Calculate correlation between accuracy drop and single feature performance\n", 2997 | "feature_comparison = leave_one_out.merge(\n", 2998 | " single_features[['features_used', 'test_accuracy']],\n", 2999 | " left_on='removed_feature',\n", 3000 | " right_on='features_used',\n", 3001 | " suffixes=('_removed', '_single')\n", 3002 | ")\n", 3003 | "if len(feature_comparison) > 0:\n", 3004 | " ax4.scatter(feature_comparison['accuracy_vs_baseline'],\n", 3005 | " feature_comparison['test_accuracy_single'],\n", 3006 | " s=100, alpha=0.6, color='purple')\n", 3007 | " for idx, row in feature_comparison.iterrows():\n", 3008 | " ax4.annotate(row['removed_feature'],\n", 3009 | " (row['accuracy_vs_baseline'], row['test_accuracy_single']),\n", 3010 | " fontsize=8, alpha=0.7)\n", 3011 | " ax4.set_xlabel('Accuracy Drop When Removed', fontweight='bold')\n", 3012 | " ax4.set_ylabel('Accuracy as Single Feature', fontweight='bold')\n", 3013 | " ax4.set_title('Feature Importance: Removal vs Individual Performance', fontweight='bold', fontsize=12)\n", 3014 | " ax4.grid(alpha=0.3)\n", 3015 | "\n", 3016 | "plt.tight_layout()\n", 3017 | "plt.savefig('feature_ablation_analysis.png', dpi=300, bbox_inches='tight')\n", 3018 | "print(\"\\n✓ Visualization saved to 'feature_ablation_analysis.png'\")\n", 3019 | "plt.show()" 3020 | ] 3021 | }, 3022 | { 3023 | "cell_type": "code", 3024 | "execution_count": null, 3025 | "metadata": { 3026 | "colab": { 3027 | "base_uri": "https://localhost:8080/", 3028 | "height": 762 3029 | }, 3030 | "id": "Ri_P2nHml4Hx", 3031 | "outputId": "f4286c25-7e36-480b-cf26-f9d7ea2c1b5b" 3032 | }, 3033 | "outputs": [], 3034 | "source": [ 3035 | "import pandas as pd\n", 3036 | "import seaborn as sns\n", 3037 | "import matplotlib.pyplot as plt\n", 3038 | "\n", 3039 | "# Create DataFrame of your features\n", 3040 | "feature_df = pd.DataFrame(\n", 3041 | " X_features_train,\n", 3042 | " columns=feature_columns\n", 3043 | ")\n", 3044 | "\n", 3045 | "# Calculate correlation matrix\n", 3046 | "correlation_matrix = feature_df.corr()\n", 3047 | "\n", 3048 | "print(\"Feature Correlation Matrix:\")\n", 3049 | "print(correlation_matrix)\n", 3050 | "\n", 3051 | "# Visualize\n", 3052 | "plt.figure(figsize=(8, 6))\n", 3053 | "sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0,\n", 3054 | " square=True, linewidths=1, cbar_kws={\"shrink\": 0.8})\n", 3055 | "plt.title('Feature Correlation Matrix', fontweight='bold', fontsize=14)\n", 3056 | "plt.tight_layout()\n", 3057 | "plt.savefig('feature_correlation.png', dpi=300, bbox_inches='tight')\n", 3058 | "plt.show()\n", 3059 | "\n", 3060 | "# Check for high correlations\n", 3061 | "print(\"\\nHighly Correlated Feature Pairs (|r| > 0.8):\")\n", 3062 | "for i in range(len(correlation_matrix.columns)):\n", 3063 | " for j in range(i+1, len(correlation_matrix.columns)):\n", 3064 | " corr = correlation_matrix.iloc[i, j]\n", 3065 | " if abs(corr) > 0.8:\n", 3066 | " print(f\"{correlation_matrix.columns[i]} <-> {correlation_matrix.columns[j]}: {corr:.3f}\")" 3067 | ] 3068 | }, 3069 | { 3070 | "cell_type": "code", 3071 | "execution_count": null, 3072 | "metadata": { 3073 | "colab": { 3074 | "base_uri": "https://localhost:8080/" 3075 | }, 3076 | "id": "92T6d3in5sU9", 3077 | "outputId": "1740308b-2ca7-47f3-82ed-6c1dcce07c45" 3078 | }, 3079 | "outputs": [], 3080 | "source": [ 3081 | "# Load the best weights for ablation\n", 3082 | "best_image_only = load_model('best_particle_classifier_image_only.keras')\n", 3083 | "best_feature_only = load_model('best_particle_classifier_feature_only.keras')\n", 3084 | "\n", 3085 | "# 3. Evaluate All Models on Test Set\n", 3086 | "\n", 3087 | "# Prepare test datasets for ablated models\n", 3088 | "test_ds_img = create_image_only_ds(X_paths_test, y_test, augment=False)\n", 3089 | "test_ds_feat = create_feature_only_ds(X_features_test, y_test)\n", 3090 | "\n", 3091 | "# Evaluation function\n", 3092 | "def evaluate_model(model, ds):\n", 3093 | " metrics = model.evaluate(ds, verbose=0)\n", 3094 | " return metrics[4] # Return AUC\n", 3095 | "\n", 3096 | "print(\"\\n\" + \"=\" * 50)\n", 3097 | "print(\"BRANCH ABLATION RESULTS (TEST AUC)\")\n", 3098 | "print(\"=\" * 50)\n", 3099 | "\n", 3100 | "# Calculate AUC for all three models\n", 3101 | "auc_baseline = evaluate_model(baseline_model, test_ds)\n", 3102 | "auc_img_only = evaluate_model(best_image_only, test_ds_img)\n", 3103 | "auc_feat_only = evaluate_model(best_feature_only, test_ds_feat)\n", 3104 | "\n", 3105 | "results = {\n", 3106 | " 'Hybrid Model': auc_baseline,\n", 3107 | " 'Image-Only Model': auc_img_only,\n", 3108 | " 'Feature-Only Model': auc_feat_only\n", 3109 | "}\n", 3110 | "\n", 3111 | "print(f\"Hybrid Model (Baseline): {auc_baseline:.4f}\")\n", 3112 | "print(f\"Image-Only Model: {auc_img_only:.4f} (Drop: {auc_baseline - auc_img_only:.4f})\")\n", 3113 | "print(f\"Feature-Only Model: {auc_feat_only:.4f} (Drop: {auc_baseline - auc_feat_only:.4f})\")\n", 3114 | "\n", 3115 | "\n", 3116 | "# 4. Individual Feature Ablation (Perturbing Test Data)\n", 3117 | "print(\"\\n\" + \"=\" * 50)\n", 3118 | "print(\"INDIVIDUAL FEATURE ABLATION (Hybrid Model Test AUC)\")\n", 3119 | "print(\"=\" * 50)\n", 3120 | "\n", 3121 | "# We evaluate the performance of the full hybrid model when individual features are zeroed out.\n", 3122 | "auc_individual_ablation = {}\n", 3123 | "X_feat_test_original = X_features_test.copy()\n", 3124 | "\n", 3125 | "for i, feature_name in enumerate(feature_columns):\n", 3126 | "\n", 3127 | " # 1. Create a perturbed test feature set (zero out column i)\n", 3128 | " X_feat_test_perturbed = X_feat_test_original.copy()\n", 3129 | " X_feat_test_perturbed[:, i] = 0.0 # Set column i to zero\n", 3130 | "\n", 3131 | " # 2. Create the perturbed hybrid test dataset\n", 3132 | " test_ds_perturbed = tf.data.Dataset.from_tensor_slices(\n", 3133 | " (X_paths_test, X_feat_test_perturbed, y_test)\n", 3134 | " )\n", 3135 | " test_ds_perturbed = test_ds_perturbed.map(\n", 3136 | " lambda p, f, l: load_and_augment_hybrid(p, f, l, augment=False),\n", 3137 | " num_parallel_calls=tf.data.AUTOTUNE\n", 3138 | " ).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)\n", 3139 | "\n", 3140 | " # 3. Evaluate the full hybrid model\n", 3141 | " metrics = baseline_model.evaluate(test_ds_perturbed, verbose=0)\n", 3142 | " auc_perturbed = metrics[4]\n", 3143 | "\n", 3144 | " drop = auc_baseline - auc_perturbed\n", 3145 | " auc_individual_ablation[feature_name] = {'AUC': auc_perturbed, 'Drop': drop}\n", 3146 | "\n", 3147 | "# Display ranked results\n", 3148 | "ranked_results = sorted(auc_individual_ablation.items(), key=lambda item: item[1]['Drop'], reverse=True)\n", 3149 | "\n", 3150 | "print(\"Feature\\t\\tAUC After Ablation\\tAUC Drop\")\n", 3151 | "print(\"-\" * 50)\n", 3152 | "for feature, data in ranked_results:\n", 3153 | " print(f\"{feature}:\\t\\t{data['AUC']:.4f}\\t\\t\\t{data['Drop']:.4f}\")\n", 3154 | "\n", 3155 | "print(\"\\nFinal Branch Ablation Summary:\")\n", 3156 | "print(\"A higher AUC Drop indicates a more important branch/feature.\")\n", 3157 | "\n", 3158 | "\n", 3159 | "# Example usage (uncomment and run if all data variables are defined)\n", 3160 | "# run_ablation_analysis()" 3161 | ] 3162 | }, 3163 | { 3164 | "cell_type": "code", 3165 | "execution_count": null, 3166 | "metadata": { 3167 | "id": "29nZSAbZ9w2_" 3168 | }, 3169 | "outputs": [], 3170 | "source": [ 3171 | "##Ablation analysis:\n", 3172 | "\n", 3173 | "import matplotlib.pyplot as plt\n", 3174 | "import seaborn as sns\n", 3175 | "import pandas as pd\n", 3176 | "import numpy as np\n", 3177 | "\n", 3178 | "# --- Data provided by the user's ablation analysis ---\n", 3179 | "\n", 3180 | "# 1. Branch Ablation Results\n", 3181 | "BRANCH_RESULTS = {\n", 3182 | " 'Model': ['Hybrid (Baseline)', 'Image-Only', 'Feature-Only'],\n", 3183 | " 'AUC': [0.9954, 0.9791, 1.0000],\n", 3184 | " 'Drop': [0.0000, 0.0163, -0.0046] # Negative drop means performance improved\n", 3185 | "}\n", 3186 | "\n", 3187 | "# 2. Individual Feature Ablation Results\n", 3188 | "FEATURE_RESULTS = {\n", 3189 | " 'Feature': ['aircrafttas', 'diam', 'xsize', 'ysize', 'area', 'GGALT', 'WSC', 'WDC', 'PALT_A'],\n", 3190 | " 'Drop': [0.0461, 0.0023, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]\n", 3191 | "}\n", 3192 | "\n", 3193 | "def plot_branch_ablation(branch_df: pd.DataFrame):\n", 3194 | " \"\"\"Generates a bar chart comparing the AUC of the three main models.\"\"\"\n", 3195 | " plt.figure(figsize=(9, 6))\n", 3196 | " sns.barplot(x='Model', y='AUC', data=branch_df, palette='viridis')\n", 3197 | "\n", 3198 | " # Add AUC values on top of bars\n", 3199 | " for index, row in branch_df.iterrows():\n", 3200 | " plt.text(index, row['AUC'] - 0.005, f\"{row['AUC']:.4f}\", color='white', ha='center', fontweight='bold')\n", 3201 | "\n", 3202 | " plt.ylim(min(branch_df['AUC']) - 0.01, 1.01)\n", 3203 | " plt.title('Branch Ablation Analysis: Model Performance (Test AUC)', fontsize=16)\n", 3204 | " plt.ylabel('Test AUC Score', fontsize=12)\n", 3205 | " plt.xlabel('Model Configuration', fontsize=12)\n", 3206 | " plt.grid(axis='y', linestyle='--', alpha=0.7)\n", 3207 | " plt.show()\n", 3208 | "\n", 3209 | "\n", 3210 | "def plot_feature_importance(feature_df: pd.DataFrame):\n", 3211 | " \"\"\"Generates a ranked bar chart showing the AUC drop per feature ablation.\"\"\"\n", 3212 | "\n", 3213 | " # Calculate performance change as a percentage drop for better interpretation\n", 3214 | " feature_df['Drop_Pct'] = (feature_df['Drop'] / BRANCH_RESULTS['AUC'][0]) * 100\n", 3215 | "\n", 3216 | " # Sort by drop percentage\n", 3217 | " feature_df = feature_df.sort_values(by='Drop', ascending=False)\n", 3218 | "\n", 3219 | " plt.figure(figsize=(12, 7))\n", 3220 | " sns.barplot(x='Drop_Pct', y='Feature', data=feature_df, palette='magma')\n", 3221 | "\n", 3222 | " # Add drop values\n", 3223 | " for index, row in feature_df.iterrows():\n", 3224 | " plt.text(row['Drop_Pct'] + 0.1, index, f\"{row['Drop_Pct']:.2f}% ({row['Drop']:.4f} AUC Drop)\",\n", 3225 | " color='black', va='center')\n", 3226 | "\n", 3227 | " plt.title('Individual Feature Importance: AUC Drop from Baseline', fontsize=16)\n", 3228 | " plt.xlabel('Performance Drop (%)', fontsize=12)\n", 3229 | " plt.ylabel('Ablated Feature', fontsize=12)\n", 3230 | " plt.grid(axis='x', linestyle='--', alpha=0.7)\n", 3231 | " plt.show()\n", 3232 | "\n", 3233 | "\n", 3234 | "def plot_feature_separability_check(\n", 3235 | " X_feat_test: np.ndarray, y_test: np.ndarray, feature_columns: list, feature_name: str\n", 3236 | "):\n", 3237 | " \"\"\"\n", 3238 | " Plots the density of a single critical feature (e.g., 'aircrafttas')\n", 3239 | " separated by the true class label (0 or 1).\n", 3240 | "\n", 3241 | " NOTE: X_feat_test and y_test must be provided from your main environment.\n", 3242 | " \"\"\"\n", 3243 | "\n", 3244 | " try:\n", 3245 | " for ind,feat in enumerate(feature_columns):\n", 3246 | " if feat == feature_name:\n", 3247 | " feature_index = ind\n", 3248 | " except ValueError:\n", 3249 | " print(f\"Error: Feature '{feature_name}' not found in feature_columns.\")\n", 3250 | " return\n", 3251 | "\n", 3252 | " # Create a DataFrame for easy plotting\n", 3253 | " plot_df = pd.DataFrame({\n", 3254 | " feature_name: X_feat_test[:, feature_index],\n", 3255 | " 'Class Label': np.argmax(y_test, axis=1) # Convert one-hot to class index\n", 3256 | " })\n", 3257 | "\n", 3258 | " plt.figure(figsize=(10, 6))\n", 3259 | "\n", 3260 | " # Use Kernel Density Estimate (KDE) plot to show distribution overlap\n", 3261 | " sns.kdeplot(\n", 3262 | " data=plot_df,\n", 3263 | " x=feature_name,\n", 3264 | " hue='Class Label',\n", 3265 | " fill=True,\n", 3266 | " alpha=.5,\n", 3267 | " linewidth=2,\n", 3268 | " legend=True\n", 3269 | " )\n", 3270 | "\n", 3271 | " plt.title(f'Feature Separability: Density of \"{feature_name}\" by Class', fontsize=16)\n", 3272 | " plt.xlabel(feature_name, fontsize=12)\n", 3273 | " plt.ylabel('Density', fontsize=12)\n", 3274 | " plt.legend(title='Class')\n", 3275 | " plt.grid(axis='y', linestyle='--', alpha=0.7)\n", 3276 | " plt.show()\n" 3277 | ] 3278 | }, 3279 | { 3280 | "cell_type": "code", 3281 | "execution_count": null, 3282 | "metadata": { 3283 | "colab": { 3284 | "base_uri": "https://localhost:8080/", 3285 | "height": 657 3286 | }, 3287 | "id": "PcnDEDzM92M5", 3288 | "outputId": "355147da-9f37-4884-aa96-1b2e1248807e" 3289 | }, 3290 | "outputs": [], 3291 | "source": [ 3292 | "branch_df = pd.DataFrame(BRANCH_RESULTS)\n", 3293 | "feature_df = pd.DataFrame(FEATURE_RESULTS)\n", 3294 | "\n", 3295 | "# 1. Visualize Branch Ablation\n", 3296 | "plot_branch_ablation(branch_df)\n", 3297 | "\n" 3298 | ] 3299 | }, 3300 | { 3301 | "cell_type": "code", 3302 | "execution_count": null, 3303 | "metadata": { 3304 | "colab": { 3305 | "base_uri": "https://localhost:8080/", 3306 | "height": 734 3307 | }, 3308 | "id": "iR_X0i5A99vA", 3309 | "outputId": "7b8403dc-507f-497d-d17f-0c4e9b690b21" 3310 | }, 3311 | "outputs": [], 3312 | "source": [ 3313 | "# 2. Visualize Individual Feature Importance\n", 3314 | "plot_feature_importance(feature_df)" 3315 | ] 3316 | }, 3317 | { 3318 | "cell_type": "code", 3319 | "execution_count": null, 3320 | "metadata": { 3321 | "colab": { 3322 | "base_uri": "https://localhost:8080/" 3323 | }, 3324 | "id": "sRSKCtaIH4lA", 3325 | "outputId": "2a7f8ab1-09c9-4a59-8b7e-f44b8c0170e2" 3326 | }, 3327 | "outputs": [], 3328 | "source": [ 3329 | "feature_columns" 3330 | ] 3331 | }, 3332 | { 3333 | "cell_type": "code", 3334 | "execution_count": null, 3335 | "metadata": { 3336 | "colab": { 3337 | "base_uri": "https://localhost:8080/", 3338 | "height": 605 3339 | }, 3340 | "id": "p2-cwRUF-CI9", 3341 | "outputId": "37fde4ed-2b09-41b2-89ad-d9358fbf9ce2" 3342 | }, 3343 | "outputs": [], 3344 | "source": [ 3345 | "try:\n", 3346 | " plot_feature_separability_check(X_features_test, y_test, feature_columns, 'xsize')\n", 3347 | "except NameError:\n", 3348 | " print(\"Skipping Feature Separability Check: Data arrays are not defined in this scope.\")" 3349 | ] 3350 | }, 3351 | { 3352 | "cell_type": "code", 3353 | "execution_count": null, 3354 | "metadata": { 3355 | "id": "AP4BZNxwAVQo" 3356 | }, 3357 | "outputs": [], 3358 | "source": [ 3359 | "def find_optimal_threshold(\n", 3360 | " X_feat_test: np.ndarray, y_test: np.ndarray, feature_columns: list, feature_name: str\n", 3361 | ") -> Tuple[float, float, float]:\n", 3362 | " \"\"\"\n", 3363 | " Finds the optimal binary classification threshold for a single feature\n", 3364 | " by maximizing accuracy on the test set.\n", 3365 | "\n", 3366 | " Args:\n", 3367 | " X_feat_test: NumPy array of test features (already normalized/scaled).\n", 3368 | " y_test: NumPy array of test labels (one-hot encoded).\n", 3369 | " feature_columns: List of feature names.\n", 3370 | " feature_name: The specific feature to analyze (e.g., 'aircrafttas').\n", 3371 | "\n", 3372 | " Returns:\n", 3373 | " A tuple containing (optimal_threshold, max_accuracy, corresponding_auc).\n", 3374 | " \"\"\"\n", 3375 | "\n", 3376 | " try:\n", 3377 | " for ind,feat in enumerate(feature_columns):\n", 3378 | " if feat == feature_name:\n", 3379 | " feature_index = ind\n", 3380 | " except ValueError:\n", 3381 | " print(f\"Error: Feature '{feature_name}' not found in feature_columns.\")\n", 3382 | " return None, None, None\n", 3383 | "\n", 3384 | " # 1. Extract the feature vector and true binary labels\n", 3385 | " feature_data = X_feat_test[:, feature_index]\n", 3386 | " # Assuming the first column is the positive class (Class 1) or converting to 0/1 index\n", 3387 | " true_labels = np.argmax(y_test, axis=1)\n", 3388 | "\n", 3389 | " # 2. Get unique values to check as potential thresholds (sorted)\n", 3390 | " threshold_candidates = np.sort(np.unique(feature_data))\n", 3391 | "\n", 3392 | " best_accuracy = 0.0\n", 3393 | " optimal_threshold = None\n", 3394 | "\n", 3395 | " # 3. Iterate through every unique value as a potential threshold\n", 3396 | " for threshold in threshold_candidates:\n", 3397 | " # Predict: Class 1 if feature value > threshold, Class 0 otherwise.\n", 3398 | " # This assumes Class 1 is associated with higher values, based on the plot.\n", 3399 | " predictions = (feature_data > threshold).astype(int)\n", 3400 | "\n", 3401 | " # Calculate accuracy for this threshold\n", 3402 | " accuracy = np.mean(predictions == true_labels)\n", 3403 | "\n", 3404 | " if accuracy > best_accuracy:\n", 3405 | " best_accuracy = accuracy\n", 3406 | " optimal_threshold = threshold\n", 3407 | "\n", 3408 | " # Calculate AUC for the final optimal threshold (optional, but good for completeness)\n", 3409 | " # Since we are using a single threshold and the data is perfectly separated,\n", 3410 | " # the AUC should match the reported 1.0000.\n", 3411 | " from sklearn.metrics import roc_auc_score\n", 3412 | " final_predictions_proba = feature_data # Use feature value as the score\n", 3413 | " try:\n", 3414 | " final_auc = roc_auc_score(true_labels, final_predictions_proba)\n", 3415 | " except:\n", 3416 | " final_auc = 1.0 # If AUC is exactly 1.0, roc_auc_score might throw a warning/error on perfect separation\n", 3417 | "\n", 3418 | " return optimal_threshold, best_accuracy, final_auc" 3419 | ] 3420 | }, 3421 | { 3422 | "cell_type": "code", 3423 | "execution_count": null, 3424 | "metadata": { 3425 | "colab": { 3426 | "base_uri": "https://localhost:8080/" 3427 | }, 3428 | "id": "lcnmGcCyAWcw", 3429 | "outputId": "71e1f039-83df-4675-b5b2-a2b7552fb762" 3430 | }, 3431 | "outputs": [], 3432 | "source": [ 3433 | "threshold, accuracy, auc = find_optimal_threshold(\n", 3434 | " X_features_test, y_test, feature_columns, 'aircrafttas'\n", 3435 | ")\n", 3436 | "\n", 3437 | "if threshold is not None:\n", 3438 | " print(f\"\\nOptimal Threshold Analysis for 'aircrafttas':\")\n", 3439 | " print(\"-\" * 40)\n", 3440 | " print(f\"Optimal Threshold Value: {threshold:.4f}\")\n", 3441 | " print(f\"Classification Accuracy at this Threshold: {accuracy:.4f}\")\n", 3442 | " print(f\"Corresponding AUC (Score based on raw feature): {auc:.4f}\")\n", 3443 | " print(\"\\nInterpretation:\")\n", 3444 | " print(f\"If aircrafttas > {threshold:.4f}, predict Class 1.\")\n", 3445 | " print(f\"If aircrafttas <= {threshold:.4f}, predict Class 0.\")" 3446 | ] 3447 | }, 3448 | { 3449 | "cell_type": "code", 3450 | "execution_count": null, 3451 | "metadata": { 3452 | "colab": { 3453 | "base_uri": "https://localhost:8080/", 3454 | "height": 466 3455 | }, 3456 | "id": "21X6fQMdA9jI", 3457 | "outputId": "99e7673d-4c05-4f82-b62f-da9b74a2c721" 3458 | }, 3459 | "outputs": [], 3460 | "source": [ 3461 | "sns.scatterplot(x=df_valid['Time'], y=df_valid['aircrafttas'])" 3462 | ] 3463 | }, 3464 | { 3465 | "cell_type": "code", 3466 | "execution_count": null, 3467 | "metadata": { 3468 | "colab": { 3469 | "base_uri": "https://localhost:8080/" 3470 | }, 3471 | "id": "PYIHfJYqGKqw", 3472 | "outputId": "80b4b091-7a43-4dcd-975e-a32f959ed8ff" 3473 | }, 3474 | "outputs": [], 3475 | "source": [ 3476 | "len(df_valid)/10" 3477 | ] 3478 | }, 3479 | { 3480 | "cell_type": "code", 3481 | "execution_count": null, 3482 | "metadata": { 3483 | "colab": { 3484 | "base_uri": "https://localhost:8080/", 3485 | "height": 424 3486 | }, 3487 | "id": "KboU75_OHIBP", 3488 | "outputId": "66e9efc2-d9eb-4b40-89a5-4e85eeaf092d" 3489 | }, 3490 | "outputs": [], 3491 | "source": [ 3492 | "df_valid" 3493 | ] 3494 | }, 3495 | { 3496 | "cell_type": "code", 3497 | "execution_count": null, 3498 | "metadata": { 3499 | "id": "DsXMjb2VGDDR" 3500 | }, 3501 | "outputs": [], 3502 | "source": [ 3503 | "xticks = []\n", 3504 | "for x in range(0,len(df_valid),433):\n", 3505 | " xticks.append(df_valid['Time'].iloc[x])" 3506 | ] 3507 | }, 3508 | { 3509 | "cell_type": "code", 3510 | "execution_count": null, 3511 | "metadata": { 3512 | "colab": { 3513 | "base_uri": "https://localhost:8080/", 3514 | "height": 1000 3515 | }, 3516 | "id": "O6GsTQykFkYS", 3517 | "outputId": "cf66b585-5f34-4866-8af4-5ec7834b5d6a" 3518 | }, 3519 | "outputs": [], 3520 | "source": [ 3521 | "import matplotlib.pyplot as plt\n", 3522 | "import seaborn as sns\n", 3523 | "for feature in feature_columns:\n", 3524 | " sns.scatterplot(x=df_valid['Time'], y=df_valid[feature], hue=df_valid['phase'])\n", 3525 | " plt.show()" 3526 | ] 3527 | }, 3528 | { 3529 | "cell_type": "code", 3530 | "execution_count": null, 3531 | "metadata": { 3532 | "colab": { 3533 | "base_uri": "https://localhost:8080/", 3534 | "height": 466 3535 | }, 3536 | "id": "oFzHn6qjBv4X", 3537 | "outputId": "8b925982-0895-4281-883e-a251d1a5544b" 3538 | }, 3539 | "outputs": [], 3540 | "source": [ 3541 | "sns.scatterplot(x=df_valid['Time'], y=df_valid['diam'])" 3542 | ] 3543 | }, 3544 | { 3545 | "cell_type": "markdown", 3546 | "metadata": { 3547 | "id": "fUtBS64OBvtb" 3548 | }, 3549 | "source": [] 3550 | }, 3551 | { 3552 | "cell_type": "markdown", 3553 | "metadata": { 3554 | "id": "jK6cRtfBGumj" 3555 | }, 3556 | "source": [ 3557 | "## 16. Save Predictions (Optional)" 3558 | ] 3559 | }, 3560 | { 3561 | "cell_type": "code", 3562 | "execution_count": null, 3563 | "metadata": { 3564 | "id": "WGLEGH4DGumj" 3565 | }, 3566 | "outputs": [], 3567 | "source": [ 3568 | "# Save predictions to CSV for further analysis\n", 3569 | "# predictions_df = pd.DataFrame({\n", 3570 | "# 'particle_idx_seq': df_valid.iloc[test_indices]['particle_idx_seq'].values,\n", 3571 | "# 'true_label': y_true,\n", 3572 | "# 'predicted_label': y_pred,\n", 3573 | "# 'liquid_probability': y_pred_probs[:, 0],\n", 3574 | "# 'solid_probability': y_pred_probs[:, 1],\n", 3575 | "# 'donut_probability': y_pred_probs[:, 2],\n", 3576 | "# 'noise_probability': y_pred_probs[:, 3],\n", 3577 | "# 'temperature': X_feat_test[:, 0] * scaler.scale_[0] + scaler.mean_[0],\n", 3578 | "# 'air_speed': X_feat_test[:, 1] * scaler.scale_[1] + scaler.mean_[1],\n", 3579 | "# 'altitude': X_feat_test[:, 2] * scaler.scale_[2] + scaler.mean_[2]\n", 3580 | "# })\n", 3581 | "# predictions_df.to_csv('particle_predictions_hybrid_4class.csv', index=False)\n", 3582 | "# print(\"Predictions saved to 'particle_predictions_hybrid_4class.csv'\")" 3583 | ] 3584 | }, 3585 | { 3586 | "cell_type": "markdown", 3587 | "metadata": { 3588 | "id": "TISvzAD5Gumj" 3589 | }, 3590 | "source": [ 3591 | "## Notes and Next Steps\n", 3592 | "\n", 3593 | "### Hybrid 4-Class Model Architecture:\n", 3594 | "- **CNN branch**: 4 convolutional blocks extract spatial/morphological features from particle images\n", 3595 | "- **Environmental branch**: Dense layers process temperature, air speed, and altitude\n", 3596 | "- **Fusion**: Both branches are concatenated before final classification\n", 3597 | "- **Output**: 4 classes (Liquid, Solid, Donut, Noise)\n", 3598 | "- **Regularization**: Dropout and BatchNormalization prevent overfitting\n", 3599 | "- **Class weights**: Applied to handle class imbalance\n", 3600 | "\n", 3601 | "### Phase Descriptions:\n", 3602 | "- **Liquid (0)**: Liquid water droplets\n", 3603 | "- **Solid (1)**: Ice crystals\n", 3604 | "- **Donut (2)**: Donut-shaped artifacts or special particle types\n", 3605 | "- **Noise (3)**: Noisy or invalid particle images\n", 3606 | "\n", 3607 | "### Advantages of Hybrid Approach:\n", 3608 | "1. **Multi-modal learning**: Combines visual and environmental information\n", 3609 | "2. **Physical constraints**: Temperature can help disambiguate liquid vs ice\n", 3610 | "3. **Context awareness**: Altitude and air speed provide atmospheric context\n", 3611 | "4. **Improved accuracy**: Environmental features can improve classification, especially for ambiguous cases\n", 3612 | "\n", 3613 | "### Comparison with Image-Only Model:\n", 3614 | "Compare this hybrid model with the image-only baseline:\n", 3615 | "1. **Accuracy improvement**: How much do environmental features help overall?\n", 3616 | "2. **Per-class performance**: Which phases benefit most from environmental data?\n", 3617 | "3. **Confidence**: Does the hybrid model have higher prediction confidence?\n", 3618 | "4. **Misclassification patterns**: Does it reduce specific confusion pairs?\n", 3619 | "5. **Physical consistency**: Are predictions more physically plausible?\n", 3620 | "\n", 3621 | "### Potential Improvements:\n", 3622 | "1. **Data augmentation**: Add rotation, flipping, scaling to particle images\n", 3623 | "2. **Attention mechanisms**: Let the model learn which features matter most\n", 3624 | "3. **Different architectures**: Try ResNet, EfficientNet, or Vision Transformers\n", 3625 | "4. **Hyperparameter tuning**: Learning rate, batch size, dropout rates, architecture depth\n", 3626 | "5. **Ensemble methods**: Combine multiple models\n", 3627 | "6. **Feature engineering**: Add derived features (e.g., supersaturation, distance from freezing)\n", 3628 | "7. **Focal loss**: Better handling of hard examples\n", 3629 | "8. **Cross-validation**: K-fold CV for more robust evaluation\n", 3630 | "\n", 3631 | "### Key Research Questions:\n", 3632 | "- How much does temperature improve liquid/solid classification?\n", 3633 | "- Are donut particles associated with specific atmospheric conditions?\n", 3634 | "- Can environmental features help filter noise automatically?\n", 3635 | "- What's the optimal balance between image and environmental features?\n", 3636 | "- How well does the model generalize across different flight conditions?" 3637 | ] 3638 | } 3639 | ], 3640 | "metadata": { 3641 | "accelerator": "GPU", 3642 | "colab": { 3643 | "gpuType": "A100", 3644 | "machine_shape": "hm", 3645 | "provenance": [] 3646 | }, 3647 | "kernelspec": { 3648 | "display_name": "Python 3", 3649 | "name": "python3" 3650 | }, 3651 | "language_info": { 3652 | "codemirror_mode": { 3653 | "name": "ipython", 3654 | "version": 3 3655 | }, 3656 | "file_extension": ".py", 3657 | "mimetype": "text/x-python", 3658 | "name": "python", 3659 | "nbconvert_exporter": "python", 3660 | "pygments_lexer": "ipython3", 3661 | "version": "3.8.0" 3662 | } 3663 | }, 3664 | "nbformat": 4, 3665 | "nbformat_minor": 0 3666 | } 3667 | --------------------------------------------------------------------------------