├── nanogpt_metaprogram └── README.md ├── README.md └── integer_only_gan ├── artbench.py ├── diffaug.py ├── integer_only_observer.py ├── dcgan_trace.py ├── fid_score.py └── dcgan_train.py /nanogpt_metaprogram/README.md: -------------------------------------------------------------------------------- 1 | Barebone interface for training nanoGPT 2 | 3 | This project aims to create a DAG of 1k-2k "programs" that describes a full forward-backward pass of training in the nanoGPT model. It will be hardware- and platform-independent, with a reference CUDA implementation. 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | research: repository for machine learning algorithms and systems on emergent runtimes 2 | 3 | --------------- 4 | 5 | * [Challenges in compilation of neural networks to zero-knowledge runtimes](https://docs.google.com/presentation/d/1Nb5D6-EW_8McRkSGzUO5lwxZPvm6Rlx3eJyLn7DD6kI/edit?usp=sharing) (Peiyuan Liao, 0xPARC ZKML day, Nov 2022) 6 | * [On-going research on training integer-only GAN](./integer_only_gan) 7 | -------------------------------------------------------------------------------- /integer_only_gan/artbench.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import CIFAR10 2 | 3 | class ArtBench10(CIFAR10): 4 | 5 | base_folder = "artbench-10-batches-py" 6 | url = "https://artbench.eecs.berkeley.edu/files/artbench-10-python.tar.gz" 7 | filename = "artbench-10-python.tar.gz" 8 | tgz_md5 = "9df1e998ee026aae36ec60ca7b44960e" 9 | train_list = [ 10 | ["data_batch_1", "c2e02a78dcea81fe6fead5f1540e542f"], 11 | ["data_batch_2", "1102a4dcf41d4dd63e20c10691193448"], 12 | ["data_batch_3", "177fc43579af15ecc80eb506953ec26f"], 13 | ["data_batch_4", "566b2a02ccfbafa026fbb2bcec856ff6"], 14 | ["data_batch_5", "faa6a572469542010a1c8a2a9a7bf436"], 15 | ] 16 | 17 | test_list = [ 18 | ["test_batch", "fa44530c8b8158467e00899609c19e52"], 19 | ] 20 | meta = { 21 | "filename": "meta", 22 | "key": "styles", 23 | "md5": "5bdcafa7398aa6b75d569baaec5cd4aa", 24 | } -------------------------------------------------------------------------------- /integer_only_gan/diffaug.py: -------------------------------------------------------------------------------- 1 | # Differentiable Augmentation for Data-Efficient GAN Training 2 | # Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han 3 | # https://arxiv.org/pdf/2006.10738 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def DiffAugment(x, policy='', channels_first=True): 10 | if policy: 11 | if not channels_first: 12 | x = x.permute(0, 3, 1, 2) 13 | for p in policy.split(','): 14 | for f in AUGMENT_FNS[p]: 15 | x = f(x) 16 | if not channels_first: 17 | x = x.permute(0, 2, 3, 1) 18 | x = x.contiguous() 19 | return x 20 | 21 | 22 | def rand_brightness(x): 23 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) 24 | return x 25 | 26 | 27 | def rand_saturation(x): 28 | x_mean = x.mean(dim=1, keepdim=True) 29 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean 30 | return x 31 | 32 | 33 | def rand_contrast(x): 34 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 35 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean 36 | return x 37 | 38 | 39 | def rand_translation(x, ratio=0.125): 40 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 41 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 42 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 43 | grid_batch, grid_x, grid_y = torch.meshgrid( 44 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 45 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 46 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 47 | ) 48 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 49 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 50 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 51 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous() 52 | return x 53 | 54 | 55 | def rand_cutout(x, ratio=0.5): 56 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 57 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 58 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 59 | grid_batch, grid_x, grid_y = torch.meshgrid( 60 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 61 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 62 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 63 | ) 64 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 65 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 66 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 67 | mask[grid_batch, grid_x, grid_y] = 0 68 | x = x * mask.unsqueeze(1) 69 | return x 70 | 71 | 72 | AUGMENT_FNS = { 73 | 'color': [rand_brightness, rand_saturation, rand_contrast], 74 | 'translation': [rand_translation], 75 | 'cutout': [rand_cutout], 76 | } -------------------------------------------------------------------------------- /integer_only_gan/integer_only_observer.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Any, List, Tuple, Optional, Dict 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.ao.quantization.utils import ( 7 | check_min_max_valid, calculate_qmin_qmax, is_per_tensor, is_per_channel) 8 | import math 9 | from torch.quantization.observer import MovingAverageMinMaxObserver 10 | 11 | class MovingAverageIntegerMinMaxObserver(MovingAverageMinMaxObserver): 12 | def __init__( 13 | self, 14 | *args, 15 | **kwargs 16 | ) -> None: 17 | super(MovingAverageIntegerMinMaxObserver, self).__init__( 18 | *args, 19 | **kwargs 20 | ) 21 | 22 | @torch.jit.export 23 | def _calculate_qparams( 24 | self, min_val: torch.Tensor, max_val: torch.Tensor 25 | ) -> Tuple[torch.Tensor, torch.Tensor]: 26 | r"""Calculates the quantization parameters, given min and max 27 | value tensors. Works for both per tensor and per channel cases 28 | Args: 29 | min_val: Minimum values per channel 30 | max_val: Maximum values per channel 31 | Returns: 32 | scales: Scales tensor of shape (#channels,) 33 | zero_points: Zero points tensor of shape (#channels,) 34 | """ 35 | if not check_min_max_valid(min_val, max_val): 36 | return torch.tensor([1.0], device=min_val.device.type), torch.tensor([0], device=min_val.device.type) 37 | 38 | quant_min, quant_max = self.quant_min, self.quant_max 39 | min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) 40 | max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) 41 | 42 | device = min_val_neg.device 43 | scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device) 44 | zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) 45 | 46 | if ( 47 | self.qscheme == torch.per_tensor_symmetric 48 | or self.qscheme == torch.per_channel_symmetric 49 | ): 50 | max_val_pos = torch.max(-min_val_neg, max_val_pos) 51 | scale = max_val_pos / (float(quant_max - quant_min) / 2) 52 | scale = torch.max(scale, self.eps) 53 | if self.dtype == torch.quint8: 54 | if self.has_customized_qrange: 55 | # When customized quantization range is used, down-rounded midpoint of the range is chosen. 56 | zero_point = zero_point.new_full( 57 | zero_point.size(), (quant_min + quant_max) // 2 58 | ) 59 | else: 60 | zero_point = zero_point.new_full(zero_point.size(), 128) 61 | elif self.qscheme == torch.per_channel_affine_float_qparams: 62 | scale = (max_val - min_val) / float(quant_max - quant_min) 63 | scale = torch.where(scale > self.eps, scale, torch.ones_like(scale)) 64 | # We use the quantize function 65 | # xq = Round(Xf * inv_scale + zero_point), 66 | # setting zero_point to (-1 * min *inv_scale) we get 67 | # Xq = Round((Xf - min) * inv_scale) 68 | zero_point = -1 * min_val / scale 69 | else: 70 | scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) 71 | scale = torch.max(scale, self.eps) 72 | zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) 73 | zero_point = torch.clamp(zero_point, quant_min, quant_max) 74 | 75 | # For scalar values, cast them to Tensors of size 1 to keep the shape 76 | # consistent with default values in FakeQuantize. 77 | if len(scale.shape) == 0: 78 | # TODO: switch to scale.item() after adding JIT support 79 | scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device) 80 | if len(zero_point.shape) == 0: 81 | # TODO: switch to zero_point.item() after adding JIT support 82 | zero_point = torch.tensor( 83 | [int(zero_point)], dtype=zero_point.dtype, device=device 84 | ) 85 | if self.qscheme == torch.per_channel_affine_float_qparams: 86 | zero_point = torch.tensor( 87 | [float(zero_point)], dtype=zero_point.dtype, device=device 88 | ) 89 | 90 | # def power_log(x): return 2**(round(math.log(x, 2))) 91 | # the key: make scale power of 2 92 | scale = torch.clamp(torch.round(torch.log2(scale)), min=-10, max=10) 93 | scale = torch.pow(2.0, scale) 94 | 95 | return scale, zero_point -------------------------------------------------------------------------------- /integer_only_gan/dcgan_trace.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import math 5 | import sys 6 | 7 | import torchvision.transforms as transforms 8 | from torchvision.utils import save_image 9 | from artbench import ArtBench10 10 | 11 | from torch.utils.data import DataLoader 12 | from torchvision import datasets 13 | from tqdm import tqdm 14 | 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.autograd as autograd 18 | import torch 19 | from torch.quantization.qconfig import QConfig 20 | from torch.quantization import QuantStub, DeQuantStub 21 | 22 | import tvm 23 | from tvm import relay 24 | from tvm.relay import transform 25 | from tvm.contrib import graph_executor 26 | 27 | torch.backends.cudnn.enabled = False 28 | torch.backends.cudnn.benchmark = False 29 | torch.backends.quantized.engine = 'fbgemm' 30 | 31 | from tvm.relay.dataflow_pattern import wildcard, is_op, is_constant, is_expr, rewrite, DFPatternCallback 32 | 33 | class UpsampleCallback(DFPatternCallback): 34 | def __init__(self, require_type=False): 35 | super().__init__(require_type) 36 | self.x = wildcard() 37 | self.scale1 = is_constant() 38 | self.scale2 = is_constant() 39 | self.zero1 = is_constant() 40 | self.zero2 = is_constant() 41 | subtract = is_op("subtract")(is_op("cast")(self.x), self.zero1) 42 | upsample = is_op("round")(is_op("divide")( 43 | is_op("image.resize2d")(is_op("multiply")(is_op("cast")(subtract), self.scale1)) \ 44 | , self.scale2)) 45 | self.pattern = is_op("cast")(is_op("clip")(is_op("add")(upsample, self.zero2))) 46 | 47 | def callback(self, pre, post, node_map): 48 | x = node_map[self.x][0] 49 | #offset = node_map[self.zero2][0] 50 | #offset = relay.const(offset.data.numpy().astype("int32"), dtype="int32") 51 | 52 | curr_shape = (x._checked_type_.shape) 53 | size = (curr_shape[-2]*2, curr_shape[-1]*2) 54 | x = relay.image.resize2d(x, size=size, roi=[0.0, 0.0, 0.0, 0.0], method="nearest_neighbor", \ 55 | coordinate_transformation_mode="asymmetric", rounding_method="", cubic_alpha=-0.75) 56 | 57 | #offset = x + offset 58 | clipped = relay.op.clip(x, a_min=0, a_max=255) 59 | return relay.op.cast(clipped, dtype="uint8") 60 | 61 | #from torch.ao.quantization.observer import ( 62 | # MovingAverageMinMaxObserver, 63 | #) 64 | from torch.ao.quantization.fake_quantize import ( 65 | FakeQuantize, 66 | #default_fused_wt_fake_quant, 67 | ) 68 | from integer_only_observer import MovingAverageIntegerMinMaxObserver 69 | 70 | #from fid_score import calculate_fid_given_paths 71 | 72 | os.makedirs("images", exist_ok=True) 73 | 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training") 76 | parser.add_argument("--batch_size", type=int, default=64, help="size of the batches") 77 | parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") 78 | parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") 79 | parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") 80 | parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation") 81 | parser.add_argument("--latent_dim", type=int, default=16, help="dimensionality of the latent space") 82 | parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension") 83 | parser.add_argument("--channels", type=int, default=3, help="number of image channels") 84 | parser.add_argument("--clip_value", type=float, default=0.01, help="lower and upper clip value for disc. weights") 85 | parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples") 86 | opt = parser.parse_args() 87 | print(opt) 88 | 89 | img_shape = (opt.channels, opt.img_size, opt.img_size) 90 | 91 | HIDDEN_DIM = 1 92 | class Generator(nn.Module): 93 | def __init__(self): 94 | super(Generator, self).__init__() 95 | self.quant = QuantStub() 96 | self.dequant = DeQuantStub() 97 | 98 | self.init_size = opt.img_size // 4 99 | self.l1 = nn.Linear(opt.latent_dim, HIDDEN_DIM * self.init_size ** 2) 100 | 101 | self.conv_blocks = nn.Sequential( 102 | #nn.BatchNorm2d(HIDDEN_DIM), 103 | nn.Upsample(scale_factor=2), 104 | nn.Conv2d(HIDDEN_DIM, HIDDEN_DIM, 3, stride=1, padding=1), 105 | nn.BatchNorm2d(1, 0.8), 106 | nn.ReLU(), 107 | nn.Upsample(scale_factor=2), 108 | nn.Conv2d(HIDDEN_DIM, HIDDEN_DIM, 3, stride=1, padding=1), 109 | nn.BatchNorm2d(HIDDEN_DIM, 0.8), 110 | nn.ReLU(), 111 | nn.Conv2d(HIDDEN_DIM, opt.channels, 3, stride=1, padding=1), 112 | ) 113 | 114 | self.tanh = nn.Tanh() 115 | 116 | def forward(self, z): 117 | z = self.quant(z) 118 | 119 | out = self.l1(z) 120 | out = out.view(out.shape[0], HIDDEN_DIM, self.init_size, self.init_size) 121 | img = self.conv_blocks(out) 122 | 123 | img = self.dequant(img) 124 | 125 | img = self.tanh(img) 126 | return img 127 | 128 | # Initialize generator 129 | generator = Generator() 130 | 131 | qconfig = QConfig(activation=FakeQuantize.with_args( 132 | observer=MovingAverageIntegerMinMaxObserver, 133 | quant_min=0, 134 | quant_max=255, 135 | reduce_range=True), 136 | weight=FakeQuantize.with_args( 137 | observer=MovingAverageIntegerMinMaxObserver, 138 | quant_min=-128, 139 | quant_max=127, 140 | dtype=torch.qint8, 141 | qscheme=torch.per_tensor_symmetric 142 | )) 143 | generator.qconfig = qconfig 144 | 145 | generator.eval() 146 | 147 | # fuse the activations to preceding layers, where applicable 148 | # this needs to be done manually depending on the model architecture 149 | generator = torch.quantization.fuse_modules(generator, 150 | [["conv_blocks.1", "conv_blocks.2", "conv_blocks.3"] 151 | ,["conv_blocks.5", "conv_blocks.6", "conv_blocks.7"]] 152 | ) 153 | 154 | generator.train() 155 | 156 | # Prepare the model for QAT. This inserts observers and fake_quants in 157 | # the model that will observe weight and activation tensors during calibration. 158 | generator = torch.quantization.prepare_qat(generator) 159 | 160 | generator.load_state_dict(torch.load('generator.pt', map_location=torch.device('cpu'))) 161 | generator = generator.cpu() 162 | 163 | # Convert to quantized model 164 | torch.quantization.convert(generator, inplace=True) 165 | print('QAT: Conversion done.') 166 | print(generator) 167 | 168 | 169 | # Configure data loader 170 | os.makedirs("./data/artbench", exist_ok=True) 171 | dataloader = torch.utils.data.DataLoader( 172 | ArtBench10( 173 | "./data/artbench", 174 | train=False, 175 | download=True, 176 | transform=transforms.Compose( 177 | [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] 178 | ), 179 | ), 180 | batch_size=opt.batch_size, 181 | shuffle=False, 182 | drop_last=True 183 | ) 184 | 185 | imgs = list(enumerate(dataloader))[0][1] 186 | 187 | dz = np.random.normal(0, 1, (64, opt.latent_dim)) 188 | z = torch.FloatTensor(dz) 189 | 190 | script_module = torch.jit.trace(generator, example_inputs=[z]).eval() 191 | script_result = script_module(z).numpy() 192 | torch_result = generator(z).numpy() 193 | 194 | print((script_result - torch_result).mean()) 195 | 196 | script_module.save("quantized_jit_generator.pt") 197 | 198 | device = tvm.cpu() 199 | target = "llvm" 200 | input_name = "input" # the input name can be be arbitrary for PyTorch frontend. 201 | input_shapes = [(input_name, (64, opt.latent_dim))] 202 | mod, params = relay.frontend.from_pytorch( 203 | script_module, input_shapes, keep_quantized_weight=True 204 | ) 205 | mod = relay.transform.InferType()(mod) 206 | print(mod) 207 | print("===========") 208 | mod = relay.qnn.transform.CanonicalizeOps()(mod) 209 | seq = tvm.transform.Sequential( 210 | [ 211 | transform.CanonicalizeOps(), 212 | transform.InferType(), 213 | #transform.SimplifyInference(), 214 | transform.FoldConstant(), 215 | #transform.FoldScaleAxis(), 216 | #transform.SimplifyExpr(), 217 | #transform.FoldConstant(), 218 | ] 219 | ) 220 | with tvm.transform.PassContext(opt_level=3): 221 | mod = seq(mod) 222 | print(mod) 223 | print("===========") 224 | mod["main"] = rewrite(UpsampleCallback(), mod["main"]) 225 | mod = relay.transform.InferType()(mod) 226 | print(mod) 227 | 228 | json_str = mod.astext(show_meta_data=True) #tvm.ir.save_json(mod["main"]) 229 | with open("generator_tvm_ir.txt", "w") as fo: 230 | fo.write(json_str) 231 | with open("generator_tvm.params", "wb") as fo: 232 | fo.write(relay.save_param_dict(params)) 233 | 234 | with tvm.transform.PassContext(opt_level=1): 235 | func = relay.create_executor("graph", mod=mod, device=device, target=target).evaluate() 236 | 237 | 238 | for i, (imgs, _) in tqdm(enumerate(dataloader), total=len(dataloader)): 239 | # Sample noise as generator input 240 | z = np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)) 241 | 242 | # Generate a batch of images 243 | input_dict = {input_name: z} 244 | # Generate a batch of images 245 | fake_imgs = func(**input_dict, **params).numpy() 246 | 247 | # print(fake_imgs.shape) 248 | for j in range(imgs.shape[0]): 249 | k = i * imgs.shape[0] + j 250 | save_image(torch.tensor(fake_imgs[j, :, :, :]), "./data/fid_eval/%d.png" % k) 251 | save_image(imgs.data[j], "./data/fid_real/%d.png" % k) 252 | 253 | #fid = calculate_fid_given_paths(paths=('./data/fid_real', './data/fid_eval'), 254 | # batch_size=256, device='cuda', dims=2048) 255 | #print(f"FID after rewrite : {fid}") 256 | -------------------------------------------------------------------------------- /integer_only_gan/fid_score.py: -------------------------------------------------------------------------------- 1 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 2 | The FID metric calculates the distance between two distributions of images. 3 | Typically, we have summary statistics (mean & covariance matrix) of one 4 | of these distributions, while the 2nd distribution is given by a GAN. 5 | When run as a stand-alone program, it compares the distribution of 6 | images that are stored as PNG/JPEG at a specified location with a 7 | distribution given by summary statistics (in pickle format). 8 | The FID is calculated by assuming that X_1 and X_2 are the activations of 9 | the pool_3 layer of the inception net for generated samples and real world 10 | samples respectively. 11 | See --help to see further details. 12 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 13 | of Tensorflow 14 | Copyright 2018 Institute of Bioinformatics, JKU Linz 15 | Licensed under the Apache License, Version 2.0 (the "License"); 16 | you may not use this file except in compliance with the License. 17 | You may obtain a copy of the License at 18 | http://www.apache.org/licenses/LICENSE-2.0 19 | Unless required by applicable law or agreed to in writing, software 20 | distributed under the License is distributed on an "AS IS" BASIS, 21 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 22 | See the License for the specific language governing permissions and 23 | limitations under the License. 24 | """ 25 | import os 26 | import pathlib 27 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 28 | 29 | import numpy as np 30 | import torch 31 | import torchvision.transforms as TF 32 | from PIL import Image 33 | from scipy import linalg 34 | from torch.nn.functional import adaptive_avg_pool2d 35 | 36 | try: 37 | from tqdm import tqdm 38 | except ImportError: 39 | # If tqdm is not available, provide a mock version of it 40 | def tqdm(x): 41 | return x 42 | 43 | from pytorch_fid.inception import InceptionV3 44 | 45 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 46 | parser.add_argument('--batch-size', type=int, default=50, 47 | help='Batch size to use') 48 | parser.add_argument('--num-workers', type=int, 49 | help=('Number of processes to use for data loading. ' 50 | 'Defaults to `min(8, num_cpus)`')) 51 | parser.add_argument('--device', type=str, default=None, 52 | help='Device to use. Like cuda, cuda:0 or cpu') 53 | parser.add_argument('--dims', type=int, default=2048, 54 | choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), 55 | help=('Dimensionality of Inception features to use. ' 56 | 'By default, uses pool3 features')) 57 | parser.add_argument('path', type=str, nargs=2, 58 | help=('Paths to the generated images or ' 59 | 'to .npz statistic files')) 60 | 61 | IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 62 | 'tif', 'tiff', 'webp'} 63 | 64 | 65 | class ImagePathDataset(torch.utils.data.Dataset): 66 | def __init__(self, files, transforms=None): 67 | self.files = files 68 | self.transforms = transforms 69 | 70 | def __len__(self): 71 | return len(self.files) 72 | 73 | def __getitem__(self, i): 74 | path = self.files[i] 75 | img = Image.open(path).convert('RGB') 76 | if self.transforms is not None: 77 | img = self.transforms(img) 78 | return img 79 | 80 | 81 | def get_activations(files, model, batch_size=50, dims=2048, device='cpu', 82 | num_workers=1): 83 | """Calculates the activations of the pool_3 layer for all images. 84 | Params: 85 | -- files : List of image files paths 86 | -- model : Instance of inception model 87 | -- batch_size : Batch size of images for the model to process at once. 88 | Make sure that the number of samples is a multiple of 89 | the batch size, otherwise some samples are ignored. This 90 | behavior is retained to match the original FID score 91 | implementation. 92 | -- dims : Dimensionality of features returned by Inception 93 | -- device : Device to run calculations 94 | -- num_workers : Number of parallel dataloader workers 95 | Returns: 96 | -- A numpy array of dimension (num images, dims) that contains the 97 | activations of the given tensor when feeding inception with the 98 | query tensor. 99 | """ 100 | model.eval() 101 | 102 | if batch_size > len(files): 103 | print(('Warning: batch size is bigger than the data size. ' 104 | 'Setting batch size to data size')) 105 | batch_size = len(files) 106 | 107 | dataset = ImagePathDataset(files, transforms=TF.ToTensor()) 108 | dataloader = torch.utils.data.DataLoader(dataset, 109 | batch_size=batch_size, 110 | shuffle=False, 111 | drop_last=False, 112 | num_workers=num_workers) 113 | 114 | pred_arr = np.empty((len(files), dims)) 115 | 116 | start_idx = 0 117 | 118 | for batch in tqdm(dataloader): 119 | batch = batch.to(device) 120 | 121 | with torch.no_grad(): 122 | pred = model(batch)[0] 123 | 124 | # If model output is not scalar, apply global spatial average pooling. 125 | # This happens if you choose a dimensionality not equal 2048. 126 | if pred.size(2) != 1 or pred.size(3) != 1: 127 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 128 | 129 | pred = pred.squeeze(3).squeeze(2).cpu().numpy() 130 | 131 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred 132 | 133 | start_idx = start_idx + pred.shape[0] 134 | 135 | return pred_arr 136 | 137 | 138 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 139 | """Numpy implementation of the Frechet Distance. 140 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 141 | and X_2 ~ N(mu_2, C_2) is 142 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 143 | Stable version by Dougal J. Sutherland. 144 | Params: 145 | -- mu1 : Numpy array containing the activations of a layer of the 146 | inception net (like returned by the function 'get_predictions') 147 | for generated samples. 148 | -- mu2 : The sample mean over activations, precalculated on an 149 | representative data set. 150 | -- sigma1: The covariance matrix over activations for generated samples. 151 | -- sigma2: The covariance matrix over activations, precalculated on an 152 | representative data set. 153 | Returns: 154 | -- : The Frechet Distance. 155 | """ 156 | 157 | mu1 = np.atleast_1d(mu1) 158 | mu2 = np.atleast_1d(mu2) 159 | 160 | sigma1 = np.atleast_2d(sigma1) 161 | sigma2 = np.atleast_2d(sigma2) 162 | 163 | assert mu1.shape == mu2.shape, \ 164 | 'Training and test mean vectors have different lengths' 165 | assert sigma1.shape == sigma2.shape, \ 166 | 'Training and test covariances have different dimensions' 167 | 168 | diff = mu1 - mu2 169 | 170 | # Product might be almost singular 171 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 172 | if not np.isfinite(covmean).all(): 173 | msg = ('fid calculation produces singular product; ' 174 | 'adding %s to diagonal of cov estimates') % eps 175 | print(msg) 176 | offset = np.eye(sigma1.shape[0]) * eps 177 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 178 | 179 | # Numerical error might give slight imaginary component 180 | if np.iscomplexobj(covmean): 181 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 182 | m = np.max(np.abs(covmean.imag)) 183 | raise ValueError('Imaginary component {}'.format(m)) 184 | covmean = covmean.real 185 | 186 | tr_covmean = np.trace(covmean) 187 | 188 | return (diff.dot(diff) + np.trace(sigma1) 189 | + np.trace(sigma2) - 2 * tr_covmean) 190 | 191 | 192 | def calculate_activation_statistics(files, model, batch_size=50, dims=2048, 193 | device='cpu', num_workers=1): 194 | """Calculation of the statistics used by the FID. 195 | Params: 196 | -- files : List of image files paths 197 | -- model : Instance of inception model 198 | -- batch_size : The images numpy array is split into batches with 199 | batch size batch_size. A reasonable batch size 200 | depends on the hardware. 201 | -- dims : Dimensionality of features returned by Inception 202 | -- device : Device to run calculations 203 | -- num_workers : Number of parallel dataloader workers 204 | Returns: 205 | -- mu : The mean over samples of the activations of the pool_3 layer of 206 | the inception model. 207 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 208 | the inception model. 209 | """ 210 | act = get_activations(files, model, batch_size, dims, device, num_workers) 211 | mu = np.mean(act, axis=0) 212 | sigma = np.cov(act, rowvar=False) 213 | return mu, sigma 214 | 215 | 216 | def compute_statistics_of_path(path, model, batch_size, dims, device, 217 | num_workers=1): 218 | if path.endswith('.npz'): 219 | with np.load(path) as f: 220 | m, s = f['mu'][:], f['sigma'][:] 221 | else: 222 | path = pathlib.Path(path) 223 | files = sorted([file for ext in IMAGE_EXTENSIONS 224 | for file in path.glob('*.{}'.format(ext))]) 225 | m, s = calculate_activation_statistics(files, model, batch_size, 226 | dims, device, num_workers) 227 | 228 | return m, s 229 | 230 | 231 | def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1): 232 | """Calculates the FID of two paths""" 233 | for p in paths: 234 | if not os.path.exists(p): 235 | raise RuntimeError('Invalid path: %s' % p) 236 | 237 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 238 | 239 | model = InceptionV3([block_idx]).to(device) 240 | 241 | m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, 242 | dims, device, num_workers) 243 | m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, 244 | dims, device, num_workers) 245 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 246 | 247 | return fid_value 248 | -------------------------------------------------------------------------------- /integer_only_gan/dcgan_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import math 5 | import sys 6 | 7 | from diffaug import DiffAugment 8 | policy = 'color,translation,cutout' 9 | 10 | import torchvision.transforms as transforms 11 | from torchvision.utils import save_image 12 | from artbench import ArtBench10 13 | 14 | from torch.utils.data import DataLoader 15 | from torchvision import datasets 16 | from tqdm import tqdm 17 | 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | import torch.autograd as autograd 21 | import torch 22 | from torch.quantization.qconfig import QConfig 23 | from torch.quantization import QuantStub, DeQuantStub 24 | 25 | #from torch.ao.quantization.observer import ( 26 | # MovingAverageMinMaxObserver, 27 | #) 28 | from torch.ao.quantization.fake_quantize import ( 29 | FakeQuantize, 30 | #default_fused_wt_fake_quant, 31 | ) 32 | from integer_only_observer import MovingAverageIntegerMinMaxObserver 33 | 34 | from fid_score import calculate_fid_given_paths 35 | 36 | os.makedirs("images", exist_ok=True) 37 | 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument("--n_epochs", type=int, default=400, help="number of epochs of training") 40 | parser.add_argument("--batch_size", type=int, default=64, help="size of the batches") 41 | parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") 42 | parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") 43 | parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") 44 | parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation") 45 | parser.add_argument("--latent_dim", type=int, default=16, help="dimensionality of the latent space") 46 | parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension") 47 | parser.add_argument("--channels", type=int, default=3, help="number of image channels") 48 | parser.add_argument("--n_critic", type=int, default=1, help="number of training steps for discriminator per iter") 49 | parser.add_argument("--clip_value", type=float, default=0.01, help="lower and upper clip value for disc. weights") 50 | parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples") 51 | opt = parser.parse_args() 52 | print(opt) 53 | 54 | img_shape = (opt.channels, opt.img_size, opt.img_size) 55 | 56 | cuda = True if torch.cuda.is_available() else False 57 | 58 | 59 | def weights_init_normal(m): 60 | classname = m.__class__.__name__ 61 | if classname.find("Conv") != -1: 62 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 63 | elif classname.find("BatchNorm2d") != -1: 64 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 65 | torch.nn.init.constant_(m.bias.data, 0.0) 66 | 67 | 68 | HIDDEN_DIM = 1 69 | class Generator(nn.Module): 70 | def __init__(self): 71 | super(Generator, self).__init__() 72 | self.quant = QuantStub() 73 | self.dequant = DeQuantStub() 74 | 75 | self.init_size = opt.img_size // 4 76 | self.l1 = nn.Linear(opt.latent_dim, HIDDEN_DIM * self.init_size ** 2) 77 | 78 | self.conv_blocks = nn.Sequential( 79 | #nn.BatchNorm2d(HIDDEN_DIM), 80 | nn.Upsample(scale_factor=2), 81 | nn.Conv2d(HIDDEN_DIM, HIDDEN_DIM, 3, stride=1, padding=1), 82 | nn.BatchNorm2d(1, 0.8), 83 | nn.ReLU(), 84 | nn.Upsample(scale_factor=2), 85 | nn.Conv2d(HIDDEN_DIM, HIDDEN_DIM, 3, stride=1, padding=1), 86 | nn.BatchNorm2d(HIDDEN_DIM, 0.8), 87 | nn.ReLU(), 88 | nn.Conv2d(HIDDEN_DIM, opt.channels, 3, stride=1, padding=1), 89 | ) 90 | 91 | self.tanh = nn.Tanh() 92 | 93 | def forward(self, z): 94 | z = self.quant(z) 95 | 96 | out = self.l1(z) 97 | out = out.view(out.shape[0], HIDDEN_DIM, self.init_size, self.init_size) 98 | img = self.conv_blocks(out) 99 | 100 | img = self.dequant(img) 101 | 102 | img = self.tanh(img) 103 | return img 104 | 105 | 106 | class Discriminator(nn.Module): 107 | def __init__(self): 108 | super(Discriminator, self).__init__() 109 | 110 | def discriminator_block(in_filters, out_filters, bn=True): 111 | block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)] 112 | if bn: 113 | block.append(nn.BatchNorm2d(out_filters, 0.8)) 114 | return block 115 | 116 | self.model = nn.Sequential( 117 | *discriminator_block(opt.channels, 16, bn=False), 118 | *discriminator_block(16, 32), 119 | *discriminator_block(32, 64), 120 | *discriminator_block(64, 128), 121 | ) 122 | 123 | # The height and width of downsampled image 124 | ds_size = opt.img_size // 2 ** 4 125 | self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid()) 126 | 127 | def forward(self, img): 128 | out = self.model(img) 129 | out = out.view(out.shape[0], -1) 130 | validity = self.adv_layer(out) 131 | 132 | return validity 133 | 134 | # Loss function 135 | adversarial_loss = torch.nn.BCELoss() 136 | 137 | # Initialize generator and discriminator 138 | generator = Generator() 139 | discriminator = Discriminator() 140 | 141 | if cuda: 142 | generator.cuda() 143 | discriminator.cuda() 144 | adversarial_loss.cuda() 145 | 146 | # Initialize weights 147 | generator.apply(weights_init_normal) 148 | discriminator.apply(weights_init_normal) 149 | 150 | qconfig = QConfig(activation=FakeQuantize.with_args( 151 | observer=MovingAverageIntegerMinMaxObserver, 152 | quant_min=0, 153 | quant_max=255, 154 | reduce_range=True), 155 | weight=FakeQuantize.with_args( 156 | observer=MovingAverageIntegerMinMaxObserver, 157 | quant_min=-128, 158 | quant_max=127, 159 | dtype=torch.qint8, 160 | qscheme=torch.per_tensor_symmetric 161 | )) 162 | generator.qconfig = qconfig 163 | 164 | generator.eval() 165 | 166 | # fuse the activations to preceding layers, where applicable 167 | # this needs to be done manually depending on the model architecture 168 | generator = torch.quantization.fuse_modules(generator, 169 | [["conv_blocks.1", "conv_blocks.2", "conv_blocks.3"] 170 | ,["conv_blocks.5", "conv_blocks.6", "conv_blocks.7"]] 171 | ) 172 | 173 | generator.train() 174 | 175 | # Prepare the model for QAT. This inserts observers and fake_quants in 176 | # the model that will observe weight and activation tensors during calibration. 177 | generator = torch.quantization.prepare_qat(generator) 178 | 179 | 180 | # Configure data loader 181 | os.makedirs("./data/artbench", exist_ok=True) 182 | os.makedirs("./data/fid_eval", exist_ok=True) 183 | os.makedirs("./data/fid_real", exist_ok=True) 184 | 185 | dataloader = torch.utils.data.DataLoader( 186 | ArtBench10( 187 | "./data/artbench", 188 | train=True, 189 | download=True, 190 | transform=transforms.Compose( 191 | [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] 192 | ), 193 | ), 194 | batch_size=opt.batch_size, 195 | shuffle=True, 196 | ) 197 | 198 | testloader = torch.utils.data.DataLoader( 199 | ArtBench10( 200 | "./data/artbench", 201 | train=False, 202 | download=True, 203 | transform=transforms.Compose( 204 | [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] 205 | ), 206 | ), 207 | batch_size=1024, 208 | shuffle=False, 209 | ) 210 | 211 | 212 | # Optimizers 213 | optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 214 | optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 215 | 216 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 217 | 218 | # ---------- 219 | # Training 220 | # ---------- 221 | fid_every = 10 222 | 223 | batches_done = 0 224 | for epoch in range(opt.n_epochs): 225 | 226 | for i, (imgs, _) in enumerate(dataloader): 227 | 228 | # Adversarial ground truths 229 | valid = Tensor(imgs.shape[0], 1).fill_(1.0) 230 | fake = Tensor(imgs.shape[0], 1).fill_(0.0) 231 | 232 | # Configure input 233 | real_imgs = imgs.type(Tensor) 234 | 235 | # ----------------- 236 | # Train Generator 237 | # ----------------- 238 | 239 | optimizer_G.zero_grad() 240 | 241 | # Sample noise as generator input 242 | z = Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))) 243 | 244 | # Generate a batch of images 245 | gen_imgs = generator(z) 246 | 247 | # Loss measures generator's ability to fool the discriminator 248 | g_loss = adversarial_loss(discriminator(DiffAugment(gen_imgs, policy=policy)), valid) 249 | 250 | g_loss.backward() 251 | optimizer_G.step() 252 | 253 | # --------------------- 254 | # Train Discriminator 255 | # --------------------- 256 | 257 | optimizer_D.zero_grad() 258 | 259 | # Measure discriminator's ability to classify real from generated samples 260 | real_loss = adversarial_loss(discriminator(DiffAugment(real_imgs, policy=policy)), valid) 261 | fake_loss = adversarial_loss(discriminator(DiffAugment(gen_imgs, policy=policy).detach()), fake) 262 | d_loss = (real_loss + fake_loss) / 2 263 | 264 | d_loss.backward() 265 | optimizer_D.step() 266 | 267 | batches_done = epoch * len(dataloader) + i 268 | if batches_done % opt.sample_interval == 0: 269 | print( 270 | "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" 271 | % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()) 272 | ) 273 | save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True) 274 | 275 | if epoch % fid_every == fid_every - 1 or epoch == 0: 276 | with torch.no_grad(): 277 | for i, (imgs, _) in tqdm(enumerate(testloader), total = len(testloader)): 278 | # Sample noise as generator input 279 | z = Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))) 280 | 281 | # Generate a batch of images 282 | fake_imgs = generator(z) 283 | 284 | for j in range(imgs.shape[0]): 285 | k = i * imgs.shape[0] + j 286 | save_image(fake_imgs.data[j], "./data/fid_eval/%d.png" % k) 287 | save_image(imgs.data[j], "./data/fid_real/%d.png" % k) 288 | 289 | fid = calculate_fid_given_paths(paths=('./data/fid_real', './data/fid_eval'), 290 | batch_size=256, device='cuda', dims=2048) 291 | print(f"FID at {epoch}: {fid}") 292 | torch.save(generator.state_dict(), f'models/generator_{epoch}.pt') 293 | torch.save(discriminator.state_dict(), f'models/discriminator_{epoch}.pt') 294 | torch.save(generator.state_dict(), 'generator.pt') 295 | 296 | torch.save(generator.cpu().state_dict(), 'generator.pt') 297 | torch.save(discriminator.cpu().state_dict(), 'discriminator.pt') --------------------------------------------------------------------------------