├── raindrop ├── gen_raindrop │ ├── __init__.py │ ├── texture.png │ ├── get_position_matrix.py │ ├── composite_img.py │ └── gen_raindrop.py └── README.md ├── imgs ├── BID.gif ├── network.png ├── taskII.png └── taskIII.png ├── util ├── __init__.py ├── image_pool.py ├── html.py ├── get_data.py ├── util.py └── visualizer.py ├── options ├── __init__.py ├── test_options.py ├── train_options.py └── base_options.py ├── requirements.txt ├── datasets ├── bibtex │ ├── VGGflower.tex │ ├── reflection.tex │ ├── rain100L100H.tex │ ├── snow100K.tex │ ├── LVM.tex │ ├── ISTD.tex │ ├── SRD.tex │ ├── foggycityscape.tex │ └── cityscapes.tex └── prepare_cityscapes_dataset.py ├── environment.yml ├── metrics ├── brisque_niqe.m ├── ssim_psnr.m ├── rmse_srd_v2.m └── rmse_istd_v1.m ├── experiments ├── __init__.py └── __main__.py ├── models ├── losses.py ├── __init__.py ├── base_model.py ├── biden2_model.py └── biden3_model.py ├── data ├── image_folder.py ├── unaligned2_dataset.py ├── template_dataset.py ├── __init__.py ├── unaligned3_dataset.py ├── unaligned4_dataset.py ├── jointremoval_dataset.py ├── unaligned5_dataset.py ├── rainb_dataset.py ├── raina_dataset.py ├── unaligned8_dataset.py └── base_dataset.py ├── test2.py ├── test.py ├── train.py ├── README.md ├── train_fid.py └── LICENSE /raindrop/gen_raindrop/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imgs/BID.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunlinHan/BID/HEAD/imgs/BID.gif -------------------------------------------------------------------------------- /imgs/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunlinHan/BID/HEAD/imgs/network.png -------------------------------------------------------------------------------- /imgs/taskII.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunlinHan/BID/HEAD/imgs/taskII.png -------------------------------------------------------------------------------- /imgs/taskIII.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunlinHan/BID/HEAD/imgs/taskIII.png -------------------------------------------------------------------------------- /raindrop/gen_raindrop/texture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunlinHan/BID/HEAD/raindrop/gen_raindrop/texture.png -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | from util import * 3 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7.1 2 | torchvision>=0.8.2 3 | dominate>=2.4.0 4 | visdom>=0.1.8.8 5 | packaging 6 | GPUtil>=1.4.0 7 | scipy 8 | Pillow>=6.1.0 9 | numpy>=1.16.4 10 | opencv-python>=3.4.2.17 11 | 12 | 13 | -------------------------------------------------------------------------------- /datasets/bibtex/VGGflower.tex: -------------------------------------------------------------------------------- 1 | @inproceedings{VGGflower2006, 2 | title={A visual vocabulary for flower classification}, 3 | author={M-E Nilsback and Andrew Zisserman}, 4 | booktitle={CVPR}, 5 | year={2006} 6 | } 7 | 8 | 9 | -------------------------------------------------------------------------------- /datasets/bibtex/reflection.tex: -------------------------------------------------------------------------------- 1 | @inproceedings{zhang2018single, 2 | title={Single image reflection removal with perceptual losses}, 3 | author={Zhang, Xuaner and Ng, Ren and Chen, Qifeng}, 4 | year={2018}, 5 | organization={CVPR} 6 | } -------------------------------------------------------------------------------- /datasets/bibtex/rain100L100H.tex: -------------------------------------------------------------------------------- 1 | @inproceedings{Yang2017deep, 2 | title = {Deep joint rain detectionand removal from a single image}, 3 | author = {Wenhan Yang, Robby T. Tan, Jiashi Feng, Jiaying Liu, Zong-ming Guo, Shuicheng Yan}, 4 | booktitle = {CVPR}, 5 | year = {2017} 6 | } 7 | -------------------------------------------------------------------------------- /datasets/bibtex/snow100K.tex: -------------------------------------------------------------------------------- 1 | @inproceedings{liu2018desnownet, 2 | author = {Yun-Fu Liu, Da-Wei Jaw, Shih-Chia Huang, and Jenq-Neng Hwang}, 3 | title = {Desnownet: Context-aware deep network for snow removal}, 4 | booktitle = {IEEE Transaction on Image Processing (TIP)}, 5 | year = {2018} 6 | } 7 | 8 | -------------------------------------------------------------------------------- /raindrop/README.md: -------------------------------------------------------------------------------- 1 | Generate Steps: 2 | 1. run gen_raindrop/gen_raindrop.py to get alpha map (mask) and corresponding texture map. 3 | 2. run gen_raindrop/composite_img.py to get compostied image. 4 | 5 | Note: to composited the image in real time, you need run get_position_matrix and save the position matrix in advance. 6 | 7 | 8 | -------------------------------------------------------------------------------- /datasets/bibtex/LVM.tex: -------------------------------------------------------------------------------- 1 | @inproceedings{liu2021wdnet, 2 | title={WDNet: Watermark-Decomposition Network for Visible Watermark Removal}, 3 | author={Liu, Yang and Zhu, Zhen and Bai, Xiang}, 4 | booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision}, 5 | pages={3685--3693}, 6 | year={2021} 7 | } -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: BID 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.7.7 7 | - pytorch=1.7.1 8 | - torchvision=0.8.2 9 | - scipy 10 | - pip: 11 | - dominate==2.4.0 12 | - Pillow==6.1.0 13 | - numpy==1.16.4 14 | - visdom==0.1.8 15 | - packaging 16 | - GPUtil==1.4.0 17 | - opencv-python ==3.4.2.17 18 | -------------------------------------------------------------------------------- /datasets/bibtex/ISTD.tex: -------------------------------------------------------------------------------- 1 | @inproceedings{wang2018stacked, 2 | title={Stacked conditional generative adversarial networks for jointly learning shadow detection and shadow removal}, 3 | author={Wang, Jifeng and Li, Xiang and Yang, Jian}, 4 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 5 | pages={1788--1797}, 6 | year={2018} 7 | } -------------------------------------------------------------------------------- /datasets/bibtex/SRD.tex: -------------------------------------------------------------------------------- 1 | @inproceedings{qu2017deshadownet, 2 | title={Deshadownet: A multi-context embedding deep network for shadow removal}, 3 | author={Qu, Liangqiong and Tian, Jiandong and He, Shengfeng and Tang, Yandong and Lau, Rynson WH}, 4 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 5 | pages={4067--4075}, 6 | year={2017} 7 | } -------------------------------------------------------------------------------- /datasets/bibtex/foggycityscape.tex: -------------------------------------------------------------------------------- 1 | @article{sakaridis2018semantic, 2 | title={Semantic foggy scene understanding with synthetic data}, 3 | author={Sakaridis, Christos and Dai, Dengxin and Van Gool, Luc}, 4 | journal={International Journal of Computer Vision}, 5 | volume={126}, 6 | number={9}, 7 | pages={973--992}, 8 | year={2018}, 9 | publisher={Springer} 10 | } 11 | 12 | -------------------------------------------------------------------------------- /datasets/bibtex/cityscapes.tex: -------------------------------------------------------------------------------- 1 | @inproceedings{Cordts2016Cityscapes, 2 | title={The Cityscapes Dataset for Semantic Urban Scene Understanding}, 3 | author={Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt}, 4 | booktitle={Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 5 | year={2016} 6 | } 7 | -------------------------------------------------------------------------------- /metrics/brisque_niqe.m: -------------------------------------------------------------------------------- 1 | clear all; 2 | % The folder to your generated/separated images. 3 | folder1 ='your_path'; 4 | files1 = dir(folder1); 5 | 6 | image_num=1329; %choose how many images to process, 185 for rain streak, 7 | % 249 for rain drop, and 1329 for snow. 8 | count_brisque=0; 9 | count_niqe=0; 10 | for i=3:image_num+2 11 | image1=uint8(imread(strcat(folder1,'\',files1(i).name))); 12 | bri_score = brisque(image1); 13 | niq_score = niqe(image1); 14 | count_brisque = count_brisque + bri_score; 15 | count_niqe = count_niqe + niq_score; 16 | end 17 | count_brisque=count_brisque/image_num; 18 | count_niqe=count_niqe/image_num; 19 | disp("BRISQUE result"); 20 | disp(count_brisque); 21 | disp("NIQE result"); 22 | disp(count_niqe); 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /metrics/ssim_psnr.m: -------------------------------------------------------------------------------- 1 | clear all; 2 | % The folder to your generated/separated images. 3 | folder1 ='C:\Users\76454\Desktop\BID\results\biden3\test_latest\images\fake_B'; 4 | % The folder to real images. 5 | folder2= 'C:\Users\76454\Desktop\BID\results\biden3\test_latest\images\real_B'; 6 | files1 = dir(folder1); 7 | files2 = dir(folder2); 8 | 9 | image_num=300; %choose how many images to process, 300 for task I, 500 for task II. 10 | count_ssim=0; 11 | count_psnr=0; 12 | for i=3:image_num+2 13 | image1=uint8(imread(strcat(folder1,'\',files1(i).name))); 14 | image2=uint8(imread(strcat(folder2,'\',files2(i).name))); 15 | [ssimval,ssimmap]=ssim(image1,image2); 16 | [peaksnr, snr] = psnr(image1,image2); 17 | count_ssim = count_ssim + ssimval; 18 | count_psnr = count_psnr + peaksnr; 19 | end 20 | count_ssim=count_ssim/image_num; 21 | count_psnr=count_psnr/image_num; 22 | disp("SSIM result"); 23 | disp(count_ssim); 24 | disp("PSNR result"); 25 | disp(count_psnr); 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 13 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 14 | # Dropout and Batchnorm has different behavioir during training and test. 15 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') 16 | # Set the default, we set 5000 HERE to test the whole test set. 17 | parser.add_argument('--num_test', type=int, default=5000, help='how many test images to run') 18 | 19 | # To avoid cropping, the load_size should be the same as crop_size 20 | parser.set_defaults(load_size=parser.get_default('crop_size')) 21 | self.isTrain = False 22 | return parser 23 | -------------------------------------------------------------------------------- /raindrop/gen_raindrop/get_position_matrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | import matplotlib.pyplot as plt 5 | import random 6 | 7 | # random crop the alpha map to fit the shape of image 8 | def get_random_crop(texture, alpha, crop_height): 9 | 10 | max_y = texture.shape[0] - crop_height 11 | y = np.random.randint(0, max_y) 12 | crop_texture = texture[y: y + crop_height, :, :] 13 | crop_alpha = alpha[y: y + crop_height, :, :] 14 | 15 | return crop_texture,crop_alpha 16 | 17 | # 2048 * 1024 18 | def get_position_matrix(texture,alpha,output_size,img): 19 | h,w = output_size 20 | 21 | texture_size = texture.shape[0] 22 | factor = h/w 23 | 24 | crop_w = texture_size 25 | crop_h = int(crop_w*factor) 26 | 27 | texture,alpha = get_random_crop(texture,alpha, crop_h) 28 | 29 | texture = cv2.resize(texture,(output_size[1],output_size[0])) 30 | alpha = cv2.resize(alpha, (output_size[1],output_size[0])) 31 | alpha = cv2.blur(alpha,(5,5)) 32 | 33 | position_matrix = np.mgrid[0:h,0:w] 34 | 35 | position_matrix[0,:,:] = position_matrix[0,:,:] + texture[:,:,2]*(texture[:,:,0]/255) 36 | position_matrix[1,:, :] = position_matrix[1,:, :] + texture[:, :, 1]*(texture[:,:,0]/255) 37 | position_matrix = position_matrix*(alpha[:,:,0]>255*0.3) 38 | 39 | 40 | return position_matrix,alpha 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | 5 | def find_launcher_using_name(launcher_name): 6 | # cur_dir = os.path.dirname(os.path.abspath(__file__)) 7 | # pythonfiles = glob.glob(cur_dir + '/**/*.py') 8 | launcher_filename = "experiments.{}_launcher".format(launcher_name) 9 | launcherlib = importlib.import_module(launcher_filename) 10 | 11 | # In the file, the class called LauncherNameLauncher() will 12 | # be instantiated. It has to be a subclass of BaseLauncher, 13 | # and it is case-insensitive. 14 | launcher = None 15 | target_launcher_name = launcher_name.replace('_', '') + 'launcher' 16 | for name, cls in launcherlib.__dict__.items(): 17 | if name.lower() == target_launcher_name.lower(): 18 | launcher = cls 19 | 20 | if launcher is None: 21 | raise ValueError("In %s.py, there should be a subclass of BaseLauncher " 22 | "with class name that matches %s in lowercase." % 23 | (launcher_filename, target_launcher_name)) 24 | 25 | return launcher 26 | 27 | 28 | if __name__ == "__main__": 29 | import sys 30 | import pickle 31 | 32 | assert len(sys.argv) >= 3 33 | 34 | name = sys.argv[1] 35 | Launcher = find_launcher_using_name(name) 36 | 37 | cache = "/tmp/tmux_launcher/{}".format(name) 38 | if os.path.isfile(cache): 39 | instance = pickle.load(open(cache, 'r')) 40 | else: 41 | instance = Launcher() 42 | 43 | cmd = sys.argv[2] 44 | if cmd == "launch": 45 | instance.launch() 46 | elif cmd == "stop": 47 | instance.stop() 48 | elif cmd == "send": 49 | expid = int(sys.argv[3]) 50 | cmd = int(sys.argv[4]) 51 | instance.send_command(expid, cmd) 52 | 53 | os.makedirs("/tmp/tmux_launcher/", exist_ok=True) 54 | pickle.dump(instance, open(cache, 'w')) 55 | -------------------------------------------------------------------------------- /raindrop/gen_raindrop/composite_img.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | import matplotlib.pyplot as plt 5 | import random 6 | from get_position_matrix import get_position_matrix 7 | 8 | 9 | def composition_img(img,alpha,position_matrix,length=2): 10 | h, w = img.shape[0:2] 11 | dis_img = img.copy() 12 | 13 | for x in range(h): 14 | for y in range(w): 15 | u,v = int(position_matrix[0,x,y]/length),int(position_matrix[1,x,y]/length) 16 | if (u != 0 and v != 0): 17 | if((u= opt.num_test: # only apply our model to opt.num_test images. 37 | break 38 | model.set_input(data) # unpack data from data loader 39 | model.test() # run inference 40 | visuals = model.get_current_visuals() # get image results 41 | img_path = model.get_image_paths() # get image paths 42 | if i % 50 == 0: # save images to an HTML file 43 | print('processing (%04d)-th image... %s' % (i, img_path)) 44 | save_images(webpage, visuals, img_path, width=opt.display_winsize) 45 | webpage.save() # save the HTML 46 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class ImagePool(): 6 | """This class implements an image buffer that stores previously generated images. 7 | 8 | This buffer enables us to update discriminators using a history of generated images 9 | rather than the ones produced by the latest generators. 10 | """ 11 | 12 | def __init__(self, pool_size): 13 | """Initialize the ImagePool class 14 | 15 | Parameters: 16 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 17 | """ 18 | self.pool_size = pool_size 19 | if self.pool_size > 0: # create an empty pool 20 | self.num_imgs = 0 21 | self.images = [] 22 | 23 | def query(self, images): 24 | """Return an image from the pool. 25 | 26 | Parameters: 27 | images: the latest generated images from the generator 28 | 29 | Returns images from the buffer. 30 | 31 | By 50/100, the buffer will return input images. 32 | By 50/100, the buffer will return images previously stored in the buffer, 33 | and insert the current images to the buffer. 34 | """ 35 | if self.pool_size == 0: # if the buffer size is 0, do nothing 36 | return images 37 | return_images = [] 38 | for image in images: 39 | image = torch.unsqueeze(image.data, 0) 40 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer 41 | self.num_imgs = self.num_imgs + 1 42 | self.images.append(image) 43 | return_images.append(image) 44 | else: 45 | p = random.uniform(0, 1) 46 | if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer 47 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 48 | tmp = self.images[random_id].clone() 49 | self.images[random_id] = image 50 | return_images.append(tmp) 51 | else: # by another 50% chance, the buffer will return the current image 52 | return_images.append(image) 53 | return_images = torch.cat(return_images, 0) # collect all the images and return 54 | return return_images 55 | -------------------------------------------------------------------------------- /metrics/rmse_srd_v2.m: -------------------------------------------------------------------------------- 1 | %% compute rmse for SRD (V2) 2 | clear;close all;clc 3 | 4 | % mask directory|mask 5 | maskdir = 'C:\Users\76454\Desktop\BID\datasets\jointremoval_v2\testB\'; 6 | MD = dir([maskdir '/*.jpg']); 7 | 8 | % ground truth directory|GT 9 | freedir = 'C:\Users\76454\Desktop\BID\datasets\jointremoval_v2\testA\'; 10 | FD = dir([freedir '/*.jpg']); 11 | 12 | % predicted result directory 13 | shadowdir = 'C:\Users\76454\Desktop\BID\results\task3_v2\test_latest\images\fake_A\'; 14 | SD = dir([shadowdir '/*.png']); 15 | 16 | total_dists = 0; 17 | total_pixels = 0; 18 | total_distn = 0; 19 | total_pixeln = 0; 20 | rl=zeros(1,size(SD,1)); 21 | ra=zeros(1,size(SD,1)); 22 | rb=zeros(1,size(SD,1)); 23 | nrl=zeros(1,size(SD,1)); 24 | nra=zeros(1,size(SD,1)); 25 | nrb=zeros(1,size(SD,1)); 26 | srl=zeros(1,size(SD,1)); 27 | sra=zeros(1,size(SD,1)); 28 | srb=zeros(1,size(SD,1)); 29 | ppsnr=zeros(1,size(SD,1)); 30 | ppsnrs=zeros(1,size(SD,1)); 31 | ppsnrn=zeros(1,size(SD,1)); 32 | sssim=zeros(1,size(SD,1)); 33 | sssims=zeros(1,size(SD,1)); 34 | sssimn=zeros(1,size(SD,1)); 35 | % ISTD dataset image size 480*640 36 | tic; 37 | cform = makecform('srgb2lab'); 38 | 39 | for i=1:size(SD) 40 | sname = strcat(shadowdir,SD(i).name); 41 | fname = strcat(freedir,FD(i).name); 42 | mname = strcat(maskdir,MD(i).name); 43 | s=imread(sname); 44 | f=imread(fname); 45 | m=imread(mname); 46 | s=imresize(s,[256 256]); 47 | f=imresize(f,[256 256]); 48 | m=imresize(m,[256 256]); 49 | mask = ones([size(f,1),size(f,2)]); 50 | 51 | nmask=~m; 52 | smask=~nmask; 53 | 54 | f = double(f)/255; 55 | s = double(s)/255; 56 | 57 | 58 | f = applycform(f,cform); 59 | s = applycform(s,cform); 60 | 61 | 62 | %abs lab 63 | absl=abs(f(:,:,1) - s(:,:,1)); 64 | absa=abs(f(:,:,2) - s(:,:,2)); 65 | absb=abs(f(:,:,3) - s(:,:,3)); 66 | 67 | % rmse 68 | summask=sum(mask(:)); 69 | rl(i)=sum(absl(:))/summask; 70 | ra(i)=sum(absa(:))/summask; 71 | rb(i)=sum(absb(:))/summask; 72 | 73 | %% non-shadow, ours, per image 74 | distl = absl.* nmask; 75 | dista = absa.* nmask; 76 | distb = absb.* nmask; 77 | sumnmask=sum(nmask(:)); 78 | nrl(i)=sum(distl(:))/sumnmask; 79 | nra(i)=sum(dista(:))/sumnmask; 80 | nrb(i)=sum(distb(:))/sumnmask; 81 | 82 | %% rmse in shadow, original way, per pixel 83 | dist = abs((f - s).* repmat(smask,[1 1 3])); 84 | total_dists = total_dists + sum(dist(:)); 85 | total_pixels = total_pixels + sum(smask(:)); 86 | % rmse in non-shadow, original way, per pixel 87 | dist = abs((f - s).* repmat(nmask,[1 1 3])); 88 | total_distn = total_distn + sum(dist(:)); 89 | total_pixeln = total_pixeln + sum(nmask(:)); 90 | end 91 | toc; 92 | %% rmse in shadow, original way, per pixel 93 | fprintf('\tall,\tnon-shadow,\tshadow:\n%f\t%f\t%f\n\n',mean(rl)+mean(ra)+mean(rb),total_distn/total_pixeln,total_dists/total_pixels); 94 | 95 | -------------------------------------------------------------------------------- /metrics/rmse_istd_v1.m: -------------------------------------------------------------------------------- 1 | %% compute rmse for ISTD (V1) 2 | clear;close all;clc 3 | 4 | % mask directory|mask 5 | maskdir = 'C:\Users\76454\Desktop\BID\datasets\jointremoval_v1\testB\'; 6 | MD = dir([maskdir '/*.png']); 7 | 8 | % ground truth directory|GT 9 | freedir = 'C:\Users\76454\Desktop\BID\datasets\jointremoval_v1\testA\'; 10 | FD = dir([freedir '/*.png']); 11 | 12 | % predicted result directory 13 | shadowdir = 'C:\Users\76454\Desktop\BID\results\task3_v1\test_latest\images\fake_A\'; 14 | SD = dir([shadowdir '/*.png']); 15 | 16 | 17 | total_dists = 0; 18 | total_pixels = 0; 19 | total_distn = 0; 20 | total_pixeln = 0; 21 | rl=zeros(1,size(SD,1)); 22 | ra=zeros(1,size(SD,1)); 23 | rb=zeros(1,size(SD,1)); 24 | nrl=zeros(1,size(SD,1)); 25 | nra=zeros(1,size(SD,1)); 26 | nrb=zeros(1,size(SD,1)); 27 | srl=zeros(1,size(SD,1)); 28 | sra=zeros(1,size(SD,1)); 29 | srb=zeros(1,size(SD,1)); 30 | ppsnr=zeros(1,size(SD,1)); 31 | ppsnrs=zeros(1,size(SD,1)); 32 | ppsnrn=zeros(1,size(SD,1)); 33 | sssim=zeros(1,size(SD,1)); 34 | sssims=zeros(1,size(SD,1)); 35 | sssimn=zeros(1,size(SD,1)); 36 | % ISTD dataset image size 480*640 37 | tic; 38 | mask = ones([480,640]); 39 | cform = makecform('srgb2lab'); 40 | 41 | for i=1:size(SD) 42 | sname = strcat(shadowdir,SD(i).name); 43 | fname = strcat(freedir,FD(i).name); 44 | mname = strcat(maskdir,MD(i).name); 45 | s=imread(sname); 46 | f=imread(fname); 47 | m=imread(mname); 48 | s=imresize(s,[256 256]); 49 | f=imresize(f,[256 256]); 50 | m=imresize(m,[256 256]); 51 | mask = ones([size(f,1),size(f,2)]); 52 | 53 | nmask=~m; 54 | smask=~nmask; 55 | 56 | f = double(f)/255; 57 | s = double(s)/255; 58 | 59 | 60 | f = applycform(f,cform); 61 | s = applycform(s,cform); 62 | 63 | 64 | %abs lab 65 | absl=abs(f(:,:,1) - s(:,:,1)); 66 | absa=abs(f(:,:,2) - s(:,:,2)); 67 | absb=abs(f(:,:,3) - s(:,:,3)); 68 | 69 | % rmse 70 | summask=sum(mask(:)); 71 | rl(i)=sum(absl(:))/summask; 72 | ra(i)=sum(absa(:))/summask; 73 | rb(i)=sum(absb(:))/summask; 74 | 75 | %% non-shadow, ours, per image 76 | distl = absl.* nmask; 77 | dista = absa.* nmask; 78 | distb = absb.* nmask; 79 | sumnmask=sum(nmask(:)); 80 | nrl(i)=sum(distl(:))/sumnmask; 81 | nra(i)=sum(dista(:))/sumnmask; 82 | nrb(i)=sum(distb(:))/sumnmask; 83 | 84 | %% rmse in shadow, original way, per pixel 85 | dist = abs((f - s).* repmat(smask,[1 1 3])); 86 | total_dists = total_dists + sum(dist(:)); 87 | total_pixels = total_pixels + sum(smask(:)); 88 | % rmse in non-shadow, original way, per pixel 89 | dist = abs((f - s).* repmat(nmask,[1 1 3])); 90 | total_distn = total_distn + sum(dist(:)); 91 | total_pixeln = total_pixeln + sum(nmask(:)); 92 | end 93 | toc; 94 | %% rmse in shadow, original way, per pixel 95 | fprintf('\tall,\tnon-shadow,\tshadow:\n%f\t%f\t%f\n\n',mean(rl)+mean(ra)+mean(rb),total_distn/total_pixeln,total_dists/total_pixels); 96 | 97 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from models.base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "models." + model_name + "_model" 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace('_', '') + 'model' 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() \ 38 | and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 43 | exit(0) 44 | 45 | return model 46 | 47 | 48 | def get_option_setter(model_name): 49 | """Return the static method of the model class.""" 50 | model_class = find_model_using_name(model_name) 51 | return model_class.modify_commandline_options 52 | 53 | 54 | def create_model(opt): 55 | """Create a model given the option. 56 | 57 | This function warps the class CustomDatasetDataLoader. 58 | This is the main interface between this package and 'train.py'/'test.py' 59 | 60 | Example: 61 | >>> from models import create_model 62 | >>> model = create_model(opt) 63 | """ 64 | model = find_model_using_name(opt.model) 65 | instance = model(opt) 66 | print("model [%s] was created" % type(instance).__name__) 67 | return instance 68 | -------------------------------------------------------------------------------- /experiments/__main__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | 5 | def find_launcher_using_name(launcher_name): 6 | # cur_dir = os.path.dirname(os.path.abspath(__file__)) 7 | # pythonfiles = glob.glob(cur_dir + '/**/*.py') 8 | launcher_filename = "experiments.{}_launcher".format(launcher_name) 9 | launcherlib = importlib.import_module(launcher_filename) 10 | 11 | # In the file, the class called LauncherNameLauncher() will 12 | # be instantiated. It has to be a subclass of BaseLauncher, 13 | # and it is case-insensitive. 14 | launcher = None 15 | # target_launcher_name = launcher_name.replace('_', '') + 'launcher' 16 | for name, cls in launcherlib.__dict__.items(): 17 | if name.lower() == "launcher": 18 | launcher = cls 19 | 20 | if launcher is None: 21 | raise ValueError("In %s.py, there should be a class named Launcher") 22 | 23 | return launcher 24 | 25 | 26 | if __name__ == "__main__": 27 | import argparse 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('name') 31 | parser.add_argument('cmd') 32 | parser.add_argument('id', nargs='+', type=str) 33 | parser.add_argument('--mode', default=None) 34 | parser.add_argument('--which_epoch', default=None) 35 | parser.add_argument('--continue_train', action='store_true') 36 | parser.add_argument('--subdir', default='') 37 | parser.add_argument('--title', default='') 38 | parser.add_argument('--gpu_id', default=None, type=int) 39 | parser.add_argument('--phase', default='test') 40 | 41 | opt = parser.parse_args() 42 | 43 | name = opt.name 44 | Launcher = find_launcher_using_name(name) 45 | 46 | instance = Launcher() 47 | 48 | cmd = opt.cmd 49 | ids = 'all' if 'all' in opt.id else [int(i) for i in opt.id] 50 | if cmd == "launch": 51 | instance.launch(ids, continue_train=opt.continue_train) 52 | elif cmd == "stop": 53 | instance.stop() 54 | elif cmd == "send": 55 | assert False 56 | elif cmd == "close": 57 | instance.close() 58 | elif cmd == "dry": 59 | instance.dry() 60 | elif cmd == "relaunch": 61 | instance.close() 62 | instance.launch(ids, continue_train=opt.continue_train) 63 | elif cmd == "run" or cmd == "train": 64 | assert len(ids) == 1, '%s is invalid for run command' % (' '.join(opt.id)) 65 | expid = ids[0] 66 | instance.run_command(instance.commands(), expid, 67 | continue_train=opt.continue_train, 68 | gpu_id=opt.gpu_id) 69 | elif cmd == 'launch_test': 70 | instance.launch(ids, test=True) 71 | elif cmd == "run_test" or cmd == "test": 72 | test_commands = instance.test_commands() 73 | if ids == "all": 74 | ids = list(range(len(test_commands))) 75 | for expid in ids: 76 | instance.run_command(test_commands, expid, opt.which_epoch, 77 | gpu_id=opt.gpu_id) 78 | if expid < len(ids) - 1: 79 | os.system("sleep 5s") 80 | elif cmd == "print_names": 81 | instance.print_names(ids, test=False) 82 | elif cmd == "print_test_names": 83 | instance.print_names(ids, test=True) 84 | elif cmd == "create_comparison_html": 85 | instance.create_comparison_html(name, ids, opt.subdir, opt.title, opt.phase) 86 | else: 87 | raise ValueError("Command not recognized") 88 | -------------------------------------------------------------------------------- /data/unaligned2_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_transform 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | import random 6 | import util.util as util 7 | 8 | 9 | class Unaligned2Dataset(BaseDataset): 10 | """ 11 | This dataset class can load unaligned/unpaired datasets. 12 | """ 13 | 14 | def __init__(self, opt): 15 | """Initialize this dataset class. 16 | 17 | Parameters: 18 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 19 | """ 20 | BaseDataset.__init__(self, opt) 21 | self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA' 22 | self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB' 23 | 24 | if opt.phase == "test" and not os.path.exists(self.dir_A) \ 25 | and os.path.exists(os.path.join(opt.dataroot, "valA")): 26 | self.dir_A = os.path.join(opt.dataroot, "valA") 27 | self.dir_B = os.path.join(opt.dataroot, "valB") 28 | 29 | self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA' 30 | self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB' 31 | self.A_size = len(self.A_paths) # get the size of dataset A 32 | self.B_size = len(self.B_paths) # get the size of dataset B 33 | 34 | def __getitem__(self, index): 35 | """Return a data point and its metadata information. 36 | 37 | Parameters: 38 | index (int) -- a random integer for data indexing 39 | 40 | Returns a dictionary that contains A, B, A_paths and B_paths 41 | A (tensor) -- an image in the input domain 42 | B (tensor) -- its corresponding image in the target domain 43 | A_paths (str) -- image paths 44 | B_paths (str) -- image paths 45 | """ 46 | A_path = self.A_paths[index % self.A_size] # make sure index is within then range 47 | if self.opt.serial_batches: # make sure index is within then range 48 | index_B = index % self.B_size 49 | else: # randomize the index for domain B to avoid fixed pairs. 50 | index_B = random.randint(0, self.B_size - 1) 51 | B_path = self.B_paths[index_B] 52 | A_img = Image.open(A_path).convert('RGB') 53 | B_img = Image.open(B_path).convert('RGB') 54 | 55 | # Apply image transformation 56 | # For FastCUT mode, if in finetuning phase (learning rate is decaying), 57 | # do not perform resize-crop data augmentation of CycleGAN. 58 | # print('current_epoch', self.current_epoch) 59 | is_finetuning = self.opt.isTrain and self.current_epoch > self.opt.n_epochs 60 | modified_opt = util.copyconf(self.opt, load_size=self.opt.crop_size if is_finetuning else self.opt.load_size) 61 | transform = get_transform(modified_opt) 62 | A = transform(A_img) 63 | B = transform(B_img) 64 | 65 | return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path} 66 | 67 | def __len__(self): 68 | """Return the total number of images in the dataset. 69 | 70 | As we have two datasets with potentially different number of images, 71 | we take a maximum of 72 | """ 73 | return max(self.A_size, self.B_size) 74 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 32 | with self.doc.head: 33 | meta(http_equiv="refresh", content=str(refresh)) 34 | 35 | def get_image_dir(self): 36 | """Return the directory that stores images""" 37 | return self.img_dir 38 | 39 | def add_header(self, text): 40 | """Insert a header to the HTML file 41 | 42 | Parameters: 43 | text (str) -- the header text 44 | """ 45 | with self.doc: 46 | h3(text) 47 | 48 | def add_images(self, ims, txts, links, width=400): 49 | """add images to the HTML file 50 | 51 | Parameters: 52 | ims (str list) -- a list of image paths 53 | txts (str list) -- a list of image names shown on the website 54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 55 | """ 56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 57 | self.doc.add(self.t) 58 | with self.t: 59 | with tr(): 60 | for im, txt, link in zip(ims, txts, links): 61 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 62 | with p(): 63 | with a(href=os.path.join('images', link)): 64 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 65 | br() 66 | p(txt) 67 | 68 | def save(self): 69 | """save the current content to the HMTL file""" 70 | html_file = '%s/index.html' % self.web_dir 71 | f = open(html_file, 'wt') 72 | f.write(self.doc.render()) 73 | f.close() 74 | 75 | 76 | if __name__ == '__main__': # we show an example usage here. 77 | html = HTML('web/', 'test_html') 78 | html.add_header('hello world') 79 | 80 | ims, txts, links = [], [], [] 81 | for n in range(4): 82 | ims.append('image_%d.png' % n) 83 | txts.append('text_%d' % n) 84 | links.append('image_%d.png' % n) 85 | html.add_images(ims, txts, links) 86 | html.save() 87 | -------------------------------------------------------------------------------- /data/template_dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset class template 2 | 3 | This module provides a template for users to implement custom datasets. 4 | You can specify '--dataset_mode template' to use this dataset. 5 | The class name should be consistent with both the filename and its dataset_mode option. 6 | The filename should be _dataset.py 7 | The class name should be Dataset.py 8 | You need to implement the following functions: 9 | -- : Add dataset-specific options and rewrite default values for existing options. 10 | -- <__init__>: Initialize this dataset class. 11 | -- <__getitem__>: Return a data point and its metadata information. 12 | -- <__len__>: Return the number of images. 13 | """ 14 | from data.base_dataset import BaseDataset, get_transform 15 | # from data.image_folder import make_dataset 16 | # from PIL import Image 17 | 18 | 19 | class TemplateDataset(BaseDataset): 20 | """A template dataset class for you to implement custom datasets.""" 21 | @staticmethod 22 | def modify_commandline_options(parser, is_train): 23 | """Add new dataset-specific options, and rewrite default values for existing options. 24 | 25 | Parameters: 26 | parser -- original option parser 27 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 28 | 29 | Returns: 30 | the modified parser. 31 | """ 32 | parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option') 33 | parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values 34 | return parser 35 | 36 | def __init__(self, opt): 37 | """Initialize this dataset class. 38 | 39 | Parameters: 40 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 41 | 42 | A few things can be done here. 43 | - save the options (have been done in BaseDataset) 44 | - get image paths and meta information of the dataset. 45 | - define the image transformation. 46 | """ 47 | # save the option and dataset root 48 | BaseDataset.__init__(self, opt) 49 | # get the image paths of your dataset; 50 | self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root 51 | # define the default transform function. You can use ; You can also define your custom transform function 52 | self.transform = get_transform(opt) 53 | 54 | def __getitem__(self, index): 55 | """Return a data point and its metadata information. 56 | 57 | Parameters: 58 | index -- a random integer for data indexing 59 | 60 | Returns: 61 | a dictionary of data with their names. It usually contains the data itself and its metadata information. 62 | 63 | Step 1: get a random image path: e.g., path = self.image_paths[index] 64 | Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB'). 65 | Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image) 66 | Step 4: return a data point as a dictionary. 67 | """ 68 | path = 'temp' # needs to be a string 69 | data_A = None # needs to be a tensor 70 | data_B = None # needs to be a tensor 71 | return {'data_A': data_A, 'data_B': data_B, 'path': path} 72 | 73 | def __len__(self): 74 | """Return the total number of images.""" 75 | return len(self.image_paths) 76 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | """This class includes training options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) 12 | # visdom and HTML visualization parameters 13 | parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') 14 | parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 15 | parser.add_argument('--display_id', type=int, default=None, help='window id of the web display. Default is random window id') 16 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 17 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') 18 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 19 | parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') 20 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 21 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 22 | # network saving and loading parameters 23 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 24 | parser.add_argument('--save_epoch_freq', type=int, default=100, help='frequency of saving checkpoints at the end of epochs') 25 | parser.add_argument('--evaluation_freq', type=int, default=10000, help='evaluation freq') 26 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') 27 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 28 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 29 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 30 | parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint') 31 | # training parameters 32 | parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate') 33 | parser.add_argument('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero') 34 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 35 | parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam') 36 | parser.add_argument('--lr', type=float, default=0.0003, help='initial learning rate for adam') 37 | parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp| hinge]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') 38 | parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') 39 | parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') 40 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 41 | 42 | self.isTrain = True 43 | return parser 44 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import importlib 14 | import torch.utils.data 15 | from data.base_dataset import BaseDataset 16 | 17 | 18 | def find_dataset_using_name(dataset_name): 19 | """Import the module "data/[dataset_name]_dataset.py". 20 | 21 | In the file, the class called DatasetNameDataset() will 22 | be instantiated. It has to be a subclass of BaseDataset, 23 | and it is case-insensitive. 24 | """ 25 | dataset_filename = "data." + dataset_name + "_dataset" 26 | datasetlib = importlib.import_module(dataset_filename) 27 | 28 | dataset = None 29 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 30 | for name, cls in datasetlib.__dict__.items(): 31 | if name.lower() == target_dataset_name.lower() \ 32 | and issubclass(cls, BaseDataset): 33 | dataset = cls 34 | 35 | if dataset is None: 36 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 37 | 38 | return dataset 39 | 40 | 41 | def get_option_setter(dataset_name): 42 | """Return the static method of the dataset class.""" 43 | dataset_class = find_dataset_using_name(dataset_name) 44 | return dataset_class.modify_commandline_options 45 | 46 | 47 | def create_dataset(opt): 48 | """Create a dataset given the option. 49 | 50 | This function wraps the class CustomDatasetDataLoader. 51 | This is the main interface between this package and 'train.py'/'test.py' 52 | 53 | Example: 54 | >>> from data import create_dataset 55 | >>> dataset = create_dataset(opt) 56 | """ 57 | data_loader = CustomDatasetDataLoader(opt) 58 | dataset = data_loader.load_data() 59 | return dataset 60 | 61 | 62 | class CustomDatasetDataLoader(): 63 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 64 | 65 | def __init__(self, opt): 66 | """Initialize this class 67 | 68 | Step 1: create a dataset instance given the name [dataset_mode] 69 | Step 2: create a multi-threaded data loader. 70 | """ 71 | self.opt = opt 72 | dataset_class = find_dataset_using_name(opt.dataset_mode) 73 | self.dataset = dataset_class(opt) 74 | print("dataset [%s] was created" % type(self.dataset).__name__) 75 | self.dataloader = torch.utils.data.DataLoader( 76 | self.dataset, 77 | batch_size=opt.batch_size, 78 | shuffle=not opt.serial_batches, 79 | num_workers=int(opt.num_threads), 80 | drop_last=True if opt.isTrain else False, 81 | ) 82 | 83 | def set_epoch(self, epoch): 84 | self.dataset.current_epoch = epoch 85 | 86 | def load_data(self): 87 | return self 88 | 89 | def __len__(self): 90 | """Return the number of data in the dataset""" 91 | return min(len(self.dataset), self.opt.max_dataset_size) 92 | 93 | def __iter__(self): 94 | """Return a batch of data""" 95 | for i, data in enumerate(self.dataloader): 96 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 97 | break 98 | yield data 99 | -------------------------------------------------------------------------------- /data/unaligned3_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_transform 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | import random 6 | import util.util as util 7 | 8 | 9 | class Unaligned3Dataset(BaseDataset): 10 | """ 11 | This dataset class can load unaligned/unpaired datasets. 12 | """ 13 | 14 | def __init__(self, opt): 15 | """Initialize this dataset class. 16 | 17 | Parameters: 18 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 19 | """ 20 | BaseDataset.__init__(self, opt) 21 | self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA' 22 | self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB' 23 | self.dir_C = os.path.join(opt.dataroot, opt.phase + 'C') # create a path '/path/to/data/trainC' 24 | 25 | if opt.phase == "test" and not os.path.exists(self.dir_A) \ 26 | and os.path.exists(os.path.join(opt.dataroot, "valA")): 27 | self.dir_A = os.path.join(opt.dataroot, "valA") 28 | self.dir_B = os.path.join(opt.dataroot, "valB") 29 | self.dir_C = os.path.join(opt.dataroot, "valC") 30 | 31 | self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA' 32 | self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB' 33 | self.C_paths = sorted(make_dataset(self.dir_C, opt.max_dataset_size)) 34 | self.A_size = len(self.A_paths) # get the size of dataset A 35 | self.B_size = len(self.B_paths) # get the size of dataset B 36 | self.C_size = len(self.C_paths) # get the size of dataset C 37 | 38 | def __getitem__(self, index): 39 | """Return a data point and its metadata information. 40 | 41 | Parameters: 42 | index (int) -- a random integer for data indexing 43 | 44 | Returns a dictionary that contains A, B, A_paths and B_paths 45 | A (tensor) -- an image in the input domain 46 | B (tensor) -- its corresponding image in the target domain 47 | A_paths (str) -- image paths 48 | B_paths (str) -- image paths 49 | """ 50 | A_path = self.A_paths[index % self.A_size] # make sure index is within then range 51 | if self.opt.serial_batches: # make sure index is within then range 52 | index_B = index % self.B_size 53 | index_C = index % self.C_size 54 | else: # randomize the index for domain B to avoid fixed pairs. 55 | index_B = random.randint(0, self.B_size - 1) 56 | index_C = random.randint(0, self.C_size - 1) 57 | 58 | B_path = self.B_paths[index_B] 59 | C_path = self.C_paths[index_C] 60 | A_img = Image.open(A_path).convert('RGB') 61 | B_img = Image.open(B_path).convert('RGB') 62 | C_img = Image.open(C_path).convert('RGB') 63 | 64 | # Apply image transformation 65 | # For FastCUT mode, if in finetuning phase (learning rate is decaying), 66 | # do not perform resize-crop data augmentation of CycleGAN. 67 | # print('current_epoch', self.current_epoch) 68 | is_finetuning = self.opt.isTrain and self.current_epoch > self.opt.n_epochs 69 | modified_opt = util.copyconf(self.opt, load_size=self.opt.crop_size if is_finetuning else self.opt.load_size) 70 | transform = get_transform(modified_opt) 71 | A = transform(A_img) 72 | B = transform(B_img) 73 | C = transform(C_img) 74 | 75 | return {'A': A, 'B': B, 'C': C, 'A_paths': A_path, 'B_paths': B_path, 'C_paths': C_path} 76 | 77 | def __len__(self): 78 | """Return the total number of images in the dataset. 79 | 80 | As we have two datasets with potentially different number of images, 81 | we take a maximum of 82 | """ 83 | return max(self.A_size, self.B_size, self.C_size) 84 | -------------------------------------------------------------------------------- /util/get_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import tarfile 4 | import requests 5 | from warnings import warn 6 | from zipfile import ZipFile 7 | from bs4 import BeautifulSoup 8 | from os.path import abspath, isdir, join, basename 9 | 10 | 11 | class GetData(object): 12 | """A Python script for downloading CycleGAN or pix2pix datasets. 13 | 14 | Parameters: 15 | technique (str) -- One of: 'cyclegan' or 'pix2pix'. 16 | verbose (bool) -- If True, print additional information. 17 | 18 | Examples: 19 | >>> from util.get_data import GetData 20 | >>> gd = GetData(technique='cyclegan') 21 | >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. 22 | 23 | Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh' 24 | and 'scripts/download_cyclegan_model.sh'. 25 | """ 26 | 27 | def __init__(self, technique='cyclegan', verbose=True): 28 | url_dict = { 29 | 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/', 30 | 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' 31 | } 32 | self.url = url_dict.get(technique.lower()) 33 | self._verbose = verbose 34 | 35 | def _print(self, text): 36 | if self._verbose: 37 | print(text) 38 | 39 | @staticmethod 40 | def _get_options(r): 41 | soup = BeautifulSoup(r.text, 'lxml') 42 | options = [h.text for h in soup.find_all('a', href=True) 43 | if h.text.endswith(('.zip', 'tar.gz'))] 44 | return options 45 | 46 | def _present_options(self): 47 | r = requests.get(self.url) 48 | options = self._get_options(r) 49 | print('Options:\n') 50 | for i, o in enumerate(options): 51 | print("{0}: {1}".format(i, o)) 52 | choice = input("\nPlease enter the number of the " 53 | "dataset above you wish to download:") 54 | return options[int(choice)] 55 | 56 | def _download_data(self, dataset_url, save_path): 57 | if not isdir(save_path): 58 | os.makedirs(save_path) 59 | 60 | base = basename(dataset_url) 61 | temp_save_path = join(save_path, base) 62 | 63 | with open(temp_save_path, "wb") as f: 64 | r = requests.get(dataset_url) 65 | f.write(r.content) 66 | 67 | if base.endswith('.tar.gz'): 68 | obj = tarfile.open(temp_save_path) 69 | elif base.endswith('.zip'): 70 | obj = ZipFile(temp_save_path, 'r') 71 | else: 72 | raise ValueError("Unknown File Type: {0}.".format(base)) 73 | 74 | self._print("Unpacking Data...") 75 | obj.extractall(save_path) 76 | obj.close() 77 | os.remove(temp_save_path) 78 | 79 | def get(self, save_path, dataset=None): 80 | """ 81 | 82 | Download a dataset. 83 | 84 | Parameters: 85 | save_path (str) -- A directory to save the data to. 86 | dataset (str) -- (optional). A specific dataset to download. 87 | Note: this must include the file extension. 88 | If None, options will be presented for you 89 | to choose from. 90 | 91 | Returns: 92 | save_path_full (str) -- the absolute path to the downloaded data. 93 | 94 | """ 95 | if dataset is None: 96 | selected_dataset = self._present_options() 97 | else: 98 | selected_dataset = dataset 99 | 100 | save_path_full = join(save_path, selected_dataset.split('.')[0]) 101 | 102 | if isdir(save_path_full): 103 | warn("\n'{0}' already exists. Voiding Download.".format( 104 | save_path_full)) 105 | else: 106 | self._print('Downloading Data...') 107 | url = "{0}/{1}".format(self.url, selected_dataset) 108 | self._download_data(url, save_path=save_path) 109 | 110 | return abspath(save_path_full) 111 | -------------------------------------------------------------------------------- /datasets/prepare_cityscapes_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from PIL import Image 4 | 5 | help_msg = """ 6 | The dataset can be downloaded from https://cityscapes-dataset.com. 7 | Please download the datasets [gtFine_trainvaltest.zip] and [leftImg8bit_trainvaltest.zip] and unzip them. 8 | gtFine contains the semantics segmentations. Use --gtFine_dir to specify the path to the unzipped gtFine_trainvaltest directory. 9 | leftImg8bit contains the dashcam photographs. Use --leftImg8bit_dir to specify the path to the unzipped leftImg8bit_trainvaltest directory. 10 | The processed images will be placed at --output_dir. 11 | 12 | Example usage: 13 | 14 | python prepare_cityscapes_dataset.py --gitFine_dir ./gtFine/ --leftImg8bit_dir ./leftImg8bit --output_dir ./datasets/cityscapes/ 15 | """ 16 | 17 | 18 | def load_resized_img(path): 19 | return Image.open(path).convert('RGB').resize((256, 256)) 20 | 21 | 22 | def check_matching_pair(segmap_path, photo_path): 23 | segmap_identifier = os.path.basename(segmap_path).replace('_gtFine_color', '') 24 | photo_identifier = os.path.basename(photo_path).replace('_leftImg8bit', '') 25 | 26 | assert segmap_identifier == photo_identifier, \ 27 | "[%s] and [%s] don't seem to be matching. Aborting." % (segmap_path, photo_path) 28 | 29 | 30 | def process_cityscapes(gtFine_dir, leftImg8bit_dir, output_dir, phase): 31 | save_phase = 'test' if phase == 'val' else 'train' 32 | savedir = os.path.join(output_dir, save_phase) 33 | os.makedirs(savedir, exist_ok=True) 34 | os.makedirs(savedir + 'A', exist_ok=True) 35 | os.makedirs(savedir + 'B', exist_ok=True) 36 | print("Directory structure prepared at %s" % output_dir) 37 | 38 | segmap_expr = os.path.join(gtFine_dir, phase) + "/*/*_color.png" 39 | segmap_paths = glob.glob(segmap_expr) 40 | segmap_paths = sorted(segmap_paths) 41 | 42 | photo_expr = os.path.join(leftImg8bit_dir, phase) + "/*/*_leftImg8bit.png" 43 | photo_paths = glob.glob(photo_expr) 44 | photo_paths = sorted(photo_paths) 45 | 46 | assert len(segmap_paths) == len(photo_paths), \ 47 | "%d images that match [%s], and %d images that match [%s]. Aborting." % (len(segmap_paths), segmap_expr, len(photo_paths), photo_expr) 48 | 49 | for i, (segmap_path, photo_path) in enumerate(zip(segmap_paths, photo_paths)): 50 | check_matching_pair(segmap_path, photo_path) 51 | segmap = load_resized_img(segmap_path) 52 | photo = load_resized_img(photo_path) 53 | 54 | # data for pix2pix where the two images are placed side-by-side 55 | sidebyside = Image.new('RGB', (512, 256)) 56 | sidebyside.paste(segmap, (256, 0)) 57 | sidebyside.paste(photo, (0, 0)) 58 | savepath = os.path.join(savedir, "%d.jpg" % i) 59 | sidebyside.save(savepath, format='JPEG', subsampling=0, quality=100) 60 | 61 | # data for cyclegan where the two images are stored at two distinct directories 62 | savepath = os.path.join(savedir + 'A', "%d_A.jpg" % i) 63 | photo.save(savepath, format='JPEG', subsampling=0, quality=100) 64 | savepath = os.path.join(savedir + 'B', "%d_B.jpg" % i) 65 | segmap.save(savepath, format='JPEG', subsampling=0, quality=100) 66 | 67 | if i % (len(segmap_paths) // 10) == 0: 68 | print("%d / %d: last image saved at %s, " % (i, len(segmap_paths), savepath)) 69 | 70 | 71 | if __name__ == '__main__': 72 | import argparse 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument('--gtFine_dir', type=str, required=True, 75 | help='Path to the Cityscapes gtFine directory.') 76 | parser.add_argument('--leftImg8bit_dir', type=str, required=True, 77 | help='Path to the Cityscapes leftImg8bit_trainvaltest directory.') 78 | parser.add_argument('--output_dir', type=str, required=True, 79 | default='./datasets/cityscapes', 80 | help='Directory the output images will be written to.') 81 | opt = parser.parse_args() 82 | 83 | print(help_msg) 84 | 85 | print('Preparing Cityscapes Dataset for val phase') 86 | process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "val") 87 | print('Preparing Cityscapes Dataset for train phase') 88 | process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "train") 89 | 90 | print('Done') 91 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from options.train_options import TrainOptions 4 | from data import create_dataset 5 | from models import create_model 6 | from util.visualizer import Visualizer 7 | 8 | 9 | if __name__ == '__main__': 10 | opt = TrainOptions().parse() # get training options 11 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options 12 | dataset_size = len(dataset) # get the number of images in the dataset. 13 | 14 | model = create_model(opt) # create a model given opt.model and other options 15 | print('The number of training images = %d' % dataset_size) 16 | 17 | visualizer = Visualizer(opt) # create a visualizer that display/save images and plots 18 | opt.visualizer = visualizer 19 | total_iters = 0 # the total number of training iterations 20 | 21 | optimize_time = 0.1 22 | times = [] 23 | for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1): # outer loop for different epochs; we save the model by , + 24 | epoch_start_time = time.time() # timer for entire epoch 25 | iter_data_time = time.time() # timer for data loading per iteration 26 | epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch 27 | visualizer.reset() # reset the visualizer: make sure it saves the results to HTML at least once every epoch 28 | 29 | dataset.set_epoch(epoch) 30 | for i, data in enumerate(dataset): # inner loop within one epoch 31 | iter_start_time = time.time() # timer for computation per iteration 32 | if total_iters % opt.print_freq == 0: 33 | t_data = iter_start_time - iter_data_time 34 | 35 | batch_size = data["A"].size(0) 36 | total_iters += batch_size 37 | epoch_iter += batch_size 38 | if len(opt.gpu_ids) > 0: 39 | torch.cuda.synchronize() 40 | optimize_start_time = time.time() 41 | if epoch == opt.epoch_count and i == 0: 42 | model.data_dependent_initialize(data) 43 | model.setup(opt) # regular setup: load and print networks; create schedulers 44 | model.parallelize() 45 | model.set_input(data) # unpack data from dataset and apply preprocessing 46 | model.optimize_parameters() # calculate loss functions, get gradients, update network weights 47 | if len(opt.gpu_ids) > 0: 48 | torch.cuda.synchronize() 49 | optimize_time = (time.time() - optimize_start_time) / batch_size * 0.005 + 0.995 * optimize_time 50 | 51 | if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file 52 | save_result = total_iters % opt.update_html_freq == 0 53 | model.compute_visuals() 54 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) 55 | 56 | if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk 57 | losses = model.get_current_losses() 58 | visualizer.print_current_losses(epoch, epoch_iter, losses, optimize_time, t_data) 59 | if opt.display_id is None or opt.display_id > 0: 60 | visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses) 61 | 62 | if total_iters % opt.save_latest_freq == 0: # cache our latest model every iterations 63 | print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) 64 | print(opt.name) # it's useful to occasionally show the experiment name on console 65 | save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest' 66 | model.save_networks(save_suffix) 67 | 68 | iter_data_time = time.time() 69 | 70 | if epoch % opt.save_epoch_freq == 0: # cache our model every epochs 71 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) 72 | model.save_networks('latest') 73 | model.save_networks(epoch) 74 | 75 | print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)) 76 | model.update_learning_rate() # update learning rates at the end of every epoch. 77 | -------------------------------------------------------------------------------- /data/unaligned4_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_transform 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | import random 6 | import util.util as util 7 | 8 | 9 | class Unaligned4Dataset(BaseDataset): 10 | """ 11 | This dataset class can load unaligned/unpaired datasets. 12 | """ 13 | 14 | def __init__(self, opt): 15 | """Initialize this dataset class. 16 | 17 | Parameters: 18 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 19 | """ 20 | BaseDataset.__init__(self, opt) 21 | self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA' 22 | self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB' 23 | self.dir_C = os.path.join(opt.dataroot, opt.phase + 'C') # create a path '/path/to/data/trainC' 24 | self.dir_D = os.path.join(opt.dataroot, opt.phase + 'D') # create a path '/path/to/data/trainB' 25 | 26 | if opt.phase == "test" and not os.path.exists(self.dir_A) \ 27 | and os.path.exists(os.path.join(opt.dataroot, "valA")): 28 | self.dir_A = os.path.join(opt.dataroot, "valA") 29 | self.dir_B = os.path.join(opt.dataroot, "valB") 30 | self.dir_C = os.path.join(opt.dataroot, "valC") 31 | self.dir_D = os.path.join(opt.dataroot, "valD") 32 | 33 | self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA' 34 | self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB' 35 | self.C_paths = sorted(make_dataset(self.dir_C, opt.max_dataset_size)) 36 | self.D_paths = sorted(make_dataset(self.dir_D, opt.max_dataset_size)) # load images from '/path/to/data/trainB' 37 | self.A_size = len(self.A_paths) # get the size of dataset A 38 | self.B_size = len(self.B_paths) # get the size of dataset B 39 | self.C_size = len(self.C_paths) # get the size of dataset C 40 | self.D_size = len(self.D_paths) # get the size of dataset B 41 | 42 | def __getitem__(self, index): 43 | """Return a data point and its metadata information. 44 | 45 | Parameters: 46 | index (int) -- a random integer for data indexing 47 | 48 | Returns a dictionary that contains A, B, A_paths and B_paths 49 | A (tensor) -- an image in the input domain 50 | B (tensor) -- its corresponding image in the target domain 51 | A_paths (str) -- image paths 52 | B_paths (str) -- image paths 53 | """ 54 | A_path = self.A_paths[index % self.A_size] # make sure index is within then range 55 | if self.opt.serial_batches: # make sure index is within then range 56 | index_B = index % self.B_size 57 | index_C = index % self.C_size 58 | index_D = index % self.D_size 59 | else: # randomize the index for domain B to avoid fixed pairs. 60 | index_B = random.randint(0, self.B_size - 1) 61 | index_C = random.randint(0, self.C_size - 1) 62 | index_D = random.randint(0, self.D_size - 1) 63 | 64 | B_path = self.B_paths[index_B] 65 | C_path = self.C_paths[index_C] 66 | D_path = self.D_paths[index_D] 67 | A_img = Image.open(A_path).convert('RGB') 68 | B_img = Image.open(B_path).convert('RGB') 69 | C_img = Image.open(C_path).convert('RGB') 70 | D_img = Image.open(D_path).convert('RGB') 71 | # Apply image transformation 72 | # For FastCUT mode, if in finetuning phase (learning rate is decaying), 73 | # do not perform resize-crop data augmentation of CycleGAN. 74 | # print('current_epoch', self.current_epoch) 75 | is_finetuning = self.opt.isTrain and self.current_epoch > self.opt.n_epochs 76 | modified_opt = util.copyconf(self.opt, load_size=self.opt.crop_size if is_finetuning else self.opt.load_size) 77 | transform = get_transform(modified_opt) 78 | A = transform(A_img) 79 | B = transform(B_img) 80 | C = transform(C_img) 81 | D = transform(D_img) 82 | 83 | return {'A': A, 'B': B, 'C': C, 'D': D, 'A_paths': A_path, 'B_paths': B_path, 'C_paths': C_path, 'D_paths': D_path} 84 | 85 | def __len__(self): 86 | """Return the total number of images in the dataset. 87 | 88 | As we have two datasets with potentially different number of images, 89 | we take a maximum of 90 | """ 91 | return max(self.A_size, self.B_size, self.C_size, self.D_size) 92 | -------------------------------------------------------------------------------- /data/jointremoval_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_transform, get_params 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | import random 6 | import util.util as util 7 | 8 | 9 | class JointremovalDataset(BaseDataset): 10 | """ 11 | This dataset class can load unaligned/unpaired datasets. 12 | """ 13 | 14 | def __init__(self, opt): 15 | """Initialize this dataset class. 16 | 17 | Parameters: 18 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 19 | """ 20 | BaseDataset.__init__(self, opt) 21 | self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA' 22 | self.dir_A2 = os.path.join(opt.dataroot, opt.phase + 'A2') 23 | self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') 24 | self.dir_C = os.path.join(opt.dataroot, opt.phase + 'C') 25 | self.dir_D1 = os.path.join(opt.dataroot, opt.phase + 'D1') 26 | self.dir_D2 = os.path.join(opt.dataroot, opt.phase + 'D2') 27 | 28 | 29 | if opt.phase == "test" and not os.path.exists(self.dir_A) \ 30 | and os.path.exists(os.path.join(opt.dataroot, "valA")): 31 | self.dir_A = os.path.join(opt.dataroot, "valA") 32 | self.dir_A2 = os.path.join(opt.dataroot, "valA2") 33 | self.dir_B = os.path.join(opt.dataroot, "valB") 34 | self.dir_C = os.path.join(opt.dataroot, "valC") 35 | self.dir_D1 = os.path.join(opt.dataroot, "valD1") 36 | self.dir_D2 = os.path.join(opt.dataroot, "valD2") 37 | 38 | self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) 39 | self.A2_paths = sorted(make_dataset(self.dir_A2, opt.max_dataset_size)) 40 | self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) 41 | self.C_paths = sorted(make_dataset(self.dir_C, opt.max_dataset_size)) 42 | self.D1_paths = sorted(make_dataset(self.dir_D1, opt.max_dataset_size)) 43 | self.D2_paths = sorted(make_dataset(self.dir_D2, opt.max_dataset_size)) 44 | 45 | self.A_size = len(self.A_paths) 46 | self.B_size = len(self.B_paths) 47 | self.C_size = len(self.C_paths) 48 | self.D_size = len(self.D1_paths) 49 | 50 | def __getitem__(self, index): 51 | """Return a data point and its metadata information. 52 | 53 | Parameters: 54 | index (int) -- a random integer for data indexing 55 | 56 | Returns a dictionary that contains A, B, A_paths and B_paths 57 | A (tensor) -- an image in the input domain 58 | B (tensor) -- its corresponding image in the target domain 59 | A_paths (str) -- image paths 60 | B_paths (str) -- image paths 61 | """ 62 | index_A = index % self.A_size 63 | A_path = self.A_paths[index_A] # make sure index is within then range 64 | if self.opt.serial_batches: # make sure index is within then range 65 | index_C = index % self.C_size 66 | index_D = index % self.D_size 67 | else: # randomize the index for domain B to avoid fixed pairs. 68 | index_C = random.randint(0, self.C_size - 1) 69 | index_D = random.randint(0, self.D_size - 1) 70 | 71 | B_path = self.B_paths[index_A] 72 | C_path = self.C_paths[index_C] 73 | D1_path = self.D1_paths[index_D] 74 | D2_path = self.D2_paths[index_D] 75 | A2_path = self.A2_paths[index_A] 76 | A_img = Image.open(A_path).convert('RGB') 77 | B_img = Image.open(B_path).convert('RGB') 78 | C_img = Image.open(C_path).convert('RGB') 79 | D1_img = Image.open(D1_path).convert('RGB') 80 | D2_img = Image.open(D2_path).convert('RGB') 81 | A2_img = Image.open(A2_path).convert('RGB') 82 | is_finetuning = self.opt.isTrain and self.current_epoch > self.opt.n_epochs 83 | modified_opt = util.copyconf(self.opt, load_size=self.opt.crop_size if is_finetuning else self.opt.load_size) 84 | transform = get_transform(modified_opt) 85 | transform_params = get_params(self.opt, A_img.size) 86 | fix_transform = get_transform(self.opt, transform_params) 87 | 88 | A = fix_transform(A_img) 89 | B = fix_transform(B_img) 90 | C = transform(C_img) 91 | D1 = fix_transform(D1_img) 92 | D2 = fix_transform(D2_img) 93 | A2 = fix_transform(A2_img) 94 | 95 | return {'A': A, 'B': B, 'C': C, 'D1': D1, 'A2': A2, 'D2': D2, 'A_paths': A_path, 'B_paths': B_path, 'C_paths': C_path, 'D1_paths': D1_path, 'A2_paths': A2_path, 'D2_paths': D2_path} 96 | 97 | def __len__(self): 98 | """Return the total number of images in the dataset. 99 | 100 | As we have two datasets with potentially different number of images, 101 | we take a maximum of 102 | """ 103 | return max(self.A_size, self.B_size, self.C_size, self.D_size) 104 | -------------------------------------------------------------------------------- /data/unaligned5_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_transform 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | import random 6 | import util.util as util 7 | 8 | 9 | class Unaligned5Dataset(BaseDataset): 10 | """ 11 | This dataset class can load unaligned/unpaired datasets. 12 | """ 13 | 14 | def __init__(self, opt): 15 | """Initialize this dataset class. 16 | 17 | Parameters: 18 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 19 | """ 20 | BaseDataset.__init__(self, opt) 21 | self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA' 22 | self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB' 23 | self.dir_C = os.path.join(opt.dataroot, opt.phase + 'C') # create a path '/path/to/data/trainC' 24 | self.dir_D = os.path.join(opt.dataroot, opt.phase + 'D') # create a path '/path/to/data/trainB' 25 | self.dir_E = os.path.join(opt.dataroot, opt.phase + 'E') # create a path '/path/to/data/trainC' 26 | 27 | if opt.phase == "test" and not os.path.exists(self.dir_A) \ 28 | and os.path.exists(os.path.join(opt.dataroot, "valA")): 29 | self.dir_A = os.path.join(opt.dataroot, "valA") 30 | self.dir_B = os.path.join(opt.dataroot, "valB") 31 | self.dir_C = os.path.join(opt.dataroot, "valC") 32 | self.dir_D = os.path.join(opt.dataroot, "valD") 33 | self.dir_E = os.path.join(opt.dataroot, "valE") 34 | 35 | self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA' 36 | self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB' 37 | self.C_paths = sorted(make_dataset(self.dir_C, opt.max_dataset_size)) 38 | self.D_paths = sorted(make_dataset(self.dir_D, opt.max_dataset_size)) # load images from '/path/to/data/trainB' 39 | self.E_paths = sorted(make_dataset(self.dir_E, opt.max_dataset_size)) 40 | self.A_size = len(self.A_paths) # get the size of dataset A 41 | self.B_size = len(self.B_paths) # get the size of dataset B 42 | self.C_size = len(self.C_paths) # get the size of dataset C 43 | self.D_size = len(self.D_paths) # get the size of dataset B 44 | self.E_size = len(self.E_paths) # get the size of dataset C 45 | 46 | def __getitem__(self, index): 47 | """Return a data point and its metadata information. 48 | 49 | Parameters: 50 | index (int) -- a random integer for data indexing 51 | 52 | Returns a dictionary that contains A, B, A_paths and B_paths 53 | A (tensor) -- an image in the input domain 54 | B (tensor) -- its corresponding image in the target domain 55 | A_paths (str) -- image paths 56 | B_paths (str) -- image paths 57 | """ 58 | A_path = self.A_paths[index % self.A_size] # make sure index is within then range 59 | if self.opt.serial_batches: # make sure index is within then range 60 | index_B = index % self.B_size 61 | index_C = index % self.C_size 62 | index_D = index % self.D_size 63 | index_E = index % self.E_size 64 | else: # randomize the index for domain B to avoid fixed pairs. 65 | index_B = random.randint(0, self.B_size - 1) 66 | index_C = random.randint(0, self.C_size - 1) 67 | index_D = random.randint(0, self.D_size - 1) 68 | index_E = random.randint(0, self.E_size - 1) 69 | 70 | B_path = self.B_paths[index_B] 71 | C_path = self.C_paths[index_C] 72 | D_path = self.D_paths[index_D] 73 | E_path = self.E_paths[index_E] 74 | A_img = Image.open(A_path).convert('RGB') 75 | B_img = Image.open(B_path).convert('RGB') 76 | C_img = Image.open(C_path).convert('RGB') 77 | D_img = Image.open(D_path).convert('RGB') 78 | E_img = Image.open(E_path).convert('RGB') 79 | # Apply image transformation 80 | # For FastCUT mode, if in finetuning phase (learning rate is decaying), 81 | # do not perform resize-crop data augmentation of CycleGAN. 82 | # print('current_epoch', self.current_epoch) 83 | is_finetuning = self.opt.isTrain and self.current_epoch > self.opt.n_epochs 84 | modified_opt = util.copyconf(self.opt, load_size=self.opt.crop_size if is_finetuning else self.opt.load_size) 85 | transform = get_transform(modified_opt) 86 | A = transform(A_img) 87 | B = transform(B_img) 88 | C = transform(C_img) 89 | D = transform(D_img) 90 | E = transform(E_img) 91 | 92 | return {'A': A, 'B': B, 'C': C, 'D': D, 'E': E, 'A_paths': A_path, 'B_paths': B_path, 'C_paths': C_path, 'D_paths': D_path, 'E_paths': E_path} 93 | 94 | def __len__(self): 95 | """Return the total number of images in the dataset. 96 | 97 | As we have two datasets with potentially different number of images, 98 | we take a maximum of 99 | """ 100 | return max(self.A_size, self.B_size, self.C_size, self.D_size, self.E_size) 101 | -------------------------------------------------------------------------------- /data/rainb_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_transform, get_params 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | import random 6 | import util.util as util 7 | 8 | 9 | class RainbDataset(BaseDataset): 10 | """ 11 | This dataset class can load unaligned/unpaired datasets. 12 | """ 13 | 14 | def __init__(self, opt): 15 | """Initialize this dataset class. 16 | 17 | Parameters: 18 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 19 | """ 20 | BaseDataset.__init__(self, opt) 21 | 22 | if opt.phase == "train": 23 | self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA' 24 | self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') 25 | self.dir_C = os.path.join(opt.dataroot, opt.phase + 'C') 26 | self.dir_E1 = os.path.join(opt.dataroot, opt.phase + 'E1') 27 | self.dir_E2 = os.path.join(opt.dataroot, opt.phase + 'E2') 28 | 29 | self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA' 30 | self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) 31 | self.C_paths = sorted(make_dataset(self.dir_C, opt.max_dataset_size)) 32 | self.E1_paths = sorted(make_dataset(self.dir_E1, opt.max_dataset_size)) 33 | self.E2_paths = sorted(make_dataset(self.dir_E2, opt.max_dataset_size)) 34 | 35 | self.A_size = len(self.A_paths) # get the size of dataset A 36 | self.B_size = len(self.B_paths) 37 | self.C_size = len(self.C_paths) 38 | self.E_size = len(self.E1_paths) 39 | 40 | if opt.phase == "test": 41 | if opt.test_input == "A": 42 | self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') 43 | elif opt.test_input == "B": 44 | self.dir_A = os.path.join(opt.dataroot, opt.phase + 'B') 45 | elif opt.test_input == "C": 46 | self.dir_A = os.path.join(opt.dataroot, opt.phase + 'C') 47 | self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') 48 | self.dir_C = os.path.join(opt.dataroot, opt.phase + 'C') 49 | # E1/E2 are useless, masks are not used in Task II.B testing 50 | self.dir_E1 = os.path.join(opt.dataroot, opt.phase + 'E1') 51 | self.dir_E2 = os.path.join(opt.dataroot, opt.phase + 'E2') 52 | 53 | self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) 54 | self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) 55 | self.C_paths = sorted(make_dataset(self.dir_C, opt.max_dataset_size)) 56 | self.E1_paths = sorted(make_dataset(self.dir_E1, opt.max_dataset_size)) 57 | self.E2_paths = sorted(make_dataset(self.dir_E2, opt.max_dataset_size)) 58 | 59 | self.A_size = len(self.A_paths) 60 | self.B_size = len(self.B_paths) 61 | self.C_size = len(self.C_paths) 62 | self.E_size = len(self.E1_paths) 63 | 64 | def __getitem__(self, index): 65 | """Return a data point and its metadata information. 66 | 67 | Parameters: 68 | index (int) -- a random integer for data indexing 69 | 70 | Returns a dictionary that contains A, B, A_paths and B_paths 71 | A (tensor) -- an image in the input domain 72 | B (tensor) -- its corresponding image in the target domain 73 | A_paths (str) -- image paths 74 | B_paths (str) -- image paths 75 | """ 76 | index_A = index % self.A_size 77 | A_path = self.A_paths[index_A] # make sure index is within then range 78 | if self.opt.serial_batches: # make sure index is within then range 79 | index_B = index % self.B_size 80 | index_C = index % self.C_size 81 | index_E = index % self.E_size 82 | else: 83 | index_B = random.randint(0, self.B_size - 1) 84 | index_C = random.randint(0, self.C_size - 1) 85 | index_E = random.randint(0, self.E_size - 1) 86 | 87 | B_path = self.B_paths[index_B] 88 | C_path = self.C_paths[index_C] 89 | E1_path = self.E1_paths[index_E] 90 | E2_path = self.E2_paths[index_E] 91 | 92 | 93 | A_img = Image.open(A_path).convert('RGB') 94 | B_img = Image.open(B_path).convert('RGB') 95 | C_img = Image.open(C_path).convert('RGB') 96 | E1_img = Image.open(E1_path).convert('RGB') 97 | E2_img = Image.open(E2_path).convert('RGB') 98 | 99 | is_finetuning = self.opt.isTrain and self.current_epoch > self.opt.n_epochs 100 | modified_opt = util.copyconf(self.opt, load_size=self.opt.crop_size if is_finetuning else self.opt.load_size) 101 | transform = get_transform(modified_opt) 102 | transform_params = get_params(self.opt, A_img.size) 103 | fix_transform = get_transform(self.opt, transform_params) 104 | 105 | A = fix_transform(A_img) 106 | B = transform(B_img) 107 | C = transform(C_img) 108 | E1 = fix_transform(E1_img) 109 | E2 = fix_transform(E2_img) 110 | 111 | return {'A': A, 'B': B, 'C': C, 'E1': E1, 'E2': E2, 'A_paths': A_path, 'B_paths': B_path, 'C_paths': C_path, 'E1_paths': E1_path, 'E2_paths': E2_path} 112 | 113 | def __len__(self): 114 | """Return the total number of images in the dataset. 115 | 116 | As we have two datasets with potentially different number of images, 117 | we take a maximum of 118 | """ 119 | return max(self.A_size, self.B_size, self.C_size, self.E_size) 120 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | import importlib 8 | import argparse 9 | from argparse import Namespace 10 | import torchvision 11 | 12 | 13 | def str2bool(v): 14 | if isinstance(v, bool): 15 | return v 16 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 17 | return True 18 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 19 | return False 20 | else: 21 | raise argparse.ArgumentTypeError('Boolean value expected.') 22 | 23 | 24 | def copyconf(default_opt, **kwargs): 25 | conf = Namespace(**vars(default_opt)) 26 | for key in kwargs: 27 | setattr(conf, key, kwargs[key]) 28 | return conf 29 | 30 | 31 | def find_class_in_module(target_cls_name, module): 32 | target_cls_name = target_cls_name.replace('_', '').lower() 33 | clslib = importlib.import_module(module) 34 | cls = None 35 | for name, clsobj in clslib.__dict__.items(): 36 | if name.lower() == target_cls_name: 37 | cls = clsobj 38 | 39 | assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name) 40 | 41 | return cls 42 | 43 | 44 | def tensor2im(input_image, imtype=np.uint8): 45 | """"Converts a Tensor array into a numpy image array. 46 | 47 | Parameters: 48 | input_image (tensor) -- the input image tensor array 49 | imtype (type) -- the desired type of the converted numpy array 50 | """ 51 | if not isinstance(input_image, np.ndarray): 52 | if isinstance(input_image, torch.Tensor): # get the data from a variable 53 | image_tensor = input_image.data 54 | else: 55 | return input_image 56 | image_numpy = image_tensor[0].clamp(-1.0, 1.0).cpu().float().numpy() # convert it into a numpy array 57 | if image_numpy.shape[0] == 1: # grayscale to RGB 58 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 59 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 60 | else: # if it is a numpy array, do nothing 61 | image_numpy = input_image 62 | return image_numpy.astype(imtype) 63 | 64 | 65 | def diagnose_network(net, name='network'): 66 | """Calculate and print the mean of average absolute(gradients) 67 | 68 | Parameters: 69 | net (torch network) -- Torch network 70 | name (str) -- the name of the network 71 | """ 72 | mean = 0.0 73 | count = 0 74 | for param in net.parameters(): 75 | if param.grad is not None: 76 | mean += torch.mean(torch.abs(param.grad.data)) 77 | count += 1 78 | if count > 0: 79 | mean = mean / count 80 | print(name) 81 | print(mean) 82 | 83 | 84 | def save_image(image_numpy, image_path, aspect_ratio=1.0): 85 | """Save a numpy image to the disk 86 | 87 | Parameters: 88 | image_numpy (numpy array) -- input numpy array 89 | image_path (str) -- the path of the image 90 | """ 91 | 92 | image_pil = Image.fromarray(image_numpy) 93 | h, w, _ = image_numpy.shape 94 | 95 | if aspect_ratio is None: 96 | pass 97 | elif aspect_ratio > 1.0: 98 | image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) 99 | elif aspect_ratio < 1.0: 100 | image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) 101 | image_pil.save(image_path) 102 | 103 | 104 | def print_numpy(x, val=True, shp=False): 105 | """Print the mean, min, max, median, std, and size of a numpy array 106 | 107 | Parameters: 108 | val (bool) -- if print the values of the numpy array 109 | shp (bool) -- if print the shape of the numpy array 110 | """ 111 | x = x.astype(np.float64) 112 | if shp: 113 | print('shape,', x.shape) 114 | if val: 115 | x = x.flatten() 116 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 117 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 118 | 119 | 120 | def mkdirs(paths): 121 | """create empty directories if they don't exist 122 | 123 | Parameters: 124 | paths (str list) -- a list of directory paths 125 | """ 126 | if isinstance(paths, list) and not isinstance(paths, str): 127 | for path in paths: 128 | mkdir(path) 129 | else: 130 | mkdir(paths) 131 | 132 | 133 | def mkdir(path): 134 | """create a single empty directory if it didn't exist 135 | 136 | Parameters: 137 | path (str) -- a single directory path 138 | """ 139 | if not os.path.exists(path): 140 | os.makedirs(path) 141 | 142 | 143 | def correct_resize_label(t, size): 144 | device = t.device 145 | t = t.detach().cpu() 146 | resized = [] 147 | for i in range(t.size(0)): 148 | one_t = t[i, :1] 149 | one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0)) 150 | one_np = one_np[:, :, 0] 151 | one_image = Image.fromarray(one_np).resize(size, Image.NEAREST) 152 | resized_t = torch.from_numpy(np.array(one_image)).long() 153 | resized.append(resized_t) 154 | return torch.stack(resized, dim=0).to(device) 155 | 156 | 157 | def correct_resize(t, size, mode=Image.BICUBIC): 158 | device = t.device 159 | t = t.detach().cpu() 160 | resized = [] 161 | for i in range(t.size(0)): 162 | one_t = t[i:i + 1] 163 | one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC) 164 | resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0 165 | resized.append(resized_t) 166 | return torch.stack(resized, dim=0).to(device) 167 | -------------------------------------------------------------------------------- /data/raina_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_transform, get_params 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | import random 6 | import util.util as util 7 | 8 | 9 | class RainaDataset(BaseDataset): 10 | """ 11 | This dataset class can load unaligned/unpaired datasets. 12 | """ 13 | 14 | def __init__(self, opt): 15 | """Initialize this dataset class. 16 | 17 | Parameters: 18 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 19 | """ 20 | BaseDataset.__init__(self, opt) 21 | self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA' 22 | self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') 23 | self.dir_C = os.path.join(opt.dataroot, opt.phase + 'C') 24 | self.dir_D1 = os.path.join(opt.dataroot, opt.phase + 'D1') 25 | self.dir_D2 = os.path.join(opt.dataroot, opt.phase + 'D2') 26 | self.dir_D3 = os.path.join(opt.dataroot, opt.phase + 'D3') 27 | self.dir_E1 = os.path.join(opt.dataroot, opt.phase + 'E1') 28 | self.dir_E2 = os.path.join(opt.dataroot, opt.phase + 'E2') 29 | 30 | if opt.phase == "test" and not os.path.exists(self.dir_A) \ 31 | and os.path.exists(os.path.join(opt.dataroot, "valA")): 32 | self.dir_A = os.path.join(opt.dataroot, "valA") 33 | self.dir_B = os.path.join(opt.dataroot, "valB") 34 | self.dir_C = os.path.join(opt.dataroot, "valC") 35 | self.dir_D1 = os.path.join(opt.dataroot, "valD1") 36 | self.dir_D2 = os.path.join(opt.dataroot, "valD2") 37 | self.dir_D3 = os.path.join(opt.dataroot, "valD3") 38 | self.dir_E1 = os.path.join(opt.dataroot, "valE1") 39 | self.dir_E2 = os.path.join(opt.dataroot, "valE2") 40 | 41 | self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA' 42 | self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) 43 | self.C_paths = sorted(make_dataset(self.dir_C, opt.max_dataset_size)) 44 | self.D1_paths = sorted(make_dataset(self.dir_D1, opt.max_dataset_size)) 45 | self.D2_paths = sorted(make_dataset(self.dir_D2, opt.max_dataset_size)) 46 | self.D3_paths = sorted(make_dataset(self.dir_D3, opt.max_dataset_size)) 47 | self.E1_paths = sorted(make_dataset(self.dir_E1, opt.max_dataset_size)) 48 | self.E2_paths = sorted(make_dataset(self.dir_E2, opt.max_dataset_size)) 49 | 50 | self.A_size = len(self.A_paths) # get the size of dataset A 51 | self.B_size = len(self.B_paths) 52 | self.C_size = len(self.C_paths) 53 | self.D_size = len(self.D1_paths) 54 | self.E_size = len(self.E1_paths) 55 | 56 | def __getitem__(self, index): 57 | """Return a data point and its metadata information. 58 | 59 | Parameters: 60 | index (int) -- a random integer for data indexing 61 | 62 | Returns a dictionary that contains A, B, A_paths and B_paths 63 | A (tensor) -- an image in the input domain 64 | B (tensor) -- its corresponding image in the target domain 65 | A_paths (str) -- image paths 66 | B_paths (str) -- image paths 67 | """ 68 | index_A = index % self.A_size 69 | A_path = self.A_paths[index_A] # make sure index is within then range 70 | if self.opt.serial_batches: # make sure index is within then range 71 | index_B = index % self.B_size 72 | index_C = index % self.C_size 73 | index_E = index % self.E_size 74 | else: 75 | index_B = random.randint(0, self.B_size - 1) 76 | index_C = random.randint(0, self.C_size - 1) 77 | index_E = random.randint(0, self.E_size - 1) 78 | 79 | B_path = self.B_paths[index_B] 80 | C_path = self.C_paths[index_C] 81 | D1_path = self.D1_paths[index_A] 82 | D2_path = self.D2_paths[index_A] 83 | D3_path = self.D3_paths[index_A] 84 | E1_path = self.E1_paths[index_E] 85 | E2_path = self.E2_paths[index_E] 86 | 87 | 88 | A_img = Image.open(A_path).convert('RGB') 89 | B_img = Image.open(B_path).convert('RGB') 90 | C_img = Image.open(C_path).convert('RGB') 91 | D1_img = Image.open(D1_path).convert('RGB') 92 | D2_img = Image.open(D2_path).convert('RGB') 93 | D3_img = Image.open(D3_path).convert('RGB') 94 | E1_img = Image.open(E1_path).convert('RGB') 95 | E2_img = Image.open(E2_path).convert('RGB') 96 | 97 | is_finetuning = self.opt.isTrain and self.current_epoch > self.opt.n_epochs 98 | modified_opt = util.copyconf(self.opt, load_size=self.opt.crop_size if is_finetuning else self.opt.load_size) 99 | transform = get_transform(modified_opt) 100 | transform_params = get_params(self.opt, A_img.size) 101 | fix_transform = get_transform(self.opt, transform_params) 102 | 103 | A = fix_transform(A_img) 104 | B = transform(B_img) 105 | C = transform(C_img) 106 | D1 = fix_transform(D1_img) 107 | D2 = fix_transform(D2_img) 108 | D3 = fix_transform(D3_img) 109 | E1 = fix_transform(E1_img) 110 | E2 = fix_transform(E2_img) 111 | 112 | return {'A': A, 'B': B, 'C': C, 'D1': D1, 'D2': D2, 'D3': D3, 'E1': E1, 'E2': E2, 'A_paths': A_path, 'B_paths': B_path, 'C_paths': C_path, 'D1_paths': D1_path,'D2_paths': D2_path, 'D3_paths': D3_path,'E1_paths': E1_path, 'E2_paths': E2_path} 113 | 114 | def __len__(self): 115 | """Return the total number of images in the dataset. 116 | 117 | As we have two datasets with potentially different number of images, 118 | we take a maximum of 119 | """ 120 | return max(self.A_size, self.B_size, self.C_size, self.D_size, self.E_size) 121 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [arXiv](https://arxiv.org/abs/2108.11364), [Porject page](https://junlinhan.github.io/projects/BID.html), [Paper](https://arxiv.org/pdf/2108.11364.pdf), [Video](https://youtu.be/wkyJDjUPCkg), [Slide](https://junlinhan.github.io/Files/BID_slide.pptx), [Poster](https://junlinhan.github.io/Files/BID_poster.pdf) 2 | 3 | # Blind Image Decomposition (BID) 4 | 5 | BID task requires separating a superimposed image into constituent underlying images in a blind setting, that is, both the source components involved in mixing as well as the mixing mechanism are unknown. 6 | 7 | We invite our community to explore the novel BID task, including discovering interesting areas of application, developing novel methods, extending the BID setting,and constructing benchmark datasets. 8 | 9 | [Blind Image Decomposition](https://arxiv.org/pdf/2108.11364.pdf)
10 | [Junlin Han](https://junlinhan.github.io/), Weihao Li, Pengfei Fang, Chunyi Sun, Jie Hong, Ali Armin, [Lars Petersson](https://people.csiro.au/P/L/Lars-Petersson), [Hongdong Li](http://users.cecs.anu.edu.au/~hongdong/)
11 | DATA61-CSIRO and Australian National University
12 | European Conference on Computer Vision (ECCV), 2022 13 | 14 | BID demo: 15 | 16 | 17 | # BIDeN (Blind Image Decomposition Network): 18 | 19 | 20 | ## Applications of BID 21 | 22 | **Deraining (rain streak, snow, haze, raindrop):** 23 | 24 |
25 | Row 1-6 presents 6 cases of a same scene. The 6 cases are (1): rainstreak, (2): rain streak + snow, (3): rain streak + light haze, (4): rain streak + heavy haze, (5): rain streak + moderate haze + raindrop, (6)rain streak + snow + moderate haze + raindrop. 26 |
27 | 28 | **Joint shadow/reflection/watermark removal:** 29 | 30 | 31 | ## Prerequisites 32 | Python 3.7 or above. 33 | 34 | For packages, see requirements.txt. 35 | 36 | ### Getting started 37 | 38 | - Clone this repo: 39 | ```bash 40 | git clone https://github.com/JunlinHan/BID.git 41 | ``` 42 | 43 | - Install PyTorch 1.7 or above and other dependencies (e.g., torchvision, visdom, dominate, gputil). 44 | 45 | For pip users, please type the command `pip install -r requirements.txt`. 46 | 47 | For Conda users, you can create a new Conda environment using `conda env create -f environment.yml`. (Recommend) 48 | 49 | We tested our code on both Windows and Ubuntu OS. 50 | 51 | ### BID Datasets 52 | 53 | - Download BID datasets: https://drive.google.com/drive/folders/1wUUKTiRAGVvelarhsjmZZ_1iBdBaM6Ka?usp=sharing 54 | 55 | unzip the downloaded datasets, put them inside `./datasets/`. 56 | 57 | - To use our dataset in your method/project, please refer to ./models for detailed usages (biden2-8_model for Task I, rain_model for Task II, jointremoval_model for Task III). The code can be easily transfered. 58 | 59 | ### BID Train/Test 60 | - Detailed instructions are provided at `./models/`. 61 | - To view training results and loss plots, run `python -m visdom.server` and click the URL http://localhost:8097. 62 | 63 | **Task I: Mixed image decomposition across multiple domains:** 64 | 65 | Train (biden n, where n is the maximum number of source components): 66 | ```bash 67 | python train.py --dataroot ./datasets/image_decom --name biden2 --model biden2 --dataset_mode unaligned2 68 | python train.py --dataroot ./datasets/image_decom --name biden3 --model biden3 --dataset_mode unaligned3 69 | ... 70 | python train.py --dataroot ./datasets/image_decom --name biden8 --model biden8 --dataset_mode unaligned8 71 | ``` 72 | 73 | Test a single case (use n = 3 as an example): 74 | ```bash 75 | Test a single case: 76 | python test.py --dataroot ./datasets/image_decom --name biden3 --model biden3 --dataset_mode unaligned3 --test_input A 77 | python test.py --dataroot ./datasets/image_decom --name biden3 --model biden3 --dataset_mode unaligned3 --test_input AB 78 | ``` 79 | ... ane other cases. 80 | change test_input to the case you want. 81 | 82 | Test all cases: 83 | ```bash 84 | python test2.py --dataroot ./datasets/image_decom --name biden3 --model biden3 --dataset_mode unaligned3 85 | ``` 86 | 87 | **Task II.A : Real-scenario deraining in driving:** 88 | 89 | Train: 90 | ```bash 91 | python train.py --dataroot ./datasets/raina --name task2a --model raina --dataset_mode raina 92 | ``` 93 | 94 | **Task II.B : Real-scenario deraining in general:** 95 | 96 | Train: 97 | ```bash 98 | python train.py --dataroot ./datasets/rainb --name task2b --model rainb --dataset_mode rainb 99 | ``` 100 | 101 | **Task III: Joint shadow/reflection/watermark removal:** 102 | 103 | Train: 104 | ```bash 105 | python train.py --dataroot ./datasets/jointremoval_v1 --name task3_v1 --model jointremoval --dataset_mode jointremoval 106 | or 107 | python train.py --dataroot ./datasets/jointremoval_v2 --name task3_v2 --model jointremoval --dataset_mode jointremoval 108 | ``` 109 | 110 | The test results will be saved to an html file here: `./results/`. 111 | 112 | ### Apply a pre-trained BIDeN model 113 | We provide our pre-trained BIDeN models at: https://drive.google.com/drive/folders/1UBmdKZXYewJVXHT4dRaat4g8xZ61OyDF?usp=sharing 114 | 115 | Download the pre-tained model, unzip it and put it inside ./checkpoints. 116 | 117 | Example usage: Download the dataset of task II.A (rain in driving) and pretainred model of task II.A. Test the rain streak case. 118 | ```bash 119 | python test.py --dataroot ./datasets/raina --name task2a --model raina --dataset_mode raina --test_input B 120 | ``` 121 | 122 | ### Evaluation 123 | For FID score, use [pytorch-fid](https://github.com/mseitzer/pytorch-fid). 124 | 125 | For PSNR/SSIM/RMSE/NIQE/BRISQUE, see `./metrics/`. 126 | 127 | ### Raindrop effect 128 | See `./raindrop/`. 129 | 130 | ### Citation 131 | If you use our code or our results, please consider citing our paper. Thanks in advance! 132 | ``` 133 | @inproceedings{han2022bid, 134 | title={Blind Image Decomposition}, 135 | author={Junlin Han and Weihao Li and Pengfei Fang and Chunyi Sun and Jie Hong and Mohammad Ali Armin and Lars Petersson and Hongdong Li}, 136 | booktitle={European Conference on Computer Vision (ECCV)}, 137 | year={2022} 138 | } 139 | ``` 140 | 141 | ### Contact 142 | junlin.han@data61.csiro.au or junlinhcv@gmail.com 143 | 144 | ### Acknowledgments 145 | Our code is developed based on [DCLGAN](https://github.com/junlinhan/DCLGAN) and [CUT](http://taesung.me/ContrastiveUnpairedTranslation/). 146 | We thank the auhtors of [MPRNet](https://github.com/swz30/MPRNet), [perceptual-reflection-removal](https://github.com/ceciliavision/perceptual-reflection-removal), [Double-DIP](https://github.com/yossigandelsman/DoubleDIP), [Deep-adversarial-decomposition](https://github.com/jiupinjia/Deep-adversarial-decomposition) for sharing their source code. 147 | We thank [exposure-fusion-shadow-removal](https://github.com/tsingqguo/exposure-fusion-shadow-removal) and [ghost-free-shadow-removal](https://github.com/vinthony/ghost-free-shadow-removal) for providing the source code and results. 148 | We thank [pytorch-fid](https://github.com/mseitzer/pytorch-fid) for FID computation. 149 | 150 | -------------------------------------------------------------------------------- /data/unaligned8_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_transform 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | import random 6 | import util.util as util 7 | 8 | 9 | class Unaligned8Dataset(BaseDataset): 10 | """ 11 | This dataset class can load unaligned/unpaired datasets. 12 | """ 13 | 14 | def __init__(self, opt): 15 | """Initialize this dataset class. 16 | 17 | Parameters: 18 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 19 | """ 20 | BaseDataset.__init__(self, opt) 21 | self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA' 22 | self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB' 23 | self.dir_C = os.path.join(opt.dataroot, opt.phase + 'C') # create a path '/path/to/data/trainC' 24 | self.dir_D = os.path.join(opt.dataroot, opt.phase + 'D') # create a path '/path/to/data/trainD' 25 | self.dir_E = os.path.join(opt.dataroot, opt.phase + 'E') # create a path '/path/to/data/trainE' 26 | self.dir_F = os.path.join(opt.dataroot, opt.phase + 'F') # create a path '/path/to/data/trainF' 27 | self.dir_G = os.path.join(opt.dataroot, opt.phase + 'G') # create a path '/path/to/data/trainG' 28 | self.dir_H = os.path.join(opt.dataroot, opt.phase + 'H') # create a path '/path/to/data/trainH' 29 | 30 | if opt.phase == "test" and not os.path.exists(self.dir_A) \ 31 | and os.path.exists(os.path.join(opt.dataroot, "valA")): 32 | self.dir_A = os.path.join(opt.dataroot, "valA") 33 | self.dir_B = os.path.join(opt.dataroot, "valB") 34 | self.dir_C = os.path.join(opt.dataroot, "valC") 35 | self.dir_D = os.path.join(opt.dataroot, "valD") 36 | self.dir_E = os.path.join(opt.dataroot, "valE") 37 | self.dir_F = os.path.join(opt.dataroot, "valF") 38 | self.dir_G = os.path.join(opt.dataroot, "valG") 39 | self.dir_H = os.path.join(opt.dataroot, "valH") 40 | 41 | self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA' 42 | self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB' 43 | self.C_paths = sorted(make_dataset(self.dir_C, opt.max_dataset_size)) 44 | self.D_paths = sorted(make_dataset(self.dir_D, opt.max_dataset_size)) # load images from '/path/to/data/trainB' 45 | self.E_paths = sorted(make_dataset(self.dir_E, opt.max_dataset_size)) 46 | self.F_paths = sorted(make_dataset(self.dir_F, opt.max_dataset_size)) 47 | self.G_paths = sorted(make_dataset(self.dir_G, opt.max_dataset_size)) # load images from '/path/to/data/trainB' 48 | self.H_paths = sorted(make_dataset(self.dir_H, opt.max_dataset_size)) 49 | self.A_size = len(self.A_paths) # get the size of dataset A 50 | self.B_size = len(self.B_paths) # get the size of dataset B 51 | self.C_size = len(self.C_paths) # get the size of dataset C 52 | self.D_size = len(self.D_paths) # get the size of dataset B 53 | self.E_size = len(self.E_paths) # get the size of dataset C 54 | self.F_size = len(self.F_paths) # get the size of dataset C 55 | self.G_size = len(self.G_paths) # get the size of dataset B 56 | self.H_size = len(self.H_paths) # get the size of dataset C 57 | 58 | def __getitem__(self, index): 59 | """Return a data point and its metadata information. 60 | 61 | Parameters: 62 | index (int) -- a random integer for data indexing 63 | 64 | Returns a dictionary that contains A, B, A_paths and B_paths 65 | A (tensor) -- an image in the input domain 66 | B (tensor) -- its corresponding image in the target domain 67 | A_paths (str) -- image paths 68 | B_paths (str) -- image paths 69 | """ 70 | A_path = self.A_paths[index % self.A_size] # make sure index is within then range 71 | if self.opt.serial_batches: # make sure index is within then range 72 | index_B = index % self.B_size 73 | index_C = index % self.C_size 74 | index_D = index % self.D_size 75 | index_E = index % self.E_size 76 | index_F = index % self.F_size 77 | index_G = index % self.G_size 78 | index_H = index % self.H_size 79 | else: # randomize the index for domain B to avoid fixed pairs. 80 | index_B = random.randint(0, self.B_size - 1) 81 | index_C = random.randint(0, self.C_size - 1) 82 | index_D = random.randint(0, self.D_size - 1) 83 | index_E = random.randint(0, self.E_size - 1) 84 | index_F = random.randint(0, self.F_size - 1) 85 | index_G = random.randint(0, self.G_size - 1) 86 | index_H = random.randint(0, self.H_size - 1) 87 | 88 | B_path = self.B_paths[index_B] 89 | C_path = self.C_paths[index_C] 90 | D_path = self.D_paths[index_D] 91 | E_path = self.E_paths[index_E] 92 | F_path = self.F_paths[index_F] 93 | G_path = self.G_paths[index_G] 94 | H_path = self.H_paths[index_H] 95 | 96 | A_img = Image.open(A_path).convert('RGB') 97 | B_img = Image.open(B_path).convert('RGB') 98 | C_img = Image.open(C_path).convert('RGB') 99 | D_img = Image.open(D_path).convert('RGB') 100 | E_img = Image.open(E_path).convert('RGB') 101 | F_img = Image.open(F_path).convert('RGB') 102 | G_img = Image.open(G_path).convert('RGB') 103 | H_img = Image.open(H_path).convert('RGB') 104 | # Apply image transformation 105 | # For FastCUT mode, if in finetuning phase (learning rate is decaying), 106 | # do not perform resize-crop data augmentation of CycleGAN. 107 | # print('current_epoch', self.current_epoch) 108 | is_finetuning = self.opt.isTrain and self.current_epoch > self.opt.n_epochs 109 | modified_opt = util.copyconf(self.opt, load_size=self.opt.crop_size if is_finetuning else self.opt.load_size) 110 | transform = get_transform(modified_opt) 111 | A = transform(A_img) 112 | B = transform(B_img) 113 | C = transform(C_img) 114 | D = transform(D_img) 115 | E = transform(E_img) 116 | F = transform(F_img) 117 | G = transform(G_img) 118 | H = transform(H_img) 119 | 120 | return {'A': A, 'B': B, 'C': C, 'D': D, 'E': E, 'F': F, 'G': G, 'H': H, 121 | 'A_paths': A_path, 'B_paths': B_path, 'C_paths': C_path, 'D_paths': D_path, 'E_paths': E_path, 'F_paths': F_path, 'G_paths': G_path, 'H_paths': H_path} 122 | 123 | def __len__(self): 124 | """Return the total number of images in the dataset. 125 | 126 | As we have two datasets with potentially different number of images, 127 | we take a maximum of 128 | """ 129 | return max(self.A_size, self.B_size, self.C_size, self.D_size, self.E_size, self.F_size, self.G_size, self.H_size) 130 | -------------------------------------------------------------------------------- /train_fid.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import os 4 | from options.train_options import TrainOptions 5 | from options.test_options import TestOptions 6 | from data import create_dataset 7 | from models import create_model 8 | from util import html 9 | from util.visualizer import Visualizer, save_images 10 | from pytorch_fid.fid_score import calculate_fid_given_paths 11 | """ 12 | This one measures the FID during training. You need to create validations sets. 13 | We keep it here, but not recommended to use, since we have access to GT, metrics like PSNR/SSIM are better. 14 | """ 15 | 16 | if __name__ == '__main__': 17 | opt = TrainOptions().parse() # get training options 18 | val_opts = TestOptions().parse() 19 | val_opts.phase = 'val' 20 | val_opts.num_threads = 0 # test code only supports num_threads = 0 21 | val_opts.batch_size = 1 # test code only supports batch_size = 1 22 | val_opts.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. 23 | val_opts.no_flip = True # no flip; comment this line if results on flipped images are needed. 24 | val_opts.display_id = -1 25 | val_opts.aspect_ratio = 1.0 26 | opt.val_metric_freq = 1 27 | 28 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options 29 | 30 | val_dataset = create_dataset(val_opts) # create a dataset given opt.dataset_mode and other options 31 | web_dir = os.path.join(val_opts.results_dir, val_opts.name, 32 | '{}_{}'.format(val_opts.phase, val_opts.epoch)) # define the website directory 33 | print('creating web directory', web_dir) 34 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch)) 35 | dataset_size = len(dataset) # get the number of images in the dataset. 36 | 37 | model = create_model(opt) # create a model given opt.model and other options 38 | print('The number of training images = %d' % dataset_size) 39 | 40 | visualizer = Visualizer(opt) # create a visualizer that display/save images and plots 41 | opt.visualizer = visualizer 42 | total_iters = 0 # the total number of training iterations 43 | 44 | optimize_time = 0.1 45 | times = [] 46 | for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1): # outer loop for different epochs; we save the model by , + 47 | epoch_start_time = time.time() # timer for entire epoch 48 | iter_data_time = time.time() # timer for data loading per iteration 49 | epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch 50 | visualizer.reset() # reset the visualizer: make sure it saves the results to HTML at least once every epoch 51 | 52 | dataset.set_epoch(epoch) 53 | for i, data in enumerate(dataset): # inner loop within one epoch 54 | iter_start_time = time.time() # timer for computation per iteration 55 | if total_iters % opt.print_freq == 0: 56 | t_data = iter_start_time - iter_data_time 57 | 58 | batch_size = data["A"].size(0) 59 | total_iters += batch_size 60 | epoch_iter += batch_size 61 | if len(opt.gpu_ids) > 0: 62 | torch.cuda.synchronize() 63 | optimize_start_time = time.time() 64 | if epoch == opt.epoch_count and i == 0: 65 | model.data_dependent_initialize(data) 66 | model.setup(opt) # regular setup: load and print networks; create schedulers 67 | model.parallelize() 68 | model.set_input(data) # unpack data from dataset and apply preprocessing 69 | model.optimize_parameters() # calculate loss functions, get gradients, update network weights 70 | if len(opt.gpu_ids) > 0: 71 | torch.cuda.synchronize() 72 | optimize_time = (time.time() - optimize_start_time) / batch_size * 0.005 + 0.995 * optimize_time 73 | 74 | if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file 75 | save_result = total_iters % opt.update_html_freq == 0 76 | model.compute_visuals() 77 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) 78 | 79 | if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk 80 | losses = model.get_current_losses() 81 | visualizer.print_current_losses(epoch, epoch_iter, losses, optimize_time, t_data) 82 | if opt.display_id is None or opt.display_id > 0: 83 | visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses) 84 | 85 | if total_iters % opt.save_latest_freq == 0: # cache our latest model every iterations 86 | print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) 87 | print(opt.name) # it's useful to occasionally show the experiment name on console 88 | save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest' 89 | model.save_networks(save_suffix) 90 | 91 | iter_data_time = time.time() 92 | 93 | if epoch % opt.save_epoch_freq == 0: # cache our model every epochs 94 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) 95 | model.save_networks('latest') 96 | model.save_networks(epoch) 97 | 98 | print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)) 99 | if epoch % opt.val_metric_freq == 0: 100 | print('Evaluating FID for validation set at epoch %d, iters %d, at dataset %s' % ( 101 | epoch, total_iters, opt.name)) 102 | model.eval() 103 | for i, data in enumerate(val_dataset): 104 | model.set_input(data) # unpack data from data loader 105 | model.test() # run inference 106 | 107 | visuals = model.get_current_visuals() # get image results 108 | if opt.direction == 'BtoA': 109 | visuals = {'fake_A': visuals['fake_A']} 110 | suffix1 = 'fake_A' 111 | suffix2 = 'valA' 112 | else: 113 | visuals = {'fake_B': visuals['fake_B']} 114 | suffix1 = 'fake_B' 115 | suffix2 = 'valB' 116 | 117 | img_path = model.get_image_paths() # get image paths 118 | if i % 50 == 0: # save images to an HTML file 119 | print('processing (%04d)-th image... %s' % (i, img_path)) 120 | save_images(webpage, visuals, img_path, aspect_ratio=val_opts.aspect_ratio, 121 | width=val_opts.display_winsize) 122 | fid_value = calculate_fid_given_paths( 123 | paths=(('./results/{d}/val_latest/images/'+suffix1).format(d=opt.name), ('{d}/'+suffix2).format(d=opt.dataroot)), 124 | batch_size=50, cuda='0', dims=2048) 125 | visualizer.print_current_fid(epoch, fid_value) 126 | visualizer.plot_current_fid(epoch, fid_value) 127 | 128 | print('End of epoch %d / %d \t Time Taken: %d sec' % ( 129 | epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)) 130 | model.update_learning_rate() # update learning rates at the end of every epoch. 131 | -------------------------------------------------------------------------------- /raindrop/gen_raindrop/gen_raindrop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | import random 5 | import numba as nb 6 | 7 | # Set the parameters, this is the default parameters used in our paper. 8 | WIDTH = 1024 9 | HEIGHT = 512 10 | CORES = os.cpu_count() 11 | DEFAULT_N_BALLS = 150 12 | num_balls_min = 100 13 | num_balls_max = 300 14 | alpha = 1 15 | rain_min_size = 2 16 | rain_max_size = 7 17 | connected_metaball_num = 3 18 | 19 | 20 | def update_balls(balls: np.ndarray, dt: float): 21 | # remove the ball outside the screen 22 | remove_index_list = [] 23 | for b in range(balls.shape[0]): 24 | # move 25 | if((balls[b].pos[0]>WIDTH) or (balls[b].pos[1]>HEIGHT)): 26 | remove_index_list.append(b) 27 | 28 | balls = np.delete(balls,remove_index_list) 29 | 30 | # update position for large raindrops 31 | for b in range(balls.shape[0]): 32 | # move 33 | if(balls[b].radius>6): 34 | balls[b].pos += balls[b].vel * dt*0.2 35 | 36 | return balls 37 | 38 | 39 | def gen_metaball(): 40 | alpha = np.zeros((20, 20, 3)) 41 | for i in range(20): 42 | for j in range(20): 43 | dx = (20 / 2 - i) / 10 44 | dy = (20 / 2 - j) / 10 45 | if ((dx != 0) or (dy != 0)): 46 | alpha[i, j, :] = 1 / (dx ** 2 + dy ** 2) 47 | else: 48 | alpha[i, j, :] = 4 49 | 50 | alpha[alpha > 4] = 4 51 | alpha = alpha / 4 52 | alpha[alpha < 0.25] = 0 53 | return alpha 54 | 55 | # draw the textures image and alpha image for raindrop 56 | def draw_textures(screen:np.ndarray,balls:np.ndarray,texture): 57 | w, h = screen.shape[0], screen.shape[1] 58 | b_count = balls.shape[0] 59 | 60 | # to use all cores 61 | for start in nb.prange(CORES): 62 | # for each pixel on screen 63 | for x in range(start, w, CORES): 64 | for y in range(h): 65 | screen[x, y].fill(0) # clear pixel 66 | # for each ball 67 | 68 | # create texture screen 69 | texture_screen = screen.copy() 70 | 71 | for b in range(b_count): 72 | # print(b,balls[b].radius) 73 | # calculate value 74 | texture_w = min(w,int(balls[b].pos[0]+balls[b].radius*2))-max(0,int(balls[b].pos[0]-balls[b].radius*2)) 75 | texture_h = min(h,int(balls[b].pos[1] + balls[b].radius * 2))-max(0,int(balls[b].pos[1] - balls[b].radius * 2)) 76 | 77 | if((texture_w>0) and (texture_h>0)): 78 | 79 | alpha = gen_metaball() 80 | 81 | # resize to fit size 82 | alpha = cv2.resize(alpha, (texture_h, texture_w)) 83 | texture = cv2.resize(texture,(texture_h,texture_w)) 84 | 85 | 86 | screen[max(0, int(balls[b].pos[0] - balls[b].radius * 2)):min(w, int( 87 | balls[b].pos[0] + balls[b].radius * 2)), 88 | max(0, int(balls[b].pos[1] - balls[b].radius * 2)):min(h, int( 89 | balls[b].pos[1] + balls[b].radius * 2))] += \ 90 | (alpha*255).astype(np.int32) 91 | 92 | 93 | texture_screen[max(0,int(balls[b].pos[0]-balls[b].radius*2)):min(w,int(balls[b].pos[0]+balls[b].radius*2)), 94 | max(0,int(balls[b].pos[1] - balls[b].radius * 2)):min(h,int(balls[b].pos[1] + balls[b].radius * 2))] += \ 95 | (texture * screen[max(0,int(balls[b].pos[0]-balls[b].radius*2)):min(w,int(balls[b].pos[0]+balls[b].radius*2)), 96 | max(0,int(balls[b].pos[1] - balls[b].radius * 2)):min(h,int(balls[b].pos[1] + balls[b].radius * 2))] / 255 ).astype(np.int32) 97 | 98 | texture_screen[max(0,int(balls[b].pos[0]-balls[b].radius*2)):min(w,int(balls[b].pos[0]+balls[b].radius*2)), 99 | max(0,int(balls[b].pos[1] - balls[b].radius * 2)):min(h,int(balls[b].pos[1] + balls[b].radius * 2)),0] = balls[b].thickness*255 100 | 101 | 102 | texture_screen[texture_screen>255] = 255 103 | screen[texture_screen > 255] = 255 104 | 105 | return texture_screen,screen.copy() 106 | 107 | 108 | def create_balls(n_balls): 109 | """make random balls""" 110 | balls = np.recarray( 111 | (n_balls,), dtype=[("pos", ("=n_balls): 117 | break 118 | # generate ball 119 | balls[i].radius = random.randint(rain_min_size, rain_max_size) 120 | balls[i].pos = ( 121 | np.random.randint(balls[i].radius, WIDTH - balls[i].radius), 122 | np.random.randint(balls[i].radius, HEIGHT - balls[i].radius), 123 | ) 124 | # set thinckness 125 | balls[i].thickness = random.random() 126 | # set velocity, correlate with radius 127 | balls[i].vel = (0,balls[i].radius**2) 128 | # generate connected metaball as raindrops 129 | num = random.randint(1, connected_metaball_num) 130 | i = i + 1 131 | for j in range(1,num): 132 | if (i >= n_balls): 133 | break 134 | balls[i].radius = random.randint(rain_min_size, rain_max_size) 135 | balls[i].pos = ( 136 | random.choice([ 137 | np.random.randint(int(balls[i-j].pos[0]-1.8*balls[i-j].radius),int(balls[i-j].pos[0]-0.5*balls[i-j].radius)), 138 | np.random.randint(int(balls[i-j].pos[0]+0.5*balls[i-j].radius),int(balls[i-j].pos[0]+1.8*balls[i-j].radius))]), 139 | random.choice([ 140 | np.random.randint(int(balls[i-j].pos[1] - 1.8 * balls[i-j].radius), 141 | int(balls[i-j].pos[1] - 0.5 * balls[i-j].radius)), 142 | np.random.randint(int(balls[i-j].pos[1] + 0.5 * balls[i-j].radius), 143 | int(balls[i-j].pos[1] + 1.8 * balls[i-j].radius))]) 144 | ) 145 | balls[i].thickness = random.random() 146 | # set the velocity same as the parent 147 | balls[i].vel = (0, balls[i].radius ** 2) 148 | i = i+1 149 | 150 | 151 | return balls 152 | 153 | # add new balls 154 | def add_balls(old_balls,n): 155 | 156 | balls = np.recarray( 157 | (n,), dtype=[("pos", (": initialize the class, first call BaseDataset.__init__(self, opt). 18 | -- <__len__>: return the size of dataset. 19 | -- <__getitem__>: get a data point. 20 | -- : (optionally) add dataset-specific options and set default options. 21 | """ 22 | 23 | def __init__(self, opt): 24 | """Initialize the class; save the options in the class 25 | 26 | Parameters: 27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 28 | """ 29 | self.opt = opt 30 | self.root = opt.dataroot 31 | self.current_epoch = 0 32 | 33 | @staticmethod 34 | def modify_commandline_options(parser, is_train): 35 | """Add new dataset-specific options, and rewrite default values for existing options. 36 | 37 | Parameters: 38 | parser -- original option parser 39 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 40 | 41 | Returns: 42 | the modified parser. 43 | """ 44 | return parser 45 | 46 | @abstractmethod 47 | def __len__(self): 48 | """Return the total number of images in the dataset.""" 49 | return 0 50 | 51 | @abstractmethod 52 | def __getitem__(self, index): 53 | """Return a data point and its metadata information. 54 | 55 | Parameters: 56 | index - - a random integer for data indexing 57 | 58 | Returns: 59 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 60 | """ 61 | pass 62 | 63 | 64 | def get_params(opt, size): 65 | w, h = size 66 | new_h = h 67 | new_w = w 68 | if opt.preprocess == 'resize_and_crop': 69 | new_h = new_w = opt.load_size 70 | elif opt.preprocess == 'scale_width_and_crop': 71 | new_w = opt.load_size 72 | new_h = opt.load_size * h // w 73 | 74 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 75 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 76 | 77 | flip = random.random() > 0.5 78 | 79 | return {'crop_pos': (x, y), 'flip': flip} 80 | 81 | 82 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 83 | transform_list = [] 84 | if grayscale: 85 | transform_list.append(transforms.Grayscale(1)) 86 | if 'fixsize' in opt.preprocess: 87 | transform_list.append(transforms.Resize(params["size"], method)) 88 | if 'resize' in opt.preprocess: 89 | osize = [opt.load_size, opt.load_size] 90 | if "gta2cityscapes" in opt.dataroot: 91 | osize[0] = opt.load_size // 2 92 | transform_list.append(transforms.Resize(osize, method)) 93 | elif 'scale_width' in opt.preprocess: 94 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method))) 95 | elif 'scale_shortside' in opt.preprocess: 96 | transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, opt.crop_size, method))) 97 | 98 | if 'zoom' in opt.preprocess: 99 | if params is None: 100 | transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method))) 101 | else: 102 | transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method, factor=params["scale_factor"]))) 103 | 104 | if 'crop' in opt.preprocess: 105 | if params is None or 'crop_pos' not in params: 106 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 107 | else: 108 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 109 | 110 | if 'patch' in opt.preprocess: 111 | transform_list.append(transforms.Lambda(lambda img: __patch(img, params['patch_index'], opt.crop_size))) 112 | 113 | if 'trim' in opt.preprocess: 114 | transform_list.append(transforms.Lambda(lambda img: __trim(img, opt.crop_size))) 115 | 116 | # if opt.preprocess == 'none': 117 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 118 | if not opt.no_flip: 119 | if params is None or 'flip' not in params: 120 | transform_list.append(transforms.RandomHorizontalFlip()) 121 | elif 'flip' in params: 122 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 123 | 124 | if convert: 125 | transform_list += [transforms.ToTensor()] 126 | if grayscale: 127 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 128 | else: 129 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 130 | return transforms.Compose(transform_list) 131 | 132 | 133 | def __make_power_2(img, base, method=Image.BICUBIC): 134 | ow, oh = img.size 135 | h = int(round(oh / base) * base) 136 | w = int(round(ow / base) * base) 137 | if h == oh and w == ow: 138 | return img 139 | 140 | return img.resize((w, h), method) 141 | 142 | 143 | def __random_zoom(img, target_width, crop_width, method=Image.BICUBIC, factor=None): 144 | if factor is None: 145 | zoom_level = np.random.uniform(0.8, 1.0, size=[2]) 146 | else: 147 | zoom_level = (factor[0], factor[1]) 148 | iw, ih = img.size 149 | zoomw = max(crop_width, iw * zoom_level[0]) 150 | zoomh = max(crop_width, ih * zoom_level[1]) 151 | img = img.resize((int(round(zoomw)), int(round(zoomh))), method) 152 | return img 153 | 154 | 155 | def __scale_shortside(img, target_width, crop_width, method=Image.BICUBIC): 156 | ow, oh = img.size 157 | shortside = min(ow, oh) 158 | if shortside >= target_width: 159 | return img 160 | else: 161 | scale = target_width / shortside 162 | return img.resize((round(ow * scale), round(oh * scale)), method) 163 | 164 | 165 | def __trim(img, trim_width): 166 | ow, oh = img.size 167 | if ow > trim_width: 168 | xstart = np.random.randint(ow - trim_width) 169 | xend = xstart + trim_width 170 | else: 171 | xstart = 0 172 | xend = ow 173 | if oh > trim_width: 174 | ystart = np.random.randint(oh - trim_width) 175 | yend = ystart + trim_width 176 | else: 177 | ystart = 0 178 | yend = oh 179 | return img.crop((xstart, ystart, xend, yend)) 180 | 181 | 182 | def __scale_width(img, target_width, crop_width, method=Image.BICUBIC): 183 | ow, oh = img.size 184 | if ow == target_width and oh >= crop_width: 185 | return img 186 | w = target_width 187 | h = int(max(target_width * oh / ow, crop_width)) 188 | return img.resize((w, h), method) 189 | 190 | 191 | def __crop(img, pos, size): 192 | ow, oh = img.size 193 | x1, y1 = pos 194 | tw = th = size 195 | if (ow > tw or oh > th): 196 | return img.crop((x1, y1, x1 + tw, y1 + th)) 197 | return img 198 | 199 | 200 | def __patch(img, index, size): 201 | ow, oh = img.size 202 | nw, nh = ow // size, oh // size 203 | roomx = ow - nw * size 204 | roomy = oh - nh * size 205 | startx = np.random.randint(int(roomx) + 1) 206 | starty = np.random.randint(int(roomy) + 1) 207 | 208 | index = index % (nw * nh) 209 | ix = index // nh 210 | iy = index % nh 211 | gridx = startx + ix * size 212 | gridy = starty + iy * size 213 | return img.crop((gridx, gridy, gridx + size, gridy + size)) 214 | 215 | 216 | def __flip(img, flip): 217 | if flip: 218 | return img.transpose(Image.FLIP_LEFT_RIGHT) 219 | return img 220 | 221 | 222 | def __print_size_warning(ow, oh, w, h): 223 | """Print warning information about image size(only print once)""" 224 | if not hasattr(__print_size_warning, 'has_printed'): 225 | print("The image size needs to be a multiple of 4. " 226 | "The loaded image size was (%d, %d), so it was adjusted to " 227 | "(%d, %d). This adjustment will be done to all images " 228 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 229 | __print_size_warning.has_printed = True 230 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | import models 6 | import data 7 | 8 | 9 | class BaseOptions(): 10 | """This class defines options used during both training and test time. 11 | 12 | It also implements several helper functions such as parsing, printing, and saving the options. 13 | It also gathers additional options defined in functions in both dataset class and model class. 14 | """ 15 | 16 | def __init__(self, cmd_line=None): 17 | """Reset the class; indicates the class hasn't been initailized""" 18 | self.initialized = False 19 | self.cmd_line = None 20 | if cmd_line is not None: 21 | self.cmd_line = cmd_line.split() 22 | 23 | def initialize(self, parser): 24 | """Define the common options that are used in both training and test.""" 25 | # basic parameters 26 | parser.add_argument('--dataroot', default='placeholder', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') 27 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 28 | parser.add_argument('--easy_label', type=str, default='experiment_name', help='Interpretable name') 29 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 30 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 31 | # model parameters 32 | parser.add_argument('--model', type=str, default='BIDeN', help='chooses which model to use.') 33 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale') 34 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale') 35 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') 36 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') 37 | parser.add_argument('--netD', type=str, default='basic', choices=['basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') 38 | parser.add_argument('--netG', type=str, default='resnet_9blocks', choices=['resnet_9blocks', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat'], help='specify generator architecture') 39 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') 40 | parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G') 41 | parser.add_argument('--normD', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for D') 42 | parser.add_argument('--init_type', type=str, default='xavier', choices=['normal', 'xavier', 'kaiming', 'orthogonal'], help='network initialization') 43 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 44 | parser.add_argument('--no_dropout', type=util.str2bool, nargs='?', const=True, default=True, 45 | help='no dropout for the generator') 46 | parser.add_argument('--no_antialias', action='store_true', help='if specified, use stride=2 convs instead of antialiased-downsampling (sad)') 47 | parser.add_argument('--no_antialias_up', action='store_true', help='if specified, use [upconv(learned filter)] instead of [upconv(hard-coded [1,3,3,1] filter), conv]') 48 | # dataset parameters 49 | parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') 50 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 51 | parser.add_argument('--num_threads', default=8, type=int, help='# threads for loading data, set this to 0 if you are using Windows OS.') 52 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size') 53 | parser.add_argument('--load_size', type=int, default=286, help='scale images to this size') 54 | parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size') 55 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 56 | parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]') 57 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 58 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') 59 | # additional parameters 60 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 61 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 62 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') 63 | self.initialized = True 64 | return parser 65 | 66 | def gather_options(self): 67 | """Initialize our parser with basic options(only once). 68 | Add additional model-specific and dataset-specific options. 69 | These options are defined in the function 70 | in model and dataset classes. 71 | """ 72 | if not self.initialized: # check if it has been initialized 73 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 74 | parser = self.initialize(parser) 75 | 76 | # get the basic options 77 | if self.cmd_line is None: 78 | opt, _ = parser.parse_known_args() 79 | else: 80 | opt, _ = parser.parse_known_args(self.cmd_line) 81 | 82 | # modify model-related parser options 83 | model_name = opt.model 84 | model_option_setter = models.get_option_setter(model_name) 85 | parser = model_option_setter(parser, self.isTrain) 86 | if self.cmd_line is None: 87 | opt, _ = parser.parse_known_args() # parse again with new defaults 88 | else: 89 | opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults 90 | 91 | # modify dataset-related parser options 92 | dataset_name = opt.dataset_mode 93 | dataset_option_setter = data.get_option_setter(dataset_name) 94 | parser = dataset_option_setter(parser, self.isTrain) 95 | 96 | # save and return the parser 97 | self.parser = parser 98 | if self.cmd_line is None: 99 | return parser.parse_args() 100 | else: 101 | return parser.parse_args(self.cmd_line) 102 | 103 | def print_options(self, opt): 104 | """Print and save options 105 | 106 | It will print both current options and default values(if different). 107 | It will save options into a text file / [checkpoints_dir] / opt.txt 108 | """ 109 | message = '' 110 | message += '----------------- Options ---------------\n' 111 | for k, v in sorted(vars(opt).items()): 112 | comment = '' 113 | default = self.parser.get_default(k) 114 | if v != default: 115 | comment = '\t[default: %s]' % str(default) 116 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 117 | message += '----------------- End -------------------' 118 | print(message) 119 | 120 | # save to the disk 121 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 122 | util.mkdirs(expr_dir) 123 | file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) 124 | try: 125 | with open(file_name, 'wt') as opt_file: 126 | opt_file.write(message) 127 | opt_file.write('\n') 128 | except PermissionError as error: 129 | print("permission error {}".format(error)) 130 | pass 131 | 132 | def parse(self): 133 | """Parse our options, create checkpoints directory suffix, and set up gpu device.""" 134 | opt = self.gather_options() 135 | opt.isTrain = self.isTrain # train or test 136 | 137 | # process opt.suffix 138 | if opt.suffix: 139 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 140 | opt.name = opt.name + suffix 141 | 142 | self.print_options(opt) 143 | 144 | # set gpu ids 145 | str_ids = opt.gpu_ids.split(',') 146 | opt.gpu_ids = [] 147 | for str_id in str_ids: 148 | id = int(str_id) 149 | if id >= 0: 150 | opt.gpu_ids.append(id) 151 | if len(opt.gpu_ids) > 0: 152 | torch.cuda.set_device(opt.gpu_ids[0]) 153 | 154 | self.opt = opt 155 | return self.opt 156 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from abc import ABC, abstractmethod 5 | from . import networks 6 | 7 | 8 | class BaseModel(ABC): 9 | """This class is an abstract base class (ABC) for models. 10 | To create a subclass, you need to implement the following five functions: 11 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 12 | -- : unpack data from dataset and apply preprocessing. 13 | -- : produce intermediate results. 14 | -- : calculate losses, gradients, and update network weights. 15 | -- : (optionally) add model-specific options and set default options. 16 | """ 17 | 18 | def __init__(self, opt): 19 | """Initialize the BaseModel class. 20 | 21 | Parameters: 22 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 23 | 24 | When creating your custom class, you need to implement your own initialization. 25 | In this fucntion, you should first call 26 | Then, you need to define four lists: 27 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 28 | -- self.model_names (str list): specify the images that you want to display and save. 29 | -- self.visual_names (str list): define networks used in our training. 30 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 31 | """ 32 | self.opt = opt 33 | self.gpu_ids = opt.gpu_ids 34 | self.isTrain = opt.isTrain 35 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU 36 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir 37 | if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. 38 | torch.backends.cudnn.benchmark = True 39 | self.loss_names = [] 40 | self.model_names = [] 41 | self.visual_names = [] 42 | self.optimizers = [] 43 | self.image_paths = [] 44 | self.metric = 0 # used for learning rate policy 'plateau' 45 | 46 | @staticmethod 47 | def dict_grad_hook_factory(add_func=lambda x: x): 48 | saved_dict = dict() 49 | 50 | def hook_gen(name): 51 | def grad_hook(grad): 52 | saved_vals = add_func(grad) 53 | saved_dict[name] = saved_vals 54 | return grad_hook 55 | return hook_gen, saved_dict 56 | 57 | @staticmethod 58 | def modify_commandline_options(parser, is_train): 59 | """Add new model-specific options, and rewrite default values for existing options. 60 | 61 | Parameters: 62 | parser -- original option parser 63 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 64 | 65 | Returns: 66 | the modified parser. 67 | """ 68 | return parser 69 | 70 | @abstractmethod 71 | def set_input(self, input): 72 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 73 | 74 | Parameters: 75 | input (dict): includes the data itself and its metadata information. 76 | """ 77 | pass 78 | 79 | @abstractmethod 80 | def forward(self): 81 | """Run forward pass; called by both functions and .""" 82 | pass 83 | 84 | @abstractmethod 85 | def forward_test(self): 86 | """Run forward pass; called by both functions and .""" 87 | pass 88 | 89 | @abstractmethod 90 | def optimize_parameters(self): 91 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 92 | pass 93 | 94 | def setup(self, opt): 95 | """Load and print networks; create schedulers 96 | 97 | Parameters: 98 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 99 | """ 100 | if self.isTrain: 101 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 102 | if not self.isTrain or opt.continue_train: 103 | load_suffix = opt.epoch 104 | self.load_networks(load_suffix) 105 | 106 | self.print_networks(opt.verbose) 107 | 108 | def parallelize(self): 109 | for name in self.model_names: 110 | if isinstance(name, str): 111 | net = getattr(self, 'net' + name) 112 | setattr(self, 'net' + name, torch.nn.DataParallel(net, self.opt.gpu_ids)) 113 | 114 | def data_dependent_initialize(self, data): 115 | pass 116 | 117 | def eval(self): 118 | """Make models eval mode during test time""" 119 | for name in self.model_names: 120 | if isinstance(name, str): 121 | net = getattr(self, 'net' + name) 122 | net.eval() 123 | 124 | def train(self): 125 | for name in self.model_names: 126 | if isinstance(name, str): 127 | net = getattr(self, 'net' + name) 128 | net.train() 129 | 130 | def test(self): 131 | """Forward function used in test time. 132 | 133 | This function wraps function in no_grad() so we don't save intermediate steps for backprop 134 | It also calls to produce additional visualization results 135 | """ 136 | with torch.no_grad(): 137 | self.forward_test() 138 | self.compute_visuals() 139 | 140 | def compute_visuals(self): 141 | """Calculate additional output images for visdom and HTML visualization""" 142 | pass 143 | 144 | def get_image_paths(self): 145 | """ Return image paths that are used to load current data""" 146 | return self.image_paths 147 | 148 | def update_learning_rate(self): 149 | """Update learning rates for all the networks; called at the end of every epoch""" 150 | for scheduler in self.schedulers: 151 | if self.opt.lr_policy == 'plateau': 152 | scheduler.step(self.metric) 153 | else: 154 | scheduler.step() 155 | 156 | lr = self.optimizers[0].param_groups[0]['lr'] 157 | print('learning rate = %.7f' % lr) 158 | 159 | def get_current_visuals(self): 160 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" 161 | visual_ret = OrderedDict() 162 | for name in self.visual_names: 163 | if isinstance(name, str): 164 | visual_ret[name] = getattr(self, name) 165 | return visual_ret 166 | 167 | def get_current_losses(self): 168 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" 169 | errors_ret = OrderedDict() 170 | for name in self.loss_names: 171 | if isinstance(name, str): 172 | errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number 173 | return errors_ret 174 | 175 | def save_networks(self, epoch): 176 | """Save all the networks to the disk. 177 | 178 | Parameters: 179 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 180 | """ 181 | for name in self.model_names: 182 | if isinstance(name, str): 183 | save_filename = '%s_net_%s.pth' % (epoch, name) 184 | save_path = os.path.join(self.save_dir, save_filename) 185 | net = getattr(self, 'net' + name) 186 | 187 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 188 | torch.save(net.module.cpu().state_dict(), save_path) 189 | net.cuda(self.gpu_ids[0]) 190 | else: 191 | torch.save(net.cpu().state_dict(), save_path) 192 | 193 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 194 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" 195 | key = keys[i] 196 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 197 | if module.__class__.__name__.startswith('InstanceNorm') and \ 198 | (key == 'running_mean' or key == 'running_var'): 199 | if getattr(module, key) is None: 200 | state_dict.pop('.'.join(keys)) 201 | if module.__class__.__name__.startswith('InstanceNorm') and \ 202 | (key == 'num_batches_tracked'): 203 | state_dict.pop('.'.join(keys)) 204 | else: 205 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 206 | 207 | def load_networks(self, epoch): 208 | """Load all the networks from the disk. 209 | 210 | Parameters: 211 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 212 | """ 213 | for name in self.model_names: 214 | if isinstance(name, str): 215 | load_filename = '%s_net_%s.pth' % (epoch, name) 216 | if self.opt.isTrain and self.opt.pretrained_name is not None: 217 | load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name) 218 | else: 219 | load_dir = self.save_dir 220 | 221 | load_path = os.path.join(load_dir, load_filename) 222 | net = getattr(self, 'net' + name) 223 | if isinstance(net, torch.nn.DataParallel): 224 | net = net.module 225 | print('loading the model from %s' % load_path) 226 | # if you are using PyTorch newer than 0.4 (e.g., built from 227 | # GitHub source), you can remove str() on self.device 228 | state_dict = torch.load(load_path, map_location=str(self.device)) 229 | if hasattr(state_dict, '_metadata'): 230 | del state_dict._metadata 231 | 232 | # patch InstanceNorm checkpoints prior to 0.4 233 | # for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 234 | # self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 235 | net.load_state_dict(state_dict) 236 | 237 | def print_networks(self, verbose): 238 | """Print the total number of parameters in the network and (if verbose) network architecture 239 | 240 | Parameters: 241 | verbose (bool) -- if verbose: print the network architecture 242 | """ 243 | print('---------- Networks initialized -------------') 244 | for name in self.model_names: 245 | if isinstance(name, str): 246 | net = getattr(self, 'net' + name) 247 | num_params = 0 248 | for param in net.parameters(): 249 | num_params += param.numel() 250 | if verbose: 251 | print(net) 252 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 253 | print('-----------------------------------------------') 254 | 255 | def set_requires_grad(self, nets, requires_grad=False): 256 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 257 | Parameters: 258 | nets (network list) -- a list of networks 259 | requires_grad (bool) -- whether the networks require gradients or not 260 | """ 261 | if not isinstance(nets, list): 262 | nets = [nets] 263 | for net in nets: 264 | if net is not None: 265 | for param in net.parameters(): 266 | param.requires_grad = requires_grad 267 | 268 | def generate_visuals_for_evaluation(self, data, mode): 269 | return {} 270 | -------------------------------------------------------------------------------- /models/biden2_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import networks 4 | from .losses import VGGLoss 5 | from numpy import * 6 | import itertools 7 | 8 | """ 9 | BIDeN model for Task I, Mixed image decomposition across multiple domains. Max number of domain = 2. 10 | 11 | Sample usage: 12 | Optional visualization: 13 | python -m visdom.server 14 | 15 | Train: 16 | python train.py --dataroot ./datasets/image_decom --name biden2 --model biden2 --dataset_mode unaligned2 17 | 18 | For test: 19 | Test a single case: 20 | python test.py --dataroot ./datasets/image_decom --name biden2 --model biden2 --dataset_mode unaligned2 --test_input A 21 | python test.py --dataroot ./datasets/image_decom --name biden2 --model biden2 --dataset_mode unaligned2 --test_input AB 22 | ... ane other cases. 23 | change test_input to the case you want. 24 | 25 | Test all cases: 26 | python test2.py --dataroot ./datasets/image_decom --name biden2 --model biden2 --dataset_mode unaligned2 27 | """ 28 | 29 | class BIDEN2Model(BaseModel): 30 | @staticmethod 31 | def modify_commandline_options(parser, is_train=True): 32 | """ Configures options specific for BIDeN model 33 | """ 34 | parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss') 35 | parser.add_argument('--lambda_Ln', type=float, default=30.0, help='weight for L1/L2 loss') 36 | parser.add_argument('--lambda_VGG', type=float, default=10.0, help='weight for VGG loss') 37 | parser.add_argument('--lambda_BCE', type=float, default=1.0, help='weight for BCE loss') 38 | parser.add_argument('--test_input', type=str, default='A', help='test mixed images.') 39 | parser.add_argument('--max_domain', type=int, default=2, help='max number of source components.') 40 | parser.add_argument('--prob', type=float, default=0.9, help='probability of adding a component') 41 | parser.add_argument('--test_choice', type=int, default=1, help='choice for test mode, 1 for one case,' 42 | ' 0 for all cases. Will be set automatically.') 43 | opt, _ = parser.parse_known_args() 44 | return parser 45 | 46 | def __init__(self, opt): 47 | BaseModel.__init__(self, opt) 48 | 49 | # specify the training losses you want to print out. 50 | # The training/test scripts will call 51 | self.loss_names = ['G_GAN', 'D_real', 'D_fake', 'Ln', 'VGG', 'BCE'] 52 | self.visual_names = ['fake_A', 'fake_B', 'real_A', 'real_B', 'real_input'] 53 | self.model_names = ['D', 'E', 'H1', 'H2'] 54 | 55 | # Define networks (both generator and discriminator) 56 | # Define Encoder E. 57 | self.netE = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'encoder', opt.normG, 58 | not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, 59 | opt.no_antialias_up, self.gpu_ids, opt) 60 | # Define Heads H. 61 | self.netH1 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'head', opt.normG, 62 | not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, 63 | opt.no_antialias_up, self.gpu_ids, opt) 64 | self.netH2 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'head', opt.normG, 65 | not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, 66 | opt.no_antialias_up, self.gpu_ids, opt) 67 | 68 | self.label = torch.zeros(self.opt.max_domain).to(self.device) 69 | self.netD = networks.define_D(opt.output_nc, self.opt.max_domain, 'BIDeN_D', opt.n_layers_D, opt.normD, 70 | opt.init_type, 71 | opt.init_gain, opt.no_antialias, self.gpu_ids, opt) 72 | self.correct = 0 73 | self.all_count = 0 74 | self.psnr_count = [0] * self.opt.max_domain 75 | self.acc_all = 0 76 | self.test_time = 0 77 | self.criterionL2 = torch.nn.MSELoss().to(self.device) 78 | 79 | if self.isTrain: 80 | # define loss functions 81 | self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) 82 | self.criterionVGG = VGGLoss(opt).to(self.device) 83 | self.criterionL1 = torch.nn.L1Loss().to(self.device) 84 | self.criterionBCE = torch.nn.BCEWithLogitsLoss().to(self.device) 85 | self.optimizer_G = torch.optim.Adam( 86 | itertools.chain(self.netE.parameters(), self.netH1.parameters(), 87 | self.netH2.parameters()), 88 | lr=opt.lr, betas=(opt.beta1, opt.beta2)) 89 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) 90 | self.optimizers.append(self.optimizer_G) 91 | self.optimizers.append(self.optimizer_D) 92 | 93 | 94 | def data_dependent_initialize(self, data): 95 | self.set_input(data) 96 | bs_per_gpu = self.real_A.size(0) // max(len(self.opt.gpu_ids), 1) 97 | self.real_A = self.real_A[:bs_per_gpu] 98 | self.real_B = self.real_B[:bs_per_gpu] 99 | self.forward() 100 | if self.opt.isTrain: 101 | self.compute_D_loss().backward() # calculate gradients for D 102 | self.compute_G_loss().backward() # calculate graidents for G 103 | 104 | def optimize_parameters(self): 105 | # forward 106 | self.forward() 107 | 108 | # update D 109 | self.set_requires_grad(self.netD, True) 110 | self.optimizer_D.zero_grad() 111 | self.loss_D = self.compute_D_loss() 112 | self.loss_D.backward() 113 | self.optimizer_D.step() 114 | 115 | # update G 116 | self.set_requires_grad(self.netD, False) 117 | self.optimizer_G.zero_grad() 118 | self.loss_G = self.compute_G_loss() 119 | self.loss_G.backward() 120 | self.optimizer_G.step() 121 | 122 | 123 | def set_input(self, input): 124 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 125 | Parameters: 126 | input (dict): include the data itself and its metadata information. 127 | """ 128 | self.real_A = input['A'].to(self.device) 129 | self.real_B = input['B'].to(self.device) 130 | self.image_paths = input['A_paths'] 131 | 132 | def forward(self): 133 | """ 134 | Run forward pass; called by both functions . 135 | We have another version of forward (forward_test) used in testing. 136 | """ 137 | # If no images are used in mixing, run again. 138 | label_sum = 0 139 | while label_sum < 1: 140 | p = torch.rand(self.opt.max_domain) 141 | for i in range(self.opt.max_domain): 142 | if p[i] < self.opt.prob: 143 | self.label[i] = 1 144 | else: 145 | self.label[i] = 0 146 | label_sum = torch.sum(self.label, 0) 147 | 148 | self.real_input = (self.real_A * self.label[0] + self.real_B * self.label[1]) / label_sum 149 | self.fake_all = self.netE(self.real_input) 150 | self.fake_A = self.netH1(self.fake_all) 151 | self.fake_B = self.netH2(self.fake_all) 152 | self.loss_sum = label_sum 153 | 154 | 155 | def compute_D_loss(self): 156 | """Calculate GAN loss and BCE loss for the discriminator""" 157 | fake1 = self.fake_A.detach() 158 | fake2 = self.fake_B.detach() 159 | pred_fake1 = self.netD(0,fake1) 160 | pred_fake2 = self.netD(0,fake2) 161 | self.loss_D_fake = self.criterionGAN(pred_fake1, False) * self.label[0] \ 162 | + self.criterionGAN(pred_fake2, False) * self.label[1] 163 | 164 | # Real 165 | self.pred_real1 = self.netD(0,self.real_A) 166 | self.pred_real2 = self.netD(0,self.real_B) 167 | 168 | self.loss_D_real = self.criterionGAN(self.pred_real1, True) * self.label[0] \ 169 | + self.criterionGAN(self.pred_real2,True) * self.label[1] 170 | 171 | # BCE loss, netD(1) for the source prediction branch. 172 | self.predict_label = self.netD(1,self.real_input).view(self.opt.max_domain) 173 | self.loss_BCE = self.criterionBCE(self.predict_label, self.label) 174 | # combine loss and calculate gradients 175 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 * self.opt.lambda_GAN + self.loss_BCE * self.opt.lambda_BCE 176 | return self.loss_D 177 | 178 | def compute_G_loss(self): 179 | """Calculate GAN loss, Ln loss, VGG loss for the generator""" 180 | # netD(0) for the separation branch. 181 | pred_fake1 = self.netD(0,self.fake_A) 182 | pred_fake2 = self.netD(0,self.fake_B) 183 | 184 | self.loss_G_GAN = self.criterionGAN(pred_fake1, True) * self.label[0] \ 185 | + self.criterionGAN(pred_fake2, True) * self.label[1] 186 | 187 | self.loss_Ln = self.criterionL1(self.real_A, self.fake_A) * self.label[0] \ 188 | + self.criterionL1(self.real_B,self.fake_B) * self.label[1] 189 | 190 | self.loss_VGG = self.criterionVGG(self.fake_A, self.real_A) * self.label[0] \ 191 | + self.criterionVGG(self.fake_B, self.real_B) * self.label[1] 192 | 193 | self.loss_G = self.loss_G_GAN * self.opt.lambda_GAN + self.loss_Ln * self.opt.lambda_Ln \ 194 | + self.loss_VGG * self.opt.lambda_VGG 195 | 196 | return self.loss_G 197 | 198 | def forward_test(self): 199 | # Test case 1, test a single case only, write the output images. 200 | if self.opt.test_choice == 1: 201 | gt_label = [0] * self.opt.max_domain 202 | if 'A' in self.opt.test_input: 203 | gt_label[0] = 1 204 | if 'B' in self.opt.test_input: 205 | gt_label[1] = 1 206 | self.real_input = (self.real_A * gt_label[0] + self.real_B * gt_label[1]) / sum(gt_label) 207 | self.fake_all = self.netE(self.real_input) 208 | self.predict_label = self.netD(1, self.real_input).view(self.opt.max_domain) 209 | predict_label = torch.where(self.predict_label > 0.0, 1, 0) 210 | self.fake_A = self.netH1(self.fake_all) 211 | self.fake_B = self.netH2(self.fake_all) 212 | if predict_label.tolist() == gt_label: 213 | self.correct = self.correct + 1 214 | self.all_count = self.all_count + 1 215 | if self.all_count == 300: 216 | print(self.correct / self.all_count) 217 | else: 218 | # Test case 0, test all cases, do not write output images. 219 | gt_label = [0] * self.opt.max_domain 220 | if 'A' in self.opt.test_input: 221 | gt_label[0] = 1 222 | if 'B' in self.opt.test_input: 223 | gt_label[1] = 1 224 | self.real_input = (self.real_A * gt_label[0] + self.real_B * gt_label[1])/ sum(gt_label) 225 | self.predict_label = self.netD(1, self.real_input).view(self.opt.max_domain) 226 | predict_label = torch.where(self.predict_label > 0.0, 1, 0) 227 | self.fake_all = self.netE(self.real_input) 228 | self.fake_A = self.netH1(self.fake_all) 229 | self.fake_B = self.netH2(self.fake_all) 230 | 231 | # Normalize to 0-1 for PSNR calculation. 232 | self.fake_A = (self.fake_A + 1)/2 233 | self.real_A = (self.real_A + 1)/2 234 | self.fake_B = (self.fake_B + 1)/2 235 | self.real_B = (self.real_B + 1)/2 236 | 237 | if gt_label[0] == 1: 238 | mse = self.criterionL2(self.fake_A, self.real_A) 239 | psnr = 10 * log10(1 / mse.item()) 240 | self.psnr_count[0] += psnr 241 | if gt_label[1] == 1: 242 | mse = self.criterionL2(self.fake_B, self.real_B) 243 | psnr = 10 * log10(1 / mse.item()) 244 | self.psnr_count[1] += psnr 245 | if predict_label.tolist() == gt_label: 246 | self.correct = self.correct + 1 247 | self.all_count = self.all_count + 1 248 | if self.all_count % 300 == 0: 249 | acc = self.correct / self.all_count 250 | print("Accuracy for current: ",acc) 251 | if gt_label[0] == 1: 252 | print("PSNR_A: ", self.psnr_count[0]/ self.all_count) 253 | if gt_label[1] == 1: 254 | print("PSNR_B: ", self.psnr_count[1]/ self.all_count) 255 | self.all_count = 0 256 | self.correct = 0 257 | self.psnr_count = [0] * self.opt.max_domain 258 | self.acc_all += acc 259 | self.test_time +=1 260 | if mean(gt_label) == 1: 261 | print("Overall Accuracy:", self.acc_all/self.test_time) 262 | 263 | 264 | 265 | 266 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import ntpath 5 | import time 6 | from . import util, html 7 | from subprocess import Popen, PIPE 8 | 9 | if sys.version_info[0] == 2: 10 | VisdomExceptionBase = Exception 11 | else: 12 | VisdomExceptionBase = ConnectionError 13 | 14 | 15 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): 16 | """Save images to the disk. 17 | 18 | Parameters: 19 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) 20 | visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs 21 | image_path (str) -- the string is used to create image paths 22 | aspect_ratio (float) -- the aspect ratio of saved images 23 | width (int) -- the images will be resized to width x width 24 | 25 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. 26 | """ 27 | image_dir = webpage.get_image_dir() 28 | short_path = ntpath.basename(image_path[0]) 29 | name = os.path.splitext(short_path)[0] 30 | 31 | webpage.add_header(name) 32 | ims, txts, links = [], [], [] 33 | 34 | for label, im_data in visuals.items(): 35 | im = util.tensor2im(im_data) 36 | image_name = '%s/%s.png' % (label, name) 37 | os.makedirs(os.path.join(image_dir, label), exist_ok=True) 38 | save_path = os.path.join(image_dir, image_name) 39 | util.save_image(im, save_path, aspect_ratio=aspect_ratio) 40 | ims.append(image_name) 41 | txts.append(label) 42 | links.append(image_name) 43 | webpage.add_images(ims, txts, links, width=width) 44 | 45 | 46 | class Visualizer(): 47 | """This class includes several functions that can display/save images and print/save logging information. 48 | 49 | It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. 50 | """ 51 | 52 | def __init__(self, opt): 53 | """Initialize the Visualizer class 54 | 55 | Parameters: 56 | opt -- stores all the experiment flags; needs to be a subclass of BaseOptions 57 | Step 1: Cache the training/test options 58 | Step 2: connect to a visdom server 59 | Step 3: create an HTML object for saveing HTML filters 60 | Step 4: create a logging file to store training losses 61 | """ 62 | self.opt = opt # cache the option 63 | if opt.display_id is None: 64 | self.display_id = np.random.randint(100000) * 10 # just a random display id 65 | else: 66 | self.display_id = opt.display_id 67 | self.use_html = opt.isTrain and not opt.no_html 68 | self.win_size = opt.display_winsize 69 | self.name = opt.name 70 | self.port = opt.display_port 71 | self.saved = False 72 | if self.display_id > 0: # connect to a visdom server given and 73 | import visdom 74 | self.plot_data = {} 75 | self.ncols = opt.display_ncols 76 | if "tensorboard_base_url" not in os.environ: 77 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env) 78 | else: 79 | self.vis = visdom.Visdom(port=2004, 80 | base_url=os.environ['tensorboard_base_url'] + '/visdom') 81 | if not self.vis.check_connection(): 82 | self.create_visdom_connections() 83 | 84 | if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ 85 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 86 | self.img_dir = os.path.join(self.web_dir, 'images') 87 | print('create web directory %s...' % self.web_dir) 88 | util.mkdirs([self.web_dir, self.img_dir]) 89 | # create a logging file to store training losses 90 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 91 | self.fid_log_name = os.path.join(opt.checkpoints_dir, opt.name, 'fid_log.txt') 92 | 93 | with open(self.log_name, "a") as log_file: 94 | now = time.strftime("%c") 95 | log_file.write('================ Training Loss (%s) ================\n' % now) 96 | with open(self.fid_log_name, "a") as log_file: 97 | now = time.strftime("%c") 98 | log_file.write('================ Validation FID (%s) ================\n' % now) 99 | 100 | def reset(self): 101 | """Reset the self.saved status""" 102 | self.saved = False 103 | 104 | def create_visdom_connections(self): 105 | """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """ 106 | cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port 107 | print('\n\nCould not connect to Visdom server. \n Trying to start a server....') 108 | print('Command: %s' % cmd) 109 | Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) 110 | 111 | def display_current_results(self, visuals, epoch, save_result): 112 | """Display current results on visdom; save current results to an HTML file. 113 | 114 | Parameters: 115 | visuals (OrderedDict) - - dictionary of images to display or save 116 | epoch (int) - - the current epoch 117 | save_result (bool) - - if save the current results to an HTML file 118 | """ 119 | if self.display_id > 0: # show images in the browser using visdom 120 | ncols = self.ncols 121 | if ncols > 0: # show all the images in one visdom panel 122 | ncols = min(ncols, len(visuals)) 123 | h, w = next(iter(visuals.values())).shape[:2] 124 | table_css = """""" % (w, h) # create a table css 128 | # create a table of images. 129 | title = self.name 130 | label_html = '' 131 | label_html_row = '' 132 | images = [] 133 | idx = 0 134 | for label, image in visuals.items(): 135 | image_numpy = util.tensor2im(image) 136 | label_html_row += '%s' % label 137 | images.append(image_numpy.transpose([2, 0, 1])) 138 | idx += 1 139 | if idx % ncols == 0: 140 | label_html += '%s' % label_html_row 141 | label_html_row = '' 142 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 143 | while idx % ncols != 0: 144 | images.append(white_image) 145 | label_html_row += '' 146 | idx += 1 147 | if label_html_row != '': 148 | label_html += '%s' % label_html_row 149 | try: 150 | self.vis.images(images, ncols, 2, self.display_id + 1, 151 | None, dict(title=title + ' images')) 152 | label_html = '%s
' % label_html 153 | self.vis.text(table_css + label_html, win=self.display_id + 2, 154 | opts=dict(title=title + ' labels')) 155 | except VisdomExceptionBase: 156 | self.create_visdom_connections() 157 | 158 | else: # show each image in a separate visdom panel; 159 | idx = 1 160 | try: 161 | for label, image in visuals.items(): 162 | image_numpy = util.tensor2im(image) 163 | self.vis.image( 164 | image_numpy.transpose([2, 0, 1]), 165 | self.display_id + idx, 166 | None, 167 | dict(title=label) 168 | ) 169 | idx += 1 170 | except VisdomExceptionBase: 171 | self.create_visdom_connections() 172 | 173 | if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. 174 | self.saved = True 175 | # save images to the disk 176 | for label, image in visuals.items(): 177 | image_numpy = util.tensor2im(image) 178 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 179 | util.save_image(image_numpy, img_path) 180 | 181 | # update website 182 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0) 183 | for n in range(epoch, 0, -1): 184 | webpage.add_header('epoch [%d]' % n) 185 | ims, txts, links = [], [], [] 186 | 187 | for label, image_numpy in visuals.items(): 188 | image_numpy = util.tensor2im(image) 189 | img_path = 'epoch%.3d_%s.png' % (n, label) 190 | ims.append(img_path) 191 | txts.append(label) 192 | links.append(img_path) 193 | webpage.add_images(ims, txts, links, width=self.win_size) 194 | webpage.save() 195 | 196 | 197 | def plot_current_fid(self, epoch, fid): 198 | """display the current fid on visdom display 199 | Parameters: 200 | epoch (int) -- current epoch 201 | fid (float) -- validation fid 202 | """ 203 | if not hasattr(self, 'fid_plot_data'): 204 | self.fid_plot_data = {'X': [], 'Y': []} 205 | self.fid_plot_data['X'].append(epoch) 206 | self.fid_plot_data['Y'].append(fid) 207 | try: 208 | self.vis.line( 209 | X=np.array(self.fid_plot_data['X']), 210 | Y=np.array(self.fid_plot_data['Y']), 211 | opts={ 212 | 'title': self.name + ' fid over time', 213 | 'xlabel': 'epoch', 214 | 'ylabel': 'fid'}, 215 | win=self.display_id + 4) 216 | except VisdomExceptionBase: 217 | self.create_visdom_connections() 218 | 219 | 220 | 221 | def plot_current_losses(self, epoch, counter_ratio, losses): 222 | """display the current losses on visdom display: dictionary of error labels and values 223 | 224 | Parameters: 225 | epoch (int) -- current epoch 226 | counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1 227 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 228 | """ 229 | if len(losses) == 0: 230 | return 231 | 232 | plot_name = '_'.join(list(losses.keys())) 233 | 234 | if plot_name not in self.plot_data: 235 | self.plot_data[plot_name] = {'X': [], 'Y': [], 'legend': list(losses.keys())} 236 | 237 | plot_data = self.plot_data[plot_name] 238 | plot_id = list(self.plot_data.keys()).index(plot_name) 239 | 240 | plot_data['X'].append(epoch + counter_ratio) 241 | plot_data['Y'].append([losses[k] for k in plot_data['legend']]) 242 | try: 243 | self.vis.line( 244 | X=np.stack([np.array(plot_data['X'])] * len(plot_data['legend']), 1), 245 | Y=np.array(plot_data['Y']), 246 | opts={ 247 | 'title': self.name, 248 | 'legend': plot_data['legend'], 249 | 'xlabel': 'epoch', 250 | 'ylabel': 'loss'}, 251 | win=self.display_id - plot_id) 252 | except VisdomExceptionBase: 253 | self.create_visdom_connections() 254 | 255 | # losses: same format as |losses| of plot_current_losses 256 | def print_current_losses(self, epoch, iters, losses, t_comp, t_data): 257 | """print current losses on console; also save the losses to the disk 258 | 259 | Parameters: 260 | epoch (int) -- current epoch 261 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) 262 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 263 | t_comp (float) -- computational time per data point (normalized by batch_size) 264 | t_data (float) -- data loading time per data point (normalized by batch_size) 265 | """ 266 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) 267 | for k, v in losses.items(): 268 | message += '%s: %.3f ' % (k, v) 269 | 270 | print(message) # print the message 271 | with open(self.log_name, "a") as log_file: 272 | log_file.write('%s\n' % message) # save the message 273 | 274 | def print_current_fid(self, epoch, fid): 275 | """print current fid on console; also save the fid to the disk 276 | Parameters: 277 | epoch (int) -- current epoch 278 | fid (float) - fid metric 279 | """ 280 | message = '(epoch: %d, fid: %.3f) ' % (epoch, fid) 281 | 282 | print(message) # print the message 283 | with open(self.fid_log_name, "a") as log_file: 284 | log_file.write('%s\n' % message) # save the message -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Junlin Han 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | --------------------------- LICENSE FOR DCLGAN ------------------------------- 23 | MIT License 24 | 25 | Copyright (c) 2021 Junlin Han 26 | 27 | Permission is hereby granted, free of charge, to any person obtaining a copy 28 | of this software and associated documentation files (the "Software"), to deal 29 | in the Software without restriction, including without limitation the rights 30 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 31 | copies of the Software, and to permit persons to whom the Software is 32 | furnished to do so, subject to the following conditions: 33 | 34 | The above copyright notice and this permission notice shall be included in all 35 | copies or substantial portions of the Software. 36 | 37 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 38 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 39 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 40 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 41 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 42 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 43 | SOFTWARE. 44 | 45 | --------------------------- LICENSE FOR CUT ------------------------------- 46 | Copyright (c) 2020, Taesung Park and Jun-Yan Zhu 47 | All rights reserved. 48 | 49 | Redistribution and use in source and binary forms, with or without 50 | modification, are permitted provided that the following conditions are met: 51 | 52 | * Redistributions of source code must retain the above copyright notice, this 53 | list of conditions and the following disclaimer. 54 | 55 | * Redistributions in binary form must reproduce the above copyright notice, 56 | this list of conditions and the following disclaimer in the documentation 57 | and/or other materials provided with the distribution. 58 | 59 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 60 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 61 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 62 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 63 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 64 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 65 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 66 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 67 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 68 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 69 | 70 | --------------------------- LICENSE FOR CycleGAN ------------------------------- 71 | -------------------https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix------ 72 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park 73 | All rights reserved. 74 | 75 | Redistribution and use in source and binary forms, with or without 76 | modification, are permitted provided that the following conditions are met: 77 | 78 | * Redistributions of source code must retain the above copyright notice, this 79 | list of conditions and the following disclaimer. 80 | 81 | * Redistributions in binary form must reproduce the above copyright notice, 82 | this list of conditions and the following disclaimer in the documentation 83 | and/or other materials provided with the distribution. 84 | 85 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 86 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 87 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 88 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 89 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 90 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 91 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 92 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 93 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 94 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 95 | 96 | --------------------------- LICENSE FOR stylegan2-pytorch ---------------------- 97 | ----------------https://github.com/rosinality/stylegan2-pytorch/---------------- 98 | MIT License 99 | 100 | Copyright (c) 2019 Kim Seonghyeon 101 | 102 | Permission is hereby granted, free of charge, to any person obtaining a copy 103 | of this software and associated documentation files (the "Software"), to deal 104 | in the Software without restriction, including without limitation the rights 105 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 106 | copies of the Software, and to permit persons to whom the Software is 107 | furnished to do so, subject to the following conditions: 108 | 109 | The above copyright notice and this permission notice shall be included in all 110 | copies or substantial portions of the Software. 111 | 112 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 113 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 114 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 115 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 116 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 117 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 118 | SOFTWARE. 119 | 120 | 121 | --------------------------- LICENSE FOR pix2pix -------------------------------- 122 | BSD License 123 | 124 | For pix2pix software 125 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu 126 | All rights reserved. 127 | 128 | Redistribution and use in source and binary forms, with or without 129 | modification, are permitted provided that the following conditions are met: 130 | 131 | * Redistributions of source code must retain the above copyright notice, this 132 | list of conditions and the following disclaimer. 133 | 134 | * Redistributions in binary form must reproduce the above copyright notice, 135 | this list of conditions and the following disclaimer in the documentation 136 | and/or other materials provided with the distribution. 137 | 138 | ----------------------------- LICENSE FOR DCGAN -------------------------------- 139 | BSD License 140 | 141 | For dcgan.torch software 142 | 143 | Copyright (c) 2015, Facebook, Inc. All rights reserved. 144 | 145 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 146 | 147 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 148 | 149 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 150 | 151 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 152 | 153 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 154 | 155 | --------------------------- LICENSE FOR StyleGAN2 ------------------------------ 156 | --------------------------- Inherited from stylegan2-pytorch ------------------- 157 | Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 158 | 159 | 160 | Nvidia Source Code License-NC 161 | 162 | ======================================================================= 163 | 164 | 1. Definitions 165 | 166 | "Licensor" means any person or entity that distributes its Work. 167 | 168 | "Software" means the original work of authorship made available under 169 | this License. 170 | 171 | "Work" means the Software and any additions to or derivative works of 172 | the Software that are made available under this License. 173 | 174 | "Nvidia Processors" means any central processing unit (CPU), graphics 175 | processing unit (GPU), field-programmable gate array (FPGA), 176 | application-specific integrated circuit (ASIC) or any combination 177 | thereof designed, made, sold, or provided by Nvidia or its affiliates. 178 | 179 | The terms "reproduce," "reproduction," "derivative works," and 180 | "distribution" have the meaning as provided under U.S. copyright law; 181 | provided, however, that for the purposes of this License, derivative 182 | works shall not include works that remain separable from, or merely 183 | link (or bind by name) to the interfaces of, the Work. 184 | 185 | Works, including the Software, are "made available" under this License 186 | by including in or with the Work either (a) a copyright notice 187 | referencing the applicability of this License to the Work, or (b) a 188 | copy of this License. 189 | 190 | 2. License Grants 191 | 192 | 2.1 Copyright Grant. Subject to the terms and conditions of this 193 | License, each Licensor grants to you a perpetual, worldwide, 194 | non-exclusive, royalty-free, copyright license to reproduce, 195 | prepare derivative works of, publicly display, publicly perform, 196 | sublicense and distribute its Work and any resulting derivative 197 | works in any form. 198 | 199 | 3. Limitations 200 | 201 | 3.1 Redistribution. You may reproduce or distribute the Work only 202 | if (a) you do so under this License, (b) you include a complete 203 | copy of this License with your distribution, and (c) you retain 204 | without modification any copyright, patent, trademark, or 205 | attribution notices that are present in the Work. 206 | 207 | 3.2 Derivative Works. You may specify that additional or different 208 | terms apply to the use, reproduction, and distribution of your 209 | derivative works of the Work ("Your Terms") only if (a) Your Terms 210 | provide that the use limitation in Section 3.3 applies to your 211 | derivative works, and (b) you identify the specific derivative 212 | works that are subject to Your Terms. Notwithstanding Your Terms, 213 | this License (including the redistribution requirements in Section 214 | 3.1) will continue to apply to the Work itself. 215 | 216 | 3.3 Use Limitation. The Work and any derivative works thereof only 217 | may be used or intended for use non-commercially. The Work or 218 | derivative works thereof may be used or intended for use by Nvidia 219 | or its affiliates commercially or non-commercially. As used herein, 220 | "non-commercially" means for research or evaluation purposes only. 221 | 222 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 223 | against any Licensor (including any claim, cross-claim or 224 | counterclaim in a lawsuit) to enforce any patents that you allege 225 | are infringed by any Work, then your rights under this License from 226 | such Licensor (including the grants in Sections 2.1 and 2.2) will 227 | terminate immediately. 228 | 229 | 3.5 Trademarks. This License does not grant any rights to use any 230 | Licensor's or its affiliates' names, logos, or trademarks, except 231 | as necessary to reproduce the notices described in this License. 232 | 233 | 3.6 Termination. If you violate any term of this License, then your 234 | rights under this License (including the grants in Sections 2.1 and 235 | 2.2) will terminate immediately. 236 | 237 | 4. Disclaimer of Warranty. 238 | 239 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 240 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 241 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 242 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 243 | THIS LICENSE. 244 | 245 | 5. Limitation of Liability. 246 | 247 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 248 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 249 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 250 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 251 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 252 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 253 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 254 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 255 | THE POSSIBILITY OF SUCH DAMAGES. 256 | 257 | ======================================================================= 258 | 259 | -------------------------------------------------------------------------------- /models/biden3_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import networks 4 | from .losses import VGGLoss 5 | from numpy import * 6 | import itertools 7 | 8 | """ 9 | BIDeN model for Task I, Mixed image decomposition across multiple domains. Max number of domain = 3. 10 | 11 | Sample usage: 12 | Optional visualization: 13 | python -m visdom.server 14 | 15 | Train: 16 | python train.py --dataroot ./datasets/image_decom --name biden3 --model biden3 --dataset_mode unaligned3 17 | 18 | For test: 19 | Test a single case: 20 | python test.py --dataroot ./datasets/image_decom --name biden3 --model biden3 --dataset_mode unaligned3 --test_input A 21 | python test.py --dataroot ./datasets/image_decom --name biden3 --model biden3 --dataset_mode unaligned3 --test_input AB 22 | ... ane other cases. 23 | change test_input to the case you want. 24 | 25 | Test all cases: 26 | python test2.py --dataroot ./datasets/image_decom --name biden3 --model biden3 --dataset_mode unaligned3 27 | """ 28 | 29 | class BIDEN3Model(BaseModel): 30 | @staticmethod 31 | def modify_commandline_options(parser, is_train=True): 32 | """ Configures options specific for BIDeN model 33 | """ 34 | parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss') 35 | parser.add_argument('--lambda_Ln', type=float, default=30.0, help='weight for L1/L2 loss') 36 | parser.add_argument('--lambda_VGG', type=float, default=10.0, help='weight for VGG loss') 37 | parser.add_argument('--lambda_BCE', type=float, default=1.0, help='weight for BCE loss') 38 | parser.add_argument('--test_input', type=str, default='AB', help='test mixed images.') 39 | parser.add_argument('--max_domain', type=int, default=3, help='max number of source components.') 40 | parser.add_argument('--prob', type=float, default=0.8, help='probability of adding a component') 41 | parser.add_argument('--test_choice', type=int, default=1, help='choice for test mode, 1 for one case,' 42 | ' 0 for all cases. Will be set automatically.') 43 | opt, _ = parser.parse_known_args() 44 | return parser 45 | 46 | def __init__(self, opt): 47 | BaseModel.__init__(self, opt) 48 | 49 | # specify the training losses you want to print out. 50 | # The training/test scripts will call 51 | self.loss_names = ['G_GAN', 'D_real', 'D_fake', 'Ln', 'VGG', 'BCE'] 52 | self.visual_names = ['fake_A', 'fake_B', 'fake_C', 'real_A', 'real_B', 'real_C', 'real_input'] 53 | self.model_names = ['D', 'E', 'H1', 'H2', 'H3'] 54 | 55 | # Define networks (both generator and discriminator) 56 | # Define Encoder E. 57 | self.netE = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'encoder', opt.normG, 58 | not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, 59 | opt.no_antialias_up, self.gpu_ids, opt) 60 | # Define Heads H. 61 | self.netH1 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'head', opt.normG, 62 | not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, 63 | opt.no_antialias_up, self.gpu_ids, opt) 64 | self.netH2 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'head', opt.normG, 65 | not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, 66 | opt.no_antialias_up, self.gpu_ids, opt) 67 | self.netH3 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'head', opt.normG, 68 | not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, 69 | opt.no_antialias_up, self.gpu_ids, opt) 70 | 71 | self.label = torch.zeros(self.opt.max_domain).to(self.device) 72 | self.netD = networks.define_D(opt.output_nc, self.opt.max_domain, 'BIDeN_D', opt.n_layers_D, opt.normD, 73 | opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt) 74 | self.correct = 0 75 | self.all_count = 0 76 | self.psnr_count = [0] * self.opt.max_domain 77 | self.acc_all = 0 78 | self.test_time = 0 79 | self.criterionL2 = torch.nn.MSELoss().to(self.device) 80 | 81 | if self.isTrain: 82 | # define loss functions 83 | self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) 84 | self.criterionVGG = VGGLoss(opt).to(self.device) 85 | self.criterionL1 = torch.nn.L1Loss().to(self.device) 86 | self.criterionBCE = torch.nn.BCEWithLogitsLoss().to(self.device) 87 | self.optimizer_G = torch.optim.Adam( 88 | itertools.chain(self.netE.parameters(), self.netH1.parameters(), 89 | self.netH2.parameters(), self.netH3.parameters()), 90 | lr=opt.lr, betas=(opt.beta1, opt.beta2)) 91 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) 92 | self.optimizers.append(self.optimizer_G) 93 | self.optimizers.append(self.optimizer_D) 94 | 95 | 96 | def data_dependent_initialize(self, data): 97 | self.set_input(data) 98 | bs_per_gpu = self.real_A.size(0) // max(len(self.opt.gpu_ids), 1) 99 | self.real_A = self.real_A[:bs_per_gpu] 100 | self.real_B = self.real_B[:bs_per_gpu] 101 | self.real_C = self.real_C[:bs_per_gpu] 102 | self.forward() 103 | if self.opt.isTrain: 104 | self.compute_D_loss().backward() # calculate gradients for D 105 | self.compute_G_loss().backward() # calculate graidents for G 106 | 107 | def optimize_parameters(self): 108 | # forward 109 | self.forward() 110 | 111 | # update D 112 | self.set_requires_grad(self.netD, True) 113 | self.optimizer_D.zero_grad() 114 | self.loss_D = self.compute_D_loss() 115 | self.loss_D.backward() 116 | self.optimizer_D.step() 117 | 118 | # update G 119 | self.set_requires_grad(self.netD, False) 120 | self.optimizer_G.zero_grad() 121 | self.loss_G = self.compute_G_loss() 122 | self.loss_G.backward() 123 | self.optimizer_G.step() 124 | 125 | 126 | def set_input(self, input): 127 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 128 | Parameters: 129 | input (dict): include the data itself and its metadata information. 130 | """ 131 | self.real_A = input['A'].to(self.device) 132 | self.real_B = input['B'].to(self.device) 133 | self.real_C = input['C'].to(self.device) 134 | self.image_paths = input['A_paths'] 135 | 136 | def forward(self): 137 | """ 138 | Run forward pass; called by both functions . 139 | We have another version of forward (forward_test) used in testing. 140 | """ 141 | # If no images are used in mixing, run again. 142 | label_sum = 0 143 | while label_sum < 1: 144 | p = torch.rand(self.opt.max_domain) 145 | for i in range(self.opt.max_domain): 146 | if p[i] < self.opt.prob: 147 | self.label[i] = 1 148 | else: 149 | self.label[i] = 0 150 | label_sum = torch.sum(self.label, 0) 151 | 152 | self.real_input = (self.real_A * self.label[0] + self.real_B * self.label[1] 153 | + self.real_C * self.label[2]) / label_sum 154 | self.fake_all = self.netE(self.real_input) 155 | self.fake_A = self.netH1(self.fake_all) 156 | self.fake_B = self.netH2(self.fake_all) 157 | self.fake_C = self.netH3(self.fake_all) 158 | self.loss_sum = label_sum 159 | 160 | 161 | def compute_D_loss(self): 162 | """Calculate GAN loss and BCE loss for the discriminator""" 163 | fake1 = self.fake_A.detach() 164 | fake2 = self.fake_B.detach() 165 | fake3 = self.fake_C.detach() 166 | pred_fake1 = self.netD(0,fake1) 167 | pred_fake2 = self.netD(0,fake2) 168 | pred_fake3 = self.netD(0,fake3) 169 | self.loss_D_fake = self.criterionGAN(pred_fake1, False) * self.label[0] \ 170 | + self.criterionGAN(pred_fake2, False) * self.label[1] \ 171 | + self.criterionGAN(pred_fake3, False) * self.label[2] 172 | 173 | # Real 174 | self.pred_real1 = self.netD(0,self.real_A) 175 | self.pred_real2 = self.netD(0,self.real_B) 176 | self.pred_real3 = self.netD(0,self.real_C) 177 | 178 | self.loss_D_real = self.criterionGAN(self.pred_real1, True) * self.label[0] + \ 179 | self.criterionGAN(self.pred_real2, True) * self.label[1] + \ 180 | self.criterionGAN(self.pred_real3, True) * self.label[2] 181 | 182 | # BCE loss, netD(1) for the source prediction branch. 183 | self.predict_label = self.netD(1,self.real_input).view(self.opt.max_domain) 184 | self.loss_BCE = self.criterionBCE(self.predict_label, self.label) 185 | # combine loss and calculate gradients 186 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 * self.opt.lambda_GAN + self.loss_BCE * self.opt.lambda_BCE 187 | return self.loss_D 188 | 189 | def compute_G_loss(self): 190 | """Calculate GAN loss, Ln loss, VGG loss for the generator""" 191 | # netD(0) for the separation branch. 192 | pred_fake1 = self.netD(0,self.fake_A) 193 | pred_fake2 = self.netD(0,self.fake_B) 194 | pred_fake3 = self.netD(0,self.fake_C) 195 | 196 | self.loss_G_GAN = self.criterionGAN(pred_fake1, True) * self.label[0] \ 197 | + self.criterionGAN(pred_fake2, True) * self.label[1] \ 198 | + self.criterionGAN(pred_fake3, True) * self.label[2] 199 | 200 | self.loss_Ln = self.criterionL1(self.real_A, self.fake_A) * self.label[0] \ 201 | + self.criterionL1(self.real_B,self.fake_B) * self.label[1] \ 202 | + self.criterionL1(self.real_C,self.fake_C) * self.label[2] 203 | 204 | self.loss_VGG = self.criterionVGG(self.fake_A, self.real_A) * self.label[0] \ 205 | + self.criterionVGG(self.fake_B, self.real_B) * self.label[1] \ 206 | + self.criterionVGG(self.fake_C, self.real_C) * self.label[2] 207 | 208 | self.loss_G = self.loss_G_GAN * self.opt.lambda_GAN + self.loss_Ln * self.opt.lambda_Ln \ 209 | + self.loss_VGG * self.opt.lambda_VGG 210 | 211 | return self.loss_G 212 | 213 | def forward_test(self): 214 | # Test case 1, test a single case only, write the output images. 215 | if self.opt.test_choice == 1: 216 | gt_label = [0] * self.opt.max_domain 217 | if 'A' in self.opt.test_input: 218 | gt_label[0] = 1 219 | if 'B' in self.opt.test_input: 220 | gt_label[1] = 1 221 | if 'C' in self.opt.test_input: 222 | gt_label[2] = 1 223 | self.real_input = (self.real_A * gt_label[0] + self.real_B * gt_label[1] 224 | + self.real_C * gt_label[2]) / sum(gt_label) 225 | self.predict_label = self.netD(1, self.real_input).view(self.opt.max_domain) 226 | predict_label = torch.where(self.predict_label > 0.0, 1, 0) 227 | self.fake_all = self.netE(self.real_input) 228 | self.fake_A = self.netH1(self.fake_all) 229 | self.fake_B = self.netH2(self.fake_all) 230 | self.fake_C = self.netH3(self.fake_all) 231 | if predict_label.tolist() == gt_label: 232 | self.correct = self.correct + 1 233 | self.all_count = self.all_count + 1 234 | if self.all_count == 300: 235 | print(self.correct / self.all_count) 236 | else: 237 | # Test case 0, test all cases, do not write output images. 238 | gt_label = [0] * self.opt.max_domain 239 | if 'A' in self.opt.test_input: 240 | gt_label[0] = 1 241 | if 'B' in self.opt.test_input: 242 | gt_label[1] = 1 243 | if 'C' in self.opt.test_input: 244 | gt_label[2] = 1 245 | self.real_input = (self.real_A * gt_label[0] + self.real_B * gt_label[1] 246 | + self.real_C * gt_label[2]) / sum(gt_label) 247 | self.predict_label = self.netD(1, self.real_input).view(self.opt.max_domain) 248 | predict_label = torch.where(self.predict_label > 0.0, 1, 0) 249 | self.fake_all = self.netE(self.real_input) 250 | self.fake_A = self.netH1(self.fake_all) 251 | self.fake_B = self.netH2(self.fake_all) 252 | self.fake_C = self.netH3(self.fake_all) 253 | 254 | # Normalize to 0-1 for PSNR calculation. 255 | self.fake_A = (self.fake_A + 1)/2 256 | self.real_A = (self.real_A + 1)/2 257 | self.fake_B = (self.fake_B + 1)/2 258 | self.real_B = (self.real_B + 1)/2 259 | self.fake_C = (self.fake_C + 1)/2 260 | self.real_C = (self.real_C + 1)/2 261 | 262 | if gt_label[0] == 1: 263 | mse = self.criterionL2(self.fake_A, self.real_A) 264 | psnr = 10 * log10(1 / mse.item()) 265 | self.psnr_count[0] += psnr 266 | if gt_label[1] == 1: 267 | mse = self.criterionL2(self.fake_B, self.real_B) 268 | psnr = 10 * log10(1 / mse.item()) 269 | self.psnr_count[1] += psnr 270 | if gt_label[2] == 1: 271 | mse = self.criterionL2(self.fake_C, self.real_C) 272 | psnr = 10 * log10(1 / mse.item()) 273 | self.psnr_count[2] += psnr 274 | if predict_label.tolist() == gt_label: 275 | self.correct = self.correct + 1 276 | self.all_count = self.all_count + 1 277 | if self.all_count % 300 == 0: 278 | acc = self.correct / self.all_count 279 | print("Accuracy for current: ",acc) 280 | if gt_label[0] == 1: 281 | print("PSNR_A: ", self.psnr_count[0]/ self.all_count) 282 | if gt_label[1] == 1: 283 | print("PSNR_B: ", self.psnr_count[1]/ self.all_count) 284 | if gt_label[2] == 1: 285 | print("PSNR_C: ", self.psnr_count[2]/ self.all_count) 286 | self.all_count = 0 287 | self.correct = 0 288 | self.psnr_count = [0] * self.opt.max_domain 289 | self.acc_all += acc 290 | self.test_time += 1 291 | if( mean(gt_label) == 1): 292 | print("Overall Accuracy:", self.acc_all/self.test_time) 293 | 294 | 295 | 296 | --------------------------------------------------------------------------------