├── LICENSE ├── README.md ├── docs └── formats.png ├── examples ├── inference │ ├── README.md │ ├── bert │ │ ├── README.txt │ │ ├── cmd_inference.sh │ │ ├── download_squad_dataset.sh │ │ ├── download_squad_finetuned_model.sh │ │ ├── modeling_bert.py │ │ ├── requirements.txt │ │ └── run_squad.py │ └── classifier │ │ ├── README.md │ │ ├── imagenet_qat.py │ │ ├── imagenet_test.py │ │ ├── launch.py │ │ ├── test_scaleshift.py │ │ ├── train.py │ │ └── utils.py └── training │ ├── README.md │ ├── bert │ ├── launch_fp8_training.sh │ ├── modeling_bert.py │ ├── requirements.txt │ ├── run_qa_beam_search_no_trainer.py │ ├── run_qa_no_trainer.py │ └── utils_qa.py │ ├── nvidia_apex_cpu_patch_05091d4.patch │ └── resnet │ ├── README.md │ ├── lr_scheduler.py │ ├── main_amp.py │ ├── main_amp_cpu.py │ ├── train_cpu.sh │ └── train_gpu.sh ├── mpemu ├── __init__.py ├── bfloat16_emu.py ├── cmodel │ ├── __init__.py │ ├── simple.py │ ├── simple │ │ ├── simple_conv2d.cpp │ │ ├── simple_conv2d_impl.cpp │ │ ├── simple_gemm.cpp │ │ ├── simple_gemm_impl.cpp │ │ ├── simple_mm_engine.cpp │ │ └── vla.h │ └── tests │ │ ├── conv_grad_test.py │ │ ├── conv_test.py │ │ ├── gemm_grad_test.py │ │ ├── gemm_irregular_test.py │ │ ├── gemm_test.py │ │ ├── linear_test.py │ │ └── net.py ├── e3m4_emu.py ├── e4m3_emu.py ├── e5m2_emu.py ├── hybrid_emu.py ├── module_wrappers │ ├── __init__.py │ ├── adasparse.py │ ├── aggregate.py │ ├── eltwise.py │ └── matmul.py ├── mpt_emu.py ├── pytquant │ ├── __init__.py │ ├── cpp │ │ ├── __init__.py │ │ ├── fpemu.py │ │ └── fpemu_impl.cpp │ ├── cuda │ │ ├── __init__.py │ │ ├── fpemu.py │ │ ├── fpemu_impl.cpp │ │ └── fpemu_kernels.cu │ ├── hip │ │ ├── __init__.py │ │ ├── fpemu.py │ │ ├── fpemu_impl.cpp │ │ └── fpemu_kernels.hip │ └── test.py ├── qutils.py ├── scale_shift.py ├── sparse_utils.py └── stats_collector.py ├── requirements.txt └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Intel Labs 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FP8 Emulation Toolkit 2 | 3 | > [!CAUTION] 4 | > **PROJECT NOT UNDER ACTIVE MANAGEMENT** 5 | > * This project will no longer be maintained by Intel. 6 | > * Intel has ceased development and contributions including, but not limited to, maintenance, bug fixes, new releases, or updates, to this project. 7 | > * Intel no longer accepts patches to this project. 8 | > * If you have an ongoing need to use this project, are interested in independently developing it, or would like to maintain patches for the open source software community, please create your own fork of this project. 9 | 10 | ## Introduction 11 | This repository provides PyTorch tools to emulate the new `FP8` formats on top of existing floating point hardware from Intel, AMD and NVIDIA. In addition to the two formats `E5M2` and `E4M3` defined in the joint specification from ARM-Intel-NVIDIA, the toolkit also suports a third variant named `E3M4` which follows the guidelines established for `E4M3` format. 12 | 13 | Following table shows the binary formats and the numeric range: 14 | 15 | | | E5M2 | E4M3 | E3M4 | 16 | | -------------- | ---------------------------------------------------------------- | ---------------------------------------------------------------- | ---------------------------------------------------------------- | 17 | | Exponent Bias | 15 | 7 | 3 | 18 | | Infinities | S.11111.002 | N/A | N/A | 19 | | NaNs | S.11111.{01, 10, 11}2 | S.1111.1112 | S.111.11112 | 20 | | Zeros | S.00000.002 | S.0000.0002 | S.000.00002 | 21 | | Max normal | S.11110.112=1.75 * 215=57344.0 | S.1111.1102=1.75 * 28=448.0 | S.111.11102=1.875 * 24=30.0 | 22 | | Min normal | S.00001.002=2-14=6.1e-05 | S.0001.0002=2-6=1.5e-02 | S.001.00002=2-2=2.5e-01 | 23 | | Max subnormal | S.00000.112=0.75 * 2-14=4.5e-05 | S.0000.1112=0.875 * 2-6=1.3e-02 | S.000.11112=0.9375 * 2-2=2.3e-01 | 24 | | Min subnormal | S.00000.012=2-16=1.5e-05 | S.0000.0012=2-9=1.9e-03 | S.000.00012=2-6=1.5e-02 | 25 | 26 | ![DataFormats](./docs/formats.png) 27 | 28 | ## Installation 29 | 30 | Follow the instructions below to install FP8 Emulation Toolkit in a Python virtual environment. 31 | Alternatively, this installation can also be performed in a docker environment. 32 | 33 | ### Requirements 34 | This package can be installed on the following hardware. 35 | 36 | * x86 CPUs from AMD and Intel 37 | * GPU devices from NVIDIA(CUDA) and AMD(HIP) 38 | 39 | Install or upgrade the following packages on your linux machine. 40 | 41 | * Python >= 3.8.5 42 | * NVIDIA CUDA >= 11.1 or AMD ROCm >= 5.6 43 | * gcc >= 8.4.0 44 | 45 | Make sure these versions are reflected in the `$PATH` 46 | 47 | #### Target Hardware 48 | * CPU >= All x86 49 | * GPU >= V100, MI2XX 50 | 51 | ### Create a Python virtual environment 52 | ``` 53 | $ python3 -m ~/py-venv 54 | $ cd ~/py-venv 55 | $ source bin/activate 56 | $ pip3 install --upgrade pip3 57 | ``` 58 | ### Clone and install FP8 Emulation Toolkit 59 | ``` 60 | $ git clone https://github.com/IntelLabs/FP8-Emulation-Toolkit.git 61 | $ cd FP8-Emulation-Toolkit 62 | $ pip3 install -r requirements.txt 63 | $ python setup.py install 64 | ``` 65 | 66 | ## Usage Examples 67 | The emulated FP8 formats can be experimented with by integrated them into standard deep learning flows. Follow the links below for detailed instructions and code samples for exploring training and inference flows using FP8 data formats. 68 | 69 | * [Post-training quantization](./examples/inference) 70 | * [Mixed precision training](./examples/training) 71 | 72 | 73 | ## Related Work 74 | This implementation is based on the following research. Check out the source material for more details on the training and inference methods. 75 | 76 | ``` 77 | @article{shen2023efficient, 78 | title={Efficient Post-training Quantization with FP8 Formats}, 79 | author={Shen, Haihao and Mellempudi, Naveen and He, Xin and Gao, Qun and Wang, Chang and Wang, Mengni}, 80 | journal={arXiv preprint arXiv:2309.14592}, 81 | year={2023} 82 | } 83 | ``` 84 | ``` 85 | @misc{mellempudi2019mixed, 86 | title={Mixed Precision Training With 8-bit Floating Point}, 87 | author={Naveen Mellempudi and Sudarshan Srinivasan and Dipankar Das and Bharat Kaul}, 88 | year={2019}, 89 | eprint={1905.12334}, 90 | archivePrefix={arXiv}, 91 | primaryClass={cs.LG} 92 | } 93 | ``` 94 | ``` 95 | @misc{micikevicius2022fp8, 96 | title={FP8 Formats for Deep Learning}, 97 | author={Paulius Micikevicius and Dusan Stosic and Neil Burgess and Marius Cornea and Pradeep Dubey and Richard Grisenthwaite and Sangwon Ha and Alexander Heinecke and Patrick Judd and John Kamalu and Naveen Mellempudi and Stuart Oberman and Mohammad Shoeybi and Michael Siu and Hao Wu}, 98 | year={2022}, 99 | eprint={2209.05433}, 100 | archivePrefix={arXiv}, 101 | primaryClass={cs.LG} 102 | } 103 | ``` 104 | -------------------------------------------------------------------------------- /docs/formats.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/FP8-Emulation-Toolkit/fcd26e9a61e0e2aa4c559e38c15b29955c71149e/docs/formats.png -------------------------------------------------------------------------------- /examples/inference/README.md: -------------------------------------------------------------------------------- 1 | # FP8 Post-training Quantization 2 | 3 | Following example demonstrates the post-training quantization flow for converting pre-trained models to use FP8 for inference. 4 | 5 | ``` 6 | # import the emulator 7 | from mpemu import mpt_emu 8 | 9 | ... 10 | 11 | # layers exempt from e4m3 conversion 12 | list_exempt_layers = ["conv1","fc"] 13 | 14 | model, emulator = mpt_emu.quantize_model (model, dtype="e4m3_rne", "None", 15 | list_exempt_layers=list_exempt_layers) 16 | 17 | # calibrate the model for a few batches of training data 18 | evaluate(model, criterion, train_loader, device, 19 | num_batches=, train=True) 20 | 21 | # Fuse BatchNorm layers and quantize the model 22 | model = emulator.fuse_layers_and_quantize_model(model) 23 | 24 | # Evaluate the quantized model 25 | evaluate(model, criterion, test_loader, device) 26 | 27 | ``` 28 | An example demostrating post-training quantization can be found [here](./classifier/imagenet_test.py). 29 | 30 | -------------------------------------------------------------------------------- /examples/inference/bert/README.txt: -------------------------------------------------------------------------------- 1 | 2 | Prerequsite: 3 | ----------- 4 | 5 | Install squad task specific requirements (one time): 6 | $pip install -r requirements.txt 7 | 8 | Download SQUAD dataset: 9 | ---------------------- 10 | $bash download_squad_dataset.sh 11 | 12 | Download Squad fine-tuned model for inference: 13 | --------------------------------------------- 14 | $bash download_squad_fine_tuned_model.sh 15 | 16 | To run squad baseline inference task: 17 | $bash cmd_infer.sh 18 | 19 | To run squad inference in BF8: 20 | $bash cmd_infer.sh --use_pcl --pcl_bf8 --unpad 21 | 22 | -------------------------------------------------------------------------------- /examples/inference/bert/cmd_inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | QUANT_TYPE=${1:-'hybrid'} 3 | 4 | # set dataset and model_path 5 | if test -z $dataset || ! test -d $dataset ; then 6 | if test -d ./SQUAD1 ; then 7 | dataset=./SQUAD1 8 | else 9 | echo "Unable to find dataset path!!" 10 | echo "Download SQuAD dataset using the command ./download_squad_dataset.sh" 11 | exit 1 12 | fi 13 | fi 14 | 15 | if test -z $model_path || ! test -d $model_path ; then 16 | if test -d ./squad_finetuned_checkpoint ; then 17 | model_path=./squad_finetuned_checkpoint 18 | else 19 | echo "Unable to find pre-trained model path!!" 20 | echo "Download the pre-trained SQuAD model using the command ./download_squad_finetuned_model.sh" 21 | exit 1 22 | fi 23 | fi 24 | 25 | $NUMA_RAGS $GDB_ARGS python -u run_squad.py \ 26 | --model_type bert \ 27 | --model_name_or_path $model_path \ 28 | --do_eval \ 29 | --do_lower_case \ 30 | --predict_file $dataset/dev-v1.1.json \ 31 | --per_gpu_eval_batch_size 24 \ 32 | --max_seq_length 384 \ 33 | --doc_stride 128 \ 34 | --quant_data_type=$QUANT_TYPE \ 35 | --output_dir /tmp/debug_squad/ \ 36 | #--no_cuda 37 | -------------------------------------------------------------------------------- /examples/inference/bert/download_squad_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p SQUAD1 && cd SQUAD1 && wget --no-check-certificate https://data.deepai.org/squad1.1.zip && unzip squad1.1.zip && cd .. 3 | -------------------------------------------------------------------------------- /examples/inference/bert/download_squad_finetuned_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | mkdir squad_finetuned_checkpoint && cd squad_finetuned_checkpoint 4 | wget -c https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/config.json 5 | wget -c https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/pytorch_model.bin 6 | wget -c https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/tokenizer.json 7 | wget -c https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/tokenizer_config.json 8 | wget -c https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt 9 | -------------------------------------------------------------------------------- /examples/inference/bert/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers>=4.11.0 2 | -------------------------------------------------------------------------------- /examples/inference/classifier/README.md: -------------------------------------------------------------------------------- 1 | #### Post Training Quantization 2 | ``` 3 | python launch.py 4 | ``` 5 | Refer to launch.py for more details. 6 | 7 | #### QAT(Quantization Aware Training) 8 | ``` 9 | python imagenet_qat.py \ 10 | --data-path \ 11 | --arch=mobilenet_v2 \ 12 | --lr=0.001 \ 13 | --epochs=15 \ 14 | --lr-step-size=5 \ 15 | --lr-gamma=0.1 \ 16 | --qdtype bfloat8 --qscheme rne --qlevel 3 17 | ``` 18 | -------------------------------------------------------------------------------- /examples/inference/classifier/imagenet_qat.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Dharma Teja Vooturi, Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | import datetime 11 | import os 12 | import time 13 | import sys 14 | import copy 15 | import collections 16 | 17 | import torch 18 | import torch.utils.data 19 | from torch import nn 20 | import torchvision 21 | from torchvision import transforms 22 | import torch.quantization 23 | import utils 24 | try: 25 | from apex import amp 26 | except ImportError: 27 | amp = None 28 | 29 | #from train import train_one_epoch, evaluate, load_data 30 | from mpemu import mpt_emu 31 | 32 | def train_one_epoch(args, model, criterion, optimizer, emulator, data_loader, 33 | device, epoch, print_freq, num_batches=None): 34 | model.train() 35 | metric_logger = utils.MetricLogger(delimiter=" ") 36 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) 37 | metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}')) 38 | 39 | header = 'Epoch: [{}]'.format(epoch) 40 | i = 0 41 | for image, target in metric_logger.log_every(data_loader, print_freq, header): 42 | i += 1 43 | start_time = time.time() 44 | image, target = image.to(device), target.to(device) 45 | output = model(image) 46 | loss = criterion(output, target) 47 | 48 | optimizer.zero_grad() 49 | if args.data_type.lower() == "bf8": 50 | with amp.scale_loss(loss, optimizer) as scaled_loss: 51 | scaled_loss.backward() 52 | else: 53 | loss.backward() 54 | optimizer.step() 55 | 56 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 57 | batch_size = image.shape[0] 58 | metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) 59 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 60 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 61 | metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time)) 62 | 63 | if num_batches is not None and i%num_batches == 0: 64 | break 65 | 66 | def evaluate(model, criterion, data_loader, device, num_batches=None, print_freq=100, train=False): 67 | if not train: 68 | model.eval() 69 | header = 'Test:' 70 | else: 71 | model.train() 72 | header = 'Train:' 73 | 74 | metric_logger = utils.MetricLogger(delimiter=" ") 75 | 76 | with torch.no_grad(): 77 | batch_id = 0 78 | for image, target in metric_logger.log_every(data_loader, print_freq, header): 79 | image = image.to(device, non_blocking=True) 80 | target = target.to(device, non_blocking=True) 81 | output = model(image) 82 | loss = criterion(output, target) 83 | 84 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 85 | # FIXME need to take into account that the datasets 86 | # could have been padded in distributed setup 87 | batch_size = image.shape[0] 88 | metric_logger.update(loss=loss.item()) 89 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 90 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 91 | 92 | if num_batches is not None and batch_id+1 == num_batches: 93 | break; 94 | else: 95 | batch_id += 1 96 | # gather the stats from all processes 97 | metric_logger.synchronize_between_processes() 98 | 99 | print(' * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}' 100 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5)) 101 | return metric_logger.acc1.global_avg 102 | 103 | 104 | def main(args): 105 | if args.output_dir: 106 | utils.mkdir(args.output_dir) 107 | 108 | # creating device and setting backend 109 | device = torch.device(args.device) 110 | torch.backends.cudnn.benchmark = True 111 | 112 | # data loading code 113 | print("Loading data") 114 | train_dir = os.path.join(args.data_path, 'train') 115 | val_dir = os.path.join(args.data_path, 'val') 116 | 117 | dataset = datasets.ImageFolder( 118 | train_dir, 119 | transforms.Compose([ 120 | transforms.RandomResizedCrop(crop_size), 121 | transforms.RandomHorizontalFlip(), 122 | ])) 123 | dataset_test = datasets.ImageFolder(val_dir, transforms.Compose([ 124 | transforms.Resize(val_size), 125 | transforms.CenterCrop(crop_size), 126 | ])) 127 | 128 | if distributed: 129 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) 130 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) 131 | else: 132 | train_sampler = torch.utils.data.RandomSampler(dataset) 133 | test_sampler = torch.utils.data.SequentialSampler(dataset_test) 134 | 135 | data_loader = torch.utils.data.DataLoader( 136 | dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 137 | sampler=train_sampler, num_workers=args.workers, pin_memory=True) 138 | 139 | data_loader_test = torch.utils.data.DataLoader( 140 | dataset_test, batch_size=args.eval_batch_size, shuffle=False, 141 | sampler=test_sampler, num_workers=args.workers, pin_memory=True) 142 | 143 | print("Creating model", args.arch) 144 | # Loading fp32 model 145 | model = torchvision.models.__dict__[args.arch](pretrained=True) 146 | model.to(device) 147 | 148 | # SGD optimizer 149 | optimizer = torch.optim.SGD( 150 | model.parameters(), lr=args.lr, momentum=args.momentum, 151 | weight_decay=args.weight_decay) 152 | 153 | if args.data_type.lower() == "bf8" : 154 | opt_level = "O2" if args.device == "cuda" else "O0" 155 | loss_scale = "dynamic" if args.device == "cuda" else 1.0 156 | model, optimizer = amp.initialize(model, optimizer, 157 | opt_level=opt_level, 158 | keep_batchnorm_fp32=True, 159 | loss_scale=loss_scale 160 | ) 161 | # LR scheduler 162 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 163 | step_size=args.lr_step_size, 164 | gamma=args.lr_gamma) 165 | # Loss function 166 | criterion = nn.CrossEntropyLoss() 167 | 168 | if args.resume: 169 | checkpoint = torch.load(args.resume, map_location='cpu') 170 | model.load_state_dict(checkpoint['model']) 171 | optimizer.load_state_dict(checkpoint['optimizer']) 172 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 173 | args.start_epoch = checkpoint['epoch'] + 1 174 | 175 | ######################################################## 176 | list_exempt_layers = [] 177 | list_layers_output_fused = [] 178 | if "resnet" in args.arch or "resnext" in args.arch: 179 | list_exempt_layers = ["conv1","fc"] 180 | elif args.arch == "vgg19_bn": 181 | list_exempt_layers = ["features.0", "classifier.6"] 182 | elif args.arch == "inception_v3": 183 | list_exempt_layers = ["Conv2d_1a_3x3.conv", "fc"] 184 | 185 | print("list of exempted layers : ", list_exempt_layers) 186 | model, emulator = mpt_emu.quantize_model(model, optimizer=optimizer, dtype=args.data_type.lower(), hw_patch=args.patch_ops, 187 | list_exempt_layers=list_exempt_layers, list_layers_output_fused=list_layers_output_fused, 188 | device=device, verbose=True) 189 | 190 | emulator.set_default_inference_qconfig() 191 | start_time = time.time() 192 | evaluate(model, criterion, data_loader_test, device=device) 193 | print('quantization-aware training : Fine-tuning the network for {} epochs'.format(args.epochs)) 194 | for epoch in range(args.start_epoch, args.epochs): 195 | train_one_epoch(args, model, criterion, optimizer, emulator, 196 | data_loader, device, epoch, args.print_freq) 197 | lr_scheduler.step() 198 | with torch.no_grad(): 199 | print('evaluating fine-tuned model') 200 | eval_model = emulator.fuse_bnlayers_and_quantize_model(model) 201 | #quantized_eval_model = copy.deepcopy(model) 202 | evaluate(eval_model, criterion, data_loader_test, device=device) 203 | #quantized_eval_model = emulator.fuse_bnlayers_and_quantize_model(quantized_eval_model) 204 | #print('evaluate quantized model') 205 | #evaluate(quantized_eval_model, criterion, data_loader_test, device=device) 206 | 207 | model.train() 208 | 209 | if args.output_dir: 210 | checkpoint = { 211 | 'model': model.state_dict(), 212 | 'model_qconfig_dict' : emulator.emulator.model_qconfig_dict, 213 | 'optimizer': optimizer.state_dict(), 214 | 'lr_scheduler': lr_scheduler.state_dict(), 215 | 'epoch': epoch, 216 | 'args': args} 217 | utils.save_on_master( 218 | checkpoint, 219 | os.path.join(args.output_dir, 'checkpoint.pth')) 220 | print('saving models after {} epochs'.format(epoch)) 221 | 222 | total_time = time.time() - start_time 223 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 224 | print('Training time {}'.format(total_time_str)) 225 | 226 | def parse_args(): 227 | import argparse 228 | parser = argparse.ArgumentParser(description='PyTorch Classification Training') 229 | 230 | parser.add_argument('--data-path', 231 | default='/fastdata/imagenet/', 232 | help='dataset') 233 | parser.add_argument('--arch', 234 | default='mobilenet_v2', 235 | help='model') 236 | parser.add_argument('--data-type', 237 | default='bf8', 238 | help='supported types : e5m2, e4m3, bf16') 239 | parser.add_argument('--patch-ops', default='None', 240 | help='Patch Ops to enable custom gemm implementation') 241 | parser.add_argument('--pruning-algo', 242 | default="None", 243 | help='Pruning method: fine-grained, unstructured, adaptive') 244 | parser.add_argument('--device', 245 | default='cuda', 246 | help='device') 247 | 248 | parser.add_argument('-b', '--batch-size', default=256, type=int, 249 | help='batch size for calibration/training') 250 | parser.add_argument('--eval-batch-size', default=256, type=int, 251 | help='batch size for evaluation') 252 | parser.add_argument('--epochs', default=5, type=int, metavar='N', 253 | help='number of total epochs to run') 254 | 255 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 256 | help='number of data loading workers (default: 16)') 257 | parser.add_argument('--lr', 258 | default=0.1, type=float, 259 | help='initial learning rate') 260 | parser.add_argument('--momentum', 261 | default=0.9, type=float, metavar='M', 262 | help='momentum') 263 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 264 | metavar='W', help='weight decay (default: 1e-4)', 265 | dest='weight_decay') 266 | parser.add_argument('--lr-step-size', default=30, type=int, 267 | help='decrease lr every step-size epochs') 268 | parser.add_argument('--lr-gamma', default=0.1, type=float, 269 | help='decrease lr by a factor of lr-gamma') 270 | parser.add_argument('--print-freq', default=100, type=int, 271 | help='print frequency') 272 | parser.add_argument('--output-dir', default='.', help='path where to save') 273 | parser.add_argument('--resume', default='', help='resume from checkpoint') 274 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 275 | help='start epoch') 276 | parser.add_argument("--cache-dataset", dest="cache_dataset",\ 277 | help="Cache the datasets for quicker initialization. \ 278 | It also serializes the transforms", 279 | action="store_true", 280 | ) 281 | 282 | args = parser.parse_args() 283 | return args 284 | 285 | 286 | if __name__ == "__main__": 287 | args = parse_args() 288 | print(args) 289 | main(args) 290 | -------------------------------------------------------------------------------- /examples/inference/classifier/imagenet_test.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Dharma Teja Vooturi, Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | import torch 11 | import torchvision 12 | 13 | import os 14 | import torch 15 | import torchvision 16 | import torchvision.transforms as transforms 17 | import torchvision.datasets as datasets 18 | 19 | import time 20 | import argparse 21 | from train import evaluate 22 | from mpemu import mpt_emu 23 | 24 | def get_model_exempt_layers(args, model): 25 | 26 | list_exempt_layers = [] 27 | list_layers_output_fused = [] 28 | if args.model == "alexnet": 29 | list_exempt_layers = ["features.0", "classifier.6"] 30 | elif args.model == "vgg16_bn": 31 | list_exempt_layers = ["features.0", "features.1", "classifier.6"] 32 | elif args.model == "inception_v3": 33 | list_exempt_layers = ["Conv2d_1a_3x3.conv", "fc"] 34 | #list_exempt_layers = ["fc"] 35 | elif args.model == "squeezenet1_1": 36 | list_exempt_layers = ["features.0", "classifier.1"] 37 | #list_exempt_layers = ["classifier.1"] 38 | elif "resnet" in args.model or "resnext" in args.model: 39 | list_exempt_layers = ["conv1","bn1","fc"] 40 | #list_exempt_layers = ["fc"] 41 | elif "densenet" in args.model: 42 | list_exempt_layers = ["features.conv0","features.norm0", "classifier"] 43 | #list_exempt_layers = ["classifier"] 44 | elif "mobilenet" in args.model or "efficientnet" in args.model: 45 | # Exempting Features[0] and classifier 46 | list_exempt_layers = ["features.0.0","features.0.1","classifier.1"] 47 | #list_exempt_layers = ["classifier.1"] 48 | elif args.model == "mobilenet_v3_small" or args.model == "mobilenet_v3_large": 49 | list_exempt_layers = ["features.0.0","features.0.1","classifier.0","classifier.3"] 50 | #list_exempt_layers = ["classifier.0","classifier.3"] 51 | elif args.model == "wide_resnet50_2": 52 | list_exempt_layers = ["features.0.0","features.0.1","classifier.3"] 53 | #list_exempt_layers = ["classifier.3"] 54 | elif args.model == "mnasnet1_0": 55 | list_exempt_layers = ["layers.0","layers.1","classifier.1"] 56 | elif args.model == "shufflenet_v2_x1_0": 57 | list_exempt_layers = ["conv1.0","conv1.1","fc"] 58 | #list_exempt_layers = ["fc"] 59 | 60 | prev_name = None 61 | prev_module = None 62 | for name,module in model.named_modules(): 63 | if type(module) == torch.nn.BatchNorm2d and type(prev_module) == torch.nn.Conv2d: 64 | list_layers_output_fused.append(prev_name) 65 | if type(module) == torch.nn.Linear: 66 | list_layers_output_fused.append(name) 67 | 68 | prev_module = module 69 | prev_name = name 70 | 71 | return list_exempt_layers, list_layers_output_fused 72 | 73 | def get_data_loaders(args): 74 | ### Data loader construction 75 | traindir = os.path.join(args.data_path, 'train') 76 | valdir = os.path.join(args.data_path, 'val') 77 | 78 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 79 | std=[0.229, 0.224, 0.225]) 80 | 81 | dataset = datasets.ImageFolder( 82 | traindir, 83 | transforms.Compose([ 84 | transforms.RandomResizedCrop(224), 85 | transforms.RandomHorizontalFlip(), 86 | transforms.ToTensor(), 87 | normalize, 88 | ])) 89 | 90 | train_loader = torch.utils.data.DataLoader( 91 | dataset, batch_size=args.batch_size, shuffle=True, 92 | num_workers=args.workers, pin_memory=True) 93 | 94 | test_loader = torch.utils.data.DataLoader( 95 | datasets.ImageFolder(valdir, transforms.Compose([ 96 | transforms.Resize(256), 97 | transforms.CenterCrop(224), 98 | transforms.ToTensor(), 99 | normalize, 100 | ])), 101 | batch_size=args.eval_batch_size, shuffle=False, 102 | num_workers=args.workers, pin_memory=True) 103 | 104 | return train_loader, test_loader 105 | 106 | def print_bn_running_stats(model): 107 | rmean = None 108 | rvar = None 109 | for name, module in model.named_children(): 110 | if isinstance(module, torch.nn.BatchNorm2d): 111 | if name == "features.0.1": #"bn1": 112 | print("--> layer: {}, tracking running stats : {}".format( name, module.track_running_stats)) 113 | #break 114 | rmean = module.running_mean 115 | rvar = module.running_var 116 | #print(module.running_mean) 117 | #print(module.running_var) 118 | #return module.running_mean, module.running_var #rmean, rvar 119 | return rmean, rvar 120 | 121 | if __name__ == "__main__": 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument("--data-path", default="/fastdata/imagenet/", help="Path to imagenet dataset") 124 | parser.add_argument("--model", default="resnet50", 125 | choices=["alexnet", "vgg16_bn", "resnet18", "resnet50","resnext50_32x4d","densenet121",\ 126 | "densenet201","mobilenet_v2","shufflenet_v2_x1_0","mobilenet_v3_large","mobilenet_v3_small",\ 127 | "wide_resnet50_2","resnext101_32x8d","mnasnet1_0","efficientnet_b7","regnet_x_32gf",\ 128 | "inception_v3","squeezenet1_1","efficientnet_b0"], 129 | help="Name of the neural network architecture") 130 | parser.add_argument('--data-type', default='e5m2', 131 | help='supported types : e5m2, e4m3, bf16' ) 132 | parser.add_argument('--patch-ops', default='None', 133 | help='Patch Ops to enable custom gemm implementation') 134 | parser.add_argument('--device',default='cuda', help='device') 135 | parser.add_argument('--verbose', action="store_true", default=False, help='show debug messages') 136 | 137 | # Post training quantization(PTQ) related 138 | parser.add_argument('--recalibrate-bn', action="store_true", default=False, 139 | help='Perform batchnorm recalibration') 140 | parser.add_argument("--num-calibration-batches", type=int, default=500, 141 | help="Number of batches for BatchNorm calibration") 142 | parser.add_argument("--batch-size", type=int, default=256, help="Batch size for train/calibrate") 143 | parser.add_argument('--eval-batch-size', default=32, type=int, 144 | help='batch size for evaluation') 145 | parser.add_argument("--workers", type=int, default=16) 146 | 147 | # Logging and storing model 148 | parser.add_argument("--output-dir", default=".", help="Output directory to dump model") 149 | 150 | args = parser.parse_args() 151 | print(args) 152 | model_dict = { 153 | "alexnet" : "AlexNet_Weights.DEFAULT", 154 | "vgg16_bn": "VGG16_BN_Weights.DEFAULT", 155 | "resnet18" : "ResNet18_Weights.DEFAULT", 156 | "resnet50" : "ResNet50_Weights.DEFAULT", 157 | "resnext50_32x4d" : "ResNeXt50_32X4D_Weights.DEFAULT", 158 | "resnext101_32x8d" : "ResNeXt101_32X8D_Weights.DEFAULT", 159 | "wide_resnet50_2" : "Wide_ResNet50_2_Weights.DEFAULT", 160 | "densenet121" : "DenseNet121_Weights.DEFAULT", 161 | "densenet201" : "DenseNet201_Weights.DEFAULT", 162 | "mobilenet_v2" : "MobileNet_V2_Weights.DEFAULT", 163 | "shufflenet_v2_x1_0" : "ShuffleNet_V2_Weights.DEFAULT", 164 | "mobilenet_v3_large" : "MobileNet_V3_Large_Weights.DEFAULT", 165 | "mobilenet_v3_small" : "MobileNet_V3_Small_Weights.DEFAULT", 166 | "mnasnet1_0" : "MNASNet1_0_Weights.DEFAULT", 167 | "efficientnet_b7" : "EfficientNet_B7_Weights.DEFAULT", 168 | "regnet_x_32gf" : "RegNet_X_32GF_Weights.DEFAULT", 169 | "inception_v3" : "Inception_V3_Weights.DEFAULT", 170 | "squeezenet1_1" : "SqueezeNet1_1_Weights.DEFAULT", 171 | "efficientnet_b0" : "EfficientNet_B0_Weights.DEFAULT", 172 | } 173 | # Creating output directory 174 | if not os.path.exists(args.output_dir): 175 | os.makedirs(args.output_dir) 176 | 177 | # Create the model and move to GPU 178 | device = torch.device(args.device) 179 | # Data loaders and loss function 180 | train_loader,test_loader = get_data_loaders(args) 181 | criterion = torch.nn.CrossEntropyLoss() 182 | # Model 183 | model = torchvision.models.__dict__[args.model](pretrained=True) 184 | #model = torchvision.models.__dict__[args.model](weights=model_dict[args.model]) 185 | model = model.to(device) 186 | model.eval() 187 | #print(model) 188 | 189 | print("Evaluating original {} model to establish baseline".format(args.model)) 190 | evaluate(model, criterion, test_loader, device) 191 | 192 | # Create a list of exempt_layers 193 | list_exempt_layers, list_layers_output_fused = get_model_exempt_layers(args, model) 194 | print("Preparing the {} model for {} quantization".format(args.model, args.data_type.lower())) 195 | print("List of exempt layers : ", list_exempt_layers) 196 | #print(list_layers_output_fused) 197 | model, emulator = mpt_emu.quantize_model (model, dtype=args.data_type.lower(), calibrate=args.recalibrate_bn, hw_patch=args.patch_ops, 198 | list_exempt_layers=list_exempt_layers, list_layers_output_fused=list_layers_output_fused, 199 | device=device, verbose=args.verbose) 200 | 201 | if args.recalibrate_bn == True: 202 | print("Calibrating the {} model for {} batches of training data".format(args.model, args.num_calibration_batches)) 203 | model.train() 204 | evaluate(model, criterion, train_loader, device, 205 | num_batches=args.num_calibration_batches, train=True) 206 | 207 | model.eval() 208 | print("Fusing BatchNorm layers") 209 | model = emulator.fuse_bnlayers_and_quantize_model(model) 210 | print("Evaluating {} fused {} model. ".format(args.model, args.data_type.lower())) 211 | evaluate(model, criterion, test_loader, device) 212 | -------------------------------------------------------------------------------- /examples/inference/classifier/launch.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Dharma Teja Vooturi, Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | import subprocess 11 | import itertools 12 | import os 13 | 14 | model_choices = [] 15 | model_choices += ["resnet50"] 16 | model_choices += ["wide_resnet50_2"] 17 | model_choices += ["resnext50_32x4d"] 18 | model_choices += ["resnext101_32x8d"] 19 | model_choices += ["densenet121"] 20 | model_choices += ["densenet201"] 21 | model_choices += ["mobilenet_v2"] 22 | model_choices += ["mobilenet_v3_small"] 23 | model_choices += ["mobilenet_v3_large"] 24 | model_choices += ["inception_v3"] 25 | model_choices += ["squeezenet1_1"] 26 | model_choices += ["efficientnet_b0"] 27 | 28 | device_choices = ["cuda:0"] 29 | batch_size = 64 30 | data_type="e4m3" # Suported types : e5m2, e4m3, e3m4, hybrid 31 | recalibrate_bn=True#False if data_type == "e4m3" else True 32 | num_calibration_batches = 100 33 | finetune_lr=0.0002 34 | finetune_epochs=2 35 | cmodel="none" 36 | BASE_GPU_ID = 2 37 | NUM_GPUS = 1 38 | print_results = False 39 | verbose=True 40 | data_dir="/fastdata/imagenet/" 41 | cur_sbps = [] 42 | 43 | for exp_id, exp_config in enumerate(itertools.product(model_choices, device_choices)): 44 | model, device = exp_config 45 | 46 | # Creating output directory 47 | output_dir = "calib_experiments/{}-{}".format(model, device) 48 | dump_fp = os.path.join(output_dir,"log.txt") 49 | if not os.path.exists(output_dir): 50 | os.makedirs(output_dir) 51 | 52 | gpu_id = BASE_GPU_ID + (exp_id % NUM_GPUS) 53 | cmd = "" 54 | if "cuda" in device: 55 | cmd = "CUDA_VISIBLE_DEVICES={} ".format(gpu_id) 56 | cmd += "python imagenet_test.py --model {} --batch-size {} --data-type {} --device {} ".format( 57 | model, batch_size, data_type, device) 58 | if cmodel != "none" and device == "cpu": 59 | cmd += "--patch-ops {} ".format(cmodel) 60 | if recalibrate_bn: 61 | cmd += "--recalibrate-bn --num-calibration-batches {} ".format(num_calibration_batches) 62 | if verbose: 63 | cmd += "--verbose " 64 | cmd += "--data-path {} --output-dir {} 2>&1 | tee {}".format(data_dir, output_dir, dump_fp) 65 | 66 | print(cmd) 67 | 68 | if print_results: 69 | fh = open(dump_fp) 70 | accuracy = float(fh.readlines()[-1].strip().split()[2].strip()) 71 | print("{:.2f}, ".format(accuracy), end="") 72 | fh.close() 73 | continue 74 | 75 | p = subprocess.Popen(cmd, shell=True) 76 | cur_sbps.append(p) 77 | 78 | if exp_id%NUM_GPUS == NUM_GPUS-1: 79 | exit_codes = [p.wait() for p in cur_sbps] 80 | cur_sbps = [] # Emptying the process list 81 | -------------------------------------------------------------------------------- /examples/inference/classifier/test_scaleshift.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Dharma Teja Vooturi, Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | import os 11 | import torch 12 | import torchvision 13 | import torchvision.transforms as transforms 14 | import torchvision.datasets as datasets 15 | 16 | 17 | def get_test_loader(data_path, batch_size=256): 18 | val_dir = os.path.join(data_path, 'val') 19 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 20 | std=[0.229, 0.224, 0.225]) 21 | 22 | test_loader = torch.utils.data.DataLoader( 23 | datasets.ImageFolder(val_dir, transforms.Compose([ 24 | transforms.Resize(256), 25 | transforms.CenterCrop(224), 26 | transforms.ToTensor(), 27 | normalize, 28 | ])), 29 | batch_size=batch_size, shuffle=False, 30 | num_workers=4, pin_memory=True) 31 | 32 | return test_loader 33 | 34 | arch = "resnet50" 35 | model = torchvision.models.__dict__[arch](pretrained=True) 36 | device = torch.device("cuda") 37 | model.to(device) 38 | 39 | test_loader = get_test_loader("/fastdata/imagenet") 40 | criterion = torch.nn.CrossEntropyLoss() 41 | 42 | from train import evaluate 43 | def test(): 44 | evaluate(model, criterion, test_loader, device, num_batches=2, print_freq=1) 45 | 46 | test() 47 | 48 | from mptemu import scale_shift 49 | model = scale_shift.replace_batchnorms_with_scaleshifts(model) # Replacing BN with scaleshift layers 50 | model.to(device) 51 | test() 52 | 53 | 54 | -------------------------------------------------------------------------------- /examples/inference/classifier/train.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Dharma Teja Vooturi, Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | import datetime 11 | import os 12 | import time 13 | 14 | import torch 15 | import torch.utils.data 16 | from torch import nn 17 | import torchvision 18 | from torchvision import transforms 19 | 20 | import utils 21 | 22 | try: 23 | from apex import amp 24 | except ImportError: 25 | amp = None 26 | 27 | def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, 28 | apex=False, num_batches=None): 29 | model.train() 30 | metric_logger = utils.MetricLogger(delimiter=" ") 31 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) 32 | metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}')) 33 | 34 | header = 'Epoch: [{}]'.format(epoch) 35 | i = 0 36 | for image, target in metric_logger.log_every(data_loader, print_freq, header): 37 | i += 1 38 | start_time = time.time() 39 | image, target = image.to(device), target.to(device) 40 | output = model(image) 41 | loss = criterion(output, target) 42 | 43 | optimizer.zero_grad() 44 | if apex: 45 | with amp.scale_loss(loss, optimizer) as scaled_loss: 46 | scaled_loss.backward() 47 | else: 48 | loss.backward() 49 | optimizer.step() 50 | 51 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 52 | batch_size = image.shape[0] 53 | metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) 54 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 55 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 56 | metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time)) 57 | 58 | if num_batches is not None and i%num_batches == 0: 59 | break 60 | 61 | 62 | def evaluate(model, criterion, data_loader, device, num_batches=None, print_freq=100, train=False): 63 | if not train: 64 | model.eval() 65 | header = 'Test:' 66 | else: 67 | model.train() 68 | header = 'Train:' 69 | 70 | metric_logger = utils.MetricLogger(delimiter=" ") 71 | 72 | with torch.no_grad(): 73 | batch_id = 0 74 | for image, target in metric_logger.log_every(data_loader, print_freq, header): 75 | image = image.to(device, non_blocking=True) 76 | target = target.to(device, non_blocking=True) 77 | output = model(image) 78 | loss = criterion(output, target) 79 | 80 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 81 | # FIXME need to take into account that the datasets 82 | # could have been padded in distributed setup 83 | batch_size = image.shape[0] 84 | metric_logger.update(loss=loss.item()) 85 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 86 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 87 | 88 | if num_batches is not None and batch_id+1 == num_batches: 89 | break; 90 | else: 91 | batch_id += 1 92 | # gather the stats from all processes 93 | metric_logger.synchronize_between_processes() 94 | 95 | print(' * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}' 96 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5)) 97 | return metric_logger.acc1.global_avg 98 | 99 | def main(args): 100 | if args.apex and amp is None: 101 | raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " 102 | "to enable mixed-precision training.") 103 | 104 | if args.output_dir: 105 | utils.mkdir(args.output_dir) 106 | 107 | #utils.init_distributed_mode(args) 108 | args.distributed = False 109 | print(args) 110 | 111 | device = torch.device(args.device) 112 | 113 | torch.backends.cudnn.benchmark = True 114 | 115 | train_dir = os.path.join(args.data_path, 'train') 116 | val_dir = os.path.join(args.data_path, 'val') 117 | 118 | dataset = datasets.ImageFolder( 119 | train_dir, 120 | transforms.Compose([ 121 | transforms.RandomResizedCrop(crop_size), 122 | transforms.RandomHorizontalFlip(), 123 | ])) 124 | dataset_test = datasets.ImageFolder(val_dir, transforms.Compose([ 125 | transforms.Resize(val_size), 126 | transforms.CenterCrop(crop_size), 127 | ])) 128 | 129 | if distributed: 130 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) 131 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) 132 | else: 133 | train_sampler = torch.utils.data.RandomSampler(dataset) 134 | test_sampler = torch.utils.data.SequentialSampler(dataset_test) 135 | 136 | data_loader = torch.utils.data.DataLoader( 137 | dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 138 | sampler=train_sampler, num_workers=args.workers, pin_memory=True) 139 | 140 | data_loader_test = torch.utils.data.DataLoader( 141 | dataset_test, batch_size=args.batch_size, shuffle=False, 142 | sampler=test_sampler, num_workers=args.workers, pin_memory=True) 143 | 144 | print("Creating model") 145 | model = torchvision.models.__dict__[args.model](pretrained=args.pretrained) 146 | model.to(device) 147 | if args.distributed and args.sync_bn: 148 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 149 | 150 | criterion = nn.CrossEntropyLoss() 151 | 152 | optimizer = torch.optim.SGD( 153 | model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 154 | 155 | if args.apex: 156 | model, optimizer = amp.initialize(model, optimizer, 157 | opt_level=args.apex_opt_level 158 | ) 159 | 160 | #lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) 161 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[8,12], gamma=args.lr_gamma) 162 | 163 | model_without_ddp = model 164 | if args.distributed: 165 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 166 | model_without_ddp = model.module 167 | 168 | if args.resume: 169 | checkpoint = torch.load(args.resume, map_location='cpu') 170 | model_without_ddp.load_state_dict(checkpoint['model']) 171 | optimizer.load_state_dict(checkpoint['optimizer']) 172 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 173 | args.start_epoch = checkpoint['epoch'] + 1 174 | 175 | if args.test_only: 176 | evaluate(model, criterion, data_loader_test, device=device) 177 | return 178 | 179 | ################### Quantization and Sparsity ########################## 180 | import collections 181 | filter_module_types = [torch.nn.Conv2d, torch.nn.Linear] # Only quantizing convolution and linear modules 182 | exempt_modules = ["conv1","fc"] 183 | is_training = True 184 | 185 | ## QUANTIZATION ## 186 | model_qconfig_dict = collections.OrderedDict() 187 | if args.do_quantization: 188 | from quantemu import qutils 189 | ### Preparing the quantization configuration dictionary ### 190 | mod_qconfig = qutils.get_module_quant_config(args.qlevel, args.qdtype, args.qscheme, is_training=is_training) 191 | model_qconfig_dict = qutils.get_or_update_model_quant_config_dict(model, filter_module_types, mod_qconfig, 192 | exempt_modules=exempt_modules) 193 | print("Model quantization configuration") 194 | for layer,qconfig in model_qconfig_dict.items(): 195 | print(layer, qconfig) 196 | print() 197 | 198 | # Setting up the model for QAT 199 | qutils.reset_quantization_setup(model, model_qconfig_dict) 200 | qhooks = qutils.add_quantization_hooks(model, model_qconfig_dict, is_training=is_training) 201 | 202 | print("Start training") 203 | start_time = time.time() 204 | for epoch in range(args.start_epoch, args.epochs): 205 | if args.distributed: 206 | train_sampler.set_epoch(epoch) 207 | 208 | train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, 209 | args.print_freq, apex=args.apex, num_batches=None) 210 | 211 | lr_scheduler.step() 212 | evaluate(model, criterion, data_loader_test, device=device) 213 | if args.output_dir: 214 | checkpoint = { 215 | 'model': model_without_ddp.state_dict(), 216 | 'optimizer': optimizer.state_dict(), 217 | 'lr_scheduler': lr_scheduler.state_dict(), 218 | 'epoch': epoch, 219 | 'args': args} 220 | utils.save_on_master( 221 | checkpoint, 222 | os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) 223 | utils.save_on_master( 224 | checkpoint, 225 | os.path.join(args.output_dir, 'checkpoint.pth')) 226 | 227 | total_time = time.time() - start_time 228 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 229 | print('Training time {}'.format(total_time_str)) 230 | 231 | 232 | def parse_args(): 233 | import argparse 234 | parser = argparse.ArgumentParser(description='PyTorch Classification Training') 235 | 236 | parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', help='dataset') 237 | parser.add_argument('--model', default='resnet18', help='model') 238 | parser.add_argument('--device', default='cuda', help='device') 239 | parser.add_argument('-b', '--batch-size', default=256, type=int) 240 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 241 | help='number of total epochs to run') 242 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', 243 | help='number of data loading workers (default: 16)') 244 | parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate') 245 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 246 | help='momentum') 247 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 248 | metavar='W', help='weight decay (default: 1e-4)', 249 | dest='weight_decay') 250 | parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') 251 | parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') 252 | parser.add_argument('--print-freq', default=10, type=int, help='print frequency') 253 | parser.add_argument('--output-dir', default='.', help='path where to save') 254 | parser.add_argument('--resume', default='', help='resume from checkpoint') 255 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 256 | help='start epoch') 257 | parser.add_argument( 258 | "--cache-dataset", 259 | dest="cache_dataset", 260 | help="Cache the datasets for quicker initialization. It also serializes the transforms", 261 | action="store_true", 262 | ) 263 | parser.add_argument( 264 | "--sync-bn", 265 | dest="sync_bn", 266 | help="Use sync batch norm", 267 | action="store_true", 268 | ) 269 | parser.add_argument( 270 | "--test-only", 271 | dest="test_only", 272 | help="Only test the model", 273 | action="store_true", 274 | ) 275 | parser.add_argument( 276 | "--pretrained", 277 | dest="pretrained", 278 | help="Use pre-trained models from the modelzoo", 279 | action="store_true", 280 | ) 281 | 282 | # Mixed precision training parameters 283 | parser.add_argument('--apex', action='store_true', 284 | help='Use apex for mixed precision training') 285 | parser.add_argument('--apex-opt-level', default='O1', type=str, 286 | help='For apex mixed precision training' 287 | 'O0 for FP32 training, O1 for mixed precision training.' 288 | 'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet' 289 | ) 290 | 291 | # distributed training parameters 292 | parser.add_argument('--world-size', default=1, type=int, 293 | help='number of distributed processes') 294 | parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') 295 | 296 | 297 | # Pruning arguments 298 | # Quantization arguments 299 | parser.add_argument("--do-quantization", action="store_true", help="Whether to do quantization or not") 300 | parser.add_argument("--qdtype", default="bfloat8", help="Quantization data type") 301 | parser.add_argument("--qscheme", default="rne", help="Quantization scheme") 302 | parser.add_argument("--qlevel", type=int, default=3, choices=[0,1,3,6,7], 303 | help="0->No quantization, 1->wt only, 3->(wt+iact), 6->(iact+oact) 7->(wt+iact+oact)") 304 | parser.add_argument("--patch_ops", action="store_true") 305 | 306 | # Pruning arguments 307 | parser.add_argument("--do-pruning", action="store_true", help="Whether or not to do pruning") 308 | parser.add_argument("--sbh", type=int, default=-1, help="Scope block height") 309 | parser.add_argument("--sbw", type=int, default=-1, help="Scope block width") 310 | parser.add_argument("--sparsity", type=float, default=0.5, help="Amount of sparsity to impose [0,1]") 311 | 312 | args = parser.parse_args() 313 | 314 | return args 315 | 316 | 317 | if __name__ == "__main__": 318 | args = parse_args() 319 | main(args) 320 | -------------------------------------------------------------------------------- /examples/inference/classifier/utils.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Dharma Teja Vooturi, Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | from collections import defaultdict, deque 11 | import datetime 12 | import time 13 | import torch 14 | import torch.distributed as dist 15 | 16 | import errno 17 | import os 18 | 19 | 20 | class SmoothedValue(object): 21 | """Track a series of values and provide access to smoothed values over a 22 | window or the global series average. 23 | """ 24 | 25 | def __init__(self, window_size=20, fmt=None): 26 | if fmt is None: 27 | fmt = "{median:.4f} ({global_avg:.4f})" 28 | self.deque = deque(maxlen=window_size) 29 | self.total = 0.0 30 | self.count = 0 31 | self.fmt = fmt 32 | 33 | def update(self, value, n=1): 34 | self.deque.append(value) 35 | self.count += n 36 | self.total += value * n 37 | 38 | def synchronize_between_processes(self): 39 | """ 40 | Warning: does not synchronize the deque! 41 | """ 42 | if not is_dist_avail_and_initialized(): 43 | return 44 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 45 | dist.barrier() 46 | dist.all_reduce(t) 47 | t = t.tolist() 48 | self.count = int(t[0]) 49 | self.total = t[1] 50 | 51 | @property 52 | def median(self): 53 | d = torch.tensor(list(self.deque)) 54 | return d.median().item() 55 | 56 | @property 57 | def avg(self): 58 | d = torch.tensor(list(self.deque), dtype=torch.float32) 59 | return d.mean().item() 60 | 61 | @property 62 | def global_avg(self): 63 | return self.total / self.count 64 | 65 | @property 66 | def max(self): 67 | return max(self.deque) 68 | 69 | @property 70 | def value(self): 71 | return self.deque[-1] 72 | 73 | def __str__(self): 74 | return self.fmt.format( 75 | median=self.median, 76 | avg=self.avg, 77 | global_avg=self.global_avg, 78 | max=self.max, 79 | value=self.value) 80 | 81 | 82 | class MetricLogger(object): 83 | def __init__(self, delimiter="\t"): 84 | self.meters = defaultdict(SmoothedValue) 85 | self.delimiter = delimiter 86 | 87 | def update(self, **kwargs): 88 | for k, v in kwargs.items(): 89 | if isinstance(v, torch.Tensor): 90 | v = v.item() 91 | assert isinstance(v, (float, int)) 92 | self.meters[k].update(v) 93 | 94 | def __getattr__(self, attr): 95 | if attr in self.meters: 96 | return self.meters[attr] 97 | if attr in self.__dict__: 98 | return self.__dict__[attr] 99 | raise AttributeError("'{}' object has no attribute '{}'".format( 100 | type(self).__name__, attr)) 101 | 102 | def __str__(self): 103 | loss_str = [] 104 | for name, meter in self.meters.items(): 105 | loss_str.append( 106 | "{}: {}".format(name, str(meter)) 107 | ) 108 | return self.delimiter.join(loss_str) 109 | 110 | def synchronize_between_processes(self): 111 | for meter in self.meters.values(): 112 | meter.synchronize_between_processes() 113 | 114 | def add_meter(self, name, meter): 115 | self.meters[name] = meter 116 | 117 | def log_every(self, iterable, print_freq, header=None): 118 | i = 0 119 | if not header: 120 | header = '' 121 | start_time = time.time() 122 | end = time.time() 123 | iter_time = SmoothedValue(fmt='{avg:.4f}') 124 | data_time = SmoothedValue(fmt='{avg:.4f}') 125 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 126 | if torch.cuda.is_available(): 127 | log_msg = self.delimiter.join([ 128 | header, 129 | '[{0' + space_fmt + '}/{1}]', 130 | 'eta: {eta}', 131 | '{meters}', 132 | 'time: {time}', 133 | 'data: {data}', 134 | 'max mem: {memory:.0f}' 135 | ]) 136 | else: 137 | log_msg = self.delimiter.join([ 138 | header, 139 | '[{0' + space_fmt + '}/{1}]', 140 | 'eta: {eta}', 141 | '{meters}', 142 | 'time: {time}', 143 | 'data: {data}' 144 | ]) 145 | MB = 1024.0 * 1024.0 146 | for obj in iterable: 147 | data_time.update(time.time() - end) 148 | yield obj 149 | iter_time.update(time.time() - end) 150 | if i % print_freq == 0: 151 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 152 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 153 | if torch.cuda.is_available(): 154 | print(log_msg.format( 155 | i, len(iterable), eta=eta_string, 156 | meters=str(self), 157 | time=str(iter_time), data=str(data_time), 158 | memory=torch.cuda.max_memory_allocated() / MB)) 159 | else: 160 | print(log_msg.format( 161 | i, len(iterable), eta=eta_string, 162 | meters=str(self), 163 | time=str(iter_time), data=str(data_time))) 164 | i += 1 165 | end = time.time() 166 | total_time = time.time() - start_time 167 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 168 | print('{} Total time: {}'.format(header, total_time_str)) 169 | 170 | 171 | def accuracy(output, target, topk=(1,)): 172 | """Computes the accuracy over the k top predictions for the specified values of k""" 173 | with torch.no_grad(): 174 | maxk = max(topk) 175 | batch_size = target.size(0) 176 | 177 | _, pred = output.topk(maxk, 1, True, True) 178 | pred = pred.t() 179 | correct = pred.eq(target[None]) 180 | 181 | res = [] 182 | for k in topk: 183 | correct_k = correct[:k].flatten().sum(dtype=torch.float32) 184 | res.append(correct_k * (100.0 / batch_size)) 185 | return res 186 | 187 | 188 | def mkdir(path): 189 | try: 190 | os.makedirs(path) 191 | except OSError as e: 192 | if e.errno != errno.EEXIST: 193 | raise 194 | 195 | 196 | def setup_for_distributed(is_master): 197 | """ 198 | This function disables printing when not in master process 199 | """ 200 | import builtins as __builtin__ 201 | builtin_print = __builtin__.print 202 | 203 | def print(*args, **kwargs): 204 | force = kwargs.pop('force', False) 205 | if is_master or force: 206 | builtin_print(*args, **kwargs) 207 | 208 | __builtin__.print = print 209 | 210 | 211 | def is_dist_avail_and_initialized(): 212 | if not dist.is_available(): 213 | return False 214 | if not dist.is_initialized(): 215 | return False 216 | return True 217 | 218 | 219 | def get_world_size(): 220 | if not is_dist_avail_and_initialized(): 221 | return 1 222 | return dist.get_world_size() 223 | 224 | 225 | def get_rank(): 226 | if not is_dist_avail_and_initialized(): 227 | return 0 228 | return dist.get_rank() 229 | 230 | 231 | def is_main_process(): 232 | return get_rank() == 0 233 | 234 | 235 | def save_on_master(*args, **kwargs): 236 | if is_main_process(): 237 | torch.save(*args, **kwargs) 238 | 239 | 240 | def init_distributed_mode(args): 241 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 242 | args.rank = int(os.environ["RANK"]) 243 | args.world_size = int(os.environ['WORLD_SIZE']) 244 | args.gpu = int(os.environ['LOCAL_RANK']) 245 | elif 'SLURM_PROCID' in os.environ: 246 | args.rank = int(os.environ['SLURM_PROCID']) 247 | args.gpu = args.rank % torch.cuda.device_count() 248 | elif hasattr(args, "rank"): 249 | pass 250 | else: 251 | print('Not using distributed mode') 252 | args.distributed = False 253 | return 254 | 255 | args.distributed = True 256 | 257 | torch.cuda.set_device(args.gpu) 258 | args.dist_backend = 'nccl' 259 | print('| distributed init (rank {}): {}'.format( 260 | args.rank, args.dist_url), flush=True) 261 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 262 | world_size=args.world_size, rank=args.rank) 263 | setup_for_distributed(args.rank == 0) 264 | -------------------------------------------------------------------------------- /examples/training/README.md: -------------------------------------------------------------------------------- 1 | # FP8 Mixed Precision Training 2 | Follow the examples listed here to train models using FP8 emulation toolkit. 3 | The toolkit supports two different training methods: 4 | 5 | ## Direct Conversion Method 6 | This method uses a single FP8 format (`E5M2`) for both forward and backward computations. Operators (ex: Convolution, Linear) perform dot-product computations on input tensors expressed in `E5M2` format and accumulated into FP32 output tensors. The outut tensors are directly converted to `E5M2` before they are written to the memoryi -- gradients are scaled using automatic loss-scaling method. Full details of the training algorithm are covered in the paper [Mixed Precision Training With 8-bit Floating Point](https://arxiv.org/pdf/1905.12334.pdf). 7 | 8 | ## Scaled-Hybrid Method 9 | This method uses Hybrid-FP8 approach which uses `E4M3` for forward computations and `E5M2` for representing error gradients. The weight and activation tensors use `per-tensor scaling` in the forward pass to compensate for the limited dynamic range of `E4M3` format. The weight and activation scaling factors are computed every iteration at run-time which are then used to quantize the FP32 output tensor to `E4M3` format. The gradients are scaled using standard automatic loss scaling methods. This method is based on the algorithm discussed in the paper [FP8 formats for Deep Learning](https://arxiv.org/abs/2209.05433) 10 | 11 | ## Loss Scaling 12 | A modified NVIDIA [Apex](https://github.com/NVIDIA/apex) library is used to extend loss scaling capabilities to CPU platforms. 13 | Follow the instructions below to patch apex library to enable CPU support. 14 | 15 | ``` 16 | $ git clone https://github.com/NVIDIA/apex 17 | $ cd apex 18 | $ git checkout 05091d4 19 | $ git apply /nvidia_apex_cpu_patch_05091d4.patch 20 | ``` 21 | Install python-only build. 22 | ``` 23 | $ pip3 install -v --no-cache-dir ./ 24 | ``` 25 | 26 | ## Usage Example 27 | 28 | Modify the training script as follows to enable FP8 emulation. 29 | 30 | ``` 31 | # import the emulator 32 | from mpemu import mpt_emu 33 | 34 | ... 35 | 36 | # layers exempt from FP8 conversion 37 | list_exempt_layers = ["conv1","bn1","fc"] 38 | # fused layers will be exempt from converting output tensor to FP8, the following layer will read from FP32 buffer. 39 | list_layers_output_fused = None 40 | # use 'direct' training method, Options : direct, hybrid 41 | model, emulator = mpt_emu.initialize(model, optimizer, training_algo="direct", 42 | list_exempt_layers=list_exempt_layers, list_layers_output_fused=list_layers_output_fused, 43 | device="cpu", verbose=True) 44 | 45 | ... 46 | 47 | # training loop 48 | for epoch in range(args.start_epoch, args.epochs): 49 | ... 50 | 51 | emulator.update_global_steps(epoch*len(train_loader)) 52 | 53 | ... 54 | 55 | emulator.optimizer_step(optimizer) 56 | 57 | ``` 58 | -------------------------------------------------------------------------------- /examples/training/bert/launch_fp8_training.sh: -------------------------------------------------------------------------------- 1 | 2 | export MIXED_PRECISION=fp16 3 | 4 | TRAINING_ALGO=${1:-'direct'} 5 | 6 | python run_qa_no_trainer.py --model_name_or_path bert-large-uncased --dataset_name squad --per_device_train_batch_size 12 --learning_rate 3e-5 --num_train_epochs 2 --max_seq_length 384 --doc_stride 128 --fp8_training --fp8_algo=$TRAINING_ALGO --output_dir ./bert-large-uncased-fp8 |& tee bert-large-uncased-fp8.log 7 | -------------------------------------------------------------------------------- /examples/training/bert/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers >= 4.11.0 2 | accelerate 3 | datasets >= 1.8.0 4 | torch >= 1.3.0 5 | evaluate 6 | -------------------------------------------------------------------------------- /examples/training/nvidia_apex_cpu_patch_05091d4.patch: -------------------------------------------------------------------------------- 1 | diff --git a/apex/amp/_initialize.py b/apex/amp/_initialize.py 2 | index 3ae6fde..04be0ae 100644 3 | --- a/apex/amp/_initialize.py 4 | +++ b/apex/amp/_initialize.py 5 | @@ -87,12 +87,14 @@ def check_params_fp32(models): 6 | "When using amp.initialize, you do not need to call .half() on your model\n" 7 | "before passing it, no matter what optimization level you choose.".format( 8 | name, param.type())) 9 | + ''' 10 | elif not param.is_cuda: 11 | warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n" 12 | "When using amp.initialize, you need to provide a model with parameters\n" 13 | "located on a CUDA device before passing it no matter what optimization level\n" 14 | "you chose. Use model.to('cuda') to use the default device.".format( 15 | name, param.type())) 16 | + ''' 17 | 18 | # Backward compatibility for PyTorch 0.4 19 | if hasattr(model, 'named_buffers'): 20 | @@ -110,12 +112,14 @@ def check_params_fp32(models): 21 | "When using amp.initialize, you do not need to call .half() on your model\n" 22 | "before passing it, no matter what optimization level you choose.".format( 23 | name, buf.type())) 24 | + ''' 25 | elif not buf.is_cuda: 26 | warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n" 27 | "When using amp.initialize, you need to provide a model with buffers\n" 28 | "located on a CUDA device before passing it no matter what optimization level\n" 29 | "you chose. Use model.to('cuda') to use the default device.".format( 30 | name, buf.type())) 31 | + ''' 32 | 33 | 34 | def check_optimizers(optimizers): 35 | diff --git a/apex/amp/_process_optimizer.py b/apex/amp/_process_optimizer.py 36 | index 471289b..3d72008 100644 37 | --- a/apex/amp/_process_optimizer.py 38 | +++ b/apex/amp/_process_optimizer.py 39 | @@ -54,6 +54,9 @@ def lazy_init_with_master_weights(self): 40 | # .format(param.size())) 41 | fp32_params_this_group.append(param) 42 | param_group['params'][i] = param 43 | + elif param.type() == 'torch.FloatTensor': 44 | + fp32_params_this_group.append(param) 45 | + param_group['params'][i] = param 46 | else: 47 | raise TypeError("Optimizer's parameters must be either " 48 | "torch.cuda.FloatTensor or torch.cuda.HalfTensor. " 49 | @@ -212,6 +215,8 @@ def lazy_init_no_master_weights(self): 50 | stash.all_fp16_params.append(param) 51 | elif param.type() == 'torch.cuda.FloatTensor': 52 | stash.all_fp32_params.append(param) 53 | + elif param.type() == 'torch.FloatTensor': 54 | + stash.all_fp32_params.append(param) 55 | else: 56 | raise TypeError("Optimizer's parameters must be either " 57 | "torch.cuda.FloatTensor or torch.cuda.HalfTensor. " 58 | diff --git a/apex/amp/scaler.py b/apex/amp/scaler.py 59 | index 99888bc..cbf0f22 100644 60 | --- a/apex/amp/scaler.py 61 | +++ b/apex/amp/scaler.py 62 | @@ -6,7 +6,11 @@ from itertools import product 63 | def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=False): 64 | # Exception handling for 18.04 compatibility 65 | if check_overflow: 66 | - cpu_sum = float(model_grad.float().sum()) 67 | + # handling sparse gradients 68 | + if (model_grad.is_sparse) : 69 | + cpu_sum = torch.sparse.sum(model_grad) 70 | + else : 71 | + cpu_sum = float(model_grad.float().sum()) 72 | if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: 73 | return True 74 | 75 | @@ -19,7 +23,11 @@ def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=F 76 | def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, a, b, check_overflow=False): 77 | # Exception handling for 18.04 compatibility 78 | if check_overflow: 79 | - cpu_sum = float(model_grad.float().sum()) 80 | + # handling sparse gradients 81 | + if (model_grad.is_sparse) : 82 | + cpu_sum = torch.sparse.sum(model_grad) 83 | + else: 84 | + cpu_sum = float(model_grad.float().sum()) 85 | if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: 86 | return True 87 | 88 | @@ -53,7 +61,7 @@ class LossScaler(object): 89 | self._scale_seq_len = scale_window 90 | self._unskipped = 0 91 | self._has_overflow = False 92 | - self._overflow_buf = torch.cuda.IntTensor([0]) 93 | + self._overflow_buf = torch.IntTensor([0]) 94 | if multi_tensor_applier.available: 95 | import amp_C 96 | LossScaler.has_fused_kernel = multi_tensor_applier.available 97 | -------------------------------------------------------------------------------- /examples/training/resnet/README.md: -------------------------------------------------------------------------------- 1 | ``` 2 | python imagenet_main.py --arch resnet50 --enable-bf8 --master-weight-precision='fp16' --resume checkpoint.pth.tar /fastdata/imagenet/ |& tee $LOGFILE 3 | ``` 4 | -------------------------------------------------------------------------------- /examples/training/resnet/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | from math import pi, cos 4 | """Learning Rate Schedulers""" 5 | 6 | class LRScheduler(object): 7 | r"""Learning Rate Scheduler 8 | For mode='step', we multiply lr with `decay_factor` at each epoch in `step`. 9 | For mode='poly':: 10 | lr = targetlr + (baselr - targetlr) * (1 - iter / maxiter) ^ power 11 | For mode='cosine':: 12 | lr = targetlr + (baselr - targetlr) * (1 + cos(pi * iter / maxiter)) / 2 13 | If warmup_epochs > 0, a warmup stage will be inserted before the main lr scheduler. 14 | For warmup_mode='linear':: 15 | lr = warmup_lr + (baselr - warmup_lr) * iter / max_warmup_iter 16 | For warmup_mode='constant':: 17 | lr = warmup_lr 18 | Parameters 19 | ---------- 20 | mode : str 21 | Modes for learning rate scheduler. 22 | Currently it supports 'step', 'poly' and 'cosine'. 23 | niters : int 24 | Number of iterations in each epoch. 25 | base_lr : float 26 | Base learning rate, i.e. the starting learning rate. 27 | epochs : int 28 | Number of training epochs. 29 | step : list 30 | A list of epochs to decay the learning rate. 31 | decay_factor : float 32 | Learning rate decay factor. 33 | targetlr : float 34 | Target learning rate for poly and cosine, as the ending learning rate. 35 | power : float 36 | Power of poly function. 37 | warmup_epochs : int 38 | Number of epochs for the warmup stage. 39 | warmup_lr : float 40 | The base learning rate for the warmup stage. 41 | warmup_mode : str 42 | Modes for the warmup stage. 43 | Currently it supports 'linear' and 'constant'. 44 | """ 45 | def __init__(self, optimizer, niters, args): 46 | super(LRScheduler, self).__init__() 47 | 48 | self.mode = args.lr_mode 49 | self.warmup_mode = args.warmup_mode if hasattr(args,'warmup_mode') else 'linear' 50 | assert(self.mode in ['step', 'poly', 'cosine']) 51 | assert(self.warmup_mode in ['linear', 'constant']) 52 | 53 | self.optimizer = optimizer 54 | 55 | #self.base_lr = args.base_lr if hasattr(args,'base_lr') else 0.1 56 | self.base_lr = args.lr if hasattr(args,'lr') else 0.1 57 | self.learning_rate = self.base_lr 58 | self.niters = niters 59 | 60 | self.step = [int(i) for i in args.step.split(',')] if hasattr(args,'step') else [30, 60, 90] 61 | self.decay_factor = args.decay_factor if hasattr(args,'decay_factor') else 0.1 62 | self.targetlr = args.targetlr if hasattr(args,'targetlr') else 0.0 63 | self.power = args.power if hasattr(args,'power') else 2.0 64 | self.warmup_lr = args.warmup_lr if hasattr(args,'warmup_lr') else 0.0 65 | self.max_iter = args.epochs * niters 66 | self.warmup_iters = (args.warmup_epochs if hasattr(args,'warmup_epochs') else 0) * niters 67 | 68 | 69 | print(self.base_lr, self.mode, self.warmup_mode, self.warmup_iters) 70 | 71 | def update(self, i, epoch): 72 | T = epoch * self.niters + i 73 | assert (T >= 0 and T <= self.max_iter) 74 | 75 | if self.warmup_iters > T: 76 | # Warm-up Stage 77 | if self.warmup_mode == 'linear': 78 | self.learning_rate = self.warmup_lr + (self.base_lr - self.warmup_lr) * \ 79 | T / self.warmup_iters 80 | elif self.warmup_mode == 'constant': 81 | self.learning_rate = self.warmup_lr 82 | else: 83 | raise NotImplementedError 84 | else: 85 | if self.mode == 'step': 86 | count = sum([1 for s in self.step if s <= epoch]) 87 | self.learning_rate = self.base_lr * pow(self.decay_factor, count) 88 | elif self.mode == 'poly': 89 | self.learning_rate = self.targetlr + (self.base_lr - self.targetlr) * \ 90 | pow(1 - (T - self.warmup_iters) / (self.max_iter - self.warmup_iters), self.power) 91 | elif self.mode == 'cosine': 92 | self.learning_rate = self.targetlr + (self.base_lr - self.targetlr) * \ 93 | (1 + cos(pi * (T - self.warmup_iters) / (self.max_iter - self.warmup_iters))) / 2 94 | else: 95 | raise NotImplementedError 96 | 97 | for i, param_group in enumerate(self.optimizer.param_groups): 98 | param_group['lr'] = self.learning_rate 99 | -------------------------------------------------------------------------------- /examples/training/resnet/train_cpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES="" 3 | 4 | HOST_ADDR=${1:-$HOSTNAME} 5 | BATCH=${2:-64} 6 | TRAINING_ALGO=${3:-'direct'} 7 | HW_PATCH=${4:-'None'} 8 | MODEL_PREC=${5:-'fp16'} 9 | PRUNING_ALG=${6:-'None'} # None, Unstructured, Adaptive, Auto 10 | DATA_DIR=${7:-'/fastdata/imagenet/'} 11 | OUTPUT_DIR=${8:-'.'} 12 | ARCH=${9:-'resnet50'} 13 | LR_MODE=${10:-'cosine'} 14 | 15 | WORLD_SIZE=${SLURM_JOB_NUM_NODES:-1} 16 | NODE_RANK=${SLURM_NODEID:-0} 17 | HOST_PORT=12345 18 | LOGFILE="train.log" 19 | 20 | CMD=" torch.distributed.launch --nnodes $WORLD_SIZE --nproc_per_node 1 --node_rank $NODE_RANK --master_addr $HOST_ADDR --master_port $HOST_PORT main_amp_cpu.py --local_rank $NODE_RANK --arch $ARCH --batch-size=$BATCH --lr-mode=$LR_MODE --fp8-training --fp8-algo=$TRAINING_ALGO --master-weight-precision=$MODEL_PREC --patch-ops=$HW_PATCH --pruning-algo $PRUNING_ALG --output-dir $OUTPUT_DIR --resume $OUTPUT_DIR/checkpoint.pth.tar $DATA_DIR " 21 | 22 | echo $CMD 23 | 24 | python -m $CMD |& tee $LOGFILE 25 | -------------------------------------------------------------------------------- /examples/training/resnet/train_gpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_VISIBLE_DEVICES=0,1,2,3 3 | 4 | HOST_ADDR=${1:-$HOSTNAME} 5 | BATCH=${2:-256} 6 | TRAINING_ALGO=${3:-'direct'} 7 | HW_PATCH=${4:-'None'} 8 | MODEL_PREC=${5:-'fp16'} 9 | PRUNING_ALG=${6:-'None'} # None, Unstructured, Adaptive, Auto 10 | DATA_DIR=${7:-'/fastdata/imagenet/'} 11 | OUTPUT_DIR=${8:-'.'} 12 | ARCH=${9:-'resnet50'} 13 | LR_MODE=${10:-'cosine'} 14 | 15 | WORLD_SIZE=${SLURM_JOB_NUM_NODES:-4} 16 | NODE_RANK=${SLURM_NODEID:-0} 17 | HOST_PORT=12345 18 | LOGFILE="train.log" 19 | 20 | #CMD="torchrun --nnodes=1:4 --nproc_per_node 4 --rdzv-backend=c10d --rdzv-endpoint=$HOST_NODE_ADDR main_amp.py --arch $ARCH --batch-size=$BATCH --lr-mode=$LR_MODE --fp8-training --fp8-algo=$TRAINING_ALGO --master-weight-precision=$MODEL_PREC --pruning-algo $PRUNING_ALG --momentum 0.875 --lr 0.256 --weight-decay 3.0517578125e-05 --output-dir $OUTPUT_DIR --resume $OUTPUT_DIR/checkpoint.pth.tar $DATA_DIR " 21 | CMD=" torch.distributed.launch --nproc_per_node 4 --master_addr $HOST_ADDR --master_port $HOST_PORT main_amp.py --arch $ARCH --batch-size=$BATCH --lr-mode=$LR_MODE --fp8-training --fp8-algo=$TRAINING_ALGO --master-weight-precision=$MODEL_PREC --pruning-algo $PRUNING_ALG --momentum 0.875 --lr 0.256 --weight-decay 3.0517578125e-05 --output-dir $OUTPUT_DIR --resume $OUTPUT_DIR/checkpoint.pth.tar $DATA_DIR " 22 | 23 | echo $CMD 24 | 25 | python -m $CMD |& tee $LOGFILE 26 | -------------------------------------------------------------------------------- /mpemu/__init__.py: -------------------------------------------------------------------------------- 1 | from . import pytquant 2 | from . import cmodel 3 | from . import module_wrappers 4 | -------------------------------------------------------------------------------- /mpemu/cmodel/__init__.py: -------------------------------------------------------------------------------- 1 | from . import simple 2 | -------------------------------------------------------------------------------- /mpemu/cmodel/simple.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | import torch 11 | from torch import nn 12 | from torch.autograd import Function 13 | import sys 14 | 15 | import simple_gemm_dev 16 | import simple_conv2d_dev 17 | # backup the original torch functions 18 | fallback_addmm = torch.addmm 19 | fallback_matmul = torch.matmul 20 | fallback_mm = torch.mm 21 | 22 | def is_transposed(input): 23 | if input.is_contiguous(): 24 | return input, False 25 | elif input.t().is_contiguous(): 26 | return input.t(), True 27 | else: 28 | return input.contiguous(), False 29 | 30 | def addmm(input, mat1, mat2, beta=1.0, alpha=1.0, out=None): 31 | if input.dtype == torch.float32 and mat1.dtype == torch.float32 and \ 32 | mat1.dim() == 2 and mat2.dim() == 2 and mat1.size(1) == mat2.size(0): 33 | a_mat, a_trans = is_transposed(mat1) 34 | b_mat, b_trans = is_transposed(mat2) 35 | output = out 36 | if out is None: 37 | output = torch.zeros([a_mat.size(0) if not a_trans else a_mat.size(1), b_mat.size(0) if b_trans else b_mat.size(1)]) 38 | 39 | return SimpleAddmm.apply(output, input, a_mat, b_mat, alpha, beta, a_trans, b_trans) 40 | else: 41 | warnings.warn('simple.addmm does not support the input dimensions - input :{}, mat1: {}, mat2: {}, falling back to torch.addmm'.format( 42 | input.size(), mat1.size(), mat2.size())) 43 | return fallback_addmm(input, mat1, mat2, beta=beta, alpha=alpha, out=out) 44 | 45 | def matmul(input, other, out=None): 46 | if input.dtype == torch.float32 and other.dtype == torch.float32 and \ 47 | input.dim() == 2 and other.dim() == 2 and input.size(1) == other.size(0): 48 | a_mat, a_trans = is_transposed(input) 49 | b_mat, b_trans = is_transposed(other) 50 | output = out 51 | if out is None: 52 | output = torch.zeros([a_mat.size(0) if not a_trans else a_mat.size(1), b_mat.size(0) if b_trans else b_mat.size(1)]) 53 | 54 | return SimpleMatmul.apply(output, a_mat, b_mat, 1.0, a_trans, b_trans) 55 | 56 | # Batch MatMul implementation 57 | elif input.dtype == torch.float32 and other.dtype == torch.float32 and \ 58 | input.dim() == 3 and other.dim() == 2 and input.size(2) == other.size(0): 59 | a_mat, a_trans = is_transposed(input) 60 | b_mat, b_trans = is_transposed(other) 61 | output = out 62 | if out is None: 63 | output = torch.zeros([a_mat.size(0) if not a_trans else a_mat.size(1), a_mat.size(1) if not a_trans else a_mat.size(0), 64 | b_mat.size(0) if b_trans else b_mat.size(1)]) 65 | 66 | return torch.stack(tuple([SimpleMatmul.apply(out1, a_mat1, b_mat, 1.0, a_trans, b_trans) \ 67 | for a_mat1, out1 in zip(a_mat, output)])) 68 | else: 69 | warnings.warn('simple.matmul does not support the input dimensions - input :{}, other: {}, falling back to torch.matmul'.format( 70 | input.size(), other.size())) 71 | return fallback_matmul(input, other, out=out) 72 | 73 | def mm(input, mat2, out=None): 74 | if input.dtype == torch.float32 and mat2.dtype == torch.float32 and \ 75 | input.dim() == 2 and mat2.dim() == 2 and input.size(1) == mat2.size(0): 76 | a_mat, a_trans = is_transposed(input) 77 | b_mat, b_trans = is_transposed(mat2) 78 | output = out 79 | if out is None: 80 | output = torch.zeros([a_mat.size(0) if not a_trans else a_mat.size(1), b_mat.size(0) if b_trans else b_mat.size(1)]) 81 | 82 | return SimpleMatmul.apply(output, a_mat, b_mat, 1.0, a_trans, b_trans) 83 | else: 84 | warnings.warn('simple.mm does not support the input dimensions - input :{}, mat2: {}, falling back to torch.mm'.format( 85 | input.size(), mat2.size())) 86 | return fallback_mm(input, mat2, out=out) 87 | 88 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 89 | N = input.size()[0] 90 | C = input.size()[1] 91 | H = input.size()[2] 92 | W = input.size()[3] 93 | K = weight.size()[0] 94 | C1 = weight.size()[1] 95 | R = weight.size()[2] 96 | S = weight.size()[3] 97 | 98 | if dilation[0] > 1: 99 | sys.exit("ERROR: simple_conv2d does not support dilated convolutions.") 100 | if padding[0] != padding[1]: 101 | sys.exit("ERROR: simple_conv2d does not support non-uniform padding; pad_h must be equal to pad_w.") 102 | if groups > 1: 103 | sys.exit("ERROR: simple_conv2d does not support grouped convolutions; set groups to 1.") 104 | 105 | H_out = ((H + (2*padding[0]) - dilation[0] * (R-1) -1)/stride[0]) + 1 106 | W_out = ((W + (2*padding[1]) - dilation[1] * (S-1) -1)/stride[1]) + 1 107 | output = torch.empty([N, K, int(H_out), int(W_out)]) 108 | output = SimpleConv2dFunction.apply(output, input, weight, bias, stride, padding, dilation, groups) 109 | return output 110 | 111 | class SimpleAddmm(Function): 112 | @staticmethod 113 | def forward(ctx, output, input, mat1, mat2, alpha, beta, a_trans, b_trans): 114 | ctx.save_for_backward(mat1, mat2) 115 | ctx.a_trans = a_trans 116 | ctx.b_trans = b_trans 117 | ctx.alpha = alpha 118 | 119 | simple_gemm_dev.gemm(output, mat1, mat2, alpha, a_trans, b_trans) 120 | output += beta * input; 121 | ctx.mark_dirty(output) 122 | return output 123 | 124 | @staticmethod 125 | def backward(ctx, grad_output): 126 | mat1, mat2 = ctx.saved_tensors 127 | 128 | alpha = ctx.alpha 129 | a_trans = ctx.a_trans 130 | b_trans = ctx.b_trans 131 | 132 | grad_mat1 = torch.zeros_like(mat1) 133 | grad_mat2 = torch.zeros_like(mat2) 134 | grad_out, out_trans = is_transposed(grad_output) 135 | 136 | if a_trans: 137 | simple_gemm_dev.gemm(grad_mat1, mat2, grad_out, alpha, b_trans, not out_trans) 138 | else: 139 | simple_gemm_dev.gemm(grad_mat1, grad_out, mat2, alpha, out_trans, not b_trans) 140 | 141 | if b_trans: 142 | simple_gemm_dev.gemm(grad_mat2, grad_out, mat1, alpha, not out_trans, a_trans) 143 | else: 144 | simple_gemm_dev.gemm(grad_mat2, mat1, grad_out, alpha, not a_trans, out_trans) 145 | 146 | return (grad_output, grad_output, grad_mat1, grad_mat2, None, None, None, None) 147 | 148 | class SimpleMatmul(Function): 149 | @staticmethod 150 | def forward(ctx, output, mat1, mat2, alpha, a_trans, b_trans): 151 | ctx.save_for_backward(mat1, mat2) 152 | ctx.a_trans = a_trans 153 | ctx.b_trans = b_trans 154 | ctx.alpha = alpha 155 | 156 | simple_gemm_dev.gemm(output, mat1, mat2, alpha, a_trans, b_trans) 157 | ctx.mark_dirty(output) 158 | return output 159 | 160 | @staticmethod 161 | def backward(ctx, grad_output): 162 | mat1, mat2 = ctx.saved_tensors 163 | alpha = ctx.alpha 164 | a_trans = ctx.a_trans 165 | b_trans = ctx.b_trans 166 | 167 | grad_mat1 = torch.empty_like(mat1) 168 | grad_mat2 = torch.empty_like(mat2) 169 | grad_out, out_trans = is_transposed(grad_output) 170 | 171 | if a_trans: 172 | simple_gemm_dev.gemm(grad_mat1, mat2, grad_out, alpha, b_trans, not out_trans) 173 | else: 174 | simple_gemm_dev.gemm(grad_mat1, grad_out, mat2, alpha, out_trans, not b_trans) 175 | 176 | if b_trans: 177 | simple_gemm_dev.gemm(grad_mat2, grad_out, mat1, alpha, not out_trans, a_trans) 178 | else: 179 | simple_gemm_dev.gemm(grad_mat2, mat1, grad_out, alpha, not a_trans, out_trans) 180 | return (grad_output, grad_mat1, grad_mat2, None, None, None) 181 | 182 | 183 | class SimpleConv2dFunction(Function): 184 | @staticmethod 185 | def forward(ctx, output, inputs, weights, bias, stride, padding, dilation, groups): 186 | #print("### conv2d fwd called input size: ", inputs.size(), weights.size(), stride, padding, dilation, groups) 187 | ctx.save_for_backward(inputs, weights)#, bias) 188 | ctx.stride = stride#[0] 189 | ctx.padding = padding#[0] 190 | ctx.dilation = dilation#[0] 191 | ctx.groups = groups 192 | 193 | if bias is None: 194 | bias_fw = torch.zeros(output.size()[1]) 195 | else : 196 | bias_fw = bias 197 | 198 | simple_conv2d_dev.conv2d_fp(output, inputs, weights, bias_fw, stride[0], padding[0], dilation[0], groups) 199 | ctx.mark_dirty(output) 200 | return output 201 | 202 | @staticmethod 203 | def backward(ctx, grad_output): 204 | #inputs, weights, bias = ctx.saved_tensors 205 | inputs, weights = ctx.saved_tensors 206 | stride = ctx.stride 207 | padding = ctx.padding 208 | dilation = ctx.dilation 209 | groups = ctx.groups 210 | #print("### conv2d bwd called input size: ", inputs.size(), weights.size(), stride, padding, dilation, groups) 211 | grad_inp = torch.zeros_like(inputs) 212 | grad_wts = torch.zeros_like(weights) 213 | 214 | simple_conv2d_dev.conv2d_bp(grad_inp, grad_output, weights, stride[0], padding[0], dilation[0], groups) 215 | simple_conv2d_dev.conv2d_wu(grad_wts, grad_output, inputs, stride[0], padding[0], dilation[0], groups) 216 | return (grad_output, grad_inp, grad_wts, None, None, None, None, None) 217 | -------------------------------------------------------------------------------- /mpemu/cmodel/simple/simple_conv2d.cpp: -------------------------------------------------------------------------------- 1 | /*----------------------------------------------------------------------------* 2 | * Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | * This file is part of FP8-Emulation-Toolkit 4 | * 5 | * SPDX-License-Identifier: BSD-3-Clause 6 | *----------------------------------------------------------------------------* 7 | * Naveen Mellempudi (Intel Corporation) 8 | *----------------------------------------------------------------------------*/ 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | extern int simple_conv2d_impl_fp(float* outputs, float *inputs, float *weights, float* bias, int N, int C, int iH, int iW, 19 | int K, int R, int S, int stride, int padding, int dilation, int groups); 20 | extern int simple_conv2d_impl_bp(float* inputs, float *outputs, float *weights, int N, int C, int iH, int iW, 21 | int K, int R, int S, int stride, int padding, int dilation, int groups); 22 | extern int simple_conv2d_impl_wu(float *weights, float *outputs, float *inputs, int N, int C, int iH, int iW, 23 | int K, int R, int S, int stride, int padding, int dilation, int groups); 24 | 25 | #define gettid() ((int)syscall(SYS_gettid)) 26 | 27 | using namespace torch::autograd::profiler; 28 | 29 | 30 | double get_time() { 31 | static bool init_done = false; 32 | static struct timespec stp = {0,0}; 33 | struct timespec tp; 34 | clock_gettime(CLOCK_REALTIME, &tp); 35 | 36 | if(!init_done) { 37 | init_done = true; 38 | stp = tp; 39 | } 40 | double ret = (tp.tv_sec - stp.tv_sec) * 1e3 + (tp.tv_nsec - stp.tv_nsec)*1e-6; 41 | return ret; 42 | } 43 | 44 | at::Tensor simple_conv2d_fp(torch::Tensor& output, torch::Tensor input, torch::Tensor weight, torch::Tensor bias, 45 | int stride, int padding, int dilation, int groups) 46 | { 47 | RECORD_FUNCTION("simple_conv2d_fp", std::vector({input, weight, bias})); 48 | 49 | auto N = input.size(0); 50 | auto C = input.size(1); 51 | auto H = input.size(2); 52 | auto W = input.size(3); 53 | 54 | auto K = weight.size(0); 55 | //auto C1 = weight.size(1); 56 | auto R = weight.size(2); 57 | auto S = weight.size(3); 58 | 59 | float *input_ptr = input.data_ptr(); 60 | float *weight_ptr = weight.data_ptr(); 61 | float *output_ptr = output.data_ptr(); 62 | float *bias_ptr = bias.data_ptr(); 63 | 64 | simple_conv2d_impl_fp(output_ptr, input_ptr, weight_ptr, bias_ptr, N, C, H, W, 65 | K, R, S, stride, padding, dilation, groups); 66 | 67 | //thnn_conv2d_out(output, input, weight, 68 | return output; 69 | } 70 | 71 | at::Tensor simple_conv2d_bp(torch::Tensor& input, torch::Tensor output, torch::Tensor weight, 72 | int stride, int padding, int dilation, int groups) 73 | { 74 | RECORD_FUNCTION("simple_conv2d_bp", std::vector({output, weight})); 75 | 76 | auto N = input.size(0); 77 | auto C = input.size(1); 78 | auto H = input.size(2); 79 | auto W = input.size(3); 80 | 81 | auto K = weight.size(0); 82 | //auto C1 = weight.size(1); 83 | auto R = weight.size(2); 84 | auto S = weight.size(3); 85 | 86 | float *input_ptr = input.data_ptr(); 87 | float *weight_ptr = weight.data_ptr(); 88 | float *output_ptr = output.data_ptr(); 89 | 90 | simple_conv2d_impl_bp(input_ptr, output_ptr, weight_ptr, N, C, H, W, 91 | K, R, S, stride, padding, dilation, groups); 92 | 93 | //thnn_conv2d_out(output, input, weight, 94 | return input; 95 | } 96 | 97 | at::Tensor simple_conv2d_wu(torch::Tensor& weight, torch::Tensor output, torch::Tensor input, 98 | int stride, int padding, int dilation, int groups) 99 | { 100 | RECORD_FUNCTION("simple_conv2d_wu", std::vector({output, input})); 101 | 102 | auto N = input.size(0); 103 | auto C = input.size(1); 104 | auto H = input.size(2); 105 | auto W = input.size(3); 106 | 107 | auto K = weight.size(0); 108 | //auto C1 = weight.size(1); 109 | auto R = weight.size(2); 110 | auto S = weight.size(3); 111 | 112 | float *input_ptr = input.data_ptr(); 113 | float *weight_ptr = weight.data_ptr(); 114 | float *output_ptr = output.data_ptr(); 115 | 116 | simple_conv2d_impl_wu(weight_ptr, output_ptr, input_ptr, N, C, H, W, 117 | K, R, S, stride, padding, dilation, groups); 118 | 119 | //thnn_conv2d_out(output, input, weight, 120 | return weight; 121 | } 122 | 123 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 124 | m.def("conv2d_fp", &simple_conv2d_fp, "simple conv_fp implementation"); 125 | m.def("conv2d_bp", &simple_conv2d_bp, "simple conv_bp implementation"); 126 | m.def("conv2d_wu", &simple_conv2d_wu, "simple conv_wu implementation"); 127 | } 128 | -------------------------------------------------------------------------------- /mpemu/cmodel/simple/simple_gemm.cpp: -------------------------------------------------------------------------------- 1 | /*----------------------------------------------------------------------------* 2 | * Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | * This file is part of FP8-Emulation-Toolkit 4 | * 5 | * SPDX-License-Identifier: BSD-3-Clause 6 | *----------------------------------------------------------------------------* 7 | * Naveen Mellempudi (Intel Corporation) 8 | *----------------------------------------------------------------------------*/ 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #if 10 18 | //extern "C" { 19 | extern int simple_sgemm_impl( char* transa, char* transb, int m, int n, int k, 20 | float alpha, float* a, int lda, float* b, int ldb, 21 | float beta, float* c, int ldc ); 22 | //} 23 | #endif 24 | #define gettid() ((int)syscall(SYS_gettid)) 25 | 26 | using namespace torch::autograd::profiler; 27 | using namespace torch; 28 | using namespace torch::autograd; 29 | using at::Tensor; 30 | 31 | double get_time() { 32 | static bool init_done = false; 33 | static struct timespec stp = {0,0}; 34 | struct timespec tp; 35 | clock_gettime(CLOCK_REALTIME, &tp); 36 | 37 | if(!init_done) { 38 | init_done = true; 39 | stp = tp; 40 | } 41 | double ret = (tp.tv_sec - stp.tv_sec) * 1e3 + (tp.tv_nsec - stp.tv_nsec)*1e-6; 42 | return ret; 43 | } 44 | 45 | at::Tensor simple_gemm(torch::Tensor& C, torch::Tensor A, torch::Tensor B, float alpha, bool a_trans, bool b_trans) 46 | { 47 | RECORD_FUNCTION("simple_gemm", std::vector({A, B, alpha})); 48 | 49 | const char *aT = a_trans ? "T" : "N"; 50 | const char *bT = b_trans ? "T" : "N"; 51 | 52 | auto M = C.size(0); 53 | auto N = C.size(1); 54 | auto K = a_trans ? A.size(0) : A.size(1); 55 | auto lda = A.size(1); 56 | auto ldb = B.size(1); 57 | auto ldc = C.size(1); 58 | 59 | float beta = 0.0; 60 | 61 | float *Aptr = A.data_ptr(); 62 | float *Bptr = B.data_ptr(); 63 | float *Cptr = C.data_ptr(); 64 | 65 | simple_sgemm_impl((char*)bT, (char*)aT, N, M, K, alpha, Bptr, ldb, Aptr, lda, beta, Cptr, ldc); 66 | return C; 67 | } 68 | 69 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 70 | m.def("gemm", &simple_gemm, "Simple Matrix Engine"); 71 | } 72 | -------------------------------------------------------------------------------- /mpemu/cmodel/simple/simple_gemm_impl.cpp: -------------------------------------------------------------------------------- 1 | /*----------------------------------------------------------------------------* 2 | * Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | * This file is part of FP8-Emulation-Toolkit 4 | * 5 | * SPDX-License-Identifier: BSD-3-Clause 6 | *----------------------------------------------------------------------------* 7 | * Naveen Mellempudi (Intel Corporation) 8 | *----------------------------------------------------------------------------*/ 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #define SCRATCH_SIZE 4294967296 19 | int lazy_init = 1; 20 | float* myscratch = NULL; 21 | 22 | extern void MMEngine_avx2_ps(int m, int n, int k, float alpha, float *A, int lda, 23 | float *B, int ldb, float beta, float *C, int ldc); 24 | 25 | __extern_always_inline 26 | void copy_matrix_and_pad(const float* in, float* out, const int LD, const int OD, int ld_pad, int od_pad) 27 | { 28 | int LDP, ODP; 29 | LDP = LD + ld_pad; 30 | ODP = OD + od_pad; 31 | int ld, od; 32 | int lp, op; 33 | 34 | for ( op = 0; op < OD; op++ ) { 35 | for ( lp = LD-1; lp < LDP; lp++ ) { 36 | out[op*LDP+lp] = 0.0; 37 | } 38 | } 39 | for ( op = OD-1; op < ODP; op++ ) { 40 | for ( lp = 0; lp < LDP; lp++ ) { 41 | out[op*LDP+lp] = 0.0; 42 | } 43 | } 44 | 45 | #if defined(_OPENMP) 46 | #pragma omp parallel for private(ld, od) 47 | #endif 48 | for ( od = 0; od < OD; od++ ) { 49 | for ( ld = 0; ld < LD; ld++ ) { 50 | out[od*LDP+ld] = in[od*LD+ld]; 51 | } 52 | } 53 | } 54 | 55 | __extern_always_inline 56 | void copy_matrix_and_strip_pading(const float* in, float* out, const int LD, const int OD, int ld_pad, int od_pad) 57 | { 58 | int LDP; 59 | LDP = LD + ld_pad; 60 | int ld, od; 61 | #if defined(_OPENMP) 62 | #pragma omp parallel for private(ld, od) 63 | #endif 64 | for ( od = 0; od < OD; od++ ) { 65 | for ( ld = 0; ld < LD; ld++ ) { 66 | out[od*LD+ld] = in[od*LDP+ld]; 67 | } 68 | } 69 | } 70 | 71 | int simple_sgemm_impl( char* transa, char* transb, int m, int n, int k, 72 | float alpha, float* a, int lda, float* b, int ldb, 73 | float beta, float* c, int ldc ) { 74 | float* myA = NULL; 75 | float* myB = NULL; 76 | float* myC = NULL; 77 | size_t ptlda = (size_t)(lda); 78 | size_t ptldb = (size_t)(ldb); 79 | size_t mylda = (size_t)(lda); 80 | size_t myldb = (size_t)(ldb); 81 | size_t myldc = (size_t)(ldc); 82 | int mym = (size_t)(m); 83 | int myn = (size_t)(n); 84 | int myk = (size_t)(k); 85 | int m_pad = 0; 86 | int n_pad = 0; 87 | int k_pad = 0; 88 | 89 | int o,p,q,pp,oo; 90 | 91 | /* check for size matching our TMUL emulation */ 92 | if ( mym % 16 != 0 ) { 93 | m_pad = (16-(mym%16)); 94 | mym += m_pad; 95 | } 96 | if ( myk % 64 != 0 ) { 97 | k_pad = (64-(myk%64)); 98 | myk += k_pad; 99 | } 100 | if ( myn % 16 != 0 ) { 101 | n_pad = (16-(myn%16)); 102 | myn += n_pad; 103 | } 104 | /* update leading dimensions with padded values */ 105 | mylda = mym; 106 | myldb = myk; 107 | myldc = mym; 108 | 109 | /* lazy init of fp8_gemm state */ 110 | if (lazy_init != 0) { 111 | lazy_init = 0; 112 | myscratch = (float*) _mm_malloc( SCRATCH_SIZE*sizeof(float), 4096 ); 113 | } 114 | /* check for sufficient scratch size */ 115 | if ( (*transa == 'N') && (*transb == 'N') ) { 116 | if ( ((ptlda*myk)+(ptldb*myn)) > SCRATCH_SIZE ) { 117 | return -1; 118 | } 119 | } else if ( (*transa == 'T') && (*transb == 'N') ) { 120 | if ( ((ptlda*mym)+(ptldb*myn)) > SCRATCH_SIZE ) { 121 | return -2; 122 | } 123 | mylda = mym; 124 | } else if ( (*transa == 'N') && (*transb == 'T') ) { 125 | if ( ((ptlda*myk)+(ptldb*myk)) > SCRATCH_SIZE ) { 126 | return -3; 127 | } 128 | myldb = myk; 129 | } else if ( (*transa == 'T') && (*transb == 'T') ) { 130 | if ( ((ptlda*mym)+(ptldb*myk)) > SCRATCH_SIZE ) { 131 | return -4; 132 | } 133 | mylda = mym; 134 | myldb = myk; 135 | } else { 136 | assert((0 && "Error : Invalid parameters")); 137 | return -5; 138 | } 139 | 140 | /* set temp A and B pointers */ 141 | myA = myscratch; 142 | myB = myscratch + (mylda*myk); 143 | myC = myscratch + (mylda*myk) + (myldb*myn); 144 | 145 | if ( *transa == 'T' ) { 146 | /* fill the padding with zeros */ 147 | for ( p = 0; p < k; p++ ) { 148 | for ( o = m-1; o < mym; o++ ) { 149 | myA[p*mylda+o] = 0.0; 150 | } 151 | } 152 | for ( p = k-1; p < myk; p++ ) { 153 | for ( o = 0; o < mym; o++ ) { 154 | myA[p*mylda+o] = 0.0; 155 | } 156 | } 157 | 158 | /* let's transpose data first */ 159 | #if defined(_OPENMP) 160 | #pragma omp parallel for private(o,p) collapse(2) 161 | #endif 162 | for ( p = 0; p < k; p++ ) { 163 | for ( o = 0; o < m; o++ ) { 164 | myA[(p*mylda)+o] = a[(o*ptlda)+p]; 165 | } 166 | } 167 | } else if ( m_pad > 0 || k_pad > 0 ) { 168 | copy_matrix_and_pad(a, myA, m, k, m_pad, k_pad); 169 | } else { 170 | myA = a; 171 | } 172 | 173 | if ( *transb == 'T' ) { 174 | /* fill the padding with zeros */ 175 | for ( p = 0; p < n; p++ ) { 176 | for ( o = k-1; o < myk; o++ ) { 177 | myB[p*myldb+o] = 0.0; 178 | } 179 | } 180 | for ( p = n-1; p < myn; p++ ) { 181 | for ( o = 0; o < myk; o++ ) { 182 | myB[p*myldb+o] = 0.0; 183 | } 184 | } 185 | 186 | /* let's transpose data first */ 187 | #if defined(_OPENMP) 188 | #pragma omp parallel for private(o,p) collapse(2) 189 | #endif 190 | for ( p = 0; p < n; p++ ) { 191 | for ( o = 0; o < k; o++ ) { 192 | myB[(p*myldb)+o] = b[(o*ptldb)+p]; 193 | } 194 | } 195 | } else if ( k_pad > 0 || n_pad > 0 ) { 196 | copy_matrix_and_pad(b, myB, k, n, k_pad, n_pad); 197 | } else { 198 | myB = b; 199 | } 200 | 201 | if ( m_pad > 0 || n_pad > 0 ) { 202 | copy_matrix_and_pad(c, myC, m, n, m_pad, n_pad); 203 | } else { 204 | myC = c; 205 | } 206 | /* run gemm */ 207 | #if defined(_OPENMP) 208 | #pragma omp parallel for private(o,p,q,pp,oo) collapse(2) 209 | #endif 210 | for ( o = 0; o < mym; o += 16 ) { 211 | for ( p = 0; p < myn; p += 16 ) { 212 | 213 | /* C initialized to zero */ 214 | __attribute__ ((aligned(64))) float ctmp[256]; 215 | for ( pp = 0; pp < 16; pp++ ) { 216 | for ( oo = 0; oo < 16; oo++ ) { 217 | ctmp[(pp*16)+oo] = 0.0f; 218 | } 219 | } 220 | /* compute a 16x16x64 block */ 221 | for ( q = 0; q < myk; q += 64 ) { 222 | MMEngine_avx2_ps(16, 16, 64, alpha, &(myA[(mylda*q)+o]), mylda, 223 | &(myB[(myldb*p)+q]), myldb, beta, ctmp, 16); 224 | } 225 | /* post accumulation */ 226 | for ( pp = 0; pp < 16; pp++ ) { 227 | for ( oo = 0; oo < 16; oo++ ) { 228 | myC[((p+pp)*myldc)+(o+oo)] += ((alpha)*ctmp[(pp*16)+oo]) + ((beta)*myC[((p+pp)*myldc)+(o+oo)]); 229 | } 230 | } 231 | } 232 | } 233 | if ( m_pad > 0 || n_pad > 0 ) { 234 | copy_matrix_and_strip_pading(myC, c, m, n, m_pad, n_pad); 235 | } 236 | return 0; 237 | } 238 | -------------------------------------------------------------------------------- /mpemu/cmodel/simple/simple_mm_engine.cpp: -------------------------------------------------------------------------------- 1 | /*----------------------------------------------------------------------------* 2 | * Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | * This file is part of FP8-Emulation-Toolkit 4 | * 5 | * SPDX-License-Identifier: BSD-3-Clause 6 | *----------------------------------------------------------------------------* 7 | * Naveen Mellempudi (Intel Corporation) 8 | *----------------------------------------------------------------------------*/ 9 | 10 | #include 11 | 12 | /* column major dat format */ 13 | void MMEngine_avx2_ps(int m, int n, int k, float alpha, float *A, int lda, 14 | float *B, int ldb, float beta, float *C, int ldc) 15 | { 16 | for (int j = 0; j < n; j++) { 17 | for (int i = 0; i < m; i+=8) { 18 | __m256 cij = _mm256_loadu_ps(&C[ j*ldc + i]); 19 | for (int l = 0; l < k; l++) { 20 | __m256 aik = _mm256_loadu_ps(&A[ i + l * lda]); 21 | __m256 bkj = _mm256_broadcast_ss(&B[ l + j * ldb]); 22 | cij = _mm256_add_ps(cij, _mm256_mul_ps(aik, bkj)); 23 | } 24 | _mm256_storeu_ps(&C[ j * ldc + i ], cij); 25 | } 26 | } 27 | } 28 | 29 | /* column major dat format */ 30 | void MMEngine_strideB_avx2_ps(int m, int n, int k, float alpha, float *A, int lda, 31 | float *B, int ldb, float beta, float *C, int ldc, int strideB) 32 | { 33 | for (int j = 0; j < n; j++) { 34 | for (int i = 0; i < m; i+=8) { 35 | __m256 cij = _mm256_loadu_ps(&C[ j*ldc + i]); 36 | for (int l = 0; l < k; l++) { 37 | __m256 aik = _mm256_loadu_ps(&A[ i + l * lda]); 38 | __m256 bkj = _mm256_broadcast_ss(&B[ l*strideB + j * ldb]); 39 | cij = _mm256_add_ps(cij, _mm256_mul_ps(aik, bkj)); 40 | } 41 | _mm256_storeu_ps(&C[ j * ldc + i], cij); 42 | } 43 | } 44 | } 45 | 46 | -------------------------------------------------------------------------------- /mpemu/cmodel/simple/vla.h: -------------------------------------------------------------------------------- 1 | /*----------------------------------------------------------------------------* 2 | * Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | * This file is part of FP8-Emulation-Toolkit 4 | * 5 | * SPDX-License-Identifier: BSD-3-Clause 6 | *----------------------------------------------------------------------------* 7 | * Naveen Mellempudi (Intel Corporation) 8 | *----------------------------------------------------------------------------*/ 9 | 10 | #ifndef VLA_H 11 | #define VLA_H 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #define ALWAYS_INLINE __attribute__((always_inline)) 18 | #define INLINE inline 19 | #define RESTRICT __restrict__ 20 | #define VLA_POSTFIX _ 21 | 22 | #define INDEX1_1(...) ((size_t)SELECT_HEAD(__VA_ARGS__)) 23 | #define INDEX1_2(I0, I1, S1) (INDEX1_1(I0) * ((size_t)S1) + (size_t)I1) 24 | #define INDEX1_3(I0, I1, I2, S1, S2) (INDEX1_2(I0, I1, S1) * ((size_t)S2) + (size_t)I2) 25 | #define INDEX1_4(I0, I1, I2, I3, S1, S2, S3) (INDEX1_3(I0, I1, I2, S1, S2) * ((size_t)S3) + (size_t)I3) 26 | #define INDEX1_5(I0, I1, I2, I3, I4, S1, S2, S3, S4) (INDEX1_4(I0, I1, I2, I3, S1, S2, S3) * ((size_t)S4) + (size_t)I4) 27 | #define INDEX1_6(I0, I1, I2, I3, I4, I5, S1, S2, S3, S4, S5) (INDEX1_5(I0, I1, I2, I3, I4, S1, S2, S3, S4) * ((size_t)S5) + (size_t)I5) 28 | #define INDEX1_7(I0, I1, I2, I3, I4, I5, I6, S1, S2, S3, S4, S5, S6) (INDEX1_6(I0, I1, I2, I3, I4, I5, S1, S2, S3, S4, S5) * ((size_t)S6) + (size_t)I6) 29 | #define INDEX1_8(I0, I1, I2, I3, I4, I5, I6, I7, S1, S2, S3, S4, S5, S6, S7) (INDEX1_7(I0, I1, I2, I3, I4, I5, I6, S1, S2, S3, S4, S5, S6) * ((size_t)S7) + (size_t)I7) 30 | #define INDEX1_9(I0, I1, I2, I3, I4, I5, I6, I7, I8, S1, S2, S3, S4, S5, S6, S7, S8) (INDEX1_8(I0, I1, I2, I3, I4, I5, I6, I7, S1, S2, S3, S4, S5, S6, S7) * ((size_t)S8) + (size_t)I8) 31 | #define INDEX1_10(I0, I1, I2, I3, I4, I5, I6, I7, I8, I9, S1, S2, S3, S4, S5, S6, S7, S8, S9) (INDEX1_9(I0, I1, I2, I3, I4, I5, I6, I7, I8, S1, S2, S3, S4, S5, S6, S7, S8) * ((size_t)S9) + (size_t)I9) 32 | 33 | #define EXPAND(...) __VA_ARGS__ 34 | #define CONCATENATE2(A, B) A##B 35 | #define CONCATENATE(A, B) CONCATENATE2(A, B) 36 | #define INDEX1(NDIMS, ...) CONCATENATE(INDEX1_, NDIMS)(__VA_ARGS__) 37 | 38 | #define SELECT_HEAD_AUX(A, ...) (A) 39 | #define SELECT_HEAD(...) EXPAND(SELECT_HEAD_AUX(__VA_ARGS__, 0)) 40 | #define SELECT_TAIL(A, ...) __VA_ARGS__ 41 | 42 | #define ACCESS_VLA(NDIMS, ARRAY, ...) CONCATENATE(ARRAY, VLA_POSTFIX)[INDEX1(NDIMS, __VA_ARGS__)] 43 | #define DECLARE_VLA(NDIMS, ELEMENT_TYPE, ARRAY_VAR, ...) \ 44 | ELEMENT_TYPE *RESTRICT CONCATENATE(ARRAY_VAR, VLA_POSTFIX) = SELECT_HEAD(__VA_ARGS__) \ 45 | + 0 * INDEX1(NDIMS, SELECT_TAIL(__VA_ARGS__, SELECT_TAIL(__VA_ARGS__, 0))) 46 | 47 | #endif 48 | -------------------------------------------------------------------------------- /mpemu/cmodel/tests/conv_grad_test.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | import torch 11 | import time 12 | from mpemu.cmodel import simple 13 | import numpy as np 14 | 15 | n = 64 16 | c = 256 17 | h = 28 18 | w = 28 19 | k = 256 20 | r = 3 21 | s = 3 22 | stride = 2 23 | pad = 2 24 | 25 | a = torch.rand((n,c,h,w), dtype=torch.float32) 26 | b = torch.rand((k,c,r,s), dtype=torch.float32) 27 | bias = torch.rand((k), dtype=torch.float32) 28 | 29 | #a = torch.rand((n,c,h,w), dtype=torch.float32) 30 | #b = torch.rand((k,c,r,s), dtype=torch.float32) 31 | #bias = torch.rand((k), dtype=torch.float32) 32 | 33 | a64 = a.to(dtype=torch.float64, copy=True) 34 | b64 = b.to(dtype=torch.float64, copy=True) 35 | bias64 = bias.to(dtype=torch.float64, copy=True) 36 | 37 | a.requires_grad=True 38 | b.requires_grad=True 39 | a64.requires_grad=True 40 | b64.requires_grad=True 41 | 42 | ref_time = time.time() 43 | z = torch.nn.functional.conv2d(a64, b64, bias64, stride=(stride,stride), padding=(pad,pad), dilation=(1,1), groups=1) 44 | ref_time = time.time()-ref_time 45 | 46 | simple_time = time.time() 47 | z2 = simple.conv2d(a, b, bias, stride=(stride,stride), padding=(pad,pad), dilation=(1,1), groups=1) 48 | ref_time = time.time()-simple_time 49 | 50 | ref_time = time.time() 51 | z3 = torch.nn.functional.conv2d(a, b, bias, stride=(stride,stride), padding=(pad,pad), dilation=(1,1), groups=1) 52 | ref_time = time.time()-ref_time 53 | 54 | #print("Forward Time : ref_time: {}, simple_time: {} ".format(ref_time, simple_time)) 55 | print('Forward: L2 distance f32 output : ', torch.dist(z3.to(dtype=torch.float64, copy=True), z, 2).item()) 56 | print('Forward: L2 distance output : ', torch.dist(z2.to(dtype=torch.float64, copy=True), z, 2).item()) 57 | 58 | ref_time = time.time() 59 | (z[0, 0] + z[0,1]).sum().backward() 60 | ref_time = time.time()-ref_time 61 | 62 | simple_time = time.time() 63 | (z2[0, 0] + z2[0,1]).sum().backward() 64 | simple_time = time.time()-simple_time 65 | 66 | #print("BackProp Time : ref_time: {}, simple_time: {}".format(ref_time, simple_time)) 67 | torch.set_printoptions(profile="full") 68 | print('Backward: L2 distance input_grad: ', torch.dist(a.grad.to(dtype=torch.float64, copy=True), a64.grad, 2).item()) 69 | print('Backward: L2 distance weight_grad: ', torch.dist(b.grad.to(dtype=torch.float64, copy=True), b64.grad, 2).item()) 70 | -------------------------------------------------------------------------------- /mpemu/cmodel/tests/conv_test.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | import torch 11 | from mpemu.cmodel import simple 12 | 13 | n = 16 14 | c = 256 15 | h = 28 16 | w = 28 17 | k = 256 18 | r = 3 19 | s = 3 20 | 21 | input = torch.rand((n,c,h,w), dtype=torch.float32) 22 | weight = torch.rand((k,c,r,s), dtype=torch.float32) 23 | bias = torch.rand((k), dtype=torch.float32) 24 | 25 | output_ref = torch.nn.functional.conv2d(input, weight, bias, stride=(2,2), padding=(2,2), dilation=(1,1), groups=1) 26 | output_simple = simple.conv2d(input, weight, bias, stride=(2,2), padding=(2,2), dilation=(1,1), groups=1) 27 | 28 | torch.set_printoptions(profile="full") 29 | 30 | print('Forward : L2 distance (simple) : ', torch.dist(output_ref, output_simple, 2).item()) 31 | 32 | #print(output_ref-output_simple) 33 | -------------------------------------------------------------------------------- /mpemu/cmodel/tests/gemm_grad_test.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | import torch 11 | from mpemu.cmodel import simple 12 | 13 | m=1024 #356320 #356305 14 | n=1024 #120 #256 15 | k=1024 #576 #128 #2020 16 | 17 | # Start here 18 | a = torch.rand((m, k), dtype=torch.float32) 19 | b = torch.rand((n, k), dtype=torch.float32) 20 | c = torch.zeros((m, n), dtype=torch.float32) 21 | 22 | a64 = a.to(dtype=torch.float64, copy=True) 23 | b64 = b.to(dtype=torch.float64, copy=True) 24 | c64 = c.to(dtype=torch.float64, copy=True) 25 | 26 | a.requires_grad=True 27 | b.requires_grad=True 28 | c.requires_grad=True 29 | a64.requires_grad=True 30 | b64.requires_grad=True 31 | c64.requires_grad=True 32 | 33 | z = torch.addmm(c64, a64, b64.t()) 34 | z2 = simple.addmm(c, a, b.t()) 35 | z3 = simple.addmm(c, a, b.t()) 36 | 37 | print('Forward :L2 distance f32 output: ', torch.dist(z3.to(dtype=torch.float64, copy=True), z, 2).item()) 38 | print('Forward :L2 distance output: ', torch.dist(z2.to(dtype=torch.float64, copy=True), z, 2).item()) 39 | 40 | (z2[0, 0] + z2[0,1]).sum().backward() 41 | (z[0, 0] + z[0,1]).sum().backward() 42 | 43 | print('Backward : L2 distance a_grad: ', torch.dist(a.grad.to(dtype=torch.float64, copy=True), a64.grad, 2).item()) 44 | print('Backward : L2 distance b_grad: ', torch.dist(b.grad.to(dtype=torch.float64, copy=True), b64.grad, 2).item()) 45 | print('Backward : L2 distance c_grad: ', torch.dist(c.grad.to(dtype=torch.float64, copy=True), c64.grad, 2).item()) 46 | -------------------------------------------------------------------------------- /mpemu/cmodel/tests/gemm_irregular_test.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | import torch 11 | import fuse 12 | 13 | m=1 #356320 #356305 14 | n=120 #256 15 | k=576 #128 #2020 16 | 17 | # Start here 18 | a = torch.rand((5, 107, 1024), dtype=torch.float32) 19 | b = torch.rand((1024, 1024), dtype=torch.float32) 20 | c = torch.zeros((m, n), dtype=torch.float32) 21 | 22 | a64 = a.to(dtype=torch.float64, copy=True) 23 | b64 = b.to(dtype=torch.float64, copy=True) 24 | c64 = c.to(dtype=torch.float64, copy=True) 25 | 26 | z = torch.matmul(a64, b64) 27 | z3 = torch.matmul(a, b) 28 | for a1 in a : 29 | print("-->", a1.size()) 30 | 31 | #z2l = tuple([fuse.tmul.matmul(a1, b) for a1 in a]) 32 | #z2 = torch.stack(z2l) 33 | #z2 = torch.stack(tuple([fuse.tmul.matmul(a1, b) for a1 in a])) 34 | z2 = fuse.tmul.matmul(a, b) 35 | 36 | #print (z3.size(), z2.size()) 37 | 38 | #print("torch :", z3.size(), z3) 39 | #print("Ours : ", z2.size(), z2) 40 | print('32b: L2 distance : ', torch.dist(z, z3.to(dtype=torch.float64, copy=True), 2)) 41 | print('ours: L2 distance : ', torch.dist(z, z2.to(dtype=torch.float64, copy=True), 2)) 42 | 43 | -------------------------------------------------------------------------------- /mpemu/cmodel/tests/gemm_test.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | import torch 11 | from mpemu.cmodel import simple 12 | 13 | def get_grads(variables): 14 | return [var.grad.clone() for var in variables] 15 | 16 | m=256 17 | n=512 18 | k=1024 19 | 20 | a = torch.rand((k, m), dtype=torch.float32) 21 | b = torch.rand((n, k), dtype=torch.float32) 22 | c = torch.zeros((m, n), dtype=torch.float32) 23 | 24 | a64 = a.to(dtype=torch.float64, copy=True) 25 | b64 = b.to(dtype=torch.float64, copy=True) 26 | c64 = c.to(dtype=torch.float64, copy=True) 27 | a32 = a.to(dtype=torch.float32, copy=True) 28 | b32 = b.to(dtype=torch.float32, copy=True) 29 | c32 = c.to(dtype=torch.float32, copy=True) 30 | print(a) 31 | print(b) 32 | 33 | z = torch.matmul(a64.t(), b64.t()) 34 | z2 = simple.matmul(a.t(), b.t(), out=c) 35 | z3 = torch.matmul(a32.t(), b32.t()) 36 | print(z) 37 | print(c-z2) 38 | print(z2.size()) 39 | print('output 32b: L2 distance : ', torch.dist(z3.to(dtype=torch.float64, copy=True), z, 2)) 40 | print('output : L2 distance : ', torch.dist(z2.to(dtype=torch.float64, copy=True), z, 2)) 41 | -------------------------------------------------------------------------------- /mpemu/cmodel/tests/linear_test.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | from mpemu.cmodel import simple 15 | 16 | class Net(nn.Module): 17 | 18 | def __init__(self): 19 | super(Net, self).__init__() 20 | self.fc = nn.Linear(576, 120) 21 | 22 | def forward(self, x): 23 | x = self.fc(x) 24 | return x 25 | 26 | def num_flat_features(self, x): 27 | size = x.size()[1:] # all dimensions except the batch dimension 28 | num_features = 1 29 | for s in size: 30 | num_features *= s 31 | return num_features 32 | 33 | net = Net() 34 | net1 = Net() 35 | print(net) 36 | input = torch.randn((1, 576), dtype=torch.float32, device="cpu") 37 | input_new = input.to(dtype=torch.float32, copy=True) 38 | 39 | output = net(input) 40 | 41 | target = torch.randn(120, dtype=torch.float32, device="cpu") # a dummy target, for example 42 | target = target.view(1, -1) # make it the same shape as output 43 | criterion = nn.MSELoss() 44 | loss = criterion(output, target) 45 | net.zero_grad() # zeroes the gradient buffers of all parameters 46 | #loss.backward(retain_graph=True) 47 | loss.backward() 48 | 49 | 50 | torch.addmm_back = torch.addmm 51 | torch.matmul_back = torch.matmul 52 | torch.addmm = simple.addmm 53 | torch.matmul = simple.matmul 54 | 55 | output1 = net1(input_new) 56 | 57 | loss1 = criterion(output1, target) 58 | net1.zero_grad() # zeroes the gradient buffers of all parameters 59 | loss1.backward() 60 | 61 | print('Linear wtgrads L2 distance : ', torch.dist(net.fc.weight.grad, net1.fc.weight.grad, 2)) 62 | -------------------------------------------------------------------------------- /mpemu/cmodel/tests/net.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | from mpemu.cmodel import simple 15 | 16 | torch_conv2d = torch.nn.functional.conv2d 17 | torch_addmm = torch.addmm 18 | torch.addmm = simple.addmm 19 | torch.nn.functional.conv2d = simple.conv2d 20 | 21 | class Net(nn.Module): 22 | 23 | def __init__(self): 24 | super(Net, self).__init__() 25 | # 1 input image channel, 6 output channels, 3x3 square convolution 26 | # kernel 27 | self.conv1 = nn.Conv2d(1, 6, 3) 28 | self.conv2 = nn.Conv2d(6, 16, 3) 29 | # an affine operation: y = Wx + b 30 | self.fc1 = nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension 31 | self.fc2 = nn.Linear(120, 84) 32 | self.fc3 = nn.Linear(84, 10) 33 | 34 | def forward(self, x): 35 | # Max pooling over a (2, 2) window 36 | x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) 37 | # If the size is a square you can only specify a single number 38 | x = F.max_pool2d(F.relu(self.conv2(x)), 2) 39 | x = x.view(-1, self.num_flat_features(x)) 40 | x = F.relu(self.fc1(x)) 41 | x = F.relu(self.fc2(x)) 42 | x = self.fc3(x) 43 | return x 44 | 45 | def num_flat_features(self, x): 46 | size = x.size()[1:] # all dimensions except the batch dimension 47 | num_features = 1 48 | for s in size: 49 | num_features *= s 50 | return num_features 51 | 52 | 53 | net = Net() 54 | print(net) 55 | params = list(net.parameters()) 56 | input = torch.randn(1, 1, 32, 32, dtype=torch.float32, device="cpu") 57 | # create your optimizer 58 | optimizer = optim.SGD(net.parameters(), lr=0.1) 59 | optimizer.zero_grad() # zero the gradient buffers 60 | 61 | output = net(input) 62 | 63 | print("fc3 output:", output) 64 | target = torch.randn(10, dtype=torch.float32, device="cpu") # a dummy target, for example 65 | target = target.view(1, -1) # make it the same shape as output 66 | criterion = nn.MSELoss() 67 | 68 | loss = criterion(output, target) 69 | 70 | net.zero_grad() # zeroes the gradient buffers of all parameters 71 | 72 | loss.backward() 73 | optimizer.step() # Does the update 74 | -------------------------------------------------------------------------------- /mpemu/e3m4_emu.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | from collections import OrderedDict 11 | import torch 12 | from .qutils import TensorQuantConfig, ModuleQuantConfig 13 | from .qutils import get_or_update_model_quant_config_dict 14 | from .qutils import reset_quantization_setup,add_quantization_hooks 15 | from .qutils import quantize_model_weights,set_quantize_weights_flag 16 | from .scale_shift import replace_batchnorms_with_scaleshifts 17 | from .module_wrappers import BatchMatmul,Matmul,AddMatmul,EltwiseMul,EltwiseAdd,EltwiseDiv 18 | from .module_wrappers import SparseConv2d,SparseLinear 19 | 20 | ''' 21 | E3M4 Emulator 22 | ''' 23 | class E3M4Emulator(object): 24 | def __init__(self, model, optimizer, sparse_config=None, device="cuda", verbose=False, tensor_stats=False): 25 | super(E3M4Emulator, self).__init__() 26 | self.whitelist = [torch.nn.Conv2d, torch.nn.Linear, torch.nn.Embedding, torch.nn.EmbeddingBag] 27 | self.whitelist += [Matmul, BatchMatmul, AddMatmul] 28 | self.whitelist += [EltwiseAdd, EltwiseMul, EltwiseDiv] 29 | self.whitelist += [SparseConv2d, SparseLinear] 30 | self.blacklist = [] 31 | self.list_unpatched = [] 32 | self.is_training = False 33 | self.list_exempt_layers = None 34 | self.list_layers_output_fused = None 35 | self.device = device 36 | self.patch_ops = False 37 | self.patch_impl = "NONE" 38 | self.patchlist = ["simple"] 39 | self.patchlist = ["SIMPLE"] 40 | self.sparse_config = sparse_config 41 | # default configuration 42 | self.data_emulation = True 43 | self.mod_qconfig = None 44 | self.model_qconfig_dict = None #OrderedDict() 45 | self.emb_qconfig = TensorQuantConfig("e3m4", "rne", "per-channel") 46 | self.wt_qconfig = TensorQuantConfig("e3m4", "rne", "per-channel") 47 | self.iact_qconfig = TensorQuantConfig("e3m4", "rne", "per-tensor") 48 | self.oact_qconfig = None #TensorQuantConfig("e3m4", "rne", "per-tensor") 49 | self.hook_handles = None 50 | self.verbose = verbose 51 | 52 | def blacklist_modules(self, list_modules) : 53 | if self.verbose: 54 | print("Current module list {}".format(self.whitelist)) 55 | print("Blacklist-ing modules {}".format(list_modules)) 56 | for module in list_modules : 57 | self.blacklist.append(module) 58 | self.whitelist.remove(module) 59 | self.create_or_update_hooks() 60 | if self.verbose: 61 | print("Updated module list {}".format(self.whitelist)) 62 | self.print_config() 63 | 64 | def whitelist_modules(self, list_modules) : 65 | if self.verbose: 66 | print("Current module list {}".format(self.whitelist)) 67 | print("Whitelist-ing modules {}".format(list_modules)) 68 | for module in list_modules : 69 | self.whitelist.append(module) 70 | self.blacklist.remove(module) 71 | self.create_or_update_hooks() 72 | if self.verbose: 73 | print("Updated module list {}".format(self.whitelist)) 74 | self.print_config() 75 | 76 | def create_or_update_hooks(self, model): 77 | self.model_qconfig_dict = get_or_update_model_quant_config_dict(model, 78 | self.whitelist, self.mod_qconfig, 79 | model_qconfig_dict=self.model_qconfig_dict, 80 | override=True) 81 | 82 | if self.list_exempt_layers is not None : 83 | for exempt_layer in self.list_exempt_layers: 84 | if self.model_qconfig_dict.get(exempt_layer) is not None: 85 | self.model_qconfig_dict.pop(exempt_layer) 86 | 87 | # Disable output quantization for these layers, 88 | # These layers are followed by precision sensitive layers such as SoftMax 89 | # In the final implementation, sensitive layers are fused with the preceeding layer 90 | if self.list_layers_output_fused is not None : 91 | for name,module in model.named_modules(): 92 | if name in self.list_layers_output_fused and name not in self.list_exempt_layers \ 93 | and module in self.whitelist : 94 | self.model_qconfig_dict[name].oact_qconfig = None 95 | self.model_qconfig_dict[name].ograd_qconfig = None 96 | 97 | # additional handling of HW patching 98 | for name,module in model.named_modules(): 99 | if type(module) in [torch.nn.Conv2d] and name in self.model_qconfig_dict: 100 | if module.in_channels < 64 or module.out_channels < 64: 101 | self.model_qconfig_dict[name].patch_ops = False 102 | self.model_qconfig_dict[name].patch_impl = "NONE" 103 | self.list_unpatched += [name] 104 | 105 | # Except for Conv2d, and Linear module disable quantization on weight 106 | for name,module in model.named_modules(): 107 | if type(module) not in [torch.nn.Conv2d, torch.nn.Linear]\ 108 | and name in self.model_qconfig_dict: 109 | self.model_qconfig_dict[name].wt_qconfig = None 110 | self.model_qconfig_dict[name].wtgrad_qconfig = None 111 | 112 | for name,module in model.named_modules(): 113 | if ((type(module) == torch.nn.Embedding) or (type(module) == torch.nn.EmbeddingBag))\ 114 | and name in self.model_qconfig_dict: 115 | self.model_qconfig_dict[name].wt_qconfig = self.emb_qconfig 116 | self.model_qconfig_dict[name].iact_qconfig = None 117 | self.model_qconfig_dict[name].igrad_qconfig = None 118 | self.model_qconfig_dict[name].ograd_qconfig = None 119 | self.model_qconfig_dict[name].oact_qconfig = None 120 | 121 | for name,module in model.named_modules(): 122 | if type(module) in [BatchMatmul]\ 123 | and name in self.model_qconfig_dict: 124 | self.model_qconfig_dict[name].wt_qconfig = None 125 | self.model_qconfig_dict[name].wtgrad_qconfig = None 126 | self.model_qconfig_dict[name].oact_qconfig = None 127 | self.model_qconfig_dict[name].ograd_qconfig = None 128 | 129 | reset_quantization_setup(model, self.model_qconfig_dict) 130 | # Adding hooks for quantizing input. 131 | self.hook_handles = add_quantization_hooks(model, self.model_qconfig_dict, is_training=self.is_training) 132 | if not self.is_training : 133 | print("e3m4 : quantizing model weights..") 134 | quantize_model_weights(model, self.model_qconfig_dict) 135 | set_quantize_weights_flag(model, self.model_qconfig_dict, False) 136 | 137 | def prepare_model(self, model, list_exempt_layers, list_layers_output_fused): 138 | mod_qconfig = ModuleQuantConfig(wt_qconfig=self.wt_qconfig, 139 | iact_qconfig=self.iact_qconfig, 140 | oact_qconfig=self.oact_qconfig, 141 | patch_ops=self.patch_ops) 142 | mod_qconfig.device = self.device 143 | mod_qconfig.patch_impl = self.patch_impl 144 | mod_qconfig.sparse_config = self.sparse_config 145 | self.mod_qconfig = mod_qconfig 146 | self.list_exempt_layers = list_exempt_layers 147 | self.list_layers_output_fused = list_layers_output_fused 148 | self.create_or_update_hooks(model) 149 | 150 | def enable_hw_patching(self, patch_ops): 151 | if patch_ops != 'NONE': 152 | if patch_ops in self.patchlist : 153 | self.patch_ops = True 154 | self.patch_impl = patch_ops 155 | print("e3m4_emulator: PyTorch Ops are monkey-patched to use {} kernels : {}".format(self.patch_impl, self.patch_ops)) 156 | else : 157 | raise RuntimeError("e3m4_emulator: HW patching is not supported for {}, supported list of options : {}".format(patch_ops, self.patchlist)) 158 | 159 | def set_calibration_qconfig(self): 160 | self.emb_qconfig = TensorQuantConfig("e3m4", "rne", "per-tensor") 161 | self.wt_qconfig = TensorQuantConfig("e3m4", "rne", "per-tensor") 162 | self.iact_qconfig = TensorQuantConfig("e3m4", "rne", "per-tensor") 163 | self.oact_qconfig = None 164 | 165 | def set_default_inference_qconfig(self): 166 | self.emb_qconfig = TensorQuantConfig("e3m4", "rne", "per-channel") 167 | self.wt_qconfig = TensorQuantConfig("e3m4", "rne", "per-channel") 168 | self.iact_qconfig = TensorQuantConfig("e3m4", "rne", "per-tensor") 169 | self.oact_qconfig = None 170 | 171 | def fuse_layers_and_quantize_model(self, model): 172 | if self.is_training : 173 | print("Warning : emulator.is_training is set to True, returning the model unchanged") 174 | return model 175 | if self.verbose : 176 | print("Fusing Batchnorm layers and replacing them with scale and shift") 177 | 178 | model = replace_batchnorms_with_scaleshifts(model) 179 | self.is_training = False 180 | self.set_default_inference_qconfig() 181 | self.prepare_model(model, self.list_exempt_layers, self.list_layers_output_fused) 182 | 183 | #reset_quantization_setup(model, self.model_qconfig_dict) 184 | #add_quantization_hooks(model, self.model_qconfig_dict) 185 | ##quantize_model_weights(model, self.model_qconfig_dict) # added new 186 | #set_quantize_weights_flag(model, self.model_qconfig_dict, False) 187 | model = model.to(self.device) 188 | if self.verbose : 189 | self.print_config() 190 | 191 | return model 192 | 193 | def print_config(self): 194 | for key in self.model_qconfig_dict: 195 | print("{} {:40s}".format(self.model_qconfig_dict[key], key)) 196 | 197 | def __repr__(self): 198 | train_infer = "inference" 199 | if self.is_training : 200 | train_infer = "training" 201 | return "[Configured to run {} on {}, using AMP: {}]".format(str(train_infer), self.device, str(self.using_apex)) 202 | -------------------------------------------------------------------------------- /mpemu/e4m3_emu.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | from collections import OrderedDict 11 | import torch 12 | from .qutils import TensorQuantConfig, ModuleQuantConfig 13 | from .qutils import get_or_update_model_quant_config_dict 14 | from .qutils import reset_quantization_setup,add_quantization_hooks 15 | from .qutils import quantize_model_weights,set_quantize_weights_flag 16 | from .scale_shift import replace_batchnorms_with_scaleshifts 17 | from .module_wrappers import BatchMatmul,Matmul,AddMatmul,EltwiseMul,EltwiseAdd,EltwiseDiv 18 | from .module_wrappers import SparseConv2d,SparseLinear 19 | 20 | ''' 21 | E4M3 Mixed Precision Emulator 22 | ''' 23 | class E4M3Emulator(object): 24 | def __init__(self, model, optimizer, sparse_config=None, device="cuda", train=False, verbose=False, tensor_stats=False): 25 | super(E4M3Emulator, self).__init__() 26 | self.whitelist = [torch.nn.Conv2d, torch.nn.Linear, torch.nn.Embedding, torch.nn.EmbeddingBag] 27 | self.whitelist += [Matmul, BatchMatmul, AddMatmul] 28 | self.whitelist += [EltwiseAdd, EltwiseMul, EltwiseDiv] 29 | self.whitelist += [SparseConv2d, SparseLinear] 30 | self.blacklist = [] 31 | self.list_unpatched = [] 32 | self.is_training = train 33 | self.list_exempt_layers = None 34 | self.list_layers_output_fused = None 35 | self.device = device 36 | self.model = model 37 | self.patch_ops = False 38 | self.patch_impl = "NONE" 39 | self.patchlist = ["simple"] 40 | self.patchlist = ["SIMPLE"] 41 | self.sparse_config = sparse_config 42 | # default configuration 43 | self.data_emulation = True 44 | self.mod_qconfig = None 45 | self.model_qconfig_dict = None #OrderedDict() 46 | self.emb_qconfig = TensorQuantConfig("e4m3", "rne")#, "per-channel") 47 | self.wt_qconfig = TensorQuantConfig("e4m3", "rne")#, "per-channel") 48 | self.iact_qconfig = TensorQuantConfig("e4m3", "rne")#, "per-tensor") 49 | self.oact_qconfig = None #TensorQuantConfig("e4m3", "rne") 50 | self.hook_handles = None 51 | self.verbose = verbose 52 | 53 | def blacklist_modules(self, list_modules) : 54 | if self.verbose: 55 | print("Current module list {}".format(self.whitelist)) 56 | print("Blacklist-ing modules {}".format(list_modules)) 57 | for module in list_modules : 58 | self.blacklist.append(module) 59 | self.whitelist.remove(module) 60 | self.create_or_update_hooks() 61 | if self.verbose: 62 | print("Updated module list {}".format(self.whitelist)) 63 | self.print_config() 64 | 65 | def whitelist_modules(self, list_modules) : 66 | if self.verbose: 67 | print("Current module list {}".format(self.whitelist)) 68 | print("Whitelist-ing modules {}".format(list_modules)) 69 | for module in list_modules : 70 | self.whitelist.append(module) 71 | self.blacklist.remove(module) 72 | self.create_or_update_hooks() 73 | if self.verbose: 74 | print("Updated module list {}".format(self.whitelist)) 75 | self.print_config() 76 | 77 | def create_or_update_hooks(self, model): 78 | self.model_qconfig_dict = get_or_update_model_quant_config_dict(model, 79 | self.whitelist, self.mod_qconfig, 80 | model_qconfig_dict=self.model_qconfig_dict, 81 | override=True) 82 | 83 | if self.list_exempt_layers is not None : 84 | for exempt_layer in self.list_exempt_layers: 85 | if self.model_qconfig_dict.get(exempt_layer) is not None: 86 | self.model_qconfig_dict.pop(exempt_layer) 87 | 88 | # Disable output quantization for these layers, 89 | # These layers are followed by precision sensitive layers such as SoftMax 90 | # In the final implementation, sensitive layers are fused with the preceeding layer 91 | if self.list_layers_output_fused is not None : 92 | for name,module in model.named_modules(): 93 | if name in self.list_layers_output_fused and name not in self.list_exempt_layers \ 94 | and module in self.whitelist : 95 | self.model_qconfig_dict[name].oact_qconfig = None 96 | self.model_qconfig_dict[name].ograd_qconfig = None 97 | 98 | # additional handling of HW patching 99 | for name,module in model.named_modules(): 100 | if type(module) in [torch.nn.Conv2d] and name in self.model_qconfig_dict: 101 | if module.in_channels < 64 or module.out_channels < 64: 102 | self.model_qconfig_dict[name].patch_ops = False 103 | self.model_qconfig_dict[name].patch_impl = "NONE" 104 | self.list_unpatched += [name] 105 | 106 | # Except for Conv2d, and Linear module disable quantization on weight 107 | for name,module in model.named_modules(): 108 | if type(module) not in [torch.nn.Conv2d, torch.nn.Linear]\ 109 | and name in self.model_qconfig_dict: 110 | self.model_qconfig_dict[name].wt_qconfig = None 111 | self.model_qconfig_dict[name].wtgrad_qconfig = None 112 | 113 | for name,module in model.named_modules(): 114 | if ((type(module) == torch.nn.Embedding) or (type(module) == torch.nn.EmbeddingBag))\ 115 | and name in self.model_qconfig_dict: 116 | self.model_qconfig_dict[name].wt_qconfig = self.emb_qconfig 117 | self.model_qconfig_dict[name].iact_qconfig = None 118 | self.model_qconfig_dict[name].igrad_qconfig = None 119 | self.model_qconfig_dict[name].ograd_qconfig = None 120 | self.model_qconfig_dict[name].oact_qconfig = None 121 | 122 | for name,module in model.named_modules(): 123 | if type(module) in [BatchMatmul]\ 124 | and name in self.model_qconfig_dict: 125 | self.model_qconfig_dict[name].wt_qconfig = None 126 | self.model_qconfig_dict[name].wtgrad_qconfig = None 127 | self.model_qconfig_dict[name].oact_qconfig = None 128 | self.model_qconfig_dict[name].ograd_qconfig = None 129 | 130 | reset_quantization_setup(model, self.model_qconfig_dict) 131 | # Adding hooks for quantizing input. 132 | self.hook_handles = add_quantization_hooks(model, self.model_qconfig_dict, is_training=self.is_training) 133 | if not self.is_training : 134 | print("e4m3 : quantizing model weights..") 135 | quantize_model_weights(model, self.model_qconfig_dict) 136 | set_quantize_weights_flag(model, self.model_qconfig_dict, False) 137 | 138 | def prepare_model(self, model, list_exempt_layers, list_layers_output_fused): 139 | mod_qconfig = ModuleQuantConfig(wt_qconfig=self.wt_qconfig, 140 | iact_qconfig=self.iact_qconfig, 141 | oact_qconfig=self.oact_qconfig, 142 | patch_ops=self.patch_ops) 143 | mod_qconfig.device = self.device 144 | mod_qconfig.patch_impl = self.patch_impl 145 | mod_qconfig.sparse_config = self.sparse_config 146 | self.mod_qconfig = mod_qconfig 147 | self.list_exempt_layers = list_exempt_layers 148 | self.list_layers_output_fused = list_layers_output_fused 149 | self.create_or_update_hooks(model) 150 | 151 | def enable_hw_patching(self, patch_ops): 152 | if patch_ops != 'NONE': 153 | if patch_ops in self.patchlist : 154 | self.patch_ops = True 155 | self.patch_impl = patch_ops 156 | print("e4m3_emulator: PyTorch Ops are monkey-patched to use {} kernels : {}".format(self.patch_impl, self.patch_ops)) 157 | else : 158 | raise RuntimeError("e4m3_emulator: HW patching is not supported for {}, supported list of options : {}".format(patch_ops, self.patchlist)) 159 | 160 | def fuse_batchnorm_with_convolution(self, model): 161 | from torch.nn.utils.fusion import fuse_conv_bn_eval 162 | temp = [] 163 | for name, module in model.named_children(): 164 | if list(module.named_children()): 165 | self.fuse_batchnorm_with_convolution(module) 166 | 167 | if isinstance(module, torch.nn.BatchNorm2d): 168 | if isinstance(temp[-1][1], torch.nn.Conv2d): 169 | setattr(model, temp[-1][0], fuse_conv_bn_eval(temp[-1][1], module)) 170 | setattr(model, name, torch.nn.Identity()) 171 | else: 172 | temp.append((name, module)) 173 | return model 174 | 175 | def set_calibration_qconfig(self): 176 | self.emb_qconfig = TensorQuantConfig("e4m3", "rne", "per-channel") 177 | self.wt_qconfig = TensorQuantConfig("e4m3", "rne", "per-channel") 178 | self.iact_qconfig = TensorQuantConfig("e4m3", "rne", "per-tensor") 179 | self.oact_qconfig = None 180 | 181 | def set_default_inference_qconfig(self): 182 | self.emb_qconfig = TensorQuantConfig("e4m3", "rne", "per-channel") 183 | self.wt_qconfig = TensorQuantConfig("e4m3", "rne", "per-channel") 184 | self.iact_qconfig = TensorQuantConfig("e4m3", "rne", "per-tensor") 185 | self.oact_qconfig = None 186 | 187 | def fuse_layers_and_quantize_model(self, model): 188 | if self.is_training : 189 | print("Warning : emulator.is_training is set to True, returning the model unchanged") 190 | return model 191 | if self.verbose : 192 | print("Fusing Batchnorm layers and replacing them with scale and shift") 193 | 194 | model = replace_batchnorms_with_scaleshifts(model) 195 | #model = self.fuse_batchnorm_with_convolution(model) 196 | self.is_training = False 197 | self.set_default_inference_qconfig() 198 | self.prepare_model(model, self.list_exempt_layers, self.list_layers_output_fused) 199 | #reset_quantization_setup(model, self.model_qconfig_dict) 200 | #add_quantization_hooks(model, self.model_qconfig_dict) 201 | #quantize_model_weights(model, self.model_qconfig_dict) # added new 202 | #set_quantize_weights_flag(model, self.model_qconfig_dict, False) 203 | model = model.to(self.device) 204 | if self.verbose : 205 | self.print_config() 206 | 207 | return model 208 | 209 | def print_config(self): 210 | for key in self.model_qconfig_dict: 211 | print("{} {:40s}".format(self.model_qconfig_dict[key], key)) 212 | 213 | def __repr__(self): 214 | train_infer = "inference" 215 | if self.is_training : 216 | train_infer = "training" 217 | return "[Configured to run {} on {}, using AMP: {}]".format(str(train_infer), self.device, str(self.using_apex)) 218 | -------------------------------------------------------------------------------- /mpemu/module_wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | INTEL CONFIDENTIAL 3 | Copyright (C) 2020 Intel Corporation 4 | This software and the related documents are Intel copyrighted materials, and your 5 | use of them is governed by the express license under which they were provided 6 | to you ("License"). Unless the License provides otherwise, you may not use, modify, 7 | copy, publish, distribute, disclose or transmit this software or the related 8 | documents without Intel's prior written permission. This software and the related 9 | documents are provided as is, with no express or implied warranties, other than 10 | those that are expressly stated in the License. 11 | ''' 12 | ''' 13 | Author(s) : Dharma Teja Vooturi, Naveen Mellempudi 14 | ''' 15 | from .eltwise import EltwiseAdd, EltwiseMul, EltwiseDiv 16 | from .matmul import Matmul, BatchMatmul, AddMatmul 17 | from .aggregate import Norm, Mean 18 | from .adasparse import SparseLinear, SparseConv2d 19 | -------------------------------------------------------------------------------- /mpemu/module_wrappers/adasparse.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | import math 11 | import torch 12 | import torch.nn as nn 13 | 14 | """ 15 | Function for activation binarization 16 | """ 17 | class WeightMaskStep(torch.autograd.Function): 18 | @staticmethod 19 | def forward(ctx, input): 20 | ctx.save_for_backward(input) 21 | return (input>0.).to(input.dtype) 22 | 23 | @staticmethod 24 | def backward(ctx, grad_output): 25 | input, = ctx.saved_tensors 26 | grad_input = grad_output.clone() 27 | zero_index = torch.abs(input) > 1 28 | middle_index = (torch.abs(input) <= 1) * (torch.abs(input) > 0.4) 29 | additional = 2-4*torch.abs(input) 30 | additional[zero_index] = 0. 31 | additional[middle_index] = 0.4 32 | return grad_input*additional 33 | 34 | class SparseLinear(nn.Module): 35 | def __init__(self, in_features, out_features, bias=True): 36 | super(SparseLinear, self).__init__() 37 | self.in_features = in_features 38 | self.out_features = out_features 39 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) 40 | if bias: 41 | self.bias = nn.Parameter(torch.Tensor(out_features)) 42 | else: 43 | self.register_parameter('bias', None) 44 | 45 | self.threshold = nn.Parameter(torch.Tensor(out_features)) 46 | self.weight_mask = WeightMaskStep.apply 47 | #self.mask = None 48 | self.reset_parameters() 49 | 50 | 51 | def reset_parameters(self): 52 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 53 | if self.bias is not None: 54 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 55 | bound = 1 / math.sqrt(fan_in) 56 | nn.init.uniform_(self.bias, -bound, bound) 57 | with torch.no_grad(): 58 | #std = self.weight.std() 59 | self.threshold.data.fill_(0) 60 | 61 | def forward(self, input): 62 | abs_weight = torch.abs(self.weight) 63 | threshold = self.threshold.view(abs_weight.shape[0], -1) 64 | abs_weight = abs_weight-threshold 65 | mask = self.weight_mask(abs_weight) 66 | ratio = torch.sum(mask) / mask.numel() 67 | #print("keep ratio {:.2f}".format(ratio)) 68 | if ratio <= 0.01: 69 | with torch.no_grad(): 70 | #std = self.weight.std() 71 | self.threshold.data.fill_(0) 72 | abs_weight = torch.abs(self.weight) 73 | threshold = self.threshold.view(abs_weight.shape[0], -1) 74 | abs_weight = abs_weight-threshold 75 | mask = self.weight_mask(abs_weight) 76 | masked_weight = self.weight * mask 77 | output = torch.nn.functional.linear(input, masked_weight, self.bias) 78 | return output 79 | def extra_repr(self) -> str: 80 | return 'in_features={}, out_features={}, bias={}'.format( 81 | self.in_features, self.out_features, self.bias is not None 82 | ) 83 | 84 | class SparseConv2d(nn.Module): 85 | def __init__(self, in_c, out_c, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode="zeros"): 86 | super(SparseConv2d, self).__init__() 87 | self.in_channels = in_c 88 | self.out_channels = out_c 89 | self.kernel_size = kernel_size 90 | self.stride = stride 91 | self.padding = padding 92 | self.dilation = dilation 93 | self.groups = groups 94 | self.padding_mode = padding_mode 95 | 96 | ## define weight 97 | self.weight = nn.Parameter(torch.Tensor( 98 | out_c, in_c // groups, *kernel_size 99 | )) 100 | if bias: 101 | self.bias = nn.Parameter(torch.Tensor(out_c)) 102 | else: 103 | self.register_parameter('bias', None) 104 | self.threshold = nn.Parameter(torch.Tensor(out_c)) 105 | self.weight_mask = WeightMaskStep.apply 106 | self.reset_parameters() 107 | 108 | def reset_parameters(self): 109 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 110 | if self.bias is not None: 111 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 112 | bound = 1 / math.sqrt(fan_in) 113 | nn.init.uniform_(self.bias, -bound, bound) 114 | with torch.no_grad(): 115 | self.threshold.data.fill_(0.) 116 | 117 | def forward(self, x): 118 | weight_shape = self.weight.shape 119 | threshold = self.threshold.view(weight_shape[0], -1) 120 | weight = torch.abs(self.weight) 121 | weight = weight.view(weight_shape[0], -1) 122 | weight = weight - threshold 123 | mask = self.weight_mask(weight) 124 | mask = mask.view(weight_shape) 125 | ratio = torch.sum(mask) / mask.numel() 126 | # print("threshold {:3f}".format(self.threshold[0])) 127 | # print("keep ratio {:.2f}".format(ratio)) 128 | if ratio <= 0.01: 129 | with torch.no_grad(): 130 | self.threshold.data.fill_(0.) 131 | threshold = self.threshold.view(weight_shape[0], -1) 132 | weight = torch.abs(self.weight) 133 | weight = weight.view(weight_shape[0], -1) 134 | weight = weight - threshold 135 | mask = self.weight_mask(weight) 136 | mask = mask.view(weight_shape) 137 | masked_weight = self.weight * mask 138 | 139 | conv_out = torch.nn.functional.conv2d(x, masked_weight, bias=self.bias, stride=self.stride, 140 | padding=self.padding, dilation=self.dilation, groups=self.groups) 141 | return conv_out 142 | 143 | def extra_repr(self): 144 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' 145 | ', stride={stride}') 146 | if self.padding != (0,) * len(self.padding): 147 | s += ', padding={padding}' 148 | if self.dilation != (1,) * len(self.dilation): 149 | s += ', dilation={dilation}' 150 | if self.groups != 1: 151 | s += ', groups={groups}' 152 | if self.bias is None: 153 | s += ', bias=False' 154 | if self.padding_mode != 'zeros': 155 | s += ', padding_mode={padding_mode}' 156 | return s.format(**self.__dict__) 157 | 158 | def __setstate__(self, state): 159 | super(SparseConv2d, self).__setstate__(state) 160 | if not hasattr(self, 'padding_mode'): 161 | self.padding_mode = 'zeros' 162 | -------------------------------------------------------------------------------- /mpemu/module_wrappers/aggregate.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | class Norm(nn.Module): 14 | def __init__(self, p='fro', dim=None, keepdim=False): 15 | super(Norm, self).__init__() 16 | self.p = p 17 | self.dim = dim 18 | self.keepdim = keepdim 19 | 20 | def forward(self, x: torch.Tensor): 21 | return torch.norm(x, p=self.p, dim=self.dim, keepdim=self.keepdim) 22 | def extra_repr(self) -> str: 23 | return 'p={}, dim={}, keepdim: {}'.format(self.p, self.dim, self.keepdim) 24 | 25 | class Mean(nn.Module): 26 | def __init__(self, *args, **kwargs): 27 | super(Mean, self).__init__() 28 | self.args = args 29 | self.kwargs = kwargs 30 | 31 | def forward(self, x: torch.Tensor): 32 | return torch.mean(x, *self.args, **self.kwargs) 33 | def extra_repr(self) -> str: 34 | return 'args={}, kwargs={}'.format(self.args, self.kwargs) 35 | 36 | -------------------------------------------------------------------------------- /mpemu/module_wrappers/eltwise.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Dharma Teja Vooturi, Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | class EltwiseAdd(nn.Module): 14 | def __init__(self, inplace=False): 15 | super(EltwiseAdd, self).__init__() 16 | 17 | self.inplace = inplace 18 | 19 | def forward(self, *input): 20 | res = input[0] 21 | if self.inplace: 22 | for t in input[1:]: 23 | res += t 24 | else: 25 | for t in input[1:]: 26 | res = res + t 27 | return res 28 | 29 | def extra_repr(self) -> str: 30 | return 'inplace={}'.format(self.inplace) 31 | 32 | class EltwiseMul(nn.Module): 33 | def __init__(self, inplace=False): 34 | super(EltwiseMult, self).__init__() 35 | self.inplace = inplace 36 | 37 | def forward(self, *input): 38 | res = input[0] 39 | if self.inplace: 40 | for t in input[1:]: 41 | res *= t 42 | else: 43 | for t in input[1:]: 44 | res = res * t 45 | return res 46 | def extra_repr(self) -> str: 47 | return 'inplace={}'.format(self.inplace) 48 | 49 | 50 | class EltwiseDiv(nn.Module): 51 | def __init__(self, inplace=False): 52 | super(EltwiseDiv, self).__init__() 53 | self.inplace = inplace 54 | 55 | def forward(self, x: torch.Tensor, y): 56 | if self.inplace: 57 | return x.div_(y) 58 | return x.div(y) 59 | def extra_repr(self) -> str: 60 | return 'inplace={}'.format(self.inplace) 61 | 62 | -------------------------------------------------------------------------------- /mpemu/module_wrappers/matmul.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Dharma Teja Vooturi, Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | class Matmul(nn.Module): 14 | def __init__(self): 15 | super(Matmul, self).__init__() 16 | 17 | def forward(self, a: torch.Tensor, b: torch.Tensor): 18 | return torch.matmul(a, b) 19 | 20 | class BatchMatmul(nn.Module): 21 | def __init__(self): 22 | super(BatchMatmul, self).__init__() 23 | 24 | def forward(self, a: torch.Tensor, b:torch.Tensor): 25 | return torch.bmm(a, b) 26 | 27 | class AddMatmul(nn.Module): 28 | def __init__(self): 29 | super(AddMatmul, self).__init__() 30 | 31 | def forward(self, input:torch.Tensor, mat1: torch.Tensor, mat2:torch.Tensor, beta=1, alpha=1): 32 | return torch.addmm(input, mat1, mat2, beta=beta, alpha=alpha) 33 | -------------------------------------------------------------------------------- /mpemu/mpt_emu.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | from collections import OrderedDict 11 | import torch 12 | import os 13 | from .qutils import TensorQuantConfig, ModuleQuantConfig 14 | from .qutils import get_or_update_model_quant_config_dict 15 | from .qutils import reset_quantization_setup,add_quantization_hooks 16 | from .qutils import quantize_model_weights,set_quantize_weights_flag 17 | from .e5m2_emu import E5M2Emulator 18 | from .e4m3_emu import E4M3Emulator 19 | from .e3m4_emu import E3M4Emulator 20 | from .hybrid_emu import HybridEmulator 21 | from .bfloat16_emu import Bfloat16Emulator 22 | from .scale_shift import replace_batchnorms_with_scaleshifts 23 | from .module_wrappers import SparseLinear, SparseConv2d 24 | from .sparse_utils import SparseConfig 25 | 26 | ''' 27 | Mixed Precision Emulator 28 | ''' 29 | class MPTEmulator(object): 30 | def __init__(self, device='cuda', training_algo='direct', hw_patch='none', pruning_algo='none'): 31 | super(MPTEmulator, self).__init__() 32 | #self.valid_dtypes = ["fp32", "fp16", "bf16", "e5m2", "e4m3", "e3m4"] 33 | self.valid_training_methods = ["direct", "hybrid"] 34 | self.valid_hw_patches = ["none", "simple"] 35 | self.valid_pruning_methods = ["none", "fine-grained", "unstructured", "adaptive", 'auto'] 36 | # basic checks 37 | #if dtype.lower() not in self.valid_dtypes: 38 | # raise RuntimeError("mpt_emulator: the requested data type {} is not supported, use one of the following: {}" 39 | # .format(dtype, self.valid_dtypes)) 40 | if training_algo.lower() not in self.valid_training_methods: 41 | raise RuntimeError("mpt_emulator: the requested training method {} is not supported, use one of the following: {}" 42 | .format(hw_patch, self.valid_training_methods)) 43 | if hw_patch.lower() not in self.valid_hw_patches: 44 | raise RuntimeError("mpt_emulator: the requested hardware emulation {} is not supported, use one of the following: {}" 45 | .format(hw_patch, self.valid_hw_patches)) 46 | if pruning_algo.lower() not in self.valid_pruning_methods: 47 | raise RuntimeError("mpt_emulator: the requested pruning method {} is not supported, use one of the following: {}" 48 | .format(pruning_algo.lower(), self.valid_pruning_methods)) 49 | 50 | self.device = device 51 | self.training_algo = training_algo.lower() 52 | self.hw_patch = hw_patch.lower() 53 | self.pruning_method = pruning_algo.lower() 54 | self.wt_sparsity = 0.5 55 | self.grad_sparsity = 0.5 56 | 57 | # emiulator object 58 | self.emulator = None 59 | self.sparse_config = None 60 | 61 | def blacklist_modules(self, list_modules): 62 | self.emulator.blacklist_modules(list_modules) 63 | 64 | def whitelist_modules(self, list_modules): 65 | self.emulator.whitelist_modules(list_modules) 66 | 67 | def optimizer_step(self, optimizer): 68 | self.emulator.optimizer_step(optimizer) 69 | if self.sparse_config != None: 70 | self.sparse_config.optimizer_step(optimizer) 71 | 72 | def update_global_steps(self, global_steps): 73 | self.emulator.global_steps = global_steps 74 | 75 | def set_master_param_precision(self, master_params): 76 | self.emulator.set_master_param_precision(master_params) 77 | 78 | def set_embedding_precision(self, emb_precision, emb_norm=False): 79 | self.emulator.set_embedding_precision(emb_precision, emb_norm) 80 | 81 | def enable_tensor_stats(self, summary_writer=None): 82 | self.emulator.enable_tensor_stats(summary_writer) 83 | 84 | def set_tensor_bindump_schedule(self, list_bindump_schedule): 85 | self.emulator.set_tensor_bindump_schedule(list_bindump_schedule) 86 | 87 | def enable_tensorboard_logging(self, summary_writer=None): 88 | self.emulator.enable_tensor_stats(summary_writer) 89 | 90 | def set_target_sparsity_weight(self, wt_sparsity=0.5): 91 | if self.sparse_config != None: 92 | self.wt_sparsity = float(wt_sparsity) 93 | self.sparse_config.weight_factor = self.wt_sparsity 94 | print("mpt-emulator: weight sparsity has been set to : {}".format(self.sparse_config.weight_factor)) 95 | else: 96 | print("mpt-emulator: set_target_sparsity_weight has no effect; sparse training is not enabled") 97 | 98 | def set_target_sparsity_gradient(self, grad_sparsity=0.5): 99 | if self.sparse_config != None: 100 | self.grad_sparsity = float(grad_sparsity) 101 | self.sparse_config.outgrad_factor = self.grad_sparsity 102 | print("mpt-emulator: gradient sparsity has been set to : {}".format(self.sparse_config.outgrad_factor)) 103 | else: 104 | print("mpt-emulator: set_target_sparsity_gradient has no effect; sparse training is not enabled") 105 | 106 | def fuse_bnlayers_and_quantize_model(self, model): 107 | self.emulator.is_training = False 108 | return self.emulator.fuse_layers_and_quantize_model(model) 109 | 110 | def set_calibration_qconfig(self): 111 | self.emulator.set_calibration_qconfig() 112 | 113 | def set_default_inference_qconfig(self): 114 | self.emulator.set_default_inference_qconfig() 115 | 116 | def __repr__(self): 117 | train_infer = "inference" 118 | if self.is_training : 119 | train_infer = "training" 120 | return "[Configured to run {} on {}, using AMP: {}]".format(str(train_infer), self.device, str(self.using_apex)) 121 | 122 | def rewrite_model_with_adasparse_ops(model, exempt_layers): 123 | list_exempt_layers = exempt_layers.copy() 124 | if isinstance(model, torch.nn.Conv2d): 125 | return SparseConv2d(model.in_channels, model.out_channels, model.kernel_size, 126 | stride=model.stride, padding=model.padding, dilation=model.dilation, groups=model.groups, 127 | bias=True if model.bias!= None else False).to(device=next(model.parameters()).device, dtype=next(model.parameters()).dtype) 128 | elif isinstance(model, torch.nn.Linear): 129 | return SparseLinear(model.in_features, model.out_features, 130 | bias=True if model.bias!= None else False).to(device=next(model.parameters()).device, dtype=next(model.parameters()).dtype) 131 | 132 | if list_exempt_layers is None: 133 | for name, module in model.named_children(): 134 | module = rewrite_model_with_adasparse_ops(module, None) 135 | setattr(model, name, module) 136 | else : 137 | for name, module in model.named_children(): 138 | if name in list_exempt_layers: 139 | setattr(model, name, module) 140 | list_exempt_layers.remove(name) 141 | else: 142 | module = rewrite_model_with_adasparse_ops(module, list_exempt_layers) 143 | setattr(model, name, module) 144 | return model 145 | 146 | def initialize(model, optimizer, training_algo='direct', hw_patch='none', pruning_algo='none', 147 | list_exempt_layers=None, list_layers_output_fused=None, device="cuda", verbose=False ): 148 | if model is None: #or optimizer is None: 149 | raise RuntimeError("mpt_emulator: Undefined model and optimizer, call this after model and optimizer are initilized.") 150 | if device == 'cuda' and hw_patch.lower() != 'none': 151 | raise RuntimeError("mpt_emulator: HW patching ops is only alowed on 'cpu' device.") 152 | 153 | mpt = MPTEmulator(device=device, training_algo=training_algo, hw_patch=hw_patch, pruning_algo=pruning_algo) 154 | 155 | if mpt.pruning_method == 'adaptive': 156 | model = rewrite_model_with_adasparse_ops(model, list_exempt_layers) 157 | print("mpt_emulator: Adaptive pruning method enabled; Adaptive (weights) only!") 158 | elif mpt.pruning_method == 'unstructured': 159 | mpt.sparse_config = SparseConfig () 160 | mpt.sparse_config.weight = True 161 | mpt.sparse_config.outgrad = True 162 | mpt.sparse_config.weight_factor = mpt.wt_sparsity 163 | mpt.sparse_config.outgrad_factor = mpt.grad_sparsity 164 | mpt.sparse_config.print_stats = False 165 | print("mpt_emulator: Unstructured pruning method enabled; TopK(weights : {}), Stochastic(gradients : {})." 166 | .format(mpt.sparse_config.weight_factor, mpt.sparse_config.outgrad_factor)) 167 | elif mpt.pruning_method == 'auto': 168 | model = rewrite_model_with_adasparse_ops(model, list_exempt_layers) 169 | mpt.sparse_config = SparseConfig () 170 | mpt.sparse_config.outgrad = True 171 | mpt.sparse_config.outgrad_factor = mpt.grad_sparsity 172 | mpt.sparse_config.print_stats = False 173 | print("mpt_emulator: Auto pruning method enabled; Adaptive(weights), Stochastic(gradients : {})." 174 | .format(mpt.sparse_config.outgrad_factor)) 175 | 176 | if mpt.training_algo == 'hybrid': 177 | mpt.emulator = HybridEmulator(model, optimizer, mpt.sparse_config, device=device, verbose=verbose) 178 | mpt.emulator.set_master_param_precision("fp16") 179 | else: # direct 180 | mpt.emulator = E5M2Emulator(model, optimizer, mpt.sparse_config, device=device, verbose=verbose) 181 | mpt.emulator.set_master_param_precision("fp16") 182 | 183 | mpt.emulator.enable_hw_patching(mpt.hw_patch.upper()) 184 | mpt.emulator.prepare_model(model, list_exempt_layers, list_layers_output_fused) 185 | if mpt.emulator.patch_ops == True and len(mpt.emulator.list_unpatched): 186 | print("mpt_emulator: Following layers are not HW_PATCH'ed because thier dimensions do not match the hardware : {} ".format(mpt.emulator.list_unpatched)) 187 | 188 | if verbose : 189 | mpt.emulator.print_config() 190 | 191 | return model, mpt 192 | 193 | def quantize_model(model, optimizer=None, dtype="none", calibrate=False, hw_patch="none", fuse_bn=False, 194 | list_exempt_layers=None, list_layers_output_fused=None, device="cuda", verbose=False ): 195 | if model is None : 196 | raise RuntimeError("mpt_emulator: Undefined model , call this after model is initilized.") 197 | if dtype == 'fp16' and device != 'cuda': 198 | raise RuntimeError("mpt_emulator: the requested data type {} is not supported on {}.".format(dtype, device)) 199 | if device == 'cuda' and hw_patch.lower() != 'none': 200 | raise RuntimeError("mpt_emulator: HW patching ops is only alowed on 'cpu' device.") 201 | 202 | mpt = MPTEmulator(device=device, hw_patch=hw_patch) 203 | if fuse_bn : 204 | model = mpt.fuse_bnlayers_and_quantize_model(model) 205 | 206 | if dtype.upper() == 'E5M2': 207 | mpt.emulator = E5M2Emulator(model, optimizer, None, device=device, verbose=verbose) 208 | elif dtype.upper() == 'E4M3': 209 | mpt.emulator = E4M3Emulator(model, optimizer, None, device=device, verbose=verbose) 210 | elif dtype.upper() == 'E3M4': 211 | mpt.emulator = E3M4Emulator(model, optimizer, None, device=device, verbose=verbose) 212 | elif dtype.upper() == 'HYBRID': 213 | mpt.emulator = HybridEmulator(model, optimizer, None, device=device, verbose=verbose) 214 | 215 | if calibrate: 216 | print("mpt_emulator : preparing model for calibration") 217 | mpt.emulator.is_training = False 218 | mpt.set_calibration_qconfig() 219 | else: 220 | mpt.emulator.is_training = False 221 | mpt.set_default_inference_qconfig() 222 | 223 | mpt.emulator.enable_hw_patching(mpt.hw_patch.upper()) 224 | mpt.emulator.prepare_model(model, list_exempt_layers, list_layers_output_fused) 225 | if mpt.emulator.patch_ops == True and len(mpt.emulator.list_unpatched): 226 | print("mpt_emulator: Following layers are not HW_PATCH'ed they cannot use the hardware configuration: {} ".format(mpt.emulator.list_unpatched)) 227 | 228 | if verbose : 229 | mpt.emulator.print_config() 230 | 231 | return model, mpt 232 | -------------------------------------------------------------------------------- /mpemu/pytquant/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch 3 | 4 | from . import cpp 5 | from .cpp import fpemu_cpp 6 | 7 | if torch.cuda.is_available() and torch.version.cuda: 8 | from . import cuda 9 | from .cuda import fpemu_cuda as fpemu_cuda 10 | 11 | if torch.cuda.is_available() and torch.version.hip: 12 | from . import hip 13 | from .hip import fpemu_hip as fpemu_cuda 14 | -------------------------------------------------------------------------------- /mpemu/pytquant/cpp/__init__.py: -------------------------------------------------------------------------------- 1 | from . import fpemu as fpemu_cpp 2 | -------------------------------------------------------------------------------- /mpemu/pytquant/cpp/fpemu.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | import math 11 | from torch import nn 12 | from torch.autograd import Function 13 | import torch 14 | import numpy 15 | import fpemu_cpp 16 | 17 | from enum import Enum 18 | 19 | torch.manual_seed(42) 20 | 21 | """ 22 | NONE 23 | E5M2_RTZ 24 | E5M2_STOCHASTIC 25 | E5M2_RNE 26 | E5M2_RNAZ 27 | E5M2_RNTZ 28 | E5M2_RPINF 29 | E5M2_RNINF 30 | E5M2_DAZ_STOCHASTIC 31 | E5M2_DAZ_RNE 32 | E5M2_DAZ_RNAZ 33 | E5M2_DAZ_RNTZ 34 | BFLOAT16_STOCHASTIC 35 | BFLOAT16_RNE 36 | FLOAT16_RNE 37 | FLOAT16_STOCHASTIC 38 | FLOAT16_DAZ_RNE 39 | E4M3_RNE 40 | E4M3_STOCHASTIC 41 | """ 42 | 43 | class FPEmuOp(Function): 44 | @staticmethod 45 | def forward(ctx, input, mode='NONE', inplace=False, scale=1.0, blocknorm=False, blocksize=1): 46 | if mode == 'NONE' : 47 | ctx.mark_dirty(input) 48 | return input 49 | else : 50 | if input.is_sparse : 51 | input = input.coalesce() 52 | size = input.values().nelement() 53 | if inplace == True: 54 | outputs = fpemu_cpp.forward(input._values().contiguous(), mode, size, inplace, scale, blocknorm, blocksize) 55 | output = input 56 | else : 57 | outputs = fpemu_cpp.forward(input._values().contiguous(), mode, size, inplace, scale, blocknorm, blocksize) 58 | output = torch.sparse.FloatTensor(input.indices(), outputs[0], input.size()) 59 | else : 60 | size = input.nelement() 61 | outputs = fpemu_cpp.forward(input.contiguous(), mode, size, inplace, scale, blocknorm, blocksize) 62 | output = outputs[0] 63 | 64 | if inplace == True: 65 | ctx.mark_dirty(input) 66 | return output 67 | 68 | @staticmethod 69 | def backward(ctx, output_grad): 70 | # straight-through estimator 71 | return output_grad, None, None, None, None 72 | -------------------------------------------------------------------------------- /mpemu/pytquant/cuda/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | if torch.cuda.is_available() and torch.version.cuda: 3 | from . import fpemu as fpemu_cuda 4 | -------------------------------------------------------------------------------- /mpemu/pytquant/cuda/fpemu.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | import math 11 | from torch import nn 12 | from torch.autograd import Function 13 | import torch 14 | import numpy 15 | import fpemu_cuda 16 | 17 | from enum import Enum 18 | 19 | torch.manual_seed(42) 20 | 21 | """ 22 | NONE 23 | E5M2_RTZ 24 | E5M2_STOCHASTIC 25 | E5M2_RNE 26 | E5M2_RNAZ 27 | E5M2_RNTZ 28 | E5M2_RPINF 29 | E5M2_RNINF 30 | E5M2_DAZ_STOCHASTIC 31 | E5M2_DAZ_RNE 32 | E5M2_DAZ_RNAZ 33 | E5M2_DAZ_RNTZ 34 | BFLOAT16_STOCHASTIC 35 | BFLOAT16_RNE 36 | FLOAT16_RNE 37 | FLOAT16_STOCHASTIC 38 | FLOAT16_DAZ_RNE 39 | E4M3_RNE 40 | E4M3_STOCHASTIC 41 | """ 42 | 43 | class FPEmuOp(Function): 44 | @staticmethod 45 | def forward(ctx, input, mode='NONE', inplace=False, scale=1.0, blocknorm=False, blocksize=1): 46 | if mode == 'NONE' : 47 | ctx.mark_dirty(input) 48 | return input 49 | else : 50 | if input.is_sparse : 51 | input = input.coalesce() 52 | size = input.values().nelement() 53 | if inplace == True: 54 | outputs = fpemu_cuda.forward(input._values().contiguous(), mode, size, inplace, scale, blocknorm, blocksize) 55 | output = input 56 | else : 57 | outputs = fpemu_cuda.forward(input._values().contiguous(), mode, size, inplace, scale, blocknorm, blocksize) 58 | output = torch.sparse.FloatTensor(input.indices(), outputs[0], input.size()) 59 | else : 60 | size = input.nelement() 61 | outputs = fpemu_cuda.forward(input.contiguous(), mode, size, inplace, scale, blocknorm, blocksize) 62 | output = outputs[0] 63 | 64 | if inplace == True: 65 | ctx.mark_dirty(input) 66 | return output 67 | 68 | @staticmethod 69 | def backward(ctx, output_grad): 70 | # straight-through estimator 71 | return output_grad, None, None, None, None 72 | -------------------------------------------------------------------------------- /mpemu/pytquant/cuda/fpemu_impl.cpp: -------------------------------------------------------------------------------- 1 | /*----------------------------------------------------------------------------* 2 | * Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | * This file is part of FP8-Emulation-Toolkit 4 | * 5 | * SPDX-License-Identifier: BSD-3-Clause 6 | *----------------------------------------------------------------------------* 7 | * Naveen Mellempudi (Intel Corporation) 8 | *----------------------------------------------------------------------------*/ 9 | 10 | #include 11 | #include 12 | 13 | std::vector fpemu_cuda_forward( 14 | torch::Tensor input, 15 | std::string mode, 16 | int size, 17 | bool inplace, 18 | float scale, 19 | bool block_norm, 20 | int block_size); 21 | 22 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. 23 | #define CHECK_CUDA(x) AT_ASSERTM(x.device().is_cuda(), #x " must be a CUDA tensor") 24 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 25 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 26 | 27 | std::vector fpemu_forward( 28 | torch::Tensor input, 29 | std::string mode, 30 | int size, 31 | bool inplace, 32 | float scale, 33 | bool block_norm, 34 | int block_size) { 35 | CHECK_INPUT(input); 36 | return fpemu_cuda_forward(input, mode, size, inplace, scale, block_norm, block_size); 37 | } 38 | 39 | std::vector fpemu_backward( 40 | torch::Tensor grad, 41 | std::string mode, 42 | int size, 43 | bool inplace, 44 | float scale, 45 | bool block_norm, 46 | int block_size) { 47 | CHECK_INPUT(grad); 48 | return fpemu_cuda_forward(grad, mode, size, inplace, scale, block_norm, block_size); 49 | } 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", &fpemu_forward, "FPEmu forward (CUDA)"); 53 | m.def("backward", &fpemu_backward, "FPEmu backward (CUDA)"); 54 | } 55 | -------------------------------------------------------------------------------- /mpemu/pytquant/hip/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | if torch.cuda.is_available() and torch.version.hip: 3 | from . import fpemu as fpemu_hip 4 | -------------------------------------------------------------------------------- /mpemu/pytquant/hip/fpemu.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | import math 11 | from torch import nn 12 | from torch.autograd import Function 13 | import torch 14 | import numpy 15 | import fpemu_hip 16 | 17 | from enum import Enum 18 | 19 | torch.manual_seed(42) 20 | 21 | """ 22 | NONE 23 | E5M2_RTZ 24 | E5M2_STOCHASTIC 25 | E5M2_RNE 26 | E5M2_RNAZ 27 | E5M2_RNTZ 28 | E5M2_RPINF 29 | E5M2_RNINF 30 | E5M2_DAZ_STOCHASTIC 31 | E5M2_DAZ_RNE 32 | E5M2_DAZ_RNAZ 33 | E5M2_DAZ_RNTZ 34 | BFLOAT16_STOCHASTIC 35 | BFLOAT16_RNE 36 | FLOAT16_RNE 37 | FLOAT16_STOCHASTIC 38 | FLOAT16_DAZ_RNE 39 | E4M3_RNE 40 | E4M3_STOCHASTIC 41 | """ 42 | 43 | class FPEmuOp(Function): 44 | @staticmethod 45 | def forward(ctx, input, mode='NONE', inplace=False, scale=1.0, blocknorm=False, blocksize=1): 46 | if mode == 'NONE' : 47 | ctx.mark_dirty(input) 48 | return input 49 | else : 50 | if input.is_sparse : 51 | input = input.coalesce() 52 | size = input.values().nelement() 53 | if inplace == True: 54 | outputs = fpemu_hip.forward(input._values().contiguous(), mode, size, inplace, scale, blocknorm, blocksize) 55 | output = input 56 | else : 57 | outputs = fpemu_hip.forward(input._values().contiguous(), mode, size, inplace, scale, blocknorm, blocksize) 58 | output = torch.sparse.FloatTensor(input.indices(), outputs[0], input.size()) 59 | else : 60 | size = input.nelement() 61 | outputs = fpemu_hip.forward(input.contiguous(), mode, size, inplace, scale, blocknorm, blocksize) 62 | output = outputs[0] 63 | 64 | if inplace == True: 65 | ctx.mark_dirty(input) 66 | return output 67 | 68 | @staticmethod 69 | def backward(ctx, output_grad): 70 | # straight-through estimator 71 | return output_grad, None, None, None, None 72 | -------------------------------------------------------------------------------- /mpemu/pytquant/hip/fpemu_impl.cpp: -------------------------------------------------------------------------------- 1 | /*----------------------------------------------------------------------------* 2 | * Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | * This file is part of FP8-Emulation-Toolkit 4 | * 5 | * SPDX-License-Identifier: BSD-3-Clause 6 | *----------------------------------------------------------------------------* 7 | * Naveen Mellempudi (Intel Corporation) 8 | *----------------------------------------------------------------------------*/ 9 | 10 | #include 11 | #include 12 | 13 | std::vector fpemu_cuda_forward( 14 | torch::Tensor input, 15 | std::string mode, 16 | int size, 17 | bool inplace, 18 | float scale, 19 | bool block_norm, 20 | int block_size); 21 | 22 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. 23 | #define CHECK_CUDA(x) AT_ASSERTM(x.device().is_cuda(), #x " must be a CUDA tensor") 24 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 25 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 26 | 27 | std::vector fpemu_forward( 28 | torch::Tensor input, 29 | std::string mode, 30 | int size, 31 | bool inplace, 32 | float scale, 33 | bool block_norm, 34 | int block_size) { 35 | CHECK_INPUT(input); 36 | return fpemu_cuda_forward(input, mode, size, inplace, scale, block_norm, block_size); 37 | } 38 | 39 | std::vector fpemu_backward( 40 | torch::Tensor grad, 41 | std::string mode, 42 | int size, 43 | bool inplace, 44 | float scale, 45 | bool block_norm, 46 | int block_size) { 47 | CHECK_INPUT(grad); 48 | return fpemu_cuda_forward(grad, mode, size, inplace, scale, block_norm, block_size); 49 | } 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", &fpemu_forward, "FPEmu forward (CUDA)"); 53 | m.def("backward", &fpemu_backward, "FPEmu backward (CUDA)"); 54 | } 55 | -------------------------------------------------------------------------------- /mpemu/pytquant/test.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | import argparse 14 | import numpy as np 15 | import torch 16 | import cpp.fpemu as fpemu_cpp 17 | import mpemu 18 | import sys 19 | 20 | sys.path.append('../../mpemu') 21 | # importing 22 | from mpemu.qutils import fpemu_device_fn 23 | 24 | 25 | if torch.cuda.is_available() and torch.version.cuda: 26 | import cuda.fpemu as fpemu_cuda 27 | if torch.cuda.is_available() and torch.version.hip: 28 | import hip.fpemu as fpemu_hip 29 | 30 | def check_equal(first, second, verbose): 31 | if verbose: 32 | print() 33 | for i, (x, y) in enumerate(zip(first, second)): 34 | x = x.cpu().detach().numpy() 35 | y = y.cpu().detach().numpy() 36 | 37 | if verbose: 38 | print("x = {}".format(x.flatten())) 39 | print("y = {}".format(y.flatten())) 40 | print('-' * 80) 41 | 42 | np.testing.assert_allclose(x, y, rtol=0.125, err_msg="Conflict : ", verbose=verbose) 43 | 44 | def print_tensor(first, second, verbose, sparse): 45 | print('printing ...') 46 | if verbose: 47 | print() 48 | for i, (x, y) in enumerate(zip(first, second)): 49 | if sparse : 50 | x = x.cpu().to_dense().detach().numpy() 51 | y = y.cpu().to_dense().detach().numpy() 52 | else : 53 | x = x.cpu().detach().numpy() 54 | y = y.cpu().detach().numpy() 55 | if verbose: 56 | print("x = {}".format(x.flatten())) 57 | print("y = {}".format(y.flatten())) 58 | print('-' * 80) 59 | 60 | 61 | def zero_grad(variables): 62 | for variable in variables: 63 | variable.grad.zero_() 64 | 65 | 66 | def get_grads(variables): 67 | return [var.grad.clone() for var in variables] 68 | 69 | def check_forward(variables, data_format, with_cuda, verbose, sparse): 70 | if with_cuda: 71 | if not torch.cuda.is_available(): 72 | print('CUDA is not supported on this platform ... ') 73 | elif torch.version.hip: 74 | cuda_values = fpemu_hip.FPEmuOp.apply(variables.cuda(), data_format.upper()) 75 | print('Forward: HIP ... ', end='') 76 | print_tensor(variables, cuda_values, verbose, sparse) 77 | else : 78 | cuda_values = fpemu_cuda.FPEmuOp.apply(variables.cuda(), data_format.upper()) 79 | print('Forward: CUDA ... ', end='') 80 | print_tensor(variables, cuda_values, verbose, sparse) 81 | else : 82 | cpp_values = fpemu_cpp.FPEmuOp.apply(variables, data_format.upper()) 83 | print('Forward: C++ ... ', end='') 84 | print_tensor(variables, cpp_values, verbose, sparse) 85 | 86 | if not verbose : 87 | print('Done.') 88 | print('Use the option --verbose to see results') 89 | else : 90 | print('Done.') 91 | 92 | 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument('-sp', '--sparse', action='store_true') 95 | parser.add_argument('-c', '--cuda', action='store_true') 96 | parser.add_argument('-v', '--verbose', action='store_true') 97 | parser.add_argument('-df', '--dformat', default='e4m3_rne') 98 | 99 | options = parser.parse_args() 100 | if options.cuda: 101 | device = torch.device("cuda") 102 | else: 103 | device = torch.device("cpu") 104 | 105 | kwargs = {'dtype': torch.float32, 106 | 'device': device, 107 | 'requires_grad': True} 108 | 109 | ''' 110 | #FP4 values 111 | input = torch.tensor(np.array([[ 0.0000000, 112 | 0.000244140625, 113 | 0.0009766, 114 | 0.0039063, 115 | 0.0156250, 116 | 0.0625000, 117 | 0.2500000, 118 | 1.0000000 ]]), dtype=torch.float32) 119 | ''' 120 | input = torch.tensor(np.array([[ 57344.00 , -61440.0, 65504.0, -500.0, 121 | 448.0 , -480.0, 30.0, -31.0, 122 | 26.0 , 15.0, -7.6505613e-00, 5.9832452e-00, 123 | 1.5625032e-02, 3.1725775e-02, 4.3268750e-02, 6.2655000e-02, 124 | 1.9545313e-03, 3.9045625e-03, 5.8845638e-03, 7.8089750e-03, 125 | -6.0151856e-02, 6.9784373e-03, -1.6634936e+03, -7.6505613e-04, 126 | 1.9545313e-03, 3.9045625e-03, 5.8845638e-03, 7.8089750e-03, 127 | 1.5629950e-02, 1.5225650e-02, -3.1256500e-02, 9.3857500e-02, 128 | 7.9284608e-01, 2.8815269e-01, -1.1787039e-01, 6.1035156e-05, 129 | -1.1787039e-06, -1.0749481e-01, -1.3085605e+00, 6.5981364e-01, 130 | -7.0325255e-02, 2.7448297e-01, 5.5694544e-01, -2.3220782e-01, 131 | -5.9746221e-02, 15.23213444 , -0.00004323 , 1.9767435e-04, 132 | -1.2203161e+00, 2.9099861e-01, -7.9642259e-02, 1.3200364e+00, 133 | -1.5196867e+00, -1.2530587e+00, -2.0159689e-03, -1.9767643e+00, 134 | 6.0834163e-04, 7.8943473e-05, 7.8247029e-04, -6.4658634e-05, 135 | -2.3020705e-06, -1.5630834e-05, -7.4762434e-07, 2.1336775e-06]]), dtype=torch.float32) 136 | #''' 137 | #input = torch.randn(4, 16) 138 | dformats=['e5m2_rne', 'e5m2_stochastic', 139 | 'e4m3_rne', 'e4m3_stochastic', 140 | 'e3m4_rne', 'e3m4_stochastic', 141 | ] 142 | 143 | if options.dformat not in dformats: 144 | print("data format {} is not supported".format(options.dformat)) 145 | exit() 146 | 147 | if options.sparse : 148 | check_forward(input.to_sparse(), options.dformat, options.cuda, options.verbose, options.sparse) 149 | else : 150 | check_forward(input, options.dformat, options.cuda, options.verbose, options.sparse) 151 | -------------------------------------------------------------------------------- /mpemu/scale_shift.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Dharma Teja Vooturi, Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | import torch 11 | 12 | class ScaleShift(torch.nn.Module): 13 | def __init__(self, num_features): 14 | super(ScaleShift, self).__init__() 15 | self.num_features = num_features 16 | self.weight = torch.nn.Parameter(torch.Tensor(num_features)) 17 | self.bias = torch.nn.Parameter(torch.Tensor(num_features)) 18 | self.reset_parameters() 19 | 20 | def reset_parameters(self): 21 | torch.nn.init.ones_(self.weight) 22 | torch.nn.init.zeros_(self.bias) 23 | 24 | def forward(self,input): 25 | # ASSUMPTION : First dimension is batch_size 26 | input_t = input.transpose(1,-1) 27 | input_t = input_t * self.weight + self.bias 28 | output = input_t.transpose(1,-1) 29 | 30 | return output 31 | 32 | # For call to repr(). Prints object's information 33 | def __repr__(self): 34 | return 'ScaleShift({})'.format(self.num_features) 35 | 36 | @staticmethod 37 | def generate_from_batchnorm(module: torch.nn.BatchNorm2d): 38 | """ 39 | Helper function for converting Batchnorm2d to ScaleShift 40 | """ 41 | # BN stat_dict 42 | bn_state_dict = module.state_dict() 43 | 44 | # BatchNorm parameters 45 | num_features = module.num_features 46 | eps = module.eps 47 | rmean = bn_state_dict['running_mean'] 48 | rvar = bn_state_dict['running_var'] 49 | gamma = bn_state_dict['weight'] 50 | beta = bn_state_dict['bias'] 51 | 52 | # Creating ScaleShift module 53 | ss_module = ScaleShift(num_features) 54 | with torch.no_grad(): 55 | denom = torch.sqrt(rvar + eps) 56 | scale = gamma.div(denom) 57 | shift = beta - gamma.mul(rmean).div(denom) 58 | 59 | ss_module.weight.copy_(scale) 60 | ss_module.bias.copy_(shift) 61 | 62 | return ss_module 63 | 64 | def replace_batchnorms_with_scaleshifts(model): 65 | """ 66 | Given a model, replace all BatchNorm2d layers with ScaleShift layers. 67 | """ 68 | if isinstance(model, torch.nn.BatchNorm2d): 69 | return ScaleShift.generate_from_batchnorm(model) 70 | for name, module in model.named_children(): 71 | module = replace_batchnorms_with_scaleshifts(module) 72 | setattr(model, name, module) 73 | return model 74 | 75 | 76 | if __name__ == '__main__': 77 | #### Testing ScaleShift layer #### 78 | num_features = 2 79 | ss = ScaleShift(num_features) 80 | ss.weight.data = ss.weight.data * torch.arange(1,num_features+1, dtype=ss.weight.dtype) 81 | ss.bias.data = torch.arange(1,num_features+1, dtype=ss.weight.dtype) 82 | 83 | input = torch.arange(num_features*2*2).reshape(1,num_features,2,2) 84 | print("Input") 85 | print(input) 86 | print("Scales") 87 | print(ss.weight.data) 88 | print("Shifts") 89 | print(ss.bias.data) 90 | 91 | output = ss(input) 92 | print("Output") 93 | print(output) 94 | 95 | #### Testing BN replacement with ScaleShift layer #### 96 | import torchvision 97 | 98 | model = torchvision.models.__dict__["resnet50"](pretrained=True) 99 | model.eval() 100 | 101 | # Original model 102 | input = torch.randn(1,3,224,224) 103 | output = model(input) 104 | 105 | # Replacing BNs 106 | model = replace_batchnorms_with_scaleshifts(model) 107 | output_s = model(input) 108 | 109 | print(torch.norm(output-output_s)) 110 | print(output.flatten()[:10]) 111 | print(output_s.flatten()[:10]) 112 | -------------------------------------------------------------------------------- /mpemu/sparse_utils.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Naveen Mellempudi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | import torch 11 | import os 12 | import sys 13 | 14 | class SparseConfig(object): 15 | def __init__(self, weight=False, ingrad=False, outgrad=False, wtgrad=False): 16 | self.weight = weight 17 | self.ingrad = ingrad 18 | self.outgrad = outgrad 19 | self.wtgrad = wtgrad 20 | self.print_stats = False 21 | 22 | self.global_step = 0 23 | self.alpha_window = 50 24 | self.weight_alpha = 65504.0 25 | self.weight_factor = 0.0 26 | self.outgrad_alpha = 65504.0 27 | self.outgrad_factor = 0.0 28 | 29 | def __repr__(self): 30 | return "[weight: {}, outgrad: {}, alpha_window : {}, weight_sparsity: {}, outgrd_sparsity :{} ]".format( 31 | self.weight, self.outgrad, self.alpha_window, self.weight_factor, self.outgrad_factor) 32 | 33 | def sparsify_ingrad_tensor(self, in_grad): 34 | return in_grad 35 | 36 | def sparsify_outgrad_tensor(self, out_grad): 37 | if self.global_step != 0 and (self.global_step%self.alpha_window) == 0 : 38 | self.outgrad_alpha = Stochastic_Pruning_Threshold(out_grad, self.outgrad_factor) 39 | 40 | out_grad.data = Stochastic_Pruning(out_grad.data, self.outgrad_alpha) 41 | return out_grad 42 | 43 | def sparsify_weight_tensor(self, weight): 44 | sample_factor = 0.1 45 | if self.global_step != 0 and (self.global_step%self.alpha_window) == 0 : 46 | self.weight_alpha = Topk_Threshold_Sampled(weight.data, self.weight_factor, sample_factor) 47 | 48 | weight.data = Topk_Pruning(weight.data, self.weight_alpha) 49 | return weight 50 | 51 | def sparsify_wtgrad_tensor(self, wt_grad): 52 | return wt_grad 53 | 54 | def optimizer_step (self, optimizer): 55 | ''' 56 | This routine should handle any sparsity related operations that needs to performed inside the optimier 57 | ''' 58 | return 59 | 60 | from scipy.optimize import root_scalar 61 | import pdb 62 | 63 | def print_sparse_stats (module, grad_input, grad_output) : 64 | for in_grad, out_grad in zip(grad_input, grad_output): 65 | wt_sp = 0.0 66 | if type(module) in [torch.nn.Conv2d, torch.nn.Linear] : 67 | wt_sp = 1 - (torch.count_nonzero(module.weight.data)/module.weight.data.numel()).item() 68 | grad_sp = 1 - (torch.count_nonzero(out_grad.data)/out_grad.data.numel()).item() 69 | print ("{} , weight_sparsity : {}, alpha_weight : {}, grad_sparsity : {}".format(module.name, wt_sp, module.qconfig.weight_ALPHA, grad_sp)) 70 | 71 | def Stochastic_Pruning(X, ALPHA): 72 | rand = ALPHA * torch.rand(X.shape, device=X.device,dtype=X.dtype) 73 | X_abs = X.abs() 74 | X = torch.where(X_abs < ALPHA, ALPHA * torch.sign(X), X) 75 | X = torch.where(X_abs < rand, torch.tensor([0], device=X.device, dtype=X.dtype), X) 76 | del X_abs 77 | return X 78 | 79 | def Stochastic_Pruning_Threshold(X, Sparsity): 80 | X = X.abs() 81 | X_sp = 1 - (torch.count_nonzero(X)/X.numel()).item() 82 | if Sparsity <= X_sp: 83 | return X 84 | Target_sparsity = Sparsity - X_sp 85 | 86 | Y = torch.log(X[X!=0.0]) 87 | mu = torch.mean(Y, dtype=torch.float) 88 | sigma = torch.std(Y) 89 | del Y 90 | guess = torch.tensor([1], device=X.device, dtype=X.dtype) 91 | bracket = [torch.exp(torch.tensor([-9], device=X.device, dtype=torch.float)), 92 | torch.exp(torch.tensor([5], device=X.device, dtype=torch.float))] 93 | sol = root_scalar(equationStochastic, x0=guess, bracket=bracket, args=(Sparsity, torch.tensor([0.0], device=X.device), sigma)) 94 | ALPHA = torch.tensor([sol.root], device=X.device, dtype=X.dtype) 95 | return torch.exp(torch.log(ALPHA) + mu) 96 | 97 | def Topk_Pruning(weight, alpha): 98 | Weight_Mask = torch.where(torch.abs(weight) < alpha, torch.tensor([0], device=weight.device, dtype=torch.short),\ 99 | torch.tensor([1], device=weight.device, dtype=torch.short)) 100 | weight.mul_(Weight_Mask.to(device=weight.device)) 101 | del Weight_Mask 102 | return weight 103 | 104 | def Topk_Threshold_Sampled(weight, Sparsity_Weight, sample_factor): 105 | Total_Weight_El = weight.numel() 106 | sampled_El = int(Total_Weight_El*sample_factor) 107 | sampled_id = torch.randint(0, Total_Weight_El, (1,sampled_El)) 108 | X_sampled = torch.abs(weight.view(-1))[sampled_id] 109 | #pdb.set_trace() 110 | topk_count_sample = int(sampled_El*(1-Sparsity_Weight)) 111 | __ , topk_w_id_sampled = torch.topk(X_sampled, topk_count_sample) 112 | alpha = X_sampled[0,topk_w_id_sampled[0,topk_count_sample-1]] 113 | return alpha 114 | 115 | def equationStochastic(alpha, sparsity, mu, sigma): 116 | sqrt2 = torch.sqrt(torch.tensor([2], device="cuda", dtype=torch.float)) 117 | alpha = torch.tensor([alpha], device="cuda", dtype=torch.float) 118 | pt1 = torch.exp((sigma**2)/2) * torch.erf(sigma/sqrt2 - torch.log(alpha/torch.exp(mu))/(sqrt2 * sigma)) 119 | pt2 = alpha/torch.exp(mu) * torch.erf(torch.log(alpha/torch.exp(mu))/(sqrt2 * sigma)) 120 | pt3 = torch.exp((sigma**2)/2) 121 | return 0.5 - sparsity + torch.exp(mu)/(2*alpha) * (pt1 + pt2 - pt3) 122 | -------------------------------------------------------------------------------- /mpemu/stats_collector.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------------------------------------ 2 | # Copyright (c) 2023, Intel Corporation - All rights reserved. 3 | # This file is part of FP8-Emulation-Toolkit 4 | # 5 | # SPDX-License-Identifier: BSD-3-Clause 6 | #------------------------------------------------------------------------------ 7 | # Dharma Teja Vooturi (Intel Corporation) 8 | #------------------------------------------------------------------------------ 9 | 10 | import torch 11 | 12 | class TensorFullIntQuantParams(object): 13 | """ 14 | min_val : float 15 | max_val : float 16 | qconfig : TensorQuantConfig 17 | """ 18 | def __init__(self, min_val, max_val, qconfig): 19 | super(TensorFullIntQuantParams, self).__init__() 20 | self.qconfig = qconfig 21 | self.min_val, self.max_val, self.scale, self.zero_point = self._calculate_int8_qparams_base(qconfig.dtype, 22 | qconfig.scheme, min_val, max_val) 23 | 24 | 25 | def quantize(self, tensor_f): 26 | # Clamping the values 27 | #tensor_f = torch.clamp(tensor_f, self.min_val, self.max_val) 28 | # Quantizing the tensor 29 | tensor_int = torch.round(tensor_f/self.scale + self.zero_point) 30 | 31 | min_int = -128 32 | max_int = 127 33 | if self.qconfig.dtype == "uint8": 34 | min_int = 0 35 | max_int = 255 36 | 37 | # Clamping the values in integer domain 38 | tensor_int = torch.clamp(tensor_int, min_int, max_int) 39 | 40 | return tensor_int 41 | 42 | def dequantize(self, tensor_int): 43 | tensor_f = (tensor_int - self.zero_point)*self.scale 44 | return tensor_f 45 | 46 | def quant_dequant(self, tensor_f): 47 | return self.dequantize(self.quantize(tensor_f)) 48 | 49 | def __repr__(self): 50 | return "{} Quantization range [{:.4f},{:.4f}] ".format(self.qconfig, self.min_val, self.max_val) 51 | 52 | # Calculating quantization parameters for INT8/UINT8 53 | @staticmethod 54 | def _calculate_int8_qparams_base(dtype, scheme, min_val, max_val): 55 | """ 56 | Adapted from https://github.com/pytorch/pytorch/blob/8074779328fa471f484fb74cc6c50d95392fe2c2/torch/quantization/observer.py#L193 57 | """ 58 | assert min_val <= max_val, "Minimum value {} has to be less than Maximum value {}".format(min_val,max_val) 59 | eps = torch.finfo(torch.float32).eps 60 | 61 | if dtype == "uint8": 62 | qmin = 0 63 | qmax = 255 64 | elif dtype == "int8": 65 | qmin = -128 66 | qmax = 127 67 | 68 | # Including zero in the range 69 | min_val,max_val = float(min_val),float(max_val) 70 | min_val = min(0.0, min_val) 71 | max_val = max(0.0, max_val) 72 | 73 | if min_val == max_val: 74 | scale = 1.0 75 | zero_point = 0 76 | else: 77 | if scheme == "sym_full" or scheme == "sym_channel": 78 | max_val = max(-min_val, max_val) 79 | scale = max_val / ((qmax - qmin) / 2) 80 | scale = max(scale, eps) 81 | zero_point = 0 if dtype == "int8" else 128 82 | 83 | min_val = -1*max_val 84 | elif scheme == "asym_full" or scheme == "asym_channel": 85 | scale = (max_val - min_val) / float(qmax - qmin) 86 | scale = max(scale, eps) 87 | zero_point = qmin - round(min_val / scale) 88 | zero_point = max(qmin, zero_point) 89 | zero_point = min(qmax, zero_point) 90 | zero_point = int(zero_point) 91 | 92 | 93 | return min_val, max_val, scale, zero_point 94 | 95 | 96 | class TensorChannelIntQuantParams(object): 97 | def __init__(self, min_vals, max_vals, qconfig): 98 | super(TensorChannelIntQuantParams, self).__init__() 99 | assert(len(min_vals) == len(max_vals)) 100 | self.num_channels = len(min_vals) 101 | self.channel_qparams = [] 102 | 103 | for min_val,max_val in zip(min_vals,max_vals): 104 | self.channel_qparams.append(TensorFullIntQuantParams(min_val, max_val, qconfig)) 105 | 106 | def quant_dequant(self, tensor_f): 107 | tensor_q = torch.zeros_like(tensor_f) 108 | for chan_id in range(self.num_channels): 109 | tensor_q[chan_id] = self.channel_qparams[chan_id].quant_dequant(tensor[chan_id]) 110 | return tensor_q 111 | 112 | 113 | class TensorDump(torch.nn.Module): 114 | def __init__(self, name=""): 115 | self.name = name 116 | self.tensors = [] 117 | 118 | def forward(self, tensor): 119 | self.tensors.append(tensor) 120 | 121 | def dump(self): 122 | import pickle 123 | import numpy as np 124 | pickle.dump(np.array([tensor.detach().numpy() for tensor in self.tensors]), open(self.name+".pickle","wb")) 125 | 126 | class TensorDumpListWrapper(torch.nn.Module): 127 | """docstring for MinMaxStats of a tensor""" 128 | def __init__(self, name=""): 129 | super(TensorDumpListWrapper, self).__init__() 130 | self.initiated = False 131 | self.tupled_input = False 132 | self.tensor_dump_list = [] 133 | 134 | self.name = name 135 | 136 | def forward(self, input): 137 | if not self.initiated: 138 | if type(input) == tuple: 139 | self.tupled_input = True 140 | for i in range(len(input)): 141 | self.tensor_dump_list.append(TensorDump(name=self.name+"_{}".format(i))) 142 | else: 143 | self.tensor_dump_list = [TensorDump(name=self.name+"_0")] 144 | self.initiated = True 145 | 146 | if self.tupled_input: 147 | for i in range(len(input)): 148 | self.tensor_dump_list[i].forward(input[i]) 149 | else: 150 | self.tensor_dump_list[0].forward(input) 151 | 152 | 153 | def dump(self): 154 | for tensor_dump in self.tensor_dump_list: 155 | tensor_dump.dump() 156 | 157 | 158 | class ArchiveStats(torch.nn.Module): 159 | def __init__(self): 160 | self.tensors = [] 161 | 162 | def forward(self, tensor): 163 | self.tensors.append(tensor) 164 | 165 | class MinMaxStats(torch.nn.Module): 166 | """docstring for MinMaxStats of a tensor""" 167 | def __init__(self, archive_tensors=False): 168 | super(MinMaxStats, self).__init__() 169 | self.min_val = None 170 | self.max_val = None 171 | 172 | self.archive_tensors = archive_tensors 173 | self.tensors = None 174 | 175 | def forward(self, tensor): 176 | min_val = torch.min(tensor).item() 177 | max_val = torch.max(tensor).item() 178 | 179 | if self.min_val == None: 180 | self.min_val = min_val 181 | else: 182 | if min_val < self.min_val: 183 | self.min_val = min_val 184 | 185 | if self.max_val == None: 186 | self.max_val = max_val 187 | else: 188 | if max_val > self.max_val: 189 | self.max_val = max_val 190 | 191 | if self.archive_tensors: 192 | if self.tensors is None: 193 | self.tensors = [tensor] 194 | else: 195 | self.tensors.append(tensor) 196 | 197 | def get_tensor_quant_params(self, ten_qconfig): 198 | assert(ten_qconfig.dtype in ["uint8","int8"]) 199 | ten_qparams = TensorFullIntQuantParams(self.min_val, self.max_val, ten_qconfig) 200 | return ten_qparams 201 | 202 | def print(self, name=""): 203 | print("{:7.3f} {:7.3f} {}".format(self.min_val, self.max_val, name)) 204 | 205 | 206 | class RunningMinMaxStats(torch.nn.Module): 207 | """docstring for MinMaxStats of a tensor""" 208 | def __init__(self, archive_tensors=False): 209 | super(RunningMinMaxStats, self).__init__() 210 | 211 | self.min_val = None 212 | self.max_val = None 213 | 214 | self.running_min_val = None 215 | self.running_max_val = None 216 | 217 | self.running_steps = 0 218 | 219 | self.archive_tensors = archive_tensors 220 | self.tensors = None 221 | 222 | def forward(self, tensor): 223 | min_val = torch.min(tensor).item() 224 | max_val = torch.max(tensor).item() 225 | 226 | if self.min_val == None: 227 | self.min_val = min_val 228 | else: 229 | if min_val < self.min_val: 230 | self.min_val = min_val 231 | 232 | if self.max_val == None: 233 | self.max_val = max_val 234 | else: 235 | if max_val > self.max_val: 236 | self.max_val = max_val 237 | 238 | ## Running max and running mean 239 | if self.running_min_val == None: 240 | self.running_min_val = min_val 241 | else: 242 | self.running_min_val = (self.running_min_val*self.running_steps + min_val)/(self.running_steps+1) 243 | 244 | if self.running_max_val == None: 245 | self.running_max_val = max_val 246 | else: 247 | self.running_max_val = (self.running_max_val*self.running_steps + max_val)/(self.running_steps+1) 248 | 249 | self.running_steps += 1 250 | 251 | if self.archive_tensors: 252 | if self.tensors is None: 253 | self.tensors = [tensor] 254 | else: 255 | self.tensors.append(tensor) 256 | 257 | def get_tensor_quant_params(self, ten_qconfig): 258 | assert(ten_qconfig.dtype in ["uint8","int8"]) 259 | ten_qparams = TensorFullIntQuantParams(self.running_min_val, self.running_max_val, ten_qconfig) 260 | return ten_qparams 261 | 262 | def print(self, name=""): 263 | print("{:7.3f} {:7.3f} {:7.3f} {:7.3f} {}".format(self.min_val, self.max_val, 264 | self.running_min_val, self.running_max_val, name)) 265 | 266 | 267 | class StatsListWrapper(torch.nn.Module): 268 | """docstring for MinMaxStats of a tensor""" 269 | def __init__(self, stats_class, archive_tensors): 270 | super(StatsListWrapper, self).__init__() 271 | self.initiated = False 272 | self.tupled_input = False 273 | self.stats_list = [] 274 | 275 | self.stats_class = stats_class 276 | self.archive_tensors = archive_tensors 277 | 278 | def forward(self, input): 279 | if not self.initiated: 280 | if type(input) == tuple: 281 | self.tupled_input = True 282 | for i in range(len(input)): 283 | self.stats_list.append(self.stats_class(archive_tensors=self.archive_tensors)) 284 | else: 285 | self.stats_list = [self.stats_class(archive_tensors=self.archive_tensors)] 286 | self.initiated = True 287 | 288 | if self.tupled_input: 289 | for i in range(len(input)): 290 | self.stats_list[i].forward(input[i]) 291 | else: 292 | self.stats_list[0].forward(input) 293 | 294 | def get_tensor_quant_params(self, ten_qconfig): 295 | ret_list = [stats_obj.get_tensor_quant_params(ten_qconfig) for stats_obj in self.stats_list] 296 | 297 | if self.tupled_input: 298 | return ret_list 299 | else: 300 | return ret_list[0] 301 | 302 | def print(self, name=""): 303 | for i in range(len(self.stats_list)): 304 | self.stats_list[i].print(name+"[{}]".format(i)) 305 | 306 | 307 | class ChannleWiseMinMaxStats(torch.nn.Module): 308 | 309 | def __init__(self): 310 | super(ChannleWiseMinMaxStats, self).__init__() 311 | self.min_vals = None 312 | self.max_vals = None 313 | 314 | def forward(self, tensor): 315 | num_chans = tensor.shape[0] 316 | if self.min_vals == None or self.max_vals == None: 317 | self.min_vals = [None]*num_chans 318 | self.max_vals = [None]*num_chans 319 | 320 | for chan_id in range(num_chans): 321 | min_val = torch.min(tensor[chan_id]).item() 322 | max_val = torch.max(tensor[chan_id]).item() 323 | 324 | if self.min_vals[chan_id] == None: 325 | self.min_vals[chan_id] = min_val 326 | else: 327 | if self.min_vals[chan_id] < min_val: 328 | self.min_vals[chan_id] = min_val 329 | 330 | if self.max_vals[chan_id] == None: 331 | self.max_vals[chan_id] = max_val 332 | else: 333 | if self.max_vals[chan_id] > max_val: 334 | self.max_vals[chan_id] = max_val 335 | 336 | def get_tensor_quant_params(self, ten_qconfig): 337 | ten_qparams = TensorChannelIntQuantParams(self.min_vals, self.max_vals, ten_qconfig) 338 | return ten_qparams 339 | 340 | def print(self): 341 | print(self.min_vals, self.max_vals, end=' ') 342 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | requests 2 | scipy 3 | setuptools 4 | torch 5 | torchaudio 6 | torchvision 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from setuptools import setup, find_packages 4 | from torch.utils.cpp_extension import BuildExtension, CppExtension 5 | 6 | cmdclass = {} 7 | ext_modules = [] 8 | 9 | with open("README.md", "r") as fh: 10 | long_description = fh.read() 11 | 12 | ext_modules.append( 13 | CppExtension('fpemu_cpp', 14 | ['mpemu/pytquant/cpp/fpemu_impl.cpp'], 15 | extra_compile_args = ["-mf16c", "-march=native", "-mlzcnt", "-fopenmp", "-Wdeprecated-declarations"] 16 | ),) 17 | 18 | if torch.cuda.is_available(): 19 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 20 | if torch.version.hip: 21 | ext_modules.append( 22 | CUDAExtension('fpemu_hip', [ 23 | 'mpemu/pytquant/hip/fpemu_impl.cpp', 24 | 'mpemu/pytquant/hip/fpemu_kernels.hip'], 25 | ),) 26 | elif torch.version.cuda: 27 | ext_modules.append( 28 | CUDAExtension('fpemu_cuda', [ 29 | 'mpemu/pytquant/cuda/fpemu_impl.cpp', 30 | 'mpemu/pytquant/cuda/fpemu_kernels.cu'], 31 | ),) 32 | 33 | ext_modules.append( 34 | CppExtension('simple_gemm_dev', 35 | ['mpemu/cmodel/simple/simple_gemm.cpp', 'mpemu/cmodel/simple/simple_gemm_impl.cpp', 'mpemu/cmodel/simple/simple_mm_engine.cpp'], 36 | extra_compile_args=[ '-march=native', '-fopenmp','-Wunused-but-set-variable','-Wunused-variable'], 37 | include_dirs=['{}/'.format(os.getenv("PWD")+'/mpemu/cmodel/simple')], 38 | ),) 39 | 40 | ext_modules.append( 41 | CppExtension('simple_conv2d_dev', 42 | ['mpemu/cmodel/simple/simple_conv2d.cpp', 'mpemu/cmodel/simple/simple_conv2d_impl.cpp', 'mpemu/cmodel/simple/simple_mm_engine.cpp'], 43 | extra_compile_args=[ '-march=native', '-fopenmp','-Wunused-but-set-variable','-Wunused-variable'], 44 | include_dirs=['{}/'.format(os.getenv("PWD")+'/mpemu/cmodel/simple')], 45 | ),) 46 | 47 | cmdclass['build_ext'] = BuildExtension 48 | 49 | setup( 50 | name="mpemu", 51 | version="1.0", 52 | ext_modules=ext_modules, 53 | cmdclass=cmdclass, 54 | author="Naveen Mellempudi", 55 | description="FP8 Emulation Toolkit", 56 | long_description=long_description, 57 | long_description_content_type="text/markdown", 58 | url="", 59 | packages=find_packages(), 60 | classifiers=[ 61 | "Programming Language :: Python :: 3", 62 | "Operating System :: OS Independent", 63 | ], 64 | python_requires='>=3.6', 65 | ) 66 | --------------------------------------------------------------------------------