├── README.md ├── datasets.py ├── models.py ├── src ├── TMM1.jpg ├── feature_vis.jpg ├── performance-1.jpg ├── performance-2.jpg └── task.jpg ├── test.py ├── test_resnet50.py └── util.py /README.md: -------------------------------------------------------------------------------- 1 | # Lightweight-Adaptive-Feature-De-drifting-for-Compressed-Image-Classification 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2401.01724) [![Project](https://img.shields.io/badge/Project-Page-blue.svg)](https://arxiv.org/pdf/2411.10798) 4 | 5 | 6 | > **Lightweight-Adaptive-Feature-De-drifting-for-Compressed-Image-Classification**
7 | > Long Peng1, Yang Cao1, Yuejin Sun1, Yang Wang1
8 | > 1 University of Science and Technology of China 9 | 10 | ## :bookmark: News!!! 11 | - [x] 2023-12-1 **Accepted by IEEE Transactions on Multimedia.** 12 | 13 | ## Abstract 14 | 15 | JPEG is a widely used compression scheme to efficiently reduce the volume of the transmitted images at the expense of visual perception drop. The artifacts appear among blocks due to the information loss in the compression process, which not only affects the quality of images but also harms the subsequent high-level tasks in terms of feature drifting. High-level vision models trained on high-quality images will suffer performance degradation when dealing with compressed images, especially on mobile devices. In recent years, numerous learning-based JPEG artifact removal methods have been proposed to handle visual artifacts. However, it is not an ideal choice to use these JPEG artifact removal methods as a pre-processing for compressed image classification for the following reasons: 1) These methods are designed for human vision rather than high-level vision models. 2) These methods are not efficient enough to serve as a pre-processing on resource-constrained devices. To address these issues, this paper proposes a novel lightweight adaptive feature de-drifting module (AFD-Module) to boost the performance of pre-trained image classification models when facing compressed images. First, a Feature Drifting Estimation Network (FDE-Net) is devised to generate the spatial-wise Feature Drifting Map (FDM) in the DCT domain. Next, the estimated FDM is transmitted to the Feature Enhancement Network (FE-Net) to generate the mapping relationship between degraded features and corresponding high-quality features. Specially, a simple but effective RepConv block equipped with structural re-parameterization is utilized in FE-Net, which enriches feature representation in the training phase while keeping efficiency in the deployment phase. After training on limited compressed images, the AFD-Module can serve as a “plug-and-play” module for pre-trained classification models to improve their performance on compressed images. Experiments on images compressed once (i.e. ImageNet-C) and multiple times demonstrate that our proposed AFD-Module can comprehensively improve the accuracy of the pre-trained classification models and significantly outperform the existing methods. 16 | 17 | ## Motivation 18 | ![Motivation](src/task.jpg) 19 | In real-world scenarios, JPEG compression often results in a loss of fine details within the pixel space of an image. Moreover, this compression can also lead to significant degradation or damage in the feature space, which negatively impacts the content representation. As a result, many vision tasks that perform well on high-quality, clear images tend to fail or become ineffective under such conditions. 20 | 21 | ## Model 22 | ![TMM1](src/TMM1.jpg) 23 | 24 | ## Performance 25 | ![performance-1](src/performance-1.jpg) 26 | ![performance-2](src/performance-2.jpg) 27 | 28 | ## Feature Enhancement 29 | ![feature_vis](src/feature_vis.jpg) 30 | 31 | ## Data 32 | 33 | We synthesized JPEG images under various scenarios for testing. The dataset includes images with different compression quality factors (QF): **7**, **10**, **15**, **18**, and **25**. A lower QF indicates a higher compression level and lower image quality. 34 | 35 | You can download the dataset from the following Google Drive link: 36 | 37 | [Google Drive Dataset Link](https://drive.google.com/drive/folders/1_Z96FMjqNCtATiYEbTFHTuKsQEqj-s4k?usp=drive_link) 38 | 39 | ### Test on ImageNet 40 | 41 | #### Resnet50 Pretrained by Pytorch 42 | 43 | To test the Resnet50 model pretrained by Pytorch, run the following command: 44 | 45 | ```bash 46 | python3 test_resnet50.py 47 | ``` 48 | **Table 1: Performance of Resnet50 Pretrained by Pytorch on JPEG data ImageNet** 49 | 50 | | JPEG-QF | Top1 Acc $\uparrow$ | 51 | | ------------------ | ---------------- | 52 | | 7 | 33.124 | 53 | | 10 | 47.216 | 54 | | 15 | 57.313 | 55 | | 18 | 60.404 | 56 | | 25 | 63.721 | 57 | | Without JPEG | 76.018 | 58 | 59 | As the QF decreases, the performance of the model, which was originally 76% in the clear (without JPEG) scenario, drops to 33%. This indicates that JPEG compression has a significant impact on the performance of ResNet50. 60 | 61 | 62 | Due to the patent application and confidentiality of our method's design, we are releasing a baseline version of the model. However, its performance is nearly identical to the results reported in the paper, making it available for further research and practical use. 63 | 64 | [Google Drive Model Link](https://drive.google.com/drive/folders/1_Z96FMjqNCtATiYEbTFHTuKsQEqj-s4k?usp=drive_link) 65 | 66 | ```bash 67 | python3 test.py 68 | ``` 69 | **Table 2: Performance Comparison** 70 | 71 | | JPEG-QF | Top1 Acc $\uparrow$ | 72 | | ------------------ | ---------------- | 73 | | 7 | 33.124 | 74 | | 7(**Ours**) | **60.459** | 75 | | 10 | 47.216 | 76 | | 10(**Ours**) | **64.711** | 77 | | 15 | 57.313 | 78 | | 15(**Ours**) | **67.919** | 79 | | 18 | 60.404 | 80 | | 18(**Ours**) | **68.948** | 81 | | 25 | 63.721 | 82 | | 25(**Ours**) | **70.171** | 83 | 84 | ## Cite US 85 | Contact email for Long Peng: longp2001@mail.ustc.edu.cn. Please cite us if this work is helpful to you. 86 | ``` 87 | @ARTICLE{10400436, 88 | author={Peng, Long and Cao, Yang and Sun, Yuejin and Wang, Yang}, 89 | journal={IEEE Transactions on Multimedia}, 90 | title={Lightweight Adaptive Feature De-Drifting for Compressed Image Classification}, 91 | year={2024}, 92 | volume={26}, 93 | number={}, 94 | pages={6424-6436}, 95 | keywords={Image coding;Transform coding;Discrete cosine transforms;Feature extraction;Performance evaluation;Mobile handsets;Image recognition;Feature drifting;feature enhancement;image classification;JPEG compression}, 96 | doi={10.1109/TMM.2024.3350917}} 97 | ``` 98 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | 2 | import glob 3 | import random 4 | import os 5 | import numpy as np 6 | import torch 7 | import cv2 8 | import torchvision.transforms 9 | 10 | from torch.utils.data import Dataset 11 | from PIL import Image 12 | import torchvision.transforms as transforms 13 | import torchvision.transforms.functional as TF 14 | from torch.utils.data import DataLoader 15 | from matplotlib import pyplot as plt 16 | import time 17 | from tqdm import tqdm 18 | import random 19 | from os.path import join 20 | from pillow_heif import register_heif_opener 21 | register_heif_opener() 22 | 23 | 24 | class ImageDataset_DDP(Dataset): 25 | def __init__(self,input_root,gt_root,all=False,format='.png'): 26 | 27 | self.input_root=input_root 28 | self.gt_root=gt_root 29 | self.format=format 30 | if all == False: 31 | self.img_names = [os.path.join(input_root,x) for x in os.listdir(self.input_root)] 32 | else: 33 | self.img_names=[] 34 | f = os.listdir(self.input_root) 35 | for i in f: 36 | img_names = os.listdir(os.path.join(input_root,i)) 37 | for names in img_names: 38 | self.img_names.append( os.path.join(input_root,i,names) ) 39 | 40 | self.tf = transforms.Compose([ 41 | 42 | # transforms.CenterCrop(224), 43 | transforms.ToTensor(), 44 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 45 | std=[0.229, 0.224, 0.225]), 46 | 47 | ]) 48 | 49 | 50 | self.crop_size = 224 51 | 52 | 53 | def __getitem__(self, index): 54 | 55 | 56 | img_name = self.img_names[index % len(self.img_names)] 57 | gt_name = img_name.split('/')[-1].split('_')[0] +self.format 58 | # gt_name = img_name.split('/')[-1].split('.')[0] +self.format 59 | 60 | img_input = Image.open(img_name) 61 | img_gt = Image.open(os.path.join(self.gt_root,gt_name)) 62 | 63 | 64 | 65 | i, j, h, w = transforms.RandomCrop.get_params(img_input, output_size=(self.crop_size, self.crop_size)) 66 | img_input = TF.crop(img_input, i, j, h, w) 67 | img_gt = TF.crop(img_gt, i, j, h, w) 68 | 69 | 70 | if random.random() > 0.1: 71 | angle = random.randint(-180, 180) 72 | img_input = TF.rotate(img_input, angle) 73 | img_gt = TF.rotate(img_gt, angle) 74 | 75 | 76 | return self.tf(img_input) , self.tf(img_gt) 77 | 78 | 79 | 80 | def __len__(self): 81 | return len(self.img_names) 82 | 83 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Date: 2023-08-21 10:39:10 3 | LastEditors: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git 4 | LastEditTime: 2023-08-30 10:28:49 5 | FilePath: /Reproduce_PL/models.py 6 | ''' 7 | from email import utils 8 | from gzip import READ 9 | from pickle import FRAME 10 | import torch 11 | import torch.nn as nn 12 | import numpy as np 13 | import torch.nn.init as init 14 | import torch.nn.functional as F 15 | import math 16 | from torch.nn.modules.pooling import FractionalMaxPool2d 17 | import torchvision 18 | from util import * 19 | class dct(nn.Module): 20 | ''' 21 | zhihu version of dct 22 | ''' 23 | def __init__(self): 24 | super(dct, self).__init__() 25 | 26 | 27 | self.dct_conv = nn.Conv2d(3,192,8,8,bias=False, groups=3) # 3 h w -> 192 h/8 w/8 28 | self.weight = torch.from_numpy(np.load('models/DCTmtx.npy')).float().permute(2,0,1).unsqueeze(1)# 64 1 8 8, order in Z 29 | self.dct_conv.weight.data = torch.cat([self.weight] * 3, dim=0) # 192 1 8 8 30 | self.dct_conv.weight.requires_grad = False 31 | 32 | self.mean = torch.Tensor([[[[0.485, 0.456, 0.406]]]]).reshape(1, 3, 1, 33 | 1) 34 | self.std = torch.Tensor([[[[0.229, 0.224, 0.225]]]]).reshape(1, 3, 1, 35 | 1) 36 | self.Ycbcr = nn.Conv2d(3, 3, 1, 1, bias=False) 37 | trans_matrix = np.array([[0.299, 0.587, 0.114], 38 | [-0.169, -0.331, 0.5], 39 | [0.5, -0.419, -0.081]]) 40 | trans_matrix = torch.from_numpy(trans_matrix).float().unsqueeze( 41 | 2).unsqueeze(3) 42 | self.Ycbcr.weight.data = trans_matrix 43 | self.Ycbcr.weight.requires_grad = False 44 | 45 | self.reYcbcr = nn.Conv2d(3, 3, 1, 1, bias=False) 46 | re_matrix = np.linalg.pinv(np.array([[0.299, 0.587, 0.114], 47 | [-0.169, -0.331, 0.5], 48 | [0.5, -0.419, -0.081]])) 49 | re_matrix = torch.from_numpy(re_matrix).float().unsqueeze( 50 | 2).unsqueeze(3) 51 | self.reYcbcr.weight.data = re_matrix 52 | 53 | def forward(self, x): 54 | 55 | # jpg = (jpg * self.std) + self.mean # 0-1 56 | ycbcr = self.Ycbcr(x) # b 3 h w 57 | 58 | dct = self.dct_conv(ycbcr) 59 | return dct 60 | 61 | def reverse(self,x): 62 | dct = F.conv_transpose2d(x, torch.cat([self.weight] * 3,0), bias=None, stride=8, groups = 3) 63 | rgb = self.reYcbcr(dct) 64 | return rgb 65 | 66 | 67 | class RepConv(nn.Module): 68 | def __init__(self, in_channels, out_channels, kernel_size=3, padding=1): 69 | super(RepConv, self).__init__() 70 | self.conv1 = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding=0), 71 | nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1)) 72 | self.conv3 = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1), 73 | nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=1, padding=0)) 74 | self.conv1_1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding=0) 75 | 76 | def forward(self, x): 77 | 78 | out = self.conv1(x) + self.conv3(x) + self.conv1_1(x) 79 | return out 80 | 81 | def conv3x3(in_channal,out_channal): 82 | return nn.Conv2d(in_channal,out_channal,3,1,1) 83 | 84 | 85 | Conv_type = RepConv 86 | class UNet_Dropout(nn.Module): 87 | def __init__(self, in_channal = 64,channal=16): 88 | super(UNet_Dropout, self).__init__() 89 | 90 | self.n_channels = in_channal 91 | 92 | 93 | self.inc = Conv_type(in_channal,channal) 94 | 95 | self.down1 = Conv_type(channal,channal*2) 96 | self.down2 = Conv_type(channal*2,channal*4) 97 | 98 | 99 | self.up1 = Conv_type(channal*4,channal*2) 100 | self.up2 = Conv_type(channal*4,channal) 101 | 102 | self.outc = Conv_type(channal*2, in_channal) 103 | 104 | self.relu = nn.PReLU() 105 | 106 | self.maxpool = nn.MaxPool2d(stride=2,kernel_size=2) 107 | self.uppool = nn.UpsamplingBilinear2d(scale_factor=2) 108 | self.dropout = nn.Dropout2d(p=0.5) 109 | def forward(self, x): 110 | 111 | x1 = self.relu( self.inc(x) ) #112 112 | d1 = self.maxpool(self.relu( self.down1(x1) )) #56 113 | d2 = self.maxpool(self.relu( self.down2(d1) )) #28 114 | 115 | x = self.uppool(self.relu( self.up1(d2) )) #56 116 | x = self.uppool(self.relu( self.up2(torch.cat([x,d1],dim=1) ) )) #112 117 | x = self.dropout(x) 118 | x = self.outc(torch.cat([x,x1],dim=1)) 119 | 120 | return x 121 | 122 | 123 | class UNet(nn.Module): 124 | def __init__(self, in_channal = 64,channal=16,out_cha=64): 125 | super(UNet, self).__init__() 126 | 127 | self.n_channels = in_channal 128 | 129 | 130 | self.inc = Conv_type(in_channal,channal) 131 | 132 | self.down1 = Conv_type(channal,channal*2) 133 | self.down2 = Conv_type(channal*2,channal*4) 134 | 135 | 136 | self.up1 = Conv_type(channal*4,channal*2) 137 | self.up2 = Conv_type(channal*4,channal) 138 | 139 | self.outc = Conv_type(channal*2, out_cha) 140 | 141 | self.relu = nn.PReLU() 142 | 143 | self.maxpool = nn.MaxPool2d(stride=2,kernel_size=2) 144 | self.uppool = nn.UpsamplingBilinear2d(scale_factor=2) 145 | def forward(self, x): 146 | 147 | x1 = self.relu( self.inc(x) ) #112 148 | d1 = self.maxpool(self.relu( self.down1(x1) )) #56 149 | d2 = self.maxpool(self.relu( self.down2(d1) )) #28 150 | 151 | x = self.uppool(self.relu( self.up1(d2) )) #56 152 | x = self.uppool(self.relu( self.up2(torch.cat([x,d1],dim=1) ) )) #112 153 | x = self.outc(torch.cat([x,x1],dim=1)) 154 | 155 | return x 156 | 157 | 158 | 159 | 160 | class DCNN(nn.Module): 161 | def __init__(self, in_channal = 64,channal=32): 162 | super(DCNN, self).__init__() 163 | 164 | self.n_channels = in_channal 165 | self.c1 = nn.Conv2d(in_channal,in_channal,3,1,1) 166 | self.c2 = nn.Conv2d(in_channal,in_channal,3,1,1) 167 | self.c3 = nn.Conv2d(in_channal,in_channal,3,1,1) 168 | self.c4 = nn.Conv2d(in_channal,in_channal,3,1,1) 169 | 170 | self.relu = nn.PReLU() 171 | 172 | def forward(self, x): 173 | 174 | x = self.c1(x) 175 | x = self.c2(x) 176 | x = self.c3(x) 177 | x = self.c4(x) 178 | return x 179 | 180 | 181 | 182 | class FDM_dct_rep(nn.Module): 183 | ''' 184 | 参数量 119800 185 | ''' 186 | def __init__(self ): 187 | super(FDM_dct_rep, self).__init__() 188 | 189 | self.fdm = UNet(in_channal=64, channal=64,out_cha=64) 190 | 191 | self.reverse = InverShift() 192 | self.rgb2ycbcr = YcbcrShift() 193 | 194 | self.dct_conv = nn.Conv2d(1,64,8,8,bias=False) 195 | self.dct_conv.weight.data = torch.from_numpy(np.load('/home/date/Trans/TMM_JPEG/Reproduce_PL/DCTmtx.npy')).float().permute(2,0,1).unsqueeze(1) 196 | self.dct_conv.weight.requires_grad = False 197 | 198 | self.guide = nn.Sequential( 199 | nn.Conv2d(192, 128, 1, 1), 200 | nn.ReLU(), 201 | nn.Conv2d(128, 64, 1, 1), 202 | nn.ReLU(), 203 | nn.AdaptiveAvgPool2d(1), 204 | nn.Conv2d(64, 64, 1, 1), 205 | nn.Sigmoid() 206 | ) 207 | 208 | def forward(self, x, jpg): 209 | ''' 210 | 211 | :param x: feature, 64*112*112, recpetive field is 7*7 212 | :param jpg: normalazitoed, 3*224*224 213 | :return: 214 | ''' 215 | rejpg = self.reverse(jpg) 216 | b,c,h,w = rejpg.shape 217 | ycbcr = self.rgb2ycbcr(rejpg) # b 3 h w 218 | 219 | dct_y = self.dct_conv(ycbcr[:, 0:1, :, :]) 220 | dct_cb = self.dct_conv(ycbcr[:, 1:2, :, :]) 221 | dct_cr = self.dct_conv(ycbcr[:, 2:3, :, :]) 222 | dct = torch.cat([dct_y, dct_cb, dct_cr], dim=1) 223 | 224 | guide = self.guide(dct) 225 | fd = self.fdm(x) 226 | 227 | out = F.relu( fd*guide + x ) 228 | 229 | return out 230 | 231 | 232 | 233 | 234 | 235 | class FDM_dct_rep_HIEF(nn.Module): 236 | ''' 237 | 参数量 119800 238 | ''' 239 | def __init__(self,input_ch ): 240 | super(FDM_dct_rep_HIEF, self).__init__() 241 | 242 | self.fdm = UNet(in_channal=input_ch, channal=input_ch//2) 243 | 244 | self.reverse = InverShift() 245 | self.rgb2ycbcr = YcbcrShift() 246 | 247 | self.dct_conv = nn.Conv2d(1,64,8,8,bias=False) 248 | self.dct_conv.weight.data = torch.from_numpy(np.load('/home/date/Trans/TMM_JPEG/Reproduce_PL/DCTmtx.npy')).float().permute(2,0,1).unsqueeze(1) 249 | self.dct_conv.weight.requires_grad = False 250 | 251 | self.guide = nn.Sequential( 252 | nn.Conv2d(192, input_ch, 1, 1), 253 | nn.ReLU(), 254 | nn.Conv2d(input_ch, input_ch, 1, 1), 255 | nn.ReLU(), 256 | nn.AdaptiveAvgPool2d(1), 257 | nn.Conv2d(input_ch, input_ch, 1, 1), 258 | nn.Sigmoid() 259 | ) 260 | 261 | def forward(self, x, jpg): 262 | ''' 263 | 264 | :param x: feature, 64*112*112, recpetive field is 7*7 265 | :param jpg: normalazitoed, 3*224*224 266 | :return: 267 | ''' 268 | rejpg = self.reverse(jpg) 269 | b,c,h,w = rejpg.shape 270 | ycbcr = self.rgb2ycbcr(rejpg) # b 3 h w 271 | 272 | dct_y = self.dct_conv(ycbcr[:, 0:1, :, :]) 273 | dct_cb = self.dct_conv(ycbcr[:, 1:2, :, :]) 274 | dct_cr = self.dct_conv(ycbcr[:, 2:3, :, :]) 275 | dct = torch.cat([dct_y, dct_cb, dct_cr], dim=1) 276 | 277 | 278 | guide = self.guide(dct) 279 | fd = self.fdm(x) 280 | 281 | out = F.relu( fd*guide + x ) 282 | 283 | return out 284 | 285 | class FDM_DDP(nn.Module): 286 | #参数量 117472 287 | def __init__(self): 288 | super(FDM_DDP, self).__init__() 289 | 290 | self.g1_1 = conv_relu(67, 64, 3, 1, 1) #这里表示concat了原始图像 291 | self.g1_2 = conv_relu(64, 64, 3, 1, 1) 292 | self.g2_1 = conv_relu(64, 32, 3, 1, 1) 293 | self.g2_2 = conv_relu(32, 32, 3, 1, 1) 294 | self.g3_1 = conv_relu(32, 16, 3, 1, 1) 295 | self.g3_2 = conv_relu(16, 16, 3, 1, 1) 296 | # self.w = conv_relu(224, 128, 1, 1, 0) 297 | self.w = nn.Conv2d(112, 64, 1, 1, 0) 298 | 299 | 300 | def forward(self, x,jpegs): 301 | jpegs = F.interpolate(jpegs,x.size()[2:]) 302 | x1 = self.g1_1(torch.cat([x,jpegs], dim=1)) 303 | x1 = self.g1_2(x1) 304 | 305 | x2 = self.g2_1(x1) 306 | x2 = self.g2_2(x2) 307 | 308 | x3 = self.g3_1(x2) 309 | x3 = self.g3_2(x3) 310 | 311 | out = F.relu(self.w(torch.cat([x1,x2,x3], dim=1))) 312 | 313 | return out 314 | 315 | 316 | if __name__ == '__main__': 317 | jpegs = torch.randn(size=(1,3,224,224)) 318 | feature = torch.randn(size=(1,64,112,112)) 319 | net = FDM_dct_rep() 320 | 321 | total = sum(p.numel() for p in net.parameters()) 322 | print(total) 323 | print(net(feature,jpegs).shape) 324 | -------------------------------------------------------------------------------- /src/TMM1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peylnog/FDNet/53d0a1c62492c9daa227cbde4154668662b49c0f/src/TMM1.jpg -------------------------------------------------------------------------------- /src/feature_vis.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peylnog/FDNet/53d0a1c62492c9daa227cbde4154668662b49c0f/src/feature_vis.jpg -------------------------------------------------------------------------------- /src/performance-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peylnog/FDNet/53d0a1c62492c9daa227cbde4154668662b49c0f/src/performance-1.jpg -------------------------------------------------------------------------------- /src/performance-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peylnog/FDNet/53d0a1c62492c9daa227cbde4154668662b49c0f/src/performance-2.jpg -------------------------------------------------------------------------------- /src/task.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peylnog/FDNet/53d0a1c62492c9daa227cbde4154668662b49c0f/src/task.jpg -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Date: 2023-08-21 10:39:01 3 | LastEditors: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git 4 | LastEditTime: 2023-09-02 09:34:31 5 | FilePath: /Reproduce_PL/test.py 6 | ''' 7 | import argparse 8 | import os 9 | import numpy as np 10 | import math 11 | import itertools 12 | import time 13 | from datetime import timedelta 14 | import sys 15 | import os 16 | import torchvision.transforms as transforms 17 | from torchvision.utils import save_image 18 | from torch.utils.data import DataLoader 19 | from torchvision import datasets 20 | from torch.autograd import Variable 21 | from models import FDM_dct_rep as Model 22 | import torch.nn as nn 23 | import logging 24 | import random 25 | import torchvision 26 | import torch 27 | from datasets import ImageDataset_DDP 28 | from util import AverageMeter,accuracy 29 | from tqdm import tqdm 30 | from warmup_scheduler import GradualWarmupScheduler 31 | import torch.optim as optim 32 | import multiprocessing as mp 33 | 34 | 35 | resnet = torchvision.models.resnet50().eval().cuda() 36 | resnet = nn.DataParallel(resnet) 37 | resnet.load_state_dict(torch.load('./ckp/checkpoint.pth.tar')['state_dict']) 38 | resnet.eval().cuda() 39 | resnet=resnet.module 40 | @torch.no_grad() 41 | def get_feature_resnet50(img): 42 | x = resnet.conv1(img) 43 | x = resnet.bn1(x) 44 | x = resnet.relu(x) 45 | return x #torch.Size([1, 64, 112, 122]) 46 | 47 | 48 | @torch.no_grad() 49 | def inf_resnet(img): 50 | x = resnet.conv1(img) 51 | x = resnet.bn1(x) 52 | x = resnet.relu(x) 53 | x = resnet.maxpool(x) 54 | 55 | x = resnet.layer1(x) 56 | x = resnet.layer2(x) 57 | x = resnet.layer3(x) 58 | x = resnet.layer4(x) 59 | 60 | x = resnet.avgpool(x) 61 | x = torch.flatten(x, 1) 62 | x = resnet.fc(x) 63 | 64 | return x 65 | 66 | @torch.no_grad() 67 | def inf_resnet_FDM(img,model): 68 | 69 | x = model(get_feature_resnet50(img),img) 70 | 71 | 72 | x = resnet.maxpool(x) 73 | 74 | x = resnet.layer1(x) 75 | x = resnet.layer2(x) 76 | x = resnet.layer3(x) 77 | x = resnet.layer4(x) 78 | 79 | x = resnet.avgpool(x) 80 | x = torch.flatten(x, 1) 81 | x = resnet.fc(x) 82 | 83 | return x 84 | 85 | 86 | def validate(val_loader, model, batch_size=32): 87 | batch_time = AverageMeter() 88 | top1 = AverageMeter() 89 | top5 = AverageMeter() 90 | 91 | model.eval() 92 | 93 | with torch.no_grad(): 94 | i=0 95 | for inputs, labels in val_loader: 96 | inputs, labels = inputs.cuda(), labels.cuda() 97 | # print(inputs.size()) 98 | # outputs = inf_resnet_FDM(inputs,model) 99 | outputs = inf_resnet(inputs) 100 | 101 | acc1, acc5 = accuracy(outputs, labels, topk=(1, 5)) 102 | top1.update(acc1.item(), batch_size) 103 | top5.update(acc5.item(), batch_size) 104 | # if i % 60 == 0: 105 | # print('itration : {}/{} top 1: {} top 5: {}'.format(i,len(val_loader), top1.show(), top5.show())) 106 | # i += 1 107 | 108 | 109 | return top1.avg, top5.avg 110 | 111 | 112 | def val(root): 113 | print("Val") 114 | model = Model().cuda() 115 | model.eval() 116 | import torchvision.datasets as torchdatasets 117 | val_batch_size=256 118 | 119 | # degree = os.listdir(root) 120 | with torch.no_grad(): 121 | for j in root: 122 | path = j 123 | print(path) 124 | val_datasets = torchdatasets.ImageFolder( 125 | path, 126 | transforms.Compose([ 127 | transforms.Resize(256), 128 | transforms.CenterCrop(224), 129 | transforms.ToTensor(), 130 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 131 | std=[0.229, 0.224, 0.225]), 132 | ])) 133 | val_loader = DataLoader(val_datasets, 134 | batch_size=val_batch_size, 135 | shuffle=False, 136 | pin_memory=True, 137 | num_workers=8) 138 | 139 | acc1, acc5 = validate(val_loader,model,batch_size=val_batch_size) 140 | print("degree:{} acc1:{} acc5:{}".format(j,acc1,acc5)) 141 | 142 | 143 | 144 | if __name__ == '__main__': 145 | 146 | 147 | 148 | 149 | val(['./Data/7']) 150 | val(['./Data/10']) 151 | val(['./Data/15']) 152 | val(['./Data/18']) 153 | val(['./Data/25']) 154 | 155 | 156 | -------------------------------------------------------------------------------- /test_resnet50.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import os 4 | import numpy as np 5 | import math 6 | import itertools 7 | import time 8 | from datetime import timedelta 9 | import sys 10 | import os 11 | import torchvision.transforms as transforms 12 | from torchvision.utils import save_image 13 | from torch.utils.data import DataLoader 14 | from torchvision import datasets 15 | from torch.autograd import Variable 16 | import torch.nn as nn 17 | import logging 18 | import random 19 | import torchvision 20 | import torch 21 | from datasets import ImageDataset_DDP 22 | from util import AverageMeter,accuracy 23 | from tqdm import tqdm 24 | from warmup_scheduler import GradualWarmupScheduler 25 | import torch.optim as optim 26 | import multiprocessing as mp 27 | 28 | 29 | resnet = torchvision.models.resnet50(pretrained=True).eval().cuda() 30 | 31 | 32 | 33 | 34 | 35 | def validate(val_loader, model, batch_size=32): 36 | batch_time = AverageMeter() 37 | top1 = AverageMeter() 38 | top5 = AverageMeter() 39 | 40 | model.eval() 41 | 42 | with torch.no_grad(): 43 | i=0 44 | for inputs, labels in val_loader: 45 | inputs, labels = inputs.cuda(), labels.cuda() 46 | # print(inputs.size()) 47 | # outputs = inf_resnet_FDM(inputs,model) 48 | outputs = resnet(inputs) 49 | 50 | acc1, acc5 = accuracy(outputs, labels, topk=(1, 5)) 51 | top1.update(acc1.item(), batch_size) 52 | top5.update(acc5.item(), batch_size) 53 | # if i % 60 == 0: 54 | # print('itration : {}/{} top 1: {} top 5: {}'.format(i,len(val_loader), top1.show(), top5.show())) 55 | # i += 1 56 | 57 | 58 | return top1.avg, top5.avg 59 | 60 | 61 | def val(root): 62 | print("Val") 63 | model = resnet 64 | model.eval() 65 | import torchvision.datasets as torchdatasets 66 | val_batch_size=256 67 | 68 | # degree = os.listdir(root) 69 | with torch.no_grad(): 70 | for j in root: 71 | path = j 72 | print(path) 73 | val_datasets = torchdatasets.ImageFolder( 74 | path, 75 | transforms.Compose([ 76 | transforms.Resize(256), 77 | transforms.CenterCrop(224), 78 | transforms.ToTensor(), 79 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 80 | std=[0.229, 0.224, 0.225]), 81 | ])) 82 | val_loader = DataLoader(val_datasets, 83 | batch_size=val_batch_size, 84 | shuffle=False, 85 | pin_memory=True, 86 | num_workers=8) 87 | 88 | acc1, acc5 = validate(val_loader,model,batch_size=val_batch_size) 89 | print("degree:{} acc1:{} acc5:{}".format(j,acc1,acc5)) 90 | 91 | 92 | 93 | if __name__ == '__main__': 94 | 95 | 96 | val(['./Data/7']) 97 | val(['./Data/10']) 98 | val(['./Data/15']) 99 | val(['./Data/18']) 100 | val(['./Data/25']) 101 | 102 | # /ILSVRC2012_img_jpg/7 103 | # degree:/ILSVRC2012_img_jpg/7 acc1:33.12460140306123 acc5:55.494260204081634 104 | # Val 105 | # /ILSVRC2012_img_jpg/10 106 | # degree:/ILSVRC2012_img_jpg/10 acc1:47.21619897959184 acc5:71.37276785714286 107 | # Val 108 | # /ILSVRC2012_img_jpg/15 109 | # degree:/ILSVRC2012_img_jpg/15 acc1:57.313456632653065 acc5:80.71428571428571 110 | # Val 111 | # /ILSVRC2012_img_jpg/18 112 | # degree:/ILSVRC2012_img_jpg/18 acc1:60.40497448979592 acc5:83.02853954081633 113 | # Val 114 | # /ILSVRC2012_img_jpg/25 115 | # degree:/ILSVRC2012_img_jpg/25 acc1:63.721699617346935 acc5:85.49346301020408 116 | # Val 117 | # /ILSVRC2012_img_val 118 | # degree:/ILSVRC2012_img_val acc1:76.01881377551021 acc5:92.83920599489795 119 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | ------------------------------------------------- 4 | # @Project :dejpeg 5 | # @File :util 6 | # @Date :2021/2/22 下午9:52 7 | # @Author :SYJ 8 | # @Email :JuZiSYJ@gmail.com 9 | # @Software :PyCharm 10 | ------------------------------------------------- 11 | """ 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | import torch.nn.init as init 16 | import torch.nn.functional as F 17 | import math 18 | from PIL import Image 19 | from matplotlib import pyplot as plt 20 | from ptflops import get_model_complexity_info 21 | 22 | def accuracy(output, target, topk=(1, )): 23 | """Computes the accuracy over the k top predictions for the specified values of k""" 24 | with torch.no_grad(): 25 | maxk = max(topk) 26 | batch_size = target.size(0) 27 | 28 | _, pred = output.topk(maxk, 1, True, True) 29 | pred = pred.t() 30 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 31 | 32 | res = [] 33 | for k in topk: 34 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 35 | res.append(correct_k.mul_(100.0 / batch_size)) 36 | 37 | return res 38 | 39 | 40 | class AverageMeter(object): 41 | """Computes and stores the average and current value""" 42 | def __init__(self): 43 | self.reset() 44 | 45 | def reset(self): 46 | self.val = 0 47 | self.avg = 0 48 | self.sum = 0 49 | self.count = 0 50 | 51 | def update(self, val, n=1): 52 | self.val = val 53 | self.sum += val * n 54 | self.count += n 55 | self.avg = self.sum / self.count 56 | 57 | 58 | def show(self): 59 | return self.avg 60 | 61 | 62 | def adjust_learning_rate(optimizer, lr): 63 | """Sets the learning rate to the initial LR decayed by 10 every 2 epochs""" 64 | lr = lr / 5 65 | for param_group in optimizer.param_groups: 66 | param_group['lr'] = lr 67 | 68 | 69 | def conv_dw(inp, oup, kernel_size, stride, pad=0, bias = True): 70 | return nn.Sequential( 71 | nn.Conv2d(inp, inp, kernel_size, stride, pad, bias=True, groups=inp), 72 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 73 | # nn.ReLU(inplace=True) 74 | ) 75 | 76 | def conv_relu(inp, oup, kernel_size, stride, pad=0, bias = True, act='relu'): 77 | if act == 'relu': 78 | return nn.Sequential( 79 | nn.Conv2d(inp, oup, kernel_size, stride, pad, bias=True), 80 | # nn.BatchNorm2d(oup), 81 | nn.ReLU(inplace=True) 82 | ) 83 | elif act == 'prelu': 84 | return nn.Sequential( 85 | nn.Conv2d(inp, oup, kernel_size, stride, pad, bias=True), 86 | # nn.BatchNorm2d(oup), 87 | nn.PReLU(oup) 88 | ) 89 | 90 | 91 | 92 | class DWTForward(nn.Module): 93 | def __init__(self): 94 | super(DWTForward, self).__init__() 95 | ll = np.array([[0.5, 0.5], [0.5, 0.5]]) 96 | lh = np.array([[-0.5, -0.5], [0.5, 0.5]]) 97 | hl = np.array([[-0.5, 0.5], [-0.5, 0.5]]) 98 | hh = np.array([[0.5, -0.5], [-0.5, 0.5]]) 99 | filts = np.stack([ll[None, ::-1, ::-1], lh[None, ::-1, ::-1], 100 | hl[None, ::-1, ::-1], hh[None, ::-1, ::-1]], 101 | axis=0) 102 | self.weight = nn.Parameter( 103 | torch.tensor(filts).to(torch.get_default_dtype()), 104 | requires_grad=False) 105 | 106 | def forward(self, x): 107 | C = x.shape[1] 108 | filters = torch.cat([self.weight, ] * C, dim=0) 109 | y = F.conv2d(x, filters, groups=C, stride=2) 110 | return y 111 | 112 | 113 | class DWTInverse(nn.Module): 114 | def __init__(self): 115 | super(DWTInverse, self).__init__() 116 | ll = np.array([[0.5, 0.5], [0.5, 0.5]]) 117 | lh = np.array([[-0.5, -0.5], [0.5, 0.5]]) 118 | hl = np.array([[-0.5, 0.5], [-0.5, 0.5]]) 119 | hh = np.array([[0.5, -0.5], [-0.5, 0.5]]) 120 | filts = np.stack([ll[None, ::-1, ::-1], lh[None, ::-1, ::-1], 121 | hl[None, ::-1, ::-1], hh[None, ::-1, ::-1]], 122 | axis=0) 123 | self.weight = nn.Parameter( 124 | torch.tensor(filts).to(torch.get_default_dtype()), 125 | requires_grad=False) 126 | 127 | def forward(self, x): 128 | C = int(x.shape[1] / 4) 129 | filters = torch.cat([self.weight, ] * C, dim=0) 130 | y = F.conv_transpose2d(x, filters, groups=C, stride=2) 131 | return y 132 | 133 | class DCT(nn.Module): 134 | def __init__(self, N = 8, in_channal = 3): 135 | super(DCT, self).__init__() 136 | 137 | self.N = N # default is 8 for JPEG 138 | self.fre_len = N * N 139 | self.in_channal = in_channal 140 | self.out_channal = N * N * in_channal 141 | # self.weight = torch.from_numpy(self.mk_coff(N = N)).float().unsqueeze(1) 142 | self.Ycbcr = nn.Conv2d(3, 3, 1, 1, bias=False) # can be moved 143 | trans_matrix = np.array([[0.299, 0.587, 0.114], 144 | [-0.169, -0.331, 0.5], 145 | [0.5, -0.419, -0.081]]) 146 | trans_matrix = torch.from_numpy(trans_matrix).float().unsqueeze( 147 | 2).unsqueeze(3) 148 | self.Ycbcr.weight.data = trans_matrix 149 | self.Ycbcr.weight.requires_grad = False 150 | 151 | 152 | # 3 H W -> N*N H/N W/N 153 | self.dct_conv = nn.Conv2d(self.in_channal, self.out_channal, N, N, bias=False, groups=self.in_channal) 154 | 155 | # 64 *1 * 8 * 8, from low frequency to high fre 156 | self.weight = torch.from_numpy(self.mk_coff(N = N, rearrange=True)).float().unsqueeze(1) 157 | # self.dct_conv = nn.Conv2d(1, self.out_channal, N, N, bias=False) 158 | self.dct_conv.weight.data = torch.cat([self.weight]*self.in_channal, dim=0) # 64 1 8 8 159 | self.dct_conv.weight.requires_grad = False 160 | 161 | 162 | 163 | # self.reDCT = nn.ConvTranspose2d(self.out_channal, 1, self.N, self.N, bias = False) 164 | # self.reDCT.weight.data = self.weight 165 | 166 | 167 | 168 | 169 | 170 | 171 | def forward(self, x): 172 | # jpg = (jpg * self.std) + self.mean # 0-1 173 | ''' 174 | x: B C H W, 0-1. RGB 175 | YCbCr: b c h w, YCBCR 176 | DCT: B C*64 H//8 W//8 , Y_L..Y_H Cb_L...Cb_H Cr_l...Cr_H 177 | 178 | ''' 179 | # x = self.Ycbcr(x) # b 3 h w 180 | dct = self.dct_conv(x) 181 | return dct 182 | 183 | def mk_coff(self, N = 8, rearrange = True): 184 | dct_weight = np.zeros((N*N, N, N)) 185 | for k in range(N*N): 186 | u = k // N 187 | v = k % N 188 | for i in range(N): 189 | for j in range(N): 190 | tmp1 = self.get_1d(i, u, N=N) 191 | tmp2 = self.get_1d(j, v, N=N) 192 | tmp = tmp1 * tmp2 193 | tmp = tmp * self.get_c(u, N=N) * self.get_c(v, N=N) 194 | 195 | dct_weight[k, i, j] += tmp 196 | if rearrange: 197 | dct_weight = self.get_order(dct_weight, N = N) # from low frequency to high frequency 198 | return dct_weight # (N*N) * N * N 199 | 200 | def get_1d(self, ij, uv, N=8): 201 | result = math.cos(math.pi * uv * (ij + 0.5) / N) 202 | return result 203 | 204 | def get_c(self, u, N=8): 205 | if u == 0: 206 | return math.sqrt(1 / N) 207 | else: 208 | return math.sqrt(2 / N) 209 | 210 | def get_order(self, src_weight, N = 8): 211 | array_size = N * N 212 | # order_index = np.zeros((N, N)) 213 | i = 0 214 | j = 0 215 | rearrange_weigth = src_weight.copy() # (N*N) * N * N 216 | for k in range(array_size - 1): 217 | if (i == 0 or i == N-1) and j % 2 == 0: 218 | j += 1 219 | elif (j == 0 or j == N-1) and i % 2 == 1: 220 | i += 1 221 | elif (i + j) % 2 == 1: 222 | i += 1 223 | j -= 1 224 | elif (i + j) % 2 == 0: 225 | i -= 1 226 | j += 1 227 | index = i * N + j 228 | rearrange_weigth[k+1, ...] = src_weight[index, ...] 229 | return rearrange_weigth 230 | 231 | class ReDCT(nn.Module): 232 | def __init__(self, N = 4, in_channal = 3): 233 | super(ReDCT, self).__init__() 234 | 235 | self.N = N # default is 8 for JPEG 236 | self.in_channal = in_channal * N * N 237 | self.out_channal = in_channal 238 | self.fre_len = N * N 239 | 240 | self.weight = torch.from_numpy(self.mk_coff(N=N)).float().unsqueeze(1) 241 | 242 | 243 | self.reDCT = nn.ConvTranspose2d(self.in_channal, self.out_channal, self.N, self.N, bias = False, groups=self.out_channal) 244 | self.reDCT.weight.data = torch.cat([self.weight]*self.out_channal, dim=0) 245 | self.reDCT.weight.requires_grad = False 246 | 247 | 248 | def forward(self, dct): 249 | ''' 250 | IDCT from DCT domain to pixle domain 251 | B C*64 H//8 W//8 -> B C H W 252 | ''' 253 | out = self.reDCT(dct) 254 | return out 255 | 256 | def mk_coff(self, N = 8, rearrange = True): 257 | dct_weight = np.zeros((N*N, N, N)) 258 | for k in range(N*N): 259 | u = k // N 260 | v = k % N 261 | for i in range(N): 262 | for j in range(N): 263 | tmp1 = self.get_1d(i, u, N=N) 264 | tmp2 = self.get_1d(j, v, N=N) 265 | tmp = tmp1 * tmp2 266 | tmp = tmp * self.get_c(u, N=N) * self.get_c(v, N=N) 267 | 268 | dct_weight[k, i, j] += tmp 269 | if rearrange: 270 | out_weight = self.get_order(dct_weight, N = N) # from low frequency to high frequency 271 | return out_weight # (N*N) * N * N 272 | 273 | def get_1d(self, ij, uv, N=8): 274 | result = math.cos(math.pi * uv * (ij + 0.5) / N) 275 | return result 276 | 277 | def get_c(self, u, N=8): 278 | if u == 0: 279 | return math.sqrt(1 / N) 280 | else: 281 | return math.sqrt(2 / N) 282 | 283 | def get_order(self, src_weight, N = 8): 284 | array_size = N * N 285 | # order_index = np.zeros((N, N)) 286 | i = 0 287 | j = 0 288 | rearrange_weigth = src_weight.copy() # (N*N) * N * N 289 | for k in range(array_size - 1): 290 | if (i == 0 or i == N-1) and j % 2 == 0: 291 | j += 1 292 | elif (j == 0 or j == N-1) and i % 2 == 1: 293 | i += 1 294 | elif (i + j) % 2 == 1: 295 | i += 1 296 | j -= 1 297 | elif (i + j) % 2 == 0: 298 | i -= 1 299 | j += 1 300 | index = i * N + j 301 | rearrange_weigth[k+1, ...] = src_weight[index, ...] 302 | return rearrange_weigth 303 | 304 | class CALayer(nn.Module): 305 | def __init__(self, channel=64, reduction=16): 306 | super(CALayer, self).__init__() 307 | 308 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 309 | self.conv_du = nn.Sequential( 310 | nn.Conv2d(channel, channel//reduction, 1, padding=0, bias=True), 311 | nn.ReLU(inplace=True), 312 | nn.Conv2d(channel//reduction, channel, 1, padding=0, bias=True), 313 | nn.Tanh() 314 | ) 315 | 316 | def forward(self, x): 317 | y = self.avg_pool(x) 318 | 319 | y = self.conv_du(y) 320 | return x * y 321 | 322 | 323 | 324 | class _NonLocalBlockND(nn.Module): 325 | """ 326 | 调用过程 327 | NONLocalBlock2D(in_channels=32), 328 | super(NONLocalBlock2D, self).__init__(in_channels, 329 | inter_channels=inter_channels, 330 | dimension=2, sub_sample=sub_sample, 331 | bn_layer=bn_layer) 332 | """ 333 | 334 | def __init__(self, 335 | in_channels, 336 | inter_channels=None, 337 | dimension=2, 338 | sub_sample = True, 339 | bn_layer=True): 340 | super(_NonLocalBlockND, self).__init__() 341 | 342 | assert dimension in [1, 2, 3] 343 | 344 | self.dimension = dimension 345 | self.sub_sample = sub_sample 346 | 347 | self.in_channels = in_channels 348 | self.inter_channels = inter_channels 349 | 350 | if self.inter_channels is None: 351 | self.inter_channels = in_channels // 2 352 | # 进行压缩得到channel个数 353 | if self.inter_channels == 0: 354 | self.inter_channels = 1 355 | 356 | if dimension == 3: 357 | conv_nd = nn.Conv3d 358 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 359 | bn = nn.BatchNorm3d 360 | elif dimension == 2: 361 | conv_nd = nn.Conv2d 362 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 363 | bn = nn.BatchNorm2d 364 | else: 365 | conv_nd = nn.Conv1d 366 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 367 | bn = nn.BatchNorm1d 368 | 369 | self.g = conv_nd(in_channels=self.in_channels, 370 | out_channels=self.inter_channels, 371 | kernel_size=1, 372 | stride=1, 373 | padding=0) 374 | 375 | if bn_layer: 376 | self.W = nn.Sequential( 377 | conv_nd(in_channels=self.inter_channels, 378 | out_channels=self.in_channels, 379 | kernel_size=1, 380 | stride=1, 381 | padding=0), bn(self.in_channels)) 382 | nn.init.constant_(self.W[1].weight, 0) 383 | nn.init.constant_(self.W[1].bias, 0) 384 | else: 385 | self.W = conv_nd(in_channels=self.inter_channels, 386 | out_channels=self.in_channels, 387 | kernel_size=1, 388 | stride=1, 389 | padding=0) 390 | nn.init.constant_(self.W.weight, 0) 391 | nn.init.constant_(self.W.bias, 0) 392 | 393 | self.theta = conv_nd(in_channels=self.in_channels, 394 | out_channels=self.inter_channels, 395 | kernel_size=1, 396 | stride=1, 397 | padding=0) 398 | self.phi = conv_nd(in_channels=self.in_channels, 399 | out_channels=self.inter_channels, 400 | kernel_size=1, 401 | stride=1, 402 | padding=0) 403 | 404 | if sub_sample: 405 | self.g = nn.Sequential(self.g, max_pool_layer) 406 | self.phi = nn.Sequential(self.phi, max_pool_layer) 407 | 408 | def forward(self, x): 409 | ''' 410 | :param x: (b, c, h, w) 411 | :return: 412 | ''' 413 | 414 | batch_size = x.size(0) 415 | 416 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) # [bs, c, w*h] 417 | g_x = g_x.permute(0, 2, 1) # b (h*w) c 418 | 419 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 420 | theta_x = theta_x.permute(0, 2, 1) # b (h*w) c, querry 421 | 422 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) # [bs, c, w*h], key 423 | 424 | f = torch.matmul(theta_x, phi_x) 425 | 426 | # print(f.shape, theta_x.shape, phi_x.shape) 427 | 428 | f_div_C = F.softmax(f, dim=-1) 429 | 430 | y = torch.matmul(f_div_C, g_x) 431 | y = y.permute(0, 2, 1).contiguous() 432 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 433 | W_y = self.W(y) 434 | z = W_y + x 435 | return z 436 | 437 | 438 | class dwt_rep(nn.Module): 439 | def __init__(self, in_channal = 64, residual = False): 440 | super(dwt_rep, self).__init__() 441 | 442 | self.in_channal = in_channal 443 | self.manuan_out = in_channal * 8 444 | self.kernel_weigth = torch.tensor([[[1,1,1],[1,1,1],[1,1,1]], 445 | [[0,0,0],[1,1,1],[0,0,0]], 446 | [[0,1,0],[0,1,0],[0,1,0]], 447 | [[1,0,0],[0,1,0],[0,0,1]], 448 | [[0,0,1],[0,1,0],[1,0,0]], 449 | [[1,0,1],[0,-1,0],[1,0,1]], 450 | [[0,1,0],[-1,1,-1],[0,1,0]], 451 | [[1,0,1],[0,1,0],[1,0,1]]]).unsqueeze(1).float() 452 | 453 | self.kernel_weigth = torch.cat([self.kernel_weigth] * in_channal, dim=0) 454 | 455 | self.manual_conv = nn.Conv2d(in_channels=in_channal, out_channels=self.manuan_out, kernel_size=3, padding=1, groups=in_channal) 456 | self.manual_conv.weight.data = self.kernel_weigth 457 | self.manual_conv.weight.requires_grad = False 458 | 459 | 460 | self.conv_list = nn.ModuleList() 461 | for i in range(8): 462 | self.conv_list.append(nn.Conv2d(in_channels=in_channal, out_channels=in_channal, kernel_size=1, padding=0)) 463 | 464 | self.relu = nn.PReLU() 465 | 466 | def forward(self, x): 467 | tmp = self.manual_conv(x) 468 | b,c,h,w = tmp.shape 469 | tmp = tmp.reshape(b, c//8, 8, h, w) 470 | out = [] 471 | 472 | for i in range(8): 473 | out.append(self.conv_list[i](tmp[:,:,i,:,:])) 474 | 475 | 476 | 477 | tmp = torch.stack(out, dim=0) 478 | tmp = torch.sum(tmp, dim=0, keepdim=False) 479 | 480 | out = self.relu(tmp + x) 481 | 482 | 483 | 484 | 485 | return out 486 | 487 | 488 | class InverShift(nn.Conv2d): 489 | ''' 490 | rgb_mean = (0.4488, 0.4371, 0.4040) 491 | rgb_std = (1.0, 1.0, 1.0) 492 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 493 | x = (x - mean) / std 494 | ''' 495 | def __init__( 496 | self, rgb_range=1.0, 497 | rgb_mean=(0.485, 0.456, 0.406), rgb_std=(0.229, 0.224, 0.225)): 498 | 499 | 500 | 501 | super(InverShift, self).__init__(3, 3, kernel_size=1, bias=True) 502 | std = torch.Tensor(rgb_std) 503 | mean = torch.Tensor(rgb_mean) 504 | 505 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) * std.view(3, 1, 1, 1) 506 | self.bias.data = mean.view(3,) 507 | 508 | for p in self.parameters(): 509 | p.requires_grad = False 510 | 511 | class YcbcrShift(nn.Conv2d): 512 | ''' 513 | 514 | RGB2Ycbcr 515 | ''' 516 | def __init__( 517 | self): 518 | 519 | super(YcbcrShift, self).__init__(3, 3, kernel_size=1) 520 | 521 | trans_matrix = torch.tensor([[0.299, 0.587, 0.114], 522 | 523 | [-0.169, -0.331, 0.5], 524 | 525 | [0.5, -0.419, -0.081]]) 526 | 527 | 528 | self.weight.data = trans_matrix.float().unsqueeze( 529 | 2).unsqueeze(3) 530 | # self.bias.data = mean.view(3,) 531 | 532 | for p in self.parameters(): 533 | p.requires_grad = False 534 | 535 | 536 | 537 | def get_coffe_dct_n(N=3): 538 | ''' 539 | 540 | :param N: 541 | :return: (N*N) * N * N 542 | ''' 543 | def mk_coff(self, N = 8): 544 | dct_weight = np.zeros((N*N, N, N)) 545 | for k in range(N*N): 546 | u = k // N 547 | v = k % N 548 | for i in range(N): 549 | for j in range(N): 550 | tmp1 = get_1d(i, u, N=N) 551 | tmp2 = get_1d(j, v, N=N) 552 | tmp = tmp1 * tmp2 553 | tmp = tmp * get_c(u, N=N) * get_c(v, N=N) 554 | 555 | dct_weight[k, i, j] += tmp 556 | return dct_weight # (N*N) * N * N 557 | 558 | def get_1d(self, ij, uv, N=8): 559 | result = math.cos(math.pi * uv * (ij + 0.5) / N) 560 | return result 561 | 562 | def get_c(self, u, N=8): 563 | if u == 0: 564 | return math.sqrt(1 / N) 565 | else: 566 | return math.sqrt(2 / N) 567 | return mk_coff(N=N) 568 | 569 | 570 | class DWTForward(nn.Module): 571 | ''' 572 | input c h w, out (4*c) * (h//2) * (w // 2) 573 | ''' 574 | def __init__(self): 575 | super(DWTForward, self).__init__() 576 | ll = np.array([[0.5, 0.5], [0.5, 0.5]]) 577 | lh = np.array([[-0.5, -0.5], [0.5, 0.5]]) 578 | hl = np.array([[-0.5, 0.5], [-0.5, 0.5]]) 579 | hh = np.array([[0.5, -0.5], [-0.5, 0.5]]) 580 | filts = np.stack([ll[None, ::-1, ::-1], lh[None, ::-1, ::-1], 581 | hl[None, ::-1, ::-1], hh[None, ::-1, ::-1]], 582 | axis=0) 583 | self.weight = nn.Parameter( 584 | torch.tensor(filts).to(torch.get_default_dtype()), 585 | requires_grad=False) 586 | 587 | def forward(self, x): 588 | C = x.shape[1] 589 | filters = torch.cat([self.weight, ] * C, dim=0) 590 | y = F.conv2d(x, filters, groups=C, stride=2) 591 | return y 592 | 593 | 594 | class DWTInverse(nn.Module): 595 | def __init__(self): 596 | super(DWTInverse, self).__init__() 597 | ll = np.array([[0.5, 0.5], [0.5, 0.5]]) 598 | lh = np.array([[-0.5, -0.5], [0.5, 0.5]]) 599 | hl = np.array([[-0.5, 0.5], [-0.5, 0.5]]) 600 | hh = np.array([[0.5, -0.5], [-0.5, 0.5]]) 601 | filts = np.stack([ll[None, ::-1, ::-1], lh[None, ::-1, ::-1], 602 | hl[None, ::-1, ::-1], hh[None, ::-1, ::-1]], 603 | axis=0) 604 | self.weight = nn.Parameter( 605 | torch.tensor(filts).to(torch.get_default_dtype()), 606 | requires_grad=False) 607 | 608 | def forward(self, x): 609 | C = int(x.shape[1] / 4) 610 | filters = torch.cat([self.weight, ] * C, dim=0) 611 | y = F.conv_transpose2d(x, filters, groups=C, stride=2) 612 | return y 613 | 614 | 615 | 616 | 617 | if __name__ == '__main__': 618 | 619 | x = torch.rand(2,64,112,112) 620 | 621 | 622 | model = _NonLocalBlockND(in_channels=64) 623 | 624 | 625 | model = DCT(N=3, in_channal=3) 626 | 627 | for i in range(3): 628 | for j in range(3): 629 | plt.subplot(3,3,i*3+j+1) 630 | plt.imshow(model.weight[i*3+j,0,...], cmap='gray') 631 | plt.show() 632 | --------------------------------------------------------------------------------