├── README.md ├── Resnet-Eager-Mode-Dynamic-Quant ├── README.md ├── __pycache__ │ ├── evaluate.cpython-310.pyc │ ├── evaluate.cpython-311.pyc │ ├── ipdb_hook.cpython-310.pyc │ └── qconfigs.cpython-311.pyc ├── dog.jpg ├── evaluate.py ├── imagenet_classes.txt ├── ipdb_hook.py ├── model │ ├── __pycache__ │ │ ├── resnet.cpython-310.pyc │ │ └── resnet.cpython-311.pyc │ └── resnet.py └── quant_dynamic_resnet.py ├── Resnet-Eager-Mode-Quant ├── README.md ├── __pycache__ │ ├── evaluate.cpython-310.pyc │ ├── evaluate.cpython-311.pyc │ └── qconfigs.cpython-311.pyc ├── dog.jpg ├── evaluate.py ├── imagenet_classes.txt ├── model │ ├── __pycache__ │ │ ├── resnet.cpython-310.pyc │ │ └── resnet.cpython-311.pyc │ └── resnet.py └── quant_resnet.py ├── Resnet-FX-CLE ├── .gitignore ├── Box_plots │ ├── With_CLE │ │ └── layer2.0.conv1.png │ └── Without_CLE │ │ └── layer2.0.conv1.png ├── CLE_notebook.ipynb ├── README.md ├── dog.jpg ├── evaluate │ ├── __init__.py │ ├── evaluate.py │ ├── imagenet_classes.txt │ └── images │ │ ├── README.md │ │ ├── Samoyed.jpg │ │ ├── clog.jpg │ │ ├── hen.jpg │ │ └── mail_box.jpg ├── graph.svg ├── main.py ├── model │ └── resnet.py ├── quant_vis │ ├── histograms │ │ ├── __init__.py │ │ ├── hooks │ │ │ ├── __init__.py │ │ │ ├── forward_hooks.py │ │ │ └── sa_back_hooks.py │ │ └── plots │ │ │ ├── __init__.py │ │ │ ├── plot_histograms.py │ │ │ ├── utils.py │ │ │ └── weights.py │ ├── settings.py │ └── utils │ │ ├── act_histogram.py │ │ ├── hooks.py │ │ └── prop_data.py └── utils │ ├── dotdict.py │ ├── graph_manip.py │ ├── ipdb_hook.py │ ├── logger.py │ └── qconfigs.py ├── Resnet-FX-Graph-Mode-Quant ├── README.md ├── __pycache__ │ ├── evaluate.cpython-310.pyc │ ├── evaluate.cpython-311.pyc │ ├── evaluate.cpython-312.pyc │ ├── ipdb_hook.cpython-310.pyc │ ├── ipdb_hook.cpython-312.pyc │ ├── qconfigs.cpython-310.pyc │ ├── qconfigs.cpython-311.pyc │ └── qconfigs.cpython-312.pyc ├── evaluate │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-312.pyc │ │ └── evaluate.cpython-312.pyc │ ├── dog.jpg │ ├── evaluate.py │ └── imagenet_classes.txt ├── graph.svg ├── ipdb_hook.py ├── main.py ├── model │ ├── __pycache__ │ │ ├── resnet.cpython-310.pyc │ │ ├── resnet.cpython-311.pyc │ │ └── resnet.cpython-312.pyc │ └── resnet.py └── qconfigs.py └── Resnet-FX-QAT ├── .gitignore ├── README.md ├── evaluate ├── __init__.py ├── evaluate.py ├── imagenet_classes.txt └── images │ ├── README.md │ ├── Samoyed.jpg │ ├── clog.jpg │ ├── hen.jpg │ └── mail_box.jpg ├── graph.svg ├── main.py ├── model └── resnet.py └── utils ├── graph_manip.py ├── ipdb_hook.py └── qconfigs.py /README.md: -------------------------------------------------------------------------------- 1 | # Quantization-Tutorials 2 | A bunch of coding tutorials for my [Youtube videos on Neural Network Quantization](https://www.youtube.com/@NeuralNetworkQuantization). 3 | 4 | # Resnet-Eager-Mode-Quant: 5 | 6 | [![How to Quantize a ResNet from Scratch! Full Coding Tutorial (Eager Mode)](https://ytcards.demolab.com/?id=jNZ1rkIfwsM&title=How+to+Quantize+a+ResNet+from+Scratch%21+Full+Coding+Tutorial+%28Eager+Mode%29%0D%0A&lang=en×tamp=1706473016&background_color=%230d1117&title_color=%23ffffff&stats_color=%23dedede&max_title_lines=1&width=250&border_radius=5 "How to Quantize a ResNet from Scratch! Full Coding Tutorial (Eager Mode)")](https://www.youtube.com/watch?v=jNZ1rkIfwsM) 7 | 8 | This is the first coding tutorial. We take the `torchvision` `ResNet` model and quantize it entirely from scratch with the PyTorch quantization library, using Eager Mode Quantization. 9 | 10 | We discuss common issues one can run into, as well as some interesting but tricky bugs. 11 | 12 | # Resnet-Eager-Mode-Dynamic-Quant: 13 | 14 | **TODO** 15 | 16 | In this tutorial, we do dynamic quantization on a ResNet model. We look at how dynamic quantization works, what the default settings are in PyTorch, and discuss how it differs to static quantization. 17 | 18 | 19 | # How to do FX Graph Mode Quantization (PyTorch ResNet Coding tutorial) 20 | 21 | In this tutorial series, we use Torch's FX Graph mode quantization to quantize a ResNet. In the first video, we look at the Directed Acyclic Graph (DAG), and see how the fusing, placement of quantstubs and FloatFunctionals all happen automatically. In the second, we look at some of the intricacies of how quantization interacts with the GraphModule. In the third and final video, we look at some more advanced techniques for manipulating and traversing the graph, and use these to discover an alternative to forward hooks, and for fusing BatchNorm layers into their preceding Convs. 22 | 23 | [![How to do FX Graph Mode Quantization: FX Graph Mode Quantization Coding tutorial - Part 1/3](https://ytcards.demolab.com/?id=AHw5BOUfLU4&title=How+to+do+FX+Graph+Mode+Quantization%3A+FX+Graph+Mode+Quantization+Coding+tutorial+-+Part+1%2F3&lang=en×tamp=1710264531&background_color=%230d1117&title_color=%23ffffff&stats_color=%23dedede&max_title_lines=1&width=250&border_radius=5 "How to do FX Graph Mode Quantization: FX Graph Mode Quantization Coding tutorial - Part 1/3")](https://www.youtube.com/watch?v=AHw5BOUfLU4) 24 | [![How does Graph Mode Affect Quantization? FX Graph Mode Quantization Coding tutorial - Part 2/3](https://ytcards.demolab.com/?id=1S3jlGdGdjM&title=How+does+Graph+Mode+Affect+Quantization%3F+FX+Graph+Mode+Quantization+Coding+tutorial+-+Part+2%2F3&lang=en×tamp=1710452876&background_color=%230d1117&title_color=%23ffffff&stats_color=%23dedede&max_title_lines=1&width=250&border_radius=5 "How does Graph Mode Affect Quantization? FX Graph Mode Quantization Coding tutorial - Part 2/3")](https://www.youtube.com/watch?v=1S3jlGdGdjM) 25 | [![Advanced PyTorch Graph Manipulation: FX Graph Mode Quantization Coding tutorial - Part 3/3](https://ytcards.demolab.com/?id=azpsgB8y0A8&title=Advanced+PyTorch+Graph+Manipulation%3A+FX+Graph+Mode+Quantization+Coding+tutorial+-+Part+3%2F3&lang=en×tamp=1711116192&background_color=%230d1117&title_color=%23ffffff&stats_color=%23dedede&max_title_lines=1&width=250&border_radius=5 "Advanced PyTorch Graph Manipulation: FX Graph Mode Quantization Coding tutorial - Part 3/3")](https://www.youtube.com/watch?v=azpsgB8y0A8) 26 | 27 | 28 | # Quantization Aware Training 29 | 30 | In this tutorial we look at how to do Quantization Aware Training (QAT) on an FX Graph Mode quantized Resnet. We build a small trianing lopp with a mini custom data loader. We also generalise the evaluate function we've been using in our tutorials to generalise to other images. We go looking for and find some of the danges of overfit. 31 | 32 | [![Quantization Aware Training (QAT) With a Custom DataLoader: Beginner's Tutorial to Training Loops](https://ytcards.demolab.com/?id=s3tqqBaRuHE&title=Quantization+Aware+Training+%28QAT%29+With+a+Custom+DataLoader%3A+Beginner%27s+Tutorial+to+Training+Loops&lang=en×tamp=1712648353&background_color=%230d1117&title_color=%23ffffff&stats_color=%23dedede&max_title_lines=1&width=250&border_radius=5 "Quantization Aware Training (QAT) With a Custom DataLoader: Beginner's Tutorial to Training Loops")](https://www.youtube.com/watch?v=s3tqqBaRuHE) 33 | 34 | 35 | # Cross Layer Equalization (CLE) 36 | 37 | In this tutorial, we look at Cross-Layer Equalization, a classic data-free method for 38 | improving the quantization of one's models. We use a graph-tracing method to find all of the 39 | layers we can do CLE on, do CLE, evaluate the results, and then visualize what's happening inside 40 | the model. 41 | 42 | [![Cross Layer Equalization: Everything You Need to Know](https://ytcards.demolab.com/?id=3eATdsWmHyI&title=Cross+Layer+Equalization%3A+Everything+You+Need+to+Know&lang=en×tamp=1715768680&background_color=%230d1117&title_color=%23ffffff&stats_color=%23dedede&max_title_lines=1&width=250&border_radius=5 "Cross Layer Equalization: Everything You Need to Know")](https://www.youtube.com/watch?v=3eATdsWmHyI) 43 | 44 | 45 | -------------------------------------------------------------------------------- /Resnet-Eager-Mode-Dynamic-Quant/README.md: -------------------------------------------------------------------------------- 1 | # How to Quantize a ResNet from Scratch (Eager Mode Dynamic Quantization) 2 | 3 | This is the finished code associated with the Youtube tutorial at: 4 | 5 | **TODO** 6 | 7 | 8 | ### Prerequisites: 9 | To run this code, you need to have PyTorch installed in your environment. If you do not have PyTorch installed, please follow this [official guide](https://pytorch.org/get-started/locally/). 10 | 11 | I created this code with PyTorch Version: 2.1.1. In case you have any versioning issues, you can revert to that version. 12 | 13 | ### Running this code: 14 | Once you have PyTorch installed, first navigate to a directory you will be working from. As you follow the next steps, your future file structure will look like this: `your-directory/Resnet-Eager-Mode-Dynamic-Quant`. 15 | 16 | Next, from `your-directory`, clone the `Quantization-Tutorials` repo. This repo contains different tutorials, but they are all interlinked. Feel no need to do any of the others! I just structured it this way because the tutorials share a lot of code and it might help people to see different parts in one place. 17 | 18 | You can also `git init` and then `git pull/fetch`, depending on what you prefer. 19 | 20 | To clone the repo, run: 21 | ``` 22 | git clone git@github.com:OscarSavolainenDR/Quantization-Tutorials.git . 23 | ``` 24 | 25 | If you did the cloning in place with the `.` at the end, your folder structure should look like `your-folder/Resnet-Eager-Mode-Dynamic-Quant`, with various other folders for other tutorials. 26 | 27 | Next, cd into the Resnet Eager Mode Quantization tutorial: 28 | ``` 29 | cd Resnet-Eager-Mode-Dynamic-Quant 30 | ``` 31 | Then, just run `python quant_dynamic_resnet.py` from your command line! However I would obviously recommend that you follow along with the tutorial, so that you learn how it all works and get your hands dirty. 32 | 33 | Let me know if there are any issues! 34 | -------------------------------------------------------------------------------- /Resnet-Eager-Mode-Dynamic-Quant/__pycache__/evaluate.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-Eager-Mode-Dynamic-Quant/__pycache__/evaluate.cpython-310.pyc -------------------------------------------------------------------------------- /Resnet-Eager-Mode-Dynamic-Quant/__pycache__/evaluate.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-Eager-Mode-Dynamic-Quant/__pycache__/evaluate.cpython-311.pyc -------------------------------------------------------------------------------- /Resnet-Eager-Mode-Dynamic-Quant/__pycache__/ipdb_hook.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-Eager-Mode-Dynamic-Quant/__pycache__/ipdb_hook.cpython-310.pyc -------------------------------------------------------------------------------- /Resnet-Eager-Mode-Dynamic-Quant/__pycache__/qconfigs.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-Eager-Mode-Dynamic-Quant/__pycache__/qconfigs.cpython-311.pyc -------------------------------------------------------------------------------- /Resnet-Eager-Mode-Dynamic-Quant/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-Eager-Mode-Dynamic-Quant/dog.jpg -------------------------------------------------------------------------------- /Resnet-Eager-Mode-Dynamic-Quant/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def evaluate(model, device_str: str): 4 | # Download an example image from the pytorch website 5 | """ 6 | The provided code defines a Python function called `evaluate` that allows a 7 | PyTorch model to be run on a sample image from the ImageNet dataset and generates 8 | probability distributions for each of the 1000 classes. 9 | 10 | Args: 11 | model (): The model input parameter is a pre-trained PyTorch model. It's 12 | used to process an input image via the model and receive an output 13 | containing class probabilities. 14 | device_str (str): The `device_str` input parameter specifies whether the 15 | computation should be executed on CPU or GPU. If device str is 'cuda' 16 | then it will be run only if CUDA is available 17 | 18 | """ 19 | import urllib 20 | url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg") 21 | try: urllib.URLopener().retrieve(url, filename) 22 | except: urllib.request.urlretrieve(url, filename) 23 | # sample execution (requires torchvision) 24 | 25 | from PIL import Image 26 | from torchvision import transforms 27 | input_image = Image.open(filename) 28 | preprocess = transforms.Compose([ 29 | transforms.Resize(256), 30 | transforms.CenterCrop(224), 31 | transforms.ToTensor(), 32 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 33 | ]) 34 | input_tensor = preprocess(input_image) 35 | input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model 36 | 37 | # move the input and model to GPU for speed if available, or to CPU if converted 38 | if not (device_str in['cpu', 'cuda']): 39 | raise NotImplementedError("`device_str` should be 'cpu' or 'cuda' ") 40 | if device_str == 'cuda': 41 | assert torch.cuda.is_available(), 'Check CUDA is available' 42 | 43 | input_batch = input_batch.to(device_str) 44 | model.to(device_str) 45 | 46 | with torch.no_grad(): 47 | output = model(input_batch) 48 | # Tensor of shape 1000, with confidence scores over ImageNet's 1000 classes 49 | # print(output[0]) 50 | # The output has unnormalized scores. To get probabilities, you can run a softmax on it. 51 | probabilities = torch.nn.functional.softmax(output[0], dim=0) 52 | # print(probabilities) 53 | 54 | # Read the categories 55 | with open("imagenet_classes.txt", "r") as f: 56 | categories = [s.strip() for s in f.readlines()] 57 | # Show top categories per image 58 | top5_prob, top5_catid = torch.topk(probabilities, 5) 59 | for i in range(top5_prob.size(0)): 60 | print(categories[top5_catid[i]], top5_prob[i].item()) 61 | -------------------------------------------------------------------------------- /Resnet-Eager-Mode-Dynamic-Quant/ipdb_hook.py: -------------------------------------------------------------------------------- 1 | import traceback, ipdb 2 | import sys 3 | 4 | def ipdb_sys_excepthook(): 5 | """ 6 | When called this function will set up the system exception hook. 7 | This hook throws one into an ipdb breakpoint if and where a system 8 | exception occurs in one's run. 9 | 10 | E.g. 11 | >>> ipdb_sys_excepthook() 12 | """ 13 | 14 | 15 | def info(type, value, tb): 16 | """ 17 | System excepthook that includes an ipdb breakpoint. 18 | """ 19 | if hasattr(sys, 'ps1') or not sys.stderr.isatty(): 20 | # we are in interactive mode or we don't have a tty-like 21 | # device, so we call the default hook 22 | sys.__excepthook__(type, value, tb) 23 | else: 24 | # we are NOT in interactive mode, print the exception... 25 | traceback.print_exception(type, value, tb) 26 | print 27 | # ...then start the debugger in post-mortem mode. 28 | # pdb.pm() # deprecated 29 | ipdb.post_mortem(tb) # more "modern" 30 | sys.excepthook = info -------------------------------------------------------------------------------- /Resnet-Eager-Mode-Dynamic-Quant/model/__pycache__/resnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-Eager-Mode-Dynamic-Quant/model/__pycache__/resnet.cpython-310.pyc -------------------------------------------------------------------------------- /Resnet-Eager-Mode-Dynamic-Quant/model/__pycache__/resnet.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-Eager-Mode-Dynamic-Quant/model/__pycache__/resnet.cpython-311.pyc -------------------------------------------------------------------------------- /Resnet-Eager-Mode-Dynamic-Quant/model/resnet.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Callable, List, Optional, Type, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch import Tensor 7 | 8 | from torchvision.transforms._presets import ImageClassification 9 | from torchvision.utils import _log_api_usage_once 10 | from torchvision.models._api import register_model, Weights, WeightsEnum 11 | from torchvision.models._meta import _IMAGENET_CATEGORIES 12 | from torchvision.models._utils import _ovewrite_named_param, handle_legacy_interface 13 | 14 | 15 | __all__ = [ 16 | "ResNet", 17 | "ResNet18_Weights", 18 | "ResNet34_Weights", 19 | "ResNet50_Weights", 20 | "ResNet101_Weights", 21 | "ResNet152_Weights", 22 | "ResNeXt50_32X4D_Weights", 23 | "ResNeXt101_32X8D_Weights", 24 | "ResNeXt101_64X4D_Weights", 25 | "Wide_ResNet50_2_Weights", 26 | "Wide_ResNet101_2_Weights", 27 | "resnet18", 28 | "resnet34", 29 | "resnet50", 30 | "resnet101", 31 | "resnet152", 32 | "resnext50_32x4d", 33 | "resnext101_32x8d", 34 | "resnext101_64x4d", 35 | "wide_resnet50_2", 36 | "wide_resnet101_2", 37 | ] 38 | 39 | 40 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 41 | """3x3 convolution with padding""" 42 | return nn.Conv2d( 43 | in_planes, 44 | out_planes, 45 | kernel_size=3, 46 | stride=stride, 47 | padding=dilation, 48 | groups=groups, 49 | bias=False, 50 | dilation=dilation, 51 | ) 52 | 53 | 54 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 55 | """1x1 convolution""" 56 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 57 | 58 | 59 | class BasicBlock(nn.Module): 60 | expansion: int = 1 61 | 62 | def __init__( 63 | self, 64 | inplanes: int, 65 | planes: int, 66 | stride: int = 1, 67 | downsample: Optional[nn.Module] = None, 68 | groups: int = 1, 69 | base_width: int = 64, 70 | dilation: int = 1, 71 | norm_layer: Optional[Callable[..., nn.Module]] = None, 72 | ) -> None: 73 | super().__init__() 74 | if norm_layer is None: 75 | norm_layer = nn.BatchNorm2d 76 | if groups != 1 or base_width != 64: 77 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 78 | if dilation > 1: 79 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 80 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 81 | self.conv1 = conv3x3(inplanes, planes, stride) 82 | self.bn1 = norm_layer(planes) 83 | self.relu1 = nn.ReLU(inplace=True) 84 | self.conv2 = conv3x3(planes, planes) 85 | self.bn2 = norm_layer(planes) 86 | self.downsample = downsample 87 | self.FFAddReLU = torch.ao.nn.quantized.FloatFunctional() 88 | # self.relu_out = nn.ReLU(inplace=True) 89 | self.stride = stride 90 | 91 | def modules_to_fuse(self, prefix): 92 | modules_to_fuse_ = [] 93 | modules_to_fuse_.append([f'{prefix}.conv1', f'{prefix}.bn1', f'{prefix}.relu1']) 94 | modules_to_fuse_.append([f'{prefix}.conv2', f'{prefix}.bn2']) 95 | if self.downsample: 96 | modules_to_fuse_.append([f'{prefix}.downsample.0', f'{prefix}.downsample.1']) 97 | 98 | return modules_to_fuse_ 99 | 100 | def forward(self, x: Tensor) -> Tensor: 101 | identity = x 102 | 103 | out = self.conv1(x) 104 | out = self.bn1(out) 105 | out = self.relu1(out) 106 | 107 | out = self.conv2(out) 108 | out = self.bn2(out) 109 | 110 | if self.downsample is not None: 111 | identity = self.downsample(x) 112 | 113 | out = self.FFAddReLU.add_relu(out, identity) 114 | # out = self.relu_out(out) 115 | 116 | return out 117 | 118 | 119 | class ResNet(nn.Module): 120 | def __init__( 121 | self, 122 | block: Type[BasicBlock], 123 | layers: List[int], 124 | num_classes: int = 1000, 125 | zero_init_residual: bool = False, 126 | groups: int = 1, 127 | width_per_group: int = 64, 128 | replace_stride_with_dilation: Optional[List[bool]] = None, 129 | norm_layer: Optional[Callable[..., nn.Module]] = None, 130 | ) -> None: 131 | super().__init__() 132 | _log_api_usage_once(self) 133 | if norm_layer is None: 134 | norm_layer = nn.BatchNorm2d 135 | self._norm_layer = norm_layer 136 | 137 | self.inplanes = 64 138 | self.dilation = 1 139 | if replace_stride_with_dilation is None: 140 | # each element in the tuple indicates if we should replace 141 | # the 2x2 stride with a dilated convolution instead 142 | replace_stride_with_dilation = [False, False, False] 143 | if len(replace_stride_with_dilation) != 3: 144 | raise ValueError( 145 | "replace_stride_with_dilation should be None " 146 | f"or a 3-element tuple, got {replace_stride_with_dilation}" 147 | ) 148 | self.groups = groups 149 | self.base_width = width_per_group 150 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 151 | self.bn1 = norm_layer(self.inplanes) 152 | self.relu = nn.ReLU(inplace=True) 153 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 154 | self.layer1 = self._make_layer(block, 64, layers[0]) 155 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 156 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) 157 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 158 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 159 | self.fc = nn.Linear(512 * block.expansion, num_classes) 160 | self.quant = torch.ao.quantization.QuantStub() 161 | self.dequant = torch.ao.quantization.DeQuantStub() 162 | 163 | for m in self.modules(): 164 | if isinstance(m, nn.Conv2d): 165 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 166 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 167 | nn.init.constant_(m.weight, 1) 168 | nn.init.constant_(m.bias, 0) 169 | 170 | # Zero-initialize the last BN in each residual branch, 171 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 172 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 173 | if zero_init_residual: 174 | for m in self.modules(): 175 | if isinstance(m, Bottleneck) and m.bn3.weight is not None: 176 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 177 | elif isinstance(m, BasicBlock) and m.bn2.weight is not None: 178 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 179 | 180 | def _make_layer( 181 | self, 182 | block: Type[BasicBlock], 183 | planes: int, 184 | blocks: int, 185 | stride: int = 1, 186 | dilate: bool = False, 187 | ) -> nn.Sequential: 188 | norm_layer = self._norm_layer 189 | downsample = None 190 | previous_dilation = self.dilation 191 | if dilate: 192 | self.dilation *= stride 193 | stride = 1 194 | if stride != 1 or self.inplanes != planes * block.expansion: 195 | downsample = nn.Sequential( 196 | conv1x1(self.inplanes, planes * block.expansion, stride), 197 | norm_layer(planes * block.expansion), 198 | ) 199 | 200 | layers = [] 201 | layers.append( 202 | block( 203 | self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer 204 | ) 205 | ) 206 | self.inplanes = planes * block.expansion 207 | for _ in range(1, blocks): 208 | layers.append( 209 | block( 210 | self.inplanes, 211 | planes, 212 | groups=self.groups, 213 | base_width=self.base_width, 214 | dilation=self.dilation, 215 | norm_layer=norm_layer, 216 | ) 217 | ) 218 | 219 | 220 | return nn.Sequential(*layers) 221 | 222 | def modules_to_fuse(self): 223 | modules_to_fuse_ = [] 224 | modules_to_fuse_.append(['conv1', 'bn1', 'relu']) 225 | 226 | for layer_str in ['layer1', 'layer2', 'layer3', 'layer4']: 227 | layer = eval(f'self.{layer_str}') 228 | for block_nb in range(len(layer)): 229 | prefix = f'{layer_str}.{block_nb}' 230 | modules_to_fuse_layer = layer[block_nb].modules_to_fuse(prefix) 231 | modules_to_fuse_.extend(modules_to_fuse_layer) 232 | 233 | return modules_to_fuse_ 234 | 235 | def _forward_impl(self, x: Tensor) -> Tensor: 236 | # See note [TorchScript super()] 237 | x = self.conv1(x) 238 | x = self.bn1(x) 239 | x = self.relu(x) 240 | x = self.maxpool(x) 241 | 242 | x = self.layer1(x) 243 | x = self.layer2(x) 244 | x = self.layer3(x) 245 | x = self.layer4(x) 246 | 247 | x = self.avgpool(x) 248 | x = torch.flatten(x, 1) 249 | x = self.fc(x) 250 | 251 | return x 252 | 253 | def forward(self, x: Tensor) -> Tensor: 254 | x = self.quant(x) 255 | x = self._forward_impl(x) 256 | x = self.dequant(x) 257 | return x 258 | 259 | 260 | def _resnet( 261 | block: Type[BasicBlock], 262 | layers: List[int], 263 | weights: Optional[WeightsEnum], 264 | progress: bool, 265 | **kwargs: Any, 266 | ) -> ResNet: 267 | if weights is not None: 268 | _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) 269 | 270 | model = ResNet(block, layers, **kwargs) 271 | 272 | if weights is not None: 273 | model.load_state_dict(weights.get_state_dict(progress=progress)) 274 | 275 | return model 276 | 277 | 278 | _COMMON_META = { 279 | "min_size": (1, 1), 280 | "categories": _IMAGENET_CATEGORIES, 281 | } 282 | 283 | 284 | class ResNet18_Weights(WeightsEnum): 285 | IMAGENET1K_V1 = Weights( 286 | url="https://download.pytorch.org/models/resnet18-f37072fd.pth", 287 | transforms=partial(ImageClassification, crop_size=224), 288 | meta={ 289 | **_COMMON_META, 290 | "num_params": 11689512, 291 | "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", 292 | "_metrics": { 293 | "ImageNet-1K": { 294 | "acc@1": 69.758, 295 | "acc@5": 89.078, 296 | } 297 | }, 298 | "_ops": 1.814, 299 | "_file_size": 44.661, 300 | "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", 301 | }, 302 | ) 303 | DEFAULT = IMAGENET1K_V1 304 | 305 | 306 | @handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1)) 307 | def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: 308 | """ResNet-18 from `Deep Residual Learning for Image Recognition `__. 309 | 310 | Args: 311 | weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The 312 | pretrained weights to use. See 313 | :class:`~torchvision.models.ResNet18_Weights` below for 314 | more details, and possible values. By default, no pre-trained 315 | weights are used. 316 | progress (bool, optional): If True, displays a progress bar of the 317 | download to stderr. Default is True. 318 | **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` 319 | base class. Please refer to the `source code 320 | `_ 321 | for more details about this class. 322 | 323 | .. autoclass:: torchvision.models.ResNet18_Weights 324 | :members: 325 | """ 326 | weights = ResNet18_Weights.verify(weights) 327 | 328 | return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) 329 | -------------------------------------------------------------------------------- /Resnet-Eager-Mode-Dynamic-Quant/quant_dynamic_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from model.resnet import resnet18 3 | 4 | from evaluate import evaluate 5 | from ipdb_hook import ipdb_sys_excepthook 6 | 7 | ipdb_sys_excepthook() 8 | 9 | model = resnet18(pretrained=True) 10 | print(model) 11 | 12 | # Step 1: architecture changes 13 | # QuantStubs (we will do FloatFunctionals later) 14 | # Done 15 | 16 | # Step 2: fuse modules (recommended but not necessary) 17 | modules_to_list = model.modules_to_fuse() 18 | 19 | # It will keep Batchnorm 20 | model.eval() 21 | # fused_model = torch.ao.quantization.fuse_modules_qat(model, modules_to_list) 22 | 23 | # This will fuse BatchNorm weights into the preceding Conv 24 | fused_model = torch.ao.quantization.fuse_modules(model, modules_to_list) 25 | 26 | # Step 3: Assign qconfigs 27 | backend = 'fbgemm' 28 | qconfig = torch.quantization.get_default_qconfig(backend) 29 | torch.backends.quantized.engine = backend 30 | 31 | for name, module in fused_model.named_modules(): 32 | module.qconfig = qconfig 33 | 34 | # Step 4: Prepare for fake-quant 35 | fused_model.train() 36 | fake_quant_model = torch.ao.quantization.prepare_qat(fused_model) 37 | 38 | # Step 4b: Try dynamic quantization 39 | # NOTE: we overrride the default mapping to have some more examples 40 | from torch.ao.quantization.quantization_mappings import get_default_dynamic_quant_module_mappings 41 | from torch.ao.quantization.qconfig import default_dynamic_qconfig 42 | import torch.ao.nn.quantized.dynamic as nnqd 43 | mapping = get_default_dynamic_quant_module_mappings() 44 | mapping[torch.nn.Conv2d] = nnqd.Conv2d 45 | qconfig_spec = { 46 | torch.nn.Linear : default_dynamic_qconfig, 47 | torch.nn.LSTM : default_dynamic_qconfig, 48 | torch.nn.GRU : default_dynamic_qconfig, 49 | torch.nn.LSTMCell : default_dynamic_qconfig, 50 | torch.nn.RNNCell : default_dynamic_qconfig, 51 | torch.nn.GRUCell : default_dynamic_qconfig, 52 | torch.nn.Conv2d: default_dynamic_qconfig, # has bad numerical performance 53 | } 54 | fake_quant_model_dynamic = torch.quantization.quantize_dynamic(model, qconfig_spec=qconfig_spec, mapping=mapping) 55 | 56 | # Evaluate 57 | print('\noriginal') 58 | evaluate(model, 'cpu') 59 | print('\nfused') 60 | evaluate(fused_model, 'cpu') 61 | 62 | print('\ndynamic') 63 | evaluate(fake_quant_model_dynamic, 'cpu') 64 | 65 | 66 | # Step 5: convert (true int8 model) 67 | fake_quant_model.to('cpu') 68 | converted_model = torch.quantization.convert(fake_quant_model) 69 | 70 | print('\nfake quant') 71 | evaluate(fake_quant_model, 'cpu') 72 | 73 | 74 | print('\nconverted') 75 | evaluate(converted_model, 'cpu') 76 | 77 | xxx 78 | # ## Torch compile 79 | # compiled_model = torch.compile(model) 80 | # print(compiled_model) -------------------------------------------------------------------------------- /Resnet-Eager-Mode-Quant/README.md: -------------------------------------------------------------------------------- 1 | # How to Quantize a ResNet from Scratch (Eager Mode Static Quantization) 2 | 3 | This is the finished code associated with the YouTube tutorial at: 4 | 5 | [![How to Quantize a ResNet from Scratch! Full Coding Tutorial (Eager Mode)](https://ytcards.demolab.com/?id=jNZ1rkIfwsM&title=How+to+Quantize+a+ResNet+from+Scratch!+Full+Coding+Tutorial+(Eager+Mode)&lang=en×tamp=1706473016&background_color=%230d1117&title_color=%23ffffff&stats_color=%23dedede&max_title_lines=1&width=250&border_radius=5 "How to Quantize a ResNet from Scratch! Full Coding Tutorial (Eager Mode)")](https://www.youtube.com/watch?v=jNZ1rkIfwsM) 6 | 7 | ### Prerequisites: 8 | To run this code, you need to have PyTorch installed in your environment. If you do not have PyTorch installed, please follow this [official guide](https://pytorch.org/get-started/locally/). 9 | 10 | I created this code with PyTorch Version: 2.1.1. In case you have any versioning issues, you can revert to that version. 11 | 12 | ### Running this code: 13 | Once you have PyTorch installed, first navigate to a directory you will be working from. As you follow the next steps, your future file structure will look like this: `your-directory/Resnet-Eager-Mode-Quant`. 14 | 15 | Next, from `your-directory`, clone the `Quantization-Tutorials` repo. This repo contains different tutorials, but they are all interlinked. Feel no need to do any of the others! I just structured it this way because the tutorials share a lot of code and it might help people to see different parts in one place. 16 | 17 | You can also `git init` and then `git pull/fetch`, depending on what you prefer. 18 | 19 | To clone the repo, run: 20 | ``` 21 | git clone git@github.com:OscarSavolainenDR/Quantization-Tutorials.git . 22 | ``` 23 | 24 | If you did the cloning in place with the `.` at the end, your folder structure should look like `your-folder/Resnet-Eager-Mode-Quant`, with various other folders for other tutorials. 25 | 26 | Next, cd into the Resnet Eager Mode Quantization tutorial: 27 | ``` 28 | cd Resnet-Eager-Mode-Quant 29 | ``` 30 | Then, just run `python quant_resnet.py` from your command line! However I would obviously recommend that you follow along with the tutorial, so that you learn how it all works and get your hands dirty. 31 | 32 | Let me know if there are any issues! 33 | -------------------------------------------------------------------------------- /Resnet-Eager-Mode-Quant/__pycache__/evaluate.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-Eager-Mode-Quant/__pycache__/evaluate.cpython-310.pyc -------------------------------------------------------------------------------- /Resnet-Eager-Mode-Quant/__pycache__/evaluate.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-Eager-Mode-Quant/__pycache__/evaluate.cpython-311.pyc -------------------------------------------------------------------------------- /Resnet-Eager-Mode-Quant/__pycache__/qconfigs.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-Eager-Mode-Quant/__pycache__/qconfigs.cpython-311.pyc -------------------------------------------------------------------------------- /Resnet-Eager-Mode-Quant/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-Eager-Mode-Quant/dog.jpg -------------------------------------------------------------------------------- /Resnet-Eager-Mode-Quant/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def evaluate(model, device_str='cuda'): 4 | 5 | # Download an example image from the pytorch website 6 | """ 7 | This is a functional piece of code that downloads an image from a predefined 8 | URL and processes it through a model and then generates a probability distribution 9 | over ImageNet classes using the softmax activation. 10 | 11 | Args: 12 | model (): The input parameter `model` to `evaluate()` is a PyTorch model 13 | whose inputs are to be passed through it for prediction and evaluation. 14 | device_str (str): The `device_str` input parameter specifies whether to 15 | move the input and model to GPU or CPU after preprocessing and before 16 | executing the model. 17 | 18 | """ 19 | import urllib 20 | url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg") 21 | try: urllib.URLopener().retrieve(url, filename) 22 | except: urllib.request.urlretrieve(url, filename) 23 | # sample execution (requires torchvision) 24 | from PIL import Image 25 | from torchvision import transforms 26 | input_image = Image.open(filename) 27 | preprocess = transforms.Compose([ 28 | transforms.Resize(256), 29 | transforms.CenterCrop(224), 30 | transforms.ToTensor(), 31 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 32 | ]) 33 | input_tensor = preprocess(input_image) 34 | input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model 35 | 36 | # Move the input and model to GPU for speed if available, or to CPU if converted 37 | if not (device_str in['cpu', 'cuda']): 38 | raise NotImplementedError("`device_str` should be 'cpu' or 'cuda' ") 39 | if device_str == 'cuda': 40 | assert torch.cuda.is_available(), 'Check CUDA is available' 41 | 42 | input_batch = input_batch.to(device_str) 43 | model.to(device_str) 44 | 45 | with torch.no_grad(): 46 | output = model(input_batch) 47 | 48 | # Tensor of shape 1000, with confidence scores over ImageNet's 1000 classes 49 | #print(output[0]) 50 | # The output has unnormalized scores. To get probabilities, you can run a softmax on it. 51 | probabilities = torch.nn.functional.softmax(output[0], dim=0) 52 | #print(probabilities) 53 | 54 | # Read the categories 55 | with open("imagenet_classes.txt", "r") as f: 56 | categories = [s.strip() for s in f.readlines()] 57 | # Show top categories per image 58 | top5_prob, top5_catid = torch.topk(probabilities, 5) 59 | for i in range(top5_prob.size(0)): 60 | print(categories[top5_catid[i]], top5_prob[i].item()) -------------------------------------------------------------------------------- /Resnet-Eager-Mode-Quant/imagenet_classes.txt: -------------------------------------------------------------------------------- 1 | tench 2 | goldfish 3 | great white shark 4 | tiger shark 5 | hammerhead 6 | electric ray 7 | stingray 8 | cock 9 | hen 10 | ostrich 11 | brambling 12 | goldfinch 13 | house finch 14 | junco 15 | indigo bunting 16 | robin 17 | bulbul 18 | jay 19 | magpie 20 | chickadee 21 | water ouzel 22 | kite 23 | bald eagle 24 | vulture 25 | great grey owl 26 | European fire salamander 27 | common newt 28 | eft 29 | spotted salamander 30 | axolotl 31 | bullfrog 32 | tree frog 33 | tailed frog 34 | loggerhead 35 | leatherback turtle 36 | mud turtle 37 | terrapin 38 | box turtle 39 | banded gecko 40 | common iguana 41 | American chameleon 42 | whiptail 43 | agama 44 | frilled lizard 45 | alligator lizard 46 | Gila monster 47 | green lizard 48 | African chameleon 49 | Komodo dragon 50 | African crocodile 51 | American alligator 52 | triceratops 53 | thunder snake 54 | ringneck snake 55 | hognose snake 56 | green snake 57 | king snake 58 | garter snake 59 | water snake 60 | vine snake 61 | night snake 62 | boa constrictor 63 | rock python 64 | Indian cobra 65 | green mamba 66 | sea snake 67 | horned viper 68 | diamondback 69 | sidewinder 70 | trilobite 71 | harvestman 72 | scorpion 73 | black and gold garden spider 74 | barn spider 75 | garden spider 76 | black widow 77 | tarantula 78 | wolf spider 79 | tick 80 | centipede 81 | black grouse 82 | ptarmigan 83 | ruffed grouse 84 | prairie chicken 85 | peacock 86 | quail 87 | partridge 88 | African grey 89 | macaw 90 | sulphur-crested cockatoo 91 | lorikeet 92 | coucal 93 | bee eater 94 | hornbill 95 | hummingbird 96 | jacamar 97 | toucan 98 | drake 99 | red-breasted merganser 100 | goose 101 | black swan 102 | tusker 103 | echidna 104 | platypus 105 | wallaby 106 | koala 107 | wombat 108 | jellyfish 109 | sea anemone 110 | brain coral 111 | flatworm 112 | nematode 113 | conch 114 | snail 115 | slug 116 | sea slug 117 | chiton 118 | chambered nautilus 119 | Dungeness crab 120 | rock crab 121 | fiddler crab 122 | king crab 123 | American lobster 124 | spiny lobster 125 | crayfish 126 | hermit crab 127 | isopod 128 | white stork 129 | black stork 130 | spoonbill 131 | flamingo 132 | little blue heron 133 | American egret 134 | bittern 135 | crane 136 | limpkin 137 | European gallinule 138 | American coot 139 | bustard 140 | ruddy turnstone 141 | red-backed sandpiper 142 | redshank 143 | dowitcher 144 | oystercatcher 145 | pelican 146 | king penguin 147 | albatross 148 | grey whale 149 | killer whale 150 | dugong 151 | sea lion 152 | Chihuahua 153 | Japanese spaniel 154 | Maltese dog 155 | Pekinese 156 | Shih-Tzu 157 | Blenheim spaniel 158 | papillon 159 | toy terrier 160 | Rhodesian ridgeback 161 | Afghan hound 162 | basset 163 | beagle 164 | bloodhound 165 | bluetick 166 | black-and-tan coonhound 167 | Walker hound 168 | English foxhound 169 | redbone 170 | borzoi 171 | Irish wolfhound 172 | Italian greyhound 173 | whippet 174 | Ibizan hound 175 | Norwegian elkhound 176 | otterhound 177 | Saluki 178 | Scottish deerhound 179 | Weimaraner 180 | Staffordshire bullterrier 181 | American Staffordshire terrier 182 | Bedlington terrier 183 | Border terrier 184 | Kerry blue terrier 185 | Irish terrier 186 | Norfolk terrier 187 | Norwich terrier 188 | Yorkshire terrier 189 | wire-haired fox terrier 190 | Lakeland terrier 191 | Sealyham terrier 192 | Airedale 193 | cairn 194 | Australian terrier 195 | Dandie Dinmont 196 | Boston bull 197 | miniature schnauzer 198 | giant schnauzer 199 | standard schnauzer 200 | Scotch terrier 201 | Tibetan terrier 202 | silky terrier 203 | soft-coated wheaten terrier 204 | West Highland white terrier 205 | Lhasa 206 | flat-coated retriever 207 | curly-coated retriever 208 | golden retriever 209 | Labrador retriever 210 | Chesapeake Bay retriever 211 | German short-haired pointer 212 | vizsla 213 | English setter 214 | Irish setter 215 | Gordon setter 216 | Brittany spaniel 217 | clumber 218 | English springer 219 | Welsh springer spaniel 220 | cocker spaniel 221 | Sussex spaniel 222 | Irish water spaniel 223 | kuvasz 224 | schipperke 225 | groenendael 226 | malinois 227 | briard 228 | kelpie 229 | komondor 230 | Old English sheepdog 231 | Shetland sheepdog 232 | collie 233 | Border collie 234 | Bouvier des Flandres 235 | Rottweiler 236 | German shepherd 237 | Doberman 238 | miniature pinscher 239 | Greater Swiss Mountain dog 240 | Bernese mountain dog 241 | Appenzeller 242 | EntleBucher 243 | boxer 244 | bull mastiff 245 | Tibetan mastiff 246 | French bulldog 247 | Great Dane 248 | Saint Bernard 249 | Eskimo dog 250 | malamute 251 | Siberian husky 252 | dalmatian 253 | affenpinscher 254 | basenji 255 | pug 256 | Leonberg 257 | Newfoundland 258 | Great Pyrenees 259 | Samoyed 260 | Pomeranian 261 | chow 262 | keeshond 263 | Brabancon griffon 264 | Pembroke 265 | Cardigan 266 | toy poodle 267 | miniature poodle 268 | standard poodle 269 | Mexican hairless 270 | timber wolf 271 | white wolf 272 | red wolf 273 | coyote 274 | dingo 275 | dhole 276 | African hunting dog 277 | hyena 278 | red fox 279 | kit fox 280 | Arctic fox 281 | grey fox 282 | tabby 283 | tiger cat 284 | Persian cat 285 | Siamese cat 286 | Egyptian cat 287 | cougar 288 | lynx 289 | leopard 290 | snow leopard 291 | jaguar 292 | lion 293 | tiger 294 | cheetah 295 | brown bear 296 | American black bear 297 | ice bear 298 | sloth bear 299 | mongoose 300 | meerkat 301 | tiger beetle 302 | ladybug 303 | ground beetle 304 | long-horned beetle 305 | leaf beetle 306 | dung beetle 307 | rhinoceros beetle 308 | weevil 309 | fly 310 | bee 311 | ant 312 | grasshopper 313 | cricket 314 | walking stick 315 | cockroach 316 | mantis 317 | cicada 318 | leafhopper 319 | lacewing 320 | dragonfly 321 | damselfly 322 | admiral 323 | ringlet 324 | monarch 325 | cabbage butterfly 326 | sulphur butterfly 327 | lycaenid 328 | starfish 329 | sea urchin 330 | sea cucumber 331 | wood rabbit 332 | hare 333 | Angora 334 | hamster 335 | porcupine 336 | fox squirrel 337 | marmot 338 | beaver 339 | guinea pig 340 | sorrel 341 | zebra 342 | hog 343 | wild boar 344 | warthog 345 | hippopotamus 346 | ox 347 | water buffalo 348 | bison 349 | ram 350 | bighorn 351 | ibex 352 | hartebeest 353 | impala 354 | gazelle 355 | Arabian camel 356 | llama 357 | weasel 358 | mink 359 | polecat 360 | black-footed ferret 361 | otter 362 | skunk 363 | badger 364 | armadillo 365 | three-toed sloth 366 | orangutan 367 | gorilla 368 | chimpanzee 369 | gibbon 370 | siamang 371 | guenon 372 | patas 373 | baboon 374 | macaque 375 | langur 376 | colobus 377 | proboscis monkey 378 | marmoset 379 | capuchin 380 | howler monkey 381 | titi 382 | spider monkey 383 | squirrel monkey 384 | Madagascar cat 385 | indri 386 | Indian elephant 387 | African elephant 388 | lesser panda 389 | giant panda 390 | barracouta 391 | eel 392 | coho 393 | rock beauty 394 | anemone fish 395 | sturgeon 396 | gar 397 | lionfish 398 | puffer 399 | abacus 400 | abaya 401 | academic gown 402 | accordion 403 | acoustic guitar 404 | aircraft carrier 405 | airliner 406 | airship 407 | altar 408 | ambulance 409 | amphibian 410 | analog clock 411 | apiary 412 | apron 413 | ashcan 414 | assault rifle 415 | backpack 416 | bakery 417 | balance beam 418 | balloon 419 | ballpoint 420 | Band Aid 421 | banjo 422 | bannister 423 | barbell 424 | barber chair 425 | barbershop 426 | barn 427 | barometer 428 | barrel 429 | barrow 430 | baseball 431 | basketball 432 | bassinet 433 | bassoon 434 | bathing cap 435 | bath towel 436 | bathtub 437 | beach wagon 438 | beacon 439 | beaker 440 | bearskin 441 | beer bottle 442 | beer glass 443 | bell cote 444 | bib 445 | bicycle-built-for-two 446 | bikini 447 | binder 448 | binoculars 449 | birdhouse 450 | boathouse 451 | bobsled 452 | bolo tie 453 | bonnet 454 | bookcase 455 | bookshop 456 | bottlecap 457 | bow 458 | bow tie 459 | brass 460 | brassiere 461 | breakwater 462 | breastplate 463 | broom 464 | bucket 465 | buckle 466 | bulletproof vest 467 | bullet train 468 | butcher shop 469 | cab 470 | caldron 471 | candle 472 | cannon 473 | canoe 474 | can opener 475 | cardigan 476 | car mirror 477 | carousel 478 | carpenter's kit 479 | carton 480 | car wheel 481 | cash machine 482 | cassette 483 | cassette player 484 | castle 485 | catamaran 486 | CD player 487 | cello 488 | cellular telephone 489 | chain 490 | chainlink fence 491 | chain mail 492 | chain saw 493 | chest 494 | chiffonier 495 | chime 496 | china cabinet 497 | Christmas stocking 498 | church 499 | cinema 500 | cleaver 501 | cliff dwelling 502 | cloak 503 | clog 504 | cocktail shaker 505 | coffee mug 506 | coffeepot 507 | coil 508 | combination lock 509 | computer keyboard 510 | confectionery 511 | container ship 512 | convertible 513 | corkscrew 514 | cornet 515 | cowboy boot 516 | cowboy hat 517 | cradle 518 | crane 519 | crash helmet 520 | crate 521 | crib 522 | Crock Pot 523 | croquet ball 524 | crutch 525 | cuirass 526 | dam 527 | desk 528 | desktop computer 529 | dial telephone 530 | diaper 531 | digital clock 532 | digital watch 533 | dining table 534 | dishrag 535 | dishwasher 536 | disk brake 537 | dock 538 | dogsled 539 | dome 540 | doormat 541 | drilling platform 542 | drum 543 | drumstick 544 | dumbbell 545 | Dutch oven 546 | electric fan 547 | electric guitar 548 | electric locomotive 549 | entertainment center 550 | envelope 551 | espresso maker 552 | face powder 553 | feather boa 554 | file 555 | fireboat 556 | fire engine 557 | fire screen 558 | flagpole 559 | flute 560 | folding chair 561 | football helmet 562 | forklift 563 | fountain 564 | fountain pen 565 | four-poster 566 | freight car 567 | French horn 568 | frying pan 569 | fur coat 570 | garbage truck 571 | gasmask 572 | gas pump 573 | goblet 574 | go-kart 575 | golf ball 576 | golfcart 577 | gondola 578 | gong 579 | gown 580 | grand piano 581 | greenhouse 582 | grille 583 | grocery store 584 | guillotine 585 | hair slide 586 | hair spray 587 | half track 588 | hammer 589 | hamper 590 | hand blower 591 | hand-held computer 592 | handkerchief 593 | hard disc 594 | harmonica 595 | harp 596 | harvester 597 | hatchet 598 | holster 599 | home theater 600 | honeycomb 601 | hook 602 | hoopskirt 603 | horizontal bar 604 | horse cart 605 | hourglass 606 | iPod 607 | iron 608 | jack-o'-lantern 609 | jean 610 | jeep 611 | jersey 612 | jigsaw puzzle 613 | jinrikisha 614 | joystick 615 | kimono 616 | knee pad 617 | knot 618 | lab coat 619 | ladle 620 | lampshade 621 | laptop 622 | lawn mower 623 | lens cap 624 | letter opener 625 | library 626 | lifeboat 627 | lighter 628 | limousine 629 | liner 630 | lipstick 631 | Loafer 632 | lotion 633 | loudspeaker 634 | loupe 635 | lumbermill 636 | magnetic compass 637 | mailbag 638 | mailbox 639 | maillot 640 | maillot 641 | manhole cover 642 | maraca 643 | marimba 644 | mask 645 | matchstick 646 | maypole 647 | maze 648 | measuring cup 649 | medicine chest 650 | megalith 651 | microphone 652 | microwave 653 | military uniform 654 | milk can 655 | minibus 656 | miniskirt 657 | minivan 658 | missile 659 | mitten 660 | mixing bowl 661 | mobile home 662 | Model T 663 | modem 664 | monastery 665 | monitor 666 | moped 667 | mortar 668 | mortarboard 669 | mosque 670 | mosquito net 671 | motor scooter 672 | mountain bike 673 | mountain tent 674 | mouse 675 | mousetrap 676 | moving van 677 | muzzle 678 | nail 679 | neck brace 680 | necklace 681 | nipple 682 | notebook 683 | obelisk 684 | oboe 685 | ocarina 686 | odometer 687 | oil filter 688 | organ 689 | oscilloscope 690 | overskirt 691 | oxcart 692 | oxygen mask 693 | packet 694 | paddle 695 | paddlewheel 696 | padlock 697 | paintbrush 698 | pajama 699 | palace 700 | panpipe 701 | paper towel 702 | parachute 703 | parallel bars 704 | park bench 705 | parking meter 706 | passenger car 707 | patio 708 | pay-phone 709 | pedestal 710 | pencil box 711 | pencil sharpener 712 | perfume 713 | Petri dish 714 | photocopier 715 | pick 716 | pickelhaube 717 | picket fence 718 | pickup 719 | pier 720 | piggy bank 721 | pill bottle 722 | pillow 723 | ping-pong ball 724 | pinwheel 725 | pirate 726 | pitcher 727 | plane 728 | planetarium 729 | plastic bag 730 | plate rack 731 | plow 732 | plunger 733 | Polaroid camera 734 | pole 735 | police van 736 | poncho 737 | pool table 738 | pop bottle 739 | pot 740 | potter's wheel 741 | power drill 742 | prayer rug 743 | printer 744 | prison 745 | projectile 746 | projector 747 | puck 748 | punching bag 749 | purse 750 | quill 751 | quilt 752 | racer 753 | racket 754 | radiator 755 | radio 756 | radio telescope 757 | rain barrel 758 | recreational vehicle 759 | reel 760 | reflex camera 761 | refrigerator 762 | remote control 763 | restaurant 764 | revolver 765 | rifle 766 | rocking chair 767 | rotisserie 768 | rubber eraser 769 | rugby ball 770 | rule 771 | running shoe 772 | safe 773 | safety pin 774 | saltshaker 775 | sandal 776 | sarong 777 | sax 778 | scabbard 779 | scale 780 | school bus 781 | schooner 782 | scoreboard 783 | screen 784 | screw 785 | screwdriver 786 | seat belt 787 | sewing machine 788 | shield 789 | shoe shop 790 | shoji 791 | shopping basket 792 | shopping cart 793 | shovel 794 | shower cap 795 | shower curtain 796 | ski 797 | ski mask 798 | sleeping bag 799 | slide rule 800 | sliding door 801 | slot 802 | snorkel 803 | snowmobile 804 | snowplow 805 | soap dispenser 806 | soccer ball 807 | sock 808 | solar dish 809 | sombrero 810 | soup bowl 811 | space bar 812 | space heater 813 | space shuttle 814 | spatula 815 | speedboat 816 | spider web 817 | spindle 818 | sports car 819 | spotlight 820 | stage 821 | steam locomotive 822 | steel arch bridge 823 | steel drum 824 | stethoscope 825 | stole 826 | stone wall 827 | stopwatch 828 | stove 829 | strainer 830 | streetcar 831 | stretcher 832 | studio couch 833 | stupa 834 | submarine 835 | suit 836 | sundial 837 | sunglass 838 | sunglasses 839 | sunscreen 840 | suspension bridge 841 | swab 842 | sweatshirt 843 | swimming trunks 844 | swing 845 | switch 846 | syringe 847 | table lamp 848 | tank 849 | tape player 850 | teapot 851 | teddy 852 | television 853 | tennis ball 854 | thatch 855 | theater curtain 856 | thimble 857 | thresher 858 | throne 859 | tile roof 860 | toaster 861 | tobacco shop 862 | toilet seat 863 | torch 864 | totem pole 865 | tow truck 866 | toyshop 867 | tractor 868 | trailer truck 869 | tray 870 | trench coat 871 | tricycle 872 | trimaran 873 | tripod 874 | triumphal arch 875 | trolleybus 876 | trombone 877 | tub 878 | turnstile 879 | typewriter keyboard 880 | umbrella 881 | unicycle 882 | upright 883 | vacuum 884 | vase 885 | vault 886 | velvet 887 | vending machine 888 | vestment 889 | viaduct 890 | violin 891 | volleyball 892 | waffle iron 893 | wall clock 894 | wallet 895 | wardrobe 896 | warplane 897 | washbasin 898 | washer 899 | water bottle 900 | water jug 901 | water tower 902 | whiskey jug 903 | whistle 904 | wig 905 | window screen 906 | window shade 907 | Windsor tie 908 | wine bottle 909 | wing 910 | wok 911 | wooden spoon 912 | wool 913 | worm fence 914 | wreck 915 | yawl 916 | yurt 917 | web site 918 | comic book 919 | crossword puzzle 920 | street sign 921 | traffic light 922 | book jacket 923 | menu 924 | plate 925 | guacamole 926 | consomme 927 | hot pot 928 | trifle 929 | ice cream 930 | ice lolly 931 | French loaf 932 | bagel 933 | pretzel 934 | cheeseburger 935 | hotdog 936 | mashed potato 937 | head cabbage 938 | broccoli 939 | cauliflower 940 | zucchini 941 | spaghetti squash 942 | acorn squash 943 | butternut squash 944 | cucumber 945 | artichoke 946 | bell pepper 947 | cardoon 948 | mushroom 949 | Granny Smith 950 | strawberry 951 | orange 952 | lemon 953 | fig 954 | pineapple 955 | banana 956 | jackfruit 957 | custard apple 958 | pomegranate 959 | hay 960 | carbonara 961 | chocolate sauce 962 | dough 963 | meat loaf 964 | pizza 965 | potpie 966 | burrito 967 | red wine 968 | espresso 969 | cup 970 | eggnog 971 | alp 972 | bubble 973 | cliff 974 | coral reef 975 | geyser 976 | lakeside 977 | promontory 978 | sandbar 979 | seashore 980 | valley 981 | volcano 982 | ballplayer 983 | groom 984 | scuba diver 985 | rapeseed 986 | daisy 987 | yellow lady's slipper 988 | corn 989 | acorn 990 | hip 991 | buckeye 992 | coral fungus 993 | agaric 994 | gyromitra 995 | stinkhorn 996 | earthstar 997 | hen-of-the-woods 998 | bolete 999 | ear 1000 | toilet tissue -------------------------------------------------------------------------------- /Resnet-Eager-Mode-Quant/model/__pycache__/resnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-Eager-Mode-Quant/model/__pycache__/resnet.cpython-310.pyc -------------------------------------------------------------------------------- /Resnet-Eager-Mode-Quant/model/__pycache__/resnet.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-Eager-Mode-Quant/model/__pycache__/resnet.cpython-311.pyc -------------------------------------------------------------------------------- /Resnet-Eager-Mode-Quant/quant_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from model.resnet import resnet18 3 | from evaluate import evaluate 4 | 5 | model = resnet18(pretrained=True) 6 | #print(model) 7 | 8 | # Step 1: architecture changes 9 | # QuantStubs (we will do FloatFunctionals later) 10 | # Done 11 | 12 | # Step 2: fuse modules (recommended but not necessary) 13 | modules_to_list = model.modules_to_fuse() 14 | 15 | # It will keep Batchnorm 16 | model.eval() 17 | # fused_model = torch.ao.quantization.fuse_modules_qat(model, modules_to_list) 18 | 19 | # This will fuse BatchNorm weights into the preceding Conv 20 | fused_model = torch.ao.quantization.fuse_modules(model, modules_to_list) 21 | 22 | # Step 3: Assign qconfigs 23 | from torch.ao.quantization.fake_quantize import FakeQuantize 24 | activation_qconfig = FakeQuantize.with_args( 25 | observer=torch.ao.quantization.observer.HistogramObserver.with_args( 26 | quant_min=0, 27 | quant_max=255, 28 | dtype=torch.quint8, 29 | qscheme=torch.per_tensor_affine, 30 | ) 31 | ) 32 | 33 | weight_qconfig = FakeQuantize.with_args( 34 | observer=torch.ao.quantization.observer.PerChannelMinMaxObserver.with_args( 35 | quant_min=-128, 36 | quant_max=127, 37 | dtype=torch.qint8, 38 | qscheme=torch.per_channel_symmetric, 39 | ) 40 | ) 41 | 42 | qconfig = torch.quantization.QConfig(activation=activation_qconfig, 43 | weight=weight_qconfig) 44 | fused_model.qconfig = qconfig 45 | 46 | # Step 4: Prepare for fake-quant 47 | fused_model.train() 48 | fake_quant_model = torch.ao.quantization.prepare_qat(fused_model) 49 | 50 | 51 | print("\nFloat") 52 | evaluate(model, 'cpu') 53 | 54 | 55 | print("\nFused Model") 56 | evaluate(fused_model, 'cpu') 57 | 58 | 59 | print("\nFake quant - PTQ") 60 | evaluate(fake_quant_model, 'cpu') 61 | 62 | fake_quant_model.apply(torch.ao.quantization.fake_quantize.disable_observer) 63 | 64 | print("\nFake quant - post-PTQ") 65 | evaluate(fake_quant_model, 'cpu') 66 | 67 | 68 | # Step 5: convert (true int8 model) 69 | converted_model = torch.ao.quantization.convert(fake_quant_model) 70 | 71 | print("\nConverted model") 72 | evaluate(converted_model, 'cpu') 73 | 74 | import ipdb 75 | ipdb.set_trace() -------------------------------------------------------------------------------- /Resnet-FX-CLE/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | Histogram_plots -------------------------------------------------------------------------------- /Resnet-FX-CLE/Box_plots/With_CLE/layer2.0.conv1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-CLE/Box_plots/With_CLE/layer2.0.conv1.png -------------------------------------------------------------------------------- /Resnet-FX-CLE/Box_plots/Without_CLE/layer2.0.conv1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-CLE/Box_plots/Without_CLE/layer2.0.conv1.png -------------------------------------------------------------------------------- /Resnet-FX-CLE/README.md: -------------------------------------------------------------------------------- 1 | # Cross Layer Equalization (CLE): PyTorch ResNet Coding tutorial 2 | 3 | This is the finished code associated with the YouTube tutorial at: 4 | 5 | [![Cross Layer Equalization: Everything You Need to Know](https://ytcards.demolab.com/?id=3eATdsWmHyI&title=Cross+Layer+Equalization%3A+Everything+You+Need+to+Know&lang=en×tamp=1715768680&background_color=%230d1117&title_color=%23ffffff&stats_color=%23dedede&max_title_lines=1&width=250&border_radius=5 "Cross Layer Equalization: Everything You Need to Know")](https://www.youtube.com/watch?v=3eATdsWmHyI) 6 | 7 | This code is built from the code for the QAT tutorial, located in `Resnet-FX-QAT`. 8 | We expand upon it to allow fusing of Conv and BN for float as well as quantized models. 9 | We add capability to do CLE, including automating of the production of the list of to-be-CLE'd 10 | layer pairs via a graph-tracing technique. 11 | 12 | ### Prerequisites: 13 | 14 | #### Installing PyTorch: 15 | 16 | To run this code, you need to have PyTorch installed in your environment. If you do not have PyTorch installed, please follow this [official guide](https://pytorch.org/get-started/locally/). 17 | 18 | I created this code with PyTorch Version: 2.1.1. In case you have any versioning issues, you can revert to that version. 19 | 20 | #### Printing the FX graph: 21 | 22 | To run `fx_model.graph.print_tabular()`, one needs to have `tabulate` installed. To do, activate your (e.g. conda) environment and run 23 | 24 | ``` 25 | pip install tabulate 26 | ``` 27 | 28 | #### Printing the FX graph: 29 | 30 | For this tutorial, I downloaded some images from google search, one example each for a handful of the ImageNet classes. 31 | You can add whatever ImageNet class examples you want, but be make sure to you name the images the same as the class names, e.g. `hen.jpg` for classname `hen`. 32 | Or, feel free to generalise the code so that isn't a constraint! 33 | 34 | ### Running this code: 35 | 36 | Once you have PyTorch installed, first navigate to a directory you will be working from. As you follow the next steps, your final file structure will look like this: `your-directory/Resnet-FX-QAT`. 37 | 38 | Next, from `your-directory`, clone the `Quantization-Tutorials` repo. This repo contains different tutorials, but they are all interlinked. Feel no need to do any of the others! I just structured it this way because the tutorials share a lot of code and it might help people to see different parts in one place. 39 | 40 | You can also `git init` and then `git pull/fetch`, depending on what you prefer. 41 | 42 | To clone the repo, run: 43 | 44 | ``` 45 | git clone git@github.com:OscarSavolainenDR/Quantization-Tutorials.git . 46 | ``` 47 | 48 | If you did the cloning in place with the `.` at the end, your folder structure should look like `your-folder/Resnet-FX-QAT`, with various other folders for other tutorials. 49 | 50 | Next, cd into the Resnet FX CLE tutorial: 51 | 52 | ``` 53 | cd Resnet-FX-CLE 54 | ``` 55 | 56 | Then, just run `python main.py` from your command line! However I would obviously recommend that you follow along with the tutorial, so that you learn how it all works and get your hands dirty. 57 | 58 | The code is also available as a Jupyter Notebook: `CLE_notebook.ipynb`, in case this is preferred. 59 | 60 | Let me know if there are any issues! 61 | -------------------------------------------------------------------------------- /Resnet-FX-CLE/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-CLE/dog.jpg -------------------------------------------------------------------------------- /Resnet-FX-CLE/evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluate import evaluate 2 | -------------------------------------------------------------------------------- /Resnet-FX-CLE/evaluate/evaluate.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch 3 | 4 | 5 | def evaluate(model, device_str: str, target: str): 6 | # Download an example image from the pytorch website 7 | import urllib 8 | 9 | filename = Path(f"evaluate/images/{target}.jpg") 10 | 11 | from PIL import Image 12 | from torchvision import transforms 13 | 14 | input_image = Image.open(filename) 15 | preprocess = transforms.Compose( 16 | [ 17 | transforms.Resize(256), 18 | transforms.CenterCrop(224), 19 | transforms.ToTensor(), 20 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 21 | ] 22 | ) 23 | input_tensor = preprocess(input_image) 24 | input_batch = input_tensor.unsqueeze( 25 | 0 26 | ) # create a mini-batch as expected by the model 27 | 28 | # move the input and model to GPU for speed if available, or to CPU if converted 29 | if not (device_str in ["cpu", "cuda"]): 30 | raise NotImplementedError("`device_str` should be 'cpu' or 'cuda' ") 31 | if device_str == "cuda": 32 | assert torch.cuda.is_available(), "Check CUDA is available" 33 | 34 | input_batch = input_batch.to(device_str) 35 | model.to(device_str) 36 | model.eval() 37 | 38 | with torch.no_grad(): 39 | output = model(input_batch) 40 | # Tensor of shape 1000, with confidence scores over ImageNet's 1000 classes 41 | # print(output[0]) 42 | # The output has unnormalized scores. To get probabilities, you can run a softmax on it. 43 | probabilities = torch.nn.functional.softmax(output[0], dim=0) 44 | # print(probabilities) 45 | 46 | # Read the categories 47 | with open(Path("evaluate/imagenet_classes.txt"), "r") as f: 48 | categories = [s.strip() for s in f.readlines()] 49 | # Show top categories per image 50 | top5_prob, top5_catid = torch.topk(probabilities, 5) 51 | for i in range(top5_prob.size(0)): 52 | print(categories[top5_catid[i]], top5_prob[i].item()) 53 | print("\n") 54 | -------------------------------------------------------------------------------- /Resnet-FX-CLE/evaluate/imagenet_classes.txt: -------------------------------------------------------------------------------- 1 | tench 2 | goldfish 3 | great white shark 4 | tiger shark 5 | hammerhead 6 | electric ray 7 | stingray 8 | cock 9 | hen 10 | ostrich 11 | brambling 12 | goldfinch 13 | house finch 14 | junco 15 | indigo bunting 16 | robin 17 | bulbul 18 | jay 19 | magpie 20 | chickadee 21 | water ouzel 22 | kite 23 | bald eagle 24 | vulture 25 | great grey owl 26 | European fire salamander 27 | common newt 28 | eft 29 | spotted salamander 30 | axolotl 31 | bullfrog 32 | tree frog 33 | tailed frog 34 | loggerhead 35 | leatherback turtle 36 | mud turtle 37 | terrapin 38 | box turtle 39 | banded gecko 40 | common iguana 41 | American chameleon 42 | whiptail 43 | agama 44 | frilled lizard 45 | alligator lizard 46 | Gila monster 47 | green lizard 48 | African chameleon 49 | Komodo dragon 50 | African crocodile 51 | American alligator 52 | triceratops 53 | thunder snake 54 | ringneck snake 55 | hognose snake 56 | green snake 57 | king snake 58 | garter snake 59 | water snake 60 | vine snake 61 | night snake 62 | boa constrictor 63 | rock python 64 | Indian cobra 65 | green mamba 66 | sea snake 67 | horned viper 68 | diamondback 69 | sidewinder 70 | trilobite 71 | harvestman 72 | scorpion 73 | black and gold garden spider 74 | barn spider 75 | garden spider 76 | black widow 77 | tarantula 78 | wolf spider 79 | tick 80 | centipede 81 | black grouse 82 | ptarmigan 83 | ruffed grouse 84 | prairie chicken 85 | peacock 86 | quail 87 | partridge 88 | African grey 89 | macaw 90 | sulphur-crested cockatoo 91 | lorikeet 92 | coucal 93 | bee eater 94 | hornbill 95 | hummingbird 96 | jacamar 97 | toucan 98 | drake 99 | red-breasted merganser 100 | goose 101 | black swan 102 | tusker 103 | echidna 104 | platypus 105 | wallaby 106 | koala 107 | wombat 108 | jellyfish 109 | sea anemone 110 | brain coral 111 | flatworm 112 | nematode 113 | conch 114 | snail 115 | slug 116 | sea slug 117 | chiton 118 | chambered nautilus 119 | Dungeness crab 120 | rock crab 121 | fiddler crab 122 | king crab 123 | American lobster 124 | spiny lobster 125 | crayfish 126 | hermit crab 127 | isopod 128 | white stork 129 | black stork 130 | spoonbill 131 | flamingo 132 | little blue heron 133 | American egret 134 | bittern 135 | crane 136 | limpkin 137 | European gallinule 138 | American coot 139 | bustard 140 | ruddy turnstone 141 | red-backed sandpiper 142 | redshank 143 | dowitcher 144 | oystercatcher 145 | pelican 146 | king penguin 147 | albatross 148 | grey whale 149 | killer whale 150 | dugong 151 | sea lion 152 | Chihuahua 153 | Japanese spaniel 154 | Maltese dog 155 | Pekinese 156 | Shih-Tzu 157 | Blenheim spaniel 158 | papillon 159 | toy terrier 160 | Rhodesian ridgeback 161 | Afghan hound 162 | basset 163 | beagle 164 | bloodhound 165 | bluetick 166 | black-and-tan coonhound 167 | Walker hound 168 | English foxhound 169 | redbone 170 | borzoi 171 | Irish wolfhound 172 | Italian greyhound 173 | whippet 174 | Ibizan hound 175 | Norwegian elkhound 176 | otterhound 177 | Saluki 178 | Scottish deerhound 179 | Weimaraner 180 | Staffordshire bullterrier 181 | American Staffordshire terrier 182 | Bedlington terrier 183 | Border terrier 184 | Kerry blue terrier 185 | Irish terrier 186 | Norfolk terrier 187 | Norwich terrier 188 | Yorkshire terrier 189 | wire-haired fox terrier 190 | Lakeland terrier 191 | Sealyham terrier 192 | Airedale 193 | cairn 194 | Australian terrier 195 | Dandie Dinmont 196 | Boston bull 197 | miniature schnauzer 198 | giant schnauzer 199 | standard schnauzer 200 | Scotch terrier 201 | Tibetan terrier 202 | silky terrier 203 | soft-coated wheaten terrier 204 | West Highland white terrier 205 | Lhasa 206 | flat-coated retriever 207 | curly-coated retriever 208 | golden retriever 209 | Labrador retriever 210 | Chesapeake Bay retriever 211 | German short-haired pointer 212 | vizsla 213 | English setter 214 | Irish setter 215 | Gordon setter 216 | Brittany spaniel 217 | clumber 218 | English springer 219 | Welsh springer spaniel 220 | cocker spaniel 221 | Sussex spaniel 222 | Irish water spaniel 223 | kuvasz 224 | schipperke 225 | groenendael 226 | malinois 227 | briard 228 | kelpie 229 | komondor 230 | Old English sheepdog 231 | Shetland sheepdog 232 | collie 233 | Border collie 234 | Bouvier des Flandres 235 | Rottweiler 236 | German shepherd 237 | Doberman 238 | miniature pinscher 239 | Greater Swiss Mountain dog 240 | Bernese mountain dog 241 | Appenzeller 242 | EntleBucher 243 | boxer 244 | bull mastiff 245 | Tibetan mastiff 246 | French bulldog 247 | Great Dane 248 | Saint Bernard 249 | Eskimo dog 250 | malamute 251 | Siberian husky 252 | dalmatian 253 | affenpinscher 254 | basenji 255 | pug 256 | Leonberg 257 | Newfoundland 258 | Great Pyrenees 259 | Samoyed 260 | Pomeranian 261 | chow 262 | keeshond 263 | Brabancon griffon 264 | Pembroke 265 | Cardigan 266 | toy poodle 267 | miniature poodle 268 | standard poodle 269 | Mexican hairless 270 | timber wolf 271 | white wolf 272 | red wolf 273 | coyote 274 | dingo 275 | dhole 276 | African hunting dog 277 | hyena 278 | red fox 279 | kit fox 280 | Arctic fox 281 | grey fox 282 | tabby 283 | tiger cat 284 | Persian cat 285 | Siamese cat 286 | Egyptian cat 287 | cougar 288 | lynx 289 | leopard 290 | snow leopard 291 | jaguar 292 | lion 293 | tiger 294 | cheetah 295 | brown bear 296 | American black bear 297 | ice bear 298 | sloth bear 299 | mongoose 300 | meerkat 301 | tiger beetle 302 | ladybug 303 | ground beetle 304 | long-horned beetle 305 | leaf beetle 306 | dung beetle 307 | rhinoceros beetle 308 | weevil 309 | fly 310 | bee 311 | ant 312 | grasshopper 313 | cricket 314 | walking stick 315 | cockroach 316 | mantis 317 | cicada 318 | leafhopper 319 | lacewing 320 | dragonfly 321 | damselfly 322 | admiral 323 | ringlet 324 | monarch 325 | cabbage butterfly 326 | sulphur butterfly 327 | lycaenid 328 | starfish 329 | sea urchin 330 | sea cucumber 331 | wood rabbit 332 | hare 333 | Angora 334 | hamster 335 | porcupine 336 | fox squirrel 337 | marmot 338 | beaver 339 | guinea pig 340 | sorrel 341 | zebra 342 | hog 343 | wild boar 344 | warthog 345 | hippopotamus 346 | ox 347 | water buffalo 348 | bison 349 | ram 350 | bighorn 351 | ibex 352 | hartebeest 353 | impala 354 | gazelle 355 | Arabian camel 356 | llama 357 | weasel 358 | mink 359 | polecat 360 | black-footed ferret 361 | otter 362 | skunk 363 | badger 364 | armadillo 365 | three-toed sloth 366 | orangutan 367 | gorilla 368 | chimpanzee 369 | gibbon 370 | siamang 371 | guenon 372 | patas 373 | baboon 374 | macaque 375 | langur 376 | colobus 377 | proboscis monkey 378 | marmoset 379 | capuchin 380 | howler monkey 381 | titi 382 | spider monkey 383 | squirrel monkey 384 | Madagascar cat 385 | indri 386 | Indian elephant 387 | African elephant 388 | lesser panda 389 | giant panda 390 | barracouta 391 | eel 392 | coho 393 | rock beauty 394 | anemone fish 395 | sturgeon 396 | gar 397 | lionfish 398 | puffer 399 | abacus 400 | abaya 401 | academic gown 402 | accordion 403 | acoustic guitar 404 | aircraft carrier 405 | airliner 406 | airship 407 | altar 408 | ambulance 409 | amphibian 410 | analog clock 411 | apiary 412 | apron 413 | ashcan 414 | assault rifle 415 | backpack 416 | bakery 417 | balance beam 418 | balloon 419 | ballpoint 420 | Band Aid 421 | banjo 422 | bannister 423 | barbell 424 | barber chair 425 | barbershop 426 | barn 427 | barometer 428 | barrel 429 | barrow 430 | baseball 431 | basketball 432 | bassinet 433 | bassoon 434 | bathing cap 435 | bath towel 436 | bathtub 437 | beach wagon 438 | beacon 439 | beaker 440 | bearskin 441 | beer bottle 442 | beer glass 443 | bell cote 444 | bib 445 | bicycle-built-for-two 446 | bikini 447 | binder 448 | binoculars 449 | birdhouse 450 | boathouse 451 | bobsled 452 | bolo tie 453 | bonnet 454 | bookcase 455 | bookshop 456 | bottlecap 457 | bow 458 | bow tie 459 | brass 460 | brassiere 461 | breakwater 462 | breastplate 463 | broom 464 | bucket 465 | buckle 466 | bulletproof vest 467 | bullet train 468 | butcher shop 469 | cab 470 | caldron 471 | candle 472 | cannon 473 | canoe 474 | can opener 475 | cardigan 476 | car mirror 477 | carousel 478 | carpenter's kit 479 | carton 480 | car wheel 481 | cash machine 482 | cassette 483 | cassette player 484 | castle 485 | catamaran 486 | CD player 487 | cello 488 | cellular telephone 489 | chain 490 | chainlink fence 491 | chain mail 492 | chain saw 493 | chest 494 | chiffonier 495 | chime 496 | china cabinet 497 | Christmas stocking 498 | church 499 | cinema 500 | cleaver 501 | cliff dwelling 502 | cloak 503 | clog 504 | cocktail shaker 505 | coffee mug 506 | coffeepot 507 | coil 508 | combination lock 509 | computer keyboard 510 | confectionery 511 | container ship 512 | convertible 513 | corkscrew 514 | cornet 515 | cowboy boot 516 | cowboy hat 517 | cradle 518 | crane 519 | crash helmet 520 | crate 521 | crib 522 | Crock Pot 523 | croquet ball 524 | crutch 525 | cuirass 526 | dam 527 | desk 528 | desktop computer 529 | dial telephone 530 | diaper 531 | digital clock 532 | digital watch 533 | dining table 534 | dishrag 535 | dishwasher 536 | disk brake 537 | dock 538 | dogsled 539 | dome 540 | doormat 541 | drilling platform 542 | drum 543 | drumstick 544 | dumbbell 545 | Dutch oven 546 | electric fan 547 | electric guitar 548 | electric locomotive 549 | entertainment center 550 | envelope 551 | espresso maker 552 | face powder 553 | feather boa 554 | file 555 | fireboat 556 | fire engine 557 | fire screen 558 | flagpole 559 | flute 560 | folding chair 561 | football helmet 562 | forklift 563 | fountain 564 | fountain pen 565 | four-poster 566 | freight car 567 | French horn 568 | frying pan 569 | fur coat 570 | garbage truck 571 | gasmask 572 | gas pump 573 | goblet 574 | go-kart 575 | golf ball 576 | golfcart 577 | gondola 578 | gong 579 | gown 580 | grand piano 581 | greenhouse 582 | grille 583 | grocery store 584 | guillotine 585 | hair slide 586 | hair spray 587 | half track 588 | hammer 589 | hamper 590 | hand blower 591 | hand-held computer 592 | handkerchief 593 | hard disc 594 | harmonica 595 | harp 596 | harvester 597 | hatchet 598 | holster 599 | home theater 600 | honeycomb 601 | hook 602 | hoopskirt 603 | horizontal bar 604 | horse cart 605 | hourglass 606 | iPod 607 | iron 608 | jack-o'-lantern 609 | jean 610 | jeep 611 | jersey 612 | jigsaw puzzle 613 | jinrikisha 614 | joystick 615 | kimono 616 | knee pad 617 | knot 618 | lab coat 619 | ladle 620 | lampshade 621 | laptop 622 | lawn mower 623 | lens cap 624 | letter opener 625 | library 626 | lifeboat 627 | lighter 628 | limousine 629 | liner 630 | lipstick 631 | Loafer 632 | lotion 633 | loudspeaker 634 | loupe 635 | lumbermill 636 | magnetic compass 637 | mailbag 638 | mailbox 639 | maillot 640 | maillot 641 | manhole cover 642 | maraca 643 | marimba 644 | mask 645 | matchstick 646 | maypole 647 | maze 648 | measuring cup 649 | medicine chest 650 | megalith 651 | microphone 652 | microwave 653 | military uniform 654 | milk can 655 | minibus 656 | miniskirt 657 | minivan 658 | missile 659 | mitten 660 | mixing bowl 661 | mobile home 662 | Model T 663 | modem 664 | monastery 665 | monitor 666 | moped 667 | mortar 668 | mortarboard 669 | mosque 670 | mosquito net 671 | motor scooter 672 | mountain bike 673 | mountain tent 674 | mouse 675 | mousetrap 676 | moving van 677 | muzzle 678 | nail 679 | neck brace 680 | necklace 681 | nipple 682 | notebook 683 | obelisk 684 | oboe 685 | ocarina 686 | odometer 687 | oil filter 688 | organ 689 | oscilloscope 690 | overskirt 691 | oxcart 692 | oxygen mask 693 | packet 694 | paddle 695 | paddlewheel 696 | padlock 697 | paintbrush 698 | pajama 699 | palace 700 | panpipe 701 | paper towel 702 | parachute 703 | parallel bars 704 | park bench 705 | parking meter 706 | passenger car 707 | patio 708 | pay-phone 709 | pedestal 710 | pencil box 711 | pencil sharpener 712 | perfume 713 | Petri dish 714 | photocopier 715 | pick 716 | pickelhaube 717 | picket fence 718 | pickup 719 | pier 720 | piggy bank 721 | pill bottle 722 | pillow 723 | ping-pong ball 724 | pinwheel 725 | pirate 726 | pitcher 727 | plane 728 | planetarium 729 | plastic bag 730 | plate rack 731 | plow 732 | plunger 733 | Polaroid camera 734 | pole 735 | police van 736 | poncho 737 | pool table 738 | pop bottle 739 | pot 740 | potter's wheel 741 | power drill 742 | prayer rug 743 | printer 744 | prison 745 | projectile 746 | projector 747 | puck 748 | punching bag 749 | purse 750 | quill 751 | quilt 752 | racer 753 | racket 754 | radiator 755 | radio 756 | radio telescope 757 | rain barrel 758 | recreational vehicle 759 | reel 760 | reflex camera 761 | refrigerator 762 | remote control 763 | restaurant 764 | revolver 765 | rifle 766 | rocking chair 767 | rotisserie 768 | rubber eraser 769 | rugby ball 770 | rule 771 | running shoe 772 | safe 773 | safety pin 774 | saltshaker 775 | sandal 776 | sarong 777 | sax 778 | scabbard 779 | scale 780 | school bus 781 | schooner 782 | scoreboard 783 | screen 784 | screw 785 | screwdriver 786 | seat belt 787 | sewing machine 788 | shield 789 | shoe shop 790 | shoji 791 | shopping basket 792 | shopping cart 793 | shovel 794 | shower cap 795 | shower curtain 796 | ski 797 | ski mask 798 | sleeping bag 799 | slide rule 800 | sliding door 801 | slot 802 | snorkel 803 | snowmobile 804 | snowplow 805 | soap dispenser 806 | soccer ball 807 | sock 808 | solar dish 809 | sombrero 810 | soup bowl 811 | space bar 812 | space heater 813 | space shuttle 814 | spatula 815 | speedboat 816 | spider web 817 | spindle 818 | sports car 819 | spotlight 820 | stage 821 | steam locomotive 822 | steel arch bridge 823 | steel drum 824 | stethoscope 825 | stole 826 | stone wall 827 | stopwatch 828 | stove 829 | strainer 830 | streetcar 831 | stretcher 832 | studio couch 833 | stupa 834 | submarine 835 | suit 836 | sundial 837 | sunglass 838 | sunglasses 839 | sunscreen 840 | suspension bridge 841 | swab 842 | sweatshirt 843 | swimming trunks 844 | swing 845 | switch 846 | syringe 847 | table lamp 848 | tank 849 | tape player 850 | teapot 851 | teddy 852 | television 853 | tennis ball 854 | thatch 855 | theater curtain 856 | thimble 857 | thresher 858 | throne 859 | tile roof 860 | toaster 861 | tobacco shop 862 | toilet seat 863 | torch 864 | totem pole 865 | tow truck 866 | toyshop 867 | tractor 868 | trailer truck 869 | tray 870 | trench coat 871 | tricycle 872 | trimaran 873 | tripod 874 | triumphal arch 875 | trolleybus 876 | trombone 877 | tub 878 | turnstile 879 | typewriter keyboard 880 | umbrella 881 | unicycle 882 | upright 883 | vacuum 884 | vase 885 | vault 886 | velvet 887 | vending machine 888 | vestment 889 | viaduct 890 | violin 891 | volleyball 892 | waffle iron 893 | wall clock 894 | wallet 895 | wardrobe 896 | warplane 897 | washbasin 898 | washer 899 | water bottle 900 | water jug 901 | water tower 902 | whiskey jug 903 | whistle 904 | wig 905 | window screen 906 | window shade 907 | Windsor tie 908 | wine bottle 909 | wing 910 | wok 911 | wooden spoon 912 | wool 913 | worm fence 914 | wreck 915 | yawl 916 | yurt 917 | web site 918 | comic book 919 | crossword puzzle 920 | street sign 921 | traffic light 922 | book jacket 923 | menu 924 | plate 925 | guacamole 926 | consomme 927 | hot pot 928 | trifle 929 | ice cream 930 | ice lolly 931 | French loaf 932 | bagel 933 | pretzel 934 | cheeseburger 935 | hotdog 936 | mashed potato 937 | head cabbage 938 | broccoli 939 | cauliflower 940 | zucchini 941 | spaghetti squash 942 | acorn squash 943 | butternut squash 944 | cucumber 945 | artichoke 946 | bell pepper 947 | cardoon 948 | mushroom 949 | Granny Smith 950 | strawberry 951 | orange 952 | lemon 953 | fig 954 | pineapple 955 | banana 956 | jackfruit 957 | custard apple 958 | pomegranate 959 | hay 960 | carbonara 961 | chocolate sauce 962 | dough 963 | meat loaf 964 | pizza 965 | potpie 966 | burrito 967 | red wine 968 | espresso 969 | cup 970 | eggnog 971 | alp 972 | bubble 973 | cliff 974 | coral reef 975 | geyser 976 | lakeside 977 | promontory 978 | sandbar 979 | seashore 980 | valley 981 | volcano 982 | ballplayer 983 | groom 984 | scuba diver 985 | rapeseed 986 | daisy 987 | yellow lady's slipper 988 | corn 989 | acorn 990 | hip 991 | buckeye 992 | coral fungus 993 | agaric 994 | gyromitra 995 | stinkhorn 996 | earthstar 997 | hen-of-the-woods 998 | bolete 999 | ear 1000 | toilet tissue -------------------------------------------------------------------------------- /Resnet-FX-CLE/evaluate/images/README.md: -------------------------------------------------------------------------------- 1 | ## The ground truth images 2 | 3 | **IMPORTANT**: the image names should exactly match the ImageNet class names, e.g. `hen`, `Samoyed`, etc. -------------------------------------------------------------------------------- /Resnet-FX-CLE/evaluate/images/Samoyed.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-CLE/evaluate/images/Samoyed.jpg -------------------------------------------------------------------------------- /Resnet-FX-CLE/evaluate/images/clog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-CLE/evaluate/images/clog.jpg -------------------------------------------------------------------------------- /Resnet-FX-CLE/evaluate/images/hen.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-CLE/evaluate/images/hen.jpg -------------------------------------------------------------------------------- /Resnet-FX-CLE/evaluate/images/mail_box.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-CLE/evaluate/images/mail_box.jpg -------------------------------------------------------------------------------- /Resnet-FX-CLE/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | from quant_vis.histograms import ( 5 | add_sensitivity_analysis_hooks, 6 | plot_quant_act_SA_hist, 7 | plot_quant_weight_hist, 8 | ) 9 | import torch 10 | from torch import fx 11 | from torch.ao.quantization._equalize import equalize 12 | from torch.ao.quantization._learnable_fake_quantize import ( 13 | _LearnableFakeQuantize as LearnableFakeQuantize, 14 | ) 15 | from torch.ao.quantization.qconfig_mapping import QConfigMapping 16 | from torch.ao.quantization.quantize_fx import prepare_qat_fx 17 | import torch.quantization as tq 18 | from utils.ipdb_hook import ipdb_sys_excepthook 19 | import matplotlib.pyplot as plt 20 | 21 | from evaluate import evaluate 22 | from model.resnet import resnet18 23 | from quant_vis.utils.prop_data import forward_and_backprop_an_image 24 | from utils.graph_manip import ( 25 | float_convbn_to_conv, 26 | qat_convbn_to_conv, 27 | get_previous_module_node, 28 | ) 29 | from utils.qconfigs import ( 30 | fake_quant_act, 31 | fake_quant_weight, 32 | learnable_act, 33 | learnable_weights, 34 | ) 35 | 36 | # Adds ipdb breakpoint if and where we have an error 37 | ipdb_sys_excepthook() 38 | 39 | # Intialize model 40 | model = resnet18(pretrained=True) 41 | 42 | 43 | ############################ 44 | # CROSS LAYER EQUALIZATION # 45 | ############################ 46 | 47 | # Graph-trace the model 48 | model.train() 49 | float_traced_model = fx.symbolic_trace(model) 50 | 51 | # Merge all batchnorms into preceding convs 52 | float_traced_model.eval() 53 | float_traced_model = float_convbn_to_conv(float_traced_model) 54 | 55 | # Iterate through graph, find CLE layer pairs. 56 | # The graph is alrady in order of execution, and there isn't 57 | # any complicated branching, so we can just treat all layers as 58 | # sequential. 59 | pairs = [] 60 | for node in float_traced_model.graph.nodes: 61 | if node.op == "call_module": 62 | module = float_traced_model.get_submodule(node.target) 63 | if hasattr(module, "weight"): 64 | prev_node = get_previous_module_node( 65 | float_traced_model, 66 | node, 67 | (torch.nn.Conv2d, torch.nn.Linear), 68 | CLE_compatible=True, 69 | ) 70 | if prev_node: 71 | pairs.append([prev_node.target, node.target]) 72 | 73 | # Perform CLE from torch.ao.quantization._equalize import equalize 74 | cle_model = equalize(float_traced_model, pairs, threshold=1e-4, inplace=False) 75 | 76 | ###################### 77 | # QUANTIZE THE MODEL # 78 | ###################### 79 | # Define qconfigs 80 | qconfig_global = tq.QConfig(activation=fake_quant_act, weight=fake_quant_weight) 81 | 82 | # Assign qconfigs 83 | qconfig_mapping = QConfigMapping() 84 | 85 | # We loop through the modules so that we can access the `out_channels` attribute 86 | for name, module in cle_model.named_modules(): 87 | if hasattr(module, "out_channels"): 88 | qconfig = tq.QConfig( 89 | activation=learnable_act(range=2), 90 | weight=learnable_weights(channels=module.out_channels), 91 | ) 92 | qconfig_mapping.set_module_name(name, qconfig) 93 | 94 | 95 | # Do symbolic tracing and quantization 96 | example_inputs = (torch.randn(1, 3, 224, 224),) 97 | cle_model.eval() 98 | fx_model_w_cle = prepare_qat_fx(cle_model, qconfig_mapping, example_inputs) 99 | 100 | # For comparison, we also get an FX model without CLE. We do so by 101 | # performing FX quantization and fusing the BNs into the Convs. 102 | qconfig_mapping = QConfigMapping() # .set_global(qconfig_global) 103 | for name, module in model.named_modules(): 104 | if hasattr(module, "out_channels"): 105 | qconfig = tq.QConfig( 106 | activation=learnable_act(range=2), 107 | weight=learnable_weights(channels=module.out_channels), 108 | ) 109 | qconfig_mapping.set_module_name(name, qconfig) 110 | fx_model_no_cle = prepare_qat_fx(model, qconfig_mapping, example_inputs) 111 | fx_model_no_cle.eval() 112 | fx_model_no_cle = qat_convbn_to_conv(fx_model_no_cle) 113 | 114 | # Evaluate model 115 | print("\nOriginal") 116 | evaluate(model, "cpu", "Samoyed") 117 | 118 | print("\nTraced model") 119 | evaluate(float_traced_model, "cpu", "Samoyed") 120 | 121 | print("\nCLE model") 122 | evaluate(cle_model, "cpu", "Samoyed") 123 | 124 | print("\nFX prepared, with CLE") 125 | evaluate(fx_model_w_cle, "cpu", "Samoyed") 126 | 127 | # Check performance on hen 128 | print("CLE model evaluation (hen):") 129 | evaluate(fx_model_w_cle, "cpu", "hen") 130 | 131 | # Check performance on clog (which we did not overfit to) 132 | print("CLE model evaluation (clog):") 133 | evaluate(fx_model_w_cle, "cpu", "clog") 134 | 135 | # Check performance on clog (which we did not overfit to) 136 | print("CLE model evaluation (mail box):") 137 | evaluate(fx_model_w_cle, "cpu", "mail_box") 138 | 139 | print("\nFX prepared, without CLE") 140 | evaluate(fx_model_no_cle, "cpu", "Samoyed") 141 | 142 | # Check performance on hen 143 | print("FX prepared, without CLE (hen):") 144 | evaluate(fx_model_no_cle, "cpu", "hen") 145 | 146 | # Check performance on clog (which we did not overfit to) 147 | print("FX prepared, without CLE (clog):") 148 | evaluate(fx_model_no_cle, "cpu", "clog") 149 | 150 | # Check performance on clog (which we did not overfit to) 151 | print("FX prepared, without CLE (mail box):") 152 | evaluate(fx_model_no_cle, "cpu", "mail_box") 153 | # # Prints the graph as a table 154 | # print("\nGraph as a Table:\n") 155 | # fx_model.graph.print_tabular() 156 | 157 | ######################## 158 | # VISUALIZE CLE EFFECT # 159 | ######################## 160 | 161 | 162 | def per_channel_boxplots(weight_tensor: torch.Tensor, title: str, CLE: bool): 163 | """ 164 | Given a weight tensor, plots its per-output-channel boxplots so we can 165 | observe its dynamic range. 166 | """ 167 | # Ensure the weight tensor is on the CPU 168 | weight_tensor = weight_tensor.cpu() 169 | 170 | # Get the number of output channels 171 | num_output_channels = weight_tensor.shape[0] 172 | 173 | # Create a figure and axis 174 | fig, ax = plt.subplots(figsize=(10, 6)) 175 | 176 | # Create a list of data for each output channel 177 | data = [weight_tensor[i].flatten().tolist() for i in range(num_output_channels)] 178 | 179 | # Create the boxplots 180 | ax.boxplot(data) 181 | 182 | # Set the title and labels 183 | CLE_str = "With_CLE" if CLE else "Without_CLE" 184 | ax.set_title(f"{CLE_str}, {title}, Per-Output-Channel Weight Boxplots") 185 | ax.set_xlabel("Output Channel") 186 | ax.set_ylabel("Weight Value") 187 | 188 | # Set the x-tick labels 189 | ax.set_xticks(range(1, num_output_channels + 1)) 190 | ax.set_xticklabels(range(num_output_channels)) 191 | 192 | # Adjust the spacing and display the plot 193 | plt.tight_layout() 194 | plt.show() 195 | 196 | # Save file 197 | folder_path = Path(os.path.abspath("") + f"/Box_plots/{CLE_str}") 198 | file_path = os.path.join(folder_path, f"{title}.png") 199 | title = title.replace(".", "-") 200 | if not os.path.exists(folder_path): 201 | os.makedirs(folder_path, exist_ok=True) 202 | fig.savefig(file_path, dpi=450) 203 | 204 | 205 | # ACTIVATION PLOTS 206 | def create_act_plots(model, title): 207 | def conditions_met_forward_act_hook(module: torch.nn.Module, name: str) -> bool: 208 | if isinstance(module, LearnableFakeQuantize): 209 | # if '1' in name: 210 | print(f"Adding hook to {name}") 211 | return True 212 | return False 213 | 214 | # We add the hooks 215 | act_forward_histograms, act_backward_histograms = add_sensitivity_analysis_hooks( 216 | model, conditions_met=conditions_met_forward_act_hook, bit_res=8 217 | ) 218 | 219 | forward_and_backprop_an_image(model) 220 | 221 | # Generate the forward and Sensitivity Analysis plots 222 | plot_quant_act_SA_hist( 223 | act_forward_histograms, 224 | act_backward_histograms, 225 | file_path=Path(os.path.abspath("") + f"/Histogram_plots/{title}"), 226 | sum_pos_1=[0.18, 0.60, 0.1, 0.1], # location of the first mean intra-bin plot 227 | sum_pos_2=[0.75, 0.43, 0.1, 0.1], 228 | plot_title="SA act hists", 229 | module_name_mapping=None, 230 | bit_res=8, # This should match the quantization resolution. Changing this will not change the model quantization, only the plots. 231 | ) 232 | 233 | 234 | # WEIGHT PLOTS 235 | # Clear gradients with the sake of an otherwised unused optimizer 236 | def create_weight_plots(model, title): 237 | from torch.optim import Adam 238 | 239 | optimizer = Adam(model.parameters(), lr=1) 240 | optimizer.zero_grad() 241 | 242 | # Check gradients cleared 243 | for parameter in model.parameters(): 244 | assert parameter.grad is None 245 | 246 | # Produce new gradients 247 | forward_and_backprop_an_image(model) 248 | 249 | # Check gradients exist 250 | for parameter in model.parameters(): 251 | assert parameter.grad is not None 252 | 253 | # Create the weight histogram plots, this time with Sensitivity Analysis 254 | # plots 255 | plot_quant_weight_hist( 256 | model, 257 | file_path=Path(os.path.abspath("") + f"/Histogram_plots/{title}"), 258 | plot_title="SA weight hists", 259 | module_name_mapping=None, 260 | conditions_met=None, 261 | # The below flag specifies that we should also do a Sensitivity 262 | # Analysis 263 | sensitivity_analysis=True, 264 | ) 265 | 266 | 267 | # # Cross Layer Equalized plots 268 | # create_act_plots(fx_model_w_cle, "FX, with CLE") 269 | # create_weight_plots(fx_model_w_cle, "FX, with CLE") 270 | # 271 | # # Original, non-Cross Layer Equalized plots 272 | # create_act_plots(fx_model_no_cle, "FX, no CLE") 273 | # create_weight_plots(fx_model_no_cle, "FX, no CLE") 274 | weight_tensor = fx_model_w_cle.layer2.get_submodule("0").conv1.weight 275 | per_channel_boxplots(weight_tensor, title="layer2.0.conv1", CLE=True) 276 | 277 | weight_tensor = fx_model_no_cle.layer2.get_submodule("0").conv1.weight 278 | per_channel_boxplots(weight_tensor, title="layer2.0.conv1", CLE=False) 279 | 280 | 281 | ############### 282 | # EVALUATIONS # 283 | ############### 284 | # Print out some paramaters before we do CLE 285 | def print_scale_and_zp(model: torch.nn.Module, module_name: str): 286 | module = model.get_submodule(module_name) 287 | scale = module.scale 288 | zero_point = module.zero_point 289 | if len(scale) == 1: 290 | print( 291 | f"{module_name} scale and zero_point: {scale.item():.5}, {zero_point.item()}" 292 | ) 293 | else: 294 | print(f"{module_name} scale and zero_point: {scale}, {zero_point}") 295 | 296 | 297 | print("\nWithout CLE:") 298 | print_scale_and_zp(fx_model_no_cle, "layer2.0.conv1.weight_fake_quant") 299 | 300 | print("\nAfter CLE:") 301 | print_scale_and_zp(fx_model_w_cle, "layer2.0.conv1.weight_fake_quant") 302 | 303 | XXX 304 | -------------------------------------------------------------------------------- /Resnet-FX-CLE/model/resnet.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Callable, List, Optional, Type, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch import Tensor 7 | 8 | from torchvision.transforms._presets import ImageClassification 9 | from torchvision.utils import _log_api_usage_once 10 | from torchvision.models._api import register_model, Weights, WeightsEnum 11 | from torchvision.models._meta import _IMAGENET_CATEGORIES 12 | from torchvision.models._utils import _ovewrite_named_param, handle_legacy_interface 13 | 14 | 15 | __all__ = [ 16 | "ResNet", 17 | "ResNet18_Weights", 18 | "resnet18", 19 | ] 20 | 21 | 22 | def conv3x3( 23 | in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1 24 | ) -> nn.Conv2d: 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d( 27 | in_planes, 28 | out_planes, 29 | kernel_size=3, 30 | stride=stride, 31 | padding=dilation, 32 | groups=groups, 33 | bias=False, 34 | dilation=dilation, 35 | ) 36 | 37 | 38 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 39 | """1x1 convolution""" 40 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 41 | 42 | 43 | class BasicBlock(nn.Module): 44 | expansion: int = 1 45 | 46 | def __init__( 47 | self, 48 | inplanes: int, 49 | planes: int, 50 | stride: int = 1, 51 | downsample: Optional[nn.Module] = None, 52 | groups: int = 1, 53 | base_width: int = 64, 54 | dilation: int = 1, 55 | norm_layer: Optional[Callable[..., nn.Module]] = None, 56 | ) -> None: 57 | super().__init__() 58 | if norm_layer is None: 59 | norm_layer = nn.BatchNorm2d 60 | if groups != 1 or base_width != 64: 61 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 62 | if dilation > 1: 63 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 64 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 65 | self.conv1 = conv3x3(inplanes, planes, stride) 66 | self.bn1 = norm_layer(planes) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.conv2 = conv3x3(planes, planes) 69 | self.bn2 = norm_layer(planes) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x: Tensor) -> Tensor: 74 | identity = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | 83 | if self.downsample is not None: 84 | identity = self.downsample(x) 85 | 86 | out += identity 87 | out = self.relu(out) 88 | 89 | return out 90 | 91 | 92 | class Bottleneck(nn.Module): 93 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 94 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 95 | # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385. 96 | # This variant is also known as ResNet V1.5 and improves accuracy according to 97 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 98 | 99 | expansion: int = 4 100 | 101 | def __init__( 102 | self, 103 | inplanes: int, 104 | planes: int, 105 | stride: int = 1, 106 | downsample: Optional[nn.Module] = None, 107 | groups: int = 1, 108 | base_width: int = 64, 109 | dilation: int = 1, 110 | norm_layer: Optional[Callable[..., nn.Module]] = None, 111 | ) -> None: 112 | super().__init__() 113 | if norm_layer is None: 114 | norm_layer = nn.BatchNorm2d 115 | width = int(planes * (base_width / 64.0)) * groups 116 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 117 | self.conv1 = conv1x1(inplanes, width) 118 | self.bn1 = norm_layer(width) 119 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 120 | self.bn2 = norm_layer(width) 121 | self.conv3 = conv1x1(width, planes * self.expansion) 122 | self.bn3 = norm_layer(planes * self.expansion) 123 | self.relu = nn.ReLU(inplace=True) 124 | self.downsample = downsample 125 | self.stride = stride 126 | 127 | def forward(self, x: Tensor) -> Tensor: 128 | identity = x 129 | 130 | out = self.conv1(x) 131 | out = self.bn1(out) 132 | out = self.relu(out) 133 | 134 | out = self.conv2(out) 135 | out = self.bn2(out) 136 | out = self.relu(out) 137 | 138 | out = self.conv3(out) 139 | out = self.bn3(out) 140 | 141 | if self.downsample is not None: 142 | identity = self.downsample(x) 143 | 144 | out += identity 145 | out = self.relu(out) 146 | 147 | return out 148 | 149 | 150 | class ResNet(nn.Module): 151 | def __init__( 152 | self, 153 | block: Type[Union[BasicBlock, Bottleneck]], 154 | layers: List[int], 155 | num_classes: int = 1000, 156 | zero_init_residual: bool = False, 157 | groups: int = 1, 158 | width_per_group: int = 64, 159 | replace_stride_with_dilation: Optional[List[bool]] = None, 160 | norm_layer: Optional[Callable[..., nn.Module]] = None, 161 | ) -> None: 162 | super().__init__() 163 | _log_api_usage_once(self) 164 | if norm_layer is None: 165 | norm_layer = nn.BatchNorm2d 166 | self._norm_layer = norm_layer 167 | 168 | self.inplanes = 64 169 | self.dilation = 1 170 | if replace_stride_with_dilation is None: 171 | # each element in the tuple indicates if we should replace 172 | # the 2x2 stride with a dilated convolution instead 173 | replace_stride_with_dilation = [False, False, False] 174 | if len(replace_stride_with_dilation) != 3: 175 | raise ValueError( 176 | "replace_stride_with_dilation should be None " 177 | f"or a 3-element tuple, got {replace_stride_with_dilation}" 178 | ) 179 | self.groups = groups 180 | self.base_width = width_per_group 181 | self.conv1 = nn.Conv2d( 182 | 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False 183 | ) 184 | self.bn1 = norm_layer(self.inplanes) 185 | self.relu = nn.ReLU(inplace=True) 186 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 187 | self.layer1 = self._make_layer(block, 64, layers[0]) 188 | self.layer2 = self._make_layer( 189 | block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] 190 | ) 191 | self.layer3 = self._make_layer( 192 | block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] 193 | ) 194 | self.layer4 = self._make_layer( 195 | block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] 196 | ) 197 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 198 | self.fc = nn.Linear(512 * block.expansion, num_classes) 199 | 200 | for m in self.modules(): 201 | if isinstance(m, nn.Conv2d): 202 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 203 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 204 | nn.init.constant_(m.weight, 1) 205 | nn.init.constant_(m.bias, 0) 206 | 207 | # Zero-initialize the last BN in each residual branch, 208 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 209 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 210 | if zero_init_residual: 211 | for m in self.modules(): 212 | if isinstance(m, Bottleneck) and m.bn3.weight is not None: 213 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 214 | elif isinstance(m, BasicBlock) and m.bn2.weight is not None: 215 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 216 | 217 | def _make_layer( 218 | self, 219 | block: Type[Union[BasicBlock, Bottleneck]], 220 | planes: int, 221 | blocks: int, 222 | stride: int = 1, 223 | dilate: bool = False, 224 | ) -> nn.Sequential: 225 | norm_layer = self._norm_layer 226 | downsample = None 227 | previous_dilation = self.dilation 228 | if dilate: 229 | self.dilation *= stride 230 | stride = 1 231 | if stride != 1 or self.inplanes != planes * block.expansion: 232 | downsample = nn.Sequential( 233 | conv1x1(self.inplanes, planes * block.expansion, stride), 234 | norm_layer(planes * block.expansion), 235 | ) 236 | 237 | layers = [] 238 | layers.append( 239 | block( 240 | self.inplanes, 241 | planes, 242 | stride, 243 | downsample, 244 | self.groups, 245 | self.base_width, 246 | previous_dilation, 247 | norm_layer, 248 | ) 249 | ) 250 | self.inplanes = planes * block.expansion 251 | for _ in range(1, blocks): 252 | layers.append( 253 | block( 254 | self.inplanes, 255 | planes, 256 | groups=self.groups, 257 | base_width=self.base_width, 258 | dilation=self.dilation, 259 | norm_layer=norm_layer, 260 | ) 261 | ) 262 | 263 | return nn.Sequential(*layers) 264 | 265 | def _forward_impl(self, x: Tensor) -> Tensor: 266 | # See note [TorchScript super()] 267 | x = self.conv1(x) 268 | x = self.bn1(x) 269 | x = self.relu(x) 270 | x = self.maxpool(x) 271 | 272 | x = self.layer1(x) 273 | x = self.layer2(x) 274 | x = self.layer3(x) 275 | x = self.layer4(x) 276 | 277 | x = self.avgpool(x) 278 | x = torch.flatten(x, 1) 279 | x = self.fc(x) 280 | 281 | return x 282 | 283 | def forward(self, x: Tensor) -> Tensor: 284 | return self._forward_impl(x) 285 | 286 | 287 | def _resnet( 288 | block: Type[Union[BasicBlock, Bottleneck]], 289 | layers: List[int], 290 | weights: Optional[WeightsEnum], 291 | progress: bool, 292 | **kwargs: Any, 293 | ) -> ResNet: 294 | if weights is not None: 295 | _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) 296 | 297 | model = ResNet(block, layers, **kwargs) 298 | 299 | if weights is not None: 300 | model.load_state_dict( 301 | weights.get_state_dict(progress=progress, check_hash=True) 302 | ) 303 | 304 | return model 305 | 306 | 307 | _COMMON_META = { 308 | "min_size": (1, 1), 309 | "categories": _IMAGENET_CATEGORIES, 310 | } 311 | 312 | 313 | class ResNet18_Weights(WeightsEnum): 314 | IMAGENET1K_V1 = Weights( 315 | url="https://download.pytorch.org/models/resnet18-f37072fd.pth", 316 | transforms=partial(ImageClassification, crop_size=224), 317 | meta={ 318 | **_COMMON_META, 319 | "num_params": 11689512, 320 | "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", 321 | "_metrics": { 322 | "ImageNet-1K": { 323 | "acc@1": 69.758, 324 | "acc@5": 89.078, 325 | } 326 | }, 327 | "_ops": 1.814, 328 | "_file_size": 44.661, 329 | "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", 330 | }, 331 | ) 332 | DEFAULT = IMAGENET1K_V1 333 | 334 | 335 | @handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1)) 336 | def resnet18( 337 | *, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any 338 | ) -> ResNet: 339 | """ResNet-18 from `Deep Residual Learning for Image Recognition `__. 340 | 341 | Args: 342 | weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The 343 | pretrained weights to use. See 344 | :class:`~torchvision.models.ResNet18_Weights` below for 345 | more details, and possible values. By default, no pre-trained 346 | weights are used. 347 | progress (bool, optional): If True, displays a progress bar of the 348 | download to stderr. Default is True. 349 | **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` 350 | base class. Please refer to the `source code 351 | `_ 352 | for more details about this class. 353 | 354 | .. autoclass:: torchvision.models.ResNet18_Weights 355 | :members: 356 | """ 357 | weights = ResNet18_Weights.verify(weights) 358 | 359 | return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) 360 | -------------------------------------------------------------------------------- /Resnet-FX-CLE/quant_vis/histograms/__init__.py: -------------------------------------------------------------------------------- 1 | from .plots import plot_quant_act_SA_hist, plot_quant_weight_hist, plot_quant_act_hist 2 | from .hooks import ( 3 | activation_forward_histogram_hook, 4 | add_activation_forward_hooks, 5 | add_sensitivity_analysis_hooks, 6 | add_sensitivity_backward_hooks, 7 | backwards_SA_histogram_hook, 8 | ) 9 | -------------------------------------------------------------------------------- /Resnet-FX-CLE/quant_vis/histograms/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | from .forward_hooks import ( 2 | activation_forward_histogram_hook, 3 | add_activation_forward_hooks, 4 | ) 5 | from .sa_back_hooks import ( 6 | add_sensitivity_analysis_hooks, 7 | add_sensitivity_backward_hooks, 8 | backwards_SA_histogram_hook, 9 | ) 10 | -------------------------------------------------------------------------------- /Resnet-FX-CLE/quant_vis/histograms/hooks/forward_hooks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.quantization._numeric_suite as ns 3 | from ...utils.act_histogram import ActHistogram 4 | from ...utils.hooks import is_model_quantizable 5 | from utils.dotdict import dotdict 6 | 7 | from ...settings import HIST_XMIN, HIST_XMAX, HIST_QUANT_BIN_RATIO 8 | 9 | from typing import Callable, Union 10 | from utils.logger import setup_logger 11 | 12 | # Configure the logger 13 | logger = setup_logger(__name__) 14 | 15 | 16 | def activation_forward_histogram_hook( 17 | act_histogram: ActHistogram, name: str, qscheme: torch.qscheme, bit_res: int = 8 18 | ): 19 | """ 20 | A pre-forward hook that measures the floating-point activation being fed into a quantization module. 21 | This hook calculates a histogram, with the bins given by the quantization module's qparams, 22 | and stores the histogram in a global class. 23 | If the histogram for the given quantization module has not yet been initialised, this hook initialises 24 | it as an entry in a dict. If it has been initialised, this hook adds to it. 25 | 26 | Therefore, as more and more data is fed throuhg the quantization module and this hook, 27 | the histogram will accumulate the frequencies of all of the binned values. 28 | 29 | activation_histogram_hook inputs: 30 | - act_histogram (ActHistogram): a dataclass instance that stores the activation histograms and hook handles. 31 | - name (str): the name of the module, and how its histogram will be stored in the dict. 32 | - qscheme (torch.qscheme): the qscheme of the quantization module. 33 | - bit_res (int): the quantization bit width of the tensor, e.g. 8 for int8. 34 | 35 | hook inputs: 36 | - module: the quantization module. 37 | - input: the activation fed to the quantization module. 38 | """ 39 | 40 | def hook(module, input): 41 | # Ensure we are in eval mode, and ensure that this is not during a Shadow conversion check. 42 | if not module.training and type(module) is not ns.Shadow: 43 | 44 | # Get number of quantization bins from the quantization bit width 45 | qrange = 2**bit_res 46 | 47 | local_input = input[0].detach().cpu() 48 | 49 | # If the entry in the `act_histogram` dict has not been initialised, i.e. this is the first forward pass 50 | # for this module 51 | if name not in act_histogram.data: 52 | # We calculate the limits of the histogram. These are dependent on the qparams, as well as how 53 | # much "buffer" we want on either side of the quantization range, defined by `HIST_XMIN` and 54 | # `HIST_XMAX` and the qparams. 55 | hist_min_bin = (-HIST_XMIN * qrange - module.zero_point) * module.scale 56 | hist_max_bin = ( 57 | (HIST_XMAX + 1) * qrange - module.zero_point 58 | ) * module.scale 59 | 60 | # If symmetric quantization, we offset the range by half. 61 | if qscheme in ( 62 | torch.per_channel_symmetric, 63 | torch.per_tensor_symmetric, 64 | ): 65 | hist_min_bin -= qrange / 2 * module.scale 66 | hist_max_bin -= qrange / 2 * module.scale 67 | 68 | # Create the histogram bins, with `HIST_QUANT_BIN_RATIO` histogram bins per quantization bin. 69 | hist_bins = ( 70 | torch.arange( 71 | hist_min_bin.item(), 72 | hist_max_bin.item(), 73 | (module.scale / HIST_QUANT_BIN_RATIO).item(), 74 | ) 75 | - (0.5 * module.scale / HIST_QUANT_BIN_RATIO).item() 76 | # NOTE: offset by half a quant bin fraction, so that quantization centroids 77 | # fall into the middle of a histogram bin. 78 | ) 79 | # TODO: figure out a way to do this histogram on CUDA 80 | tensor_histogram = torch.histogram(local_input, bins=hist_bins) 81 | 82 | # Create a map between the histogram and values by using torch.bucketize() 83 | # The idea is to be able to map the gradients to the same histogram bins 84 | bin_indices = torch.bucketize(local_input, tensor_histogram.bin_edges) 85 | 86 | # Initialise stored histogram for this quant module 87 | stored_histogram = dotdict() 88 | stored_histogram.hist = tensor_histogram.hist 89 | stored_histogram.bin_edges = tensor_histogram.bin_edges 90 | stored_histogram.bin_indices = bin_indices 91 | 92 | # Store final dict in `act_histogram` 93 | act_histogram.data[name] = stored_histogram 94 | 95 | # This histogram entry for this quant module has already been intialised. 96 | else: 97 | # We use the stored histogram bins to bin the incoming activation, and add its 98 | # frequencies to the histogram. 99 | histogram = torch.histogram( 100 | local_input, 101 | bins=act_histogram.data[name].bin_edges.cpu(), 102 | ) 103 | act_histogram.data[name].hist += histogram.hist 104 | 105 | # We overwrite the bin indices with the most recent bin indices 106 | bin_indices = torch.bucketize(local_input, histogram.bin_edges) 107 | act_histogram.data[name].bin_indices = bin_indices 108 | 109 | return hook 110 | 111 | 112 | def add_activation_forward_hooks( 113 | model: torch.nn.Module, 114 | conditions_met: Union[Callable, None] = None, 115 | bit_res: int = 8, 116 | ): 117 | """ 118 | This function adds forward activation hooks to the quantization modules in the model, if their names 119 | match any of the patterns in `act_histogram.accepted_module_name_patterns`. 120 | These hooks measure and store an aggregated histogram, with the bins defined by the quantization 121 | grid. This tells us how the activation data is distributed on the quantization grid. 122 | 123 | Inputs: 124 | - model (torch.nn.Module): the model we will be adding hooks to. 125 | - conditions_met (Callable): a function that returns True if the conditons are met for 126 | adding a hook to a module, and false otherwise. Defaults to None. 127 | - bit_res (int): the quantization bit width of the tensor, e.g. 8 for int8. 128 | 129 | Returns: 130 | - act_histograms (ActHistogram): A dataclass instance that contains the stored histograms 131 | and hook handles. 132 | """ 133 | 134 | # If the conditons are met for adding hooks 135 | if not is_model_quantizable(model, "activation"): 136 | logger.warning(f"None of the model activations are quantizable") 137 | return 138 | 139 | logger.warning( 140 | f"\nAdding forward activation histogram hooks. This will significantly slow down the forward calls for " 141 | "the targetted modules." 142 | ) 143 | 144 | # We intialise a new ActHistogram instance, which will be responsible for containing the 145 | # activation histogram data 146 | act_histograms = ActHistogram(data={}, hook_handles={}) 147 | 148 | # Add activation-hist pre-forward hooks to the desired quantizable module 149 | for name, module in model.named_modules(): 150 | if hasattr(module, "fake_quant_enabled") and "weight_fake_quant" not in name: 151 | if conditions_met and not conditions_met(module, name): 152 | logger.debug( 153 | f"The conditons for adding an activation hook to module {name} were not met." 154 | ) 155 | continue 156 | 157 | hook_handle = module.register_forward_pre_hook( 158 | activation_forward_histogram_hook( 159 | act_histograms, 160 | name, 161 | module.qscheme, 162 | bit_res=bit_res, 163 | ) 164 | ) 165 | # We store the hook handles so that we can remove the hooks once we have finished 166 | # accumulating the histograms. 167 | act_histograms.hook_handles[name] = hook_handle 168 | 169 | return act_histograms 170 | -------------------------------------------------------------------------------- /Resnet-FX-CLE/quant_vis/histograms/hooks/sa_back_hooks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ...utils.act_histogram import ActHistogram 4 | from .forward_hooks import add_activation_forward_hooks 5 | from utils.dotdict import dotdict 6 | 7 | from typing import Union, Callable 8 | from utils.logger import setup_logger 9 | 10 | # Configure the logger 11 | logger = setup_logger(__name__) 12 | 13 | 14 | def add_sensitivity_analysis_hooks( 15 | model: torch.nn.Module, 16 | conditions_met: Union[Callable, None] = None, 17 | bit_res: int = 8, 18 | ): 19 | """ 20 | Adds the required forward and baxckwards hooks to gather the required data for the 21 | forward and backwards histograms, for the combined forward / sensitivity analysis 22 | plots. 23 | NOTE: the `bit_res` parameter does not control the quantization reoslution of the model, only of the 24 | histograms. Ideally they should match. 25 | 26 | `conditions_met` (Callable): This is a function that takes in a module and its name, and returns a boolean 27 | indicating whether one should add the hook to it or not. 28 | Example: 29 | ``` 30 | def conditions_met_forward_act_hook(module: torch.nn.Module, name: str) -> bool: 31 | if "hello" in name: 32 | return True 33 | else: 34 | return False 35 | ``` 36 | """ 37 | act_forward_histograms = add_activation_forward_hooks( 38 | model, conditions_met=conditions_met, bit_res=bit_res 39 | ) 40 | act_backward_histograms = add_sensitivity_backward_hooks( 41 | model, act_forward_histograms 42 | ) 43 | 44 | return act_forward_histograms, act_backward_histograms 45 | 46 | 47 | def add_sensitivity_backward_hooks( 48 | model: torch.nn.Module, act_forward_histograms: ActHistogram 49 | ): 50 | """ 51 | Adds the backwards hooks that gather the gradients and sums them up according to the forward 52 | histogram bins, so that one gets the summed gradients for each histogrma bin. If the output of the 53 | model is backpropagated without any manipulation of the loss, then these gradients will correspond 54 | to the relative contribution of each quantization bin to the output. 55 | """ 56 | 57 | # We intialise a new ActHistogram instance, which will be responsible for containing the 58 | # backwards pass data 59 | act_backward_histograms = ActHistogram(data={}, hook_handles={}) 60 | for module_name in act_forward_histograms.hook_handles.keys(): 61 | module = model.get_submodule(module_name) 62 | hook_handle = module.register_full_backward_hook( 63 | backwards_SA_histogram_hook( 64 | act_forward_histograms, act_backward_histograms, module_name 65 | ) 66 | ) 67 | 68 | # We store the hook handles so that we can remove the hooks once we have finished 69 | # accumulating the backwards gradients. 70 | act_backward_histograms.hook_handles[module_name] = hook_handle 71 | 72 | return act_backward_histograms 73 | 74 | 75 | def backwards_SA_histogram_hook( 76 | act_forward_histograms: ActHistogram, 77 | act_backward_histograms: ActHistogram, 78 | name: str, 79 | ): 80 | """ 81 | A backward hook that measures the gradient being fed back through a quantization module. 82 | 83 | It requires that the `add_activation_forward_hooks` be called first. 84 | 85 | The hook will capture the backwards gradient, and map it to the same histogram bins as were 86 | used for the histograms in the forward hook. This will make it so that the gradients will be summed 87 | into bins that correspond to the forward values, so that they can be associated. If the output of the 88 | model is backpropagated without any manipulation of the loss, then these gradients will correspond 89 | to the relative contribution of each quantization bin to the output. I.e., it will correspond to a 90 | sensitivity analysis. 91 | 92 | backwards_histogram_hook inputs: 93 | - act_histogram (ActHistogram): a dataclass instance that stores the activation histograms and hook handles. 94 | - name (str): the name of the module, and how its histogram will be stored in the dict. 95 | 96 | hook inputs: 97 | - module: the quantization module. 98 | - inp_grad: input gradient. 99 | - out_grad: output gradient. 100 | """ 101 | 102 | def hook(module, inp_grad, out_grad): 103 | if name not in act_forward_histograms.data: 104 | return 105 | 106 | # Access the values-to-histogram-bins mapping from the forward call 107 | bin_indices = act_forward_histograms.data[name].bin_indices - 1 108 | grad = out_grad[0] 109 | 110 | # Compute the sum of gradients, with the forward histogram bins, using torch.bincount() 111 | size_diff = ( 112 | act_forward_histograms.data[name].hist.size()[0] - bin_indices.max() - 1 113 | ) 114 | padding = torch.zeros(size_diff) 115 | binned_grads = torch.concat( 116 | [torch.bincount(bin_indices.flatten(), weights=grad.flatten()), padding] 117 | ) 118 | 119 | # Store the summed gradients into the dataclass 120 | back = dotdict() 121 | back.binned_grads = binned_grads 122 | act_backward_histograms.data[name] = back 123 | 124 | return hook 125 | -------------------------------------------------------------------------------- /Resnet-FX-CLE/quant_vis/histograms/plots/__init__.py: -------------------------------------------------------------------------------- 1 | from .plot_histograms import ( 2 | plot_quant_act_hist, 3 | plot_quant_act_SA_hist, 4 | plot_quant_weight_hist, 5 | ) 6 | -------------------------------------------------------------------------------- /Resnet-FX-CLE/quant_vis/histograms/plots/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import matplotlib 5 | import os 6 | 7 | matplotlib.use("Agg") 8 | 9 | from ...settings import HIST_QUANT_BIN_RATIO, HIST_XMAX, HIST_XMIN 10 | 11 | 12 | ############ 13 | # SUBPLOTS # 14 | ############ 15 | def fill_in_mean_subplot( 16 | distribution: torch.Tensor, 17 | zero_bin_value: torch.Tensor, 18 | clamped_prob_mass: torch.Tensor, 19 | ax_sub: matplotlib.axes._axes.Axes, 20 | color: str = "blue", 21 | data_name: str = "", 22 | ): 23 | """ 24 | Fills in the summary sub-plot. This involves calculating the mean intra-bin values, and plotting them. 25 | We also add a few interesting statistics: 26 | - the amount of not-on-bin-centroid probability mass 27 | - the zero-bin value 28 | - the amount of clamped probability mass 29 | 30 | Inputs: 31 | - distribution (torch.Tensor): the PDF we will be getting the mean intra-bin plot of. 32 | - zero_bin_value (torch.Tensor): the zero bin probability mass value. 33 | - clamped_prob_mass (torch.Tensor): the clamped probability mass scalar value. 34 | - ax_sub (matplotlib.axes._axes.Axes) = the Axes object we will manipulate to fill in the subplot. 35 | - color (str): the color of the imean intra-bin plot. 36 | - data_name (str): part of the subtitle, e.g. "Forward Activation", "Gradient", etc. 37 | """ 38 | # Sum every HIST_QUANT_BIN_RATIO'th value in the histogram. 39 | intra_bin = torch.zeros(HIST_QUANT_BIN_RATIO) 40 | for step in torch.arange(HIST_QUANT_BIN_RATIO): 41 | intra_bin[step] = distribution[step::HIST_QUANT_BIN_RATIO].sum() 42 | indices = range(HIST_QUANT_BIN_RATIO) 43 | intra_bin = intra_bin.numpy() 44 | 45 | # Plot the intra-bin behavior as subplot 46 | ax_sub.bar(indices, intra_bin, color=color) 47 | 48 | # Remove tick labels and set background to transparent for the overlay subplot 49 | ax_sub.set_xticks(np.arange(0, HIST_QUANT_BIN_RATIO + 1, HIST_QUANT_BIN_RATIO / 2)) 50 | ax_sub.set_xticklabels( 51 | [ 52 | f"{int(i)}/{HIST_QUANT_BIN_RATIO}" 53 | for i in np.arange(0, HIST_QUANT_BIN_RATIO + 1, HIST_QUANT_BIN_RATIO / 2) 54 | ] 55 | ) 56 | ax_sub.set_xlim(-0.5, HIST_QUANT_BIN_RATIO + 0.5) 57 | ax_sub.patch.set_alpha(1) 58 | 59 | # Add title (with summary-ish statistics) and labels 60 | title_str = f"{data_name}\nMean Intra-bin Behavior\n(Not-on-quant-bin-centroid\nprob mass: {intra_bin[1:].sum():.2f})\nZero-bin mass: {zero_bin_value:.2f}" 61 | title_str += f"\nClamped prob mass: {clamped_prob_mass:.6f}" 62 | ax_sub.set_title(title_str) 63 | ax_sub.set_ylabel("Prob") 64 | ax_sub.set_xlabel("Bins (0 and 1 are centroids)") 65 | ax_sub.axvline(x=0, color="black", linewidth=1) 66 | ax_sub.axvline(x=HIST_QUANT_BIN_RATIO, color="black", linewidth=1) 67 | 68 | 69 | def draw_centroids_and_tensor_range( 70 | ax: matplotlib.axes._axes.Axes, 71 | bin_edges: torch.Tensor, 72 | qrange: int, 73 | tensor_min_index: torch.Tensor, 74 | tensor_max_index: torch.Tensor, 75 | scale: torch.Tensor, 76 | ): 77 | """ 78 | Draws black vertical lines at each quantization centroid, and adds thick red lines at the edges 79 | of the floating point tensor, i.e. highlights its dynamic range. 80 | 81 | Inputs: 82 | - ax (matplotlib.axes._axes.Axes): the Axes object we will be manipulating to add the plot elements. 83 | - bin_edges (torch.Tensor): the histogram bin edges 84 | - qrange (int): the number of quantization bins 85 | - tensor_min_index (torch.Tensor): the minimum value in the floating point tensor. 86 | - tensor_max_index (torch.Tensor): the maximum value in the floating point tensor. 87 | - scale (torch.Tensor): the quantization scale. 88 | """ 89 | # Draws black vertical lines 90 | for index, x_val in enumerate( 91 | np.arange( 92 | start=bin_edges[int(HIST_XMIN * qrange * HIST_QUANT_BIN_RATIO)], 93 | stop=bin_edges[-int(HIST_XMAX * qrange * HIST_QUANT_BIN_RATIO)], 94 | step=scale, 95 | ) 96 | ): 97 | if index == 0: 98 | ax.axvline( 99 | x=x_val, 100 | color="black", 101 | linewidth=0.08, 102 | label="Quantization bin centroids", 103 | ) 104 | else: 105 | ax.axvline(x=x_val, color="black", linewidth=0.08) 106 | 107 | # Draw vertical lines at dynamic range boundaries of forward tensor (1 quantization bin padding) 108 | ax.axvline( 109 | x=bin_edges[tensor_min_index] - scale, 110 | color="red", 111 | linewidth=1, 112 | label="Tensor dynamic range", 113 | ) 114 | ax.axvline(x=bin_edges[tensor_max_index] + scale, color="red", linewidth=1) 115 | 116 | 117 | ################### 118 | # DATA PROCESSING # 119 | ################### 120 | def get_prob_mass_outside_quant_range( 121 | distribution: torch.Tensor, qrange: int 122 | ) -> torch.Tensor: 123 | """ 124 | Returns the amount of probability mass outside the quantization range. 125 | """ 126 | clamped_prob_mass = torch.sum( 127 | distribution[: int(HIST_XMIN * qrange * HIST_QUANT_BIN_RATIO)] 128 | ) + torch.sum(distribution[int((HIST_XMIN + 1) * qrange * HIST_QUANT_BIN_RATIO) :]) 129 | return clamped_prob_mass 130 | 131 | 132 | def moving_average(input_tensor, window_size): 133 | """ 134 | Get a 1d moving average of a 1D torch tensor, used for creating a smoothed 135 | data distribution for the histograms. 136 | """ 137 | # Create a 1D convolution kernel filled with ones 138 | kernel = torch.ones(1, 1, window_size) / window_size 139 | 140 | # Apply padding to handle boundary elements 141 | padding = (window_size - 1) // 2 142 | 143 | # Apply the convolution operation 144 | output_tensor = F.conv1d( 145 | input_tensor.unsqueeze(0).unsqueeze(0), kernel, padding=padding 146 | ) 147 | 148 | return output_tensor.squeeze() 149 | 150 | 151 | ########### 152 | # PATHING # 153 | ########### 154 | 155 | 156 | def create_double_level_plot_folder(file_path: str, lvl_1: str, lvl_2: str) -> str: 157 | weight_plot_folder = file_path / lvl_1 / lvl_2 158 | if not os.path.exists(weight_plot_folder): 159 | os.makedirs(weight_plot_folder, exist_ok=True) 160 | return weight_plot_folder 161 | -------------------------------------------------------------------------------- /Resnet-FX-CLE/quant_vis/histograms/plots/weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple, Union 3 | 4 | from ...settings import HIST_QUANT_BIN_RATIO, HIST_XMAX, HIST_XMIN 5 | 6 | from utils.logger import setup_logger 7 | 8 | # Configure logger 9 | logger = setup_logger(__name__) 10 | 11 | 12 | def get_weight_quant_histogram( 13 | weight: torch.nn.Parameter, 14 | scale: torch.nn.Parameter, 15 | zero_point: torch.nn.Parameter, 16 | qscheme: torch.qscheme, 17 | sensitivity_analysis: bool = False, 18 | bit_res: int = 8, 19 | ) -> Union[Tuple[torch.Tensor, torch.Tensor], Union[torch.tensor, None]]: 20 | """ 21 | Calculates the histogram of the weight, with bins defined by its scale and zero-point. 22 | Unlike the activation, we plot the weight tensor on the integer scale. This is because: 23 | 1) The weight tensor values are difficult to interpret anyway, so there isn't much to gain from the original scale. 24 | 2) Normalizing by each channel's quantization parameters makes sense, so we can aggregate across channels. 25 | 26 | Inputs: 27 | - weight (torch.nn.Parameter): a weight tensor 28 | - scale (torch.nn.Parameter): a qparam scale. This can be a single parameter, or in the case of per-channel quantization, a tensor with len > 1. 29 | - zero_point (torch.nn.Parameter): a qparam zero_point. This can be a single parameter, or in the case of per-channel quantization, a tensor with len > 1. 30 | - qscheme: specifies the quantization scheme of the weight tensor. 31 | - sensitivity_analysis (bool): whether ot nor, if we have grads, should we plot the sensitivity analysis for the weights 32 | - bit_res (int): the quantization bit width, e.g. 8 for 8-bit quantization. 33 | 34 | Outputs: 35 | - hist (Tuple[torch.Tensor, torch.Tensor]): torch.histogram output instance, with histogram and bin edges. 36 | """ 37 | 38 | if qscheme in [torch.per_channel_affine, torch.per_channel_symmetric]: 39 | scale = scale.view(len(scale), 1, 1, 1) 40 | zero_point = zero_point.view(len(zero_point), 1, 1, 1) 41 | elif qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]: 42 | pass 43 | else: 44 | raise ValueError( 45 | "`qscheme` variable should be per-channel symmetric or affine, or per-tensor symmetric or affine" 46 | ) 47 | 48 | # Weight tensor in fake-quantized space 49 | fake_quant_tensor = weight.detach() / scale.detach() + zero_point.detach() 50 | 51 | # Flatten the weight tensor 52 | fake_quant_tensor = fake_quant_tensor.reshape(-1) 53 | 54 | # Get number of quantization bins from the quantization bit width 55 | qrange = 2**bit_res 56 | 57 | # Calculate the histogram between `-HIST_XMIN * qrange` and `(1+HIST_MAX_XLIM) * qrange`, with `HIST_QUANT_BIN_RATIO` samples per quantization bin. 58 | # This covers space on either side of the 0-qrange quantization range, so we can see any overflow, i.e clamping. 59 | hist_bins = torch.arange( 60 | -HIST_XMIN * qrange, 61 | (1 + HIST_XMAX) * qrange, 62 | 1 / HIST_QUANT_BIN_RATIO, 63 | ) 64 | # If we are doing symmetric quantization, center the range at 0. 65 | if qscheme in (torch.per_channel_symmetric, torch.per_tensor_symmetric): 66 | hist_bins -= qrange / 2 67 | 68 | fake_quant_tensor = fake_quant_tensor.cpu() 69 | hist = torch.histogram(fake_quant_tensor, bins=hist_bins) 70 | 71 | if sensitivity_analysis and weight.grad is not None: 72 | # Create a map between the histogram and values by using torch.bucketize() 73 | # The idea is to be able to map the gradients to the same histogram bins 74 | bin_indices = torch.bucketize( 75 | fake_quant_tensor, hist.bin_edges 76 | ) # , right=True) 77 | 78 | # Compute the sum of gradients, with the forward histogram bins, using torch.bincount() 79 | binned_grads = torch.bincount( 80 | bin_indices.flatten(), weights=weight.grad.flatten() 81 | ) 82 | 83 | # Padding may be required, if the bin_indices (which stop at the index of the maximum value 84 | # of `fake_quant_tensor` when mapped on to the histogram) is smaller than the maximum 85 | # histogram bin value. I.e., if the tensor doesn't fill the rightmost bin of the histogram. 86 | size_diff = hist_bins.size()[0] - bin_indices.max() - 2 87 | if size_diff > 0: 88 | # We add zeros to the end, as bincount automatically zero-pads the beginning as needed 89 | padding = torch.zeros(size_diff) 90 | binned_grads = torch.concat([binned_grads, padding]) 91 | elif size_diff == -2: 92 | binned_grads = binned_grads[1:-1] 93 | elif size_diff == -1: 94 | binned_grads = binned_grads[:-1] 95 | 96 | return hist, binned_grads 97 | 98 | elif sensitivity_analysis: 99 | logger.warning( 100 | "`get_weight_quant_histogram` provided `sensitivity_analysis=True`, but the weight tensor does not have any attached gradients." 101 | ) 102 | 103 | return hist, None 104 | -------------------------------------------------------------------------------- /Resnet-FX-CLE/quant_vis/settings.py: -------------------------------------------------------------------------------- 1 | # The histogram plot range will extend negatively beyond the minimum quantization value by `HIST_XLIM_MIN` of the quantization range. 2 | HIST_XMIN = 0.5 3 | # The histogram plot range will extend positively beyond the maximum quantization value by `HIST_XMAS` of the quantization range. 4 | HIST_XMAX = 0.5 5 | # How many histogram bins per quantization bin 6 | HIST_QUANT_BIN_RATIO = 5 7 | # How many quantization bins to average in the smoothing average plot for the forward and grads 8 | SMOOTH_WINDOW = 9 9 | 10 | # Coordinates for the sub-plot for the forward histogram mini-plot 11 | SUM_POS_1_DEFAULT = [0.18, 0.60, 0.1, 0.1] 12 | # Coordinates for the sub-plot for the sensitivity analysis mini-plot 13 | SUM_POS_2_DEFAULT = [0.75, 0.43, 0.1, 0.1] 14 | -------------------------------------------------------------------------------- /Resnet-FX-CLE/quant_vis/utils/act_histogram.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class ActHistogram: 6 | """ 7 | A dataclass for storing arbitrary activation data via hooks. 8 | The data is stored in `data`, and `hook_handles` stores the hook handles. 9 | """ 10 | 11 | data: dict 12 | hook_handles: dict 13 | 14 | def reset(self): 15 | 16 | # Remove activation hook once we're done with them 17 | # Otherwise the hook will remain in, slowing down forward calls 18 | for handle in self.hook_handles.values(): 19 | handle.remove() 20 | 21 | self.data = {} 22 | self.hook_handles = {} 23 | -------------------------------------------------------------------------------- /Resnet-FX-CLE/quant_vis/utils/hooks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def is_model_quantizable(model: nn.Module, weight_or_act: str) -> bool: 5 | """ 6 | A function for ehcking if the model has any quantizable weights and/or activations. 7 | 8 | Inputs: 9 | - model: Our quantizable model. 10 | - weight_or_act (str): Whether we are checking for quantizable weights, activations, or both. 11 | Options: ['activation', 'weight', 'both']. 12 | 13 | Outputs: 14 | - return (bool): whether or not the model has a quantizable weight and/or activation. 15 | """ 16 | 17 | if weight_or_act not in ["weight", "activation", "both"]: 18 | raise ValueError( 19 | "`weight_or_act` should be a string equal to `weight`, `activation`, or `both`" 20 | ) 21 | 22 | # Check weights 23 | quantizable_model = False 24 | if weight_or_act in ["weight", "both"]: 25 | for name, _ in model.named_parameters(): 26 | if "weight_fake_quant" in name: 27 | quantizable_model = True 28 | 29 | if not quantizable_model: 30 | return False 31 | 32 | # Check activations 33 | if weight_or_act in ["activation", "both"]: 34 | for name, module in model.named_modules(): 35 | if hasattr(module, "activation_post_process") and hasattr( 36 | module.activation_post_process, "qscheme" 37 | ): 38 | quantizable_model = True 39 | 40 | if not quantizable_model: 41 | return False 42 | 43 | return True 44 | -------------------------------------------------------------------------------- /Resnet-FX-CLE/quant_vis/utils/prop_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def forward_and_backprop_an_image(model: torch.nn.Module): 5 | """ 6 | Forward and backwards propagate an image of a dog. 7 | """ 8 | import urllib 9 | 10 | url, filename = ( 11 | "https://github.com/pytorch/hub/raw/master/images/dog.jpg", 12 | "dog.jpg", 13 | ) 14 | try: 15 | urllib.URLopener().retrieve(url, filename) 16 | except: 17 | urllib.request.urlretrieve(url, filename) 18 | # sample execution (requires torchvision) 19 | from PIL import Image 20 | from torchvision import transforms 21 | 22 | input_image = Image.open(filename) 23 | preprocess = transforms.Compose( 24 | [ 25 | transforms.Resize(256), 26 | transforms.CenterCrop(224), 27 | transforms.ToTensor(), 28 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 29 | ] 30 | ) 31 | input_tensor = preprocess(input_image) 32 | input_batch = input_tensor.unsqueeze( 33 | 0 34 | ) # create a mini-batch as expected by the model 35 | input_batch.to("cpu") 36 | model.to("cpu") 37 | 38 | # Feed data through the model 39 | output = model(input_batch) 40 | 41 | # We backpropagate the gradients. We take the mean of the output, ensuring that 42 | # we backprop a scalar where all outputs are equally represented. 43 | output.mean().backward() 44 | -------------------------------------------------------------------------------- /Resnet-FX-CLE/utils/dotdict.py: -------------------------------------------------------------------------------- 1 | class dotdict(dict): 2 | """ 3 | A class with dot.notation access to dictionary attributes 4 | """ 5 | 6 | __getattr__ = dict.get 7 | __setattr__ = dict.__setitem__ 8 | __delattr__ = dict.__delitem__ 9 | -------------------------------------------------------------------------------- /Resnet-FX-CLE/utils/graph_manip.py: -------------------------------------------------------------------------------- 1 | ######################### 2 | # SOME GRAPH TECHNIQUES # 3 | ######################### 4 | # Experiment with iterator pattern: 5 | # NOTE: taken from https://pytorch.org/docs/stable/fx.html#the-interpreter-pattern 6 | import torch 7 | from torch import fx 8 | from torch.fx.node import Node 9 | from typing import Dict, Union, Tuple, Any 10 | 11 | from torch.ao.nn.intrinsic.qat.modules.conv_fused import ( 12 | ConvBnReLU2d, 13 | ConvReLU2d, 14 | ConvBn2d, 15 | ) 16 | from torch.nn.modules.conv import Conv2d 17 | from torch.ao.nn.qat import Conv2d as QATConv2d 18 | from torch.nn.modules.batchnorm import BatchNorm2d 19 | 20 | 21 | ######################################### 22 | # Fusing Bn in ConvBnReLU into ConvReLU # 23 | ######################################### 24 | def qat_fuse_conv_bn_relu_eval( 25 | conv: Union[ConvBnReLU2d, ConvBn2d] 26 | ) -> Union[ConvReLU2d, Conv2d]: 27 | """ 28 | Given a quantizable ConvBnReLU2d Module returns a quantizable ConvReLU2d 29 | module such that the BatchNorm has been fused into the Conv, in inference mode. 30 | Given a ConvBn2d, it does the same to produce a Conv2d. 31 | One could also use `torch.nn.utils.fuse_conv_bn_eval` to produce a Conv, and then quantize that as desired. 32 | """ 33 | assert not (conv.training or conv.bn.training), "Fusion only for eval!" 34 | qconfig = conv.qconfig 35 | if type(conv) is ConvBnReLU2d: 36 | new_conv = ConvReLU2d( 37 | conv.in_channels, 38 | conv.out_channels, 39 | conv.kernel_size, 40 | conv.stride, 41 | conv.padding, 42 | conv.dilation, 43 | conv.groups, 44 | conv.bias is not None, 45 | conv.padding_mode, 46 | qconfig=qconfig, 47 | ) 48 | elif type(conv) is ConvBn2d: 49 | new_conv = QATConv2d( 50 | conv.in_channels, 51 | conv.out_channels, 52 | conv.kernel_size, 53 | conv.stride, 54 | conv.padding, 55 | conv.dilation, 56 | conv.groups, 57 | conv.bias is not None, 58 | conv.padding_mode, 59 | qconfig=qconfig, 60 | ) 61 | else: 62 | raise NotImplementedError(f"conv type {type(conv)} not supported.") 63 | 64 | new_conv.weight, new_conv.bias = fuse_conv_bn_weights( 65 | conv.weight, 66 | conv.bias, 67 | conv.bn.running_mean, 68 | conv.bn.running_var, 69 | conv.bn.eps, 70 | conv.bn.weight, 71 | conv.bn.bias, 72 | ) 73 | 74 | return new_conv 75 | 76 | 77 | def float_fuse_conv_bn_relu_eval( 78 | conv: Union[ConvReLU2d, Conv2d], bn 79 | ) -> Union[ConvReLU2d, Conv2d]: 80 | """ 81 | Given a Conv2d and a BatchNorm module pair, returns a Conv2d 82 | module such that the BatchNorm has been fused into the Conv, in inference mode. 83 | """ 84 | assert not (conv.training or bn.training), "Fusion only for eval!" 85 | if type(conv) is Conv2d: 86 | new_conv = Conv2d( 87 | conv.in_channels, 88 | conv.out_channels, 89 | conv.kernel_size, 90 | conv.stride, 91 | conv.padding, 92 | conv.dilation, 93 | conv.groups, 94 | conv.bias is not None, 95 | conv.padding_mode, 96 | ) 97 | 98 | new_conv.weight, new_conv.bias = fuse_conv_bn_weights( 99 | conv.weight, 100 | conv.bias, 101 | bn.running_mean, 102 | bn.running_var, 103 | bn.eps, 104 | bn.weight, 105 | bn.bias, 106 | ) 107 | 108 | return new_conv 109 | 110 | 111 | def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b): 112 | """ 113 | Helper function for fusing a Conv and BatchNorm into a single weight/bias tensor pair. 114 | """ 115 | if conv_b is None: 116 | conv_b = torch.zeros_like(bn_rm) 117 | if bn_w is None: 118 | bn_w = torch.ones_like(bn_rm) 119 | if bn_b is None: 120 | bn_b = torch.zeros_like(bn_rm) 121 | bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) 122 | 123 | conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape( 124 | [-1] + [1] * (len(conv_w.shape) - 1) 125 | ) 126 | conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b 127 | 128 | return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b) 129 | 130 | 131 | # Graph manipulation functions for fusing Convs and BatchNorms 132 | def _parent_name(target: str) -> Tuple[str, str]: 133 | """ 134 | Splits a qualname into parent path and last atom. 135 | For example, `foo.bar.baz` -> (`foo.bar`, `baz`) 136 | """ 137 | *parent, name = target.rsplit(".", 1) 138 | return parent[0] if parent else "", name 139 | 140 | 141 | def replace_node_module( 142 | node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module 143 | ): 144 | """ 145 | Helper function for having `new_module` take the place of `node` in a dict of modules. 146 | """ 147 | assert isinstance(node.target, str) 148 | parent_name, name = _parent_name(node.target) 149 | # modules[node.target] = new_module 150 | setattr(modules[parent_name], name, new_module) 151 | 152 | 153 | def replace_conv_bn_pair( 154 | conv_node: fx.Node, 155 | bn_node: fx.Node, 156 | modules: Dict[str, Any], 157 | new_module: torch.nn.Module, 158 | model: torch.fx.GraphModule, 159 | ): 160 | """ 161 | Helper function for having `new_module` take the place of two adjacent 162 | nodes (`conv_node` and `bn_node`) in a dict of modules. 163 | """ 164 | # Replace the Convs with Convs with fused Batchnorms 165 | assert isinstance(conv_node.target, str) 166 | parent_name, name = _parent_name(conv_node.target) 167 | modules[conv_node.target] = new_module 168 | setattr(modules[parent_name], name, new_module) 169 | 170 | # Delete the Batchnorms from the graph 171 | assert isinstance(bn_node.target, str) 172 | bn_node.replace_all_uses_with(bn_node.args[0]) 173 | model.graph.erase_node(bn_node) 174 | 175 | 176 | def get_previous_module_node( 177 | fx_model: torch.fx.GraphModule, 178 | node: torch.fx.Node, 179 | module_type, 180 | CLE_compatible: bool = False, 181 | ): 182 | """ 183 | For a given node, find the closest previous node of a certain type. 184 | module_type: Can be an individual module type, or a tuple of types. 185 | 186 | If we specify that `CLE_compatible = True`, we only return a predecessor if 187 | the nodes between the current node and its predecessor don't have any 188 | CLE-breaking operations between them, e.g. `avgpool`, `add`, etc. 189 | """ 190 | modules = dict(fx_model.named_modules()) 191 | 192 | # Traverse the graph backwards 193 | for predecessor in node.all_input_nodes: 194 | if isinstance(predecessor, torch.fx.Node): 195 | 196 | # Return None if we run into any CLE breaking operation 197 | if CLE_compatible: 198 | if predecessor.target not in modules: 199 | return None 200 | 201 | # If the current node is CLE "breaking", we didn't find a 202 | # predecessor that was CLE compatible 203 | if not isinstance( 204 | fx_model.get_submodule(node.target), torch.nn.ReLU 205 | ) and not isinstance(fx_model.get_submodule(node.target), module_type): 206 | return None 207 | 208 | # If we found the preceding node that matches the type 209 | if predecessor.target in modules: 210 | # Check if the predecessor node is the desired module type 211 | if isinstance(fx_model.get_submodule(predecessor.target), module_type): 212 | return predecessor 213 | 214 | # Recursively search for the module in the predecessor's inputs 215 | prev_module_node = get_previous_module_node( 216 | fx_model, predecessor, module_type, CLE_compatible 217 | ) 218 | if prev_module_node is not None: 219 | return prev_module_node 220 | 221 | # Module not found in the previous nodes 222 | return None 223 | 224 | 225 | def qat_convbn_to_conv(fx_model: torch.fx.GraphModule) -> torch.fx.GraphModule: 226 | """ 227 | Iterates through the graph nodes, and: 228 | - where it finds a ConvBnReLU it replaces it with ConvReLU 229 | - where it finds a ConvBn it replaces it with Conv 230 | 231 | This function works in-place on `fx_model`. 232 | 233 | Inputs: 234 | fx_model (torch.fx.GraphModule): a graph module, that we want to 235 | perform transformations on. 236 | 237 | Output: 238 | (torch.fx.GraphModule): a model where we have swapped out the 2d 239 | ConvBn/ConvBnReLU for Conv/ConvReLU, and 240 | fused the Bns into the Convs. 241 | """ 242 | modules = dict(fx_model.named_modules()) 243 | 244 | for node in fx_model.graph.nodes: 245 | # If the operation the node is doing is to call a module 246 | if node.op == "call_module": 247 | # The current node 248 | orig = fx_model.get_submodule(node.target) 249 | if type(orig) in [ConvBnReLU2d, ConvBn2d]: 250 | # Produce a fused Bn equivalent. 251 | fused_conv = qat_fuse_conv_bn_relu_eval(orig) 252 | # This updates `modules` so that `fused_conv` takes the place 253 | # of what was represented by `node` 254 | replace_node_module(node, modules, fused_conv) 255 | 256 | return fx_model 257 | 258 | 259 | def float_convbn_to_conv(fx_model: torch.fx.GraphModule) -> torch.fx.GraphModule: 260 | """ 261 | Iterates through the graph nodes, and where it finds a pair of 262 | Conv-Bn nodes it replaces it with Conv with the Bn fused in. 263 | 264 | This is distinct from the `qat_convbn_to_conv` function that deals with 265 | taking fused [ConvBnReLU2d, ConvBn2d] instances and replaces them with 266 | quantized Conv2d/ConvReLU2d equivalents. 267 | 268 | This function works in-place on `fx_model`. 269 | 270 | Inputs: 271 | fx_model (torch.fx.GraphModule): a graph module, that we want to 272 | perform transformations on. 273 | 274 | Output: 275 | (torch.fx.GraphModule): a model where we have swapped out the 2d 276 | ConvBn/ConvBnReLU for Conv/ConvReLU, and 277 | fused the Bns into the Convs. 278 | """ 279 | modules = dict(fx_model.named_modules()) 280 | 281 | pair, pairs = [], [] 282 | for node in fx_model.graph.nodes: 283 | if node.op == "call_module": 284 | module = fx_model.get_submodule(node.target) 285 | if hasattr(module, "weight"): 286 | pair.append(node) 287 | if len(pair) == 2: 288 | if isinstance(module, BatchNorm2d): 289 | pairs.append(pair) 290 | pair = [] 291 | pair.append(node) 292 | 293 | for conv_bn in pairs: 294 | conv_name, bn_name = conv_bn 295 | conv = fx_model.get_submodule(conv_name.target) 296 | bn = fx_model.get_submodule(bn_name.target) 297 | # Produce a fused Bn equivalent. 298 | fused_conv = float_fuse_conv_bn_relu_eval(conv, bn) 299 | # This updates `modules` so that `fused_conv` takes the place of what 300 | # was represented by `node` 301 | replace_conv_bn_pair(conv_name, bn_name, modules, fused_conv, fx_model) 302 | 303 | return fx_model 304 | -------------------------------------------------------------------------------- /Resnet-FX-CLE/utils/ipdb_hook.py: -------------------------------------------------------------------------------- 1 | import traceback, ipdb 2 | import sys 3 | 4 | 5 | def ipdb_sys_excepthook(): 6 | """ 7 | When called this function will set up the system exception hook. 8 | This hook throws one into an ipdb breakpoint if and where a system 9 | exception occurs in one's run. 10 | 11 | E.g. 12 | >>> ipdb_sys_excepthook() 13 | """ 14 | 15 | def info(type, value, tb): 16 | """ 17 | System excepthook that includes an ipdb breakpoint. 18 | """ 19 | if hasattr(sys, "ps1") or not sys.stderr.isatty(): 20 | # we are in interactive mode or we don't have a tty-like 21 | # device, so we call the default hook 22 | sys.__excepthook__(type, value, tb) 23 | else: 24 | # we are NOT in interactive mode, print the exception... 25 | traceback.print_exception(type, value, tb) 26 | print 27 | # ...then start the debugger in post-mortem mode. 28 | # pdb.pm() # deprecated 29 | ipdb.post_mortem(tb) # more "modern" 30 | 31 | sys.excepthook = info 32 | -------------------------------------------------------------------------------- /Resnet-FX-CLE/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | LOGGING_LEVEL = logging.DEBUG 4 | 5 | 6 | def setup_logger(logger_name, level=LOGGING_LEVEL): 7 | """ 8 | A basic logger setup that logs to console. 9 | """ 10 | logger = logging.getLogger(logger_name) 11 | logger.setLevel(level) 12 | 13 | # Create a console handler 14 | console_handler = logging.StreamHandler() 15 | console_handler.setLevel(level) 16 | 17 | # Create a formatter 18 | formatter = logging.Formatter( 19 | "%(levelname)s - %(filename)s:%(lineno)d - %(message)s" 20 | ) 21 | console_handler.setFormatter(formatter) 22 | 23 | # Add the handlers to the logger 24 | logger.addHandler(console_handler) 25 | 26 | return logger 27 | -------------------------------------------------------------------------------- /Resnet-FX-CLE/utils/qconfigs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.quantization as tq 3 | from torch.ao.quantization.fake_quantize import FakeQuantize 4 | from torch.ao.quantization._learnable_fake_quantize import ( 5 | _LearnableFakeQuantize as LearnableFakeQuantize, 6 | ) 7 | 8 | learnable_act = lambda range: LearnableFakeQuantize.with_args( 9 | observer=tq.HistogramObserver, 10 | quant_min=0, 11 | quant_max=255, 12 | dtype=torch.quint8, 13 | qscheme=torch.per_tensor_affine, 14 | scale=range / 255.0, 15 | zero_point=0.0, 16 | use_grad_scaling=True, 17 | ) 18 | 19 | learnable_weights = lambda channels: LearnableFakeQuantize.with_args( # need to specify number of channels here 20 | observer=tq.PerChannelMinMaxObserver, 21 | quant_min=-128, 22 | quant_max=127, 23 | dtype=torch.qint8, 24 | qscheme=torch.per_channel_symmetric, 25 | scale=0.1, 26 | zero_point=0.0, 27 | use_grad_scaling=True, 28 | channel_len=channels, 29 | ) 30 | 31 | fake_quant_act = FakeQuantize.with_args( 32 | observer=tq.HistogramObserver.with_args( 33 | quant_min=0, 34 | quant_max=255, 35 | dtype=torch.quint8, 36 | qscheme=torch.per_tensor_affine, 37 | ), 38 | ) 39 | 40 | fake_quant_weight = FakeQuantize.with_args( 41 | observer=tq.PerChannelMinMaxObserver, 42 | quant_min=-128, 43 | quant_max=127, 44 | dtype=torch.qint8, 45 | qscheme=torch.per_channel_symmetric, 46 | ) 47 | -------------------------------------------------------------------------------- /Resnet-FX-Graph-Mode-Quant/README.md: -------------------------------------------------------------------------------- 1 | # How to do FX Graph Mode Quantization (PyTorch ResNet Coding tutorial) 2 | 3 | In this tutorial series, we use Torch's FX Graph mode quantization to quantize a ResNet. In the first video, we look at the Directed Acyclic Graph (DAG), and see how the fusing, placement of quantstubs and FloatFunctionals all happen automatically. In the second, we look at some of the intricacies of how quantization interacts with the GraphModule. In the third and final video, we look at some more advanced techniques for manipulating and traversing the graph, and use these to discover an alternative to forward hooks, and for fusing BatchNorm layers into their preceding Convs. 4 | 5 | [![How to do FX Graph Mode Quantization: FX Graph Mode Quantization Coding tutorial - Part 1/3](https://ytcards.demolab.com/?id=AHw5BOUfLU4&title=How+to+do+FX+Graph+Mode+Quantization%3A+FX+Graph+Mode+Quantization+Coding+tutorial+-+Part+1%2F3&lang=en×tamp=1710264531&background_color=%230d1117&title_color=%23ffffff&stats_color=%23dedede&max_title_lines=1&width=250&border_radius=5 "How to do FX Graph Mode Quantization: FX Graph Mode Quantization Coding tutorial - Part 1/3")](https://www.youtube.com/watch?v=AHw5BOUfLU4) 6 | [![How does Graph Mode Affect Quantization? FX Graph Mode Quantization Coding tutorial - Part 2/3](https://ytcards.demolab.com/?id=1S3jlGdGdjM&title=How+does+Graph+Mode+Affect+Quantization%3F+FX+Graph+Mode+Quantization+Coding+tutorial+-+Part+2%2F3&lang=en×tamp=1710452876&background_color=%230d1117&title_color=%23ffffff&stats_color=%23dedede&max_title_lines=1&width=250&border_radius=5 "How does Graph Mode Affect Quantization? FX Graph Mode Quantization Coding tutorial - Part 2/3")](https://www.youtube.com/watch?v=1S3jlGdGdjM) 7 | [![Advanced PyTorch Graph Manipulation: FX Graph Mode Quantization Coding tutorial - Part 3/3](https://ytcards.demolab.com/?id=azpsgB8y0A8&title=Advanced+PyTorch+Graph+Manipulation%3A+FX+Graph+Mode+Quantization+Coding+tutorial+-+Part+3%2F3&lang=en×tamp=1711116192&background_color=%230d1117&title_color=%23ffffff&stats_color=%23dedede&max_title_lines=1&width=250&border_radius=5 "Advanced PyTorch Graph Manipulation: FX Graph Mode Quantization Coding tutorial - Part 3/3")](https://www.youtube.com/watch?v=azpsgB8y0A8) 8 | 9 | This repo is the code associated with the videos. 10 | 11 | ### Prerequisites: 12 | To run this code, you need to have PyTorch installed in your environment. If you do not have PyTorch installed, please follow this [official guide](https://pytorch.org/get-started/locally/). 13 | 14 | I created this code with PyTorch Version: 2.1.1. In case you have any versioning issues, you can revert to that version. 15 | 16 | To run `fx_model.graph.print_tabular()`, one needs to have `tabulate` installed. To do, activate your (e.g. conda) environment and run 17 | ``` 18 | pip install tabulate 19 | ``` 20 | 21 | To plot the graph as a tree in an SVG file, i.e. to run: 22 | ``` 23 | g = passes.graph_drawer.FxGraphDrawer(fx_model, 'resnet18-fx-model') 24 | ``` 25 | 26 | One needs to install GraphViz and have it on PATH (or as a local PATH variable). 27 | 28 | ### Running this code: 29 | Once you have PyTorch installed, first navigate to a directory you will be working from. As you follow the next steps, your final file structure will look like this: `your-directory/Resnet-FX-Graph-Mode-Quant`. 30 | 31 | Next, from `your-directory`, clone the `Quantization-Tutorials` repo. This repo contains different tutorials, but they are all interlinked. Feel no need to do any of the others! I just structured it this way because the tutorials share a lot of code and it might help people to see different parts in one place. 32 | 33 | You can also `git init` and then `git pull/fetch`, depending on what you prefer. 34 | 35 | To clone the repo, run: 36 | ``` 37 | git clone git@github.com:OscarSavolainenDR/Quantization-Tutorials.git . 38 | ``` 39 | 40 | If you did the cloning in place with the `.` at the end, your folder structure should look like `your-folder/Resnet-FX-Graph-Mode-Quant`, with various other folders for other tutorials. 41 | 42 | Next, cd into the Resnet FX Graph Mode Quantization tutorial: 43 | ``` 44 | cd Resnet-FX-Graph-Mode-Quant 45 | ``` 46 | Then, just run `python main.py` from your command line! However I would obviously recommend that you follow along with the tutorial, so that you learn how it all works and get your hands dirty. 47 | 48 | For the tutorial on eager mode (which includes making architecture changes to the publicly available Resenet model to make it quantizable), cd into `your-folder/Resnet-Eager-Mode-Quant` and see the `README.md` there. 49 | 50 | 51 | Let me know if there are any issues! 52 | -------------------------------------------------------------------------------- /Resnet-FX-Graph-Mode-Quant/__pycache__/evaluate.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-Graph-Mode-Quant/__pycache__/evaluate.cpython-310.pyc -------------------------------------------------------------------------------- /Resnet-FX-Graph-Mode-Quant/__pycache__/evaluate.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-Graph-Mode-Quant/__pycache__/evaluate.cpython-311.pyc -------------------------------------------------------------------------------- /Resnet-FX-Graph-Mode-Quant/__pycache__/evaluate.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-Graph-Mode-Quant/__pycache__/evaluate.cpython-312.pyc -------------------------------------------------------------------------------- /Resnet-FX-Graph-Mode-Quant/__pycache__/ipdb_hook.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-Graph-Mode-Quant/__pycache__/ipdb_hook.cpython-310.pyc -------------------------------------------------------------------------------- /Resnet-FX-Graph-Mode-Quant/__pycache__/ipdb_hook.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-Graph-Mode-Quant/__pycache__/ipdb_hook.cpython-312.pyc -------------------------------------------------------------------------------- /Resnet-FX-Graph-Mode-Quant/__pycache__/qconfigs.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-Graph-Mode-Quant/__pycache__/qconfigs.cpython-310.pyc -------------------------------------------------------------------------------- /Resnet-FX-Graph-Mode-Quant/__pycache__/qconfigs.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-Graph-Mode-Quant/__pycache__/qconfigs.cpython-311.pyc -------------------------------------------------------------------------------- /Resnet-FX-Graph-Mode-Quant/__pycache__/qconfigs.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-Graph-Mode-Quant/__pycache__/qconfigs.cpython-312.pyc -------------------------------------------------------------------------------- /Resnet-FX-Graph-Mode-Quant/evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluate import evaluate -------------------------------------------------------------------------------- /Resnet-FX-Graph-Mode-Quant/evaluate/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-Graph-Mode-Quant/evaluate/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /Resnet-FX-Graph-Mode-Quant/evaluate/__pycache__/evaluate.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-Graph-Mode-Quant/evaluate/__pycache__/evaluate.cpython-312.pyc -------------------------------------------------------------------------------- /Resnet-FX-Graph-Mode-Quant/evaluate/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-Graph-Mode-Quant/evaluate/dog.jpg -------------------------------------------------------------------------------- /Resnet-FX-Graph-Mode-Quant/evaluate/evaluate.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch 3 | 4 | def evaluate(model, device_str: str): 5 | # Download an example image from the pytorch website 6 | import urllib 7 | url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", Path("evaluate/dog.jpg")) 8 | try: urllib.URLopener().retrieve(url, filename) 9 | except: urllib.request.urlretrieve(url, filename) 10 | # sample execution (requires torchvision) 11 | 12 | from PIL import Image 13 | from torchvision import transforms 14 | input_image = Image.open(filename) 15 | preprocess = transforms.Compose([ 16 | transforms.Resize(256), 17 | transforms.CenterCrop(224), 18 | transforms.ToTensor(), 19 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 20 | ]) 21 | input_tensor = preprocess(input_image) 22 | input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model 23 | 24 | # move the input and model to GPU for speed if available, or to CPU if converted 25 | if not (device_str in['cpu', 'cuda']): 26 | raise NotImplementedError("`device_str` should be 'cpu' or 'cuda' ") 27 | if device_str == 'cuda': 28 | assert torch.cuda.is_available(), 'Check CUDA is available' 29 | 30 | input_batch = input_batch.to(device_str) 31 | model.to(device_str) 32 | model.eval() 33 | 34 | with torch.no_grad(): 35 | output = model(input_batch) 36 | # Tensor of shape 1000, with confidence scores over ImageNet's 1000 classes 37 | # print(output[0]) 38 | # The output has unnormalized scores. To get probabilities, you can run a softmax on it. 39 | probabilities = torch.nn.functional.softmax(output[0], dim=0) 40 | # print(probabilities) 41 | 42 | # Read the categories 43 | with open(Path("evaluate/imagenet_classes.txt"), "r") as f: 44 | categories = [s.strip() for s in f.readlines()] 45 | # Show top categories per image 46 | top5_prob, top5_catid = torch.topk(probabilities, 5) 47 | for i in range(top5_prob.size(0)): 48 | print(categories[top5_catid[i]], top5_prob[i].item()) 49 | -------------------------------------------------------------------------------- /Resnet-FX-Graph-Mode-Quant/ipdb_hook.py: -------------------------------------------------------------------------------- 1 | import traceback, ipdb 2 | import sys 3 | 4 | def ipdb_sys_excepthook(): 5 | """ 6 | When called this function will set up the system exception hook. 7 | This hook throws one into an ipdb breakpoint if and where a system 8 | exception occurs in one's run. 9 | 10 | E.g. 11 | >>> ipdb_sys_excepthook() 12 | """ 13 | 14 | 15 | def info(type, value, tb): 16 | """ 17 | System excepthook that includes an ipdb breakpoint. 18 | """ 19 | if hasattr(sys, 'ps1') or not sys.stderr.isatty(): 20 | # we are in interactive mode or we don't have a tty-like 21 | # device, so we call the default hook 22 | sys.__excepthook__(type, value, tb) 23 | else: 24 | # we are NOT in interactive mode, print the exception... 25 | traceback.print_exception(type, value, tb) 26 | print 27 | # ...then start the debugger in post-mortem mode. 28 | # pdb.pm() # deprecated 29 | ipdb.post_mortem(tb) # more "modern" 30 | sys.excepthook = info -------------------------------------------------------------------------------- /Resnet-FX-Graph-Mode-Quant/main.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Any, Union 2 | from copy import deepcopy 3 | 4 | import torch 5 | import torch.quantization as tq 6 | from torch.ao.quantization.quantize_fx import prepare_qat_fx 7 | from torch.ao.quantization.qconfig_mapping import QConfigMapping 8 | import torch.fx as fx 9 | 10 | from evaluate import evaluate 11 | from qconfigs import learnable_act, learnable_weights, fake_quant_act 12 | from ipdb_hook import ipdb_sys_excepthook 13 | 14 | from model.resnet import resnet18 15 | 16 | # Adds ipdb breakpoint if and where we have an error 17 | ipdb_sys_excepthook() 18 | 19 | # Intialize model 20 | model = resnet18(pretrained=True) 21 | 22 | # Define qconfigs 23 | qconfig_global = tq.QConfig( 24 | activation=fake_quant_act, 25 | weight=tq.default_fused_per_channel_wt_fake_quant 26 | ) 27 | 28 | 29 | # Assign qconfigs 30 | qconfig_mapping = QConfigMapping().set_global(qconfig_global) 31 | 32 | # We loop through the modules so that we can access the `out_channels` attribute 33 | for name, module in model.named_modules(): 34 | if hasattr(module, 'out_channels'): 35 | qconfig = tq.QConfig( 36 | activation=learnable_act(range=2), 37 | weight=learnable_weights(channels=module.out_channels) 38 | ) 39 | qconfig_mapping.set_module_name(name, qconfig) 40 | 41 | # Do symbolic tracing and quantization 42 | example_inputs = (torch.randn(1, 3, 224, 224),) 43 | model.eval() 44 | fx_model = prepare_qat_fx(model, qconfig_mapping, example_inputs) 45 | 46 | # Evaluate model 47 | print('\n Original') 48 | evaluate(model, 'cpu') 49 | 50 | print('\n FX prepared') 51 | evaluate(fx_model, 'cpu') 52 | 53 | # Can experiment with visualize the graph, e.g. 54 | # >> fx_model 55 | # >> print(fx_model.graph) # prints the DAG 56 | 57 | # Prints the graph as a table 58 | print("\nGraph as a Table:\n") 59 | fx_model.graph.print_tabular() 60 | 61 | # Plots the graph 62 | # Need to install GraphViz and have it on PATH (or as a local PATH variable) 63 | #from torch.fx import passes 64 | #g = passes.graph_drawer.FxGraphDrawer(fx_model, 'fx-model') 65 | #with open("graph.svg", "wb") as f: 66 | #f.write(g.get_dot_graph().create_svg()) 67 | 68 | 69 | ######################### 70 | # SOME GRAPH TECHNIQUES # 71 | ######################### 72 | # Experiment with iterator pattern: 73 | # NOTE: taken from https://pytorch.org/docs/stable/fx.html#the-interpreter-pattern 74 | from torch.fx.node import Node 75 | from typing import Dict 76 | 77 | class GraphIteratorStorage: 78 | """ 79 | A general Iterator over the graph. This class takes a `GraphModule`, 80 | and a callable `storage` representing a function that will store some 81 | attribute for each node when the `propagate` method is called. 82 | 83 | Its `propagate` method executes the `GraphModule` 84 | node-by-node with the given arguments, e.g. an example input tensor. 85 | As each operation executes, the GraphIteratorStorage class stores 86 | away the result of the callable for the output values of each operation on 87 | the attributes of the operation's `Node`. For example, 88 | one could use a callable `store_shaped_dtype()` where: 89 | 90 | ``` 91 | def store_shape_dtype(result): 92 | if isinstance(result, torch.Tensor): 93 | node.shape = result.shape 94 | node.dtype = result.dtype 95 | ``` 96 | This would store the `shape` and `dtype` of each operation on 97 | its respective `Node`, for the given input to `propagate`. 98 | """ 99 | def __init__(self, mod, storage): 100 | self.mod = mod 101 | self.graph = mod.graph 102 | self.modules = dict(self.mod.named_modules()) 103 | self.storage = storage 104 | 105 | def propagate(self, *args): 106 | args_iter = iter(args) 107 | env : Dict[str, Node] = {} 108 | 109 | def load_arg(a): 110 | return torch.fx.graph.map_arg(a, lambda n: env[n.name]) 111 | 112 | def fetch_attr(target : str): 113 | target_atoms = target.split('.') 114 | attr_itr = self.mod 115 | for i, atom in enumerate(target_atoms): 116 | if not hasattr(attr_itr, atom): 117 | raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}") 118 | attr_itr = getattr(attr_itr, atom) 119 | return attr_itr 120 | 121 | # Each of these `result` variables correpsonds to the output of the node in question 122 | for node in self.graph.nodes: 123 | if node.op == 'placeholder': 124 | result = next(args_iter) 125 | elif node.op == 'get_attr': 126 | result = fetch_attr(node.target) 127 | elif node.op == 'call_function': 128 | result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) 129 | elif node.op == 'call_method': 130 | self_obj, *args = load_arg(node.args) 131 | kwargs = load_arg(node.kwargs) 132 | result = getattr(self_obj, node.target)(*args, **kwargs) 133 | elif node.op == 'call_module': 134 | module = self.mod.get_submodule(node.target) # this is robust to "shared" nodes. 135 | result = module(*load_arg(node.args), **load_arg(node.kwargs)) 136 | 137 | # This is the only code specific to the `storage` function. 138 | # you can delete this line and this `propagate` function becomes 139 | # a generic GraphModule interpreter 140 | self.storage(node, result) 141 | 142 | # Store the output activation in `env` for the given node 143 | env[node.name] = result 144 | 145 | #return load_arg(self.graph.result) 146 | 147 | def store_shape_dtype(node, result): 148 | """ 149 | Function that takes in the current node, and the tensor it is operating on (`result`) 150 | and stores the shape and dtype of `result` on the node as attributes. 151 | """ 152 | if isinstance(result, torch.Tensor): 153 | node.shape = result.shape 154 | node.dtype = result.dtype 155 | # NOTE: I just discovered that they have the `Interpreter` class, which accomplishes the same thing: 156 | # https://pytorch.org/docs/stable/fx.html#torch.fx.Interpreter 157 | GraphIteratorStorage(fx_model, store_shape_dtype).propagate(example_inputs[0]) 158 | 159 | print("\nPrinting size and data types:") 160 | for node in fx_model.graph.nodes: 161 | print(node.name, node.shape, node.dtype) 162 | 163 | 164 | 165 | ########################################### 166 | ### Fusing Bn in ConvBnReLU into ConvReLU # 167 | ########################################### 168 | from torch.ao.nn.intrinsic.qat.modules.conv_fused import ConvBnReLU2d, ConvReLU2d, ConvBn2d 169 | from torch.ao.nn.qat import Conv2d 170 | 171 | def fuse_conv_bn_relu_eval(conv: Union[ConvBnReLU2d, ConvBn2d]) -> Union[ConvReLU2d, Conv2d]: 172 | """ 173 | Given a quantizable ConvBnReLU2d Module returns a quantizable ConvReLU2d 174 | module such that the BatchNorm has been fused into the Conv, in inference mode. 175 | Given a ConvBn2d, it does the same to produce a Conv2d. 176 | One could also use `torch.nn.utils.fuse_conv_bn_eval` to produce a Conv, and then quantize that as desired. 177 | """ 178 | assert(not (conv.training or conv.bn.training)), "Fusion only for eval!" 179 | qconfig = conv.qconfig 180 | if type(conv) is ConvBnReLU2d: 181 | new_conv = ConvReLU2d(conv.in_channels, conv.out_channels, conv.kernel_size, 182 | conv.stride, conv.padding, conv.dilation, 183 | conv.groups, conv.bias is not None, 184 | conv.padding_mode, qconfig=qconfig) 185 | elif type(conv) is ConvBn2d: 186 | new_conv = Conv2d(conv.in_channels, conv.out_channels, conv.kernel_size, 187 | conv.stride, conv.padding, conv.dilation, 188 | conv.groups, conv.bias is not None, 189 | conv.padding_mode, qconfig=qconfig) 190 | 191 | 192 | new_conv.weight, new_conv.bias = \ 193 | fuse_conv_bn_weights(conv.weight, conv.bias, 194 | conv.bn.running_mean, conv.bn.running_var, conv.bn.eps, conv.bn.weight, conv.bn.bias) 195 | 196 | return new_conv 197 | 198 | def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b): 199 | """ 200 | Helper function for fusing a Conv and BatchNorm into a single weight/bias tensor pair. 201 | """ 202 | if conv_b is None: 203 | conv_b = torch.zeros_like(bn_rm) 204 | if bn_w is None: 205 | bn_w = torch.ones_like(bn_rm) 206 | if bn_b is None: 207 | bn_b = torch.zeros_like(bn_rm) 208 | bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) 209 | 210 | conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1)) 211 | conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b 212 | 213 | return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b) 214 | 215 | # Graph manipulation functions for fusing Convs and BatchNorms 216 | def _parent_name(target : str) -> Tuple[str, str]: 217 | """ 218 | Splits a qualname into parent path and last atom. 219 | For example, `foo.bar.baz` -> (`foo.bar`, `baz`) 220 | """ 221 | *parent, name = target.rsplit('.', 1) 222 | return parent[0] if parent else '', name 223 | 224 | def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module): 225 | """ 226 | Helper function for having `new_module` take the place of `node` in a dict of modules. 227 | """ 228 | assert(isinstance(node.target, str)) 229 | parent_name, name = _parent_name(node.target) 230 | #modules[node.target] = new_module 231 | setattr(modules[parent_name], name, new_module) 232 | 233 | def convbn_to_conv(fx_model: torch.fx.GraphModule) -> torch.fx.GraphModule: 234 | """ 235 | Iterates through the graph nodes, and: 236 | - where it finds a ConvBnReLU it replaces it with ConvReLU 237 | - where it finds a ConvBn it replaces it with Conv 238 | 239 | This function works in-place on `fx_model`. 240 | 241 | Inputs: 242 | fx_model (torch.fx.GraphModule): a graph module, that we want to perform transformations on 243 | 244 | Output: 245 | (torch.fx.GraphModule): a model where we have swapped out the 2d ConvBn/ConvBnReLU for Conv/ConvReLU, and 246 | fused the Bns into the Convs. 247 | """ 248 | modules = dict(fx_model.named_modules()) 249 | 250 | for node in fx_model.graph.nodes: 251 | # If the operation the node is doing is to call a module 252 | if node.op == 'call_module': 253 | # The current node 254 | orig = fx_model.get_submodule(node.target) 255 | if type(orig) in [ConvBnReLU2d, ConvBn2d]: 256 | # Produce a fused Bn equivalent. 257 | fused_conv = fuse_conv_bn_relu_eval(orig) 258 | # This updates `modules` so that `fused_conv` takes the place of what was represented by `node` 259 | replace_node_module(node, modules, fused_conv) 260 | 261 | return fx_model 262 | 263 | transformed : torch.fx.GraphModule = convbn_to_conv(deepcopy(fx_model)) 264 | input = example_inputs[0] 265 | out = transformed(input) # Test we can feed something through the model 266 | print('\nTransformed model evaluation:') 267 | evaluate(transformed, 'cpu') 268 | XXX 269 | -------------------------------------------------------------------------------- /Resnet-FX-Graph-Mode-Quant/model/__pycache__/resnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-Graph-Mode-Quant/model/__pycache__/resnet.cpython-310.pyc -------------------------------------------------------------------------------- /Resnet-FX-Graph-Mode-Quant/model/__pycache__/resnet.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-Graph-Mode-Quant/model/__pycache__/resnet.cpython-311.pyc -------------------------------------------------------------------------------- /Resnet-FX-Graph-Mode-Quant/model/__pycache__/resnet.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-Graph-Mode-Quant/model/__pycache__/resnet.cpython-312.pyc -------------------------------------------------------------------------------- /Resnet-FX-Graph-Mode-Quant/model/resnet.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Callable, List, Optional, Type, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch import Tensor 7 | 8 | from torchvision.transforms._presets import ImageClassification 9 | from torchvision.utils import _log_api_usage_once 10 | from torchvision.models._api import register_model, Weights, WeightsEnum 11 | from torchvision.models._meta import _IMAGENET_CATEGORIES 12 | from torchvision.models._utils import _ovewrite_named_param, handle_legacy_interface 13 | 14 | 15 | __all__ = [ 16 | "ResNet", 17 | "ResNet18_Weights", 18 | "resnet18", 19 | ] 20 | 21 | 22 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 23 | """3x3 convolution with padding""" 24 | return nn.Conv2d( 25 | in_planes, 26 | out_planes, 27 | kernel_size=3, 28 | stride=stride, 29 | padding=dilation, 30 | groups=groups, 31 | bias=False, 32 | dilation=dilation, 33 | ) 34 | 35 | 36 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 37 | """1x1 convolution""" 38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 39 | 40 | 41 | class BasicBlock(nn.Module): 42 | expansion: int = 1 43 | 44 | def __init__( 45 | self, 46 | inplanes: int, 47 | planes: int, 48 | stride: int = 1, 49 | downsample: Optional[nn.Module] = None, 50 | groups: int = 1, 51 | base_width: int = 64, 52 | dilation: int = 1, 53 | norm_layer: Optional[Callable[..., nn.Module]] = None, 54 | ) -> None: 55 | super().__init__() 56 | if norm_layer is None: 57 | norm_layer = nn.BatchNorm2d 58 | if groups != 1 or base_width != 64: 59 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 60 | if dilation > 1: 61 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 62 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 63 | self.conv1 = conv3x3(inplanes, planes, stride) 64 | self.bn1 = norm_layer(planes) 65 | self.relu = nn.ReLU(inplace=True) 66 | self.conv2 = conv3x3(planes, planes) 67 | self.bn2 = norm_layer(planes) 68 | self.downsample = downsample 69 | self.stride = stride 70 | 71 | def forward(self, x: Tensor) -> Tensor: 72 | identity = x 73 | 74 | out = self.conv1(x) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv2(out) 79 | out = self.bn2(out) 80 | 81 | if self.downsample is not None: 82 | identity = self.downsample(x) 83 | 84 | out += identity 85 | out = self.relu(out) 86 | 87 | return out 88 | 89 | 90 | class Bottleneck(nn.Module): 91 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 92 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 93 | # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385. 94 | # This variant is also known as ResNet V1.5 and improves accuracy according to 95 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 96 | 97 | expansion: int = 4 98 | 99 | def __init__( 100 | self, 101 | inplanes: int, 102 | planes: int, 103 | stride: int = 1, 104 | downsample: Optional[nn.Module] = None, 105 | groups: int = 1, 106 | base_width: int = 64, 107 | dilation: int = 1, 108 | norm_layer: Optional[Callable[..., nn.Module]] = None, 109 | ) -> None: 110 | super().__init__() 111 | if norm_layer is None: 112 | norm_layer = nn.BatchNorm2d 113 | width = int(planes * (base_width / 64.0)) * groups 114 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 115 | self.conv1 = conv1x1(inplanes, width) 116 | self.bn1 = norm_layer(width) 117 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 118 | self.bn2 = norm_layer(width) 119 | self.conv3 = conv1x1(width, planes * self.expansion) 120 | self.bn3 = norm_layer(planes * self.expansion) 121 | self.relu = nn.ReLU(inplace=True) 122 | self.downsample = downsample 123 | self.stride = stride 124 | 125 | def forward(self, x: Tensor) -> Tensor: 126 | identity = x 127 | 128 | out = self.conv1(x) 129 | out = self.bn1(out) 130 | out = self.relu(out) 131 | 132 | out = self.conv2(out) 133 | out = self.bn2(out) 134 | out = self.relu(out) 135 | 136 | out = self.conv3(out) 137 | out = self.bn3(out) 138 | 139 | if self.downsample is not None: 140 | identity = self.downsample(x) 141 | 142 | out += identity 143 | out = self.relu(out) 144 | 145 | return out 146 | 147 | 148 | class ResNet(nn.Module): 149 | def __init__( 150 | self, 151 | block: Type[Union[BasicBlock, Bottleneck]], 152 | layers: List[int], 153 | num_classes: int = 1000, 154 | zero_init_residual: bool = False, 155 | groups: int = 1, 156 | width_per_group: int = 64, 157 | replace_stride_with_dilation: Optional[List[bool]] = None, 158 | norm_layer: Optional[Callable[..., nn.Module]] = None, 159 | ) -> None: 160 | super().__init__() 161 | _log_api_usage_once(self) 162 | if norm_layer is None: 163 | norm_layer = nn.BatchNorm2d 164 | self._norm_layer = norm_layer 165 | 166 | self.inplanes = 64 167 | self.dilation = 1 168 | if replace_stride_with_dilation is None: 169 | # each element in the tuple indicates if we should replace 170 | # the 2x2 stride with a dilated convolution instead 171 | replace_stride_with_dilation = [False, False, False] 172 | if len(replace_stride_with_dilation) != 3: 173 | raise ValueError( 174 | "replace_stride_with_dilation should be None " 175 | f"or a 3-element tuple, got {replace_stride_with_dilation}" 176 | ) 177 | self.groups = groups 178 | self.base_width = width_per_group 179 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 180 | self.bn1 = norm_layer(self.inplanes) 181 | self.relu = nn.ReLU(inplace=True) 182 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 183 | self.layer1 = self._make_layer(block, 64, layers[0]) 184 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 185 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) 186 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 187 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 188 | self.fc = nn.Linear(512 * block.expansion, num_classes) 189 | 190 | for m in self.modules(): 191 | if isinstance(m, nn.Conv2d): 192 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 193 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 194 | nn.init.constant_(m.weight, 1) 195 | nn.init.constant_(m.bias, 0) 196 | 197 | # Zero-initialize the last BN in each residual branch, 198 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 199 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 200 | if zero_init_residual: 201 | for m in self.modules(): 202 | if isinstance(m, Bottleneck) and m.bn3.weight is not None: 203 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 204 | elif isinstance(m, BasicBlock) and m.bn2.weight is not None: 205 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 206 | 207 | def _make_layer( 208 | self, 209 | block: Type[Union[BasicBlock, Bottleneck]], 210 | planes: int, 211 | blocks: int, 212 | stride: int = 1, 213 | dilate: bool = False, 214 | ) -> nn.Sequential: 215 | norm_layer = self._norm_layer 216 | downsample = None 217 | previous_dilation = self.dilation 218 | if dilate: 219 | self.dilation *= stride 220 | stride = 1 221 | if stride != 1 or self.inplanes != planes * block.expansion: 222 | downsample = nn.Sequential( 223 | conv1x1(self.inplanes, planes * block.expansion, stride), 224 | norm_layer(planes * block.expansion), 225 | ) 226 | 227 | layers = [] 228 | layers.append( 229 | block( 230 | self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer 231 | ) 232 | ) 233 | self.inplanes = planes * block.expansion 234 | for _ in range(1, blocks): 235 | layers.append( 236 | block( 237 | self.inplanes, 238 | planes, 239 | groups=self.groups, 240 | base_width=self.base_width, 241 | dilation=self.dilation, 242 | norm_layer=norm_layer, 243 | ) 244 | ) 245 | 246 | return nn.Sequential(*layers) 247 | 248 | def _forward_impl(self, x: Tensor) -> Tensor: 249 | # See note [TorchScript super()] 250 | x = self.conv1(x) 251 | x = self.bn1(x) 252 | x = self.relu(x) 253 | x = self.maxpool(x) 254 | 255 | x = self.layer1(x) 256 | x = self.layer2(x) 257 | x = self.layer3(x) 258 | x = self.layer4(x) 259 | 260 | x = self.avgpool(x) 261 | x = torch.flatten(x, 1) 262 | x = self.fc(x) 263 | 264 | return x 265 | 266 | def forward(self, x: Tensor) -> Tensor: 267 | return self._forward_impl(x) 268 | 269 | 270 | def _resnet( 271 | block: Type[Union[BasicBlock, Bottleneck]], 272 | layers: List[int], 273 | weights: Optional[WeightsEnum], 274 | progress: bool, 275 | **kwargs: Any, 276 | ) -> ResNet: 277 | if weights is not None: 278 | _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) 279 | 280 | model = ResNet(block, layers, **kwargs) 281 | 282 | if weights is not None: 283 | model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) 284 | 285 | return model 286 | 287 | 288 | _COMMON_META = { 289 | "min_size": (1, 1), 290 | "categories": _IMAGENET_CATEGORIES, 291 | } 292 | 293 | 294 | class ResNet18_Weights(WeightsEnum): 295 | IMAGENET1K_V1 = Weights( 296 | url="https://download.pytorch.org/models/resnet18-f37072fd.pth", 297 | transforms=partial(ImageClassification, crop_size=224), 298 | meta={ 299 | **_COMMON_META, 300 | "num_params": 11689512, 301 | "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", 302 | "_metrics": { 303 | "ImageNet-1K": { 304 | "acc@1": 69.758, 305 | "acc@5": 89.078, 306 | } 307 | }, 308 | "_ops": 1.814, 309 | "_file_size": 44.661, 310 | "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", 311 | }, 312 | ) 313 | DEFAULT = IMAGENET1K_V1 314 | 315 | 316 | @handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1)) 317 | def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: 318 | """ResNet-18 from `Deep Residual Learning for Image Recognition `__. 319 | 320 | Args: 321 | weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The 322 | pretrained weights to use. See 323 | :class:`~torchvision.models.ResNet18_Weights` below for 324 | more details, and possible values. By default, no pre-trained 325 | weights are used. 326 | progress (bool, optional): If True, displays a progress bar of the 327 | download to stderr. Default is True. 328 | **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` 329 | base class. Please refer to the `source code 330 | `_ 331 | for more details about this class. 332 | 333 | .. autoclass:: torchvision.models.ResNet18_Weights 334 | :members: 335 | """ 336 | weights = ResNet18_Weights.verify(weights) 337 | 338 | return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) 339 | 340 | -------------------------------------------------------------------------------- /Resnet-FX-Graph-Mode-Quant/qconfigs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.quantization as tq 3 | from torch.ao.quantization.fake_quantize import FakeQuantize 4 | from torch.ao.quantization._learnable_fake_quantize import ( 5 | _LearnableFakeQuantize as LearnableFakeQuantize, 6 | ) 7 | 8 | learnable_act = lambda range : LearnableFakeQuantize.with_args( 9 | observer=tq.HistogramObserver, 10 | quant_min=0, 11 | quant_max=255, 12 | dtype=torch.quint8, 13 | qscheme=torch.per_tensor_affine, 14 | scale=range / 255.0, 15 | zero_point=0.0, 16 | use_grad_scaling=True, 17 | ) 18 | 19 | learnable_weights = lambda channels : LearnableFakeQuantize.with_args( # need to specify number of channels here 20 | observer=tq.PerChannelMinMaxObserver, 21 | quant_min=-128, 22 | quant_max=127, 23 | dtype=torch.qint8, 24 | qscheme=torch.per_channel_symmetric, 25 | scale=0.1, 26 | zero_point=0.0, 27 | use_grad_scaling=True, 28 | channel_len=channels, 29 | ) 30 | 31 | fake_quant_act = FakeQuantize.with_args( 32 | observer=tq.HistogramObserver.with_args( 33 | quant_min=0, 34 | quant_max=255, 35 | dtype=torch.quint8, 36 | qscheme=torch.per_tensor_affine, 37 | ), 38 | ) -------------------------------------------------------------------------------- /Resnet-FX-QAT/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ -------------------------------------------------------------------------------- /Resnet-FX-QAT/README.md: -------------------------------------------------------------------------------- 1 | # How to Quantization Aware Training (QAT): PyTorch ResNet Coding tutorial 2 | 3 | This is the finished code associated with the YouTube tutorial at: 4 | 5 | [![Quantization Aware Training (QAT) With a Custom DataLoader: Beginner's Tutorial to Training Loops](https://ytcards.demolab.com/?id=s3tqqBaRuHE&title=Quantization+Aware+Training+%28QAT%29+With+a+Custom+DataLoader%3A+Beginner%27s+Tutorial+to+Training+Loops&lang=en×tamp=1712648353&background_color=%230d1117&title_color=%23ffffff&stats_color=%23dedede&max_title_lines=1&width=250&border_radius=5 "Quantization Aware Training (QAT) With a Custom DataLoader: Beginner's Tutorial to Training Loops")](https://www.youtube.com/watch?v=s3tqqBaRuHE) 6 | 7 | This code is built from the code for the FX Graph mode tutorial, located in `Resnet-FX-Graph-Mode-Quant`. 8 | However, we modularize some stuff, and build up some of the functions a bit more. 9 | 10 | 11 | ### Prerequisites: 12 | #### Installing PyTorch: 13 | To run this code, you need to have PyTorch installed in your environment. If you do not have PyTorch installed, please follow this [official guide](https://pytorch.org/get-started/locally/). 14 | 15 | I created this code with PyTorch Version: 2.1.1. In case you have any versioning issues, you can revert to that version. 16 | 17 | #### Printing the FX graph: 18 | To run `fx_model.graph.print_tabular()`, one needs to have `tabulate` installed. To do, activate your (e.g. conda) environment and run 19 | ``` 20 | pip install tabulate 21 | ``` 22 | 23 | #### Printing the FX graph: 24 | For this tutorial, I downloaded some images form google search, one example each for a handful of the ImageNet classes. 25 | You can add whatever ImageNet class examples you want, but be make sure to you name the images the same as the class names, e.g. `hen.jpg` for classname `hen`. 26 | Or, feel free to generalise the code so that isn't a constraint! 27 | 28 | 29 | ### Running this code: 30 | Once you have PyTorch installed, first navigate to a directory you will be working from. As you follow the next steps, your final file structure will look like this: `your-directory/Resnet-FX-QAT`. 31 | 32 | Next, from `your-directory`, clone the `Quantization-Tutorials` repo. This repo contains different tutorials, but they are all interlinked. Feel no need to do any of the others! I just structured it this way because the tutorials share a lot of code and it might help people to see different parts in one place. 33 | 34 | You can also `git init` and then `git pull/fetch`, depending on what you prefer. 35 | 36 | To clone the repo, run: 37 | ``` 38 | git clone git@github.com:OscarSavolainenDR/Quantization-Tutorials.git . 39 | ``` 40 | 41 | If you did the cloning in place with the `.` at the end, your folder structure should look like `your-folder/Resnet-FX-QAT`, with various other folders for other tutorials. 42 | 43 | Next, cd into the Resnet FX QAT tutorial: 44 | ``` 45 | cd Resnet-FX-QAT 46 | ``` 47 | Then, just run `python main.py` from your command line! However I would obviously recommend that you follow along with the tutorial, so that you learn how it all works and get your hands dirty. 48 | 49 | Let me know if there are any issues! 50 | -------------------------------------------------------------------------------- /Resnet-FX-QAT/evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluate import evaluate -------------------------------------------------------------------------------- /Resnet-FX-QAT/evaluate/evaluate.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch 3 | 4 | def evaluate(model, device_str: str, target: str): 5 | # Download an example image from the pytorch website 6 | import urllib 7 | filename = Path(f"evaluate/images/{target}.jpg") 8 | 9 | from PIL import Image 10 | from torchvision import transforms 11 | input_image = Image.open(filename) 12 | preprocess = transforms.Compose([ 13 | transforms.Resize(256), 14 | transforms.CenterCrop(224), 15 | transforms.ToTensor(), 16 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 17 | ]) 18 | input_tensor = preprocess(input_image) 19 | input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model 20 | 21 | # move the input and model to GPU for speed if available, or to CPU if converted 22 | if not (device_str in['cpu', 'cuda']): 23 | raise NotImplementedError("`device_str` should be 'cpu' or 'cuda' ") 24 | if device_str == 'cuda': 25 | assert torch.cuda.is_available(), 'Check CUDA is available' 26 | 27 | input_batch = input_batch.to(device_str) 28 | model.to(device_str) 29 | model.eval() 30 | 31 | with torch.no_grad(): 32 | output = model(input_batch) 33 | # Tensor of shape 1000, with confidence scores over ImageNet's 1000 classes 34 | # print(output[0]) 35 | # The output has unnormalized scores. To get probabilities, you can run a softmax on it. 36 | probabilities = torch.nn.functional.softmax(output[0], dim=0) 37 | # print(probabilities) 38 | 39 | # Read the categories 40 | with open(Path("evaluate/imagenet_classes.txt"), "r") as f: 41 | categories = [s.strip() for s in f.readlines()] 42 | # Show top categories per image 43 | top5_prob, top5_catid = torch.topk(probabilities, 5) 44 | for i in range(top5_prob.size(0)): 45 | print(categories[top5_catid[i]], top5_prob[i].item()) 46 | print('\n') 47 | -------------------------------------------------------------------------------- /Resnet-FX-QAT/evaluate/images/README.md: -------------------------------------------------------------------------------- 1 | ## The ground truth images 2 | 3 | **IMPORTANT**: the image names should exactly match the ImageNet class names, e.g. `hen`, `Samoyed`, etc. -------------------------------------------------------------------------------- /Resnet-FX-QAT/evaluate/images/Samoyed.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-QAT/evaluate/images/Samoyed.jpg -------------------------------------------------------------------------------- /Resnet-FX-QAT/evaluate/images/clog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-QAT/evaluate/images/clog.jpg -------------------------------------------------------------------------------- /Resnet-FX-QAT/evaluate/images/hen.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-QAT/evaluate/images/hen.jpg -------------------------------------------------------------------------------- /Resnet-FX-QAT/evaluate/images/mail_box.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OscarSavolainen/Quantization-Tutorials/f5bb29c45144fc69e5074a4352f4bd8c70d289cf/Resnet-FX-QAT/evaluate/images/mail_box.jpg -------------------------------------------------------------------------------- /Resnet-FX-QAT/main.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Any, Union, List 2 | from copy import deepcopy 3 | 4 | import torch 5 | import torch.quantization as tq 6 | from torch.ao.quantization.quantize_fx import prepare_qat_fx 7 | from torch.ao.quantization.qconfig_mapping import QConfigMapping 8 | 9 | from evaluate import evaluate 10 | from utils.qconfigs import learnable_act, learnable_weights, fake_quant_act 11 | from utils.ipdb_hook import ipdb_sys_excepthook 12 | from utils.graph_manip import convbn_to_conv 13 | 14 | from model.resnet import resnet18 15 | 16 | # Adds ipdb breakpoint if and where we have an error 17 | ipdb_sys_excepthook() 18 | 19 | # Intialize model 20 | model = resnet18(pretrained=True) 21 | 22 | # Define qconfigs 23 | qconfig_global = tq.QConfig( 24 | activation=fake_quant_act, 25 | weight=tq.default_fused_per_channel_wt_fake_quant 26 | ) 27 | 28 | 29 | # Assign qconfigs 30 | qconfig_mapping = QConfigMapping().set_global(qconfig_global) 31 | 32 | # We loop through the modules so that we can access the `out_channels` attribute 33 | for name, module in model.named_modules(): 34 | if hasattr(module, 'out_channels'): 35 | qconfig = tq.QConfig( 36 | activation=learnable_act(range=2), 37 | weight=learnable_weights(channels=module.out_channels) 38 | ) 39 | qconfig_mapping.set_module_name(name, qconfig) 40 | 41 | # Do symbolic tracing and quantization 42 | example_inputs = (torch.randn(1, 3, 224, 224),) 43 | model.eval() 44 | fx_model = prepare_qat_fx(model, qconfig_mapping, example_inputs) 45 | 46 | # Evaluate model 47 | print('\nOriginal') 48 | evaluate(model, 'cpu', 'Samoyed') 49 | 50 | print('\nFX prepared') 51 | evaluate(fx_model, 'cpu', 'Samoyed') 52 | 53 | # Prints the graph as a table 54 | print("\nGraph as a Table:\n") 55 | fx_model.graph.print_tabular() 56 | 57 | # Fuses Batchnorms into preceding convs 58 | transformed : torch.fx.GraphModule = convbn_to_conv(deepcopy(fx_model)) 59 | input = example_inputs[0] 60 | out = transformed(input) # Test we can feed something through the model 61 | print('\nTransformed model evaluation:') 62 | evaluate(transformed, 'cpu', 'Samoyed') 63 | 64 | 65 | 66 | ############################### 67 | # Quantization Aware Training # 68 | ############################### 69 | from torch.optim import Adam 70 | from torch.nn import CrossEntropyLoss 71 | from PIL import Image 72 | from torchvision import transforms 73 | from pathlib import Path 74 | 75 | optim = Adam(transformed.parameters(), lr=1e-5) 76 | loss_fn = CrossEntropyLoss() 77 | 78 | # Used to get the index of the target image in the imagenet classes 79 | def find_row_with_string(file_path, target_string): 80 | with open(file_path, 'r') as file: 81 | for line_number, line in enumerate(file): 82 | if target_string in line: 83 | return line_number 84 | return None # String not found in any row 85 | 86 | def batch_images(targets: List[str], images_path: str, labels_path: Path): 87 | """ 88 | Takes image labels (e.g. 'Samoyed'), and batches the processed image tensor together. 89 | It also produces a batched one-hot tensor, with the different images across the batch dimension. 90 | I.e., given a list of image names, it produces a batch of processed images and their onehot labels. 91 | """ 92 | first_image = True 93 | for target in targets: 94 | # Get label index to create onehot vector 95 | row_number = find_row_with_string(labels_path, target) 96 | one_hot_label = torch.zeros(1000) 97 | one_hot_label[row_number] = 1 98 | 99 | # Get image 100 | input_image = Image.open(Path(f"{images_path}/{target}.jpg")) 101 | preprocess = transforms.Compose([ 102 | transforms.Resize(256), 103 | transforms.CenterCrop(224), 104 | transforms.ToTensor(), 105 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 106 | ]) 107 | input_tensor = preprocess(input_image) 108 | 109 | # Batch image and labels 110 | if first_image: 111 | input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model 112 | label_batch = one_hot_label.unsqueeze(0) 113 | first_image = False 114 | else: 115 | input_batch = torch.cat((input_batch, input_tensor.unsqueeze(0)), dim=0) 116 | label_batch = torch.cat((label_batch, one_hot_label.unsqueeze(0)), dim=0) 117 | 118 | return input_batch, label_batch 119 | 120 | # Get input and ouput batches 121 | images_path = Path("evaluate/images") # Path to the images 122 | labels_path = Path('evaluate/imagenet_classes.txt') # Path to your text file with imagenet class labels 123 | batched_images, batched_labels = batch_images(['hen', 'Samoyed'], images_path=images_path, labels_path=labels_path) 124 | 125 | 126 | def print_scale_and_zp(model, module_name): 127 | module = model.get_submodule(module_name) 128 | scale = module.scale 129 | zero_point = module.zero_point 130 | print(f"{module_name} scale and zero_point: {scale.item():.5}, {zero_point.item()}") 131 | 132 | # Print out some paramaters before we do QAT 133 | print('\nBefore QAT:') 134 | print_scale_and_zp(transformed, 'activation_post_process_0') 135 | 136 | # Training loop where we do QAT 137 | mean_loss, counter = 0, 0 138 | log_freq = 10 139 | print('\nTraining loop') 140 | for epoch in range(50): 141 | y_pred = transformed(batched_images) 142 | probabilities = torch.nn.functional.softmax(y_pred, dim=1) 143 | 144 | loss = loss_fn(probabilities, batched_labels) 145 | optim.zero_grad() 146 | loss.backward() 147 | optim.step() 148 | 149 | counter += 1 150 | if counter % log_freq == 0: 151 | mean_loss += loss.item() 152 | print(f"Iter: {counter}, Mean loss: {(mean_loss/log_freq):.5}") 153 | mean_loss = 0 154 | 155 | print('\nAfter QAT:') 156 | print_scale_and_zp(transformed, 'activation_post_process_0') 157 | 158 | # Post QAT evaluations 159 | print('QAT model evaluation (Samoyed):') 160 | evaluate(transformed, 'cpu', 'Samoyed') 161 | 162 | # Check performance on hen 163 | print('QAT model evaluation (hen):') 164 | evaluate(transformed, 'cpu', 'hen') 165 | 166 | # Check performance on clog (which we did not overfit to) 167 | print('QAT model evaluation (clog):') 168 | evaluate(transformed, 'cpu', 'clog') 169 | 170 | # Check performance on clog (which we did not overfit to) 171 | print('QAT model evaluation (mail box):') 172 | evaluate(transformed, 'cpu', 'mail_box') 173 | XXX 174 | -------------------------------------------------------------------------------- /Resnet-FX-QAT/model/resnet.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Callable, List, Optional, Type, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch import Tensor 7 | 8 | from torchvision.transforms._presets import ImageClassification 9 | from torchvision.utils import _log_api_usage_once 10 | from torchvision.models._api import register_model, Weights, WeightsEnum 11 | from torchvision.models._meta import _IMAGENET_CATEGORIES 12 | from torchvision.models._utils import _ovewrite_named_param, handle_legacy_interface 13 | 14 | 15 | __all__ = [ 16 | "ResNet", 17 | "ResNet18_Weights", 18 | "resnet18", 19 | ] 20 | 21 | 22 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 23 | """3x3 convolution with padding""" 24 | return nn.Conv2d( 25 | in_planes, 26 | out_planes, 27 | kernel_size=3, 28 | stride=stride, 29 | padding=dilation, 30 | groups=groups, 31 | bias=False, 32 | dilation=dilation, 33 | ) 34 | 35 | 36 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 37 | """1x1 convolution""" 38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 39 | 40 | 41 | class BasicBlock(nn.Module): 42 | expansion: int = 1 43 | 44 | def __init__( 45 | self, 46 | inplanes: int, 47 | planes: int, 48 | stride: int = 1, 49 | downsample: Optional[nn.Module] = None, 50 | groups: int = 1, 51 | base_width: int = 64, 52 | dilation: int = 1, 53 | norm_layer: Optional[Callable[..., nn.Module]] = None, 54 | ) -> None: 55 | super().__init__() 56 | if norm_layer is None: 57 | norm_layer = nn.BatchNorm2d 58 | if groups != 1 or base_width != 64: 59 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 60 | if dilation > 1: 61 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 62 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 63 | self.conv1 = conv3x3(inplanes, planes, stride) 64 | self.bn1 = norm_layer(planes) 65 | self.relu = nn.ReLU(inplace=True) 66 | self.conv2 = conv3x3(planes, planes) 67 | self.bn2 = norm_layer(planes) 68 | self.downsample = downsample 69 | self.stride = stride 70 | 71 | def forward(self, x: Tensor) -> Tensor: 72 | identity = x 73 | 74 | out = self.conv1(x) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv2(out) 79 | out = self.bn2(out) 80 | 81 | if self.downsample is not None: 82 | identity = self.downsample(x) 83 | 84 | out += identity 85 | out = self.relu(out) 86 | 87 | return out 88 | 89 | 90 | class Bottleneck(nn.Module): 91 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 92 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 93 | # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385. 94 | # This variant is also known as ResNet V1.5 and improves accuracy according to 95 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 96 | 97 | expansion: int = 4 98 | 99 | def __init__( 100 | self, 101 | inplanes: int, 102 | planes: int, 103 | stride: int = 1, 104 | downsample: Optional[nn.Module] = None, 105 | groups: int = 1, 106 | base_width: int = 64, 107 | dilation: int = 1, 108 | norm_layer: Optional[Callable[..., nn.Module]] = None, 109 | ) -> None: 110 | super().__init__() 111 | if norm_layer is None: 112 | norm_layer = nn.BatchNorm2d 113 | width = int(planes * (base_width / 64.0)) * groups 114 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 115 | self.conv1 = conv1x1(inplanes, width) 116 | self.bn1 = norm_layer(width) 117 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 118 | self.bn2 = norm_layer(width) 119 | self.conv3 = conv1x1(width, planes * self.expansion) 120 | self.bn3 = norm_layer(planes * self.expansion) 121 | self.relu = nn.ReLU(inplace=True) 122 | self.downsample = downsample 123 | self.stride = stride 124 | 125 | def forward(self, x: Tensor) -> Tensor: 126 | identity = x 127 | 128 | out = self.conv1(x) 129 | out = self.bn1(out) 130 | out = self.relu(out) 131 | 132 | out = self.conv2(out) 133 | out = self.bn2(out) 134 | out = self.relu(out) 135 | 136 | out = self.conv3(out) 137 | out = self.bn3(out) 138 | 139 | if self.downsample is not None: 140 | identity = self.downsample(x) 141 | 142 | out += identity 143 | out = self.relu(out) 144 | 145 | return out 146 | 147 | 148 | class ResNet(nn.Module): 149 | def __init__( 150 | self, 151 | block: Type[Union[BasicBlock, Bottleneck]], 152 | layers: List[int], 153 | num_classes: int = 1000, 154 | zero_init_residual: bool = False, 155 | groups: int = 1, 156 | width_per_group: int = 64, 157 | replace_stride_with_dilation: Optional[List[bool]] = None, 158 | norm_layer: Optional[Callable[..., nn.Module]] = None, 159 | ) -> None: 160 | super().__init__() 161 | _log_api_usage_once(self) 162 | if norm_layer is None: 163 | norm_layer = nn.BatchNorm2d 164 | self._norm_layer = norm_layer 165 | 166 | self.inplanes = 64 167 | self.dilation = 1 168 | if replace_stride_with_dilation is None: 169 | # each element in the tuple indicates if we should replace 170 | # the 2x2 stride with a dilated convolution instead 171 | replace_stride_with_dilation = [False, False, False] 172 | if len(replace_stride_with_dilation) != 3: 173 | raise ValueError( 174 | "replace_stride_with_dilation should be None " 175 | f"or a 3-element tuple, got {replace_stride_with_dilation}" 176 | ) 177 | self.groups = groups 178 | self.base_width = width_per_group 179 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 180 | self.bn1 = norm_layer(self.inplanes) 181 | self.relu = nn.ReLU(inplace=True) 182 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 183 | self.layer1 = self._make_layer(block, 64, layers[0]) 184 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 185 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) 186 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 187 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 188 | self.fc = nn.Linear(512 * block.expansion, num_classes) 189 | 190 | for m in self.modules(): 191 | if isinstance(m, nn.Conv2d): 192 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 193 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 194 | nn.init.constant_(m.weight, 1) 195 | nn.init.constant_(m.bias, 0) 196 | 197 | # Zero-initialize the last BN in each residual branch, 198 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 199 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 200 | if zero_init_residual: 201 | for m in self.modules(): 202 | if isinstance(m, Bottleneck) and m.bn3.weight is not None: 203 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 204 | elif isinstance(m, BasicBlock) and m.bn2.weight is not None: 205 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 206 | 207 | def _make_layer( 208 | self, 209 | block: Type[Union[BasicBlock, Bottleneck]], 210 | planes: int, 211 | blocks: int, 212 | stride: int = 1, 213 | dilate: bool = False, 214 | ) -> nn.Sequential: 215 | norm_layer = self._norm_layer 216 | downsample = None 217 | previous_dilation = self.dilation 218 | if dilate: 219 | self.dilation *= stride 220 | stride = 1 221 | if stride != 1 or self.inplanes != planes * block.expansion: 222 | downsample = nn.Sequential( 223 | conv1x1(self.inplanes, planes * block.expansion, stride), 224 | norm_layer(planes * block.expansion), 225 | ) 226 | 227 | layers = [] 228 | layers.append( 229 | block( 230 | self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer 231 | ) 232 | ) 233 | self.inplanes = planes * block.expansion 234 | for _ in range(1, blocks): 235 | layers.append( 236 | block( 237 | self.inplanes, 238 | planes, 239 | groups=self.groups, 240 | base_width=self.base_width, 241 | dilation=self.dilation, 242 | norm_layer=norm_layer, 243 | ) 244 | ) 245 | 246 | return nn.Sequential(*layers) 247 | 248 | def _forward_impl(self, x: Tensor) -> Tensor: 249 | # See note [TorchScript super()] 250 | x = self.conv1(x) 251 | x = self.bn1(x) 252 | x = self.relu(x) 253 | x = self.maxpool(x) 254 | 255 | x = self.layer1(x) 256 | x = self.layer2(x) 257 | x = self.layer3(x) 258 | x = self.layer4(x) 259 | 260 | x = self.avgpool(x) 261 | x = torch.flatten(x, 1) 262 | x = self.fc(x) 263 | 264 | return x 265 | 266 | def forward(self, x: Tensor) -> Tensor: 267 | return self._forward_impl(x) 268 | 269 | 270 | def _resnet( 271 | block: Type[Union[BasicBlock, Bottleneck]], 272 | layers: List[int], 273 | weights: Optional[WeightsEnum], 274 | progress: bool, 275 | **kwargs: Any, 276 | ) -> ResNet: 277 | if weights is not None: 278 | _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) 279 | 280 | model = ResNet(block, layers, **kwargs) 281 | 282 | if weights is not None: 283 | model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) 284 | 285 | return model 286 | 287 | 288 | _COMMON_META = { 289 | "min_size": (1, 1), 290 | "categories": _IMAGENET_CATEGORIES, 291 | } 292 | 293 | 294 | class ResNet18_Weights(WeightsEnum): 295 | IMAGENET1K_V1 = Weights( 296 | url="https://download.pytorch.org/models/resnet18-f37072fd.pth", 297 | transforms=partial(ImageClassification, crop_size=224), 298 | meta={ 299 | **_COMMON_META, 300 | "num_params": 11689512, 301 | "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", 302 | "_metrics": { 303 | "ImageNet-1K": { 304 | "acc@1": 69.758, 305 | "acc@5": 89.078, 306 | } 307 | }, 308 | "_ops": 1.814, 309 | "_file_size": 44.661, 310 | "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", 311 | }, 312 | ) 313 | DEFAULT = IMAGENET1K_V1 314 | 315 | 316 | @handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1)) 317 | def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: 318 | """ResNet-18 from `Deep Residual Learning for Image Recognition `__. 319 | 320 | Args: 321 | weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The 322 | pretrained weights to use. See 323 | :class:`~torchvision.models.ResNet18_Weights` below for 324 | more details, and possible values. By default, no pre-trained 325 | weights are used. 326 | progress (bool, optional): If True, displays a progress bar of the 327 | download to stderr. Default is True. 328 | **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` 329 | base class. Please refer to the `source code 330 | `_ 331 | for more details about this class. 332 | 333 | .. autoclass:: torchvision.models.ResNet18_Weights 334 | :members: 335 | """ 336 | weights = ResNet18_Weights.verify(weights) 337 | 338 | return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) 339 | 340 | -------------------------------------------------------------------------------- /Resnet-FX-QAT/utils/graph_manip.py: -------------------------------------------------------------------------------- 1 | ######################### 2 | # SOME GRAPH TECHNIQUES # 3 | ######################### 4 | # Experiment with iterator pattern: 5 | # NOTE: taken from https://pytorch.org/docs/stable/fx.html#the-interpreter-pattern 6 | import torch 7 | from torch import fx 8 | from torch.fx.node import Node 9 | from typing import Dict, Union, Tuple, Any 10 | 11 | class GraphIteratorStorage: 12 | """ 13 | A general Iterator over the graph. This class takes a `GraphModule`, 14 | and a callable `storage` representing a function that will store some 15 | attribute for each node when the `propagate` method is called. 16 | 17 | Its `propagate` method executes the `GraphModule` 18 | node-by-node with the given arguments, e.g. an example input tensor. 19 | As each operation executes, the GraphIteratorStorage class stores 20 | away the result of the callable for the output values of each operation on 21 | the attributes of the operation's `Node`. For example, 22 | one could use a callable `store_shaped_dtype()` where: 23 | 24 | ``` 25 | def store_shape_dtype(result): 26 | if isinstance(result, torch.Tensor): 27 | node.shape = result.shape 28 | node.dtype = result.dtype 29 | ``` 30 | This would store the `shape` and `dtype` of each operation on 31 | its respective `Node`, for the given input to `propagate`. 32 | """ 33 | def __init__(self, mod, storage): 34 | self.mod = mod 35 | self.graph = mod.graph 36 | self.modules = dict(self.mod.named_modules()) 37 | self.storage = storage 38 | 39 | def propagate(self, *args): 40 | args_iter = iter(args) 41 | env : Dict[str, Node] = {} 42 | 43 | def load_arg(a): 44 | return torch.fx.graph.map_arg(a, lambda n: env[n.name]) 45 | 46 | def fetch_attr(target : str): 47 | target_atoms = target.split('.') 48 | attr_itr = self.mod 49 | for i, atom in enumerate(target_atoms): 50 | if not hasattr(attr_itr, atom): 51 | raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}") 52 | attr_itr = getattr(attr_itr, atom) 53 | return attr_itr 54 | 55 | # Each of these `result` variables correpsonds to the output of the node in question 56 | for node in self.graph.nodes: 57 | if node.op == 'placeholder': 58 | result = next(args_iter) 59 | elif node.op == 'get_attr': 60 | result = fetch_attr(node.target) 61 | elif node.op == 'call_function': 62 | result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) 63 | elif node.op == 'call_method': 64 | self_obj, *args = load_arg(node.args) 65 | kwargs = load_arg(node.kwargs) 66 | result = getattr(self_obj, node.target)(*args, **kwargs) 67 | elif node.op == 'call_module': 68 | module = self.mod.get_submodule(node.target) # this is robust to "shared" nodes. 69 | result = module(*load_arg(node.args), **load_arg(node.kwargs)) 70 | 71 | # This is the only code specific to the `storage` function. 72 | # you can delete this line and this `propagate` function becomes 73 | # a generic GraphModule interpreter 74 | self.storage(node, result) 75 | 76 | # Store the output activation in `env` for the given node 77 | env[node.name] = result 78 | 79 | #return load_arg(self.graph.result) 80 | 81 | def store_shape_dtype(node, result): 82 | """ 83 | Function that takes in the current node, and the tensor it is operating on (`result`) 84 | and stores the shape and dtype of `result` on the node as attributes. 85 | """ 86 | if isinstance(result, torch.Tensor): 87 | node.shape = result.shape 88 | node.dtype = result.dtype 89 | # NOTE: I just discovered that they have the `Interpreter` class, which accomplishes the same thing: 90 | # https://pytorch.org/docs/stable/fx.html#torch.fx.Interpreter 91 | # GraphIteratorStorage(fx_model, store_shape_dtype).propagate(example_inputs[0]) 92 | 93 | #print("\nPrinting size and data types:") 94 | #for node in fx_model.graph.nodes: 95 | #print(node.name, node.shape, node.dtype) 96 | 97 | 98 | 99 | ########################################### 100 | ### Fusing Bn in ConvBnReLU into ConvReLU # 101 | ########################################### 102 | from torch.ao.nn.intrinsic.qat.modules.conv_fused import ConvBnReLU2d, ConvReLU2d, ConvBn2d 103 | from torch.ao.nn.qat import Conv2d 104 | 105 | def fuse_conv_bn_relu_eval(conv: Union[ConvBnReLU2d, ConvBn2d]) -> Union[ConvReLU2d, Conv2d]: 106 | """ 107 | Given a quantizable ConvBnReLU2d Module returns a quantizable ConvReLU2d 108 | module such that the BatchNorm has been fused into the Conv, in inference mode. 109 | Given a ConvBn2d, it does the same to produce a Conv2d. 110 | One could also use `torch.nn.utils.fuse_conv_bn_eval` to produce a Conv, and then quantize that as desired. 111 | """ 112 | assert(not (conv.training or conv.bn.training)), "Fusion only for eval!" 113 | qconfig = conv.qconfig 114 | if type(conv) is ConvBnReLU2d: 115 | new_conv = ConvReLU2d(conv.in_channels, conv.out_channels, conv.kernel_size, 116 | conv.stride, conv.padding, conv.dilation, 117 | conv.groups, conv.bias is not None, 118 | conv.padding_mode, qconfig=qconfig) 119 | elif type(conv) is ConvBn2d: 120 | new_conv = Conv2d(conv.in_channels, conv.out_channels, conv.kernel_size, 121 | conv.stride, conv.padding, conv.dilation, 122 | conv.groups, conv.bias is not None, 123 | conv.padding_mode, qconfig=qconfig) 124 | 125 | 126 | new_conv.weight, new_conv.bias = \ 127 | fuse_conv_bn_weights(conv.weight, conv.bias, 128 | conv.bn.running_mean, conv.bn.running_var, conv.bn.eps, conv.bn.weight, conv.bn.bias) 129 | 130 | return new_conv 131 | 132 | def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b): 133 | """ 134 | Helper function for fusing a Conv and BatchNorm into a single weight/bias tensor pair. 135 | """ 136 | if conv_b is None: 137 | conv_b = torch.zeros_like(bn_rm) 138 | if bn_w is None: 139 | bn_w = torch.ones_like(bn_rm) 140 | if bn_b is None: 141 | bn_b = torch.zeros_like(bn_rm) 142 | bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) 143 | 144 | conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1)) 145 | conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b 146 | 147 | return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b) 148 | 149 | # Graph manipulation functions for fusing Convs and BatchNorms 150 | def _parent_name(target : str) -> Tuple[str, str]: 151 | """ 152 | Splits a qualname into parent path and last atom. 153 | For example, `foo.bar.baz` -> (`foo.bar`, `baz`) 154 | """ 155 | *parent, name = target.rsplit('.', 1) 156 | return parent[0] if parent else '', name 157 | 158 | def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module): 159 | """ 160 | Helper function for having `new_module` take the place of `node` in a dict of modules. 161 | """ 162 | assert(isinstance(node.target, str)) 163 | parent_name, name = _parent_name(node.target) 164 | #modules[node.target] = new_module 165 | setattr(modules[parent_name], name, new_module) 166 | 167 | def convbn_to_conv(fx_model: torch.fx.GraphModule) -> torch.fx.GraphModule: 168 | """ 169 | Iterates through the graph nodes, and: 170 | - where it finds a ConvBnReLU it replaces it with ConvReLU 171 | - where it finds a ConvBn it replaces it with Conv 172 | 173 | This function works in-place on `fx_model`. 174 | 175 | Inputs: 176 | fx_model (torch.fx.GraphModule): a graph module, that we want to perform transformations on 177 | 178 | Output: 179 | (torch.fx.GraphModule): a model where we have swapped out the 2d ConvBn/ConvBnReLU for Conv/ConvReLU, and 180 | fused the Bns into the Convs. 181 | """ 182 | modules = dict(fx_model.named_modules()) 183 | 184 | for node in fx_model.graph.nodes: 185 | # If the operation the node is doing is to call a module 186 | if node.op == 'call_module': 187 | # The current node 188 | orig = fx_model.get_submodule(node.target) 189 | if type(orig) in [ConvBnReLU2d, ConvBn2d]: 190 | # Produce a fused Bn equivalent. 191 | fused_conv = fuse_conv_bn_relu_eval(orig) 192 | # This updates `modules` so that `fused_conv` takes the place of what was represented by `node` 193 | replace_node_module(node, modules, fused_conv) 194 | 195 | return fx_model -------------------------------------------------------------------------------- /Resnet-FX-QAT/utils/ipdb_hook.py: -------------------------------------------------------------------------------- 1 | import traceback, ipdb 2 | import sys 3 | 4 | def ipdb_sys_excepthook(): 5 | """ 6 | When called this function will set up the system exception hook. 7 | This hook throws one into an ipdb breakpoint if and where a system 8 | exception occurs in one's run. 9 | 10 | E.g. 11 | >>> ipdb_sys_excepthook() 12 | """ 13 | 14 | 15 | def info(type, value, tb): 16 | """ 17 | System excepthook that includes an ipdb breakpoint. 18 | """ 19 | if hasattr(sys, 'ps1') or not sys.stderr.isatty(): 20 | # we are in interactive mode or we don't have a tty-like 21 | # device, so we call the default hook 22 | sys.__excepthook__(type, value, tb) 23 | else: 24 | # we are NOT in interactive mode, print the exception... 25 | traceback.print_exception(type, value, tb) 26 | print 27 | # ...then start the debugger in post-mortem mode. 28 | # pdb.pm() # deprecated 29 | ipdb.post_mortem(tb) # more "modern" 30 | sys.excepthook = info -------------------------------------------------------------------------------- /Resnet-FX-QAT/utils/qconfigs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.quantization as tq 3 | from torch.ao.quantization.fake_quantize import FakeQuantize 4 | from torch.ao.quantization._learnable_fake_quantize import ( 5 | _LearnableFakeQuantize as LearnableFakeQuantize, 6 | ) 7 | 8 | learnable_act = lambda range : LearnableFakeQuantize.with_args( 9 | observer=tq.HistogramObserver, 10 | quant_min=0, 11 | quant_max=255, 12 | dtype=torch.quint8, 13 | qscheme=torch.per_tensor_affine, 14 | scale=range / 255.0, 15 | zero_point=0.0, 16 | use_grad_scaling=True, 17 | ) 18 | 19 | learnable_weights = lambda channels : LearnableFakeQuantize.with_args( # need to specify number of channels here 20 | observer=tq.PerChannelMinMaxObserver, 21 | quant_min=-128, 22 | quant_max=127, 23 | dtype=torch.qint8, 24 | qscheme=torch.per_channel_symmetric, 25 | scale=0.1, 26 | zero_point=0.0, 27 | use_grad_scaling=True, 28 | channel_len=channels, 29 | ) 30 | 31 | fake_quant_act = FakeQuantize.with_args( 32 | observer=tq.HistogramObserver.with_args( 33 | quant_min=0, 34 | quant_max=255, 35 | dtype=torch.quint8, 36 | qscheme=torch.per_tensor_affine, 37 | ), 38 | ) --------------------------------------------------------------------------------