├── img
└── 1.png
├── data.py
├── ssim_psnr.py
├── README.md
├── Test_MP.py
├── utils.py
├── Test_MC.py
├── Train.py
├── dataset.py
├── PUIENet_MC.py
└── PUIENet_MP.py
/img/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhenqifu/PUIE-Net/HEAD/img/1.png
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | from os.path import join
2 | from torchvision.transforms import Compose, ToTensor, Resize
3 | from dataset import DatasetFromFolderEval, DatasetFromFolder
4 |
5 |
6 | def transform():
7 | return Compose([
8 | # Resize((512, 512)),
9 | ToTensor(),
10 | ])
11 |
12 |
13 | def get_training_set(path, label, data, patch_size, data_augmentation):
14 | label_dir = join(path, label)
15 | data_dir = join(path, data)
16 |
17 | return DatasetFromFolder(label_dir, data_dir, patch_size, data_augmentation, transform=transform())
18 |
19 |
20 | def get_eval_set(data_dir, label_dir):
21 | return DatasetFromFolderEval(data_dir, label_dir, transform=transform())
22 |
23 |
24 |
--------------------------------------------------------------------------------
/ssim_psnr.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import cv2
4 | from skimage.metrics import structural_similarity as ssim
5 | from skimage.metrics import peak_signal_noise_ratio as psnr
6 |
7 |
8 | # settings
9 | parser = argparse.ArgumentParser(description='Performance')
10 | parser.add_argument('--input_dir', default='results/UIEBD/mc')
11 | parser.add_argument('--reference_dir', default='dataset/new_UIEBD/test/label/O')
12 |
13 | opt = parser.parse_args()
14 | print(opt)
15 |
16 |
17 | im_path = opt.input_dir
18 | re_path = opt.reference_dir
19 | avg_psnr = 0
20 | avg_ssim = 0
21 | n = 0
22 |
23 | for filename in os.listdir(im_path):
24 | print(im_path + '/' + filename)
25 | n = n + 1
26 | im1 = cv2.imread(im_path + '/' + filename)
27 | im2 = cv2.imread(re_path + '/' + filename)
28 |
29 | (h, w, c) = im2.shape
30 | im1 = cv2.resize(im1, (w, h)) # reference size
31 |
32 | score_psnr = psnr(im1, im2)
33 | score_ssim = ssim(im1, im2, multichannel=True)
34 |
35 | avg_psnr += score_psnr
36 | avg_ssim += score_ssim
37 |
38 | avg_psnr = avg_psnr / n
39 | avg_ssim = avg_ssim / n
40 | print("===> Avg.PSNR: {:.4f} dB ".format(avg_psnr))
41 | print("===> Avg.SSIM: {:.4f} ".format(avg_ssim))
42 |
43 |
44 |
45 |
46 |
47 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Uncertainty Inspired Underwater Image Enhancement (ECCV 2022)([Paper](https://arxiv.org/pdf/2207.09689.pdf))
2 | The Pytorch Implementation of ''Uncertainty Inspired Underwater Image Enhancement''.
3 |
4 |

5 |
6 | ## Introduction
7 | In this project, we use Ubuntu 16.04.5, Python 3.7, Pytorch 1.7.1 and one NVIDIA RTX 2080Ti GPU.
8 |
9 | ## Running
10 |
11 | ### Testing
12 |
13 | Download the pretrained model [pretrained model](https://drive.google.com/file/d/1rkGm0l826ybOk_RSJNSZwbKpJc_z2ZkU/view?usp=sharing).
14 |
15 | Check the model and image pathes in Test_MC.py and Test_MP.py, and then run:
16 |
17 | ```
18 | python Test_MC.py
19 | ```
20 | ```
21 | python Test_MP.py
22 | ```
23 |
24 | ### Training
25 |
26 | To train the model, you need to prepare our [dataset](https://drive.google.com/file/d/1YXdyNT9ac6CCpQTNKP7SnKtlRyugauvh/view?usp=sharing).
27 |
28 | Check the dataset path in Train.py, and then run:
29 | ```
30 | python Train.py
31 | ```
32 |
33 | ## Citation
34 |
35 | If you find PUIE-Net is useful in your research, please cite our paper:
36 |
37 | ```
38 | @inproceedings{Fu_2022,
39 | title={Uncertainty Inspired Underwater Image Enhancement},
40 | author={Fu, Zhenqi and Wang, Wu and Huang, Yue and Ding, Xinghao and Ma, Kai-Kuang},
41 | booktitle={European Conference on Computer Vision (ECCV)},
42 | year={2022},
43 | pages={465--482},
44 | }
45 | ```
46 |
47 |
--------------------------------------------------------------------------------
/Test_MP.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import cv2
4 | import time
5 | import argparse
6 | from torch.autograd import Variable
7 | from torch.utils.data import DataLoader
8 | from PUIENet_MP import mynet
9 | from data import get_eval_set
10 |
11 |
12 | # settings
13 | parser = argparse.ArgumentParser(description='PyTorch PUIE-Net')
14 | parser.add_argument('--testBatchSize', type=int, default=1, help='testing batch size')
15 | parser.add_argument('--gpu_mode', type=bool, default=True)
16 | parser.add_argument('--threads', type=int, default=4, help='number of threads for data loader to use')
17 | parser.add_argument('--seed', type=int, default=123, help='random seed to use Default=123')
18 | parser.add_argument('--device', type=str, default='cuda:2')
19 | parser.add_argument('--input_dir', type=str, default='dataset/new_UIEBD/test/image')
20 | parser.add_argument('--output', default='results/UIEBD', help='Location to save checkpoint models')
21 | parser.add_argument('--reference_out', type=str, default='mp')
22 | parser.add_argument('--model', default='weights/puie.pth', help='Pretrained base model')
23 |
24 |
25 | opt = parser.parse_args()
26 | print(opt)
27 | device = torch.device(opt.device)
28 | cuda = opt.gpu_mode
29 | if cuda and not torch.cuda.is_available():
30 | raise Exception("No GPU found, please run without --cuda")
31 |
32 | # torch.manual_seed(opt.seed)
33 | # if cuda:
34 | # torch.cuda.manual_seed(opt.seed)
35 |
36 | print('===> Loading datasets')
37 | test_set = get_eval_set(opt.input_dir, opt.input_dir)
38 | testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False)
39 |
40 | print('===> Building model')
41 |
42 | model = mynet(opt)
43 |
44 | model.load_state_dict(torch.load(opt.model, map_location=lambda storage, loc: storage))
45 | print('Pre-trained model is loaded.')
46 |
47 | if cuda:
48 | model = model.cuda(device)
49 |
50 |
51 | def eval():
52 | model.eval()
53 | torch.set_grad_enabled(False)
54 | for batch in testing_data_loader:
55 | with torch.no_grad():
56 | input, _, name = Variable(batch[0]), Variable(batch[1]), batch[2]
57 | if cuda:
58 | input = input.cuda(device)
59 |
60 | with torch.no_grad():
61 |
62 | model.forward(input, input, training=False)
63 | t0 = time.time()
64 | prediction = model.sample(testing=True)
65 | t1 = time.time()
66 | print("===> Processing: %s || Timer: %.4f sec." % (name[0], (t1 - t0)))
67 | save_img_2(prediction.cpu().data, name[0], opt.reference_out)
68 |
69 |
70 | def save_img_2(img, img_name, out):
71 | save_img = img.squeeze().clamp(0, 1).numpy().transpose(1, 2, 0)
72 | # save img
73 | save_dir = os.path.join(opt.output, out)
74 | if not os.path.exists(save_dir):
75 | os.makedirs(save_dir)
76 |
77 | name_list = img_name.split('.', 1)
78 | save_fn = save_dir + '/' + name_list[0] + '.' + name_list[1]
79 | cv2.imwrite(save_fn, cv2.cvtColor(save_img * 255, cv2.COLOR_BGR2RGB), [cv2.IMWRITE_PNG_COMPRESSION, 0])
80 |
81 | eval()
82 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.distributions import Normal, Independent
4 | import numpy as np
5 |
6 |
7 | class SELayer(torch.nn.Module):
8 | def __init__(self, num_filter):
9 | super(SELayer, self).__init__()
10 | self.global_pool = torch.nn.AdaptiveAvgPool2d(1)
11 | self.conv_double = torch.nn.Sequential(
12 | nn.Conv2d(num_filter, num_filter // 16, 1, 1, 0, bias=True),
13 | nn.LeakyReLU(),
14 | nn.Conv2d(num_filter // 16, num_filter, 1, 1, 0, bias=True),
15 | nn.Sigmoid())
16 |
17 | def forward(self, x):
18 | mask = self.global_pool(x)
19 | mask = self.conv_double(mask)
20 | x = x * mask
21 | return x
22 |
23 |
24 | class ResBlock(nn.Module):
25 | def __init__(self, num_filter):
26 | super(ResBlock, self).__init__()
27 | body = []
28 | for i in range(2):
29 | body.append(nn.ReflectionPad2d(1))
30 | body.append(nn.Conv2d(num_filter, num_filter, kernel_size=3, padding=0))
31 | if i == 0:
32 | body.append(nn.LeakyReLU())
33 | body.append(SELayer(num_filter))
34 | self.body = nn.Sequential(*body)
35 |
36 | def forward(self, x):
37 | res = self.body(x)
38 | x = res + x
39 | return x
40 |
41 |
42 | class Up(nn.Module):
43 | def __init__(self):
44 | super(Up, self).__init__()
45 | self.up = nn.Sequential(
46 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
47 | )
48 |
49 | def forward(self, x):
50 | x = self.up(x)
51 | return x
52 |
53 |
54 | class ConvBlock(nn.Module):
55 | def __init__(self, ch_in, ch_out):
56 | super(ConvBlock, self).__init__()
57 | self.conv = nn.Sequential(
58 | nn.ReflectionPad2d(1),
59 | nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=0),
60 | nn.LeakyReLU(),
61 | nn.ReflectionPad2d(1),
62 | nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=0),
63 | nn.LeakyReLU(),
64 | )
65 |
66 | def forward(self, x):
67 | x = self.conv(x)
68 | return x
69 |
70 |
71 | class Compute_z(nn.Module):
72 | def __init__(self, latent_dim):
73 | super(Compute_z, self).__init__()
74 | self.latent_dim = latent_dim
75 | self.u_conv_layer = nn.Conv2d(128, 2 * self.latent_dim, kernel_size=1, padding=0)
76 | self.s_conv_layer = nn.Conv2d(128, 2 * self.latent_dim, kernel_size=1, padding=0)
77 |
78 | def forward(self, x):
79 | u_encoding = torch.mean(x, dim=2, keepdim=True)
80 | u_encoding = torch.mean(u_encoding, dim=3, keepdim=True)
81 | u_mu_log_sigma = self.u_conv_layer(u_encoding)
82 | u_mu_log_sigma = torch.squeeze(u_mu_log_sigma, dim=2)
83 | u_mu_log_sigma = torch.squeeze(u_mu_log_sigma, dim=2)
84 | u_mu = u_mu_log_sigma[:, :self.latent_dim]
85 | u_log_sigma = u_mu_log_sigma[:, self.latent_dim:]
86 | u_dist = Independent(Normal(loc=u_mu, scale=torch.exp(u_log_sigma)), 1)
87 |
88 | s_encoding = torch.std(x, dim=2, keepdim=True)
89 | s_encoding = torch.std(s_encoding, dim=3, keepdim=True)
90 | s_mu_log_sigma = self.s_conv_layer(s_encoding)
91 | s_mu_log_sigma = torch.squeeze(s_mu_log_sigma, dim=2)
92 | s_mu_log_sigma = torch.squeeze(s_mu_log_sigma, dim=2)
93 | s_mu = s_mu_log_sigma[:, :self.latent_dim]
94 | s_log_sigma = s_mu_log_sigma[:, self.latent_dim:]
95 | s_dist = Independent(Normal(loc=s_mu, scale=torch.exp(s_log_sigma)), 1)
96 | return u_dist, s_dist, u_mu, s_mu, torch.exp(u_log_sigma), torch.exp(s_log_sigma)
97 |
98 |
--------------------------------------------------------------------------------
/Test_MC.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import cv2
4 | import time
5 | import argparse
6 | from torch.autograd import Variable
7 | from torch.utils.data import DataLoader
8 | from PUIENet_MC import mynet
9 | from data import get_eval_set
10 |
11 |
12 | # settings
13 | parser = argparse.ArgumentParser(description='PyTorch PUIE-Net')
14 | parser.add_argument('--testBatchSize', type=int, default=1, help='testing batch size')
15 | parser.add_argument('--gpu_mode', type=bool, default=True)
16 | parser.add_argument('--threads', type=int, default=4, help='number of threads for data loader to use')
17 | parser.add_argument('--seed', type=int, default=123, help='random seed to use Default=123')
18 | parser.add_argument('--device', type=str, default='cuda:2')
19 | parser.add_argument('--input_dir', type=str, default='dataset/new_UIEBD/test/image')
20 | parser.add_argument('--output', default='results/UIEBD', help='Location to save checkpoint models')
21 | parser.add_argument('--sample_out', type=str, default='sample')
22 | parser.add_argument('--reference_out', type=str, default='mc')
23 | parser.add_argument('--model', default='weights/puie.pth', help='Pretrained base model')
24 |
25 |
26 | opt = parser.parse_args()
27 | print(opt)
28 | device = torch.device(opt.device)
29 | cuda = opt.gpu_mode
30 | if cuda and not torch.cuda.is_available():
31 | raise Exception("No GPU found, please run without --cuda")
32 |
33 | # torch.manual_seed(opt.seed)
34 | # if cuda:
35 | # torch.cuda.manual_seed(opt.seed)
36 |
37 | print('===> Loading datasets')
38 | test_set = get_eval_set(opt.input_dir, opt.input_dir)
39 | testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False)
40 |
41 | print('===> Building model')
42 |
43 | model = mynet(opt)
44 |
45 | model.load_state_dict(torch.load(opt.model, map_location=lambda storage, loc: storage))
46 | print('Pre-trained model is loaded.')
47 |
48 | if cuda:
49 | model = model.cuda(device)
50 |
51 |
52 | def eval():
53 | model.eval()
54 | torch.set_grad_enabled(False)
55 | for batch in testing_data_loader:
56 | with torch.no_grad():
57 | input, _, name = Variable(batch[0]), Variable(batch[1]), batch[2]
58 | if cuda:
59 | input = input.cuda(device)
60 |
61 | with torch.no_grad():
62 |
63 | model.forward(input, input, training=False)
64 | avg_pre = 0
65 | for i in range(20):
66 | t0 = time.time()
67 | prediction = model.sample(testing=True)
68 | t1 = time.time()
69 | avg_pre = avg_pre + prediction / 20
70 | save_img_1(prediction.cpu().data, name[0], i, opt.sample_out)
71 |
72 | print("===> Processing: %s || Timer: %.4f sec." % (name[0], (t1 - t0)))
73 | save_img_2(avg_pre.cpu().data, name[0], opt.reference_out)
74 |
75 |
76 | def save_img_1(img, img_name, i, out):
77 | save_img = img.squeeze().clamp(0, 1).numpy().transpose(1, 2, 0)
78 | # save img
79 | save_dir = os.path.join(opt.output, out)
80 | if not os.path.exists(save_dir):
81 | os.makedirs(save_dir)
82 |
83 | name_list = img_name.split('.', 1)
84 | save_fn = save_dir + '/' + name_list[0] + '_' + str(i) + '.' + name_list[1]
85 | cv2.imwrite(save_fn, cv2.cvtColor(save_img * 255, cv2.COLOR_BGR2RGB), [cv2.IMWRITE_PNG_COMPRESSION, 0])
86 |
87 |
88 | def save_img_2(img, img_name, out):
89 | save_img = img.squeeze().clamp(0, 1).numpy().transpose(1, 2, 0)
90 | # save img
91 | save_dir = os.path.join(opt.output, out)
92 | if not os.path.exists(save_dir):
93 | os.makedirs(save_dir)
94 |
95 | name_list = img_name.split('.', 1)
96 | save_fn = save_dir + '/' + name_list[0] + '.' + name_list[1]
97 | cv2.imwrite(save_fn, cv2.cvtColor(save_img * 255, cv2.COLOR_BGR2RGB), [cv2.IMWRITE_PNG_COMPRESSION, 0])
98 |
99 | eval()
100 |
--------------------------------------------------------------------------------
/Train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import socket
3 | import time
4 | import argparse
5 | import torch.optim as optim
6 | import torch.backends.cudnn as cudnn
7 | import torch.optim.lr_scheduler as lrs
8 | from torch.utils.data import DataLoader
9 | from PUIENet_MC import mynet
10 | from torch.autograd import Variable
11 | from data import get_training_set
12 |
13 |
14 | # Training settings
15 | parser = argparse.ArgumentParser(description='PyTorch PUIE-Net')
16 | parser.add_argument('--device', type=str, default='cuda:0')
17 | parser.add_argument('--batchSize', type=int, default=4, help='training batch size')
18 | parser.add_argument('--nEpochs', type=int, default=500, help='number of epochs to train for')
19 | parser.add_argument('--snapshots', type=int, default=2, help='Snapshots')
20 | parser.add_argument('--start_iter', type=int, default=1, help='Starting Epoch')
21 | parser.add_argument('--lr', type=float, default=1e-4, help='Learning Rate. Default=1e-4')
22 | parser.add_argument('--data_dir', type=str, default='dataset/new_UIEBD/train')
23 | parser.add_argument('--label_train_dataset', type=str, default='label')
24 | parser.add_argument('--data_train_dataset', type=str, default='image')
25 | parser.add_argument('--patch_size', type=int, default=256, help='Size of cropped image')
26 | parser.add_argument('--save_folder', default='weights/', help='Location to save checkpoint models')
27 | parser.add_argument('--gpu_mode', type=bool, default=True)
28 | parser.add_argument('--threads', type=int, default=4, help='number of threads for data loader to use')
29 | parser.add_argument('--decay', type=int, default='10000', help='learning rate decay type')
30 | parser.add_argument('--gamma', type=float, default=0.5, help='learning rate decay factor for step decay')
31 | parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123')
32 | parser.add_argument('--data_augmentation', type=bool, default=True)
33 |
34 |
35 | opt = parser.parse_args()
36 | device = torch.device(opt.device)
37 | hostname = str(socket.gethostname())
38 | cudnn.benchmark = True
39 | print(opt)
40 |
41 |
42 | def train(epoch):
43 | epoch_loss = 0
44 | model.train()
45 | for iteration, batch in enumerate(training_data_loader, 1):
46 | input, target = Variable(batch[0]), Variable(batch[1])
47 | if cuda:
48 | input = input.to(device)
49 | target = target.to(device)
50 |
51 | t0 = time.time()
52 | model.forward(input, target, training=True)
53 | loss = model.elbo(target)
54 | optimizer.zero_grad()
55 | loss.backward()
56 | epoch_loss += loss.item()
57 | optimizer.step()
58 | t1 = time.time()
59 |
60 | print("===> Epoch[{}]({}/{}): Loss: {:.4f} || Learning rate: lr={} || Timer: {:.4f} sec.".format(epoch, iteration,
61 | len(training_data_loader), loss.item(), optimizer.param_groups[0]['lr'], (t1 - t0)))
62 | print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(training_data_loader)))
63 |
64 |
65 | def checkpoint(epoch):
66 | model_out_path = opt.save_folder+"epoch_{}.pth".format(epoch)
67 | torch.save(model.state_dict(), model_out_path)
68 | print("Checkpoint saved to {}".format(model_out_path))
69 |
70 |
71 | cuda = opt.gpu_mode
72 | if cuda and not torch.cuda.is_available():
73 | raise Exception("No GPU found, please run without --cuda")
74 |
75 | torch.manual_seed(opt.seed)
76 | if cuda:
77 | torch.cuda.manual_seed(opt.seed)
78 |
79 | print('===> Loading datasets')
80 |
81 | train_set = get_training_set(opt.data_dir, opt.label_train_dataset, opt.data_train_dataset, opt.patch_size, opt.data_augmentation)
82 | training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
83 |
84 | model = mynet(opt)
85 |
86 | # print('---------- Networks architecture -------------')
87 | # print_network(model)
88 | # print('----------------------------------------------')
89 |
90 | optimizer = optim.Adam(model.parameters(), lr=opt.lr, betas=(0.9, 0.999), eps=1e-8)
91 |
92 | milestones = []
93 | for i in range(1, opt.nEpochs+1):
94 | if i % opt.decay == 0:
95 | milestones.append(i)
96 |
97 | scheduler = lrs.MultiStepLR(optimizer, milestones, opt.gamma)
98 |
99 | for epoch in range(opt.start_iter, opt.nEpochs + 1):
100 |
101 | train(epoch)
102 | scheduler.step()
103 |
104 | if (epoch+1) % opt.snapshots == 0:
105 | checkpoint(epoch)
106 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import torch.utils.data as data
4 | from os import listdir
5 | from os.path import join
6 | from PIL import Image, ImageOps
7 |
8 |
9 | def is_image_file(filename):
10 | return any(filename.endswith(extension) for extension in [".png", ".jpg", ".bmp"])
11 |
12 |
13 | def load_img(filepath):
14 | img = Image.open(filepath).convert('RGB')
15 | return img
16 |
17 |
18 | def rescale_img(img_in, scale):
19 | size_in = img_in.size
20 | new_size_in = tuple([int(x * scale) for x in size_in])
21 | img_in = img_in.resize(new_size_in, resample=Image.BICUBIC)
22 | return img_in
23 |
24 |
25 | def get_patch(img_in, img_tar, patch_size, scale=1, ix=-1, iy=-1):
26 | (ih, iw) = img_in.size
27 |
28 | patch_mult = scale
29 | tp = patch_mult * patch_size
30 | ip = tp // scale
31 |
32 | if ix == -1:
33 | ix = random.randrange(0, iw - ip + 1)
34 | if iy == -1:
35 | iy = random.randrange(0, ih - ip + 1)
36 |
37 | (tx, ty) = (scale * ix, scale * iy)
38 |
39 | img_in = img_in.crop((iy, ix, iy + ip, ix + ip))
40 | img_tar = img_tar.crop((ty, tx, ty + tp, tx + tp))
41 |
42 | info_patch = {
43 | 'ix': ix, 'iy': iy, 'ip': ip, 'tx': tx, 'ty': ty, 'tp': tp}
44 |
45 | return img_in, img_tar, info_patch
46 |
47 |
48 | def augment(img_in, img_tar, flip_h=True, rot=True):
49 | info_aug = {'flip_h': False, 'flip_v': False, 'trans': False}
50 |
51 | if random.random() < 0.5 and flip_h:
52 | img_in = ImageOps.flip(img_in)
53 | img_tar = ImageOps.flip(img_tar)
54 | info_aug['flip_h'] = True
55 |
56 | if rot:
57 | if random.random() < 0.5:
58 | img_in = ImageOps.mirror(img_in)
59 | img_tar = ImageOps.mirror(img_tar)
60 | info_aug['flip_v'] = True
61 | if random.random() < 0.5:
62 | img_in = img_in.rotate(180)
63 | img_tar = img_tar.rotate(180)
64 | info_aug['trans'] = True
65 |
66 | return img_in, img_tar, info_aug
67 |
68 |
69 | class DatasetFromFolder(data.Dataset):
70 | def __init__(self, label_dir, data_dir, patch_size, data_augmentation, transform=None):
71 | super(DatasetFromFolder, self).__init__()
72 | self.label_path = label_dir
73 | data_filenames = [join(data_dir, x) for x in listdir(data_dir) if is_image_file(x)]
74 | data_filenames.sort()
75 | self.data_filenames = data_filenames
76 | self.patch_size = patch_size
77 | self.transform = transform
78 | self.data_augmentation = data_augmentation
79 |
80 | def __getitem__(self, index):
81 | _, file = os.path.split(self.data_filenames[index])
82 |
83 | k = random.randint(0,3)
84 | if k == 0:
85 | label_filenames = self.label_path + '/O/' + file
86 | if k == 1 :
87 | label_filenames = self.label_path + '/S/' + file
88 | if k == 2 :
89 | label_filenames = self.label_path + '/C/' + file
90 | if k == 3 :
91 | label_filenames = self.label_path + '/G/' + file
92 |
93 |
94 | target = load_img(label_filenames)
95 | input = load_img(self.data_filenames[index])
96 |
97 | input = input.resize((512, 512), resample=Image.BICUBIC)
98 | target = target.resize((512, 512), resample=Image.BICUBIC)
99 | input, target, _ = get_patch(input, target, self.patch_size)
100 |
101 | if self.data_augmentation:
102 | input, target, _ = augment(input, target)
103 |
104 | if self.transform:
105 | input = self.transform(input)
106 | target = self.transform(target)
107 |
108 | return input, target, file
109 |
110 | def __len__(self):
111 | return len(self.data_filenames)
112 |
113 |
114 | class DatasetFromFolderEval(data.Dataset):
115 | def __init__(self, data_dir, label_dir, transform=None):
116 | super(DatasetFromFolderEval, self).__init__()
117 | data_filenames = [join(data_dir, x) for x in listdir(data_dir) if is_image_file(x)]
118 | data_filenames.sort()
119 | self.data_filenames = data_filenames
120 |
121 | label_filenames = [join(label_dir, x) for x in listdir(label_dir) if is_image_file(x)]
122 | label_filenames.sort()
123 | self.label_filenames = label_filenames
124 |
125 | self.transform = transform
126 |
127 | def __getitem__(self, index):
128 | input = load_img(self.data_filenames[index])
129 | label = load_img(self.label_filenames[index])
130 | _, file = os.path.split(self.data_filenames[index])
131 |
132 | (ih, iw) = input.size
133 | dh = ih % 8
134 | dw = iw % 8
135 | new_h, new_w = ih - dh, iw - dw
136 |
137 | input = input.resize((new_h, new_w))
138 | label = label.resize((new_h, new_w))
139 |
140 | if self.transform:
141 | input = self.transform(input)
142 | label = self.transform(label)
143 |
144 | return input, label, file
145 |
146 | def __len__(self):
147 | return len(self.data_filenames)
148 |
149 |
150 |
--------------------------------------------------------------------------------
/PUIENet_MC.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision.models.vgg import vgg16
4 | from torch.distributions import kl
5 | from utils import ResBlock, ConvBlock, Up, Compute_z
6 |
7 |
8 | class Encoder(nn.Module):
9 | def __init__(self, ch):
10 | super(Encoder, self).__init__()
11 | self.conv1 = ConvBlock(ch_in=ch, ch_out=64)
12 | self.conv2 = ConvBlock(ch_in=64, ch_out=64)
13 | self.conv3 = ConvBlock(ch_in=64, ch_out=64)
14 | self.conv4 = ResBlock(64)
15 | self.pool1 = nn.MaxPool2d(2)
16 | self.pool2 = nn.MaxPool2d(2)
17 | self.pool3 = nn.MaxPool2d(2)
18 | def forward(self, x):
19 | x1 = self.conv1(x)
20 | x2 = self.pool1(x1)
21 | x2 = self.conv2(x2)
22 | x3 = self.pool2(x2)
23 | x3 = self.conv3(x3)
24 | x4 = self.pool3(x3)
25 | x4 = self.conv4(x4)
26 | return x1, x2, x3, x4
27 |
28 |
29 | class Decoder(nn.Module):
30 | def __init__(self, device):
31 | super(Decoder, self).__init__()
32 | self.device = device
33 | self.pr_encoder = Encoder(3)
34 | self.po_encoder = Encoder(6)
35 |
36 | self.pr_conv = ResBlock(64)
37 | self.po_conv = ResBlock(64)
38 |
39 | self.pr_Up3 = Up()
40 | self.pr_UpConv3 = ConvBlock(ch_in=128, ch_out=64)
41 | self.pr_Up2 = Up()
42 | self.pr_UpConv2 = ConvBlock(ch_in=128, ch_out=64)
43 | self.pr_Up1 = Up()
44 | self.pr_UpConv1 = ConvBlock(ch_in=128, ch_out=64)
45 |
46 | self.po_Up3 = Up()
47 | self.po_UpConv3 = ConvBlock(ch_in=128, ch_out=64)
48 | self.po_Up2 = Up()
49 | self.po_UpConv2 = ConvBlock(ch_in=128, ch_out=64)
50 |
51 | out_conv = []
52 | out_conv.append(ResBlock(64))
53 | out_conv.append(ResBlock(64))
54 | out_conv.append(nn.Conv2d(64, 3, kernel_size=1, padding=0))
55 | self.out_conv = nn.Sequential(*out_conv)
56 |
57 | z = 20
58 |
59 | self.compute_z_pr = Compute_z(z)
60 | self.compute_z_po = Compute_z(z)
61 |
62 | self.conv_u = nn.Conv2d(z, 128, kernel_size=1, padding=0)
63 | self.conv_s = nn.Conv2d(z, 128, kernel_size=1, padding=0)
64 |
65 | self.insnorm = nn.InstanceNorm2d(128)
66 | self.sigmoid = nn.Sigmoid()
67 |
68 |
69 | def forward(self, Input, Target, training=True):
70 |
71 | pr_x1, pr_x2, pr_x3, pr_x4 = self.pr_encoder.forward(Input)
72 |
73 | if training:
74 | po_x1, po_x2, po_x3, po_x4 = self.po_encoder.forward(torch.cat((Input, Target), dim=1))
75 | # x4->x3
76 | pr_x4 = self.pr_conv(pr_x4)
77 | po_x4 = self.po_conv(po_x4)
78 | pr_d3 = self.pr_Up3(pr_x4)
79 | po_d3 = self.po_Up3(po_x4)
80 | # x3->x2
81 | pr_d3 = torch.cat((pr_x3, pr_d3), dim=1)
82 | po_d3 = torch.cat((po_x3, po_d3), dim=1)
83 | pr_d3 = self.pr_UpConv3(pr_d3)
84 | po_d3 = self.po_UpConv3(po_d3)
85 | pr_d2 = self.pr_Up2(pr_d3)
86 | po_d2 = self.po_Up2(po_d3)
87 | # x2->x1
88 | pr_d2 = torch.cat((pr_x2, pr_d2), dim=1)
89 | po_d2 = torch.cat((po_x2, po_d2), dim=1)
90 | pr_d2 = self.pr_UpConv2(pr_d2)
91 | po_d2 = self.po_UpConv2(po_d2)
92 | pr_d1 = self.pr_Up1(pr_d2)
93 | po_d1 = self.pr_Up1(po_d2)
94 | # cat
95 | pr_d1 = torch.cat((pr_x1, pr_d1), dim=1)
96 | po_d1 = torch.cat((po_x1, po_d1), dim=1)
97 | # x1->dis
98 | pr_u_dist, pr_s_dist, _, _, _, _ = self.compute_z_pr(pr_d1)
99 | po_u_dist, po_s_dist, _, _, _, _ = self.compute_z_po(po_d1)
100 | po_latent_u = po_u_dist.rsample()
101 | po_latent_s = po_s_dist.rsample()
102 | po_latent_u = torch.unsqueeze(po_latent_u, -1)
103 | po_latent_u = torch.unsqueeze(po_latent_u, -1)
104 | po_latent_s = torch.unsqueeze(po_latent_s, -1)
105 | po_latent_s = torch.unsqueeze(po_latent_s, -1)
106 | po_u = self.conv_u(po_latent_u)
107 | po_s = self.conv_s(po_latent_s)
108 | pr_d1 = self.insnorm(pr_d1) * torch.abs(po_s) + po_u
109 | # x1->out
110 | pr_d1 = self.pr_UpConv1(pr_d1)
111 | out = self.out_conv(pr_d1)
112 |
113 | return out, pr_u_dist, pr_s_dist, po_u_dist, po_s_dist
114 |
115 | else:
116 | # x4->x3
117 | pr_x4 = self.pr_conv(pr_x4)
118 | pr_d3 = self.pr_Up3(pr_x4)
119 | # x3->x2
120 | pr_d3 = torch.cat((pr_x3, pr_d3), dim=1)
121 | pr_d3 = self.pr_UpConv3(pr_d3)
122 | pr_d2 = self.pr_Up2(pr_d3)
123 | # x2->x1
124 | pr_d2 = torch.cat((pr_x2, pr_d2), dim=1)
125 | pr_d2 = self.pr_UpConv2(pr_d2)
126 | pr_d1 = self.pr_Up1(pr_d2)
127 | # cat
128 | pr_d1 = torch.cat((pr_x1, pr_d1), dim=1)
129 | # x1->dis
130 | pr_u_dist, pr_s_dist, _, _, _, _ = self.compute_z_pr(pr_d1)
131 | pr_latent_u = pr_u_dist.rsample()
132 | pr_latent_s = pr_s_dist.rsample()
133 | pr_latent_u = torch.unsqueeze(pr_latent_u, -1)
134 | pr_latent_u = torch.unsqueeze(pr_latent_u, -1)
135 | pr_latent_s = torch.unsqueeze(pr_latent_s, -1)
136 | pr_latent_s = torch.unsqueeze(pr_latent_s, -1)
137 | pr_u = self.conv_u(pr_latent_u)
138 | pr_s = self.conv_s(pr_latent_s)
139 | pr_d1 = self.insnorm(pr_d1) * torch.abs(pr_s) + pr_u
140 | # x1->out
141 | pr_d1 = self.pr_UpConv1(pr_d1)
142 | out = self.out_conv(pr_d1)
143 | return out
144 |
145 |
146 | class PerceptionLoss(nn.Module):
147 | def __init__(self):
148 | super(PerceptionLoss, self).__init__()
149 | vgg = vgg16(pretrained=True)
150 | loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
151 | for param in loss_network.parameters():
152 | param.requires_grad = False
153 | self.loss_network = loss_network
154 | self.mse_loss = nn.MSELoss()
155 |
156 | def forward(self, out_images, target_images):
157 | perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
158 | return perception_loss
159 |
160 |
161 | class mynet(nn.Module):
162 | def __init__(self, opt):
163 | super(mynet, self).__init__()
164 | self.device = torch.device(opt.device)
165 | self.decoder = Decoder(device=self.device).to(self.device)
166 | self.criterion = nn.MSELoss().to(self.device)
167 | self.VGG16 = PerceptionLoss().to(self.device)
168 |
169 | def forward(self, Input, label, training=True):
170 | self.Input = Input
171 | self.label = label
172 | if training:
173 | self.out, self.pr_u, self.pr_s, self.po_u, self.po_s = self.decoder.forward(Input, label)
174 |
175 | def sample(self, testing=False):
176 | if testing:
177 | self.out = self.decoder.forward(self.Input, self.label, training=False)
178 | return self.out
179 |
180 | def kl_divergence(self, analytic=True):
181 | if analytic:
182 | kl_div_u = torch.mean(kl.kl_divergence(self.po_u, self.pr_u))
183 | kl_div_s = torch.mean(kl.kl_divergence(self.po_s, self.pr_s))
184 | return kl_div_u + kl_div_s
185 |
186 | def elbo(self, target, analytic_kl=True):
187 | self.kl_loss = self.kl_divergence(analytic=analytic_kl)
188 | self.reconstruction_loss = self.criterion(self.out, target)
189 | self.vgg16_loss = self.VGG16(self.out, target)
190 | return self.reconstruction_loss + self.vgg16_loss + self.kl_loss
191 |
--------------------------------------------------------------------------------
/PUIENet_MP.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision.models.vgg import vgg16
4 | from torch.distributions import kl
5 | from utils import ResBlock, ConvBlock, Up, Compute_z
6 |
7 |
8 | class Encoder(nn.Module):
9 | def __init__(self, ch):
10 | super(Encoder, self).__init__()
11 | self.conv1 = ConvBlock(ch_in=ch, ch_out=64)
12 | self.conv2 = ConvBlock(ch_in=64, ch_out=64)
13 | self.conv3 = ConvBlock(ch_in=64, ch_out=64)
14 | self.conv4 = ResBlock(64)
15 | self.pool1 = nn.MaxPool2d(2)
16 | self.pool2 = nn.MaxPool2d(2)
17 | self.pool3 = nn.MaxPool2d(2)
18 | def forward(self, x):
19 | x1 = self.conv1(x)
20 | x2 = self.pool1(x1)
21 | x2 = self.conv2(x2)
22 | x3 = self.pool2(x2)
23 | x3 = self.conv3(x3)
24 | x4 = self.pool3(x3)
25 | x4 = self.conv4(x4)
26 | return x1, x2, x3, x4
27 |
28 |
29 | class Decoder(nn.Module):
30 | def __init__(self, device):
31 | super(Decoder, self).__init__()
32 | self.device = device
33 | self.pr_encoder = Encoder(3)
34 | self.po_encoder = Encoder(6)
35 |
36 | self.pr_conv = ResBlock(64)
37 | self.po_conv = ResBlock(64)
38 |
39 | self.pr_Up3 = Up()
40 | self.pr_UpConv3 = ConvBlock(ch_in=128, ch_out=64)
41 | self.pr_Up2 = Up()
42 | self.pr_UpConv2 = ConvBlock(ch_in=128, ch_out=64)
43 | self.pr_Up1 = Up()
44 | self.pr_UpConv1 = ConvBlock(ch_in=128, ch_out=64)
45 |
46 | self.po_Up3 = Up()
47 | self.po_UpConv3 = ConvBlock(ch_in=128, ch_out=64)
48 | self.po_Up2 = Up()
49 | self.po_UpConv2 = ConvBlock(ch_in=128, ch_out=64)
50 |
51 | out_conv = []
52 | out_conv.append(ResBlock(64))
53 | out_conv.append(ResBlock(64))
54 | out_conv.append(nn.Conv2d(64, 3, kernel_size=1, padding=0))
55 | self.out_conv = nn.Sequential(*out_conv)
56 |
57 | z = 20
58 |
59 | self.compute_z_pr = Compute_z(z)
60 | self.compute_z_po = Compute_z(z)
61 |
62 | self.conv_u = nn.Conv2d(z, 128, kernel_size=1, padding=0)
63 | self.conv_s = nn.Conv2d(z, 128, kernel_size=1, padding=0)
64 |
65 | self.insnorm = nn.InstanceNorm2d(128)
66 | self.sigmoid = nn.Sigmoid()
67 |
68 |
69 | def forward(self, Input, Target, training=True):
70 |
71 | pr_x1, pr_x2, pr_x3, pr_x4 = self.pr_encoder.forward(Input)
72 |
73 | if training:
74 | po_x1, po_x2, po_x3, po_x4 = self.po_encoder.forward(torch.cat((Input, Target), dim=1))
75 | # x4->x3
76 | pr_x4 = self.pr_conv(pr_x4)
77 | po_x4 = self.po_conv(po_x4)
78 | pr_d3 = self.pr_Up3(pr_x4)
79 | po_d3 = self.po_Up3(po_x4)
80 | # x3->x2
81 | pr_d3 = torch.cat((pr_x3, pr_d3), dim=1)
82 | po_d3 = torch.cat((po_x3, po_d3), dim=1)
83 | pr_d3 = self.pr_UpConv3(pr_d3)
84 | po_d3 = self.po_UpConv3(po_d3)
85 | pr_d2 = self.pr_Up2(pr_d3)
86 | po_d2 = self.po_Up2(po_d3)
87 | # x2->x1
88 | pr_d2 = torch.cat((pr_x2, pr_d2), dim=1)
89 | po_d2 = torch.cat((po_x2, po_d2), dim=1)
90 | pr_d2 = self.pr_UpConv2(pr_d2)
91 | po_d2 = self.po_UpConv2(po_d2)
92 | pr_d1 = self.pr_Up1(pr_d2)
93 | po_d1 = self.pr_Up1(po_d2)
94 | # cat
95 | pr_d1 = torch.cat((pr_x1, pr_d1), dim=1)
96 | po_d1 = torch.cat((po_x1, po_d1), dim=1)
97 | # x1->dis
98 | pr_u_dist, pr_s_dist, _, _, _, _ = self.compute_z_pr(pr_d1)
99 | po_u_dist, po_s_dist, _, _, _, _ = self.compute_z_po(po_d1)
100 | po_latent_u = po_u_dist.rsample()
101 | po_latent_s = po_s_dist.rsample()
102 | po_latent_u = torch.unsqueeze(po_latent_u, -1)
103 | po_latent_u = torch.unsqueeze(po_latent_u, -1)
104 | po_latent_s = torch.unsqueeze(po_latent_s, -1)
105 | po_latent_s = torch.unsqueeze(po_latent_s, -1)
106 | po_u = self.conv_u(po_latent_u)
107 | po_s = self.conv_s(po_latent_s)
108 | pr_d1 = self.insnorm(pr_d1) * torch.abs(po_s) + po_u
109 | # x1->out
110 | pr_d1 = self.pr_UpConv1(pr_d1)
111 | out = self.out_conv(pr_d1)
112 |
113 | return out, pr_u_dist, pr_s_dist, po_u_dist, po_s_dist
114 |
115 | else:
116 | # x4->x3
117 | pr_x4 = self.pr_conv(pr_x4)
118 | pr_d3 = self.pr_Up3(pr_x4)
119 | # x3->x2
120 | pr_d3 = torch.cat((pr_x3, pr_d3), dim=1)
121 | pr_d3 = self.pr_UpConv3(pr_d3)
122 | pr_d2 = self.pr_Up2(pr_d3)
123 | # x2->x1
124 | pr_d2 = torch.cat((pr_x2, pr_d2), dim=1)
125 | pr_d2 = self.pr_UpConv2(pr_d2)
126 | pr_d1 = self.pr_Up1(pr_d2)
127 | # cat
128 | pr_d1 = torch.cat((pr_x1, pr_d1), dim=1)
129 | # x1->dis
130 | pr_u_dist, pr_s_dist, u_mu, s_mu, u_sigma, s_sigma = self.compute_z_pr(pr_d1)
131 |
132 | pr_latent_u = u_mu + u_sigma * 0
133 | pr_latent_s = s_mu + s_sigma * 0
134 |
135 | pr_latent_u = torch.unsqueeze(pr_latent_u, -1)
136 | pr_latent_u = torch.unsqueeze(pr_latent_u, -1)
137 | pr_latent_s = torch.unsqueeze(pr_latent_s, -1)
138 | pr_latent_s = torch.unsqueeze(pr_latent_s, -1)
139 | pr_u = self.conv_u(pr_latent_u)
140 | pr_s = self.conv_s(pr_latent_s)
141 | pr_d1 = self.insnorm(pr_d1) * torch.abs(pr_s) + pr_u
142 | # x1->out
143 | pr_d1 = self.pr_UpConv1(pr_d1)
144 | out = self.out_conv(pr_d1)
145 | return out
146 |
147 |
148 | class PerceptionLoss(nn.Module):
149 | def __init__(self):
150 | super(PerceptionLoss, self).__init__()
151 | vgg = vgg16(pretrained=True)
152 | loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
153 | for param in loss_network.parameters():
154 | param.requires_grad = False
155 | self.loss_network = loss_network
156 | self.mse_loss = nn.MSELoss()
157 |
158 | def forward(self, out_images, target_images):
159 | perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
160 | return perception_loss
161 |
162 |
163 | class mynet(nn.Module):
164 | def __init__(self, opt):
165 | super(mynet, self).__init__()
166 | self.device = torch.device(opt.device)
167 | self.decoder = Decoder(device=self.device).to(self.device)
168 | self.criterion = nn.MSELoss().to(self.device)
169 | self.VGG16 = PerceptionLoss().to(self.device)
170 |
171 | def forward(self, Input, label, training=True):
172 | self.Input = Input
173 | self.label = label
174 | if training:
175 | self.out, self.pr_u, self.pr_s, self.po_u, self.po_s = self.decoder.forward(Input, label)
176 |
177 | def sample(self, testing=False):
178 | if testing:
179 | self.out = self.decoder.forward(self.Input, self.label, training=False)
180 | return self.out
181 |
182 | def kl_divergence(self, analytic=True):
183 | if analytic:
184 | kl_div_u = torch.mean(kl.kl_divergence(self.po_u, self.pr_u))
185 | kl_div_s = torch.mean(kl.kl_divergence(self.po_s, self.pr_s))
186 | return kl_div_u + kl_div_s
187 |
188 | def elbo(self, target, analytic_kl=True):
189 | self.kl_loss = self.kl_divergence(analytic=analytic_kl)
190 | self.reconstruction_loss = self.criterion(self.out, target)
191 | self.vgg16_loss = self.VGG16(self.out, target)
192 | return self.reconstruction_loss + self.vgg16_loss + self.kl_loss
193 |
--------------------------------------------------------------------------------