├── README.md ├── figs ├── HST.png └── framework ├── model.py ├── srdata.py ├── swin_block.py ├── test.sh ├── test_HST.py ├── util_calculate_psnr_ssim.py └── utils_logger.py /README.md: -------------------------------------------------------------------------------- 1 | # HST 2 | HST: Hierarchical Swin Transformer for Compressed Image Super-resolution 3 | > [**HST**](https://arxiv.org/abs/2208.09885), Bingchen Li, Xin Li, et al. 4 | 5 | > Achieved **the fifth place** in the competition of the **AIM2022 compressed image super-resolution** track. 6 | 7 | > Accepted by ECCV2022 Workshop 8 | 9 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hst-hierarchical-swin-transformer-for/compressed-image-super-resolution-on-div2k)](https://paperswithcode.com/sota/compressed-image-super-resolution-on-div2k?p=hst-hierarchical-swin-transformer-for) 10 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hst-hierarchical-swin-transformer-for/compressed-image-super-resolution-on-div2k-1)](https://paperswithcode.com/sota/compressed-image-super-resolution-on-div2k-1?p=hst-hierarchical-swin-transformer-for) 11 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hst-hierarchical-swin-transformer-for/compressed-image-super-resolution-on-div2k-2)](https://paperswithcode.com/sota/compressed-image-super-resolution-on-div2k-2?p=hst-hierarchical-swin-transformer-for) 12 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hst-hierarchical-swin-transformer-for/compressed-image-super-resolution-on-div2k-3)](https://paperswithcode.com/sota/compressed-image-super-resolution-on-div2k-3?p=hst-hierarchical-swin-transformer-for) 13 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hst-hierarchical-swin-transformer-for/compressed-image-super-resolution-on-set5-q10)](https://paperswithcode.com/sota/compressed-image-super-resolution-on-set5-q10?p=hst-hierarchical-swin-transformer-for) 14 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hst-hierarchical-swin-transformer-for/compressed-image-super-resolution-on-set14)](https://paperswithcode.com/sota/compressed-image-super-resolution-on-set14?p=hst-hierarchical-swin-transformer-for) 15 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hst-hierarchical-swin-transformer-for/compressed-image-super-resolution-on-bsd100)](https://paperswithcode.com/sota/compressed-image-super-resolution-on-bsd100?p=hst-hierarchical-swin-transformer-for) 16 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hst-hierarchical-swin-transformer-for/compressed-image-super-resolution-on-urban100)](https://paperswithcode.com/sota/compressed-image-super-resolution-on-urban100?p=hst-hierarchical-swin-transformer-for) 17 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/hst-hierarchical-swin-transformer-for/compressed-image-super-resolution-on-manga109)](https://paperswithcode.com/sota/compressed-image-super-resolution-on-manga109?p=hst-hierarchical-swin-transformer-for) 18 | 19 | ![image](https://github.com/lixinustc/HST-Hierarchical-Swin-Transformer-for-Compressed-Image-Super-resolution/blob/main/figs/HST.png) 20 | 21 | ## Abstract 22 | Compressed Image Super-resolution has achieved great attention in recent years, where images are degraded with compression artifacts and low-resolution artifacts. Since the complex hybrid distortions, 23 | it is hard to restore the distorted image with the simple cooperation 24 | of super-resolution and compression artifacts removing. In this paper, 25 | we take a step forward to propose the Hierarchical Swin Transformer 26 | (HST) network to restore the low-resolution compressed image, which 27 | jointly captures the hierarchical feature representations and enhances 28 | each-scale representation with Swin transformer, respectively. Moreover, 29 | we find that the pretraining with Super-resolution (SR) task is vital 30 | in compressed image super-resolution. To explore the effects of different SR pretraining, we take the commonly-used SR tasks (e.g., bicubic 31 | and different real super-resolution simulations) as our pretraining tasks, 32 | and reveal that SR plays an irreplaceable role in the compressed image super-resolution. With the cooperation of HST and pre-training, our 33 | HST achieves the fifth place in AIM 2022 challenge on the low-quality 34 | compressed image super-resolution track, with the PSNR of 23.51dB. Extensive experiments and ablation studies have validated the effectiveness 35 | of our proposed methods. 36 | 37 | ## Usages 38 | More details will be decribed progressively. 39 | 40 | **The checkpoints for HST are released**: 41 | - [checkpoint_comp10_x4](https://drive.google.com/file/d/1ZtGxO6ghT1YFLgu_PIHBt7VpDV52CsjS/view?usp=sharing) 42 | - [checkpoint_comp20_x4](https://drive.google.com/file/d/1ldXbI5c9KHxsHQZS3hRRK2jR9HRvfqyD/view?usp=sharing) 43 | - [checkpoint_comp30_x4](https://drive.google.com/file/d/1ANqQkYW7JKPLdJLKq1xHixaSZtn3e0q-/view?usp=sharing) 44 | - [checkpoint_comp40_x4](https://drive.google.com/file/d/1SlvhcFSEr4jM5gUB_we8c-EYmpIUntIT/view?usp=sharing) 45 | 46 | 47 | ## Cite US 48 | Please cite us if this work is helpful to you. 49 | ``` 50 | @inproceedings{li2022hst, 51 | title={HST: Hierarchical Swin Transformer for Compressed Image Super-resolution}, 52 | author={Li, Bingchen and Li, Xin and Lu, Yiting and Liu, Sen and Feng, Ruoyu and Chen, Zhibo}, 53 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV) Workshops}, 54 | year={2022} 55 | } 56 | ``` 57 | 58 | The model is implemented based on the works: 59 | [MSGDN](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w7/Li_Multi-Scale_Grouped_Dense_Network_for_VVC_Intra_Coding_CVPRW_2020_paper.pdf), [SwinIR](https://github.com/JingyunLiang/SwinIR), [SwinTransformer](https://arxiv.org/abs/2103.14030) 60 | -------------------------------------------------------------------------------- /figs/HST.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/USTC-IMCL/HST-for-Compressed-Image-SR/8798e244062c7a17259abcda74d93477584a57ac/figs/HST.png -------------------------------------------------------------------------------- /figs/framework: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from swin_block import RSTB, PatchEmbed, PatchUnEmbed 6 | from timm.models.layers import trunc_normal_ 7 | 8 | 9 | class GRSTB(nn.Module): 10 | def __init__(self, num_features, img_size, use_embed=True, patch_size=1, 11 | depths=[6, 6, 6, 6, 6, 6], num_heads=[6, 6, 6, 6, 6, 6], 12 | window_size=7, mlp_ratio=2., qkv_bias=True, qk_scale=None, 13 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 14 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 15 | use_checkpoint=False, resi_connection='1conv', 16 | **kwargs): 17 | super(GRSTB, self).__init__() 18 | 19 | self.num_layers = len(depths) 20 | patches_resolution = (img_size, img_size) 21 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 22 | self.mlp_ratio = mlp_ratio 23 | self.patch_norm = patch_norm 24 | self.ape = ape 25 | self.use_embed = use_embed 26 | 27 | self.layers = nn.ModuleList() 28 | for i_layer in range(self.num_layers): 29 | layer = RSTB(dim=num_features, 30 | input_resolution=(patches_resolution[0], 31 | patches_resolution[1]), 32 | depth=depths[i_layer], 33 | num_heads=num_heads[i_layer], 34 | window_size=window_size, 35 | mlp_ratio=self.mlp_ratio, 36 | qkv_bias=qkv_bias, qk_scale=qk_scale, 37 | drop=drop_rate, attn_drop=attn_drop_rate, 38 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results 39 | norm_layer=norm_layer, 40 | downsample=None, 41 | use_checkpoint=use_checkpoint, 42 | img_size=img_size, 43 | patch_size=patch_size, 44 | resi_connection=resi_connection 45 | 46 | ) 47 | self.layers.append(layer) 48 | 49 | self.patch_embed = PatchEmbed( 50 | img_size=img_size, patch_size=patch_size, in_chans=num_features, embed_dim=num_features, 51 | norm_layer=norm_layer if self.patch_norm else None) 52 | self.patch_unembed = PatchUnEmbed( 53 | img_size=img_size, patch_size=patch_size, in_chans=num_features, embed_dim=num_features, 54 | norm_layer=norm_layer if self.patch_norm else None) 55 | self.pos_drop = nn.Dropout(p=drop_rate) 56 | self.norm = norm_layer(num_features) 57 | 58 | def forward(self, x, x_size=None): 59 | if self.use_embed: 60 | x_size = (x.shape[2], x.shape[3]) 61 | x = self.patch_embed(x) 62 | 63 | x = self.pos_drop(x) 64 | 65 | for layer in self.layers: 66 | x = layer(x, x_size) 67 | 68 | x = self.norm(x) # B L C 69 | x = self.patch_unembed(x, x_size) 70 | 71 | return x 72 | 73 | 74 | class HST(nn.Module): 75 | def __init__(self, img_size, num_features=[60, 60, 60], scale=4, window_size=8): 76 | super(HST, self).__init__() 77 | self.img_size_h = img_size 78 | self.img_size_m = self.img_size_h // 2 79 | self.img_size_l = self.img_size_m // 2 80 | num_fea_h, num_fea_m, num_fea_l = num_features 81 | self.window_size = window_size 82 | self.scale = scale 83 | 84 | rgb_mean = (0.4488, 0.4371, 0.4040) 85 | self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) 86 | 87 | self.GRSTB_1 = GRSTB(num_fea_l, self.img_size_l, depths=[6, 6], num_heads=[6, 6], window_size=self.window_size) 88 | self.conv_after_grstb1 = nn.Conv2d(num_fea_l, num_fea_l, 3, 1, 1) 89 | self.GRSTB_2 = GRSTB(num_fea_m*2, self.img_size_m, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], window_size=self.window_size) 90 | self.conv_after_grstb2 = nn.Conv2d(num_fea_m*2, num_fea_m*2, 3, 1, 1) 91 | self.GRSTB_3 = GRSTB(num_fea_h*3, self.img_size_h, window_size=self.window_size) 92 | self.conv_after_grstb3 = nn.Conv2d(num_fea_h*3, num_fea_h*3, 3, 1, 1) 93 | 94 | self.conv_fuse_1 = nn.Conv2d(num_fea_m*2, num_fea_m*2, 1, 1, 0) 95 | self.conv_fuse_2 = nn.Conv2d(num_fea_h*3, num_fea_h*3, 1, 1, 0) 96 | 97 | self.conv1 = nn.Conv2d(3, num_fea_m, kernel_size=5, stride=2, padding=2) 98 | self.conv2 = nn.Conv2d(num_fea_m, num_fea_l, kernel_size=3, stride=2, padding=1) 99 | self.conv3 = nn.Conv2d(3, num_fea_h, kernel_size=7, stride=1, padding=3) 100 | 101 | self.up = nn.PixelShuffle(2) 102 | self.upconv1 = nn.Conv2d(num_fea_l, num_fea_m*4, kernel_size=3, stride=1, padding=1) 103 | self.upconv2 = nn.Conv2d(num_fea_m*2, num_fea_h*8, kernel_size=3, stride=1, padding=1) 104 | 105 | self.conv_before_up = nn.Sequential(nn.Conv2d(num_fea_h*3, 64, 3, 1, 1), 106 | nn.LeakyReLU(inplace=True)) 107 | self.upsample = nn.Sequential( 108 | nn.Conv2d(64, 4 * 64, kernel_size=3, padding=1), 109 | nn.PixelShuffle(2), 110 | nn.Conv2d(64, 4 * 64, kernel_size=3, padding=1), 111 | nn.PixelShuffle(2) 112 | ) 113 | 114 | self.conv_last = nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1) 115 | 116 | self.apply(self._init_weights) 117 | 118 | def _init_weights(self, m): 119 | if isinstance(m, nn.Linear): 120 | trunc_normal_(m.weight, std=.02) 121 | if isinstance(m, nn.Linear) and m.bias is not None: 122 | nn.init.constant_(m.bias, 0) 123 | elif isinstance(m, nn.LayerNorm): 124 | nn.init.constant_(m.bias, 0) 125 | nn.init.constant_(m.weight, 1.0) 126 | 127 | def check_image_size(self, x, t): 128 | _, _, h, w = x.size() 129 | # t = self.window_size * self.scale 130 | mod_pad_h = (t - h % t) % t 131 | mod_pad_w = (t - w % t) % t 132 | # mod_pad_h = (self.window_size - h % self.window_size) % self.window_size 133 | # mod_pad_w = (self.window_size - w % self.window_size) % self.window_size 134 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'replicate') 135 | return x 136 | 137 | def forward(self, x): 138 | _, _, H, W = x.size() 139 | x = self.check_image_size(x, 8) 140 | self.mean = self.mean.type_as(x) 141 | x = x - self.mean 142 | x_m = self.conv1(x) 143 | x_l = self.conv2(x_m) 144 | x_h = self.conv3(x) 145 | res_l = x_l 146 | 147 | x_l = self.GRSTB_1(x_l) 148 | x_l = self.conv_after_grstb1(x_l) 149 | 150 | x_l += res_l 151 | x_l = self.upconv1(x_l) 152 | x_l = self.up(x_l) 153 | x_lm = torch.cat([x_l, x_m], dim=1) 154 | x_lm = self.conv_fuse_1(x_lm) 155 | 156 | res_lm = x_lm 157 | 158 | x_lm = self.GRSTB_2(x_lm) 159 | x_lm = self.conv_after_grstb2(x_lm) 160 | 161 | x_lm += res_lm 162 | x_lm = self.upconv2(x_lm) 163 | x_lm = self.up(x_lm) 164 | x_h = torch.cat([x_lm, x_h], dim=1) 165 | x_h = self.conv_fuse_2(x_h) 166 | 167 | res_h = x_h 168 | 169 | x_h = self.GRSTB_3(x_h) 170 | x_h = self.conv_after_grstb3(x_h) 171 | x_h += res_h 172 | x_h = self.conv_before_up(x_h) 173 | x_h = self.upsample(x_h) 174 | x_h = self.conv_last(x_h) 175 | x_h += self.mean 176 | x_h = x_h[:, :, :H*self.scale, :W*self.scale] 177 | 178 | return x_h 179 | -------------------------------------------------------------------------------- /srdata.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data 6 | from torchvision.transforms import ToTensor 7 | import cv2 8 | import random 9 | import glob 10 | 11 | 12 | class Data_Train(data.Dataset): 13 | def __init__(self, patch_size=64): 14 | self.scale = 4 15 | self.patch_size = patch_size 16 | 17 | self.dir_hr = sorted(glob.glob('PATH_TO_CLIC/*') + glob.glob('PATH_TO_DF2K/*')) 18 | self.dir_lr = sorted(glob.glob('PATH_TO_CLIC_COMP/*') + glob.glob('PATH_TO_DF2K_COMP/*')) 19 | 20 | def __getitem__(self, idx): 21 | name_hr = self.dir_hr[idx] 22 | name_lr = self.dir_lr[idx] 23 | 24 | number_hr = os.path.basename(name_hr).split('.')[0] 25 | number_lr = os.path.basename(name_lr).split('.')[0] 26 | 27 | assert number_hr == number_lr 28 | 29 | hr = cv2.cvtColor(cv2.imread(name_hr), cv2.COLOR_BGR2RGB) 30 | lr = cv2.cvtColor(cv2.imread(name_lr), cv2.COLOR_BGR2RGB) 31 | hmax, wmax, _ = lr.shape 32 | 33 | crop_h = np.random.randint(0, hmax-self.patch_size) 34 | crop_w = np.random.randint(0, wmax-self.patch_size) 35 | 36 | hr = hr[crop_h*4:(crop_h+self.patch_size)*4, crop_w*4:(crop_w+self.patch_size)*4, ...] 37 | lr = lr[crop_h:crop_h+self.patch_size, crop_w:crop_w+self.patch_size, ...] 38 | 39 | mode = random.randint(0, 7) 40 | 41 | lr, hr = augment_img(lr, mode=mode), augment_img(hr, mode=mode) 42 | 43 | lr = ToTensor()(lr.copy()) 44 | hr = ToTensor()(hr.copy()) 45 | 46 | output = {'L': lr, 'H': hr, 'N': number_hr} 47 | return output 48 | 49 | def __len__(self): 50 | return len(self.dir_hr) 51 | 52 | 53 | class Data_Test(data.Dataset): 54 | def __init__(self): 55 | 56 | self.dir_hr = 'PATH_TO_DIV2K_VALID' 57 | self.dir_lr = 'PATH_TO_DIV2K_VALID_COMP' 58 | self.name_hr = sorted(os.listdir(self.dir_hr)) 59 | 60 | def __getitem__(self, idx): 61 | name = self.name_hr[idx] 62 | 63 | number = name.split('.')[0] 64 | 65 | hr = cv2.cvtColor(cv2.imread(os.path.join(self.dir_hr, name)), cv2.COLOR_BGR2RGB) 66 | lr = cv2.cvtColor(cv2.imread(os.path.join(self.dir_lr, number+'.jpg')), cv2.COLOR_BGR2RGB) 67 | 68 | lr = ToTensor()(lr) 69 | hr = ToTensor()(hr) 70 | output = {'L': lr, 'H': hr, 'N': number} 71 | return output 72 | 73 | def __len__(self): 74 | return len(self.name_hr) 75 | 76 | 77 | def augment_img(img, mode=0): 78 | '''Kai Zhang (github: https://github.com/cszn) 79 | ''' 80 | if mode == 0: 81 | return img 82 | elif mode == 1: 83 | return np.flipud(np.rot90(img)) 84 | elif mode == 2: 85 | return np.flipud(img) 86 | elif mode == 3: 87 | return np.rot90(img, k=3) 88 | elif mode == 4: 89 | return np.flipud(np.rot90(img, k=2)) 90 | elif mode == 5: 91 | return np.rot90(img) 92 | elif mode == 6: 93 | return np.rot90(img, k=2) 94 | elif mode == 7: 95 | return np.flipud(np.rot90(img, k=3)) -------------------------------------------------------------------------------- /swin_block.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------------- 2 | # SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 3 | # Originally Written by Ze Liu, Modified by Jingyun Liang. 4 | # ----------------------------------------------------------------------------------- 5 | 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.utils.checkpoint as checkpoint 11 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 12 | 13 | 14 | class Mlp(nn.Module): 15 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 16 | super().__init__() 17 | out_features = out_features or in_features 18 | hidden_features = hidden_features or in_features 19 | self.fc1 = nn.Linear(in_features, hidden_features) 20 | self.act = act_layer() 21 | self.fc2 = nn.Linear(hidden_features, out_features) 22 | self.drop = nn.Dropout(drop) 23 | 24 | def forward(self, x): 25 | x = self.fc1(x) 26 | x = self.act(x) 27 | x = self.drop(x) 28 | x = self.fc2(x) 29 | x = self.drop(x) 30 | return x 31 | 32 | 33 | def window_partition(x, window_size): 34 | """ 35 | Args: 36 | x: (B, H, W, C) 37 | window_size (int): window size 38 | 39 | Returns: 40 | windows: (num_windows*B, window_size, window_size, C) 41 | """ 42 | B, H, W, C = x.shape 43 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 44 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 45 | return windows 46 | 47 | 48 | def window_reverse(windows, window_size, H, W): 49 | """ 50 | Args: 51 | windows: (num_windows*B, window_size, window_size, C) 52 | window_size (int): Window size 53 | H (int): Height of image 54 | W (int): Width of image 55 | 56 | Returns: 57 | x: (B, H, W, C) 58 | """ 59 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 60 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 61 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 62 | return x 63 | 64 | 65 | class WindowAttention(nn.Module): 66 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 67 | It supports both of shifted and non-shifted window. 68 | 69 | Args: 70 | dim (int): Number of input channels. 71 | window_size (tuple[int]): The height and width of the window. 72 | num_heads (int): Number of attention heads. 73 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 74 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 75 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 76 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 77 | """ 78 | 79 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 80 | 81 | super().__init__() 82 | self.dim = dim 83 | self.window_size = window_size # Wh, Ww 84 | self.num_heads = num_heads 85 | head_dim = dim // num_heads 86 | self.scale = qk_scale or head_dim ** -0.5 87 | 88 | # define a parameter table of relative position bias 89 | self.relative_position_bias_table = nn.Parameter( 90 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 91 | 92 | # get pair-wise relative position index for each token inside the window 93 | coords_h = torch.arange(self.window_size[0]) 94 | coords_w = torch.arange(self.window_size[1]) 95 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 96 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 97 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 98 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 99 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 100 | relative_coords[:, :, 1] += self.window_size[1] - 1 101 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 102 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 103 | self.register_buffer("relative_position_index", relative_position_index) 104 | 105 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 106 | self.attn_drop = nn.Dropout(attn_drop) 107 | self.proj = nn.Linear(dim, dim) 108 | 109 | self.proj_drop = nn.Dropout(proj_drop) 110 | 111 | trunc_normal_(self.relative_position_bias_table, std=.02) 112 | self.softmax = nn.Softmax(dim=-1) 113 | 114 | def forward(self, x, mask=None): 115 | """ 116 | Args: 117 | x: input features with shape of (num_windows*B, N, C) 118 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 119 | """ 120 | B_, N, C = x.shape 121 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 122 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 123 | 124 | q = q * self.scale 125 | attn = (q @ k.transpose(-2, -1)) 126 | 127 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 128 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 129 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 130 | attn = attn + relative_position_bias.unsqueeze(0) 131 | 132 | if mask is not None: 133 | nW = mask.shape[0] 134 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 135 | attn = attn.view(-1, self.num_heads, N, N) 136 | attn = self.softmax(attn) 137 | else: 138 | attn = self.softmax(attn) 139 | 140 | attn = self.attn_drop(attn) 141 | 142 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 143 | x = self.proj(x) 144 | x = self.proj_drop(x) 145 | return x 146 | 147 | def extra_repr(self) -> str: 148 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 149 | 150 | def flops(self, N): 151 | # calculate flops for 1 window with token length of N 152 | flops = 0 153 | # qkv = self.qkv(x) 154 | flops += N * self.dim * 3 * self.dim 155 | # attn = (q @ k.transpose(-2, -1)) 156 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 157 | # x = (attn @ v) 158 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 159 | # x = self.proj(x) 160 | flops += N * self.dim * self.dim 161 | return flops 162 | 163 | 164 | class SwinTransformerBlock(nn.Module): 165 | r""" Swin Transformer Block. 166 | 167 | Args: 168 | dim (int): Number of input channels. 169 | input_resolution (tuple[int]): Input resulotion. 170 | num_heads (int): Number of attention heads. 171 | window_size (int): Window size. 172 | shift_size (int): Shift size for SW-MSA. 173 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 174 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 175 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 176 | drop (float, optional): Dropout rate. Default: 0.0 177 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 178 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 179 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 180 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 181 | """ 182 | 183 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 184 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 185 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 186 | super().__init__() 187 | self.dim = dim 188 | self.input_resolution = input_resolution 189 | self.num_heads = num_heads 190 | self.window_size = window_size 191 | self.shift_size = shift_size 192 | self.mlp_ratio = mlp_ratio 193 | if min(self.input_resolution) <= self.window_size: 194 | # if window size is larger than input resolution, we don't partition windows 195 | self.shift_size = 0 196 | self.window_size = min(self.input_resolution) 197 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 198 | 199 | self.norm1 = norm_layer(dim) 200 | self.attn = WindowAttention( 201 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 202 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 203 | 204 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 205 | self.norm2 = norm_layer(dim) 206 | mlp_hidden_dim = int(dim * mlp_ratio) 207 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 208 | 209 | if self.shift_size > 0: 210 | attn_mask = self.calculate_mask(self.input_resolution) 211 | else: 212 | attn_mask = None 213 | 214 | self.register_buffer("attn_mask", attn_mask) 215 | 216 | def calculate_mask(self, x_size): 217 | # calculate attention mask for SW-MSA 218 | H, W = x_size 219 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 220 | h_slices = (slice(0, -self.window_size), 221 | slice(-self.window_size, -self.shift_size), 222 | slice(-self.shift_size, None)) 223 | w_slices = (slice(0, -self.window_size), 224 | slice(-self.window_size, -self.shift_size), 225 | slice(-self.shift_size, None)) 226 | cnt = 0 227 | for h in h_slices: 228 | for w in w_slices: 229 | img_mask[:, h, w, :] = cnt 230 | cnt += 1 231 | 232 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 233 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 234 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 235 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 236 | 237 | return attn_mask 238 | 239 | def forward(self, x, x_size): 240 | H, W = x_size 241 | B, L, C = x.shape 242 | # assert L == H * W, "input feature has wrong size" 243 | 244 | shortcut = x 245 | x = self.norm1(x) 246 | x = x.view(B, H, W, C) 247 | 248 | # cyclic shift 249 | if self.shift_size > 0: 250 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 251 | else: 252 | shifted_x = x 253 | 254 | # partition windows 255 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 256 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 257 | 258 | # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size 259 | if self.input_resolution == x_size: 260 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 261 | else: 262 | attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) 263 | 264 | # merge windows 265 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 266 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 267 | 268 | # reverse cyclic shift 269 | if self.shift_size > 0: 270 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 271 | else: 272 | x = shifted_x 273 | x = x.view(B, H * W, C) 274 | 275 | # FFN 276 | x = shortcut + self.drop_path(x) 277 | x = x + self.drop_path(self.mlp(self.norm2(x))) 278 | 279 | return x 280 | 281 | def extra_repr(self) -> str: 282 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 283 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 284 | 285 | def flops(self): 286 | flops = 0 287 | H, W = self.input_resolution 288 | # norm1 289 | flops += self.dim * H * W 290 | # W-MSA/SW-MSA 291 | nW = H * W / self.window_size / self.window_size 292 | flops += nW * self.attn.flops(self.window_size * self.window_size) 293 | # mlp 294 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 295 | # norm2 296 | flops += self.dim * H * W 297 | return flops 298 | 299 | 300 | class BasicLayer(nn.Module): 301 | """ A basic Swin Transformer layer for one stage. 302 | 303 | Args: 304 | dim (int): Number of input channels. 305 | input_resolution (tuple[int]): Input resolution. 306 | depth (int): Number of blocks. 307 | num_heads (int): Number of attention heads. 308 | window_size (int): Local window size. 309 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 310 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 311 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 312 | drop (float, optional): Dropout rate. Default: 0.0 313 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 314 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 315 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 316 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 317 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 318 | """ 319 | 320 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 321 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 322 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 323 | 324 | super().__init__() 325 | self.dim = dim 326 | self.input_resolution = input_resolution 327 | self.depth = depth 328 | self.use_checkpoint = use_checkpoint 329 | 330 | # build blocks 331 | self.blocks = nn.ModuleList([ 332 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 333 | num_heads=num_heads, window_size=window_size, 334 | shift_size=0 if (i % 2 == 0) else window_size // 2, 335 | mlp_ratio=mlp_ratio, 336 | qkv_bias=qkv_bias, qk_scale=qk_scale, 337 | drop=drop, attn_drop=attn_drop, 338 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 339 | norm_layer=norm_layer) 340 | for i in range(depth)]) 341 | 342 | # patch merging layer 343 | if downsample is not None: 344 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 345 | else: 346 | self.downsample = None 347 | 348 | def forward(self, x, x_size): 349 | for blk in self.blocks: 350 | if self.use_checkpoint: 351 | x = checkpoint.checkpoint(blk, x, x_size) 352 | else: 353 | x = blk(x, x_size) 354 | if self.downsample is not None: 355 | x = self.downsample(x) 356 | return x 357 | 358 | def extra_repr(self) -> str: 359 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 360 | 361 | def flops(self): 362 | flops = 0 363 | for blk in self.blocks: 364 | flops += blk.flops() 365 | if self.downsample is not None: 366 | flops += self.downsample.flops() 367 | return flops 368 | 369 | 370 | class RSTB(nn.Module): 371 | """Residual Swin Transformer Block (RSTB). 372 | 373 | Args: 374 | dim (int): Number of input channels. 375 | input_resolution (tuple[int]): Input resolution. 376 | depth (int): Number of blocks. 377 | num_heads (int): Number of attention heads. 378 | window_size (int): Local window size. 379 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 380 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 381 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 382 | drop (float, optional): Dropout rate. Default: 0.0 383 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 384 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 385 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 386 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 387 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 388 | img_size: Input image size. 389 | patch_size: Patch size. 390 | resi_connection: The convolutional block before residual connection. 391 | """ 392 | 393 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 394 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 395 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, 396 | img_size=224, patch_size=4, resi_connection='1conv'): 397 | super(RSTB, self).__init__() 398 | 399 | self.dim = dim 400 | self.input_resolution = input_resolution 401 | 402 | self.residual_group = BasicLayer(dim=dim, 403 | input_resolution=input_resolution, 404 | depth=depth, 405 | num_heads=num_heads, 406 | window_size=window_size, 407 | mlp_ratio=mlp_ratio, 408 | qkv_bias=qkv_bias, qk_scale=qk_scale, 409 | drop=drop, attn_drop=attn_drop, 410 | drop_path=drop_path, 411 | norm_layer=norm_layer, 412 | downsample=downsample, 413 | use_checkpoint=use_checkpoint) 414 | 415 | if resi_connection == '1conv': 416 | self.conv = nn.Conv2d(dim, dim, 3, 1, 1) 417 | elif resi_connection == '3conv': 418 | # to save parameters and memory 419 | self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), 420 | nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), 421 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 422 | nn.Conv2d(dim // 4, dim, 3, 1, 1)) 423 | 424 | self.patch_embed = PatchEmbed( 425 | img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, 426 | norm_layer=None) 427 | 428 | self.patch_unembed = PatchUnEmbed( 429 | img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, 430 | norm_layer=None) 431 | 432 | def forward(self, x, x_size): 433 | return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x 434 | 435 | def flops(self): 436 | flops = 0 437 | flops += self.residual_group.flops() 438 | H, W = self.input_resolution 439 | flops += H * W * self.dim * self.dim * 9 440 | flops += self.patch_embed.flops() 441 | flops += self.patch_unembed.flops() 442 | 443 | return flops 444 | 445 | 446 | class PatchEmbed(nn.Module): 447 | r""" Image to Patch Embedding 448 | 449 | Args: 450 | img_size (int): Image size. Default: 224. 451 | patch_size (int): Patch token size. Default: 4. 452 | in_chans (int): Number of input image channels. Default: 3. 453 | embed_dim (int): Number of linear projection output channels. Default: 96. 454 | norm_layer (nn.Module, optional): Normalization layer. Default: None 455 | """ 456 | 457 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 458 | super().__init__() 459 | img_size = to_2tuple(img_size) 460 | patch_size = to_2tuple(patch_size) 461 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 462 | self.img_size = img_size 463 | self.patch_size = patch_size 464 | self.patches_resolution = patches_resolution 465 | self.num_patches = patches_resolution[0] * patches_resolution[1] 466 | 467 | self.in_chans = in_chans 468 | self.embed_dim = embed_dim 469 | 470 | if norm_layer is not None: 471 | self.norm = norm_layer(embed_dim) 472 | else: 473 | self.norm = None 474 | 475 | def forward(self, x): 476 | x = x.flatten(2).transpose(1, 2) # B Ph*Pw C 477 | if self.norm is not None: 478 | x = self.norm(x) 479 | return x 480 | 481 | def flops(self): 482 | flops = 0 483 | H, W = self.img_size 484 | if self.norm is not None: 485 | flops += H * W * self.embed_dim 486 | return flops 487 | 488 | 489 | class PatchUnEmbed(nn.Module): 490 | r""" Image to Patch Unembedding 491 | 492 | Args: 493 | img_size (int): Image size. Default: 224. 494 | patch_size (int): Patch token size. Default: 4. 495 | in_chans (int): Number of input image channels. Default: 3. 496 | embed_dim (int): Number of linear projection output channels. Default: 96. 497 | norm_layer (nn.Module, optional): Normalization layer. Default: None 498 | """ 499 | 500 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 501 | super().__init__() 502 | img_size = to_2tuple(img_size) 503 | patch_size = to_2tuple(patch_size) 504 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 505 | self.img_size = img_size 506 | self.patch_size = patch_size 507 | self.patches_resolution = patches_resolution 508 | self.num_patches = patches_resolution[0] * patches_resolution[1] 509 | 510 | self.in_chans = in_chans 511 | self.embed_dim = embed_dim 512 | 513 | def forward(self, x, x_size): 514 | B, HW, C = x.shape 515 | x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C 516 | return x 517 | 518 | def flops(self): 519 | flops = 0 520 | return flops 521 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | python test_HST.py --ckpt checkpoint/model_comp40.pt --comp_level 40 -------------------------------------------------------------------------------- /test_HST.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import torch.nn.utils as utils 5 | import cv2 6 | import numpy as np 7 | from collections import OrderedDict 8 | from model import HST 9 | import util_calculate_psnr_ssim as util 10 | 11 | parser = argparse.ArgumentParser(description='Test HST') 12 | 13 | parser.add_argument('--ckpt', type=str, default='', help='path to load checkpoint') 14 | parser.add_argument('--scale', type=int, default=4, help='SR scale, 4 is used in the competition') 15 | parser.add_argument('--window_size', type=int, default=8, help='window size, 8 is default') 16 | parser.add_argument('--comp_level', type=int, default=40, help='compression level, support 10, 20, 30, 40') 17 | parser.add_argument('--use_ensemble', action='store_true') 18 | 19 | args = parser.parse_args() 20 | 21 | 22 | weight = torch.load(args.ckpt, map_location=lambda storage, loc: storage) 23 | 24 | model = HST(img_size=64) 25 | 26 | model.load_state_dict(weight) 27 | model = model.cuda() 28 | 29 | test_paths = ['Set5_comp'+str(args.comp_level), 'Set14_comp'+str(args.comp_level), 'BSD100_comp'+str(args.comp_level), 'urban100_comp'+str(args.comp_level), 'manga109_comp'+str(args.comp_level)] 30 | 31 | output_paths = ['Set5_out', 'Set14_out', 'BSD100_out', 'urban100_out', 'manga109_out'] 32 | 33 | gts = ['Set5', 'Set14', 'BSD100', 'urban100', 'manga109'] 34 | 35 | 36 | def test(model, img): 37 | _, _, h_old, w_old = img.size() 38 | padding = args.scale * args.window_size 39 | h_pad = (h_old // padding + 1) * padding - h_old 40 | w_pad = (w_old // padding + 1) * padding - w_old 41 | img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, :h_old + h_pad, :] 42 | img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, :w_old + w_pad] 43 | 44 | img = model(img) 45 | img = img[..., :h_old * 4, :w_old * 4] 46 | return img 47 | 48 | for i in range(len(gts)): 49 | 50 | output_path = output_paths[i] 51 | test_path = test_paths[i] 52 | gt = gts[i] 53 | 54 | if not os.path.exists(output_path): 55 | os.makedirs(output_path) 56 | 57 | f = open(os.path.join(output_path, 'log.txt'),'w') 58 | 59 | model.eval() 60 | count = 0 61 | with torch.no_grad(): 62 | 63 | p = 0 64 | s = 0 65 | py = 0 66 | sy = 0 67 | 68 | for img_n in sorted(os.listdir(test_path)): 69 | count += 1 70 | lr = cv2.imread(os.path.join(test_path, img_n)) 71 | hr_n = img_n.split(".")[0] + '.png' 72 | hr = cv2.imread(os.path.join(gt, hr_n)) 73 | lr = cv2.cvtColor(lr, cv2.COLOR_BGR2RGB) 74 | 75 | img = np.ascontiguousarray(lr.transpose((2, 0, 1))) 76 | img = torch.from_numpy(img).float() 77 | img /= 255. 78 | img = img.unsqueeze(0).cuda() 79 | E = test(model, img) 80 | if args.use_ensemble: 81 | E1 = test(model, img.flip(-1)).flip(-1) 82 | E2 = test(model, img.flip(-2)).flip(-2) 83 | E3 = test(model, img.flip(-1, -2)).flip(-1, -2) 84 | L_t = img.transpose(-2, -1) 85 | E4 = test(model, L_t).transpose(-2, -1) 86 | E5 = test(model, L_t.flip(-1)).flip(-1).transpose(-2, -1) 87 | E6 = test(model, L_t.flip(-2)).flip(-2).transpose(-2, -1) 88 | E7 = test(model, L_t.flip(-1, -2)).flip(-1, -2).transpose(-2, -1) 89 | 90 | E = (E.clamp_(0, 1) + E1.clamp_(0, 1) + E2.clamp_(0, 1) + E3.clamp_(0, 1) + E4.clamp_(0, 1) + E5.clamp_(0, 1) + E6.clamp_(0, 1) + E7.clamp_(0, 1)) / 8.0 91 | 92 | img = E 93 | sr = img.detach().cpu().squeeze(0).numpy().transpose(1, 2, 0) 94 | 95 | sr = sr * 255. 96 | sr = np.clip(sr.round(), 0, 255).astype(np.uint8) 97 | 98 | sr = cv2.cvtColor(sr, cv2.COLOR_RGB2BGR) 99 | 100 | psnr = util.calculate_psnr(sr.copy(), hr.copy(), crop_border=4, test_y_channel=False) 101 | psnr_y = util.calculate_psnr(sr.copy(), hr.copy(), crop_border=4, test_y_channel=True) 102 | 103 | ssim = util.calculate_ssim(sr.copy(), hr.copy(), crop_border=4, test_y_channel=False) 104 | ssim_y = util.calculate_ssim(sr.copy(), hr.copy(), crop_border=4, test_y_channel=True) 105 | 106 | p += psnr 107 | s += ssim 108 | py += psnr_y 109 | sy += ssim_y 110 | f.write('{}: PSNR, {}. PSNR_Y, {}. SSIM, {}. SSIM_Y, {}.\n'.format(img_n, psnr, psnr_y, ssim, ssim_y)) 111 | 112 | cv2.imwrite(os.path.join(output_path, hr_n), sr) 113 | 114 | 115 | p /= count 116 | s /= count 117 | py /= count 118 | sy /= count 119 | print(p, py, s, sy) 120 | f.write('avg PSNR: {}, PSNR_Y: {}, SSIM: {}, SSIM_Y: {}.'.format(p, py, s, sy)) 121 | 122 | f.close() 123 | -------------------------------------------------------------------------------- /util_calculate_psnr_ssim.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False): 7 | """Calculate PSNR (Peak Signal-to-Noise Ratio). 8 | 9 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio 10 | 11 | Args: 12 | img1 (ndarray): Images with range [0, 255]. 13 | img2 (ndarray): Images with range [0, 255]. 14 | crop_border (int): Cropped pixels in each edge of an image. These 15 | pixels are not involved in the PSNR calculation. 16 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 17 | Default: 'HWC'. 18 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 19 | 20 | Returns: 21 | float: psnr result. 22 | """ 23 | 24 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 25 | if input_order not in ['HWC', 'CHW']: 26 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') 27 | img1 = reorder_image(img1, input_order=input_order) 28 | img2 = reorder_image(img2, input_order=input_order) 29 | img1 = img1.astype(np.float64) 30 | img2 = img2.astype(np.float64) 31 | 32 | if crop_border != 0: 33 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 34 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 35 | 36 | if test_y_channel: 37 | img1 = to_y_channel(img1) 38 | img2 = to_y_channel(img2) 39 | 40 | mse = np.mean((img1 - img2) ** 2) 41 | if mse == 0: 42 | return float('inf') 43 | return 20. * np.log10(255. / np.sqrt(mse)) 44 | 45 | 46 | def _ssim(img1, img2): 47 | """Calculate SSIM (structural similarity) for one channel images. 48 | 49 | It is called by func:`calculate_ssim`. 50 | 51 | Args: 52 | img1 (ndarray): Images with range [0, 255] with order 'HWC'. 53 | img2 (ndarray): Images with range [0, 255] with order 'HWC'. 54 | 55 | Returns: 56 | float: ssim result. 57 | """ 58 | 59 | C1 = (0.01 * 255) ** 2 60 | C2 = (0.03 * 255) ** 2 61 | 62 | img1 = img1.astype(np.float64) 63 | img2 = img2.astype(np.float64) 64 | kernel = cv2.getGaussianKernel(11, 1.5) 65 | window = np.outer(kernel, kernel.transpose()) 66 | 67 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] 68 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 69 | mu1_sq = mu1 ** 2 70 | mu2_sq = mu2 ** 2 71 | mu1_mu2 = mu1 * mu2 72 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 73 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 74 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 75 | 76 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 77 | return ssim_map.mean() 78 | 79 | 80 | def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False): 81 | """Calculate SSIM (structural similarity). 82 | 83 | Ref: 84 | Image quality assessment: From error visibility to structural similarity 85 | 86 | The results are the same as that of the official released MATLAB code in 87 | https://ece.uwaterloo.ca/~z70wang/research/ssim/. 88 | 89 | For three-channel images, SSIM is calculated for each channel and then 90 | averaged. 91 | 92 | Args: 93 | img1 (ndarray): Images with range [0, 255]. 94 | img2 (ndarray): Images with range [0, 255]. 95 | crop_border (int): Cropped pixels in each edge of an image. These 96 | pixels are not involved in the SSIM calculation. 97 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 98 | Default: 'HWC'. 99 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 100 | 101 | Returns: 102 | float: ssim result. 103 | """ 104 | 105 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 106 | if input_order not in ['HWC', 'CHW']: 107 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') 108 | img1 = reorder_image(img1, input_order=input_order) 109 | img2 = reorder_image(img2, input_order=input_order) 110 | img1 = img1.astype(np.float64) 111 | img2 = img2.astype(np.float64) 112 | 113 | if crop_border != 0: 114 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 115 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 116 | 117 | if test_y_channel: 118 | img1 = to_y_channel(img1) 119 | img2 = to_y_channel(img2) 120 | 121 | ssims = [] 122 | for i in range(img1.shape[2]): 123 | ssims.append(_ssim(img1[..., i], img2[..., i])) 124 | return np.array(ssims).mean() 125 | 126 | 127 | def _blocking_effect_factor(im): 128 | block_size = 8 129 | 130 | block_horizontal_positions = torch.arange(7, im.shape[3] - 1, 8) 131 | block_vertical_positions = torch.arange(7, im.shape[2] - 1, 8) 132 | 133 | horizontal_block_difference = ( 134 | (im[:, :, :, block_horizontal_positions] - im[:, :, :, block_horizontal_positions + 1]) ** 2).sum( 135 | 3).sum(2).sum(1) 136 | vertical_block_difference = ( 137 | (im[:, :, block_vertical_positions, :] - im[:, :, block_vertical_positions + 1, :]) ** 2).sum(3).sum( 138 | 2).sum(1) 139 | 140 | nonblock_horizontal_positions = np.setdiff1d(torch.arange(0, im.shape[3] - 1), block_horizontal_positions) 141 | nonblock_vertical_positions = np.setdiff1d(torch.arange(0, im.shape[2] - 1), block_vertical_positions) 142 | 143 | horizontal_nonblock_difference = ( 144 | (im[:, :, :, nonblock_horizontal_positions] - im[:, :, :, nonblock_horizontal_positions + 1]) ** 2).sum( 145 | 3).sum(2).sum(1) 146 | vertical_nonblock_difference = ( 147 | (im[:, :, nonblock_vertical_positions, :] - im[:, :, nonblock_vertical_positions + 1, :]) ** 2).sum( 148 | 3).sum(2).sum(1) 149 | 150 | n_boundary_horiz = im.shape[2] * (im.shape[3] // block_size - 1) 151 | n_boundary_vert = im.shape[3] * (im.shape[2] // block_size - 1) 152 | boundary_difference = (horizontal_block_difference + vertical_block_difference) / ( 153 | n_boundary_horiz + n_boundary_vert) 154 | 155 | n_nonboundary_horiz = im.shape[2] * (im.shape[3] - 1) - n_boundary_horiz 156 | n_nonboundary_vert = im.shape[3] * (im.shape[2] - 1) - n_boundary_vert 157 | nonboundary_difference = (horizontal_nonblock_difference + vertical_nonblock_difference) / ( 158 | n_nonboundary_horiz + n_nonboundary_vert) 159 | 160 | scaler = np.log2(block_size) / np.log2(min([im.shape[2], im.shape[3]])) 161 | bef = scaler * (boundary_difference - nonboundary_difference) 162 | 163 | bef[boundary_difference <= nonboundary_difference] = 0 164 | return bef 165 | 166 | 167 | def calculate_psnrb(img1, img2, crop_border, input_order='HWC', test_y_channel=False): 168 | """Calculate PSNR-B (Peak Signal-to-Noise Ratio). 169 | 170 | Ref: Quality assessment of deblocked images, for JPEG image deblocking evaluation 171 | # https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py 172 | 173 | Args: 174 | img1 (ndarray): Images with range [0, 255]. 175 | img2 (ndarray): Images with range [0, 255]. 176 | crop_border (int): Cropped pixels in each edge of an image. These 177 | pixels are not involved in the PSNR calculation. 178 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 179 | Default: 'HWC'. 180 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 181 | 182 | Returns: 183 | float: psnr result. 184 | """ 185 | 186 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 187 | if input_order not in ['HWC', 'CHW']: 188 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') 189 | img1 = reorder_image(img1, input_order=input_order) 190 | img2 = reorder_image(img2, input_order=input_order) 191 | img1 = img1.astype(np.float64) 192 | img2 = img2.astype(np.float64) 193 | 194 | if crop_border != 0: 195 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 196 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 197 | 198 | if test_y_channel: 199 | img1 = to_y_channel(img1) 200 | img2 = to_y_channel(img2) 201 | 202 | # follow https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py 203 | img1 = torch.from_numpy(img1).permute(2, 0, 1).unsqueeze(0) / 255. 204 | img2 = torch.from_numpy(img2).permute(2, 0, 1).unsqueeze(0) / 255. 205 | 206 | total = 0 207 | for c in range(img1.shape[1]): 208 | mse = torch.nn.functional.mse_loss(img1[:, c:c + 1, :, :], img2[:, c:c + 1, :, :], reduction='none') 209 | bef = _blocking_effect_factor(img1[:, c:c + 1, :, :]) 210 | 211 | mse = mse.view(mse.shape[0], -1).mean(1) 212 | total += 10 * torch.log10(1 / (mse + bef)) 213 | 214 | return float(total) / img1.shape[1] 215 | 216 | 217 | def reorder_image(img, input_order='HWC'): 218 | """Reorder images to 'HWC' order. 219 | 220 | If the input_order is (h, w), return (h, w, 1); 221 | If the input_order is (c, h, w), return (h, w, c); 222 | If the input_order is (h, w, c), return as it is. 223 | 224 | Args: 225 | img (ndarray): Input image. 226 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 227 | If the input image shape is (h, w), input_order will not have 228 | effects. Default: 'HWC'. 229 | 230 | Returns: 231 | ndarray: reordered image. 232 | """ 233 | 234 | if input_order not in ['HWC', 'CHW']: 235 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'") 236 | if len(img.shape) == 2: 237 | img = img[..., None] 238 | if input_order == 'CHW': 239 | img = img.transpose(1, 2, 0) 240 | return img 241 | 242 | 243 | def to_y_channel(img): 244 | """Change to Y channel of YCbCr. 245 | 246 | Args: 247 | img (ndarray): Images with range [0, 255]. 248 | 249 | Returns: 250 | (ndarray): Images with range [0, 255] (float type) without round. 251 | """ 252 | img = img.astype(np.float32) / 255. 253 | if img.ndim == 3 and img.shape[2] == 3: 254 | img = bgr2ycbcr(img, y_only=True) 255 | img = img[..., None] 256 | return img * 255. 257 | 258 | 259 | def _convert_input_type_range(img): 260 | """Convert the type and range of the input image. 261 | 262 | It converts the input image to np.float32 type and range of [0, 1]. 263 | It is mainly used for pre-processing the input image in colorspace 264 | convertion functions such as rgb2ycbcr and ycbcr2rgb. 265 | 266 | Args: 267 | img (ndarray): The input image. It accepts: 268 | 1. np.uint8 type with range [0, 255]; 269 | 2. np.float32 type with range [0, 1]. 270 | 271 | Returns: 272 | (ndarray): The converted image with type of np.float32 and range of 273 | [0, 1]. 274 | """ 275 | img_type = img.dtype 276 | img = img.astype(np.float32) 277 | if img_type == np.float32: 278 | pass 279 | elif img_type == np.uint8: 280 | img /= 255. 281 | else: 282 | raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}') 283 | return img 284 | 285 | 286 | def _convert_output_type_range(img, dst_type): 287 | """Convert the type and range of the image according to dst_type. 288 | 289 | It converts the image to desired type and range. If `dst_type` is np.uint8, 290 | images will be converted to np.uint8 type with range [0, 255]. If 291 | `dst_type` is np.float32, it converts the image to np.float32 type with 292 | range [0, 1]. 293 | It is mainly used for post-processing images in colorspace convertion 294 | functions such as rgb2ycbcr and ycbcr2rgb. 295 | 296 | Args: 297 | img (ndarray): The image to be converted with np.float32 type and 298 | range [0, 255]. 299 | dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it 300 | converts the image to np.uint8 type with range [0, 255]. If 301 | dst_type is np.float32, it converts the image to np.float32 type 302 | with range [0, 1]. 303 | 304 | Returns: 305 | (ndarray): The converted image with desired type and range. 306 | """ 307 | if dst_type not in (np.uint8, np.float32): 308 | raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}') 309 | if dst_type == np.uint8: 310 | img = img.round() 311 | else: 312 | img /= 255. 313 | return img.astype(dst_type) 314 | 315 | 316 | def bgr2ycbcr(img, y_only=False): 317 | """Convert a BGR image to YCbCr image. 318 | 319 | The bgr version of rgb2ycbcr. 320 | It implements the ITU-R BT.601 conversion for standard-definition 321 | television. See more details in 322 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. 323 | 324 | It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. 325 | In OpenCV, it implements a JPEG conversion. See more details in 326 | https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. 327 | 328 | Args: 329 | img (ndarray): The input image. It accepts: 330 | 1. np.uint8 type with range [0, 255]; 331 | 2. np.float32 type with range [0, 1]. 332 | y_only (bool): Whether to only return Y channel. Default: False. 333 | 334 | Returns: 335 | ndarray: The converted YCbCr image. The output image has the same type 336 | and range as input image. 337 | """ 338 | img_type = img.dtype 339 | img = _convert_input_type_range(img) 340 | if y_only: 341 | out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 342 | else: 343 | out_img = np.matmul( 344 | img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128] 345 | out_img = _convert_output_type_range(out_img, img_type) 346 | return out_img 347 | -------------------------------------------------------------------------------- /utils_logger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import datetime 3 | import logging 4 | 5 | 6 | ''' 7 | # -------------------------------------------- 8 | # Kai Zhang (github: https://github.com/cszn) 9 | # 03/Mar/2019 10 | # -------------------------------------------- 11 | # https://github.com/xinntao/BasicSR 12 | # -------------------------------------------- 13 | ''' 14 | 15 | 16 | def log(*args, **kwargs): 17 | print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs) 18 | 19 | 20 | ''' 21 | # -------------------------------------------- 22 | # logger 23 | # -------------------------------------------- 24 | ''' 25 | 26 | 27 | def logger_info(logger_name, log_path='default_logger.log', mode='a'): 28 | ''' set up logger 29 | modified by Kai Zhang (github: https://github.com/cszn) 30 | ''' 31 | log = logging.getLogger(logger_name) 32 | if log.hasHandlers(): 33 | print('LogHandlers exist!') 34 | else: 35 | print('LogHandlers setup!') 36 | level = logging.INFO 37 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S') 38 | fh = logging.FileHandler(log_path, mode=mode) 39 | fh.setFormatter(formatter) 40 | log.setLevel(level) 41 | log.addHandler(fh) 42 | # print(len(log.handlers)) 43 | 44 | sh = logging.StreamHandler() 45 | sh.setFormatter(formatter) 46 | log.addHandler(sh) 47 | 48 | 49 | ''' 50 | # -------------------------------------------- 51 | # print to file and std_out simultaneously 52 | # -------------------------------------------- 53 | ''' 54 | 55 | 56 | class logger_print(object): 57 | def __init__(self, log_path="default.log"): 58 | self.terminal = sys.stdout 59 | self.log = open(log_path, 'a') 60 | 61 | def write(self, message): 62 | self.terminal.write(message) 63 | self.log.write(message) # write the message 64 | 65 | def flush(self): 66 | pass 67 | --------------------------------------------------------------------------------