├── .gitignore
├── README.md
├── assets
└── results
│ ├── FMA.png
│ ├── FloatvsSuperfloat.jpg
│ ├── LLM Distribution
│ ├── Llama-2-7b-chat-hf_layers.png
│ ├── Llama-2-7b-chat-hf_parameters.png
│ ├── MiniCPM-V-2_6_layers.png
│ ├── MiniCPM-V-2_6_parameters.png
│ ├── Mistral-7B-v0.1_layers.png
│ ├── Mistral-7B-v0.1_parameters.png
│ ├── gte-Qwen2-7B-instruct_layers.png
│ ├── gte-Qwen2-7B-instruct_parameters.png
│ ├── japanese-stablelm-base-beta-7b_layers.png
│ └── japanese-stablelm-base-beta-7b_parameters.png
│ ├── Qwen2
│ ├── RULE1-Max_Context_Length_Barrier.png
│ ├── SF11.png
│ ├── SF16.png
│ ├── SF4.png
│ └── SF8.png
│ ├── YOLOv5 Weight Distribution
│ ├── yolov5l.png
│ ├── yolov5l6.png
│ ├── yolov5m.png
│ ├── yolov5m6.png
│ ├── yolov5n.png
│ ├── yolov5n6.png
│ ├── yolov5s.png
│ ├── yolov5s6.png
│ ├── yolov5x.png
│ └── yolov5x6.png
│ ├── YOLOv7 Weight Distribution
│ ├── layer_status_plot_yolov7-d6.png
│ ├── layer_status_plot_yolov7-e6.png
│ ├── layer_status_plot_yolov7-e6e.png
│ ├── layer_status_plot_yolov7-w6.png
│ ├── layer_status_plot_yolov7.png
│ ├── layer_status_plot_yolov7x.png
│ ├── yolov7-d6.pt.png
│ ├── yolov7-e6.pt.png
│ ├── yolov7-e6e.pt.png
│ ├── yolov7-w6.pt.png
│ ├── yolov7.pt.png
│ └── yolov7x.pt.png
│ ├── activation_unit.png
│ ├── blockplan.png
│ ├── cycle_count_logic.png
│ ├── gtkwave_fma_comparison.png
│ ├── gtkwave_superfloat.png
│ ├── hardware architecture.png
│ ├── improved_dadda.jpg
│ ├── isa_integrated_floorplan.png
│ ├── results_viewer.ipynb
│ └── shift_register.png
├── docs
└── technical_reports
│ ├── Major Project.docx
│ ├── SuperFloat PoC.pdf
│ └── Superfloat Whitepaper.pdf
└── src
├── modal
├── better.py
├── modal_eval.py
├── modal_fasteropt.py
├── modal_lth.py
└── train-custom-gpt.py
├── test
├── emulate_sf.py
├── emulate_sf_dec.py
└── sf_tensors.py
├── verilog
├── activationUnit_16bit.sv
└── shiftRegister_16bit.sv
├── wasq
├── wasq_eval.py
├── wasq_fasteropt.py
├── wasq_fpm.py
├── wasq_inference.py
├── wasq_lth.py
├── wasq_mplth.py
├── wasq_opt.py
├── wasq_sa_mplth.py
└── wasq_vanilla.py
└── website
├── backend
├── parallel_backend.py
├── sequential_backend.py
└── testing_backend.py
└── frontend
├── .gitignore
├── package-lock.json
├── package.json
├── public
├── favicon.ico
├── index.html
├── logo192.png
├── logo512.png
├── manifest.json
└── robots.txt
├── src
├── App.css
├── App.js
├── App.test.js
├── README.md
├── components
│ ├── ChatInterface.jsx
│ └── LandingPage.jsx
├── index.css
├── index.js
├── logo.svg
├── reportWebVitals.js
└── setupTests.js
└── tailwind.config.js
/.gitignore:
--------------------------------------------------------------------------------
1 | models--*
2 | .locks/
3 | .env
4 | frontend/node_modules/
5 | *.exe
6 | __pycache__/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # **Superfloat: Accelerators for AI on Edge. Reimagined.**
2 |
3 | This repository contains the code, methods, and scripts for implementing **Superfloat Quantization** and **Lottery Ticket Hypothesis (LTH)** techniques for optimizing neural networks. The repository focuses on various quantization algorithms, model evaluations, and fine-tuning techniques to minimize perplexity and stabilize activations.
4 |
5 | ---
6 |
7 | ## **What is Superfloat?**
8 |
9 | **Superfloat** is a custom quantization algorithm that operates with a **scalable precision format**. Unlike traditional floating-point systems (IEEE-754), Superfloat removes the mantissa entirely and focuses solely on the **exponent** for precision representation.
10 |
11 | ### **Key Features**:
12 | 1. **Sign-Exponent Representation**:
13 | - Superfloat (SFx) uses `1 bit` for the **sign** and allocates the remaining `x-1 bits` for the **exponent**.
14 | - For instance, in **SF16**:
15 | - 1 bit → Sign
16 | - 15 bits → Exponent
17 |
18 | 2. **Clamping Range**:
19 | - All values are clamped within the range `[-1, 1]`. This ensures activation and parameter stability, reducing the likelihood of exploding or vanishing gradients.
20 |
21 | 3. **Bit-width Flexibility**:
22 | - Superfloat supports variable precision formats, scaling between **3-bit and 16-bit**:
23 | - Lower precision (e.g., **SF4**) → Faster computation, reduced model size.
24 | - Higher precision (e.g., **SF16**) → Improved accuracy while maintaining efficient quantization.
25 |
26 | 4. **Gradient and Activation Capping**:
27 | - To stabilize the training process, gradients and activations are **capped** at -1 and +1.
28 |
29 | ### **Advantages of Superfloat**:
30 | - Saves **precision** without a significant drop in accuracy.
31 | - Reduces **computational complexity** compared to traditional floating-point representations.
32 | - Allows adaptive scaling for diverse quantization requirements.
33 |
34 | ---
35 |
36 | **Conversion FP32 - SF(4-16)**
37 |
38 | A standard 32-bit floating-point number is converted into a custom superfloat representation with a variable-sized mantissa.
39 |
40 | - **Clamp Input Range** – The input value is restricted to the range (-1, 1). If the value exceeds this, it is set to a predefined maximum value.
41 |
42 | - **Extract Sign Bit** – The sign bit is determined and stored separately, while the value is converted to its absolute form.
43 |
44 | - **Compute Mantissa** – The fractional value is scaled by `2^mantissa_bits` to convert it into an integer representation.
45 |
46 | - **Bit Packing** – The sign bit and mantissa are arranged into a custom format, with the mantissa shifted to fit within a float-sized bit structure.
47 |
48 | - **Bitwise Reinterpretation** – The constructed bit pattern is reinterpreted as a floating-point number and returned.
49 |
50 | ---
51 | ## **What is WASQ?**
52 |
53 | **WASQ** stands for **Weight and Activation Superfloat Quantization**. It is a **hybrid quantization framework** that leverages Superfloat precision to optimize both model weights and activations.
54 |
55 | ### **Key Characteristics of WASQ**:
56 | 1. **Weight Quantization**:
57 | - Model weights are converted to **Superfloat precision** (SFx) without requiring complex computations like mantissa adjustments.
58 |
59 | 2. **Activation Quantization**:
60 | - Activations are clamped and quantized within a stable range to prevent issues such as exploding activations.
61 |
62 | 3. **Optimization Algorithms**:
63 | - WASQ includes customized algorithms like **WASQ OPT** and **Full Parameter Method (FPM)** to balance accuracy and convergence speed.
64 | - New: **Simulated Annealing Multi-Prize Lottery Ticket (SA-MPLTH)** algorithm for healing quantized models
65 |
66 | 4. **Scalability**:
67 | - WASQ supports **multi-bit quantization** (from 4-bit to 16-bit), making it adaptable for different deployment environments, such as:
68 | - **Edge devices** → Lower precision for speed and memory savings.
69 | - **Servers** → Higher precision for accuracy-sensitive tasks.
70 |
71 | ### **WASQ + Lottery Ticket Hypothesis (LTH)**
72 | WASQ integrates **LTH** to identify specific weights that are critical for maintaining model performance after quantization. By fine-tuning only the **essential weights**, WASQ reduces computational overhead while achieving high accuracy.
73 |
74 | ---
75 |
76 | ## **Files Overview**
77 |
78 | 1. **[Quant_Dequant.ipynb](Quant_Dequant.ipynb)**
79 | Contains the implementation of basic Superfloat quantization and dequantization functions.
80 |
81 | 2. **[sf16quant.ipynb](sf16quant.ipynb)**
82 | Builds on Superfloat quantization functions, specifically for **SF16 precision**.
83 |
84 | 3. **[lth_analysis.py](lth_analysis.py)**
85 | Analyzes **activation magnitude distribution** for **LTH**. It compares activation patterns of original and quantized models.
86 |
87 | 4. **[lth_trainer.py](lth_trainer.py)**
88 | The **LTH trainer** script for fine-tuning models based on the Lottery Ticket Hypothesis technique.
89 |
90 | 5. **[wasq_eval.py](wasq_eval.py)**
91 | Calculates **perplexity** for a series of models, grouped by context length, epochs, or model species.
92 |
93 | 6. **[wasq_inference.py](wasq_inference.py)**
94 | Provides inference capabilities for **individual** or **multiple WASQ-quantized models**.
95 |
96 | 7. **[wasq_fasteropt.py](wasq_fasteropt.py)**
97 | An optimized version of the **OPT algorithm** implemented in `wasq_opt.py`.
98 |
99 | 8. **[wasq_opt.py](wasq_opt.py)**
100 | Core implementation of the WASQ OPT algorithm.
101 |
102 | 9. **[wasq_fpm.py](wasq_fpm.py)**
103 | Implements the **Full Parameter Method** (FPM) for WASQ quantization.
104 |
105 | 10. **[wasq_vanilla.py](wasq_vanilla.py)**
106 | Baseline implementation of the **Vanilla algorithm** for WASQ.
107 |
108 | 11. **[sa_mplth.py](sa_mplth.py)**
109 | New: Implements Simulated Annealing Multi-Prize Lottery Ticket Hypothesis for healing quantized models.
110 |
111 | 12. **[assets/results](assets/results/)**
112 | Contains outputs of model tests, perplexity scores, and supplementary studies.
113 |
114 | ---
115 |
116 | ## **Scaling Laws**
117 |
118 | ### 1. **Maximum Context Length Barrier - Perplexity Factor**
119 | For a model with `n` parameters, a calibration dataset of maximum input length `c`, **three-shot quantization fine-tuning**, and Superfloat precision bit `x` (where `4 ≤ x ≤ 16`):
120 |
121 | \[
122 | P = f(n, c, 3, x)
123 | \]
124 |
125 | - **Lower P** indicates better model understanding and calibration performance.
126 |
127 | ---
128 |
129 | ### 2. **Maximum Neuron Spread Factor**
130 | This scaling law uses the **Lottery Ticket Hypothesis** for WASQ quantization to stabilize activations:
131 |
132 | 1. Perform a forward pass using the **original model** and record the average magnitudes of activations across all layers.
133 | 2. Perform the same for the **vanilla quantized model** to observe how quantization impacts activation magnitudes.
134 | 3. Rank layers based on the **difference in activation magnitudes** between the original and quantized models.
135 | 4. Identify and **cluster layers** with significant deviations to address issues like exploding/vanishing activations.
136 | 5. Fine-tune or analyze these clusters to ensure stable activations and minimal performance degradation.
137 |
138 | The law establishes that the **maximum neuron spread** (region targeted for fine-tuning/updating) is a function of:
139 | - **Activation magnitude**
140 | - **Activation fracture** (spread of how a weight affects neighboring weights during backpropagation)
141 |
142 | ---
143 |
144 | ## **Quantization Algorithms**
145 |
146 | The repository explores three quantization approaches:
147 |
148 | 1. **Superfloat Precision**: Custom precision without mantissa, clamped within `[-1, 1]` for stability.
149 | 2. **WASQ OPT**: Optimized quantization with faster convergence.
150 | 3. **Full Parameter Method (FPM)**: Retrains all parameters for higher accuracy.
151 | 4. **SA-MPLTH**: New simulated annealing approach for healing quantized models.
152 |
153 | ---
154 |
155 | ## **Usage**
156 |
157 | ### **Setup**
158 | Clone the repository and install dependencies:
159 |
160 | ```bash
161 | git clone https://github.com/aloshdenny/superfloat
162 | cd superfloat
163 | pip install -r requirements.txt
164 | ```
165 |
166 | ### **Running Scripts**
167 |
168 | - Train with **LTH**:
169 | ```bash
170 | python lth_trainer.py
171 | ```
172 |
173 | - Evaluate Perplexity:
174 | ```bash
175 | python wasq_eval.py
176 | ```
177 |
178 | - Perform Inference:
179 | ```bash
180 | python wasq_inference.py
181 | ```
182 |
183 | - Run SA-MPLTH:
184 | ```bash
185 | python sa_mplth.py
186 | ```
187 |
188 | ---
189 |
190 | ## **assets/Results**
191 |
192 | The assets/results folder contains:
193 | - **Perplexity scores** for different model configurations.
194 | - **Activation magnitude comparisons** before and after quantization.
195 | - Supplementary studies showcasing model performance.
196 |
197 | ---
198 |
199 | ## **Chip-1: Atreides**
200 |
201 | Atreides is an ASIC accelerator designed specifically for Superfloat-based inference. We redesigned the systolic array to support SFx operations, adopting a modded RV32 ISA and faster Fused-Multiply-Adder (FMA) units. The end goal is not convention—it's breaking the rules of computing and physics to achieve faster inference, lower memory consumption, and the same accuracy.
202 |
203 | ## FMA in Atreides
204 |
205 | Below is an image showing the FMA in Atreides:
206 |
207 | 
208 |
209 | ## Expanded View of Chip-1's Architecture
210 |
211 | An expanded view of Chip-1's architecture includes non-unified memory blocks (subject to unification), cache, control store (modded RV32 ISA), and an array of FMAs:
212 |
213 | 
214 |
215 | ### FPGA Functional Units Design
216 |
217 | #### 1. 8 x 16-bit Shift Register (simplified)
218 |
219 | 
220 |
221 | #### 2. Activation Unit (simplified)
222 |
223 | 
224 |
225 | #### 3. Cycle Count Logic
226 |
227 | 
228 |
229 | ## Instruction Set
230 |
231 | The current instruction set for the FPGA architecture is show below:
232 |
233 | | Instruction | Opcode(4) | Op 1(4) | Op 2(4) | Op 3(4) | Description |
234 | |-------------|-----------|---------|---------|---------|---------------------------------------------------------------------------------------|
235 | | STR | 0001 | addr | row | col | Stores the matrix data from activation unit buffer into specified address in memory |
236 | | LDR | 0010 | addr | row | col | Loads the matrix at addr into the Row Shift Buffer |
237 | | LDC | 0011 | addr | row | col | Loads the matrix at addr into the Column Shift Buffer |
238 | | MATMUL | 0100 | - | - | - | Performs matrix multiplication using data in Row Shift Buffer and Column Shift Buffer |
239 | | RELU | 0101 | - | - | - | Performs ReLU activation function on Systolic Array output |
240 | | LIN | 0110 | - | - | - | Performs Linear activation function on Systolic Array output |
241 | | NOP | 0000 | - | - | - | No Operation |
242 |
243 | ### FPGA floorplan (ISA integrated)
244 |
245 | The FPGA floorplan integrated with instruction set is shown below:
246 |
247 | 
248 |
249 | ---
250 |
251 | ## **Contributions**
252 |
253 | Contributions are welcome! Feel free to open issues or submit pull requests.
254 |
255 | ---
256 |
257 | ## **Sponsors**
258 |
259 | We would like to thank our sponsors for their support:
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 | ---
269 |
270 | ## **License**
271 |
272 | This project is licensed under the MIT License.
273 |
--------------------------------------------------------------------------------
/assets/results/FMA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/FMA.png
--------------------------------------------------------------------------------
/assets/results/FloatvsSuperfloat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/FloatvsSuperfloat.jpg
--------------------------------------------------------------------------------
/assets/results/LLM Distribution/Llama-2-7b-chat-hf_layers.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/LLM Distribution/Llama-2-7b-chat-hf_layers.png
--------------------------------------------------------------------------------
/assets/results/LLM Distribution/Llama-2-7b-chat-hf_parameters.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/LLM Distribution/Llama-2-7b-chat-hf_parameters.png
--------------------------------------------------------------------------------
/assets/results/LLM Distribution/MiniCPM-V-2_6_layers.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/LLM Distribution/MiniCPM-V-2_6_layers.png
--------------------------------------------------------------------------------
/assets/results/LLM Distribution/MiniCPM-V-2_6_parameters.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/LLM Distribution/MiniCPM-V-2_6_parameters.png
--------------------------------------------------------------------------------
/assets/results/LLM Distribution/Mistral-7B-v0.1_layers.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/LLM Distribution/Mistral-7B-v0.1_layers.png
--------------------------------------------------------------------------------
/assets/results/LLM Distribution/Mistral-7B-v0.1_parameters.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/LLM Distribution/Mistral-7B-v0.1_parameters.png
--------------------------------------------------------------------------------
/assets/results/LLM Distribution/gte-Qwen2-7B-instruct_layers.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/LLM Distribution/gte-Qwen2-7B-instruct_layers.png
--------------------------------------------------------------------------------
/assets/results/LLM Distribution/gte-Qwen2-7B-instruct_parameters.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/LLM Distribution/gte-Qwen2-7B-instruct_parameters.png
--------------------------------------------------------------------------------
/assets/results/LLM Distribution/japanese-stablelm-base-beta-7b_layers.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/LLM Distribution/japanese-stablelm-base-beta-7b_layers.png
--------------------------------------------------------------------------------
/assets/results/LLM Distribution/japanese-stablelm-base-beta-7b_parameters.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/LLM Distribution/japanese-stablelm-base-beta-7b_parameters.png
--------------------------------------------------------------------------------
/assets/results/Qwen2/RULE1-Max_Context_Length_Barrier.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/Qwen2/RULE1-Max_Context_Length_Barrier.png
--------------------------------------------------------------------------------
/assets/results/Qwen2/SF11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/Qwen2/SF11.png
--------------------------------------------------------------------------------
/assets/results/Qwen2/SF16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/Qwen2/SF16.png
--------------------------------------------------------------------------------
/assets/results/Qwen2/SF4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/Qwen2/SF4.png
--------------------------------------------------------------------------------
/assets/results/Qwen2/SF8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/Qwen2/SF8.png
--------------------------------------------------------------------------------
/assets/results/YOLOv5 Weight Distribution/yolov5l.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv5 Weight Distribution/yolov5l.png
--------------------------------------------------------------------------------
/assets/results/YOLOv5 Weight Distribution/yolov5l6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv5 Weight Distribution/yolov5l6.png
--------------------------------------------------------------------------------
/assets/results/YOLOv5 Weight Distribution/yolov5m.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv5 Weight Distribution/yolov5m.png
--------------------------------------------------------------------------------
/assets/results/YOLOv5 Weight Distribution/yolov5m6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv5 Weight Distribution/yolov5m6.png
--------------------------------------------------------------------------------
/assets/results/YOLOv5 Weight Distribution/yolov5n.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv5 Weight Distribution/yolov5n.png
--------------------------------------------------------------------------------
/assets/results/YOLOv5 Weight Distribution/yolov5n6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv5 Weight Distribution/yolov5n6.png
--------------------------------------------------------------------------------
/assets/results/YOLOv5 Weight Distribution/yolov5s.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv5 Weight Distribution/yolov5s.png
--------------------------------------------------------------------------------
/assets/results/YOLOv5 Weight Distribution/yolov5s6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv5 Weight Distribution/yolov5s6.png
--------------------------------------------------------------------------------
/assets/results/YOLOv5 Weight Distribution/yolov5x.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv5 Weight Distribution/yolov5x.png
--------------------------------------------------------------------------------
/assets/results/YOLOv5 Weight Distribution/yolov5x6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv5 Weight Distribution/yolov5x6.png
--------------------------------------------------------------------------------
/assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7-d6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7-d6.png
--------------------------------------------------------------------------------
/assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7-e6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7-e6.png
--------------------------------------------------------------------------------
/assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7-e6e.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7-e6e.png
--------------------------------------------------------------------------------
/assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7-w6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7-w6.png
--------------------------------------------------------------------------------
/assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7.png
--------------------------------------------------------------------------------
/assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7x.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/layer_status_plot_yolov7x.png
--------------------------------------------------------------------------------
/assets/results/YOLOv7 Weight Distribution/yolov7-d6.pt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/yolov7-d6.pt.png
--------------------------------------------------------------------------------
/assets/results/YOLOv7 Weight Distribution/yolov7-e6.pt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/yolov7-e6.pt.png
--------------------------------------------------------------------------------
/assets/results/YOLOv7 Weight Distribution/yolov7-e6e.pt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/yolov7-e6e.pt.png
--------------------------------------------------------------------------------
/assets/results/YOLOv7 Weight Distribution/yolov7-w6.pt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/yolov7-w6.pt.png
--------------------------------------------------------------------------------
/assets/results/YOLOv7 Weight Distribution/yolov7.pt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/yolov7.pt.png
--------------------------------------------------------------------------------
/assets/results/YOLOv7 Weight Distribution/yolov7x.pt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/YOLOv7 Weight Distribution/yolov7x.pt.png
--------------------------------------------------------------------------------
/assets/results/activation_unit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/activation_unit.png
--------------------------------------------------------------------------------
/assets/results/blockplan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/blockplan.png
--------------------------------------------------------------------------------
/assets/results/cycle_count_logic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/cycle_count_logic.png
--------------------------------------------------------------------------------
/assets/results/gtkwave_fma_comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/gtkwave_fma_comparison.png
--------------------------------------------------------------------------------
/assets/results/gtkwave_superfloat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/gtkwave_superfloat.png
--------------------------------------------------------------------------------
/assets/results/hardware architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/hardware architecture.png
--------------------------------------------------------------------------------
/assets/results/improved_dadda.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/improved_dadda.jpg
--------------------------------------------------------------------------------
/assets/results/isa_integrated_floorplan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/isa_integrated_floorplan.png
--------------------------------------------------------------------------------
/assets/results/shift_register.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/assets/results/shift_register.png
--------------------------------------------------------------------------------
/docs/technical_reports/Major Project.docx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/docs/technical_reports/Major Project.docx
--------------------------------------------------------------------------------
/docs/technical_reports/SuperFloat PoC.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/docs/technical_reports/SuperFloat PoC.pdf
--------------------------------------------------------------------------------
/docs/technical_reports/Superfloat Whitepaper.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/docs/technical_reports/Superfloat Whitepaper.pdf
--------------------------------------------------------------------------------
/src/modal/better.py:
--------------------------------------------------------------------------------
1 | import modal
2 |
3 | # Create a Modal image with the required dependencies
4 | image = (
5 | modal.Image.debian_slim()
6 | .pip_install(
7 | "torch",
8 | "transformers",
9 | "datasets",
10 | "tqdm",
11 | "huggingface_hub",
12 | )
13 | .apt_install("gcc", "python3-dev") # Add necessary system libraries if needed
14 | )
15 |
16 | app = modal.App("qwen-sf4-experimental")
17 |
18 | # Define the function that runs the script
19 | @app.function(gpu="H100", image=image, timeout=86400)
20 | def train_and_upload():
21 | import torch
22 | import gc
23 | import os
24 | import re
25 | import requests
26 | from tqdm import tqdm
27 | from datasets import Dataset
28 | from torch.utils.data import DataLoader
29 | from transformers import AutoModelForCausalLM, AutoTokenizer
30 | import pandas as pd
31 | import math
32 |
33 | # Function to calculate perplexity
34 | def calculate_perplexity(model, dataloader, loss_fn, device):
35 | model.eval() # Set model to evaluation mode
36 | total_loss = 0.0
37 | total_steps = 0
38 |
39 | with torch.no_grad():
40 | for batch in dataloader:
41 | input_ids = batch["input_ids"].to(device)
42 | attention_mask = batch["attention_mask"].to(device)
43 | outputs = model(input_ids=input_ids, attention_mask=attention_mask)
44 | logits = outputs.logits
45 | target = input_ids[:, 1:].contiguous() # Shift targets by one token
46 | logits = logits[:, :-1].contiguous() # Align logits with target
47 |
48 | loss = loss_fn(logits.view(-1, logits.size(-1)), target.view(-1))
49 | total_loss += loss.item()
50 | total_steps += 1
51 |
52 | avg_loss = total_loss / total_steps
53 | perplexity = math.exp(avg_loss) # Perplexity is the exponential of the average loss
54 | return perplexity
55 |
56 | # List of dataset URLs
57 | urls = [
58 | "https://huggingface.co/datasets/EleutherAI/the_pile_deduplicated/resolve/main/data/train-00000-of-01650-f70471ee3deb09c0.parquet",
59 | ]
60 |
61 | # Local final output file path
62 | final_file_name = "train.parquet"
63 |
64 | # Check if the final file already exists
65 | if not os.path.exists(final_file_name):
66 | print(f"Downloading and combining dataset from {len(urls)} files...")
67 |
68 | # List to hold all the dataframes
69 | combined_df = pd.DataFrame()
70 |
71 | # Loop through each URL to download and combine the files
72 | for i, url in enumerate(urls):
73 | downloaded_file = f"temp_file_{i}.parquet"
74 |
75 | # Download the dataset
76 | print(f"Downloading dataset from {url}...")
77 | response = requests.get(url, stream=True)
78 | with open(downloaded_file, "wb") as f:
79 | for chunk in response.iter_content(chunk_size=8192):
80 | f.write(chunk)
81 | print(f"Downloaded to {downloaded_file}.")
82 |
83 | # Read the downloaded parquet file and append to the combined dataframe
84 | df = pd.read_parquet(downloaded_file)
85 | combined_df = pd.concat([combined_df, df], ignore_index=True)
86 |
87 | # Optionally remove the temporary file after reading
88 | os.remove(downloaded_file)
89 |
90 | # Save the combined dataframe as a final parquet file
91 | combined_df.to_parquet(final_file_name)
92 | print(f"Combined data saved to {final_file_name}.")
93 | else:
94 | print(f"{final_file_name} already exists. Skipping download.")
95 |
96 |
97 | # max_lengths = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
98 | max_lengths = [2]
99 | bit = 4
100 |
101 | class Superfloat:
102 | # Mapping of bit-widths to floating-point types
103 | CASTING_TABLE = {
104 | 16: torch.float32,
105 | 15: torch.float32,
106 | 14: torch.float32,
107 | 13: torch.float32,
108 | 12: torch.float32,
109 | 11: torch.float16,
110 | 10: torch.float16,
111 | 9: torch.float16,
112 | 8: torch.bfloat16,
113 | 7: torch.bfloat16,
114 | 6: torch.bfloat16,
115 | 5: torch.bfloat16,
116 | 4: torch.bfloat16,
117 | }
118 |
119 | def __init__(self, bits: int):
120 | assert 4 <= bits <= 16, "Superfloat bitwidth must be between 4 and 16."
121 | self.bits = bits
122 | self.mantissa_bits = bits - 1
123 | self.max_val = 1.0 # Default max value
124 | self.scale_factor = 1.0 # Initialize scale factor
125 | self.float_type = self.CASTING_TABLE[bits] # Set float_type based on bits
126 |
127 | def set_scale(self, weights, dim=None, percentile=None):
128 | if dim is not None:
129 | self.scale_factor = torch.max(torch.abs(weights), dim=dim, keepdim=True)[0]
130 | elif percentile:
131 | scale = torch.kthvalue(torch.abs(weights).view(-1), int(weights.numel() * percentile / 100))[0]
132 | self.scale_factor = scale
133 | else:
134 | self.scale_factor = torch.max(torch.abs(weights))
135 |
136 | def encode(self, value: torch.Tensor):
137 | scaled_value = value / self.scale_factor
138 | quantized_value = torch.round(scaled_value * (2**self.mantissa_bits - 1)) / (2**self.mantissa_bits - 1)
139 | return quantized_value.to(self.float_type) # Cast to float_type
140 |
141 | def decode(self, quantized_value: torch.Tensor):
142 | decoded_value = quantized_value * self.scale_factor
143 | return decoded_value.to(self.float_type) # Cast to float_type
144 |
145 | sf = Superfloat(bit)
146 |
147 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
148 | print(f"Using device: {device}")
149 |
150 | # Initialize model and tokenizer
151 | model_name = "Qwen/Qwen2-0.5B"
152 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
153 | tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./", token='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')
154 | tokenizer.pad_token = tokenizer.eos_token
155 |
156 | def quantize_model(model, sf_type, dim=None, percentile=None):
157 | for name, param in model.named_parameters():
158 | # Cast weights to the correct float_type before quantization
159 | param.data = param.data.to(sf_type.float_type)
160 |
161 | # Set scale and quantize
162 | sf_type.set_scale(param.data, dim=dim, percentile=percentile)
163 | quantized_param = sf_type.encode(param.data)
164 | param.data = quantized_param.data
165 | return model
166 |
167 | def load_checkpoint(model, sf_bits, suffix="opt", device="cuda"):
168 | """
169 | Load the latest checkpoint based on the provided Superfloat bitwidth and filename suffix.
170 |
171 | Args:
172 | quantized_model: The model to load the checkpoint into.
173 | sf_bits: Bitwidth of the Superfloat format (e.g., 11).
174 | suffix: The suffix of the filename (default: 'opt').
175 | device: Device to load the model onto ('cuda' or 'cpu').
176 |
177 | Returns:
178 | The quantized model with loaded weights and the epoch number.
179 | """
180 | # Define the filename pattern to search for
181 | checkpoint_pattern = re.compile(f"sf{sf_bits}_.*_epoch(\\d+)_.*{suffix}$")
182 |
183 | # Find all matching checkpoint files
184 | checkpoint_files = [
185 | f for f in os.listdir(".") if checkpoint_pattern.match(f)
186 | ]
187 |
188 | if not checkpoint_files:
189 | print(f"No checkpoints found for sf{sf_bits} with suffix '{suffix}'.")
190 | return quantize_model(model, sf), 0
191 |
192 | # Extract epoch numbers and sort by latest epoch
193 | epochs_and_files = [
194 | (int(checkpoint_pattern.match(f).group(1)), f) for f in checkpoint_files
195 | ]
196 | latest_epoch, latest_checkpoint = max(epochs_and_files, key=lambda x: x[0])
197 |
198 | # Load the latest checkpoint
199 | print(f"Loading checkpoint: {latest_checkpoint}")
200 | checkpoint = torch.load(latest_checkpoint, map_location=device)
201 | model.load_state_dict(checkpoint)
202 | model.to(device)
203 |
204 | return model, latest_epoch
205 |
206 | # Pre-training parameter check to ensure they are within range
207 | def check_parameters_in_range(model, sf):
208 | out_of_range_params = []
209 | for name, param in model.named_parameters():
210 | if not torch.all(torch.abs(param.data) <= sf.max_val):
211 | out_of_range_params.append(name)
212 | if out_of_range_params:
213 | print(f"Warning: The following parameters are out of range:")
214 | for param_name in out_of_range_params:
215 | print(f"- {param_name}")
216 | else:
217 | print("All parameters are within the valid range.")
218 |
219 |
220 | def prepare_dataset(tokenizer, max_length=1):
221 | dataset = Dataset.from_parquet("train.parquet")
222 |
223 | def tokenize_function(examples):
224 | return tokenizer(
225 | examples["text"],
226 | truncation=True,
227 | max_length=max_length,
228 | padding="max_length",
229 | return_tensors="pt",
230 | )
231 |
232 | tokenized_dataset = dataset.map(
233 | tokenize_function, batched=True, remove_columns=dataset.column_names
234 | )
235 | return tokenized_dataset
236 |
237 | def collate_fn(batch):
238 | input_ids = torch.stack(
239 | [torch.tensor(example["input_ids"]) for example in batch]
240 | )
241 | attention_mask = torch.stack(
242 | [torch.tensor(example["attention_mask"]) for example in batch]
243 | )
244 | return {"input_ids": input_ids, "attention_mask": attention_mask}
245 |
246 |
247 | # Loop over different max_length values
248 | for max_length in max_lengths:
249 | print(f"Starting training for max_length = {max_length}")
250 |
251 | # Prepare Dataset
252 | tokenized_dataset = prepare_dataset(tokenizer, max_length=max_length)
253 | train_dataloader = DataLoader(
254 | tokenized_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn
255 | )
256 | model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="./", token='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')
257 | model = model.to(sf.float_type).to(device)
258 |
259 | quantized_model, last_epoch = load_checkpoint(model, sf.bits, suffix="opt", device=device)
260 | quantized_model.to(device)
261 | print(f"Resuming training from epoch {last_epoch + 1}.")
262 |
263 | # Check if model parameters are within range before training
264 | check_parameters_in_range(quantized_model, sf)
265 |
266 | # del model
267 | torch.cuda.empty_cache()
268 | gc.collect()
269 |
270 | optimizer = torch.optim.Adam(quantized_model.parameters(), lr=1e-5, eps=1e-4)
271 | loss_fn = torch.nn.CrossEntropyLoss()
272 |
273 | # Calculate and print the original model's perplexity before training
274 | print("Calculating original model perplexity...")
275 | original_perplexity = calculate_perplexity(model, train_dataloader, loss_fn, device)
276 | print(f"Original model perplexity: {original_perplexity:.4f}")
277 |
278 | num_epochs = 10
279 | accumulation_steps = 16
280 |
281 | for epoch in range(num_epochs):
282 | epoch_loss = 0.0
283 | epoch_iterator = tqdm(
284 | enumerate(train_dataloader),
285 | total=len(train_dataloader),
286 | desc=f"Epoch {epoch + 1}/{num_epochs}",
287 | )
288 |
289 | for step, batch in epoch_iterator:
290 | input_ids = batch["input_ids"].to(device)
291 | attention_mask = batch["attention_mask"].to(device)
292 | outputs = quantized_model(input_ids=input_ids, attention_mask=attention_mask)
293 | logits = outputs.logits
294 | target = input_ids[:, 1:].contiguous()
295 | logits = logits[:, :-1].contiguous()
296 |
297 | loss = loss_fn(logits.view(-1, logits.size(-1)), target.view(-1))
298 | loss = loss / accumulation_steps
299 | loss.backward()
300 |
301 | epoch_loss += loss.item() * accumulation_steps
302 |
303 | if (step + 1) % accumulation_steps == 0:
304 | optimizer.step()
305 | optimizer.zero_grad()
306 | epoch_iterator.set_postfix({"Loss": f"{loss.item() * accumulation_steps:.4f}"})
307 |
308 | epoch_loss /= len(train_dataloader)
309 | print(f"Epoch {epoch + 1} completed with average loss: {epoch_loss:.4f}")
310 |
311 | # Calculate and print the perplexity after each epoch
312 | epoch_perplexity = calculate_perplexity(quantized_model, train_dataloader, loss_fn, device)
313 | print(f"Epoch {epoch + 1} perplexity: {epoch_perplexity:.4f}")
314 |
315 | model_path = f"sf{sf.bits}_{max_length}_{epoch + 1}_opt"
316 | torch.save(quantized_model.state_dict(), model_path)
317 |
318 | # Upload model to Hugging Face
319 | os.system(
320 | f"huggingface-cli upload aoxo/qwen2-idkwtf {model_path} --token='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'"
321 | )
322 |
323 | del quantized_model
324 | torch.cuda.empty_cache()
325 | gc.collect()
326 |
327 | print(f"Completed training for max_length = {max_length}")
328 |
329 | # Entry point to run locally
330 | @app.local_entrypoint()
331 | def main():
332 | train_and_upload.remote()
--------------------------------------------------------------------------------
/src/modal/modal_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | import torch
4 | from transformers import AutoTokenizer, AutoModelForCausalLM
5 | from tqdm import tqdm
6 | from datasets import load_dataset
7 |
8 | # Device setup
9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10 | print(f"Using device: {device}")
11 |
12 | base_dir = "./qwen2-sf4"
13 | os.makedirs(base_dir, exist_ok=True)
14 |
15 | # Existing Superfloat and other helper functions remain the same as in the previous script
16 | # (Copying the Superfloat class, load_model, calculate_perplexity functions from previous script)
17 |
18 | # Function to load model
19 | def load_model(model_path):
20 | model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir='./', token='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', trust_remote_code=True)
21 | model.load_state_dict(torch.load(model_path, map_location=device))
22 | model = model.to(torch.bfloat16).to(device)
23 | model.eval() # Ensure model is in inference mode
24 | return model
25 |
26 | # Define Superfloat quantizer for clamping activations
27 | class Superfloat:
28 | def __init__(self, bits: int):
29 | assert 4 <= bits <= 16, "Superfloat bitwidth must be between 4 and 16."
30 | self.bits = bits
31 | self.mantissa_bits = bits - 1
32 | self.max_val = 1 - 2**-self.mantissa_bits # Precompute max representable value
33 |
34 | def encode(self, value: torch.Tensor) -> torch.Tensor:
35 | """Encodes a tensor of values into the superfloat format with optimized operations."""
36 | # Clip tensor values to the valid range for SFx
37 | clipped_value = torch.clamp(value, min=-self.max_val, max=self.max_val)
38 |
39 | # Calculate mantissa representation element-wise
40 | mantissa = (torch.abs(clipped_value) * (2**self.mantissa_bits - 1) / self.max_val).floor().to(torch.int32)
41 |
42 | # Create the superfloat representation (1 bit for sign and mantissa bits)
43 | sign = (clipped_value < 0).to(torch.int32)
44 | return (mantissa | (sign << self.mantissa_bits)).to(torch.int32)
45 |
46 | def decode(self, encoded_value: torch.Tensor) -> torch.Tensor:
47 | """Decodes a tensor of encoded superfloat values to regular floats."""
48 | # Extract mantissa and sign from the encoded superfloat
49 | mantissa = encoded_value & ((1 << self.mantissa_bits) - 1)
50 | sign = (encoded_value >> self.mantissa_bits) & 1
51 |
52 | # Calculate the decoded float using the mantissa and max_val
53 | decoded_value = (mantissa.to(torch.bfloat16) / (2**self.mantissa_bits - 1)) * self.max_val
54 | return decoded_value * (2 * sign - 1) # Apply the sign
55 |
56 | def tensor_quantize(self, tensor: torch.Tensor) -> torch.Tensor:
57 | """Quantizes a tensor to the superfloat format, preserving the tensor's shape."""
58 | # Apply element-wise encoding to the entire tensor and then decode back
59 | encoded_tensor = self.encode(tensor)
60 | decoded_tensor = self.decode(encoded_tensor)
61 | return decoded_tensor
62 |
63 | # Initialize Superfloat quantizer for sf{sf.bits}amping
64 | sf = Superfloat(4)
65 |
66 | model_name = "Qwen/Qwen2-0.5B"
67 |
68 | # Load tokenizer
69 | tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='./', token='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', trust_remote_code=True)
70 | tokenizer.pad_token = tokenizer.eos_token
71 |
72 | def quantized_inference(model, tokenizer, prompt, max_length=4096):
73 | """Runs inference on a prompt with activation quantization using Superfloat."""
74 | # Encode input prompt
75 | inputs = tokenizer(prompt, return_tensors="pt", padding=True)
76 | input_ids = inputs.input_ids.to(device)
77 | attention_mask = inputs.attention_mask.to(device)
78 |
79 | with torch.no_grad():
80 | # Perform generation with clamped activations
81 | outputs = model.generate(
82 | input_ids=input_ids,
83 | attention_mask=attention_mask,
84 | max_length=max_length,
85 | do_sample=True,
86 | top_k=50,
87 | top_p=0.95,
88 | temperature=0.7,
89 | pad_token_id=tokenizer.eos_token_id
90 | )
91 |
92 | # Decode generated output
93 | generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
94 | return generated_text
95 |
96 | def calculate_perplexity(model, tokenizer, prompt):
97 | """Calculates the perplexity of the model on a given prompt."""
98 | # Tokenize input
99 | inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
100 | input_ids = inputs.input_ids.to(device)
101 | attention_mask = inputs.attention_mask.to(device)
102 |
103 | # Get model outputs (logits)
104 | with torch.no_grad():
105 | outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
106 |
107 | # Get the loss (cross entropy) from the model's output
108 | loss = outputs.loss # This is the cross-entropy loss
109 |
110 | # Compute perplexity: exp(loss)
111 | perplexity = torch.exp(loss)
112 | return perplexity.item()
113 |
114 | # Model paths
115 | import os
116 |
117 | def get_model_paths(base_dir, sf_bits):
118 | """
119 | Dynamically generate model paths based on the sf.bits format.
120 | Looks for models of the form:
121 | 1. sf{sf_bits}_vanilla
122 | 2. sf{sf_bits}_{epoch_num}_fpm
123 | 3. sf{sf_bits}_{epoch_num}_opt
124 |
125 | Args:
126 | base_dir (str): The directory where the models are stored.
127 | sf_bits (int): The bitwidth for the Superfloat quantizer.
128 |
129 | Returns:
130 | List of model paths.
131 | """
132 | model_paths = []
133 | model_pattern = f"sf{sf_bits}_"
134 |
135 | # Scan directory for models matching the pattern
136 | for model_name in os.listdir(base_dir):
137 | if model_name.startswith(model_pattern):
138 | model_paths.append(os.path.join(base_dir, model_name))
139 |
140 | # Ensure models are sorted to follow the desired order: vanilla -> fpm -> opt
141 | model_paths.sort()
142 |
143 | return model_paths
144 |
145 | # Function to evaluate perplexity for a list of models and prompts
146 | def evaluate_models(base_dir, sf_bits, tokenizer, prompts):
147 | """
148 | Evaluates models dynamically loaded based on the sf.bits format.
149 |
150 | Args:
151 | base_dir (str): The directory where the models are stored.
152 | sf_bits (int): The bitwidth for the Superfloat quantizer.
153 | tokenizer: The tokenizer to use for model inference.
154 | prompts: The list of prompts to evaluate.
155 |
156 | Returns:
157 | Dictionary with model names and their corresponding average perplexity.
158 | """
159 | model_perplexities = {}
160 |
161 | # Get dynamically generated model paths
162 | models = get_model_paths(base_dir, sf_bits)
163 |
164 | for model_path in models:
165 | model = load_model(model_path)
166 | print(f"Evaluating model: {model_path}")
167 |
168 | total_perplexity = 0.0
169 | num_prompts = len(prompts)
170 |
171 | # Compute perplexity for each prompt
172 | for prompt in tqdm(prompts, desc=f"Processing {model_path}", leave=False):
173 | perplexity = calculate_perplexity(model, tokenizer, prompt)
174 | total_perplexity += perplexity
175 |
176 | # Average perplexity for the current model
177 | avg_perplexity = total_perplexity / num_prompts
178 | model_perplexities[model_path] = avg_perplexity
179 | print(f"Average Perplexity for {model_path}: {avg_perplexity}")
180 |
181 | return model_perplexities
182 |
183 | # Function to load the HellaSwag dataset
184 | def load_hellaswag_data():
185 | """Load the HellaSwag dataset from Hugging Face."""
186 | dataset = load_dataset("hellaswag", split='validation', trust_remote_code=True)
187 |
188 | # Extract only the prompts (contexts) for evaluation
189 | prompts = [entry['ctx'] for entry in dataset]
190 |
191 | # Return the prompts as a list
192 | return prompts
193 |
194 | # Load HellaSwag data (prompts)
195 | prompts = load_hellaswag_data()
196 |
197 |
198 | # Function to download model
199 | def download_model(model_name):
200 | """Download a specific model using wget."""
201 | url = f"https://huggingface.co/aoxo/qwen2-sf4/resolve/main/{model_name}"
202 | download_path = os.path.join(base_dir, model_name)
203 |
204 | try:
205 | # Use wget to download the model
206 | subprocess.run(["wget", "-O", download_path, url], check=True)
207 | print(f"Successfully downloaded {model_name}")
208 | return download_path
209 | except subprocess.CalledProcessError as e:
210 | print(f"Error downloading {model_name}: {e}")
211 | return None
212 |
213 | # Existing functions from previous script (load_model, calculate_perplexity, etc.)
214 | # [Keep all the previous implementations of these functions]
215 |
216 | # Main evaluation and cleanup function
217 | def process_models(model_list):
218 | """
219 | Process models by downloading, evaluating, and deleting them.
220 |
221 | Args:
222 | model_list (list): List of model names to process
223 | """
224 | # Load tokenizer once
225 | model_name = "Qwen/Qwen2-0.5B"
226 | tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='./', token='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')
227 | tokenizer.pad_token = tokenizer.eos_token
228 |
229 | # Load HellaSwag prompts
230 | prompts = load_hellaswag_data()
231 |
232 | # Results tracking
233 | results = {}
234 |
235 | for model_filename in model_list:
236 | print(f"\n--- Processing {model_filename} ---")
237 |
238 | # Download model
239 | model_path = download_model(model_filename)
240 |
241 | if not model_path:
242 | print(f"Skipping {model_filename} due to download failure")
243 | continue
244 |
245 | try:
246 | # Load model
247 | model = load_model(model_path)
248 |
249 | # Evaluate perplexity
250 | total_perplexity = 0.0
251 | for prompt in tqdm(prompts, desc=f"Evaluating {model_filename}"):
252 | perplexity = calculate_perplexity(model, tokenizer, prompt)
253 | total_perplexity += perplexity
254 |
255 | avg_perplexity = total_perplexity / len(prompts)
256 | results[model_filename] = avg_perplexity
257 | print(f"Average Perplexity for {model_filename}: {avg_perplexity}")
258 |
259 | except Exception as e:
260 | print(f"Error processing {model_filename}: {e}")
261 |
262 | # Delete the model file
263 | try:
264 | os.remove(model_path)
265 | print(f"Deleted {model_filename}")
266 | except Exception as e:
267 | print(f"Error deleting {model_filename}: {e}")
268 |
269 | # Save results to a file
270 | with open('perplexity_results.txt', 'w') as f:
271 | for model, perplexity in sorted(results.items(), key=lambda x: x[1]):
272 | f.write(f"{model}: {perplexity}\n")
273 |
274 | print("\nResults saved to perplexity_results.txt")
275 | return results
276 |
277 | # List of models to process
278 | models_to_process = [
279 | 'sf4_1024_1_opt', 'sf4_128_1_opt', 'sf4_16_1_opt',
280 | 'sf4_2048_1_opt', 'sf4_256_1_opt', 'sf4_2_1_opt',
281 | 'sf4_32_1_opt', 'sf4_4096_1_opt', 'sf4_4_1_opt',
282 | 'sf4_512_1_opt', 'sf4_64_1_opt', 'sf4_8_1_opt',
283 | 'sf4_1024_2_opt', 'sf4_128_2_opt', 'sf4_16_2_opt',
284 | 'sf4_2048_2_opt', 'sf4_256_2_opt', 'sf4_2_2_opt',
285 | 'sf4_32_2_opt', 'sf4_4096_2_opt', 'sf4_4_2_opt',
286 | 'sf4_512_2_opt', 'sf4_64_2_opt', 'sf4_8_2_opt',
287 | 'sf4_1024_3_opt', 'sf4_128_3_opt', 'sf4_16_3_opt',
288 | 'sf4_2048_3_opt', 'sf4_256_3_opt', 'sf4_2_3_opt',
289 | 'sf4_32_3_opt', 'sf4_4096_3_opt', 'sf4_4_3_opt',
290 | 'sf4_512_3_opt', 'sf4_64_3_opt', 'sf4_8_3_opt'
291 | ]
292 |
293 | # Run the processing
294 | results = process_models(models_to_process)
295 |
296 | # Print sorted results
297 | print("\nSorted Perplexity Results:")
298 | for model, perplexity in sorted(results.items(), key=lambda x: x[1]):
299 | print(f"{model}: {perplexity}")
--------------------------------------------------------------------------------
/src/modal/modal_fasteropt.py:
--------------------------------------------------------------------------------
1 | import modal
2 |
3 | # Create a Modal image with the required dependencies
4 | image = (
5 | modal.Image.debian_slim()
6 | .pip_install(
7 | "torch",
8 | "transformers",
9 | "datasets",
10 | "tqdm",
11 | "huggingface_hub",
12 | )
13 | .apt_install("gcc", "python3-dev") # Add necessary system libraries if needed
14 | )
15 |
16 | app = modal.App("qwen-sf4-experimental")
17 |
18 | # Define the function that runs the script
19 | # @app.function(gpu=modal.gpu.A100(size="80GB"), image=image, timeout=86400)
20 | @app.function(gpu="A100", image=image, timeout=86400)
21 | def train_and_upload():
22 | import torch
23 | import gc
24 | import os
25 | import re
26 | import requests
27 | from tqdm import tqdm
28 | from datasets import Dataset
29 | from torch.utils.data import DataLoader
30 | from transformers import AutoModelForCausalLM, AutoTokenizer
31 | import pandas as pd
32 |
33 | # List of dataset URLs
34 | urls = [
35 | "https://huggingface.co/datasets/EleutherAI/the_pile_deduplicated/resolve/main/data/train-00000-of-01650-f70471ee3deb09c0.parquet",
36 | "https://huggingface.co/datasets/EleutherAI/the_pile_deduplicated/resolve/main/data/train-00001-of-01650-172fc1a0c346b36e.parquet"
37 | ]
38 |
39 | # Local final output file path
40 | final_file_name = "train.parquet"
41 |
42 | # Check if the final file already exists
43 | if not os.path.exists(final_file_name):
44 | print(f"Downloading and combining dataset from {len(urls)} files...")
45 |
46 | # List to hold all the dataframes
47 | combined_df = pd.DataFrame()
48 |
49 | # Loop through each URL to download and combine the files
50 | for i, url in enumerate(urls):
51 | downloaded_file = f"temp_file_{i}.parquet"
52 |
53 | # Download the dataset
54 | print(f"Downloading dataset from {url}...")
55 | response = requests.get(url, stream=True)
56 | with open(downloaded_file, "wb") as f:
57 | for chunk in response.iter_content(chunk_size=8192):
58 | f.write(chunk)
59 | print(f"Downloaded to {downloaded_file}.")
60 |
61 | # Read the downloaded parquet file and append to the combined dataframe
62 | df = pd.read_parquet(downloaded_file)
63 | combined_df = pd.concat([combined_df, df], ignore_index=True)
64 |
65 | # Optionally remove the temporary file after reading
66 | os.remove(downloaded_file)
67 |
68 | # Save the combined dataframe as a final parquet file
69 | combined_df.to_parquet(final_file_name)
70 | print(f"Combined data saved to {final_file_name}.")
71 | else:
72 | print(f"{final_file_name} already exists. Skipping download.")
73 |
74 |
75 | # max_lengths = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
76 | max_lengths = [256]
77 | bit = 4
78 |
79 | class Superfloat:
80 | CASTING_TABLE = {
81 | 16: torch.float32,
82 | 15: torch.float32,
83 | 14: torch.float32,
84 | 13: torch.float32,
85 | 12: torch.float32,
86 | 11: torch.float16,
87 | 10: torch.float16,
88 | 9: torch.float16,
89 | 8: torch.bfloat16,
90 | 7: torch.bfloat16,
91 | 6: torch.bfloat16,
92 | 5: torch.bfloat16,
93 | 4: torch.bfloat16,
94 | }
95 |
96 | def __init__(self, bits: int):
97 | assert 4 <= bits <= 16, "Superfloat bitwidth must be between 4 and 16."
98 | self.bits = bits
99 | self.mantissa_bits = bits - 1
100 | self.max_val = 1 - 2**-self.mantissa_bits
101 | self.float_type = self.CASTING_TABLE[bits]
102 |
103 | def encode(self, value: torch.Tensor):
104 | clipped_value = torch.clamp(value, min=-self.max_val, max=self.max_val)
105 | out_of_range = (value.abs() > self.max_val)
106 | mantissa = (
107 | (torch.abs(clipped_value) * (2**self.mantissa_bits - 1) / self.max_val)
108 | .floor()
109 | .to(torch.int32)
110 | )
111 | sign = (clipped_value < 0).to(torch.int32)
112 | return (mantissa | (sign << self.mantissa_bits)).to(torch.int32), out_of_range
113 |
114 | def decode(self, encoded_value: torch.Tensor):
115 | mantissa = encoded_value & ((1 << self.mantissa_bits) - 1)
116 | sign = (encoded_value >> self.mantissa_bits) & 1
117 | decoded_value = (
118 | mantissa.to(self.float_type)
119 | / (2**self.mantissa_bits - 1)
120 | * self.max_val
121 | )
122 | return decoded_value * (2 * sign - 1)
123 |
124 | def tensor_quantize(self, tensor: torch.Tensor):
125 | encoded_tensor, out_of_range = self.encode(tensor)
126 | decoded_tensor = self.decode(encoded_tensor)
127 | return decoded_tensor, out_of_range
128 |
129 | sf = Superfloat(bit)
130 |
131 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
132 | print(f"Using device: {device}")
133 |
134 | # Initialize model and tokenizer
135 | model_name = "Qwen/Qwen2-0.5B"
136 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
137 | tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="./", token='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')
138 | tokenizer.pad_token = tokenizer.eos_token
139 |
140 | def quantize_model(model, sf_type):
141 | for name, param in model.named_parameters():
142 | quantized_param, _ = sf_type.tensor_quantize(param)
143 | param.data = quantized_param.data
144 | return model
145 |
146 | def load_checkpoint(model, sf_bits, suffix="opt", device="cuda"):
147 | """
148 | Load the latest checkpoint based on the provided Superfloat bitwidth and filename suffix.
149 |
150 | Args:
151 | quantized_model: The model to load the checkpoint into.
152 | sf_bits: Bitwidth of the Superfloat format (e.g., 11).
153 | suffix: The suffix of the filename (default: 'opt').
154 | device: Device to load the model onto ('cuda' or 'cpu').
155 |
156 | Returns:
157 | The quantized model with loaded weights and the epoch number.
158 | """
159 | # Define the filename pattern to search for
160 | checkpoint_pattern = re.compile(f"sf{sf_bits}_.*_epoch(\\d+)_.*{suffix}$")
161 |
162 | # Find all matching checkpoint files
163 | checkpoint_files = [
164 | f for f in os.listdir(".") if checkpoint_pattern.match(f)
165 | ]
166 |
167 | if not checkpoint_files:
168 | print(f"No checkpoints found for sf{sf_bits} with suffix '{suffix}'.")
169 | return quantize_model(model, sf), 0
170 |
171 | # Extract epoch numbers and sort by latest epoch
172 | epochs_and_files = [
173 | (int(checkpoint_pattern.match(f).group(1)), f) for f in checkpoint_files
174 | ]
175 | latest_epoch, latest_checkpoint = max(epochs_and_files, key=lambda x: x[0])
176 |
177 | # Load the latest checkpoint
178 | print(f"Loading checkpoint: {latest_checkpoint}")
179 | checkpoint = torch.load(latest_checkpoint, map_location=device)
180 | model.load_state_dict(checkpoint)
181 | model.to(device)
182 |
183 | return model, latest_epoch
184 |
185 | # Pre-training parameter check to ensure they are within range
186 | def check_parameters_in_range(model, sf):
187 | out_of_range_params = []
188 | for name, param in model.named_parameters():
189 | if not torch.all(torch.abs(param.data) <= sf.max_val):
190 | out_of_range_params.append(name)
191 | if out_of_range_params:
192 | print(f"Warning: The following parameters are out of range:")
193 | for param_name in out_of_range_params:
194 | print(f"- {param_name}")
195 | else:
196 | print("All parameters are within the valid range.")
197 |
198 |
199 | def prepare_dataset(tokenizer, max_length=1):
200 | dataset = Dataset.from_parquet("train.parquet")
201 |
202 | def tokenize_function(examples):
203 | return tokenizer(
204 | examples["text"],
205 | truncation=True,
206 | max_length=max_length,
207 | padding="max_length",
208 | return_tensors="pt",
209 | )
210 |
211 | tokenized_dataset = dataset.map(
212 | tokenize_function, batched=True, remove_columns=dataset.column_names
213 | )
214 | return tokenized_dataset
215 |
216 | def collate_fn(batch):
217 | input_ids = torch.stack(
218 | [torch.tensor(example["input_ids"]) for example in batch]
219 | )
220 | attention_mask = torch.stack(
221 | [torch.tensor(example["attention_mask"]) for example in batch]
222 | )
223 | return {"input_ids": input_ids, "attention_mask": attention_mask}
224 |
225 |
226 | # Loop over different max_length values
227 | for max_length in max_lengths:
228 | print(f"Starting training for max_length = {max_length}")
229 |
230 | # Prepare Dataset
231 | tokenized_dataset = prepare_dataset(tokenizer, max_length=max_length)
232 | train_dataloader = DataLoader(
233 | tokenized_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn
234 | )
235 | model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="./", token='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')
236 | model = model.to(sf.float_type).to(device)
237 | quantized_model, last_epoch = load_checkpoint(model, sf.bits, suffix="opt", device=device)
238 | quantized_model.to(device)
239 | print(f"Resuming training from epoch {last_epoch + 1}.")
240 |
241 | # Check if model parameters are within range before training
242 | check_parameters_in_range(quantized_model, sf)
243 |
244 | del model
245 | torch.cuda.empty_cache()
246 | gc.collect()
247 |
248 | optimizer = torch.optim.Adam(quantized_model.parameters(), lr=1e-5, eps=1e-4)
249 | loss_fn = torch.nn.CrossEntropyLoss()
250 |
251 | num_epochs = 10
252 | accumulation_steps = 16
253 |
254 | for epoch in range(num_epochs):
255 | epoch_loss = 0.0
256 | epoch_iterator = tqdm(
257 | enumerate(train_dataloader),
258 | total=len(train_dataloader),
259 | desc=f"Epoch {epoch + 1}/{num_epochs}",
260 | )
261 |
262 | for step, batch in epoch_iterator:
263 | input_ids = batch["input_ids"].to(device)
264 | attention_mask = batch["attention_mask"].to(device)
265 | outputs = quantized_model(input_ids=input_ids, attention_mask=attention_mask)
266 | logits = outputs.logits
267 | target = input_ids[:, 1:].contiguous()
268 | logits = logits[:, :-1].contiguous()
269 |
270 | loss = loss_fn(logits.view(-1, logits.size(-1)), target.view(-1))
271 | loss = loss / accumulation_steps
272 | loss.backward()
273 |
274 | epoch_loss += loss.item() * accumulation_steps
275 |
276 | if (step + 1) % accumulation_steps == 0:
277 | optimizer.step()
278 | optimizer.zero_grad()
279 | epoch_iterator.set_postfix({"Loss": f"{loss.item() * accumulation_steps:.4f}"})
280 |
281 | epoch_loss /= len(train_dataloader)
282 | print(f"Epoch {epoch + 1} completed with average loss: {epoch_loss:.4f}")
283 |
284 | model_path = f"sf{sf.bits}_{max_length}_{epoch + 1}_opt"
285 | torch.save(quantized_model.state_dict(), model_path)
286 |
287 | # Upload model to Hugging Face
288 | os.system(
289 | f"huggingface-cli upload aoxo/qwen2-sf4-experimental {model_path} --token='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'"
290 | )
291 |
292 | del quantized_model
293 | torch.cuda.empty_cache()
294 | gc.collect()
295 |
296 | print(f"Completed training for max_length = {max_length}")
297 |
298 | # Entry point to run locally
299 | @app.local_entrypoint()
300 | def main():
301 | train_and_upload.remote()
--------------------------------------------------------------------------------
/src/test/emulate_sf.py:
--------------------------------------------------------------------------------
1 | def sf_mul(bin1: str, bin2: str, n_bits: int = None) -> str:
2 | """
3 | Multiplies two signed fixed-point binary fractional numbers in s.xxx format.
4 |
5 | Args:
6 | - bin1, bin2: binary strings like '0.101' or '1.011'
7 | - n_bits: number of fractional bits (excluding sign). If None, inferred.
8 |
9 | Returns:
10 | - Result in s.xxx format with sign bit and n_bits fractional bits.
11 | """
12 | # Extract sign and fractional parts
13 | sign1, frac1 = bin1[0], bin1[2:]
14 | sign2, frac2 = bin2[0], bin2[2:]
15 |
16 | # Infer n_bits if not given
17 | if n_bits is None:
18 | n_bits = max(len(frac1), len(frac2))
19 |
20 | # Pad fractions to match n_bits
21 | frac1 = frac1.ljust(n_bits, '0')
22 | frac2 = frac2.ljust(n_bits, '0')
23 |
24 | # Convert to integers
25 | int1 = int(frac1, 2)
26 | int2 = int(frac2, 2)
27 |
28 | # Apply signs
29 | if sign1 == '1':
30 | int1 = -int1
31 | if sign2 == '1':
32 | int2 = -int2
33 |
34 | # Multiply
35 | product = int1 * int2
36 |
37 | # Result needs 2 * n_bits for full precision
38 | product_bits = 2 * n_bits
39 | abs_product = abs(product)
40 | product_bin = bin(abs_product)[2:].zfill(product_bits)
41 |
42 | # Take the top n_bits as fractional result
43 | fractional_part = product_bin[:n_bits]
44 |
45 | # Determine sign bit
46 | sign_bit = '0' if product >= 0 else '1'
47 |
48 | return f"{sign_bit}.{fractional_part}"
49 |
50 |
51 |
52 |
53 | # Example usage:
54 | # print(sf_mul('0.101', '0.110')) # Should return '0.011'
55 | # print(sf_mul('1.101', '0.110')) # Should return '1.011'
56 | # print(sf_mul('1.101', '1.110')) # Should return '0.011'
57 |
58 | print(sf_mul('0.1110001','0.0001001'))
--------------------------------------------------------------------------------
/src/test/emulate_sf_dec.py:
--------------------------------------------------------------------------------
1 | def sf_mul(bin1: str, bin2: str, n_bits: int = None) -> str:
2 | """
3 | Multiplies two signed fixed-point binary fractional numbers in s.xxx format.
4 |
5 | Args:
6 | - bin1, bin2: binary strings like '0.101' or '1.011'
7 | - n_bits: number of fractional bits (excluding sign). If None, inferred.
8 |
9 | Returns:
10 | - Result in s.xxx format with sign bit and n_bits fractional bits.
11 | """
12 | # Extract sign and fractional parts
13 | sign1, frac1 = bin1[0], bin1[2:]
14 | sign2, frac2 = bin2[0], bin2[2:]
15 | # Infer n_bits if not given
16 | if n_bits is None:
17 | n_bits = max(len(frac1), len(frac2))
18 | # Pad fractions to match n_bits
19 | frac1 = frac1.ljust(n_bits, '0')
20 | frac2 = frac2.ljust(n_bits, '0')
21 | # Convert to integers
22 | int1 = int(frac1, 2)
23 | int2 = int(frac2, 2)
24 | # Apply signs
25 | if sign1 == '1':
26 | int1 = -int1
27 | if sign2 == '1':
28 | int2 = -int2
29 | # Multiply
30 | product = int1 * int2
31 | # Result needs 2 * n_bits for full precision
32 | product_bits = 2 * n_bits
33 | abs_product = abs(product)
34 | product_bin = bin(abs_product)[2:].zfill(product_bits)
35 | # Take the top n_bits as fractional result
36 | fractional_part = product_bin[:n_bits]
37 | # Determine sign bit
38 | sign_bit = '0' if product >= 0 else '1'
39 | return f"{sign_bit}.{fractional_part}"
40 |
41 |
42 | def decimal_to_sf(decimal_val: float, n_bits: int = 8) -> str:
43 | """
44 | Converts a decimal number in range [-1, 1) to signed fixed-point binary format.
45 |
46 | Args:
47 | - decimal_val: decimal number between -1 and 1
48 | - n_bits: number of fractional bits
49 |
50 | Returns:
51 | - Binary string in s.xxx format
52 | """
53 | if not (-1 <= decimal_val < 1):
54 | raise ValueError("Decimal value must be in range [-1, 1)")
55 |
56 | # Determine sign
57 | sign_bit = '0' if decimal_val >= 0 else '1'
58 | abs_val = abs(decimal_val)
59 |
60 | # Convert fractional part to binary
61 | frac_binary = ""
62 | for _ in range(n_bits):
63 | abs_val *= 2
64 | if abs_val >= 1:
65 | frac_binary += '1'
66 | abs_val -= 1
67 | else:
68 | frac_binary += '0'
69 |
70 | return f"{sign_bit}.{frac_binary}"
71 |
72 |
73 | def sf_to_decimal(sf_binary: str) -> float:
74 | """
75 | Converts signed fixed-point binary format to decimal.
76 |
77 | Args:
78 | - sf_binary: binary string in s.xxx format
79 |
80 | Returns:
81 | - Decimal equivalent
82 | """
83 | sign_bit = sf_binary[0]
84 | frac_part = sf_binary[2:]
85 |
86 | # Convert fractional part to decimal
87 | decimal_val = 0
88 | for i, bit in enumerate(frac_part):
89 | if bit == '1':
90 | decimal_val += 2 ** (-(i + 1))
91 |
92 | # Apply sign
93 | if sign_bit == '1':
94 | decimal_val = -decimal_val
95 |
96 | return decimal_val
97 |
98 |
99 | def sf_mul_dec(dec1: float, dec2: float, n_bits: int = 8) -> dict:
100 | """
101 | Multiplies two decimal numbers in SF range [-1, 1) and returns result in both formats.
102 |
103 | Args:
104 | - dec1, dec2: decimal numbers between -1 and 1
105 | - n_bits: number of fractional bits for binary representation
106 |
107 | Returns:
108 | - Dictionary with 'sf_binary', 'decimal', 'inputs_sf', and 'exact_decimal' keys
109 | """
110 | # Validate inputs
111 | if not (-1 <= dec1 < 1) or not (-1 <= dec2 < 1):
112 | raise ValueError("Both decimal values must be in range [-1, 1)")
113 |
114 | # Convert decimals to SF binary format
115 | sf1 = decimal_to_sf(dec1, n_bits)
116 | sf2 = decimal_to_sf(dec2, n_bits)
117 |
118 | # Perform SF multiplication
119 | sf_result = sf_mul(sf1, sf2, n_bits)
120 |
121 | # Convert result back to decimal
122 | result_decimal = sf_to_decimal(sf_result)
123 |
124 | # Calculate exact decimal multiplication for comparison
125 | exact_decimal = dec1 * dec2
126 |
127 | return {
128 | 'sf_binary': sf_result,
129 | 'decimal': result_decimal,
130 | 'inputs_sf': (sf1, sf2),
131 | 'exact_decimal': exact_decimal,
132 | 'error': abs(exact_decimal - result_decimal)
133 | }
134 |
135 | result = sf_mul_dec(0.5, 0.213, 16)
136 | print(f" SF inputs: {result['inputs_sf'][0]} × {result['inputs_sf'][1]}")
137 | print(f" SF result: {result['sf_binary']}")
138 | print(f" Decimal result: {result['decimal']:.6f}")
139 | print(f" Exact decimal: {result['exact_decimal']:.6f}")
140 | print(f" Error: {result['error']:.6f}")
--------------------------------------------------------------------------------
/src/test/sf_tensors.py:
--------------------------------------------------------------------------------
1 | def sf_mul(bin1: str, bin2: str, n_bits: int = None) -> str:
2 | """
3 | Multiplies two signed fixed-point binary fractional numbers in s.xxx format.
4 |
5 | Args:
6 | - bin1, bin2: binary strings like '0.101' or '1.011'
7 | - n_bits: number of fractional bits (excluding sign). If None, inferred.
8 |
9 | Returns:
10 | - Result in s.xxx format with sign bit and n_bits fractional bits.
11 | """
12 | # Extract sign and fractional parts
13 | sign1, frac1 = bin1[0], bin1[2:]
14 | sign2, frac2 = bin2[0], bin2[2:]
15 | # Infer n_bits if not given
16 | if n_bits is None:
17 | n_bits = max(len(frac1), len(frac2))
18 | # Pad fractions to match n_bits
19 | frac1 = frac1.ljust(n_bits, '0')
20 | frac2 = frac2.ljust(n_bits, '0')
21 | # Convert to integers
22 | int1 = int(frac1, 2)
23 | int2 = int(frac2, 2)
24 | # Apply signs
25 | if sign1 == '1':
26 | int1 = -int1
27 | if sign2 == '1':
28 | int2 = -int2
29 | # Multiply
30 | product = int1 * int2
31 | # Result needs 2 * n_bits for full precision
32 | product_bits = 2 * n_bits
33 | abs_product = abs(product)
34 | product_bin = bin(abs_product)[2:].zfill(product_bits)
35 | # Take the top n_bits as fractional result
36 | fractional_part = product_bin[:n_bits]
37 | # Determine sign bit
38 | sign_bit = '0' if product >= 0 else '1'
39 | return f"{sign_bit}.{fractional_part}"
40 |
41 |
42 | def decimal_to_sf(decimal_val: float, n_bits: int = 8) -> str:
43 | """
44 | Converts a decimal number to signed fixed-point binary format with overflow clamping.
45 |
46 | Args:
47 | - decimal_val: decimal number (will be clamped to [-1, 1) range)
48 | - n_bits: number of fractional bits
49 |
50 | Returns:
51 | - Binary string in s.xxx format
52 | """
53 | # Clamp to valid SF range
54 | max_val = 1.0 - (2 ** (-n_bits)) # Maximum positive value (just under 1.0)
55 | min_val = -1.0 # Minimum negative value
56 |
57 | if decimal_val >= 1.0:
58 | decimal_val = max_val
59 | elif decimal_val < -1.0:
60 | decimal_val = min_val
61 |
62 | # Determine sign
63 | sign_bit = '0' if decimal_val >= 0 else '1'
64 | abs_val = abs(decimal_val)
65 |
66 | # Convert fractional part to binary
67 | frac_binary = ""
68 | for _ in range(n_bits):
69 | abs_val *= 2
70 | if abs_val >= 1:
71 | frac_binary += '1'
72 | abs_val -= 1
73 | else:
74 | frac_binary += '0'
75 |
76 | return f"{sign_bit}.{frac_binary}"
77 |
78 |
79 | def sf_to_decimal(sf_binary: str) -> float:
80 | """
81 | Converts signed fixed-point binary format to decimal.
82 |
83 | Args:
84 | - sf_binary: binary string in s.xxx format
85 |
86 | Returns:
87 | - Decimal equivalent
88 | """
89 | sign_bit = sf_binary[0]
90 | frac_part = sf_binary[2:]
91 |
92 | # Convert fractional part to decimal
93 | decimal_val = 0
94 | for i, bit in enumerate(frac_part):
95 | if bit == '1':
96 | decimal_val += 2 ** (-(i + 1))
97 |
98 | # Apply sign
99 | if sign_bit == '1':
100 | decimal_val = -decimal_val
101 |
102 | return decimal_val
103 |
104 |
105 | def sf_mul_dec(dec1: float, dec2: float, n_bits: int = 8) -> dict:
106 | """
107 | Multiplies two decimal numbers using SF arithmetic with overflow clamping.
108 |
109 | Args:
110 | - dec1, dec2: decimal numbers (will be clamped to SF range if needed)
111 | - n_bits: number of fractional bits for binary representation
112 |
113 | Returns:
114 | - Dictionary with 'sf_binary', 'decimal', 'inputs_sf', 'exact_decimal', and overflow info
115 | """
116 | # Store original values for exact calculation
117 | original_dec1, original_dec2 = dec1, dec2
118 |
119 | # Clamp inputs to valid SF range
120 | max_val = 1.0 - (2 ** (-n_bits))
121 | min_val = -1.0
122 |
123 | input1_overflow = False
124 | input2_overflow = False
125 |
126 | if dec1 >= 1.0:
127 | dec1 = max_val
128 | input1_overflow = True
129 | elif dec1 < -1.0:
130 | dec1 = min_val
131 | input1_overflow = True
132 |
133 | if dec2 >= 1.0:
134 | dec2 = max_val
135 | input2_overflow = True
136 | elif dec2 < -1.0:
137 | dec2 = min_val
138 | input2_overflow = True
139 |
140 | # Convert decimals to SF binary format
141 | sf1 = decimal_to_sf(dec1, n_bits)
142 | sf2 = decimal_to_sf(dec2, n_bits)
143 |
144 | # Perform SF multiplication
145 | sf_result = sf_mul(sf1, sf2, n_bits)
146 |
147 | # Convert result back to decimal
148 | result_decimal = sf_to_decimal(sf_result)
149 |
150 | # Calculate exact decimal multiplication for comparison
151 | exact_decimal = original_dec1 * original_dec2
152 |
153 | return {
154 | 'sf_binary': sf_result,
155 | 'decimal': result_decimal,
156 | 'inputs_sf': (sf1, sf2),
157 | 'exact_decimal': exact_decimal,
158 | 'error': abs(exact_decimal - result_decimal),
159 | 'input1_overflow': input1_overflow,
160 | 'input2_overflow': input2_overflow,
161 | 'clamped_inputs': (dec1, dec2)
162 | }
163 |
164 |
165 | def sf_add_dec(dec1: float, dec2: float, n_bits: int = 8) -> dict:
166 | """
167 | Adds two decimal numbers using SF arithmetic with overflow clamping.
168 |
169 | Args:
170 | - dec1, dec2: decimal numbers (will be clamped to SF range if needed)
171 | - n_bits: number of fractional bits for binary representation
172 |
173 | Returns:
174 | - Dictionary with result information
175 | """
176 | # Store originals and clamp inputs
177 | original_dec1, original_dec2 = dec1, dec2
178 | max_val = 1.0 - (2 ** (-n_bits))
179 | min_val = -1.0
180 |
181 | input1_overflow = False
182 | input2_overflow = False
183 |
184 | if dec1 >= 1.0:
185 | dec1 = max_val
186 | input1_overflow = True
187 | elif dec1 < -1.0:
188 | dec1 = min_val
189 | input1_overflow = True
190 |
191 | if dec2 >= 1.0:
192 | dec2 = max_val
193 | input2_overflow = True
194 | elif dec2 < -1.0:
195 | dec2 = min_val
196 | input2_overflow = True
197 |
198 | # Convert to SF format
199 | sf1 = decimal_to_sf(dec1, n_bits)
200 | sf2 = decimal_to_sf(dec2, n_bits)
201 |
202 | # Extract components
203 | sign1, frac1 = sf1[0], sf1[2:]
204 | sign2, frac2 = sf2[0], sf2[2:]
205 |
206 | # Convert to signed integers
207 | int1 = int(frac1, 2)
208 | int2 = int(frac2, 2)
209 |
210 | if sign1 == '1':
211 | int1 = -int1
212 | if sign2 == '1':
213 | int2 = -int2
214 |
215 | # Add
216 | result = int1 + int2
217 |
218 | # Handle overflow/underflow in SF integer space
219 | max_sf_int = (1 << n_bits) - 1 # Maximum positive SF integer
220 | min_sf_int = -(1 << n_bits) # Minimum negative SF integer
221 |
222 | result_overflow = False
223 | if result > max_sf_int:
224 | result = max_sf_int
225 | result_overflow = True
226 | elif result < min_sf_int:
227 | result = min_sf_int
228 | result_overflow = True
229 |
230 | # Convert back to SF format
231 | abs_result = abs(result)
232 | sign_bit = '0' if result >= 0 else '1'
233 | frac_binary = bin(abs_result)[2:].zfill(n_bits)
234 |
235 | sf_result = f"{sign_bit}.{frac_binary}"
236 | result_decimal = sf_to_decimal(sf_result)
237 | exact_decimal = original_dec1 + original_dec2
238 |
239 | return {
240 | 'sf_binary': sf_result,
241 | 'decimal': result_decimal,
242 | 'exact_decimal': exact_decimal,
243 | 'error': abs(exact_decimal - result_decimal),
244 | 'result_overflow': result_overflow,
245 | 'input1_overflow': input1_overflow,
246 | 'input2_overflow': input2_overflow
247 | }
248 |
249 |
250 | def sf_tensor_mul(tensor_a: list, tensor_b: list, n_bits: int = 8) -> dict:
251 | """
252 | Multiplies two 2D tensors using signed fixed-point arithmetic with overflow clamping.
253 |
254 | Args:
255 | - tensor_a, tensor_b: 2D lists representing matrices (values will be clamped to SF range)
256 | - n_bits: number of fractional bits for SF representation
257 |
258 | Returns:
259 | - Dictionary with SF result, decimal result, exact result, and error analysis
260 | """
261 | # Validate dimensions
262 | rows_a = len(tensor_a)
263 | cols_a = len(tensor_a[0]) if rows_a > 0 else 0
264 | rows_b = len(tensor_b)
265 | cols_b = len(tensor_b[0]) if rows_b > 0 else 0
266 |
267 | if cols_a != rows_b:
268 | raise ValueError(f"Cannot multiply matrices: {rows_a}x{cols_a} × {rows_b}x{cols_b}")
269 |
270 | # Initialize result matrices
271 | sf_result = [[None for _ in range(cols_b)] for _ in range(rows_a)]
272 | decimal_result = [[0.0 for _ in range(cols_b)] for _ in range(rows_a)]
273 | exact_result = [[0.0 for _ in range(cols_b)] for _ in range(rows_a)]
274 | clamped_exact_result = [[0.0 for _ in range(cols_b)] for _ in range(rows_a)]
275 |
276 | total_error = 0.0
277 | max_error = 0.0
278 | input_overflow_count = 0
279 | result_overflow_count = 0
280 |
281 | max_sf_val = 1.0 - (2 ** (-n_bits))
282 | min_sf_val = -1.0
283 |
284 | # Perform matrix multiplication using SF arithmetic
285 | for i in range(rows_a):
286 | for j in range(cols_b):
287 | # Compute dot product of row i of A and column j of B
288 | sf_sum = 0.0 # Accumulate in decimal for intermediate sums
289 | exact_sum = 0.0
290 |
291 | for k in range(cols_a):
292 | # Track input overflows
293 | if tensor_a[i][k] >= 1.0 or tensor_a[i][k] < -1.0:
294 | input_overflow_count += 1
295 | if tensor_b[k][j] >= 1.0 or tensor_b[k][j] < -1.0:
296 | input_overflow_count += 1
297 |
298 | # SF multiplication (this handles input clamping internally)
299 | mul_result = sf_mul_dec(tensor_a[i][k], tensor_b[k][j], n_bits)
300 |
301 | # Add to running sum (in decimal space)
302 | sf_sum += mul_result['decimal']
303 | exact_sum += tensor_a[i][k] * tensor_b[k][j]
304 |
305 | # Clamp sf_sum to valid SF range
306 | original_sf_sum = sf_sum
307 | if sf_sum >= 1.0:
308 | sf_sum = max_sf_val
309 | result_overflow_count += 1
310 | elif sf_sum < -1.0:
311 | sf_sum = min_sf_val
312 | result_overflow_count += 1
313 |
314 | # Convert final sum to SF format
315 | sf_binary = decimal_to_sf(sf_sum, n_bits)
316 | sf_result[i][j] = sf_binary
317 | decimal_result[i][j] = sf_to_decimal(sf_binary)
318 | exact_result[i][j] = exact_sum
319 |
320 | # Clamp exact result for error comparison
321 | if exact_sum >= 1.0:
322 | clamped_exact_result[i][j] = max_sf_val
323 | elif exact_sum < -1.0:
324 | clamped_exact_result[i][j] = min_sf_val
325 | else:
326 | clamped_exact_result[i][j] = exact_sum
327 |
328 | # Track errors against clamped exact result
329 | error = abs(clamped_exact_result[i][j] - decimal_result[i][j])
330 | total_error += error
331 | max_error = max(max_error, error)
332 |
333 | return {
334 | 'sf_result': sf_result,
335 | 'decimal_result': decimal_result,
336 | 'exact_result': exact_result,
337 | 'clamped_exact_result': clamped_exact_result,
338 | 'total_error': total_error,
339 | 'average_error': total_error / (rows_a * cols_b),
340 | 'max_error': max_error,
341 | 'input_overflow_count': input_overflow_count,
342 | 'result_overflow_count': result_overflow_count,
343 | 'result_shape': (rows_a, cols_b)
344 | }
345 |
346 |
347 | def print_tensor(tensor, title, precision=6):
348 | """Helper function to print tensors nicely"""
349 | print(f"{title}:")
350 | for row in tensor:
351 | formatted_row = []
352 | for val in row:
353 | if isinstance(val, str): # SF binary format
354 | formatted_row.append(f"{val:>12}")
355 | else: # Decimal format
356 | formatted_row.append(f"{val:>{precision+6}.{precision}f}")
357 | print(" [" + ", ".join(formatted_row) + "]")
358 | print()
359 |
360 |
361 | # Test tensor multiplication
362 | print("\n" + "="*60)
363 | print("TENSOR MULTIPLICATION EXAMPLE")
364 | print("="*60)
365 |
366 | A = [
367 | [0.32, -0.75, 0.11],
368 | [-0.58, 0.94, -0.23],
369 | [0.67, -0.12, -0.81]
370 | ]
371 |
372 | B = [
373 | [-0.44, 0.09, 0.65],
374 | [0.77, -0.36, -0.52],
375 | [-0.19, 0.84, 0.27]
376 | ]
377 |
378 | print("Input Tensors:")
379 | print_tensor(A, "Tensor A (Decimal)", 2)
380 |
381 | # Convert input tensors to SF format for display
382 | A_sf = [[decimal_to_sf(A[i][j], 16) for j in range(len(A[i]))] for i in range(len(A))]
383 | B_sf = [[decimal_to_sf(B[i][j], 16) for j in range(len(B[i]))] for i in range(len(B))]
384 |
385 | print_tensor(A_sf, "Tensor A (SF Binary)")
386 | print_tensor(B, "Tensor B (Decimal)", 2)
387 | print_tensor(B_sf, "Tensor B (SF Binary)")
388 |
389 | # Perform SF tensor multiplication
390 | result = sf_tensor_mul(A, B, n_bits=16)
391 |
392 | print("Results:")
393 | print_tensor(result['sf_result'], "SF Binary Result")
394 | print_tensor(result['decimal_result'], "SF Decimal Result", 6)
395 | print_tensor(result['exact_result'], "Exact Decimal Result", 6)
396 | print_tensor(result['clamped_exact_result'], "Clamped Exact Result", 6)
397 |
398 | print(f"Error Analysis (SF vs Clamped Exact):")
399 | print(f" Total Error: {result['total_error']:.8f}")
400 | print(f" Average Error: {result['average_error']:.8f}")
401 | print(f" Maximum Error: {result['max_error']:.8f}")
402 | print(f" Input Overflow Count: {result['input_overflow_count']}")
403 | print(f" Result Overflow Count: {result['result_overflow_count']}")
404 | print(f" Result Shape: {result['result_shape']}")
--------------------------------------------------------------------------------
/src/verilog/activationUnit_16bit.sv:
--------------------------------------------------------------------------------
1 | module ActivationUnit (
2 | input wire clk,
3 | input wire start, // Start processing
4 | input wire select, // 0: Linear, 1: ReLU
5 | input wire [15:0] buffer, // 16-bit buffer holding matrix values
6 | output reg [15:0] result // Output after activation
7 | );
8 | reg signed [15:0] data; // Assuming 16-bit fixed/floating point numbers
9 | reg signed [15:0] output_data;
10 |
11 | always @(posedge clk) begin
12 | if (start) begin
13 | // Extract value from the buffer
14 | data = buffer;
15 |
16 | // Apply activation function
17 | if (select == 1'b1) begin
18 | // ReLU: Max(0, value)
19 | output_data = (data < 0) ? 0 : data;
20 | end else begin
21 | // Linear: Identity function
22 | output_data = data;
23 | end
24 |
25 | // Store result
26 | result = output_data;
27 | end
28 | end
29 | endmodule
30 |
--------------------------------------------------------------------------------
/src/verilog/shiftRegister_16bit.sv:
--------------------------------------------------------------------------------
1 | module ShiftRegister #(parameter SIZE = 16, DEPTH = 8) (
2 | input wire clk,
3 | input wire reset,
4 | input wire shift_en,
5 | input wire [SIZE-1:0] data_in,
6 | output reg [SIZE-1:0] data_out
7 | );
8 |
9 | reg [SIZE-1:0] shift_reg [DEPTH-1:0]; // Single chain of 8 registers, each 16-bit
10 | integer i;
11 |
12 | always @(posedge clk or posedge reset) begin
13 | if (reset) begin
14 | for (i = 0; i < DEPTH; i = i + 1) begin
15 | shift_reg[i] <= 0;
16 | end
17 | data_out <= 0;
18 | end else if (shift_en) begin
19 | for (i = DEPTH-1; i > 0; i = i - 1) begin
20 | shift_reg[i] <= shift_reg[i-1];
21 | end
22 | shift_reg[0] <= data_in;
23 | data_out <= shift_reg[DEPTH-1];
24 | end
25 | end
26 | endmodule
27 |
--------------------------------------------------------------------------------
/src/wasq/wasq_eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoTokenizer, AutoModelForCausalLM
3 | from tqdm import tqdm
4 | from datasets import load_dataset
5 |
6 | # Device setup
7 | if torch.backends.mps.is_available():
8 | device = torch.device("mps")
9 | elif torch.cuda.is_available():
10 | device = torch.device("cuda")
11 | else:
12 | device = torch.device("cpu")
13 |
14 | print(f"Using device: {device}")
15 |
16 | base_dir = "./"
17 |
18 | # Function to load model
19 | def load_model(model_path):
20 | model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir='./', token='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')
21 | model.load_state_dict(torch.load(model_path, map_location=device))
22 | model = model.to(torch.bfloat16).to(device)
23 | model.eval() # Ensure model is in inference mode
24 | return model
25 |
26 | # Define Superfloat quantizer for clamping activations
27 | class Superfloat:
28 | def __init__(self, bits: int):
29 | assert 4 <= bits <= 16, "Superfloat bitwidth must be between 4 and 16."
30 | self.bits = bits
31 | self.mantissa_bits = bits - 1
32 | self.max_val = 1 - 2**-self.mantissa_bits # Precompute max representable value
33 |
34 | def encode(self, value: torch.Tensor) -> torch.Tensor:
35 | """Encodes a tensor of values into the superfloat format with optimized operations."""
36 | # Clip tensor values to the valid range for SFx
37 | clipped_value = torch.clamp(value, min=-self.max_val, max=self.max_val)
38 |
39 | # Calculate mantissa representation element-wise
40 | mantissa = (torch.abs(clipped_value) * (2**self.mantissa_bits - 1) / self.max_val).floor().to(torch.int32)
41 |
42 | # Create the superfloat representation (1 bit for sign and mantissa bits)
43 | sign = (clipped_value < 0).to(torch.int32)
44 | return (mantissa | (sign << self.mantissa_bits)).to(torch.int32)
45 |
46 | def decode(self, encoded_value: torch.Tensor) -> torch.Tensor:
47 | """Decodes a tensor of encoded superfloat values to regular floats."""
48 | # Extract mantissa and sign from the encoded superfloat
49 | mantissa = encoded_value & ((1 << self.mantissa_bits) - 1)
50 | sign = (encoded_value >> self.mantissa_bits) & 1
51 |
52 | # Calculate the decoded float using the mantissa and max_val
53 | decoded_value = (mantissa.to(torch.bfloat16) / (2**self.mantissa_bits - 1)) * self.max_val
54 | return decoded_value * (2 * sign - 1) # Apply the sign
55 |
56 | def tensor_quantize(self, tensor: torch.Tensor) -> torch.Tensor:
57 | """Quantizes a tensor to the superfloat format, preserving the tensor's shape."""
58 | # Apply element-wise encoding to the entire tensor and then decode back
59 | encoded_tensor = self.encode(tensor)
60 | decoded_tensor = self.decode(encoded_tensor)
61 | return decoded_tensor
62 |
63 | # Initialize Superfloat quantizer for sf{sf.bits}amping
64 | sf = Superfloat(8)
65 |
66 | model_name = "meta-llama/Llama-3.2-1B"
67 |
68 | # Load tokenizer
69 | tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='./', token='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')
70 | tokenizer.pad_token = tokenizer.eos_token
71 |
72 | def quantized_inference(model, tokenizer, prompt, max_length=500):
73 | """Runs inference on a prompt with activation quantization using Superfloat."""
74 | # Encode input prompt
75 | inputs = tokenizer(prompt, return_tensors="pt", padding=True)
76 | input_ids = inputs.input_ids.to(device)
77 | attention_mask = inputs.attention_mask.to(device)
78 |
79 | with torch.no_grad():
80 | # Perform generation with clamped activations
81 | outputs = model.generate(
82 | input_ids=input_ids,
83 | attention_mask=attention_mask,
84 | max_length=max_length,
85 | do_sample=True,
86 | top_k=50,
87 | top_p=0.95,
88 | temperature=0.7,
89 | pad_token_id=tokenizer.eos_token_id
90 | )
91 |
92 | # Decode generated output
93 | generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
94 | return generated_text
95 |
96 | def calculate_perplexity(model, tokenizer, prompt):
97 | """Calculates the perplexity of the model on a given prompt."""
98 | # Tokenize input
99 | inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
100 | input_ids = inputs.input_ids.to(device)
101 | attention_mask = inputs.attention_mask.to(device)
102 |
103 | # Get model outputs (logits)
104 | with torch.no_grad():
105 | outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
106 |
107 | # Get the loss (cross entropy) from the model's output
108 | loss = outputs.loss # This is the cross-entropy loss
109 |
110 | # Compute perplexity: exp(loss)
111 | perplexity = torch.exp(loss)
112 | return perplexity.item()
113 |
114 | # Model paths
115 | import os
116 |
117 | def get_model_paths(base_dir, sf_bits):
118 | """
119 | Dynamically generate model paths based on the sf.bits format.
120 | Looks for models of the form:
121 | 1. sf{sf_bits}_vanilla
122 | 2. sf{sf_bits}_{epoch_num}_fpm
123 | 3. sf{sf_bits}_{epoch_num}_opt
124 |
125 | Args:
126 | base_dir (str): The directory where the models are stored.
127 | sf_bits (int): The bitwidth for the Superfloat quantizer.
128 |
129 | Returns:
130 | List of model paths.
131 | """
132 | model_paths = []
133 | model_pattern = f"sf{sf_bits}_"
134 |
135 | # Scan directory for models matching the pattern
136 | for model_name in os.listdir(base_dir):
137 | if model_name.startswith(model_pattern):
138 | model_paths.append(os.path.join(base_dir, model_name))
139 |
140 | # Ensure models are sorted to follow the desired order: vanilla -> fpm -> opt
141 | model_paths.sort()
142 |
143 | return model_paths
144 |
145 | # Function to evaluate perplexity for a list of models and prompts
146 | def evaluate_models(base_dir, sf_bits, tokenizer, prompts):
147 | """
148 | Evaluates models dynamically loaded based on the sf.bits format.
149 |
150 | Args:
151 | base_dir (str): The directory where the models are stored.
152 | sf_bits (int): The bitwidth for the Superfloat quantizer.
153 | tokenizer: The tokenizer to use for model inference.
154 | prompts: The list of prompts to evaluate.
155 |
156 | Returns:
157 | Dictionary with model names and their corresponding average perplexity.
158 | """
159 | model_perplexities = {}
160 |
161 | # Get dynamically generated model paths
162 | models = get_model_paths(base_dir, sf_bits)
163 |
164 | for model_path in models:
165 | model = load_model(model_path)
166 | print(f"Evaluating model: {model_path}")
167 |
168 | total_perplexity = 0.0
169 | num_prompts = len(prompts)
170 |
171 | # Compute perplexity for each prompt
172 | for prompt in tqdm(prompts, desc=f"Processing {model_path}", leave=False):
173 | perplexity = calculate_perplexity(model, tokenizer, prompt)
174 | total_perplexity += perplexity
175 |
176 | # Average perplexity for the current model
177 | avg_perplexity = total_perplexity / num_prompts
178 | model_perplexities[model_path] = avg_perplexity
179 | print(f"Average Perplexity for {model_path}: {avg_perplexity}")
180 |
181 | return model_perplexities
182 |
183 | # Function to load the HellaSwag dataset
184 | def load_hellaswag_data():
185 | """Load the HellaSwag dataset from Hugging Face."""
186 | dataset = load_dataset("hellaswag", split='validation')
187 |
188 | # Extract only the prompts (contexts) for evaluation
189 | prompts = [entry['ctx'] for entry in dataset]
190 |
191 | # Return the prompts as a list
192 | return prompts
193 |
194 | # Load HellaSwag data (prompts)
195 | prompts = load_hellaswag_data()
196 |
197 | # Evaluate all models on HellaSwag prompts
198 | model_perplexities = evaluate_models(base_dir, sf.bits, tokenizer, prompts)
199 |
200 | # Print final results
201 | print("\nAverage Perplexities for all models:")
202 | for model_path, avg_perplexity in model_perplexities.items():
203 | print(f"{model_path}: {avg_perplexity}")
--------------------------------------------------------------------------------
/src/wasq/wasq_fasteropt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers import AutoModelForCausalLM, AutoTokenizer
4 | from torch.utils.data import DataLoader
5 | from datasets import load_dataset, Dataset
6 | from tqdm import tqdm
7 | import gc
8 |
9 | max_length = 512
10 | bit = 8
11 | ERROR_BUDGET = 0.01 # maximum allowed relative quantization error per layer
12 |
13 | class Superfloat:
14 | CASTING_TABLE = {
15 | 16: torch.float32,
16 | 15: torch.float32,
17 | 14: torch.float32,
18 | 13: torch.float32,
19 | 12: torch.float32,
20 | 11: torch.float16,
21 | 10: torch.float16,
22 | 9: torch.float16,
23 | 8: torch.bfloat16,
24 | 7: torch.bfloat16,
25 | 6: torch.bfloat16,
26 | 5: torch.bfloat16,
27 | 4: torch.bfloat16,
28 | }
29 |
30 | def __init__(self, bits: int):
31 | assert 4 <= bits <= 16, "Superfloat bitwidth must be between 4 and 16."
32 | self.bits = bits
33 | self.mantissa_bits = bits - 1
34 | self.max_val = 1 - 2**-self.mantissa_bits # Precompute max representable value
35 | self.float_type = self.CASTING_TABLE[bits] # Get float type based on bitwidth
36 |
37 | def encode(self, value: torch.Tensor) -> torch.Tensor:
38 | """Encodes a tensor of values into the superfloat format."""
39 | clipped_value = torch.clamp(value, min=-self.max_val, max=self.max_val)
40 | out_of_range = (value.abs() > self.max_val)
41 | mantissa = (torch.abs(clipped_value) * (2**self.mantissa_bits - 1) / self.max_val).floor().to(torch.int32)
42 | sign = (clipped_value < 0).to(torch.int32)
43 | return (mantissa | (sign << self.mantissa_bits)).to(torch.int32), out_of_range
44 |
45 | def decode(self, encoded_value: torch.Tensor) -> torch.Tensor:
46 | """Decodes a tensor of encoded superfloat values to regular floats."""
47 | mantissa = encoded_value & ((1 << self.mantissa_bits) - 1)
48 | sign = (encoded_value >> self.mantissa_bits) & 1
49 | decoded_value = (mantissa.to(self.float_type) / (2**self.mantissa_bits - 1)) * self.max_val
50 | return decoded_value * (2 * sign - 1)
51 |
52 | def tensor_quantize(self, tensor: torch.Tensor) -> torch.Tensor:
53 | """Quantizes a tensor to the superfloat format, preserving the tensor's shape."""
54 | encoded_tensor, out_of_range = self.encode(tensor)
55 | decoded_tensor = self.decode(encoded_tensor)
56 | return decoded_tensor, out_of_range
57 |
58 | sf = Superfloat(bit)
59 |
60 |
61 | class SFQuantFunction(torch.autograd.Function):
62 | """Straight-through estimator for Superfloat quantisation."""
63 |
64 | @staticmethod
65 | def forward(ctx, tensor, sf_obj):
66 | encoded, mask = sf_obj.encode(tensor)
67 | ctx.save_for_backward(mask)
68 | ctx.sf_obj = sf_obj
69 | return sf_obj.decode(encoded)
70 |
71 | @staticmethod
72 | def backward(ctx, grad_output):
73 | (mask,) = ctx.saved_tensors
74 | return grad_output * mask.to(grad_output.dtype), None
75 |
76 |
77 | class QuantizedLinear(nn.Module):
78 | """Linear layer with weights encoded once and decoded on-the-fly."""
79 |
80 | def __init__(self, linear: nn.Linear, sf_bits: int):
81 | super().__init__()
82 | self.sf = Superfloat(sf_bits)
83 | self.in_features = linear.in_features
84 | self.out_features = linear.out_features
85 | self.bias = nn.Parameter(linear.bias.detach()) if linear.bias is not None else None
86 | self.register_buffer("encoded_weight", None)
87 | self.register_buffer("outlier_mask", None)
88 | self.register_buffer("outlier_values", None)
89 | # learnable per-channel scale (LSQ+ style)
90 | self.scale = nn.Parameter(torch.ones(linear.out_features, dtype=self.sf.float_type))
91 | self.encode_weight(linear.weight)
92 |
93 | def encode_weight(self, weight, outlier_percent=0.5):
94 | # Encode once and split top-k outliers
95 | encoded, mask = self.sf.encode(weight)
96 | if outlier_percent:
97 | k = max(1, int(outlier_percent / 100.0 * weight.numel()))
98 | thresh = torch.topk(weight.abs().view(-1), k).values[-1]
99 | mask |= weight.abs() >= thresh
100 | self.encoded_weight = encoded
101 | self.outlier_mask = mask
102 | self.outlier_values = weight[mask]
103 |
104 | def forward(self, x):
105 | w = self.sf.decode(self.encoded_weight)
106 | if self.outlier_mask.any():
107 | w = w.clone()
108 | w[self.outlier_mask] = self.outlier_values
109 | w = w * self.scale.unsqueeze(1)
110 | x = SFQuantFunction.apply(x, self.sf)
111 | return nn.functional.linear(x, w, self.bias)
112 |
113 |
114 | def quant_error(weight: torch.Tensor, sf_bits: int) -> float:
115 | sf_tmp = Superfloat(sf_bits)
116 | quant, _ = sf_tmp.tensor_quantize(weight)
117 | return torch.norm(weight - quant) / torch.norm(weight)
118 |
119 |
120 | def search_layer_bitwidth(weight: torch.Tensor, bits_list) -> int:
121 | for b in sorted(bits_list):
122 | if quant_error(weight, b) <= ERROR_BUDGET:
123 | return b
124 | return max(bits_list)
125 |
126 |
127 | def quantize_linear_layers(model, bits_candidates):
128 | for name, module in model.named_modules():
129 | if isinstance(module, nn.Linear):
130 | chosen = search_layer_bitwidth(module.weight.data, bits_candidates)
131 | qlin = QuantizedLinear(module, chosen)
132 | parent = model
133 | name_parts = name.split('.')
134 | for n in name_parts[:-1]:
135 | parent = getattr(parent, n)
136 | setattr(parent, name_parts[-1], qlin)
137 |
138 | class QuantizedLlamaModel(torch.nn.Module):
139 | def __init__(self, base_model: torch.nn.Module, sf_quantizer: Superfloat):
140 | super(QuantizedLlamaModel, self).__init__()
141 | self.base_model = base_model
142 | self.sf_quantizer = sf_quantizer
143 | # Replace Linear layers with quantised versions
144 | quantize_linear_layers(self.base_model, [4, 8, 11, 16])
145 | self.apply_gradient_hooks()
146 |
147 | def apply_gradient_hooks(self):
148 | for param in self.base_model.parameters():
149 | def hook(grad, param=param):
150 | _, mask = self.sf_quantizer.encode(param)
151 | return grad * mask.to(grad.dtype)
152 | param.register_hook(hook)
153 |
154 | def forward(self, *args, **kwargs):
155 | if "input_ids" in kwargs:
156 | kwargs["input_ids"] = SFQuantFunction.apply(kwargs["input_ids"], self.sf_quantizer)
157 | outputs = self.base_model(*args, **kwargs)
158 | if hasattr(outputs, "logits"):
159 | outputs.logits = SFQuantFunction.apply(outputs.logits, self.sf_quantizer)
160 | return outputs
161 |
162 | # Initialize model and tokenizer
163 | if torch.backends.mps.is_available():
164 | device = torch.device("mps")
165 | elif torch.cuda.is_available():
166 | device = torch.device("cuda")
167 | else:
168 | device = torch.device("cpu")
169 |
170 | print(f"Using device: {device}")
171 |
172 | model_name = "Qwen/Qwen2-0.5B"
173 | model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir='./', token='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')
174 | model = model.to(sf.float_type).to(device)
175 |
176 | tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='./', token='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')
177 | tokenizer.pad_token = tokenizer.eos_token
178 |
179 | # Quantize Model Weights Selectively
180 | def quantize_model(model, sf_type):
181 | for name, param in model.named_parameters():
182 | quantized_param, _ = sf_type.tensor_quantize(param)
183 | param.data = quantized_param.data
184 | return model
185 |
186 | import os
187 | import re
188 |
189 | def load_checkpoint(model, sf_bits, suffix="opt", device=device):
190 | """
191 | Load the latest checkpoint based on the provided Superfloat bitwidth and filename suffix.
192 |
193 | Args:
194 | quantized_model: The model to load the checkpoint into.
195 | sf_bits: Bitwidth of the Superfloat format (e.g., 11).
196 | suffix: The suffix of the filename (default: 'opt').
197 | device: Device to load the model onto ('cuda' or 'cpu').
198 |
199 | Returns:
200 | The quantized model with loaded weights and the epoch number.
201 | """
202 | # Define the filename pattern to search for
203 | checkpoint_pattern = re.compile(f"sf{sf_bits}_.*_epoch(\\d+)_.*{suffix}$")
204 |
205 | # Find all matching checkpoint files
206 | checkpoint_files = [
207 | f for f in os.listdir(".") if checkpoint_pattern.match(f)
208 | ]
209 |
210 | if not checkpoint_files:
211 | print(f"No checkpoints found for sf{sf_bits} with suffix '{suffix}'.")
212 | return quantize_model(model, sf), 0
213 |
214 | # Extract epoch numbers and sort by latest epoch
215 | epochs_and_files = [
216 | (int(checkpoint_pattern.match(f).group(1)), f) for f in checkpoint_files
217 | ]
218 | latest_epoch, latest_checkpoint = max(epochs_and_files, key=lambda x: x[0])
219 |
220 | # Load the latest checkpoint
221 | print(f"Loading checkpoint: {latest_checkpoint}")
222 | checkpoint = torch.load(latest_checkpoint, map_location=device)
223 | model.load_state_dict(checkpoint)
224 | model.to(device)
225 |
226 | return model, latest_epoch
227 |
228 | # Pre-training parameter check to ensure they are within range
229 | def check_parameters_in_range(model, sf):
230 | out_of_range_params = []
231 | for name, param in model.named_parameters():
232 | if not torch.all(torch.abs(param.data) <= sf.max_val):
233 | out_of_range_params.append(name)
234 | if out_of_range_params:
235 | print(f"Warning: The following parameters are out of range:")
236 | for param_name in out_of_range_params:
237 | print(f"- {param_name}")
238 | else:
239 | print("All parameters are within the valid range.")
240 |
241 | # Usage
242 | base_model, last_epoch = load_checkpoint(model, sf.bits, suffix="opt", device=device)
243 | quantized = QuantizedLlamaModel(base_model, sf)
244 | print(f"Resuming training from epoch {last_epoch + 1}.")
245 |
246 | # Check if model parameters are within range before training
247 | check_parameters_in_range(quantized, sf)
248 |
249 | del model
250 | if device=="cuda":
251 | torch.cuda.empty_cache()
252 | elif device=="mps":
253 | torch.mps.empty_cache()
254 | gc.collect()
255 |
256 | # Prepare Dataset
257 | def prepare_dataset(tokenizer, max_length=1):
258 | dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
259 | # dataset = Dataset.from_parquet('train.parquet')
260 | def tokenize_function(examples):
261 | return tokenizer(
262 | examples["text"],
263 | truncation=True,
264 | max_length=max_length,
265 | padding="max_length",
266 | return_tensors="pt"
267 | )
268 | tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
269 | return tokenized_dataset
270 |
271 | # Custom collate function
272 | def collate_fn(batch):
273 | input_ids = torch.stack([torch.tensor(example['input_ids']) for example in batch])
274 | attention_mask = torch.stack([torch.tensor(example['attention_mask']) for example in batch])
275 | return {'input_ids': input_ids, 'attention_mask': attention_mask}
276 |
277 | # Prepare tokenized dataset and dataloader
278 | tokenized_dataset = prepare_dataset(tokenizer, max_length=max_length)
279 | train_dataloader = DataLoader(tokenized_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
280 |
281 | # Optimizer and Loss
282 | optimizer = torch.optim.Adam(quantized.parameters(), lr=1e-5, eps=1e-4)
283 | loss_fn = torch.nn.CrossEntropyLoss()
284 |
285 | # Training Loop
286 | num_epochs = 3
287 | accumulation_steps = 8 # Number of steps to accumulate gradients
288 | best_loss = float('inf')
289 |
290 | quantized.to(device)
291 |
292 | for epoch in range(num_epochs):
293 | epoch_loss = 0.0
294 | epoch_iterator = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Epoch {epoch + 1}/{num_epochs}")
295 |
296 | for step, batch in epoch_iterator:
297 | input_ids = batch['input_ids'].to(device)
298 | attention_mask = batch['attention_mask'].to(device)
299 |
300 | # Forward pass
301 | outputs = quantized(input_ids=input_ids, attention_mask=attention_mask)
302 | logits = outputs.logits
303 | target = input_ids[:, 1:].contiguous()
304 | logits = logits[:, :-1].contiguous()
305 |
306 | # Calculate loss
307 | loss = loss_fn(logits.view(-1, logits.size(-1)), target.view(-1)) / accumulation_steps
308 |
309 | # Backward pass with gradient quantization
310 | loss.backward()
311 |
312 | # Accumulate loss for reporting
313 | epoch_loss += loss.item() * accumulation_steps
314 |
315 | if (step + 1) % accumulation_steps == 0:
316 | # Clamp activations and model parameters within the Superfloat range
317 | for name, param in quantized.named_parameters():
318 | param.data = torch.clamp(param.data, min=-sf.max_val, max=sf.max_val)
319 |
320 | # Check activations range
321 | for name, param in quantized.named_parameters():
322 | if not torch.all(torch.abs(param.data) <= sf.max_val):
323 | print(f"Warning: {name} activation is out of range after clamping!")
324 |
325 | torch.nn.utils.clip_grad_value_(quantized.parameters(), clip_value=sf.max_val)
326 | optimizer.step()
327 | optimizer.zero_grad()
328 | epoch_iterator.set_postfix({"Loss": f"{loss.item() * accumulation_steps:.4f}"})
329 |
330 | epoch_loss /= len(train_dataloader)
331 | if epoch_loss < best_loss:
332 | torch.save(quantized.state_dict(), f"sf{sf.bits}_{epoch+1}_opt")
333 | print(f"Epoch {epoch + 1} completed with average loss: {epoch_loss:.4f}")
--------------------------------------------------------------------------------
/src/wasq/wasq_fpm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
3 | import os
4 | import re
5 | import gc
6 | from torch.utils.data import DataLoader
7 | from datasets import load_dataset, Dataset
8 | from tqdm import tqdm
9 |
10 | class Superfloat:
11 | CASTING_TABLE = {
12 | 16: torch.float32,
13 | 15: torch.float32,
14 | 14: torch.float32,
15 | 13: torch.float32,
16 | 12: torch.float32,
17 | 11: torch.float16,
18 | 10: torch.float16,
19 | 9: torch.float16,
20 | 8: torch.bfloat16,
21 | 7: torch.bfloat16,
22 | 6: torch.bfloat16,
23 | 5: torch.bfloat16,
24 | 4: torch.bfloat16,
25 | }
26 |
27 | def __init__(self, bits: int):
28 | assert 4 <= bits <= 16, "Superfloat bitwidth must be between 4 and 16."
29 | self.bits = bits
30 | self.mantissa_bits = bits - 1
31 | self.max_val = 1 - 2**-self.mantissa_bits # Precompute max representable value
32 | self.float_type = self.CASTING_TABLE[bits] # Get float type based on bitwidth
33 |
34 | def encode(self, value: torch.Tensor) -> torch.Tensor:
35 | """Encodes a tensor of values into the superfloat format."""
36 | clipped_value = torch.clamp(value, min=-self.max_val, max=self.max_val)
37 | mantissa = (torch.abs(clipped_value) * (2**self.mantissa_bits - 1) / self.max_val).floor().to(torch.int32)
38 | sign = (clipped_value < 0).to(torch.int32)
39 | return (mantissa | (sign << self.mantissa_bits)).to(torch.int32)
40 |
41 | def decode(self, encoded_value: torch.Tensor) -> torch.Tensor:
42 | """Decodes a tensor of encoded superfloat values to regular floats."""
43 | mantissa = encoded_value & ((1 << self.mantissa_bits) - 1)
44 | sign = (encoded_value >> self.mantissa_bits) & 1
45 | decoded_value = (mantissa.to(self.float_type) / (2**self.mantissa_bits - 1)) * self.max_val
46 | return decoded_value * (2 * sign - 1)
47 |
48 | def tensor_quantize(self, tensor: torch.Tensor) -> torch.Tensor:
49 | """Quantizes a tensor to the superfloat format, preserving the tensor's shape."""
50 | # Apply element-wise encoding to the entire tensor and then decode back
51 | print(f"Params in layer: {len(tensor)}")
52 | encoded_tensor = self.encode(tensor)
53 | print("Encoding complete")
54 | decoded_tensor = self.decode(encoded_tensor)
55 | print("Decoding complete")
56 | return decoded_tensor
57 |
58 | sf = Superfloat(8)
59 |
60 | class QuantizedLlamaModel(torch.nn.Module):
61 | def __init__(self, base_model: torch.nn.Module, sf_quantizer: Superfloat):
62 | super(QuantizedLlamaModel, self).__init__()
63 | self.base_model = base_model
64 | self.sf_quantizer = sf_quantizer
65 | self.apply_gradient_hooks()
66 |
67 | def apply_gradient_hooks(self):
68 | # Register a hook to quantize gradients after backward pass
69 | for param in self.base_model.parameters():
70 | param.register_hook(lambda grad: self.sf_quantizer.tensor_quantize(grad))
71 |
72 | def forward(self, x):
73 | # Quantize activations and parameters during forward pass
74 | x = self.sf_quantizer.tensor_quantize(x)
75 | for layer in self.base_model.children():
76 | if isinstance(layer, torch.nn.Linear):
77 | layer.weight.data = self.sf_quantizer.tensor_quantize(layer.weight.data)
78 | x = self.sf_quantizer.tensor_quantize(layer(x))
79 | return x
80 |
81 | if torch.backends.mps.is_available():
82 | device = torch.device("mps")
83 | elif torch.cuda.is_available():
84 | device = torch.device("cuda")
85 | else:
86 | device = torch.device("cpu")
87 |
88 | print(f"Using device: {device}")
89 |
90 | # Initialize model and tokenizer
91 | model_name = "meta-llama/Llama-3.2-1B"
92 | model = LlamaForCausalLM.from_pretrained(model_name, cache_dir='./', token='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')
93 | model = model.to(sf.float_type).to(device)
94 |
95 | tokenizer = PreTrainedTokenizerFast.from_pretrained(model_name, cache_dir='./', token='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')
96 |
97 | tokenizer.pad_token = tokenizer.eos_token
98 |
99 | def quantize_model(model, sf_type):
100 | for name, param in model.named_parameters():
101 | print(name, len(param))
102 | quantized_param = sf_type.tensor_quantize(param)
103 | param.data = quantized_param.data
104 | return model
105 |
106 | # Checker function to verify quantization
107 | def check_model_quantization(model, sf_type):
108 | all_parameters_valid = True
109 | for name, param in model.named_parameters():
110 | param_data = param.data
111 | if param_data.dtype != sf_type.float_type:
112 | print(f"Parameter {name} is not in {sf_type.float_type} format!")
113 | all_parameters_valid = False
114 | if not torch.all((param_data >= -sf_type.max_val) & (param_data <= sf_type.max_val)):
115 | print(f"Parameter {name} has values outside the SF{sf_type.bits} range!")
116 | all_parameters_valid = False
117 | return all_parameters_valid
118 |
119 | def load_checkpoint(model, sf_bits, suffix="fpm", device=device):
120 | """
121 | Load the latest checkpoint based on the provided Superfloat bitwidth and filename suffix.
122 |
123 | Args:
124 | quantized_model: The model to load the checkpoint into.
125 | sf_bits: Bitwidth of the Superfloat format (e.g., 11).
126 | suffix: The suffix of the filename (default: 'fpm').
127 | device: Device to load the model onto ('cuda' or 'cpu').
128 |
129 | Returns:
130 | The quantized model with loaded weights and the epoch number.
131 | """
132 | # Define the filename pattern to search for
133 | checkpoint_pattern = re.compile(f"sf{sf_bits}_.*_epoch(\\d+)_.*{suffix}$")
134 |
135 | # Find all matching checkpoint files
136 | checkpoint_files = [
137 | f for f in os.listdir(".") if checkpoint_pattern.match(f)
138 | ]
139 |
140 | if not checkpoint_files:
141 | print(f"No checkpoints found for sf{sf_bits} with suffix '{suffix}'.")
142 | return quantize_model(model, sf), 0
143 |
144 | # Extract epoch numbers and sort by latest epoch
145 | epochs_and_files = [
146 | (int(checkpoint_pattern.match(f).group(1)), f) for f in checkpoint_files
147 | ]
148 | latest_epoch, latest_checkpoint = max(epochs_and_files, key=lambda x: x[0])
149 |
150 | # Load the latest checkpoint
151 | print(f"Loading checkpoint: {latest_checkpoint}")
152 | checkpoint = torch.load(latest_checkpoint, map_location=device)
153 | model.load_state_dict(checkpoint)
154 | model.to(device)
155 |
156 | return model, latest_epoch
157 |
158 | # Usage
159 | quantized, last_epoch = load_checkpoint(model, sf.bits, suffix="fpm", device=device)
160 | print(f"Resuming training from epoch {last_epoch + 1}.")
161 |
162 | del model
163 | if device=='cuda':
164 | # Clear CUDA cache to free up memory
165 | torch.cuda.empty_cache()
166 | elif device=='mps':
167 | # Clear MPS cache to free up memory
168 | torch.mps.empty_cache()
169 | gc.collect()
170 |
171 | # Prepare Dataset
172 | def prepare_dataset(tokenizer, max_length=512):
173 | """Prepare the dataset with proper tensor formatting."""
174 | dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
175 | # dataset = Dataset.from_parquet('train.parquet')
176 | def tokenize_function(examples):
177 | outputs = tokenizer(
178 | examples["text"],
179 | truncation=True,
180 | max_length=max_length,
181 | padding="max_length",
182 | return_tensors="pt"
183 | )
184 | return outputs
185 |
186 | tokenized_dataset = dataset.map(
187 | tokenize_function,
188 | batched=True,
189 | remove_columns=dataset.column_names
190 | )
191 |
192 | return tokenized_dataset
193 |
194 | # Custom collate function
195 | def collate_fn(batch):
196 | """Custom collate function to properly format tensors."""
197 | input_ids = torch.stack([torch.tensor(example['input_ids']) for example in batch])
198 | attention_mask = torch.stack([torch.tensor(example['attention_mask']) for example in batch])
199 | return {
200 | 'input_ids': input_ids,
201 | 'attention_mask': attention_mask
202 | }
203 |
204 | # Optimizer and Loss
205 | optimizer = torch.optim.Adam(quantized.parameters(), lr=1e-5, eps=1e-4)
206 | loss_fn = torch.nn.CrossEntropyLoss()
207 |
208 | # Prepare tokenized dataset and dataloader
209 | tokenized_dataset = prepare_dataset(tokenizer)
210 | train_dataloader = DataLoader(tokenized_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
211 |
212 | num_epochs = 3
213 | accumulation_steps = 8 # Number of steps to accumulate gradients
214 | best_loss = float('inf')
215 |
216 | quantized.to(device)
217 |
218 | # Training Loop with Autoregressive Target, Gradient Accumulation, and Progress Tracking
219 | for epoch in range(num_epochs):
220 | epoch_loss = 0.0 # Track total loss for each epoch
221 |
222 | # Initialize tqdm for tracking epoch progress
223 | epoch_iterator = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Epoch {epoch + 1}/{num_epochs}")
224 |
225 | for step, batch in epoch_iterator:
226 | input_ids = batch['input_ids'].to(device)
227 | attention_mask = batch['attention_mask'].to(device)
228 |
229 | # Forward pass
230 | outputs = quantized(input_ids=input_ids, attention_mask=attention_mask)
231 |
232 | # Access the logits for token predictions
233 | logits = outputs.logits # Retrieve logits tensor from ModelOutput
234 |
235 | # Shift input_ids by one for autoregressive target
236 | target = input_ids[:, 1:].contiguous() # Target is the input shifted by one token
237 | logits = logits[:, :-1].contiguous() # Align logits with target length
238 |
239 | # Calculate loss
240 | loss = loss_fn(logits.view(-1, logits.size(-1)), target.view(-1))
241 |
242 | # Divide loss by accumulation steps
243 | loss = loss / accumulation_steps
244 |
245 | # Backward pass with gradient quantization
246 | loss.backward() # Gradient quantization occurs via hook
247 |
248 | # Accumulate the loss for reporting
249 | epoch_loss += loss.item() * accumulation_steps
250 |
251 | # Perform optimizer step after accumulating gradients
252 | if (step + 1) % accumulation_steps == 0:
253 | torch.nn.utils.clip_grad_value_(quantized.parameters(), clip_value = sf.max_val)
254 | optimizer.step()
255 | optimizer.zero_grad() # Clear gradients for next accumulation
256 |
257 | # Update tqdm progress bar with current step loss
258 | epoch_iterator.set_postfix({"Loss": f"{loss.item() * accumulation_steps:.4f}"})
259 |
260 | # Average epoch loss
261 | epoch_loss /= len(train_dataloader)
262 | if epoch_loss < best_loss:
263 | torch.save(quantized.state_dict(), f"sf{sf.bits}_{epoch+1}_fpm")
264 | print(f"Epoch {epoch + 1} completed with average loss: {epoch_loss:.4f}")
--------------------------------------------------------------------------------
/src/wasq/wasq_inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
3 |
4 | # Device setup
5 | if torch.backends.mps.is_available():
6 | device = torch.device("mps")
7 | elif torch.cuda.is_available():
8 | device = torch.device("cuda")
9 | else:
10 | device = torch.device("cpu")
11 |
12 | print(f"Using device: {device}")
13 |
14 | # Define Superfloat quantizer for clamping activations
15 | class Superfloat:
16 | def __init__(self, bits: int):
17 | assert 4 <= bits <= 16, "Superfloat bitwidth must be between 4 and 16."
18 | self.bits = bits
19 | self.mantissa_bits = bits - 1
20 | self.max_val = 1 - 2**-self.mantissa_bits # Precompute max representable value
21 |
22 | def encode(self, value: torch.Tensor) -> torch.Tensor:
23 | """Encodes a tensor of values into the superfloat format with optimized operations."""
24 | # Clip tensor values to the valid range for SFx
25 | clipped_value = torch.clamp(value, min=-self.max_val, max=self.max_val)
26 |
27 | # Calculate mantissa representation element-wise
28 | mantissa = (torch.abs(clipped_value) * (2**self.mantissa_bits - 1) / self.max_val).floor().to(torch.int32)
29 |
30 | # Create the superfloat representation (1 bit for sign and mantissa bits)
31 | sign = (clipped_value < 0).to(torch.int32)
32 | return (mantissa | (sign << self.mantissa_bits)).to(torch.int32)
33 |
34 | def decode(self, encoded_value: torch.Tensor) -> torch.Tensor:
35 | """Decodes a tensor of encoded superfloat values to regular floats."""
36 | # Extract mantissa and sign from the encoded superfloat
37 | mantissa = encoded_value & ((1 << self.mantissa_bits) - 1)
38 | sign = (encoded_value >> self.mantissa_bits) & 1
39 |
40 | # Calculate the decoded float using the mantissa and max_val
41 | decoded_value = (mantissa.to(torch.bfloat16) / (2**self.mantissa_bits - 1)) * self.max_val
42 | return decoded_value * (2 * sign - 1) # Apply the sign
43 |
44 | def tensor_quantize(self, tensor: torch.Tensor) -> torch.Tensor:
45 | """Quantizes a tensor to the superfloat format, preserving the tensor's shape."""
46 | # Apply element-wise encoding to the entire tensor and then decode back
47 | encoded_tensor = self.encode(tensor)
48 | decoded_tensor = self.decode(encoded_tensor)
49 | return decoded_tensor
50 |
51 | # Initialize Superfloat quantizer for clamping
52 | sf = Superfloat(8)
53 |
54 | # Load model in bfloat16 directly for inference
55 | model_name = "meta-llama/Llama-3.2-1B"
56 | model = LlamaForCausalLM.from_pretrained(model_name, cache_dir='./', token='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')
57 | model.load_state_dict(torch.load("sf{sf.bits}_trained_epoch3", map_location=device))
58 | model = model.to(torch.bfloat16).to(device)
59 | model.eval() # Ensure model is in inference mode
60 |
61 | # Load tokenizer
62 | tokenizer = PreTrainedTokenizerFast.from_pretrained(model_name, cache_dir='./', token='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')
63 | tokenizer.pad_token = tokenizer.eos_token
64 |
65 | def quantized_inference(model, tokenizer, prompt, max_length=500):
66 | """Runs inference on a prompt with activation quantization using Superfloat."""
67 | # Encode input prompt
68 | inputs = tokenizer(prompt, return_tensors="pt", padding=True)
69 | input_ids = inputs.input_ids.to(device)
70 | attention_mask = inputs.attention_mask.to(device)
71 |
72 | with torch.no_grad():
73 | # Perform generation with clamped activations
74 | outputs = model.generate(
75 | input_ids=input_ids,
76 | attention_mask=attention_mask,
77 | max_length=max_length,
78 | do_sample=True,
79 | top_k=50,
80 | top_p=0.95,
81 | temperature=0.7,
82 | pad_token_id=tokenizer.eos_token_id
83 | )
84 |
85 | # Decode generated output
86 | generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
87 | return generated_text
88 |
89 | # Example usage
90 | prompt = "It could be one of those nights, where we don't turn off the lights."
91 | generated_text = quantized_inference(model, tokenizer, prompt)
92 | print("Generated text:", generated_text)
--------------------------------------------------------------------------------
/src/wasq/wasq_lth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | from transformers import AutoModelForCausalLM, AutoTokenizer
5 | from torch.utils.data import DataLoader
6 | from datasets import Dataset
7 | from tqdm import tqdm
8 | import copy
9 | import numpy as np
10 |
11 | class Superfloat:
12 | CASTING_TABLE = {
13 | 16: torch.float32,
14 | 15: torch.float32,
15 | 14: torch.float32,
16 | 13: torch.float32,
17 | 12: torch.float32,
18 | 11: torch.float16,
19 | 10: torch.float16,
20 | 9: torch.float16,
21 | 8: torch.bfloat16,
22 | 7: torch.bfloat16,
23 | 6: torch.bfloat16,
24 | 5: torch.bfloat16,
25 | 4: torch.bfloat16,
26 | }
27 |
28 | def __init__(self, bits: int):
29 | assert 4 <= bits <= 16, "Superfloat bitwidth must be between 4 and 16."
30 | self.bits = bits
31 | self.mantissa_bits = bits - 1
32 | self.max_val = 1 - 2**-self.mantissa_bits # Precompute max representable value
33 | self.float_type = self.CASTING_TABLE[bits] # Get float type based on bitwidth
34 |
35 | def quantize(self, tensor: torch.Tensor) -> torch.Tensor:
36 | """Quantizes a tensor to Superfloat format."""
37 | # Per-channel scaling for dynamic range
38 | scale = self.max_val / tensor.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-8)
39 | quantized = torch.clamp(tensor * scale, -self.max_val, self.max_val).round()
40 | return quantized / scale # Dequantize for inference
41 |
42 | if torch.backends.mps.is_available():
43 | device = torch.device("mps")
44 | elif torch.cuda.is_available():
45 | device = torch.device("cuda")
46 | else:
47 | device = torch.device("cpu")
48 |
49 | print(f"Using device: {device}")
50 |
51 | class LotteryTicketTrainer:
52 | def __init__(self, model, sf_quantizer, tokenizer, config):
53 | self.device = device
54 | self.sf_quantizer = sf_quantizer
55 | self.model = model.to(device=self.device, dtype=sf_quantizer.float_type)
56 | self.tokenizer = tokenizer
57 | self.config = config
58 | self.original_model_state = copy.deepcopy(self.model.state_dict())
59 | self.winning_tickets = {}
60 | self.pruning_rate = config.get('pruning_rate', 0.2)
61 | self.pruning_iterations = config.get('pruning_iterations', 3)
62 | self.optimizer = optim.Adam(self.model.parameters(), lr=config.get('learning_rate', 1e-5), eps=config.get('optimizer_eps', 1e-4))
63 | self.loss_fn = nn.CrossEntropyLoss()
64 |
65 | def prepare_dataset(self, max_length=512):
66 | dataset = Dataset.from_parquet('train.parquet')
67 |
68 | def tokenize_function(examples):
69 | return self.tokenizer(
70 | examples["text"],
71 | truncation=True,
72 | max_length=max_length,
73 | padding="max_length",
74 | return_tensors="pt"
75 | )
76 |
77 | tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
78 | return tokenized_dataset
79 |
80 | def create_dataloader(self, dataset, batch_size=4):
81 | def collate_fn(batch):
82 | input_ids = torch.stack([torch.tensor(example['input_ids']) for example in batch])
83 | attention_mask = torch.stack([torch.tensor(example['attention_mask']) for example in batch])
84 | return {'input_ids': input_ids, 'attention_mask': attention_mask}
85 |
86 | return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
87 |
88 | def magnitude_based_pruning(self):
89 | pruning_masks = {}
90 |
91 | for name, param in self.model.named_parameters():
92 | if len(param.shape) > 1: # Only prune weight matrices
93 | weight_abs = torch.abs(param.data)
94 | flat_weights = weight_abs.view(-1)
95 | k = int(flat_weights.numel() * self.pruning_rate)
96 | threshold = torch.topk(flat_weights, k, largest=False).values.max()
97 | mask = (weight_abs > threshold).float()
98 | pruning_masks[name] = mask
99 | param.data *= mask
100 |
101 | return pruning_masks
102 |
103 | def reset_to_winning_ticket(self, pruning_masks):
104 | for name, param in self.model.named_parameters():
105 | if name in pruning_masks:
106 | # Reset to original initialization, then apply mask
107 | param.data.copy_(self.original_model_state[name])
108 | param.data *= pruning_masks[name]
109 |
110 | def activation_magnitude_analysis(self):
111 | # Compare activations between original and quantized models
112 | with torch.no_grad():
113 | original_activations = self.get_activations(self.original_model_state)
114 | quantized_activations = self.get_activations(self.model.state_dict())
115 | return self.compute_layerwise_differences(original_activations, quantized_activations)
116 |
117 | def get_activations(self, model_state):
118 | # Placeholder: Implement forward pass to collect activations
119 | activations = {}
120 | for name, param in model_state.items():
121 | if len(param.shape) > 1:
122 | activations[name] = torch.mean(torch.abs(param)).item()
123 | return activations
124 |
125 | def compute_layerwise_differences(self, original_activations, quantized_activations):
126 | differences = {}
127 | for name in original_activations:
128 | differences[name] = abs(original_activations[name] - quantized_activations[name])
129 | return differences
130 |
131 | def fine_tune_based_on_activations(self, layer_activation_changes):
132 | # Fine-tune layers with significant activation change
133 | for layer, change in layer_activation_changes.items():
134 | if change > self.config.get('activation_threshold', 0.1):
135 | # Fine-tune or adjust this layer specifically
136 | pass # Fine-tune layer weights based on magnitude analysis
137 |
138 | def train(self):
139 | tokenized_dataset = self.prepare_dataset()
140 | dataloader = self.create_dataloader(tokenized_dataset)
141 |
142 | num_epochs = self.config.get('num_epochs', 3)
143 | accumulation_steps = self.config.get('accumulation_steps', 32)
144 | best_loss = float('inf')
145 |
146 | for iteration in range(self.pruning_iterations):
147 | print(f"\nPruning Iteration {iteration + 1}/{self.pruning_iterations}")
148 |
149 | for epoch in range(num_epochs):
150 | self.model.train()
151 | epoch_loss = 0.0
152 |
153 | epoch_iterator = tqdm(
154 | enumerate(dataloader),
155 | total=len(dataloader),
156 | desc=f"Iteration {iteration + 1}, Epoch {epoch + 1}"
157 | )
158 |
159 | for step, batch in epoch_iterator:
160 | input_ids = batch['input_ids'].to(device=self.device, dtype=torch.long)
161 | attention_mask = batch['attention_mask'].to(device=self.device, dtype=torch.long)
162 |
163 | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
164 | logits = outputs.logits
165 |
166 | target = input_ids[:, 1:].contiguous()
167 | logits = logits[:, :-1].contiguous()
168 |
169 | loss = self.loss_fn(
170 | logits.view(-1, logits.size(-1)),
171 | target.view(-1)
172 | ) / accumulation_steps
173 |
174 | loss.backward()
175 | epoch_loss += loss.item() * accumulation_steps
176 |
177 | if (step + 1) % accumulation_steps == 0:
178 | for param in self.model.parameters():
179 | param.data = torch.clamp(
180 | param.data,
181 | min=-self.sf_quantizer.max_val,
182 | max=self.sf_quantizer.max_val
183 | )
184 |
185 | torch.nn.utils.clip_grad_value_(
186 | self.model.parameters(),
187 | clip_value=self.sf_quantizer.max_val
188 | )
189 |
190 | self.optimizer.step()
191 | self.optimizer.zero_grad()
192 |
193 | epoch_iterator.set_postfix({"Loss": f"{loss.item() * accumulation_steps:.4f}"})
194 |
195 | epoch_loss /= len(dataloader)
196 | print(f"Epoch {epoch + 1} Loss: {epoch_loss:.4f}")
197 |
198 | if epoch_loss < best_loss:
199 | best_loss = epoch_loss
200 | torch.save(
201 | self.model.state_dict(),
202 | f"sf{self.sf_quantizer.bits}_iteration{iteration+1}_epoch{epoch+1}_best.pth"
203 | )
204 |
205 | pruning_masks = self.magnitude_based_pruning()
206 | self.reset_to_winning_ticket(pruning_masks)
207 |
208 | # After pruning, perform activation analysis and fine-tuning
209 | layer_activation_changes = self.activation_magnitude_analysis()
210 | self.fine_tune_based_on_activations(layer_activation_changes)
211 |
212 | torch.save(self.model.state_dict(), f"sf{self.sf_quantizer.bits}_winning_ticket_iteration{iteration+1}.pth")
213 |
214 | def main():
215 | # Load the pre-trained model and tokenizer
216 | model_name = "gpt2"
217 | model = AutoModelForCausalLM.from_pretrained(model_name)
218 | tokenizer = AutoTokenizer.from_pretrained(model_name)
219 |
220 | # Initialize the Superfloat quantizer
221 | sf_quantizer = Superfloat(bits=11) # You can experiment with different bit-widths
222 |
223 | # Configuration settings
224 | config = {
225 | "pruning_rate": 0.2,
226 | "pruning_iterations": 3,
227 | "learning_rate": 1e-5,
228 | "optimizer_eps": 1e-4,
229 | "num_epochs": 3,
230 | "accumulation_steps": 32,
231 | "activation_threshold": 0.1 # You can adjust this threshold
232 | }
233 |
234 | # Instantiate the trainer
235 | trainer = LotteryTicketTrainer(model, sf_quantizer, tokenizer, config)
236 |
237 | # Train the model with LTH + Superfloat quantization
238 | trainer.train()
239 |
240 | if __name__ == "__main__":
241 | main()
--------------------------------------------------------------------------------
/src/wasq/wasq_mplth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | from transformers import AutoModelForCausalLM, AutoTokenizer
5 | from torch.utils.data import DataLoader
6 | from datasets import Dataset
7 | from tqdm import tqdm
8 | import copy
9 | import numpy as np
10 |
11 | class Superfloat:
12 | CASTING_TABLE = {
13 | 16: torch.float32,
14 | 15: torch.float32,
15 | 14: torch.float32,
16 | 13: torch.float32,
17 | 12: torch.float32,
18 | 11: torch.float16,
19 | 10: torch.float16,
20 | 9: torch.float16,
21 | 8: torch.bfloat16,
22 | 7: torch.bfloat16,
23 | 6: torch.bfloat16,
24 | 5: torch.bfloat16,
25 | 4: torch.bfloat16,
26 | }
27 |
28 | def __init__(self, bits: int):
29 | assert 4 <= bits <= 16, "Superfloat bitwidth must be between 4 and 16."
30 | self.bits = bits
31 | self.mantissa_bits = bits - 1
32 | self.max_val = 1 - 2**-self.mantissa_bits # Precompute max representable value
33 | self.float_type = self.CASTING_TABLE[bits] # Get float type based on bitwidth
34 |
35 | def quantize(self, tensor: torch.Tensor) -> torch.Tensor:
36 | """Quantizes a tensor to Superfloat format."""
37 | scale = self.max_val / tensor.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-8)
38 | quantized = torch.clamp(tensor * scale, -self.max_val, self.max_val).round()
39 | return quantized / scale # Dequantize for inference
40 |
41 | if torch.backends.mps.is_available():
42 | device = torch.device("mps")
43 | elif torch.cuda.is_available():
44 | device = torch.device("cuda")
45 | else:
46 | device = torch.device("cpu")
47 |
48 | print(f"Using device: {device}")
49 |
50 | class MultiPrizeLotteryTicketTrainer:
51 | def __init__(self, model, sf_quantizer, tokenizer, config):
52 | self.device = device
53 | self.sf_quantizer = sf_quantizer
54 | self.model = model.to(device=self.device, dtype=sf_quantizer.float_type)
55 | self.tokenizer = tokenizer
56 | self.config = config
57 | self.original_model_state = copy.deepcopy(self.model.state_dict())
58 | self.winning_tickets = {} # Store multiple winning tickets
59 | self.pruning_rates = config.get('pruning_rates', [0.1, 0.2, 0.3]) # Multiple pruning rates
60 | self.pruning_iterations = config.get('pruning_iterations', 3)
61 | self.optimizer = optim.Adam(self.model.parameters(), lr=config.get('learning_rate', 1e-5), eps=config.get('optimizer_eps', 1e-4))
62 | self.loss_fn = nn.CrossEntropyLoss()
63 |
64 | def prepare_dataset(self, max_length=512):
65 | dataset = Dataset.from_parquet('train.parquet')
66 |
67 | def tokenize_function(examples):
68 | return self.tokenizer(
69 | examples["text"],
70 | truncation=True,
71 | max_length=max_length,
72 | padding="max_length",
73 | return_tensors="pt"
74 | )
75 |
76 | tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
77 | return tokenized_dataset
78 |
79 | def create_dataloader(self, dataset, batch_size=4):
80 | def collate_fn(batch):
81 | input_ids = torch.stack([torch.tensor(example['input_ids']) for example in batch])
82 | attention_mask = torch.stack([torch.tensor(example['attention_mask']) for example in batch])
83 | return {'input_ids': input_ids, 'attention_mask': attention_mask}
84 |
85 | return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
86 |
87 | def magnitude_based_pruning(self, pruning_rate):
88 | pruning_masks = {}
89 |
90 | for name, param in self.model.named_parameters():
91 | if len(param.shape) > 1: # Only prune weight matrices
92 | weight_abs = torch.abs(param.data)
93 | flat_weights = weight_abs.view(-1)
94 | k = int(flat_weights.numel() * pruning_rate)
95 | threshold = torch.topk(flat_weights, k, largest=False).values.max()
96 | mask = (weight_abs > threshold).float()
97 | pruning_masks[name] = mask
98 | param.data *= mask
99 |
100 | return pruning_masks
101 |
102 | def reset_to_winning_ticket(self, pruning_masks):
103 | for name, param in self.model.named_parameters():
104 | if name in pruning_masks:
105 | param.data.copy_(self.original_model_state[name])
106 | param.data *= pruning_masks[name]
107 |
108 | def fine_tune(self, dataloader, num_epochs=3):
109 | best_loss = float('inf')
110 |
111 | for epoch in range(num_epochs):
112 | self.model.train()
113 | epoch_loss = 0.0
114 |
115 | epoch_iterator = tqdm(dataloader, total=len(dataloader), desc=f"Epoch {epoch + 1}")
116 |
117 | for step, batch in epoch_iterator:
118 | input_ids = batch['input_ids'].to(device=self.device, dtype=torch.long)
119 | attention_mask = batch['attention_mask'].to(device=self.device, dtype=torch.long)
120 |
121 | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
122 | logits = outputs.logits
123 |
124 | target = input_ids[:, 1:].contiguous()
125 | logits = logits[:, :-1].contiguous()
126 |
127 | loss = self.loss_fn(
128 | logits.view(-1, logits.size(-1)),
129 | target.view(-1)
130 | )
131 |
132 | loss.backward()
133 | epoch_loss += loss.item()
134 |
135 | if (step + 1) % self.config.get('accumulation_steps', 32) == 0:
136 | self.optimizer.step()
137 | self.optimizer.zero_grad()
138 |
139 | epoch_iterator.set_postfix({"Loss": f"{loss.item():.4f}"})
140 |
141 | epoch_loss /= len(dataloader)
142 | print(f"Epoch {epoch + 1} Loss: {epoch_loss:.4f}")
143 |
144 | if epoch_loss < best_loss:
145 | best_loss = epoch_loss
146 | torch.save(self.model.state_dict(), f"best_model_epoch_{epoch+1}.pth")
147 |
148 | def train(self):
149 | tokenized_dataset = self.prepare_dataset()
150 | dataloader = self.create_dataloader(tokenized_dataset)
151 |
152 | for iteration in range(self.pruning_iterations):
153 | print(f"\nPruning Iteration {iteration + 1}/{self.pruning_iterations}")
154 |
155 | for pruning_rate in self.pruning_rates:
156 | print(f"Applying pruning with rate {pruning_rate}...")
157 |
158 | pruning_masks = self.magnitude_based_pruning(pruning_rate)
159 | self.reset_to_winning_ticket(pruning_masks)
160 |
161 | # Fine-tune with the current pruning rate
162 | self.fine_tune(dataloader)
163 |
164 | # Save the "winning ticket" for this pruning rate
165 | self.winning_tickets[pruning_rate] = self.model.state_dict()
166 | torch.save(self.model.state_dict(), f"sf{self.sf_quantizer.bits}_winning_ticket_rate{pruning_rate}.pth")
167 |
168 | # Optionally, after all pruning iterations, you could apply activation magnitude analysis and fine-tune further
169 |
170 | def main():
171 | model_name = "gpt2"
172 | model = AutoModelForCausalLM.from_pretrained(model_name)
173 | tokenizer = AutoTokenizer.from_pretrained(model_name)
174 |
175 | # Initialize the Superfloat quantizer
176 | sf_quantizer = Superfloat(bits=11)
177 |
178 | # Configuration settings
179 | config = {
180 | "pruning_rates": [0.1, 0.2, 0.3], # Different pruning rates
181 | "pruning_iterations": 3,
182 | "learning_rate": 1e-5,
183 | "optimizer_eps": 1e-4,
184 | "num_epochs": 3,
185 | "accumulation_steps": 32,
186 | "activation_threshold": 0.1 # Optional: You can adjust this threshold
187 | }
188 |
189 | trainer = MultiPrizeLotteryTicketTrainer(model, sf_quantizer, tokenizer, config)
190 | trainer.train()
191 |
192 | if __name__ == "__main__":
193 | main()
--------------------------------------------------------------------------------
/src/wasq/wasq_opt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
4 | from torch.utils.data import DataLoader
5 | from datasets import load_dataset, Dataset
6 | from tqdm import tqdm
7 | import gc
8 |
9 | class Superfloat:
10 | CASTING_TABLE = {
11 | 16: torch.float32,
12 | 15: torch.float32,
13 | 14: torch.float32,
14 | 13: torch.float32,
15 | 12: torch.float32,
16 | 11: torch.float16,
17 | 10: torch.float16,
18 | 9: torch.float16,
19 | 8: torch.bfloat16,
20 | 7: torch.bfloat16,
21 | 6: torch.bfloat16,
22 | 5: torch.bfloat16,
23 | 4: torch.bfloat16,
24 | }
25 |
26 | def __init__(self, bits: int):
27 | assert 4 <= bits <= 16, "Superfloat bitwidth must be between 4 and 16."
28 | self.bits = bits
29 | self.mantissa_bits = bits - 1
30 | self.max_val = 1 - 2**-self.mantissa_bits # Precompute max representable value
31 | self.float_type = self.CASTING_TABLE[bits] # Get float type based on bitwidth
32 |
33 | def encode(self, value: torch.Tensor) -> torch.Tensor:
34 | """Encodes a tensor of values into the superfloat format."""
35 | clipped_value = torch.clamp(value, min=-self.max_val, max=self.max_val)
36 | out_of_range = (value.abs() > self.max_val)
37 | mantissa = (torch.abs(clipped_value) * (2**self.mantissa_bits - 1) / self.max_val).floor().to(torch.int32)
38 | sign = (clipped_value < 0).to(torch.int32)
39 | return (mantissa | (sign << self.mantissa_bits)).to(torch.int32), out_of_range
40 |
41 | def decode(self, encoded_value: torch.Tensor) -> torch.Tensor:
42 | """Decodes a tensor of encoded superfloat values to regular floats."""
43 | mantissa = encoded_value & ((1 << self.mantissa_bits) - 1)
44 | sign = (encoded_value >> self.mantissa_bits) & 1
45 | decoded_value = (mantissa.to(self.float_type) / (2**self.mantissa_bits - 1)) * self.max_val
46 | return decoded_value * (2 * sign - 1)
47 |
48 | def tensor_quantize(self, tensor: torch.Tensor) -> torch.Tensor:
49 | """Quantizes a tensor to the superfloat format, preserving the tensor's shape."""
50 | encoded_tensor, out_of_range = self.encode(tensor)
51 | decoded_tensor = self.decode(encoded_tensor)
52 | return decoded_tensor, out_of_range
53 |
54 | sf = Superfloat(11)
55 |
56 | class QuantizedLlamaModel(torch.nn.Module):
57 | def __init__(self, base_model: torch.nn.Module, sf_quantizer: Superfloat):
58 | super(QuantizedLlamaModel, self).__init__()
59 | self.base_model = base_model
60 | self.sf_quantizer = sf_quantizer
61 | self.apply_gradient_hooks()
62 |
63 | def apply_gradient_hooks(self):
64 | for param in self.base_model.parameters():
65 | def hook(grad, param=param):
66 | _, out_of_range = self.sf_quantizer.tensor_quantize(param)
67 | grad = grad * out_of_range.to(grad.dtype) # Mask to allow gradients only on out-of-range params
68 | return grad
69 | param.register_hook(hook)
70 |
71 | def forward(self, x):
72 | x, _ = self.sf_quantizer.tensor_quantize(x)
73 | for layer in self.base_model.children():
74 | if isinstance(layer, torch.nn.Linear):
75 | layer.weight.data, _ = self.sf_quantizer.tensor_quantize(layer.weight.data)
76 | x = layer(x)
77 | x, _ = self.sf_quantizer.tensor_quantize(x)
78 | return x
79 |
80 | # Initialize model and tokenizer
81 | if torch.backends.mps.is_available():
82 | device = torch.device("mps")
83 | elif torch.cuda.is_available():
84 | device = torch.device("cuda")
85 | else:
86 | device = torch.device("cpu")
87 |
88 | print(f"Using device: {device}")
89 |
90 | model_name = "meta-llama/Llama-3.2-1B"
91 | model = LlamaForCausalLM.from_pretrained(model_name, cache_dir='./', token='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')
92 | model = model.to(sf.float_type).to(device)
93 |
94 | tokenizer = PreTrainedTokenizerFast.from_pretrained(model_name, cache_dir='./', token='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')
95 | tokenizer.pad_token = tokenizer.eos_token
96 |
97 | # Quantize Model Weights Selectively
98 | def quantize_model(model, sf_type):
99 | for name, param in model.named_parameters():
100 | quantized_param, _ = sf_type.tensor_quantize(param)
101 | param.data = quantized_param.data
102 | return model
103 |
104 | import os
105 | import re
106 |
107 | def load_checkpoint(model, sf_bits, suffix="opt", device=device):
108 | """
109 | Load the latest checkpoint based on the provided Superfloat bitwidth and filename suffix.
110 |
111 | Args:
112 | quantized_model: The model to load the checkpoint into.
113 | sf_bits: Bitwidth of the Superfloat format (e.g., 11).
114 | suffix: The suffix of the filename (default: 'opt').
115 | device: Device to load the model onto ('cuda' or 'cpu').
116 |
117 | Returns:
118 | The quantized model with loaded weights and the epoch number.
119 | """
120 | # Define the filename pattern to search for
121 | checkpoint_pattern = re.compile(f"sf{sf_bits}_.*_epoch(\\d+)_.*{suffix}$")
122 |
123 | # Find all matching checkpoint files
124 | checkpoint_files = [
125 | f for f in os.listdir(".") if checkpoint_pattern.match(f)
126 | ]
127 |
128 | if not checkpoint_files:
129 | print(f"No checkpoints found for sf{sf_bits} with suffix '{suffix}'.")
130 | return quantize_model(model, sf), 0
131 |
132 | # Extract epoch numbers and sort by latest epoch
133 | epochs_and_files = [
134 | (int(checkpoint_pattern.match(f).group(1)), f) for f in checkpoint_files
135 | ]
136 | latest_epoch, latest_checkpoint = max(epochs_and_files, key=lambda x: x[0])
137 |
138 | # Load the latest checkpoint
139 | print(f"Loading checkpoint: {latest_checkpoint}")
140 | checkpoint = torch.load(latest_checkpoint, map_location=device)
141 | model.load_state_dict(checkpoint)
142 | model.to(device)
143 |
144 | return model, latest_epoch
145 |
146 | # Usage
147 | quantized, last_epoch = load_checkpoint(model, sf.bits, suffix="opt", device=device)
148 | print(f"Resuming training from epoch {last_epoch + 1}.")
149 |
150 | del model
151 | if device=="cuda":
152 | torch.cuda.empty_cache()
153 | if device=="mps":
154 | torch.mps.empty_cache()
155 | gc.collect()
156 |
157 | # Prepare Dataset
158 | def prepare_dataset(tokenizer, max_length=1024):
159 | dataset = Dataset.from_parquet('train.parquet')
160 | def tokenize_function(examples):
161 | return tokenizer(
162 | examples["text"],
163 | truncation=True,
164 | max_length=max_length,
165 | padding="max_length",
166 | return_tensors="pt"
167 | )
168 | tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
169 | return tokenized_dataset
170 |
171 | # Custom collate function
172 | def collate_fn(batch):
173 | input_ids = torch.stack([torch.tensor(example['input_ids']) for example in batch])
174 | attention_mask = torch.stack([torch.tensor(example['attention_mask']) for example in batch])
175 | return {'input_ids': input_ids, 'attention_mask': attention_mask}
176 |
177 | # Prepare tokenized dataset and dataloader
178 | tokenized_dataset = prepare_dataset(tokenizer)
179 | train_dataloader = DataLoader(tokenized_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
180 |
181 | # Optimizer and Loss
182 | optimizer = torch.optim.Adam(quantized.parameters(), lr=1e-5, eps=1e-4)
183 | loss_fn = torch.nn.CrossEntropyLoss()
184 |
185 | # Training Loop
186 | num_epochs = 3
187 | accumulation_steps = 32 # Number of steps to accumulate gradients
188 | best_loss = float('inf')
189 |
190 | quantized.to(device)
191 |
192 | for epoch in range(num_epochs):
193 | epoch_loss = 0.0
194 | epoch_iterator = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Epoch {epoch + 1}/{num_epochs}")
195 |
196 | for step, batch in epoch_iterator:
197 | input_ids = batch['input_ids'].to(device)
198 | attention_mask = batch['attention_mask'].to(device)
199 |
200 | # Forward pass
201 | outputs = quantized(input_ids=input_ids, attention_mask=attention_mask)
202 | logits = outputs.logits
203 | target = input_ids[:, 1:].contiguous()
204 | logits = logits[:, :-1].contiguous()
205 |
206 | # Calculate loss
207 | loss = loss_fn(logits.view(-1, logits.size(-1)), target.view(-1)) / accumulation_steps
208 |
209 | # Backward pass with gradient quantization
210 | loss.backward()
211 |
212 | # Accumulate loss for reporting
213 | epoch_loss += loss.item() * accumulation_steps
214 |
215 | if (step + 1) % accumulation_steps == 0:
216 | torch.nn.utils.clip_grad_value_(quantized.parameters(), clip_value = sf.max_val)
217 | optimizer.step()
218 | optimizer.zero_grad()
219 | epoch_iterator.set_postfix({"Loss": f"{loss.item() * accumulation_steps:.4f}"})
220 |
221 | epoch_loss /= len(train_dataloader)
222 | if epoch_loss < best_loss:
223 | torch.save(quantized.state_dict(), f"sf{sf.bits}_{epoch+1}_opt")
224 | print(f"Epoch {epoch + 1} completed with average loss: {epoch_loss:.4f}")
--------------------------------------------------------------------------------
/src/wasq/wasq_vanilla.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Function
5 | from transformers import AutoModelForCausalLM
6 |
7 |
8 | class Superfloat:
9 | """Simple Superfloat quantizer with encode/decode utilities."""
10 |
11 | CASTING_TABLE = {
12 | 16: torch.float32,
13 | 15: torch.float32,
14 | 14: torch.float32,
15 | 13: torch.float32,
16 | 12: torch.float32,
17 | 11: torch.float16,
18 | 10: torch.float16,
19 | 9: torch.float16,
20 | 8: torch.bfloat16,
21 | 7: torch.bfloat16,
22 | 6: torch.bfloat16,
23 | 5: torch.bfloat16,
24 | 4: torch.bfloat16,
25 | }
26 |
27 | def __init__(self, bits: int):
28 | assert 4 <= bits <= 16, "Superfloat bitwidth must be between 4 and 16."
29 | self.bits = bits
30 | self.mantissa_bits = bits - 1
31 | self.max_val = 1 - 2 ** -self.mantissa_bits
32 | self.float_type = self.CASTING_TABLE[bits]
33 |
34 | def encode(self, value: torch.Tensor):
35 | clipped = torch.clamp(value, min=-self.max_val, max=self.max_val)
36 | mantissa = (torch.abs(clipped) * (2 ** self.mantissa_bits - 1) / self.max_val).floor().to(torch.int32)
37 | sign = (clipped < 0).to(torch.int32)
38 | encoded = (mantissa | (sign << self.mantissa_bits)).to(torch.int32)
39 | out_of_range = (value.abs() > self.max_val)
40 | return encoded, out_of_range
41 |
42 | def decode(self, encoded: torch.Tensor) -> torch.Tensor:
43 | mantissa = encoded & ((1 << self.mantissa_bits) - 1)
44 | sign = (encoded >> self.mantissa_bits) & 1
45 | decoded = (mantissa.to(self.float_type) / (2 ** self.mantissa_bits - 1)) * self.max_val
46 | return decoded * (2 * sign - 1)
47 |
48 | def tensor_quantize(self, tensor: torch.Tensor) -> torch.Tensor:
49 | enc, _ = self.encode(tensor)
50 | return self.decode(enc)
51 |
52 |
53 | class SFQuant(Function):
54 | """Straight-through estimator for Superfloat quantization."""
55 |
56 | @staticmethod
57 | def forward(ctx, input: torch.Tensor, sf: "Superfloat"):
58 | encoded, mask = sf.encode(input)
59 | ctx.save_for_backward(mask)
60 | ctx.sf = sf
61 | return sf.decode(encoded)
62 |
63 | @staticmethod
64 | def backward(ctx, grad_output):
65 | (mask,) = ctx.saved_tensors
66 | # Pass gradients only where values were in range
67 | return grad_output * mask.to(grad_output.dtype), None
68 |
69 |
70 | class QuantizedLinear(nn.Linear):
71 | """Linear layer with on-the-fly Superfloat decode and optional LSQ+ scale."""
72 |
73 | def __init__(self, in_features, out_features, sf: Superfloat, bias=True, k_outlier=0.005):
74 | super().__init__(in_features, out_features, bias)
75 | self.sf = sf
76 |
77 | # Split outlier channels that would overflow after quantisation
78 | with torch.no_grad():
79 | channel_max = self.weight.abs().max(dim=1).values
80 | k = max(1, int(k_outlier * out_features))
81 | self.outlier_idx = torch.topk(channel_max, k).indices
82 | mask = torch.ones(out_features, dtype=torch.bool)
83 | mask[self.outlier_idx] = False
84 | base_w = self.weight[mask].clone()
85 | self.register_buffer("encoded_weight", sf.encode(base_w)[0])
86 | self.register_parameter("scale", nn.Parameter(torch.ones(base_w.size(0))))
87 | self.register_parameter("outlier_weight", nn.Parameter(self.weight[self.outlier_idx].clone()))
88 | self.register_buffer("mask", mask)
89 | # Remove original parameter
90 | self.weight.requires_grad = False
91 |
92 | def forward(self, input: torch.Tensor) -> torch.Tensor:
93 | # Decode base weight on the fly and apply LSQ+ scale
94 | decoded_base = self.sf.decode(self.encoded_weight) * self.scale.view(-1, 1)
95 | weight = self.weight.new_zeros(self.out_features, self.in_features)
96 | weight[self.mask] = decoded_base
97 | weight[self.outlier_idx] = self.outlier_weight
98 | return F.linear(input, weight, self.bias)
99 |
100 |
101 | class ActivationQuant(nn.Module):
102 | """Module to quantise activations symmetrically with Superfloat."""
103 |
104 | def __init__(self, sf: Superfloat):
105 | super().__init__()
106 | self.sf = sf
107 |
108 | def forward(self, x: torch.Tensor) -> torch.Tensor:
109 | return SFQuant.apply(x, self.sf)
110 |
111 |
112 | def compute_hessian_scores(model, data_loader, device, num_batches=1):
113 | """Approximate block-diagonal Hessian scores for parameters."""
114 | scores = {name: torch.zeros_like(p) for name, p in model.named_parameters() if p.requires_grad}
115 | loss_fn = nn.CrossEntropyLoss()
116 | model.eval()
117 | for i, batch in enumerate(data_loader):
118 | if i >= num_batches:
119 | break
120 | batch = {k: v.to(device) for k, v in batch.items()}
121 | output = model(**batch)
122 | logits = output.logits
123 | target = batch["input_ids"][:, 1:].contiguous()
124 | logits = logits[:, :-1].contiguous()
125 | loss = loss_fn(logits.view(-1, logits.size(-1)), target.view(-1))
126 | grads = torch.autograd.grad(loss, [p for p in model.parameters() if p.requires_grad], create_graph=False)
127 | for (name, _), g in zip([(n, p) for n, p in model.named_parameters() if p.requires_grad], grads):
128 | scores[name] += g.pow(2)
129 | return scores
130 |
131 |
132 | def select_sf_bits(weight, score, bit_options=(16, 11, 8, 4), budget=1e-3):
133 | """Simple layer-adaptive bit-width search using a quantisation error budget."""
134 | for bits in sorted(bit_options, reverse=True):
135 | sf = Superfloat(bits)
136 | q = sf.tensor_quantize(weight)
137 | err = (weight - q).abs().mean() * score.mean()
138 | if err <= budget:
139 | return sf
140 | return Superfloat(bit_options[0])
141 |
142 |
143 | def quantize_model(model, sf_options=(16, 11, 8, 4), data_loader=None, device="cpu"):
144 | """Quantise linear layers adaptively and insert activation quantisation."""
145 | if data_loader is not None:
146 | scores = compute_hessian_scores(model, data_loader, device)
147 | else:
148 | scores = {name: torch.ones_like(p) for name, p in model.named_parameters() if p.requires_grad}
149 |
150 | for name, module in model.named_modules():
151 | if isinstance(module, nn.Linear):
152 | score = scores.get(f"{name}.weight", torch.ones_like(module.weight))
153 | sf = select_sf_bits(module.weight.data, score)
154 | qlinear = QuantizedLinear(module.in_features, module.out_features, sf, module.bias is not None)
155 | qlinear.bias = module.bias
156 | setattr(model, name.split(".")[-1], qlinear)
157 | elif isinstance(module, nn.Module) and not isinstance(module, ActivationQuant):
158 | module.register_forward_pre_hook(lambda m, inp: (SFQuant.apply(inp[0], Superfloat(11)),))
159 | return model
160 |
161 |
162 | def main():
163 | model_name = "Qwen/Qwen2-0.5B"
164 | if torch.backends.mps.is_available():
165 | device = torch.device("mps")
166 | elif torch.cuda.is_available():
167 | device = torch.device("cuda")
168 | else:
169 | device = torch.device("cpu")
170 | print(f"Using device: {device}")
171 |
172 | # Model loading may require network; placeholder path for offline usage
173 | model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="./")
174 | model = model.to(device)
175 |
176 | # Dummy dataloader for Hessian approximation (replace with real data)
177 | dummy_input = torch.randint(0, 10, (1, 8))
178 | dummy_mask = torch.ones_like(dummy_input)
179 | dataset = [{"input_ids": dummy_input, "attention_mask": dummy_mask}]
180 | data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
181 |
182 | print("Applying adaptive Superfloat quantisation...")
183 | quantized_model = quantize_model(model, data_loader=data_loader, device=device)
184 |
185 | save_path = "sf_vanilla_adaptive.pt"
186 | torch.save(quantized_model.state_dict(), save_path)
187 | print(f"Quantised model saved to {save_path}")
188 |
189 |
190 | if __name__ == "__main__":
191 | main()
192 |
--------------------------------------------------------------------------------
/src/website/backend/parallel_backend.py:
--------------------------------------------------------------------------------
1 | import modal
2 | import torch
3 | import time
4 | import psutil
5 | import gc
6 | import os
7 | import uuid
8 | from concurrent.futures import ThreadPoolExecutor
9 | from fastapi import FastAPI, Request, Response
10 | from pydantic import BaseModel
11 | from transformers import AutoTokenizer, AutoModelForCausalLM
12 | import math
13 | from fastapi.responses import StreamingResponse
14 |
15 | # Set your Hugging Face token
16 | hf_token = "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"
17 | os.environ["HUGGING_FACE_HUB_TOKEN"] = hf_token
18 |
19 | # Define the mapping for bit-width
20 | def map_bitwidth(bits):
21 | if 4 <= bits <= 7:
22 | return 4
23 | elif 8 <= bits <= 15:
24 | return 8
25 | else:
26 | return 16
27 |
28 | # Mapping bit-width to model names
29 | model_mapping = {
30 | "Qwen/Qwen2.5-0.5B": {
31 | 4: "Qwen/Qwen2.5-0.5B",
32 | 8: "Qwen/Qwen2.5-0.5B",
33 | 16: "Qwen/Qwen2.5-0.5B"
34 | },
35 | "Qwen/Qwen2.5-1.5B": {
36 | 4: "Qwen/Qwen2.5-1.5B",
37 | 8: "Qwen/Qwen2.5-1.5B",
38 | 16: "Qwen/Qwen2.5-1.5B"
39 | },
40 | "meta-llama/Llama-3.2-1B": {
41 | 4: "meta-llama/Llama-3.2-1B",
42 | 8: "meta-llama/Llama-3.2-1B",
43 | 16: "meta-llama/Llama-3.2-1B"
44 | },
45 | "meta-llama/Llama-3.2-3B": {
46 | 4: "meta-llama/Llama-3.2-3B",
47 | 8: "meta-llama/Llama-3.2-3B",
48 | 16: "meta-llama/Llama-3.2-3B"
49 | },
50 | "meta-llama/Llama-3.1-8B": {
51 | 4: "meta-llama/Llama-3.1-8B",
52 | 8: "meta-llama/Llama-3.1-8B",
53 | 16: "meta-llama/Llama-3.1-8B"
54 | },
55 | }
56 |
57 | # Function to quantize the model
58 | def absmax_quantize(tensor, bitwidth):
59 | scale = torch.max(torch.abs(tensor))
60 | q_tensor = torch.round(tensor / scale * (2**(bitwidth - 1) - 1))
61 | deq_tensor = q_tensor / (2**(bitwidth - 1) - 1) * scale
62 | return deq_tensor
63 |
64 | def zero_mean_quantize(tensor, bitwidth):
65 | scale = torch.max(torch.abs(tensor - tensor.mean()))
66 | q_tensor = torch.round((tensor - tensor.mean()) / scale * (2**(bitwidth - 1) - 1))
67 | deq_tensor = q_tensor / (2**(bitwidth - 1) - 1) * scale + tensor.mean()
68 | return deq_tensor
69 |
70 | model_cache = {}
71 |
72 | def load_model(model_name, bitwidth, quantization_type, device):
73 | # Create a unique key for the model based on its name, bitwidth, and quantization type
74 | cache_key = (model_name, bitwidth, quantization_type)
75 |
76 | # Check if the model is already in the cache
77 | if cache_key in model_cache:
78 | print(f"Using cached model: {model_name} with bitwidth {bitwidth} and quantization type {quantization_type}")
79 | model, tokenizer = model_cache[cache_key]
80 | else:
81 | print(f"Downloading and caching model: {model_name} with bitwidth {bitwidth} and quantization type {quantization_type}")
82 | # Load the model in half-precision (torch.float16)
83 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, use_safetensors=True).to(device)
84 |
85 | # Apply quantization
86 | for param in model.parameters():
87 | if quantization_type == 'WASQ-LTH':
88 | param.data = absmax_quantize(param.data, bitwidth).to(torch.bfloat16) # Ensure quantization output is float16
89 | elif quantization_type == 'WASQ-OPT':
90 | param.data = zero_mean_quantize(param.data, bitwidth).to(torch.bfloat16) # Ensure quantization output is float16
91 |
92 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token)
93 |
94 | # Cache the model
95 | model_cache[cache_key] = (model, tokenizer)
96 |
97 | return model, tokenizer
98 |
99 | def measure_performance(model, tokenizer, input_text, device):
100 | # Tokenize the input text
101 | inputs = tokenizer(input_text, return_tensors='pt').to(device)
102 |
103 | # Ensure input_ids remains as torch.long (integer)
104 | inputs = {k: v.to(torch.long) if k == "input_ids" else v.to(torch.bfloat16) for k, v in inputs.items()}
105 |
106 | start_time = time.time()
107 | with torch.no_grad():
108 | outputs = model.generate(
109 | **inputs,
110 | max_new_tokens=256,
111 | num_return_sequences=1,
112 | do_sample=True,
113 | temperature=0.7,
114 | repetition_penalty=1.2,
115 | pad_token_id=tokenizer.pad_token_id,
116 | eos_token_id=tokenizer.eos_token_id,
117 | )
118 | end_time = time.time()
119 | inference_time = end_time - start_time
120 | memory_usage = psutil.Process().memory_info().rss / (1024 ** 2) # in MB
121 | generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
122 | return inference_time, memory_usage, generated_text
123 |
124 | # Define request models
125 | class ModelRequest(BaseModel):
126 | model_name: str
127 | quantization_bits: int
128 | quantization_type: str
129 | input_text: str
130 |
131 | # Create a Modal Dict for persistent storage
132 | results_dict = modal.Dict.from_name("emelinlabs-results", create_if_missing=True)
133 |
134 | # Create a FastAPI app
135 | from fastapi.middleware.cors import CORSMiddleware
136 |
137 | app_fastapi = FastAPI()
138 |
139 | # Add CORS middleware
140 | app_fastapi.add_middleware(
141 | CORSMiddleware,
142 | allow_origins=["*"], # Allow all origins (replace with your frontend URL in production)
143 | allow_credentials=True,
144 | allow_methods=["*"],
145 | allow_headers=["*"],
146 | )
147 |
148 | # Modal setup
149 | image = (
150 | modal.Image.debian_slim(python_version="3.11")
151 | .pip_install("fastapi", "uvicorn", "transformers", "torch", "psutil", "pydantic")
152 | )
153 |
154 | app = modal.App(name="emelinlabs-runner", image=image)
155 |
156 | def sanitize_float(value):
157 | """Ensure the value is a finite float and replace NaN/Infinity with 0.0."""
158 | if not isinstance(value, (int, float)) or not math.isfinite(value):
159 | return 0.0
160 | return value
161 |
162 | # POST endpoint
163 | @app.function(
164 | gpu="A100", # Specify the GPU type (e.g., "A10G", "A100", "H100")
165 | timeout=86400, # Timeout in seconds (1 day = 86400 seconds)
166 | allow_concurrent_inputs=100 # Allow concurrent requests
167 | )
168 | @modal.web_endpoint(method="POST")
169 | def run_inference(request: ModelRequest):
170 | device = "cuda" if torch.cuda.is_available() else "cpu"
171 | model_name = request.model_name
172 | quantization_bits = request.quantization_bits
173 | quantization_type = request.quantization_type
174 | input_text = request.input_text
175 |
176 | print(f"Model: {model_name}, Bits: {quantization_bits}, Type: {quantization_type}")
177 |
178 | # Generate a unique ID for this request
179 | request_id = str(uuid.uuid4())
180 |
181 | # Load original model
182 | original_model = AutoModelForCausalLM.from_pretrained(
183 | model_name, torch_dtype=torch.bfloat16, use_safetensors=True
184 | ).to(device)
185 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token)
186 |
187 | # Load and quantize model
188 | effective_bits = map_bitwidth(quantization_bits)
189 | quantized_model_name = model_mapping[model_name][effective_bits]
190 | quantized_model, _ = load_model(quantized_model_name, effective_bits, quantization_type, device)
191 |
192 | # Function to run inference and measure performance
193 | def run_model_inference(model, tokenizer, input_text, device):
194 | inference_time, memory_usage, generated_text = measure_performance(model, tokenizer, input_text, device)
195 | return inference_time, memory_usage, generated_text
196 |
197 | # Run original and quantized model inference in parallel
198 | with ThreadPoolExecutor() as executor:
199 | orig_future = executor.submit(run_model_inference, original_model, tokenizer, input_text, device)
200 | quant_future = executor.submit(run_model_inference, quantized_model, tokenizer, input_text, device)
201 |
202 | orig_inference_time, orig_memory_usage, orig_text = orig_future.result()
203 | quant_inference_time, quant_memory_usage, quant_text = quant_future.result()
204 |
205 | # Calculate memory usage for quantized model
206 | quant_memory_usage = (effective_bits / 16.0) * orig_memory_usage
207 |
208 | # Calculate differences
209 | speed_diff = (orig_inference_time - quant_inference_time) / orig_inference_time * 100
210 | memory_savings = (1 - (quantization_bits / 16.0)) * 100
211 |
212 | # Sanitize all floating-point values
213 | orig_inference_time = sanitize_float(orig_inference_time)
214 | orig_memory_usage = sanitize_float(orig_memory_usage)
215 | quant_inference_time = sanitize_float(quant_inference_time)
216 | quant_memory_usage = sanitize_float(quant_memory_usage)
217 | speed_diff = sanitize_float(speed_diff)
218 | memory_savings = sanitize_float(memory_savings)
219 |
220 | # Store results in Modal Dict
221 | results_dict[request_id] = {
222 | "original": {
223 | "text": orig_text,
224 | "inference_time": orig_inference_time,
225 | "memory_usage": orig_memory_usage,
226 | },
227 | "quantized": {
228 | "text": quant_text,
229 | "inference_time": quant_inference_time,
230 | "memory_usage": quant_memory_usage,
231 | },
232 | "comparison": {
233 | "speed_diff": speed_diff,
234 | "memory_savings": memory_savings,
235 | }
236 | }
237 |
238 | # Clean up to free memory
239 | del original_model
240 | del quantized_model
241 | gc.collect()
242 | torch.cuda.empty_cache()
243 |
244 | return {"request_id": request_id}
245 |
246 | # GET endpoint
247 | @app.function()
248 | @modal.web_endpoint()
249 | def get_result(request_id: str):
250 | result = results_dict.get(request_id, None)
251 | if result:
252 | return result
253 | else:
254 | return {"error": "Request ID not found"}
255 |
256 | # Health check endpoint
257 | @app.function()
258 | @modal.web_endpoint()
259 | def health_check():
260 | return {"status": "active"}
261 |
262 | # Stream tokens from original model
263 | @app.function()
264 | @modal.web_endpoint()
265 | def stream_original(request_id: str):
266 | result = results_dict.get(request_id, None)
267 | if result:
268 | def generate():
269 | for token in result["original"]["text"].split():
270 | yield f"data: {token}\n\n"
271 | time.sleep(0.1) # Simulate streaming delay
272 | return StreamingResponse(generate(), media_type="text/event-stream")
273 | else:
274 | return {"error": "Request ID not found"}
275 |
276 | # Stream tokens from quantized model
277 | @app.function()
278 | @modal.web_endpoint()
279 | def stream_quantized(request_id: str):
280 | result = results_dict.get(request_id, None)
281 | if result:
282 | def generate():
283 | for token in result["quantized"]["text"].split():
284 | yield f"data: {token}\n\n"
285 | time.sleep(0.1) # Simulate streaming delay
286 | return StreamingResponse(generate(), media_type="text/event-stream")
287 | else:
288 | return {"error": "Request ID not found"}
--------------------------------------------------------------------------------
/src/website/backend/sequential_backend.py:
--------------------------------------------------------------------------------
1 | import modal
2 | import torch
3 | import time
4 | import psutil
5 | import gc
6 | import os
7 | import uuid
8 | from concurrent.futures import ThreadPoolExecutor
9 | from fastapi import FastAPI, Request
10 | from pydantic import BaseModel
11 | from transformers import AutoTokenizer, AutoModelForCausalLM
12 | import math
13 |
14 | # Set your Hugging Face token
15 | hf_token = "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"
16 | os.environ["HUGGING_FACE_HUB_TOKEN"] = hf_token
17 |
18 | # Define the mapping for bit-width
19 | def map_bitwidth(bits):
20 | if 4 <= bits <= 7:
21 | return 4
22 | elif 8 <= bits <= 15:
23 | return 8
24 | else:
25 | return 16
26 |
27 | # Mapping bit-width to model names
28 | model_mapping = {
29 | "Qwen/Qwen2.5-0.5B": {
30 | 4: "Qwen/Qwen2.5-0.5B",
31 | 8: "Qwen/Qwen2.5-0.5B",
32 | 16: "Qwen/Qwen2.5-0.5B"
33 | },
34 | "Qwen/Qwen2.5-1.5B": {
35 | 4: "Qwen/Qwen2.5-1.5B",
36 | 8: "Qwen/Qwen2.5-1.5B",
37 | 16: "Qwen/Qwen2.5-1.5B"
38 | },
39 | "meta-llama/Llama-3.2-1B": {
40 | 4: "meta-llama/Llama-3.2-1B",
41 | 8: "meta-llama/Llama-3.2-1B",
42 | 16: "meta-llama/Llama-3.2-1B"
43 | },
44 | "meta-llama/Llama-3.2-3B": {
45 | 4: "meta-llama/Llama-3.2-3B",
46 | 8: "meta-llama/Llama-3.2-3B",
47 | 16: "meta-llama/Llama-3.2-3B"
48 | }
49 | }
50 |
51 | # Function to quantize the model
52 | def absmax_quantize(tensor, bitwidth):
53 | scale = torch.max(torch.abs(tensor))
54 | q_tensor = torch.round(tensor / scale * (2**(bitwidth - 1) - 1))
55 | deq_tensor = q_tensor / (2**(bitwidth - 1) - 1) * scale
56 | return deq_tensor
57 |
58 | def zero_mean_quantize(tensor, bitwidth):
59 | scale = torch.max(torch.abs(tensor - tensor.mean()))
60 | q_tensor = torch.round((tensor - tensor.mean()) / scale * (2**(bitwidth - 1) - 1))
61 | deq_tensor = q_tensor / (2**(bitwidth - 1) - 1) * scale + tensor.mean()
62 | return deq_tensor
63 |
64 | def load_model(model_name, bitwidth, quantization_type, device):
65 | # Load the model in half-precision (torch.float16)
66 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device)
67 |
68 | # Apply quantization
69 | for param in model.parameters():
70 | if quantization_type == 'WASQ-LTH':
71 | param.data = absmax_quantize(param.data, bitwidth).to(torch.float16) # Ensure quantization output is float16
72 | elif quantization_type == 'WASQ-OPT':
73 | param.data = zero_mean_quantize(param.data, bitwidth).to(torch.float16) # Ensure quantization output is float16
74 |
75 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token)
76 | return model, tokenizer
77 |
78 | def measure_performance(model, tokenizer, input_text, device):
79 | # Tokenize the input text
80 | inputs = tokenizer(input_text, return_tensors='pt').to(device)
81 |
82 | # Ensure input_ids remains as torch.long (integer)
83 | inputs = {k: v.to(torch.long) if k == "input_ids" else v.to(torch.float16) for k, v in inputs.items()}
84 |
85 | start_time = time.time()
86 | with torch.no_grad():
87 | outputs = model.generate(
88 | **inputs,
89 | max_new_tokens=256,
90 | num_return_sequences=1,
91 | do_sample=True,
92 | temperature=0.7,
93 | repetition_penalty=1.2,
94 | pad_token_id=tokenizer.pad_token_id,
95 | eos_token_id=tokenizer.eos_token_id,
96 | )
97 | end_time = time.time()
98 | inference_time = end_time - start_time
99 | memory_usage = psutil.Process().memory_info().rss / (1024 ** 2) # in MB
100 | generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
101 | return inference_time, memory_usage, generated_text
102 |
103 | # def calculate_perplexity(model, tokenizer, input_text, device):
104 | # inputs = tokenizer(input_text, return_tensors='pt').to(device)
105 | # max_length = inputs.input_ids.size(1)
106 | # with torch.no_grad():
107 | # outputs = model(**inputs)
108 | # shift_logits = outputs.logits[..., :-1, :].contiguous()
109 | # shift_labels = inputs.input_ids[..., 1:].contiguous()
110 | # loss_fct = torch.nn.CrossEntropyLoss()
111 | # loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
112 | # perplexity = torch.exp(loss).item()
113 | # return perplexity
114 |
115 | # Define request models
116 | class ModelRequest(BaseModel):
117 | model_name: str
118 | quantization_bits: int
119 | quantization_type: str
120 | input_text: str
121 |
122 | # Create a Modal Dict for persistent storage
123 | results_dict = modal.Dict.from_name("emelinlabs-results", create_if_missing=True)
124 |
125 | # Create a FastAPI app
126 | app_fastapi = FastAPI()
127 |
128 | # Modal setup
129 | image = (
130 | modal.Image.debian_slim(python_version="3.11")
131 | .pip_install("fastapi", "uvicorn", "transformers", "torch", "psutil", "pydantic")
132 | )
133 |
134 | app = modal.App(name="emelinlabs-runner", image=image)
135 |
136 | def sanitize_float(value):
137 | """Ensure the value is a finite float and replace NaN/Infinity with 0.0."""
138 | if not isinstance(value, (int, float)) or not math.isfinite(value):
139 | return 0.0
140 | return value
141 |
142 | # POST endpoint
143 | @app.function(
144 | gpu="A100", # Specify the GPU type (e.g., "A10G", "A100", "H100")
145 | timeout=86400, # Timeout in seconds (1 day = 86400 seconds)
146 | allow_concurrent_inputs=100 # Allow concurrent requests
147 | )
148 | @modal.web_endpoint(method="POST")
149 | def run_inference(request: ModelRequest):
150 | device = "cuda" if torch.cuda.is_available() else "cpu"
151 | model_name = request.model_name
152 | quantization_bits = request.quantization_bits
153 | quantization_type = request.quantization_type
154 | input_text = request.input_text
155 |
156 | # Generate a unique ID for this request
157 | request_id = str(uuid.uuid4())
158 |
159 | # Load original model
160 | original_model = AutoModelForCausalLM.from_pretrained(
161 | model_name, torch_dtype=torch.float16
162 | ).to(device)
163 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token)
164 |
165 | # Load and quantize model
166 | effective_bits = map_bitwidth(quantization_bits)
167 | quantized_model_name = model_mapping[model_name][effective_bits]
168 | quantized_model, _ = load_model(quantized_model_name, effective_bits, quantization_type, device)
169 |
170 | # Measure performance for original model
171 | orig_inference_time, orig_memory_usage, orig_text = measure_performance(original_model, tokenizer, input_text, device)
172 | # orig_perplexity = calculate_perplexity(original_model, tokenizer, input_text, device)
173 |
174 | # Measure performance for quantized model
175 | quant_inference_time, _, quant_text = measure_performance(quantized_model, tokenizer, input_text, device)
176 | # quant_perplexity = calculate_perplexity(quantized_model, tokenizer, input_text, device)
177 |
178 | # Calculate memory usage for quantized model
179 | quant_memory_usage = (effective_bits / 16.0) * orig_memory_usage
180 |
181 | # Calculate differences
182 | speed_diff = (orig_inference_time - quant_inference_time) / orig_inference_time * 100
183 | memory_savings = (1 - (quantization_bits / 16.0)) * 100
184 | # perplexity_diff = quant_perplexity - orig_perplexity
185 |
186 | # Sanitize all floating-point values
187 | orig_inference_time = sanitize_float(orig_inference_time)
188 | orig_memory_usage = sanitize_float(orig_memory_usage)
189 | # orig_perplexity = sanitize_float(orig_perplexity)
190 | quant_inference_time = sanitize_float(quant_inference_time)
191 | quant_memory_usage = sanitize_float(quant_memory_usage)
192 | # quant_perplexity = sanitize_float(quant_perplexity)
193 | speed_diff = sanitize_float(speed_diff)
194 | memory_savings = sanitize_float(memory_savings)
195 | # perplexity_diff = sanitize_float(perplexity_diff)
196 |
197 | # Store results in Modal Dict
198 | results_dict[request_id] = {
199 | "original": {
200 | "text": orig_text,
201 | "inference_time": orig_inference_time,
202 | "memory_usage": orig_memory_usage,
203 | # "perplexity": orig_perplexity
204 | },
205 | "quantized": {
206 | "text": quant_text,
207 | "inference_time": quant_inference_time,
208 | "memory_usage": quant_memory_usage,
209 | # "perplexity": quant_perplexity
210 | },
211 | "comparison": {
212 | "speed_diff": speed_diff,
213 | "memory_savings": memory_savings,
214 | # "perplexity_diff": perplexity_diff
215 | }
216 | }
217 |
218 | # Clean up to free memory
219 | del original_model
220 | del quantized_model
221 | gc.collect()
222 | torch.cuda.empty_cache()
223 |
224 | return {"request_id": request_id}
225 |
226 | # GET endpoint
227 | @app.function()
228 | @modal.web_endpoint()
229 | def get_result(request_id: str):
230 | result = results_dict.get(request_id, None)
231 | if result:
232 | return result
233 | else:
234 | return {"error": "Request ID not found"}
235 |
236 | # Health check endpoint
237 | @app.function()
238 | @modal.web_endpoint()
239 | def health_check():
240 | return {"status": "active"}
--------------------------------------------------------------------------------
/src/website/backend/testing_backend.py:
--------------------------------------------------------------------------------
1 | import modal
2 | import torch
3 | import time
4 | import psutil
5 | import gc
6 | import os
7 | import uuid
8 | from concurrent.futures import ThreadPoolExecutor
9 | from fastapi import FastAPI, Request, Response
10 | from pydantic import BaseModel
11 | from transformers import AutoTokenizer, AutoModelForCausalLM
12 | import math
13 | from fastapi.responses import StreamingResponse
14 |
15 | # Set your Hugging Face token
16 | hf_token = "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"
17 | os.environ["HUGGING_FACE_HUB_TOKEN"] = hf_token
18 |
19 | # Define the mapping for bit-width
20 | def map_bitwidth(bits):
21 | if 4 <= bits <= 7:
22 | return 4
23 | elif 8 <= bits <= 15:
24 | return 8
25 | else:
26 | return 16
27 |
28 | # Mapping bit-width to model names
29 | model_mapping = {
30 | "Qwen/Qwen2.5-0.5B": {
31 | 4: "Qwen/Qwen2.5-0.5B",
32 | 8: "Qwen/Qwen2.5-0.5B",
33 | 16: "Qwen/Qwen2.5-0.5B"
34 | },
35 | "Qwen/Qwen2.5-1.5B": {
36 | 4: "Qwen/Qwen2.5-1.5B",
37 | 8: "Qwen/Qwen2.5-1.5B",
38 | 16: "Qwen/Qwen2.5-1.5B"
39 | },
40 | "meta-llama/Llama-3.2-1B": {
41 | 4: "meta-llama/Llama-3.2-1B",
42 | 8: "meta-llama/Llama-3.2-1B",
43 | 16: "meta-llama/Llama-3.2-1B"
44 | },
45 | "meta-llama/Llama-3.2-3B": {
46 | 4: "meta-llama/Llama-3.2-3B",
47 | 8: "meta-llama/Llama-3.2-3B",
48 | 16: "meta-llama/Llama-3.2-3B"
49 | },
50 | "meta-llama/Llama-3.1-8B": {
51 | 4: "meta-llama/Llama-3.1-8B",
52 | 8: "meta-llama/Llama-3.1-8B",
53 | 16: "meta-llama/Llama-3.1-8B"
54 | },
55 | }
56 |
57 | # Function to quantize the model
58 | def absmax_quantize(tensor, bitwidth):
59 | scale = torch.max(torch.abs(tensor))
60 | q_tensor = torch.round(tensor / scale * (2**(bitwidth - 1) - 1))
61 | deq_tensor = q_tensor / (2**(bitwidth - 1) - 1) * scale
62 | return deq_tensor
63 |
64 | def zero_mean_quantize(tensor, bitwidth):
65 | scale = torch.max(torch.abs(tensor - tensor.mean()))
66 | q_tensor = torch.round((tensor - tensor.mean()) / scale * (2**(bitwidth - 1) - 1))
67 | deq_tensor = q_tensor / (2**(bitwidth - 1) - 1) * scale + tensor.mean()
68 | return deq_tensor
69 |
70 | def load_model(model_name, bitwidth, quantization_type, device):
71 | # Load the model in half-precision (torch.float16)
72 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, use_safetensors=True).to(device)
73 |
74 | # Apply quantization
75 | for param in model.parameters():
76 | if quantization_type == 'WASQ-LTH':
77 | param.data = absmax_quantize(param.data, bitwidth).to(torch.bfloat16) # Ensure quantization output is float16
78 | elif quantization_type == 'WASQ-OPT':
79 | param.data = zero_mean_quantize(param.data, bitwidth).to(torch.bfloat16) # Ensure quantization output is float16
80 |
81 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token)
82 | return model, tokenizer
83 |
84 | def measure_performance(model, tokenizer, input_text, device):
85 | # Tokenize the input text
86 | inputs = tokenizer(input_text, return_tensors='pt').to(device)
87 |
88 | # Ensure input_ids remains as torch.long (integer)
89 | inputs = {k: v.to(torch.long) if k == "input_ids" else v.to(torch.bfloat16) for k, v in inputs.items()}
90 |
91 | start_time = time.time()
92 | with torch.no_grad():
93 | outputs = model.generate(
94 | **inputs,
95 | max_new_tokens=256,
96 | num_return_sequences=1,
97 | do_sample=True,
98 | temperature=0.7,
99 | repetition_penalty=1.2,
100 | pad_token_id=tokenizer.pad_token_id,
101 | eos_token_id=tokenizer.eos_token_id,
102 | )
103 | end_time = time.time()
104 | inference_time = end_time - start_time
105 | memory_usage = psutil.Process().memory_info().rss / (1024 ** 2) # in MB
106 | generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
107 | return inference_time, memory_usage, generated_text
108 |
109 | # Define request models
110 | class ModelRequest(BaseModel):
111 | model_name: str
112 | quantization_bits: int
113 | quantization_type: str
114 | input_text: str
115 |
116 | # Create a Modal Dict for persistent storage
117 | results_dict = modal.Dict.from_name("emelinlabs-results", create_if_missing=True)
118 |
119 | # Create a FastAPI app
120 | from fastapi.middleware.cors import CORSMiddleware
121 |
122 | app_fastapi = FastAPI()
123 |
124 | # Add CORS middleware
125 | app_fastapi.add_middleware(
126 | CORSMiddleware,
127 | allow_origins=["*"], # Allow all origins (replace with your frontend URL in production)
128 | allow_credentials=True,
129 | allow_methods=["*"],
130 | allow_headers=["*"],
131 | )
132 |
133 | # Modal setup
134 | image = (
135 | modal.Image.debian_slim(python_version="3.11")
136 | .pip_install("fastapi", "uvicorn", "transformers", "torch", "psutil", "pydantic")
137 | )
138 |
139 | app = modal.App(name="emelinlabs-runners", image=image)
140 |
141 | def sanitize_float(value):
142 | """Ensure the value is a finite float and replace NaN/Infinity with 0.0."""
143 | if not isinstance(value, (int, float)) or not math.isfinite(value):
144 | return 0.0
145 | return value
146 |
147 | # POST endpoint
148 | @app.function(
149 | gpu="A100", # Specify the GPU type (e.g., "A10G", "A100", "H100")
150 | timeout=86400, # Timeout in seconds (1 day = 86400 seconds)
151 | allow_concurrent_inputs=100 # Allow concurrent requests
152 | )
153 | @modal.web_endpoint(method="POST")
154 | def run_inference(request: ModelRequest):
155 | device = "cuda" if torch.cuda.is_available() else "cpu"
156 | model_name = request.model_name
157 | quantization_bits = request.quantization_bits
158 | quantization_type = request.quantization_type
159 | input_text = request.input_text
160 |
161 | print(f"Model: {model_name}, Bits: {quantization_bits}, Type: {quantization_type}")
162 |
163 | # Generate a unique ID for this request
164 | request_id = str(uuid.uuid4())
165 |
166 | # Load original model
167 | original_model = AutoModelForCausalLM.from_pretrained(
168 | model_name, torch_dtype=torch.bfloat16, use_safetensors=True
169 | ).to(device)
170 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token)
171 |
172 | # Load and quantize model
173 | effective_bits = map_bitwidth(quantization_bits)
174 | quantized_model_name = model_mapping[model_name][effective_bits]
175 | quantized_model, _ = load_model(quantized_model_name, effective_bits, quantization_type, device)
176 |
177 | # Function to run inference and measure performance
178 | def run_model_inference(model, tokenizer, input_text, device):
179 | inference_time, memory_usage, generated_text = measure_performance(model, tokenizer, input_text, device)
180 | return inference_time, memory_usage, generated_text
181 |
182 | # Run original and quantized model inference in parallel
183 | with ThreadPoolExecutor() as executor:
184 | orig_future = executor.submit(run_model_inference, original_model, tokenizer, input_text, device)
185 | quant_future = executor.submit(run_model_inference, quantized_model, tokenizer, input_text, device)
186 |
187 | orig_inference_time, orig_memory_usage, orig_text = orig_future.result()
188 | quant_inference_time, quant_memory_usage, quant_text = quant_future.result()
189 |
190 | # Calculate memory usage for quantized model
191 | quant_memory_usage = (effective_bits / 16.0) * orig_memory_usage
192 |
193 | # Calculate differences
194 | speed_diff = (orig_inference_time - quant_inference_time) / orig_inference_time * 100
195 | memory_savings = (1 - (quantization_bits / 16.0)) * 100
196 |
197 | # Sanitize all floating-point values
198 | orig_inference_time = sanitize_float(orig_inference_time)
199 | orig_memory_usage = sanitize_float(orig_memory_usage)
200 | quant_inference_time = sanitize_float(quant_inference_time)
201 | quant_memory_usage = sanitize_float(quant_memory_usage)
202 | speed_diff = sanitize_float(speed_diff)
203 | memory_savings = sanitize_float(memory_savings)
204 |
205 | # Store results in Modal Dict
206 | results_dict[request_id] = {
207 | "original": {
208 | "text": orig_text,
209 | "inference_time": orig_inference_time,
210 | "memory_usage": orig_memory_usage,
211 | },
212 | "quantized": {
213 | "text": quant_text,
214 | "inference_time": quant_inference_time,
215 | "memory_usage": quant_memory_usage,
216 | },
217 | "comparison": {
218 | "speed_diff": speed_diff,
219 | "memory_savings": memory_savings,
220 | }
221 | }
222 |
223 | # Clean up to free memory
224 | del original_model
225 | del quantized_model
226 | gc.collect()
227 | torch.cuda.empty_cache()
228 |
229 | return {"request_id": request_id}
230 |
231 | # GET endpoint
232 | @app.function()
233 | @modal.web_endpoint()
234 | def get_result(request_id: str):
235 | result = results_dict.get(request_id, None)
236 | if result:
237 | return result
238 | else:
239 | return {"error": "Request ID not found"}
240 |
241 | # Health check endpoint
242 | @app.function()
243 | @modal.web_endpoint()
244 | def health_check():
245 | return {"status": "active"}
246 |
247 | # Stream tokens from original model
248 | @app.function()
249 | @modal.web_endpoint()
250 | def stream_original(request_id: str):
251 | result = results_dict.get(request_id, None)
252 | if result:
253 | def generate():
254 | for token in result["original"]["text"].split():
255 | yield f"data: {token}\n\n"
256 | time.sleep(0.1) # Simulate streaming delay
257 | return StreamingResponse(generate(), media_type="text/event-stream")
258 | else:
259 | return {"error": "Request ID not found"}
260 |
261 | # Stream tokens from quantized model
262 | @app.function()
263 | @modal.web_endpoint()
264 | def stream_quantized(request_id: str):
265 | result = results_dict.get(request_id, None)
266 | if result:
267 | def generate():
268 | for token in result["quantized"]["text"].split():
269 | yield f"data: {token}\n\n"
270 | time.sleep(0.1) # Simulate streaming delay
271 | return StreamingResponse(generate(), media_type="text/event-stream")
272 | else:
273 | return {"error": "Request ID not found"}
--------------------------------------------------------------------------------
/src/website/frontend/.gitignore:
--------------------------------------------------------------------------------
1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
2 |
3 | # dependencies
4 | /node_modules
5 | /.pnp
6 | .pnp.js
7 |
8 | # testing
9 | /coverage
10 |
11 | # production
12 | /build
13 |
14 | # misc
15 | .DS_Store
16 | .env.local
17 | .env.development.local
18 | .env.test.local
19 | .env.production.local
20 |
21 | npm-debug.log*
22 | yarn-debug.log*
23 | yarn-error.log*
24 |
--------------------------------------------------------------------------------
/src/website/frontend/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "jumbo",
3 | "version": "0.1.0",
4 | "private": true,
5 | "dependencies": {
6 | "cra-template": "1.2.0",
7 | "lucide-react": "^0.472.0",
8 | "react": "^19.0.0",
9 | "react-dom": "^19.0.0",
10 | "react-markdown": "^9.0.3",
11 | "react-scripts": "^5.0.1",
12 | "web-vitals": "^4.2.4",
13 | "webpack": "^5.97.1"
14 | },
15 | "scripts": {
16 | "start": "react-scripts start",
17 | "build": "react-scripts build",
18 | "test": "react-scripts test",
19 | "eject": "react-scripts eject"
20 | },
21 | "eslintConfig": {
22 | "extends": [
23 | "react-app",
24 | "react-app/jest"
25 | ]
26 | },
27 | "browserslist": {
28 | "production": [
29 | ">0.2%",
30 | "not dead",
31 | "not op_mini all"
32 | ],
33 | "development": [
34 | "last 1 chrome version",
35 | "last 1 firefox version",
36 | "last 1 safari version"
37 | ]
38 | },
39 | "devDependencies": {
40 | "@tailwindcss/typography": "^0.5.16",
41 | "tailwindcss": "^3.4.17"
42 | }
43 | }
44 |
--------------------------------------------------------------------------------
/src/website/frontend/public/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/src/website/frontend/public/favicon.ico
--------------------------------------------------------------------------------
/src/website/frontend/public/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
12 |
13 |
17 |
18 |
27 | React App
28 |
29 |
30 |
31 |
32 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/src/website/frontend/public/logo192.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/src/website/frontend/public/logo192.png
--------------------------------------------------------------------------------
/src/website/frontend/public/logo512.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aloshdenny/superfloat/47475a5182cfcf0d5231db28b59b59ea2fe3cf32/src/website/frontend/public/logo512.png
--------------------------------------------------------------------------------
/src/website/frontend/public/manifest.json:
--------------------------------------------------------------------------------
1 | {
2 | "short_name": "React App",
3 | "name": "Create React App Sample",
4 | "icons": [
5 | {
6 | "src": "favicon.ico",
7 | "sizes": "64x64 32x32 24x24 16x16",
8 | "type": "image/x-icon"
9 | },
10 | {
11 | "src": "logo192.png",
12 | "type": "image/png",
13 | "sizes": "192x192"
14 | },
15 | {
16 | "src": "logo512.png",
17 | "type": "image/png",
18 | "sizes": "512x512"
19 | }
20 | ],
21 | "start_url": ".",
22 | "display": "standalone",
23 | "theme_color": "#000000",
24 | "background_color": "#ffffff"
25 | }
26 |
--------------------------------------------------------------------------------
/src/website/frontend/public/robots.txt:
--------------------------------------------------------------------------------
1 | # https://www.robotstxt.org/robotstxt.html
2 | User-agent: *
3 | Disallow:
4 |
--------------------------------------------------------------------------------
/src/website/frontend/src/App.css:
--------------------------------------------------------------------------------
1 | .App {
2 | text-align: center;
3 | }
4 |
5 | .App-logo {
6 | height: 40vmin;
7 | pointer-events: none;
8 | }
9 |
10 | @media (prefers-reduced-motion: no-preference) {
11 | .App-logo {
12 | animation: App-logo-spin infinite 20s linear;
13 | }
14 | }
15 |
16 | .App-header {
17 | background-color: #282c34;
18 | min-height: 100vh;
19 | display: flex;
20 | flex-direction: column;
21 | align-items: center;
22 | justify-content: center;
23 | font-size: calc(10px + 2vmin);
24 | color: white;
25 | }
26 |
27 | .App-link {
28 | color: #61dafb;
29 | }
30 |
31 | @keyframes App-logo-spin {
32 | from {
33 | transform: rotate(0deg);
34 | }
35 | to {
36 | transform: rotate(360deg);
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/src/website/frontend/src/App.js:
--------------------------------------------------------------------------------
1 | import logo from './logo.svg';
2 | import './App.css';
3 | import ChatInterface from './components/ChatInterface';
4 | import LandingPage from './components/LandingPage';
5 | function App() {
6 | return (
7 |
8 |
9 |
10 | );
11 | }
12 |
13 | export default App;
14 |
--------------------------------------------------------------------------------
/src/website/frontend/src/App.test.js:
--------------------------------------------------------------------------------
1 | import { render, screen } from '@testing-library/react';
2 | import App from './App';
3 |
4 | test('renders learn react link', () => {
5 | render();
6 | const linkElement = screen.getByText(/learn react/i);
7 | expect(linkElement).toBeInTheDocument();
8 | });
9 |
--------------------------------------------------------------------------------
/src/website/frontend/src/README.md:
--------------------------------------------------------------------------------
1 | # Getting Started with Create React App
2 |
3 | This project was bootstrapped with [Create React App](https://github.com/facebook/create-react-app).
4 |
5 | ## Available Scripts
6 |
7 | In the project directory, you can run:
8 |
9 | ### `npm start`
10 |
11 | Runs the app in the development mode.\
12 | Open [http://localhost:3000](http://localhost:3000) to view it in your browser.
13 |
14 | The page will reload when you make changes.\
15 | You may also see any lint errors in the console.
16 |
17 | ### `npm test`
18 |
19 | Launches the test runner in the interactive watch mode.\
20 | See the section about [running tests](https://facebook.github.io/create-react-app/docs/running-tests) for more information.
21 |
22 | ### `npm run build`
23 |
24 | Builds the app for production to the `build` folder.\
25 | It correctly bundles React in production mode and optimizes the build for the best performance.
26 |
27 | The build is minified and the filenames include the hashes.\
28 | Your app is ready to be deployed!
29 |
30 | See the section about [deployment](https://facebook.github.io/create-react-app/docs/deployment) for more information.
31 |
32 | ### `npm run eject`
33 |
34 | **Note: this is a one-way operation. Once you `eject`, you can't go back!**
35 |
36 | If you aren't satisfied with the build tool and configuration choices, you can `eject` at any time. This command will remove the single build dependency from your project.
37 |
38 | Instead, it will copy all the configuration files and the transitive dependencies (webpack, Babel, ESLint, etc) right into your project so you have full control over them. All of the commands except `eject` will still work, but they will point to the copied scripts so you can tweak them. At this point you're on your own.
39 |
40 | You don't have to ever use `eject`. The curated feature set is suitable for small and middle deployments, and you shouldn't feel obligated to use this feature. However we understand that this tool wouldn't be useful if you couldn't customize it when you are ready for it.
41 |
42 | ## Learn More
43 |
44 | You can learn more in the [Create React App documentation](https://facebook.github.io/create-react-app/docs/getting-started).
45 |
46 | To learn React, check out the [React documentation](https://reactjs.org/).
47 |
48 | ### Code Splitting
49 |
50 | This section has moved here: [https://facebook.github.io/create-react-app/docs/code-splitting](https://facebook.github.io/create-react-app/docs/code-splitting)
51 |
52 | ### Analyzing the Bundle Size
53 |
54 | This section has moved here: [https://facebook.github.io/create-react-app/docs/analyzing-the-bundle-size](https://facebook.github.io/create-react-app/docs/analyzing-the-bundle-size)
55 |
56 | ### Making a Progressive Web App
57 |
58 | This section has moved here: [https://facebook.github.io/create-react-app/docs/making-a-progressive-web-app](https://facebook.github.io/create-react-app/docs/making-a-progressive-web-app)
59 |
60 | ### Advanced Configuration
61 |
62 | This section has moved here: [https://facebook.github.io/create-react-app/docs/advanced-configuration](https://facebook.github.io/create-react-app/docs/advanced-configuration)
63 |
64 | ### Deployment
65 |
66 | This section has moved here: [https://facebook.github.io/create-react-app/docs/deployment](https://facebook.github.io/create-react-app/docs/deployment)
67 |
68 | ### `npm run build` fails to minify
69 |
70 | This section has moved here: [https://facebook.github.io/create-react-app/docs/troubleshooting#npm-run-build-fails-to-minify](https://facebook.github.io/create-react-app/docs/troubleshooting#npm-run-build-fails-to-minify)
71 |
--------------------------------------------------------------------------------
/src/website/frontend/src/components/LandingPage.jsx:
--------------------------------------------------------------------------------
1 | import React, { useState } from 'react';
2 | import { Github, Mail, ArrowRight, ArrowLeft } from 'lucide-react';
3 | import ChatInterface from './ChatInterface';
4 |
5 | const LandingPage = () => {
6 | const [showChat, setShowChat] = useState(false);
7 |
8 | return (
9 |
10 | {showChat ? (
11 |
12 |
13 |
20 |
SuperFloat
21 |
22 |
23 |
24 |
25 |
26 |
27 | ) : (
28 |
29 | {/* Hero Section */}
30 |
31 |
32 | EmelinLabs presents
33 |
34 |
35 | Superfloat
36 |
37 |
38 | A revolutionary quantization algorithm optimizing neural networks through custom precision formats.
39 | Designed for edge computing with scalable precision and superior performance.
40 |
41 |
48 |
49 |
50 | {/* What We Do Section */}
51 |
52 |
What We Do
53 |
54 | {[
55 | {
56 | title: "Sign-Exponent Representation",
57 | description: "Efficient bit allocation with 1 bit for sign and remaining for exponent, optimizing precision without mantissa."
58 | },
59 | {
60 | title: "Clamping Range",
61 | description: "Values clamped within [-1, 1] for activation and parameter stability, preventing gradient issues."
62 | },
63 | {
64 | title: "Bit-width Flexibility",
65 | description: "Scalable precision from 3-bit to 16-bit, balancing computation speed and accuracy."
66 |
67 | }
68 | ].map((feature, index) => (
69 |
73 |
{feature.title}
74 |
{feature.description}
75 |
76 | ))}
77 |
78 |
79 |
80 | {/* About Section */}
81 |
82 |
About Us
83 |
84 | SuperFloat implements custom quantization algorithms focusing on the Lottery Ticket Hypothesis (LTH)
85 | and Weight and Activation SuperFloat Quantization (WASQ) techniques for optimizing neural networks
86 | on edge devices.
87 |