├── overview.png ├── utils ├── __init__.py ├── samplers.py ├── data_utils.py ├── train_utils.py └── utils.py ├── models ├── quantization_utils │ ├── __init__.py │ ├── quant_utils.py │ └── quant_modules.py ├── __init__.py ├── model_utils.py ├── layers_quant.py ├── utils.py ├── vit_quant.py └── swin_quant.py ├── TVM_benchmark ├── README.md ├── models │ ├── build_model.py │ ├── utils.py │ ├── quantized_vit.py │ └── layers.py ├── evaluate_latency.py ├── evaluate_accuracy.py └── convert_model.py ├── README.md ├── LICENSE └── quant_train.py /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zkkli/I-ViT/HEAD/overview.png -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_utils import * 2 | from .train_utils import * 3 | 4 | from .utils import * -------------------------------------------------------------------------------- /models/quantization_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .quant_modules import QuantLinear, QuantAct, QuantConv2d, QuantMatMul, IntLayerNorm, IntSoftmax, IntGELU -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .quantization_utils import * 2 | from .vit_quant import * 3 | from .swin_quant import * 4 | 5 | from .model_utils import * 6 | -------------------------------------------------------------------------------- /models/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .quantization_utils import * 3 | 4 | 5 | def freeze_model(model): 6 | """ 7 | freeze the activation range. Resursively invokes layer.fix() 8 | """ 9 | if type(model) in [QuantAct]: 10 | model.fix() 11 | elif type(model) == nn.Sequential: 12 | for n, m in model.named_children(): 13 | freeze_model(m) 14 | elif type(model) == nn.ModuleList: 15 | for n in model: 16 | freeze_model(n) 17 | else: 18 | for attr in dir(model): 19 | mod = getattr(model, attr) 20 | if isinstance(mod, nn.Module): 21 | freeze_model(mod) 22 | 23 | 24 | def unfreeze_model(model): 25 | """ 26 | unfreeze the activation range. Resursively invokes layer.unfix() 27 | """ 28 | if type(model) in [QuantAct]: 29 | model.unfix() 30 | elif type(model) == nn.Sequential: 31 | for n, m in model.named_children(): 32 | unfreeze_model(m) 33 | elif type(model) == nn.ModuleList: 34 | for n in model: 35 | unfreeze_model(n) 36 | else: 37 | for attr in dir(model): 38 | mod = getattr(model, attr) 39 | if isinstance(mod, nn.Module): 40 | unfreeze_model(mod) 41 | 42 | -------------------------------------------------------------------------------- /TVM_benchmark/README.md: -------------------------------------------------------------------------------- 1 | ## I-ViT: Integer-only Quantization for Efficient Vision Transformer Inference 2 | 3 | Below are instructions for performing integer-only inference of DeiT models on 2080Ti GPU, with separate accuracy and speed evaluations. 4 | 5 | ## 0. Install TVM 6 | - You can follow the official [tutorial](https://tvm.apache.org/docs/install/from_source.html#install-from-source) to install TVM. 7 | 8 | ## 1. Convert Model 9 | - Save checkpoint of QAT in PyTorch (checkpoint.pth.tar) 10 | - Convert Pytorch parameters to TVM parameters (params.npy): 11 | ```bash 12 | python convert_model --model-path --params-path 13 | 14 | Required arguments: 15 | : Path to saved checkpoint of QAT (checkpoint.pth.tar) 16 | : Path to save TVM parameters 17 | ``` 18 | 19 | ## 2. Evaluation 20 | ### 2.1 Accuracy 21 | - You can evaluate the accuracy of a model using the following command: 22 | 23 | ```bash 24 | python evaluate_accuracy.py --model-name --model-path --params-path 25 | 26 | Required arguments: 27 | : Model name, e.g., 'deit_small_patch16_224' 28 | : Path to saved checkpoint of QAT (checkpoint.pth.tar) 29 | : Path to saved TVM parameters (params.npy) 30 | ``` 31 | 32 | ### 2.2 Latency 33 | - You can perform **TVM auto-tuning** and evaluate the latency of a model using the following command: 34 | 35 | ```bash 36 | python evaluate_latency.py --model-name --log-path --target 37 | 38 | Required arguments: 39 | : Model name, e.g., 'deit_small_patch16_224' 40 | : Path to save tuning log 41 | : TVM target, e.g., 'cuda -model=2080ti' 42 | ``` 43 | 44 | ## Citation 45 | 46 | We appreciate it if you would please cite the following paper if you found the implementation useful for your work: 47 | 48 | ```bash 49 | @article{li2022ivit, 50 | title={I-ViT: integer-only quantization for efficient vision transformer inference}, 51 | author={Li, Zhikai and Gu, Qingyi}, 52 | journal={arXiv preprint arXiv:2207.01405}, 53 | year={2022} 54 | } 55 | ``` 56 | -------------------------------------------------------------------------------- /TVM_benchmark/models/build_model.py: -------------------------------------------------------------------------------- 1 | from .quantized_vit import Q_VisionTransformer 2 | from .utils import create_workload, QuantizeInitializer 3 | 4 | 5 | def get_deit(name, 6 | batch_size, 7 | image_shape=(3, 224, 224), 8 | dtype="int8", 9 | data_layout="NCHW", 10 | kernel_layout="OIHW", 11 | debug_unit=None): 12 | 13 | 14 | if data_layout == 'NCHW': 15 | data_shape = (batch_size,) + image_shape 16 | elif data_layout == 'NHWC': 17 | data_shape = (batch_size, image_shape[1], image_shape[2], image_shape[0]) 18 | elif data_layout == 'HWCN': 19 | data_shape = (image_shape[1], image_shape[2], image_shape[0], batch_size) 20 | elif data_layout == 'HWNC': 21 | data_shape = (image_shape[1], image_shape[2], batch_size, image_shape[0]) 22 | else: 23 | raise RuntimeError("Unsupported data layout {}".format(data_layout)) 24 | 25 | 26 | if name == 'deit_tiny_patch16_224': 27 | embed_dim = 192 28 | num_heads = 3 29 | elif name == 'deit_small_patch16_224': 30 | embed_dim = 384 31 | num_heads = 6 32 | elif name == 'deit_base_patch16_224': 33 | embed_dim = 768 34 | num_heads = 12 35 | else: 36 | raise RuntimeError("Unsupported model {}".format(name)) 37 | 38 | 39 | return Q_VisionTransformer(data_shape=data_shape, 40 | dtype=dtype, 41 | patch_size=16, 42 | num_patches=196, 43 | in_chans=3, 44 | num_classes=1000, 45 | embed_dim=embed_dim, 46 | depth=12, 47 | num_heads=num_heads, 48 | mlp_ratio=4) 49 | 50 | 51 | def get_workload(name, 52 | batch_size=1, 53 | image_shape=(3, 224, 224), 54 | dtype="int8", 55 | data_layout="NCHW", 56 | kernel_layout="OIHW", 57 | debug_unit=None): 58 | 59 | if batch_size != 1: 60 | raise RuntimeError("The released project only supports batch_size = 1.") 61 | 62 | net = get_deit(name, 63 | batch_size, 64 | image_shape=image_shape, 65 | dtype=dtype, 66 | data_layout=data_layout, 67 | kernel_layout=kernel_layout, 68 | debug_unit=debug_unit) 69 | 70 | return create_workload(net, QuantizeInitializer()) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 | # I-ViT: Integer-only Quantization for Efficient Vision Transformer Inference 6 | 7 | This repository contains the official implementation for the paper 8 | *["I-ViT: Integer-only Quantization for Efficient Vision Transformer Inference"](https://arxiv.org/abs/2207.01405).* To the best of our knowledge, this is the first work on integer-only quantization for vision transformers. 9 | 10 | Below are instructions of Pytorch code to reproduce the accuracy results of quantization-aware training (QAT). [**TVM benchmark**](https://github.com/zkkli/I-ViT/tree/main/TVM_benchmark) 11 | is the TVM deployment project for reproducing latency results. 12 | 13 | ## Installation 14 | - TVM version is recommended to be 0.9.dev0. 15 | - Timm version is recommended to be 0.4.12. 16 | - **To install I-ViT** and develop locally: 17 | 18 | ```bash 19 | git clone https://github.com/zkkli/I-ViT.git 20 | cd I-ViT 21 | ``` 22 | 23 | ## QAT Experiments 24 | 25 | - You can quantize and fine-tune a single model using the following command: 26 | 27 | ```bash 28 | python quant_train.py [--model] [--data] [--epochs] [--lr] 29 | 30 | optional arguments: 31 | --model: Model architecture, the choises can be: 32 | deit_tiny, deit_small, deit_base, swin_tiny, swin_small, swin_base. 33 | --data: Path to ImageNet dataset. 34 | --epochs: recommended values are: [30, 60, 90], default=90. 35 | --lr: recommended values are: [2e-7, 5e-7, 1e-6, 2e-6], default=1e-6. 36 | ``` 37 | 38 | - Example: Quantize and fine-tune DeiT-T: 39 | 40 | ```bash 41 | python quant_train.py --model deit_tiny --data --epochs 30 --lr 5e-7 42 | ``` 43 | 44 | ## Results 45 | 46 | Below are the Top-1 (%) accuracy results of our proposed I-ViT that you should get on ImageNet dataset. 47 | 48 | | Model | FP32 | INT8 (I-ViT) | Diff. | 49 | |:------:|:-----:|:------------:|:-----:| 50 | | ViT-S | 81.39 | 81.27 | -0.12 | 51 | | ViT-B | 84.53 | 84.76 | +0.23 | 52 | | DeiT-T | 72.21 | 72.24 | +0.03 | 53 | | DeiT-S | 79.85 | 80.12 | +0.27 | 54 | | DeiT-B | 81.85 | 81.74 | -0.11 | 55 | | Swin-T | 81.35 | 81.50 | +0.15 | 56 | | Swin-S | 83.20 | 83.01 | -0.19 | 57 | 58 | ## Citation 59 | 60 | We appreciate it if you would please cite the following paper if you found the implementation useful for your work: 61 | 62 | ```bash 63 | @inproceedings{li2023vit, 64 | title={I-vit: Integer-only quantization for efficient vision transformer inference}, 65 | author={Li, Zhikai and Gu, Qingyi}, 66 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 67 | pages={17065--17075}, 68 | year={2023} 69 | } 70 | ``` 71 | -------------------------------------------------------------------------------- /utils/samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.distributed as dist 5 | import math 6 | 7 | 8 | class RASampler(torch.utils.data.Sampler): 9 | """Sampler that restricts data loading to a subset of the dataset for distributed, 10 | with repeated augmentation. 11 | It ensures that different each augmented version of a sample will be visible to a 12 | different process (GPU) 13 | Heavily based on torch.utils.data.DistributedSampler 14 | """ 15 | 16 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3): 17 | if num_replicas is None: 18 | if not dist.is_available(): 19 | raise RuntimeError("Requires distributed package to be available") 20 | num_replicas = dist.get_world_size() 21 | if rank is None: 22 | if not dist.is_available(): 23 | raise RuntimeError("Requires distributed package to be available") 24 | rank = dist.get_rank() 25 | if num_repeats < 1: 26 | raise ValueError("num_repeats should be greater than 0") 27 | self.dataset = dataset 28 | self.num_replicas = num_replicas 29 | self.rank = rank 30 | self.num_repeats = num_repeats 31 | self.epoch = 0 32 | self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas)) 33 | self.total_size = self.num_samples * self.num_replicas 34 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 35 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 36 | self.shuffle = shuffle 37 | 38 | def __iter__(self): 39 | if self.shuffle: 40 | # deterministically shuffle based on epoch 41 | g = torch.Generator() 42 | g.manual_seed(self.epoch) 43 | indices = torch.randperm(len(self.dataset), generator=g) 44 | else: 45 | indices = torch.arange(start=0, end=len(self.dataset)) 46 | 47 | # add extra samples to make it evenly divisible 48 | indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist() 49 | padding_size: int = self.total_size - len(indices) 50 | if padding_size > 0: 51 | indices += indices[:padding_size] 52 | assert len(indices) == self.total_size 53 | 54 | # subsample 55 | indices = indices[self.rank:self.total_size:self.num_replicas] 56 | assert len(indices) == self.num_samples 57 | 58 | return iter(indices[:self.num_selected_samples]) 59 | #return iter(indices) 60 | 61 | def __len__(self): 62 | return self.num_selected_samples 63 | 64 | def set_epoch(self, epoch): 65 | self.epoch = epoch 66 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from torchvision import datasets, transforms 5 | from torchvision.datasets.folder import ImageFolder, default_loader 6 | 7 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 8 | from timm.data import create_transform 9 | 10 | from .samplers import RASampler 11 | import utils 12 | 13 | 14 | def dataloader(args): 15 | model_type = args.model.split("_")[0] 16 | if model_type == "deit" or "swin": 17 | dataset_train = build_dataset(is_train=True, args=args) 18 | dataset_val = build_dataset(is_train=False, args=args) 19 | 20 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 21 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 22 | 23 | # Data 24 | data_loader_train = torch.utils.data.DataLoader( 25 | dataset_train, sampler=sampler_train, 26 | batch_size=args.batch_size, 27 | num_workers=args.num_workers, 28 | pin_memory=args.pin_mem, 29 | drop_last=True, 30 | ) 31 | 32 | data_loader_val = torch.utils.data.DataLoader( 33 | dataset_val, sampler=sampler_val, 34 | batch_size=int(1.5 * args.batch_size), 35 | num_workers=args.num_workers, 36 | pin_memory=args.pin_mem, 37 | drop_last=False 38 | ) 39 | else: 40 | raise NotImplementedError 41 | 42 | return data_loader_train, data_loader_val 43 | 44 | 45 | def build_dataset(is_train, args): 46 | transform = build_transform(is_train, args) 47 | 48 | if args.data_set == 'CIFAR': 49 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) 50 | nb_classes = 100 51 | elif args.data_set == 'IMNET': 52 | root = os.path.join(args.data, 'train' if is_train else 'val') 53 | dataset = datasets.ImageFolder(root, transform=transform) 54 | nb_classes = 1000 55 | else: 56 | raise NotImplementedError 57 | 58 | return dataset 59 | 60 | 61 | def build_transform(is_train, args): 62 | resize_im = args.input_size > 32 63 | if is_train: 64 | # this should always dispatch to transforms_imagenet_train 65 | transform = create_transform( 66 | input_size=args.input_size, 67 | is_training=True, 68 | color_jitter=args.color_jitter, 69 | auto_augment=args.aa, 70 | interpolation=args.train_interpolation, 71 | re_prob=args.reprob, 72 | re_mode=args.remode, 73 | re_count=args.recount, 74 | ) 75 | if not resize_im: 76 | # replace RandomResizedCropAndInterpolation with 77 | # RandomCrop 78 | transform.transforms[0] = transforms.RandomCrop( 79 | args.input_size, padding=4) 80 | return transform 81 | 82 | t = [] 83 | if resize_im: 84 | size = int((256 / 224) * args.input_size) 85 | t.append( 86 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 87 | ) 88 | t.append(transforms.CenterCrop(args.input_size)) 89 | 90 | t.append(transforms.ToTensor()) 91 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 92 | return transforms.Compose(t) 93 | -------------------------------------------------------------------------------- /TVM_benchmark/evaluate_latency.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | 4 | import tvm 5 | from tvm import relay, auto_scheduler 6 | import tvm.relay.testing 7 | from tvm.contrib import graph_executor 8 | 9 | import models.build_model as build_model 10 | 11 | import os 12 | from pathlib import Path 13 | 14 | 15 | parser = argparse.ArgumentParser(description="TVM-Speed") 16 | 17 | parser.add_argument("--model-name", default='deit_tiny_patch16_224', 18 | choices=['deit_tiny_patch16_224', 19 | 'deit_small_patch16_224', 20 | 'deit_base_patch16_224'], 21 | help="model fullname") 22 | parser.add_argument("--log-path", default='result/deit.json', 23 | help="log_file path (.json)") 24 | parser.add_argument("--target", default='cuda -model=2080ti', 25 | help="tvm target") 26 | 27 | 28 | def main(): 29 | args = parser.parse_args() 30 | 31 | # Set target device 32 | target = tvm.target.Target(args.target) 33 | 34 | # Path to save tuning log 35 | log_file = args.log_path 36 | 37 | # Load model 38 | name = args.model_name 39 | batch_size = 1 40 | image_shape = (3, 224, 224) 41 | input_shape = (batch_size, 3, 224, 224) 42 | output_shape = (batch_size, 1000) 43 | data_layout = "NCHW" 44 | kernel_layout = "OIHW" 45 | 46 | mod, params = build_model.get_workload(name=name, 47 | batch_size=batch_size, 48 | image_shape=image_shape, 49 | dtype="int8", 50 | data_layout=data_layout, 51 | kernel_layout=kernel_layout) 52 | 53 | ################################################################################### 54 | print("Extract tasks...") 55 | tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) 56 | 57 | for idx, task in enumerate(tasks): 58 | print("========== Task %d (workload key: %s) ==========" % (idx, task.workload_key)) 59 | print(task.compute_dag) 60 | 61 | #################################################################################### 62 | print("Begin tuning...") 63 | measure_ctx = auto_scheduler.LocalRPCMeasureContext(repeat=1, min_repeat_ms=500, timeout=1000) 64 | 65 | tuner = auto_scheduler.TaskScheduler(tasks, task_weights) 66 | tune_option = auto_scheduler.TuningOptions( 67 | num_measure_trials=50000, 68 | runner=measure_ctx.runner, 69 | measure_callbacks=[auto_scheduler.RecordToFile(log_file)], 70 | ) 71 | 72 | tuner.tune(tune_option) 73 | 74 | #################################################################################### 75 | print("Compile...") 76 | with auto_scheduler.ApplyHistoryBest(log_file): 77 | with tvm.transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}): 78 | lib = relay.build(mod, target=target, params=params) 79 | 80 | # Create graph executor 81 | dev = tvm.device(str(target), 0) 82 | module = graph_executor.GraphModule(lib["default"](dev)) 83 | data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype('int8')) 84 | module.set_input("data", data_tvm) 85 | 86 | # Evaluate 87 | print("Evaluate inference time cost...") 88 | print(module.benchmark(dev, repeat=1000, min_repeat_ms=500)) 89 | 90 | 91 | if __name__ == "__main__": 92 | main() -------------------------------------------------------------------------------- /TVM_benchmark/evaluate_accuracy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import tvm 4 | from tvm import relay 5 | from tvm import target 6 | from tvm.contrib.download import download_testdata 7 | 8 | from torchvision import transforms 9 | from PIL import Image 10 | import numpy as np 11 | 12 | import models.build_model as build_model 13 | from models.layers import QuantizeContext 14 | 15 | import convert_model 16 | 17 | 18 | parser = argparse.ArgumentParser(description="TVM-Accuracy") 19 | 20 | parser.add_argument("--model-name", default='deit_tiny_patch16_224', 21 | choices=['deit_tiny_patch16_224', 22 | 'deit_small_patch16_224', 23 | 'deit_base_patch16_224'], 24 | help="model fullname") 25 | parser.add_argument("--model-path", default='', 26 | help="saved checkpoint path in QAT (checkpoint.pth.tar)") 27 | parser.add_argument("--params-path", default='', 28 | help="saved parameters path in convert_model.py (params.npy)") 29 | 30 | 31 | def main(): 32 | args = parser.parse_args() 33 | 34 | # Set target device 35 | target = 'cuda' 36 | 37 | # Load params 38 | model = torch.load(args.model_path) 39 | pretrained_params = np.load(args.params_path, allow_pickle=True)[()] 40 | depth = 12 41 | convert_model.load_qconfig(model, depth) 42 | 43 | # Classic cat example! 44 | img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true" 45 | img_path = download_testdata(img_url, "cat.png", module="data") 46 | img = Image.open(img_path).resize((224, 224)) 47 | # Preprocess the image and convert to tensor 48 | my_preprocess = transforms.Compose( 49 | [ 50 | transforms.Resize(256), 51 | transforms.CenterCrop(224), 52 | transforms.ToTensor(), 53 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 54 | ] 55 | ) 56 | img = my_preprocess(img) 57 | input_image = np.expand_dims(img, 0) 58 | input_image = input_image / QuantizeContext.qconfig_dict['qconfig_embed_conv'].input_scale 59 | input_image = np.clip(input_image, -128, 127) 60 | input_image = np.round(input_image) 61 | input_image = input_image.astype("int8") 62 | 63 | # Load model 64 | name = args.model_name 65 | batch_size = 1 66 | shape = list(input_image.shape) 67 | image_shape = (3, 224, 224) 68 | data_layout = "NCHW" 69 | kernel_layout = "OIHW" 70 | func, params = build_model.get_workload(name=name, 71 | batch_size=batch_size, 72 | image_shape=image_shape, 73 | dtype="int8", 74 | data_layout=data_layout, 75 | kernel_layout=kernel_layout) 76 | 77 | # Build model 78 | pretrained_params = {**pretrained_params} 79 | with tvm.transform.PassContext(opt_level=3): 80 | lib = relay.build(func, target=target, params=pretrained_params) 81 | 82 | runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](tvm.device(target, 0))) 83 | 84 | # Run model 85 | input_data = np.repeat(input_image, batch_size, axis=0) 86 | runtime.set_input('data', input_data) 87 | runtime.run() 88 | 89 | tvm_result = runtime.get_output(0).numpy() 90 | 91 | tvm_top1_labels = np.argsort(tvm_result[0])[::-1][:5] 92 | print("TVM top1 labels:", tvm_top1_labels) 93 | 94 | 95 | if __name__ == "__main__": 96 | main() -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | import io 4 | 5 | 6 | class DistillationLoss(torch.nn.Module): 7 | """ 8 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 9 | taking a teacher model prediction and using it as additional supervision. 10 | """ 11 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 12 | distillation_type: str, alpha: float, tau: float): 13 | super().__init__() 14 | self.base_criterion = base_criterion 15 | self.teacher_model = teacher_model 16 | assert distillation_type in ['none', 'soft', 'hard'] 17 | self.distillation_type = distillation_type 18 | self.alpha = alpha 19 | self.tau = tau 20 | 21 | def forward(self, inputs, outputs, labels): 22 | """ 23 | Args: 24 | inputs: The original inputs that are feed to the teacher model 25 | outputs: the outputs of the model to be trained. It is expected to be 26 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 27 | in the first position and the distillation predictions as the second output 28 | labels: the labels for the base criterion 29 | """ 30 | outputs_kd = None 31 | if not isinstance(outputs, torch.Tensor): 32 | # assume that the model outputs a tuple of [outputs, outputs_kd] 33 | outputs, outputs_kd = outputs 34 | base_loss = self.base_criterion(outputs, labels) 35 | if self.distillation_type == 'none': 36 | return base_loss 37 | 38 | if outputs_kd is None: 39 | raise ValueError("When knowledge distillation is enabled, the model is " 40 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 41 | "class_token and the dist_token") 42 | # don't backprop throught the teacher 43 | with torch.no_grad(): 44 | teacher_outputs = self.teacher_model(inputs) 45 | 46 | if self.distillation_type == 'soft': 47 | T = self.tau 48 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 49 | # with slight modifications 50 | distillation_loss = F.kl_div( 51 | F.log_softmax(outputs_kd / T, dim=1), 52 | #We provide the teacher's targets in log probability because we use log_target=True 53 | #(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719) 54 | #but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both. 55 | F.log_softmax(teacher_outputs / T, dim=1), 56 | reduction='sum', 57 | log_target=True 58 | ) * (T * T) / outputs_kd.numel() 59 | #We divide by outputs_kd.numel() to have the legacy PyTorch behavior. 60 | #But we also experiments output_kd.size(0) 61 | #see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details 62 | elif self.distillation_type == 'hard': 63 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 64 | 65 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 66 | return loss 67 | 68 | 69 | def load_checkpoint_for_ema(model_ema, checkpoint): 70 | """ 71 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 72 | """ 73 | mem_file = io.BytesIO() 74 | torch.save(checkpoint, mem_file) 75 | mem_file.seek(0) 76 | model_ema._load_checkpoint(mem_file) 77 | -------------------------------------------------------------------------------- /TVM_benchmark/models/utils.py: -------------------------------------------------------------------------------- 1 | """Initializer of parameters.""" 2 | import numpy as np 3 | 4 | import tvm 5 | from tvm import relay 6 | 7 | class Initializer(object): 8 | """The base class of an initializer.""" 9 | def __init__(self, **kwargs): 10 | self._kwargs = kwargs 11 | 12 | def __call__(self, desc, arr): 13 | """Initialize an array 14 | Parameters 15 | ---------- 16 | desc : str 17 | Initialization pattern descriptor. 18 | arr : NDArray 19 | The array to be initialized. 20 | """ 21 | if desc.endswith('weight'): 22 | self._init_weight(desc, arr) 23 | elif desc.endswith('bias'): 24 | self._init_bias(desc, arr) 25 | elif desc.endswith('gamma'): 26 | self._init_gamma(desc, arr) 27 | elif desc.endswith('beta'): 28 | self._init_beta(desc, arr) 29 | elif desc.endswith('mean'): 30 | self._init_mean(desc, arr) 31 | elif desc.endswith('var'): 32 | self._init_var(desc, arr) 33 | elif desc.endswith('scale'): 34 | self._init_scale(desc, arr) 35 | elif desc.endswith('shift'): 36 | self._init_shift(desc, arr) 37 | else: 38 | self._init_default(desc, arr) 39 | 40 | def _init_bias(self, _, arr): 41 | arr[:] = 0.0 42 | 43 | def _init_gamma(self, _, arr): 44 | arr[:] = 1.0 45 | 46 | def _init_beta(self, _, arr): 47 | arr[:] = 0.0 48 | 49 | def _init_mean(self, _, arr): 50 | arr[:] = 0.0 51 | 52 | def _init_var(self, _, arr): 53 | arr[:] = 1.0 54 | 55 | def _init_scale(self, _, arr): 56 | arr[:] = 2 57 | 58 | def _init_shift(self, _, arr): 59 | arr[:] = 2 60 | 61 | def _init_weight(self, name, arr): 62 | """Abstract method to Initialize weight.""" 63 | raise NotImplementedError("Must override it") 64 | 65 | def _init_default(self, name, _): 66 | raise ValueError( 67 | 'Unknown initialization pattern for %s. ' \ 68 | 'Default initialization is now limited to '\ 69 | '"weight", "bias", "gamma" (1.0), and "beta" (0.0).' \ 70 | 'Please use mx.sym.Variable(init=mx.init.*) to set initialization pattern' % name) 71 | 72 | class Xavier(Initializer): 73 | """ "Xavier" initialization for weights 74 | Parameters 75 | ---------- 76 | rnd_type: str, optional 77 | Random generator type, can be ``'gaussian'`` or ``'uniform'``. 78 | factor_type: str, optional 79 | Can be ``'avg'``, ``'in'``, or ``'out'``. 80 | magnitude: float, optional 81 | Scale of random number. 82 | """ 83 | def __init__(self, rnd_type="uniform", factor_type="avg", magnitude=3): 84 | super(Xavier, self).__init__(rnd_type=rnd_type, 85 | factor_type=factor_type, 86 | magnitude=magnitude) 87 | self.rnd_type = rnd_type 88 | self.factor_type = factor_type 89 | self.magnitude = float(magnitude) 90 | 91 | def _init_weight(self, name, arr): 92 | shape = arr.shape 93 | hw_scale = 1. 94 | if len(shape) < 2: 95 | raise ValueError('Xavier initializer cannot be applied to vector {0}. It requires at' 96 | ' least 2D.'.format(name)) 97 | if len(shape) > 2: 98 | hw_scale = np.prod(shape[2:]) 99 | fan_in, fan_out = shape[1] * hw_scale, shape[0] * hw_scale 100 | factor = 1. 101 | if self.factor_type == "avg": 102 | factor = (fan_in + fan_out) / 2.0 103 | elif self.factor_type == "in": 104 | factor = fan_in 105 | elif self.factor_type == "out": 106 | factor = fan_out 107 | else: 108 | raise ValueError("Incorrect factor type") 109 | # Hack for mobilenet, because there is less connectivity 110 | if "depthwise" in name: 111 | factor = hw_scale 112 | scale = np.sqrt(self.magnitude / factor) 113 | if self.rnd_type == "uniform": 114 | arr[:] = np.random.uniform(-scale, scale, size=arr.shape) 115 | else: 116 | raise ValueError("Unknown random type") 117 | 118 | class QuantizeInitializer(Initializer): 119 | def _init_weight(self, name, arr): 120 | if arr.dtype == np.float32: 121 | arr[:] = np.random.uniform(-1., 1., size=arr.shape) 122 | elif arr.dtype == np.int8: 123 | arr[:] = np.random.randint(-127, 128, size=arr.shape) 124 | elif arr.dtype == np.uint8: 125 | arr[:] = np.random.randint(0, 256, size=arr.shape) 126 | elif arr.dtype == np.int32: 127 | arr[:] = np.random.randint(-2**31, 2**31, size=arr.shape) 128 | else: 129 | raise ValueError("Unknown random type %s" % (arr.dtype)) 130 | 131 | def _init_bias(self, name, arr): 132 | if arr.dtype == np.int32: 133 | arr[:] = np.random.randint(-200, 200, size=arr.shape) 134 | elif arr.dtype == np.float32: 135 | arr[:] = np.random.uniform(-1., 1., size=arr.shape) 136 | else: 137 | raise ValueError("Unknown random type %s" % (arr.dtype)) 138 | 139 | def _init_scale(self, _, arr): 140 | arr[:] = np.random.randint(-256, 256, size=arr.shape) 141 | 142 | def _init_shift(self, _, arr): 143 | arr[:] = np.random.randint(-256, 256, size=arr.shape) 144 | 145 | def create_workload(net, initializer=None, seed=0): 146 | """Helper function to create benchmark image classification workload. 147 | Parameters 148 | ---------- 149 | net : tvm.relay.Function 150 | The selected function of the network. 151 | initializer : Initializer 152 | The initializer used 153 | seed : int 154 | The seed used in initialization. 155 | Returns 156 | ------- 157 | mod : tvm.IRModule 158 | The created relay module. 159 | params : dict of str to NDArray 160 | The parameters. 161 | """ 162 | mod = tvm.IRModule.from_expr(net) 163 | mod = relay.transform.InferType()(mod) 164 | shape_dict = { 165 | v.name_hint : v.checked_type for v in mod["main"].params} 166 | np.random.seed(seed) 167 | initializer = initializer if initializer else Xavier() 168 | params = {} 169 | for k, v in shape_dict.items(): 170 | if k == "data": 171 | continue 172 | 173 | if v.dtype == 'int4' or v.dtype == 'uint4': 174 | pack_shape = list(v.concrete_shape) 175 | pack_shape[-1] = pack_shape[-1] // 8 176 | init_value = np.zeros(pack_shape).astype('int32') 177 | else: 178 | init_value = np.zeros(v.concrete_shape).astype(v.dtype) 179 | 180 | initializer(k, init_value) 181 | params[k] = tvm.nd.array(init_value) #, ctx=tvm.cpu(0) 182 | 183 | return mod, params -------------------------------------------------------------------------------- /models/layers_quant.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from itertools import repeat 4 | import collections.abc 5 | 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from .quantization_utils import QuantLinear, QuantAct, QuantConv2d, IntGELU 11 | 12 | 13 | def _ntuple(n): 14 | def parse(x): 15 | if isinstance(x, collections.abc.Iterable): 16 | return x 17 | return tuple(repeat(x, n)) 18 | 19 | return parse 20 | 21 | 22 | to_2tuple = _ntuple(2) 23 | 24 | 25 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 26 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 27 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 28 | def norm_cdf(x): 29 | # Computes standard normal cumulative distribution function 30 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 31 | 32 | if (mean < a - 2 * std) or (mean > b + 2 * std): 33 | warnings.warn( 34 | "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 35 | "The distribution of values may be incorrect.", 36 | stacklevel=2, 37 | ) 38 | 39 | with torch.no_grad(): 40 | # Values are generated by using a truncated uniform distribution and 41 | # then using the inverse CDF for the normal distribution. 42 | # Get upper and lower cdf values 43 | l = norm_cdf((a - mean) / std) 44 | u = norm_cdf((b - mean) / std) 45 | 46 | # Uniformly fill tensor with values from [l, u], then translate to 47 | # [2l-1, 2u-1]. 48 | tensor.uniform_(2 * l - 1, 2 * u - 1) 49 | 50 | # Use inverse cdf transform for normal distribution to get truncated 51 | # standard normal 52 | tensor.erfinv_() 53 | 54 | # Transform to proper mean, std 55 | tensor.mul_(std * math.sqrt(2.0)) 56 | tensor.add_(mean) 57 | 58 | # Clamp to ensure it's in the proper range 59 | tensor.clamp_(min=a, max=b) 60 | return tensor 61 | 62 | 63 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): 64 | # type: (Tensor, float, float, float, float) -> Tensor 65 | r"""Fills the input Tensor with values drawn from a truncated 66 | normal distribution. The values are effectively drawn from the 67 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 68 | with values outside :math:`[a, b]` redrawn until they are within 69 | the bounds. The method used for generating the random values works 70 | best when :math:`a \leq \text{mean} \leq b`. 71 | Args: 72 | tensor: an n-dimensional `torch.Tensor` 73 | mean: the mean of the normal distribution 74 | std: the standard deviation of the normal distribution 75 | a: the minimum cutoff value 76 | b: the maximum cutoff value 77 | Examples: 78 | >>> w = torch.empty(3, 5) 79 | >>> nn.init.trunc_normal_(w) 80 | """ 81 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 82 | 83 | 84 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 85 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 86 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 87 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 88 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 89 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 90 | 'survival rate' as the argument. 91 | """ 92 | if drop_prob == 0.0 or not training: 93 | return x 94 | keep_prob = 1 - drop_prob 95 | shape = (x.shape[0],) + (1,) * ( 96 | x.ndim - 1 97 | ) # work with diff dim tensors, not just 2D ConvNets 98 | random_tensor = keep_prob + \ 99 | torch.rand(shape, dtype=x.dtype, device=x.device) 100 | random_tensor.floor_() # binarize 101 | output = x.div(keep_prob) * random_tensor 102 | return output 103 | 104 | 105 | class DropPath(nn.Module): 106 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 107 | 108 | def __init__(self, drop_prob=None): 109 | super(DropPath, self).__init__() 110 | self.drop_prob = drop_prob 111 | 112 | def forward(self, x): 113 | return drop_path(x, self.drop_prob, self.training) 114 | 115 | 116 | class Mlp(nn.Module): 117 | def __init__( 118 | self, 119 | in_features, 120 | hidden_features=None, 121 | out_features=None, 122 | act_layer=IntGELU, 123 | drop=0.0): 124 | super().__init__() 125 | out_features = out_features or in_features 126 | hidden_features = hidden_features or in_features 127 | # self.fc1 = nn.Linear(in_features, hidden_features) 128 | self.fc1 = QuantLinear( 129 | in_features, 130 | hidden_features 131 | ) 132 | self.act = act_layer() 133 | self.qact1 = QuantAct() 134 | # self.fc2 = nn.Linear(hidden_features, out_features) 135 | self.fc2 = QuantLinear( 136 | hidden_features, 137 | out_features 138 | ) 139 | self.qact2 = QuantAct(16) 140 | self.drop = nn.Dropout(drop) 141 | 142 | self.qact_gelu = QuantAct() 143 | 144 | def forward(self, x, act_scaling_factor): 145 | x, act_scaling_factor = self.fc1(x, act_scaling_factor) 146 | x, act_scaling_factor = self.qact_gelu(x, act_scaling_factor) 147 | x, act_scaling_factor = self.act(x, act_scaling_factor) 148 | x, act_scaling_factor = self.qact1(x, act_scaling_factor) 149 | x = self.drop(x) 150 | x, act_scaling_factor = self.fc2(x, act_scaling_factor) 151 | x, act_scaling_factor = self.qact2(x, act_scaling_factor) 152 | x = self.drop(x) 153 | return x, act_scaling_factor 154 | 155 | 156 | class PatchEmbed(nn.Module): 157 | """Image to Patch Embedding""" 158 | 159 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None): 160 | super().__init__() 161 | img_size = to_2tuple(img_size) 162 | patch_size = to_2tuple(patch_size) 163 | self.img_size = img_size 164 | self.patch_size = patch_size 165 | 166 | self.grid_size = (img_size[0] // patch_size[0], 167 | img_size[1] // patch_size[1]) 168 | self.num_patches = self.grid_size[0] * self.grid_size[1] 169 | 170 | self.norm_layer = norm_layer 171 | 172 | self.proj = QuantConv2d( 173 | in_chans, 174 | embed_dim, 175 | kernel_size=patch_size, 176 | stride=patch_size, 177 | ) 178 | if self.norm_layer: 179 | self.qact_before_norm = QuantAct() 180 | self.norm = norm_layer(embed_dim) 181 | self.qact = QuantAct(16) 182 | 183 | 184 | def forward(self, x, act_scaling_factor): 185 | B, C, H, W = x.shape 186 | # FIXME look at relaxing size constraints 187 | assert ( 188 | H == self.img_size[0] and W == self.img_size[1] 189 | ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 190 | x, act_scaling_factor = self.proj(x, act_scaling_factor) 191 | x = x.flatten(2).transpose(1, 2) 192 | if self.norm_layer: 193 | x, act_scaling_factor = self.qact_before_norm(x, act_scaling_factor) 194 | x, act_scaling_factor = self.norm(x, act_scaling_factor) 195 | x, act_scaling_factor = self.qact(x, act_scaling_factor) 196 | return x, act_scaling_factor 197 | 198 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Misc functions, including distributed helpers. 5 | Mostly copy-paste from torchvision references. 6 | """ 7 | import io 8 | import os 9 | import time 10 | from collections import defaultdict, deque 11 | import datetime 12 | 13 | import torch 14 | import torch.distributed as dist 15 | 16 | 17 | class SmoothedValue(object): 18 | """Track a series of values and provide access to smoothed values over a 19 | window or the global series average. 20 | """ 21 | 22 | def __init__(self, window_size=20, fmt=None): 23 | if fmt is None: 24 | fmt = "{median:.4f} ({global_avg:.4f})" 25 | self.deque = deque(maxlen=window_size) 26 | self.total = 0.0 27 | self.count = 0 28 | self.fmt = fmt 29 | 30 | def update(self, value, n=1): 31 | self.deque.append(value) 32 | self.count += n 33 | self.total += value * n 34 | 35 | def synchronize_between_processes(self): 36 | """ 37 | Warning: does not synchronize the deque! 38 | """ 39 | if not is_dist_avail_and_initialized(): 40 | return 41 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 42 | dist.barrier() 43 | dist.all_reduce(t) 44 | t = t.tolist() 45 | self.count = int(t[0]) 46 | self.total = t[1] 47 | 48 | @property 49 | def median(self): 50 | d = torch.tensor(list(self.deque)) 51 | return d.median().item() 52 | 53 | @property 54 | def avg(self): 55 | d = torch.tensor(list(self.deque), dtype=torch.float32) 56 | return d.mean().item() 57 | 58 | @property 59 | def global_avg(self): 60 | return self.total / self.count 61 | 62 | @property 63 | def max(self): 64 | return max(self.deque) 65 | 66 | @property 67 | def value(self): 68 | return self.deque[-1] 69 | 70 | def __str__(self): 71 | return self.fmt.format( 72 | median=self.median, 73 | avg=self.avg, 74 | global_avg=self.global_avg, 75 | max=self.max, 76 | value=self.value) 77 | 78 | 79 | class MetricLogger(object): 80 | def __init__(self, delimiter="\t"): 81 | self.meters = defaultdict(SmoothedValue) 82 | self.delimiter = delimiter 83 | 84 | def update(self, **kwargs): 85 | for k, v in kwargs.items(): 86 | if isinstance(v, torch.Tensor): 87 | v = v.item() 88 | assert isinstance(v, (float, int)) 89 | self.meters[k].update(v) 90 | 91 | def __getattr__(self, attr): 92 | if attr in self.meters: 93 | return self.meters[attr] 94 | if attr in self.__dict__: 95 | return self.__dict__[attr] 96 | raise AttributeError("'{}' object has no attribute '{}'".format( 97 | type(self).__name__, attr)) 98 | 99 | def __str__(self): 100 | loss_str = [] 101 | for name, meter in self.meters.items(): 102 | loss_str.append( 103 | "{}: {}".format(name, str(meter)) 104 | ) 105 | return self.delimiter.join(loss_str) 106 | 107 | def synchronize_between_processes(self): 108 | for meter in self.meters.values(): 109 | meter.synchronize_between_processes() 110 | 111 | def add_meter(self, name, meter): 112 | self.meters[name] = meter 113 | 114 | def log_every(self, iterable, print_freq, header=None): 115 | i = 0 116 | if not header: 117 | header = '' 118 | start_time = time.time() 119 | end = time.time() 120 | iter_time = SmoothedValue(fmt='{avg:.4f}') 121 | data_time = SmoothedValue(fmt='{avg:.4f}') 122 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 123 | log_msg = [ 124 | header, 125 | '[{0' + space_fmt + '}/{1}]', 126 | 'eta: {eta}', 127 | '{meters}', 128 | 'time: {time}', 129 | 'data: {data}' 130 | ] 131 | if torch.cuda.is_available(): 132 | log_msg.append('max mem: {memory:.0f}') 133 | log_msg = self.delimiter.join(log_msg) 134 | MB = 1024.0 * 1024.0 135 | for obj in iterable: 136 | data_time.update(time.time() - end) 137 | yield obj 138 | iter_time.update(time.time() - end) 139 | if i % print_freq == 0 or i == len(iterable) - 1: 140 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 141 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 142 | if torch.cuda.is_available(): 143 | print(log_msg.format( 144 | i, len(iterable), eta=eta_string, 145 | meters=str(self), 146 | time=str(iter_time), data=str(data_time), 147 | memory=torch.cuda.max_memory_allocated() / MB)) 148 | else: 149 | print(log_msg.format( 150 | i, len(iterable), eta=eta_string, 151 | meters=str(self), 152 | time=str(iter_time), data=str(data_time))) 153 | i += 1 154 | end = time.time() 155 | total_time = time.time() - start_time 156 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 157 | print('{} Total time: {} ({:.4f} s / it)'.format( 158 | header, total_time_str, total_time / len(iterable))) 159 | 160 | 161 | def _load_checkpoint_for_ema(model_ema, checkpoint): 162 | """ 163 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 164 | """ 165 | mem_file = io.BytesIO() 166 | torch.save(checkpoint, mem_file) 167 | mem_file.seek(0) 168 | model_ema._load_checkpoint(mem_file) 169 | 170 | 171 | def setup_for_distributed(is_master): 172 | """ 173 | This function disables printing when not in master process 174 | """ 175 | import builtins as __builtin__ 176 | builtin_print = __builtin__.print 177 | 178 | def print(*args, **kwargs): 179 | force = kwargs.pop('force', False) 180 | if is_master or force: 181 | builtin_print(*args, **kwargs) 182 | 183 | __builtin__.print = print 184 | 185 | 186 | def is_dist_avail_and_initialized(): 187 | if not dist.is_available(): 188 | return False 189 | if not dist.is_initialized(): 190 | return False 191 | return True 192 | 193 | 194 | def get_world_size(): 195 | if not is_dist_avail_and_initialized(): 196 | return 1 197 | return dist.get_world_size() 198 | 199 | 200 | def get_rank(): 201 | if not is_dist_avail_and_initialized(): 202 | return 0 203 | return dist.get_rank() 204 | 205 | 206 | def is_main_process(): 207 | return get_rank() == 0 208 | 209 | 210 | def save_on_master(*args, **kwargs): 211 | if is_main_process(): 212 | torch.save(*args, **kwargs) 213 | 214 | 215 | def init_distributed_mode(args): 216 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 217 | args.rank = int(os.environ["RANK"]) 218 | args.world_size = int(os.environ['WORLD_SIZE']) 219 | args.gpu = int(os.environ['LOCAL_RANK']) 220 | elif 'SLURM_PROCID' in os.environ: 221 | args.rank = int(os.environ['SLURM_PROCID']) 222 | args.gpu = args.rank % torch.cuda.device_count() 223 | else: 224 | print('Not using distributed mode') 225 | args.distributed = False 226 | return 227 | 228 | args.distributed = True 229 | 230 | torch.cuda.set_device(args.gpu) 231 | args.dist_backend = 'nccl' 232 | print('| distributed init (rank {}): {}'.format( 233 | args.rank, args.dist_url), flush=True) 234 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 235 | world_size=args.world_size, rank=args.rank) 236 | torch.distributed.barrier() 237 | setup_for_distributed(args.rank == 0) -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | 10 | @torch.no_grad() 11 | def load_weights_from_npz(model, url, check_hash=False, progress=False, prefix=''): 12 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 13 | """ 14 | 15 | def _n2p(w, t=True): 16 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 17 | w = w.flatten() 18 | if t: 19 | if w.ndim == 4: 20 | w = w.transpose([3, 2, 0, 1]) 21 | elif w.ndim == 3: 22 | w = w.transpose([2, 0, 1]) 23 | elif w.ndim == 2: 24 | w = w.transpose([1, 0]) 25 | return torch.from_numpy(w) 26 | 27 | def _get_cache_dir(child_dir=''): 28 | """ 29 | Returns the location of the directory where models are cached (and creates it if necessary). 30 | """ 31 | hub_dir = torch.hub.get_dir() 32 | child_dir = () if not child_dir else (child_dir,) 33 | model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir) 34 | os.makedirs(model_dir, exist_ok=True) 35 | return model_dir 36 | 37 | def _download_cached_file(url, check_hash=True, progress=False): 38 | parts = torch.hub.urlparse(url) 39 | filename = os.path.basename(parts.path) 40 | cached_file = os.path.join(_get_cache_dir(), filename) 41 | if not os.path.exists(cached_file): 42 | hash_prefix = None 43 | if check_hash: 44 | r = torch.hub.HASH_REGEX.search( 45 | filename) # r is Optional[Match[str]] 46 | hash_prefix = r.group(1) if r else None 47 | torch.hub.download_url_to_file( 48 | url, cached_file, hash_prefix, progress=progress) 49 | return cached_file 50 | 51 | def adapt_input_conv(in_chans, conv_weight): 52 | conv_type = conv_weight.dtype 53 | # Some weights are in torch.half, ensure it's float for sum on CPU 54 | conv_weight = conv_weight.float() 55 | O, I, J, K = conv_weight.shape 56 | if in_chans == 1: 57 | if I > 3: 58 | assert conv_weight.shape[1] % 3 == 0 59 | # For models with space2depth stems 60 | conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) 61 | conv_weight = conv_weight.sum(dim=2, keepdim=False) 62 | else: 63 | conv_weight = conv_weight.sum(dim=1, keepdim=True) 64 | elif in_chans != 3: 65 | if I != 3: 66 | raise NotImplementedError( 67 | 'Weight format not supported by conversion.') 68 | else: 69 | # NOTE this strategy should be better than random init, but there could be other combinations of 70 | # the original RGB input layer weights that'd work better for specific cases. 71 | repeat = int(math.ceil(in_chans / 3)) 72 | conv_weight = conv_weight.repeat(1, repeat, 1, 1)[ 73 | :, :in_chans, :, :] 74 | conv_weight *= (3 / float(in_chans)) 75 | conv_weight = conv_weight.to(conv_type) 76 | return conv_weight 77 | 78 | def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): 79 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 80 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 81 | ntok_new = posemb_new.shape[1] 82 | if num_tokens: 83 | posemb_tok, posemb_grid = posemb[:, 84 | :num_tokens], posemb[0, num_tokens:] 85 | ntok_new -= num_tokens 86 | else: 87 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 88 | gs_old = int(math.sqrt(len(posemb_grid))) 89 | if not len(gs_new): # backwards compatibility 90 | gs_new = [int(math.sqrt(ntok_new))] * 2 91 | assert len(gs_new) >= 2 92 | posemb_grid = posemb_grid.reshape( 93 | 1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 94 | posemb_grid = F.interpolate( 95 | posemb_grid, size=gs_new, mode='bicubic', align_corners=False) 96 | posemb_grid = posemb_grid.permute( 97 | 0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) 98 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 99 | return posemb 100 | 101 | cached_file = _download_cached_file( 102 | url, check_hash=check_hash, progress=progress) 103 | 104 | w = np.load(cached_file) 105 | if not prefix and 'opt/target/embedding/kernel' in w: 106 | prefix = 'opt/target/' 107 | 108 | if hasattr(model.patch_embed, 'backbone'): 109 | # hybrid 110 | backbone = model.patch_embed.backbone 111 | stem_only = not hasattr(backbone, 'stem') 112 | stem = backbone if stem_only else backbone.stem 113 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 114 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 115 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 116 | if not stem_only: 117 | for i, stage in enumerate(backbone.stages): 118 | for j, block in enumerate(stage.blocks): 119 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 120 | for r in range(3): 121 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 122 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 123 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 124 | if block.downsample is not None: 125 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 126 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 127 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 128 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 129 | else: 130 | embed_conv_w = adapt_input_conv( 131 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 132 | model.patch_embed.proj.weight.copy_(embed_conv_w) 133 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 134 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 135 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 136 | if pos_embed_w.shape != model.pos_embed.shape: 137 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 138 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 139 | model.pos_embed.copy_(pos_embed_w) 140 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 141 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 142 | if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 143 | model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 144 | model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 145 | if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 146 | model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 147 | model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 148 | for i, block in enumerate(model.blocks.children()): 149 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 150 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 151 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 152 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 153 | block.attn.qkv.weight.copy_(torch.cat([ 154 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 155 | block.attn.qkv.bias.copy_(torch.cat([ 156 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 157 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 158 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 159 | for r in range(2): 160 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 161 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 162 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 163 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 164 | -------------------------------------------------------------------------------- /TVM_benchmark/convert_model.py: -------------------------------------------------------------------------------- 1 | from curses.ascii import isascii 2 | from pytest import param 3 | import torch 4 | import numpy as np 5 | import argparse 6 | 7 | import os 8 | 9 | from models.layers import QConfig, QuantizeContext 10 | 11 | 12 | def save_params(model, depth, save_path): 13 | ## weight and bias (conv and dense) 14 | params = {} 15 | for (key, tensor) in model.items(): 16 | if 'weight_integer' in key: 17 | print(key) 18 | params[key] = tensor.cpu().numpy().astype('int8') 19 | if 'bias_integer' in key: 20 | print(key) 21 | params[key] = tensor.cpu().numpy().astype('int32') 22 | 23 | renamed_params = {} 24 | renamed_params['embed_conv_weight'] = params['patch_embed.proj.weight_integer'] 25 | renamed_params['embed_conv_bias'] = params['patch_embed.proj.bias_integer'].reshape(1, -1, 1, 1) 26 | 27 | for i in range(depth): 28 | for key in ['weight_integer', 'bias_integer']: 29 | old_name = 'blocks.%d.attn.qkv.' % (i) + key 30 | new_name = 'block_%d_attn_qkv_' % (i) + key[:-8] 31 | renamed_params[new_name] = params[old_name] 32 | 33 | old_name = 'blocks.%d.attn.proj.' % (i) + key 34 | new_name = 'block_%d_attn_proj_' % (i) + key[:-8] 35 | renamed_params[new_name] = params[old_name] 36 | 37 | old_name = 'blocks.%d.mlp.fc1.' % (i) + key 38 | new_name = 'block_%d_mlp_fc1_' % (i) + key[:-8] 39 | renamed_params[new_name] = params[old_name] 40 | 41 | old_name = 'blocks.%d.mlp.fc2.' % (i) + key 42 | new_name = 'block_%d_mlp_fc2_' % (i) + key[:-8] 43 | renamed_params[new_name] = params[old_name] 44 | 45 | renamed_params['head_weight'] = params['head.weight_integer'] 46 | renamed_params['head_bias'] = params['head.bias_integer'] 47 | 48 | ## norm 49 | for i in range(depth): 50 | for key in ['bias_integer']: 51 | old_name = 'blocks.%d.norm1.' % (i) + key 52 | new_name = 'block_%d_norm1_' % (i) + key[:-8] 53 | renamed_params[new_name] = model[old_name].cpu().numpy().astype('int32') 54 | 55 | old_name = 'blocks.%d.norm2.' % (i) + key 56 | new_name = 'block_%d_norm2_' % (i) + key[:-8] 57 | renamed_params[new_name] = model[old_name].cpu().numpy().astype('int32') 58 | 59 | renamed_params['norm_bias'] = model['norm.bias_integer'].cpu().numpy().astype('int32') 60 | 61 | 62 | ## other params 63 | renamed_params['cls_token_weight'] = model['cls_token'].cpu().numpy() 64 | renamed_params['pos_embed_weight'] = model['pos_embed'].cpu().numpy() 65 | 66 | np.save(os.path.join(save_path, 'params.npy'), renamed_params) 67 | 68 | 69 | def load_qconfig(model, depth): 70 | params = {} 71 | for (key, tensor) in model.items(): 72 | if 'scaling_factor' in key: 73 | #print(key) 74 | tensor_np = tensor.cpu().numpy().reshape((-1)) 75 | params[key] = tensor_np 76 | if "act_scaling_factor" in key and np.ndim(tensor_np) == 1: 77 | tensor_np = tensor_np[0] 78 | params[key] = tensor_np 79 | 80 | QuantizeContext.qconfig_dict['qconfig_pos'] = QConfig(output_scale=params['qact_pos.act_scaling_factor']) 81 | QuantizeContext.qconfig_dict['qconfig_addpos'] = QConfig(input_scale=params['patch_embed.qact.act_scaling_factor'], input_dtype='int16', output_scale=params['qact1.act_scaling_factor']) 82 | ## Embed 83 | conv_input_scale = params['qact_input.act_scaling_factor'] 84 | conv_kernel_scale = params['patch_embed.proj.conv_scaling_factor'] 85 | conv_output_scale = conv_input_scale * conv_kernel_scale 86 | QuantizeContext.qconfig_dict['qconfig_embed_conv'] = \ 87 | QConfig(input_scale=conv_input_scale, kernel_scale=conv_kernel_scale, output_scale=conv_output_scale) 88 | 89 | for i in range(depth): 90 | input_scale = params['qact1.act_scaling_factor'] if i == 0 else params['blocks.%d.qact4.act_scaling_factor' % (i-1)] 91 | output_scale = params['blocks.%d.norm1.norm_scaling_factor' % (i)] 92 | QuantizeContext.qconfig_dict['block_%d_qconfig_norm1' % (i)] = QConfig(input_scale=input_scale, output_scale=output_scale) 93 | 94 | input_scale = params['blocks.%d.qact1.act_scaling_factor' % (i)] 95 | kernel_scale = params['blocks.%d.attn.qkv.fc_scaling_factor' % (i)] 96 | output_scale = input_scale * kernel_scale 97 | QuantizeContext.qconfig_dict['block_%d_qconfig_qkv' % (i)] = QConfig(input_scale=input_scale, kernel_scale=kernel_scale, output_scale=output_scale) 98 | 99 | input_scale = params['blocks.%d.attn.qact1.act_scaling_factor' % (i)] 100 | output_scale = params['blocks.%d.attn.matmul_1.act_scaling_factor' % (i)] 101 | QuantizeContext.qconfig_dict['block_%d_qconfig_matmul_1' % (i)] = QConfig(input_scale=input_scale, output_scale=output_scale) 102 | 103 | input_scale = params['blocks.%d.attn.qact_attn1.act_scaling_factor' % (i)] 104 | output_scale = params['blocks.%d.attn.int_softmax.act_scaling_factor' % (i)] 105 | QuantizeContext.qconfig_dict['block_%d_qconfig_softmax' % (i)] = QConfig(input_scale=input_scale, output_scale=output_scale) 106 | 107 | input_scale = params['blocks.%d.attn.int_softmax.act_scaling_factor' % (i)] 108 | output_scale = params['blocks.%d.attn.matmul_2.act_scaling_factor' % (i)] 109 | QuantizeContext.qconfig_dict['block_%d_qconfig_matmul_2' % (i)] = QConfig(input_scale=input_scale, output_scale=output_scale) 110 | 111 | input_scale = params['blocks.%d.attn.qact2.act_scaling_factor' % (i)] 112 | kernel_scale = params['blocks.%d.attn.proj.fc_scaling_factor' % (i)] 113 | output_scale = input_scale * kernel_scale 114 | QuantizeContext.qconfig_dict['block_%d_qconfig_proj' % (i)] = QConfig(input_scale=input_scale, kernel_scale=kernel_scale, output_scale=output_scale) 115 | 116 | input_scale = params['blocks.%d.attn.qact3.act_scaling_factor' % (i)] 117 | output_scale = params['blocks.%d.qact2.act_scaling_factor' % (i)] 118 | QuantizeContext.qconfig_dict['block_%d_qconfig_add1' % (i)] = QConfig(input_scale=input_scale, input_dtype='int16', output_scale=output_scale) 119 | 120 | input_scale = params['blocks.%d.qact2.act_scaling_factor' % (i)] 121 | output_scale = params['blocks.%d.norm2.norm_scaling_factor' % (i)] 122 | QuantizeContext.qconfig_dict['block_%d_qconfig_norm2' % (i)] = QConfig(input_scale=input_scale, output_scale=output_scale) 123 | 124 | input_scale = params['blocks.%d.qact3.act_scaling_factor' % (i)] 125 | kernel_scale = params['blocks.%d.mlp.fc1.fc_scaling_factor' % (i)] 126 | output_scale = input_scale * kernel_scale 127 | QuantizeContext.qconfig_dict['block_%d_qconfig_fc1' % (i)] = QConfig(input_scale=input_scale, kernel_scale=kernel_scale, output_scale=output_scale) 128 | 129 | input_scale = params['blocks.%d.mlp.qact_gelu.act_scaling_factor' % (i)] 130 | output_scale = params['blocks.%d.mlp.act.act_scaling_factor' % (i)] 131 | QuantizeContext.qconfig_dict['block_%d_qconfig_gelu' % (i)] = QConfig(input_scale=input_scale, output_scale=output_scale, input_dtype='int8') 132 | 133 | input_scale = params['blocks.%d.mlp.qact1.act_scaling_factor' % (i)] 134 | kernel_scale = params['blocks.%d.mlp.fc2.fc_scaling_factor' % (i)] 135 | output_scale = input_scale * kernel_scale 136 | QuantizeContext.qconfig_dict['block_%d_qconfig_fc2' % (i)] = QConfig(input_scale=input_scale, kernel_scale=kernel_scale, output_scale=output_scale) 137 | 138 | input_scale = params['blocks.%d.mlp.qact2.act_scaling_factor' % (i)] 139 | output_scale = params['blocks.%d.qact4.act_scaling_factor' % (i)] 140 | QuantizeContext.qconfig_dict['block_%d_qconfig_add2' % (i)] = QConfig(input_scale=input_scale, input_dtype='int16', output_scale=output_scale) 141 | 142 | output_scale = params['norm.norm_scaling_factor'] 143 | QuantizeContext.qconfig_dict['qconfig_norm'] = QConfig(input_scale=input_scale, output_scale=output_scale) 144 | 145 | input_scale = params['qact2.act_scaling_factor'] 146 | kernel_scale = params['head.fc_scaling_factor'] 147 | output_scale = input_scale * kernel_scale 148 | QuantizeContext.qconfig_dict['qconfig_head'] = QConfig(input_scale=input_scale, kernel_scale=kernel_scale, output_scale=output_scale) 149 | 150 | 151 | if __name__ == '__main__': 152 | parser = argparse.ArgumentParser(description='I-ViT convert model', 153 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 154 | parser.add_argument('--model-path', default='', 155 | help='saved checkpoint path in QAT (checkpoint.pth.tar)') 156 | parser.add_argument('--params-path', default='', 157 | help='Saved parameters directory') 158 | parser.add_argument('--depth', default=12, 159 | help='Depth of ViT') 160 | 161 | args = parser.parse_args() 162 | model = torch.load(args.model_path) 163 | # print(model.keys()) 164 | 165 | save_params(model, args.depth, args.params_path) 166 | -------------------------------------------------------------------------------- /models/quantization_utils/quant_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from torch.autograd import Function, Variable 4 | import torch 5 | import bisect 6 | from fractions import Fraction 7 | import decimal 8 | from decimal import Decimal 9 | import time 10 | 11 | 12 | def linear_quantize(input, scale, zero_point, is_weight): 13 | """ 14 | Quantize single-precision input tensor to integers with the given scaling factor and zeropoint. 15 | Parameters: 16 | ---------- 17 | input: single-precision input tensor to be quantized 18 | scale: scaling factor for quantization 19 | zero_pint: shift for quantization 20 | """ 21 | 22 | # reshape scale and zeropoint for convolutional weights and activation 23 | if is_weight: 24 | if len(input.shape) == 4: 25 | scale = scale.view(-1, 1, 1, 1) 26 | zero_point = zero_point.view(-1, 1, 1, 1) 27 | # reshape scale and zeropoint for linear weights 28 | elif len(input.shape) == 2: 29 | scale = scale.view(-1, 1) 30 | zero_point = zero_point.view(-1, 1) 31 | else: 32 | scale = scale.view(-1) 33 | zero_point = zero_point.view(-1) 34 | else: 35 | if len(input.shape) == 2: 36 | scale = scale.view(1, -1) 37 | zero_point = zero_point.view(1, -1) 38 | elif len(input.shape) == 3: 39 | scale = scale.view(1, 1, -1) 40 | zero_point = zero_point.view(1, 1, -1) 41 | elif len(input.shape) == 4: 42 | scale = scale.view(1, -1, 1, 1) 43 | zero_point = zero_point.view(1, -1, 1, 1) 44 | else: 45 | raise NotImplementedError 46 | 47 | # quantized = float / scale + zero_point 48 | return torch.round(1. / scale * input + zero_point) 49 | 50 | 51 | def symmetric_linear_quantization_params(num_bits, min_val, max_val): 52 | """ 53 | Compute the scaling factor with the given quantization range for symmetric quantization. 54 | Parameters: 55 | ---------- 56 | saturation_min: lower bound for quantization range 57 | saturation_max: upper bound for quantization range 58 | """ 59 | # in this part, we do not need any gradient computation, 60 | # in order to enfore this, we put torch.no_grad() 61 | with torch.no_grad(): 62 | n = 2 ** (num_bits - 1) - 1 63 | eps = torch.finfo(torch.float32).eps 64 | 65 | max_val = torch.max(-min_val, max_val) 66 | scale = max_val / float(n) 67 | scale.clamp_(eps) 68 | 69 | return scale 70 | 71 | 72 | class SymmetricQuantFunction(Function): 73 | """ 74 | Class to quantize the given floating-point values using symmetric quantization with given range and bitwidth. 75 | """ 76 | 77 | @staticmethod 78 | def forward(ctx, x, k, specified_scale, is_weight): 79 | """ 80 | x: floating point tensor to be quantized 81 | k: quantization bitwidth 82 | Note that the current implementation of SymmetricQuantFunction requires pre-calculated scaling factor. 83 | specified_scale: pre-calculated scaling factor for the tensor x 84 | """ 85 | 86 | scale = specified_scale 87 | 88 | zero_point = torch.tensor(0.).cuda() 89 | 90 | n = 2 ** (k - 1) - 1 91 | new_quant_x = linear_quantize(x, scale, zero_point, is_weight=is_weight) 92 | new_quant_x = torch.clamp(new_quant_x, -n-1, n) 93 | 94 | ctx.scale = scale 95 | ctx.is_weight = is_weight 96 | return new_quant_x 97 | 98 | @staticmethod 99 | def backward(ctx, grad_output): 100 | 101 | scale = ctx.scale 102 | is_weight = ctx.is_weight 103 | if is_weight: 104 | if len(grad_output.shape) == 4: 105 | scale = scale.view(-1, 1, 1, 1) 106 | elif len(grad_output.shape) == 2: 107 | scale = scale.view(-1, 1) 108 | else: 109 | scale = scale.view(-1) 110 | else: 111 | if len(grad_output.shape) == 2: 112 | scale = scale.view(1, -1) 113 | elif len(grad_output.shape) == 3: 114 | scale = scale.view(1, 1, -1) 115 | elif len(grad_output.shape) == 4: 116 | scale = scale.view(1, -1, 1, 1) 117 | else: 118 | raise NotImplementedError 119 | return grad_output.clone() / scale, None, None, None 120 | 121 | 122 | class floor_ste(Function): 123 | """ 124 | Straight-through Estimator(STE) for torch.floor() 125 | """ 126 | 127 | @staticmethod 128 | def forward(ctx, x): 129 | return torch.floor(x) 130 | 131 | @staticmethod 132 | def backward(ctx, grad_output): 133 | return grad_output.clone() 134 | 135 | 136 | class round_ste(Function): 137 | """ 138 | Straight-through Estimator(STE) for torch.round() 139 | """ 140 | 141 | @staticmethod 142 | def forward(ctx, x): 143 | return torch.round(x) 144 | 145 | @staticmethod 146 | def backward(ctx, grad_output): 147 | return grad_output.clone() 148 | 149 | 150 | def batch_frexp(inputs, max_bit=31): 151 | """ 152 | Decompose the scaling factor into mantissa and twos exponent. 153 | Parameters: 154 | ---------- 155 | inputs: scaling factor 156 | return: (mantissa, exponent) 157 | """ 158 | 159 | shape_of_input = inputs.size() 160 | 161 | # trans the input to be a 1-d tensor 162 | inputs = inputs.view(-1) 163 | 164 | output_m, output_e = np.frexp(inputs.cpu().numpy()) 165 | tmp_m = [] 166 | for m in output_m: 167 | int_m_shifted = int(Decimal(m * (2 ** max_bit)).quantize(Decimal('1'), 168 | rounding=decimal.ROUND_HALF_UP)) 169 | tmp_m.append(int_m_shifted) 170 | output_m = np.array(tmp_m) 171 | 172 | output_e = float(max_bit) - output_e 173 | 174 | return torch.from_numpy(output_m).cuda().view(shape_of_input), \ 175 | torch.from_numpy(output_e).cuda().view(shape_of_input) 176 | 177 | 178 | class fixedpoint_mul(Function): 179 | """ 180 | Function to perform fixed-point arthmetic that can match integer arthmetic on hardware. 181 | Parameters: 182 | ---------- 183 | pre_act: input tensor 184 | pre_act_scaling_factor: ithe scaling factor of the input tensor 185 | bit_num: quantization bitwidth 186 | quant_mode: The mode for quantization, 'symmetric' or 'asymmetric' 187 | z_scaling_factor: the scaling factor of the output tensor 188 | identity: identity tensor 189 | identity_scaling_factor: the scaling factor of the identity tensor 190 | """ 191 | 192 | @staticmethod 193 | def forward(ctx, pre_act, pre_act_scaling_factor, 194 | bit_num, quant_mode, z_scaling_factor, 195 | identity=None, identity_scaling_factor=None): 196 | 197 | # TODO(Sehoon): May require other type of reshape 198 | if len(pre_act.shape) == 2: 199 | reshape = lambda x: x.view(1, -1) 200 | elif len(pre_act.shape) == 3: 201 | reshape = lambda x: x.view(1, 1, -1) 202 | elif len(pre_act.shape) == 4: 203 | reshape = lambda x: x.view(1, -1, 1, 1) 204 | else: 205 | raise NotImplementedError 206 | ctx.identity = identity 207 | 208 | if quant_mode == 'symmetric': 209 | n = 2 ** (bit_num - 1) - 1 210 | else: 211 | n = 2 ** bit_num - 1 212 | 213 | with torch.no_grad(): 214 | pre_act_scaling_factor = reshape(pre_act_scaling_factor) 215 | if identity is not None: 216 | identity_scaling_factor = reshape(identity_scaling_factor) 217 | 218 | ctx.z_scaling_factor = z_scaling_factor 219 | 220 | z_int = torch.round(pre_act / pre_act_scaling_factor) 221 | _A = pre_act_scaling_factor.type(torch.double) 222 | _B = (z_scaling_factor.type(torch.float)).type(torch.double) 223 | new_scale = _A / _B 224 | # print(new_scale) 225 | # exit() 226 | new_scale = reshape(new_scale) 227 | 228 | m, e = batch_frexp(new_scale) 229 | output = z_int.type(torch.double) * m.type(torch.double) 230 | output = torch.round(output / (2.0 ** e)) 231 | 232 | if identity is not None: 233 | # needs addition of identity activation 234 | wx_int = torch.round(identity / identity_scaling_factor) 235 | 236 | _A = identity_scaling_factor.type(torch.double) 237 | _B = (z_scaling_factor.type(torch.float)).type(torch.double) 238 | new_scale = _A / _B 239 | new_scale = reshape(new_scale) 240 | 241 | m1, e1 = batch_frexp(new_scale) 242 | output1 = wx_int.type(torch.double) * m1.type(torch.double) 243 | output1 = torch.round(output1 / (2.0 ** e1)) 244 | 245 | output = output1 + output 246 | 247 | if bit_num in [4, 8, 16, 32]: 248 | if quant_mode == 'symmetric': 249 | return torch.clamp(output.type(torch.float), -n-1, n) 250 | else: 251 | return torch.clamp(output.type(torch.float), 0, n) 252 | else: 253 | return output.type(torch.float) 254 | 255 | @staticmethod 256 | def backward(ctx, grad_output): 257 | identity_grad = None 258 | if ctx.identity is not None: 259 | identity_grad = grad_output.clone() / ctx.z_scaling_factor 260 | return grad_output.clone() / ctx.z_scaling_factor, None, None, None, None, \ 261 | identity_grad, None 262 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /TVM_benchmark/models/quantized_vit.py: -------------------------------------------------------------------------------- 1 | import tvm 2 | from tvm import relay 3 | 4 | from . import layers 5 | 6 | 7 | def Q_Block(data, 8 | name, 9 | dim, 10 | num_heads, 11 | mlp_ratio, 12 | qk_scale, 13 | batch_size, 14 | rounding='TRUNCATE'): 15 | 16 | 'Attention mudule' 17 | shortcut = data 18 | 19 | ## layer_norm 20 | qconfig0 = layers.get_qconfig(name + '_qconfig_norm1') 21 | norm1_bias = relay.var(name + '_norm1_bias', shape=[dim], dtype='int32') 22 | norm1 = layers.quantized_layernorm(data, norm1_bias) 23 | 24 | ## attention 25 | qconfig1 = layers.get_qconfig(name + '_qconfig_qkv') 26 | req1 = layers.requantize(norm1, 27 | input_scale=qconfig0.output_scale, 28 | output_scale=qconfig1.input_scale, 29 | out_dtype=qconfig1.input_dtype) 30 | 31 | req1 = relay.reshape(req1, [-3,0]) 32 | qkv = layers.quantized_dense(data=req1, 33 | name=name + '_attn_qkv', 34 | input_scale=qconfig1.input_scale, 35 | kernel_scale=qconfig1.kernel_scale, 36 | units=dim*3, 37 | kernel_shape=(dim*3, dim), 38 | kernel_dtype='int8', 39 | add_bias=True) 40 | qkv = relay.reshape(qkv, [-4,batch_size,-1,-2]) 41 | 42 | qconfig2 = layers.get_qconfig(name + '_qconfig_matmul_1') 43 | req2 = layers.requantize(qkv, 44 | input_scale=qconfig1.output_scale, 45 | output_scale=qconfig2.input_scale, 46 | out_dtype=qconfig2.input_dtype) 47 | 48 | qkv_reshape = relay.reshape(req2, [0, 0, 3, num_heads, -1]) 49 | qkv = relay.transpose(qkv_reshape, [2, 0, 3, 1, 4]) 50 | qkv = relay.split(qkv, 3, axis=0) 51 | q = relay.reshape(relay.squeeze(qkv[0], axis=[0]), [-3,-2]) 52 | k = relay.reshape(relay.squeeze(qkv[1], axis=[0]), [-3,-2]) 53 | v = relay.reshape(relay.squeeze(qkv[2], axis=[0]), [-3,-2]) 54 | 55 | attn = layers.quantized_matmul(q, k, 56 | input_scale1=qconfig2.input_scale, 57 | input_scale2=qconfig2.input_scale) 58 | 59 | attn = relay.reshape(attn, [-4,-1,num_heads,-2]) 60 | 61 | qconfig3 = layers.get_qconfig(name + '_qconfig_softmax') 62 | req3 = layers.requantize(attn, 63 | input_scale=qconfig2.output_scale * qk_scale, 64 | output_scale=qconfig3.input_scale, 65 | out_dtype=qconfig3.input_dtype) 66 | 67 | attn = layers.quantized_softmax(req3, qconfig3.input_scale) 68 | 69 | qconfig4 = layers.get_qconfig(name + '_qconfig_matmul_2') 70 | attn = relay.reshape(attn,[-3,-2]) 71 | v = relay.transpose(v, [0, 2, 1]) 72 | attn = layers.quantized_matmul(attn, v, 73 | input_scale1=qconfig4.input_scale, 74 | input_scale2=qconfig2.input_scale) 75 | 76 | attn = relay.reshape(attn, [-4,-1,num_heads,-2]) 77 | 78 | attn = relay.transpose(attn, [0, 2, 1, 3]) 79 | attn = relay.reshape(attn, [0, 0, -1]) 80 | 81 | qconfig5 = layers.get_qconfig(name + '_qconfig_proj') 82 | req5 = layers.requantize(attn, 83 | input_scale=qconfig4.output_scale, 84 | output_scale=qconfig5.input_scale, 85 | out_dtype=qconfig5.input_dtype) 86 | 87 | req5 = relay.reshape(req5, [-3,0]) 88 | proj = layers.quantized_dense(data=req5, 89 | name=name + '_attn_proj', 90 | input_scale=qconfig5.input_scale, 91 | kernel_scale=qconfig5.kernel_scale, 92 | units=dim, 93 | kernel_shape=(dim, dim), 94 | kernel_dtype='int8', 95 | add_bias=True) 96 | proj = relay.reshape(proj, [-4,batch_size,-1,-2]) 97 | 98 | ## shortcut 99 | qconfig6 = layers.get_qconfig(name + '_qconfig_add1') 100 | req6 = layers.requantize(proj, 101 | input_scale=qconfig5.output_scale, 102 | output_scale=qconfig6.input_scale, 103 | out_dtype=qconfig6.input_dtype) 104 | 105 | add1 = layers.add(lhs=req6, 106 | rhs=shortcut, 107 | lhs_scale=qconfig6.input_scale, 108 | rhs_scale=qconfig0.input_scale, 109 | output_scale=qconfig6.output_scale) 110 | 111 | 'MLP module' 112 | shortcut = add1 113 | ## layer_norm 114 | qconfig7 = layers.get_qconfig(name + '_qconfig_norm2') 115 | norm2_bias = relay.var(name + '_norm2_bias', shape=[dim], dtype='int32') 116 | norm2 = layers.quantized_layernorm(add1, norm2_bias) 117 | 118 | ## dense 119 | qconfig8 = layers.get_qconfig(name + '_qconfig_fc1') 120 | req8 = layers.requantize(norm2, 121 | input_scale=qconfig7.output_scale, 122 | output_scale=qconfig8.input_scale, 123 | out_dtype=qconfig8.input_dtype) 124 | 125 | req8 = relay.reshape(req8, [-3,0]) 126 | fc1 = layers.quantized_dense(data=req8, 127 | name=name + '_mlp_fc1', 128 | input_scale=qconfig8.input_scale, 129 | kernel_scale=qconfig8.kernel_scale, 130 | units=mlp_ratio*dim, 131 | kernel_shape=(mlp_ratio*dim, dim), 132 | kernel_dtype='int8', 133 | add_bias=True) 134 | fc1 = relay.reshape(fc1, [-4,batch_size,-1,-2]) 135 | 136 | qconfig9 = layers.get_qconfig(name + '_qconfig_gelu') 137 | req9 = layers.requantize(fc1, 138 | input_scale=qconfig8.output_scale, 139 | output_scale=qconfig9.input_scale, 140 | out_dtype=qconfig9.input_dtype) 141 | 142 | act = layers.quantized_gelu(req9, qconfig9.input_scale) 143 | 144 | 145 | qconfig10 = layers.get_qconfig(name + '_qconfig_fc2') 146 | req10 = layers.requantize(act, 147 | input_scale=qconfig9.output_scale, 148 | output_scale=qconfig10.input_scale, 149 | out_dtype=qconfig10.input_dtype) 150 | 151 | req10 = relay.reshape(req10, [-3,0]) 152 | fc2 = layers.quantized_dense(data=req10, 153 | name=name + '_mlp_fc2', 154 | input_scale=qconfig10.input_scale, 155 | kernel_scale=qconfig10.kernel_scale, 156 | units=dim, 157 | kernel_shape=(dim, mlp_ratio*dim), 158 | kernel_dtype='int8', 159 | add_bias=True) 160 | fc2 = relay.reshape(fc2, [-4,batch_size,-1,-2]) 161 | 162 | ## shortcut 163 | qconfig11 = layers.get_qconfig(name + '_qconfig_add2') 164 | req11 = layers.requantize(fc2, 165 | input_scale=qconfig10.output_scale, 166 | output_scale=qconfig11.input_scale, 167 | out_dtype=qconfig11.input_dtype) 168 | 169 | add2 = layers.add(lhs=req11, 170 | rhs=shortcut, 171 | lhs_scale=qconfig11.input_scale, 172 | rhs_scale=qconfig7.input_scale, 173 | output_scale=qconfig11.output_scale) 174 | 175 | add2 = relay.annotation.stop_fusion(add2) 176 | 177 | return add2 178 | 179 | 180 | def Q_VisionTransformer(data_shape, 181 | dtype='int8', 182 | patch_size=16, 183 | num_patches=196, 184 | in_chans=3, 185 | num_classes=1000, 186 | embed_dim=192, 187 | depth=12, 188 | num_heads=3, 189 | mlp_ratio=4): 190 | data = relay.var('data', shape=data_shape, dtype=dtype) 191 | 192 | qconfig_embed_conv = layers.get_qconfig('qconfig_embed_conv') 193 | proj = layers.quantized_conv2d(data=data, 194 | name='embed_conv', 195 | add_bias=True, 196 | input_channels=in_chans, 197 | output_channels=embed_dim, 198 | kernel_dtype=qconfig_embed_conv.kernel_dtype, 199 | input_scale=qconfig_embed_conv.input_scale, 200 | kernel_scale=qconfig_embed_conv.kernel_scale, 201 | kernel_size=(patch_size, patch_size), 202 | strides=(patch_size, patch_size), 203 | padding=(0, 0), 204 | data_layout='NCHW', 205 | kernel_layout='OIHW') 206 | proj = relay.reshape(proj, [0, 0, -1]) 207 | body = relay.transpose(proj, [0, 2, 1]) 208 | 209 | qconfig_add = layers.get_qconfig('qconfig_addpos') 210 | body = layers.requantize(body, 211 | input_scale=qconfig_embed_conv.output_scale, 212 | output_scale=qconfig_add.input_scale, 213 | out_dtype=qconfig_add.input_dtype) 214 | 215 | 216 | cls_token = relay.var('cls_token_weight', shape=(1, 1, embed_dim)) 217 | cls_token = layers.quantize(cls_token, output_scale=qconfig_add.input_scale, out_dtype=qconfig_add.input_dtype) 218 | cls_tokens = relay.repeat(cls_token, data_shape[0], axis=0) 219 | 220 | body = relay.concatenate([cls_tokens, body], axis=1) 221 | 222 | pos_embed = relay.var('pos_embed_weight', shape=(1, num_patches+1, embed_dim)) 223 | qconfig_pos = layers.get_qconfig('qconfig_pos') 224 | pos_embed = layers.quantize(pos_embed, output_scale=qconfig_pos.output_scale, out_dtype=qconfig_add.input_dtype) 225 | 226 | body = layers.add(lhs=body, 227 | rhs=pos_embed, 228 | lhs_scale=qconfig_add.input_scale, 229 | rhs_scale=qconfig_pos.output_scale, 230 | output_scale=qconfig_add.output_scale) 231 | 232 | body = relay.annotation.stop_fusion(body) 233 | 234 | 235 | qk_scale = (embed_dim//num_heads) ** -0.5 236 | 237 | for i in range(depth): 238 | body = Q_Block(body, 239 | name='block_%d' % (i), 240 | dim=embed_dim, 241 | num_heads=num_heads, 242 | mlp_ratio=mlp_ratio, 243 | qk_scale=qk_scale, 244 | batch_size=data_shape[0], 245 | rounding='TONEAREST') 246 | 247 | 248 | qconfig_norm = layers.get_qconfig('qconfig_norm') 249 | norm_bias = relay.var('norm_bias', shape=[embed_dim], dtype='int32') 250 | norm = layers.quantized_layernorm(body, norm_bias) 251 | 252 | body = relay.split(norm, 197, axis=1) 253 | body = relay.squeeze(body[0], axis=[1]) 254 | 255 | 256 | qconfig_head = layers.get_qconfig('qconfig_head') 257 | req = layers.requantize(body, 258 | input_scale=qconfig_norm.output_scale, 259 | output_scale=qconfig_head.input_scale, 260 | out_dtype=qconfig_head.input_dtype) 261 | 262 | head = layers.quantized_dense(data=req, 263 | name='head', 264 | input_scale=qconfig_head.input_scale, 265 | kernel_scale=qconfig_head.kernel_scale, 266 | units=num_classes, 267 | kernel_shape=(num_classes, embed_dim), 268 | kernel_dtype='int8', 269 | add_bias=True) 270 | 271 | 272 | net = layers.dequantize(head, input_scale=qconfig_head.output_scale) 273 | net = relay.nn.softmax(data=net) 274 | return relay.Function(relay.analysis.free_vars(net), net) 275 | -------------------------------------------------------------------------------- /models/vit_quant.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import re 4 | import warnings 5 | from itertools import repeat 6 | import collections.abc 7 | from collections import OrderedDict 8 | from functools import partial 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from torch import nn 13 | 14 | from .layers_quant import PatchEmbed, Mlp, DropPath, trunc_normal_ 15 | from .quantization_utils import QuantLinear, QuantAct, QuantConv2d, IntLayerNorm, IntSoftmax, IntGELU, QuantMatMul 16 | from .utils import load_weights_from_npz 17 | 18 | 19 | __all__ = ['deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224', 20 | 'vit_base_patch16_224', 'vit_large_patch16_224'] 21 | 22 | 23 | class Attention(nn.Module): 24 | def __init__( 25 | self, 26 | dim, 27 | num_heads=8, 28 | qkv_bias=False, 29 | qk_scale=None, 30 | attn_drop=0.0, 31 | proj_drop=0.0): 32 | super().__init__() 33 | self.num_heads = num_heads 34 | head_dim = dim // num_heads 35 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 36 | self.scale = qk_scale or head_dim ** -0.5 37 | 38 | self.qkv = QuantLinear( 39 | dim, 40 | dim * 3, 41 | bias=qkv_bias 42 | ) 43 | self.qact1 = QuantAct() 44 | self.qact_attn1 = QuantAct() 45 | self.qact2 = QuantAct() 46 | self.proj = QuantLinear( 47 | dim, 48 | dim 49 | ) 50 | self.qact3 = QuantAct(16) 51 | self.qact_softmax = QuantAct() 52 | self.attn_drop = nn.Dropout(attn_drop) 53 | self.proj_drop = nn.Dropout(proj_drop) 54 | self.int_softmax = IntSoftmax(16) 55 | 56 | self.matmul_1 = QuantMatMul() 57 | self.matmul_2 = QuantMatMul() 58 | 59 | def forward(self, x, act_scaling_factor): 60 | B, N, C = x.shape 61 | x, act_scaling_factor = self.qkv(x, act_scaling_factor) 62 | x, act_scaling_factor_1 = self.qact1(x, act_scaling_factor) 63 | qkv = x.reshape(B, N, 3, self.num_heads, C // 64 | self.num_heads).permute(2, 0, 3, 1, 4) # (BN33) 65 | q, k, v = ( 66 | qkv[0], 67 | qkv[1], 68 | qkv[2], 69 | ) # make torchscript happy (cannot use tensor as tuple) 70 | attn, act_scaling_factor = self.matmul_1(q, act_scaling_factor_1, 71 | k.transpose(-2, -1), act_scaling_factor_1) 72 | attn = attn * self.scale 73 | act_scaling_factor = act_scaling_factor * self.scale 74 | attn, act_scaling_factor = self.qact_attn1(attn, act_scaling_factor) 75 | 76 | attn, act_scaling_factor = self.int_softmax(attn, act_scaling_factor) 77 | 78 | attn = self.attn_drop(attn) 79 | x, act_scaling_factor = self.matmul_2(attn, act_scaling_factor, 80 | v, act_scaling_factor_1) 81 | x = x.transpose(1, 2).reshape(B, N, C) 82 | 83 | x, act_scaling_factor = self.qact2(x, act_scaling_factor) 84 | x, act_scaling_factor = self.proj(x, act_scaling_factor) 85 | x, act_scaling_factor = self.qact3(x, act_scaling_factor) 86 | x = self.proj_drop(x) 87 | 88 | return x, act_scaling_factor 89 | 90 | 91 | class Block(nn.Module): 92 | def __init__( 93 | self, 94 | dim, 95 | num_heads, 96 | mlp_ratio=4.0, 97 | qkv_bias=False, 98 | qk_scale=None, 99 | drop=0.0, 100 | attn_drop=0.0, 101 | drop_path=0.0, 102 | act_layer=nn.GELU, 103 | norm_layer=nn.LayerNorm): 104 | super().__init__() 105 | self.norm1 = norm_layer(dim) 106 | self.qact1 = QuantAct() 107 | self.attn = Attention( 108 | dim, 109 | num_heads=num_heads, 110 | qkv_bias=qkv_bias, 111 | qk_scale=qk_scale, 112 | attn_drop=attn_drop, 113 | proj_drop=drop 114 | ) 115 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 116 | self.drop_path = DropPath( 117 | drop_path) if drop_path > 0.0 else nn.Identity() 118 | self.qact2 = QuantAct(16) 119 | self.norm2 = norm_layer(dim) 120 | self.qact3 = QuantAct() 121 | mlp_hidden_dim = int(dim * mlp_ratio) 122 | self.mlp = Mlp( 123 | in_features=dim, 124 | hidden_features=mlp_hidden_dim, 125 | act_layer=act_layer, 126 | drop=drop 127 | ) 128 | self.qact4 = QuantAct(16) 129 | 130 | def forward(self, x_1, act_scaling_factor_1): 131 | x, act_scaling_factor = self.norm1(x_1, act_scaling_factor_1) 132 | x, act_scaling_factor = self.qact1(x, act_scaling_factor) 133 | x, act_scaling_factor = self.attn(x, act_scaling_factor) 134 | x = self.drop_path(x) 135 | x_2, act_scaling_factor_2 = self.qact2(x, act_scaling_factor, x_1, act_scaling_factor_1) 136 | 137 | x, act_scaling_factor = self.norm2(x_2, act_scaling_factor_2) 138 | x, act_scaling_factor = self.qact3(x, act_scaling_factor) 139 | x, act_scaling_factor = self.mlp(x, act_scaling_factor) 140 | x = self.drop_path(x) 141 | x, act_scaling_factor = self.qact4(x, act_scaling_factor, x_2, act_scaling_factor_2) 142 | 143 | return x, act_scaling_factor 144 | 145 | 146 | class VisionTransformer(nn.Module): 147 | """Vision Transformer 148 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 149 | https://arxiv.org/abs/2010.11929 150 | """ 151 | 152 | def __init__( 153 | self, 154 | img_size=224, 155 | patch_size=16, 156 | in_chans=3, 157 | num_classes=1000, 158 | embed_dim=768, 159 | depth=12, 160 | num_heads=12, 161 | mlp_ratio=4.0, 162 | qkv_bias=True, 163 | qk_scale=None, 164 | representation_size=None, 165 | drop_rate=0.0, 166 | attn_drop_rate=0.0, 167 | drop_path_rate=0.0, 168 | norm_layer=None): 169 | super().__init__() 170 | self.num_classes = num_classes 171 | self.num_features = ( 172 | self.embed_dim 173 | ) = embed_dim # num_features for consistency with other models 174 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 175 | 176 | self.qact_input = QuantAct() 177 | 178 | self.patch_embed = PatchEmbed( 179 | img_size=img_size, 180 | patch_size=patch_size, 181 | in_chans=in_chans, 182 | embed_dim=embed_dim 183 | ) 184 | num_patches = self.patch_embed.num_patches 185 | 186 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 187 | self.pos_embed = nn.Parameter( 188 | torch.zeros(1, num_patches + 1, embed_dim)) 189 | 190 | self.pos_drop = nn.Dropout(p=drop_rate) 191 | 192 | self.qact_pos = QuantAct(16) 193 | self.qact1 = QuantAct(16) 194 | 195 | dpr = [ 196 | x.item() for x in torch.linspace(0, drop_path_rate, depth) 197 | ] # stochastic depth decay rule 198 | self.blocks = nn.ModuleList( 199 | [ 200 | Block( 201 | dim=embed_dim, 202 | num_heads=num_heads, 203 | mlp_ratio=mlp_ratio, 204 | qkv_bias=qkv_bias, 205 | qk_scale=qk_scale, 206 | drop=drop_rate, 207 | attn_drop=attn_drop_rate, 208 | drop_path=dpr[i], 209 | act_layer=IntGELU, 210 | norm_layer=norm_layer 211 | ) 212 | for i in range(depth) 213 | ] 214 | ) 215 | self.norm = norm_layer(embed_dim) 216 | self.qact2 = QuantAct() 217 | 218 | # Representation layer 219 | if representation_size: 220 | self.num_features = representation_size 221 | self.pre_logits = nn.Sequential( 222 | OrderedDict( 223 | [ 224 | ("fc", nn.Linear(embed_dim, representation_size)), 225 | ("act", nn.Tanh()), 226 | ] 227 | ) 228 | ) 229 | else: 230 | self.pre_logits = nn.Identity() 231 | 232 | # Classifier head 233 | self.head = ( 234 | QuantLinear( 235 | self.num_features, 236 | num_classes) 237 | if num_classes > 0 238 | else nn.Identity() 239 | ) 240 | self.act_out = QuantAct() 241 | trunc_normal_(self.pos_embed, std=0.02) 242 | trunc_normal_(self.cls_token, std=0.02) 243 | self.apply(self._init_weights) 244 | 245 | def _init_weights(self, m): 246 | if isinstance(m, nn.Linear): 247 | trunc_normal_(m.weight, std=0.02) 248 | if isinstance(m, nn.Linear) and m.bias is not None: 249 | nn.init.constant_(m.bias, 0) 250 | elif isinstance(m, nn.LayerNorm): 251 | nn.init.constant_(m.bias, 0) 252 | nn.init.constant_(m.weight, 1.0) 253 | 254 | def forward_features(self, x): 255 | B = x.shape[0] 256 | 257 | x, act_scaling_factor = self.qact_input(x) 258 | x, act_scaling_factor = self.patch_embed(x, act_scaling_factor) 259 | cls_tokens = self.cls_token.expand( 260 | B, -1, -1 261 | ) # stole cls_tokens impl from Phil Wang, thanks 262 | x = torch.cat((cls_tokens, x), dim=1) # share scaling_factor 263 | 264 | x_pos, act_scaling_factor_pos = self.qact_pos(self.pos_embed) 265 | x, act_scaling_factor = self.qact1(x, act_scaling_factor, x_pos, act_scaling_factor_pos) 266 | x = self.pos_drop(x) 267 | 268 | for blk in self.blocks: 269 | x, act_scaling_factor = blk(x, act_scaling_factor) 270 | 271 | x, act_scaling_factor = self.norm(x, act_scaling_factor) 272 | x = x[:, 0] 273 | x, act_scaling_factor = self.qact2(x, act_scaling_factor) 274 | x = self.pre_logits(x) 275 | 276 | return x, act_scaling_factor 277 | 278 | def forward(self, x): 279 | x, act_scaling_factor = self.forward_features(x) 280 | x, act_scaling_factor = self.head(x, act_scaling_factor) 281 | #x, _ = self.act_out(x, act_scaling_factor) 282 | return x 283 | 284 | 285 | def deit_tiny_patch16_224(pretrained=False, **kwargs): 286 | model = VisionTransformer( 287 | patch_size=16, 288 | embed_dim=192, 289 | depth=12, 290 | num_heads=3, 291 | mlp_ratio=4, 292 | qkv_bias=True, 293 | norm_layer=partial(IntLayerNorm, eps=1e-6), 294 | **kwargs, 295 | ) 296 | if pretrained: 297 | checkpoint = torch.hub.load_state_dict_from_url( 298 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", 299 | map_location="cpu", 300 | check_hash=True, 301 | ) 302 | model.load_state_dict(checkpoint["model"], strict=False) 303 | return model 304 | 305 | 306 | def deit_small_patch16_224(pretrained=False, **kwargs): 307 | model = VisionTransformer( 308 | patch_size=16, 309 | embed_dim=384, 310 | depth=12, 311 | num_heads=6, 312 | mlp_ratio=4, 313 | qkv_bias=True, 314 | norm_layer=partial(IntLayerNorm, eps=1e-6), 315 | **kwargs 316 | ) 317 | if pretrained: 318 | checkpoint = torch.hub.load_state_dict_from_url( 319 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", 320 | map_location="cpu", check_hash=True 321 | ) 322 | model.load_state_dict(checkpoint["model"], strict=False) 323 | return model 324 | 325 | 326 | def deit_base_patch16_224(pretrained=False, **kwargs): 327 | model = VisionTransformer( 328 | patch_size=16, 329 | embed_dim=768, 330 | depth=12, 331 | num_heads=12, 332 | mlp_ratio=4, 333 | qkv_bias=True, 334 | norm_layer=partial(IntLayerNorm, eps=1e-6), 335 | **kwargs 336 | ) 337 | if pretrained: 338 | checkpoint = torch.hub.load_state_dict_from_url( 339 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", 340 | map_location="cpu", check_hash=True 341 | ) 342 | model.load_state_dict(checkpoint["model"], strict=False) 343 | return model 344 | 345 | 346 | def vit_base_patch16_224(pretrained=False, **kwargs): 347 | model = VisionTransformer( 348 | patch_size=16, 349 | embed_dim=768, 350 | depth=12, 351 | num_heads=12, 352 | mlp_ratio=4, 353 | qkv_bias=True, 354 | norm_layer=partial(IntLayerNorm, eps=1e-6), 355 | **kwargs 356 | ) 357 | if pretrained: 358 | url = "https://storage.googleapis.com/vit_models/augreg/" + \ 359 | "B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz" 360 | 361 | load_weights_from_npz(model, url, check_hash=True) 362 | return model 363 | 364 | 365 | def vit_large_patch16_224(pretrained=False, **kwargs): 366 | model = VisionTransformer( 367 | patch_size=16, 368 | embed_dim=1024, 369 | depth=24, 370 | num_heads=16, 371 | mlp_ratio=4, 372 | qkv_bias=True, 373 | norm_layer=partial(IntLayerNorm, eps=1e-6), 374 | **kwargs 375 | ) 376 | if pretrained: 377 | url = "https://storage.googleapis.com/vit_models/augreg/" + \ 378 | "L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz" 379 | 380 | load_weights_from_npz(model, url, check_hash=True) 381 | return model 382 | -------------------------------------------------------------------------------- /TVM_benchmark/models/layers.py: -------------------------------------------------------------------------------- 1 | """Simple Layer DSL wrapper to ease creation of neural nets.""" 2 | from dataclasses import dataclass 3 | from tvm import relay 4 | from collections import namedtuple 5 | 6 | import numpy as np 7 | from tvm.relay.op.tensor import exp 8 | 9 | 10 | QConfig = namedtuple('QConfig', 'from_dtype, from_scale, from_zero_point, \ 11 | input_dtype, input_scale, input_zero_point, \ 12 | kernel_dtype, kernel_scale, kernel_zero_point, \ 13 | output_dtype, output_scale, output_zero_point', 14 | defaults=('int32', 65.0, 0.0, 'int8', 8.0, 0.0, 'int8', 8.0, 0.0, 'int32', 74.0, 0.0)) 15 | 16 | class QuantizeContext(object): 17 | qconfig_dict = dict() 18 | qconfig_print = dict() 19 | default_qconfig = QConfig() 20 | 21 | @staticmethod 22 | def read_qconfig_from_file(file_path): 23 | pass 24 | 25 | @staticmethod 26 | def set_default_qconfig(qconfig): 27 | QuantizeContext.default_qconfig = qconfig 28 | 29 | def get_qconfig(name): 30 | #print(QuantizeContext.qconfig_dict) 31 | if name in QuantizeContext.qconfig_dict: 32 | return QuantizeContext.qconfig_dict[name] 33 | else: 34 | QuantizeContext.qconfig_print[name] = QuantizeContext.default_qconfig 35 | return QuantizeContext.default_qconfig 36 | 37 | 38 | def quantized_conv2d(data, 39 | kernel_dtype, 40 | name, 41 | input_channels, 42 | kernel_size, 43 | output_channels, 44 | strides=(1, 1), 45 | padding=(0, 0), 46 | weight=None, 47 | add_bias=False, 48 | input_scale=8.0, 49 | kernel_scale=8.0, 50 | input_zero_point=0.0, 51 | kernel_zero_point=0.0, 52 | data_layout='NCHW', 53 | kernel_layout='OIHW', 54 | **kwargs): 55 | 56 | """Wrapper of qnn.conv2d 57 | Parameters 58 | ---------- 59 | data : relay.Expr 60 | The input expression. 61 | weight : relay.Expr 62 | The weight to conv2d. 63 | name : str 64 | The name of this convolution. 65 | input_channels: int 66 | The number of input channels. 67 | out_channels: int 68 | The number of output channels. 69 | input_scale : float 70 | The scale of input. 71 | kernel_scale : float 72 | The scale of kernel. 73 | input_zero_point : float 74 | The zero point of input. 75 | kernel_zero_point : float 76 | The zero point of kernel. 77 | kwargs : dict 78 | Additional arguments. 79 | Returns 80 | ------- 81 | result : relay.Expr 82 | The result. 83 | """ 84 | 85 | # print("%s, %s, %d, %d, %d, %d, %d" % (, kernel_dtype, input_channels, output_channels, kernel_size[0], strides[0], padding[0])) 86 | 87 | input_zero_point = relay.const(input_zero_point, 'int32') 88 | kernel_zero_point = relay.const(kernel_zero_point, 'int32') 89 | 90 | if isinstance(input_scale, float): 91 | input_scale = relay.const(input_scale, 'float32') 92 | else: 93 | input_scale = relay.const(input_scale.astype('float32'), 'float32') 94 | 95 | if isinstance(kernel_scale, float): 96 | kernel_scale = relay.const(kernel_scale, 'float32') 97 | else: 98 | kernel_scale = relay.const(kernel_scale.astype('float32'), 'float32') 99 | 100 | if kernel_layout == "OIHW": 101 | kernel_shape = (output_channels, input_channels, kernel_size[0], kernel_size[1]) 102 | elif kernel_layout == "HWIO": 103 | kernel_shape = (kernel_size[0], kernel_size[1], input_channels, output_channels) 104 | elif kernel_layout == "HWOI": 105 | kernel_shape = (kernel_size[0], kernel_size[1], output_channels, input_channels) 106 | elif kernel_layout == "OHWI": 107 | kernel_shape = (output_channels, kernel_size[0], kernel_size[1], input_channels) 108 | else: 109 | raise RuntimeError("Unsupported kernel layout {}".format(kernel_layout)) 110 | 111 | if weight is None: 112 | weight = relay.var(name + "_weight", shape=kernel_shape, dtype=kernel_dtype) 113 | 114 | conv2d = relay.qnn.op.conv2d(data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale, 115 | kernel_size=kernel_size, channels=output_channels, data_layout=data_layout, kernel_layout=kernel_layout, strides=strides, padding=padding, **kwargs) 116 | 117 | if add_bias: 118 | if data_layout == 'NCHW': 119 | bias_shape = (1, output_channels, 1, 1) 120 | elif data_layout == 'NHWC': 121 | bias_shape = (1, 1, 1, output_channels) 122 | elif data_layout == 'HWCN': 123 | bias_shape = (1, 1, output_channels, 1) 124 | elif data_layout == 'HWNC': 125 | bias_shape = (1, 1, 1, output_channels) 126 | else: 127 | raise RuntimeError("Unsupported conv2d layout {}".format(data_layout)) 128 | 129 | bias = relay.var(name + "_bias", shape=bias_shape, dtype="int32") 130 | return relay.add(conv2d, bias) 131 | else: 132 | return conv2d 133 | 134 | 135 | def quantize(data, 136 | output_scale=8.0, 137 | output_zero_point=0.0, 138 | axis=-1, 139 | out_dtype='int8'): 140 | 141 | output_scale = relay.const(output_scale, 'float32') 142 | output_zero_point = relay.const(output_zero_point, 'int32') 143 | 144 | return relay.qnn.op.quantize(data, output_scale, output_zero_point, axis, out_dtype) 145 | 146 | 147 | def requantize(data, 148 | input_scale=8.0, 149 | input_zero_point=0.0, 150 | output_scale=8.0, 151 | output_zero_point=0.0, 152 | axis=-1, 153 | rounding="None", 154 | compute_dtype="None", 155 | out_dtype="int8"): 156 | 157 | if isinstance(input_scale, float): 158 | input_scale = relay.const(input_scale, 'float32') 159 | else: 160 | input_scale = relay.const(np.array(input_scale).astype('float32')) 161 | 162 | input_zero_point = relay.const(input_zero_point, 'int32') 163 | 164 | if isinstance(output_scale, float): 165 | output_scale = relay.const(output_scale, 'float32') 166 | else: 167 | output_scale = relay.const(np.array(output_scale).astype('float32')) 168 | 169 | output_zero_point = relay.const(output_zero_point, 'int32') 170 | 171 | return relay.qnn.op.requantize(data, 172 | input_scale, 173 | input_zero_point, 174 | output_scale, 175 | output_zero_point, 176 | axis, 177 | rounding, 178 | compute_dtype, 179 | out_dtype) 180 | 181 | 182 | def dequantize(data, 183 | input_scale, 184 | input_zero_point=0.0, 185 | axis=-1): 186 | 187 | if isinstance(input_scale, float): 188 | input_scale = relay.const(input_scale, 'float32') 189 | else: 190 | input_scale = relay.const(input_scale.astype('float32'), 'float32') 191 | 192 | input_zero_point = relay.const(input_zero_point, 'int32') 193 | 194 | return relay.qnn.op.dequantize(data, 195 | input_scale, 196 | input_zero_point, 197 | axis) 198 | 199 | 200 | def add(lhs, 201 | rhs, 202 | lhs_scale, 203 | rhs_scale, 204 | output_scale, 205 | lhs_zero_point=0.0, 206 | rhs_zero_point=0.0, 207 | output_zero_point=0.0): 208 | 209 | lhs_scale = relay.const(lhs_scale, 'float32') 210 | lhs_zero_point = relay.const(lhs_zero_point, 'int32') 211 | 212 | rhs_scale = relay.const(rhs_scale, 'float32') 213 | rhs_zero_point = relay.const(rhs_zero_point, 'int32') 214 | 215 | if np.ndim(output_scale) == 1: 216 | output_scale = output_scale[0] 217 | if np.ndim(output_zero_point) == 1: 218 | output_zero_point = output_zero_point[0] 219 | 220 | output_scale = relay.const(output_scale, 'float32') 221 | output_zero_point = relay.const(output_zero_point, 'int32') 222 | 223 | return relay.qnn.op.add(lhs, 224 | rhs, 225 | lhs_scale, 226 | lhs_zero_point, 227 | rhs_scale, 228 | rhs_zero_point, 229 | output_scale, 230 | output_zero_point) 231 | 232 | 233 | def quantized_dense(data, 234 | name, 235 | units, 236 | kernel_shape, 237 | kernel_dtype, 238 | input_scale=8.0, 239 | kernel_scale=8.0, 240 | input_zero_point=0.0, 241 | kernel_zero_point=0.0, 242 | add_bias=False, 243 | out_dtype="int32"): 244 | """Qnn Dense operator. 245 | Applies a quantized linear transformation 246 | .. math:: 247 | `Y = X * W` 248 | Parameters 249 | ---------- 250 | data : tvm.relay.Expr 251 | The quantized input data to the operator. 252 | weight : tvm.relay.Expr 253 | The quantized weight expressions. 254 | input_zero_point: tvm.relay.Expr 255 | The input zero point. 256 | kernel_zero_point: tvm.relay.Expr 257 | The kernel zero point. 258 | input_scale: tvm.relay.Expr 259 | The scale for the input tensor. 260 | kernel_scale: tvm.relay.Expr 261 | The scale for the weight tensor. The scale for the weight tensor is 262 | stored for access to this during relay. This information is not 263 | needed in the pass pipeline after qnn.conv2d is lowered to the 264 | sequence of steps as in nn.conv2d. See also input_scale in Requantize. 265 | units : int 266 | Number of hidden units of the dense transformation. 267 | out_dtype : str, optional 268 | Specifies the output data type for mixed precision dense can be int32 or int16. 269 | Returns 270 | ------- 271 | result : tvm.relay.Expr 272 | The computed result. 273 | """ 274 | 275 | input_zero_point = relay.const(input_zero_point, 'int32') 276 | kernel_zero_point = relay.const(kernel_zero_point, 'int32') 277 | if isinstance(input_scale, float): 278 | input_scale = relay.const(input_scale, 'float32') 279 | else: 280 | input_scale = relay.const(input_scale.astype('float32'), 'float32') 281 | 282 | if isinstance(kernel_scale, float): 283 | kernel_scale = relay.const(kernel_scale, 'float32') 284 | else: 285 | kernel_scale = relay.const(kernel_scale.astype('float32'), 'float32') 286 | 287 | weight = relay.var(name + "_weight", shape=kernel_shape, dtype=kernel_dtype) 288 | 289 | dense = relay.qnn.op.dense(data, 290 | weight, 291 | input_zero_point, 292 | kernel_zero_point, 293 | input_scale, 294 | kernel_scale, 295 | units, 296 | out_dtype) 297 | if add_bias: 298 | bias = relay.var(name + "_bias", dtype="int32") 299 | return relay.nn.bias_add(dense, bias, axis=-1) 300 | else: 301 | return dense 302 | 303 | 304 | def quantized_matmul(x, y, 305 | input_scale1, 306 | input_scale2, 307 | x_zero_point=0.0, 308 | y_zero_point=0.0): 309 | x_zero_point = relay.const(x_zero_point, 'int32') 310 | y_zero_point = relay.const(y_zero_point, 'int32') 311 | if isinstance(input_scale1, float): 312 | x_scale = relay.const(input_scale1, 'float32') 313 | else: 314 | x_scale = relay.const(input_scale1.astype('float32'), 'float32') 315 | if isinstance(input_scale2, float): 316 | y_scale = relay.const(input_scale2, 'float32') 317 | else: 318 | y_scale = relay.const(input_scale2.astype('float32'), 'float32') 319 | 320 | matmul = relay.qnn.op.batch_matmul(x, y, 321 | x_zero_point, 322 | y_zero_point, 323 | x_scale, 324 | y_scale, 325 | out_dtype="int32") 326 | return matmul 327 | 328 | 329 | def quantized_layernorm(data, 330 | bias_int): 331 | mean = relay.mean(data, axis=2, keepdims=True) 332 | data = data - mean 333 | 334 | data = relay.cast(data, 'int32') 335 | data_sq = data * data 336 | 337 | data_sq = relay.cast(data_sq, 'uint32') 338 | var = relay.sum(data_sq, axis=2, keepdims=True) 339 | 340 | std = relay.const(2 ** 16, 'uint32') 341 | for _ in range(10): 342 | tmp = (std + var/std)/relay.const(2, 'uint32') 343 | std = tmp 344 | std = relay.cast(std, 'int32') 345 | 346 | factor = relay.const(2**31-1, 'int32') 347 | data = (factor / std) * data / relay.const(2, 'int32') 348 | data = data + bias_int 349 | 350 | return data 351 | 352 | 353 | def shift_exp(data, input_scale, n): 354 | 355 | data = data + relay.right_shift(data, relay.const(1, dtype='int32')) - relay.right_shift(data, relay.const(4, dtype='int32')) 356 | 357 | x0 = relay.const(-1.0/input_scale-1, 'int32') 358 | n = relay.const(n, dtype='int32') 359 | 360 | data = relay.maximum(data, n*x0) 361 | 362 | q = data / x0 363 | r = data - q * x0 364 | 365 | exp_int = relay.right_shift(r, relay.const(1, dtype='int32')) - x0 366 | exp_int = relay.left_shift(exp_int, n-q) 367 | 368 | return exp_int 369 | 370 | 371 | 372 | def quantized_softmax(data, input_scale): 373 | data = relay.cast(data, 'int32') 374 | data_max = relay.max(data, axis=-1, keepdims=True) 375 | data = data - data_max 376 | 377 | exp_int = shift_exp(data, input_scale, 16) 378 | 379 | exp_int_sum = relay.sum(exp_int, axis=-1, keepdims=True) 380 | factor = relay.const(2**31-1, 'int32') 381 | # exp_int = (factor/exp_int_sum) * exp_int / relay.const(2 ** 24, 'int32') 382 | exp_int = relay.right_shift((factor/exp_int_sum) * exp_int, relay.const(24, dtype='int32')) 383 | 384 | exp_int = relay.cast(exp_int, 'int8') 385 | 386 | return exp_int 387 | 388 | 389 | def quantized_gelu(pre_data, input_scale): 390 | pre_data = relay.cast(pre_data, 'int32') 391 | data_max = relay.max(pre_data, axis=-1, keepdims=True) 392 | data = pre_data - data_max 393 | 394 | exp_int = shift_exp(data, input_scale* 1.702, 23) 395 | exp_int_max = shift_exp(-data_max, input_scale* 1.702, 23) 396 | exp_int_sum = exp_int + exp_int_max 397 | 398 | factor = relay.const(2**31-1, 'int32') 399 | # sigmoid_int = (factor/exp_int_sum) * exp_int / relay.const(2 ** 24, 'int32') 400 | sigmoid_int = relay.right_shift((factor/exp_int_sum) * exp_int, relay.const(24, dtype='int32')) 401 | 402 | gelu = pre_data * sigmoid_int 403 | 404 | return gelu -------------------------------------------------------------------------------- /quant_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import math 5 | import logging 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn as nn 10 | from pathlib import Path 11 | 12 | from timm.data import Mixup 13 | from timm.models import create_model 14 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 15 | from timm.scheduler import create_scheduler 16 | from timm.optim import create_optimizer 17 | from timm.utils import NativeScaler, get_state_dict, ModelEma, accuracy 18 | 19 | from models import * 20 | from utils import * 21 | 22 | 23 | parser = argparse.ArgumentParser(description="I-ViT") 24 | 25 | parser.add_argument("--model", default='deit_tiny', 26 | choices=['deit_tiny', 'deit_small', 'deit_base', 27 | 'swin_tiny', 'swin_small', 'swin_base'], 28 | help="model") 29 | parser.add_argument('--data', metavar='DIR', default='/dataset/imagenet/', 30 | help='path to dataset') 31 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET'], 32 | type=str, help='Image Net dataset path') 33 | parser.add_argument("--nb-classes", default=1000, type=int, help="number of classes") 34 | parser.add_argument('--input-size', default=224, type=int, help='images input size') 35 | parser.add_argument("--device", default="cuda", type=str, help="device") 36 | parser.add_argument("--print-freq", default=1000, 37 | type=int, help="print frequency") 38 | parser.add_argument("--seed", default=0, type=int, help="seed") 39 | parser.add_argument('--output-dir', type=str, default='results/', 40 | help='path to save log and quantized model') 41 | 42 | parser.add_argument('--resume', default='', help='resume from checkpoint') 43 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 44 | help='start epoch') 45 | parser.add_argument('--batch-size', default=128, type=int) 46 | parser.add_argument('--epochs', default=90, type=int) 47 | parser.add_argument('--num-workers', default=8, type=int) 48 | parser.add_argument('--pin-mem', action='store_true', 49 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 50 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', 51 | help='') 52 | parser.set_defaults(pin_mem=True) 53 | 54 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 55 | help='Dropout rate (default: 0.)') 56 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 57 | help='Drop path rate (default: 0.1)') 58 | 59 | parser.add_argument('--model-ema', action='store_true') 60 | parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') 61 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') 62 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') 63 | 64 | # Optimizer parameters 65 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 66 | help='Optimizer (default: "adamw"') 67 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 68 | help='Optimizer Epsilon (default: 1e-8)') 69 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 70 | help='Optimizer Betas (default: None, use opt default)') 71 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 72 | help='Clip gradient norm (default: None, no clipping)') 73 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 74 | help='SGD momentum (default: 0.9)') 75 | parser.add_argument('--weight-decay', type=float, default=1e-4, 76 | help='weight decay (default: 1e-4)') 77 | # Learning rate schedule parameters 78 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 79 | help='LR scheduler (default: "cosine"') 80 | parser.add_argument('--lr', type=float, default=1e-6, metavar='LR', 81 | help='learning rate (default: 1e-6)') 82 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 83 | help='learning rate noise on/off epoch percentages') 84 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 85 | help='learning rate noise limit percent (default: 0.67)') 86 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 87 | help='learning rate noise std-dev (default: 1.0)') 88 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 89 | help='warmup learning rate (default: 1e-6)') 90 | parser.add_argument('--min-lr', type=float, default=5e-7, metavar='LR', 91 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 92 | 93 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 94 | help='epoch interval to decay LR') 95 | parser.add_argument('--warmup-epochs', type=int, default=0, metavar='N', 96 | help='epochs to warmup LR, if scheduler supports') 97 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 98 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 99 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 100 | help='patience epochs for Plateau LR scheduler (default: 10') 101 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 102 | help='LR decay rate (default: 0.1)') 103 | 104 | # Augmentation parameters 105 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 106 | help='Color jitter factor (default: 0.4)') 107 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 108 | help='Use AutoAugment policy. "v0" or "original". " + \ 109 | "(default: rand-m9-mstd0.5-inc1)'), 110 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 111 | parser.add_argument('--train-interpolation', type=str, default='bicubic', 112 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 113 | 114 | # * Random Erase params 115 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 116 | help='Random erase prob (default: 0.25)') 117 | parser.add_argument('--remode', type=str, default='pixel', 118 | help='Random erase mode (default: "pixel")') 119 | parser.add_argument('--recount', type=int, default=1, 120 | help='Random erase count (default: 1)') 121 | parser.add_argument('--resplit', action='store_true', default=False, 122 | help='Do not random erase first (clean) augmentation split') 123 | 124 | # * Mixup params 125 | parser.add_argument('--mixup', type=float, default=0.8, 126 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 127 | parser.add_argument('--cutmix', type=float, default=1.0, 128 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 129 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 130 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 131 | parser.add_argument('--mixup-prob', type=float, default=1.0, 132 | help='Probability of performing mixup or cutmix when either/both is enabled') 133 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 134 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 135 | parser.add_argument('--mixup-mode', type=str, default='batch', 136 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 137 | 138 | parser.add_argument('--best-acc1', type=float, default=0, help='best_acc1') 139 | 140 | 141 | def str2model(name): 142 | d = {'deit_tiny': deit_tiny_patch16_224, 143 | 'deit_small': deit_small_patch16_224, 144 | 'deit_base': deit_base_patch16_224, 145 | 'swin_tiny': swin_tiny_patch4_window7_224, 146 | 'swin_small': swin_small_patch4_window7_224, 147 | 'swin_base': swin_base_patch4_window7_224, 148 | } 149 | print('Model: %s' % d[name].__name__) 150 | return d[name] 151 | 152 | 153 | def main(): 154 | args = parser.parse_args() 155 | 156 | seed = args.seed 157 | torch.manual_seed(seed) 158 | torch.cuda.manual_seed(seed) 159 | np.random.seed(seed) 160 | torch.backends.cudnn.benchmark = True 161 | 162 | import warnings 163 | warnings.filterwarnings('ignore') 164 | 165 | if not os.path.exists(args.output_dir): 166 | os.makedirs(args.output_dir) 167 | logging.basicConfig(format='%(asctime)s - %(message)s', 168 | datefmt='%d-%b-%y %H:%M:%S', filename=args.output_dir + 'log.log') 169 | logging.getLogger().setLevel(logging.INFO) 170 | logging.getLogger().addHandler(logging.StreamHandler()) 171 | logging.info(args) 172 | 173 | device = torch.device(args.device) 174 | 175 | # Dataset 176 | train_loader, val_loader = dataloader(args) 177 | 178 | mixup_fn = None 179 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 180 | if mixup_active: 181 | mixup_fn = Mixup( 182 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 183 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 184 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 185 | 186 | # Model 187 | model = str2model(args.model)(pretrained=True, 188 | num_classes=args.nb_classes, 189 | drop_rate=args.drop, 190 | drop_path_rate=args.drop_path) 191 | model.to(device) 192 | 193 | model_ema = None 194 | if args.model_ema: 195 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 196 | model_ema = ModelEma( 197 | model, 198 | decay=args.model_ema_decay, 199 | device='cpu' if args.model_ema_force_cpu else '', 200 | resume='') 201 | 202 | args.min_lr = args.lr / 15 203 | optimizer = create_optimizer(args, model) 204 | loss_scaler = NativeScaler() 205 | lr_scheduler, _ = create_scheduler(args, optimizer) 206 | 207 | if mixup_active: 208 | # smoothing is handled with mixup label transform 209 | criterion = SoftTargetCrossEntropy() 210 | elif args.smoothing: 211 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 212 | else: 213 | criterion = nn.CrossEntropyLoss() 214 | criterion_v = nn.CrossEntropyLoss() 215 | 216 | if args.resume: 217 | if args.resume.startswith('https'): 218 | checkpoint = torch.hub.load_state_dict_from_url( 219 | args.resume, map_location='cpu', check_hash=True) 220 | else: 221 | checkpoint = torch.load(args.resume, map_location='cpu') 222 | model.load_state_dict(checkpoint['model']) 223 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 224 | optimizer.load_state_dict(checkpoint['optimizer']) 225 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 226 | args.start_epoch = checkpoint['epoch'] + 1 227 | if args.model_ema: 228 | load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) 229 | if 'scaler' in checkpoint: 230 | loss_scaler.load_state_dict(checkpoint['scaler']) 231 | lr_scheduler.step(args.start_epoch) 232 | 233 | print(f"Start training for {args.epochs} epochs") 234 | best_epoch = 0 235 | for epoch in range(args.start_epoch, args.epochs): 236 | # train for one epoch 237 | train(args, train_loader, model, criterion, optimizer, epoch, 238 | loss_scaler, args.clip_grad, model_ema, mixup_fn, device) 239 | lr_scheduler.step(epoch) 240 | 241 | # if args.output_dir: # this is for resume training 242 | # checkpoint_path = os.path.join(args.output_dir, 'checkpoint.pth.tar') 243 | # torch.save({ 244 | # 'model': model.state_dict(), 245 | # 'optimizer': optimizer.state_dict(), 246 | # 'lr_scheduler': lr_scheduler.state_dict(), 247 | # 'epoch': epoch, 248 | # 'model_ema': get_state_dict(model_ema), 249 | # 'scaler': loss_scaler.state_dict(), 250 | # 'args': args, 251 | # }, checkpoint_path) 252 | 253 | acc1 = validate(args, val_loader, model, criterion_v, device) 254 | 255 | # remember best acc@1 and save checkpoint 256 | is_best = acc1 > args.best_acc1 257 | args.best_acc1 = max(acc1, args.best_acc1) 258 | if is_best: 259 | # record the best epoch 260 | best_epoch = epoch 261 | torch.save(model.state_dict(), os.path.join(args.output_dir, 'checkpoint.pth.tar')) 262 | logging.info(f'Acc at epoch {epoch}: {acc1}') 263 | logging.info(f'Best acc at epoch {best_epoch}: {args.best_acc1}') 264 | 265 | 266 | def train(args, train_loader, model, criterion, optimizer, epoch, loss_scaler, max_norm, model_ema, mixup_fn, device): 267 | batch_time = AverageMeter('Time', ':6.3f') 268 | data_time = AverageMeter('Data', ':6.3f') 269 | losses = AverageMeter('Loss', ':.4e') 270 | progress = ProgressMeter( 271 | len(train_loader), 272 | [batch_time, data_time, losses], 273 | prefix="Epoch: [{}]".format(epoch)) 274 | 275 | # switch to train mode 276 | model.train() 277 | unfreeze_model(model) 278 | 279 | end = time.time() 280 | for i, (data, target) in enumerate(train_loader): 281 | # measure data loading time 282 | data_time.update(time.time() - end) 283 | 284 | data = data.to(device, non_blocking=True) 285 | target = target.to(device, non_blocking=True) 286 | 287 | if mixup_fn is not None: 288 | data, target = mixup_fn(data, target) 289 | 290 | output = model(data) 291 | loss = criterion(output, target) 292 | 293 | # measure accuracy and record loss 294 | losses.update(loss.item(), data.size(0)) 295 | 296 | # compute gradient and do SGD step 297 | optimizer.zero_grad() 298 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 299 | loss_scaler(loss, optimizer, clip_grad=max_norm, 300 | parameters=model.parameters(), create_graph=is_second_order) 301 | 302 | torch.cuda.synchronize() 303 | if model_ema is not None: 304 | model_ema.update(model) 305 | 306 | # measure elapsed time 307 | batch_time.update(time.time() - end) 308 | end = time.time() 309 | 310 | if i % args.print_freq == 0: 311 | progress.display(i) 312 | 313 | 314 | def validate(args, val_loader, model, criterion, device): 315 | batch_time = AverageMeter('Time', ':6.3f') 316 | losses = AverageMeter('Loss', ':.4e') 317 | top1 = AverageMeter('Acc@1', ':6.2f') 318 | top5 = AverageMeter('Acc@5', ':6.2f') 319 | progress = ProgressMeter( 320 | len(val_loader), 321 | [batch_time, losses, top1, top5], 322 | prefix='Test: ') 323 | 324 | # switch to evaluate mode 325 | model.eval() 326 | freeze_model(model) 327 | 328 | end = time.time() 329 | for i, (data, target) in enumerate(val_loader): 330 | data = data.to(device, non_blocking=True) 331 | target = target.to(device, non_blocking=True) 332 | 333 | with torch.no_grad(): 334 | output = model(data) 335 | loss = criterion(output, target) 336 | 337 | # measure accuracy and record loss 338 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) 339 | losses.update(loss.data.item(), data.size(0)) 340 | top1.update(prec1.data.item(), data.size(0)) 341 | top5.update(prec5.data.item(), data.size(0)) 342 | 343 | # measure elapsed time 344 | batch_time.update(time.time() - end) 345 | end = time.time() 346 | 347 | if i % args.print_freq == 0: 348 | progress.display(i) 349 | 350 | print(" * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}".format(top1=top1, top5=top5)) 351 | return top1.avg 352 | 353 | 354 | class AverageMeter(object): 355 | """Computes and stores the average and current value""" 356 | 357 | def __init__(self, name, fmt=':f'): 358 | self.name = name 359 | self.fmt = fmt 360 | self.reset() 361 | 362 | def reset(self): 363 | self.val = 0 364 | self.avg = 0 365 | self.sum = 0 366 | self.count = 0 367 | 368 | def update(self, val, n=1): 369 | self.val = val 370 | self.sum += val * n 371 | self.count += n 372 | self.avg = self.sum / self.count 373 | 374 | def __str__(self): 375 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 376 | return fmtstr.format(**self.__dict__) 377 | 378 | 379 | class ProgressMeter(object): 380 | def __init__(self, num_batches, meters, prefix=""): 381 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 382 | self.meters = meters 383 | self.prefix = prefix 384 | 385 | def display(self, batch): 386 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 387 | entries += [str(meter) for meter in self.meters] 388 | logging.info('\t'.join(entries)) 389 | 390 | def _get_batch_fmtstr(self, num_batches): 391 | num_digits = len(str(num_batches // 1)) 392 | fmt = '{:' + str(num_digits) + 'd}' 393 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 394 | 395 | 396 | 397 | if __name__ == "__main__": 398 | main() 399 | -------------------------------------------------------------------------------- /models/quantization_utils/quant_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.multiprocessing as mp 7 | from torch.nn import Parameter 8 | 9 | from .quant_utils import * 10 | 11 | 12 | class QuantLinear(nn.Linear): 13 | """ 14 | Class to quantize weights of given Linear layer 15 | 16 | Parameters: 17 | ---------- 18 | weight_bit : int 19 | Bitwidth for quantized weights. 20 | bias_bit : int, default None 21 | Bitwidth for quantized bias. 22 | per_channel : bool, default False 23 | Whether to use channel-wise quantization. 24 | quant_mode : 'none' or 'symmetric', default 'none' 25 | The mode for quantization. 'none' for no quantization. 26 | """ 27 | 28 | def __init__(self, 29 | in_features, 30 | out_features, 31 | bias=True, 32 | weight_bit=8, 33 | bias_bit=32, 34 | per_channel=True, 35 | quant_mode='symmetric'): 36 | super(QuantLinear, self).__init__(in_features, out_features, bias) 37 | self.weight_bit = weight_bit 38 | self.per_channel = per_channel 39 | self.bias_bit = bias_bit 40 | self.quantize_bias = (False if bias_bit is None else True) 41 | self.quant_mode = quant_mode 42 | 43 | if self.quant_mode == "symmetric": 44 | self.weight_function = SymmetricQuantFunction.apply 45 | elif self.quant_mode == "asymmetric": 46 | raise NotImplementedError("unsupported quant mode: {}".format(quant_mode)) 47 | else: 48 | raise ValueError("unknown quant mode: {}".format(self.quant_mode)) 49 | 50 | self.register_buffer('fc_scaling_factor', torch.zeros(self.out_features)) 51 | self.register_buffer('weight_integer', torch.zeros_like(self.weight)) 52 | if self.bias is not None: 53 | self.register_buffer('bias_integer', torch.zeros_like(self.bias)) 54 | 55 | def __repr__(self): 56 | s = super(QuantLinear, self).__repr__() 57 | s = "(" + s + " weight_bit={}, quant_mode={})".format( 58 | self.weight_bit, self.quant_mode) 59 | return s 60 | 61 | def fix(self): 62 | pass 63 | 64 | def unfix(self): 65 | pass 66 | 67 | def forward(self, x, prev_act_scaling_factor=None): 68 | with torch.no_grad(): 69 | w = self.weight 70 | if self.per_channel: 71 | v = w.reshape(w.shape[0], -1) 72 | cur_min = v.min(axis=1).values 73 | cur_max = v.max(axis=1).values 74 | self.min_val = cur_min 75 | self.max_val = cur_max 76 | else: 77 | raise Exception('For weight, we only support per_channel quantization.') 78 | 79 | self.fc_scaling_factor = symmetric_linear_quantization_params( 80 | self.weight_bit, self.min_val, self.max_val) 81 | 82 | self.weight_integer = self.weight_function( 83 | self.weight, self.weight_bit, self.fc_scaling_factor, True) 84 | 85 | bias_scaling_factor = self.fc_scaling_factor * prev_act_scaling_factor 86 | 87 | if self.bias is not None: 88 | self.bias_integer = self.weight_function( 89 | self.bias, self.bias_bit, bias_scaling_factor, True) 90 | else: 91 | self.bias_integer = None 92 | 93 | prev_act_scaling_factor = prev_act_scaling_factor.view(1, -1) 94 | x_int = x / prev_act_scaling_factor 95 | 96 | return F.linear(x_int, weight=self.weight_integer, bias=self.bias_integer) \ 97 | * bias_scaling_factor, bias_scaling_factor 98 | 99 | 100 | class QuantAct(nn.Module): 101 | """ 102 | Class to quantize given activations 103 | Parameters: 104 | ---------- 105 | activation_bit : int 106 | Bitwidth for quantized activations. 107 | act_range_momentum : float, default 0.95 108 | Momentum for updating the activation quantization range. 109 | running_stat : bool, default True 110 | Whether to use running statistics for activation quantization range. 111 | per_channel : bool, default False 112 | Whether to use channel-wise quantization. 113 | channel_len : int, default None 114 | Specify the channel length when using the per_channel mode. 115 | quant_mode : 'none' or 'asymmetric', default 'none' 116 | The mode for quantization. 'none' for no quantization. 117 | """ 118 | 119 | def __init__(self, 120 | activation_bit=8, 121 | act_range_momentum=0.95, 122 | running_stat=True, 123 | per_channel=False, 124 | quant_mode="symmetric"): 125 | super(QuantAct, self).__init__() 126 | 127 | self.activation_bit = activation_bit 128 | self.act_range_momentum = act_range_momentum 129 | self.running_stat = running_stat 130 | self.quant_mode = quant_mode 131 | self.per_channel = per_channel 132 | 133 | self.min_val = torch.zeros(1) 134 | self.max_val = torch.zeros(1) 135 | self.register_buffer('act_scaling_factor', torch.zeros(1)) 136 | 137 | self.quant_mode = quant_mode 138 | self.per_channel = per_channel 139 | 140 | if self.quant_mode == "symmetric": 141 | self.act_function = SymmetricQuantFunction.apply 142 | elif self.quant_mode == "asymmetric": 143 | raise NotImplementedError("unsupported quant mode: {}".format(self.quant_mode)) 144 | else: 145 | raise ValueError("unknown quant mode: {}".format(self.quant_mode)) 146 | 147 | def __repr__(self): 148 | return "{0}(activation_bit={1}, " \ 149 | "quant_mode: {2}, Act_min: {3:.2f}, " \ 150 | "Act_max: {4:.2f})".format(self.__class__.__name__, self.activation_bit, 151 | self.quant_mode, self.x_min.item(), self.x_max.item()) 152 | 153 | def fix(self): 154 | """ 155 | fix the activation range by setting running stat 156 | """ 157 | self.running_stat = False 158 | 159 | def unfix(self): 160 | """ 161 | unfix the activation range by setting running stat 162 | """ 163 | self.running_stat = True 164 | 165 | def forward(self, x, 166 | pre_act_scaling_factor=None, 167 | identity=None, 168 | identity_scaling_factor=None): 169 | # collect runnng stats 170 | with torch.no_grad(): 171 | x_act = x if identity is None else identity + x 172 | if self.running_stat: 173 | if len(x_act.shape) == 4: 174 | x_act = x_act.permute(0, 2, 3, 1) 175 | v = x_act.reshape(-1, x_act.shape[-1]) 176 | v = v.transpose(0, 1) 177 | 178 | cur_min = v.min(axis=1).values 179 | cur_max = v.max(axis=1).values 180 | if torch.eq(self.min_val, self.max_val).all(): 181 | self.min_val = cur_min 182 | self.max_val = cur_max 183 | else: 184 | self.min_val = self.min_val * self.act_range_momentum + \ 185 | cur_min * (1 - self.act_range_momentum) 186 | self.max_val = self.max_val * self.act_range_momentum + \ 187 | cur_max * (1 - self.act_range_momentum) 188 | self.max_val = self.max_val.max() 189 | self.min_val = self.min_val.min() 190 | 191 | self.act_scaling_factor = symmetric_linear_quantization_params( 192 | self.activation_bit, self.min_val, self.max_val) 193 | 194 | if pre_act_scaling_factor is None: 195 | # this is for the input quantization 196 | quant_act_int = self.act_function(x, self.activation_bit, self.act_scaling_factor, False) 197 | else: 198 | quant_act_int = fixedpoint_mul.apply( 199 | x, pre_act_scaling_factor, 200 | self.activation_bit, self.quant_mode, 201 | self.act_scaling_factor, 202 | identity, identity_scaling_factor) 203 | 204 | correct_output_scale = self.act_scaling_factor.view(-1) 205 | 206 | return quant_act_int * correct_output_scale, self.act_scaling_factor 207 | 208 | 209 | class QuantMatMul(nn.Module): 210 | """ 211 | Class to quantize weights of given matmul layer 212 | """ 213 | def __init__(self): 214 | super(QuantMatMul, self).__init__() 215 | self.register_buffer('act_scaling_factor', torch.zeros(1)) 216 | 217 | def fix(self): 218 | pass 219 | 220 | def unfix(self): 221 | pass 222 | 223 | def forward(self, A, pre_act_scaling_factor_A, B, pre_act_scaling_factor_B): 224 | A_int = A / pre_act_scaling_factor_A 225 | B_int = B / pre_act_scaling_factor_B 226 | act_scaling_factor = pre_act_scaling_factor_A * pre_act_scaling_factor_B 227 | self.act_scaling_factor = act_scaling_factor 228 | return (A_int @ B_int) * act_scaling_factor, act_scaling_factor 229 | 230 | 231 | class QuantConv2d(nn.Conv2d): 232 | """ 233 | Class to quantize weights of given convolutional layer 234 | Parameters: 235 | ---------- 236 | weight_bit : int, default 4 237 | Bitwidth for quantized weights. 238 | bias_bit : int, default None 239 | Bitwidth for quantized bias. 240 | full_precision_flag : bool, default False 241 | If True, use fp32 and skip quantization 242 | quant_mode : 'symmetric' or 'asymmetric', default 'symmetric' 243 | The mode for quantization. 244 | per_channel : bool, default False 245 | Whether to use channel-wise quantization. 246 | fix_flag : bool, default False 247 | Whether the module is in fixed mode or not. 248 | weight_percentile : float, default 0 249 | The percentile to setup quantization range, 0 means no use of percentile, 99.9 means to cut off 0.1%. 250 | """ 251 | 252 | def __init__(self, 253 | in_channels, 254 | out_channels, 255 | kernel_size, 256 | stride=1, 257 | padding=0, 258 | dilation=1, 259 | groups=1, 260 | bias=True, 261 | weight_bit=8, 262 | bias_bit=32, 263 | quant_mode="symmetric", 264 | per_channel=True, 265 | weight_percentile=0): 266 | super(QuantConv2d, self).__init__(in_channels=in_channels, 267 | out_channels=out_channels, 268 | kernel_size=kernel_size, 269 | stride=stride, 270 | padding=padding, 271 | dilation=dilation, 272 | groups=groups, 273 | bias=bias 274 | ) 275 | self.weight_bit = weight_bit 276 | self.quant_mode = quant_mode 277 | self.per_channel = per_channel 278 | self.weight_percentile = weight_percentile 279 | self.bias_bit = bias_bit 280 | self.quantize_bias = (False if bias_bit is None else True) 281 | 282 | self.register_buffer('conv_scaling_factor', torch.zeros(self.out_channels)) 283 | self.register_buffer('weight_integer', torch.zeros_like(self.weight)) 284 | self.register_buffer('bias_integer', torch.zeros_like(self.bias)) 285 | 286 | def __repr__(self): 287 | s = super(QuantConv2d, self).__repr__() 288 | s = "(" + s + " weight_bit={}, quant_mode={})".format(self.weight_bit, self.quant_mode) 289 | return s 290 | 291 | def fix(self): 292 | pass 293 | 294 | def unfix(self): 295 | pass 296 | 297 | def forward(self, x, pre_act_scaling_factor=None): 298 | if self.quant_mode == "symmetric": 299 | self.weight_function = SymmetricQuantFunction.apply 300 | elif self.quant_mode == "asymmetric": 301 | raise NotImplementedError("unsupported quant mode: {}".format(self.quant_mode)) 302 | else: 303 | raise ValueError("unknown quant mode: {}".format(self.quant_mode)) 304 | 305 | with torch.no_grad(): 306 | w = self.weight 307 | if self.per_channel: 308 | v = w.reshape(w.shape[0], -1) 309 | cur_min = v.min(axis=1).values 310 | cur_max = v.max(axis=1).values 311 | self.min_val = cur_min 312 | self.max_val = cur_max 313 | else: 314 | raise Exception('For weight, we only support per_channel quantization.') 315 | 316 | self.conv_scaling_factor = symmetric_linear_quantization_params( 317 | self.weight_bit, self.min_val, self.max_val) 318 | 319 | self.weight_integer = self.weight_function( 320 | self.weight, self.weight_bit, self.conv_scaling_factor, True) 321 | bias_scaling_factor = self.conv_scaling_factor * pre_act_scaling_factor 322 | self.bias_integer = self.weight_function( 323 | self.bias, self.bias_bit, bias_scaling_factor, True) 324 | 325 | pre_act_scaling_factor = pre_act_scaling_factor.view(1, -1, 1, 1) 326 | x_int = x / pre_act_scaling_factor 327 | correct_output_scale = bias_scaling_factor.view(1, -1, 1, 1) 328 | 329 | return (F.conv2d(x_int, self.weight_integer, self.bias_integer, self.stride, self.padding, 330 | self.dilation, self.groups) * correct_output_scale, correct_output_scale) 331 | 332 | 333 | class IntLayerNorm(nn.LayerNorm): 334 | """ 335 | Implementation of I-LayerNorm 336 | Class to quantize given LayerNorm layer 337 | """ 338 | def __init__(self, 339 | normalized_shape, 340 | eps=1e-5, 341 | elementwise_affine=True): 342 | super(IntLayerNorm, self).__init__(normalized_shape, eps, elementwise_affine) 343 | self.dim_sqrt = None 344 | self.register_buffer('norm_scaling_factor', torch.zeros(1)) 345 | self.register_buffer('bias_integer', torch.zeros_like(self.bias)) 346 | 347 | def fix(self): 348 | pass 349 | 350 | def unfix(self): 351 | pass 352 | 353 | def forward(self, x, scaling_factor=None): 354 | if self.dim_sqrt is None: 355 | n = torch.tensor(x.shape[2], dtype=torch.float) 356 | self.dim_sqrt = torch.sqrt(n).cuda() 357 | 358 | # Normalization: computes mean and variance(std) 359 | x_int = x / scaling_factor 360 | mean_int = round_ste.apply(x_int.mean(axis=2, keepdim=True)) 361 | y_int = x_int - mean_int 362 | y_sq_int = y_int ** 2 363 | var_int = torch.sum(y_sq_int, axis=2, keepdim=True) 364 | 365 | # Integer Iteration 366 | k = 2 ** 16 367 | for _ in range(10): 368 | k_1 = floor_ste.apply((k + floor_ste.apply(var_int/k))/2) 369 | k = k_1 370 | std_int = k 371 | 372 | factor = floor_ste.apply((2 ** 31-1) / std_int) 373 | y_int = floor_ste.apply(y_int * factor / 2) 374 | scaling_factor = self.dim_sqrt / 2 ** 30 375 | 376 | # scaling and shifting 377 | bias = self.bias.data.detach() / (self.weight.data.detach()) 378 | bias_int = floor_ste.apply(bias / scaling_factor) 379 | 380 | self.bias_integer = bias_int 381 | 382 | y_int = y_int + bias_int 383 | scaling_factor = scaling_factor * self.weight 384 | x = y_int * scaling_factor 385 | self.norm_scaling_factor = scaling_factor 386 | return x, scaling_factor 387 | 388 | 389 | class IntGELU(nn.Module): 390 | """ 391 | Implementation of ShiftGELU 392 | Class to quantize given GELU layer 393 | """ 394 | 395 | def __init__(self, output_bit=8): 396 | super(IntGELU, self).__init__() 397 | self.output_bit = output_bit 398 | 399 | self.n = 23 # sufficiently large integer 400 | #The minimum value for ensuring accuracy (varies depending on models) 401 | 402 | self.register_buffer('act_scaling_factor', torch.zeros(1)) 403 | 404 | def fix(self): 405 | pass 406 | 407 | def unfix(self): 408 | pass 409 | 410 | def int_exp_shift(self, x_int, scaling_factor): 411 | x_int = x_int + floor_ste.apply(x_int / 2) - floor_ste.apply(x_int / 2 ** 4) 412 | 413 | with torch.no_grad(): 414 | x0_int = torch.floor(-1.0 / scaling_factor) 415 | x_int = torch.max(x_int, self.n * x0_int) 416 | 417 | q = floor_ste.apply(x_int / x0_int) 418 | r = x_int - x0_int * q 419 | exp_int = r/2 - x0_int 420 | exp_int = torch.clamp(floor_ste.apply(exp_int * 2 ** (self.n - q)), min=0) 421 | scaling_factor = scaling_factor / 2 ** self.n 422 | 423 | return exp_int, scaling_factor 424 | 425 | def forward(self, x, scaling_factor=None): 426 | pre_x_int = x / scaling_factor 427 | scaling_factor_sig = scaling_factor * 1.702 428 | 429 | x_int_max, _ = pre_x_int.max(dim=-1, keepdim=True) 430 | x_int = pre_x_int - x_int_max 431 | 432 | exp_int, _ = self.int_exp_shift(x_int, scaling_factor_sig) # e^(x-x_max) 433 | 434 | exp_int_max, _ = self.int_exp_shift(-x_int_max, scaling_factor_sig) # e^(-x_max) 435 | exp_int_sum = exp_int + exp_int_max 436 | 437 | exp_int_sum.clamp_max_(2**31-1) 438 | factor = floor_ste.apply((2 ** 31-1) / exp_int_sum) 439 | sigmoid_int = floor_ste.apply(exp_int * factor / 2 ** (31-self.output_bit+1)) 440 | sigmoid_scaling_factor = torch.Tensor([1 / 2 ** (self.output_bit-1)]).cuda() 441 | 442 | x_int = pre_x_int * sigmoid_int 443 | scaling_factor = scaling_factor * sigmoid_scaling_factor 444 | self.act_scaling_factor = scaling_factor 445 | return x_int * scaling_factor, scaling_factor 446 | 447 | 448 | class IntSoftmax(nn.Module): 449 | """ 450 | Implementation of Shiftmax 451 | Class to quantize given Softmax layer 452 | """ 453 | 454 | def __init__(self, output_bit=8): 455 | super(IntSoftmax, self).__init__() 456 | self.output_bit = output_bit 457 | 458 | self.n = 15 # sufficiently large integer 459 | #The minimum value for ensuring accuracy (varies depending on models) 460 | 461 | self.register_buffer('act_scaling_factor', torch.zeros(1)) 462 | 463 | def fix(self): 464 | pass 465 | 466 | def unfix(self): 467 | pass 468 | 469 | def int_exp_shift(self, x_int, scaling_factor): 470 | x_int = x_int + floor_ste.apply(x_int / 2) - floor_ste.apply(x_int / 2 ** 4) 471 | 472 | with torch.no_grad(): 473 | x0_int = torch.floor(-1.0 / scaling_factor) 474 | x_int = torch.max(x_int, self.n * x0_int) 475 | 476 | q = floor_ste.apply(x_int / x0_int) 477 | r = x_int - x0_int * q 478 | exp_int = r/2 - x0_int 479 | exp_int = torch.clamp(floor_ste.apply(exp_int * 2 ** (self.n - q)), min=0) 480 | scaling_factor = scaling_factor / 2 ** self.n 481 | return exp_int, scaling_factor 482 | 483 | def forward(self, x, scaling_factor): 484 | x_int = x / scaling_factor 485 | x_int_max, _ = x_int.max(dim=-1, keepdim=True) 486 | x_int = x_int - x_int_max 487 | 488 | exp_int, _ = self.int_exp_shift(x_int, scaling_factor) 489 | exp_int_sum = exp_int.sum(dim=-1, keepdim=True) 490 | 491 | exp_int_sum.clamp_max_(2**31-1) 492 | factor = floor_ste.apply((2**31-1) / exp_int_sum) 493 | exp_int = floor_ste.apply(exp_int * factor / 2 ** (31-self.output_bit+1)) 494 | scaling_factor = torch.Tensor([1 / 2 ** (self.output_bit-1)]).cuda() 495 | 496 | self.act_scaling_factor = scaling_factor 497 | return exp_int * scaling_factor, scaling_factor 498 | -------------------------------------------------------------------------------- /models/swin_quant.py: -------------------------------------------------------------------------------- 1 | import math 2 | from tkinter import X 3 | from typing import Optional 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.utils.checkpoint as checkpoint 8 | from functools import partial 9 | 10 | from .layers_quant import PatchEmbed, Mlp, DropPath, trunc_normal_, to_2tuple 11 | from .quantization_utils import QuantLinear, QuantAct, QuantConv2d, IntLayerNorm, IntSoftmax, IntGELU, QuantMatMul 12 | 13 | __all__ = ['swin_tiny_patch4_window7_224', 14 | 'swin_small_patch4_window7_224', 15 | 'swin_base_patch4_window7_224'] 16 | 17 | 18 | def window_partition(x, window_size: int): 19 | """ 20 | Args: 21 | x: (B, H, W, C) 22 | window_size (int): window size 23 | 24 | Returns: 25 | windows: (num_windows*B, window_size, window_size, C) 26 | """ 27 | B, H, W, C = x.shape 28 | x = x.view(B, H // window_size, window_size, 29 | W // window_size, window_size, C) 30 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous( 31 | ).view(-1, window_size, window_size, C) 32 | return windows 33 | 34 | 35 | def window_reverse(windows, window_size: int, H: int, W: int): 36 | """ 37 | Args: 38 | windows: (num_windows*B, window_size, window_size, C) 39 | window_size (int): Window size 40 | H (int): Height of image 41 | W (int): Width of image 42 | 43 | Returns: 44 | x: (B, H, W, C) 45 | """ 46 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 47 | x = windows.view(B, H // window_size, W // window_size, 48 | window_size, window_size, -1) 49 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 50 | return x 51 | 52 | 53 | class WindowAttention(nn.Module): 54 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 55 | It supports both of shifted and non-shifted window. 56 | 57 | Args: 58 | dim (int): Number of input channels. 59 | window_size (tuple[int]): The height and width of the window. 60 | num_heads (int): Number of attention heads. 61 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 62 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 63 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 64 | """ 65 | 66 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.): 67 | 68 | super().__init__() 69 | self.dim = dim 70 | self.window_size = window_size # Wh, Ww 71 | self.num_heads = num_heads 72 | head_dim = dim // num_heads 73 | self.scale = head_dim ** -0.5 74 | 75 | # define a parameter table of relative position bias 76 | self.relative_position_bias_table = nn.Parameter( 77 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 78 | 79 | # get pair-wise relative position index for each token inside the window 80 | coords_h = torch.arange(self.window_size[0]) 81 | coords_w = torch.arange(self.window_size[1]) 82 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 83 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 84 | relative_coords = coords_flatten[:, :, None] - \ 85 | coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 86 | relative_coords = relative_coords.permute( 87 | 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 88 | relative_coords[:, :, 0] += self.window_size[0] - \ 89 | 1 # shift to start from 0 90 | relative_coords[:, :, 1] += self.window_size[1] - 1 91 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 92 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 93 | self.register_buffer("relative_position_index", 94 | relative_position_index) 95 | 96 | # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 97 | self.qkv = QuantLinear( 98 | dim, 99 | dim * 3, 100 | bias=qkv_bias 101 | ) 102 | self.qact1 = QuantAct() 103 | self.qact_attn1 = QuantAct() 104 | self.qact_table = QuantAct() 105 | self.qact2 = QuantAct() 106 | 107 | self.attn_drop = nn.Dropout(attn_drop) 108 | self.log_int_softmax = IntSoftmax() 109 | self.qact3 = QuantAct() 110 | self.qact4 = QuantAct(16) 111 | # self.proj = nn.Linear(dim, dim) 112 | self.proj = QuantLinear(dim,dim) 113 | self.proj_drop = nn.Dropout(proj_drop) 114 | 115 | trunc_normal_(self.relative_position_bias_table, std=.02) 116 | #self.softmax = nn.Softmax(dim=-1) 117 | 118 | self.matmul_1 = QuantMatMul() 119 | self.matmul_2 = QuantMatMul() 120 | 121 | def forward(self, x, act_scaling_factor, mask: Optional[torch.Tensor] = None): 122 | """ 123 | Args: 124 | x: input features with shape of (num_windows*B, N, C) 125 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 126 | """ 127 | B_, N, C = x.shape 128 | x, act_scaling_factor = self.qkv(x, act_scaling_factor) 129 | x, act_scaling_factor_1= self.qact1(x, act_scaling_factor) 130 | qkv = x.reshape(B_, N, 3, self.num_heads, C // 131 | self.num_heads).permute(2, 0, 3, 1, 4) 132 | # make torchscript happy (cannot use tensor as tuple) 133 | q, k, v = qkv[0], qkv[1], qkv[2] 134 | 135 | attn, act_scaling_factor = self.matmul_1(q, act_scaling_factor_1, 136 | k.transpose(-2, -1), act_scaling_factor_1) 137 | attn = attn * self.scale 138 | act_scaling_factor = act_scaling_factor * self.scale 139 | 140 | attn, act_scaling_factor = self.qact_attn1(attn, act_scaling_factor) 141 | 142 | relative_position_bias_table_q, act_scaling_factor_tabel = self.qact_table( 143 | self.relative_position_bias_table) 144 | relative_position_bias = relative_position_bias_table_q[self.relative_position_index.view(-1)].view( 145 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 146 | relative_position_bias = relative_position_bias.permute( 147 | 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 148 | 149 | attn, act_scaling_factor = self.qact2(attn, act_scaling_factor, relative_position_bias.unsqueeze(0), act_scaling_factor_tabel) 150 | 151 | if mask is not None: 152 | nW = mask.shape[0] 153 | attn = attn.view(B_ // nW, nW, self.num_heads, N, 154 | N) + mask.unsqueeze(1).unsqueeze(0) 155 | attn = attn.view(-1, self.num_heads, N, N) 156 | attn, act_scaling_factor = self.log_int_softmax(attn, act_scaling_factor) 157 | else: 158 | attn, act_scaling_factor = self.log_int_softmax(attn, act_scaling_factor) 159 | 160 | attn = self.attn_drop(attn) 161 | x, act_scaling_factor = self.matmul_2(attn, act_scaling_factor, 162 | v, act_scaling_factor_1) 163 | x = x.transpose(1, 2).reshape(B_, N, C) 164 | x, act_scaling_factor = self.qact3(x, act_scaling_factor) 165 | 166 | x, act_scaling_factor = self.proj(x, act_scaling_factor) 167 | x, act_scaling_factor = self.qact4(x, act_scaling_factor) 168 | x = self.proj_drop(x) 169 | return x, act_scaling_factor 170 | 171 | 172 | class SwinTransformerBlock(nn.Module): 173 | r""" Swin Transformer Block. 174 | 175 | Args: 176 | dim (int): Number of input channels. 177 | input_resolution (tuple[int]): Input resulotion. 178 | num_heads (int): Number of attention heads. 179 | window_size (int): Window size. 180 | shift_size (int): Shift size for SW-MSA. 181 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 182 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 183 | drop (float, optional): Dropout rate. Default: 0.0 184 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 185 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 186 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 187 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 188 | """ 189 | 190 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 191 | mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., 192 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 193 | super().__init__() 194 | self.dim = dim 195 | self.input_resolution = input_resolution 196 | self.num_heads = num_heads 197 | self.window_size = window_size 198 | self.shift_size = shift_size 199 | self.mlp_ratio = mlp_ratio 200 | if min(self.input_resolution) <= self.window_size: 201 | # if window size is larger than input resolution, we don't partition windows 202 | self.shift_size = 0 203 | self.window_size = min(self.input_resolution) 204 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 205 | 206 | self.norm1 = norm_layer(dim) 207 | self.qact1 = QuantAct() 208 | self.attn = WindowAttention( 209 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, 210 | attn_drop=attn_drop, proj_drop=drop) 211 | 212 | self.drop_path = DropPath( 213 | drop_path) if drop_path > 0. else nn.Identity() 214 | self.qact2 = QuantAct(16) 215 | self.norm2 = norm_layer(dim) 216 | self.qact3 = QuantAct() 217 | mlp_hidden_dim = int(dim * mlp_ratio) 218 | self.mlp = Mlp(in_features=dim, 219 | hidden_features=mlp_hidden_dim, 220 | act_layer=act_layer, 221 | drop=drop) 222 | self.qact4 = QuantAct(16) 223 | if self.shift_size > 0: 224 | # calculate attention mask for SW-MSA 225 | H, W = self.input_resolution 226 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 227 | h_slices = (slice(0, -self.window_size), 228 | slice(-self.window_size, -self.shift_size), 229 | slice(-self.shift_size, None)) 230 | w_slices = (slice(0, -self.window_size), 231 | slice(-self.window_size, -self.shift_size), 232 | slice(-self.shift_size, None)) 233 | cnt = 0 234 | for h in h_slices: 235 | for w in w_slices: 236 | img_mask[:, h, w, :] = cnt 237 | cnt += 1 238 | 239 | # nW, window_size, window_size, 1 240 | mask_windows = window_partition(img_mask, self.window_size) 241 | mask_windows = mask_windows.view(-1, 242 | self.window_size * self.window_size) 243 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 244 | attn_mask = attn_mask.masked_fill( 245 | attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 246 | else: 247 | attn_mask = None 248 | 249 | self.register_buffer("attn_mask", attn_mask) 250 | 251 | def forward(self, x_1, act_scaling_factor_1): 252 | H, W = self.input_resolution 253 | B, L, C = x_1.shape 254 | assert L == H * W, "input feature has wrong size" 255 | 256 | x, act_scaling_factor = self.norm1(x_1, act_scaling_factor_1) 257 | x, act_scaling_factor = self.qact1(x, act_scaling_factor) 258 | x = x.view(B, H, W, C) 259 | 260 | # cyclic shift 261 | if self.shift_size > 0: 262 | shifted_x = torch.roll( 263 | x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 264 | else: 265 | shifted_x = x 266 | 267 | # partition windows 268 | # nW*B, window_size, window_size, C 269 | x_windows = window_partition(shifted_x, self.window_size) 270 | # nW*B, window_size*window_size, C 271 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) 272 | 273 | # W-MSA/SW-MSA 274 | # nW*B, window_size*window_size, C 275 | attn_windows, act_scaling_factor = self.attn(x_windows, act_scaling_factor, mask=self.attn_mask) 276 | 277 | # merge windows 278 | attn_windows = attn_windows.view(-1, 279 | self.window_size, self.window_size, C) 280 | shifted_x = window_reverse( 281 | attn_windows, self.window_size, H, W) # B H' W' C 282 | 283 | # reverse cyclic shift 284 | if self.shift_size > 0: 285 | x = torch.roll(shifted_x, shifts=( 286 | self.shift_size, self.shift_size), dims=(1, 2)) 287 | else: 288 | x = shifted_x 289 | x = x.view(B, H * W, C) 290 | 291 | # FFN 292 | x = self.drop_path(x) 293 | x_2, act_scaling_factor_2 = self.qact2(x, act_scaling_factor, x_1, act_scaling_factor_1) 294 | 295 | x, act_scaling_factor = self.norm2(x_2, act_scaling_factor_2) 296 | x, act_scaling_factor = self.qact3(x, act_scaling_factor) 297 | x, act_scaling_factor = self.mlp(x, act_scaling_factor) 298 | x = self.drop_path(x) 299 | x, act_scaling_factor = self.qact4(x, act_scaling_factor, x_2, act_scaling_factor_2) 300 | 301 | return x, act_scaling_factor 302 | 303 | 304 | class PatchMerging(nn.Module): 305 | r""" Patch Merging Layer. 306 | 307 | Args: 308 | input_resolution (tuple[int]): Resolution of input feature. 309 | dim (int): Number of input channels. 310 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 311 | """ 312 | 313 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 314 | super().__init__() 315 | self.input_resolution = input_resolution 316 | self.dim = dim 317 | 318 | self.norm = norm_layer(4 * dim) 319 | self.qact1 = QuantAct() 320 | # self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 321 | self.reduction = QuantLinear( 322 | 4 * dim, 323 | 2 * dim, 324 | bias=False 325 | ) 326 | self.qact2 = QuantAct() 327 | 328 | def forward(self, x, act_scaling_factor): 329 | """ 330 | x: B, H*W, C 331 | """ 332 | H, W = self.input_resolution 333 | B, L, C = x.shape 334 | assert L == H * W, "input feature has wrong size" 335 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 336 | 337 | x = x.view(B, H, W, C) 338 | 339 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 340 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 341 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 342 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 343 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 344 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 345 | x, act_scaling_factor = self.norm(x, act_scaling_factor) 346 | x, act_scaling_factor = self.qact1(x, act_scaling_factor) 347 | x, act_scaling_factor = self.reduction(x, act_scaling_factor) 348 | x, act_scaling_factor = self.qact2(x, act_scaling_factor) 349 | return x, act_scaling_factor 350 | 351 | def extra_repr(self) -> str: 352 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 353 | 354 | def flops(self): 355 | H, W = self.input_resolution 356 | flops = H * W * self.dim 357 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 358 | return flops 359 | 360 | 361 | class BasicLayer(nn.Module): 362 | """ A basic Swin Transformer layer for one stage. 363 | 364 | Args: 365 | dim (int): Number of input channels. 366 | input_resolution (tuple[int]): Input resolution. 367 | depth (int): Number of blocks. 368 | num_heads (int): Number of attention heads. 369 | window_size (int): Local window size. 370 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 371 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 372 | drop (float, optional): Dropout rate. Default: 0.0 373 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 374 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 375 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 376 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 377 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 378 | """ 379 | 380 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 381 | mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., 382 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 383 | 384 | super().__init__() 385 | self.dim = dim 386 | self.input_resolution = input_resolution 387 | self.depth = depth 388 | self.use_checkpoint = use_checkpoint 389 | 390 | # build blocks 391 | self.blocks = nn.ModuleList([ 392 | SwinTransformerBlock( 393 | dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, 394 | shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, 395 | qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, 396 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, act_layer=IntGELU, norm_layer=norm_layer) 397 | for i in range(depth)]) 398 | 399 | # patch merging layer 400 | if downsample is not None: 401 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 402 | else: 403 | self.downsample = None 404 | 405 | def forward(self, x, act_scaling_factor): 406 | for blk in self.blocks: 407 | if not torch.jit.is_scripting() and self.use_checkpoint: 408 | x = checkpoint.checkpoint(blk, x) 409 | else: 410 | x, act_scaling_factor = blk(x, act_scaling_factor) 411 | if self.downsample is not None: 412 | x, act_scaling_factor = self.downsample(x, act_scaling_factor) 413 | return x, act_scaling_factor 414 | 415 | def extra_repr(self) -> str: 416 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 417 | 418 | 419 | class SwinTransformer(nn.Module): 420 | r""" Swin Transformer 421 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 422 | https://arxiv.org/pdf/2103.14030 423 | 424 | Args: 425 | img_size (int | tuple(int)): Input image size. Default 224 426 | patch_size (int | tuple(int)): Patch size. Default: 4 427 | in_chans (int): Number of input image channels. Default: 3 428 | num_classes (int): Number of classes for classification head. Default: 1000 429 | embed_dim (int): Patch embedding dimension. Default: 96 430 | depths (tuple(int)): Depth of each Swin Transformer layer. 431 | num_heads (tuple(int)): Number of attention heads in different layers. 432 | window_size (int): Window size. Default: 7 433 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 434 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 435 | drop_rate (float): Dropout rate. Default: 0 436 | attn_drop_rate (float): Attention dropout rate. Default: 0 437 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 438 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 439 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 440 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 441 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 442 | """ 443 | 444 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 445 | embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), 446 | window_size=7, mlp_ratio=4., qkv_bias=True, 447 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 448 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 449 | use_checkpoint=False, **kwargs): 450 | super().__init__() 451 | 452 | self.num_classes = num_classes 453 | self.num_layers = len(depths) 454 | self.embed_dim = embed_dim 455 | self.ape = ape 456 | self.patch_norm = patch_norm 457 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 458 | self.mlp_ratio = mlp_ratio 459 | self.qact_input = QuantAct() 460 | # split image into non-overlapping patches 461 | self.patch_embed = PatchEmbed( 462 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 463 | norm_layer=norm_layer if self.patch_norm else None) 464 | num_patches = self.patch_embed.num_patches 465 | self.patch_grid = self.patch_embed.grid_size 466 | 467 | # absolute position embedding 468 | if self.ape: 469 | self.absolute_pos_embed = nn.Parameter( 470 | torch.zeros(1, num_patches, embed_dim)) 471 | trunc_normal_(self.absolute_pos_embed, std=.02) 472 | self.qact_pos = QuantAct(16) 473 | else: 474 | self.absolute_pos_embed = None 475 | self.qact1 = QuantAct(16) 476 | 477 | self.pos_drop = nn.Dropout(p=drop_rate) 478 | 479 | # stochastic depth 480 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 481 | sum(depths))] # stochastic depth decay rule 482 | 483 | # build layers 484 | layers = [] 485 | for i_layer in range(self.num_layers): 486 | layers += [BasicLayer( 487 | dim=int(embed_dim * 2 ** i_layer), 488 | input_resolution=( 489 | self.patch_grid[0] // (2 ** i_layer), self.patch_grid[1] // (2 ** i_layer)), 490 | depth=depths[i_layer], 491 | num_heads=num_heads[i_layer], 492 | window_size=window_size, 493 | mlp_ratio=self.mlp_ratio, 494 | qkv_bias=qkv_bias, 495 | drop=drop_rate, 496 | attn_drop=attn_drop_rate, 497 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 498 | norm_layer=norm_layer, 499 | downsample=PatchMerging if ( 500 | i_layer < self.num_layers - 1) else None, 501 | use_checkpoint=use_checkpoint) 502 | ] 503 | self.layers = nn.Sequential(*layers) 504 | 505 | self.norm = norm_layer(self.num_features) 506 | 507 | self.qact2 = QuantAct() 508 | self.avgpool = nn.AdaptiveAvgPool1d(1) 509 | self.qact3 = QuantAct() 510 | self.head = ( 511 | QuantLinear( 512 | self.num_features, 513 | num_classes) 514 | if num_classes > 0 515 | else nn.Identity() 516 | ) 517 | 518 | self.act_out = QuantAct() 519 | self.apply(self._init_weights) 520 | 521 | def _init_weights(self, m): 522 | if isinstance(m, nn.Linear): 523 | trunc_normal_(m.weight, std=.02) 524 | if isinstance(m, nn.Linear) and m.bias is not None: 525 | nn.init.constant_(m.bias, 0) 526 | elif isinstance(m, nn.LayerNorm): 527 | nn.init.constant_(m.bias, 0) 528 | nn.init.constant_(m.weight, 1.0) 529 | 530 | @torch.jit.ignore 531 | def no_weight_decay(self): 532 | return {'absolute_pos_embed'} 533 | 534 | @torch.jit.ignore 535 | def no_weight_decay_keywords(self): 536 | return {'relative_position_bias_table'} 537 | 538 | 539 | def forward_features(self, x): 540 | x, act_scaling_factor = self.qact_input(x) 541 | x, act_scaling_factor = self.patch_embed(x, act_scaling_factor) 542 | if self.absolute_pos_embed is not None: 543 | x_pos, act_scaling_factor_pos = self.qact_pos(self.absolute_pos_embed) 544 | x, act_scaling_factor = self.qact1(x, act_scaling_factor, x_pos, act_scaling_factor_pos) 545 | else: 546 | x, act_scaling_factor = self.qact1(x, act_scaling_factor) 547 | x = self.pos_drop(x) 548 | for layer in self.layers: 549 | x, act_scaling_factor = layer(x, act_scaling_factor) 550 | 551 | x, act_scaling_factor = self.norm(x, act_scaling_factor) 552 | x, act_scaling_factor = self.qact2(x, act_scaling_factor) 553 | 554 | x = self.avgpool(x.transpose(1, 2)) # B C 1 555 | x, act_scaling_factor = self.qact3(x, act_scaling_factor) 556 | 557 | x = torch.flatten(x, 1) 558 | return x, act_scaling_factor 559 | 560 | def forward(self, x): 561 | x, act_scaling_factor = self.forward_features(x) 562 | x, act_scaling_factor = self.head(x, act_scaling_factor) 563 | #x, _ = self.act_out(x, act_scaling_factor) 564 | return x 565 | 566 | 567 | def swin_tiny_patch4_window7_224(pretrained=False, quant=False, calibrate=False, cfg=None, **kwargs): 568 | """ Swin-T @ 224x224, trained ImageNet-1k 569 | """ 570 | model = SwinTransformer( 571 | patch_size=4, 572 | window_size=7, 573 | embed_dim=96, 574 | depths=(2, 2, 6, 2), 575 | num_heads=(3, 6, 12, 24), 576 | norm_layer=partial(IntLayerNorm, eps=1e-6), 577 | **kwargs 578 | ) 579 | if pretrained: 580 | checkpoint = torch.hub.load_state_dict_from_url( 581 | url="https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth", 582 | map_location="cpu", check_hash=True 583 | ) 584 | model.load_state_dict(checkpoint["model"], strict=False) 585 | return model 586 | 587 | 588 | def swin_small_patch4_window7_224(pretrained=False, quant=False, calibrate=False, cfg=None, **kwargs): 589 | """ Swin-S @ 224x224, trained ImageNet-1k 590 | """ 591 | model = SwinTransformer( 592 | patch_size=4, 593 | window_size=7, 594 | embed_dim=96, 595 | depths=(2, 2, 18, 2), 596 | num_heads=(3, 6, 12, 24), 597 | norm_layer=partial(IntLayerNorm, eps=1e-6), 598 | **kwargs 599 | ) 600 | if pretrained: 601 | checkpoint = torch.hub.load_state_dict_from_url( 602 | url="https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth", 603 | map_location="cpu", check_hash=True 604 | ) 605 | model.load_state_dict(checkpoint["model"], strict=False) 606 | return model 607 | 608 | 609 | def swin_base_patch4_window7_224(pretrained=False, quant=False, calibrate=False, cfg=None, **kwargs): 610 | """ Swin-B @ 224x224, trained ImageNet-1k 611 | """ 612 | model = SwinTransformer( 613 | patch_size=4, 614 | window_size=7, 615 | embed_dim=128, 616 | depths=(2, 2, 18, 2), 617 | num_heads=(4, 8, 16, 32), 618 | norm_layer=partial(IntLayerNorm, eps=1e-6), 619 | **kwargs 620 | ) 621 | if pretrained: 622 | checkpoint = torch.hub.load_state_dict_from_url( 623 | url="https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth", 624 | map_location="cpu", check_hash=True 625 | ) 626 | model.load_state_dict(checkpoint["model"], strict=False) 627 | return model 628 | --------------------------------------------------------------------------------