├── .gitignore ├── README.md ├── assets └── results │ ├── FMA.png │ ├── FloatvsSuperfloat.jpg │ ├── LLM Distribution │ ├── Llama-2-7b-chat-hf_layers.png │ ├── Llama-2-7b-chat-hf_parameters.png │ ├── MiniCPM-V-2_6_layers.png │ ├── MiniCPM-V-2_6_parameters.png │ ├── Mistral-7B-v0.1_layers.png │ ├── Mistral-7B-v0.1_parameters.png │ ├── gte-Qwen2-7B-instruct_layers.png │ ├── gte-Qwen2-7B-instruct_parameters.png │ ├── japanese-stablelm-base-beta-7b_layers.png │ └── japanese-stablelm-base-beta-7b_parameters.png │ ├── Qwen2 │ ├── RULE1-Max_Context_Length_Barrier.png │ ├── SF11.png │ ├── SF16.png │ ├── SF4.png │ └── SF8.png │ ├── YOLOv5 Weight Distribution │ ├── yolov5l.png │ ├── yolov5l6.png │ ├── yolov5m.png │ ├── yolov5m6.png │ ├── yolov5n.png │ ├── yolov5n6.png │ ├── yolov5s.png │ ├── yolov5s6.png │ ├── yolov5x.png │ └── yolov5x6.png │ ├── YOLOv7 Weight Distribution │ ├── layer_status_plot_yolov7-d6.png │ ├── layer_status_plot_yolov7-e6.png │ ├── layer_status_plot_yolov7-e6e.png │ ├── layer_status_plot_yolov7-w6.png │ ├── layer_status_plot_yolov7.png │ ├── layer_status_plot_yolov7x.png │ ├── yolov7-d6.pt.png │ ├── yolov7-e6.pt.png │ ├── yolov7-e6e.pt.png │ ├── yolov7-w6.pt.png │ ├── yolov7.pt.png │ └── yolov7x.pt.png │ ├── activation_unit.png │ ├── blockplan.png │ ├── cycle_count_logic.png │ ├── gtkwave_fma_comparison.png │ ├── gtkwave_superfloat.png │ ├── hardware architecture.png │ ├── improved_dadda.jpg │ ├── isa_integrated_floorplan.png │ ├── results_viewer.ipynb │ └── shift_register.png ├── docs └── technical_reports │ ├── Major Project.docx │ ├── SuperFloat PoC.pdf │ └── Superfloat Whitepaper.pdf └── src ├── modal ├── better.py ├── modal_eval.py ├── modal_fasteropt.py ├── modal_lth.py └── train-custom-gpt.py ├── test ├── emulate_sf.py ├── emulate_sf_dec.py └── sf_tensors.py ├── verilog ├── activationUnit_16bit.sv └── shiftRegister_16bit.sv ├── wasq ├── wasq_eval.py ├── wasq_fasteropt.py ├── wasq_fpm.py ├── wasq_inference.py ├── wasq_lth.py ├── wasq_mplth.py ├── wasq_opt.py ├── wasq_sa_mplth.py └── wasq_vanilla.py └── website ├── backend ├── parallel_backend.py ├── sequential_backend.py └── testing_backend.py └── frontend ├── .gitignore ├── package-lock.json ├── package.json ├── public ├── favicon.ico ├── index.html ├── logo192.png ├── logo512.png ├── manifest.json └── robots.txt ├── src ├── App.css ├── App.js ├── App.test.js ├── README.md ├── components │ ├── ChatInterface.jsx │ └── LandingPage.jsx ├── index.css ├── index.js ├── logo.svg ├── reportWebVitals.js └── setupTests.js └── tailwind.config.js /.gitignore: -------------------------------------------------------------------------------- 1 | models--* 2 | .locks/ 3 | .env 4 | frontend/node_modules/ 5 | *.exe 6 | __pycache__/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **Superfloat: Accelerators for AI on Edge. Reimagined.** 2 | 3 | This repository contains the code, methods, and scripts for implementing **Superfloat Quantization** and **Lottery Ticket Hypothesis (LTH)** techniques for optimizing neural networks. The repository focuses on various quantization algorithms, model evaluations, and fine-tuning techniques to minimize perplexity and stabilize activations. 4 | 5 | --- 6 | 7 | ## **What is Superfloat?** 8 | 9 | **Superfloat** is a custom quantization algorithm that operates with a **scalable precision format**. Unlike traditional floating-point systems (IEEE-754), Superfloat removes the mantissa entirely and focuses solely on the **exponent** for precision representation. 10 | 11 | ### **Key Features**: 12 | 1. **Sign-Exponent Representation**: 13 | - Superfloat (SFx) uses `1 bit` for the **sign** and allocates the remaining `x-1 bits` for the **exponent**. 14 | - For instance, in **SF16**: 15 | - 1 bit → Sign 16 | - 15 bits → Exponent 17 | 18 | 2. **Clamping Range**: 19 | - All values are clamped within the range `[-1, 1]`. This ensures activation and parameter stability, reducing the likelihood of exploding or vanishing gradients. 20 | 21 | 3. **Bit-width Flexibility**: 22 | - Superfloat supports variable precision formats, scaling between **3-bit and 16-bit**: 23 | - Lower precision (e.g., **SF4**) → Faster computation, reduced model size. 24 | - Higher precision (e.g., **SF16**) → Improved accuracy while maintaining efficient quantization. 25 | 26 | 4. **Gradient and Activation Capping**: 27 | - To stabilize the training process, gradients and activations are **capped** at -1 and +1. 28 | 29 | ### **Advantages of Superfloat**: 30 | - Saves **precision** without a significant drop in accuracy. 31 | - Reduces **computational complexity** compared to traditional floating-point representations. 32 | - Allows adaptive scaling for diverse quantization requirements. 33 | 34 | --- 35 | 36 | **Conversion FP32 - SF(4-16)** 37 | 38 | A standard 32-bit floating-point number is converted into a custom superfloat representation with a variable-sized mantissa. 39 | 40 | - **Clamp Input Range** – The input value is restricted to the range (-1, 1). If the value exceeds this, it is set to a predefined maximum value. 41 | 42 | - **Extract Sign Bit** – The sign bit is determined and stored separately, while the value is converted to its absolute form. 43 | 44 | - **Compute Mantissa** – The fractional value is scaled by `2^mantissa_bits` to convert it into an integer representation. 45 | 46 | - **Bit Packing** – The sign bit and mantissa are arranged into a custom format, with the mantissa shifted to fit within a float-sized bit structure. 47 | 48 | - **Bitwise Reinterpretation** – The constructed bit pattern is reinterpreted as a floating-point number and returned. 49 | 50 | --- 51 | ## **What is WASQ?** 52 | 53 | **WASQ** stands for **Weight and Activation Superfloat Quantization**. It is a **hybrid quantization framework** that leverages Superfloat precision to optimize both model weights and activations. 54 | 55 | ### **Key Characteristics of WASQ**: 56 | 1. **Weight Quantization**: 57 | - Model weights are converted to **Superfloat precision** (SFx) without requiring complex computations like mantissa adjustments. 58 | 59 | 2. **Activation Quantization**: 60 | - Activations are clamped and quantized within a stable range to prevent issues such as exploding activations. 61 | 62 | 3. **Optimization Algorithms**: 63 | - WASQ includes customized algorithms like **WASQ OPT** and **Full Parameter Method (FPM)** to balance accuracy and convergence speed. 64 | - New: **Simulated Annealing Multi-Prize Lottery Ticket (SA-MPLTH)** algorithm for healing quantized models 65 | 66 | 4. **Scalability**: 67 | - WASQ supports **multi-bit quantization** (from 4-bit to 16-bit), making it adaptable for different deployment environments, such as: 68 | - **Edge devices** → Lower precision for speed and memory savings. 69 | - **Servers** → Higher precision for accuracy-sensitive tasks. 70 | 71 | ### **WASQ + Lottery Ticket Hypothesis (LTH)** 72 | WASQ integrates **LTH** to identify specific weights that are critical for maintaining model performance after quantization. By fine-tuning only the **essential weights**, WASQ reduces computational overhead while achieving high accuracy. 73 | 74 | --- 75 | 76 | ## **Files Overview** 77 | 78 | 1. **[Quant_Dequant.ipynb](Quant_Dequant.ipynb)** 79 | Contains the implementation of basic Superfloat quantization and dequantization functions. 80 | 81 | 2. **[sf16quant.ipynb](sf16quant.ipynb)** 82 | Builds on Superfloat quantization functions, specifically for **SF16 precision**. 83 | 84 | 3. **[lth_analysis.py](lth_analysis.py)** 85 | Analyzes **activation magnitude distribution** for **LTH**. It compares activation patterns of original and quantized models. 86 | 87 | 4. **[lth_trainer.py](lth_trainer.py)** 88 | The **LTH trainer** script for fine-tuning models based on the Lottery Ticket Hypothesis technique. 89 | 90 | 5. **[wasq_eval.py](wasq_eval.py)** 91 | Calculates **perplexity** for a series of models, grouped by context length, epochs, or model species. 92 | 93 | 6. **[wasq_inference.py](wasq_inference.py)** 94 | Provides inference capabilities for **individual** or **multiple WASQ-quantized models**. 95 | 96 | 7. **[wasq_fasteropt.py](wasq_fasteropt.py)** 97 | An optimized version of the **OPT algorithm** implemented in `wasq_opt.py`. 98 | 99 | 8. **[wasq_opt.py](wasq_opt.py)** 100 | Core implementation of the WASQ OPT algorithm. 101 | 102 | 9. **[wasq_fpm.py](wasq_fpm.py)** 103 | Implements the **Full Parameter Method** (FPM) for WASQ quantization. 104 | 105 | 10. **[wasq_vanilla.py](wasq_vanilla.py)** 106 | Baseline implementation of the **Vanilla algorithm** for WASQ. 107 | 108 | 11. **[sa_mplth.py](sa_mplth.py)** 109 | New: Implements Simulated Annealing Multi-Prize Lottery Ticket Hypothesis for healing quantized models. 110 | 111 | 12. **[assets/results](assets/results/)** 112 | Contains outputs of model tests, perplexity scores, and supplementary studies. 113 | 114 | --- 115 | 116 | ## **Scaling Laws** 117 | 118 | ### 1. **Maximum Context Length Barrier - Perplexity Factor** 119 | For a model with `n` parameters, a calibration dataset of maximum input length `c`, **three-shot quantization fine-tuning**, and Superfloat precision bit `x` (where `4 ≤ x ≤ 16`): 120 | 121 | \[ 122 | P = f(n, c, 3, x) 123 | \] 124 | 125 | - **Lower P** indicates better model understanding and calibration performance. 126 | 127 | --- 128 | 129 | ### 2. **Maximum Neuron Spread Factor** 130 | This scaling law uses the **Lottery Ticket Hypothesis** for WASQ quantization to stabilize activations: 131 | 132 | 1. Perform a forward pass using the **original model** and record the average magnitudes of activations across all layers. 133 | 2. Perform the same for the **vanilla quantized model** to observe how quantization impacts activation magnitudes. 134 | 3. Rank layers based on the **difference in activation magnitudes** between the original and quantized models. 135 | 4. Identify and **cluster layers** with significant deviations to address issues like exploding/vanishing activations. 136 | 5. Fine-tune or analyze these clusters to ensure stable activations and minimal performance degradation. 137 | 138 | The law establishes that the **maximum neuron spread** (region targeted for fine-tuning/updating) is a function of: 139 | - **Activation magnitude** 140 | - **Activation fracture** (spread of how a weight affects neighboring weights during backpropagation) 141 | 142 | --- 143 | 144 | ## **Quantization Algorithms** 145 | 146 | The repository explores three quantization approaches: 147 | 148 | 1. **Superfloat Precision**: Custom precision without mantissa, clamped within `[-1, 1]` for stability. 149 | 2. **WASQ OPT**: Optimized quantization with faster convergence. 150 | 3. **Full Parameter Method (FPM)**: Retrains all parameters for higher accuracy. 151 | 4. **SA-MPLTH**: New simulated annealing approach for healing quantized models. 152 | 153 | --- 154 | 155 | ## **Usage** 156 | 157 | ### **Setup** 158 | Clone the repository and install dependencies: 159 | 160 | ```bash 161 | git clone https://github.com/aloshdenny/superfloat 162 | cd superfloat 163 | pip install -r requirements.txt 164 | ``` 165 | 166 | ### **Running Scripts** 167 | 168 | - Train with **LTH**: 169 | ```bash 170 | python lth_trainer.py 171 | ``` 172 | 173 | - Evaluate Perplexity: 174 | ```bash 175 | python wasq_eval.py 176 | ``` 177 | 178 | - Perform Inference: 179 | ```bash 180 | python wasq_inference.py 181 | ``` 182 | 183 | - Run SA-MPLTH: 184 | ```bash 185 | python sa_mplth.py 186 | ``` 187 | 188 | --- 189 | 190 | ## **assets/Results** 191 | 192 | The assets/results folder contains: 193 | - **Perplexity scores** for different model configurations. 194 | - **Activation magnitude comparisons** before and after quantization. 195 | - Supplementary studies showcasing model performance. 196 | 197 | --- 198 | 199 | ## **Chip-1: Atreides** 200 | 201 | Atreides is an ASIC accelerator designed specifically for Superfloat-based inference. We redesigned the systolic array to support SFx operations, adopting a modded RV32 ISA and faster Fused-Multiply-Adder (FMA) units. The end goal is not convention—it's breaking the rules of computing and physics to achieve faster inference, lower memory consumption, and the same accuracy. 202 | 203 | ## FMA in Atreides 204 | 205 | Below is an image showing the FMA in Atreides: 206 | 207 | ![FMA](assets/results/FMA.png) 208 | 209 | ## Expanded View of Chip-1's Architecture 210 | 211 | An expanded view of Chip-1's architecture includes non-unified memory blocks (subject to unification), cache, control store (modded RV32 ISA), and an array of FMAs: 212 | 213 | ![Chip-1 Architecture](assets/results/hardware%20architecture.png) 214 | 215 | ### FPGA Functional Units Design 216 | 217 | #### 1. 8 x 16-bit Shift Register (simplified) 218 | 219 | ![FPGA Floorplan](assets/results/shift_register.png) 220 | 221 | #### 2. Activation Unit (simplified) 222 | 223 | ![FPGA Floorplan](assets/results/activation_unit.png) 224 | 225 | #### 3. Cycle Count Logic 226 | 227 | ![FPGA Floorplan](assets/results/cycle_count_logic.png) 228 | 229 | ## Instruction Set 230 | 231 | The current instruction set for the FPGA architecture is show below: 232 | 233 | | Instruction | Opcode(4) | Op 1(4) | Op 2(4) | Op 3(4) | Description | 234 | |-------------|-----------|---------|---------|---------|---------------------------------------------------------------------------------------| 235 | | STR | 0001 | addr | row | col | Stores the matrix data from activation unit buffer into specified address in memory | 236 | | LDR | 0010 | addr | row | col | Loads the matrix at addr into the Row Shift Buffer | 237 | | LDC | 0011 | addr | row | col | Loads the matrix at addr into the Column Shift Buffer | 238 | | MATMUL | 0100 | - | - | - | Performs matrix multiplication using data in Row Shift Buffer and Column Shift Buffer | 239 | | RELU | 0101 | - | - | - | Performs ReLU activation function on Systolic Array output | 240 | | LIN | 0110 | - | - | - | Performs Linear activation function on Systolic Array output | 241 | | NOP | 0000 | - | - | - | No Operation | 242 | 243 | ### FPGA floorplan (ISA integrated) 244 | 245 | The FPGA floorplan integrated with instruction set is shown below: 246 | 247 | ![FPGA Floorplan](assets/results/isa_integrated_floorplan.png) 248 | 249 | --- 250 | 251 | ## **Contributions** 252 | 253 | Contributions are welcome! Feel free to open issues or submit pull requests. 254 | 255 | --- 256 | 257 | ## **Sponsors** 258 | 259 | We would like to thank our sponsors for their support: 260 | 261 |
262 | 263 | 264 | 265 | 266 |
267 | 268 | --- 269 | 270 | ## **License** 271 | 272 | This project is licensed under the MIT License. 273 | -------------------------------------------------------------------------------- /assets/results/FMA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/FMA.png -------------------------------------------------------------------------------- /assets/results/FloatvsSuperfloat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/FloatvsSuperfloat.jpg -------------------------------------------------------------------------------- /assets/results/LLM Distribution/Llama-2-7b-chat-hf_layers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/LLM Distribution/Llama-2-7b-chat-hf_layers.png -------------------------------------------------------------------------------- /assets/results/LLM Distribution/Llama-2-7b-chat-hf_parameters.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/LLM Distribution/Llama-2-7b-chat-hf_parameters.png -------------------------------------------------------------------------------- /assets/results/LLM Distribution/MiniCPM-V-2_6_layers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/LLM Distribution/MiniCPM-V-2_6_layers.png -------------------------------------------------------------------------------- /assets/results/LLM Distribution/MiniCPM-V-2_6_parameters.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/LLM Distribution/MiniCPM-V-2_6_parameters.png -------------------------------------------------------------------------------- /assets/results/LLM Distribution/Mistral-7B-v0.1_layers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/LLM Distribution/Mistral-7B-v0.1_layers.png -------------------------------------------------------------------------------- /assets/results/LLM Distribution/Mistral-7B-v0.1_parameters.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/LLM Distribution/Mistral-7B-v0.1_parameters.png -------------------------------------------------------------------------------- /assets/results/LLM Distribution/gte-Qwen2-7B-instruct_layers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/LLM Distribution/gte-Qwen2-7B-instruct_layers.png -------------------------------------------------------------------------------- /assets/results/LLM Distribution/gte-Qwen2-7B-instruct_parameters.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/LLM Distribution/gte-Qwen2-7B-instruct_parameters.png -------------------------------------------------------------------------------- /assets/results/LLM Distribution/japanese-stablelm-base-beta-7b_layers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/LLM Distribution/japanese-stablelm-base-beta-7b_layers.png -------------------------------------------------------------------------------- /assets/results/LLM Distribution/japanese-stablelm-base-beta-7b_parameters.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/LLM Distribution/japanese-stablelm-base-beta-7b_parameters.png -------------------------------------------------------------------------------- /assets/results/Qwen2/RULE1-Max_Context_Length_Barrier.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/Qwen2/RULE1-Max_Context_Length_Barrier.png -------------------------------------------------------------------------------- /assets/results/Qwen2/SF11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/Qwen2/SF11.png -------------------------------------------------------------------------------- /assets/results/Qwen2/SF16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/Qwen2/SF16.png -------------------------------------------------------------------------------- /assets/results/Qwen2/SF4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/Qwen2/SF4.png -------------------------------------------------------------------------------- /assets/results/Qwen2/SF8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/Qwen2/SF8.png -------------------------------------------------------------------------------- /assets/results/YOLOv5 Weight Distribution/yolov5l.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv5 Weight Distribution/yolov5l.png -------------------------------------------------------------------------------- /assets/results/YOLOv5 Weight Distribution/yolov5l6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv5 Weight Distribution/yolov5l6.png -------------------------------------------------------------------------------- /assets/results/YOLOv5 Weight Distribution/yolov5m.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv5 Weight Distribution/yolov5m.png -------------------------------------------------------------------------------- /assets/results/YOLOv5 Weight Distribution/yolov5m6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv5 Weight Distribution/yolov5m6.png -------------------------------------------------------------------------------- /assets/results/YOLOv5 Weight Distribution/yolov5n.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv5 Weight Distribution/yolov5n.png -------------------------------------------------------------------------------- /assets/results/YOLOv5 Weight Distribution/yolov5n6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv5 Weight Distribution/yolov5n6.png -------------------------------------------------------------------------------- /assets/results/YOLOv5 Weight Distribution/yolov5s.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv5 Weight Distribution/yolov5s.png -------------------------------------------------------------------------------- /assets/results/YOLOv5 Weight Distribution/yolov5s6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv5 Weight Distribution/yolov5s6.png -------------------------------------------------------------------------------- /assets/results/YOLOv5 Weight Distribution/yolov5x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv5 Weight Distribution/yolov5x.png -------------------------------------------------------------------------------- /assets/results/YOLOv5 Weight Distribution/yolov5x6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv5 Weight Distribution/yolov5x6.png -------------------------------------------------------------------------------- /assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7-d6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7-d6.png -------------------------------------------------------------------------------- /assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7-e6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7-e6.png -------------------------------------------------------------------------------- /assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7-e6e.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7-e6e.png -------------------------------------------------------------------------------- /assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7-w6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7-w6.png -------------------------------------------------------------------------------- /assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7.png -------------------------------------------------------------------------------- /assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7x.png -------------------------------------------------------------------------------- /assets/results/YOLOv7 Weight Distribution/yolov7-d6.pt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/yolov7-d6.pt.png -------------------------------------------------------------------------------- /assets/results/YOLOv7 Weight Distribution/yolov7-e6.pt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/yolov7-e6.pt.png -------------------------------------------------------------------------------- /assets/results/YOLOv7 Weight Distribution/yolov7-e6e.pt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/yolov7-e6e.pt.png -------------------------------------------------------------------------------- /assets/results/YOLOv7 Weight Distribution/yolov7-w6.pt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/yolov7-w6.pt.png -------------------------------------------------------------------------------- /assets/results/YOLOv7 Weight Distribution/yolov7.pt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/yolov7.pt.png -------------------------------------------------------------------------------- /assets/results/YOLOv7 Weight Distribution/yolov7x.pt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/yolov7x.pt.png -------------------------------------------------------------------------------- /assets/results/activation_unit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/activation_unit.png -------------------------------------------------------------------------------- /assets/results/blockplan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/blockplan.png -------------------------------------------------------------------------------- /assets/results/cycle_count_logic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/cycle_count_logic.png -------------------------------------------------------------------------------- /assets/results/gtkwave_fma_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/gtkwave_fma_comparison.png -------------------------------------------------------------------------------- /assets/results/gtkwave_superfloat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/gtkwave_superfloat.png -------------------------------------------------------------------------------- /assets/results/hardware architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/hardware architecture.png -------------------------------------------------------------------------------- /assets/results/improved_dadda.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/improved_dadda.jpg -------------------------------------------------------------------------------- /assets/results/isa_integrated_floorplan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/isa_integrated_floorplan.png -------------------------------------------------------------------------------- /assets/results/shift_register.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/shift_register.png -------------------------------------------------------------------------------- /docs/technical_reports/Major Project.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/docs/technical_reports/Major Project.docx -------------------------------------------------------------------------------- /docs/technical_reports/SuperFloat PoC.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/docs/technical_reports/SuperFloat PoC.pdf -------------------------------------------------------------------------------- /docs/technical_reports/Superfloat Whitepaper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/docs/technical_reports/Superfloat Whitepaper.pdf -------------------------------------------------------------------------------- /src/modal/better.py: -------------------------------------------------------------------------------- 1 | import modal 2 | 3 | # Create a Modal image with the required dependencies 4 | image = ( 5 | modal.Image.debian_slim() 6 | .pip_install( 7 | "torch", 8 | "transformers", 9 | "datasets", 10 | "tqdm", 11 | "huggingface_hub", 12 | ) 13 | .apt_install("gcc", "python3-dev") # Add necessary system libraries if needed 14 | ) 15 | 16 | app = modal.App("qwen-sf4-experimental") 17 | 18 | # Define the function that runs the script 19 | @app.function(gpu="H100", image=image, timeout=86400) 20 | def train_and_upload(): 21 | import torch 22 | import gc 23 | import os 24 | import re 25 | import requests 26 | from tqdm import tqdm 27 | from datasets import Dataset 28 | from torch.utils.data import DataLoader 29 | from transformers import AutoModelForCausalLM, AutoTokenizer 30 | import pandas as pd 31 | import math 32 | 33 | # Function to calculate perplexity 34 | def calculate_perplexity(model, dataloader, loss_fn, device): 35 | model.eval() # Set model to evaluation mode 36 | total_loss = 0.0 37 | total_steps = 0 38 | 39 | with torch.no_grad(): 40 | for batch in dataloader: 41 | input_ids = batch["input_ids"].to(device) 42 | attention_mask = batch["attention_mask"].to(device) 43 | outputs = model(input_ids=input_ids, attention_mask=attention_mask) 44 | logits = outputs.logits 45 | target = input_ids[:, 1:].contiguous() # Shift targets by one token 46 | logits = logits[:, :-1].contiguous() # Align logits with target 47 | 48 | loss = loss_fn(logits.view(-1, logits.size(-1)), target.view(-1)) 49 | total_loss += loss.item() 50 | total_steps += 1 51 | 52 | avg_loss = total_loss / total_steps 53 | perplexity = math.exp(avg_loss) # Perplexity is the exponential of the average loss 54 | return perplexity 55 | 56 | # List of dataset URLs 57 | urls = [ 58 | "https://huggingface.co/datasets/EleutherAI/the_pile_deduplicated/resolve/main/data/train-00000-of-01650-f70471ee3deb09c0.parquet", 59 | ] 60 | 61 | # Local final output file path 62 | final_file_name = "train.parquet" 63 | 64 | # Check if the final file already exists 65 | if not os.path.exists(final_file_name): 66 | print(f"Downloading and combining dataset from {len(urls)} files...") 67 | 68 | # List to hold all the dataframes 69 | combined_df = pd.DataFrame() 70 | 71 | # Loop through each URL to download and combine the files 72 | for i, url in enumerate(urls): 73 | downloaded_file = f"temp_file_{i}.parquet" 74 | 75 | # Download the dataset 76 | print(f"Downloading dataset from {url}...") 77 | response = requests.get(url, stream=True) 78 | with open(downloaded_file, "wb") as f: 79 | for chunk in response.iter_content(chunk_size=8192): 80 | f.write(chunk) 81 | print(f"Downloaded to {downloaded_file}.") 82 | 83 | # Read the downloaded parquet file and append to the combined dataframe 84 | df = pd.read_parquet(downloaded_file) 85 | combined_df = pd.concat([combined_df, df], ignore_index=True) 86 | 87 | # Optionally remove the temporary file after reading 88 | os.remove(downloaded_file) 89 | 90 | # Save the combined dataframe as a final parquet file 91 | combined_df.to_parquet(final_file_name) 92 | print(f"Combined data saved to {final_file_name}.") 93 | else: 94 | print(f"{final_file_name} already exists. Skipping download.") 95 | 96 | 97 | # max_lengths = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096] 98 | max_lengths = [2] 99 | bit = 4 100 | 101 | class Superfloat: 102 | # Mapping of bit-widths to floating-point types 103 | CASTING_TABLE = { 104 | 16: torch.float32, 105 | 15: torch.float32, 106 | 14: torch.float32, 107 | 13: torch.float32, 108 | 12: torch.float32, 109 | 11: torch.float16, 110 | 10: torch.float16, 111 | 9: torch.float16, 112 | 8: torch.bfloat16, 113 | 7: torch.bfloat16, 114 | 6: torch.bfloat16, 115 | 5: torch.bfloat16, 116 | 4: torch.bfloat16, 117 | } 118 | 119 | def __init__(self, bits: int): 120 | assert 4 <= bits <= 16, "Superfloat bitwidth must be between 4 and 16." 121 | self.bits = bits 122 | self.mantissa_bits = bits - 1 123 | self.max_val = 1.0 # Default max value 124 | self.scale_factor = 1.0 # Initialize scale factor 125 | self.float_type = self.CASTING_TABLE[bits] # Set float_type based on bits 126 | 127 | def set_scale(self, weights, dim=None, percentile=None): 128 | if dim is not None: 129 | self.scale_factor = torch.max(torch.abs(weights), dim=dim, keepdim=True)[0] 130 | elif percentile: 131 | scale = torch.kthvalue(torch.abs(weights).view(-1), int(weights.numel() * percentile / 100))[0] 132 | self.scale_factor = scale 133 | else: 134 | self.scale_factor = torch.max(torch.abs(weights)) 135 | 136 | def encode(self, value: torch.Tensor): 137 | scaled_value = value / self.scale_factor 138 | quantized_value = torch.round(scaled_value * (2**self.mantissa_bits - 1)) / (2**self.mantissa_bits - 1) 139 | return quantized_value.to(self.float_type) # Cast to float_type 140 | 141 | def decode(self, quantized_value: torch.Tensor): 142 | decoded_value = quantized_value * self.scale_factor 143 | return decoded_value.to(self.float_type) # Cast to float_type 144 | 145 | sf = Superfloat(bit) 146 | 147 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 148 | print(f"Using device: {device}") 149 | 150 | # Initialize model and tokenizer 151 | model_name = "Qwen/Qwen2-0.5B" 152 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 153 | tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./", token='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') 154 | tokenizer.pad_token = tokenizer.eos_token 155 | 156 | def quantize_model(model, sf_type, dim=None, percentile=None): 157 | for name, param in model.named_parameters(): 158 | # Cast weights to the correct float_type before quantization 159 | param.data = param.data.to(sf_type.float_type) 160 | 161 | # Set scale and quantize 162 | sf_type.set_scale(param.data, dim=dim, percentile=percentile) 163 | quantized_param = sf_type.encode(param.data) 164 | param.data = quantized_param.data 165 | return model 166 | 167 | def load_checkpoint(model, sf_bits, suffix="opt", device="cuda"): 168 | """ 169 | Load the latest checkpoint based on the provided Superfloat bitwidth and filename suffix. 170 | 171 | Args: 172 | quantized_model: The model to load the checkpoint into. 173 | sf_bits: Bitwidth of the Superfloat format (e.g., 11). 174 | suffix: The suffix of the filename (default: 'opt'). 175 | device: Device to load the model onto ('cuda' or 'cpu'). 176 | 177 | Returns: 178 | The quantized model with loaded weights and the epoch number. 179 | """ 180 | # Define the filename pattern to search for 181 | checkpoint_pattern = re.compile(f"sf{sf_bits}_.*_epoch(\\d+)_.*{suffix}$") 182 | 183 | # Find all matching checkpoint files 184 | checkpoint_files = [ 185 | f for f in os.listdir(".") if checkpoint_pattern.match(f) 186 | ] 187 | 188 | if not checkpoint_files: 189 | print(f"No checkpoints found for sf{sf_bits} with suffix '{suffix}'.") 190 | return quantize_model(model, sf), 0 191 | 192 | # Extract epoch numbers and sort by latest epoch 193 | epochs_and_files = [ 194 | (int(checkpoint_pattern.match(f).group(1)), f) for f in checkpoint_files 195 | ] 196 | latest_epoch, latest_checkpoint = max(epochs_and_files, key=lambda x: x[0]) 197 | 198 | # Load the latest checkpoint 199 | print(f"Loading checkpoint: {latest_checkpoint}") 200 | checkpoint = torch.load(latest_checkpoint, map_location=device) 201 | model.load_state_dict(checkpoint) 202 | model.to(device) 203 | 204 | return model, latest_epoch 205 | 206 | # Pre-training parameter check to ensure they are within range 207 | def check_parameters_in_range(model, sf): 208 | out_of_range_params = [] 209 | for name, param in model.named_parameters(): 210 | if not torch.all(torch.abs(param.data) <= sf.max_val): 211 | out_of_range_params.append(name) 212 | if out_of_range_params: 213 | print(f"Warning: The following parameters are out of range:") 214 | for param_name in out_of_range_params: 215 | print(f"- {param_name}") 216 | else: 217 | print("All parameters are within the valid range.") 218 | 219 | 220 | def prepare_dataset(tokenizer, max_length=1): 221 | dataset = Dataset.from_parquet("train.parquet") 222 | 223 | def tokenize_function(examples): 224 | return tokenizer( 225 | examples["text"], 226 | truncation=True, 227 | max_length=max_length, 228 | padding="max_length", 229 | return_tensors="pt", 230 | ) 231 | 232 | tokenized_dataset = dataset.map( 233 | tokenize_function, batched=True, remove_columns=dataset.column_names 234 | ) 235 | return tokenized_dataset 236 | 237 | def collate_fn(batch): 238 | input_ids = torch.stack( 239 | [torch.tensor(example["input_ids"]) for example in batch] 240 | ) 241 | attention_mask = torch.stack( 242 | [torch.tensor(example["attention_mask"]) for example in batch] 243 | ) 244 | return {"input_ids": input_ids, "attention_mask": attention_mask} 245 | 246 | 247 | # Loop over different max_length values 248 | for max_length in max_lengths: 249 | print(f"Starting training for max_length = {max_length}") 250 | 251 | # Prepare Dataset 252 | tokenized_dataset = prepare_dataset(tokenizer, max_length=max_length) 253 | train_dataloader = DataLoader( 254 | tokenized_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn 255 | ) 256 | model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="./", token='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') 257 | model = model.to(sf.float_type).to(device) 258 | 259 | quantized_model, last_epoch = load_checkpoint(model, sf.bits, suffix="opt", device=device) 260 | quantized_model.to(device) 261 | print(f"Resuming training from epoch {last_epoch + 1}.") 262 | 263 | # Check if model parameters are within range before training 264 | check_parameters_in_range(quantized_model, sf) 265 | 266 | # del model 267 | torch.cuda.empty_cache() 268 | gc.collect() 269 | 270 | optimizer = torch.optim.Adam(quantized_model.parameters(), lr=1e-5, eps=1e-4) 271 | loss_fn = torch.nn.CrossEntropyLoss() 272 | 273 | # Calculate and print the original model's perplexity before training 274 | print("Calculating original model perplexity...") 275 | original_perplexity = calculate_perplexity(model, train_dataloader, loss_fn, device) 276 | print(f"Original model perplexity: {original_perplexity:.4f}") 277 | 278 | num_epochs = 10 279 | accumulation_steps = 16 280 | 281 | for epoch in range(num_epochs): 282 | epoch_loss = 0.0 283 | epoch_iterator = tqdm( 284 | enumerate(train_dataloader), 285 | total=len(train_dataloader), 286 | desc=f"Epoch {epoch + 1}/{num_epochs}", 287 | ) 288 | 289 | for step, batch in epoch_iterator: 290 | input_ids = batch["input_ids"].to(device) 291 | attention_mask = batch["attention_mask"].to(device) 292 | outputs = quantized_model(input_ids=input_ids, attention_mask=attention_mask) 293 | logits = outputs.logits 294 | target = input_ids[:, 1:].contiguous() 295 | logits = logits[:, :-1].contiguous() 296 | 297 | loss = loss_fn(logits.view(-1, logits.size(-1)), target.view(-1)) 298 | loss = loss / accumulation_steps 299 | loss.backward() 300 | 301 | epoch_loss += loss.item() * accumulation_steps 302 | 303 | if (step + 1) % accumulation_steps == 0: 304 | optimizer.step() 305 | optimizer.zero_grad() 306 | epoch_iterator.set_postfix({"Loss": f"{loss.item() * accumulation_steps:.4f}"}) 307 | 308 | epoch_loss /= len(train_dataloader) 309 | print(f"Epoch {epoch + 1} completed with average loss: {epoch_loss:.4f}") 310 | 311 | # Calculate and print the perplexity after each epoch 312 | epoch_perplexity = calculate_perplexity(quantized_model, train_dataloader, loss_fn, device) 313 | print(f"Epoch {epoch + 1} perplexity: {epoch_perplexity:.4f}") 314 | 315 | model_path = f"sf{sf.bits}_{max_length}_{epoch + 1}_opt" 316 | torch.save(quantized_model.state_dict(), model_path) 317 | 318 | # Upload model to Hugging Face 319 | os.system( 320 | f"huggingface-cli upload aoxo/qwen2-idkwtf {model_path} --token='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'" 321 | ) 322 | 323 | del quantized_model 324 | torch.cuda.empty_cache() 325 | gc.collect() 326 | 327 | print(f"Completed training for max_length = {max_length}") 328 | 329 | # Entry point to run locally 330 | @app.local_entrypoint() 331 | def main(): 332 | train_and_upload.remote() -------------------------------------------------------------------------------- /src/modal/modal_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import torch 4 | from transformers import AutoTokenizer, AutoModelForCausalLM 5 | from tqdm import tqdm 6 | from datasets import load_dataset 7 | 8 | # Device setup 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | print(f"Using device: {device}") 11 | 12 | base_dir = "./qwen2-sf4" 13 | os.makedirs(base_dir, exist_ok=True) 14 | 15 | # Existing Superfloat and other helper functions remain the same as in the previous script 16 | # (Copying the Superfloat class, load_model, calculate_perplexity functions from previous script) 17 | 18 | # Function to load model 19 | def load_model(model_path): 20 | model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir='./', token='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', trust_remote_code=True) 21 | model.load_state_dict(torch.load(model_path, map_location=device)) 22 | model = model.to(torch.bfloat16).to(device) 23 | model.eval() # Ensure model is in inference mode 24 | return model 25 | 26 | # Define Superfloat quantizer for clamping activations 27 | class Superfloat: 28 | def __init__(self, bits: int): 29 | assert 4 <= bits <= 16, "Superfloat bitwidth must be between 4 and 16." 30 | self.bits = bits 31 | self.mantissa_bits = bits - 1 32 | self.max_val = 1 - 2**-self.mantissa_bits # Precompute max representable value 33 | 34 | def encode(self, value: torch.Tensor) -> torch.Tensor: 35 | """Encodes a tensor of values into the superfloat format with optimized operations.""" 36 | # Clip tensor values to the valid range for SFx 37 | clipped_value = torch.clamp(value, min=-self.max_val, max=self.max_val) 38 | 39 | # Calculate mantissa representation element-wise 40 | mantissa = (torch.abs(clipped_value) * (2**self.mantissa_bits - 1) / self.max_val).floor().to(torch.int32) 41 | 42 | # Create the superfloat representation (1 bit for sign and mantissa bits) 43 | sign = (clipped_value < 0).to(torch.int32) 44 | return (mantissa | (sign << self.mantissa_bits)).to(torch.int32) 45 | 46 | def decode(self, encoded_value: torch.Tensor) -> torch.Tensor: 47 | """Decodes a tensor of encoded superfloat values to regular floats.""" 48 | # Extract mantissa and sign from the encoded superfloat 49 | mantissa = encoded_value & ((1 << self.mantissa_bits) - 1) 50 | sign = (encoded_value >> self.mantissa_bits) & 1 51 | 52 | # Calculate the decoded float using the mantissa and max_val 53 | decoded_value = (mantissa.to(torch.bfloat16) / (2**self.mantissa_bits - 1)) * self.max_val 54 | return decoded_value * (2 * sign - 1) # Apply the sign 55 | 56 | def tensor_quantize(self, tensor: torch.Tensor) -> torch.Tensor: 57 | """Quantizes a tensor to the superfloat format, preserving the tensor's shape.""" 58 | # Apply element-wise encoding to the entire tensor and then decode back 59 | encoded_tensor = self.encode(tensor) 60 | decoded_tensor = self.decode(encoded_tensor) 61 | return decoded_tensor 62 | 63 | # Initialize Superfloat quantizer for sf{sf.bits}amping 64 | sf = Superfloat(4) 65 | 66 | model_name = "Qwen/Qwen2-0.5B" 67 | 68 | # Load tokenizer 69 | tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='./', token='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', trust_remote_code=True) 70 | tokenizer.pad_token = tokenizer.eos_token 71 | 72 | def quantized_inference(model, tokenizer, prompt, max_length=4096): 73 | """Runs inference on a prompt with activation quantization using Superfloat.""" 74 | # Encode input prompt 75 | inputs = tokenizer(prompt, return_tensors="pt", padding=True) 76 | input_ids = inputs.input_ids.to(device) 77 | attention_mask = inputs.attention_mask.to(device) 78 | 79 | with torch.no_grad(): 80 | # Perform generation with clamped activations 81 | outputs = model.generate( 82 | input_ids=input_ids, 83 | attention_mask=attention_mask, 84 | max_length=max_length, 85 | do_sample=True, 86 | top_k=50, 87 | top_p=0.95, 88 | temperature=0.7, 89 | pad_token_id=tokenizer.eos_token_id 90 | ) 91 | 92 | # Decode generated output 93 | generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) 94 | return generated_text 95 | 96 | def calculate_perplexity(model, tokenizer, prompt): 97 | """Calculates the perplexity of the model on a given prompt.""" 98 | # Tokenize input 99 | inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) 100 | input_ids = inputs.input_ids.to(device) 101 | attention_mask = inputs.attention_mask.to(device) 102 | 103 | # Get model outputs (logits) 104 | with torch.no_grad(): 105 | outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids) 106 | 107 | # Get the loss (cross entropy) from the model's output 108 | loss = outputs.loss # This is the cross-entropy loss 109 | 110 | # Compute perplexity: exp(loss) 111 | perplexity = torch.exp(loss) 112 | return perplexity.item() 113 | 114 | # Model paths 115 | import os 116 | 117 | def get_model_paths(base_dir, sf_bits): 118 | """ 119 | Dynamically generate model paths based on the sf.bits format. 120 | Looks for models of the form: 121 | 1. sf{sf_bits}_vanilla 122 | 2. sf{sf_bits}_{epoch_num}_fpm 123 | 3. sf{sf_bits}_{epoch_num}_opt 124 | 125 | Args: 126 | base_dir (str): The directory where the models are stored. 127 | sf_bits (int): The bitwidth for the Superfloat quantizer. 128 | 129 | Returns: 130 | List of model paths. 131 | """ 132 | model_paths = [] 133 | model_pattern = f"sf{sf_bits}_" 134 | 135 | # Scan directory for models matching the pattern 136 | for model_name in os.listdir(base_dir): 137 | if model_name.startswith(model_pattern): 138 | model_paths.append(os.path.join(base_dir, model_name)) 139 | 140 | # Ensure models are sorted to follow the desired order: vanilla -> fpm -> opt 141 | model_paths.sort() 142 | 143 | return model_paths 144 | 145 | # Function to evaluate perplexity for a list of models and prompts 146 | def evaluate_models(base_dir, sf_bits, tokenizer, prompts): 147 | """ 148 | Evaluates models dynamically loaded based on the sf.bits format. 149 | 150 | Args: 151 | base_dir (str): The directory where the models are stored. 152 | sf_bits (int): The bitwidth for the Superfloat quantizer. 153 | tokenizer: The tokenizer to use for model inference. 154 | prompts: The list of prompts to evaluate. 155 | 156 | Returns: 157 | Dictionary with model names and their corresponding average perplexity. 158 | """ 159 | model_perplexities = {} 160 | 161 | # Get dynamically generated model paths 162 | models = get_model_paths(base_dir, sf_bits) 163 | 164 | for model_path in models: 165 | model = load_model(model_path) 166 | print(f"Evaluating model: {model_path}") 167 | 168 | total_perplexity = 0.0 169 | num_prompts = len(prompts) 170 | 171 | # Compute perplexity for each prompt 172 | for prompt in tqdm(prompts, desc=f"Processing {model_path}", leave=False): 173 | perplexity = calculate_perplexity(model, tokenizer, prompt) 174 | total_perplexity += perplexity 175 | 176 | # Average perplexity for the current model 177 | avg_perplexity = total_perplexity / num_prompts 178 | model_perplexities[model_path] = avg_perplexity 179 | print(f"Average Perplexity for {model_path}: {avg_perplexity}") 180 | 181 | return model_perplexities 182 | 183 | # Function to load the HellaSwag dataset 184 | def load_hellaswag_data(): 185 | """Load the HellaSwag dataset from Hugging Face.""" 186 | dataset = load_dataset("hellaswag", split='validation', trust_remote_code=True) 187 | 188 | # Extract only the prompts (contexts) for evaluation 189 | prompts = [entry['ctx'] for entry in dataset] 190 | 191 | # Return the prompts as a list 192 | return prompts 193 | 194 | # Load HellaSwag data (prompts) 195 | prompts = load_hellaswag_data() 196 | 197 | 198 | # Function to download model 199 | def download_model(model_name): 200 | """Download a specific model using wget.""" 201 | url = f"https://huggingface.co/aoxo/qwen2-sf4/resolve/main/{model_name}" 202 | download_path = os.path.join(base_dir, model_name) 203 | 204 | try: 205 | # Use wget to download the model 206 | subprocess.run(["wget", "-O", download_path, url], check=True) 207 | print(f"Successfully downloaded {model_name}") 208 | return download_path 209 | except subprocess.CalledProcessError as e: 210 | print(f"Error downloading {model_name}: {e}") 211 | return None 212 | 213 | # Existing functions from previous script (load_model, calculate_perplexity, etc.) 214 | # [Keep all the previous implementations of these functions] 215 | 216 | # Main evaluation and cleanup function 217 | def process_models(model_list): 218 | """ 219 | Process models by downloading, evaluating, and deleting them. 220 | 221 | Args: 222 | model_list (list): List of model names to process 223 | """ 224 | # Load tokenizer once 225 | model_name = "Qwen/Qwen2-0.5B" 226 | tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='./', token='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') 227 | tokenizer.pad_token = tokenizer.eos_token 228 | 229 | # Load HellaSwag prompts 230 | prompts = load_hellaswag_data() 231 | 232 | # Results tracking 233 | results = {} 234 | 235 | for model_filename in model_list: 236 | print(f"\n--- Processing {model_filename} ---") 237 | 238 | # Download model 239 | model_path = download_model(model_filename) 240 | 241 | if not model_path: 242 | print(f"Skipping {model_filename} due to download failure") 243 | continue 244 | 245 | try: 246 | # Load model 247 | model = load_model(model_path) 248 | 249 | # Evaluate perplexity 250 | total_perplexity = 0.0 251 | for prompt in tqdm(prompts, desc=f"Evaluating {model_filename}"): 252 | perplexity = calculate_perplexity(model, tokenizer, prompt) 253 | total_perplexity += perplexity 254 | 255 | avg_perplexity = total_perplexity / len(prompts) 256 | results[model_filename] = avg_perplexity 257 | print(f"Average Perplexity for {model_filename}: {avg_perplexity}") 258 | 259 | except Exception as e: 260 | print(f"Error processing {model_filename}: {e}") 261 | 262 | # Delete the model file 263 | try: 264 | os.remove(model_path) 265 | print(f"Deleted {model_filename}") 266 | except Exception as e: 267 | print(f"Error deleting {model_filename}: {e}") 268 | 269 | # Save results to a file 270 | with open('perplexity_results.txt', 'w') as f: 271 | for model, perplexity in sorted(results.items(), key=lambda x: x[1]): 272 | f.write(f"{model}: {perplexity}\n") 273 | 274 | print("\nResults saved to perplexity_results.txt") 275 | return results 276 | 277 | # List of models to process 278 | models_to_process = [ 279 | 'sf4_1024_1_opt', 'sf4_128_1_opt', 'sf4_16_1_opt', 280 | 'sf4_2048_1_opt', 'sf4_256_1_opt', 'sf4_2_1_opt', 281 | 'sf4_32_1_opt', 'sf4_4096_1_opt', 'sf4_4_1_opt', 282 | 'sf4_512_1_opt', 'sf4_64_1_opt', 'sf4_8_1_opt', 283 | 'sf4_1024_2_opt', 'sf4_128_2_opt', 'sf4_16_2_opt', 284 | 'sf4_2048_2_opt', 'sf4_256_2_opt', 'sf4_2_2_opt', 285 | 'sf4_32_2_opt', 'sf4_4096_2_opt', 'sf4_4_2_opt', 286 | 'sf4_512_2_opt', 'sf4_64_2_opt', 'sf4_8_2_opt', 287 | 'sf4_1024_3_opt', 'sf4_128_3_opt', 'sf4_16_3_opt', 288 | 'sf4_2048_3_opt', 'sf4_256_3_opt', 'sf4_2_3_opt', 289 | 'sf4_32_3_opt', 'sf4_4096_3_opt', 'sf4_4_3_opt', 290 | 'sf4_512_3_opt', 'sf4_64_3_opt', 'sf4_8_3_opt' 291 | ] 292 | 293 | # Run the processing 294 | results = process_models(models_to_process) 295 | 296 | # Print sorted results 297 | print("\nSorted Perplexity Results:") 298 | for model, perplexity in sorted(results.items(), key=lambda x: x[1]): 299 | print(f"{model}: {perplexity}") -------------------------------------------------------------------------------- /src/modal/modal_fasteropt.py: -------------------------------------------------------------------------------- 1 | import modal 2 | 3 | # Create a Modal image with the required dependencies 4 | image = ( 5 | modal.Image.debian_slim() 6 | .pip_install( 7 | "torch", 8 | "transformers", 9 | "datasets", 10 | "tqdm", 11 | "huggingface_hub", 12 | ) 13 | .apt_install("gcc", "python3-dev") # Add necessary system libraries if needed 14 | ) 15 | 16 | app = modal.App("qwen-sf4-experimental") 17 | 18 | # Define the function that runs the script 19 | # @app.function(gpu=modal.gpu.A100(size="80GB"), image=image, timeout=86400) 20 | @app.function(gpu="A100", image=image, timeout=86400) 21 | def train_and_upload(): 22 | import torch 23 | import gc 24 | import os 25 | import re 26 | import requests 27 | from tqdm import tqdm 28 | from datasets import Dataset 29 | from torch.utils.data import DataLoader 30 | from transformers import AutoModelForCausalLM, AutoTokenizer 31 | import pandas as pd 32 | 33 | # List of dataset URLs 34 | urls = [ 35 | "https://huggingface.co/datasets/EleutherAI/the_pile_deduplicated/resolve/main/data/train-00000-of-01650-f70471ee3deb09c0.parquet", 36 | "https://huggingface.co/datasets/EleutherAI/the_pile_deduplicated/resolve/main/data/train-00001-of-01650-172fc1a0c346b36e.parquet" 37 | ] 38 | 39 | # Local final output file path 40 | final_file_name = "train.parquet" 41 | 42 | # Check if the final file already exists 43 | if not os.path.exists(final_file_name): 44 | print(f"Downloading and combining dataset from {len(urls)} files...") 45 | 46 | # List to hold all the dataframes 47 | combined_df = pd.DataFrame() 48 | 49 | # Loop through each URL to download and combine the files 50 | for i, url in enumerate(urls): 51 | downloaded_file = f"temp_file_{i}.parquet" 52 | 53 | # Download the dataset 54 | print(f"Downloading dataset from {url}...") 55 | response = requests.get(url, stream=True) 56 | with open(downloaded_file, "wb") as f: 57 | for chunk in response.iter_content(chunk_size=8192): 58 | f.write(chunk) 59 | print(f"Downloaded to {downloaded_file}.") 60 | 61 | # Read the downloaded parquet file and append to the combined dataframe 62 | df = pd.read_parquet(downloaded_file) 63 | combined_df = pd.concat([combined_df, df], ignore_index=True) 64 | 65 | # Optionally remove the temporary file after reading 66 | os.remove(downloaded_file) 67 | 68 | # Save the combined dataframe as a final parquet file 69 | combined_df.to_parquet(final_file_name) 70 | print(f"Combined data saved to {final_file_name}.") 71 | else: 72 | print(f"{final_file_name} already exists. Skipping download.") 73 | 74 | 75 | # max_lengths = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096] 76 | max_lengths = [256] 77 | bit = 4 78 | 79 | class Superfloat: 80 | CASTING_TABLE = { 81 | 16: torch.float32, 82 | 15: torch.float32, 83 | 14: torch.float32, 84 | 13: torch.float32, 85 | 12: torch.float32, 86 | 11: torch.float16, 87 | 10: torch.float16, 88 | 9: torch.float16, 89 | 8: torch.bfloat16, 90 | 7: torch.bfloat16, 91 | 6: torch.bfloat16, 92 | 5: torch.bfloat16, 93 | 4: torch.bfloat16, 94 | } 95 | 96 | def __init__(self, bits: int): 97 | assert 4 <= bits <= 16, "Superfloat bitwidth must be between 4 and 16." 98 | self.bits = bits 99 | self.mantissa_bits = bits - 1 100 | self.max_val = 1 - 2**-self.mantissa_bits 101 | self.float_type = self.CASTING_TABLE[bits] 102 | 103 | def encode(self, value: torch.Tensor): 104 | clipped_value = torch.clamp(value, min=-self.max_val, max=self.max_val) 105 | out_of_range = (value.abs() > self.max_val) 106 | mantissa = ( 107 | (torch.abs(clipped_value) * (2**self.mantissa_bits - 1) / self.max_val) 108 | .floor() 109 | .to(torch.int32) 110 | ) 111 | sign = (clipped_value < 0).to(torch.int32) 112 | return (mantissa | (sign << self.mantissa_bits)).to(torch.int32), out_of_range 113 | 114 | def decode(self, encoded_value: torch.Tensor): 115 | mantissa = encoded_value & ((1 << self.mantissa_bits) - 1) 116 | sign = (encoded_value >> self.mantissa_bits) & 1 117 | decoded_value = ( 118 | mantissa.to(self.float_type) 119 | / (2**self.mantissa_bits - 1) 120 | * self.max_val 121 | ) 122 | return decoded_value * (2 * sign - 1) 123 | 124 | def tensor_quantize(self, tensor: torch.Tensor): 125 | encoded_tensor, out_of_range = self.encode(tensor) 126 | decoded_tensor = self.decode(encoded_tensor) 127 | return decoded_tensor, out_of_range 128 | 129 | sf = Superfloat(bit) 130 | 131 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 132 | print(f"Using device: {device}") 133 | 134 | # Initialize model and tokenizer 135 | model_name = "Qwen/Qwen2-0.5B" 136 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 137 | tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./", token='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') 138 | tokenizer.pad_token = tokenizer.eos_token 139 | 140 | def quantize_model(model, sf_type): 141 | for name, param in model.named_parameters(): 142 | quantized_param, _ = sf_type.tensor_quantize(param) 143 | param.data = quantized_param.data 144 | return model 145 | 146 | def load_checkpoint(model, sf_bits, suffix="opt", device="cuda"): 147 | """ 148 | Load the latest checkpoint based on the provided Superfloat bitwidth and filename suffix. 149 | 150 | Args: 151 | quantized_model: The model to load the checkpoint into. 152 | sf_bits: Bitwidth of the Superfloat format (e.g., 11). 153 | suffix: The suffix of the filename (default: 'opt'). 154 | device: Device to load the model onto ('cuda' or 'cpu'). 155 | 156 | Returns: 157 | The quantized model with loaded weights and the epoch number. 158 | """ 159 | # Define the filename pattern to search for 160 | checkpoint_pattern = re.compile(f"sf{sf_bits}_.*_epoch(\\d+)_.*{suffix}$") 161 | 162 | # Find all matching checkpoint files 163 | checkpoint_files = [ 164 | f for f in os.listdir(".") if checkpoint_pattern.match(f) 165 | ] 166 | 167 | if not checkpoint_files: 168 | print(f"No checkpoints found for sf{sf_bits} with suffix '{suffix}'.") 169 | return quantize_model(model, sf), 0 170 | 171 | # Extract epoch numbers and sort by latest epoch 172 | epochs_and_files = [ 173 | (int(checkpoint_pattern.match(f).group(1)), f) for f in checkpoint_files 174 | ] 175 | latest_epoch, latest_checkpoint = max(epochs_and_files, key=lambda x: x[0]) 176 | 177 | # Load the latest checkpoint 178 | print(f"Loading checkpoint: {latest_checkpoint}") 179 | checkpoint = torch.load(latest_checkpoint, map_location=device) 180 | model.load_state_dict(checkpoint) 181 | model.to(device) 182 | 183 | return model, latest_epoch 184 | 185 | # Pre-training parameter check to ensure they are within range 186 | def check_parameters_in_range(model, sf): 187 | out_of_range_params = [] 188 | for name, param in model.named_parameters(): 189 | if not torch.all(torch.abs(param.data) <= sf.max_val): 190 | out_of_range_params.append(name) 191 | if out_of_range_params: 192 | print(f"Warning: The following parameters are out of range:") 193 | for param_name in out_of_range_params: 194 | print(f"- {param_name}") 195 | else: 196 | print("All parameters are within the valid range.") 197 | 198 | 199 | def prepare_dataset(tokenizer, max_length=1): 200 | dataset = Dataset.from_parquet("train.parquet") 201 | 202 | def tokenize_function(examples): 203 | return tokenizer( 204 | examples["text"], 205 | truncation=True, 206 | max_length=max_length, 207 | padding="max_length", 208 | return_tensors="pt", 209 | ) 210 | 211 | tokenized_dataset = dataset.map( 212 | tokenize_function, batched=True, remove_columns=dataset.column_names 213 | ) 214 | return tokenized_dataset 215 | 216 | def collate_fn(batch): 217 | input_ids = torch.stack( 218 | [torch.tensor(example["input_ids"]) for example in batch] 219 | ) 220 | attention_mask = torch.stack( 221 | [torch.tensor(example["attention_mask"]) for example in batch] 222 | ) 223 | return {"input_ids": input_ids, "attention_mask": attention_mask} 224 | 225 | 226 | # Loop over different max_length values 227 | for max_length in max_lengths: 228 | print(f"Starting training for max_length = {max_length}") 229 | 230 | # Prepare Dataset 231 | tokenized_dataset = prepare_dataset(tokenizer, max_length=max_length) 232 | train_dataloader = DataLoader( 233 | tokenized_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn 234 | ) 235 | model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="./", token='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') 236 | model = model.to(sf.float_type).to(device) 237 | quantized_model, last_epoch = load_checkpoint(model, sf.bits, suffix="opt", device=device) 238 | quantized_model.to(device) 239 | print(f"Resuming training from epoch {last_epoch + 1}.") 240 | 241 | # Check if model parameters are within range before training 242 | check_parameters_in_range(quantized_model, sf) 243 | 244 | del model 245 | torch.cuda.empty_cache() 246 | gc.collect() 247 | 248 | optimizer = torch.optim.Adam(quantized_model.parameters(), lr=1e-5, eps=1e-4) 249 | loss_fn = torch.nn.CrossEntropyLoss() 250 | 251 | num_epochs = 10 252 | accumulation_steps = 16 253 | 254 | for epoch in range(num_epochs): 255 | epoch_loss = 0.0 256 | epoch_iterator = tqdm( 257 | enumerate(train_dataloader), 258 | total=len(train_dataloader), 259 | desc=f"Epoch {epoch + 1}/{num_epochs}", 260 | ) 261 | 262 | for step, batch in epoch_iterator: 263 | input_ids = batch["input_ids"].to(device) 264 | attention_mask = batch["attention_mask"].to(device) 265 | outputs = quantized_model(input_ids=input_ids, attention_mask=attention_mask) 266 | logits = outputs.logits 267 | target = input_ids[:, 1:].contiguous() 268 | logits = logits[:, :-1].contiguous() 269 | 270 | loss = loss_fn(logits.view(-1, logits.size(-1)), target.view(-1)) 271 | loss = loss / accumulation_steps 272 | loss.backward() 273 | 274 | epoch_loss += loss.item() * accumulation_steps 275 | 276 | if (step + 1) % accumulation_steps == 0: 277 | optimizer.step() 278 | optimizer.zero_grad() 279 | epoch_iterator.set_postfix({"Loss": f"{loss.item() * accumulation_steps:.4f}"}) 280 | 281 | epoch_loss /= len(train_dataloader) 282 | print(f"Epoch {epoch + 1} completed with average loss: {epoch_loss:.4f}") 283 | 284 | model_path = f"sf{sf.bits}_{max_length}_{epoch + 1}_opt" 285 | torch.save(quantized_model.state_dict(), model_path) 286 | 287 | # Upload model to Hugging Face 288 | os.system( 289 | f"huggingface-cli upload aoxo/qwen2-sf4-experimental {model_path} --token='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'" 290 | ) 291 | 292 | del quantized_model 293 | torch.cuda.empty_cache() 294 | gc.collect() 295 | 296 | print(f"Completed training for max_length = {max_length}") 297 | 298 | # Entry point to run locally 299 | @app.local_entrypoint() 300 | def main(): 301 | train_and_upload.remote() -------------------------------------------------------------------------------- /src/test/emulate_sf.py: -------------------------------------------------------------------------------- 1 | def sf_mul(bin1: str, bin2: str, n_bits: int = None) -> str: 2 | """ 3 | Multiplies two signed fixed-point binary fractional numbers in s.xxx format. 4 | 5 | Args: 6 | - bin1, bin2: binary strings like '0.101' or '1.011' 7 | - n_bits: number of fractional bits (excluding sign). If None, inferred. 8 | 9 | Returns: 10 | - Result in s.xxx format with sign bit and n_bits fractional bits. 11 | """ 12 | # Extract sign and fractional parts 13 | sign1, frac1 = bin1[0], bin1[2:] 14 | sign2, frac2 = bin2[0], bin2[2:] 15 | 16 | # Infer n_bits if not given 17 | if n_bits is None: 18 | n_bits = max(len(frac1), len(frac2)) 19 | 20 | # Pad fractions to match n_bits 21 | frac1 = frac1.ljust(n_bits, '0') 22 | frac2 = frac2.ljust(n_bits, '0') 23 | 24 | # Convert to integers 25 | int1 = int(frac1, 2) 26 | int2 = int(frac2, 2) 27 | 28 | # Apply signs 29 | if sign1 == '1': 30 | int1 = -int1 31 | if sign2 == '1': 32 | int2 = -int2 33 | 34 | # Multiply 35 | product = int1 * int2 36 | 37 | # Result needs 2 * n_bits for full precision 38 | product_bits = 2 * n_bits 39 | abs_product = abs(product) 40 | product_bin = bin(abs_product)[2:].zfill(product_bits) 41 | 42 | # Take the top n_bits as fractional result 43 | fractional_part = product_bin[:n_bits] 44 | 45 | # Determine sign bit 46 | sign_bit = '0' if product >= 0 else '1' 47 | 48 | return f"{sign_bit}.{fractional_part}" 49 | 50 | 51 | 52 | 53 | # Example usage: 54 | # print(sf_mul('0.101', '0.110')) # Should return '0.011' 55 | # print(sf_mul('1.101', '0.110')) # Should return '1.011' 56 | # print(sf_mul('1.101', '1.110')) # Should return '0.011' 57 | 58 | print(sf_mul('0.1110001','0.0001001')) -------------------------------------------------------------------------------- /src/test/emulate_sf_dec.py: -------------------------------------------------------------------------------- 1 | def sf_mul(bin1: str, bin2: str, n_bits: int = None) -> str: 2 | """ 3 | Multiplies two signed fixed-point binary fractional numbers in s.xxx format. 4 | 5 | Args: 6 | - bin1, bin2: binary strings like '0.101' or '1.011' 7 | - n_bits: number of fractional bits (excluding sign). If None, inferred. 8 | 9 | Returns: 10 | - Result in s.xxx format with sign bit and n_bits fractional bits. 11 | """ 12 | # Extract sign and fractional parts 13 | sign1, frac1 = bin1[0], bin1[2:] 14 | sign2, frac2 = bin2[0], bin2[2:] 15 | # Infer n_bits if not given 16 | if n_bits is None: 17 | n_bits = max(len(frac1), len(frac2)) 18 | # Pad fractions to match n_bits 19 | frac1 = frac1.ljust(n_bits, '0') 20 | frac2 = frac2.ljust(n_bits, '0') 21 | # Convert to integers 22 | int1 = int(frac1, 2) 23 | int2 = int(frac2, 2) 24 | # Apply signs 25 | if sign1 == '1': 26 | int1 = -int1 27 | if sign2 == '1': 28 | int2 = -int2 29 | # Multiply 30 | product = int1 * int2 31 | # Result needs 2 * n_bits for full precision 32 | product_bits = 2 * n_bits 33 | abs_product = abs(product) 34 | product_bin = bin(abs_product)[2:].zfill(product_bits) 35 | # Take the top n_bits as fractional result 36 | fractional_part = product_bin[:n_bits] 37 | # Determine sign bit 38 | sign_bit = '0' if product >= 0 else '1' 39 | return f"{sign_bit}.{fractional_part}" 40 | 41 | 42 | def decimal_to_sf(decimal_val: float, n_bits: int = 8) -> str: 43 | """ 44 | Converts a decimal number in range [-1, 1) to signed fixed-point binary format. 45 | 46 | Args: 47 | - decimal_val: decimal number between -1 and 1 48 | - n_bits: number of fractional bits 49 | 50 | Returns: 51 | - Binary string in s.xxx format 52 | """ 53 | if not (-1 <= decimal_val < 1): 54 | raise ValueError("Decimal value must be in range [-1, 1)") 55 | 56 | # Determine sign 57 | sign_bit = '0' if decimal_val >= 0 else '1' 58 | abs_val = abs(decimal_val) 59 | 60 | # Convert fractional part to binary 61 | frac_binary = "" 62 | for _ in range(n_bits): 63 | abs_val *= 2 64 | if abs_val >= 1: 65 | frac_binary += '1' 66 | abs_val -= 1 67 | else: 68 | frac_binary += '0' 69 | 70 | return f"{sign_bit}.{frac_binary}" 71 | 72 | 73 | def sf_to_decimal(sf_binary: str) -> float: 74 | """ 75 | Converts signed fixed-point binary format to decimal. 76 | 77 | Args: 78 | - sf_binary: binary string in s.xxx format 79 | 80 | Returns: 81 | - Decimal equivalent 82 | """ 83 | sign_bit = sf_binary[0] 84 | frac_part = sf_binary[2:] 85 | 86 | # Convert fractional part to decimal 87 | decimal_val = 0 88 | for i, bit in enumerate(frac_part): 89 | if bit == '1': 90 | decimal_val += 2 ** (-(i + 1)) 91 | 92 | # Apply sign 93 | if sign_bit == '1': 94 | decimal_val = -decimal_val 95 | 96 | return decimal_val 97 | 98 | 99 | def sf_mul_dec(dec1: float, dec2: float, n_bits: int = 8) -> dict: 100 | """ 101 | Multiplies two decimal numbers in SF range [-1, 1) and returns result in both formats. 102 | 103 | Args: 104 | - dec1, dec2: decimal numbers between -1 and 1 105 | - n_bits: number of fractional bits for binary representation 106 | 107 | Returns: 108 | - Dictionary with 'sf_binary', 'decimal', 'inputs_sf', and 'exact_decimal' keys 109 | """ 110 | # Validate inputs 111 | if not (-1 <= dec1 < 1) or not (-1 <= dec2 < 1): 112 | raise ValueError("Both decimal values must be in range [-1, 1)") 113 | 114 | # Convert decimals to SF binary format 115 | sf1 = decimal_to_sf(dec1, n_bits) 116 | sf2 = decimal_to_sf(dec2, n_bits) 117 | 118 | # Perform SF multiplication 119 | sf_result = sf_mul(sf1, sf2, n_bits) 120 | 121 | # Convert result back to decimal 122 | result_decimal = sf_to_decimal(sf_result) 123 | 124 | # Calculate exact decimal multiplication for comparison 125 | exact_decimal = dec1 * dec2 126 | 127 | return { 128 | 'sf_binary': sf_result, 129 | 'decimal': result_decimal, 130 | 'inputs_sf': (sf1, sf2), 131 | 'exact_decimal': exact_decimal, 132 | 'error': abs(exact_decimal - result_decimal) 133 | } 134 | 135 | result = sf_mul_dec(0.5, 0.213, 16) 136 | print(f" SF inputs: {result['inputs_sf'][0]} × {result['inputs_sf'][1]}") 137 | print(f" SF result: {result['sf_binary']}") 138 | print(f" Decimal result: {result['decimal']:.6f}") 139 | print(f" Exact decimal: {result['exact_decimal']:.6f}") 140 | print(f" Error: {result['error']:.6f}") -------------------------------------------------------------------------------- /src/test/sf_tensors.py: -------------------------------------------------------------------------------- 1 | def sf_mul(bin1: str, bin2: str, n_bits: int = None) -> str: 2 | """ 3 | Multiplies two signed fixed-point binary fractional numbers in s.xxx format. 4 | 5 | Args: 6 | - bin1, bin2: binary strings like '0.101' or '1.011' 7 | - n_bits: number of fractional bits (excluding sign). If None, inferred. 8 | 9 | Returns: 10 | - Result in s.xxx format with sign bit and n_bits fractional bits. 11 | """ 12 | # Extract sign and fractional parts 13 | sign1, frac1 = bin1[0], bin1[2:] 14 | sign2, frac2 = bin2[0], bin2[2:] 15 | # Infer n_bits if not given 16 | if n_bits is None: 17 | n_bits = max(len(frac1), len(frac2)) 18 | # Pad fractions to match n_bits 19 | frac1 = frac1.ljust(n_bits, '0') 20 | frac2 = frac2.ljust(n_bits, '0') 21 | # Convert to integers 22 | int1 = int(frac1, 2) 23 | int2 = int(frac2, 2) 24 | # Apply signs 25 | if sign1 == '1': 26 | int1 = -int1 27 | if sign2 == '1': 28 | int2 = -int2 29 | # Multiply 30 | product = int1 * int2 31 | # Result needs 2 * n_bits for full precision 32 | product_bits = 2 * n_bits 33 | abs_product = abs(product) 34 | product_bin = bin(abs_product)[2:].zfill(product_bits) 35 | # Take the top n_bits as fractional result 36 | fractional_part = product_bin[:n_bits] 37 | # Determine sign bit 38 | sign_bit = '0' if product >= 0 else '1' 39 | return f"{sign_bit}.{fractional_part}" 40 | 41 | 42 | def decimal_to_sf(decimal_val: float, n_bits: int = 8) -> str: 43 | """ 44 | Converts a decimal number to signed fixed-point binary format with overflow clamping. 45 | 46 | Args: 47 | - decimal_val: decimal number (will be clamped to [-1, 1) range) 48 | - n_bits: number of fractional bits 49 | 50 | Returns: 51 | - Binary string in s.xxx format 52 | """ 53 | # Clamp to valid SF range 54 | max_val = 1.0 - (2 ** (-n_bits)) # Maximum positive value (just under 1.0) 55 | min_val = -1.0 # Minimum negative value 56 | 57 | if decimal_val >= 1.0: 58 | decimal_val = max_val 59 | elif decimal_val < -1.0: 60 | decimal_val = min_val 61 | 62 | # Determine sign 63 | sign_bit = '0' if decimal_val >= 0 else '1' 64 | abs_val = abs(decimal_val) 65 | 66 | # Convert fractional part to binary 67 | frac_binary = "" 68 | for _ in range(n_bits): 69 | abs_val *= 2 70 | if abs_val >= 1: 71 | frac_binary += '1' 72 | abs_val -= 1 73 | else: 74 | frac_binary += '0' 75 | 76 | return f"{sign_bit}.{frac_binary}" 77 | 78 | 79 | def sf_to_decimal(sf_binary: str) -> float: 80 | """ 81 | Converts signed fixed-point binary format to decimal. 82 | 83 | Args: 84 | - sf_binary: binary string in s.xxx format 85 | 86 | Returns: 87 | - Decimal equivalent 88 | """ 89 | sign_bit = sf_binary[0] 90 | frac_part = sf_binary[2:] 91 | 92 | # Convert fractional part to decimal 93 | decimal_val = 0 94 | for i, bit in enumerate(frac_part): 95 | if bit == '1': 96 | decimal_val += 2 ** (-(i + 1)) 97 | 98 | # Apply sign 99 | if sign_bit == '1': 100 | decimal_val = -decimal_val 101 | 102 | return decimal_val 103 | 104 | 105 | def sf_mul_dec(dec1: float, dec2: float, n_bits: int = 8) -> dict: 106 | """ 107 | Multiplies two decimal numbers using SF arithmetic with overflow clamping. 108 | 109 | Args: 110 | - dec1, dec2: decimal numbers (will be clamped to SF range if needed) 111 | - n_bits: number of fractional bits for binary representation 112 | 113 | Returns: 114 | - Dictionary with 'sf_binary', 'decimal', 'inputs_sf', 'exact_decimal', and overflow info 115 | """ 116 | # Store original values for exact calculation 117 | original_dec1, original_dec2 = dec1, dec2 118 | 119 | # Clamp inputs to valid SF range 120 | max_val = 1.0 - (2 ** (-n_bits)) 121 | min_val = -1.0 122 | 123 | input1_overflow = False 124 | input2_overflow = False 125 | 126 | if dec1 >= 1.0: 127 | dec1 = max_val 128 | input1_overflow = True 129 | elif dec1 < -1.0: 130 | dec1 = min_val 131 | input1_overflow = True 132 | 133 | if dec2 >= 1.0: 134 | dec2 = max_val 135 | input2_overflow = True 136 | elif dec2 < -1.0: 137 | dec2 = min_val 138 | input2_overflow = True 139 | 140 | # Convert decimals to SF binary format 141 | sf1 = decimal_to_sf(dec1, n_bits) 142 | sf2 = decimal_to_sf(dec2, n_bits) 143 | 144 | # Perform SF multiplication 145 | sf_result = sf_mul(sf1, sf2, n_bits) 146 | 147 | # Convert result back to decimal 148 | result_decimal = sf_to_decimal(sf_result) 149 | 150 | # Calculate exact decimal multiplication for comparison 151 | exact_decimal = original_dec1 * original_dec2 152 | 153 | return { 154 | 'sf_binary': sf_result, 155 | 'decimal': result_decimal, 156 | 'inputs_sf': (sf1, sf2), 157 | 'exact_decimal': exact_decimal, 158 | 'error': abs(exact_decimal - result_decimal), 159 | 'input1_overflow': input1_overflow, 160 | 'input2_overflow': input2_overflow, 161 | 'clamped_inputs': (dec1, dec2) 162 | } 163 | 164 | 165 | def sf_add_dec(dec1: float, dec2: float, n_bits: int = 8) -> dict: 166 | """ 167 | Adds two decimal numbers using SF arithmetic with overflow clamping. 168 | 169 | Args: 170 | - dec1, dec2: decimal numbers (will be clamped to SF range if needed) 171 | - n_bits: number of fractional bits for binary representation 172 | 173 | Returns: 174 | - Dictionary with result information 175 | """ 176 | # Store originals and clamp inputs 177 | original_dec1, original_dec2 = dec1, dec2 178 | max_val = 1.0 - (2 ** (-n_bits)) 179 | min_val = -1.0 180 | 181 | input1_overflow = False 182 | input2_overflow = False 183 | 184 | if dec1 >= 1.0: 185 | dec1 = max_val 186 | input1_overflow = True 187 | elif dec1 < -1.0: 188 | dec1 = min_val 189 | input1_overflow = True 190 | 191 | if dec2 >= 1.0: 192 | dec2 = max_val 193 | input2_overflow = True 194 | elif dec2 < -1.0: 195 | dec2 = min_val 196 | input2_overflow = True 197 | 198 | # Convert to SF format 199 | sf1 = decimal_to_sf(dec1, n_bits) 200 | sf2 = decimal_to_sf(dec2, n_bits) 201 | 202 | # Extract components 203 | sign1, frac1 = sf1[0], sf1[2:] 204 | sign2, frac2 = sf2[0], sf2[2:] 205 | 206 | # Convert to signed integers 207 | int1 = int(frac1, 2) 208 | int2 = int(frac2, 2) 209 | 210 | if sign1 == '1': 211 | int1 = -int1 212 | if sign2 == '1': 213 | int2 = -int2 214 | 215 | # Add 216 | result = int1 + int2 217 | 218 | # Handle overflow/underflow in SF integer space 219 | max_sf_int = (1 << n_bits) - 1 # Maximum positive SF integer 220 | min_sf_int = -(1 << n_bits) # Minimum negative SF integer 221 | 222 | result_overflow = False 223 | if result > max_sf_int: 224 | result = max_sf_int 225 | result_overflow = True 226 | elif result < min_sf_int: 227 | result = min_sf_int 228 | result_overflow = True 229 | 230 | # Convert back to SF format 231 | abs_result = abs(result) 232 | sign_bit = '0' if result >= 0 else '1' 233 | frac_binary = bin(abs_result)[2:].zfill(n_bits) 234 | 235 | sf_result = f"{sign_bit}.{frac_binary}" 236 | result_decimal = sf_to_decimal(sf_result) 237 | exact_decimal = original_dec1 + original_dec2 238 | 239 | return { 240 | 'sf_binary': sf_result, 241 | 'decimal': result_decimal, 242 | 'exact_decimal': exact_decimal, 243 | 'error': abs(exact_decimal - result_decimal), 244 | 'result_overflow': result_overflow, 245 | 'input1_overflow': input1_overflow, 246 | 'input2_overflow': input2_overflow 247 | } 248 | 249 | 250 | def sf_tensor_mul(tensor_a: list, tensor_b: list, n_bits: int = 8) -> dict: 251 | """ 252 | Multiplies two 2D tensors using signed fixed-point arithmetic with overflow clamping. 253 | 254 | Args: 255 | - tensor_a, tensor_b: 2D lists representing matrices (values will be clamped to SF range) 256 | - n_bits: number of fractional bits for SF representation 257 | 258 | Returns: 259 | - Dictionary with SF result, decimal result, exact result, and error analysis 260 | """ 261 | # Validate dimensions 262 | rows_a = len(tensor_a) 263 | cols_a = len(tensor_a[0]) if rows_a > 0 else 0 264 | rows_b = len(tensor_b) 265 | cols_b = len(tensor_b[0]) if rows_b > 0 else 0 266 | 267 | if cols_a != rows_b: 268 | raise ValueError(f"Cannot multiply matrices: {rows_a}x{cols_a} × {rows_b}x{cols_b}") 269 | 270 | # Initialize result matrices 271 | sf_result = [[None for _ in range(cols_b)] for _ in range(rows_a)] 272 | decimal_result = [[0.0 for _ in range(cols_b)] for _ in range(rows_a)] 273 | exact_result = [[0.0 for _ in range(cols_b)] for _ in range(rows_a)] 274 | clamped_exact_result = [[0.0 for _ in range(cols_b)] for _ in range(rows_a)] 275 | 276 | total_error = 0.0 277 | max_error = 0.0 278 | input_overflow_count = 0 279 | result_overflow_count = 0 280 | 281 | max_sf_val = 1.0 - (2 ** (-n_bits)) 282 | min_sf_val = -1.0 283 | 284 | # Perform matrix multiplication using SF arithmetic 285 | for i in range(rows_a): 286 | for j in range(cols_b): 287 | # Compute dot product of row i of A and column j of B 288 | sf_sum = 0.0 # Accumulate in decimal for intermediate sums 289 | exact_sum = 0.0 290 | 291 | for k in range(cols_a): 292 | # Track input overflows 293 | if tensor_a[i][k] >= 1.0 or tensor_a[i][k] < -1.0: 294 | input_overflow_count += 1 295 | if tensor_b[k][j] >= 1.0 or tensor_b[k][j] < -1.0: 296 | input_overflow_count += 1 297 | 298 | # SF multiplication (this handles input clamping internally) 299 | mul_result = sf_mul_dec(tensor_a[i][k], tensor_b[k][j], n_bits) 300 | 301 | # Add to running sum (in decimal space) 302 | sf_sum += mul_result['decimal'] 303 | exact_sum += tensor_a[i][k] * tensor_b[k][j] 304 | 305 | # Clamp sf_sum to valid SF range 306 | original_sf_sum = sf_sum 307 | if sf_sum >= 1.0: 308 | sf_sum = max_sf_val 309 | result_overflow_count += 1 310 | elif sf_sum < -1.0: 311 | sf_sum = min_sf_val 312 | result_overflow_count += 1 313 | 314 | # Convert final sum to SF format 315 | sf_binary = decimal_to_sf(sf_sum, n_bits) 316 | sf_result[i][j] = sf_binary 317 | decimal_result[i][j] = sf_to_decimal(sf_binary) 318 | exact_result[i][j] = exact_sum 319 | 320 | # Clamp exact result for error comparison 321 | if exact_sum >= 1.0: 322 | clamped_exact_result[i][j] = max_sf_val 323 | elif exact_sum < -1.0: 324 | clamped_exact_result[i][j] = min_sf_val 325 | else: 326 | clamped_exact_result[i][j] = exact_sum 327 | 328 | # Track errors against clamped exact result 329 | error = abs(clamped_exact_result[i][j] - decimal_result[i][j]) 330 | total_error += error 331 | max_error = max(max_error, error) 332 | 333 | return { 334 | 'sf_result': sf_result, 335 | 'decimal_result': decimal_result, 336 | 'exact_result': exact_result, 337 | 'clamped_exact_result': clamped_exact_result, 338 | 'total_error': total_error, 339 | 'average_error': total_error / (rows_a * cols_b), 340 | 'max_error': max_error, 341 | 'input_overflow_count': input_overflow_count, 342 | 'result_overflow_count': result_overflow_count, 343 | 'result_shape': (rows_a, cols_b) 344 | } 345 | 346 | 347 | def print_tensor(tensor, title, precision=6): 348 | """Helper function to print tensors nicely""" 349 | print(f"{title}:") 350 | for row in tensor: 351 | formatted_row = [] 352 | for val in row: 353 | if isinstance(val, str): # SF binary format 354 | formatted_row.append(f"{val:>12}") 355 | else: # Decimal format 356 | formatted_row.append(f"{val:>{precision+6}.{precision}f}") 357 | print(" [" + ", ".join(formatted_row) + "]") 358 | print() 359 | 360 | 361 | # Test tensor multiplication 362 | print("\n" + "="*60) 363 | print("TENSOR MULTIPLICATION EXAMPLE") 364 | print("="*60) 365 | 366 | A = [ 367 | [0.32, -0.75, 0.11], 368 | [-0.58, 0.94, -0.23], 369 | [0.67, -0.12, -0.81] 370 | ] 371 | 372 | B = [ 373 | [-0.44, 0.09, 0.65], 374 | [0.77, -0.36, -0.52], 375 | [-0.19, 0.84, 0.27] 376 | ] 377 | 378 | print("Input Tensors:") 379 | print_tensor(A, "Tensor A (Decimal)", 2) 380 | 381 | # Convert input tensors to SF format for display 382 | A_sf = [[decimal_to_sf(A[i][j], 16) for j in range(len(A[i]))] for i in range(len(A))] 383 | B_sf = [[decimal_to_sf(B[i][j], 16) for j in range(len(B[i]))] for i in range(len(B))] 384 | 385 | print_tensor(A_sf, "Tensor A (SF Binary)") 386 | print_tensor(B, "Tensor B (Decimal)", 2) 387 | print_tensor(B_sf, "Tensor B (SF Binary)") 388 | 389 | # Perform SF tensor multiplication 390 | result = sf_tensor_mul(A, B, n_bits=16) 391 | 392 | print("Results:") 393 | print_tensor(result['sf_result'], "SF Binary Result") 394 | print_tensor(result['decimal_result'], "SF Decimal Result", 6) 395 | print_tensor(result['exact_result'], "Exact Decimal Result", 6) 396 | print_tensor(result['clamped_exact_result'], "Clamped Exact Result", 6) 397 | 398 | print(f"Error Analysis (SF vs Clamped Exact):") 399 | print(f" Total Error: {result['total_error']:.8f}") 400 | print(f" Average Error: {result['average_error']:.8f}") 401 | print(f" Maximum Error: {result['max_error']:.8f}") 402 | print(f" Input Overflow Count: {result['input_overflow_count']}") 403 | print(f" Result Overflow Count: {result['result_overflow_count']}") 404 | print(f" Result Shape: {result['result_shape']}") -------------------------------------------------------------------------------- /src/verilog/activationUnit_16bit.sv: -------------------------------------------------------------------------------- 1 | module ActivationUnit ( 2 | input wire clk, 3 | input wire start, // Start processing 4 | input wire select, // 0: Linear, 1: ReLU 5 | input wire [15:0] buffer, // 16-bit buffer holding matrix values 6 | output reg [15:0] result // Output after activation 7 | ); 8 | reg signed [15:0] data; // Assuming 16-bit fixed/floating point numbers 9 | reg signed [15:0] output_data; 10 | 11 | always @(posedge clk) begin 12 | if (start) begin 13 | // Extract value from the buffer 14 | data = buffer; 15 | 16 | // Apply activation function 17 | if (select == 1'b1) begin 18 | // ReLU: Max(0, value) 19 | output_data = (data < 0) ? 0 : data; 20 | end else begin 21 | // Linear: Identity function 22 | output_data = data; 23 | end 24 | 25 | // Store result 26 | result = output_data; 27 | end 28 | end 29 | endmodule 30 | -------------------------------------------------------------------------------- /src/verilog/shiftRegister_16bit.sv: -------------------------------------------------------------------------------- 1 | module ShiftRegister #(parameter SIZE = 16, DEPTH = 8) ( 2 | input wire clk, 3 | input wire reset, 4 | input wire shift_en, 5 | input wire [SIZE-1:0] data_in, 6 | output reg [SIZE-1:0] data_out 7 | ); 8 | 9 | reg [SIZE-1:0] shift_reg [DEPTH-1:0]; // Single chain of 8 registers, each 16-bit 10 | integer i; 11 | 12 | always @(posedge clk or posedge reset) begin 13 | if (reset) begin 14 | for (i = 0; i < DEPTH; i = i + 1) begin 15 | shift_reg[i] <= 0; 16 | end 17 | data_out <= 0; 18 | end else if (shift_en) begin 19 | for (i = DEPTH-1; i > 0; i = i - 1) begin 20 | shift_reg[i] <= shift_reg[i-1]; 21 | end 22 | shift_reg[0] <= data_in; 23 | data_out <= shift_reg[DEPTH-1]; 24 | end 25 | end 26 | endmodule 27 | -------------------------------------------------------------------------------- /src/wasq/wasq_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer, AutoModelForCausalLM 3 | from tqdm import tqdm 4 | from datasets import load_dataset 5 | 6 | # Device setup 7 | if torch.backends.mps.is_available(): 8 | device = torch.device("mps") 9 | elif torch.cuda.is_available(): 10 | device = torch.device("cuda") 11 | else: 12 | device = torch.device("cpu") 13 | 14 | print(f"Using device: {device}") 15 | 16 | base_dir = "./" 17 | 18 | # Function to load model 19 | def load_model(model_path): 20 | model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir='./', token='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') 21 | model.load_state_dict(torch.load(model_path, map_location=device)) 22 | model = model.to(torch.bfloat16).to(device) 23 | model.eval() # Ensure model is in inference mode 24 | return model 25 | 26 | # Define Superfloat quantizer for clamping activations 27 | class Superfloat: 28 | def __init__(self, bits: int): 29 | assert 4 <= bits <= 16, "Superfloat bitwidth must be between 4 and 16." 30 | self.bits = bits 31 | self.mantissa_bits = bits - 1 32 | self.max_val = 1 - 2**-self.mantissa_bits # Precompute max representable value 33 | 34 | def encode(self, value: torch.Tensor) -> torch.Tensor: 35 | """Encodes a tensor of values into the superfloat format with optimized operations.""" 36 | # Clip tensor values to the valid range for SFx 37 | clipped_value = torch.clamp(value, min=-self.max_val, max=self.max_val) 38 | 39 | # Calculate mantissa representation element-wise 40 | mantissa = (torch.abs(clipped_value) * (2**self.mantissa_bits - 1) / self.max_val).floor().to(torch.int32) 41 | 42 | # Create the superfloat representation (1 bit for sign and mantissa bits) 43 | sign = (clipped_value < 0).to(torch.int32) 44 | return (mantissa | (sign << self.mantissa_bits)).to(torch.int32) 45 | 46 | def decode(self, encoded_value: torch.Tensor) -> torch.Tensor: 47 | """Decodes a tensor of encoded superfloat values to regular floats.""" 48 | # Extract mantissa and sign from the encoded superfloat 49 | mantissa = encoded_value & ((1 << self.mantissa_bits) - 1) 50 | sign = (encoded_value >> self.mantissa_bits) & 1 51 | 52 | # Calculate the decoded float using the mantissa and max_val 53 | decoded_value = (mantissa.to(torch.bfloat16) / (2**self.mantissa_bits - 1)) * self.max_val 54 | return decoded_value * (2 * sign - 1) # Apply the sign 55 | 56 | def tensor_quantize(self, tensor: torch.Tensor) -> torch.Tensor: 57 | """Quantizes a tensor to the superfloat format, preserving the tensor's shape.""" 58 | # Apply element-wise encoding to the entire tensor and then decode back 59 | encoded_tensor = self.encode(tensor) 60 | decoded_tensor = self.decode(encoded_tensor) 61 | return decoded_tensor 62 | 63 | # Initialize Superfloat quantizer for sf{sf.bits}amping 64 | sf = Superfloat(8) 65 | 66 | model_name = "meta-llama/Llama-3.2-1B" 67 | 68 | # Load tokenizer 69 | tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='./', token='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') 70 | tokenizer.pad_token = tokenizer.eos_token 71 | 72 | def quantized_inference(model, tokenizer, prompt, max_length=500): 73 | """Runs inference on a prompt with activation quantization using Superfloat.""" 74 | # Encode input prompt 75 | inputs = tokenizer(prompt, return_tensors="pt", padding=True) 76 | input_ids = inputs.input_ids.to(device) 77 | attention_mask = inputs.attention_mask.to(device) 78 | 79 | with torch.no_grad(): 80 | # Perform generation with clamped activations 81 | outputs = model.generate( 82 | input_ids=input_ids, 83 | attention_mask=attention_mask, 84 | max_length=max_length, 85 | do_sample=True, 86 | top_k=50, 87 | top_p=0.95, 88 | temperature=0.7, 89 | pad_token_id=tokenizer.eos_token_id 90 | ) 91 | 92 | # Decode generated output 93 | generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) 94 | return generated_text 95 | 96 | def calculate_perplexity(model, tokenizer, prompt): 97 | """Calculates the perplexity of the model on a given prompt.""" 98 | # Tokenize input 99 | inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) 100 | input_ids = inputs.input_ids.to(device) 101 | attention_mask = inputs.attention_mask.to(device) 102 | 103 | # Get model outputs (logits) 104 | with torch.no_grad(): 105 | outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids) 106 | 107 | # Get the loss (cross entropy) from the model's output 108 | loss = outputs.loss # This is the cross-entropy loss 109 | 110 | # Compute perplexity: exp(loss) 111 | perplexity = torch.exp(loss) 112 | return perplexity.item() 113 | 114 | # Model paths 115 | import os 116 | 117 | def get_model_paths(base_dir, sf_bits): 118 | """ 119 | Dynamically generate model paths based on the sf.bits format. 120 | Looks for models of the form: 121 | 1. sf{sf_bits}_vanilla 122 | 2. sf{sf_bits}_{epoch_num}_fpm 123 | 3. sf{sf_bits}_{epoch_num}_opt 124 | 125 | Args: 126 | base_dir (str): The directory where the models are stored. 127 | sf_bits (int): The bitwidth for the Superfloat quantizer. 128 | 129 | Returns: 130 | List of model paths. 131 | """ 132 | model_paths = [] 133 | model_pattern = f"sf{sf_bits}_" 134 | 135 | # Scan directory for models matching the pattern 136 | for model_name in os.listdir(base_dir): 137 | if model_name.startswith(model_pattern): 138 | model_paths.append(os.path.join(base_dir, model_name)) 139 | 140 | # Ensure models are sorted to follow the desired order: vanilla -> fpm -> opt 141 | model_paths.sort() 142 | 143 | return model_paths 144 | 145 | # Function to evaluate perplexity for a list of models and prompts 146 | def evaluate_models(base_dir, sf_bits, tokenizer, prompts): 147 | """ 148 | Evaluates models dynamically loaded based on the sf.bits format. 149 | 150 | Args: 151 | base_dir (str): The directory where the models are stored. 152 | sf_bits (int): The bitwidth for the Superfloat quantizer. 153 | tokenizer: The tokenizer to use for model inference. 154 | prompts: The list of prompts to evaluate. 155 | 156 | Returns: 157 | Dictionary with model names and their corresponding average perplexity. 158 | """ 159 | model_perplexities = {} 160 | 161 | # Get dynamically generated model paths 162 | models = get_model_paths(base_dir, sf_bits) 163 | 164 | for model_path in models: 165 | model = load_model(model_path) 166 | print(f"Evaluating model: {model_path}") 167 | 168 | total_perplexity = 0.0 169 | num_prompts = len(prompts) 170 | 171 | # Compute perplexity for each prompt 172 | for prompt in tqdm(prompts, desc=f"Processing {model_path}", leave=False): 173 | perplexity = calculate_perplexity(model, tokenizer, prompt) 174 | total_perplexity += perplexity 175 | 176 | # Average perplexity for the current model 177 | avg_perplexity = total_perplexity / num_prompts 178 | model_perplexities[model_path] = avg_perplexity 179 | print(f"Average Perplexity for {model_path}: {avg_perplexity}") 180 | 181 | return model_perplexities 182 | 183 | # Function to load the HellaSwag dataset 184 | def load_hellaswag_data(): 185 | """Load the HellaSwag dataset from Hugging Face.""" 186 | dataset = load_dataset("hellaswag", split='validation') 187 | 188 | # Extract only the prompts (contexts) for evaluation 189 | prompts = [entry['ctx'] for entry in dataset] 190 | 191 | # Return the prompts as a list 192 | return prompts 193 | 194 | # Load HellaSwag data (prompts) 195 | prompts = load_hellaswag_data() 196 | 197 | # Evaluate all models on HellaSwag prompts 198 | model_perplexities = evaluate_models(base_dir, sf.bits, tokenizer, prompts) 199 | 200 | # Print final results 201 | print("\nAverage Perplexities for all models:") 202 | for model_path, avg_perplexity in model_perplexities.items(): 203 | print(f"{model_path}: {avg_perplexity}") -------------------------------------------------------------------------------- /src/wasq/wasq_fasteropt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | from torch.utils.data import DataLoader 5 | from datasets import load_dataset, Dataset 6 | from tqdm import tqdm 7 | import gc 8 | 9 | max_length = 512 10 | bit = 8 11 | ERROR_BUDGET = 0.01 # maximum allowed relative quantization error per layer 12 | 13 | class Superfloat: 14 | CASTING_TABLE = { 15 | 16: torch.float32, 16 | 15: torch.float32, 17 | 14: torch.float32, 18 | 13: torch.float32, 19 | 12: torch.float32, 20 | 11: torch.float16, 21 | 10: torch.float16, 22 | 9: torch.float16, 23 | 8: torch.bfloat16, 24 | 7: torch.bfloat16, 25 | 6: torch.bfloat16, 26 | 5: torch.bfloat16, 27 | 4: torch.bfloat16, 28 | } 29 | 30 | def __init__(self, bits: int): 31 | assert 4 <= bits <= 16, "Superfloat bitwidth must be between 4 and 16." 32 | self.bits = bits 33 | self.mantissa_bits = bits - 1 34 | self.max_val = 1 - 2**-self.mantissa_bits # Precompute max representable value 35 | self.float_type = self.CASTING_TABLE[bits] # Get float type based on bitwidth 36 | 37 | def encode(self, value: torch.Tensor) -> torch.Tensor: 38 | """Encodes a tensor of values into the superfloat format.""" 39 | clipped_value = torch.clamp(value, min=-self.max_val, max=self.max_val) 40 | out_of_range = (value.abs() > self.max_val) 41 | mantissa = (torch.abs(clipped_value) * (2**self.mantissa_bits - 1) / self.max_val).floor().to(torch.int32) 42 | sign = (clipped_value < 0).to(torch.int32) 43 | return (mantissa | (sign << self.mantissa_bits)).to(torch.int32), out_of_range 44 | 45 | def decode(self, encoded_value: torch.Tensor) -> torch.Tensor: 46 | """Decodes a tensor of encoded superfloat values to regular floats.""" 47 | mantissa = encoded_value & ((1 << self.mantissa_bits) - 1) 48 | sign = (encoded_value >> self.mantissa_bits) & 1 49 | decoded_value = (mantissa.to(self.float_type) / (2**self.mantissa_bits - 1)) * self.max_val 50 | return decoded_value * (2 * sign - 1) 51 | 52 | def tensor_quantize(self, tensor: torch.Tensor) -> torch.Tensor: 53 | """Quantizes a tensor to the superfloat format, preserving the tensor's shape.""" 54 | encoded_tensor, out_of_range = self.encode(tensor) 55 | decoded_tensor = self.decode(encoded_tensor) 56 | return decoded_tensor, out_of_range 57 | 58 | sf = Superfloat(bit) 59 | 60 | 61 | class SFQuantFunction(torch.autograd.Function): 62 | """Straight-through estimator for Superfloat quantisation.""" 63 | 64 | @staticmethod 65 | def forward(ctx, tensor, sf_obj): 66 | encoded, mask = sf_obj.encode(tensor) 67 | ctx.save_for_backward(mask) 68 | ctx.sf_obj = sf_obj 69 | return sf_obj.decode(encoded) 70 | 71 | @staticmethod 72 | def backward(ctx, grad_output): 73 | (mask,) = ctx.saved_tensors 74 | return grad_output * mask.to(grad_output.dtype), None 75 | 76 | 77 | class QuantizedLinear(nn.Module): 78 | """Linear layer with weights encoded once and decoded on-the-fly.""" 79 | 80 | def __init__(self, linear: nn.Linear, sf_bits: int): 81 | super().__init__() 82 | self.sf = Superfloat(sf_bits) 83 | self.in_features = linear.in_features 84 | self.out_features = linear.out_features 85 | self.bias = nn.Parameter(linear.bias.detach()) if linear.bias is not None else None 86 | self.register_buffer("encoded_weight", None) 87 | self.register_buffer("outlier_mask", None) 88 | self.register_buffer("outlier_values", None) 89 | # learnable per-channel scale (LSQ+ style) 90 | self.scale = nn.Parameter(torch.ones(linear.out_features, dtype=self.sf.float_type)) 91 | self.encode_weight(linear.weight) 92 | 93 | def encode_weight(self, weight, outlier_percent=0.5): 94 | # Encode once and split top-k outliers 95 | encoded, mask = self.sf.encode(weight) 96 | if outlier_percent: 97 | k = max(1, int(outlier_percent / 100.0 * weight.numel())) 98 | thresh = torch.topk(weight.abs().view(-1), k).values[-1] 99 | mask |= weight.abs() >= thresh 100 | self.encoded_weight = encoded 101 | self.outlier_mask = mask 102 | self.outlier_values = weight[mask] 103 | 104 | def forward(self, x): 105 | w = self.sf.decode(self.encoded_weight) 106 | if self.outlier_mask.any(): 107 | w = w.clone() 108 | w[self.outlier_mask] = self.outlier_values 109 | w = w * self.scale.unsqueeze(1) 110 | x = SFQuantFunction.apply(x, self.sf) 111 | return nn.functional.linear(x, w, self.bias) 112 | 113 | 114 | def quant_error(weight: torch.Tensor, sf_bits: int) -> float: 115 | sf_tmp = Superfloat(sf_bits) 116 | quant, _ = sf_tmp.tensor_quantize(weight) 117 | return torch.norm(weight - quant) / torch.norm(weight) 118 | 119 | 120 | def search_layer_bitwidth(weight: torch.Tensor, bits_list) -> int: 121 | for b in sorted(bits_list): 122 | if quant_error(weight, b) <= ERROR_BUDGET: 123 | return b 124 | return max(bits_list) 125 | 126 | 127 | def quantize_linear_layers(model, bits_candidates): 128 | for name, module in model.named_modules(): 129 | if isinstance(module, nn.Linear): 130 | chosen = search_layer_bitwidth(module.weight.data, bits_candidates) 131 | qlin = QuantizedLinear(module, chosen) 132 | parent = model 133 | name_parts = name.split('.') 134 | for n in name_parts[:-1]: 135 | parent = getattr(parent, n) 136 | setattr(parent, name_parts[-1], qlin) 137 | 138 | class QuantizedLlamaModel(torch.nn.Module): 139 | def __init__(self, base_model: torch.nn.Module, sf_quantizer: Superfloat): 140 | super(QuantizedLlamaModel, self).__init__() 141 | self.base_model = base_model 142 | self.sf_quantizer = sf_quantizer 143 | # Replace Linear layers with quantised versions 144 | quantize_linear_layers(self.base_model, [4, 8, 11, 16]) 145 | self.apply_gradient_hooks() 146 | 147 | def apply_gradient_hooks(self): 148 | for param in self.base_model.parameters(): 149 | def hook(grad, param=param): 150 | _, mask = self.sf_quantizer.encode(param) 151 | return grad * mask.to(grad.dtype) 152 | param.register_hook(hook) 153 | 154 | def forward(self, *args, **kwargs): 155 | if "input_ids" in kwargs: 156 | kwargs["input_ids"] = SFQuantFunction.apply(kwargs["input_ids"], self.sf_quantizer) 157 | outputs = self.base_model(*args, **kwargs) 158 | if hasattr(outputs, "logits"): 159 | outputs.logits = SFQuantFunction.apply(outputs.logits, self.sf_quantizer) 160 | return outputs 161 | 162 | # Initialize model and tokenizer 163 | if torch.backends.mps.is_available(): 164 | device = torch.device("mps") 165 | elif torch.cuda.is_available(): 166 | device = torch.device("cuda") 167 | else: 168 | device = torch.device("cpu") 169 | 170 | print(f"Using device: {device}") 171 | 172 | model_name = "Qwen/Qwen2-0.5B" 173 | model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir='./', token='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') 174 | model = model.to(sf.float_type).to(device) 175 | 176 | tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='./', token='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') 177 | tokenizer.pad_token = tokenizer.eos_token 178 | 179 | # Quantize Model Weights Selectively 180 | def quantize_model(model, sf_type): 181 | for name, param in model.named_parameters(): 182 | quantized_param, _ = sf_type.tensor_quantize(param) 183 | param.data = quantized_param.data 184 | return model 185 | 186 | import os 187 | import re 188 | 189 | def load_checkpoint(model, sf_bits, suffix="opt", device=device): 190 | """ 191 | Load the latest checkpoint based on the provided Superfloat bitwidth and filename suffix. 192 | 193 | Args: 194 | quantized_model: The model to load the checkpoint into. 195 | sf_bits: Bitwidth of the Superfloat format (e.g., 11). 196 | suffix: The suffix of the filename (default: 'opt'). 197 | device: Device to load the model onto ('cuda' or 'cpu'). 198 | 199 | Returns: 200 | The quantized model with loaded weights and the epoch number. 201 | """ 202 | # Define the filename pattern to search for 203 | checkpoint_pattern = re.compile(f"sf{sf_bits}_.*_epoch(\\d+)_.*{suffix}$") 204 | 205 | # Find all matching checkpoint files 206 | checkpoint_files = [ 207 | f for f in os.listdir(".") if checkpoint_pattern.match(f) 208 | ] 209 | 210 | if not checkpoint_files: 211 | print(f"No checkpoints found for sf{sf_bits} with suffix '{suffix}'.") 212 | return quantize_model(model, sf), 0 213 | 214 | # Extract epoch numbers and sort by latest epoch 215 | epochs_and_files = [ 216 | (int(checkpoint_pattern.match(f).group(1)), f) for f in checkpoint_files 217 | ] 218 | latest_epoch, latest_checkpoint = max(epochs_and_files, key=lambda x: x[0]) 219 | 220 | # Load the latest checkpoint 221 | print(f"Loading checkpoint: {latest_checkpoint}") 222 | checkpoint = torch.load(latest_checkpoint, map_location=device) 223 | model.load_state_dict(checkpoint) 224 | model.to(device) 225 | 226 | return model, latest_epoch 227 | 228 | # Pre-training parameter check to ensure they are within range 229 | def check_parameters_in_range(model, sf): 230 | out_of_range_params = [] 231 | for name, param in model.named_parameters(): 232 | if not torch.all(torch.abs(param.data) <= sf.max_val): 233 | out_of_range_params.append(name) 234 | if out_of_range_params: 235 | print(f"Warning: The following parameters are out of range:") 236 | for param_name in out_of_range_params: 237 | print(f"- {param_name}") 238 | else: 239 | print("All parameters are within the valid range.") 240 | 241 | # Usage 242 | base_model, last_epoch = load_checkpoint(model, sf.bits, suffix="opt", device=device) 243 | quantized = QuantizedLlamaModel(base_model, sf) 244 | print(f"Resuming training from epoch {last_epoch + 1}.") 245 | 246 | # Check if model parameters are within range before training 247 | check_parameters_in_range(quantized, sf) 248 | 249 | del model 250 | if device=="cuda": 251 | torch.cuda.empty_cache() 252 | elif device=="mps": 253 | torch.mps.empty_cache() 254 | gc.collect() 255 | 256 | # Prepare Dataset 257 | def prepare_dataset(tokenizer, max_length=1): 258 | dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") 259 | # dataset = Dataset.from_parquet('train.parquet') 260 | def tokenize_function(examples): 261 | return tokenizer( 262 | examples["text"], 263 | truncation=True, 264 | max_length=max_length, 265 | padding="max_length", 266 | return_tensors="pt" 267 | ) 268 | tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names) 269 | return tokenized_dataset 270 | 271 | # Custom collate function 272 | def collate_fn(batch): 273 | input_ids = torch.stack([torch.tensor(example['input_ids']) for example in batch]) 274 | attention_mask = torch.stack([torch.tensor(example['attention_mask']) for example in batch]) 275 | return {'input_ids': input_ids, 'attention_mask': attention_mask} 276 | 277 | # Prepare tokenized dataset and dataloader 278 | tokenized_dataset = prepare_dataset(tokenizer, max_length=max_length) 279 | train_dataloader = DataLoader(tokenized_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn) 280 | 281 | # Optimizer and Loss 282 | optimizer = torch.optim.Adam(quantized.parameters(), lr=1e-5, eps=1e-4) 283 | loss_fn = torch.nn.CrossEntropyLoss() 284 | 285 | # Training Loop 286 | num_epochs = 3 287 | accumulation_steps = 8 # Number of steps to accumulate gradients 288 | best_loss = float('inf') 289 | 290 | quantized.to(device) 291 | 292 | for epoch in range(num_epochs): 293 | epoch_loss = 0.0 294 | epoch_iterator = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Epoch {epoch + 1}/{num_epochs}") 295 | 296 | for step, batch in epoch_iterator: 297 | input_ids = batch['input_ids'].to(device) 298 | attention_mask = batch['attention_mask'].to(device) 299 | 300 | # Forward pass 301 | outputs = quantized(input_ids=input_ids, attention_mask=attention_mask) 302 | logits = outputs.logits 303 | target = input_ids[:, 1:].contiguous() 304 | logits = logits[:, :-1].contiguous() 305 | 306 | # Calculate loss 307 | loss = loss_fn(logits.view(-1, logits.size(-1)), target.view(-1)) / accumulation_steps 308 | 309 | # Backward pass with gradient quantization 310 | loss.backward() 311 | 312 | # Accumulate loss for reporting 313 | epoch_loss += loss.item() * accumulation_steps 314 | 315 | if (step + 1) % accumulation_steps == 0: 316 | # Clamp activations and model parameters within the Superfloat range 317 | for name, param in quantized.named_parameters(): 318 | param.data = torch.clamp(param.data, min=-sf.max_val, max=sf.max_val) 319 | 320 | # Check activations range 321 | for name, param in quantized.named_parameters(): 322 | if not torch.all(torch.abs(param.data) <= sf.max_val): 323 | print(f"Warning: {name} activation is out of range after clamping!") 324 | 325 | torch.nn.utils.clip_grad_value_(quantized.parameters(), clip_value=sf.max_val) 326 | optimizer.step() 327 | optimizer.zero_grad() 328 | epoch_iterator.set_postfix({"Loss": f"{loss.item() * accumulation_steps:.4f}"}) 329 | 330 | epoch_loss /= len(train_dataloader) 331 | if epoch_loss < best_loss: 332 | torch.save(quantized.state_dict(), f"sf{sf.bits}_{epoch+1}_opt") 333 | print(f"Epoch {epoch + 1} completed with average loss: {epoch_loss:.4f}") -------------------------------------------------------------------------------- /src/wasq/wasq_fpm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import LlamaForCausalLM, PreTrainedTokenizerFast 3 | import os 4 | import re 5 | import gc 6 | from torch.utils.data import DataLoader 7 | from datasets import load_dataset, Dataset 8 | from tqdm import tqdm 9 | 10 | class Superfloat: 11 | CASTING_TABLE = { 12 | 16: torch.float32, 13 | 15: torch.float32, 14 | 14: torch.float32, 15 | 13: torch.float32, 16 | 12: torch.float32, 17 | 11: torch.float16, 18 | 10: torch.float16, 19 | 9: torch.float16, 20 | 8: torch.bfloat16, 21 | 7: torch.bfloat16, 22 | 6: torch.bfloat16, 23 | 5: torch.bfloat16, 24 | 4: torch.bfloat16, 25 | } 26 | 27 | def __init__(self, bits: int): 28 | assert 4 <= bits <= 16, "Superfloat bitwidth must be between 4 and 16." 29 | self.bits = bits 30 | self.mantissa_bits = bits - 1 31 | self.max_val = 1 - 2**-self.mantissa_bits # Precompute max representable value 32 | self.float_type = self.CASTING_TABLE[bits] # Get float type based on bitwidth 33 | 34 | def encode(self, value: torch.Tensor) -> torch.Tensor: 35 | """Encodes a tensor of values into the superfloat format.""" 36 | clipped_value = torch.clamp(value, min=-self.max_val, max=self.max_val) 37 | mantissa = (torch.abs(clipped_value) * (2**self.mantissa_bits - 1) / self.max_val).floor().to(torch.int32) 38 | sign = (clipped_value < 0).to(torch.int32) 39 | return (mantissa | (sign << self.mantissa_bits)).to(torch.int32) 40 | 41 | def decode(self, encoded_value: torch.Tensor) -> torch.Tensor: 42 | """Decodes a tensor of encoded superfloat values to regular floats.""" 43 | mantissa = encoded_value & ((1 << self.mantissa_bits) - 1) 44 | sign = (encoded_value >> self.mantissa_bits) & 1 45 | decoded_value = (mantissa.to(self.float_type) / (2**self.mantissa_bits - 1)) * self.max_val 46 | return decoded_value * (2 * sign - 1) 47 | 48 | def tensor_quantize(self, tensor: torch.Tensor) -> torch.Tensor: 49 | """Quantizes a tensor to the superfloat format, preserving the tensor's shape.""" 50 | # Apply element-wise encoding to the entire tensor and then decode back 51 | print(f"Params in layer: {len(tensor)}") 52 | encoded_tensor = self.encode(tensor) 53 | print("Encoding complete") 54 | decoded_tensor = self.decode(encoded_tensor) 55 | print("Decoding complete") 56 | return decoded_tensor 57 | 58 | sf = Superfloat(8) 59 | 60 | class QuantizedLlamaModel(torch.nn.Module): 61 | def __init__(self, base_model: torch.nn.Module, sf_quantizer: Superfloat): 62 | super(QuantizedLlamaModel, self).__init__() 63 | self.base_model = base_model 64 | self.sf_quantizer = sf_quantizer 65 | self.apply_gradient_hooks() 66 | 67 | def apply_gradient_hooks(self): 68 | # Register a hook to quantize gradients after backward pass 69 | for param in self.base_model.parameters(): 70 | param.register_hook(lambda grad: self.sf_quantizer.tensor_quantize(grad)) 71 | 72 | def forward(self, x): 73 | # Quantize activations and parameters during forward pass 74 | x = self.sf_quantizer.tensor_quantize(x) 75 | for layer in self.base_model.children(): 76 | if isinstance(layer, torch.nn.Linear): 77 | layer.weight.data = self.sf_quantizer.tensor_quantize(layer.weight.data) 78 | x = self.sf_quantizer.tensor_quantize(layer(x)) 79 | return x 80 | 81 | if torch.backends.mps.is_available(): 82 | device = torch.device("mps") 83 | elif torch.cuda.is_available(): 84 | device = torch.device("cuda") 85 | else: 86 | device = torch.device("cpu") 87 | 88 | print(f"Using device: {device}") 89 | 90 | # Initialize model and tokenizer 91 | model_name = "meta-llama/Llama-3.2-1B" 92 | model = LlamaForCausalLM.from_pretrained(model_name, cache_dir='./', token='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') 93 | model = model.to(sf.float_type).to(device) 94 | 95 | tokenizer = PreTrainedTokenizerFast.from_pretrained(model_name, cache_dir='./', token='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') 96 | 97 | tokenizer.pad_token = tokenizer.eos_token 98 | 99 | def quantize_model(model, sf_type): 100 | for name, param in model.named_parameters(): 101 | print(name, len(param)) 102 | quantized_param = sf_type.tensor_quantize(param) 103 | param.data = quantized_param.data 104 | return model 105 | 106 | # Checker function to verify quantization 107 | def check_model_quantization(model, sf_type): 108 | all_parameters_valid = True 109 | for name, param in model.named_parameters(): 110 | param_data = param.data 111 | if param_data.dtype != sf_type.float_type: 112 | print(f"Parameter {name} is not in {sf_type.float_type} format!") 113 | all_parameters_valid = False 114 | if not torch.all((param_data >= -sf_type.max_val) & (param_data <= sf_type.max_val)): 115 | print(f"Parameter {name} has values outside the SF{sf_type.bits} range!") 116 | all_parameters_valid = False 117 | return all_parameters_valid 118 | 119 | def load_checkpoint(model, sf_bits, suffix="fpm", device=device): 120 | """ 121 | Load the latest checkpoint based on the provided Superfloat bitwidth and filename suffix. 122 | 123 | Args: 124 | quantized_model: The model to load the checkpoint into. 125 | sf_bits: Bitwidth of the Superfloat format (e.g., 11). 126 | suffix: The suffix of the filename (default: 'fpm'). 127 | device: Device to load the model onto ('cuda' or 'cpu'). 128 | 129 | Returns: 130 | The quantized model with loaded weights and the epoch number. 131 | """ 132 | # Define the filename pattern to search for 133 | checkpoint_pattern = re.compile(f"sf{sf_bits}_.*_epoch(\\d+)_.*{suffix}$") 134 | 135 | # Find all matching checkpoint files 136 | checkpoint_files = [ 137 | f for f in os.listdir(".") if checkpoint_pattern.match(f) 138 | ] 139 | 140 | if not checkpoint_files: 141 | print(f"No checkpoints found for sf{sf_bits} with suffix '{suffix}'.") 142 | return quantize_model(model, sf), 0 143 | 144 | # Extract epoch numbers and sort by latest epoch 145 | epochs_and_files = [ 146 | (int(checkpoint_pattern.match(f).group(1)), f) for f in checkpoint_files 147 | ] 148 | latest_epoch, latest_checkpoint = max(epochs_and_files, key=lambda x: x[0]) 149 | 150 | # Load the latest checkpoint 151 | print(f"Loading checkpoint: {latest_checkpoint}") 152 | checkpoint = torch.load(latest_checkpoint, map_location=device) 153 | model.load_state_dict(checkpoint) 154 | model.to(device) 155 | 156 | return model, latest_epoch 157 | 158 | # Usage 159 | quantized, last_epoch = load_checkpoint(model, sf.bits, suffix="fpm", device=device) 160 | print(f"Resuming training from epoch {last_epoch + 1}.") 161 | 162 | del model 163 | if device=='cuda': 164 | # Clear CUDA cache to free up memory 165 | torch.cuda.empty_cache() 166 | elif device=='mps': 167 | # Clear MPS cache to free up memory 168 | torch.mps.empty_cache() 169 | gc.collect() 170 | 171 | # Prepare Dataset 172 | def prepare_dataset(tokenizer, max_length=512): 173 | """Prepare the dataset with proper tensor formatting.""" 174 | dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") 175 | # dataset = Dataset.from_parquet('train.parquet') 176 | def tokenize_function(examples): 177 | outputs = tokenizer( 178 | examples["text"], 179 | truncation=True, 180 | max_length=max_length, 181 | padding="max_length", 182 | return_tensors="pt" 183 | ) 184 | return outputs 185 | 186 | tokenized_dataset = dataset.map( 187 | tokenize_function, 188 | batched=True, 189 | remove_columns=dataset.column_names 190 | ) 191 | 192 | return tokenized_dataset 193 | 194 | # Custom collate function 195 | def collate_fn(batch): 196 | """Custom collate function to properly format tensors.""" 197 | input_ids = torch.stack([torch.tensor(example['input_ids']) for example in batch]) 198 | attention_mask = torch.stack([torch.tensor(example['attention_mask']) for example in batch]) 199 | return { 200 | 'input_ids': input_ids, 201 | 'attention_mask': attention_mask 202 | } 203 | 204 | # Optimizer and Loss 205 | optimizer = torch.optim.Adam(quantized.parameters(), lr=1e-5, eps=1e-4) 206 | loss_fn = torch.nn.CrossEntropyLoss() 207 | 208 | # Prepare tokenized dataset and dataloader 209 | tokenized_dataset = prepare_dataset(tokenizer) 210 | train_dataloader = DataLoader(tokenized_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn) 211 | 212 | num_epochs = 3 213 | accumulation_steps = 8 # Number of steps to accumulate gradients 214 | best_loss = float('inf') 215 | 216 | quantized.to(device) 217 | 218 | # Training Loop with Autoregressive Target, Gradient Accumulation, and Progress Tracking 219 | for epoch in range(num_epochs): 220 | epoch_loss = 0.0 # Track total loss for each epoch 221 | 222 | # Initialize tqdm for tracking epoch progress 223 | epoch_iterator = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Epoch {epoch + 1}/{num_epochs}") 224 | 225 | for step, batch in epoch_iterator: 226 | input_ids = batch['input_ids'].to(device) 227 | attention_mask = batch['attention_mask'].to(device) 228 | 229 | # Forward pass 230 | outputs = quantized(input_ids=input_ids, attention_mask=attention_mask) 231 | 232 | # Access the logits for token predictions 233 | logits = outputs.logits # Retrieve logits tensor from ModelOutput 234 | 235 | # Shift input_ids by one for autoregressive target 236 | target = input_ids[:, 1:].contiguous() # Target is the input shifted by one token 237 | logits = logits[:, :-1].contiguous() # Align logits with target length 238 | 239 | # Calculate loss 240 | loss = loss_fn(logits.view(-1, logits.size(-1)), target.view(-1)) 241 | 242 | # Divide loss by accumulation steps 243 | loss = loss / accumulation_steps 244 | 245 | # Backward pass with gradient quantization 246 | loss.backward() # Gradient quantization occurs via hook 247 | 248 | # Accumulate the loss for reporting 249 | epoch_loss += loss.item() * accumulation_steps 250 | 251 | # Perform optimizer step after accumulating gradients 252 | if (step + 1) % accumulation_steps == 0: 253 | torch.nn.utils.clip_grad_value_(quantized.parameters(), clip_value = sf.max_val) 254 | optimizer.step() 255 | optimizer.zero_grad() # Clear gradients for next accumulation 256 | 257 | # Update tqdm progress bar with current step loss 258 | epoch_iterator.set_postfix({"Loss": f"{loss.item() * accumulation_steps:.4f}"}) 259 | 260 | # Average epoch loss 261 | epoch_loss /= len(train_dataloader) 262 | if epoch_loss < best_loss: 263 | torch.save(quantized.state_dict(), f"sf{sf.bits}_{epoch+1}_fpm") 264 | print(f"Epoch {epoch + 1} completed with average loss: {epoch_loss:.4f}") -------------------------------------------------------------------------------- /src/wasq/wasq_inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import PreTrainedTokenizerFast, LlamaForCausalLM 3 | 4 | # Device setup 5 | if torch.backends.mps.is_available(): 6 | device = torch.device("mps") 7 | elif torch.cuda.is_available(): 8 | device = torch.device("cuda") 9 | else: 10 | device = torch.device("cpu") 11 | 12 | print(f"Using device: {device}") 13 | 14 | # Define Superfloat quantizer for clamping activations 15 | class Superfloat: 16 | def __init__(self, bits: int): 17 | assert 4 <= bits <= 16, "Superfloat bitwidth must be between 4 and 16." 18 | self.bits = bits 19 | self.mantissa_bits = bits - 1 20 | self.max_val = 1 - 2**-self.mantissa_bits # Precompute max representable value 21 | 22 | def encode(self, value: torch.Tensor) -> torch.Tensor: 23 | """Encodes a tensor of values into the superfloat format with optimized operations.""" 24 | # Clip tensor values to the valid range for SFx 25 | clipped_value = torch.clamp(value, min=-self.max_val, max=self.max_val) 26 | 27 | # Calculate mantissa representation element-wise 28 | mantissa = (torch.abs(clipped_value) * (2**self.mantissa_bits - 1) / self.max_val).floor().to(torch.int32) 29 | 30 | # Create the superfloat representation (1 bit for sign and mantissa bits) 31 | sign = (clipped_value < 0).to(torch.int32) 32 | return (mantissa | (sign << self.mantissa_bits)).to(torch.int32) 33 | 34 | def decode(self, encoded_value: torch.Tensor) -> torch.Tensor: 35 | """Decodes a tensor of encoded superfloat values to regular floats.""" 36 | # Extract mantissa and sign from the encoded superfloat 37 | mantissa = encoded_value & ((1 << self.mantissa_bits) - 1) 38 | sign = (encoded_value >> self.mantissa_bits) & 1 39 | 40 | # Calculate the decoded float using the mantissa and max_val 41 | decoded_value = (mantissa.to(torch.bfloat16) / (2**self.mantissa_bits - 1)) * self.max_val 42 | return decoded_value * (2 * sign - 1) # Apply the sign 43 | 44 | def tensor_quantize(self, tensor: torch.Tensor) -> torch.Tensor: 45 | """Quantizes a tensor to the superfloat format, preserving the tensor's shape.""" 46 | # Apply element-wise encoding to the entire tensor and then decode back 47 | encoded_tensor = self.encode(tensor) 48 | decoded_tensor = self.decode(encoded_tensor) 49 | return decoded_tensor 50 | 51 | # Initialize Superfloat quantizer for clamping 52 | sf = Superfloat(8) 53 | 54 | # Load model in bfloat16 directly for inference 55 | model_name = "meta-llama/Llama-3.2-1B" 56 | model = LlamaForCausalLM.from_pretrained(model_name, cache_dir='./', token='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') 57 | model.load_state_dict(torch.load("sf{sf.bits}_trained_epoch3", map_location=device)) 58 | model = model.to(torch.bfloat16).to(device) 59 | model.eval() # Ensure model is in inference mode 60 | 61 | # Load tokenizer 62 | tokenizer = PreTrainedTokenizerFast.from_pretrained(model_name, cache_dir='./', token='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') 63 | tokenizer.pad_token = tokenizer.eos_token 64 | 65 | def quantized_inference(model, tokenizer, prompt, max_length=500): 66 | """Runs inference on a prompt with activation quantization using Superfloat.""" 67 | # Encode input prompt 68 | inputs = tokenizer(prompt, return_tensors="pt", padding=True) 69 | input_ids = inputs.input_ids.to(device) 70 | attention_mask = inputs.attention_mask.to(device) 71 | 72 | with torch.no_grad(): 73 | # Perform generation with clamped activations 74 | outputs = model.generate( 75 | input_ids=input_ids, 76 | attention_mask=attention_mask, 77 | max_length=max_length, 78 | do_sample=True, 79 | top_k=50, 80 | top_p=0.95, 81 | temperature=0.7, 82 | pad_token_id=tokenizer.eos_token_id 83 | ) 84 | 85 | # Decode generated output 86 | generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) 87 | return generated_text 88 | 89 | # Example usage 90 | prompt = "It could be one of those nights, where we don't turn off the lights." 91 | generated_text = quantized_inference(model, tokenizer, prompt) 92 | print("Generated text:", generated_text) -------------------------------------------------------------------------------- /src/wasq/wasq_lth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from transformers import AutoModelForCausalLM, AutoTokenizer 5 | from torch.utils.data import DataLoader 6 | from datasets import Dataset 7 | from tqdm import tqdm 8 | import copy 9 | import numpy as np 10 | 11 | class Superfloat: 12 | CASTING_TABLE = { 13 | 16: torch.float32, 14 | 15: torch.float32, 15 | 14: torch.float32, 16 | 13: torch.float32, 17 | 12: torch.float32, 18 | 11: torch.float16, 19 | 10: torch.float16, 20 | 9: torch.float16, 21 | 8: torch.bfloat16, 22 | 7: torch.bfloat16, 23 | 6: torch.bfloat16, 24 | 5: torch.bfloat16, 25 | 4: torch.bfloat16, 26 | } 27 | 28 | def __init__(self, bits: int): 29 | assert 4 <= bits <= 16, "Superfloat bitwidth must be between 4 and 16." 30 | self.bits = bits 31 | self.mantissa_bits = bits - 1 32 | self.max_val = 1 - 2**-self.mantissa_bits # Precompute max representable value 33 | self.float_type = self.CASTING_TABLE[bits] # Get float type based on bitwidth 34 | 35 | def quantize(self, tensor: torch.Tensor) -> torch.Tensor: 36 | """Quantizes a tensor to Superfloat format.""" 37 | # Per-channel scaling for dynamic range 38 | scale = self.max_val / tensor.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-8) 39 | quantized = torch.clamp(tensor * scale, -self.max_val, self.max_val).round() 40 | return quantized / scale # Dequantize for inference 41 | 42 | if torch.backends.mps.is_available(): 43 | device = torch.device("mps") 44 | elif torch.cuda.is_available(): 45 | device = torch.device("cuda") 46 | else: 47 | device = torch.device("cpu") 48 | 49 | print(f"Using device: {device}") 50 | 51 | class LotteryTicketTrainer: 52 | def __init__(self, model, sf_quantizer, tokenizer, config): 53 | self.device = device 54 | self.sf_quantizer = sf_quantizer 55 | self.model = model.to(device=self.device, dtype=sf_quantizer.float_type) 56 | self.tokenizer = tokenizer 57 | self.config = config 58 | self.original_model_state = copy.deepcopy(self.model.state_dict()) 59 | self.winning_tickets = {} 60 | self.pruning_rate = config.get('pruning_rate', 0.2) 61 | self.pruning_iterations = config.get('pruning_iterations', 3) 62 | self.optimizer = optim.Adam(self.model.parameters(), lr=config.get('learning_rate', 1e-5), eps=config.get('optimizer_eps', 1e-4)) 63 | self.loss_fn = nn.CrossEntropyLoss() 64 | 65 | def prepare_dataset(self, max_length=512): 66 | dataset = Dataset.from_parquet('train.parquet') 67 | 68 | def tokenize_function(examples): 69 | return self.tokenizer( 70 | examples["text"], 71 | truncation=True, 72 | max_length=max_length, 73 | padding="max_length", 74 | return_tensors="pt" 75 | ) 76 | 77 | tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names) 78 | return tokenized_dataset 79 | 80 | def create_dataloader(self, dataset, batch_size=4): 81 | def collate_fn(batch): 82 | input_ids = torch.stack([torch.tensor(example['input_ids']) for example in batch]) 83 | attention_mask = torch.stack([torch.tensor(example['attention_mask']) for example in batch]) 84 | return {'input_ids': input_ids, 'attention_mask': attention_mask} 85 | 86 | return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) 87 | 88 | def magnitude_based_pruning(self): 89 | pruning_masks = {} 90 | 91 | for name, param in self.model.named_parameters(): 92 | if len(param.shape) > 1: # Only prune weight matrices 93 | weight_abs = torch.abs(param.data) 94 | flat_weights = weight_abs.view(-1) 95 | k = int(flat_weights.numel() * self.pruning_rate) 96 | threshold = torch.topk(flat_weights, k, largest=False).values.max() 97 | mask = (weight_abs > threshold).float() 98 | pruning_masks[name] = mask 99 | param.data *= mask 100 | 101 | return pruning_masks 102 | 103 | def reset_to_winning_ticket(self, pruning_masks): 104 | for name, param in self.model.named_parameters(): 105 | if name in pruning_masks: 106 | # Reset to original initialization, then apply mask 107 | param.data.copy_(self.original_model_state[name]) 108 | param.data *= pruning_masks[name] 109 | 110 | def activation_magnitude_analysis(self): 111 | # Compare activations between original and quantized models 112 | with torch.no_grad(): 113 | original_activations = self.get_activations(self.original_model_state) 114 | quantized_activations = self.get_activations(self.model.state_dict()) 115 | return self.compute_layerwise_differences(original_activations, quantized_activations) 116 | 117 | def get_activations(self, model_state): 118 | # Placeholder: Implement forward pass to collect activations 119 | activations = {} 120 | for name, param in model_state.items(): 121 | if len(param.shape) > 1: 122 | activations[name] = torch.mean(torch.abs(param)).item() 123 | return activations 124 | 125 | def compute_layerwise_differences(self, original_activations, quantized_activations): 126 | differences = {} 127 | for name in original_activations: 128 | differences[name] = abs(original_activations[name] - quantized_activations[name]) 129 | return differences 130 | 131 | def fine_tune_based_on_activations(self, layer_activation_changes): 132 | # Fine-tune layers with significant activation change 133 | for layer, change in layer_activation_changes.items(): 134 | if change > self.config.get('activation_threshold', 0.1): 135 | # Fine-tune or adjust this layer specifically 136 | pass # Fine-tune layer weights based on magnitude analysis 137 | 138 | def train(self): 139 | tokenized_dataset = self.prepare_dataset() 140 | dataloader = self.create_dataloader(tokenized_dataset) 141 | 142 | num_epochs = self.config.get('num_epochs', 3) 143 | accumulation_steps = self.config.get('accumulation_steps', 32) 144 | best_loss = float('inf') 145 | 146 | for iteration in range(self.pruning_iterations): 147 | print(f"\nPruning Iteration {iteration + 1}/{self.pruning_iterations}") 148 | 149 | for epoch in range(num_epochs): 150 | self.model.train() 151 | epoch_loss = 0.0 152 | 153 | epoch_iterator = tqdm( 154 | enumerate(dataloader), 155 | total=len(dataloader), 156 | desc=f"Iteration {iteration + 1}, Epoch {epoch + 1}" 157 | ) 158 | 159 | for step, batch in epoch_iterator: 160 | input_ids = batch['input_ids'].to(device=self.device, dtype=torch.long) 161 | attention_mask = batch['attention_mask'].to(device=self.device, dtype=torch.long) 162 | 163 | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) 164 | logits = outputs.logits 165 | 166 | target = input_ids[:, 1:].contiguous() 167 | logits = logits[:, :-1].contiguous() 168 | 169 | loss = self.loss_fn( 170 | logits.view(-1, logits.size(-1)), 171 | target.view(-1) 172 | ) / accumulation_steps 173 | 174 | loss.backward() 175 | epoch_loss += loss.item() * accumulation_steps 176 | 177 | if (step + 1) % accumulation_steps == 0: 178 | for param in self.model.parameters(): 179 | param.data = torch.clamp( 180 | param.data, 181 | min=-self.sf_quantizer.max_val, 182 | max=self.sf_quantizer.max_val 183 | ) 184 | 185 | torch.nn.utils.clip_grad_value_( 186 | self.model.parameters(), 187 | clip_value=self.sf_quantizer.max_val 188 | ) 189 | 190 | self.optimizer.step() 191 | self.optimizer.zero_grad() 192 | 193 | epoch_iterator.set_postfix({"Loss": f"{loss.item() * accumulation_steps:.4f}"}) 194 | 195 | epoch_loss /= len(dataloader) 196 | print(f"Epoch {epoch + 1} Loss: {epoch_loss:.4f}") 197 | 198 | if epoch_loss < best_loss: 199 | best_loss = epoch_loss 200 | torch.save( 201 | self.model.state_dict(), 202 | f"sf{self.sf_quantizer.bits}_iteration{iteration+1}_epoch{epoch+1}_best.pth" 203 | ) 204 | 205 | pruning_masks = self.magnitude_based_pruning() 206 | self.reset_to_winning_ticket(pruning_masks) 207 | 208 | # After pruning, perform activation analysis and fine-tuning 209 | layer_activation_changes = self.activation_magnitude_analysis() 210 | self.fine_tune_based_on_activations(layer_activation_changes) 211 | 212 | torch.save(self.model.state_dict(), f"sf{self.sf_quantizer.bits}_winning_ticket_iteration{iteration+1}.pth") 213 | 214 | def main(): 215 | # Load the pre-trained model and tokenizer 216 | model_name = "gpt2" 217 | model = AutoModelForCausalLM.from_pretrained(model_name) 218 | tokenizer = AutoTokenizer.from_pretrained(model_name) 219 | 220 | # Initialize the Superfloat quantizer 221 | sf_quantizer = Superfloat(bits=11) # You can experiment with different bit-widths 222 | 223 | # Configuration settings 224 | config = { 225 | "pruning_rate": 0.2, 226 | "pruning_iterations": 3, 227 | "learning_rate": 1e-5, 228 | "optimizer_eps": 1e-4, 229 | "num_epochs": 3, 230 | "accumulation_steps": 32, 231 | "activation_threshold": 0.1 # You can adjust this threshold 232 | } 233 | 234 | # Instantiate the trainer 235 | trainer = LotteryTicketTrainer(model, sf_quantizer, tokenizer, config) 236 | 237 | # Train the model with LTH + Superfloat quantization 238 | trainer.train() 239 | 240 | if __name__ == "__main__": 241 | main() -------------------------------------------------------------------------------- /src/wasq/wasq_mplth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from transformers import AutoModelForCausalLM, AutoTokenizer 5 | from torch.utils.data import DataLoader 6 | from datasets import Dataset 7 | from tqdm import tqdm 8 | import copy 9 | import numpy as np 10 | 11 | class Superfloat: 12 | CASTING_TABLE = { 13 | 16: torch.float32, 14 | 15: torch.float32, 15 | 14: torch.float32, 16 | 13: torch.float32, 17 | 12: torch.float32, 18 | 11: torch.float16, 19 | 10: torch.float16, 20 | 9: torch.float16, 21 | 8: torch.bfloat16, 22 | 7: torch.bfloat16, 23 | 6: torch.bfloat16, 24 | 5: torch.bfloat16, 25 | 4: torch.bfloat16, 26 | } 27 | 28 | def __init__(self, bits: int): 29 | assert 4 <= bits <= 16, "Superfloat bitwidth must be between 4 and 16." 30 | self.bits = bits 31 | self.mantissa_bits = bits - 1 32 | self.max_val = 1 - 2**-self.mantissa_bits # Precompute max representable value 33 | self.float_type = self.CASTING_TABLE[bits] # Get float type based on bitwidth 34 | 35 | def quantize(self, tensor: torch.Tensor) -> torch.Tensor: 36 | """Quantizes a tensor to Superfloat format.""" 37 | scale = self.max_val / tensor.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-8) 38 | quantized = torch.clamp(tensor * scale, -self.max_val, self.max_val).round() 39 | return quantized / scale # Dequantize for inference 40 | 41 | if torch.backends.mps.is_available(): 42 | device = torch.device("mps") 43 | elif torch.cuda.is_available(): 44 | device = torch.device("cuda") 45 | else: 46 | device = torch.device("cpu") 47 | 48 | print(f"Using device: {device}") 49 | 50 | class MultiPrizeLotteryTicketTrainer: 51 | def __init__(self, model, sf_quantizer, tokenizer, config): 52 | self.device = device 53 | self.sf_quantizer = sf_quantizer 54 | self.model = model.to(device=self.device, dtype=sf_quantizer.float_type) 55 | self.tokenizer = tokenizer 56 | self.config = config 57 | self.original_model_state = copy.deepcopy(self.model.state_dict()) 58 | self.winning_tickets = {} # Store multiple winning tickets 59 | self.pruning_rates = config.get('pruning_rates', [0.1, 0.2, 0.3]) # Multiple pruning rates 60 | self.pruning_iterations = config.get('pruning_iterations', 3) 61 | self.optimizer = optim.Adam(self.model.parameters(), lr=config.get('learning_rate', 1e-5), eps=config.get('optimizer_eps', 1e-4)) 62 | self.loss_fn = nn.CrossEntropyLoss() 63 | 64 | def prepare_dataset(self, max_length=512): 65 | dataset = Dataset.from_parquet('train.parquet') 66 | 67 | def tokenize_function(examples): 68 | return self.tokenizer( 69 | examples["text"], 70 | truncation=True, 71 | max_length=max_length, 72 | padding="max_length", 73 | return_tensors="pt" 74 | ) 75 | 76 | tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names) 77 | return tokenized_dataset 78 | 79 | def create_dataloader(self, dataset, batch_size=4): 80 | def collate_fn(batch): 81 | input_ids = torch.stack([torch.tensor(example['input_ids']) for example in batch]) 82 | attention_mask = torch.stack([torch.tensor(example['attention_mask']) for example in batch]) 83 | return {'input_ids': input_ids, 'attention_mask': attention_mask} 84 | 85 | return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) 86 | 87 | def magnitude_based_pruning(self, pruning_rate): 88 | pruning_masks = {} 89 | 90 | for name, param in self.model.named_parameters(): 91 | if len(param.shape) > 1: # Only prune weight matrices 92 | weight_abs = torch.abs(param.data) 93 | flat_weights = weight_abs.view(-1) 94 | k = int(flat_weights.numel() * pruning_rate) 95 | threshold = torch.topk(flat_weights, k, largest=False).values.max() 96 | mask = (weight_abs > threshold).float() 97 | pruning_masks[name] = mask 98 | param.data *= mask 99 | 100 | return pruning_masks 101 | 102 | def reset_to_winning_ticket(self, pruning_masks): 103 | for name, param in self.model.named_parameters(): 104 | if name in pruning_masks: 105 | param.data.copy_(self.original_model_state[name]) 106 | param.data *= pruning_masks[name] 107 | 108 | def fine_tune(self, dataloader, num_epochs=3): 109 | best_loss = float('inf') 110 | 111 | for epoch in range(num_epochs): 112 | self.model.train() 113 | epoch_loss = 0.0 114 | 115 | epoch_iterator = tqdm(dataloader, total=len(dataloader), desc=f"Epoch {epoch + 1}") 116 | 117 | for step, batch in epoch_iterator: 118 | input_ids = batch['input_ids'].to(device=self.device, dtype=torch.long) 119 | attention_mask = batch['attention_mask'].to(device=self.device, dtype=torch.long) 120 | 121 | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) 122 | logits = outputs.logits 123 | 124 | target = input_ids[:, 1:].contiguous() 125 | logits = logits[:, :-1].contiguous() 126 | 127 | loss = self.loss_fn( 128 | logits.view(-1, logits.size(-1)), 129 | target.view(-1) 130 | ) 131 | 132 | loss.backward() 133 | epoch_loss += loss.item() 134 | 135 | if (step + 1) % self.config.get('accumulation_steps', 32) == 0: 136 | self.optimizer.step() 137 | self.optimizer.zero_grad() 138 | 139 | epoch_iterator.set_postfix({"Loss": f"{loss.item():.4f}"}) 140 | 141 | epoch_loss /= len(dataloader) 142 | print(f"Epoch {epoch + 1} Loss: {epoch_loss:.4f}") 143 | 144 | if epoch_loss < best_loss: 145 | best_loss = epoch_loss 146 | torch.save(self.model.state_dict(), f"best_model_epoch_{epoch+1}.pth") 147 | 148 | def train(self): 149 | tokenized_dataset = self.prepare_dataset() 150 | dataloader = self.create_dataloader(tokenized_dataset) 151 | 152 | for iteration in range(self.pruning_iterations): 153 | print(f"\nPruning Iteration {iteration + 1}/{self.pruning_iterations}") 154 | 155 | for pruning_rate in self.pruning_rates: 156 | print(f"Applying pruning with rate {pruning_rate}...") 157 | 158 | pruning_masks = self.magnitude_based_pruning(pruning_rate) 159 | self.reset_to_winning_ticket(pruning_masks) 160 | 161 | # Fine-tune with the current pruning rate 162 | self.fine_tune(dataloader) 163 | 164 | # Save the "winning ticket" for this pruning rate 165 | self.winning_tickets[pruning_rate] = self.model.state_dict() 166 | torch.save(self.model.state_dict(), f"sf{self.sf_quantizer.bits}_winning_ticket_rate{pruning_rate}.pth") 167 | 168 | # Optionally, after all pruning iterations, you could apply activation magnitude analysis and fine-tune further 169 | 170 | def main(): 171 | model_name = "gpt2" 172 | model = AutoModelForCausalLM.from_pretrained(model_name) 173 | tokenizer = AutoTokenizer.from_pretrained(model_name) 174 | 175 | # Initialize the Superfloat quantizer 176 | sf_quantizer = Superfloat(bits=11) 177 | 178 | # Configuration settings 179 | config = { 180 | "pruning_rates": [0.1, 0.2, 0.3], # Different pruning rates 181 | "pruning_iterations": 3, 182 | "learning_rate": 1e-5, 183 | "optimizer_eps": 1e-4, 184 | "num_epochs": 3, 185 | "accumulation_steps": 32, 186 | "activation_threshold": 0.1 # Optional: You can adjust this threshold 187 | } 188 | 189 | trainer = MultiPrizeLotteryTicketTrainer(model, sf_quantizer, tokenizer, config) 190 | trainer.train() 191 | 192 | if __name__ == "__main__": 193 | main() -------------------------------------------------------------------------------- /src/wasq/wasq_opt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import LlamaForCausalLM, PreTrainedTokenizerFast 4 | from torch.utils.data import DataLoader 5 | from datasets import load_dataset, Dataset 6 | from tqdm import tqdm 7 | import gc 8 | 9 | class Superfloat: 10 | CASTING_TABLE = { 11 | 16: torch.float32, 12 | 15: torch.float32, 13 | 14: torch.float32, 14 | 13: torch.float32, 15 | 12: torch.float32, 16 | 11: torch.float16, 17 | 10: torch.float16, 18 | 9: torch.float16, 19 | 8: torch.bfloat16, 20 | 7: torch.bfloat16, 21 | 6: torch.bfloat16, 22 | 5: torch.bfloat16, 23 | 4: torch.bfloat16, 24 | } 25 | 26 | def __init__(self, bits: int): 27 | assert 4 <= bits <= 16, "Superfloat bitwidth must be between 4 and 16." 28 | self.bits = bits 29 | self.mantissa_bits = bits - 1 30 | self.max_val = 1 - 2**-self.mantissa_bits # Precompute max representable value 31 | self.float_type = self.CASTING_TABLE[bits] # Get float type based on bitwidth 32 | 33 | def encode(self, value: torch.Tensor) -> torch.Tensor: 34 | """Encodes a tensor of values into the superfloat format.""" 35 | clipped_value = torch.clamp(value, min=-self.max_val, max=self.max_val) 36 | out_of_range = (value.abs() > self.max_val) 37 | mantissa = (torch.abs(clipped_value) * (2**self.mantissa_bits - 1) / self.max_val).floor().to(torch.int32) 38 | sign = (clipped_value < 0).to(torch.int32) 39 | return (mantissa | (sign << self.mantissa_bits)).to(torch.int32), out_of_range 40 | 41 | def decode(self, encoded_value: torch.Tensor) -> torch.Tensor: 42 | """Decodes a tensor of encoded superfloat values to regular floats.""" 43 | mantissa = encoded_value & ((1 << self.mantissa_bits) - 1) 44 | sign = (encoded_value >> self.mantissa_bits) & 1 45 | decoded_value = (mantissa.to(self.float_type) / (2**self.mantissa_bits - 1)) * self.max_val 46 | return decoded_value * (2 * sign - 1) 47 | 48 | def tensor_quantize(self, tensor: torch.Tensor) -> torch.Tensor: 49 | """Quantizes a tensor to the superfloat format, preserving the tensor's shape.""" 50 | encoded_tensor, out_of_range = self.encode(tensor) 51 | decoded_tensor = self.decode(encoded_tensor) 52 | return decoded_tensor, out_of_range 53 | 54 | sf = Superfloat(11) 55 | 56 | class QuantizedLlamaModel(torch.nn.Module): 57 | def __init__(self, base_model: torch.nn.Module, sf_quantizer: Superfloat): 58 | super(QuantizedLlamaModel, self).__init__() 59 | self.base_model = base_model 60 | self.sf_quantizer = sf_quantizer 61 | self.apply_gradient_hooks() 62 | 63 | def apply_gradient_hooks(self): 64 | for param in self.base_model.parameters(): 65 | def hook(grad, param=param): 66 | _, out_of_range = self.sf_quantizer.tensor_quantize(param) 67 | grad = grad * out_of_range.to(grad.dtype) # Mask to allow gradients only on out-of-range params 68 | return grad 69 | param.register_hook(hook) 70 | 71 | def forward(self, x): 72 | x, _ = self.sf_quantizer.tensor_quantize(x) 73 | for layer in self.base_model.children(): 74 | if isinstance(layer, torch.nn.Linear): 75 | layer.weight.data, _ = self.sf_quantizer.tensor_quantize(layer.weight.data) 76 | x = layer(x) 77 | x, _ = self.sf_quantizer.tensor_quantize(x) 78 | return x 79 | 80 | # Initialize model and tokenizer 81 | if torch.backends.mps.is_available(): 82 | device = torch.device("mps") 83 | elif torch.cuda.is_available(): 84 | device = torch.device("cuda") 85 | else: 86 | device = torch.device("cpu") 87 | 88 | print(f"Using device: {device}") 89 | 90 | model_name = "meta-llama/Llama-3.2-1B" 91 | model = LlamaForCausalLM.from_pretrained(model_name, cache_dir='./', token='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') 92 | model = model.to(sf.float_type).to(device) 93 | 94 | tokenizer = PreTrainedTokenizerFast.from_pretrained(model_name, cache_dir='./', token='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') 95 | tokenizer.pad_token = tokenizer.eos_token 96 | 97 | # Quantize Model Weights Selectively 98 | def quantize_model(model, sf_type): 99 | for name, param in model.named_parameters(): 100 | quantized_param, _ = sf_type.tensor_quantize(param) 101 | param.data = quantized_param.data 102 | return model 103 | 104 | import os 105 | import re 106 | 107 | def load_checkpoint(model, sf_bits, suffix="opt", device=device): 108 | """ 109 | Load the latest checkpoint based on the provided Superfloat bitwidth and filename suffix. 110 | 111 | Args: 112 | quantized_model: The model to load the checkpoint into. 113 | sf_bits: Bitwidth of the Superfloat format (e.g., 11). 114 | suffix: The suffix of the filename (default: 'opt'). 115 | device: Device to load the model onto ('cuda' or 'cpu'). 116 | 117 | Returns: 118 | The quantized model with loaded weights and the epoch number. 119 | """ 120 | # Define the filename pattern to search for 121 | checkpoint_pattern = re.compile(f"sf{sf_bits}_.*_epoch(\\d+)_.*{suffix}$") 122 | 123 | # Find all matching checkpoint files 124 | checkpoint_files = [ 125 | f for f in os.listdir(".") if checkpoint_pattern.match(f) 126 | ] 127 | 128 | if not checkpoint_files: 129 | print(f"No checkpoints found for sf{sf_bits} with suffix '{suffix}'.") 130 | return quantize_model(model, sf), 0 131 | 132 | # Extract epoch numbers and sort by latest epoch 133 | epochs_and_files = [ 134 | (int(checkpoint_pattern.match(f).group(1)), f) for f in checkpoint_files 135 | ] 136 | latest_epoch, latest_checkpoint = max(epochs_and_files, key=lambda x: x[0]) 137 | 138 | # Load the latest checkpoint 139 | print(f"Loading checkpoint: {latest_checkpoint}") 140 | checkpoint = torch.load(latest_checkpoint, map_location=device) 141 | model.load_state_dict(checkpoint) 142 | model.to(device) 143 | 144 | return model, latest_epoch 145 | 146 | # Usage 147 | quantized, last_epoch = load_checkpoint(model, sf.bits, suffix="opt", device=device) 148 | print(f"Resuming training from epoch {last_epoch + 1}.") 149 | 150 | del model 151 | if device=="cuda": 152 | torch.cuda.empty_cache() 153 | if device=="mps": 154 | torch.mps.empty_cache() 155 | gc.collect() 156 | 157 | # Prepare Dataset 158 | def prepare_dataset(tokenizer, max_length=1024): 159 | dataset = Dataset.from_parquet('train.parquet') 160 | def tokenize_function(examples): 161 | return tokenizer( 162 | examples["text"], 163 | truncation=True, 164 | max_length=max_length, 165 | padding="max_length", 166 | return_tensors="pt" 167 | ) 168 | tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names) 169 | return tokenized_dataset 170 | 171 | # Custom collate function 172 | def collate_fn(batch): 173 | input_ids = torch.stack([torch.tensor(example['input_ids']) for example in batch]) 174 | attention_mask = torch.stack([torch.tensor(example['attention_mask']) for example in batch]) 175 | return {'input_ids': input_ids, 'attention_mask': attention_mask} 176 | 177 | # Prepare tokenized dataset and dataloader 178 | tokenized_dataset = prepare_dataset(tokenizer) 179 | train_dataloader = DataLoader(tokenized_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn) 180 | 181 | # Optimizer and Loss 182 | optimizer = torch.optim.Adam(quantized.parameters(), lr=1e-5, eps=1e-4) 183 | loss_fn = torch.nn.CrossEntropyLoss() 184 | 185 | # Training Loop 186 | num_epochs = 3 187 | accumulation_steps = 32 # Number of steps to accumulate gradients 188 | best_loss = float('inf') 189 | 190 | quantized.to(device) 191 | 192 | for epoch in range(num_epochs): 193 | epoch_loss = 0.0 194 | epoch_iterator = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Epoch {epoch + 1}/{num_epochs}") 195 | 196 | for step, batch in epoch_iterator: 197 | input_ids = batch['input_ids'].to(device) 198 | attention_mask = batch['attention_mask'].to(device) 199 | 200 | # Forward pass 201 | outputs = quantized(input_ids=input_ids, attention_mask=attention_mask) 202 | logits = outputs.logits 203 | target = input_ids[:, 1:].contiguous() 204 | logits = logits[:, :-1].contiguous() 205 | 206 | # Calculate loss 207 | loss = loss_fn(logits.view(-1, logits.size(-1)), target.view(-1)) / accumulation_steps 208 | 209 | # Backward pass with gradient quantization 210 | loss.backward() 211 | 212 | # Accumulate loss for reporting 213 | epoch_loss += loss.item() * accumulation_steps 214 | 215 | if (step + 1) % accumulation_steps == 0: 216 | torch.nn.utils.clip_grad_value_(quantized.parameters(), clip_value = sf.max_val) 217 | optimizer.step() 218 | optimizer.zero_grad() 219 | epoch_iterator.set_postfix({"Loss": f"{loss.item() * accumulation_steps:.4f}"}) 220 | 221 | epoch_loss /= len(train_dataloader) 222 | if epoch_loss < best_loss: 223 | torch.save(quantized.state_dict(), f"sf{sf.bits}_{epoch+1}_opt") 224 | print(f"Epoch {epoch + 1} completed with average loss: {epoch_loss:.4f}") -------------------------------------------------------------------------------- /src/wasq/wasq_vanilla.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Function 5 | from transformers import AutoModelForCausalLM 6 | 7 | 8 | class Superfloat: 9 | """Simple Superfloat quantizer with encode/decode utilities.""" 10 | 11 | CASTING_TABLE = { 12 | 16: torch.float32, 13 | 15: torch.float32, 14 | 14: torch.float32, 15 | 13: torch.float32, 16 | 12: torch.float32, 17 | 11: torch.float16, 18 | 10: torch.float16, 19 | 9: torch.float16, 20 | 8: torch.bfloat16, 21 | 7: torch.bfloat16, 22 | 6: torch.bfloat16, 23 | 5: torch.bfloat16, 24 | 4: torch.bfloat16, 25 | } 26 | 27 | def __init__(self, bits: int): 28 | assert 4 <= bits <= 16, "Superfloat bitwidth must be between 4 and 16." 29 | self.bits = bits 30 | self.mantissa_bits = bits - 1 31 | self.max_val = 1 - 2 ** -self.mantissa_bits 32 | self.float_type = self.CASTING_TABLE[bits] 33 | 34 | def encode(self, value: torch.Tensor): 35 | clipped = torch.clamp(value, min=-self.max_val, max=self.max_val) 36 | mantissa = (torch.abs(clipped) * (2 ** self.mantissa_bits - 1) / self.max_val).floor().to(torch.int32) 37 | sign = (clipped < 0).to(torch.int32) 38 | encoded = (mantissa | (sign << self.mantissa_bits)).to(torch.int32) 39 | out_of_range = (value.abs() > self.max_val) 40 | return encoded, out_of_range 41 | 42 | def decode(self, encoded: torch.Tensor) -> torch.Tensor: 43 | mantissa = encoded & ((1 << self.mantissa_bits) - 1) 44 | sign = (encoded >> self.mantissa_bits) & 1 45 | decoded = (mantissa.to(self.float_type) / (2 ** self.mantissa_bits - 1)) * self.max_val 46 | return decoded * (2 * sign - 1) 47 | 48 | def tensor_quantize(self, tensor: torch.Tensor) -> torch.Tensor: 49 | enc, _ = self.encode(tensor) 50 | return self.decode(enc) 51 | 52 | 53 | class SFQuant(Function): 54 | """Straight-through estimator for Superfloat quantization.""" 55 | 56 | @staticmethod 57 | def forward(ctx, input: torch.Tensor, sf: "Superfloat"): 58 | encoded, mask = sf.encode(input) 59 | ctx.save_for_backward(mask) 60 | ctx.sf = sf 61 | return sf.decode(encoded) 62 | 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | (mask,) = ctx.saved_tensors 66 | # Pass gradients only where values were in range 67 | return grad_output * mask.to(grad_output.dtype), None 68 | 69 | 70 | class QuantizedLinear(nn.Linear): 71 | """Linear layer with on-the-fly Superfloat decode and optional LSQ+ scale.""" 72 | 73 | def __init__(self, in_features, out_features, sf: Superfloat, bias=True, k_outlier=0.005): 74 | super().__init__(in_features, out_features, bias) 75 | self.sf = sf 76 | 77 | # Split outlier channels that would overflow after quantisation 78 | with torch.no_grad(): 79 | channel_max = self.weight.abs().max(dim=1).values 80 | k = max(1, int(k_outlier * out_features)) 81 | self.outlier_idx = torch.topk(channel_max, k).indices 82 | mask = torch.ones(out_features, dtype=torch.bool) 83 | mask[self.outlier_idx] = False 84 | base_w = self.weight[mask].clone() 85 | self.register_buffer("encoded_weight", sf.encode(base_w)[0]) 86 | self.register_parameter("scale", nn.Parameter(torch.ones(base_w.size(0)))) 87 | self.register_parameter("outlier_weight", nn.Parameter(self.weight[self.outlier_idx].clone())) 88 | self.register_buffer("mask", mask) 89 | # Remove original parameter 90 | self.weight.requires_grad = False 91 | 92 | def forward(self, input: torch.Tensor) -> torch.Tensor: 93 | # Decode base weight on the fly and apply LSQ+ scale 94 | decoded_base = self.sf.decode(self.encoded_weight) * self.scale.view(-1, 1) 95 | weight = self.weight.new_zeros(self.out_features, self.in_features) 96 | weight[self.mask] = decoded_base 97 | weight[self.outlier_idx] = self.outlier_weight 98 | return F.linear(input, weight, self.bias) 99 | 100 | 101 | class ActivationQuant(nn.Module): 102 | """Module to quantise activations symmetrically with Superfloat.""" 103 | 104 | def __init__(self, sf: Superfloat): 105 | super().__init__() 106 | self.sf = sf 107 | 108 | def forward(self, x: torch.Tensor) -> torch.Tensor: 109 | return SFQuant.apply(x, self.sf) 110 | 111 | 112 | def compute_hessian_scores(model, data_loader, device, num_batches=1): 113 | """Approximate block-diagonal Hessian scores for parameters.""" 114 | scores = {name: torch.zeros_like(p) for name, p in model.named_parameters() if p.requires_grad} 115 | loss_fn = nn.CrossEntropyLoss() 116 | model.eval() 117 | for i, batch in enumerate(data_loader): 118 | if i >= num_batches: 119 | break 120 | batch = {k: v.to(device) for k, v in batch.items()} 121 | output = model(**batch) 122 | logits = output.logits 123 | target = batch["input_ids"][:, 1:].contiguous() 124 | logits = logits[:, :-1].contiguous() 125 | loss = loss_fn(logits.view(-1, logits.size(-1)), target.view(-1)) 126 | grads = torch.autograd.grad(loss, [p for p in model.parameters() if p.requires_grad], create_graph=False) 127 | for (name, _), g in zip([(n, p) for n, p in model.named_parameters() if p.requires_grad], grads): 128 | scores[name] += g.pow(2) 129 | return scores 130 | 131 | 132 | def select_sf_bits(weight, score, bit_options=(16, 11, 8, 4), budget=1e-3): 133 | """Simple layer-adaptive bit-width search using a quantisation error budget.""" 134 | for bits in sorted(bit_options, reverse=True): 135 | sf = Superfloat(bits) 136 | q = sf.tensor_quantize(weight) 137 | err = (weight - q).abs().mean() * score.mean() 138 | if err <= budget: 139 | return sf 140 | return Superfloat(bit_options[0]) 141 | 142 | 143 | def quantize_model(model, sf_options=(16, 11, 8, 4), data_loader=None, device="cpu"): 144 | """Quantise linear layers adaptively and insert activation quantisation.""" 145 | if data_loader is not None: 146 | scores = compute_hessian_scores(model, data_loader, device) 147 | else: 148 | scores = {name: torch.ones_like(p) for name, p in model.named_parameters() if p.requires_grad} 149 | 150 | for name, module in model.named_modules(): 151 | if isinstance(module, nn.Linear): 152 | score = scores.get(f"{name}.weight", torch.ones_like(module.weight)) 153 | sf = select_sf_bits(module.weight.data, score) 154 | qlinear = QuantizedLinear(module.in_features, module.out_features, sf, module.bias is not None) 155 | qlinear.bias = module.bias 156 | setattr(model, name.split(".")[-1], qlinear) 157 | elif isinstance(module, nn.Module) and not isinstance(module, ActivationQuant): 158 | module.register_forward_pre_hook(lambda m, inp: (SFQuant.apply(inp[0], Superfloat(11)),)) 159 | return model 160 | 161 | 162 | def main(): 163 | model_name = "Qwen/Qwen2-0.5B" 164 | if torch.backends.mps.is_available(): 165 | device = torch.device("mps") 166 | elif torch.cuda.is_available(): 167 | device = torch.device("cuda") 168 | else: 169 | device = torch.device("cpu") 170 | print(f"Using device: {device}") 171 | 172 | # Model loading may require network; placeholder path for offline usage 173 | model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="./") 174 | model = model.to(device) 175 | 176 | # Dummy dataloader for Hessian approximation (replace with real data) 177 | dummy_input = torch.randint(0, 10, (1, 8)) 178 | dummy_mask = torch.ones_like(dummy_input) 179 | dataset = [{"input_ids": dummy_input, "attention_mask": dummy_mask}] 180 | data_loader = torch.utils.data.DataLoader(dataset, batch_size=1) 181 | 182 | print("Applying adaptive Superfloat quantisation...") 183 | quantized_model = quantize_model(model, data_loader=data_loader, device=device) 184 | 185 | save_path = "sf_vanilla_adaptive.pt" 186 | torch.save(quantized_model.state_dict(), save_path) 187 | print(f"Quantised model saved to {save_path}") 188 | 189 | 190 | if __name__ == "__main__": 191 | main() 192 | -------------------------------------------------------------------------------- /src/website/backend/parallel_backend.py: -------------------------------------------------------------------------------- 1 | import modal 2 | import torch 3 | import time 4 | import psutil 5 | import gc 6 | import os 7 | import uuid 8 | from concurrent.futures import ThreadPoolExecutor 9 | from fastapi import FastAPI, Request, Response 10 | from pydantic import BaseModel 11 | from transformers import AutoTokenizer, AutoModelForCausalLM 12 | import math 13 | from fastapi.responses import StreamingResponse 14 | 15 | # Set your Hugging Face token 16 | hf_token = "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" 17 | os.environ["HUGGING_FACE_HUB_TOKEN"] = hf_token 18 | 19 | # Define the mapping for bit-width 20 | def map_bitwidth(bits): 21 | if 4 <= bits <= 7: 22 | return 4 23 | elif 8 <= bits <= 15: 24 | return 8 25 | else: 26 | return 16 27 | 28 | # Mapping bit-width to model names 29 | model_mapping = { 30 | "Qwen/Qwen2.5-0.5B": { 31 | 4: "Qwen/Qwen2.5-0.5B", 32 | 8: "Qwen/Qwen2.5-0.5B", 33 | 16: "Qwen/Qwen2.5-0.5B" 34 | }, 35 | "Qwen/Qwen2.5-1.5B": { 36 | 4: "Qwen/Qwen2.5-1.5B", 37 | 8: "Qwen/Qwen2.5-1.5B", 38 | 16: "Qwen/Qwen2.5-1.5B" 39 | }, 40 | "meta-llama/Llama-3.2-1B": { 41 | 4: "meta-llama/Llama-3.2-1B", 42 | 8: "meta-llama/Llama-3.2-1B", 43 | 16: "meta-llama/Llama-3.2-1B" 44 | }, 45 | "meta-llama/Llama-3.2-3B": { 46 | 4: "meta-llama/Llama-3.2-3B", 47 | 8: "meta-llama/Llama-3.2-3B", 48 | 16: "meta-llama/Llama-3.2-3B" 49 | }, 50 | "meta-llama/Llama-3.1-8B": { 51 | 4: "meta-llama/Llama-3.1-8B", 52 | 8: "meta-llama/Llama-3.1-8B", 53 | 16: "meta-llama/Llama-3.1-8B" 54 | }, 55 | } 56 | 57 | # Function to quantize the model 58 | def absmax_quantize(tensor, bitwidth): 59 | scale = torch.max(torch.abs(tensor)) 60 | q_tensor = torch.round(tensor / scale * (2**(bitwidth - 1) - 1)) 61 | deq_tensor = q_tensor / (2**(bitwidth - 1) - 1) * scale 62 | return deq_tensor 63 | 64 | def zero_mean_quantize(tensor, bitwidth): 65 | scale = torch.max(torch.abs(tensor - tensor.mean())) 66 | q_tensor = torch.round((tensor - tensor.mean()) / scale * (2**(bitwidth - 1) - 1)) 67 | deq_tensor = q_tensor / (2**(bitwidth - 1) - 1) * scale + tensor.mean() 68 | return deq_tensor 69 | 70 | model_cache = {} 71 | 72 | def load_model(model_name, bitwidth, quantization_type, device): 73 | # Create a unique key for the model based on its name, bitwidth, and quantization type 74 | cache_key = (model_name, bitwidth, quantization_type) 75 | 76 | # Check if the model is already in the cache 77 | if cache_key in model_cache: 78 | print(f"Using cached model: {model_name} with bitwidth {bitwidth} and quantization type {quantization_type}") 79 | model, tokenizer = model_cache[cache_key] 80 | else: 81 | print(f"Downloading and caching model: {model_name} with bitwidth {bitwidth} and quantization type {quantization_type}") 82 | # Load the model in half-precision (torch.float16) 83 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, use_safetensors=True).to(device) 84 | 85 | # Apply quantization 86 | for param in model.parameters(): 87 | if quantization_type == 'WASQ-LTH': 88 | param.data = absmax_quantize(param.data, bitwidth).to(torch.bfloat16) # Ensure quantization output is float16 89 | elif quantization_type == 'WASQ-OPT': 90 | param.data = zero_mean_quantize(param.data, bitwidth).to(torch.bfloat16) # Ensure quantization output is float16 91 | 92 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token) 93 | 94 | # Cache the model 95 | model_cache[cache_key] = (model, tokenizer) 96 | 97 | return model, tokenizer 98 | 99 | def measure_performance(model, tokenizer, input_text, device): 100 | # Tokenize the input text 101 | inputs = tokenizer(input_text, return_tensors='pt').to(device) 102 | 103 | # Ensure input_ids remains as torch.long (integer) 104 | inputs = {k: v.to(torch.long) if k == "input_ids" else v.to(torch.bfloat16) for k, v in inputs.items()} 105 | 106 | start_time = time.time() 107 | with torch.no_grad(): 108 | outputs = model.generate( 109 | **inputs, 110 | max_new_tokens=256, 111 | num_return_sequences=1, 112 | do_sample=True, 113 | temperature=0.7, 114 | repetition_penalty=1.2, 115 | pad_token_id=tokenizer.pad_token_id, 116 | eos_token_id=tokenizer.eos_token_id, 117 | ) 118 | end_time = time.time() 119 | inference_time = end_time - start_time 120 | memory_usage = psutil.Process().memory_info().rss / (1024 ** 2) # in MB 121 | generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) 122 | return inference_time, memory_usage, generated_text 123 | 124 | # Define request models 125 | class ModelRequest(BaseModel): 126 | model_name: str 127 | quantization_bits: int 128 | quantization_type: str 129 | input_text: str 130 | 131 | # Create a Modal Dict for persistent storage 132 | results_dict = modal.Dict.from_name("emelinlabs-results", create_if_missing=True) 133 | 134 | # Create a FastAPI app 135 | from fastapi.middleware.cors import CORSMiddleware 136 | 137 | app_fastapi = FastAPI() 138 | 139 | # Add CORS middleware 140 | app_fastapi.add_middleware( 141 | CORSMiddleware, 142 | allow_origins=["*"], # Allow all origins (replace with your frontend URL in production) 143 | allow_credentials=True, 144 | allow_methods=["*"], 145 | allow_headers=["*"], 146 | ) 147 | 148 | # Modal setup 149 | image = ( 150 | modal.Image.debian_slim(python_version="3.11") 151 | .pip_install("fastapi", "uvicorn", "transformers", "torch", "psutil", "pydantic") 152 | ) 153 | 154 | app = modal.App(name="emelinlabs-runner", image=image) 155 | 156 | def sanitize_float(value): 157 | """Ensure the value is a finite float and replace NaN/Infinity with 0.0.""" 158 | if not isinstance(value, (int, float)) or not math.isfinite(value): 159 | return 0.0 160 | return value 161 | 162 | # POST endpoint 163 | @app.function( 164 | gpu="A100", # Specify the GPU type (e.g., "A10G", "A100", "H100") 165 | timeout=86400, # Timeout in seconds (1 day = 86400 seconds) 166 | allow_concurrent_inputs=100 # Allow concurrent requests 167 | ) 168 | @modal.web_endpoint(method="POST") 169 | def run_inference(request: ModelRequest): 170 | device = "cuda" if torch.cuda.is_available() else "cpu" 171 | model_name = request.model_name 172 | quantization_bits = request.quantization_bits 173 | quantization_type = request.quantization_type 174 | input_text = request.input_text 175 | 176 | print(f"Model: {model_name}, Bits: {quantization_bits}, Type: {quantization_type}") 177 | 178 | # Generate a unique ID for this request 179 | request_id = str(uuid.uuid4()) 180 | 181 | # Load original model 182 | original_model = AutoModelForCausalLM.from_pretrained( 183 | model_name, torch_dtype=torch.bfloat16, use_safetensors=True 184 | ).to(device) 185 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token) 186 | 187 | # Load and quantize model 188 | effective_bits = map_bitwidth(quantization_bits) 189 | quantized_model_name = model_mapping[model_name][effective_bits] 190 | quantized_model, _ = load_model(quantized_model_name, effective_bits, quantization_type, device) 191 | 192 | # Function to run inference and measure performance 193 | def run_model_inference(model, tokenizer, input_text, device): 194 | inference_time, memory_usage, generated_text = measure_performance(model, tokenizer, input_text, device) 195 | return inference_time, memory_usage, generated_text 196 | 197 | # Run original and quantized model inference in parallel 198 | with ThreadPoolExecutor() as executor: 199 | orig_future = executor.submit(run_model_inference, original_model, tokenizer, input_text, device) 200 | quant_future = executor.submit(run_model_inference, quantized_model, tokenizer, input_text, device) 201 | 202 | orig_inference_time, orig_memory_usage, orig_text = orig_future.result() 203 | quant_inference_time, quant_memory_usage, quant_text = quant_future.result() 204 | 205 | # Calculate memory usage for quantized model 206 | quant_memory_usage = (effective_bits / 16.0) * orig_memory_usage 207 | 208 | # Calculate differences 209 | speed_diff = (orig_inference_time - quant_inference_time) / orig_inference_time * 100 210 | memory_savings = (1 - (quantization_bits / 16.0)) * 100 211 | 212 | # Sanitize all floating-point values 213 | orig_inference_time = sanitize_float(orig_inference_time) 214 | orig_memory_usage = sanitize_float(orig_memory_usage) 215 | quant_inference_time = sanitize_float(quant_inference_time) 216 | quant_memory_usage = sanitize_float(quant_memory_usage) 217 | speed_diff = sanitize_float(speed_diff) 218 | memory_savings = sanitize_float(memory_savings) 219 | 220 | # Store results in Modal Dict 221 | results_dict[request_id] = { 222 | "original": { 223 | "text": orig_text, 224 | "inference_time": orig_inference_time, 225 | "memory_usage": orig_memory_usage, 226 | }, 227 | "quantized": { 228 | "text": quant_text, 229 | "inference_time": quant_inference_time, 230 | "memory_usage": quant_memory_usage, 231 | }, 232 | "comparison": { 233 | "speed_diff": speed_diff, 234 | "memory_savings": memory_savings, 235 | } 236 | } 237 | 238 | # Clean up to free memory 239 | del original_model 240 | del quantized_model 241 | gc.collect() 242 | torch.cuda.empty_cache() 243 | 244 | return {"request_id": request_id} 245 | 246 | # GET endpoint 247 | @app.function() 248 | @modal.web_endpoint() 249 | def get_result(request_id: str): 250 | result = results_dict.get(request_id, None) 251 | if result: 252 | return result 253 | else: 254 | return {"error": "Request ID not found"} 255 | 256 | # Health check endpoint 257 | @app.function() 258 | @modal.web_endpoint() 259 | def health_check(): 260 | return {"status": "active"} 261 | 262 | # Stream tokens from original model 263 | @app.function() 264 | @modal.web_endpoint() 265 | def stream_original(request_id: str): 266 | result = results_dict.get(request_id, None) 267 | if result: 268 | def generate(): 269 | for token in result["original"]["text"].split(): 270 | yield f"data: {token}\n\n" 271 | time.sleep(0.1) # Simulate streaming delay 272 | return StreamingResponse(generate(), media_type="text/event-stream") 273 | else: 274 | return {"error": "Request ID not found"} 275 | 276 | # Stream tokens from quantized model 277 | @app.function() 278 | @modal.web_endpoint() 279 | def stream_quantized(request_id: str): 280 | result = results_dict.get(request_id, None) 281 | if result: 282 | def generate(): 283 | for token in result["quantized"]["text"].split(): 284 | yield f"data: {token}\n\n" 285 | time.sleep(0.1) # Simulate streaming delay 286 | return StreamingResponse(generate(), media_type="text/event-stream") 287 | else: 288 | return {"error": "Request ID not found"} -------------------------------------------------------------------------------- /src/website/backend/sequential_backend.py: -------------------------------------------------------------------------------- 1 | import modal 2 | import torch 3 | import time 4 | import psutil 5 | import gc 6 | import os 7 | import uuid 8 | from concurrent.futures import ThreadPoolExecutor 9 | from fastapi import FastAPI, Request 10 | from pydantic import BaseModel 11 | from transformers import AutoTokenizer, AutoModelForCausalLM 12 | import math 13 | 14 | # Set your Hugging Face token 15 | hf_token = "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" 16 | os.environ["HUGGING_FACE_HUB_TOKEN"] = hf_token 17 | 18 | # Define the mapping for bit-width 19 | def map_bitwidth(bits): 20 | if 4 <= bits <= 7: 21 | return 4 22 | elif 8 <= bits <= 15: 23 | return 8 24 | else: 25 | return 16 26 | 27 | # Mapping bit-width to model names 28 | model_mapping = { 29 | "Qwen/Qwen2.5-0.5B": { 30 | 4: "Qwen/Qwen2.5-0.5B", 31 | 8: "Qwen/Qwen2.5-0.5B", 32 | 16: "Qwen/Qwen2.5-0.5B" 33 | }, 34 | "Qwen/Qwen2.5-1.5B": { 35 | 4: "Qwen/Qwen2.5-1.5B", 36 | 8: "Qwen/Qwen2.5-1.5B", 37 | 16: "Qwen/Qwen2.5-1.5B" 38 | }, 39 | "meta-llama/Llama-3.2-1B": { 40 | 4: "meta-llama/Llama-3.2-1B", 41 | 8: "meta-llama/Llama-3.2-1B", 42 | 16: "meta-llama/Llama-3.2-1B" 43 | }, 44 | "meta-llama/Llama-3.2-3B": { 45 | 4: "meta-llama/Llama-3.2-3B", 46 | 8: "meta-llama/Llama-3.2-3B", 47 | 16: "meta-llama/Llama-3.2-3B" 48 | } 49 | } 50 | 51 | # Function to quantize the model 52 | def absmax_quantize(tensor, bitwidth): 53 | scale = torch.max(torch.abs(tensor)) 54 | q_tensor = torch.round(tensor / scale * (2**(bitwidth - 1) - 1)) 55 | deq_tensor = q_tensor / (2**(bitwidth - 1) - 1) * scale 56 | return deq_tensor 57 | 58 | def zero_mean_quantize(tensor, bitwidth): 59 | scale = torch.max(torch.abs(tensor - tensor.mean())) 60 | q_tensor = torch.round((tensor - tensor.mean()) / scale * (2**(bitwidth - 1) - 1)) 61 | deq_tensor = q_tensor / (2**(bitwidth - 1) - 1) * scale + tensor.mean() 62 | return deq_tensor 63 | 64 | def load_model(model_name, bitwidth, quantization_type, device): 65 | # Load the model in half-precision (torch.float16) 66 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device) 67 | 68 | # Apply quantization 69 | for param in model.parameters(): 70 | if quantization_type == 'WASQ-LTH': 71 | param.data = absmax_quantize(param.data, bitwidth).to(torch.float16) # Ensure quantization output is float16 72 | elif quantization_type == 'WASQ-OPT': 73 | param.data = zero_mean_quantize(param.data, bitwidth).to(torch.float16) # Ensure quantization output is float16 74 | 75 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token) 76 | return model, tokenizer 77 | 78 | def measure_performance(model, tokenizer, input_text, device): 79 | # Tokenize the input text 80 | inputs = tokenizer(input_text, return_tensors='pt').to(device) 81 | 82 | # Ensure input_ids remains as torch.long (integer) 83 | inputs = {k: v.to(torch.long) if k == "input_ids" else v.to(torch.float16) for k, v in inputs.items()} 84 | 85 | start_time = time.time() 86 | with torch.no_grad(): 87 | outputs = model.generate( 88 | **inputs, 89 | max_new_tokens=256, 90 | num_return_sequences=1, 91 | do_sample=True, 92 | temperature=0.7, 93 | repetition_penalty=1.2, 94 | pad_token_id=tokenizer.pad_token_id, 95 | eos_token_id=tokenizer.eos_token_id, 96 | ) 97 | end_time = time.time() 98 | inference_time = end_time - start_time 99 | memory_usage = psutil.Process().memory_info().rss / (1024 ** 2) # in MB 100 | generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) 101 | return inference_time, memory_usage, generated_text 102 | 103 | # def calculate_perplexity(model, tokenizer, input_text, device): 104 | # inputs = tokenizer(input_text, return_tensors='pt').to(device) 105 | # max_length = inputs.input_ids.size(1) 106 | # with torch.no_grad(): 107 | # outputs = model(**inputs) 108 | # shift_logits = outputs.logits[..., :-1, :].contiguous() 109 | # shift_labels = inputs.input_ids[..., 1:].contiguous() 110 | # loss_fct = torch.nn.CrossEntropyLoss() 111 | # loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 112 | # perplexity = torch.exp(loss).item() 113 | # return perplexity 114 | 115 | # Define request models 116 | class ModelRequest(BaseModel): 117 | model_name: str 118 | quantization_bits: int 119 | quantization_type: str 120 | input_text: str 121 | 122 | # Create a Modal Dict for persistent storage 123 | results_dict = modal.Dict.from_name("emelinlabs-results", create_if_missing=True) 124 | 125 | # Create a FastAPI app 126 | app_fastapi = FastAPI() 127 | 128 | # Modal setup 129 | image = ( 130 | modal.Image.debian_slim(python_version="3.11") 131 | .pip_install("fastapi", "uvicorn", "transformers", "torch", "psutil", "pydantic") 132 | ) 133 | 134 | app = modal.App(name="emelinlabs-runner", image=image) 135 | 136 | def sanitize_float(value): 137 | """Ensure the value is a finite float and replace NaN/Infinity with 0.0.""" 138 | if not isinstance(value, (int, float)) or not math.isfinite(value): 139 | return 0.0 140 | return value 141 | 142 | # POST endpoint 143 | @app.function( 144 | gpu="A100", # Specify the GPU type (e.g., "A10G", "A100", "H100") 145 | timeout=86400, # Timeout in seconds (1 day = 86400 seconds) 146 | allow_concurrent_inputs=100 # Allow concurrent requests 147 | ) 148 | @modal.web_endpoint(method="POST") 149 | def run_inference(request: ModelRequest): 150 | device = "cuda" if torch.cuda.is_available() else "cpu" 151 | model_name = request.model_name 152 | quantization_bits = request.quantization_bits 153 | quantization_type = request.quantization_type 154 | input_text = request.input_text 155 | 156 | # Generate a unique ID for this request 157 | request_id = str(uuid.uuid4()) 158 | 159 | # Load original model 160 | original_model = AutoModelForCausalLM.from_pretrained( 161 | model_name, torch_dtype=torch.float16 162 | ).to(device) 163 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token) 164 | 165 | # Load and quantize model 166 | effective_bits = map_bitwidth(quantization_bits) 167 | quantized_model_name = model_mapping[model_name][effective_bits] 168 | quantized_model, _ = load_model(quantized_model_name, effective_bits, quantization_type, device) 169 | 170 | # Measure performance for original model 171 | orig_inference_time, orig_memory_usage, orig_text = measure_performance(original_model, tokenizer, input_text, device) 172 | # orig_perplexity = calculate_perplexity(original_model, tokenizer, input_text, device) 173 | 174 | # Measure performance for quantized model 175 | quant_inference_time, _, quant_text = measure_performance(quantized_model, tokenizer, input_text, device) 176 | # quant_perplexity = calculate_perplexity(quantized_model, tokenizer, input_text, device) 177 | 178 | # Calculate memory usage for quantized model 179 | quant_memory_usage = (effective_bits / 16.0) * orig_memory_usage 180 | 181 | # Calculate differences 182 | speed_diff = (orig_inference_time - quant_inference_time) / orig_inference_time * 100 183 | memory_savings = (1 - (quantization_bits / 16.0)) * 100 184 | # perplexity_diff = quant_perplexity - orig_perplexity 185 | 186 | # Sanitize all floating-point values 187 | orig_inference_time = sanitize_float(orig_inference_time) 188 | orig_memory_usage = sanitize_float(orig_memory_usage) 189 | # orig_perplexity = sanitize_float(orig_perplexity) 190 | quant_inference_time = sanitize_float(quant_inference_time) 191 | quant_memory_usage = sanitize_float(quant_memory_usage) 192 | # quant_perplexity = sanitize_float(quant_perplexity) 193 | speed_diff = sanitize_float(speed_diff) 194 | memory_savings = sanitize_float(memory_savings) 195 | # perplexity_diff = sanitize_float(perplexity_diff) 196 | 197 | # Store results in Modal Dict 198 | results_dict[request_id] = { 199 | "original": { 200 | "text": orig_text, 201 | "inference_time": orig_inference_time, 202 | "memory_usage": orig_memory_usage, 203 | # "perplexity": orig_perplexity 204 | }, 205 | "quantized": { 206 | "text": quant_text, 207 | "inference_time": quant_inference_time, 208 | "memory_usage": quant_memory_usage, 209 | # "perplexity": quant_perplexity 210 | }, 211 | "comparison": { 212 | "speed_diff": speed_diff, 213 | "memory_savings": memory_savings, 214 | # "perplexity_diff": perplexity_diff 215 | } 216 | } 217 | 218 | # Clean up to free memory 219 | del original_model 220 | del quantized_model 221 | gc.collect() 222 | torch.cuda.empty_cache() 223 | 224 | return {"request_id": request_id} 225 | 226 | # GET endpoint 227 | @app.function() 228 | @modal.web_endpoint() 229 | def get_result(request_id: str): 230 | result = results_dict.get(request_id, None) 231 | if result: 232 | return result 233 | else: 234 | return {"error": "Request ID not found"} 235 | 236 | # Health check endpoint 237 | @app.function() 238 | @modal.web_endpoint() 239 | def health_check(): 240 | return {"status": "active"} -------------------------------------------------------------------------------- /src/website/backend/testing_backend.py: -------------------------------------------------------------------------------- 1 | import modal 2 | import torch 3 | import time 4 | import psutil 5 | import gc 6 | import os 7 | import uuid 8 | from concurrent.futures import ThreadPoolExecutor 9 | from fastapi import FastAPI, Request, Response 10 | from pydantic import BaseModel 11 | from transformers import AutoTokenizer, AutoModelForCausalLM 12 | import math 13 | from fastapi.responses import StreamingResponse 14 | 15 | # Set your Hugging Face token 16 | hf_token = "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" 17 | os.environ["HUGGING_FACE_HUB_TOKEN"] = hf_token 18 | 19 | # Define the mapping for bit-width 20 | def map_bitwidth(bits): 21 | if 4 <= bits <= 7: 22 | return 4 23 | elif 8 <= bits <= 15: 24 | return 8 25 | else: 26 | return 16 27 | 28 | # Mapping bit-width to model names 29 | model_mapping = { 30 | "Qwen/Qwen2.5-0.5B": { 31 | 4: "Qwen/Qwen2.5-0.5B", 32 | 8: "Qwen/Qwen2.5-0.5B", 33 | 16: "Qwen/Qwen2.5-0.5B" 34 | }, 35 | "Qwen/Qwen2.5-1.5B": { 36 | 4: "Qwen/Qwen2.5-1.5B", 37 | 8: "Qwen/Qwen2.5-1.5B", 38 | 16: "Qwen/Qwen2.5-1.5B" 39 | }, 40 | "meta-llama/Llama-3.2-1B": { 41 | 4: "meta-llama/Llama-3.2-1B", 42 | 8: "meta-llama/Llama-3.2-1B", 43 | 16: "meta-llama/Llama-3.2-1B" 44 | }, 45 | "meta-llama/Llama-3.2-3B": { 46 | 4: "meta-llama/Llama-3.2-3B", 47 | 8: "meta-llama/Llama-3.2-3B", 48 | 16: "meta-llama/Llama-3.2-3B" 49 | }, 50 | "meta-llama/Llama-3.1-8B": { 51 | 4: "meta-llama/Llama-3.1-8B", 52 | 8: "meta-llama/Llama-3.1-8B", 53 | 16: "meta-llama/Llama-3.1-8B" 54 | }, 55 | } 56 | 57 | # Function to quantize the model 58 | def absmax_quantize(tensor, bitwidth): 59 | scale = torch.max(torch.abs(tensor)) 60 | q_tensor = torch.round(tensor / scale * (2**(bitwidth - 1) - 1)) 61 | deq_tensor = q_tensor / (2**(bitwidth - 1) - 1) * scale 62 | return deq_tensor 63 | 64 | def zero_mean_quantize(tensor, bitwidth): 65 | scale = torch.max(torch.abs(tensor - tensor.mean())) 66 | q_tensor = torch.round((tensor - tensor.mean()) / scale * (2**(bitwidth - 1) - 1)) 67 | deq_tensor = q_tensor / (2**(bitwidth - 1) - 1) * scale + tensor.mean() 68 | return deq_tensor 69 | 70 | def load_model(model_name, bitwidth, quantization_type, device): 71 | # Load the model in half-precision (torch.float16) 72 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, use_safetensors=True).to(device) 73 | 74 | # Apply quantization 75 | for param in model.parameters(): 76 | if quantization_type == 'WASQ-LTH': 77 | param.data = absmax_quantize(param.data, bitwidth).to(torch.bfloat16) # Ensure quantization output is float16 78 | elif quantization_type == 'WASQ-OPT': 79 | param.data = zero_mean_quantize(param.data, bitwidth).to(torch.bfloat16) # Ensure quantization output is float16 80 | 81 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token) 82 | return model, tokenizer 83 | 84 | def measure_performance(model, tokenizer, input_text, device): 85 | # Tokenize the input text 86 | inputs = tokenizer(input_text, return_tensors='pt').to(device) 87 | 88 | # Ensure input_ids remains as torch.long (integer) 89 | inputs = {k: v.to(torch.long) if k == "input_ids" else v.to(torch.bfloat16) for k, v in inputs.items()} 90 | 91 | start_time = time.time() 92 | with torch.no_grad(): 93 | outputs = model.generate( 94 | **inputs, 95 | max_new_tokens=256, 96 | num_return_sequences=1, 97 | do_sample=True, 98 | temperature=0.7, 99 | repetition_penalty=1.2, 100 | pad_token_id=tokenizer.pad_token_id, 101 | eos_token_id=tokenizer.eos_token_id, 102 | ) 103 | end_time = time.time() 104 | inference_time = end_time - start_time 105 | memory_usage = psutil.Process().memory_info().rss / (1024 ** 2) # in MB 106 | generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) 107 | return inference_time, memory_usage, generated_text 108 | 109 | # Define request models 110 | class ModelRequest(BaseModel): 111 | model_name: str 112 | quantization_bits: int 113 | quantization_type: str 114 | input_text: str 115 | 116 | # Create a Modal Dict for persistent storage 117 | results_dict = modal.Dict.from_name("emelinlabs-results", create_if_missing=True) 118 | 119 | # Create a FastAPI app 120 | from fastapi.middleware.cors import CORSMiddleware 121 | 122 | app_fastapi = FastAPI() 123 | 124 | # Add CORS middleware 125 | app_fastapi.add_middleware( 126 | CORSMiddleware, 127 | allow_origins=["*"], # Allow all origins (replace with your frontend URL in production) 128 | allow_credentials=True, 129 | allow_methods=["*"], 130 | allow_headers=["*"], 131 | ) 132 | 133 | # Modal setup 134 | image = ( 135 | modal.Image.debian_slim(python_version="3.11") 136 | .pip_install("fastapi", "uvicorn", "transformers", "torch", "psutil", "pydantic") 137 | ) 138 | 139 | app = modal.App(name="emelinlabs-runners", image=image) 140 | 141 | def sanitize_float(value): 142 | """Ensure the value is a finite float and replace NaN/Infinity with 0.0.""" 143 | if not isinstance(value, (int, float)) or not math.isfinite(value): 144 | return 0.0 145 | return value 146 | 147 | # POST endpoint 148 | @app.function( 149 | gpu="A100", # Specify the GPU type (e.g., "A10G", "A100", "H100") 150 | timeout=86400, # Timeout in seconds (1 day = 86400 seconds) 151 | allow_concurrent_inputs=100 # Allow concurrent requests 152 | ) 153 | @modal.web_endpoint(method="POST") 154 | def run_inference(request: ModelRequest): 155 | device = "cuda" if torch.cuda.is_available() else "cpu" 156 | model_name = request.model_name 157 | quantization_bits = request.quantization_bits 158 | quantization_type = request.quantization_type 159 | input_text = request.input_text 160 | 161 | print(f"Model: {model_name}, Bits: {quantization_bits}, Type: {quantization_type}") 162 | 163 | # Generate a unique ID for this request 164 | request_id = str(uuid.uuid4()) 165 | 166 | # Load original model 167 | original_model = AutoModelForCausalLM.from_pretrained( 168 | model_name, torch_dtype=torch.bfloat16, use_safetensors=True 169 | ).to(device) 170 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token) 171 | 172 | # Load and quantize model 173 | effective_bits = map_bitwidth(quantization_bits) 174 | quantized_model_name = model_mapping[model_name][effective_bits] 175 | quantized_model, _ = load_model(quantized_model_name, effective_bits, quantization_type, device) 176 | 177 | # Function to run inference and measure performance 178 | def run_model_inference(model, tokenizer, input_text, device): 179 | inference_time, memory_usage, generated_text = measure_performance(model, tokenizer, input_text, device) 180 | return inference_time, memory_usage, generated_text 181 | 182 | # Run original and quantized model inference in parallel 183 | with ThreadPoolExecutor() as executor: 184 | orig_future = executor.submit(run_model_inference, original_model, tokenizer, input_text, device) 185 | quant_future = executor.submit(run_model_inference, quantized_model, tokenizer, input_text, device) 186 | 187 | orig_inference_time, orig_memory_usage, orig_text = orig_future.result() 188 | quant_inference_time, quant_memory_usage, quant_text = quant_future.result() 189 | 190 | # Calculate memory usage for quantized model 191 | quant_memory_usage = (effective_bits / 16.0) * orig_memory_usage 192 | 193 | # Calculate differences 194 | speed_diff = (orig_inference_time - quant_inference_time) / orig_inference_time * 100 195 | memory_savings = (1 - (quantization_bits / 16.0)) * 100 196 | 197 | # Sanitize all floating-point values 198 | orig_inference_time = sanitize_float(orig_inference_time) 199 | orig_memory_usage = sanitize_float(orig_memory_usage) 200 | quant_inference_time = sanitize_float(quant_inference_time) 201 | quant_memory_usage = sanitize_float(quant_memory_usage) 202 | speed_diff = sanitize_float(speed_diff) 203 | memory_savings = sanitize_float(memory_savings) 204 | 205 | # Store results in Modal Dict 206 | results_dict[request_id] = { 207 | "original": { 208 | "text": orig_text, 209 | "inference_time": orig_inference_time, 210 | "memory_usage": orig_memory_usage, 211 | }, 212 | "quantized": { 213 | "text": quant_text, 214 | "inference_time": quant_inference_time, 215 | "memory_usage": quant_memory_usage, 216 | }, 217 | "comparison": { 218 | "speed_diff": speed_diff, 219 | "memory_savings": memory_savings, 220 | } 221 | } 222 | 223 | # Clean up to free memory 224 | del original_model 225 | del quantized_model 226 | gc.collect() 227 | torch.cuda.empty_cache() 228 | 229 | return {"request_id": request_id} 230 | 231 | # GET endpoint 232 | @app.function() 233 | @modal.web_endpoint() 234 | def get_result(request_id: str): 235 | result = results_dict.get(request_id, None) 236 | if result: 237 | return result 238 | else: 239 | return {"error": "Request ID not found"} 240 | 241 | # Health check endpoint 242 | @app.function() 243 | @modal.web_endpoint() 244 | def health_check(): 245 | return {"status": "active"} 246 | 247 | # Stream tokens from original model 248 | @app.function() 249 | @modal.web_endpoint() 250 | def stream_original(request_id: str): 251 | result = results_dict.get(request_id, None) 252 | if result: 253 | def generate(): 254 | for token in result["original"]["text"].split(): 255 | yield f"data: {token}\n\n" 256 | time.sleep(0.1) # Simulate streaming delay 257 | return StreamingResponse(generate(), media_type="text/event-stream") 258 | else: 259 | return {"error": "Request ID not found"} 260 | 261 | # Stream tokens from quantized model 262 | @app.function() 263 | @modal.web_endpoint() 264 | def stream_quantized(request_id: str): 265 | result = results_dict.get(request_id, None) 266 | if result: 267 | def generate(): 268 | for token in result["quantized"]["text"].split(): 269 | yield f"data: {token}\n\n" 270 | time.sleep(0.1) # Simulate streaming delay 271 | return StreamingResponse(generate(), media_type="text/event-stream") 272 | else: 273 | return {"error": "Request ID not found"} -------------------------------------------------------------------------------- /src/website/frontend/.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | /.pnp 6 | .pnp.js 7 | 8 | # testing 9 | /coverage 10 | 11 | # production 12 | /build 13 | 14 | # misc 15 | .DS_Store 16 | .env.local 17 | .env.development.local 18 | .env.test.local 19 | .env.production.local 20 | 21 | npm-debug.log* 22 | yarn-debug.log* 23 | yarn-error.log* 24 | -------------------------------------------------------------------------------- /src/website/frontend/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "jumbo", 3 | "version": "0.1.0", 4 | "private": true, 5 | "dependencies": { 6 | "cra-template": "1.2.0", 7 | "lucide-react": "^0.472.0", 8 | "react": "^19.0.0", 9 | "react-dom": "^19.0.0", 10 | "react-markdown": "^9.0.3", 11 | "react-scripts": "^5.0.1", 12 | "web-vitals": "^4.2.4", 13 | "webpack": "^5.97.1" 14 | }, 15 | "scripts": { 16 | "start": "react-scripts start", 17 | "build": "react-scripts build", 18 | "test": "react-scripts test", 19 | "eject": "react-scripts eject" 20 | }, 21 | "eslintConfig": { 22 | "extends": [ 23 | "react-app", 24 | "react-app/jest" 25 | ] 26 | }, 27 | "browserslist": { 28 | "production": [ 29 | ">0.2%", 30 | "not dead", 31 | "not op_mini all" 32 | ], 33 | "development": [ 34 | "last 1 chrome version", 35 | "last 1 firefox version", 36 | "last 1 safari version" 37 | ] 38 | }, 39 | "devDependencies": { 40 | "@tailwindcss/typography": "^0.5.16", 41 | "tailwindcss": "^3.4.17" 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/website/frontend/public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/src/website/frontend/public/favicon.ico -------------------------------------------------------------------------------- /src/website/frontend/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 12 | 13 | 17 | 18 | 27 | React App 28 | 29 | 30 | 31 |
32 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /src/website/frontend/public/logo192.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/src/website/frontend/public/logo192.png -------------------------------------------------------------------------------- /src/website/frontend/public/logo512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/src/website/frontend/public/logo512.png -------------------------------------------------------------------------------- /src/website/frontend/public/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "short_name": "React App", 3 | "name": "Create React App Sample", 4 | "icons": [ 5 | { 6 | "src": "favicon.ico", 7 | "sizes": "64x64 32x32 24x24 16x16", 8 | "type": "image/x-icon" 9 | }, 10 | { 11 | "src": "logo192.png", 12 | "type": "image/png", 13 | "sizes": "192x192" 14 | }, 15 | { 16 | "src": "logo512.png", 17 | "type": "image/png", 18 | "sizes": "512x512" 19 | } 20 | ], 21 | "start_url": ".", 22 | "display": "standalone", 23 | "theme_color": "#000000", 24 | "background_color": "#ffffff" 25 | } 26 | -------------------------------------------------------------------------------- /src/website/frontend/public/robots.txt: -------------------------------------------------------------------------------- 1 | # https://www.robotstxt.org/robotstxt.html 2 | User-agent: * 3 | Disallow: 4 | -------------------------------------------------------------------------------- /src/website/frontend/src/App.css: -------------------------------------------------------------------------------- 1 | .App { 2 | text-align: center; 3 | } 4 | 5 | .App-logo { 6 | height: 40vmin; 7 | pointer-events: none; 8 | } 9 | 10 | @media (prefers-reduced-motion: no-preference) { 11 | .App-logo { 12 | animation: App-logo-spin infinite 20s linear; 13 | } 14 | } 15 | 16 | .App-header { 17 | background-color: #282c34; 18 | min-height: 100vh; 19 | display: flex; 20 | flex-direction: column; 21 | align-items: center; 22 | justify-content: center; 23 | font-size: calc(10px + 2vmin); 24 | color: white; 25 | } 26 | 27 | .App-link { 28 | color: #61dafb; 29 | } 30 | 31 | @keyframes App-logo-spin { 32 | from { 33 | transform: rotate(0deg); 34 | } 35 | to { 36 | transform: rotate(360deg); 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/website/frontend/src/App.js: -------------------------------------------------------------------------------- 1 | import logo from './logo.svg'; 2 | import './App.css'; 3 | import ChatInterface from './components/ChatInterface'; 4 | import LandingPage from './components/LandingPage'; 5 | function App() { 6 | return ( 7 |
8 | 9 |
10 | ); 11 | } 12 | 13 | export default App; 14 | -------------------------------------------------------------------------------- /src/website/frontend/src/App.test.js: -------------------------------------------------------------------------------- 1 | import { render, screen } from '@testing-library/react'; 2 | import App from './App'; 3 | 4 | test('renders learn react link', () => { 5 | render(); 6 | const linkElement = screen.getByText(/learn react/i); 7 | expect(linkElement).toBeInTheDocument(); 8 | }); 9 | -------------------------------------------------------------------------------- /src/website/frontend/src/README.md: -------------------------------------------------------------------------------- 1 | # Getting Started with Create React App 2 | 3 | This project was bootstrapped with [Create React App](https://github.com/facebook/create-react-app). 4 | 5 | ## Available Scripts 6 | 7 | In the project directory, you can run: 8 | 9 | ### `npm start` 10 | 11 | Runs the app in the development mode.\ 12 | Open [http://localhost:3000](http://localhost:3000) to view it in your browser. 13 | 14 | The page will reload when you make changes.\ 15 | You may also see any lint errors in the console. 16 | 17 | ### `npm test` 18 | 19 | Launches the test runner in the interactive watch mode.\ 20 | See the section about [running tests](https://facebook.github.io/create-react-app/docs/running-tests) for more information. 21 | 22 | ### `npm run build` 23 | 24 | Builds the app for production to the `build` folder.\ 25 | It correctly bundles React in production mode and optimizes the build for the best performance. 26 | 27 | The build is minified and the filenames include the hashes.\ 28 | Your app is ready to be deployed! 29 | 30 | See the section about [deployment](https://facebook.github.io/create-react-app/docs/deployment) for more information. 31 | 32 | ### `npm run eject` 33 | 34 | **Note: this is a one-way operation. Once you `eject`, you can't go back!** 35 | 36 | If you aren't satisfied with the build tool and configuration choices, you can `eject` at any time. This command will remove the single build dependency from your project. 37 | 38 | Instead, it will copy all the configuration files and the transitive dependencies (webpack, Babel, ESLint, etc) right into your project so you have full control over them. All of the commands except `eject` will still work, but they will point to the copied scripts so you can tweak them. At this point you're on your own. 39 | 40 | You don't have to ever use `eject`. The curated feature set is suitable for small and middle deployments, and you shouldn't feel obligated to use this feature. However we understand that this tool wouldn't be useful if you couldn't customize it when you are ready for it. 41 | 42 | ## Learn More 43 | 44 | You can learn more in the [Create React App documentation](https://facebook.github.io/create-react-app/docs/getting-started). 45 | 46 | To learn React, check out the [React documentation](https://reactjs.org/). 47 | 48 | ### Code Splitting 49 | 50 | This section has moved here: [https://facebook.github.io/create-react-app/docs/code-splitting](https://facebook.github.io/create-react-app/docs/code-splitting) 51 | 52 | ### Analyzing the Bundle Size 53 | 54 | This section has moved here: [https://facebook.github.io/create-react-app/docs/analyzing-the-bundle-size](https://facebook.github.io/create-react-app/docs/analyzing-the-bundle-size) 55 | 56 | ### Making a Progressive Web App 57 | 58 | This section has moved here: [https://facebook.github.io/create-react-app/docs/making-a-progressive-web-app](https://facebook.github.io/create-react-app/docs/making-a-progressive-web-app) 59 | 60 | ### Advanced Configuration 61 | 62 | This section has moved here: [https://facebook.github.io/create-react-app/docs/advanced-configuration](https://facebook.github.io/create-react-app/docs/advanced-configuration) 63 | 64 | ### Deployment 65 | 66 | This section has moved here: [https://facebook.github.io/create-react-app/docs/deployment](https://facebook.github.io/create-react-app/docs/deployment) 67 | 68 | ### `npm run build` fails to minify 69 | 70 | This section has moved here: [https://facebook.github.io/create-react-app/docs/troubleshooting#npm-run-build-fails-to-minify](https://facebook.github.io/create-react-app/docs/troubleshooting#npm-run-build-fails-to-minify) 71 | -------------------------------------------------------------------------------- /src/website/frontend/src/components/LandingPage.jsx: -------------------------------------------------------------------------------- 1 | import React, { useState } from 'react'; 2 | import { Github, Mail, ArrowRight, ArrowLeft } from 'lucide-react'; 3 | import ChatInterface from './ChatInterface'; 4 | 5 | const LandingPage = () => { 6 | const [showChat, setShowChat] = useState(false); 7 | 8 | return ( 9 |
10 | {showChat ? ( 11 |
12 |
13 | 20 |

SuperFloat

21 | 22 |
23 |
24 | 25 |
26 |
27 | ) : ( 28 |
29 | {/* Hero Section */} 30 |
31 |

32 | EmelinLabs presents 33 |

34 |

35 | Superfloat 36 |

37 |

38 | A revolutionary quantization algorithm optimizing neural networks through custom precision formats. 39 | Designed for edge computing with scalable precision and superior performance. 40 |

41 | 48 |
49 | 50 | {/* What We Do Section */} 51 |
52 |

What We Do

53 |
54 | {[ 55 | { 56 | title: "Sign-Exponent Representation", 57 | description: "Efficient bit allocation with 1 bit for sign and remaining for exponent, optimizing precision without mantissa." 58 | }, 59 | { 60 | title: "Clamping Range", 61 | description: "Values clamped within [-1, 1] for activation and parameter stability, preventing gradient issues." 62 | }, 63 | { 64 | title: "Bit-width Flexibility", 65 | description: "Scalable precision from 3-bit to 16-bit, balancing computation speed and accuracy." 66 | 67 | } 68 | ].map((feature, index) => ( 69 |
73 |

{feature.title}

74 |

{feature.description}

75 |
76 | ))} 77 |
78 |
79 | 80 | {/* About Section */} 81 |
82 |

About Us

83 |

84 | SuperFloat implements custom quantization algorithms focusing on the Lottery Ticket Hypothesis (LTH) 85 | and Weight and Activation SuperFloat Quantization (WASQ) techniques for optimizing neural networks 86 | on edge devices. 87 |

88 |
89 | 90 | {/* Contact Section */} 91 |
92 | 98 | 99 | GitHub 100 | 101 | 105 | 106 | Contact Us 107 | 108 |
109 |
110 | )} 111 | 112 | 166 |
167 | ); 168 | }; 169 | 170 | export default LandingPage; -------------------------------------------------------------------------------- /src/website/frontend/src/index.css: -------------------------------------------------------------------------------- 1 | @tailwind base; 2 | @tailwind components; 3 | @tailwind utilities; -------------------------------------------------------------------------------- /src/website/frontend/src/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ReactDOM from 'react-dom/client'; 3 | import './index.css'; 4 | import App from './App'; 5 | import reportWebVitals from './reportWebVitals'; 6 | 7 | const root = ReactDOM.createRoot(document.getElementById('root')); 8 | root.render( 9 | 10 | 11 | 12 | ); 13 | 14 | // If you want to start measuring performance in your app, pass a function 15 | // to log results (for example: reportWebVitals(console.log)) 16 | // or send to an analytics endpoint. Learn more: https://bit.ly/CRA-vitals 17 | reportWebVitals(); 18 | -------------------------------------------------------------------------------- /src/website/frontend/src/logo.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/website/frontend/src/reportWebVitals.js: -------------------------------------------------------------------------------- 1 | const reportWebVitals = onPerfEntry => { 2 | if (onPerfEntry && onPerfEntry instanceof Function) { 3 | import('web-vitals').then(({ getCLS, getFID, getFCP, getLCP, getTTFB }) => { 4 | getCLS(onPerfEntry); 5 | getFID(onPerfEntry); 6 | getFCP(onPerfEntry); 7 | getLCP(onPerfEntry); 8 | getTTFB(onPerfEntry); 9 | }); 10 | } 11 | }; 12 | 13 | export default reportWebVitals; 14 | -------------------------------------------------------------------------------- /src/website/frontend/src/setupTests.js: -------------------------------------------------------------------------------- 1 | // jest-dom adds custom jest matchers for asserting on DOM nodes. 2 | // allows you to do things like: 3 | // expect(element).toHaveTextContent(/react/i) 4 | // learn more: https://github.com/testing-library/jest-dom 5 | import '@testing-library/jest-dom'; 6 | -------------------------------------------------------------------------------- /src/website/frontend/tailwind.config.js: -------------------------------------------------------------------------------- 1 | /** @type {import('tailwindcss').Config} */ 2 | module.exports = { 3 | content: [ 4 | "./src/**/*.{js,jsx,ts,tsx}", 5 | ], 6 | theme: { 7 | extend: {}, 8 | }, 9 | plugins: [ 10 | require('@tailwindcss/typography'), 11 | ], 12 | } --------------------------------------------------------------------------------