├── PTQ4DM ├── QDrop │ ├── LICENSE │ ├── README.md │ ├── data │ │ └── imagenet.py │ ├── exp │ │ ├── config.sh │ │ └── run.sh │ ├── hubconf.py │ ├── main_imagenet.py │ ├── models │ │ ├── __init__.py │ │ ├── mnasnet.py │ │ ├── mobilenetv2.py │ │ ├── regnet.py │ │ ├── resnet.py │ │ └── utils.py │ └── quant │ │ ├── __init__.py │ │ ├── adaptive_rounding.py │ │ ├── block_recon.py │ │ ├── data_utils.py │ │ ├── fold_bn.py │ │ ├── layer_recon.py │ │ ├── quant_block.py │ │ ├── quant_layer.py │ │ ├── quant_model.py │ │ ├── set_act_quantize_params.py │ │ └── set_weight_quantize_params.py ├── baseline.sh ├── guided-diffusion │ ├── LICENSE │ ├── README.md │ ├── datasets │ │ ├── README.md │ │ └── lsun_bedroom.py │ ├── evaluations │ │ ├── README.md │ │ ├── evaluator.py │ │ └── requirements.txt │ ├── guided_diffusion │ │ ├── __init__.py │ │ ├── dist_util.py │ │ ├── fp16_util.py │ │ ├── gaussian_diffusion.py │ │ ├── image_datasets.py │ │ ├── logger.py │ │ ├── losses.py │ │ ├── nn.py │ │ ├── resample.py │ │ ├── respace.py │ │ ├── script_util.py │ │ ├── train_util.py │ │ └── unet.py │ ├── model-card.md │ ├── scripts │ │ ├── classifier_sample.py │ │ ├── classifier_train.py │ │ ├── image_nll.py │ │ ├── image_sample.py │ │ ├── image_train.py │ │ ├── quant_image_sample.py │ │ ├── super_res_sample.py │ │ └── super_res_train.py │ └── setup.py ├── improved-diffusion │ ├── LICENSE │ ├── README.md │ ├── datasets │ │ ├── README.md │ │ ├── cifar10.py │ │ └── lsun_bedroom.py │ ├── improved_diffusion │ │ ├── __init__.py │ │ ├── dist_util.py │ │ ├── fp16_util.py │ │ ├── gaussian_diffusion.py │ │ ├── image_datasets.py │ │ ├── logger.py │ │ ├── losses.py │ │ ├── nn.py │ │ ├── resample.py │ │ ├── respace.py │ │ ├── script_util.py │ │ ├── train_util.py │ │ └── unet.py │ ├── scripts │ │ ├── image_nll.py │ │ ├── image_sample.py │ │ ├── image_train.py │ │ ├── quant_image_sample.py │ │ ├── super_res_sample.py │ │ └── super_res_train.py │ └── setup.py ├── quant_sample.sh ├── quant_sample_ddim_in_backward_DNTC.sh ├── quant_sample_ddim_in_forward.sh ├── quant_sample_ddim_in_random.sh └── quant_sample_ddim_in_raw.sh ├── README.md └── activation_hist.png /PTQ4DM/QDrop/README.md: -------------------------------------------------------------------------------- 1 | # QDrop 2 | PyTorch implementation of QDrop: Randomly Dropping Quantization for Extremely Low-bit Post-Training Quantization 3 | 4 | ## Overview 5 | 6 | QDrop is a simple yet effective approach, which randomly drops the quantization of activations during reconstruction to pursue flatter model on both calibration and test data. QDrop is easy to implement for various neural networks including CNNs and Transformers, and plug-and-play with little additional computational complexity. 7 | 8 | ## Integrated into MQBench 9 | Our method has been integrated into quantization benchmark [MQBench](https://github.com/ModelTC/MQBench). And the docs can be found here . 10 | 11 | **Moreover, obeying the design style of quantization structure in MQBench, we also implement another form of QDrop in branch "qdrop". You can use any branch you like. Details seen in the readme in branch "qdrop"** 12 | 13 | 14 | ## Usage 15 | 16 | Go into the exp directory and you can see run.sh and config.sh. run.sh represents a example for resnet18 w2a2. You can run config.sh to produce similar scripts across bits and archs. 17 | 18 | run.sh 19 | ``` 20 | #!/bin/bash 21 | PYTHONPATH=../../../../:$PYTHONPATH \ 22 | python ../../../main_imagenet.py --data_path data_path \ 23 | --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 2 --act_quant --order together --wwq --waq --awq --aaq \ 24 | --weight 0.01 --input_prob 0.5 --prob 0.5 25 | ``` 26 | 27 | config.sh 28 | 29 | ``` 30 | #!/bin/bash 31 | # pretrain models and hyperparameters following BRECQ 32 | arch=('resnet18' 'resnet50' 'mobilenetv2' 'regnetx_600m' 'regnetx_3200m' 'mnasnet') 33 | weight=(0.01 0.01 0.1 0.01 0.01 0.2) 34 | w_bit=(3 2 2 4) 35 | a_bit=(3 4 2 4) 36 | for((i=0;i<6;i++)) 37 | do 38 | for((j=0;j<4;j++)) 39 | do 40 | path=w${w_bit[j]}a${a_bit[j]}/${arch[i]} 41 | mkdir -p $path 42 | echo $path 43 | cp run.sh $path/run.sh 44 | sed -re "s/weight([[:space:]]+)0.01/weight ${weight[i]}/" -i $path/run.sh 45 | sed -re "s/resnet18/${arch[i]}/" -i $path/run.sh 46 | sed -re "s/n_bits_w([[:space:]]+)2/n_bits_w ${w_bit[j]}/" -i $path/run.sh 47 | sed -re "s/n_bits_a([[:space:]]+)2/n_bits_a ${a_bit[j]}/" -i $path/run.sh 48 | done 49 | done 50 | ``` 51 | Then you can get a series of scripts and run it directly to get the following results. 52 | ## Results 53 | 54 | Results on low-bit activation in terms of accuracy on ImageNet. 55 | 56 | | Methods | Bits (W/A) | Res18 |Res50 | MNV2 | Reg600M | Reg3.2G | MNasx2 | 57 | | ------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | 58 | | Full Prec. | 32/32 | 71.06 | 77.00 | 72.49 | 73.71 | 78.36 | 76.68 | 59 | |QDrop| 4/4 | 69.07 | 75.03 | 67.91 | 70.81 | 76.36 | 72.81 | 60 | |QDrop| 2/4 | 64.49 | 70.09 | 53.62 | 63.36 | 71.69 | 63.22 | 61 | |QDrop| 3/3 | 65.57 | 71.28 | 55.00 | 64.84 | 71.70 | 64.44 | 62 | |QDrop| 2/2 | 51.76 | 55.36 | 10.21 | 38.35 | 54.00 | 23.62 | 63 | 64 | 65 | 66 | ## More experiments 67 | 68 | **Case 1, Case 2, Case 3** 69 | 70 | To compare the results of 3 Cases mentioned in the observation part of the method, we can use the following commands. 71 | 72 | Case 1: weight tuning doesn't feel any activation quantization 73 | 74 | Case 2: weight tuning feels full activation quantization 75 | 76 | Case 3: weight tuning feels part activation quantization 77 | 78 | ``` 79 | # Case 1 80 | python main_imagenet.py --data_path data_path --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 2 --act_quant --order after --wwq --awq --aaq --input_prob 1.0 --prob 1.0 81 | # Case 2 82 | python main_imagenet.py --data_path data_path --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 2 --act_quant --order before --wwq --waq --aaq --input_prob 1.0 --prob 1.0 83 | # Case 3 84 | python main_imagenet.py --data_path data_path --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 2 --act_quant --order after --wwq --waq --awq --aaq --input_prob 1.0 --prob 1.0 85 | ``` 86 | 87 | **No Drop** 88 | 89 | To compare with QDrop, No Drop can be achieved by turning the probability to 1.0 to disable dropping quantization during weight tuning. 90 | 91 | ``` 92 | python main_imagenet.py --data_path data_path --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 2 --act_quant --order together --wwq --waq --awq --aaq --input_prob 1.0 --prob 1.0 93 | ``` 94 | 95 | ## Reference 96 | 97 | If you find this repo useful for your research, please consider citing the paper: 98 | 99 | @article{wei2022qdrop, 100 | title={QDrop: Randomly Dropping Quantization for Extremely Low-bit Post-Training Quantization}, 101 | author={Wei, Xiuying and Gong, Ruihao and Li, Yuhang and Liu, Xianglong and Yu, Fengwei}, 102 | journal={arXiv preprint arXiv:2203.05740}, 103 | year={2022} 104 | } -------------------------------------------------------------------------------- /PTQ4DM/QDrop/data/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.transforms as transforms 4 | import torchvision.datasets as datasets 5 | 6 | 7 | def build_imagenet_data(data_path: str = '', input_size: int = 224, batch_size: int = 64, workers: int = 4): 8 | print('==> Using Pytorch Dataset') 9 | 10 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 11 | std=[0.229, 0.224, 0.225]) 12 | # traindir = '/mnt/lustre/share/ImageNet-Pytorch/train' 13 | # valdir = '/mnt/lustre/share/ImageNet-Pytorch/val' 14 | traindir = os.path.join(data_path, 'train') 15 | valdir = os.path.join(data_path, 'val') 16 | # torchvision.set_image_backend('accimage') 17 | train_dataset = datasets.ImageFolder( 18 | traindir, 19 | transforms.Compose([ 20 | transforms.RandomResizedCrop(input_size), 21 | transforms.RandomHorizontalFlip(), 22 | transforms.ToTensor(), 23 | normalize, 24 | ])) 25 | 26 | train_loader = torch.utils.data.DataLoader( 27 | train_dataset, batch_size=batch_size, shuffle=True, 28 | num_workers=workers, pin_memory=True) 29 | val_loader = torch.utils.data.DataLoader( 30 | datasets.ImageFolder(valdir, transforms.Compose([ 31 | transforms.Resize(256), 32 | transforms.CenterCrop(input_size), 33 | transforms.ToTensor(), 34 | normalize, 35 | ])), 36 | batch_size=batch_size, shuffle=False, 37 | num_workers=workers, pin_memory=True) 38 | return train_loader, val_loader 39 | -------------------------------------------------------------------------------- /PTQ4DM/QDrop/exp/config.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | arch=('resnet18' 'resnet50' 'mobilenetv2' 'regnetx_600m' 'regnetx_3200m' 'mnasnet') 3 | weight=(0.01 0.01 0.1 0.01 0.01 0.2) 4 | w_bit=(3 2 2 4) 5 | a_bit=(3 4 2 4) 6 | for((i=0;i<6;i++)) 7 | do 8 | for((j=0;j<4;j++)) 9 | do 10 | path=w${w_bit[j]}a${a_bit[j]}/${arch[i]} 11 | mkdir -p $path 12 | echo $path 13 | cp run.sh $path/run.sh 14 | sed -re "s/weight([[:space:]]+)0.01/weight ${weight[i]}/" -i $path/run.sh 15 | sed -re "s/resnet18/${arch[i]}/" -i $path/run.sh 16 | sed -re "s/n_bits_w([[:space:]]+)2/n_bits_w ${w_bit[j]}/" -i $path/run.sh 17 | sed -re "s/n_bits_a([[:space:]]+)2/n_bits_a ${a_bit[j]}/" -i $path/run.sh 18 | # tmux kill-session -t ${arch[i]}_w${w_bit[j]}a${a_bit[j]} 19 | # cd $path 20 | # tmux new -s ${arch[i]}_w${w_bit[j]}a${a_bit[j]} -d ./run.sh 21 | # cd ../.. 22 | done 23 | done 24 | -------------------------------------------------------------------------------- /PTQ4DM/QDrop/exp/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | PYTHONPATH=../../../../:$PYTHONPATH \ 3 | python ../../../main_imagenet.py --data_path data_path \ 4 | --arch resnet18 --n_bits_w 2 --channel_wise --n_bits_a 2 --act_quant --order together --wwq --waq --awq --aaq \ 5 | --weight 0.01 --input_prob 0.5 --prob 0.5 -------------------------------------------------------------------------------- /PTQ4DM/QDrop/hubconf.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from QDrop.models.resnet import resnet18 as _resnet18 3 | from QDrop.models.resnet import resnet50 as _resnet50 4 | from QDrop.models.mobilenetv2 import mobilenetv2 as _mobilenetv2 5 | from QDrop.models.mnasnet import mnasnet as _mnasnet 6 | from QDrop.models.regnet import regnetx_600m as _regnetx_600m 7 | from QDrop.models.regnet import regnetx_3200m as _regnetx_3200m 8 | import torch 9 | dependencies = ['torch'] 10 | prefix = '/mnt/lustre/weixiuying' 11 | model_path = { 12 | 'resnet18': prefix+'/model_zoo/resnet18_imagenet.pth.tar', 13 | 'resnet50': prefix+'/model_zoo/resnet50_imagenet.pth.tar', 14 | 'mbv2': prefix+'/model_zoo/mobilenetv2.pth.tar', 15 | 'reg600m': prefix+'/model_zoo/regnet_600m.pth.tar', 16 | 'reg3200m': prefix+'/model_zoo/regnet_3200m.pth.tar', 17 | 'mnasnet': prefix+'/model_zoo/mnasnet.pth.tar', 18 | 'spring_resnet50': prefix+'/model_zoo/spring_resnet50.pth', 19 | } 20 | 21 | 22 | def resnet18(pretrained=False, **kwargs): 23 | # Call the model, load pretrained weights 24 | model = _resnet18(**kwargs) 25 | if pretrained: 26 | checkpoint = torch.load(model_path['resnet18'], map_location='cpu') 27 | model.load_state_dict(checkpoint) 28 | return model 29 | 30 | 31 | def resnet50(pretrained=False, **kwargs): 32 | # Call the model, load pretrained weights 33 | model = _resnet50(**kwargs) 34 | if pretrained: 35 | checkpoint = torch.load(model_path['resnet50'], map_location='cpu') 36 | model.load_state_dict(checkpoint) 37 | return model 38 | 39 | 40 | def spring_resnet50(pretrained=False, **kwargs): 41 | # Call the model, load pretrained weights 42 | model = _resnet50(**kwargs) 43 | if pretrained: 44 | checkpoint = torch.load(model_path['spring_resnet50'], map_location='cpu') 45 | q = OrderedDict() 46 | for k, v in checkpoint.items(): 47 | q[k[7:]] = v 48 | model.load_state_dict(q) 49 | return model 50 | 51 | 52 | def mobilenetv2(pretrained=False, **kwargs): 53 | # Call the model, load pretrained weights 54 | model = _mobilenetv2(**kwargs) 55 | if pretrained: 56 | checkpoint = torch.load(model_path['mbv2'], map_location='cpu') 57 | model.load_state_dict(checkpoint['model']) 58 | return model 59 | 60 | 61 | def regnetx_600m(pretrained=False, **kwargs): 62 | # Call the model, load pretrained weights 63 | model = _regnetx_600m(**kwargs) 64 | if pretrained: 65 | checkpoint = torch.load(model_path['reg600m'], map_location='cpu') 66 | model.load_state_dict(checkpoint) 67 | return model 68 | 69 | 70 | def regnetx_3200m(pretrained=False, **kwargs): 71 | # Call the model, load pretrained weights 72 | model = _regnetx_3200m(**kwargs) 73 | if pretrained: 74 | checkpoint = torch.load(model_path['reg3200m'], map_location='cpu') 75 | model.load_state_dict(checkpoint) 76 | return model 77 | 78 | 79 | def mnasnet(pretrained=False, **kwargs): 80 | # Call the model, load pretrained weights 81 | model = _mnasnet(**kwargs) 82 | if pretrained: 83 | checkpoint = torch.load(model_path['mnasnet'], map_location='cpu') 84 | model.load_state_dict(checkpoint) 85 | return model 86 | -------------------------------------------------------------------------------- /PTQ4DM/QDrop/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/42Shawn/PTQ4DM/180a4d15d400316e2971f54d10b96c53f8673455/PTQ4DM/QDrop/models/__init__.py -------------------------------------------------------------------------------- /PTQ4DM/QDrop/models/mnasnet.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | __all__ = ['mnasnet'] 7 | 8 | 9 | class _InvertedResidual(nn.Module): 10 | 11 | def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor): 12 | super(_InvertedResidual, self).__init__() 13 | assert stride in [1, 2] 14 | assert kernel_size in [3, 5] 15 | mid_ch = in_ch * expansion_factor 16 | self.apply_residual = (in_ch == out_ch and stride == 1) 17 | self.layers = nn.Sequential( 18 | # Pointwise 19 | nn.Conv2d(in_ch, mid_ch, 1, bias=False), 20 | BN(mid_ch), 21 | nn.ReLU(inplace=True), 22 | # Depthwise 23 | nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, 24 | stride=stride, groups=mid_ch, bias=False), 25 | BN(mid_ch), 26 | nn.ReLU(inplace=True), 27 | # Linear pointwise. Note that there's no activation. 28 | nn.Conv2d(mid_ch, out_ch, 1, bias=False), 29 | BN(out_ch)) 30 | 31 | def forward(self, input): 32 | if self.apply_residual: 33 | return self.layers(input) + input 34 | else: 35 | return self.layers(input) 36 | 37 | 38 | def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats): 39 | """ Creates a stack of inverted residuals. """ 40 | assert repeats >= 1 41 | # First one has no skip, because feature map size changes. 42 | first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor) 43 | remaining = [] 44 | for _ in range(1, repeats): 45 | remaining.append( 46 | _InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor)) 47 | return nn.Sequential(first, *remaining) 48 | 49 | 50 | def _round_to_multiple_of(val, divisor, round_up_bias=0.9): 51 | """ Asymmetric rounding to make `val` divisible by `divisor`. With default 52 | bias, will round up, unless the number is no more than 10% greater than the 53 | smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """ 54 | assert 0.0 < round_up_bias < 1.0 55 | new_val = max(divisor, int(val + divisor / 2) // divisor * divisor) 56 | return new_val if new_val >= round_up_bias * val else new_val + divisor 57 | 58 | 59 | def _get_depths(scale): 60 | """ Scales tensor depths as in reference MobileNet code, prefers rouding up 61 | rather than down. """ 62 | depths = [32, 16, 24, 40, 80, 96, 192, 320] 63 | return [_round_to_multiple_of(depth * scale, 8) for depth in depths] 64 | 65 | 66 | class MNASNet(torch.nn.Module): 67 | # Version 2 adds depth scaling in the initial stages of the network. 68 | _version = 2 69 | 70 | def __init__(self, scale=2.0, num_classes=1000, dropout=0.0): 71 | super(MNASNet, self).__init__() 72 | 73 | global BN 74 | BN = nn.BatchNorm2d 75 | 76 | assert scale > 0.0 77 | self.scale = scale 78 | self.num_classes = num_classes 79 | depths = _get_depths(scale) 80 | layers = [ 81 | # First layer: regular conv. 82 | nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False), 83 | BN(depths[0]), 84 | nn.ReLU(inplace=True), 85 | # Depthwise separable, no skip. 86 | nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, 87 | groups=depths[0], bias=False), 88 | BN(depths[0]), 89 | nn.ReLU(inplace=True), 90 | nn.Conv2d(depths[0], depths[1], 1, 91 | padding=0, stride=1, bias=False), 92 | BN(depths[1]), 93 | # MNASNet blocks: stacks of inverted residuals. 94 | _stack(depths[1], depths[2], 3, 2, 3, 3), 95 | _stack(depths[2], depths[3], 5, 2, 3, 3), 96 | _stack(depths[3], depths[4], 5, 2, 6, 3), 97 | _stack(depths[4], depths[5], 3, 1, 6, 2), 98 | _stack(depths[5], depths[6], 5, 2, 6, 4), 99 | _stack(depths[6], depths[7], 3, 1, 6, 1), 100 | # Final mapping to classifier input. 101 | nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False), 102 | BN(1280), 103 | nn.ReLU(inplace=True), 104 | ] 105 | self.layers = nn.Sequential(*layers) 106 | self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), 107 | nn.Linear(1280, num_classes)) 108 | self._initialize_weights() 109 | 110 | def forward(self, x): 111 | x = self.layers(x) 112 | # Equivalent to global avgpool and removing H and W dimensions. 113 | x = x.mean([2, 3]) 114 | return self.classifier(x) 115 | 116 | def _initialize_weights(self): 117 | for m in self.modules(): 118 | if isinstance(m, nn.Conv2d): 119 | nn.init.kaiming_normal_(m.weight, mode="fan_out", 120 | nonlinearity="relu") 121 | if m.bias is not None: 122 | nn.init.zeros_(m.bias) 123 | elif isinstance(m, nn.BatchNorm2d): 124 | nn.init.ones_(m.weight) 125 | nn.init.zeros_(m.bias) 126 | elif isinstance(m, nn.Linear): 127 | nn.init.kaiming_uniform_(m.weight, mode="fan_out", 128 | nonlinearity="sigmoid") 129 | nn.init.zeros_(m.bias) 130 | 131 | 132 | def mnasnet(**kwargs): 133 | model = MNASNet(**kwargs) 134 | return model 135 | 136 | -------------------------------------------------------------------------------- /PTQ4DM/QDrop/models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python 2 | 3 | import torch.nn as nn 4 | import math 5 | import torch 6 | 7 | 8 | def conv_bn(inp, oup, stride): 9 | return nn.Sequential( 10 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 11 | nn.BatchNorm2d(oup), 12 | nn.ReLU6(inplace=True) 13 | ) 14 | 15 | 16 | def conv_1x1_bn(inp, oup): 17 | return nn.Sequential( 18 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 19 | nn.BatchNorm2d(oup), 20 | nn.ReLU6(inplace=True) 21 | ) 22 | 23 | 24 | class InvertedResidual(nn.Module): 25 | def __init__(self, inp, oup, stride, expand_ratio): 26 | super(InvertedResidual, self).__init__() 27 | self.stride = stride 28 | assert stride in [1, 2] 29 | 30 | hidden_dim = round(inp * expand_ratio) 31 | self.use_res_connect = self.stride == 1 and inp == oup 32 | self.expand_ratio = expand_ratio 33 | if expand_ratio == 1: 34 | self.conv = nn.Sequential( 35 | # dw 36 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 37 | nn.BatchNorm2d(hidden_dim), 38 | nn.ReLU6(inplace=True), 39 | # pw-linear 40 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 41 | nn.BatchNorm2d(oup), 42 | ) 43 | else: 44 | self.conv = nn.Sequential( 45 | # pw 46 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 47 | nn.BatchNorm2d(hidden_dim), 48 | nn.ReLU6(inplace=True), 49 | # dw 50 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 51 | nn.BatchNorm2d(hidden_dim), 52 | nn.ReLU6(inplace=True), 53 | # pw-linear 54 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 55 | nn.BatchNorm2d(oup), 56 | ) 57 | 58 | def forward(self, x): 59 | if self.use_res_connect: 60 | return x + self.conv(x) 61 | else: 62 | return self.conv(x) 63 | 64 | 65 | class MobileNetV2(nn.Module): 66 | def __init__(self, n_class=1000, input_size=224, width_mult=1., dropout=0.0): 67 | super(MobileNetV2, self).__init__() 68 | block = InvertedResidual 69 | input_channel = 32 70 | last_channel = 1280 71 | interverted_residual_setting = [ 72 | # t, c, n, s 73 | [1, 16, 1, 1], 74 | [6, 24, 2, 2], 75 | [6, 32, 3, 2], 76 | [6, 64, 4, 2], 77 | [6, 96, 3, 1], 78 | [6, 160, 3, 2], 79 | [6, 320, 1, 1], 80 | ] 81 | 82 | # building first layer 83 | assert input_size % 32 == 0 84 | input_channel = int(input_channel * width_mult) 85 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 86 | self.features = [conv_bn(3, input_channel, 2)] 87 | # building inverted residual blocks 88 | for t, c, n, s in interverted_residual_setting: 89 | output_channel = int(c * width_mult) 90 | for i in range(n): 91 | if i == 0: 92 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 93 | else: 94 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 95 | input_channel = output_channel 96 | # building last several layers 97 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 98 | # self.features.append(nn.AvgPool2d(input_size // 32)) 99 | # make it nn.Sequential 100 | self.features = nn.Sequential(*self.features) 101 | 102 | # building classifier 103 | self.classifier = nn.Sequential( 104 | nn.Dropout(dropout), 105 | nn.Linear(self.last_channel, n_class), 106 | ) 107 | 108 | self._initialize_weights() 109 | 110 | def forward(self, x): 111 | x = self.features(x) 112 | x = x.mean([2, 3]) 113 | x = self.classifier(x) 114 | return x 115 | 116 | def _initialize_weights(self): 117 | for m in self.modules(): 118 | if isinstance(m, nn.Conv2d): 119 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 120 | m.weight.data.normal_(0, math.sqrt(2. / n)) 121 | if m.bias is not None: 122 | m.bias.data.zero_() 123 | elif isinstance(m, nn.BatchNorm2d): 124 | m.weight.data.fill_(1) 125 | m.bias.data.zero_() 126 | elif isinstance(m, nn.Linear): 127 | n = m.weight.size(1) 128 | m.weight.data.normal_(0, 0.01) 129 | m.bias.data.zero_() 130 | 131 | 132 | def mobilenetv2(**kwargs): 133 | """ 134 | Constructs a MobileNetV2 model. 135 | """ 136 | model = MobileNetV2(**kwargs) 137 | return model -------------------------------------------------------------------------------- /PTQ4DM/QDrop/models/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/42Shawn/PTQ4DM/180a4d15d400316e2971f54d10b96c53f8673455/PTQ4DM/QDrop/models/utils.py -------------------------------------------------------------------------------- /PTQ4DM/QDrop/quant/__init__.py: -------------------------------------------------------------------------------- 1 | from .block_recon import block_reconstruction 2 | from .layer_recon import layer_reconstruction 3 | from .quant_block import BaseQuantBlock 4 | from .quant_layer import QuantModule 5 | from .quant_model import QuantModel 6 | from .set_weight_quantize_params import set_weight_quantize_params, get_init, save_quantized_weight 7 | from .set_act_quantize_params import set_act_quantize_params 8 | -------------------------------------------------------------------------------- /PTQ4DM/QDrop/quant/adaptive_rounding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .quant_layer import UniformAffineQuantizer, round_ste 4 | 5 | 6 | class AdaRoundQuantizer(nn.Module): 7 | """ 8 | Adaptive Rounding Quantizer, used to optimize the rounding policy 9 | by reconstructing the intermediate output. 10 | Based on 11 | Up or Down? Adaptive Rounding for Post-Training Quantization: https://arxiv.org/abs/2004.10568 12 | 13 | :param uaq: UniformAffineQuantizer, used to initialize quantization parameters in this quantizer 14 | :param round_mode: controls the forward pass in this quantizer 15 | :param weight_tensor: initialize alpha 16 | """ 17 | 18 | def __init__(self, uaq: UniformAffineQuantizer, weight_tensor: torch.Tensor, round_mode='learned_round_sigmoid'): 19 | super(AdaRoundQuantizer, self).__init__() 20 | # copying all attributes from UniformAffineQuantizer 21 | self.n_bits = uaq.n_bits 22 | self.sym = uaq.sym 23 | self.delta = uaq.delta 24 | self.zero_point = uaq.zero_point 25 | self.n_levels = uaq.n_levels 26 | 27 | self.round_mode = round_mode 28 | self.alpha = None 29 | self.soft_targets = False 30 | 31 | # params for sigmoid function 32 | self.gamma, self.zeta = -0.1, 1.1 33 | self.beta = 2/3 34 | self.init_alpha(x=weight_tensor.clone()) 35 | 36 | def forward(self, x): 37 | if self.round_mode == 'nearest': 38 | x_int = torch.round(x / self.delta) 39 | elif self.round_mode == 'nearest_ste': 40 | x_int = round_ste(x / self.delta) 41 | elif self.round_mode == 'stochastic': 42 | x_floor = torch.floor(x / self.delta) 43 | rest = (x / self.delta) - x_floor # rest of rounding 44 | x_int = x_floor + torch.bernoulli(rest) 45 | print('Draw stochastic sample') 46 | elif self.round_mode == 'learned_hard_sigmoid': 47 | x_floor = torch.floor(x / self.delta) 48 | if self.soft_targets: 49 | x_int = x_floor + self.get_soft_targets() 50 | else: 51 | x_int = x_floor + (self.alpha >= 0).float() 52 | else: 53 | raise ValueError('Wrong rounding mode') 54 | 55 | x_quant = torch.clamp(x_int + self.zero_point, 0, self.n_levels - 1) 56 | x_float_q = (x_quant - self.zero_point) * self.delta 57 | 58 | return x_float_q 59 | 60 | def get_soft_targets(self): 61 | return torch.clamp(torch.sigmoid(self.alpha) * (self.zeta - self.gamma) + self.gamma, 0, 1) 62 | 63 | def init_alpha(self, x: torch.Tensor): 64 | x_floor = torch.floor(x / self.delta) 65 | if self.round_mode == 'learned_hard_sigmoid': 66 | print('Init alpha to be FP32') 67 | rest = (x / self.delta) - x_floor # rest of rounding [0, 1) 68 | alpha = -torch.log((self.zeta - self.gamma) / (rest - self.gamma) - 1) # => sigmoid(alpha) = rest 69 | self.alpha = nn.Parameter(alpha) 70 | else: 71 | raise NotImplementedError 72 | 73 | @torch.jit.export 74 | def extra_repr(self): 75 | return 'bit={}'.format(self.n_bits) 76 | -------------------------------------------------------------------------------- /PTQ4DM/QDrop/quant/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .quant_layer import QuantModule, Union 4 | from .quant_model import QuantModel 5 | from .quant_block import BaseQuantBlock 6 | 7 | 8 | def save_inp_oup_data( 9 | model: QuantModel, 10 | layer: Union[QuantModule, BaseQuantBlock], 11 | cali_data: torch.Tensor, 12 | wq: bool = False, 13 | aq: bool = False, 14 | batch_size: int = 32, 15 | keep_gpu: bool = True, 16 | input_prob: bool = False, 17 | ): 18 | """ 19 | Save input data and output data of a particular layer/block over calibration dataset. 20 | 21 | :param model: QuantModel 22 | :param layer: QuantModule or QuantBlock 23 | :param cali_data: calibration data set 24 | :param weight_quant: use weight_quant quantization 25 | :param act_quant: use act_quant quantization 26 | :param batch_size: mini-batch size for calibration 27 | :param keep_gpu: put saved data on GPU for faster optimization 28 | :return: input and output data 29 | """ 30 | device = next(model.parameters()).device 31 | get_inp_out = GetLayerInpOut( 32 | model, layer, device=device, wq=wq, aq=aq, input_prob=input_prob 33 | ) 34 | cached_batches = [] 35 | 36 | for i in range(int(cali_data[0].size(0) / batch_size)): 37 | if input_prob: 38 | cur_inp, cur_out, cur_sym = get_inp_out( 39 | [_[i * batch_size : (i + 1) * batch_size] for _ in cali_data] 40 | ) 41 | cached_batches.append((cur_inp.cpu(), cur_out.cpu(), cur_sym.cpu())) 42 | else: 43 | cur_inp, cur_out = get_inp_out( 44 | [_[i * batch_size : (i + 1) * batch_size] for _ in cali_data] 45 | ) 46 | cached_batches.append((cur_inp.cpu(), cur_out.cpu())) 47 | cached_inps = torch.cat([x[0] for x in cached_batches]) 48 | cached_outs = torch.cat([x[1] for x in cached_batches]) 49 | if input_prob: 50 | cached_sym = torch.cat([x[2] for x in cached_batches]) 51 | torch.cuda.empty_cache() 52 | if keep_gpu: 53 | cached_inps = cached_inps.to(device) 54 | cached_outs = cached_outs.to(device) 55 | if input_prob: 56 | cached_sym = cached_sym.to(device) 57 | if input_prob: 58 | return (cached_inps, cached_sym), cached_outs 59 | return (cached_inps,), cached_outs 60 | 61 | 62 | class StopForwardException(Exception): 63 | """ 64 | Used to throw and catch an exception to stop traversing the graph 65 | """ 66 | 67 | pass 68 | 69 | 70 | class DataSaverHook: 71 | """ 72 | Forward hook that stores the input and output of a block 73 | """ 74 | 75 | def __init__(self, store_input=False, store_output=False, stop_forward=False): 76 | self.store_input = store_input 77 | self.store_output = store_output 78 | self.stop_forward = stop_forward 79 | 80 | self.input_store = None 81 | self.output_store = None 82 | 83 | def __call__(self, module, input_batch, output_batch): 84 | if self.store_input: 85 | self.input_store = input_batch 86 | if self.store_output: 87 | self.output_store = output_batch 88 | if self.stop_forward: 89 | raise StopForwardException 90 | 91 | 92 | class GetLayerInpOut: 93 | def __init__( 94 | self, 95 | model: QuantModel, 96 | layer: Union[QuantModule, BaseQuantBlock], 97 | device: torch.device, 98 | wq: bool = False, 99 | aq: bool = False, 100 | input_prob: bool = False, 101 | ): 102 | self.model = model 103 | self.layer = layer 104 | self.device = device 105 | self.wq = wq 106 | self.aq = aq 107 | self.data_saver = DataSaverHook( 108 | store_input=True, store_output=True, stop_forward=True 109 | ) 110 | self.input_prob = input_prob 111 | 112 | def __call__(self, model_input): 113 | self.model.set_quant_state(False, False) 114 | 115 | handle = self.layer.register_forward_hook(self.data_saver) 116 | with torch.no_grad(): 117 | try: 118 | _ = self.model(*[_.to(self.device) for _ in model_input]) 119 | except StopForwardException: 120 | pass 121 | if self.input_prob: 122 | input_sym = self.data_saver.input_store[0].detach() 123 | if self.wq or self.aq: 124 | # Recalculate input with network quantized 125 | self.data_saver.store_output = False 126 | self.model.set_quant_state(weight_quant=self.wq, act_quant=self.aq) 127 | try: 128 | _ = self.model(*[_.to(self.device) for _ in model_input]) 129 | except StopForwardException: 130 | pass 131 | 132 | self.data_saver.store_output = True 133 | handle.remove() 134 | 135 | if self.input_prob: 136 | return ( 137 | self.data_saver.input_store[0].detach(), 138 | self.data_saver.output_store.detach(), 139 | input_sym, 140 | ) 141 | return ( 142 | self.data_saver.input_store[0].detach(), 143 | self.data_saver.output_store.detach(), 144 | ) 145 | 146 | 147 | class GradSaverHook: 148 | def __init__(self, store_grad=True): 149 | self.store_grad = store_grad 150 | self.stop_backward = False 151 | self.grad_out = None 152 | 153 | def __call__(self, module, grad_input, grad_output): 154 | if self.store_grad: 155 | self.grad_out = grad_output[0] 156 | if self.stop_backward: 157 | raise StopForwardException 158 | 159 | 160 | class GetLayerGrad: 161 | def __init__( 162 | self, 163 | model: QuantModel, 164 | layer: Union[QuantModule, BaseQuantBlock], 165 | device: torch.device, 166 | act_quant: bool = False, 167 | ): 168 | self.model = model 169 | self.layer = layer 170 | self.device = device 171 | self.act_quant = act_quant 172 | self.data_saver = GradSaverHook(True) 173 | 174 | def __call__(self, model_input): 175 | """ 176 | Compute the gradients of block output, note that we compute the 177 | gradient by calculating the KL loss between fp model and quant model 178 | 179 | :param model_input: calibration data samples 180 | :return: gradients 181 | """ 182 | self.model.eval() 183 | 184 | handle = self.layer.register_backward_hook(self.data_saver) 185 | with torch.enable_grad(): 186 | try: 187 | self.model.zero_grad() 188 | inputs = model_input.to(self.device) 189 | self.model.set_quant_state(False, False) 190 | out_fp = self.model(inputs) 191 | quantize_model_till(self.model, self.layer, self.act_quant) 192 | out_q = self.model(inputs) 193 | loss = F.kl_div( 194 | F.log_softmax(out_q, dim=1), 195 | F.softmax(out_fp, dim=1), 196 | reduction="batchmean", 197 | ) 198 | loss.backward() 199 | except StopForwardException: 200 | pass 201 | 202 | handle.remove() 203 | self.model.set_quant_state(False, False) 204 | self.layer.set_quant_state(True, self.act_quant) 205 | self.model.train() 206 | return self.data_saver.grad_out.data 207 | 208 | 209 | def quantize_model_till( 210 | model: QuantModule, 211 | layer: Union[QuantModule, BaseQuantBlock], 212 | act_quant: bool = False, 213 | ): 214 | """ 215 | We assumes modules are correctly ordered, holds for all models considered 216 | :param model: quantized_model 217 | :param layer: a block or a single layer. 218 | """ 219 | model.set_quant_state(False, False) 220 | for name, module in model.named_modules(): 221 | if isinstance(module, (QuantModule, BaseQuantBlock)): 222 | module.set_quant_state(True, act_quant) 223 | if module == layer: 224 | break 225 | -------------------------------------------------------------------------------- /PTQ4DM/QDrop/quant/fold_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | 5 | 6 | class StraightThrough(nn.Module): 7 | def __int__(self): 8 | super().__init__() 9 | 10 | def forward(self, input): 11 | return input 12 | 13 | 14 | def _fold_bn(conv_module, bn_module): 15 | w = conv_module.weight.data 16 | y_mean = bn_module.running_mean 17 | y_var = bn_module.running_var 18 | safe_std = torch.sqrt(y_var + bn_module.eps) 19 | w_view = (conv_module.out_channels, 1, 1, 1) 20 | if bn_module.affine: 21 | weight = w * (bn_module.weight / safe_std).view(w_view) 22 | beta = bn_module.bias - bn_module.weight * y_mean / safe_std 23 | if conv_module.bias is not None: 24 | bias = bn_module.weight * conv_module.bias / safe_std + beta 25 | else: 26 | bias = beta 27 | else: 28 | weight = w / safe_std.view(w_view) 29 | beta = -y_mean / safe_std 30 | if conv_module.bias is not None: 31 | bias = conv_module.bias / safe_std + beta 32 | else: 33 | bias = beta 34 | return weight, bias 35 | 36 | 37 | def fold_bn_into_conv(conv_module, bn_module): 38 | w, b = _fold_bn(conv_module, bn_module) 39 | if conv_module.bias is None: 40 | conv_module.bias = nn.Parameter(b) 41 | else: 42 | conv_module.bias.data = b 43 | conv_module.weight.data = w 44 | # set bn running stats 45 | bn_module.running_mean = bn_module.bias.data 46 | bn_module.running_var = bn_module.weight.data ** 2 47 | 48 | 49 | def reset_bn(module: nn.BatchNorm2d): 50 | if module.track_running_stats: 51 | module.running_mean.zero_() 52 | module.running_var.fill_(1-module.eps) 53 | # we do not reset numer of tracked batches here 54 | # self.num_batches_tracked.zero_() 55 | if module.affine: 56 | init.ones_(module.weight) 57 | init.zeros_(module.bias) 58 | 59 | 60 | def is_bn(m): 61 | return isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) 62 | 63 | 64 | def is_absorbing(m): 65 | return (isinstance(m, nn.Conv2d)) or isinstance(m, nn.Linear) 66 | 67 | 68 | def search_fold_and_remove_bn(model): 69 | model.eval() 70 | prev = None 71 | for n, m in model.named_children(): 72 | if is_bn(m) and is_absorbing(prev): 73 | fold_bn_into_conv(prev, m) 74 | # set the bn module to straight through 75 | setattr(model, n, StraightThrough()) 76 | elif is_absorbing(m): 77 | prev = m 78 | else: 79 | prev = search_fold_and_remove_bn(m) 80 | return prev 81 | 82 | 83 | def search_fold_and_reset_bn(model): 84 | model.eval() 85 | prev = None 86 | for n, m in model.named_children(): 87 | if is_bn(m) and is_absorbing(prev): 88 | fold_bn_into_conv(prev, m) 89 | # reset_bn(m) 90 | else: 91 | search_fold_and_reset_bn(m) 92 | prev = m 93 | 94 | -------------------------------------------------------------------------------- /PTQ4DM/QDrop/quant/layer_recon.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .quant_layer import QuantModule, lp_loss 3 | from .quant_model import QuantModel 4 | from .block_recon import LinearTempDecay 5 | from .adaptive_rounding import AdaRoundQuantizer 6 | from .set_weight_quantize_params import weight_get_quant_state, get_init 7 | from .set_act_quantize_params import set_act_quantize_params 8 | 9 | 10 | def layer_reconstruction(model: QuantModel, layer: QuantModule, cali_data: torch.Tensor, 11 | batch_size: int = 32, iters: int = 20000, weight: float = 0.001, opt_mode: str = 'mse', 12 | act_quant: bool = False, b_range: tuple = (20, 2), 13 | warmup: float = 0.0, p: float = 2.0, lr: float = 4e-5, wwq: bool = True, waq: bool = True, 14 | order: str = 'together', input_prob: float = 1.0, keep_gpu: bool = True): 15 | """ 16 | Block reconstruction to optimize the output from each layer. 17 | 18 | :param model: QuantModel 19 | :param layer: QuantModule that needs to be optimized 20 | :param cali_data: data for calibration, typically 1024 training images, as described in AdaRound 21 | :param batch_size: mini-batch size for reconstruction 22 | :param iters: optimization iterations for reconstruction, 23 | :param weight: the weight of rounding regularization term 24 | :param opt_mode: optimization mode 25 | :param asym: asymmetric optimization designed in AdaRound, use quant input to reconstruct fp output 26 | :param include_act_func: optimize the output after activation function 27 | :param b_range: temperature range 28 | :param warmup: proportion of iterations that no scheduling for temperature 29 | :param act_quant: use activation quantization or not. 30 | :param lr: learning rate for act delta learning 31 | :param p: L_p norm minimization 32 | """ 33 | 34 | '''get input and set scale''' 35 | cached_inps, cached_outs = get_init(model, layer, cali_data, wq=wwq, aq=waq, batch_size=batch_size, 36 | input_prob=True, keep_gpu=keep_gpu) 37 | if act_quant and order == 'together': 38 | set_act_quantize_params(layer, cali_data=cached_inps[0][:min(256, cached_inps[0].size(0))], awq=True, order=order) 39 | 40 | '''set state''' 41 | cur_weight, cur_act = weight_get_quant_state(order, act_quant) 42 | layer.set_quant_state(cur_weight, cur_act) 43 | 44 | '''set quantizer''' 45 | round_mode = 'learned_hard_sigmoid' 46 | # Replace weight quantizer to AdaRoundQuantizer 47 | w_para, a_para = [], [] 48 | w_opt, a_opt = None, None 49 | scheduler, a_scheduler = None, None 50 | '''weight''' 51 | layer.weight_quantizer = AdaRoundQuantizer(uaq=layer.weight_quantizer, round_mode=round_mode, 52 | weight_tensor=layer.org_weight.data) 53 | layer.weight_quantizer.soft_targets = True 54 | w_para += [layer.weight_quantizer.alpha] 55 | 56 | '''activation''' 57 | if act_quant and order == 'together' and layer.act_quantizer.delta is not None: 58 | layer.act_quantizer.delta = torch.nn.Parameter(torch.tensor(layer.act_quantizer.delta)) 59 | a_para += [layer.act_quantizer.delta] 60 | layer.act_quantizer.is_training = True 61 | 62 | if len(w_para) != 0: 63 | w_opt = torch.optim.Adam(w_para) 64 | if len(a_para) != 0: 65 | a_opt = torch.optim.Adam(a_para, lr=lr) 66 | a_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(a_opt, T_max=iters, eta_min=0.) 67 | loss_mode = 'relaxation' 68 | rec_loss = opt_mode 69 | loss_func = LossFunction(layer, round_loss=loss_mode, weight=weight, 70 | max_count=iters, rec_loss=rec_loss, b_range=b_range, 71 | decay_start=0, warmup=warmup, p=p) 72 | device = 'cuda' 73 | sz = cached_inps[0].size(0) 74 | for i in range(iters): 75 | idx = torch.randint(0, sz, (batch_size,)) 76 | # cur_inp = cached_inps[0][idx].to(device) 77 | cur_inp, cur_sym = cached_inps[0][idx].to(device), cached_inps[1][idx].to(device) 78 | if input_prob < 1.0: 79 | cur_inp = torch.where(torch.rand_like(cur_inp) < input_prob, cur_inp, cur_sym) 80 | cur_out = cached_outs[idx].to(device) 81 | 82 | w_opt.zero_grad() 83 | if a_opt: 84 | a_opt.zero_grad() 85 | out_quant = layer(cur_inp) 86 | 87 | err = loss_func(out_quant, cur_out) 88 | 89 | err.backward(retain_graph=True) 90 | w_opt.step() 91 | if a_opt: 92 | a_opt.step() 93 | if scheduler: 94 | scheduler.step() 95 | if a_scheduler: 96 | a_scheduler.step() 97 | torch.cuda.empty_cache() 98 | 99 | layer.weight_quantizer.soft_targets = False 100 | layer.act_quantizer.is_training = False 101 | '''Case 3''' 102 | if act_quant and order == 'after' and waq: 103 | set_act_quantize_params(layer, cached_inps[0], awq=True, order=order) 104 | 105 | 106 | 107 | class LossFunction: 108 | def __init__(self, 109 | layer: QuantModule, 110 | round_loss: str = 'relaxation', 111 | weight: float = 1., 112 | rec_loss: str = 'mse', 113 | max_count: int = 2000, 114 | b_range: tuple = (10, 2), 115 | decay_start: float = 0.0, 116 | warmup: float = 0.0, 117 | p: float = 2.): 118 | 119 | self.layer = layer 120 | self.round_loss = round_loss 121 | self.weight = weight 122 | self.rec_loss = rec_loss 123 | self.loss_start = max_count * warmup 124 | self.p = p 125 | 126 | self.temp_decay = LinearTempDecay(max_count, rel_start_decay=warmup + (1 - warmup) * decay_start, 127 | start_b=b_range[0], end_b=b_range[1]) 128 | self.count = 0 129 | 130 | def __call__(self, pred, tgt, grad=None): 131 | """ 132 | Compute the total loss for adaptive rounding: 133 | rec_loss is the quadratic output reconstruction loss, round_loss is 134 | a regularization term to optimize the rounding policy 135 | 136 | :param pred: output from quantized model 137 | :param tgt: output from FP model 138 | :param grad: gradients to compute fisher information 139 | :return: total loss function 140 | """ 141 | self.count += 1 142 | if self.rec_loss == 'mse': 143 | rec_loss = lp_loss(pred, tgt, p=self.p) 144 | elif self.rec_loss == 'fisher_diag': 145 | rec_loss = ((pred - tgt).pow(2) * grad.pow(2)).sum(1).mean() 146 | elif self.rec_loss == 'fisher_full': 147 | a = (pred - tgt).abs() 148 | grad = grad.abs() 149 | batch_dotprod = torch.sum(a * grad, (1, 2, 3)).view(-1, 1, 1, 1) 150 | rec_loss = (batch_dotprod * a * grad).mean() / 100 151 | else: 152 | raise ValueError('Not supported reconstruction loss function: {}'.format(self.rec_loss)) 153 | 154 | b = self.temp_decay(self.count) 155 | if self.count < self.loss_start or self.round_loss == 'none': 156 | b = round_loss = 0 157 | elif self.round_loss == 'relaxation': 158 | round_loss = 0 159 | round_vals = self.layer.weight_quantizer.get_soft_targets() 160 | round_loss += self.weight * (1 - ((round_vals - .5).abs() * 2).pow(b)).sum() 161 | else: 162 | raise NotImplementedError 163 | 164 | total_loss = rec_loss + round_loss 165 | if self.count % 500 == 0: 166 | print('Total loss:\t{:.3f} (rec:{:.3f}, round:{:.3f})\tb={:.2f}\tcount={}'.format( 167 | float(total_loss), float(rec_loss), float(round_loss), b, self.count)) 168 | return total_loss 169 | -------------------------------------------------------------------------------- /PTQ4DM/QDrop/quant/quant_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .quant_layer import QuantModule, UniformAffineQuantizer 3 | from QDrop.models.resnet import BasicBlock, Bottleneck 4 | from QDrop.models.regnet import ResBottleneckBlock 5 | from QDrop.models.mobilenetv2 import InvertedResidual 6 | from QDrop.models.mnasnet import _InvertedResidual 7 | 8 | 9 | class BaseQuantBlock(nn.Module): 10 | """ 11 | Base implementation of block structures for all networks. 12 | Due to the branch architecture, we have to perform activation function 13 | and quantization after the elemental-wise add operation, therefore, we 14 | put this part in this class. 15 | """ 16 | def __init__(self): 17 | super().__init__() 18 | self.use_weight_quant = False 19 | self.use_act_quant = False 20 | self.ignore_reconstruction = False 21 | 22 | def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False): 23 | # setting weight quantization here does not affect actual forward pass 24 | self.use_weight_quant = weight_quant 25 | self.use_act_quant = act_quant 26 | for m in self.modules(): 27 | if isinstance(m, QuantModule): 28 | m.set_quant_state(weight_quant, act_quant) 29 | 30 | 31 | class QuantBasicBlock(BaseQuantBlock): 32 | """ 33 | Implementation of Quantized BasicBlock used in ResNet-18 and ResNet-34. 34 | """ 35 | def __init__(self, basic_block: BasicBlock, weight_quant_params: dict = {}, act_quant_params: dict = {}): 36 | super().__init__() 37 | self.conv1 = QuantModule(basic_block.conv1, weight_quant_params, act_quant_params) 38 | self.conv1.activation_function = basic_block.relu1 39 | self.conv2 = QuantModule(basic_block.conv2, weight_quant_params, act_quant_params, disable_act_quant=True) 40 | 41 | if basic_block.downsample is None: 42 | self.downsample = None 43 | else: 44 | self.downsample = QuantModule(basic_block.downsample[0], weight_quant_params, act_quant_params, 45 | disable_act_quant=True) 46 | self.activation_function = basic_block.relu2 47 | self.act_quantizer = UniformAffineQuantizer(**act_quant_params) 48 | 49 | def forward(self, x): 50 | residual = x if self.downsample is None else self.downsample(x) 51 | out = self.conv1(x) 52 | out = self.conv2(out) 53 | out += residual 54 | out = self.activation_function(out) 55 | if self.use_act_quant: 56 | out = self.act_quantizer(out) 57 | return out 58 | 59 | 60 | class QuantBottleneck(BaseQuantBlock): 61 | """ 62 | Implementation of Quantized Bottleneck Block used in ResNet-50, -101 and -152. 63 | """ 64 | 65 | def __init__(self, bottleneck: Bottleneck, weight_quant_params: dict = {}, act_quant_params: dict = {}): 66 | super().__init__() 67 | self.conv1 = QuantModule(bottleneck.conv1, weight_quant_params, act_quant_params) 68 | self.conv1.activation_function = bottleneck.relu1 69 | self.conv2 = QuantModule(bottleneck.conv2, weight_quant_params, act_quant_params) 70 | self.conv2.activation_function = bottleneck.relu2 71 | self.conv3 = QuantModule(bottleneck.conv3, weight_quant_params, act_quant_params, disable_act_quant=True) 72 | 73 | if bottleneck.downsample is None: 74 | self.downsample = None 75 | else: 76 | self.downsample = QuantModule(bottleneck.downsample[0], weight_quant_params, act_quant_params, 77 | disable_act_quant=True) 78 | # modify the activation function to ReLU 79 | self.activation_function = bottleneck.relu3 80 | self.act_quantizer = UniformAffineQuantizer(**act_quant_params) 81 | 82 | def forward(self, x): 83 | residual = x if self.downsample is None else self.downsample(x) 84 | out = self.conv1(x) 85 | out = self.conv2(out) 86 | out = self.conv3(out) 87 | out += residual 88 | out = self.activation_function(out) 89 | if self.use_act_quant: 90 | out = self.act_quantizer(out) 91 | return out 92 | 93 | 94 | class QuantResBottleneckBlock(BaseQuantBlock): 95 | """ 96 | Implementation of Quantized Bottleneck Blockused in RegNetX (no SE module). 97 | """ 98 | 99 | def __init__(self, bottleneck: ResBottleneckBlock, weight_quant_params: dict = {}, act_quant_params: dict = {}): 100 | super().__init__() 101 | self.conv1 = QuantModule(bottleneck.f.a, weight_quant_params, act_quant_params) 102 | self.conv1.activation_function = bottleneck.f.a_relu 103 | self.conv2 = QuantModule(bottleneck.f.b, weight_quant_params, act_quant_params) 104 | self.conv2.activation_function = bottleneck.f.b_relu 105 | self.conv3 = QuantModule(bottleneck.f.c, weight_quant_params, act_quant_params, disable_act_quant=True) 106 | 107 | if bottleneck.proj_block: 108 | self.downsample = QuantModule(bottleneck.proj, weight_quant_params, act_quant_params, 109 | disable_act_quant=True) 110 | else: 111 | self.downsample = None 112 | # copying all attributes in original block 113 | self.proj_block = bottleneck.proj_block 114 | 115 | self.activation_function = bottleneck.relu 116 | self.act_quantizer = UniformAffineQuantizer(**act_quant_params) 117 | 118 | def forward(self, x): 119 | residual = x if not self.proj_block else self.downsample(x) 120 | out = self.conv1(x) 121 | out = self.conv2(out) 122 | out = self.conv3(out) 123 | out += residual 124 | out = self.activation_function(out) 125 | if self.use_act_quant: 126 | out = self.act_quantizer(out) 127 | return out 128 | 129 | 130 | class QuantInvertedResidual(BaseQuantBlock): 131 | """ 132 | Implementation of Quantized Inverted Residual Block used in MobileNetV2. 133 | Inverted Residual does not have activation function. 134 | """ 135 | 136 | def __init__(self, inv_res: InvertedResidual, weight_quant_params: dict = {}, act_quant_params: dict = {}): 137 | super().__init__() 138 | 139 | self.use_res_connect = inv_res.use_res_connect 140 | self.expand_ratio = inv_res.expand_ratio 141 | if self.expand_ratio == 1: 142 | self.conv = nn.Sequential( 143 | QuantModule(inv_res.conv[0], weight_quant_params, act_quant_params), 144 | QuantModule(inv_res.conv[3], weight_quant_params, act_quant_params, disable_act_quant=True), 145 | ) 146 | self.conv[0].activation_function = nn.ReLU6() 147 | else: 148 | self.conv = nn.Sequential( 149 | QuantModule(inv_res.conv[0], weight_quant_params, act_quant_params), 150 | QuantModule(inv_res.conv[3], weight_quant_params, act_quant_params), 151 | QuantModule(inv_res.conv[6], weight_quant_params, act_quant_params, disable_act_quant=True), 152 | ) 153 | self.conv[0].activation_function = nn.ReLU6() 154 | self.conv[1].activation_function = nn.ReLU6() 155 | self.act_quantizer = UniformAffineQuantizer(**act_quant_params) 156 | 157 | def forward(self, x): 158 | if self.use_res_connect: 159 | out = x + self.conv(x) 160 | else: 161 | out = self.conv(x) 162 | if self.use_act_quant: 163 | out = self.act_quantizer(out) 164 | return out 165 | 166 | 167 | class _QuantInvertedResidual(BaseQuantBlock): 168 | def __init__(self, _inv_res: _InvertedResidual, weight_quant_params: dict = {}, act_quant_params: dict = {}): 169 | super().__init__() 170 | 171 | self.apply_residual = _inv_res.apply_residual 172 | self.conv = nn.Sequential( 173 | QuantModule(_inv_res.layers[0], weight_quant_params, act_quant_params), 174 | QuantModule(_inv_res.layers[3], weight_quant_params, act_quant_params), 175 | QuantModule(_inv_res.layers[6], weight_quant_params, act_quant_params, disable_act_quant=True), 176 | ) 177 | self.conv[0].activation_function = nn.ReLU() 178 | self.conv[1].activation_function = nn.ReLU() 179 | self.act_quantizer = UniformAffineQuantizer(**act_quant_params) 180 | 181 | def forward(self, x): 182 | if self.apply_residual: 183 | out = x + self.conv(x) 184 | else: 185 | out = self.conv(x) 186 | if self.use_act_quant: 187 | out = self.act_quantizer(out) 188 | return out 189 | 190 | 191 | specials = { 192 | BasicBlock: QuantBasicBlock, 193 | Bottleneck: QuantBottleneck, 194 | ResBottleneckBlock: QuantResBottleneckBlock, 195 | InvertedResidual: QuantInvertedResidual, 196 | _InvertedResidual: _QuantInvertedResidual, 197 | } 198 | -------------------------------------------------------------------------------- /PTQ4DM/QDrop/quant/quant_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .quant_block import specials, BaseQuantBlock 3 | from .quant_layer import QuantModule, StraightThrough, UniformAffineQuantizer 4 | from .fold_bn import search_fold_and_remove_bn 5 | 6 | 7 | class QuantModel(nn.Module): 8 | def __init__( 9 | self, 10 | model: nn.Module, 11 | weight_quant_params: dict = {}, 12 | act_quant_params: dict = {}, 13 | ): 14 | super().__init__() 15 | search_fold_and_remove_bn(model) 16 | self.model = model 17 | self.quant_module_refactor(self.model, weight_quant_params, act_quant_params) 18 | 19 | def quant_module_refactor( 20 | self, 21 | module: nn.Module, 22 | weight_quant_params: dict = {}, 23 | act_quant_params: dict = {}, 24 | ): 25 | """ 26 | Recursively replace the normal conv2d and Linear layer to QuantModule 27 | :param module: nn.Module with nn.Conv2d or nn.Linear in its children 28 | :param weight_quant_params: quantization parameters like n_bits for weight quantizer 29 | :param act_quant_params: quantization parameters like n_bits for activation quantizer 30 | """ 31 | prev_quantmodule = None 32 | for name, child_module in module.named_children(): 33 | if type(child_module) in specials: 34 | setattr( 35 | module, 36 | name, 37 | specials[type(child_module)]( 38 | child_module, weight_quant_params, act_quant_params 39 | ), 40 | ) 41 | elif isinstance(child_module, (nn.Conv2d, nn.Linear)): 42 | setattr( 43 | module, 44 | name, 45 | QuantModule(child_module, weight_quant_params, act_quant_params), 46 | ) 47 | prev_quantmodule = getattr(module, name) 48 | 49 | elif isinstance(child_module, (nn.ReLU, nn.ReLU6)): 50 | if prev_quantmodule is not None: 51 | prev_quantmodule.activation_function = child_module 52 | setattr(module, name, StraightThrough()) 53 | else: 54 | continue 55 | 56 | elif isinstance(child_module, StraightThrough): 57 | continue 58 | 59 | else: 60 | self.quant_module_refactor( 61 | child_module, weight_quant_params, act_quant_params 62 | ) 63 | 64 | def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False): 65 | for m in self.model.modules(): 66 | if isinstance(m, (QuantModule, BaseQuantBlock)): 67 | m.set_quant_state(weight_quant, act_quant) 68 | 69 | def forward(self, *args, **kwargs): 70 | if len(args)==1 and type(args[0]) in [tuple,list]: 71 | return self.model(*args[0]) 72 | else: 73 | return self.model(*args, **kwargs) 74 | 75 | def set_first_last_layer_to_8bit(self): 76 | w_list, a_list = [], [] 77 | for module in self.model.modules(): 78 | if isinstance(module, UniformAffineQuantizer): 79 | if module.leaf_param: 80 | a_list.append(module) 81 | else: 82 | w_list.append(module) 83 | w_list[0].bitwidth_refactor(8) 84 | w_list[-1].bitwidth_refactor(8) 85 | "the image input has been in 0~255, set the last layer's input to 8-bit" 86 | a_list[-2].bitwidth_refactor(8) 87 | 88 | def set_cosine_embedding_layer_to_32bit(self): 89 | w_list, a_list = [], [] 90 | for module in self.model.modules(): 91 | if isinstance(module, UniformAffineQuantizer): 92 | if module.leaf_param: 93 | a_list.append(module) 94 | else: 95 | w_list.append(module) 96 | w_list[0].bitwidth_refactor(32) 97 | a_list[0].bitwidth_refactor(32) 98 | # a_list[1].bitwidth_refactor(32) 99 | # a_list[2].bitwidth_refactor(32) 100 | w_list[-1].bitwidth_refactor(8) 101 | "the image input has been in 0~255, set the last layer's input to 8-bit" 102 | a_list[-2].bitwidth_refactor(8) 103 | 104 | def disable_network_output_quantization(self): 105 | module_list = [] 106 | for m in self.model.modules(): 107 | if isinstance(m, QuantModule): 108 | module_list += [m] 109 | module_list[-1].disable_act_quant = True 110 | -------------------------------------------------------------------------------- /PTQ4DM/QDrop/quant/set_act_quantize_params.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .quant_layer import QuantModule 3 | from .quant_block import BaseQuantBlock 4 | from .quant_model import QuantModel 5 | from typing import Union 6 | 7 | 8 | def set_act_quantize_params( 9 | module: Union[QuantModel, QuantModule, BaseQuantBlock], 10 | cali_data, 11 | awq: bool = False, 12 | order: str = "before", 13 | batch_size: int = 256, 14 | ): 15 | weight_quant, act_quant = act_get_quant_state(order, awq) 16 | module.set_quant_state(weight_quant, act_quant) 17 | 18 | for t in module.modules(): 19 | if isinstance(t, (QuantModule, BaseQuantBlock)): 20 | t.act_quantizer.set_inited(False) 21 | 22 | """set or init step size and zero point in the activation quantizer""" 23 | if not isinstance(cali_data, (tuple, list)): 24 | batch_size = min(batch_size, cali_data.size(0)) 25 | with torch.no_grad(): 26 | for i in range(int(cali_data.size(0) / batch_size)): 27 | module(cali_data[i * batch_size : (i + 1) * batch_size].cuda()) 28 | torch.cuda.empty_cache() 29 | 30 | for t in module.modules(): 31 | if isinstance(t, (QuantModule, BaseQuantBlock)): 32 | t.act_quantizer.set_inited(True) 33 | else: 34 | batch_size = min(batch_size, cali_data[0].size(0)) 35 | with torch.no_grad(): 36 | for i in range(int(cali_data[0].size(0) / batch_size)): 37 | module( 38 | *[ 39 | _[i * batch_size : (i + 1) * batch_size].cuda() 40 | for _ in cali_data 41 | ] 42 | ) 43 | torch.cuda.empty_cache() 44 | 45 | for t in module.modules(): 46 | if isinstance(t, (QuantModule, BaseQuantBlock)): 47 | t.act_quantizer.set_inited(True) 48 | 49 | 50 | def act_get_quant_state(order, awq): 51 | if order == "before": 52 | weight_quant, act_quant = False, True 53 | elif order == "after": 54 | weight_quant, act_quant = awq, True 55 | elif order == "together": 56 | weight_quant, act_quant = True, True 57 | else: 58 | raise NotImplementedError 59 | return weight_quant, act_quant 60 | -------------------------------------------------------------------------------- /PTQ4DM/QDrop/quant/set_weight_quantize_params.py: -------------------------------------------------------------------------------- 1 | from .quant_layer import QuantModule 2 | from .data_utils import save_inp_oup_data 3 | 4 | 5 | def get_init( 6 | model, 7 | block, 8 | cali_data, 9 | wq, 10 | aq, 11 | batch_size, 12 | input_prob: bool = False, 13 | keep_gpu: bool = True, 14 | ): 15 | cached_inps, cached_outs = save_inp_oup_data( 16 | model, 17 | block, 18 | cali_data, 19 | wq, 20 | aq, 21 | batch_size, 22 | input_prob=input_prob, 23 | keep_gpu=keep_gpu, 24 | ) 25 | return cached_inps, cached_outs 26 | 27 | 28 | def set_weight_quantize_params(model): 29 | print(f"set_weight_quantize_params") 30 | for name, module in model.named_modules(): 31 | if isinstance(module, QuantModule): 32 | module.weight_quantizer.set_inited(False) 33 | """caculate the step size and zero point for weight quantizer""" 34 | module.weight_quantizer(module.weight) 35 | module.weight_quantizer.set_inited(True) 36 | 37 | 38 | def weight_get_quant_state(order, act_quant): 39 | if not act_quant: 40 | return True, False 41 | if order == "before": 42 | weight_quant, act_quant = True, True 43 | elif order == "after": 44 | weight_quant, act_quant = True, False 45 | elif order == "together": 46 | weight_quant, act_quant = True, True 47 | else: 48 | raise NotImplementedError 49 | return weight_quant, act_quant 50 | 51 | 52 | def save_quantized_weight(model): 53 | for module in model.modules(): 54 | if isinstance(module, QuantModule): 55 | module.weight.data = module.weight_quantizer(module.weight) 56 | -------------------------------------------------------------------------------- /PTQ4DM/baseline.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=3 2 | export PYTHONPATH=".:guided-diffusion:improved-diffusion" 3 | # MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3 --learn_sigma True" 4 | # DIFFUSION_FLAGS="--diffusion_steps 4000 --timestep_respacing 2000 --use_ddim True --noise_schedule cosine" 5 | # DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule cosine" 6 | # python guided-diffusion/scripts/image_sample.py $MODEL_FLAGS --model_path guided-diffusion/models/imagenet64_uncond_100M_1500K.pt $DIFFUSION_FLAGS --num_samples 100 --batch_size 100 7 | 8 | 9 | # MODEL_FLAGS="--image_size 32 --num_channels 128 --num_res_blocks 3 --learn_sigma True --dropout 0.3" 10 | # DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule cosine" 11 | # python improved-diffusion/scripts/image_sample.py $MODEL_FLAGS --model_path guided-diffusion/models/cifar10_uncond_50M_500K.pt $DIFFUSION_FLAGS --num_samples 10 --batch_size 10 12 | 13 | MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3 --learn_sigma True" 14 | DIFFUSION_FLAGS="--diffusion_steps 4000 --timestep_respacing 250 --use_ddim True --noise_schedule cosine" 15 | # python guided-diffusion/scripts/image_sample.py $MODEL_FLAGS --model_path guided-diffusion/models/cifar10_uncond_50M_500K.pt $DIFFUSION_FLAGS --num_samples 10 --batch_size 10 16 | python improved-diffusion/scripts/image_sample.py $MODEL_FLAGS --model_path guided-diffusion/models/imagenet64_uncond_100M_1500K.pt $DIFFUSION_FLAGS --num_samples 10000 --batch_size 32 -------------------------------------------------------------------------------- /PTQ4DM/guided-diffusion/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /PTQ4DM/guided-diffusion/datasets/README.md: -------------------------------------------------------------------------------- 1 | # Downloading datasets 2 | 3 | This directory includes instructions and scripts for downloading ImageNet and LSUN bedrooms for use in this codebase. 4 | 5 | ## Class-conditional ImageNet 6 | 7 | For our class-conditional models, we use the official ILSVRC2012 dataset with manual center cropping and downsampling. To obtain this dataset, navigate to [this page on image-net.org](http://www.image-net.org/challenges/LSVRC/2012/downloads) and sign in (or create an account if you do not already have one). Then click on the link reading "Training images (Task 1 & 2)". This is a 138GB tar file containing 1000 sub-tar files, one per class. 8 | 9 | Once the file is downloaded, extract it and look inside. You should see 1000 `.tar` files. You need to extract each of these, which may be impractical to do by hand on your operating system. To automate the process on a Unix-based system, you can `cd` into the directory and run this short shell script: 10 | 11 | ``` 12 | for file in *.tar; do tar xf "$file"; rm "$file"; done 13 | ``` 14 | 15 | This will extract and remove each tar file in turn. 16 | 17 | Once all of the images have been extracted, the resulting directory should be usable as a data directory (the `--data_dir` argument for the training script). The filenames should all start with WNID (class ids) followed by underscores, like `n01440764_2708.JPEG`. Conveniently (but not by accident) this is how the automated data-loader expects to discover class labels. 18 | 19 | ## LSUN bedroom 20 | 21 | To download and pre-process LSUN bedroom, clone [fyu/lsun](https://github.com/fyu/lsun) on GitHub and run their download script `python3 download.py bedroom`. The result will be an "lmdb" database named like `bedroom_train_lmdb`. You can pass this to our [lsun_bedroom.py](lsun_bedroom.py) script like so: 22 | 23 | ``` 24 | python lsun_bedroom.py bedroom_train_lmdb lsun_train_output_dir 25 | ``` 26 | 27 | This creates a directory called `lsun_train_output_dir`. This directory can be passed to the training scripts via the `--data_dir` argument. 28 | -------------------------------------------------------------------------------- /PTQ4DM/guided-diffusion/datasets/lsun_bedroom.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert an LSUN lmdb database into a directory of images. 3 | """ 4 | 5 | import argparse 6 | import io 7 | import os 8 | 9 | from PIL import Image 10 | import lmdb 11 | import numpy as np 12 | 13 | 14 | def read_images(lmdb_path, image_size): 15 | env = lmdb.open(lmdb_path, map_size=1099511627776, max_readers=100, readonly=True) 16 | with env.begin(write=False) as transaction: 17 | cursor = transaction.cursor() 18 | for _, webp_data in cursor: 19 | img = Image.open(io.BytesIO(webp_data)) 20 | width, height = img.size 21 | scale = image_size / min(width, height) 22 | img = img.resize( 23 | (int(round(scale * width)), int(round(scale * height))), 24 | resample=Image.BOX, 25 | ) 26 | arr = np.array(img) 27 | h, w, _ = arr.shape 28 | h_off = (h - image_size) // 2 29 | w_off = (w - image_size) // 2 30 | arr = arr[h_off : h_off + image_size, w_off : w_off + image_size] 31 | yield arr 32 | 33 | 34 | def dump_images(out_dir, images, prefix): 35 | if not os.path.exists(out_dir): 36 | os.mkdir(out_dir) 37 | for i, img in enumerate(images): 38 | Image.fromarray(img).save(os.path.join(out_dir, f"{prefix}_{i:07d}.png")) 39 | 40 | 41 | def main(): 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("--image-size", help="new image size", type=int, default=256) 44 | parser.add_argument("--prefix", help="class name", type=str, default="bedroom") 45 | parser.add_argument("lmdb_path", help="path to an LSUN lmdb database") 46 | parser.add_argument("out_dir", help="path to output directory") 47 | args = parser.parse_args() 48 | 49 | images = read_images(args.lmdb_path, args.image_size) 50 | dump_images(args.out_dir, images, args.prefix) 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /PTQ4DM/guided-diffusion/evaluations/README.md: -------------------------------------------------------------------------------- 1 | # Evaluations 2 | 3 | To compare different generative models, we use FID, sFID, Precision, Recall, and Inception Score. These metrics can all be calculated using batches of samples, which we store in `.npz` (numpy) files. 4 | 5 | # Download batches 6 | 7 | We provide pre-computed sample batches for the reference datasets, our diffusion models, and several baselines we compare against. These are all stored in `.npz` format. 8 | 9 | Reference dataset batches contain pre-computed statistics over the whole dataset, as well as 10,000 images for computing Precision and Recall. All other batches contain 50,000 images which can be used to compute statistics and Precision/Recall. 10 | 11 | Here are links to download all of the sample and reference batches: 12 | 13 | * LSUN 14 | * LSUN bedroom: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/VIRTUAL_lsun_bedroom256.npz) 15 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/admnet_dropout_lsun_bedroom.npz) 16 | * [DDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/ddpm_lsun_bedroom.npz) 17 | * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/iddpm_lsun_bedroom.npz) 18 | * [StyleGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/stylegan_lsun_bedroom.npz) 19 | * LSUN cat: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/VIRTUAL_lsun_cat256.npz) 20 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/admnet_dropout_lsun_cat.npz) 21 | * [StyleGAN2](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/stylegan2_lsun_cat.npz) 22 | * LSUN horse: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/VIRTUAL_lsun_horse256.npz) 23 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_dropout_lsun_horse.npz) 24 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_lsun_horse.npz) 25 | 26 | * ImageNet 27 | * ImageNet 64x64: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz) 28 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/admnet_imagenet64.npz) 29 | * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/iddpm_imagenet64.npz) 30 | * [BigGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/biggan_deep_imagenet64.npz) 31 | * ImageNet 128x128: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz) 32 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_imagenet128.npz) 33 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_imagenet128.npz) 34 | * [ADM-G, 25 steps](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_25step_imagenet128.npz) 35 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/biggan_deep_trunc1_imagenet128.npz) 36 | * ImageNet 256x256: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz) 37 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_imagenet256.npz) 38 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_imagenet256.npz) 39 | * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_25step_imagenet256.npz) 40 | * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_upsampled_imagenet256.npz) 41 | * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_upsampled_imagenet256.npz) 42 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/biggan_deep_trunc1_imagenet256.npz) 43 | * ImageNet 512x512: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz) 44 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_imagenet512.npz) 45 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_imagenet512.npz) 46 | * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_25step_imagenet512.npz) 47 | * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_upsampled_imagenet512.npz) 48 | * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_upsampled_imagenet512.npz) 49 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/biggan_deep_trunc1_imagenet512.npz) 50 | 51 | # Run evaluations 52 | 53 | First, generate or download a batch of samples and download the corresponding reference batch for the given dataset. For this example, we'll use ImageNet 256x256, so the refernce batch is `VIRTUAL_imagenet256_labeled.npz` and we can use the sample batch `admnet_guided_upsampled_imagenet256.npz`. 54 | 55 | Next, run the `evaluator.py` script. The requirements of this script can be found in [requirements.txt](requirements.txt). Pass two arguments to the script: the reference batch and the sample batch. The script will download the InceptionV3 model used for evaluations into the current working directory (if it is not already present). This file is roughly 100MB. 56 | 57 | The output of the script will look something like this, where the first `...` is a bunch of verbose TensorFlow logging: 58 | 59 | ``` 60 | $ python evaluator.py VIRTUAL_imagenet256_labeled.npz admnet_guided_upsampled_imagenet256.npz 61 | ... 62 | computing reference batch activations... 63 | computing/reading reference batch statistics... 64 | computing sample batch activations... 65 | computing/reading sample batch statistics... 66 | Computing evaluations... 67 | Inception Score: 215.8370361328125 68 | FID: 3.9425574129223264 69 | sFID: 6.140433703346162 70 | Precision: 0.8265 71 | Recall: 0.5309 72 | ``` 73 | -------------------------------------------------------------------------------- /PTQ4DM/guided-diffusion/evaluations/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu>=2.0 2 | scipy 3 | requests 4 | tqdm -------------------------------------------------------------------------------- /PTQ4DM/guided-diffusion/guided_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for "Improved Denoising Diffusion Probabilistic Models". 3 | """ 4 | -------------------------------------------------------------------------------- /PTQ4DM/guided-diffusion/guided_diffusion/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | # Change this to reflect your cluster layout. 15 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 16 | GPUS_PER_NODE = 8 17 | 18 | SETUP_RETRY_COUNT = 3 19 | 20 | 21 | def setup_dist(): 22 | """ 23 | Setup a distributed process group. 24 | """ 25 | if dist.is_initialized(): 26 | return 27 | if not os.environ["CUDA_VISIBLE_DEVICES"]: 28 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" 29 | 30 | comm = MPI.COMM_WORLD 31 | backend = "gloo" if not th.cuda.is_available() else "nccl" 32 | 33 | if backend == "gloo": 34 | hostname = "localhost" 35 | else: 36 | hostname = socket.gethostbyname(socket.getfqdn()) 37 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 38 | os.environ["RANK"] = str(comm.rank) 39 | os.environ["WORLD_SIZE"] = str(comm.size) 40 | 41 | port = comm.bcast(_find_free_port(), root=0) 42 | os.environ["MASTER_PORT"] = str(port) 43 | dist.init_process_group(backend=backend, init_method="env://") 44 | 45 | 46 | def dev(): 47 | """ 48 | Get the device to use for torch.distributed. 49 | """ 50 | if th.cuda.is_available(): 51 | return th.device(f"cuda") 52 | return th.device("cpu") 53 | 54 | 55 | def load_state_dict(path, **kwargs): 56 | """ 57 | Load a PyTorch file without redundant fetches across MPI ranks. 58 | """ 59 | chunk_size = 2 ** 30 # MPI has a relatively small size limit 60 | if MPI.COMM_WORLD.Get_rank() == 0: 61 | with bf.BlobFile(path, "rb") as f: 62 | data = f.read() 63 | num_chunks = len(data) // chunk_size 64 | if len(data) % chunk_size: 65 | num_chunks += 1 66 | MPI.COMM_WORLD.bcast(num_chunks) 67 | for i in range(0, len(data), chunk_size): 68 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) 69 | else: 70 | num_chunks = MPI.COMM_WORLD.bcast(None) 71 | data = bytes() 72 | for _ in range(num_chunks): 73 | data += MPI.COMM_WORLD.bcast(None) 74 | 75 | return th.load(io.BytesIO(data), **kwargs) 76 | 77 | 78 | def sync_params(params): 79 | """ 80 | Synchronize a sequence of Tensors across ranks from rank 0. 81 | """ 82 | for p in params: 83 | with th.no_grad(): 84 | dist.broadcast(p, 0) 85 | 86 | 87 | def _find_free_port(): 88 | try: 89 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 90 | s.bind(("", 0)) 91 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 92 | return s.getsockname()[1] 93 | finally: 94 | s.close() 95 | -------------------------------------------------------------------------------- /PTQ4DM/guided-diffusion/guided_diffusion/image_datasets.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | from PIL import Image 5 | import blobfile as bf 6 | from mpi4py import MPI 7 | import numpy as np 8 | from torch.utils.data import DataLoader, Dataset 9 | 10 | 11 | def load_data( 12 | *, 13 | data_dir, 14 | batch_size, 15 | image_size, 16 | class_cond=False, 17 | deterministic=False, 18 | random_crop=False, 19 | random_flip=True, 20 | ): 21 | """ 22 | For a dataset, create a generator over (images, kwargs) pairs. 23 | 24 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 25 | more keys, each of which map to a batched Tensor of their own. 26 | The kwargs dict can be used for class labels, in which case the key is "y" 27 | and the values are integer tensors of class labels. 28 | 29 | :param data_dir: a dataset directory. 30 | :param batch_size: the batch size of each returned pair. 31 | :param image_size: the size to which images are resized. 32 | :param class_cond: if True, include a "y" key in returned dicts for class 33 | label. If classes are not available and this is true, an 34 | exception will be raised. 35 | :param deterministic: if True, yield results in a deterministic order. 36 | :param random_crop: if True, randomly crop the images for augmentation. 37 | :param random_flip: if True, randomly flip the images for augmentation. 38 | """ 39 | if not data_dir: 40 | raise ValueError("unspecified data directory") 41 | all_files = _list_image_files_recursively(data_dir) 42 | classes = None 43 | if class_cond: 44 | # Assume classes are the first part of the filename, 45 | # before an underscore. 46 | class_names = [bf.basename(path).split("_")[0] for path in all_files] 47 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} 48 | classes = [sorted_classes[x] for x in class_names] 49 | dataset = ImageDataset( 50 | image_size, 51 | all_files, 52 | classes=classes, 53 | shard=MPI.COMM_WORLD.Get_rank(), 54 | num_shards=MPI.COMM_WORLD.Get_size(), 55 | random_crop=random_crop, 56 | random_flip=random_flip, 57 | ) 58 | if deterministic: 59 | loader = DataLoader( 60 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True 61 | ) 62 | else: 63 | loader = DataLoader( 64 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True 65 | ) 66 | while True: 67 | yield from loader 68 | 69 | 70 | def _list_image_files_recursively(data_dir): 71 | results = [] 72 | for entry in sorted(bf.listdir(data_dir)): 73 | full_path = bf.join(data_dir, entry) 74 | ext = entry.split(".")[-1] 75 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: 76 | results.append(full_path) 77 | elif bf.isdir(full_path): 78 | results.extend(_list_image_files_recursively(full_path)) 79 | return results 80 | 81 | 82 | class ImageDataset(Dataset): 83 | def __init__( 84 | self, 85 | resolution, 86 | image_paths, 87 | classes=None, 88 | shard=0, 89 | num_shards=1, 90 | random_crop=False, 91 | random_flip=True, 92 | ): 93 | super().__init__() 94 | self.resolution = resolution 95 | self.local_images = image_paths[shard:][::num_shards] 96 | self.local_classes = None if classes is None else classes[shard:][::num_shards] 97 | self.random_crop = random_crop 98 | self.random_flip = random_flip 99 | 100 | def __len__(self): 101 | return len(self.local_images) 102 | 103 | def __getitem__(self, idx): 104 | path = self.local_images[idx] 105 | with bf.BlobFile(path, "rb") as f: 106 | pil_image = Image.open(f) 107 | pil_image.load() 108 | pil_image = pil_image.convert("RGB") 109 | 110 | if self.random_crop: 111 | arr = random_crop_arr(pil_image, self.resolution) 112 | else: 113 | arr = center_crop_arr(pil_image, self.resolution) 114 | 115 | if self.random_flip and random.random() < 0.5: 116 | arr = arr[:, ::-1] 117 | 118 | arr = arr.astype(np.float32) / 127.5 - 1 119 | 120 | out_dict = {} 121 | if self.local_classes is not None: 122 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) 123 | return np.transpose(arr, [2, 0, 1]), out_dict 124 | 125 | 126 | def center_crop_arr(pil_image, image_size): 127 | # We are not on a new enough PIL to support the `reducing_gap` 128 | # argument, which uses BOX downsampling at powers of two first. 129 | # Thus, we do it by hand to improve downsample quality. 130 | while min(*pil_image.size) >= 2 * image_size: 131 | pil_image = pil_image.resize( 132 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 133 | ) 134 | 135 | scale = image_size / min(*pil_image.size) 136 | pil_image = pil_image.resize( 137 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 138 | ) 139 | 140 | arr = np.array(pil_image) 141 | crop_y = (arr.shape[0] - image_size) // 2 142 | crop_x = (arr.shape[1] - image_size) // 2 143 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 144 | 145 | 146 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): 147 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac) 148 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac) 149 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) 150 | 151 | # We are not on a new enough PIL to support the `reducing_gap` 152 | # argument, which uses BOX downsampling at powers of two first. 153 | # Thus, we do it by hand to improve downsample quality. 154 | while min(*pil_image.size) >= 2 * smaller_dim_size: 155 | pil_image = pil_image.resize( 156 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 157 | ) 158 | 159 | scale = smaller_dim_size / min(*pil_image.size) 160 | pil_image = pil_image.resize( 161 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 162 | ) 163 | 164 | arr = np.array(pil_image) 165 | crop_y = random.randrange(arr.shape[0] - image_size + 1) 166 | crop_x = random.randrange(arr.shape[1] - image_size + 1) 167 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 168 | -------------------------------------------------------------------------------- /PTQ4DM/guided-diffusion/guided_diffusion/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /PTQ4DM/guided-diffusion/guided_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /PTQ4DM/guided-diffusion/guided_diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /PTQ4DM/guided-diffusion/guided_diffusion/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | 77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 78 | last_alpha_cumprod = 1.0 79 | new_betas = [] 80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 81 | if i in self.use_timesteps: 82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 83 | last_alpha_cumprod = alpha_cumprod 84 | self.timestep_map.append(i) 85 | kwargs["betas"] = np.array(new_betas) 86 | super().__init__(**kwargs) 87 | 88 | def p_mean_variance( 89 | self, model, *args, **kwargs 90 | ): # pylint: disable=signature-differs 91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 92 | 93 | def training_losses( 94 | self, model, *args, **kwargs 95 | ): # pylint: disable=signature-differs 96 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 97 | 98 | def condition_mean(self, cond_fn, *args, **kwargs): 99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 100 | 101 | def condition_score(self, cond_fn, *args, **kwargs): 102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 103 | 104 | def _wrap_model(self, model): 105 | if isinstance(model, _WrappedModel): 106 | return model 107 | return _WrappedModel( 108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 109 | ) 110 | 111 | def _scale_timesteps(self, t): 112 | # Scaling is done by the wrapped model. 113 | return t 114 | 115 | 116 | class _WrappedModel: 117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 118 | self.model = model 119 | self.timestep_map = timestep_map 120 | self.rescale_timesteps = rescale_timesteps 121 | self.original_num_steps = original_num_steps 122 | 123 | def __call__(self, x, ts, **kwargs): 124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 125 | new_ts = map_tensor[ts] 126 | if self.rescale_timesteps: 127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 128 | return self.model(x, new_ts, **kwargs) 129 | -------------------------------------------------------------------------------- /PTQ4DM/guided-diffusion/model-card.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | These are diffusion models and noised image classifiers described in the paper [Diffusion Models Beat GANs on Image Synthesis](https://arxiv.org/abs/2105.05233). 4 | Included in this release are the following models: 5 | 6 | * Noisy ImageNet classifiers at resolutions 64x64, 128x128, 256x256, 512x512 7 | * A class-unconditional ImageNet diffusion model at resolution 256x256 8 | * Class conditional ImageNet diffusion models at 64x64, 128x128, 256x256, 512x512 resolutions 9 | * Class-conditional ImageNet upsampling diffusion models: 64x64->256x256, 128x128->512x512 10 | * Diffusion models trained on three LSUN classes at 256x256 resolution: cat, horse, bedroom 11 | 12 | # Datasets 13 | 14 | All of the models we are releasing were either trained on the [ILSVRC 2012 subset of ImageNet](http://www.image-net.org/challenges/LSVRC/2012/) or on single classes of [LSUN](https://arxiv.org/abs/1506.03365). 15 | Here, we describe characteristics of these datasets which impact model behavior: 16 | 17 | **LSUN**: This dataset was collected in 2015 using a combination of human labeling (from Amazon Mechanical Turk) and automated data labeling. 18 | * Each of the three classes we consider contain over a million images. 19 | * The dataset creators found that the label accuracy was roughly 90% across the entire LSUN dataset when measured by trained experts. 20 | * Images are scraped from the internet, and LSUN cat images in particular tend to often follow a “meme” format. 21 | * We found that there are occasionally humans in these photos, including faces, especially within the cat class. 22 | 23 | **ILSVRC 2012 subset of ImageNet**: This dataset was curated in 2012 and consists of roughly one million images, each belonging to one of 1000 classes. 24 | * A large portion of the classes in this dataset are animals, plants, and other naturally-occurring objects. 25 | * Many images contain humans, although usually these humans aren’t reflected by the class label (e.g. the class “Tench, tinca tinca” contains many photos of people holding fish). 26 | 27 | # Performance 28 | 29 | These models are intended to generate samples consistent with their training distributions. 30 | This has been measured in terms of FID, Precision, and Recall. 31 | These metrics all rely on the representations of a [pre-trained Inception-V3 model](https://arxiv.org/abs/1512.00567), 32 | which was trained on ImageNet, and so is likely to focus more on the ImageNet classes (such as animals) than on other visual features (such as human faces). 33 | 34 | Qualitatively, the samples produced by these models often look highly realistic, especially when a diffusion model is combined with a noisy classifier. 35 | 36 | # Intended Use 37 | 38 | These models are intended to be used for research purposes only. 39 | In particular, they can be used as a baseline for generative modeling research, or as a starting point to build off of for such research. 40 | 41 | These models are not intended to be commercially deployed. 42 | Additionally, they are not intended to be used to create propaganda or offensive imagery. 43 | 44 | Before releasing these models, we probed their ability to ease the creation of targeted imagery, since doing so could be potentially harmful. 45 | We did this either by fine-tuning our ImageNet models on a target LSUN class, or through classifier guidance with publicly available [CLIP models](https://github.com/openai/CLIP). 46 | * To probe fine-tuning capabilities, we restricted our compute budget to roughly $100 and tried both standard fine-tuning, 47 | and a diffusion-specific approach where we train a specialized classifier for the LSUN class. The resulting FIDs were significantly worse than publicly available GAN models, indicating that fine-tuning an ImageNet diffusion model does not significantly lower the cost of image generation. 48 | * To probe guidance with CLIP, we tried two approaches for using pre-trained CLIP models for classifier guidance. Either we fed the noised image to CLIP directly and used its gradients, or we fed the diffusion model's denoised prediction to the CLIP model and differentiated through the whole process. In both cases, we found that it was difficult to recover information from the CLIP model, indicating that these diffusion models are unlikely to make it significantly easier to extract knowledge from CLIP compared to existing GAN models. 49 | 50 | # Limitations 51 | 52 | These models sometimes produce highly unrealistic outputs, particularly when generating images containing human faces. 53 | This may stem from ImageNet's emphasis on non-human objects. 54 | 55 | While classifier guidance can improve sample quality, it reduces diversity, resulting in some modes of the data distribution being underrepresented. 56 | This can potentially amplify existing biases in the training dataset such as gender and racial biases. 57 | 58 | Because ImageNet and LSUN contain images from the internet, they include photos of real people, and the model may have memorized some of the information contained in these photos. 59 | However, these images are already publicly available, and existing generative models trained on ImageNet have not demonstrated significant leakage of this information. -------------------------------------------------------------------------------- /PTQ4DM/guided-diffusion/scripts/classifier_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Like image_sample.py, but use a noisy image classifier to guide the sampling 3 | process towards more realistic images. 4 | """ 5 | 6 | import argparse 7 | import os 8 | 9 | import numpy as np 10 | import torch as th 11 | import torch.distributed as dist 12 | import torch.nn.functional as F 13 | 14 | from guided_diffusion import dist_util, logger 15 | from guided_diffusion.script_util import ( 16 | NUM_CLASSES, 17 | model_and_diffusion_defaults, 18 | classifier_defaults, 19 | create_model_and_diffusion, 20 | create_classifier, 21 | add_dict_to_argparser, 22 | args_to_dict, 23 | ) 24 | 25 | 26 | def main(): 27 | args = create_argparser().parse_args() 28 | 29 | dist_util.setup_dist() 30 | logger.configure() 31 | 32 | logger.log("creating model and diffusion...") 33 | model, diffusion = create_model_and_diffusion( 34 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 35 | ) 36 | model.load_state_dict( 37 | dist_util.load_state_dict(args.model_path, map_location="cpu") 38 | ) 39 | model.to(dist_util.dev()) 40 | if args.use_fp16: 41 | model.convert_to_fp16() 42 | model.eval() 43 | 44 | logger.log("loading classifier...") 45 | classifier = create_classifier(**args_to_dict(args, classifier_defaults().keys())) 46 | classifier.load_state_dict( 47 | dist_util.load_state_dict(args.classifier_path, map_location="cpu") 48 | ) 49 | classifier.to(dist_util.dev()) 50 | if args.classifier_use_fp16: 51 | classifier.convert_to_fp16() 52 | classifier.eval() 53 | 54 | def cond_fn(x, t, y=None): 55 | assert y is not None 56 | with th.enable_grad(): 57 | x_in = x.detach().requires_grad_(True) 58 | logits = classifier(x_in, t) 59 | log_probs = F.log_softmax(logits, dim=-1) 60 | selected = log_probs[range(len(logits)), y.view(-1)] 61 | return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale 62 | 63 | def model_fn(x, t, y=None): 64 | assert y is not None 65 | return model(x, t, y if args.class_cond else None) 66 | 67 | logger.log("sampling...") 68 | all_images = [] 69 | all_labels = [] 70 | while len(all_images) * args.batch_size < args.num_samples: 71 | model_kwargs = {} 72 | classes = th.randint( 73 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev() 74 | ) 75 | model_kwargs["y"] = classes 76 | sample_fn = ( 77 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop 78 | ) 79 | sample = sample_fn( 80 | model_fn, 81 | (args.batch_size, 3, args.image_size, args.image_size), 82 | clip_denoised=args.clip_denoised, 83 | model_kwargs=model_kwargs, 84 | cond_fn=cond_fn, 85 | device=dist_util.dev(), 86 | ) 87 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) 88 | sample = sample.permute(0, 2, 3, 1) 89 | sample = sample.contiguous() 90 | 91 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 92 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL 93 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples]) 94 | gathered_labels = [th.zeros_like(classes) for _ in range(dist.get_world_size())] 95 | dist.all_gather(gathered_labels, classes) 96 | all_labels.extend([labels.cpu().numpy() for labels in gathered_labels]) 97 | logger.log(f"created {len(all_images) * args.batch_size} samples") 98 | 99 | arr = np.concatenate(all_images, axis=0) 100 | arr = arr[: args.num_samples] 101 | label_arr = np.concatenate(all_labels, axis=0) 102 | label_arr = label_arr[: args.num_samples] 103 | if dist.get_rank() == 0: 104 | shape_str = "x".join([str(x) for x in arr.shape]) 105 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz") 106 | logger.log(f"saving to {out_path}") 107 | np.savez(out_path, arr, label_arr) 108 | 109 | dist.barrier() 110 | logger.log("sampling complete") 111 | 112 | 113 | def create_argparser(): 114 | defaults = dict( 115 | clip_denoised=True, 116 | num_samples=10000, 117 | batch_size=16, 118 | use_ddim=False, 119 | model_path="", 120 | classifier_path="", 121 | classifier_scale=1.0, 122 | ) 123 | defaults.update(model_and_diffusion_defaults()) 124 | defaults.update(classifier_defaults()) 125 | parser = argparse.ArgumentParser() 126 | add_dict_to_argparser(parser, defaults) 127 | return parser 128 | 129 | 130 | if __name__ == "__main__": 131 | main() 132 | -------------------------------------------------------------------------------- /PTQ4DM/guided-diffusion/scripts/classifier_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a noised image classifier on ImageNet. 3 | """ 4 | 5 | import argparse 6 | import os 7 | 8 | import blobfile as bf 9 | import torch as th 10 | import torch.distributed as dist 11 | import torch.nn.functional as F 12 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 13 | from torch.optim import AdamW 14 | 15 | from guided_diffusion import dist_util, logger 16 | from guided_diffusion.fp16_util import MixedPrecisionTrainer 17 | from guided_diffusion.image_datasets import load_data 18 | from guided_diffusion.resample import create_named_schedule_sampler 19 | from guided_diffusion.script_util import ( 20 | add_dict_to_argparser, 21 | args_to_dict, 22 | classifier_and_diffusion_defaults, 23 | create_classifier_and_diffusion, 24 | ) 25 | from guided_diffusion.train_util import parse_resume_step_from_filename, log_loss_dict 26 | 27 | 28 | def main(): 29 | args = create_argparser().parse_args() 30 | 31 | dist_util.setup_dist() 32 | logger.configure() 33 | 34 | logger.log("creating model and diffusion...") 35 | model, diffusion = create_classifier_and_diffusion( 36 | **args_to_dict(args, classifier_and_diffusion_defaults().keys()) 37 | ) 38 | model.to(dist_util.dev()) 39 | if args.noised: 40 | schedule_sampler = create_named_schedule_sampler( 41 | args.schedule_sampler, diffusion 42 | ) 43 | 44 | resume_step = 0 45 | if args.resume_checkpoint: 46 | resume_step = parse_resume_step_from_filename(args.resume_checkpoint) 47 | if dist.get_rank() == 0: 48 | logger.log( 49 | f"loading model from checkpoint: {args.resume_checkpoint}... at {resume_step} step" 50 | ) 51 | model.load_state_dict( 52 | dist_util.load_state_dict( 53 | args.resume_checkpoint, map_location=dist_util.dev() 54 | ) 55 | ) 56 | 57 | # Needed for creating correct EMAs and fp16 parameters. 58 | dist_util.sync_params(model.parameters()) 59 | 60 | mp_trainer = MixedPrecisionTrainer( 61 | model=model, use_fp16=args.classifier_use_fp16, initial_lg_loss_scale=16.0 62 | ) 63 | 64 | model = DDP( 65 | model, 66 | device_ids=[dist_util.dev()], 67 | output_device=dist_util.dev(), 68 | broadcast_buffers=False, 69 | bucket_cap_mb=128, 70 | find_unused_parameters=False, 71 | ) 72 | 73 | logger.log("creating data loader...") 74 | data = load_data( 75 | data_dir=args.data_dir, 76 | batch_size=args.batch_size, 77 | image_size=args.image_size, 78 | class_cond=True, 79 | random_crop=True, 80 | ) 81 | if args.val_data_dir: 82 | val_data = load_data( 83 | data_dir=args.val_data_dir, 84 | batch_size=args.batch_size, 85 | image_size=args.image_size, 86 | class_cond=True, 87 | ) 88 | else: 89 | val_data = None 90 | 91 | logger.log(f"creating optimizer...") 92 | opt = AdamW(mp_trainer.master_params, lr=args.lr, weight_decay=args.weight_decay) 93 | if args.resume_checkpoint: 94 | opt_checkpoint = bf.join( 95 | bf.dirname(args.resume_checkpoint), f"opt{resume_step:06}.pt" 96 | ) 97 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 98 | opt.load_state_dict( 99 | dist_util.load_state_dict(opt_checkpoint, map_location=dist_util.dev()) 100 | ) 101 | 102 | logger.log("training classifier model...") 103 | 104 | def forward_backward_log(data_loader, prefix="train"): 105 | batch, extra = next(data_loader) 106 | labels = extra["y"].to(dist_util.dev()) 107 | 108 | batch = batch.to(dist_util.dev()) 109 | # Noisy images 110 | if args.noised: 111 | t, _ = schedule_sampler.sample(batch.shape[0], dist_util.dev()) 112 | batch = diffusion.q_sample(batch, t) 113 | else: 114 | t = th.zeros(batch.shape[0], dtype=th.long, device=dist_util.dev()) 115 | 116 | for i, (sub_batch, sub_labels, sub_t) in enumerate( 117 | split_microbatches(args.microbatch, batch, labels, t) 118 | ): 119 | logits = model(sub_batch, timesteps=sub_t) 120 | loss = F.cross_entropy(logits, sub_labels, reduction="none") 121 | 122 | losses = {} 123 | losses[f"{prefix}_loss"] = loss.detach() 124 | losses[f"{prefix}_acc@1"] = compute_top_k( 125 | logits, sub_labels, k=1, reduction="none" 126 | ) 127 | losses[f"{prefix}_acc@5"] = compute_top_k( 128 | logits, sub_labels, k=5, reduction="none" 129 | ) 130 | log_loss_dict(diffusion, sub_t, losses) 131 | del losses 132 | loss = loss.mean() 133 | if loss.requires_grad: 134 | if i == 0: 135 | mp_trainer.zero_grad() 136 | mp_trainer.backward(loss * len(sub_batch) / len(batch)) 137 | 138 | for step in range(args.iterations - resume_step): 139 | logger.logkv("step", step + resume_step) 140 | logger.logkv( 141 | "samples", 142 | (step + resume_step + 1) * args.batch_size * dist.get_world_size(), 143 | ) 144 | if args.anneal_lr: 145 | set_annealed_lr(opt, args.lr, (step + resume_step) / args.iterations) 146 | forward_backward_log(data) 147 | mp_trainer.optimize(opt) 148 | if val_data is not None and not step % args.eval_interval: 149 | with th.no_grad(): 150 | with model.no_sync(): 151 | model.eval() 152 | forward_backward_log(val_data, prefix="val") 153 | model.train() 154 | if not step % args.log_interval: 155 | logger.dumpkvs() 156 | if ( 157 | step 158 | and dist.get_rank() == 0 159 | and not (step + resume_step) % args.save_interval 160 | ): 161 | logger.log("saving model...") 162 | save_model(mp_trainer, opt, step + resume_step) 163 | 164 | if dist.get_rank() == 0: 165 | logger.log("saving model...") 166 | save_model(mp_trainer, opt, step + resume_step) 167 | dist.barrier() 168 | 169 | 170 | def set_annealed_lr(opt, base_lr, frac_done): 171 | lr = base_lr * (1 - frac_done) 172 | for param_group in opt.param_groups: 173 | param_group["lr"] = lr 174 | 175 | 176 | def save_model(mp_trainer, opt, step): 177 | if dist.get_rank() == 0: 178 | th.save( 179 | mp_trainer.master_params_to_state_dict(mp_trainer.master_params), 180 | os.path.join(logger.get_dir(), f"model{step:06d}.pt"), 181 | ) 182 | th.save(opt.state_dict(), os.path.join(logger.get_dir(), f"opt{step:06d}.pt")) 183 | 184 | 185 | def compute_top_k(logits, labels, k, reduction="mean"): 186 | _, top_ks = th.topk(logits, k, dim=-1) 187 | if reduction == "mean": 188 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 189 | elif reduction == "none": 190 | return (top_ks == labels[:, None]).float().sum(dim=-1) 191 | 192 | 193 | def split_microbatches(microbatch, *args): 194 | bs = len(args[0]) 195 | if microbatch == -1 or microbatch >= bs: 196 | yield tuple(args) 197 | else: 198 | for i in range(0, bs, microbatch): 199 | yield tuple(x[i : i + microbatch] if x is not None else None for x in args) 200 | 201 | 202 | def create_argparser(): 203 | defaults = dict( 204 | data_dir="", 205 | val_data_dir="", 206 | noised=True, 207 | iterations=150000, 208 | lr=3e-4, 209 | weight_decay=0.0, 210 | anneal_lr=False, 211 | batch_size=4, 212 | microbatch=-1, 213 | schedule_sampler="uniform", 214 | resume_checkpoint="", 215 | log_interval=10, 216 | eval_interval=5, 217 | save_interval=10000, 218 | ) 219 | defaults.update(classifier_and_diffusion_defaults()) 220 | parser = argparse.ArgumentParser() 221 | add_dict_to_argparser(parser, defaults) 222 | return parser 223 | 224 | 225 | if __name__ == "__main__": 226 | main() 227 | -------------------------------------------------------------------------------- /PTQ4DM/guided-diffusion/scripts/image_nll.py: -------------------------------------------------------------------------------- 1 | """ 2 | Approximate the bits/dimension for an image model. 3 | """ 4 | 5 | import argparse 6 | import os 7 | 8 | import numpy as np 9 | import torch.distributed as dist 10 | 11 | from guided_diffusion import dist_util, logger 12 | from guided_diffusion.image_datasets import load_data 13 | from guided_diffusion.script_util import ( 14 | model_and_diffusion_defaults, 15 | create_model_and_diffusion, 16 | add_dict_to_argparser, 17 | args_to_dict, 18 | ) 19 | 20 | 21 | def main(): 22 | args = create_argparser().parse_args() 23 | 24 | dist_util.setup_dist() 25 | logger.configure() 26 | 27 | logger.log("creating model and diffusion...") 28 | model, diffusion = create_model_and_diffusion( 29 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 30 | ) 31 | model.load_state_dict( 32 | dist_util.load_state_dict(args.model_path, map_location="cpu") 33 | ) 34 | model.to(dist_util.dev()) 35 | model.eval() 36 | 37 | logger.log("creating data loader...") 38 | data = load_data( 39 | data_dir=args.data_dir, 40 | batch_size=args.batch_size, 41 | image_size=args.image_size, 42 | class_cond=args.class_cond, 43 | deterministic=True, 44 | ) 45 | 46 | logger.log("evaluating...") 47 | run_bpd_evaluation(model, diffusion, data, args.num_samples, args.clip_denoised) 48 | 49 | 50 | def run_bpd_evaluation(model, diffusion, data, num_samples, clip_denoised): 51 | all_bpd = [] 52 | all_metrics = {"vb": [], "mse": [], "xstart_mse": []} 53 | num_complete = 0 54 | while num_complete < num_samples: 55 | batch, model_kwargs = next(data) 56 | batch = batch.to(dist_util.dev()) 57 | model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()} 58 | minibatch_metrics = diffusion.calc_bpd_loop( 59 | model, batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs 60 | ) 61 | 62 | for key, term_list in all_metrics.items(): 63 | terms = minibatch_metrics[key].mean(dim=0) / dist.get_world_size() 64 | dist.all_reduce(terms) 65 | term_list.append(terms.detach().cpu().numpy()) 66 | 67 | total_bpd = minibatch_metrics["total_bpd"] 68 | total_bpd = total_bpd.mean() / dist.get_world_size() 69 | dist.all_reduce(total_bpd) 70 | all_bpd.append(total_bpd.item()) 71 | num_complete += dist.get_world_size() * batch.shape[0] 72 | 73 | logger.log(f"done {num_complete} samples: bpd={np.mean(all_bpd)}") 74 | 75 | if dist.get_rank() == 0: 76 | for name, terms in all_metrics.items(): 77 | out_path = os.path.join(logger.get_dir(), f"{name}_terms.npz") 78 | logger.log(f"saving {name} terms to {out_path}") 79 | np.savez(out_path, np.mean(np.stack(terms), axis=0)) 80 | 81 | dist.barrier() 82 | logger.log("evaluation complete") 83 | 84 | 85 | def create_argparser(): 86 | defaults = dict( 87 | data_dir="", clip_denoised=True, num_samples=1000, batch_size=1, model_path="" 88 | ) 89 | defaults.update(model_and_diffusion_defaults()) 90 | parser = argparse.ArgumentParser() 91 | add_dict_to_argparser(parser, defaults) 92 | return parser 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /PTQ4DM/guided-diffusion/scripts/image_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | 6 | import argparse 7 | import os 8 | 9 | import numpy as np 10 | import torch as th 11 | import torch.distributed as dist 12 | import time 13 | 14 | from guided_diffusion import dist_util, logger 15 | from guided_diffusion.script_util import ( 16 | NUM_CLASSES, 17 | model_and_diffusion_defaults, 18 | create_model_and_diffusion, 19 | add_dict_to_argparser, 20 | args_to_dict, 21 | ) 22 | 23 | 24 | def main(): 25 | args = create_argparser().parse_args() 26 | 27 | dist_util.setup_dist() 28 | logger.configure() 29 | 30 | logger.log("creating model and diffusion...") 31 | model, diffusion = create_model_and_diffusion( 32 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 33 | ) 34 | model.load_state_dict( 35 | dist_util.load_state_dict(args.model_path, map_location="cpu") 36 | ) 37 | model.to(dist_util.dev()) 38 | if args.use_fp16: 39 | model.convert_to_fp16() 40 | model.eval() 41 | 42 | logger.log("sampling...") 43 | all_images = [] 44 | all_labels = [] 45 | while len(all_images) * args.batch_size < args.num_samples: 46 | st=time.time() 47 | model_kwargs = {} 48 | if args.class_cond: 49 | classes = th.randint( 50 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev() 51 | ) 52 | model_kwargs["y"] = classes 53 | sample_fn = ( 54 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop 55 | ) 56 | sample = sample_fn( 57 | model, 58 | (args.batch_size, 3, args.image_size, args.image_size), 59 | clip_denoised=args.clip_denoised, 60 | model_kwargs=model_kwargs, 61 | ) 62 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) 63 | sample = sample.permute(0, 2, 3, 1) 64 | sample = sample.contiguous() 65 | 66 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 67 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL 68 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples]) 69 | if args.class_cond: 70 | gathered_labels = [ 71 | th.zeros_like(classes) for _ in range(dist.get_world_size()) 72 | ] 73 | dist.all_gather(gathered_labels, classes) 74 | all_labels.extend([labels.cpu().numpy() for labels in gathered_labels]) 75 | ed=time.time() 76 | logger.log(f"created {len(all_images) * args.batch_size} samples in {int(ed-st)}s") 77 | 78 | arr = np.concatenate(all_images, axis=0) 79 | arr = arr[: args.num_samples] 80 | if args.class_cond: 81 | label_arr = np.concatenate(all_labels, axis=0) 82 | label_arr = label_arr[: args.num_samples] 83 | if dist.get_rank() == 0: 84 | shape_str = "x".join([str(x) for x in arr.shape]) 85 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz") 86 | logger.log(f"saving to {out_path}") 87 | if args.class_cond: 88 | np.savez(out_path, arr, label_arr) 89 | else: 90 | np.savez(out_path, arr) 91 | 92 | dist.barrier() 93 | logger.log("sampling complete") 94 | 95 | 96 | def create_argparser(): 97 | defaults = dict( 98 | clip_denoised=True, 99 | num_samples=10000, 100 | batch_size=16, 101 | use_ddim=False, 102 | model_path="", 103 | ) 104 | defaults.update(model_and_diffusion_defaults()) 105 | parser = argparse.ArgumentParser() 106 | add_dict_to_argparser(parser, defaults) 107 | return parser 108 | 109 | 110 | if __name__ == "__main__": 111 | main() 112 | -------------------------------------------------------------------------------- /PTQ4DM/guided-diffusion/scripts/image_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | 5 | import argparse 6 | 7 | from guided_diffusion import dist_util, logger 8 | from guided_diffusion.image_datasets import load_data 9 | from guided_diffusion.resample import create_named_schedule_sampler 10 | from guided_diffusion.script_util import ( 11 | model_and_diffusion_defaults, 12 | create_model_and_diffusion, 13 | args_to_dict, 14 | add_dict_to_argparser, 15 | ) 16 | from guided_diffusion.train_util import TrainLoop 17 | 18 | 19 | def main(): 20 | args = create_argparser().parse_args() 21 | 22 | dist_util.setup_dist() 23 | logger.configure() 24 | 25 | logger.log("creating model and diffusion...") 26 | model, diffusion = create_model_and_diffusion( 27 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 28 | ) 29 | model.to(dist_util.dev()) 30 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 31 | 32 | logger.log("creating data loader...") 33 | data = load_data( 34 | data_dir=args.data_dir, 35 | batch_size=args.batch_size, 36 | image_size=args.image_size, 37 | class_cond=args.class_cond, 38 | ) 39 | 40 | logger.log("training...") 41 | TrainLoop( 42 | model=model, 43 | diffusion=diffusion, 44 | data=data, 45 | batch_size=args.batch_size, 46 | microbatch=args.microbatch, 47 | lr=args.lr, 48 | ema_rate=args.ema_rate, 49 | log_interval=args.log_interval, 50 | save_interval=args.save_interval, 51 | resume_checkpoint=args.resume_checkpoint, 52 | use_fp16=args.use_fp16, 53 | fp16_scale_growth=args.fp16_scale_growth, 54 | schedule_sampler=schedule_sampler, 55 | weight_decay=args.weight_decay, 56 | lr_anneal_steps=args.lr_anneal_steps, 57 | ).run_loop() 58 | 59 | 60 | def create_argparser(): 61 | defaults = dict( 62 | data_dir="", 63 | schedule_sampler="uniform", 64 | lr=1e-4, 65 | weight_decay=0.0, 66 | lr_anneal_steps=0, 67 | batch_size=1, 68 | microbatch=-1, # -1 disables microbatches 69 | ema_rate="0.9999", # comma-separated list of EMA values 70 | log_interval=10, 71 | save_interval=10000, 72 | resume_checkpoint="", 73 | use_fp16=False, 74 | fp16_scale_growth=1e-3, 75 | ) 76 | defaults.update(model_and_diffusion_defaults()) 77 | parser = argparse.ArgumentParser() 78 | add_dict_to_argparser(parser, defaults) 79 | return parser 80 | 81 | 82 | if __name__ == "__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /PTQ4DM/guided-diffusion/scripts/super_res_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of samples from a super resolution model, given a batch 3 | of samples from a regular model from image_sample.py. 4 | """ 5 | 6 | import argparse 7 | import os 8 | 9 | import blobfile as bf 10 | import numpy as np 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | from guided_diffusion import dist_util, logger 15 | from guided_diffusion.script_util import ( 16 | sr_model_and_diffusion_defaults, 17 | sr_create_model_and_diffusion, 18 | args_to_dict, 19 | add_dict_to_argparser, 20 | ) 21 | 22 | 23 | def main(): 24 | args = create_argparser().parse_args() 25 | 26 | dist_util.setup_dist() 27 | logger.configure() 28 | 29 | logger.log("creating model...") 30 | model, diffusion = sr_create_model_and_diffusion( 31 | **args_to_dict(args, sr_model_and_diffusion_defaults().keys()) 32 | ) 33 | model.load_state_dict( 34 | dist_util.load_state_dict(args.model_path, map_location="cpu") 35 | ) 36 | model.to(dist_util.dev()) 37 | if args.use_fp16: 38 | model.convert_to_fp16() 39 | model.eval() 40 | 41 | logger.log("loading data...") 42 | data = load_data_for_worker(args.base_samples, args.batch_size, args.class_cond) 43 | 44 | logger.log("creating samples...") 45 | all_images = [] 46 | while len(all_images) * args.batch_size < args.num_samples: 47 | model_kwargs = next(data) 48 | model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()} 49 | sample = diffusion.p_sample_loop( 50 | model, 51 | (args.batch_size, 3, args.large_size, args.large_size), 52 | clip_denoised=args.clip_denoised, 53 | model_kwargs=model_kwargs, 54 | ) 55 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) 56 | sample = sample.permute(0, 2, 3, 1) 57 | sample = sample.contiguous() 58 | 59 | all_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 60 | dist.all_gather(all_samples, sample) # gather not supported with NCCL 61 | for sample in all_samples: 62 | all_images.append(sample.cpu().numpy()) 63 | logger.log(f"created {len(all_images) * args.batch_size} samples") 64 | 65 | arr = np.concatenate(all_images, axis=0) 66 | arr = arr[: args.num_samples] 67 | if dist.get_rank() == 0: 68 | shape_str = "x".join([str(x) for x in arr.shape]) 69 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz") 70 | logger.log(f"saving to {out_path}") 71 | np.savez(out_path, arr) 72 | 73 | dist.barrier() 74 | logger.log("sampling complete") 75 | 76 | 77 | def load_data_for_worker(base_samples, batch_size, class_cond): 78 | with bf.BlobFile(base_samples, "rb") as f: 79 | obj = np.load(f) 80 | image_arr = obj["arr_0"] 81 | if class_cond: 82 | label_arr = obj["arr_1"] 83 | rank = dist.get_rank() 84 | num_ranks = dist.get_world_size() 85 | buffer = [] 86 | label_buffer = [] 87 | while True: 88 | for i in range(rank, len(image_arr), num_ranks): 89 | buffer.append(image_arr[i]) 90 | if class_cond: 91 | label_buffer.append(label_arr[i]) 92 | if len(buffer) == batch_size: 93 | batch = th.from_numpy(np.stack(buffer)).float() 94 | batch = batch / 127.5 - 1.0 95 | batch = batch.permute(0, 3, 1, 2) 96 | res = dict(low_res=batch) 97 | if class_cond: 98 | res["y"] = th.from_numpy(np.stack(label_buffer)) 99 | yield res 100 | buffer, label_buffer = [], [] 101 | 102 | 103 | def create_argparser(): 104 | defaults = dict( 105 | clip_denoised=True, 106 | num_samples=10000, 107 | batch_size=16, 108 | use_ddim=False, 109 | base_samples="", 110 | model_path="", 111 | ) 112 | defaults.update(sr_model_and_diffusion_defaults()) 113 | parser = argparse.ArgumentParser() 114 | add_dict_to_argparser(parser, defaults) 115 | return parser 116 | 117 | 118 | if __name__ == "__main__": 119 | main() 120 | -------------------------------------------------------------------------------- /PTQ4DM/guided-diffusion/scripts/super_res_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a super-resolution model. 3 | """ 4 | 5 | import argparse 6 | 7 | import torch.nn.functional as F 8 | 9 | from guided_diffusion import dist_util, logger 10 | from guided_diffusion.image_datasets import load_data 11 | from guided_diffusion.resample import create_named_schedule_sampler 12 | from guided_diffusion.script_util import ( 13 | sr_model_and_diffusion_defaults, 14 | sr_create_model_and_diffusion, 15 | args_to_dict, 16 | add_dict_to_argparser, 17 | ) 18 | from guided_diffusion.train_util import TrainLoop 19 | 20 | 21 | def main(): 22 | args = create_argparser().parse_args() 23 | 24 | dist_util.setup_dist() 25 | logger.configure() 26 | 27 | logger.log("creating model...") 28 | model, diffusion = sr_create_model_and_diffusion( 29 | **args_to_dict(args, sr_model_and_diffusion_defaults().keys()) 30 | ) 31 | model.to(dist_util.dev()) 32 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 33 | 34 | logger.log("creating data loader...") 35 | data = load_superres_data( 36 | args.data_dir, 37 | args.batch_size, 38 | large_size=args.large_size, 39 | small_size=args.small_size, 40 | class_cond=args.class_cond, 41 | ) 42 | 43 | logger.log("training...") 44 | TrainLoop( 45 | model=model, 46 | diffusion=diffusion, 47 | data=data, 48 | batch_size=args.batch_size, 49 | microbatch=args.microbatch, 50 | lr=args.lr, 51 | ema_rate=args.ema_rate, 52 | log_interval=args.log_interval, 53 | save_interval=args.save_interval, 54 | resume_checkpoint=args.resume_checkpoint, 55 | use_fp16=args.use_fp16, 56 | fp16_scale_growth=args.fp16_scale_growth, 57 | schedule_sampler=schedule_sampler, 58 | weight_decay=args.weight_decay, 59 | lr_anneal_steps=args.lr_anneal_steps, 60 | ).run_loop() 61 | 62 | 63 | def load_superres_data(data_dir, batch_size, large_size, small_size, class_cond=False): 64 | data = load_data( 65 | data_dir=data_dir, 66 | batch_size=batch_size, 67 | image_size=large_size, 68 | class_cond=class_cond, 69 | ) 70 | for large_batch, model_kwargs in data: 71 | model_kwargs["low_res"] = F.interpolate(large_batch, small_size, mode="area") 72 | yield large_batch, model_kwargs 73 | 74 | 75 | def create_argparser(): 76 | defaults = dict( 77 | data_dir="", 78 | schedule_sampler="uniform", 79 | lr=1e-4, 80 | weight_decay=0.0, 81 | lr_anneal_steps=0, 82 | batch_size=1, 83 | microbatch=-1, 84 | ema_rate="0.9999", 85 | log_interval=10, 86 | save_interval=10000, 87 | resume_checkpoint="", 88 | use_fp16=False, 89 | fp16_scale_growth=1e-3, 90 | ) 91 | defaults.update(sr_model_and_diffusion_defaults()) 92 | parser = argparse.ArgumentParser() 93 | add_dict_to_argparser(parser, defaults) 94 | return parser 95 | 96 | 97 | if __name__ == "__main__": 98 | main() 99 | -------------------------------------------------------------------------------- /PTQ4DM/guided-diffusion/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="guided-diffusion", 5 | py_modules=["guided_diffusion"], 6 | install_requires=["blobfile>=1.0.5", "torch", "tqdm"], 7 | ) 8 | -------------------------------------------------------------------------------- /PTQ4DM/improved-diffusion/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /PTQ4DM/improved-diffusion/README.md: -------------------------------------------------------------------------------- 1 | # improved-diffusion 2 | 3 | This is the codebase for [Improved Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2102.09672). 4 | 5 | # Usage 6 | 7 | This section of the README walks through how to train and sample from a model. 8 | 9 | ## Installation 10 | 11 | Clone this repository and navigate to it in your terminal. Then run: 12 | 13 | ``` 14 | pip install -e . 15 | ``` 16 | 17 | This should install the `improved_diffusion` python package that the scripts depend on. 18 | 19 | ## Preparing Data 20 | 21 | The training code reads images from a directory of image files. In the [datasets](datasets) folder, we have provided instructions/scripts for preparing these directories for ImageNet, LSUN bedrooms, and CIFAR-10. 22 | 23 | For creating your own dataset, simply dump all of your images into a directory with ".jpg", ".jpeg", or ".png" extensions. If you wish to train a class-conditional model, name the files like "mylabel1_XXX.jpg", "mylabel2_YYY.jpg", etc., so that the data loader knows that "mylabel1" and "mylabel2" are the labels. Subdirectories will automatically be enumerated as well, so the images can be organized into a recursive structure (although the directory names will be ignored, and the underscore prefixes are used as names). 24 | 25 | The images will automatically be scaled and center-cropped by the data-loading pipeline. Simply pass `--data_dir path/to/images` to the training script, and it will take care of the rest. 26 | 27 | ## Training 28 | 29 | To train your model, you should first decide some hyperparameters. We will split up our hyperparameters into three groups: model architecture, diffusion process, and training flags. Here are some reasonable defaults for a baseline: 30 | 31 | ``` 32 | MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3" 33 | DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule linear" 34 | TRAIN_FLAGS="--lr 1e-4 --batch_size 128" 35 | ``` 36 | 37 | Here are some changes we experiment with, and how to set them in the flags: 38 | 39 | * **Learned sigmas:** add `--learn_sigma True` to `MODEL_FLAGS` 40 | * **Cosine schedule:** change `--noise_schedule linear` to `--noise_schedule cosine` 41 | * **Importance-sampled VLB:** add `--use_kl True` to `DIFFUSION_FLAGS` and add `--schedule_sampler loss-second-moment` to `TRAIN_FLAGS`. 42 | * **Class-conditional:** add `--class_cond True` to `MODEL_FLAGS`. 43 | 44 | Once you have setup your hyper-parameters, you can run an experiment like so: 45 | 46 | ``` 47 | python scripts/image_train.py --data_dir path/to/images $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS 48 | ``` 49 | 50 | You may also want to train in a distributed manner. In this case, run the same command with `mpiexec`: 51 | 52 | ``` 53 | mpiexec -n $NUM_GPUS python scripts/image_train.py --data_dir path/to/images $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS 54 | ``` 55 | 56 | When training in a distributed manner, you must manually divide the `--batch_size` argument by the number of ranks. In lieu of distributed training, you may use `--microbatch 16` (or `--microbatch 1` in extreme memory-limited cases) to reduce memory usage. 57 | 58 | The logs and saved models will be written to a logging directory determined by the `OPENAI_LOGDIR` environment variable. If it is not set, then a temporary directory will be created in `/tmp`. 59 | 60 | ## Sampling 61 | 62 | The above training script saves checkpoints to `.pt` files in the logging directory. These checkpoints will have names like `ema_0.9999_200000.pt` and `model200000.pt`. You will likely want to sample from the EMA models, since those produce much better samples. 63 | 64 | Once you have a path to your model, you can generate a large batch of samples like so: 65 | 66 | ``` 67 | python scripts/image_sample.py --model_path /path/to/model.pt $MODEL_FLAGS $DIFFUSION_FLAGS 68 | ``` 69 | 70 | Again, this will save results to a logging directory. Samples are saved as a large `npz` file, where `arr_0` in the file is a large batch of samples. 71 | 72 | Just like for training, you can run `image_sample.py` through MPI to use multiple GPUs and machines. 73 | 74 | You can change the number of sampling steps using the `--timestep_respacing` argument. For example, `--timestep_respacing 250` uses 250 steps to sample. Passing `--timestep_respacing ddim250` is similar, but uses the uniform stride from the [DDIM paper](https://arxiv.org/abs/2010.02502) rather than our stride. 75 | 76 | To sample using [DDIM](https://arxiv.org/abs/2010.02502), pass `--use_ddim True`. 77 | 78 | ## Models and Hyperparameters 79 | 80 | This section includes model checkpoints and run flags for the main models in the paper. 81 | 82 | Note that the batch sizes are specified for single-GPU training, even though most of these runs will not naturally fit on a single GPU. To address this, either set `--microbatch` to a small value (e.g. 4) to train on one GPU, or run with MPI and divide `--batch_size` by the number of GPUs. 83 | 84 | Unconditional ImageNet-64 with our `L_hybrid` objective and cosine noise schedule [[checkpoint](https://openaipublic.blob.core.windows.net/diffusion/march-2021/imagenet64_uncond_100M_1500K.pt)]: 85 | 86 | ```bash 87 | MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3 --learn_sigma True" 88 | DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule cosine" 89 | TRAIN_FLAGS="--lr 1e-4 --batch_size 128" 90 | ``` 91 | 92 | Unconditional CIFAR-10 with our `L_hybrid` objective and cosine noise schedule [[checkpoint](https://openaipublic.blob.core.windows.net/diffusion/march-2021/cifar10_uncond_50M_500K.pt)]: 93 | 94 | ```bash 95 | MODEL_FLAGS="--image_size 32 --num_channels 128 --num_res_blocks 3 --learn_sigma True --dropout 0.3" 96 | DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule cosine" 97 | TRAIN_FLAGS="--lr 1e-4 --batch_size 128" 98 | ``` 99 | 100 | Class-conditional ImageNet-64 model (270M parameters, trained for 250K iterations) [[checkpoint](https://openaipublic.blob.core.windows.net/diffusion/march-2021/imagenet64_cond_270M_250K.pt)]: 101 | 102 | ```bash 103 | MODEL_FLAGS="--image_size 64 --num_channels 192 --num_res_blocks 3 --learn_sigma True --class_cond True" 104 | DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule cosine --rescale_learned_sigmas False --rescale_timesteps False" 105 | TRAIN_FLAGS="--lr 3e-4 --batch_size 2048" 106 | ``` 107 | 108 | Upsampling 256x256 model (280M parameters, trained for 500K iterations) [[checkpoint](https://openaipublic.blob.core.windows.net/diffusion/march-2021/upsample_cond_500K.pt)]: 109 | 110 | ```bash 111 | MODEL_FLAGS="--num_channels 192 --num_res_blocks 2 --learn_sigma True --class_cond True" 112 | DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule linear --rescale_learned_sigmas False --rescale_timesteps False" 113 | TRAIN_FLAGS="--lr 3e-4 --batch_size 256" 114 | ``` 115 | 116 | LSUN bedroom model (lr=1e-4) [[checkpoint](https://openaipublic.blob.core.windows.net/diffusion/march-2021/lsun_uncond_100M_1200K_bs128.pt)]: 117 | 118 | ```bash 119 | MODEL_FLAGS="--image_size 256 --num_channels 128 --num_res_blocks 2 --num_heads 1 --learn_sigma True --use_scale_shift_norm False --attention_resolutions 16" 120 | DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear --rescale_learned_sigmas False --rescale_timesteps False" 121 | TRAIN_FLAGS="--lr 1e-4 --batch_size 128" 122 | ``` 123 | 124 | LSUN bedroom model (lr=2e-5) [[checkpoint](https://openaipublic.blob.core.windows.net/diffusion/march-2021/lsun_uncond_100M_2400K_bs64.pt)]: 125 | 126 | ```bash 127 | MODEL_FLAGS="--image_size 256 --num_channels 128 --num_res_blocks 2 --num_heads 1 --learn_sigma True --use_scale_shift_norm False --attention_resolutions 16" 128 | DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear --rescale_learned_sigmas False --rescale_timesteps False --use_scale_shift_norm False" 129 | TRAIN_FLAGS="--lr 2e-5 --batch_size 128" 130 | ``` 131 | 132 | Unconditional ImageNet-64 with the `L_vlb` objective and cosine noise schedule [[checkpoint](https://openaipublic.blob.core.windows.net/diffusion/march-2021/imagenet64_uncond_vlb_100M_1500K.pt)]: 133 | 134 | ```bash 135 | MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3 --learn_sigma True" 136 | DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule cosine --use_kl True" 137 | TRAIN_FLAGS="--lr 1e-4 --batch_size 128 --schedule_sampler loss-second-moment" 138 | ``` 139 | 140 | Unconditional CIFAR-10 with the `L_vlb` objective and cosine noise schedule [[checkpoint](https://openaipublic.blob.core.windows.net/diffusion/march-2021/cifar10_uncond_vlb_50M_500K.pt)]: 141 | 142 | ```bash 143 | MODEL_FLAGS="--image_size 32 --num_channels 128 --num_res_blocks 3 --learn_sigma True --dropout 0.3" 144 | DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule cosine --use_kl True" 145 | TRAIN_FLAGS="--lr 1e-4 --batch_size 128 --schedule_sampler loss-second-moment" 146 | ``` 147 | -------------------------------------------------------------------------------- /PTQ4DM/improved-diffusion/datasets/README.md: -------------------------------------------------------------------------------- 1 | # Downloading datasets 2 | 3 | This directory includes instructions and scripts for downloading ImageNet, LSUN bedrooms, and CIFAR-10 for use in this codebase. 4 | 5 | ## ImageNet-64 6 | 7 | To download unconditional ImageNet-64, go to [this page on image-net.org](http://www.image-net.org/small/download.php) and click on "Train (64x64)". Simply download the file and unzip it, and use the resulting directory as the data directory (the `--data_dir` argument for the training script). 8 | 9 | ## Class-conditional ImageNet 10 | 11 | For our class-conditional models, we use the official ILSVRC2012 dataset with manual center cropping and downsampling. To obtain this dataset, navigate to [this page on image-net.org](http://www.image-net.org/challenges/LSVRC/2012/downloads) and sign in (or create an account if you do not already have one). Then click on the link reading "Training images (Task 1 & 2)". This is a 138GB tar file containing 1000 sub-tar files, one per class. 12 | 13 | Once the file is downloaded, extract it and look inside. You should see 1000 `.tar` files. You need to extract each of these, which may be impractical to do by hand on your operating system. To automate the process on a Unix-based system, you can `cd` into the directory and run this short shell script: 14 | 15 | ``` 16 | for file in *.tar; do tar xf "$file"; rm "$file"; done 17 | ``` 18 | 19 | This will extract and remove each tar file in turn. 20 | 21 | Once all of the images have been extracted, the resulting directory should be usable as a data directory (the `--data_dir` argument for the training script). The filenames should all start with WNID (class ids) followed by underscores, like `n01440764_2708.JPEG`. Conveniently (but not by accident) this is how the automated data-loader expects to discover class labels. 22 | 23 | ## CIFAR-10 24 | 25 | For CIFAR-10, we created a script [cifar10.py](cifar10.py) that creates `cifar_train` and `cifar_test` directories. These directories contain files named like `truck_49997.png`, so that the class name is discernable to the data loader. 26 | 27 | The `cifar_train` and `cifar_test` directories can be passed directly to the training scripts via the `--data_dir` argument. 28 | 29 | ## LSUN bedroom 30 | 31 | To download and pre-process LSUN bedroom, clone [fyu/lsun](https://github.com/fyu/lsun) on GitHub and run their download script `python3 download.py bedroom`. The result will be an "lmdb" database named like `bedroom_train_lmdb`. You can pass this to our [lsun_bedroom.py](lsun_bedroom.py) script like so: 32 | 33 | ``` 34 | python lsun_bedroom.py bedroom_train_lmdb lsun_train_output_dir 35 | ``` 36 | 37 | This creates a directory called `lsun_train_output_dir`. This directory can be passed to the training scripts via the `--data_dir` argument. 38 | -------------------------------------------------------------------------------- /PTQ4DM/improved-diffusion/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import torchvision 5 | from tqdm.auto import tqdm 6 | 7 | CLASSES = ( 8 | "plane", 9 | "car", 10 | "bird", 11 | "cat", 12 | "deer", 13 | "dog", 14 | "frog", 15 | "horse", 16 | "ship", 17 | "truck", 18 | ) 19 | 20 | 21 | def main(): 22 | for split in ["train", "test"]: 23 | out_dir = f"cifar_{split}" 24 | if os.path.exists(out_dir): 25 | print(f"skipping split {split} since {out_dir} already exists.") 26 | continue 27 | 28 | print("downloading...") 29 | with tempfile.TemporaryDirectory() as tmp_dir: 30 | dataset = torchvision.datasets.CIFAR10( 31 | root=tmp_dir, train=split == "train", download=True 32 | ) 33 | 34 | print("dumping images...") 35 | os.mkdir(out_dir) 36 | for i in tqdm(range(len(dataset))): 37 | image, label = dataset[i] 38 | filename = os.path.join(out_dir, f"{CLASSES[label]}_{i:05d}.png") 39 | image.save(filename) 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /PTQ4DM/improved-diffusion/datasets/lsun_bedroom.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert an LSUN lmdb database into a directory of images. 3 | """ 4 | 5 | import argparse 6 | import io 7 | import os 8 | 9 | from PIL import Image 10 | import lmdb 11 | import numpy as np 12 | 13 | 14 | def read_images(lmdb_path, image_size): 15 | env = lmdb.open(lmdb_path, map_size=1099511627776, max_readers=100, readonly=True) 16 | with env.begin(write=False) as transaction: 17 | cursor = transaction.cursor() 18 | for _, webp_data in cursor: 19 | img = Image.open(io.BytesIO(webp_data)) 20 | width, height = img.size 21 | scale = image_size / min(width, height) 22 | img = img.resize( 23 | (int(round(scale * width)), int(round(scale * height))), 24 | resample=Image.BOX, 25 | ) 26 | arr = np.array(img) 27 | h, w, _ = arr.shape 28 | h_off = (h - image_size) // 2 29 | w_off = (w - image_size) // 2 30 | arr = arr[h_off : h_off + image_size, w_off : w_off + image_size] 31 | yield arr 32 | 33 | 34 | def dump_images(out_dir, images, prefix): 35 | if not os.path.exists(out_dir): 36 | os.mkdir(out_dir) 37 | for i, img in enumerate(images): 38 | Image.fromarray(img).save(os.path.join(out_dir, f"{prefix}_{i:07d}.png")) 39 | 40 | 41 | def main(): 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("--image-size", help="new image size", type=int, default=256) 44 | parser.add_argument("--prefix", help="class name", type=str, default="bedroom") 45 | parser.add_argument("lmdb_path", help="path to an LSUN lmdb database") 46 | parser.add_argument("out_dir", help="path to output directory") 47 | args = parser.parse_args() 48 | 49 | images = read_images(args.lmdb_path, args.image_size) 50 | dump_images(args.out_dir, images, args.prefix) 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /PTQ4DM/improved-diffusion/improved_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Codebase for "Improved Denoising Diffusion Probabilistic Models". 3 | """ 4 | -------------------------------------------------------------------------------- /PTQ4DM/improved-diffusion/improved_diffusion/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | # Change this to reflect your cluster layout. 15 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 16 | GPUS_PER_NODE = 8 17 | 18 | SETUP_RETRY_COUNT = 3 19 | 20 | 21 | def setup_dist(): 22 | """ 23 | Setup a distributed process group. 24 | """ 25 | if dist.is_initialized(): 26 | return 27 | 28 | comm = MPI.COMM_WORLD 29 | backend = "gloo" if not th.cuda.is_available() else "nccl" 30 | 31 | if backend == "gloo": 32 | hostname = "localhost" 33 | else: 34 | hostname = socket.gethostbyname(socket.getfqdn()) 35 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 36 | os.environ["RANK"] = str(comm.rank) 37 | os.environ["WORLD_SIZE"] = str(comm.size) 38 | 39 | port = comm.bcast(_find_free_port(), root=0) 40 | os.environ["MASTER_PORT"] = str(port) 41 | dist.init_process_group(backend=backend, init_method="env://") 42 | 43 | 44 | def dev(): 45 | """ 46 | Get the device to use for torch.distributed. 47 | """ 48 | if th.cuda.is_available(): 49 | return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}") 50 | return th.device("cpu") 51 | 52 | 53 | def load_state_dict(path, **kwargs): 54 | """ 55 | Load a PyTorch file without redundant fetches across MPI ranks. 56 | """ 57 | if MPI.COMM_WORLD.Get_rank() == 0: 58 | with bf.BlobFile(path, "rb") as f: 59 | data = f.read() 60 | else: 61 | data = None 62 | data = MPI.COMM_WORLD.bcast(data) 63 | return th.load(io.BytesIO(data), **kwargs) 64 | 65 | 66 | def sync_params(params): 67 | """ 68 | Synchronize a sequence of Tensors across ranks from rank 0. 69 | """ 70 | for p in params: 71 | with th.no_grad(): 72 | dist.broadcast(p, 0) 73 | 74 | 75 | def _find_free_port(): 76 | try: 77 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 78 | s.bind(("", 0)) 79 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 80 | return s.getsockname()[1] 81 | finally: 82 | s.close() 83 | -------------------------------------------------------------------------------- /PTQ4DM/improved-diffusion/improved_diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import torch.nn as nn 6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 7 | 8 | 9 | def convert_module_to_f16(l): 10 | """ 11 | Convert primitive modules to float16. 12 | """ 13 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 14 | l.weight.data = l.weight.data.half() 15 | l.bias.data = l.bias.data.half() 16 | 17 | 18 | def convert_module_to_f32(l): 19 | """ 20 | Convert primitive modules to float32, undoing convert_module_to_f16(). 21 | """ 22 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 23 | l.weight.data = l.weight.data.float() 24 | l.bias.data = l.bias.data.float() 25 | 26 | 27 | def make_master_params(model_params): 28 | """ 29 | Copy model parameters into a (differently-shaped) list of full-precision 30 | parameters. 31 | """ 32 | master_params = _flatten_dense_tensors( 33 | [param.detach().float() for param in model_params] 34 | ) 35 | master_params = nn.Parameter(master_params) 36 | master_params.requires_grad = True 37 | return [master_params] 38 | 39 | 40 | def model_grads_to_master_grads(model_params, master_params): 41 | """ 42 | Copy the gradients from the model parameters into the master parameters 43 | from make_master_params(). 44 | """ 45 | master_params[0].grad = _flatten_dense_tensors( 46 | [param.grad.data.detach().float() for param in model_params] 47 | ) 48 | 49 | 50 | def master_params_to_model_params(model_params, master_params): 51 | """ 52 | Copy the master parameter data back into the model parameters. 53 | """ 54 | # Without copying to a list, if a generator is passed, this will 55 | # silently not copy any parameters. 56 | model_params = list(model_params) 57 | 58 | for param, master_param in zip( 59 | model_params, unflatten_master_params(model_params, master_params) 60 | ): 61 | param.detach().copy_(master_param) 62 | 63 | 64 | def unflatten_master_params(model_params, master_params): 65 | """ 66 | Unflatten the master parameters to look like model_params. 67 | """ 68 | return _unflatten_dense_tensors(master_params[0].detach(), model_params) 69 | 70 | 71 | def zero_grad(model_params): 72 | for param in model_params: 73 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 74 | if param.grad is not None: 75 | param.grad.detach_() 76 | param.grad.zero_() 77 | -------------------------------------------------------------------------------- /PTQ4DM/improved-diffusion/improved_diffusion/image_datasets.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import blobfile as bf 3 | from mpi4py import MPI 4 | import numpy as np 5 | from torch.utils.data import DataLoader, Dataset 6 | 7 | 8 | def load_data( 9 | *, data_dir, batch_size, image_size, class_cond=False, deterministic=False 10 | ): 11 | """ 12 | For a dataset, create a generator over (images, kwargs) pairs. 13 | 14 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 15 | more keys, each of which map to a batched Tensor of their own. 16 | The kwargs dict can be used for class labels, in which case the key is "y" 17 | and the values are integer tensors of class labels. 18 | 19 | :param data_dir: a dataset directory. 20 | :param batch_size: the batch size of each returned pair. 21 | :param image_size: the size to which images are resized. 22 | :param class_cond: if True, include a "y" key in returned dicts for class 23 | label. If classes are not available and this is true, an 24 | exception will be raised. 25 | :param deterministic: if True, yield results in a deterministic order. 26 | """ 27 | if not data_dir: 28 | raise ValueError("unspecified data directory") 29 | all_files = _list_image_files_recursively(data_dir) 30 | classes = None 31 | if class_cond: 32 | # Assume classes are the first part of the filename, 33 | # before an underscore. 34 | class_names = [bf.basename(path).split("_")[0] for path in all_files] 35 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} 36 | classes = [sorted_classes[x] for x in class_names] 37 | dataset = ImageDataset( 38 | image_size, 39 | all_files, 40 | classes=classes, 41 | shard=MPI.COMM_WORLD.Get_rank(), 42 | num_shards=MPI.COMM_WORLD.Get_size(), 43 | ) 44 | if deterministic: 45 | loader = DataLoader( 46 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True 47 | ) 48 | else: 49 | loader = DataLoader( 50 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True 51 | ) 52 | while True: 53 | yield from loader 54 | 55 | 56 | def _list_image_files_recursively(data_dir): 57 | results = [] 58 | for entry in sorted(bf.listdir(data_dir)): 59 | full_path = bf.join(data_dir, entry) 60 | ext = entry.split(".")[-1] 61 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: 62 | results.append(full_path) 63 | elif bf.isdir(full_path): 64 | results.extend(_list_image_files_recursively(full_path)) 65 | return results 66 | 67 | 68 | class ImageDataset(Dataset): 69 | def __init__(self, resolution, image_paths, classes=None, shard=0, num_shards=1): 70 | super().__init__() 71 | self.resolution = resolution 72 | self.local_images = image_paths[shard:][::num_shards] 73 | self.local_classes = None if classes is None else classes[shard:][::num_shards] 74 | 75 | def __len__(self): 76 | return len(self.local_images) 77 | 78 | def __getitem__(self, idx): 79 | path = self.local_images[idx] 80 | with bf.BlobFile(path, "rb") as f: 81 | pil_image = Image.open(f) 82 | pil_image.load() 83 | 84 | # We are not on a new enough PIL to support the `reducing_gap` 85 | # argument, which uses BOX downsampling at powers of two first. 86 | # Thus, we do it by hand to improve downsample quality. 87 | while min(*pil_image.size) >= 2 * self.resolution: 88 | pil_image = pil_image.resize( 89 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 90 | ) 91 | 92 | scale = self.resolution / min(*pil_image.size) 93 | pil_image = pil_image.resize( 94 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 95 | ) 96 | 97 | arr = np.array(pil_image.convert("RGB")) 98 | crop_y = (arr.shape[0] - self.resolution) // 2 99 | crop_x = (arr.shape[1] - self.resolution) // 2 100 | arr = arr[crop_y : crop_y + self.resolution, crop_x : crop_x + self.resolution] 101 | arr = arr.astype(np.float32) / 127.5 - 1 102 | 103 | out_dict = {} 104 | if self.local_classes is not None: 105 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) 106 | return np.transpose(arr, [2, 0, 1]), out_dict 107 | -------------------------------------------------------------------------------- /PTQ4DM/improved-diffusion/improved_diffusion/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /PTQ4DM/improved-diffusion/improved_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /PTQ4DM/improved-diffusion/improved_diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /PTQ4DM/improved-diffusion/improved_diffusion/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | 77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 78 | last_alpha_cumprod = 1.0 79 | new_betas = [] 80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 81 | if i in self.use_timesteps: 82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 83 | last_alpha_cumprod = alpha_cumprod 84 | self.timestep_map.append(i) 85 | kwargs["betas"] = np.array(new_betas) 86 | super().__init__(**kwargs) 87 | 88 | def p_mean_variance( 89 | self, model, *args, **kwargs 90 | ): # pylint: disable=signature-differs 91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 92 | 93 | def training_losses( 94 | self, model, *args, **kwargs 95 | ): # pylint: disable=signature-differs 96 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 97 | 98 | def _wrap_model(self, model): 99 | if isinstance(model, _WrappedModel): 100 | return model 101 | return _WrappedModel( 102 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 103 | ) 104 | 105 | def _scale_timesteps(self, t): 106 | # Scaling is done by the wrapped model. 107 | return t 108 | 109 | 110 | class _WrappedModel: 111 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 112 | self.model = model 113 | self.timestep_map = timestep_map 114 | self.rescale_timesteps = rescale_timesteps 115 | self.original_num_steps = original_num_steps 116 | 117 | def __call__(self, x, ts, **kwargs): 118 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 119 | new_ts = map_tensor[ts] 120 | if self.rescale_timesteps: 121 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 122 | return self.model(x, new_ts, **kwargs) 123 | -------------------------------------------------------------------------------- /PTQ4DM/improved-diffusion/scripts/image_nll.py: -------------------------------------------------------------------------------- 1 | """ 2 | Approximate the bits/dimension for an image model. 3 | """ 4 | 5 | import argparse 6 | import os 7 | 8 | import numpy as np 9 | import torch.distributed as dist 10 | 11 | from improved_diffusion import dist_util, logger 12 | from improved_diffusion.image_datasets import load_data 13 | from improved_diffusion.script_util import ( 14 | model_and_diffusion_defaults, 15 | create_model_and_diffusion, 16 | add_dict_to_argparser, 17 | args_to_dict, 18 | ) 19 | 20 | 21 | def main(): 22 | args = create_argparser().parse_args() 23 | 24 | dist_util.setup_dist() 25 | logger.configure() 26 | 27 | logger.log("creating model and diffusion...") 28 | model, diffusion = create_model_and_diffusion( 29 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 30 | ) 31 | model.load_state_dict( 32 | dist_util.load_state_dict(args.model_path, map_location="cpu") 33 | ) 34 | model.to(dist_util.dev()) 35 | model.eval() 36 | 37 | logger.log("creating data loader...") 38 | data = load_data( 39 | data_dir=args.data_dir, 40 | batch_size=args.batch_size, 41 | image_size=args.image_size, 42 | class_cond=args.class_cond, 43 | deterministic=True, 44 | ) 45 | 46 | logger.log("evaluating...") 47 | run_bpd_evaluation(model, diffusion, data, args.num_samples, args.clip_denoised) 48 | 49 | 50 | def run_bpd_evaluation(model, diffusion, data, num_samples, clip_denoised): 51 | all_bpd = [] 52 | all_metrics = {"vb": [], "mse": [], "xstart_mse": []} 53 | num_complete = 0 54 | while num_complete < num_samples: 55 | batch, model_kwargs = next(data) 56 | batch = batch.to(dist_util.dev()) 57 | model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()} 58 | minibatch_metrics = diffusion.calc_bpd_loop( 59 | model, batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs 60 | ) 61 | 62 | for key, term_list in all_metrics.items(): 63 | terms = minibatch_metrics[key].mean(dim=0) / dist.get_world_size() 64 | dist.all_reduce(terms) 65 | term_list.append(terms.detach().cpu().numpy()) 66 | 67 | total_bpd = minibatch_metrics["total_bpd"] 68 | total_bpd = total_bpd.mean() / dist.get_world_size() 69 | dist.all_reduce(total_bpd) 70 | all_bpd.append(total_bpd.item()) 71 | num_complete += dist.get_world_size() * batch.shape[0] 72 | 73 | logger.log(f"done {num_complete} samples: bpd={np.mean(all_bpd)}") 74 | 75 | if dist.get_rank() == 0: 76 | for name, terms in all_metrics.items(): 77 | out_path = os.path.join(logger.get_dir(), f"{name}_terms.npz") 78 | logger.log(f"saving {name} terms to {out_path}") 79 | np.savez(out_path, np.mean(np.stack(terms), axis=0)) 80 | 81 | dist.barrier() 82 | logger.log("evaluation complete") 83 | 84 | 85 | def create_argparser(): 86 | defaults = dict( 87 | data_dir="", clip_denoised=True, num_samples=1000, batch_size=1, model_path="" 88 | ) 89 | defaults.update(model_and_diffusion_defaults()) 90 | parser = argparse.ArgumentParser() 91 | add_dict_to_argparser(parser, defaults) 92 | return parser 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /PTQ4DM/improved-diffusion/scripts/image_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | 6 | import argparse 7 | import os 8 | 9 | import numpy as np 10 | import torch as th 11 | import torch.distributed as dist 12 | 13 | from improved_diffusion import dist_util, logger 14 | from improved_diffusion.script_util import ( 15 | NUM_CLASSES, 16 | model_and_diffusion_defaults, 17 | create_model_and_diffusion, 18 | add_dict_to_argparser, 19 | args_to_dict, 20 | ) 21 | 22 | 23 | def main(): 24 | args = create_argparser().parse_args() 25 | 26 | dist_util.setup_dist() 27 | logger.configure() 28 | 29 | logger.log("creating model and diffusion...") 30 | model, diffusion = create_model_and_diffusion( 31 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 32 | ) 33 | model.load_state_dict( 34 | dist_util.load_state_dict(args.model_path, map_location="cpu") 35 | ) 36 | model.to(dist_util.dev()) 37 | model.eval() 38 | 39 | logger.log("sampling...") 40 | all_images = [] 41 | all_labels = [] 42 | while len(all_images) * args.batch_size < args.num_samples: 43 | model_kwargs = {} 44 | if args.class_cond: 45 | classes = th.randint( 46 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev() 47 | ) 48 | model_kwargs["y"] = classes 49 | sample_fn = ( 50 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop 51 | ) 52 | sample = sample_fn( 53 | model, 54 | (args.batch_size, 3, args.image_size, args.image_size), 55 | clip_denoised=args.clip_denoised, 56 | model_kwargs=model_kwargs, 57 | ) 58 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) 59 | sample = sample.permute(0, 2, 3, 1) 60 | sample = sample.contiguous() 61 | 62 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 63 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL 64 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples]) 65 | if args.class_cond: 66 | gathered_labels = [ 67 | th.zeros_like(classes) for _ in range(dist.get_world_size()) 68 | ] 69 | dist.all_gather(gathered_labels, classes) 70 | all_labels.extend([labels.cpu().numpy() for labels in gathered_labels]) 71 | logger.log(f"created {len(all_images) * args.batch_size} samples") 72 | 73 | arr = np.concatenate(all_images, axis=0) 74 | arr = arr[: args.num_samples] 75 | if args.class_cond: 76 | label_arr = np.concatenate(all_labels, axis=0) 77 | label_arr = label_arr[: args.num_samples] 78 | if dist.get_rank() == 0: 79 | _shape=arr.shape[1:] 80 | shape_str = "x".join([str(x) for x in _shape]) 81 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz") 82 | logger.log(f"saving to {out_path}") 83 | if args.class_cond: 84 | np.savez(out_path, arr, label_arr) 85 | else: 86 | np.savez(out_path, arr) 87 | 88 | dist.barrier() 89 | logger.log("sampling complete") 90 | 91 | 92 | def create_argparser(): 93 | defaults = dict( 94 | clip_denoised=True, 95 | num_samples=10000, 96 | batch_size=16, 97 | use_ddim=False, 98 | model_path="", 99 | ) 100 | defaults.update(model_and_diffusion_defaults()) 101 | parser = argparse.ArgumentParser() 102 | add_dict_to_argparser(parser, defaults) 103 | return parser 104 | 105 | 106 | if __name__ == "__main__": 107 | main() 108 | -------------------------------------------------------------------------------- /PTQ4DM/improved-diffusion/scripts/image_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | 5 | import argparse 6 | 7 | from improved_diffusion import dist_util, logger 8 | from improved_diffusion.image_datasets import load_data 9 | from improved_diffusion.resample import create_named_schedule_sampler 10 | from improved_diffusion.script_util import ( 11 | model_and_diffusion_defaults, 12 | create_model_and_diffusion, 13 | args_to_dict, 14 | add_dict_to_argparser, 15 | ) 16 | from improved_diffusion.train_util import TrainLoop 17 | 18 | 19 | def main(): 20 | args = create_argparser().parse_args() 21 | 22 | dist_util.setup_dist() 23 | logger.configure() 24 | 25 | logger.log("creating model and diffusion...") 26 | model, diffusion = create_model_and_diffusion( 27 | **args_to_dict(args, model_and_diffusion_defaults().keys()) 28 | ) 29 | model.to(dist_util.dev()) 30 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 31 | 32 | logger.log("creating data loader...") 33 | data = load_data( 34 | data_dir=args.data_dir, 35 | batch_size=args.batch_size, 36 | image_size=args.image_size, 37 | class_cond=args.class_cond, 38 | ) 39 | 40 | logger.log("training...") 41 | TrainLoop( 42 | model=model, 43 | diffusion=diffusion, 44 | data=data, 45 | batch_size=args.batch_size, 46 | microbatch=args.microbatch, 47 | lr=args.lr, 48 | ema_rate=args.ema_rate, 49 | log_interval=args.log_interval, 50 | save_interval=args.save_interval, 51 | resume_checkpoint=args.resume_checkpoint, 52 | use_fp16=args.use_fp16, 53 | fp16_scale_growth=args.fp16_scale_growth, 54 | schedule_sampler=schedule_sampler, 55 | weight_decay=args.weight_decay, 56 | lr_anneal_steps=args.lr_anneal_steps, 57 | ).run_loop() 58 | 59 | 60 | def create_argparser(): 61 | defaults = dict( 62 | data_dir="", 63 | schedule_sampler="uniform", 64 | lr=1e-4, 65 | weight_decay=0.0, 66 | lr_anneal_steps=0, 67 | batch_size=1, 68 | microbatch=-1, # -1 disables microbatches 69 | ema_rate="0.9999", # comma-separated list of EMA values 70 | log_interval=10, 71 | save_interval=10000, 72 | resume_checkpoint="", 73 | use_fp16=False, 74 | fp16_scale_growth=1e-3, 75 | ) 76 | defaults.update(model_and_diffusion_defaults()) 77 | parser = argparse.ArgumentParser() 78 | add_dict_to_argparser(parser, defaults) 79 | return parser 80 | 81 | 82 | if __name__ == "__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /PTQ4DM/improved-diffusion/scripts/super_res_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of samples from a super resolution model, given a batch 3 | of samples from a regular model from image_sample.py. 4 | """ 5 | 6 | import argparse 7 | import os 8 | 9 | import blobfile as bf 10 | import numpy as np 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | from improved_diffusion import dist_util, logger 15 | from improved_diffusion.script_util import ( 16 | sr_model_and_diffusion_defaults, 17 | sr_create_model_and_diffusion, 18 | args_to_dict, 19 | add_dict_to_argparser, 20 | ) 21 | 22 | 23 | def main(): 24 | args = create_argparser().parse_args() 25 | 26 | dist_util.setup_dist() 27 | logger.configure() 28 | 29 | logger.log("creating model...") 30 | model, diffusion = sr_create_model_and_diffusion( 31 | **args_to_dict(args, sr_model_and_diffusion_defaults().keys()) 32 | ) 33 | model.load_state_dict( 34 | dist_util.load_state_dict(args.model_path, map_location="cpu") 35 | ) 36 | model.to(dist_util.dev()) 37 | model.eval() 38 | 39 | logger.log("loading data...") 40 | data = load_data_for_worker(args.base_samples, args.batch_size, args.class_cond) 41 | 42 | logger.log("creating samples...") 43 | all_images = [] 44 | while len(all_images) * args.batch_size < args.num_samples: 45 | model_kwargs = next(data) 46 | model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()} 47 | sample = diffusion.p_sample_loop( 48 | model, 49 | (args.batch_size, 3, args.large_size, args.large_size), 50 | clip_denoised=args.clip_denoised, 51 | model_kwargs=model_kwargs, 52 | ) 53 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) 54 | sample = sample.permute(0, 2, 3, 1) 55 | sample = sample.contiguous() 56 | 57 | all_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 58 | dist.all_gather(all_samples, sample) # gather not supported with NCCL 59 | for sample in all_samples: 60 | all_images.append(sample.cpu().numpy()) 61 | logger.log(f"created {len(all_images) * args.batch_size} samples") 62 | 63 | arr = np.concatenate(all_images, axis=0) 64 | arr = arr[: args.num_samples] 65 | if dist.get_rank() == 0: 66 | shape_str = "x".join([str(x) for x in arr.shape]) 67 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz") 68 | logger.log(f"saving to {out_path}") 69 | np.savez(out_path, arr) 70 | 71 | dist.barrier() 72 | logger.log("sampling complete") 73 | 74 | 75 | def load_data_for_worker(base_samples, batch_size, class_cond): 76 | with bf.BlobFile(base_samples, "rb") as f: 77 | obj = np.load(f) 78 | image_arr = obj["arr_0"] 79 | if class_cond: 80 | label_arr = obj["arr_1"] 81 | rank = dist.get_rank() 82 | num_ranks = dist.get_world_size() 83 | buffer = [] 84 | label_buffer = [] 85 | while True: 86 | for i in range(rank, len(image_arr), num_ranks): 87 | buffer.append(image_arr[i]) 88 | if class_cond: 89 | label_buffer.append(label_arr[i]) 90 | if len(buffer) == batch_size: 91 | batch = th.from_numpy(np.stack(buffer)).float() 92 | batch = batch / 127.5 - 1.0 93 | batch = batch.permute(0, 3, 1, 2) 94 | res = dict(low_res=batch) 95 | if class_cond: 96 | res["y"] = th.from_numpy(np.stack(label_buffer)) 97 | yield res 98 | buffer, label_buffer = [], [] 99 | 100 | 101 | def create_argparser(): 102 | defaults = dict( 103 | clip_denoised=True, 104 | num_samples=10000, 105 | batch_size=16, 106 | use_ddim=False, 107 | base_samples="", 108 | model_path="", 109 | ) 110 | defaults.update(sr_model_and_diffusion_defaults()) 111 | parser = argparse.ArgumentParser() 112 | add_dict_to_argparser(parser, defaults) 113 | return parser 114 | 115 | 116 | if __name__ == "__main__": 117 | main() 118 | -------------------------------------------------------------------------------- /PTQ4DM/improved-diffusion/scripts/super_res_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a super-resolution model. 3 | """ 4 | 5 | import argparse 6 | 7 | import torch.nn.functional as F 8 | 9 | from improved_diffusion import dist_util, logger 10 | from improved_diffusion.image_datasets import load_data 11 | from improved_diffusion.resample import create_named_schedule_sampler 12 | from improved_diffusion.script_util import ( 13 | sr_model_and_diffusion_defaults, 14 | sr_create_model_and_diffusion, 15 | args_to_dict, 16 | add_dict_to_argparser, 17 | ) 18 | from improved_diffusion.train_util import TrainLoop 19 | 20 | 21 | def main(): 22 | args = create_argparser().parse_args() 23 | 24 | dist_util.setup_dist() 25 | logger.configure() 26 | 27 | logger.log("creating model...") 28 | model, diffusion = sr_create_model_and_diffusion( 29 | **args_to_dict(args, sr_model_and_diffusion_defaults().keys()) 30 | ) 31 | model.to(dist_util.dev()) 32 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 33 | 34 | logger.log("creating data loader...") 35 | data = load_superres_data( 36 | args.data_dir, 37 | args.batch_size, 38 | large_size=args.large_size, 39 | small_size=args.small_size, 40 | class_cond=args.class_cond, 41 | ) 42 | 43 | logger.log("training...") 44 | TrainLoop( 45 | model=model, 46 | diffusion=diffusion, 47 | data=data, 48 | batch_size=args.batch_size, 49 | microbatch=args.microbatch, 50 | lr=args.lr, 51 | ema_rate=args.ema_rate, 52 | log_interval=args.log_interval, 53 | save_interval=args.save_interval, 54 | resume_checkpoint=args.resume_checkpoint, 55 | use_fp16=args.use_fp16, 56 | fp16_scale_growth=args.fp16_scale_growth, 57 | schedule_sampler=schedule_sampler, 58 | weight_decay=args.weight_decay, 59 | lr_anneal_steps=args.lr_anneal_steps, 60 | ).run_loop() 61 | 62 | 63 | def load_superres_data(data_dir, batch_size, large_size, small_size, class_cond=False): 64 | data = load_data( 65 | data_dir=data_dir, 66 | batch_size=batch_size, 67 | image_size=large_size, 68 | class_cond=class_cond, 69 | ) 70 | for large_batch, model_kwargs in data: 71 | model_kwargs["low_res"] = F.interpolate(large_batch, small_size, mode="area") 72 | yield large_batch, model_kwargs 73 | 74 | 75 | def create_argparser(): 76 | defaults = dict( 77 | data_dir="", 78 | schedule_sampler="uniform", 79 | lr=1e-4, 80 | weight_decay=0.0, 81 | lr_anneal_steps=0, 82 | batch_size=1, 83 | microbatch=-1, 84 | ema_rate="0.9999", 85 | log_interval=10, 86 | save_interval=10000, 87 | resume_checkpoint="", 88 | use_fp16=False, 89 | fp16_scale_growth=1e-3, 90 | ) 91 | defaults.update(sr_model_and_diffusion_defaults()) 92 | parser = argparse.ArgumentParser() 93 | add_dict_to_argparser(parser, defaults) 94 | return parser 95 | 96 | 97 | if __name__ == "__main__": 98 | main() 99 | -------------------------------------------------------------------------------- /PTQ4DM/improved-diffusion/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="improved-diffusion", 5 | py_modules=["improved_diffusion"], 6 | install_requires=["blobfile>=1.0.5", "torch", "tqdm"], 7 | ) 8 | -------------------------------------------------------------------------------- /PTQ4DM/quant_sample.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=".:guided-diffusion" 2 | # MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True" 3 | # python scripts/image_sample.py --image_size 256 --num_channels 128 --num_res_blocks 3 \ 4 | # --learn_sigma True --dropout 0.3 --diffusion_steps 4000 --noise_schedule cosine \ 5 | # --batch_size 1000 --num_samples 50001 --timestep_respacing 100 --use_ddim False \ 6 | # --model_path "/home/yzh/docs/pytorch/PTQDiffusionModel/guided-diffusion/models/256x256_diffusion_uncond.pt" 7 | # python classifier_sample.py $MODEL_FLAGS --classifier_scale 10.0 --classifier_path models/256x256_classifier.pt --model_path models/256x256_diffusion_uncond.pt $SAMPLE_FLAGS 8 | QUANT_FLAGS="--n_bits_w 8 --channel_wise --n_bits_a 8 --act_quant --order together --wwq --waq --awq --aaq \ 9 | --weight 0.01 --input_prob 0.5 --prob 0.5 --iters_w 100 --calib_num_samples 128" 10 | MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --dropout 0.1 --image_size 64 --learn_sigma True --noise_schedule cosine --num_channels 192 --num_head_channels 64 --num_res_blocks 3 --resblock_updown True --use_new_attention_order True --use_fp16 True --use_scale_shift_norm True" 11 | python guided-diffusion/scripts/quant_image_sample.py $QUANT_FLAGS $MODEL_FLAGS --model_path guided-diffusion/models/64x64_diffusion.pt $SAMPLE_FLAGS --num_samples 10 --batch_size 16 12 | # python -m cProfile -o quant.pstats scripts/quant_image_sample.py $QUANT_FLAGS $MODEL_FLAGS --model_path models/64x64_diffusion.pt $SAMPLE_FLAGS --num_samples 100 --batch_size 16 13 | -------------------------------------------------------------------------------- /PTQ4DM/quant_sample_ddim_in_backward_DNTC.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=".:guided-diffusion:improved-diffusion" 2 | 3 | QUANT_FLAGS="--n_bits_w 8 --channel_wise --n_bits_a 8 --act_quant --order together --wwq --waq --awq --aaq \ 4 | --weight 0.01 --input_prob 0.5 --prob 0.5 --iters_w 100 --calib_num_samples 1024 \ 5 | --data_dir /datasets/imagenet --calib_im_mode noise_backward_t" 6 | MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3 --learn_sigma True" 7 | DIFFUSION_FLAGS="--diffusion_steps 4000 --timestep_respacing 250 --use_ddim True --noise_schedule cosine" 8 | 9 | BATCH_SIZE=500 10 | NUM_SAMPLES=10000 11 | 12 | export CUDA_VISIBLE_DEVICES="7" 13 | CALIB_FLAGS="--calib_t_mode normal --calib_t_mode_normal_mean 0.4 --calib_t_mode_normal_std 0.4 --out_path /home/shangyuzhang/diffusion_models/PTQ4DM/results/random8-normalmean04std040_ddim250.npz" 14 | #python improved-diffusion/scripts/quant_image_sample.py $CALIB_FLAGS $QUANT_FLAGS $MODEL_FLAGS --model_path /pretrained-model-path/imagenet64_uncond_100M_1500K.pt $DIFFUSION_FLAGS --num_samples $NUM_SAMPLES --batch_size $BATCH_SIZE & 15 | python improved-diffusion/scripts/quant_image_sample.py $CALIB_FLAGS $QUANT_FLAGS $MODEL_FLAGS --model_path /home/shangyuzhang/diffusion_models/ptqdiffusionmodel/models/imagenet64_uncond_100M_1500K.pt $DIFFUSION_FLAGS --num_samples $NUM_SAMPLES --batch_size $BATCH_SIZE & 16 | -------------------------------------------------------------------------------- /PTQ4DM/quant_sample_ddim_in_forward.sh: -------------------------------------------------------------------------------- 1 | 2 | export PYTHONPATH=".:guided-diffusion:improved-diffusion" 3 | # MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True" 4 | # python scripts/image_sample.py --image_size 256 --num_channels 128 --num_res_blocks 3 \ 5 | # --learn_sigma True --dropout 0.3 --diffusion_steps 4000 --noise_schedule cosine \ 6 | # --batch_size 1000 --num_samples 50001 --timestep_respacing 100 --use_ddim False \ 7 | # --model_path "/home/yzh/docs/pytorch/PTQDiffusionModel/guided-diffusion/models/256x256_diffusion_uncond.pt" 8 | # python classifier_sample.py $MODEL_FLAGS --classifier_scale 10.0 --classifier_path models/256x256_classifier.pt --model_path models/256x256_diffusion_uncond.pt $SAMPLE_FLAGS 9 | QUANT_FLAGS="--n_bits_w 8 --channel_wise --n_bits_a 8 --act_quant --order together --wwq --waq --awq --aaq \ 10 | --weight 0.01 --input_prob 0.5 --prob 0.5 --iters_w 100 --calib_num_samples 1024 \ 11 | --data_dir /datasets/imagenet --calib_im_mode raw_forward_t" 12 | MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3 --learn_sigma True" 13 | DIFFUSION_FLAGS="--diffusion_steps 4000 --timestep_respacing 250 --use_ddim True --noise_schedule cosine" 14 | 15 | BATCH_SIZE=32 16 | NUM_SAMPLES=10000 17 | 18 | export CUDA_VISIBLE_DEVICES="2" 19 | CALIB_FLAGS="--calib_t_mode random" 20 | python improved-diffusion/scripts/quant_image_sample.py $CALIB_FLAGS $QUANT_FLAGS $MODEL_FLAGS --model_path guided-diffusion/models/imagenet64_uncond_100M_1500K.pt $DIFFUSION_FLAGS --num_samples $NUM_SAMPLES --batch_size $BATCH_SIZE & 21 | 22 | # export CUDA_VISIBLE_DEVICES="1" 23 | # CALIB_FLAGS="--calib_t_mode -1" 24 | # python improved-diffusion/scripts/quant_image_sample.py $CALIB_FLAGS $QUANT_FLAGS $MODEL_FLAGS --model_path guided-diffusion/models/imagenet64_uncond_100M_1500K.pt $DIFFUSION_FLAGS --num_samples $NUM_SAMPLES --batch_size $BATCH_SIZE & 25 | 26 | # export CUDA_VISIBLE_DEVICES="2" 27 | # CALIB_FLAGS="--calib_t_mode 1" 28 | # python improved-diffusion/scripts/quant_image_sample.py $CALIB_FLAGS $QUANT_FLAGS $MODEL_FLAGS --model_path guided-diffusion/models/imagenet64_uncond_100M_1500K.pt $DIFFUSION_FLAGS --num_samples $NUM_SAMPLES --batch_size $BATCH_SIZE & 29 | 30 | export CUDA_VISIBLE_DEVICES="3" 31 | CALIB_FLAGS="--calib_t_mode uniform" 32 | python improved-diffusion/scripts/quant_image_sample.py $CALIB_FLAGS $QUANT_FLAGS $MODEL_FLAGS --model_path guided-diffusion/models/imagenet64_uncond_100M_1500K.pt $DIFFUSION_FLAGS --num_samples $NUM_SAMPLES --batch_size $BATCH_SIZE & 33 | 34 | 35 | # export CUDA_VISIBLE_DEVICES="4" 36 | # CALIB_FLAGS="--calib_t_mode mean" 37 | # python improved-diffusion/scripts/quant_image_sample.py $CALIB_FLAGS $QUANT_FLAGS $MODEL_FLAGS --model_path guided-diffusion/models/imagenet64_uncond_100M_1500K.pt $DIFFUSION_FLAGS --num_samples $NUM_SAMPLES --batch_size $BATCH_SIZE & 38 | 39 | -------------------------------------------------------------------------------- /PTQ4DM/quant_sample_ddim_in_random.sh: -------------------------------------------------------------------------------- 1 | 2 | export PYTHONPATH=".:guided-diffusion:improved-diffusion" 3 | # MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True" 4 | # python scripts/image_sample.py --image_size 256 --num_channels 128 --num_res_blocks 3 \ 5 | # --learn_sigma True --dropout 0.3 --diffusion_steps 4000 --noise_schedule cosine \ 6 | # --batch_size 1000 --num_samples 50001 --timestep_respacing 100 --use_ddim False \ 7 | # --model_path "/home/yzh/docs/pytorch/PTQDiffusionModel/guided-diffusion/models/256x256_diffusion_uncond.pt" 8 | # python classifier_sample.py $MODEL_FLAGS --classifier_scale 10.0 --classifier_path models/256x256_classifier.pt --model_path models/256x256_diffusion_uncond.pt $SAMPLE_FLAGS 9 | QUANT_FLAGS="--n_bits_w 8 --channel_wise --n_bits_a 8 --act_quant --order together --wwq --waq --awq --aaq \ 10 | --weight 0.01 --input_prob 0.5 --prob 0.5 --iters_w 100 --calib_num_samples 128 \ 11 | --data_dir /datasets/imagenet --calib_im_mode random" 12 | MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3 --learn_sigma True" 13 | DIFFUSION_FLAGS="--diffusion_steps 4000 --timestep_respacing 250 --use_ddim True --noise_schedule cosine" 14 | 15 | BATCH_SIZE=600 16 | NUM_SAMPLES=10000 17 | 18 | export CUDA_VISIBLE_DEVICES="1" 19 | CALIB_FLAGS="--calib_t_mode random --out_path outputs_mixup_quant_on_cosine/random8-5bit.npz" 20 | python improved-diffusion/scripts/quant_image_sample.py $CALIB_FLAGS $QUANT_FLAGS $MODEL_FLAGS --model_path /home/shangyuzhang/diffusion_models/ptqdiffusionmodel/models/imagenet64_uncond_100M_1500K.pt $DIFFUSION_FLAGS --num_samples $NUM_SAMPLES --batch_size $BATCH_SIZE & 21 | 22 | # export CUDA_VISIBLE_DEVICES="1" 23 | # CALIB_FLAGS="--calib_t_mode -1" 24 | # python improved-diffusion/scripts/quant_image_sample.py $CALIB_FLAGS $QUANT_FLAGS $MODEL_FLAGS --model_path guided-diffusion/models/imagenet64_uncond_100M_1500K.pt $DIFFUSION_FLAGS --num_samples $NUM_SAMPLES --batch_size $BATCH_SIZE & 25 | 26 | # export CUDA_VISIBLE_DEVICES="2" 27 | # CALIB_FLAGS="--calib_t_mode 1" 28 | # python improved-diffusion/scripts/quant_image_sample.py $CALIB_FLAGS $QUANT_FLAGS $MODEL_FLAGS --model_path guided-diffusion/models/imagenet64_uncond_100M_1500K.pt $DIFFUSION_FLAGS --num_samples $NUM_SAMPLES --batch_size $BATCH_SIZE & 29 | 30 | #export CUDA_VISIBLE_DEVICES="1" 31 | #CALIB_FLAGS="--calib_t_mode uniform" 32 | #python improved-diffusion/scripts/quant_image_sample.py $CALIB_FLAGS $QUANT_FLAGS $MODEL_FLAGS --model_path guided-diffusion/models/imagenet64_uncond_100M_1500K.pt $DIFFUSION_FLAGS --num_samples $NUM_SAMPLES --batch_size $BATCH_SIZE & 33 | 34 | 35 | # export CUDA_VISIBLE_DEVICES="4" 36 | # CALIB_FLAGS="--calib_t_mode mean" 37 | # python improved-diffusion/scripts/quant_image_sample.py $CALIB_FLAGS $QUANT_FLAGS $MODEL_FLAGS --model_path guided-diffusion/models/imagenet64_uncond_100M_1500K.pt $DIFFUSION_FLAGS --num_samples $NUM_SAMPLES --batch_size $BATCH_SIZE & 38 | 39 | -------------------------------------------------------------------------------- /PTQ4DM/quant_sample_ddim_in_raw.sh: -------------------------------------------------------------------------------- 1 | 2 | export PYTHONPATH=".:guided-diffusion:improved-diffusion" 3 | # MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True" 4 | # python scripts/image_sample.py --image_size 256 --num_channels 128 --num_res_blocks 3 \ 5 | # --learn_sigma True --dropout 0.3 --diffusion_steps 4000 --noise_schedule cosine \ 6 | # --batch_size 1000 --num_samples 50001 --timestep_respacing 100 --use_ddim False \ 7 | # --model_path "/home/yzh/docs/pytorch/PTQDiffusionModel/guided-diffusion/models/256x256_diffusion_uncond.pt" 8 | # python classifier_sample.py $MODEL_FLAGS --classifier_scale 10.0 --classifier_path models/256x256_classifier.pt --model_path models/256x256_diffusion_uncond.pt $SAMPLE_FLAGS 9 | QUANT_FLAGS="--n_bits_w 8 --channel_wise --n_bits_a 8 --act_quant --order together --wwq --waq --awq --aaq \ 10 | --weight 0.01 --input_prob 0.5 --prob 0.5 --iters_w 100 --calib_num_samples 128 \ 11 | --data_dir /datasets/imagenet --calib_im_mode raw" 12 | MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3 --learn_sigma True" 13 | DIFFUSION_FLAGS="--diffusion_steps 4000 --timestep_respacing 250 --use_ddim True --noise_schedule cosine" 14 | 15 | BATCH_SIZE=32 16 | NUM_SAMPLES=10000 17 | 18 | export CUDA_VISIBLE_DEVICES="0" 19 | CALIB_FLAGS="--calib_t_mode random" 20 | python improved-diffusion/scripts/quant_image_sample.py $CALIB_FLAGS $QUANT_FLAGS $MODEL_FLAGS --model_path guided-diffusion/models/imagenet64_uncond_100M_1500K.pt $DIFFUSION_FLAGS --num_samples $NUM_SAMPLES --batch_size $BATCH_SIZE & 21 | 22 | export CUDA_VISIBLE_DEVICES="1" 23 | CALIB_FLAGS="--calib_t_mode -1" 24 | python improved-diffusion/scripts/quant_image_sample.py $CALIB_FLAGS $QUANT_FLAGS $MODEL_FLAGS --model_path guided-diffusion/models/imagenet64_uncond_100M_1500K.pt $DIFFUSION_FLAGS --num_samples $NUM_SAMPLES --batch_size $BATCH_SIZE & 25 | 26 | export CUDA_VISIBLE_DEVICES="2" 27 | CALIB_FLAGS="--calib_t_mode 1" 28 | python improved-diffusion/scripts/quant_image_sample.py $CALIB_FLAGS $QUANT_FLAGS $MODEL_FLAGS --model_path guided-diffusion/models/imagenet64_uncond_100M_1500K.pt $DIFFUSION_FLAGS --num_samples $NUM_SAMPLES --batch_size $BATCH_SIZE & 29 | 30 | export CUDA_VISIBLE_DEVICES="3" 31 | CALIB_FLAGS="--calib_t_mode uniform" 32 | python improved-diffusion/scripts/quant_image_sample.py $CALIB_FLAGS $QUANT_FLAGS $MODEL_FLAGS --model_path guided-diffusion/models/imagenet64_uncond_100M_1500K.pt $DIFFUSION_FLAGS --num_samples $NUM_SAMPLES --batch_size $BATCH_SIZE & 33 | 34 | 35 | export CUDA_VISIBLE_DEVICES="4" 36 | CALIB_FLAGS="--calib_t_mode mean" 37 | python improved-diffusion/scripts/quant_image_sample.py $CALIB_FLAGS $QUANT_FLAGS $MODEL_FLAGS --model_path guided-diffusion/models/imagenet64_uncond_100M_1500K.pt $DIFFUSION_FLAGS --num_samples $NUM_SAMPLES --batch_size $BATCH_SIZE & 38 | 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PTQ4DM: Post-training Quantization on Diffusion Models 2 | [Yuzhang Shang*](https://42shawn.github.io/), [Zhihang Yuan*](http://hahnyuan.com/), Bin Xie, Bingzhe Wu, and Yan Yan 3 | (* denote equal contribution) 4 | 5 | The code for the Post-training Quantization on Diffusion Models, which has been accepted to CVPR 2023. [paper](https://arxiv.org/abs/2211.15736) 6 | 7 | 8 | 9 | **_Key Obersevation_**: Studies on the activation distribution w.r.t. time-step. **(Upper)** Per (output) channel weight ranges of the first depthwise-separable layer in diffusion model on different timestep. In the boxplot, the min and max values, the 2nd and 3rd quartile, and the median are plotted for each channel. We only include the layers in the decoder of UNet for noise estimation, as the ranges of the encoder and decoder are quite different. **(Bottom)** Histograms of activations on different time-steps by various layers. We can observe that the distribution of activations changes dramatically with time-step, which makes **traditional single-time-step PTQ calibration methods inapplicable for diffusion models**. 10 | 11 | ## Quick Start 12 | First, download our repo: 13 | ```bash 14 | git clone https://github.com/42Shawn/PTQ4DM.git 15 | cd PTQ4DM 16 | ``` 17 | Then, run the DNTC script: 18 | ```bash 19 | bash quant_sample_ddim_in_backword_DNTC.sh 20 | ``` 21 | 22 | **Demo Result** 23 | baseline (full-precision IDDPM) => 8-bit PTQ4DM 24 | FID 21.7 => 24.3 25 | 26 | # Reference 27 | If you find our code useful for your research, please cite our paper. 28 | ``` 29 | @inproceedings{ 30 | shang2023ptqdm, 31 | title={Post-training Quantization on Diffusion Models}, 32 | author={Yuzhang Shang and Zhihang Yuan and Bin Xie and Bingzhe Wu and Yan Yan}, 33 | booktitle={CVPR}, 34 | year={2023} 35 | } 36 | ``` 37 | 38 | **Related Work** 39 | Our repo is developed based on the PyTorch implementations of Improved Diffusion ([IDDPM](https://github.com/openai/improved-diffusion), ICML 2021) and QDrop ([QDrop](https://github.com/wimh966/QDrop), ICLR 2022). Thanks to the authors for releasing their codebases! 40 | -------------------------------------------------------------------------------- /activation_hist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/42Shawn/PTQ4DM/180a4d15d400316e2971f54d10b96c53f8673455/activation_hist.png --------------------------------------------------------------------------------