├── img └── 1.png ├── data.py ├── ssim_psnr.py ├── README.md ├── Test_MP.py ├── utils.py ├── Test_MC.py ├── Train.py ├── dataset.py ├── PUIENet_MC.py └── PUIENet_MP.py /img/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhenqifu/PUIE-Net/HEAD/img/1.png -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from torchvision.transforms import Compose, ToTensor, Resize 3 | from dataset import DatasetFromFolderEval, DatasetFromFolder 4 | 5 | 6 | def transform(): 7 | return Compose([ 8 | # Resize((512, 512)), 9 | ToTensor(), 10 | ]) 11 | 12 | 13 | def get_training_set(path, label, data, patch_size, data_augmentation): 14 | label_dir = join(path, label) 15 | data_dir = join(path, data) 16 | 17 | return DatasetFromFolder(label_dir, data_dir, patch_size, data_augmentation, transform=transform()) 18 | 19 | 20 | def get_eval_set(data_dir, label_dir): 21 | return DatasetFromFolderEval(data_dir, label_dir, transform=transform()) 22 | 23 | 24 | -------------------------------------------------------------------------------- /ssim_psnr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import cv2 4 | from skimage.metrics import structural_similarity as ssim 5 | from skimage.metrics import peak_signal_noise_ratio as psnr 6 | 7 | 8 | # settings 9 | parser = argparse.ArgumentParser(description='Performance') 10 | parser.add_argument('--input_dir', default='results/UIEBD/mc') 11 | parser.add_argument('--reference_dir', default='dataset/new_UIEBD/test/label/O') 12 | 13 | opt = parser.parse_args() 14 | print(opt) 15 | 16 | 17 | im_path = opt.input_dir 18 | re_path = opt.reference_dir 19 | avg_psnr = 0 20 | avg_ssim = 0 21 | n = 0 22 | 23 | for filename in os.listdir(im_path): 24 | print(im_path + '/' + filename) 25 | n = n + 1 26 | im1 = cv2.imread(im_path + '/' + filename) 27 | im2 = cv2.imread(re_path + '/' + filename) 28 | 29 | (h, w, c) = im2.shape 30 | im1 = cv2.resize(im1, (w, h)) # reference size 31 | 32 | score_psnr = psnr(im1, im2) 33 | score_ssim = ssim(im1, im2, multichannel=True) 34 | 35 | avg_psnr += score_psnr 36 | avg_ssim += score_ssim 37 | 38 | avg_psnr = avg_psnr / n 39 | avg_ssim = avg_ssim / n 40 | print("===> Avg.PSNR: {:.4f} dB ".format(avg_psnr)) 41 | print("===> Avg.SSIM: {:.4f} ".format(avg_ssim)) 42 | 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Uncertainty Inspired Underwater Image Enhancement (ECCV 2022)([Paper](https://arxiv.org/pdf/2207.09689.pdf)) 2 | The Pytorch Implementation of ''Uncertainty Inspired Underwater Image Enhancement''. 3 | 4 |
5 | 6 | ## Introduction 7 | In this project, we use Ubuntu 16.04.5, Python 3.7, Pytorch 1.7.1 and one NVIDIA RTX 2080Ti GPU. 8 | 9 | ## Running 10 | 11 | ### Testing 12 | 13 | Download the pretrained model [pretrained model](https://drive.google.com/file/d/1rkGm0l826ybOk_RSJNSZwbKpJc_z2ZkU/view?usp=sharing). 14 | 15 | Check the model and image pathes in Test_MC.py and Test_MP.py, and then run: 16 | 17 | ``` 18 | python Test_MC.py 19 | ``` 20 | ``` 21 | python Test_MP.py 22 | ``` 23 | 24 | ### Training 25 | 26 | To train the model, you need to prepare our [dataset](https://drive.google.com/file/d/1YXdyNT9ac6CCpQTNKP7SnKtlRyugauvh/view?usp=sharing). 27 | 28 | Check the dataset path in Train.py, and then run: 29 | ``` 30 | python Train.py 31 | ``` 32 | 33 | ## Citation 34 | 35 | If you find PUIE-Net is useful in your research, please cite our paper: 36 | 37 | ``` 38 | @inproceedings{Fu_2022, 39 | title={Uncertainty Inspired Underwater Image Enhancement}, 40 | author={Fu, Zhenqi and Wang, Wu and Huang, Yue and Ding, Xinghao and Ma, Kai-Kuang}, 41 | booktitle={European Conference on Computer Vision (ECCV)}, 42 | year={2022}, 43 | pages={465--482}, 44 | } 45 | ``` 46 | 47 | -------------------------------------------------------------------------------- /Test_MP.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import cv2 4 | import time 5 | import argparse 6 | from torch.autograd import Variable 7 | from torch.utils.data import DataLoader 8 | from PUIENet_MP import mynet 9 | from data import get_eval_set 10 | 11 | 12 | # settings 13 | parser = argparse.ArgumentParser(description='PyTorch PUIE-Net') 14 | parser.add_argument('--testBatchSize', type=int, default=1, help='testing batch size') 15 | parser.add_argument('--gpu_mode', type=bool, default=True) 16 | parser.add_argument('--threads', type=int, default=4, help='number of threads for data loader to use') 17 | parser.add_argument('--seed', type=int, default=123, help='random seed to use Default=123') 18 | parser.add_argument('--device', type=str, default='cuda:2') 19 | parser.add_argument('--input_dir', type=str, default='dataset/new_UIEBD/test/image') 20 | parser.add_argument('--output', default='results/UIEBD', help='Location to save checkpoint models') 21 | parser.add_argument('--reference_out', type=str, default='mp') 22 | parser.add_argument('--model', default='weights/puie.pth', help='Pretrained base model') 23 | 24 | 25 | opt = parser.parse_args() 26 | print(opt) 27 | device = torch.device(opt.device) 28 | cuda = opt.gpu_mode 29 | if cuda and not torch.cuda.is_available(): 30 | raise Exception("No GPU found, please run without --cuda") 31 | 32 | # torch.manual_seed(opt.seed) 33 | # if cuda: 34 | # torch.cuda.manual_seed(opt.seed) 35 | 36 | print('===> Loading datasets') 37 | test_set = get_eval_set(opt.input_dir, opt.input_dir) 38 | testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False) 39 | 40 | print('===> Building model') 41 | 42 | model = mynet(opt) 43 | 44 | model.load_state_dict(torch.load(opt.model, map_location=lambda storage, loc: storage)) 45 | print('Pre-trained model is loaded.') 46 | 47 | if cuda: 48 | model = model.cuda(device) 49 | 50 | 51 | def eval(): 52 | model.eval() 53 | torch.set_grad_enabled(False) 54 | for batch in testing_data_loader: 55 | with torch.no_grad(): 56 | input, _, name = Variable(batch[0]), Variable(batch[1]), batch[2] 57 | if cuda: 58 | input = input.cuda(device) 59 | 60 | with torch.no_grad(): 61 | 62 | model.forward(input, input, training=False) 63 | t0 = time.time() 64 | prediction = model.sample(testing=True) 65 | t1 = time.time() 66 | print("===> Processing: %s || Timer: %.4f sec." % (name[0], (t1 - t0))) 67 | save_img_2(prediction.cpu().data, name[0], opt.reference_out) 68 | 69 | 70 | def save_img_2(img, img_name, out): 71 | save_img = img.squeeze().clamp(0, 1).numpy().transpose(1, 2, 0) 72 | # save img 73 | save_dir = os.path.join(opt.output, out) 74 | if not os.path.exists(save_dir): 75 | os.makedirs(save_dir) 76 | 77 | name_list = img_name.split('.', 1) 78 | save_fn = save_dir + '/' + name_list[0] + '.' + name_list[1] 79 | cv2.imwrite(save_fn, cv2.cvtColor(save_img * 255, cv2.COLOR_BGR2RGB), [cv2.IMWRITE_PNG_COMPRESSION, 0]) 80 | 81 | eval() 82 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import Normal, Independent 4 | import numpy as np 5 | 6 | 7 | class SELayer(torch.nn.Module): 8 | def __init__(self, num_filter): 9 | super(SELayer, self).__init__() 10 | self.global_pool = torch.nn.AdaptiveAvgPool2d(1) 11 | self.conv_double = torch.nn.Sequential( 12 | nn.Conv2d(num_filter, num_filter // 16, 1, 1, 0, bias=True), 13 | nn.LeakyReLU(), 14 | nn.Conv2d(num_filter // 16, num_filter, 1, 1, 0, bias=True), 15 | nn.Sigmoid()) 16 | 17 | def forward(self, x): 18 | mask = self.global_pool(x) 19 | mask = self.conv_double(mask) 20 | x = x * mask 21 | return x 22 | 23 | 24 | class ResBlock(nn.Module): 25 | def __init__(self, num_filter): 26 | super(ResBlock, self).__init__() 27 | body = [] 28 | for i in range(2): 29 | body.append(nn.ReflectionPad2d(1)) 30 | body.append(nn.Conv2d(num_filter, num_filter, kernel_size=3, padding=0)) 31 | if i == 0: 32 | body.append(nn.LeakyReLU()) 33 | body.append(SELayer(num_filter)) 34 | self.body = nn.Sequential(*body) 35 | 36 | def forward(self, x): 37 | res = self.body(x) 38 | x = res + x 39 | return x 40 | 41 | 42 | class Up(nn.Module): 43 | def __init__(self): 44 | super(Up, self).__init__() 45 | self.up = nn.Sequential( 46 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 47 | ) 48 | 49 | def forward(self, x): 50 | x = self.up(x) 51 | return x 52 | 53 | 54 | class ConvBlock(nn.Module): 55 | def __init__(self, ch_in, ch_out): 56 | super(ConvBlock, self).__init__() 57 | self.conv = nn.Sequential( 58 | nn.ReflectionPad2d(1), 59 | nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=0), 60 | nn.LeakyReLU(), 61 | nn.ReflectionPad2d(1), 62 | nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=0), 63 | nn.LeakyReLU(), 64 | ) 65 | 66 | def forward(self, x): 67 | x = self.conv(x) 68 | return x 69 | 70 | 71 | class Compute_z(nn.Module): 72 | def __init__(self, latent_dim): 73 | super(Compute_z, self).__init__() 74 | self.latent_dim = latent_dim 75 | self.u_conv_layer = nn.Conv2d(128, 2 * self.latent_dim, kernel_size=1, padding=0) 76 | self.s_conv_layer = nn.Conv2d(128, 2 * self.latent_dim, kernel_size=1, padding=0) 77 | 78 | def forward(self, x): 79 | u_encoding = torch.mean(x, dim=2, keepdim=True) 80 | u_encoding = torch.mean(u_encoding, dim=3, keepdim=True) 81 | u_mu_log_sigma = self.u_conv_layer(u_encoding) 82 | u_mu_log_sigma = torch.squeeze(u_mu_log_sigma, dim=2) 83 | u_mu_log_sigma = torch.squeeze(u_mu_log_sigma, dim=2) 84 | u_mu = u_mu_log_sigma[:, :self.latent_dim] 85 | u_log_sigma = u_mu_log_sigma[:, self.latent_dim:] 86 | u_dist = Independent(Normal(loc=u_mu, scale=torch.exp(u_log_sigma)), 1) 87 | 88 | s_encoding = torch.std(x, dim=2, keepdim=True) 89 | s_encoding = torch.std(s_encoding, dim=3, keepdim=True) 90 | s_mu_log_sigma = self.s_conv_layer(s_encoding) 91 | s_mu_log_sigma = torch.squeeze(s_mu_log_sigma, dim=2) 92 | s_mu_log_sigma = torch.squeeze(s_mu_log_sigma, dim=2) 93 | s_mu = s_mu_log_sigma[:, :self.latent_dim] 94 | s_log_sigma = s_mu_log_sigma[:, self.latent_dim:] 95 | s_dist = Independent(Normal(loc=s_mu, scale=torch.exp(s_log_sigma)), 1) 96 | return u_dist, s_dist, u_mu, s_mu, torch.exp(u_log_sigma), torch.exp(s_log_sigma) 97 | 98 | -------------------------------------------------------------------------------- /Test_MC.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import cv2 4 | import time 5 | import argparse 6 | from torch.autograd import Variable 7 | from torch.utils.data import DataLoader 8 | from PUIENet_MC import mynet 9 | from data import get_eval_set 10 | 11 | 12 | # settings 13 | parser = argparse.ArgumentParser(description='PyTorch PUIE-Net') 14 | parser.add_argument('--testBatchSize', type=int, default=1, help='testing batch size') 15 | parser.add_argument('--gpu_mode', type=bool, default=True) 16 | parser.add_argument('--threads', type=int, default=4, help='number of threads for data loader to use') 17 | parser.add_argument('--seed', type=int, default=123, help='random seed to use Default=123') 18 | parser.add_argument('--device', type=str, default='cuda:2') 19 | parser.add_argument('--input_dir', type=str, default='dataset/new_UIEBD/test/image') 20 | parser.add_argument('--output', default='results/UIEBD', help='Location to save checkpoint models') 21 | parser.add_argument('--sample_out', type=str, default='sample') 22 | parser.add_argument('--reference_out', type=str, default='mc') 23 | parser.add_argument('--model', default='weights/puie.pth', help='Pretrained base model') 24 | 25 | 26 | opt = parser.parse_args() 27 | print(opt) 28 | device = torch.device(opt.device) 29 | cuda = opt.gpu_mode 30 | if cuda and not torch.cuda.is_available(): 31 | raise Exception("No GPU found, please run without --cuda") 32 | 33 | # torch.manual_seed(opt.seed) 34 | # if cuda: 35 | # torch.cuda.manual_seed(opt.seed) 36 | 37 | print('===> Loading datasets') 38 | test_set = get_eval_set(opt.input_dir, opt.input_dir) 39 | testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False) 40 | 41 | print('===> Building model') 42 | 43 | model = mynet(opt) 44 | 45 | model.load_state_dict(torch.load(opt.model, map_location=lambda storage, loc: storage)) 46 | print('Pre-trained model is loaded.') 47 | 48 | if cuda: 49 | model = model.cuda(device) 50 | 51 | 52 | def eval(): 53 | model.eval() 54 | torch.set_grad_enabled(False) 55 | for batch in testing_data_loader: 56 | with torch.no_grad(): 57 | input, _, name = Variable(batch[0]), Variable(batch[1]), batch[2] 58 | if cuda: 59 | input = input.cuda(device) 60 | 61 | with torch.no_grad(): 62 | 63 | model.forward(input, input, training=False) 64 | avg_pre = 0 65 | for i in range(20): 66 | t0 = time.time() 67 | prediction = model.sample(testing=True) 68 | t1 = time.time() 69 | avg_pre = avg_pre + prediction / 20 70 | save_img_1(prediction.cpu().data, name[0], i, opt.sample_out) 71 | 72 | print("===> Processing: %s || Timer: %.4f sec." % (name[0], (t1 - t0))) 73 | save_img_2(avg_pre.cpu().data, name[0], opt.reference_out) 74 | 75 | 76 | def save_img_1(img, img_name, i, out): 77 | save_img = img.squeeze().clamp(0, 1).numpy().transpose(1, 2, 0) 78 | # save img 79 | save_dir = os.path.join(opt.output, out) 80 | if not os.path.exists(save_dir): 81 | os.makedirs(save_dir) 82 | 83 | name_list = img_name.split('.', 1) 84 | save_fn = save_dir + '/' + name_list[0] + '_' + str(i) + '.' + name_list[1] 85 | cv2.imwrite(save_fn, cv2.cvtColor(save_img * 255, cv2.COLOR_BGR2RGB), [cv2.IMWRITE_PNG_COMPRESSION, 0]) 86 | 87 | 88 | def save_img_2(img, img_name, out): 89 | save_img = img.squeeze().clamp(0, 1).numpy().transpose(1, 2, 0) 90 | # save img 91 | save_dir = os.path.join(opt.output, out) 92 | if not os.path.exists(save_dir): 93 | os.makedirs(save_dir) 94 | 95 | name_list = img_name.split('.', 1) 96 | save_fn = save_dir + '/' + name_list[0] + '.' + name_list[1] 97 | cv2.imwrite(save_fn, cv2.cvtColor(save_img * 255, cv2.COLOR_BGR2RGB), [cv2.IMWRITE_PNG_COMPRESSION, 0]) 98 | 99 | eval() 100 | -------------------------------------------------------------------------------- /Train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import socket 3 | import time 4 | import argparse 5 | import torch.optim as optim 6 | import torch.backends.cudnn as cudnn 7 | import torch.optim.lr_scheduler as lrs 8 | from torch.utils.data import DataLoader 9 | from PUIENet_MC import mynet 10 | from torch.autograd import Variable 11 | from data import get_training_set 12 | 13 | 14 | # Training settings 15 | parser = argparse.ArgumentParser(description='PyTorch PUIE-Net') 16 | parser.add_argument('--device', type=str, default='cuda:0') 17 | parser.add_argument('--batchSize', type=int, default=4, help='training batch size') 18 | parser.add_argument('--nEpochs', type=int, default=500, help='number of epochs to train for') 19 | parser.add_argument('--snapshots', type=int, default=2, help='Snapshots') 20 | parser.add_argument('--start_iter', type=int, default=1, help='Starting Epoch') 21 | parser.add_argument('--lr', type=float, default=1e-4, help='Learning Rate. Default=1e-4') 22 | parser.add_argument('--data_dir', type=str, default='dataset/new_UIEBD/train') 23 | parser.add_argument('--label_train_dataset', type=str, default='label') 24 | parser.add_argument('--data_train_dataset', type=str, default='image') 25 | parser.add_argument('--patch_size', type=int, default=256, help='Size of cropped image') 26 | parser.add_argument('--save_folder', default='weights/', help='Location to save checkpoint models') 27 | parser.add_argument('--gpu_mode', type=bool, default=True) 28 | parser.add_argument('--threads', type=int, default=4, help='number of threads for data loader to use') 29 | parser.add_argument('--decay', type=int, default='10000', help='learning rate decay type') 30 | parser.add_argument('--gamma', type=float, default=0.5, help='learning rate decay factor for step decay') 31 | parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123') 32 | parser.add_argument('--data_augmentation', type=bool, default=True) 33 | 34 | 35 | opt = parser.parse_args() 36 | device = torch.device(opt.device) 37 | hostname = str(socket.gethostname()) 38 | cudnn.benchmark = True 39 | print(opt) 40 | 41 | 42 | def train(epoch): 43 | epoch_loss = 0 44 | model.train() 45 | for iteration, batch in enumerate(training_data_loader, 1): 46 | input, target = Variable(batch[0]), Variable(batch[1]) 47 | if cuda: 48 | input = input.to(device) 49 | target = target.to(device) 50 | 51 | t0 = time.time() 52 | model.forward(input, target, training=True) 53 | loss = model.elbo(target) 54 | optimizer.zero_grad() 55 | loss.backward() 56 | epoch_loss += loss.item() 57 | optimizer.step() 58 | t1 = time.time() 59 | 60 | print("===> Epoch[{}]({}/{}): Loss: {:.4f} || Learning rate: lr={} || Timer: {:.4f} sec.".format(epoch, iteration, 61 | len(training_data_loader), loss.item(), optimizer.param_groups[0]['lr'], (t1 - t0))) 62 | print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(training_data_loader))) 63 | 64 | 65 | def checkpoint(epoch): 66 | model_out_path = opt.save_folder+"epoch_{}.pth".format(epoch) 67 | torch.save(model.state_dict(), model_out_path) 68 | print("Checkpoint saved to {}".format(model_out_path)) 69 | 70 | 71 | cuda = opt.gpu_mode 72 | if cuda and not torch.cuda.is_available(): 73 | raise Exception("No GPU found, please run without --cuda") 74 | 75 | torch.manual_seed(opt.seed) 76 | if cuda: 77 | torch.cuda.manual_seed(opt.seed) 78 | 79 | print('===> Loading datasets') 80 | 81 | train_set = get_training_set(opt.data_dir, opt.label_train_dataset, opt.data_train_dataset, opt.patch_size, opt.data_augmentation) 82 | training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True) 83 | 84 | model = mynet(opt) 85 | 86 | # print('---------- Networks architecture -------------') 87 | # print_network(model) 88 | # print('----------------------------------------------') 89 | 90 | optimizer = optim.Adam(model.parameters(), lr=opt.lr, betas=(0.9, 0.999), eps=1e-8) 91 | 92 | milestones = [] 93 | for i in range(1, opt.nEpochs+1): 94 | if i % opt.decay == 0: 95 | milestones.append(i) 96 | 97 | scheduler = lrs.MultiStepLR(optimizer, milestones, opt.gamma) 98 | 99 | for epoch in range(opt.start_iter, opt.nEpochs + 1): 100 | 101 | train(epoch) 102 | scheduler.step() 103 | 104 | if (epoch+1) % opt.snapshots == 0: 105 | checkpoint(epoch) 106 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch.utils.data as data 4 | from os import listdir 5 | from os.path import join 6 | from PIL import Image, ImageOps 7 | 8 | 9 | def is_image_file(filename): 10 | return any(filename.endswith(extension) for extension in [".png", ".jpg", ".bmp"]) 11 | 12 | 13 | def load_img(filepath): 14 | img = Image.open(filepath).convert('RGB') 15 | return img 16 | 17 | 18 | def rescale_img(img_in, scale): 19 | size_in = img_in.size 20 | new_size_in = tuple([int(x * scale) for x in size_in]) 21 | img_in = img_in.resize(new_size_in, resample=Image.BICUBIC) 22 | return img_in 23 | 24 | 25 | def get_patch(img_in, img_tar, patch_size, scale=1, ix=-1, iy=-1): 26 | (ih, iw) = img_in.size 27 | 28 | patch_mult = scale 29 | tp = patch_mult * patch_size 30 | ip = tp // scale 31 | 32 | if ix == -1: 33 | ix = random.randrange(0, iw - ip + 1) 34 | if iy == -1: 35 | iy = random.randrange(0, ih - ip + 1) 36 | 37 | (tx, ty) = (scale * ix, scale * iy) 38 | 39 | img_in = img_in.crop((iy, ix, iy + ip, ix + ip)) 40 | img_tar = img_tar.crop((ty, tx, ty + tp, tx + tp)) 41 | 42 | info_patch = { 43 | 'ix': ix, 'iy': iy, 'ip': ip, 'tx': tx, 'ty': ty, 'tp': tp} 44 | 45 | return img_in, img_tar, info_patch 46 | 47 | 48 | def augment(img_in, img_tar, flip_h=True, rot=True): 49 | info_aug = {'flip_h': False, 'flip_v': False, 'trans': False} 50 | 51 | if random.random() < 0.5 and flip_h: 52 | img_in = ImageOps.flip(img_in) 53 | img_tar = ImageOps.flip(img_tar) 54 | info_aug['flip_h'] = True 55 | 56 | if rot: 57 | if random.random() < 0.5: 58 | img_in = ImageOps.mirror(img_in) 59 | img_tar = ImageOps.mirror(img_tar) 60 | info_aug['flip_v'] = True 61 | if random.random() < 0.5: 62 | img_in = img_in.rotate(180) 63 | img_tar = img_tar.rotate(180) 64 | info_aug['trans'] = True 65 | 66 | return img_in, img_tar, info_aug 67 | 68 | 69 | class DatasetFromFolder(data.Dataset): 70 | def __init__(self, label_dir, data_dir, patch_size, data_augmentation, transform=None): 71 | super(DatasetFromFolder, self).__init__() 72 | self.label_path = label_dir 73 | data_filenames = [join(data_dir, x) for x in listdir(data_dir) if is_image_file(x)] 74 | data_filenames.sort() 75 | self.data_filenames = data_filenames 76 | self.patch_size = patch_size 77 | self.transform = transform 78 | self.data_augmentation = data_augmentation 79 | 80 | def __getitem__(self, index): 81 | _, file = os.path.split(self.data_filenames[index]) 82 | 83 | k = random.randint(0,3) 84 | if k == 0: 85 | label_filenames = self.label_path + '/O/' + file 86 | if k == 1 : 87 | label_filenames = self.label_path + '/S/' + file 88 | if k == 2 : 89 | label_filenames = self.label_path + '/C/' + file 90 | if k == 3 : 91 | label_filenames = self.label_path + '/G/' + file 92 | 93 | 94 | target = load_img(label_filenames) 95 | input = load_img(self.data_filenames[index]) 96 | 97 | input = input.resize((512, 512), resample=Image.BICUBIC) 98 | target = target.resize((512, 512), resample=Image.BICUBIC) 99 | input, target, _ = get_patch(input, target, self.patch_size) 100 | 101 | if self.data_augmentation: 102 | input, target, _ = augment(input, target) 103 | 104 | if self.transform: 105 | input = self.transform(input) 106 | target = self.transform(target) 107 | 108 | return input, target, file 109 | 110 | def __len__(self): 111 | return len(self.data_filenames) 112 | 113 | 114 | class DatasetFromFolderEval(data.Dataset): 115 | def __init__(self, data_dir, label_dir, transform=None): 116 | super(DatasetFromFolderEval, self).__init__() 117 | data_filenames = [join(data_dir, x) for x in listdir(data_dir) if is_image_file(x)] 118 | data_filenames.sort() 119 | self.data_filenames = data_filenames 120 | 121 | label_filenames = [join(label_dir, x) for x in listdir(label_dir) if is_image_file(x)] 122 | label_filenames.sort() 123 | self.label_filenames = label_filenames 124 | 125 | self.transform = transform 126 | 127 | def __getitem__(self, index): 128 | input = load_img(self.data_filenames[index]) 129 | label = load_img(self.label_filenames[index]) 130 | _, file = os.path.split(self.data_filenames[index]) 131 | 132 | (ih, iw) = input.size 133 | dh = ih % 8 134 | dw = iw % 8 135 | new_h, new_w = ih - dh, iw - dw 136 | 137 | input = input.resize((new_h, new_w)) 138 | label = label.resize((new_h, new_w)) 139 | 140 | if self.transform: 141 | input = self.transform(input) 142 | label = self.transform(label) 143 | 144 | return input, label, file 145 | 146 | def __len__(self): 147 | return len(self.data_filenames) 148 | 149 | 150 | -------------------------------------------------------------------------------- /PUIENet_MC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models.vgg import vgg16 4 | from torch.distributions import kl 5 | from utils import ResBlock, ConvBlock, Up, Compute_z 6 | 7 | 8 | class Encoder(nn.Module): 9 | def __init__(self, ch): 10 | super(Encoder, self).__init__() 11 | self.conv1 = ConvBlock(ch_in=ch, ch_out=64) 12 | self.conv2 = ConvBlock(ch_in=64, ch_out=64) 13 | self.conv3 = ConvBlock(ch_in=64, ch_out=64) 14 | self.conv4 = ResBlock(64) 15 | self.pool1 = nn.MaxPool2d(2) 16 | self.pool2 = nn.MaxPool2d(2) 17 | self.pool3 = nn.MaxPool2d(2) 18 | def forward(self, x): 19 | x1 = self.conv1(x) 20 | x2 = self.pool1(x1) 21 | x2 = self.conv2(x2) 22 | x3 = self.pool2(x2) 23 | x3 = self.conv3(x3) 24 | x4 = self.pool3(x3) 25 | x4 = self.conv4(x4) 26 | return x1, x2, x3, x4 27 | 28 | 29 | class Decoder(nn.Module): 30 | def __init__(self, device): 31 | super(Decoder, self).__init__() 32 | self.device = device 33 | self.pr_encoder = Encoder(3) 34 | self.po_encoder = Encoder(6) 35 | 36 | self.pr_conv = ResBlock(64) 37 | self.po_conv = ResBlock(64) 38 | 39 | self.pr_Up3 = Up() 40 | self.pr_UpConv3 = ConvBlock(ch_in=128, ch_out=64) 41 | self.pr_Up2 = Up() 42 | self.pr_UpConv2 = ConvBlock(ch_in=128, ch_out=64) 43 | self.pr_Up1 = Up() 44 | self.pr_UpConv1 = ConvBlock(ch_in=128, ch_out=64) 45 | 46 | self.po_Up3 = Up() 47 | self.po_UpConv3 = ConvBlock(ch_in=128, ch_out=64) 48 | self.po_Up2 = Up() 49 | self.po_UpConv2 = ConvBlock(ch_in=128, ch_out=64) 50 | 51 | out_conv = [] 52 | out_conv.append(ResBlock(64)) 53 | out_conv.append(ResBlock(64)) 54 | out_conv.append(nn.Conv2d(64, 3, kernel_size=1, padding=0)) 55 | self.out_conv = nn.Sequential(*out_conv) 56 | 57 | z = 20 58 | 59 | self.compute_z_pr = Compute_z(z) 60 | self.compute_z_po = Compute_z(z) 61 | 62 | self.conv_u = nn.Conv2d(z, 128, kernel_size=1, padding=0) 63 | self.conv_s = nn.Conv2d(z, 128, kernel_size=1, padding=0) 64 | 65 | self.insnorm = nn.InstanceNorm2d(128) 66 | self.sigmoid = nn.Sigmoid() 67 | 68 | 69 | def forward(self, Input, Target, training=True): 70 | 71 | pr_x1, pr_x2, pr_x3, pr_x4 = self.pr_encoder.forward(Input) 72 | 73 | if training: 74 | po_x1, po_x2, po_x3, po_x4 = self.po_encoder.forward(torch.cat((Input, Target), dim=1)) 75 | # x4->x3 76 | pr_x4 = self.pr_conv(pr_x4) 77 | po_x4 = self.po_conv(po_x4) 78 | pr_d3 = self.pr_Up3(pr_x4) 79 | po_d3 = self.po_Up3(po_x4) 80 | # x3->x2 81 | pr_d3 = torch.cat((pr_x3, pr_d3), dim=1) 82 | po_d3 = torch.cat((po_x3, po_d3), dim=1) 83 | pr_d3 = self.pr_UpConv3(pr_d3) 84 | po_d3 = self.po_UpConv3(po_d3) 85 | pr_d2 = self.pr_Up2(pr_d3) 86 | po_d2 = self.po_Up2(po_d3) 87 | # x2->x1 88 | pr_d2 = torch.cat((pr_x2, pr_d2), dim=1) 89 | po_d2 = torch.cat((po_x2, po_d2), dim=1) 90 | pr_d2 = self.pr_UpConv2(pr_d2) 91 | po_d2 = self.po_UpConv2(po_d2) 92 | pr_d1 = self.pr_Up1(pr_d2) 93 | po_d1 = self.pr_Up1(po_d2) 94 | # cat 95 | pr_d1 = torch.cat((pr_x1, pr_d1), dim=1) 96 | po_d1 = torch.cat((po_x1, po_d1), dim=1) 97 | # x1->dis 98 | pr_u_dist, pr_s_dist, _, _, _, _ = self.compute_z_pr(pr_d1) 99 | po_u_dist, po_s_dist, _, _, _, _ = self.compute_z_po(po_d1) 100 | po_latent_u = po_u_dist.rsample() 101 | po_latent_s = po_s_dist.rsample() 102 | po_latent_u = torch.unsqueeze(po_latent_u, -1) 103 | po_latent_u = torch.unsqueeze(po_latent_u, -1) 104 | po_latent_s = torch.unsqueeze(po_latent_s, -1) 105 | po_latent_s = torch.unsqueeze(po_latent_s, -1) 106 | po_u = self.conv_u(po_latent_u) 107 | po_s = self.conv_s(po_latent_s) 108 | pr_d1 = self.insnorm(pr_d1) * torch.abs(po_s) + po_u 109 | # x1->out 110 | pr_d1 = self.pr_UpConv1(pr_d1) 111 | out = self.out_conv(pr_d1) 112 | 113 | return out, pr_u_dist, pr_s_dist, po_u_dist, po_s_dist 114 | 115 | else: 116 | # x4->x3 117 | pr_x4 = self.pr_conv(pr_x4) 118 | pr_d3 = self.pr_Up3(pr_x4) 119 | # x3->x2 120 | pr_d3 = torch.cat((pr_x3, pr_d3), dim=1) 121 | pr_d3 = self.pr_UpConv3(pr_d3) 122 | pr_d2 = self.pr_Up2(pr_d3) 123 | # x2->x1 124 | pr_d2 = torch.cat((pr_x2, pr_d2), dim=1) 125 | pr_d2 = self.pr_UpConv2(pr_d2) 126 | pr_d1 = self.pr_Up1(pr_d2) 127 | # cat 128 | pr_d1 = torch.cat((pr_x1, pr_d1), dim=1) 129 | # x1->dis 130 | pr_u_dist, pr_s_dist, _, _, _, _ = self.compute_z_pr(pr_d1) 131 | pr_latent_u = pr_u_dist.rsample() 132 | pr_latent_s = pr_s_dist.rsample() 133 | pr_latent_u = torch.unsqueeze(pr_latent_u, -1) 134 | pr_latent_u = torch.unsqueeze(pr_latent_u, -1) 135 | pr_latent_s = torch.unsqueeze(pr_latent_s, -1) 136 | pr_latent_s = torch.unsqueeze(pr_latent_s, -1) 137 | pr_u = self.conv_u(pr_latent_u) 138 | pr_s = self.conv_s(pr_latent_s) 139 | pr_d1 = self.insnorm(pr_d1) * torch.abs(pr_s) + pr_u 140 | # x1->out 141 | pr_d1 = self.pr_UpConv1(pr_d1) 142 | out = self.out_conv(pr_d1) 143 | return out 144 | 145 | 146 | class PerceptionLoss(nn.Module): 147 | def __init__(self): 148 | super(PerceptionLoss, self).__init__() 149 | vgg = vgg16(pretrained=True) 150 | loss_network = nn.Sequential(*list(vgg.features)[:31]).eval() 151 | for param in loss_network.parameters(): 152 | param.requires_grad = False 153 | self.loss_network = loss_network 154 | self.mse_loss = nn.MSELoss() 155 | 156 | def forward(self, out_images, target_images): 157 | perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images)) 158 | return perception_loss 159 | 160 | 161 | class mynet(nn.Module): 162 | def __init__(self, opt): 163 | super(mynet, self).__init__() 164 | self.device = torch.device(opt.device) 165 | self.decoder = Decoder(device=self.device).to(self.device) 166 | self.criterion = nn.MSELoss().to(self.device) 167 | self.VGG16 = PerceptionLoss().to(self.device) 168 | 169 | def forward(self, Input, label, training=True): 170 | self.Input = Input 171 | self.label = label 172 | if training: 173 | self.out, self.pr_u, self.pr_s, self.po_u, self.po_s = self.decoder.forward(Input, label) 174 | 175 | def sample(self, testing=False): 176 | if testing: 177 | self.out = self.decoder.forward(self.Input, self.label, training=False) 178 | return self.out 179 | 180 | def kl_divergence(self, analytic=True): 181 | if analytic: 182 | kl_div_u = torch.mean(kl.kl_divergence(self.po_u, self.pr_u)) 183 | kl_div_s = torch.mean(kl.kl_divergence(self.po_s, self.pr_s)) 184 | return kl_div_u + kl_div_s 185 | 186 | def elbo(self, target, analytic_kl=True): 187 | self.kl_loss = self.kl_divergence(analytic=analytic_kl) 188 | self.reconstruction_loss = self.criterion(self.out, target) 189 | self.vgg16_loss = self.VGG16(self.out, target) 190 | return self.reconstruction_loss + self.vgg16_loss + self.kl_loss 191 | -------------------------------------------------------------------------------- /PUIENet_MP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models.vgg import vgg16 4 | from torch.distributions import kl 5 | from utils import ResBlock, ConvBlock, Up, Compute_z 6 | 7 | 8 | class Encoder(nn.Module): 9 | def __init__(self, ch): 10 | super(Encoder, self).__init__() 11 | self.conv1 = ConvBlock(ch_in=ch, ch_out=64) 12 | self.conv2 = ConvBlock(ch_in=64, ch_out=64) 13 | self.conv3 = ConvBlock(ch_in=64, ch_out=64) 14 | self.conv4 = ResBlock(64) 15 | self.pool1 = nn.MaxPool2d(2) 16 | self.pool2 = nn.MaxPool2d(2) 17 | self.pool3 = nn.MaxPool2d(2) 18 | def forward(self, x): 19 | x1 = self.conv1(x) 20 | x2 = self.pool1(x1) 21 | x2 = self.conv2(x2) 22 | x3 = self.pool2(x2) 23 | x3 = self.conv3(x3) 24 | x4 = self.pool3(x3) 25 | x4 = self.conv4(x4) 26 | return x1, x2, x3, x4 27 | 28 | 29 | class Decoder(nn.Module): 30 | def __init__(self, device): 31 | super(Decoder, self).__init__() 32 | self.device = device 33 | self.pr_encoder = Encoder(3) 34 | self.po_encoder = Encoder(6) 35 | 36 | self.pr_conv = ResBlock(64) 37 | self.po_conv = ResBlock(64) 38 | 39 | self.pr_Up3 = Up() 40 | self.pr_UpConv3 = ConvBlock(ch_in=128, ch_out=64) 41 | self.pr_Up2 = Up() 42 | self.pr_UpConv2 = ConvBlock(ch_in=128, ch_out=64) 43 | self.pr_Up1 = Up() 44 | self.pr_UpConv1 = ConvBlock(ch_in=128, ch_out=64) 45 | 46 | self.po_Up3 = Up() 47 | self.po_UpConv3 = ConvBlock(ch_in=128, ch_out=64) 48 | self.po_Up2 = Up() 49 | self.po_UpConv2 = ConvBlock(ch_in=128, ch_out=64) 50 | 51 | out_conv = [] 52 | out_conv.append(ResBlock(64)) 53 | out_conv.append(ResBlock(64)) 54 | out_conv.append(nn.Conv2d(64, 3, kernel_size=1, padding=0)) 55 | self.out_conv = nn.Sequential(*out_conv) 56 | 57 | z = 20 58 | 59 | self.compute_z_pr = Compute_z(z) 60 | self.compute_z_po = Compute_z(z) 61 | 62 | self.conv_u = nn.Conv2d(z, 128, kernel_size=1, padding=0) 63 | self.conv_s = nn.Conv2d(z, 128, kernel_size=1, padding=0) 64 | 65 | self.insnorm = nn.InstanceNorm2d(128) 66 | self.sigmoid = nn.Sigmoid() 67 | 68 | 69 | def forward(self, Input, Target, training=True): 70 | 71 | pr_x1, pr_x2, pr_x3, pr_x4 = self.pr_encoder.forward(Input) 72 | 73 | if training: 74 | po_x1, po_x2, po_x3, po_x4 = self.po_encoder.forward(torch.cat((Input, Target), dim=1)) 75 | # x4->x3 76 | pr_x4 = self.pr_conv(pr_x4) 77 | po_x4 = self.po_conv(po_x4) 78 | pr_d3 = self.pr_Up3(pr_x4) 79 | po_d3 = self.po_Up3(po_x4) 80 | # x3->x2 81 | pr_d3 = torch.cat((pr_x3, pr_d3), dim=1) 82 | po_d3 = torch.cat((po_x3, po_d3), dim=1) 83 | pr_d3 = self.pr_UpConv3(pr_d3) 84 | po_d3 = self.po_UpConv3(po_d3) 85 | pr_d2 = self.pr_Up2(pr_d3) 86 | po_d2 = self.po_Up2(po_d3) 87 | # x2->x1 88 | pr_d2 = torch.cat((pr_x2, pr_d2), dim=1) 89 | po_d2 = torch.cat((po_x2, po_d2), dim=1) 90 | pr_d2 = self.pr_UpConv2(pr_d2) 91 | po_d2 = self.po_UpConv2(po_d2) 92 | pr_d1 = self.pr_Up1(pr_d2) 93 | po_d1 = self.pr_Up1(po_d2) 94 | # cat 95 | pr_d1 = torch.cat((pr_x1, pr_d1), dim=1) 96 | po_d1 = torch.cat((po_x1, po_d1), dim=1) 97 | # x1->dis 98 | pr_u_dist, pr_s_dist, _, _, _, _ = self.compute_z_pr(pr_d1) 99 | po_u_dist, po_s_dist, _, _, _, _ = self.compute_z_po(po_d1) 100 | po_latent_u = po_u_dist.rsample() 101 | po_latent_s = po_s_dist.rsample() 102 | po_latent_u = torch.unsqueeze(po_latent_u, -1) 103 | po_latent_u = torch.unsqueeze(po_latent_u, -1) 104 | po_latent_s = torch.unsqueeze(po_latent_s, -1) 105 | po_latent_s = torch.unsqueeze(po_latent_s, -1) 106 | po_u = self.conv_u(po_latent_u) 107 | po_s = self.conv_s(po_latent_s) 108 | pr_d1 = self.insnorm(pr_d1) * torch.abs(po_s) + po_u 109 | # x1->out 110 | pr_d1 = self.pr_UpConv1(pr_d1) 111 | out = self.out_conv(pr_d1) 112 | 113 | return out, pr_u_dist, pr_s_dist, po_u_dist, po_s_dist 114 | 115 | else: 116 | # x4->x3 117 | pr_x4 = self.pr_conv(pr_x4) 118 | pr_d3 = self.pr_Up3(pr_x4) 119 | # x3->x2 120 | pr_d3 = torch.cat((pr_x3, pr_d3), dim=1) 121 | pr_d3 = self.pr_UpConv3(pr_d3) 122 | pr_d2 = self.pr_Up2(pr_d3) 123 | # x2->x1 124 | pr_d2 = torch.cat((pr_x2, pr_d2), dim=1) 125 | pr_d2 = self.pr_UpConv2(pr_d2) 126 | pr_d1 = self.pr_Up1(pr_d2) 127 | # cat 128 | pr_d1 = torch.cat((pr_x1, pr_d1), dim=1) 129 | # x1->dis 130 | pr_u_dist, pr_s_dist, u_mu, s_mu, u_sigma, s_sigma = self.compute_z_pr(pr_d1) 131 | 132 | pr_latent_u = u_mu + u_sigma * 0 133 | pr_latent_s = s_mu + s_sigma * 0 134 | 135 | pr_latent_u = torch.unsqueeze(pr_latent_u, -1) 136 | pr_latent_u = torch.unsqueeze(pr_latent_u, -1) 137 | pr_latent_s = torch.unsqueeze(pr_latent_s, -1) 138 | pr_latent_s = torch.unsqueeze(pr_latent_s, -1) 139 | pr_u = self.conv_u(pr_latent_u) 140 | pr_s = self.conv_s(pr_latent_s) 141 | pr_d1 = self.insnorm(pr_d1) * torch.abs(pr_s) + pr_u 142 | # x1->out 143 | pr_d1 = self.pr_UpConv1(pr_d1) 144 | out = self.out_conv(pr_d1) 145 | return out 146 | 147 | 148 | class PerceptionLoss(nn.Module): 149 | def __init__(self): 150 | super(PerceptionLoss, self).__init__() 151 | vgg = vgg16(pretrained=True) 152 | loss_network = nn.Sequential(*list(vgg.features)[:31]).eval() 153 | for param in loss_network.parameters(): 154 | param.requires_grad = False 155 | self.loss_network = loss_network 156 | self.mse_loss = nn.MSELoss() 157 | 158 | def forward(self, out_images, target_images): 159 | perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images)) 160 | return perception_loss 161 | 162 | 163 | class mynet(nn.Module): 164 | def __init__(self, opt): 165 | super(mynet, self).__init__() 166 | self.device = torch.device(opt.device) 167 | self.decoder = Decoder(device=self.device).to(self.device) 168 | self.criterion = nn.MSELoss().to(self.device) 169 | self.VGG16 = PerceptionLoss().to(self.device) 170 | 171 | def forward(self, Input, label, training=True): 172 | self.Input = Input 173 | self.label = label 174 | if training: 175 | self.out, self.pr_u, self.pr_s, self.po_u, self.po_s = self.decoder.forward(Input, label) 176 | 177 | def sample(self, testing=False): 178 | if testing: 179 | self.out = self.decoder.forward(self.Input, self.label, training=False) 180 | return self.out 181 | 182 | def kl_divergence(self, analytic=True): 183 | if analytic: 184 | kl_div_u = torch.mean(kl.kl_divergence(self.po_u, self.pr_u)) 185 | kl_div_s = torch.mean(kl.kl_divergence(self.po_s, self.pr_s)) 186 | return kl_div_u + kl_div_s 187 | 188 | def elbo(self, target, analytic_kl=True): 189 | self.kl_loss = self.kl_divergence(analytic=analytic_kl) 190 | self.reconstruction_loss = self.criterion(self.out, target) 191 | self.vgg16_loss = self.VGG16(self.out, target) 192 | return self.reconstruction_loss + self.vgg16_loss + self.kl_loss 193 | --------------------------------------------------------------------------------