├── loss └── __init__.py ├── models ├── __init__.py ├── SUWNet.py ├── UWCNN.py ├── BOths.py ├── DICAM.py ├── ASNet.py ├── WaterNet.py ├── CLUIE-Net.py ├── USUIR.py ├── NU2Net.py ├── LiteEnhanceNet.py ├── P2CNet.py ├── IACC.py ├── LANet.py ├── ADMNNet.py ├── UIEC2Net.py ├── Deep-WaveNet.py ├── AoSRNet.py ├── UIEPTA.py ├── UIETPA.py ├── SCNet.py ├── TUDA.py ├── FIVE_APLUS.py ├── todo_SGUIE-Net.py ├── UIALN.py ├── RauneNet.py ├── CCMSRNet.py └── Spectroformer.py ├── utils ├── __init__.py └── utils.py ├── config ├── __init__.py └── config.py ├── data ├── __init__.py ├── data_RGB.py └── dataset_RGB.py ├── config.yml ├── LICENSE ├── test.py ├── README.md ├── preprocess └── transform.py ├── .gitignore └── train.py /loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss import * 2 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Model -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import Config 2 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_RGB import get_training_data, get_validation_data, get_testing_data 2 | -------------------------------------------------------------------------------- /config.yml: -------------------------------------------------------------------------------- 1 | VERBOSE: True 2 | 3 | MODEL: 4 | SESSION: 'UW' 5 | INPUT: 'input' 6 | TARGET: 'target' 7 | 8 | # Optimization arguments. 9 | OPTIM: 10 | BATCH_SIZE: 4 11 | NUM_EPOCHS: 200 12 | LR_INITIAL: 2e-4 13 | LR_MIN: 1e-6 14 | SEED: 3407 15 | WANDB: False 16 | 17 | TRAINING: 18 | VAL_AFTER_EVERY: 1 19 | RESUME: False 20 | PS_W: 256 21 | PS_H: 256 22 | TRAIN_DIR: '../dataset/UW/train/' 23 | VAL_DIR: '../dataset/UW/test/' 24 | SAVE_DIR: './checkpoints/' 25 | ORI: False 26 | 27 | TESTING: 28 | WEIGHT: './checkpoints/best.pth' 29 | SAVE_IMAGES: True -------------------------------------------------------------------------------- /data/data_RGB.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .dataset_RGB import DataLoaderTrain, DataLoaderVal, DataLoaderTest 3 | 4 | 5 | def get_training_data(rgb_dir, inp, target, img_options): 6 | assert os.path.exists(rgb_dir) 7 | return DataLoaderTrain(rgb_dir, inp, target, img_options) 8 | 9 | 10 | def get_validation_data(rgb_dir, inp, target, img_options): 11 | assert os.path.exists(rgb_dir) 12 | return DataLoaderVal(rgb_dir, inp, target, img_options) 13 | 14 | 15 | def get_testing_data(rgb_dir, inp, img_options): 16 | assert os.path.exists(rgb_dir) 17 | return DataLoaderTest(rgb_dir, inp, img_options) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Nick Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from collections import OrderedDict 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def seed_everything(seed=3407): 10 | os.environ['PYTHONHASHSEED'] = str(seed) 11 | random.seed(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | torch.backends.cudnn.deterministic = True 17 | torch.backends.cudnn.benchmark = False 18 | 19 | 20 | def save_checkpoint(state, epoch, model_name, outdir): 21 | if not os.path.exists(outdir): 22 | os.makedirs(outdir) 23 | checkpoint_file = os.path.join(outdir, model_name + '_' + 'epoch_' + str(epoch) + '.pth') 24 | torch.save(state, checkpoint_file) 25 | 26 | 27 | def load_checkpoint(model, weights): 28 | checkpoint = torch.load(weights, map_location=lambda storage, loc: storage.cuda(0)) 29 | new_state_dict = OrderedDict() 30 | for key, value in checkpoint['state_dict'].items(): 31 | if key.startswith('module'): 32 | name = key[7:] 33 | else: 34 | name = key 35 | new_state_dict[name] = value 36 | model.load_state_dict(new_state_dict) 37 | 38 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import os 3 | 4 | from accelerate import Accelerator 5 | from torch.utils.data import DataLoader 6 | from torchmetrics.functional import peak_signal_noise_ratio, structural_similarity_index_measure 7 | from torchvision.utils import save_image 8 | from tqdm import tqdm 9 | 10 | from config import Config 11 | from data import get_testing_data 12 | from models import * 13 | from utils import * 14 | 15 | warnings.filterwarnings('ignore') 16 | 17 | opt = Config('config.yml') 18 | 19 | seed_everything(opt.OPTIM.SEED) 20 | 21 | os.makedirs('result', exist_ok=True) 22 | 23 | def test(): 24 | accelerator = Accelerator() 25 | 26 | # Data Loader 27 | val_dir = opt.TRAINING.VAL_DIR 28 | 29 | test_dataset = get_testing_data(val_dir, opt.MODEL.INPUT, {'w': opt.TRAINING.PS_W, 'h': opt.TRAINING.PS_H, 'ori': False}) 30 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=8, drop_last=False, 31 | pin_memory=True) 32 | 33 | # Model & Metrics 34 | model = Model() 35 | 36 | load_checkpoint(model, opt.TESTING.WEIGHT) 37 | 38 | model, test_loader = accelerator.prepare(model, test_loader) 39 | 40 | model.eval() 41 | 42 | for _, test_data in enumerate(tqdm(test_loader)): 43 | # get the inputs; data is a list of [targets, inputs, filename] 44 | inp = test_data[0].contiguous() 45 | dep = test_data[1].contiguous() 46 | 47 | with torch.no_grad(): 48 | res = model(inp, dep) 49 | 50 | save_image(res, os.path.join(os.getcwd(), "result", test_data[2][0])) 51 | 52 | 53 | if __name__ == '__main__': 54 | test() 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ucolor-Reimplementation 2 | 3 | Reimplementation of Underwater Image Enhancement via Medium Transmission-Guided Multi-Color Space Embedding 4 | 5 | Edited from https://github.com/59Kkk/pytorch_Ucolor_lcy 6 | 7 | ## Fixed Problems 8 | 9 | 1. Use Kornia library for color space conversion 10 | 2. Use accelerate to implement distributed training 11 | 3. Fix the problem in depth map concatenation 12 | 4. Fix nan loss 13 | 14 | ## Dataset Structure 15 | 16 | The dataset should be formatted like below 17 | 18 | ``` 19 | dataset/ 20 | ├─ train/ 21 | │ ├─ input/ 22 | │ │ ├─ 1.jpg 23 | │ │ ├─ ... 24 | │ ├─ depth/ 25 | │ │ ├─ 1.jpg 26 | │ │ ├─ ... 27 | │ └─ target/ 28 | │ ├─ 1.jpg 29 | │ ├─ ... 30 | └─ test/ 31 | ├─ input/ 32 | │ ├─ 1.jpg 33 | │ ├─ ... 34 | ├─ depth/ 35 | │ ├─ 1.jpg 36 | │ ├─ ... 37 | └─ target/ 38 | ├─ 1.jpg 39 | ├─ ... 40 | 41 | ``` 42 | 43 | input folder contains underwater image 44 | 45 | depth folder contains transmission map 46 | 47 | target folder contains ground truth 48 | 49 | each triplet should have the exact same name and extension 50 | 51 | ## Training 52 | 53 | You may download the dataset first, and then specify TRAIN_DIR, VAL_DIR and SAVE_DIR in the section TRAINING in `config.yml`. 54 | 55 | For single GPU training: 56 | 57 | ```bash 58 | python train.py 59 | ``` 60 | 61 | For multiple GPUs training: 62 | 63 | ```bash 64 | accelerate config 65 | accelerate launch train.py 66 | ``` 67 | 68 | If you have difficulties with the usage of `accelerate`, please refer to [Accelerate](https://github.com/huggingface/accelerate). 69 | 70 | ## Inference 71 | 72 | Please first specify TRAIN_DIR, VAL_DIR and SAVE_DIR in section TESTING in `config.yml`. 73 | 74 | ```bash 75 | python infer.py 76 | ``` -------------------------------------------------------------------------------- /models/SUWNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ConvBlock(nn.Module): 5 | def __init__(self): 6 | super(ConvBlock, self).__init__() 7 | self.conv = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 8 | self.conv1 = nn.Conv2d(in_channels=64, out_channels=61, kernel_size=3, stride=1, padding=1, bias=False) 9 | self.relu = nn.ReLU(inplace=True) 10 | self.drop = nn.Dropout2d(p=.20) 11 | 12 | def forward(self, x): 13 | x, input_x = x 14 | a = self.relu(self.conv1(self.relu(self.drop(self.conv(self.relu(self.drop(self.conv(x)))))))) 15 | out = torch.cat((a, input_x), 1) 16 | return (out, input_x) 17 | 18 | class SUWnet(nn.Module): 19 | def __init__(self, num_layers=3): 20 | super(SUWnet, self).__init__() 21 | self.input = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 22 | self.output = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False) 23 | self.relu = nn.ReLU(inplace=True) 24 | self.blocks = self.StackBlock(ConvBlock, num_layers) 25 | 26 | def StackBlock(self, block, layer_num): 27 | layers = [] 28 | for _ in range(layer_num): 29 | layers.append(block()) 30 | return nn.Sequential(*layers) 31 | 32 | def forward(self, x): 33 | input_x = x 34 | x1 = self.relu(self.input(x)) 35 | out, _ = self.blocks((x1, input_x)) 36 | out = self.output(out) 37 | return out 38 | 39 | if __name__ == '__main__': 40 | t = torch.randn(1, 3, 256, 256).cuda() 41 | model = SUWnet().cuda() 42 | res = model(t) 43 | print(res.shape) -------------------------------------------------------------------------------- /preprocess/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from typing import Tuple 4 | 5 | 6 | def white_balance_transform(im_rgb): 7 | """ 8 | Requires HWC uint8 input 9 | Originally in SimplestColorBalance.m 10 | """ 11 | # This section basically reshapes into vectors per channel I think? 12 | 13 | # if RGB 14 | if len(im_rgb.shape) == 3: 15 | R = np.sum(im_rgb[:, :, 0], axis=None) 16 | G = np.sum(im_rgb[:, :, 1], axis=None) 17 | B = np.sum(im_rgb[:, :, 2], axis=None) 18 | 19 | maxpix = max(R, G, B) 20 | ratio = np.array([maxpix / R, maxpix / G, maxpix / B]) 21 | 22 | satLevel1 = 0.005 * ratio 23 | satLevel2 = 0.005 * ratio 24 | 25 | m, n, p = im_rgb.shape 26 | im_rgb_flat = np.zeros(shape=(p, m * n)) 27 | for i in range(0, p): 28 | im_rgb_flat[i, :] = np.reshape(im_rgb[:, :, i], (1, m * n)) 29 | 30 | # if grayscale 31 | else: 32 | satLevel1 = np.array([0.001]) 33 | satLevel2 = np.array([0.005]) 34 | m, n = im_rgb.shape 35 | p = 1 36 | im_rgb_flat = np.reshape(im_rgb, (1, m * n)) 37 | 38 | wb = np.zeros(shape=im_rgb_flat.shape) 39 | for ch in range(p): 40 | q = [satLevel1[ch], 1 - satLevel2[ch]] 41 | tiles = np.quantile(im_rgb_flat[ch, :], q) 42 | temp = im_rgb_flat[ch, :] 43 | temp[temp < tiles[0]] = tiles[0] 44 | temp[temp > tiles[1]] = tiles[1] 45 | wb[ch, :] = temp 46 | bottom = min(wb[ch, :]) 47 | top = max(wb[ch, :]) 48 | wb[ch, :] = (wb[ch, :] - bottom) * 255 / (top - bottom) 49 | 50 | if len(im_rgb.shape) == 3: 51 | outval = np.zeros(shape=im_rgb.shape) 52 | for i in range(p): 53 | outval[:, :, i] = np.reshape(wb[i, :], (m, n)) 54 | 55 | else: 56 | outval = np.reshape(wb, (m, n)) 57 | 58 | return outval.astype(np.uint8) 59 | 60 | 61 | def gamma_correction(im): 62 | gc = np.power(im / 255, 0.7) 63 | gc = np.clip(255 * gc, 0, 255) 64 | gc = gc.astype(np.uint8) 65 | return gc 66 | 67 | 68 | def histeq(im_rgb): 69 | im_lab = cv2.cvtColor(im_rgb, cv2.COLOR_RGB2LAB) 70 | 71 | clahe = cv2.createCLAHE(clipLimit=0.1, tileGridSize=(8, 8)) 72 | el = clahe.apply(im_lab[:, :, 0]) 73 | 74 | im_he = im_lab.copy() 75 | im_he[:, :, 0] = el 76 | im_he_rgb = cv2.cvtColor(im_he, cv2.COLOR_LAB2RGB) 77 | 78 | return im_he_rgb 79 | 80 | 81 | def transform(rgb) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 82 | """ 83 | transform(rgb) -> wb, gc, he 84 | """ 85 | # Convenience wrapper 86 | wb = white_balance_transform(rgb) 87 | gc = gamma_correction(rgb) 88 | he = histeq(rgb) 89 | 90 | return wb, gc, he 91 | 92 | -------------------------------------------------------------------------------- /models/UWCNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class UWCNN(nn.Module): 6 | def __init__(self): 7 | super(UWCNN, self).__init__() 8 | self._init_layers() 9 | 10 | def _init_layers(self): 11 | self.conv2d_dehaze1 = nn.Conv2d(3, 16, 3, 1, 1) 12 | self.dehaze1_relu = nn.ReLU(inplace=True) 13 | 14 | self.conv2d_dehaze2 = nn.Conv2d(16, 16, 3, 1, 1) 15 | self.dehaze2_relu = nn.ReLU(inplace=True) 16 | 17 | self.conv2d_dehaze3 = nn.Conv2d(16, 16, 3, 1, 1) 18 | self.dehaze3_relu = nn.ReLU(inplace=True) 19 | 20 | self.conv2d_dehaze4 = nn.Conv2d(3+16+16+16, 16, 3, 1, 1) 21 | self.dehaze4_relu = nn.ReLU(inplace=True) 22 | 23 | self.conv2d_dehaze5 = nn.Conv2d(16, 16, 3, 1, 1) 24 | self.dehaze5_relu = nn.ReLU(inplace=True) 25 | 26 | self.conv2d_dehaze6 = nn.Conv2d(16, 16, 3, 1, 1) 27 | self.dehaze6_relu = nn.ReLU(inplace=True) 28 | 29 | self.conv2d_dehaze7 = nn.Conv2d(51+48, 16, 3, 1, 1) 30 | self.dehaze7_relu = nn.ReLU(inplace=True) 31 | 32 | self.conv2d_dehaze8 = nn.Conv2d(16, 16, 3, 1, 1) 33 | self.dehaze8_relu = nn.ReLU(inplace=True) 34 | 35 | self.conv2d_dehaze9 = nn.Conv2d(16, 16, 3, 1, 1) 36 | self.dehaze9_relu = nn.ReLU(inplace=True) 37 | 38 | self.conv2d_dehaze10 = nn.Conv2d(99+48, 3, 3, 1, 1) 39 | 40 | def forward(self, x): 41 | image_conv1 = self.dehaze1_relu(self.conv2d_dehaze1(x)) 42 | image_conv2 = self.dehaze2_relu(self.conv2d_dehaze2(image_conv1)) 43 | image_conv3 = self.dehaze3_relu(self.conv2d_dehaze3(image_conv2)) 44 | 45 | dehaze_concat1 = torch.cat([image_conv1, image_conv2, image_conv3, x], dim=1) 46 | image_conv4 = self.dehaze4_relu(self.conv2d_dehaze4(dehaze_concat1)) 47 | image_conv5 = self.dehaze5_relu(self.conv2d_dehaze5(image_conv4)) 48 | image_conv6 = self.dehaze6_relu(self.conv2d_dehaze6(image_conv5)) 49 | 50 | dehaze_concat2 = torch.cat([dehaze_concat1, image_conv4, image_conv5, image_conv6], dim=1) 51 | image_conv7 = self.dehaze7_relu(self.conv2d_dehaze7(dehaze_concat2)) 52 | image_conv8 = self.dehaze8_relu(self.conv2d_dehaze8(image_conv7)) 53 | image_conv9 = self.dehaze9_relu(self.conv2d_dehaze9(image_conv8)) 54 | 55 | dehaze_concat3 = torch.cat([dehaze_concat2, image_conv7, image_conv8, image_conv9], dim=1) 56 | image_conv10 = self.conv2d_dehaze10(dehaze_concat3) 57 | out = x + image_conv10 58 | 59 | return out 60 | 61 | 62 | if __name__ == '__main__': 63 | t = torch.randn(1, 3, 256, 256).cuda() 64 | model = UWCNN().cuda() 65 | res = model(t) 66 | print(res.shape) 67 | -------------------------------------------------------------------------------- /models/BOths.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class S_D(nn.Module): 8 | def __init__(self): 9 | super(S_D, self).__init__() 10 | self.A = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, dilation=1) 11 | self.B = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=5, dilation=5) 12 | self.LReLU = nn.LeakyReLU(0.2, inplace=True) 13 | self.Sigmoid = nn.Sigmoid() 14 | 15 | def forward(self, x): 16 | A1 = self.LReLU(self.A(x)) 17 | A2 = self.LReLU(self.B(x)) 18 | A3 = A1 - A2 19 | Guided_map = self.Sigmoid(A3) 20 | Detail_map = x * Guided_map 21 | Structure_map = x - Detail_map 22 | 23 | return Structure_map, Detail_map 24 | 25 | 26 | class simam_module(nn.Module): 27 | def __init__(self, e_lambda=1e-4): 28 | super(simam_module, self).__init__() 29 | self.activaton = nn.Sigmoid() 30 | self.e_lambda = e_lambda 31 | 32 | def __repr__(self): 33 | s = self.__class__.__name__ + '(' 34 | s += ('lambda=%f)' % self.e_lambda) 35 | 36 | return s 37 | 38 | def forward(self, x): 39 | b, c, h, w = x.size() 40 | n = w * h - 1 41 | x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2) 42 | y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5 43 | 44 | return self.activaton(y) 45 | 46 | 47 | class SD_3D(nn.Module): 48 | def __init__(self): 49 | super(SD_3D, self).__init__() 50 | self.SD = S_D() 51 | self.sim = simam_module() 52 | 53 | def forward(self, x): 54 | Structure_map, Detail_map = self.SD(x) 55 | M1_0 = self.sim(Structure_map) 56 | M2_0 = M1_0 * Detail_map 57 | M1_1 = self.sim(M2_0) 58 | M2_1 = M1_1 * Detail_map 59 | M1_2 = self.sim(M2_1) 60 | M2_2 = M1_2 * Detail_map 61 | 62 | return M2_2, M2_1, M2_0 63 | 64 | 65 | class BOths(nn.Module): 66 | def __init__(self, in_channels=3, out_channels=3): 67 | super(BOths, self).__init__() 68 | # Conv 69 | self.conv1 = nn.Conv2d(in_channels, 16, kernel_size=3, stride=1, padding=1, bias=True) 70 | # Function 71 | self.LReLU = nn.LeakyReLU(negative_slope=0.2, inplace=True) 72 | self.In1 = nn.InstanceNorm2d(16) 73 | # SD_3D 74 | self.SD3D = SD_3D() 75 | # Gate 76 | self.gate = nn.Conv2d(16 * 3, 3, 3, 1, 1, bias=True) 77 | # Final 78 | self.Final = nn.Conv2d(19, out_channels, kernel_size=1, stride=1, padding=0, bias=True) 79 | 80 | def forward(self, x): 81 | en = self.conv1(x) # 16, 256, 256 82 | en = self.LReLU(en) 83 | en = self.In1(en) 84 | # SD3D 85 | sd2, sd1, sd0 = self.SD3D(en) 86 | # Gate 87 | gates = self.gate(torch.cat((sd0, sd1, sd2), dim=1)) 88 | gated_X = sd0 * gates[:, [0], :, :] + sd1 * gates[:, [1], :, :] + sd2 * gates[:, [2], :, :] # 16, 256, 256 89 | gated_X = torch.cat((gated_X, x), dim=1) # 19, 256, 256 90 | # Final 91 | result = self.Final(gated_X) 92 | 93 | return torch.tanh(result) 94 | # 256*256*3 95 | 96 | if __name__ == '__main__': 97 | t = torch.randn(1, 3, 256, 256).cuda() 98 | model = BOths().cuda() 99 | res = model(t) 100 | print(res.shape) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Jul 23 14:35:48 2019 5 | 6 | @author: aditya 7 | """ 8 | 9 | r"""This module provides package-wide configuration management.""" 10 | from typing import Any, List 11 | 12 | from yacs.config import CfgNode as CN 13 | 14 | 15 | class Config(object): 16 | r""" 17 | A collection of all the required configuration parameters. This class is a nested dict-like 18 | structure, with nested keys accessible as attributes. It contains sensible default values for 19 | all the parameters, which may be overriden by (first) through a YAML file and (second) through 20 | a list of attributes and values. 21 | 22 | Extended Summary 23 | ---------------- 24 | This class definition contains default values corresponding to ``joint_training`` phase, as it 25 | is the final training phase and uses almost all the configuration parameters. Modification of 26 | any parameter after instantiating this class is not possible, so you must override required 27 | parameter values in either through ``config_yaml`` file or ``config_override`` list. 28 | 29 | Parameters 30 | ---------- 31 | config_yaml: str 32 | Path to a YAML file containing configuration parameters to override. 33 | config_override: List[Any], optional (default= []) 34 | A list of sequential attributes and values of parameters to override. This happens after 35 | overriding from YAML file. 36 | 37 | Examples 38 | -------- 39 | Let a YAML file named "config.yaml" specify these parameters to override:: 40 | 41 | ALPHA: 1000.0 42 | BETA: 0.5 43 | 44 | >>> _C = Config("config.yaml", ["OPTIM.BATCH_SIZE", 2048, "BETA", 0.7]) 45 | >>> _C.ALPHA # default: 100.0 46 | 1000.0 47 | >>> _C.BATCH_SIZE # default: 256 48 | 2048 49 | >>> _C.BETA # default: 0.1 50 | 0.7 51 | 52 | Attributes 53 | ---------- 54 | """ 55 | 56 | def __init__(self, config_yaml: str, config_override: List[Any] = []): 57 | self._C = CN() 58 | self._C.GPU = [0] 59 | self._C.VERBOSE = False 60 | 61 | self._C.MODEL = CN() 62 | self._C.MODEL.SESSION = 'LUT' 63 | self._C.MODEL.INPUT = 'input' 64 | self._C.MODEL.TARGET = 'target' 65 | 66 | self._C.OPTIM = CN() 67 | self._C.OPTIM.BATCH_SIZE = 1 68 | self._C.OPTIM.SEED = 3407 69 | self._C.OPTIM.NUM_EPOCHS = 100 70 | self._C.OPTIM.NEPOCH_DECAY = [100] 71 | self._C.OPTIM.LR_INITIAL = 0.0002 72 | self._C.OPTIM.LR_MIN = 0.0002 73 | self._C.OPTIM.BETA1 = 0.5 74 | self._C.OPTIM.WANDB = False 75 | 76 | self._C.TRAINING = CN() 77 | self._C.TRAINING.VAL_AFTER_EVERY = 3 78 | self._C.TRAINING.RESUME = False 79 | self._C.TRAINING.TRAIN_DIR = '../dataset/Jung/train' 80 | self._C.TRAINING.VAL_DIR = '../dataset/Jung/test' 81 | self._C.TRAINING.SAVE_DIR = 'checkpoints' 82 | self._C.TRAINING.PS_W = 512 83 | self._C.TRAINING.PS_H = 512 84 | self._C.TRAINING.ORI = False 85 | 86 | self._C.TESTING = CN() 87 | self._C.TESTING.WEIGHT = None 88 | self._C.TESTING.SAVE_IMAGES = False 89 | 90 | # Override parameter values from YAML file first, then from override list. 91 | self._C.merge_from_file(config_yaml) 92 | self._C.merge_from_list(config_override) 93 | 94 | # Make an instantiated object of this class immutable. 95 | self._C.freeze() 96 | 97 | def dump(self, file_path: str): 98 | r"""Save config at the specified file path. 99 | 100 | Parameters 101 | ---------- 102 | file_path: str 103 | (YAML) path to save config at. 104 | """ 105 | self._C.dump(stream=open(file_path, "w")) 106 | 107 | def __getattr__(self, attr: str): 108 | return self._C.__getattr__(attr) 109 | 110 | def __repr__(self): 111 | return self._C.__repr__() 112 | -------------------------------------------------------------------------------- /models/DICAM.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | class Inc(nn.Module): 4 | def __init__(self,in_channels,filters): 5 | super(Inc, self).__init__() 6 | self.branch1 = nn.Sequential( 7 | nn.Conv2d(in_channels=in_channels, out_channels=filters, kernel_size=(1, 1), stride=(1, 1),dilation=1,padding=(1-1) // 2), 8 | nn.LeakyReLU(), 9 | nn.Conv2d(in_channels=filters, out_channels=filters, kernel_size=(3, 3), stride=(1, 1),dilation=1,padding=(3-1) // 2), 10 | nn.LeakyReLU(), 11 | ) 12 | self.branch2 = nn.Sequential( 13 | nn.Conv2d(in_channels=in_channels, out_channels=filters, kernel_size=(1, 1), stride=(1, 1),dilation=1,padding=(1-1) // 2), 14 | nn.LeakyReLU(), 15 | nn.Conv2d(in_channels=filters, out_channels=filters, kernel_size=(5, 5), stride=(1, 1),dilation=1,padding=(5-1) // 2), 16 | nn.LeakyReLU(), 17 | ) 18 | self.branch3 = nn.Sequential( 19 | nn.MaxPool2d(kernel_size=(3, 3), stride=(1, 1), padding=1), 20 | nn.Conv2d(in_channels=in_channels, out_channels=filters, kernel_size=(1, 1), stride=(1, 1),dilation=1), 21 | nn.LeakyReLU(), 22 | 23 | ) 24 | self.branch4 = nn.Sequential( 25 | nn.Conv2d(in_channels=in_channels, out_channels=filters, kernel_size=(1, 1), stride=(1, 1),dilation=1), 26 | nn.LeakyReLU(), 27 | ) 28 | def forward(self,input): 29 | o1 = self.branch1(input) 30 | o2 = self.branch2(input) 31 | o3 = self.branch3(input) 32 | o4 = self.branch4(input) 33 | return torch.cat([o1,o2,o3,o4],dim=1) 34 | class Flatten(nn.Module): 35 | def forward(self, inp): 36 | return inp.view(inp.size(0), -1) 37 | class CAM(nn.Module): 38 | def __init__(self,in_channels,reduction_ratio): 39 | super(CAM, self).__init__() 40 | self.module = nn.Sequential( 41 | nn.AdaptiveAvgPool2d((1,1)), 42 | Flatten(), 43 | nn.Linear(in_channels, in_channels // reduction_ratio), 44 | nn.Softsign(), 45 | nn.Linear(in_channels // reduction_ratio, in_channels ), 46 | nn.Softsign() 47 | ) 48 | def forward(self,input): 49 | return input* self.module(input).unsqueeze(2).unsqueeze(3).expand_as(input) 50 | 51 | class DICAM(nn.Module): 52 | def __init__(self): 53 | super(DICAM, self).__init__() 54 | self.layer_1_r = Inc(in_channels=1,filters= 64) 55 | self.layer_1_g = Inc(in_channels=1,filters= 64) 56 | self.layer_1_b = Inc(in_channels=1,filters= 64) 57 | 58 | self.layer_2_r = CAM(256,4) 59 | self.layer_2_g = CAM(256,4) 60 | self.layer_2_b = CAM(256,4) 61 | 62 | self.layer_3 = Inc(768,64) 63 | self.layer_4 = CAM(256,4) 64 | 65 | self.layer_tail = nn.Sequential( 66 | nn.Conv2d(in_channels=256,out_channels=24,kernel_size=(3,3),stride=(1, 1),padding=(3-1) // 2), 67 | nn.LeakyReLU(), 68 | nn.Conv2d(in_channels=24,out_channels=3,kernel_size=(1,1),stride=(1, 1),padding=(1-1) // 2), 69 | nn.Sigmoid() 70 | 71 | ) 72 | def forward(self,input): 73 | input_r = torch.unsqueeze(input[:,0,:,:], dim=1) 74 | input_g = torch.unsqueeze(input[:,1,:,:], dim=1) 75 | input_b = torch.unsqueeze(input[:,2,:,:], dim=1) 76 | 77 | layer_1_r = self.layer_1_r(input_r) 78 | layer_1_g = self.layer_1_g(input_g) 79 | layer_1_b = self.layer_1_b(input_b) 80 | 81 | layer_2_r = self.layer_2_r(layer_1_r) 82 | layer_2_g = self.layer_2_g(layer_1_g) 83 | layer_2_b = self.layer_2_b(layer_1_b) 84 | 85 | layer_concat = torch.cat([layer_2_r,layer_2_g,layer_2_b],dim=1) 86 | 87 | layer_3 = self.layer_3(layer_concat) 88 | layer_4 = self.layer_4(layer_3) 89 | 90 | output = self.layer_tail(layer_4) 91 | return output 92 | 93 | if __name__ == '__main__': 94 | inp = torch.randn(1, 3, 256, 256).cuda() 95 | model = DICAM().cuda() 96 | res = model(inp) 97 | print(res.shape) -------------------------------------------------------------------------------- /models/ASNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class MASEblock(nn.Module): 5 | def __init__(self, in_channels, r=16): 6 | super().__init__() 7 | self.squeeze = nn.AdaptiveMaxPool2d((1,1)) 8 | self.excitation = nn.Sequential( 9 | nn.Linear(in_channels, in_channels // r), 10 | nn.ReLU(), 11 | nn.Linear(in_channels // r, in_channels), 12 | nn.Sigmoid() 13 | ) 14 | 15 | def forward(self, x): 16 | 17 | x = self.squeeze(x) 18 | x = x.view(x.size(0), -1) 19 | x = self.excitation(x) 20 | x = x.view(x.size(0), x.size(1), 1, 1) 21 | 22 | return x 23 | class MISEblock(nn.Module): 24 | def __init__(self, in_channels, r=16): 25 | super().__init__() 26 | self.squeeze = nn.AdaptiveMaxPool2d((1,1)) 27 | self.excitation = nn.Sequential( 28 | nn.Linear(in_channels, in_channels // r), 29 | nn.ReLU(), 30 | nn.Linear(in_channels // r, in_channels), 31 | nn.Sigmoid() 32 | ) 33 | 34 | def forward(self, x): 35 | 36 | x = -self.squeeze(-x) 37 | x = x.view(x.size(0), -1) 38 | x = self.excitation(x) 39 | x = x.view(x.size(0), x.size(1), 1, 1) 40 | 41 | return x 42 | class ANB(nn.Module): 43 | def __init__(self, in_channels): 44 | super().__init__() 45 | 46 | self.maseblock = MASEblock(in_channels) 47 | self.miseblock = MISEblock(in_channels) 48 | 49 | def forward(self, x): 50 | 51 | im_h = self.maseblock(x) 52 | im_l = self.miseblock(x) 53 | 54 | me = torch.tensor(0.00001, dtype=torch.float32).cuda() 55 | 56 | x = (x - im_l) / torch.maximum(im_h - im_l, me) 57 | x = torch.clip(x, 0.0, 1.0) 58 | 59 | return x 60 | class ASB(nn.Module): 61 | def __init__(self, in_channels, out_channels, kernel_size, **kwargs): 62 | super().__init__() 63 | 64 | self.conv = nn.Sequential( 65 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, **kwargs), 66 | nn.BatchNorm2d(out_channels), 67 | nn.Sigmoid(), 68 | ) 69 | 70 | def forward(self, x): 71 | 72 | x = self.conv(x) 73 | 74 | return x 75 | class BasicConv2d(nn.Module): 76 | def __init__(self, in_channels, out_channels, kernel_size, **kwargs): 77 | super().__init__() 78 | 79 | self.conv = nn.Sequential( 80 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, **kwargs), 81 | nn.Sigmoid(), 82 | ) 83 | 84 | def forward(self, x): 85 | x = self.conv(x) 86 | return x 87 | class CorlorCorrection(nn.Module): 88 | def __init__(self, in_channels, out_channels, kernel_size, **kwargs): 89 | super().__init__() 90 | 91 | self.conv = nn.Sequential( 92 | ASB(in_channels, out_channels, kernel_size=kernel_size, **kwargs), 93 | ANB(out_channels), 94 | ASB(out_channels, out_channels, kernel_size=kernel_size, **kwargs), 95 | ASB(out_channels, out_channels, kernel_size=kernel_size, **kwargs), 96 | ASB(out_channels, out_channels, kernel_size=kernel_size, **kwargs), 97 | ASB(out_channels, out_channels, kernel_size=kernel_size, **kwargs), 98 | ASB(out_channels, out_channels, kernel_size=kernel_size, **kwargs), 99 | ANB(out_channels), 100 | BasicConv2d(out_channels, in_channels, kernel_size=kernel_size, **kwargs), 101 | ) 102 | 103 | def forward(self, x): 104 | 105 | x = self.conv(x) 106 | 107 | return x 108 | 109 | 110 | class ASNet(nn.Module): 111 | 112 | def __init__(self): 113 | super(ASNet, self).__init__() 114 | 115 | self.conv2wb_1 = CorlorCorrection(3, 128, 3, stride=1, padding=1) 116 | 117 | def forward(self, img_haze): 118 | 119 | conv_wb1 = self.conv2wb_1(img_haze) 120 | 121 | return conv_wb1 122 | 123 | 124 | if __name__ == '__main__': 125 | t = torch.randn(1, 3, 256, 256).cuda() 126 | model = ASNet().cuda() 127 | res = model(t) 128 | print(res.shape) -------------------------------------------------------------------------------- /models/WaterNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ConfidenceMapGenerator(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | # Confidence maps 9 | # Accepts input of size (N, 3*4, H, W) 10 | self.conv1 = nn.Conv2d( 11 | in_channels=12, out_channels=128, kernel_size=7, dilation=1, padding="same" 12 | ) 13 | self.relu1 = nn.ReLU() 14 | self.conv2 = nn.Conv2d( 15 | in_channels=128, out_channels=128, kernel_size=5, dilation=1, padding="same" 16 | ) 17 | self.relu2 = nn.ReLU() 18 | self.conv3 = nn.Conv2d( 19 | in_channels=128, out_channels=128, kernel_size=3, dilation=1, padding="same" 20 | ) 21 | self.relu3 = nn.ReLU() 22 | self.conv4 = nn.Conv2d( 23 | in_channels=128, out_channels=64, kernel_size=1, dilation=1, padding="same" 24 | ) 25 | self.relu4 = nn.ReLU() 26 | self.conv5 = nn.Conv2d( 27 | in_channels=64, out_channels=64, kernel_size=7, dilation=1, padding="same" 28 | ) 29 | self.relu5 = nn.ReLU() 30 | self.conv6 = nn.Conv2d( 31 | in_channels=64, out_channels=64, kernel_size=5, dilation=1, padding="same" 32 | ) 33 | self.relu6 = nn.ReLU() 34 | self.conv7 = nn.Conv2d( 35 | in_channels=64, out_channels=64, kernel_size=3, dilation=1, padding="same" 36 | ) 37 | self.relu7 = nn.ReLU() 38 | self.conv8 = nn.Conv2d( 39 | in_channels=64, out_channels=3, kernel_size=3, dilation=1, padding="same" 40 | ) 41 | self.sigmoid = nn.Sigmoid() 42 | 43 | def forward(self, x, wb, he, gc): 44 | out = torch.cat([x, wb, he, gc], dim=1) 45 | out = self.relu1(self.conv1(out)) 46 | out = self.relu2(self.conv2(out)) 47 | out = self.relu3(self.conv3(out)) 48 | out = self.relu4(self.conv4(out)) 49 | out = self.relu5(self.conv5(out)) 50 | out = self.relu6(self.conv6(out)) 51 | out = self.relu7(self.conv7(out)) 52 | out = self.sigmoid(self.conv8(out)) 53 | out1, out2, out3 = torch.split(out, [1, 1, 1], dim=1) 54 | return out1, out2, out3 55 | 56 | 57 | class Refiner(nn.Module): 58 | def __init__(self): 59 | super().__init__() 60 | self.conv1 = nn.Conv2d( 61 | in_channels=6, out_channels=32, kernel_size=7, dilation=1, padding="same" 62 | ) 63 | self.conv2 = nn.Conv2d( 64 | in_channels=32, out_channels=32, kernel_size=5, dilation=1, padding="same" 65 | ) 66 | self.conv3 = nn.Conv2d( 67 | in_channels=32, out_channels=3, kernel_size=3, dilation=1, padding="same" 68 | ) 69 | self.relu1 = nn.ReLU() 70 | self.relu2 = nn.ReLU() 71 | self.relu3 = nn.ReLU() 72 | 73 | def forward(self, x, xbar): 74 | out = torch.cat([x, xbar], dim=1) 75 | out = self.relu1(self.conv1(out)) 76 | out = self.relu2(self.conv2(out)) 77 | out = self.relu3(self.conv3(out)) 78 | return out 79 | 80 | 81 | class WaterNet(nn.Module): 82 | """ 83 | waternet = WaterNet() 84 | in = torch.randn(16, 3, 112, 112) 85 | waternet_out = waternet(in, in, in, in) 86 | waternet_out.shape 87 | # torch.Size([16, 3, 112, 112]) 88 | """ 89 | 90 | def __init__(self): 91 | super().__init__() 92 | self.cmg = ConfidenceMapGenerator() 93 | self.wb_refiner = Refiner() 94 | self.he_refiner = Refiner() 95 | self.gc_refiner = Refiner() 96 | 97 | def forward(self, x, wb, he, gc): 98 | wb_cm, he_cm, gc_cm = self.cmg(x, wb, he, gc) 99 | refined_wb = self.wb_refiner(x, wb) 100 | refined_he = self.he_refiner(x, he) 101 | refined_gc = self.gc_refiner(x, gc) 102 | return ( 103 | torch.mul(refined_wb, wb_cm) 104 | + torch.mul(refined_he, he_cm) 105 | + torch.mul(refined_gc, gc_cm) 106 | ) 107 | 108 | if __name__ == '__main__': 109 | t = torch.randn(1, 3, 256, 256).cuda() 110 | model = WaterNet().cuda() 111 | res = model(t, t, t, t) 112 | print(res.shape) -------------------------------------------------------------------------------- /models/CLUIE-Net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class double_conv(nn.Module): 7 | '''(conv => BN => ReLU) * 2''' 8 | def __init__(self, in_ch, out_ch): 9 | super(double_conv, self).__init__() 10 | self.conv = nn.Sequential( 11 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 12 | nn.BatchNorm2d(out_ch), 13 | nn.ReLU(inplace=True), 14 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 15 | nn.BatchNorm2d(out_ch), 16 | nn.ReLU(inplace=True) 17 | ) 18 | 19 | def forward(self, x): 20 | x = self.conv(x) 21 | return x 22 | 23 | 24 | class inconv(nn.Module): 25 | def __init__(self, in_ch, out_ch): 26 | super(inconv, self).__init__() 27 | self.conv = double_conv(in_ch, out_ch) 28 | 29 | def forward(self, x): 30 | x = self.conv(x) 31 | return x 32 | 33 | 34 | class down(nn.Module): 35 | def __init__(self, in_ch, out_ch): 36 | super(down, self).__init__() 37 | self.mpconv = nn.Sequential( 38 | nn.MaxPool2d(2), 39 | double_conv(in_ch, out_ch) 40 | ) 41 | 42 | def forward(self, x): 43 | x = self.mpconv(x) 44 | return x 45 | 46 | 47 | class up(nn.Module): 48 | def __init__(self, in_ch, out_ch, bilinear=True): 49 | super(up, self).__init__() 50 | 51 | # would be a nice idea if the upsampling could be learned too, 52 | # but my machine do not have enough memory to handle all those weights 53 | if bilinear: 54 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 55 | else: 56 | self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) 57 | 58 | self.conv = double_conv(in_ch, out_ch) 59 | 60 | def forward(self, x1, x2): 61 | x1 = self.up(x1) 62 | 63 | # input is CHW 64 | diffY = x2.size()[2] - x1.size()[2] 65 | diffX = x2.size()[3] - x1.size()[3] 66 | 67 | 68 | x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, 69 | diffY // 2, diffY - diffY//2)) 70 | 71 | x = torch.cat([x2, x1], dim=1) 72 | x = self.conv(x) 73 | return x 74 | 75 | 76 | class outconv(nn.Module): 77 | def __init__(self, in_ch, out_ch): 78 | super(outconv, self).__init__() 79 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 80 | 81 | def forward(self, x): 82 | x = self.conv(x) 83 | return x 84 | 85 | 86 | class Flatten(nn.Module): 87 | def forward(self, input): 88 | return input.view(input.size(0), -1) 89 | 90 | class UNetEncoder(nn.Module): 91 | def __init__(self, n_channels=3): 92 | super(UNetEncoder, self).__init__() 93 | self.inc = inconv(n_channels, 64) 94 | self.down1 = down(64, 128) 95 | self.down2 = down(128, 256) 96 | self.down3 = down(256, 512) 97 | self.down4 = down(512, 512) 98 | 99 | def forward(self, x): 100 | x1 = self.inc(x) 101 | x2 = self.down1(x1) 102 | x3 = self.down2(x2) 103 | x4 = self.down3(x3) 104 | x5 = self.down4(x4) 105 | return x5, (x1, x2, x3, x4) 106 | 107 | class UNetDecoder(nn.Module): 108 | def __init__(self, n_channels=3): 109 | super(UNetDecoder, self).__init__() 110 | self.up1 = up(1024, 256) 111 | self.up2 = up(512, 128) 112 | self.up3 = up(256, 64) 113 | self.up4 = up(128, 64) 114 | self.outc = outconv(64, n_channels) 115 | self.sigmoid = nn.Sigmoid() 116 | 117 | def forward(self, x, enc_outs): 118 | x = self.sigmoid(x) 119 | x = self.up1(x, enc_outs[3]) 120 | x = self.up2(x, enc_outs[2]) 121 | x = self.up3(x, enc_outs[1]) 122 | x = self.up4(x, enc_outs[0]) 123 | x = self.outc(x) 124 | return nn.Tanh()(x) 125 | 126 | class CLUIE(nn.Module): 127 | def __init__(self): 128 | super(CLUIE, self).__init__() 129 | self.encoder = UNetEncoder() 130 | self.decoder = UNetDecoder() 131 | 132 | def forward(self, x): 133 | x, enc_x = self.encoder(x) 134 | res = self.decoder(x, enc_x) 135 | return res 136 | 137 | if __name__ == '__main__': 138 | t = torch.randn(1, 3, 256, 256).cuda() 139 | model = CLUIE().cuda() 140 | res = model(t) 141 | print(res.shape) 142 | -------------------------------------------------------------------------------- /models/USUIR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image, ImageFilter 4 | from torchvision.transforms import ToTensor 5 | 6 | def np_to_pil(img_np): 7 | """ 8 | Converts image in np.array format to PIL image. 9 | 10 | From C x W x H [0..1] to W x H x C [0...255] 11 | :param img_np: 12 | :return: 13 | """ 14 | ar = np.clip(img_np * 255, 0, 255).astype(np.uint8) 15 | 16 | if img_np.shape[0] == 1: 17 | ar = ar[0] 18 | else: 19 | assert img_np.shape[0] == 3, img_np.shape 20 | ar = ar.transpose(1, 2, 0) 21 | 22 | return Image.fromarray(ar) 23 | 24 | def torch_to_np(img_var): 25 | """ 26 | Converts an image in torch.Tensor format to np.array. 27 | 28 | From 1 x C x W x H [0..1] to C x W x H [0..1] 29 | :param img_var: 30 | :return: 31 | """ 32 | return img_var.detach().cpu().numpy()[0] 33 | 34 | 35 | def quantize(img, rgb_range): 36 | pixel_range = 255 / rgb_range 37 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) 38 | 39 | 40 | def get_A(x): 41 | x_np = np.clip(torch_to_np(x), 0, 1) 42 | x_pil = np_to_pil(x_np) 43 | h, w = x_pil.size 44 | windows = (h + w) / 2 45 | A = x_pil.filter(ImageFilter.GaussianBlur(windows)) 46 | A = ToTensor()(A) 47 | return A.unsqueeze(0) 48 | 49 | class JNet(torch.nn.Module): 50 | def __init__(self, num=64): 51 | super().__init__() 52 | self.conv1 = torch.nn.Sequential( 53 | torch.nn.ReflectionPad2d(1), 54 | torch.nn.Conv2d(3, num, 3, 1, 0), 55 | torch.nn.InstanceNorm2d(num), 56 | torch.nn.ReLU() 57 | ) 58 | self.conv2 = torch.nn.Sequential( 59 | torch.nn.ReflectionPad2d(1), 60 | torch.nn.Conv2d(num, num, 3, 1, 0), 61 | torch.nn.InstanceNorm2d(num), 62 | torch.nn.ReLU() 63 | ) 64 | self.conv3 = torch.nn.Sequential( 65 | torch.nn.ReflectionPad2d(1), 66 | torch.nn.Conv2d(num, num, 3, 1, 0), 67 | torch.nn.InstanceNorm2d(num), 68 | torch.nn.ReLU() 69 | ) 70 | self.conv4 = torch.nn.Sequential( 71 | torch.nn.ReflectionPad2d(1), 72 | torch.nn.Conv2d(num, num, 3, 1, 0), 73 | torch.nn.InstanceNorm2d(num), 74 | torch.nn.ReLU() 75 | ) 76 | self.final = torch.nn.Sequential( 77 | torch.nn.Conv2d(num, 3, 1, 1, 0), 78 | torch.nn.Sigmoid() 79 | ) 80 | 81 | def forward(self, data): 82 | data = self.conv1(data) 83 | data = self.conv2(data) 84 | data = self.conv3(data) 85 | data = self.conv4(data) 86 | data1 = self.final(data) 87 | 88 | return data1 89 | 90 | 91 | class TNet(torch.nn.Module): 92 | def __init__(self, num=64): 93 | super().__init__() 94 | self.conv1 = torch.nn.Sequential( 95 | torch.nn.ReflectionPad2d(1), 96 | torch.nn.Conv2d(3, num, 3, 1, 0), 97 | torch.nn.InstanceNorm2d(num), 98 | torch.nn.ReLU() 99 | ) 100 | self.conv2 = torch.nn.Sequential( 101 | torch.nn.ReflectionPad2d(1), 102 | torch.nn.Conv2d(num, num, 3, 1, 0), 103 | torch.nn.InstanceNorm2d(num), 104 | torch.nn.ReLU() 105 | ) 106 | self.conv3 = torch.nn.Sequential( 107 | torch.nn.ReflectionPad2d(1), 108 | torch.nn.Conv2d(num, num, 3, 1, 0), 109 | torch.nn.InstanceNorm2d(num), 110 | torch.nn.ReLU() 111 | ) 112 | self.conv4 = torch.nn.Sequential( 113 | torch.nn.ReflectionPad2d(1), 114 | torch.nn.Conv2d(num, num, 3, 1, 0), 115 | torch.nn.InstanceNorm2d(num), 116 | torch.nn.ReLU() 117 | ) 118 | self.final = torch.nn.Sequential( 119 | torch.nn.Conv2d(num, 3, 1, 1, 0), 120 | torch.nn.Sigmoid() 121 | ) 122 | 123 | def forward(self, data): 124 | data = self.conv1(data) 125 | data = self.conv2(data) 126 | data = self.conv3(data) 127 | data = self.conv4(data) 128 | data1 = self.final(data) 129 | 130 | return data1 131 | 132 | 133 | class USUIR(torch.nn.Module): 134 | def __init__(self): 135 | super(USUIR, self).__init__() 136 | self.image_net = JNet() 137 | self.mask_net = TNet() 138 | 139 | def forward(self, data): 140 | a_out = get_A(data).to(data.device) 141 | j_out = self.image_net(data) 142 | t_out = self.mask_net(data) 143 | I_rec = j_out * t_out + (1 - t_out) * a_out 144 | return I_rec 145 | 146 | if __name__ == '__main__': 147 | x = torch.randn(1, 3, 256, 256).cuda() 148 | model = USUIR().cuda() 149 | res = model(x) 150 | print(res.shape) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch.optim as optim 4 | from accelerate import Accelerator 5 | from pytorch_msssim import SSIM 6 | from torch.utils.data import DataLoader 7 | from torchmetrics.functional import peak_signal_noise_ratio, structural_similarity_index_measure 8 | from tqdm import tqdm 9 | 10 | from config import Config 11 | from data import get_training_data, get_validation_data 12 | from models import * 13 | from utils import * 14 | 15 | warnings.filterwarnings('ignore') 16 | 17 | opt = Config('config.yml') 18 | 19 | seed_everything(opt.OPTIM.SEED) 20 | 21 | def train(): 22 | # Accelerate 23 | accelerator = Accelerator(log_with='wandb') if opt.OPTIM.WANDB else Accelerator() 24 | device = accelerator.device 25 | config = { 26 | "dataset": opt.TRAINING.TRAIN_DIR 27 | } 28 | accelerator.init_trackers("shadow", config=config) 29 | 30 | if accelerator.is_local_main_process: 31 | os.makedirs(opt.TRAINING.SAVE_DIR, exist_ok=True) 32 | 33 | # Data Loader 34 | train_dir = opt.TRAINING.TRAIN_DIR 35 | val_dir = opt.TRAINING.VAL_DIR 36 | 37 | train_dataset = get_training_data(train_dir, opt.MODEL.INPUT, opt.MODEL.TARGET, {'w': opt.TRAINING.PS_W, 'h': opt.TRAINING.PS_H}) 38 | train_loader = DataLoader(dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=16, 39 | drop_last=False, pin_memory=True) 40 | val_dataset = get_validation_data(val_dir, opt.MODEL.INPUT, opt.MODEL.TARGET, {'w': opt.TRAINING.PS_W, 'h': opt.TRAINING.PS_H, 'ori': opt.TRAINING.ORI}) 41 | val_loader = DataLoader(dataset=val_dataset, batch_size=1, shuffle=False, num_workers=8, drop_last=False, 42 | pin_memory=True) 43 | 44 | # Model & Loss 45 | model = Model() 46 | criterion_ssim = SSIM(data_range=1, size_average=True, channel=3).to(device) 47 | criterion_psnr = torch.nn.MSELoss() 48 | 49 | # Optimizer & Scheduler 50 | optimizer_b = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=opt.OPTIM.LR_INITIAL, betas=(0.9, 0.999), eps=1e-8) 51 | scheduler_b = optim.lr_scheduler.CosineAnnealingLR(optimizer_b, opt.OPTIM.NUM_EPOCHS, eta_min=opt.OPTIM.LR_MIN) 52 | 53 | train_loader, val_loader = accelerator.prepare(train_loader, val_loader) 54 | model = accelerator.prepare(model) 55 | optimizer_b, scheduler_b = accelerator.prepare(optimizer_b, scheduler_b) 56 | 57 | start_epoch = 1 58 | best_epoch = 1 59 | best_psnr = 0 60 | size = len(val_loader) 61 | # training 62 | for epoch in range(start_epoch, opt.OPTIM.NUM_EPOCHS + 1): 63 | model.train() 64 | 65 | for i, data in enumerate(tqdm(train_loader, disable=not accelerator.is_local_main_process)): 66 | # get the inputs; data is a list of [target, input, filename] 67 | inp = data[0].contiguous() 68 | dep = data[1].contiguous() 69 | tar = data[2] 70 | 71 | # forward 72 | optimizer_b.zero_grad() 73 | res = model(inp, dep) 74 | 75 | loss_psnr = criterion_psnr(res, tar) 76 | loss_ssim = 1 - criterion_ssim(res, tar) 77 | 78 | train_loss = loss_psnr + 0.4 * loss_ssim 79 | 80 | # backward 81 | accelerator.backward(train_loss) 82 | optimizer_b.step() 83 | 84 | scheduler_b.step() 85 | 86 | # testing 87 | if epoch % opt.TRAINING.VAL_AFTER_EVERY == 0: 88 | model.eval() 89 | psnr = 0 90 | ssim = 0 91 | for idx, test_data in enumerate(tqdm(val_loader, disable=not accelerator.is_local_main_process)): 92 | # get the inputs; data is a list of [targets, inputs, filename] 93 | inp = test_data[0].contiguous() 94 | dep = test_data[1].contiguous() 95 | tar = test_data[2] 96 | 97 | with torch.no_grad(): 98 | res = model(inp, dep) 99 | 100 | res, tar = accelerator.gather((res, tar)) 101 | 102 | psnr += peak_signal_noise_ratio(res, tar, data_range=1) 103 | ssim += structural_similarity_index_measure(res, tar, data_range=1) 104 | 105 | psnr /= size 106 | ssim /= size 107 | 108 | if psnr > best_psnr: 109 | # save model 110 | best_epoch = epoch 111 | best_psnr = psnr 112 | save_checkpoint({ 113 | 'state_dict': model.state_dict(), 114 | }, epoch, opt.MODEL.SESSION, opt.TRAINING.SAVE_DIR) 115 | 116 | if accelerator.is_local_main_process: 117 | accelerator.log({ 118 | "PSNR": psnr, 119 | "SSIM": ssim, 120 | }, step=epoch) 121 | 122 | print( 123 | "epoch: {}, PSNR: {}, SSIM: {}, best PSNR: {}, best epoch: {}" 124 | .format(epoch, psnr, ssim, best_psnr, best_epoch)) 125 | 126 | accelerator.end_training() 127 | 128 | 129 | if __name__ == '__main__': 130 | train() 131 | -------------------------------------------------------------------------------- /models/NU2Net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | def normalize_img(img): 6 | if torch.max(img) > 1 or torch.min(img) < 0: 7 | # img: b x c x h x w 8 | b, c, h, w = img.shape 9 | temp_img = img.view(b, c, h*w) 10 | im_max = torch.max(temp_img, dim=2)[0].view(b, c, 1) 11 | im_min = torch.min(temp_img, dim=2)[0].view(b, c, 1) 12 | 13 | temp_img = (temp_img - im_min) / (im_max - im_min + 1e-7) 14 | 15 | img = temp_img.view(b, c, h, w) 16 | 17 | return img 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | def __init__(self, in_channel, out_channel): 22 | super(BasicBlock, self).__init__() 23 | self.out = nn.Sequential( 24 | nn.Conv2d(in_channel, out_channel, 3, 1, 1), 25 | nn.InstanceNorm2d(out_channel), 26 | nn.ELU() 27 | ) 28 | 29 | def forward(self, x): 30 | y = self.out(x) 31 | 32 | return y 33 | 34 | 35 | class ChannelAttention(nn.Module): 36 | def __init__(self, channels, factor): 37 | super().__init__() 38 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 39 | self.channel_map = nn.Sequential( 40 | nn.Conv2d(channels, channels // factor, 1, 1, 0), 41 | nn.LeakyReLU(), 42 | nn.Conv2d(channels // factor, channels, 1, 1, 0), 43 | nn.Softmax() 44 | ) 45 | 46 | def forward(self, x): 47 | avg_pool = self.avg_pool(x) 48 | ch_map = self.channel_map(avg_pool) 49 | return x * ch_map 50 | 51 | 52 | class Encoder(nn.Module): 53 | def __init__(self, basic_channel): 54 | super(Encoder, self).__init__() 55 | self.e_stage1 = nn.Sequential( 56 | nn.Conv2d(3, basic_channel, 3, 1, 1), 57 | BasicBlock(basic_channel, basic_channel) 58 | ) 59 | self.e_stage2 = nn.Sequential( 60 | nn.MaxPool2d(kernel_size=2, stride=2), 61 | BasicBlock(basic_channel, basic_channel * 2) 62 | ) 63 | self.e_stage3 = nn.Sequential( 64 | nn.MaxPool2d(kernel_size=2, stride=2), 65 | BasicBlock(basic_channel * 2, basic_channel * 4) 66 | ) 67 | self.e_stage4 = nn.Sequential( 68 | nn.MaxPool2d(kernel_size=2, stride=2), 69 | BasicBlock(basic_channel * 4, basic_channel * 8) 70 | ) 71 | 72 | def forward(self, x): 73 | x1 = self.e_stage1(x) 74 | x2 = self.e_stage2(x1) 75 | x3 = self.e_stage3(x2) 76 | x4 = self.e_stage4(x3) 77 | 78 | return x1, x2, x3, x4 79 | 80 | 81 | class Decoder(nn.Module): 82 | def __init__(self, basic_channel, is_residual=True): 83 | super(Decoder, self).__init__() 84 | self.is_residual = is_residual 85 | self.d_stage4 = nn.Sequential( 86 | BasicBlock(basic_channel * 8, basic_channel * 4), 87 | nn.UpsamplingBilinear2d(scale_factor=2) 88 | ) 89 | self.d_stage3 = nn.Sequential( 90 | BasicBlock(basic_channel * 4, basic_channel * 2), 91 | nn.UpsamplingBilinear2d(scale_factor=2) 92 | ) 93 | self.d_stage2 = nn.Sequential( 94 | BasicBlock(basic_channel * 2, basic_channel), 95 | nn.UpsamplingBilinear2d(scale_factor=2) 96 | ) 97 | self.d_stage1 = nn.Sequential( 98 | BasicBlock(basic_channel, basic_channel // 4) 99 | ) 100 | self.output = nn.Sequential( 101 | nn.Conv2d(basic_channel // 4, 3, 1, 1, 0), 102 | nn.Tanh() 103 | ) 104 | 105 | def forward(self, x, x1, x2, x3, x4): 106 | y3 = self.d_stage4(x4) 107 | y2 = self.d_stage3(y3 + x3) 108 | y1 = self.d_stage2(y2 + x2) 109 | y = self.output(self.d_stage1(y1 + x1)) 110 | 111 | if self.is_residual: 112 | return y + x 113 | else: 114 | return y 115 | 116 | 117 | class NU2Net(nn.Module): 118 | def __init__(self, basic_channel=64, is_residual=True, tail='norm'): 119 | super(NU2Net, self).__init__() 120 | self.tail = tail 121 | self.encoder = Encoder(basic_channel) 122 | self.decoder = Decoder(basic_channel, is_residual=is_residual) 123 | if self.tail == 'IN+clip' or self.tail == 'IN+sigmoid': 124 | self.IN = nn.InstanceNorm2d(3) 125 | 126 | def forward(self, raw_img, **kwargs): 127 | # encoder-decoder part 128 | x1, x2, x3, x4 = self.encoder(raw_img) 129 | y = self.decoder(raw_img, x1, x2, x3, x4) 130 | if self.tail == 'norm': 131 | y = normalize_img(y) 132 | elif self.tail == 'clip': 133 | y = torch.clamp(y, min=0.0, max=1.0) 134 | elif self.tail == 'sigmoid': 135 | y = torch.sigmoid(y) 136 | elif self.tail == 'IN+clip': 137 | y = torch.clamp(self.IN(y), min=0.0, max=1.0) 138 | elif self.tail == 'IN+sigmoid': 139 | y = torch.sigmoid(self.IN(y)) 140 | elif self.tail == 'none': 141 | y = y 142 | 143 | return y 144 | 145 | 146 | if __name__ == "__main__": 147 | model = NU2Net().cuda() 148 | x = torch.rand((1, 3, 512, 512)).cuda() 149 | y = model(x) 150 | print(y) 151 | -------------------------------------------------------------------------------- /models/LiteEnhanceNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def swish(x): 5 | return x * x.sigmoid() 6 | 7 | def hard_sigmoid(x, inplace=False): 8 | return nn.ReLU6(inplace=inplace)(x + 3) / 6 9 | 10 | def hard_swish(x, inplace=False): 11 | return x * hard_sigmoid(x, inplace) 12 | 13 | class HardSigmoid(nn.Module): 14 | def __init__(self, inplace=False): 15 | super(HardSigmoid, self).__init__() 16 | self.inplace = inplace 17 | 18 | def forward(self, x): 19 | return hard_sigmoid(x, inplace=self.inplace) 20 | 21 | class HardSwish(nn.Module): 22 | def __init__(self, inplace=False): 23 | super(HardSwish, self).__init__() 24 | self.inplace = inplace 25 | 26 | def forward(self, x): 27 | return hard_swish(x, inplace=self.inplace) 28 | 29 | def _make_divisible(v, divisor=8, min_value=None): ## 将通道数变成8的整数倍 30 | """ 31 | This function is taken from the original tf repo. 32 | It ensures that all layers have a channel number that is divisible by 8 33 | It can be seen here: 34 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 35 | :param v: 36 | :param divisor: 37 | :param min_value: 38 | :return: 39 | """ 40 | if min_value is None: 41 | min_value = divisor 42 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 43 | # Make sure that round down does not go down by more than 10%. 44 | if new_v < 0.9 * v: 45 | new_v += divisor 46 | return new_v 47 | 48 | 49 | class SELayer(nn.Module): 50 | def __init__(self, inp, oup, reduction=4): 51 | super(SELayer, self).__init__() 52 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 53 | self.fc = nn.Sequential( 54 | nn.Conv2d(oup, _make_divisible(inp // reduction), 1, 1, 0,), 55 | nn.ReLU(), 56 | nn.Conv2d(_make_divisible(inp // reduction), oup, 1, 1, 0), 57 | HardSigmoid() 58 | ) 59 | 60 | def forward(self, x): 61 | b, c, _, _ = x.size() 62 | y = self.avg_pool(x) 63 | y = self.fc(y).view(b, c, 1, 1) 64 | return x * y 65 | 66 | class ConvBlock1(nn.Module): 67 | def __init__(self): 68 | super(ConvBlock1, self).__init__() 69 | self.DW = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, groups=16, padding=1, bias=False) 70 | self.BN = nn.BatchNorm2d(16) 71 | self.HS = HardSwish() 72 | self.PW = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=1, stride=1, padding=0, bias=False) 73 | self.BNN = nn.BatchNorm2d(32) 74 | 75 | def forward(self, x): 76 | a = self.HS(self.BN(self.DW(x))) 77 | a = self.HS(self.BNN(self.PW(a))) 78 | return a 79 | 80 | class ConvBlock2(nn.Module): 81 | def __init__(self): 82 | super(ConvBlock2, self).__init__() 83 | self.DW = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, groups=32, padding=1, bias=False) 84 | self.BN = nn.BatchNorm2d(32) 85 | self.HS = HardSwish() 86 | self.PW = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=1, stride=1, padding=0, bias=False) 87 | self.BNN = nn.BatchNorm2d(64) 88 | 89 | def forward(self, x): 90 | a = self.HS(self.BN(self.DW(x))) 91 | a = self.HS(self.BNN(self.PW(a))) 92 | return a 93 | 94 | class ConvBlock3(nn.Module): 95 | def __init__(self): 96 | super(ConvBlock3, self).__init__() 97 | self.DW = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, groups=64, padding=1, bias=False) 98 | self.BN = nn.BatchNorm2d(64) 99 | self.HS = HardSwish() 100 | self.PW = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=1, stride=1, padding=0, bias=False) 101 | self.BNN = nn.BatchNorm2d(32) 102 | 103 | def forward(self, x): 104 | a = self.HS(self.BN(self.DW(x))) 105 | a = self.HS(self.BNN(self.PW(a))) 106 | return a 107 | 108 | class ConvBlock4(nn.Module): 109 | def __init__(self): 110 | super(ConvBlock4, self).__init__() 111 | self.DW = nn.Conv2d(in_channels=80, out_channels=80, kernel_size=3, stride=1, groups=80, padding=1, bias=False) 112 | self.BN = nn.BatchNorm2d(80) 113 | self.HS = HardSwish() 114 | self.PW = nn.Conv2d(in_channels=80, out_channels=32, kernel_size=1, stride=1, padding=0, bias=False) 115 | self.BNN = nn.BatchNorm2d(32) 116 | self.SE = SELayer(80, 80) 117 | 118 | def forward(self, x): 119 | 120 | a = self.HS(self.BN(self.DW(x))) 121 | a = self.SE(a) 122 | a = self.HS(self.BNN(self.PW(a))) 123 | return a 124 | 125 | class LiteEnhanceNet(nn.Module): 126 | def __init__(self): 127 | super(LiteEnhanceNet, self).__init__() 128 | self.input = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=1, stride=1, padding=0, bias=False) ## 第一层卷积 129 | self.output = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=1, stride=1, padding=0, bias=False) 130 | self.relu = nn.ReLU(inplace=True) 131 | self.block1 = ConvBlock1() 132 | self.block2 = ConvBlock2() 133 | self.block3 = ConvBlock3() 134 | self.block4 = ConvBlock4() 135 | 136 | def forward(self, x): 137 | x = self.input(x) 138 | x1 = self.block1(x) 139 | x2 = self.block2(x1) 140 | # x2 = torch.cat((x, x2), 1) 141 | x3 = self.block3(x2) 142 | x3 = torch.cat((x, x1, x3), 1) 143 | x4 = self.block4(x3) 144 | out = self.output(x4) 145 | return out 146 | 147 | 148 | if __name__ == '__main__': 149 | t = torch.randn(1, 3, 256, 256).cuda() 150 | model = LiteEnhanceNet().cuda() 151 | res = model(t) 152 | print(res.shape) -------------------------------------------------------------------------------- /models/P2CNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import kornia 5 | from timm.models.layers import trunc_normal_ 6 | 7 | 8 | class Volume_2D(nn.Module): 9 | def __init__(self, indim=256): 10 | super(Volume_2D, self).__init__() 11 | self.c2f_dim = 33 12 | self.conv_dim = indim 13 | coarse_range = torch.arange(-1., 1. + 0.01, step=0.0625).reshape(1, -1) # 1 k 14 | self.color_range = nn.Parameter(coarse_range[None, :, :, None, None], requires_grad=False) # 1 1 k 1 1 15 | self.coarse_conv = nn.Sequential( 16 | nn.Conv2d(self.conv_dim, self.conv_dim, kernel_size=1, stride=1, padding=0), 17 | nn.Hardswish(inplace=True), 18 | nn.Conv2d(self.conv_dim, 2 * self.c2f_dim, kernel_size=1, stride=1, padding=0) 19 | ) 20 | 21 | def forward(self, x): 22 | b, c, h, w = x.shape 23 | cost_feature = self.coarse_conv(x).view(b, 2, self.c2f_dim, h, w) 24 | prob = F.softmax(cost_feature, dim=2) # b 2 k h w 25 | exp = torch.sum(prob * self.color_range, dim=2) # b 2 h w 26 | return exp 27 | 28 | 29 | class ColorCompenateNet(nn.Module): 30 | def __init__(self, cont_dim=64, color_dim=64): 31 | super(ColorCompenateNet, self).__init__() 32 | self.color_encoder = nn.Sequential( 33 | nn.Conv2d(2, 64, kernel_size=3, padding=1) 34 | ) 35 | self.context_encoder = nn.Sequential( 36 | nn.Conv2d(1, 64, kernel_size=3, padding=1) 37 | ) 38 | self.encoder_0 = nn.Sequential( 39 | nn.InstanceNorm2d(128, affine=True), 40 | nn.Hardswish(inplace=True), 41 | nn.MaxPool2d(2), 42 | nn.Conv2d(128, 128, kernel_size=3, padding=1, groups=2), 43 | nn.InstanceNorm2d(128, affine=True), 44 | nn.LeakyReLU(0.2, inplace=True) # 1/2 45 | ) 46 | self.encoder_1 = nn.Sequential( 47 | nn.Conv2d(128, 128, kernel_size=3, padding=1, groups=2), 48 | nn.InstanceNorm2d(128, affine=True), 49 | nn.Hardswish(inplace=True), 50 | nn.MaxPool2d(2), 51 | nn.Conv2d(128, 128, kernel_size=3, padding=1, groups=2), 52 | nn.InstanceNorm2d(128, affine=True), 53 | nn.LeakyReLU(0.2, inplace=True) # 1/4 54 | ) 55 | self.encoder_2 = nn.Sequential( 56 | nn.Conv2d(128, 256, kernel_size=3, padding=1, groups=2), 57 | nn.InstanceNorm2d(256, affine=True), 58 | nn.Hardswish(inplace=True), 59 | nn.MaxPool2d(2), 60 | nn.Conv2d(256, 256, kernel_size=3, padding=1, groups=2), 61 | nn.InstanceNorm2d(256, affine=True), 62 | nn.LeakyReLU(0.2, inplace=True) # 1/8 63 | ) 64 | self.encoder_3 = nn.Sequential( 65 | nn.Conv2d(256, 256, kernel_size=3, padding=1, groups=1), 66 | nn.InstanceNorm2d(256, affine=True), 67 | nn.Hardswish(inplace=True), 68 | nn.MaxPool2d(2), 69 | nn.Conv2d(256, 256, kernel_size=3, padding=1, groups=1), 70 | nn.InstanceNorm2d(256, affine=True), 71 | nn.LeakyReLU(0.2, inplace=True) # 1/16 72 | ) 73 | self.color_decoder3 = Volume_2D(indim=256) 74 | self.color_decoder2 = Volume_2D(indim=258) 75 | self.color_decoder1 = Volume_2D(indim=130) 76 | self.color_decoder0 = Volume_2D(indim=130) 77 | self.init_parameters() 78 | 79 | def init_parameters(self): 80 | for m in self.modules(): 81 | if isinstance(m, nn.Conv2d): 82 | nn.init.kaiming_normal_(m.weight, nonlinearity='leaky_relu') 83 | if m.bias is not None: 84 | nn.init.constant_(m.bias, 0) 85 | if isinstance(m, nn.Linear): 86 | trunc_normal_(m.weight, std=.02) 87 | if m.bias is not None: 88 | nn.init.constant_(m.bias, 0) 89 | 90 | def forward(self, x): 91 | # step1: feature extraction 92 | h, w = x.shape[2:] 93 | l = x[:, :1, :, :] 94 | ab = x[:, 1:, :, :] 95 | feat_color = self.color_encoder(ab) 96 | fest_cont = self.context_encoder(l) 97 | feat = torch.cat([feat_color, fest_cont], dim=1) 98 | feat0 = self.encoder_0(feat) # 1/2 128 99 | h0, w0 = feat0.shape[2:] 100 | feat1 = self.encoder_1(feat0) # 1/4 128 101 | h1, w1 = feat1.shape[2:] 102 | feat2 = self.encoder_2(feat1) # 1/8 256 103 | h2, w2 = feat2.shape[2:] 104 | feat3 = self.encoder_3(feat2) # 1/16 256 105 | 106 | # step2: multi-scale probabilistic volumetric fusion 107 | pre_ab3 = self.color_decoder3(feat3) 108 | pre_ab3 = F.interpolate(pre_ab3, size=(h2, w2), mode='bilinear', align_corners=True) # 1/8 2 109 | feat2 = torch.cat([feat2, pre_ab3], dim=1) # 1/8 258 110 | 111 | pre_ab2 = self.color_decoder2(feat2) 112 | pre_ab2 = F.interpolate(pre_ab2, size=(h1, w1), mode='bilinear', align_corners=True) # 1/4 2 113 | feat1 = torch.cat([feat1, pre_ab2], dim=1) # 1/4 130 114 | 115 | pre_ab1 = self.color_decoder1(feat1) 116 | pre_ab1 = F.interpolate(pre_ab1, size=(h0, w0), mode='bilinear', align_corners=True) # 1/2 2 117 | feat0 = torch.cat([feat0, pre_ab1], dim=1) # 1/2 130 118 | 119 | pre_ab0 = self.color_decoder0(feat0) 120 | pre_ab0 = F.interpolate(pre_ab0, size=(h, w), mode='bilinear', align_corners=True) # 1 2 121 | 122 | return l, pre_ab0, pre_ab1, pre_ab2, pre_ab3 123 | 124 | 125 | class P2CNet(nn.Module): 126 | def __init__(self, dim=64): 127 | super(P2CNet, self).__init__() 128 | self.color = ColorCompenateNet(cont_dim=dim, color_dim=dim) 129 | 130 | def forward(self, x): 131 | l, ab0, ab1, ab2, ab3 = self.color(x) 132 | lab = torch.cat([l * 100., ab0 * 127.], dim=1) 133 | rgb = kornia.color.lab_to_rgb(lab) # 0~1 134 | return {'ab_pred0': ab0, 'ab_pred1': ab1, 'ab_pred2': ab2, 'ab_pred3': ab3, 'lab_rgb': rgb} 135 | 136 | 137 | if __name__ == '__main__': 138 | t = torch.randn(1, 3, 256, 256).cuda() 139 | model = P2CNet().cuda() 140 | res = model(t)['lab_rgb'] 141 | print(res.shape) -------------------------------------------------------------------------------- /models/IACC.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | from kornia.color import rgb_to_hsv 5 | import torch 6 | from ptflops import get_model_complexity_info 7 | 8 | 9 | class UIA(nn.Module): 10 | def __init__(self, channels, ks): 11 | super(UIA, self).__init__() 12 | self._c_avg = nn.AdaptiveAvgPool2d((1, 1)) 13 | self._c_conv = nn.Conv2d(channels, channels, 1, bias=False) 14 | self._c_sig = nn.Sigmoid() 15 | self._h_avg = nn.AdaptiveAvgPool2d((1, None)) 16 | self._h_conv = nn.Conv2d(channels, channels, 1, groups=channels, bias=False) 17 | self._w_avg = nn.AdaptiveAvgPool2d((None, 1)) 18 | self._w_conv = nn.Conv2d(channels, channels, 1, groups=channels, bias=False) 19 | self._hw_conv = nn.Conv2d(channels, channels, ks, padding=ks // 2, padding_mode='reflect', 20 | groups=channels, bias=False) 21 | self._chw_conv = nn.Conv2d(channels, 1, 1, bias=False) 22 | self._chw_sig = nn.Sigmoid() 23 | 24 | def forward(self, x): 25 | c_map = self._c_conv(self._c_avg(x)) 26 | c_weight = self._c_sig(c_map) 27 | h_map = self._h_conv(self._h_avg(x)) 28 | w_map = self._w_conv(self._w_avg(x)) 29 | hw_map = self._hw_conv(w_map @ h_map) 30 | chw_map = self._chw_conv(c_weight * hw_map) 31 | chw_weight = self._chw_sig(chw_map) 32 | return chw_weight * x 33 | 34 | 35 | class NormGate(nn.Module): 36 | def __init__(self, channels, ks, norm=nn.InstanceNorm2d): 37 | super(NormGate, self).__init__() 38 | self._norm_branch = nn.Sequential( 39 | norm(channels), 40 | nn.Conv2d(channels, channels, ks, padding=ks // 2, padding_mode='reflect', bias=False) 41 | ) 42 | self._sig_branch = nn.Sequential( 43 | nn.Conv2d(channels, channels, ks, padding=ks // 2, padding_mode='reflect', bias=False), 44 | nn.Sigmoid() 45 | ) 46 | 47 | def forward(self, x): 48 | norm = self._norm_branch(x) 49 | sig = self._sig_branch(x) 50 | return norm * sig 51 | 52 | 53 | class UCB(nn.Module): 54 | def __init__(self, channels, ks): 55 | super(UCB, self).__init__() 56 | self._body = nn.Sequential( 57 | nn.Conv2d(channels, channels, kernel_size=ks, padding=ks // 2, 58 | padding_mode='reflect', bias=False), 59 | NormGate(channels, ks), 60 | UIA(channels, ks) 61 | ) 62 | 63 | def forward(self, x): 64 | y = self._body(x) 65 | return y + x 66 | 67 | 68 | class PWConv(nn.Module): 69 | def __init__(self, in_channels, out_channels, kernel_size, bias=False): 70 | super(PWConv, self).__init__() 71 | self._body = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, 72 | padding=kernel_size // 2, padding_mode='reflect', bias=bias) 73 | 74 | def forward(self, x): 75 | return self._body(x) 76 | 77 | 78 | class GlobalColorCompensationNet(nn.Module): 79 | def __init__(self, channel_scale, kernel_size): 80 | super(GlobalColorCompensationNet, self).__init__() 81 | self._body = nn.Sequential( 82 | PWConv(3, channel_scale, kernel_size), 83 | UCB(channel_scale, kernel_size), 84 | UCB(channel_scale, kernel_size), 85 | UCB(channel_scale, kernel_size), 86 | PWConv(channel_scale, 3, kernel_size), 87 | nn.Sigmoid() 88 | ) 89 | 90 | def forward(self, x): 91 | y = self._body(x) 92 | return y 93 | 94 | 95 | class CLCC(nn.Module): 96 | def __init__(self, channel_scale, main_ks, gcc_ks): 97 | super(CLCC, self).__init__() 98 | self._color_branch = GlobalColorCompensationNet(channel_scale, gcc_ks) 99 | self._in_conv = nn.Sequential( 100 | PWConv(3, channel_scale, main_ks), 101 | UIA(channel_scale, main_ks) 102 | ) 103 | self._group1 = nn.Sequential( 104 | *[UCB(channel_scale, main_ks) for _ in range(4)] 105 | ) 106 | self._group2 = nn.Sequential( 107 | *[UCB(channel_scale, main_ks) for _ in range(4)] 108 | ) 109 | self._group3 = nn.Sequential( 110 | *[UCB(channel_scale, main_ks) for _ in range(4)] 111 | ) 112 | self._group1_adaptation = nn.Sequential( 113 | PWConv(3, channel_scale, main_ks), 114 | UCB(channel_scale, main_ks) 115 | ) 116 | self._group2_adaptation = nn.Sequential( 117 | PWConv(3, channel_scale, main_ks), 118 | UCB(channel_scale, main_ks) 119 | ) 120 | self._group3_adaptation = nn.Sequential( 121 | PWConv(3, channel_scale, main_ks), 122 | UCB(channel_scale, main_ks) 123 | ) 124 | self._out_conv = nn.Sequential( 125 | PWConv(channel_scale, 3, main_ks), 126 | nn.Tanh() 127 | ) 128 | 129 | for m in self.modules(): 130 | if isinstance(m, nn.Conv2d): 131 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 132 | # elif isinstance(m, nn.InstanceNorm2d): 133 | # m.weight.data.fill_(1) 134 | # m.bias.data.zero_() 135 | 136 | 137 | def forward(self, x): 138 | color_comp = 1 - x 139 | color_comp_map = self._color_branch(color_comp) 140 | in_feat = self._in_conv(x) 141 | group1_out = self._group1(in_feat) 142 | group1_comp_out = group1_out + self._group1_adaptation(color_comp_map * color_comp) 143 | group2_out = self._group2(group1_comp_out) 144 | group2_comp_out = group2_out + self._group2_adaptation(color_comp_map * color_comp) 145 | group3_out = self._group3(group2_comp_out) 146 | group3_comp_out = group3_out + self._group3_adaptation(color_comp_map * color_comp) 147 | out = self._out_conv(group3_comp_out) 148 | return out 149 | 150 | 151 | if __name__ == '__main__': 152 | import torch 153 | x = torch.randn((2, 3, 256, 256)) 154 | model = CLCC(64, 3, 3) 155 | macs, params = get_model_complexity_info(model, (3, 256, 256), verbose=False, print_per_layer_stat=False) 156 | print('MACS: ' + str(macs)) 157 | print('Params: ' + str(params)) 158 | # model = GlobalColorCompensationNet(64) 159 | y = model(x) 160 | print(y.shape) -------------------------------------------------------------------------------- /models/LANet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | '''........Laplace operation.......''' 6 | class Laplace(nn.Module): 7 | def __init__(self): 8 | super(Laplace,self).__init__() 9 | self.conv1=nn.Conv2d(in_channels=3,out_channels=1,kernel_size=3,stride=1,padding=0,bias=False) 10 | nn.init.constant_(self.conv1.weight,1) 11 | nn.init.constant_(self.conv1.weight[0,0,1,1],-8) 12 | nn.init.constant_(self.conv1.weight[0,1,1,1],-8) 13 | nn.init.constant_(self.conv1.weight[0,2,1,1],-8) 14 | 15 | def forward(self,x1): 16 | edge_map=self.conv1(x1) 17 | return edge_map 18 | 19 | 20 | class PALayer(nn.Module): 21 | '''........pixel attention(PA).......''' 22 | def __init__(self, channel): 23 | super(PALayer, self).__init__() 24 | 25 | self.PA = nn.Sequential( 26 | nn.Conv2d(channel, channel // 8, 3, padding=1, bias=True), 27 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 28 | nn.InstanceNorm2d(64), 29 | nn.Conv2d(channel // 8, 1, 3, padding=1, bias=True), 30 | # CxHxW -> 1xHxW 31 | nn.Sigmoid() 32 | ) 33 | 34 | def forward(self, x): 35 | y = self.PA(x) 36 | return x * y 37 | 38 | class CALayer(nn.Module): 39 | '''........Channel attention(CA).......''' 40 | def __init__(self, channel): 41 | super(CALayer, self).__init__() 42 | 43 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 44 | self.CA = nn.Sequential( 45 | nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True), 46 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 47 | nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True), 48 | nn.Sigmoid() 49 | ) 50 | 51 | def forward(self, x): 52 | y = self.avg_pool(x) 53 | y = self.CA(y) 54 | return x * y 55 | 56 | class Block(nn.Module): 57 | '''........parallel attention module(PAM).......''' 58 | 59 | def __init__(self, dim, kernel_size): 60 | super(Block, self).__init__() 61 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 62 | self.conv1 = nn.Conv2d(dim, dim, kernel_size, padding=(kernel_size // 2), bias=True) 63 | self.conv2 = nn.Conv2d(dim, dim, kernel_size, padding=(kernel_size // 2), bias=True) 64 | self.calayer = CALayer(dim) 65 | self.palayer = PALayer(dim) 66 | 67 | def forward(self, x): 68 | res = self.conv1(x) 69 | res1 = self.calayer(res) 70 | res2 = self.palayer(res) 71 | res = res2 + res1 72 | res = self.conv2(res) 73 | res = res + x 74 | return res 75 | 76 | class GS(nn.Module): 77 | '''........Group structure.......''' 78 | def __init__(self, dim, kernel_size, blocks): 79 | super(GS, self).__init__() 80 | modules = [Block(dim, kernel_size) for _ in range(blocks)] 81 | self.gs = nn.Sequential(*modules) 82 | 83 | def forward(self, x): 84 | res = self.gs(x) 85 | return res 86 | 87 | class Branch(nn.Module): 88 | '''......Branch......''' 89 | def __init__(self, in_channels, out_channels, kernel_size, bias=True): 90 | super(Branch, self).__init__() 91 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias) 92 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias) 93 | self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 94 | self.IN = nn.InstanceNorm2d(out_channels) 95 | 96 | def forward(self, x): 97 | x1 = self.conv1(x) 98 | x2 = self.conv2(x1) 99 | x3 = self.relu(x2) 100 | x4 = self.IN(x3) 101 | x5 = self.conv2(x4) 102 | 103 | return x1, x5 104 | 105 | class LANet(nn.Module): 106 | '''......the structure of LANet......''' 107 | def __init__(self, gps = 3, blocks = 20, dim = 64, kernel_size = 3): 108 | super(LANet, self).__init__() 109 | self.gps = gps 110 | self.dim = dim 111 | self.kernel_size = kernel_size 112 | 113 | self.laplace = Laplace() 114 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 115 | assert self.gps == 3 116 | 117 | self.g1 = GS(self.dim, kernel_size, blocks=blocks) 118 | self.g2 = GS(self.dim, kernel_size, blocks=blocks) 119 | self.g3 = GS(self.dim, kernel_size, blocks=blocks) 120 | 121 | self.brabch_3 = Branch(in_channels = 3, out_channels = self.dim, kernel_size = 3) 122 | self.brabch_5 = Branch(in_channels=3, out_channels=self.dim, kernel_size=5) 123 | self.brabch_7 = Branch(in_channels=3, out_channels=self.dim, kernel_size=7) 124 | 125 | self.fusion = nn.Sequential(*[ 126 | nn.Conv2d(self.dim * self.gps, self.dim // 8, 3, padding=1, bias=True), 127 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 128 | nn.InstanceNorm2d(self.dim), 129 | nn.Conv2d(self.dim // 8, self.gps, 1, padding = 0, bias=True), 130 | nn.Sigmoid() 131 | ]) 132 | 133 | self.Final = nn.Conv2d(self.dim, 3, 1, padding=0, bias=True) 134 | 135 | def forward(self, x): 136 | 137 | '''.....three branch.......''' 138 | x11, x1 = self.brabch_3(x) 139 | x22, x2 = self.brabch_5(x) 140 | x33, x3 = self.brabch_7(x) 141 | 142 | '''......Multiscale Fusion......''' 143 | w = self.fusion(torch.cat([x1, x2, x3], dim=1)) 144 | w = torch.split(w, 1, dim = 1) 145 | x4 = w[0] * x1 + w[1] * x2 + w[2] * x3 146 | 147 | res1 = self.g1(x4) #GS(1) 148 | '''......Adaptive learning Module......''' 149 | x5 = self.avg_pool(x4) 150 | res1= x5 * res1 + x33 151 | 152 | res2 = self.g2(res1) #GS(2) 153 | '''......Adaptive learning Module......''' 154 | x6 = self.avg_pool(res1) 155 | res2 = x6 * res2 + x22 156 | 157 | res3 = self.g3(res2) #GS(3) 158 | '''......Adaptive learning Module......''' 159 | x7 = self.avg_pool(res2) 160 | res3 = x7 * res3 + x11 161 | 162 | out = self.Final(res3) 163 | # Laplace operation 164 | edge_map = self.laplace(out) 165 | 166 | return out, edge_map 167 | 168 | if __name__ == '__main__': 169 | t = torch.randn(1, 3, 256, 256).cuda() 170 | model = LANet().cuda() 171 | res, _ = model(t) 172 | print(res.shape) -------------------------------------------------------------------------------- /models/ADMNNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def similarity_matrix(x): 5 | B, C, H, W = x.size() 6 | scale = ((H * W) ** -0.5) 7 | bmat1 = x.flatten(start_dim=2) 8 | bmat2 = bmat1.transpose(1, 2) 9 | bmat3 = torch.bmm(bmat1, bmat2, out=None) 10 | bmat3 = bmat3 * scale 11 | similarity = nn.Softmax(dim=-1)(bmat3) 12 | return similarity 13 | 14 | 15 | class PoolBlock(nn.Module): 16 | def __init__(self, in_channels, stride=2): 17 | super(PoolBlock, self).__init__() 18 | self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels, 19 | bias=False) 20 | self.bn1 = nn.BatchNorm2d(in_channels) 21 | self.prelu1 = nn.PReLU(in_channels) 22 | self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels, 23 | bias=False) 24 | self.bn2 = nn.BatchNorm2d(in_channels) 25 | self.prelu2 = nn.PReLU(in_channels) 26 | 27 | def forward(self, x): 28 | x = self.prelu1(self.bn1(self.conv1(x))) 29 | x = self.prelu2(self.bn2(self.conv2(x))) 30 | return x 31 | 32 | 33 | class convLayer(nn.Module): 34 | """ 35 | "Depthwise conv + Pointwise conv" 36 | """ 37 | 38 | def __init__(self, in_channels, out_channels, dilation=1, kernel_size=3, stride=1): 39 | super(convLayer, self).__init__() 40 | self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=dilation, 41 | groups=in_channels, dilation=dilation, bias=False) 42 | self.bn1 = nn.BatchNorm2d(in_channels) 43 | self.prelu1 = nn.PReLU(in_channels) 44 | self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False) 45 | self.bn2 = nn.BatchNorm2d(out_channels) 46 | self.prelu2 = nn.PReLU(out_channels) 47 | 48 | def forward(self, x): 49 | x = self.prelu1(self.bn1(self.conv1(x))) 50 | x = self.prelu2(self.bn2(self.conv2(x))) 51 | return x 52 | 53 | 54 | class conv1_1(nn.Module): 55 | def __init__(self, in_channels, out_channels): 56 | super(conv1_1, self).__init__() 57 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False) 58 | self.bn1 = nn.BatchNorm2d(out_channels) 59 | self.prelu1 = nn.PReLU(out_channels) 60 | 61 | def forward(self, x): 62 | x = self.prelu1(self.bn1(self.conv1(x))) 63 | return x 64 | 65 | 66 | class CA(nn.Module): 67 | def __init__(self, in_channels): 68 | super(CA, self).__init__() 69 | self.pool1 = PoolBlock(in_channels) 70 | 71 | def forward(self, x): 72 | B, C, H, W = x.size() 73 | pool1 = self.pool1(x) 74 | 75 | attention = similarity_matrix(pool1) 76 | x1 = x.flatten(start_dim=2).transpose(1, 2) 77 | out = torch.bmm(x1, attention, out=None) 78 | out = out.transpose(1, 2).view(B, C, H, W) 79 | 80 | return out 81 | 82 | 83 | class Multiscale(nn.Module): 84 | def __init__(self, in_channels): 85 | super(Multiscale, self).__init__() 86 | self.conv1 = conv1_1(in_channels, in_channels) 87 | self.conv3 = convLayer(in_channels, in_channels) 88 | self.conv5 = convLayer(in_channels, in_channels, dilation=2) 89 | 90 | def forward(self, x): 91 | x1 = self.conv1(x) 92 | 93 | x2 = self.conv3(x) 94 | 95 | x3 = self.conv5(x) 96 | 97 | return x1, x2, x3 98 | 99 | 100 | class DFSM(nn.Module): 101 | def __init__(self, in_channels): 102 | super(DFSM, self).__init__() 103 | self.Multiscale = Multiscale(in_channels) 104 | 105 | self.pool1 = PoolBlock(in_channels) 106 | self.conv1 = conv1_1(in_channels, 1) 107 | 108 | self.pool2 = PoolBlock(in_channels) 109 | self.conv2 = conv1_1(in_channels, 1) 110 | 111 | self.pool3 = PoolBlock(in_channels) 112 | self.conv3 = conv1_1(in_channels, 1) 113 | 114 | def forward(self, x): 115 | B, C, H, W = x.size() 116 | x1, x2, x3 = self.Multiscale(x) 117 | pool1 = self.conv1(self.pool1(x1)) # 1*1 118 | pool2 = self.conv2(self.pool2(x2)) # 3*3 119 | pool3 = self.conv3(self.pool3(x3)) # 5*5 120 | 121 | cat = torch.cat([pool1, pool2, pool3], dim=1) 122 | attention = similarity_matrix(cat) # 3*3 123 | 124 | x1_1 = x1.view(B, -1, C, H, W).flatten(start_dim=2) 125 | x2_2 = x2.view(B, -1, C, H, W).flatten(start_dim=2) 126 | x3_3 = x3.view(B, -1, C, H, W).flatten(start_dim=2) 127 | 128 | bmat1 = torch.cat([x1_1, x2_2, x3_3], dim=1).transpose(1, 2) 129 | 130 | out = torch.bmm(bmat1, attention) 131 | out = out.transpose(1, 2) 132 | 133 | x1 = x1 + out[:, 0, :].view(B, C, H, W) 134 | x2 = x2 + out[:, 1, :].view(B, C, H, W) 135 | x3 = x3 + out[:, 2, :].view(B, C, H, W) 136 | 137 | return x1 + x2 + x3 138 | 139 | 140 | class MCAM(nn.Module): 141 | def __init__(self, in_channels): 142 | super(MCAM, self).__init__() 143 | self.Multiscale = Multiscale(in_channels) 144 | 145 | self.ca1 = CA(in_channels) 146 | self.ca2 = CA(in_channels) 147 | self.ca3 = CA(in_channels) 148 | 149 | def forward(self, x): 150 | x1, x2, x3 = self.Multiscale(x) 151 | x_ca1 = self.ca1(x1) 152 | x_ca2 = self.ca2(x2) 153 | x_ca3 = self.ca3(x3) 154 | return x_ca1 + x_ca2 + x_ca3 155 | 156 | 157 | class Block(nn.Module): 158 | def __init__(self, in_channels): 159 | super(Block, self).__init__() 160 | self.dfsm = DFSM(in_channels) 161 | self.mcam = MCAM(in_channels) 162 | 163 | def forward(self, x): 164 | x_dfsm = self.dfsm(x) + x 165 | x_mcam = self.mcam(x_dfsm) + x_dfsm 166 | 167 | return x_mcam + x 168 | 169 | 170 | class ADMNNet(nn.Module): 171 | def __init__(self, in_channels=3, blocks_num=3, out_channels=16): 172 | super(ADMNNet, self).__init__() 173 | self.first = convLayer(in_channels, out_channels) 174 | self.blocks = nn.Sequential(*[Block(out_channels) for i in range(blocks_num)]) 175 | self.conv2_final = nn.Conv2d(out_channels, in_channels, kernel_size=1, stride=1, padding=0, bias=False) 176 | self.tanh = nn.Tanh() 177 | 178 | def forward(self, x): 179 | y = self.first(x) 180 | y = self.blocks(y) 181 | 182 | mask = self.tanh(self.conv2_final(y)) 183 | 184 | out = torch.clamp(mask + x, -1, 1, out=None) 185 | return out 186 | 187 | 188 | if __name__ == '__main__': 189 | t = torch.randn(1, 3, 256, 256) 190 | net = ADMNNet(3, 3) 191 | out = net(t) 192 | print(out.shape) 193 | -------------------------------------------------------------------------------- /data/dataset_RGB.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import albumentations as A 4 | import numpy as np 5 | import torchvision.transforms.functional as F 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | 9 | 10 | def is_image_file(filename): 11 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif']) 12 | 13 | 14 | class DataLoaderTrain(Dataset): 15 | def __init__(self, rgb_dir, inp='input', target='target', img_options=None): 16 | super(DataLoaderTrain, self).__init__() 17 | 18 | inp_files = sorted(os.listdir(os.path.join(rgb_dir, inp))) 19 | tar_files = sorted(os.listdir(os.path.join(rgb_dir, target))) 20 | dep_files = sorted(os.listdir(os.path.join(rgb_dir, 'depth'))) 21 | # mas_files = sorted(os.listdir(os.path.join(rgb_dir, 'mask'))) 22 | 23 | self.inp_filenames = [os.path.join(rgb_dir, inp, x) for x in inp_files if is_image_file(x)] 24 | self.tar_filenames = [os.path.join(rgb_dir, target, x) for x in tar_files if is_image_file(x)] 25 | self.dep_filenames = [os.path.join(rgb_dir, 'depth', x) for x in dep_files if is_image_file(x)] 26 | # self.mas_filenames = [os.path.join(rgb_dir, 'mask', x) for x in mas_files if is_image_file(x)] 27 | 28 | self.img_options = img_options 29 | self.sizex = len(self.tar_filenames) # get the size of target 30 | 31 | self.transform = A.Compose([ 32 | A.Resize(height=img_options['h'], width=img_options['w']), 33 | A.Transpose(p=0.3), 34 | A.Flip(p=0.3), 35 | A.RandomRotate90(p=0.3), 36 | ], 37 | is_check_shapes=False, 38 | additional_targets={ 39 | 'target': 'image', 40 | 'depth': 'image' 41 | } 42 | ) 43 | 44 | def __len__(self): 45 | return self.sizex 46 | 47 | def __getitem__(self, index): 48 | index_ = index % self.sizex 49 | 50 | inp_path = self.inp_filenames[index_] 51 | tar_path = self.tar_filenames[index_] 52 | dep_path = self.dep_filenames[index_] 53 | 54 | inp_img = Image.open(inp_path).convert('RGB') 55 | tar_img = Image.open(tar_path).convert('RGB') 56 | dep_img = Image.open(dep_path).convert('RGB') 57 | 58 | inp_img = np.array(inp_img) 59 | tar_img = np.array(tar_img) 60 | dep_img = np.array(dep_img) 61 | 62 | transformed = self.transform(image=inp_img, target=tar_img, depth=dep_img) 63 | 64 | inp_img = F.to_tensor(transformed['image']) 65 | tar_img = F.to_tensor(transformed['target']) 66 | dep_img = F.to_tensor(transformed['depth']) 67 | 68 | filename = os.path.splitext(os.path.split(tar_path)[-1])[0] 69 | 70 | return inp_img, dep_img, tar_img, filename 71 | 72 | 73 | class DataLoaderVal(Dataset): 74 | def __init__(self, rgb_dir, inp='input', target='target', img_options=None): 75 | super(DataLoaderVal, self).__init__() 76 | 77 | inp_files = sorted(os.listdir(os.path.join(rgb_dir, inp))) 78 | tar_files = sorted(os.listdir(os.path.join(rgb_dir, target))) 79 | dep_files = sorted(os.listdir(os.path.join(rgb_dir, 'depth'))) 80 | 81 | self.inp_filenames = [os.path.join(rgb_dir, inp, x) for x in inp_files if is_image_file(x)] 82 | self.tar_filenames = [os.path.join(rgb_dir, target, x) for x in tar_files if is_image_file(x)] 83 | self.dep_filenames = [os.path.join(rgb_dir, 'depth', x) for x in dep_files if is_image_file(x)] 84 | 85 | self.img_options = img_options 86 | self.sizex = len(self.tar_filenames) # get the size of target 87 | 88 | self.transform = A.Compose([ 89 | A.Resize(height=img_options['h'], width=img_options['w']), ], 90 | is_check_shapes=False, 91 | additional_targets={ 92 | 'target': 'image', 93 | 'depth': 'image' 94 | } 95 | ) 96 | 97 | def __len__(self): 98 | return self.sizex 99 | 100 | def __getitem__(self, index): 101 | index_ = index % self.sizex 102 | 103 | inp_path = self.inp_filenames[index_] 104 | tar_path = self.tar_filenames[index_] 105 | dep_path = self.dep_filenames[index_] 106 | 107 | inp_img = Image.open(inp_path).convert('RGB') 108 | tar_img = Image.open(tar_path).convert('RGB') 109 | dep_img = Image.open(dep_path).convert('RGB') 110 | 111 | if not self.img_options['ori']: 112 | inp_img = np.array(inp_img) 113 | tar_img = np.array(tar_img) 114 | dep_img = np.array(dep_img) 115 | 116 | transformed = self.transform(image=inp_img, target=tar_img, depth=dep_img) 117 | 118 | inp_img = transformed['image'] 119 | tar_img = transformed['target'] 120 | dep_img = transformed['depth'] 121 | 122 | inp_img = F.to_tensor(inp_img) 123 | tar_img = F.to_tensor(tar_img) 124 | dep_img = F.to_tensor(dep_img) 125 | 126 | filename = os.path.split(tar_path)[-1] 127 | 128 | return inp_img, dep_img, tar_img, filename 129 | 130 | 131 | class DataLoaderTest(Dataset): 132 | def __init__(self, rgb_dir, inp='input', img_options=None): 133 | super(DataLoaderTest, self).__init__() 134 | 135 | inp_files = sorted(os.listdir(os.path.join(rgb_dir, inp))) 136 | dep_files = sorted(os.listdir(os.path.join(rgb_dir, 'depth'))) 137 | 138 | self.inp_filenames = [os.path.join(rgb_dir, inp, x) for x in inp_files if is_image_file(x)] 139 | self.dep_filenames = [os.path.join(rgb_dir, 'depth', x) for x in dep_files if is_image_file(x)] 140 | 141 | self.img_options = img_options 142 | self.sizex = len(self.inp_filenames) # get the size of target 143 | 144 | self.transform = A.Compose([ 145 | A.Resize(height=img_options['h'], width=img_options['w']), ], 146 | is_check_shapes=False, 147 | additional_targets={ 148 | 'depth': 'image' 149 | } 150 | ) 151 | 152 | def __len__(self): 153 | return self.sizex 154 | 155 | def __getitem__(self, index): 156 | index_ = index % self.sizex 157 | 158 | inp_path = self.inp_filenames[index_] 159 | dep_path = self.dep_filenames[index_] 160 | 161 | inp_img = Image.open(inp_path).convert('RGB') 162 | dep_img = Image.open(dep_path).convert('RGB') 163 | 164 | if not self.img_options['ori']: 165 | inp_img = np.array(inp_img) 166 | dep_img = np.array(dep_img) 167 | 168 | transformed = self.transform(image=inp_img, depth=dep_img) 169 | 170 | inp_img = transformed['image'] 171 | dep_img = transformed['depth'] 172 | 173 | inp_img = F.to_tensor(inp_img) 174 | dep_img = F.to_tensor(dep_img) 175 | 176 | filename = os.path.split(inp_path)[-1] 177 | 178 | return inp_img, dep_img, filename 179 | -------------------------------------------------------------------------------- /models/UIEC2Net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from kornia.color import rgb_to_hsv, hsv_to_rgb 4 | 5 | def piece_function_org(x_m, para_m, M): 6 | b, c, w, h = x_m.shape 7 | r_m = para_m[:, 0].view(b, c, 1, 1).expand(b, c, w, h) 8 | for i in range(M-1): 9 | para = (para_m[:, i + 1] - para_m[:, i]).view(b, c, 1, 1).expand(b, c, w, h) 10 | r_m = r_m + para * \ 11 | sgn_m(M * x_m - i * torch.ones(x_m.shape).to(x_m.device)) 12 | return r_m 13 | 14 | def sgn_m(x): 15 | # x = torch.Tensor(x) 16 | zero_lab = torch.zeros(x.shape).to(x.device) 17 | # print("one_lab",one_lab) 18 | s_t = torch.where(x < 0, zero_lab, x) 19 | one_lab = torch.ones(x.shape).to(x.device) 20 | s = torch.where(s_t > 1, one_lab, s_t) 21 | return s 22 | 23 | class UIEC2Net(nn.Module): 24 | def __init__(self): 25 | super(UIEC2Net, self).__init__() 26 | self._init_layers() 27 | 28 | def _init_layers(self): 29 | self.rgb2hsv = rgb_to_hsv 30 | self.hsv2rgb = hsv_to_rgb 31 | # rgb 32 | self.norm_batch = nn.InstanceNorm2d # choose one 33 | 34 | self.rgb_norm_batch1 = self.norm_batch(64) 35 | self.rgb_con1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) 36 | self.rgb_norm_batch2 = self.norm_batch(64) 37 | self.rgb_con2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 38 | self.rgb_norm_batch3 = self.norm_batch(64) 39 | self.rgb_con3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 40 | self.rgb_norm_batch4 = self.norm_batch(64) 41 | self.rgb_con4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 42 | self.rgb_norm_batch5 = self.norm_batch(64) 43 | self.rgb_con5 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 44 | self.rgb_norm_batch6 = self.norm_batch(64) 45 | self.rgb_con6 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 46 | self.rgb_con7 = nn.Conv2d(64, 64, kernel_size=1, stride=1, padding=0) 47 | 48 | self.rgb_fuction_down = nn.LeakyReLU(inplace=True) 49 | self.rgb_fuction_up = nn.ReLU(inplace=True) 50 | 51 | # hsv 52 | self.relu = nn.LeakyReLU(inplace=True) 53 | self.rrelu = nn.ReLU(inplace=True) 54 | self.M = 11 55 | # New /1/./2/./3/ use number_f = 32 56 | number_f = 64 57 | self.e_conv1 = nn.Conv2d(6, number_f, 3, 1, 1, bias=True) 58 | self.e_conv2 = nn.Conv2d(number_f, number_f, 3, 1, 1, bias=True) 59 | self.e_conv3 = nn.Conv2d(number_f, number_f, 3, 1, 1, bias=True) 60 | self.e_conv4 = nn.Conv2d(number_f, number_f, 3, 1, 1, bias=True) 61 | self.e_conv7 = nn.Conv2d(number_f, number_f, 3, 1, 1, bias=True) 62 | self.e_convfc = nn.Linear(number_f, 44) 63 | self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False) 64 | self.avagepool = nn.AdaptiveAvgPool2d(1) 65 | self.upsample = nn.UpsamplingBilinear2d(scale_factor=2) 66 | # confidence 67 | self.norm_batch = nn.InstanceNorm2d # choose one 68 | 69 | self.norm_batch1 = self.norm_batch(64) 70 | self.con1 = nn.Conv2d(9, 64, kernel_size=3, stride=1, padding=1) 71 | self.norm_batch2 = self.norm_batch(64) 72 | self.con2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 73 | self.norm_batch3 = self.norm_batch(64) 74 | self.con3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 75 | self.norm_batch4 = self.norm_batch(64) 76 | self.con4 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 77 | self.norm_batch5 = self.norm_batch(64) 78 | self.con5 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 79 | self.norm_batch6 = self.norm_batch(64) 80 | self.con6 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 81 | self.con7 = nn.Conv2d(64, 6, kernel_size=1, stride=1, padding=0) 82 | 83 | self.fuction_down = nn.LeakyReLU(inplace=True) 84 | self.fuction_up = nn.ReLU(inplace=True) 85 | 86 | 87 | def forward(self, x): 88 | h = self.rgb_fuction_down(self.rgb_norm_batch1(self.rgb_con1(x))) 89 | h = self.rgb_fuction_down(self.rgb_norm_batch2(self.rgb_con2(h))) 90 | h = self.rgb_fuction_down(self.rgb_norm_batch3(self.rgb_con3(h))) 91 | h = self.rgb_fuction_down(self.rgb_norm_batch4(self.rgb_con4(h))) 92 | h = self.rgb_fuction_up(self.rgb_norm_batch5(self.rgb_con5(h))) 93 | h = self.rgb_fuction_up(self.rgb_norm_batch6(self.rgb_con6(h))) # try to use 94 | rgb_out = torch.sigmoid(self.rgb_con7(h)) 95 | rgb_out = rgb_out[:, 0:3, :, :] 96 | hsv_fromrgbout = self.rgb2hsv(rgb_out) 97 | hsv_frominput = self.rgb2hsv(x) 98 | 99 | hsv_input = torch.cat([hsv_fromrgbout, hsv_fromrgbout], dim=1) 100 | batch_size = hsv_input.size()[0] 101 | x1 = self.relu(self.e_conv1(hsv_input)) 102 | x1 = self.maxpool(x1) 103 | x2 = self.relu(self.e_conv2(x1)) 104 | x2 = self.maxpool(x2) 105 | x3 = self.relu(self.e_conv3(x2)) 106 | x3 = self.maxpool(x3) 107 | x4 = self.relu(self.e_conv4(x3)) 108 | x_r = self.relu(self.e_conv7(x4)) 109 | x_r = self.avagepool(x_r).view(batch_size, -1) 110 | x_r = self.e_convfc(x_r) 111 | H, S, V, H2S = torch.split(x_r, self.M, dim=1) 112 | H_in, S_in, V_in = hsv_input[:, 0:1, :, :], hsv_input[:, 1:2, :, :], hsv_input[:, 2:3, :, :], 113 | H_out = piece_function_org(H_in, H, self.M) 114 | S_out1 = piece_function_org(S_in, S, self.M) 115 | V_out = piece_function_org(V_in, V, self.M) 116 | 117 | S_out2 = piece_function_org(H_in, H2S, self.M) 118 | S_out = (S_out1 + S_out2) / 2 119 | 120 | zero_lab = torch.zeros(S_out.shape).cuda() 121 | s_t = torch.where(S_out < 0, zero_lab, S_out) 122 | one_lab = torch.ones(S_out.shape).cuda() 123 | S_out = torch.where(s_t > 1, one_lab, s_t) 124 | 125 | zero_lab = torch.zeros(V_out.shape).cuda() 126 | s_t = torch.where(V_out < 0, zero_lab, V_out) 127 | one_lab = torch.ones(V_out.shape).cuda() 128 | V_out = torch.where(s_t > 1, one_lab, s_t) 129 | 130 | hsv_out = torch.cat([H_out, S_out, V_out], dim=1) 131 | curve = torch.cat([H.view(batch_size, 1, -1), 132 | S.view(batch_size, 1, -1), 133 | V.view(batch_size, 1, -1), 134 | H2S.view(batch_size, 1, -1)], dim=1) 135 | 136 | hsv_out_rgb = self.hsv2rgb(hsv_out) 137 | 138 | confindencenet_input = torch.cat([x, 139 | rgb_out, 140 | hsv_out_rgb], dim=1) 141 | 142 | h = self.fuction_down(self.norm_batch1(self.con1(confindencenet_input))) 143 | h = self.fuction_down(self.norm_batch2(self.con2(h))) 144 | h = self.fuction_down(self.norm_batch3(self.con3(h))) 145 | h = self.fuction_down(self.norm_batch4(self.con4(h))) 146 | h = self.fuction_down(self.norm_batch5(self.con5(h))) 147 | h = self.fuction_down(self.norm_batch6(self.con6(h))) # try to use 148 | confindence_out = torch.sigmoid(self.con7(h)) 149 | 150 | # 需要改名 151 | confindence_rgb = confindence_out[:, 0:3, :, :] 152 | confindence_hsv = confindence_out[:, 3:6, :, :] 153 | output_useconf = 0.5 * confindence_rgb * rgb_out + \ 154 | 0.5 * confindence_hsv * hsv_out_rgb 155 | 156 | return output_useconf 157 | #, rgb_out, hsv_out_rgb 158 | 159 | 160 | 161 | if __name__ == '__main__': 162 | t = torch.randn(1, 3, 256, 256).cuda() 163 | model = UIEC2Net().cuda() 164 | res = model(t) 165 | print(res.shape) -------------------------------------------------------------------------------- /models/Deep-WaveNet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from torch.autograd import Variable 9 | from torch.utils.data import DataLoader, Dataset 10 | 11 | 12 | """# Channel and Spatial Attention""" 13 | class BasicConv(nn.Module): 14 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=False, bias=False): 15 | super(BasicConv, self).__init__() 16 | self.out_channels = out_planes 17 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 18 | self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None 19 | self.relu = nn.ReLU() if relu else None 20 | 21 | def forward(self, x): 22 | x = self.conv(x) 23 | if self.bn is not None: 24 | x = self.bn(x) 25 | if self.relu is not None: 26 | x = self.relu(x) 27 | return x 28 | 29 | class Flatten(nn.Module): 30 | def forward(self, x): 31 | return x.view(x.size(0), -1) 32 | 33 | class ChannelGate(nn.Module): 34 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): 35 | super(ChannelGate, self).__init__() 36 | self.gate_channels = gate_channels 37 | self.mlp = nn.Sequential( 38 | Flatten(), 39 | nn.Linear(gate_channels, gate_channels // reduction_ratio), 40 | nn.ReLU(), 41 | nn.Linear(gate_channels // reduction_ratio, gate_channels) 42 | ) 43 | self.pool_types = pool_types 44 | def forward(self, x): 45 | channel_att_sum = None 46 | for pool_type in self.pool_types: 47 | if pool_type=='avg': 48 | avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 49 | channel_att_raw = self.mlp( avg_pool ) 50 | elif pool_type=='max': 51 | max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 52 | channel_att_raw = self.mlp( max_pool ) 53 | elif pool_type=='lp': 54 | lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 55 | channel_att_raw = self.mlp( lp_pool ) 56 | elif pool_type=='lse': 57 | # LSE pool only 58 | lse_pool = logsumexp_2d(x) 59 | channel_att_raw = self.mlp( lse_pool ) 60 | 61 | if channel_att_sum is None: 62 | channel_att_sum = channel_att_raw 63 | else: 64 | channel_att_sum = channel_att_sum + channel_att_raw 65 | 66 | scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) 67 | return x * scale 68 | 69 | def logsumexp_2d(tensor): 70 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) 71 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) 72 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() 73 | return outputs 74 | 75 | class ChannelPool(nn.Module): 76 | def forward(self, x): 77 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) 78 | 79 | class SpatialGate(nn.Module): 80 | def __init__(self): 81 | super(SpatialGate, self).__init__() 82 | kernel_size = 7 83 | self.compress = ChannelPool() 84 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) 85 | def forward(self, x): 86 | x_compress = self.compress(x) 87 | x_out = self.spatial(x_compress) 88 | scale = torch.sigmoid(x_out) # broadcasting 89 | return x * scale 90 | 91 | class CBAM(nn.Module): 92 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): 93 | super(CBAM, self).__init__() 94 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) 95 | self.no_spatial=no_spatial 96 | if not no_spatial: 97 | self.SpatialGate = SpatialGate() 98 | def forward(self, x): 99 | x_out = self.ChannelGate(x) 100 | if not self.no_spatial: 101 | x_out = self.SpatialGate(x_out) 102 | return x_out 103 | 104 | 105 | class Conv2D_pxp(nn.Module): 106 | 107 | def __init__(self, in_ch, out_ch, k,s,p): 108 | super(Conv2D_pxp, self).__init__() 109 | self.conv = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=k, stride=s, padding=p) 110 | self.bn = nn.BatchNorm2d(num_features=out_ch) 111 | self.relu = nn.PReLU(out_ch) 112 | 113 | def forward(self, input): 114 | return self.relu(self.bn(self.conv(input))) 115 | 116 | 117 | class DeepWave(nn.Module): 118 | 119 | def __init__(self): 120 | super(DeepWave, self).__init__() 121 | 122 | self.layer1_1 = Conv2D_pxp(1, 32, 3,1,1) 123 | self.layer1_2 = Conv2D_pxp(1, 32, 5,1,2) 124 | self.layer1_3 = Conv2D_pxp(1, 32, 7,1,3) 125 | 126 | self.layer2_1 = Conv2D_pxp(96, 32, 3,1,1) 127 | self.layer2_2 = Conv2D_pxp(96, 32, 5,1,2) 128 | self.layer2_3 = Conv2D_pxp(96, 32, 7,1,3) 129 | 130 | self.local_attn_r = CBAM(64) 131 | self.local_attn_g = CBAM(64) 132 | self.local_attn_b = CBAM(64) 133 | 134 | self.layer3_1 = Conv2D_pxp(192, 1, 3,1,1) 135 | self.layer3_2 = Conv2D_pxp(192, 1, 5,1,2) 136 | self.layer3_3 = Conv2D_pxp(192, 1, 7,1,3) 137 | 138 | 139 | self.d_conv1 = nn.ConvTranspose2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1) 140 | self.d_bn1 = nn.BatchNorm2d(num_features=32) 141 | self.d_relu1 = nn.PReLU(32) 142 | 143 | self.global_attn_rgb = CBAM(35) 144 | 145 | self.d_conv2 = nn.ConvTranspose2d(in_channels=35, out_channels=3, kernel_size=3, stride=1, padding=1) 146 | self.d_bn2 = nn.BatchNorm2d(num_features=3) 147 | self.d_relu2 = nn.PReLU(3) 148 | 149 | 150 | def forward(self, input): 151 | input_1 = torch.unsqueeze(input[:,0,:,:], dim=1) 152 | input_2 = torch.unsqueeze(input[:,1,:,:], dim=1) 153 | input_3 = torch.unsqueeze(input[:,2,:,:], dim=1) 154 | 155 | #layer 1 156 | l1_1=self.layer1_1(input_1) 157 | l1_2=self.layer1_2(input_2) 158 | l1_3=self.layer1_3(input_3) 159 | 160 | #Input to layer 2 161 | input_l2=torch.cat((l1_1,l1_2),1) 162 | input_l2=torch.cat((input_l2,l1_3),1) 163 | 164 | #layer 2 165 | l2_1=self.layer2_1(input_l2) 166 | l2_1 = self.local_attn_r(torch.cat((l2_1, l1_1),1)) 167 | 168 | l2_2=self.layer2_2(input_l2) 169 | l2_2 = self.local_attn_g(torch.cat((l2_2, l1_2),1)) 170 | 171 | l2_3=self.layer2_3(input_l2) 172 | l2_3 = self.local_attn_b(torch.cat((l2_3, l1_3),1)) 173 | 174 | #Input to layer 3 175 | input_l3=torch.cat((l2_1,l2_2),1) 176 | input_l3=torch.cat((input_l3,l2_3),1) 177 | 178 | #layer 3 179 | l3_1=self.layer3_1(input_l3) 180 | l3_2=self.layer3_2(input_l3) 181 | l3_3=self.layer3_3(input_l3) 182 | 183 | #input to decoder unit 184 | temp_d1=torch.add(input_1,l3_1) 185 | temp_d2=torch.add(input_2,l3_2) 186 | temp_d3=torch.add(input_3,l3_3) 187 | 188 | input_d1=torch.cat((temp_d1,temp_d2),1) 189 | input_d1=torch.cat((input_d1,temp_d3),1) 190 | 191 | #decoder 192 | output_d1=self.d_relu1(self.d_bn1(self.d_conv1(input_d1))) 193 | output_d1 = self.global_attn_rgb(torch.cat((output_d1, input_d1),1)) 194 | final_output=self.d_relu2(self.d_bn2(self.d_conv2(output_d1))) 195 | 196 | return final_output 197 | 198 | 199 | if __name__ == '__main__': 200 | t = torch.randn(1, 3, 256, 256).cuda() 201 | model = DeepWave().cuda() 202 | res = model(t) 203 | print(res.shape) -------------------------------------------------------------------------------- /models/AoSRNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class AoSRNet(nn.Module): 7 | def __init__(self): 8 | super(AoSRNet,self).__init__() 9 | 10 | self.mns = MainNetworkStructure(3,12) 11 | 12 | def forward(self,x): 13 | 14 | Fout = self.mns(x) 15 | 16 | return Fout# + x 17 | 18 | 19 | class MainNetworkStructure(nn.Module): 20 | def __init__(self,inchannel,channel): 21 | super(MainNetworkStructure,self).__init__() 22 | 23 | self.conv_mv_in = nn.Conv2d(3,3,kernel_size=3,stride=1,padding=1,dilation=1,bias=False) 24 | self.conv_wb_in = nn.Conv2d(3,3,kernel_size=3,stride=1,padding=1,dilation=1,bias=False) 25 | self.conv_gc_in = nn.Conv2d(3,3,kernel_size=3,stride=1,padding=1,dilation=1,bias=False) 26 | 27 | 28 | self.conv_mv = BB(channel) 29 | self.conv_wb = BB(channel) 30 | self.conv_gc = BB(channel) 31 | 32 | self.ED = En_Decoder(channel,3*channel) 33 | 34 | self.wbm = WBM() 35 | self.gcm = GCM() 36 | self.mvp = MVP() 37 | 38 | def forward(self,x): 39 | 40 | mv_x1 = torch.clamp(self.conv_mv_in(x),1e-10,1.0) 41 | mv_x2 = torch.clamp(self.conv_mv_in(x),1e-10,1.0) 42 | mv_x3 = torch.clamp(self.conv_mv_in(x),1e-10,1.0) 43 | mv_x4 = torch.clamp(self.conv_mv_in(x),1e-10,1.0) 44 | 45 | wb_x1 = torch.clamp(self.conv_wb_in(x),1e-10,1.0) 46 | wb_x2 = torch.clamp(self.conv_wb_in(x),1e-10,1.0) 47 | wb_x3 = torch.clamp(self.conv_wb_in(x),1e-10,1.0) 48 | wb_x4 = torch.clamp(self.conv_wb_in(x),1e-10,1.0) 49 | 50 | gc_x1 = torch.clamp(self.conv_gc_in(x),1e-10,1.0) 51 | gc_x2 = torch.clamp(self.conv_gc_in(x),1e-10,1.0) 52 | gc_x3 = torch.clamp(self.conv_gc_in(x),1e-10,1.0) 53 | gc_x4 = torch.clamp(self.conv_gc_in(x),1e-10,1.0) 54 | 55 | 56 | mv = self.conv_mv(self.mvp(mv_x1,mv_x2,mv_x3,mv_x4)) 57 | wb = self.conv_wb(self.wbm(wb_x1,wb_x2,wb_x3,wb_x4)) 58 | gc = self.conv_gc(self.gcm(gc_x1,gc_x2,gc_x3,gc_x4)) 59 | 60 | 61 | x_out = self.ED(mv,wb,gc) 62 | 63 | return x_out# + x 64 | 65 | 66 | class MVP(nn.Module): #Multi-view perception 67 | def __init__(self,norm=False): 68 | super(MVP,self).__init__() 69 | 70 | self.convD_3 = nn.Conv2d(3,3,kernel_size=3,stride=1,padding=3,dilation=3,bias=False) 71 | self.convD_6 = nn.Conv2d(3,3,kernel_size=3,stride=1,padding=6,dilation=6,bias=False) 72 | self.convD_9 = nn.Conv2d(3,3,kernel_size=3,stride=1,padding=9,dilation=9,bias=False) 73 | self.convD_12 = nn.Conv2d(3,3,kernel_size=3,stride=1,padding=12,dilation=12,bias=False) 74 | 75 | self.act = nn.PReLU(3) 76 | self.norm = nn.GroupNorm(num_channels=3,num_groups=1)# nn.InstanceNorm2d(channel)# 77 | 78 | def forward(self,x1,x2,x3,x4): 79 | 80 | x1 = self.act(self.norm(self.convD_3(x1))) 81 | x2 = self.act(self.norm(self.convD_6(x2))) 82 | x3 = self.act(self.norm(self.convD_9(x3))) 83 | x4 = self.act(self.norm(self.convD_12(x4))) 84 | 85 | xout = torch.cat((x1,x2,x3,x4),1) 86 | 87 | return xout 88 | 89 | class WBM(nn.Module): # White Balance Model 90 | def __init__(self): 91 | super(WBM,self).__init__() 92 | 93 | self.conv_1 = ConvL(3,3) 94 | self.conv_2 = ConvL(3,3) 95 | self.conv_3 = ConvL(3,3) 96 | self.conv_4 = ConvL(3,3) 97 | 98 | def forward(self,x1,x2,x3,x4): 99 | 100 | x1 = self.conv_1(WhiteBalance(x1,0.05,0.10)) 101 | x2 = self.conv_2(WhiteBalance(x2,0.05,0.15)) 102 | x3 = self.conv_3(WhiteBalance(x3,0.15,0.10)) 103 | x4 = self.conv_4(WhiteBalance(x4,0.15,0.20)) 104 | 105 | xout = torch.cat((x1,x2,x3,x4),1) 106 | 107 | return xout 108 | 109 | 110 | def WhiteBalance(TensorData,pmi,pma): 111 | '''White Balance for recovery priors''' 112 | for i in range(TensorData.shape[0]): 113 | for j in range(3): 114 | tmi = torch.quantile(TensorData[i,j,:,:].clone(),0.01) 115 | tma = torch.quantile(TensorData[i,j,:,:].clone(),0.09) 116 | tpmi = tmi - pmi * (tma - tmi) 117 | tpma = tma + pma * (tma - tmi) 118 | 119 | TensorData[i,j,:,:,] = (TensorData[i,j,:,:].clone() - tpmi) / ((tpma - tpmi) + 1e-10) 120 | 121 | return TensorData 122 | 123 | 124 | class GCM(nn.Module): # Gamma Correction Model 125 | def __init__(self): 126 | super(GCM,self).__init__() 127 | 128 | self.conv_1 = ConvL(3,3) 129 | self.conv_2 = ConvL(3,3) 130 | self.conv_3 = ConvL(3,3) 131 | self.conv_4 = ConvL(3,3) 132 | 133 | def forward(self,x1,x2,x3,x4): 134 | 135 | x1 = self.conv_1(torch.pow(x1,1/4)) 136 | x2 = self.conv_2(torch.pow(x2,1/2)) 137 | x3 = self.conv_3(torch.pow(x3,2)) 138 | x4 = self.conv_4(torch.pow(x4,4)) 139 | 140 | xout = torch.cat((x1,x2,x3,x4),1) 141 | 142 | return xout 143 | 144 | class BB(nn.Module): #Basic Block (BB) 145 | def __init__(self,channel,norm=False): 146 | super(BB,self).__init__() 147 | 148 | self.conv_1 = nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1,bias=False) 149 | self.conv_2 = nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1,bias=False) 150 | self.conv_3 = nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1,bias=False) 151 | self.conv_out = nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1,bias=False) 152 | 153 | self.act = nn.PReLU(channel) 154 | self.norm = nn.GroupNorm(num_channels=channel,num_groups=1)# nn.InstanceNorm2d(channel)# 155 | 156 | def forward(self,x): 157 | 158 | x_1 = self.act(self.norm(self.conv_1(x))) 159 | x_2 = self.act(self.norm(self.conv_2(x_1))) 160 | x_out = self.act(self.norm(self.conv_out(x_2)) + x) 161 | 162 | return x_out 163 | 164 | 165 | class ConvL(nn.Module): 166 | def __init__(self,inchannel,channel,norm=False): 167 | super(ConvL,self).__init__() 168 | 169 | self.conv = nn.Conv2d(inchannel,channel,kernel_size=3,stride=1,padding=1,bias=False) 170 | self.act = nn.PReLU(channel) 171 | self.norm = nn.GroupNorm(num_channels=channel,num_groups=1) 172 | 173 | def forward(self,x): 174 | 175 | x_out = self.act(self.norm(self.conv(x))) 176 | 177 | return x_out 178 | 179 | 180 | class En_Decoder(nn.Module): 181 | def __init__(self,inchannel,channel): 182 | super(En_Decoder,self).__init__() 183 | 184 | self.el = BB(channel) 185 | self.em = BB(channel*2) 186 | self.es = BB(channel*4) 187 | self.ds = BB(channel*4) 188 | self.dm = BB(channel*2) 189 | self.dl = BB(channel) 190 | 191 | self.conv_eltem = nn.Conv2d(channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False) 192 | self.conv_emtes = nn.Conv2d(2*channel,4*channel,kernel_size=1,stride=1,padding=0,bias=False) 193 | self.conv_dstdm = nn.Conv2d(4*channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False) 194 | self.conv_dmtdl = nn.Conv2d(2*channel,channel,kernel_size=1,stride=1,padding=0,bias=False) 195 | 196 | self.conv_in = nn.Conv2d(12,channel,kernel_size=3,stride=1,padding=1,bias=False) 197 | #self.conv_cat_in = nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1,bias=False) 198 | 199 | self.conv_out = nn.Conv2d(channel,3,kernel_size=3,stride=1,padding=1,bias=False) 200 | self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1) 201 | 202 | def _upsample(self,x,y): 203 | _,_,H,W = y.size() 204 | return F.interpolate(x,size=(H,W),mode='bilinear') 205 | 206 | def forward(self,x1,x2,x3): 207 | 208 | x_elin = torch.cat((x1,x2,x3),1) + self.conv_in(x1+x2+x3)# + self.conv_in(x1) 209 | 210 | elout = self.el(x_elin) 211 | emout = self.em(self.conv_eltem(self.maxpool(elout))) 212 | esout = self.es(self.conv_emtes(self.maxpool(emout))) 213 | 214 | dsout = self.ds(esout) 215 | dmout = self.dm(self._upsample(self.conv_dstdm(dsout),emout) + emout) 216 | dlout = self.dl(self._upsample(self.conv_dmtdl(dmout),elout) + elout) 217 | 218 | x_out = self.conv_out(dlout) 219 | 220 | return x_out 221 | 222 | 223 | if __name__ == '__main__': 224 | inp = torch.randn(1, 3, 256, 256).cuda() 225 | model = AoSRNet().cuda() 226 | res = model(inp) 227 | print(res.shape) 228 | -------------------------------------------------------------------------------- /models/UIEPTA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import cv2 5 | import numpy as np 6 | 7 | class MHA(nn.Module): 8 | def __init__(self, channels, num_heads): 9 | super(MHA, self).__init__() 10 | self.num_heads = num_heads 11 | self.temperature = nn.Parameter(torch.ones(1, num_heads, 1, 1)) 12 | 13 | self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1, bias=False) 14 | self.qkv_conv = nn.Conv2d(channels * 3, channels * 3, kernel_size=3, padding=1, groups=channels * 3, bias=False) 15 | self.project_out = nn.Conv2d(channels, channels, kernel_size=1, bias=False) 16 | 17 | def forward(self, x): 18 | b, c, h, w = x.shape 19 | q, k, v = self.qkv_conv(self.qkv(x)).chunk(3, dim=1) 20 | 21 | q = q.reshape(b, self.num_heads, -1, h * w) 22 | k = k.reshape(b, self.num_heads, -1, h * w) 23 | v = v.reshape(b, self.num_heads, -1, h * w) 24 | q, k = F.normalize(q, dim=-1), F.normalize(k, dim=-1) 25 | 26 | attn = torch.softmax(torch.matmul(q, k.transpose(-2, -1).contiguous()) * self.temperature, dim=-1) 27 | out = self.project_out(torch.matmul(attn, v).reshape(b, -1, h, w)) 28 | return out 29 | 30 | 31 | class GFFN(nn.Module): 32 | def __init__(self, channels, expansion_factor): 33 | super(GFFN, self).__init__() 34 | 35 | hidden_channels = int(channels * expansion_factor) 36 | self.project_in = nn.Conv2d(channels, hidden_channels * 2, kernel_size=1, bias=False) 37 | self.conv = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, kernel_size=3, padding=1, 38 | groups=hidden_channels * 2, bias=False) 39 | self.project_out = nn.Conv2d(hidden_channels, channels, kernel_size=1, bias=False) 40 | 41 | def forward(self, x): 42 | x1, x2 = self.conv(self.project_in(x)).chunk(2, dim=1) 43 | x = self.project_out(F.gelu(x1) * x2) 44 | return x 45 | 46 | #####-----attention_to_1x1 used for Gray Scale Attention----##### 47 | class attention_to_1x1(nn.Module): 48 | def __init__(self, channels): 49 | super(attention_to_1x1, self).__init__() 50 | self.conv1 = nn.Conv2d(channels, channels*2, kernel_size=1, bias=False) 51 | self.conv2 = nn.Conv2d(channels*2, channels, kernel_size=1, bias=False) 52 | 53 | def forward(self,x): 54 | x=torch.mean(x,-1) 55 | x=torch.mean(x ,-1) 56 | x=torch.unsqueeze(x ,-1) 57 | x=torch.unsqueeze(x ,-1) 58 | xx = self.conv2(self.conv1(x)) 59 | return xx 60 | 61 | 62 | class TransformerBlock(nn.Module): 63 | def __init__(self, channels, num_heads, expansion_factor): 64 | super(TransformerBlock, self).__init__() 65 | 66 | self.norm1 = nn.LayerNorm(channels) 67 | self.attn = MHA(channels, num_heads) 68 | self.norm2 = nn.LayerNorm(channels) 69 | self.ffn = GFFN(channels, expansion_factor) 70 | 71 | def forward(self, x): 72 | b, c, h, w = x.shape 73 | x = x + self.attn(self.norm1(x.reshape(b, c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1) 74 | .contiguous().reshape(b, c, h, w)) 75 | x = x + self.ffn(self.norm2(x.reshape(b, c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1) 76 | .contiguous().reshape(b, c, h, w)) 77 | return x 78 | 79 | 80 | class DownSample(nn.Module): 81 | def __init__(self, channels): 82 | super(DownSample, self).__init__() 83 | self.body = nn.Sequential(nn.Conv2d(channels, channels // 2, kernel_size=3, padding=1, bias=False), 84 | nn.PixelUnshuffle(2)) 85 | 86 | def forward(self, x): 87 | return self.body(x) 88 | 89 | 90 | class UpSample(nn.Module): 91 | def __init__(self, channels): 92 | super(UpSample, self).__init__() 93 | self.body = nn.Sequential(nn.Conv2d(channels, channels * 2, kernel_size=3, padding=1, bias=False), 94 | nn.PixelShuffle(2)) 95 | 96 | def forward(self, x): 97 | return self.body(x) 98 | 99 | def Phase_swap(x,y): 100 | fftn1 = torch.fft.fftn(x) 101 | fftn2 = torch.fft.fftn(y) 102 | out=torch.fft.ifftn(abs(fftn2)*torch.exp(1j*(fftn1.angle()))) #pHASE SWAPPING 103 | return out.real 104 | 105 | 106 | class Model(nn.Module): 107 | def __init__(self, num_blocks=[2, 3, 3, 4], num_heads=[1, 2, 4, 8], channels=[16, 32, 64, 128], num_refinement=4, 108 | expansion_factor=2.66, ch=[16,16,32,64]): 109 | super(Model, self).__init__() 110 | self.sig=nn.Sigmoid() 111 | self.to_1x1 = nn.ModuleList([attention_to_1x1(num_ch) for num_ch in ch]) 112 | 113 | self.embed_conv_rgb = nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False) 114 | self.embed_conv_gray = nn.Conv2d(1, channels[0], kernel_size=3, padding=1, bias=False) 115 | 116 | self.encoders = nn.ModuleList([nn.Sequential(*[TransformerBlock( 117 | num_ch, num_ah, expansion_factor) for _ in range(num_tb)]) for num_tb, num_ah, num_ch in 118 | zip(num_blocks, num_heads, channels)]) 119 | # the number of down sample or up sample == the number of encoder - 1 120 | self.downs = nn.ModuleList([DownSample(num_ch) for num_ch in channels[:-1]]) 121 | self.ups = nn.ModuleList([UpSample(num_ch) for num_ch in list(reversed(channels))[:-1]]) 122 | # the number of reduce block == the number of decoder - 1 123 | self.reduces = nn.ModuleList([nn.Conv2d(channels[i], channels[i - 1], kernel_size=1, bias=False) 124 | for i in reversed(range(2, len(channels)))]) 125 | # the number of decoder == the number of encoder - 1 126 | self.decoders = nn.ModuleList([nn.Sequential(*[TransformerBlock(channels[2], num_heads[2], expansion_factor) 127 | for _ in range(num_blocks[2])])]) 128 | self.decoders.append(nn.Sequential(*[TransformerBlock(channels[1], num_heads[1], expansion_factor) 129 | for _ in range(num_blocks[1])])) 130 | # the channel of last one is not change 131 | self.decoders.append(nn.Sequential(*[TransformerBlock(channels[1], num_heads[0], expansion_factor) 132 | for _ in range(num_blocks[0])])) 133 | 134 | self.refinement = nn.Sequential(*[TransformerBlock(channels[1], num_heads[0], expansion_factor) 135 | for _ in range(num_refinement)]) 136 | self.output = nn.Conv2d(channels[1], 3, kernel_size=3, padding=1, bias=False) 137 | 138 | 139 | def forward(self,RGB_input, Gray_input): 140 | 141 | ###-------encoder-------#### 142 | 143 | fo_rgb = self.embed_conv_rgb(RGB_input) 144 | fo_gray = self.embed_conv_gray(Gray_input) 145 | out_enc_rgb1 = self.encoders[0](self.sig(self.to_1x1[0](fo_gray)*fo_rgb)) 146 | out_enc_gray1 = self.encoders[0](fo_gray) 147 | out_enc_rgb2 = self.encoders[1](self.downs[0](self.sig(self.to_1x1[1](out_enc_gray1)*out_enc_rgb1))) 148 | out_enc_gray2 = self.encoders[1](self.downs[0](out_enc_gray1)) 149 | out_enc_rgb3 = self.encoders[2](self.downs[1](self.sig(self.to_1x1[2](out_enc_gray2)*out_enc_rgb2))) 150 | out_enc_gray3 = self.encoders[2](self.downs[1](out_enc_gray2)) 151 | out_enc_rgb4 = self.encoders[3](self.downs[2](self.sig(self.to_1x1[3](out_enc_gray3)*out_enc_rgb3))) 152 | 153 | ###-------Dencoder------#### 154 | 155 | OUT1=Phase_swap(out_enc_rgb3,self.ups[0](out_enc_rgb4)) 156 | out_dec3 = self.decoders[0](self.reduces[0](torch.cat([self.ups[0](out_enc_rgb4), OUT1], dim=1))) 157 | OUT2 =Phase_swap(out_enc_rgb2,self.ups[1](out_dec3)) 158 | out_dec2 = self.decoders[1](self.reduces[1](torch.cat([self.ups[1](out_dec3), OUT2], dim=1))) 159 | OUT3 =Phase_swap(out_enc_rgb1,self.ups[2](out_dec2)) 160 | fd = self.decoders[2](torch.cat([self.ups[2](out_dec2), OUT3], dim=1)) 161 | fr = self.refinement(fd) 162 | out=self.output(fr) 163 | return out+RGB_input 164 | 165 | 166 | if __name__ == '__main__': 167 | inp = torch.randn(1, 3, 256, 256).cuda() 168 | gray = torch.randn(1, 1, 256, 256).cuda() 169 | model = Model().cuda() 170 | res = model(inp, gray) 171 | print(res.shape) -------------------------------------------------------------------------------- /models/UIETPA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MHA(nn.Module): 7 | def __init__(self, channels, num_heads): 8 | super(MHA, self).__init__() 9 | self.num_heads = num_heads 10 | self.temperature = nn.Parameter(torch.ones(1, num_heads, 1, 1)) 11 | 12 | self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1, bias=False) 13 | self.qkv_conv = nn.Conv2d( 14 | channels * 3, channels * 3, kernel_size=3, padding=1, groups=channels * 3, bias=False) 15 | self.project_out = nn.Conv2d( 16 | channels, channels, kernel_size=1, bias=False) 17 | 18 | def forward(self, x): 19 | b, c, h, w = x.shape 20 | q, k, v = self.qkv_conv(self.qkv(x)).chunk(3, dim=1) 21 | 22 | q = q.reshape(b, self.num_heads, -1, h * w) 23 | k = k.reshape(b, self.num_heads, -1, h * w) 24 | v = v.reshape(b, self.num_heads, -1, h * w) 25 | q, k = F.normalize(q, dim=-1), F.normalize(k, dim=-1) 26 | 27 | attn = torch.softmax(torch.matmul( 28 | q, k.transpose(-2, -1).contiguous()) * self.temperature, dim=-1) 29 | out = self.project_out(torch.matmul(attn, v).reshape(b, -1, h, w)) 30 | return out 31 | 32 | 33 | class GFFN(nn.Module): 34 | def __init__(self, channels, expansion_factor): 35 | super(GFFN, self).__init__() 36 | 37 | hidden_channels = int(channels * expansion_factor) 38 | self.project_in = nn.Conv2d( 39 | channels, hidden_channels * 2, kernel_size=1, bias=False) 40 | self.conv = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, kernel_size=3, padding=1, 41 | groups=hidden_channels * 2, bias=False) 42 | self.project_out = nn.Conv2d( 43 | hidden_channels, channels, kernel_size=1, bias=False) 44 | 45 | def forward(self, x): 46 | x1, x2 = self.conv(self.project_in(x)).chunk(2, dim=1) 47 | x = self.project_out(F.gelu(x1) * x2) 48 | return x 49 | 50 | #####-----attention_to_1x1 used for Gray Scale Attention----##### 51 | 52 | 53 | class attention_to_1x1(nn.Module): 54 | def __init__(self, channels): 55 | super(attention_to_1x1, self).__init__() 56 | self.conv1 = nn.Conv2d(channels, channels*2, kernel_size=1, bias=False) 57 | self.conv2 = nn.Conv2d(channels*2, channels, kernel_size=1, bias=False) 58 | 59 | def forward(self, x): 60 | x = torch.mean(x, -1) 61 | x = torch.mean(x, -1) 62 | x = torch.unsqueeze(x, -1) 63 | x = torch.unsqueeze(x, -1) 64 | xx = self.conv2(self.conv1(x)) 65 | return xx 66 | 67 | 68 | class TransformerBlock(nn.Module): 69 | def __init__(self, channels, num_heads, expansion_factor): 70 | super(TransformerBlock, self).__init__() 71 | 72 | self.norm1 = nn.LayerNorm(channels) 73 | self.attn = MHA(channels, num_heads) 74 | self.norm2 = nn.LayerNorm(channels) 75 | self.ffn = GFFN(channels, expansion_factor) 76 | 77 | def forward(self, x): 78 | b, c, h, w = x.shape 79 | x = x + self.attn(self.norm1(x.reshape(b, c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1) 80 | .contiguous().reshape(b, c, h, w)) 81 | x = x + self.ffn(self.norm2(x.reshape(b, c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1) 82 | .contiguous().reshape(b, c, h, w)) 83 | return x 84 | 85 | 86 | class DownSample(nn.Module): 87 | def __init__(self, channels): 88 | super(DownSample, self).__init__() 89 | self.body = nn.Sequential(nn.Conv2d(channels, channels // 2, kernel_size=3, padding=1, bias=False), 90 | nn.PixelUnshuffle(2)) 91 | 92 | def forward(self, x): 93 | return self.body(x) 94 | 95 | 96 | class UpSample(nn.Module): 97 | def __init__(self, channels): 98 | super(UpSample, self).__init__() 99 | self.body = nn.Sequential(nn.Conv2d(channels, channels * 2, kernel_size=3, padding=1, bias=False), 100 | nn.PixelShuffle(2)) 101 | 102 | def forward(self, x): 103 | return self.body(x) 104 | 105 | 106 | def Phase_swap(x, y): 107 | fftn1 = torch.fft.fftn(x) 108 | fftn2 = torch.fft.fftn(y) 109 | out = torch.fft.ifftn( 110 | abs(fftn2)*torch.exp(1j*(fftn1.angle()))) # pHASE SWAPPING 111 | return out.real 112 | 113 | 114 | class UIETPA(nn.Module): 115 | def __init__(self, num_blocks=[2, 3, 3, 4], num_heads=[1, 2, 4, 8], channels=[16, 32, 64, 128], num_refinement=4, 116 | expansion_factor=2.66, ch=[16, 16, 32, 64]): 117 | super(UIETPA, self).__init__() 118 | self.sig = nn.Sigmoid() 119 | self.to_1x1 = nn.ModuleList( 120 | [attention_to_1x1(num_ch) for num_ch in ch]) 121 | 122 | self.embed_conv_rgb = nn.Conv2d( 123 | 3, channels[0], kernel_size=3, padding=1, bias=False) 124 | self.embed_conv_gray = nn.Conv2d( 125 | 1, channels[0], kernel_size=3, padding=1, bias=False) 126 | 127 | self.encoders = nn.ModuleList([nn.Sequential(*[TransformerBlock( 128 | num_ch, num_ah, expansion_factor) for _ in range(num_tb)]) for num_tb, num_ah, num_ch in 129 | zip(num_blocks, num_heads, channels)]) 130 | # the number of down sample or up sample == the number of encoder - 1 131 | self.downs = nn.ModuleList([DownSample(num_ch) 132 | for num_ch in channels[:-1]]) 133 | self.ups = nn.ModuleList([UpSample(num_ch) 134 | for num_ch in list(reversed(channels))[:-1]]) 135 | # the number of reduce block == the number of decoder - 1 136 | self.reduces = nn.ModuleList([nn.Conv2d(channels[i], channels[i - 1], kernel_size=1, bias=False) 137 | for i in reversed(range(2, len(channels)))]) 138 | # the number of decoder == the number of encoder - 1 139 | self.decoders = nn.ModuleList([nn.Sequential(*[TransformerBlock(channels[2], num_heads[2], expansion_factor) 140 | for _ in range(num_blocks[2])])]) 141 | self.decoders.append(nn.Sequential(*[TransformerBlock(channels[1], num_heads[1], expansion_factor) 142 | for _ in range(num_blocks[1])])) 143 | # the channel of last one is not change 144 | self.decoders.append(nn.Sequential(*[TransformerBlock(channels[1], num_heads[0], expansion_factor) 145 | for _ in range(num_blocks[0])])) 146 | 147 | self.refinement = nn.Sequential(*[TransformerBlock(channels[1], num_heads[0], expansion_factor) 148 | for _ in range(num_refinement)]) 149 | self.output = nn.Conv2d( 150 | channels[1], 3, kernel_size=3, padding=1, bias=False) 151 | 152 | def forward(self, RGB_input, Gray_input): 153 | 154 | ###-------encoder-------#### 155 | 156 | fo_rgb = self.embed_conv_rgb(RGB_input) 157 | fo_gray = self.embed_conv_gray(Gray_input) 158 | out_enc_rgb1 = self.encoders[0]( 159 | self.sig(self.to_1x1[0](fo_gray)*fo_rgb)) 160 | out_enc_gray1 = self.encoders[0](fo_gray) 161 | out_enc_rgb2 = self.encoders[1](self.downs[0]( 162 | self.sig(self.to_1x1[1](out_enc_gray1)*out_enc_rgb1))) 163 | out_enc_gray2 = self.encoders[1](self.downs[0](out_enc_gray1)) 164 | out_enc_rgb3 = self.encoders[2](self.downs[1]( 165 | self.sig(self.to_1x1[2](out_enc_gray2)*out_enc_rgb2))) 166 | out_enc_gray3 = self.encoders[2](self.downs[1](out_enc_gray2)) 167 | out_enc_rgb4 = self.encoders[3](self.downs[2]( 168 | self.sig(self.to_1x1[3](out_enc_gray3)*out_enc_rgb3))) 169 | 170 | ###-------Dencoder------#### 171 | 172 | OUT1 = Phase_swap(out_enc_rgb3, self.ups[0](out_enc_rgb4)) 173 | out_dec3 = self.decoders[0](self.reduces[0]( 174 | torch.cat([self.ups[0](out_enc_rgb4), OUT1], dim=1))) 175 | OUT2 = Phase_swap(out_enc_rgb2, self.ups[1](out_dec3)) 176 | out_dec2 = self.decoders[1](self.reduces[1]( 177 | torch.cat([self.ups[1](out_dec3), OUT2], dim=1))) 178 | OUT3 = Phase_swap(out_enc_rgb1, self.ups[2](out_dec2)) 179 | fd = self.decoders[2](torch.cat([self.ups[2](out_dec2), OUT3], dim=1)) 180 | fr = self.refinement(fd) 181 | out = self.output(fr) 182 | return out+RGB_input 183 | -------------------------------------------------------------------------------- /models/SCNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def PONO(x, epsilon=1e-5): 6 | mean = x.mean(dim=1, keepdim=True) 7 | std = x.var(dim=1, keepdim=True).add(epsilon).sqrt() 8 | out = (x - mean) / std 9 | return out, mean, std 10 | 11 | def MS(x, beta, gamma): 12 | return x * gamma + beta 13 | 14 | 15 | class Whiten2d(nn.Module): 16 | def __init__(self, num_features, t=5, eps=1e-5, affine=True): 17 | super(Whiten2d, self).__init__() 18 | self.T = t 19 | self.eps = eps 20 | self.affine = affine 21 | self.num_features = num_features 22 | if self.affine: 23 | self.weight = nn.Parameter(torch.ones(num_features)) 24 | self.bias = nn.Parameter(torch.zeros(num_features)) 25 | 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | if self.affine: 30 | nn.init.ones_(self.weight) 31 | nn.init.zeros_(self.bias) 32 | 33 | def forward(self, x): 34 | 35 | N, C, H, W = x.size() 36 | 37 | # N x C x (H x W) 38 | in_data = x.view(N, C, -1) 39 | 40 | eye = in_data.data.new().resize_(C, C) 41 | eye = torch.nn.init.eye_(eye).view(1, C, C).expand(N, C, C) 42 | 43 | # calculate other statistics 44 | # N x C x 1 45 | mean_in = in_data.mean(-1, keepdim=True) 46 | x_in = in_data - mean_in 47 | # N x C x C 48 | cov_in = torch.bmm(x_in, torch.transpose(x_in, 1, 2)).div(H * W) 49 | # N x c x 1 50 | mean = mean_in 51 | cov = cov_in + self.eps * eye 52 | 53 | # perform whitening using Newton's iteration 54 | Ng, c, _ = cov.size() 55 | P = torch.eye(c).to(cov).expand(Ng, c, c) 56 | # reciprocal of trace of covariance 57 | rTr = (cov * P).sum((1, 2), keepdim=True).reciprocal_() 58 | cov_N = cov * rTr 59 | for k in range(self.T): 60 | P = torch.baddbmm(1.5, P, -0.5, torch.matrix_power(P, 3), cov_N) 61 | # P = torch.baddbmm(P, torch.matrix_power(P, 3), 1.5, -0.5, cov_N) 62 | # whiten matrix: the matrix inverse of covariance, i.e., cov^{-1/2} 63 | wm = P.mul_(rTr.sqrt()) 64 | 65 | x_hat = torch.bmm(wm, in_data - mean) 66 | x_hat = x_hat.view(N, C, H, W) 67 | if self.affine: 68 | x_hat = x_hat * self.weight.view(1, self.num_features, 1, 1) + \ 69 | self.bias.view(1, self.num_features, 1, 1) 70 | 71 | return x_hat 72 | 73 | class SELayer(torch.nn.Module): 74 | def __init__(self, num_filter): 75 | super(SELayer, self).__init__() 76 | self.global_pool = torch.nn.AdaptiveAvgPool2d(1) 77 | self.conv_double = torch.nn.Sequential( 78 | torch.nn.Conv2d(num_filter, num_filter // 16, 1, 1, 0, bias=True), 79 | torch.nn.LeakyReLU(0.2), 80 | torch.nn.Conv2d(num_filter // 16, num_filter, 1, 1, 0, bias=True), 81 | torch.nn.Sigmoid()) 82 | 83 | def forward(self, x): 84 | mask = self.global_pool(x) 85 | mask = self.conv_double(mask) 86 | x = x * mask 87 | return x 88 | 89 | 90 | class ResBlock(nn.Module): 91 | def __init__(self, num_filter): 92 | super(ResBlock, self).__init__() 93 | body = [] 94 | for i in range(2): 95 | body.append(nn.ReflectionPad2d(1)) 96 | body.append(nn.Conv2d(num_filter, num_filter, kernel_size=3, padding=0)) 97 | if i == 0: 98 | body.append(nn.LeakyReLU(0.2)) 99 | body.append(SELayer(num_filter)) 100 | self.body = nn.Sequential(*body) 101 | 102 | def forward(self, x): 103 | res = self.body(x) 104 | x = res + x 105 | return x 106 | 107 | 108 | class Up(nn.Module): 109 | def __init__(self): 110 | super(Up, self).__init__() 111 | self.up = nn.Sequential( 112 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 113 | ) 114 | 115 | def forward(self, x): 116 | x = self.up(x) 117 | return x 118 | 119 | 120 | class ConvBlock(nn.Module): 121 | def __init__(self, ch_in, ch_out): 122 | super(ConvBlock, self).__init__() 123 | self.conv = nn.Sequential( 124 | nn.ReflectionPad2d(1), 125 | nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=0), 126 | nn.LeakyReLU(0.2), 127 | nn.ReflectionPad2d(1), 128 | nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=0), 129 | nn.LeakyReLU(0.2), 130 | ) 131 | 132 | def forward(self, x): 133 | x = self.conv(x) 134 | return x 135 | 136 | 137 | class Encoder(nn.Module): 138 | def __init__(self): 139 | super(Encoder, self).__init__() 140 | self.conv_in = ConvBlock(ch_in=3, ch_out=64) 141 | self.conv1 = ConvBlock(ch_in=64, ch_out=64) 142 | self.conv2 = ConvBlock(ch_in=64, ch_out=64) 143 | self.conv3 = ConvBlock(ch_in=64, ch_out=64) 144 | self.conv4 = ConvBlock(ch_in=64, ch_out=64) 145 | self.IW1 = Whiten2d(64) 146 | self.IW2 = Whiten2d(64) 147 | self.IW3 = Whiten2d(64) 148 | self.IW4 = Whiten2d(64) 149 | self.pool = nn.MaxPool2d(2) 150 | 151 | def forward(self, x): 152 | x = self.conv_in(x) 153 | 154 | x1, x1_mean, x1_std = PONO(x) 155 | x1 = self.conv1(x) 156 | x2 = self.pool(x1) 157 | 158 | x2, x2_mean, x2_std = PONO(x2) 159 | x2 = self.conv2(x2) 160 | x3 = self.pool(x2) 161 | 162 | x3, x3_mean, x3_std = PONO(x3) 163 | x3 = self.conv3(x3) 164 | x4 = self.pool(x3) 165 | 166 | x4, x4_mean, x4_std = PONO(x4) 167 | x4 = self.conv4(x4) 168 | 169 | x4_iw = self.IW4(x4) 170 | x3_iw = self.IW3(x3) 171 | x2_iw = self.IW2(x2) 172 | x1_iw = self.IW1(x1) 173 | 174 | return x1_iw, x2_iw, x3_iw, x4_iw, x1_mean, x2_mean, x3_mean, x4_mean, x1_std, x2_std, x3_std, x4_std 175 | 176 | 177 | class Decoder(nn.Module): 178 | def __init__(self): 179 | super(Decoder, self).__init__() 180 | 181 | self.encoder = Encoder() 182 | self.UpConv4 = ConvBlock(ch_in=64, ch_out=64) 183 | self.Up3 = Up() 184 | self.UpConv3 = ConvBlock(ch_in=128, ch_out=64) 185 | self.Up2 = Up() 186 | self.UpConv2 = ConvBlock(ch_in=128, ch_out=64) 187 | self.Up1 = Up() 188 | self.UpConv1 = ConvBlock(ch_in=128, ch_out=64) 189 | 190 | self.conv_u4 = nn.Conv2d(1, 64, kernel_size=1, padding=0) 191 | self.conv_s4 = nn.Conv2d(1, 64, kernel_size=1, padding=0) 192 | self.conv_u3 = nn.Conv2d(1, 64, kernel_size=1, padding=0) 193 | self.conv_s3 = nn.Conv2d(1, 64, kernel_size=1, padding=0) 194 | self.conv_u2 = nn.Conv2d(1, 64, kernel_size=1, padding=0) 195 | self.conv_s2 = nn.Conv2d(1, 64, kernel_size=1, padding=0) 196 | self.conv_u1 = nn.Conv2d(1, 64, kernel_size=1, padding=0) 197 | self.conv_s1 = nn.Conv2d(1, 64, kernel_size=1, padding=0) 198 | 199 | out_conv = [] 200 | for i in range(1): 201 | out_conv.append(ResBlock(64)) 202 | out_conv.append(nn.ReflectionPad2d(1)) 203 | out_conv.append(nn.Conv2d(64, 3, kernel_size=3, padding=0)) 204 | self.out_conv = nn.Sequential(*out_conv) 205 | 206 | def forward(self, Input): 207 | x1, x2, x3, x4, x1_mean, x2_mean, x3_mean, x4_mean, x1_std, x2_std, x3_std, x4_std = self.encoder(Input) 208 | 209 | # x4->x3 210 | x4_mean = self.conv_u4(x4_mean) 211 | x4_std = self.conv_s4(x4_std) 212 | x4 = MS(x4, x4_mean, x4_std) 213 | x4 = self.UpConv4(x4) 214 | d3 = self.Up3(x4) 215 | # x3->x2 216 | d3 = torch.cat((x3, d3), dim=1) 217 | d3 = self.UpConv3(d3) 218 | x3_mean = self.conv_u3(x3_mean) 219 | x3_std = self.conv_s3(x3_std) 220 | d3 = MS(d3, x3_mean, x3_std) 221 | d2 = self.Up2(d3) 222 | # x2->x1 223 | d2 = torch.cat((x2, d2), dim=1) 224 | d2 = self.UpConv2(d2) 225 | x2_mean = self.conv_u2(x2_mean) 226 | x2_std = self.conv_s2(x2_std) 227 | d2 = MS(d2, x2_mean, x2_std) 228 | d1 = self.Up1(d2) 229 | # x1->out 230 | d1 = torch.cat((x1, d1), dim=1) 231 | d1 = self.UpConv1(d1) 232 | x1_mean = self.conv_u1(x1_mean) 233 | x1_std = self.conv_s1(x1_std) 234 | d1 = MS(d1, x1_mean, x1_std) 235 | out = self.out_conv(d1) 236 | 237 | return out 238 | 239 | 240 | class SCNet(nn.Module): 241 | def __init__(self): 242 | super(SCNet, self).__init__() 243 | self.decoder = Decoder() 244 | 245 | def forward(self, Input): 246 | return self.decoder(Input) 247 | 248 | 249 | if __name__ == '__main__': 250 | t = torch.randn(1, 3, 256, 256).cuda() 251 | model = SCNet().cuda() 252 | res = model(t) 253 | print(res.shape) -------------------------------------------------------------------------------- /models/TUDA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # 4个 conv 的 Dense block 5 | class Dense_Block_IN(nn.Module): 6 | def __init__(self, block_num, inter_channel, channel): 7 | super(Dense_Block_IN, self).__init__() 8 | # 9 | concat_channels = channel + block_num * inter_channel 10 | channels_now = channel 11 | 12 | self.group_list = nn.ModuleList([]) 13 | for i in range(block_num): 14 | group = nn.Sequential( 15 | nn.Conv2d(in_channels=channels_now, out_channels=inter_channel, kernel_size=3, 16 | stride=1, padding=1), 17 | nn.InstanceNorm2d(inter_channel, affine=True), 18 | nn.ReLU(), 19 | ) 20 | self.add_module(name='group_%d' % i, module=group) 21 | self.group_list.append(group) 22 | 23 | channels_now += inter_channel 24 | 25 | assert channels_now == concat_channels 26 | # 27 | self.fusion = nn.Sequential( 28 | nn.Conv2d(concat_channels, channel, kernel_size=1, stride=1, padding=0), 29 | nn.InstanceNorm2d(channel, affine=True), 30 | nn.ReLU(), 31 | ) 32 | # 33 | 34 | def forward(self, x): 35 | feature_list = [x] 36 | 37 | for group in self.group_list: 38 | inputs = torch.cat(feature_list, dim=1) 39 | outputs = group(inputs) 40 | feature_list.append(outputs) 41 | 42 | inputs = torch.cat(feature_list, dim=1) 43 | # 44 | fusion_outputs = self.fusion(inputs) 45 | # 46 | block_outputs = fusion_outputs + x 47 | 48 | return block_outputs 49 | 50 | class CALayer(nn.Module): 51 | def __init__(self, channel): 52 | super(CALayer, self).__init__() 53 | out_channel = channel 54 | self.relu = nn.ReLU(inplace=True) 55 | self.conv1 = nn.Conv2d(channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False) 56 | # 57 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 58 | self.ca = nn.Sequential( 59 | nn.Conv2d(out_channel, out_channel // 8, 1, padding=0, bias=True), 60 | nn.ReLU(inplace=True), 61 | nn.Conv2d(out_channel // 8, channel, 1, padding=0, bias=True), 62 | nn.Sigmoid() 63 | ) 64 | 65 | def forward(self, x): 66 | t1 = self.conv1(x) # in 67 | t2 = self.relu(t1) # in, 64 68 | y = self.avg_pool(t2) # torch.Size([1, in, 1, 1]) 69 | y = self.ca(y) # torch.Size([1, in, 1, 1]) 70 | m = t2 * y # torch.Size([1, in, 64, 64]) * torch.Size([1, in, 1, 1]) 71 | return x + m 72 | 73 | 74 | class PALayer(nn.Module): 75 | def __init__(self, channel): 76 | super(PALayer, self).__init__() 77 | self.pa = nn.Sequential( 78 | nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True), 79 | nn.ReLU(inplace=True), 80 | nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True), 81 | nn.Sigmoid() 82 | ) 83 | 84 | def forward(self, x): 85 | y = self.pa(x) 86 | return x * y 87 | 88 | # upsample 89 | class Trans_Up(nn.Module): 90 | def __init__(self, in_planes, out_planes): 91 | super(Trans_Up, self).__init__() 92 | self.conv0 = nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1) 93 | self.IN1 = nn.InstanceNorm2d(out_planes) 94 | self.relu = nn.ReLU(inplace=True) 95 | 96 | def forward(self, x): 97 | out = self.relu(self.IN1(self.conv0(x))) 98 | return out 99 | 100 | 101 | # downsample 102 | class Trans_Down(nn.Module): 103 | def __init__(self, in_planes, out_planes): 104 | super(Trans_Down, self).__init__() 105 | self.conv0 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=2, padding=1) 106 | self.IN1 = nn.InstanceNorm2d(out_planes) 107 | self.relu = nn.ReLU(inplace=True) 108 | 109 | def forward(self, x): 110 | out = self.relu(self.IN1(self.conv0(x))) 111 | return out 112 | 113 | class TUDA(nn.Module): 114 | def __init__(self, input_nc=3, output_nc=3): 115 | super(TUDA, self).__init__() 116 | 117 | self.conv1 = nn.Sequential( 118 | nn.Conv2d(input_nc, 64, 3, 1, 1), 119 | nn.InstanceNorm2d(64, affine=True), 120 | nn.ReLU(), 121 | ) 122 | # 123 | self.conv2 = nn.Sequential( 124 | nn.Conv2d(64, 64, 3, 1, 1), 125 | nn.InstanceNorm2d(64, affine=True), 126 | nn.ReLU(), 127 | ) 128 | # 几个 conv, 中间 channel, 输入 channel 129 | self.up_1 = Dense_Block_IN(block_num=3, inter_channel=32, channel=64) # 256 -> 128 130 | self.up_2 = Dense_Block_IN(block_num=3, inter_channel=32, channel=64) # 128 -> 64 131 | self.up_3 = Dense_Block_IN(block_num=3, inter_channel=32, channel=64) # 64 -> 32 132 | self.up_4 = Dense_Block_IN(block_num=3, inter_channel=32, channel=64) # 32 -> 16 133 | # 134 | self.Latent = Dense_Block_IN(block_num=3, inter_channel=32, channel=64) # 16 -> 8 135 | # 136 | self.down_4 = Dense_Block_IN(block_num=3, inter_channel=32, channel=64) # 16 137 | self.down_3 = Dense_Block_IN(block_num=3, inter_channel=32, channel=64) # 32 138 | self.down_2 = Dense_Block_IN(block_num=3, inter_channel=32, channel=64) # 64 139 | self.down_1 = Dense_Block_IN(block_num=3, inter_channel=32, channel=64) # 128 140 | # 141 | self.CALayer4 = CALayer(128) 142 | self.CALayer3 = CALayer(128) 143 | self.CALayer2 = CALayer(128) 144 | self.CALayer1 = CALayer(128) 145 | # 146 | self.trans_down1 = Trans_Down(64, 64) 147 | self.trans_down2 = Trans_Down(64, 64) 148 | self.trans_down3 = Trans_Down(64, 64) 149 | self.trans_down4 = Trans_Down(64, 64) 150 | # 151 | self.trans_up4 = Trans_Up(64, 64) 152 | self.trans_up3 = Trans_Up(64, 64) 153 | self.trans_up2 = Trans_Up(64, 64) 154 | self.trans_up1 = Trans_Up(64, 64) 155 | # 156 | self.down_4_fusion = nn.Sequential( 157 | nn.Conv2d(64 + 64, 64, 1, 1, 0), 158 | nn.InstanceNorm2d(64, affine=True), 159 | nn.ReLU(), 160 | ) 161 | self.down_3_fusion = nn.Sequential( 162 | nn.Conv2d(64 + 64, 64, 1, 1, 0), 163 | nn.InstanceNorm2d(64, affine=True), 164 | nn.ReLU(), 165 | ) 166 | self.down_2_fusion = nn.Sequential( 167 | nn.Conv2d(64 + 64, 64, 1, 1, 0), 168 | nn.InstanceNorm2d(64, affine=True), 169 | nn.ReLU(), 170 | ) 171 | self.down_1_fusion = nn.Sequential( 172 | nn.Conv2d(64 + 64, 64, 1, 1, 0), 173 | nn.InstanceNorm2d(64, affine=True), 174 | nn.ReLU(), 175 | ) 176 | # 177 | self.fusion = nn.Sequential( 178 | nn.Conv2d(64, 64, 1, 1, 0), 179 | nn.InstanceNorm2d(64, affine=True), 180 | nn.ReLU(), 181 | nn.Conv2d(64, 64, 3, 1, 1), 182 | nn.InstanceNorm2d(64, affine=True), 183 | nn.ReLU() 184 | ) 185 | self.fusion2 = nn.Sequential( 186 | nn.Conv2d(64, output_nc, 3, 1, 1), 187 | nn.Tanh(), 188 | ) 189 | # 190 | 191 | def forward(self, x): # 1, 3, 256, 256 192 | # 193 | feature_neg_1 = self.conv1(x) # 1, 64, 256, 256 194 | feature_0 = self.conv2(feature_neg_1) # # 1, 64, 256, 256 195 | ####################################################### 196 | up_11 = self.up_1(feature_0) # 1, 64, 256, 256 197 | up_1 = self.trans_down1(up_11) # 1, 64, 128, 128 198 | 199 | up_21 = self.up_2(up_1) # 1, 64, 128, 128 200 | up_2 = self.trans_down2(up_21) # 1, 64, 64, 64 201 | 202 | up_31 = self.up_3(up_2) # 1, 64, 64, 64 203 | up_3 = self.trans_down3(up_31) # 1, 64, 32, 32 204 | 205 | up_41 = self.up_4(up_3) # 1, 64, 32, 32 206 | up_4 = self.trans_down4(up_41) # 1, 64, 16, 16 207 | 208 | ####################################################### 209 | Latent = self.Latent(up_4) # 1, 64, 16, 16 210 | ####################################################### 211 | 212 | down_4 = self.trans_up4(Latent) # 1, 64, 32, 32 213 | down_4 = torch.cat([up_41, down_4], dim=1) # 1, 128, 32, 32 214 | down_41 = self.CALayer4(down_4) # 1, 128, 32, 32 215 | down_4 = self.down_4_fusion(down_41) # 1, 64, 32, 32 216 | down_4 = self.down_4(down_4) # 1, 64, 32, 32 217 | 218 | down_3 = self.trans_up3(down_4) # 1, 64, 64, 64 219 | down_3 = torch.cat([up_31, down_3], dim=1) # 1, 128, 64, 64 220 | down_31 = self.CALayer3(down_3) # 1, 128, 64, 64 221 | down_3 = self.down_3_fusion(down_31) # 1, 64, 64, 64 222 | down_3 = self.down_3(down_3) # 1, 64, 64, 64 223 | 224 | down_2 = self.trans_up2(down_3) # 1, 64, 128,128 225 | down_2 = torch.cat([up_21, down_2], dim=1) # 1, 128, 128,128 226 | down_21 = self.CALayer2(down_2) # 1, 128, 128,128 227 | down_2 = self.down_2_fusion(down_21) # 1, 64, 128,128 228 | down_2 = self.down_2(down_2) # 1, 64, 128,128 229 | 230 | down_1 = self.trans_up1(down_2) # 1, 64, 256, 256 231 | down_1 = torch.cat([up_11, down_1], dim=1) # 1, 128, 256, 256 232 | down_11 = self.CALayer1(down_1) # 1, 128, 256, 256 233 | down_1 = self.down_1_fusion(down_11) # 1, 64, 256, 256 234 | down_1 = self.down_1(down_1) # 1, 64, 256, 256 235 | # 236 | feature = self.fusion(down_1) # 1, 64, 256, 256 237 | # 238 | feature = feature + feature_neg_1 # 1, 64, 256, 256 239 | # 240 | outputs = self.fusion2(feature) 241 | return outputs 242 | 243 | 244 | if __name__ == '__main__': 245 | t = torch.randn(1, 3, 256, 256).cuda() 246 | model = TUDA().cuda() 247 | res = model(t) 248 | print(res.shape) -------------------------------------------------------------------------------- /models/FIVE_APLUS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class PALayer(nn.Module): 7 | def __init__(self, channel): 8 | super(PALayer, self).__init__() 9 | self.pa = nn.Sequential( 10 | 11 | nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True), 12 | nn.ReLU(inplace=True), 13 | nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True), 14 | nn.Sigmoid() 15 | ) 16 | 17 | def forward(self, x): 18 | y = self.pa(x) 19 | return x * y 20 | 21 | #Multi-Scale Pyramid Module 22 | class Enhance(nn.Module): 23 | def __init__(self): 24 | super(Enhance, self).__init__() 25 | 26 | self.relu=nn.ReLU(inplace=True) 27 | 28 | self.tanh=nn.Tanh() 29 | self.refine2= nn.Conv2d(16, 16, kernel_size=3,stride=1,padding=1) 30 | 31 | self.conv1010 = nn.Conv2d(16, 1, kernel_size=1,stride=1,padding=0) # 1mm 32 | self.conv1020 = nn.Conv2d(16, 1, kernel_size=1,stride=1,padding=0) # 1mm 33 | self.conv1030 = nn.Conv2d(16, 1, kernel_size=1,stride=1,padding=0) # 1mm 34 | self.refine3= nn.Conv2d(16+3, 16, kernel_size=3,stride=1,padding=1) 35 | self.upsample = F.interpolate 36 | 37 | 38 | def forward(self, x): 39 | dehaze = self.relu((self.refine2(x))) 40 | shape_out = dehaze.data.size() 41 | # print(shape_out) 42 | shape_out = shape_out[2:4] 43 | 44 | x101 = F.avg_pool2d(dehaze, 128) 45 | 46 | x102 = F.avg_pool2d(dehaze, 64) 47 | 48 | x103 = F.avg_pool2d(dehaze, 32) 49 | 50 | x1010 = self.upsample(self.relu(self.conv1010(x101)),size=shape_out) 51 | x1020 = self.upsample(self.relu(self.conv1020(x102)),size=shape_out) 52 | x1030 = self.upsample(self.relu(self.conv1030(x103)),size=shape_out) 53 | 54 | dehaze = torch.cat((x1010, x1020, x1030, dehaze), 1) 55 | dehaze= self.tanh(self.refine3(dehaze)) 56 | 57 | return dehaze 58 | 59 | 60 | class SFDIM(nn.Module): 61 | def __init__(self, n_feats): 62 | super().__init__() 63 | # i_feats =n_feats*2 64 | 65 | self.Conv1 =nn.Sequential( 66 | nn.Conv2d(n_feats,2*n_feats,1,1,0), 67 | nn.LeakyReLU(0.1,inplace=True), 68 | nn.Conv2d(2*n_feats,n_feats,1,1,0)) 69 | self.Conv1_1 =nn.Sequential( 70 | nn.Conv2d(n_feats,2*n_feats,1,1,0), 71 | nn.LeakyReLU(0.1,inplace=True), 72 | nn.Conv2d(2*n_feats,n_feats,1,1,0)) 73 | 74 | self.Conv2 = nn.Conv2d(n_feats, n_feats, 1, 1, 0) 75 | self.FF = FreBlock() 76 | self.scale = nn.Parameter(torch.zeros((1, n_feats, 1, 1)), requires_grad=True) 77 | 78 | def forward(self, x,y): 79 | b,c,H,W=x.shape 80 | a = 0.1 81 | mix = x+y 82 | mix_mag,mix_pha = self.FF(mix) 83 | #Ghost Expand 84 | mix_mag = self.Conv1(mix_mag) 85 | mix_pha = self.Conv1_1(mix_pha) 86 | 87 | real_main = mix_mag * torch.cos(mix_pha) 88 | imag_main = mix_mag * torch.sin(mix_pha) 89 | x_out_main = torch.complex(real_main, imag_main) 90 | x_out_main = torch.abs(torch.fft.irfft2(x_out_main, s=(H, W), norm='backward'))+1e-8 91 | 92 | return self.Conv2(a*x_out_main+(1-a)*mix) 93 | 94 | class FreBlock(nn.Module): 95 | def __init__(self): 96 | super(FreBlock, self).__init__() 97 | 98 | 99 | def forward(self,x): 100 | x = x+1e-8 101 | mag = torch.abs(x) 102 | pha = torch.angle(x) 103 | 104 | return mag,pha 105 | 106 | #Multi-branch Color Enhancement Modul 107 | class MCEM(nn.Module): 108 | def __init__(self, in_channels, channels): 109 | super(MCEM, self).__init__() 110 | self.conv_first_r = nn.Conv2d(in_channels//4, channels//2, kernel_size=1, stride=1, padding=0, bias=False) 111 | self.conv_first_g = nn.Conv2d(in_channels//4, channels//2, kernel_size=1, stride=1, padding=0, bias=False) 112 | self.conv_first_b = nn.Conv2d(in_channels//4, channels//2, kernel_size=1, stride=1, padding=0, bias=False) 113 | self.instance_r = nn.InstanceNorm2d(channels//2, affine=True) 114 | self.instance_g = nn.InstanceNorm2d(channels//2, affine=True) 115 | self.instance_b = nn.InstanceNorm2d(channels//2, affine=True) 116 | 117 | self.conv_out_r = nn.Conv2d( channels//2,in_channels//4, kernel_size=1, stride=1, padding=0, bias=False) 118 | self.conv_out_g = nn.Conv2d( channels//2, in_channels//4,kernel_size=1, stride=1, padding=0, bias=False) 119 | self.conv_out_b = nn.Conv2d( channels//2,in_channels//4, kernel_size=1, stride=1, padding=0, bias=False) 120 | 121 | def forward(self, x): 122 | 123 | x1,x2, x3,x4= torch.chunk(x, 4, dim=1) 124 | 125 | x_1 = self.conv_first_r(x1) 126 | x_2 = self.conv_first_g(x2) 127 | x_3 = self.conv_first_b(x3) 128 | 129 | out_instance_r = self.instance_r(x_1) 130 | out_instance_g = self.instance_g(x_2) 131 | out_instance_b = self.instance_b(x_3) 132 | 133 | out_instance_r=self.conv_out_r(out_instance_r) 134 | out_instance_g=self.conv_out_g(out_instance_g) 135 | out_instance_b=self.conv_out_b(out_instance_b) 136 | 137 | mix = out_instance_r+out_instance_g+out_instance_b+x4 138 | 139 | out_instance= torch.cat((out_instance_r, out_instance_g,out_instance_b,mix),dim=1) 140 | 141 | return out_instance 142 | 143 | class MCEM_2(nn.Module): 144 | def __init__(self, in_channels, channels): 145 | super(MCEM_2, self).__init__() 146 | self.conv_first_r = nn.Conv2d(in_channels//4, channels//2, kernel_size=1, stride=1, padding=0, bias=False) 147 | self.conv_first_g = nn.Conv2d(in_channels//4, channels//2, kernel_size=1, stride=1, padding=0, bias=False) 148 | self.conv_first_b = nn.Conv2d(in_channels//4, channels//2, kernel_size=1, stride=1, padding=0, bias=False) 149 | self.instance_r = nn.InstanceNorm2d(channels//2, affine=True) 150 | self.instance_g = nn.InstanceNorm2d(channels//2, affine=True) 151 | self.instance_b = nn.InstanceNorm2d(channels//2, affine=True) 152 | 153 | self.conv_out_r = nn.Conv2d( channels//2,in_channels//4, kernel_size=1, stride=1, padding=0, bias=False) 154 | self.conv_out_g = nn.Conv2d( channels//2, in_channels//4,kernel_size=1, stride=1, padding=0, bias=False) 155 | self.conv_out_b = nn.Conv2d( channels//2,in_channels//4, kernel_size=1, stride=1, padding=0, bias=False) 156 | 157 | def forward(self, x): 158 | 159 | x1,x2, x3,x4= torch.chunk(x, 4, dim=1) 160 | 161 | x_1 = self.conv_first_r(x1) 162 | x_2 = self.conv_first_g(x2) 163 | x_3 = self.conv_first_b(x3) 164 | 165 | out_instance_r = self.instance_r(x_1) 166 | out_instance_g = self.instance_g(x_2) 167 | out_instance_b = self.instance_b(x_3) 168 | 169 | out_instance_r=self.conv_out_r(out_instance_r) 170 | out_instance_g=self.conv_out_g(out_instance_g) 171 | out_instance_b=self.conv_out_b(out_instance_b) 172 | 173 | mix = out_instance_r+out_instance_g+out_instance_b+x4 174 | 175 | out_instance= torch.cat((out_instance_r, out_instance_g,out_instance_b,mix),dim=1) 176 | # out_instance = self.act(self.conv2(out_instance)) 177 | 178 | return out_instance 179 | # MAIN-Net 180 | class FIVE_APLUSNet(nn.Module): 181 | def __init__(self, in_nc=3, out_nc=3, base_nf=16): 182 | super(FIVE_APLUSNet, self).__init__() 183 | 184 | self.base_nf = base_nf 185 | self.out_nc = out_nc 186 | self.pyramid_enhance = Enhance() 187 | # self.encoder = Condition() 188 | self.color_cer_1 = MCEM(base_nf,base_nf*2) 189 | self.color_cer_2 = MCEM_2(base_nf,base_nf*2) 190 | 191 | self.fusion_mixer=SFDIM(base_nf) 192 | self.conv1 = nn.Conv2d(in_nc, base_nf, 1, 1, bias=True) 193 | self.conv2 = nn.Conv2d(base_nf, base_nf, 1, 1, bias=True) 194 | self.conv3 = nn.Conv2d(base_nf, out_nc, 1, 1, bias=True) 195 | # self.conv4 = nn.Conv2d(base_nf, out_nc, 1, 1, bias=True) 196 | self.stage2 = PALayer(base_nf) 197 | self.act = nn.ReLU(inplace=True) 198 | 199 | 200 | def forward(self, x): 201 | # cond = self.cond_net(x) 202 | 203 | out = self.conv1(x) 204 | out_1=self.color_cer_1(out) 205 | out_2=self.pyramid_enhance(out) 206 | mix_out = self.fusion_mixer(out_1,out_2) 207 | 208 | out_stage2 = self.act(mix_out) 209 | # out_stage2_head = self.conv4(out_stage2) 210 | 211 | out_stage2 = self.conv2(out_stage2) 212 | out_stage2=self.color_cer_2(out_stage2) 213 | out = self.stage2(out_stage2) 214 | out = self.act(out) 215 | 216 | out = self.conv3(out) 217 | 218 | return out 219 | 220 | 221 | if __name__ == '__main__': 222 | t = torch.randn(1, 3, 256, 256).cuda() 223 | model = FIVE_APLUSNet().cuda() 224 | res = model(t) 225 | print(res.shape) 226 | 227 | -------------------------------------------------------------------------------- /models/todo_SGUIE-Net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | ############################################################################## 6 | # Classes 7 | ############################################################################## 8 | 9 | def conv(in_channels, out_channels, kernel_size, bias=False, stride=1): 10 | return nn.Conv2d( 11 | in_channels, out_channels, kernel_size, 12 | padding=(kernel_size // 2), bias=bias, stride=stride) 13 | 14 | 15 | ## Channel Attention Layer 16 | class CALayer(nn.Module): 17 | def __init__(self, channel, reduction=8, bias=True): 18 | super(CALayer, self).__init__() 19 | # global average pooling: feature --> point 20 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 21 | # feature channel downscale and upscale --> channel weight 22 | self.conv_du = nn.Sequential( 23 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias), 24 | nn.ReLU(inplace=True), 25 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), 26 | nn.Sigmoid() 27 | ) 28 | 29 | def forward(self, x): 30 | y = self.avg_pool(x) 31 | y = self.conv_du(y) 32 | return x * y 33 | 34 | 35 | ## Pixel Attention Layer 36 | class PALayer(nn.Module): 37 | def __init__(self, channel, reduction=8, bias=True): 38 | super(PALayer, self).__init__() 39 | self.pa = nn.Sequential( 40 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias), 41 | nn.ReLU(inplace=True), 42 | nn.Conv2d(channel // reduction, 1, 1, padding=0, bias=bias), 43 | nn.Sigmoid() 44 | ) 45 | 46 | def forward(self, x): 47 | y = self.pa(x) 48 | return x * y 49 | 50 | 51 | ## Features Attention include Channel attention and Pixel Attention 52 | class FABlock(nn.Module): 53 | def __init__(self, dim, kernel_size, ): 54 | super(FABlock, self).__init__() 55 | self.conv1 = conv(dim, dim, kernel_size, bias=True) 56 | self.act1 = nn.ReLU(inplace=True) 57 | self.conv2 = conv(dim, dim, kernel_size, bias=True) 58 | self.calayer = CALayer(dim) 59 | self.palayer = PALayer(dim) 60 | 61 | def forward(self, x): 62 | res = self.act1(self.conv1(x)) 63 | res = res + x 64 | res = self.conv2(res) 65 | res = self.calayer(res) 66 | res = self.palayer(res) 67 | res = res + x 68 | return res 69 | 70 | 71 | ## Channel Attention Block(CAB) 72 | class CAB(nn.Module): 73 | def __init__(self, n_feat, kernel_size, reduction, bias, act): 74 | super(CAB, self).__init__() 75 | modules_body = [] 76 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 77 | modules_body.append(act) 78 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 79 | 80 | self.CA = CALayer(n_feat, reduction, bias=bias) 81 | self.body = nn.Sequential(*modules_body) 82 | 83 | def forward(self, x): 84 | res = self.body(x) 85 | res = self.CA(res) 86 | res = res + x 87 | return res 88 | 89 | 90 | # Attention augmented feature 91 | class FA(nn.Module): 92 | def __init__(self, in_channels=3, out_channels=3, dim=64, kernel_size=3): 93 | super(FA, self).__init__() 94 | self.conv1 = conv(in_channels, dim, kernel_size) 95 | self.fab = FABlock(dim, kernel_size) 96 | self.unet = UNet(dim, dim) 97 | self.conv2 = conv(dim, out_channels, kernel_size) 98 | 99 | def forward(self, raw): 100 | x = self.conv1(raw) 101 | x1 = self.fab(x) 102 | x2 = self.unet(x1) 103 | x3 = self.conv2(x2.clone()) 104 | img = x3 + raw 105 | return x2, img 106 | 107 | 108 | class Group(nn.Module): 109 | def __init__(self, dim, kernel_size, blocks): 110 | super(Group, self).__init__() 111 | modules = [FABlock(dim, kernel_size) for _ in range(blocks)] 112 | modules.append(conv(dim, dim, kernel_size)) 113 | self.gp = nn.Sequential(*modules) 114 | 115 | def forward(self, x): 116 | res = self.gp(x) 117 | res = res + x 118 | return res 119 | 120 | 121 | ############################################# 122 | # Unet 123 | ############################################# 124 | 125 | # Unet parts 126 | class DoubleConv(nn.Module): 127 | """(convolution => [BN] => ReLU) * 2""" 128 | 129 | def __init__(self, in_channels, out_channels, mid_channels=None): 130 | super().__init__() 131 | if not mid_channels: 132 | mid_channels = out_channels 133 | self.double_conv = nn.Sequential( 134 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 135 | nn.BatchNorm2d(mid_channels), 136 | nn.ReLU(inplace=True), 137 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 138 | nn.BatchNorm2d(out_channels), 139 | nn.ReLU(inplace=True) 140 | ) 141 | 142 | def forward(self, x): 143 | return self.double_conv(x) 144 | 145 | 146 | class Down(nn.Module): 147 | """Downscaling with maxpool then double conv""" 148 | 149 | def __init__(self, in_channels, out_channels): 150 | super().__init__() 151 | self.maxpool_conv = nn.Sequential( 152 | nn.MaxPool2d(2), 153 | DoubleConv(in_channels, out_channels) 154 | ) 155 | 156 | def forward(self, x): 157 | return self.maxpool_conv(x) 158 | 159 | 160 | class Up(nn.Module): 161 | """Upscaling then double conv""" 162 | 163 | def __init__(self, in_channels, out_channels, bilinear=True): 164 | super().__init__() 165 | 166 | # if bilinear, use the normal convolutions to reduce the number of channels 167 | if bilinear: 168 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 169 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 170 | else: 171 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 172 | self.conv = DoubleConv(in_channels, out_channels) 173 | 174 | def forward(self, x1, x2): 175 | x1 = self.up(x1) 176 | # input is CHW 177 | diffY = x2.size()[2] - x1.size()[2] 178 | diffX = x2.size()[3] - x1.size()[3] 179 | 180 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 181 | diffY // 2, diffY - diffY // 2]) 182 | # if you have padding issues, see 183 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 184 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 185 | x = torch.cat([x2, x1], dim=1) 186 | return self.conv(x) 187 | 188 | 189 | class OutConv(nn.Module): 190 | def __init__(self, in_channels, out_channels): 191 | super(OutConv, self).__init__() 192 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 193 | 194 | def forward(self, x): 195 | return self.conv(x) 196 | 197 | 198 | class UNet(nn.Module): 199 | def __init__(self, n_channels=11, n_classes=3, bilinear=True): 200 | super(UNet, self).__init__() 201 | self.n_channels = n_channels 202 | self.n_classes = n_classes 203 | self.bilinear = bilinear 204 | 205 | self.inc = DoubleConv(n_channels, 64) 206 | self.down1 = Down(64, 128) 207 | self.down2 = Down(128, 256) 208 | self.down3 = Down(256, 512) 209 | self.down4 = Down(512, 512) 210 | self.up1 = Up(1024, 256) 211 | self.up2 = Up(512, 128) 212 | self.up3 = Up(256, 64) 213 | self.up4 = Up(128, 64) 214 | self.outc = OutConv(64, n_classes) 215 | 216 | def forward(self, x): 217 | # x = torch.nn.Sigmoid()(x) 218 | x1 = self.inc(x) 219 | x2 = self.down1(x1) 220 | x3 = self.down2(x2) 221 | x4 = self.down3(x3) 222 | x5 = self.down4(x4) 223 | x = self.up1(x5, x4) 224 | x = self.up2(x, x3) 225 | x = self.up3(x, x2) 226 | x = self.up4(x, x1) 227 | x = self.outc(x) 228 | x = torch.nn.Tanh()(x) 229 | # print(x.shape) 230 | return x 231 | 232 | 233 | class SGUIENet(nn.Module): 234 | def __init__(self, in_channels=3, out_channels=3, dim=64, kernel_size=3, blocks=4, groups=3): 235 | super(SGUIENet, self).__init__() 236 | self.conv1 = conv(in_channels, dim, kernel_size) 237 | self.fab = FABlock(dim, kernel_size) 238 | self.orb = nn.Sequential(*[Group(dim, kernel_size, blocks) for _ in range(groups)]) 239 | self.conv2 = conv(dim, out_channels, kernel_size) 240 | 241 | def forward(self, raw): 242 | # main branch 243 | x1 = self.conv1(raw) 244 | x1 = self.fab(x1) 245 | 246 | x1 = self.orb(x1) 247 | x1 = self.conv2(x1) 248 | x1 += raw 249 | x1 = torch.tanh(x1) 250 | 251 | return x1 252 | 253 | if __name__ == '__main__': 254 | t = torch.randn(1, 3, 256, 256).cuda() 255 | model = SGUIENet().cuda() 256 | res = model(t) 257 | print(res.shape) -------------------------------------------------------------------------------- /models/UIALN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | # 论文:10.1109/TCSVT.2023.3237993 6 | # https://ieeexplore.ieee.org/abstract/document/10019314 7 | # UIALN: Enhancement for Underwater Image with Artificial Light 8 | class Retinex_Decomposition_net(nn.Module): 9 | def __init__(self, in_channels=1, out_channels=2): 10 | super(Retinex_Decomposition_net, self).__init__() 11 | self.relu = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1) 13 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) 14 | self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 15 | self.conv4 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1) 16 | self.conv5 = nn.Conv2d(32, out_channels, kernel_size=3, stride=1, padding=1) 17 | 18 | def forward(self, x): 19 | x = self.conv1(x) 20 | # relu激活 21 | x = self.relu(self.conv2(x)) 22 | x = self.relu(self.conv3(x)) 23 | x = self.relu(self.conv4(x)) 24 | x = self.relu(self.conv5(x)) 25 | return x 26 | 27 | 28 | class Illumination_Correction(nn.Module): 29 | def __init__(self, in_channels=2, out_channels=1): 30 | super(Illumination_Correction, self).__init__() 31 | self.down_1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=2) 32 | self.down_2 = nn.Conv2d(32, 64, kernel_size=3, stride=2) 33 | self.down_3 = nn.Conv2d(64, 128, kernel_size=3, stride=2) 34 | self.up_1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2) 35 | self.up_2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2) 36 | self.up_3 = nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, output_padding=1) 37 | # 相当于两次反卷积 38 | self.up_4_1 = nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2) ################# 存疑 39 | self.up_4_2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, output_padding=1) 40 | self.up_5 = nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, output_padding=1) 41 | self.conv1 = nn.Conv2d(32 * 3, out_channels, kernel_size=3, stride=1, padding=1) 42 | 43 | def forward(self, x): 44 | x = self.down_1(x) 45 | x = self.down_2(x) 46 | x = self.down_3(x) 47 | x1 = self.up_1(x) 48 | print(x1.shape) 49 | x2 = self.up_2(x1) 50 | print(x2.shape) 51 | x = self.up_3(x2) 52 | x1 = self.up_4_2(self.up_4_1(x1)) 53 | x2 = self.up_5(x2) 54 | print(x.shape, x1.shape, x2.shape) 55 | x = torch.cat((x, x1, x2), dim=1) 56 | x = self.conv1(x) 57 | return x 58 | 59 | 60 | # Residual Dense Block 61 | class Dense_Block_IN(nn.Module): 62 | def __init__(self, block_num, inter_channel, channel, with_residual=True): 63 | super(Dense_Block_IN, self).__init__() 64 | concat_channels = channel + block_num * inter_channel 65 | channels_now = channel 66 | 67 | self.group_list = nn.ModuleList([]) 68 | for i in range(block_num): 69 | group = nn.Sequential( 70 | nn.Conv2d(in_channels=channels_now, out_channels=inter_channel, kernel_size=3, 71 | stride=1, padding=1), 72 | nn.InstanceNorm2d(inter_channel, affine=True), 73 | nn.ReLU(), 74 | ) 75 | self.add_module(name='group_%d' % i, module=group) 76 | self.group_list.append(group) 77 | channels_now += inter_channel 78 | assert channels_now == concat_channels 79 | self.fusion = nn.Sequential( 80 | nn.Conv2d(concat_channels, channel, kernel_size=1, stride=1, padding=0), 81 | nn.InstanceNorm2d(channel, affine=True), 82 | nn.ReLU(), 83 | ) 84 | self.with_residual = with_residual 85 | 86 | def forward(self, x): 87 | feature_list = [x] 88 | for group in self.group_list: 89 | inputs = torch.cat(feature_list, dim=1) 90 | outputs = group(inputs) 91 | feature_list.append(outputs) 92 | inputs = torch.cat(feature_list, dim=1) 93 | fusion_outputs = self.fusion(inputs) 94 | if self.with_residual: 95 | block_outputs = fusion_outputs + x 96 | else: 97 | block_outputs = fusion_outputs 98 | 99 | return block_outputs 100 | 101 | 102 | class AL_Area_Selfguidance_Color_Correction(nn.Module): 103 | def __init__(self, in_channels=2, out_channels=2): 104 | super(AL_Area_Selfguidance_Color_Correction, self).__init__() 105 | self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1) 106 | self.conv2 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1) 107 | self.RDB1 = Dense_Block_IN(4, 32, 64) 108 | self.Down_1 = nn.Conv2d(64, 128, kernel_size=3, stride=2) 109 | self.RDB2 = Dense_Block_IN(4, 32, 128) 110 | self.Down_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2) 111 | self.RDB3 = Dense_Block_IN(4, 32, 256) 112 | self.Up_1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2) 113 | self.RDB4 = Dense_Block_IN(4, 32, 128 + 128) 114 | self.Up_2 = nn.ConvTranspose2d(128 + 128, 64, kernel_size=3, stride=2, output_padding=1) 115 | self.RDB5 = Dense_Block_IN(4, 32, 64 + 64) 116 | self.conv3 = nn.Conv2d(64 + 64, out_channels, kernel_size=3, stride=1, padding=1) 117 | 118 | def forward(self, x, y): 119 | x = x * y 120 | x = self.conv1(x) 121 | y = self.conv2(y) 122 | x = torch.cat((x, y), dim=1) 123 | x1 = self.RDB1(x) 124 | x2 = self.RDB2(self.Down_1(x1)) 125 | x = self.RDB3(self.Down_2(x2)) 126 | x = self.Up_1(x) 127 | x = torch.cat((x, x2), dim=1) 128 | x = self.Up_2(self.RDB4(x)) 129 | x = torch.cat((x, x1), dim=1) 130 | x = self.RDB5(x) 131 | x = self.conv3(x) 132 | return x 133 | 134 | 135 | class Detail_Enhancement(nn.Module): 136 | def __init__(self, in_channels=1, out_channels=1): 137 | super(Detail_Enhancement, self).__init__() 138 | self.Down_1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=2) 139 | self.DB_1 = Dense_Block_IN(4, 32, 32, with_residual=False) 140 | self.Down_2 = nn.Conv2d(32, 64, kernel_size=3, stride=2) 141 | self.DB_2 = Dense_Block_IN(4, 32, 64, with_residual=False) 142 | self.Down_3 = nn.Conv2d(64, 128, kernel_size=3, stride=2) 143 | self.DB_3 = Dense_Block_IN(4, 32, 128, with_residual=False) 144 | self.Down_4 = nn.Conv2d(128, 256, kernel_size=3, stride=2) 145 | self.DB_4 = Dense_Block_IN(4, 32, 256, with_residual=False) 146 | self.Up_1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2) 147 | self.DB_5 = Dense_Block_IN(4, 32, 128 + 128, with_residual=False) 148 | self.Up_2 = nn.ConvTranspose2d(128 + 128, 64, kernel_size=3, stride=2) 149 | self.DB_6 = Dense_Block_IN(4, 32, 64 + 64, with_residual=False) 150 | self.Up_3 = nn.ConvTranspose2d(64 + 64, 32, kernel_size=3, stride=2) 151 | self.DB_7 = Dense_Block_IN(4, 32, 32 + 32, with_residual=False) 152 | self.Up_4 = nn.ConvTranspose2d(32 + 32, 16, kernel_size=3, stride=2, output_padding=1) 153 | self.DB_8 = Dense_Block_IN(4, 32, 16 + in_channels, with_residual=False) 154 | self.conv1 = nn.Conv2d(16 + in_channels, out_channels, kernel_size=3, stride=1, padding=1) 155 | 156 | def forward(self, x): 157 | x0 = x 158 | x1 = self.DB_1(self.Down_1(x)) 159 | x2 = self.DB_2(self.Down_2(x1)) 160 | x3 = self.DB_3(self.Down_3(x2)) 161 | x = self.DB_4(self.Down_4(x3)) 162 | x = self.Up_1(x) 163 | x = torch.cat((x, x3), dim=1) 164 | x = self.Up_2(self.DB_5(x)) 165 | x = torch.cat((x, x2), dim=1) 166 | x = self.Up_3(self.DB_6(x)) 167 | x = torch.cat((x, x1), dim=1) 168 | x = self.Up_4(self.DB_7(x)) 169 | x = torch.cat((x, x0), dim=1) 170 | x = self.DB_8(x) 171 | x = self.conv1(x) 172 | return x 173 | 174 | 175 | class Channels_Fusion(nn.Module): 176 | def __init__(self, in_channels=3, out_channels=3): 177 | super(Channels_Fusion, self).__init__() 178 | self.relu = nn.ReLU(inplace=True) 179 | self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1) 180 | self.conv2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) 181 | self.conv3 = nn.Conv2d(32, out_channels, kernel_size=3, stride=1, padding=1) 182 | 183 | def forward(self, x): 184 | x = self.relu(self.conv1(x)) 185 | x = self.relu(self.conv2(x)) 186 | x = self.conv3(x) 187 | return x 188 | 189 | 190 | class UIALN(nn.Module): 191 | def __init__(self, in_channels=3, out_channels=3): 192 | super(UIALN, self).__init__() 193 | self.chan1 = Detail_Enhancement() 194 | self.chan2 = Detail_Enhancement() 195 | self.chan3 = Detail_Enhancement() 196 | self.fuse = Channels_Fusion() 197 | 198 | def forward(self, x): 199 | chan1 = self.chan1(x[:,0,:,:]) 200 | chan2 = self.chan2(x[:,1,:,:]) 201 | chan3 = self.chan3(x[:,2,:,:]) 202 | res = self.fuse(torch.cat((chan1, chan2, chan3), dim=1)) 203 | return res 204 | 205 | if __name__ == '__main__': 206 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 207 | model = Detail_Enhancement().to(device) 208 | t = torch.randn(1, 3, 256, 256).to(device) 209 | res = model(t) 210 | print(res.shape) -------------------------------------------------------------------------------- /models/RauneNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.nn import init 4 | 5 | 6 | class ChannelAttention(nn.Module): 7 | def __init__(self,channel,reduction=16): 8 | super().__init__() 9 | self.maxpool=nn.AdaptiveMaxPool2d(1) 10 | self.avgpool=nn.AdaptiveAvgPool2d(1) 11 | self.se=nn.Sequential( 12 | nn.Conv2d(channel,channel//reduction,1,bias=False), 13 | nn.ReLU(), 14 | nn.Conv2d(channel//reduction,channel,1,bias=False) 15 | ) 16 | self.sigmoid=nn.Sigmoid() 17 | 18 | def forward(self, x) : 19 | max_result=self.maxpool(x) 20 | avg_result=self.avgpool(x) 21 | max_out=self.se(max_result) 22 | avg_out=self.se(avg_result) 23 | output=self.sigmoid(max_out+avg_out) 24 | return output 25 | 26 | 27 | class SpatialAttention(nn.Module): 28 | def __init__(self,kernel_size=7): 29 | super().__init__() 30 | self.conv=nn.Conv2d(2,1,kernel_size=kernel_size,padding=kernel_size//2) 31 | self.sigmoid=nn.Sigmoid() 32 | 33 | def forward(self, x) : 34 | max_result,_=torch.max(x,dim=1,keepdim=True) 35 | avg_result=torch.mean(x,dim=1,keepdim=True) 36 | result=torch.cat([max_result,avg_result],1) 37 | output=self.conv(result) 38 | output=self.sigmoid(output) 39 | return output 40 | 41 | 42 | class CBAMBlock(nn.Module): 43 | def __init__(self, channel=512,reduction=16,kernel_size=7): 44 | super().__init__() 45 | self.ca=ChannelAttention(channel=channel,reduction=reduction) 46 | self.sa=SpatialAttention(kernel_size=kernel_size) 47 | 48 | 49 | def init_weights(self): 50 | for m in self.modules(): 51 | if isinstance(m, nn.Conv2d): 52 | init.kaiming_normal_(m.weight, mode='fan_out') 53 | if m.bias is not None: 54 | init.constant_(m.bias, 0) 55 | elif isinstance(m, nn.BatchNorm2d): 56 | init.constant_(m.weight, 1) 57 | init.constant_(m.bias, 0) 58 | elif isinstance(m, nn.Linear): 59 | init.normal_(m.weight, std=0.001) 60 | if m.bias is not None: 61 | init.constant_(m.bias, 0) 62 | 63 | def forward(self, x): 64 | b, c, _, _ = x.size() 65 | residual=x 66 | out=x*self.ca(x) 67 | out=out*self.sa(out) 68 | return out+residual 69 | 70 | class ResnetBlock(nn.Module): 71 | """Resnet block. 72 | 73 | Adapted from "https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix". 74 | """ 75 | 76 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 77 | """Initialize the Resnet block. 78 | 79 | A resnet block is a conv block with skip connections 80 | We construct a conv block with build_conv_block function, 81 | and implement skip connections in function. 82 | Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf 83 | """ 84 | super(ResnetBlock, self).__init__() 85 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 86 | 87 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 88 | """Construct a convolutional block. 89 | 90 | Parameters: 91 | dim (int) -- the number of channels in the conv layer. 92 | padding_type (str) -- the name of padding layer: reflect | replicate | zero 93 | norm_layer -- normalization layer 94 | use_dropout (bool) -- if use dropout layers. 95 | use_bias (bool) -- if the conv layer uses bias or not 96 | 97 | Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) 98 | """ 99 | conv_block = [] 100 | p = 0 101 | if padding_type == 'reflect': 102 | conv_block += [nn.ReflectionPad2d(1)] 103 | elif padding_type == 'replicate': 104 | conv_block += [nn.ReplicationPad2d(1)] 105 | elif padding_type == 'zero': 106 | p = 1 107 | else: 108 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 109 | 110 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] 111 | if use_dropout: 112 | conv_block += [nn.Dropout(0.5)] 113 | 114 | p = 0 115 | if padding_type == 'reflect': 116 | conv_block += [nn.ReflectionPad2d(1)] 117 | elif padding_type == 'replicate': 118 | conv_block += [nn.ReplicationPad2d(1)] 119 | elif padding_type == 'zero': 120 | p = 1 121 | else: 122 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 123 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] 124 | 125 | return nn.Sequential(*conv_block) 126 | 127 | def forward(self, x): 128 | """Forward function (with skip connections)""" 129 | out = x + self.conv_block(x) # add skip connections 130 | return out 131 | 132 | 133 | class RauneNet(nn.Module): 134 | """Residual and Attention-driven underwater enhancement Network. 135 | """ 136 | def __init__(self, input_nc=3, output_nc=3, n_blocks=30, n_down=2, ngf=64, 137 | padding_type='reflect', use_dropout=False, use_att_down=True, use_att_up=False, 138 | norm_layer=nn.InstanceNorm2d): 139 | """Initializes the RAUNE-Net. 140 | 141 | Args: 142 | input_nc: Number of channels of input images. 143 | output_nc: Number of chnnels of output images. 144 | n_blocks: Number of residual blocks. 145 | n_down: Number of down-sampling blocks. 146 | ngf: Number of kernels of Conv2d layer in `WRPM`. 147 | padding_type: Type of padding layer in Residual Block. 148 | use_dropout: Whether to use dropout. 149 | use_att_down: Whether to use attention block in down-sampling. 150 | use_att_up: Whether to use attention block in up-sampling. 151 | norm_layer: Type of Normalization layer. 152 | """ 153 | assert (n_blocks >= 0 and n_down >= 0) 154 | super().__init__() 155 | use_bias = False if norm_layer else True 156 | 157 | model = [] 158 | 159 | # Wide-range Perception Module (WRPM) 160 | model.append(nn.Sequential( 161 | nn.ReflectionPad2d(3), 162 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=False), 163 | norm_layer(ngf), 164 | nn.ReLU(True) 165 | )) 166 | 167 | # Attention Down-sampling Module (ADM) 168 | for i in range(n_down): 169 | mult = 2 ** i 170 | model.append(self._down(ngf*mult, ngf*mult*2, norm_layer=norm_layer, use_att=use_att_down, use_dropout=use_dropout)) 171 | 172 | # High-level Features Residual Learning Module (HFRLM) 173 | mult = 2 ** n_down 174 | for i in range(n_blocks): 175 | model.append(ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, 176 | use_dropout=use_dropout, use_bias=use_bias)) 177 | 178 | # Up-sampling Module (UM) 179 | for i in range(n_down): 180 | mult = 2 ** (n_down - i) 181 | model.append(self._up(ngf * mult, int(ngf * mult / 2), use_att=use_att_up, use_dropout=use_dropout)) 182 | 183 | # Feature Map Smoothing Module (FMSM) and Tanh Activation Layer 184 | model.append(nn.Sequential( 185 | nn.ReflectionPad2d(3), 186 | nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), 187 | nn.Tanh() 188 | )) 189 | 190 | self.model = nn.Sequential(*model) 191 | 192 | def _down(self, in_channels, out_channels, norm_layer=None, use_att=True, use_dropout=False, dropout_rate=0.5): 193 | """Attention Down-sampling Block. 194 | 195 | Args: 196 | in_channels: Number of channels of input tensor. 197 | out_channels: Number of channels of output tensor. 198 | norm_layer: Type of Normalization layer. 199 | use_att: Whether to use attention. 200 | use_dropout: Whether to use dropout. 201 | dropout_rate: Probability of dropout layer. 202 | """ 203 | use_bias = False if norm_layer else True 204 | layers = [nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=use_bias)] 205 | if norm_layer: 206 | layers.append(norm_layer(out_channels)) 207 | layers.append(nn.LeakyReLU(0.2)) 208 | if use_dropout: 209 | layers.append(nn.Dropout(dropout_rate)) 210 | if use_att: 211 | layers.append(CBAMBlock(out_channels)) 212 | return nn.Sequential(*layers) 213 | 214 | def _up(self, in_channels, out_channels, use_att=False, use_dropout=False, dropout_rate=0.5): 215 | """Up-sampling Block. 216 | 217 | Args: 218 | in_channels: Number of channels of input tensor. 219 | out_channels: Number of channels of output tensor. 220 | use_att: Whether to use attention. 221 | use_dropout: Whether to use dropout. 222 | dropout_rate: Probability of dropout layer. 223 | """ 224 | layers = [ 225 | nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False), 226 | nn.InstanceNorm2d(out_channels), 227 | nn.ReLU(), 228 | ] 229 | if use_dropout: 230 | layers.append(nn.Dropout(dropout_rate)) 231 | if use_att: 232 | layers.append(CBAMBlock(out_channels)) 233 | return nn.Sequential(*layers) 234 | 235 | def forward(self, input): 236 | """Forward function. 237 | 238 | Args: 239 | input: Input images. Type of `torch.Tensor`. 240 | """ 241 | return self.model(input) 242 | 243 | 244 | if __name__ == '__main__': 245 | inp = torch.randn(1, 3, 256, 256).cuda() 246 | model = RauneNet().cuda() 247 | res = model(inp) 248 | print(res.shape) -------------------------------------------------------------------------------- /models/CCMSRNet.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision.transforms.functional as TF 6 | from torch.utils.data import Dataset 7 | from torch.utils.data import DataLoader, random_split 8 | import os 9 | import numpy as np 10 | import random 11 | import time 12 | 13 | def normalize_img(img): 14 | if torch.max(img) > 1 or torch.min(img) < 0: 15 | # img: b x c x h x w 16 | b, c, h, w = img.shape 17 | temp_img = img.view(b, c, h*w) 18 | im_max = torch.max(temp_img, dim=2)[0].view(b, c, 1) 19 | im_min = torch.min(temp_img, dim=2)[0].view(b, c, 1) 20 | 21 | temp_img = (temp_img - im_min) / (im_max - im_min + 1e-7) 22 | 23 | img = temp_img.view(b, c, h, w) 24 | 25 | return img 26 | 27 | 28 | class DoubleConv(nn.Module): 29 | 30 | def __init__(self, in_channels, out_channels, mid_channels=None): 31 | super().__init__() 32 | if not mid_channels: 33 | mid_channels = out_channels 34 | self.double_conv = nn.Sequential( 35 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 36 | nn.BatchNorm2d(mid_channels), 37 | nn.ReLU(inplace=True), 38 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 39 | nn.BatchNorm2d(out_channels), 40 | nn.ReLU(inplace=True) 41 | ) 42 | 43 | def forward(self, x): 44 | return self.double_conv(x) 45 | 46 | class Down_Axial_onlyV(nn.Module): 47 | def __init__(self, in_channels, out_channels, key_dim, num_heads): 48 | super().__init__() 49 | self.pool = nn.MaxPool2d(2) 50 | self.conv = nn.Sequential( 51 | DoubleConv(in_channels, out_channels), 52 | ) 53 | self.pool4trans = nn.AdaptiveAvgPool2d((16,16)) 54 | self.attn = Sea_Attention_onlyV(dim = in_channels, key_dim=key_dim, num_heads=num_heads) 55 | self.transition = nn.Conv2d(in_channels,out_channels,1) 56 | 57 | def forward(self, x): 58 | x = self.pool(x) 59 | x_conv = self.conv(x) 60 | b,c,h,w = x_conv.shape 61 | 62 | x_trans = self.pool4trans(x) 63 | x_trans = self.attn(x_trans) 64 | x_trans = self.transition(x_trans) 65 | x_trans = F.interpolate(x_trans,size=(h,w),mode='bilinear') 66 | return x_conv + x_trans 67 | 68 | class Conv2d_BN(nn.Module): 69 | def __init__(self,in_channel,out_channel,ks=1,stride=1,pad=0,dilation=1,groups=1): 70 | super().__init__() 71 | self.conv = nn.Conv2d(in_channel,out_channel,kernel_size=ks, stride=stride, padding=pad, dilation=dilation, groups=groups) 72 | self.bn = nn.BatchNorm2d(out_channel) 73 | def forward(self,x): 74 | return self.bn(self.conv(x)) 75 | 76 | 77 | class SqueezeAxialPositionalEmbedding(nn.Module): 78 | def __init__(self, dim, shape): 79 | super().__init__() 80 | 81 | self.pos_embed = nn.Parameter(torch.randn([1, dim, shape]), requires_grad=True) 82 | 83 | def forward(self, x): 84 | B, C, N = x.shape 85 | x = x + F.interpolate(self.pos_embed, size=(N), mode='linear', align_corners=False) 86 | return x 87 | 88 | 89 | class Sea_Attention_onlyV(torch.nn.Module): 90 | def __init__(self, dim, key_dim, num_heads, 91 | attn_ratio=2, 92 | activation=nn.ReLU, 93 | norm_cfg=dict(type='BN', requires_grad=True), ): 94 | super().__init__() 95 | self.num_heads = num_heads 96 | self.scale = key_dim ** -0.5 97 | self.key_dim = key_dim 98 | self.nh_kd = nh_kd = key_dim * num_heads # num_head key_dim 99 | self.d = int(attn_ratio * key_dim) 100 | self.dh = int(attn_ratio * key_dim) * num_heads 101 | 102 | self.attn_ratio = attn_ratio 103 | 104 | self.to_q = Conv2d_BN(dim, nh_kd, 1 ) 105 | self.to_k = Conv2d_BN(dim, nh_kd, 1 ) 106 | self.to_v = Conv2d_BN(dim, self.dh, 1 ) 107 | 108 | self.proj = torch.nn.Sequential(activation(), Conv2d_BN( 109 | self.dh, dim )) 110 | self.proj_encode_row = torch.nn.Sequential(activation(), Conv2d_BN( 111 | self.dh, self.dh )) 112 | self.pos_emb_rowq = SqueezeAxialPositionalEmbedding(nh_kd, 16) 113 | self.pos_emb_rowk = SqueezeAxialPositionalEmbedding(nh_kd, 16) 114 | self.proj_encode_column = torch.nn.Sequential(activation(), Conv2d_BN( 115 | self.dh, self.dh )) 116 | self.pos_emb_columnq = SqueezeAxialPositionalEmbedding(nh_kd, 16) 117 | self.pos_emb_columnk = SqueezeAxialPositionalEmbedding(nh_kd, 16) 118 | 119 | self.dwconv = Conv2d_BN(self.dh, self.dh, ks=3, stride=1, pad=1, dilation=1, 120 | groups=self.dh ) 121 | self.act = activation() 122 | self.pwconv = Conv2d_BN(self.dh, dim, ks=1 ) 123 | 124 | def forward(self, x): # x (B,N,C) 125 | B, C, H, W = x.shape 126 | 127 | q = self.to_q(x) 128 | k = self.to_k(x) 129 | v = self.to_v(x) 130 | 131 | 132 | qrow = self.pos_emb_rowq(q.mean(-1)).reshape(B, self.num_heads, -1, H).permute(0, 1, 3, 2) 133 | krow = self.pos_emb_rowk(k.mean(-1)).reshape(B, self.num_heads, -1, H) 134 | vrow = v.mean(-1).reshape(B, self.num_heads, -1, H).permute(0, 1, 3, 2) 135 | 136 | attn_row = torch.matmul(qrow, krow) * self.scale 137 | attn_row = attn_row.softmax(dim=-1) 138 | xx_row = torch.matmul(attn_row, vrow) # B nH H C 139 | xx_row = self.proj_encode_row(xx_row.permute(0, 1, 3, 2).reshape(B, self.dh, H, 1)) 140 | 141 | 142 | ## squeeze column 143 | qcolumn = self.pos_emb_columnq(q.mean(-2)).reshape(B, self.num_heads, -1, W).permute(0, 1, 3, 2) 144 | kcolumn = self.pos_emb_columnk(k.mean(-2)).reshape(B, self.num_heads, -1, W) 145 | vcolumn = v.mean(-2).reshape(B, self.num_heads, -1, W).permute(0, 1, 3, 2) 146 | 147 | attn_column = torch.matmul(qcolumn, kcolumn) * self.scale 148 | attn_column = attn_column.softmax(dim=-1) 149 | xx_column = torch.matmul(attn_column, vcolumn) # B nH W C 150 | xx_column = self.proj_encode_column(xx_column.permute(0, 1, 3, 2).reshape(B, self.dh, 1, W)) 151 | 152 | xx = xx_row.add(xx_column) 153 | xx = v.add(xx) 154 | xx = self.proj(xx) 155 | return xx 156 | 157 | class Up_Axial_onlyV(nn.Module): 158 | """Upscaling then double conv""" 159 | 160 | def __init__(self, in_channels, out_channels, key_dim, num_heads,bilinear=True): 161 | super().__init__() 162 | 163 | # if bilinear, use the normal convolutions to reduce the number of channels 164 | if bilinear: 165 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 166 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 167 | else: 168 | self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2) 169 | self.conv = DoubleConv(in_channels, out_channels) 170 | self.pool4trans = nn.AdaptiveAvgPool2d((16,16)) 171 | self.attn = Sea_Attention_onlyV(dim=in_channels,key_dim=key_dim, num_heads=num_heads) 172 | self.transition = nn.Conv2d(in_channels,out_channels,1) 173 | 174 | def forward(self, x1, x2): 175 | x1 = self.up(x1) 176 | # input is CHW 177 | diffY = x2.size()[2] - x1.size()[2] 178 | diffX = x2.size()[3] - x1.size()[3] 179 | 180 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 181 | diffY // 2, diffY - diffY // 2]) 182 | x = torch.cat([x2, x1], dim=1) 183 | x_conv = self.conv(x) 184 | b,c,h,w = x_conv.shape 185 | 186 | x_trans = self.pool4trans(x) 187 | x_trans = self.attn(x_trans) 188 | x_trans = self.transition(x_trans) 189 | x_trans = F.interpolate(x_trans,size=(h,w),mode='bilinear') 190 | return x_conv+x_trans 191 | 192 | class OutConv(nn.Module): 193 | def __init__(self, in_channels, out_channels): 194 | super(OutConv, self).__init__() 195 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 196 | 197 | def forward(self, x): 198 | return self.conv(x) 199 | 200 | class CCNet(nn.Module): 201 | def __init__(self, n_channels, n_classes,img_size, bilinear=True): 202 | super().__init__() 203 | self.n_channels = n_channels 204 | self.n_classes = n_classes 205 | factor = 2 if bilinear else 1 206 | self.bilinear = bilinear 207 | 208 | #self.b = nn.Linear(1024//factor,n_channels) 209 | 210 | 211 | self.inc = DoubleConv(n_channels, 16) 212 | self.down1 = Down_Axial_onlyV(16, 32,16,4) 213 | self.down2 = Down_Axial_onlyV(32, 64,16,4) 214 | self.down3 = Down_Axial_onlyV(64, 128,16,4) 215 | self.down4 = Down_Axial_onlyV(128, 128,16,4) 216 | 217 | self.up1 = Up_Axial_onlyV(256,64,16,4) 218 | self.up2 = Up_Axial_onlyV(128,32,16,4) 219 | self.up3 = Up_Axial_onlyV(64,16,16,4) 220 | self.up4 = Up_Axial_onlyV(32,16,16,4) 221 | self.outc = OutConv(16,2) 222 | self.sigmoid = nn.Sigmoid() 223 | 224 | 225 | def forward(self, x): 226 | ### color balance 227 | b,c,h,w= x.shape 228 | I = x 229 | x1 = self.inc(x) 230 | x2 = self.down1(x1) 231 | x3 = self.down2(x2) 232 | x4 = self.down3(x3) 233 | x5 = self.down4(x4) 234 | 235 | x = self.up1(x5, x4) 236 | x = self.up2(x, x3) 237 | x = self.up3(x, x2) 238 | x = self.up4(x, x1) 239 | maps = self.outc(x) 240 | maps = self.sigmoid(maps) 241 | r_map = maps[:,0,:,:] 242 | b_map = maps[:,1,:,:] 243 | R,G,B = I[:,0,:,:],I[:,1,:,:],I[:,2,:,:] 244 | 245 | R = R + r_map*(G-R)*(1-R)*G 246 | B = B + b_map*(G-B)*(1-B)*G 247 | 248 | R = R.unsqueeze(1) 249 | G = G.unsqueeze(1) 250 | B = B.unsqueeze(1) 251 | out = torch.cat([R,G,B],dim=1) 252 | x_cc = normalize_img(out) 253 | 254 | return x_cc 255 | 256 | class UNetAxialFuser(nn.Module): 257 | def __init__(self, n_channels, n_classes,img_size, bilinear=True): 258 | super().__init__() 259 | self.n_channels = n_channels 260 | self.n_classes = n_classes 261 | factor = 2 if bilinear else 1 262 | self.bilinear = bilinear 263 | 264 | self.inc = DoubleConv(n_channels, 64) 265 | self.down1 = Down_Axial_onlyV(64, 128,16,4) 266 | self.down2 = Down_Axial_onlyV(128, 256,16,4) 267 | self.down3 = Down_Axial_onlyV(256, 512,16,4) 268 | 269 | self.down4 = Down_Axial_onlyV(512, 1024 // factor,16,4) 270 | self.up1 = Up_Axial_onlyV(1024, 512 // factor, 16,4,bilinear) 271 | self.up2 = Up_Axial_onlyV(512, 256 // factor, 16,4,bilinear) 272 | self.up3 = Up_Axial_onlyV(256, 128 // factor,16,4,bilinear) 273 | self.up4 = Up_Axial_onlyV(128, 64,16,4, bilinear) 274 | 275 | self.outc = OutConv(64,n_classes) 276 | 277 | def forward(self, x): 278 | 279 | b,c,h,w = x.shape 280 | 281 | x1 = self.inc(x) 282 | x2 = self.down1(x1) 283 | x3 = self.down2(x2) 284 | x4 = self.down3(x3) 285 | x5 = self.down4(x4) 286 | 287 | ### F decoder 288 | 289 | x = self.up1(x5, x4) 290 | x = self.up2(x, x3) 291 | x = self.up3(x, x2) 292 | x = self.up4(x, x1) 293 | 294 | out = self.outc(x) 295 | return out 296 | 297 | class CCMSRNet(nn.Module): 298 | def __init__(self, n_channels=3, n_classes=3,img_size=256, bilinear=True): 299 | super(CCMSRNet, self).__init__() 300 | self.n_channels = n_channels 301 | self.n_classes = n_classes 302 | factor = 2 if bilinear else 1 303 | self.bilinear = bilinear 304 | self.ccnet = CCNet(n_channels, n_classes,img_size, bilinear) 305 | self.fuser = UNetAxialFuser(9,3,img_size,bilinear) 306 | 307 | def forward(self, x): 308 | x_cc = self.ccnet(x) 309 | 310 | I = x_cc 311 | ssr1 = torch.log(I+1/255)*(1 - torch.log(TF.gaussian_blur(I+1/255,kernel_size=3))) 312 | ssr2 = torch.log(I+1/255)*(1 - torch.log(TF.gaussian_blur(I+1/255,kernel_size=7))) 313 | ssr3 = torch.log(I+1/255)*(1 - torch.log(TF.gaussian_blur(I+1/255,kernel_size=11))) 314 | msr_cat = torch.cat([ssr1,ssr2,ssr3],dim=1) 315 | msr_fuse = self.fuser(msr_cat) 316 | 317 | msr = normalize_img(msr_fuse) 318 | return msr 319 | 320 | if __name__ == '__main__': 321 | inp = torch.randn(1, 3, 256, 256).cuda() 322 | model = CCMSRNet().cuda() 323 | res = model(inp) 324 | print(res.shape) -------------------------------------------------------------------------------- /models/Spectroformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import torch.fft as fft 6 | 7 | 8 | class AGSSF(nn.Module): 9 | def __init__(self, channels, b=1, gamma=2): 10 | super(AGSSF, self).__init__() 11 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 12 | self.channels = channels 13 | self.b = b 14 | self.gamma = gamma 15 | self.conv = nn.Conv1d(1, 1, kernel_size=self.kernel_size(), padding=(self.kernel_size() - 1) // 2, bias=False) 16 | self.sigmoid = nn.Sigmoid() 17 | 18 | def kernel_size(self): 19 | k = int(abs((math.log2(self.channels)/self.gamma)+ self.b/self.gamma)) 20 | out = k if k % 2 else k+1 21 | return out 22 | 23 | def forward(self, x): 24 | 25 | # x1=inv_mag(x) 26 | # feature descriptor on the global spatial information 27 | y = self.avg_pool(x) 28 | 29 | 30 | # Two different branches of ECA module 31 | y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) 32 | 33 | 34 | # Multi-scale information fusion 35 | y = self.sigmoid(y) 36 | 37 | 38 | return x * y.expand_as(x) 39 | 40 | 41 | class SFCA(nn.Module): 42 | def __init__(self, channels, relu_slope=0.2, gamma=2): 43 | super(SFCA, self).__init__() 44 | self.identity1 = nn.Conv2d(channels, channels, 1) 45 | self.identity2 = nn.Conv2d(channels, channels, 1) 46 | self.conv_1 = nn.Conv2d(channels, 2*channels, kernel_size=1, bias=True) 47 | self.relu_1 = nn.LeakyReLU(relu_slope) 48 | self.conv_2 = nn.Conv2d(2*channels, channels, kernel_size=3, padding=1, groups=channels, bias=True) 49 | self.relu_2 = nn.LeakyReLU(relu_slope) 50 | 51 | self.conv_f1 = nn.Conv2d(channels, 2*channels, kernel_size=1) 52 | self.conv_f2 = nn.Conv2d(2*channels, channels, kernel_size=1) 53 | self.con2X1 = nn.Conv2d(2*channels, channels, kernel_size=1) 54 | self.agssf = AGSSF(channels) 55 | 56 | def forward(self, x): 57 | out = self.conv_1(x) 58 | out_1, out_2 = torch.chunk(out, 2, dim=1) 59 | out = torch.cat([out_1, out_2], dim=1) 60 | out = self.relu_1(out) 61 | out = self.relu_2(self.conv_2(out)) 62 | # print(self.identity1(x).shape, out.shape) 63 | out += self.identity1(x) 64 | 65 | x_fft = fft.fftn(x, dim=(-2, -1)).real 66 | x_fft = F.gelu(self.conv_f1(x_fft)) 67 | x_fft = self.conv_f2(x_fft) 68 | x_reconstructed = fft.ifftn(x_fft, dim=(-2, -1)).real 69 | x_reconstructed += self.identity2(x) 70 | 71 | f_out = self.con2X1(torch.cat([out, x_reconstructed], dim=1)) 72 | 73 | return self.agssf(f_out) 74 | 75 | 76 | class MDTA(nn.Module): 77 | def __init__(self, channels, num_heads): 78 | super(MDTA, self).__init__() 79 | self.num_heads = num_heads 80 | self.temperature = nn.Parameter(torch.ones(1, num_heads, 1, 1)) 81 | self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1, bias=False) 82 | self.qkv_conv = nn.Conv2d(channels * 3, channels * 3, kernel_size=3, padding=1, groups=channels * 3, bias=False) 83 | self.project_out = nn.Conv2d(channels, channels, kernel_size=1, bias=False) 84 | 85 | #frequency 86 | 87 | self.kv = nn.Conv2d(channels, channels * 2, kernel_size=1, bias=False) 88 | self.q1X1_1 = nn.Conv2d(channels, channels , kernel_size=1, bias=False) 89 | self.q1X1_2 = nn.Conv2d(channels, channels , kernel_size=1, bias=False) 90 | self.kv_conv = nn.Conv2d(channels * 2, channels * 2, kernel_size=3, padding=1, groups=channels * 2, bias=False) 91 | self.project_outf = nn.Conv2d(channels, channels, kernel_size=1, bias=False) 92 | 93 | 94 | 95 | def forward(self, x): 96 | b, c, h, w = x.shape 97 | q, k, v = self.qkv_conv(self.qkv(x)).chunk(3, dim=1) 98 | q = q.reshape(b, self.num_heads, -1, h * w) 99 | k = k.reshape(b, self.num_heads, -1, h * w) 100 | v = v.reshape(b, self.num_heads, -1, h * w) 101 | q, k = F.normalize(q, dim=-1), F.normalize(k, dim=-1) 102 | attn = torch.softmax(torch.matmul(q, k.transpose(-2, -1).contiguous()) * self.temperature, dim=-1) 103 | out = self.project_out(torch.matmul(attn, v).reshape(b, -1, h, w)) 104 | 105 | # frequency 106 | 107 | x_fft = fft.fftn(x, dim=(-2, -1)).real 108 | x_fft1=self.q1X1_1(x_fft) 109 | x_fft2=F.gelu(x_fft1) 110 | x_fft3=self.q1X1_2(x_fft2) 111 | qf=fft.ifftn(x_fft3,dim=(-2, -1)).real 112 | 113 | 114 | kf, vf = self.kv_conv(self.kv(out)).chunk(2, dim=1) 115 | qf = qf.reshape(b, self.num_heads, -1, h * w) 116 | kf = kf.reshape(b, self.num_heads, -1, h * w) 117 | vf = vf.reshape(b, self.num_heads, -1, h * w) 118 | qf, kf = F.normalize(qf, dim=-1), F.normalize(kf, dim=-1) 119 | attnf = torch.softmax(torch.matmul(qf, k.transpose(-2, -1).contiguous()) * self.temperature, dim=-1) 120 | outf = self.project_outf(torch.matmul(attn, vf).reshape(b, -1, h, w)) 121 | return outf 122 | 123 | 124 | class GDFN(nn.Module): 125 | def __init__(self, channels, expansion_factor): 126 | super(GDFN, self).__init__() 127 | 128 | hidden_channels = int(channels * expansion_factor) 129 | self.project_in = nn.Conv2d(channels, hidden_channels * 2, kernel_size=1, bias=False) 130 | self.conv = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, kernel_size=3, padding=1, 131 | groups=hidden_channels * 2, bias=False) 132 | self.project_out = nn.Conv2d(hidden_channels, channels, kernel_size=1, bias=False) 133 | 134 | def forward(self, x): 135 | x1, x2 = self.conv(self.project_in(x)).chunk(2, dim=1) 136 | x = self.project_out(F.gelu(x1) * x2) 137 | return x 138 | 139 | 140 | 141 | 142 | class TransformerBlock(nn.Module): 143 | def __init__(self, channels, num_heads, expansion_factor): 144 | super(TransformerBlock, self).__init__() 145 | 146 | self.norm1 = nn.LayerNorm(channels) 147 | self.attn = MDTA(channels, num_heads) 148 | self.norm2 = nn.LayerNorm(channels) 149 | self.ffn = GDFN(channels, expansion_factor) 150 | 151 | def forward(self, x): 152 | b, c, h, w = x.shape 153 | x = x + self.attn(self.norm1(x.reshape(b, c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1) 154 | .contiguous().reshape(b, c, h, w)) 155 | x = x + self.ffn(self.norm2(x.reshape(b, c, -1).transpose(-2, -1).contiguous()).transpose(-2, -1) 156 | .contiguous().reshape(b, c, h, w)) 157 | return x 158 | 159 | 160 | class DownSample(nn.Module): 161 | def __init__(self, channels): 162 | super(DownSample, self).__init__() 163 | self.body = nn.Sequential(nn.Conv2d(channels, channels // 2, kernel_size=3, padding=1, bias=False), 164 | nn.PixelUnshuffle(2)) 165 | 166 | def forward(self, x): 167 | return self.body(x) 168 | 169 | 170 | 171 | class UpSample(nn.Module): 172 | def __init__(self, channels,channel_red): 173 | super(UpSample, self).__init__() 174 | 175 | self.amp_fuse = nn.Sequential(nn.Conv2d(channels,channels,1,1,0),nn.LeakyReLU(0.1,inplace=False), 176 | nn.Conv2d(channels,channels,1,1,0)) 177 | self.pha_fuse = nn.Sequential(nn.Conv2d(channels,channels,1,1,0),nn.LeakyReLU(0.1,inplace=False), 178 | nn.Conv2d(channels,channels,1,1,0)) 179 | if channel_red: 180 | self.post = nn.Conv2d(channels, channels//2, 1, 1, 0) 181 | 182 | else: 183 | self.post = nn.Conv2d(channels, channels, 1, 1, 0) 184 | 185 | 186 | def forward(self, x): 187 | N, C, H, W = x.shape 188 | fft_x = torch.fft.fft2(x) 189 | mag_x = torch.abs(fft_x) 190 | pha_x = torch.angle(fft_x) 191 | Mag = self.amp_fuse(mag_x) 192 | Pha = self.pha_fuse(pha_x) 193 | amp_fuse = torch.tile(Mag, (2, 2)) 194 | pha_fuse = torch.tile(Pha, (2, 2)) 195 | real = amp_fuse * torch.cos(pha_fuse) 196 | imag = amp_fuse * torch.sin(pha_fuse) 197 | out = torch.complex(real, imag) 198 | output = torch.fft.ifft2(out) 199 | output = torch.abs(output) 200 | return self.post(output) 201 | 202 | 203 | class UpSample1(nn.Module): 204 | def __init__(self, channels): 205 | super(UpSample1, self).__init__() 206 | self.body = nn.Sequential(nn.Conv2d(channels, channels * 2, kernel_size=3, padding=1, bias=False), 207 | nn.PixelShuffle(2)) 208 | 209 | def forward(self, x): 210 | return self.body(x) 211 | 212 | 213 | class UpS(nn.Module): 214 | def __init__(self, channels): 215 | super(UpS, self).__init__() 216 | self.Fups=UpSample(channels,True) 217 | self.Sups=UpSample1(channels) 218 | self.reduce=nn.Conv2d(channels, channels // 2, kernel_size=1,bias=False) 219 | 220 | def forward(self, x): 221 | out=torch.cat([self.Fups(x),self.Sups(x)],dim=1) 222 | # print(out.shape) 223 | return self.reduce(out) 224 | 225 | 226 | class Model(nn.Module): 227 | def __init__(self, num_blocks=[2, 3, 3, 4], num_heads=[1, 2, 4, 8], channels=[16, 32, 64, 128], num_refinement=4, 228 | expansion_factor=2.66, ch=[64,32,16,64]): 229 | super(Model, self).__init__() 230 | 231 | self.attention = nn.ModuleList([SFCA(num_ch) for num_ch in ch]) 232 | self.embed_conv_rgb = nn.Conv2d(3, channels[0], kernel_size=3, padding=1, bias=False) 233 | self.encoders = nn.ModuleList([nn.Sequential(*[TransformerBlock(num_ch, num_ah, expansion_factor) for _ in range(num_tb)]) for num_tb, num_ah, num_ch in 234 | zip(num_blocks, num_heads, channels)]) 235 | 236 | self.down1 = DownSample(channels[0]) 237 | self.down2 = DownSample(channels[1]) 238 | self.down3 = DownSample(channels[2]) 239 | self.ups_1=UpS(128) 240 | self.ups_2=UpS(64) 241 | self.ups_3=UpS(32) 242 | self.ups_4=UpS(3) 243 | 244 | self.ups1 = UpSample1(32) 245 | self.reduces2 = nn.Conv2d(64, 32, kernel_size=1, bias=False) 246 | self.reduces1=nn.Conv2d(128, 64, kernel_size=1, bias=False) 247 | 248 | self.decoders = nn.ModuleList([nn.Sequential(*[TransformerBlock(channels[2], num_heads[2], expansion_factor) 249 | for _ in range(num_blocks[2])])]) 250 | self.decoders.append(nn.Sequential(*[TransformerBlock(channels[1], num_heads[1], expansion_factor) 251 | for _ in range(num_blocks[1])])) 252 | 253 | self.decoders.append(nn.Sequential(*[TransformerBlock(channels[1], num_heads[0], expansion_factor) for _ in range(num_blocks[0])])) 254 | 255 | self.refinement = nn.Sequential(*[TransformerBlock(channels[1], num_heads[0], expansion_factor) 256 | for _ in range(num_refinement)]) 257 | self.output = nn.Conv2d(8, 3, kernel_size=3, padding=1, bias=False) 258 | self.output1= nn.Conv2d(16, 8, kernel_size=3, padding=1, bias=False) 259 | 260 | self.ups2 = UpSample1(16) 261 | self.outputl=nn.Conv2d(32, 8, kernel_size=3, padding=1, bias=False) 262 | 263 | def forward(self, RGB_input): 264 | ###-------encoder for RGB-------#### 265 | fo_rgb = self.embed_conv_rgb(RGB_input) 266 | out_enc_rgb1 = self.encoders[0](fo_rgb) 267 | out_enc_rgb2 = self.encoders[1](self.down1(out_enc_rgb1)) 268 | # print(out_enc_rgb2.shape) 269 | 270 | out_enc_rgb3 = self.encoders[2](self.down2(out_enc_rgb2)) 271 | # print(out_enc_rgb3.shape) 272 | out_enc_rgb4 = self.encoders[3](self.down3(out_enc_rgb3)) 273 | # print(out_enc_rgb4.shape) 274 | 275 | ###-------Dencoder------### 276 | out_dec3 = self.decoders[0](self.reduces1(torch.cat([(self.ups_1(out_enc_rgb4)), out_enc_rgb3], dim=1))) 277 | # print(out_dec3.shape) 278 | out_dec2 = self.decoders[1](self.reduces2(torch.cat([self.ups_2(out_dec3),out_enc_rgb2], dim=1))) 279 | # print(out_dec2.shape) 280 | fd = self.decoders[2](torch.cat([self.ups_3(out_dec2),out_enc_rgb1], dim=1)) 281 | # print(fd.shape) 282 | # print('lasst',fd_FP.shape) 283 | fr = self.refinement(fd) 284 | return self.output(self.outputl(fr)) 285 | 286 | 287 | if __name__ == '__main__': 288 | t = torch.randn(1, 3, 256, 256).cuda() 289 | model = Model().cuda() 290 | res = model(t) 291 | print(res.shape) --------------------------------------------------------------------------------