├── LICENSE
├── README.md
├── docs
└── formats.png
├── examples
├── inference
│ ├── README.md
│ ├── bert
│ │ ├── README.txt
│ │ ├── cmd_inference.sh
│ │ ├── download_squad_dataset.sh
│ │ ├── download_squad_finetuned_model.sh
│ │ ├── modeling_bert.py
│ │ ├── requirements.txt
│ │ └── run_squad.py
│ └── classifier
│ │ ├── README.md
│ │ ├── imagenet_qat.py
│ │ ├── imagenet_test.py
│ │ ├── launch.py
│ │ ├── test_scaleshift.py
│ │ ├── train.py
│ │ └── utils.py
└── training
│ ├── README.md
│ ├── bert
│ ├── launch_fp8_training.sh
│ ├── modeling_bert.py
│ ├── requirements.txt
│ ├── run_qa_beam_search_no_trainer.py
│ ├── run_qa_no_trainer.py
│ └── utils_qa.py
│ ├── nvidia_apex_cpu_patch_05091d4.patch
│ └── resnet
│ ├── README.md
│ ├── lr_scheduler.py
│ ├── main_amp.py
│ ├── main_amp_cpu.py
│ ├── train_cpu.sh
│ └── train_gpu.sh
├── mpemu
├── __init__.py
├── bfloat16_emu.py
├── cmodel
│ ├── __init__.py
│ ├── simple.py
│ ├── simple
│ │ ├── simple_conv2d.cpp
│ │ ├── simple_conv2d_impl.cpp
│ │ ├── simple_gemm.cpp
│ │ ├── simple_gemm_impl.cpp
│ │ ├── simple_mm_engine.cpp
│ │ └── vla.h
│ └── tests
│ │ ├── conv_grad_test.py
│ │ ├── conv_test.py
│ │ ├── gemm_grad_test.py
│ │ ├── gemm_irregular_test.py
│ │ ├── gemm_test.py
│ │ ├── linear_test.py
│ │ └── net.py
├── e3m4_emu.py
├── e4m3_emu.py
├── e5m2_emu.py
├── hybrid_emu.py
├── module_wrappers
│ ├── __init__.py
│ ├── adasparse.py
│ ├── aggregate.py
│ ├── eltwise.py
│ └── matmul.py
├── mpt_emu.py
├── pytquant
│ ├── __init__.py
│ ├── cpp
│ │ ├── __init__.py
│ │ ├── fpemu.py
│ │ └── fpemu_impl.cpp
│ ├── cuda
│ │ ├── __init__.py
│ │ ├── fpemu.py
│ │ ├── fpemu_impl.cpp
│ │ └── fpemu_kernels.cu
│ ├── hip
│ │ ├── __init__.py
│ │ ├── fpemu.py
│ │ ├── fpemu_impl.cpp
│ │ └── fpemu_kernels.hip
│ └── test.py
├── qutils.py
├── scale_shift.py
├── sparse_utils.py
└── stats_collector.py
├── requirements.txt
└── setup.py
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2023, Intel Labs
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FP8 Emulation Toolkit
2 |
3 | > [!CAUTION]
4 | > **PROJECT NOT UNDER ACTIVE MANAGEMENT**
5 | > * This project will no longer be maintained by Intel.
6 | > * Intel has ceased development and contributions including, but not limited to, maintenance, bug fixes, new releases, or updates, to this project.
7 | > * Intel no longer accepts patches to this project.
8 | > * If you have an ongoing need to use this project, are interested in independently developing it, or would like to maintain patches for the open source software community, please create your own fork of this project.
9 |
10 | ## Introduction
11 | This repository provides PyTorch tools to emulate the new `FP8` formats on top of existing floating point hardware from Intel, AMD and NVIDIA. In addition to the two formats `E5M2` and `E4M3` defined in the joint specification from ARM-Intel-NVIDIA, the toolkit also suports a third variant named `E3M4` which follows the guidelines established for `E4M3` format.
12 |
13 | Following table shows the binary formats and the numeric range:
14 |
15 | | | E5M2 | E4M3 | E3M4 |
16 | | -------------- | ---------------------------------------------------------------- | ---------------------------------------------------------------- | ---------------------------------------------------------------- |
17 | | Exponent Bias | 15 | 7 | 3 |
18 | | Infinities | S.11111.002 | N/A | N/A |
19 | | NaNs | S.11111.{01, 10, 11}2 | S.1111.1112 | S.111.11112 |
20 | | Zeros | S.00000.002 | S.0000.0002 | S.000.00002 |
21 | | Max normal | S.11110.112=1.75 * 215=57344.0 | S.1111.1102=1.75 * 28=448.0 | S.111.11102=1.875 * 24=30.0 |
22 | | Min normal | S.00001.002=2-14=6.1e-05 | S.0001.0002=2-6=1.5e-02 | S.001.00002=2-2=2.5e-01 |
23 | | Max subnormal | S.00000.112=0.75 * 2-14=4.5e-05 | S.0000.1112=0.875 * 2-6=1.3e-02 | S.000.11112=0.9375 * 2-2=2.3e-01 |
24 | | Min subnormal | S.00000.012=2-16=1.5e-05 | S.0000.0012=2-9=1.9e-03 | S.000.00012=2-6=1.5e-02 |
25 |
26 | 
27 |
28 | ## Installation
29 |
30 | Follow the instructions below to install FP8 Emulation Toolkit in a Python virtual environment.
31 | Alternatively, this installation can also be performed in a docker environment.
32 |
33 | ### Requirements
34 | This package can be installed on the following hardware.
35 |
36 | * x86 CPUs from AMD and Intel
37 | * GPU devices from NVIDIA(CUDA) and AMD(HIP)
38 |
39 | Install or upgrade the following packages on your linux machine.
40 |
41 | * Python >= 3.8.5
42 | * NVIDIA CUDA >= 11.1 or AMD ROCm >= 5.6
43 | * gcc >= 8.4.0
44 |
45 | Make sure these versions are reflected in the `$PATH`
46 |
47 | #### Target Hardware
48 | * CPU >= All x86
49 | * GPU >= V100, MI2XX
50 |
51 | ### Create a Python virtual environment
52 | ```
53 | $ python3 -m ~/py-venv
54 | $ cd ~/py-venv
55 | $ source bin/activate
56 | $ pip3 install --upgrade pip3
57 | ```
58 | ### Clone and install FP8 Emulation Toolkit
59 | ```
60 | $ git clone https://github.com/IntelLabs/FP8-Emulation-Toolkit.git
61 | $ cd FP8-Emulation-Toolkit
62 | $ pip3 install -r requirements.txt
63 | $ python setup.py install
64 | ```
65 |
66 | ## Usage Examples
67 | The emulated FP8 formats can be experimented with by integrated them into standard deep learning flows. Follow the links below for detailed instructions and code samples for exploring training and inference flows using FP8 data formats.
68 |
69 | * [Post-training quantization](./examples/inference)
70 | * [Mixed precision training](./examples/training)
71 |
72 |
73 | ## Related Work
74 | This implementation is based on the following research. Check out the source material for more details on the training and inference methods.
75 |
76 | ```
77 | @article{shen2023efficient,
78 | title={Efficient Post-training Quantization with FP8 Formats},
79 | author={Shen, Haihao and Mellempudi, Naveen and He, Xin and Gao, Qun and Wang, Chang and Wang, Mengni},
80 | journal={arXiv preprint arXiv:2309.14592},
81 | year={2023}
82 | }
83 | ```
84 | ```
85 | @misc{mellempudi2019mixed,
86 | title={Mixed Precision Training With 8-bit Floating Point},
87 | author={Naveen Mellempudi and Sudarshan Srinivasan and Dipankar Das and Bharat Kaul},
88 | year={2019},
89 | eprint={1905.12334},
90 | archivePrefix={arXiv},
91 | primaryClass={cs.LG}
92 | }
93 | ```
94 | ```
95 | @misc{micikevicius2022fp8,
96 | title={FP8 Formats for Deep Learning},
97 | author={Paulius Micikevicius and Dusan Stosic and Neil Burgess and Marius Cornea and Pradeep Dubey and Richard Grisenthwaite and Sangwon Ha and Alexander Heinecke and Patrick Judd and John Kamalu and Naveen Mellempudi and Stuart Oberman and Mohammad Shoeybi and Michael Siu and Hao Wu},
98 | year={2022},
99 | eprint={2209.05433},
100 | archivePrefix={arXiv},
101 | primaryClass={cs.LG}
102 | }
103 | ```
104 |
--------------------------------------------------------------------------------
/docs/formats.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntelLabs/FP8-Emulation-Toolkit/fcd26e9a61e0e2aa4c559e38c15b29955c71149e/docs/formats.png
--------------------------------------------------------------------------------
/examples/inference/README.md:
--------------------------------------------------------------------------------
1 | # FP8 Post-training Quantization
2 |
3 | Following example demonstrates the post-training quantization flow for converting pre-trained models to use FP8 for inference.
4 |
5 | ```
6 | # import the emulator
7 | from mpemu import mpt_emu
8 |
9 | ...
10 |
11 | # layers exempt from e4m3 conversion
12 | list_exempt_layers = ["conv1","fc"]
13 |
14 | model, emulator = mpt_emu.quantize_model (model, dtype="e4m3_rne", "None",
15 | list_exempt_layers=list_exempt_layers)
16 |
17 | # calibrate the model for a few batches of training data
18 | evaluate(model, criterion, train_loader, device,
19 | num_batches=, train=True)
20 |
21 | # Fuse BatchNorm layers and quantize the model
22 | model = emulator.fuse_layers_and_quantize_model(model)
23 |
24 | # Evaluate the quantized model
25 | evaluate(model, criterion, test_loader, device)
26 |
27 | ```
28 | An example demostrating post-training quantization can be found [here](./classifier/imagenet_test.py).
29 |
30 |
--------------------------------------------------------------------------------
/examples/inference/bert/README.txt:
--------------------------------------------------------------------------------
1 |
2 | Prerequsite:
3 | -----------
4 |
5 | Install squad task specific requirements (one time):
6 | $pip install -r requirements.txt
7 |
8 | Download SQUAD dataset:
9 | ----------------------
10 | $bash download_squad_dataset.sh
11 |
12 | Download Squad fine-tuned model for inference:
13 | ---------------------------------------------
14 | $bash download_squad_fine_tuned_model.sh
15 |
16 | To run squad baseline inference task:
17 | $bash cmd_infer.sh
18 |
19 | To run squad inference in BF8:
20 | $bash cmd_infer.sh --use_pcl --pcl_bf8 --unpad
21 |
22 |
--------------------------------------------------------------------------------
/examples/inference/bert/cmd_inference.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | QUANT_TYPE=${1:-'hybrid'}
3 |
4 | # set dataset and model_path
5 | if test -z $dataset || ! test -d $dataset ; then
6 | if test -d ./SQUAD1 ; then
7 | dataset=./SQUAD1
8 | else
9 | echo "Unable to find dataset path!!"
10 | echo "Download SQuAD dataset using the command ./download_squad_dataset.sh"
11 | exit 1
12 | fi
13 | fi
14 |
15 | if test -z $model_path || ! test -d $model_path ; then
16 | if test -d ./squad_finetuned_checkpoint ; then
17 | model_path=./squad_finetuned_checkpoint
18 | else
19 | echo "Unable to find pre-trained model path!!"
20 | echo "Download the pre-trained SQuAD model using the command ./download_squad_finetuned_model.sh"
21 | exit 1
22 | fi
23 | fi
24 |
25 | $NUMA_RAGS $GDB_ARGS python -u run_squad.py \
26 | --model_type bert \
27 | --model_name_or_path $model_path \
28 | --do_eval \
29 | --do_lower_case \
30 | --predict_file $dataset/dev-v1.1.json \
31 | --per_gpu_eval_batch_size 24 \
32 | --max_seq_length 384 \
33 | --doc_stride 128 \
34 | --quant_data_type=$QUANT_TYPE \
35 | --output_dir /tmp/debug_squad/ \
36 | #--no_cuda
37 |
--------------------------------------------------------------------------------
/examples/inference/bert/download_squad_dataset.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | mkdir -p SQUAD1 && cd SQUAD1 && wget --no-check-certificate https://data.deepai.org/squad1.1.zip && unzip squad1.1.zip && cd ..
3 |
--------------------------------------------------------------------------------
/examples/inference/bert/download_squad_finetuned_model.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 | mkdir squad_finetuned_checkpoint && cd squad_finetuned_checkpoint
4 | wget -c https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/config.json
5 | wget -c https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/pytorch_model.bin
6 | wget -c https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/tokenizer.json
7 | wget -c https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/tokenizer_config.json
8 | wget -c https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt
9 |
--------------------------------------------------------------------------------
/examples/inference/bert/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers>=4.11.0
2 |
--------------------------------------------------------------------------------
/examples/inference/classifier/README.md:
--------------------------------------------------------------------------------
1 | #### Post Training Quantization
2 | ```
3 | python launch.py
4 | ```
5 | Refer to launch.py for more details.
6 |
7 | #### QAT(Quantization Aware Training)
8 | ```
9 | python imagenet_qat.py \
10 | --data-path \
11 | --arch=mobilenet_v2 \
12 | --lr=0.001 \
13 | --epochs=15 \
14 | --lr-step-size=5 \
15 | --lr-gamma=0.1 \
16 | --qdtype bfloat8 --qscheme rne --qlevel 3
17 | ```
18 |
--------------------------------------------------------------------------------
/examples/inference/classifier/imagenet_qat.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Dharma Teja Vooturi, Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | import datetime
11 | import os
12 | import time
13 | import sys
14 | import copy
15 | import collections
16 |
17 | import torch
18 | import torch.utils.data
19 | from torch import nn
20 | import torchvision
21 | from torchvision import transforms
22 | import torch.quantization
23 | import utils
24 | try:
25 | from apex import amp
26 | except ImportError:
27 | amp = None
28 |
29 | #from train import train_one_epoch, evaluate, load_data
30 | from mpemu import mpt_emu
31 |
32 | def train_one_epoch(args, model, criterion, optimizer, emulator, data_loader,
33 | device, epoch, print_freq, num_batches=None):
34 | model.train()
35 | metric_logger = utils.MetricLogger(delimiter=" ")
36 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
37 | metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}'))
38 |
39 | header = 'Epoch: [{}]'.format(epoch)
40 | i = 0
41 | for image, target in metric_logger.log_every(data_loader, print_freq, header):
42 | i += 1
43 | start_time = time.time()
44 | image, target = image.to(device), target.to(device)
45 | output = model(image)
46 | loss = criterion(output, target)
47 |
48 | optimizer.zero_grad()
49 | if args.data_type.lower() == "bf8":
50 | with amp.scale_loss(loss, optimizer) as scaled_loss:
51 | scaled_loss.backward()
52 | else:
53 | loss.backward()
54 | optimizer.step()
55 |
56 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
57 | batch_size = image.shape[0]
58 | metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
59 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
60 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
61 | metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
62 |
63 | if num_batches is not None and i%num_batches == 0:
64 | break
65 |
66 | def evaluate(model, criterion, data_loader, device, num_batches=None, print_freq=100, train=False):
67 | if not train:
68 | model.eval()
69 | header = 'Test:'
70 | else:
71 | model.train()
72 | header = 'Train:'
73 |
74 | metric_logger = utils.MetricLogger(delimiter=" ")
75 |
76 | with torch.no_grad():
77 | batch_id = 0
78 | for image, target in metric_logger.log_every(data_loader, print_freq, header):
79 | image = image.to(device, non_blocking=True)
80 | target = target.to(device, non_blocking=True)
81 | output = model(image)
82 | loss = criterion(output, target)
83 |
84 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
85 | # FIXME need to take into account that the datasets
86 | # could have been padded in distributed setup
87 | batch_size = image.shape[0]
88 | metric_logger.update(loss=loss.item())
89 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
90 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
91 |
92 | if num_batches is not None and batch_id+1 == num_batches:
93 | break;
94 | else:
95 | batch_id += 1
96 | # gather the stats from all processes
97 | metric_logger.synchronize_between_processes()
98 |
99 | print(' * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}'
100 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5))
101 | return metric_logger.acc1.global_avg
102 |
103 |
104 | def main(args):
105 | if args.output_dir:
106 | utils.mkdir(args.output_dir)
107 |
108 | # creating device and setting backend
109 | device = torch.device(args.device)
110 | torch.backends.cudnn.benchmark = True
111 |
112 | # data loading code
113 | print("Loading data")
114 | train_dir = os.path.join(args.data_path, 'train')
115 | val_dir = os.path.join(args.data_path, 'val')
116 |
117 | dataset = datasets.ImageFolder(
118 | train_dir,
119 | transforms.Compose([
120 | transforms.RandomResizedCrop(crop_size),
121 | transforms.RandomHorizontalFlip(),
122 | ]))
123 | dataset_test = datasets.ImageFolder(val_dir, transforms.Compose([
124 | transforms.Resize(val_size),
125 | transforms.CenterCrop(crop_size),
126 | ]))
127 |
128 | if distributed:
129 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
130 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
131 | else:
132 | train_sampler = torch.utils.data.RandomSampler(dataset)
133 | test_sampler = torch.utils.data.SequentialSampler(dataset_test)
134 |
135 | data_loader = torch.utils.data.DataLoader(
136 | dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
137 | sampler=train_sampler, num_workers=args.workers, pin_memory=True)
138 |
139 | data_loader_test = torch.utils.data.DataLoader(
140 | dataset_test, batch_size=args.eval_batch_size, shuffle=False,
141 | sampler=test_sampler, num_workers=args.workers, pin_memory=True)
142 |
143 | print("Creating model", args.arch)
144 | # Loading fp32 model
145 | model = torchvision.models.__dict__[args.arch](pretrained=True)
146 | model.to(device)
147 |
148 | # SGD optimizer
149 | optimizer = torch.optim.SGD(
150 | model.parameters(), lr=args.lr, momentum=args.momentum,
151 | weight_decay=args.weight_decay)
152 |
153 | if args.data_type.lower() == "bf8" :
154 | opt_level = "O2" if args.device == "cuda" else "O0"
155 | loss_scale = "dynamic" if args.device == "cuda" else 1.0
156 | model, optimizer = amp.initialize(model, optimizer,
157 | opt_level=opt_level,
158 | keep_batchnorm_fp32=True,
159 | loss_scale=loss_scale
160 | )
161 | # LR scheduler
162 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
163 | step_size=args.lr_step_size,
164 | gamma=args.lr_gamma)
165 | # Loss function
166 | criterion = nn.CrossEntropyLoss()
167 |
168 | if args.resume:
169 | checkpoint = torch.load(args.resume, map_location='cpu')
170 | model.load_state_dict(checkpoint['model'])
171 | optimizer.load_state_dict(checkpoint['optimizer'])
172 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
173 | args.start_epoch = checkpoint['epoch'] + 1
174 |
175 | ########################################################
176 | list_exempt_layers = []
177 | list_layers_output_fused = []
178 | if "resnet" in args.arch or "resnext" in args.arch:
179 | list_exempt_layers = ["conv1","fc"]
180 | elif args.arch == "vgg19_bn":
181 | list_exempt_layers = ["features.0", "classifier.6"]
182 | elif args.arch == "inception_v3":
183 | list_exempt_layers = ["Conv2d_1a_3x3.conv", "fc"]
184 |
185 | print("list of exempted layers : ", list_exempt_layers)
186 | model, emulator = mpt_emu.quantize_model(model, optimizer=optimizer, dtype=args.data_type.lower(), hw_patch=args.patch_ops,
187 | list_exempt_layers=list_exempt_layers, list_layers_output_fused=list_layers_output_fused,
188 | device=device, verbose=True)
189 |
190 | emulator.set_default_inference_qconfig()
191 | start_time = time.time()
192 | evaluate(model, criterion, data_loader_test, device=device)
193 | print('quantization-aware training : Fine-tuning the network for {} epochs'.format(args.epochs))
194 | for epoch in range(args.start_epoch, args.epochs):
195 | train_one_epoch(args, model, criterion, optimizer, emulator,
196 | data_loader, device, epoch, args.print_freq)
197 | lr_scheduler.step()
198 | with torch.no_grad():
199 | print('evaluating fine-tuned model')
200 | eval_model = emulator.fuse_bnlayers_and_quantize_model(model)
201 | #quantized_eval_model = copy.deepcopy(model)
202 | evaluate(eval_model, criterion, data_loader_test, device=device)
203 | #quantized_eval_model = emulator.fuse_bnlayers_and_quantize_model(quantized_eval_model)
204 | #print('evaluate quantized model')
205 | #evaluate(quantized_eval_model, criterion, data_loader_test, device=device)
206 |
207 | model.train()
208 |
209 | if args.output_dir:
210 | checkpoint = {
211 | 'model': model.state_dict(),
212 | 'model_qconfig_dict' : emulator.emulator.model_qconfig_dict,
213 | 'optimizer': optimizer.state_dict(),
214 | 'lr_scheduler': lr_scheduler.state_dict(),
215 | 'epoch': epoch,
216 | 'args': args}
217 | utils.save_on_master(
218 | checkpoint,
219 | os.path.join(args.output_dir, 'checkpoint.pth'))
220 | print('saving models after {} epochs'.format(epoch))
221 |
222 | total_time = time.time() - start_time
223 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
224 | print('Training time {}'.format(total_time_str))
225 |
226 | def parse_args():
227 | import argparse
228 | parser = argparse.ArgumentParser(description='PyTorch Classification Training')
229 |
230 | parser.add_argument('--data-path',
231 | default='/fastdata/imagenet/',
232 | help='dataset')
233 | parser.add_argument('--arch',
234 | default='mobilenet_v2',
235 | help='model')
236 | parser.add_argument('--data-type',
237 | default='bf8',
238 | help='supported types : e5m2, e4m3, bf16')
239 | parser.add_argument('--patch-ops', default='None',
240 | help='Patch Ops to enable custom gemm implementation')
241 | parser.add_argument('--pruning-algo',
242 | default="None",
243 | help='Pruning method: fine-grained, unstructured, adaptive')
244 | parser.add_argument('--device',
245 | default='cuda',
246 | help='device')
247 |
248 | parser.add_argument('-b', '--batch-size', default=256, type=int,
249 | help='batch size for calibration/training')
250 | parser.add_argument('--eval-batch-size', default=256, type=int,
251 | help='batch size for evaluation')
252 | parser.add_argument('--epochs', default=5, type=int, metavar='N',
253 | help='number of total epochs to run')
254 |
255 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
256 | help='number of data loading workers (default: 16)')
257 | parser.add_argument('--lr',
258 | default=0.1, type=float,
259 | help='initial learning rate')
260 | parser.add_argument('--momentum',
261 | default=0.9, type=float, metavar='M',
262 | help='momentum')
263 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
264 | metavar='W', help='weight decay (default: 1e-4)',
265 | dest='weight_decay')
266 | parser.add_argument('--lr-step-size', default=30, type=int,
267 | help='decrease lr every step-size epochs')
268 | parser.add_argument('--lr-gamma', default=0.1, type=float,
269 | help='decrease lr by a factor of lr-gamma')
270 | parser.add_argument('--print-freq', default=100, type=int,
271 | help='print frequency')
272 | parser.add_argument('--output-dir', default='.', help='path where to save')
273 | parser.add_argument('--resume', default='', help='resume from checkpoint')
274 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
275 | help='start epoch')
276 | parser.add_argument("--cache-dataset", dest="cache_dataset",\
277 | help="Cache the datasets for quicker initialization. \
278 | It also serializes the transforms",
279 | action="store_true",
280 | )
281 |
282 | args = parser.parse_args()
283 | return args
284 |
285 |
286 | if __name__ == "__main__":
287 | args = parse_args()
288 | print(args)
289 | main(args)
290 |
--------------------------------------------------------------------------------
/examples/inference/classifier/imagenet_test.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Dharma Teja Vooturi, Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | import torch
11 | import torchvision
12 |
13 | import os
14 | import torch
15 | import torchvision
16 | import torchvision.transforms as transforms
17 | import torchvision.datasets as datasets
18 |
19 | import time
20 | import argparse
21 | from train import evaluate
22 | from mpemu import mpt_emu
23 |
24 | def get_model_exempt_layers(args, model):
25 |
26 | list_exempt_layers = []
27 | list_layers_output_fused = []
28 | if args.model == "alexnet":
29 | list_exempt_layers = ["features.0", "classifier.6"]
30 | elif args.model == "vgg16_bn":
31 | list_exempt_layers = ["features.0", "features.1", "classifier.6"]
32 | elif args.model == "inception_v3":
33 | list_exempt_layers = ["Conv2d_1a_3x3.conv", "fc"]
34 | #list_exempt_layers = ["fc"]
35 | elif args.model == "squeezenet1_1":
36 | list_exempt_layers = ["features.0", "classifier.1"]
37 | #list_exempt_layers = ["classifier.1"]
38 | elif "resnet" in args.model or "resnext" in args.model:
39 | list_exempt_layers = ["conv1","bn1","fc"]
40 | #list_exempt_layers = ["fc"]
41 | elif "densenet" in args.model:
42 | list_exempt_layers = ["features.conv0","features.norm0", "classifier"]
43 | #list_exempt_layers = ["classifier"]
44 | elif "mobilenet" in args.model or "efficientnet" in args.model:
45 | # Exempting Features[0] and classifier
46 | list_exempt_layers = ["features.0.0","features.0.1","classifier.1"]
47 | #list_exempt_layers = ["classifier.1"]
48 | elif args.model == "mobilenet_v3_small" or args.model == "mobilenet_v3_large":
49 | list_exempt_layers = ["features.0.0","features.0.1","classifier.0","classifier.3"]
50 | #list_exempt_layers = ["classifier.0","classifier.3"]
51 | elif args.model == "wide_resnet50_2":
52 | list_exempt_layers = ["features.0.0","features.0.1","classifier.3"]
53 | #list_exempt_layers = ["classifier.3"]
54 | elif args.model == "mnasnet1_0":
55 | list_exempt_layers = ["layers.0","layers.1","classifier.1"]
56 | elif args.model == "shufflenet_v2_x1_0":
57 | list_exempt_layers = ["conv1.0","conv1.1","fc"]
58 | #list_exempt_layers = ["fc"]
59 |
60 | prev_name = None
61 | prev_module = None
62 | for name,module in model.named_modules():
63 | if type(module) == torch.nn.BatchNorm2d and type(prev_module) == torch.nn.Conv2d:
64 | list_layers_output_fused.append(prev_name)
65 | if type(module) == torch.nn.Linear:
66 | list_layers_output_fused.append(name)
67 |
68 | prev_module = module
69 | prev_name = name
70 |
71 | return list_exempt_layers, list_layers_output_fused
72 |
73 | def get_data_loaders(args):
74 | ### Data loader construction
75 | traindir = os.path.join(args.data_path, 'train')
76 | valdir = os.path.join(args.data_path, 'val')
77 |
78 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
79 | std=[0.229, 0.224, 0.225])
80 |
81 | dataset = datasets.ImageFolder(
82 | traindir,
83 | transforms.Compose([
84 | transforms.RandomResizedCrop(224),
85 | transforms.RandomHorizontalFlip(),
86 | transforms.ToTensor(),
87 | normalize,
88 | ]))
89 |
90 | train_loader = torch.utils.data.DataLoader(
91 | dataset, batch_size=args.batch_size, shuffle=True,
92 | num_workers=args.workers, pin_memory=True)
93 |
94 | test_loader = torch.utils.data.DataLoader(
95 | datasets.ImageFolder(valdir, transforms.Compose([
96 | transforms.Resize(256),
97 | transforms.CenterCrop(224),
98 | transforms.ToTensor(),
99 | normalize,
100 | ])),
101 | batch_size=args.eval_batch_size, shuffle=False,
102 | num_workers=args.workers, pin_memory=True)
103 |
104 | return train_loader, test_loader
105 |
106 | def print_bn_running_stats(model):
107 | rmean = None
108 | rvar = None
109 | for name, module in model.named_children():
110 | if isinstance(module, torch.nn.BatchNorm2d):
111 | if name == "features.0.1": #"bn1":
112 | print("--> layer: {}, tracking running stats : {}".format( name, module.track_running_stats))
113 | #break
114 | rmean = module.running_mean
115 | rvar = module.running_var
116 | #print(module.running_mean)
117 | #print(module.running_var)
118 | #return module.running_mean, module.running_var #rmean, rvar
119 | return rmean, rvar
120 |
121 | if __name__ == "__main__":
122 | parser = argparse.ArgumentParser()
123 | parser.add_argument("--data-path", default="/fastdata/imagenet/", help="Path to imagenet dataset")
124 | parser.add_argument("--model", default="resnet50",
125 | choices=["alexnet", "vgg16_bn", "resnet18", "resnet50","resnext50_32x4d","densenet121",\
126 | "densenet201","mobilenet_v2","shufflenet_v2_x1_0","mobilenet_v3_large","mobilenet_v3_small",\
127 | "wide_resnet50_2","resnext101_32x8d","mnasnet1_0","efficientnet_b7","regnet_x_32gf",\
128 | "inception_v3","squeezenet1_1","efficientnet_b0"],
129 | help="Name of the neural network architecture")
130 | parser.add_argument('--data-type', default='e5m2',
131 | help='supported types : e5m2, e4m3, bf16' )
132 | parser.add_argument('--patch-ops', default='None',
133 | help='Patch Ops to enable custom gemm implementation')
134 | parser.add_argument('--device',default='cuda', help='device')
135 | parser.add_argument('--verbose', action="store_true", default=False, help='show debug messages')
136 |
137 | # Post training quantization(PTQ) related
138 | parser.add_argument('--recalibrate-bn', action="store_true", default=False,
139 | help='Perform batchnorm recalibration')
140 | parser.add_argument("--num-calibration-batches", type=int, default=500,
141 | help="Number of batches for BatchNorm calibration")
142 | parser.add_argument("--batch-size", type=int, default=256, help="Batch size for train/calibrate")
143 | parser.add_argument('--eval-batch-size', default=32, type=int,
144 | help='batch size for evaluation')
145 | parser.add_argument("--workers", type=int, default=16)
146 |
147 | # Logging and storing model
148 | parser.add_argument("--output-dir", default=".", help="Output directory to dump model")
149 |
150 | args = parser.parse_args()
151 | print(args)
152 | model_dict = {
153 | "alexnet" : "AlexNet_Weights.DEFAULT",
154 | "vgg16_bn": "VGG16_BN_Weights.DEFAULT",
155 | "resnet18" : "ResNet18_Weights.DEFAULT",
156 | "resnet50" : "ResNet50_Weights.DEFAULT",
157 | "resnext50_32x4d" : "ResNeXt50_32X4D_Weights.DEFAULT",
158 | "resnext101_32x8d" : "ResNeXt101_32X8D_Weights.DEFAULT",
159 | "wide_resnet50_2" : "Wide_ResNet50_2_Weights.DEFAULT",
160 | "densenet121" : "DenseNet121_Weights.DEFAULT",
161 | "densenet201" : "DenseNet201_Weights.DEFAULT",
162 | "mobilenet_v2" : "MobileNet_V2_Weights.DEFAULT",
163 | "shufflenet_v2_x1_0" : "ShuffleNet_V2_Weights.DEFAULT",
164 | "mobilenet_v3_large" : "MobileNet_V3_Large_Weights.DEFAULT",
165 | "mobilenet_v3_small" : "MobileNet_V3_Small_Weights.DEFAULT",
166 | "mnasnet1_0" : "MNASNet1_0_Weights.DEFAULT",
167 | "efficientnet_b7" : "EfficientNet_B7_Weights.DEFAULT",
168 | "regnet_x_32gf" : "RegNet_X_32GF_Weights.DEFAULT",
169 | "inception_v3" : "Inception_V3_Weights.DEFAULT",
170 | "squeezenet1_1" : "SqueezeNet1_1_Weights.DEFAULT",
171 | "efficientnet_b0" : "EfficientNet_B0_Weights.DEFAULT",
172 | }
173 | # Creating output directory
174 | if not os.path.exists(args.output_dir):
175 | os.makedirs(args.output_dir)
176 |
177 | # Create the model and move to GPU
178 | device = torch.device(args.device)
179 | # Data loaders and loss function
180 | train_loader,test_loader = get_data_loaders(args)
181 | criterion = torch.nn.CrossEntropyLoss()
182 | # Model
183 | model = torchvision.models.__dict__[args.model](pretrained=True)
184 | #model = torchvision.models.__dict__[args.model](weights=model_dict[args.model])
185 | model = model.to(device)
186 | model.eval()
187 | #print(model)
188 |
189 | print("Evaluating original {} model to establish baseline".format(args.model))
190 | evaluate(model, criterion, test_loader, device)
191 |
192 | # Create a list of exempt_layers
193 | list_exempt_layers, list_layers_output_fused = get_model_exempt_layers(args, model)
194 | print("Preparing the {} model for {} quantization".format(args.model, args.data_type.lower()))
195 | print("List of exempt layers : ", list_exempt_layers)
196 | #print(list_layers_output_fused)
197 | model, emulator = mpt_emu.quantize_model (model, dtype=args.data_type.lower(), calibrate=args.recalibrate_bn, hw_patch=args.patch_ops,
198 | list_exempt_layers=list_exempt_layers, list_layers_output_fused=list_layers_output_fused,
199 | device=device, verbose=args.verbose)
200 |
201 | if args.recalibrate_bn == True:
202 | print("Calibrating the {} model for {} batches of training data".format(args.model, args.num_calibration_batches))
203 | model.train()
204 | evaluate(model, criterion, train_loader, device,
205 | num_batches=args.num_calibration_batches, train=True)
206 |
207 | model.eval()
208 | print("Fusing BatchNorm layers")
209 | model = emulator.fuse_bnlayers_and_quantize_model(model)
210 | print("Evaluating {} fused {} model. ".format(args.model, args.data_type.lower()))
211 | evaluate(model, criterion, test_loader, device)
212 |
--------------------------------------------------------------------------------
/examples/inference/classifier/launch.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Dharma Teja Vooturi, Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | import subprocess
11 | import itertools
12 | import os
13 |
14 | model_choices = []
15 | model_choices += ["resnet50"]
16 | model_choices += ["wide_resnet50_2"]
17 | model_choices += ["resnext50_32x4d"]
18 | model_choices += ["resnext101_32x8d"]
19 | model_choices += ["densenet121"]
20 | model_choices += ["densenet201"]
21 | model_choices += ["mobilenet_v2"]
22 | model_choices += ["mobilenet_v3_small"]
23 | model_choices += ["mobilenet_v3_large"]
24 | model_choices += ["inception_v3"]
25 | model_choices += ["squeezenet1_1"]
26 | model_choices += ["efficientnet_b0"]
27 |
28 | device_choices = ["cuda:0"]
29 | batch_size = 64
30 | data_type="e4m3" # Suported types : e5m2, e4m3, e3m4, hybrid
31 | recalibrate_bn=True#False if data_type == "e4m3" else True
32 | num_calibration_batches = 100
33 | finetune_lr=0.0002
34 | finetune_epochs=2
35 | cmodel="none"
36 | BASE_GPU_ID = 2
37 | NUM_GPUS = 1
38 | print_results = False
39 | verbose=True
40 | data_dir="/fastdata/imagenet/"
41 | cur_sbps = []
42 |
43 | for exp_id, exp_config in enumerate(itertools.product(model_choices, device_choices)):
44 | model, device = exp_config
45 |
46 | # Creating output directory
47 | output_dir = "calib_experiments/{}-{}".format(model, device)
48 | dump_fp = os.path.join(output_dir,"log.txt")
49 | if not os.path.exists(output_dir):
50 | os.makedirs(output_dir)
51 |
52 | gpu_id = BASE_GPU_ID + (exp_id % NUM_GPUS)
53 | cmd = ""
54 | if "cuda" in device:
55 | cmd = "CUDA_VISIBLE_DEVICES={} ".format(gpu_id)
56 | cmd += "python imagenet_test.py --model {} --batch-size {} --data-type {} --device {} ".format(
57 | model, batch_size, data_type, device)
58 | if cmodel != "none" and device == "cpu":
59 | cmd += "--patch-ops {} ".format(cmodel)
60 | if recalibrate_bn:
61 | cmd += "--recalibrate-bn --num-calibration-batches {} ".format(num_calibration_batches)
62 | if verbose:
63 | cmd += "--verbose "
64 | cmd += "--data-path {} --output-dir {} 2>&1 | tee {}".format(data_dir, output_dir, dump_fp)
65 |
66 | print(cmd)
67 |
68 | if print_results:
69 | fh = open(dump_fp)
70 | accuracy = float(fh.readlines()[-1].strip().split()[2].strip())
71 | print("{:.2f}, ".format(accuracy), end="")
72 | fh.close()
73 | continue
74 |
75 | p = subprocess.Popen(cmd, shell=True)
76 | cur_sbps.append(p)
77 |
78 | if exp_id%NUM_GPUS == NUM_GPUS-1:
79 | exit_codes = [p.wait() for p in cur_sbps]
80 | cur_sbps = [] # Emptying the process list
81 |
--------------------------------------------------------------------------------
/examples/inference/classifier/test_scaleshift.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Dharma Teja Vooturi, Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | import os
11 | import torch
12 | import torchvision
13 | import torchvision.transforms as transforms
14 | import torchvision.datasets as datasets
15 |
16 |
17 | def get_test_loader(data_path, batch_size=256):
18 | val_dir = os.path.join(data_path, 'val')
19 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
20 | std=[0.229, 0.224, 0.225])
21 |
22 | test_loader = torch.utils.data.DataLoader(
23 | datasets.ImageFolder(val_dir, transforms.Compose([
24 | transforms.Resize(256),
25 | transforms.CenterCrop(224),
26 | transforms.ToTensor(),
27 | normalize,
28 | ])),
29 | batch_size=batch_size, shuffle=False,
30 | num_workers=4, pin_memory=True)
31 |
32 | return test_loader
33 |
34 | arch = "resnet50"
35 | model = torchvision.models.__dict__[arch](pretrained=True)
36 | device = torch.device("cuda")
37 | model.to(device)
38 |
39 | test_loader = get_test_loader("/fastdata/imagenet")
40 | criterion = torch.nn.CrossEntropyLoss()
41 |
42 | from train import evaluate
43 | def test():
44 | evaluate(model, criterion, test_loader, device, num_batches=2, print_freq=1)
45 |
46 | test()
47 |
48 | from mptemu import scale_shift
49 | model = scale_shift.replace_batchnorms_with_scaleshifts(model) # Replacing BN with scaleshift layers
50 | model.to(device)
51 | test()
52 |
53 |
54 |
--------------------------------------------------------------------------------
/examples/inference/classifier/train.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Dharma Teja Vooturi, Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | import datetime
11 | import os
12 | import time
13 |
14 | import torch
15 | import torch.utils.data
16 | from torch import nn
17 | import torchvision
18 | from torchvision import transforms
19 |
20 | import utils
21 |
22 | try:
23 | from apex import amp
24 | except ImportError:
25 | amp = None
26 |
27 | def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq,
28 | apex=False, num_batches=None):
29 | model.train()
30 | metric_logger = utils.MetricLogger(delimiter=" ")
31 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
32 | metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}'))
33 |
34 | header = 'Epoch: [{}]'.format(epoch)
35 | i = 0
36 | for image, target in metric_logger.log_every(data_loader, print_freq, header):
37 | i += 1
38 | start_time = time.time()
39 | image, target = image.to(device), target.to(device)
40 | output = model(image)
41 | loss = criterion(output, target)
42 |
43 | optimizer.zero_grad()
44 | if apex:
45 | with amp.scale_loss(loss, optimizer) as scaled_loss:
46 | scaled_loss.backward()
47 | else:
48 | loss.backward()
49 | optimizer.step()
50 |
51 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
52 | batch_size = image.shape[0]
53 | metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
54 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
55 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
56 | metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
57 |
58 | if num_batches is not None and i%num_batches == 0:
59 | break
60 |
61 |
62 | def evaluate(model, criterion, data_loader, device, num_batches=None, print_freq=100, train=False):
63 | if not train:
64 | model.eval()
65 | header = 'Test:'
66 | else:
67 | model.train()
68 | header = 'Train:'
69 |
70 | metric_logger = utils.MetricLogger(delimiter=" ")
71 |
72 | with torch.no_grad():
73 | batch_id = 0
74 | for image, target in metric_logger.log_every(data_loader, print_freq, header):
75 | image = image.to(device, non_blocking=True)
76 | target = target.to(device, non_blocking=True)
77 | output = model(image)
78 | loss = criterion(output, target)
79 |
80 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
81 | # FIXME need to take into account that the datasets
82 | # could have been padded in distributed setup
83 | batch_size = image.shape[0]
84 | metric_logger.update(loss=loss.item())
85 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
86 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
87 |
88 | if num_batches is not None and batch_id+1 == num_batches:
89 | break;
90 | else:
91 | batch_id += 1
92 | # gather the stats from all processes
93 | metric_logger.synchronize_between_processes()
94 |
95 | print(' * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}'
96 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5))
97 | return metric_logger.acc1.global_avg
98 |
99 | def main(args):
100 | if args.apex and amp is None:
101 | raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
102 | "to enable mixed-precision training.")
103 |
104 | if args.output_dir:
105 | utils.mkdir(args.output_dir)
106 |
107 | #utils.init_distributed_mode(args)
108 | args.distributed = False
109 | print(args)
110 |
111 | device = torch.device(args.device)
112 |
113 | torch.backends.cudnn.benchmark = True
114 |
115 | train_dir = os.path.join(args.data_path, 'train')
116 | val_dir = os.path.join(args.data_path, 'val')
117 |
118 | dataset = datasets.ImageFolder(
119 | train_dir,
120 | transforms.Compose([
121 | transforms.RandomResizedCrop(crop_size),
122 | transforms.RandomHorizontalFlip(),
123 | ]))
124 | dataset_test = datasets.ImageFolder(val_dir, transforms.Compose([
125 | transforms.Resize(val_size),
126 | transforms.CenterCrop(crop_size),
127 | ]))
128 |
129 | if distributed:
130 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
131 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
132 | else:
133 | train_sampler = torch.utils.data.RandomSampler(dataset)
134 | test_sampler = torch.utils.data.SequentialSampler(dataset_test)
135 |
136 | data_loader = torch.utils.data.DataLoader(
137 | dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
138 | sampler=train_sampler, num_workers=args.workers, pin_memory=True)
139 |
140 | data_loader_test = torch.utils.data.DataLoader(
141 | dataset_test, batch_size=args.batch_size, shuffle=False,
142 | sampler=test_sampler, num_workers=args.workers, pin_memory=True)
143 |
144 | print("Creating model")
145 | model = torchvision.models.__dict__[args.model](pretrained=args.pretrained)
146 | model.to(device)
147 | if args.distributed and args.sync_bn:
148 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
149 |
150 | criterion = nn.CrossEntropyLoss()
151 |
152 | optimizer = torch.optim.SGD(
153 | model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
154 |
155 | if args.apex:
156 | model, optimizer = amp.initialize(model, optimizer,
157 | opt_level=args.apex_opt_level
158 | )
159 |
160 | #lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
161 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[8,12], gamma=args.lr_gamma)
162 |
163 | model_without_ddp = model
164 | if args.distributed:
165 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
166 | model_without_ddp = model.module
167 |
168 | if args.resume:
169 | checkpoint = torch.load(args.resume, map_location='cpu')
170 | model_without_ddp.load_state_dict(checkpoint['model'])
171 | optimizer.load_state_dict(checkpoint['optimizer'])
172 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
173 | args.start_epoch = checkpoint['epoch'] + 1
174 |
175 | if args.test_only:
176 | evaluate(model, criterion, data_loader_test, device=device)
177 | return
178 |
179 | ################### Quantization and Sparsity ##########################
180 | import collections
181 | filter_module_types = [torch.nn.Conv2d, torch.nn.Linear] # Only quantizing convolution and linear modules
182 | exempt_modules = ["conv1","fc"]
183 | is_training = True
184 |
185 | ## QUANTIZATION ##
186 | model_qconfig_dict = collections.OrderedDict()
187 | if args.do_quantization:
188 | from quantemu import qutils
189 | ### Preparing the quantization configuration dictionary ###
190 | mod_qconfig = qutils.get_module_quant_config(args.qlevel, args.qdtype, args.qscheme, is_training=is_training)
191 | model_qconfig_dict = qutils.get_or_update_model_quant_config_dict(model, filter_module_types, mod_qconfig,
192 | exempt_modules=exempt_modules)
193 | print("Model quantization configuration")
194 | for layer,qconfig in model_qconfig_dict.items():
195 | print(layer, qconfig)
196 | print()
197 |
198 | # Setting up the model for QAT
199 | qutils.reset_quantization_setup(model, model_qconfig_dict)
200 | qhooks = qutils.add_quantization_hooks(model, model_qconfig_dict, is_training=is_training)
201 |
202 | print("Start training")
203 | start_time = time.time()
204 | for epoch in range(args.start_epoch, args.epochs):
205 | if args.distributed:
206 | train_sampler.set_epoch(epoch)
207 |
208 | train_one_epoch(model, criterion, optimizer, data_loader, device, epoch,
209 | args.print_freq, apex=args.apex, num_batches=None)
210 |
211 | lr_scheduler.step()
212 | evaluate(model, criterion, data_loader_test, device=device)
213 | if args.output_dir:
214 | checkpoint = {
215 | 'model': model_without_ddp.state_dict(),
216 | 'optimizer': optimizer.state_dict(),
217 | 'lr_scheduler': lr_scheduler.state_dict(),
218 | 'epoch': epoch,
219 | 'args': args}
220 | utils.save_on_master(
221 | checkpoint,
222 | os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
223 | utils.save_on_master(
224 | checkpoint,
225 | os.path.join(args.output_dir, 'checkpoint.pth'))
226 |
227 | total_time = time.time() - start_time
228 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
229 | print('Training time {}'.format(total_time_str))
230 |
231 |
232 | def parse_args():
233 | import argparse
234 | parser = argparse.ArgumentParser(description='PyTorch Classification Training')
235 |
236 | parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', help='dataset')
237 | parser.add_argument('--model', default='resnet18', help='model')
238 | parser.add_argument('--device', default='cuda', help='device')
239 | parser.add_argument('-b', '--batch-size', default=256, type=int)
240 | parser.add_argument('--epochs', default=100, type=int, metavar='N',
241 | help='number of total epochs to run')
242 | parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
243 | help='number of data loading workers (default: 16)')
244 | parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
245 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
246 | help='momentum')
247 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
248 | metavar='W', help='weight decay (default: 1e-4)',
249 | dest='weight_decay')
250 | parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
251 | parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
252 | parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
253 | parser.add_argument('--output-dir', default='.', help='path where to save')
254 | parser.add_argument('--resume', default='', help='resume from checkpoint')
255 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
256 | help='start epoch')
257 | parser.add_argument(
258 | "--cache-dataset",
259 | dest="cache_dataset",
260 | help="Cache the datasets for quicker initialization. It also serializes the transforms",
261 | action="store_true",
262 | )
263 | parser.add_argument(
264 | "--sync-bn",
265 | dest="sync_bn",
266 | help="Use sync batch norm",
267 | action="store_true",
268 | )
269 | parser.add_argument(
270 | "--test-only",
271 | dest="test_only",
272 | help="Only test the model",
273 | action="store_true",
274 | )
275 | parser.add_argument(
276 | "--pretrained",
277 | dest="pretrained",
278 | help="Use pre-trained models from the modelzoo",
279 | action="store_true",
280 | )
281 |
282 | # Mixed precision training parameters
283 | parser.add_argument('--apex', action='store_true',
284 | help='Use apex for mixed precision training')
285 | parser.add_argument('--apex-opt-level', default='O1', type=str,
286 | help='For apex mixed precision training'
287 | 'O0 for FP32 training, O1 for mixed precision training.'
288 | 'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet'
289 | )
290 |
291 | # distributed training parameters
292 | parser.add_argument('--world-size', default=1, type=int,
293 | help='number of distributed processes')
294 | parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
295 |
296 |
297 | # Pruning arguments
298 | # Quantization arguments
299 | parser.add_argument("--do-quantization", action="store_true", help="Whether to do quantization or not")
300 | parser.add_argument("--qdtype", default="bfloat8", help="Quantization data type")
301 | parser.add_argument("--qscheme", default="rne", help="Quantization scheme")
302 | parser.add_argument("--qlevel", type=int, default=3, choices=[0,1,3,6,7],
303 | help="0->No quantization, 1->wt only, 3->(wt+iact), 6->(iact+oact) 7->(wt+iact+oact)")
304 | parser.add_argument("--patch_ops", action="store_true")
305 |
306 | # Pruning arguments
307 | parser.add_argument("--do-pruning", action="store_true", help="Whether or not to do pruning")
308 | parser.add_argument("--sbh", type=int, default=-1, help="Scope block height")
309 | parser.add_argument("--sbw", type=int, default=-1, help="Scope block width")
310 | parser.add_argument("--sparsity", type=float, default=0.5, help="Amount of sparsity to impose [0,1]")
311 |
312 | args = parser.parse_args()
313 |
314 | return args
315 |
316 |
317 | if __name__ == "__main__":
318 | args = parse_args()
319 | main(args)
320 |
--------------------------------------------------------------------------------
/examples/inference/classifier/utils.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Dharma Teja Vooturi, Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | from collections import defaultdict, deque
11 | import datetime
12 | import time
13 | import torch
14 | import torch.distributed as dist
15 |
16 | import errno
17 | import os
18 |
19 |
20 | class SmoothedValue(object):
21 | """Track a series of values and provide access to smoothed values over a
22 | window or the global series average.
23 | """
24 |
25 | def __init__(self, window_size=20, fmt=None):
26 | if fmt is None:
27 | fmt = "{median:.4f} ({global_avg:.4f})"
28 | self.deque = deque(maxlen=window_size)
29 | self.total = 0.0
30 | self.count = 0
31 | self.fmt = fmt
32 |
33 | def update(self, value, n=1):
34 | self.deque.append(value)
35 | self.count += n
36 | self.total += value * n
37 |
38 | def synchronize_between_processes(self):
39 | """
40 | Warning: does not synchronize the deque!
41 | """
42 | if not is_dist_avail_and_initialized():
43 | return
44 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
45 | dist.barrier()
46 | dist.all_reduce(t)
47 | t = t.tolist()
48 | self.count = int(t[0])
49 | self.total = t[1]
50 |
51 | @property
52 | def median(self):
53 | d = torch.tensor(list(self.deque))
54 | return d.median().item()
55 |
56 | @property
57 | def avg(self):
58 | d = torch.tensor(list(self.deque), dtype=torch.float32)
59 | return d.mean().item()
60 |
61 | @property
62 | def global_avg(self):
63 | return self.total / self.count
64 |
65 | @property
66 | def max(self):
67 | return max(self.deque)
68 |
69 | @property
70 | def value(self):
71 | return self.deque[-1]
72 |
73 | def __str__(self):
74 | return self.fmt.format(
75 | median=self.median,
76 | avg=self.avg,
77 | global_avg=self.global_avg,
78 | max=self.max,
79 | value=self.value)
80 |
81 |
82 | class MetricLogger(object):
83 | def __init__(self, delimiter="\t"):
84 | self.meters = defaultdict(SmoothedValue)
85 | self.delimiter = delimiter
86 |
87 | def update(self, **kwargs):
88 | for k, v in kwargs.items():
89 | if isinstance(v, torch.Tensor):
90 | v = v.item()
91 | assert isinstance(v, (float, int))
92 | self.meters[k].update(v)
93 |
94 | def __getattr__(self, attr):
95 | if attr in self.meters:
96 | return self.meters[attr]
97 | if attr in self.__dict__:
98 | return self.__dict__[attr]
99 | raise AttributeError("'{}' object has no attribute '{}'".format(
100 | type(self).__name__, attr))
101 |
102 | def __str__(self):
103 | loss_str = []
104 | for name, meter in self.meters.items():
105 | loss_str.append(
106 | "{}: {}".format(name, str(meter))
107 | )
108 | return self.delimiter.join(loss_str)
109 |
110 | def synchronize_between_processes(self):
111 | for meter in self.meters.values():
112 | meter.synchronize_between_processes()
113 |
114 | def add_meter(self, name, meter):
115 | self.meters[name] = meter
116 |
117 | def log_every(self, iterable, print_freq, header=None):
118 | i = 0
119 | if not header:
120 | header = ''
121 | start_time = time.time()
122 | end = time.time()
123 | iter_time = SmoothedValue(fmt='{avg:.4f}')
124 | data_time = SmoothedValue(fmt='{avg:.4f}')
125 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
126 | if torch.cuda.is_available():
127 | log_msg = self.delimiter.join([
128 | header,
129 | '[{0' + space_fmt + '}/{1}]',
130 | 'eta: {eta}',
131 | '{meters}',
132 | 'time: {time}',
133 | 'data: {data}',
134 | 'max mem: {memory:.0f}'
135 | ])
136 | else:
137 | log_msg = self.delimiter.join([
138 | header,
139 | '[{0' + space_fmt + '}/{1}]',
140 | 'eta: {eta}',
141 | '{meters}',
142 | 'time: {time}',
143 | 'data: {data}'
144 | ])
145 | MB = 1024.0 * 1024.0
146 | for obj in iterable:
147 | data_time.update(time.time() - end)
148 | yield obj
149 | iter_time.update(time.time() - end)
150 | if i % print_freq == 0:
151 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
152 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
153 | if torch.cuda.is_available():
154 | print(log_msg.format(
155 | i, len(iterable), eta=eta_string,
156 | meters=str(self),
157 | time=str(iter_time), data=str(data_time),
158 | memory=torch.cuda.max_memory_allocated() / MB))
159 | else:
160 | print(log_msg.format(
161 | i, len(iterable), eta=eta_string,
162 | meters=str(self),
163 | time=str(iter_time), data=str(data_time)))
164 | i += 1
165 | end = time.time()
166 | total_time = time.time() - start_time
167 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
168 | print('{} Total time: {}'.format(header, total_time_str))
169 |
170 |
171 | def accuracy(output, target, topk=(1,)):
172 | """Computes the accuracy over the k top predictions for the specified values of k"""
173 | with torch.no_grad():
174 | maxk = max(topk)
175 | batch_size = target.size(0)
176 |
177 | _, pred = output.topk(maxk, 1, True, True)
178 | pred = pred.t()
179 | correct = pred.eq(target[None])
180 |
181 | res = []
182 | for k in topk:
183 | correct_k = correct[:k].flatten().sum(dtype=torch.float32)
184 | res.append(correct_k * (100.0 / batch_size))
185 | return res
186 |
187 |
188 | def mkdir(path):
189 | try:
190 | os.makedirs(path)
191 | except OSError as e:
192 | if e.errno != errno.EEXIST:
193 | raise
194 |
195 |
196 | def setup_for_distributed(is_master):
197 | """
198 | This function disables printing when not in master process
199 | """
200 | import builtins as __builtin__
201 | builtin_print = __builtin__.print
202 |
203 | def print(*args, **kwargs):
204 | force = kwargs.pop('force', False)
205 | if is_master or force:
206 | builtin_print(*args, **kwargs)
207 |
208 | __builtin__.print = print
209 |
210 |
211 | def is_dist_avail_and_initialized():
212 | if not dist.is_available():
213 | return False
214 | if not dist.is_initialized():
215 | return False
216 | return True
217 |
218 |
219 | def get_world_size():
220 | if not is_dist_avail_and_initialized():
221 | return 1
222 | return dist.get_world_size()
223 |
224 |
225 | def get_rank():
226 | if not is_dist_avail_and_initialized():
227 | return 0
228 | return dist.get_rank()
229 |
230 |
231 | def is_main_process():
232 | return get_rank() == 0
233 |
234 |
235 | def save_on_master(*args, **kwargs):
236 | if is_main_process():
237 | torch.save(*args, **kwargs)
238 |
239 |
240 | def init_distributed_mode(args):
241 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
242 | args.rank = int(os.environ["RANK"])
243 | args.world_size = int(os.environ['WORLD_SIZE'])
244 | args.gpu = int(os.environ['LOCAL_RANK'])
245 | elif 'SLURM_PROCID' in os.environ:
246 | args.rank = int(os.environ['SLURM_PROCID'])
247 | args.gpu = args.rank % torch.cuda.device_count()
248 | elif hasattr(args, "rank"):
249 | pass
250 | else:
251 | print('Not using distributed mode')
252 | args.distributed = False
253 | return
254 |
255 | args.distributed = True
256 |
257 | torch.cuda.set_device(args.gpu)
258 | args.dist_backend = 'nccl'
259 | print('| distributed init (rank {}): {}'.format(
260 | args.rank, args.dist_url), flush=True)
261 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
262 | world_size=args.world_size, rank=args.rank)
263 | setup_for_distributed(args.rank == 0)
264 |
--------------------------------------------------------------------------------
/examples/training/README.md:
--------------------------------------------------------------------------------
1 | # FP8 Mixed Precision Training
2 | Follow the examples listed here to train models using FP8 emulation toolkit.
3 | The toolkit supports two different training methods:
4 |
5 | ## Direct Conversion Method
6 | This method uses a single FP8 format (`E5M2`) for both forward and backward computations. Operators (ex: Convolution, Linear) perform dot-product computations on input tensors expressed in `E5M2` format and accumulated into FP32 output tensors. The outut tensors are directly converted to `E5M2` before they are written to the memoryi -- gradients are scaled using automatic loss-scaling method. Full details of the training algorithm are covered in the paper [Mixed Precision Training With 8-bit Floating Point](https://arxiv.org/pdf/1905.12334.pdf).
7 |
8 | ## Scaled-Hybrid Method
9 | This method uses Hybrid-FP8 approach which uses `E4M3` for forward computations and `E5M2` for representing error gradients. The weight and activation tensors use `per-tensor scaling` in the forward pass to compensate for the limited dynamic range of `E4M3` format. The weight and activation scaling factors are computed every iteration at run-time which are then used to quantize the FP32 output tensor to `E4M3` format. The gradients are scaled using standard automatic loss scaling methods. This method is based on the algorithm discussed in the paper [FP8 formats for Deep Learning](https://arxiv.org/abs/2209.05433)
10 |
11 | ## Loss Scaling
12 | A modified NVIDIA [Apex](https://github.com/NVIDIA/apex) library is used to extend loss scaling capabilities to CPU platforms.
13 | Follow the instructions below to patch apex library to enable CPU support.
14 |
15 | ```
16 | $ git clone https://github.com/NVIDIA/apex
17 | $ cd apex
18 | $ git checkout 05091d4
19 | $ git apply /nvidia_apex_cpu_patch_05091d4.patch
20 | ```
21 | Install python-only build.
22 | ```
23 | $ pip3 install -v --no-cache-dir ./
24 | ```
25 |
26 | ## Usage Example
27 |
28 | Modify the training script as follows to enable FP8 emulation.
29 |
30 | ```
31 | # import the emulator
32 | from mpemu import mpt_emu
33 |
34 | ...
35 |
36 | # layers exempt from FP8 conversion
37 | list_exempt_layers = ["conv1","bn1","fc"]
38 | # fused layers will be exempt from converting output tensor to FP8, the following layer will read from FP32 buffer.
39 | list_layers_output_fused = None
40 | # use 'direct' training method, Options : direct, hybrid
41 | model, emulator = mpt_emu.initialize(model, optimizer, training_algo="direct",
42 | list_exempt_layers=list_exempt_layers, list_layers_output_fused=list_layers_output_fused,
43 | device="cpu", verbose=True)
44 |
45 | ...
46 |
47 | # training loop
48 | for epoch in range(args.start_epoch, args.epochs):
49 | ...
50 |
51 | emulator.update_global_steps(epoch*len(train_loader))
52 |
53 | ...
54 |
55 | emulator.optimizer_step(optimizer)
56 |
57 | ```
58 |
--------------------------------------------------------------------------------
/examples/training/bert/launch_fp8_training.sh:
--------------------------------------------------------------------------------
1 |
2 | export MIXED_PRECISION=fp16
3 |
4 | TRAINING_ALGO=${1:-'direct'}
5 |
6 | python run_qa_no_trainer.py --model_name_or_path bert-large-uncased --dataset_name squad --per_device_train_batch_size 12 --learning_rate 3e-5 --num_train_epochs 2 --max_seq_length 384 --doc_stride 128 --fp8_training --fp8_algo=$TRAINING_ALGO --output_dir ./bert-large-uncased-fp8 |& tee bert-large-uncased-fp8.log
7 |
--------------------------------------------------------------------------------
/examples/training/bert/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers >= 4.11.0
2 | accelerate
3 | datasets >= 1.8.0
4 | torch >= 1.3.0
5 | evaluate
6 |
--------------------------------------------------------------------------------
/examples/training/nvidia_apex_cpu_patch_05091d4.patch:
--------------------------------------------------------------------------------
1 | diff --git a/apex/amp/_initialize.py b/apex/amp/_initialize.py
2 | index 3ae6fde..04be0ae 100644
3 | --- a/apex/amp/_initialize.py
4 | +++ b/apex/amp/_initialize.py
5 | @@ -87,12 +87,14 @@ def check_params_fp32(models):
6 | "When using amp.initialize, you do not need to call .half() on your model\n"
7 | "before passing it, no matter what optimization level you choose.".format(
8 | name, param.type()))
9 | + '''
10 | elif not param.is_cuda:
11 | warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
12 | "When using amp.initialize, you need to provide a model with parameters\n"
13 | "located on a CUDA device before passing it no matter what optimization level\n"
14 | "you chose. Use model.to('cuda') to use the default device.".format(
15 | name, param.type()))
16 | + '''
17 |
18 | # Backward compatibility for PyTorch 0.4
19 | if hasattr(model, 'named_buffers'):
20 | @@ -110,12 +112,14 @@ def check_params_fp32(models):
21 | "When using amp.initialize, you do not need to call .half() on your model\n"
22 | "before passing it, no matter what optimization level you choose.".format(
23 | name, buf.type()))
24 | + '''
25 | elif not buf.is_cuda:
26 | warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
27 | "When using amp.initialize, you need to provide a model with buffers\n"
28 | "located on a CUDA device before passing it no matter what optimization level\n"
29 | "you chose. Use model.to('cuda') to use the default device.".format(
30 | name, buf.type()))
31 | + '''
32 |
33 |
34 | def check_optimizers(optimizers):
35 | diff --git a/apex/amp/_process_optimizer.py b/apex/amp/_process_optimizer.py
36 | index 471289b..3d72008 100644
37 | --- a/apex/amp/_process_optimizer.py
38 | +++ b/apex/amp/_process_optimizer.py
39 | @@ -54,6 +54,9 @@ def lazy_init_with_master_weights(self):
40 | # .format(param.size()))
41 | fp32_params_this_group.append(param)
42 | param_group['params'][i] = param
43 | + elif param.type() == 'torch.FloatTensor':
44 | + fp32_params_this_group.append(param)
45 | + param_group['params'][i] = param
46 | else:
47 | raise TypeError("Optimizer's parameters must be either "
48 | "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
49 | @@ -212,6 +215,8 @@ def lazy_init_no_master_weights(self):
50 | stash.all_fp16_params.append(param)
51 | elif param.type() == 'torch.cuda.FloatTensor':
52 | stash.all_fp32_params.append(param)
53 | + elif param.type() == 'torch.FloatTensor':
54 | + stash.all_fp32_params.append(param)
55 | else:
56 | raise TypeError("Optimizer's parameters must be either "
57 | "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
58 | diff --git a/apex/amp/scaler.py b/apex/amp/scaler.py
59 | index 99888bc..cbf0f22 100644
60 | --- a/apex/amp/scaler.py
61 | +++ b/apex/amp/scaler.py
62 | @@ -6,7 +6,11 @@ from itertools import product
63 | def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=False):
64 | # Exception handling for 18.04 compatibility
65 | if check_overflow:
66 | - cpu_sum = float(model_grad.float().sum())
67 | + # handling sparse gradients
68 | + if (model_grad.is_sparse) :
69 | + cpu_sum = torch.sparse.sum(model_grad)
70 | + else :
71 | + cpu_sum = float(model_grad.float().sum())
72 | if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
73 | return True
74 |
75 | @@ -19,7 +23,11 @@ def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=F
76 | def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, a, b, check_overflow=False):
77 | # Exception handling for 18.04 compatibility
78 | if check_overflow:
79 | - cpu_sum = float(model_grad.float().sum())
80 | + # handling sparse gradients
81 | + if (model_grad.is_sparse) :
82 | + cpu_sum = torch.sparse.sum(model_grad)
83 | + else:
84 | + cpu_sum = float(model_grad.float().sum())
85 | if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
86 | return True
87 |
88 | @@ -53,7 +61,7 @@ class LossScaler(object):
89 | self._scale_seq_len = scale_window
90 | self._unskipped = 0
91 | self._has_overflow = False
92 | - self._overflow_buf = torch.cuda.IntTensor([0])
93 | + self._overflow_buf = torch.IntTensor([0])
94 | if multi_tensor_applier.available:
95 | import amp_C
96 | LossScaler.has_fused_kernel = multi_tensor_applier.available
97 |
--------------------------------------------------------------------------------
/examples/training/resnet/README.md:
--------------------------------------------------------------------------------
1 | ```
2 | python imagenet_main.py --arch resnet50 --enable-bf8 --master-weight-precision='fp16' --resume checkpoint.pth.tar /fastdata/imagenet/ |& tee $LOGFILE
3 | ```
4 |
--------------------------------------------------------------------------------
/examples/training/resnet/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 | from math import pi, cos
4 | """Learning Rate Schedulers"""
5 |
6 | class LRScheduler(object):
7 | r"""Learning Rate Scheduler
8 | For mode='step', we multiply lr with `decay_factor` at each epoch in `step`.
9 | For mode='poly'::
10 | lr = targetlr + (baselr - targetlr) * (1 - iter / maxiter) ^ power
11 | For mode='cosine'::
12 | lr = targetlr + (baselr - targetlr) * (1 + cos(pi * iter / maxiter)) / 2
13 | If warmup_epochs > 0, a warmup stage will be inserted before the main lr scheduler.
14 | For warmup_mode='linear'::
15 | lr = warmup_lr + (baselr - warmup_lr) * iter / max_warmup_iter
16 | For warmup_mode='constant'::
17 | lr = warmup_lr
18 | Parameters
19 | ----------
20 | mode : str
21 | Modes for learning rate scheduler.
22 | Currently it supports 'step', 'poly' and 'cosine'.
23 | niters : int
24 | Number of iterations in each epoch.
25 | base_lr : float
26 | Base learning rate, i.e. the starting learning rate.
27 | epochs : int
28 | Number of training epochs.
29 | step : list
30 | A list of epochs to decay the learning rate.
31 | decay_factor : float
32 | Learning rate decay factor.
33 | targetlr : float
34 | Target learning rate for poly and cosine, as the ending learning rate.
35 | power : float
36 | Power of poly function.
37 | warmup_epochs : int
38 | Number of epochs for the warmup stage.
39 | warmup_lr : float
40 | The base learning rate for the warmup stage.
41 | warmup_mode : str
42 | Modes for the warmup stage.
43 | Currently it supports 'linear' and 'constant'.
44 | """
45 | def __init__(self, optimizer, niters, args):
46 | super(LRScheduler, self).__init__()
47 |
48 | self.mode = args.lr_mode
49 | self.warmup_mode = args.warmup_mode if hasattr(args,'warmup_mode') else 'linear'
50 | assert(self.mode in ['step', 'poly', 'cosine'])
51 | assert(self.warmup_mode in ['linear', 'constant'])
52 |
53 | self.optimizer = optimizer
54 |
55 | #self.base_lr = args.base_lr if hasattr(args,'base_lr') else 0.1
56 | self.base_lr = args.lr if hasattr(args,'lr') else 0.1
57 | self.learning_rate = self.base_lr
58 | self.niters = niters
59 |
60 | self.step = [int(i) for i in args.step.split(',')] if hasattr(args,'step') else [30, 60, 90]
61 | self.decay_factor = args.decay_factor if hasattr(args,'decay_factor') else 0.1
62 | self.targetlr = args.targetlr if hasattr(args,'targetlr') else 0.0
63 | self.power = args.power if hasattr(args,'power') else 2.0
64 | self.warmup_lr = args.warmup_lr if hasattr(args,'warmup_lr') else 0.0
65 | self.max_iter = args.epochs * niters
66 | self.warmup_iters = (args.warmup_epochs if hasattr(args,'warmup_epochs') else 0) * niters
67 |
68 |
69 | print(self.base_lr, self.mode, self.warmup_mode, self.warmup_iters)
70 |
71 | def update(self, i, epoch):
72 | T = epoch * self.niters + i
73 | assert (T >= 0 and T <= self.max_iter)
74 |
75 | if self.warmup_iters > T:
76 | # Warm-up Stage
77 | if self.warmup_mode == 'linear':
78 | self.learning_rate = self.warmup_lr + (self.base_lr - self.warmup_lr) * \
79 | T / self.warmup_iters
80 | elif self.warmup_mode == 'constant':
81 | self.learning_rate = self.warmup_lr
82 | else:
83 | raise NotImplementedError
84 | else:
85 | if self.mode == 'step':
86 | count = sum([1 for s in self.step if s <= epoch])
87 | self.learning_rate = self.base_lr * pow(self.decay_factor, count)
88 | elif self.mode == 'poly':
89 | self.learning_rate = self.targetlr + (self.base_lr - self.targetlr) * \
90 | pow(1 - (T - self.warmup_iters) / (self.max_iter - self.warmup_iters), self.power)
91 | elif self.mode == 'cosine':
92 | self.learning_rate = self.targetlr + (self.base_lr - self.targetlr) * \
93 | (1 + cos(pi * (T - self.warmup_iters) / (self.max_iter - self.warmup_iters))) / 2
94 | else:
95 | raise NotImplementedError
96 |
97 | for i, param_group in enumerate(self.optimizer.param_groups):
98 | param_group['lr'] = self.learning_rate
99 |
--------------------------------------------------------------------------------
/examples/training/resnet/train_cpu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_VISIBLE_DEVICES=""
3 |
4 | HOST_ADDR=${1:-$HOSTNAME}
5 | BATCH=${2:-64}
6 | TRAINING_ALGO=${3:-'direct'}
7 | HW_PATCH=${4:-'None'}
8 | MODEL_PREC=${5:-'fp16'}
9 | PRUNING_ALG=${6:-'None'} # None, Unstructured, Adaptive, Auto
10 | DATA_DIR=${7:-'/fastdata/imagenet/'}
11 | OUTPUT_DIR=${8:-'.'}
12 | ARCH=${9:-'resnet50'}
13 | LR_MODE=${10:-'cosine'}
14 |
15 | WORLD_SIZE=${SLURM_JOB_NUM_NODES:-1}
16 | NODE_RANK=${SLURM_NODEID:-0}
17 | HOST_PORT=12345
18 | LOGFILE="train.log"
19 |
20 | CMD=" torch.distributed.launch --nnodes $WORLD_SIZE --nproc_per_node 1 --node_rank $NODE_RANK --master_addr $HOST_ADDR --master_port $HOST_PORT main_amp_cpu.py --local_rank $NODE_RANK --arch $ARCH --batch-size=$BATCH --lr-mode=$LR_MODE --fp8-training --fp8-algo=$TRAINING_ALGO --master-weight-precision=$MODEL_PREC --patch-ops=$HW_PATCH --pruning-algo $PRUNING_ALG --output-dir $OUTPUT_DIR --resume $OUTPUT_DIR/checkpoint.pth.tar $DATA_DIR "
21 |
22 | echo $CMD
23 |
24 | python -m $CMD |& tee $LOGFILE
25 |
--------------------------------------------------------------------------------
/examples/training/resnet/train_gpu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_VISIBLE_DEVICES=0,1,2,3
3 |
4 | HOST_ADDR=${1:-$HOSTNAME}
5 | BATCH=${2:-256}
6 | TRAINING_ALGO=${3:-'direct'}
7 | HW_PATCH=${4:-'None'}
8 | MODEL_PREC=${5:-'fp16'}
9 | PRUNING_ALG=${6:-'None'} # None, Unstructured, Adaptive, Auto
10 | DATA_DIR=${7:-'/fastdata/imagenet/'}
11 | OUTPUT_DIR=${8:-'.'}
12 | ARCH=${9:-'resnet50'}
13 | LR_MODE=${10:-'cosine'}
14 |
15 | WORLD_SIZE=${SLURM_JOB_NUM_NODES:-4}
16 | NODE_RANK=${SLURM_NODEID:-0}
17 | HOST_PORT=12345
18 | LOGFILE="train.log"
19 |
20 | #CMD="torchrun --nnodes=1:4 --nproc_per_node 4 --rdzv-backend=c10d --rdzv-endpoint=$HOST_NODE_ADDR main_amp.py --arch $ARCH --batch-size=$BATCH --lr-mode=$LR_MODE --fp8-training --fp8-algo=$TRAINING_ALGO --master-weight-precision=$MODEL_PREC --pruning-algo $PRUNING_ALG --momentum 0.875 --lr 0.256 --weight-decay 3.0517578125e-05 --output-dir $OUTPUT_DIR --resume $OUTPUT_DIR/checkpoint.pth.tar $DATA_DIR "
21 | CMD=" torch.distributed.launch --nproc_per_node 4 --master_addr $HOST_ADDR --master_port $HOST_PORT main_amp.py --arch $ARCH --batch-size=$BATCH --lr-mode=$LR_MODE --fp8-training --fp8-algo=$TRAINING_ALGO --master-weight-precision=$MODEL_PREC --pruning-algo $PRUNING_ALG --momentum 0.875 --lr 0.256 --weight-decay 3.0517578125e-05 --output-dir $OUTPUT_DIR --resume $OUTPUT_DIR/checkpoint.pth.tar $DATA_DIR "
22 |
23 | echo $CMD
24 |
25 | python -m $CMD |& tee $LOGFILE
26 |
--------------------------------------------------------------------------------
/mpemu/__init__.py:
--------------------------------------------------------------------------------
1 | from . import pytquant
2 | from . import cmodel
3 | from . import module_wrappers
4 |
--------------------------------------------------------------------------------
/mpemu/cmodel/__init__.py:
--------------------------------------------------------------------------------
1 | from . import simple
2 |
--------------------------------------------------------------------------------
/mpemu/cmodel/simple.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | import torch
11 | from torch import nn
12 | from torch.autograd import Function
13 | import sys
14 |
15 | import simple_gemm_dev
16 | import simple_conv2d_dev
17 | # backup the original torch functions
18 | fallback_addmm = torch.addmm
19 | fallback_matmul = torch.matmul
20 | fallback_mm = torch.mm
21 |
22 | def is_transposed(input):
23 | if input.is_contiguous():
24 | return input, False
25 | elif input.t().is_contiguous():
26 | return input.t(), True
27 | else:
28 | return input.contiguous(), False
29 |
30 | def addmm(input, mat1, mat2, beta=1.0, alpha=1.0, out=None):
31 | if input.dtype == torch.float32 and mat1.dtype == torch.float32 and \
32 | mat1.dim() == 2 and mat2.dim() == 2 and mat1.size(1) == mat2.size(0):
33 | a_mat, a_trans = is_transposed(mat1)
34 | b_mat, b_trans = is_transposed(mat2)
35 | output = out
36 | if out is None:
37 | output = torch.zeros([a_mat.size(0) if not a_trans else a_mat.size(1), b_mat.size(0) if b_trans else b_mat.size(1)])
38 |
39 | return SimpleAddmm.apply(output, input, a_mat, b_mat, alpha, beta, a_trans, b_trans)
40 | else:
41 | warnings.warn('simple.addmm does not support the input dimensions - input :{}, mat1: {}, mat2: {}, falling back to torch.addmm'.format(
42 | input.size(), mat1.size(), mat2.size()))
43 | return fallback_addmm(input, mat1, mat2, beta=beta, alpha=alpha, out=out)
44 |
45 | def matmul(input, other, out=None):
46 | if input.dtype == torch.float32 and other.dtype == torch.float32 and \
47 | input.dim() == 2 and other.dim() == 2 and input.size(1) == other.size(0):
48 | a_mat, a_trans = is_transposed(input)
49 | b_mat, b_trans = is_transposed(other)
50 | output = out
51 | if out is None:
52 | output = torch.zeros([a_mat.size(0) if not a_trans else a_mat.size(1), b_mat.size(0) if b_trans else b_mat.size(1)])
53 |
54 | return SimpleMatmul.apply(output, a_mat, b_mat, 1.0, a_trans, b_trans)
55 |
56 | # Batch MatMul implementation
57 | elif input.dtype == torch.float32 and other.dtype == torch.float32 and \
58 | input.dim() == 3 and other.dim() == 2 and input.size(2) == other.size(0):
59 | a_mat, a_trans = is_transposed(input)
60 | b_mat, b_trans = is_transposed(other)
61 | output = out
62 | if out is None:
63 | output = torch.zeros([a_mat.size(0) if not a_trans else a_mat.size(1), a_mat.size(1) if not a_trans else a_mat.size(0),
64 | b_mat.size(0) if b_trans else b_mat.size(1)])
65 |
66 | return torch.stack(tuple([SimpleMatmul.apply(out1, a_mat1, b_mat, 1.0, a_trans, b_trans) \
67 | for a_mat1, out1 in zip(a_mat, output)]))
68 | else:
69 | warnings.warn('simple.matmul does not support the input dimensions - input :{}, other: {}, falling back to torch.matmul'.format(
70 | input.size(), other.size()))
71 | return fallback_matmul(input, other, out=out)
72 |
73 | def mm(input, mat2, out=None):
74 | if input.dtype == torch.float32 and mat2.dtype == torch.float32 and \
75 | input.dim() == 2 and mat2.dim() == 2 and input.size(1) == mat2.size(0):
76 | a_mat, a_trans = is_transposed(input)
77 | b_mat, b_trans = is_transposed(mat2)
78 | output = out
79 | if out is None:
80 | output = torch.zeros([a_mat.size(0) if not a_trans else a_mat.size(1), b_mat.size(0) if b_trans else b_mat.size(1)])
81 |
82 | return SimpleMatmul.apply(output, a_mat, b_mat, 1.0, a_trans, b_trans)
83 | else:
84 | warnings.warn('simple.mm does not support the input dimensions - input :{}, mat2: {}, falling back to torch.mm'.format(
85 | input.size(), mat2.size()))
86 | return fallback_mm(input, mat2, out=out)
87 |
88 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
89 | N = input.size()[0]
90 | C = input.size()[1]
91 | H = input.size()[2]
92 | W = input.size()[3]
93 | K = weight.size()[0]
94 | C1 = weight.size()[1]
95 | R = weight.size()[2]
96 | S = weight.size()[3]
97 |
98 | if dilation[0] > 1:
99 | sys.exit("ERROR: simple_conv2d does not support dilated convolutions.")
100 | if padding[0] != padding[1]:
101 | sys.exit("ERROR: simple_conv2d does not support non-uniform padding; pad_h must be equal to pad_w.")
102 | if groups > 1:
103 | sys.exit("ERROR: simple_conv2d does not support grouped convolutions; set groups to 1.")
104 |
105 | H_out = ((H + (2*padding[0]) - dilation[0] * (R-1) -1)/stride[0]) + 1
106 | W_out = ((W + (2*padding[1]) - dilation[1] * (S-1) -1)/stride[1]) + 1
107 | output = torch.empty([N, K, int(H_out), int(W_out)])
108 | output = SimpleConv2dFunction.apply(output, input, weight, bias, stride, padding, dilation, groups)
109 | return output
110 |
111 | class SimpleAddmm(Function):
112 | @staticmethod
113 | def forward(ctx, output, input, mat1, mat2, alpha, beta, a_trans, b_trans):
114 | ctx.save_for_backward(mat1, mat2)
115 | ctx.a_trans = a_trans
116 | ctx.b_trans = b_trans
117 | ctx.alpha = alpha
118 |
119 | simple_gemm_dev.gemm(output, mat1, mat2, alpha, a_trans, b_trans)
120 | output += beta * input;
121 | ctx.mark_dirty(output)
122 | return output
123 |
124 | @staticmethod
125 | def backward(ctx, grad_output):
126 | mat1, mat2 = ctx.saved_tensors
127 |
128 | alpha = ctx.alpha
129 | a_trans = ctx.a_trans
130 | b_trans = ctx.b_trans
131 |
132 | grad_mat1 = torch.zeros_like(mat1)
133 | grad_mat2 = torch.zeros_like(mat2)
134 | grad_out, out_trans = is_transposed(grad_output)
135 |
136 | if a_trans:
137 | simple_gemm_dev.gemm(grad_mat1, mat2, grad_out, alpha, b_trans, not out_trans)
138 | else:
139 | simple_gemm_dev.gemm(grad_mat1, grad_out, mat2, alpha, out_trans, not b_trans)
140 |
141 | if b_trans:
142 | simple_gemm_dev.gemm(grad_mat2, grad_out, mat1, alpha, not out_trans, a_trans)
143 | else:
144 | simple_gemm_dev.gemm(grad_mat2, mat1, grad_out, alpha, not a_trans, out_trans)
145 |
146 | return (grad_output, grad_output, grad_mat1, grad_mat2, None, None, None, None)
147 |
148 | class SimpleMatmul(Function):
149 | @staticmethod
150 | def forward(ctx, output, mat1, mat2, alpha, a_trans, b_trans):
151 | ctx.save_for_backward(mat1, mat2)
152 | ctx.a_trans = a_trans
153 | ctx.b_trans = b_trans
154 | ctx.alpha = alpha
155 |
156 | simple_gemm_dev.gemm(output, mat1, mat2, alpha, a_trans, b_trans)
157 | ctx.mark_dirty(output)
158 | return output
159 |
160 | @staticmethod
161 | def backward(ctx, grad_output):
162 | mat1, mat2 = ctx.saved_tensors
163 | alpha = ctx.alpha
164 | a_trans = ctx.a_trans
165 | b_trans = ctx.b_trans
166 |
167 | grad_mat1 = torch.empty_like(mat1)
168 | grad_mat2 = torch.empty_like(mat2)
169 | grad_out, out_trans = is_transposed(grad_output)
170 |
171 | if a_trans:
172 | simple_gemm_dev.gemm(grad_mat1, mat2, grad_out, alpha, b_trans, not out_trans)
173 | else:
174 | simple_gemm_dev.gemm(grad_mat1, grad_out, mat2, alpha, out_trans, not b_trans)
175 |
176 | if b_trans:
177 | simple_gemm_dev.gemm(grad_mat2, grad_out, mat1, alpha, not out_trans, a_trans)
178 | else:
179 | simple_gemm_dev.gemm(grad_mat2, mat1, grad_out, alpha, not a_trans, out_trans)
180 | return (grad_output, grad_mat1, grad_mat2, None, None, None)
181 |
182 |
183 | class SimpleConv2dFunction(Function):
184 | @staticmethod
185 | def forward(ctx, output, inputs, weights, bias, stride, padding, dilation, groups):
186 | #print("### conv2d fwd called input size: ", inputs.size(), weights.size(), stride, padding, dilation, groups)
187 | ctx.save_for_backward(inputs, weights)#, bias)
188 | ctx.stride = stride#[0]
189 | ctx.padding = padding#[0]
190 | ctx.dilation = dilation#[0]
191 | ctx.groups = groups
192 |
193 | if bias is None:
194 | bias_fw = torch.zeros(output.size()[1])
195 | else :
196 | bias_fw = bias
197 |
198 | simple_conv2d_dev.conv2d_fp(output, inputs, weights, bias_fw, stride[0], padding[0], dilation[0], groups)
199 | ctx.mark_dirty(output)
200 | return output
201 |
202 | @staticmethod
203 | def backward(ctx, grad_output):
204 | #inputs, weights, bias = ctx.saved_tensors
205 | inputs, weights = ctx.saved_tensors
206 | stride = ctx.stride
207 | padding = ctx.padding
208 | dilation = ctx.dilation
209 | groups = ctx.groups
210 | #print("### conv2d bwd called input size: ", inputs.size(), weights.size(), stride, padding, dilation, groups)
211 | grad_inp = torch.zeros_like(inputs)
212 | grad_wts = torch.zeros_like(weights)
213 |
214 | simple_conv2d_dev.conv2d_bp(grad_inp, grad_output, weights, stride[0], padding[0], dilation[0], groups)
215 | simple_conv2d_dev.conv2d_wu(grad_wts, grad_output, inputs, stride[0], padding[0], dilation[0], groups)
216 | return (grad_output, grad_inp, grad_wts, None, None, None, None, None)
217 |
--------------------------------------------------------------------------------
/mpemu/cmodel/simple/simple_conv2d.cpp:
--------------------------------------------------------------------------------
1 | /*----------------------------------------------------------------------------*
2 | * Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | * This file is part of FP8-Emulation-Toolkit
4 | *
5 | * SPDX-License-Identifier: BSD-3-Clause
6 | *----------------------------------------------------------------------------*
7 | * Naveen Mellempudi (Intel Corporation)
8 | *----------------------------------------------------------------------------*/
9 |
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 | #include
17 |
18 | extern int simple_conv2d_impl_fp(float* outputs, float *inputs, float *weights, float* bias, int N, int C, int iH, int iW,
19 | int K, int R, int S, int stride, int padding, int dilation, int groups);
20 | extern int simple_conv2d_impl_bp(float* inputs, float *outputs, float *weights, int N, int C, int iH, int iW,
21 | int K, int R, int S, int stride, int padding, int dilation, int groups);
22 | extern int simple_conv2d_impl_wu(float *weights, float *outputs, float *inputs, int N, int C, int iH, int iW,
23 | int K, int R, int S, int stride, int padding, int dilation, int groups);
24 |
25 | #define gettid() ((int)syscall(SYS_gettid))
26 |
27 | using namespace torch::autograd::profiler;
28 |
29 |
30 | double get_time() {
31 | static bool init_done = false;
32 | static struct timespec stp = {0,0};
33 | struct timespec tp;
34 | clock_gettime(CLOCK_REALTIME, &tp);
35 |
36 | if(!init_done) {
37 | init_done = true;
38 | stp = tp;
39 | }
40 | double ret = (tp.tv_sec - stp.tv_sec) * 1e3 + (tp.tv_nsec - stp.tv_nsec)*1e-6;
41 | return ret;
42 | }
43 |
44 | at::Tensor simple_conv2d_fp(torch::Tensor& output, torch::Tensor input, torch::Tensor weight, torch::Tensor bias,
45 | int stride, int padding, int dilation, int groups)
46 | {
47 | RECORD_FUNCTION("simple_conv2d_fp", std::vector({input, weight, bias}));
48 |
49 | auto N = input.size(0);
50 | auto C = input.size(1);
51 | auto H = input.size(2);
52 | auto W = input.size(3);
53 |
54 | auto K = weight.size(0);
55 | //auto C1 = weight.size(1);
56 | auto R = weight.size(2);
57 | auto S = weight.size(3);
58 |
59 | float *input_ptr = input.data_ptr();
60 | float *weight_ptr = weight.data_ptr();
61 | float *output_ptr = output.data_ptr();
62 | float *bias_ptr = bias.data_ptr();
63 |
64 | simple_conv2d_impl_fp(output_ptr, input_ptr, weight_ptr, bias_ptr, N, C, H, W,
65 | K, R, S, stride, padding, dilation, groups);
66 |
67 | //thnn_conv2d_out(output, input, weight,
68 | return output;
69 | }
70 |
71 | at::Tensor simple_conv2d_bp(torch::Tensor& input, torch::Tensor output, torch::Tensor weight,
72 | int stride, int padding, int dilation, int groups)
73 | {
74 | RECORD_FUNCTION("simple_conv2d_bp", std::vector({output, weight}));
75 |
76 | auto N = input.size(0);
77 | auto C = input.size(1);
78 | auto H = input.size(2);
79 | auto W = input.size(3);
80 |
81 | auto K = weight.size(0);
82 | //auto C1 = weight.size(1);
83 | auto R = weight.size(2);
84 | auto S = weight.size(3);
85 |
86 | float *input_ptr = input.data_ptr();
87 | float *weight_ptr = weight.data_ptr();
88 | float *output_ptr = output.data_ptr();
89 |
90 | simple_conv2d_impl_bp(input_ptr, output_ptr, weight_ptr, N, C, H, W,
91 | K, R, S, stride, padding, dilation, groups);
92 |
93 | //thnn_conv2d_out(output, input, weight,
94 | return input;
95 | }
96 |
97 | at::Tensor simple_conv2d_wu(torch::Tensor& weight, torch::Tensor output, torch::Tensor input,
98 | int stride, int padding, int dilation, int groups)
99 | {
100 | RECORD_FUNCTION("simple_conv2d_wu", std::vector({output, input}));
101 |
102 | auto N = input.size(0);
103 | auto C = input.size(1);
104 | auto H = input.size(2);
105 | auto W = input.size(3);
106 |
107 | auto K = weight.size(0);
108 | //auto C1 = weight.size(1);
109 | auto R = weight.size(2);
110 | auto S = weight.size(3);
111 |
112 | float *input_ptr = input.data_ptr();
113 | float *weight_ptr = weight.data_ptr();
114 | float *output_ptr = output.data_ptr();
115 |
116 | simple_conv2d_impl_wu(weight_ptr, output_ptr, input_ptr, N, C, H, W,
117 | K, R, S, stride, padding, dilation, groups);
118 |
119 | //thnn_conv2d_out(output, input, weight,
120 | return weight;
121 | }
122 |
123 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
124 | m.def("conv2d_fp", &simple_conv2d_fp, "simple conv_fp implementation");
125 | m.def("conv2d_bp", &simple_conv2d_bp, "simple conv_bp implementation");
126 | m.def("conv2d_wu", &simple_conv2d_wu, "simple conv_wu implementation");
127 | }
128 |
--------------------------------------------------------------------------------
/mpemu/cmodel/simple/simple_gemm.cpp:
--------------------------------------------------------------------------------
1 | /*----------------------------------------------------------------------------*
2 | * Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | * This file is part of FP8-Emulation-Toolkit
4 | *
5 | * SPDX-License-Identifier: BSD-3-Clause
6 | *----------------------------------------------------------------------------*
7 | * Naveen Mellempudi (Intel Corporation)
8 | *----------------------------------------------------------------------------*/
9 |
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 | #include
17 | #if 10
18 | //extern "C" {
19 | extern int simple_sgemm_impl( char* transa, char* transb, int m, int n, int k,
20 | float alpha, float* a, int lda, float* b, int ldb,
21 | float beta, float* c, int ldc );
22 | //}
23 | #endif
24 | #define gettid() ((int)syscall(SYS_gettid))
25 |
26 | using namespace torch::autograd::profiler;
27 | using namespace torch;
28 | using namespace torch::autograd;
29 | using at::Tensor;
30 |
31 | double get_time() {
32 | static bool init_done = false;
33 | static struct timespec stp = {0,0};
34 | struct timespec tp;
35 | clock_gettime(CLOCK_REALTIME, &tp);
36 |
37 | if(!init_done) {
38 | init_done = true;
39 | stp = tp;
40 | }
41 | double ret = (tp.tv_sec - stp.tv_sec) * 1e3 + (tp.tv_nsec - stp.tv_nsec)*1e-6;
42 | return ret;
43 | }
44 |
45 | at::Tensor simple_gemm(torch::Tensor& C, torch::Tensor A, torch::Tensor B, float alpha, bool a_trans, bool b_trans)
46 | {
47 | RECORD_FUNCTION("simple_gemm", std::vector({A, B, alpha}));
48 |
49 | const char *aT = a_trans ? "T" : "N";
50 | const char *bT = b_trans ? "T" : "N";
51 |
52 | auto M = C.size(0);
53 | auto N = C.size(1);
54 | auto K = a_trans ? A.size(0) : A.size(1);
55 | auto lda = A.size(1);
56 | auto ldb = B.size(1);
57 | auto ldc = C.size(1);
58 |
59 | float beta = 0.0;
60 |
61 | float *Aptr = A.data_ptr();
62 | float *Bptr = B.data_ptr();
63 | float *Cptr = C.data_ptr();
64 |
65 | simple_sgemm_impl((char*)bT, (char*)aT, N, M, K, alpha, Bptr, ldb, Aptr, lda, beta, Cptr, ldc);
66 | return C;
67 | }
68 |
69 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
70 | m.def("gemm", &simple_gemm, "Simple Matrix Engine");
71 | }
72 |
--------------------------------------------------------------------------------
/mpemu/cmodel/simple/simple_gemm_impl.cpp:
--------------------------------------------------------------------------------
1 | /*----------------------------------------------------------------------------*
2 | * Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | * This file is part of FP8-Emulation-Toolkit
4 | *
5 | * SPDX-License-Identifier: BSD-3-Clause
6 | *----------------------------------------------------------------------------*
7 | * Naveen Mellempudi (Intel Corporation)
8 | *----------------------------------------------------------------------------*/
9 |
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 | #include
17 |
18 | #define SCRATCH_SIZE 4294967296
19 | int lazy_init = 1;
20 | float* myscratch = NULL;
21 |
22 | extern void MMEngine_avx2_ps(int m, int n, int k, float alpha, float *A, int lda,
23 | float *B, int ldb, float beta, float *C, int ldc);
24 |
25 | __extern_always_inline
26 | void copy_matrix_and_pad(const float* in, float* out, const int LD, const int OD, int ld_pad, int od_pad)
27 | {
28 | int LDP, ODP;
29 | LDP = LD + ld_pad;
30 | ODP = OD + od_pad;
31 | int ld, od;
32 | int lp, op;
33 |
34 | for ( op = 0; op < OD; op++ ) {
35 | for ( lp = LD-1; lp < LDP; lp++ ) {
36 | out[op*LDP+lp] = 0.0;
37 | }
38 | }
39 | for ( op = OD-1; op < ODP; op++ ) {
40 | for ( lp = 0; lp < LDP; lp++ ) {
41 | out[op*LDP+lp] = 0.0;
42 | }
43 | }
44 |
45 | #if defined(_OPENMP)
46 | #pragma omp parallel for private(ld, od)
47 | #endif
48 | for ( od = 0; od < OD; od++ ) {
49 | for ( ld = 0; ld < LD; ld++ ) {
50 | out[od*LDP+ld] = in[od*LD+ld];
51 | }
52 | }
53 | }
54 |
55 | __extern_always_inline
56 | void copy_matrix_and_strip_pading(const float* in, float* out, const int LD, const int OD, int ld_pad, int od_pad)
57 | {
58 | int LDP;
59 | LDP = LD + ld_pad;
60 | int ld, od;
61 | #if defined(_OPENMP)
62 | #pragma omp parallel for private(ld, od)
63 | #endif
64 | for ( od = 0; od < OD; od++ ) {
65 | for ( ld = 0; ld < LD; ld++ ) {
66 | out[od*LD+ld] = in[od*LDP+ld];
67 | }
68 | }
69 | }
70 |
71 | int simple_sgemm_impl( char* transa, char* transb, int m, int n, int k,
72 | float alpha, float* a, int lda, float* b, int ldb,
73 | float beta, float* c, int ldc ) {
74 | float* myA = NULL;
75 | float* myB = NULL;
76 | float* myC = NULL;
77 | size_t ptlda = (size_t)(lda);
78 | size_t ptldb = (size_t)(ldb);
79 | size_t mylda = (size_t)(lda);
80 | size_t myldb = (size_t)(ldb);
81 | size_t myldc = (size_t)(ldc);
82 | int mym = (size_t)(m);
83 | int myn = (size_t)(n);
84 | int myk = (size_t)(k);
85 | int m_pad = 0;
86 | int n_pad = 0;
87 | int k_pad = 0;
88 |
89 | int o,p,q,pp,oo;
90 |
91 | /* check for size matching our TMUL emulation */
92 | if ( mym % 16 != 0 ) {
93 | m_pad = (16-(mym%16));
94 | mym += m_pad;
95 | }
96 | if ( myk % 64 != 0 ) {
97 | k_pad = (64-(myk%64));
98 | myk += k_pad;
99 | }
100 | if ( myn % 16 != 0 ) {
101 | n_pad = (16-(myn%16));
102 | myn += n_pad;
103 | }
104 | /* update leading dimensions with padded values */
105 | mylda = mym;
106 | myldb = myk;
107 | myldc = mym;
108 |
109 | /* lazy init of fp8_gemm state */
110 | if (lazy_init != 0) {
111 | lazy_init = 0;
112 | myscratch = (float*) _mm_malloc( SCRATCH_SIZE*sizeof(float), 4096 );
113 | }
114 | /* check for sufficient scratch size */
115 | if ( (*transa == 'N') && (*transb == 'N') ) {
116 | if ( ((ptlda*myk)+(ptldb*myn)) > SCRATCH_SIZE ) {
117 | return -1;
118 | }
119 | } else if ( (*transa == 'T') && (*transb == 'N') ) {
120 | if ( ((ptlda*mym)+(ptldb*myn)) > SCRATCH_SIZE ) {
121 | return -2;
122 | }
123 | mylda = mym;
124 | } else if ( (*transa == 'N') && (*transb == 'T') ) {
125 | if ( ((ptlda*myk)+(ptldb*myk)) > SCRATCH_SIZE ) {
126 | return -3;
127 | }
128 | myldb = myk;
129 | } else if ( (*transa == 'T') && (*transb == 'T') ) {
130 | if ( ((ptlda*mym)+(ptldb*myk)) > SCRATCH_SIZE ) {
131 | return -4;
132 | }
133 | mylda = mym;
134 | myldb = myk;
135 | } else {
136 | assert((0 && "Error : Invalid parameters"));
137 | return -5;
138 | }
139 |
140 | /* set temp A and B pointers */
141 | myA = myscratch;
142 | myB = myscratch + (mylda*myk);
143 | myC = myscratch + (mylda*myk) + (myldb*myn);
144 |
145 | if ( *transa == 'T' ) {
146 | /* fill the padding with zeros */
147 | for ( p = 0; p < k; p++ ) {
148 | for ( o = m-1; o < mym; o++ ) {
149 | myA[p*mylda+o] = 0.0;
150 | }
151 | }
152 | for ( p = k-1; p < myk; p++ ) {
153 | for ( o = 0; o < mym; o++ ) {
154 | myA[p*mylda+o] = 0.0;
155 | }
156 | }
157 |
158 | /* let's transpose data first */
159 | #if defined(_OPENMP)
160 | #pragma omp parallel for private(o,p) collapse(2)
161 | #endif
162 | for ( p = 0; p < k; p++ ) {
163 | for ( o = 0; o < m; o++ ) {
164 | myA[(p*mylda)+o] = a[(o*ptlda)+p];
165 | }
166 | }
167 | } else if ( m_pad > 0 || k_pad > 0 ) {
168 | copy_matrix_and_pad(a, myA, m, k, m_pad, k_pad);
169 | } else {
170 | myA = a;
171 | }
172 |
173 | if ( *transb == 'T' ) {
174 | /* fill the padding with zeros */
175 | for ( p = 0; p < n; p++ ) {
176 | for ( o = k-1; o < myk; o++ ) {
177 | myB[p*myldb+o] = 0.0;
178 | }
179 | }
180 | for ( p = n-1; p < myn; p++ ) {
181 | for ( o = 0; o < myk; o++ ) {
182 | myB[p*myldb+o] = 0.0;
183 | }
184 | }
185 |
186 | /* let's transpose data first */
187 | #if defined(_OPENMP)
188 | #pragma omp parallel for private(o,p) collapse(2)
189 | #endif
190 | for ( p = 0; p < n; p++ ) {
191 | for ( o = 0; o < k; o++ ) {
192 | myB[(p*myldb)+o] = b[(o*ptldb)+p];
193 | }
194 | }
195 | } else if ( k_pad > 0 || n_pad > 0 ) {
196 | copy_matrix_and_pad(b, myB, k, n, k_pad, n_pad);
197 | } else {
198 | myB = b;
199 | }
200 |
201 | if ( m_pad > 0 || n_pad > 0 ) {
202 | copy_matrix_and_pad(c, myC, m, n, m_pad, n_pad);
203 | } else {
204 | myC = c;
205 | }
206 | /* run gemm */
207 | #if defined(_OPENMP)
208 | #pragma omp parallel for private(o,p,q,pp,oo) collapse(2)
209 | #endif
210 | for ( o = 0; o < mym; o += 16 ) {
211 | for ( p = 0; p < myn; p += 16 ) {
212 |
213 | /* C initialized to zero */
214 | __attribute__ ((aligned(64))) float ctmp[256];
215 | for ( pp = 0; pp < 16; pp++ ) {
216 | for ( oo = 0; oo < 16; oo++ ) {
217 | ctmp[(pp*16)+oo] = 0.0f;
218 | }
219 | }
220 | /* compute a 16x16x64 block */
221 | for ( q = 0; q < myk; q += 64 ) {
222 | MMEngine_avx2_ps(16, 16, 64, alpha, &(myA[(mylda*q)+o]), mylda,
223 | &(myB[(myldb*p)+q]), myldb, beta, ctmp, 16);
224 | }
225 | /* post accumulation */
226 | for ( pp = 0; pp < 16; pp++ ) {
227 | for ( oo = 0; oo < 16; oo++ ) {
228 | myC[((p+pp)*myldc)+(o+oo)] += ((alpha)*ctmp[(pp*16)+oo]) + ((beta)*myC[((p+pp)*myldc)+(o+oo)]);
229 | }
230 | }
231 | }
232 | }
233 | if ( m_pad > 0 || n_pad > 0 ) {
234 | copy_matrix_and_strip_pading(myC, c, m, n, m_pad, n_pad);
235 | }
236 | return 0;
237 | }
238 |
--------------------------------------------------------------------------------
/mpemu/cmodel/simple/simple_mm_engine.cpp:
--------------------------------------------------------------------------------
1 | /*----------------------------------------------------------------------------*
2 | * Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | * This file is part of FP8-Emulation-Toolkit
4 | *
5 | * SPDX-License-Identifier: BSD-3-Clause
6 | *----------------------------------------------------------------------------*
7 | * Naveen Mellempudi (Intel Corporation)
8 | *----------------------------------------------------------------------------*/
9 |
10 | #include
11 |
12 | /* column major dat format */
13 | void MMEngine_avx2_ps(int m, int n, int k, float alpha, float *A, int lda,
14 | float *B, int ldb, float beta, float *C, int ldc)
15 | {
16 | for (int j = 0; j < n; j++) {
17 | for (int i = 0; i < m; i+=8) {
18 | __m256 cij = _mm256_loadu_ps(&C[ j*ldc + i]);
19 | for (int l = 0; l < k; l++) {
20 | __m256 aik = _mm256_loadu_ps(&A[ i + l * lda]);
21 | __m256 bkj = _mm256_broadcast_ss(&B[ l + j * ldb]);
22 | cij = _mm256_add_ps(cij, _mm256_mul_ps(aik, bkj));
23 | }
24 | _mm256_storeu_ps(&C[ j * ldc + i ], cij);
25 | }
26 | }
27 | }
28 |
29 | /* column major dat format */
30 | void MMEngine_strideB_avx2_ps(int m, int n, int k, float alpha, float *A, int lda,
31 | float *B, int ldb, float beta, float *C, int ldc, int strideB)
32 | {
33 | for (int j = 0; j < n; j++) {
34 | for (int i = 0; i < m; i+=8) {
35 | __m256 cij = _mm256_loadu_ps(&C[ j*ldc + i]);
36 | for (int l = 0; l < k; l++) {
37 | __m256 aik = _mm256_loadu_ps(&A[ i + l * lda]);
38 | __m256 bkj = _mm256_broadcast_ss(&B[ l*strideB + j * ldb]);
39 | cij = _mm256_add_ps(cij, _mm256_mul_ps(aik, bkj));
40 | }
41 | _mm256_storeu_ps(&C[ j * ldc + i], cij);
42 | }
43 | }
44 | }
45 |
46 |
--------------------------------------------------------------------------------
/mpemu/cmodel/simple/vla.h:
--------------------------------------------------------------------------------
1 | /*----------------------------------------------------------------------------*
2 | * Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | * This file is part of FP8-Emulation-Toolkit
4 | *
5 | * SPDX-License-Identifier: BSD-3-Clause
6 | *----------------------------------------------------------------------------*
7 | * Naveen Mellempudi (Intel Corporation)
8 | *----------------------------------------------------------------------------*/
9 |
10 | #ifndef VLA_H
11 | #define VLA_H
12 | #include
13 | #include
14 | #include
15 | #include
16 |
17 | #define ALWAYS_INLINE __attribute__((always_inline))
18 | #define INLINE inline
19 | #define RESTRICT __restrict__
20 | #define VLA_POSTFIX _
21 |
22 | #define INDEX1_1(...) ((size_t)SELECT_HEAD(__VA_ARGS__))
23 | #define INDEX1_2(I0, I1, S1) (INDEX1_1(I0) * ((size_t)S1) + (size_t)I1)
24 | #define INDEX1_3(I0, I1, I2, S1, S2) (INDEX1_2(I0, I1, S1) * ((size_t)S2) + (size_t)I2)
25 | #define INDEX1_4(I0, I1, I2, I3, S1, S2, S3) (INDEX1_3(I0, I1, I2, S1, S2) * ((size_t)S3) + (size_t)I3)
26 | #define INDEX1_5(I0, I1, I2, I3, I4, S1, S2, S3, S4) (INDEX1_4(I0, I1, I2, I3, S1, S2, S3) * ((size_t)S4) + (size_t)I4)
27 | #define INDEX1_6(I0, I1, I2, I3, I4, I5, S1, S2, S3, S4, S5) (INDEX1_5(I0, I1, I2, I3, I4, S1, S2, S3, S4) * ((size_t)S5) + (size_t)I5)
28 | #define INDEX1_7(I0, I1, I2, I3, I4, I5, I6, S1, S2, S3, S4, S5, S6) (INDEX1_6(I0, I1, I2, I3, I4, I5, S1, S2, S3, S4, S5) * ((size_t)S6) + (size_t)I6)
29 | #define INDEX1_8(I0, I1, I2, I3, I4, I5, I6, I7, S1, S2, S3, S4, S5, S6, S7) (INDEX1_7(I0, I1, I2, I3, I4, I5, I6, S1, S2, S3, S4, S5, S6) * ((size_t)S7) + (size_t)I7)
30 | #define INDEX1_9(I0, I1, I2, I3, I4, I5, I6, I7, I8, S1, S2, S3, S4, S5, S6, S7, S8) (INDEX1_8(I0, I1, I2, I3, I4, I5, I6, I7, S1, S2, S3, S4, S5, S6, S7) * ((size_t)S8) + (size_t)I8)
31 | #define INDEX1_10(I0, I1, I2, I3, I4, I5, I6, I7, I8, I9, S1, S2, S3, S4, S5, S6, S7, S8, S9) (INDEX1_9(I0, I1, I2, I3, I4, I5, I6, I7, I8, S1, S2, S3, S4, S5, S6, S7, S8) * ((size_t)S9) + (size_t)I9)
32 |
33 | #define EXPAND(...) __VA_ARGS__
34 | #define CONCATENATE2(A, B) A##B
35 | #define CONCATENATE(A, B) CONCATENATE2(A, B)
36 | #define INDEX1(NDIMS, ...) CONCATENATE(INDEX1_, NDIMS)(__VA_ARGS__)
37 |
38 | #define SELECT_HEAD_AUX(A, ...) (A)
39 | #define SELECT_HEAD(...) EXPAND(SELECT_HEAD_AUX(__VA_ARGS__, 0))
40 | #define SELECT_TAIL(A, ...) __VA_ARGS__
41 |
42 | #define ACCESS_VLA(NDIMS, ARRAY, ...) CONCATENATE(ARRAY, VLA_POSTFIX)[INDEX1(NDIMS, __VA_ARGS__)]
43 | #define DECLARE_VLA(NDIMS, ELEMENT_TYPE, ARRAY_VAR, ...) \
44 | ELEMENT_TYPE *RESTRICT CONCATENATE(ARRAY_VAR, VLA_POSTFIX) = SELECT_HEAD(__VA_ARGS__) \
45 | + 0 * INDEX1(NDIMS, SELECT_TAIL(__VA_ARGS__, SELECT_TAIL(__VA_ARGS__, 0)))
46 |
47 | #endif
48 |
--------------------------------------------------------------------------------
/mpemu/cmodel/tests/conv_grad_test.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | import torch
11 | import time
12 | from mpemu.cmodel import simple
13 | import numpy as np
14 |
15 | n = 64
16 | c = 256
17 | h = 28
18 | w = 28
19 | k = 256
20 | r = 3
21 | s = 3
22 | stride = 2
23 | pad = 2
24 |
25 | a = torch.rand((n,c,h,w), dtype=torch.float32)
26 | b = torch.rand((k,c,r,s), dtype=torch.float32)
27 | bias = torch.rand((k), dtype=torch.float32)
28 |
29 | #a = torch.rand((n,c,h,w), dtype=torch.float32)
30 | #b = torch.rand((k,c,r,s), dtype=torch.float32)
31 | #bias = torch.rand((k), dtype=torch.float32)
32 |
33 | a64 = a.to(dtype=torch.float64, copy=True)
34 | b64 = b.to(dtype=torch.float64, copy=True)
35 | bias64 = bias.to(dtype=torch.float64, copy=True)
36 |
37 | a.requires_grad=True
38 | b.requires_grad=True
39 | a64.requires_grad=True
40 | b64.requires_grad=True
41 |
42 | ref_time = time.time()
43 | z = torch.nn.functional.conv2d(a64, b64, bias64, stride=(stride,stride), padding=(pad,pad), dilation=(1,1), groups=1)
44 | ref_time = time.time()-ref_time
45 |
46 | simple_time = time.time()
47 | z2 = simple.conv2d(a, b, bias, stride=(stride,stride), padding=(pad,pad), dilation=(1,1), groups=1)
48 | ref_time = time.time()-simple_time
49 |
50 | ref_time = time.time()
51 | z3 = torch.nn.functional.conv2d(a, b, bias, stride=(stride,stride), padding=(pad,pad), dilation=(1,1), groups=1)
52 | ref_time = time.time()-ref_time
53 |
54 | #print("Forward Time : ref_time: {}, simple_time: {} ".format(ref_time, simple_time))
55 | print('Forward: L2 distance f32 output : ', torch.dist(z3.to(dtype=torch.float64, copy=True), z, 2).item())
56 | print('Forward: L2 distance output : ', torch.dist(z2.to(dtype=torch.float64, copy=True), z, 2).item())
57 |
58 | ref_time = time.time()
59 | (z[0, 0] + z[0,1]).sum().backward()
60 | ref_time = time.time()-ref_time
61 |
62 | simple_time = time.time()
63 | (z2[0, 0] + z2[0,1]).sum().backward()
64 | simple_time = time.time()-simple_time
65 |
66 | #print("BackProp Time : ref_time: {}, simple_time: {}".format(ref_time, simple_time))
67 | torch.set_printoptions(profile="full")
68 | print('Backward: L2 distance input_grad: ', torch.dist(a.grad.to(dtype=torch.float64, copy=True), a64.grad, 2).item())
69 | print('Backward: L2 distance weight_grad: ', torch.dist(b.grad.to(dtype=torch.float64, copy=True), b64.grad, 2).item())
70 |
--------------------------------------------------------------------------------
/mpemu/cmodel/tests/conv_test.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | import torch
11 | from mpemu.cmodel import simple
12 |
13 | n = 16
14 | c = 256
15 | h = 28
16 | w = 28
17 | k = 256
18 | r = 3
19 | s = 3
20 |
21 | input = torch.rand((n,c,h,w), dtype=torch.float32)
22 | weight = torch.rand((k,c,r,s), dtype=torch.float32)
23 | bias = torch.rand((k), dtype=torch.float32)
24 |
25 | output_ref = torch.nn.functional.conv2d(input, weight, bias, stride=(2,2), padding=(2,2), dilation=(1,1), groups=1)
26 | output_simple = simple.conv2d(input, weight, bias, stride=(2,2), padding=(2,2), dilation=(1,1), groups=1)
27 |
28 | torch.set_printoptions(profile="full")
29 |
30 | print('Forward : L2 distance (simple) : ', torch.dist(output_ref, output_simple, 2).item())
31 |
32 | #print(output_ref-output_simple)
33 |
--------------------------------------------------------------------------------
/mpemu/cmodel/tests/gemm_grad_test.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | import torch
11 | from mpemu.cmodel import simple
12 |
13 | m=1024 #356320 #356305
14 | n=1024 #120 #256
15 | k=1024 #576 #128 #2020
16 |
17 | # Start here
18 | a = torch.rand((m, k), dtype=torch.float32)
19 | b = torch.rand((n, k), dtype=torch.float32)
20 | c = torch.zeros((m, n), dtype=torch.float32)
21 |
22 | a64 = a.to(dtype=torch.float64, copy=True)
23 | b64 = b.to(dtype=torch.float64, copy=True)
24 | c64 = c.to(dtype=torch.float64, copy=True)
25 |
26 | a.requires_grad=True
27 | b.requires_grad=True
28 | c.requires_grad=True
29 | a64.requires_grad=True
30 | b64.requires_grad=True
31 | c64.requires_grad=True
32 |
33 | z = torch.addmm(c64, a64, b64.t())
34 | z2 = simple.addmm(c, a, b.t())
35 | z3 = simple.addmm(c, a, b.t())
36 |
37 | print('Forward :L2 distance f32 output: ', torch.dist(z3.to(dtype=torch.float64, copy=True), z, 2).item())
38 | print('Forward :L2 distance output: ', torch.dist(z2.to(dtype=torch.float64, copy=True), z, 2).item())
39 |
40 | (z2[0, 0] + z2[0,1]).sum().backward()
41 | (z[0, 0] + z[0,1]).sum().backward()
42 |
43 | print('Backward : L2 distance a_grad: ', torch.dist(a.grad.to(dtype=torch.float64, copy=True), a64.grad, 2).item())
44 | print('Backward : L2 distance b_grad: ', torch.dist(b.grad.to(dtype=torch.float64, copy=True), b64.grad, 2).item())
45 | print('Backward : L2 distance c_grad: ', torch.dist(c.grad.to(dtype=torch.float64, copy=True), c64.grad, 2).item())
46 |
--------------------------------------------------------------------------------
/mpemu/cmodel/tests/gemm_irregular_test.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | import torch
11 | import fuse
12 |
13 | m=1 #356320 #356305
14 | n=120 #256
15 | k=576 #128 #2020
16 |
17 | # Start here
18 | a = torch.rand((5, 107, 1024), dtype=torch.float32)
19 | b = torch.rand((1024, 1024), dtype=torch.float32)
20 | c = torch.zeros((m, n), dtype=torch.float32)
21 |
22 | a64 = a.to(dtype=torch.float64, copy=True)
23 | b64 = b.to(dtype=torch.float64, copy=True)
24 | c64 = c.to(dtype=torch.float64, copy=True)
25 |
26 | z = torch.matmul(a64, b64)
27 | z3 = torch.matmul(a, b)
28 | for a1 in a :
29 | print("-->", a1.size())
30 |
31 | #z2l = tuple([fuse.tmul.matmul(a1, b) for a1 in a])
32 | #z2 = torch.stack(z2l)
33 | #z2 = torch.stack(tuple([fuse.tmul.matmul(a1, b) for a1 in a]))
34 | z2 = fuse.tmul.matmul(a, b)
35 |
36 | #print (z3.size(), z2.size())
37 |
38 | #print("torch :", z3.size(), z3)
39 | #print("Ours : ", z2.size(), z2)
40 | print('32b: L2 distance : ', torch.dist(z, z3.to(dtype=torch.float64, copy=True), 2))
41 | print('ours: L2 distance : ', torch.dist(z, z2.to(dtype=torch.float64, copy=True), 2))
42 |
43 |
--------------------------------------------------------------------------------
/mpemu/cmodel/tests/gemm_test.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | import torch
11 | from mpemu.cmodel import simple
12 |
13 | def get_grads(variables):
14 | return [var.grad.clone() for var in variables]
15 |
16 | m=256
17 | n=512
18 | k=1024
19 |
20 | a = torch.rand((k, m), dtype=torch.float32)
21 | b = torch.rand((n, k), dtype=torch.float32)
22 | c = torch.zeros((m, n), dtype=torch.float32)
23 |
24 | a64 = a.to(dtype=torch.float64, copy=True)
25 | b64 = b.to(dtype=torch.float64, copy=True)
26 | c64 = c.to(dtype=torch.float64, copy=True)
27 | a32 = a.to(dtype=torch.float32, copy=True)
28 | b32 = b.to(dtype=torch.float32, copy=True)
29 | c32 = c.to(dtype=torch.float32, copy=True)
30 | print(a)
31 | print(b)
32 |
33 | z = torch.matmul(a64.t(), b64.t())
34 | z2 = simple.matmul(a.t(), b.t(), out=c)
35 | z3 = torch.matmul(a32.t(), b32.t())
36 | print(z)
37 | print(c-z2)
38 | print(z2.size())
39 | print('output 32b: L2 distance : ', torch.dist(z3.to(dtype=torch.float64, copy=True), z, 2))
40 | print('output : L2 distance : ', torch.dist(z2.to(dtype=torch.float64, copy=True), z, 2))
41 |
--------------------------------------------------------------------------------
/mpemu/cmodel/tests/linear_test.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import torch.optim as optim
14 | from mpemu.cmodel import simple
15 |
16 | class Net(nn.Module):
17 |
18 | def __init__(self):
19 | super(Net, self).__init__()
20 | self.fc = nn.Linear(576, 120)
21 |
22 | def forward(self, x):
23 | x = self.fc(x)
24 | return x
25 |
26 | def num_flat_features(self, x):
27 | size = x.size()[1:] # all dimensions except the batch dimension
28 | num_features = 1
29 | for s in size:
30 | num_features *= s
31 | return num_features
32 |
33 | net = Net()
34 | net1 = Net()
35 | print(net)
36 | input = torch.randn((1, 576), dtype=torch.float32, device="cpu")
37 | input_new = input.to(dtype=torch.float32, copy=True)
38 |
39 | output = net(input)
40 |
41 | target = torch.randn(120, dtype=torch.float32, device="cpu") # a dummy target, for example
42 | target = target.view(1, -1) # make it the same shape as output
43 | criterion = nn.MSELoss()
44 | loss = criterion(output, target)
45 | net.zero_grad() # zeroes the gradient buffers of all parameters
46 | #loss.backward(retain_graph=True)
47 | loss.backward()
48 |
49 |
50 | torch.addmm_back = torch.addmm
51 | torch.matmul_back = torch.matmul
52 | torch.addmm = simple.addmm
53 | torch.matmul = simple.matmul
54 |
55 | output1 = net1(input_new)
56 |
57 | loss1 = criterion(output1, target)
58 | net1.zero_grad() # zeroes the gradient buffers of all parameters
59 | loss1.backward()
60 |
61 | print('Linear wtgrads L2 distance : ', torch.dist(net.fc.weight.grad, net1.fc.weight.grad, 2))
62 |
--------------------------------------------------------------------------------
/mpemu/cmodel/tests/net.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import torch.optim as optim
14 | from mpemu.cmodel import simple
15 |
16 | torch_conv2d = torch.nn.functional.conv2d
17 | torch_addmm = torch.addmm
18 | torch.addmm = simple.addmm
19 | torch.nn.functional.conv2d = simple.conv2d
20 |
21 | class Net(nn.Module):
22 |
23 | def __init__(self):
24 | super(Net, self).__init__()
25 | # 1 input image channel, 6 output channels, 3x3 square convolution
26 | # kernel
27 | self.conv1 = nn.Conv2d(1, 6, 3)
28 | self.conv2 = nn.Conv2d(6, 16, 3)
29 | # an affine operation: y = Wx + b
30 | self.fc1 = nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension
31 | self.fc2 = nn.Linear(120, 84)
32 | self.fc3 = nn.Linear(84, 10)
33 |
34 | def forward(self, x):
35 | # Max pooling over a (2, 2) window
36 | x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
37 | # If the size is a square you can only specify a single number
38 | x = F.max_pool2d(F.relu(self.conv2(x)), 2)
39 | x = x.view(-1, self.num_flat_features(x))
40 | x = F.relu(self.fc1(x))
41 | x = F.relu(self.fc2(x))
42 | x = self.fc3(x)
43 | return x
44 |
45 | def num_flat_features(self, x):
46 | size = x.size()[1:] # all dimensions except the batch dimension
47 | num_features = 1
48 | for s in size:
49 | num_features *= s
50 | return num_features
51 |
52 |
53 | net = Net()
54 | print(net)
55 | params = list(net.parameters())
56 | input = torch.randn(1, 1, 32, 32, dtype=torch.float32, device="cpu")
57 | # create your optimizer
58 | optimizer = optim.SGD(net.parameters(), lr=0.1)
59 | optimizer.zero_grad() # zero the gradient buffers
60 |
61 | output = net(input)
62 |
63 | print("fc3 output:", output)
64 | target = torch.randn(10, dtype=torch.float32, device="cpu") # a dummy target, for example
65 | target = target.view(1, -1) # make it the same shape as output
66 | criterion = nn.MSELoss()
67 |
68 | loss = criterion(output, target)
69 |
70 | net.zero_grad() # zeroes the gradient buffers of all parameters
71 |
72 | loss.backward()
73 | optimizer.step() # Does the update
74 |
--------------------------------------------------------------------------------
/mpemu/e3m4_emu.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | from collections import OrderedDict
11 | import torch
12 | from .qutils import TensorQuantConfig, ModuleQuantConfig
13 | from .qutils import get_or_update_model_quant_config_dict
14 | from .qutils import reset_quantization_setup,add_quantization_hooks
15 | from .qutils import quantize_model_weights,set_quantize_weights_flag
16 | from .scale_shift import replace_batchnorms_with_scaleshifts
17 | from .module_wrappers import BatchMatmul,Matmul,AddMatmul,EltwiseMul,EltwiseAdd,EltwiseDiv
18 | from .module_wrappers import SparseConv2d,SparseLinear
19 |
20 | '''
21 | E3M4 Emulator
22 | '''
23 | class E3M4Emulator(object):
24 | def __init__(self, model, optimizer, sparse_config=None, device="cuda", verbose=False, tensor_stats=False):
25 | super(E3M4Emulator, self).__init__()
26 | self.whitelist = [torch.nn.Conv2d, torch.nn.Linear, torch.nn.Embedding, torch.nn.EmbeddingBag]
27 | self.whitelist += [Matmul, BatchMatmul, AddMatmul]
28 | self.whitelist += [EltwiseAdd, EltwiseMul, EltwiseDiv]
29 | self.whitelist += [SparseConv2d, SparseLinear]
30 | self.blacklist = []
31 | self.list_unpatched = []
32 | self.is_training = False
33 | self.list_exempt_layers = None
34 | self.list_layers_output_fused = None
35 | self.device = device
36 | self.patch_ops = False
37 | self.patch_impl = "NONE"
38 | self.patchlist = ["simple"]
39 | self.patchlist = ["SIMPLE"]
40 | self.sparse_config = sparse_config
41 | # default configuration
42 | self.data_emulation = True
43 | self.mod_qconfig = None
44 | self.model_qconfig_dict = None #OrderedDict()
45 | self.emb_qconfig = TensorQuantConfig("e3m4", "rne", "per-channel")
46 | self.wt_qconfig = TensorQuantConfig("e3m4", "rne", "per-channel")
47 | self.iact_qconfig = TensorQuantConfig("e3m4", "rne", "per-tensor")
48 | self.oact_qconfig = None #TensorQuantConfig("e3m4", "rne", "per-tensor")
49 | self.hook_handles = None
50 | self.verbose = verbose
51 |
52 | def blacklist_modules(self, list_modules) :
53 | if self.verbose:
54 | print("Current module list {}".format(self.whitelist))
55 | print("Blacklist-ing modules {}".format(list_modules))
56 | for module in list_modules :
57 | self.blacklist.append(module)
58 | self.whitelist.remove(module)
59 | self.create_or_update_hooks()
60 | if self.verbose:
61 | print("Updated module list {}".format(self.whitelist))
62 | self.print_config()
63 |
64 | def whitelist_modules(self, list_modules) :
65 | if self.verbose:
66 | print("Current module list {}".format(self.whitelist))
67 | print("Whitelist-ing modules {}".format(list_modules))
68 | for module in list_modules :
69 | self.whitelist.append(module)
70 | self.blacklist.remove(module)
71 | self.create_or_update_hooks()
72 | if self.verbose:
73 | print("Updated module list {}".format(self.whitelist))
74 | self.print_config()
75 |
76 | def create_or_update_hooks(self, model):
77 | self.model_qconfig_dict = get_or_update_model_quant_config_dict(model,
78 | self.whitelist, self.mod_qconfig,
79 | model_qconfig_dict=self.model_qconfig_dict,
80 | override=True)
81 |
82 | if self.list_exempt_layers is not None :
83 | for exempt_layer in self.list_exempt_layers:
84 | if self.model_qconfig_dict.get(exempt_layer) is not None:
85 | self.model_qconfig_dict.pop(exempt_layer)
86 |
87 | # Disable output quantization for these layers,
88 | # These layers are followed by precision sensitive layers such as SoftMax
89 | # In the final implementation, sensitive layers are fused with the preceeding layer
90 | if self.list_layers_output_fused is not None :
91 | for name,module in model.named_modules():
92 | if name in self.list_layers_output_fused and name not in self.list_exempt_layers \
93 | and module in self.whitelist :
94 | self.model_qconfig_dict[name].oact_qconfig = None
95 | self.model_qconfig_dict[name].ograd_qconfig = None
96 |
97 | # additional handling of HW patching
98 | for name,module in model.named_modules():
99 | if type(module) in [torch.nn.Conv2d] and name in self.model_qconfig_dict:
100 | if module.in_channels < 64 or module.out_channels < 64:
101 | self.model_qconfig_dict[name].patch_ops = False
102 | self.model_qconfig_dict[name].patch_impl = "NONE"
103 | self.list_unpatched += [name]
104 |
105 | # Except for Conv2d, and Linear module disable quantization on weight
106 | for name,module in model.named_modules():
107 | if type(module) not in [torch.nn.Conv2d, torch.nn.Linear]\
108 | and name in self.model_qconfig_dict:
109 | self.model_qconfig_dict[name].wt_qconfig = None
110 | self.model_qconfig_dict[name].wtgrad_qconfig = None
111 |
112 | for name,module in model.named_modules():
113 | if ((type(module) == torch.nn.Embedding) or (type(module) == torch.nn.EmbeddingBag))\
114 | and name in self.model_qconfig_dict:
115 | self.model_qconfig_dict[name].wt_qconfig = self.emb_qconfig
116 | self.model_qconfig_dict[name].iact_qconfig = None
117 | self.model_qconfig_dict[name].igrad_qconfig = None
118 | self.model_qconfig_dict[name].ograd_qconfig = None
119 | self.model_qconfig_dict[name].oact_qconfig = None
120 |
121 | for name,module in model.named_modules():
122 | if type(module) in [BatchMatmul]\
123 | and name in self.model_qconfig_dict:
124 | self.model_qconfig_dict[name].wt_qconfig = None
125 | self.model_qconfig_dict[name].wtgrad_qconfig = None
126 | self.model_qconfig_dict[name].oact_qconfig = None
127 | self.model_qconfig_dict[name].ograd_qconfig = None
128 |
129 | reset_quantization_setup(model, self.model_qconfig_dict)
130 | # Adding hooks for quantizing input.
131 | self.hook_handles = add_quantization_hooks(model, self.model_qconfig_dict, is_training=self.is_training)
132 | if not self.is_training :
133 | print("e3m4 : quantizing model weights..")
134 | quantize_model_weights(model, self.model_qconfig_dict)
135 | set_quantize_weights_flag(model, self.model_qconfig_dict, False)
136 |
137 | def prepare_model(self, model, list_exempt_layers, list_layers_output_fused):
138 | mod_qconfig = ModuleQuantConfig(wt_qconfig=self.wt_qconfig,
139 | iact_qconfig=self.iact_qconfig,
140 | oact_qconfig=self.oact_qconfig,
141 | patch_ops=self.patch_ops)
142 | mod_qconfig.device = self.device
143 | mod_qconfig.patch_impl = self.patch_impl
144 | mod_qconfig.sparse_config = self.sparse_config
145 | self.mod_qconfig = mod_qconfig
146 | self.list_exempt_layers = list_exempt_layers
147 | self.list_layers_output_fused = list_layers_output_fused
148 | self.create_or_update_hooks(model)
149 |
150 | def enable_hw_patching(self, patch_ops):
151 | if patch_ops != 'NONE':
152 | if patch_ops in self.patchlist :
153 | self.patch_ops = True
154 | self.patch_impl = patch_ops
155 | print("e3m4_emulator: PyTorch Ops are monkey-patched to use {} kernels : {}".format(self.patch_impl, self.patch_ops))
156 | else :
157 | raise RuntimeError("e3m4_emulator: HW patching is not supported for {}, supported list of options : {}".format(patch_ops, self.patchlist))
158 |
159 | def set_calibration_qconfig(self):
160 | self.emb_qconfig = TensorQuantConfig("e3m4", "rne", "per-tensor")
161 | self.wt_qconfig = TensorQuantConfig("e3m4", "rne", "per-tensor")
162 | self.iact_qconfig = TensorQuantConfig("e3m4", "rne", "per-tensor")
163 | self.oact_qconfig = None
164 |
165 | def set_default_inference_qconfig(self):
166 | self.emb_qconfig = TensorQuantConfig("e3m4", "rne", "per-channel")
167 | self.wt_qconfig = TensorQuantConfig("e3m4", "rne", "per-channel")
168 | self.iact_qconfig = TensorQuantConfig("e3m4", "rne", "per-tensor")
169 | self.oact_qconfig = None
170 |
171 | def fuse_layers_and_quantize_model(self, model):
172 | if self.is_training :
173 | print("Warning : emulator.is_training is set to True, returning the model unchanged")
174 | return model
175 | if self.verbose :
176 | print("Fusing Batchnorm layers and replacing them with scale and shift")
177 |
178 | model = replace_batchnorms_with_scaleshifts(model)
179 | self.is_training = False
180 | self.set_default_inference_qconfig()
181 | self.prepare_model(model, self.list_exempt_layers, self.list_layers_output_fused)
182 |
183 | #reset_quantization_setup(model, self.model_qconfig_dict)
184 | #add_quantization_hooks(model, self.model_qconfig_dict)
185 | ##quantize_model_weights(model, self.model_qconfig_dict) # added new
186 | #set_quantize_weights_flag(model, self.model_qconfig_dict, False)
187 | model = model.to(self.device)
188 | if self.verbose :
189 | self.print_config()
190 |
191 | return model
192 |
193 | def print_config(self):
194 | for key in self.model_qconfig_dict:
195 | print("{} {:40s}".format(self.model_qconfig_dict[key], key))
196 |
197 | def __repr__(self):
198 | train_infer = "inference"
199 | if self.is_training :
200 | train_infer = "training"
201 | return "[Configured to run {} on {}, using AMP: {}]".format(str(train_infer), self.device, str(self.using_apex))
202 |
--------------------------------------------------------------------------------
/mpemu/e4m3_emu.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | from collections import OrderedDict
11 | import torch
12 | from .qutils import TensorQuantConfig, ModuleQuantConfig
13 | from .qutils import get_or_update_model_quant_config_dict
14 | from .qutils import reset_quantization_setup,add_quantization_hooks
15 | from .qutils import quantize_model_weights,set_quantize_weights_flag
16 | from .scale_shift import replace_batchnorms_with_scaleshifts
17 | from .module_wrappers import BatchMatmul,Matmul,AddMatmul,EltwiseMul,EltwiseAdd,EltwiseDiv
18 | from .module_wrappers import SparseConv2d,SparseLinear
19 |
20 | '''
21 | E4M3 Mixed Precision Emulator
22 | '''
23 | class E4M3Emulator(object):
24 | def __init__(self, model, optimizer, sparse_config=None, device="cuda", train=False, verbose=False, tensor_stats=False):
25 | super(E4M3Emulator, self).__init__()
26 | self.whitelist = [torch.nn.Conv2d, torch.nn.Linear, torch.nn.Embedding, torch.nn.EmbeddingBag]
27 | self.whitelist += [Matmul, BatchMatmul, AddMatmul]
28 | self.whitelist += [EltwiseAdd, EltwiseMul, EltwiseDiv]
29 | self.whitelist += [SparseConv2d, SparseLinear]
30 | self.blacklist = []
31 | self.list_unpatched = []
32 | self.is_training = train
33 | self.list_exempt_layers = None
34 | self.list_layers_output_fused = None
35 | self.device = device
36 | self.model = model
37 | self.patch_ops = False
38 | self.patch_impl = "NONE"
39 | self.patchlist = ["simple"]
40 | self.patchlist = ["SIMPLE"]
41 | self.sparse_config = sparse_config
42 | # default configuration
43 | self.data_emulation = True
44 | self.mod_qconfig = None
45 | self.model_qconfig_dict = None #OrderedDict()
46 | self.emb_qconfig = TensorQuantConfig("e4m3", "rne")#, "per-channel")
47 | self.wt_qconfig = TensorQuantConfig("e4m3", "rne")#, "per-channel")
48 | self.iact_qconfig = TensorQuantConfig("e4m3", "rne")#, "per-tensor")
49 | self.oact_qconfig = None #TensorQuantConfig("e4m3", "rne")
50 | self.hook_handles = None
51 | self.verbose = verbose
52 |
53 | def blacklist_modules(self, list_modules) :
54 | if self.verbose:
55 | print("Current module list {}".format(self.whitelist))
56 | print("Blacklist-ing modules {}".format(list_modules))
57 | for module in list_modules :
58 | self.blacklist.append(module)
59 | self.whitelist.remove(module)
60 | self.create_or_update_hooks()
61 | if self.verbose:
62 | print("Updated module list {}".format(self.whitelist))
63 | self.print_config()
64 |
65 | def whitelist_modules(self, list_modules) :
66 | if self.verbose:
67 | print("Current module list {}".format(self.whitelist))
68 | print("Whitelist-ing modules {}".format(list_modules))
69 | for module in list_modules :
70 | self.whitelist.append(module)
71 | self.blacklist.remove(module)
72 | self.create_or_update_hooks()
73 | if self.verbose:
74 | print("Updated module list {}".format(self.whitelist))
75 | self.print_config()
76 |
77 | def create_or_update_hooks(self, model):
78 | self.model_qconfig_dict = get_or_update_model_quant_config_dict(model,
79 | self.whitelist, self.mod_qconfig,
80 | model_qconfig_dict=self.model_qconfig_dict,
81 | override=True)
82 |
83 | if self.list_exempt_layers is not None :
84 | for exempt_layer in self.list_exempt_layers:
85 | if self.model_qconfig_dict.get(exempt_layer) is not None:
86 | self.model_qconfig_dict.pop(exempt_layer)
87 |
88 | # Disable output quantization for these layers,
89 | # These layers are followed by precision sensitive layers such as SoftMax
90 | # In the final implementation, sensitive layers are fused with the preceeding layer
91 | if self.list_layers_output_fused is not None :
92 | for name,module in model.named_modules():
93 | if name in self.list_layers_output_fused and name not in self.list_exempt_layers \
94 | and module in self.whitelist :
95 | self.model_qconfig_dict[name].oact_qconfig = None
96 | self.model_qconfig_dict[name].ograd_qconfig = None
97 |
98 | # additional handling of HW patching
99 | for name,module in model.named_modules():
100 | if type(module) in [torch.nn.Conv2d] and name in self.model_qconfig_dict:
101 | if module.in_channels < 64 or module.out_channels < 64:
102 | self.model_qconfig_dict[name].patch_ops = False
103 | self.model_qconfig_dict[name].patch_impl = "NONE"
104 | self.list_unpatched += [name]
105 |
106 | # Except for Conv2d, and Linear module disable quantization on weight
107 | for name,module in model.named_modules():
108 | if type(module) not in [torch.nn.Conv2d, torch.nn.Linear]\
109 | and name in self.model_qconfig_dict:
110 | self.model_qconfig_dict[name].wt_qconfig = None
111 | self.model_qconfig_dict[name].wtgrad_qconfig = None
112 |
113 | for name,module in model.named_modules():
114 | if ((type(module) == torch.nn.Embedding) or (type(module) == torch.nn.EmbeddingBag))\
115 | and name in self.model_qconfig_dict:
116 | self.model_qconfig_dict[name].wt_qconfig = self.emb_qconfig
117 | self.model_qconfig_dict[name].iact_qconfig = None
118 | self.model_qconfig_dict[name].igrad_qconfig = None
119 | self.model_qconfig_dict[name].ograd_qconfig = None
120 | self.model_qconfig_dict[name].oact_qconfig = None
121 |
122 | for name,module in model.named_modules():
123 | if type(module) in [BatchMatmul]\
124 | and name in self.model_qconfig_dict:
125 | self.model_qconfig_dict[name].wt_qconfig = None
126 | self.model_qconfig_dict[name].wtgrad_qconfig = None
127 | self.model_qconfig_dict[name].oact_qconfig = None
128 | self.model_qconfig_dict[name].ograd_qconfig = None
129 |
130 | reset_quantization_setup(model, self.model_qconfig_dict)
131 | # Adding hooks for quantizing input.
132 | self.hook_handles = add_quantization_hooks(model, self.model_qconfig_dict, is_training=self.is_training)
133 | if not self.is_training :
134 | print("e4m3 : quantizing model weights..")
135 | quantize_model_weights(model, self.model_qconfig_dict)
136 | set_quantize_weights_flag(model, self.model_qconfig_dict, False)
137 |
138 | def prepare_model(self, model, list_exempt_layers, list_layers_output_fused):
139 | mod_qconfig = ModuleQuantConfig(wt_qconfig=self.wt_qconfig,
140 | iact_qconfig=self.iact_qconfig,
141 | oact_qconfig=self.oact_qconfig,
142 | patch_ops=self.patch_ops)
143 | mod_qconfig.device = self.device
144 | mod_qconfig.patch_impl = self.patch_impl
145 | mod_qconfig.sparse_config = self.sparse_config
146 | self.mod_qconfig = mod_qconfig
147 | self.list_exempt_layers = list_exempt_layers
148 | self.list_layers_output_fused = list_layers_output_fused
149 | self.create_or_update_hooks(model)
150 |
151 | def enable_hw_patching(self, patch_ops):
152 | if patch_ops != 'NONE':
153 | if patch_ops in self.patchlist :
154 | self.patch_ops = True
155 | self.patch_impl = patch_ops
156 | print("e4m3_emulator: PyTorch Ops are monkey-patched to use {} kernels : {}".format(self.patch_impl, self.patch_ops))
157 | else :
158 | raise RuntimeError("e4m3_emulator: HW patching is not supported for {}, supported list of options : {}".format(patch_ops, self.patchlist))
159 |
160 | def fuse_batchnorm_with_convolution(self, model):
161 | from torch.nn.utils.fusion import fuse_conv_bn_eval
162 | temp = []
163 | for name, module in model.named_children():
164 | if list(module.named_children()):
165 | self.fuse_batchnorm_with_convolution(module)
166 |
167 | if isinstance(module, torch.nn.BatchNorm2d):
168 | if isinstance(temp[-1][1], torch.nn.Conv2d):
169 | setattr(model, temp[-1][0], fuse_conv_bn_eval(temp[-1][1], module))
170 | setattr(model, name, torch.nn.Identity())
171 | else:
172 | temp.append((name, module))
173 | return model
174 |
175 | def set_calibration_qconfig(self):
176 | self.emb_qconfig = TensorQuantConfig("e4m3", "rne", "per-channel")
177 | self.wt_qconfig = TensorQuantConfig("e4m3", "rne", "per-channel")
178 | self.iact_qconfig = TensorQuantConfig("e4m3", "rne", "per-tensor")
179 | self.oact_qconfig = None
180 |
181 | def set_default_inference_qconfig(self):
182 | self.emb_qconfig = TensorQuantConfig("e4m3", "rne", "per-channel")
183 | self.wt_qconfig = TensorQuantConfig("e4m3", "rne", "per-channel")
184 | self.iact_qconfig = TensorQuantConfig("e4m3", "rne", "per-tensor")
185 | self.oact_qconfig = None
186 |
187 | def fuse_layers_and_quantize_model(self, model):
188 | if self.is_training :
189 | print("Warning : emulator.is_training is set to True, returning the model unchanged")
190 | return model
191 | if self.verbose :
192 | print("Fusing Batchnorm layers and replacing them with scale and shift")
193 |
194 | model = replace_batchnorms_with_scaleshifts(model)
195 | #model = self.fuse_batchnorm_with_convolution(model)
196 | self.is_training = False
197 | self.set_default_inference_qconfig()
198 | self.prepare_model(model, self.list_exempt_layers, self.list_layers_output_fused)
199 | #reset_quantization_setup(model, self.model_qconfig_dict)
200 | #add_quantization_hooks(model, self.model_qconfig_dict)
201 | #quantize_model_weights(model, self.model_qconfig_dict) # added new
202 | #set_quantize_weights_flag(model, self.model_qconfig_dict, False)
203 | model = model.to(self.device)
204 | if self.verbose :
205 | self.print_config()
206 |
207 | return model
208 |
209 | def print_config(self):
210 | for key in self.model_qconfig_dict:
211 | print("{} {:40s}".format(self.model_qconfig_dict[key], key))
212 |
213 | def __repr__(self):
214 | train_infer = "inference"
215 | if self.is_training :
216 | train_infer = "training"
217 | return "[Configured to run {} on {}, using AMP: {}]".format(str(train_infer), self.device, str(self.using_apex))
218 |
--------------------------------------------------------------------------------
/mpemu/module_wrappers/__init__.py:
--------------------------------------------------------------------------------
1 | '''
2 | INTEL CONFIDENTIAL
3 | Copyright (C) 2020 Intel Corporation
4 | This software and the related documents are Intel copyrighted materials, and your
5 | use of them is governed by the express license under which they were provided
6 | to you ("License"). Unless the License provides otherwise, you may not use, modify,
7 | copy, publish, distribute, disclose or transmit this software or the related
8 | documents without Intel's prior written permission. This software and the related
9 | documents are provided as is, with no express or implied warranties, other than
10 | those that are expressly stated in the License.
11 | '''
12 | '''
13 | Author(s) : Dharma Teja Vooturi, Naveen Mellempudi
14 | '''
15 | from .eltwise import EltwiseAdd, EltwiseMul, EltwiseDiv
16 | from .matmul import Matmul, BatchMatmul, AddMatmul
17 | from .aggregate import Norm, Mean
18 | from .adasparse import SparseLinear, SparseConv2d
19 |
--------------------------------------------------------------------------------
/mpemu/module_wrappers/adasparse.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | import math
11 | import torch
12 | import torch.nn as nn
13 |
14 | """
15 | Function for activation binarization
16 | """
17 | class WeightMaskStep(torch.autograd.Function):
18 | @staticmethod
19 | def forward(ctx, input):
20 | ctx.save_for_backward(input)
21 | return (input>0.).to(input.dtype)
22 |
23 | @staticmethod
24 | def backward(ctx, grad_output):
25 | input, = ctx.saved_tensors
26 | grad_input = grad_output.clone()
27 | zero_index = torch.abs(input) > 1
28 | middle_index = (torch.abs(input) <= 1) * (torch.abs(input) > 0.4)
29 | additional = 2-4*torch.abs(input)
30 | additional[zero_index] = 0.
31 | additional[middle_index] = 0.4
32 | return grad_input*additional
33 |
34 | class SparseLinear(nn.Module):
35 | def __init__(self, in_features, out_features, bias=True):
36 | super(SparseLinear, self).__init__()
37 | self.in_features = in_features
38 | self.out_features = out_features
39 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
40 | if bias:
41 | self.bias = nn.Parameter(torch.Tensor(out_features))
42 | else:
43 | self.register_parameter('bias', None)
44 |
45 | self.threshold = nn.Parameter(torch.Tensor(out_features))
46 | self.weight_mask = WeightMaskStep.apply
47 | #self.mask = None
48 | self.reset_parameters()
49 |
50 |
51 | def reset_parameters(self):
52 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
53 | if self.bias is not None:
54 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
55 | bound = 1 / math.sqrt(fan_in)
56 | nn.init.uniform_(self.bias, -bound, bound)
57 | with torch.no_grad():
58 | #std = self.weight.std()
59 | self.threshold.data.fill_(0)
60 |
61 | def forward(self, input):
62 | abs_weight = torch.abs(self.weight)
63 | threshold = self.threshold.view(abs_weight.shape[0], -1)
64 | abs_weight = abs_weight-threshold
65 | mask = self.weight_mask(abs_weight)
66 | ratio = torch.sum(mask) / mask.numel()
67 | #print("keep ratio {:.2f}".format(ratio))
68 | if ratio <= 0.01:
69 | with torch.no_grad():
70 | #std = self.weight.std()
71 | self.threshold.data.fill_(0)
72 | abs_weight = torch.abs(self.weight)
73 | threshold = self.threshold.view(abs_weight.shape[0], -1)
74 | abs_weight = abs_weight-threshold
75 | mask = self.weight_mask(abs_weight)
76 | masked_weight = self.weight * mask
77 | output = torch.nn.functional.linear(input, masked_weight, self.bias)
78 | return output
79 | def extra_repr(self) -> str:
80 | return 'in_features={}, out_features={}, bias={}'.format(
81 | self.in_features, self.out_features, self.bias is not None
82 | )
83 |
84 | class SparseConv2d(nn.Module):
85 | def __init__(self, in_c, out_c, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode="zeros"):
86 | super(SparseConv2d, self).__init__()
87 | self.in_channels = in_c
88 | self.out_channels = out_c
89 | self.kernel_size = kernel_size
90 | self.stride = stride
91 | self.padding = padding
92 | self.dilation = dilation
93 | self.groups = groups
94 | self.padding_mode = padding_mode
95 |
96 | ## define weight
97 | self.weight = nn.Parameter(torch.Tensor(
98 | out_c, in_c // groups, *kernel_size
99 | ))
100 | if bias:
101 | self.bias = nn.Parameter(torch.Tensor(out_c))
102 | else:
103 | self.register_parameter('bias', None)
104 | self.threshold = nn.Parameter(torch.Tensor(out_c))
105 | self.weight_mask = WeightMaskStep.apply
106 | self.reset_parameters()
107 |
108 | def reset_parameters(self):
109 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
110 | if self.bias is not None:
111 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
112 | bound = 1 / math.sqrt(fan_in)
113 | nn.init.uniform_(self.bias, -bound, bound)
114 | with torch.no_grad():
115 | self.threshold.data.fill_(0.)
116 |
117 | def forward(self, x):
118 | weight_shape = self.weight.shape
119 | threshold = self.threshold.view(weight_shape[0], -1)
120 | weight = torch.abs(self.weight)
121 | weight = weight.view(weight_shape[0], -1)
122 | weight = weight - threshold
123 | mask = self.weight_mask(weight)
124 | mask = mask.view(weight_shape)
125 | ratio = torch.sum(mask) / mask.numel()
126 | # print("threshold {:3f}".format(self.threshold[0]))
127 | # print("keep ratio {:.2f}".format(ratio))
128 | if ratio <= 0.01:
129 | with torch.no_grad():
130 | self.threshold.data.fill_(0.)
131 | threshold = self.threshold.view(weight_shape[0], -1)
132 | weight = torch.abs(self.weight)
133 | weight = weight.view(weight_shape[0], -1)
134 | weight = weight - threshold
135 | mask = self.weight_mask(weight)
136 | mask = mask.view(weight_shape)
137 | masked_weight = self.weight * mask
138 |
139 | conv_out = torch.nn.functional.conv2d(x, masked_weight, bias=self.bias, stride=self.stride,
140 | padding=self.padding, dilation=self.dilation, groups=self.groups)
141 | return conv_out
142 |
143 | def extra_repr(self):
144 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
145 | ', stride={stride}')
146 | if self.padding != (0,) * len(self.padding):
147 | s += ', padding={padding}'
148 | if self.dilation != (1,) * len(self.dilation):
149 | s += ', dilation={dilation}'
150 | if self.groups != 1:
151 | s += ', groups={groups}'
152 | if self.bias is None:
153 | s += ', bias=False'
154 | if self.padding_mode != 'zeros':
155 | s += ', padding_mode={padding_mode}'
156 | return s.format(**self.__dict__)
157 |
158 | def __setstate__(self, state):
159 | super(SparseConv2d, self).__setstate__(state)
160 | if not hasattr(self, 'padding_mode'):
161 | self.padding_mode = 'zeros'
162 |
--------------------------------------------------------------------------------
/mpemu/module_wrappers/aggregate.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | import torch
11 | import torch.nn as nn
12 |
13 | class Norm(nn.Module):
14 | def __init__(self, p='fro', dim=None, keepdim=False):
15 | super(Norm, self).__init__()
16 | self.p = p
17 | self.dim = dim
18 | self.keepdim = keepdim
19 |
20 | def forward(self, x: torch.Tensor):
21 | return torch.norm(x, p=self.p, dim=self.dim, keepdim=self.keepdim)
22 | def extra_repr(self) -> str:
23 | return 'p={}, dim={}, keepdim: {}'.format(self.p, self.dim, self.keepdim)
24 |
25 | class Mean(nn.Module):
26 | def __init__(self, *args, **kwargs):
27 | super(Mean, self).__init__()
28 | self.args = args
29 | self.kwargs = kwargs
30 |
31 | def forward(self, x: torch.Tensor):
32 | return torch.mean(x, *self.args, **self.kwargs)
33 | def extra_repr(self) -> str:
34 | return 'args={}, kwargs={}'.format(self.args, self.kwargs)
35 |
36 |
--------------------------------------------------------------------------------
/mpemu/module_wrappers/eltwise.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Dharma Teja Vooturi, Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | import torch
11 | import torch.nn as nn
12 |
13 | class EltwiseAdd(nn.Module):
14 | def __init__(self, inplace=False):
15 | super(EltwiseAdd, self).__init__()
16 |
17 | self.inplace = inplace
18 |
19 | def forward(self, *input):
20 | res = input[0]
21 | if self.inplace:
22 | for t in input[1:]:
23 | res += t
24 | else:
25 | for t in input[1:]:
26 | res = res + t
27 | return res
28 |
29 | def extra_repr(self) -> str:
30 | return 'inplace={}'.format(self.inplace)
31 |
32 | class EltwiseMul(nn.Module):
33 | def __init__(self, inplace=False):
34 | super(EltwiseMult, self).__init__()
35 | self.inplace = inplace
36 |
37 | def forward(self, *input):
38 | res = input[0]
39 | if self.inplace:
40 | for t in input[1:]:
41 | res *= t
42 | else:
43 | for t in input[1:]:
44 | res = res * t
45 | return res
46 | def extra_repr(self) -> str:
47 | return 'inplace={}'.format(self.inplace)
48 |
49 |
50 | class EltwiseDiv(nn.Module):
51 | def __init__(self, inplace=False):
52 | super(EltwiseDiv, self).__init__()
53 | self.inplace = inplace
54 |
55 | def forward(self, x: torch.Tensor, y):
56 | if self.inplace:
57 | return x.div_(y)
58 | return x.div(y)
59 | def extra_repr(self) -> str:
60 | return 'inplace={}'.format(self.inplace)
61 |
62 |
--------------------------------------------------------------------------------
/mpemu/module_wrappers/matmul.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Dharma Teja Vooturi, Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | import torch
11 | import torch.nn as nn
12 |
13 | class Matmul(nn.Module):
14 | def __init__(self):
15 | super(Matmul, self).__init__()
16 |
17 | def forward(self, a: torch.Tensor, b: torch.Tensor):
18 | return torch.matmul(a, b)
19 |
20 | class BatchMatmul(nn.Module):
21 | def __init__(self):
22 | super(BatchMatmul, self).__init__()
23 |
24 | def forward(self, a: torch.Tensor, b:torch.Tensor):
25 | return torch.bmm(a, b)
26 |
27 | class AddMatmul(nn.Module):
28 | def __init__(self):
29 | super(AddMatmul, self).__init__()
30 |
31 | def forward(self, input:torch.Tensor, mat1: torch.Tensor, mat2:torch.Tensor, beta=1, alpha=1):
32 | return torch.addmm(input, mat1, mat2, beta=beta, alpha=alpha)
33 |
--------------------------------------------------------------------------------
/mpemu/mpt_emu.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | from collections import OrderedDict
11 | import torch
12 | import os
13 | from .qutils import TensorQuantConfig, ModuleQuantConfig
14 | from .qutils import get_or_update_model_quant_config_dict
15 | from .qutils import reset_quantization_setup,add_quantization_hooks
16 | from .qutils import quantize_model_weights,set_quantize_weights_flag
17 | from .e5m2_emu import E5M2Emulator
18 | from .e4m3_emu import E4M3Emulator
19 | from .e3m4_emu import E3M4Emulator
20 | from .hybrid_emu import HybridEmulator
21 | from .bfloat16_emu import Bfloat16Emulator
22 | from .scale_shift import replace_batchnorms_with_scaleshifts
23 | from .module_wrappers import SparseLinear, SparseConv2d
24 | from .sparse_utils import SparseConfig
25 |
26 | '''
27 | Mixed Precision Emulator
28 | '''
29 | class MPTEmulator(object):
30 | def __init__(self, device='cuda', training_algo='direct', hw_patch='none', pruning_algo='none'):
31 | super(MPTEmulator, self).__init__()
32 | #self.valid_dtypes = ["fp32", "fp16", "bf16", "e5m2", "e4m3", "e3m4"]
33 | self.valid_training_methods = ["direct", "hybrid"]
34 | self.valid_hw_patches = ["none", "simple"]
35 | self.valid_pruning_methods = ["none", "fine-grained", "unstructured", "adaptive", 'auto']
36 | # basic checks
37 | #if dtype.lower() not in self.valid_dtypes:
38 | # raise RuntimeError("mpt_emulator: the requested data type {} is not supported, use one of the following: {}"
39 | # .format(dtype, self.valid_dtypes))
40 | if training_algo.lower() not in self.valid_training_methods:
41 | raise RuntimeError("mpt_emulator: the requested training method {} is not supported, use one of the following: {}"
42 | .format(hw_patch, self.valid_training_methods))
43 | if hw_patch.lower() not in self.valid_hw_patches:
44 | raise RuntimeError("mpt_emulator: the requested hardware emulation {} is not supported, use one of the following: {}"
45 | .format(hw_patch, self.valid_hw_patches))
46 | if pruning_algo.lower() not in self.valid_pruning_methods:
47 | raise RuntimeError("mpt_emulator: the requested pruning method {} is not supported, use one of the following: {}"
48 | .format(pruning_algo.lower(), self.valid_pruning_methods))
49 |
50 | self.device = device
51 | self.training_algo = training_algo.lower()
52 | self.hw_patch = hw_patch.lower()
53 | self.pruning_method = pruning_algo.lower()
54 | self.wt_sparsity = 0.5
55 | self.grad_sparsity = 0.5
56 |
57 | # emiulator object
58 | self.emulator = None
59 | self.sparse_config = None
60 |
61 | def blacklist_modules(self, list_modules):
62 | self.emulator.blacklist_modules(list_modules)
63 |
64 | def whitelist_modules(self, list_modules):
65 | self.emulator.whitelist_modules(list_modules)
66 |
67 | def optimizer_step(self, optimizer):
68 | self.emulator.optimizer_step(optimizer)
69 | if self.sparse_config != None:
70 | self.sparse_config.optimizer_step(optimizer)
71 |
72 | def update_global_steps(self, global_steps):
73 | self.emulator.global_steps = global_steps
74 |
75 | def set_master_param_precision(self, master_params):
76 | self.emulator.set_master_param_precision(master_params)
77 |
78 | def set_embedding_precision(self, emb_precision, emb_norm=False):
79 | self.emulator.set_embedding_precision(emb_precision, emb_norm)
80 |
81 | def enable_tensor_stats(self, summary_writer=None):
82 | self.emulator.enable_tensor_stats(summary_writer)
83 |
84 | def set_tensor_bindump_schedule(self, list_bindump_schedule):
85 | self.emulator.set_tensor_bindump_schedule(list_bindump_schedule)
86 |
87 | def enable_tensorboard_logging(self, summary_writer=None):
88 | self.emulator.enable_tensor_stats(summary_writer)
89 |
90 | def set_target_sparsity_weight(self, wt_sparsity=0.5):
91 | if self.sparse_config != None:
92 | self.wt_sparsity = float(wt_sparsity)
93 | self.sparse_config.weight_factor = self.wt_sparsity
94 | print("mpt-emulator: weight sparsity has been set to : {}".format(self.sparse_config.weight_factor))
95 | else:
96 | print("mpt-emulator: set_target_sparsity_weight has no effect; sparse training is not enabled")
97 |
98 | def set_target_sparsity_gradient(self, grad_sparsity=0.5):
99 | if self.sparse_config != None:
100 | self.grad_sparsity = float(grad_sparsity)
101 | self.sparse_config.outgrad_factor = self.grad_sparsity
102 | print("mpt-emulator: gradient sparsity has been set to : {}".format(self.sparse_config.outgrad_factor))
103 | else:
104 | print("mpt-emulator: set_target_sparsity_gradient has no effect; sparse training is not enabled")
105 |
106 | def fuse_bnlayers_and_quantize_model(self, model):
107 | self.emulator.is_training = False
108 | return self.emulator.fuse_layers_and_quantize_model(model)
109 |
110 | def set_calibration_qconfig(self):
111 | self.emulator.set_calibration_qconfig()
112 |
113 | def set_default_inference_qconfig(self):
114 | self.emulator.set_default_inference_qconfig()
115 |
116 | def __repr__(self):
117 | train_infer = "inference"
118 | if self.is_training :
119 | train_infer = "training"
120 | return "[Configured to run {} on {}, using AMP: {}]".format(str(train_infer), self.device, str(self.using_apex))
121 |
122 | def rewrite_model_with_adasparse_ops(model, exempt_layers):
123 | list_exempt_layers = exempt_layers.copy()
124 | if isinstance(model, torch.nn.Conv2d):
125 | return SparseConv2d(model.in_channels, model.out_channels, model.kernel_size,
126 | stride=model.stride, padding=model.padding, dilation=model.dilation, groups=model.groups,
127 | bias=True if model.bias!= None else False).to(device=next(model.parameters()).device, dtype=next(model.parameters()).dtype)
128 | elif isinstance(model, torch.nn.Linear):
129 | return SparseLinear(model.in_features, model.out_features,
130 | bias=True if model.bias!= None else False).to(device=next(model.parameters()).device, dtype=next(model.parameters()).dtype)
131 |
132 | if list_exempt_layers is None:
133 | for name, module in model.named_children():
134 | module = rewrite_model_with_adasparse_ops(module, None)
135 | setattr(model, name, module)
136 | else :
137 | for name, module in model.named_children():
138 | if name in list_exempt_layers:
139 | setattr(model, name, module)
140 | list_exempt_layers.remove(name)
141 | else:
142 | module = rewrite_model_with_adasparse_ops(module, list_exempt_layers)
143 | setattr(model, name, module)
144 | return model
145 |
146 | def initialize(model, optimizer, training_algo='direct', hw_patch='none', pruning_algo='none',
147 | list_exempt_layers=None, list_layers_output_fused=None, device="cuda", verbose=False ):
148 | if model is None: #or optimizer is None:
149 | raise RuntimeError("mpt_emulator: Undefined model and optimizer, call this after model and optimizer are initilized.")
150 | if device == 'cuda' and hw_patch.lower() != 'none':
151 | raise RuntimeError("mpt_emulator: HW patching ops is only alowed on 'cpu' device.")
152 |
153 | mpt = MPTEmulator(device=device, training_algo=training_algo, hw_patch=hw_patch, pruning_algo=pruning_algo)
154 |
155 | if mpt.pruning_method == 'adaptive':
156 | model = rewrite_model_with_adasparse_ops(model, list_exempt_layers)
157 | print("mpt_emulator: Adaptive pruning method enabled; Adaptive (weights) only!")
158 | elif mpt.pruning_method == 'unstructured':
159 | mpt.sparse_config = SparseConfig ()
160 | mpt.sparse_config.weight = True
161 | mpt.sparse_config.outgrad = True
162 | mpt.sparse_config.weight_factor = mpt.wt_sparsity
163 | mpt.sparse_config.outgrad_factor = mpt.grad_sparsity
164 | mpt.sparse_config.print_stats = False
165 | print("mpt_emulator: Unstructured pruning method enabled; TopK(weights : {}), Stochastic(gradients : {})."
166 | .format(mpt.sparse_config.weight_factor, mpt.sparse_config.outgrad_factor))
167 | elif mpt.pruning_method == 'auto':
168 | model = rewrite_model_with_adasparse_ops(model, list_exempt_layers)
169 | mpt.sparse_config = SparseConfig ()
170 | mpt.sparse_config.outgrad = True
171 | mpt.sparse_config.outgrad_factor = mpt.grad_sparsity
172 | mpt.sparse_config.print_stats = False
173 | print("mpt_emulator: Auto pruning method enabled; Adaptive(weights), Stochastic(gradients : {})."
174 | .format(mpt.sparse_config.outgrad_factor))
175 |
176 | if mpt.training_algo == 'hybrid':
177 | mpt.emulator = HybridEmulator(model, optimizer, mpt.sparse_config, device=device, verbose=verbose)
178 | mpt.emulator.set_master_param_precision("fp16")
179 | else: # direct
180 | mpt.emulator = E5M2Emulator(model, optimizer, mpt.sparse_config, device=device, verbose=verbose)
181 | mpt.emulator.set_master_param_precision("fp16")
182 |
183 | mpt.emulator.enable_hw_patching(mpt.hw_patch.upper())
184 | mpt.emulator.prepare_model(model, list_exempt_layers, list_layers_output_fused)
185 | if mpt.emulator.patch_ops == True and len(mpt.emulator.list_unpatched):
186 | print("mpt_emulator: Following layers are not HW_PATCH'ed because thier dimensions do not match the hardware : {} ".format(mpt.emulator.list_unpatched))
187 |
188 | if verbose :
189 | mpt.emulator.print_config()
190 |
191 | return model, mpt
192 |
193 | def quantize_model(model, optimizer=None, dtype="none", calibrate=False, hw_patch="none", fuse_bn=False,
194 | list_exempt_layers=None, list_layers_output_fused=None, device="cuda", verbose=False ):
195 | if model is None :
196 | raise RuntimeError("mpt_emulator: Undefined model , call this after model is initilized.")
197 | if dtype == 'fp16' and device != 'cuda':
198 | raise RuntimeError("mpt_emulator: the requested data type {} is not supported on {}.".format(dtype, device))
199 | if device == 'cuda' and hw_patch.lower() != 'none':
200 | raise RuntimeError("mpt_emulator: HW patching ops is only alowed on 'cpu' device.")
201 |
202 | mpt = MPTEmulator(device=device, hw_patch=hw_patch)
203 | if fuse_bn :
204 | model = mpt.fuse_bnlayers_and_quantize_model(model)
205 |
206 | if dtype.upper() == 'E5M2':
207 | mpt.emulator = E5M2Emulator(model, optimizer, None, device=device, verbose=verbose)
208 | elif dtype.upper() == 'E4M3':
209 | mpt.emulator = E4M3Emulator(model, optimizer, None, device=device, verbose=verbose)
210 | elif dtype.upper() == 'E3M4':
211 | mpt.emulator = E3M4Emulator(model, optimizer, None, device=device, verbose=verbose)
212 | elif dtype.upper() == 'HYBRID':
213 | mpt.emulator = HybridEmulator(model, optimizer, None, device=device, verbose=verbose)
214 |
215 | if calibrate:
216 | print("mpt_emulator : preparing model for calibration")
217 | mpt.emulator.is_training = False
218 | mpt.set_calibration_qconfig()
219 | else:
220 | mpt.emulator.is_training = False
221 | mpt.set_default_inference_qconfig()
222 |
223 | mpt.emulator.enable_hw_patching(mpt.hw_patch.upper())
224 | mpt.emulator.prepare_model(model, list_exempt_layers, list_layers_output_fused)
225 | if mpt.emulator.patch_ops == True and len(mpt.emulator.list_unpatched):
226 | print("mpt_emulator: Following layers are not HW_PATCH'ed they cannot use the hardware configuration: {} ".format(mpt.emulator.list_unpatched))
227 |
228 | if verbose :
229 | mpt.emulator.print_config()
230 |
231 | return model, mpt
232 |
--------------------------------------------------------------------------------
/mpemu/pytquant/__init__.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import torch
3 |
4 | from . import cpp
5 | from .cpp import fpemu_cpp
6 |
7 | if torch.cuda.is_available() and torch.version.cuda:
8 | from . import cuda
9 | from .cuda import fpemu_cuda as fpemu_cuda
10 |
11 | if torch.cuda.is_available() and torch.version.hip:
12 | from . import hip
13 | from .hip import fpemu_hip as fpemu_cuda
14 |
--------------------------------------------------------------------------------
/mpemu/pytquant/cpp/__init__.py:
--------------------------------------------------------------------------------
1 | from . import fpemu as fpemu_cpp
2 |
--------------------------------------------------------------------------------
/mpemu/pytquant/cpp/fpemu.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | import math
11 | from torch import nn
12 | from torch.autograd import Function
13 | import torch
14 | import numpy
15 | import fpemu_cpp
16 |
17 | from enum import Enum
18 |
19 | torch.manual_seed(42)
20 |
21 | """
22 | NONE
23 | E5M2_RTZ
24 | E5M2_STOCHASTIC
25 | E5M2_RNE
26 | E5M2_RNAZ
27 | E5M2_RNTZ
28 | E5M2_RPINF
29 | E5M2_RNINF
30 | E5M2_DAZ_STOCHASTIC
31 | E5M2_DAZ_RNE
32 | E5M2_DAZ_RNAZ
33 | E5M2_DAZ_RNTZ
34 | BFLOAT16_STOCHASTIC
35 | BFLOAT16_RNE
36 | FLOAT16_RNE
37 | FLOAT16_STOCHASTIC
38 | FLOAT16_DAZ_RNE
39 | E4M3_RNE
40 | E4M3_STOCHASTIC
41 | """
42 |
43 | class FPEmuOp(Function):
44 | @staticmethod
45 | def forward(ctx, input, mode='NONE', inplace=False, scale=1.0, blocknorm=False, blocksize=1):
46 | if mode == 'NONE' :
47 | ctx.mark_dirty(input)
48 | return input
49 | else :
50 | if input.is_sparse :
51 | input = input.coalesce()
52 | size = input.values().nelement()
53 | if inplace == True:
54 | outputs = fpemu_cpp.forward(input._values().contiguous(), mode, size, inplace, scale, blocknorm, blocksize)
55 | output = input
56 | else :
57 | outputs = fpemu_cpp.forward(input._values().contiguous(), mode, size, inplace, scale, blocknorm, blocksize)
58 | output = torch.sparse.FloatTensor(input.indices(), outputs[0], input.size())
59 | else :
60 | size = input.nelement()
61 | outputs = fpemu_cpp.forward(input.contiguous(), mode, size, inplace, scale, blocknorm, blocksize)
62 | output = outputs[0]
63 |
64 | if inplace == True:
65 | ctx.mark_dirty(input)
66 | return output
67 |
68 | @staticmethod
69 | def backward(ctx, output_grad):
70 | # straight-through estimator
71 | return output_grad, None, None, None, None
72 |
--------------------------------------------------------------------------------
/mpemu/pytquant/cuda/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | if torch.cuda.is_available() and torch.version.cuda:
3 | from . import fpemu as fpemu_cuda
4 |
--------------------------------------------------------------------------------
/mpemu/pytquant/cuda/fpemu.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | import math
11 | from torch import nn
12 | from torch.autograd import Function
13 | import torch
14 | import numpy
15 | import fpemu_cuda
16 |
17 | from enum import Enum
18 |
19 | torch.manual_seed(42)
20 |
21 | """
22 | NONE
23 | E5M2_RTZ
24 | E5M2_STOCHASTIC
25 | E5M2_RNE
26 | E5M2_RNAZ
27 | E5M2_RNTZ
28 | E5M2_RPINF
29 | E5M2_RNINF
30 | E5M2_DAZ_STOCHASTIC
31 | E5M2_DAZ_RNE
32 | E5M2_DAZ_RNAZ
33 | E5M2_DAZ_RNTZ
34 | BFLOAT16_STOCHASTIC
35 | BFLOAT16_RNE
36 | FLOAT16_RNE
37 | FLOAT16_STOCHASTIC
38 | FLOAT16_DAZ_RNE
39 | E4M3_RNE
40 | E4M3_STOCHASTIC
41 | """
42 |
43 | class FPEmuOp(Function):
44 | @staticmethod
45 | def forward(ctx, input, mode='NONE', inplace=False, scale=1.0, blocknorm=False, blocksize=1):
46 | if mode == 'NONE' :
47 | ctx.mark_dirty(input)
48 | return input
49 | else :
50 | if input.is_sparse :
51 | input = input.coalesce()
52 | size = input.values().nelement()
53 | if inplace == True:
54 | outputs = fpemu_cuda.forward(input._values().contiguous(), mode, size, inplace, scale, blocknorm, blocksize)
55 | output = input
56 | else :
57 | outputs = fpemu_cuda.forward(input._values().contiguous(), mode, size, inplace, scale, blocknorm, blocksize)
58 | output = torch.sparse.FloatTensor(input.indices(), outputs[0], input.size())
59 | else :
60 | size = input.nelement()
61 | outputs = fpemu_cuda.forward(input.contiguous(), mode, size, inplace, scale, blocknorm, blocksize)
62 | output = outputs[0]
63 |
64 | if inplace == True:
65 | ctx.mark_dirty(input)
66 | return output
67 |
68 | @staticmethod
69 | def backward(ctx, output_grad):
70 | # straight-through estimator
71 | return output_grad, None, None, None, None
72 |
--------------------------------------------------------------------------------
/mpemu/pytquant/cuda/fpemu_impl.cpp:
--------------------------------------------------------------------------------
1 | /*----------------------------------------------------------------------------*
2 | * Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | * This file is part of FP8-Emulation-Toolkit
4 | *
5 | * SPDX-License-Identifier: BSD-3-Clause
6 | *----------------------------------------------------------------------------*
7 | * Naveen Mellempudi (Intel Corporation)
8 | *----------------------------------------------------------------------------*/
9 |
10 | #include
11 | #include
12 |
13 | std::vector fpemu_cuda_forward(
14 | torch::Tensor input,
15 | std::string mode,
16 | int size,
17 | bool inplace,
18 | float scale,
19 | bool block_norm,
20 | int block_size);
21 |
22 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
23 | #define CHECK_CUDA(x) AT_ASSERTM(x.device().is_cuda(), #x " must be a CUDA tensor")
24 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
25 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
26 |
27 | std::vector fpemu_forward(
28 | torch::Tensor input,
29 | std::string mode,
30 | int size,
31 | bool inplace,
32 | float scale,
33 | bool block_norm,
34 | int block_size) {
35 | CHECK_INPUT(input);
36 | return fpemu_cuda_forward(input, mode, size, inplace, scale, block_norm, block_size);
37 | }
38 |
39 | std::vector fpemu_backward(
40 | torch::Tensor grad,
41 | std::string mode,
42 | int size,
43 | bool inplace,
44 | float scale,
45 | bool block_norm,
46 | int block_size) {
47 | CHECK_INPUT(grad);
48 | return fpemu_cuda_forward(grad, mode, size, inplace, scale, block_norm, block_size);
49 | }
50 |
51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
52 | m.def("forward", &fpemu_forward, "FPEmu forward (CUDA)");
53 | m.def("backward", &fpemu_backward, "FPEmu backward (CUDA)");
54 | }
55 |
--------------------------------------------------------------------------------
/mpemu/pytquant/hip/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | if torch.cuda.is_available() and torch.version.hip:
3 | from . import fpemu as fpemu_hip
4 |
--------------------------------------------------------------------------------
/mpemu/pytquant/hip/fpemu.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | import math
11 | from torch import nn
12 | from torch.autograd import Function
13 | import torch
14 | import numpy
15 | import fpemu_hip
16 |
17 | from enum import Enum
18 |
19 | torch.manual_seed(42)
20 |
21 | """
22 | NONE
23 | E5M2_RTZ
24 | E5M2_STOCHASTIC
25 | E5M2_RNE
26 | E5M2_RNAZ
27 | E5M2_RNTZ
28 | E5M2_RPINF
29 | E5M2_RNINF
30 | E5M2_DAZ_STOCHASTIC
31 | E5M2_DAZ_RNE
32 | E5M2_DAZ_RNAZ
33 | E5M2_DAZ_RNTZ
34 | BFLOAT16_STOCHASTIC
35 | BFLOAT16_RNE
36 | FLOAT16_RNE
37 | FLOAT16_STOCHASTIC
38 | FLOAT16_DAZ_RNE
39 | E4M3_RNE
40 | E4M3_STOCHASTIC
41 | """
42 |
43 | class FPEmuOp(Function):
44 | @staticmethod
45 | def forward(ctx, input, mode='NONE', inplace=False, scale=1.0, blocknorm=False, blocksize=1):
46 | if mode == 'NONE' :
47 | ctx.mark_dirty(input)
48 | return input
49 | else :
50 | if input.is_sparse :
51 | input = input.coalesce()
52 | size = input.values().nelement()
53 | if inplace == True:
54 | outputs = fpemu_hip.forward(input._values().contiguous(), mode, size, inplace, scale, blocknorm, blocksize)
55 | output = input
56 | else :
57 | outputs = fpemu_hip.forward(input._values().contiguous(), mode, size, inplace, scale, blocknorm, blocksize)
58 | output = torch.sparse.FloatTensor(input.indices(), outputs[0], input.size())
59 | else :
60 | size = input.nelement()
61 | outputs = fpemu_hip.forward(input.contiguous(), mode, size, inplace, scale, blocknorm, blocksize)
62 | output = outputs[0]
63 |
64 | if inplace == True:
65 | ctx.mark_dirty(input)
66 | return output
67 |
68 | @staticmethod
69 | def backward(ctx, output_grad):
70 | # straight-through estimator
71 | return output_grad, None, None, None, None
72 |
--------------------------------------------------------------------------------
/mpemu/pytquant/hip/fpemu_impl.cpp:
--------------------------------------------------------------------------------
1 | /*----------------------------------------------------------------------------*
2 | * Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | * This file is part of FP8-Emulation-Toolkit
4 | *
5 | * SPDX-License-Identifier: BSD-3-Clause
6 | *----------------------------------------------------------------------------*
7 | * Naveen Mellempudi (Intel Corporation)
8 | *----------------------------------------------------------------------------*/
9 |
10 | #include
11 | #include
12 |
13 | std::vector fpemu_cuda_forward(
14 | torch::Tensor input,
15 | std::string mode,
16 | int size,
17 | bool inplace,
18 | float scale,
19 | bool block_norm,
20 | int block_size);
21 |
22 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
23 | #define CHECK_CUDA(x) AT_ASSERTM(x.device().is_cuda(), #x " must be a CUDA tensor")
24 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
25 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
26 |
27 | std::vector fpemu_forward(
28 | torch::Tensor input,
29 | std::string mode,
30 | int size,
31 | bool inplace,
32 | float scale,
33 | bool block_norm,
34 | int block_size) {
35 | CHECK_INPUT(input);
36 | return fpemu_cuda_forward(input, mode, size, inplace, scale, block_norm, block_size);
37 | }
38 |
39 | std::vector fpemu_backward(
40 | torch::Tensor grad,
41 | std::string mode,
42 | int size,
43 | bool inplace,
44 | float scale,
45 | bool block_norm,
46 | int block_size) {
47 | CHECK_INPUT(grad);
48 | return fpemu_cuda_forward(grad, mode, size, inplace, scale, block_norm, block_size);
49 | }
50 |
51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
52 | m.def("forward", &fpemu_forward, "FPEmu forward (CUDA)");
53 | m.def("backward", &fpemu_backward, "FPEmu backward (CUDA)");
54 | }
55 |
--------------------------------------------------------------------------------
/mpemu/pytquant/test.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | from __future__ import division
11 | from __future__ import print_function
12 |
13 | import argparse
14 | import numpy as np
15 | import torch
16 | import cpp.fpemu as fpemu_cpp
17 | import mpemu
18 | import sys
19 |
20 | sys.path.append('../../mpemu')
21 | # importing
22 | from mpemu.qutils import fpemu_device_fn
23 |
24 |
25 | if torch.cuda.is_available() and torch.version.cuda:
26 | import cuda.fpemu as fpemu_cuda
27 | if torch.cuda.is_available() and torch.version.hip:
28 | import hip.fpemu as fpemu_hip
29 |
30 | def check_equal(first, second, verbose):
31 | if verbose:
32 | print()
33 | for i, (x, y) in enumerate(zip(first, second)):
34 | x = x.cpu().detach().numpy()
35 | y = y.cpu().detach().numpy()
36 |
37 | if verbose:
38 | print("x = {}".format(x.flatten()))
39 | print("y = {}".format(y.flatten()))
40 | print('-' * 80)
41 |
42 | np.testing.assert_allclose(x, y, rtol=0.125, err_msg="Conflict : ", verbose=verbose)
43 |
44 | def print_tensor(first, second, verbose, sparse):
45 | print('printing ...')
46 | if verbose:
47 | print()
48 | for i, (x, y) in enumerate(zip(first, second)):
49 | if sparse :
50 | x = x.cpu().to_dense().detach().numpy()
51 | y = y.cpu().to_dense().detach().numpy()
52 | else :
53 | x = x.cpu().detach().numpy()
54 | y = y.cpu().detach().numpy()
55 | if verbose:
56 | print("x = {}".format(x.flatten()))
57 | print("y = {}".format(y.flatten()))
58 | print('-' * 80)
59 |
60 |
61 | def zero_grad(variables):
62 | for variable in variables:
63 | variable.grad.zero_()
64 |
65 |
66 | def get_grads(variables):
67 | return [var.grad.clone() for var in variables]
68 |
69 | def check_forward(variables, data_format, with_cuda, verbose, sparse):
70 | if with_cuda:
71 | if not torch.cuda.is_available():
72 | print('CUDA is not supported on this platform ... ')
73 | elif torch.version.hip:
74 | cuda_values = fpemu_hip.FPEmuOp.apply(variables.cuda(), data_format.upper())
75 | print('Forward: HIP ... ', end='')
76 | print_tensor(variables, cuda_values, verbose, sparse)
77 | else :
78 | cuda_values = fpemu_cuda.FPEmuOp.apply(variables.cuda(), data_format.upper())
79 | print('Forward: CUDA ... ', end='')
80 | print_tensor(variables, cuda_values, verbose, sparse)
81 | else :
82 | cpp_values = fpemu_cpp.FPEmuOp.apply(variables, data_format.upper())
83 | print('Forward: C++ ... ', end='')
84 | print_tensor(variables, cpp_values, verbose, sparse)
85 |
86 | if not verbose :
87 | print('Done.')
88 | print('Use the option --verbose to see results')
89 | else :
90 | print('Done.')
91 |
92 |
93 | parser = argparse.ArgumentParser()
94 | parser.add_argument('-sp', '--sparse', action='store_true')
95 | parser.add_argument('-c', '--cuda', action='store_true')
96 | parser.add_argument('-v', '--verbose', action='store_true')
97 | parser.add_argument('-df', '--dformat', default='e4m3_rne')
98 |
99 | options = parser.parse_args()
100 | if options.cuda:
101 | device = torch.device("cuda")
102 | else:
103 | device = torch.device("cpu")
104 |
105 | kwargs = {'dtype': torch.float32,
106 | 'device': device,
107 | 'requires_grad': True}
108 |
109 | '''
110 | #FP4 values
111 | input = torch.tensor(np.array([[ 0.0000000,
112 | 0.000244140625,
113 | 0.0009766,
114 | 0.0039063,
115 | 0.0156250,
116 | 0.0625000,
117 | 0.2500000,
118 | 1.0000000 ]]), dtype=torch.float32)
119 | '''
120 | input = torch.tensor(np.array([[ 57344.00 , -61440.0, 65504.0, -500.0,
121 | 448.0 , -480.0, 30.0, -31.0,
122 | 26.0 , 15.0, -7.6505613e-00, 5.9832452e-00,
123 | 1.5625032e-02, 3.1725775e-02, 4.3268750e-02, 6.2655000e-02,
124 | 1.9545313e-03, 3.9045625e-03, 5.8845638e-03, 7.8089750e-03,
125 | -6.0151856e-02, 6.9784373e-03, -1.6634936e+03, -7.6505613e-04,
126 | 1.9545313e-03, 3.9045625e-03, 5.8845638e-03, 7.8089750e-03,
127 | 1.5629950e-02, 1.5225650e-02, -3.1256500e-02, 9.3857500e-02,
128 | 7.9284608e-01, 2.8815269e-01, -1.1787039e-01, 6.1035156e-05,
129 | -1.1787039e-06, -1.0749481e-01, -1.3085605e+00, 6.5981364e-01,
130 | -7.0325255e-02, 2.7448297e-01, 5.5694544e-01, -2.3220782e-01,
131 | -5.9746221e-02, 15.23213444 , -0.00004323 , 1.9767435e-04,
132 | -1.2203161e+00, 2.9099861e-01, -7.9642259e-02, 1.3200364e+00,
133 | -1.5196867e+00, -1.2530587e+00, -2.0159689e-03, -1.9767643e+00,
134 | 6.0834163e-04, 7.8943473e-05, 7.8247029e-04, -6.4658634e-05,
135 | -2.3020705e-06, -1.5630834e-05, -7.4762434e-07, 2.1336775e-06]]), dtype=torch.float32)
136 | #'''
137 | #input = torch.randn(4, 16)
138 | dformats=['e5m2_rne', 'e5m2_stochastic',
139 | 'e4m3_rne', 'e4m3_stochastic',
140 | 'e3m4_rne', 'e3m4_stochastic',
141 | ]
142 |
143 | if options.dformat not in dformats:
144 | print("data format {} is not supported".format(options.dformat))
145 | exit()
146 |
147 | if options.sparse :
148 | check_forward(input.to_sparse(), options.dformat, options.cuda, options.verbose, options.sparse)
149 | else :
150 | check_forward(input, options.dformat, options.cuda, options.verbose, options.sparse)
151 |
--------------------------------------------------------------------------------
/mpemu/scale_shift.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Dharma Teja Vooturi, Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | import torch
11 |
12 | class ScaleShift(torch.nn.Module):
13 | def __init__(self, num_features):
14 | super(ScaleShift, self).__init__()
15 | self.num_features = num_features
16 | self.weight = torch.nn.Parameter(torch.Tensor(num_features))
17 | self.bias = torch.nn.Parameter(torch.Tensor(num_features))
18 | self.reset_parameters()
19 |
20 | def reset_parameters(self):
21 | torch.nn.init.ones_(self.weight)
22 | torch.nn.init.zeros_(self.bias)
23 |
24 | def forward(self,input):
25 | # ASSUMPTION : First dimension is batch_size
26 | input_t = input.transpose(1,-1)
27 | input_t = input_t * self.weight + self.bias
28 | output = input_t.transpose(1,-1)
29 |
30 | return output
31 |
32 | # For call to repr(). Prints object's information
33 | def __repr__(self):
34 | return 'ScaleShift({})'.format(self.num_features)
35 |
36 | @staticmethod
37 | def generate_from_batchnorm(module: torch.nn.BatchNorm2d):
38 | """
39 | Helper function for converting Batchnorm2d to ScaleShift
40 | """
41 | # BN stat_dict
42 | bn_state_dict = module.state_dict()
43 |
44 | # BatchNorm parameters
45 | num_features = module.num_features
46 | eps = module.eps
47 | rmean = bn_state_dict['running_mean']
48 | rvar = bn_state_dict['running_var']
49 | gamma = bn_state_dict['weight']
50 | beta = bn_state_dict['bias']
51 |
52 | # Creating ScaleShift module
53 | ss_module = ScaleShift(num_features)
54 | with torch.no_grad():
55 | denom = torch.sqrt(rvar + eps)
56 | scale = gamma.div(denom)
57 | shift = beta - gamma.mul(rmean).div(denom)
58 |
59 | ss_module.weight.copy_(scale)
60 | ss_module.bias.copy_(shift)
61 |
62 | return ss_module
63 |
64 | def replace_batchnorms_with_scaleshifts(model):
65 | """
66 | Given a model, replace all BatchNorm2d layers with ScaleShift layers.
67 | """
68 | if isinstance(model, torch.nn.BatchNorm2d):
69 | return ScaleShift.generate_from_batchnorm(model)
70 | for name, module in model.named_children():
71 | module = replace_batchnorms_with_scaleshifts(module)
72 | setattr(model, name, module)
73 | return model
74 |
75 |
76 | if __name__ == '__main__':
77 | #### Testing ScaleShift layer ####
78 | num_features = 2
79 | ss = ScaleShift(num_features)
80 | ss.weight.data = ss.weight.data * torch.arange(1,num_features+1, dtype=ss.weight.dtype)
81 | ss.bias.data = torch.arange(1,num_features+1, dtype=ss.weight.dtype)
82 |
83 | input = torch.arange(num_features*2*2).reshape(1,num_features,2,2)
84 | print("Input")
85 | print(input)
86 | print("Scales")
87 | print(ss.weight.data)
88 | print("Shifts")
89 | print(ss.bias.data)
90 |
91 | output = ss(input)
92 | print("Output")
93 | print(output)
94 |
95 | #### Testing BN replacement with ScaleShift layer ####
96 | import torchvision
97 |
98 | model = torchvision.models.__dict__["resnet50"](pretrained=True)
99 | model.eval()
100 |
101 | # Original model
102 | input = torch.randn(1,3,224,224)
103 | output = model(input)
104 |
105 | # Replacing BNs
106 | model = replace_batchnorms_with_scaleshifts(model)
107 | output_s = model(input)
108 |
109 | print(torch.norm(output-output_s))
110 | print(output.flatten()[:10])
111 | print(output_s.flatten()[:10])
112 |
--------------------------------------------------------------------------------
/mpemu/sparse_utils.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Naveen Mellempudi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | import torch
11 | import os
12 | import sys
13 |
14 | class SparseConfig(object):
15 | def __init__(self, weight=False, ingrad=False, outgrad=False, wtgrad=False):
16 | self.weight = weight
17 | self.ingrad = ingrad
18 | self.outgrad = outgrad
19 | self.wtgrad = wtgrad
20 | self.print_stats = False
21 |
22 | self.global_step = 0
23 | self.alpha_window = 50
24 | self.weight_alpha = 65504.0
25 | self.weight_factor = 0.0
26 | self.outgrad_alpha = 65504.0
27 | self.outgrad_factor = 0.0
28 |
29 | def __repr__(self):
30 | return "[weight: {}, outgrad: {}, alpha_window : {}, weight_sparsity: {}, outgrd_sparsity :{} ]".format(
31 | self.weight, self.outgrad, self.alpha_window, self.weight_factor, self.outgrad_factor)
32 |
33 | def sparsify_ingrad_tensor(self, in_grad):
34 | return in_grad
35 |
36 | def sparsify_outgrad_tensor(self, out_grad):
37 | if self.global_step != 0 and (self.global_step%self.alpha_window) == 0 :
38 | self.outgrad_alpha = Stochastic_Pruning_Threshold(out_grad, self.outgrad_factor)
39 |
40 | out_grad.data = Stochastic_Pruning(out_grad.data, self.outgrad_alpha)
41 | return out_grad
42 |
43 | def sparsify_weight_tensor(self, weight):
44 | sample_factor = 0.1
45 | if self.global_step != 0 and (self.global_step%self.alpha_window) == 0 :
46 | self.weight_alpha = Topk_Threshold_Sampled(weight.data, self.weight_factor, sample_factor)
47 |
48 | weight.data = Topk_Pruning(weight.data, self.weight_alpha)
49 | return weight
50 |
51 | def sparsify_wtgrad_tensor(self, wt_grad):
52 | return wt_grad
53 |
54 | def optimizer_step (self, optimizer):
55 | '''
56 | This routine should handle any sparsity related operations that needs to performed inside the optimier
57 | '''
58 | return
59 |
60 | from scipy.optimize import root_scalar
61 | import pdb
62 |
63 | def print_sparse_stats (module, grad_input, grad_output) :
64 | for in_grad, out_grad in zip(grad_input, grad_output):
65 | wt_sp = 0.0
66 | if type(module) in [torch.nn.Conv2d, torch.nn.Linear] :
67 | wt_sp = 1 - (torch.count_nonzero(module.weight.data)/module.weight.data.numel()).item()
68 | grad_sp = 1 - (torch.count_nonzero(out_grad.data)/out_grad.data.numel()).item()
69 | print ("{} , weight_sparsity : {}, alpha_weight : {}, grad_sparsity : {}".format(module.name, wt_sp, module.qconfig.weight_ALPHA, grad_sp))
70 |
71 | def Stochastic_Pruning(X, ALPHA):
72 | rand = ALPHA * torch.rand(X.shape, device=X.device,dtype=X.dtype)
73 | X_abs = X.abs()
74 | X = torch.where(X_abs < ALPHA, ALPHA * torch.sign(X), X)
75 | X = torch.where(X_abs < rand, torch.tensor([0], device=X.device, dtype=X.dtype), X)
76 | del X_abs
77 | return X
78 |
79 | def Stochastic_Pruning_Threshold(X, Sparsity):
80 | X = X.abs()
81 | X_sp = 1 - (torch.count_nonzero(X)/X.numel()).item()
82 | if Sparsity <= X_sp:
83 | return X
84 | Target_sparsity = Sparsity - X_sp
85 |
86 | Y = torch.log(X[X!=0.0])
87 | mu = torch.mean(Y, dtype=torch.float)
88 | sigma = torch.std(Y)
89 | del Y
90 | guess = torch.tensor([1], device=X.device, dtype=X.dtype)
91 | bracket = [torch.exp(torch.tensor([-9], device=X.device, dtype=torch.float)),
92 | torch.exp(torch.tensor([5], device=X.device, dtype=torch.float))]
93 | sol = root_scalar(equationStochastic, x0=guess, bracket=bracket, args=(Sparsity, torch.tensor([0.0], device=X.device), sigma))
94 | ALPHA = torch.tensor([sol.root], device=X.device, dtype=X.dtype)
95 | return torch.exp(torch.log(ALPHA) + mu)
96 |
97 | def Topk_Pruning(weight, alpha):
98 | Weight_Mask = torch.where(torch.abs(weight) < alpha, torch.tensor([0], device=weight.device, dtype=torch.short),\
99 | torch.tensor([1], device=weight.device, dtype=torch.short))
100 | weight.mul_(Weight_Mask.to(device=weight.device))
101 | del Weight_Mask
102 | return weight
103 |
104 | def Topk_Threshold_Sampled(weight, Sparsity_Weight, sample_factor):
105 | Total_Weight_El = weight.numel()
106 | sampled_El = int(Total_Weight_El*sample_factor)
107 | sampled_id = torch.randint(0, Total_Weight_El, (1,sampled_El))
108 | X_sampled = torch.abs(weight.view(-1))[sampled_id]
109 | #pdb.set_trace()
110 | topk_count_sample = int(sampled_El*(1-Sparsity_Weight))
111 | __ , topk_w_id_sampled = torch.topk(X_sampled, topk_count_sample)
112 | alpha = X_sampled[0,topk_w_id_sampled[0,topk_count_sample-1]]
113 | return alpha
114 |
115 | def equationStochastic(alpha, sparsity, mu, sigma):
116 | sqrt2 = torch.sqrt(torch.tensor([2], device="cuda", dtype=torch.float))
117 | alpha = torch.tensor([alpha], device="cuda", dtype=torch.float)
118 | pt1 = torch.exp((sigma**2)/2) * torch.erf(sigma/sqrt2 - torch.log(alpha/torch.exp(mu))/(sqrt2 * sigma))
119 | pt2 = alpha/torch.exp(mu) * torch.erf(torch.log(alpha/torch.exp(mu))/(sqrt2 * sigma))
120 | pt3 = torch.exp((sigma**2)/2)
121 | return 0.5 - sparsity + torch.exp(mu)/(2*alpha) * (pt1 + pt2 - pt3)
122 |
--------------------------------------------------------------------------------
/mpemu/stats_collector.py:
--------------------------------------------------------------------------------
1 | #------------------------------------------------------------------------------
2 | # Copyright (c) 2023, Intel Corporation - All rights reserved.
3 | # This file is part of FP8-Emulation-Toolkit
4 | #
5 | # SPDX-License-Identifier: BSD-3-Clause
6 | #------------------------------------------------------------------------------
7 | # Dharma Teja Vooturi (Intel Corporation)
8 | #------------------------------------------------------------------------------
9 |
10 | import torch
11 |
12 | class TensorFullIntQuantParams(object):
13 | """
14 | min_val : float
15 | max_val : float
16 | qconfig : TensorQuantConfig
17 | """
18 | def __init__(self, min_val, max_val, qconfig):
19 | super(TensorFullIntQuantParams, self).__init__()
20 | self.qconfig = qconfig
21 | self.min_val, self.max_val, self.scale, self.zero_point = self._calculate_int8_qparams_base(qconfig.dtype,
22 | qconfig.scheme, min_val, max_val)
23 |
24 |
25 | def quantize(self, tensor_f):
26 | # Clamping the values
27 | #tensor_f = torch.clamp(tensor_f, self.min_val, self.max_val)
28 | # Quantizing the tensor
29 | tensor_int = torch.round(tensor_f/self.scale + self.zero_point)
30 |
31 | min_int = -128
32 | max_int = 127
33 | if self.qconfig.dtype == "uint8":
34 | min_int = 0
35 | max_int = 255
36 |
37 | # Clamping the values in integer domain
38 | tensor_int = torch.clamp(tensor_int, min_int, max_int)
39 |
40 | return tensor_int
41 |
42 | def dequantize(self, tensor_int):
43 | tensor_f = (tensor_int - self.zero_point)*self.scale
44 | return tensor_f
45 |
46 | def quant_dequant(self, tensor_f):
47 | return self.dequantize(self.quantize(tensor_f))
48 |
49 | def __repr__(self):
50 | return "{} Quantization range [{:.4f},{:.4f}] ".format(self.qconfig, self.min_val, self.max_val)
51 |
52 | # Calculating quantization parameters for INT8/UINT8
53 | @staticmethod
54 | def _calculate_int8_qparams_base(dtype, scheme, min_val, max_val):
55 | """
56 | Adapted from https://github.com/pytorch/pytorch/blob/8074779328fa471f484fb74cc6c50d95392fe2c2/torch/quantization/observer.py#L193
57 | """
58 | assert min_val <= max_val, "Minimum value {} has to be less than Maximum value {}".format(min_val,max_val)
59 | eps = torch.finfo(torch.float32).eps
60 |
61 | if dtype == "uint8":
62 | qmin = 0
63 | qmax = 255
64 | elif dtype == "int8":
65 | qmin = -128
66 | qmax = 127
67 |
68 | # Including zero in the range
69 | min_val,max_val = float(min_val),float(max_val)
70 | min_val = min(0.0, min_val)
71 | max_val = max(0.0, max_val)
72 |
73 | if min_val == max_val:
74 | scale = 1.0
75 | zero_point = 0
76 | else:
77 | if scheme == "sym_full" or scheme == "sym_channel":
78 | max_val = max(-min_val, max_val)
79 | scale = max_val / ((qmax - qmin) / 2)
80 | scale = max(scale, eps)
81 | zero_point = 0 if dtype == "int8" else 128
82 |
83 | min_val = -1*max_val
84 | elif scheme == "asym_full" or scheme == "asym_channel":
85 | scale = (max_val - min_val) / float(qmax - qmin)
86 | scale = max(scale, eps)
87 | zero_point = qmin - round(min_val / scale)
88 | zero_point = max(qmin, zero_point)
89 | zero_point = min(qmax, zero_point)
90 | zero_point = int(zero_point)
91 |
92 |
93 | return min_val, max_val, scale, zero_point
94 |
95 |
96 | class TensorChannelIntQuantParams(object):
97 | def __init__(self, min_vals, max_vals, qconfig):
98 | super(TensorChannelIntQuantParams, self).__init__()
99 | assert(len(min_vals) == len(max_vals))
100 | self.num_channels = len(min_vals)
101 | self.channel_qparams = []
102 |
103 | for min_val,max_val in zip(min_vals,max_vals):
104 | self.channel_qparams.append(TensorFullIntQuantParams(min_val, max_val, qconfig))
105 |
106 | def quant_dequant(self, tensor_f):
107 | tensor_q = torch.zeros_like(tensor_f)
108 | for chan_id in range(self.num_channels):
109 | tensor_q[chan_id] = self.channel_qparams[chan_id].quant_dequant(tensor[chan_id])
110 | return tensor_q
111 |
112 |
113 | class TensorDump(torch.nn.Module):
114 | def __init__(self, name=""):
115 | self.name = name
116 | self.tensors = []
117 |
118 | def forward(self, tensor):
119 | self.tensors.append(tensor)
120 |
121 | def dump(self):
122 | import pickle
123 | import numpy as np
124 | pickle.dump(np.array([tensor.detach().numpy() for tensor in self.tensors]), open(self.name+".pickle","wb"))
125 |
126 | class TensorDumpListWrapper(torch.nn.Module):
127 | """docstring for MinMaxStats of a tensor"""
128 | def __init__(self, name=""):
129 | super(TensorDumpListWrapper, self).__init__()
130 | self.initiated = False
131 | self.tupled_input = False
132 | self.tensor_dump_list = []
133 |
134 | self.name = name
135 |
136 | def forward(self, input):
137 | if not self.initiated:
138 | if type(input) == tuple:
139 | self.tupled_input = True
140 | for i in range(len(input)):
141 | self.tensor_dump_list.append(TensorDump(name=self.name+"_{}".format(i)))
142 | else:
143 | self.tensor_dump_list = [TensorDump(name=self.name+"_0")]
144 | self.initiated = True
145 |
146 | if self.tupled_input:
147 | for i in range(len(input)):
148 | self.tensor_dump_list[i].forward(input[i])
149 | else:
150 | self.tensor_dump_list[0].forward(input)
151 |
152 |
153 | def dump(self):
154 | for tensor_dump in self.tensor_dump_list:
155 | tensor_dump.dump()
156 |
157 |
158 | class ArchiveStats(torch.nn.Module):
159 | def __init__(self):
160 | self.tensors = []
161 |
162 | def forward(self, tensor):
163 | self.tensors.append(tensor)
164 |
165 | class MinMaxStats(torch.nn.Module):
166 | """docstring for MinMaxStats of a tensor"""
167 | def __init__(self, archive_tensors=False):
168 | super(MinMaxStats, self).__init__()
169 | self.min_val = None
170 | self.max_val = None
171 |
172 | self.archive_tensors = archive_tensors
173 | self.tensors = None
174 |
175 | def forward(self, tensor):
176 | min_val = torch.min(tensor).item()
177 | max_val = torch.max(tensor).item()
178 |
179 | if self.min_val == None:
180 | self.min_val = min_val
181 | else:
182 | if min_val < self.min_val:
183 | self.min_val = min_val
184 |
185 | if self.max_val == None:
186 | self.max_val = max_val
187 | else:
188 | if max_val > self.max_val:
189 | self.max_val = max_val
190 |
191 | if self.archive_tensors:
192 | if self.tensors is None:
193 | self.tensors = [tensor]
194 | else:
195 | self.tensors.append(tensor)
196 |
197 | def get_tensor_quant_params(self, ten_qconfig):
198 | assert(ten_qconfig.dtype in ["uint8","int8"])
199 | ten_qparams = TensorFullIntQuantParams(self.min_val, self.max_val, ten_qconfig)
200 | return ten_qparams
201 |
202 | def print(self, name=""):
203 | print("{:7.3f} {:7.3f} {}".format(self.min_val, self.max_val, name))
204 |
205 |
206 | class RunningMinMaxStats(torch.nn.Module):
207 | """docstring for MinMaxStats of a tensor"""
208 | def __init__(self, archive_tensors=False):
209 | super(RunningMinMaxStats, self).__init__()
210 |
211 | self.min_val = None
212 | self.max_val = None
213 |
214 | self.running_min_val = None
215 | self.running_max_val = None
216 |
217 | self.running_steps = 0
218 |
219 | self.archive_tensors = archive_tensors
220 | self.tensors = None
221 |
222 | def forward(self, tensor):
223 | min_val = torch.min(tensor).item()
224 | max_val = torch.max(tensor).item()
225 |
226 | if self.min_val == None:
227 | self.min_val = min_val
228 | else:
229 | if min_val < self.min_val:
230 | self.min_val = min_val
231 |
232 | if self.max_val == None:
233 | self.max_val = max_val
234 | else:
235 | if max_val > self.max_val:
236 | self.max_val = max_val
237 |
238 | ## Running max and running mean
239 | if self.running_min_val == None:
240 | self.running_min_val = min_val
241 | else:
242 | self.running_min_val = (self.running_min_val*self.running_steps + min_val)/(self.running_steps+1)
243 |
244 | if self.running_max_val == None:
245 | self.running_max_val = max_val
246 | else:
247 | self.running_max_val = (self.running_max_val*self.running_steps + max_val)/(self.running_steps+1)
248 |
249 | self.running_steps += 1
250 |
251 | if self.archive_tensors:
252 | if self.tensors is None:
253 | self.tensors = [tensor]
254 | else:
255 | self.tensors.append(tensor)
256 |
257 | def get_tensor_quant_params(self, ten_qconfig):
258 | assert(ten_qconfig.dtype in ["uint8","int8"])
259 | ten_qparams = TensorFullIntQuantParams(self.running_min_val, self.running_max_val, ten_qconfig)
260 | return ten_qparams
261 |
262 | def print(self, name=""):
263 | print("{:7.3f} {:7.3f} {:7.3f} {:7.3f} {}".format(self.min_val, self.max_val,
264 | self.running_min_val, self.running_max_val, name))
265 |
266 |
267 | class StatsListWrapper(torch.nn.Module):
268 | """docstring for MinMaxStats of a tensor"""
269 | def __init__(self, stats_class, archive_tensors):
270 | super(StatsListWrapper, self).__init__()
271 | self.initiated = False
272 | self.tupled_input = False
273 | self.stats_list = []
274 |
275 | self.stats_class = stats_class
276 | self.archive_tensors = archive_tensors
277 |
278 | def forward(self, input):
279 | if not self.initiated:
280 | if type(input) == tuple:
281 | self.tupled_input = True
282 | for i in range(len(input)):
283 | self.stats_list.append(self.stats_class(archive_tensors=self.archive_tensors))
284 | else:
285 | self.stats_list = [self.stats_class(archive_tensors=self.archive_tensors)]
286 | self.initiated = True
287 |
288 | if self.tupled_input:
289 | for i in range(len(input)):
290 | self.stats_list[i].forward(input[i])
291 | else:
292 | self.stats_list[0].forward(input)
293 |
294 | def get_tensor_quant_params(self, ten_qconfig):
295 | ret_list = [stats_obj.get_tensor_quant_params(ten_qconfig) for stats_obj in self.stats_list]
296 |
297 | if self.tupled_input:
298 | return ret_list
299 | else:
300 | return ret_list[0]
301 |
302 | def print(self, name=""):
303 | for i in range(len(self.stats_list)):
304 | self.stats_list[i].print(name+"[{}]".format(i))
305 |
306 |
307 | class ChannleWiseMinMaxStats(torch.nn.Module):
308 |
309 | def __init__(self):
310 | super(ChannleWiseMinMaxStats, self).__init__()
311 | self.min_vals = None
312 | self.max_vals = None
313 |
314 | def forward(self, tensor):
315 | num_chans = tensor.shape[0]
316 | if self.min_vals == None or self.max_vals == None:
317 | self.min_vals = [None]*num_chans
318 | self.max_vals = [None]*num_chans
319 |
320 | for chan_id in range(num_chans):
321 | min_val = torch.min(tensor[chan_id]).item()
322 | max_val = torch.max(tensor[chan_id]).item()
323 |
324 | if self.min_vals[chan_id] == None:
325 | self.min_vals[chan_id] = min_val
326 | else:
327 | if self.min_vals[chan_id] < min_val:
328 | self.min_vals[chan_id] = min_val
329 |
330 | if self.max_vals[chan_id] == None:
331 | self.max_vals[chan_id] = max_val
332 | else:
333 | if self.max_vals[chan_id] > max_val:
334 | self.max_vals[chan_id] = max_val
335 |
336 | def get_tensor_quant_params(self, ten_qconfig):
337 | ten_qparams = TensorChannelIntQuantParams(self.min_vals, self.max_vals, ten_qconfig)
338 | return ten_qparams
339 |
340 | def print(self):
341 | print(self.min_vals, self.max_vals, end=' ')
342 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | requests
2 | scipy
3 | setuptools
4 | torch
5 | torchaudio
6 | torchvision
7 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from setuptools import setup, find_packages
4 | from torch.utils.cpp_extension import BuildExtension, CppExtension
5 |
6 | cmdclass = {}
7 | ext_modules = []
8 |
9 | with open("README.md", "r") as fh:
10 | long_description = fh.read()
11 |
12 | ext_modules.append(
13 | CppExtension('fpemu_cpp',
14 | ['mpemu/pytquant/cpp/fpemu_impl.cpp'],
15 | extra_compile_args = ["-mf16c", "-march=native", "-mlzcnt", "-fopenmp", "-Wdeprecated-declarations"]
16 | ),)
17 |
18 | if torch.cuda.is_available():
19 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
20 | if torch.version.hip:
21 | ext_modules.append(
22 | CUDAExtension('fpemu_hip', [
23 | 'mpemu/pytquant/hip/fpemu_impl.cpp',
24 | 'mpemu/pytquant/hip/fpemu_kernels.hip'],
25 | ),)
26 | elif torch.version.cuda:
27 | ext_modules.append(
28 | CUDAExtension('fpemu_cuda', [
29 | 'mpemu/pytquant/cuda/fpemu_impl.cpp',
30 | 'mpemu/pytquant/cuda/fpemu_kernels.cu'],
31 | ),)
32 |
33 | ext_modules.append(
34 | CppExtension('simple_gemm_dev',
35 | ['mpemu/cmodel/simple/simple_gemm.cpp', 'mpemu/cmodel/simple/simple_gemm_impl.cpp', 'mpemu/cmodel/simple/simple_mm_engine.cpp'],
36 | extra_compile_args=[ '-march=native', '-fopenmp','-Wunused-but-set-variable','-Wunused-variable'],
37 | include_dirs=['{}/'.format(os.getenv("PWD")+'/mpemu/cmodel/simple')],
38 | ),)
39 |
40 | ext_modules.append(
41 | CppExtension('simple_conv2d_dev',
42 | ['mpemu/cmodel/simple/simple_conv2d.cpp', 'mpemu/cmodel/simple/simple_conv2d_impl.cpp', 'mpemu/cmodel/simple/simple_mm_engine.cpp'],
43 | extra_compile_args=[ '-march=native', '-fopenmp','-Wunused-but-set-variable','-Wunused-variable'],
44 | include_dirs=['{}/'.format(os.getenv("PWD")+'/mpemu/cmodel/simple')],
45 | ),)
46 |
47 | cmdclass['build_ext'] = BuildExtension
48 |
49 | setup(
50 | name="mpemu",
51 | version="1.0",
52 | ext_modules=ext_modules,
53 | cmdclass=cmdclass,
54 | author="Naveen Mellempudi",
55 | description="FP8 Emulation Toolkit",
56 | long_description=long_description,
57 | long_description_content_type="text/markdown",
58 | url="",
59 | packages=find_packages(),
60 | classifiers=[
61 | "Programming Language :: Python :: 3",
62 | "Operating System :: OS Independent",
63 | ],
64 | python_requires='>=3.6',
65 | )
66 |
--------------------------------------------------------------------------------