├── .circleci └── config.yml ├── .github ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md └── deit.png ├── .gitignore ├── README.md ├── finetune.sh ├── ipt.py ├── main.py ├── pretrain.sh ├── requirements.txt └── test.sh /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | jobs: 4 | python_lint: 5 | docker: 6 | - image: circleci/python:3.7 7 | steps: 8 | - checkout 9 | - run: 10 | command: | 11 | pip install --user --progress-bar off flake8 typing 12 | flake8 . 13 | test: 14 | docker: 15 | - image: circleci/python:3.7 16 | steps: 17 | - checkout 18 | - run: 19 | command: | 20 | pip install --user --progress-bar off pytest 21 | pip install --user --progress-bar off torch torchvision 22 | pip install --user --progress-bar off timm==0.3.2 23 | pytest . 24 | 25 | workflows: 26 | build: 27 | jobs: 28 | - python_lint 29 | -------------------------------------------------------------------------------- /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to DeiT 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | Minor changes and improvements will be released on an ongoing basis. Larger changes (e.g., changesets implementing a new paper) will be released on a more periodic basis. 7 | 8 | ## Pull Requests 9 | We actively welcome your pull requests. 10 | 11 | 1. Fork the repo and create your branch from `main`. 12 | 2. If you've added code that should be tested, add tests. 13 | 3. If you've changed APIs, update the documentation. 14 | 4. Ensure the test suite passes. 15 | 5. Make sure your code lints. 16 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 17 | 18 | ## Contributor License Agreement ("CLA") 19 | In order to accept your pull request, we need you to submit a CLA. You only need 20 | to do this once to work on any of Facebook's open source projects. 21 | 22 | Complete your CLA here: 23 | 24 | ## Issues 25 | We use GitHub issues to track public bugs. Please ensure your description is 26 | clear and has sufficient instructions to be able to reproduce the issue. 27 | 28 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 29 | disclosure of security bugs. In those cases, please go through the process 30 | outlined on that page and do not file a public issue. 31 | 32 | ## Coding Style 33 | * 4 spaces for indentation rather than tabs 34 | * 80 character line length 35 | * PEP8 formatting following [Black](https://black.readthedocs.io/en/stable/) 36 | 37 | ## License 38 | By contributing to DeiT, you agree that your contributions will be licensed 39 | under the LICENSE file in the root directory of this source tree. 40 | -------------------------------------------------------------------------------- /.github/deit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perseveranceLX/ImageProcessingTransformer/bc54c184652af1fe7b965d63fc8a9b521e5808c6/.github/deit.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | **/__pycache__/** 3 | imnet_resnet50_scratch/timm_temp/ 4 | .dumbo.json 5 | checkpoints/ 6 | ckpt/ 7 | *.log 8 | logs/ 9 | __pycache__/ 10 | dataset/__pycache__/ 11 | test_data 12 | .vscode/ 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ImageProcessingTransformer 2 | Third party Pytorch implement of Image Processing Transformer (Pre-Trained Image Processing Transformer arXiv:2012.00364v2) 3 | 4 | The latest version contains some important modifications according to the official mindspore implementation. It makes convergecy a lot faster. Please make sure you update to the latest version. 5 | 6 | only contain model definition file and train/test file. Dataloader file is not yet released. You could implement your own dataloader. It may be released in the next version. 7 | 8 | To pretrain on random task 9 | 10 | python main.py --seed 0 \ 11 | --lr 5e-5 \ 12 | --save-path "./ckpt" \ 13 | --epochs 300 \ 14 | --data path-to-data \ 15 | --batch-size 256 16 | 17 | To finetune on a specific task 18 | 19 | python main.py --seed 0 \ 20 | --lr 2e-5 \ 21 | --save-path "./ckpt" \ 22 | --epochs 30 \ 23 | --reset-epoch \ 24 | --data path-to-data \ 25 | --batch-size 256 \ 26 | --resume path-to-pretrain-model \ 27 | --task "dehaze" 28 | 29 | To eval on a specific task 30 | 31 | python main.py --seed 0 \ 32 | --eval-data path-to-val-data \ 33 | --batch-size 256 \ 34 | --eval \ 35 | --resume path-to-model \ 36 | --task "dehaze" 37 | -------------------------------------------------------------------------------- /finetune.sh: -------------------------------------------------------------------------------- 1 | python main.py --seed 0 \ 2 | --lr 2e-5 \ 3 | --save-path "./ckpt" \ 4 | --epochs 30 \ 5 | --data path-to-data \ 6 | --batch-size 256 \ 7 | --resume path-to-checkpoint \ 8 | --task "dehaze" 9 | # finetune option "denoise30", "denoise50", "SRx2", "SRx3", "SRx4", "dehaze" -------------------------------------------------------------------------------- /ipt.py: -------------------------------------------------------------------------------- 1 | """ Vision Transformer (ViT) in PyTorch 2 | 3 | A PyTorch implement of Vision Transformers as described in 4 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929 5 | 6 | The official jax code is released and available at https://github.com/google-research/vision_transformer 7 | 8 | Status/TODO: 9 | * Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights. 10 | * Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches. 11 | * Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code. 12 | * Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future. 13 | 14 | Acknowledgments: 15 | * The paper authors for releasing code and weights, thanks! 16 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out 17 | for some einops/einsum fun 18 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT 19 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert 20 | 21 | Hacked together by / Copyright 2020 Ross Wightman 22 | """ 23 | import torch 24 | from torch.functional import Tensor 25 | import torch.nn as nn 26 | from functools import partial 27 | import math 28 | import warnings 29 | 30 | 31 | 32 | 33 | class Ffn(nn.Module): 34 | # feed forward network layer after attention 35 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, drop=0.): 36 | super().__init__() 37 | out_features = out_features or in_features 38 | hidden_features = hidden_features or in_features 39 | self.fc1 = nn.Linear(in_features, hidden_features) 40 | self.act = act_layer(inplace=True) 41 | self.fc2 = nn.Linear(hidden_features, out_features) 42 | self.drop = nn.Dropout(drop) 43 | 44 | def forward(self, x): 45 | x = self.fc1(x) 46 | x = self.act(x) 47 | x = self.drop(x) 48 | x = self.fc2(x) 49 | x = self.drop(x) 50 | return x 51 | 52 | class Attention(nn.Module): 53 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 54 | super().__init__() 55 | self.num_heads = num_heads 56 | head_dim = dim // num_heads 57 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 58 | self.scale = qk_scale or head_dim ** -0.5 59 | 60 | self.query = nn.Linear(dim, dim, bias=qkv_bias) 61 | self.key = nn.Linear(dim, dim, bias=qkv_bias) 62 | self.value = nn.Linear(dim, dim, bias=qkv_bias) 63 | self.attn_drop = nn.Dropout(attn_drop) 64 | self.proj = nn.Linear(dim, dim) 65 | self.proj_drop = nn.Dropout(proj_drop) 66 | 67 | def forward(self, q, k, v): 68 | N, L, D = q.shape 69 | q, k, v = self.query(q), self.key(k), self.value(v) 70 | q = q.reshape(N, L, self.num_heads, D // self.num_heads).permute(0, 2, 1, 3) 71 | k = k.reshape(N, L, self.num_heads, D // self.num_heads).permute(0, 2, 1, 3) 72 | v = v.reshape(N, L, self.num_heads, D // self.num_heads).permute(0, 2, 1, 3) 73 | 74 | attn = (q @ k.transpose(-2, -1)) * self.scale 75 | attn = attn.softmax(dim=-1) 76 | attn = self.attn_drop(attn) 77 | 78 | x = (attn @ v).transpose(1, 2).reshape(N, L, D) 79 | x = self.proj(x) 80 | x = self.proj_drop(x) 81 | return x 82 | 83 | 84 | class EncoderLayer(nn.Module): 85 | 86 | def __init__(self, dim, num_heads, ffn_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 87 | act_layer=nn.ReLU, norm_layer=nn.LayerNorm): 88 | super().__init__() 89 | self.norm1 = norm_layer(dim) 90 | self.attn = Attention( 91 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 92 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 93 | # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 94 | self.norm2 = norm_layer(dim) 95 | ffn_hidden_dim = int(dim * ffn_ratio) 96 | self.ffn = Ffn(in_features=dim, hidden_features=ffn_hidden_dim, act_layer=act_layer, drop=drop) 97 | 98 | def forward(self, x, pos): 99 | x = self.norm1(x) 100 | q, k, v = x + pos, x + pos, x 101 | x = x + self.attn(q, k, v) 102 | x = x + self.ffn(self.norm2(x)) 103 | return x 104 | 105 | class DecoderLayer(nn.Module): 106 | 107 | def __init__(self, dim, num_heads, ffn_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 108 | act_layer=nn.ReLU, norm_layer=nn.LayerNorm): 109 | super().__init__() 110 | self.norm1 = norm_layer(dim) 111 | self.attn1 = Attention( 112 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 113 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 114 | # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 115 | self.norm2 = norm_layer(dim) 116 | self.attn2 = Attention( 117 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 118 | self.norm3 = norm_layer(dim) 119 | ffn_hidden_dim = int(dim * ffn_ratio) 120 | self.ffn = Ffn(in_features=dim, hidden_features=ffn_hidden_dim, act_layer=act_layer, drop=drop) 121 | 122 | def forward(self, x, pos, task_embed): 123 | memory = x 124 | x = self.norm1(x) 125 | q, k, v = x + task_embed, x + task_embed, x 126 | x = x + self.attn1(q, k, v) 127 | x = self.norm2(x) 128 | q, k, v = x + task_embed, memory + pos, memory 129 | x = x + self.attn2(q, k, v) 130 | x = x + self.ffn(self.norm3(x)) 131 | return x 132 | 133 | 134 | class ResBlock(nn.Module): 135 | 136 | def __init__(self, channels): 137 | super(ResBlock, self).__init__() 138 | self.conv1 = nn.Conv2d(channels, channels, kernel_size=5, stride=1, 139 | padding=2, bias=False) 140 | # self.bn1 = nn.BatchNorm2d(channels) 141 | self.relu = nn.ReLU(inplace=True) 142 | self.conv2 = nn.Conv2d(channels, channels, kernel_size=5, stride=1, 143 | padding=2, bias=False) 144 | # self.bn2 = nn.BatchNorm2d(channels) 145 | 146 | def forward(self, x): 147 | residual = x 148 | 149 | out = self.conv1(x) 150 | # out = self.bn1(out) 151 | out = self.relu(out) 152 | 153 | out = self.conv2(out) 154 | # out = self.bn2(out) 155 | 156 | out += residual 157 | # out = self.relu(out) 158 | 159 | return out 160 | 161 | class Head(nn.Module): 162 | """ Head consisting of convolution layers 163 | Extract features from corrupted images, mapping N3HW images into NCHW feature map. 164 | """ 165 | def __init__(self, in_channels, out_channels): 166 | super(Head, self).__init__() 167 | self.conv1= nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, 168 | padding=1, bias=False) 169 | # self.bn1 = nn.BatchNorm2d(out_channels) if task_id in [0, 1, 5] else nn.Identity() 170 | # self.relu = nn.ReLU(inplace=True) 171 | self.resblock1 = ResBlock(out_channels) 172 | self.resblock2 = ResBlock(out_channels) 173 | 174 | def forward(self, x): 175 | out = self.conv1(x) 176 | # out = self.bn1(out) 177 | # out = self.relu(out) 178 | 179 | out = self.resblock1(out) 180 | out = self.resblock2(out) 181 | 182 | return out 183 | 184 | class PatchEmbed(nn.Module): 185 | """ Feature to Patch Embedding 186 | input : N C H W 187 | output: N num_patch P^2*C 188 | """ 189 | def __init__(self, patch_size=1, in_channels=64): 190 | super().__init__() 191 | self.patch_size = patch_size 192 | self.dim = self.patch_size ** 2 * in_channels 193 | 194 | def forward(self, x): 195 | N, C, H, W = ori_shape = x.shape 196 | 197 | p = self.patch_size 198 | num_patches = (H // p) * (W // p) 199 | out = torch.zeros((N, num_patches, self.dim)).to(x.device) 200 | #print(f"feature map size: {ori_shape}, embedding size: {out.shape}") 201 | i, j = 0, 0 202 | for k in range(num_patches): 203 | if i + p > W: 204 | i = 0 205 | j += p 206 | out[:, k, :] = x[:, :, i:i+p, j:j+p].flatten(1) 207 | i += p 208 | return out, ori_shape 209 | 210 | class DePatchEmbed(nn.Module): 211 | """ Patch Embedding to Feature 212 | input : N num_patch P^2*C 213 | output: N C H W 214 | """ 215 | def __init__(self, patch_size=1, in_channels=64): 216 | super().__init__() 217 | self.patch_size = patch_size 218 | self.num_patches = None 219 | self.dim = self.patch_size ** 2 * in_channels 220 | 221 | def forward(self, x, ori_shape): 222 | N, num_patches, dim = x.shape 223 | _, C, H, W = ori_shape 224 | p = self.patch_size 225 | out = torch.zeros(ori_shape).to(x.device) 226 | i, j = 0, 0 227 | for k in range(num_patches): 228 | if i + p > W: 229 | i = 0 230 | j += p 231 | out[:, :, i:i+p, j:j+p] = x[:, k, :].reshape(N, C, p, p) 232 | #out[:, k, :] = x[:, :, i:i+p, j:j+p].flatten(1) 233 | i += p 234 | return out 235 | 236 | 237 | class Tail(nn.Module): 238 | """ Tail consisting of convolution layers and pixel shuffle layers 239 | NCHW -> N3HW. 240 | """ 241 | def __init__(self, task_id, in_channels, out_channels): 242 | super(Tail, self).__init__() 243 | assert 0 <= task_id <= 5 244 | # 0, 1 for noise 30, 50; 2, 3, 4 for sr x2, x3, x4, 5 for defog 245 | upscale_map = [1, 1, 2, 3, 4, 1] 246 | scale = upscale_map[task_id] 247 | m = [] 248 | # for SR task 249 | if scale > 1: 250 | m.append(nn.Conv2d(in_channels, in_channels * scale * scale, kernel_size=3, stride=1, 251 | padding=1, bias=False)) 252 | if (scale & (scale - 1)) == 0: 253 | for _ in range(int(math.log(scale, 2))): 254 | m.append(nn.PixelShuffle(2)) 255 | elif scale == 3: 256 | m.append(nn.PixelShuffle(3)) 257 | else: 258 | raise NameError("Only support x3 and x2^n SR") 259 | 260 | m.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, 261 | padding=1, bias=False)) 262 | self.m = nn.Sequential(*m) 263 | 264 | def forward(self, x): 265 | out = self.m(x) 266 | #print("task_id:", self.task_id) 267 | #print("shape of tail's output:", x.shape) 268 | # out = self.bn1(out) 269 | return out 270 | 271 | class ImageProcessingTransformer(nn.Module): 272 | """ Vision Transformer with support for patch or hybrid CNN input stage 273 | """ 274 | def __init__(self, patch_size=1, in_channels=3, mid_channels=64, num_classes=1000, depth=12, 275 | num_heads=8, ffn_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 276 | norm_layer=nn.LayerNorm): 277 | super(ImageProcessingTransformer, self).__init__() 278 | 279 | self.task_id = None 280 | self.num_classes = num_classes 281 | self.embed_dim = patch_size * patch_size * mid_channels 282 | self.headsets = nn.ModuleList([Head(in_channels, mid_channels) for _ in range(6)]) 283 | self.patch_embedding = PatchEmbed(patch_size=patch_size, in_channels=mid_channels) 284 | self.embed_dim = self.patch_embedding.dim 285 | if self.embed_dim % num_heads != 0: 286 | raise RuntimeError("Embedding dim must be devided by numbers of heads") 287 | 288 | self.pos_embed = nn.Parameter(torch.zeros(1, (48 // patch_size) ** 2, self.embed_dim)) 289 | self.task_embed = nn.Parameter(torch.zeros(6, 1, (48 // patch_size) ** 2, self.embed_dim)) 290 | self.encoder = nn.ModuleList([ 291 | EncoderLayer( 292 | dim=self.embed_dim, num_heads=num_heads, ffn_ratio=ffn_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 293 | drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer) 294 | for _ in range(depth)]) 295 | self.decoder = nn.ModuleList([ 296 | DecoderLayer( 297 | dim=self.embed_dim, num_heads=num_heads, ffn_ratio=ffn_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 298 | drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer) 299 | for _ in range(depth)]) 300 | #self.norm = norm_layer(self.embed_dim) 301 | 302 | self.de_patch_embedding = DePatchEmbed(patch_size=patch_size, in_channels=mid_channels) 303 | # tail 304 | self.tailsets = nn.ModuleList([Tail(id, mid_channels, in_channels) for id in range(6)]) 305 | 306 | trunc_normal_(self.pos_embed, std=.02) 307 | self.apply(self._init_weights) 308 | 309 | def set_task(self, task_id): 310 | self.task_id = task_id 311 | 312 | def _init_weights(self, m): 313 | if isinstance(m, nn.Linear): 314 | trunc_normal_(m.weight, std=.02) 315 | if isinstance(m, nn.Linear) and m.bias is not None: 316 | nn.init.constant_(m.bias, 0) 317 | elif isinstance(m, nn.LayerNorm): 318 | nn.init.constant_(m.bias, 0) 319 | nn.init.constant_(m.weight, 1.0) 320 | 321 | def forward(self, x): 322 | assert 0 <= self.task_id <= 5 323 | # print("input shape:", x.shape, x.device) 324 | x = self.headsets[self.task_id](x) 325 | x, ori_shape = self.patch_embedding(x) 326 | # print("embedding shape:", x.shape) 327 | # print(x.device, self.pos_embed.device) 328 | for blk in self.encoder: 329 | x = blk(x, self.pos_embed[:, :x.shape[1]]) 330 | for blk in self.decoder: 331 | x = blk(x, self.pos_embed[:, :x.shape[1]], self.task_embed[self.task_id, :, :x.shape[1]]) 332 | x = self.de_patch_embedding(x, ori_shape) 333 | x = self.tailsets[self.task_id](x) 334 | #x = self.norm(x) 335 | return x 336 | 337 | 338 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 339 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 340 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 341 | def norm_cdf(x): 342 | # Computes standard normal cumulative distribution function 343 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 344 | 345 | if (mean < a - 2 * std) or (mean > b + 2 * std): 346 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 347 | "The distribution of values may be incorrect.", 348 | stacklevel=2) 349 | 350 | with torch.no_grad(): 351 | # Values are generated by using a truncated uniform distribution and 352 | # then using the inverse CDF for the normal distribution. 353 | # Get upper and lower cdf values 354 | l = norm_cdf((a - mean) / std) 355 | u = norm_cdf((b - mean) / std) 356 | 357 | # Uniformly fill tensor with values from [l, u], then translate to 358 | # [2l-1, 2u-1]. 359 | tensor.uniform_(2 * l - 1, 2 * u - 1) 360 | 361 | # Use inverse cdf transform for normal distribution to get truncated 362 | # standard normal 363 | tensor.erfinv_() 364 | 365 | # Transform to proper mean, std 366 | tensor.mul_(std * math.sqrt(2.)) 367 | tensor.add_(mean) 368 | 369 | # Clamp to ensure it's in the proper range 370 | tensor.clamp_(min=a, max=b) 371 | return tensor 372 | 373 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 374 | # type: (Tensor, float, float, float, float) -> Tensor 375 | r"""Fills the input Tensor with values drawn from a truncated 376 | normal distribution. The values are effectively drawn from the 377 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 378 | with values outside :math:`[a, b]` redrawn until they are within 379 | the bounds. The method used for generating the random values works 380 | best when :math:`a \leq \text{mean} \leq b`. 381 | Args: 382 | tensor: an n-dimensional `torch.Tensor` 383 | mean: the mean of the normal distribution 384 | std: the standard deviation of the normal distribution 385 | a: the minimum cutoff value 386 | b: the maximum cutoff value 387 | Examples: 388 | >>> w = torch.empty(3, 5) 389 | >>> nn.init.trunc_normal_(w) 390 | """ 391 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 392 | 393 | 394 | def ipt_base(**kwargs): 395 | model = ImageProcessingTransformer( 396 | patch_size=4, depth=12, num_heads=8, ffn_ratio=4, qkv_bias=True, 397 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 398 | return model 399 | 400 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import time 5 | import warnings 6 | from datetime import datetime 7 | from collections import OrderedDict 8 | import math 9 | import random 10 | 11 | import torch 12 | from torch.cuda.amp.autocast_mode import autocast 13 | import torch.nn as nn 14 | import torch.nn.parallel 15 | import torch.backends.cudnn as cudnn 16 | import torch.distributed as dist 17 | import torch.optim 18 | import torch.multiprocessing as mp 19 | import torch.cuda.amp as amp 20 | import torch.utils.data 21 | import torch.utils.data.distributed 22 | import torchvision.transforms as transforms 23 | import torchvision.datasets as datasets 24 | #import torchvision.models as models 25 | from ipt import ipt_base 26 | from dataset.dataset import * 27 | from datetime import datetime 28 | 29 | 30 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 31 | parser.add_argument('-d','--data', metavar='DIR', default='./data', 32 | help='path to dataset') 33 | parser.add_argument('--eval-data', metavar='DIR', default='./data', 34 | help='path to eval dataset') 35 | parser.add_argument('-s','--save-path', metavar='DIR', default='./ckpt', 36 | help='path to save checkpoints') 37 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 38 | help='number of data loading workers (default: 4)') 39 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 40 | help='number of total epochs to run') 41 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 42 | help='manual epoch number (useful on restarts)') 43 | parser.add_argument('-b', '--batch-size', default=256, type=int, 44 | metavar='N', 45 | help='mini-batch size (default: 256), this is the total ' 46 | 'batch size of all GPUs on the current node when ' 47 | 'using Data Parallel or Distributed Data Parallel') 48 | parser.add_argument('--lr', '--learning-rate', default=0.0001, type=float, 49 | metavar='LR', help='initial learning rate', dest='lr') 50 | parser.add_argument('--lr-policy', default='naive', 51 | help='lr policy') 52 | parser.add_argument('--warmup-epochs', default=0, type=int, metavar='N', 53 | help='number of warmup epochs') 54 | parser.add_argument('--warmup-lr-multiplier', default=0.1, type=float, metavar='W', 55 | help='warmup lr multiplier') 56 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 57 | help='momentum') 58 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 59 | metavar='W', help='weight decay (default: 1-4)', 60 | dest='weight_decay') 61 | parser.add_argument('--power', default=1.0, type=float, 62 | metavar='P', help='power for poly learning-rate decay') 63 | parser.add_argument('-p', '--print-freq', default=10, type=int, 64 | metavar='N', help='print frequency (default: 10)') 65 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 66 | help='path to latest checkpoint (default: none)') 67 | parser.add_argument('--reset-epoch', action='store_true', 68 | help='whether to reset epoch') 69 | parser.add_argument('--eval', action='store_true', 70 | help='only do evaluation') 71 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 72 | help='evaluate model on validation set') 73 | parser.add_argument('--task', default='', type=str, metavar='string', 74 | help='specific a task' 75 | '["denoise30", "denoise50", "SRx2", "SRx3", "SRx4", "dehaze"] (default: none)') 76 | parser.add_argument('--world-size', default=-1, type=int, 77 | help='number of nodes for distributed training') 78 | parser.add_argument('--rank', default=-1, type=int, 79 | help='node rank for distributed training') 80 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 81 | help='url used to set up distributed training') 82 | parser.add_argument('--dist-backend', default='nccl', type=str, 83 | help='distributed backend') 84 | parser.add_argument('--seed', default=None, type=int, 85 | help='seed for initializing training. ') 86 | parser.add_argument('--gpu', default=None, type=int, 87 | help='GPU id to use.') 88 | parser.add_argument('--multiprocessing-distributed', action='store_true', 89 | help='Use multi-processing distributed training to launch ' 90 | 'N processes per node, which has N GPUs. This is the ' 91 | 'fastest way to use PyTorch for either single node or ' 92 | 'multi node data parallel training') 93 | parser.add_argument('--fp16',action='store_true', default=False, help="\ 94 | use fp16 instead of fp32.") 95 | 96 | 97 | best_acc1 = 0 98 | # set task sets 99 | 100 | 101 | def main(): 102 | args = parser.parse_args() 103 | 104 | now = datetime.now() 105 | timestr = now.strftime("%m-%d-%H_%M_%S") 106 | args.save_path = os.path.join(args.save_path, f"{args.task}" if args.task else "train") 107 | #args.save_path = os.path.join(args.save_path, timestr) 108 | save_path = args.save_path 109 | 110 | if not os.path.exists(save_path): 111 | os.makedirs(save_path) 112 | 113 | if args.seed is not None: 114 | random.seed(args.seed) 115 | torch.manual_seed(args.seed) 116 | cudnn.deterministic = True 117 | warnings.warn('You have chosen to seed training. ' 118 | 'This will turn on the CUDNN deterministic setting, ' 119 | 'which can slow down your training considerably! ' 120 | 'You may see unexpected behavior when restarting ' 121 | 'from checkpoints.') 122 | 123 | if args.gpu is not None: 124 | warnings.warn('You have chosen a specific GPU. This will completely ' 125 | 'disable data parallelism.') 126 | 127 | if args.dist_url == "env://" and args.world_size == -1: 128 | args.world_size = int(os.environ["WORLD_SIZE"]) 129 | 130 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 131 | 132 | ngpus_per_node = torch.cuda.device_count() 133 | if args.multiprocessing_distributed: 134 | # Since we have ngpus_per_node processes per node, the total world_size 135 | # needs to be adjusted accordingly 136 | args.world_size = ngpus_per_node * args.world_size 137 | # Use torch.multiprocessing.spawn to launch distributed processes: the 138 | # main_worker process function 139 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 140 | else: 141 | # Simply call main_worker function 142 | main_worker(args.gpu, ngpus_per_node, args) 143 | 144 | 145 | def main_worker(gpu, ngpus_per_node, args): 146 | args.gpu = gpu 147 | 148 | if args.gpu is not None: 149 | print("Use GPU: {} for training".format(args.gpu)) 150 | 151 | if args.distributed: 152 | if args.dist_url == "env://" and args.rank == -1: 153 | args.rank = int(os.environ["RANK"]) 154 | if args.multiprocessing_distributed: 155 | # For multiprocessing distributed training, rank needs to be the 156 | # global rank among all the processes 157 | args.rank = args.rank * ngpus_per_node + gpu 158 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 159 | world_size=args.world_size, rank=args.rank) 160 | 161 | print("=> creating model '{}'".format("ipt_base")) 162 | model = ipt_base().cuda() 163 | 164 | 165 | 166 | # define loss function (criterion) and optimizer 167 | 168 | # IPT uses L1 loss function 169 | #criterion = nn.CrossEntropyLoss().cuda(args.gpu) 170 | criterion = nn.L1Loss() 171 | 172 | optimizer = torch.optim.Adam(model.parameters(), args.lr, 173 | betas=(0.9, 0.999), 174 | weight_decay=args.weight_decay) 175 | 176 | # optionally resume from a checkpoint 177 | if args.resume: 178 | if os.path.isfile(args.resume): 179 | print("=> loading checkpoint '{}'".format(args.resume)) 180 | checkpoint = torch.load(args.resume) 181 | if not args.reset_epoch: 182 | args.start_epoch = checkpoint['epoch'] 183 | #args.start_epoch = 10 184 | model.load_state_dict(checkpoint['state_dict']) 185 | #optimizer.load_state_dict(checkpoint['optimizer']) 186 | print("=> loaded checkpoint '{}' (epoch {})" 187 | .format(args.resume, checkpoint['epoch'])) 188 | else: 189 | print("=> no checkpoint found at '{}'".format(args.resume)) 190 | 191 | cudnn.benchmark = True 192 | 193 | if args.distributed: 194 | # For multiprocessing distributed, DistributedDataParallel constructor 195 | # should always set the single device scope, otherwise, 196 | # DistributedDataParallel will use all available devices. 197 | if args.gpu is not None: 198 | torch.cuda.set_device(args.gpu) 199 | model.cuda(args.gpu) 200 | # When using a single GPU per process and per 201 | # DistributedDataParallel, we need to divide the batch size 202 | # ourselves based on the total number of GPUs we have 203 | args.batch_size = int(args.batch_size / ngpus_per_node) 204 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 205 | else: 206 | model.cuda() 207 | # DistributedDataParallel will divide and allocate batch_size to all 208 | # available GPUs if device_ids are not set 209 | model = torch.nn.parallel.DistributedDataParallel(model) 210 | elif args.gpu is not None: 211 | torch.cuda.set_device(args.gpu) 212 | model = model.cuda(args.gpu) 213 | else: 214 | # DataParallel will divide and allocate batch_size to all available GPUs 215 | model = torch.nn.DataParallel(model).cuda() 216 | input_size = 48 217 | 218 | # Data loading code 219 | 220 | trans = transforms.Compose([transforms.ToTensor(), 221 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 222 | ]) 223 | if args.eval: 224 | val_dataset = ImageProcessDataset(args.eval_data, transform=trans) 225 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16) 226 | #raise RuntimeError("evaluate dataloader not implemented") 227 | validate(val_loader, model, criterion, args) 228 | return 229 | 230 | train_dataset = ImageProcessDataset(args.data, transform=trans) 231 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16) 232 | 233 | args.epoch_size = len(train_loader) 234 | print(f"Each epoch contains {args.epoch_size} iterations") 235 | 236 | print(f"Using {args.lr_policy} learning rate") 237 | 238 | if args.distributed: 239 | raise RuntimeError("distributed not implemented") 240 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 241 | else: 242 | train_sampler = None 243 | 244 | scaler = amp.GradScaler() if args.fp16 else None 245 | print(args) 246 | for epoch in range(args.start_epoch, args.epochs): 247 | if args.distributed: 248 | train_sampler.set_epoch(epoch) 249 | # adjust_learning_rate(optimizer, epoch, args) 250 | # train for one epoch 251 | train(train_loader, model, criterion, optimizer, epoch, args, scaler) 252 | 253 | # evaluate on validation set 254 | # validate(val_loader, model, criterion, args) 255 | 256 | 257 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 258 | and args.rank % ngpus_per_node == 0): 259 | model_to_save = getattr(model, "module", model) 260 | save_checkpoint({ 261 | 'epoch': epoch + 1, 262 | 'state_dict': model_to_save.state_dict(), 263 | 'optimizer' : optimizer.state_dict(), 264 | }, path=args.save_path) 265 | 266 | 267 | task_map = {"denoise30": 0, "denoise50": 1, "SRx2": 2, "SRx3": 3, "SRx4": 4, "dehaze": 5} 268 | 269 | def train(train_loader, model, criterion, optimizer, epoch, args, scaler=None): 270 | # train for one epoch 271 | batch_time = AverageMeter() 272 | data_time = AverageMeter() 273 | losses = AverageMeter() 274 | psnr_out = AverageMeter() 275 | 276 | # switch to train mode 277 | model.train() 278 | 279 | end = time.time() 280 | 281 | if args.lr_policy == 'naive': 282 | local_lr = adjust_learning_rate_naive(optimizer, epoch, args) 283 | elif args.lr_policy == 'step': 284 | local_lr = adjust_learning_rate(optimizer, epoch, args) 285 | elif args.lr_policy == 'epoch_poly': 286 | local_lr = adjust_learning_rate_epoch_poly(optimizer, epoch, args) 287 | 288 | 289 | for i, (target, input_group) in enumerate(train_loader): 290 | 291 | # set random task 292 | task_id = random.randint(0, 5) if not args.task else task_map[args.task] 293 | input = input_group[task_id] 294 | model.module.set_task(task_id) 295 | #print(f"Iter {i}, task_id: {task_id}") 296 | #for m in model.module.modules(): 297 | # if isinstance(m, ) 298 | #print(m.weight.device) 299 | global_iter = epoch * args.epoch_size + i 300 | 301 | if args.lr_policy == 'iter_poly': 302 | local_lr = adjust_learning_rate_poly(optimizer, global_iter, args) 303 | elif args.lr_policy == 'cosine': 304 | local_lr = adjust_learning_rate_cosine(optimizer, global_iter, args) 305 | 306 | # measure data loading time 307 | data_time.update(time.time() - end) 308 | 309 | if args.gpu is not None: 310 | input = input.cuda(args.gpu, non_blocking=True) 311 | target = target.cuda(args.gpu, non_blocking=True) 312 | 313 | target = target.cuda() 314 | if scaler is None: 315 | # compute output 316 | output = model(input) 317 | #print(output.device, target.device) 318 | loss = criterion(output, target) 319 | else: 320 | with autocast(): 321 | # compute output 322 | output = model(input) 323 | #print(output.device, target.device) 324 | loss = criterion(output, target) 325 | 326 | # measure accuracy and record loss 327 | output = (output * 0.5 + 0.5) * 255. 328 | target = (target * 0.5 + 0.5) * 255. 329 | psnr = PSNR()(output, target) 330 | losses.update(loss.item(), input.size(0)) 331 | psnr_out.update(psnr.item(), input.size(0)) 332 | 333 | # compute gradient and do SGD step 334 | optimizer.zero_grad() 335 | 336 | if scaler is None: 337 | # compute gradient and do SGD step 338 | loss.backward() 339 | optimizer.step() 340 | else: 341 | scaler.scale(loss).backward() 342 | scaler.step(optimizer) 343 | scaler.update() 344 | 345 | # measure elapsed time 346 | batch_time.update(time.time() - end) 347 | end = time.time() 348 | 349 | if i % args.print_freq == 0: 350 | print('Epoch: [{0}][{1}/{2}]\t' 351 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 352 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 353 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 354 | 'PSNR {psnr.val:.3f} ({psnr.avg:.3f})\t' 355 | 'LR: {lr: .6f}'.format( 356 | epoch, i, args.epoch_size, batch_time=batch_time, 357 | data_time=data_time, loss=losses, psnr=psnr_out, lr=local_lr)) 358 | 359 | 360 | def validate(val_loader, model, criterion, args): 361 | batch_time = AverageMeter() 362 | losses = AverageMeter() 363 | psnr_out = AverageMeter() 364 | psnr_in = AverageMeter() 365 | 366 | # switch to evaluate mode 367 | model.eval() 368 | P = PSNR() 369 | with torch.no_grad(): 370 | end = time.time() 371 | for i, (target, input_group) in enumerate(val_loader): 372 | task_id = task_map[args.task] 373 | input = input_group[task_id] 374 | model.module.set_task(task_id) 375 | if args.gpu is not None: 376 | input = input.cuda(args.gpu, non_blocking=True) 377 | target = target.cuda(args.gpu, non_blocking=True) 378 | target = target.cuda() 379 | # compute output 380 | output = model(input) 381 | loss = criterion(output, target) 382 | 383 | # measure accuracy and record loss 384 | output = (output * 0.5 + 0.5) * 255. 385 | target = (target * 0.5 + 0.5) * 255. 386 | psnr1 = P(output, target) 387 | # psnr2 = P(input.cuda(), target) 388 | losses.update(loss.item(), input.size(0)) 389 | psnr_out.update(psnr1.item(), input.size(0)) 390 | # psnr_in.update(psnr2.item(), input.size(0)) 391 | 392 | # measure elapsed time 393 | batch_time.update(time.time() - end) 394 | end = time.time() 395 | 396 | if i % args.print_freq == 0: 397 | print('Test: [{0}/{1}]\t' 398 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 399 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 400 | 'PSNR_Out {psnr1.val:.3f} ({psnr1.avg:.3f})\t' 401 | 'PSNR_In {psnr2.val:.3f} ({psnr2.avg:.3f})'.format( 402 | i, len(val_loader), batch_time=batch_time, loss=losses, psnr1=psnr_out, psnr2=psnr_in 403 | )) 404 | 405 | print(' * PSNR_Out {psnr1.val:.3f} ({psnr1.avg:.3f})\t' 406 | 'PSNR_In {psnr2.val:.3f} ({psnr2.avg:.3f})'.format(psnr1=psnr_out, psnr2=psnr_in)) 407 | 408 | return psnr_out.avg 409 | 410 | 411 | def save_checkpoint(state, path='./', filename='checkpoint'): 412 | saved_path = os.path.join(path, filename+'.pth.tar') 413 | torch.save(state, saved_path) 414 | ''' 415 | if is_best: 416 | state_dict = state['state_dict'] 417 | new_state_dict = OrderedDict() 418 | best_path = os.path.join(path, 'model_best.pth') 419 | for key in state_dict.keys(): 420 | if 'module.' in key: 421 | new_state_dict[key.replace('module.', '')] = state_dict[key].cpu() 422 | else: 423 | new_state_dict[key] = state_dict[key].cpu() 424 | torch.save(new_state_dict, best_path) 425 | ''' 426 | 427 | class AverageMeter(object): 428 | """Computes and stores the average and current value""" 429 | def __init__(self): 430 | self.reset() 431 | 432 | def reset(self): 433 | self.val = 0 434 | self.avg = 0 435 | self.sum = 0 436 | self.count = 0 437 | 438 | def update(self, val, n=1): 439 | self.val = val 440 | self.sum += val * n 441 | self.count += n 442 | self.avg = self.sum / self.count 443 | 444 | def adjust_learning_rate_naive(optimizer, epoch, args): 445 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 446 | lr = args.lr if epoch < 200 else 2/5 * args.lr 447 | for param_group in optimizer.param_groups: 448 | param_group['lr'] = lr 449 | return lr 450 | 451 | def adjust_learning_rate(optimizer, epoch, args): 452 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 453 | lr = args.lr * (0.1 ** (epoch // 30)) 454 | for param_group in optimizer.param_groups: 455 | param_group['lr'] = lr 456 | return lr 457 | 458 | def adjust_learning_rate_epoch_poly(optimizer, epoch, args): 459 | """Sets epoch poly learning rate""" 460 | lr = args.lr * ((1 - epoch * 1.0 / args.epochs) ** args.power) 461 | for param_group in optimizer.param_groups: 462 | param_group['lr'] = lr 463 | return lr 464 | 465 | def adjust_learning_rate_poly(optimizer, global_iter, args): 466 | """Sets iter poly learning rate""" 467 | lr = args.lr * ((1 - global_iter * 1.0 / (args.epochs * args.epoch_size)) ** args.power) 468 | for param_group in optimizer.param_groups: 469 | param_group['lr'] = lr 470 | return lr 471 | 472 | def adjust_learning_rate_cosine(optimizer, global_iter, args): 473 | warmup_lr = args.lr * args.warmup_lr_multiplier 474 | max_iter = args.epochs * args.epoch_size 475 | warmup_iter = args.warmup_epochs * args.epoch_size 476 | if global_iter < warmup_iter: 477 | slope = (args.lr - warmup_lr) / warmup_iter 478 | lr = slope * global_iter + warmup_lr 479 | else: 480 | lr = 0.5 * args.lr * (1 + math.cos(math.pi * (global_iter - warmup_iter) / (max_iter - warmup_iter))) 481 | 482 | for param_group in optimizer.param_groups: 483 | param_group['lr'] = lr 484 | return lr 485 | 486 | class PSNR: 487 | """Peak Signal to Noise Ratio 488 | img1 and img2 have range [0, 255]""" 489 | 490 | def __init__(self): 491 | self.name = "PSNR" 492 | 493 | @staticmethod 494 | def __call__(img1, img2): 495 | mse = torch.mean((img1 - img2) ** 2) 496 | return 20 * torch.log10(255.0 / torch.sqrt(mse)) 497 | ''' 498 | class SSIM: 499 | """Structure Similarity 500 | img1, img2: [0, 255]""" 501 | 502 | def __init__(self): 503 | self.name = "SSIM" 504 | 505 | @staticmethod 506 | def __call__(img1, img2): 507 | if not img1.shape == img2.shape: 508 | raise ValueError("Input images must have the same dimensions.") 509 | if img1.ndim == 2: # Grey or Y-channel image 510 | return self._ssim(img1, img2) 511 | elif img1.ndim == 3: 512 | if img1.shape[2] == 3: 513 | ssims = [] 514 | for i in range(3): 515 | ssims.append(ssim(img1, img2)) 516 | return np.array(ssims).mean() 517 | elif img1.shape[2] == 1: 518 | return self._ssim(np.squeeze(img1), np.squeeze(img2)) 519 | else: 520 | raise ValueError("Wrong input image dimensions.") 521 | 522 | @staticmethod 523 | def _ssim(img1, img2): 524 | C1 = (0.01 * 255) ** 2 525 | C2 = (0.03 * 255) ** 2 526 | 527 | img1 = img1.astype(np.float64) 528 | img2 = img2.astype(np.float64) 529 | kernel = cv2.getGaussianKernel(11, 1.5) 530 | window = np.outer(kernel, kernel.transpose()) 531 | 532 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 533 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 534 | mu1_sq = mu1 ** 2 535 | mu2_sq = mu2 ** 2 536 | mu1_mu2 = mu1 * mu2 537 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 538 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 539 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 540 | 541 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( 542 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) 543 | ) 544 | return ssim_map.mean() 545 | ''' 546 | if __name__ == '__main__': 547 | main() 548 | -------------------------------------------------------------------------------- /pretrain.sh: -------------------------------------------------------------------------------- 1 | python main.py --seed 0 \ 2 | --lr 5e-5 \ 3 | --save-path "./ckpt" \ 4 | --epochs 300 \ 5 | --data path-to-data 6 | --batch-size 256 7 | 8 | # finetune option "denoise30", "denoise50", "SRx2", "SRx3", "SRx4", "dehaze" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1 2 | torchvision==0.8.2 -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | python main.py --seed 0 \ 2 | --eval-data path-to-data \ 3 | --batch-size 256 \ 4 | --eval \ 5 | --resume path-to-ckpt \ 6 | --task "SRx3" 7 | 8 | # finetune option "denoise30", "denoise50", "SRx2", "SRx3", "SRx4", "dehaze" --------------------------------------------------------------------------------