├── figs ├── model.png └── result.png ├── __pycache__ ├── model.cpython-37.pyc ├── model.cpython-38.pyc ├── saver.cpython-37.pyc ├── saver.cpython-38.pyc ├── dataset.cpython-37.pyc ├── dataset.cpython-38.pyc ├── options.cpython-37.pyc ├── options.cpython-38.pyc ├── networks.cpython-37.pyc └── networks.cpython-38.pyc ├── LICENSE ├── README.md ├── test_transfer.py ├── saver.py ├── train.py ├── dataset.py ├── test.py ├── options.py ├── model.py └── networks.py /figs/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lkpengcs/DCDA/HEAD/figs/model.png -------------------------------------------------------------------------------- /figs/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lkpengcs/DCDA/HEAD/figs/result.png -------------------------------------------------------------------------------- /__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lkpengcs/DCDA/HEAD/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lkpengcs/DCDA/HEAD/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/saver.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lkpengcs/DCDA/HEAD/__pycache__/saver.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/saver.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lkpengcs/DCDA/HEAD/__pycache__/saver.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lkpengcs/DCDA/HEAD/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lkpengcs/DCDA/HEAD/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lkpengcs/DCDA/HEAD/__pycache__/options.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/options.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lkpengcs/DCDA/HEAD/__pycache__/options.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/networks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lkpengcs/DCDA/HEAD/__pycache__/networks.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/networks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lkpengcs/DCDA/HEAD/__pycache__/networks.cpython-38.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Linkai Peng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Domain Adaptation for Cross-Modality Retinal Vessel Segmentation via Disentangling Representation Style Transfer and Collaborative Consistency Learning 2 | 3 | Pytorch implementation for our unsupervised domain adaptation framework with application to retinal vessel segmentation. We use style transfer and collaborative consistency learning to train a segmentation model on the target domain. 4 | 5 | ![image](https://github.com/lkpengcs/DCDA/blob/main/figs/model.png) 6 | 7 | ## Paper 8 | 9 | Please cite our [paper](https://arxiv.org/abs/2201.04812) if you find the code useful for your research. 10 | 11 | ``` 12 | @article{peng2022unsupervised, 13 | title={Unsupervised Domain Adaptation for Cross-Modality Retinal Vessel Segmentation via Disentangling Representation Style Transfer and Collaborative Consistency Learning}, 14 | author={Peng, Linkai and Lin, Li and Cheng, Pujin and Huang, Ziqi and Tang, Xiaoying}, 15 | journal={arXiv preprint}, 16 | url={arXiv:2201.04812}, 17 | year={2022} 18 | } 19 | ``` 20 | 21 | ## Example Results 22 | 23 | ![image](https://github.com/lkpengcs/DCDA/blob/main/figs/result.png) 24 | 25 | ## Usage 26 | 27 | ### Prerequisite 28 | 29 | - Python 3.7+ 30 | - Pytorch 1.4.0 31 | 32 | ### Acknowledgement 33 | 34 | Code adapted from [DRIT](https://github.com/HsinYingLee/DRIT). 35 | 36 | -------------------------------------------------------------------------------- /test_transfer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from options import TestOptions 3 | from dataset import dataset_single 4 | from model import DRIT 5 | from saver import save_imgs 6 | import os 7 | 8 | def main(): 9 | # parse options 10 | parser = TestOptions() 11 | opts = parser.parse() 12 | 13 | # data loader 14 | print('\n--- load dataset ---') 15 | datasetA = dataset_single(opts, 'A', opts.input_dim_a) 16 | datasetB = dataset_single(opts, 'B', opts.input_dim_b) 17 | if opts.a2b: 18 | loader = torch.utils.data.DataLoader(datasetA, batch_size=1, num_workers=opts.nThreads) 19 | loader_attr = torch.utils.data.DataLoader(datasetB, batch_size=1, num_workers=opts.nThreads, shuffle=True) 20 | else: 21 | loader = torch.utils.data.DataLoader(datasetB, batch_size=1, num_workers=opts.nThreads) 22 | loader_attr = torch.utils.data.DataLoader(datasetA, batch_size=1, num_workers=opts.nThreads, shuffle=True) 23 | 24 | # model 25 | print('\n--- load model ---') 26 | model = DRIT(opts) 27 | model.setgpu(opts.gpu) 28 | model.resume(opts.resume, train=False) 29 | model.eval() 30 | 31 | # directory 32 | result_dir = os.path.join(opts.result_dir, opts.name) 33 | if not os.path.exists(result_dir): 34 | os.mkdir(result_dir) 35 | 36 | # test 37 | print('\n--- testing ---') 38 | for idx1, img1 in enumerate(loader): 39 | print('{}/{}'.format(idx1, len(loader))) 40 | img1 = img1.cuda() 41 | imgs = [img1] 42 | names = ['input'] 43 | for idx2, img2 in enumerate(loader_attr): 44 | if idx2 == opts.num: 45 | break 46 | img2 = img2.cuda() 47 | with torch.no_grad(): 48 | if opts.a2b: 49 | img = model.test_forward_transfer(img1, img2, a2b=True) 50 | else: 51 | img = model.test_forward_transfer(img2, img1, a2b=False) 52 | imgs.append(img) 53 | names.append('output_{}'.format(idx2)) 54 | save_imgs(imgs, names, os.path.join(result_dir, '{}'.format(idx1))) 55 | 56 | return 57 | 58 | if __name__ == '__main__': 59 | main() 60 | -------------------------------------------------------------------------------- /saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchvision 3 | from tensorboardX import SummaryWriter 4 | import numpy as np 5 | from PIL import Image 6 | 7 | # tensor to PIL Image 8 | def tensor2img(img): 9 | img = img[0].cpu().float().numpy() 10 | if img.shape[0] == 1: 11 | img = np.tile(img, (3, 1, 1)) 12 | img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0 13 | return img.astype(np.uint8) 14 | 15 | # save a set of images 16 | def save_imgs(imgs, names, path): 17 | if not os.path.exists(path): 18 | os.mkdir(path) 19 | for img, name in zip(imgs, names): 20 | img = tensor2img(img) 21 | img = Image.fromarray(img) 22 | if 'input' in name: 23 | img.save(os.path.join(path + '/real/', name + '.png')) 24 | if 'output' in name: 25 | img.save(os.path.join(path + '/fake/', name + '.png')) 26 | 27 | class Saver(): 28 | def __init__(self, opts): 29 | self.display_dir = os.path.join(opts.display_dir, opts.name) 30 | self.model_dir = os.path.join(opts.result_dir, opts.name) 31 | self.image_dir = os.path.join(self.model_dir, 'images') 32 | self.display_freq = opts.display_freq 33 | self.img_save_freq = opts.img_save_freq 34 | self.model_save_freq = opts.model_save_freq 35 | 36 | # make directory 37 | if not os.path.exists(self.display_dir): 38 | os.makedirs(self.display_dir) 39 | if not os.path.exists(self.model_dir): 40 | os.makedirs(self.model_dir) 41 | if not os.path.exists(self.image_dir): 42 | os.makedirs(self.image_dir) 43 | 44 | # create tensorboard writer 45 | self.writer = SummaryWriter(logdir=self.display_dir) 46 | 47 | # write losses and images to tensorboard 48 | def write_display(self, total_it, model): 49 | if (total_it + 1) % self.display_freq == 0: 50 | # write loss 51 | members = [attr for attr in dir(model) if not callable(getattr(model, attr)) and not attr.startswith("__") and 'loss' in attr] 52 | for m in members: 53 | self.writer.add_scalar(m, getattr(model, m), total_it) 54 | # write img 55 | image_dis = torchvision.utils.make_grid(model.image_display, nrow=model.image_display.size(0)//2)/2 + 0.5 56 | self.writer.add_image('Image', image_dis, total_it) 57 | 58 | # save result images 59 | def write_img(self, ep, model): 60 | if (ep + 1) % self.img_save_freq == 0: 61 | assembled_images = model.assemble_outputs() 62 | img_filename = '%s/gen_%05d.jpg' % (self.image_dir, ep) 63 | torchvision.utils.save_image(assembled_images / 2 + 0.5, img_filename, nrow=1) 64 | elif ep == -1: 65 | assembled_images = model.assemble_outputs() 66 | img_filename = '%s/gen_last.jpg' % (self.image_dir, ep) 67 | torchvision.utils.save_image(assembled_images / 2 + 0.5, img_filename, nrow=1) 68 | 69 | # save model 70 | def write_model(self, ep, total_it, model): 71 | if (ep + 1) % self.model_save_freq == 0: 72 | print('--- save the model @ ep %d ---' % (ep)) 73 | model.save('%s/%05d.pth' % (self.model_dir, ep), ep, total_it) 74 | elif ep == -2: 75 | model.save('%s/last.pth' % self.model_dir, ep, total_it) 76 | 77 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from os import pardir 2 | import torch 3 | from torch.nn.modules import loss 4 | from options import TrainOptions 5 | from dataset import dataset_unpair 6 | from model import DCDA 7 | from saver import Saver 8 | from numpy import mean 9 | import numpy as np 10 | import segmentation_models_pytorch as smp 11 | import cv2 12 | 13 | def main(): 14 | # parse options 15 | parser = TrainOptions() 16 | opts = parser.parse() 17 | 18 | # daita loader 19 | print('\n--- load dataset ---') 20 | dataset = dataset_unpair(opts) 21 | train_loader = torch.utils.data.DataLoader(dataset, batch_size=opts.batch_size, shuffle=True, num_workers=opts.nThreads) 22 | pre_seg_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=opts.nThreads) 23 | 24 | # model 25 | print('\n--- load model ---') 26 | pre_seg_model = smp.Unet(encoder_name="resnet34", encoder_depth=5, encoder_weights='imagenet', decoder_channels=[256, 128, 64, 32, 16], in_channels=1, classes=1, activation='sigmoid') 27 | pre_seg_optimizer = torch.optim.Adam([ 28 | dict(params=pre_seg_model.parameters(), lr=1e-4) 29 | ]) 30 | pre_seg_model.cuda(opts.gpu) 31 | pre_loss = smp.losses.DiceLoss(mode='binary', from_logits=False) 32 | 33 | model = DCDA(opts) 34 | model.setgpu(opts.gpu) 35 | if opts.resume is None: 36 | model.initialize() 37 | ep0 = -1 38 | total_it = 0 39 | else: 40 | ep0, total_it = model.resume(opts.resume) 41 | model.set_scheduler(opts, last_ep=ep0) 42 | ep0 += 1 43 | print('start the training at epoch %d'%(ep0)) 44 | 45 | # saver for display and output 46 | saver = Saver(opts) 47 | 48 | # train 49 | print('\n--- train ---') 50 | 51 | for ep in range(ep0, 100): 52 | for it, (images_a, images_b, gt, _) in enumerate(pre_seg_loader): 53 | images_a = images_a.cuda(opts.gpu) 54 | images_b = images_b.cuda(opts.gpu) 55 | gt = gt.cuda(opts.gpu) 56 | pre_seg_optimizer.zero_grad() 57 | outputs = pre_seg_model(images_a) 58 | losses = pre_loss(outputs, gt) 59 | losses.backward() 60 | pre_seg_optimizer.step() 61 | 62 | torch.autograd.set_detect_anomaly(True) 63 | max_it = 500000 64 | for ep in range(ep0, opts.n_ep): 65 | for it, (images_a, images_b, gt, _) in enumerate(train_loader): 66 | if images_a.size(0) != opts.batch_size or images_b.size(0) != opts.batch_size: 67 | continue 68 | 69 | # input data 70 | images_a = images_a.cuda(opts.gpu).detach() 71 | images_b = images_b.cuda(opts.gpu).detach() 72 | # update model 73 | if (it + 1) % opts.d_iter != 0 and it < len(train_loader) - 2: 74 | model.update_D_content(images_a[:, 0, :, :].unsqueeze(1), images_b[:, 0, :, :].unsqueeze(1)) 75 | continue 76 | else: 77 | model.update_D(images_a[:, 0, :, :].unsqueeze(1), images_b[:, 0, :, :].unsqueeze(1)) 78 | model.update_EG() 79 | if ep >= 100: 80 | model.update_dual_seg(images_a, images_b, gt, pre_seg_model, True) 81 | # save to display file 82 | if not opts.no_display_img: 83 | saver.write_display(total_it, model) 84 | 85 | print('total_it: %d (ep %d, it %d), lr %08f' % (total_it, ep, it, model.gen_opt.param_groups[0]['lr'])) 86 | total_it += 1 87 | if total_it >= max_it: 88 | saver.write_img(-1, model) 89 | saver.write_model(-1, model) 90 | break 91 | 92 | # decay learning rate 93 | if opts.n_ep_decay > -1: 94 | model.update_lr() 95 | 96 | # save result image 97 | saver.write_img(ep, model) 98 | 99 | # Save network weights 100 | saver.write_model(ep, total_it, model) 101 | 102 | return 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data as data 3 | from PIL import Image 4 | from PIL import ImageOps 5 | from torchvision.transforms import Compose, Resize, RandomCrop, CenterCrop, RandomHorizontalFlip, ToTensor, Normalize 6 | import random 7 | import numpy as np 8 | import cv2 9 | import torch 10 | 11 | class dataset_single(data.Dataset): 12 | def __init__(self, opts, setname, input_dim): 13 | self.dataroot = opts.dataroot 14 | images = os.listdir(os.path.join(self.dataroot, opts.test_phase1 + 'img')) 15 | self.img = [os.path.join(self.dataroot, opts.test_phase1 + 'img', x) for x in images] 16 | images2 = os.listdir(os.path.join(self.dataroot, opts.test_phase2 + 'img')) 17 | self.img2 = [os.path.join(self.dataroot, opts.test_phase2 + 'img', x) for x in images2] 18 | gts = os.listdir(os.path.join(self.dataroot, opts.test_phase1 + 'gt')) 19 | self.gt = [os.path.join(self.dataroot, opts.test_phase1 + 'gt', x) for x in gts] 20 | self.size = len(self.img) 21 | self.input_dim = input_dim 22 | 23 | # setup image transformation 24 | transforms = [Resize((opts.resize_size, opts.resize_size), Image.BICUBIC)] 25 | transforms.append(CenterCrop(opts.crop_size)) 26 | transforms.append(ToTensor()) 27 | transforms.append(Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])) 28 | self.transforms = Compose(transforms) 29 | print('%s: %d images'%(setname, self.size)) 30 | return 31 | 32 | def __getitem__(self, index): 33 | data1 = self.load_img(self.img[index], self.input_dim, 1) 34 | data2 = self.load_img(self.img2[index], self.input_dim, 2) 35 | gt = cv2.imread(self.img[index].replace('img', 'gt'), 0) 36 | gt = np.where(gt == 255, 1, 0) 37 | return data1, data2, gt 38 | 39 | def load_img(self, img_name, input_dim, seq): 40 | img = Image.open(img_name).convert('RGB') 41 | img = self.transforms(img) 42 | if input_dim == 1: 43 | img = img[0, ...] * 0.299 + img[1, ...] * 0.587 + img[2, ...] * 0.114 44 | img = img.unsqueeze(0) 45 | return img 46 | 47 | def __len__(self): 48 | return self.size 49 | 50 | class dataset_unpair(data.Dataset): 51 | def __init__(self, opts): 52 | self.direction = opts.direction 53 | self.dataroot = opts.dataroot 54 | self.seed = random.randint(0,6546498451351) 55 | # A 56 | images_A = os.listdir(os.path.join(self.dataroot, opts.phase1 + 'img')) 57 | self.A = [os.path.join(self.dataroot, opts.phase1 + 'img', x) for x in images_A] 58 | 59 | # B 60 | images_B = os.listdir(os.path.join(self.dataroot, opts.phase2 + 'img')) 61 | self.B = [os.path.join(self.dataroot, opts.phase2 + 'img', x) for x in images_B] 62 | 63 | #gt 64 | gts_A = os.listdir(os.path.join(self.dataroot, opts.phase1 + 'gt')) 65 | self.gt = [os.path.join(self.dataroot, opts.phase1 + 'gt', x) for x in gts_A] 66 | 67 | gts_B = os.listdir(os.path.join(self.dataroot, opts.phase2 + 'gt')) 68 | self.gt2 = [os.path.join(self.dataroot, opts.phase2 + 'gt', x) for x in gts_B] 69 | 70 | self.A_size = len(self.A) 71 | self.B_size = len(self.B) 72 | self.dataset_size = max(self.A_size, self.B_size) 73 | self.input_dim_A = opts.input_dim_a 74 | self.input_dim_B = opts.input_dim_b 75 | 76 | # setup image transformation 77 | transforms = [Resize((opts.resize_size, opts.resize_size), Image.BICUBIC)] 78 | gt_transforms = [Resize((opts.resize_size, opts.resize_size), Image.BICUBIC)] 79 | transforms.append(ToTensor()) 80 | transforms.append(Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])) 81 | self.transforms = Compose(transforms) 82 | self.gt_transforms = Compose(gt_transforms) 83 | print('A: %d, B: %d images'%(self.A_size, self.B_size)) 84 | return 85 | 86 | def __getitem__(self, index): 87 | gt = cv2.imread(self.A[index].replace('img', 'gt'), 0) 88 | gt = np.where(gt == 255, 1, 0) 89 | gt = Image.fromarray(np.uint8(gt)) 90 | gt = self.gt_transforms(gt) 91 | gt = np.array(gt) 92 | 93 | gt2 = cv2.imread(self.B[index].replace('img', 'gt'), 0) 94 | gt2 = np.where(gt2 == 255, 1, 0) 95 | gt2 = Image.fromarray(np.uint8(gt2)) 96 | gt2 = self.gt_transforms(gt2) 97 | gt2 = np.array(gt2) 98 | if self.dataset_size == self.A_size: 99 | data_A = self.load_img(self.A[index], self.input_dim_A, 'A') 100 | data_B = self.load_img(self.B[random.randint(0, self.B_size - 1)], self.input_dim_B, 'B') 101 | else: 102 | data_A = self.load_img(self.A[random.randint(0, self.A_size - 1)], self.input_dim_A, 'A') 103 | data_B = self.load_img(self.B[index], self.input_dim_B, 'B') 104 | return data_A, data_B, gt, gt2 105 | 106 | def load_img(self, img_name, input_dim, domain='A'): 107 | img = Image.open(img_name).convert('RGB') 108 | img = self.transforms(img) 109 | if input_dim == 1: 110 | img = img[0, ...] * 0.299 + img[1, ...] * 0.587 + img[2, ...] * 0.114 111 | img = img.unsqueeze(0) 112 | return img 113 | 114 | def __len__(self): 115 | return self.dataset_size 116 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from options import TestOptions 3 | from dataset import dataset_single, dataset_unpair 4 | from model import DCDA 5 | from saver import save_imgs 6 | import os 7 | import numpy as np 8 | from numpy import mean, std 9 | import pandas as pd 10 | import scipy.spatial 11 | import surface_distance 12 | 13 | 14 | def getDSC(testImage, resultImage): 15 | """Compute the Dice Similarity Coefficient.""" 16 | testArray = testImage.flatten() 17 | resultArray = resultImage.flatten() 18 | 19 | return 1.0 - scipy.spatial.distance.dice(testArray, resultArray) 20 | 21 | 22 | def getJaccard(testImage, resultImage): 23 | """Compute the Dice Similarity Coefficient.""" 24 | testArray = testImage.flatten() 25 | resultArray = resultImage.flatten() 26 | 27 | return 1.0 - scipy.spatial.distance.jaccard(testArray, resultArray) 28 | 29 | 30 | def getPrecisionAndRecall(testImage, resultImage): 31 | 32 | testArray = testImage.flatten() 33 | resultArray = resultImage.flatten() 34 | 35 | TP = np.sum(testArray*resultArray) 36 | FP = np.sum((1-testArray)*resultArray) 37 | FN = np.sum(testArray*(1-resultArray)) 38 | 39 | precision = TP/(TP+FP) 40 | recall = TP/(TP+FN) 41 | 42 | return precision, recall 43 | 44 | 45 | def getHD_ASSD(seg_preds, seg_labels): 46 | label_seg = np.array(seg_labels, dtype=bool) 47 | predict = np.array(seg_preds, dtype=bool) 48 | 49 | surface_distances = surface_distance.compute_surface_distances( 50 | label_seg, predict, spacing_mm=(1, 1)) 51 | 52 | HD = surface_distance.compute_robust_hausdorff(surface_distances, 95) 53 | 54 | distances_gt_to_pred = surface_distances["distances_gt_to_pred"] 55 | distances_pred_to_gt = surface_distances["distances_pred_to_gt"] 56 | surfel_areas_gt = surface_distances["surfel_areas_gt"] 57 | surfel_areas_pred = surface_distances["surfel_areas_pred"] 58 | 59 | ASSD = (np.sum(distances_pred_to_gt * surfel_areas_pred) + np.sum(distances_gt_to_pred * surfel_areas_gt))/(np.sum(surfel_areas_gt)+np.sum(surfel_areas_pred)) 60 | 61 | return HD, ASSD 62 | 63 | 64 | def main(): 65 | # parse options 66 | parser = TestOptions() 67 | opts = parser.parse() 68 | 69 | # data loader 70 | print('\n--- load dataset ---') 71 | if opts.a2b: 72 | dataset = dataset_single(opts, 'A', opts.input_dim_a) 73 | else: 74 | dataset = dataset_single(opts, 'B', opts.input_dim_b) 75 | loader = torch.utils.data.DataLoader(dataset, batch_size=1, num_workers=opts.nThreads, shuffle=False) 76 | 77 | # model 78 | print('\n--- load model ---') 79 | model = DCDA(opts) 80 | model.setgpu(opts.gpu) 81 | model.resume(opts.resume, train=False) 82 | model.eval() 83 | 84 | # directory 85 | result_dir = os.path.join(opts.result_dir, opts.name) 86 | if not os.path.exists(result_dir): 87 | os.mkdir(result_dir) 88 | 89 | # test 90 | print('\n--- testing ---') 91 | dices = [] 92 | pd_names = [] 93 | jaccards = [] 94 | hds = [] 95 | assds = [] 96 | for idx1, (img1, img2, gt1) in enumerate(loader): 97 | print('{}/{}'.format(idx1, len(loader))) 98 | pd_names.append(str(idx1)) 99 | img1 = img1.cuda(opts.gpu) 100 | img2 = img2.cuda(opts.gpu) 101 | gt1 = gt1.cuda(opts.gpu) 102 | imgs = [img1] 103 | names = ['input' + str(idx1)] 104 | for idx2 in range(opts.num): 105 | with torch.no_grad(): 106 | img = model.test_forward(img2, a2b=opts.a2b) 107 | outputs = model.tar_seg_model(img1) 108 | preds = torch.round(outputs) 109 | preds = preds.squeeze().cpu().detach().numpy() 110 | gt1 = gt1.squeeze().cpu().detach().numpy() 111 | dice = getDSC(gt1, preds) 112 | jac = getJaccard(gt1, preds) 113 | hd, assd = getHD_ASSD(preds, gt1) 114 | dices.append(dice) 115 | jaccards.append(jac) 116 | hds.append(hd) 117 | assds.append(assd) 118 | imgs.append(img) 119 | names.append('output_{}'.format(idx1)) 120 | #save_imgs(imgs, names, os.path.join(result_dir, '{}'.format(idx1))) 121 | save_imgs(imgs, names, os.path.join(result_dir)) 122 | 123 | dataframe = pd.DataFrame({'case': pd_names, 124 | 'rv_dice': dices, 125 | 'rv_jaccard': jaccards, 126 | 'rv_HD': hds, 'rv_ASSD': assds 127 | }) 128 | dataframe.to_csv(opts.result_dir + "/detail_metrics.csv", 129 | index=False, sep=',') 130 | print('Counting CSV generated!') 131 | mean_resultframe = pd.DataFrame({ 132 | 'rv_dice': mean(dices), 'rv_jaccard': mean(jaccards), 133 | 'rv_HD': mean(hds), 'rv_ASSD': mean(assds)}, index=[1]) 134 | mean_resultframe.to_csv(opts.result_dir + "/mean_metrics.csv", index=0) 135 | std_resultframe = pd.DataFrame({ 136 | 'rv_dice': std(dices, ddof=1), 'rv_jaccard': std(jaccards, ddof=1), 137 | 'rv_HD': std(hds, ddof=1), 'rv_ASSD': std(assds, ddof=1)}, index=[1]) 138 | std_resultframe.to_csv(opts.result_dir + "/std_metrics.csv", index=0) 139 | print('Calculating CSV generated!') 140 | return 141 | 142 | if __name__ == '__main__': 143 | main() 144 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | class TrainOptions(): 4 | def __init__(self): 5 | self.parser = argparse.ArgumentParser() 6 | 7 | # data loader related 8 | self.parser.add_argument('--dataroot', type=str, required=True, help='path of data') 9 | self.parser.add_argument('--phase1', type=str, default='train', help='phase for dataloading') 10 | self.parser.add_argument('--phase2', type=str, default='train', help='phase for dataloading') 11 | self.parser.add_argument('--batch_size', type=int, default=2, help='batch size') 12 | self.parser.add_argument('--resize_size', type=int, default=256, help='resized image size for training') 13 | self.parser.add_argument('--crop_size', type=int, default=216, help='cropped image size for training') 14 | self.parser.add_argument('--input_dim_a', type=int, default=3, help='# of input channels for domain A') 15 | self.parser.add_argument('--input_dim_b', type=int, default=3, help='# of input channels for domain B') 16 | self.parser.add_argument('--nThreads', type=int, default=8, help='# of threads for data loader') 17 | self.parser.add_argument('--no_flip', action='store_true', help='specified if no flipping') 18 | 19 | # ouptput related 20 | self.parser.add_argument('--name', type=str, default='trial', help='folder name to save outputs') 21 | self.parser.add_argument('--display_dir', type=str, default='../logs', help='path for saving display results') 22 | self.parser.add_argument('--result_dir', type=str, default='../results', help='path for saving result images and models') 23 | self.parser.add_argument('--display_freq', type=int, default=1, help='freq (iteration) of display') 24 | self.parser.add_argument('--img_save_freq', type=int, default=5, help='freq (epoch) of saving images') 25 | self.parser.add_argument('--model_save_freq', type=int, default=10, help='freq (epoch) of saving models') 26 | self.parser.add_argument('--no_display_img', action='store_true', help='specified if no dispaly') 27 | 28 | # training related 29 | self.parser.add_argument('--direction', type=str, default='AtoB', help='image translation direction') 30 | self.parser.add_argument('--no_ms', action='store_true', help='disable mode seeking regularization') 31 | self.parser.add_argument('--concat', type=int, default=1, help='concatenate attribute features for translation, set 0 for using feature-wise transform') 32 | self.parser.add_argument('--dis_scale', type=int, default=3, help='scale of discriminator') 33 | self.parser.add_argument('--dis_norm', type=str, default='None', help='normalization layer in discriminator [None, Instance]') 34 | self.parser.add_argument('--dis_spectral_norm', action='store_true', help='use spectral normalization in discriminator') 35 | self.parser.add_argument('--lr_policy', type=str, default='lambda', help='type of learn rate decay') 36 | self.parser.add_argument('--n_ep', type=int, default=1200, help='number of epochs') # 400 * d_iter 37 | self.parser.add_argument('--n_ep_decay', type=int, default=600, help='epoch start decay learning rate, set -1 if no decay') # 200 * d_iter 38 | self.parser.add_argument('--resume', type=str, default=None, help='specified the dir of saved models for resume the training') 39 | self.parser.add_argument('--d_iter', type=int, default=3, help='# of iterations for updating content discriminator') 40 | self.parser.add_argument('--gpu', type=int, default=0, help='gpu') 41 | 42 | def parse(self): 43 | self.opt = self.parser.parse_args() 44 | args = vars(self.opt) 45 | print('\n--- load options ---') 46 | for name, value in sorted(args.items()): 47 | print('%s: %s' % (str(name), str(value))) 48 | return self.opt 49 | 50 | class TestOptions(): 51 | def __init__(self): 52 | self.parser = argparse.ArgumentParser() 53 | 54 | # data loader related 55 | self.parser.add_argument('--dataroot', type=str, required=True, help='path of data') 56 | #self.parser.add_argument('--test_phase', type=str, default='test', help='phase for dataloading') 57 | self.parser.add_argument('--resize_size', type=int, default=256, help='resized image size for training') 58 | self.parser.add_argument('--crop_size', type=int, default=216, help='cropped image size for training') 59 | self.parser.add_argument('--nThreads', type=int, default=4, help='for data loader') 60 | self.parser.add_argument('--input_dim_a', type=int, default=3, help='# of input channels for domain A') 61 | self.parser.add_argument('--input_dim_b', type=int, default=3, help='# of input channels for domain B') 62 | self.parser.add_argument('--a2b', type=int, default=1, help='translation direction, 1 for a2b, 0 for b2a') 63 | self.parser.add_argument('--test_phase1', type=str, default='test', help='phase for dataloading') 64 | self.parser.add_argument('--test_phase2', type=str, default='test', help='phase for dataloading') 65 | # ouptput related 66 | self.parser.add_argument('--num', type=int, default=5, help='number of outputs per image') 67 | self.parser.add_argument('--name', type=str, default='trial', help='folder name to save outputs') 68 | self.parser.add_argument('--direction', type=str, default='AtoB', help='transfer direction') 69 | self.parser.add_argument('--result_dir', type=str, default='../outputs', help='path for saving result images and models') 70 | 71 | # model related 72 | self.parser.add_argument('--concat', type=int, default=1, help='concatenate attribute features for translation, set 0 for using feature-wise transform') 73 | self.parser.add_argument('--no_ms', action='store_true', help='disable mode seeking regularization') 74 | self.parser.add_argument('--resume', type=str, required=True, help='specified the dir of saved models for resume the training') 75 | self.parser.add_argument('--gpu', type=int, default=0, help='gpu') 76 | 77 | def parse(self): 78 | self.opt = self.parser.parse_args() 79 | args = vars(self.opt) 80 | print('\n--- load options ---') 81 | for name, value in sorted(args.items()): 82 | print('%s: %s' % (str(name), str(value))) 83 | # set irrelevant options 84 | self.opt.dis_scale = 3 85 | self.opt.dis_norm = 'None' 86 | self.opt.dis_spectral_norm = False 87 | return self.opt 88 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | import networks 3 | import torch 4 | import torch.nn as nn 5 | import segmentation_models_pytorch as smp 6 | import numpy as np 7 | import cv2 8 | import torch.nn.functional as F 9 | 10 | 11 | class DCDA(nn.Module): 12 | def __init__(self, opts): 13 | super(DCDA, self).__init__() 14 | 15 | # parameters 16 | lr = 0.0001 17 | lr_dcontent = lr / 2.5 18 | self.nz = 8 19 | self.concat = opts.concat 20 | self.no_ms = opts.no_ms 21 | 22 | # discriminators 23 | if opts.dis_scale > 1: 24 | self.disA = networks.MultiScaleDis(opts.input_dim_a, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm) 25 | self.disB = networks.MultiScaleDis(opts.input_dim_b, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm) 26 | self.disA2 = networks.MultiScaleDis(opts.input_dim_a, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm) 27 | self.disB2 = networks.MultiScaleDis(opts.input_dim_b, opts.dis_scale, norm=opts.dis_norm, sn=opts.dis_spectral_norm) 28 | else: 29 | self.disA = networks.Dis(opts.input_dim_a, norm=opts.dis_norm, sn=opts.dis_spectral_norm) 30 | self.disB = networks.Dis(opts.input_dim_b, norm=opts.dis_norm, sn=opts.dis_spectral_norm) 31 | self.disA2 = networks.Dis(opts.input_dim_a, norm=opts.dis_norm, sn=opts.dis_spectral_norm) 32 | self.disB2 = networks.Dis(opts.input_dim_b, norm=opts.dis_norm, sn=opts.dis_spectral_norm) 33 | self.disContent = networks.Dis_content() 34 | 35 | # encoders 36 | self.enc_c = networks.E_content(opts.input_dim_a, opts.input_dim_b) 37 | if self.concat: 38 | self.enc_a = networks.E_attr_concat(opts.input_dim_a, opts.input_dim_b, self.nz, \ 39 | norm_layer=None, nl_layer=networks.get_non_linearity(layer_type='lrelu')) 40 | else: 41 | self.enc_a = networks.E_attr(opts.input_dim_a, opts.input_dim_b, self.nz) 42 | 43 | # generator 44 | if self.concat: 45 | self.gen = networks.G_concat(opts.input_dim_a, opts.input_dim_b, nz=self.nz) 46 | else: 47 | self.gen = networks.G(opts.input_dim_a, opts.input_dim_b, nz=self.nz) 48 | 49 | # optimizers 50 | self.disA_opt = torch.optim.Adam(self.disA.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) 51 | self.disB_opt = torch.optim.Adam(self.disB.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) 52 | self.disA2_opt = torch.optim.Adam(self.disA2.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) 53 | self.disB2_opt = torch.optim.Adam(self.disB2.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) 54 | self.disContent_opt = torch.optim.Adam(self.disContent.parameters(), lr=lr_dcontent, betas=(0.5, 0.999), weight_decay=0.0001) 55 | self.enc_c_opt = torch.optim.Adam(self.enc_c.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) 56 | self.enc_a_opt = torch.optim.Adam(self.enc_a.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) 57 | self.gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=0.0001) 58 | 59 | #segmentation model 60 | self.src_seg_model = smp.Unet(encoder_name="resnet34", encoder_depth=5, encoder_weights='imagenet', decoder_channels=[256, 128, 64, 32, 16], in_channels=1, classes=1, activation='sigmoid') 61 | self.src_seg_optimizer = torch.optim.Adam([ 62 | dict(params=self.src_seg_model.parameters(), lr=1e-4) 63 | ]) 64 | self.tar_seg_model = smp.Unet(encoder_name="resnet34", encoder_depth=5, encoder_weights='imagenet', decoder_channels=[256, 128, 64, 32, 16], in_channels=1, classes=1, activation='sigmoid') 65 | self.tar_seg_optimizer = torch.optim.Adam([ 66 | dict(params=self.tar_seg_model.parameters(), lr=1e-4) 67 | ]) 68 | 69 | # Setup the loss function for training 70 | self.criterionL1 = torch.nn.L1Loss() 71 | self.dice_loss = smp.losses.DiceLoss(mode='binary', from_logits=False) 72 | self.ce_loss = nn.BCELoss() 73 | 74 | def initialize(self): 75 | self.disA.apply(networks.gaussian_weights_init) 76 | self.disB.apply(networks.gaussian_weights_init) 77 | self.disA2.apply(networks.gaussian_weights_init) 78 | self.disB2.apply(networks.gaussian_weights_init) 79 | self.disContent.apply(networks.gaussian_weights_init) 80 | self.gen.apply(networks.gaussian_weights_init) 81 | self.enc_c.apply(networks.gaussian_weights_init) 82 | self.enc_a.apply(networks.gaussian_weights_init) 83 | 84 | def set_scheduler(self, opts, last_ep=0): 85 | self.disA_sch = networks.get_scheduler(self.disA_opt, opts, last_ep) 86 | self.disB_sch = networks.get_scheduler(self.disB_opt, opts, last_ep) 87 | self.disA2_sch = networks.get_scheduler(self.disA2_opt, opts, last_ep) 88 | self.disB2_sch = networks.get_scheduler(self.disB2_opt, opts, last_ep) 89 | self.disContent_sch = networks.get_scheduler(self.disContent_opt, opts, last_ep) 90 | self.enc_c_sch = networks.get_scheduler(self.enc_c_opt, opts, last_ep) 91 | self.enc_a_sch = networks.get_scheduler(self.enc_a_opt, opts, last_ep) 92 | self.gen_sch = networks.get_scheduler(self.gen_opt, opts, last_ep) 93 | 94 | def setgpu(self, gpu): 95 | self.gpu = gpu 96 | self.disA.cuda(self.gpu) 97 | self.disB.cuda(self.gpu) 98 | self.disA2.cuda(self.gpu) 99 | self.disB2.cuda(self.gpu) 100 | self.disContent.cuda(self.gpu) 101 | self.enc_c.cuda(self.gpu) 102 | self.enc_a.cuda(self.gpu) 103 | self.gen.cuda(self.gpu) 104 | self.src_seg_model.cuda(self.gpu) 105 | self.tar_seg_model.cuda(self.gpu) 106 | 107 | def get_z_random(self, batchSize, nz, random_type='gauss'): 108 | z = torch.randn(batchSize, nz).cuda(self.gpu) 109 | return z 110 | 111 | def test_forward(self, image, a2b=True): 112 | self.z_random = self.get_z_random(image.size(0), self.nz, 'gauss') 113 | if a2b: 114 | self.z_content = self.enc_c.forward_a(image) 115 | output = self.gen.forward_b(self.z_content, self.z_random) 116 | else: 117 | self.z_content = self.enc_c.forward_b(image) 118 | output = self.gen.forward_a(self.z_content, self.z_random) 119 | return output 120 | 121 | def test_forward_transfer(self, image_a, image_b, a2b=True): 122 | self.z_content_a, self.z_content_b = self.enc_c.forward(image_a, image_b) 123 | if self.concat: 124 | self.mu_a, self.logvar_a, self.mu_b, self.logvar_b = self.enc_a.forward(image_a, image_b) 125 | std_a = self.logvar_a.mul(0.5).exp_() 126 | eps = self.get_z_random(std_a.size(0), std_a.size(1), 'gauss') 127 | self.z_attr_a = eps.mul(std_a).add_(self.mu_a) 128 | std_b = self.logvar_b.mul(0.5).exp_() 129 | eps = self.get_z_random(std_b.size(0), std_b.size(1), 'gauss') 130 | self.z_attr_b = eps.mul(std_b).add_(self.mu_b) 131 | else: 132 | self.z_attr_a, self.z_attr_b = self.enc_a.forward(image_a, image_b) 133 | if a2b: 134 | output = self.gen.forward_b(self.z_content_a, self.z_attr_b) 135 | else: 136 | output = self.gen.forward_a(self.z_content_b, self.z_attr_a) 137 | return output 138 | 139 | def forward(self, half_size=1): 140 | # input images 141 | half_size = half_size 142 | real_A = self.input_A 143 | real_B = self.input_B 144 | self.real_A_encoded = real_A[0:half_size] 145 | self.real_A_random = real_A[half_size:] 146 | self.real_B_encoded = real_B[0:half_size] 147 | self.real_B_random = real_B[half_size:] 148 | 149 | # get encoded z_c 150 | self.z_content_a, self.z_content_b = self.enc_c.forward(self.real_A_encoded, self.real_B_encoded) 151 | 152 | # get encoded z_a 153 | if self.concat: 154 | self.mu_a, self.logvar_a, self.mu_b, self.logvar_b = self.enc_a.forward(self.real_A_encoded, self.real_B_encoded) 155 | std_a = self.logvar_a.mul(0.5).exp_() 156 | eps_a = self.get_z_random(std_a.size(0), std_a.size(1), 'gauss') 157 | self.z_attr_a = eps_a.mul(std_a).add_(self.mu_a) 158 | std_b = self.logvar_b.mul(0.5).exp_() 159 | eps_b = self.get_z_random(std_b.size(0), std_b.size(1), 'gauss') 160 | self.z_attr_b = eps_b.mul(std_b).add_(self.mu_b) 161 | else: 162 | self.z_attr_a, self.z_attr_b = self.enc_a.forward(self.real_A_encoded, self.real_B_encoded) 163 | 164 | # get random z_a 165 | self.z_random = self.get_z_random(self.real_A_encoded.size(0), self.nz, 'gauss') 166 | if not self.no_ms: 167 | self.z_random2 = self.get_z_random(self.real_A_encoded.size(0), self.nz, 'gauss') 168 | 169 | # first cross translation 170 | if not self.no_ms: 171 | input_content_forA = torch.cat((self.z_content_b, self.z_content_a, self.z_content_b, self.z_content_b),0) 172 | input_content_forB = torch.cat((self.z_content_a, self.z_content_b, self.z_content_a, self.z_content_a),0) 173 | input_attr_forA = torch.cat((self.z_attr_a, self.z_attr_a, self.z_random, self.z_random2),0) 174 | input_attr_forB = torch.cat((self.z_attr_b, self.z_attr_b, self.z_random, self.z_random2),0) 175 | output_fakeA = self.gen.forward_a(input_content_forA, input_attr_forA) 176 | output_fakeB = self.gen.forward_b(input_content_forB, input_attr_forB) 177 | self.fake_A_encoded, self.fake_AA_encoded, self.fake_A_random, self.fake_A_random2 = torch.split(output_fakeA, self.z_content_a.size(0), dim=0) 178 | self.fake_B_encoded, self.fake_BB_encoded, self.fake_B_random, self.fake_B_random2 = torch.split(output_fakeB, self.z_content_a.size(0), dim=0) 179 | else: 180 | input_content_forA = torch.cat((self.z_content_b, self.z_content_a, self.z_content_b),0) 181 | input_content_forB = torch.cat((self.z_content_a, self.z_content_b, self.z_content_a),0) 182 | input_attr_forA = torch.cat((self.z_attr_a, self.z_attr_a, self.z_random),0) 183 | input_attr_forB = torch.cat((self.z_attr_b, self.z_attr_b, self.z_random),0) 184 | output_fakeA = self.gen.forward_a(input_content_forA, input_attr_forA) 185 | output_fakeB = self.gen.forward_b(input_content_forB, input_attr_forB) 186 | self.fake_A_encoded, self.fake_AA_encoded, self.fake_A_random = torch.split(output_fakeA, self.z_content_a.size(0), dim=0) 187 | self.fake_B_encoded, self.fake_BB_encoded, self.fake_B_random = torch.split(output_fakeB, self.z_content_a.size(0), dim=0) 188 | 189 | # get reconstructed encoded z_c 190 | self.z_content_recon_b, self.z_content_recon_a = self.enc_c.forward(self.fake_A_encoded, self.fake_B_encoded) 191 | 192 | # get reconstructed encoded z_a 193 | if self.concat: 194 | self.mu_recon_a, self.logvar_recon_a, self.mu_recon_b, self.logvar_recon_b = self.enc_a.forward(self.fake_A_encoded, self.fake_B_encoded) 195 | std_a = self.logvar_recon_a.mul(0.5).exp_() 196 | eps_a = self.get_z_random(std_a.size(0), std_a.size(1), 'gauss') 197 | self.z_attr_recon_a = eps_a.mul(std_a).add_(self.mu_recon_a) 198 | std_b = self.logvar_recon_b.mul(0.5).exp_() 199 | eps_b = self.get_z_random(std_b.size(0), std_b.size(1), 'gauss') 200 | self.z_attr_recon_b = eps_b.mul(std_b).add_(self.mu_recon_b) 201 | else: 202 | self.z_attr_recon_a, self.z_attr_recon_b = self.enc_a.forward(self.fake_A_encoded, self.fake_B_encoded) 203 | 204 | # second cross translation 205 | self.fake_A_recon = self.gen.forward_a(self.z_content_recon_a, self.z_attr_recon_a) 206 | self.fake_B_recon = self.gen.forward_b(self.z_content_recon_b, self.z_attr_recon_b) 207 | 208 | # for display 209 | self.image_display = torch.cat((self.real_A_encoded[0:1].detach().cpu(), self.fake_B_encoded[0:1].detach().cpu(), \ 210 | self.fake_B_random[0:1].detach().cpu(), self.fake_AA_encoded[0:1].detach().cpu(), self.fake_A_recon[0:1].detach().cpu(), \ 211 | self.real_B_encoded[0:1].detach().cpu(), self.fake_A_encoded[0:1].detach().cpu(), \ 212 | self.fake_A_random[0:1].detach().cpu(), self.fake_BB_encoded[0:1].detach().cpu(), self.fake_B_recon[0:1].detach().cpu()), dim=0) 213 | 214 | # for latent regression 215 | if self.concat: 216 | self.mu2_a, _, self.mu2_b, _ = self.enc_a.forward(self.fake_A_random, self.fake_B_random) 217 | else: 218 | self.z_attr_random_a, self.z_attr_random_b = self.enc_a.forward(self.fake_A_random, self.fake_B_random) 219 | 220 | def forward_content(self): 221 | half_size = 1 222 | self.real_A_encoded = self.input_A[0:half_size] 223 | self.real_B_encoded = self.input_B[0:half_size] 224 | # get encoded z_c 225 | self.z_content_a, self.z_content_b = self.enc_c.forward(self.real_A_encoded, self.real_B_encoded) 226 | 227 | 228 | def update_dual_seg(self, image_a, image_b, gt, src_seg, pretrained=False): 229 | if pretrained == True: 230 | self.src_seg_model = src_seg 231 | self.src_seg_model.train() 232 | self.tar_seg_model.train() 233 | torch.set_grad_enabled(True) 234 | self.src_seg_optimizer.zero_grad() 235 | self.tar_seg_optimizer.zero_grad() 236 | gt = gt[0:1] 237 | gt = gt.to(self.gpu) 238 | self.real_A_encoded = self.real_A_encoded.to(self.gpu) 239 | self.real_B_encoded = self.real_B_encoded.to(self.gpu) 240 | self.fake_B_encoded = self.fake_B_encoded.to(self.gpu) 241 | self.fake_A_encoded = self.fake_A_encoded.to(self.gpu) 242 | self.fake_A_recon = self.fake_A_recon.to(self.gpu) 243 | self.fake_B_recon = self.fake_B_recon.to(self.gpu) 244 | 245 | outa_real = self.src_seg_model(self.real_A_encoded) 246 | outb_real = self.tar_seg_model(self.real_B_encoded) 247 | outa2b = self.tar_seg_model(self.fake_B_encoded) 248 | outb2a = self.src_seg_model(self.fake_A_encoded) 249 | 250 | dice_a_real = self.dice_loss(outa_real, gt) 251 | dice_a2b = self.dice_loss(outa2b, gt) 252 | bce_a = self.ce_loss(outa2b, torch.round(outa_real).detach()) 253 | bce_b = self.ce_loss(outb_real, torch.round(outb2a).detach()) 254 | all_loss = dice_a_real + dice_a2b + bce_a + bce_b 255 | all_loss.backward() 256 | self.src_seg_optimizer.step() 257 | self.tar_seg_optimizer.step() 258 | 259 | def update_D_content(self, image_a, image_b): 260 | self.input_A = image_a 261 | self.input_B = image_b 262 | self.forward_content() 263 | self.disContent_opt.zero_grad() 264 | loss_D_Content = self.backward_contentD(self.z_content_a, self.z_content_b) 265 | self.disContent_loss = loss_D_Content.item() 266 | nn.utils.clip_grad_norm_(self.disContent.parameters(), 5) 267 | self.disContent_opt.step() 268 | 269 | def update_D(self, image_a, image_b): 270 | self.input_A = image_a 271 | self.input_B = image_b 272 | self.forward() 273 | 274 | # update disA 275 | self.disA_opt.zero_grad() 276 | loss_D1_A = self.backward_D(self.disA, self.real_A_encoded, self.fake_A_encoded) 277 | self.disA_loss = loss_D1_A.item() 278 | self.disA_opt.step() 279 | 280 | # update disA2 281 | self.disA2_opt.zero_grad() 282 | loss_D2_A = self.backward_D(self.disA2, self.real_A_random, self.fake_A_random) 283 | self.disA2_loss = loss_D2_A.item() 284 | if not self.no_ms: 285 | loss_D2_A2 = self.backward_D(self.disA2, self.real_A_random, self.fake_A_random2) 286 | self.disA2_loss += loss_D2_A2.item() 287 | self.disA2_opt.step() 288 | 289 | # update disB 290 | self.disB_opt.zero_grad() 291 | loss_D1_B = self.backward_D(self.disB, self.real_B_encoded, self.fake_B_encoded) 292 | self.disB_loss = loss_D1_B.item() 293 | self.disB_opt.step() 294 | 295 | # update disB2 296 | self.disB2_opt.zero_grad() 297 | loss_D2_B = self.backward_D(self.disB2, self.real_B_random, self.fake_B_random) 298 | self.disB2_loss = loss_D2_B.item() 299 | if not self.no_ms: 300 | loss_D2_B2 = self.backward_D(self.disB2, self.real_B_random, self.fake_B_random2) 301 | self.disB2_loss += loss_D2_B2.item() 302 | self.disB2_opt.step() 303 | 304 | # update disContent 305 | self.disContent_opt.zero_grad() 306 | loss_D_Content = self.backward_contentD(self.z_content_a, self.z_content_b) 307 | self.disContent_loss = loss_D_Content.item() 308 | nn.utils.clip_grad_norm_(self.disContent.parameters(), 5) 309 | self.disContent_opt.step() 310 | 311 | def backward_D(self, netD, real, fake): 312 | pred_fake = netD.forward(fake.detach()) 313 | pred_real = netD.forward(real) 314 | loss_D = 0 315 | for it, (out_a, out_b) in enumerate(zip(pred_fake, pred_real)): 316 | out_fake = nn.functional.sigmoid(out_a) 317 | out_real = nn.functional.sigmoid(out_b) 318 | all0 = torch.zeros_like(out_fake).cuda(self.gpu) 319 | all1 = torch.ones_like(out_real).cuda(self.gpu) 320 | ad_fake_loss = nn.functional.binary_cross_entropy(out_fake, all0) 321 | ad_true_loss = nn.functional.binary_cross_entropy(out_real, all1) 322 | loss_D += ad_true_loss + ad_fake_loss 323 | loss_D.backward() 324 | return loss_D 325 | 326 | def backward_contentD(self, imageA, imageB): 327 | pred_fake = self.disContent.forward(imageA.detach()) 328 | pred_real = self.disContent.forward(imageB.detach()) 329 | for it, (out_a, out_b) in enumerate(zip(pred_fake, pred_real)): 330 | out_fake = nn.functional.sigmoid(out_a) 331 | out_real = nn.functional.sigmoid(out_b) 332 | all1 = torch.ones((out_real.size(0))).cuda(self.gpu) 333 | all0 = torch.zeros((out_fake.size(0))).cuda(self.gpu) 334 | ad_true_loss = nn.functional.binary_cross_entropy(out_real, all1) 335 | ad_fake_loss = nn.functional.binary_cross_entropy(out_fake, all0) 336 | loss_D = ad_true_loss + ad_fake_loss 337 | loss_D.backward() 338 | return loss_D 339 | 340 | def update_EG(self): 341 | # update G, Ec, Ea 342 | self.enc_c_opt.zero_grad() 343 | self.enc_a_opt.zero_grad() 344 | self.gen_opt.zero_grad() 345 | self.backward_EG() 346 | self.enc_c_opt.step() 347 | self.enc_a_opt.step() 348 | self.gen_opt.step() 349 | 350 | # update G, Ec 351 | self.enc_c_opt.zero_grad() 352 | self.gen_opt.zero_grad() 353 | self.backward_G_alone() 354 | self.enc_c_opt.step() 355 | self.gen_opt.step() 356 | 357 | def backward_EG(self): 358 | # content Ladv for generator 359 | loss_G_GAN_Acontent = self.backward_G_GAN_content(self.z_content_a) 360 | loss_G_GAN_Bcontent = self.backward_G_GAN_content(self.z_content_b) 361 | 362 | # Ladv for generator 363 | loss_G_GAN_A = self.backward_G_GAN(self.fake_A_encoded, self.disA) 364 | loss_G_GAN_B = self.backward_G_GAN(self.fake_B_encoded, self.disB) 365 | 366 | # KL loss - z_a 367 | if self.concat: 368 | kl_element_a = self.mu_a.pow(2).add_(self.logvar_a.exp()).mul_(-1).add_(1).add_(self.logvar_a) 369 | loss_kl_za_a = torch.sum(kl_element_a).mul_(-0.5) * 0.01 370 | kl_element_b = self.mu_b.pow(2).add_(self.logvar_b.exp()).mul_(-1).add_(1).add_(self.logvar_b) 371 | loss_kl_za_b = torch.sum(kl_element_b).mul_(-0.5) * 0.01 372 | else: 373 | loss_kl_za_a = self._l2_regularize(self.z_attr_a) * 0.01 374 | loss_kl_za_b = self._l2_regularize(self.z_attr_b) * 0.01 375 | 376 | # KL loss - z_c 377 | loss_kl_zc_a = self._l2_regularize(self.z_content_a) * 0.01 378 | loss_kl_zc_b = self._l2_regularize(self.z_content_b) * 0.01 379 | 380 | # cross cycle consistency loss 381 | loss_G_L1_A = self.criterionL1(self.fake_A_recon, self.real_A_encoded) * 10 382 | loss_G_L1_B = self.criterionL1(self.fake_B_recon, self.real_B_encoded) * 10 383 | loss_G_L1_AA = self.criterionL1(self.fake_AA_encoded, self.real_A_encoded) * 10 384 | loss_G_L1_BB = self.criterionL1(self.fake_BB_encoded, self.real_B_encoded) * 10 385 | 386 | loss_G = loss_G_GAN_A + loss_G_GAN_B + \ 387 | loss_G_GAN_Acontent + loss_G_GAN_Bcontent + \ 388 | loss_G_L1_AA + loss_G_L1_BB + \ 389 | loss_G_L1_A + loss_G_L1_B + \ 390 | loss_kl_zc_a + loss_kl_zc_b + \ 391 | loss_kl_za_a + loss_kl_za_b 392 | 393 | loss_G.backward(retain_graph=True) 394 | 395 | self.gan_loss_a = loss_G_GAN_A.item() 396 | self.gan_loss_b = loss_G_GAN_B.item() 397 | self.gan_loss_acontent = loss_G_GAN_Acontent.item() 398 | self.gan_loss_bcontent = loss_G_GAN_Bcontent.item() 399 | self.kl_loss_za_a = loss_kl_za_a.item() 400 | self.kl_loss_za_b = loss_kl_za_b.item() 401 | self.kl_loss_zc_a = loss_kl_zc_a.item() 402 | self.kl_loss_zc_b = loss_kl_zc_b.item() 403 | self.l1_recon_A_loss = loss_G_L1_A.item() 404 | self.l1_recon_B_loss = loss_G_L1_B.item() 405 | self.l1_recon_AA_loss = loss_G_L1_AA.item() 406 | self.l1_recon_BB_loss = loss_G_L1_BB.item() 407 | self.G_loss = loss_G.item() 408 | 409 | def backward_G_GAN_content(self, data): 410 | outs = self.disContent.forward(data) 411 | for out in outs: 412 | outputs_fake = nn.functional.sigmoid(out) 413 | all_half = 0.5*torch.ones((outputs_fake.size(0))).cuda(self.gpu) 414 | ad_loss = nn.functional.binary_cross_entropy(outputs_fake, all_half) 415 | return ad_loss 416 | 417 | def backward_G_GAN(self, fake, netD=None): 418 | outs_fake = netD.forward(fake) 419 | loss_G = 0 420 | for out_a in outs_fake: 421 | outputs_fake = nn.functional.sigmoid(out_a) 422 | all_ones = torch.ones_like(outputs_fake).cuda(self.gpu) 423 | loss_G += nn.functional.binary_cross_entropy(outputs_fake, all_ones) 424 | return loss_G 425 | 426 | def backward_G_alone(self): 427 | # Ladv for generator 428 | loss_G_GAN2_A = self.backward_G_GAN(self.fake_A_random, self.disA2) 429 | loss_G_GAN2_B = self.backward_G_GAN(self.fake_B_random, self.disB2) 430 | if not self.no_ms: 431 | loss_G_GAN2_A2 = self.backward_G_GAN(self.fake_A_random2, self.disA2) 432 | loss_G_GAN2_B2 = self.backward_G_GAN(self.fake_B_random2, self.disB2) 433 | 434 | # mode seeking loss for A-->B and B-->A 435 | if not self.no_ms: 436 | lz_AB = torch.mean(torch.abs(self.fake_B_random2 - self.fake_B_random)) / torch.mean(torch.abs(self.z_random2 - self.z_random)) 437 | lz_BA = torch.mean(torch.abs(self.fake_A_random2 - self.fake_A_random)) / torch.mean(torch.abs(self.z_random2 - self.z_random)) 438 | eps = 1 * 1e-5 439 | loss_lz_AB = 1 / (lz_AB + eps) 440 | loss_lz_BA = 1 / (lz_BA + eps) 441 | # latent regression loss 442 | if self.concat: 443 | loss_z_L1_a = torch.mean(torch.abs(self.mu2_a - self.z_random)) * 10 444 | loss_z_L1_b = torch.mean(torch.abs(self.mu2_b - self.z_random)) * 10 445 | else: 446 | loss_z_L1_a = torch.mean(torch.abs(self.z_attr_random_a - self.z_random)) * 10 447 | loss_z_L1_b = torch.mean(torch.abs(self.z_attr_random_b - self.z_random)) * 10 448 | 449 | loss_z_L1 = loss_z_L1_a + loss_z_L1_b + loss_G_GAN2_A + loss_G_GAN2_B 450 | if not self.no_ms: 451 | loss_z_L1 += (loss_G_GAN2_A2 + loss_G_GAN2_B2) 452 | loss_z_L1 += (loss_lz_AB + loss_lz_BA) 453 | loss_z_L1.backward(retain_graph=True) 454 | self.l1_recon_z_loss_a = loss_z_L1_a.item() 455 | self.l1_recon_z_loss_b = loss_z_L1_b.item() 456 | if not self.no_ms: 457 | self.gan2_loss_a = loss_G_GAN2_A.item() + loss_G_GAN2_A2.item() 458 | self.gan2_loss_b = loss_G_GAN2_B.item() + loss_G_GAN2_B2.item() 459 | self.lz_AB = loss_lz_AB.item() 460 | self.lz_BA = loss_lz_BA.item() 461 | else: 462 | self.gan2_loss_a = loss_G_GAN2_A.item() 463 | self.gan2_loss_b = loss_G_GAN2_B.item() 464 | def update_lr(self): 465 | self.disA_sch.step() 466 | self.disB_sch.step() 467 | self.disA2_sch.step() 468 | self.disB2_sch.step() 469 | self.disContent_sch.step() 470 | self.enc_c_sch.step() 471 | self.enc_a_sch.step() 472 | self.gen_sch.step() 473 | 474 | def _l2_regularize(self, mu): 475 | mu_2 = torch.pow(mu, 2) 476 | encoding_loss = torch.mean(mu_2) 477 | return encoding_loss 478 | 479 | def resume(self, model_dir, train=True): 480 | checkpoint = torch.load(model_dir) 481 | # weight 482 | if train: 483 | self.disA.load_state_dict(checkpoint['disA']) 484 | self.disA2.load_state_dict(checkpoint['disA2']) 485 | self.disB.load_state_dict(checkpoint['disB']) 486 | self.disB2.load_state_dict(checkpoint['disB2']) 487 | self.disContent.load_state_dict(checkpoint['disContent']) 488 | self.enc_c.load_state_dict(checkpoint['enc_c']) 489 | self.enc_a.load_state_dict(checkpoint['enc_a']) 490 | self.gen.load_state_dict(checkpoint['gen']) 491 | self.src_seg_model.load_state_dict(checkpoint['src_seg']) 492 | self.tar_seg_model.load_state_dict(checkpoint['tar_seg']) 493 | # optimizer 494 | if train: 495 | self.disA_opt.load_state_dict(checkpoint['disA_opt']) 496 | self.disA2_opt.load_state_dict(checkpoint['disA2_opt']) 497 | self.disB_opt.load_state_dict(checkpoint['disB_opt']) 498 | self.disB2_opt.load_state_dict(checkpoint['disB2_opt']) 499 | self.disContent_opt.load_state_dict(checkpoint['disContent_opt']) 500 | self.enc_c_opt.load_state_dict(checkpoint['enc_c_opt']) 501 | self.enc_a_opt.load_state_dict(checkpoint['enc_a_opt']) 502 | self.gen_opt.load_state_dict(checkpoint['gen_opt']) 503 | return checkpoint['ep'], checkpoint['total_it'] 504 | 505 | def save(self, filename, ep, total_it): 506 | state = { 507 | 'disA': self.disA.state_dict(), 508 | 'disA2': self.disA2.state_dict(), 509 | 'disB': self.disB.state_dict(), 510 | 'disB2': self.disB2.state_dict(), 511 | 'disContent': self.disContent.state_dict(), 512 | 'enc_c': self.enc_c.state_dict(), 513 | 'enc_a': self.enc_a.state_dict(), 514 | 'gen': self.gen.state_dict(), 515 | 'disA_opt': self.disA_opt.state_dict(), 516 | 'disA2_opt': self.disA2_opt.state_dict(), 517 | 'disB_opt': self.disB_opt.state_dict(), 518 | 'disB2_opt': self.disB2_opt.state_dict(), 519 | 'disContent_opt': self.disContent_opt.state_dict(), 520 | 'enc_c_opt': self.enc_c_opt.state_dict(), 521 | 'enc_a_opt': self.enc_a_opt.state_dict(), 522 | 'gen_opt': self.gen_opt.state_dict(), 523 | 'src_seg': self.src_seg_model.state_dict(), 524 | 'tar_seg': self.tar_seg_model.state_dict(), 525 | 'ep': ep, 526 | 'total_it': total_it 527 | } 528 | torch.save(state, filename) 529 | return 530 | 531 | def assemble_outputs(self): 532 | images_a = self.normalize_image(self.real_A_encoded).detach() 533 | images_b = self.normalize_image(self.real_B_encoded).detach() 534 | images_a1 = self.normalize_image(self.fake_A_encoded).detach() 535 | images_a2 = self.normalize_image(self.fake_A_random).detach() 536 | images_a3 = self.normalize_image(self.fake_A_recon).detach() 537 | images_a4 = self.normalize_image(self.fake_AA_encoded).detach() 538 | images_b1 = self.normalize_image(self.fake_B_encoded).detach() 539 | images_b2 = self.normalize_image(self.fake_B_random).detach() 540 | images_b3 = self.normalize_image(self.fake_B_recon).detach() 541 | images_b4 = self.normalize_image(self.fake_BB_encoded).detach() 542 | row1 = torch.cat((images_a[0:1, ::], images_b1[0:1, ::], images_b2[0:1, ::], images_a4[0:1, ::], images_a3[0:1, ::]),3) 543 | row2 = torch.cat((images_b[0:1, ::], images_a1[0:1, ::], images_a2[0:1, ::], images_b4[0:1, ::], images_b3[0:1, ::]),3) 544 | return torch.cat((row1,row2),2) 545 | 546 | def normalize_image(self, x): 547 | return x[:,0:3,:,:] 548 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import functools 5 | from torch.optim import lr_scheduler 6 | import torch.nn.functional as F 7 | 8 | #################################################################### 9 | #------------------------- Discriminators -------------------------- 10 | #################################################################### 11 | class Dis_content(nn.Module): 12 | def __init__(self): 13 | super(Dis_content, self).__init__() 14 | model = [] 15 | model += [LeakyReLUConv2d(256, 256, kernel_size=7, stride=2, padding=1, norm='Instance')] 16 | model += [LeakyReLUConv2d(256, 256, kernel_size=7, stride=2, padding=1, norm='Instance')] 17 | model += [LeakyReLUConv2d(256, 256, kernel_size=7, stride=2, padding=1, norm='Instance')] 18 | model += [LeakyReLUConv2d(256, 256, kernel_size=4, stride=1, padding=0)] 19 | model += [nn.Conv2d(256, 1, kernel_size=1, stride=1, padding=0)] 20 | self.model = nn.Sequential(*model) 21 | 22 | def forward(self, x): 23 | out = self.model(x) 24 | out = out.view(-1) 25 | outs = [] 26 | outs.append(out) 27 | return outs 28 | 29 | class MultiScaleDis(nn.Module): 30 | def __init__(self, input_dim, n_scale=3, n_layer=4, norm='None', sn=False): 31 | super(MultiScaleDis, self).__init__() 32 | ch = 64 33 | self.downsample = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False) 34 | self.Diss = nn.ModuleList() 35 | for _ in range(n_scale): 36 | self.Diss.append(self._make_net(ch, input_dim, n_layer, norm, sn)) 37 | 38 | def _make_net(self, ch, input_dim, n_layer, norm, sn): 39 | model = [] 40 | model += [LeakyReLUConv2d(input_dim, ch, 4, 2, 1, norm, sn)] 41 | tch = ch 42 | for _ in range(1, n_layer): 43 | model += [LeakyReLUConv2d(tch, tch * 2, 4, 2, 1, norm, sn)] 44 | tch *= 2 45 | if sn: 46 | model += [spectral_norm(nn.Conv2d(tch, 1, 1, 1, 0))] 47 | else: 48 | model += [nn.Conv2d(tch, 1, 1, 1, 0)] 49 | return nn.Sequential(*model) 50 | 51 | def forward(self, x): 52 | outs = [] 53 | for Dis in self.Diss: 54 | outs.append(Dis(x)) 55 | x = self.downsample(x) 56 | return outs 57 | 58 | class Dis(nn.Module): 59 | def __init__(self, input_dim, norm='None', sn=False): 60 | super(Dis, self).__init__() 61 | ch = 64 62 | n_layer = 6 63 | self.model = self._make_net(ch, input_dim, n_layer, norm, sn) 64 | 65 | def _make_net(self, ch, input_dim, n_layer, norm, sn): 66 | model = [] 67 | model += [LeakyReLUConv2d(input_dim, ch, kernel_size=3, stride=2, padding=1, norm=norm, sn=sn)] #16 68 | tch = ch 69 | for i in range(1, n_layer-1): 70 | model += [LeakyReLUConv2d(tch, tch * 2, kernel_size=3, stride=2, padding=1, norm=norm, sn=sn)] # 8 71 | tch *= 2 72 | model += [LeakyReLUConv2d(tch, tch * 2, kernel_size=3, stride=2, padding=1, norm='None', sn=sn)] # 2 73 | tch *= 2 74 | if sn: 75 | model += [spectral_norm(nn.Conv2d(tch, 1, kernel_size=1, stride=1, padding=0))] # 1 76 | else: 77 | model += [nn.Conv2d(tch, 1, kernel_size=1, stride=1, padding=0)] # 1 78 | return nn.Sequential(*model) 79 | 80 | def cuda(self,gpu): 81 | self.model.cuda(gpu) 82 | 83 | def forward(self, x_A): 84 | out_A = self.model(x_A) 85 | out_A = out_A.view(-1) 86 | outs_A = [] 87 | outs_A.append(out_A) 88 | return outs_A 89 | 90 | #################################################################### 91 | #---------------------------- Encoders ----------------------------- 92 | #################################################################### 93 | class E_content(nn.Module): 94 | def __init__(self, input_dim_a, input_dim_b): 95 | super(E_content, self).__init__() 96 | encA_c = [] 97 | tch = 64 98 | encA_c += [LeakyReLUConv2d(input_dim_a, tch, kernel_size=7, stride=1, padding=3)] 99 | for i in range(1, 3): 100 | encA_c += [ReLUINSConv2d(tch, tch * 2, kernel_size=3, stride=2, padding=1)] 101 | tch *= 2 102 | for i in range(0, 3): 103 | encA_c += [INSResBlock(tch, tch)] 104 | 105 | encB_c = [] 106 | tch = 64 107 | encB_c += [LeakyReLUConv2d(input_dim_b, tch, kernel_size=7, stride=1, padding=3)] 108 | for i in range(1, 3): 109 | encB_c += [ReLUINSConv2d(tch, tch * 2, kernel_size=3, stride=2, padding=1)] 110 | tch *= 2 111 | for i in range(0, 3): 112 | encB_c += [INSResBlock(tch, tch)] 113 | 114 | enc_share = [] 115 | for i in range(0, 1): 116 | enc_share += [INSResBlock(tch, tch)] 117 | enc_share += [GaussianNoiseLayer()] 118 | self.conv_share = nn.Sequential(*enc_share) 119 | 120 | self.convA = nn.Sequential(*encA_c) 121 | self.convB = nn.Sequential(*encB_c) 122 | 123 | def forward(self, xa, xb): 124 | outputA = self.convA(xa) 125 | outputB = self.convB(xb) 126 | outputA = self.conv_share(outputA) 127 | outputB = self.conv_share(outputB) 128 | return outputA, outputB 129 | 130 | def forward_a(self, xa): 131 | outputA = self.convA(xa) 132 | outputA = self.conv_share(outputA) 133 | return outputA 134 | 135 | def forward_b(self, xb): 136 | outputB = self.convB(xb) 137 | outputB = self.conv_share(outputB) 138 | return outputB 139 | 140 | class E_attr(nn.Module): 141 | def __init__(self, input_dim_a, input_dim_b, output_nc=8): 142 | super(E_attr, self).__init__() 143 | dim = 64 144 | self.model_a = nn.Sequential( 145 | nn.ReflectionPad2d(3), 146 | nn.Conv2d(input_dim_a, dim, 7, 1), 147 | nn.ReLU(inplace=True), 148 | nn.ReflectionPad2d(1), 149 | nn.Conv2d(dim, dim*2, 4, 2), 150 | nn.ReLU(inplace=True), 151 | nn.ReflectionPad2d(1), 152 | nn.Conv2d(dim*2, dim*4, 4, 2), 153 | nn.ReLU(inplace=True), 154 | nn.ReflectionPad2d(1), 155 | nn.Conv2d(dim*4, dim*4, 4, 2), 156 | nn.ReLU(inplace=True), 157 | nn.ReflectionPad2d(1), 158 | nn.Conv2d(dim*4, dim*4, 4, 2), 159 | nn.ReLU(inplace=True), 160 | nn.AdaptiveAvgPool2d(1), 161 | nn.Conv2d(dim*4, output_nc, 1, 1, 0)) 162 | self.model_b = nn.Sequential( 163 | nn.ReflectionPad2d(3), 164 | nn.Conv2d(input_dim_b, dim, 7, 1), 165 | nn.ReLU(inplace=True), 166 | nn.ReflectionPad2d(1), 167 | nn.Conv2d(dim, dim*2, 4, 2), 168 | nn.ReLU(inplace=True), 169 | nn.ReflectionPad2d(1), 170 | nn.Conv2d(dim*2, dim*4, 4, 2), 171 | nn.ReLU(inplace=True), 172 | nn.ReflectionPad2d(1), 173 | nn.Conv2d(dim*4, dim*4, 4, 2), 174 | nn.ReLU(inplace=True), 175 | nn.ReflectionPad2d(1), 176 | nn.Conv2d(dim*4, dim*4, 4, 2), 177 | nn.ReLU(inplace=True), 178 | nn.AdaptiveAvgPool2d(1), 179 | nn.Conv2d(dim*4, output_nc, 1, 1, 0)) 180 | return 181 | 182 | def forward(self, xa, xb): 183 | xa = self.model_a(xa) 184 | xb = self.model_b(xb) 185 | output_A = xa.view(xa.size(0), -1) 186 | output_B = xb.view(xb.size(0), -1) 187 | return output_A, output_B 188 | 189 | def forward_a(self, xa): 190 | xa = self.model_a(xa) 191 | output_A = xa.view(xa.size(0), -1) 192 | return output_A 193 | 194 | def forward_b(self, xb): 195 | xb = self.model_b(xb) 196 | output_B = xb.view(xb.size(0), -1) 197 | return output_B 198 | 199 | class E_attr_concat(nn.Module): 200 | def __init__(self, input_dim_a, input_dim_b, output_nc=8, norm_layer=None, nl_layer=None): 201 | super(E_attr_concat, self).__init__() 202 | 203 | ndf = 64 204 | n_blocks=4 205 | max_ndf = 4 206 | 207 | conv_layers_A = [nn.ReflectionPad2d(1)] 208 | conv_layers_A += [nn.Conv2d(input_dim_a, ndf, kernel_size=4, stride=2, padding=0, bias=True)] 209 | for n in range(1, n_blocks): 210 | input_ndf = ndf * min(max_ndf, n) # 2**(n-1) 211 | output_ndf = ndf * min(max_ndf, n+1) # 2**n 212 | conv_layers_A += [BasicBlock(input_ndf, output_ndf, norm_layer, nl_layer)] 213 | conv_layers_A += [nl_layer(), nn.AdaptiveAvgPool2d(1)] # AvgPool2d(13) 214 | self.fc_A = nn.Sequential(*[nn.Linear(output_ndf, output_nc)]) 215 | self.fcVar_A = nn.Sequential(*[nn.Linear(output_ndf, output_nc)]) 216 | self.conv_A = nn.Sequential(*conv_layers_A) 217 | 218 | conv_layers_B = [nn.ReflectionPad2d(1)] 219 | conv_layers_B += [nn.Conv2d(input_dim_b, ndf, kernel_size=4, stride=2, padding=0, bias=True)] 220 | for n in range(1, n_blocks): 221 | input_ndf = ndf * min(max_ndf, n) # 2**(n-1) 222 | output_ndf = ndf * min(max_ndf, n+1) # 2**n 223 | conv_layers_B += [BasicBlock(input_ndf, output_ndf, norm_layer, nl_layer)] 224 | conv_layers_B += [nl_layer(), nn.AdaptiveAvgPool2d(1)] # AvgPool2d(13) 225 | self.fc_B = nn.Sequential(*[nn.Linear(output_ndf, output_nc)]) 226 | self.fcVar_B = nn.Sequential(*[nn.Linear(output_ndf, output_nc)]) 227 | self.conv_B = nn.Sequential(*conv_layers_B) 228 | 229 | def forward(self, xa, xb): 230 | x_conv_A = self.conv_A(xa) 231 | conv_flat_A = x_conv_A.view(xa.size(0), -1) 232 | output_A = self.fc_A(conv_flat_A) 233 | outputVar_A = self.fcVar_A(conv_flat_A) 234 | x_conv_B = self.conv_B(xb) 235 | conv_flat_B = x_conv_B.view(xb.size(0), -1) 236 | output_B = self.fc_B(conv_flat_B) 237 | outputVar_B = self.fcVar_B(conv_flat_B) 238 | return output_A, outputVar_A, output_B, outputVar_B 239 | 240 | def forward_a(self, xa): 241 | x_conv_A = self.conv_A(xa) 242 | conv_flat_A = x_conv_A.view(xa.size(0), -1) 243 | output_A = self.fc_A(conv_flat_A) 244 | outputVar_A = self.fcVar_A(conv_flat_A) 245 | return output_A, outputVar_A 246 | 247 | def forward_b(self, xb): 248 | x_conv_B = self.conv_B(xb) 249 | conv_flat_B = x_conv_B.view(xb.size(0), -1) 250 | output_B = self.fc_B(conv_flat_B) 251 | outputVar_B = self.fcVar_B(conv_flat_B) 252 | return output_B, outputVar_B 253 | 254 | #################################################################### 255 | #--------------------------- Generators ---------------------------- 256 | #################################################################### 257 | class G(nn.Module): 258 | def __init__(self, output_dim_a, output_dim_b, nz): 259 | super(G, self).__init__() 260 | self.nz = nz 261 | ini_tch = 256 262 | tch_add = ini_tch 263 | tch = ini_tch 264 | self.tch_add = tch_add 265 | self.decA1 = MisINSResBlock(tch, tch_add) 266 | self.decA2 = MisINSResBlock(tch, tch_add) 267 | self.decA3 = MisINSResBlock(tch, tch_add) 268 | self.decA4 = MisINSResBlock(tch, tch_add) 269 | 270 | decA5 = [] 271 | decA5 += [ReLUINSConvTranspose2d(tch, tch//2, kernel_size=3, stride=2, padding=1, output_padding=1)] 272 | tch = tch//2 273 | decA5 += [ReLUINSConvTranspose2d(tch, tch//2, kernel_size=3, stride=2, padding=1, output_padding=1)] 274 | tch = tch//2 275 | decA5 += [nn.ConvTranspose2d(tch, output_dim_a, kernel_size=1, stride=1, padding=0)] 276 | decA5 += [nn.Tanh()] 277 | self.decA5 = nn.Sequential(*decA5) 278 | 279 | tch = ini_tch 280 | self.decB1 = MisINSResBlock(tch, tch_add) 281 | self.decB2 = MisINSResBlock(tch, tch_add) 282 | self.decB3 = MisINSResBlock(tch, tch_add) 283 | self.decB4 = MisINSResBlock(tch, tch_add) 284 | decB5 = [] 285 | decB5 += [ReLUINSConvTranspose2d(tch, tch//2, kernel_size=3, stride=2, padding=1, output_padding=1)] 286 | tch = tch//2 287 | decB5 += [ReLUINSConvTranspose2d(tch, tch//2, kernel_size=3, stride=2, padding=1, output_padding=1)] 288 | tch = tch//2 289 | decB5 += [nn.ConvTranspose2d(tch, output_dim_b, kernel_size=1, stride=1, padding=0)] 290 | decB5 += [nn.Tanh()] 291 | self.decB5 = nn.Sequential(*decB5) 292 | 293 | self.mlpA = nn.Sequential( 294 | nn.Linear(8, 256), 295 | nn.ReLU(inplace=True), 296 | nn.Linear(256, 256), 297 | nn.ReLU(inplace=True), 298 | nn.Linear(256, tch_add*4)) 299 | self.mlpB = nn.Sequential( 300 | nn.Linear(8, 256), 301 | nn.ReLU(inplace=True), 302 | nn.Linear(256, 256), 303 | nn.ReLU(inplace=True), 304 | nn.Linear(256, tch_add*4)) 305 | return 306 | 307 | def forward_a(self, x, z): 308 | z = self.mlpA(z) 309 | z1, z2, z3, z4 = torch.split(z, self.tch_add, dim=1) 310 | z1, z2, z3, z4 = z1.contiguous(), z2.contiguous(), z3.contiguous(), z4.contiguous() 311 | out1 = self.decA1(x, z1) 312 | out2 = self.decA2(out1, z2) 313 | out3 = self.decA3(out2, z3) 314 | out4 = self.decA4(out3, z4) 315 | out = self.decA5(out4) 316 | return out 317 | 318 | def forward_b(self, x, z): 319 | z = self.mlpB(z) 320 | z1, z2, z3, z4 = torch.split(z, self.tch_add, dim=1) 321 | z1, z2, z3, z4 = z1.contiguous(), z2.contiguous(), z3.contiguous(), z4.contiguous() 322 | out1 = self.decB1(x, z1) 323 | out2 = self.decB2(out1, z2) 324 | out3 = self.decB3(out2, z3) 325 | out4 = self.decB4(out3, z4) 326 | out = self.decB5(out4) 327 | return out 328 | 329 | class G_concat(nn.Module): 330 | def __init__(self, output_dim_a, output_dim_b, nz): 331 | super(G_concat, self).__init__() 332 | self.nz = nz 333 | tch = 256 334 | dec_share = [] 335 | dec_share += [INSResBlock(tch, tch)] 336 | self.dec_share = nn.Sequential(*dec_share) 337 | tch = 256+self.nz 338 | decA1 = [] 339 | for i in range(0, 3): 340 | decA1 += [INSResBlock(tch, tch)] 341 | tch = tch + self.nz 342 | decA2 = ReLUINSConvTranspose2d(tch, tch//2, kernel_size=3, stride=2, padding=1, output_padding=1) 343 | tch = tch//2 344 | tch = tch + self.nz 345 | decA3 = ReLUINSConvTranspose2d(tch, tch//2, kernel_size=3, stride=2, padding=1, output_padding=1) 346 | tch = tch//2 347 | tch = tch + self.nz 348 | decA4 = [nn.ConvTranspose2d(tch, output_dim_a, kernel_size=1, stride=1, padding=0)]+[nn.Tanh()] 349 | self.decA1 = nn.Sequential(*decA1) 350 | self.decA2 = nn.Sequential(*[decA2]) 351 | self.decA3 = nn.Sequential(*[decA3]) 352 | self.decA4 = nn.Sequential(*decA4) 353 | 354 | tch = 256+self.nz 355 | decB1 = [] 356 | for i in range(0, 3): 357 | decB1 += [INSResBlock(tch, tch)] 358 | tch = tch + self.nz 359 | decB2 = ReLUINSConvTranspose2d(tch, tch//2, kernel_size=3, stride=2, padding=1, output_padding=1) 360 | tch = tch//2 361 | tch = tch + self.nz 362 | decB3 = ReLUINSConvTranspose2d(tch, tch//2, kernel_size=3, stride=2, padding=1, output_padding=1) 363 | tch = tch//2 364 | tch = tch + self.nz 365 | decB4 = [nn.ConvTranspose2d(tch, output_dim_b, kernel_size=1, stride=1, padding=0)]+[nn.Tanh()] 366 | self.decB1 = nn.Sequential(*decB1) 367 | self.decB2 = nn.Sequential(*[decB2]) 368 | self.decB3 = nn.Sequential(*[decB3]) 369 | self.decB4 = nn.Sequential(*decB4) 370 | 371 | def forward_a(self, x, z): 372 | out0 = self.dec_share(x) 373 | z_img = z.view(z.size(0), z.size(1), 1, 1).expand(z.size(0), z.size(1), x.size(2), x.size(3)) 374 | x_and_z = torch.cat([out0, z_img], 1) 375 | out1 = self.decA1(x_and_z) 376 | z_img2 = z.view(z.size(0), z.size(1), 1, 1).expand(z.size(0), z.size(1), out1.size(2), out1.size(3)) 377 | x_and_z2 = torch.cat([out1, z_img2], 1) 378 | out2 = self.decA2(x_and_z2) 379 | z_img3 = z.view(z.size(0), z.size(1), 1, 1).expand(z.size(0), z.size(1), out2.size(2), out2.size(3)) 380 | x_and_z3 = torch.cat([out2, z_img3], 1) 381 | out3 = self.decA3(x_and_z3) 382 | z_img4 = z.view(z.size(0), z.size(1), 1, 1).expand(z.size(0), z.size(1), out3.size(2), out3.size(3)) 383 | x_and_z4 = torch.cat([out3, z_img4], 1) 384 | out4 = self.decA4(x_and_z4) 385 | return out4 386 | 387 | def forward_b(self, x, z): 388 | out0 = self.dec_share(x) 389 | z_img = z.view(z.size(0), z.size(1), 1, 1).expand(z.size(0), z.size(1), x.size(2), x.size(3)) 390 | x_and_z = torch.cat([out0, z_img], 1) 391 | out1 = self.decB1(x_and_z) 392 | z_img2 = z.view(z.size(0), z.size(1), 1, 1).expand(z.size(0), z.size(1), out1.size(2), out1.size(3)) 393 | x_and_z2 = torch.cat([out1, z_img2], 1) 394 | out2 = self.decB2(x_and_z2) 395 | z_img3 = z.view(z.size(0), z.size(1), 1, 1).expand(z.size(0), z.size(1), out2.size(2), out2.size(3)) 396 | x_and_z3 = torch.cat([out2, z_img3], 1) 397 | out3 = self.decB3(x_and_z3) 398 | z_img4 = z.view(z.size(0), z.size(1), 1, 1).expand(z.size(0), z.size(1), out3.size(2), out3.size(3)) 399 | x_and_z4 = torch.cat([out3, z_img4], 1) 400 | out4 = self.decB4(x_and_z4) 401 | return out4 402 | 403 | #################################################################### 404 | #------------------------- Basic Functions ------------------------- 405 | #################################################################### 406 | def get_scheduler(optimizer, opts, cur_ep=-1): 407 | if opts.lr_policy == 'lambda': 408 | def lambda_rule(ep): 409 | lr_l = 1.0 - max(0, ep - opts.n_ep_decay) / float(opts.n_ep - opts.n_ep_decay + 1) 410 | return lr_l 411 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule, last_epoch=cur_ep) 412 | elif opts.lr_policy == 'step': 413 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opts.n_ep_decay, gamma=0.1, last_epoch=cur_ep) 414 | else: 415 | return NotImplementedError('no such learn rate policy') 416 | return scheduler 417 | 418 | def meanpoolConv(inplanes, outplanes): 419 | sequence = [] 420 | sequence += [nn.AvgPool2d(kernel_size=2, stride=2)] 421 | sequence += [nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0, bias=True)] 422 | return nn.Sequential(*sequence) 423 | 424 | def convMeanpool(inplanes, outplanes): 425 | sequence = [] 426 | sequence += conv3x3(inplanes, outplanes) 427 | sequence += [nn.AvgPool2d(kernel_size=2, stride=2)] 428 | return nn.Sequential(*sequence) 429 | 430 | def get_norm_layer(layer_type='instance'): 431 | if layer_type == 'batch': 432 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 433 | elif layer_type == 'instance': 434 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) 435 | elif layer_type == 'none': 436 | norm_layer = None 437 | else: 438 | raise NotImplementedError('normalization layer [%s] is not found' % layer_type) 439 | return norm_layer 440 | 441 | def get_non_linearity(layer_type='relu'): 442 | if layer_type == 'relu': 443 | nl_layer = functools.partial(nn.ReLU, inplace=True) 444 | elif layer_type == 'lrelu': 445 | nl_layer = functools.partial(nn.LeakyReLU, negative_slope=0.2, inplace=False) 446 | elif layer_type == 'elu': 447 | nl_layer = functools.partial(nn.ELU, inplace=True) 448 | else: 449 | raise NotImplementedError('nonlinearity activitation [%s] is not found' % layer_type) 450 | return nl_layer 451 | def conv3x3(in_planes, out_planes): 452 | return [nn.ReflectionPad2d(1), nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=0, bias=True)] 453 | 454 | def gaussian_weights_init(m): 455 | classname = m.__class__.__name__ 456 | if classname.find('Conv') != -1 and classname.find('Conv') == 0: 457 | m.weight.data.normal_(0.0, 0.02) 458 | 459 | #################################################################### 460 | #-------------------------- Basic Blocks -------------------------- 461 | #################################################################### 462 | 463 | ## The code of LayerNorm is modified from MUNIT (https://github.com/NVlabs/MUNIT) 464 | class LayerNorm(nn.Module): 465 | def __init__(self, n_out, eps=1e-5, affine=True): 466 | super(LayerNorm, self).__init__() 467 | self.n_out = n_out 468 | self.affine = affine 469 | if self.affine: 470 | self.weight = nn.Parameter(torch.ones(n_out, 1, 1)) 471 | self.bias = nn.Parameter(torch.zeros(n_out, 1, 1)) 472 | return 473 | def forward(self, x): 474 | normalized_shape = x.size()[1:] 475 | if self.affine: 476 | return F.layer_norm(x, normalized_shape, self.weight.expand(normalized_shape), self.bias.expand(normalized_shape)) 477 | else: 478 | return F.layer_norm(x, normalized_shape) 479 | 480 | class BasicBlock(nn.Module): 481 | def __init__(self, inplanes, outplanes, norm_layer=None, nl_layer=None): 482 | super(BasicBlock, self).__init__() 483 | layers = [] 484 | if norm_layer is not None: 485 | layers += [norm_layer(inplanes)] 486 | layers += [nl_layer()] 487 | layers += conv3x3(inplanes, inplanes) 488 | if norm_layer is not None: 489 | layers += [norm_layer(inplanes)] 490 | layers += [nl_layer()] 491 | layers += [convMeanpool(inplanes, outplanes)] 492 | self.conv = nn.Sequential(*layers) 493 | self.shortcut = meanpoolConv(inplanes, outplanes) 494 | def forward(self, x): 495 | out = self.conv(x) + self.shortcut(x) 496 | return out 497 | 498 | class LeakyReLUConv2d(nn.Module): 499 | def __init__(self, n_in, n_out, kernel_size, stride, padding=0, norm='None', sn=False): 500 | super(LeakyReLUConv2d, self).__init__() 501 | model = [] 502 | model += [nn.ReflectionPad2d(padding)] 503 | if sn: 504 | model += [spectral_norm(nn.Conv2d(n_in, n_out, kernel_size=kernel_size, stride=stride, padding=0, bias=True))] 505 | else: 506 | model += [nn.Conv2d(n_in, n_out, kernel_size=kernel_size, stride=stride, padding=0, bias=True)] 507 | if 'norm' == 'Instance': 508 | model += [nn.InstanceNorm2d(n_out, affine=False)] 509 | model += [nn.LeakyReLU(inplace=True)] 510 | self.model = nn.Sequential(*model) 511 | self.model.apply(gaussian_weights_init) 512 | #elif == 'Group' 513 | def forward(self, x): 514 | return self.model(x) 515 | 516 | class ReLUINSConv2d(nn.Module): 517 | def __init__(self, n_in, n_out, kernel_size, stride, padding=0): 518 | super(ReLUINSConv2d, self).__init__() 519 | model = [] 520 | model += [nn.ReflectionPad2d(padding)] 521 | model += [nn.Conv2d(n_in, n_out, kernel_size=kernel_size, stride=stride, padding=0, bias=True)] 522 | model += [nn.InstanceNorm2d(n_out, affine=False)] 523 | model += [nn.ReLU(inplace=True)] 524 | self.model = nn.Sequential(*model) 525 | self.model.apply(gaussian_weights_init) 526 | def forward(self, x): 527 | return self.model(x) 528 | 529 | class INSResBlock(nn.Module): 530 | def conv3x3(self, inplanes, out_planes, stride=1): 531 | return [nn.ReflectionPad2d(1), nn.Conv2d(inplanes, out_planes, kernel_size=3, stride=stride)] 532 | def __init__(self, inplanes, planes, stride=1, dropout=0.0): 533 | super(INSResBlock, self).__init__() 534 | model = [] 535 | model += self.conv3x3(inplanes, planes, stride) 536 | model += [nn.InstanceNorm2d(planes)] 537 | model += [nn.ReLU(inplace=True)] 538 | model += self.conv3x3(planes, planes) 539 | model += [nn.InstanceNorm2d(planes)] 540 | if dropout > 0: 541 | model += [nn.Dropout(p=dropout)] 542 | self.model = nn.Sequential(*model) 543 | self.model.apply(gaussian_weights_init) 544 | def forward(self, x): 545 | residual = x 546 | out = self.model(x) 547 | out += residual 548 | return out 549 | 550 | class MisINSResBlock(nn.Module): 551 | def conv3x3(self, dim_in, dim_out, stride=1): 552 | return nn.Sequential(nn.ReflectionPad2d(1), nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=stride)) 553 | def conv1x1(self, dim_in, dim_out): 554 | return nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=1, padding=0) 555 | def __init__(self, dim, dim_extra, stride=1, dropout=0.0): 556 | super(MisINSResBlock, self).__init__() 557 | self.conv1 = nn.Sequential( 558 | self.conv3x3(dim, dim, stride), 559 | nn.InstanceNorm2d(dim)) 560 | self.conv2 = nn.Sequential( 561 | self.conv3x3(dim, dim, stride), 562 | nn.InstanceNorm2d(dim)) 563 | self.blk1 = nn.Sequential( 564 | self.conv1x1(dim + dim_extra, dim + dim_extra), 565 | nn.ReLU(inplace=False), 566 | self.conv1x1(dim + dim_extra, dim), 567 | nn.ReLU(inplace=False)) 568 | self.blk2 = nn.Sequential( 569 | self.conv1x1(dim + dim_extra, dim + dim_extra), 570 | nn.ReLU(inplace=False), 571 | self.conv1x1(dim + dim_extra, dim), 572 | nn.ReLU(inplace=False)) 573 | model = [] 574 | if dropout > 0: 575 | model += [nn.Dropout(p=dropout)] 576 | self.model = nn.Sequential(*model) 577 | self.model.apply(gaussian_weights_init) 578 | self.conv1.apply(gaussian_weights_init) 579 | self.conv2.apply(gaussian_weights_init) 580 | self.blk1.apply(gaussian_weights_init) 581 | self.blk2.apply(gaussian_weights_init) 582 | def forward(self, x, z): 583 | residual = x 584 | z_expand = z.view(z.size(0), z.size(1), 1, 1).expand(z.size(0), z.size(1), x.size(2), x.size(3)) 585 | o1 = self.conv1(x) 586 | o2 = self.blk1(torch.cat([o1, z_expand], dim=1)) 587 | o3 = self.conv2(o2) 588 | out = self.blk2(torch.cat([o3, z_expand], dim=1)) 589 | out += residual 590 | return out 591 | 592 | class GaussianNoiseLayer(nn.Module): 593 | def __init__(self,): 594 | super(GaussianNoiseLayer, self).__init__() 595 | def forward(self, x): 596 | if self.training == False: 597 | return x 598 | noise = Variable(torch.randn(x.size()).cuda(x.get_device())) 599 | return x + noise 600 | 601 | class ReLUINSConvTranspose2d(nn.Module): 602 | def __init__(self, n_in, n_out, kernel_size, stride, padding, output_padding): 603 | super(ReLUINSConvTranspose2d, self).__init__() 604 | model = [] 605 | model += [nn.ConvTranspose2d(n_in, n_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=True)] 606 | model += [LayerNorm(n_out)] 607 | model += [nn.ReLU(inplace=True)] 608 | self.model = nn.Sequential(*model) 609 | self.model.apply(gaussian_weights_init) 610 | def forward(self, x): 611 | return self.model(x) 612 | 613 | 614 | #################################################################### 615 | #--------------------- Spectral Normalization --------------------- 616 | # This part of code is copied from pytorch master branch (0.5.0) 617 | #################################################################### 618 | class SpectralNorm(object): 619 | def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12): 620 | self.name = name 621 | self.dim = dim 622 | if n_power_iterations <= 0: 623 | raise ValueError('Expected n_power_iterations to be positive, but ' 624 | 'got n_power_iterations={}'.format(n_power_iterations)) 625 | self.n_power_iterations = n_power_iterations 626 | self.eps = eps 627 | def compute_weight(self, module): 628 | weight = getattr(module, self.name + '_orig') 629 | u = getattr(module, self.name + '_u') 630 | weight_mat = weight 631 | if self.dim != 0: 632 | # permute dim to front 633 | weight_mat = weight_mat.permute(self.dim, 634 | *[d for d in range(weight_mat.dim()) if d != self.dim]) 635 | height = weight_mat.size(0) 636 | weight_mat = weight_mat.reshape(height, -1) 637 | with torch.no_grad(): 638 | for _ in range(self.n_power_iterations): 639 | v = F.normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps) 640 | u = F.normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps) 641 | sigma = torch.dot(u, torch.matmul(weight_mat, v)) 642 | weight = weight / sigma 643 | return weight, u 644 | def remove(self, module): 645 | weight = getattr(module, self.name) 646 | delattr(module, self.name) 647 | delattr(module, self.name + '_u') 648 | delattr(module, self.name + '_orig') 649 | module.register_parameter(self.name, torch.nn.Parameter(weight)) 650 | def __call__(self, module, inputs): 651 | if module.training: 652 | weight, u = self.compute_weight(module) 653 | setattr(module, self.name, weight) 654 | setattr(module, self.name + '_u', u) 655 | else: 656 | r_g = getattr(module, self.name + '_orig').requires_grad 657 | getattr(module, self.name).detach_().requires_grad_(r_g) 658 | 659 | @staticmethod 660 | def apply(module, name, n_power_iterations, dim, eps): 661 | fn = SpectralNorm(name, n_power_iterations, dim, eps) 662 | weight = module._parameters[name] 663 | height = weight.size(dim) 664 | u = F.normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps) 665 | delattr(module, fn.name) 666 | module.register_parameter(fn.name + "_orig", weight) 667 | module.register_buffer(fn.name, weight.data) 668 | module.register_buffer(fn.name + "_u", u) 669 | module.register_forward_pre_hook(fn) 670 | return fn 671 | 672 | def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None): 673 | if dim is None: 674 | if isinstance(module, (torch.nn.ConvTranspose1d, 675 | torch.nn.ConvTranspose2d, 676 | torch.nn.ConvTranspose3d)): 677 | dim = 1 678 | else: 679 | dim = 0 680 | SpectralNorm.apply(module, name, n_power_iterations, dim, eps) 681 | return module 682 | 683 | def remove_spectral_norm(module, name='weight'): 684 | for k, hook in module._forward_pre_hooks.items(): 685 | if isinstance(hook, SpectralNorm) and hook.name == name: 686 | hook.remove(module) 687 | del module._forward_pre_hooks[k] 688 | return module 689 | raise ValueError("spectral_norm of '{}' not found in {}".format(name, module)) 690 | 691 | --------------------------------------------------------------------------------