├── .gitignore
├── Dehazing
├── ITS
│ ├── data
│ │ ├── __init__.py
│ │ ├── data_augment.py
│ │ └── data_load.py
│ ├── eval.py
│ ├── main.py
│ ├── models
│ │ ├── OKNet.py
│ │ └── layers.py
│ ├── train.py
│ ├── utils.py
│ └── valid.py
├── OTS
│ ├── data
│ │ ├── __init__.py
│ │ ├── data_augment.py
│ │ └── data_load.py
│ ├── eval.py
│ ├── main.py
│ ├── models
│ │ ├── OKNet.py
│ │ └── layers.py
│ ├── train.py
│ ├── utils.py
│ └── valid.py
└── README.md
├── Desnowing
├── README.md
├── data
│ ├── __init__.py
│ ├── data_augment.py
│ └── data_load.py
├── eval.py
├── main.py
├── models
│ ├── OKNet.py
│ └── layers.py
├── train.py
├── utils.py
└── valid.py
├── LICENSE
├── README.md
└── pytorch-gradual-warmup-lr
├── setup.py
└── warmup_scheduler
├── __init__.py
├── run.py
└── scheduler.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/Dehazing/ITS/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_augment import PairRandomCrop, PairCompose, PairRandomHorizontalFilp, PairToTensor
2 | from .data_load import train_dataloader, test_dataloader, valid_dataloader
3 |
--------------------------------------------------------------------------------
/Dehazing/ITS/data/data_augment.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torchvision.transforms as transforms
3 | import torchvision.transforms.functional as F
4 |
5 |
6 | class PairRandomCrop(transforms.RandomCrop):
7 |
8 | def __call__(self, image, label):
9 |
10 | if self.padding is not None:
11 | image = F.pad(image, self.padding, self.fill, self.padding_mode)
12 | label = F.pad(label, self.padding, self.fill, self.padding_mode)
13 |
14 | # pad the width if needed
15 | if self.pad_if_needed and image.size[0] < self.size[1]:
16 | image = F.pad(image, (self.size[1] - image.size[0], 0), self.fill, self.padding_mode)
17 | label = F.pad(label, (self.size[1] - label.size[0], 0), self.fill, self.padding_mode)
18 | # pad the height if needed
19 | if self.pad_if_needed and image.size[1] < self.size[0]:
20 | image = F.pad(image, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode)
21 | label = F.pad(label, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode)
22 |
23 | i, j, h, w = self.get_params(image, self.size)
24 |
25 | return F.crop(image, i, j, h, w), F.crop(label, i, j, h, w)
26 |
27 |
28 | class PairCompose(transforms.Compose):
29 | def __call__(self, image, label):
30 | for t in self.transforms:
31 | image, label = t(image, label)
32 | return image, label
33 |
34 |
35 | class PairRandomHorizontalFilp(transforms.RandomHorizontalFlip):
36 | def __call__(self, img, label):
37 | """
38 | Args:
39 | img (PIL Image): Image to be flipped.
40 |
41 | Returns:
42 | PIL Image: Randomly flipped image.
43 | """
44 | if random.random() < self.p:
45 | return F.hflip(img), F.hflip(label)
46 | return img, label
47 |
48 |
49 | class PairToTensor(transforms.ToTensor):
50 | def __call__(self, pic, label):
51 | """
52 | Args:
53 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
54 |
55 | Returns:
56 | Tensor: Converted image.
57 | """
58 | return F.to_tensor(pic), F.to_tensor(label)
59 |
--------------------------------------------------------------------------------
/Dehazing/ITS/data/data_load.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from PIL import Image as Image
5 | from data import PairCompose, PairRandomCrop, PairRandomHorizontalFilp, PairToTensor
6 | from torchvision.transforms import functional as F
7 | from torch.utils.data import Dataset, DataLoader
8 |
9 |
10 | def train_dataloader(path, batch_size=64, num_workers=0, use_transform=True):
11 | image_dir = os.path.join(path, 'train')
12 |
13 | transform = None
14 | if use_transform:
15 | transform = PairCompose(
16 | [
17 | PairRandomCrop(256),
18 | PairRandomHorizontalFilp(),
19 | PairToTensor()
20 | ]
21 | )
22 | dataloader = DataLoader(
23 | DeblurDataset(image_dir, transform=transform),
24 | batch_size=batch_size,
25 | shuffle=True,
26 | num_workers=num_workers,
27 | pin_memory=True
28 | )
29 | return dataloader
30 |
31 |
32 | def test_dataloader(path, batch_size=1, num_workers=0):
33 | image_dir = os.path.join(path, 'test')
34 | dataloader = DataLoader(
35 | DeblurDataset(image_dir, is_test=True),
36 | batch_size=batch_size,
37 | shuffle=False,
38 | num_workers=num_workers,
39 | pin_memory=True
40 | )
41 |
42 | return dataloader
43 |
44 |
45 | def valid_dataloader(path, batch_size=1, num_workers=0):
46 | dataloader = DataLoader(
47 | DeblurDataset(os.path.join(path, 'test')),
48 | batch_size=batch_size,
49 | shuffle=False,
50 | num_workers=num_workers
51 | )
52 |
53 | return dataloader
54 |
55 |
56 | class DeblurDataset(Dataset):
57 | def __init__(self, image_dir, transform=None, is_test=False):
58 | self.image_dir = image_dir
59 | self.image_list = os.listdir(os.path.join(image_dir, 'hazy/'))
60 | self._check_image(self.image_list)
61 | self.image_list.sort()
62 | self.transform = transform
63 | self.is_test = is_test
64 |
65 | def __len__(self):
66 | return len(self.image_list)
67 |
68 | def __getitem__(self, idx):
69 | image = Image.open(os.path.join(self.image_dir, 'hazy', self.image_list[idx]))
70 | label = Image.open(os.path.join(self.image_dir, 'gt', self.image_list[idx].split('_')[0]+'.png'))
71 |
72 | if self.transform:
73 | image, label = self.transform(image, label)
74 | else:
75 | image = F.to_tensor(image)
76 | label = F.to_tensor(label)
77 | if self.is_test:
78 | name = self.image_list[idx]
79 | return image, label, name
80 | return image, label
81 |
82 | @staticmethod
83 | def _check_image(lst):
84 | for x in lst:
85 | splits = x.split('.')
86 | if splits[-1] not in ['png', 'jpg', 'jpeg']:
87 | raise ValueError
88 |
--------------------------------------------------------------------------------
/Dehazing/ITS/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torchvision.transforms import functional as F
4 | import numpy as np
5 | from utils import Adder
6 | from data import test_dataloader
7 | from skimage.metrics import peak_signal_noise_ratio
8 | import time
9 | from pytorch_msssim import ssim
10 | import torch.nn.functional as f
11 |
12 | from skimage import img_as_ubyte
13 | import cv2
14 |
15 | def _eval(model, args):
16 | state_dict = torch.load(args.test_model)
17 | model.load_state_dict(state_dict['model'])
18 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19 | dataloader = test_dataloader(args.data_dir, batch_size=1, num_workers=0)
20 | torch.cuda.empty_cache()
21 | adder = Adder()
22 | model.eval()
23 | factor = 4
24 | with torch.no_grad():
25 | psnr_adder = Adder()
26 | ssim_adder = Adder()
27 |
28 | for iter_idx, data in enumerate(dataloader):
29 | input_img, label_img, name = data
30 |
31 | input_img = input_img.to(device)
32 |
33 | h, w = input_img.shape[2], input_img.shape[3]
34 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor)
35 | padh = H-h if h%factor!=0 else 0
36 | padw = W-w if w%factor!=0 else 0
37 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect')
38 |
39 | tm = time.time()
40 |
41 | pred = model(input_img)[2]
42 | pred = pred[:,:,:h,:w]
43 |
44 | elapsed = time.time() - tm
45 | adder(elapsed)
46 |
47 | pred_clip = torch.clamp(pred, 0, 1)
48 |
49 | pred_numpy = pred_clip.squeeze(0).cpu().numpy()
50 | label_numpy = label_img.squeeze(0).cpu().numpy()
51 |
52 |
53 | label_img = (label_img).cuda()
54 | psnr_val = 10 * torch.log10(1 / f.mse_loss(pred_clip, label_img))
55 | down_ratio = max(1, round(min(H, W) / 256))
56 | ssim_val = ssim(f.adaptive_avg_pool2d(pred_clip, (int(H / down_ratio), int(W / down_ratio))),
57 | f.adaptive_avg_pool2d(label_img, (int(H / down_ratio), int(W / down_ratio))),
58 | data_range=1, size_average=False)
59 | print('%d iter PSNR_dehazing: %.2f ssim: %f' % (iter_idx + 1, psnr_val, ssim_val))
60 | ssim_adder(ssim_val)
61 |
62 | if args.save_image:
63 | save_name = os.path.join(args.result_dir, name[0])
64 | pred_clip += 0.5 / 255
65 | pred = F.to_pil_image(pred_clip.squeeze(0).cpu(), 'RGB')
66 | pred.save(save_name)
67 |
68 | psnr_mimo = peak_signal_noise_ratio(pred_numpy, label_numpy, data_range=1)
69 | psnr_adder(psnr_val)
70 |
71 | print('%d iter PSNR: %.2f time: %f' % (iter_idx + 1, psnr_mimo, elapsed))
72 |
73 | print('==========================================================')
74 | print('The average PSNR is %.2f dB' % (psnr_adder.average()))
75 | print('The average SSIM is %.5f dB' % (ssim_adder.average()))
76 |
77 | print("Average time: %f" % adder.average())
78 |
79 |
--------------------------------------------------------------------------------
/Dehazing/ITS/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | from torch.backends import cudnn
5 | from models.OKNet import build_net
6 | from train import _train
7 | from eval import _eval
8 | import numpy as np
9 | import random
10 |
11 | def main(args):
12 | # CUDNN
13 | cudnn.benchmark = True
14 |
15 | if not os.path.exists('results/'):
16 | os.makedirs(args.model_save_dir)
17 | if not os.path.exists('results/' + args.model_name + '/'):
18 | os.makedirs('results/' + args.model_name + '/')
19 | if not os.path.exists(args.model_save_dir):
20 | os.makedirs(args.model_save_dir)
21 | if not os.path.exists(args.result_dir):
22 | os.makedirs(args.result_dir)
23 |
24 | model = build_net()
25 | print(model)
26 |
27 | if torch.cuda.is_available():
28 | model.cuda()
29 | if args.mode == 'train':
30 | _train(model, args)
31 |
32 | elif args.mode == 'test':
33 | _eval(model, args)
34 |
35 |
36 | if __name__ == '__main__':
37 | parser = argparse.ArgumentParser()
38 |
39 | # Directories
40 | parser.add_argument('--model_name', default='OKNet', type=str)
41 |
42 | parser.add_argument('--mode', default='test', choices=['train', 'test'], type=str)
43 | parser.add_argument('--data_dir', type=str, default='')
44 |
45 | # Train
46 | parser.add_argument('--batch_size', type=int, default=8)
47 | parser.add_argument('--learning_rate', type=float, default=2e-4)
48 | parser.add_argument('--weight_decay', type=float, default=0)
49 | parser.add_argument('--num_epoch', type=int, default=1000)
50 | parser.add_argument('--print_freq', type=int, default=100)
51 | parser.add_argument('--num_worker', type=int, default=16)
52 | parser.add_argument('--save_freq', type=int, default=20)
53 | parser.add_argument('--valid_freq', type=int, default=20)
54 | parser.add_argument('--resume', type=str, default='')
55 |
56 |
57 | # Test
58 | parser.add_argument('--test_model', type=str, default='')
59 | parser.add_argument('--save_image', type=bool, default=False, choices=[True, False])
60 |
61 | args = parser.parse_args()
62 | args.model_save_dir = os.path.join('results/', 'OKNet', 'ITS/')
63 | args.result_dir = os.path.join('results/', args.model_name, 'test')
64 | if not os.path.exists(args.model_save_dir):
65 | os.makedirs(args.model_save_dir)
66 | command = 'cp ' + 'models/layers.py ' + args.model_save_dir
67 | os.system(command)
68 | command = 'cp ' + 'models/OKNet.py ' + args.model_save_dir
69 | os.system(command)
70 | command = 'cp ' + 'train.py ' + args.model_save_dir
71 | os.system(command)
72 | command = 'cp ' + 'main.py ' + args.model_save_dir
73 | os.system(command)
74 | print(args)
75 | main(args)
76 |
--------------------------------------------------------------------------------
/Dehazing/ITS/models/OKNet.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from .layers import *
6 |
7 | class EBlock(nn.Module):
8 | def __init__(self, out_channel, num_res=8):
9 | super(EBlock, self).__init__()
10 |
11 | layers = [ResBlock(out_channel, out_channel) for _ in range(num_res)]
12 |
13 | self.layers = nn.Sequential(*layers)
14 |
15 | def forward(self, x):
16 | return self.layers(x)
17 |
18 |
19 | class DBlock(nn.Module):
20 | def __init__(self, channel, num_res=8):
21 | super(DBlock, self).__init__()
22 |
23 | layers = [ResBlock(channel, channel) for _ in range(num_res)]
24 | self.layers = nn.Sequential(*layers)
25 |
26 | def forward(self, x):
27 | return self.layers(x)
28 |
29 |
30 | class SCM(nn.Module):
31 | def __init__(self, out_plane):
32 | super(SCM, self).__init__()
33 | self.main = nn.Sequential(
34 | BasicConv(3, out_plane//4, kernel_size=3, stride=1, relu=True),
35 | BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True),
36 | BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True),
37 | BasicConv(out_plane // 2, out_plane, kernel_size=1, stride=1, relu=False),
38 | nn.InstanceNorm2d(out_plane, affine=True)
39 | )
40 |
41 | def forward(self, x):
42 | x = self.main(x)
43 | return x
44 |
45 |
46 | class BottleNect(nn.Module):
47 | def __init__(self, dim) -> None:
48 | super().__init__()
49 |
50 | ker = 63
51 | pad = ker // 2
52 | self.in_conv = nn.Sequential(
53 | nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1),
54 | nn.GELU()
55 | )
56 | self.out_conv = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1)
57 | self.dw_13 = nn.Conv2d(dim, dim, kernel_size=(1,ker), padding=(0,pad), stride=1, groups=dim)
58 | self.dw_31 = nn.Conv2d(dim, dim, kernel_size=(ker,1), padding=(pad,0), stride=1, groups=dim)
59 | self.dw_33 = nn.Conv2d(dim, dim, kernel_size=ker, padding=pad, stride=1, groups=dim)
60 | self.dw_11 = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1, groups=dim)
61 |
62 | self.act = nn.ReLU()
63 |
64 | ### sca ###
65 | self.conv = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
66 | self.pool = nn.AdaptiveAvgPool2d((1,1))
67 |
68 | ### fca ###
69 | self.fac_conv = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
70 | self.fac_pool = nn.AdaptiveAvgPool2d((1,1))
71 | self.fgm = FGM(dim)
72 |
73 | def forward(self, x):
74 | out = self.in_conv(x)
75 |
76 | ### fca ###
77 | x_att = self.fac_conv(self.fac_pool(out))
78 | x_fft = torch.fft.fft2(out, norm='backward')
79 | x_fft = x_att * x_fft
80 | x_fca = torch.fft.ifft2(x_fft, dim=(-2,-1), norm='backward')
81 | x_fca = torch.abs(x_fca)
82 |
83 | ### fca ###
84 | ### sca ###
85 | x_att = self.conv(self.pool(x_fca))
86 | x_sca = x_att * x_fca
87 | ### sca ###
88 | x_sca = self.fgm(x_sca)
89 |
90 | out = x + self.dw_13(out) + self.dw_31(out) + self.dw_33(out) + self.dw_11(out) + x_sca
91 | out = self.act(out)
92 | return self.out_conv(out)
93 |
94 | class FGM(nn.Module):
95 | def __init__(self, dim) -> None:
96 | super().__init__()
97 |
98 | self.conv = nn.Conv2d(dim, dim*2, 3, 1, 1)
99 |
100 | self.dwconv1 = nn.Conv2d(dim, dim, 1, 1, groups=1)
101 | self.dwconv2 = nn.Conv2d(dim, dim, 1, 1, groups=1)
102 | self.alpha = nn.Parameter(torch.zeros(dim, 1, 1))
103 | self.beta = nn.Parameter(torch.ones(dim, 1, 1))
104 |
105 | def forward(self, x):
106 | # res = x.clone()
107 | fft_size = x.size()[2:]
108 | x1 = self.dwconv1(x)
109 | x2 = self.dwconv2(x)
110 |
111 | x2_fft = torch.fft.fft2(x2, norm='backward')
112 |
113 | out = x1 * x2_fft
114 |
115 | out = torch.fft.ifft2(out, dim=(-2,-1), norm='backward')
116 | out = torch.abs(out)
117 |
118 | return out * self.alpha + x * self.beta
119 |
120 |
121 | class FAM(nn.Module):
122 | def __init__(self, channel):
123 | super(FAM, self).__init__()
124 | self.merge = BasicConv(channel*2, channel, kernel_size=3, stride=1, relu=False)
125 |
126 | def forward(self, x1, x2):
127 | return self.merge(torch.cat([x1, x2], dim=1))
128 |
129 | class OKNet(nn.Module):
130 | def __init__(self, num_res=4):
131 | super(OKNet, self).__init__()
132 |
133 | base_channel = 32
134 |
135 | self.Encoder = nn.ModuleList([
136 | EBlock(base_channel, num_res),
137 | EBlock(base_channel*2, num_res),
138 | EBlock(base_channel*4, num_res),
139 | ])
140 |
141 | self.feat_extract = nn.ModuleList([
142 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1),
143 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2),
144 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2),
145 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True),
146 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True),
147 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1)
148 | ])
149 |
150 | self.Decoder = nn.ModuleList([
151 | DBlock(base_channel * 4, num_res),
152 | DBlock(base_channel * 2, num_res),
153 | DBlock(base_channel, num_res)
154 | ])
155 |
156 | self.Convs = nn.ModuleList([
157 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
158 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1),
159 | ])
160 |
161 | self.ConvsOut = nn.ModuleList(
162 | [
163 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1),
164 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),
165 | ]
166 | )
167 |
168 | self.FAM1 = FAM(base_channel * 4)
169 | self.SCM1 = SCM(base_channel * 4)
170 | self.FAM2 = FAM(base_channel * 2)
171 | self.SCM2 = SCM(base_channel * 2)
172 |
173 | self.bottelneck = BottleNect(base_channel * 4)
174 |
175 |
176 | def forward(self, x):
177 | x_2 = F.interpolate(x, scale_factor=0.5)
178 | x_4 = F.interpolate(x_2, scale_factor=0.5)
179 | z2 = self.SCM2(x_2)
180 | z4 = self.SCM1(x_4)
181 |
182 | outputs = list()
183 | # 256
184 | x_ = self.feat_extract[0](x)
185 | res1 = self.Encoder[0](x_)
186 | # 128
187 | z = self.feat_extract[1](res1)
188 | z = self.FAM2(z, z2)
189 | res2 = self.Encoder[1](z)
190 | # 64
191 | z = self.feat_extract[2](res2)
192 | z = self.FAM1(z, z4)
193 | z = self.Encoder[2](z)
194 | z = self.bottelneck(z)
195 |
196 | z = self.Decoder[0](z)
197 | z_ = self.ConvsOut[0](z)
198 | # 128
199 | z = self.feat_extract[3](z)
200 | outputs.append(z_+x_4)
201 |
202 | z = torch.cat([z, res2], dim=1)
203 | z = self.Convs[0](z)
204 | z = self.Decoder[1](z)
205 | z_ = self.ConvsOut[1](z)
206 | # 256
207 | z = self.feat_extract[4](z)
208 | outputs.append(z_+x_2)
209 |
210 | z = torch.cat([z, res1], dim=1)
211 | z = self.Convs[1](z)
212 | z = self.Decoder[2](z)
213 | z = self.feat_extract[5](z)
214 | outputs.append(z+x)
215 |
216 | return outputs
217 |
218 | def build_net():
219 | return OKNet()
220 |
221 |
--------------------------------------------------------------------------------
/Dehazing/ITS/models/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class BasicConv(nn.Module):
6 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False):
7 | super(BasicConv, self).__init__()
8 | if bias and norm:
9 | bias = False
10 |
11 | padding = kernel_size // 2
12 | layers = list()
13 | if transpose:
14 | padding = kernel_size // 2 -1
15 | layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
16 | else:
17 | layers.append(
18 | nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
19 | if norm:
20 | layers.append(nn.BatchNorm2d(out_channel))
21 | if relu:
22 | layers.append(nn.GELU())
23 | self.main = nn.Sequential(*layers)
24 |
25 | def forward(self, x):
26 | return self.main(x)
27 |
28 |
29 | class ResBlock(nn.Module):
30 | def __init__(self, in_channel, out_channel):
31 | super(ResBlock, self).__init__()
32 | self.main = nn.Sequential(
33 | BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True),
34 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
35 | )
36 |
37 | def forward(self, x):
38 | return self.main(x) + x
--------------------------------------------------------------------------------
/Dehazing/ITS/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from data import train_dataloader
4 | from utils import Adder, Timer, check_lr
5 | from torch.utils.tensorboard import SummaryWriter
6 | from valid import _valid
7 | import torch.nn.functional as F
8 | import torch.nn as nn
9 |
10 | from warmup_scheduler import GradualWarmupScheduler
11 |
12 | def _train(model, args):
13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14 | criterion = torch.nn.L1Loss()
15 |
16 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-8)
17 | dataloader = train_dataloader(args.data_dir, args.batch_size, args.num_worker)
18 | max_iter = len(dataloader)
19 | warmup_epochs=3
20 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_epoch-warmup_epochs, eta_min=1e-6)
21 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
22 | scheduler.step()
23 | epoch = 1
24 | if args.resume:
25 | state = torch.load(args.resume)
26 | epoch = state['epoch']
27 | optimizer.load_state_dict(state['optimizer'])
28 | model.load_state_dict(state['model'])
29 | print('Resume from %d'%epoch)
30 | epoch += 1
31 |
32 | writer = SummaryWriter()
33 | epoch_pixel_adder = Adder()
34 | epoch_fft_adder = Adder()
35 | iter_pixel_adder = Adder()
36 | iter_fft_adder = Adder()
37 | epoch_timer = Timer('m')
38 | iter_timer = Timer('m')
39 | best_psnr=-1
40 |
41 | for epoch_idx in range(epoch, args.num_epoch + 1):
42 |
43 | epoch_timer.tic()
44 | iter_timer.tic()
45 | for iter_idx, batch_data in enumerate(dataloader):
46 |
47 | input_img, label_img = batch_data
48 | input_img = input_img.to(device)
49 | label_img = label_img.to(device)
50 |
51 | optimizer.zero_grad()
52 | pred_img = model(input_img)
53 | label_img2 = F.interpolate(label_img, scale_factor=0.5, mode='bilinear')
54 | label_img4 = F.interpolate(label_img, scale_factor=0.25, mode='bilinear')
55 | l1 = criterion(pred_img[0], label_img4)
56 | l2 = criterion(pred_img[1], label_img2)
57 | l3 = criterion(pred_img[2], label_img)
58 | loss_content = l1+l2+l3
59 |
60 | label_fft1 = torch.fft.fft2(label_img4, dim=(-2,-1))
61 | label_fft1 = torch.stack((label_fft1.real, label_fft1.imag), -1)
62 |
63 | pred_fft1 = torch.fft.fft2(pred_img[0], dim=(-2,-1))
64 | pred_fft1 = torch.stack((pred_fft1.real, pred_fft1.imag), -1)
65 |
66 | label_fft2 = torch.fft.fft2(label_img2, dim=(-2,-1))
67 | label_fft2 = torch.stack((label_fft2.real, label_fft2.imag), -1)
68 |
69 | pred_fft2 = torch.fft.fft2(pred_img[1], dim=(-2,-1))
70 | pred_fft2 = torch.stack((pred_fft2.real, pred_fft2.imag), -1)
71 |
72 | label_fft3 = torch.fft.fft2(label_img, dim=(-2,-1))
73 | label_fft3 = torch.stack((label_fft3.real, label_fft3.imag), -1)
74 |
75 | pred_fft3 = torch.fft.fft2(pred_img[2], dim=(-2,-1))
76 | pred_fft3 = torch.stack((pred_fft3.real, pred_fft3.imag), -1)
77 |
78 | f1 = criterion(pred_fft1, label_fft1)
79 | f2 = criterion(pred_fft2, label_fft2)
80 | f3 = criterion(pred_fft3, label_fft3)
81 | loss_fft = f1+f2+f3
82 |
83 | loss = loss_content + 0.1 * loss_fft
84 | loss.backward()
85 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.001)
86 | optimizer.step()
87 |
88 | iter_pixel_adder(loss_content.item())
89 | iter_fft_adder(loss_fft.item())
90 |
91 | epoch_pixel_adder(loss_content.item())
92 | epoch_fft_adder(loss_fft.item())
93 |
94 | if (iter_idx + 1) % args.print_freq == 0:
95 | print("Time: %7.4f Epoch: %03d Iter: %4d/%4d LR: %.10f Loss content: %7.4f Loss fft: %7.4f" % (
96 | iter_timer.toc(), epoch_idx, iter_idx + 1, max_iter, scheduler.get_lr()[0], iter_pixel_adder.average(),
97 | iter_fft_adder.average()))
98 | writer.add_scalar('Pixel Loss', iter_pixel_adder.average(), iter_idx + (epoch_idx-1)* max_iter)
99 | writer.add_scalar('FFT Loss', iter_fft_adder.average(), iter_idx + (epoch_idx - 1) * max_iter)
100 |
101 | iter_timer.tic()
102 | iter_pixel_adder.reset()
103 | iter_fft_adder.reset()
104 | overwrite_name = os.path.join(args.model_save_dir, 'model.pkl')
105 | torch.save({'model': model.state_dict(),
106 | 'optimizer': optimizer.state_dict(),
107 | 'epoch': epoch_idx}, overwrite_name)
108 |
109 | if epoch_idx % args.save_freq == 0:
110 | save_name = os.path.join(args.model_save_dir, 'model_%d.pkl' % epoch_idx)
111 | torch.save({'model': model.state_dict()}, save_name)
112 | print("EPOCH: %02d\nElapsed time: %4.2f Epoch Pixel Loss: %7.4f Epoch FFT Loss: %7.4f" % (
113 | epoch_idx, epoch_timer.toc(), epoch_pixel_adder.average(), epoch_fft_adder.average()))
114 | epoch_fft_adder.reset()
115 | epoch_pixel_adder.reset()
116 | scheduler.step()
117 | if epoch_idx % args.valid_freq == 0:
118 | val = _valid(model, args, epoch_idx)
119 | print('%03d epoch \n Average PSNR %.2f dB' % (epoch_idx, val))
120 | writer.add_scalar('PSNR', val, epoch_idx)
121 | if val >= best_psnr:
122 | torch.save({'model': model.state_dict()}, os.path.join(args.model_save_dir, 'Best.pkl'))
123 | save_name = os.path.join(args.model_save_dir, 'Final.pkl')
124 | torch.save({'model': model.state_dict()}, save_name)
125 |
--------------------------------------------------------------------------------
/Dehazing/ITS/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 |
4 |
5 | class Adder(object):
6 | def __init__(self):
7 | self.count = 0
8 | self.num = float(0)
9 |
10 | def reset(self):
11 | self.count = 0
12 | self.num = float(0)
13 |
14 | def __call__(self, num):
15 | self.count += 1
16 | self.num += num
17 |
18 | def average(self):
19 | return self.num / self.count
20 |
21 |
22 | class Timer(object):
23 | def __init__(self, option='s'):
24 | self.tm = 0
25 | self.option = option
26 | if option == 's':
27 | self.devider = 1
28 | elif option == 'm':
29 | self.devider = 60
30 | else:
31 | self.devider = 3600
32 |
33 | def tic(self):
34 | self.tm = time.time()
35 |
36 | def toc(self):
37 | return (time.time() - self.tm) / self.devider
38 |
39 |
40 | def check_lr(optimizer):
41 | for i, param_group in enumerate(optimizer.param_groups):
42 | lr = param_group['lr']
43 | return lr
44 |
--------------------------------------------------------------------------------
/Dehazing/ITS/valid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.transforms import functional as F
3 | from data import valid_dataloader
4 | from utils import Adder
5 | import os
6 | from skimage.metrics import peak_signal_noise_ratio
7 | import torch.nn.functional as f
8 |
9 |
10 | def _valid(model, args, ep):
11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12 | its = valid_dataloader(args.data_dir, batch_size=1, num_workers=0)
13 | model.eval()
14 | psnr_adder = Adder()
15 |
16 | with torch.no_grad():
17 | print('Start Evaluation')
18 | factor = 4
19 | for idx, data in enumerate(its):
20 | input_img, label_img = data
21 | input_img = input_img.to(device)
22 |
23 | h, w = input_img.shape[2], input_img.shape[3]
24 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor)
25 | padh = H-h if h%factor!=0 else 0
26 | padw = W-w if w%factor!=0 else 0
27 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect')
28 |
29 | if not os.path.exists(os.path.join(args.result_dir, '%d' % (ep))):
30 | os.mkdir(os.path.join(args.result_dir, '%d' % (ep)))
31 |
32 | pred = model(input_img)[2]
33 | pred = pred[:,:,:h,:w]
34 |
35 | pred_clip = torch.clamp(pred, 0, 1)
36 | p_numpy = pred_clip.squeeze(0).cpu().numpy()
37 | label_numpy = label_img.squeeze(0).cpu().numpy()
38 |
39 | psnr = peak_signal_noise_ratio(p_numpy, label_numpy, data_range=1)
40 |
41 | psnr_adder(psnr)
42 | print('\r%03d'%idx, end=' ')
43 |
44 | print('\n')
45 | model.train()
46 | return psnr_adder.average()
47 |
--------------------------------------------------------------------------------
/Dehazing/OTS/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_augment import PairRandomCrop, PairCompose, PairRandomHorizontalFilp, PairToTensor
2 | from .data_load import train_dataloader, test_dataloader, valid_dataloader
3 |
--------------------------------------------------------------------------------
/Dehazing/OTS/data/data_augment.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torchvision.transforms as transforms
3 | import torchvision.transforms.functional as F
4 |
5 |
6 | class PairRandomCrop(transforms.RandomCrop):
7 |
8 | def __call__(self, image, label):
9 |
10 | if self.padding is not None:
11 | image = F.pad(image, self.padding, self.fill, self.padding_mode)
12 | label = F.pad(label, self.padding, self.fill, self.padding_mode)
13 |
14 | # pad the width if needed
15 | if self.pad_if_needed and image.size[0] < self.size[1]:
16 | image = F.pad(image, (self.size[1] - image.size[0], 0), self.fill, self.padding_mode)
17 | label = F.pad(label, (self.size[1] - label.size[0], 0), self.fill, self.padding_mode)
18 | # pad the height if needed
19 | if self.pad_if_needed and image.size[1] < self.size[0]:
20 | image = F.pad(image, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode)
21 | label = F.pad(label, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode)
22 |
23 | i, j, h, w = self.get_params(image, self.size)
24 |
25 | return F.crop(image, i, j, h, w), F.crop(label, i, j, h, w)
26 |
27 |
28 | class PairCompose(transforms.Compose):
29 | def __call__(self, image, label):
30 | for t in self.transforms:
31 | image, label = t(image, label)
32 | return image, label
33 |
34 |
35 | class PairRandomHorizontalFilp(transforms.RandomHorizontalFlip):
36 | def __call__(self, img, label):
37 | """
38 | Args:
39 | img (PIL Image): Image to be flipped.
40 |
41 | Returns:
42 | PIL Image: Randomly flipped image.
43 | """
44 | if random.random() < self.p:
45 | return F.hflip(img), F.hflip(label)
46 | return img, label
47 |
48 |
49 | class PairToTensor(transforms.ToTensor):
50 | def __call__(self, pic, label):
51 | """
52 | Args:
53 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
54 |
55 | Returns:
56 | Tensor: Converted image.
57 | """
58 | return F.to_tensor(pic), F.to_tensor(label)
59 |
--------------------------------------------------------------------------------
/Dehazing/OTS/data/data_load.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from PIL import Image as Image
5 | from torchvision.transforms import functional as F
6 | from torch.utils.data import Dataset, DataLoader
7 | from PIL import ImageFile
8 | ImageFile.LOAD_TRUNCATED_IMAGES = True
9 |
10 | def train_dataloader(path, batch_size=64, num_workers=0):
11 | image_dir = os.path.join(path, 'train')
12 |
13 | dataloader = DataLoader(
14 | DeblurDataset(image_dir, ps=256),
15 | batch_size=batch_size,
16 | shuffle=True,
17 | num_workers=num_workers,
18 | pin_memory=True
19 | )
20 | return dataloader
21 |
22 |
23 | def test_dataloader(path, batch_size=1, num_workers=0):
24 | image_dir = os.path.join(path, 'test')
25 | dataloader = DataLoader(
26 | DeblurDataset(image_dir, is_test=True),
27 | batch_size=batch_size,
28 | shuffle=False,
29 | num_workers=num_workers,
30 | pin_memory=True
31 | )
32 |
33 | return dataloader
34 |
35 |
36 | def valid_dataloader(path, batch_size=1, num_workers=0):
37 | dataloader = DataLoader(
38 | DeblurDataset(os.path.join(path, 'test'), is_valid=True),
39 | batch_size=batch_size,
40 | shuffle=False,
41 | num_workers=num_workers
42 | )
43 |
44 | return dataloader
45 |
46 | import random
47 | class DeblurDataset(Dataset):
48 | def __init__(self, image_dir, transform=None, is_test=False, is_valid=False, ps=None):
49 | self.image_dir = image_dir
50 | self.image_list = os.listdir(os.path.join(image_dir, 'hazy/'))
51 | self._check_image(self.image_list)
52 | self.image_list.sort()
53 | self.transform = transform
54 | self.is_test = is_test
55 | self.is_valid = is_valid
56 | self.ps = ps
57 |
58 | def __len__(self):
59 | return len(self.image_list)
60 |
61 | def __getitem__(self, idx):
62 | image = Image.open(os.path.join(self.image_dir, 'hazy', self.image_list[idx])).convert('RGB')
63 | if self.is_valid or self.is_test:
64 | label = Image.open(os.path.join(self.image_dir, 'gt', self.image_list[idx].split('_')[0]+'.png')).convert('RGB')
65 | else:
66 | label = Image.open(os.path.join(self.image_dir, 'gt', self.image_list[idx].split('_')[0]+'.jpg')).convert('RGB')
67 | ps = self.ps
68 |
69 | if self.ps is not None:
70 | image = F.to_tensor(image)
71 | label = F.to_tensor(label)
72 |
73 | hh, ww = label.shape[1], label.shape[2]
74 |
75 | rr = random.randint(0, hh-ps)
76 | cc = random.randint(0, ww-ps)
77 |
78 | image = image[:, rr:rr+ps, cc:cc+ps]
79 | label = label[:, rr:rr+ps, cc:cc+ps]
80 |
81 | if random.random() < 0.5:
82 | image = image.flip(2)
83 | label = label.flip(2)
84 | else:
85 | image = F.to_tensor(image)
86 | label = F.to_tensor(label)
87 |
88 | if self.is_test:
89 | name = self.image_list[idx]
90 | return image, label, name
91 | return image, label
92 |
93 |
94 |
95 | @staticmethod
96 | def _check_image(lst):
97 | for x in lst:
98 | splits = x.split('.')
99 | if splits[-1] not in ['png', 'jpg', 'jpeg']:
100 | raise ValueError
101 |
--------------------------------------------------------------------------------
/Dehazing/OTS/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torchvision.transforms import functional as F
4 | import numpy as np
5 | from utils import Adder
6 | from data import test_dataloader
7 | from skimage.metrics import peak_signal_noise_ratio
8 | import time
9 | from pytorch_msssim import ssim
10 | import torch.nn.functional as f
11 |
12 | from skimage import img_as_ubyte
13 | import cv2
14 | # ---------------------------------------------------
15 |
16 | def _eval(model, args):
17 | state_dict = torch.load(args.test_model)
18 | model.load_state_dict(state_dict['model'])
19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20 | dataloader = test_dataloader(args.data_dir, batch_size=1, num_workers=0)
21 | torch.cuda.empty_cache()
22 | adder = Adder()
23 | model.eval()
24 | factor = 4
25 | with torch.no_grad():
26 | psnr_adder = Adder()
27 | ssim_adder = Adder()
28 |
29 | for iter_idx, data in enumerate(dataloader):
30 | input_img, label_img, name = data
31 |
32 | input_img = input_img.to(device)
33 |
34 | h, w = input_img.shape[2], input_img.shape[3]
35 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor)
36 | padh = H-h if h%factor!=0 else 0
37 | padw = W-w if w%factor!=0 else 0
38 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect')
39 |
40 | tm = time.time()
41 |
42 | pred = model(input_img)[2]
43 | pred = pred[:,:,:h,:w]
44 |
45 | elapsed = time.time() - tm
46 | adder(elapsed)
47 |
48 | pred_clip = torch.clamp(pred, 0, 1)
49 |
50 | pred_numpy = pred_clip.squeeze(0).cpu().numpy()
51 | label_numpy = label_img.squeeze(0).cpu().numpy()
52 |
53 | label_img = (label_img).cuda()
54 | psnr_val = 10 * torch.log10(1 / f.mse_loss(pred_clip, label_img))
55 | down_ratio = max(1, round(min(H, W) / 256))
56 | ssim_val = ssim(f.adaptive_avg_pool2d(pred_clip, (int(H / down_ratio), int(W / down_ratio))),
57 | f.adaptive_avg_pool2d(label_img, (int(H / down_ratio), int(W / down_ratio))),
58 | data_range=1, size_average=False)
59 | print('%d iter PSNR_dehazing: %.2f ssim: %f' % (iter_idx + 1, psnr_val, ssim_val))
60 | ssim_adder(ssim_val)
61 |
62 | if args.save_image:
63 | save_name = os.path.join(args.result_dir, name[0])
64 | pred_clip += 0.5 / 255
65 | pred = F.to_pil_image(pred_clip.squeeze(0).cpu(), 'RGB')
66 | pred.save(save_name)
67 |
68 | psnr_mimo = peak_signal_noise_ratio(pred_numpy, label_numpy, data_range=1)
69 | psnr_adder(psnr_val)
70 |
71 | print('%d iter PSNR: %.2f time: %f' % (iter_idx + 1, psnr_mimo, elapsed))
72 |
73 | print('==========================================================')
74 | print('The average PSNR is %.2f dB' % (psnr_adder.average()))
75 | print('The average SSIM is %.4f dB' % (ssim_adder.average()))
76 |
77 | print("Average time: %f" % adder.average())
78 |
79 |
--------------------------------------------------------------------------------
/Dehazing/OTS/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | from torch.backends import cudnn
5 | from models.OKNet import build_net
6 | from train import _train
7 | from eval import _eval
8 | import numpy as np
9 | import random
10 |
11 | def main(args):
12 | # CUDNN
13 | cudnn.benchmark = True
14 |
15 | if not os.path.exists('results/'):
16 | os.makedirs(args.model_save_dir)
17 | if not os.path.exists('results/' + args.model_name + '/'):
18 | os.makedirs('results/' + args.model_name + '/')
19 | if not os.path.exists(args.model_save_dir):
20 | os.makedirs(args.model_save_dir)
21 | if not os.path.exists(args.result_dir):
22 | os.makedirs(args.result_dir)
23 |
24 | model = build_net()
25 | print(model)
26 |
27 | if torch.cuda.is_available():
28 | model.cuda()
29 | if args.mode == 'train':
30 | _train(model, args)
31 |
32 | elif args.mode == 'test':
33 | _eval(model, args)
34 |
35 |
36 | if __name__ == '__main__':
37 | parser = argparse.ArgumentParser()
38 |
39 | # Directories
40 | parser.add_argument('--model_name', default='OKNet',type=str)
41 | parser.add_argument('--data_dir', type=str, default='')
42 | parser.add_argument('--mode', default='test', choices=['train', 'test'], type=str)
43 |
44 | # Train
45 | parser.add_argument('--batch_size', type=int, default=8)
46 | parser.add_argument('--learning_rate', type=float, default=1e-4)
47 | parser.add_argument('--weight_decay', type=float, default=0)
48 | parser.add_argument('--num_epoch', type=int, default=30)
49 | parser.add_argument('--print_freq', type=int, default=100)
50 | parser.add_argument('--num_worker', type=int, default=8)
51 | parser.add_argument('--save_freq', type=int, default=1)
52 | parser.add_argument('--valid_freq', type=int, default=1)
53 | parser.add_argument('--resume', type=str, default='')
54 |
55 |
56 | # Test
57 | parser.add_argument('--test_model', type=str, default='')
58 | parser.add_argument('--save_image', type=bool, default=False, choices=[True, False])
59 |
60 | args = parser.parse_args()
61 | args.model_save_dir = os.path.join('results/', 'OKNet', 'ots/')
62 | args.result_dir = os.path.join('results/', args.model_name, 'test')
63 | if not os.path.exists(args.model_save_dir):
64 | os.makedirs(args.model_save_dir)
65 | command = 'cp ' + 'models/layers.py ' + args.model_save_dir
66 | os.system(command)
67 | command = 'cp ' + 'models/OKNet.py ' + args.model_save_dir
68 | os.system(command)
69 | command = 'cp ' + 'train.py ' + args.model_save_dir
70 | os.system(command)
71 | command = 'cp ' + 'main.py ' + args.model_save_dir
72 | os.system(command)
73 | print(args)
74 | main(args)
75 |
--------------------------------------------------------------------------------
/Dehazing/OTS/models/OKNet.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from .layers import *
6 |
7 | class EBlock(nn.Module):
8 | def __init__(self, out_channel, num_res=8):
9 | super(EBlock, self).__init__()
10 |
11 | layers = [ResBlock(out_channel, out_channel) for _ in range(num_res)]
12 |
13 | self.layers = nn.Sequential(*layers)
14 |
15 | def forward(self, x):
16 | return self.layers(x)
17 |
18 |
19 | class DBlock(nn.Module):
20 | def __init__(self, channel, num_res=8):
21 | super(DBlock, self).__init__()
22 |
23 | layers = [ResBlock(channel, channel) for _ in range(num_res)]
24 | self.layers = nn.Sequential(*layers)
25 |
26 | def forward(self, x):
27 | return self.layers(x)
28 |
29 |
30 | class SCM(nn.Module):
31 | def __init__(self, out_plane):
32 | super(SCM, self).__init__()
33 | self.main = nn.Sequential(
34 | BasicConv(3, out_plane//4, kernel_size=3, stride=1, relu=True),
35 | BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True),
36 | BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True),
37 | BasicConv(out_plane // 2, out_plane, kernel_size=1, stride=1, relu=False),
38 | nn.InstanceNorm2d(out_plane, affine=True)
39 | )
40 |
41 | def forward(self, x):
42 | x = self.main(x)
43 | return x
44 |
45 |
46 | class BottleNect(nn.Module):
47 | def __init__(self, dim) -> None:
48 | super().__init__()
49 |
50 | ker = 63
51 | pad = ker // 2
52 | self.in_conv = nn.Sequential(
53 | nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1),
54 | nn.GELU()
55 | )
56 | self.out_conv = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1)
57 | self.dw_13 = nn.Conv2d(dim, dim, kernel_size=(1,ker), padding=(0,pad), stride=1, groups=dim)
58 | self.dw_31 = nn.Conv2d(dim, dim, kernel_size=(ker,1), padding=(pad,0), stride=1, groups=dim)
59 | self.dw_33 = nn.Conv2d(dim, dim, kernel_size=ker, padding=pad, stride=1, groups=dim)
60 | self.dw_11 = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1, groups=dim)
61 |
62 | self.act = nn.ReLU()
63 |
64 | ### sca ###
65 | self.conv = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
66 | self.pool = nn.AdaptiveAvgPool2d((1,1))
67 |
68 | ### fca ###
69 | self.fac_conv = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
70 | self.fac_pool = nn.AdaptiveAvgPool2d((1,1))
71 | self.fgm = FGM(dim)
72 |
73 | def forward(self, x):
74 | out = self.in_conv(x)
75 |
76 | ### fca ###
77 | x_att = self.fac_conv(self.fac_pool(out))
78 | x_fft = torch.fft.fft2(out, norm='backward')
79 | x_fft = x_att * x_fft
80 | x_fca = torch.fft.ifft2(x_fft, dim=(-2,-1), norm='backward')
81 | x_fca = torch.abs(x_fca)
82 |
83 | ### fca ###
84 | ### sca ###
85 | x_att = self.conv(self.pool(x_fca))
86 | x_sca = x_att * x_fca
87 | ### sca ###
88 | x_sca = self.fgm(x_sca)
89 |
90 | out = x + self.dw_13(out) + self.dw_31(out) + self.dw_33(out) + self.dw_11(out) + x_sca
91 | out = self.act(out)
92 | return self.out_conv(out)
93 |
94 | class FGM(nn.Module):
95 | def __init__(self, dim) -> None:
96 | super().__init__()
97 |
98 | self.conv = nn.Conv2d(dim, dim*2, 3, 1, 1)
99 |
100 | self.dwconv1 = nn.Conv2d(dim, dim, 1, 1, groups=1)
101 | self.dwconv2 = nn.Conv2d(dim, dim, 1, 1, groups=1)
102 | self.alpha = nn.Parameter(torch.zeros(dim, 1, 1))
103 | self.beta = nn.Parameter(torch.ones(dim, 1, 1))
104 |
105 | def forward(self, x):
106 | # res = x.clone()
107 | fft_size = x.size()[2:]
108 | x1 = self.dwconv1(x)
109 | x2 = self.dwconv2(x)
110 |
111 | x2_fft = torch.fft.fft2(x2, norm='backward')
112 |
113 | out = x1 * x2_fft
114 |
115 | out = torch.fft.ifft2(out, dim=(-2,-1), norm='backward')
116 | out = torch.abs(out)
117 |
118 | return out * self.alpha + x * self.beta
119 |
120 |
121 | class FAM(nn.Module):
122 | def __init__(self, channel):
123 | super(FAM, self).__init__()
124 | self.merge = BasicConv(channel*2, channel, kernel_size=3, stride=1, relu=False)
125 |
126 | def forward(self, x1, x2):
127 | return self.merge(torch.cat([x1, x2], dim=1))
128 |
129 | class OKNet(nn.Module):
130 | def __init__(self, num_res=4):
131 | super(OKNet, self).__init__()
132 |
133 | base_channel = 32
134 |
135 | self.Encoder = nn.ModuleList([
136 | EBlock(base_channel, num_res),
137 | EBlock(base_channel*2, num_res),
138 | EBlock(base_channel*4, num_res),
139 | ])
140 |
141 | self.feat_extract = nn.ModuleList([
142 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1),
143 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2),
144 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2),
145 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True),
146 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True),
147 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1)
148 | ])
149 |
150 | self.Decoder = nn.ModuleList([
151 | DBlock(base_channel * 4, num_res),
152 | DBlock(base_channel * 2, num_res),
153 | DBlock(base_channel, num_res)
154 | ])
155 |
156 | self.Convs = nn.ModuleList([
157 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
158 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1),
159 | ])
160 |
161 | self.ConvsOut = nn.ModuleList(
162 | [
163 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1),
164 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),
165 | ]
166 | )
167 |
168 | self.FAM1 = FAM(base_channel * 4)
169 | self.SCM1 = SCM(base_channel * 4)
170 | self.FAM2 = FAM(base_channel * 2)
171 | self.SCM2 = SCM(base_channel * 2)
172 |
173 | self.bottelneck = BottleNect(base_channel * 4)
174 |
175 |
176 | def forward(self, x):
177 | x_2 = F.interpolate(x, scale_factor=0.5)
178 | x_4 = F.interpolate(x_2, scale_factor=0.5)
179 | z2 = self.SCM2(x_2)
180 | z4 = self.SCM1(x_4)
181 |
182 | outputs = list()
183 | # 256
184 | x_ = self.feat_extract[0](x)
185 | res1 = self.Encoder[0](x_)
186 | # 128
187 | z = self.feat_extract[1](res1)
188 | z = self.FAM2(z, z2)
189 | res2 = self.Encoder[1](z)
190 | # 64
191 | z = self.feat_extract[2](res2)
192 | z = self.FAM1(z, z4)
193 | z = self.Encoder[2](z)
194 | z = self.bottelneck(z)
195 |
196 | z = self.Decoder[0](z)
197 | z_ = self.ConvsOut[0](z)
198 | # 128
199 | z = self.feat_extract[3](z)
200 | outputs.append(z_+x_4)
201 |
202 | z = torch.cat([z, res2], dim=1)
203 | z = self.Convs[0](z)
204 | z = self.Decoder[1](z)
205 | z_ = self.ConvsOut[1](z)
206 | # 256
207 | z = self.feat_extract[4](z)
208 | outputs.append(z_+x_2)
209 |
210 | z = torch.cat([z, res1], dim=1)
211 | z = self.Convs[1](z)
212 | z = self.Decoder[2](z)
213 | z = self.feat_extract[5](z)
214 | outputs.append(z+x)
215 |
216 | return outputs
217 |
218 | def build_net():
219 | return OKNet()
220 |
221 |
--------------------------------------------------------------------------------
/Dehazing/OTS/models/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class BasicConv(nn.Module):
6 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False):
7 | super(BasicConv, self).__init__()
8 | if bias and norm:
9 | bias = False
10 |
11 | padding = kernel_size // 2
12 | layers = list()
13 | if transpose:
14 | padding = kernel_size // 2 -1
15 | layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
16 | else:
17 | layers.append(
18 | nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
19 | if norm:
20 | layers.append(nn.BatchNorm2d(out_channel))
21 | if relu:
22 | layers.append(nn.GELU())
23 | self.main = nn.Sequential(*layers)
24 |
25 | def forward(self, x):
26 | return self.main(x)
27 |
28 |
29 | class ResBlock(nn.Module):
30 | def __init__(self, in_channel, out_channel):
31 | super(ResBlock, self).__init__()
32 | self.main = nn.Sequential(
33 | BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True),
34 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
35 | )
36 |
37 | def forward(self, x):
38 | return self.main(x) + x
--------------------------------------------------------------------------------
/Dehazing/OTS/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from data import train_dataloader
4 | from utils import Adder, Timer, check_lr
5 | from torch.utils.tensorboard import SummaryWriter
6 | from valid import _valid
7 | import torch.nn.functional as F
8 | import torch.nn as nn
9 |
10 | from warmup_scheduler import GradualWarmupScheduler
11 |
12 | def _train(model, args):
13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14 | criterion = torch.nn.L1Loss()
15 |
16 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-8)
17 | dataloader = train_dataloader(args.data_dir, args.batch_size, args.num_worker)
18 | max_iter = len(dataloader)
19 | warmup_epochs=1
20 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_epoch-warmup_epochs, eta_min=1e-6)
21 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
22 | scheduler.step()
23 | epoch = 1
24 | if args.resume:
25 | state = torch.load(args.resume)
26 | epoch = state['epoch']
27 | optimizer.load_state_dict(state['optimizer'])
28 | model.load_state_dict(state['model'])
29 | print('Resume from %d'%epoch)
30 | epoch += 1
31 |
32 | writer = SummaryWriter()
33 | epoch_pixel_adder = Adder()
34 | epoch_fft_adder = Adder()
35 | iter_pixel_adder = Adder()
36 | iter_fft_adder = Adder()
37 | epoch_timer = Timer('m')
38 | iter_timer = Timer('m')
39 | best_psnr=-1
40 |
41 | eval_now = max_iter//6-1
42 |
43 | for epoch_idx in range(epoch, args.num_epoch + 1):
44 |
45 | epoch_timer.tic()
46 | iter_timer.tic()
47 | for iter_idx, batch_data in enumerate(dataloader):
48 |
49 | input_img, label_img = batch_data
50 | input_img = input_img.to(device)
51 | label_img = label_img.to(device)
52 |
53 | optimizer.zero_grad()
54 | pred_img = model(input_img)
55 | label_img2 = F.interpolate(label_img, scale_factor=0.5, mode='bilinear')
56 | label_img4 = F.interpolate(label_img, scale_factor=0.25, mode='bilinear')
57 | l1 = criterion(pred_img[0], label_img4)
58 | l2 = criterion(pred_img[1], label_img2)
59 | l3 = criterion(pred_img[2], label_img)
60 | loss_content = l1+l2+l3
61 |
62 | label_fft1 = torch.fft.fft2(label_img4, dim=(-2,-1))
63 | label_fft1 = torch.stack((label_fft1.real, label_fft1.imag), -1)
64 |
65 | pred_fft1 = torch.fft.fft2(pred_img[0], dim=(-2,-1))
66 | pred_fft1 = torch.stack((pred_fft1.real, pred_fft1.imag), -1)
67 |
68 | label_fft2 = torch.fft.fft2(label_img2, dim=(-2,-1))
69 | label_fft2 = torch.stack((label_fft2.real, label_fft2.imag), -1)
70 |
71 | pred_fft2 = torch.fft.fft2(pred_img[1], dim=(-2,-1))
72 | pred_fft2 = torch.stack((pred_fft2.real, pred_fft2.imag), -1)
73 |
74 | label_fft3 = torch.fft.fft2(label_img, dim=(-2,-1))
75 | label_fft3 = torch.stack((label_fft3.real, label_fft3.imag), -1)
76 |
77 | pred_fft3 = torch.fft.fft2(pred_img[2], dim=(-2,-1))
78 | pred_fft3 = torch.stack((pred_fft3.real, pred_fft3.imag), -1)
79 |
80 | f1 = criterion(pred_fft1, label_fft1)
81 | f2 = criterion(pred_fft2, label_fft2)
82 | f3 = criterion(pred_fft3, label_fft3)
83 | loss_fft = f1+f2+f3
84 |
85 | loss = loss_content + 0.1 * loss_fft
86 | loss.backward()
87 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01)
88 | optimizer.step()
89 |
90 | iter_pixel_adder(loss_content.item())
91 | iter_fft_adder(loss_fft.item())
92 |
93 | epoch_pixel_adder(loss_content.item())
94 | epoch_fft_adder(loss_fft.item())
95 |
96 | if (iter_idx + 1) % args.print_freq == 0:
97 | print("Time: %7.4f Epoch: %03d Iter: %4d/%4d LR: %.10f Loss content: %7.4f Loss fft: %7.4f" % (
98 | iter_timer.toc(), epoch_idx, iter_idx + 1, max_iter, scheduler.get_lr()[0], iter_pixel_adder.average(),
99 | iter_fft_adder.average()))
100 | writer.add_scalar('Pixel Loss', iter_pixel_adder.average(), iter_idx + (epoch_idx-1)* max_iter)
101 | writer.add_scalar('FFT Loss', iter_fft_adder.average(), iter_idx + (epoch_idx - 1) * max_iter)
102 |
103 | iter_timer.tic()
104 | iter_pixel_adder.reset()
105 | iter_fft_adder.reset()
106 |
107 |
108 | if iter_idx%eval_now==0 and iter_idx>0 and (epoch_idx>20 or epoch_idx == 1):
109 |
110 | save_name = os.path.join(args.model_save_dir, 'model_%d_%d.pkl' % (epoch_idx, iter_idx))
111 | torch.save({'model': model.state_dict()}, save_name)
112 |
113 | val_gopro = _valid(model, args, epoch_idx)
114 | print('%03d epoch \n Average GOPRO PSNR %.2f dB' % (epoch_idx, val_gopro))
115 | writer.add_scalar('PSNR_GOPRO', val_gopro, epoch_idx)
116 | if val_gopro >= best_psnr:
117 | torch.save({'model': model.state_dict()}, os.path.join(args.model_save_dir, 'Best.pkl'))
118 |
119 |
120 | overwrite_name = os.path.join(args.model_save_dir, 'model.pkl')
121 | torch.save({'model': model.state_dict()}, overwrite_name)
122 |
123 |
124 | if epoch_idx % args.save_freq == 0:
125 | save_name = os.path.join(args.model_save_dir, 'model_%d.pkl' % epoch_idx)
126 | torch.save({'model': model.state_dict()}, save_name)
127 | print("EPOCH: %02d\nElapsed time: %4.2f Epoch Pixel Loss: %7.4f Epoch FFT Loss: %7.4f" % (
128 | epoch_idx, epoch_timer.toc(), epoch_pixel_adder.average(), epoch_fft_adder.average()))
129 | epoch_fft_adder.reset()
130 | epoch_pixel_adder.reset()
131 | scheduler.step()
132 |
133 | if epoch_idx % args.valid_freq == 0:
134 | val = _valid(model, args, epoch_idx)
135 | print('%03d epoch \n Average PSNR %.2f dB' % (epoch_idx, val))
136 | writer.add_scalar('PSNR', val, epoch_idx)
137 | if val >= best_psnr:
138 | torch.save({'model': model.state_dict()}, os.path.join(args.model_save_dir, 'Best.pkl'))
139 | save_name = os.path.join(args.model_save_dir, 'Final.pkl')
140 | torch.save({'model': model.state_dict()}, save_name)
141 |
--------------------------------------------------------------------------------
/Dehazing/OTS/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 |
4 |
5 | class Adder(object):
6 | def __init__(self):
7 | self.count = 0
8 | self.num = float(0)
9 |
10 | def reset(self):
11 | self.count = 0
12 | self.num = float(0)
13 |
14 | def __call__(self, num):
15 | self.count += 1
16 | self.num += num
17 |
18 | def average(self):
19 | return self.num / self.count
20 |
21 |
22 | class Timer(object):
23 | def __init__(self, option='s'):
24 | self.tm = 0
25 | self.option = option
26 | if option == 's':
27 | self.devider = 1
28 | elif option == 'm':
29 | self.devider = 60
30 | else:
31 | self.devider = 3600
32 |
33 | def tic(self):
34 | self.tm = time.time()
35 |
36 | def toc(self):
37 | return (time.time() - self.tm) / self.devider
38 |
39 |
40 | def check_lr(optimizer):
41 | for i, param_group in enumerate(optimizer.param_groups):
42 | lr = param_group['lr']
43 | return lr
44 |
--------------------------------------------------------------------------------
/Dehazing/OTS/valid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.transforms import functional as F
3 | from data import valid_dataloader
4 | from utils import Adder
5 | import os
6 | from skimage.metrics import peak_signal_noise_ratio
7 | import torch.nn.functional as f
8 |
9 |
10 | def _valid(model, args, ep):
11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12 | ots = valid_dataloader(args.data_dir, batch_size=1, num_workers=0)
13 | model.eval()
14 | psnr_adder = Adder()
15 |
16 | with torch.no_grad():
17 | print('Start Evaluation')
18 | factor = 4
19 | for idx, data in enumerate(ots):
20 | input_img, label_img = data
21 | input_img = input_img.to(device)
22 |
23 | h, w = input_img.shape[2], input_img.shape[3]
24 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor)
25 | padh = H-h if h%factor!=0 else 0
26 | padw = W-w if w%factor!=0 else 0
27 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect')
28 |
29 | if not os.path.exists(os.path.join(args.result_dir, '%d' % (ep))):
30 | os.mkdir(os.path.join(args.result_dir, '%d' % (ep)))
31 |
32 | pred = model(input_img)[2]
33 | pred = pred[:,:,:h,:w]
34 |
35 | pred_clip = torch.clamp(pred, 0, 1)
36 | p_numpy = pred_clip.squeeze(0).cpu().numpy()
37 | label_numpy = label_img.squeeze(0).cpu().numpy()
38 |
39 | psnr = peak_signal_noise_ratio(p_numpy, label_numpy, data_range=1)
40 |
41 | psnr_adder(psnr)
42 | print('\r%03d'%idx, end=' ')
43 |
44 | print('\n')
45 | model.train()
46 | return psnr_adder.average()
47 |
--------------------------------------------------------------------------------
/Dehazing/README.md:
--------------------------------------------------------------------------------
1 | ### Download the Datasets
2 | - reside-indoor [[gdrive](https://drive.google.com/drive/folders/1pbtfTp29j7Ip-mRzDpMpyopCfXd-ZJhC?usp=sharing), [Baidu](https://pan.baidu.com/s/1jD-TU0wdtSoEb4ki-Cut2A?pwd=1lr0)]
3 | - reside-outdoor [[gdrive](https://drive.google.com/drive/folders/1eL4Qs-WNj7PzsKwDRsgUEzmysdjkRs22?usp=sharing)]
4 | - (Separate SOTS test set if needed) [[gdrive](https://drive.google.com/file/d/16j2dwVIa9q_0RtpIXMzhu-7Q6dwz_D1N/view?usp=sharing), [Baidu](https://pan.baidu.com/s/1R6qWri7sG1hC_Ifj-H6DOQ?pwd=o5sk)]
5 | ### Train on RESIDE-Indoor
6 |
7 | ~~~
8 | cd ITS
9 | python main.py --mode train --data_dir your_path/reside-indoor
10 | ~~~
11 |
12 |
13 | ### Train on RESIDE-Outdoor
14 | ~~~
15 | cd OTS
16 | python main.py --mode train --data_dir your_path/reside-outdoor
17 | ~~~
18 |
19 |
20 | ### Evaluation
21 | #### Download the model [here](https://drive.google.com/drive/folders/1jrqqTBFHi2XNvBb9-rm9n0fZHfYZabFw?usp=sharing)
22 | #### Testing on SOTS-Indoor
23 | ~~~
24 | cd ITS
25 | python main.py --data_dir your_path/reside-indoor --test_model path_to_its_model
26 | ~~~
27 | #### Testing on SOTS-Outdoor
28 | ~~~
29 | cd OTS
30 | python main.py --data_dir your_path/reside-outdoor --test_model path_to_ots_model
31 | ~~~
32 |
33 | For training and testing, your directory structure should look like this
34 |
35 | `Your path`
36 | `├──reside-indoor`
37 | `├──train`
38 | `├──gt`
39 | `└──hazy`
40 | `└──test`
41 | `├──gt`
42 | `└──hazy`
43 | `└──reside-outdoor`
44 | `├──train`
45 | `├──gt`
46 | `└──hazy`
47 | `└──test`
48 | `├──gt`
49 | `└──hazy`
50 |
--------------------------------------------------------------------------------
/Desnowing/README.md:
--------------------------------------------------------------------------------
1 | ### Download the Datasets
2 | - SRRS [[gdrive](https://drive.google.com/file/d/11h1cZ0NXx6ev35cl5NKOAL3PCgLlWUl2/view?usp=sharing), [Baidu](https://pan.baidu.com/s/1VXqsamkl12fPsI1Qek97TQ?pwd=vcfg)]
3 | - CSD [[gdrive](https://drive.google.com/file/d/1pns-7uWy-0SamxjA40qOCkkhSu7o7ULb/view?usp=sharing), [Baidu](https://pan.baidu.com/s/1N52Jnx0co9udJeYrbd3blA?pwd=sb4a)]
4 | - Snow100K [[gdrive](https://drive.google.com/file/d/19zJs0cJ6F3G3IlDHLU2BO7nHnCTMNrIS/view?usp=sharing), [Baidu](https://pan.baidu.com/s/1QGd5z9uM6vBKPnD5d7jQmA?pwd=aph4)]
5 |
6 | ### Training
7 |
8 | ~~~
9 | python main.py --mode train --data_dir your_path/CSD
10 | ~~~
11 |
12 | ### Evaluation
13 | #### Download the model [here](https://drive.google.com/drive/folders/1jrqqTBFHi2XNvBb9-rm9n0fZHfYZabFw?usp=sharing)
14 | #### Testing
15 | ~~~
16 | python main.py --data_dir your_path/CSD
17 | ~~~
18 |
19 | For training and testing, your directory structure should look like this
20 |
21 | `Your path`
22 | `├──CSD`
23 | `├──train2500`
24 | `├──Gt`
25 | `└──Snow`
26 | `└──test2000`
27 | `├──Gt`
28 | `└──Snow`
29 | `├──SRRS`
30 | `├──train2500`
31 | `├──Gt`
32 | `└──Snow`
33 | `└──test2000`
34 | `├──Gt`
35 | `└──Snow`
36 | `└──Snow100K`
37 | `├──train2500`
38 | `├──Gt`
39 | `└──Snow`
40 | `└──test2000`
41 | `├──Gt`
42 | `└──Snow`
43 |
--------------------------------------------------------------------------------
/Desnowing/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_augment import PairRandomCrop, PairCompose, PairRandomHorizontalFilp, PairToTensor
2 | from .data_load import train_dataloader, test_dataloader, valid_dataloader
3 |
--------------------------------------------------------------------------------
/Desnowing/data/data_augment.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torchvision.transforms as transforms
3 | import torchvision.transforms.functional as F
4 |
5 |
6 | class PairRandomCrop(transforms.RandomCrop):
7 |
8 | def __call__(self, image, label):
9 |
10 | if self.padding is not None:
11 | image = F.pad(image, self.padding, self.fill, self.padding_mode)
12 | label = F.pad(label, self.padding, self.fill, self.padding_mode)
13 |
14 | # pad the width if needed
15 | if self.pad_if_needed and image.size[0] < self.size[1]:
16 | image = F.pad(image, (self.size[1] - image.size[0], 0), self.fill, self.padding_mode)
17 | label = F.pad(label, (self.size[1] - label.size[0], 0), self.fill, self.padding_mode)
18 | # pad the height if needed
19 | if self.pad_if_needed and image.size[1] < self.size[0]:
20 | image = F.pad(image, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode)
21 | label = F.pad(label, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode)
22 |
23 | i, j, h, w = self.get_params(image, self.size)
24 |
25 | return F.crop(image, i, j, h, w), F.crop(label, i, j, h, w)
26 |
27 |
28 | class PairCompose(transforms.Compose):
29 | def __call__(self, image, label):
30 | for t in self.transforms:
31 | image, label = t(image, label)
32 | return image, label
33 |
34 |
35 | class PairRandomHorizontalFilp(transforms.RandomHorizontalFlip):
36 | def __call__(self, img, label):
37 | """
38 | Args:
39 | img (PIL Image): Image to be flipped.
40 |
41 | Returns:
42 | PIL Image: Randomly flipped image.
43 | """
44 | if random.random() < self.p:
45 | return F.hflip(img), F.hflip(label)
46 | return img, label
47 |
48 |
49 | class PairToTensor(transforms.ToTensor):
50 | def __call__(self, pic, label):
51 | """
52 | Args:
53 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
54 |
55 | Returns:
56 | Tensor: Converted image.
57 | """
58 | return F.to_tensor(pic), F.to_tensor(label)
59 |
--------------------------------------------------------------------------------
/Desnowing/data/data_load.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from PIL import Image as Image
5 | from data import PairCompose, PairRandomCrop, PairRandomHorizontalFilp, PairToTensor
6 | from torchvision.transforms import functional as F
7 | from torch.utils.data import Dataset, DataLoader
8 |
9 |
10 | def train_dataloader(path, batch_size=64, num_workers=0, use_transform=True):
11 | image_dir = os.path.join(path, 'train2500')
12 |
13 | transform = None
14 | if use_transform:
15 | transform = PairCompose(
16 | [
17 | PairRandomCrop(256),
18 | PairRandomHorizontalFilp(),
19 | PairToTensor()
20 | ]
21 | )
22 | dataloader = DataLoader(
23 | DeblurDataset(image_dir, transform=transform),
24 | batch_size=batch_size,
25 | shuffle=True,
26 | num_workers=num_workers,
27 | pin_memory=True
28 | )
29 | return dataloader
30 |
31 |
32 | def test_dataloader(path, batch_size=1, num_workers=0):
33 | image_dir = os.path.join(path, 'test2000')
34 | dataloader = DataLoader(
35 | DeblurDataset(image_dir, is_test=True),
36 | batch_size=batch_size,
37 | shuffle=False,
38 | num_workers=num_workers,
39 | pin_memory=True
40 | )
41 |
42 | return dataloader
43 |
44 |
45 | def valid_dataloader(path, batch_size=1, num_workers=0):
46 | dataloader = DataLoader(
47 | DeblurDataset(os.path.join(path, 'test2000')),
48 | batch_size=batch_size,
49 | shuffle=False,
50 | num_workers=num_workers
51 | )
52 |
53 | return dataloader
54 |
55 |
56 | class DeblurDataset(Dataset):
57 | def __init__(self, image_dir, transform=None, is_test=False):
58 | self.image_dir = image_dir
59 | self.image_list = os.listdir(os.path.join(image_dir, 'Snow/'))
60 | # self._check_image(self.image_list)
61 | self.image_list.sort()
62 | self.transform = transform
63 | self.is_test = is_test
64 |
65 | def __len__(self):
66 | return len(self.image_list)
67 |
68 | def __getitem__(self, idx):
69 | image = Image.open(os.path.join(self.image_dir, 'Snow', self.image_list[idx]))
70 | # label = Image.open(os.path.join(self.image_dir, 'Gt', self.image_list[idx].split('.')[0]+'.jpg'))#srrs+jpg
71 | label = Image.open(os.path.join(self.image_dir, 'Gt', self.image_list[idx]))
72 |
73 | if self.transform:
74 | image, label = self.transform(image, label)
75 | else:
76 | image = F.to_tensor(image)
77 | label = F.to_tensor(label)
78 | if self.is_test:
79 | name = self.image_list[idx]
80 | return image, label, name
81 | return image, label
82 |
83 |
--------------------------------------------------------------------------------
/Desnowing/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torchvision.transforms import functional as F
4 | import numpy as np
5 | from utils import Adder
6 | from data import test_dataloader
7 | from skimage.metrics import peak_signal_noise_ratio
8 | import time
9 | from pytorch_msssim import ssim
10 | import torch.nn.functional as f
11 |
12 | from skimage import img_as_ubyte
13 | import cv2
14 |
15 | def _eval(model, args):
16 | state_dict = torch.load(args.test_model)
17 | model.load_state_dict(state_dict['model'])
18 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19 | dataloader = test_dataloader(args.data_dir, batch_size=1, num_workers=0)
20 | torch.cuda.empty_cache()
21 | adder = Adder()
22 | model.eval()
23 | factor = 4
24 | with torch.no_grad():
25 | psnr_adder = Adder()
26 | ssim_adder = Adder()
27 |
28 | for iter_idx, data in enumerate(dataloader):
29 | input_img, label_img, name = data
30 |
31 | input_img = input_img.to(device)
32 |
33 | h, w = input_img.shape[2], input_img.shape[3]
34 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor)
35 | padh = H-h if h%factor!=0 else 0
36 | padw = W-w if w%factor!=0 else 0
37 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect')
38 |
39 | tm = time.time()
40 |
41 | pred = model(input_img)[2]
42 | pred = pred[:,:,:h,:w]
43 |
44 | elapsed = time.time() - tm
45 | adder(elapsed)
46 |
47 | pred_clip = torch.clamp(pred, 0, 1)
48 |
49 | pred_numpy = pred_clip.squeeze(0).cpu().numpy()
50 | label_numpy = label_img.squeeze(0).cpu().numpy()
51 |
52 |
53 | label_img = (label_img).cuda()
54 | psnr_val = 10 * torch.log10(1 / f.mse_loss(pred_clip, label_img))
55 | down_ratio = max(1, round(min(H, W) / 256))
56 | ssim_val = ssim(f.adaptive_avg_pool2d(pred_clip, (int(H / down_ratio), int(W / down_ratio))),
57 | f.adaptive_avg_pool2d(label_img, (int(H / down_ratio), int(W / down_ratio))),
58 | data_range=1, size_average=False)
59 | print('%d iter PSNR_dehazing: %.2f ssim: %f' % (iter_idx + 1, psnr_val, ssim_val))
60 | ssim_adder(ssim_val)
61 |
62 | if args.save_image:
63 | save_name = os.path.join(args.result_dir, name[0])
64 | pred_clip += 0.5 / 255
65 | pred = F.to_pil_image(pred_clip.squeeze(0).cpu(), 'RGB')
66 | pred.save(save_name)
67 |
68 | psnr_mimo = peak_signal_noise_ratio(pred_numpy, label_numpy, data_range=1)
69 | psnr_adder(psnr_val)
70 |
71 | print('%d iter PSNR: %.2f time: %f' % (iter_idx + 1, psnr_mimo, elapsed))
72 |
73 | print('==========================================================')
74 | print('The average PSNR is %.2f dB' % (psnr_adder.average()))
75 | print('The average SSIM is %.5f dB' % (ssim_adder.average()))
76 |
77 | print("Average time: %f" % adder.average())
78 |
79 |
--------------------------------------------------------------------------------
/Desnowing/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | from torch.backends import cudnn
5 | from models.OKNet import build_net
6 | from train import _train
7 | from eval import _eval
8 | import numpy as np
9 | import random
10 |
11 | def main(args):
12 | # CUDNN
13 | cudnn.benchmark = True
14 |
15 | if not os.path.exists('results/'):
16 | os.makedirs(args.model_save_dir)
17 | if not os.path.exists('results/' + args.model_name + '/'):
18 | os.makedirs('results/' + args.model_name + '/')
19 | if not os.path.exists(args.model_save_dir):
20 | os.makedirs(args.model_save_dir)
21 | if not os.path.exists(args.result_dir):
22 | os.makedirs(args.result_dir)
23 |
24 | model = build_net()
25 | print(model)
26 |
27 | if torch.cuda.is_available():
28 | model.cuda()
29 | if args.mode == 'train':
30 | _train(model, args)
31 |
32 | elif args.mode == 'test':
33 | _eval(model, args)
34 |
35 |
36 | if __name__ == '__main__':
37 | parser = argparse.ArgumentParser()
38 |
39 | # Directories
40 | parser.add_argument('--model_name', default='OKNet', type=str)
41 | parser.add_argument('--data_dir', type=str, default='')
42 | parser.add_argument('--mode', default='test', choices=['train', 'test'], type=str)
43 |
44 | # Train
45 | parser.add_argument('--batch_size', type=int, default=8)
46 | parser.add_argument('--learning_rate', type=float, default=2e-4)
47 | parser.add_argument('--weight_decay', type=float, default=0)
48 | parser.add_argument('--num_epoch', type=int, default=2000)
49 | parser.add_argument('--print_freq', type=int, default=100)
50 | parser.add_argument('--num_worker', type=int, default=16)
51 | parser.add_argument('--save_freq', type=int, default=50)
52 | parser.add_argument('--valid_freq', type=int, default=50)
53 | parser.add_argument('--resume', type=str, default='')
54 |
55 | # Test
56 | parser.add_argument('--test_model', type=str, default='')
57 | parser.add_argument('--save_image', type=bool, default=False, choices=[True, False])
58 |
59 | args = parser.parse_args()
60 | args.model_save_dir = os.path.join('results/', 'OKNet', 'CSD/')
61 | args.result_dir = os.path.join('results/', args.model_name, 'test')
62 | if not os.path.exists(args.model_save_dir):
63 | os.makedirs(args.model_save_dir)
64 | command = 'cp ' + 'models/layers.py ' + args.model_save_dir
65 | os.system(command)
66 | command = 'cp ' + 'models/OKNet.py ' + args.model_save_dir
67 | os.system(command)
68 | command = 'cp ' + 'train.py ' + args.model_save_dir
69 | os.system(command)
70 | command = 'cp ' + 'main.py ' + args.model_save_dir
71 | os.system(command)
72 | print(args)
73 | main(args)
74 |
--------------------------------------------------------------------------------
/Desnowing/models/OKNet.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from .layers import *
6 |
7 | class EBlock(nn.Module):
8 | def __init__(self, out_channel, num_res=8):
9 | super(EBlock, self).__init__()
10 |
11 | layers = [ResBlock(out_channel, out_channel) for _ in range(num_res)]
12 |
13 | self.layers = nn.Sequential(*layers)
14 |
15 | def forward(self, x):
16 | return self.layers(x)
17 |
18 |
19 | class DBlock(nn.Module):
20 | def __init__(self, channel, num_res=8):
21 | super(DBlock, self).__init__()
22 |
23 | layers = [ResBlock(channel, channel) for _ in range(num_res)]
24 | self.layers = nn.Sequential(*layers)
25 |
26 | def forward(self, x):
27 | return self.layers(x)
28 |
29 |
30 | class SCM(nn.Module):
31 | def __init__(self, out_plane):
32 | super(SCM, self).__init__()
33 | self.main = nn.Sequential(
34 | BasicConv(3, out_plane//4, kernel_size=3, stride=1, relu=True),
35 | BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True),
36 | BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True),
37 | BasicConv(out_plane // 2, out_plane, kernel_size=1, stride=1, relu=False),
38 | nn.InstanceNorm2d(out_plane, affine=True)
39 | )
40 |
41 | def forward(self, x):
42 | x = self.main(x)
43 | return x
44 |
45 |
46 | class BottleNect(nn.Module):
47 | def __init__(self, dim) -> None:
48 | super().__init__()
49 |
50 | ker = 63
51 | pad = ker // 2
52 | self.in_conv = nn.Sequential(
53 | nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1),
54 | nn.GELU()
55 | )
56 | self.out_conv = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1)
57 | self.dw_13 = nn.Conv2d(dim, dim, kernel_size=(1,ker), padding=(0,pad), stride=1, groups=dim)
58 | self.dw_31 = nn.Conv2d(dim, dim, kernel_size=(ker,1), padding=(pad,0), stride=1, groups=dim)
59 | self.dw_33 = nn.Conv2d(dim, dim, kernel_size=ker, padding=pad, stride=1, groups=dim)
60 | self.dw_11 = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1, groups=dim)
61 |
62 | self.act = nn.ReLU()
63 |
64 | ### sca ###
65 | self.conv = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
66 | self.pool = nn.AdaptiveAvgPool2d((1,1))
67 |
68 | ### fca ###
69 | self.fac_conv = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
70 | self.fac_pool = nn.AdaptiveAvgPool2d((1,1))
71 | self.fgm = FGM(dim)
72 |
73 | def forward(self, x):
74 | out = self.in_conv(x)
75 |
76 | ### fca ###
77 | x_att = self.fac_conv(self.fac_pool(out))
78 | x_fft = torch.fft.fft2(out, norm='backward')
79 | x_fft = x_att * x_fft
80 | x_fca = torch.fft.ifft2(x_fft, dim=(-2,-1), norm='backward')
81 | x_fca = torch.abs(x_fca)
82 |
83 | ### fca ###
84 | ### sca ###
85 | x_att = self.conv(self.pool(x_fca))
86 | x_sca = x_att * x_fca
87 | ### sca ###
88 | x_sca = self.fgm(x_sca)
89 |
90 | out = x + self.dw_13(out) + self.dw_31(out) + self.dw_33(out) + self.dw_11(out) + x_sca
91 | out = self.act(out)
92 | return self.out_conv(out)
93 |
94 | class FGM(nn.Module):
95 | def __init__(self, dim) -> None:
96 | super().__init__()
97 |
98 | self.conv = nn.Conv2d(dim, dim*2, 3, 1, 1)
99 |
100 | self.dwconv1 = nn.Conv2d(dim, dim, 1, 1, groups=1)
101 | self.dwconv2 = nn.Conv2d(dim, dim, 1, 1, groups=1)
102 | self.alpha = nn.Parameter(torch.zeros(dim, 1, 1))
103 | self.beta = nn.Parameter(torch.ones(dim, 1, 1))
104 |
105 | def forward(self, x):
106 | # res = x.clone()
107 | fft_size = x.size()[2:]
108 | x1 = self.dwconv1(x)
109 | x2 = self.dwconv2(x)
110 |
111 | x2_fft = torch.fft.fft2(x2, norm='backward')
112 |
113 | out = x1 * x2_fft
114 |
115 | out = torch.fft.ifft2(out, dim=(-2,-1), norm='backward')
116 | out = torch.abs(out)
117 |
118 | return out * self.alpha + x * self.beta
119 |
120 |
121 | class FAM(nn.Module):
122 | def __init__(self, channel):
123 | super(FAM, self).__init__()
124 | self.merge = BasicConv(channel*2, channel, kernel_size=3, stride=1, relu=False)
125 |
126 | def forward(self, x1, x2):
127 | return self.merge(torch.cat([x1, x2], dim=1))
128 |
129 | class OKNet(nn.Module):
130 | def __init__(self, num_res=4):
131 | super(OKNet, self).__init__()
132 |
133 | base_channel = 32
134 |
135 | self.Encoder = nn.ModuleList([
136 | EBlock(base_channel, num_res),
137 | EBlock(base_channel*2, num_res),
138 | EBlock(base_channel*4, num_res),
139 | ])
140 |
141 | self.feat_extract = nn.ModuleList([
142 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1),
143 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2),
144 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2),
145 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True),
146 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True),
147 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1)
148 | ])
149 |
150 | self.Decoder = nn.ModuleList([
151 | DBlock(base_channel * 4, num_res),
152 | DBlock(base_channel * 2, num_res),
153 | DBlock(base_channel, num_res)
154 | ])
155 |
156 | self.Convs = nn.ModuleList([
157 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
158 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1),
159 | ])
160 |
161 | self.ConvsOut = nn.ModuleList(
162 | [
163 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1),
164 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),
165 | ]
166 | )
167 |
168 | self.FAM1 = FAM(base_channel * 4)
169 | self.SCM1 = SCM(base_channel * 4)
170 | self.FAM2 = FAM(base_channel * 2)
171 | self.SCM2 = SCM(base_channel * 2)
172 |
173 | self.bottelneck = BottleNect(base_channel * 4)
174 |
175 |
176 | def forward(self, x):
177 | x_2 = F.interpolate(x, scale_factor=0.5)
178 | x_4 = F.interpolate(x_2, scale_factor=0.5)
179 | z2 = self.SCM2(x_2)
180 | z4 = self.SCM1(x_4)
181 |
182 | outputs = list()
183 | # 256
184 | x_ = self.feat_extract[0](x)
185 | res1 = self.Encoder[0](x_)
186 | # 128
187 | z = self.feat_extract[1](res1)
188 | z = self.FAM2(z, z2)
189 | res2 = self.Encoder[1](z)
190 | # 64
191 | z = self.feat_extract[2](res2)
192 | z = self.FAM1(z, z4)
193 | z = self.Encoder[2](z)
194 | z = self.bottelneck(z)
195 |
196 | z = self.Decoder[0](z)
197 | z_ = self.ConvsOut[0](z)
198 | # 128
199 | z = self.feat_extract[3](z)
200 | outputs.append(z_+x_4)
201 |
202 | z = torch.cat([z, res2], dim=1)
203 | z = self.Convs[0](z)
204 | z = self.Decoder[1](z)
205 | z_ = self.ConvsOut[1](z)
206 | # 256
207 | z = self.feat_extract[4](z)
208 | outputs.append(z_+x_2)
209 |
210 | z = torch.cat([z, res1], dim=1)
211 | z = self.Convs[1](z)
212 | z = self.Decoder[2](z)
213 | z = self.feat_extract[5](z)
214 | outputs.append(z+x)
215 |
216 | return outputs
217 |
218 | def build_net():
219 | return OKNet()
220 |
221 |
--------------------------------------------------------------------------------
/Desnowing/models/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class BasicConv(nn.Module):
6 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False):
7 | super(BasicConv, self).__init__()
8 | if bias and norm:
9 | bias = False
10 |
11 | padding = kernel_size // 2
12 | layers = list()
13 | if transpose:
14 | padding = kernel_size // 2 -1
15 | layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
16 | else:
17 | layers.append(
18 | nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
19 | if norm:
20 | layers.append(nn.BatchNorm2d(out_channel))
21 | if relu:
22 | layers.append(nn.GELU())
23 | self.main = nn.Sequential(*layers)
24 |
25 | def forward(self, x):
26 | return self.main(x)
27 |
28 |
29 | class ResBlock(nn.Module):
30 | def __init__(self, in_channel, out_channel):
31 | super(ResBlock, self).__init__()
32 | self.main = nn.Sequential(
33 | BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True),
34 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
35 | )
36 |
37 | def forward(self, x):
38 | return self.main(x) + x
--------------------------------------------------------------------------------
/Desnowing/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from data import train_dataloader
4 | from utils import Adder, Timer, check_lr
5 | from torch.utils.tensorboard import SummaryWriter
6 | from valid import _valid
7 | import torch.nn.functional as F
8 | import torch.nn as nn
9 |
10 | from warmup_scheduler import GradualWarmupScheduler
11 |
12 | def _train(model, args):
13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14 | criterion = torch.nn.L1Loss()
15 |
16 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-8)
17 | dataloader = train_dataloader(args.data_dir, args.batch_size, args.num_worker)
18 | max_iter = len(dataloader)
19 | warmup_epochs=3
20 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_epoch-warmup_epochs, eta_min=1e-6)
21 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
22 | scheduler.step()
23 | epoch = 1
24 | if args.resume:
25 | state = torch.load(args.resume)
26 | epoch = state['epoch']
27 | optimizer.load_state_dict(state['optimizer'])
28 | model.load_state_dict(state['model'])
29 | print('Resume from %d'%epoch)
30 | epoch += 1
31 |
32 | writer = SummaryWriter()
33 | epoch_pixel_adder = Adder()
34 | epoch_fft_adder = Adder()
35 | iter_pixel_adder = Adder()
36 | iter_fft_adder = Adder()
37 | epoch_timer = Timer('m')
38 | iter_timer = Timer('m')
39 | best_psnr=-1
40 |
41 | for epoch_idx in range(epoch, args.num_epoch + 1):
42 |
43 | epoch_timer.tic()
44 | iter_timer.tic()
45 | for iter_idx, batch_data in enumerate(dataloader):
46 |
47 | input_img, label_img = batch_data
48 | input_img = input_img.to(device)
49 | label_img = label_img.to(device)
50 |
51 | optimizer.zero_grad()
52 | pred_img = model(input_img)
53 | label_img2 = F.interpolate(label_img, scale_factor=0.5, mode='bilinear')
54 | label_img4 = F.interpolate(label_img, scale_factor=0.25, mode='bilinear')
55 | l1 = criterion(pred_img[0], label_img4)
56 | l2 = criterion(pred_img[1], label_img2)
57 | l3 = criterion(pred_img[2], label_img)
58 | loss_content = l1+l2+l3
59 |
60 | label_fft1 = torch.fft.fft2(label_img4, dim=(-2,-1))
61 | label_fft1 = torch.stack((label_fft1.real, label_fft1.imag), -1)
62 |
63 | pred_fft1 = torch.fft.fft2(pred_img[0], dim=(-2,-1))
64 | pred_fft1 = torch.stack((pred_fft1.real, pred_fft1.imag), -1)
65 |
66 | label_fft2 = torch.fft.fft2(label_img2, dim=(-2,-1))
67 | label_fft2 = torch.stack((label_fft2.real, label_fft2.imag), -1)
68 |
69 | pred_fft2 = torch.fft.fft2(pred_img[1], dim=(-2,-1))
70 | pred_fft2 = torch.stack((pred_fft2.real, pred_fft2.imag), -1)
71 |
72 | label_fft3 = torch.fft.fft2(label_img, dim=(-2,-1))
73 | label_fft3 = torch.stack((label_fft3.real, label_fft3.imag), -1)
74 |
75 | pred_fft3 = torch.fft.fft2(pred_img[2], dim=(-2,-1))
76 | pred_fft3 = torch.stack((pred_fft3.real, pred_fft3.imag), -1)
77 |
78 | f1 = criterion(pred_fft1, label_fft1)
79 | f2 = criterion(pred_fft2, label_fft2)
80 | f3 = criterion(pred_fft3, label_fft3)
81 | loss_fft = f1+f2+f3
82 |
83 | loss = loss_content + 0.1 * loss_fft
84 | loss.backward()
85 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.001)
86 | optimizer.step()
87 |
88 | iter_pixel_adder(loss_content.item())
89 | iter_fft_adder(loss_fft.item())
90 |
91 | epoch_pixel_adder(loss_content.item())
92 | epoch_fft_adder(loss_fft.item())
93 |
94 | if (iter_idx + 1) % args.print_freq == 0:
95 | print("Time: %7.4f Epoch: %03d Iter: %4d/%4d LR: %.10f Loss content: %7.4f Loss fft: %7.4f" % (
96 | iter_timer.toc(), epoch_idx, iter_idx + 1, max_iter, scheduler.get_lr()[0], iter_pixel_adder.average(),
97 | iter_fft_adder.average()))
98 | writer.add_scalar('Pixel Loss', iter_pixel_adder.average(), iter_idx + (epoch_idx-1)* max_iter)
99 | writer.add_scalar('FFT Loss', iter_fft_adder.average(), iter_idx + (epoch_idx - 1) * max_iter)
100 |
101 | iter_timer.tic()
102 | iter_pixel_adder.reset()
103 | iter_fft_adder.reset()
104 | overwrite_name = os.path.join(args.model_save_dir, 'model.pkl')
105 | torch.save({'model': model.state_dict(),
106 | 'optimizer': optimizer.state_dict(),
107 | 'epoch': epoch_idx}, overwrite_name)
108 |
109 | if epoch_idx % args.save_freq == 0:
110 | save_name = os.path.join(args.model_save_dir, 'model_%d.pkl' % epoch_idx)
111 | torch.save({'model': model.state_dict()}, save_name)
112 | print("EPOCH: %02d\nElapsed time: %4.2f Epoch Pixel Loss: %7.4f Epoch FFT Loss: %7.4f" % (
113 | epoch_idx, epoch_timer.toc(), epoch_pixel_adder.average(), epoch_fft_adder.average()))
114 | epoch_fft_adder.reset()
115 | epoch_pixel_adder.reset()
116 | scheduler.step()
117 | if epoch_idx % args.valid_freq == 0:
118 | val = _valid(model, args, epoch_idx)
119 | print('%03d epoch \n Average PSNR %.2f dB' % (epoch_idx, val))
120 | writer.add_scalar('PSNR', val, epoch_idx)
121 | if val >= best_psnr:
122 | torch.save({'model': model.state_dict()}, os.path.join(args.model_save_dir, 'Best.pkl'))
123 | save_name = os.path.join(args.model_save_dir, 'Final.pkl')
124 | torch.save({'model': model.state_dict()}, save_name)
125 |
--------------------------------------------------------------------------------
/Desnowing/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 |
4 |
5 | class Adder(object):
6 | def __init__(self):
7 | self.count = 0
8 | self.num = float(0)
9 |
10 | def reset(self):
11 | self.count = 0
12 | self.num = float(0)
13 |
14 | def __call__(self, num):
15 | self.count += 1
16 | self.num += num
17 |
18 | def average(self):
19 | return self.num / self.count
20 |
21 |
22 | class Timer(object):
23 | def __init__(self, option='s'):
24 | self.tm = 0
25 | self.option = option
26 | if option == 's':
27 | self.devider = 1
28 | elif option == 'm':
29 | self.devider = 60
30 | else:
31 | self.devider = 3600
32 |
33 | def tic(self):
34 | self.tm = time.time()
35 |
36 | def toc(self):
37 | return (time.time() - self.tm) / self.devider
38 |
39 |
40 | def check_lr(optimizer):
41 | for i, param_group in enumerate(optimizer.param_groups):
42 | lr = param_group['lr']
43 | return lr
44 |
--------------------------------------------------------------------------------
/Desnowing/valid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.transforms import functional as F
3 | from data import valid_dataloader
4 | from utils import Adder
5 | import os
6 | from skimage.metrics import peak_signal_noise_ratio
7 | import torch.nn.functional as f
8 |
9 |
10 | def _valid(model, args, ep):
11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12 | snow = valid_dataloader(args.data_dir, batch_size=1, num_workers=0)
13 | model.eval()
14 | psnr_adder = Adder()
15 |
16 | with torch.no_grad():
17 | print('Start Evaluation')
18 | factor = 4
19 | for idx, data in enumerate(snow):
20 | input_img, label_img = data
21 | input_img = input_img.to(device)
22 |
23 | h, w = input_img.shape[2], input_img.shape[3]
24 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor)
25 | padh = H-h if h%factor!=0 else 0
26 | padw = W-w if w%factor!=0 else 0
27 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect')
28 |
29 | if not os.path.exists(os.path.join(args.result_dir, '%d' % (ep))):
30 | os.mkdir(os.path.join(args.result_dir, '%d' % (ep)))
31 |
32 | pred = model(input_img)[2]
33 | pred = pred[:,:,:h,:w]
34 |
35 | pred_clip = torch.clamp(pred, 0, 1)
36 | p_numpy = pred_clip.squeeze(0).cpu().numpy()
37 | label_numpy = label_img.squeeze(0).cpu().numpy()
38 |
39 | psnr = peak_signal_noise_ratio(p_numpy, label_numpy, data_range=1)
40 |
41 | psnr_adder(psnr)
42 | print('\r%03d'%idx, end=' ')
43 |
44 | print('\n')
45 | model.train()
46 | return psnr_adder.average()
47 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Yuning Cui
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Omni-Kernel Network for Image Restoration [AAAI-24]
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 | ## Installation
13 | The project is built with PyTorch 3.8, PyTorch 1.8.1. CUDA 10.2, cuDNN 7.6.5
14 | For installing, follow these instructions:
15 | ~~~
16 | conda install pytorch=1.8.1 torchvision=0.9.1 -c pytorch
17 | pip install tensorboard einops scikit-image pytorch_msssim opencv-python
18 | ~~~
19 | Install warmup scheduler:
20 | ~~~
21 | cd pytorch-gradual-warmup-lr/
22 | python setup.py install
23 | cd ..
24 | ~~~
25 |
26 | Please download pillow package using Conda instead of pip.
27 |
28 |
29 |
30 | ITS FLOPs: 39.67G, Params: 4.72M
31 | Training and testing details can be found in the individual directories.
32 |
33 | ## [Models](https://drive.google.com/drive/folders/1jrqqTBFHi2XNvBb9-rm9n0fZHfYZabFw?usp=sharing)
34 | ## [Images](https://drive.google.com/drive/folders/1FuaHw5Wr9PTSKAKn2qEuE-IN1Ye-O8Xj?usp=sharing)
35 |
36 |
37 |
38 | ## Citation
39 | ~~~
40 | @inproceedings{cui2024omni,
41 | title={Omni-Kernel Network for Image Restoration},
42 | author={Cui, Yuning and Ren, Wenqi and Knoll, Alois},
43 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
44 | volume={38},
45 | number={2},
46 | pages={1426--1434},
47 | year={2024}
48 | }
49 | ~~~
50 |
51 |
52 | ## Contact
53 | Should you have any question, please contact Yuning Cui.
54 |
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/setup.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import setuptools
6 |
7 | _VERSION = '0.3'
8 |
9 | REQUIRED_PACKAGES = [
10 | ]
11 |
12 | DEPENDENCY_LINKS = [
13 | ]
14 |
15 | setuptools.setup(
16 | name='warmup_scheduler',
17 | version=_VERSION,
18 | description='Gradually Warm-up LR Scheduler for Pytorch',
19 | install_requires=REQUIRED_PACKAGES,
20 | dependency_links=DEPENDENCY_LINKS,
21 | url='https://github.com/ildoonet/pytorch-gradual-warmup-lr',
22 | license='MIT License',
23 | package_dir={},
24 | packages=setuptools.find_packages(exclude=['tests']),
25 | )
26 |
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/warmup_scheduler/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from warmup_scheduler.scheduler import GradualWarmupScheduler
3 |
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/warmup_scheduler/run.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.lr_scheduler import StepLR, ExponentialLR
3 | from torch.optim.sgd import SGD
4 |
5 | from warmup_scheduler import GradualWarmupScheduler
6 |
7 |
8 | if __name__ == '__main__':
9 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
10 | optim = SGD(model, 0.1)
11 |
12 | # scheduler_warmup is chained with schduler_steplr
13 | scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1)
14 | scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr)
15 |
16 | # this zero gradient update is needed to avoid a warning message, issue #8.
17 | optim.zero_grad()
18 | optim.step()
19 |
20 | for epoch in range(1, 20):
21 | scheduler_warmup.step(epoch)
22 | print(epoch, optim.param_groups[0]['lr'])
23 |
24 | optim.step() # backward pass (update network)
25 |
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/warmup_scheduler/scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 | from torch.optim.lr_scheduler import ReduceLROnPlateau
3 |
4 |
5 | class GradualWarmupScheduler(_LRScheduler):
6 | """ Gradually warm-up(increasing) learning rate in optimizer.
7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
8 |
9 | Args:
10 | optimizer (Optimizer): Wrapped optimizer.
11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
12 | total_epoch: target learning rate is reached at total_epoch, gradually
13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
14 | """
15 |
16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
17 | self.multiplier = multiplier
18 | if self.multiplier < 1.:
19 | raise ValueError('multiplier should be greater thant or equal to 1.')
20 | self.total_epoch = total_epoch
21 | self.after_scheduler = after_scheduler
22 | self.finished = False
23 | super(GradualWarmupScheduler, self).__init__(optimizer)
24 |
25 | def get_lr(self):
26 | if self.last_epoch > self.total_epoch:
27 | if self.after_scheduler:
28 | if not self.finished:
29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
30 | self.finished = True
31 | return self.after_scheduler.get_lr()
32 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
33 |
34 | if self.multiplier == 1.0:
35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
36 | else:
37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
38 |
39 | def step_ReduceLROnPlateau(self, metrics, epoch=None):
40 | if epoch is None:
41 | epoch = self.last_epoch + 1
42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
43 | if self.last_epoch <= self.total_epoch:
44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
46 | param_group['lr'] = lr
47 | else:
48 | if epoch is None:
49 | self.after_scheduler.step(metrics, None)
50 | else:
51 | self.after_scheduler.step(metrics, epoch - self.total_epoch)
52 |
53 | def step(self, epoch=None, metrics=None):
54 | if type(self.after_scheduler) != ReduceLROnPlateau:
55 | if self.finished and self.after_scheduler:
56 | if epoch is None:
57 | self.after_scheduler.step(None)
58 | else:
59 | self.after_scheduler.step(epoch - self.total_epoch)
60 | else:
61 | return super(GradualWarmupScheduler, self).step(epoch)
62 | else:
63 | self.step_ReduceLROnPlateau(metrics, epoch)
64 |
--------------------------------------------------------------------------------