├── NoProp.pdf ├── README.md └── NoProp.ipynb /NoProp.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sid3503/NoProp/HEAD/NoProp.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NoProp: Training Neural Networks Without Forward/Backward Propagation 🚀 2 | 3 | [![Paper](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2503.24322) 4 | [![PyTorch](https://img.shields.io/badge/PyTorch-2.0+-red.svg)](https://pytorch.org) 5 | 6 | Official implementation of **NoProp**, a novel neural network training method that eliminates both forward and backward propagation through denoising diffusion. Achieves competitive performance on MNIST/CIFAR while enabling parallel layer training. 7 | 8 | 9 | ## 🔍 Table of Contents 10 | - [Key Innovations](#-key-innovations) 11 | - [Mathematical Foundations](#-mathematical-foundations) 12 | - [Implementation](#-implementation) 13 | - [Results](#-results) 14 | 15 | ## 🚀 Key Innovations 16 | NoProp introduces three paradigm shifts: 17 | 18 | 1. **Propagation-Free Training** 19 | - No sequential forward/backward passes 20 | - Layers train **independently** via denoising 21 | 22 | 2. **Diffusion-Based Learning** 23 | - Corrupts labels progressively (forward diffusion) 24 | - Each layer learns to denoise a specific noise level 25 | 26 | 3. **Biological Plausibility** 27 | - Avoids weight transport problem 28 | - Local learning only (no global gradients) 29 | 30 | --- 31 | 32 | ## 📜 Mathematical Foundations 33 | 34 | ### A. Forward Diffusion Process (Noising) 35 | 36 | NoProp gradually corrupts clean labels through a Markov chain of noise additions: 37 | 38 | #### 1. Noise Corruption Equation 39 | ```math 40 | z_t = \sqrt{\alpha_t} z_{t-1} + \sqrt{1-\alpha_t} \epsilon_t 41 | ``` 42 | Where: 43 | - $z_t$: Noisy label at step $t$ 44 | - $\alpha_t$: Noise schedule ($\alpha_0=1 \rightarrow \alpha_T\approx0$) 45 | - $\epsilon_t \sim \mathcal{N}(0,I)$: Gaussian noise 46 | - $z_0 = u_y$: Ground truth one-hot label 47 | 48 | #### 2. Noise Schedule Properties 49 | | Parameter | Role | Typical Value | 50 | |-----------|------|---------------| 51 | | `αₜ` | Controls noise level | Linear: `1.0 → 0.1` | - [1] 52 | | `T` | Total steps | 10-1000 | 53 | | `ᾱₜ = ∏ αₛ` (s=1 to t) | Cumulative product | (auto-computed) | - [2] 54 | 55 | [1]: `αₜ` decreases linearly 56 | [2]: `ᾱₜ = α₁ × α₂ × ... × αₜ` 57 | 58 | #### 3. Step-by-Step Example (MNIST) 59 | Given label "2" ($u_y = [0,0,1,0,...]$): 60 | 61 | | Step $t$ | $\alpha_t$ | $z_t$ (Visualized) | Noise Level | 62 | |---------|-----------|--------------------|------------| 63 | | 0 | 1.0 | [0, 0, 1.0, 0] | 0% | 64 | | 1 | 0.9 | [0, 0.1, 0.85, 0.05] | 10% | 65 | | 2 | 0.7 | [0.05, 0.15, 0.7, 0.1] | 30% | 66 | | ... | ... | ... | ... | 67 | | T | 0.1 | [0.25, 0.25, 0.3, 0.2] | 90% | 68 | 69 | #### 4. Key Properties 70 | 1. **Gradual Corruption**: 71 | ```math 72 | \text{SNR}(t) = \frac{\alpha_t}{1-\alpha_t} \quad \text{(Monotonically decreases)} 73 | ``` 74 | 2. **Variance-Preserving**: 75 | ```math 76 | \text{Var}(z_t) = \text{Var}(z_{t-1}) = 1 77 | ``` 78 | 3. **Closed-Form Sampling**: 79 | ```math 80 | q(z_t|u_y) = \mathcal{N}(z_t; \sqrt{\bar{\alpha}_t}u_y, (1-\bar{\alpha}_t)I) 81 | ``` 82 | 83 | 84 | ### B. Reverse Process (Training) 85 | Each MLP layer $t$ predicts clean labels from noisy inputs: 86 | 87 | ```math 88 | \mathcal{L}_t = \mathbb{E} \| \hat{u}_\theta(z_t,x) - u_y \|^2 89 | ``` 90 | 91 | Where: 92 | - $\hat{u}_\theta$: MLP prediction 93 | - $x$: Input image features 94 | - $u_y$: Ground truth one-hot label 95 | 96 | 97 | During inference, NoProp iteratively refines noisy labels through learned denoising steps: 98 | 99 | #### 1. Denoising Update Rule 100 | ```math 101 | z_{t-1} = \sqrt{\alpha_{l-1}} \underbrace{\hat{u}_\theta(z_t,x)}_{\text{Predicted clean label}} + \sqrt{1-\alpha_{t-1}} \epsilon_t 102 | ``` 103 | 104 | Where: 105 | - $z_t$: Noisy label at step $t$ 106 | - $\hat{u}_\theta(z_t,x)$: MLP's prediction of clean label 107 | - $\alpha_{t-1}$: Noise schedule value 108 | - $\epsilon_t \sim \mathcal{N}(0,I)$: Fresh Gaussian noise 109 | 110 | #### 2. Step-by-Step Process 111 | 1. **Start from noise**: $z_T \sim \mathcal{N}(0,I)$ 112 | 2. **Iterate for** $t=T$ to $1$: 113 | - Predict $\hat{u}_\theta(z_t,x)$ using MLP 114 | - Compute $z_{t-1}$ via denoising update 115 | 3. **Final prediction**: $\arg\max(z_0)$ 116 | 117 | #### 3. Example (MNIST) 118 | | Step | $z_t$ (Noisy) | $\hat{u}_\theta(z_t,x)$ (Predicted) | $z_{t-1}$ (Refined) | 119 | |------|-------------------------|-------------------------------------|---------------------------| 120 | | t=3 | [0.4, 0.3, 0.3] | [0.1, 0.0, 0.9] | [0.25, 0.05, 0.7] | 121 | | t=2 | [0.25, 0.05, 0.7] | [0.0, 0.0, 1.0] | [0.1, 0.0, 0.9] | 122 | | t=1 | [0.1, 0.0, 0.9] | [0.0, 0.0, 1.0] | [0.0, 0.0, 1.0] (Final) | 123 | 124 | #### 4. Key Properties 125 | 1. **Noise Injection**: 126 | - The $\sqrt{1-\alpha_{t-1}} \epsilon_t$ term prevents deterministic collapse 127 | 2. **Geometric Interpolation**: 128 | - Balances between prediction ($\sqrt{\alpha_{t-1}}$) and noise ($\sqrt{1-\alpha_{t-1}}$) 129 | 3. **Stochasticity**: 130 | - Different noise samples $\epsilon_t$ yield varied trajectories 131 | 132 | 133 | ## ⚙️ Implementation 134 | 135 | ### Core Components 136 | ```python 137 | # 1. Diffusion Noise Scheduler 138 | alpha = torch.linspace(1.0, 0.1, T) # T diffusion steps 139 | 140 | # 2. Denoising MLP (per layer) 141 | class DenoisingMLP(nn.Module): 142 | def forward(self, x_features, z_t): 143 | return self.mlp(torch.cat([x_features, z_t], dim=1)) 144 | 145 | # 3. Training Loop 146 | for t in range(T): 147 | u_hat = mlps[t](x_features, z[t+1].detach()) 148 | loss = F.mse_loss(u_hat, u_y) 149 | ``` 150 | 151 | ### Key Features 152 | - **Parallel Training**: All `T` MLPs update simultaneously 153 | - **Memory Efficient**: No activation storage 154 | - **Flexible**: Works with CNNs/Transformers 155 | 156 | 157 | ## 👤 Author 158 | 159 | For any questions or issues, please open an issue on GitHub: [@Siddharth Mishra](https://github.com/Sid3503) 160 | 161 | --- 162 | 163 |

164 | Made with ❤️ and lots of ☕ 165 |

166 | -------------------------------------------------------------------------------- /NoProp.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "gpuType": "T4", 8 | "authorship_tag": "ABX9TyOg0HFF5C/8CLWLLGaczBlp", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | }, 18 | "accelerator": "GPU" 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "source": [ 34 | "# Setup Libraries" 35 | ], 36 | "metadata": { 37 | "id": "zCu1mXACaKdx" 38 | } 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 1, 43 | "metadata": { 44 | "colab": { 45 | "base_uri": "https://localhost:8080/" 46 | }, 47 | "id": "mve5S0YE-oeM", 48 | "outputId": "2503536e-6b14-403c-c38e-b64dd36983bf" 49 | }, 50 | "outputs": [ 51 | { 52 | "output_type": "stream", 53 | "name": "stdout", 54 | "text": [ 55 | "Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n", 56 | "Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (0.21.0+cu124)\n", 57 | "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n", 58 | "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.13.0)\n", 59 | "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.4.2)\n", 60 | "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n", 61 | "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.3.2)\n", 62 | "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)\n", 63 | " Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", 64 | "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)\n", 65 | " Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", 66 | "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)\n", 67 | " Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", 68 | "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)\n", 69 | " Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", 70 | "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)\n", 71 | " Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", 72 | "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)\n", 73 | " Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", 74 | "Collecting nvidia-curand-cu12==10.3.5.147 (from torch)\n", 75 | " Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", 76 | "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)\n", 77 | " Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", 78 | "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch)\n", 79 | " Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", 80 | "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch) (0.6.2)\n", 81 | "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n", 82 | "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", 83 | "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch)\n", 84 | " Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", 85 | "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.2.0)\n", 86 | "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n", 87 | "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n", 88 | "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision) (2.0.2)\n", 89 | "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision) (11.1.0)\n", 90 | "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n", 91 | "Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n", 92 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m4.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 93 | "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n", 94 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m53.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 95 | "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n", 96 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m58.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 97 | "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n", 98 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m45.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 99 | "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n", 100 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m837.9 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 101 | "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n", 102 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m6.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 103 | "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n", 104 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 105 | "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n", 106 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m9.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 107 | "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n", 108 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 109 | "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", 110 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m92.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 111 | "\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12\n", 112 | " Attempting uninstall: nvidia-nvjitlink-cu12\n", 113 | " Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n", 114 | " Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n", 115 | " Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n", 116 | " Attempting uninstall: nvidia-curand-cu12\n", 117 | " Found existing installation: nvidia-curand-cu12 10.3.6.82\n", 118 | " Uninstalling nvidia-curand-cu12-10.3.6.82:\n", 119 | " Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n", 120 | " Attempting uninstall: nvidia-cufft-cu12\n", 121 | " Found existing installation: nvidia-cufft-cu12 11.2.3.61\n", 122 | " Uninstalling nvidia-cufft-cu12-11.2.3.61:\n", 123 | " Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n", 124 | " Attempting uninstall: nvidia-cuda-runtime-cu12\n", 125 | " Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n", 126 | " Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n", 127 | " Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n", 128 | " Attempting uninstall: nvidia-cuda-nvrtc-cu12\n", 129 | " Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n", 130 | " Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n", 131 | " Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n", 132 | " Attempting uninstall: nvidia-cuda-cupti-cu12\n", 133 | " Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n", 134 | " Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n", 135 | " Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n", 136 | " Attempting uninstall: nvidia-cublas-cu12\n", 137 | " Found existing installation: nvidia-cublas-cu12 12.5.3.2\n", 138 | " Uninstalling nvidia-cublas-cu12-12.5.3.2:\n", 139 | " Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n", 140 | " Attempting uninstall: nvidia-cusparse-cu12\n", 141 | " Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n", 142 | " Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n", 143 | " Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n", 144 | " Attempting uninstall: nvidia-cudnn-cu12\n", 145 | " Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n", 146 | " Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n", 147 | " Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n", 148 | " Attempting uninstall: nvidia-cusolver-cu12\n", 149 | " Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n", 150 | " Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n", 151 | " Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n", 152 | "Successfully installed nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127\n" 153 | ] 154 | } 155 | ], 156 | "source": [ 157 | "!pip install torch torchvision" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "source": [ 163 | "import torch\n", 164 | "import torch.nn as nn\n", 165 | "import torch.optim as optim\n", 166 | "from torchvision import datasets, transforms\n", 167 | "from torch.utils.data import DataLoader" 168 | ], 169 | "metadata": { 170 | "id": "ZoxseEX7_Fy1" 171 | }, 172 | "execution_count": 2, 173 | "outputs": [] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "source": [ 178 | "# Hyperparameters" 179 | ], 180 | "metadata": { 181 | "id": "nxES-OQfaZaF" 182 | } 183 | }, 184 | { 185 | "cell_type": "code", 186 | "source": [ 187 | "# Hyperparameters\n", 188 | "T = 10 # Diffusion steps\n", 189 | "embed_dim = 10 # Label embedding dimension(No. of Classes)\n", 190 | "batch_size = 128\n", 191 | "lr = 0.001\n", 192 | "epochs = 50" 193 | ], 194 | "metadata": { 195 | "id": "GApOWb8K_QbC" 196 | }, 197 | "execution_count": 22, 198 | "outputs": [] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "source": [ 203 | "# Noise schedule (linear)\n", 204 | "alpha = torch.linspace(1.0, 0.1, T) # α_t from 1.0 → 0.1" 205 | ], 206 | "metadata": { 207 | "id": "vzSRvuN-_QUv" 208 | }, 209 | "execution_count": 4, 210 | "outputs": [] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "source": [ 215 | "print(alpha)" 216 | ], 217 | "metadata": { 218 | "colab": { 219 | "base_uri": "https://localhost:8080/" 220 | }, 221 | "id": "Sqfg91HP_QSx", 222 | "outputId": "10d03ab1-b7ad-4d32-92fc-47d80b4aae24" 223 | }, 224 | "execution_count": 6, 225 | "outputs": [ 226 | { 227 | "output_type": "stream", 228 | "name": "stdout", 229 | "text": [ 230 | "tensor([1.0000, 0.9000, 0.8000, 0.7000, 0.6000, 0.5000, 0.4000, 0.3000, 0.2000,\n", 231 | " 0.1000])\n" 232 | ] 233 | } 234 | ] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "source": [ 239 | "# Setting Up Models" 240 | ], 241 | "metadata": { 242 | "id": "eeZ3MSd3adhn" 243 | } 244 | }, 245 | { 246 | "cell_type": "code", 247 | "source": [ 248 | "# MLP for denoising\n", 249 | "class DenoisingMLP(nn.Module):\n", 250 | " def __init__(self):\n", 251 | " super().__init__()\n", 252 | " self.mlp = nn.Sequential(\n", 253 | " nn.Linear(128 + embed_dim, 256), # Input: image features + noisy label\n", 254 | " nn.ReLU(),\n", 255 | " nn.Linear(256, embed_dim) # Output: denoised label\n", 256 | " )\n", 257 | "\n", 258 | " def forward(self, x_features, z_t):\n", 259 | " combined = torch.cat([x_features, z_t], dim=1)\n", 260 | " return self.mlp(combined)" 261 | ], 262 | "metadata": { 263 | "id": "uJ26Tb7Z_QQR" 264 | }, 265 | "execution_count": 7, 266 | "outputs": [] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "source": [ 271 | "# CNN for image features\n", 272 | "class CNN(nn.Module):\n", 273 | " def __init__(self):\n", 274 | " super().__init__()\n", 275 | " self.features = nn.Sequential(\n", 276 | " nn.Conv2d(1, 32, 3, 1),\n", 277 | " nn.ReLU(),\n", 278 | " nn.MaxPool2d(2),\n", 279 | " nn.Conv2d(32, 64, 3, 1),\n", 280 | " nn.ReLU(),\n", 281 | " nn.MaxPool2d(2),\n", 282 | " nn.Flatten(),\n", 283 | " nn.Linear(1600, 128) # MNIST: 28x28 → 1600-dim\n", 284 | " )\n", 285 | "\n", 286 | " def forward(self, x):\n", 287 | " return self.features(x)" 288 | ], 289 | "metadata": { 290 | "id": "JDYgzrdr_tsx" 291 | }, 292 | "execution_count": 8, 293 | "outputs": [] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "source": [ 298 | "# Initialize models\n", 299 | "cnn = CNN()\n", 300 | "mlps = nn.ModuleList([DenoisingMLP() for _ in range(T)]) # One MLP per layer\n", 301 | "optimizers = [optim.Adam(mlp.parameters(), lr=lr) for mlp in mlps]" 302 | ], 303 | "metadata": { 304 | "id": "bTXUdHIm_wEw" 305 | }, 306 | "execution_count": 9, 307 | "outputs": [] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "source": [ 312 | "cnn" 313 | ], 314 | "metadata": { 315 | "colab": { 316 | "base_uri": "https://localhost:8080/" 317 | }, 318 | "id": "3c_c5vIBaE1o", 319 | "outputId": "348103a5-19c8-4a99-db33-d5997ab6a314" 320 | }, 321 | "execution_count": 55, 322 | "outputs": [ 323 | { 324 | "output_type": "execute_result", 325 | "data": { 326 | "text/plain": [ 327 | "CNN(\n", 328 | " (features): Sequential(\n", 329 | " (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))\n", 330 | " (1): ReLU()\n", 331 | " (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 332 | " (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n", 333 | " (4): ReLU()\n", 334 | " (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 335 | " (6): Flatten(start_dim=1, end_dim=-1)\n", 336 | " (7): Linear(in_features=1600, out_features=128, bias=True)\n", 337 | " )\n", 338 | ")" 339 | ] 340 | }, 341 | "metadata": {}, 342 | "execution_count": 55 343 | } 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "source": [ 349 | "mlps" 350 | ], 351 | "metadata": { 352 | "colab": { 353 | "base_uri": "https://localhost:8080/" 354 | }, 355 | "id": "uwA8rRCqaF4C", 356 | "outputId": "47eed38e-48ce-4186-a945-d0f0795f1f95" 357 | }, 358 | "execution_count": 56, 359 | "outputs": [ 360 | { 361 | "output_type": "execute_result", 362 | "data": { 363 | "text/plain": [ 364 | "ModuleList(\n", 365 | " (0-9): 10 x DenoisingMLP(\n", 366 | " (mlp): Sequential(\n", 367 | " (0): Linear(in_features=138, out_features=256, bias=True)\n", 368 | " (1): ReLU()\n", 369 | " (2): Linear(in_features=256, out_features=10, bias=True)\n", 370 | " )\n", 371 | " )\n", 372 | ")" 373 | ] 374 | }, 375 | "metadata": {}, 376 | "execution_count": 56 377 | } 378 | ] 379 | }, 380 | { 381 | "cell_type": "markdown", 382 | "source": [ 383 | "# Preparing Dataset" 384 | ], 385 | "metadata": { 386 | "id": "TY3BUc9Dah44" 387 | } 388 | }, 389 | { 390 | "cell_type": "code", 391 | "source": [ 392 | "# Load MNIST\n", 393 | "transform = transforms.ToTensor()\n", 394 | "train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n", 395 | "train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)" 396 | ], 397 | "metadata": { 398 | "colab": { 399 | "base_uri": "https://localhost:8080/" 400 | }, 401 | "id": "zZAWOkNB_xuF", 402 | "outputId": "3acd5dba-09e9-40a9-cd8c-16e8a87d6ffc" 403 | }, 404 | "execution_count": 10, 405 | "outputs": [ 406 | { 407 | "output_type": "stream", 408 | "name": "stderr", 409 | "text": [ 410 | "100%|██████████| 9.91M/9.91M [00:02<00:00, 4.56MB/s]\n", 411 | "100%|██████████| 28.9k/28.9k [00:00<00:00, 65.1kB/s]\n", 412 | "100%|██████████| 1.65M/1.65M [00:01<00:00, 1.26MB/s]\n", 413 | "100%|██████████| 4.54k/4.54k [00:00<00:00, 6.77MB/s]\n" 414 | ] 415 | } 416 | ] 417 | }, 418 | { 419 | "cell_type": "markdown", 420 | "source": [ 421 | "# Training" 422 | ], 423 | "metadata": { 424 | "id": "0lgU9EC0amwQ" 425 | } 426 | }, 427 | { 428 | "cell_type": "markdown", 429 | "source": [ 430 | "---\n", 431 | "\n", 432 | "### **NoProp Training Dry Run (1 Epoch Example)**\n", 433 | "\n", 434 | "\n", 435 | "# 🧪 NoProp Training Dry Run (1 Epoch)\n", 436 | "\n", 437 | "Let's simulate **one epoch** of training with:\n", 438 | "- Batch size = 3\n", 439 | "- Classes = 3 (\"cat\", \"dog\", \"bird\")\n", 440 | "- Diffusion steps (T) = 2\n", 441 | "\n", 442 | "## 📥 Batch Data\n", 443 | "**Input (x):** 3 images \n", 444 | "**Labels (y):** [\"cat\"=0, \"dog\"=1, \"bird\"=2] \n", 445 | "→ One-hot encoded `u_y`:\n", 446 | "```\n", 447 | "tensor([\n", 448 | " [1., 0., 0.], # cat\n", 449 | " [0., 1., 0.], # dog\n", 450 | " [0., 0., 1.] # bird\n", 451 | "])\n", 452 | "```\n", 453 | "\n", 454 | "---\n", 455 | "\n", 456 | "## 🌪️ Forward Diffusion (Corrupt Labels)\n", 457 | "**Noise schedule (α):** [α₀=1.0, α₁=0.6, α₂=0.3] \n", 458 | "1. **t=0:** `z₀ = u_y` (clean)\n", 459 | "\n", 460 | "2. **t=1:** \n", 461 | " `z₁ = √0.6*z₀ + √0.4*noise` ≈\n", 462 | "```\n", 463 | "tensor([\n", 464 | " [0.77, 0.20, 0.03], # noisy cat\n", 465 | " [0.10, 0.85, 0.05], # noisy dog\n", 466 | " [0.05, 0.10, 0.85] # noisy bird\n", 467 | "])\n", 468 | "```\n", 469 | "\n", 470 | "\n", 471 | "3. **t=2:** \n", 472 | " `z₂ = √0.3*z₁ + √0.7*noise` ≈\n", 473 | "\n", 474 | "```\n", 475 | "tensor([\n", 476 | " [0.40, 0.35, 0.25], # very noisy cat\n", 477 | " [0.25, 0.45, 0.30], # very noisy dog\n", 478 | " [0.20, 0.25, 0.55] # very noisy bird\n", 479 | "])\n", 480 | "```\n", 481 | "\n", 482 | "---\n", 483 | "\n", 484 | "**Example for MLP1 (t=1):**\n", 485 | "- **Input:** Image features + `z₁ = [0.77,0.20,0.03]` (noisy cat)\n", 486 | "- **Prediction:** `[0.9, 0.1, 0.0]` (should approach `[1,0,0]`)\n", 487 | "- **Loss:** MSE([0.9,0.1,0.0], [1,0,0]) = 0.01\n", 488 | "\n", 489 | "## 🔄 Weight Updates\n", 490 | "1. Sum losses from all MLPs → `total_loss`\n", 491 | "2. Backpropagate → Update all MLPs independently\n", 492 | "\n", 493 | "## 📊 Epoch Output\n", 494 | "`Epoch 1/10 | Avg Loss: 0.85` \n", 495 | "*(Loss decreases as MLPs learn to denoise better)*\n", 496 | "\n", 497 | "---\n", 498 | "\n", 499 | "### **Key Takeaways**\n", 500 | "1. **Diffusion**: Labels are progressively noised from clean → random.\n", 501 | "2. **Specialization**: Each MLP handles a specific noise level.\n", 502 | "3. **Independence**: No backprop between layers → parallel training." 503 | ], 504 | "metadata": { 505 | "id": "HjthYG3JddhO" 506 | } 507 | }, 508 | { 509 | "cell_type": "code", 510 | "source": [ 511 | "# Training loop\n", 512 | "for epoch in range(epochs):\n", 513 | " epoch_loss = 0.0\n", 514 | " batch_count = 0\n", 515 | "\n", 516 | " for x, y in train_loader:\n", 517 | " current_batch_size = x.shape[0]\n", 518 | " u_y = torch.zeros(current_batch_size, embed_dim).scatter_(1, y.unsqueeze(1), 1)\n", 519 | "\n", 520 | " # Forward diffusion(Adding Noise for each 'T')\n", 521 | " z = [u_y]\n", 522 | " for t in range(1, T + 1):\n", 523 | " eps = torch.randn_like(u_y)\n", 524 | " z_t = torch.sqrt(alpha[t-1]) * z[-1] + torch.sqrt(1 - alpha[t-1]) * eps\n", 525 | " z.append(z_t)\n", 526 | "\n", 527 | " # Train MLPs\n", 528 | " x_features = cnn(x)\n", 529 | " losses = []\n", 530 | " for t in range(T):\n", 531 | " u_hat = mlps[t](x_features, z[t+1].detach())\n", 532 | " losses.append(torch.mean((u_hat - u_y) ** 2))\n", 533 | "\n", 534 | " total_loss = sum(losses)\n", 535 | " for opt in optimizers:\n", 536 | " opt.zero_grad()\n", 537 | " total_loss.backward()\n", 538 | " for opt in optimizers:\n", 539 | " opt.step()\n", 540 | "\n", 541 | " epoch_loss += total_loss.item()\n", 542 | " batch_count += 1\n", 543 | "\n", 544 | " # Epoch summary print\n", 545 | " avg_loss = epoch_loss / batch_count\n", 546 | " print(f\"Epoch {epoch+1}/{epochs} | Avg Loss: {avg_loss:.4f}\")\n", 547 | "\n", 548 | "# Final message\n", 549 | "print(\"Training complete!\")" 550 | ], 551 | "metadata": { 552 | "colab": { 553 | "base_uri": "https://localhost:8080/" 554 | }, 555 | "id": "KtGT_c1H_y9r", 556 | "outputId": "2ae2430a-4e96-4e08-9ce6-1e7bc7b719ed" 557 | }, 558 | "execution_count": 23, 559 | "outputs": [ 560 | { 561 | "output_type": "stream", 562 | "name": "stdout", 563 | "text": [ 564 | "Epoch 1/50 | Avg Loss: 0.0835\n", 565 | "Epoch 2/50 | Avg Loss: 0.0819\n", 566 | "Epoch 3/50 | Avg Loss: 0.0806\n", 567 | "Epoch 4/50 | Avg Loss: 0.0791\n", 568 | "Epoch 5/50 | Avg Loss: 0.0778\n", 569 | "Epoch 6/50 | Avg Loss: 0.0766\n", 570 | "Epoch 7/50 | Avg Loss: 0.0753\n", 571 | "Epoch 8/50 | Avg Loss: 0.0744\n", 572 | "Epoch 9/50 | Avg Loss: 0.0734\n", 573 | "Epoch 10/50 | Avg Loss: 0.0725\n", 574 | "Epoch 11/50 | Avg Loss: 0.0714\n", 575 | "Epoch 12/50 | Avg Loss: 0.0706\n", 576 | "Epoch 13/50 | Avg Loss: 0.0700\n", 577 | "Epoch 14/50 | Avg Loss: 0.0692\n", 578 | "Epoch 15/50 | Avg Loss: 0.0685\n", 579 | "Epoch 16/50 | Avg Loss: 0.0677\n", 580 | "Epoch 17/50 | Avg Loss: 0.0671\n", 581 | "Epoch 18/50 | Avg Loss: 0.0664\n", 582 | "Epoch 19/50 | Avg Loss: 0.0660\n", 583 | "Epoch 20/50 | Avg Loss: 0.0654\n", 584 | "Epoch 21/50 | Avg Loss: 0.0648\n", 585 | "Epoch 22/50 | Avg Loss: 0.0643\n", 586 | "Epoch 23/50 | Avg Loss: 0.0637\n", 587 | "Epoch 24/50 | Avg Loss: 0.0631\n", 588 | "Epoch 25/50 | Avg Loss: 0.0627\n", 589 | "Epoch 26/50 | Avg Loss: 0.0621\n", 590 | "Epoch 27/50 | Avg Loss: 0.0617\n", 591 | "Epoch 28/50 | Avg Loss: 0.0615\n", 592 | "Epoch 29/50 | Avg Loss: 0.0614\n", 593 | "Epoch 30/50 | Avg Loss: 0.0604\n", 594 | "Epoch 31/50 | Avg Loss: 0.0600\n", 595 | "Epoch 32/50 | Avg Loss: 0.0596\n", 596 | "Epoch 33/50 | Avg Loss: 0.0594\n", 597 | "Epoch 34/50 | Avg Loss: 0.0589\n", 598 | "Epoch 35/50 | Avg Loss: 0.0588\n", 599 | "Epoch 36/50 | Avg Loss: 0.0583\n", 600 | "Epoch 37/50 | Avg Loss: 0.0581\n", 601 | "Epoch 38/50 | Avg Loss: 0.0577\n", 602 | "Epoch 39/50 | Avg Loss: 0.0575\n", 603 | "Epoch 40/50 | Avg Loss: 0.0570\n", 604 | "Epoch 41/50 | Avg Loss: 0.0568\n", 605 | "Epoch 42/50 | Avg Loss: 0.0564\n", 606 | "Epoch 43/50 | Avg Loss: 0.0562\n", 607 | "Epoch 44/50 | Avg Loss: 0.0560\n", 608 | "Epoch 45/50 | Avg Loss: 0.0558\n", 609 | "Epoch 46/50 | Avg Loss: 0.0555\n", 610 | "Epoch 47/50 | Avg Loss: 0.0552\n", 611 | "Epoch 48/50 | Avg Loss: 0.0551\n", 612 | "Epoch 49/50 | Avg Loss: 0.0548\n", 613 | "Epoch 50/50 | Avg Loss: 0.0545\n", 614 | "Training complete!\n" 615 | ] 616 | } 617 | ] 618 | }, 619 | { 620 | "cell_type": "markdown", 621 | "source": [ 622 | "# Inferencing" 623 | ], 624 | "metadata": { 625 | "id": "kGrsvukHaqGX" 626 | } 627 | }, 628 | { 629 | "cell_type": "code", 630 | "source": [ 631 | "# Inference (denoising)\n", 632 | "def predict(x):\n", 633 | " z_t = torch.randn(1, embed_dim) # Start from noise\n", 634 | " x_features = cnn(x.unsqueeze(0))\n", 635 | " for t in reversed(range(T)):\n", 636 | " u_hat = mlps[t](x_features, z_t)\n", 637 | " z_t = torch.sqrt(alpha[t]) * u_hat + torch.sqrt(1 - alpha[t]) * torch.randn_like(u_hat)\n", 638 | " return torch.argmax(z_t) # Final prediction" 639 | ], 640 | "metadata": { 641 | "id": "LINA1yBV_0XH" 642 | }, 643 | "execution_count": 24, 644 | "outputs": [] 645 | }, 646 | { 647 | "cell_type": "code", 648 | "source": [ 649 | "# Test on an example\n", 650 | "x_test, y_test = next(iter(train_loader))\n", 651 | "pred = predict(x_test[0])\n", 652 | "print(f\"Predicted: {pred}, True: {y_test[0]}\")" 653 | ], 654 | "metadata": { 655 | "colab": { 656 | "base_uri": "https://localhost:8080/" 657 | }, 658 | "id": "Qs8kIZcb_0Tp", 659 | "outputId": "8823f70e-d3f6-4eb8-c2a6-bd771136c828" 660 | }, 661 | "execution_count": 48, 662 | "outputs": [ 663 | { 664 | "output_type": "stream", 665 | "name": "stdout", 666 | "text": [ 667 | "Predicted: 1, True: 1\n" 668 | ] 669 | } 670 | ] 671 | }, 672 | { 673 | "cell_type": "code", 674 | "source": [ 675 | "import matplotlib.pyplot as plt\n", 676 | "import numpy as np\n", 677 | "\n", 678 | "def plot_prediction(x, true_label, pred_label, class_names=None):\n", 679 | " \"\"\"\n", 680 | " Plot image with true and predicted labels.\n", 681 | "\n", 682 | " Args:\n", 683 | " x (torch.Tensor): Input image tensor (1, C, H, W)\n", 684 | " true_label (int): Ground truth class index\n", 685 | " pred_label (int): Predicted class index\n", 686 | " class_names (list): Optional list of class names\n", 687 | " \"\"\"\n", 688 | " # Convert tensor to numpy and denormalize if needed\n", 689 | " img = x.squeeze().cpu().numpy()\n", 690 | " if img.min() < 0 or img.max() > 1: # Assuming [0,1] or [-1,1] range\n", 691 | " img = (img - img.min()) / (img.max() - img.min())\n", 692 | "\n", 693 | " # Create figure\n", 694 | " plt.figure(figsize=(6, 3))\n", 695 | "\n", 696 | " # Plot image\n", 697 | " plt.subplot(1, 2, 1)\n", 698 | " plt.imshow(img, cmap='gray' if img.ndim == 2 else None)\n", 699 | " plt.axis('off')\n", 700 | " plt.title('Input Image', pad=10)\n", 701 | "\n", 702 | " # Plot labels\n", 703 | " plt.subplot(1, 2, 2)\n", 704 | " plt.axis('off')\n", 705 | "\n", 706 | " if class_names:\n", 707 | " true_str = f\"True: {class_names[true_label]} ({true_label})\"\n", 708 | " pred_str = f\"Predicted: {class_names[pred_label]} ({pred_label})\"\n", 709 | " else:\n", 710 | " true_str = f\"True label: {true_label}\"\n", 711 | " pred_str = f\"Predicted: {pred_label}\"\n", 712 | "\n", 713 | " plt.text(0.1, 0.7, true_str, fontsize=12, color='green')\n", 714 | " plt.text(0.1, 0.5, pred_str,\n", 715 | " fontsize=12,\n", 716 | " color='red' if true_label != pred_label else 'green')\n", 717 | "\n", 718 | " # Highlight incorrect predictions\n", 719 | " if true_label != pred_label:\n", 720 | " plt.text(0.1, 0.3, \"INCORRECT\", fontsize=14, color='red', weight='bold')\n", 721 | "\n", 722 | " plt.tight_layout()\n", 723 | " plt.show()" 724 | ], 725 | "metadata": { 726 | "id": "XTHv45S5_0RO" 727 | }, 728 | "execution_count": 29, 729 | "outputs": [] 730 | }, 731 | { 732 | "cell_type": "code", 733 | "source": [ 734 | "# Example usage with MNIST\n", 735 | "class_names = [str(i) for i in range(10)] # ['0', '1', ..., '9']\n", 736 | "x_test, y_test = next(iter(train_loader))\n", 737 | "pred = predict(x_test[0])\n", 738 | "\n", 739 | "plot_prediction(x_test[0],\n", 740 | " y_test[0].item(),\n", 741 | " pred.item(),\n", 742 | " class_names)" 743 | ], 744 | "metadata": { 745 | "colab": { 746 | "base_uri": "https://localhost:8080/", 747 | "height": 316 748 | }, 749 | "id": "WK1nzjCt_0O-", 750 | "outputId": "d3afcc0a-d137-4ad0-dec9-e1a7fe0faa07" 751 | }, 752 | "execution_count": 57, 753 | "outputs": [ 754 | { 755 | "output_type": "display_data", 756 | "data": { 757 | "text/plain": [ 758 | "
" 759 | ], 760 | "image/png": "\n" 761 | }, 762 | "metadata": {} 763 | } 764 | ] 765 | }, 766 | { 767 | "cell_type": "code", 768 | "source": [], 769 | "metadata": { 770 | "id": "nGEr22BecAXl" 771 | }, 772 | "execution_count": null, 773 | "outputs": [] 774 | } 775 | ] 776 | } 777 | --------------------------------------------------------------------------------