├── README.md ├── full_model ├── ckp_s_q1 │ └── readme.md ├── ckp_s_q2 │ └── readme.md ├── ckp_s_q3 │ └── readme.md ├── datasets.py ├── engine.py ├── losses.py ├── main.py ├── model.py ├── optim.py ├── pretrain_s │ └── readme.md ├── test.sh ├── train.sh └── utils.py └── pretrained_model ├── datasets.py ├── engine.py ├── hubconf.py ├── losses.py ├── main.py ├── model.py ├── pretrain_s └── readme.md ├── samplers.py ├── test.sh ├── train.sh └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | ## Towards End-to-End Image Compression and Analysis with Transformers 2 | 3 | Source code of our AAAI 2022 paper "Towards End-to-End Image Compression and Analysis with Transformers". 4 | 5 | ## Usage 6 | The code is run with `Python 3.7`, `Pytorch 1.8.1`, `Timm 0.4.9` and `Compressai 1.1.4`. 7 | 8 | ### Data preparation 9 | Download and extract ImageNet train and val images from http://image-net.org/. 10 | The directory structure is the standard layout for the torchvision [`datasets.ImageFolder`](https://pytorch.org/vision/stable/datasets.html?highlight=imagefolder#torchvision.datasets.ImageFolder), and the training and validation data is expected to be in the `train` folder and `val` folder respectively: 11 | 12 | ``` 13 | /path/to/imagenet/ 14 | train/ 15 | class1/ 16 | img1.jpeg 17 | class2/ 18 | img2.jpeg 19 | val/ 20 | class1/ 21 | img3.jpeg 22 | class2/ 23 | img4.jpeg 24 | ``` 25 | 26 | ### Pretrained model 27 | The `./pretrained_model` provides the pretrained model without compression. 28 | * Test 29 | 30 | Please adjust `--data-path` and run `sh test.sh`: 31 | ``` 32 | python main.py --eval --resume ./pretrain_s/checkpoint.pth --model pretrained_model --data-path /path/to/imagenet/ --output_dir ./eval 33 | ``` 34 | The `./pretrain_s/checkpoint.pth` can be downloaded from [Baidu Netdisk](https://pan.baidu.com/s/1RFXeKEzRn7mWk7ay0mQh_Q), with access code `aaai`. 35 | * Train 36 | 37 | Please adjust `--data-path` and run `sh train.sh`: 38 | ``` 39 | python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model pretrained_model --no-model-ema --clip-grad 1.0 --batch-size 128 --num_workers 16 --data-path /path/to/imagenet/ --output_dir ./ckp_pretrain 40 | ``` 41 | 42 | ### Full model 43 | The `./full_model` provides the full model with compression. 44 | * Test 45 | 46 | Please adjust `--data-path` and `--resume`, respectively. Run `sh test.sh`: 47 | ``` 48 | python main.py --eval --resume ./ckp_s_q1/checkpoint.pth --model full_model --no-pretrained --data-path /path/to/imagenet/ --output_dir ./eval 49 | ``` 50 | The `./ckp_s_q1/checkpoint.pth`, `./ckp_s_q2/checkpoint.pth` and `./ckp_s_q3/checkpoint.pth` can be downloaded from [Baidu Netdisk](https://pan.baidu.com/s/1RFXeKEzRn7mWk7ay0mQh_Q), with access code `aaai`. 51 | 52 | * Train 53 | 54 | Please download `./pretrain_s/checkpoint.pth` from [Baidu Netdisk](https://pan.baidu.com/s/1RFXeKEzRn7mWk7ay0mQh_Q) with access code `aaai`, adjust `--data-path` and `--quality`, respectively. 55 | 56 | | quality | alpha | beta | 57 | | :---: | :---: | :---: | 58 | | 1 | 0.1 | 0.001 | 59 | | 2 | 0.3 | 0.003 | 60 | | 3 | 0.6 | 0.006 | 61 | 62 | Run `sh train.sh`: 63 | ``` 64 | python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model full_model --batch-size 128 --num_workers 16 --clip-grad 1.0 --quality 1 --data-path /path/to/imagenet/ --output_dir ./ckp_full 65 | ``` 66 | 67 | ## Citation 68 | ``` 69 | @InProceedings{Bai2022AAAI, 70 | title={Towards End-to-End Image Compression and Analysis with Transformers}, 71 | author={Bai, Yuanchao and Yang, Xu and Liu, Xianming and Jiang, Junjun and Wang, Yaowei and Ji, Xiangyang and Gao, Wen}, 72 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 73 | year={2022} 74 | } 75 | ``` 76 | -------------------------------------------------------------------------------- /full_model/ckp_s_q1/readme.md: -------------------------------------------------------------------------------- 1 | Place the corresponding checkpoint.pth in this folder. -------------------------------------------------------------------------------- /full_model/ckp_s_q2/readme.md: -------------------------------------------------------------------------------- 1 | Place the corresponding checkpoint.pth in this folder. -------------------------------------------------------------------------------- /full_model/ckp_s_q3/readme.md: -------------------------------------------------------------------------------- 1 | Place the corresponding checkpoint.pth in this folder. -------------------------------------------------------------------------------- /full_model/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torchvision import datasets, transforms 5 | from torchvision.datasets.folder import ImageFolder, default_loader 6 | 7 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 8 | from timm.data import create_transform 9 | 10 | 11 | class INatDataset(ImageFolder): 12 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 13 | category='name', loader=default_loader): 14 | self.transform = transform 15 | self.loader = loader 16 | self.target_transform = target_transform 17 | self.year = year 18 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 19 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 20 | with open(path_json) as json_file: 21 | data = json.load(json_file) 22 | 23 | with open(os.path.join(root, 'categories.json')) as json_file: 24 | data_catg = json.load(json_file) 25 | 26 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 27 | 28 | with open(path_json_for_targeter) as json_file: 29 | data_for_targeter = json.load(json_file) 30 | 31 | targeter = {} 32 | indexer = 0 33 | for elem in data_for_targeter['annotations']: 34 | king = [] 35 | king.append(data_catg[int(elem['category_id'])][category]) 36 | if king[0] not in targeter.keys(): 37 | targeter[king[0]] = indexer 38 | indexer += 1 39 | self.nb_classes = len(targeter) 40 | 41 | self.samples = [] 42 | for elem in data['images']: 43 | cut = elem['file_name'].split('/') 44 | target_current = int(cut[2]) 45 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 46 | 47 | categors = data_catg[target_current] 48 | target_current_true = targeter[categors[category]] 49 | self.samples.append((path_current, target_current_true)) 50 | 51 | 52 | def build_dataset(is_train, args): 53 | transform = build_transform(is_train, args) 54 | 55 | if args.data_set == 'IMNET': 56 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 57 | dataset = datasets.ImageFolder(root, transform=transform) 58 | nb_classes = 1000 59 | elif args.data_set == 'INAT19': 60 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 61 | category=args.inat_category, transform=transform) 62 | nb_classes = dataset.nb_classes 63 | 64 | return dataset, nb_classes 65 | 66 | 67 | def build_transform(is_train, args): 68 | resize_im = args.input_size > 32 69 | if is_train: 70 | # this should always dispatch to transforms_imagenet_train 71 | transform = create_transform( 72 | input_size=args.input_size, 73 | is_training=True, 74 | color_jitter=None, 75 | interpolation=args.train_interpolation, 76 | ) 77 | if not resize_im: 78 | # replace RandomResizedCropAndInterpolation with 79 | # RandomCrop 80 | transform.transforms[0] = transforms.RandomCrop( 81 | args.input_size, padding=4) 82 | return transform 83 | 84 | t = [] 85 | if resize_im: 86 | size = int((256 / 224) * args.input_size) 87 | t.append( 88 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 89 | ) 90 | t.append(transforms.CenterCrop(args.input_size)) 91 | 92 | t.append(transforms.ToTensor()) 93 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 94 | return transforms.Compose(t) -------------------------------------------------------------------------------- /full_model/engine.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train and eval functions used in main.py 3 | """ 4 | import math 5 | import sys 6 | import os 7 | from typing import Iterable, Optional 8 | 9 | import torch 10 | 11 | from timm.data import Mixup 12 | from timm.utils import accuracy, ModelEma 13 | 14 | from losses import JointLoss, DenormalizedMSELoss 15 | import utils 16 | 17 | 18 | def train_one_epoch(model: torch.nn.Module, criterion: JointLoss, 19 | data_loader: Iterable, optimizer: torch.optim.Optimizer, aux_optimizer: torch.optim.Optimizer, 20 | device: torch.device, epoch: int, max_norm: float = 0, 21 | set_training_mode=True): 22 | model.train(set_training_mode) 23 | metric_logger = utils.MetricLogger(delimiter=" ") 24 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 25 | header = 'Epoch: [{}]'.format(epoch) 26 | print_freq = 10 27 | 28 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 29 | samples = samples.to(device, non_blocking=True) 30 | targets = targets.to(device, non_blocking=True) 31 | 32 | optimizer.zero_grad() 33 | aux_optimizer.zero_grad() 34 | 35 | outputs, aux_loss = model(samples) 36 | loss = criterion(samples, outputs, targets) 37 | loss_value = loss.item() 38 | loss.backward() 39 | 40 | if not math.isfinite(loss_value): 41 | print("Loss is {}, stopping training".format(loss_value)) 42 | sys.exit(1) 43 | 44 | if max_norm is not None and max_norm > 0: 45 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 46 | optimizer.step() 47 | 48 | aux_loss_value = aux_loss.item() 49 | aux_loss.backward() 50 | aux_optimizer.step() 51 | 52 | torch.cuda.synchronize() 53 | 54 | metric_logger.update(loss=loss_value) 55 | metric_logger.update(aux_loss=aux_loss_value) 56 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 57 | # gather the stats from all processes 58 | metric_logger.synchronize_between_processes() 59 | print("Averaged stats:", metric_logger) 60 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 61 | 62 | 63 | @torch.no_grad() 64 | def evaluate(data_loader, model, device, output_dir, write_img=True): 65 | criterion_cls = torch.nn.CrossEntropyLoss() 66 | criterion_rec = DenormalizedMSELoss() 67 | 68 | metric_logger = utils.MetricLogger(delimiter=" ") 69 | header = 'Test:' 70 | 71 | # switch to evaluation mode 72 | model.eval() 73 | for images, target in metric_logger.log_every(data_loader, 10, header): 74 | images = images.to(device, non_blocking=True) 75 | target = target.to(device, non_blocking=True) 76 | 77 | B, _, H, W = images.shape 78 | num_pixels = B * H * W 79 | 80 | # compute output 81 | output, _ = model(images) 82 | loss_cls = criterion_cls(output[0], target) 83 | loss_rec = criterion_rec(output[1], images) 84 | loss_bpp = (torch.log(output[2]).sum() + torch.log(output[3]).sum()) / (-math.log(2) * num_pixels) 85 | psnr = utils.img_distortion(output[1], images) 86 | 87 | if write_img: 88 | utils.imwrite(images[:4], os.path.join(output_dir,'example_org.png')) 89 | utils.imwrite(output[1][:4], os.path.join(output_dir,'example_rec.png')) 90 | write_img = False 91 | 92 | acc1, acc5 = accuracy(output[0], target, topk=(1, 5)) 93 | 94 | batch_size = images.shape[0] 95 | metric_logger.update(loss_cls=loss_cls.item()) 96 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 97 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 98 | metric_logger.meters['loss_rec'].update(loss_rec.item(), n=batch_size) 99 | metric_logger.meters['loss_bpp'].update(loss_bpp.item(), n=batch_size) 100 | metric_logger.meters['psnr'].update(psnr, n=batch_size) 101 | # gather the stats from all processes 102 | metric_logger.synchronize_between_processes() 103 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss_cls {losses_cls.global_avg:.3f} loss_rec {losses_rec.global_avg:.3f} loss bpp {losses_bpp.global_avg:.3f} psnr {psnr.global_avg:.3f}' 104 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses_cls=metric_logger.loss_cls, losses_rec=metric_logger.loss_rec, losses_bpp=metric_logger.loss_bpp, psnr=metric_logger.psnr)) 105 | 106 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 107 | 108 | 109 | @torch.no_grad() 110 | def evaluate_real(data_loader, model, device, output_dir, write_img=True): 111 | criterion_cls = torch.nn.CrossEntropyLoss() 112 | criterion_rec = DenormalizedMSELoss() 113 | 114 | metric_logger = utils.MetricLogger(delimiter=" ") 115 | header = 'Test:' 116 | 117 | # switch to evaluation mode 118 | model.eval() 119 | model.update(force=True) 120 | for images, target in metric_logger.log_every(data_loader, 10, header): 121 | images = images.to(device, non_blocking=True) 122 | target = target.to(device, non_blocking=True) 123 | 124 | B, _, H, W = images.shape 125 | num_pixels = B * H * W 126 | 127 | # compute output 128 | out_enc = model.compress(images) 129 | output = model.decompress(out_enc['strings'], out_enc['shape']) 130 | loss_cls = criterion_cls(output[0], target) 131 | loss_rec = criterion_rec(output[1], images) 132 | psnr = utils.img_distortion(output[1], images) 133 | 134 | bitstream = sum([len(out_enc['strings'][0][i]) + len(out_enc['strings'][1][i]) + 16 for i in range(B)]) 135 | loss_bpp = bitstream * 8 / num_pixels 136 | 137 | if write_img: 138 | utils.imwrite(images[:4], os.path.join(output_dir,'example_org.png')) 139 | utils.imwrite(output[1][:4], os.path.join(output_dir,'example_rec.png')) 140 | write_img = False 141 | 142 | acc1, acc5 = accuracy(output[0], target, topk=(1, 5)) 143 | 144 | batch_size = images.shape[0] 145 | metric_logger.update(loss_cls=loss_cls.item()) 146 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 147 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 148 | metric_logger.meters['loss_rec'].update(loss_rec.item(), n=batch_size) 149 | metric_logger.meters['loss_bpp'].update(loss_bpp, n=batch_size) 150 | metric_logger.meters['psnr'].update(psnr, n=batch_size) 151 | # gather the stats from all processes 152 | metric_logger.synchronize_between_processes() 153 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss_cls {losses_cls.global_avg:.3f} loss_rec {losses_rec.global_avg:.3f} loss bpp {losses_bpp.global_avg:.3f} psnr {psnr.global_avg:.3f}' 154 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses_cls=metric_logger.loss_cls, losses_rec=metric_logger.loss_rec, losses_bpp=metric_logger.loss_bpp, psnr=metric_logger.psnr)) 155 | 156 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 157 | -------------------------------------------------------------------------------- /full_model/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements the loss functions 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | 8 | import math 9 | 10 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 11 | 12 | 13 | class DenormalizedMSELoss(nn.Module): 14 | 15 | def __init__(self, scale=255.): 16 | super().__init__() 17 | self.imagenet_std = torch.tensor(IMAGENET_DEFAULT_STD).reshape(1,3,1,1) 18 | self.scale = scale 19 | 20 | def forward(self, x, y): 21 | diff = (x - y) * self.imagenet_std.to(x.device) * self.scale 22 | mse_loss = torch.mean(diff ** 2) 23 | 24 | return mse_loss 25 | 26 | 27 | class JointLoss(nn.Module): 28 | 29 | def __init__(self, base_criterion: torch.nn.Module, alpha: float, beta: float): 30 | super().__init__() 31 | self.base_criterion = base_criterion 32 | self.alpha = alpha 33 | self.d_mse = DenormalizedMSELoss() 34 | self.beta = beta 35 | 36 | def forward(self, inputs, outputs, labels): 37 | B, _, H, W = inputs.size() 38 | num_pixels = B * H * W 39 | 40 | outputs_cls, outputs_rec, outputs_y_likelihoods, outputs_z_likelihoods = outputs 41 | 42 | cls_loss = self.base_criterion(outputs_cls, labels) 43 | mse_loss = self.d_mse(outputs_rec, inputs) 44 | bpp_loss = (torch.log(outputs_y_likelihoods).sum() + torch.log(outputs_z_likelihoods).sum()) / (-math.log(2) * num_pixels) 45 | 46 | return self.alpha * cls_loss + self.beta * mse_loss + bpp_loss 47 | 48 | -------------------------------------------------------------------------------- /full_model/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import numpy as np 4 | import time 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import json 8 | 9 | from pathlib import Path 10 | 11 | from timm.data import Mixup 12 | from timm.models import create_model 13 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 14 | from timm.scheduler import create_scheduler 15 | from optim import configure_optimizers 16 | from timm.utils import NativeScaler, get_state_dict 17 | 18 | from datasets import build_dataset 19 | from engine import train_one_epoch, evaluate, evaluate_real 20 | import model 21 | import utils 22 | 23 | torch.backends.cudnn.deterministic = True 24 | torch.backends.cudnn.benchmark = False 25 | 26 | from losses import JointLoss 27 | 28 | alpha_beta = { 29 | 1: (0.1, 0.001), 30 | 2: (0.3, 0.003), 31 | 3: (0.6, 0.006) 32 | } 33 | 34 | def get_args_parser(): 35 | parser = argparse.ArgumentParser('JCCTransformer training and evaluation script', add_help=False) 36 | parser.add_argument('--batch-size', default=64, type=int) 37 | parser.add_argument('--epochs', default=300, type=int) 38 | 39 | # Model parameters 40 | parser.add_argument('--model', default='finetune_tiny_patch16_224', type=str, metavar='MODEL', 41 | help='Name of model to train') 42 | parser.add_argument('--input-size', default=224, type=int, help='images input size') 43 | 44 | parser.add_argument('--quality', default=1, type=int, help='quality') 45 | 46 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 47 | help='Dropout rate (default: 0.)') 48 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 49 | help='Drop path rate (default: 0.1)') 50 | 51 | parser.add_argument('--pretrained', action='store_true') 52 | parser.add_argument('--no-pretrained', action='store_false', dest='pretrained') 53 | parser.set_defaults(pretrained=True) 54 | 55 | # Optimizer parameters 56 | parser.add_argument('--opt', default='adam', type=str, metavar='OPTIMIZER', 57 | help='Optimizer (default: "adam"') 58 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 59 | help='Optimizer Epsilon (default: 1e-8)') 60 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 61 | help='Optimizer Betas (default: None, use opt default)') 62 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 63 | help='Clip gradient norm (default: None, no clipping)') 64 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 65 | help='SGD momentum (default: 0.9)') 66 | parser.add_argument('--weight-decay', type=float, default=0., 67 | help='weight decay (default: 0.)') 68 | # Learning rate schedule parameters 69 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 70 | help='LR scheduler (default: "cosine"') 71 | parser.add_argument('--lr', type=float, default=1e-4, metavar='LR', 72 | help='learning rate (default: 1e-4)') 73 | parser.add_argument('--aux-lr', type=float, default=1e-3, metavar='LR', 74 | help='aux learning rate (default: 1e-3)') 75 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 76 | help='learning rate noise on/off epoch percentages') 77 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 78 | help='learning rate noise limit percent (default: 0.67)') 79 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 80 | help='learning rate noise std-dev (default: 1.0)') 81 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 82 | help='warmup learning rate (default: 1e-6)') 83 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 84 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 85 | 86 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 87 | help='epoch interval to decay LR') 88 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 89 | help='epochs to warmup LR, if scheduler supports') 90 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 91 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 92 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 93 | help='patience epochs for Plateau LR scheduler (default: 10') 94 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 95 | help='LR decay rate (default: 0.1)') 96 | 97 | # Augmentation parameters 98 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 99 | parser.add_argument('--train-interpolation', type=str, default='bicubic', 100 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 101 | 102 | # * Finetuning params 103 | parser.add_argument('--finetune', default='', help='finetune from checkpoint') 104 | 105 | # Dataset parameters 106 | parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str, 107 | help='dataset path') 108 | parser.add_argument('--data-set', default='IMNET', choices=['IMNET', 'INAT19'], 109 | type=str, help='Image Net dataset path') 110 | parser.add_argument('--inat-category', default='name', 111 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], 112 | type=str, help='semantic granularity') 113 | 114 | parser.add_argument('--output_dir', default='', 115 | help='path where to save, empty for no saving') 116 | parser.add_argument('--device', default='cuda', 117 | help='device to use for training / testing') 118 | parser.add_argument('--seed', default=0, type=int) 119 | parser.add_argument('--resume', default='', help='resume from checkpoint') 120 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 121 | help='start epoch') 122 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 123 | parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') 124 | parser.add_argument('--num_workers', default=10, type=int) 125 | parser.add_argument('--pin-mem', action='store_true', 126 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 127 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', 128 | help='') 129 | parser.set_defaults(pin_mem=True) 130 | 131 | # distributed training parameters 132 | parser.add_argument('--world_size', default=1, type=int, 133 | help='number of distributed processes') 134 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 135 | return parser 136 | 137 | 138 | def main(args): 139 | utils.init_distributed_mode(args) 140 | 141 | print(args) 142 | 143 | device = torch.device(args.device) 144 | 145 | # fix the seed for reproducibility 146 | seed = args.seed + utils.get_rank() 147 | torch.manual_seed(seed) 148 | np.random.seed(seed) 149 | 150 | cudnn.benchmark = True 151 | 152 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 153 | dataset_val, _ = build_dataset(is_train=False, args=args) 154 | 155 | num_tasks = utils.get_world_size() 156 | global_rank = utils.get_rank() 157 | sampler_train = torch.utils.data.DistributedSampler( 158 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 159 | ) 160 | if args.dist_eval: 161 | if len(dataset_val) % num_tasks != 0: 162 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 163 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 164 | 'equal num of samples per-process.') 165 | sampler_val = torch.utils.data.DistributedSampler( 166 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 167 | else: 168 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 169 | 170 | data_loader_train = torch.utils.data.DataLoader( 171 | dataset_train, sampler=sampler_train, 172 | batch_size=args.batch_size, 173 | num_workers=args.num_workers, 174 | pin_memory=args.pin_mem, 175 | drop_last=True, 176 | ) 177 | 178 | data_loader_val = torch.utils.data.DataLoader( 179 | dataset_val, sampler=sampler_val, 180 | batch_size=int(2. * args.batch_size), 181 | num_workers=args.num_workers, 182 | pin_memory=args.pin_mem, 183 | drop_last=False 184 | ) 185 | 186 | print(f"Creating model: {args.model}") 187 | model = create_model( 188 | args.model, 189 | pretrained=args.pretrained, 190 | num_classes=args.nb_classes, 191 | drop_rate=args.drop, 192 | drop_path_rate=args.drop_path, 193 | drop_block_rate=None, 194 | ) 195 | 196 | if args.finetune: 197 | if args.finetune.startswith('https'): 198 | checkpoint = torch.hub.load_state_dict_from_url( 199 | args.finetune, map_location='cpu', check_hash=True) 200 | else: 201 | checkpoint = torch.load(args.finetune, map_location='cpu') 202 | 203 | checkpoint_model = checkpoint['model'] 204 | state_dict = model.state_dict() 205 | for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']: 206 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 207 | print(f"Removing key {k} from pretrained checkpoint") 208 | del checkpoint_model[k] 209 | 210 | # interpolate position embedding 211 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 212 | embedding_size = pos_embed_checkpoint.shape[-1] 213 | num_patches = model.patch_embed.num_patches 214 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 215 | # height (== width) for the checkpoint position embedding 216 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 217 | # height (== width) for the new position embedding 218 | new_size = int(num_patches ** 0.5) 219 | # class_token and dist_token are kept unchanged 220 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 221 | # only the position tokens are interpolated 222 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 223 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 224 | pos_tokens = torch.nn.functional.interpolate( 225 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 226 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 227 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 228 | checkpoint_model['pos_embed'] = new_pos_embed 229 | 230 | model.load_state_dict(checkpoint_model, strict=False) 231 | 232 | model.to(device) 233 | 234 | model_without_ddp = model 235 | if args.distributed: 236 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 237 | model_without_ddp = model.module 238 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 239 | print('number of params:', n_parameters) 240 | 241 | optimizer, aux_optimizer = configure_optimizers(args, model_without_ddp) 242 | 243 | lr_scheduler, _ = create_scheduler(args, optimizer) 244 | 245 | if args.smoothing: 246 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 247 | else: 248 | criterion = torch.nn.CrossEntropyLoss() 249 | 250 | alpha = alpha_beta[args.quality][0] 251 | beta = alpha_beta[args.quality][1] 252 | criterion = JointLoss(criterion, alpha=alpha, beta=beta) 253 | 254 | output_dir = Path(args.output_dir) 255 | if args.resume: 256 | if args.resume.startswith('https'): 257 | checkpoint = torch.hub.load_state_dict_from_url( 258 | args.resume, map_location='cpu', check_hash=True) 259 | else: 260 | checkpoint = torch.load(args.resume, map_location='cpu') 261 | model_without_ddp.load_state_dict(checkpoint['model']) 262 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 263 | optimizer.load_state_dict(checkpoint['optimizer']) 264 | aux_optimizer.load_state_dict(checkpoint['aux_optimizer']) 265 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 266 | args.start_epoch = checkpoint['epoch'] + 1 267 | 268 | if args.eval: 269 | test_stats = evaluate_real(data_loader_val, model, device, output_dir) 270 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 271 | return 272 | 273 | 274 | print(f"Start training for {args.epochs} epochs") 275 | start_time = time.time() 276 | max_accuracy = 0.0 277 | for epoch in range(args.start_epoch, args.epochs): 278 | if args.distributed: 279 | data_loader_train.sampler.set_epoch(epoch) 280 | 281 | train_stats = train_one_epoch( 282 | model, criterion, data_loader_train, 283 | optimizer, aux_optimizer, device, epoch, 284 | args.clip_grad, set_training_mode=args.finetune == '' # keep in eval mode during finetuning 285 | ) 286 | 287 | lr_scheduler.step(epoch) 288 | if args.output_dir: 289 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 290 | for checkpoint_path in checkpoint_paths: 291 | utils.save_on_master({ 292 | 'model': model_without_ddp.state_dict(), 293 | 'optimizer': optimizer.state_dict(), 294 | 'aux_optimizer': aux_optimizer.state_dict(), 295 | 'lr_scheduler': lr_scheduler.state_dict(), 296 | 'epoch': epoch, 297 | 'args': args, 298 | }, checkpoint_path) 299 | 300 | test_stats = evaluate(data_loader_val, model, device, output_dir) 301 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 302 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 303 | print(f'Max accuracy: {max_accuracy:.2f}%') 304 | 305 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 306 | **{f'test_{k}': v for k, v in test_stats.items()}, 307 | 'epoch': epoch, 308 | 'n_parameters': n_parameters} 309 | 310 | if args.output_dir and utils.is_main_process(): 311 | with (output_dir / "log.txt").open("a") as f: 312 | f.write(json.dumps(log_stats) + "\n") 313 | 314 | total_time = time.time() - start_time 315 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 316 | print('Training time {}'.format(total_time_str)) 317 | 318 | 319 | if __name__ == '__main__': 320 | parser = argparse.ArgumentParser('JCCTransformer training and evaluation script', parents=[get_args_parser()]) 321 | args = parser.parse_args() 322 | if args.output_dir: 323 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 324 | main(args) 325 | -------------------------------------------------------------------------------- /full_model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | 5 | import math 6 | 7 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 8 | from timm.models.registry import register_model 9 | from timm.models.layers import DropPath, trunc_normal_, to_2tuple 10 | 11 | from compressai.entropy_models import EntropyBottleneck, GaussianConditional 12 | 13 | from compressai.models.utils import update_registered_buffers 14 | 15 | 16 | __all__ = [ 17 | "full_model" 18 | ] 19 | 20 | def _cfg(url='', **kwargs): 21 | return { 22 | 'url': url, 23 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 24 | 'crop_pct': .9, 'interpolation': 'bicubic', 25 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 26 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 27 | **kwargs 28 | } 29 | 30 | # From Balle's tensorflow compression examples 31 | SCALES_MIN = 0.11 32 | SCALES_MAX = 256 33 | SCALES_LEVELS = 64 34 | 35 | 36 | def get_scale_table( 37 | min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS 38 | ): # pylint: disable=W0622 39 | return torch.exp(torch.linspace(math.log(min), math.log(max), levels)) 40 | 41 | 42 | def conv(in_channels, out_channels, kernel_size=5, stride=2): 43 | return nn.Conv2d( 44 | in_channels, 45 | out_channels, 46 | kernel_size=kernel_size, 47 | stride=stride, 48 | padding=kernel_size // 2, 49 | ) 50 | 51 | 52 | def deconv(in_channels, out_channels, kernel_size=5, stride=2): 53 | return nn.ConvTranspose2d( 54 | in_channels, 55 | out_channels, 56 | kernel_size=kernel_size, 57 | stride=stride, 58 | output_padding=stride - 1, 59 | padding=kernel_size // 2, 60 | ) 61 | 62 | 63 | class Img2Embed(nn.Module): 64 | """ Encoding Image to Embedding 65 | """ 66 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 67 | super().__init__() 68 | img_size = to_2tuple(img_size) 69 | patch_size = to_2tuple(patch_size) 70 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 71 | self.img_size = img_size 72 | self.patch_size = patch_size 73 | self.num_patches = num_patches 74 | 75 | middle_chans = [128, 128, 128] 76 | self.proj = nn.Sequential( 77 | conv(in_chans, middle_chans[0]), 78 | nn.LeakyReLU(inplace=True), 79 | conv(middle_chans[0], middle_chans[1]), 80 | nn.LeakyReLU(inplace=True), 81 | conv(middle_chans[1], middle_chans[2]), 82 | nn.LeakyReLU(inplace=True), 83 | conv(middle_chans[2], embed_dim) 84 | ) 85 | 86 | def forward(self, x): 87 | B, C, H, W = x.shape 88 | # FIXME look at relaxing size constraints 89 | assert H == self.img_size[0] and W == self.img_size[1], \ 90 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 91 | x = self.proj(x).flatten(2).transpose(1, 2) 92 | return x 93 | 94 | 95 | class Embed2Img(nn.Module): 96 | """ Decode Embedding to Image 97 | """ 98 | def __init__(self, img_size=224, patch_size=16, out_chans=3, embed_dim=768): 99 | super().__init__() 100 | img_size = to_2tuple(img_size) 101 | patch_size = to_2tuple(patch_size) 102 | embed_size = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) 103 | num_patches = embed_size[0] * embed_size[1] 104 | self.img_size = img_size 105 | self.patch_size = patch_size 106 | self.embed_size = embed_size 107 | self.num_patches = num_patches 108 | 109 | middle_chans = [128, 128, 128] 110 | self.proj = nn.Sequential( 111 | deconv(embed_dim, middle_chans[0]), 112 | nn.LeakyReLU(inplace=True), 113 | deconv(middle_chans[0], middle_chans[1]), 114 | nn.LeakyReLU(inplace=True), 115 | deconv(middle_chans[1], middle_chans[2]), 116 | nn.LeakyReLU(inplace=True), 117 | deconv(middle_chans[2], out_chans) 118 | ) 119 | 120 | def forward(self, x): 121 | B, HW, C = x.shape 122 | assert HW == self.num_patches, \ 123 | f"Input embeding size ({HW}) doesn't match patches size ({self.num_patches})." 124 | x = self.proj(x.transpose(1, 2).reshape(B, C, self.embed_size[0], self.embed_size[1])) 125 | return x 126 | 127 | 128 | class Mlp(nn.Module): 129 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 130 | super().__init__() 131 | out_features = out_features or in_features 132 | hidden_features = hidden_features or in_features 133 | self.fc1 = nn.Linear(in_features, hidden_features) 134 | self.act = act_layer() 135 | self.fc2 = nn.Linear(hidden_features, out_features) 136 | self.drop = nn.Dropout(drop) 137 | 138 | def forward(self, x): 139 | x = self.fc1(x) 140 | x = self.act(x) 141 | x = self.drop(x) 142 | x = self.fc2(x) 143 | x = self.drop(x) 144 | return x 145 | 146 | 147 | class Attention(nn.Module): 148 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 149 | super().__init__() 150 | self.num_heads = num_heads 151 | head_dim = dim // num_heads 152 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 153 | self.scale = qk_scale or head_dim ** -0.5 154 | 155 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 156 | self.attn_drop = nn.Dropout(attn_drop) 157 | self.proj = nn.Linear(dim, dim) 158 | self.proj_drop = nn.Dropout(proj_drop) 159 | 160 | def forward(self, x): 161 | B, N, C = x.shape 162 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 163 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 164 | 165 | attn = (q @ k.transpose(-2, -1)) * self.scale 166 | attn = attn.softmax(dim=-1) 167 | attn = self.attn_drop(attn) 168 | 169 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 170 | x = self.proj(x) 171 | x = self.proj_drop(x) 172 | return x 173 | 174 | 175 | class Block(nn.Module): 176 | 177 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 178 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 179 | super().__init__() 180 | self.norm1 = norm_layer(dim) 181 | self.attn = Attention( 182 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 183 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 184 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 185 | self.norm2 = norm_layer(dim) 186 | mlp_hidden_dim = int(dim * mlp_ratio) 187 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 188 | 189 | def forward(self, x): 190 | x = x + self.drop_path(self.attn(self.norm1(x))) 191 | x = x + self.drop_path(self.mlp(self.norm2(x))) 192 | return x 193 | 194 | 195 | class JCCTransformer(nn.Module): 196 | """ Joint Compression and Classification Transformer 197 | """ 198 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 199 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 200 | drop_path_rate=0., norm_layer=nn.LayerNorm): 201 | super().__init__() 202 | self.num_classes = num_classes 203 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 204 | self.depth = depth 205 | 206 | self.patch_embed = Img2Embed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=192) 207 | self.chans_embed = nn.Linear(192, embed_dim) 208 | num_patches = self.patch_embed.num_patches 209 | 210 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 211 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 212 | self.pos_drop = nn.Dropout(p=drop_rate) 213 | 214 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 215 | self.blocks = nn.ModuleList([ 216 | Block( 217 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 218 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 219 | for i in range(depth)]) 220 | self.norm = norm_layer(embed_dim) 221 | 222 | # Classifier head 223 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 224 | 225 | # Reconstruction head 226 | self.head_rec = Embed2Img(img_size=img_size, patch_size=patch_size, out_chans=in_chans, embed_dim=embed_dim) 227 | 228 | # fusion 229 | self.fusion0 = nn.Linear(embed_dim, embed_dim // 4) 230 | self.fusion1 = nn.Linear(embed_dim, embed_dim // 4) 231 | self.fusion2 = nn.Linear(embed_dim, embed_dim // 4) 232 | self.fusion3 = nn.Linear(embed_dim, embed_dim // 4) 233 | self.fusion = nn.Linear(embed_dim, embed_dim) 234 | 235 | # Compression model 236 | hyper_dim = 128 237 | self.entropy_bottleneck = EntropyBottleneck(hyper_dim) 238 | self.h_a = nn.Sequential( 239 | conv(192, hyper_dim, stride=1, kernel_size=3), 240 | nn.LeakyReLU(inplace=True), 241 | conv(hyper_dim, hyper_dim), 242 | nn.LeakyReLU(inplace=True), 243 | conv(hyper_dim, hyper_dim, stride=1), 244 | ) 245 | self.h_s = nn.Sequential( 246 | conv(hyper_dim, 192, stride=1), 247 | nn.LeakyReLU(inplace=True), 248 | deconv(192, 192 * 3 // 2), 249 | nn.LeakyReLU(inplace=True), 250 | conv(192 * 3 //2, 192 * 2, stride=1, kernel_size=3), 251 | ) 252 | self.gaussian_conditional = GaussianConditional(None) 253 | 254 | trunc_normal_(self.pos_embed, std=.02) 255 | trunc_normal_(self.cls_token, std=.02) 256 | self.apply(self._init_weights) 257 | 258 | def aux_loss(self): 259 | """Return the aggregated loss over the auxiliary entropy bottleneck 260 | module(s). 261 | """ 262 | aux_loss = sum( 263 | m.loss() for m in self.modules() if isinstance(m, EntropyBottleneck) 264 | ) 265 | return aux_loss 266 | 267 | def _update_entropybottleneck(self, force=False): 268 | """Updates the entropy bottleneck(s) CDF values. 269 | 270 | Needs to be called once after training to be able to later perform the 271 | evaluation with an actual entropy coder. 272 | 273 | Args: 274 | force (bool): overwrite previous values (default: False) 275 | 276 | Returns: 277 | updated (bool): True if one of the EntropyBottlenecks was updated. 278 | 279 | """ 280 | updated = False 281 | for m in self.children(): 282 | if not isinstance(m, EntropyBottleneck): 283 | continue 284 | rv = m.update(force=force) 285 | updated |= rv 286 | return updated 287 | 288 | def update(self, scale_table=None, force=False): 289 | if scale_table is None: 290 | scale_table = get_scale_table() 291 | updated = self.gaussian_conditional.update_scale_table(scale_table, force=force) 292 | updated |= self._update_entropybottleneck(force=force) 293 | return updated 294 | 295 | def load_state_dict(self, state_dict, pretrained=False, **kwargs): 296 | # Dynamically update the entropy bottleneck buffers related to the CDFs 297 | if not pretrained: 298 | update_registered_buffers( 299 | self.entropy_bottleneck, 300 | "entropy_bottleneck", 301 | ["_quantized_cdf", "_offset", "_cdf_length"], 302 | state_dict, 303 | ) 304 | update_registered_buffers( 305 | self.gaussian_conditional, 306 | "gaussian_conditional", 307 | ["_quantized_cdf", "_offset", "_cdf_length", "scale_table"], 308 | state_dict, 309 | ) 310 | super().load_state_dict(state_dict, **kwargs) 311 | 312 | def _init_weights(self, m): 313 | if isinstance(m, nn.Linear): 314 | trunc_normal_(m.weight, std=.02) 315 | if isinstance(m, nn.Linear) and m.bias is not None: 316 | nn.init.constant_(m.bias, 0) 317 | elif isinstance(m, nn.LayerNorm): 318 | nn.init.constant_(m.bias, 0) 319 | nn.init.constant_(m.weight, 1.0) 320 | 321 | @torch.jit.ignore 322 | def no_weight_decay(self): 323 | return {'pos_embed', 'cls_token'} 324 | 325 | def get_classifier(self): 326 | return self.head 327 | 328 | def reset_classifier(self, num_classes, global_pool=''): 329 | self.num_classes = num_classes 330 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 331 | 332 | def forward_features(self, x): 333 | x = self.patch_embed(x) 334 | B, N, C = x.shape 335 | H = W = int(N**0.5) 336 | 337 | # Bottleneck Compression 338 | y = x.transpose(1, 2).reshape(B, C, H, W) 339 | z = self.h_a(y) 340 | z_hat, z_likelihoods = self.entropy_bottleneck(z) 341 | gaussian_params = self.h_s(z_hat) 342 | scales_hat, means_hat = gaussian_params.chunk(2, 1) 343 | y_hat, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat) 344 | 345 | # Transformer 346 | y_hat = y_hat.flatten(2).transpose(1, 2) 347 | y_hat = self.chans_embed(y_hat) 348 | 349 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 350 | y0 = torch.cat((cls_tokens, y_hat), dim=1) 351 | y0 = y0 + self.pos_embed 352 | y0 = self.pos_drop(y0) 353 | 354 | y1 = self.blocks[0](y0) 355 | y2 = self.blocks[1](y1) 356 | y3 = self.blocks[2](y2) 357 | y4 = self.blocks[3](y3) 358 | y5 = self.blocks[4](y4) 359 | y6 = self.blocks[5](y5) 360 | y7 = self.blocks[6](y6) 361 | y8 = self.blocks[7](y7) 362 | y9 = self.blocks[8](y8) 363 | y10 = self.blocks[9](y9) 364 | y11 = self.blocks[10](y10) 365 | y12 = self.blocks[11](y11) 366 | 367 | y_out = self.norm(y12) 368 | 369 | y0 = self.fusion0(y_hat) 370 | y1 = self.fusion1(y1[:, 1:]) 371 | y2 = self.fusion2(y2[:, 1:]) 372 | y3 = self.fusion3(y3[:, 1:]) 373 | 374 | y_rec = torch.cat((y0, y1, y2, y3), dim = 2) 375 | y_rec = self.fusion(y_rec) 376 | 377 | return (y_out[:, 0], y_rec, y_likelihoods, z_likelihoods) 378 | 379 | def forward(self, x): 380 | y_cls, y_rec, y_likelihoods, z_likelihoods = self.forward_features(x) 381 | cls = self.head(y_cls) 382 | rec = self.head_rec(y_rec) 383 | return (cls, rec, y_likelihoods, z_likelihoods), self.aux_loss() 384 | 385 | def compress(self, x): 386 | x = self.patch_embed(x) 387 | B, N, C = x.shape 388 | H = W = int(N**0.5) 389 | 390 | # Bottleneck Compression 391 | y = x.transpose(1, 2).reshape(B, C, H, W) 392 | z = self.h_a(y) 393 | z_strings = self.entropy_bottleneck.compress(z) 394 | z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) 395 | 396 | gaussian_params = self.h_s(z_hat) 397 | scales_hat, means_hat = gaussian_params.chunk(2, 1) 398 | indexes = self.gaussian_conditional.build_indexes(scales_hat) 399 | y_strings = self.gaussian_conditional.compress(y, indexes, means=means_hat) 400 | 401 | return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]} 402 | 403 | def decompress(self, strings, shape): 404 | assert isinstance(strings, list) and len(strings) == 2 405 | z_hat = self.entropy_bottleneck.decompress(strings[1], shape) 406 | gaussian_params = self.h_s(z_hat) 407 | scales_hat, means_hat = gaussian_params.chunk(2, 1) 408 | indexes = self.gaussian_conditional.build_indexes(scales_hat) 409 | y_hat = self.gaussian_conditional.decompress( 410 | strings[0], indexes, means=means_hat 411 | ) 412 | 413 | # Transformer 414 | y_hat = y_hat.flatten(2).transpose(1, 2) 415 | y_hat = self.chans_embed(y_hat) 416 | 417 | B = y_hat.shape[0] 418 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 419 | y0 = torch.cat((cls_tokens, y_hat), dim=1) 420 | y0 = y0 + self.pos_embed 421 | y0 = self.pos_drop(y0) 422 | 423 | y1 = self.blocks[0](y0) 424 | y2 = self.blocks[1](y1) 425 | y3 = self.blocks[2](y2) 426 | y4 = self.blocks[3](y3) 427 | y5 = self.blocks[4](y4) 428 | y6 = self.blocks[5](y5) 429 | y7 = self.blocks[6](y6) 430 | y8 = self.blocks[7](y7) 431 | y9 = self.blocks[8](y8) 432 | y10 = self.blocks[9](y9) 433 | y11 = self.blocks[10](y10) 434 | y12 = self.blocks[11](y11) 435 | 436 | y_out = self.norm(y12) 437 | 438 | y0 = self.fusion0(y_hat) 439 | y1 = self.fusion1(y1[:, 1:]) 440 | y2 = self.fusion2(y2[:, 1:]) 441 | y3 = self.fusion3(y3[:, 1:]) 442 | 443 | y_rec = torch.cat((y0, y1, y2, y3), dim = 2) 444 | y_rec = self.fusion(y_rec) 445 | 446 | cls = self.head(y_out[:, 0]) 447 | rec = self.head_rec(y_rec) 448 | 449 | return cls, rec 450 | 451 | 452 | @register_model 453 | def full_model(pretrained=False, **kwargs): 454 | model = JCCTransformer( 455 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 456 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 457 | model.default_cfg = _cfg() 458 | if pretrained: 459 | checkpoint = torch.load("./pretrain_s/checkpoint.pth", map_location='cpu') 460 | model.load_state_dict(checkpoint["model"], pretrained=True, strict=False) 461 | return model 462 | 463 | -------------------------------------------------------------------------------- /full_model/optim.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | 7 | 8 | def configure_optimizers(args, model): 9 | 10 | parameters = { 11 | n 12 | for n, p in model.named_parameters() 13 | if not n.endswith(".quantiles") and p.requires_grad 14 | } 15 | aux_parameters = { 16 | n 17 | for n, p in model.named_parameters() 18 | if n.endswith(".quantiles") and p.requires_grad 19 | } 20 | 21 | # Make sure we don't have an intersection of parameters 22 | params_dict = dict(model.named_parameters()) 23 | inter_params = parameters & aux_parameters 24 | union_params = parameters | aux_parameters 25 | 26 | assert len(inter_params) == 0 27 | assert len(union_params) - len(params_dict.keys()) == 0 28 | 29 | if args.opt =="adam": 30 | optimizer = optim.Adam( 31 | (params_dict[n] for n in sorted(parameters)), 32 | lr=args.lr, 33 | weight_decay=args.weight_decay, 34 | eps=args.opt_eps, 35 | ) 36 | aux_optimizer = optim.Adam( 37 | (params_dict[n] for n in sorted(aux_parameters)), 38 | lr=args.aux_lr, 39 | weight_decay=args.weight_decay, 40 | eps=args.opt_eps, 41 | ) 42 | 43 | return optimizer, aux_optimizer 44 | -------------------------------------------------------------------------------- /full_model/pretrain_s/readme.md: -------------------------------------------------------------------------------- 1 | Place the corresponding checkpoint.pth in this folder. -------------------------------------------------------------------------------- /full_model/test.sh: -------------------------------------------------------------------------------- 1 | python main.py --eval --resume ./ckp_s_q1/checkpoint.pth --model full_model --no-pretrained --data-path /path/to/imagenet/ --output_dir ./eval -------------------------------------------------------------------------------- /full_model/train.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model full_model --batch-size 128 --num_workers 16 --clip-grad 1.0 --quality 1 --data-path /path/to/imagenet/ --output_dir ./ckp_full -------------------------------------------------------------------------------- /full_model/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc functions, including distributed helpers. 3 | 4 | Mostly copy-paste from torchvision references. 5 | """ 6 | import io 7 | import os 8 | import time 9 | from collections import defaultdict, deque 10 | import datetime 11 | 12 | import torch 13 | import torch.distributed as dist 14 | import torchvision 15 | from torch.nn.functional import mse_loss 16 | 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | 19 | 20 | class SmoothedValue(object): 21 | """Track a series of values and provide access to smoothed values over a 22 | window or the global series average. 23 | """ 24 | 25 | def __init__(self, window_size=20, fmt=None): 26 | if fmt is None: 27 | fmt = "{median:.4f} ({global_avg:.4f})" 28 | self.deque = deque(maxlen=window_size) 29 | self.total = 0.0 30 | self.count = 0 31 | self.fmt = fmt 32 | 33 | def update(self, value, n=1): 34 | self.deque.append(value) 35 | self.count += n 36 | self.total += value * n 37 | 38 | def synchronize_between_processes(self): 39 | """ 40 | Warning: does not synchronize the deque! 41 | """ 42 | if not is_dist_avail_and_initialized(): 43 | return 44 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 45 | dist.barrier() 46 | dist.all_reduce(t) 47 | t = t.tolist() 48 | self.count = int(t[0]) 49 | self.total = t[1] 50 | 51 | @property 52 | def median(self): 53 | d = torch.tensor(list(self.deque)) 54 | return d.median().item() 55 | 56 | @property 57 | def avg(self): 58 | d = torch.tensor(list(self.deque), dtype=torch.float32) 59 | return d.mean().item() 60 | 61 | @property 62 | def global_avg(self): 63 | return self.total / self.count 64 | 65 | @property 66 | def max(self): 67 | return max(self.deque) 68 | 69 | @property 70 | def value(self): 71 | return self.deque[-1] 72 | 73 | def __str__(self): 74 | return self.fmt.format( 75 | median=self.median, 76 | avg=self.avg, 77 | global_avg=self.global_avg, 78 | max=self.max, 79 | value=self.value) 80 | 81 | 82 | class MetricLogger(object): 83 | def __init__(self, delimiter="\t"): 84 | self.meters = defaultdict(SmoothedValue) 85 | self.delimiter = delimiter 86 | 87 | def update(self, **kwargs): 88 | for k, v in kwargs.items(): 89 | if isinstance(v, torch.Tensor): 90 | v = v.item() 91 | assert isinstance(v, (float, int)) 92 | self.meters[k].update(v) 93 | 94 | def __getattr__(self, attr): 95 | if attr in self.meters: 96 | return self.meters[attr] 97 | if attr in self.__dict__: 98 | return self.__dict__[attr] 99 | raise AttributeError("'{}' object has no attribute '{}'".format( 100 | type(self).__name__, attr)) 101 | 102 | def __str__(self): 103 | loss_str = [] 104 | for name, meter in self.meters.items(): 105 | loss_str.append( 106 | "{}: {}".format(name, str(meter)) 107 | ) 108 | return self.delimiter.join(loss_str) 109 | 110 | def synchronize_between_processes(self): 111 | for meter in self.meters.values(): 112 | meter.synchronize_between_processes() 113 | 114 | def add_meter(self, name, meter): 115 | self.meters[name] = meter 116 | 117 | def log_every(self, iterable, print_freq, header=None): 118 | i = 0 119 | if not header: 120 | header = '' 121 | start_time = time.time() 122 | end = time.time() 123 | iter_time = SmoothedValue(fmt='{avg:.4f}') 124 | data_time = SmoothedValue(fmt='{avg:.4f}') 125 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 126 | log_msg = [ 127 | header, 128 | '[{0' + space_fmt + '}/{1}]', 129 | 'eta: {eta}', 130 | '{meters}', 131 | 'time: {time}', 132 | 'data: {data}' 133 | ] 134 | if torch.cuda.is_available(): 135 | log_msg.append('max mem: {memory:.0f}') 136 | log_msg = self.delimiter.join(log_msg) 137 | MB = 1024.0 * 1024.0 138 | for obj in iterable: 139 | data_time.update(time.time() - end) 140 | yield obj 141 | iter_time.update(time.time() - end) 142 | if i % print_freq == 0 or i == len(iterable) - 1: 143 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 144 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 145 | if torch.cuda.is_available(): 146 | print(log_msg.format( 147 | i, len(iterable), eta=eta_string, 148 | meters=str(self), 149 | time=str(iter_time), data=str(data_time), 150 | memory=torch.cuda.max_memory_allocated() / MB)) 151 | # ipdb.set_trace() 152 | else: 153 | print(log_msg.format( 154 | i, len(iterable), eta=eta_string, 155 | meters=str(self), 156 | time=str(iter_time), data=str(data_time))) 157 | i += 1 158 | end = time.time() 159 | total_time = time.time() - start_time 160 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 161 | print('{} Total time: {} ({:.4f} s / it)'.format( 162 | header, total_time_str, total_time / len(iterable))) 163 | 164 | 165 | def _load_checkpoint_for_ema(model_ema, checkpoint): 166 | """ 167 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 168 | """ 169 | mem_file = io.BytesIO() 170 | torch.save(checkpoint, mem_file) 171 | mem_file.seek(0) 172 | model_ema._load_checkpoint(mem_file) 173 | 174 | 175 | def setup_for_distributed(is_master): 176 | """ 177 | This function disables printing when not in master process 178 | """ 179 | import builtins as __builtin__ 180 | builtin_print = __builtin__.print 181 | 182 | def print(*args, **kwargs): 183 | force = kwargs.pop('force', False) 184 | if is_master or force: 185 | builtin_print(*args, **kwargs) 186 | 187 | __builtin__.print = print 188 | 189 | 190 | def is_dist_avail_and_initialized(): 191 | if not dist.is_available(): 192 | return False 193 | if not dist.is_initialized(): 194 | return False 195 | return True 196 | 197 | 198 | def get_world_size(): 199 | if not is_dist_avail_and_initialized(): 200 | return 1 201 | return dist.get_world_size() 202 | 203 | 204 | def get_rank(): 205 | if not is_dist_avail_and_initialized(): 206 | return 0 207 | return dist.get_rank() 208 | 209 | 210 | def is_main_process(): 211 | return get_rank() == 0 212 | 213 | 214 | def save_on_master(*args, **kwargs): 215 | if is_main_process(): 216 | torch.save(*args, **kwargs) 217 | 218 | 219 | def init_distributed_mode(args): 220 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 221 | args.rank = int(os.environ["RANK"]) 222 | args.world_size = int(os.environ['WORLD_SIZE']) 223 | args.gpu = int(os.environ['LOCAL_RANK']) 224 | elif 'SLURM_PROCID' in os.environ: 225 | args.rank = int(os.environ['SLURM_PROCID']) 226 | args.gpu = args.rank % torch.cuda.device_count() 227 | else: 228 | print('Not using distributed mode') 229 | args.distributed = False 230 | return 231 | 232 | args.distributed = True 233 | 234 | torch.cuda.set_device(args.gpu) 235 | args.dist_backend = 'nccl' 236 | print('| distributed init (rank {}): {}'.format( 237 | args.rank, args.dist_url), flush=True) 238 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 239 | world_size=args.world_size, rank=args.rank) 240 | torch.distributed.barrier() 241 | setup_for_distributed(args.rank == 0) 242 | 243 | 244 | def imwrite(imgs, path): 245 | imgs=imgs.cpu() 246 | imagenet_std = torch.tensor(IMAGENET_DEFAULT_STD).reshape(1, 3, 1, 1) 247 | imagenet_mean = torch.tensor(IMAGENET_DEFAULT_MEAN).reshape(1, 3, 1, 1) 248 | 249 | imgs = torchvision.utils.make_grid(torch.clamp(imgs * imagenet_std + imagenet_mean, min=0., max=1.)) 250 | 251 | imgs_pil = torchvision.transforms.ToPILImage()(imgs) 252 | imgs_pil.save(path) 253 | 254 | 255 | def img_distortion(recs, orgs): 256 | 257 | device = torch.device('cuda') 258 | 259 | imagenet_std = torch.tensor(IMAGENET_DEFAULT_STD).reshape(1, 3, 1, 1).to(device) 260 | imagenet_mean = torch.tensor(IMAGENET_DEFAULT_MEAN).reshape(1, 3, 1, 1).to(device) 261 | 262 | org_imgs = torch.clamp(orgs * imagenet_std + imagenet_mean, min=0., max=1.) * 255. 263 | rec_imgs = torch.clamp(recs * imagenet_std + imagenet_mean, min=0., max=1.) * 255. 264 | 265 | mse_no_reduction = torch.mean(mse_loss(org_imgs, rec_imgs, reduction='none'), dim=(1, 2, 3)) 266 | psnr = torch.mean(10. * torch.log10(255. ** 2 / mse_no_reduction)) 267 | 268 | return psnr.item() 269 | 270 | 271 | -------------------------------------------------------------------------------- /pretrained_model/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torchvision import datasets, transforms 5 | from torchvision.datasets.folder import ImageFolder, default_loader 6 | 7 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 8 | from timm.data import create_transform 9 | 10 | 11 | class INatDataset(ImageFolder): 12 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 13 | category='name', loader=default_loader): 14 | self.transform = transform 15 | self.loader = loader 16 | self.target_transform = target_transform 17 | self.year = year 18 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 19 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 20 | with open(path_json) as json_file: 21 | data = json.load(json_file) 22 | 23 | with open(os.path.join(root, 'categories.json')) as json_file: 24 | data_catg = json.load(json_file) 25 | 26 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 27 | 28 | with open(path_json_for_targeter) as json_file: 29 | data_for_targeter = json.load(json_file) 30 | 31 | targeter = {} 32 | indexer = 0 33 | for elem in data_for_targeter['annotations']: 34 | king = [] 35 | king.append(data_catg[int(elem['category_id'])][category]) 36 | if king[0] not in targeter.keys(): 37 | targeter[king[0]] = indexer 38 | indexer += 1 39 | self.nb_classes = len(targeter) 40 | 41 | self.samples = [] 42 | for elem in data['images']: 43 | cut = elem['file_name'].split('/') 44 | target_current = int(cut[2]) 45 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 46 | 47 | categors = data_catg[target_current] 48 | target_current_true = targeter[categors[category]] 49 | self.samples.append((path_current, target_current_true)) 50 | 51 | # __getitem__ and __len__ inherited from ImageFolder 52 | 53 | 54 | def build_dataset(is_train, args): 55 | transform = build_transform(is_train, args) 56 | 57 | if args.data_set == 'IMNET': 58 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 59 | dataset = datasets.ImageFolder(root, transform=transform) 60 | nb_classes = 1000 61 | elif args.data_set == 'INAT19': 62 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 63 | category=args.inat_category, transform=transform) 64 | nb_classes = dataset.nb_classes 65 | 66 | return dataset, nb_classes 67 | 68 | 69 | def build_transform(is_train, args): 70 | resize_im = args.input_size > 32 71 | if is_train: 72 | # this should always dispatch to transforms_imagenet_train 73 | transform = create_transform( 74 | input_size=args.input_size, 75 | is_training=True, 76 | color_jitter=args.color_jitter, 77 | auto_augment=args.aa, 78 | interpolation=args.train_interpolation, 79 | re_prob=args.reprob, 80 | re_mode=args.remode, 81 | re_count=args.recount, 82 | ) 83 | if not resize_im: 84 | # replace RandomResizedCropAndInterpolation with 85 | # RandomCrop 86 | transform.transforms[0] = transforms.RandomCrop( 87 | args.input_size, padding=4) 88 | return transform 89 | 90 | t = [] 91 | if resize_im: 92 | size = int((256 / 224) * args.input_size) 93 | t.append( 94 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 95 | ) 96 | t.append(transforms.CenterCrop(args.input_size)) 97 | 98 | t.append(transforms.ToTensor()) 99 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 100 | return transforms.Compose(t) 101 | -------------------------------------------------------------------------------- /pretrained_model/engine.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train and eval functions used in main.py 3 | """ 4 | import math 5 | import sys 6 | import os 7 | from typing import Iterable, Optional 8 | 9 | import torch 10 | 11 | from timm.data import Mixup 12 | from timm.utils import accuracy, ModelEma 13 | 14 | from losses import JointLoss, DenormalizedMSELoss 15 | import utils 16 | 17 | 18 | def train_one_epoch(model: torch.nn.Module, criterion: JointLoss, 19 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 20 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 21 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 22 | set_training_mode=True): 23 | model.train(set_training_mode) 24 | metric_logger = utils.MetricLogger(delimiter=" ") 25 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 26 | header = 'Epoch: [{}]'.format(epoch) 27 | print_freq = 10 28 | 29 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 30 | samples = samples.to(device, non_blocking=True) 31 | targets = targets.to(device, non_blocking=True) 32 | 33 | if mixup_fn is not None: 34 | samples, targets = mixup_fn(samples, targets) 35 | 36 | with torch.cuda.amp.autocast(): 37 | outputs = model(samples) 38 | loss = criterion(samples, outputs, targets) 39 | 40 | loss_value = loss.item() 41 | 42 | if not math.isfinite(loss_value): 43 | print("Loss is {}, stopping training".format(loss_value)) 44 | sys.exit(1) 45 | 46 | optimizer.zero_grad() 47 | 48 | # this attribute is added by timm on one optimizer (adahessian) 49 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 50 | loss_scaler(loss, optimizer, clip_grad=max_norm, 51 | parameters=model.parameters(), create_graph=is_second_order) 52 | 53 | torch.cuda.synchronize() 54 | if model_ema is not None: 55 | model_ema.update(model) 56 | 57 | metric_logger.update(loss=loss_value) 58 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 59 | # gather the stats from all processes 60 | metric_logger.synchronize_between_processes() 61 | print("Averaged stats:", metric_logger) 62 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 63 | 64 | 65 | @torch.no_grad() 66 | def evaluate(data_loader, model, device, output_dir): 67 | criterion_cls = torch.nn.CrossEntropyLoss() 68 | criterion_rec = DenormalizedMSELoss() 69 | 70 | metric_logger = utils.MetricLogger(delimiter=" ") 71 | header = 'Test:' 72 | 73 | # switch to evaluation mode 74 | model.eval() 75 | write_img = True 76 | for images, target in metric_logger.log_every(data_loader, 10, header): 77 | images = images.to(device, non_blocking=True) 78 | target = target.to(device, non_blocking=True) 79 | 80 | # compute output 81 | with torch.cuda.amp.autocast(): 82 | output = model(images) 83 | loss_cls = criterion_cls(output[0], target) 84 | loss_rec = criterion_rec(output[1], images) 85 | psnr = utils.img_distortion(output[1], images) 86 | 87 | if write_img: 88 | utils.imwrite(images[:4], os.path.join(output_dir,'example_org.png')) 89 | utils.imwrite(output[1][:4], os.path.join(output_dir,'example_rec.png')) 90 | write_img = False 91 | 92 | acc1, acc5 = accuracy(output[0], target, topk=(1, 5)) 93 | 94 | batch_size = images.shape[0] 95 | metric_logger.update(loss_cls=loss_cls.item()) 96 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 97 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 98 | metric_logger.meters['loss_rec'].update(loss_rec.item(), n=batch_size) 99 | metric_logger.meters['psnr'].update(psnr, n=batch_size) 100 | 101 | # gather the stats from all processes 102 | metric_logger.synchronize_between_processes() 103 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss_cls {losses_cls.global_avg:.3f} loss_rec {losses_rec.global_avg:.3f} psnr {psnr.global_avg:.3f}' 104 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses_cls=metric_logger.loss_cls, losses_rec=metric_logger.loss_rec, psnr=metric_logger.psnr)) 105 | 106 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 107 | 108 | 109 | -------------------------------------------------------------------------------- /pretrained_model/hubconf.py: -------------------------------------------------------------------------------- 1 | from models import * 2 | 3 | dependencies = ["torch", "torchvision", "timm"] 4 | -------------------------------------------------------------------------------- /pretrained_model/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements the loss functions 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | 8 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 9 | 10 | 11 | class DenormalizedMSELoss(nn.Module): 12 | 13 | def __init__(self, scale=255.): 14 | super().__init__() 15 | self.imagenet_std = torch.tensor(IMAGENET_DEFAULT_STD).reshape(1,3,1,1) 16 | self.scale = scale 17 | 18 | def forward(self, x, y): 19 | diff = (x - y) * self.imagenet_std.to(x.device) * self.scale 20 | mse_loss = torch.mean(diff ** 2) 21 | 22 | return mse_loss 23 | 24 | 25 | class JointLoss(nn.Module): 26 | 27 | def __init__(self, base_criterion: torch.nn.Module, alpha: float): 28 | super().__init__() 29 | self.base_criterion = base_criterion 30 | self.alpha = alpha 31 | self.d_mse = DenormalizedMSELoss() 32 | 33 | def forward(self, inputs, outputs, labels): 34 | outputs_cls, outputs_rec = outputs 35 | 36 | cls_loss = self.base_criterion(outputs_cls, labels) 37 | mse_loss = self.d_mse(outputs_rec, inputs) 38 | 39 | return cls_loss + self.alpha * mse_loss 40 | -------------------------------------------------------------------------------- /pretrained_model/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import numpy as np 4 | import time 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import json 8 | import os 9 | 10 | from pathlib import Path 11 | 12 | from timm.data import Mixup 13 | from timm.models import create_model 14 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 15 | from timm.scheduler import create_scheduler 16 | from timm.optim import create_optimizer 17 | from timm.utils import NativeScaler, get_state_dict, ModelEma 18 | 19 | from datasets import build_dataset 20 | from engine import train_one_epoch, evaluate 21 | 22 | from samplers import RASampler 23 | 24 | import utils 25 | 26 | from losses import JointLoss 27 | import model 28 | 29 | 30 | def get_args_parser(): 31 | parser = argparse.ArgumentParser('Pretraining and evaluation script', add_help=False) 32 | parser.add_argument('--batch-size', default=64, type=int) 33 | parser.add_argument('--epochs', default=300, type=int) 34 | 35 | # Model parameters 36 | parser.add_argument('--model', default='pretrain_small_patch16_224', type=str, metavar='MODEL', 37 | help='Name of model to train') 38 | parser.add_argument('--input-size', default=224, type=int, help='images input size') 39 | 40 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 41 | help='Dropout rate (default: 0.)') 42 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 43 | help='Drop path rate (default: 0.1)') 44 | 45 | parser.add_argument('--model-ema', action='store_true') 46 | parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') 47 | parser.set_defaults(model_ema=True) 48 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') 49 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') 50 | 51 | # Optimizer parameters 52 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 53 | help='Optimizer (default: "adamw"') 54 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 55 | help='Optimizer Epsilon (default: 1e-8)') 56 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 57 | help='Optimizer Betas (default: None, use opt default)') 58 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 59 | help='Clip gradient norm (default: None, no clipping)') 60 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 61 | help='SGD momentum (default: 0.9)') 62 | parser.add_argument('--weight-decay', type=float, default=0.05, 63 | help='weight decay (default: 0.05)') 64 | # Learning rate schedule parameters 65 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 66 | help='LR scheduler (default: "cosine"') 67 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', 68 | help='learning rate (default: 5e-4)') 69 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 70 | help='learning rate noise on/off epoch percentages') 71 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 72 | help='learning rate noise limit percent (default: 0.67)') 73 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 74 | help='learning rate noise std-dev (default: 1.0)') 75 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 76 | help='warmup learning rate (default: 1e-6)') 77 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 78 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 79 | 80 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 81 | help='epoch interval to decay LR') 82 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 83 | help='epochs to warmup LR, if scheduler supports') 84 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 85 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 86 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 87 | help='patience epochs for Plateau LR scheduler (default: 10') 88 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 89 | help='LR decay rate (default: 0.1)') 90 | 91 | # Augmentation parameters 92 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 93 | help='Color jitter factor (default: 0.4)') 94 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 95 | help='Use AutoAugment policy. "v0" or "original". " + \ 96 | "(default: rand-m9-mstd0.5-inc1)'), 97 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 98 | parser.add_argument('--train-interpolation', type=str, default='bicubic', 99 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 100 | 101 | parser.add_argument('--repeated-aug', action='store_true') 102 | parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') 103 | parser.set_defaults(repeated_aug=False) 104 | 105 | # * Random Erase params 106 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 107 | help='Random erase prob (default: 0.25)') 108 | parser.add_argument('--remode', type=str, default='pixel', 109 | help='Random erase mode (default: "pixel")') 110 | parser.add_argument('--recount', type=int, default=1, 111 | help='Random erase count (default: 1)') 112 | parser.add_argument('--resplit', action='store_true', default=False, 113 | help='Do not random erase first (clean) augmentation split') 114 | 115 | # * Mixup params 116 | parser.add_argument('--mixup', type=float, default=0.8, 117 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 118 | parser.add_argument('--cutmix', type=float, default=1.0, 119 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 120 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 121 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 122 | parser.add_argument('--mixup-prob', type=float, default=1.0, 123 | help='Probability of performing mixup or cutmix when either/both is enabled') 124 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 125 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 126 | parser.add_argument('--mixup-mode', type=str, default='batch', 127 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 128 | 129 | # Distillation parameters 130 | parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL', 131 | help='Name of teacher model to train (default: "regnety_160"') 132 | parser.add_argument('--teacher-path', type=str, default='') 133 | parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") 134 | parser.add_argument('--distillation-alpha', default=0.5, type=float, help="") 135 | parser.add_argument('--distillation-tau', default=1.0, type=float, help="") 136 | 137 | # * Finetuning params 138 | parser.add_argument('--finetune', default='', help='finetune from checkpoint') 139 | 140 | # Dataset parameters 141 | parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str, 142 | help='dataset path') 143 | parser.add_argument('--data-set', default='IMNET', choices=['IMNET', 'INAT19'], 144 | type=str, help='Image Net dataset path') 145 | parser.add_argument('--inat-category', default='name', 146 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], 147 | type=str, help='semantic granularity') 148 | 149 | parser.add_argument('--output_dir', default='', 150 | help='path where to save, empty for no saving') 151 | parser.add_argument('--device', default='cuda', 152 | help='device to use for training / testing') 153 | parser.add_argument('--seed', default=0, type=int) 154 | parser.add_argument('--resume', default='', help='resume from checkpoint') 155 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 156 | help='start epoch') 157 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 158 | parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') 159 | parser.add_argument('--num_workers', default=10, type=int) 160 | parser.add_argument('--pin-mem', action='store_true', 161 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 162 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', 163 | help='') 164 | parser.set_defaults(pin_mem=True) 165 | 166 | # distributed training parameters 167 | parser.add_argument('--world_size', default=1, type=int, 168 | help='number of distributed processes') 169 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 170 | return parser 171 | 172 | 173 | def main(args): 174 | utils.init_distributed_mode(args) 175 | 176 | print(args) 177 | 178 | if args.distillation_type != 'none' and args.finetune and not args.eval: 179 | raise NotImplementedError("Finetuning with distillation not yet supported") 180 | 181 | device = torch.device(args.device) 182 | 183 | # fix the seed for reproducibility 184 | seed = args.seed + utils.get_rank() 185 | torch.manual_seed(seed) 186 | np.random.seed(seed) 187 | 188 | cudnn.benchmark = True 189 | 190 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 191 | dataset_val, _ = build_dataset(is_train=False, args=args) 192 | 193 | if True: # args.distributed: 194 | num_tasks = utils.get_world_size() 195 | global_rank = utils.get_rank() 196 | if args.repeated_aug: 197 | sampler_train = RASampler( 198 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 199 | ) 200 | else: 201 | sampler_train = torch.utils.data.DistributedSampler( 202 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 203 | ) 204 | if args.dist_eval: 205 | if len(dataset_val) % num_tasks != 0: 206 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 207 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 208 | 'equal num of samples per-process.') 209 | sampler_val = torch.utils.data.DistributedSampler( 210 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 211 | else: 212 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 213 | else: 214 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 215 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 216 | 217 | data_loader_train = torch.utils.data.DataLoader( 218 | dataset_train, sampler=sampler_train, 219 | batch_size=args.batch_size, 220 | num_workers=args.num_workers, 221 | pin_memory=args.pin_mem, 222 | drop_last=True, 223 | ) 224 | 225 | data_loader_val = torch.utils.data.DataLoader( 226 | dataset_val, sampler=sampler_val, 227 | batch_size=int(1.5 * args.batch_size), 228 | num_workers=args.num_workers, 229 | pin_memory=args.pin_mem, 230 | drop_last=False 231 | ) 232 | 233 | mixup_fn = None 234 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 235 | if mixup_active: 236 | mixup_fn = Mixup( 237 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 238 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 239 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 240 | 241 | print(f"Creating model: {args.model}") 242 | model = create_model( 243 | args.model, 244 | pretrained=False, 245 | num_classes=args.nb_classes, 246 | drop_rate=args.drop, 247 | drop_path_rate=args.drop_path, 248 | drop_block_rate=None, 249 | ) 250 | 251 | if args.finetune: 252 | if args.finetune.startswith('https'): 253 | checkpoint = torch.hub.load_state_dict_from_url( 254 | args.finetune, map_location='cpu', check_hash=True) 255 | else: 256 | checkpoint = torch.load(args.finetune, map_location='cpu') 257 | 258 | checkpoint_model = checkpoint['model'] 259 | state_dict = model.state_dict() 260 | for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']: 261 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 262 | print(f"Removing key {k} from pretrained checkpoint") 263 | del checkpoint_model[k] 264 | 265 | # interpolate position embedding 266 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 267 | embedding_size = pos_embed_checkpoint.shape[-1] 268 | num_patches = model.patch_embed.num_patches 269 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 270 | # height (== width) for the checkpoint position embedding 271 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 272 | # height (== width) for the new position embedding 273 | new_size = int(num_patches ** 0.5) 274 | # class_token and dist_token are kept unchanged 275 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 276 | # only the position tokens are interpolated 277 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 278 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 279 | pos_tokens = torch.nn.functional.interpolate( 280 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 281 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 282 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 283 | checkpoint_model['pos_embed'] = new_pos_embed 284 | 285 | model.load_state_dict(checkpoint_model, strict=False) 286 | 287 | model.to(device) 288 | 289 | model_ema = None 290 | if args.model_ema: 291 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 292 | model_ema = ModelEma( 293 | model, 294 | decay=args.model_ema_decay, 295 | device='cpu' if args.model_ema_force_cpu else '', 296 | resume='') 297 | 298 | model_without_ddp = model 299 | if args.distributed: 300 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 301 | model_without_ddp = model.module 302 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 303 | print('number of params:', n_parameters) 304 | 305 | linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 306 | args.lr = linear_scaled_lr 307 | optimizer = create_optimizer(args, model_without_ddp) 308 | loss_scaler = NativeScaler() 309 | 310 | lr_scheduler, _ = create_scheduler(args, optimizer) 311 | 312 | criterion = LabelSmoothingCrossEntropy() 313 | 314 | if args.mixup > 0.: 315 | # smoothing is handled with mixup label transform 316 | criterion = SoftTargetCrossEntropy() 317 | elif args.smoothing: 318 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 319 | else: 320 | criterion = torch.nn.CrossEntropyLoss() 321 | 322 | teacher_model = None 323 | if args.distillation_type != 'none': 324 | assert args.teacher_path, 'need to specify teacher-path when using distillation' 325 | print(f"Creating teacher model: {args.teacher_model}") 326 | teacher_model = create_model( 327 | args.teacher_model, 328 | pretrained=False, 329 | num_classes=args.nb_classes, 330 | global_pool='avg', 331 | ) 332 | if args.teacher_path.startswith('https'): 333 | checkpoint = torch.hub.load_state_dict_from_url( 334 | args.teacher_path, map_location='cpu', check_hash=True) 335 | else: 336 | checkpoint = torch.load(args.teacher_path, map_location='cpu') 337 | teacher_model.load_state_dict(checkpoint['model']) 338 | teacher_model.to(device) 339 | teacher_model.eval() 340 | 341 | criterion = JointLoss(criterion, alpha=1e-3) 342 | 343 | output_dir = Path(args.output_dir) 344 | if args.resume: 345 | if args.resume.startswith('https'): 346 | checkpoint = torch.hub.load_state_dict_from_url( 347 | args.resume, map_location='cpu', check_hash=True) 348 | else: 349 | checkpoint = torch.load(args.resume, map_location='cpu') 350 | model_without_ddp.load_state_dict(checkpoint['model']) 351 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 352 | optimizer.load_state_dict(checkpoint['optimizer']) 353 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 354 | args.start_epoch = checkpoint['epoch'] + 1 355 | if args.model_ema: 356 | utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) 357 | if 'scaler' in checkpoint: 358 | loss_scaler.load_state_dict(checkpoint['scaler']) 359 | 360 | if args.eval: 361 | if not os.path.exists("./eval"): 362 | os.makedirs('./eval') 363 | 364 | test_stats = evaluate(data_loader_val, model, device, "./eval") 365 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 366 | return 367 | 368 | print(f"Start training for {args.epochs} epochs") 369 | start_time = time.time() 370 | max_accuracy = 0.0 371 | for epoch in range(args.start_epoch, args.epochs): 372 | if args.distributed: 373 | data_loader_train.sampler.set_epoch(epoch) 374 | 375 | train_stats = train_one_epoch( 376 | model, criterion, data_loader_train, 377 | optimizer, device, epoch, loss_scaler, 378 | args.clip_grad, model_ema, mixup_fn, 379 | set_training_mode=args.finetune == '' # keep in eval mode during finetuning 380 | ) 381 | 382 | lr_scheduler.step(epoch) 383 | if args.output_dir and args.model_ema: 384 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 385 | for checkpoint_path in checkpoint_paths: 386 | utils.save_on_master({ 387 | 'model': model_without_ddp.state_dict(), 388 | 'optimizer': optimizer.state_dict(), 389 | 'lr_scheduler': lr_scheduler.state_dict(), 390 | 'epoch': epoch, 391 | 'model_ema': get_state_dict(model_ema), 392 | 'scaler': loss_scaler.state_dict(), 393 | 'args': args, 394 | }, checkpoint_path) 395 | elif args.output_dir: 396 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 397 | for checkpoint_path in checkpoint_paths: 398 | utils.save_on_master({ 399 | 'model': model_without_ddp.state_dict(), 400 | 'optimizer': optimizer.state_dict(), 401 | 'lr_scheduler': lr_scheduler.state_dict(), 402 | 'epoch': epoch, 403 | 'scaler': loss_scaler.state_dict(), 404 | 'args': args, 405 | }, checkpoint_path) 406 | 407 | test_stats = evaluate(data_loader_val, model, device, args.output_dir) 408 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 409 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 410 | print(f'Max accuracy: {max_accuracy:.2f}%') 411 | 412 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 413 | **{f'test_{k}': v for k, v in test_stats.items()}, 414 | 'epoch': epoch, 415 | 'n_parameters': n_parameters} 416 | 417 | if args.output_dir and utils.is_main_process(): 418 | with (output_dir / "log.txt").open("a") as f: 419 | f.write(json.dumps(log_stats) + "\n") 420 | 421 | total_time = time.time() - start_time 422 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 423 | print('Training time {}'.format(total_time_str)) 424 | 425 | 426 | if __name__ == '__main__': 427 | parser = argparse.ArgumentParser('Pretraining and evaluation script', parents=[get_args_parser()]) 428 | args = parser.parse_args() 429 | if args.output_dir: 430 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 431 | main(args) 432 | -------------------------------------------------------------------------------- /pretrained_model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | 5 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 6 | from timm.models.registry import register_model 7 | from timm.models.layers import DropPath, trunc_normal_, to_2tuple 8 | 9 | 10 | 11 | __all__ = [ 12 | 'pretrained_model' 13 | ] 14 | 15 | 16 | def _cfg(url='', **kwargs): 17 | return { 18 | 'url': url, 19 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 20 | 'crop_pct': .9, 'interpolation': 'bicubic', 21 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 22 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 23 | **kwargs 24 | } 25 | 26 | 27 | def conv(in_channels, out_channels, kernel_size=5, stride=2): 28 | return nn.Conv2d( 29 | in_channels, 30 | out_channels, 31 | kernel_size=kernel_size, 32 | stride=stride, 33 | padding=kernel_size // 2, 34 | ) 35 | 36 | 37 | def deconv(in_channels, out_channels, kernel_size=5, stride=2): 38 | return nn.ConvTranspose2d( 39 | in_channels, 40 | out_channels, 41 | kernel_size=kernel_size, 42 | stride=stride, 43 | output_padding=stride - 1, 44 | padding=kernel_size // 2, 45 | ) 46 | 47 | 48 | class Img2Embed(nn.Module): 49 | """ Encoding Image to Embedding 50 | """ 51 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 52 | super().__init__() 53 | img_size = to_2tuple(img_size) 54 | patch_size = to_2tuple(patch_size) 55 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 56 | self.img_size = img_size 57 | self.patch_size = patch_size 58 | self.num_patches = num_patches 59 | 60 | middle_chans = [128, 128, 128] 61 | self.proj = nn.Sequential( 62 | conv(in_chans, middle_chans[0]), 63 | nn.LeakyReLU(inplace=True), 64 | conv(middle_chans[0], middle_chans[1]), 65 | nn.LeakyReLU(inplace=True), 66 | conv(middle_chans[1], middle_chans[2]), 67 | nn.LeakyReLU(inplace=True), 68 | conv(middle_chans[2], embed_dim) 69 | ) 70 | 71 | def forward(self, x): 72 | B, C, H, W = x.shape 73 | # FIXME look at relaxing size constraints 74 | assert H == self.img_size[0] and W == self.img_size[1], \ 75 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 76 | x = self.proj(x).flatten(2).transpose(1, 2) 77 | return x 78 | 79 | 80 | class Embed2Img(nn.Module): 81 | """ Decode Embedding to Image 82 | """ 83 | def __init__(self, img_size=224, patch_size=16, out_chans=3, embed_dim=768): 84 | super().__init__() 85 | img_size = to_2tuple(img_size) 86 | patch_size = to_2tuple(patch_size) 87 | embed_size = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) 88 | num_patches = embed_size[0] * embed_size[1] 89 | self.img_size = img_size 90 | self.patch_size = patch_size 91 | self.embed_size = embed_size 92 | self.num_patches = num_patches 93 | 94 | middle_chans = [128, 128, 128] 95 | self.proj = nn.Sequential( 96 | deconv(embed_dim, middle_chans[0]), 97 | nn.LeakyReLU(inplace=True), 98 | deconv(middle_chans[0], middle_chans[1]), 99 | nn.LeakyReLU(inplace=True), 100 | deconv(middle_chans[1], middle_chans[2]), 101 | nn.LeakyReLU(inplace=True), 102 | deconv(middle_chans[2], out_chans) 103 | ) 104 | 105 | def forward(self, x): 106 | B, HW, C = x.shape 107 | assert HW == self.num_patches, \ 108 | f"Input embeding size ({HW}) doesn't match patches size ({self.num_patches})." 109 | x = self.proj(x.transpose(1, 2).reshape(B, C, self.embed_size[0], self.embed_size[1])) 110 | return x 111 | 112 | 113 | class Mlp(nn.Module): 114 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 115 | super().__init__() 116 | out_features = out_features or in_features 117 | hidden_features = hidden_features or in_features 118 | self.fc1 = nn.Linear(in_features, hidden_features) 119 | self.act = act_layer() 120 | self.fc2 = nn.Linear(hidden_features, out_features) 121 | self.drop = nn.Dropout(drop) 122 | 123 | def forward(self, x): 124 | x = self.fc1(x) 125 | x = self.act(x) 126 | x = self.drop(x) 127 | x = self.fc2(x) 128 | x = self.drop(x) 129 | return x 130 | 131 | 132 | class Attention(nn.Module): 133 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 134 | super().__init__() 135 | self.num_heads = num_heads 136 | head_dim = dim // num_heads 137 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 138 | self.scale = qk_scale or head_dim ** -0.5 139 | 140 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 141 | self.attn_drop = nn.Dropout(attn_drop) 142 | self.proj = nn.Linear(dim, dim) 143 | self.proj_drop = nn.Dropout(proj_drop) 144 | 145 | def forward(self, x): 146 | B, N, C = x.shape 147 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 148 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 149 | 150 | attn = (q @ k.transpose(-2, -1)) * self.scale 151 | attn = attn.softmax(dim=-1) 152 | attn = self.attn_drop(attn) 153 | 154 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 155 | x = self.proj(x) 156 | x = self.proj_drop(x) 157 | return x 158 | 159 | 160 | class Block(nn.Module): 161 | 162 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 163 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 164 | super().__init__() 165 | self.norm1 = norm_layer(dim) 166 | self.attn = Attention( 167 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 168 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 169 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 170 | self.norm2 = norm_layer(dim) 171 | mlp_hidden_dim = int(dim * mlp_ratio) 172 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 173 | 174 | def forward(self, x): 175 | x = x + self.drop_path(self.attn(self.norm1(x))) 176 | x = x + self.drop_path(self.mlp(self.norm2(x))) 177 | return x 178 | 179 | 180 | class PretrainTransformer(nn.Module): 181 | """ Pretrain Transformer 182 | """ 183 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 184 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 185 | drop_path_rate=0., norm_layer=nn.LayerNorm): 186 | super().__init__() 187 | self.num_classes = num_classes 188 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 189 | self.depth = depth 190 | 191 | self.patch_embed = Img2Embed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=192) 192 | self.chans_embed = nn.Linear(192, embed_dim) 193 | num_patches = self.patch_embed.num_patches 194 | 195 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 196 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 197 | self.pos_drop = nn.Dropout(p=drop_rate) 198 | 199 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 200 | self.blocks = nn.ModuleList([ 201 | Block( 202 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 203 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 204 | for i in range(depth)]) 205 | self.norm = norm_layer(embed_dim) 206 | 207 | # Classifier head 208 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 209 | 210 | # Reconstruction head 211 | self.head_rec = Embed2Img(img_size=img_size, patch_size=patch_size, out_chans=in_chans, embed_dim=embed_dim) 212 | 213 | # fusion 214 | self.fusion0 = nn.Linear(embed_dim, embed_dim // 4) 215 | self.fusion1 = nn.Linear(embed_dim, embed_dim // 4) 216 | self.fusion2 = nn.Linear(embed_dim, embed_dim // 4) 217 | self.fusion3 = nn.Linear(embed_dim, embed_dim // 4) 218 | self.fusion = nn.Linear(embed_dim, embed_dim) 219 | 220 | trunc_normal_(self.pos_embed, std=.02) 221 | trunc_normal_(self.cls_token, std=.02) 222 | self.apply(self._init_weights) 223 | 224 | def _init_weights(self, m): 225 | if isinstance(m, nn.Linear): 226 | trunc_normal_(m.weight, std=.02) 227 | if isinstance(m, nn.Linear) and m.bias is not None: 228 | nn.init.constant_(m.bias, 0) 229 | elif isinstance(m, nn.LayerNorm): 230 | nn.init.constant_(m.bias, 0) 231 | nn.init.constant_(m.weight, 1.0) 232 | 233 | @torch.jit.ignore 234 | def no_weight_decay(self): 235 | return {'pos_embed', 'cls_token'} 236 | 237 | def get_classifier(self): 238 | return self.head 239 | 240 | def reset_classifier(self, num_classes, global_pool=''): 241 | self.num_classes = num_classes 242 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 243 | 244 | def forward_features(self, x): 245 | B = x.shape[0] 246 | x = self.patch_embed(x) 247 | x = self.chans_embed(x) 248 | 249 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 250 | x0 = torch.cat((cls_tokens, x), dim=1) 251 | x0 = x0 + self.pos_embed 252 | x0 = self.pos_drop(x0) 253 | 254 | x1 = self.blocks[0](x0) 255 | x2 = self.blocks[1](x1) 256 | x3 = self.blocks[2](x2) 257 | x4 = self.blocks[3](x3) 258 | x5 = self.blocks[4](x4) 259 | x6 = self.blocks[5](x5) 260 | x7 = self.blocks[6](x6) 261 | x8 = self.blocks[7](x7) 262 | x9 = self.blocks[8](x8) 263 | x10 = self.blocks[9](x9) 264 | x11 = self.blocks[10](x10) 265 | x12 = self.blocks[11](x11) 266 | 267 | x_out = self.norm(x12) 268 | 269 | y0 = self.fusion0(x) 270 | y1 = self.fusion1(x1[:, 1:]) 271 | y2 = self.fusion2(x2[:, 1:]) 272 | y3 = self.fusion3(x3[:, 1:]) 273 | 274 | y = torch.cat((y0, y1, y2, y3), dim=2) 275 | y = self.fusion(y) 276 | 277 | return x_out[:, 0], y 278 | 279 | def forward(self, x): 280 | x, y = self.forward_features(x) 281 | cls = self.head(x) 282 | rec = self.head_rec(y) 283 | return (cls, rec) 284 | 285 | 286 | @register_model 287 | def pretrained_model(pretrained=False, **kwargs): 288 | model = PretrainTransformer( 289 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 290 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 291 | model.default_cfg = _cfg() 292 | 293 | return model -------------------------------------------------------------------------------- /pretrained_model/pretrain_s/readme.md: -------------------------------------------------------------------------------- 1 | Place the corresponding checkpoint.pth in this folder. -------------------------------------------------------------------------------- /pretrained_model/samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import math 4 | 5 | 6 | class RASampler(torch.utils.data.Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset for distributed, 8 | with repeated augmentation. 9 | It ensures that different each augmented version of a sample will be visible to a 10 | different process (GPU) 11 | Heavily based on torch.utils.data.DistributedSampler 12 | """ 13 | 14 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 15 | if num_replicas is None: 16 | if not dist.is_available(): 17 | raise RuntimeError("Requires distributed package to be available") 18 | num_replicas = dist.get_world_size() 19 | if rank is None: 20 | if not dist.is_available(): 21 | raise RuntimeError("Requires distributed package to be available") 22 | rank = dist.get_rank() 23 | self.dataset = dataset 24 | self.num_replicas = num_replicas 25 | self.rank = rank 26 | self.epoch = 0 27 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 28 | self.total_size = self.num_samples * self.num_replicas 29 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 30 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 31 | self.shuffle = shuffle 32 | 33 | def __iter__(self): 34 | # deterministically shuffle based on epoch 35 | g = torch.Generator() 36 | g.manual_seed(self.epoch) 37 | if self.shuffle: 38 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 39 | else: 40 | indices = list(range(len(self.dataset))) 41 | 42 | # add extra samples to make it evenly divisible 43 | indices = [ele for ele in indices for i in range(3)] 44 | indices += indices[:(self.total_size - len(indices))] 45 | assert len(indices) == self.total_size 46 | 47 | # subsample 48 | indices = indices[self.rank:self.total_size:self.num_replicas] 49 | assert len(indices) == self.num_samples 50 | 51 | return iter(indices[:self.num_selected_samples]) 52 | 53 | def __len__(self): 54 | return self.num_selected_samples 55 | 56 | def set_epoch(self, epoch): 57 | self.epoch = epoch 58 | -------------------------------------------------------------------------------- /pretrained_model/test.sh: -------------------------------------------------------------------------------- 1 | python main.py --eval --resume ./pretrain_s/checkpoint.pth --model pretrained_model --data-path /path/to/imagenet/ --output_dir ./eval -------------------------------------------------------------------------------- /pretrained_model/train.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model pretrained_model --no-model-ema --clip-grad 1.0 --batch-size 128 --num_workers 16 --data-path /path/to/imagenet/ --output_dir ./ckp_pretrain -------------------------------------------------------------------------------- /pretrained_model/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc functions, including distributed helpers. 3 | 4 | Mostly copy-paste from torchvision references. 5 | """ 6 | import io 7 | import os 8 | import time 9 | from collections import defaultdict, deque 10 | import datetime 11 | 12 | import torch 13 | import torch.distributed as dist 14 | import torchvision 15 | from torch.nn.functional import mse_loss 16 | 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | 19 | # import ipdb 20 | 21 | 22 | class SmoothedValue(object): 23 | """Track a series of values and provide access to smoothed values over a 24 | window or the global series average. 25 | """ 26 | 27 | def __init__(self, window_size=20, fmt=None): 28 | if fmt is None: 29 | fmt = "{median:.4f} ({global_avg:.4f})" 30 | self.deque = deque(maxlen=window_size) 31 | self.total = 0.0 32 | self.count = 0 33 | self.fmt = fmt 34 | 35 | def update(self, value, n=1): 36 | self.deque.append(value) 37 | self.count += n 38 | self.total += value * n 39 | 40 | def synchronize_between_processes(self): 41 | """ 42 | Warning: does not synchronize the deque! 43 | """ 44 | if not is_dist_avail_and_initialized(): 45 | return 46 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 47 | dist.barrier() 48 | dist.all_reduce(t) 49 | t = t.tolist() 50 | self.count = int(t[0]) 51 | self.total = t[1] 52 | 53 | @property 54 | def median(self): 55 | d = torch.tensor(list(self.deque)) 56 | return d.median().item() 57 | 58 | @property 59 | def avg(self): 60 | d = torch.tensor(list(self.deque), dtype=torch.float32) 61 | return d.mean().item() 62 | 63 | @property 64 | def global_avg(self): 65 | return self.total / self.count 66 | 67 | @property 68 | def max(self): 69 | return max(self.deque) 70 | 71 | @property 72 | def value(self): 73 | return self.deque[-1] 74 | 75 | def __str__(self): 76 | return self.fmt.format( 77 | median=self.median, 78 | avg=self.avg, 79 | global_avg=self.global_avg, 80 | max=self.max, 81 | value=self.value) 82 | 83 | 84 | class MetricLogger(object): 85 | def __init__(self, delimiter="\t"): 86 | self.meters = defaultdict(SmoothedValue) 87 | self.delimiter = delimiter 88 | 89 | def update(self, **kwargs): 90 | for k, v in kwargs.items(): 91 | if isinstance(v, torch.Tensor): 92 | v = v.item() 93 | assert isinstance(v, (float, int)) 94 | self.meters[k].update(v) 95 | 96 | def __getattr__(self, attr): 97 | if attr in self.meters: 98 | return self.meters[attr] 99 | if attr in self.__dict__: 100 | return self.__dict__[attr] 101 | raise AttributeError("'{}' object has no attribute '{}'".format( 102 | type(self).__name__, attr)) 103 | 104 | def __str__(self): 105 | loss_str = [] 106 | for name, meter in self.meters.items(): 107 | loss_str.append( 108 | "{}: {}".format(name, str(meter)) 109 | ) 110 | return self.delimiter.join(loss_str) 111 | 112 | def synchronize_between_processes(self): 113 | for meter in self.meters.values(): 114 | meter.synchronize_between_processes() 115 | 116 | def add_meter(self, name, meter): 117 | self.meters[name] = meter 118 | 119 | def log_every(self, iterable, print_freq, header=None): 120 | i = 0 121 | if not header: 122 | header = '' 123 | start_time = time.time() 124 | end = time.time() 125 | iter_time = SmoothedValue(fmt='{avg:.4f}') 126 | data_time = SmoothedValue(fmt='{avg:.4f}') 127 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 128 | log_msg = [ 129 | header, 130 | '[{0' + space_fmt + '}/{1}]', 131 | 'eta: {eta}', 132 | '{meters}', 133 | 'time: {time}', 134 | 'data: {data}' 135 | ] 136 | if torch.cuda.is_available(): 137 | log_msg.append('max mem: {memory:.0f}') 138 | log_msg = self.delimiter.join(log_msg) 139 | MB = 1024.0 * 1024.0 140 | for obj in iterable: 141 | data_time.update(time.time() - end) 142 | yield obj 143 | iter_time.update(time.time() - end) 144 | if i % print_freq == 0 or i == len(iterable) - 1: 145 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 146 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 147 | if torch.cuda.is_available(): 148 | print(log_msg.format( 149 | i, len(iterable), eta=eta_string, 150 | meters=str(self), 151 | time=str(iter_time), data=str(data_time), 152 | memory=torch.cuda.max_memory_allocated() / MB)) 153 | else: 154 | print(log_msg.format( 155 | i, len(iterable), eta=eta_string, 156 | meters=str(self), 157 | time=str(iter_time), data=str(data_time))) 158 | i += 1 159 | end = time.time() 160 | total_time = time.time() - start_time 161 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 162 | print('{} Total time: {} ({:.4f} s / it)'.format( 163 | header, total_time_str, total_time / len(iterable))) 164 | 165 | 166 | def _load_checkpoint_for_ema(model_ema, checkpoint): 167 | """ 168 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 169 | """ 170 | mem_file = io.BytesIO() 171 | torch.save(checkpoint, mem_file) 172 | mem_file.seek(0) 173 | model_ema._load_checkpoint(mem_file) 174 | 175 | 176 | def setup_for_distributed(is_master): 177 | """ 178 | This function disables printing when not in master process 179 | """ 180 | import builtins as __builtin__ 181 | builtin_print = __builtin__.print 182 | 183 | def print(*args, **kwargs): 184 | force = kwargs.pop('force', False) 185 | if is_master or force: 186 | builtin_print(*args, **kwargs) 187 | 188 | __builtin__.print = print 189 | 190 | 191 | def is_dist_avail_and_initialized(): 192 | if not dist.is_available(): 193 | return False 194 | if not dist.is_initialized(): 195 | return False 196 | return True 197 | 198 | 199 | def get_world_size(): 200 | if not is_dist_avail_and_initialized(): 201 | return 1 202 | return dist.get_world_size() 203 | 204 | 205 | def get_rank(): 206 | if not is_dist_avail_and_initialized(): 207 | return 0 208 | return dist.get_rank() 209 | 210 | 211 | def is_main_process(): 212 | return get_rank() == 0 213 | 214 | 215 | def save_on_master(*args, **kwargs): 216 | if is_main_process(): 217 | torch.save(*args, **kwargs) 218 | 219 | 220 | def init_distributed_mode(args): 221 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 222 | args.rank = int(os.environ["RANK"]) 223 | args.world_size = int(os.environ['WORLD_SIZE']) 224 | args.gpu = int(os.environ['LOCAL_RANK']) 225 | elif 'SLURM_PROCID' in os.environ: 226 | args.rank = int(os.environ['SLURM_PROCID']) 227 | args.gpu = args.rank % torch.cuda.device_count() 228 | else: 229 | print('Not using distributed mode') 230 | args.distributed = False 231 | return 232 | 233 | args.distributed = True 234 | 235 | torch.cuda.set_device(args.gpu) 236 | args.dist_backend = 'nccl' 237 | print('| distributed init (rank {}): {}'.format( 238 | args.rank, args.dist_url), flush=True) 239 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 240 | world_size=args.world_size, rank=args.rank) 241 | torch.distributed.barrier() 242 | setup_for_distributed(args.rank == 0) 243 | 244 | 245 | def imwrite(imgs, path): 246 | imgs=imgs.cpu() 247 | imagenet_std = torch.tensor(IMAGENET_DEFAULT_STD).reshape(1, 3, 1, 1) 248 | imagenet_mean = torch.tensor(IMAGENET_DEFAULT_MEAN).reshape(1, 3, 1, 1) 249 | 250 | imgs = torchvision.utils.make_grid(torch.clamp(imgs * imagenet_std + imagenet_mean, min=0., max=1.)) 251 | 252 | imgs_pil = torchvision.transforms.ToPILImage()(imgs) 253 | imgs_pil.save(path) 254 | 255 | 256 | def img_distortion(recs, orgs): 257 | 258 | device = torch.device('cuda') 259 | 260 | imagenet_std = torch.tensor(IMAGENET_DEFAULT_STD).reshape(1, 3, 1, 1).to(device) 261 | imagenet_mean = torch.tensor(IMAGENET_DEFAULT_MEAN).reshape(1, 3, 1, 1).to(device) 262 | 263 | org_imgs = torch.clamp(orgs * imagenet_std + imagenet_mean, min=0., max=1.) * 255. 264 | rec_imgs = torch.clamp(recs * imagenet_std + imagenet_mean, min=0., max=1.) * 255. 265 | 266 | mse_no_reduction = torch.mean(mse_loss(org_imgs, rec_imgs, reduction='none'), dim=(1, 2, 3)) 267 | psnr = torch.mean(10. * torch.log10(255. ** 2 / mse_no_reduction)) 268 | 269 | return psnr.item() 270 | 271 | 272 | 273 | --------------------------------------------------------------------------------