├── .gitignore ├── LICENSE ├── README.md ├── data ├── __init__.py ├── base_dataset.py ├── data_loader.py ├── image_folder.py └── unaligned_dataset.py ├── img ├── Discr.png ├── Matches.png └── Res.png ├── models ├── __init__.py ├── base_model.py ├── combogan_model.py └── networks.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── scripts ├── continue_combogan.sh ├── test_combogan.sh └── train_combogan.sh ├── test.py ├── train.py └── util ├── __init__.py ├── html.py ├── image_pool.py ├── png.py ├── util.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | datasets/ 2 | checkpoints/ 3 | results/ 4 | matlab/ 5 | */**/__pycache__ 6 | */*.pyc 7 | */**/*.pyc 8 | */**/**/*.pyc 9 | */**/**/**/*.pyc 10 | */**/**/**/**/*.pyc 11 | */*.so* 12 | */**/*.so* 13 | */**/*.dylib* 14 | *~ 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, Asha Anoosheh 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # ToDayGAN 3 | 4 | This is our PyTorch implementation for ToDayGAN. 5 | Code was written by [Asha Anoosheh](https://github.com/aanoosheh) (built upon [ComboGAN](https://github.com/AAnoosheh/ComboGAN)) 6 | 7 | #### [[ToDayGAN Paper]](https://arxiv.org/pdf/1809.09767.pdf) 8 | #### [[ComboGAN Paper]](https://arxiv.org/pdf/1712.06909.pdf) 9 | 10 | 11 | If you use this code for your research, please cite: 12 | 13 | Night-to-Day Image Translation for Retrieval-based Localization 14 | [Asha Anoosheh](http://ashaanoosheh.com), [Torsten Sattler](http://people.inf.ethz.ch/sattlert/), [Radu Timofte](http://www.vision.ee.ethz.ch/~timofter/), [Marc Pollefeys](https://www.microsoft.com/en-us/research/people/mapoll/), [Luc van Gool](https://www.vision.ee.ethz.ch/en/members/get_member.cgi?id=1) 15 | In Arxiv, 2018. 16 | 17 | 18 | 19 |

20 | 21 |

22 | 23 | 24 | 25 | ## Prerequisites 26 | - Linux or macOS 27 | - Python 3 28 | - CPU or NVIDIA GPU + CUDA CuDNN 29 | 30 | ## Getting Started 31 | ### Installation 32 | - Install requisite Python libraries. 33 | ```bash 34 | pip install torch 35 | pip install torchvision 36 | pip install visdom 37 | pip install dominate 38 | ``` 39 | - Clone this repo: 40 | ```bash 41 | git clone https://github.com/AAnoosheh/ToDayGAN.git 42 | ``` 43 | 44 | ### Training 45 | Example running scripts can be found in the `scripts` directory. 46 | 47 | One of our pretrained models for the Oxford Robotcars dataset is found [HERE](https://www.dropbox.com/s/mwqfbs19cptrej6/2DayGAN_Checkpoint150.zip?dl=0). Place under ./checkpoints/robotcar_2day and test using the instructions below, with args `--name robotcar_2day --dataroot ./datasets/ --n_domains 2 --which_epoch 150 --loadSize 512` 48 | 49 | Because of sesitivity to instrinsic camera characteristics, testing should ideally be on the same Oxford dataset photos (and same Grasshopper camera) found conveniently preprocessed and ready-to-use [HERE](https://www.visuallocalization.net/datasets/). 50 | 51 | If using this pretrained model, `` should contain two subfolders `test0` & `test1`, containing Day and Night images to test, respectively (as mine was trained with this ordering). `test0` can be empty if you do not care about Day image translated to Night, but just needs to exist to not break the code. 52 | 53 | - Train a model: 54 | ``` 55 | python train.py --name --dataroot ./datasets/ --n_domains --niter --niter_decay 56 | ``` 57 | Checkpoints will be saved by default to `./checkpoints//` 58 | - Fine-tuning/Resume training: 59 | ``` 60 | python train.py --continue_train --which_epoch --name --dataroot ./datasets/ --n_domains --niter --niter_decay 61 | ``` 62 | - Test the model: 63 | ``` 64 | python test.py --phase test --serial_test --name --dataroot ./datasets/ --n_domains --which_epoch 65 | ``` 66 | The test results will be saved to a html file here: `./results///index.html`. 67 | 68 | 69 | 70 | ## Training/Testing Details 71 | - Flags: see `options/train_options.py` for training-specific flags; see `options/test_options.py` for test-specific flags; and see `options/base_options.py` for all common flags. 72 | - Dataset format: The desired data directory (provided by `--dataroot`) should contain subfolders of the form `train*/` and `test*/`, and they are loaded in alphabetical order. (Note that a folder named train10 would be loaded before train2, and thus all checkpoints and results would be ordered accordingly.) Test directories should match alphabetical ordering of the training ones. 73 | - CPU/GPU (default `--gpu_ids 0`): set`--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode. 74 | - Visualization: during training, the current results and loss plots can be viewed using two methods. First, if you set `--display_id` > 0, the results and loss plot will appear on a local graphics web server launched by [visdom](https://github.com/facebookresearch/visdom). To do this, you should have `visdom` installed and a server running by the command `python -m visdom.server`. The default server URL is `http://localhost:8097`. `display_id` corresponds to the window ID that is displayed on the `visdom` server. The `visdom` display functionality is turned on by default. To avoid the extra overhead of communicating with `visdom` set `--display_id 0`. Secondly, the intermediate results are also saved to `./checkpoints//web/index.html`. To avoid this, set the `--no_html` flag. 75 | - Preprocessing: images can be resized and cropped in different ways using `--resize_or_crop` option. The default option `'resize_and_crop'` resizes the image such that the largest side becomes `opt.loadSize` and then does a random crop of size `(opt.fineSize, opt.fineSize)`. Other options are either just `resize` or `crop` on their own. 76 | 77 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AAnoosheh/ToDayGAN/12de3af4f209cd227bdc54160fc3164f97f6732d/data/__init__.py -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | 5 | class BaseDataset(data.Dataset): 6 | def __init__(self): 7 | super(BaseDataset, self).__init__() 8 | 9 | def name(self): 10 | return 'BaseDataset' 11 | 12 | def initialize(self, opt): 13 | pass 14 | 15 | def get_transform(opt): 16 | transform_list = [] 17 | if 'resize' in opt.resize_or_crop: 18 | transform_list.append(transforms.Resize(opt.loadSize, Image.BICUBIC)) 19 | 20 | if opt.isTrain: 21 | if 'crop' in opt.resize_or_crop: 22 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 23 | if not opt.no_flip: 24 | transform_list.append(transforms.RandomHorizontalFlip()) 25 | 26 | transform_list += [transforms.ToTensor(), 27 | transforms.Normalize((0.5, 0.5, 0.5), 28 | (0.5, 0.5, 0.5))] 29 | return transforms.Compose(transform_list) 30 | -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from data.unaligned_dataset import UnalignedDataset 3 | 4 | 5 | class DataLoader(): 6 | def name(self): 7 | return 'DataLoader' 8 | 9 | def __init__(self, opt): 10 | self.opt = opt 11 | self.dataset = UnalignedDataset(opt) 12 | self.dataloader = torch.utils.data.DataLoader( 13 | self.dataset, 14 | batch_size=opt.batchSize, 15 | num_workers=int(opt.nThreads)) 16 | 17 | def __len__(self): 18 | return min(len(self.dataset), self.opt.max_dataset_size) 19 | 20 | def __iter__(self): 21 | for i, data in enumerate(self.dataloader): 22 | if i >= self.opt.max_dataset_size: 23 | break 24 | yield data 25 | 26 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ############################################################################### 7 | 8 | import torch.utils.data as data 9 | 10 | from PIL import Image 11 | import os 12 | import os.path 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir): 25 | images = [] 26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 27 | 28 | for root, _, fnames in sorted(os.walk(dir)): 29 | for fname in fnames: 30 | if is_image_file(fname): 31 | path = os.path.join(root, fname) 32 | images.append(path) 33 | 34 | return images 35 | 36 | 37 | def default_loader(path): 38 | return Image.open(path).convert('RGB') 39 | 40 | 41 | class ImageFolder(data.Dataset): 42 | 43 | def __init__(self, root, transform=None, return_paths=False, 44 | loader=default_loader): 45 | imgs = make_dataset(root) 46 | if len(imgs) == 0: 47 | raise(RuntimeError("Found 0 images in: " + root + "\n" 48 | "Supported image extensions are: " + 49 | ",".join(IMG_EXTENSIONS))) 50 | 51 | self.root = root 52 | self.imgs = imgs 53 | self.transform = transform 54 | self.return_paths = return_paths 55 | self.loader = loader 56 | 57 | def __getitem__(self, index): 58 | path = self.imgs[index] 59 | img = self.loader(path) 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | if self.return_paths: 63 | return img, path 64 | else: 65 | return img 66 | 67 | def __len__(self): 68 | return len(self.imgs) 69 | -------------------------------------------------------------------------------- /data/unaligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path, glob 2 | import torchvision.transforms as transforms 3 | from data.base_dataset import BaseDataset, get_transform 4 | from data.image_folder import make_dataset 5 | from PIL import Image 6 | import random 7 | 8 | class UnalignedDataset(BaseDataset): 9 | def __init__(self, opt): 10 | super(UnalignedDataset, self).__init__() 11 | self.opt = opt 12 | self.transform = get_transform(opt) 13 | 14 | datapath = os.path.join(opt.dataroot, opt.phase + '*') 15 | self.dirs = sorted(glob.glob(datapath)) 16 | 17 | self.paths = [sorted(make_dataset(d)) for d in self.dirs] 18 | self.sizes = [len(p) for p in self.paths] 19 | 20 | def load_image(self, dom, idx): 21 | path = self.paths[dom][idx] 22 | img = Image.open(path).convert('RGB') 23 | img = self.transform(img) 24 | return img, path 25 | 26 | def __getitem__(self, index): 27 | if not self.opt.isTrain: 28 | if self.opt.serial_test: 29 | for d,s in enumerate(self.sizes): 30 | if index < s: 31 | DA = d; break 32 | index -= s 33 | index_A = index 34 | else: 35 | DA = index % len(self.dirs) 36 | index_A = random.randint(0, self.sizes[DA] - 1) 37 | else: 38 | # Choose two of our domains to perform a pass on 39 | DA, DB = random.sample(range(len(self.dirs)), 2) 40 | index_A = random.randint(0, self.sizes[DA] - 1) 41 | 42 | A_img, A_path = self.load_image(DA, index_A) 43 | bundle = {'A': A_img, 'DA': DA, 'path': A_path} 44 | 45 | if self.opt.isTrain: 46 | index_B = random.randint(0, self.sizes[DB] - 1) 47 | B_img, _ = self.load_image(DB, index_B) 48 | bundle.update( {'B': B_img, 'DB': DB} ) 49 | 50 | return bundle 51 | 52 | def __len__(self): 53 | if self.opt.isTrain: 54 | return max(self.sizes) 55 | return sum(self.sizes) 56 | 57 | def name(self): 58 | return 'UnalignedDataset' 59 | -------------------------------------------------------------------------------- /img/Discr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AAnoosheh/ToDayGAN/12de3af4f209cd227bdc54160fc3164f97f6732d/img/Discr.png -------------------------------------------------------------------------------- /img/Matches.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AAnoosheh/ToDayGAN/12de3af4f209cd227bdc54160fc3164f97f6732d/img/Matches.png -------------------------------------------------------------------------------- /img/Res.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AAnoosheh/ToDayGAN/12de3af4f209cd227bdc54160fc3164f97f6732d/img/Res.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AAnoosheh/ToDayGAN/12de3af4f209cd227bdc54160fc3164f97f6732d/models/__init__.py -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | class BaseModel(): 6 | def name(self): 7 | return 'BaseModel' 8 | 9 | def __init__(self, opt): 10 | self.opt = opt 11 | self.gpu_ids = opt.gpu_ids 12 | self.isTrain = opt.isTrain 13 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 14 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 15 | 16 | def set_input(self, input): 17 | self.input = input 18 | 19 | def forward(self): 20 | pass 21 | 22 | # used in test time, no backprop 23 | def test(self): 24 | pass 25 | 26 | def get_image_paths(self): 27 | pass 28 | 29 | def optimize_parameters(self): 30 | pass 31 | 32 | def get_current_visuals(self): 33 | return self.input 34 | 35 | def get_current_errors(self): 36 | return {} 37 | 38 | def save(self, label): 39 | pass 40 | 41 | # helper saving function that can be used by subclasses 42 | def save_network(self, network, network_label, epoch, gpu_ids): 43 | save_filename = '%d_net_%s' % (epoch, network_label) 44 | save_path = os.path.join(self.save_dir, save_filename) 45 | network.save(save_path) 46 | if gpu_ids and torch.cuda.is_available(): 47 | network.cuda(gpu_ids[0]) 48 | 49 | # helper loading function that can be used by subclasses 50 | def load_network(self, network, network_label, epoch): 51 | save_filename = '%d_net_%s' % (epoch, network_label) 52 | save_path = os.path.join(self.save_dir, save_filename) 53 | network.load(save_path) 54 | 55 | def update_learning_rate(): 56 | pass 57 | -------------------------------------------------------------------------------- /models/combogan_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from collections import OrderedDict 4 | import util.util as util 5 | from util.image_pool import ImagePool 6 | from .base_model import BaseModel 7 | from . import networks 8 | 9 | 10 | class ComboGANModel(BaseModel): 11 | def name(self): 12 | return 'ComboGANModel' 13 | 14 | def __init__(self, opt): 15 | super(ComboGANModel, self).__init__(opt) 16 | 17 | self.n_domains = opt.n_domains 18 | self.DA, self.DB = None, None 19 | 20 | self.real_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) 21 | self.real_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) 22 | 23 | # load/define networks 24 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 25 | opt.netG_n_blocks, opt.netG_n_shared, 26 | self.n_domains, opt.norm, opt.use_dropout, self.gpu_ids) 27 | if self.isTrain: 28 | self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD_n_layers, 29 | self.n_domains, self.Tensor, opt.norm, self.gpu_ids) 30 | 31 | if not self.isTrain or opt.continue_train: 32 | which_epoch = opt.which_epoch 33 | self.load_network(self.netG, 'G', which_epoch) 34 | if self.isTrain: 35 | self.load_network(self.netD, 'D', which_epoch) 36 | 37 | if self.isTrain: 38 | self.fake_pools = [ImagePool(opt.pool_size) for _ in range(self.n_domains)] 39 | # define loss functions 40 | self.L1 = torch.nn.SmoothL1Loss() 41 | self.downsample = torch.nn.AvgPool2d(3, stride=2) 42 | self.criterionCycle = self.L1 43 | self.criterionIdt = lambda y,t : self.L1(self.downsample(y), self.downsample(t)) 44 | self.criterionLatent = lambda y,t : self.L1(y, t.detach()) 45 | self.criterionGAN = lambda r,f,v : (networks.GANLoss(r[0],f[0],v) + \ 46 | networks.GANLoss(r[1],f[1],v) + \ 47 | networks.GANLoss(r[2],f[2],v)) / 3 48 | # initialize optimizers 49 | self.netG.init_optimizers(torch.optim.Adam, opt.lr, (opt.beta1, 0.999)) 50 | self.netD.init_optimizers(torch.optim.Adam, opt.lr, (opt.beta1, 0.999)) 51 | # initialize loss storage 52 | self.loss_D, self.loss_G = [0]*self.n_domains, [0]*self.n_domains 53 | self.loss_cycle = [0]*self.n_domains 54 | # initialize loss multipliers 55 | self.lambda_cyc, self.lambda_enc = opt.lambda_cycle, (0 * opt.lambda_latent) 56 | self.lambda_idt, self.lambda_fwd = opt.lambda_identity, opt.lambda_forward 57 | 58 | print('---------- Networks initialized -------------') 59 | print(self.netG) 60 | if self.isTrain: 61 | print(self.netD) 62 | print('-----------------------------------------------') 63 | 64 | def set_input(self, input): 65 | input_A = input['A'] 66 | self.real_A.resize_(input_A.size()).copy_(input_A) 67 | self.DA = input['DA'][0] 68 | if self.isTrain: 69 | input_B = input['B'] 70 | self.real_B.resize_(input_B.size()).copy_(input_B) 71 | self.DB = input['DB'][0] 72 | self.image_paths = input['path'] 73 | 74 | def test(self): 75 | with torch.no_grad(): 76 | self.visuals = [self.real_A] 77 | self.labels = ['real_%d' % self.DA] 78 | 79 | # cache encoding to not repeat it everytime 80 | encoded = self.netG.encode(self.real_A, self.DA) 81 | for d in range(self.n_domains): 82 | if d == self.DA and not self.opt.autoencode: 83 | continue 84 | fake = self.netG.decode(encoded, d) 85 | self.visuals.append( fake ) 86 | self.labels.append( 'fake_%d' % d ) 87 | if self.opt.reconstruct: 88 | rec = self.netG.forward(fake, d, self.DA) 89 | self.visuals.append( rec ) 90 | self.labels.append( 'rec_%d' % d ) 91 | 92 | def get_image_paths(self): 93 | return self.image_paths 94 | 95 | def backward_D_basic(self, pred_real, fake, domain): 96 | pred_fake = self.netD.forward(fake.detach(), domain) 97 | loss_D = self.criterionGAN(pred_real, pred_fake, True) * 0.5 98 | loss_D.backward() 99 | return loss_D 100 | 101 | def backward_D(self): 102 | #D_A 103 | fake_B = self.fake_pools[self.DB].query(self.fake_B) 104 | self.loss_D[self.DA] = self.backward_D_basic(self.pred_real_B, fake_B, self.DB) 105 | #D_B 106 | fake_A = self.fake_pools[self.DA].query(self.fake_A) 107 | self.loss_D[self.DB] = self.backward_D_basic(self.pred_real_A, fake_A, self.DA) 108 | 109 | def backward_G(self): 110 | encoded_A = self.netG.encode(self.real_A, self.DA) 111 | encoded_B = self.netG.encode(self.real_B, self.DB) 112 | 113 | # Optional identity "autoencode" loss 114 | if self.lambda_idt > 0: 115 | # Same encoder and decoder should recreate image 116 | idt_A = self.netG.decode(encoded_A, self.DA) 117 | loss_idt_A = self.criterionIdt(idt_A, self.real_A) 118 | idt_B = self.netG.decode(encoded_B, self.DB) 119 | loss_idt_B = self.criterionIdt(idt_B, self.real_B) 120 | else: 121 | loss_idt_A, loss_idt_B = 0, 0 122 | 123 | # GAN loss 124 | # D_A(G_A(A)) 125 | self.fake_B = self.netG.decode(encoded_A, self.DB) 126 | pred_fake = self.netD.forward(self.fake_B, self.DB) 127 | self.loss_G[self.DA] = self.criterionGAN(self.pred_real_B, pred_fake, False) 128 | # D_B(G_B(B)) 129 | self.fake_A = self.netG.decode(encoded_B, self.DA) 130 | pred_fake = self.netD.forward(self.fake_A, self.DA) 131 | self.loss_G[self.DB] = self.criterionGAN(self.pred_real_A, pred_fake, False) 132 | # Forward cycle loss 133 | rec_encoded_A = self.netG.encode(self.fake_B, self.DB) 134 | self.rec_A = self.netG.decode(rec_encoded_A, self.DA) 135 | self.loss_cycle[self.DA] = self.criterionCycle(self.rec_A, self.real_A) 136 | # Backward cycle loss 137 | rec_encoded_B = self.netG.encode(self.fake_A, self.DA) 138 | self.rec_B = self.netG.decode(rec_encoded_B, self.DB) 139 | self.loss_cycle[self.DB] = self.criterionCycle(self.rec_B, self.real_B) 140 | 141 | # Optional cycle loss on encoding space 142 | if self.lambda_enc > 0: 143 | loss_enc_A = self.criterionLatent(rec_encoded_A, encoded_A) 144 | loss_enc_B = self.criterionLatent(rec_encoded_B, encoded_B) 145 | else: 146 | loss_enc_A, loss_enc_B = 0, 0 147 | 148 | # Optional loss on downsampled image before and after 149 | if self.lambda_fwd > 0: 150 | loss_fwd_A = self.criterionIdt(self.fake_B, self.real_A) 151 | loss_fwd_B = self.criterionIdt(self.fake_A, self.real_B) 152 | else: 153 | loss_fwd_A, loss_fwd_B = 0, 0 154 | 155 | # combined loss 156 | loss_G = self.loss_G[self.DA] + self.loss_G[self.DB] + \ 157 | (self.loss_cycle[self.DA] + self.loss_cycle[self.DB]) * self.lambda_cyc + \ 158 | (loss_idt_A + loss_idt_B) * self.lambda_idt + \ 159 | (loss_enc_A + loss_enc_B) * self.lambda_enc + \ 160 | (loss_fwd_A + loss_fwd_B) * self.lambda_fwd 161 | loss_G.backward() 162 | 163 | def optimize_parameters(self): 164 | self.pred_real_A = self.netD.forward(self.real_A, self.DA) 165 | self.pred_real_B = self.netD.forward(self.real_B, self.DB) 166 | # G_A and G_B 167 | self.netG.zero_grads(self.DA, self.DB) 168 | self.backward_G() 169 | self.netG.step_grads(self.DA, self.DB) 170 | # D_A and D_B 171 | self.netD.zero_grads(self.DA, self.DB) 172 | self.backward_D() 173 | self.netD.step_grads(self.DA, self.DB) 174 | 175 | def get_current_errors(self): 176 | extract = lambda l: [(i if type(i) is int or type(i) is float else i.item()) for i in l] 177 | D_losses, G_losses, cyc_losses = extract(self.loss_D), extract(self.loss_G), extract(self.loss_cycle) 178 | return OrderedDict([('D', D_losses), ('G', G_losses), ('Cyc', cyc_losses)]) 179 | 180 | def get_current_visuals(self, testing=False): 181 | if not testing: 182 | self.visuals = [self.real_A, self.fake_B, self.rec_A, self.real_B, self.fake_A, self.rec_B] 183 | self.labels = ['real_A', 'fake_B', 'rec_A', 'real_B', 'fake_A', 'rec_B'] 184 | images = [util.tensor2im(v.data) for v in self.visuals] 185 | return OrderedDict(zip(self.labels, images)) 186 | 187 | def save(self, label): 188 | self.save_network(self.netG, 'G', label, self.gpu_ids) 189 | self.save_network(self.netD, 'D', label, self.gpu_ids) 190 | 191 | def update_hyperparams(self, curr_iter): 192 | if curr_iter > self.opt.niter: 193 | decay_frac = (curr_iter - self.opt.niter) / self.opt.niter_decay 194 | new_lr = self.opt.lr * (1 - decay_frac) 195 | self.netG.update_lr(new_lr) 196 | self.netD.update_lr(new_lr) 197 | print('updated learning rate: %f' % new_lr) 198 | 199 | if self.opt.lambda_latent > 0: 200 | decay_frac = curr_iter / (self.opt.niter + self.opt.niter_decay) 201 | self.lambda_enc = self.opt.lambda_latent * decay_frac 202 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools, itertools 5 | import numpy as np 6 | from util.util import gkern_2d 7 | 8 | 9 | 10 | 11 | def weights_init(m): 12 | classname = m.__class__.__name__ 13 | if classname.find('Conv') != -1: 14 | m.weight.data.normal_(0.0, 0.02) 15 | if hasattr(m.bias, 'data'): 16 | m.bias.data.fill_(0) 17 | elif classname.find('BatchNorm2d') != -1: 18 | m.weight.data.normal_(1.0, 0.02) 19 | m.bias.data.fill_(0) 20 | 21 | 22 | def get_norm_layer(norm_type='instance'): 23 | if norm_type == 'batch': 24 | return functools.partial(nn.BatchNorm2d, affine=True) 25 | elif norm_type == 'instance': 26 | return functools.partial(nn.InstanceNorm2d, affine=False) 27 | else: 28 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 29 | 30 | 31 | def define_G(input_nc, output_nc, ngf, n_blocks, n_blocks_shared, n_domains, norm='batch', use_dropout=False, gpu_ids=[]): 32 | norm_layer = get_norm_layer(norm_type=norm) 33 | if type(norm_layer) == functools.partial: 34 | use_bias = norm_layer.func == nn.InstanceNorm2d 35 | else: 36 | use_bias = norm_layer == nn.InstanceNorm2d 37 | 38 | n_blocks -= n_blocks_shared 39 | n_blocks_enc = n_blocks // 2 40 | n_blocks_dec = n_blocks - n_blocks_enc 41 | 42 | dup_args = (ngf, norm_layer, use_dropout, gpu_ids, use_bias) 43 | enc_args = (input_nc, n_blocks_enc) + dup_args 44 | dec_args = (output_nc, n_blocks_dec) + dup_args 45 | 46 | if n_blocks_shared > 0: 47 | n_blocks_shdec = n_blocks_shared // 2 48 | n_blocks_shenc = n_blocks_shared - n_blocks_shdec 49 | shenc_args = (n_domains, n_blocks_shenc) + dup_args 50 | shdec_args = (n_domains, n_blocks_shdec) + dup_args 51 | plex_netG = G_Plexer(n_domains, ResnetGenEncoder, enc_args, ResnetGenDecoder, dec_args, ResnetGenShared, shenc_args, shdec_args) 52 | else: 53 | plex_netG = G_Plexer(n_domains, ResnetGenEncoder, enc_args, ResnetGenDecoder, dec_args) 54 | 55 | if len(gpu_ids) > 0: 56 | assert(torch.cuda.is_available()) 57 | plex_netG.cuda(gpu_ids[0]) 58 | 59 | plex_netG.apply(weights_init) 60 | return plex_netG 61 | 62 | 63 | def define_D(input_nc, ndf, netD_n_layers, n_domains, tensor, norm='batch', gpu_ids=[]): 64 | norm_layer = get_norm_layer(norm_type=norm) 65 | 66 | model_args = (input_nc, ndf, netD_n_layers, tensor, norm_layer, gpu_ids) 67 | plex_netD = D_Plexer(n_domains, NLayerDiscriminator, model_args) 68 | 69 | if len(gpu_ids) > 0: 70 | assert(torch.cuda.is_available()) 71 | plex_netD.cuda(gpu_ids[0]) 72 | 73 | plex_netD.apply(weights_init) 74 | return plex_netD 75 | 76 | 77 | ############################################################################## 78 | # Classes 79 | ############################################################################## 80 | 81 | 82 | # Defines the GAN loss which uses the Relativistic LSGAN 83 | def GANLoss(inputs_real, inputs_fake, is_discr): 84 | if is_discr: 85 | y = -1 86 | else: 87 | y = 1 88 | inputs_real = [i.detach() for i in inputs_real] 89 | loss = lambda r,f : torch.mean((r-f+y)**2) 90 | losses = [loss(r,f) for r,f in zip(inputs_real, inputs_fake)] 91 | multipliers = list(range(1, len(inputs_real)+1)); multipliers[-1] += 1 92 | losses = [m*l for m,l in zip(multipliers, losses)] 93 | return sum(losses) / (sum(multipliers) * len(losses)) 94 | 95 | 96 | # Defines the generator that consists of Resnet blocks between a few 97 | # downsampling/upsampling operations. 98 | # Code and idea originally from Justin Johnson's architecture. 99 | # https://github.com/jcjohnson/fast-neural-style/ 100 | class ResnetGenEncoder(nn.Module): 101 | def __init__(self, input_nc, n_blocks=4, ngf=64, norm_layer=nn.BatchNorm2d, 102 | use_dropout=False, gpu_ids=[], use_bias=False, padding_type='reflect'): 103 | assert(n_blocks >= 0) 104 | super(ResnetGenEncoder, self).__init__() 105 | self.gpu_ids = gpu_ids 106 | 107 | model = [nn.ReflectionPad2d(3), 108 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, 109 | bias=use_bias), 110 | norm_layer(ngf), 111 | nn.PReLU()] 112 | 113 | n_downsampling = 2 114 | for i in range(n_downsampling): 115 | mult = 2**i 116 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, 117 | stride=2, padding=1, bias=use_bias), 118 | norm_layer(ngf * mult * 2), 119 | nn.PReLU()] 120 | 121 | mult = 2**n_downsampling 122 | for _ in range(n_blocks): 123 | model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, 124 | use_dropout=use_dropout, use_bias=use_bias, padding_type=padding_type)] 125 | 126 | self.model = nn.Sequential(*model) 127 | 128 | def forward(self, input): 129 | if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): 130 | return nn.parallel.data_parallel(self.model, input, self.gpu_ids) 131 | return self.model(input) 132 | 133 | class ResnetGenShared(nn.Module): 134 | def __init__(self, n_domains, n_blocks=2, ngf=64, norm_layer=nn.BatchNorm2d, 135 | use_dropout=False, gpu_ids=[], use_bias=False, padding_type='reflect'): 136 | assert(n_blocks >= 0) 137 | super(ResnetGenShared, self).__init__() 138 | self.gpu_ids = gpu_ids 139 | 140 | model = [] 141 | n_downsampling = 2 142 | mult = 2**n_downsampling 143 | 144 | for _ in range(n_blocks): 145 | model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, n_domains=n_domains, 146 | use_dropout=use_dropout, use_bias=use_bias, padding_type=padding_type)] 147 | 148 | self.model = SequentialContext(n_domains, *model) 149 | 150 | def forward(self, input, domain): 151 | if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): 152 | return nn.parallel.data_parallel(self.model, (input, domain), self.gpu_ids) 153 | return self.model(input, domain) 154 | 155 | class ResnetGenDecoder(nn.Module): 156 | def __init__(self, output_nc, n_blocks=5, ngf=64, norm_layer=nn.BatchNorm2d, 157 | use_dropout=False, gpu_ids=[], use_bias=False, padding_type='reflect'): 158 | assert(n_blocks >= 0) 159 | super(ResnetGenDecoder, self).__init__() 160 | self.gpu_ids = gpu_ids 161 | 162 | model = [] 163 | n_downsampling = 2 164 | mult = 2**n_downsampling 165 | 166 | for _ in range(n_blocks): 167 | model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, 168 | use_dropout=use_dropout, use_bias=use_bias, padding_type=padding_type)] 169 | 170 | for i in range(n_downsampling): 171 | mult = 2**(n_downsampling - i) 172 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 173 | kernel_size=4, stride=2, 174 | padding=1, output_padding=0, 175 | bias=use_bias), 176 | norm_layer(int(ngf * mult / 2)), 177 | nn.PReLU()] 178 | 179 | model += [nn.ReflectionPad2d(3), 180 | nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), 181 | nn.Tanh()] 182 | 183 | self.model = nn.Sequential(*model) 184 | 185 | def forward(self, input): 186 | if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): 187 | return nn.parallel.data_parallel(self.model, input, self.gpu_ids) 188 | return self.model(input) 189 | 190 | 191 | # Define a resnet block 192 | class ResnetBlock(nn.Module): 193 | def __init__(self, dim, norm_layer, use_dropout, use_bias, padding_type='reflect', n_domains=0): 194 | super(ResnetBlock, self).__init__() 195 | 196 | conv_block = [] 197 | p = 0 198 | if padding_type == 'reflect': 199 | conv_block += [nn.ReflectionPad2d(1)] 200 | elif padding_type == 'replicate': 201 | conv_block += [nn.ReplicationPad2d(1)] 202 | elif padding_type == 'zero': 203 | p = 1 204 | else: 205 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 206 | 207 | conv_block += [nn.Conv2d(dim + n_domains, dim, kernel_size=3, padding=p, bias=use_bias), 208 | norm_layer(dim), 209 | nn.PReLU()] 210 | if use_dropout: 211 | conv_block += [nn.Dropout(0.5)] 212 | 213 | p = 0 214 | if padding_type == 'reflect': 215 | conv_block += [nn.ReflectionPad2d(1)] 216 | elif padding_type == 'replicate': 217 | conv_block += [nn.ReplicationPad2d(1)] 218 | elif padding_type == 'zero': 219 | p = 1 220 | else: 221 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 222 | conv_block += [nn.Conv2d(dim + n_domains, dim, kernel_size=3, padding=p, bias=use_bias), 223 | norm_layer(dim)] 224 | 225 | self.conv_block = SequentialContext(n_domains, *conv_block) 226 | 227 | def forward(self, input): 228 | if isinstance(input, tuple): 229 | return input[0] + self.conv_block(*input) 230 | return input + self.conv_block(input) 231 | 232 | 233 | # Defines the PatchGAN discriminator with the specified arguments. 234 | class NLayerDiscriminator(nn.Module): 235 | def __init__(self, input_nc, ndf=64, n_layers=3, tensor=torch.FloatTensor, norm_layer=nn.BatchNorm2d, gpu_ids=[]): 236 | super(NLayerDiscriminator, self).__init__() 237 | self.gpu_ids = gpu_ids 238 | self.grad_filter = tensor([0,0,0,-1,0,1,0,0,0]).view(1,1,3,3) 239 | self.dsamp_filter = tensor([1]).view(1,1,1,1) 240 | self.blur_filter = tensor(gkern_2d()) 241 | 242 | self.model_rgb = self.model(input_nc, ndf, n_layers, norm_layer) 243 | self.model_gray = self.model(1, ndf, n_layers, norm_layer) 244 | self.model_grad = self.model(2, ndf, n_layers-1, norm_layer) 245 | 246 | def model(self, input_nc, ndf, n_layers, norm_layer): 247 | if type(norm_layer) == functools.partial: 248 | use_bias = norm_layer.func == nn.InstanceNorm2d 249 | else: 250 | use_bias = norm_layer == nn.InstanceNorm2d 251 | 252 | kw = 4 253 | padw = int(np.ceil((kw-1)/2)) 254 | sequences = [[ 255 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 256 | nn.PReLU() 257 | ]] 258 | 259 | nf_mult = 1 260 | nf_mult_prev = 1 261 | for n in range(1, n_layers): 262 | nf_mult_prev = nf_mult 263 | nf_mult = min(2**n, 8) 264 | sequences += [[ 265 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult + 1, 266 | kernel_size=kw, stride=2, padding=padw, bias=use_bias), 267 | norm_layer(ndf * nf_mult + 1), 268 | nn.PReLU() 269 | ]] 270 | 271 | nf_mult_prev = nf_mult 272 | nf_mult = min(2**n_layers, 8) 273 | sequences += [[ 274 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 275 | kernel_size=kw, stride=1, padding=padw, bias=use_bias), 276 | norm_layer(ndf * nf_mult), 277 | nn.PReLU(), 278 | \ 279 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) 280 | ]] 281 | 282 | return SequentialOutput(*sequences) 283 | 284 | def forward(self, input): 285 | blurred = torch.nn.functional.conv2d(input, self.blur_filter, groups=3, padding=2) 286 | gray = (.299*input[:,0,:,:] + .587*input[:,1,:,:] + .114*input[:,2,:,:]).unsqueeze_(1) 287 | 288 | gray_dsamp = nn.functional.conv2d(gray, self.dsamp_filter, stride=2) 289 | dx = nn.functional.conv2d(gray_dsamp, self.grad_filter) 290 | dy = nn.functional.conv2d(gray_dsamp, self.grad_filter.transpose(-2,-1)) 291 | gradient = torch.cat([dx,dy], 1) 292 | 293 | if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): 294 | outs1 = nn.parallel.data_parallel(self.model_rgb, blurred, self.gpu_ids) 295 | outs2 = nn.parallel.data_parallel(self.model_gray, gray, self.gpu_ids) 296 | outs3 = nn.parallel.data_parallel(self.model_grad, gradient, self.gpu_ids) 297 | else: 298 | outs1 = self.model_rgb(blurred) 299 | outs2 = self.model_gray(gray) 300 | outs3 = self.model_grad(gradient) 301 | return outs1, outs2, outs3 302 | 303 | 304 | class Plexer(nn.Module): 305 | def __init__(self): 306 | super(Plexer, self).__init__() 307 | 308 | def apply(self, func): 309 | for net in self.networks: 310 | net.apply(func) 311 | 312 | def cuda(self, device_id): 313 | for net in self.networks: 314 | net.cuda(device_id) 315 | 316 | def init_optimizers(self, opt, lr, betas): 317 | self.optimizers = [opt(net.parameters(), lr=lr, betas=betas) \ 318 | for net in self.networks] 319 | 320 | def zero_grads(self, dom_a, dom_b): 321 | self.optimizers[dom_a].zero_grad() 322 | self.optimizers[dom_b].zero_grad() 323 | 324 | def step_grads(self, dom_a, dom_b): 325 | self.optimizers[dom_a].step() 326 | self.optimizers[dom_b].step() 327 | 328 | def update_lr(self, new_lr): 329 | for opt in self.optimizers: 330 | for param_group in opt.param_groups: 331 | param_group['lr'] = new_lr 332 | 333 | def save(self, save_path): 334 | for i, net in enumerate(self.networks): 335 | filename = save_path + ('%d.pth' % i) 336 | torch.save(net.cpu().state_dict(), filename) 337 | 338 | def load(self, save_path): 339 | for i, net in enumerate(self.networks): 340 | filename = save_path + ('%d.pth' % i) 341 | net.load_state_dict(torch.load(filename)) 342 | 343 | class G_Plexer(Plexer): 344 | def __init__(self, n_domains, encoder, enc_args, decoder, dec_args, 345 | block=None, shenc_args=None, shdec_args=None): 346 | super(G_Plexer, self).__init__() 347 | self.encoders = [encoder(*enc_args) for _ in range(n_domains)] 348 | self.decoders = [decoder(*dec_args) for _ in range(n_domains)] 349 | 350 | self.sharing = block is not None 351 | if self.sharing: 352 | self.shared_encoder = block(*shenc_args) 353 | self.shared_decoder = block(*shdec_args) 354 | self.encoders.append( self.shared_encoder ) 355 | self.decoders.append( self.shared_decoder ) 356 | self.networks = self.encoders + self.decoders 357 | 358 | def init_optimizers(self, opt, lr, betas): 359 | self.optimizers = [] 360 | for enc, dec in zip(self.encoders, self.decoders): 361 | params = itertools.chain(enc.parameters(), dec.parameters()) 362 | self.optimizers.append( opt(params, lr=lr, betas=betas) ) 363 | 364 | def forward(self, input, in_domain, out_domain): 365 | encoded = self.encode(input, in_domain) 366 | return self.decode(encoded, out_domain) 367 | 368 | def encode(self, input, domain): 369 | output = self.encoders[domain].forward(input) 370 | if self.sharing: 371 | return self.shared_encoder.forward(output, domain) 372 | return output 373 | 374 | def decode(self, input, domain): 375 | if self.sharing: 376 | input = self.shared_decoder.forward(input, domain) 377 | return self.decoders[domain].forward(input) 378 | 379 | def zero_grads(self, dom_a, dom_b): 380 | self.optimizers[dom_a].zero_grad() 381 | if self.sharing: 382 | self.optimizers[-1].zero_grad() 383 | self.optimizers[dom_b].zero_grad() 384 | 385 | def step_grads(self, dom_a, dom_b): 386 | self.optimizers[dom_a].step() 387 | if self.sharing: 388 | self.optimizers[-1].step() 389 | self.optimizers[dom_b].step() 390 | 391 | def __repr__(self): 392 | e, d = self.encoders[0], self.decoders[0] 393 | e_params = sum([p.numel() for p in e.parameters()]) 394 | d_params = sum([p.numel() for p in d.parameters()]) 395 | return repr(e) +'\n'+ repr(d) +'\n'+ \ 396 | 'Created %d Encoder-Decoder pairs' % len(self.encoders) +'\n'+ \ 397 | 'Number of parameters per Encoder: %d' % e_params +'\n'+ \ 398 | 'Number of parameters per Deocder: %d' % d_params 399 | 400 | class D_Plexer(Plexer): 401 | def __init__(self, n_domains, model, model_args): 402 | super(D_Plexer, self).__init__() 403 | self.networks = [model(*model_args) for _ in range(n_domains)] 404 | 405 | def forward(self, input, domain): 406 | discriminator = self.networks[domain] 407 | return discriminator.forward(input) 408 | 409 | def __repr__(self): 410 | t = self.networks[0] 411 | t_params = sum([p.numel() for p in t.parameters()]) 412 | return repr(t) +'\n'+ \ 413 | 'Created %d Discriminators' % len(self.networks) +'\n'+ \ 414 | 'Number of parameters per Discriminator: %d' % t_params 415 | 416 | 417 | class SequentialContext(nn.Sequential): 418 | def __init__(self, n_classes, *args): 419 | super(SequentialContext, self).__init__(*args) 420 | self.n_classes = n_classes 421 | self.context_var = None 422 | 423 | def prepare_context(self, input, domain): 424 | if self.context_var is None or self.context_var.size()[-2:] != input.size()[-2:]: 425 | tensor = torch.cuda.FloatTensor if isinstance(input.data, torch.cuda.FloatTensor) \ 426 | else torch.FloatTensor 427 | self.context_var = tensor(*((1, self.n_classes) + input.size()[-2:])) 428 | 429 | self.context_var.data.fill_(-1.0) 430 | self.context_var.data[:,domain,:,:] = 1.0 431 | return self.context_var 432 | 433 | def forward(self, *input): 434 | if self.n_classes < 2 or len(input) < 2: 435 | return super(SequentialContext, self).forward(input[0]) 436 | x, domain = input 437 | 438 | for module in self._modules.values(): 439 | if 'Conv' in module.__class__.__name__: 440 | context_var = self.prepare_context(x, domain) 441 | x = torch.cat([x, context_var], dim=1) 442 | elif 'Block' in module.__class__.__name__: 443 | x = (x,) + input[1:] 444 | x = module(x) 445 | return x 446 | 447 | class SequentialOutput(nn.Sequential): 448 | def __init__(self, *args): 449 | args = [nn.Sequential(*arg) for arg in args] 450 | super(SequentialOutput, self).__init__(*args) 451 | 452 | def forward(self, input): 453 | predictions = [] 454 | layers = self._modules.values() 455 | for i, module in enumerate(layers): 456 | output = module(input) 457 | if i == 0: 458 | input = output; continue 459 | predictions.append( output[:,-1,:,:] ) 460 | if i != len(layers) - 1: 461 | input = output[:,:-1,:,:] 462 | return predictions 463 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AAnoosheh/ToDayGAN/12de3af4f209cd227bdc54160fc3164f97f6732d/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | 6 | class BaseOptions(): 7 | def __init__(self): 8 | self.parser = argparse.ArgumentParser() 9 | self.initialized = False 10 | 11 | def initialize(self): 12 | self.parser.add_argument('--name', required=True, type=str, help='name of the experiment. It decides where to store samples and models') 13 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 14 | 15 | self.parser.add_argument('--dataroot', required=True, type=str, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') 16 | self.parser.add_argument('--n_domains', required=True, type=int, help='Number of domains to transfer among') 17 | 18 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 19 | self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [none|resize|resize_and_crop|crop]') 20 | self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 21 | 22 | self.parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size') 23 | self.parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size') 24 | 25 | self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size') 26 | self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') 27 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') 28 | 29 | self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 30 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 31 | self.parser.add_argument('--netG_n_blocks', type=int, default=9, help='number of residual blocks to use for netG') 32 | self.parser.add_argument('--netG_n_shared', type=int, default=0, help='number of blocks to use for netG shared center module') 33 | self.parser.add_argument('--netD_n_layers', type=int, default=4, help='number of layers to use for netD') 34 | 35 | self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') 36 | self.parser.add_argument('--use_dropout', action='store_true', help='insert dropout for the generator') 37 | 38 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 39 | self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data') 40 | 41 | self.parser.add_argument('--display_id', type=int, default=0, help='window id of the web display (set >1 to use visdom)') 42 | self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 43 | self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size') 44 | self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 45 | 46 | self.initialized = True 47 | 48 | def parse(self): 49 | if not self.initialized: 50 | self.initialize() 51 | self.opt = self.parser.parse_args() 52 | self.opt.isTrain = self.isTrain # train or test 53 | 54 | str_ids = self.opt.gpu_ids.split(',') 55 | self.opt.gpu_ids = [] 56 | for str_id in str_ids: 57 | id = int(str_id) 58 | if id >= 0: 59 | self.opt.gpu_ids.append(id) 60 | 61 | # set gpu ids 62 | if len(self.opt.gpu_ids) > 0: 63 | torch.cuda.set_device(self.opt.gpu_ids[0]) 64 | 65 | args = vars(self.opt) 66 | 67 | print('------------ Options -------------') 68 | for k, v in sorted(args.items()): 69 | print('%s: %s' % (str(k), str(v))) 70 | print('-------------- End ----------------') 71 | 72 | # save to the disk 73 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) 74 | util.mkdirs(expr_dir) 75 | file_name = os.path.join(expr_dir, 'opt.txt') 76 | with open(file_name, 'wt') as opt_file: 77 | opt_file.write('------------ Options -------------\n') 78 | for k, v in sorted(args.items()): 79 | opt_file.write('%s: %s\n' % (str(k), str(v))) 80 | opt_file.write('-------------- End ----------------\n') 81 | return self.opt 82 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | self.isTrain = False 8 | 9 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 10 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 11 | 12 | self.parser.add_argument('--which_epoch', required=True, type=int, help='which epoch to load for inference?') 13 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc (determines name of folder to load from)') 14 | 15 | self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run (if serial_test not enabled)') 16 | self.parser.add_argument('--serial_test', action='store_true', help='read each image once from folders in sequential order') 17 | 18 | self.parser.add_argument('--autoencode', action='store_true', help='translate images back into its own domain') 19 | self.parser.add_argument('--reconstruct', action='store_true', help='do reconstructions of images during testing') 20 | 21 | self.parser.add_argument('--show_matrix', action='store_true', help='visualize images in a matrix format as well') 22 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | self.isTrain = True 8 | 9 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 10 | self.parser.add_argument('--which_epoch', type=int, default=0, help='which epoch to load if continuing training') 11 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc (determines name of folder to load from)') 12 | 13 | self.parser.add_argument('--niter', required=True, type=int, help='# of epochs at starting learning rate (try 50*n_domains)') 14 | self.parser.add_argument('--niter_decay', required=True, type=int, help='# of epochs to linearly decay learning rate to zero (try 50*n_domains)') 15 | 16 | self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for ADAM') 17 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of ADAM') 18 | 19 | self.parser.add_argument('--lambda_cycle', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') 20 | self.parser.add_argument('--lambda_identity', type=float, default=0.0, help='weight for identity "autoencode" mapping (A -> A)') 21 | self.parser.add_argument('--lambda_latent', type=float, default=0.0, help='weight for latent-space loss (A -> z -> B -> z)') 22 | self.parser.add_argument('--lambda_forward', type=float, default=0.0, help='weight for forward loss (A -> B; try 0.2)') 23 | 24 | self.parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') 25 | self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') 26 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 27 | 28 | self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') 29 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 30 | -------------------------------------------------------------------------------- /scripts/continue_combogan.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --continue_train --which_epoch 30 \ 3 | --dataroot ./datasets/robotcar \ 4 | --name robotcar_night2day \ 5 | --n_domains 2 \ 6 | --niter 25 --niter_decay 25 \ 7 | --loadSize 512 --fineSize 384 8 | -------------------------------------------------------------------------------- /scripts/test_combogan.sh: -------------------------------------------------------------------------------- 1 | python test.py \ 2 | --phase test --which_epoch 50 \ 3 | --serial_test \ 4 | --dataroot ./datasets/robotcar \ 5 | --name robotcar_night2day \ 6 | --n_domains 2 \ 7 | --loadSize 512 8 | -------------------------------------------------------------------------------- /scripts/train_combogan.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --dataroot ./datasets/robotcar \ 3 | --name robotcar_night2day \ 4 | --n_domains 2 \ 5 | --niter 75 --niter_decay 75 \ 6 | --loadSize 512 --fineSize 384 7 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | from options.test_options import TestOptions 4 | from data.data_loader import DataLoader 5 | from models.combogan_model import ComboGANModel 6 | from util.visualizer import Visualizer 7 | from util import html 8 | 9 | 10 | opt = TestOptions().parse() 11 | opt.nThreads = 1 # test code only supports nThreads = 1 12 | opt.batchSize = 1 # test code only supports batchSize = 1 13 | 14 | dataset = DataLoader(opt) 15 | model = ComboGANModel(opt) 16 | visualizer = Visualizer(opt) 17 | # create website 18 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%d' % (opt.phase, opt.which_epoch)) 19 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %d' % (opt.name, opt.phase, opt.which_epoch)) 20 | # store images for matrix visualization 21 | vis_buffer = [] 22 | 23 | # test 24 | for i, data in enumerate(dataset): 25 | if not opt.serial_test and i >= opt.how_many: 26 | break 27 | model.set_input(data) 28 | model.test() 29 | visuals = model.get_current_visuals(testing=True) 30 | img_path = model.get_image_paths() 31 | print('process image... %s' % img_path) 32 | visualizer.save_images(webpage, visuals, img_path) 33 | 34 | if opt.show_matrix: 35 | vis_buffer.append(visuals) 36 | if (i+1) % opt.n_domains == 0: 37 | save_path = os.path.join(web_dir, 'mat_%d.png' % (i//opt.n_domains)) 38 | visualizer.save_image_matrix(vis_buffer, save_path) 39 | vis_buffer.clear() 40 | 41 | webpage.save() 42 | 43 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import TrainOptions 3 | from data.data_loader import DataLoader 4 | from models.combogan_model import ComboGANModel 5 | from util.visualizer import Visualizer 6 | 7 | 8 | opt = TrainOptions().parse() 9 | dataset = DataLoader(opt) 10 | print('# training images = %d' % len(dataset)) 11 | model = ComboGANModel(opt) 12 | visualizer = Visualizer(opt) 13 | total_steps = 0 14 | 15 | # Update initially if continuing 16 | if opt.which_epoch > 0: 17 | model.update_hyperparams(opt.which_epoch) 18 | 19 | for epoch in range(opt.which_epoch + 1, opt.niter + opt.niter_decay + 1): 20 | epoch_start_time = time.time() 21 | epoch_iter = 0 22 | for i, data in enumerate(dataset): 23 | iter_start_time = time.time() 24 | total_steps += opt.batchSize 25 | epoch_iter += opt.batchSize 26 | model.set_input(data) 27 | model.optimize_parameters() 28 | 29 | if total_steps % opt.display_freq == 0: 30 | visualizer.display_current_results(model.get_current_visuals(), epoch) 31 | 32 | if total_steps % opt.print_freq == 0: 33 | errors = model.get_current_errors() 34 | t = (time.time() - iter_start_time) / opt.batchSize 35 | visualizer.print_current_errors(epoch, epoch_iter, errors, t) 36 | if opt.display_id > 0: 37 | visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors) 38 | 39 | if epoch % opt.save_epoch_freq == 0: 40 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) 41 | model.save(epoch) 42 | 43 | print('End of epoch %d / %d \t Time Taken: %d sec' % 44 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 45 | 46 | model.update_hyperparams(epoch) 47 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AAnoosheh/ToDayGAN/12de3af4f209cd227bdc54160fc3164f97f6732d/util/__init__.py -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, reflesh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | self.img_dir = os.path.join(self.web_dir, 'images') 11 | if not os.path.exists(self.web_dir): 12 | os.makedirs(self.web_dir) 13 | if not os.path.exists(self.img_dir): 14 | os.makedirs(self.img_dir) 15 | # print(self.img_dir) 16 | 17 | self.doc = dominate.document(title=title) 18 | if reflesh > 0: 19 | with self.doc.head: 20 | meta(http_equiv="reflesh", content=str(reflesh)) 21 | 22 | def get_image_dir(self): 23 | return self.img_dir 24 | 25 | def add_header(self, str): 26 | with self.doc: 27 | h3(str) 28 | 29 | def add_table(self, border=1): 30 | self.t = table(border=border, style="table-layout: fixed;") 31 | self.doc.add(self.t) 32 | 33 | def add_images(self, ims, txts, links, width=400): 34 | self.add_table() 35 | with self.t: 36 | with tr(): 37 | for im, txt, link in zip(ims, txts, links): 38 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 39 | with p(): 40 | with a(href=os.path.join('images', link)): 41 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 42 | br() 43 | p(txt) 44 | 45 | def save(self): 46 | html_file = '%s/index.html' % self.web_dir 47 | f = open(html_file, 'wt') 48 | f.write(self.doc.render()) 49 | f.close() 50 | 51 | 52 | if __name__ == '__main__': 53 | html = HTML('web/', 'test_html') 54 | html.add_header('hello world') 55 | 56 | ims = [] 57 | txts = [] 58 | links = [] 59 | for n in range(4): 60 | ims.append('image_%d.png' % n) 61 | txts.append('text_%d' % n) 62 | links.append('image_%d.png' % n) 63 | html.add_images(ims, txts, links) 64 | html.save() 65 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from torch.autograd import Variable 5 | class ImagePool(): 6 | def __init__(self, pool_size): 7 | self.pool_size = pool_size 8 | if self.pool_size > 0: 9 | self.num_imgs = 0 10 | self.images = [] 11 | 12 | def query(self, images): 13 | if self.pool_size == 0: 14 | return images 15 | return_images = [] 16 | for image in images.data: 17 | image = torch.unsqueeze(image, 0) 18 | if self.num_imgs < self.pool_size: 19 | self.num_imgs = self.num_imgs + 1 20 | self.images.append(image) 21 | return_images.append(image) 22 | else: 23 | p = random.uniform(0, 1) 24 | if p > 0.5: 25 | random_id = random.randint(0, self.pool_size-1) 26 | tmp = self.images[random_id].clone() 27 | self.images[random_id] = image 28 | return_images.append(tmp) 29 | else: 30 | return_images.append(image) 31 | return_images = Variable(torch.cat(return_images, 0)) 32 | return return_images 33 | -------------------------------------------------------------------------------- /util/png.py: -------------------------------------------------------------------------------- 1 | import struct 2 | import zlib 3 | 4 | def encode(buf, width, height): 5 | """ buf: must be bytes or a bytearray in py3, a regular string in py2. formatted RGBRGB... """ 6 | assert (width * height * 3 == len(buf)) 7 | bpp = 3 8 | 9 | def raw_data(): 10 | # reverse the vertical line order and add null bytes at the start 11 | row_bytes = width * bpp 12 | for row_start in range((height - 1) * width * bpp, -1, -row_bytes): 13 | yield b'\x00' 14 | yield buf[row_start:row_start + row_bytes] 15 | 16 | def chunk(tag, data): 17 | return [ 18 | struct.pack("!I", len(data)), 19 | tag, 20 | data, 21 | struct.pack("!I", 0xFFFFFFFF & zlib.crc32(data, zlib.crc32(tag))) 22 | ] 23 | 24 | SIGNATURE = b'\x89PNG\r\n\x1a\n' 25 | COLOR_TYPE_RGB = 2 26 | COLOR_TYPE_RGBA = 6 27 | bit_depth = 8 28 | return b''.join( 29 | [ SIGNATURE ] + 30 | chunk(b'IHDR', struct.pack("!2I5B", width, height, bit_depth, COLOR_TYPE_RGB, 0, 0, 0)) + 31 | chunk(b'IDAT', zlib.compress(b''.join(raw_data()), 9)) + 32 | chunk(b'IEND', b'') 33 | ) 34 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | from scipy.ndimage.filters import gaussian_filter 5 | from PIL import Image 6 | import inspect, re 7 | import os 8 | import collections 9 | 10 | # Converts a Tensor into a Numpy array 11 | # |imtype|: the desired type of the converted numpy array 12 | def tensor2im(image_tensor, imtype=np.uint8): 13 | image_numpy = image_tensor[0].cpu().float().numpy() 14 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 15 | if image_numpy.shape[2] < 3: 16 | image_numpy = np.dstack([image_numpy]*3) 17 | return image_numpy.astype(imtype) 18 | 19 | def gkern_2d(size=5, sigma=3): 20 | # Create 2D gaussian kernel 21 | dirac = np.zeros((size, size)) 22 | dirac[size//2, size//2] = 1 23 | mask = gaussian_filter(dirac, sigma) 24 | # Adjust dimensions for torch conv2d 25 | return np.stack([np.expand_dims(mask, axis=0)] * 3) 26 | 27 | 28 | def diagnose_network(net, name='network'): 29 | mean = 0.0 30 | count = 0 31 | for param in net.parameters(): 32 | if param.grad is not None: 33 | mean += torch.mean(torch.abs(param.grad.data)) 34 | count += 1 35 | if count > 0: 36 | mean = mean / count 37 | print(name) 38 | print(mean) 39 | 40 | 41 | def save_image(image_numpy, image_path): 42 | image_pil = Image.fromarray(image_numpy) 43 | image_pil.save(image_path) 44 | 45 | def info(object, spacing=10, collapse=1): 46 | """Print methods and doc strings. 47 | Takes module, class, list, dictionary, or string.""" 48 | methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)] 49 | processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s) 50 | print( "\n".join(["%s %s" % 51 | (method.ljust(spacing), 52 | processFunc(str(getattr(object, method).__doc__))) 53 | for method in methodList]) ) 54 | 55 | def varname(p): 56 | for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]: 57 | m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line) 58 | if m: 59 | return m.group(1) 60 | 61 | def print_numpy(x, val=True, shp=False): 62 | x = x.astype(np.float64) 63 | if shp: 64 | print('shape,', x.shape) 65 | if val: 66 | x = x.flatten() 67 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 68 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 69 | 70 | 71 | def mkdirs(paths): 72 | if isinstance(paths, list) and not isinstance(paths, str): 73 | for path in paths: 74 | mkdir(path) 75 | else: 76 | mkdir(paths) 77 | 78 | 79 | def mkdir(path): 80 | if not os.path.exists(path): 81 | os.makedirs(path) 82 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import ntpath 4 | import time 5 | from . import util 6 | from . import html 7 | 8 | class Visualizer(): 9 | def __init__(self, opt): 10 | # self.opt = opt 11 | self.display_id = opt.display_id 12 | self.use_html = opt.isTrain and not opt.no_html 13 | self.win_size = opt.display_winsize 14 | self.name = opt.name 15 | if self.display_id > 0: 16 | import visdom 17 | self.vis = visdom.Visdom(port = opt.display_port) 18 | self.display_single_pane_ncols = opt.display_single_pane_ncols 19 | 20 | if self.use_html: 21 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 22 | self.img_dir = os.path.join(self.web_dir, 'images') 23 | print('create web directory %s...' % self.web_dir) 24 | util.mkdirs([self.web_dir, self.img_dir]) 25 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 26 | with open(self.log_name, "a") as log_file: 27 | now = time.strftime("%c") 28 | log_file.write('================ Training Loss (%s) ================\n' % now) 29 | 30 | # |visuals|: dictionary of images to display or save 31 | def display_current_results(self, visuals, epoch): 32 | if self.display_id > 0: # show images in the browser 33 | if self.display_single_pane_ncols > 0: 34 | h, w = next(iter(visuals.values())).shape[:2] 35 | table_css = """""" % (w, h) 39 | ncols = self.display_single_pane_ncols 40 | title = self.name 41 | label_html = '' 42 | label_html_row = '' 43 | nrows = int(np.ceil(len(visuals.items()) / ncols)) 44 | images = [] 45 | idx = 0 46 | for label, image_numpy in visuals.items(): 47 | label_html_row += '%s' % label 48 | images.append(image_numpy.transpose([2, 0, 1])) 49 | idx += 1 50 | if idx % ncols == 0: 51 | label_html += '%s' % label_html_row 52 | label_html_row = '' 53 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255 54 | while idx % ncols != 0: 55 | images.append(white_image) 56 | label_html_row += '' 57 | idx += 1 58 | if label_html_row != '': 59 | label_html += '%s' % label_html_row 60 | # pane col = image row 61 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 62 | padding=2, opts=dict(title=title + ' images')) 63 | label_html = '%s
' % label_html 64 | self.vis.text(table_css + label_html, win = self.display_id + 2, 65 | opts=dict(title=title + ' labels')) 66 | else: 67 | idx = 1 68 | for label, image_numpy in visuals.items(): 69 | #image_numpy = np.flipud(image_numpy) 70 | self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label), 71 | win=self.display_id + idx) 72 | idx += 1 73 | 74 | if self.use_html: # save images to a html file 75 | for label, image_numpy in visuals.items(): 76 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 77 | util.save_image(image_numpy, img_path) 78 | # update website 79 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) 80 | for n in range(epoch, 0, -1): 81 | webpage.add_header('epoch [%d]' % n) 82 | ims = [] 83 | txts = [] 84 | links = [] 85 | 86 | for label, image_numpy in visuals.items(): 87 | img_path = 'epoch%.3d_%s.png' % (n, label) 88 | ims.append(img_path) 89 | txts.append(label) 90 | links.append(img_path) 91 | webpage.add_images(ims, txts, links, width=self.win_size) 92 | webpage.save() 93 | 94 | # errors: dictionary of error labels and values 95 | def plot_current_errors(self, epoch, counter_ratio, opt, errors): 96 | if not hasattr(self, 'plot_data'): 97 | self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} 98 | self.plot_data['X'].append(epoch + counter_ratio) 99 | self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) 100 | self.vis.line( 101 | X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1), 102 | Y=np.array(self.plot_data['Y']), 103 | opts={ 104 | 'title': self.name + ' loss over time', 105 | 'legend': self.plot_data['legend'], 106 | 'xlabel': 'epoch', 107 | 'ylabel': 'loss'}, 108 | win=self.display_id) 109 | 110 | # errors: same format as |errors| of plotCurrentErrors 111 | def print_current_errors(self, epoch, i, errors, t): 112 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) 113 | for k, v in errors.items(): 114 | v = ['%.3f' % iv for iv in v] 115 | message += k + ': ' + ', '.join(v) + ' | ' 116 | 117 | print(message) 118 | with open(self.log_name, "a") as log_file: 119 | log_file.write('%s\n' % message) 120 | 121 | # save image to the disk 122 | def save_images(self, webpage, visuals, image_path): 123 | image_dir = webpage.get_image_dir() 124 | short_path = ntpath.basename(image_path[0]) 125 | name = os.path.splitext(short_path)[0] 126 | 127 | webpage.add_header(name) 128 | ims = [] 129 | txts = [] 130 | links = [] 131 | 132 | for label, image_numpy in visuals.items(): 133 | image_name = '%s_%s.png' % (name, label) 134 | save_path = os.path.join(image_dir, image_name) 135 | util.save_image(image_numpy, save_path) 136 | 137 | ims.append(image_name) 138 | txts.append(label) 139 | links.append(image_name) 140 | webpage.add_images(ims, txts, links, width=self.win_size) 141 | 142 | def save_image_matrix(self, visuals_list, save_path): 143 | images_list = [] 144 | get_domain = lambda x: x.split('_')[-1] 145 | 146 | for visuals in visuals_list: 147 | pairs = list(visuals.items()) 148 | real_label, real_img = pairs[0] 149 | real_dom = get_domain(real_label) 150 | 151 | for label, img in pairs: 152 | if 'fake' not in label: 153 | continue 154 | if get_domain(label) == real_dom: 155 | images_list.append(real_img) 156 | else: 157 | images_list.append(img) 158 | 159 | immat = self.stack_images(images_list) 160 | util.save_image(immat, save_path) 161 | 162 | # reshape a list of images into a square matrix of them 163 | def stack_images(self, list_np_images): 164 | n = int(np.ceil(np.sqrt(len(list_np_images)))) 165 | 166 | # add padding between images 167 | for i, im in enumerate(list_np_images): 168 | val = 255 if i%n == i//n else 0 169 | r_pad = np.pad(im[:,:,0], (3,3), mode='constant', constant_values=0) 170 | g_pad = np.pad(im[:,:,1], (3,3), mode='constant', constant_values=val) 171 | b_pad = np.pad(im[:,:,2], (3,3), mode='constant', constant_values=0) 172 | list_np_images[i] = np.stack([r_pad,g_pad,b_pad], axis=2) 173 | 174 | data = np.array(list_np_images) 175 | data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1))) 176 | data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:]) 177 | return data 178 | --------------------------------------------------------------------------------