├── models ├── __init__.py ├── CSformer.py ├── ViT_helper.py ├── CNN64.py └── Transformer64.py ├── utils ├── __init__.py └── utils.py ├── logs ├── checkpoint_bsd400 │ └── pre_trained_weight.txt └── checkpoint_coco │ └── pre_trained_weight.txt ├── train_script.sh ├── loss └── loss.py ├── eval.py ├── README.md ├── cfg.py ├── train.py ├── datasets.py └── functions.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from models import CSformer 3 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from utils import utils 6 | -------------------------------------------------------------------------------- /logs/checkpoint_bsd400/pre_trained_weight.txt: -------------------------------------------------------------------------------- 1 | Please download the pre-trained weight from: 2 | https://drive.google.com/file/d/1b8mbNcD7zbv5XC2oXJ8VStPoMqqS4g5s/view?usp=sharing 3 | or 4 | 链接:https://pan.baidu.com/s/1S2VHdv1WKUo6jyGHLPi1FQ?pwd=9lmz 提取码:9lmz 5 | -------------------------------------------------------------------------------- /logs/checkpoint_coco/pre_trained_weight.txt: -------------------------------------------------------------------------------- 1 | Please download the pre-trained weight from : 2 | https://drive.google.com/file/d/1P_HKhmTsYi2H94VMY1TcIU5Ze6H_mIq0/view?usp=sharing 3 | or 4 | 链接:https://pan.baidu.com/s/1S2VHdv1WKUo6jyGHLPi1FQ?pwd=9lmz 提取码:9lmz 5 | -------------------------------------------------------------------------------- /train_script.sh: -------------------------------------------------------------------------------- 1 | C:/Anaconda3/envs/pytorch_17/python train.py \ 2 | --gen_model CSformer \ 3 | --exp_name coco_cs4 \ 4 | --cs_ratio 4 \ 5 | --img_size 64 \ 6 | --bottom_width 8 \ 7 | --max_iter 500000 \ 8 | --g_lr 1e-4 \ 9 | --gen_batch_size 10 \ 10 | --eval_batch_size 10 \ 11 | --gf_dim 128 \ 12 | --val_freq 1 \ 13 | --print_freq 100 \ 14 | --g_window_size 8 \ 15 | --num_workers 1 \ 16 | --optimizer adam \ 17 | --beta1 0.9 \ 18 | --beta2 0.999 \ 19 | --init_type xavier_uniform \ 20 | --g_depth 5,5,5,5 \ 21 | --datarange -11 \ 22 | --train_patch_size 128 \ 23 | --rec_loss_type l2 \ 24 | --dataset coco \ 25 | --data_path D:/database/package/coco/unlabeled2017 26 | # --dataset BSD400 \ 27 | # --data_path C:/dataset/data/BSD400 28 | 29 | 30 | -------------------------------------------------------------------------------- /loss/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class ReconstructionLoss(nn.Module): 5 | def __init__(self, type='l1'): 6 | super(ReconstructionLoss, self).__init__() 7 | if (type == 'l1'): 8 | self.loss = nn.L1Loss() 9 | elif (type == 'l2'): 10 | self.loss = nn.MSELoss() 11 | else: 12 | raise SystemExit('Error: no such type of ReconstructionLoss!') 13 | 14 | def forward(self, sr, hr): 15 | return self.loss(sr, hr) 16 | 17 | 18 | 19 | def get_loss_dict(args, logger=None): 20 | loss = {} 21 | if (abs(args.rec_w - 0) <= 1e-8): 22 | raise SystemExit('NotImplementError: ReconstructionLoss must exist!') 23 | else: 24 | loss['rec_loss'] = ReconstructionLoss(type=args.rec_loss_type) 25 | return loss 26 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import cfg 2 | import torch 3 | from functions import test_overlap,test 4 | from utils.utils import set_log_dir, create_logger 5 | import os 6 | from tensorboardX import SummaryWriter 7 | 8 | 9 | torch.backends.cudnn.enabled = True 10 | torch.backends.cudnn.benchmark = True 11 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 12 | 13 | def main(): 14 | args = cfg.parse_args() 15 | torch.cuda.manual_seed(args.random_seed) 16 | assert args.exp_name 17 | assert args.load_path.endswith('.pth') 18 | assert os.path.exists(args.load_path) 19 | args.path_helper = set_log_dir('logs_eval', args.exp_name) 20 | logger = create_logger(args.path_helper['log_path'], phase='test') 21 | 22 | 23 | # Get Sensing_matrix 24 | # ratio_dict = {1: 10, 4: 43, 10: 103, 25: 272, 30: 327, 40: 436, 50: 545} #32*32 25 | ratio_dict = {1: 3, 4: 11, 10: 26, 25: 64, 30: 77, 40: 103, 50: 128} #16*16 26 | n_input = ratio_dict[args.cs_ratio] 27 | args.n_input = n_input 28 | 29 | # import network 30 | exec('import '+'models.'+args.gen_model) 31 | gen_net = eval('models.'+args.gen_model+'.CSformer')(args).cuda() 32 | gpu_ids = [i for i in range(int(torch.cuda.device_count()))] 33 | print('gpu ids:',gpu_ids) 34 | gen_net = torch.nn.DataParallel(gen_net.to("cuda:0"), device_ids=gpu_ids) 35 | print(f'gen_model: {args.gen_model}') 36 | print(f'Model params: {sum(param.numel() for param in gen_net.parameters())}') 37 | print(f'dataset: {args.dataset}') 38 | print(f'cs ratio: {args.cs_ratio}') 39 | 40 | 41 | # set writer 42 | logger.info(f'=> resuming from {args.load_path}') 43 | print(f'=> resuming from {args.load_path}') 44 | checkpoint_file = args.load_path 45 | assert os.path.exists(checkpoint_file) 46 | checkpoint = torch.load(checkpoint_file) 47 | writer_dict = { 48 | 'writer': SummaryWriter(args.path_helper['log_path']), 49 | } 50 | writer = writer_dict['writer'] 51 | if 'avg_gen_state_dict' in checkpoint: 52 | gen_net.load_state_dict(checkpoint['avg_gen_state_dict']) 53 | epoch = checkpoint['epoch'] 54 | logger.info(f'=> loaded checkpoint {checkpoint_file} (epoch {epoch})') 55 | else: 56 | gen_net.load_state_dict(checkpoint['gen_state_dict']) 57 | logger.info(f'=> loaded checkpoint {checkpoint_file}') 58 | print(f'=> loaded checkpoint {checkpoint_file}') 59 | 60 | if args.overlap == True: 61 | test_overlap(args,gen_net,logger) 62 | else: 63 | test(args,gen_net,logger) 64 | 65 | 66 | if __name__ == '__main__': 67 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CSformer: Bridging Convolution and Transformer for Compressive Sensing (TIP 2023) 2 | 3 | Official Pytorch implementation of "**CSformer: Bridging Convolution and Transformer for Compressive Sensing**" published in ***IEEE Transactions on Image Processing (TIP)***. 4 | #### [[Paper-arXiv](https://arxiv.org/abs/2112.15299)] [[Paper-official](https://ieeexplore.ieee.org/document/10124835/)] 5 | Dongjie Ye, [Zhangkai Ni](https://eezkni.github.io/), [Hanli Wang](https://mic.tongji.edu.cn/51/91/c9778a86417/page.htm), [Jian Zhang](https://jianzhang.tech/), [Shiqi Wang](https://www.cs.cityu.edu.hk/~shiqwang/), [Sam Kwong](http://www6.cityu.edu.hk/stfprofile/cssamk.htm) 6 | 7 | 8 | 9 | ## Testing (Running pretrained models) 10 | - Checkpoint 11 | 12 | Checkpoints trained on CoCo dataset can be found from [Google Drive](https://drive.google.com/file/d/1P_HKhmTsYi2H94VMY1TcIU5Ze6H_mIq0/view?usp=sharing) or [Baidu Netdisk](https://pan.baidu.com/s/1o7Cs9OLjy63PLydgFmQ_qw?pwd=fr6m) (提取码:fr6m). 13 | 14 | Checkpoints trained on BSD400 dataset can be found from [Google Drive](https://drive.google.com/file/d/1b8mbNcD7zbv5XC2oXJ8VStPoMqqS4g5s/view?usp=sharing) or [Baidu Netdisk](https://pan.baidu.com/s/1S2VHdv1WKUo6jyGHLPi1FQ?pwd=9lmz) (提取码:9lmz). 15 | 16 | - Inference 17 | 1. Unzip the checkpoint file and place all the files in the ./logs/checkpoint_coco/ or ./logs/checkpoint_bsd400/ directory. 18 | 2. Edit the ./cfg.py file to modify the [--testdata_path] by specifying the path to your test datasets. 19 | 3. Excute the test script below: 20 | ``` 21 | python eval.py --cs_ratio 1 --exp_name coco_test_CS1 --load_path ./logs/checkpoint_coco/checkpoint_CS1.pth --overlap --overlapstep 8 22 | ``` 23 | (The available options for [cs_ratio] in our pre-trained model are 1, 4, 10, 25, and 50.) 24 | 25 | If you want to test the model wihtout overlapping, you may run the script below: 26 | ``` 27 | python eval.py --cs_ratio 1 --exp_name coco_test_CS1 --load_path ./logs/checkpoint_coco/checkpoint_CS1.pth 28 | ``` 29 | ## Training (Training from scratch) 30 | 1. Prepare the training dataset. 31 | 2. Edit the train_script.sh file to modify your python path and the [--data_path], [--dataset] by specifying the path to your training datasets. 32 | 3. Excute the training script below: 33 | ``` 34 | sh train_script.sh 35 | ``` 36 | 4. Find the trained weight in the ./logs/[env]/Model/ folder. 37 | 38 | ## Citation 39 | If this code is useful for your research, please cite our paper: 40 | 41 | ``` 42 | @article{csformer, 43 | author={Ye, Dongjie and Ni, Zhangkai and Wang, Hanli and Zhang, Jian and Wang, Shiqi and Kwong, Sam}, 44 | journal={IEEE Transactions on Image Processing}, 45 | title={CSformer: Bridging Convolution and Transformer for Compressive Sensing}, 46 | year={2023}, 47 | volume={32}, 48 | number={}, 49 | pages={2827-2842}, 50 | doi={10.1109/TIP.2023.3274988}} 51 | ``` 52 | 53 | ## Contact 54 | 55 | Thanks for your attention! If you have any suggestion or question, feel free to leave a message here or contact Dongjie Ye (dj.ye@my.cityu.edu.hk). 56 | -------------------------------------------------------------------------------- /models/CSformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import models.Transformer64 as Transformer64 5 | import models.CNN64 as CNN64 6 | 7 | import numpy as np 8 | 9 | class CSformer(nn.Module): 10 | def __init__(self, args): 11 | super(CSformer, self).__init__() 12 | 13 | self.args = args 14 | ## sampling matrix 15 | # self.register_parameter("sensing_matrix", nn.Parameter(torch.from_numpy(sensing_matrix).float(), requires_grad=True)) 16 | self.n_input = args.n_input 17 | self.bottom_width = args.bottom_width 18 | self.embed_dim = args.gf_dim 19 | # self.l1 = nn.Linear(args.latent_dim, (args.bottom_width ** 2) * args.gf_dim) 20 | self.outdim = int(np.ceil((args.img_size**2)//(args.bottom_width ** 2))) 21 | # self.outdim = int(np.ceil(args.latent_dim // (args.bottom_width ** 2))) 22 | self.iniconv = nn.Sequential( 23 | # nn.ReflectionPad2d(1), 24 | nn.Conv2d(self.n_input,128,1,1,0), 25 | nn.LeakyReLU(0.2, inplace=True), 26 | nn.Conv2d(128, 256, 1, 1, 0), 27 | nn.LeakyReLU(0.2, inplace=True), 28 | nn.Conv2d(256, 512, 1, 1, 0), 29 | nn.LeakyReLU(0.2, inplace=True), 30 | nn.Conv2d(512, 512, 1, 1, 0), 31 | nn.LeakyReLU(0.2, inplace=True), 32 | nn.Conv2d(512, 512, 1, 1, 0) 33 | ) 34 | self.act = nn.LeakyReLU(0.2,inplace=True) 35 | 36 | self.Phi = nn.Parameter(torch.nn.init.xavier_normal_(torch.Tensor(self.n_input, 256))) 37 | self.PhiT = nn.Parameter(torch.nn.init.xavier_normal_(torch.Tensor(256,self.n_input))) 38 | # self.nsv = args.nsv 39 | 40 | #transformer branch 41 | self.td = Transformer64.Transformer(args) 42 | 43 | 44 | #cnn branch 45 | self.gs = CNN64.Generator(args) 46 | 47 | 48 | 49 | 50 | def together(self,inputs,S,H,L): 51 | inputs = inputs.squeeze(1) 52 | inputs = torch.cat(torch.split(inputs, split_size_or_sections=H*S, dim=0), dim=2) 53 | inputs = torch.cat(torch.split(inputs, split_size_or_sections=S, dim=0), dim=1) 54 | inputs = inputs.unsqueeze(1) 55 | return inputs 56 | 57 | def forward(self, inputs): 58 | 59 | H = int(inputs.shape[2]/64) 60 | L = int(inputs.shape[3]/64) 61 | S = inputs.shape[0] 62 | inputs = torch.squeeze(inputs,dim=1) 63 | inputs = torch.cat(torch.split(inputs, split_size_or_sections=64, dim=1), dim=0) 64 | inputs = torch.cat(torch.split(inputs, split_size_or_sections=64, dim=2), dim=0) 65 | inputs = torch.unsqueeze(inputs,dim=1) 66 | 67 | np.random.seed(12345) 68 | PhiWeight = self.Phi.contiguous().view(self.n_input, 1, 16, 16) 69 | y = F.conv2d(inputs, PhiWeight, padding=0, stride=16, bias=None) 70 | 71 | 72 | # Initialization-subnet 73 | PhiTWeight = self.PhiT.contiguous().view(256, self.n_input, 1, 1) 74 | PhiTb = F.conv2d(y, PhiTWeight, padding=0, bias=None) 75 | PhiTb = torch.nn.PixelShuffle(16)(PhiTb) 76 | 77 | x = self.iniconv(y) 78 | x = torch.nn.PixelShuffle(2)(x) 79 | 80 | x =x.flatten(2).transpose(1,2).contiguous() 81 | gsfeatures = self.gs(x) 82 | output = self.td(x,gsfeatures,PhiTb) 83 | merge_output = self.together(output, S, H, L) 84 | merge_PhiTb = self.together(PhiTb, S, H, L) 85 | # final_output = self.enhance(merge_output,merge_PhiTb) 86 | 87 | 88 | return merge_output, merge_PhiTb, output, PhiTb 89 | # return merge_output, merge_PhiTb, output, a 90 | 91 | -------------------------------------------------------------------------------- /models/ViT_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | def drop_path(x, drop_prob: float = 0., training: bool = False): 5 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 6 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 7 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 8 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 9 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 10 | 'survival rate' as the argument. 11 | """ 12 | if drop_prob == 0. or not training: 13 | return x 14 | keep_prob = 1 - drop_prob 15 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 16 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 17 | random_tensor.floor_() # binarize 18 | output = x.div(keep_prob) * random_tensor 19 | return output 20 | 21 | 22 | class DropPath(nn.Module): 23 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 24 | """ 25 | def __init__(self, drop_prob=None): 26 | super(DropPath, self).__init__() 27 | self.drop_prob = drop_prob 28 | 29 | def forward(self, x): 30 | return drop_path(x, self.drop_prob, self.training) 31 | 32 | from itertools import repeat 33 | from torch._six import container_abcs 34 | 35 | 36 | # From PyTorch internals 37 | def _ntuple(n): 38 | def parse(x): 39 | if isinstance(x, container_abcs.Iterable): 40 | return x 41 | return tuple(repeat(x, n)) 42 | return parse 43 | 44 | 45 | to_1tuple = _ntuple(1) 46 | to_2tuple = _ntuple(2) 47 | to_3tuple = _ntuple(3) 48 | to_4tuple = _ntuple(4) 49 | 50 | 51 | 52 | import torch 53 | import math 54 | import warnings 55 | 56 | 57 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 58 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 59 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 60 | def norm_cdf(x): 61 | # Computes standard normal cumulative distribution function 62 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 63 | 64 | if (mean < a - 2 * std) or (mean > b + 2 * std): 65 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 66 | "The distribution of values may be incorrect.", 67 | stacklevel=2) 68 | 69 | with torch.no_grad(): 70 | # Values are generated by using a truncated uniform distribution and 71 | # then using the inverse CDF for the normal distribution. 72 | # Get upper and lower cdf values 73 | l = norm_cdf((a - mean) / std) 74 | u = norm_cdf((b - mean) / std) 75 | 76 | # Uniformly fill tensor with values from [l, u], then translate to 77 | # [2l-1, 2u-1]. 78 | tensor.uniform_(2 * l - 1, 2 * u - 1) 79 | 80 | # Use inverse cdf transform for normal distribution to get truncated 81 | # standard normal 82 | tensor.erfinv_() 83 | 84 | # Transform to proper mean, std 85 | tensor.mul_(std * math.sqrt(2.)) 86 | tensor.add_(mean) 87 | 88 | # Clamp to ensure it's in the proper range 89 | tensor.clamp_(min=a, max=b) 90 | return tensor 91 | 92 | 93 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 94 | # type: (Tensor, float, float, float, float) -> Tensor 95 | r"""Fills the input Tensor with values drawn from a truncated 96 | normal distribution. The values are effectively drawn from the 97 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 98 | with values outside :math:`[a, b]` redrawn until they are within 99 | the bounds. The method used for generating the random values works 100 | best when :math:`a \leq \text{mean} \leq b`. 101 | Args: 102 | tensor: an n-dimensional `torch.Tensor` 103 | mean: the mean of the normal distribution 104 | std: the standard deviation of the normal distribution 105 | a: the minimum cutoff value 106 | b: the maximum cutoff value 107 | Examples: 108 | >>> w = torch.empty(3, 5) 109 | >>> nn.init.trunc_normal_(w) 110 | """ 111 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) -------------------------------------------------------------------------------- /cfg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def parse_args(): 5 | parser = argparse.ArgumentParser() 6 | 7 | # train 8 | parser.add_argument( 9 | '-gen_bs', 10 | '--gen_batch_size', 11 | type=int, 12 | default=10, 13 | help='size of the batches') 14 | parser.add_argument( 15 | '--max_epoch', 16 | type=int, 17 | default=200, 18 | help='number of epochs of training') 19 | parser.add_argument( 20 | '--max_iter', 21 | type=int, 22 | default=None, 23 | help='set the max iteration number') 24 | parser.add_argument( 25 | '--g_lr', 26 | type=float, 27 | default=0.0001, 28 | help='adam: gen learning rate') 29 | parser.add_argument( 30 | '--lr_decay', 31 | action='store_true', 32 | help='learning rate decay or not') 33 | parser.add_argument( 34 | '--beta1', 35 | type=float, 36 | default=0.0, 37 | help='adam: decay of first order momentum of gradient') 38 | parser.add_argument( 39 | '--beta2', 40 | type=float, 41 | default=0.9, 42 | help='adam: decay of first order momentum of gradient') 43 | parser.add_argument( 44 | '--num_workers', 45 | type=int, 46 | default=8, 47 | help='number of cpu threads to use during batch generation') 48 | parser.add_argument( 49 | '--val_freq', 50 | type=int, 51 | default=1, 52 | help='interval between each validation') 53 | 54 | parser.add_argument( 55 | '--dataset', 56 | type=str, 57 | default='coco', 58 | help='dataset type') 59 | parser.add_argument( 60 | '--data_path', 61 | type=str, 62 | default="D:\database\package\coco\\unlabeled2017", 63 | help='The path of data set') 64 | parser.add_argument('--init_type', type=str, default='xavier_uniform', 65 | choices=['normal', 'orth', 'xavier_uniform', 'false'], 66 | help='The init type') 67 | parser.add_argument('--optimizer', type=str, default="adam", 68 | help='optimizer') 69 | parser.add_argument('--rec_w', type=int, default=1, help='penalty for the reconstruction loss') 70 | parser.add_argument('--rec_loss_type', type=str, default='l1', help='The type of reconstruction loss') 71 | parser.add_argument('--weight_decay', type=float, default=1e-3, help='penalty for the adam') 72 | parser.add_argument('--lr_multi', type=int, default=None) 73 | parser.add_argument('--train_patch_size',type=int,default=128,help='size of training input in dataloader') 74 | parser.add_argument('--random_seed', type=int, default=12345) 75 | 76 | # setting 77 | parser.add_argument( 78 | '--exp_name', 79 | type=str, 80 | help='The name of exp') 81 | parser.add_argument( 82 | '--load_path', 83 | type=str, 84 | help='The reload model path') 85 | parser.add_argument('--cs_ratio', type=int, default=10, help='from {1, 4, 10, 25, 40, 50}') 86 | parser.add_argument('--datarange', type=str, default='-11', 87 | help='input data norm to range') 88 | parser.add_argument('--torch_vision', action='store_true', default=False, help='Show intermediate results in tensorbard dir') 89 | parser.add_argument( 90 | '--print_freq', 91 | type=int, 92 | default=100, 93 | help='interval between each verbose') 94 | 95 | #test 96 | parser.add_argument( 97 | '--testdata_path', 98 | type=str, 99 | default=[r'C:\dataset\data\Set11'], 100 | # default=['D:\database\dataset\\Urban100','C:\dataset\data\Set11','D:\database\dataset\BSD68','D:\database\dataset\Set14','D:\database\dataset\Set5\data'], 101 | help='The path of data set') 102 | parser.add_argument( 103 | '--valdata_path', 104 | type=str, 105 | default='C:\dataset\data\Set11', 106 | help='The path of val set') 107 | parser.add_argument('--eval_batch_size', type=int, default=400) 108 | parser.add_argument('--overlap', action='store_true',help='overlap or not during testing') 109 | parser.add_argument('--overlapstep',type=int,default=8,help='the overlap step for testing') 110 | 111 | #model 112 | parser.add_argument( 113 | '--gen_model', 114 | type=str, 115 | default='CSformer', 116 | help='path of gen model') 117 | parser.add_argument( 118 | '--bottom_width', 119 | type=int, 120 | default=8) 121 | parser.add_argument( 122 | '--img_size', 123 | type=int, 124 | default=64, 125 | help='output size') 126 | parser.add_argument('--gf_dim', type=int, default=128, 127 | help='The base channel num of model') 128 | 129 | parser.add_argument('--g_depth', type=str, default="5,5,5,5", 130 | help='Generator Depth') 131 | parser.add_argument('--g_window_size', type=int, default=8, 132 | help='generator mlp ratio') 133 | parser.add_argument('--num_heads', type=str, default="16,8,4,2", 134 | help='num_head of transformer') 135 | parser.add_argument('--cnnnorm_type', type=str, default="BatchNorm", 136 | help='norm type of cnn') 137 | parser.add_argument('--g_norm', type=str, default="ln", 138 | help='Generator Normalization') 139 | parser.add_argument('--g_mlp', type=int, default=4, 140 | help='generator mlp ratio') 141 | parser.add_argument('--g_act', type=str, default="gelu", 142 | help='Generator activation Layer') 143 | parser.add_argument('--pretrained', type=str, default="church", 144 | help='pretrained dataset') 145 | parser.add_argument('--seed', default=12345, type=int, 146 | help='seed for initializing training. ') 147 | 148 | 149 | opt = parser.parse_args() 150 | 151 | return opt 152 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import cfg 6 | import torch 7 | import os 8 | import numpy as np 9 | import torch.nn as nn 10 | from tqdm import tqdm 11 | from functions import train, validate, load_params, copy_params 12 | import datasets 13 | from utils.utils import set_log_dir, save_checkpoint, create_logger 14 | from tensorboardX import SummaryWriter 15 | from loss.loss import get_loss_dict 16 | import time 17 | 18 | 19 | torch.backends.cudnn.enabled = True 20 | torch.backends.cudnn.benchmark = True 21 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 22 | 23 | def main(): 24 | args = cfg.parse_args() 25 | torch.cuda.manual_seed(args.random_seed) 26 | 27 | # Get Sensing_matrix 28 | # ratio_dict = {1: 10, 4: 43, 10: 103, 25: 272, 30: 327, 40: 436, 50: 545} #32*32 29 | ratio_dict = {1: 3, 4: 11, 10: 26, 25: 64, 30: 77, 40: 103, 50: 128} #16*16 30 | n_input = ratio_dict[args.cs_ratio] 31 | args.n_input = n_input 32 | 33 | # import network 34 | exec('import '+'models.'+args.gen_model) 35 | gen_net = eval('models.'+args.gen_model+'.CSformer')(args).cuda() 36 | print(f'model: {args.gen_model}') 37 | print(f'model param {(sum(param.numel() for param in gen_net.parameters()))/1e6}M') 38 | print(f'dataset: {args.dataset}') 39 | print(f'cs ratio: {args.cs_ratio}') 40 | print(f'windows size: {args.g_window_size}') 41 | print(f'transformer depth: {[int(i) for i in args.g_depth.split(",")]}') 42 | print(f'transformer num_heads: {[int(i) for i in args.num_heads.split(",")]}') 43 | print(f'dim: {args.gf_dim}') 44 | print(f'CNN Norm: {args.cnnnorm_type}') 45 | print(f'Training patch size: {args.train_patch_size}') 46 | 47 | # weight init 48 | def weights_init(m): 49 | classname = m.__class__.__name__ 50 | if classname.find('Conv2d') != -1: 51 | if args.init_type == 'normal': 52 | nn.init.normal_(m.weight.data, 0.0, 0.02) 53 | elif args.init_type == 'orth': 54 | nn.init.orthogonal_(m.weight.data) 55 | elif args.init_type == 'xavier_uniform': 56 | nn.init.xavier_uniform_(m.weight.data, 1.) 57 | else: 58 | raise NotImplementedError('{} unknown inital type'.format(args.init_type)) 59 | elif classname.find('BatchNorm2d') != -1: 60 | nn.init.normal_(m.weight.data, 1.0, 0.02) 61 | nn.init.constant_(m.bias.data, 0.0) 62 | 63 | gen_net.apply(weights_init) 64 | gpu_ids = [i for i in range(int(torch.cuda.device_count()))] 65 | print('gpu ids:',gpu_ids) 66 | gen_net = torch.nn.DataParallel(gen_net.to("cuda:0"), device_ids=gpu_ids) 67 | 68 | 69 | # set optimizer 70 | if args.lr_multi is not None: 71 | print('multi lr is not None') 72 | gs_params_id = list(map(id, gen_net.module.gs.parameters())) 73 | decoder_params = filter(lambda p: id(p) not in gs_params_id, gen_net.parameters()) 74 | gen_optimizer = torch.optim.Adam([{'params':gen_net.module.gs.parameters(), 'lr': args.g_lr}, 75 | {'params': decoder_params, 'lr': args.decoder_lr}], betas=(args.beta1, args.beta2)) 76 | else: 77 | gen_optimizer = torch.optim.Adam(gen_net.parameters(), 78 | args.g_lr, betas=(args.beta1, args.beta2)) 79 | 80 | # set up data_loader 81 | dataset = datasets.ImageDataset(args) 82 | train_loader = dataset.train 83 | 84 | # epoch number 85 | args.max_epoch = args.max_epoch 86 | if args.max_iter: 87 | args.max_epoch = np.ceil(args.max_iter / len(train_loader)) 88 | 89 | 90 | # ------------------------------ Cosine LR ------------------------------ 91 | t_max = args.max_epoch 92 | print(f'max epoch is {t_max}') 93 | scheduler_lr = torch.optim.lr_scheduler.CosineAnnealingLR(gen_optimizer, T_max=t_max, eta_min=1e-6) 94 | 95 | # initial 96 | start_epoch = 0 97 | best_psnr =10 98 | best_ssim =0 99 | psnr_score = 8 100 | best_epoch = 0 101 | 102 | # set up loss 103 | loss_all = get_loss_dict(args) 104 | print('loss:',loss_all.keys()) 105 | print(f'rec_loss type: {args.rec_loss_type}\ng_lr:{args.g_lr}') 106 | # set writer 107 | if args.load_path: 108 | print(f'=> resuming from {args.load_path}') 109 | assert os.path.exists(args.load_path) 110 | checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint_best.pth') 111 | assert os.path.exists(checkpoint_file) 112 | checkpoint = torch.load(checkpoint_file) 113 | start_epoch = checkpoint['epoch'] 114 | gen_net.load_state_dict(checkpoint['gen_state_dict']) 115 | gen_optimizer.load_state_dict(checkpoint['gen_optimizer']) 116 | args.path_helper = checkpoint['path_helper'] 117 | logger = create_logger(args.path_helper['log_path']) 118 | logger.info(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})') 119 | else: 120 | # create new log dir 121 | assert args.exp_name 122 | args.path_helper = set_log_dir('logs', args.exp_name) 123 | logger = create_logger(args.path_helper['log_path']) 124 | 125 | logger.info(args) 126 | writer_dict = { 127 | 'writer': SummaryWriter(args.path_helper['log_path']), 128 | 'train_global_steps': start_epoch * len(train_loader), 129 | 'valid_global_steps': start_epoch // args.val_freq, 130 | } 131 | 132 | # train loop 133 | for epoch in tqdm(range(int(start_epoch), int(args.max_epoch)), desc='total progress'): 134 | #lr 135 | cur_lr = gen_optimizer.param_groups[0]['lr'] 136 | print(f'epoch[{epoch}] lr: {cur_lr}') 137 | writer = writer_dict['writer'] 138 | writer.add_scalar('LR/g_lr', cur_lr, epoch) 139 | #lr 140 | start_time = time.time() 141 | train(args, gen_net,gen_optimizer,train_loader, epoch, writer_dict, 142 | loss_all) 143 | print(f'training time is {time.time()-start_time}') 144 | 145 | if epoch % args.val_freq == 0 or epoch == int(args.max_epoch)-1: 146 | 147 | backup_param = copy_params(gen_net) 148 | psnr_score,PSNR_cross, ssim_score,SSIM_cross = validate(args, epoch, gen_net, writer_dict) 149 | logger.info(f'@ epoch {epoch} || PSNR score: {psnr_score:.2f} SSIM score: {ssim_score:.4f}.\t @ best epoch {best_epoch} || Best PSNR score: {best_psnr:.2f} SSIM score: {best_ssim:.4f} .') 150 | load_params(gen_net, backup_param) 151 | 152 | if psnr_score > best_psnr: 153 | best_epoch = epoch 154 | best_psnr = psnr_score 155 | is_best = True 156 | best_ssim = ssim_score 157 | logger.info(f'@ epoch {epoch} || Best PSNR score: {psnr_score:.2f} SSIM score: {ssim_score:.4f}.') 158 | else: 159 | is_best = False 160 | else: 161 | is_best = False 162 | 163 | save_checkpoint({ 164 | 'epoch': epoch + 1, 165 | 'gen_state_dict': gen_net.state_dict(), 166 | 'best_psnr': best_psnr, 167 | 'best_ssim': best_ssim, 168 | 'path_helper': args.path_helper 169 | }, is_best, args.path_helper['ckpt_path']) 170 | 171 | #lr 172 | scheduler_lr.step() 173 | 174 | 175 | 176 | if __name__ == '__main__': 177 | main() -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.datasets as datasets 3 | import torchvision.transforms as transforms 4 | from torch.utils.data import Dataset 5 | import os 6 | import random 7 | from scipy import io 8 | import cv2 9 | import numpy as np 10 | 11 | class Augment_RGB_torch: 12 | def __init__(self): 13 | pass 14 | def transform0(self, torch_tensor): 15 | return torch_tensor 16 | def transform1(self, torch_tensor): 17 | torch_tensor = torch.rot90(torch_tensor, k=1, dims=[-1,-2]) 18 | return torch_tensor 19 | def transform2(self, torch_tensor): 20 | torch_tensor = torch.rot90(torch_tensor, k=2, dims=[-1,-2]) 21 | return torch_tensor 22 | def transform3(self, torch_tensor): 23 | torch_tensor = torch.rot90(torch_tensor, k=3, dims=[-1,-2]) 24 | return torch_tensor 25 | def transform4(self, torch_tensor): 26 | torch_tensor = torch_tensor.flip(-2) 27 | return torch_tensor 28 | def transform5(self, torch_tensor): 29 | torch_tensor = (torch.rot90(torch_tensor, k=1, dims=[-1,-2])).flip(-2) 30 | return torch_tensor 31 | def transform6(self, torch_tensor): 32 | torch_tensor = (torch.rot90(torch_tensor, k=2, dims=[-1,-2])).flip(-2) 33 | return torch_tensor 34 | def transform7(self, torch_tensor): 35 | torch_tensor = (torch.rot90(torch_tensor, k=3, dims=[-1,-2])).flip(-2) 36 | return torch_tensor 37 | augment = Augment_RGB_torch() 38 | transforms_aug = [method for method in dir(augment) if callable(getattr(augment, method)) if not method.startswith('_')] 39 | 40 | class ImageDataset(object): 41 | def __init__(self, args, cur_img_size=None, bs=None): 42 | bs = args.gen_batch_size if bs == None else bs 43 | img_size = args.img_size 44 | if args.dataset.lower() == 'coco' or args.dataset.lower() == 'div2k': 45 | Dt = ImgData(args) 46 | self.train = torch.utils.data.DataLoader(Dt,batch_size=args.gen_batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) 47 | elif args.dataset.lower() == 'bsd400': 48 | Dt = ImgData_BSD400(args) 49 | self.train = torch.utils.data.DataLoader(Dt,batch_size=args.gen_batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) 50 | else: 51 | raise NotImplementedError('Unknown dataset: {}'.format(args.dataset)) 52 | 53 | 54 | class ImgData_BSD400(): 55 | """imagent""" 56 | def __init__(self, args,train=True): 57 | self.dataroot = args.data_path 58 | 59 | self.img_list = search(os.path.join(self.dataroot), "png") 60 | self.img_list = sorted(self.img_list) 61 | # print(self.img_list) 62 | self.train = train 63 | 64 | self.args = args 65 | self.len = len(self.img_list) 66 | print("data length:", len(self.img_list)) 67 | 68 | def __len__(self): 69 | return len(self.img_list) 70 | 71 | def _get_index(self, idx): 72 | return idx % len(self.img_list) 73 | 74 | def _load_file(self, idx): 75 | idx = self._get_index(idx) 76 | f_lr = self.img_list[idx] 77 | 78 | #CV 79 | 80 | lr_cv = cv2.imread(f_lr) 81 | lr_cv = cv2.cvtColor(lr_cv, cv2.COLOR_BGR2RGB) 82 | lr = rgb2ycbcr(lr_cv) 83 | if len(lr.shape) == 2: 84 | lr = np.expand_dims(lr, axis=2) 85 | 86 | 87 | return lr 88 | 89 | def _np2Tensor(self, img, rgb_range): 90 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 91 | tensor = np_transpose.astype(np.float32) 92 | tensor = tensor * (rgb_range / 255) 93 | return tensor 94 | 95 | def __getitem__(self, idx): 96 | hr = self._load_file(idx % self.len) 97 | hr = self.get_patch(hr) 98 | hr = self._np2Tensor(hr, rgb_range=255) 99 | 100 | if self.args.datarange == '01': 101 | hr = hr/255. 102 | else: 103 | hr = hr / 127.5 - 1 104 | #aug 105 | apply_trans = transforms_aug[random.getrandbits(3)] 106 | hr = torch.from_numpy(hr) 107 | hr = getattr(augment, apply_trans)(hr) 108 | 109 | 110 | return hr 111 | 112 | def get_patch(self, lr, scale=1): 113 | lr = get_patch_img(lr, patch_size=self.args.train_patch_size, scale=scale) 114 | return lr 115 | 116 | class ImgData(): 117 | """imagent""" 118 | def __init__(self, args,train=True): 119 | self.dataroot = args.data_path 120 | if args.dataset.lower() == 'coco': 121 | self.img_list = search(os.path.join(self.dataroot), "jpg") 122 | else: 123 | self.img_list = search(os.path.join(self.dataroot), "png") 124 | self.img_list = sorted(self.img_list)[:(len(self.img_list)//3)] 125 | self.train = train 126 | 127 | self.args = args 128 | self.len = len(self.img_list) 129 | print("data length:", len(self.img_list)) 130 | 131 | def __len__(self): 132 | return len(self.img_list) 133 | 134 | def _get_index(self, idx): 135 | return idx % len(self.img_list) 136 | 137 | def _load_file(self, idx): 138 | idx = self._get_index(idx) 139 | f_lr = self.img_list[idx] 140 | 141 | #CV 142 | 143 | lr_cv = cv2.imread(f_lr) 144 | lr_cv = cv2.cvtColor(lr_cv, cv2.COLOR_BGR2RGB) 145 | lr = rgb2ycbcr(lr_cv) 146 | if len(lr.shape) == 2: 147 | lr = np.expand_dims(lr, axis=2) 148 | 149 | return lr 150 | 151 | def _np2Tensor(self, img, rgb_range): 152 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 153 | tensor = np_transpose.astype(np.float32) 154 | tensor = tensor * (rgb_range / 255) 155 | return tensor 156 | 157 | def __getitem__(self, idx): 158 | hr = self._load_file(idx % self.len) 159 | hr = self.get_patch(hr) 160 | hr = self._np2Tensor(hr, rgb_range=255) 161 | 162 | if self.args.datarange == '01': 163 | hr = hr/255. 164 | else: 165 | hr = hr / 127.5 - 1 166 | #aug 167 | apply_trans = transforms_aug[random.getrandbits(3)] 168 | hr = torch.from_numpy(hr) 169 | hr = getattr(augment, apply_trans)(hr) 170 | 171 | 172 | return hr 173 | 174 | def get_patch(self, lr, scale=1): 175 | lr = get_patch_img(lr, patch_size=self.args.train_patch_size, scale=scale) 176 | return lr 177 | 178 | 179 | def rgb2ycbcr(img, only_y=True): 180 | '''same as matlab rgb2ycbcr 181 | only_y: only return Y channel 182 | Input: 183 | uint8, [0, 255] 184 | float, [0, 1] 185 | ''' 186 | in_img_type = img.dtype 187 | img.astype(np.float32) 188 | if in_img_type != np.uint8: 189 | img *= 255. 190 | # convert 191 | if only_y: 192 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 193 | else: 194 | rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], 195 | [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] 196 | if in_img_type == np.uint8: 197 | rlt = rlt.round() 198 | else: 199 | rlt /= 255. 200 | return rlt.astype(in_img_type) 201 | 202 | def search(root, target="JPEG"): 203 | """imagent""" 204 | item_list = [] 205 | items = os.listdir(root) 206 | for item in items: 207 | path = os.path.join(root, item) 208 | if os.path.isdir(path): 209 | item_list.extend(search(path, target)) 210 | elif path.split('.')[-1] == target: 211 | item_list.append(path) 212 | elif path.split('/')[-1].startswith(target): 213 | item_list.append(path) 214 | return item_list 215 | 216 | 217 | def get_patch_img(img, patch_size=128, scale=1): 218 | """imagent""" 219 | ih, iw = img.shape[:2] 220 | tp = scale * patch_size 221 | if (iw - tp) > -1 and (ih-tp) > 1: 222 | ix = random.randrange(0, iw-tp+1) 223 | iy = random.randrange(0, ih-tp+1) 224 | hr = img[iy:iy+tp, ix:ix+tp, :] 225 | else: 226 | img = np.resize(img,(ih*2,iw*2,1)) 227 | ih, iw = img.shape[:2] 228 | if (iw - tp) > -1 and (ih - tp) > 1: 229 | tp = scale * patch_size 230 | ix = random.randrange(0, iw - tp + 1) 231 | iy = random.randrange(0, ih - tp + 1) 232 | hr = img[iy:iy + tp, ix:ix + tp, :] 233 | return hr 234 | 235 | 236 | -------------------------------------------------------------------------------- /models/CNN64.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.nn.functional as F 5 | import functools 6 | 7 | 8 | 9 | 10 | 11 | def pixel_upsample(x, H, W): 12 | B, N, C = x.size() 13 | assert N == H*W 14 | x = x.permute(0, 2, 1) 15 | x = x.view(-1, C, H, W) 16 | x = nn.PixelShuffle(2)(x) 17 | B, C, H, W = x.size() 18 | x = x.view(-1, C, H*W) 19 | x = x.permute(0,2,1) 20 | return x, H, W 21 | 22 | class Upsample(nn.Module): 23 | def __init__(self, in_channel, out_channel,up_mode='bicubic'): 24 | super(Upsample, self).__init__() 25 | self.conv = nn.Sequential( 26 | nn.Conv2d(in_channel, out_channel, kernel_size=1), 27 | ) 28 | self.up_mode = up_mode 29 | 30 | def forward(self, x): 31 | #up 32 | x = F.interpolate(x, scale_factor=2, mode=self.up_mode) 33 | #↓dim 34 | out = self.conv(x) # B H*W C 35 | return out 36 | 37 | class ResGenerator(nn.Module): 38 | def __init__(self, args,upsample=Upsample): 39 | super(ResGenerator, self).__init__() 40 | self.args = args 41 | self.bottom_width = args.bottom_width 42 | self.embed_dim = conv_dim = args.gf_dim 43 | self.dec1 = nn.Sequential( 44 | ResBlock(in_channels=conv_dim, out_channels=conv_dim,norm_fun = args.cnnnorm_type), 45 | ResBlock(in_channels=conv_dim, out_channels=conv_dim,norm_fun = args.cnnnorm_type))# 8*8*128 --> 32*32*256 46 | self.upsample_1 = upsample(conv_dim, conv_dim // 2) 47 | self.dec2 = nn.Sequential( 48 | ResBlock(in_channels=conv_dim// 2, out_channels=conv_dim// 2,norm_fun = args.cnnnorm_type), 49 | ResBlock(in_channels=conv_dim// 2, out_channels=conv_dim// 2,norm_fun = args.cnnnorm_type)) # 16*16*128 --> 32*32*256 50 | self.upsample_2 = upsample(conv_dim//2, conv_dim // 4) 51 | self.dec3 =nn.Sequential( 52 | ResBlock(in_channels=conv_dim// 4, out_channels=conv_dim// 4,norm_fun = args.cnnnorm_type), 53 | ResBlock(in_channels=conv_dim// 4, out_channels=conv_dim// 4,norm_fun = args.cnnnorm_type)) # 32*32*128 --> 32*32*256 54 | self.upsample_3 = upsample(conv_dim // 4, conv_dim // 8) 55 | self.dec4 = nn.Sequential( 56 | ResBlock(in_channels=conv_dim// 8, out_channels=conv_dim// 8,norm_fun = args.cnnnorm_type), 57 | ResBlock(in_channels=conv_dim// 8, out_channels=conv_dim// 8,norm_fun = args.cnnnorm_type)) # 64*64*128 --> 32*32*256 58 | # self.l1 = nn.Linear(args.latent_dim, (self.bottom_width ** 2) * self.embed_dim) 59 | 60 | self.to_rgb = nn.ModuleList() 61 | self.padding3 = (3 + (3 - 1) * (1 - 1) - 1) // 2 62 | self.padding7 = (7 + (7 - 1) * (1 - 1) - 1) // 2 63 | # for i in range(3): 64 | # to_rgb = None 65 | # dim = conv_dim // (2**i) 66 | # to_rgb = nn.Sequential( 67 | # nn.ReflectionPad2d(self.padding3), 68 | # nn.Conv2d(dim, 1, 3, 1, 0), 69 | # # nn.ReflectionPad2d(self.padding7), 70 | # # nn.Conv2d(conv_dim, 1, 7, 1, 0), 71 | # nn.Tanh() 72 | # ) 73 | # self.to_rgb.append(to_rgb) 74 | 75 | def forward(self, x): 76 | features = [] 77 | rgb = [] 78 | x = x.permute(0,2,1).contiguous().view(-1,self.embed_dim, self.bottom_width,self.bottom_width) 79 | 80 | #8x8 81 | x = self.dec1(x) 82 | features.append(x) 83 | # rgb.append(self.to_rgb[0](x)) 84 | 85 | #16x16 86 | x = self.upsample_1(x) 87 | x = self.dec2(x) 88 | features.append(x) 89 | # rgb.append(self.to_rgb[1](x)) 90 | 91 | #32x32 92 | x = self.upsample_2(x) 93 | x = self.dec3(x) 94 | features.append(x) 95 | # rgb.append(self.to_rgb[2](x)) 96 | 97 | #64x64 98 | x = self.upsample_3(x) 99 | x = self.dec4(x) 100 | features.append(x) 101 | 102 | return features,rgb 103 | 104 | 105 | class Generator(nn.Module): 106 | def __init__(self, args,upsample=Upsample): 107 | super(Generator, self).__init__() 108 | self.args = args 109 | self.bottom_width = args.bottom_width 110 | self.embed_dim = conv_dim = args.gf_dim 111 | self.dec1 = ConvBlock(in_channels=conv_dim, out_channels=conv_dim,norm_fun = args.cnnnorm_type) # 8*8*128 --> 32*32*256 112 | self.upsample_1 = upsample(conv_dim, conv_dim // 2) 113 | self.dec2 = ConvBlock(in_channels=conv_dim//2, out_channels=conv_dim//2,norm_fun = args.cnnnorm_type) # 16*16*128 --> 32*32*256 114 | self.upsample_2 = upsample(conv_dim//2, conv_dim // 4) 115 | self.dec3 = ConvBlock(in_channels=conv_dim//4, out_channels=conv_dim//4,norm_fun = args.cnnnorm_type) # 32*32*128 --> 32*32*256 116 | self.upsample_3 = upsample(conv_dim // 4, conv_dim // 8) 117 | self.dec4 = ConvBlock(in_channels=conv_dim//8, out_channels=conv_dim//8,norm_fun = args.cnnnorm_type) # 64*64*128 --> 32*32*256 118 | # self.l1 = nn.Linear(args.latent_dim, (self.bottom_width ** 2) * self.embed_dim) 119 | 120 | def forward(self, x): 121 | features = [] 122 | x = tf2cnn(x) 123 | 124 | #8x8 125 | x = self.dec1(x) 126 | features.append(x) 127 | 128 | 129 | #16x16 130 | x = self.upsample_1(x) 131 | x = self.dec2(x) 132 | features.append(x) 133 | 134 | 135 | #32x32 136 | x = self.upsample_2(x) 137 | x = self.dec3(x) 138 | features.append(x) 139 | 140 | #64x64 141 | x = self.upsample_3(x) 142 | x = self.dec4(x) 143 | features.append(x) 144 | 145 | return features 146 | 147 | 148 | class Generator_nopos(nn.Module): 149 | def __init__(self, args,upsample=Upsample): 150 | super(Generator_nopos, self).__init__() 151 | self.args = args 152 | self.bottom_width = args.bottom_width 153 | self.embed_dim = conv_dim = args.gf_dim 154 | self.dec1 = ConvBlock(in_channels=conv_dim, out_channels=conv_dim,norm_fun = args.cnnnorm_type) # 8*8*128 --> 32*32*256 155 | self.upsample_1 = upsample(conv_dim, conv_dim // 2) 156 | self.dec2 = ConvBlock(in_channels=conv_dim//2, out_channels=conv_dim//2,norm_fun = args.cnnnorm_type) # 16*16*128 --> 32*32*256 157 | self.upsample_2 = upsample(conv_dim//2, conv_dim // 4) 158 | self.dec3 = ConvBlock(in_channels=conv_dim//4, out_channels=conv_dim//4,norm_fun = args.cnnnorm_type) # 32*32*128 --> 32*32*256 159 | self.upsample_3 = upsample(conv_dim // 4, conv_dim // 8) 160 | self.dec4 = ConvBlock(in_channels=conv_dim//8, out_channels=conv_dim//8,norm_fun = args.cnnnorm_type) # 64*64*128 --> 32*32*256 161 | # self.l1 = nn.Linear(args.latent_dim, (self.bottom_width ** 2) * self.embed_dim) 162 | 163 | def forward(self, x): 164 | features = [] 165 | # x = tf2cnn(x) 166 | 167 | #8x8 168 | x = self.dec1(x) 169 | features.append(x) 170 | 171 | 172 | #16x16 173 | x = self.upsample_1(x) 174 | x = self.dec2(x) 175 | features.append(x) 176 | 177 | 178 | #32x32 179 | x = self.upsample_2(x) 180 | x = self.dec3(x) 181 | features.append(x) 182 | 183 | #64x64 184 | x = self.upsample_3(x) 185 | x = self.dec4(x) 186 | features.append(x) 187 | 188 | return features 189 | 190 | def tf2cnn(x): 191 | B, L, C = x.shape 192 | H = int(math.sqrt(L)) 193 | W = int(math.sqrt(L)) 194 | x = x.transpose(1, 2).contiguous().view(B, C, H, W) 195 | return x 196 | 197 | def cnn2tf(x): 198 | B,C,H,W = x.shape 199 | L = H*W 200 | x = x.flatten(2).transpose(1,2).contiguous() # B H*W C 201 | return x,C,H,W 202 | 203 | def pixel_upsample(x, H, W): 204 | B, N, C = x.size() 205 | assert N == H*W 206 | x = x.permute(0, 2, 1) 207 | x = x.view(-1, C, H, W) 208 | x = nn.PixelShuffle(2)(x) 209 | B, C, H, W = x.size() 210 | x = x.view(-1, C, H*W) 211 | x = x.permute(0,2,1) 212 | return x, H, W 213 | 214 | 215 | def bicubic_upsample(x,H,W,up_mode='bicubic'): 216 | B, N, C = x.size() 217 | assert N == H*W 218 | x = x.permute(0, 2, 1) 219 | x = x.view(-1, C, H, W) 220 | x = F.interpolate(x, scale_factor=2, mode=up_mode) 221 | B, C, H, W = x.size() 222 | x = x.view(-1, C, H*W) 223 | x = x.permute(0,2,1) 224 | return x, H, W 225 | 226 | def get_norm_fun(norm_fun_type='none'): 227 | if norm_fun_type == 'BatchNorm': 228 | norm_fun = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 229 | elif norm_fun_type == 'InstanceNorm': 230 | norm_fun = functools.partial(nn.InstanceNorm2d, affine=True, track_running_stats=True) 231 | elif norm_fun_type == 'none': 232 | norm_fun = lambda x: Identity() 233 | else: 234 | raise NotImplementedError('normalization function [%s] is not found' % norm_fun_type) 235 | return norm_fun 236 | 237 | 238 | class ConvBlock(nn.Module): 239 | def __init__(self, in_channels, out_channels,kernel_size=3,dilation = 1,norm_fun='none'): 240 | super(ConvBlock, self).__init__() 241 | self.padding = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2 242 | norm_fun = get_norm_fun(norm_fun) 243 | self.conv = nn.Sequential( 244 | #1 245 | nn.ReflectionPad2d(self.padding), 246 | nn.Conv2d(in_channels, out_channels, 3, 1, 0), 247 | norm_fun(out_channels), 248 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 249 | #2 250 | nn.ReflectionPad2d(self.padding), 251 | nn.Conv2d(in_channels, out_channels, 3, 1, 0), 252 | norm_fun(out_channels), 253 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 254 | ) 255 | 256 | def forward(self, x): 257 | return self.conv(x) 258 | 259 | 260 | 261 | class ResBlock(nn.Module): 262 | def __init__(self, in_channels, out_channels,kernel_size=3,dilation = 1,norm_fun='none'): 263 | super(ResBlock, self).__init__() 264 | self.padding = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2 265 | norm_fun = get_norm_fun(norm_fun) 266 | self.conv = nn.Sequential( 267 | #1 268 | nn.ReflectionPad2d(self.padding), 269 | nn.Conv2d(in_channels, out_channels, 3, 1, 0), 270 | norm_fun(out_channels), 271 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 272 | #2 273 | nn.ReflectionPad2d(self.padding), 274 | nn.Conv2d(in_channels, out_channels, 3, 1, 0), 275 | norm_fun(out_channels), 276 | nn.LeakyReLU(negative_slope=0.2, inplace=True) 277 | ) 278 | 279 | def forward(self, x): 280 | return self.conv(x) + x 281 | 282 | 283 | def get_act_fun(act_fun_type='LeakyReLU'): 284 | if isinstance(act_fun_type, str): 285 | if act_fun_type == 'LeakyReLU': 286 | return nn.LeakyReLU(0.2, inplace=True) 287 | elif act_fun_type == 'ReLU': 288 | return nn.ReLU(inplace=True) 289 | elif act_fun_type == 'SELU': 290 | return nn.SELU(inplace=True) 291 | elif act_fun_type == 'none': 292 | return nn.Sequential() 293 | else: 294 | raise NotImplementedError('activation function [%s] is not found' % act_fun_type) 295 | else: 296 | return act_fun_type() 297 | 298 | class Identity(nn.Module): 299 | def forward(self, x): 300 | return x 301 | 302 | def get_norm_fun(norm_fun_type='none'): 303 | if norm_fun_type == 'BatchNorm': 304 | norm_fun = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 305 | elif norm_fun_type == 'InstanceNorm': 306 | norm_fun = functools.partial(nn.InstanceNorm2d, affine=True, track_running_stats=True) 307 | elif norm_fun_type == 'none': 308 | norm_fun = lambda x: Identity() 309 | else: 310 | raise NotImplementedError('normalization function [%s] is not found' % norm_fun_type) 311 | return norm_fun 312 | 313 | 314 | 315 | 316 | 317 | 318 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import collections 3 | import logging 4 | import math 5 | import os 6 | import time 7 | from datetime import datetime 8 | 9 | import dateutil.tz 10 | import torch 11 | import numpy as np 12 | from PIL import Image 13 | from scipy.signal import convolve2d 14 | import cv2 15 | 16 | 17 | def create_logger(log_dir, phase='train'): 18 | time_str = time.strftime('%Y-%m-%d-%H-%M') 19 | log_file = '{}_{}.log'.format(time_str, phase) 20 | final_log_file = os.path.join(log_dir, log_file) 21 | head = '%(asctime)-15s %(message)s' 22 | logging.basicConfig(filename=str(final_log_file), 23 | format=head) 24 | logger = logging.getLogger() 25 | logger.setLevel(logging.INFO) 26 | console = logging.StreamHandler() 27 | logging.getLogger('').addHandler(console) 28 | 29 | return logger 30 | 31 | 32 | def set_log_dir(root_dir, exp_name): 33 | path_dict = {} 34 | os.makedirs(root_dir, exist_ok=True) 35 | 36 | # set log path 37 | exp_path = os.path.join(root_dir, exp_name) 38 | now = datetime.now(dateutil.tz.tzlocal()) 39 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 40 | # prefix = exp_path + '_' + timestamp 41 | prefix = exp_path 42 | if not os.path.exists(prefix): 43 | os.makedirs(prefix) 44 | path_dict['prefix'] = prefix 45 | 46 | # set checkpoint path 47 | ckpt_path = os.path.join(prefix, 'Model') 48 | if not os.path.exists(ckpt_path): 49 | os.makedirs(ckpt_path) 50 | path_dict['ckpt_path'] = ckpt_path 51 | 52 | log_path = os.path.join(prefix, 'Log') 53 | if not os.path.exists(log_path): 54 | os.makedirs(log_path) 55 | path_dict['log_path'] = log_path 56 | 57 | # set sample image path for fid calculation 58 | sample_path = os.path.join(prefix, 'Samples') 59 | if not os.path.exists(sample_path): 60 | os.makedirs(sample_path) 61 | path_dict['sample_path'] = sample_path 62 | 63 | return path_dict 64 | 65 | 66 | def save_checkpoint(states, is_best, output_dir, 67 | filename='checkpoint.pth'): 68 | torch.save(states, os.path.join(output_dir, filename)) 69 | if is_best: 70 | torch.save(states, os.path.join(output_dir, 'checkpoint_best.pth')) 71 | 72 | 73 | class RunningStats: 74 | def __init__(self, WIN_SIZE): 75 | self.mean = 0 76 | self.run_var = 0 77 | self.WIN_SIZE = WIN_SIZE 78 | 79 | self.window = collections.deque(maxlen=WIN_SIZE) 80 | 81 | def clear(self): 82 | self.window.clear() 83 | self.mean = 0 84 | self.run_var = 0 85 | 86 | def is_full(self): 87 | return len(self.window) == self.WIN_SIZE 88 | 89 | def push(self, x): 90 | 91 | if len(self.window) == self.WIN_SIZE: 92 | # Adjusting variance 93 | x_removed = self.window.popleft() 94 | self.window.append(x) 95 | old_m = self.mean 96 | self.mean += (x - x_removed) / self.WIN_SIZE 97 | self.run_var += (x + x_removed - old_m - self.mean) * (x - x_removed) 98 | else: 99 | # Calculating first variance 100 | self.window.append(x) 101 | delta = x - self.mean 102 | self.mean += delta / len(self.window) 103 | self.run_var += delta * (x - self.mean) 104 | 105 | def get_mean(self): 106 | return self.mean if len(self.window) else 0.0 107 | 108 | def get_var(self): 109 | return self.run_var / len(self.window) if len(self.window) > 1 else 0.0 110 | 111 | def get_std(self): 112 | return math.sqrt(self.get_var()) 113 | 114 | def get_all(self): 115 | return list(self.window) 116 | 117 | def __str__(self): 118 | return "Current window values: {}".format(list(self.window)) 119 | 120 | def imread_CS_py(imgName): 121 | block_size = 64 122 | # img = cv2.imread(imgName) 123 | # img = cv2.cvtColor(imgName, cv2.COLOR_BGR2RGB) 124 | # Iorg = np.array(img, dtype='float32') 125 | Iorg = np.array(Image.open(imgName), dtype='float32') # 读图 126 | if len(Iorg.shape) == 3: #rgb转y 127 | Iorg = test_rgb2ycbcr(Iorg) 128 | [row, col] = Iorg.shape # 图像的 形状 129 | row_pad = block_size-np.mod(row,block_size) # 求余数操作 130 | col_pad = block_size-np.mod(col,block_size) # 求余数操作,用于判断需要补零的数量 131 | Ipad = np.concatenate((Iorg, np.zeros([row, col_pad])), axis=1) 132 | Ipad = np.concatenate((Ipad, np.zeros([row_pad, col+col_pad])), axis=0) 133 | [row_new, col_new] = Ipad.shape 134 | 135 | return [Iorg, row, col, Ipad, row_new, col_new] 136 | 137 | 138 | def imread_CS_py_new(imgName,block_size = 8): 139 | Iorg = np.array(Image.open(imgName), dtype='float32') # 读图 140 | if len(Iorg.shape) == 3: #rgb转y 141 | Iorg = test_rgb2ycbcr(Iorg) 142 | # [row, col] = Iorg.shape # 图像的 形状 143 | # row_pad = block_size-np.mod(row,block_size) # 求余数操作 144 | # col_pad = block_size-np.mod(col,block_size) # 求余数操作,用于判断需要补零的数量 145 | # Ipad = np.concatenate((Iorg, np.zeros([row, col_pad])), axis=1) 146 | # Ipad = np.concatenate((Ipad, np.zeros([row_pad, col+col_pad])), axis=0) 147 | # [row_new, col_new] = Ipad.shape 148 | 149 | return Iorg 150 | 151 | def img2col_py(Ipad, block_size): 152 | [row, col] = Ipad.shape # 当前图像的形状 153 | row_block = row/block_size 154 | col_block = col/block_size 155 | block_num = int(row_block*col_block) # 一共有多少个 模块 156 | img_col = np.zeros([block_size**2, block_num]) # 把每一块放进每一列中, 这就是容器 157 | count = 0 158 | for x in range(0, row-block_size+1, block_size): 159 | for y in range(0, col-block_size+1, block_size): 160 | img_col[:, count] = Ipad[x:x+block_size, y:y+block_size].reshape([-1]) 161 | # img_col[:, count] = Ipad[x:x+block_size, y:y+block_size].transpose().reshape([-1]) 162 | count = count + 1 163 | return img_col 164 | 165 | def psnr(img1, img2): 166 | img1.astype(np.float32) 167 | img2.astype(np.float32) 168 | mse = np.mean((img1 - img2) ** 2) 169 | if mse == 0: 170 | return 100 171 | PIXEL_MAX = 255.0 172 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 173 | 174 | def col2im_CS_py(X_col, row, col, row_new, col_new): 175 | block_size = 64 176 | X0_rec = np.zeros([row_new, col_new]) 177 | count = 0 178 | for x in range(0, row_new-block_size+1, block_size): 179 | for y in range(0, col_new-block_size+1, block_size): 180 | X0_rec[x:x+block_size, y:y+block_size] = X_col[:, count].reshape([block_size, block_size]) 181 | # X0_rec[x:x+block_size, y:y+block_size] = X_col[:, count].reshape([block_size, block_size]).transpose() 182 | count = count + 1 183 | X_rec = X0_rec[:row, :col] 184 | return X_rec 185 | 186 | def compute_ssim(im1, im2, k1=0.01, k2=0.03, win_size=11, L=255): 187 | 188 | if not im1.shape == im2.shape: 189 | raise ValueError("Input Imagees must have the same dimensions") 190 | if len(im1.shape) > 2: 191 | raise ValueError("Please input the images with 1 channel") 192 | 193 | M, N = im1.shape 194 | C1 = (k1*L)**2 195 | C2 = (k2*L)**2 196 | window = matlab_style_gauss2D(shape=(win_size,win_size), sigma=1.5) 197 | window = window/np.sum(np.sum(window)) 198 | 199 | if im1.dtype == np.uint8: 200 | im1 = np.double(im1) 201 | if im2.dtype == np.uint8: 202 | im2 = np.double(im2) 203 | 204 | mu1 = filter2(im1, window, 'valid') 205 | mu2 = filter2(im2, window, 'valid') 206 | mu1_sq = mu1 * mu1 207 | mu2_sq = mu2 * mu2 208 | mu1_mu2 = mu1 * mu2 209 | sigma1_sq = filter2(im1*im1, window, 'valid') - mu1_sq 210 | sigma2_sq = filter2(im2*im2, window, 'valid') - mu2_sq 211 | sigmal2 = filter2(im1*im2, window, 'valid') - mu1_mu2 212 | 213 | ssim_map = ((2*mu1_mu2+C1) * (2*sigmal2+C2)) / ((mu1_sq+mu2_sq+C1) * (sigma1_sq+sigma2_sq+C2)) 214 | 215 | return np.mean(np.mean(ssim_map)) 216 | 217 | def matlab_style_gauss2D(shape=(3,3),sigma=0.5): 218 | """ 219 | 2D gaussian mask - should give the same result as MATLAB's 220 | fspecial('gaussian',[shape],[sigma]) 221 | """ 222 | m,n = [(ss-1.)/2. for ss in shape] 223 | y,x = np.ogrid[-m:m+1,-n:n+1] 224 | h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) ) 225 | h[ h < np.finfo(h.dtype).eps*h.max() ] = 0 226 | sumh = h.sum() 227 | if sumh != 0: 228 | h /= sumh 229 | return h 230 | 231 | def filter2(x, kernel, mode='same'): 232 | return convolve2d(x, np.rot90(kernel, 2), mode=mode) 233 | 234 | def batch_rgb2ycbcr(img, only_y=True): 235 | '''same as matlab rgb2ycbcr 236 | only_y: only return Y channel 237 | Input: 238 | img type:tensor 239 | size: [batch,channels,h,w] [5,3,64,64] 240 | device: gpu 241 | range: [-1,1] 242 | uint8, [0, 255] 243 | float, [0, 1] 244 | ''' 245 | device = img.get_device() 246 | # img = img.to('cpu').numpy() 247 | img = (img + 1.) * 127.5 248 | img = img.permute((0,2,3,1)) #[batch,h,w,channels] 249 | # convert 250 | w = torch.FloatTensor([65.481, 128.553, 24.966]).to(device) 251 | rlt = torch.matmul(img, w) / 255.0 + 16.0 252 | rlt = rlt/127.5 - 1. #[batch,h,w] 253 | rlt = torch.unsqueeze(rlt,1) #[batch,1,h,w] 254 | 255 | return rlt 256 | 257 | def test_rgb2ycbcr(img, only_y=True): 258 | '''same as matlab rgb2ycbcr 259 | only_y: only return Y channel 260 | Input: 261 | uint8, [0, 255] 262 | float, [0, 1] 263 | ''' 264 | in_img_type = img.dtype 265 | img.astype(np.float32) 266 | # convert 267 | if only_y: 268 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 269 | else: 270 | rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], 271 | [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] 272 | 273 | rlt = rlt.round() 274 | 275 | return rlt.astype(in_img_type) 276 | 277 | 278 | def img2patches(imgs,patch_size:tuple,stride_size:tuple): 279 | """ 280 | Args: 281 | imgs: (H,W)/(H,W,C)/(B,H,W,C) 282 | patch_size: (patch_h, patch_w) 283 | stride_size: (stride_h, stride_w) 284 | """ 285 | 286 | 287 | if imgs.ndim == 2: 288 | # (H,W) -> (1,H,W,1) 289 | imgs = np.expand_dims(imgs,axis=2) 290 | imgs = np.expand_dims(imgs,axis=0) 291 | elif imgs.ndim == 3: 292 | # (H,W,C) -> (1,H,W,C) 293 | imgs = np.expand_dims(imgs,axis=0) 294 | b,h,w,c = imgs.shape 295 | p_h,p_w = patch_size 296 | s_h,s_w = stride_size 297 | 298 | assert (h-p_h) % s_h == 0 and (w-p_w) % s_w == 0 299 | 300 | n_patches_y = (h - p_h) // s_h + 1 301 | n_patches_x = (w - p_w) // s_w + 1 302 | n_patches_per_img = n_patches_y * n_patches_x 303 | n_patches = n_patches_per_img * b 304 | patches = np.empty((n_patches,p_h,p_w,c),dtype=imgs.dtype) 305 | 306 | patch_idx = 0 307 | for img in imgs: 308 | for i in range(n_patches_y): 309 | for j in range(n_patches_x): 310 | y1 = i * s_h 311 | y2 = y1 + p_h 312 | x1 = j * s_w 313 | x2 = x1 + p_w 314 | patches[patch_idx] = img[y1:y2, x1:x2] 315 | patch_idx += 1 316 | return patches 317 | 318 | def unpatch2d(patches, imsize: tuple, stride_size: tuple): 319 | ''' 320 | patches: (n_patches, p_h, p_w,c) 321 | imsize: (img_h, img_w) 322 | ''' 323 | assert len(patches.shape) == 4 324 | 325 | i_h, i_w = imsize 326 | n_patches,p_h,p_w,c = patches.shape 327 | s_h, s_w = stride_size 328 | 329 | assert (i_h - p_h) % s_h == 0 and (i_w - p_w) % s_w == 0 330 | 331 | n_patches_y = (i_h - p_h) // s_h + 1 332 | n_patches_x = (i_w - p_w) // s_w + 1 333 | n_patches_per_img = n_patches_y * n_patches_x 334 | batch_size = n_patches // n_patches_per_img 335 | 336 | imgs = np.zeros((batch_size,i_h,i_w,c)) 337 | weights = np.zeros_like(imgs) 338 | 339 | 340 | for img_idx, (img,weights) in enumerate(zip(imgs,weights)): 341 | start = img_idx * n_patches_per_img 342 | 343 | for i in range(n_patches_y): 344 | for j in range(n_patches_x): 345 | y1 = i * s_h 346 | y2 = y1 + p_h 347 | x1 = j * s_w 348 | x2 = x1 + p_w 349 | patch_idx = start + i*n_patches_x+j 350 | img[y1:y2,x1:x2] += patches[patch_idx] 351 | weights[y1:y2, x1:x2] += 1 352 | imgs /= weights 353 | 354 | return imgs.astype(patches.dtype) 355 | -------------------------------------------------------------------------------- /functions.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | from imageio import imsave 4 | from tqdm import tqdm 5 | from torchvision import utils as vutils 6 | import torchvision.transforms as transforms 7 | import glob 8 | from utils.utils import * 9 | from torch.autograd import Variable 10 | from torch.utils.data import BatchSampler,SequentialSampler 11 | from torch.utils.data._utils.collate import default_collate as collate_fn 12 | import time 13 | from copy import deepcopy 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | def train(args, gen_net: nn.Module, gen_optimizer,train_loader,epoch, writer_dict, loss_all): 18 | writer = writer_dict['writer'] 19 | gen_step = 0 20 | # train mode 21 | gen_net = gen_net.train() 22 | 23 | for iter_idx, imgs in enumerate(tqdm(train_loader)): 24 | if imgs.size(0) != args.gen_batch_size: 25 | continue 26 | 27 | global_steps = writer_dict['train_global_steps'] 28 | 29 | input_imgs = imgs.type(torch.cuda.FloatTensor).to("cuda:0") 30 | real_imgs = input_imgs 31 | 32 | # --------------------- 33 | # Training 34 | # --------------------- 35 | gen_optimizer.zero_grad() 36 | gen_imgs,init_rgb,tf_patch,ini_patch = gen_net(input_imgs) 37 | # 38 | if iter_idx%200==0: 39 | if args.datarange == '-11': 40 | final_rgb2 = torch.squeeze(gen_imgs[1, :, :, :]) 41 | init_rgb2 = torch.squeeze(init_rgb[1,:,:,:]) 42 | gt_rgb2 = torch.squeeze(real_imgs[1,:,:,:]) 43 | final_rgb2 = np.round((final_rgb2.detach().cpu().numpy() + 1.) * 127.5) 44 | init_rgb2 = np.round((init_rgb2.detach().cpu().numpy() + 1.) * 127.5) 45 | gt_rgb2 = np.round((gt_rgb2.detach().cpu().numpy() + 1.) * 127.5) 46 | # torchvision 47 | if args.torch_vision == True: 48 | checkrgb = torch.from_numpy(np.stack((gt_rgb2, final_rgb2, init_rgb2), axis=0)) 49 | checkrgb = checkrgb.unsqueeze(1) 50 | img_grid = vutils.make_grid(checkrgb, nrow=3, normalize=True, range=(0, 255)) 51 | writer.add_image("train_img___gt___output___ini", img_grid, global_steps) 52 | 53 | # cal loss 54 | rec_loss = args.rec_w * loss_all['rec_loss'](gen_imgs, real_imgs) 55 | loss = rec_loss 56 | rec_loss_print = rec_loss.item() 57 | writer.add_scalar('rec_loss', rec_loss_print, global_steps) 58 | 59 | loss.backward() 60 | gen_optimizer.step() 61 | 62 | gen_step += 1 63 | 64 | # verbose 65 | if gen_step and iter_idx % args.print_freq == 0: 66 | tqdm.write( 67 | "[Epoch %d/%d] [Batch %d/%d] [loss: %f] [Rec loss: %f]" % 68 | (epoch, args.max_epoch, iter_idx % len(train_loader), len(train_loader), loss.item(), 69 | rec_loss_print)) 70 | 71 | writer_dict['train_global_steps'] = global_steps + 1 72 | 73 | def validate(args, epoch, gen_net: nn.Module, writer_dict): 74 | writer = writer_dict['writer'] 75 | global_steps = writer_dict['valid_global_steps'] 76 | # eval mode 77 | gen_net = gen_net.eval() 78 | PSNR_cross = [] 79 | SSIM_cross = [] 80 | 81 | 82 | with torch.no_grad(): 83 | val_set_path = args.valdata_path 84 | val_set_path1 = glob.glob(val_set_path + '/*.tif') 85 | val_set_path2 = glob.glob(val_set_path + '/*.png') 86 | val_set_path = val_set_path1 + val_set_path2 87 | ImgNum = len(val_set_path) 88 | PSNR_All = np.zeros([1, ImgNum], dtype=np.float32) 89 | SSIM_All = np.zeros([1, ImgNum], dtype=np.float32) 90 | 91 | dname = 'val' 92 | stime = time.time() 93 | 94 | 95 | for img_no in tqdm(range(ImgNum)): 96 | imgName = val_set_path[img_no] 97 | Iorg = imread_CS_py_new(imgName) 98 | if args.datarange == '-11': 99 | Ifloat = Iorg/ 127.5 - 1. 100 | inputs = Variable(torch.from_numpy(Ifloat.astype('float32')).cuda()) 101 | inputs = torch.unsqueeze(inputs, dim=0) 102 | inputs = torch.unsqueeze(inputs, dim=0) 103 | _, _, h_old, w_old = inputs.size() 104 | padding_block = 64 105 | if h_old%padding_block != 0: 106 | h_pad = (h_old // padding_block+1) * padding_block - h_old 107 | else: 108 | h_pad = 0 109 | if w_old%padding_block != 0: 110 | w_pad = (w_old // padding_block+1) * padding_block - w_old 111 | else: 112 | w_pad = 0 113 | inputs = torch.cat([inputs, torch.flip(inputs, [2])], 2)[:, :, :h_old + h_pad, :] 114 | inputs = torch.cat([inputs, torch.flip(inputs, [3])], 3)[:, :, :, :w_old + w_pad] 115 | 116 | 117 | output,ini_rgbs, _, _ = gen_net(inputs) 118 | 119 | output = torch.squeeze(output) 120 | Inirec = torch.squeeze(ini_rgbs) 121 | 122 | 123 | output = output.cpu().data.numpy() 124 | Inirec = Inirec.cpu().data.numpy() 125 | 126 | 127 | images_recovered = output[0:h_old, 0:w_old] 128 | Inirec = Inirec[0:h_old, 0:w_old] 129 | 130 | if args.datarange == '-11': 131 | images_recovered = np.clip(images_recovered, -1, 1).astype(np.float32) 132 | Irec = np.round((images_recovered+1.) * 127.5) 133 | 134 | 135 | rec_PSNR = psnr(Irec, Iorg) 136 | PSNR_All[0, img_no] = rec_PSNR 137 | rec_SSIM = compute_ssim(Irec, Iorg) 138 | SSIM_All[0, img_no] = rec_SSIM 139 | 140 | # 141 | PSNR_mean = np.mean(PSNR_All) 142 | PSNR_cross.append(PSNR_mean) 143 | SSIM_mean = np.mean(SSIM_All) 144 | SSIM_cross.append(SSIM_mean) 145 | tqdm.write( 146 | "[%s] [Epoch %d] [Mean PSNR: %f] [Mean SSIM: %f] [time: %d]" % 147 | (dname, epoch, PSNR_mean, SSIM_mean,time.time()-stime)) 148 | 149 | PSNR_cross_mean = np.mean(PSNR_cross) 150 | SSIM_cross_mean = np.mean(SSIM_cross) 151 | 152 | writer.add_scalar('val_SSIM_score', SSIM_cross_mean, global_steps) 153 | writer.add_scalar('val_PSNR_score', PSNR_cross_mean, global_steps) 154 | 155 | writer_dict['valid_global_steps'] = global_steps + 1 156 | 157 | return PSNR_cross_mean,PSNR_cross,SSIM_cross_mean,SSIM_cross 158 | 159 | def test(args,gen_net: nn.Module, logger): 160 | # eval mode 161 | gen_net = gen_net.eval() 162 | 163 | PSNR_cross = [] 164 | SSIM_cross = [] 165 | stage1_PSNR_cross = [] 166 | stage1_SSIM_cross = [] 167 | 168 | with torch.no_grad(): 169 | for i in range(len(args.testdata_path)): 170 | # for i in range(1): 171 | test_set_path = args.testdata_path[i] 172 | test_set_path1 = glob.glob(test_set_path + '/*.tif') 173 | test_set_path2 = glob.glob(test_set_path + '/*.png') 174 | test_set_path = test_set_path1 + test_set_path2 175 | ImgNum = len(test_set_path) 176 | PSNR_All = np.zeros([1, ImgNum], dtype=np.float32) 177 | SSIM_All = np.zeros([1, ImgNum], dtype=np.float32) 178 | PSNR_stage1ALL = np.zeros([1, ImgNum], dtype=np.float32) 179 | SSIM_stage1ALL = np.zeros([1, ImgNum], dtype=np.float32) 180 | 181 | if ImgNum == 11: 182 | dname = 'set11' 183 | elif ImgNum == 68: 184 | dname = 'BSD68' 185 | elif ImgNum == 14: 186 | dname = 'set14' 187 | elif ImgNum == 5: 188 | dname = 'set5' 189 | elif ImgNum == 100: 190 | dname = 'urban100' 191 | else: 192 | dname = 'test' 193 | 194 | 195 | 196 | save_dir = args.path_helper['sample_path'] 197 | print(f'save dir is {save_dir}') 198 | 199 | if not os.path.exists(os.path.join(save_dir)): 200 | os.makedirs(os.path.join(save_dir)) 201 | 202 | stime = time.time() 203 | for img_no in tqdm(range(ImgNum)): 204 | imgName = test_set_path[img_no] 205 | [Iorg, row, col, Ipad, row_new, col_new] = imread_CS_py(imgName) 206 | if args.datarange == '-11': 207 | # Icol = img2col_py(Ipad, 64) / 127.5 - 1. # uint to [-1,1] 208 | Ipad = Ipad / 127.5 - 1. 209 | inputs = Variable(torch.from_numpy(Ipad.astype('float32')).cuda()) 210 | inputs = torch.unsqueeze(inputs, dim=0) 211 | inputs = torch.unsqueeze(inputs, dim=0) 212 | 213 | output, _, _,_ = gen_net(inputs) 214 | 215 | output = torch.squeeze(output) 216 | 217 | 218 | output = output.cpu().data.numpy() 219 | 220 | 221 | images_recovered = output[0:row, 0:col] 222 | 223 | if args.datarange == '-11': 224 | Irec = np.round((images_recovered + 1.) * 127.5) 225 | 226 | rec_PSNR = psnr(Irec, Iorg) 227 | PSNR_All[0, img_no] = rec_PSNR 228 | 229 | 230 | rec_SSIM = compute_ssim(Irec, Iorg) 231 | SSIM_All[0, img_no] = rec_SSIM 232 | 233 | 234 | imgname_for_save = os.path.basename(imgName) 235 | imgname_for_save = os.path.splitext(imgname_for_save)[0] 236 | imgname_for_save = imgname_for_save + '.png' 237 | imgname_for_save = os.path.join(save_dir,imgname_for_save) 238 | imsave(imgname_for_save,Irec.astype(np.uint8)) 239 | 240 | PSNR_mean = np.mean(PSNR_All) 241 | PSNR_cross.append(PSNR_mean) 242 | SSIM_mean = np.mean(SSIM_All) 243 | SSIM_cross.append(SSIM_mean) 244 | 245 | 246 | logger.info(f"[{dname}] [Mean PSNR loss: {PSNR_mean:.2f}] [Mean SSIM loss: {SSIM_mean:.4f}] )") 247 | 248 | 249 | PSNR_cross_mean = np.mean(PSNR_cross) 250 | SSIM_cross_mean = np.mean(SSIM_cross) 251 | 252 | 253 | logger.info(f"[all cross] [Mean PSNR loss: {PSNR_cross_mean}] [Mean SSIM loss: {SSIM_cross_mean}] ") 254 | 255 | 256 | def test_overlap(args,gen_net: nn.Module, logger): 257 | # eval mode 258 | gen_net = gen_net.eval() 259 | step = args.overlapstep 260 | PSNR_cross = [] 261 | SSIM_cross = [] 262 | logger.info(f'the overlap step is {step}') 263 | with torch.no_grad(): 264 | #for i in range(len(args.testdata_path)): 265 | for i in range(1): 266 | test_set_path = args.testdata_path[i] 267 | # test_set_path = args.testdata_path 268 | print(f'test_set_path is {test_set_path} \n') 269 | 270 | test_set_path1 = glob.glob(test_set_path + '/*.tif') 271 | test_set_path2 = glob.glob(test_set_path + '/*.png') 272 | test_set_path3 = glob.glob(test_set_path + '/*.JPG') 273 | test_set_path4 = glob.glob(test_set_path + '/*.jpg') 274 | test_set_path = test_set_path1 + test_set_path2 + test_set_path3 + test_set_path4 275 | ImgNum = len(test_set_path) 276 | PSNR_All = np.zeros([1, ImgNum], dtype=np.float32) 277 | SSIM_All = np.zeros([1, ImgNum], dtype=np.float32) 278 | # PSNR_stage1ALL = np.zeros([1, ImgNum], dtype=np.float32) 279 | # SSIM_stage1ALL = np.zeros([1, ImgNum], dtype=np.float32) 280 | print(f'len is {ImgNum} \n') 281 | 282 | if ImgNum == 11: 283 | dname = 'set11' 284 | elif ImgNum == 68: 285 | dname = 'BSD68' 286 | elif ImgNum == 14: 287 | dname = 'set14' 288 | elif ImgNum == 5: 289 | dname = 'set5' 290 | elif ImgNum == 100: 291 | dname = 'urban100' 292 | 293 | save_dir = args.path_helper['sample_path'] 294 | print(f'save dir is {save_dir}') 295 | 296 | 297 | 298 | for img_no in tqdm(range(ImgNum)): 299 | imgName = test_set_path[img_no] 300 | [Iorg, row, col, Ipad, row_new, col_new] = imread_CS_py(imgName) 301 | if args.datarange == '-11': 302 | patches = img2patches(Ipad, (64, 64), (step,step)) 303 | patches_batch = patches / 127.5 - 1. 304 | inputs = Variable(torch.from_numpy(patches_batch.astype('float32')).cuda()) 305 | inputs = inputs.permute(0,3,1,2) 306 | 307 | 308 | output = torch.FloatTensor(inputs.shape[0],inputs.shape[1],inputs.shape[2],inputs.shape[3]).cuda() 309 | ini_rgbs = torch.FloatTensor(inputs.shape[0],inputs.shape[1],inputs.shape[2],inputs.shape[3]).cuda() 310 | 311 | batch_list = list(BatchSampler(SequentialSampler(output), batch_size=args.eval_batch_size, drop_last=False)) 312 | for idx, list_data in enumerate(BatchSampler(inputs, batch_size=args.eval_batch_size, drop_last=False)): 313 | batch_x = collate_fn(list_data) 314 | list_tmp = batch_list[idx] 315 | output[list_tmp, :, :, :], _, _, _ = gen_net(batch_x) 316 | 317 | 318 | 319 | output = output.permute(0,2,3,1) 320 | Inirec = ini_rgbs.permute(0,2,3,1) 321 | 322 | 323 | output = output.cpu().data.numpy() 324 | Inirec = Inirec.cpu().data.numpy() 325 | 326 | #unpatch 327 | output = unpatch2d(output, Ipad.shape,(step,step)).squeeze() 328 | Inirec = unpatch2d(Inirec, Ipad.shape,(step,step)).squeeze() 329 | 330 | images_recovered = output[0:row, 0:col] 331 | Inirec = Inirec[0:row, 0:col] 332 | 333 | if args.datarange == '-11': 334 | Irec = np.round((images_recovered + 1.) * 127.5) 335 | 336 | 337 | rec_PSNR = psnr(Irec, Iorg) 338 | PSNR_All[0, img_no] = rec_PSNR 339 | 340 | rec_SSIM = compute_ssim(Irec, Iorg) 341 | SSIM_All[0, img_no] = rec_SSIM 342 | 343 | if not os.path.exists(os.path.join(save_dir)): 344 | os.makedirs(os.path.join(save_dir)) 345 | 346 | imgname_for_save = os.path.basename(imgName) 347 | imgname_for_save = os.path.splitext(imgname_for_save)[0] 348 | imgname_for_save = imgname_for_save + '.png' 349 | imgname_for_save = os.path.join(save_dir,imgname_for_save) 350 | imsave(imgname_for_save,Irec.astype(np.uint8)) 351 | 352 | 353 | PSNR_mean = np.mean(PSNR_All) 354 | PSNR_cross.append(PSNR_mean) 355 | SSIM_mean = np.mean(SSIM_All) 356 | SSIM_cross.append(SSIM_mean) 357 | 358 | 359 | logger.info(f"[{dname}] [Mean PSNR loss: {PSNR_mean:.2f}] [Mean SSIM loss: {SSIM_mean:.4f}] )") 360 | 361 | 362 | PSNR_cross_mean = np.mean(PSNR_cross) 363 | SSIM_cross_mean = np.mean(SSIM_cross) 364 | 365 | 366 | logger.info(f"[all cross] [Mean PSNR loss: {PSNR_cross_mean}] [Mean SSIM loss: {SSIM_cross_mean}] ") 367 | 368 | 369 | 370 | def pytorch_unnormalize(tensor,mean=(0.5,0.5,0.5),std = (0.5,0.5,0.5)): 371 | inv_normalize = transforms.Normalize( 372 | mean=[-m / s for m, s in zip(mean, std)], 373 | std=[1 / s for s in std] 374 | ) 375 | return inv_normalize(tensor) 376 | 377 | def inverse_normalize(tensor, mean, std): 378 | for t, m, s in zip(tensor, mean, std): 379 | t.mul_(s).add_(m) 380 | return tensor 381 | 382 | def load_params(model, new_param): 383 | for p, new_p in zip(model.parameters(), new_param): 384 | p.data.copy_(new_p) 385 | 386 | 387 | def copy_params(model): 388 | flatten = deepcopy(list(p.data for p in model.parameters())) 389 | return flatten -------------------------------------------------------------------------------- /models/Transformer64.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.nn.functional as F 5 | from models.ViT_helper import DropPath, to_2tuple, trunc_normal_ 6 | from einops import rearrange, repeat 7 | 8 | 9 | class matmul(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, x1, x2): 14 | x = x1@x2 15 | return x 16 | 17 | def count_matmul(m, x, y): 18 | num_mul = x[0].numel() * x[1].size(-1) 19 | # m.total_ops += torch.DoubleTensor([int(num_mul)]) 20 | m.total_ops += torch.DoubleTensor([int(0)]) 21 | 22 | 23 | def gelu(x): 24 | """ Original Implementation of the gelu activation function in Google Bert repo when initialy created. 25 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 26 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 27 | Also see https://arxiv.org/abs/1606.08415 28 | """ 29 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 30 | 31 | def leakyrelu(x): 32 | return nn.functional.leaky_relu_(x, 0.2) 33 | class CustomAct(nn.Module): 34 | def __init__(self, act_layer): 35 | super().__init__() 36 | if act_layer == "gelu": 37 | self.act_layer = gelu 38 | elif act_layer == "leakyrelu": 39 | self.act_layer = leakyrelu 40 | 41 | def forward(self, x): 42 | return self.act_layer(x) 43 | 44 | class Mlp(nn.Module): 45 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=gelu, drop=0.): 46 | super().__init__() 47 | out_features = out_features or in_features 48 | hidden_features = hidden_features or in_features 49 | self.fc1 = nn.Linear(in_features, hidden_features) 50 | self.act = CustomAct(act_layer) 51 | self.fc2 = nn.Linear(hidden_features, out_features) 52 | self.drop = nn.Dropout(drop) 53 | def forward(self, x): 54 | x = self.fc1(x) 55 | x = self.act(x) 56 | x = self.drop(x) 57 | x = self.fc2(x) 58 | x = self.drop(x) 59 | return x 60 | 61 | 62 | def get_attn_mask(N, w): 63 | mask = torch.zeros(1, 1, N, N).cuda() 64 | for i in range(N): 65 | if i <= w: 66 | mask[:, :, i, 0:i+w+1] = 1 67 | elif N - i <= w: 68 | mask[:, :, i, i-w:N] = 1 69 | else: 70 | mask[:, :, i, i:i+w+1] = 1 71 | mask[:, :, i, i-w:i] = 1 72 | return mask 73 | 74 | 75 | class Attention(nn.Module): 76 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=16): 77 | super().__init__() 78 | self.num_heads = num_heads 79 | head_dim = dim // num_heads 80 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 81 | self.scale = qk_scale or head_dim ** -0.5 82 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 83 | self.attn_drop = nn.Dropout(attn_drop) 84 | self.proj = nn.Linear(dim, dim) 85 | self.proj_drop = nn.Dropout(proj_drop) 86 | self.mat = matmul() 87 | self.window_size = window_size 88 | 89 | self.relative_position_bias_table = nn.Parameter( 90 | torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 91 | 92 | # get pair-wise relative position index for each token inside the window 93 | coords_h = torch.arange(window_size) 94 | coords_w = torch.arange(window_size) 95 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 96 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 97 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 98 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 99 | relative_coords[:, :, 0] += window_size - 1 # shift to start from 0 100 | relative_coords[:, :, 1] += window_size - 1 101 | relative_coords[:, :, 0] *= 2 * window_size - 1 102 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 103 | self.register_buffer("relative_position_index", relative_position_index) 104 | 105 | trunc_normal_(self.relative_position_bias_table, std=.02) 106 | 107 | def forward(self, x): 108 | B, N, C = x.shape 109 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 110 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 111 | attn = (self.mat(q, k.transpose(-2, -1))) * self.scale 112 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 113 | self.window_size * self.window_size, self.window_size * self.window_size, -1) # Wh*Ww,Wh*Ww,nH 114 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 115 | attn = attn + relative_position_bias.unsqueeze(0) 116 | 117 | attn = attn.softmax(dim=-1) 118 | attn = self.attn_drop(attn) 119 | x = self.mat(attn, v).transpose(1, 2).reshape(B, N, C) 120 | x = self.proj(x) 121 | x = self.proj_drop(x) 122 | return x 123 | 124 | ######################################### 125 | ########### window operation############# 126 | def window_partition(x, window_size): 127 | """ 128 | Args: 129 | x: (B, H, W, C) 130 | window_size (int): window size 131 | Returns: 132 | windows: (num_windows*B, window_size, window_size, C) 133 | """ 134 | B, H, W, C = x.shape 135 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 136 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 137 | return windows 138 | 139 | def window_reverse(windows, window_size, H, W): 140 | """ 141 | Args: 142 | windows: (num_windows*B, window_size, window_size, C) 143 | window_size (int): Window size 144 | H (int): Height of image 145 | W (int): Width of image 146 | Returns: 147 | x: (B, H, W, C) 148 | """ 149 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 150 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 151 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 152 | return x 153 | 154 | 155 | class PixelNorm(nn.Module): 156 | def __init__(self, dim): 157 | super().__init__() 158 | def forward(self, input): 159 | return input * torch.rsqrt(torch.mean(input ** 2, dim=2, keepdim=True) + 1e-8) 160 | 161 | class CustomNorm(nn.Module): 162 | def __init__(self, norm_layer, dim): 163 | super().__init__() 164 | self.norm_type = norm_layer 165 | if norm_layer == "ln": 166 | self.norm = nn.LayerNorm(dim) 167 | elif norm_layer == "bn": 168 | self.norm = nn.BatchNorm1d(dim) 169 | elif norm_layer == "in": 170 | self.norm = nn.InstanceNorm1d(dim) 171 | elif norm_layer == "pn": 172 | self.norm = PixelNorm(dim) 173 | 174 | def forward(self, x): 175 | if self.norm_type == "bn" or self.norm_type == "in": 176 | x = self.norm(x.permute(0, 2, 1)).permute(0, 2, 1) 177 | return x 178 | elif self.norm_type == "none": 179 | return x 180 | else: 181 | return self.norm(x) 182 | 183 | class Block(nn.Module): 184 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 185 | drop_path=0., act_layer=gelu, norm_layer=nn.LayerNorm, window_size=16): 186 | super().__init__() 187 | self.norm1 = CustomNorm(norm_layer, dim) 188 | self.attn = Attention( 189 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size) 190 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 191 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 192 | self.norm2 = CustomNorm(norm_layer, dim) 193 | mlp_hidden_dim = int(dim * mlp_ratio) 194 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 195 | def forward(self, x): 196 | x = x + self.drop_path(self.attn(self.norm1(x))) 197 | x = x + self.drop_path(self.mlp(self.norm2(x))) 198 | return x 199 | 200 | class StageBlock(nn.Module): 201 | def __init__(self, depth, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=gelu, norm_layer=nn.LayerNorm, window_size=16): 202 | super().__init__() 203 | self.depth = depth 204 | models = [Block( 205 | dim=dim, 206 | num_heads=num_heads, 207 | mlp_ratio=mlp_ratio, 208 | qkv_bias=qkv_bias, 209 | qk_scale=qk_scale, 210 | drop=drop, 211 | attn_drop=attn_drop, 212 | drop_path=drop_path, 213 | act_layer=act_layer, 214 | norm_layer=norm_layer, 215 | window_size=window_size 216 | ) for i in range(depth)] 217 | self.block = nn.Sequential(*models) 218 | def forward(self, x): 219 | # for blk in self.block: 220 | # # x = blk(x) 221 | # checkpoint.checkpoint(blk, x) 222 | # x = checkpoint.checkpoint(self.block, x) 223 | x = self.block(x) 224 | return x 225 | 226 | def pixel_upsample(x, H, W): 227 | B, N, C = x.size() 228 | assert N == H*W 229 | x = x.permute(0, 2, 1) 230 | x = x.view(-1, C, H, W) 231 | x = nn.PixelShuffle(2)(x) 232 | B, C, H, W = x.size() 233 | x = x.view(-1, C, H*W) 234 | x = x.permute(0,2,1) 235 | return x, H, W 236 | 237 | class Upsample(nn.Module): 238 | def __init__(self, in_channel, out_channel,up_mode='bicubic'): 239 | super(Upsample, self).__init__() 240 | self.conv = nn.Sequential( 241 | nn.Conv2d(in_channel, out_channel, kernel_size=1), 242 | ) 243 | self.up_mode = up_mode 244 | 245 | def forward(self, x): 246 | B, L, C = x.shape 247 | H = int(math.sqrt(L)) 248 | W = int(math.sqrt(L)) 249 | x = x.transpose(1, 2).contiguous().view(B, C, H, W) 250 | #up 251 | x = F.interpolate(x, scale_factor=2, mode=self.up_mode) 252 | H = x.shape[2] 253 | W = x.shape[3] 254 | #↓dim 255 | out = self.conv(x).flatten(2).transpose(1,2).contiguous() # B H*W C 256 | return out,H,W 257 | 258 | # --- Build dense --- # 259 | class MakeDense(nn.Module): 260 | def __init__(self, in_channels, growth_rate, num_heads,mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=gelu, norm_layer=nn.LayerNorm, window_size=16): 261 | super(MakeDense, self).__init__() 262 | self.block =Block( 263 | dim=in_channels, 264 | num_heads=num_heads, 265 | mlp_ratio=mlp_ratio, 266 | qkv_bias=qkv_bias, 267 | qk_scale=qk_scale, 268 | drop=drop, 269 | attn_drop=attn_drop, 270 | drop_path=drop_path, 271 | act_layer=act_layer, 272 | norm_layer=norm_layer, 273 | window_size=window_size 274 | ) 275 | self.conv_1x1 = nn.Conv2d(in_channels, growth_rate, kernel_size=1, padding=0) 276 | self.norm = CustomNorm('ln', growth_rate) 277 | 278 | def forward(self, x): 279 | out = self.block(x) 280 | 281 | B, N, C = out.size() 282 | H = W = int(math.sqrt(N)) 283 | out = out.permute(0, 2, 1).contiguous() 284 | out = out.view(-1, C, H, W) 285 | out = self.conv_1x1(out) 286 | B, C, H, W = out.size() 287 | out = out.view(-1, C, H * W) 288 | out = out.permute(0, 2, 1).contiguous() 289 | out = self.norm(out) 290 | 291 | out = torch.cat((x, out), 2) 292 | 293 | return out 294 | 295 | 296 | # --- Build the Residual Dense Block --- # 297 | class RDTB(nn.Module): 298 | def __init__(self, num_dense_layer,in_channels,growth_rate, num_heads,mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=gelu, norm_layer=nn.LayerNorm, window_size=16): 299 | """ 300 | :param in_channels: input channel size 301 | :param num_dense_layer: the number of RDB layers 302 | :param growth_rate: growth_rate 303 | """ 304 | super(RDTB, self).__init__() 305 | _in_channels = in_channels 306 | modules = [] 307 | for i in range(num_dense_layer): 308 | modules.append(MakeDense(in_channels=_in_channels, 309 | growth_rate = growth_rate, 310 | num_heads=num_heads, 311 | mlp_ratio=mlp_ratio, 312 | qkv_bias=qkv_bias, 313 | qk_scale=qk_scale, 314 | drop=drop, 315 | attn_drop=attn_drop, 316 | drop_path=drop_path, 317 | act_layer=act_layer, 318 | norm_layer=norm_layer, 319 | window_size=window_size 320 | )) 321 | _in_channels += growth_rate 322 | self.residual_dense_layers = nn.Sequential(*modules) 323 | self.conv_1x1 = nn.Conv2d(_in_channels, in_channels, kernel_size=1, padding=0) 324 | self.norm = CustomNorm('ln', in_channels) 325 | 326 | def forward(self, x): 327 | out = self.residual_dense_layers(x) 328 | 329 | B, N, C = out.size() 330 | H = W = int(math.sqrt(N)) 331 | out = out.permute(0, 2, 1).contiguous() 332 | out = out.view(-1, C, H, W) 333 | out = self.conv_1x1(out) 334 | B, C, H, W = out.size() 335 | out = out.view(-1, C, H * W) 336 | out = out.permute(0, 2, 1).contiguous() 337 | out = self.norm(out) 338 | out = out + x 339 | return out 340 | 341 | class Generator_tailadd(nn.Module): 342 | def __init__(self, args, img_size=224, patch_size=16, in_chans=3, num_classes=10, embed_dim=384, depth=[2,2,2,2], 343 | num_heads=[8,4,2,1], mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 344 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm,upsample=Upsample): 345 | super().__init__() 346 | self.args = args 347 | growth_rate = args.growth_rate 348 | self.ch = embed_dim 349 | self.bottom_width = args.bottom_width 350 | self.embed_dim = embed_dim = args.gf_dim 351 | self.window_size = args.g_window_size 352 | num_dense_layer = args.num_dense_layer 353 | norm_layer = args.g_norm 354 | mlp_ratio = args.g_mlp 355 | depth = [int(i) for i in args.g_depth.split(",")] 356 | act_layer = args.g_act 357 | num_heads = [int(i) for i in args.num_heads.split(",")] 358 | 359 | self.pos_embed_1 = nn.Parameter(torch.zeros(1, self.bottom_width ** 2, embed_dim)) 360 | self.pos_embed_2 = nn.Parameter(torch.zeros(1, (self.bottom_width * 2) ** 2, embed_dim//2)) 361 | self.pos_embed_3 = nn.Parameter(torch.zeros(1, (self.bottom_width * 4) ** 2, embed_dim//4)) 362 | self.pos_embed_4 = nn.Parameter(torch.zeros(1, (self.bottom_width * 8) ** 2, embed_dim//8)) 363 | 364 | self.pos_embed = [ 365 | self.pos_embed_1, 366 | self.pos_embed_2, 367 | self.pos_embed_3, 368 | self.pos_embed_4 369 | ] 370 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth[0])] # stochastic depth decay rule 371 | 372 | self.blocks_1 = StageBlock( 373 | depth=depth[0], 374 | dim=embed_dim, 375 | num_heads=num_heads[0], 376 | mlp_ratio=mlp_ratio, 377 | qkv_bias=qkv_bias, 378 | qk_scale=qk_scale, 379 | drop=drop_rate, 380 | attn_drop=attn_drop_rate, 381 | drop_path=0, 382 | act_layer=act_layer, 383 | norm_layer=norm_layer, 384 | window_size=self.window_size 385 | ) 386 | self.upsample_1 = upsample(embed_dim, embed_dim//2) 387 | 388 | self.blocks_2 = StageBlock( 389 | depth=depth[1], 390 | dim=embed_dim//2, 391 | num_heads=num_heads[1], 392 | mlp_ratio=mlp_ratio, 393 | qkv_bias=qkv_bias, 394 | qk_scale=qk_scale, 395 | drop=drop_rate, 396 | attn_drop=attn_drop_rate, 397 | drop_path=0, 398 | act_layer=act_layer, 399 | norm_layer=norm_layer, 400 | window_size=self.window_size 401 | ) 402 | self.upsample_2 = upsample(embed_dim//2, embed_dim // 4) 403 | 404 | self.blocks_3 = StageBlock( 405 | depth=depth[2], 406 | dim=embed_dim//4, 407 | num_heads=num_heads[2], 408 | mlp_ratio=mlp_ratio, 409 | qkv_bias=qkv_bias, 410 | qk_scale=qk_scale, 411 | drop=drop_rate, 412 | attn_drop=attn_drop_rate, 413 | drop_path=0, 414 | act_layer=act_layer, 415 | norm_layer=norm_layer, 416 | window_size=self.window_size 417 | ) 418 | self.upsample_3 = upsample(embed_dim // 4, embed_dim // 8) 419 | 420 | self.blocks_4 = StageBlock( 421 | depth=depth[3], 422 | dim=embed_dim//8, 423 | num_heads=num_heads[3], 424 | mlp_ratio=mlp_ratio, 425 | qkv_bias=qkv_bias, 426 | qk_scale=qk_scale, 427 | drop=drop_rate, 428 | attn_drop=attn_drop_rate, 429 | drop_path=0, 430 | act_layer=act_layer, 431 | norm_layer=norm_layer, 432 | window_size=self.window_size 433 | ) 434 | 435 | 436 | for i in range(len(self.pos_embed)): 437 | trunc_normal_(self.pos_embed[i], std=.02) 438 | 439 | 440 | self.padding3 = (3 + (3 - 1) * (1 - 1) - 1) // 2 441 | self.padding7 = (7 + (7 - 1) * (1 - 1) - 1) // 2 442 | self.to_rgb = nn.ModuleList() 443 | for i in range(4): 444 | to_rgb = None 445 | dim = embed_dim // (2**i) 446 | to_rgb = nn.Sequential( 447 | nn.ReflectionPad2d(self.padding3), 448 | nn.Conv2d(dim, 1, 3, 1, 0), 449 | # nn.ReflectionPad2d(self.padding7), 450 | # nn.Conv2d(self.embed_dim, 1, 7, 1, 0), 451 | nn.Tanh() 452 | ) 453 | self.to_rgb.append(to_rgb) 454 | 455 | self.apply(self._init_weights) 456 | 457 | def _init_weights(self, m): 458 | if isinstance(m, nn.Linear): 459 | trunc_normal_(m.weight, std=.02) 460 | if isinstance(m, nn.Linear) and m.bias is not None: 461 | nn.init.constant_(m.bias, 0) 462 | elif isinstance(m, nn.LayerNorm): 463 | nn.init.constant_(m.bias, 0) 464 | nn.init.constant_(m.weight, 1.0) 465 | 466 | 467 | def forward(self, x, gsfeatures, gsrgb): 468 | 469 | features = [] 470 | fufeatures = [] 471 | tfrgb = [] 472 | outputrgb = [] 473 | self.pos_embed = self.pos_embed 474 | #change 475 | 476 | #-------block 1---------- 477 | x = x + self.pos_embed[0].to(x.get_device()) #8x8 478 | B,_,C = x.size() 479 | H, W = self.bottom_width, self.bottom_width 480 | x = self.blocks_1(x) 481 | #features add 482 | features.append(x.permute(0, 2, 1).contiguous().view(-1, C, H, W)) 483 | tfrgb.append(self.to_rgb[0](x.permute(0, 2, 1).contiguous().view(-1, C, H, W))) 484 | 485 | fu_rgb = tfrgb[0]+gsrgb[0] 486 | fu_feature = features[0] + gsfeatures[0] 487 | outputrgb.append(fu_rgb) 488 | fufeatures.append(fu_feature) 489 | x = fu_feature.view(-1,C,H*W).permute(0,2,1).contiguous() 490 | 491 | # -------block 2---------- 492 | x, H, W = self.upsample_1(x) #16x16 493 | x = x + self.pos_embed[1].to(x.get_device()) 494 | B, _, C = x.size() 495 | x = x.view(B, H, W, C) 496 | x = window_partition(x, self.window_size) 497 | x = x.view(-1, self.window_size * self.window_size, C) 498 | x = self.blocks_2(x) 499 | x = x.view(-1, self.window_size, self.window_size, C) 500 | x = window_reverse(x, self.window_size, H, W).view(B, H * W, C) 501 | #features add 502 | features.append(x.permute(0, 2, 1).contiguous().view(-1, C, H, W)) 503 | tfrgb.append(self.to_rgb[1](x.permute(0, 2, 1).contiguous().view(-1, C, H, W))) 504 | 505 | fu_rgb = tfrgb[1]+gsrgb[1] 506 | fu_feature = features[1] + gsfeatures[1] 507 | outputrgb.append(fu_rgb) 508 | fufeatures.append(fu_feature) 509 | x = fu_feature.view(-1,C,H*W).permute(0,2,1).contiguous() 510 | 511 | # -------block 3---------- 512 | x, H, W = self.upsample_2(x) #32x32 513 | x = x + self.pos_embed[2].to(x.get_device()) 514 | B, _, C = x.size() 515 | x = x.view(B, H, W, C) 516 | x = window_partition(x, self.window_size) 517 | x = x.view(-1, self.window_size * self.window_size, C) 518 | x = self.blocks_3(x) 519 | x = x.view(-1, self.window_size, self.window_size, C) 520 | x = window_reverse(x, self.window_size, H, W).view(B, H * W, C) 521 | #features add 522 | features.append(x.permute(0, 2, 1).contiguous().view(-1, C, H, W)) 523 | tfrgb.append(self.to_rgb[2](x.permute(0, 2, 1).contiguous().view(-1, C, H, W))) 524 | 525 | fu_rgb = tfrgb[2]+gsrgb[2] 526 | fu_feature = features[2] + gsfeatures[2] 527 | outputrgb.append(fu_rgb) 528 | fufeatures.append(fu_feature) 529 | x = fu_feature.view(-1,C,H*W).permute(0,2,1).contiguous() 530 | 531 | # -------block 4---------- 532 | x, H, W = self.upsample_3(x) #64x64 533 | x = x + self.pos_embed[3].to(x.get_device()) 534 | B, _, C = x.size() 535 | x = x.view(B, H, W, C) 536 | x = window_partition(x, self.window_size) 537 | x = x.view(-1, self.window_size * self.window_size, C) 538 | x = self.blocks_4(x) 539 | x = x.view(-1, self.window_size, self.window_size, C) 540 | x = window_reverse(x, self.window_size, H, W).view(B, H * W, C) 541 | #features add 542 | features.append(x.permute(0, 2, 1).contiguous().view(-1, C, H, W)) 543 | 544 | 545 | fu_feature = features[3] + gsfeatures[3] 546 | fufeatures.append(fu_feature) 547 | output = self.to_rgb[3](fu_feature) 548 | outputrgb.append(output) 549 | 550 | final_output = outputrgb.pop(-1) 551 | return final_output, outputrgb 552 | 553 | 554 | class Transformer(nn.Module): 555 | def __init__(self, args, img_size=224, patch_size=16, in_chans=3, num_classes=10, embed_dim=384, depth=[2,2,2,2], 556 | num_heads=[16,8,4,2], mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 557 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm,upsample=Upsample): 558 | super().__init__() 559 | self.args = args 560 | self.ch = embed_dim 561 | self.bottom_width = args.bottom_width 562 | self.embed_dim = embed_dim = args.gf_dim 563 | self.window_size = args.g_window_size 564 | norm_layer = args.g_norm 565 | mlp_ratio = args.g_mlp 566 | depth = [int(i) for i in args.g_depth.split(",")] 567 | act_layer = args.g_act 568 | num_heads = [int(i) for i in args.num_heads.split(",")] 569 | 570 | self.pos_embed_1 = nn.Parameter(torch.zeros(1, self.bottom_width ** 2, embed_dim*2)) 571 | self.pos_embed_2 = nn.Parameter(torch.zeros(1, (self.bottom_width * 2) ** 2, embed_dim)) 572 | self.pos_embed_3 = nn.Parameter(torch.zeros(1, (self.bottom_width * 4) ** 2, embed_dim//2)) 573 | self.pos_embed_4 = nn.Parameter(torch.zeros(1, (self.bottom_width * 8) ** 2, embed_dim//4)) 574 | 575 | self.pos_embed = [ 576 | self.pos_embed_1, 577 | self.pos_embed_2, 578 | self.pos_embed_3, 579 | self.pos_embed_4 580 | ] 581 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth[0])] # stochastic depth decay rule 582 | 583 | self.blocks_1 = StageBlock( 584 | depth=depth[0], 585 | dim=embed_dim*2, 586 | num_heads=num_heads[0], 587 | mlp_ratio=mlp_ratio, 588 | qkv_bias=qkv_bias, 589 | qk_scale=qk_scale, 590 | drop=drop_rate, 591 | attn_drop=attn_drop_rate, 592 | drop_path=0, 593 | act_layer=act_layer, 594 | norm_layer=norm_layer, 595 | window_size=self.window_size 596 | ) 597 | # self.upsample_1 = upsample(embed_dim, embed_dim//2) 598 | 599 | self.blocks_2 = StageBlock( 600 | depth=depth[1], 601 | dim=embed_dim, 602 | num_heads=num_heads[1], 603 | mlp_ratio=mlp_ratio, 604 | qkv_bias=qkv_bias, 605 | qk_scale=qk_scale, 606 | drop=drop_rate, 607 | attn_drop=attn_drop_rate, 608 | drop_path=0, 609 | act_layer=act_layer, 610 | norm_layer=norm_layer, 611 | window_size=self.window_size 612 | ) 613 | # self.upsample_2 = upsample(embed_dim//2, embed_dim // 4) 614 | 615 | self.blocks_3 = StageBlock( 616 | depth=depth[2], 617 | dim=embed_dim//2, 618 | num_heads=num_heads[2], 619 | mlp_ratio=mlp_ratio, 620 | qkv_bias=qkv_bias, 621 | qk_scale=qk_scale, 622 | drop=drop_rate, 623 | attn_drop=attn_drop_rate, 624 | drop_path=0, 625 | act_layer=act_layer, 626 | norm_layer=norm_layer, 627 | window_size=self.window_size 628 | ) 629 | # self.upsample_3 = upsample(embed_dim // 4, embed_dim // 8) 630 | 631 | self.blocks_4 = StageBlock( 632 | depth=depth[3], 633 | dim=embed_dim//4, 634 | num_heads=num_heads[3], 635 | mlp_ratio=mlp_ratio, 636 | qkv_bias=qkv_bias, 637 | qk_scale=qk_scale, 638 | drop=drop_rate, 639 | attn_drop=attn_drop_rate, 640 | drop_path=0, 641 | act_layer=act_layer, 642 | norm_layer=norm_layer, 643 | window_size=self.window_size 644 | ) 645 | 646 | for i in range(len(self.pos_embed)): 647 | trunc_normal_(self.pos_embed[i], std=.02) 648 | 649 | if args.datarange == '01': 650 | rgbact = nn.Sigmoid() 651 | print('You choose sigmoid for range [0,1]') 652 | elif args.datarange == '-11': 653 | rgbact = nn.Tanh() 654 | print('You choose tanh for range [-1,1]') 655 | 656 | self.padding3 = (3 + (3 - 1) * (1 - 1) - 1) // 2 657 | self.padding7 = (7 + (7 - 1) * (1 - 1) - 1) // 2 658 | self.to_rgb = nn.Sequential( 659 | nn.ReflectionPad2d(self.padding3), 660 | nn.Conv2d((embed_dim*2) // (2**3), (embed_dim*2) // (2**3), 3, 1, 0), 661 | nn.ReflectionPad2d(self.padding7), 662 | nn.Conv2d((embed_dim*2) // (2**3), 1, 7, 1, 0), 663 | nn.Tanh() 664 | ) 665 | 666 | self.apply(self._init_weights) 667 | print('apply init weith trunc_normal') 668 | 669 | def _init_weights(self, m): 670 | if isinstance(m, nn.Linear): 671 | trunc_normal_(m.weight, std=.02) 672 | if isinstance(m, nn.Linear) and m.bias is not None: 673 | nn.init.constant_(m.bias, 0) 674 | elif isinstance(m, nn.LayerNorm): 675 | nn.init.constant_(m.bias, 0) 676 | nn.init.constant_(m.weight, 1.0) 677 | 678 | def forward(self, x, gsfeatures, inirgb): 679 | 680 | features = [] 681 | # fufeatures = [] 682 | # tfrgb = [] 683 | outputrgb = [] 684 | 685 | #change 686 | 687 | #-------block 1---------- 688 | #feature cat 689 | x = tf2cnn(x) 690 | x = torch.cat([gsfeatures[0],x],1) 691 | x,C,H,W = cnn2tf(x) 692 | x = x + self.pos_embed[0].to(x.get_device()) 693 | x = self.blocks_1(x) 694 | 695 | 696 | # -------block 2---------- 697 | x, H, W = pixel_upsample(x,H,W) #16x16 698 | x = tf2cnn(x) 699 | x = torch.cat([gsfeatures[1],x],1) 700 | x,C,H,W = cnn2tf(x) 701 | x = x + self.pos_embed[1].to(x.get_device()) 702 | B, _, C = x.size() 703 | x = x.view(B, H, W, C) 704 | x = window_partition(x, self.window_size) 705 | x = x.view(-1, self.window_size * self.window_size, C) 706 | x = self.blocks_2(x) 707 | x = x.view(-1, self.window_size, self.window_size, C) 708 | x = window_reverse(x, self.window_size, H, W).view(B, H * W, C) 709 | 710 | 711 | 712 | # -------block 3---------- 713 | x, H, W = pixel_upsample(x,H,W) #32x32 714 | x = tf2cnn(x) 715 | x = torch.cat([gsfeatures[2],x],1) 716 | x,C,H,W = cnn2tf(x) 717 | x = x + self.pos_embed[2].to(x.get_device()) 718 | B, _, C = x.size() 719 | x = x.view(B, H, W, C) 720 | x = window_partition(x, self.window_size) 721 | x = x.view(-1, self.window_size * self.window_size, C) 722 | x = self.blocks_3(x) 723 | x = x.view(-1, self.window_size, self.window_size, C) 724 | x = window_reverse(x, self.window_size, H, W).view(B, H * W, C) 725 | 726 | 727 | # -------block 4---------- 728 | x, H, W = pixel_upsample(x,H,W) #64x64 729 | x = tf2cnn(x) 730 | x = torch.cat([gsfeatures[3],x],1) 731 | x,C,H,W = cnn2tf(x) 732 | x = x + self.pos_embed[3].to(x.get_device()) 733 | B, _, C = x.size() 734 | x = x.view(B, H, W, C) 735 | x = window_partition(x, self.window_size) 736 | x = x.view(-1, self.window_size * self.window_size, C) 737 | x = self.blocks_4(x) 738 | x = x.view(-1, self.window_size, self.window_size, C) 739 | x = window_reverse(x, self.window_size, H, W).view(B, H * W, C) 740 | rgb_64 =self.to_rgb(x.permute(0, 2, 1).view(-1, C, H, W))+inirgb 741 | 742 | return rgb_64 743 | 744 | def pixel_upsample(x, H, W): 745 | B, N, C = x.size() 746 | assert N == H*W 747 | x = x.permute(0, 2, 1) 748 | x = x.view(-1, C, H, W) 749 | x = nn.PixelShuffle(2)(x) 750 | B, C, H, W = x.size() 751 | x = x.view(-1, C, H*W) 752 | x = x.permute(0,2,1) 753 | return x, H, W 754 | 755 | def tf2cnn(x): 756 | B, L, C = x.shape 757 | H = int(math.sqrt(L)) 758 | W = int(math.sqrt(L)) 759 | x = x.transpose(1, 2).contiguous().view(B, C, H, W) 760 | return x 761 | 762 | def cnn2tf(x): 763 | B,C,H,W = x.shape 764 | L = H*W 765 | x = x.flatten(2).transpose(1,2).contiguous() # B H*W C 766 | return x,C,H,W 767 | 768 | def bicubic_upsample(x,H,W,up_mode='bicubic'): 769 | B, N, C = x.size() 770 | assert N == H*W 771 | x = x.permute(0, 2, 1) 772 | x = x.view(-1, C, H, W) 773 | x = F.interpolate(x, scale_factor=2, mode=up_mode) 774 | B, C, H, W = x.size() 775 | x = x.view(-1, C, H*W) 776 | x = x.permute(0,2,1) 777 | return x, H, W 778 | 779 | 780 | 781 | def _downsample(x): 782 | # Downsample (Mean Avg Pooling with 2x2 kernel) 783 | return nn.AvgPool2d(kernel_size=2)(x) 784 | 785 | 786 | class PatchEmbed(nn.Module): 787 | """ Image to Patch Embedding 788 | """ 789 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 790 | super().__init__() 791 | img_size = to_2tuple(img_size) 792 | patch_size = to_2tuple(patch_size) 793 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 794 | self.img_size = img_size 795 | self.patch_size = patch_size 796 | self.num_patches = num_patches 797 | 798 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 799 | 800 | def forward(self, x): 801 | B, C, H, W = x.shape 802 | # FIXME look at relaxing size constraints 803 | assert H == self.img_size[0] and W == self.img_size[1], \ 804 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 805 | x = self.proj(x).flatten(2).transpose(1, 2) 806 | return x 807 | 808 | 809 | class HybridEmbed(nn.Module): 810 | """ CNN Feature Map Embedding 811 | Extract feature map from CNN, flatten, project to embedding dim. 812 | """ 813 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): 814 | super().__init__() 815 | assert isinstance(backbone, nn.Module) 816 | img_size = to_2tuple(img_size) 817 | self.img_size = img_size 818 | self.backbone = backbone 819 | if feature_size is None: 820 | with torch.no_grad(): 821 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature 822 | # map for all networks, the feature metadata has reliable channel and stride info, but using 823 | # stride to calc feature dim requires info about padding of each stage that isn't captured. 824 | training = backbone.training 825 | if training: 826 | backbone.eval() 827 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] 828 | feature_size = o.shape[-2:] 829 | feature_dim = o.shape[1] 830 | backbone.train(training) 831 | else: 832 | feature_size = to_2tuple(feature_size) 833 | feature_dim = self.backbone.feature_info.channels()[-1] 834 | self.num_patches = feature_size[0] * feature_size[1] 835 | self.proj = nn.Linear(feature_dim, embed_dim) 836 | 837 | def forward(self, x): 838 | x = self.backbone(x)[-1] 839 | x = x.flatten(2).transpose(1, 2) 840 | x = self.proj(x) 841 | return x 842 | 843 | 844 | def get_activation(activation): 845 | """ 846 | Get the module for a specific activation function and its gain if 847 | it can be calculated. 848 | Arguments: 849 | activation (str, callable, nn.Module): String representing the activation. 850 | Returns: 851 | activation_module (torch.nn.Module): The module representing 852 | the activation function. 853 | gain (float): The gain value. Defaults to 1 if it can not be calculated. 854 | """ 855 | if isinstance(activation, nn.Module) or callable(activation): 856 | return activation 857 | if isinstance(activation, str): 858 | activation = activation.lower() 859 | if activation in [None, 'linear']: 860 | return nn.Identity() 861 | lrelu_strings = ('leaky', 'leakyrely', 'leaky_relu', 'leaky relu', 'lrelu') 862 | if activation.startswith(lrelu_strings): 863 | for l_s in lrelu_strings: 864 | activation = activation.replace(l_s, '') 865 | slope = ''.join( 866 | char for char in activation if char.isdigit() or char == '.') 867 | slope = float(slope) if slope else 0.01 868 | return nn.LeakyReLU(slope) # close enough to true gain 869 | elif activation in ['relu']: 870 | return nn.ReLU() 871 | elif activation in ['elu']: 872 | return nn.ELU() 873 | elif activation in ['prelu']: 874 | return nn.PReLU() 875 | elif activation in ['rrelu', 'randomrelu']: 876 | return nn.RReLU() 877 | elif activation in ['selu']: 878 | return nn.SELU() 879 | elif activation in ['softplus']: 880 | return nn.Softplus() 881 | elif activation in ['softsign']: 882 | return nn.Softsign() # unsure about this gain 883 | elif activation in ['sigmoid', 'logistic']: 884 | return nn.Sigmoid() 885 | elif activation in ['tanh']: 886 | return nn.Tanh() 887 | else: 888 | raise ValueError( 889 | 'Activation "{}" not available.'.format(activation) 890 | ) 891 | 892 | 893 | def _conv_filter(state_dict, patch_size=16): 894 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 895 | out_dict = {} 896 | for k, v in state_dict.items(): 897 | if 'patch_embed.proj.weight' in k: 898 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 899 | out_dict[k] = v 900 | return out_dict 901 | 902 | 903 | 904 | --------------------------------------------------------------------------------