├── README.md ├── MegaScale_Infer_v1.ipynb └── MegaScale_Infer_v1_1.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # MegaScale-Infer Prototype 2 | 3 | ## Overview 4 | 5 | This repository contains a Python-based prototype implementation of the "MegaScale-Infer" system, inspired by the paper *"MegaScale-Infer: Serving Mixture-of-Experts at Scale with Disaggregated Expert Parallelism"* by Ruidong Zhu et al. (ByteDance Seed & Peking University, 2025). The system is designed for efficient and cost-effective serving of large-scale Mixture-of-Experts (MoE) models by disaggregating attention and feed-forward network (FFN) modules, employing ping-pong pipeline parallelism, and optimizing communication. 6 | 7 | This prototype adapts these concepts for training an MoE model in a Google Colab environment, simulating disaggregation and communication due to Colab's single-device limitation. It includes a simplified MoE architecture, a synthetic dataset, and a training loop, making it suitable for experimentation and educational purposes. 8 | 9 | ## Features 10 | 11 | - **Disaggregated Expert Parallelism**: Simulates separation of attention and FFN (expert) modules. 12 | - **Ping-Pong Pipeline Parallelism**: Processes micro-batches to mimic overlapping computation and communication. 13 | - **MoE Layer**: Implements a gating mechanism with top-k expert selection. 14 | - **Training Support**: Includes a full training loop with a dummy dataset, loss function, and optimizer. 15 | - **Simulated M2N Communication**: Uses delays to emulate network latency between modules. 16 | 17 | ## Prerequisites 18 | 19 | - **Google Colab Account**: Free tier with GPU runtime recommended (e.g., T4 GPU). 20 | - **Python**: Version 3.7+ (pre-installed in Colab). 21 | - **PyTorch**: Version 2.0+ (installable via `pip` if not pre-installed). 22 | 23 | ## Setup 24 | 25 | 1. **Open Google Colab**: 26 | - Visit [Google Colab](https://colab.research.google.com/) and create a new notebook. 27 | 28 | 2. **Set Runtime to GPU** (optional, for faster training): 29 | - Click `Runtime` > `Change runtime type` > Select `GPU` > Save. 30 | 31 | 3. **Install Dependencies**: 32 | - Run the following command in a Colab cell to ensure PyTorch is installed: 33 | ```bash 34 | !pip install torch 35 | ``` 36 | - Colab typically has PyTorch pre-installed (e.g., 2.0.x as of April 2025). Verify with: 37 | ```python 38 | import torch 39 | print(torch.__version__) 40 | ``` 41 | 42 | 4. **Copy the Script**: 43 | - Paste the full training script (provided separately) into a Colab code cell. 44 | 45 | ## Usage 46 | 47 | 1. **Run the Script**: 48 | - Execute the cell containing the script by clicking the play button or pressing `Shift + Enter`. 49 | - The script will: 50 | - Initialize a `MegaScaleInfer` model with an MoE architecture. 51 | - Create a synthetic dataset (`DummyDataset`) with random inputs and targets. 52 | - Train the model for 5 epochs, printing loss and timing metrics. 53 | 54 | 2. **Sample Output**: 55 | ``` 56 | Running on: cuda 57 | Starting training... 58 | Micro-batch 1/4: Attention time: 0.0123s, Expert time: 0.0456s 59 | Micro-batch 2/4: Attention time: 0.0118s, Expert time: 0.0432s 60 | ... 61 | Epoch [1/5], Batch [0/32], Loss: 1.2345 62 | Epoch [1/5] completed in 12.34s, Avg Loss: 1.1234, Total Attention Time: 0.4567s, Total Expert Time: 1.6789s 63 | ... 64 | Training completed! 65 | ``` 66 | 67 | 3. **Customize Hyperparameters**: 68 | - Modify the `main()` function to adjust: 69 | - `hidden_size`: Dimensionality of input/output (default: 256). 70 | - `num_experts`: Number of FFN experts (default: 8). 71 | - `top_k`: Number of experts selected per token (default: 2). 72 | - `batch_size`: Training batch size (default: 32). 73 | - `num_micro_batches`: Number of micro-batches for ping-pong pipeline (default: 4). 74 | - `num_epochs`: Training epochs (default: 5). 75 | 76 | ## Code Structure 77 | 78 | - **`AttentionModule`**: Simplified multi-head attention with QKV projection and KV cache. 79 | - **`Expert`**: Single FFN expert with two linear layers and ReLU activation. 80 | - **`MoELayer`**: MoE layer with gating and top-k expert dispatch. 81 | - **`MegaScaleInfer`**: Main model integrating attention, MoE, and pipeline parallelism. 82 | - **`DummyDataset`**: Synthetic dataset for regression task. 83 | - **`train_megascale_infer`**: Training loop with loss computation and optimization. 84 | - **`simulate_m2n_communication`**: Placeholder for M2N communication latency. 85 | 86 | ## Limitations 87 | 88 | - **Single Device**: Runs on one GPU/CPU in Colab, simulating disaggregation rather than using multiple nodes. 89 | - **Simplified Task**: Uses a dummy regression task; real NLP tasks require additional datasets and tokenization. 90 | - **No Real M2N**: Communication is simulated with `time.sleep` due to lack of RDMA or multi-GPU support in Colab. 91 | - **Resource Constraints**: Colab’s free tier (e.g., 12GB GPU memory) limits model size and batch size. 92 | 93 | ## Extending the Prototype 94 | 95 | - **Real Dataset**: Replace `DummyDataset` with a real dataset (e.g., via `!pip install datasets` and `datasets` library). 96 | - **Task Enhancement**: Switch to language modeling by adding a vocabulary and `nn.CrossEntropyLoss`. 97 | - **Profiling**: Use `torch.profiler` for detailed performance analysis. 98 | - **Multi-GPU**: Adapt for a multi-GPU setup outside Colab using `torch.distributed`. 99 | 100 | ## References 101 | 102 | - Zhu, Ruidong et al. "MegaScale-Infer: Serving Mixture-of-Experts at Scale with Disaggregated Expert Parallelism." ByteDance Seed & Peking University, 2025. 103 | 104 | ## License 105 | 106 | This prototype is provided for educational and experimental purposes under the MIT License. 107 | 108 | -------------------------------------------------------------------------------- /MegaScale_Infer_v1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "machine_shape": "hm", 8 | "gpuType": "A100" 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | }, 17 | "accelerator": "GPU" 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "code", 22 | "execution_count": 1, 23 | "metadata": { 24 | "colab": { 25 | "base_uri": "https://localhost:8080/" 26 | }, 27 | "id": "4Asakpq46fP7", 28 | "outputId": "da74e4ef-ba6f-4397-970b-738b5b553df8" 29 | }, 30 | "outputs": [ 31 | { 32 | "output_type": "stream", 33 | "name": "stdout", 34 | "text": [ 35 | "Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n", 36 | "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (2.0.2)\n", 37 | "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n", 38 | "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.13.0)\n", 39 | "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.4.2)\n", 40 | "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n", 41 | "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.3.2)\n", 42 | "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)\n", 43 | " Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", 44 | "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)\n", 45 | " Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", 46 | "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)\n", 47 | " Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", 48 | "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)\n", 49 | " Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", 50 | "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)\n", 51 | " Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", 52 | "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)\n", 53 | " Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", 54 | "Collecting nvidia-curand-cu12==10.3.5.147 (from torch)\n", 55 | " Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", 56 | "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)\n", 57 | " Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", 58 | "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch)\n", 59 | " Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", 60 | "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch) (0.6.2)\n", 61 | "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n", 62 | "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", 63 | "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch)\n", 64 | " Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", 65 | "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.2.0)\n", 66 | "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n", 67 | "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n", 68 | "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n", 69 | "Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n", 70 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m3.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 71 | "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n", 72 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m72.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 73 | "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n", 74 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m57.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 75 | "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n", 76 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m50.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 77 | "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n", 78 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m1.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 79 | "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n", 80 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m10.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 81 | "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n", 82 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m37.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 83 | "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n", 84 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m10.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 85 | "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n", 86 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m10.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 87 | "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", 88 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m88.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 89 | "\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12\n", 90 | " Attempting uninstall: nvidia-nvjitlink-cu12\n", 91 | " Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n", 92 | " Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n", 93 | " Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n", 94 | " Attempting uninstall: nvidia-curand-cu12\n", 95 | " Found existing installation: nvidia-curand-cu12 10.3.6.82\n", 96 | " Uninstalling nvidia-curand-cu12-10.3.6.82:\n", 97 | " Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n", 98 | " Attempting uninstall: nvidia-cufft-cu12\n", 99 | " Found existing installation: nvidia-cufft-cu12 11.2.3.61\n", 100 | " Uninstalling nvidia-cufft-cu12-11.2.3.61:\n", 101 | " Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n", 102 | " Attempting uninstall: nvidia-cuda-runtime-cu12\n", 103 | " Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n", 104 | " Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n", 105 | " Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n", 106 | " Attempting uninstall: nvidia-cuda-nvrtc-cu12\n", 107 | " Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n", 108 | " Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n", 109 | " Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n", 110 | " Attempting uninstall: nvidia-cuda-cupti-cu12\n", 111 | " Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n", 112 | " Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n", 113 | " Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n", 114 | " Attempting uninstall: nvidia-cublas-cu12\n", 115 | " Found existing installation: nvidia-cublas-cu12 12.5.3.2\n", 116 | " Uninstalling nvidia-cublas-cu12-12.5.3.2:\n", 117 | " Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n", 118 | " Attempting uninstall: nvidia-cusparse-cu12\n", 119 | " Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n", 120 | " Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n", 121 | " Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n", 122 | " Attempting uninstall: nvidia-cudnn-cu12\n", 123 | " Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n", 124 | " Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n", 125 | " Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n", 126 | " Attempting uninstall: nvidia-cusolver-cu12\n", 127 | " Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n", 128 | " Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n", 129 | " Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n", 130 | "Successfully installed nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127\n" 131 | ] 132 | } 133 | ], 134 | "source": [ 135 | "!pip install torch numpy" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "source": [ 141 | "import torch\n", 142 | "import torch.nn as nn\n", 143 | "import time\n", 144 | "from torch.cuda import nvtx # For profiling (optional, requires GPU)\n", 145 | "import numpy as np\n", 146 | "\n", 147 | "# Check if GPU is available\n", 148 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 149 | "print(f\"Running on: {device}\")\n", 150 | "\n", 151 | "# Simulated Attention Module\n", 152 | "class AttentionModule(nn.Module):\n", 153 | " def __init__(self, hidden_size, num_heads):\n", 154 | " super(AttentionModule, self).__init__()\n", 155 | " self.hidden_size = hidden_size\n", 156 | " self.num_heads = num_heads\n", 157 | " self.qkv_proj = nn.Linear(hidden_size, hidden_size * 3) # Simplified QKV projection\n", 158 | " self.out_proj = nn.Linear(hidden_size, hidden_size)\n", 159 | "\n", 160 | " def forward(self, x, kv_cache=None):\n", 161 | " batch_size, seq_len, _ = x.size()\n", 162 | " qkv = self.qkv_proj(x).view(batch_size, seq_len, 3, self.num_heads, -1)\n", 163 | " q, k, v = qkv.split(1, dim=2) # Split into Q, K, V\n", 164 | " q, k, v = q.squeeze(2), k.squeeze(2), v.squeeze(2)\n", 165 | "\n", 166 | " # Simulated attention (simplified, no real multi-head computation)\n", 167 | " attn_scores = torch.matmul(q, k.transpose(-1, -2)) / (self.hidden_size ** 0.5)\n", 168 | " attn_weights = torch.softmax(attn_scores, dim=-1)\n", 169 | " attn_output = torch.matmul(attn_weights, v)\n", 170 | " output = self.out_proj(attn_output.view(batch_size, seq_len, -1))\n", 171 | "\n", 172 | " # Update KV cache (simulated)\n", 173 | " new_kv_cache = (k, v) if kv_cache is None else kv_cache\n", 174 | " return output, new_kv_cache\n", 175 | "\n", 176 | "# Simulated FFN Expert Module\n", 177 | "class Expert(nn.Module):\n", 178 | " def __init__(self, hidden_size, intermediate_size):\n", 179 | " super(Expert, self).__init__()\n", 180 | " self.fc1 = nn.Linear(hidden_size, intermediate_size)\n", 181 | " self.fc2 = nn.Linear(intermediate_size, hidden_size)\n", 182 | " self.activation = nn.ReLU()\n", 183 | "\n", 184 | " def forward(self, x):\n", 185 | " return self.fc2(self.activation(self.fc1(x)))\n", 186 | "\n", 187 | "# MoE Layer with Disaggregated Experts\n", 188 | "class MoELayer(nn.Module):\n", 189 | " def __init__(self, hidden_size, num_experts, top_k, intermediate_size):\n", 190 | " super(MoELayer, self).__init__()\n", 191 | " self.hidden_size = hidden_size\n", 192 | " self.num_experts = num_experts\n", 193 | " self.top_k = top_k\n", 194 | " self.gating = nn.Linear(hidden_size, num_experts)\n", 195 | " self.experts = nn.ModuleList([Expert(hidden_size, intermediate_size) for _ in range(num_experts)])\n", 196 | "\n", 197 | " def forward(self, x):\n", 198 | " # Gating to select top-k experts\n", 199 | " gate_scores = self.gating(x)\n", 200 | " topk_values, topk_indices = torch.topk(gate_scores, self.top_k, dim=-1)\n", 201 | " gate_weights = torch.softmax(topk_values, dim=-1)\n", 202 | "\n", 203 | " # Simulate expert dispatch (simplified, all on one device here)\n", 204 | " batch_size, seq_len, _ = x.size()\n", 205 | " output = torch.zeros_like(x)\n", 206 | " for b in range(batch_size):\n", 207 | " for s in range(seq_len):\n", 208 | " token = x[b, s:s+1]\n", 209 | " weights = gate_weights[b, s]\n", 210 | " indices = topk_indices[b, s]\n", 211 | " expert_out = torch.zeros(1, self.hidden_size, device=device)\n", 212 | " for i, idx in enumerate(indices):\n", 213 | " expert_out += weights[i] * self.experts[idx](token)\n", 214 | " output[b, s] = expert_out\n", 215 | " return output\n", 216 | "\n", 217 | "# Simulated M2N Communication (placeholder for real network comms)\n", 218 | "def simulate_m2n_communication(data, sender_id, receiver_id, comm_time=0.001):\n", 219 | " # Simulate network latency\n", 220 | " time.sleep(comm_time)\n", 221 | " return data\n", 222 | "\n", 223 | "# MegaScale-Infer Pipeline\n", 224 | "class MegaScaleInfer(nn.Module):\n", 225 | " def __init__(self, hidden_size, num_experts, top_k, intermediate_size, num_micro_batches):\n", 226 | " super(MegaScaleInfer, self).__init__()\n", 227 | " self.attention = AttentionModule(hidden_size, num_heads=4).to(device)\n", 228 | " self.moe_layer = MoELayer(hidden_size, num_experts, top_k, intermediate_size).to(device)\n", 229 | " self.num_micro_batches = num_micro_batches\n", 230 | "\n", 231 | " def forward(self, x, kv_cache=None):\n", 232 | " batch_size, seq_len, _ = x.size()\n", 233 | " micro_batch_size = max(1, batch_size // self.num_micro_batches)\n", 234 | " outputs = []\n", 235 | "\n", 236 | " # Ping-pong pipeline simulation\n", 237 | " for i in range(0, batch_size, micro_batch_size):\n", 238 | " micro_batch = x[i:i + micro_batch_size]\n", 239 | "\n", 240 | " # Attention computation\n", 241 | " start_time = time.time()\n", 242 | " attn_output, new_kv_cache = self.attention(micro_batch, kv_cache)\n", 243 | " attn_time = time.time() - start_time\n", 244 | "\n", 245 | " # Simulate sending to expert nodes\n", 246 | " attn_output = simulate_m2n_communication(attn_output, sender_id=0, receiver_id=1)\n", 247 | "\n", 248 | " # Expert computation\n", 249 | " start_time = time.time()\n", 250 | " expert_output = self.moe_layer(attn_output)\n", 251 | " expert_time = time.time() - start_time\n", 252 | "\n", 253 | " # Simulate sending back to attention nodes\n", 254 | " expert_output = simulate_m2n_communication(expert_output, sender_id=1, receiver_id=0)\n", 255 | " outputs.append(expert_output)\n", 256 | "\n", 257 | " # Update KV cache for next iteration (simulated autoregressive)\n", 258 | " kv_cache = new_kv_cache\n", 259 | "\n", 260 | " print(f\"Micro-batch {i // micro_batch_size + 1}/{self.num_micro_batches}: \"\n", 261 | " f\"Attention time: {attn_time:.4f}s, Expert time: {expert_time:.4f}s\")\n", 262 | "\n", 263 | " return torch.cat(outputs, dim=0), kv_cache\n", 264 | "\n", 265 | "# Test the prototype\n", 266 | "def test_megascale_infer():\n", 267 | " # Hyperparameters\n", 268 | " hidden_size = 256\n", 269 | " num_experts = 8\n", 270 | " top_k = 2\n", 271 | " intermediate_size = 1024\n", 272 | " batch_size = 16\n", 273 | " seq_len = 32\n", 274 | " num_micro_batches = 4\n", 275 | "\n", 276 | " # Initialize model\n", 277 | " model = MegaScaleInfer(hidden_size, num_experts, top_k, intermediate_size, num_micro_batches)\n", 278 | "\n", 279 | " # Generate dummy input (simulating decoding phase)\n", 280 | " input_tokens = torch.randn(batch_size, seq_len, hidden_size).to(device)\n", 281 | "\n", 282 | " # Run inference\n", 283 | " print(\"Starting inference...\")\n", 284 | " start_time = time.time()\n", 285 | " output, kv_cache = model(input_tokens)\n", 286 | " total_time = time.time() - start_time\n", 287 | "\n", 288 | " print(f\"\\nInference completed in {total_time:.4f} seconds\")\n", 289 | " print(f\"Output shape: {output.shape}\")\n", 290 | "\n", 291 | "if __name__ == \"__main__\":\n", 292 | " test_megascale_infer()" 293 | ], 294 | "metadata": { 295 | "colab": { 296 | "base_uri": "https://localhost:8080/" 297 | }, 298 | "id": "55zOneGg61A3", 299 | "outputId": "a3cf484d-ddbb-4728-dbd4-4cedec2584cb" 300 | }, 301 | "execution_count": 2, 302 | "outputs": [ 303 | { 304 | "output_type": "stream", 305 | "name": "stdout", 306 | "text": [ 307 | "Running on: cuda\n", 308 | "Starting inference...\n", 309 | "Micro-batch 1/4: Attention time: 0.2124s, Expert time: 0.3070s\n", 310 | "Micro-batch 2/4: Attention time: 0.0007s, Expert time: 0.0695s\n", 311 | "Micro-batch 3/4: Attention time: 0.0005s, Expert time: 0.0776s\n", 312 | "Micro-batch 4/4: Attention time: 0.0005s, Expert time: 0.0761s\n", 313 | "\n", 314 | "Inference completed in 0.7829 seconds\n", 315 | "Output shape: torch.Size([16, 32, 256])\n" 316 | ] 317 | } 318 | ] 319 | } 320 | ] 321 | } -------------------------------------------------------------------------------- /MegaScale_Infer_v1_1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "machine_shape": "hm", 8 | "gpuType": "A100" 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | }, 17 | "accelerator": "GPU" 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "code", 22 | "execution_count": 1, 23 | "metadata": { 24 | "colab": { 25 | "base_uri": "https://localhost:8080/" 26 | }, 27 | "id": "4Asakpq46fP7", 28 | "outputId": "da74e4ef-ba6f-4397-970b-738b5b553df8" 29 | }, 30 | "outputs": [ 31 | { 32 | "output_type": "stream", 33 | "name": "stdout", 34 | "text": [ 35 | "Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n", 36 | "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (2.0.2)\n", 37 | "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n", 38 | "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.13.0)\n", 39 | "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.4.2)\n", 40 | "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n", 41 | "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.3.2)\n", 42 | "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)\n", 43 | " Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", 44 | "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)\n", 45 | " Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", 46 | "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)\n", 47 | " Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", 48 | "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)\n", 49 | " Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", 50 | "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)\n", 51 | " Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", 52 | "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)\n", 53 | " Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", 54 | "Collecting nvidia-curand-cu12==10.3.5.147 (from torch)\n", 55 | " Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", 56 | "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)\n", 57 | " Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", 58 | "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch)\n", 59 | " Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", 60 | "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch) (0.6.2)\n", 61 | "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n", 62 | "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", 63 | "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch)\n", 64 | " Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", 65 | "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.2.0)\n", 66 | "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n", 67 | "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n", 68 | "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n", 69 | "Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n", 70 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m3.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 71 | "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n", 72 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m72.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 73 | "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n", 74 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m57.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 75 | "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n", 76 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m50.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 77 | "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n", 78 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m1.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 79 | "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n", 80 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m10.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 81 | "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n", 82 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m37.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 83 | "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n", 84 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m10.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 85 | "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n", 86 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m10.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 87 | "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", 88 | "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m88.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", 89 | "\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12\n", 90 | " Attempting uninstall: nvidia-nvjitlink-cu12\n", 91 | " Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n", 92 | " Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n", 93 | " Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n", 94 | " Attempting uninstall: nvidia-curand-cu12\n", 95 | " Found existing installation: nvidia-curand-cu12 10.3.6.82\n", 96 | " Uninstalling nvidia-curand-cu12-10.3.6.82:\n", 97 | " Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n", 98 | " Attempting uninstall: nvidia-cufft-cu12\n", 99 | " Found existing installation: nvidia-cufft-cu12 11.2.3.61\n", 100 | " Uninstalling nvidia-cufft-cu12-11.2.3.61:\n", 101 | " Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n", 102 | " Attempting uninstall: nvidia-cuda-runtime-cu12\n", 103 | " Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n", 104 | " Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n", 105 | " Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n", 106 | " Attempting uninstall: nvidia-cuda-nvrtc-cu12\n", 107 | " Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n", 108 | " Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n", 109 | " Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n", 110 | " Attempting uninstall: nvidia-cuda-cupti-cu12\n", 111 | " Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n", 112 | " Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n", 113 | " Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n", 114 | " Attempting uninstall: nvidia-cublas-cu12\n", 115 | " Found existing installation: nvidia-cublas-cu12 12.5.3.2\n", 116 | " Uninstalling nvidia-cublas-cu12-12.5.3.2:\n", 117 | " Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n", 118 | " Attempting uninstall: nvidia-cusparse-cu12\n", 119 | " Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n", 120 | " Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n", 121 | " Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n", 122 | " Attempting uninstall: nvidia-cudnn-cu12\n", 123 | " Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n", 124 | " Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n", 125 | " Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n", 126 | " Attempting uninstall: nvidia-cusolver-cu12\n", 127 | " Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n", 128 | " Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n", 129 | " Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n", 130 | "Successfully installed nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127\n" 131 | ] 132 | } 133 | ], 134 | "source": [ 135 | "!pip install torch numpy" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "source": [ 141 | "import torch\n", 142 | "import torch.nn as nn\n", 143 | "import torch.optim as optim\n", 144 | "import time\n", 145 | "import numpy as np\n", 146 | "from torch.utils.data import Dataset, DataLoader\n", 147 | "\n", 148 | "# Check if GPU is available\n", 149 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 150 | "print(f\"Running on: {device}\")\n", 151 | "\n", 152 | "# Simulated Attention Module\n", 153 | "class AttentionModule(nn.Module):\n", 154 | " def __init__(self, hidden_size, num_heads):\n", 155 | " super(AttentionModule, self).__init__()\n", 156 | " self.hidden_size = hidden_size\n", 157 | " self.num_heads = num_heads\n", 158 | " self.qkv_proj = nn.Linear(hidden_size, hidden_size * 3) # Q, K, V projection\n", 159 | " self.out_proj = nn.Linear(hidden_size, hidden_size)\n", 160 | "\n", 161 | " def forward(self, x, kv_cache=None):\n", 162 | " batch_size, seq_len, _ = x.size()\n", 163 | " qkv = self.qkv_proj(x).view(batch_size, seq_len, 3, self.num_heads, -1)\n", 164 | " q, k, v = qkv.split(1, dim=2)\n", 165 | " q, k, v = q.squeeze(2), k.squeeze(2), v.squeeze(2)\n", 166 | "\n", 167 | " # Simplified attention computation\n", 168 | " attn_scores = torch.matmul(q, k.transpose(-1, -2)) / (self.hidden_size ** 0.5)\n", 169 | " attn_weights = torch.softmax(attn_scores, dim=-1)\n", 170 | " attn_output = torch.matmul(attn_weights, v)\n", 171 | " output = self.out_proj(attn_output.view(batch_size, seq_len, -1))\n", 172 | "\n", 173 | " # Update KV cache (simulated)\n", 174 | " new_kv_cache = (k, v) if kv_cache is None else kv_cache\n", 175 | " return output, new_kv_cache\n", 176 | "\n", 177 | "# Simulated FFN Expert Module\n", 178 | "class Expert(nn.Module):\n", 179 | " def __init__(self, hidden_size, intermediate_size):\n", 180 | " super(Expert, self).__init__()\n", 181 | " self.fc1 = nn.Linear(hidden_size, intermediate_size)\n", 182 | " self.fc2 = nn.Linear(intermediate_size, hidden_size)\n", 183 | " self.activation = nn.ReLU()\n", 184 | "\n", 185 | " def forward(self, x):\n", 186 | " return self.fc2(self.activation(self.fc1(x)))\n", 187 | "\n", 188 | "# MoE Layer with Disaggregated Experts\n", 189 | "class MoELayer(nn.Module):\n", 190 | " def __init__(self, hidden_size, num_experts, top_k, intermediate_size):\n", 191 | " super(MoELayer, self).__init__()\n", 192 | " self.hidden_size = hidden_size\n", 193 | " self.num_experts = num_experts\n", 194 | " self.top_k = top_k\n", 195 | " self.gating = nn.Linear(hidden_size, num_experts)\n", 196 | " self.experts = nn.ModuleList([Expert(hidden_size, intermediate_size) for _ in range(num_experts)])\n", 197 | "\n", 198 | " def forward(self, x):\n", 199 | " # Gating to select top-k experts\n", 200 | " gate_scores = self.gating(x)\n", 201 | " topk_values, topk_indices = torch.topk(gate_scores, self.top_k, dim=-1)\n", 202 | " gate_weights = torch.softmax(topk_values, dim=-1)\n", 203 | "\n", 204 | " # Expert dispatch (simplified for single device)\n", 205 | " batch_size, seq_len, _ = x.size()\n", 206 | " output = torch.zeros_like(x)\n", 207 | " for b in range(batch_size):\n", 208 | " for s in range(seq_len):\n", 209 | " token = x[b, s:s+1]\n", 210 | " weights = gate_weights[b, s]\n", 211 | " indices = topk_indices[b, s]\n", 212 | " expert_out = torch.zeros(1, self.hidden_size, device=device)\n", 213 | " for i, idx in enumerate(indices):\n", 214 | " expert_out += weights[i] * self.experts[idx](token)\n", 215 | " output[b, s] = expert_out\n", 216 | " return output\n", 217 | "\n", 218 | "# Simulated M2N Communication\n", 219 | "def simulate_m2n_communication(data, sender_id, receiver_id, comm_time=0.001):\n", 220 | " time.sleep(comm_time) # Simulate network latency\n", 221 | " return data\n", 222 | "\n", 223 | "# MegaScale-Infer Model\n", 224 | "class MegaScaleInfer(nn.Module):\n", 225 | " def __init__(self, hidden_size, num_experts, top_k, intermediate_size, num_micro_batches):\n", 226 | " super(MegaScaleInfer, self).__init__()\n", 227 | " self.attention = AttentionModule(hidden_size, num_heads=4).to(device)\n", 228 | " self.moe_layer = MoELayer(hidden_size, num_experts, top_k, intermediate_size).to(device)\n", 229 | " self.num_micro_batches = num_micro_batches\n", 230 | " self.output_proj = nn.Linear(hidden_size, hidden_size) # Final projection for prediction\n", 231 | "\n", 232 | " def forward(self, x, kv_cache=None):\n", 233 | " batch_size, seq_len, _ = x.size()\n", 234 | " micro_batch_size = max(1, batch_size // self.num_micro_batches)\n", 235 | " outputs = []\n", 236 | " total_attn_time = 0\n", 237 | " total_expert_time = 0\n", 238 | "\n", 239 | " # Ping-pong pipeline simulation\n", 240 | " for i in range(0, batch_size, micro_batch_size):\n", 241 | " micro_batch = x[i:i + micro_batch_size]\n", 242 | "\n", 243 | " # Attention computation\n", 244 | " start_time = time.time()\n", 245 | " attn_output, new_kv_cache = self.attention(micro_batch, kv_cache)\n", 246 | " attn_time = time.time() - start_time\n", 247 | " total_attn_time += attn_time\n", 248 | "\n", 249 | " # Simulate sending to expert nodes\n", 250 | " attn_output = simulate_m2n_communication(attn_output, sender_id=0, receiver_id=1)\n", 251 | "\n", 252 | " # Expert computation\n", 253 | " start_time = time.time()\n", 254 | " expert_output = self.moe_layer(attn_output)\n", 255 | " expert_time = time.time() - start_time\n", 256 | " total_expert_time += expert_time\n", 257 | "\n", 258 | " # Simulate sending back to attention nodes\n", 259 | " expert_output = simulate_m2n_communication(expert_output, sender_id=1, receiver_id=0)\n", 260 | " outputs.append(expert_output)\n", 261 | "\n", 262 | " kv_cache = new_kv_cache\n", 263 | "\n", 264 | " print(f\"Micro-batch {i // micro_batch_size + 1}/{self.num_micro_batches}: \"\n", 265 | " f\"Attention time: {attn_time:.4f}s, Expert time: {expert_time:.4f}s\")\n", 266 | "\n", 267 | " output = torch.cat(outputs, dim=0)\n", 268 | " output = self.output_proj(output)\n", 269 | " return output, kv_cache, total_attn_time, total_expert_time\n", 270 | "\n", 271 | "# Simulated Dataset\n", 272 | "class DummyDataset(Dataset):\n", 273 | " def __init__(self, num_samples, seq_len, hidden_size):\n", 274 | " self.num_samples = num_samples\n", 275 | " self.seq_len = seq_len\n", 276 | " self.hidden_size = hidden_size\n", 277 | " self.data = torch.randn(num_samples, seq_len, hidden_size)\n", 278 | " self.labels = torch.randn(num_samples, seq_len, hidden_size) # Dummy regression targets\n", 279 | "\n", 280 | " def __len__(self):\n", 281 | " return self.num_samples\n", 282 | "\n", 283 | " def __getitem__(self, idx):\n", 284 | " return self.data[idx], self.labels[idx]\n", 285 | "\n", 286 | "# Training Function\n", 287 | "def train_megascale_infer(model, train_loader, optimizer, criterion, num_epochs):\n", 288 | " model.train()\n", 289 | " for epoch in range(num_epochs):\n", 290 | " epoch_loss = 0\n", 291 | " epoch_start_time = time.time()\n", 292 | " total_attn_time = 0\n", 293 | " total_expert_time = 0\n", 294 | "\n", 295 | " for batch_idx, (inputs, targets) in enumerate(train_loader):\n", 296 | " inputs, targets = inputs.to(device), targets.to(device)\n", 297 | " optimizer.zero_grad()\n", 298 | "\n", 299 | " # Forward pass\n", 300 | " outputs, kv_cache, attn_time, expert_time = model(inputs)\n", 301 | " loss = criterion(outputs, targets)\n", 302 | " total_attn_time += attn_time\n", 303 | " total_expert_time += expert_time\n", 304 | "\n", 305 | " # Backward pass and optimization\n", 306 | " loss.backward()\n", 307 | " optimizer.step()\n", 308 | "\n", 309 | " epoch_loss += loss.item()\n", 310 | " if batch_idx % 10 == 0:\n", 311 | " print(f\"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(train_loader)}], \"\n", 312 | " f\"Loss: {loss.item():.4f}\")\n", 313 | "\n", 314 | " avg_loss = epoch_loss / len(train_loader)\n", 315 | " epoch_time = time.time() - epoch_start_time\n", 316 | " print(f\"Epoch [{epoch+1}/{num_epochs}] completed in {epoch_time:.2f}s, \"\n", 317 | " f\"Avg Loss: {avg_loss:.4f}, \"\n", 318 | " f\"Total Attention Time: {total_attn_time:.4f}s, \"\n", 319 | " f\"Total Expert Time: {total_expert_time:.4f}s\")\n", 320 | "\n", 321 | "# Main Execution\n", 322 | "def main():\n", 323 | " # Hyperparameters\n", 324 | " hidden_size = 256\n", 325 | " num_experts = 8\n", 326 | " top_k = 2\n", 327 | " intermediate_size = 1024\n", 328 | " batch_size = 32\n", 329 | " seq_len = 32\n", 330 | " num_micro_batches = 4\n", 331 | " num_samples = 1000\n", 332 | " num_epochs = 5\n", 333 | " learning_rate = 0.001\n", 334 | "\n", 335 | " # Initialize model, dataset, and training components\n", 336 | " model = MegaScaleInfer(hidden_size, num_experts, top_k, intermediate_size, num_micro_batches)\n", 337 | " dataset = DummyDataset(num_samples, seq_len, hidden_size)\n", 338 | " train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n", 339 | " optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n", 340 | " criterion = nn.MSELoss() # Mean Squared Error for regression task\n", 341 | "\n", 342 | " # Train the model\n", 343 | " print(\"Starting training...\")\n", 344 | " train_megascale_infer(model, train_loader, optimizer, criterion, num_epochs)\n", 345 | " print(\"Training completed!\")\n", 346 | "\n", 347 | "if __name__ == \"__main__\":\n", 348 | " main()" 349 | ], 350 | "metadata": { 351 | "id": "55zOneGg61A3" 352 | }, 353 | "execution_count": null, 354 | "outputs": [] 355 | } 356 | ] 357 | } --------------------------------------------------------------------------------