├── docs ├── header.png ├── image.png ├── octav.png ├── themcu.jpg ├── NF4plot.png ├── testdata.png ├── 12kopt_test.png ├── NF4scaling.png ├── header_cnn.png ├── schedules.png ├── 12kopt_train.png ├── Model.drawio.png ├── NF4lossplots.png ├── fp130_export.png ├── model_cnn_mcu.png ├── prepostquant.png ├── 12kopt_augmented.png ├── 4bit_histograms.png ├── Model_mcu.drawio.png ├── lossvsschedule.png ├── octav_equation.png ├── octav_weightdist.png ├── quantscale_scan.png ├── trainingpipeline.png ├── cnn_tradeoff_plots.png ├── model_cnn_channel.png ├── model_cnn_overview.png ├── quantscale_entropy.png ├── train_loss_vs_now.png ├── train_vs_test_loss.png ├── explorationaugmented.png ├── first_layer_weights.png ├── quantscale_fp130_scan.png ├── combined_pixel_analysis.png ├── train_loss_vs_totalbits.png ├── train_vs_test_accuracy.png ├── first_layer_weights_fp130.png ├── quantscale_fp130_entropy.png ├── first_layer_weights_noaugment.png ├── plots │ ├── readme.md │ ├── augmented_60ep.txt │ ├── losses_vstotalbits.py │ ├── prevspostquant_plot.py │ ├── errorrate.py │ ├── networksize_subset.txt │ ├── accuracy_loss_plots.py │ ├── plot12kopt.py │ ├── plot12kaugopt.py │ ├── 12kaugruns.txt │ ├── create_cnn_plots.py │ ├── 12kruns.txt │ ├── clean_30epruns.txt │ └── networksize_all.txt ├── readme.md ├── explorationaugmented.py ├── trainingpipeline.drawio ├── model_cnn_channel.drawio ├── Model_mcu.drawio ├── Model.drawio └── model_cnn_overview.drawio ├── mcu ├── console.png ├── Makefile ├── funconfig.h ├── readme.md ├── BitNetMCU_model_1k.h ├── BitNetMCUdemo.c └── BitNetMCU_model_cnn_16small.h ├── requirements.txt ├── .gitmodules ├── modeldata ├── octav_FCMNIST_Aug_BitMnist_4bitsym_width64_64_64_epochs60.pth ├── octav__CNNMNIST_Aug_BitMnist_4bitsym_width96_64_0_epochs60.pth ├── opt_Cosine_lr0.001_Aug_BitMnist_PerTensor_4bit_RMS_width64_64_64_bs128_epochs60.pth ├── opt_Cosine_lr0.001_Aug_BitMnist_PerTensor_FP130_RMS_width64_64_64_bs128_epochs60.pth ├── opt_Cosine_lr0.001_Aug_BitMnist_PerTensor_None_RMS_width64_64_64_bs128_epochs60.pth ├── opt_Cosine_lr0.001_Aug_BitMnist_PerTensor_4bitsym_RMS_width64_64_64_bs128_epochs60.pth ├── octav_Cosine_lr0.001_Aug_BitMnist_PerTensor_4bitsym_RMS_width64_64_64_bs128_epochs60.pth ├── emnist_letters_Cosine_lr0.001_Aug_BitMnist_PerTensor_4bitsym_RMS_width64_64_64_bs128_epochs60.pth ├── emnist_letters_Cosine_lr0.001_Aug_BitMnist_PerTensor_4bitsym_RMS_width96_96_96_bs128_epochs60.pth ├── a11_Opt12k_cos_Aug_BitMnist_PerTensor_2bitsym_RMS_width96_96_96_lr0.001_decay0.1_stepsize10_bs128_epochs60.pth ├── a11_Opt12k_cos_Aug_BitMnist_PerTensor_4bitsym_RMS_width56_56_56_lr0.001_decay0.1_stepsize10_bs128_epochs60.pth ├── a11_Opt12k_cos_Aug_BitMnist_PerTensor_4bitsym_RMS_width64_64_0_lr0.001_decay0.1_stepsize10_bs128_epochs60.pth ├── a11_Opt12k_cos_Aug_BitMnist_PerTensor_4bitsym_RMS_width72_80_0_lr0.001_decay0.1_stepsize10_bs128_epochs60.pth └── a11_Opt12k_cos_Aug_BitMnist_PerTensor_Binary_RMS_width160_160_160_lr0.001_decay0.1_stepsize10_bs128_epochs60.pth ├── Makefile ├── .gitignore ├── trainingparameters.yaml ├── BitNetMCU_inference.h ├── BitNetMCU_MNIST_dll.c ├── readme.md ├── BitNetMCU_MNIST_test.c ├── models.py ├── test_inference.py ├── BitNetMCU_inference.c ├── training.py └── BitNetMCU_MNIST_test_data.h /docs/header.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/header.png -------------------------------------------------------------------------------- /docs/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/image.png -------------------------------------------------------------------------------- /docs/octav.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/octav.png -------------------------------------------------------------------------------- /docs/themcu.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/themcu.jpg -------------------------------------------------------------------------------- /mcu/console.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/mcu/console.png -------------------------------------------------------------------------------- /docs/NF4plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/NF4plot.png -------------------------------------------------------------------------------- /docs/testdata.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/testdata.png -------------------------------------------------------------------------------- /docs/12kopt_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/12kopt_test.png -------------------------------------------------------------------------------- /docs/NF4scaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/NF4scaling.png -------------------------------------------------------------------------------- /docs/header_cnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/header_cnn.png -------------------------------------------------------------------------------- /docs/schedules.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/schedules.png -------------------------------------------------------------------------------- /docs/12kopt_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/12kopt_train.png -------------------------------------------------------------------------------- /docs/Model.drawio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/Model.drawio.png -------------------------------------------------------------------------------- /docs/NF4lossplots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/NF4lossplots.png -------------------------------------------------------------------------------- /docs/fp130_export.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/fp130_export.png -------------------------------------------------------------------------------- /docs/model_cnn_mcu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/model_cnn_mcu.png -------------------------------------------------------------------------------- /docs/prepostquant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/prepostquant.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | numpy 4 | PyYAML 5 | tensorboard 6 | matplotlib 7 | -------------------------------------------------------------------------------- /docs/12kopt_augmented.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/12kopt_augmented.png -------------------------------------------------------------------------------- /docs/4bit_histograms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/4bit_histograms.png -------------------------------------------------------------------------------- /docs/Model_mcu.drawio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/Model_mcu.drawio.png -------------------------------------------------------------------------------- /docs/lossvsschedule.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/lossvsschedule.png -------------------------------------------------------------------------------- /docs/octav_equation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/octav_equation.png -------------------------------------------------------------------------------- /docs/octav_weightdist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/octav_weightdist.png -------------------------------------------------------------------------------- /docs/quantscale_scan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/quantscale_scan.png -------------------------------------------------------------------------------- /docs/trainingpipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/trainingpipeline.png -------------------------------------------------------------------------------- /docs/cnn_tradeoff_plots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/cnn_tradeoff_plots.png -------------------------------------------------------------------------------- /docs/model_cnn_channel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/model_cnn_channel.png -------------------------------------------------------------------------------- /docs/model_cnn_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/model_cnn_overview.png -------------------------------------------------------------------------------- /docs/quantscale_entropy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/quantscale_entropy.png -------------------------------------------------------------------------------- /docs/train_loss_vs_now.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/train_loss_vs_now.png -------------------------------------------------------------------------------- /docs/train_vs_test_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/train_vs_test_loss.png -------------------------------------------------------------------------------- /docs/explorationaugmented.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/explorationaugmented.png -------------------------------------------------------------------------------- /docs/first_layer_weights.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/first_layer_weights.png -------------------------------------------------------------------------------- /docs/quantscale_fp130_scan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/quantscale_fp130_scan.png -------------------------------------------------------------------------------- /docs/combined_pixel_analysis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/combined_pixel_analysis.png -------------------------------------------------------------------------------- /docs/train_loss_vs_totalbits.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/train_loss_vs_totalbits.png -------------------------------------------------------------------------------- /docs/train_vs_test_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/train_vs_test_accuracy.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "mcu/ch32v003fun"] 2 | path = mcu/ch32v003fun 3 | url = https://github.com/cnlohr/ch32v003fun 4 | -------------------------------------------------------------------------------- /docs/first_layer_weights_fp130.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/first_layer_weights_fp130.png -------------------------------------------------------------------------------- /docs/quantscale_fp130_entropy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/quantscale_fp130_entropy.png -------------------------------------------------------------------------------- /docs/first_layer_weights_noaugment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/docs/first_layer_weights_noaugment.png -------------------------------------------------------------------------------- /modeldata/octav_FCMNIST_Aug_BitMnist_4bitsym_width64_64_64_epochs60.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/modeldata/octav_FCMNIST_Aug_BitMnist_4bitsym_width64_64_64_epochs60.pth -------------------------------------------------------------------------------- /modeldata/octav__CNNMNIST_Aug_BitMnist_4bitsym_width96_64_0_epochs60.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/modeldata/octav__CNNMNIST_Aug_BitMnist_4bitsym_width96_64_0_epochs60.pth -------------------------------------------------------------------------------- /docs/plots/readme.md: -------------------------------------------------------------------------------- 1 | This directory contains plots used in the development log, including data. Please note that most of the code used to generate these plots was created with ChatGPT, so use it with caution. -------------------------------------------------------------------------------- /modeldata/opt_Cosine_lr0.001_Aug_BitMnist_PerTensor_4bit_RMS_width64_64_64_bs128_epochs60.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/modeldata/opt_Cosine_lr0.001_Aug_BitMnist_PerTensor_4bit_RMS_width64_64_64_bs128_epochs60.pth -------------------------------------------------------------------------------- /modeldata/opt_Cosine_lr0.001_Aug_BitMnist_PerTensor_FP130_RMS_width64_64_64_bs128_epochs60.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/modeldata/opt_Cosine_lr0.001_Aug_BitMnist_PerTensor_FP130_RMS_width64_64_64_bs128_epochs60.pth -------------------------------------------------------------------------------- /modeldata/opt_Cosine_lr0.001_Aug_BitMnist_PerTensor_None_RMS_width64_64_64_bs128_epochs60.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/modeldata/opt_Cosine_lr0.001_Aug_BitMnist_PerTensor_None_RMS_width64_64_64_bs128_epochs60.pth -------------------------------------------------------------------------------- /modeldata/opt_Cosine_lr0.001_Aug_BitMnist_PerTensor_4bitsym_RMS_width64_64_64_bs128_epochs60.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/modeldata/opt_Cosine_lr0.001_Aug_BitMnist_PerTensor_4bitsym_RMS_width64_64_64_bs128_epochs60.pth -------------------------------------------------------------------------------- /modeldata/octav_Cosine_lr0.001_Aug_BitMnist_PerTensor_4bitsym_RMS_width64_64_64_bs128_epochs60.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/modeldata/octav_Cosine_lr0.001_Aug_BitMnist_PerTensor_4bitsym_RMS_width64_64_64_bs128_epochs60.pth -------------------------------------------------------------------------------- /mcu/Makefile: -------------------------------------------------------------------------------- 1 | all : build 2 | 3 | TARGET:=BitNetMCUdemo 4 | # TARGET_MCU?=CH32V003 5 | TARGET_MCU?=CH32V002 6 | 7 | CH32FUN ?= ch32v003fun/ch32fun 8 | 9 | include $(CH32FUN)/ch32fun.mk 10 | 11 | flash : cv_flash 12 | clean : cv_clean 13 | 14 | -------------------------------------------------------------------------------- /modeldata/emnist_letters_Cosine_lr0.001_Aug_BitMnist_PerTensor_4bitsym_RMS_width64_64_64_bs128_epochs60.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/modeldata/emnist_letters_Cosine_lr0.001_Aug_BitMnist_PerTensor_4bitsym_RMS_width64_64_64_bs128_epochs60.pth -------------------------------------------------------------------------------- /modeldata/emnist_letters_Cosine_lr0.001_Aug_BitMnist_PerTensor_4bitsym_RMS_width96_96_96_bs128_epochs60.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/modeldata/emnist_letters_Cosine_lr0.001_Aug_BitMnist_PerTensor_4bitsym_RMS_width96_96_96_bs128_epochs60.pth -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SOURCES = BitNetMCU_MNIST_dll.c BitNetMCU_inference.c 2 | HEADERS = BitNetMCU_model.h BitNetMCU_inference.h 3 | DLL = Bitnet_inf.dll 4 | 5 | $(DLL): $(SOURCES) $(HEADERS) 6 | cc -fPIC -shared -o $@ -D _DLL $< 7 | 8 | .PHONY: clean 9 | clean: 10 | rm -f $(DLL) 11 | -------------------------------------------------------------------------------- /modeldata/a11_Opt12k_cos_Aug_BitMnist_PerTensor_2bitsym_RMS_width96_96_96_lr0.001_decay0.1_stepsize10_bs128_epochs60.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/modeldata/a11_Opt12k_cos_Aug_BitMnist_PerTensor_2bitsym_RMS_width96_96_96_lr0.001_decay0.1_stepsize10_bs128_epochs60.pth -------------------------------------------------------------------------------- /modeldata/a11_Opt12k_cos_Aug_BitMnist_PerTensor_4bitsym_RMS_width56_56_56_lr0.001_decay0.1_stepsize10_bs128_epochs60.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/modeldata/a11_Opt12k_cos_Aug_BitMnist_PerTensor_4bitsym_RMS_width56_56_56_lr0.001_decay0.1_stepsize10_bs128_epochs60.pth -------------------------------------------------------------------------------- /modeldata/a11_Opt12k_cos_Aug_BitMnist_PerTensor_4bitsym_RMS_width64_64_0_lr0.001_decay0.1_stepsize10_bs128_epochs60.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/modeldata/a11_Opt12k_cos_Aug_BitMnist_PerTensor_4bitsym_RMS_width64_64_0_lr0.001_decay0.1_stepsize10_bs128_epochs60.pth -------------------------------------------------------------------------------- /modeldata/a11_Opt12k_cos_Aug_BitMnist_PerTensor_4bitsym_RMS_width72_80_0_lr0.001_decay0.1_stepsize10_bs128_epochs60.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/modeldata/a11_Opt12k_cos_Aug_BitMnist_PerTensor_4bitsym_RMS_width72_80_0_lr0.001_decay0.1_stepsize10_bs128_epochs60.pth -------------------------------------------------------------------------------- /modeldata/a11_Opt12k_cos_Aug_BitMnist_PerTensor_Binary_RMS_width160_160_160_lr0.001_decay0.1_stepsize10_bs128_epochs60.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cpldcpu/BitNetMCU/HEAD/modeldata/a11_Opt12k_cos_Aug_BitMnist_PerTensor_Binary_RMS_width160_160_160_lr0.001_decay0.1_stepsize10_bs128_epochs60.pth -------------------------------------------------------------------------------- /mcu/funconfig.h: -------------------------------------------------------------------------------- 1 | #ifndef _FUNCONFIG_H 2 | #define _FUNCONFIG_H 3 | 4 | #define FUNCONF_USE_HSE 0 // external crystal on PA1 PA2 5 | #define FUNCONF_USE_HSI 1 // internal 24MHz clock oscillator 6 | #define FUNCONF_USE_PLL 1 // use PLL x2 7 | 8 | #define FUNCONF_SYSTICK_USE_HCLK 1 9 | 10 | #endif -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # training data 2 | data/ 3 | runs/ 4 | runs_opt/ 5 | backup/ 6 | batchtest/ 7 | *.obj 8 | *.dll 9 | *.exp 10 | *.lib 11 | *.exe 12 | *.hex 13 | *.elf 14 | *.bin 15 | *.map 16 | *.o 17 | *.lst 18 | *.bkp 19 | *.pdf 20 | *.log 21 | *.elog 22 | # python cache 23 | __pycache__/ 24 | venv/ 25 | # ides 26 | .vscode/ 27 | .idea/ 28 | -------------------------------------------------------------------------------- /docs/readme.md: -------------------------------------------------------------------------------- 1 | # BitNetMCU Documentation Assets 2 | 3 | This folder contains my progress logs during implementation of the BitNetMCU project, including diagrams, charts, and images that illustrate the architecture, training process, quantization analysis, and deployment on microcontroller units (MCUs). 4 | 5 | - [Primary documentation](./documentation.md), focussed on the initial fully connected MLP implementation 6 | - [CNN documentation](./documentation_cnn.md), focussed on the CNN implementation 7 | 8 | --- 9 | Last updated: 2025-09-07 10 | -------------------------------------------------------------------------------- /docs/plots/augmented_60ep.txt: -------------------------------------------------------------------------------- 1 | w1 w2 w3 L1_weights L2_weights L3_weights BPW total kbyte Layers test/accuracy 2 | 64 64 64 16384 4096 4096 4 12.3125 3 99.01 3 | 64 64 0 16384 4096 0 4 10.3125 2 98.79 4 | 64 48 32 16384 3072 1536 4 10.40625 3 98.51 5 | 56 56 56 14336 3136 3136 4 10.3359375 3 98.73 6 | 72 80 0 18432 5760 0 4 12.203125 2 98.98 7 | 80 80 80 20480 6400 6400 4 16.640625 3 99.04 8 | 96 96 0 24576 9216 0 4 16.96875 2 98.98 9 | 48 48 0 12288 2304 0 4 7.359375 2 98.52 10 | 48 48 48 12288 2304 2304 4 8.484375 3 98.72 11 | 72 64 0 18432 4608 0 4 11.5625 2 98.85 12 | 72 72 0 18432 5184 0 4 11.8828125 2 98.94 13 | 64 96 0 16384 6144 0 4 11.46875 2 98.96 14 | 64 128 0 16384 8192 0 4 12.625 2 99.03 15 | 64 72 0 16384 4608 0 4 10.6015625 2 98.76 16 | 56 112 0 14336 6272 0 4 10.609375 2 98.87 17 | 48 160 0 12288 7680 0 4 10.53125 2 98.79 18 | 48 128 0 12288 6144 0 4 9.625 2 98.7 19 | 40 64 0 10240 2560 0 4 6.5625 2 98.29 20 | 40 40 0 10240 1600 0 8 11.953125 2 98.6 21 | 40 40 0 10240 1600 0 4 5.976563 2 98.13 22 | 40 40 0 10240 1600 0 5 7.460703 2 98.44 23 | 40 40 0 10240 1600 0 2 2.988281 2 96.79 24 | -------------------------------------------------------------------------------- /docs/plots/losses_vstotalbits.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | import pandas as pd 5 | 6 | # Assuming df is your DataFrame loaded from the file 7 | # Replace the loading part with your actual DataFrame loading code 8 | df = pd.read_csv('clean_30epruns.txt', sep='\t') 9 | 10 | #%% 11 | # Plotting training and test loss in separate panes, horizontally, in the same figure 12 | fig, axes = plt.subplots(1, 2, figsize=(16, 6), sharey=True) # Using shared y-axis 13 | 14 | # Plotting Training Loss on the first pane 15 | sns.scatterplot(ax=axes[0], data=df, x='Totalbits', y='Loss/train', hue='QuantType', style='QuantType', palette='tab10', s=100, edgecolor='black', linewidth=0.5) 16 | axes[0].set_title('Training Loss vs. Total Bits', fontsize=14) 17 | axes[0].set_xlabel('Total Bits (log scale)', fontsize=12) 18 | axes[0].set_ylabel('Loss (log scale)', fontsize=12) 19 | axes[0].set_xscale('log') 20 | axes[0].set_yscale('log') 21 | axes[0].legend(title='Quantization Type', bbox_to_anchor=(1.05, 1), loc='upper left') 22 | 23 | # Plotting Test Loss on the second pane 24 | sns.scatterplot(ax=axes[1], data=df, x='Totalbits', y='Loss/test', hue='QuantType', style='QuantType', palette='tab10', s=100, edgecolor='black', linewidth=0.5) 25 | axes[1].set_title('Test Loss vs. Total Bits', fontsize=14) 26 | axes[1].set_xlabel('Total Bits (log scale)', fontsize=12) 27 | # No need for y-label as it shares with the first pane 28 | axes[1].set_xscale('log') 29 | axes[1].set_yscale('log') 30 | axes[1].legend(title='Quantization Type', bbox_to_anchor=(1.05, 1), loc='upper left') 31 | 32 | # Adjusting layout for shared y-axis and ensuring the legend is not cut off 33 | plt.tight_layout() 34 | 35 | # Displaying the plot 36 | plt.show() 37 | 38 | # %% 39 | -------------------------------------------------------------------------------- /docs/plots/prevspostquant_plot.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import matplotlib.pyplot as plt 3 | import pandas as pd 4 | 5 | # Data preparation 6 | data = { 7 | "Quantization during training": [ 8 | "FP32", "FP32", "FP32", "FP32", "FP32", "FP32", 9 | "QAT", "QAT", "QAT", "QAT", "QAT", 10 | "4 Bit", "4 Bit", "4 Bit", "4 Bit", 11 | "8 Bit", "8 Bit", "8 Bit", "8 Bit", "8 Bit" 12 | ], 13 | "Postquantization": [ 14 | 1, 1.6, 2, 4, 8, None, 15 | 1, 1.6, 2, 4, 8, 16 | 1, 1.6, 2, 4, 17 | 1, 1.6, 2, 4, 8 18 | ], 19 | "Test/Accuracy": [ 20 | 51.74, 69.96, 80.09, 85.39, 97.92, 98.22, 21 | 97.03, 97.39, 97.68, 97.92, 98.02, 22 | 70.14, 91.99, 96.63, 97.92, 23 | 59.8, 83.26, 87.8, 94.34, 98.02 24 | ] 25 | } 26 | df = pd.DataFrame(data) 27 | #%% 28 | 29 | # Drop rows with NaN in 'Postquantization' 30 | df_clean = df.dropna(subset=['Postquantization']) 31 | 32 | # Plotting setup 33 | plt.figure(figsize=(8, 5)) 34 | markers = ['o', 's', '^', 'D', 'x'] # Different symbols 35 | colors = ['blue', 'green', 'red', 'purple', 'orange'] # Different colors 36 | 37 | quantization_categories = df_clean['Quantization during training'].unique() 38 | 39 | for i, category in enumerate(quantization_categories): 40 | subset = df_clean[df_clean['Quantization during training'] == category] 41 | plt.plot(subset['Postquantization'], subset['Test/Accuracy'], marker=markers[i], color=colors[i], label=category, linestyle='-', markersize=8) 42 | 43 | plt.xlabel('Postquantization [bits]', fontsize=12) 44 | plt.ylabel('Test/Accuracy [%]', fontsize=12) 45 | plt.title('Test Accuracy vs Postquantization, Grouped by Quantization during Training', fontsize=14) 46 | plt.legend(title='Quantization during training') 47 | plt.grid(True, which='both', linestyle='--', linewidth=0.5) 48 | 49 | plt.show() 50 | 51 | # %% 52 | -------------------------------------------------------------------------------- /docs/plots/errorrate.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | import pandas as pd 5 | import numpy as np 6 | 7 | # Load your data 8 | # Assuming df is already loaded as shown earlier 9 | # df = pd.read_csv('your_data_file.csv', sep='\t') 10 | df = pd.read_csv('clean_30epruns.txt', sep='\t') 11 | 12 | # Preparing the figure 13 | fig, axes = plt.subplots(1, 2, figsize=(16, 6), sharey=True) 14 | 15 | # Defining color palette and markers 16 | palette = sns.color_palette('tab10', n_colors=len(df['QuantType'].unique())) 17 | marker_dict = {quant_type: marker for quant_type, marker in zip(df['QuantType'].unique(), ['o', 's', 'D', '^', 'v', '<', '>', 'p', '*', '+'])} 18 | 19 | for i, quant_type in enumerate(df['QuantType'].unique()): 20 | # Filter data for each quantization type 21 | temp_df = df[df['QuantType'] == quant_type].copy() 22 | 23 | # Calculate Error Rates as 100 - Accuracy 24 | temp_df['Error Rate/train'] = 100 - temp_df['Accuracy/train'] 25 | temp_df['Error Rate/test'] = 100 - temp_df['Accuracy/test'] 26 | 27 | # Plot Training Error Rate 28 | axes[0].scatter(temp_df['Totalbits'], temp_df['Error Rate/train'], color=palette[i], label=quant_type, marker=marker_dict[quant_type]) 29 | 30 | # Plot Test Error Rate 31 | axes[1].scatter(temp_df['Totalbits'], temp_df['Error Rate/test'], color=palette[i], marker=marker_dict[quant_type]) 32 | 33 | # Setting log scale for x-axis and y-axis 34 | axes[0].set_xscale('log') 35 | axes[0].set_yscale('log') 36 | axes[1].set_xscale('log') 37 | 38 | # Adding titles and labels 39 | axes[0].set_title('Training Error Rate vs. Total Bits', fontsize=14) 40 | axes[0].set_xlabel('Total Bits (log scale)', fontsize=12) 41 | axes[0].set_ylabel('Error Rate (%)', fontsize=12) 42 | 43 | axes[1].set_title('Test Error Rate vs. Total Bits', fontsize=14) 44 | axes[1].set_xlabel('Total Bits (log scale)', fontsize=12) 45 | 46 | # Adding legend to the first plot 47 | axes[0].legend(title='Quantization Type', bbox_to_anchor=(1.05, 1), loc='upper left') 48 | 49 | plt.tight_layout() 50 | plt.show() 51 | 52 | # %% 53 | -------------------------------------------------------------------------------- /docs/explorationaugmented.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import pandas as pd 3 | 4 | # Data 5 | data = { 6 | "w1": [64, 64, 64, 56, 72, 80, 96, 48, 48, 72, 72, 64, 64, 64, 56, 48, 48, 40, 40, 40, 40, 40], 7 | "w2": [64, 64, 48, 56, 80, 80, 96, 48, 48, 64, 72, 96, 128, 72, 112, 160, 128, 64, 40, 40, 40, 40], 8 | "w3": [64, 0, 32, 56, 0, 80, 0, 0, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 9 | "BPW": [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, 4, 5, 2], 10 | "total_kbyte": [12.3125, 10.3125, 10.40625, 10.3359375, 12.203125, 16.640625, 16.96875, 7.359375, 11 | 8.484375, 11.5625, 11.8828125, 11.46875, 12.625, 10.6015625, 10.609375, 10.53125, 12 | 9.625, 6.5625, 11.953125, 5.976563, 7.460703, 2.988281], 13 | "Layers": [3, 2, 3, 3, 2, 3, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], 14 | "test_accuracy": [99.01, 98.79, 98.51, 98.73, 98.98, 99.04, 98.98, 98.52, 98.72, 98.85, 98.94, 98.96, 15 | 99.03, 98.76, 98.87, 98.79, 98.7, 98.29, 98.6, 98.13, 98.44, 96.79] 16 | } 17 | 18 | df = pd.DataFrame(data) 19 | 20 | # Filter out a specific data point 21 | df_filtered = df[~((df['w1'] == 40) & (df['w2'] == 40) & (df['w3'] == 0) & (df['BPW'] == 2))] 22 | 23 | # Group the filtered data by the number of layers 24 | groups_filtered = df_filtered.groupby('Layers') 25 | 26 | # Colors and markers 27 | colors = {2: 'blue', 3: 'green'} 28 | markers = {2: 'o', 3: 's'} # Circle for 2 layers, square for 3 layers 29 | 30 | # Plotting 31 | fig, ax = plt.subplots(figsize=(8, 6)) 32 | 33 | for name, group in groups_filtered: 34 | ax.scatter(group['total_kbyte'], group['test_accuracy'], c=colors[name], label=labels[name], 35 | marker=markers[name], s=100, alpha=0.6, edgecolors='none') 36 | for i, row in group.iterrows(): 37 | label = f"{int(row['w1'])}/{int(row['w2'])}/{int(row['w3'])}/{int(row['BPW'])}b" 38 | # Bold specific labels 39 | if label == "40/40/0/8b" or label == "64/48/32/4b": 40 | ax.text(row['total_kbyte'] + 0.1, row['test_accuracy'], label, fontsize=8, fontweight='bold', 41 | verticalalignment='center', horizontalalignment='left') 42 | else: 43 | ax.text(row['total_kbyte'] + 0.1, row['test_accuracy'], label, fontsize=8, verticalalignment='center', 44 | horizontalalignment='left') 45 | 46 | ax.set_xlabel('Total kbyte') 47 | -------------------------------------------------------------------------------- /docs/plots/networksize_subset.txt: -------------------------------------------------------------------------------- 1 | QuantType BPW batch_size network_width1 network_width2 network_width3 Parameters Totalbits Accuracy/train Accuracy/test Loss/train Loss/test 2 | Binary 1 128 64 64 64 25216 25216 97.59500122 96.80000305 0.075935371 0.10220889 3 | Binary 1 128 64 64 64 25216 25216 97.79333496 96.76999664 0.072463408 0.105802305 4 | Binary 1 128 64 64 64 25216 25216 97.63833618 96.61000061 0.076641515 0.117163301 5 | Binary 1 128 64 64 64 25216 25216 97.72000122 96.70999908 0.073574446 0.103008136 6 | Binary 1 128 64 64 64 25216 25216 97.72000122 96.72000122 0.074169569 0.106362127 7 | Binary 1 128 64 64 64 25216 25216 97.74666595 97.04000092 0.071919426 0.100333385 8 | Binary 1 128 64 64 64 25216 25216 97.67166901 96.91000366 0.074403964 0.10203369 9 | Ternary 1.6 128 64 64 64 25216 40345.60156 98.80500031 97.5 0.038920593 0.086704224 10 | Ternary 1.6 128 64 64 64 25216 40345.60156 98.82666779 97.66000366 0.039240263 0.080099061 11 | 2bitsym 2 128 64 48 48 22240 44480 98.94999695 97.62999725 0.034304049 0.081468672 12 | 2bitsym 2 128 64 64 48 24032 48064 99.12999725 97.79000092 0.030079897 0.07552468 13 | 2bitsym 2 128 64 64 48 24032 48064 99.06666565 97.73000336 0.030958444 0.081059061 14 | 2bitsym 2 128 64 64 80 26400 52800 99.04833221 97.80000305 0.030868249 0.076512448 15 | 2bitsym 2 128 64 64 64 25216 50432 99.15666962 97.95999908 0.029203488 0.074282631 16 | 2bitsym 2 128 64 64 64 25216 50432 99.07499695 97.66999817 0.030472489 0.086960226 17 | 4bitsym 4 128 64 64 64 25216 100864 99.77666473 98.20999908 0.011492738 0.067927346 18 | 4bitsym 4 128 64 64 64 25216 100864 99.73666382 97.98000336 0.012213291 0.075106576 19 | 4bitsym 4 128 64 64 64 25216 100864 99.76833344 97.76000214 0.01179162 0.07858 20 | 8bit 8 128 64 64 64 25216 201728 99.89167023 98.13999939 0.00680928 0.066086918 21 | 8bit 8 128 64 64 64 25216 201728 99.86833191 98.02999878 0.007506728 0.073647425 22 | 8bit 8 128 64 64 64 25216 201728 99.84333038 97.93000031 0.009154263 0.078787938 23 | 8bit 8 128 64 64 64 25216 201728 99.88833618 97.98999786 0.007123163 0.074402936 24 | None 32 128 64 64 64 25216 806912 99.25166321 98.75 0.025440145 0.043007035 25 | None 32 128 64 64 64 25216 806912 99.88500214 98.12000275 0.007095081 0.06810613 26 | None 32 128 64 64 64 25216 806912 99.87666321 98.08999634 0.007548976 0.075958706 27 | None 32 128 64 64 64 25216 806912 99.54666901 98.09999847 0.016896078 0.076675124 28 | None 32 128 64 64 64 25216 806912 99.83166504 98 0.008510431 0.07825122 29 | -------------------------------------------------------------------------------- /mcu/readme.md: -------------------------------------------------------------------------------- 1 | # 🚀 BitNetMCU Inference Engine Demo for CH32V003 (and other) MCU 2 | 3 | This folder contains a demo that implements the BitNetMCU inference engine on an actual CH32V003 MCU with 16kb flash and 2kb ram. This example is to be used with the [ch32fun SDK](https://github.com/cnlohr/ch32fun). 4 | 5 | ## File Descriptions 6 | 7 | - [`Makefile`](Makefile): The default makefile assumes that you clone ch32v003fun as a submodule. This Makefile will compile the demo and flash it to the MCU. Change this file to retarget other MCUs. 8 | 9 | - [`funconfig.h`](funconfig.h): Configuration of main clock speed, and SysTick clock source. 10 | 11 | - [`BitNetMCUdemo.c`](BitNetMCUdemo.c): This is the main C file for the demo. It includes the BitNetMCU inference engine and model from the main folder. It will perform inference on included test images and output the results to the monitoring console. 12 | 13 | ## Models 14 | 15 | Different models can be selected by including the respective header file in `BitNetMCUdemo.c`. The following model file are included. Execution timings are measured on a CH32V002 at 48MHz. 16 | 17 | | File Name | Configuration | CNN Width | Size (kB) | Test Accuracy | Cycles Avg. | Time (ms) | 18 | |-----------|---------------|-------|-----------|---------------|-------------|-----------| 19 | | `BitNetMCU_model_cnn_16small.h` | 16-wide CNN, small fc | 16 | 3.2 | 98.92% | 686,490 | 14.30 | 20 | | `BitNetMCU_model_cnn_16.h` | 16-wide CNN | 16 | 5.4 | 99.06% | 785,123 | 16.36 | 21 | | `BitNetMCU_model_cnn_32.h` | 32-wide CNN | 32 | 7.3 | 99.28% | 1,434,667 | 29.89 | 22 | | `BitNetMCU_model_cnn_48.h` | 48-wide CNN | 48 | 9.3 | 99.44% | 2,083,568 | 43.41 | 23 | | `BitNetMCU_model_cnn_64.h` | 64-wide CNN | 64 | 11.0 | 99.55% | 2,736,250 | 57.01 | 24 | | `BitNetMCU_model_1k.h` | 1k 2Bitsym FC | - | 1.1 | 94.22% | 99,783 | 2.08 | 25 | | `BitNetMCU_model_12k.h` | 12k 4Bitsym FC | - | 12.3 | 99.02% | 528,377 | 11.01 | 26 | | `BitNetMCU_model_12k_FP130.h` | 12k FP130 FC | - | 12.3 | 98.86% | 481,624 | 10.03 | 27 | 28 | Take a look at the documentation for more details on the model architecture and trade offs: [Documentation](../docs/documentation.md) 29 | 30 | ## Usage 31 | 32 | ``` 33 | make flash 34 | make monitor 35 | ``` 36 | Example output 37 | 38 | ![Example output on Monitor](console.png) 39 | -------------------------------------------------------------------------------- /trainingparameters.yaml: -------------------------------------------------------------------------------- 1 | # Description: Training parameters for the training script 2 | 3 | # Model selection 4 | model: 'CNNMNIST' # 'FCMNIST' or 'CNNMNST' This is the class name of the model as defined in models.py. 5 | dataset: 'MNIST' # 'MNIST', or EMNIST splits: EMNIST_BALANCED, EMNIST_BYCLASS, EMNIST_BYMERGE, EMNIST_LETTERS, EMNIST_DIGITS, EMNIST_MNIST 6 | 7 | # Quantization settings 8 | QuantType: '4bitsym' # 'Ternary', 'Binary', 'BinaryBalanced', '2bitsym', '4bit', '4bitsym', '8bit', 'None", 'FP130', 'NF4' 9 | NormType: 'RMS' # 'RMS', 'Lin', 'BatchNorm' 10 | WScale: 'PerTensor' # 'PerTensor', 'PerOutput' 11 | 12 | # Clipping parameters - only used for 2 bit and higher quantization 13 | maxw_algo: 'octav' # 'octav', 'prop' Algorithm used to calculate the clipping parameters (maximum weight) 14 | maxw_update_until_epoch: 60 # Update clipping parameters until this epoch, they are frozen afterwards 15 | maxw_quantscale: 0.25 # Used only for clipping_algo='prop'. Determines the relation between stddev of weights and max_weight 16 | 17 | # Learning parameters 18 | num_epochs: 60 # 5, 20, 80 19 | batch_size: 64 20 | scheduler: "Cosine" # "StepLR", "Cosine", "CosineWarmRestarts" 21 | learning_rate: 0.001 22 | 23 | # CosineWarmRestarts parameters 24 | # T_0: 5 # Period of the first restart for CosineWarmRestarts - 10+20+40 = 70 epochs, need to step in epoch 69 at minimum LR 25 | # T_mult: 4 # Factor increasing T_i after a restart 26 | 27 | # StepLR parameters 28 | # lr_decay: 0.1 # lr_decay and step size for StepLR 29 | # step_size: 10 30 | 31 | # halve_lr_epoch: 30 # Epoch at which to halve the learning rate - to be used with Cosine schedule 32 | 33 | # Data augmentation 34 | augmentation: True 35 | rotation1: 10 # rotation1 and rotation2 are used for data augmentation 36 | rotation2: 10 37 | elastictransformprobability: 0.0 # probability of applying elastic transform 38 | 39 | # channel pruning settings. Requires "MaskLayer" in the model, otherwise these settings have no effect 40 | lambda_l1: 0.0005 # L1 regularization parameter for mask learning 41 | prune_epoch: -1 # Epoch at which to start pruning. -1 means no pruning 42 | prune_groupstoprune: 32 # number of groups to prune 43 | prune_totalgroups: 96 # total number of groups. e.g. if there are 384 channels and 96 groups, then each group has 4 channels 44 | 45 | # Model parameters 46 | network_width1: 96 47 | network_width2: 64 48 | network_width3: 0 49 | 50 | # name 51 | runtag: "octav" # runtag is prefix for runname 52 | -------------------------------------------------------------------------------- /docs/plots/accuracy_loss_plots.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | import seaborn as sns 5 | 6 | # Load the data 7 | df = pd.read_csv('clean_30epruns.txt', sep='\t') 8 | 9 | #%% 10 | 11 | # Filter out '8bit' and 'None' quantization types 12 | df_filtered = df[~df['QuantType'].isin(['8bit', 'None'])] 13 | 14 | # Prepare combined data for accuracy plot 15 | train_data = df_filtered[['Totalbits', 'Accuracy/train', 'Loss/train']].rename(columns={'Accuracy/train': 'Accuracy', 'Loss/train': 'Loss'}) 16 | train_data['Type'] = 'Train' 17 | test_data = df_filtered[['Totalbits', 'Accuracy/test', 'Loss/test']].rename(columns={'Accuracy/test': 'Accuracy', 'Loss/test': 'Loss'}) 18 | test_data['Type'] = 'Test' 19 | combined_data = pd.concat([train_data, test_data]) 20 | 21 | # Prepare combined data for loss plot 22 | train_loss_data = df_filtered[['Totalbits', 'Loss/train']].rename(columns={'Loss/train': 'Loss'}) 23 | train_loss_data['Type'] = 'Train' 24 | test_loss_data = df_filtered[['Totalbits', 'Loss/test']].rename(columns={'Loss/test': 'Loss'}) 25 | test_loss_data['Type'] = 'Test' 26 | combined_loss_data = pd.concat([train_loss_data, test_loss_data]) 27 | 28 | # Plotting Accuracy 29 | plt.figure(figsize=(8, 4)) 30 | sns.scatterplot(data=combined_data, x='Totalbits', y='Accuracy', hue='Type', style='Type', palette='Set1', s=100, edgecolor='black', linewidth=0.5) 31 | position_bits = 8 * 12 * 1024 # Corrected position 32 | plt.axvline(x=position_bits, color='r', linestyle='--') 33 | plt.text(position_bits, combined_data['Accuracy'].min(), '12kbyte', color='r', ha='right') 34 | plt.title('Train vs. Test Accuracy vs. Total Bits', fontsize=14) 35 | plt.xlabel('Total Bits (log scale)', fontsize=12) 36 | plt.ylabel('Accuracy', fontsize=12) 37 | plt.xscale('log') 38 | plt.legend(title='Data Type', bbox_to_anchor=(1.05, 1), loc='upper left') 39 | plt.tight_layout() 40 | plt.show() 41 | 42 | # Plotting Loss 43 | plt.figure(figsize=(8, 4)) 44 | sns.scatterplot(data=combined_loss_data, x='Totalbits', y='Loss', hue='Type', style='Type', palette='Set1', s=100, edgecolor='black', linewidth=0.5) 45 | plt.axvline(x=position_bits, color='r', linestyle='--') 46 | plt.text(position_bits, combined_loss_data['Loss'].min(), '12kbyte', color='r', ha='right') 47 | plt.title('Train vs. Test Loss vs. Total Bits', fontsize=14) 48 | plt.xlabel('Total Bits (log scale)', fontsize=12) 49 | plt.ylabel('Loss', fontsize=12) 50 | plt.xscale('log') 51 | plt.yscale('log') 52 | plt.legend(title='Data Type', bbox_to_anchor=(1.05, 1), loc='upper left') 53 | plt.tight_layout() 54 | plt.show() 55 | 56 | # %% 57 | -------------------------------------------------------------------------------- /BitNetMCU_inference.h: -------------------------------------------------------------------------------- 1 | #ifndef BITNETMCU_INFERENCE_H 2 | #define BITNETMCU_INFERENCE_H 3 | 4 | #include 5 | 6 | /** 7 | * @brief Applies a ReLU activation function to an array of integers and normalizes the result to 8-bit integers. 8 | * 9 | * @param input Pointer to the input array of 32-bit integers. 10 | * @param output Pointer to the output array of 8-bit integers. 11 | * @param n_input The number of elements in the input array. 12 | * @return The position of maximum value found in the input array before applying the ReLU activation. 13 | */ 14 | 15 | uint32_t ReLUNorm(int32_t *input, int8_t *output, uint32_t n_inpu); 16 | 17 | 18 | /** 19 | * @brief Processes a fully connected layer in a neural network. 20 | * 21 | * This function processes a fully connected layer in a neural network by performing 22 | * the dot product of the input activations and weights, and stores the result in the output array. 23 | * 24 | * @param activations Pointer to the input activations of the layer. 25 | * @param weights Pointer to the weights of the layer. 26 | * @param bits_per_weight The number of bits per weight. 27 | * @param n_input The number of input neurons. 28 | * @param n_output The number of output neurons. 29 | * @param output Pointer to the output array where the result of the layer is stored. 30 | */ 31 | void processfclayer(int8_t *input, const uint32_t *weights, int32_t bits_per_weight, uint32_t incoming_weights, uint32_t outgoing_weights, int32_t *output); 32 | 33 | 34 | /** 35 | * @brief fused 3x3 conv2d and ReLU activation function 36 | 37 | * @param activations Pointer to the input activations of the layer. 38 | * @param weights Pointer to the weights of the layer. 39 | * @param xy_input The number of input neurons. 40 | * @param n_shift The number of bits to shift the result of the convolution after summation, typically either 8+3=11 or 8+4=12. 41 | * @param output Pointer to the output array where the result of the layer is stored. 42 | * @return Pointer to the end of the output array. 43 | */ 44 | 45 | int32_t* processconv33ReLU(int32_t *activations, const int8_t *weights, uint32_t xy_input, uint32_t n_shift , int32_t *output); 46 | 47 | 48 | /** 49 | * @brief maxpool2d 2x2 function 50 | * 51 | * This function performs a 2x2 max pooling operation on a 2D array of input activations. 52 | * The function divides the input activations into 2x2 non-overlapping regions and selects the maximum value in each region. 53 | * * 54 | * @param activations Pointer to the input activations of the layer. 55 | * @param xy_input The number of input neurons. 56 | * @param output Pointer to the output array where the result of the layer is stored. 57 | * @return Pointer to the end of the output array. 58 | */ 59 | 60 | int32_t *processmaxpool22(int32_t *activations, uint32_t xy_input, int32_t *output); 61 | 62 | 63 | 64 | #endif // BITNETMCU_INFERENCE_H -------------------------------------------------------------------------------- /docs/plots/plot12kopt.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | import seaborn as sns 5 | 6 | # Read the data 7 | data_path = '12kruns.txt' 8 | data = pd.read_csv(data_path, delimiter=",", skiprows=1, header=None) 9 | 10 | #%% 11 | # Assign column headers (from first row of the original file) 12 | headers = pd.read_csv(data_path, nrows=0).columns.tolist() 13 | data.columns = headers 14 | 15 | # Extract schedule from runname 16 | data['Schedule'] = data['runname'].str.extract(r'(Lin|Cos)') 17 | 18 | # Function to categorize groups with descriptive names 19 | def categorize_group(row): 20 | if row['Schedule'] == 'Lin' and row['num_epochs'] == 30: 21 | return 'A: L_30ep' 22 | elif row['Schedule'] == 'Cos' and row['num_epochs'] == 30: 23 | return 'B: C_30ep' 24 | elif row['Schedule'] == 'Cos' and row['num_epochs'] == 60: 25 | return 'C: C_60ep' 26 | elif row['Schedule'] == 'Cos' and row['num_epochs'] == 120: 27 | return 'D: C_120ep' 28 | else: 29 | return 'Other' # To catch any cases that don't match expected conditions 30 | 31 | # Applying the categorization function 32 | data['Group'] = data.apply(categorize_group, axis=1) 33 | 34 | # Ensuring groups are ordered correctly for plotting 35 | group_order = ['A: L_30ep', 'B: C_30ep', 'C: C_60ep', 'D: C_120ep'] 36 | data['Group'] = pd.Categorical(data['Group'], categories=group_order, ordered=True) 37 | 38 | # Set up the plotting 39 | sns.set(style="whitegrid") 40 | fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 6)) 41 | 42 | # Training Accuracy 43 | sns.lineplot(ax=axes[0, 0], data=data, x='Group', y='Accuracy/train', hue='QuantType', style='QuantType', markers=True, dashes=False, palette='tab10', markersize=10) 44 | axes[0, 0].set_title('Training Accuracy by Group and QuantType') 45 | axes[0, 0].set_xlabel('Group') 46 | axes[0, 0].set_ylabel('Train Accuracy (%)') 47 | 48 | # Training Loss 49 | sns.lineplot(ax=axes[0, 1], data=data, x='Group', y='Loss/train', hue='QuantType', style='QuantType', markers=True, dashes=False, palette='tab10', markersize=10) 50 | axes[0, 1].set_yscale('log') 51 | axes[0, 1].set_title('Training Loss by Group and QuantType') 52 | axes[0, 1].set_xlabel('Group') 53 | axes[0, 1].set_ylabel('Train Loss (log scale)') 54 | 55 | # Test Accuracy 56 | sns.lineplot(ax=axes[1, 0], data=data, x='Group', y='Accuracy/test', hue='QuantType', style='QuantType', markers=True, dashes=False, palette='tab10', markersize=10) 57 | axes[1, 0].set_title('Test Accuracy by Group and QuantType') 58 | axes[1, 0].set_xlabel('Group') 59 | axes[1, 0].set_ylabel('Test Accuracy (%)') 60 | 61 | # Test Loss 62 | sns.lineplot(ax=axes[1, 1], data=data, x='Group', y='Loss/test', hue='QuantType', style='QuantType', markers=True, dashes=False, palette='tab10', markersize=10) 63 | axes[1, 1].set_yscale('log') 64 | axes[1, 1].set_title('Test Loss by Group and QuantType') 65 | axes[1, 1].set_xlabel('Group') 66 | axes[1, 1].set_ylabel('Test Loss (log scale)') 67 | 68 | plt.tight_layout() 69 | plt.show() 70 | 71 | # %% 72 | -------------------------------------------------------------------------------- /docs/plots/plot12kaugopt.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | import seaborn as sns 5 | 6 | # Read the data 7 | data_path = '12kaugruns.txt' 8 | data = pd.read_csv(data_path, delimiter=",", skiprows=1, header=None) 9 | 10 | #%% 11 | # Assign column headers (from first row of the original file) 12 | headers = pd.read_csv(data_path, nrows=0).columns.tolist() 13 | data.columns = headers 14 | 15 | # Extract schedule from runname 16 | data['Schedule'] = data['runname'].str.extract(r'(Lin|cos)') 17 | 18 | # Function to categorize groups with descriptive names 19 | def categorize_group(row): 20 | if row['Schedule'] == 'Lin' and row['num_epochs'] == 30: 21 | return 'A: L_30ep' 22 | elif row['Schedule'] == 'cos' and row['num_epochs'] == 30: 23 | return 'B: C_30ep' 24 | elif row['Schedule'] == 'cos' and row['num_epochs'] == 60: 25 | return 'C: C_60ep' 26 | elif row['Schedule'] == 'cos' and row['num_epochs'] == 120: 27 | return 'D: C_120ep' 28 | else: 29 | return 'Other' # To catch any cases that don't match expected conditions 30 | 31 | # Applying the categorization function 32 | data['Group'] = data.apply(categorize_group, axis=1) 33 | 34 | # Ensuring groups are ordered correctly for plotting 35 | group_order = ['A: L_30ep', 'B: C_30ep', 'C: C_60ep', 'D: C_120ep'] 36 | data['Group'] = pd.Categorical(data['Group'], categories=group_order, ordered=True) 37 | 38 | # Set up the plotting 39 | sns.set(style="whitegrid") 40 | fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 8)) 41 | 42 | # Training Accuracy 43 | sns.lineplot(ax=axes[0, 0], data=data, x='Group', y='Accuracy/train', hue='QuantType', style='QuantType', markers=True, dashes=False, palette='tab10', markersize=10) 44 | axes[0, 0].set_title('Training Accuracy by Group and QuantType') 45 | axes[0, 0].set_xlabel('Group') 46 | axes[0, 0].set_ylabel('Train Accuracy (%)') 47 | 48 | # Training Loss 49 | sns.lineplot(ax=axes[0, 1], data=data, x='Group', y='Loss/train', hue='QuantType', style='QuantType', markers=True, dashes=False, palette='tab10', markersize=10) 50 | # axes[0, 1].set_yscale('log') 51 | axes[0, 1].set_title('Training Loss by Group and QuantType') 52 | axes[0, 1].set_xlabel('Group') 53 | axes[0, 1].set_ylabel('Train Loss (log scale)') 54 | 55 | # Test Accuracy 56 | sns.lineplot(ax=axes[1, 0], data=data, x='Group', y='Accuracy/test', hue='QuantType', style='QuantType', markers=True, dashes=False, palette='tab10', markersize=10) 57 | axes[1, 0].set_title('Test Accuracy by Group and QuantType') 58 | axes[1, 0].set_xlabel('Group') 59 | axes[1, 0].set_ylabel('Test Accuracy (%)') 60 | 61 | # Test Loss 62 | sns.lineplot(ax=axes[1, 1], data=data, x='Group', y='Loss/test', hue='QuantType', style='QuantType', markers=True, dashes=False, palette='tab10', markersize=10) 63 | # axes[1, 1].set_yscale('log') 64 | axes[1, 1].set_title('Test Loss by Group and QuantType') 65 | axes[1, 1].set_xlabel('Group') 66 | axes[1, 1].set_ylabel('Test Loss (log scale)') 67 | 68 | plt.tight_layout() 69 | plt.show() 70 | 71 | # %% 72 | -------------------------------------------------------------------------------- /docs/plots/12kaugruns.txt: -------------------------------------------------------------------------------- 1 | num_epochs,QuantType,BPW,NormType,WScale,batch_size,network_width1,network_width2,network_width3,runname,Parameters,Totalbits,Accuracy/train,Accuracy/test,Loss/train,Loss/test 2 | 120.0,2bitsym,2.0,RMS,PerTensor,128.0,112.0,122.0,96.0,Opt12k_cos_Aug_BitMnist_PerTensor_2bitsym_RMS_width112_122_96_lr0.001_decay0.1_stepsize10_bs128_epochs120,55008.0,110016.0,98.43499755859375,98.73798370361328,0.05160922929644585,0.04045324772596359 3 | 30.0,2bitsym,2.0,RMS,PerTensor,128.0,112.0,122.0,96.0,Opt12k_cos_Aug_BitMnist_PerTensor_2bitsym_RMS_width112_122_96_lr0.001_decay0.1_stepsize10_bs128_epochs30,55008.0,110016.0,97.83333587646484,98.7880630493164,0.06734599173069,0.039318375289440155 4 | 60.0,2bitsym,2.0,RMS,PerTensor,128.0,112.0,122.0,96.0,Opt12k_cos_Aug_BitMnist_PerTensor_2bitsym_RMS_width112_122_96_lr0.001_decay0.1_stepsize10_bs128_epochs60,55008.0,110016.0,98.26416778564453,98.80809020996094,0.05530998855829239,0.03794977441430092 5 | 120.0,4bitsym,4.0,RMS,PerTensor,128.0,64.0,64.0,64.0,Opt12k_cos_Aug_BitMnist_PerTensor_4bitsym_RMS_width64_64_64_lr0.001_decay0.1_stepsize10_bs128_epochs120,25216.0,100864.0,98.37000274658203,98.83814239501953,0.05225309729576111,0.03491820767521858 6 | 30.0,4bitsym,4.0,RMS,PerTensor,128.0,64.0,64.0,64.0,Opt12k_cos_Aug_BitMnist_PerTensor_4bitsym_RMS_width64_64_64_lr0.001_decay0.1_stepsize10_bs128_epochs30,25216.0,100864.0,97.69916534423828,98.5977554321289,0.073830746114254,0.040835313498973846 7 | 60.0,4bitsym,4.0,RMS,PerTensor,128.0,64.0,64.0,64.0,Opt12k_cos_Aug_BitMnist_PerTensor_4bitsym_RMS_width64_64_64_lr0.001_decay0.1_stepsize10_bs128_epochs60,25216.0,100864.0,98.1308364868164,98.88822174072266,0.061817750334739685,0.03904476389288902 8 | 120.0,Binary,1.0,RMS,PerTensor,128.0,176.0,160.0,160.0,Opt12k_cos_Aug_BitMnist_PerTensor_Binary_RMS_width176_160_160_lr0.001_decay0.1_stepsize10_bs128_epochs120,100416.0,100416.0,98.00666809082031,98.83814239501953,0.06432902067899704,0.0384233258664608 9 | 30.0,Binary,1.0,RMS,PerTensor,128.0,176.0,160.0,160.0,Opt12k_cos_Aug_BitMnist_PerTensor_Binary_RMS_width176_160_160_lr0.001_decay0.1_stepsize10_bs128_epochs30,100416.0,100416.0,97.54582977294922,98.55769348144531,0.07620281726121902,0.044467128813266754 10 | 60.0,Binary,1.0,RMS,PerTensor,128.0,176.0,160.0,160.0,Opt12k_cos_Aug_BitMnist_PerTensor_Binary_RMS_width176_160_160_lr0.001_decay0.1_stepsize10_bs128_epochs60,100416.0,100416.0,97.8758316040039,98.6278076171875,0.06653448939323425,0.03925606235861778 11 | 120.0,Ternary,1.6,RMS,PerTensor,128.0,128.0,128.0,112.0,Opt12k_cos_Aug_BitMnist_PerTensor_Ternary_RMS_width128_128_112_lr0.001_decay0.1_stepsize10_bs128_epochs120,64608.0,103372.796875,98.28916931152344,98.75801086425781,0.05376652628183365,0.040183547884225845 12 | 30.0,Ternary,1.6,RMS,PerTensor,128.0,128.0,128.0,112.0,Opt12k_cos_Aug_BitMnist_PerTensor_Ternary_RMS_width128_128_112_lr0.001_decay0.1_stepsize10_bs128_epochs30,64608.0,103372.796875,97.72916412353516,98.54767608642578,0.07222413271665573,0.04652818292379379 13 | 60.0,Ternary,1.6,RMS,PerTensor,128.0,128.0,128.0,112.0,Opt12k_cos_Aug_BitMnist_PerTensor_Ternary_RMS_width128_128_112_lr0.001_decay0.1_stepsize10_bs128_epochs60,64608.0,103372.796875,98.1449966430664,98.85816955566406,0.05941591039299965,0.03926968574523926 14 | -------------------------------------------------------------------------------- /docs/plots/create_cnn_plots.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | # Data from the table 5 | data = [ 6 | {"config": "16-wide CNN, small fc", "width": 16, "size": 3.2, "accuracy": 98.92, "cycles": 686490, "time": 14.30, "type": "CNN"}, 7 | {"config": "16-wide CNN", "width": 16, "size": 5.4, "accuracy": 99.06, "cycles": 785123, "time": 16.36, "type": "CNN"}, 8 | {"config": "32-wide CNN", "width": 32, "size": 7.3, "accuracy": 99.28, "cycles": 1434667, "time": 29.89, "type": "CNN"}, 9 | {"config": "48-wide CNN", "width": 48, "size": 9.3, "accuracy": 99.44, "cycles": 2083568, "time": 43.41, "type": "CNN"}, 10 | {"config": "64-wide CNN", "width": 64, "size": 11.0, "accuracy": 99.55, "cycles": 2736250, "time": 57.01, "type": "CNN"}, 11 | {"config": "4Bitsym FC", "width": None, "size": 12.3, "accuracy": 99.02, "cycles": 528377, "time": 11.01, "type": "4Bitsym"}, 12 | {"config": "FP130 FC", "width": None, "size": 12.3, "accuracy": 98.86, "cycles": 481624, "time": 10.03, "type": "FP130"}, 13 | {"config": "4Bitsym FC", "width": None, "size": 7.359375, "accuracy": 98.52, "cycles": None, "time": 6.59, "type": "4Bitsym"}, 14 | {"config": "4Bitsym FC", "width": None, "size": 8.484375, "accuracy": 98.72, "cycles": None, "time": 7.60, "type": "4Bitsym"} 15 | ] 16 | 17 | # Separate data by type 18 | cnn_data = [d for d in data if d["type"] == "CNN"] 19 | fp130_data = [d for d in data if d["type"] == "FP130"] 20 | bitsym_data = [d for d in data if d["type"] == "4Bitsym"] 21 | 22 | # Create figure with subplots 23 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4)) 24 | 25 | # Plot 1: Time vs. Accuracy 26 | ax1.scatter([d["time"] for d in cnn_data], [d["accuracy"] for d in cnn_data], 27 | color='blue', s=100, alpha=0.7, label='CNN', marker='o') 28 | ax1.scatter([d["time"] for d in fp130_data], [d["accuracy"] for d in fp130_data], 29 | color='green', s=100, alpha=0.7, label='FP130', marker='^') 30 | ax1.scatter([d["time"] for d in bitsym_data], [d["accuracy"] for d in bitsym_data], 31 | color='red', s=100, alpha=0.7, label='4Bitsym', marker='s') 32 | 33 | # Add labels for CNN points only 34 | for d in cnn_data: 35 | if "small fc" in d["config"]: 36 | ax1.annotate('16-wide (small)', (d["time"], d["accuracy"]), 37 | xytext=(5, 5), textcoords='offset points', fontsize=9) 38 | else: 39 | ax1.annotate(f'{d["width"]}-wide', (d["time"], d["accuracy"]), 40 | xytext=(5, 5), textcoords='offset points', fontsize=9) 41 | 42 | ax1.set_xlabel('Time (ms)') 43 | ax1.set_ylabel('Test Accuracy (%)') 44 | ax1.set_title('Time vs. Accuracy Trade-off') 45 | ax1.grid(True, alpha=0.3) 46 | ax1.legend() 47 | 48 | # Plot 2: Size vs. Accuracy 49 | ax2.scatter([d["size"] for d in cnn_data], [d["accuracy"] for d in cnn_data], 50 | color='blue', s=100, alpha=0.7, label='CNN', marker='o') 51 | ax2.scatter([d["size"] for d in fp130_data], [d["accuracy"] for d in fp130_data], 52 | color='green', s=100, alpha=0.7, label='FP130', marker='^') 53 | ax2.scatter([d["size"] for d in bitsym_data], [d["accuracy"] for d in bitsym_data], 54 | color='red', s=100, alpha=0.7, label='4Bitsym', marker='s') 55 | 56 | # Add labels for CNN points only 57 | for d in cnn_data: 58 | if "small fc" in d["config"]: 59 | ax2.annotate('16-wide (small)', (d["size"], d["accuracy"]), 60 | xytext=(5, 5), textcoords='offset points', fontsize=9) 61 | else: 62 | ax2.annotate(f'{d["width"]}-wide', (d["size"], d["accuracy"]), 63 | xytext=(5, 5), textcoords='offset points', fontsize=9) 64 | 65 | ax2.set_xlabel('Size (kB)') 66 | ax2.set_ylabel('Test Accuracy (%)') 67 | ax2.set_title('Size vs. Accuracy Trade-off') 68 | ax2.grid(True, alpha=0.3) 69 | ax2.legend() 70 | 71 | # Adjust layout and save 72 | plt.tight_layout() 73 | plt.savefig('cnn_tradeoff_plots.png', dpi=300, bbox_inches='tight') 74 | plt.show() 75 | 76 | print("Trade-off plots created and saved as 'tradeoff_plots.png'") -------------------------------------------------------------------------------- /mcu/BitNetMCU_model_1k.h: -------------------------------------------------------------------------------- 1 | // Automatically generated header file 2 | // Date: 2024-04-19 16:57:41.096677 3 | // Quantized model exported from a11_Opt12k_cos_BitMnist_PerTensor_2bitsym_RMS_width16_16_0_lr0.001_decay0.1_stepsize10_bs128_epochs60.pth 4 | // Generated by exportquant.py 5 | 6 | #include 7 | 8 | #ifndef BITNETMCU_MODEL_H 9 | #define BITNETMCU_MODEL_H 10 | 11 | #define MODEL_FCMNIST 12 | 13 | /* 14 | Total number of bits: 9024 (1.1015625 kbytes) 15 | inference of quantized model 16 | Accuracy/Test of quantized model: 94.22 % 17 | */ 18 | // Number of layers 19 | #define NUM_LAYERS 3 20 | 21 | // Maximum number of activations per layer 22 | #define MAX_N_ACTIVATIONS 16 23 | 24 | // Layer: L1 25 | // QuantType: 2bitsym 26 | #define L1_active 27 | #define L1_bitperweight 2 28 | #define L1_incoming_weights 256 29 | #define L1_outgoing_weights 16 30 | const uint32_t L1_weights[] = {0x8022800, 0x88000, 0x800a888, 0x8aa000, 0x28aa00, 0x8a2282, 0x22aa2808, 0x882aa8a0, 0x82a02a00, 0x828aa20, 0x282a2a2, 0x28aa282, 0x202a82, 0x82202082, 0x20000000, 0x202020, 0xa02aaaaa, 0xa010aaa0, 0x85522ba, 0x81002aa, 0xa929682c, 0xa112d288, 0x2bfff2b8, 0x2ffffcb0, 0x8fcaabf0, 0x8491abe5, 0x4418054, 0x84559154, 0x95908556, 0x80008554, 0x2ffc2f00, 0xaaaa0a88, 0x200, 0x8100f80, 0xa0aea8, 0x2feae88, 0xbf9b23c, 0xb8b5bfc, 0xbfe53f0, 0x3fc63e5, 0x2111f14, 0x17c55, 0xb013554, 0xb45fa94, 0x5ab2fc0, 0x5aa8780, 0xffa82, 0x20aa000, 0x80800800, 0x22000aa2, 0xaaa150a8, 0xba0567a, 0xaa28652, 0xba0251e, 0xaaaf91e, 0x2c29be14, 0x2054f005, 0x2157d154, 0xa95fa854, 0x215ffa50, 0x8183fe2a, 0x2e00eff8, 0x2abfff2, 0xa08a8a0, 0x2220a2a8, 0xa00154aa, 0x800554f8, 0xa088003a, 0x8501227e, 0x43f957c, 0xbff51ba, 0xaffd02f2, 0xabc14a22, 0x15093016, 0x95be9418, 0x954ba82a, 0x438a9aa, 0x40a83e8, 0xa1558ba2, 0x22aaaa08, 0x882a00, 0xabfffa2, 0x2fffe80, 0x288a280, 0x8a82800, 0x8a80c2aa, 0xae1410bc, 0x945107e, 0x855294b8, 0x2153164e, 0x2848b8, 0x2bfffaf0, 0xbffffa0, 0xaabea00, 0x50ba848, 0x80004000, 0x22000, 0x28aeb800, 0x2aa28c10, 0x8802aa14, 0x294efa14, 0x847d514, 0x2814cea0, 0x8515ea4a, 0x604a01a, 0xae090a1e, 0xac01298c, 0xae0a82e, 0xab2a8028, 0xba5412a, 0x2afbef80, 0x2bffe80, 0x8010000, 0x154800, 0x8010a888, 0x2c2aae, 0xa0aabe, 0x228c5bbe, 0x2ab9000f, 0x8e80e212, 0xc93e010, 0x2197f114, 0x2159604, 0x6156028, 0x2a013f8, 0x848bffa, 0xabfffe8, 0x202ba200, 0x8a000000, 0x152000, 0x82008852, 0x228fefd0, 0x80aefff0, 0x2aaafebe, 0xbe2beebe, 0xac050856, 0x81151054, 0x9045556, 0x29445554, 0xa222022, 0xb80ebfa, 0x2efffb0, 0x8bfbfe0, 0x22affea0, 0xaaaaaaaa, 0xaaa882aa, 0xaaa2828a, 0xaaa0a212, 0xaaa0a882, 0x2a52226, 0x9015f052, 0x9157ff92, 0x850fff12, 0xa00ffa1f, 0xaaaff06e, 0xaa20a82a, 0xaba5540a, 0xaa80102a, 0xa855450a, 0xaa0000aa, 0x22208282, 0xaa000102, 0xae91002, 0x2ba889a2, 0xaea82a14, 0xae85fa88, 0xbd557010, 0xac294814, 0xaa80bf0, 0x2e902fe0, 0xae15aa92, 0xae12a052, 0x2828200a, 0x2b808108, 0x2fa010aa, 0xa2802288, 0xa82aaaaa, 0x2a8afa80, 0xa014ff82, 0xaa855ab8, 0x8fa150bc, 0xac01522e, 0xa540214a, 0x50e015a, 0x82f8a03a, 0xa3a2caea, 0x2a96a008, 0xae0ee880, 0x2820afaa, 0xa8aaaa52, 0xa155554a, 0xa015550a, 0xaaaaaa8a, 0xa2aa050a, 0xaa244656, 0xa0000556, 0xaaa05014, 0xaa01fffc, 0xa2012ffe, 0xaaa920fe, 0x8a008052, 0xabfee116, 0x85bfc152, 0x856ea25a, 0xa200281a, 0x20828a88, 0xa1555a0a, 0xa20aaaaa, 0xa8a200, 0x802aa820, 0xabfff80, 0x21008eaa, 0x550a00c, 0x1505500c, 0x1500527a, 0x13bfcaf1, 0x8fffc555, 0xafac0054, 0xa802294, 0xba03ff0, 0x2faabff8, 0x13baff0, 0x24000bfa, 0x2a0a1408, 0x20080882, 0xa02aebea, 0x82bffff8, 0x82e8bff8, 0xb880154, 0xba81555, 0x8082f55, 0x8a013ff4, 0x890fff8, 0x58282e2, 0x896a00f8, 0xac0aa802, 0xaaaaa880, 0xb82ba90, 0x2aba8550, 0x8a00820, 0xa882aa0, 0x8002e0a, 0x25400ffa, 0x5a08ff0, 0x848a6bfa, 0x4fd5bfa, 0x3e56fe8, 0xb112fc0, 0xa0f0a50, 0x8bff7810, 0x2baf6850, 0xe0e5400, 0xe0a8528, 0x82a4a0, 0x8283000, 0xaaa080}; 31 | //first channel is topmost bit 32 | 33 | // Layer: L2 34 | // QuantType: 2bitsym 35 | #define L2_active 36 | #define L2_bitperweight 2 37 | #define L2_incoming_weights 16 38 | #define L2_outgoing_weights 16 39 | const uint32_t L2_weights[] = {0x10922988, 0xbf394a2d, 0x3844766, 0x182c6013, 0x321388cd, 0xaf609844, 0x3907b582, 0xb6def070, 0x274bd891, 0x84b234c9, 0x15682aa9, 0x29b114bb, 0x2c953bfc, 0x2c64f138, 0x80e08867, 0x981dcdb}; 40 | //first channel is topmost bit 41 | 42 | // Layer: L3 43 | // QuantType: 2bitsym 44 | #define L3_active 45 | #define L3_bitperweight 2 46 | #define L3_incoming_weights 16 47 | #define L3_outgoing_weights 10 48 | const uint32_t L3_weights[] = {0xbefee925, 0xcbb856ab, 0x29abc6aa, 0x9ada339f, 0x337aa86e, 0xc6dabae7, 0xd0fecd33, 0x382b6bbd, 0xe2e6c89b, 0x6f21be88}; 49 | //first channel is topmost bit 50 | 51 | #endif 52 | -------------------------------------------------------------------------------- /docs/plots/12kruns.txt: -------------------------------------------------------------------------------- 1 | num_epochs,QuantType,BPW,NormType,WScale,network_width1,network_width2,network_width3,runname,Parameters,Totalbits,Accuracy/train,Accuracy/test,Loss/train,Loss/test 2 | 120.0,2bitsym,2.0,RMS,PerTensor,112.0,96.0,96.0,det12k_Cos_BitMnist_PerTensor_2bitsym_RMS_width112_96_96_lr0.001_decay0.1_stepsize10_bs128_epochs120,49600.0,99200.0,99.83999633789062,98.4375,5.9198926464887336e-05,0.08611243963241577 3 | 30.0,2bitsym,2.0,RMS,PerTensor,112.0,96.0,96.0,det12k_Cos_BitMnist_PerTensor_2bitsym_RMS_width112_96_96_lr0.001_decay0.1_stepsize10_bs128_epochs30,49600.0,99200.0,99.78666687011719,98.03685760498047,0.0032718118745833635,0.08026767522096634 4 | 60.0,2bitsym,2.0,RMS,PerTensor,112.0,96.0,96.0,det12k_Cos_BitMnist_PerTensor_2bitsym_RMS_width112_96_96_lr0.001_decay0.1_stepsize10_bs128_epochs60,49600.0,99200.0,99.83833312988281,98.15705108642578,0.0004393052076920867,0.07754910737276077 5 | 120.0,4bitsym,4.0,RMS,PerTensor,64.0,64.0,64.0,det12k_Cos_BitMnist_PerTensor_4bitsym_RMS_width64_64_64_lr0.001_decay0.1_stepsize10_bs128_epochs120,25216.0,100864.0,99.83999633789062,98.14703369140625,4.2223415221087635e-05,0.12716418504714966 6 | 30.0,4bitsym,4.0,RMS,PerTensor,64.0,64.0,64.0,det12k_Cos_BitMnist_PerTensor_4bitsym_RMS_width64_64_64_lr0.001_decay0.1_stepsize10_bs128_epochs30,25216.0,100864.0,99.76667022705078,97.8565673828125,0.004676529206335545,0.0802917629480362 7 | 60.0,4bitsym,4.0,RMS,PerTensor,64.0,64.0,64.0,det12k_Cos_BitMnist_PerTensor_4bitsym_RMS_width64_64_64_lr0.001_decay0.1_stepsize10_bs128_epochs60,25216.0,100864.0,99.836669921875,98.15705108642578,0.0005398741341196001,0.09897313266992569 8 | 120.0,8bit,8.0,RMS,PerTensor,40.0,32.0,32.0,det12k_Cos_BitMnist_PerTensor_8bit_RMS_width40_32_32_lr0.001_decay0.1_stepsize10_bs128_epochs120,12864.0,102912.0,99.8316650390625,97.35576629638672,0.000584523135330528,0.1700584888458252 9 | 30.0,8bit,8.0,RMS,PerTensor,40.0,32.0,32.0,det12k_Cos_BitMnist_PerTensor_8bit_RMS_width40_32_32_lr0.001_decay0.1_stepsize10_bs128_epochs30,12864.0,102912.0,99.52666473388672,97.63621520996094,0.01578357256948948,0.09619983285665512 10 | 60.0,8bit,8.0,RMS,PerTensor,40.0,32.0,32.0,det12k_Cos_BitMnist_PerTensor_8bit_RMS_width40_32_32_lr0.001_decay0.1_stepsize10_bs128_epochs60,12864.0,102912.0,99.79166412353516,97.48597717285156,0.003540520556271076,0.12196425348520279 11 | 120.0,Binary,1.0,RMS,PerTensor,176.0,160.0,160.0,det12k_Cos_BitMnist_PerTensor_Binary_RMS_width176_160_160_lr0.001_decay0.1_stepsize10_bs128_epochs120,100416.0,100416.0,99.83999633789062,98.3974380493164,0.000297209044219926,0.07841531932353973 12 | 30.0,Binary,1.0,RMS,PerTensor,176.0,160.0,160.0,det12k_Cos_BitMnist_PerTensor_Binary_RMS_width176_160_160_lr0.001_decay0.1_stepsize10_bs128_epochs30,100416.0,100416.0,99.71833038330078,98.32732391357422,0.006229992024600506,0.06156959757208824 13 | 60.0,Binary,1.0,RMS,PerTensor,176.0,160.0,160.0,det12k_Cos_BitMnist_PerTensor_Binary_RMS_width176_160_160_lr0.001_decay0.1_stepsize10_bs128_epochs60,100416.0,100416.0,99.82499694824219,98.28726196289062,0.0016564603429287672,0.06957776099443436 14 | 120.0,Ternary,1.6,RMS,PerTensor,128.0,128.0,112.0,det12k_Cos_BitMnist_PerTensor_Ternary_RMS_width128_128_112_lr0.001_decay0.1_stepsize10_bs128_epochs120,64608.0,103372.796875,99.83999633789062,98.29727935791016,5.413280814536847e-05,0.10357271879911423 15 | 30.0,Ternary,1.6,RMS,PerTensor,128.0,128.0,112.0,det12k_Cos_BitMnist_PerTensor_Ternary_RMS_width128_128_112_lr0.001_decay0.1_stepsize10_bs128_epochs30,64608.0,103372.796875,99.79499816894531,98.29727935791016,0.003240385791286826,0.06467792391777039 16 | 60.0,Ternary,1.6,RMS,PerTensor,128.0,128.0,112.0,det12k_Cos_BitMnist_PerTensor_Ternary_RMS_width128_128_112_lr0.001_decay0.1_stepsize10_bs128_epochs60,64608.0,103372.796875,99.83833312988281,98.29727935791016,0.0004688764165621251,0.08804627507925034 17 | 30.0,2bitsym,2.0,RMS,PerTensor,112.0,96.0,96.0,det12k_Lin_BitMnist_PerTensor_2bitsym_RMS_width112_96_96_lr0.001_decay0.1_stepsize10_bs128_epochs30,49600.0,99200.0,99.58499908447266,98.10697174072266,0.009854994714260101,0.06892337650060654 18 | 30.0,4bitsym,4.0,RMS,PerTensor,64.0,64.0,64.0,det12k_Lin_BitMnist_PerTensor_4bitsym_RMS_width64_64_64_lr0.001_decay0.1_stepsize10_bs128_epochs30,25216.0,100864.0,99.5616683959961,98.0068130493164,0.012504241429269314,0.07720039784908295 19 | 30.0,8bit,8.0,RMS,PerTensor,40.0,32.0,32.0,det12k_Lin_BitMnist_PerTensor_8bit_RMS_width40_32_32_lr0.001_decay0.1_stepsize10_bs128_epochs30,12864.0,102912.0,99.07499694824219,97.63621520996094,0.03037620522081852,0.08293931931257248 20 | 30.0,Binary,1.0,RMS,PerTensor,176.0,160.0,160.0,det12k_Lin_BitMnist_PerTensor_Binary_RMS_width176_160_160_lr0.001_decay0.1_stepsize10_bs128_epochs30,100416.0,100416.0,99.46833038330078,98.2772445678711,0.013757308013737202,0.05867868661880493 21 | 30.0,Ternary,1.6,RMS,PerTensor,128.0,128.0,112.0,det12k_Lin_BitMnist_PerTensor_Ternary_RMS_width128_128_112_lr0.001_decay0.1_stepsize10_bs128_epochs30,64608.0,103372.796875,99.58000183105469,98.19711303710938,0.010668843984603882,0.06755951046943665 22 | -------------------------------------------------------------------------------- /BitNetMCU_MNIST_dll.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "BitNetMCU_model.h" 5 | #include "BitNetMCU_inference.c" 6 | 7 | /** 8 | * @file Bitnet_inference_lib.c 9 | * @brief DLL wrapper for the BitMnist model. 10 | * build in 64 bit visual studio: 11 | * cl /LD BitNetMCU_MNIST_dll.c /MD /FeBitnet_inf.dll /link /MACHINE:X64 12 | *@param input The input data for the inference. 13 | * @return The result of the inference. 14 | */ 15 | 16 | uint32_t BitMnistInference(int8_t *input); 17 | 18 | #ifdef _DLL 19 | #ifdef _WIN64 20 | #define EXPORT __declspec(dllexport) 21 | #else 22 | #define EXPORT __attribute__((visibility("default"))) 23 | #endif 24 | EXPORT uint32_t Inference(int8_t *input) { 25 | return BitMnistInference(input); 26 | } 27 | #endif 28 | 29 | void printactivations(uint8_t *activations, int32_t n_activations) 30 | { 31 | for (int i = 0; i < n_activations; i++) { 32 | printf("%d, ", activations[i]); 33 | if ((i + 1) % 16 == 0) { 34 | printf("\n"); 35 | } 36 | } 37 | } 38 | 39 | /** 40 | * @brief Performs inference on the BitMnist model. 41 | * 42 | * @param input The input data for the inference. 43 | * @return The result of the inference. 44 | */ 45 | 46 | #ifdef MODEL_CNNMNIST 47 | 48 | uint32_t BitMnistInference(int8_t *input) { 49 | int32_t layer_out[MAX_N_ACTIVATIONS]; 50 | int8_t layer_in[MAX_N_ACTIVATIONS*4]; 51 | 52 | /* 53 | Layer: L2 Conv2d bpw: 8 1 -> 64 groups:1 Kernel: 3x3 Incoming: 16x16 Outgoing: 14x14 54 | Layer: L4 Conv2d bpw: 8 64 -> 64 groups:64 Kernel: 3x3 Incoming: 14x14 Outgoing: 12x12 55 | Layer: L6 MaxPool2d Pool Size: 2 Incoming: 12x12 Outgoing: 6x6 56 | Layer: L7 Conv2d bpw: 8 64 -> 64 groups:64 Kernel: 3x3 Incoming: 6x6 Outgoing: 4x4 57 | Layer: L9 MaxPool2d Pool Size: 2 Incoming: 4x4 Outgoing: 2x2 58 | Layer: L11 Quantization type: <2bitsym>, Bits per weight: 2, Num. incoming: 256, Num outgoing: 96 59 | Layer: L13 Quantization type: <4bitsym>, Bits per weight: 4, Num. incoming: 96, Num outgoing: 64 60 | Layer: L15 Quantization type: <4bitsym>, Bits per weight: 4, Num. incoming: 64, Num outgoing: 10 61 | */ 62 | 63 | // Depthwise separable convolution with 32 bit activations and 8 bit weights 64 | int32_t *tmpbuf=(int32_t*)layer_out; 65 | int32_t *outputptr=(int32_t*)layer_in; 66 | for (uint32_t channel=0; channel < L7_out_channels; channel++) { 67 | 68 | for (uint32_t i=0; i < 16*16; i++) { 69 | tmpbuf[i]=input[i]; 70 | } 71 | processconv33ReLU(tmpbuf, L2_weights + 9*channel, L2_incoming_x, 4, tmpbuf); 72 | processconv33ReLU(tmpbuf, L4_weights + 9*channel, L4_incoming_x, 4, tmpbuf); 73 | processmaxpool22(tmpbuf, L6_incoming_x, tmpbuf); 74 | processconv33ReLU(tmpbuf, L7_weights + 9*channel, L7_incoming_x, 4, tmpbuf); 75 | 76 | outputptr= processmaxpool22(tmpbuf, L9_incoming_x, outputptr); 77 | } 78 | 79 | // Normalization and conversion to 8-bit 80 | ReLUNorm((int32_t*)layer_in, layer_in, L7_out_channels * L9_outgoing_x * L9_outgoing_y); 81 | 82 | // Fully connected layers 83 | processfclayer(layer_in, L11_weights, L11_bitperweight, L11_incoming_weights, L11_outgoing_weights, layer_out); 84 | ReLUNorm(layer_out, layer_in, L11_outgoing_weights); 85 | 86 | processfclayer(layer_in, L13_weights, L13_bitperweight, L13_incoming_weights, L13_outgoing_weights, layer_out); 87 | ReLUNorm(layer_out, layer_in, L13_outgoing_weights); 88 | 89 | processfclayer(layer_in, L15_weights, L15_bitperweight, L15_incoming_weights, L15_outgoing_weights, layer_out); 90 | return ReLUNorm(layer_out, layer_in, L15_outgoing_weights); 91 | } 92 | 93 | #elif defined(MODEL_FCMNIST) 94 | 95 | uint32_t BitMnistInference(int8_t *input) { 96 | int32_t layer_out[MAX_N_ACTIVATIONS]; 97 | int8_t layer_in[MAX_N_ACTIVATIONS]; 98 | 99 | processfclayer(input, L1_weights, L1_bitperweight, L1_incoming_weights, L1_outgoing_weights, layer_out); 100 | ReLUNorm(layer_out, layer_in, L1_outgoing_weights); 101 | 102 | // printf("L1 activations: \n"); 103 | // printactivations(layer_in, L1_outgoing_weights); 104 | 105 | processfclayer(layer_in, L2_weights, L2_bitperweight, L2_incoming_weights, L2_outgoing_weights, layer_out); 106 | ReLUNorm(layer_out, layer_in, L2_outgoing_weights); 107 | 108 | // printf("L2 activations: \n"); 109 | // printactivations(layer_in, L2_outgoing_weights); 110 | 111 | #ifdef L4_active 112 | processfclayer(layer_in, L3_weights, L3_bitperweight, L3_incoming_weights, L3_outgoing_weights, layer_out); 113 | ReLUNorm(layer_out, layer_in, L3_outgoing_weights); 114 | 115 | processfclayer(layer_in, L4_weights, L4_bitperweight, L4_incoming_weights, L4_outgoing_weights, layer_out); 116 | return ReLUNorm(layer_out, layer_in, L4_outgoing_weights); 117 | #else 118 | processfclayer(layer_in, L3_weights, L3_bitperweight, L3_incoming_weights, L3_outgoing_weights, layer_out); 119 | return ReLUNorm(layer_out, layer_in, L3_outgoing_weights); 120 | #endif 121 | } 122 | #else 123 | #error "No model defined" 124 | #endif 125 | 126 | -------------------------------------------------------------------------------- /docs/trainingpipeline.drawio: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # BitNetMCU: High Accuracy Low-Bit Quantized Neural Networks on a low-end Microcontroller 2 | 3 | **BitNetMCU** is a project focused on the training and inference of low-bit quantized neural networks, specifically designed to run efficiently on low-end microcontrollers like the CH32V003. Quantization aware training (QAT) and fine-tuning of model structure and inference code allowed *surpassing 99% Test accuracy on a 16x16 MNIST dataset without using multiplication instructions and in only 2kb of RAM and 16kb of Flash*. 4 | 5 | **Update** Introducing a new model architecture based on deep separable convolutions allowed to push the accuracy even further to **99.55% accuracy**, meeting state-of-the-art MNIST accuracy for CNNs while still fitting into the same memory constraints. This model requires a hardware multiplier, which is available in many low-end RISC-V and ARM Cortex-M0 microcontrollers. 6 | 7 | The training pipeline is based on PyTorch and should run anywhere. The inference engine is implemented in Ansi-C and can be easily ported to any Microcontroller. 8 | 9 | **You can my buildlogs on the project in the `docs/` directory [here](docs/documentation.md) and CNN details [here](docs/documentation_cnn.md).** 10 | **Also see [BitNetMCU blog articles](https://cpldcpu.github.io/tags/bitnetmcu/)** 11 | 12 |
13 | 14 |
15 | 16 | ## Project Structure 17 | 18 | ``` 19 | BitNetMCU/ 20 | │ 21 | ├── docs/ # Report 22 | ├── mcu/ # MCU specific code for CH32V003 23 | ├── modeldata/ # Pre-trained models 24 | │ 25 | ├── BitNetMCU.py # Pytorch model and QAT classes 26 | ├── BitNetMCU_inference.c # C code for inference 27 | ├── BitNetMCU_inference.h # Header file for C inference code 28 | ├── BitNetMCU_MNIST_test.c # Test script for MNIST dataset 29 | ├── BitNetMCU_MNIST_test_data.h# MNIST test data in header format (generated) 30 | ├── BitNetMCU_model.h # Model data in C header format (generated) 31 | ├── exportquant.py # Script to convert trained model to quantized format 32 | ├── test_inference.py # Script to test C implementation of inference 33 | ├── training.py # Training script for the neural network 34 | └── trainingparameters.yaml # Configuration file for training parameters 35 | ``` 36 | 37 | ## Training Pipeline 38 | 39 | The data pipeline is split into several Python scripts for flexibility: 40 | 41 | 1. **Configuration**: Modify `trainingparameters.yaml` to set all hyperparameters for training the model. 42 | 43 | 2. **Training the Model**: The `training.py` script is used to train the model and store the weights as a `.pth` file in the `modeldata/` folder. The model weights are still in float format at this stage, as they are quantized on-the-fly during training. 44 | 45 | 2. **Exporting the Quantized Model**: The `exportquant.py` script is used to convert the model into a quantized format. The quantized model weights are exported to the C header file `BitNetMCU_model.h`. 46 | 47 | 3. **Optional: Testing the C-Model**: Compile and execute `BitNetMCU_MNIST_test.c` to test inference of ten digits. The model data is included from `BitNetMCU_MNIST_test_data.h`, and the test data is included from the `BitNetMCU_MNIST_test_data.h` file. 48 | 49 | 4. **Optional: Verification C vs Python Model on full dataset**: The inference code, along with the model data, is compiled into a DLL. The `test-inference.py` script calls the DLL and compares the results with the original Python model. This allows for an accurate comparison to the entire MNIST test data set of 10,000 images. 50 | 51 | 5. **Optional: Testing inference on the MCU**: follow the instructions in `mcu/readme.md`. Porting to architectures other than CH32V003 is straightforward and the files in the `mcu` directory can serve as a reference. 52 | 53 | ## Updates 54 | 55 | - 24th April 2024 - First release with Binary, Ternary, 2 bit, 4 bit and 8 bit quantization. 56 | - 2nd May 2024 - [tagged version 0.1a](https://github.com/cpldcpu/BitNetMCU/tree/0.1a) 57 | - 8th May 2024 - Added FP1.3.0 Quantization to allow fully multiplication-free inference with 98.9% accuracy. 58 | - 11th May 2024 - Fixes for Linux. Thanks to @donn 59 | - 19th May 2024 - Add support for non-symmetric 4bit quantization scheme that allows for easier inference on MCUs with multiplier. The inference code will now use code optimized for multiplierless MCUs only on RV32 architectures without multiplier. 60 | - 20th May 2024 - Added ```quantscale``` as a hyperparameter to influence weight scaling. [Updated documentation on new quantization schemes](https://github.com/cpldcpu/BitNetMCU/blob/main/docs/documentation.md#may-20-2024-additional-quantization-schemes). 61 | - 26th May 2024 - [tagged version 0.2a](https://github.com/cpldcpu/BitNetMCU/tree/0.2a) 62 | - 19th July 2024 - [Added octav algorithm](https://github.com/cpldcpu/BitNetMCU/blob/main/docs/documentation.md#july-19-2024-octav-optimum-clipping) to calculate optimal clipping and quantization parameters. 63 | - 26th July 2024 - Added support for NormalFloat4 (NF4) Quantization. [Updated documentation](docs/documentation.md#july-26-2024-normalfloat4-nf4-quantization) 64 | - 7th September 2025 - New CNN architecture based on sequential depthwise separable convolutions allows to reach 99.55% accuracy while still fitting into 16kb Flash and 4kb RAM. [See documentation](docs/documentation_cnn.md) for details. 65 | -------------------------------------------------------------------------------- /BitNetMCU_MNIST_test.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "BitNetMCU_model.h" 5 | #include "BitNetMCU_inference.c" 6 | #include "BitNetMCU_MNIST_test_data.h" 7 | 8 | /** 9 | * Performs inference on the MNIST dataset using the BitNetMCU model. 10 | * 11 | * @param input The input data for the inference, a 16x16 array of int8_t. 12 | * @return The predicted digit. 13 | */ 14 | 15 | uint32_t BitMnistInference(int8_t*); 16 | 17 | void main(void) { 18 | uint32_t output[10]; 19 | uint8_t predicted_label; 20 | predicted_label = BitMnistInference(input_data_0); 21 | printf("label: %d predicted: %d\n", label_0, predicted_label); 22 | predicted_label = BitMnistInference(input_data_1); 23 | printf("label: %d predicted: %d\n", label_1, predicted_label); 24 | predicted_label = BitMnistInference(input_data_2); 25 | printf("label: %d predicted: %d\n", label_2, predicted_label); 26 | predicted_label = BitMnistInference(input_data_3); 27 | printf("label: %d predicted: %d\n", label_3, predicted_label); 28 | predicted_label = BitMnistInference(input_data_4); 29 | printf("label: %d predicted: %d\n", label_4, predicted_label); 30 | predicted_label = BitMnistInference(input_data_5); 31 | printf("label: %d predicted: %d\n", label_5, predicted_label); 32 | predicted_label = BitMnistInference(input_data_6); 33 | printf("label: %d predicted: %d\n", label_6, predicted_label); 34 | predicted_label = BitMnistInference(input_data_7); 35 | printf("label: %d predicted: %d\n", label_7, predicted_label); 36 | predicted_label = BitMnistInference(input_data_8); 37 | printf("label: %d predicted: %d\n", label_8, predicted_label); 38 | predicted_label = BitMnistInference(input_data_9); 39 | printf("label: %d predicted: %d\n", label_9, predicted_label); 40 | } 41 | 42 | 43 | #ifdef MODEL_CNNMNIST 44 | 45 | uint32_t BitMnistInference(int8_t *input) { 46 | int32_t layer_out[MAX_N_ACTIVATIONS]; 47 | int8_t layer_in[MAX_N_ACTIVATIONS*4]; 48 | 49 | /* 50 | Layer: L2 Conv2d bpw: 8 1 -> 64 groups:1 Kernel: 3x3 Incoming: 16x16 Outgoing: 14x14 51 | Layer: L4 Conv2d bpw: 8 64 -> 64 groups:64 Kernel: 3x3 Incoming: 14x14 Outgoing: 12x12 52 | Layer: L6 MaxPool2d Pool Size: 2 Incoming: 12x12 Outgoing: 6x6 53 | Layer: L7 Conv2d bpw: 8 64 -> 64 groups:64 Kernel: 3x3 Incoming: 6x6 Outgoing: 4x4 54 | Layer: L9 MaxPool2d Pool Size: 2 Incoming: 4x4 Outgoing: 2x2 55 | Layer: L11 Quantization type: <2bitsym>, Bits per weight: 2, Num. incoming: 256, Num outgoing: 96 56 | Layer: L13 Quantization type: <4bitsym>, Bits per weight: 4, Num. incoming: 96, Num outgoing: 64 57 | Layer: L15 Quantization type: <4bitsym>, Bits per weight: 4, Num. incoming: 64, Num outgoing: 10 58 | */ 59 | 60 | // Depthwise separable convolution with 32 bit activations and 8 bit weights 61 | int32_t *tmpbuf=(int32_t*)layer_out; 62 | int32_t *outputptr=(int32_t*)layer_in; 63 | for (uint32_t channel=0; channel < L7_out_channels; channel++) { 64 | 65 | for (uint32_t i=0; i < 16*16; i++) { 66 | tmpbuf[i]=input[i]; 67 | } 68 | processconv33ReLU(tmpbuf, L2_weights + 9*channel, L2_incoming_x, 4, tmpbuf); 69 | processconv33ReLU(tmpbuf, L4_weights + 9*channel, L4_incoming_x, 4, tmpbuf); 70 | processmaxpool22(tmpbuf, L6_incoming_x, tmpbuf); 71 | processconv33ReLU(tmpbuf, L7_weights + 9*channel, L7_incoming_x, 4, tmpbuf); 72 | 73 | outputptr= processmaxpool22(tmpbuf, L9_incoming_x, outputptr); 74 | } 75 | 76 | // Normalization and conversion to 8-bit 77 | ReLUNorm((int32_t*)layer_in, layer_in, L7_out_channels * L9_outgoing_x * L9_outgoing_y); 78 | 79 | // Fully connected layers 80 | processfclayer(layer_in, L11_weights, L11_bitperweight, L11_incoming_weights, L11_outgoing_weights, layer_out); 81 | ReLUNorm(layer_out, layer_in, L11_outgoing_weights); 82 | 83 | processfclayer(layer_in, L13_weights, L13_bitperweight, L13_incoming_weights, L13_outgoing_weights, layer_out); 84 | ReLUNorm(layer_out, layer_in, L13_outgoing_weights); 85 | 86 | processfclayer(layer_in, L15_weights, L15_bitperweight, L15_incoming_weights, L15_outgoing_weights, layer_out); 87 | return ReLUNorm(layer_out, layer_in, L15_outgoing_weights); 88 | } 89 | 90 | #elif defined(MODEL_FCMNIST) 91 | 92 | uint32_t BitMnistInference(int8_t *input) { 93 | int32_t layer_out[MAX_N_ACTIVATIONS]; 94 | int8_t layer_in[MAX_N_ACTIVATIONS]; 95 | 96 | processfclayer(input, L1_weights, L1_bitperweight, L1_incoming_weights, L1_outgoing_weights, layer_out); 97 | ReLUNorm(layer_out, layer_in, L1_outgoing_weights); 98 | 99 | // printf("L1 activations: \n"); 100 | // printactivations(layer_in, L1_outgoing_weights); 101 | 102 | processfclayer(layer_in, L2_weights, L2_bitperweight, L2_incoming_weights, L2_outgoing_weights, layer_out); 103 | ReLUNorm(layer_out, layer_in, L2_outgoing_weights); 104 | 105 | // printf("L2 activations: \n"); 106 | // printactivations(layer_in, L2_outgoing_weights); 107 | 108 | #ifdef L4_active 109 | processfclayer(layer_in, L3_weights, L3_bitperweight, L3_incoming_weights, L3_outgoing_weights, layer_out); 110 | ReLUNorm(layer_out, layer_in, L3_outgoing_weights); 111 | 112 | processfclayer(layer_in, L4_weights, L4_bitperweight, L4_incoming_weights, L4_outgoing_weights, layer_out); 113 | return ReLUNorm(layer_out, layer_in, L4_outgoing_weights); 114 | #else 115 | processfclayer(layer_in, L3_weights, L3_bitperweight, L3_incoming_weights, L3_outgoing_weights, layer_out); 116 | return ReLUNorm(layer_out, layer_in, L3_outgoing_weights); 117 | #endif 118 | } 119 | #else 120 | #error "No model defined" 121 | #endif -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from BitNetMCU import BitLinear, BitConv2d 5 | 6 | class MaskingLayer(nn.Module): 7 | 8 | def __init__(self, num_channels): 9 | super(MaskingLayer, self).__init__() 10 | self.mask = nn.Parameter(torch.ones(num_channels)) 11 | 12 | def forward(self, x): 13 | return x * self.mask.view(1, -1) 14 | 15 | def prune_channels(self, prune_number=8, groups=0): 16 | with torch.no_grad(): 17 | if groups > 0: 18 | 19 | channels_per_group = self.mask.size(0) // groups 20 | group_mask_values = torch.zeros(groups) 21 | 22 | # Calculate the sum of mask values for each group 23 | for group in range(groups): 24 | start = group * channels_per_group 25 | end = start + channels_per_group 26 | group_mask_values[group] = self.mask[start:end].sum() 27 | 28 | # Sort the group mask values and determine the threshold 29 | sorted_group_mask_values, _ = torch.sort(group_mask_values) 30 | threshold = sorted_group_mask_values[prune_number - 1].item() 31 | 32 | # Update the mask values to prune entire groups 33 | mask_values = self.mask.clone() 34 | for group in range(groups): 35 | start = group * channels_per_group 36 | end = start + channels_per_group 37 | if group_mask_values[group] <= threshold: 38 | mask_values[start:end] = 0.0 39 | else: 40 | mask_values[start:end] = 1.0 41 | else: 42 | sorted_mask_values, _ = torch.sort(self.mask.view(-1)) 43 | threshold = sorted_mask_values[prune_number - 1].item() 44 | mask_values = (self.mask > threshold).float() 45 | 46 | self.mask.requires_grad = False 47 | self.mask.data = mask_values 48 | 49 | pruned_channels = (mask_values < threshold).sum().item() 50 | remaining_channels = (mask_values >= threshold).sum().item() 51 | print(f"Pruned {pruned_channels} channels. {remaining_channels} channels remaining.") 52 | return pruned_channels, remaining_channels 53 | 54 | 55 | class FCMNIST(nn.Module): 56 | """ 57 | Fully Connected Neural Network for MNIST dataset. 58 | 16x16 input image, 3 hidden layers with a configurable width. 59 | 60 | @cpldcpu 2024-March-24 61 | 62 | """ 63 | def __init__(self,network_width1=64,network_width2=64,network_width3=64,QuantType='Binary',WScale='PerTensor',NormType='RMS', num_classes: int = 10): 64 | super(FCMNIST, self).__init__() 65 | 66 | self.network_width1 = network_width1 67 | self.network_width2 = network_width2 68 | self.network_width3 = network_width3 69 | 70 | self.model = nn.Sequential( 71 | nn.Flatten(), 72 | BitLinear(1* 16 *16, network_width1,QuantType=QuantType,NormType=NormType, WScale=WScale), 73 | nn.ReLU(), 74 | BitLinear(network_width1, network_width2,QuantType=QuantType,NormType=NormType, WScale=WScale), 75 | nn.ReLU() 76 | ) 77 | 78 | if network_width3>0: 79 | self.model.add_module("fc3", BitLinear(network_width2, network_width3,QuantType=QuantType,NormType=NormType, WScale=WScale)) 80 | self.model.add_module("relu_fc2", nn.ReLU()) 81 | 82 | last_width = network_width3 if network_width3>0 else network_width2 83 | # Output layer parameterized by number of classes (default 10 for MNIST / 47 for EMNIST balanced, etc.) 84 | self.classifier= BitLinear(last_width, num_classes,QuantType=QuantType,NormType=NormType, WScale=WScale) 85 | 86 | def forward(self, x): 87 | x = self.model(x) 88 | x = self.classifier(x) 89 | 90 | return x 91 | 92 | class CNNMNIST(nn.Module): 93 | """ 94 | CNN+FC Neural Network for MNIST dataset. Depthwise separable convolutions. 95 | 16x16 input image, 3 hidden layers with a configurable width. 96 | 97 | @cpldcpu 2024-April-19 98 | 99 | """ 100 | def __init__(self,network_width1=64,network_width2=64,network_width3=64,QuantType='Binary',WScale='PerTensor',NormType='RMS', num_classes: int = 10): 101 | super(CNNMNIST, self).__init__() 102 | 103 | self.network_width1 = network_width1 104 | self.network_width2 = network_width2 105 | self.network_width3 = network_width3 106 | 107 | self.model = nn.Sequential( 108 | 109 | # 256ch out , 99.5% 110 | BitConv2d(1, 64, kernel_size=3, stride=1, padding=(0,0), groups=1,QuantType='8bit',NormType='None', WScale=WScale), 111 | nn.ReLU(), 112 | BitConv2d(64, 64, kernel_size=3, stride=1, padding=(0,0), groups=64,QuantType='8bit',NormType='None', WScale=WScale), 113 | nn.ReLU(), 114 | nn.MaxPool2d(kernel_size=2, stride=2), 115 | BitConv2d(64, 64, kernel_size=3, stride=1, padding=(0,0), groups=64,QuantType='8bit',NormType='None', WScale=WScale), 116 | nn.ReLU(), 117 | nn.MaxPool2d(kernel_size=2, stride=2), 118 | 119 | nn.Flatten(), 120 | # MaskingLayer(96*4), # learnable masking layer for auto-pruning 121 | BitLinear(64*4 , network_width1,QuantType='2bitsym',NormType=NormType, WScale=WScale), 122 | nn.ReLU(), 123 | BitLinear(network_width1, network_width2,QuantType=QuantType,NormType=NormType, WScale=WScale), 124 | nn.ReLU() 125 | ) 126 | 127 | if network_width3>0: 128 | self.model.add_module("fc3", BitLinear(network_width2, network_width3,QuantType=QuantType,NormType=NormType, WScale=WScale)) 129 | self.model.add_module("relu_fc2", nn.ReLU()) 130 | 131 | last_width = network_width3 if network_width3>0 else network_width2 132 | # Output layer parameterized by number of classes (default 10 for MNIST / 47 for EMNIST balanced, etc.) 133 | self.classifier= BitLinear(last_width, num_classes,QuantType=QuantType,NormType=NormType, WScale=WScale) 134 | # self.dropout = nn.Dropout(0.05) 135 | 136 | def forward(self, x): 137 | x = self.model(x) 138 | x = self.classifier(x) 139 | return x 140 | -------------------------------------------------------------------------------- /docs/plots/clean_30epruns.txt: -------------------------------------------------------------------------------- 1 | num_epochs QuantType BPW NormType batch_size learning_rate lr_decay step_size network_width1 network_width2 network_width3 Parameters Totalbits Accuracy/train Accuracy/test Loss/train Loss/test 2 | 30 2bitsym 2 RMS 128 0.001 0.1 10 128 128 128 66816 133632 99.88999939 98.19000244 0.005878554 0.069907822 3 | 30 2bitsym 2 RMS 128 0.001 0.1 10 160 160 160 93760 187520 99.95999908 98.47000122 0.002911944 0.060931921 4 | 30 2bitsym 2 RMS 128 0.001 0.1 10 160 160 160 93760 187520 99.95666504 98.36000061 0.003068546 0.067683592 5 | 30 2bitsym 2 RMS 128 0.001 0.1 10 16 16 16 4768 9536 93.96833038 93.54000092 0.210660011 0.220271438 6 | 30 2bitsym 2 RMS 128 0.001 0.1 10 24 24 24 7536 15072 96.30500031 95.86000061 0.124898829 0.144815609 7 | 30 2bitsym 2 RMS 128 0.001 0.1 10 64 48 48 22240 44480 98.94999695 97.62999725 0.034304049 0.081468672 8 | 30 2bitsym 2 RMS 128 0.001 0.1 10 64 64 48 24032 48064 99.12999725 97.79000092 0.030079897 0.07552468 9 | 30 2bitsym 2 RMS 128 0.001 0.1 10 64 64 48 24032 48064 99.06666565 97.73000336 0.030958444 0.081059061 10 | 30 2bitsym 2 RMS 128 0.001 0.1 10 64 64 80 26400 52800 99.04833221 97.80000305 0.030868249 0.076512448 11 | 30 2bitsym 2 RMS 128 0.001 0.1 10 80 64 64 30336 60672 99.3833313 97.76000214 0.021506036 0.0804158 12 | 30 2bitsym 2 RMS 128 0.001 0.1 10 80 80 64 32640 65280 99.47499847 97.91000366 0.019363793 0.071172968 13 | 30 2bitsym 2 RMS 128 0.001 0.1 10 80 80 80 34080 68160 99.47166443 97.81999969 0.018568316 0.075221173 14 | 30 2bitsym 2 RMS 128 0.001 0.1 10 96 96 96 43968 87936 99.64333344 97.87999725 0.014006477 0.078228727 15 | 30 2bitsym 2 RMS 128 0.001 0.1 10 64 64 64 25216 50432 99.07499695 97.66999817 0.030472489 0.086960226 16 | 30 2bitsym 2 RMS 128 0.001 0.1 10 32 32 32 10560 21120 97.36166382 96.19999695 0.087572411 0.129585445 17 | 30 2bitsym 2 RMS 128 0.001 0.1 10 48 48 48 17376 34752 98.48666382 97.54000092 0.049696106 0.087036267 18 | 30 2bitsym 2 RMS 128 0.001 0.1 10 64 64 64 25216 50432 99.10666656 98.05000305 0.030031208 0.073056258 19 | 30 4bitsym 4 RMS 128 0.001 0.1 10 128 128 128 66816 267264 99.98833466 98.33999634 0.001657521 0.070398375 20 | 30 4bitsym 4 RMS 128 0.001 0.1 10 160 160 160 93760 375040 99.99333191 98.52999878 0.001168294 0.062753431 21 | 30 4bitsym 4 RMS 128 0.001 0.1 10 16 16 16 4768 19072 95.8833313 95.15000153 0.142325208 0.166150466 22 | 30 4bitsym 4 RMS 128 0.001 0.1 10 24 24 24 7536 30144 97.59166718 96.58000183 0.081336185 0.116349958 23 | 30 4bitsym 4 RMS 128 0.001 0.1 10 48 48 48 17376 69504 99.37833405 97.75 0.02541119 0.078946352 24 | 30 4bitsym 4 RMS 128 0.001 0.1 10 80 80 80 34080 136320 99.86833191 98.16000366 0.007509662 0.063676886 25 | 30 4bitsym 4 RMS 128 0.001 0.1 10 96 96 96 43968 175872 99.94166565 98.27999878 0.004224633 0.067755178 26 | 30 4bitsym 4 RMS 128 0.001 0.1 10 64 64 64 25216 100864 99.73666382 97.98000336 0.012213291 0.075106576 27 | 30 4bitsym 4 RMS 128 0.001 0.1 10 64 64 64 25216 100864 99.76667023 97.91999817 0.011324343 0.070851415 28 | 30 4bitsym 4 RMS 128 0.001 0.1 10 32 32 32 10560 42240 98.625 97.44999695 0.050012723 0.091146365 29 | 30 4bitsym 4 RMS 128 0.001 0.1 10 48 48 48 17376 69504 99.45333099 97.66000366 0.023183124 0.088122196 30 | 30 4bitsym 4 RMS 128 0.001 0.1 10 64 64 64 25216 100864 99.77166748 98.02999878 0.012159877 0.075515941 31 | 30 8bit 8 RMS 128 0.001 0.1 10 128 128 128 66816 534528 99.98166656 98.26000214 0.001929083 0.069513015 32 | 30 8bit 8 RMS 128 0.001 0.1 10 160 160 160 93760 750080 99.9916687 98.31999969 0.000933084 0.067958511 33 | 30 8bit 8 RMS 128 0.001 0.1 10 16 16 16 4768 38144 95.5533371 95.01999664 0.149930149 0.172408655 34 | 30 8bit 8 RMS 128 0.001 0.1 10 16 16 16 4768 38144 96.48999786 95.75 0.121690705 0.149008647 35 | 30 8bit 8 RMS 128 0.001 0.1 10 32 32 32 10560 84480 98.66166687 97.38999939 0.048467185 0.09226539 36 | 30 8bit 8 RMS 128 0.001 0.1 10 32 32 32 10560 84480 98.68333435 97.16000366 0.048221145 0.097158171 37 | 30 8bit 8 RMS 128 0.001 0.1 10 48 48 48 17376 139008 99.39167023 97.94000244 0.024051715 0.077490211 38 | 30 8bit 8 RMS 128 0.001 0.1 10 48 48 48 17376 139008 99.51667023 97.76000214 0.02148254 0.077619374 39 | 30 8bit 8 RMS 128 0.001 0.1 10 48 48 48 17376 139008 99.49833679 97.69000244 0.021350816 0.085419267 40 | 30 8bit 8 RMS 128 0.001 0.1 10 80 80 80 34080 272640 99.89833069 98.08999634 0.006452073 0.076201476 41 | 30 8bit 8 RMS 128 0.001 0.1 10 96 96 96 43968 351744 99.95999908 98.31999969 0.003559482 0.066809647 42 | 30 8bit 8 RMS 128 0.001 0.1 10 64 64 64 25216 201728 99.86833191 98.02999878 0.007506728 0.073647425 43 | 30 8bit 8 RMS 128 0.001 0.1 10 64 64 64 25216 201728 99.78166962 98.05999756 0.011363678 0.07218986 44 | 30 Binary 1 RMS 128 0.001 0.1 10 128 128 128 66816 66816 99.25499725 97.93000031 0.024259612 0.069476709 45 | 30 Binary 1 RMS 128 0.001 0.1 10 160 160 160 93760 93760 99.51833344 98.26999664 0.016576426 0.060108811 46 | 30 Binary 1 RMS 128 0.001 0.1 10 32 48 32 11584 11584 94.95166779 94.88999939 0.171185374 0.175417304 47 | 30 Binary 1 RMS 128 0.001 0.1 10 48 32 32 15168 15168 96.12666321 95.54000092 0.130515814 0.145601466 48 | 30 Binary 1 RMS 128 0.001 0.1 10 48 48 32 16448 16448 96.61333466 96.06999969 0.113300718 0.132798299 49 | 30 Binary 1 RMS 128 0.001 0.1 10 96 96 96 43968 43968 98.6883316 97.80999756 0.041361477 0.075462244 50 | 30 Binary 1 RMS 128 0.001 0.1 10 64 64 64 25216 25216 97.79333496 96.76999664 0.072463408 0.105802305 51 | 30 Binary 1 RMS 128 0.001 0.1 10 64 64 64 25216 25216 97.75333405 97.15000153 0.073109455 0.09615273 52 | 30 Binary 1 RMS 128 0.001 0.1 10 64 64 64 25216 25216 97.61499786 96.77999878 0.075120606 0.111203559 53 | 30 Binary 1 RMS 128 0.001 0.1 10 32 32 32 10560 10560 94.74500275 94.55000305 0.18030183 0.19052586 54 | 30 Binary 1 RMS 128 0.001 0.1 10 48 48 48 17376 17376 96.86000061 96.41999817 0.10408707 0.127519086 55 | 30 Binary 1 RMS 128 0.001 0.1 10 64 64 64 25216 25216 97.74833679 97.01999664 0.071593277 0.107530497 56 | 30 Binary 1 RMS 128 0.001 0.1 10 80 80 80 34080 34080 98.29666901 97.18000031 0.055908281 0.093226478 57 | 30 None 32 RMS 128 0.001 0.1 10 48 48 48 17376 556032 99.66999817 97.61000061 0.016479535 0.081654519 58 | 30 None 32 RMS 128 0.001 0.1 10 64 64 64 25216 806912 99.88500214 98.12000275 0.007095081 0.06810613 59 | 30 None 32 RMS 128 0.001 0.1 10 64 64 64 25216 806912 99.87666321 98.08999634 0.007548976 0.075958706 60 | 30 None 32 RMS 128 0.001 0.1 10 64 64 64 25216 806912 99.89499664 98.22000122 0.00649053 0.071967855 61 | 30 Ternary 1.6 RMS 128 0.001 0.1 10 64 64 64 25216 40345.60156 98.82666779 97.66000366 0.039240263 0.080099061 62 | 30 Ternary 1.6 RMS 128 0.001 0.1 10 128 128 128 66816 106905.6016 99.76166534 98.34999847 0.0100767 0.06396731 63 | 30 Ternary 1.6 RMS 128 0.001 0.1 10 160 160 160 93760 150016 99.90000153 98.37999725 0.004819755 0.061947428 64 | 30 Ternary 1.6 RMS 128 0.001 0.1 10 24 24 24 7536 12057.59961 95.5083313 95.05999756 0.157431617 0.172766447 65 | 30 Ternary 1.6 RMS 128 0.001 0.1 10 96 96 96 43968 70348.79688 99.5083313 97.86000061 0.018384641 0.074451007 66 | 30 Ternary 1.6 RMS 128 0.001 0.1 10 32 32 32 10560 16896 96.65499878 96.08000183 0.113325745 0.129083708 67 | 30 Ternary 1.6 RMS 128 0.001 0.1 10 48 48 48 17376 27801.6 97.95999908 97.11000061 0.06503389 0.096156642 68 | 30 Ternary 1.6 RMS 128 0.001 0.1 10 64 64 64 25216 40345.6 98.68499756 97.38999939 0.042172994 0.08866106 69 | 30 Ternary 1.6 RMS 128 0.001 0.1 10 80 80 80 34080 54528 99.1783371 97.77999878 0.027794939 0.073586062 70 | -------------------------------------------------------------------------------- /test_inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | from torch.utils.data import DataLoader 4 | import numpy as np 5 | from BitNetMCU import QuantizedModel 6 | # from models import FCMNIST 7 | from ctypes import CDLL, c_uint32, c_int8, c_uint8, POINTER 8 | import argparse 9 | import yaml 10 | import importlib 11 | 12 | # Export quantized model from saved checkpoint 13 | # cpldcpu 2024-04-14 14 | # Note: Hyperparameters are used to generated the filename 15 | #--------------------------------------------- 16 | 17 | def create_run_name(hyperparameters): 18 | runname = hyperparameters["runtag"] + '_' + hyperparameters["model"] + ('_Aug' if hyperparameters["augmentation"] else '') + '_BitMnist_' + hyperparameters["QuantType"] + "_width" + str(hyperparameters["network_width1"]) + "_" + str(hyperparameters["network_width2"]) + "_" + str(hyperparameters["network_width3"]) + "_epochs" + str(hyperparameters["num_epochs"]) 19 | hyperparameters["runname"] = runname 20 | return runname 21 | 22 | def load_model(model_name, params): 23 | try: 24 | module = importlib.import_module('models') 25 | model_class = getattr(module, model_name) 26 | return model_class( 27 | network_width1=params["network_width1"], 28 | network_width2=params["network_width2"], 29 | network_width3=params["network_width3"], 30 | QuantType=params["QuantType"], 31 | NormType=params["NormType"], 32 | WScale=params["WScale"] 33 | ) 34 | except AttributeError: 35 | raise ValueError(f"Model {model_name} not found in models.py") 36 | 37 | def export_test_data_to_c(test_loader, filename, num=8): 38 | with open(filename, 'w') as f: 39 | for i, (input_data, labels) in enumerate(test_loader): 40 | if i >= num: 41 | break 42 | # Reshape and convert to numpy 43 | input_data = input_data.view(input_data.size(0), -1).cpu().numpy() 44 | labels = labels.cpu().numpy() 45 | 46 | scale = 127.0 / np.maximum(np.abs(input_data).max(axis=-1, keepdims=True), 1e-5) 47 | scaled_data = np.round(input_data * scale).clip(-128, 127).astype(np.uint8) 48 | 49 | f.write(f'int8_t input_data_{i}[256] = {{\n') 50 | flattened_data = scaled_data.flatten() 51 | for k in range(0, len(flattened_data), 16): 52 | f.write(', '.join(f'0x{value:02X}' for value in flattened_data[k:k+16]) + ',\n') 53 | f.write('};\n') 54 | 55 | f.write(f'uint8_t label_{i} = ' + str(labels[0]) + ';\n') 56 | 57 | if __name__ == '__main__': 58 | parser = argparse.ArgumentParser(description='Training script') 59 | parser.add_argument('--params', type=str, help='Name of the parameter file', default='trainingparameters.yaml') 60 | 61 | args = parser.parse_args() 62 | 63 | if args.params: 64 | paramname = args.params 65 | else: 66 | paramname = 'trainingparameters.yaml' 67 | 68 | print(f'Load parameters from file: {paramname}') 69 | with open(paramname) as f: 70 | hyperparameters = yaml.safe_load(f) 71 | 72 | # main 73 | runname= create_run_name(hyperparameters) 74 | print(runname) 75 | 76 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 77 | 78 | # Load the MNIST dataset 79 | transform = transforms.Compose([ 80 | transforms.Resize((16, 16)), # Resize images to 16x16 81 | transforms.ToTensor(), 82 | transforms.Normalize((0.1307,), (0.3081,)) 83 | ]) 84 | 85 | train_data = datasets.MNIST(root='data', train=True, transform=transform, download=True) 86 | test_data = datasets.MNIST(root='data', train=False, transform=transform) 87 | # Create data loaders 88 | test_loader = DataLoader(test_data, batch_size=hyperparameters["batch_size"], shuffle=False) 89 | 90 | model = load_model(hyperparameters["model"], hyperparameters).to(device) 91 | 92 | print('Loading model...') 93 | try: 94 | model.load_state_dict(torch.load(f'modeldata/{runname}.pth')) 95 | except FileNotFoundError: 96 | print(f"The file 'modeldata/{runname}.pth' does not exist.") 97 | exit() 98 | 99 | print('Inference using the original model...') 100 | correct = 0 101 | total = 0 102 | test_loss = [] 103 | with torch.no_grad(): 104 | for images, labels in test_loader: 105 | images, labels = images.to(device), labels.to(device) 106 | outputs = model(images) 107 | _, predicted = torch.max(outputs.data, 1) 108 | total += labels.size(0) 109 | correct += (predicted == labels).sum().item() 110 | testaccuracy = correct / total * 100 111 | print(f'Accuracy/Test of trained model: {testaccuracy} %') 112 | 113 | print('Quantizing model...') 114 | # Quantize the model 115 | quantized_model = QuantizedModel(model) 116 | print(f'Total number of bits: {quantized_model.totalbits()} ({quantized_model.totalbits()/8/1024} kbytes)') 117 | 118 | # Inference using the quantized model 119 | print ("Verifying inference of quantized model in Python and C") 120 | 121 | # Initialize counter 122 | counter = 0 123 | correct_c = 0 124 | correct_py = 0 125 | mismatch = 0 126 | 127 | test_loader2 = DataLoader(test_data, batch_size=1, shuffle=False) 128 | 129 | # export_test_data_to_c(test_loader2, 'BitNetMCU_MNIST_test_data.h', num=10) 130 | 131 | lib = CDLL('./Bitnet_inf.dll') 132 | 133 | for input_data, labels in test_loader2: 134 | input_data = input_data.view(input_data.size(0), -1).cpu().numpy() 135 | labels = labels.cpu().numpy() 136 | 137 | scale = 127.0 / np.maximum(np.abs(input_data).max(axis=-1, keepdims=True), 1e-5) 138 | scaled_data = np.round(input_data * scale).clip(-128, 127) 139 | 140 | # Create a pointer to the ctypes array 141 | input_data_pointer = (c_int8 * len(scaled_data.flatten()))(*scaled_data.astype(np.int8).flatten()) 142 | 143 | lib.Inference.argtypes = [POINTER(c_int8)] 144 | lib.Inference.restype = c_uint32 145 | 146 | # Inference C 147 | result_c = lib.Inference(input_data_pointer) 148 | 149 | # Inference Python 150 | result_py = quantized_model.inference_quantized(input_data) 151 | predict_py = np.argmax(result_py, axis=1) 152 | 153 | # activations = quantized_model.get_activations(input_data) 154 | 155 | if (result_c == labels[0]): 156 | correct_c += 1 157 | 158 | if (predict_py[0] == labels[0]): 159 | correct_py += 1 160 | 161 | if (result_c != predict_py[0]): 162 | print(f'{counter:5} Mismatch between inference engines found. Prediction C: {result_c} Prediction Python: {predict_py[0]} True: {labels[0]}') 163 | mismatch +=1 164 | 165 | counter += 1 166 | 167 | print("size of test data:", counter) 168 | print(f'Mispredictions C: {counter - correct_c} Py: {counter - correct_py}') 169 | print('Overall accuracy C:', correct_c / counter * 100, '%') 170 | print('Overall accuracy Python:', correct_py / counter * 100, '%') 171 | 172 | print(f'Mismatches between engines: {mismatch} ({mismatch/counter*100}%)') -------------------------------------------------------------------------------- /docs/model_cnn_channel.drawio: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /docs/plots/networksize_all.txt: -------------------------------------------------------------------------------- 1 | num_epochs QuantType BPW NormType batch_size learning_rate lr_decay step_size network_width1 network_width2 network_width3 Parameters Totalbits Accuracy/train Accuracy/test Loss/train Loss/test 2 | 30 2bitsym 2 RMS 128 0.001 0.1 10 16 16 16 4768 9536 93.96833038 93.54000092 0.210660011 0.220271438 3 | 30 Binary 1 RMS 128 0.001 0.1 10 32 48 32 11584 11584 94.95166779 94.88999939 0.171185374 0.175417304 4 | 30 Ternary 1.6 RMS 128 0.001 0.1 10 24 24 24 7536 12057.59961 95.5083313 95.05999756 0.157431617 0.172766447 5 | 30 2bitsym 2 RMS 128 0.001 0.1 10 24 24 24 7536 15072 96.30500031 95.86000061 0.124898829 0.144815609 6 | 30 Binary 1 RMS 128 0.001 0.1 10 48 32 32 15168 15168 96.12666321 95.54000092 0.130515814 0.145601466 7 | 30 Binary 1 RMS 128 0.001 0.1 10 48 48 32 16448 16448 96.61333466 96.06999969 0.113300718 0.132798299 8 | 30 4bitsym 4 RMS 128 0.001 0.1 10 16 16 16 4768 19072 95.8833313 95.15000153 0.142325208 0.166150466 9 | 30 Binary 1 RMS 128 0.001 0.1 10 64 64 64 25216 25216 97.59500122 96.80000305 0.075935371 0.10220889 10 | 30 Binary 1 RMS 128 0.001 0.1 10 64 64 64 25216 25216 97.79333496 96.76999664 0.072463408 0.105802305 11 | 30 Binary 1 RMS 128 0.001 0.1 10 64 64 64 25216 25216 97.63833618 96.61000061 0.076641515 0.117163301 12 | 30 Binary 1 RMS 128 0.001 0.1 10 64 64 64 25216 25216 97.72000122 96.70999908 0.073574446 0.103008136 13 | 30 Binary 1 RMS 128 0.001 0.1 10 64 64 64 25216 25216 97.72000122 96.72000122 0.074169569 0.106362127 14 | 30 Binary 1 RMS 128 0.001 0.1 10 64 64 64 25216 25216 97.74666595 97.04000092 0.071919426 0.100333385 15 | 30 Binary 1 RMS 128 0.001 0.1 10 64 64 64 25216 25216 97.67166901 96.91000366 0.074403964 0.10203369 16 | 30 4bitsym 4 RMS 128 0.001 0.1 10 24 24 24 7536 30144 97.59166718 96.58000183 0.081336185 0.116349958 17 | 30 8bit 8 RMS 128 0.001 0.1 10 16 16 16 4768 38144 95.5533371 95.01999664 0.149930149 0.172408655 18 | 30 8bit 8 RMS 128 0.001 0.1 10 16 16 16 4768 38144 96.48999786 95.75 0.121690705 0.149008647 19 | 30 8bit 8 RMS 128 0.001 0.1 10 16 16 16 4768 38144 95.96333313 95.20999908 0.150282741 0.175092667 20 | 30 8bit 8 RMS 128 0.001 0.1 10 16 16 16 4768 38144 95.93166351 95.27999878 0.138523698 0.163161501 21 | 30 8bit 8 RMS 128 0.001 0.1 10 16 16 16 4768 38144 96.42666626 95.66000366 0.125531822 0.149377212 22 | 30 8bit 8 RMS 128 0.001 0.1 10 16 16 16 4768 38144 96.2583313 95.65000153 0.130329177 0.152184337 23 | 30 8bit 8 RMS 128 0.001 0.1 10 16 16 16 4768 38144 97.07333374 96.05999756 0.103929766 0.141094044 24 | 30 8bit 8 RMS 128 0.001 0.1 10 16 16 16 4768 38144 96.21333313 95.72000122 0.131994531 0.154453799 25 | 30 Ternary 1.6 RMS 128 0.001 0.1 10 64 64 64 25216 40345.60156 98.80500031 97.5 0.038920593 0.086704224 26 | 30 Ternary 1.6 RMS 128 0.001 0.1 10 64 64 64 25216 40345.60156 98.82666779 97.66000366 0.039240263 0.080099061 27 | 30 Binary 1 RMS 128 0.001 0.1 10 96 96 96 43968 43968 98.6883316 97.80999756 0.041361477 0.075462244 28 | 30 2bitsym 2 RMS 128 0.001 0.1 10 64 48 48 22240 44480 98.94999695 97.62999725 0.034304049 0.081468672 29 | 30 2bitsym 2 RMS 128 0.001 0.1 10 64 64 48 24032 48064 99.12999725 97.79000092 0.030079897 0.07552468 30 | 30 2bitsym 2 RMS 128 0.001 0.1 10 64 64 48 24032 48064 99.06666565 97.73000336 0.030958444 0.081059061 31 | 30 2bitsym 2 RMS 128 0.001 0.1 10 64 64 64 25216 50432 99.15666962 97.95999908 0.029203488 0.074282631 32 | 30 2bitsym 2 RMS 128 0.001 0.1 10 64 64 64 25216 50432 99.07499695 97.66999817 0.030472489 0.086960226 33 | 30 2bitsym 2 RMS 128 0.001 0.1 10 64 64 80 26400 52800 99.04833221 97.80000305 0.030868249 0.076512448 34 | 30 2bitsym 2 RMS 128 0.001 0.1 10 80 64 64 30336 60672 99.3833313 97.76000214 0.021506036 0.0804158 35 | 30 2bitsym 2 RMS 128 0.001 0.1 10 80 80 64 32640 65280 99.47499847 97.91000366 0.019363793 0.071172968 36 | 30 Binary 1 RMS 128 0.001 0.1 10 128 128 128 66816 66816 99.25499725 97.93000031 0.024259612 0.069476709 37 | 30 2bitsym 2 RMS 128 0.001 0.1 10 80 80 80 34080 68160 99.47166443 97.81999969 0.018568316 0.075221173 38 | 30 4bitsym 4 RMS 128 0.001 0.1 10 48 48 48 17376 69504 99.37833405 97.75 0.02541119 0.078946352 39 | 30 4bitsym 4 RMS 128 0.001 0.1 10 48 48 48 17376 69504 97.09166718 96.18000031 0.096560962 0.135403857 40 | 30 Ternary 1.6 RMS 128 0.001 0.1 10 96 96 96 43968 70348.79688 99.5083313 97.86000061 0.018384641 0.074451007 41 | 30 8bit 8 RMS 128 0.001 0.1 10 32 32 32 10560 84480 98.74666595 97.44000244 0.047507994 0.090913475 42 | 30 8bit 8 RMS 128 0.001 0.1 10 32 32 32 10560 84480 98.66166687 97.38999939 0.048467185 0.09226539 43 | 30 8bit 8 RMS 128 0.001 0.1 10 32 32 32 10560 84480 98.68333435 97.16000366 0.048221145 0.097158171 44 | 30 2bitsym 2 RMS 128 0.001 0.1 10 96 96 96 43968 87936 99.64333344 97.87999725 0.014006477 0.078228727 45 | 30 Binary 1 RMS 128 0.001 0.1 10 160 160 160 93760 93760 99.51833344 98.26999664 0.016576426 0.060108811 46 | 30 4bitsym 4 RMS 128 0.001 0.1 10 64 64 64 25216 100864 99.77666473 98.20999908 0.011492738 0.067927346 47 | 30 4bitsym 4 RMS 128 0.001 0.1 10 64 64 64 25216 100864 99.73666382 97.98000336 0.012213291 0.075106576 48 | 30 4bitsym 4 RMS 128 0.001 0.1 10 64 64 64 25216 100864 99.76833344 97.76000214 0.01179162 0.07858 49 | 30 Ternary 1.6 RMS 128 0.001 0.1 10 128 128 128 66816 106905.6016 99.17166901 98.43000031 0.026882112 0.047549289 50 | 30 Ternary 1.6 RMS 128 0.001 0.1 10 128 128 128 66816 106905.6016 99.34500122 98.63999939 0.021795386 0.04329231 51 | 30 Ternary 1.6 RMS 128 0.001 0.1 10 128 128 128 66816 106905.6016 99.76166534 98.34999847 0.0100767 0.06396731 52 | 30 Ternary 1.6 RMS 128 0.001 0.1 10 128 128 128 66816 106905.6016 98.08999634 98.65000153 0.060123611 0.044574291 53 | 30 Ternary 1.6 RMS 128 0.001 0.1 10 128 128 128 66816 106905.6016 97.51333618 98.62000275 0.077958964 0.04628583 54 | 30 Ternary 1.6 RMS 128 0.001 0.1 10 128 128 128 66816 106905.6016 97.89167023 98.69999695 0.067252591 0.040403374 55 | 30 Ternary 1.6 RMS 128 0.001 0.1 10 128 128 128 66816 106905.6016 97.87750244 98.76000214 0.068583876 0.041533936 56 | 30 2bitsym 2 RMS 128 0.001 0.1 10 128 128 128 66816 133632 99.88999939 98.19000244 0.005878554 0.069907822 57 | 30 4bitsym 4 RMS 128 0.001 0.1 10 80 80 80 34080 136320 99.18000031 98.58000183 0.026338689 0.044426065 58 | 30 4bitsym 4 RMS 128 0.001 0.1 10 80 80 80 34080 136320 99.20249939 98.62000275 0.026008073 0.045075145 59 | 30 4bitsym 4 RMS 128 0.001 0.1 10 80 80 80 34080 136320 99.21916962 98.55999756 0.026043883 0.045832504 60 | 30 4bitsym 4 RMS 128 0.001 0.1 10 80 80 80 34080 136320 99.86833191 98.16000366 0.007509662 0.063676886 61 | 30 8bit 8 RMS 128 0.001 0.1 10 48 48 48 17376 139008 99.39167023 97.94000244 0.024051715 0.077490211 62 | 30 8bit 8 RMS 128 0.001 0.1 10 48 48 48 17376 139008 99.51667023 97.76000214 0.02148254 0.077619374 63 | 30 8bit 8 RMS 128 0.001 0.1 10 48 48 48 17376 139008 99.49833679 97.69000244 0.021350816 0.085419267 64 | 30 8bit 8 RMS 128 0.001 0.1 10 48 48 48 17376 139008 99.66333008 97.73000336 0.016294222 0.085673384 65 | 30 Ternary 1.6 RMS 128 0.001 0.1 10 160 160 160 93760 150016 99.90000153 98.37999725 0.004819755 0.061947428 66 | 30 4bitsym 4 RMS 128 0.001 0.1 10 96 96 96 43968 175872 99.94166565 98.27999878 0.004224633 0.067755178 67 | 30 2bitsym 2 RMS 128 0.001 0.1 10 160 160 160 93760 187520 99.95999908 98.47000122 0.002911944 0.060931921 68 | 30 2bitsym 2 RMS 128 0.001 0.1 10 160 160 160 93760 187520 98.50499725 97.55999756 0.04577975 0.087386839 69 | 30 2bitsym 2 RMS 128 0.001 0.1 10 160 160 160 93760 187520 99.95666504 98.36000061 0.003068546 0.067683592 70 | 30 8bit 8 RMS 128 0.001 0.1 10 64 64 64 25216 201728 99.89167023 98.13999939 0.00680928 0.066086918 71 | 30 8bit 8 RMS 128 0.001 0.1 10 64 64 64 25216 201728 99.86833191 98.02999878 0.007506728 0.073647425 72 | 30 8bit 8 RMS 128 0.001 0.1 10 64 64 64 25216 201728 99.84333038 97.93000031 0.009154263 0.078787938 73 | 30 8bit 8 RMS 128 0.001 0.1 10 64 64 64 25216 201728 99.88833618 97.98999786 0.007123163 0.074402936 74 | 30 4bitsym 4 RMS 128 0.001 0.1 10 128 128 128 66816 267264 99.98833466 98.33999634 0.001657521 0.070398375 75 | 30 8bit 8 RMS 128 0.001 0.1 10 80 80 80 34080 272640 99.89833069 98.08999634 0.006452073 0.076201476 76 | 30 None 32 RMS 128 0.001 0.1 10 32 32 32 10560 337920 98.88999939 97.47000122 0.04172992 0.089228399 77 | 30 8bit 8 RMS 128 0.001 0.1 10 96 96 96 43968 351744 99.95999908 98.31999969 0.003559482 0.066809647 78 | 30 4bitsym 4 RMS 128 0.001 0.1 10 160 160 160 93760 375040 99.99333191 98.52999878 0.001168294 0.062753431 79 | 30 8bit 8 RMS 128 0.001 0.1 10 128 128 128 66816 534528 99.98166656 98.26000214 0.001929083 0.069513015 80 | 30 None 32 RMS 128 0.001 0.1 10 48 48 48 17376 556032 99.66999817 97.61000061 0.016479535 0.081654519 81 | 30 8bit 8 RMS 128 0.001 0.1 10 160 160 160 93760 750080 99.9916687 98.31999969 0.000933084 0.067958511 82 | 30 None 32 RMS 128 0.001 0.1 10 64 64 64 25216 806912 99.25166321 98.75 0.025440145 0.043007035 83 | 30 None 32 RMS 128 0.001 0.1 10 64 64 64 25216 806912 99.88500214 98.12000275 0.007095081 0.06810613 84 | 30 None 32 RMS 128 0.001 0.1 10 64 64 64 25216 806912 99.87666321 98.08999634 0.007548976 0.075958706 85 | 30 None 32 RMS 128 0.001 0.1 10 64 64 64 25216 806912 99.54666901 98.09999847 0.016896078 0.076675124 86 | 30 None 32 RMS 128 0.001 0.1 10 64 64 64 25216 806912 99.83166504 98 0.008510431 0.07825122 87 | -------------------------------------------------------------------------------- /docs/Model_mcu.drawio: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /docs/Model.drawio: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /mcu/BitNetMCUdemo.c: -------------------------------------------------------------------------------- 1 | /* Using BitNeuMCU for inference of 16x16 MNIST images on a CH32V003 */ 2 | 3 | #include "ch32fun.h" 4 | 5 | // Latest version of CH32FUN seems to have more overhead. Hence, only three test patterns can be included. 6 | // Declare processfclayer an SRAM based function for speedup 7 | void processfclayer(int8_t *, const uint32_t *, int32_t, uint32_t, uint32_t, int32_t *) __attribute__((section(".srodata"))) __attribute__((used)); 8 | int32_t *processmaxpool22(int32_t *activations, uint32_t xy_input, int32_t *output) __attribute__((section(".srodata"))) __attribute__((used)); 9 | int32_t* processconv33ReLU(int32_t *activations, const int8_t *weights, uint32_t xy_input, uint32_t n_shift , int32_t *output) __attribute__((section(".srodata"))) __attribute__((used)); 10 | 11 | // #include "BitNetMCU_model_1k.h" 12 | // #include "BitNetMCU_model_12k.h" 13 | // #include "BitNetMCU_model_12k_FP130.h" 14 | // #include "BitNetMCU_model_cnn_48.h" 15 | // #include "BitNetMCU_model_cnn_32.h" 16 | // #include "BitNetMCU_model_cnn_16.h" 17 | // #include "BitNetMCU_model_cnn_16small.h" 18 | #include "BitNetMCU_model_cnn_64.h" 19 | // #include "BitNetMCU_model_cnn_letters.h" 20 | #include "../BitNetMCU_inference.c" 21 | #include 22 | 23 | const int8_t input_data_0[256] = {-22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, 11.0, 64.0, 30.0, 6.0, -14.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, 28.0, 124.0, 127.0, 115.0, 66.0, -3.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -12.0, 18.0, 58.0, 97.0, 124.0, 70.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -16.0, 47.0, 100.0, -11.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -21.0, 44.0, 104.0, -11.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -16.0, 68.0, 106.0, -12.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -13.0, 77.0, 99.0, -18.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -13.0, 77.0, 96.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -13.0, 77.0, 96.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -13.0, 77.0, 96.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -17.0, 62.0, 97.0, -20.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, 18.0, 71.0, -14.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -22.0, -20.0, -16.0, -21.0, -22.0, -22.0, -22.0, -22.0, -22.0}; 24 | const uint8_t label_0 = 7; 25 | const int8_t input_data_1[256] = {-20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -4.0, 69.0, 6.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, 5.0, 106.0, 42.0, -18.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, 28.0, 119.0, 50.0, -17.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -14.0, 64.0, 125.0, 19.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, 3.0, 99.0, 121.0, 13.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -19.0, 33.0, 120.0, 100.0, -7.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -16.0, 71.0, 126.0, 65.0, -17.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, 6.0, 106.0, 112.0, 13.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, 47.0, 125.0, 100.0, -3.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, 54.0, 127.0, 68.0, -19.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, 53.0, 119.0, 43.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, 16.0, 59.0, -3.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0}; 26 | const uint8_t label_1 = 1; 27 | const int8_t input_data_2[256] = {-21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -16.0, 11.0, 49.0, 48.0, 0.0, -20.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -20.0, -5.0, 41.0, 80.0, 62.0, 56.0, 70.0, 0.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -18.0, 10.0, 76.0, 58.0, 3.0, -18.0, -14.0, 70.0, 24.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -19.0, 28.0, 94.0, 29.0, -15.0, -21.0, -21.0, 1.0, 99.0, 47.0, -20.0, -21.0, -21.0, -21.0, -21.0, -21.0, 7.0, 87.0, 29.0, -19.0, -21.0, -20.0, -9.0, 65.0, 90.0, 7.0, -21.0, -21.0, -21.0, -21.0, -21.0, -19.0, 55.0, 67.0, -14.0, -21.0, -20.0, -4.0, 77.0, 118.0, 30.0, -19.0, -21.0, -21.0, -21.0, -21.0, -21.0, -17.0, 68.0, 33.0, -12.0, -0.0, 29.0, 69.0, 127.0, 72.0, -12.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -18.0, 55.0, 91.0, 84.0, 75.0, 38.0, 51.0, 111.0, 8.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -11.0, 16.0, 14.0, -8.0, -13.0, 62.0, 65.0, -18.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -3.0, 84.0, 18.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, 14.0, 68.0, -13.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, 8.0, 39.0, -17.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -20.0, -18.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0, -21.0}; 28 | const uint8_t label_2 = 9; 29 | // const int8_t input_data_3[256] = {-20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -13.0, -15.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -6.0, 41.0, 78.0, 38.0, -18.0, -20.0, -20.0, -20.0, -20.0, -17.0, -17.0, -20.0, -20.0, -20.0, -20.0, -11.0, 67.0, 109.0, 63.0, 6.0, -20.0, -20.0, -20.0, -20.0, -8.0, 48.0, 50.0, -8.0, -20.0, -20.0, -20.0, 2.0, 108.0, 65.0, -14.0, -20.0, -20.0, -20.0, -20.0, -12.0, 59.0, 114.0, 89.0, 4.0, -20.0, -20.0, -20.0, 10.0, 114.0, 27.0, -20.0, -20.0, -20.0, -20.0, -20.0, 36.0, 122.0, 65.0, -14.0, -20.0, -20.0, -20.0, -20.0, -2.0, 96.0, 55.0, -13.0, -20.0, -20.0, -20.0, -12.0, 89.0, 114.0, 16.0, -20.0, -20.0, -20.0, -20.0, -20.0, -17.0, 43.0, 100.0, 46.0, -5.0, -15.0, -18.0, 6.0, 115.0, 84.0, -9.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -14.0, 45.0, 115.0, 100.0, 78.0, 50.0, 66.0, 127.0, 53.0, -17.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -13.0, 28.0, 76.0, 91.0, 104.0, 127.0, 122.0, 28.0, -18.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -19.0, -16.0, -14.0, -1.0, 71.0, 114.0, 8.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, 19.0, 112.0, 39.0, -13.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -11.0, 70.0, 89.0, 19.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -18.0, -6.0, -18.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0, -20.0}; 30 | // const uint8_t label_3 = 4; 31 | 32 | #ifdef MODEL_CNNMNIST 33 | 34 | uint32_t BitMnistInference(const int8_t *input) { 35 | int32_t layer_out[256]; // has to hold 16x16 image 36 | int8_t layer_in[MAX_N_ACTIVATIONS*4]; 37 | 38 | /* 39 | Layer: L2 Conv2d bpw: 8 1 -> 64 groups:1 Kernel: 3x3 Incoming: 16x16 Outgoing: 14x14 40 | Layer: L4 Conv2d bpw: 8 64 -> 64 groups:64 Kernel: 3x3 Incoming: 14x14 Outgoing: 12x12 41 | Layer: L6 MaxPool2d Pool Size: 2 Incoming: 12x12 Outgoing: 6x6 42 | Layer: L7 Conv2d bpw: 8 64 -> 64 groups:64 Kernel: 3x3 Incoming: 6x6 Outgoing: 4x4 43 | Layer: L9 MaxPool2d Pool Size: 2 Incoming: 4x4 Outgoing: 2x2 44 | Layer: L11 Quantization type: <2bitsym>, Bits per weight: 2, Num. incoming: 256, Num outgoing: 96 45 | Layer: L13 Quantization type: <4bitsym>, Bits per weight: 4, Num. incoming: 96, Num outgoing: 64 46 | Layer: L15 Quantization type: <4bitsym>, Bits per weight: 4, Num. incoming: 64, Num outgoing: 10 47 | */ 48 | 49 | // Depthwise separable convolution with 32 bit activations and 8 bit weights 50 | int32_t *tmpbuf=(int32_t*)layer_out; 51 | int32_t *outputptr=(int32_t*)layer_in; 52 | for (uint32_t channel=0; channel < L7_out_channels; channel++) { 53 | 54 | for (uint32_t i=0; i < 16*16; i++) { 55 | tmpbuf[i]=input[i]; 56 | } 57 | processconv33ReLU(tmpbuf, L2_weights + 9*channel, L2_incoming_x, 4, tmpbuf); 58 | processconv33ReLU(tmpbuf, L4_weights + 9*channel, L4_incoming_x, 4, tmpbuf); 59 | processmaxpool22(tmpbuf, L6_incoming_x, tmpbuf); 60 | processconv33ReLU(tmpbuf, L7_weights + 9*channel, L7_incoming_x, 4, tmpbuf); 61 | 62 | outputptr= processmaxpool22(tmpbuf, L9_incoming_x, outputptr); 63 | } 64 | 65 | // Normalization and conversion to 8-bit 66 | ReLUNorm((int32_t*)layer_in, layer_in, L7_out_channels * L9_outgoing_x * L9_outgoing_y); 67 | 68 | // Fully connected layers 69 | processfclayer(layer_in, L11_weights, L11_bitperweight, L11_incoming_weights, L11_outgoing_weights, layer_out); 70 | ReLUNorm(layer_out, layer_in, L11_outgoing_weights); 71 | 72 | processfclayer(layer_in, L13_weights, L13_bitperweight, L13_incoming_weights, L13_outgoing_weights, layer_out); 73 | ReLUNorm(layer_out, layer_in, L13_outgoing_weights); 74 | 75 | processfclayer(layer_in, L15_weights, L15_bitperweight, L15_incoming_weights, L15_outgoing_weights, layer_out); 76 | return ReLUNorm(layer_out, layer_in, L15_outgoing_weights); 77 | } 78 | 79 | #elif defined(MODEL_FCMNIST) 80 | 81 | uint32_t BitMnistInference(const int8_t *input) { 82 | int32_t layer_out[MAX_N_ACTIVATIONS]; 83 | int8_t layer_in[MAX_N_ACTIVATIONS]; 84 | int32_t prediction; 85 | 86 | processfclayer((int8_t*)input, L1_weights, L1_bitperweight, L1_incoming_weights, L1_outgoing_weights, layer_out); 87 | ReLUNorm(layer_out, layer_in, L1_outgoing_weights); 88 | 89 | processfclayer(layer_in, L2_weights, L2_bitperweight, L2_incoming_weights, L2_outgoing_weights, layer_out); 90 | ReLUNorm(layer_out, layer_in, L2_outgoing_weights); 91 | 92 | processfclayer(layer_in, L3_weights, L3_bitperweight, L3_incoming_weights, L3_outgoing_weights, layer_out); 93 | prediction=ReLUNorm(layer_out, layer_in, L3_outgoing_weights); 94 | 95 | #if NUM_LAYERS == 4 96 | processfclayer(layer_in, L4_weights, L4_bitperweight, L4_incoming_weights, L4_outgoing_weights, layer_out); 97 | prediction=ReLUNorm(layer_out, layer_in, L4_outgoing_weights); 98 | #endif 99 | 100 | return prediction; 101 | } 102 | 103 | #endif 104 | 105 | void TestSample(const int8_t *input, const uint8_t label, const uint8_t sample) { 106 | volatile int32_t startticks, endticks; 107 | int32_t prediction; 108 | 109 | startticks = SysTick->CNT; 110 | prediction = BitMnistInference(input); 111 | endticks = SysTick->CNT; 112 | 113 | printf( "Inference of Sample %d\tPrediction: %ld\tLabel: %d\tTiming: %lu clock cycles\n", sample, prediction, label, endticks-startticks); 114 | } 115 | 116 | 117 | int main() 118 | { 119 | SystemInit(); 120 | 121 | while(1) 122 | { 123 | printf("Starting MNIST inference...\n"); 124 | TestSample(input_data_0, label_0,1); 125 | TestSample(input_data_1, label_1,2); 126 | TestSample(input_data_2, label_2,3); 127 | // BitMnistInference(input_data_3, label_3,4); 128 | Delay_Ms(1000); 129 | } 130 | } 131 | 132 | -------------------------------------------------------------------------------- /mcu/BitNetMCU_model_cnn_16small.h: -------------------------------------------------------------------------------- 1 | // Automatically generated header file 2 | // Date: 2025-09-07 11:10:55.437128 3 | // Quantized model exported from octav_CNNMNIST_Aug_BitMnist_4bitsym_width64_48_0_epochs60.pth 4 | // Generated by exportquant.py 5 | 6 | #include 7 | 8 | #ifndef BITNETMCU_MODEL_H 9 | #define BITNETMCU_MODEL_H 10 | 11 | // Model class name as defined in models.py 12 | #define MODEL_CNNMNIST 13 | 14 | // Number of layers 15 | #define NUM_LAYERS 8 16 | 17 | // Maximum number of activations per layer 18 | #define MAX_N_ACTIVATIONS 64 19 | 20 | // Layer: L2 (Convolutional) 21 | #define L2_active 22 | #define L2_type BitConv2d 23 | #define L2_in_channels 1 24 | #define L2_out_channels 16 25 | #define L2_incoming_x 16 26 | #define L2_incoming_y 16 27 | #define L2_outgoing_x 14 28 | #define L2_outgoing_y 14 29 | #define L2_kernel_size 3 30 | #define L2_stride 1 31 | #define L2_padding 0 32 | #define L2_groups 1 33 | #define L2_bitperweight 8 34 | const int8_t L2_weights[] = { 35 | 18,-28,-49,-21,94,-45,15,4,-47,32,-65,-21,-52,-24,57,-67, 36 | 31,6,2,-99,4,66,21,38,-16,15,-2,9,63,-42,39,16, 37 | -22,23,-30,-56,-14,-39,-31,38,21,-32,18,2,41,-102,-17,-16, 38 | -20,40,-24,5,42,44,16,-53,-65,5,73,47,59,29,-1,3, 39 | -56,42,-18,-41,38,-34,85,27,-19,-69,11,61,-23,-88,-21,73, 40 | -55,-43,-34,-28,-19,26,-79,-73,26,8,-19,-27,-7,20,53,82, 41 | -28,-47,-22,-34,12,25,-30,-97,-72,9,64,-59,-21,-8,-85,-24, 42 | -29,76,33,41,-3,-71,-14,43,-54,44,-20,-90,-128,-29,30,-27, 43 | 30,5,15,-76,24,12,64,35,-24,15,55,-101,31,32,18,-36, 44 | }; 45 | 46 | // Layer: L4 (Convolutional) 47 | #define L4_active 48 | #define L4_type BitConv2d 49 | #define L4_in_channels 16 50 | #define L4_out_channels 16 51 | #define L4_incoming_x 14 52 | #define L4_incoming_y 14 53 | #define L4_outgoing_x 12 54 | #define L4_outgoing_y 12 55 | #define L4_kernel_size 3 56 | #define L4_stride 1 57 | #define L4_padding 0 58 | #define L4_groups 16 59 | #define L4_bitperweight 8 60 | const int8_t L4_weights[] = { 61 | 47,47,28,45,42,36,38,5,9,-15,34,49,50,17,35,16, 62 | 14,14,-81,57,-3,-101,-6,29,3,-51,-4,35,17,30,10,15, 63 | 26,0,-5,-1,29,28,5,8,37,18,32,21,22,-76,-72,-31, 64 | 36,29,-37,15,14,29,-11,26,-28,-31,54,-27,-26,32,-15,3, 65 | 0,-7,7,26,-5,71,35,75,-13,-44,-56,32,32,-3,21,23, 66 | 73,31,37,27,7,-24,-45,-52,-14,15,15,10,32,20,30,-11, 67 | -102,-112,-34,19,34,28,-8,19,63,-39,-71,-61,45,-8,-96,-5, 68 | -128,-12,-35,23,-11,10,38,25,35,11,-36,-16,-42,-24,2,16, 69 | 9,29,20,20,46,24,17,-23,-1,8,9,27,27,47,47,10, 70 | }; 71 | 72 | #define L6_active 73 | #define L6_type MaxPool2d 74 | #define L6_pool_size 2 75 | #define L6_incoming_x 12 76 | #define L6_incoming_y 12 77 | #define L6_outgoing_x 6 78 | #define L6_outgoing_y 6 79 | 80 | // Layer: L7 (Convolutional) 81 | #define L7_active 82 | #define L7_type BitConv2d 83 | #define L7_in_channels 16 84 | #define L7_out_channels 16 85 | #define L7_incoming_x 6 86 | #define L7_incoming_y 6 87 | #define L7_outgoing_x 4 88 | #define L7_outgoing_y 4 89 | #define L7_kernel_size 3 90 | #define L7_stride 1 91 | #define L7_padding 0 92 | #define L7_groups 16 93 | #define L7_bitperweight 8 94 | const int8_t L7_weights[] = { 95 | 19,22,39,-12,5,-17,-39,-6,-14,37,20,7,1,-4,-13,-6, 96 | -34,-25,1,2,0,30,11,11,20,44,15,18,34,8,7,-6, 97 | -10,0,2,-9,18,8,29,-15,-13,0,-36,-15,12,36,20,10, 98 | 27,11,4,20,8,-2,7,27,53,-4,11,17,-6,0,4,10, 99 | 12,13,1,1,8,-26,-61,-128,-1,32,41,-12,5,-3,-30,-37, 100 | -32,42,35,0,28,0,-13,-19,12,-22,15,37,27,1,-1,-2, 101 | -5,2,-1,-10,-1,-12,13,5,18,29,30,38,8,-11,-3,31, 102 | 31,4,41,19,-3,13,21,44,-8,-7,2,-55,-49,-31,-2,-1, 103 | -9,-3,-11,4,46,4,16,-26,-9,0,-38,7,7,13,36,-2, 104 | }; 105 | 106 | #define L9_active 107 | #define L9_type MaxPool2d 108 | #define L9_pool_size 2 109 | #define L9_incoming_x 4 110 | #define L9_incoming_y 4 111 | #define L9_outgoing_x 2 112 | #define L9_outgoing_y 2 113 | 114 | // Layer: L11 115 | // QuantType: 2bitsym 116 | #define L11_active 117 | #define L11_bitperweight 2 118 | #define L11_incoming_weights 64 119 | #define L11_outgoing_weights 64 120 | const uint32_t L11_weights[] = { 121 | 0x8a0a28aa,0xa22888a8,0x2a2aaa8a,0x8008aaa2,0xf96e2e16,0xa8b42928,0x16869718,0xf0a0a1b0, 122 | 0x8a7b6240,0xb28ab212,0x492c1098,0x002b2023,0x8826c13e,0x810c9ab2,0x0a23762d,0x40a42023, 123 | 0x2a22aa08,0xaaaa0aaa,0x2a088888,0xa88aaaaa,0xb09012cc,0x9b063584,0xe0f29137,0xe2e6a28c, 124 | 0xb091e452,0xb8c830a3,0x898cd893,0xcc820a0c,0x72a39862,0x602a7628,0x00a0fb02,0x70bb8068, 125 | 0xa48ae002,0x120e1038,0xab020fa8,0x0faee144,0x80684a0e,0x8aa20420,0x2b0bb6b1,0xae68218c, 126 | 0x0e5248ac,0x28eea290,0xa01280aa,0xa6238aa5,0xbc228ff4,0x180aba09,0x02c1f803,0x8489b254, 127 | 0x50f28a99,0x682a2a80,0x21b93298,0x40b0e280,0xa6f88237,0x09080133,0x000ab003,0xb804672c, 128 | 0x2090c8c2,0x88ae41a8,0x08c4fa2f,0x82c28504,0xe02228ed,0x2f94229f,0x99012aa7,0x813f8398, 129 | 0xeab84abe,0x4052c2aa,0x8300400f,0xa9608a01,0x20082c88,0x048b096b,0x0accfb2e,0xfba3088c, 130 | 0xaa2210f8,0x82aa82aa,0x8250aeae,0x65050524,0x0aac2b01,0x4202238c,0x2b28002e,0xab082f95, 131 | 0x04a1a8ca,0xa8f01ab0,0x2380fba8,0xa6911220,0xde8288ac,0x9a18d220,0xa0a64030,0xd4b282a6, 132 | 0x6a0f5841,0x80a00800,0x8bacc424,0xa0a0ffc2,0x50ae6e1a,0x22240d38,0x392a0484,0x1e8b8cee, 133 | 0x88218cbb,0x31049f08,0x0382220d,0xa40081a4,0x8c6a202a,0x8a88802b,0x1d6f22a5,0x493a2583, 134 | 0xb4fd0a76,0x100e0229,0x82a08888,0xca09e140,0xb52a2325,0x9e83a1a2,0x1b0882ab,0x80037f00, 135 | 0x0a006440,0xb2ee36d8,0x88a8edd8,0x02022403,0x202a4a28,0xc8ca0e09,0x9f88a2a7,0x080e4f29, 136 | 0x9aa2c082,0x216c8315,0x2f220a9e,0x4a098f8a,0x048bc2c2,0x1f2820a9,0xa52b00b7,0x488b072f, 137 | 0x8a0a080a,0x20a8620a,0x208cbb00,0xaa32c0aa,0x2e274000,0x032a3098,0xacaa988d,0xa68281a7, 138 | 0xaaaaaaa8,0x8aaaa2a8,0xa88aaaa0,0xa80aaaaa,0xa0a4e6b8,0x898aa93b,0x2f0a2a82,0x0d8a5f4c, 139 | 0x4a9f0cc4,0xb438122a,0x2eb4f52f,0x02880fa0,0xb27b5080,0x5002b0a1,0xa02a9aa3,0xf282a808, 140 | 0x029ae8ca,0x0a4220aa,0x8a802287,0x1eafc58a,0x20f8ff02,0x9aa80f00,0x23108ba9,0x40052400, 141 | 0x0082a41a,0x020fb205,0x84b24807,0x0e82268f,0x14ed30aa,0xb0a40bcf,0xa250ee80,0x281d5808, 142 | 0x2aa8fec0,0xc0a6a253,0xa284c231,0x12862990,0x2c4a4aa8,0x08e4ba0c,0xb08cc586,0x281f0f69, 143 | 0x4ac80822,0x44480bc8,0x292888ce,0xe0d1c186,0x08e9fe98,0x68240aa8,0x8a227094,0x80e0c0ce, 144 | 0x2822a22a,0xaa0882aa,0xa88aaaaa,0xa22a228a,0x2a88b094,0xa022053f,0xaf2422ac,0x2b080f08, 145 | 0x22a82aaa,0x8aa88a20,0x20aaaaaa,0xaa0aa88a,0x9880ea03,0xf2cdf88a,0xb181e81a,0x0a2a2540, 146 | 0xd0325260,0x02aa1a10,0xa83e48b8,0x10f9aa22,0x48142142,0xa2ce20b8,0x1aceee0b,0xee8f0c0b, 147 | 0xf62b0a46,0x0a80340a,0xa08e1092,0xee08aa7c,0xfc56044a,0xf2ccc0a9,0x9802902b,0x2800253a, 148 | 0x52284880,0x093612a3,0x2a223b2d,0x46200e04,0xaaa08a2a,0xa28a02a2,0xa8aaaa22,0x2aa822a0, 149 | 0xa02382e0,0x87200b05,0x2bb88a32,0x2c314b0f,0x2400fcce,0xa024a017,0xa5a64365,0xaa22a48e, 150 | 0xa38001fc,0x0a474613,0x60a69828,0xe48224ae,0x88f888a6,0x480ae8bb,0x0158a885,0x014bcc2a, 151 | 0xa8195c8c,0x85ea20f0,0xa88a8aad,0x0a60e9c0,0x2a028af7,0x0000ba00,0x39a8e914,0xab18200b, 152 | 0x421a8444,0xe2b81a01,0x3a20fa00,0x220a032a,0x0808fe11,0xa3032220,0x2fe9fba2,0x2afa8855, 153 | }; //first channel is topmost bit 154 | 155 | // Layer: L13 156 | // QuantType: 4bitsym 157 | #define L13_active 158 | #define L13_bitperweight 4 159 | #define L13_incoming_weights 64 160 | #define L13_outgoing_weights 48 161 | const uint32_t L13_weights[] = { 162 | 0x91ab1a11,0x8313a218,0xa4023b29,0x8d8820bc,0x429129b2,0xb1088800,0x81c30981,0x081e0d33, 163 | 0x9b9181f5,0x879081a9,0x01860910,0xa3c1990a,0x1183b041,0x118ad204,0x9f8afd30,0xd2098c1a, 164 | 0x00080909,0xb9abc21b,0xb3003da3,0x28ae9801,0x8a023da1,0xf682a981,0x00b32821,0x2a80aa31, 165 | 0x07a880c6,0x04994b99,0x808b29c0,0x3108008c,0x280a8b14,0xb2688402,0x0d8b8a10,0x12a029a3, 166 | 0x0a040f41,0x9fbc0531,0x803f1d9e,0xd9880b00,0x19140805,0x013c008b,0x04028009,0x8088811d, 167 | 0x959b0429,0x0480cd81,0x9aa92b33,0xa19194a1,0x2a01b219,0x9083908a,0x00a881c0,0x811190a9, 168 | 0x91948a18,0xc008898a,0x8b0a22db,0x11d9da28,0x99803889,0x8b39b984,0x89aab328,0x940f1010, 169 | 0x93428a00,0xf488b8d8,0x1482c900,0x83f2c1a9,0xb900a1aa,0x0ca0ca92,0x0a920130,0x82989993, 170 | 0x18a882aa,0x2b839031,0x85410139,0x9d82a001,0x5892a922,0xb9299994,0x8a29d0a8,0x1eaa4985, 171 | 0x180a0280,0x8c086f11,0x9a2c82d9,0x9bb20820,0x158f81bb,0x8f9c1991,0x1d2fb008,0x3a422b8b, 172 | 0x03b49a43,0x38da911a,0x9a391998,0x2018d9b8,0x810af882,0xac4a108b,0x00019111,0x180188aa, 173 | 0x8a1a8a91,0x0b0c8a90,0x5549c8b3,0x188c1c9c,0x80baa30c,0x14aa2c0a,0x8a88e818,0xa8011b84, 174 | 0x814a821d,0x0c13bb89,0x29392aa0,0x28a95d28,0x950a859c,0x83001a08,0x0e090808,0x9939520b, 175 | 0x8048023a,0xbf181020,0xb89d9113,0x82132082,0x0188aa8a,0x319a210b,0x929482b0,0xa889998b, 176 | 0x03a0049a,0xc9081f90,0x9d0c0332,0x92a1b4a3,0x2289123f,0x20c2108a,0x98a8a2b8,0x380cb891, 177 | 0x09a802df,0x1003a9ba,0x29a52381,0x2e1b1a0a,0xd4091812,0x14811084,0x1ab90a18,0x09aa990c, 178 | 0x0b831088,0xc430b080,0x1c001532,0x32c90118,0xd90200b9,0xa3a29a82,0x9c299228,0x8b2d1d19, 179 | 0x0a110a11,0x9a018208,0x97810948,0x8ab0089a,0x448311a1,0x9a29b814,0x09010918,0xad090c03, 180 | 0x19230908,0xad0b3b00,0x1d2e82ea,0x248c2eab,0x010d0a99,0xa21c188a,0x018ba118,0x88114b2b, 181 | 0x0ac889c9,0xb02a381c,0x82803c30,0x1b5c0078,0x91003880,0x33c23101,0x88dad928,0xab291c00, 182 | 0x88908bb5,0x4987382a,0x8828a1ad,0xb01281a3,0x39898022,0x0f280a81,0x88304898,0x5aa21318, 183 | 0x9bf19280,0x3bc1c130,0x081a20bf,0xa229ac93,0x08001160,0x29810919,0x002090a0,0x7a39054c, 184 | 0x0a201129,0xc82408a8,0x99aabbd1,0xa0b81f12,0x10040880,0x9183118a,0x82eb4a08,0xd033068c, 185 | 0x99008d03,0x3da2b92e,0xb63b0a8a,0x81291a80,0x110af222,0x880c1b0a,0x03a32190,0x1cb81223, 186 | 0x00091112,0xb809441c,0x80911d4b,0xee6631a2,0x918840ab,0x40c1188a,0x80c38889,0x0c9dc92b, 187 | 0x1d2382a1,0xb80a50a0,0xaf8aa293,0xb48898c4,0xa90391bd,0x219a0989,0x8803d810,0x81a309ac, 188 | 0x18a989c5,0xdc0c28e3,0x19b0d803,0x10ac9da3,0x9880192b,0x02b11501,0x0cabbd08,0xc102b9a1, 189 | 0x042a8129,0x79a3e020,0x9289ff88,0xa1138fc0,0xa189d229,0x1a40198b,0x08d810c8,0x911902aa, 190 | 0x8009892b,0x0322f822,0x303919ab,0x138a90c9,0x8b11aa48,0xc8829808,0x9b2a15d8,0xaa8050c9, 191 | 0x9a0b1938,0xaf3a1a83,0x2080420c,0x8df1033b,0xa884a8db,0x90a3a882,0x08130000,0x1cd10a01, 192 | 0x8b2099ba,0xb9008aaa,0x89a2a810,0x90912a42,0x92983009,0x0aa10991,0x0b0c8a81,0x38398308, 193 | 0x93a80888,0x20c382ab,0x4bb39089,0x3c31b839,0xb0883289,0x30092112,0x02220018,0xa002b5a0, 194 | 0x081a9041,0x12211a89,0x85cf0b11,0x99082239,0x010cb002,0x105c128b,0x14a41098,0x83ba4801, 195 | 0x09108883,0xa9103924,0x0189a122,0x80a89ba2,0x0194936b,0x21288302,0x8e0aa950,0xa8151915, 196 | 0x08c91922,0x19d8a442,0x23c11021,0x9e21b1bc,0x3d8130a0,0xd9490801,0x80387999,0xca801302, 197 | 0x01a109c4,0x1100208c,0x8b2f209a,0x122b8a73,0x9809b822,0x2c89020e,0x0b810088,0x6aa22358, 198 | 0x09b8014b,0x122aa098,0xb1190c4a,0xaa202181,0x90011a11,0x93240081,0x02b0b010,0xc118114e, 199 | 0x8d8182a8,0xd8a998c2,0x29b1c5e8,0x2949b840,0x1b889930,0x289c0282,0x028ca099,0x2231c1d1, 200 | 0x89c192a1,0x839e4bbb,0xa9eb0227,0xb01bba02,0x8a8e2828,0x498a819d,0x1bba9291,0x12a21a29, 201 | 0x154a0001,0xa99a0202,0x11038969,0x9293a09a,0x818ab3b1,0x00099080,0x0c3d4208,0x59caacb8, 202 | 0x183080ae,0xd2208998,0xa191d033,0x84c810a9,0xd28d51fd,0x898a0c01,0x0a19b018,0x8a8b313c, 203 | 0x8b818c2a,0x9389c398,0xa1908b2b,0xb2112423,0x1918f123,0x3ab8c99d,0x901302b1,0x4ac88321, 204 | 0x8d360d92,0x8b085999,0x0ab292df,0x012f0d3b,0x9319a9b0,0x2a1f2894,0x8b8af938,0x890c9b20, 205 | 0x8a3c0395,0x18a24a06,0x01982888,0x8ab30818,0x4b802819,0x1910a001,0x18ec3c10,0x437b8201, 206 | 0x10091132,0x081421a3,0x1094a251,0x8b9292a2,0xb80440b8,0xd9928085,0x0c083b20,0xbcb8aa81, 207 | 0x82021139,0xd08c9a8c,0x89083d50,0x0f148279,0x3119a2b8,0x2b98038a,0x080293a0,0x200a881a, 208 | 0x8b1f1093,0x19850b23,0x8990092f,0x9d914089,0x219d10a8,0xb9288981,0x01115b08,0x0b9da305, 209 | 0x98808b48,0xa828c398,0x4a23a890,0x2a118889,0xbb8388b1,0xb0b19b00,0x82510109,0xc98012db, 210 | }; //first channel is topmost bit 211 | 212 | // Layer: L15 213 | // QuantType: 4bitsym 214 | #define L15_active 215 | #define L15_bitperweight 4 216 | #define L15_incoming_weights 48 217 | #define L15_outgoing_weights 10 218 | const uint32_t L15_weights[] = { 219 | 0x909dadc9,0x3a119cfa,0x8a9c5215,0xa8148109,0x8ac4bca2,0xb4dff1fd,0x22e2e92f,0x2ac3aff2, 220 | 0xc4c99cd3,0xaefc1a9c,0x3151a07e,0xeb83986c,0xaaa4b259,0xbb5bb9ba,0x1c22bbbb,0x18c80911, 221 | 0xc984c2fb,0x0d62c000,0xed39b98b,0xea99b24f,0xab0bb28c,0x846c2c99,0xc398c639,0xdc930899, 222 | 0x4a3a4901,0xca388cc2,0x22b4c0c0,0x2dbcc40b,0xe8892a0b,0x3bbc2cab,0x9f1ab402,0xbc9c2239, 223 | 0xec8b836d,0xba94ca8b,0x08fa2d84,0xccc08319,0x2ce02c03,0xbcb894bc,0xbb33ad82,0x5ad3a0ae, 224 | 0x4caa1bca,0x95cbc1af,0xf102809e,0x90a88d94,0x8ad94bcc,0x8c3aba84,0xc03b91b8,0xca3a3ec2, 225 | 0x9c1abda9,0x07b43299,0x192910ac,0x839a2e8a,0x80c88b90,0x4aaaf803,0xd3ac52e4,0x880a9b0b, 226 | 0x32deb088,0xd0bc24d8,0xb8d90c12,0x18892c04, 227 | }; //first channel is topmost bit 228 | 229 | #endif 230 | -------------------------------------------------------------------------------- /docs/model_cnn_overview.drawio: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | -------------------------------------------------------------------------------- /BitNetMCU_inference.c: -------------------------------------------------------------------------------- 1 | /* 2 | BitNetMCU inference functions 3 | @cpldcpu April 2024 4 | 5 | Performs inference on fully connected layer on a very resource constrained MCU. 6 | 1,2,4 bit weights are supported. 7 | 8 | */ 9 | 10 | #include 11 | #include 12 | #include "BitNetMCU_inference.h" 13 | 14 | /** 15 | * @brief Applies a ReLU activation function to an array of integers and normalizes the result to 8-bit integers. 16 | * 17 | * @param input Pointer to the input array of 32-bit integers. 18 | * @param output Pointer to the output array of 8-bit integers. 19 | * @param n_input The number of elements in the input array. 20 | * @return The position of maximum value found in the input array before applying the ReLU activation. 21 | */ 22 | 23 | uint32_t ReLUNorm(int32_t *input, int8_t *output, uint32_t n_input) { 24 | int32_t max_val = -INT32_MAX; 25 | int32_t max_pos = 255; 26 | uint32_t scale; 27 | uint32_t shift; 28 | int32_t rounding; 29 | int32_t tmp; 30 | 31 | // Find the maximum value in the input array 32 | for (uint32_t i = 0; i < n_input; i++) { 33 | if (input[i] > max_val) { 34 | max_val = input[i]; 35 | max_pos = i; 36 | } 37 | } 38 | 39 | // Normalization 40 | // Dynamic shift according to max value in the input array 41 | scale=max_val>>7; // define max range, all bits above 7 will be shifted down 42 | shift=0; 43 | 44 | while (scale>0) { 45 | shift++; 46 | scale>>=1; 47 | } 48 | 49 | // impact of rounding is almost negligible (+0.03% in eval accuracy) 50 | // But rounding affects mismatch to python inference engine 51 | rounding = (1 << (shift))>>1; 52 | 53 | // Apply ReLU activation and normalize to 8-bit 54 | for (uint32_t i = 0; i < n_input; i++) { 55 | // Apply ReLU activation 56 | if (input[i] < 0) { 57 | output[i] = 0; 58 | } else { 59 | tmp=(input[i] + rounding) >> shift; 60 | 61 | // clipping needed to catch overflow from rounding 62 | if (tmp > 127) { 63 | output[i] = 127; 64 | } else { 65 | output[i] = tmp; 66 | } 67 | } 68 | // printf("%d,", output[i]); 69 | } 70 | // printf("---\n"); 71 | return max_pos; 72 | } 73 | 74 | /** 75 | * @brief Processes a fully connected layer in a neural network. 76 | * 77 | * This function processes a fully connected layer in a neural network by performing 78 | * the dot product of the input activations and weights, and stores the result in the output array. 79 | * 80 | * @param activations Pointer to the input activations of the layer. 81 | * @param weights Pointer to the weights of the layer. 82 | * @param bits_per_weight The number of bits per weight. 83 | * @param n_input The number of input neurons. 84 | * @param n_output The number of output neurons. 85 | * @param output Pointer to the output array where the result of the layer is stored. 86 | */ 87 | 88 | void processfclayer( int8_t *activations, const uint32_t *weights, int32_t bits_per_weight, uint32_t n_input, uint32_t n_output, int32_t *output) 89 | { 90 | const uint32_t *weightidx = weights; 91 | 92 | for (uint32_t i = 0; i < n_output; i++) { 93 | int8_t *activations_idx = activations; 94 | int32_t sum = 0; 95 | 96 | if (bits_per_weight == 1) { 97 | for (uint32_t k = 0; k < n_input; k+=32) { 98 | uint32_t weightChunk = *weightidx++; 99 | for (uint32_t j = 0; j < 32; j++) { 100 | int32_t in=*activations_idx++; 101 | sum += (weightChunk & 0x80000000) ? in : -in; // Note that sign is flipped for Binary quant (bit set equals positive) 102 | weightChunk <<= 1; 103 | } 104 | } 105 | } else if (bits_per_weight == 2 ) { 106 | for (uint32_t k = 0; k < n_input; k+=16) { 107 | uint32_t weightChunk = *weightidx++; 108 | for (uint32_t j = 0; j < 16; j++) { 109 | int32_t in=*activations_idx++; 110 | int32_t tmpsum = (weightChunk & 0x80000000) ? -in : in; // one complements sign (bit set equals negative) 111 | sum += tmpsum; // sign*in*1 112 | if (weightChunk & 0x40000000) sum += tmpsum<<1; // sign*in*2 113 | weightChunk <<= 2; 114 | } 115 | } 116 | // Multiplier-less inference for CH32V003 117 | // #if defined(__riscv) && !defined(__riscv_mul) 118 | #if defined(CH32V003) 119 | } else if (bits_per_weight == 4 ) { 120 | for (uint32_t k = 0; k < n_input; k+=8) { 121 | uint32_t weightChunk = *weightidx++; 122 | for (uint32_t j = 0; j < 8; j++) { 123 | int32_t in=*activations_idx++; 124 | if (in != 0) { // Skip zero activations to speed up inference in layers after first layer 125 | int32_t tmpsum = (weightChunk & 0x80000000) ? -in : in; // one complements sign (bit set equals negative) 126 | sum += tmpsum; // sign*in*1 127 | if (weightChunk & 0x10000000) sum += tmpsum<<1; // sign*in*2 128 | if (weightChunk & 0x20000000) sum += tmpsum<<2; // sign*in*4 129 | if (weightChunk & 0x40000000) sum += tmpsum<<3; // sign*in*8 130 | } 131 | weightChunk <<= 4; 132 | } 133 | } 134 | #else 135 | } else if (bits_per_weight == 4 ) { // 4bitsym 136 | for (uint32_t k = 0; k < n_input; k+=8) { 137 | uint32_t weightChunk = *weightidx++; 138 | for (uint32_t j = 0; j < 8; j++) { 139 | int32_t in=*activations_idx++; 140 | if (in != 0) { // Skip zero activations to speed up inference in layers after first layer 141 | int32_t tmpsum = (weightChunk & 0x80000000) ? -in : in; // one complements sign (bit set equals negative) 142 | sum += tmpsum; // sign*in*1 143 | sum += tmpsum * ((weightChunk>>(32-4-1))&0x0e); // sum += tmpsum * 2 144 | } 145 | weightChunk <<= 4; 146 | } 147 | } 148 | } else if (bits_per_weight == 8 + 4 ) { // 4 bit twos-complement 149 | for (uint32_t k = 0; k < n_input; k+=8) { 150 | int32_t weightChunk = *weightidx++; 151 | for (uint32_t j = 0; j < 8; j++) { 152 | int32_t in=*activations_idx++; 153 | int32_t weight = (weightChunk) >> (32-4); // extend sign, cut off lower bits 154 | sum += in*weight; 155 | weightChunk <<= 4; 156 | } 157 | } 158 | } else if (bits_per_weight == 8 + 8 ) { // 8 bit twos-complement 159 | for (uint32_t k = 0; k < n_input; k+=4) { 160 | int32_t weightChunk = *weightidx++; 161 | for (uint32_t j = 0; j < 4; j++) { 162 | int32_t in=*activations_idx++; 163 | int32_t weight = (weightChunk) >> (32-8); // extend sign, cut off lower bits 164 | sum += in*weight; 165 | weightChunk <<= 8; 166 | } 167 | } 168 | #endif 169 | } else if (bits_per_weight == 16 + 4 ) { // 4 bit shift 170 | for (uint32_t k = 0; k < n_input; k+=8) { 171 | uint32_t weightChunk = *weightidx++; 172 | for (uint32_t j = 0; j < 8; j++) { 173 | int32_t in=*activations_idx++; 174 | int32_t tmpsum; 175 | 176 | tmpsum = (weightChunk & 0x80000000) ? -in : in; // one complements sign (bit set equals negative) 177 | sum += tmpsum << ((weightChunk >> 28) & 7); // sign*in*2^log 178 | weightChunk <<= 4; 179 | } 180 | } 181 | } // else printf("Error: unsupported weight bit width %d\n", bits_per_weight); 182 | 183 | output[i] = sum; 184 | // printf("%d,", output[i]); 185 | } 186 | // printf("-X-\n"); 187 | } 188 | 189 | #ifndef MODEL_FCMNIST 190 | 191 | /** 192 | * @brief fused 3x3 conv2d and ReLU activation function 193 | * convo 194 | * This function processes a 3x3 convolutional layer in a neural network by performing 195 | * the dot product of the input activations and weights, and stores the result in the output array. 196 | * The function also applies a ReLU activation function to the result. 197 | * 198 | * To simplify the implementation, some assumptions are made: 199 | * - The kernel size is always 3x3, and the stride is always 1 and padding is always 0. 200 | * - Only square arrays (x=y) are supported. 201 | * - Always the full array is processed, no border handling. 202 | * - The input activations are stored in a 2D array with dimensions (xy_input, xy_input). 203 | * - The weights are stored in a 2D array with dimensions (3, 3). The weights are assumed to be 8-bit signed integers. 204 | * - The output is stored in a 2D array with dimensions (xy_input - 2, xy_input - 2). 205 | * 206 | * This function is intended to be used in a loop to process multiple channels in parallel. 207 | * Convolutions can be performed in place, i.e., the output array can be the same as the input activations array. 208 | * 209 | * @param activations Pointer to the input activations of the layer. 210 | * @param weights Pointer to the weights of the layer. 211 | * @param xy_input The number of input neurons. 212 | * @param n_shift The number of bits to shift the result of the convolution after summation, typically 8. 213 | * @param output Pointer to the output array where the result of the layer is stored. 214 | * @return Pointer to the end of the output array. 215 | */ 216 | 217 | int32_t* processconv33ReLU(int32_t *activations, const int8_t *weightsin, uint32_t xy_input, uint32_t n_shift , int32_t *output) { 218 | 219 | // Create SRAM copy of the weights for speed up 220 | int8_t weights[9]; 221 | 222 | for (uint32_t i = 0; i < 9; i++) { 223 | weights[i] = weightsin[i]; 224 | } 225 | 226 | for (uint32_t i = 0; i < xy_input - 2; i++) { 227 | int32_t *row = activations + i * xy_input; 228 | for (uint32_t j = 0; j < xy_input - 2; j++) { 229 | int32_t sum = 0; 230 | int32_t *in = row ++; 231 | 232 | // Unrolled convolution loop for 3x3 kernel 233 | sum += weights[0] * in[0] + weights[1] * in[1] + weights[2] * in[2]; 234 | in += xy_input; 235 | sum += weights[3] * in[0] + weights[4] * in[1] + weights[5] * in[2]; 236 | in += xy_input; 237 | sum += weights[6] * in[0] + weights[7] * in[1] + weights[8] * in[2]; 238 | 239 | // Apply shift and ReLU 240 | if (sum < 0) { 241 | sum = 0; // ReLU 242 | } else { 243 | 244 | // sum += (1 << n_shift) >> 1; // Add 1/2 of the shift value for rounding 245 | sum = sum >> n_shift; 246 | 247 | // if (sum > 127) { 248 | // sum = 127; // Clip to int8_t range. Important, otherwise the rounding can cause overflow! 249 | // } 250 | } 251 | *output++ = (int32_t)sum; 252 | } 253 | } 254 | 255 | return output; 256 | } 257 | 258 | /** 259 | * @brief maxpool2d 2x2 function 260 | * 261 | * This function performs a 2x2 max pooling operation on a 2D array of input activations. 262 | * The function divides the input activations into 2x2 non-overlapping regions and selects the maximum value in each region. 263 | * 264 | * To simplify the implementation, some assumptions are made: 265 | * - The input activations are stored in a 2D array with dimensions (xy_input, xy_input). 266 | * - The input activations are assumed to be 8-bit signed integers. 267 | * - The output is stored in a 2D array with dimensions (xy_input / 2, xy_input / 2). 268 | * - The stride of the max pooling operation is 2. 269 | * - Padding is not supported, so the input dimensions must be divisible by 2. 270 | * - Dilation is not supported. 271 | * - The output array can be the same as the input activations array. (in place operation) 272 | * 273 | * @param activations Pointer to the input activations of the layer. 274 | * @param xy_input The number of input neurons. 275 | * @param output Pointer to the output array where the result of the layer is stored. 276 | * @return Pointer to the end of the output array. 277 | */ 278 | 279 | int32_t *processmaxpool22(int32_t *activations, uint32_t xy_input, int32_t *output) { 280 | uint32_t xy_output = xy_input / 2; 281 | 282 | // Iterate over the output array dimensions 283 | for (uint32_t i = 0; i < xy_output; i++) { 284 | int32_t *row = activations + (2 * i) * xy_input; 285 | for (uint32_t j = 0; j < xy_output; j++) { 286 | 287 | // Find the maximum value in the corresponding 2x2 patch in the input activations 288 | int32_t max_val; 289 | max_val = row[0]; 290 | max_val = max_val > row[xy_input] ? max_val : row[xy_input]; 291 | row++; 292 | max_val = max_val > row[0] ? max_val : row[0]; 293 | max_val = max_val > row[xy_input] ? max_val : row[xy_input]; 294 | row++; 295 | 296 | // Store the maximum value in the output array 297 | *output++ = max_val; 298 | } 299 | } 300 | return output; 301 | } 302 | 303 | #endif -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import torch, torch.nn as nn, torch.optim as optim 2 | from torchvision import datasets, transforms 3 | from torch.utils.data import DataLoader 4 | from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR, CosineAnnealingWarmRestarts 5 | import numpy as np 6 | from torch.utils.tensorboard import SummaryWriter 7 | from torch.utils.data import ConcatDataset 8 | from datetime import datetime 9 | # from models import FCMNIST, CNNMNIST 10 | from BitNetMCU import BitLinear, BitConv2d, Activation 11 | import time 12 | import random 13 | import argparse 14 | import yaml 15 | from torchsummary import summary 16 | import importlib 17 | from models import MaskingLayer 18 | 19 | #---------------------------------------------- 20 | # BitNetMCU training 21 | #---------------------------------------------- 22 | 23 | def create_run_name(hyperparameters): 24 | runname = hyperparameters["runtag"] + '_' + hyperparameters["model"] + ('_Aug' if hyperparameters["augmentation"] else '') + '_BitMnist_' + hyperparameters["QuantType"] + "_width" + str(hyperparameters["network_width1"]) + "_" + str(hyperparameters["network_width2"]) + "_" + str(hyperparameters["network_width3"]) + "_epochs" + str(hyperparameters["num_epochs"]) 25 | hyperparameters["runname"] = runname 26 | return runname 27 | 28 | def load_model(model_name, params): 29 | try: 30 | module = importlib.import_module('models') 31 | model_class = getattr(module, model_name) 32 | kwargs = dict( 33 | network_width1=params["network_width1"], 34 | network_width2=params["network_width2"], 35 | network_width3=params["network_width3"], 36 | QuantType=params["QuantType"], 37 | NormType=params["NormType"], 38 | WScale=params["WScale"] 39 | ) 40 | if 'num_classes' in params: 41 | kwargs['num_classes'] = params['num_classes'] 42 | return model_class(**kwargs) 43 | except AttributeError: 44 | raise ValueError(f"Model {model_name} not found in models.py") 45 | 46 | def log_positive_activations(model, writer, epoch, all_test_images, batch_size): 47 | total_activations = 0 48 | positive_activations = 0 49 | 50 | def hook_fn(module, input, output): 51 | nonlocal total_activations, positive_activations 52 | if isinstance(module, nn.ReLU) or isinstance(module, Activation): 53 | total_activations += output.numel() 54 | positive_activations += (output > 0).sum().item() 55 | 56 | hooks = [] 57 | for layer in model.modules(): 58 | if isinstance(layer, nn.ReLU) or isinstance(layer, Activation): 59 | hooks.append(layer.register_forward_hook(hook_fn)) 60 | 61 | # Run a forward pass to trigger hooks 62 | with torch.no_grad(): 63 | for i in range(len(all_test_images) // batch_size): 64 | images = all_test_images[i * batch_size:(i + 1) * batch_size] 65 | model(images) 66 | 67 | for hook in hooks: 68 | hook.remove() 69 | 70 | fraction_positive = positive_activations / total_activations 71 | writer.add_scalar('Activations/positive_fraction', fraction_positive, epoch+1) 72 | 73 | return fraction_positive 74 | 75 | 76 | # Function to add L1 regularization on the mask 77 | def add_mask_regularization(model, lambda_l1): 78 | mask_layer = next((layer for layer in model.modules() if isinstance(layer, MaskingLayer)), None) 79 | 80 | if mask_layer is None: 81 | return 0 82 | 83 | l1_reg = lambda_l1 * torch.norm(mask_layer.mask, 1) 84 | return l1_reg 85 | 86 | 87 | def train_model(model, device, hyperparameters, train_data, test_data): 88 | num_epochs = hyperparameters["num_epochs"] 89 | learning_rate = hyperparameters["learning_rate"] 90 | halve_lr_epoch = hyperparameters.get("halve_lr_epoch", -1) 91 | runname = create_run_name(hyperparameters) 92 | 93 | # define dataloaders 94 | 95 | batch_size = hyperparameters["batch_size"] # Define your batch size 96 | 97 | # ON-the-fly augmentation requires using the (slow) dataloader. Without augmentation, we can load the entire dataset into GPU for speedup 98 | if hyperparameters["augmentation"]: 99 | train_loader = DataLoader( 100 | train_data, batch_size=batch_size, shuffle=True, 101 | num_workers=4, pin_memory=True) 102 | else: 103 | # load entire dataset into GPU for 5x speedup 104 | train_loader = DataLoader(train_data, batch_size=len(train_data), shuffle=False) # shuffling will be done separately 105 | entire_dataset = next(iter(train_loader)) 106 | all_train_images, all_train_labels = entire_dataset[0].to(device), entire_dataset[1].to(device) 107 | 108 | # Test dataset is always in GPU 109 | test_loader = DataLoader(test_data, batch_size=len(test_data), shuffle=False) 110 | entire_dataset = next(iter(test_loader)) 111 | all_test_images, all_test_labels = entire_dataset[0].to(device), entire_dataset[1].to(device) 112 | 113 | optimizer = optim.Adam(model.parameters(), lr=learning_rate) 114 | 115 | if hyperparameters["scheduler"] == "StepLR": 116 | scheduler = StepLR(optimizer, step_size=hyperparameters["step_size"], gamma=hyperparameters["lr_decay"]) 117 | elif hyperparameters["scheduler"] == "Cosine": 118 | scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0) 119 | elif hyperparameters["scheduler"] == "CosineWarmRestarts": 120 | scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=hyperparameters["T_0"], T_mult=hyperparameters["T_mult"], eta_min=0) 121 | else: 122 | raise ValueError("Invalid scheduler") 123 | 124 | criterion = nn.CrossEntropyLoss() 125 | 126 | # tensorboard writer 127 | now_str = datetime.now().strftime("%Y%m%d-%H%M%S") 128 | writer = SummaryWriter(log_dir=f'runs/{runname}-{now_str}') 129 | 130 | train_loss=[] 131 | test_loss = [] 132 | 133 | # Train the CNN 134 | for epoch in range(num_epochs): 135 | correct = 0 136 | train_loss=[] 137 | start_time = time.time() 138 | 139 | if hyperparameters["augmentation"]: 140 | for i, (images, labels) in enumerate(train_loader): 141 | images, labels = images.to(device), labels.to(device) 142 | optimizer.zero_grad() 143 | outputs = model(images) 144 | _, predicted = torch.max(outputs.data, 1) 145 | loss = criterion(outputs, labels) 146 | if epoch < hyperparameters['prune_epoch']: 147 | loss += add_mask_regularization(model, hyperparameters["lambda_l1"]) 148 | loss.backward() 149 | optimizer.step() 150 | train_loss.append(loss.item()) 151 | correct += (predicted == labels).sum().item() 152 | else: 153 | # Shuffle images (important!) 154 | indices = list(range(len(all_train_images))) 155 | random.shuffle(indices) 156 | 157 | for i in range(len(indices) // batch_size): 158 | batch_indices = indices[i * batch_size:(i + 1) * batch_size] 159 | images = torch.stack([all_train_images[i] for i in batch_indices]) 160 | labels = torch.stack([all_train_labels[i] for i in batch_indices]) 161 | optimizer.zero_grad() 162 | outputs = model(images) 163 | _, predicted = torch.max(outputs.data, 1) 164 | loss = criterion(outputs, labels) 165 | if epoch < hyperparameters['prune_epoch']: 166 | loss += add_mask_regularization(model, hyperparameters["lambda_l1"]) 167 | loss.backward() 168 | optimizer.step() 169 | train_loss.append(loss.item()) 170 | correct += (predicted == labels).sum().item() 171 | 172 | scheduler.step() 173 | 174 | if epoch + 1 == halve_lr_epoch: 175 | for param_group in optimizer.param_groups: 176 | param_group['lr'] *= 0.5 177 | print(f"Learning rate halved at epoch {epoch + 1}") 178 | 179 | 180 | trainaccuracy = correct / len(train_loader.dataset) * 100 181 | 182 | correct = 0 183 | total = 0 184 | test_loss = [] 185 | with torch.no_grad(): 186 | for i in range(len(all_test_images) // batch_size): 187 | images = all_test_images[i * batch_size:(i + 1) * batch_size] 188 | labels = all_test_labels[i * batch_size:(i + 1) * batch_size] 189 | 190 | outputs = model(images) 191 | _, predicted = torch.max(outputs.data, 1) 192 | loss = criterion(outputs, labels) 193 | test_loss.append(loss.item()) 194 | total += labels.size(0) 195 | correct += (predicted == labels).sum().item() 196 | 197 | # Log positive activations 198 | activity=log_positive_activations(model, writer, epoch, all_test_images, batch_size) 199 | 200 | end_time = time.time() 201 | epoch_time = end_time - start_time 202 | 203 | testaccuracy = correct / total * 100 204 | 205 | print(f'Epoch [{epoch+1}/{num_epochs}], LTrain:{np.mean(train_loss):.6f} ATrain: {trainaccuracy:.2f}% LTest:{np.mean(test_loss):.6f} ATest: {correct / total * 100:.2f}% Time[s]: {epoch_time:.2f} Act: {activity*100:.1f}% w_clip/entropy[bits]: ', end='') 206 | 207 | # update clipping scalars once per epoch 208 | totalbits = 0 209 | for i, layer in enumerate(model.modules()): 210 | if isinstance(layer, BitLinear) or isinstance(layer, BitConv2d): 211 | 212 | # update clipping scalar 213 | if epoch < hyperparameters['maxw_update_until_epoch']: 214 | layer.update_clipping_scalar(layer.weight, hyperparameters['maxw_algo'], hyperparameters['maxw_quantscale']) 215 | 216 | # calculate entropy of weights 217 | w_quant, _, _ = layer.weight_quant(layer.weight) 218 | _, counts = np.unique(w_quant.cpu().detach().numpy(), return_counts=True) 219 | probabilities = counts / np.sum(counts) 220 | entropy = -np.sum(probabilities * np.log2(probabilities)) 221 | 222 | print(f'{layer.s.item():.3f}/{entropy:.2f}', end=' ') 223 | 224 | totalbits += layer.weight.numel() * layer.bpw 225 | 226 | print() 227 | 228 | if epoch + 1 == hyperparameters ["prune_epoch"]: 229 | for m in model.modules(): 230 | if isinstance(m, MaskingLayer): 231 | pruned_channels, remaining_channels = m.prune_channels(prune_number=hyperparameters['prune_groupstoprune'], groups=hyperparameters['prune_totalgroups']) 232 | 233 | writer.add_scalar('Loss/train', np.mean(train_loss), epoch+1) 234 | writer.add_scalar('Accuracy/train', trainaccuracy, epoch+1) 235 | writer.add_scalar('Loss/test', np.mean(test_loss), epoch+1) 236 | writer.add_scalar('Accuracy/test', testaccuracy, epoch+1) 237 | writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch+1) 238 | writer.flush() 239 | 240 | numofweights = sum(p.numel() for p in model.parameters() if p.requires_grad) 241 | # totalbits = numofweights * hyperparameters['BPW'] 242 | 243 | print(f'TotalBits: {totalbits} TotalBytes: {totalbits/8.0} ') 244 | 245 | writer.add_hparams(hyperparameters, {'Parameters': numofweights, 'Totalbits': totalbits, 'Accuracy/train': trainaccuracy, 'Accuracy/test': testaccuracy, 'Loss/train': np.mean(train_loss), 'Loss/test': np.mean(test_loss)}) 246 | writer.close() 247 | 248 | if __name__ == '__main__': 249 | parser = argparse.ArgumentParser(description='Training script') 250 | parser.add_argument('--params', type=str, help='Name of the parameter file', default='trainingparameters.yaml') 251 | 252 | args = parser.parse_args() 253 | 254 | if args.params: 255 | paramname = args.params 256 | else: 257 | paramname = 'trainingparameters.yaml' 258 | 259 | print(f'Load parameters from file: {paramname}') 260 | with open(paramname) as f: 261 | hyperparameters = yaml.safe_load(f) 262 | 263 | runname= create_run_name(hyperparameters) 264 | print(runname) 265 | 266 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 267 | 268 | # Dataset selection (MNIST default, EMNIST optional) 269 | dataset_name = hyperparameters.get("dataset", "MNIST").upper() 270 | 271 | if dataset_name == "MNIST": 272 | num_classes = 10 273 | mean, std = (0.1307,), (0.3081,) 274 | base_dataset_train = datasets.MNIST 275 | base_dataset_test = datasets.MNIST 276 | dataset_kwargs = {"train": True} 277 | dataset_kwargs_test = {"train": False} 278 | elif dataset_name.startswith("EMNIST"): 279 | # Expected format: EMNIST or EMNIST_BALANCED, EMNIST_BYCLASS etc. 280 | # Torchvision subsets: 'byclass'(62), 'bymerge'(47), 'balanced'(47), 'letters'(37), 'digits'(10), 'mnist'(10) 281 | split = dataset_name.split('_')[1].lower() if '_' in dataset_name else 'balanced' 282 | # Map common names 283 | split_alias = { 'BALANCED':'balanced', 'BYCLASS':'byclass', 'BYMERGE':'bymerge', 'LETTERS':'letters', 'DIGITS':'digits', 'MNIST':'mnist'} 284 | split = split_alias.get(split.upper(), split) 285 | # class counts per split 286 | split_classes = { 'byclass':62, 'bymerge':47, 'balanced':47, 'letters':37, 'digits':10, 'mnist':10 } 287 | num_classes = split_classes.get(split, 47) 288 | # EMNIST uses same normalization as MNIST typically 289 | mean, std = (0.1307,), (0.3081,) 290 | from torchvision.datasets import EMNIST 291 | base_dataset_train = EMNIST 292 | base_dataset_test = EMNIST 293 | dataset_kwargs = {"split": split, "train": True} 294 | dataset_kwargs_test = {"split": split, "train": False} 295 | else: 296 | raise ValueError(f"Unsupported dataset: {dataset_name}") 297 | 298 | transform = transforms.Compose([ 299 | transforms.Resize((16, 16)), 300 | transforms.ToTensor(), 301 | transforms.Normalize(mean, std) 302 | ]) 303 | 304 | train_data = base_dataset_train(root='data', transform=transform, download=True, **dataset_kwargs) 305 | test_data = base_dataset_test(root='data', transform=transform, download=True, **dataset_kwargs_test) 306 | 307 | if hyperparameters["augmentation"]: 308 | # Data augmentation for training data 309 | augmented_transform = transforms.Compose([ 310 | transforms.RandomRotation(degrees=hyperparameters["rotation1"]), 311 | transforms.RandomAffine(degrees=hyperparameters["rotation2"], translate=(0.1, 0.1), scale=(0.9, 1.1)), 312 | transforms.RandomApply([ 313 | transforms.ElasticTransform(alpha=40.0, sigma=4.0) 314 | ], p=hyperparameters["elastictransformprobability"]), 315 | transforms.Resize((16, 16)), 316 | transforms.ToTensor(), 317 | transforms.Normalize(mean, std) 318 | ]) 319 | 320 | augmented_train_data = base_dataset_train(root='data', transform=augmented_transform, download=True, **dataset_kwargs) 321 | train_data = ConcatDataset([train_data, augmented_train_data]) 322 | 323 | # Pass num_classes dynamically to model 324 | hyperparameters['num_classes'] = num_classes 325 | model = load_model(hyperparameters["model"], {**hyperparameters, 'num_classes': num_classes}) 326 | # If model class supports num_classes argument, it will be used. Otherwise ignore. 327 | if hasattr(model, 'to'): 328 | model = model.to(device) 329 | 330 | summary(model, input_size=(1, 16, 16)) # Assuming the input size is (1, 16, 16) 331 | 332 | print('training...') 333 | train_model(model, device, hyperparameters, train_data, test_data) 334 | 335 | print('saving model...') 336 | torch.save(model.state_dict(), f'modeldata/{runname}.pth') 337 | -------------------------------------------------------------------------------- /BitNetMCU_MNIST_test_data.h: -------------------------------------------------------------------------------- 1 | int8_t input_data_0[256] = { 2 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 3 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 4 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xF6, 0xFC, 0x02, 0xF8, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 5 | 0xEC, 0xEC, 0xEE, 0x01, 0x2B, 0x54, 0x59, 0x61, 0x57, 0x0D, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 6 | 0xEC, 0xEC, 0x03, 0x5C, 0x6C, 0x4E, 0x1C, 0x1B, 0x6F, 0x49, 0xEE, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 7 | 0xEC, 0xEC, 0x01, 0x52, 0x16, 0xF5, 0xEC, 0x00, 0x67, 0x4F, 0xEF, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 8 | 0xEC, 0xEC, 0xEE, 0xFA, 0xEF, 0xEC, 0xF2, 0x38, 0x7C, 0x68, 0x2F, 0xFD, 0xEC, 0xEC, 0xEC, 0xEC, 9 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEF, 0x34, 0x6F, 0x43, 0x29, 0x68, 0x33, 0xEF, 0xEC, 0xEC, 0xEC, 10 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xFF, 0x5F, 0x43, 0xFB, 0xEE, 0x40, 0x4E, 0xF6, 0xEC, 0xEC, 0xEC, 11 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xF0, 0x06, 0xF5, 0xEC, 0xED, 0x3E, 0x49, 0xF4, 0xEC, 0xEC, 0xEC, 12 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0x06, 0x70, 0x2C, 0xED, 0xEC, 0xEC, 0xEC, 13 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xF2, 0x13, 0x0A, 0xF5, 0x32, 0x61, 0x00, 0xEC, 0xEC, 0xEC, 0xEC, 14 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0x10, 0x71, 0x5E, 0x40, 0x69, 0x24, 0xEE, 0xEC, 0xEC, 0xEC, 0xEC, 15 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0x00, 0x6B, 0x7F, 0x71, 0x2C, 0xF0, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 16 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xED, 0x09, 0x16, 0x08, 0xEF, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 17 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 18 | }; 19 | uint8_t label_0 = 3; 20 | int8_t input_data_1[256] = { 21 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 22 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 23 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xF0, 0x0C, 0x27, 0x2A, 0x20, 0xFF, 0xED, 0xED, 0xED, 0xED, 0xED, 24 | 0xED, 0xED, 0xED, 0xED, 0xED, 0x0F, 0x6D, 0x72, 0x6B, 0x6B, 0x54, 0xFF, 0xED, 0xED, 0xED, 0xED, 25 | 0xED, 0xED, 0xED, 0xED, 0xED, 0x10, 0x5B, 0x22, 0x06, 0x15, 0x68, 0x22, 0xED, 0xED, 0xED, 0xED, 26 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xF1, 0xFC, 0xEE, 0xEE, 0x1F, 0x70, 0x1E, 0xED, 0xED, 0xED, 0xED, 27 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xEE, 0x0A, 0x62, 0x4C, 0xF8, 0xF2, 0xFC, 0xED, 0xED, 28 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xF1, 0x0A, 0x56, 0x62, 0x06, 0xED, 0x13, 0x41, 0xED, 0xED, 29 | 0xED, 0xED, 0xED, 0xED, 0xEF, 0x06, 0x31, 0x62, 0x5F, 0x16, 0xEE, 0xF5, 0x41, 0x51, 0xED, 0xED, 30 | 0xED, 0xED, 0xED, 0xF9, 0x27, 0x67, 0x72, 0x47, 0x0D, 0xEE, 0xF2, 0x23, 0x6B, 0x26, 0xED, 0xED, 31 | 0xED, 0xED, 0x01, 0x51, 0x7B, 0x7F, 0x6E, 0x2A, 0x19, 0x1E, 0x47, 0x73, 0x49, 0xF7, 0xED, 0xED, 32 | 0xED, 0xED, 0x38, 0x7C, 0x72, 0x3C, 0x4F, 0x6A, 0x70, 0x70, 0x6F, 0x4A, 0x03, 0xED, 0xED, 0xED, 33 | 0xED, 0xED, 0x4A, 0x6D, 0x2F, 0xF5, 0xF8, 0x08, 0x19, 0x1B, 0x12, 0xF8, 0xEE, 0xED, 0xED, 0xED, 34 | 0xED, 0xED, 0x05, 0x0D, 0xF3, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 35 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 36 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 37 | }; 38 | uint8_t label_1 = 2; 39 | int8_t input_data_2[256] = { 40 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 41 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 42 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xF4, 0x0C, 0x48, 0x40, 0xF9, 0xEC, 0xEC, 0xEC, 0xEC, 43 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xED, 0x06, 0x41, 0x6E, 0x7F, 0x50, 0xFF, 0xEC, 0xEC, 0xEC, 0xEC, 44 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEE, 0x06, 0x5C, 0x7A, 0x7B, 0x7A, 0x47, 0x09, 0xF0, 0xEC, 0xEC, 0xEC, 45 | 0xEC, 0xEC, 0xEC, 0xED, 0x05, 0x5A, 0x73, 0x3E, 0x21, 0x28, 0x52, 0x66, 0x37, 0xF5, 0xEC, 0xEC, 46 | 0xEC, 0xEC, 0xEC, 0x03, 0x4D, 0x71, 0x20, 0xF0, 0xEC, 0xEC, 0xF8, 0x4C, 0x78, 0x17, 0xEC, 0xEC, 47 | 0xEC, 0xEC, 0xFB, 0x52, 0x74, 0x2E, 0xF0, 0xEC, 0xEC, 0xEC, 0xEC, 0x21, 0x77, 0x49, 0xEC, 0xEC, 48 | 0xEC, 0xEC, 0x11, 0x77, 0x4C, 0xF8, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0x1C, 0x77, 0x57, 0xEC, 0xEC, 49 | 0xEC, 0xEC, 0x3A, 0x70, 0x13, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0x2A, 0x79, 0x47, 0xEC, 0xEC, 50 | 0xEC, 0xEC, 0x36, 0x6A, 0x05, 0xEC, 0xEC, 0xEC, 0xEC, 0xEE, 0x0F, 0x5E, 0x71, 0x13, 0xEC, 0xEC, 51 | 0xEC, 0xEC, 0x10, 0x68, 0x44, 0x03, 0xFC, 0xFF, 0x0F, 0x36, 0x69, 0x6F, 0x2D, 0xF5, 0xEC, 0xEC, 52 | 0xEC, 0xEC, 0xF0, 0x25, 0x62, 0x66, 0x62, 0x6A, 0x6F, 0x76, 0x58, 0x21, 0xF3, 0xEC, 0xEC, 0xEC, 53 | 0xEC, 0xEC, 0xEC, 0xF3, 0x03, 0x1A, 0x47, 0x56, 0x3C, 0x1C, 0xF9, 0xED, 0xEC, 0xEC, 0xEC, 0xEC, 54 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 55 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 56 | }; 57 | uint8_t label_2 = 0; 58 | int8_t input_data_3[256] = { 59 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 60 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 61 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 62 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEE, 0x03, 0x34, 0x3F, 0x2C, 0x09, 0xEF, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 63 | 0xEC, 0xEC, 0xEC, 0xEC, 0xFE, 0x58, 0x72, 0x5B, 0x50, 0x46, 0x25, 0xF6, 0xEC, 0xEC, 0xEC, 0xEC, 64 | 0xEC, 0xEC, 0xEC, 0xEC, 0x1C, 0x6D, 0x20, 0xF9, 0xF6, 0xFB, 0x2C, 0x06, 0xEC, 0xEC, 0xEC, 0xEC, 65 | 0xEC, 0xEC, 0xEC, 0xF0, 0x3A, 0x51, 0xEF, 0xEC, 0xF2, 0x1C, 0x6A, 0x1D, 0xEC, 0xEC, 0xEC, 0xEC, 66 | 0xEC, 0xEC, 0xEC, 0xEE, 0x28, 0x60, 0x03, 0xFD, 0x3A, 0x77, 0x70, 0x05, 0xEC, 0xEC, 0xEC, 0xEC, 67 | 0xEC, 0xEC, 0xEC, 0xEC, 0xFC, 0x46, 0x58, 0x53, 0x69, 0x7F, 0x4F, 0xF1, 0xEC, 0xEC, 0xEC, 0xEC, 68 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xF6, 0x13, 0x16, 0x1D, 0x73, 0x1D, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 69 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0x22, 0x6E, 0x06, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 70 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xF5, 0x4C, 0x64, 0xF6, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 71 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xFF, 0x6C, 0x53, 0xEF, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 72 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0x1A, 0x77, 0x33, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 73 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0x0B, 0x44, 0x11, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 74 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xED, 0xF0, 0xED, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 75 | }; 76 | uint8_t label_3 = 9; 77 | int8_t input_data_4[256] = { 78 | 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 79 | 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 80 | 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEC, 0x05, 0x33, 0x39, 0x0D, 0xEC, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 81 | 0xEB, 0xEB, 0xEB, 0xEB, 0xEC, 0x0E, 0x69, 0x74, 0x77, 0x6E, 0x18, 0xEE, 0xEB, 0xEB, 0xEB, 0xEB, 82 | 0xEB, 0xEB, 0xEB, 0xEB, 0xF8, 0x5D, 0x6A, 0x19, 0x22, 0x7B, 0x6C, 0x0A, 0xEC, 0xEB, 0xEB, 0xEB, 83 | 0xEB, 0xEB, 0xEB, 0xEC, 0x1D, 0x7F, 0x34, 0xEE, 0xF3, 0x3C, 0x7E, 0x54, 0xFD, 0xEB, 0xEB, 0xEB, 84 | 0xEB, 0xEB, 0xEB, 0xF3, 0x47, 0x79, 0x09, 0xEB, 0xEB, 0xEF, 0x2E, 0x7E, 0x3B, 0xED, 0xEB, 0xEB, 85 | 0xEB, 0xEB, 0xEB, 0xFC, 0x62, 0x64, 0xF4, 0xEB, 0xEB, 0xEB, 0xFF, 0x67, 0x68, 0xFD, 0xEB, 0xEB, 86 | 0xEB, 0xEB, 0xEB, 0x0F, 0x77, 0x3F, 0xEC, 0xEB, 0xEB, 0xEB, 0xEE, 0x43, 0x76, 0x06, 0xEB, 0xEB, 87 | 0xEB, 0xEB, 0xEB, 0x0F, 0x77, 0x45, 0xEC, 0xEB, 0xEB, 0xEB, 0xEB, 0x38, 0x7B, 0x15, 0xEB, 0xEB, 88 | 0xEB, 0xEB, 0xEB, 0xFA, 0x5B, 0x68, 0xF8, 0xEB, 0xEB, 0xEB, 0xED, 0x3E, 0x6D, 0x04, 0xEB, 0xEB, 89 | 0xEB, 0xEB, 0xEB, 0xEF, 0x2C, 0x7A, 0x3B, 0xFD, 0xF4, 0xF5, 0x1B, 0x6C, 0x3D, 0xEF, 0xEB, 0xEB, 90 | 0xEB, 0xEB, 0xEB, 0xEB, 0xF6, 0x3D, 0x72, 0x60, 0x4C, 0x51, 0x71, 0x68, 0x05, 0xEB, 0xEB, 0xEB, 91 | 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xF0, 0x11, 0x41, 0x55, 0x59, 0x37, 0x08, 0xEE, 0xEB, 0xEB, 0xEB, 92 | 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 93 | 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 0xEB, 94 | }; 95 | uint8_t label_4 = 0; 96 | int8_t input_data_5[256] = { 97 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xF0, 0xF1, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 98 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0x2B, 0x37, 0xF4, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 99 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xEF, 0x4A, 0x5F, 0xFC, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 100 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xF2, 0x5B, 0x60, 0xFC, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 101 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xF2, 0x5B, 0x60, 0xFC, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 102 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xF2, 0x5B, 0x60, 0x07, 0x00, 0xEE, 0xED, 0xED, 0xED, 0xED, 0xED, 103 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xF2, 0x5B, 0x67, 0x49, 0x67, 0x24, 0xF2, 0xED, 0xED, 0xED, 0xED, 104 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xF2, 0x5B, 0x79, 0x77, 0x7F, 0x6E, 0x18, 0xEE, 0xED, 0xED, 0xED, 105 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xF2, 0x5B, 0x7F, 0x7E, 0x69, 0x6C, 0x54, 0xFB, 0xED, 0xED, 0xED, 106 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xF2, 0x5B, 0x7F, 0x7D, 0x52, 0x6A, 0x6E, 0x06, 0xED, 0xED, 0xED, 107 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xEE, 0x3E, 0x7D, 0x7F, 0x7E, 0x7E, 0x5B, 0xFF, 0xED, 0xED, 0xED, 108 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0x0A, 0x66, 0x7E, 0x7A, 0x45, 0x0A, 0xEE, 0xED, 0xED, 0xED, 109 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xF0, 0x13, 0x3D, 0x32, 0xFB, 0xED, 0xED, 0xED, 0xED, 0xED, 110 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 111 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 112 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 113 | }; 114 | uint8_t label_5 = 6; 115 | int8_t input_data_6[256] = { 116 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 117 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 118 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 119 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xF5, 0x30, 0x41, 0x29, 0xF6, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 120 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEF, 0x29, 0x76, 0x50, 0x47, 0x2E, 0xF1, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 121 | 0xEC, 0xEC, 0xEC, 0xEC, 0xFC, 0x63, 0x50, 0xFD, 0xFD, 0x57, 0x29, 0xF2, 0xEC, 0xEC, 0xEC, 0xEC, 122 | 0xEC, 0xEC, 0xEC, 0xED, 0x15, 0x6D, 0x13, 0xED, 0xFE, 0x67, 0x76, 0x17, 0xEC, 0xEC, 0xEC, 0xEC, 123 | 0xEC, 0xEC, 0xEC, 0xF2, 0x3C, 0x6A, 0x02, 0xEE, 0x0A, 0x70, 0x7E, 0x2C, 0xEE, 0xEC, 0xEC, 0xEC, 124 | 0xEC, 0xEC, 0xEC, 0xEF, 0x1E, 0x6D, 0x3E, 0x2A, 0x3D, 0x75, 0x6E, 0x13, 0xED, 0xEC, 0xEC, 0xEC, 125 | 0xEC, 0xEC, 0xEC, 0xEC, 0xF0, 0x19, 0x52, 0x62, 0x6C, 0x7F, 0x48, 0xF2, 0xEC, 0xEC, 0xEC, 0xEC, 126 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xF2, 0xFA, 0x0F, 0x6E, 0x43, 0xF0, 0xEC, 0xEC, 0xEC, 0xEC, 127 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xF3, 0x5E, 0x43, 0xF0, 0xEC, 0xEC, 0xEC, 0xEC, 128 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xF2, 0x58, 0x49, 0xF0, 0xEC, 0xEC, 0xEC, 0xEC, 129 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xF0, 0x41, 0x60, 0xF2, 0xEC, 0xEC, 0xEC, 0xEC, 130 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xED, 0x17, 0x3E, 0xF1, 0xEC, 0xEC, 0xEC, 0xEC, 131 | 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 0xEE, 0xF0, 0xEC, 0xEC, 0xEC, 0xEC, 0xEC, 132 | }; 133 | uint8_t label_6 = 9; 134 | int8_t input_data_7[256] = { 135 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 136 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 137 | 0xED, 0xED, 0xED, 0xED, 0xFB, 0x2A, 0x50, 0x49, 0x2E, 0x00, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 138 | 0xED, 0xED, 0xED, 0xED, 0x09, 0x60, 0x5D, 0x3C, 0x3E, 0x4F, 0x0D, 0xED, 0xED, 0xED, 0xED, 0xED, 139 | 0xED, 0xED, 0xED, 0xED, 0xF0, 0x0F, 0x02, 0xF2, 0x08, 0x6A, 0x4A, 0xF1, 0xED, 0xED, 0xED, 0xED, 140 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xF2, 0x37, 0x7D, 0x55, 0xF2, 0xED, 0xED, 0xED, 0xED, 141 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xFC, 0x5B, 0x7F, 0x43, 0xF0, 0xED, 0xED, 0xED, 0xED, 142 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0x0E, 0x6F, 0x75, 0x1B, 0xF5, 0xF6, 0xF1, 0xEE, 0xED, 143 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xF4, 0x3A, 0x7B, 0x7D, 0x6A, 0x47, 0x4B, 0x31, 0x19, 0xF2, 144 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xFF, 0x4A, 0x76, 0x7E, 0x62, 0x45, 0x3A, 0x29, 0x23, 0x15, 0xF5, 145 | 0xED, 0xED, 0xED, 0xED, 0xF5, 0x40, 0x7D, 0x7A, 0x53, 0x02, 0xEE, 0xED, 0xED, 0xED, 0xED, 0xED, 146 | 0xED, 0xED, 0xED, 0xEE, 0x18, 0x73, 0x7D, 0x3E, 0x06, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 147 | 0xED, 0xED, 0xED, 0xF2, 0x35, 0x7A, 0x64, 0x04, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 148 | 0xED, 0xED, 0xED, 0xF2, 0x2C, 0x53, 0x17, 0xEE, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 149 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 150 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 151 | }; 152 | uint8_t label_7 = 2; 153 | int8_t input_data_8[256] = { 154 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 155 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 156 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 157 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xEE, 0xF3, 0x02, 0x17, 0x33, 0x24, 0xF1, 0xED, 0xED, 158 | 0xED, 0xED, 0xED, 0xED, 0xEE, 0xF0, 0xF7, 0x1E, 0x3E, 0x62, 0x73, 0x7B, 0x64, 0x03, 0xED, 0xED, 159 | 0xED, 0xED, 0xED, 0xF3, 0x16, 0x3F, 0x58, 0x76, 0x79, 0x6B, 0x5A, 0x72, 0x7A, 0x21, 0xED, 0xED, 160 | 0xED, 0xED, 0xF6, 0x3B, 0x72, 0x7F, 0x78, 0x48, 0x19, 0x02, 0x10, 0x6A, 0x75, 0x22, 0xED, 0xED, 161 | 0xED, 0xED, 0x1A, 0x76, 0x7A, 0x5E, 0x27, 0xF8, 0xED, 0xF3, 0x48, 0x7C, 0x4E, 0xF7, 0xED, 0xED, 162 | 0xED, 0xED, 0x27, 0x73, 0x47, 0x02, 0xF0, 0xED, 0xEF, 0x1D, 0x73, 0x74, 0x20, 0xED, 0xED, 0xED, 163 | 0xED, 0xED, 0xF9, 0x1A, 0xFF, 0xED, 0xED, 0xEF, 0x14, 0x66, 0x7E, 0x4A, 0xFB, 0xED, 0xED, 0xED, 164 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xF0, 0x19, 0x66, 0x7E, 0x5E, 0x03, 0xED, 0xED, 0xED, 0xED, 165 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xF4, 0x25, 0x6A, 0x7E, 0x6E, 0x17, 0xEE, 0xED, 0xED, 0xED, 0xED, 166 | 0xED, 0xED, 0xED, 0xED, 0xEE, 0x1D, 0x74, 0x7E, 0x75, 0x2D, 0xF1, 0xED, 0xED, 0xED, 0xED, 0xED, 167 | 0xED, 0xED, 0xED, 0xED, 0xF0, 0x39, 0x7F, 0x6B, 0x2C, 0xF5, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 168 | 0xED, 0xED, 0xED, 0xED, 0xEE, 0x14, 0x4B, 0x12, 0xEF, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 169 | 0xED, 0xED, 0xED, 0xED, 0xED, 0xEF, 0xF1, 0xEE, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 0xED, 170 | }; 171 | uint8_t label_8 = 7; 172 | int8_t input_data_9[256] = { 173 | 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 174 | 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 175 | 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 176 | 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE9, 0xF6, 0xF9, 0xE9, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 177 | 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xEB, 0xF3, 0x31, 0x6D, 0x77, 0x18, 0xE8, 0xE6, 0xE6, 0xE6, 0xE6, 178 | 0xE6, 0xE6, 0xE6, 0xE6, 0xEF, 0x4C, 0x69, 0x6F, 0x37, 0x44, 0x64, 0xEC, 0xE6, 0xE6, 0xE6, 0xE6, 179 | 0xE6, 0xE6, 0xE6, 0xE6, 0xFA, 0x76, 0x3C, 0xF9, 0xE9, 0x3A, 0x55, 0xEB, 0xE6, 0xE6, 0xE6, 0xE6, 180 | 0xE6, 0xE6, 0xE6, 0xE6, 0xEE, 0x13, 0xEC, 0xE6, 0xF4, 0x6C, 0x32, 0xE8, 0xE6, 0xE6, 0xE6, 0xE6, 181 | 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE7, 0x22, 0x7C, 0x00, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 182 | 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xF3, 0x64, 0x4B, 0xE8, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 183 | 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0x0F, 0x7E, 0x0C, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 184 | 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE9, 0x4D, 0x61, 0xF1, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 185 | 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xFC, 0x74, 0x2B, 0xE7, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 186 | 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0x0D, 0x7F, 0x04, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 187 | 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0x3F, 0x63, 0xF3, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 188 | 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0x0B, 0x12, 0xE9, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 0xE6, 189 | }; 190 | uint8_t label_9 = 7; 191 | --------------------------------------------------------------------------------