├── 8 Python to FPGA ├── .gitignore ├── 3_manual_inference │ ├── hack_bug.png │ ├── final_output.png │ ├── final_custom_system.png │ ├── working_dma_config.png │ ├── vitis_software │ │ ├── generate_test_data.py │ │ └── main.c │ └── README.md ├── README.md └── 1_train_quantized_model.ipynb ├── 2 AXI IP Hello world custom LED driver ├── Hardware │ ├── constraints.xdc │ ├── led_axi_ip.v │ ├── LED_IP_v1_0.v │ └── LED_IP_v1_0_S00_AXI.v └── Software │ └── main.c ├── 3 HDL - A system verilog example ├── constraints_to_zybo.xdc └── tutorial_hdl.sv ├── 5 Petalinux on Zynq quick example └── commands.sh ├── 4 DMA tutorial └── dma_fft.c └── README.md /8 Python to FPGA/.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | *.h -------------------------------------------------------------------------------- /8 Python to FPGA/3_manual_inference/hack_bug.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0BAB1/BRH_Tutorials/HEAD/8 Python to FPGA/3_manual_inference/hack_bug.png -------------------------------------------------------------------------------- /8 Python to FPGA/3_manual_inference/final_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0BAB1/BRH_Tutorials/HEAD/8 Python to FPGA/3_manual_inference/final_output.png -------------------------------------------------------------------------------- /2 AXI IP Hello world custom LED driver/Hardware/constraints.xdc: -------------------------------------------------------------------------------- 1 | set_property IOSTANDARD LVCMOS33 [get_ports {led_out}] 2 | set_property PACKAGE_PIN D18 [get_ports {led_out}] -------------------------------------------------------------------------------- /8 Python to FPGA/3_manual_inference/final_custom_system.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0BAB1/BRH_Tutorials/HEAD/8 Python to FPGA/3_manual_inference/final_custom_system.png -------------------------------------------------------------------------------- /8 Python to FPGA/3_manual_inference/working_dma_config.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0BAB1/BRH_Tutorials/HEAD/8 Python to FPGA/3_manual_inference/working_dma_config.png -------------------------------------------------------------------------------- /3 HDL - A system verilog example/constraints_to_zybo.xdc: -------------------------------------------------------------------------------- 1 | # CONSTRAINTS 2 | 3 | set_property PACKAGE_PIN G15 [get_ports {sw1}] 4 | set_property IOSTANDARD LVCMOS33 [get_ports {sw1}] 5 | 6 | set_property PACKAGE_PIN P15 [get_ports {sw2}] 7 | set_property IOSTANDARD LVCMOS33 [get_ports {sw2}] 8 | 9 | set_property PACKAGE_PIN K18 [get_ports {btn}] 10 | set_property IOSTANDARD LVCMOS33 [get_ports {btn}] 11 | 12 | set_property PACKAGE_PIN M14 [get_ports {led_out}] 13 | set_property IOSTANDARD LVCMOS33 [get_ports {led_out}] 14 | 15 | # deal with the fact btn is not an actual clock 16 | set_property CLOCK_DEDICATED_ROUTE FALSE [get_nets {clk_IBUF_inst/O}] -------------------------------------------------------------------------------- /2 AXI IP Hello world custom LED driver/Software/main.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include "xparameters.h" 3 | 4 | // give a better name to our base addres (from xparameters.h) 5 | // REPLACE "XPAR_LED_IP_0_BASEADDR" with YOUR base address name 6 | // (search in xparameters.h, you'll find it !) 7 | // or directly the you took from vivado 8 | #define LED_ADDR XPAR_LED_IP_0_BASEADDR 9 | 10 | int main(void){ 11 | 12 | //write a 1 to led base address 13 | int* ptr = LED_ADDR; // declare a pointer 14 | *ptr = 0x1; // write 1 in the register at that adress 15 | 16 | // this will go though our axi hardware 17 | // and make an AXI LITE trasction to our IP 18 | // and the rest is history (you wrote the HDL after all !) 19 | 20 | return 0; 21 | } -------------------------------------------------------------------------------- /3 HDL - A system verilog example/tutorial_hdl.sv: -------------------------------------------------------------------------------- 1 | `timescale 1ns / 1ps 2 | ////////////////////////////////////////////////////////////////////////////////// 3 | // Company: 4 | // Engineer: 5 | // 6 | // Create Date: 07/20/2024 11:42:12 AM 7 | // Design Name: 8 | // Module Name: tutorial_hdl 9 | // Project Name: 10 | // Target Devices: 11 | // Tool Versions: 12 | // Description: 13 | // 14 | // Dependencies: 15 | // 16 | // Revision: 17 | // Revision 0.01 - File Created 18 | // Additional Comments: 19 | // 20 | ////////////////////////////////////////////////////////////////////////////////// 21 | 22 | 23 | module tutorial_hdl( 24 | input logic sw1, 25 | input logic sw2, 26 | // (* clock_buffer_type="none" *) ===> deal with the fact btn is not an actual clock 27 | (* clock_buffer_type="none" *) input logic btn, 28 | output logic led_out 29 | ); 30 | 31 | always @(posedge btn) begin 32 | led_out <= sw1 & sw2; 33 | end 34 | endmodule 35 | -------------------------------------------------------------------------------- /2 AXI IP Hello world custom LED driver/Hardware/led_axi_ip.v: -------------------------------------------------------------------------------- 1 | `timescale 1ns / 1ps 2 | ////////////////////////////////////////////////////////////////////////////////// 3 | // Company: 4 | // Engineer: 5 | // 6 | // Create Date: 07/10/2024 06:40:48 PM 7 | // Design Name: 8 | // Module Name: led_axi_ip 9 | // Project Name: 10 | // Target Devices: 11 | // Tool Versions: 12 | // Description: 13 | // 14 | // Dependencies: 15 | // 16 | // Revision: 17 | // Revision 0.01 - File Created 18 | // Additional Comments: 19 | // 20 | ////////////////////////////////////////////////////////////////////////////////// 21 | 22 | 23 | module led_axi_ip( 24 | input wire [31:0] slv_reg, 25 | input wire clk, 26 | output reg led_out 27 | ); 28 | 29 | // CHANGES TO MY CODE 30 | 31 | always @(posedge clk) begin 32 | if(slv_reg[0] == 1) begin 33 | led_out <= 1; 34 | end else begin 35 | led_out <= 0; 36 | end 37 | end 38 | 39 | endmodule 40 | -------------------------------------------------------------------------------- /5 Petalinux on Zynq quick example/commands.sh: -------------------------------------------------------------------------------- 1 | # This is a list of the commands i ran in the video 2 | # WARNING : This file is .sh for syntax only and is NOT meant to be executed 3 | 4 | # Install petalinux. 5 | # LINK : https://www.xilinx.com/support/download/index.html/content/xilinx/en/downloadNav/embedded-design-tools.html 6 | 7 | # Choose an installation path and let it install. You might have to install a bunch of libraries to make wrning and errors go away 8 | 9 | # Then everything is pretty self explainatory and here are the commands : 10 | 11 | source ./settings.sh . 12 | 13 | petalinux-create project --template zynq --name /path/to/your/project 14 | 15 | petalinux-config --get-hw-description /path/to/XSA 16 | 17 | # takes a long time, go for a walk (almost mandatory) 18 | petalinux-build 19 | 20 | cd ./images/linux 21 | petalinux-package --boot --fsbl zynq_fsbl.elf --u-boot u-boot.elf --fpga system.bit --force 22 | 23 | # If you want to know more about First Stage and Second Stage Boot loadeers, see the README.md at the root of the repo in this vide ressource, i will give ressources on that (Zynq book) 24 | 25 | # once on your board's UART : get its ip with : 26 | ifconfig 27 | 28 | # you can then SSH into it using : 29 | 30 | ssh petalinux@your.board.ip 31 | 32 | # on your main machine 33 | # you can also test access to diferrent adresses with 34 | ping ip.address 35 | 36 | # I've got some feedback with ping : bad adress Error. 37 | # try 192.168.1.1 (your home router most of the time) 38 | # If this does not work refer to these posts: 39 | # https://support.xilinx.com/s/question/0D52E00006hpTKeSAM/petalinux-201310-can-ping-outside-web-site?language=en_US 40 | # https://support.xilinx.com/s/question/0D52E00006hpRxBSAU/petalinux-build-ethernet-not-working-cannot-ping?language=en_US 41 | # As you can see, these bugs are sometimes pretty hard to troubleshoot, so start a discussion if you need deeper troubleshooting. 42 | -------------------------------------------------------------------------------- /2 AXI IP Hello world custom LED driver/Hardware/LED_IP_v1_0.v: -------------------------------------------------------------------------------- 1 | 2 | `timescale 1 ns / 1 ps 3 | 4 | module LED_IP_v1_0 # 5 | ( 6 | // Users to add parameters here 7 | 8 | // User parameters ends 9 | // Do not modify the parameters beyond this line 10 | 11 | 12 | // Parameters of Axi Slave Bus Interface S00_AXI 13 | parameter integer C_S00_AXI_DATA_WIDTH = 32, 14 | parameter integer C_S00_AXI_ADDR_WIDTH = 4 15 | ) 16 | ( 17 | // Users to add ports here 18 | 19 | output wire led_out, 20 | // User ports ends 21 | // Do not modify the ports beyond this line 22 | 23 | 24 | // Ports of Axi Slave Bus Interface S00_AXI 25 | input wire s00_axi_aclk, 26 | input wire s00_axi_aresetn, 27 | input wire [C_S00_AXI_ADDR_WIDTH-1 : 0] s00_axi_awaddr, 28 | input wire [2 : 0] s00_axi_awprot, 29 | input wire s00_axi_awvalid, 30 | output wire s00_axi_awready, 31 | input wire [C_S00_AXI_DATA_WIDTH-1 : 0] s00_axi_wdata, 32 | input wire [(C_S00_AXI_DATA_WIDTH/8)-1 : 0] s00_axi_wstrb, 33 | input wire s00_axi_wvalid, 34 | output wire s00_axi_wready, 35 | output wire [1 : 0] s00_axi_bresp, 36 | output wire s00_axi_bvalid, 37 | input wire s00_axi_bready, 38 | input wire [C_S00_AXI_ADDR_WIDTH-1 : 0] s00_axi_araddr, 39 | input wire [2 : 0] s00_axi_arprot, 40 | input wire s00_axi_arvalid, 41 | output wire s00_axi_arready, 42 | output wire [C_S00_AXI_DATA_WIDTH-1 : 0] s00_axi_rdata, 43 | output wire [1 : 0] s00_axi_rresp, 44 | output wire s00_axi_rvalid, 45 | input wire s00_axi_rready 46 | ); 47 | // Instantiation of Axi Bus Interface S00_AXI 48 | 49 | wire led_out; 50 | 51 | LED_IP_v1_0_S00_AXI # ( 52 | .C_S_AXI_DATA_WIDTH(C_S00_AXI_DATA_WIDTH), 53 | .C_S_AXI_ADDR_WIDTH(C_S00_AXI_ADDR_WIDTH) 54 | ) LED_IP_v1_0_S00_AXI_inst ( 55 | .S_AXI_ACLK(s00_axi_aclk), 56 | .S_AXI_ARESETN(s00_axi_aresetn), 57 | .S_AXI_AWADDR(s00_axi_awaddr), 58 | .S_AXI_AWPROT(s00_axi_awprot), 59 | .S_AXI_AWVALID(s00_axi_awvalid), 60 | .S_AXI_AWREADY(s00_axi_awready), 61 | .S_AXI_WDATA(s00_axi_wdata), 62 | .S_AXI_WSTRB(s00_axi_wstrb), 63 | .S_AXI_WVALID(s00_axi_wvalid), 64 | .S_AXI_WREADY(s00_axi_wready), 65 | .S_AXI_BRESP(s00_axi_bresp), 66 | .S_AXI_BVALID(s00_axi_bvalid), 67 | .S_AXI_BREADY(s00_axi_bready), 68 | .S_AXI_ARADDR(s00_axi_araddr), 69 | .S_AXI_ARPROT(s00_axi_arprot), 70 | .S_AXI_ARVALID(s00_axi_arvalid), 71 | .S_AXI_ARREADY(s00_axi_arready), 72 | .S_AXI_RDATA(s00_axi_rdata), 73 | .S_AXI_RRESP(s00_axi_rresp), 74 | .S_AXI_RVALID(s00_axi_rvalid), 75 | .S_AXI_RREADY(s00_axi_rready), 76 | .led_out(led_out) 77 | ); 78 | 79 | // Add user logic here 80 | 81 | // User logic ends 82 | 83 | endmodule 84 | -------------------------------------------------------------------------------- /8 Python to FPGA/3_manual_inference/vitis_software/generate_test_data.py: -------------------------------------------------------------------------------- 1 | # This generates an header file with N MNIST flatten sample 2 | # data is 8bits INT8 (or just char) meant to be imported 3 | # into a Tx buffer for use in Vitis (for DMA). 4 | # we can then use labels to compare execution 5 | 6 | import torch 7 | from torchvision import datasets, transforms 8 | import random 9 | 10 | def quantize_tensor(x, num_bits=8): 11 | qmin = 0. 12 | qmax = 2.**num_bits - 1. 13 | min_val, max_val = x.min(), x.max() 14 | 15 | scale = (max_val - min_val) / (qmax - qmin) 16 | initial_zero_point = qmin - min_val / scale 17 | 18 | zero_point = 0 19 | if initial_zero_point < qmin: 20 | zero_point = qmin 21 | elif initial_zero_point > qmax: 22 | zero_point = qmax 23 | else: 24 | zero_point = initial_zero_point 25 | 26 | zero_point = int(zero_point) 27 | q_x = zero_point + x / scale 28 | q_x.clamp_(qmin, qmax).round_() 29 | 30 | return q_x.byte() 31 | 32 | # Load MNIST dataset 33 | transform = transforms.Compose([ 34 | transforms.ToTensor(), 35 | transforms.Lambda(lambda x: quantize_tensor(x)) 36 | ]) 37 | 38 | mnist_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform) 39 | 40 | print(mnist_dataset[0][0]) 41 | 42 | # Select random samples 43 | num_samples = 100 44 | indices = random.sample(range(num_samples), num_samples) 45 | print(indices) 46 | samples = [mnist_dataset[i][0] for i in indices] 47 | labels = [mnist_dataset[i][1] for i in indices] 48 | print(labels) 49 | 50 | # Generate C header file 51 | with open('mnist_samples.h', 'w') as f: 52 | f.write("// This file has been auto-generated\n\n") 53 | f.write("#ifndef MNIST_SAMPLES_H\n") 54 | f.write("#define MNIST_SAMPLES_H\n\n") 55 | f.write(f"#define NUM_SAMPLES {str(num_samples)}\n") 56 | f.write("#define IMAGE_SIZE 784\n\n") 57 | 58 | f.write("const unsigned char mnist_samples[NUM_SAMPLES][IMAGE_SIZE] = {\n") 59 | 60 | for i, sample in enumerate(samples): 61 | # Denormalize, scale to 0-255, and flatten 62 | img = sample.squeeze() 63 | img_flat = img.reshape(-1).byte().tolist() 64 | 65 | f.write(" {") 66 | f.write(", ".join(map(str, img_flat))) 67 | f.write("},\n") 68 | 69 | f.write("};\n\n") 70 | 71 | f.write("const unsigned char mnist_labels[NUM_SAMPLES] = {\n") 72 | 73 | for j, label in enumerate(labels): 74 | f.write(" " + str(label) + ",\n") 75 | 76 | f.write("};\n\n") 77 | f.write("#endif // MNIST_SAMPLES_H\n") 78 | 79 | print("MNIST samples have been generated and saved in 'mnist_samples.h'") -------------------------------------------------------------------------------- /4 DMA tutorial/dma_fft.c: -------------------------------------------------------------------------------- 1 | #include "xparameters.h" 2 | #include "xaxidma.h" 3 | #include 4 | #include 5 | #include "complex.h" 6 | 7 | #define FFT_POINTS 8 8 | 9 | XAxiDma AxiDma; 10 | int init_dma(XAxiDma *AxiDma){ 11 | XAxiDma_Config* CfgPtr; 12 | int status; 13 | 14 | // Look up the configuration in the hardware configuration table 15 | CfgPtr = XAxiDma_LookupConfig(XPAR_AXI_DMA_0_BASEADDR); 16 | if (!CfgPtr) { 17 | printf("No configuration found for %d\n", XPAR_AXI_DMA_0_BASEADDR); 18 | return XST_FAILURE; 19 | } 20 | 21 | // Initialize the DMA handle with the configuration structure 22 | status = XAxiDma_CfgInitialize(AxiDma, CfgPtr); 23 | if (status != XST_SUCCESS) { 24 | printf("Initialization failed\n"); 25 | return XST_FAILURE; 26 | } 27 | 28 | // Check for Scatter Gather mode, which is not supported in this example 29 | if (XAxiDma_HasSg(AxiDma)) { 30 | printf("Device configured as SG mode\n"); 31 | return XST_FAILURE; 32 | } 33 | 34 | return XST_SUCCESS; 35 | } 36 | 37 | int main(void){ 38 | int status = init_dma(&AxiDma); 39 | if(status != XST_SUCCESS){ 40 | printf("Error while initialize the DMA"); 41 | return 1; 42 | } 43 | 44 | // data buffers that will live in memory 45 | float complex RxBuffer[FFT_POINTS] __attribute__ ((aligned (32))); 46 | float complex TxBuffer[FFT_POINTS] __attribute__ ((aligned (32))); 47 | 48 | //fill tx buffer with actual data 49 | for(int i = 0; i < FFT_POINTS; i++){ 50 | TxBuffer[i] = i; 51 | } 52 | 53 | // flush the cache 54 | Xil_DCacheFlushRange((UINTPTR)RxBuffer, FFT_POINTS * sizeof(float complex)); 55 | Xil_DCacheFlushRange((UINTPTR)TxBuffer, FFT_POINTS * sizeof(float complex)); 56 | 57 | 58 | 59 | // START TRANFER 60 | status = XAxiDma_SimpleTransfer(&AxiDma, (UINTPTR)TxBuffer, FFT_POINTS * sizeof(float complex), XAXIDMA_DMA_TO_DEVICE); 61 | if (status != XST_SUCCESS) { 62 | printf("DMA transfer to FFT failed\n"); 63 | return XST_FAILURE; 64 | } 65 | 66 | status = XAxiDma_SimpleTransfer(&AxiDma, (UINTPTR)RxBuffer, FFT_POINTS * sizeof(float complex), XAXIDMA_DEVICE_TO_DMA); 67 | if (status != XST_SUCCESS) { 68 | printf("DMA transfer from failed\n"); 69 | return XST_FAILURE; 70 | } 71 | 72 | while (1) { 73 | if (!(XAxiDma_Busy(&AxiDma, XAXIDMA_DEVICE_TO_DMA)) && 74 | !(XAxiDma_Busy(&AxiDma, XAXIDMA_DMA_TO_DEVICE))) { 75 | break; 76 | } 77 | } 78 | 79 | // make sure to use printf() / not xil_printf() here 80 | for(int i = 0; i < FFT_POINTS; i++){ 81 | printf("FPGA valeur RxBuffer[%d] = %.2f %.2f i\n", i, crealf(RxBuffer[i]), cimagf(RxBuffer[i])); 82 | } 83 | 84 | 85 | return 0; 86 | } -------------------------------------------------------------------------------- /8 Python to FPGA/3_manual_inference/vitis_software/main.c: -------------------------------------------------------------------------------- 1 | // This code is to be copy pasted in you vivtis application component 2 | // alongside data generated by the python generator 3 | // PLEASE INCREASE YOUR STACK AND HEAP SIZE ! to avoid program stall 4 | 5 | #include "xparameters.h" 6 | #include "xaxidma.h" 7 | #include 8 | #include 9 | #include "mnist_samples.h" 10 | 11 | XAxiDma AxiDma; 12 | 13 | int init_dma(XAxiDma *AxiDma) { 14 | XAxiDma_Config* CfgPtr; 15 | int status; 16 | 17 | CfgPtr = XAxiDma_LookupConfig(XPAR_AXI_DMA_0_BASEADDR); 18 | if (!CfgPtr) { 19 | xil_printf("No configuration found for %d\n", XPAR_AXI_DMA_0_BASEADDR); 20 | return XST_FAILURE; 21 | } 22 | 23 | status = XAxiDma_CfgInitialize(AxiDma, CfgPtr); 24 | if (status != XST_SUCCESS) { 25 | xil_printf("Initialization failed\n"); 26 | return XST_FAILURE; 27 | } 28 | 29 | if (XAxiDma_HasSg(AxiDma)) { 30 | xil_printf("Device configured as SG mode\n"); 31 | return XST_FAILURE; 32 | } 33 | 34 | return XST_SUCCESS; 35 | } 36 | 37 | static inline void enable_pmu_cycle_counter(void) { 38 | asm volatile("mcr p15, 0, %0, c9, c12, 1" :: "r"(1 << 31)); // Enable cycle counter 39 | asm volatile("mcr p15, 0, %0, c9, c12, 0" :: "r"(1)); // Enable all counters 40 | } 41 | 42 | static inline uint32_t read_pmu_cycle_counter(void) { 43 | uint32_t value; 44 | asm volatile("mrc p15, 0, %0, c9, c13, 0" : "=r"(value)); 45 | return value; 46 | } 47 | 48 | int main(void) { 49 | enable_pmu_cycle_counter(); 50 | uint32_t start, end; 51 | 52 | int status = init_dma(&AxiDma); 53 | if(status != XST_SUCCESS) { 54 | xil_printf("Error while initializing the DMA\n"); 55 | return 1; 56 | } 57 | 58 | xil_printf("DMA initialized successfully\n"); 59 | 60 | volatile char TxBuffer[IMAGE_SIZE*NUM_SAMPLES] __attribute__ ((aligned (32))); 61 | volatile int RxBuffer[NUM_SAMPLES] __attribute__ ((aligned (32))); 62 | 63 | xil_printf("Memory init OKAY\n"); 64 | 65 | for(int j = 0; j < NUM_SAMPLES; j++) { 66 | for(int i = 0; i < IMAGE_SIZE; i++) { 67 | // xil_printf("I : %d /// J : %d\n", i, j); // debug purpose 68 | TxBuffer[j * IMAGE_SIZE + i] = (char)mnist_samples[j][i]; // fill with variable placeholder data 69 | } 70 | } 71 | 72 | xil_printf("Memory allocation OKAY\n"); 73 | 74 | Xil_DCacheFlushRange((UINTPTR)TxBuffer, NUM_SAMPLES * IMAGE_SIZE * sizeof(char)); 75 | Xil_DCacheFlushRange((UINTPTR)RxBuffer, NUM_SAMPLES * sizeof(char)); 76 | 77 | xil_printf("Cach flush OKAY, Strating transfers...\n"); 78 | 79 | start = read_pmu_cycle_counter(); 80 | for(int k = 0; k < NUM_SAMPLES; k++) { 81 | 82 | status = XAxiDma_SimpleTransfer(&AxiDma, (UINTPTR)&TxBuffer[k*IMAGE_SIZE], IMAGE_SIZE * sizeof(char), XAXIDMA_DMA_TO_DEVICE); 83 | //printf("%i TO_DEVICE status code\n", status); 84 | if (status != XST_SUCCESS) { 85 | xil_printf("Error: DMA transfer to device failed\n"); 86 | return XST_FAILURE; 87 | } 88 | 89 | status = XAxiDma_SimpleTransfer(&AxiDma, (UINTPTR)&RxBuffer[k], sizeof(int), XAXIDMA_DEVICE_TO_DMA); 90 | //printf("%i FROM_DEVICE status code\n", status); 91 | if (status != XST_SUCCESS) { 92 | xil_printf("Error: DMA transfer from device failed\n"); 93 | return XST_FAILURE; 94 | } 95 | 96 | while (XAxiDma_Busy(&AxiDma, XAXIDMA_DMA_TO_DEVICE) || 97 | XAxiDma_Busy(&AxiDma, XAXIDMA_DEVICE_TO_DMA)) { 98 | ; 99 | } 100 | xil_printf("#%i iteration done\n", k); 101 | } 102 | end = read_pmu_cycle_counter(); 103 | 104 | // Output classifier's results & compute the accuracy 105 | 106 | int valid = 0; 107 | int accuracy_percentage; 108 | 109 | for(int i = 0; i < NUM_SAMPLES; i++) { 110 | xil_printf("FPGA value RxBuffer[%d] = %d\n", i, RxBuffer[i]); 111 | if(RxBuffer[i] == mnist_labels[i]){ 112 | valid++; 113 | } 114 | } 115 | // Calculate accuracy as a percentage, multiplied by 100 to preserve precision 116 | accuracy_percentage = (valid * 100) / NUM_SAMPLES; 117 | xil_printf("\n\nMODEL ACCURACY = %d%%\n", accuracy_percentage); 118 | 119 | uint32_t cycles = end - start; 120 | double time_ms = (double)cycles / 667000.0; 121 | printf("Execution time: %f milliseconds\n", time_ms); 122 | 123 | return 0; 124 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BRH Tutorials - Code Github repo 2 | 3 | > [!TIP] 4 | > Contributions are welcome ! Especially since I'm prone to mistakes ;) You can also check out the [blog](https://0bab1.github.io/BRH/) where you have additional informations for each tutorial, *but also exclusive content related to engineering in general and bonus tutorials to learn more valuable stuff on a text format*. 5 | 6 | This repo contains **all the code you need** for the tutorials, ready to copy and paste. 7 | 8 | It also contains **links to the associated [blog posts](https://0bab1.github.io/BRH/)** where you'll find additional tips & resources if you want to do it yourself or push the concepts I discuss a bit further. 9 | 10 | I will also link all the very interesting ressources that helped me in my own learning journey for the curious among you. 11 | 12 | > [!TIP] 13 | > In the future, I will focus more on "Project reviews" on youtube and post actual technical tutorials here and on the blog. Stay tuned ! 14 | 15 | # Table of Contents 16 | 17 | 1. [ZynQ in two minutes](#1--zynq-in-two-minutes) 18 | 2. [AXI IP, Hello world & custom LED driver](#2--axi-ip-hello-world--custom-led-driver) 19 | 3. [An introduction to HDLs](#3--an-introduction-to-hdls) 20 | 4. [DMA Tutorial with an FFT IP](#4--dma-tuturial-with-an-fft-ip) 21 | 5. [Quick Petalinux Example for Zynq](#5--quick-petalinux-example-for-zynq) 22 | 8. [Python AI to FPGA](#8--python-ai-model-to-fpga) 23 | 24 | 25 | # 1 : ZynQ in two minutes 26 | 27 | [Video](https://www.youtube.com/watch?v=DQHTSelupDs) 28 | 29 | This video is all about clarifying what ZynQ is and how it use FPGAs at its advantage. 30 | 31 | ## Code 32 | 33 | There is no code necessary for this video 34 | 35 | ## Blog post 36 | 37 | Find the blog post [here](https://0bab1.github.io/BRH/posts/Zynq_in_120s/) 38 | 39 | # 2 : AXI IP, Hello world & custom LED driver 40 | 41 | [Video](https://www.youtube.com/watch?v=zJJTxOT37K4) 42 | 43 | ## Code 44 | 45 | In the video's directory you will find: 46 | 47 | - Verilog code for the top wrapper 48 | - Verilog code for the AXI protocol handler in which we add our submodule 49 | - Verilog code for the submodule 50 | - Constraint file to "link" the IP's output to the LED's pin 51 | - The main.c file for the sofware part 52 | 53 | ## Blog post 54 | 55 | Find the blog post [here](https://0bab1.github.io/BRH/posts/Axi_led_ip/) 56 | 57 | # 3 : An introduction to HDLs 58 | 59 | [Video](https://www.youtube.com/watch?v=9wNddNaA_1o) 60 | 61 | ## Code 62 | 63 | In the video's directory you will find: 64 | 65 | - just the basic hdl and constraints so you have a basic template, nothing fancy 66 | 67 | ## Blog post 68 | 69 | Find the blog post [here](https://0bab1.github.io/BRH/posts/HDL_in_120s/) 70 | 71 | # 4 : DMA Tuturial with an FFT IP 72 | 73 | [Video](https://youtu.be/aySO9jCKj9g) 74 | 75 | ## Code 76 | 77 | In the video's directory you will find: 78 | 79 | - dma_fft.c : The main file that I took you through during the video 80 | 81 | **IMPORTANT NOTE :** do not use xil_printf() but rather printf() to handle the display of floats correctly. 82 | 83 | ## Blog post 84 | 85 | Find the blog post [here](https://0bab1.github.io/BRH/posts/DMA_FFT_ON_ZYNQ/) 86 | 87 | # 5 : Quick Petalinux Example for Zynq 88 | 89 | Summer is here ! Lighthearted video, more or less entertainement but keeping in mind the objective of delivering fast and technicals infos for you to enjoy ! 90 | 91 | [Video](https://www.youtube.com/watch?v=SUBGtxwq7RY) 92 | 93 | ## Code 94 | 95 | In the video's directory you will find: 96 | 97 | - commands.sh : a file (*NOT MEANT FOR EXECUTION*) with a list of the commands alongside some tips 98 | 99 | IMPORTANT NOTE : do not use xil_printf() but rather printf() to handle the display of floats correctly. 100 | 101 | ## Blog Post 102 | 103 | There is no blog post for this one, but here are some resources : 104 | 105 | - [Learn a LOT more on Linux for embedded application and FPGA on the Zynq book PDF (from page 385 chapter 22 to chapter 24)](https://is.muni.cz/el/1433/jaro2015/PV191/um/The_Zynq_Book_ebook.pdf) 106 | - [Download Petalinux Here](https://www.xilinx.com/support/download/index.html/content/xilinx/en/downloadNav/embedded-design-tools.html) 107 | 108 | Here are the topics covered in this ressource : 109 | 110 | - Chapter 22 _Linux an overview_ : understand the basic of linux to get some context 111 | - Chapter 23 _The linux kernel_ : pretty interresting but not useful to understand booting here 112 | - Chapter 24 _Linux booting_ : FSBL, SSBL, U-BOOT etc... are discussed 113 | 114 | Combined with the video, this 1 hour read should give you a great overview on Linux for embedded solution in no time (in 1/2 day of learning, worth it) ! 115 | 116 | # 8 : Python AI Model to fpga 117 | 118 | THANK YOU SO MUCH for the success of the intro video !!! But now, Let's get serious : time for the actual tutorial, and it's a BIG one ! 119 | 120 | [Video](https://www.youtube.com/@BRH_SoC) 121 | 122 | ## Code 123 | 124 | In the video's directory you will find: 125 | 126 | - ALL THE RESSOURCES 127 | 128 | > This video is a big one, It has its own [Readme file](./8%20Python%20to%20FPGA//README.md). See you there ! 129 | 130 | ## Resources & Blog Post 131 | 132 | C.f. the [Readme file](./8%20Python%20to%20FPGA//README.md). 133 | 134 | You can also find the blog post [here](https://0bab1.github.io/BRH/posts/PY2FPGA/) for building the model's IP using FINN and [here](https://0bab1.github.io/BRH/posts/FPGA_MANUAL_INFERENCE/) to run manual inference on Zynq. 135 | -------------------------------------------------------------------------------- /8 Python to FPGA/3_manual_inference/README.md: -------------------------------------------------------------------------------- 1 | # Manual inference : A bonus sub-tutorial 2 | 3 | > [!CAUTION] 4 | > This part is not like the others, it will be hard for you to do it 100% on your own if you are a beginner. The video will contain more details on this part (and include a bit of debugging insights). Use the video as support to help you getting through this part ! I also expect the reader to have some knowledge ith xilinx tools. If you are a complete beginner, check the ressource from the [main readme](../README.md) for some of my video or directly go on my channel to check for my early tutorials. 5 | 6 | ## Why manual inference ? 7 | 8 | FINN Provides a very nice runtime environement based on Pynq to run FPGA inference directly from a notebook. 9 | 10 | However, you might want to run inference on other (unsuported) FPGA boards. Given that only 2 boards are made for PYNQ and only [a few](http://www.pynq.io/boards.html) are officially supported at the moment, we will do the FPGA inference manually. 11 | This allows for better understanding and flexibility for your future projects. 12 | 13 | To do so, we will go over various steps to make it work. 14 | 15 | - 0 => Find the FINN exportoted stitched IP and integrate this to your vivado project 16 | - 1 => Create our own "glue logic" IP to interface between the model and Xilinx's DMA IP 17 | - 2 => Run synth & impl and export to vitis 18 | - 3 => Create software to run inference using DMA's drivers 19 | 20 | ## What is in this folder ? 21 | 22 | This folder contains : 23 | 24 | - HARDWARE : The glue logic IP's System Verilog code (use the ```git clone --recursive``` flag to clone the sub repo if needed or access it [here](https://github.com/0BAB1/Axi-Stream-FIFO-for-FINN)) 25 | - SOFTWARE : The code we'll run in Vitis 26 | - SOFTWARE : A data generator for C inference of MNIST data. 27 | 28 | ## How do I run manual inference 29 | 30 | > Note that this is just a minimalist example and your applications/projects might need another architecture. 31 | 32 | ## 0 : Find the FINN exportoted stitched IP and integrate this to your vivado project 33 | 34 | During PART 2, we used FINN to generate a "stitched IP" that conviniently generated a zynq project for us. Regardless of the workflow you chose (FINN has multple workflows like a [CLI](https://finn.readthedocs.io/en/latest/command_line.html) or [Custom builders](https://finn.readthedocs.io/en/latest/command_line.html) that does all we did in LAB in a automated way), you will always have a collection of outputs including : 35 | 36 | - The different layers as IPs 37 | - A stitched IP 38 | 39 | After PART 2, you can access the stitched IP by opinning the ```/tmp/finn_dev_yourusername``` folder, you will then find a range of output product. 40 | We are going to focus on the ```vivado_zynq_xxx.../``` folder and open the .xpr using vivado. 41 | 42 | **Important :** Before doing anything, go in the project settings and chage the target part/board ! 43 | 44 | > For this part, the video teaser of the turial includes more details. It was meant to be useful ! don't hesitate to go back to the video on part 3 to see where was my output model from PART2, how i found it and how I preped the vivado project. 45 | 46 | ## 1 : Create our own "glue logic" IP to interface between the model and Xilinx's DMA IP 47 | 48 | DMA Tutorial [here](https://www.youtube.com/watch?v=aySO9jCKj9g) for begginers. 49 | 50 | > You have all al the HDL for the custom glue logic in this repo (use got clone --recursive to clone it directly with this repo). Agin, check the video if you have any doubt on how to make a custom IP block from HDL. 51 | 52 | With the output vivado project oppened, we will now proceed to delete every IP used in the block design **except for the stiched IP**, we will keep it a build our system around it. 53 | 54 | As we can see, the stiched IP is very conviently packed with simplifed stream interfaces and expects a 8bits input for the data, just as planed ! 55 | 56 | But there is a problem : to transfer data, we will use Xilinx's DMA IP that need TLAST signal to function properly. 57 | 58 | You can create a custom FIFO IP using the HDL in this repo's folder in order to handle the correct signals assertion for DMA to function properly. 59 | 60 | Then add this custom IP Right after the FINN IP. 61 | 62 | We then add the usual DMA etc.. to send data to the model via AXI Stream directly from memory. [Here is a small tutorial illustrating how to use DMA](https://www.youtube.com/watch?v=aySO9jCKj9g) 63 | 64 | The end custom system should then look like this : 65 | 66 | ![Final system image](./final_custom_system.png) 67 | 68 | > [!CAUTION] 69 | > As you saw in the video, some weird stuff happens when you configure the custom FIFO data width to 32 bits to match the one of DMA. Here's a good way to get around this : 70 | > this consits in not changing the default 8bits data width of the fifo and letting the DMA operate in 32 its with memory. 71 | > To avoid mismatch between the 8bits fifo and the 32bits dma interfaces, we connect them manually using constants and concat blocks, this process is 72 | > described in the video. 73 | 74 | **ANYWAY**, here are the working configs used in the tutorial at the end with the concat and const manual connections : 75 | 76 | ![image of the working hacky system](hack_bug.png) 77 | 78 | Edit : 79 | 80 | - The TKEEP const block is there to tell DMA what byte is valid. As we send 8 bit data, only th 1st byte is valid so set it to 0b001 81 | - And the 24buts const blocks are 0s, its just to fill the gap between the 8 bits we send and the 32 bits required to avoid side effects. 82 | - Don't forget to ensure the concat block put your 8 bits at the end of the 32 bits ! 83 | 84 | And the fianl DMA config : 85 | 86 | ![final dma config that worked for me](working_dma_config.png) 87 | 88 | ## 2 : Run synth & impl and export to vitis 89 | 90 | This step is not a step per say, you simply have to generate a bitsteam and export the resulting hardware to Vitis so we can use the drivers to generate some software 91 | 92 | > Also brefely described in the video, my channel also has more [basic tutorials](https://www.youtube.com/watch?v=zJJTxOT37K4) on this workflow. 93 | 94 | ## 3 : Create software to run inference using DMA's drivers 95 | 96 | Once in vitis with platform & app components created, you can take inspiration from the code in the repo's "vitis prgram folders". 97 | 98 | You also have, in this repo, a main.py file that will generate random quantized (UINT8) data alongside the corresponding labels and put these in a header file to use in our software. 99 | 100 | > Note : Be sure that your heap and stack size complies with the number of samples you will load in memory (I use 0xf000 for both stack and heap to run this example). You can modify this parameter in linker.ld in the src/ folder. Or you can lower the number of generated samples in the python generator (e.g. 0x2000 size fits 20 samples pretty well on my side). 101 | 102 | At the end of the day, you should have (in Vitis): 103 | 104 | - src/ 105 | - [main.c](./vitis_software/main.c) (with ```#include ```) 106 | - mnist_samples.h, generated by executing [this python script](./vitis_software/generate_test_data.py) 107 | - linker.ld (modified linker to increase heap & stack size) 108 | 109 | When this is done, you can simply build the code and flash it onto your FPGA. 110 | 111 | ## 4 : What is next ? 112 | 113 | This is the part where you figure out whether you need to spend hours to debug or enjoy a nicely build project. 114 | 115 | - Open UART tty 116 | - Run debuging mode, observe results 117 | - Compare FPGA accuracy with python simulations (asserted equal) 118 | - Use system ILA to debug DMA & Drivers problems 119 | 120 | Do not hesitate to open an Issue or contact if you have a problem. Here is a final output example : 121 | 122 | ![final output image](final_output.png) 123 | 124 | > If you need to debug stuff, use ILA. If you are lost, reach out to me in the video comments. 125 | 126 | Quick note : I'm planning to do an ILA tutorial in the next few month in case the next videos take to much time to produce. In which case I will post the link here, In the meantime, internet is full of tutorials for this, they just tend to not be the best quality... 127 | -------------------------------------------------------------------------------- /8 Python to FPGA/README.md: -------------------------------------------------------------------------------- 1 | # Python to FPGA video : Kickstart your FPGA AI Projects 2 | 3 | > [!IMPORTANT] 4 | > This is a "DIY" tutorial derived from my 9 hours course. As I cannot teach all the different concepts here, you'll have access to all the resources you need to learn. If you go from 0 (given that you have some basis on FPGA and all) I'd say this project can that a few days to comlete with a good understanding. Hop you enjoy this project, godd luck, and see you on the other side ! 5 | 6 | This folder contains EVERYTHING you need to kickstart your FPGA AI projects. 7 | 8 | > It contains a FULL TUTORIAL meant to be "DIY" in which you will be guided to deploy your own AI model on FPGA ! 9 | 10 | To make the video as enjoyable as possible (It is not course, because I sell the actual course to universities), I made the choice to let the viewer do its own research when it comes to the various details **BUT** you are not alone ! This README acts like a small course with all of the resources nedded for a full understanding of what you are doing. 11 | 12 | ## What will you find in this repo 13 | 14 | ### PART 0 : This readme 15 | 16 | The goal of this readme is to serve a "lighthouse" to which you come and go to gather resources and tips and various stuff to complete this project at you own pace. 17 | 18 | > I recommend you read it fully before starting the PART 1. 19 | 20 | You have some explainations below on where and how to start this project. 21 | 22 | ### PART 1 (Notebook) : Create and train your first QNN 23 | 24 | Using brevitas, you will empower Quantize Aware Training to create and train a classifier that you recognize samples from the FasionMNIST dataset. 25 | 26 | The notbooks will have plenty of explaination and you can you the resources at the end of this readme to fill any hole in your understanding. 27 | 28 | ### PART 2 (Notebook) : Use FINN to convert you QNN to hardware 29 | 30 | Learn how to use FINN to manipulate an ONNX graph until the said graph has been reduced to a hardware representation 31 | 32 | ### PART 3 (Vivado & Vitis) : Run a simple manual inference application 33 | 34 | If you know your way around Vivado and vitis, this is a great exercice ! Create a SoC, code a firmware for the DMA and finally, get you results through UART, from a model exclusively ran on FPGA ! YOUR OWN FPGA hardware ! 35 | 36 | You'll have a separate readme for PART 3, but more on that below :) 37 | 38 | ## Pre-requesites 39 | 40 | ### Before listing the prerequisites 41 | 42 | Regarding knowledge and understanding, you might not understand everything layed down here : 43 | 44 | - Regarding harware tools : take a look on my channel, I have basic tutorials to learn hardware project workflows 45 | - Regarding knowledge : Take a look at the resources below 46 | 47 | ### Actual prerequisites 48 | 49 | - Install [Vivado, Vitis & Vitis HLS](https://www.xilinx.com/support/download/index.html/content/xilinx/en/downloadNav/vivado-design-tools/2023-2.html), tutorial done with 2023.2 version. 50 | - Linux, or good mastering of your docker windows environment (please use linux). 51 | - Some knowledge on : 52 | - AI, Python, C, Pytorch 53 | - Data and network quantization 54 | - Vivado, Vitis 55 | - HLS (you just need to know it exists as FINN will handle it) 56 | - Docker, linux, Jupyter notebooks 57 | - ONNX (You can learn that along the way don't worry, but [look it up](https://fr.wikipedia.org/wiki/Open_Neural_Network_Exchange) on if needed) 58 | - And of course, FPGAs 59 | - Have a FPGA board. If you have a PYNQ compatible board, you can skip the manual inference part 60 | 61 | ## All set ? Let's get to work 62 | 63 | > Okay, so where do I start ? What do I do now ?
- You (maybe) 64 | 65 | Well here is how I suggest you go about this : 66 | 67 | - First, watch the "[tutorial](https://www.youtube.com/watch?v=VsXMlSB6Yq4)" that sumarizes what we'll do here. 68 | - Set up your dev environement, c.f. resources below for the official up-to-date tutorial from FINN (fairly simple) **AND** the notes on preparing the dev environement just below this bullet list 69 | - After that, clone this repo and cd in this sub directory and simply follow the notebook instructions, by starting from PART 1 obvbiously ;) 70 | 71 | If you have a doubt, come back to the video or use the resources below to learn more about each and every points of interest. 72 | 73 | ### Preparing the dev environement 74 | 75 | You have the official tutorial in the resources but here are some points of attention : 76 | 77 | - use `bash run-docker.sh notebook` command and **NOT** `sh run-docker.sh notebook` 78 | - It will give you the jupyter URL 79 | - If you use your notebook in VSCODE, simply : 80 | - Select kernel... 81 | - Exsting Jupyter server.. 82 | - Paste in the URL that looks like this : `http://127.0.0.1:8888/tree?token=` 83 | - If you use Jupyter : simply go on the URL : `http://127.0.0.1:8888/tree?token=` 84 | 85 | ### Your coding starting point 86 | 87 | Once you dev environement is set up, you can start with the first notebook explaining how to use Brevitas using a simple FC net on FashionMNIST. 88 | 89 | Every notbook is in this subfolder sub dir. 90 | 91 | ### Regarding The Manual Infernce 92 | 93 | Manual inference implies you use code snippets and Xilinx tools rather than a notebook, material will be located in a dedicated sub-directory. Alongside another [specialized README](./3_manual_inference/README.md) to help you figuring it out. 94 | 95 | See you on the other side ! 96 | 97 | ## General resources 98 | 99 | > [!IMPORTANT] 100 | > Notebooks also contains context-dependant resources to explore and understand important topics. 101 | 102 | If you want to truly understand what you are doing and become proficient in engineering, these resources are your go-to. 103 | 104 | I personnally used ALL of these resources, so I will anotate some tips with each one so you don't have to waste time on less valuable knowledge. 105 | 106 | | Resource name | Link | Comment | 107 | | ----------------------------------------------- | ------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | 108 | | FINN : notebooks | [GitHub repo](https://github.com/Xilinx/finn/tree/main/notebooks) | **MOST** of my tutorial material is based on these noptebooks, **HUGE** credit to the FINN dev team ! | 109 | | FINN : setup dev environement | [Link](https://finn.readthedocs.io/en/latest/getting_started.html#running-finn-in-docker) | Lunch a notebook environement in docker. Note : use `bash` command and **NOT** `sh` | 110 | | My tutorials on Xilinx's HW tools and ZYNQ FPGA | [Youtube playlist](https://www.youtube.com/watch?v=DQHTSelupDs&list=PLCn4eX6oSgMbgI4WERry0XnHiVysNqtGc) | To get started on all the hardware tools we'll use, from A to Z, fast. | 111 | | Quantization | [Youtube Link](https://www.youtube.com/watch?v=0VdNflU08yA) | Watch it all, great lecture | 112 | | Quantization : 2 | [Paper link](https://arxiv.org/pdf/2106.08295) | Great paper, very accessible for ost people. _(You can skip the backward propagation part.)_ | 113 | | Streamlining | [Paper link](https://arxiv.org/pdf/1709.04060) | Streamlining is the main FINN "math transformation" of our model that will convert it to fully integer (some QQN still have FP bias for example). After that, we only have INTs in the model, making it MUCH easier to work with in hardware. | 114 | | FINN : how does it work | [Paper 1 Link](https://arxiv.org/pdf/1612.07119)
[Paper 2 Link](https://arxiv.org/pdf/1809.04570) | Very nice papers explaining how FINN operate to generate HW from streamlined model (also benchmarks). 1st was on the previous FINN version on Binary Neural Networks (BNN) which make it easier to read in my opinion. | 115 | | Write you own DMA firmware | [Youtube link](https://www.youtube.com/watch?v=aySO9jCKj9g) | This video is from me and explains everything you need to know on DMA. You can also simply copy-paste the firmaware I wrote but this is interresting to have better understanding if needed. | 116 | -------------------------------------------------------------------------------- /2 AXI IP Hello world custom LED driver/Hardware/LED_IP_v1_0_S00_AXI.v: -------------------------------------------------------------------------------- 1 | 2 | `timescale 1 ns / 1 ps 3 | 4 | module LED_IP_v1_0_S00_AXI # 5 | ( 6 | // Users to add parameters here 7 | 8 | // User parameters ends 9 | // Do not modify the parameters beyond this line 10 | 11 | // Width of S_AXI data bus 12 | parameter integer C_S_AXI_DATA_WIDTH = 32, 13 | // Width of S_AXI address bus 14 | parameter integer C_S_AXI_ADDR_WIDTH = 4 15 | ) 16 | ( 17 | // Users to add ports here 18 | 19 | output wire led_out, 20 | 21 | // User ports ends 22 | // Do not modify the ports beyond this line 23 | 24 | // Global Clock Signal 25 | input wire S_AXI_ACLK, 26 | // Global Reset Signal. This Signal is Active LOW 27 | input wire S_AXI_ARESETN, 28 | // Write address (issued by master, acceped by Slave) 29 | input wire [C_S_AXI_ADDR_WIDTH-1 : 0] S_AXI_AWADDR, 30 | // Write channel Protection type. This signal indicates the 31 | // privilege and security level of the transaction, and whether 32 | // the transaction is a data access or an instruction access. 33 | input wire [2 : 0] S_AXI_AWPROT, 34 | // Write address valid. This signal indicates that the master signaling 35 | // valid write address and control information. 36 | input wire S_AXI_AWVALID, 37 | // Write address ready. This signal indicates that the slave is ready 38 | // to accept an address and associated control signals. 39 | output wire S_AXI_AWREADY, 40 | // Write data (issued by master, acceped by Slave) 41 | input wire [C_S_AXI_DATA_WIDTH-1 : 0] S_AXI_WDATA, 42 | // Write strobes. This signal indicates which byte lanes hold 43 | // valid data. There is one write strobe bit for each eight 44 | // bits of the write data bus. 45 | input wire [(C_S_AXI_DATA_WIDTH/8)-1 : 0] S_AXI_WSTRB, 46 | // Write valid. This signal indicates that valid write 47 | // data and strobes are available. 48 | input wire S_AXI_WVALID, 49 | // Write ready. This signal indicates that the slave 50 | // can accept the write data. 51 | output wire S_AXI_WREADY, 52 | // Write response. This signal indicates the status 53 | // of the write transaction. 54 | output wire [1 : 0] S_AXI_BRESP, 55 | // Write response valid. This signal indicates that the channel 56 | // is signaling a valid write response. 57 | output wire S_AXI_BVALID, 58 | // Response ready. This signal indicates that the master 59 | // can accept a write response. 60 | input wire S_AXI_BREADY, 61 | // Read address (issued by master, acceped by Slave) 62 | input wire [C_S_AXI_ADDR_WIDTH-1 : 0] S_AXI_ARADDR, 63 | // Protection type. This signal indicates the privilege 64 | // and security level of the transaction, and whether the 65 | // transaction is a data access or an instruction access. 66 | input wire [2 : 0] S_AXI_ARPROT, 67 | // Read address valid. This signal indicates that the channel 68 | // is signaling valid read address and control information. 69 | input wire S_AXI_ARVALID, 70 | // Read address ready. This signal indicates that the slave is 71 | // ready to accept an address and associated control signals. 72 | output wire S_AXI_ARREADY, 73 | // Read data (issued by slave) 74 | output wire [C_S_AXI_DATA_WIDTH-1 : 0] S_AXI_RDATA, 75 | // Read response. This signal indicates the status of the 76 | // read transfer. 77 | output wire [1 : 0] S_AXI_RRESP, 78 | // Read valid. This signal indicates that the channel is 79 | // signaling the required read data. 80 | output wire S_AXI_RVALID, 81 | // Read ready. This signal indicates that the master can 82 | // accept the read data and response information. 83 | input wire S_AXI_RREADY 84 | ); 85 | 86 | // AXI4LITE signals 87 | reg [C_S_AXI_ADDR_WIDTH-1 : 0] axi_awaddr; 88 | reg axi_awready; 89 | reg axi_wready; 90 | reg [1 : 0] axi_bresp; 91 | reg axi_bvalid; 92 | reg [C_S_AXI_ADDR_WIDTH-1 : 0] axi_araddr; 93 | reg axi_arready; 94 | reg [C_S_AXI_DATA_WIDTH-1 : 0] axi_rdata; 95 | reg [1 : 0] axi_rresp; 96 | reg axi_rvalid; 97 | 98 | // Example-specific design signals 99 | // local parameter for addressing 32 bit / 64 bit C_S_AXI_DATA_WIDTH 100 | // ADDR_LSB is used for addressing 32/64 bit registers/memories 101 | // ADDR_LSB = 2 for 32 bits (n downto 2) 102 | // ADDR_LSB = 3 for 64 bits (n downto 3) 103 | localparam integer ADDR_LSB = (C_S_AXI_DATA_WIDTH/32) + 1; 104 | localparam integer OPT_MEM_ADDR_BITS = 1; 105 | //---------------------------------------------- 106 | //-- Signals for user logic register space example 107 | //------------------------------------------------ 108 | //-- Number of Slave Registers 4 109 | reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg0; 110 | reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg1; 111 | reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg2; 112 | reg [C_S_AXI_DATA_WIDTH-1:0] slv_reg3; 113 | wire slv_reg_rden; 114 | wire slv_reg_wren; 115 | reg [C_S_AXI_DATA_WIDTH-1:0] reg_data_out; 116 | integer byte_index; 117 | reg aw_en; 118 | 119 | // I/O Connections assignments 120 | 121 | assign S_AXI_AWREADY = axi_awready; 122 | assign S_AXI_WREADY = axi_wready; 123 | assign S_AXI_BRESP = axi_bresp; 124 | assign S_AXI_BVALID = axi_bvalid; 125 | assign S_AXI_ARREADY = axi_arready; 126 | assign S_AXI_RDATA = axi_rdata; 127 | assign S_AXI_RRESP = axi_rresp; 128 | assign S_AXI_RVALID = axi_rvalid; 129 | // Implement axi_awready generation 130 | // axi_awready is asserted for one S_AXI_ACLK clock cycle when both 131 | // S_AXI_AWVALID and S_AXI_WVALID are asserted. axi_awready is 132 | // de-asserted when reset is low. 133 | 134 | always @( posedge S_AXI_ACLK ) 135 | begin 136 | if ( S_AXI_ARESETN == 1'b0 ) 137 | begin 138 | axi_awready <= 1'b0; 139 | aw_en <= 1'b1; 140 | end 141 | else 142 | begin 143 | if (~axi_awready && S_AXI_AWVALID && S_AXI_WVALID && aw_en) 144 | begin 145 | // slave is ready to accept write address when 146 | // there is a valid write address and write data 147 | // on the write address and data bus. This design 148 | // expects no outstanding transactions. 149 | axi_awready <= 1'b1; 150 | aw_en <= 1'b0; 151 | end 152 | else if (S_AXI_BREADY && axi_bvalid) 153 | begin 154 | aw_en <= 1'b1; 155 | axi_awready <= 1'b0; 156 | end 157 | else 158 | begin 159 | axi_awready <= 1'b0; 160 | end 161 | end 162 | end 163 | 164 | // Implement axi_awaddr latching 165 | // This process is used to latch the address when both 166 | // S_AXI_AWVALID and S_AXI_WVALID are valid. 167 | 168 | always @( posedge S_AXI_ACLK ) 169 | begin 170 | if ( S_AXI_ARESETN == 1'b0 ) 171 | begin 172 | axi_awaddr <= 0; 173 | end 174 | else 175 | begin 176 | if (~axi_awready && S_AXI_AWVALID && S_AXI_WVALID && aw_en) 177 | begin 178 | // Write Address latching 179 | axi_awaddr <= S_AXI_AWADDR; 180 | end 181 | end 182 | end 183 | 184 | // Implement axi_wready generation 185 | // axi_wready is asserted for one S_AXI_ACLK clock cycle when both 186 | // S_AXI_AWVALID and S_AXI_WVALID are asserted. axi_wready is 187 | // de-asserted when reset is low. 188 | 189 | always @( posedge S_AXI_ACLK ) 190 | begin 191 | if ( S_AXI_ARESETN == 1'b0 ) 192 | begin 193 | axi_wready <= 1'b0; 194 | end 195 | else 196 | begin 197 | if (~axi_wready && S_AXI_WVALID && S_AXI_AWVALID && aw_en ) 198 | begin 199 | // slave is ready to accept write data when 200 | // there is a valid write address and write data 201 | // on the write address and data bus. This design 202 | // expects no outstanding transactions. 203 | axi_wready <= 1'b1; 204 | end 205 | else 206 | begin 207 | axi_wready <= 1'b0; 208 | end 209 | end 210 | end 211 | 212 | // Implement memory mapped register select and write logic generation 213 | // The write data is accepted and written to memory mapped registers when 214 | // axi_awready, S_AXI_WVALID, axi_wready and S_AXI_WVALID are asserted. Write strobes are used to 215 | // select byte enables of slave registers while writing. 216 | // These registers are cleared when reset (active low) is applied. 217 | // Slave register write enable is asserted when valid address and data are available 218 | // and the slave is ready to accept the write address and write data. 219 | assign slv_reg_wren = axi_wready && S_AXI_WVALID && axi_awready && S_AXI_AWVALID; 220 | 221 | always @( posedge S_AXI_ACLK ) 222 | begin 223 | if ( S_AXI_ARESETN == 1'b0 ) 224 | begin 225 | slv_reg0 <= 0; 226 | slv_reg1 <= 0; 227 | slv_reg2 <= 0; 228 | slv_reg3 <= 0; 229 | end 230 | else begin 231 | if (slv_reg_wren) 232 | begin 233 | case ( axi_awaddr[ADDR_LSB+OPT_MEM_ADDR_BITS:ADDR_LSB] ) 234 | 2'h0: 235 | for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) 236 | if ( S_AXI_WSTRB[byte_index] == 1 ) begin 237 | // Respective byte enables are asserted as per write strobes 238 | // Slave register 0 239 | slv_reg0[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; 240 | end 241 | 2'h1: 242 | for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) 243 | if ( S_AXI_WSTRB[byte_index] == 1 ) begin 244 | // Respective byte enables are asserted as per write strobes 245 | // Slave register 1 246 | slv_reg1[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; 247 | end 248 | 2'h2: 249 | for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) 250 | if ( S_AXI_WSTRB[byte_index] == 1 ) begin 251 | // Respective byte enables are asserted as per write strobes 252 | // Slave register 2 253 | slv_reg2[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; 254 | end 255 | 2'h3: 256 | for ( byte_index = 0; byte_index <= (C_S_AXI_DATA_WIDTH/8)-1; byte_index = byte_index+1 ) 257 | if ( S_AXI_WSTRB[byte_index] == 1 ) begin 258 | // Respective byte enables are asserted as per write strobes 259 | // Slave register 3 260 | slv_reg3[(byte_index*8) +: 8] <= S_AXI_WDATA[(byte_index*8) +: 8]; 261 | end 262 | default : begin 263 | slv_reg0 <= slv_reg0; 264 | slv_reg1 <= slv_reg1; 265 | slv_reg2 <= slv_reg2; 266 | slv_reg3 <= slv_reg3; 267 | end 268 | endcase 269 | end 270 | end 271 | end 272 | 273 | // Implement write response logic generation 274 | // The write response and response valid signals are asserted by the slave 275 | // when axi_wready, S_AXI_WVALID, axi_wready and S_AXI_WVALID are asserted. 276 | // This marks the acceptance of address and indicates the status of 277 | // write transaction. 278 | 279 | always @( posedge S_AXI_ACLK ) 280 | begin 281 | if ( S_AXI_ARESETN == 1'b0 ) 282 | begin 283 | axi_bvalid <= 0; 284 | axi_bresp <= 2'b0; 285 | end 286 | else 287 | begin 288 | if (axi_awready && S_AXI_AWVALID && ~axi_bvalid && axi_wready && S_AXI_WVALID) 289 | begin 290 | // indicates a valid write response is available 291 | axi_bvalid <= 1'b1; 292 | axi_bresp <= 2'b0; // 'OKAY' response 293 | end // work error responses in future 294 | else 295 | begin 296 | if (S_AXI_BREADY && axi_bvalid) 297 | //check if bready is asserted while bvalid is high) 298 | //(there is a possibility that bready is always asserted high) 299 | begin 300 | axi_bvalid <= 1'b0; 301 | end 302 | end 303 | end 304 | end 305 | 306 | // Implement axi_arready generation 307 | // axi_arready is asserted for one S_AXI_ACLK clock cycle when 308 | // S_AXI_ARVALID is asserted. axi_awready is 309 | // de-asserted when reset (active low) is asserted. 310 | // The read address is also latched when S_AXI_ARVALID is 311 | // asserted. axi_araddr is reset to zero on reset assertion. 312 | 313 | always @( posedge S_AXI_ACLK ) 314 | begin 315 | if ( S_AXI_ARESETN == 1'b0 ) 316 | begin 317 | axi_arready <= 1'b0; 318 | axi_araddr <= 32'b0; 319 | end 320 | else 321 | begin 322 | if (~axi_arready && S_AXI_ARVALID) 323 | begin 324 | // indicates that the slave has acceped the valid read address 325 | axi_arready <= 1'b1; 326 | // Read address latching 327 | axi_araddr <= S_AXI_ARADDR; 328 | end 329 | else 330 | begin 331 | axi_arready <= 1'b0; 332 | end 333 | end 334 | end 335 | 336 | // Implement axi_arvalid generation 337 | // axi_rvalid is asserted for one S_AXI_ACLK clock cycle when both 338 | // S_AXI_ARVALID and axi_arready are asserted. The slave registers 339 | // data are available on the axi_rdata bus at this instance. The 340 | // assertion of axi_rvalid marks the validity of read data on the 341 | // bus and axi_rresp indicates the status of read transaction.axi_rvalid 342 | // is deasserted on reset (active low). axi_rresp and axi_rdata are 343 | // cleared to zero on reset (active low). 344 | always @( posedge S_AXI_ACLK ) 345 | begin 346 | if ( S_AXI_ARESETN == 1'b0 ) 347 | begin 348 | axi_rvalid <= 0; 349 | axi_rresp <= 0; 350 | end 351 | else 352 | begin 353 | if (axi_arready && S_AXI_ARVALID && ~axi_rvalid) 354 | begin 355 | // Valid read data is available at the read data bus 356 | axi_rvalid <= 1'b1; 357 | axi_rresp <= 2'b0; // 'OKAY' response 358 | end 359 | else if (axi_rvalid && S_AXI_RREADY) 360 | begin 361 | // Read data is accepted by the master 362 | axi_rvalid <= 1'b0; 363 | end 364 | end 365 | end 366 | 367 | // Implement memory mapped register select and read logic generation 368 | // Slave register read enable is asserted when valid address is available 369 | // and the slave is ready to accept the read address. 370 | assign slv_reg_rden = axi_arready & S_AXI_ARVALID & ~axi_rvalid; 371 | always @(*) 372 | begin 373 | // Address decoding for reading registers 374 | case ( axi_araddr[ADDR_LSB+OPT_MEM_ADDR_BITS:ADDR_LSB] ) 375 | 2'h0 : reg_data_out <= slv_reg0; 376 | 2'h1 : reg_data_out <= slv_reg1; 377 | 2'h2 : reg_data_out <= slv_reg2; 378 | 2'h3 : reg_data_out <= slv_reg3; 379 | default : reg_data_out <= 0; 380 | endcase 381 | end 382 | 383 | // Output register or memory read data 384 | always @( posedge S_AXI_ACLK ) 385 | begin 386 | if ( S_AXI_ARESETN == 1'b0 ) 387 | begin 388 | axi_rdata <= 0; 389 | end 390 | else 391 | begin 392 | // When there is a valid read address (S_AXI_ARVALID) with 393 | // acceptance of read address by the slave (axi_arready), 394 | // output the read dada 395 | if (slv_reg_rden) 396 | begin 397 | axi_rdata <= reg_data_out; // register read data 398 | end 399 | end 400 | end 401 | 402 | // Add user logic here 403 | 404 | wire led_out; 405 | 406 | led_axi_ip custom_led_driver( 407 | .slv_reg(slv_reg0), 408 | .clk(S_AXI_ACLK), 409 | .led_out(led_out) 410 | ); 411 | 412 | // User logic ends 413 | 414 | endmodule 415 | -------------------------------------------------------------------------------- /8 Python to FPGA/1_train_quantized_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "a71a8839-7c58-4862-a5f0-0f7e97085918", 6 | "metadata": {}, 7 | "source": [ 8 | "# THIS NOTEBOOK TRAINS A FASHION-MNIST MODEL USING 0-255 VALUES (UINT8)\n", 9 | "\n", 10 | "> Note that the end model, oexported in a FINN-ONNX format will be Quantized. Quantization is handled by Brevitas.\n", 11 | "\n", 12 | "After creating the docker environement,\n", 13 | "\n", 14 | "The goal is to create a MNIST Fasion model in pytorch and experiment with the different parameters\n", 15 | "\n", 16 | "Then, we will do the same model but fully quantized and start adapting it for FINN\n", 17 | "\n", 18 | "> Side note : We'll use Quantization Aware Training (QAT) for this tutorial, but another possibility is to use Post-Training Quantization (PTQ) [Here is a resource to learn more but it's not really necessary for this tutorial](https://www.youtube.com/watch?v=0VdNflU08yA) " 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "id": "2a09c1a7-e82e-4d14-bced-29aaf340638c", 24 | "metadata": {}, 25 | "source": [ 26 | "## Base model creation\n", 27 | "\n", 28 | "Before going anyfurther, note that this notebook was meant to be run in the FINN docker environment, from a linux host.\n", 29 | "\n", 30 | "Use the command ```bash run-docker.sh notbook``` and use the jupyter link for your vscode/vim editor or simply open it in your browser if you like jupyter you will also need to setup you env variables right.\n", 31 | "\n", 32 | "I encourage you to check the FINN docs for this part, you can find more precise guidelines in the resources of the main [readme](./README.md).\n", 33 | "\n", 34 | "> Don't forget to use your own username below. ```root_dir = f\"/tmp/finn_dev_{user_name}\"``` is the common mounted folder where you'll be able to access all outputs from the FINN docker container from you host machine to then use it for further treatment." 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 1, 40 | "id": "a4e9dec1-299c-45ec-bc1a-41d4825fea44", 41 | "metadata": {}, 42 | "outputs": [ 43 | { 44 | "name": "stdout", 45 | "output_type": "stream", 46 | "text": [ 47 | "/tmp/finn_dev_rootmin\n" 48 | ] 49 | } 50 | ], 51 | "source": [ 52 | "import torch\n", 53 | "from torchvision import datasets, transforms\n", 54 | "from torch.utils.data import DataLoader\n", 55 | "import os\n", 56 | "\n", 57 | "user_name = \"rootmin\" # REPLACE THIS WITH YOUR HOST MACHINE USER NAME \n", 58 | "root_dir = f\"/tmp/finn_dev_{user_name}\"\n", 59 | "\n", 60 | "print(root_dir)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "id": "d5934e06", 66 | "metadata": {}, 67 | "source": [ 68 | "Import the data and transform it, we don't normalize to really make it as simple as possible, the only transformation is to convert it to a pytorch tensor." 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 2, 74 | "id": "b9651dd1-25d9-4c48-8e02-08c28d29fa84", 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "# Define a transform to normalize the data\n", 79 | "transform = transforms.Compose([\n", 80 | " transforms.ToTensor(), # Convert the image to a PyTorch tensor\n", 81 | " #transforms.Normalize((0.5,), (0.5,)) # Normalize the tensor with mean and std\n", 82 | "]);\n", 83 | "\n", 84 | "# Load the training dataset\n", 85 | "train_dataset = datasets.FashionMNIST(\n", 86 | " root='./data', # Directory to save the dataset\n", 87 | " train=True, # Load the training set\n", 88 | " download=True, # Download the dataset if it doesn't exist\n", 89 | " transform=transform # Apply the defined transformations\n", 90 | ");\n", 91 | "\n", 92 | "# Load the test dataset\n", 93 | "test_dataset = datasets.FashionMNIST(\n", 94 | " root='./data',\n", 95 | " train=False, # Load the test set\n", 96 | " download=True,\n", 97 | " transform=transform\n", 98 | ");" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "id": "d20ba53e", 104 | "metadata": {}, 105 | "source": [ 106 | "Let's visualize the data a little bit... We see the data is ranging from 0 to 1 but we know our FPGA AI model IP will expect INTEGER values... More on that later !" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 4, 112 | "id": "7cb48f47-1369-4f87-8c9b-c0e29a8de5e5", 113 | "metadata": {}, 114 | "outputs": [ 115 | { 116 | "name": "stdout", 117 | "output_type": "stream", 118 | "text": [ 119 | "Min : 0.0 /// Max : 0.78039217\n", 120 | "Data type : float32\n" 121 | ] 122 | }, 123 | { 124 | "data": { 125 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAgzElEQVR4nO3de2zV9f3H8Vdb2tNSerGW3qSwAipTLm4IHdMhjg7oMiLCFm/JwBiIrhiROU0XFdmWdT9MnNEw+GcDTQQvmUA0jkXRlrgBCkIIblba1bUMWuTSntLaC+339wexW+Xm58Ppebfl+UhOQs85r34//fR7eHE4p+/GBEEQCACAKIu1XgAA4PJEAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMDEEOsFfFV3d7cOHz6slJQUxcTEWC8HAOAoCAI1NzcrLy9PsbHnf57T7wro8OHDys/Pt14GAOAS1dXVacSIEee9vd8VUEpKivUS0IcyMzOdM7fccotz5qc//alzRpKampqcM5WVlc6Zzs5O50xaWppzprCw0DkjSR9++KFzZuXKlc6ZtrY25wwGjov9fd5nBbR69Wo9/fTTqq+v16RJk/T8889r6tSpF80N1v928/m6BuOYvgs9HT+f+Ph450xycrJzRvIrhsTEROeMzz74HMd3H3yONRgfuzxuL83F9q9P3oTwyiuvaPny5VqxYoU++ugjTZo0SbNnz9bRo0f74nAAgAGoTwromWee0eLFi3Xvvffquuuu09q1azV06FD96U9/6ovDAQAGoIgXUEdHh/bs2aOioqL/HiQ2VkVFRdqxY8dZ929vb1c4HO51AQAMfhEvoGPHjqmrq0vZ2dm9rs/OzlZ9ff1Z9y8rK1NaWlrPhXfAAcDlwfwHUUtLS9XU1NRzqaurs14SACAKIv4uuMzMTMXFxamhoaHX9Q0NDcrJyTnr/qFQSKFQKNLLAAD0cxF/BpSQkKDJkydr27ZtPdd1d3dr27ZtmjZtWqQPBwAYoPrk54CWL1+uhQsX6sYbb9TUqVP17LPPqqWlRffee29fHA4AMAD1SQHdcccd+vzzz/Xkk0+qvr5eN9xwg7Zu3XrWGxMAAJevmKCf/dhuOBz2GjkSTf35p6N9Rt089NBDXsf637faf10+r/e1tLRE5TiSNG7cOOdMtMZH+UxpOHTokNexjhw54pxJSkpyzpw4ccI5s337dufM888/75yRpJMnT3rlcEZTU5NSU1PPe7v5u+AAAJcnCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJhhG6iFaw0jHjBnjnHnjjTecM1/95YFfV1tbm3PGZ6BmV1eXc6a9vd05I/kNxxw2bJhzJlpfU0JCgnNGkoYPH+6cGTLEfbi+z/p8Mq2trc4ZSVq7dq1zZtOmTV7HGowYRgoA6JcoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACaYht2Pvfrqq86ZzMxM54zPBGhJio+Pd874nG4+E7S7u7udM5LfxGmfjM8k8VAo5JzxfSz5fG99psT7iI11/3ez71Rwn32YN2+ec+bUqVPOmYGAadgAgH6JAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACAiSHWC7hc5ObmOmdycnKcM01NTc4Z30GNp0+fds4MHTrUOZOcnOyc8RlYKfkNMe3q6opKJjEx0Tnjs3eS3/p8zgef4/gM7vQZ/ir57d/cuXOdMxs3bnTODAY8AwIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCYaRRcsUVVzhnfIaR+gx39B1G6jOo0WdgZSgUcs74DBWVpJiYmKhkfMTFxTlnfNfms38+x/I5X4cPH+6cOXbsmHNG8nts/OAHP3DOMIwUAIAoooAAACYiXkBPPfWUYmJiel3GjRsX6cMAAAa4PnkN6Prrr9c777zz34MM4aUmAEBvfdIMQ4YM8XoBHQBw+eiT14AOHjyovLw8jR49Wvfcc49qa2vPe9/29naFw+FeFwDA4BfxAiosLNT69eu1detWrVmzRjU1Nfre976n5ubmc96/rKxMaWlpPZf8/PxILwkA0A9FvICKi4v1k5/8RBMnTtTs2bP11ltvqbGxUa+++uo5719aWqqmpqaeS11dXaSXBADoh/r83QHp6em65pprVFVVdc7bQ6GQ1w8aAgAGtj7/OaBTp06purpaubm5fX0oAMAAEvECeuSRR1RRUaHPPvtMf//733X77bcrLi5Od911V6QPBQAYwCL+X3CHDh3SXXfdpePHj2v48OG6+eabtXPnTq/5TQCAwSviBfTyyy9H+lMOChMnTnTO+Ayf9Pn5q9hYvyfCPrm2tjbnzOHDh50z1dXVzhlJ+uyzz5wzLS0tzhmfffA5Tmdnp3NG8hvC6XOO/+hHP3LO+Oxdenq6c0aShg0b5pzxGdJ7uWIWHADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMxQRAE1ov4X+FwWGlpadbL6Beuuuoq58w999zjnBk/frxzRpJ++9vfOmc++eQTr2NFy9ChQ50zSUlJUcn4DLlMTEx0zkh+g0/P90snI+3DDz90zvg8liSptbXVOXPy5EnnzJQpU5wzA0FTU5NSU1PPezvPgAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJoZYL+BysWrVKudMd3e3c+a9995zzuzdu9c5I+mCU27Px2cadkxMjHMmHA47ZyTp+PHjzpnGxkbnTGdnp3PGZ3C9z95J8ppIf/311ztnqqurnTM+E99PnTrlnJH8zof29navY12OeAYEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADAREzgM+GwD4XDYa9BiP3dzJkzo5LJzMx0zsyaNcs5I0kvvPCCc6a8vNw5k56e7pwZO3asc0aShg0b5pzxeQjFxcU5ZxISEpwzHR0dzhnJbxDuxx9/7Jxpbm52zvz4xz92zvjuw8mTJ50z8+fPd85897vfdc6cOHHCORNtTU1NFxxazDMgAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJhhGGiUffvihc6azs9M5c/jwYedMcnKyc0aSsrOznTPf+ta3vI7lymfvJKm9vd0509XV5ZzxedidPn3aOeMz9FSS4uPjnTM+g1x9hn1+8MEHzpn6+nrnjCS99dZbzhmfx9O6deucMwMBw0gBAP0SBQQAMOFcQNu3b9fcuXOVl5enmJgYbd68udftQRDoySefVG5urpKSklRUVKSDBw9Gar0AgEHCuYBaWlo0adIkrV69+py3r1q1Ss8995zWrl2rXbt2KTk5WbNnz1ZbW9slLxYAMHgMcQ0UFxeruLj4nLcFQaBnn31Wjz/+uG677TZJ0osvvqjs7Gxt3rxZd95556WtFgAwaET0NaCamhrV19erqKio57q0tDQVFhZqx44d58y0t7crHA73ugAABr+IFtCXb3X86ttzs7Ozz/s2yLKyMqWlpfVc8vPzI7kkAEA/Zf4uuNLSUjU1NfVc6urqrJcEAIiCiBZQTk6OJKmhoaHX9Q0NDT23fVUoFFJqamqvCwBg8ItoARUUFCgnJ0fbtm3ruS4cDmvXrl2aNm1aJA8FABjgnN8Fd+rUKVVVVfV8XFNTo3379ikjI0MjR47UsmXL9Jvf/EZXX321CgoK9MQTTygvL0/z5s2L5LoBAAOccwHt3r1bt956a8/Hy5cvlyQtXLhQ69ev16OPPqqWlhYtWbJEjY2Nuvnmm7V161YlJiZGbtUAgAGPYaRRUlpa6pyZOXOmc2bs2LHOmb/85S/OGUnav3+/cyYrK8s5U1tb65yJ5hBOn39cDRni/G8/Lz4DTCWptbXVOdPR0eGc8XnNd9SoUc6ZZcuWOWckqaKiwjkzY8YM54zPkN59+/Y5Z6KNYaQAgH6JAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGAiOiN5oeuuu84588UXXzhn6uvrnTM7d+50zkjSTTfd5JwZP368c8ZnYLvvNGwf3d3dzhmfrykmJiYqGclv/3z2wed83bBhg3PGd3L0v/71L+dMXV2dc+bTTz91zgwGPAMCAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABggmGkUTJ69GjnzJAh7t+eESNGOGd8BkJKUmtrq3Pm9OnTzpnm5mbnTGys37+tfNbnM7izq6vLORNNycnJzpnOzk7nzPDhw50zPuddSkqKc0byezylp6c7Z3JycpwzPoNS+xueAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADDBMNIo8RmO2dbW5pzxGXLpM+xTkoYOHeqc6e7uds74DPv0yUhSTEyMc8bne+uT8Vmbz35LfutLSEhwzvh8n44dO+ac8ZWRkeGc8RkinJeX55xhGCkAAJ4oIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYYBhplPTn4ZMnTpxwzkhSUlKSc8ZnfT57FwSBc8aXz7F8Mj7nQ2dnp3NGkkKhkHPGZwinz/e2vr7eOeMz2FfyG+7rM2A1JSXFOTMY8AwIAGCCAgIAmHAuoO3bt2vu3LnKy8tTTEyMNm/e3Ov2RYsWKSYmptdlzpw5kVovAGCQcC6glpYWTZo0SatXrz7vfebMmaMjR470XDZu3HhJiwQADD7OrxoWFxeruLj4gvcJhULKycnxXhQAYPDrk9eAysvLlZWVpWuvvVYPPPCAjh8/ft77tre3KxwO97oAAAa/iBfQnDlz9OKLL2rbtm36v//7P1VUVKi4uPi8b2csKytTWlpazyU/Pz/SSwIA9EMR/zmgO++8s+fPEyZM0MSJEzVmzBiVl5dr5syZZ92/tLRUy5cv7/k4HA5TQgBwGejzt2GPHj1amZmZqqqqOuftoVBIqampvS4AgMGvzwvo0KFDOn78uHJzc/v6UACAAcT5v+BOnTrV69lMTU2N9u3bp4yMDGVkZGjlypVasGCBcnJyVF1drUcffVRjx47V7NmzI7pwAMDA5lxAu3fv1q233trz8Zev3yxcuFBr1qzR/v379cILL6ixsVF5eXmaNWuWfv3rX3vNlgIADF7OBTRjxowLDlL861//ekkLwn/5DDX0GfbZ0NDgnJH8hpFGi8/gTslv/6I1hDNaA22l6A3h9NHR0RGV40h+e96f966/YRYcAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMBExH8lN87tQhPEI8ln+vHJkye9jhUfH++c8dkHnwnVvlOgT58+7ZzxmZjssw/ROoek6O2Dz/fJZwp7Y2Ojc0aSEhMTvXL99Tj9Dc+AAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmGAYKbz5DFCM1mBRn8GYvsfyEa3Bor7H8cl1dHQ4Z3y+Tz7DSKuqqpwzknTDDTc4Z3z2IVrnXX/DMyAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmGEYaJc3Nzc6Z5ORk54zvEE4fPkMhfQY1+gzG9Bl66stnfT7DJ30ycXFxzhnJ72vq7Ox0zkRr0Gxtba1zRpJuvPFG50x7e7tzxvf7NNDxDAgAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJhpF6SEhIcM74DHf0GboYDoedM77i4+OdMz4DK3347Lfk973t6upyzvgM4fQxZIjfQ9zna/IZAOvzffL5mj777DPnjOR3jvvsnc9xBgOeAQEATFBAAAATTgVUVlamKVOmKCUlRVlZWZo3b54qKyt73aetrU0lJSW68sorNWzYMC1YsEANDQ0RXTQAYOBzKqCKigqVlJRo586devvtt9XZ2alZs2appaWl5z4PP/yw3njjDb322muqqKjQ4cOHNX/+/IgvHAAwsDm9mrd169ZeH69fv15ZWVnas2ePpk+frqamJv3xj3/Uhg0b9P3vf1+StG7dOn3zm9/Uzp079Z3vfCdyKwcADGiX9BpQU1OTJCkjI0OStGfPHnV2dqqoqKjnPuPGjdPIkSO1Y8eOc36O9vZ2hcPhXhcAwODnXUDd3d1atmyZbrrpJo0fP16SVF9fr4SEBKWnp/e6b3Z2turr68/5ecrKypSWltZzyc/P910SAGAA8S6gkpISHThwQC+//PIlLaC0tFRNTU09l7q6ukv6fACAgcHrp9SWLl2qN998U9u3b9eIESN6rs/JyVFHR4caGxt7PQtqaGhQTk7OOT9XKBRSKBTyWQYAYABzegYUBIGWLl2qTZs26d1331VBQUGv2ydPnqz4+Hht27at57rKykrV1tZq2rRpkVkxAGBQcHoGVFJSog0bNmjLli1KSUnpeV0nLS1NSUlJSktL03333afly5crIyNDqampevDBBzVt2jTeAQcA6MWpgNasWSNJmjFjRq/r161bp0WLFkmSfv/73ys2NlYLFixQe3u7Zs+erT/84Q8RWSwAYPBwKqCvMzgwMTFRq1ev1urVq70X1d/5DFCM1tDF//znP84ZX3Fxcc4Zn33wGXLpy2dIaLQyPvvgMxhTit731md9KSkpzplPP/3UOSP5PQZ9vk/RGk7b3zALDgBgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgwus3osKdz6Tg2Fj3fx9Ecxq2z/p89iE+Pt4547M2yW8KdLSmdftMTPbZb8lvSnW0JjqnpaU5Zz7++GOvY/mcRz4ZpmEDABBFFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATDCMNEqiNYy0trbWOeOrvb3dOfP55587Z5qbm50zp0+fds74itbgzmgOufTJhUIh50xiYqJzJjk52TnjO6TXZx98htMOGXJ5/lXMMyAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmLs8JeJfIZ0Ch71BIV+FwOCrHkfyGT/pkOjs7nTMZGRnOGclvsKjP4NNonQ++x/EZfOpz7vkMFs3Ly3POtLW1OWckKSEhwTnjM1jU5ziDAc+AAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmGAYqYe4uDjnTEdHh3PGZ8ilzxBJX3/+85+dM6mpqc6Zo0ePOmd8BkJKfnvuw2d90RyC293d7Zzx2bumpibnzO7du50zvny+pv7+uO1PLs+vGgBgjgICAJhwKqCysjJNmTJFKSkpysrK0rx581RZWdnrPjNmzFBMTEyvy/333x/RRQMABj6nAqqoqFBJSYl27typt99+W52dnZo1a5ZaWlp63W/x4sU6cuRIz2XVqlURXTQAYOBzeiV069atvT5ev369srKytGfPHk2fPr3n+qFDhyonJycyKwQADEqX9BrQl+9g+eqvP37ppZeUmZmp8ePHq7S0VK2tref9HO3t7QqHw70uAIDBz/tt2N3d3Vq2bJluuukmjR8/vuf6u+++W6NGjVJeXp7279+vxx57TJWVlXr99dfP+XnKysq0cuVK32UAAAYo7wIqKSnRgQMH9P777/e6fsmSJT1/njBhgnJzczVz5kxVV1drzJgxZ32e0tJSLV++vOfjcDis/Px832UBAAYIrwJaunSp3nzzTW3fvl0jRoy44H0LCwslSVVVVecsoFAopFAo5LMMAMAA5lRAQRDowQcf1KZNm1ReXq6CgoKLZvbt2ydJys3N9VogAGBwciqgkpISbdiwQVu2bFFKSorq6+slSWlpaUpKSlJ1dbU2bNigH/7wh7ryyiu1f/9+Pfzww5o+fbomTpzYJ18AAGBgciqgNWvWSDrzw6b/a926dVq0aJESEhL0zjvv6Nlnn1VLS4vy8/O1YMECPf744xFbMABgcHD+L7gLyc/PV0VFxSUtCABweWAatoekpCTnjM9UYp8Juenp6c4ZX2VlZVE7FmDhYv/oPpf+/rjtTxhGCgAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwATDSD2cOHHCOfPpp586Zw4dOuSc2bVrl3PGl8+AVR8+AyGBSHjppZecM6NHj3bOfPTRR86ZwYBnQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAw0e9mwQ3WuV9tbW3OGZ9Za52dnc4ZX4P1ewV8yedx29ra6pyJ5uM2mi72d0RM0M/+Fjl06JDy8/OtlwEAuER1dXUaMWLEeW/vdwXU3d2tw4cPKyUl5axnAOFwWPn5+aqrq1NqaqrRCu2xD2ewD2ewD2ewD2f0h30IgkDNzc3Ky8tTbOz5X+npd/8FFxsbe8HGlKTU1NTL+gT7EvtwBvtwBvtwBvtwhvU+pKWlXfQ+vAkBAGCCAgIAmBhQBRQKhbRixQqFQiHrpZhiH85gH85gH85gH84YSPvQ796EAAC4PAyoZ0AAgMGDAgIAmKCAAAAmKCAAgIkBU0CrV6/WN77xDSUmJqqwsFAffPCB9ZKi7qmnnlJMTEyvy7hx46yX1ee2b9+uuXPnKi8vTzExMdq8eXOv24Mg0JNPPqnc3FwlJSWpqKhIBw8etFlsH7rYPixatOis82POnDk2i+0jZWVlmjJlilJSUpSVlaV58+apsrKy133a2tpUUlKiK6+8UsOGDdOCBQvU0NBgtOK+8XX2YcaMGWedD/fff7/Ris9tQBTQK6+8ouXLl2vFihX66KOPNGnSJM2ePVtHjx61XlrUXX/99Tpy5EjP5f3337deUp9raWnRpEmTtHr16nPevmrVKj333HNau3atdu3apeTkZM2ePdtrkGR/drF9kKQ5c+b0Oj82btwYxRX2vYqKCpWUlGjnzp16++231dnZqVmzZqmlpaXnPg8//LDeeOMNvfbaa6qoqNDhw4c1f/58w1VH3tfZB0lavHhxr/Nh1apVRis+j2AAmDp1alBSUtLzcVdXV5CXlxeUlZUZrir6VqxYEUyaNMl6GaYkBZs2ber5uLu7O8jJyQmefvrpnusaGxuDUCgUbNy40WCF0fHVfQiCIFi4cGFw2223mazHytGjRwNJQUVFRRAEZ7738fHxwWuvvdZzn3/+85+BpGDHjh1Wy+xzX92HIAiCW265JXjooYfsFvU19PtnQB0dHdqzZ4+Kiop6rouNjVVRUZF27NhhuDIbBw8eVF5enkaPHq177rlHtbW11ksyVVNTo/r6+l7nR1pamgoLCy/L86O8vFxZWVm69tpr9cADD+j48ePWS+pTTU1NkqSMjAxJ0p49e9TZ2dnrfBg3bpxGjhw5qM+Hr+7Dl1566SVlZmZq/PjxKi0t9fpVEX2p3w0j/apjx46pq6tL2dnZva7Pzs7WJ598YrQqG4WFhVq/fr2uvfZaHTlyRCtXrtT3vvc9HThwQCkpKdbLM1FfXy9J5zw/vrztcjFnzhzNnz9fBQUFqq6u1i9/+UsVFxdrx44diouLs15exHV3d2vZsmW66aabNH78eElnzoeEhASlp6f3uu9gPh/OtQ+SdPfdd2vUqFHKy8vT/v379dhjj6myslKvv/664Wp76/cFhP8qLi7u+fPEiRNVWFioUaNG6dVXX9V9991nuDL0B3feeWfPnydMmKCJEydqzJgxKi8v18yZMw1X1jdKSkp04MCBy+J10As53z4sWbKk588TJkxQbm6uZs6cqerqao0ZMybayzynfv9fcJmZmYqLizvrXSwNDQ3KyckxWlX/kJ6ermuuuUZVVVXWSzHz5TnA+XG20aNHKzMzc1CeH0uXLtWbb76p9957r9evb8nJyVFHR4caGxt73X+wng/n24dzKSwslKR+dT70+wJKSEjQ5MmTtW3btp7ruru7tW3bNk2bNs1wZfZOnTql6upq5ebmWi/FTEFBgXJycnqdH+FwWLt27brsz49Dhw7p+PHjg+r8CIJAS5cu1aZNm/Tuu++qoKCg1+2TJ09WfHx8r/OhsrJStbW1g+p8uNg+nMu+ffskqX+dD9bvgvg6Xn755SAUCgXr168P/vGPfwRLliwJ0tPTg/r6euulRdXPf/7zoLy8PKipqQn+9re/BUVFRUFmZmZw9OhR66X1qebm5mDv3r3B3r17A0nBM888E+zduzf497//HQRBEPzud78L0tPTgy1btgT79+8PbrvttqCgoCD44osvjFceWRfah+bm5uCRRx4JduzYEdTU1ATvvPNO8O1vfzu4+uqrg7a2NuulR8wDDzwQpKWlBeXl5cGRI0d6Lq2trT33uf/++4ORI0cG7777brB79+5g2rRpwbRp0wxXHXkX24eqqqrgV7/6VbB79+6gpqYm2LJlSzB69Ohg+vTpxivvbUAUUBAEwfPPPx+MHDkySEhICKZOnRrs3LnTeklRd8cddwS5ublBQkJCcNVVVwV33HFHUFVVZb2sPvfee+8Fks66LFy4MAiCM2/FfuKJJ4Ls7OwgFAoFM2fODCorK20X3QcutA+tra3BrFmzguHDhwfx8fHBqFGjgsWLFw+6f6Sd6+uXFKxbt67nPl988UXws5/9LLjiiiuCoUOHBrfffntw5MgRu0X3gYvtQ21tbTB9+vQgIyMjCIVCwdixY4Nf/OIXQVNTk+3Cv4JfxwAAMNHvXwMCAAxOFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATPw/TLzao1TNC1kAAAAASUVORK5CYII=", 126 | "text/plain": [ 127 | "
" 128 | ] 129 | }, 130 | "metadata": {}, 131 | "output_type": "display_data" 132 | } 133 | ], 134 | "source": [ 135 | "import matplotlib.pyplot as plt\n", 136 | "import numpy as np\n", 137 | "\n", 138 | "image, label = train_dataset[5]\n", 139 | "image = np.array(image).squeeze()\n", 140 | "print(\"Min : \", np.min(image[0]), \" /// Max : \", np.max(image[0]))\n", 141 | "print(\"Data type :\", image[0].dtype)\n", 142 | "# plot the sample\n", 143 | "\n", 144 | "fig = plt.figure\n", 145 | "plt.imshow(image, cmap='gray')\n", 146 | "plt.show()" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "id": "1648a989-431b-418d-9dbf-118ca52aa7e1", 152 | "metadata": {}, 153 | "source": [ 154 | "# Actual QNN\n", 155 | "\n", 156 | "> QNN = Quantized Neural Network\n", 157 | "\n", 158 | "That's what we are about to do using Quantized Aware Training, and the best part is Brevitas handles all of it in the background !\n", 159 | "\n", 160 | "This part is about creating a quantized version of the model and adapting it to finn." 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 5, 166 | "id": "dab5b0a0-9cfd-41c7-99cc-3c2d3365cbcf", 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "import torch\n", 171 | "from brevitas.nn import QuantLinear\n", 172 | "from brevitas.nn import QuantReLU\n", 173 | "\n", 174 | "import torch.nn as nn\n", 175 | "\n", 176 | "brevitas_input_size = 28 * 28\n", 177 | "brevitas_hidden1 = 64\n", 178 | "brevitas_hidden2 = 64\n", 179 | "brevitas_num_classes = 10\n", 180 | "weight_bit_width = 4\n", 181 | "act_bit_width = 4\n", 182 | "dropout_prob = 0.5\n", 183 | "\n", 184 | "#is this model fully quantized or only the wieghts, i shall dig to find out once done !\n", 185 | "brevitas_model = nn.Sequential(\n", 186 | " QuantLinear(brevitas_input_size, brevitas_hidden1, bias=True, weight_bit_width=weight_bit_width),\n", 187 | " nn.BatchNorm1d(brevitas_hidden1),\n", 188 | " nn.Dropout(0.5),\n", 189 | " QuantReLU(bit_width=act_bit_width),\n", 190 | " QuantLinear(brevitas_hidden1, brevitas_hidden2, bias=True, weight_bit_width=weight_bit_width),\n", 191 | " nn.BatchNorm1d(brevitas_hidden2),\n", 192 | " nn.Dropout(0.5),\n", 193 | " QuantReLU(bit_width=act_bit_width),\n", 194 | " QuantLinear(brevitas_hidden2, brevitas_num_classes, bias=True, weight_bit_width=weight_bit_width),\n", 195 | " QuantReLU(bit_width=act_bit_width)\n", 196 | ")\n", 197 | "\n", 198 | "# uncomment to check the network object\n", 199 | "# brevitas_model" 200 | ] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "id": "289eec74", 205 | "metadata": {}, 206 | "source": [ 207 | "### The input data has to be quantized.\n", 208 | "\n", 209 | "Normaly in brevistas, we can use the ```QuantIdentity()``` layer for this but unfortunatly, it does not convert to hardware (yet) in FINN.\n", 210 | "\n", 211 | "It its really not a problem, as we can just quantize the input data !" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 6, 217 | "id": "670acb3a", 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "from torch.utils.data import Dataset\n", 222 | "\n", 223 | "# Define a custom quantization function\n", 224 | "def quantize_tensor(x, num_bits=8):\n", 225 | " qmin = 0.\n", 226 | " qmax = 2.**num_bits - 1.\n", 227 | " min_val, max_val = x.min(), x.max()\n", 228 | "\n", 229 | " scale = (max_val - min_val) / (qmax - qmin)\n", 230 | " initial_zero_point = qmin - min_val / scale\n", 231 | "\n", 232 | " zero_point = 0\n", 233 | " if initial_zero_point < qmin:\n", 234 | " zero_point = qmin\n", 235 | " elif initial_zero_point > qmax:\n", 236 | " zero_point = qmax\n", 237 | " else:\n", 238 | " zero_point = initial_zero_point\n", 239 | "\n", 240 | " zero_point = int(zero_point)\n", 241 | " q_x = zero_point + x / scale\n", 242 | " q_x.clamp_(qmin, qmax).round_()\n", 243 | " \n", 244 | " return q_x\n", 245 | "\n", 246 | "# Define the quantized transform\n", 247 | "transform_quantized = transforms.Compose([\n", 248 | " transforms.ToTensor(), # Convert the image to a PyTorch tensor\n", 249 | " transforms.Lambda(lambda x: quantize_tensor(x)) # Apply quantization\n", 250 | "])\n", 251 | "\n", 252 | "# Load the training dataset\n", 253 | "train_dataset_qnt = datasets.FashionMNIST(\n", 254 | " root='./data', # Directory to save the dataset\n", 255 | " train=True, # Load the training set\n", 256 | " download=True, # Download the dataset if it doesn't exist\n", 257 | " transform=transform_quantized # Apply the defined transformations\n", 258 | ");\n", 259 | "\n", 260 | "# Load the test dataset\n", 261 | "test_dataset_qnt = datasets.FashionMNIST(\n", 262 | " root='./data',\n", 263 | " train=False, # Load the test set\n", 264 | " download=True,\n", 265 | " transform=transform_quantized\n", 266 | ")\n", 267 | "\n", 268 | "train_loader = DataLoader(train_dataset_qnt, 100)\n", 269 | "test_loader = DataLoader(test_dataset_qnt, 100)" 270 | ] 271 | }, 272 | { 273 | "cell_type": "markdown", 274 | "id": "dc84a0bd", 275 | "metadata": {}, 276 | "source": [ 277 | "Let's re-visualize the data now... That's better !!" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 8, 283 | "id": "67c819d8", 284 | "metadata": {}, 285 | "outputs": [ 286 | { 287 | "name": "stdout", 288 | "output_type": "stream", 289 | "text": [ 290 | "Min : 0.0 /// Max : 255.0\n", 291 | "float32\n" 292 | ] 293 | }, 294 | { 295 | "data": { 296 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAg7ElEQVR4nO3dfWzV5f3G8eu0tKcttKeW0icoWFBgyoMbSu0YDKQBusQIks2nZOAMBC1uwJyui4puS7qfJmo0DP5xMBPxKRGIZmFBlBIVWAAJI3MVsEoZtEiVlrb0gfb7+4PYrQLifXN6Pm15v5KT0HPO1e997n716uk5/TQUBEEgAABiLM56AQCAKxMFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMDrBfwTZ2dnTp27JhSU1MVCoWslwMAcBQEgU6fPq28vDzFxV38eU6vK6Bjx44pPz/fehkAgMtUXV2tYcOGXfT2XldAqamp1ku44qSkpHjlHn30UedMYWGhc2b9+vXOmRdffNE5g8szd+5c58zPf/5z58yWLVucM6tXr3bO4PJd6v/nPVZAq1at0tNPP62amhpNnDhRL7zwgiZPnnzJHD92iz3fPU9KSnLODBw40DmTmJjonEHsJSQkOGd8zodwOOycgY1L/b+lR96E8Nprr2nFihVauXKl9u7dq4kTJ2r27Nk6ceJETxwOANAH9UgBPfPMM1q0aJHuvfdeXXfddVqzZo1SUlL0l7/8pScOBwDog6JeQG1tbdqzZ4+Ki4v/e5C4OBUXF2vHjh3n3b+1tVUNDQ3dLgCA/i/qBXTy5El1dHQoOzu72/XZ2dmqqak57/7l5eWKRCJdF94BBwBXBvNfRC0rK1N9fX3Xpbq62npJAIAYiPq74DIzMxUfH6/a2tpu19fW1ionJ+e8+4fDYd7VAgBXoKg/A0pMTNSkSZO0devWrus6Ozu1detWFRUVRftwAIA+qkd+D2jFihVasGCBbrzxRk2ePFnPPfecmpqadO+99/bE4QAAfVCPFNAdd9yhL774Qo8//rhqamp0ww03aPPmzee9MQEAcOUKBUEQWC/ifzU0NCgSiVgvo89as2aNc2batGlex4qPj3fOfPO1we/iuuuuc86cPHnSOSPJ600wn3zyiXPG59cNMjIynDM//OEPnTOS3/SJtLQ058yxY8ecM4MGDXLO+L65afHixc6ZTz/91OtY/VF9ff23nhfm74IDAFyZKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmGAYaS82Y8YM58xvf/tb50xdXZ1zRpJSU1OdM3Fx7t/zJCcnO2eGDBninJGklJQU58yF/tT8pezZs8c5c+ONNzpnkpKSnDPSuSGSrnwGzWZlZTlnvvzyS+dMenq6c0aSTp8+7ZyZN2+e17H6I4aRAgB6JQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACAiQHWC8DFzZo1yznz2WefOWfC4bBzRpLOnj3rnBkwwP2UO3nypHPGZ22SFAqFnDPx8fHOmeuuu84509LS4pxpampyzkh+U6CHDh3qnGlubnbO+ExU/89//uOckfStk5wvZsqUKc6ZDz74wDnTH/AMCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAmGkfZieXl5zpmGhgbnjO8w0vb2dueMz+BOn/W1trY6ZyS/4Z0JCQnOGZ+hpx0dHc4Zn2GakpSSkuKc8Rks6jP0NAgC54zPAFPfY02dOtU5wzBSAABiiAICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAmGkcaIzzBEn0GS9fX1MclIUlJSklfO1YAB7qepT8aXzzDStra2mBzHdwinz/75HMvnMZ05c8Y546uzs9M5M3r06B5YSf/EMyAAgAkKCABgIuoF9MQTTygUCnW7jB07NtqHAQD0cT3yg/Lrr79e77zzzn8PEsOfxwMA+oYeaYYBAwYoJyenJz41AKCf6JHXgA4ePKi8vDyNHDlS99xzj44cOXLR+7a2tqqhoaHbBQDQ/0W9gAoLC7Vu3Tpt3rxZq1evVlVVlaZOnXrRv/1eXl6uSCTSdcnPz4/2kgAAvVDUC6ikpEQ//elPNWHCBM2ePVt/+9vfdOrUKb3++usXvH9ZWZnq6+u7LtXV1dFeEgCgF+rxdwekp6dr9OjROnTo0AVvD4fDCofDPb0MAEAv0+O/B9TY2KjDhw8rNze3pw8FAOhDol5ADz30kCoqKvTZZ5/pww8/1Lx58xQfH6+77ror2ocCAPRhUf8R3NGjR3XXXXeprq5OQ4YM0Y9+9CPt3LlTQ4YMifahAAB9WNQL6NVXX432p+wXCgoKnDM+wx2Tk5OdM77DSL/66ivnjM8vJQ8ePNg5c/bsWeeMJK/XI0OhkHPGZ5Crz3Ha29udM5Lf18lnfT7DPn0yzc3NzhlfQ4cOjdmx+jpmwQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADDR43+QDufk5OQ4Z1pbW50zPoMafYZIStLnn3/unImPj3fONDY2Omd8H9PAgQOdMz6DT32+Tj6DRX2Gikp+wzt9HpPPOV5TU+OcSUlJcc5IUmpqqnOmrq7OOePz1wK++OIL50xvwzMgAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJpmHHSGZmpnPm+PHjzplIJOKcmTp1qnNGkl5++WXnzLFjx5wzubm5zplwOOyckaQzZ844Z3ymVAdB4Jzp6OhwzrS1tTlnJCkhIcE547MPJ06ccM7cfPPNzhmfSd2S9PHHHztn0tLSnDNjxoxxzjANGwAATxQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAEwwjDRGhgwZ4pwZNGiQc2bGjBnOGZ9BqZJ04403Ome2b9/unJkwYYJz5tSpU84ZyW9oZVyc+/dxPoM7ExMTnTPx8fHOGUlKSkpyzmRkZDhnjhw54pxpbm52zhQWFjpnJL99qK6uds7ccMMNzpn333/fOdPb8AwIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYoIACAiVAQBIH1Iv5XQ0ODIpGI9TJ6hREjRjhnnn32WefML3/5S+eMJP3iF79wzgwdOtQ5k5qa6pxpaGhwzkh+Az99+AwwDYVCzpmzZ886ZyRp4MCBzpns7GznTEdHh3PmZz/7mXNm+fLlzhlJGjZsmHNmyZIlzpnW1lbnTF9QX1+vtLS0i97OMyAAgAkKCABgwrmAtm/frltvvVV5eXkKhULauHFjt9uDINDjjz+u3NxcJScnq7i4WAcPHozWegEA/YRzATU1NWnixIlatWrVBW9/6qmn9Pzzz2vNmjXatWuXBg4cqNmzZ6ulpeWyFwsA6D+c/yJqSUmJSkpKLnhbEAR67rnn9Oijj+q2226TJL300kvKzs7Wxo0bdeedd17eagEA/UZUXwOqqqpSTU2NiouLu66LRCIqLCzUjh07LphpbW1VQ0NDtwsAoP+LagHV1NRIOv/tmNnZ2V23fVN5ebkikUjXJT8/P5pLAgD0UubvgisrK1N9fX3Xpbq62npJAIAYiGoB5eTkSJJqa2u7XV9bW9t12zeFw2GlpaV1uwAA+r+oFlBBQYFycnK0devWrusaGhq0a9cuFRUVRfNQAIA+zvldcI2NjTp06FDXx1VVVdq3b58yMjI0fPhwLVu2TH/84x917bXXqqCgQI899pjy8vI0d+7caK4bANDHORfQ7t27NWPGjK6PV6xYIUlasGCB1q1bp4cfflhNTU1avHixTp06pR/96EfavHmzkpKSordqAECfxzBSeJs3b55z5oEHHnDOHD161DnT1tbmnJGkAQOcvyfzGhIaq+P4OnPmjHOmoKDAORMfH++cueWWW5wzsMEwUgBAr0QBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMOE+khdefCYZx8W5f3/gk2lvb3fOSNI///lP50xjY6Nzxmdgu88+SFJCQoJz5uzZs86Zzs5O54zPY/KZNi357Xlzc7NzZtiwYc6ZWPLdP1cdHR0xOU5vwzMgAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJhhGGiM+wx19BhT6DLn01dTUFJPjtLW1OWeSkpK8juUzWNRnYKXP+eAz0Nb3fPDZP5/zwXcQbqz47J/P1/ZKxTMgAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJhhG2s/4DMb0GcApSQkJCTE5ls/AyoEDBzpnfI8VDoedMz77EBfn/v2iz0BbSUpOTnbOtLa2Omc++eQT50ws+QyAZRjpd8czIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYYRgpveXl5zhmfYZ9JSUnOGV8+Q0x9HpOPzs5O54zPwFjJ7zHFaljqsGHDnDNHjx51zkh+w0jx3fEMCABgggICAJhwLqDt27fr1ltvVV5enkKhkDZu3Njt9oULFyoUCnW7zJkzJ1rrBQD0E84F1NTUpIkTJ2rVqlUXvc+cOXN0/Pjxrssrr7xyWYsEAPQ/zm9CKCkpUUlJybfeJxwOKycnx3tRAID+r0deA9q2bZuysrI0ZswY3X///aqrq7vofVtbW9XQ0NDtAgDo/6JeQHPmzNFLL72krVu36v/+7/9UUVGhkpKSi77dsry8XJFIpOuSn58f7SUBAHqhqP8e0J133tn17/Hjx2vChAkaNWqUtm3bppkzZ553/7KyMq1YsaLr44aGBkoIAK4APf427JEjRyozM1OHDh264O3hcFhpaWndLgCA/q/HC+jo0aOqq6tTbm5uTx8KANCHOP8IrrGxsduzmaqqKu3bt08ZGRnKyMjQk08+qfnz5ysnJ0eHDx/Www8/rGuuuUazZ8+O6sIBAH2bcwHt3r1bM2bM6Pr469dvFixYoNWrV2v//v3661//qlOnTikvL0+zZs3SH/7wB4XD4eitGgDQ5zkX0PTp0xUEwUVv//vf/35ZC8Ll+bavTbQVFRU5Z3yGXCYmJjpn4uPjnTPSuV8LcJWcnByT48RyGGlzc7NzxmfPffYuKyvLOeM7jDRWA1avVMyCAwCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYiPqf5IYtn4nJvq655hrnzNmzZ50zKSkpzhnfKdA+U6oHDHD/z8hnKngsv7ZJSUnOGZ8J2j6TzseMGeOc2bt3r3NGiu10+SsRz4AAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYYBhpLxYX5/79gc/ASp9hmpKUlZXlnGlpaXHO+AyEDIVCzhlf4XDYOdPW1uac6ejocM74nEOS37BUn2P5HMdnGKmvWA6AvRLxDAgAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJhpH2YrEaqJmWluaVq6urc84MGTLEOXP69GnnTGpqqnNGit0QTh/x8fHOGd9zyOdYPkNjfQbhjho1yjnjy2cYqc+e++xdf8AzIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCYoIAAACYYRtqLxWoYaX5+vlfOZ+Cnz9DFcDjsnElMTHTOSH7r8zmWz2NqaWlxzvgOuUxOTnbO+AyNPXv2rHPGZ2BsQkKCc8b3WD7DaTs6Opwz/QHPgAAAJiggAIAJpwIqLy/XTTfdpNTUVGVlZWnu3LmqrKzsdp+WlhaVlpZq8ODBGjRokObPn6/a2tqoLhoA0Pc5FVBFRYVKS0u1c+dObdmyRe3t7Zo1a5aampq67rN8+XK99dZbeuONN1RRUaFjx47p9ttvj/rCAQB9m9ObEDZv3tzt43Xr1ikrK0t79uzRtGnTVF9frxdffFHr16/XLbfcIklau3atvve972nnzp26+eabo7dyAECfdlmvAdXX10uSMjIyJEl79uxRe3u7iouLu+4zduxYDR8+XDt27Ljg52htbVVDQ0O3CwCg//MuoM7OTi1btkxTpkzRuHHjJEk1NTVKTExUenp6t/tmZ2erpqbmgp+nvLxckUik6+L7lmAAQN/iXUClpaU6cOCAXn311ctaQFlZmerr67su1dXVl/X5AAB9g9cvoi5dulRvv/22tm/frmHDhnVdn5OTo7a2Np06darbs6Da2lrl5ORc8HOFw2GvX8oDAPRtTs+AgiDQ0qVLtWHDBr377rsqKCjodvukSZOUkJCgrVu3dl1XWVmpI0eOqKioKDorBgD0C07PgEpLS7V+/Xpt2rRJqampXa/rRCIRJScnKxKJ6L777tOKFSuUkZGhtLQ0PfjggyoqKuIdcACAbpwKaPXq1ZKk6dOnd7t+7dq1WrhwoSTp2WefVVxcnObPn6/W1lbNnj1bf/7zn6OyWABA/+FUQN9lsGFSUpJWrVqlVatWeS8KsTV27FivXFpamnPmq6++cs5cddVVzpm2tjbnjCQNGOD+sqhPxmfYp88wUt99+OY7WXvqWD6PKSkpyTkTiUScM5J08uRJ50yshgj3B8yCAwCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgggICAJiggAAAJiggAIAJCggAYIICAgCY8PqLqOhfMjIyvHI+U4nb29udMz6TjOvq6pwzkt9k6+8yJf6b4uLcv/dLSEhwzjQ2NjpnJL89P336tHMmPj4+JpmL/UXmS/GZho3vjmdAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATDCMtBcLhUIxOU5BQYFXrq2tzTnj85gGDhzonPn000+dM5IUDoe9cq7S0tKcM1999ZVzxudrJEmpqanOmeTkZOdMa2urc8bnHBo0aJBzxles/rvtD3gGBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwATDSKGOjg6vnM8gSZ+BlT4DNdvb250zkpSYmOic8RmWmpGR4Zypqqpyzvg8Hl9xce7fz/qcewkJCc6ZWPLZhysVOwUAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEw0jhNexTit0gyRMnTjhnOjs7nTOS34BVn8fks3dffvmlcyYlJcU5I0mNjY3OGZ8hnL5fJ1ctLS0xOY4Uu8fUH/AMCABgggICAJhwKqDy8nLddNNNSk1NVVZWlubOnavKyspu95k+fbpCoVC3y5IlS6K6aABA3+dUQBUVFSotLdXOnTu1ZcsWtbe3a9asWWpqaup2v0WLFun48eNdl6eeeiqqiwYA9H1Ob0LYvHlzt4/XrVunrKws7dmzR9OmTeu6PiUlRTk5OdFZIQCgX7qs14Dq6+slnf/nhV9++WVlZmZq3LhxKisrU3Nz80U/R2trqxoaGrpdAAD9n/fbsDs7O7Vs2TJNmTJF48aN67r+7rvv1ogRI5SXl6f9+/frkUceUWVlpd58880Lfp7y8nI9+eSTvssAAPRR3gVUWlqqAwcO6P333+92/eLFi7v+PX78eOXm5mrmzJk6fPiwRo0add7nKSsr04oVK7o+bmhoUH5+vu+yAAB9hFcBLV26VG+//ba2b9+uYcOGfet9CwsLJUmHDh26YAGFw2GFw2GfZQAA+jCnAgqCQA8++KA2bNigbdu2qaCg4JKZffv2SZJyc3O9FggA6J+cCqi0tFTr16/Xpk2blJqaqpqaGklSJBJRcnKyDh8+rPXr1+snP/mJBg8erP3792v58uWaNm2aJkyY0CMPAADQNzkV0OrVqyWd+2XT/7V27VotXLhQiYmJeuedd/Tcc8+pqalJ+fn5mj9/vh599NGoLRgA0D84/wju2+Tn56uiouKyFgQAuDIwDRsaPXq0Vy49Pd05097eHpPjXHXVVc4ZSUpMTHTOZGZmOmfS0tKcM9dee61zJisryzkjSd///vedMx9++KFzJjU11TkTCoWcM74T39GzGEYKADBBAQEATFBAAAATFBAAwAQFBAAwQQEBAExQQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABMNIe7HOzs6YHGf37t1eOZ8hnCdOnHDOtLS0OGdOnjzpnJGks2fPOmeGDh3qnPH5A4179+51zvgMV5Wkq6++2jlzqWn5F9Lc3OycueGGG5wzX//tsliI1X+3/QHPgAAAJiggAIAJCggAYIICAgCYoIAAACYoIACACQoIAGCCAgIAmKCAAAAmKCAAgAkKCABgotfNgvOZJ9VfxWovWltbvXI+M9p8jtXW1uacaW9vd85IfrPgfNbns3c+jykUCjlnJL+vk8/56nOcM2fOOGdi+f8V/h/2X5fai1DQy3br6NGjys/Pt14GAOAyVVdXa9iwYRe9vdcVUGdnp44dO6bU1NTzvntraGhQfn6+qqurlZaWZrRCe+zDOezDOezDOezDOb1hH4Ig0OnTp5WXl6e4uIu/0tPrfgQXFxf3rY0pSWlpaVf0CfY19uEc9uEc9uEc9uEc632IRCKXvA9vQgAAmKCAAAAm+lQBhcNhrVy5UuFw2HopptiHc9iHc9iHc9iHc/rSPvS6NyEAAK4MfeoZEACg/6CAAAAmKCAAgAkKCABgos8U0KpVq3T11VcrKSlJhYWF+sc//mG9pJh74oknFAqFul3Gjh1rvawet337dt16663Ky8tTKBTSxo0bu90eBIEef/xx5ebmKjk5WcXFxTp48KDNYnvQpfZh4cKF550fc+bMsVlsDykvL9dNN92k1NRUZWVlae7cuaqsrOx2n5aWFpWWlmrw4MEaNGiQ5s+fr9raWqMV94zvsg/Tp08/73xYsmSJ0YovrE8U0GuvvaYVK1Zo5cqV2rt3ryZOnKjZs2frxIkT1kuLueuvv17Hjx/vurz//vvWS+pxTU1NmjhxolatWnXB25966ik9//zzWrNmjXbt2qWBAwdq9uzZXgM/e7NL7YMkzZkzp9v58corr8RwhT2voqJCpaWl2rlzp7Zs2aL29nbNmjVLTU1NXfdZvny53nrrLb3xxhuqqKjQsWPHdPvttxuuOvq+yz5I0qJFi7qdD0899ZTRii8i6AMmT54clJaWdn3c0dER5OXlBeXl5Yarir2VK1cGEydOtF6GKUnBhg0buj7u7OwMcnJygqeffrrrulOnTgXhcDh45ZVXDFYYG9/chyAIggULFgS33XabyXqsnDhxIpAUVFRUBEFw7mufkJAQvPHGG133+fjjjwNJwY4dO6yW2eO+uQ9BEAQ//vGPg1/96ld2i/oOev0zoLa2Nu3Zs0fFxcVd18XFxam4uFg7duwwXJmNgwcPKi8vTyNHjtQ999yjI0eOWC/JVFVVlWpqarqdH5FIRIWFhVfk+bFt2zZlZWVpzJgxuv/++1VXV2e9pB5VX18vScrIyJAk7dmzR+3t7d3Oh7Fjx2r48OH9+nz45j587eWXX1ZmZqbGjRunsrIyNTc3WyzvonrdMNJvOnnypDo6OpSdnd3t+uzsbP373/82WpWNwsJCrVu3TmPGjNHx48f15JNPaurUqTpw4IBSU1Otl2eipqZGki54fnx925Vizpw5uv3221VQUKDDhw/rd7/7nUpKSrRjxw7Fx8dbLy/qOjs7tWzZMk2ZMkXjxo2TdO58SExMVHp6erf79ufz4UL7IEl33323RowYoby8PO3fv1+PPPKIKisr9eabbxqutrteX0D4r5KSkq5/T5gwQYWFhRoxYoRef/113XfffYYrQ29w5513dv17/PjxmjBhgkaNGqVt27Zp5syZhivrGaWlpTpw4MAV8Trot7nYPixevLjr3+PHj1dubq5mzpypw4cPa9SoUbFe5gX1+h/BZWZmKj4+/rx3sdTW1ionJ8doVb1Denq6Ro8erUOHDlkvxczX5wDnx/lGjhypzMzMfnl+LF26VG+//bbee++9bn++JScnR21tbTp16lS3+/fX8+Fi+3AhhYWFktSrzodeX0CJiYmaNGmStm7d2nVdZ2entm7dqqKiIsOV2WtsbNThw4eVm5trvRQzBQUFysnJ6XZ+NDQ0aNeuXVf8+XH06FHV1dX1q/MjCAItXbpUGzZs0LvvvquCgoJut0+aNEkJCQndzofKykodOXKkX50Pl9qHC9m3b58k9a7zwfpdEN/Fq6++GoTD4WDdunXBv/71r2Dx4sVBenp6UFNTY720mPr1r38dbNu2Laiqqgo++OCDoLi4OMjMzAxOnDhhvbQedfr06eCjjz4KPvroo0BS8MwzzwQfffRR8PnnnwdBEAR/+tOfgvT09GDTpk3B/v37g9tuuy0oKCgIzpw5Y7zy6Pq2fTh9+nTw0EMPBTt27AiqqqqCd955J/jBD34QXHvttUFLS4v10qPm/vvvDyKRSLBt27bg+PHjXZfm5uau+yxZsiQYPnx48O677wa7d+8OioqKgqKiIsNVR9+l9uHQoUPB73//+2D37t1BVVVVsGnTpmDkyJHBtGnTjFfeXZ8ooCAIghdeeCEYPnx4kJiYGEyePDnYuXOn9ZJi7o477ghyc3ODxMTEYOjQocEdd9wRHDp0yHpZPe69994LJJ13WbBgQRAE596K/dhjjwXZ2dlBOBwOZs6cGVRWVtouugd82z40NzcHs2bNCoYMGRIkJCQEI0aMCBYtWtTvvkm70OOXFKxdu7brPmfOnAkeeOCB4KqrrgpSUlKCefPmBcePH7dbdA+41D4cOXIkmDZtWpCRkRGEw+HgmmuuCX7zm98E9fX1tgv/Bv4cAwDARK9/DQgA0D9RQAAAExQQAMAEBQQAMEEBAQBMUEAAABMUEADABAUEADBBAQEATFBAAAATFBAAwAQFBAAw8f+FkJDJAsm5jQAAAABJRU5ErkJggg==", 297 | "text/plain": [ 298 | "
" 299 | ] 300 | }, 301 | "metadata": {}, 302 | "output_type": "display_data" 303 | } 304 | ], 305 | "source": [ 306 | "import matplotlib.pyplot as plt\n", 307 | "import numpy as np\n", 308 | "\n", 309 | "image, label = train_dataset_qnt[10]\n", 310 | "image = np.array(image).squeeze()\n", 311 | "print(\"Min : \", np.min(image), \" /// Max : \", np.max(image))\n", 312 | "print(image.dtype)\n", 313 | "\n", 314 | "# plot the sample\n", 315 | "fig = plt.figure\n", 316 | "plt.imshow(image, cmap='gray')\n", 317 | "plt.show()" 318 | ] 319 | }, 320 | { 321 | "cell_type": "markdown", 322 | "id": "a9a58259", 323 | "metadata": {}, 324 | "source": [ 325 | "## Actual training\n", 326 | "\n", 327 | "Now, just like in a regular PyTorch workflow, we train the model using a simple trainng loop.\n", 328 | "\n", 329 | "Note that it takes a bit of time because\n", 330 | "- We train on CPU\n", 331 | "- QAT \"simulates\" Quantization using quant/dequant layers (so the model becomes robust to quantization, which is not the case in PTQ) and backpropagation is different and that takes up computing power" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 9, 337 | "id": "236d85fe-986a-44f4-84d1-fce1d5591f80", 338 | "metadata": {}, 339 | "outputs": [ 340 | { 341 | "name": "stderr", 342 | "output_type": "stream", 343 | "text": [ 344 | "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py:1255: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at ../c10/core/TensorImpl.h:1758.)\n", 345 | " return super(Tensor, self).rename(names)\n" 346 | ] 347 | }, 348 | { 349 | "name": "stdout", 350 | "output_type": "stream", 351 | "text": [ 352 | "Epoch [1/5], Loss: 0.6401\n", 353 | "Epoch [2/5], Loss: 0.4895\n", 354 | "Epoch [3/5], Loss: 0.5153\n", 355 | "Epoch [4/5], Loss: 0.6120\n", 356 | "Epoch [5/5], Loss: 0.3428\n" 357 | ] 358 | } 359 | ], 360 | "source": [ 361 | "# loss criterion and optimizer\n", 362 | "criterion = nn.CrossEntropyLoss()\n", 363 | "optimizer = torch.optim.Adam(brevitas_model.parameters(), lr=0.001, betas=(0.9, 0.999))\n", 364 | "\n", 365 | "num_epochs = 5\n", 366 | "brevitas_model.train()\n", 367 | "batch_size = 100\n", 368 | "\n", 369 | "for epoch in range(num_epochs):\n", 370 | " for batch_idx, (images, labels) in enumerate(train_loader):\n", 371 | " images = torch.reshape(images, (batch_size, 28*28))\n", 372 | " out = brevitas_model(images.float()) # This just make the value a float ie 255 becomes 255,0 and not 1\n", 373 | " loss = criterion(out, labels)\n", 374 | " optimizer.zero_grad()\n", 375 | " loss.backward()\n", 376 | " optimizer.step()\n", 377 | "\n", 378 | " print(f'Epoch [{epoch+1}/5], Loss: {loss.item():.4f}')" 379 | ] 380 | }, 381 | { 382 | "cell_type": "markdown", 383 | "id": "78f7ee17", 384 | "metadata": {}, 385 | "source": [ 386 | "We now test the model with ~85% accuracy, which is extremely close to non-quatized model performances (~85-86% when I tested it on my side when I made the course).\n", 387 | "\n", 388 | "This dataset isn't very complex though, this delta will depend on the data you have to classify." 389 | ] 390 | }, 391 | { 392 | "cell_type": "code", 393 | "execution_count": 10, 394 | "id": "f4d239cf-edb2-4faa-a502-d05861a156fb", 395 | "metadata": {}, 396 | "outputs": [ 397 | { 398 | "name": "stdout", 399 | "output_type": "stream", 400 | "text": [ 401 | "accuracy = 84.2 %\n" 402 | ] 403 | } 404 | ], 405 | "source": [ 406 | "# test the model\n", 407 | "\n", 408 | "brevitas_model.eval()\n", 409 | "correct = 0\n", 410 | "total = 0\n", 411 | "loss_total = 0\n", 412 | "\n", 413 | "with torch.no_grad():\n", 414 | " for batch_idx, (images, labels) in enumerate(test_loader):\n", 415 | " images = torch.reshape(images, (batch_size, 28*28))\n", 416 | " out = brevitas_model(images.float())\n", 417 | " _, predicted = torch.max(out.data, 1)\n", 418 | " total += labels.size(0)\n", 419 | " correct += (predicted == labels).sum().item()\n", 420 | "\n", 421 | " accuracy = 100 * correct / total\n", 422 | " print(\"accuracy =\", accuracy, \"%\")" 423 | ] 424 | }, 425 | { 426 | "cell_type": "markdown", 427 | "id": "da5c0b43", 428 | "metadata": {}, 429 | "source": [ 430 | "Here are some lines to execute the tensor that mmakes up the model and their different representations :" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": 11, 436 | "id": "c3843919-1e83-4c50-accd-3e9292cecc2f", 437 | "metadata": {}, 438 | "outputs": [ 439 | { 440 | "name": "stdout", 441 | "output_type": "stream", 442 | "text": [ 443 | "QuantTensor(value=tensor([[ 0.0294, -0.0000, 0.0589, ..., -0.0589, 0.0589, 0.0883],\n", 444 | " [ 0.0294, 0.0000, 0.0294, ..., -0.0883, -0.1471, -0.0589],\n", 445 | " [-0.0589, 0.0294, -0.0294, ..., 0.0883, 0.0883, -0.1177],\n", 446 | " ...,\n", 447 | " [ 0.0589, 0.0883, -0.0294, ..., -0.0589, -0.0000, 0.0589],\n", 448 | " [ 0.0589, -0.0000, -0.0294, ..., 0.0000, 0.0294, -0.0294],\n", 449 | " [-0.0294, 0.0000, -0.0589, ..., 0.0000, 0.0294, -0.0294]],\n", 450 | " grad_fn=), scale=tensor(0.0294, grad_fn=), zero_point=tensor(0.), bit_width=tensor(4.), signed_t=tensor(True), training_t=tensor(False))\n", 451 | "tensor([[ 1, 0, 2, ..., -2, 2, 3],\n", 452 | " [ 1, 0, 1, ..., -3, -5, -2],\n", 453 | " [-2, 1, -1, ..., 3, 3, -4],\n", 454 | " ...,\n", 455 | " [ 2, 3, -1, ..., -2, 0, 2],\n", 456 | " [ 2, 0, -1, ..., 0, 1, -1],\n", 457 | " [-1, 0, -2, ..., 0, 1, -1]], dtype=torch.int8)\n", 458 | "torch.int8\n" 459 | ] 460 | } 461 | ], 462 | "source": [ 463 | "#lets have a quick look at the weights too\n", 464 | "print(brevitas_model[0].quant_weight())\n", 465 | "#internally, weoght are stored as float 32, here nare ways to visualize actual quantized weights :\n", 466 | "print(brevitas_model[0].quant_weight().int())\n", 467 | "print(brevitas_model[0].quant_weight().int().dtype)" 468 | ] 469 | }, 470 | { 471 | "cell_type": "markdown", 472 | "id": "3a1d9899-000b-4343-ac81-2c425b571519", 473 | "metadata": {}, 474 | "source": [ 475 | "# EXPORTING THE MODEL TO FINN-ONNX\n", 476 | "\n", 477 | "maybe you know ONNX, maybe you don't. At the end of the day, It's just a way to represent AI models (or just tensor operations) in an optimised an standard way. It also allows us to use Netron to visualize the model as a nice graph.\n", 478 | "\n", 479 | "> FINN-ONNX is just like ONNX, exept it's better for quantized models (under 8bit quantization)." 480 | ] 481 | }, 482 | { 483 | "cell_type": "code", 484 | "execution_count": 12, 485 | "id": "11f27a3b-88c6-483e-be82-85d1ee53129b", 486 | "metadata": {}, 487 | "outputs": [ 488 | { 489 | "name": "stdout", 490 | "output_type": "stream", 491 | "text": [ 492 | "255.0\n", 493 | "Model saved to /tmp/finn_dev_rootmin/ready_finn.onnx\n" 494 | ] 495 | }, 496 | { 497 | "name": "stderr", 498 | "output_type": "stream", 499 | "text": [ 500 | "/home/rootmin/Documents/freelance/mission2/finn/deps/qonnx/src/qonnx/transformation/gemm_to_matmul.py:57: UserWarning: The GemmToMatMul transformation only offers explicit support for version 9 of the Gemm node, but the ONNX version of the supplied model is 14. Thus the transformation may fail or return incomplete results.\n", 501 | " warnings.warn(\n" 502 | ] 503 | } 504 | ], 505 | "source": [ 506 | "from brevitas.export import export_qonnx\n", 507 | "from qonnx.util.cleanup import cleanup as qonnx_cleanup\n", 508 | "from qonnx.core.modelwrapper import ModelWrapper\n", 509 | "from qonnx.core.datatype import DataType\n", 510 | "from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN\n", 511 | "\n", 512 | "filename = root_dir + \"/part1.onnx\"\n", 513 | "filename_clean = root_dir + \"/part1_clean.onnx\"\n", 514 | "\n", 515 | "def asymmetric_quantize(arr, num_bits=8):\n", 516 | " min = 0\n", 517 | " max = 2**num_bits - 1\n", 518 | " \n", 519 | " beta = np.min(arr)\n", 520 | " alpha = np.max(arr)\n", 521 | " scale = (alpha - beta) / max\n", 522 | " zero_point = np.clip((-beta/scale),0,max).round().astype(np.int8)\n", 523 | "\n", 524 | " quantized_arr = np.clip(np.round(arr / scale + zero_point), min, max).astype(np.float32)\n", 525 | " \n", 526 | " return quantized_arr\n", 527 | "\n", 528 | "#Crete a tensor ressembling the input tensor we saw earlier\n", 529 | "input_a = np.random.rand(1,28*28).astype(np.float32)\n", 530 | "input_a = asymmetric_quantize(input_a)\n", 531 | "print(np.max(input_a[0]))\n", 532 | "scale = 1.0\n", 533 | "input_t = torch.from_numpy(input_a * scale)\n", 534 | "\n", 535 | "# Export to ONNX\n", 536 | "export_qonnx(\n", 537 | " brevitas_model, export_path=filename, input_t=input_t\n", 538 | ")\n", 539 | "\n", 540 | "# clean-up\n", 541 | "qonnx_cleanup(filename, out_file=filename_clean)\n", 542 | "\n", 543 | "# ModelWrapper\n", 544 | "model = ModelWrapper(filename_clean)\n", 545 | "model = model.transform(ConvertQONNXtoFINN())\n", 546 | "model.save(root_dir + \"/ready_finn.onnx\")\n", 547 | "\n", 548 | "print(\"Model saved to \" + root_dir + \"/ready_finn.onnx\")" 549 | ] 550 | }, 551 | { 552 | "cell_type": "markdown", 553 | "id": "f2c7e045", 554 | "metadata": {}, 555 | "source": [ 556 | "## Visualization in Netron\n", 557 | "\n", 558 | "We can now visualise our network in netron, we can clearly identify each layer in this graph, Great !" 559 | ] 560 | }, 561 | { 562 | "cell_type": "code", 563 | "execution_count": 13, 564 | "id": "8d6f0250-5b1c-4276-9293-c29a1b112782", 565 | "metadata": {}, 566 | "outputs": [ 567 | { 568 | "name": "stdout", 569 | "output_type": "stream", 570 | "text": [ 571 | "Serving '/tmp/finn_dev_rootmin/ready_finn.onnx' at http://0.0.0.0:8081\n" 572 | ] 573 | }, 574 | { 575 | "data": { 576 | "text/html": [ 577 | "\n", 578 | " \n", 586 | " " 587 | ], 588 | "text/plain": [ 589 | "" 590 | ] 591 | }, 592 | "execution_count": 13, 593 | "metadata": {}, 594 | "output_type": "execute_result" 595 | } 596 | ], 597 | "source": [ 598 | "from finn.util.visualization import showInNetron\n", 599 | "\n", 600 | "showInNetron(root_dir + \"/ready_finn.onnx\")" 601 | ] 602 | }, 603 | { 604 | "cell_type": "markdown", 605 | "id": "41db41a1", 606 | "metadata": {}, 607 | "source": [ 608 | "We are now ready to move on to FINN !" 609 | ] 610 | } 611 | ], 612 | "metadata": { 613 | "kernelspec": { 614 | "display_name": "Python 3 (ipykernel)", 615 | "language": "python", 616 | "name": "python3" 617 | } 618 | }, 619 | "nbformat": 4, 620 | "nbformat_minor": 5 621 | } 622 | --------------------------------------------------------------------------------