├── .gitignore
├── LICENSE
├── README.md
├── data
├── __init__.py
├── base_dataset.py
├── data_loader.py
├── image_folder.py
└── unaligned_dataset.py
├── img
├── Discr.png
├── Matches.png
└── Res.png
├── models
├── __init__.py
├── base_model.py
├── combogan_model.py
└── networks.py
├── options
├── __init__.py
├── base_options.py
├── test_options.py
└── train_options.py
├── scripts
├── continue_combogan.sh
├── test_combogan.sh
└── train_combogan.sh
├── test.py
├── train.py
└── util
├── __init__.py
├── html.py
├── image_pool.py
├── png.py
├── util.py
└── visualizer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | datasets/
2 | checkpoints/
3 | results/
4 | matlab/
5 | */**/__pycache__
6 | */*.pyc
7 | */**/*.pyc
8 | */**/**/*.pyc
9 | */**/**/**/*.pyc
10 | */**/**/**/**/*.pyc
11 | */*.so*
12 | */**/*.so*
13 | */**/*.dylib*
14 | *~
15 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017, Asha Anoosheh
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # ToDayGAN
3 |
4 | This is our PyTorch implementation for ToDayGAN.
5 | Code was written by [Asha Anoosheh](https://github.com/aanoosheh) (built upon [ComboGAN](https://github.com/AAnoosheh/ComboGAN))
6 |
7 | #### [[ToDayGAN Paper]](https://arxiv.org/pdf/1809.09767.pdf)
8 | #### [[ComboGAN Paper]](https://arxiv.org/pdf/1712.06909.pdf)
9 |
10 |
11 | If you use this code for your research, please cite:
12 |
13 | Night-to-Day Image Translation for Retrieval-based Localization
14 | [Asha Anoosheh](http://ashaanoosheh.com), [Torsten Sattler](http://people.inf.ethz.ch/sattlert/), [Radu Timofte](http://www.vision.ee.ethz.ch/~timofter/), [Marc Pollefeys](https://www.microsoft.com/en-us/research/people/mapoll/), [Luc van Gool](https://www.vision.ee.ethz.ch/en/members/get_member.cgi?id=1)
15 | In Arxiv, 2018.
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 | ## Prerequisites
26 | - Linux or macOS
27 | - Python 3
28 | - CPU or NVIDIA GPU + CUDA CuDNN
29 |
30 | ## Getting Started
31 | ### Installation
32 | - Install requisite Python libraries.
33 | ```bash
34 | pip install torch
35 | pip install torchvision
36 | pip install visdom
37 | pip install dominate
38 | ```
39 | - Clone this repo:
40 | ```bash
41 | git clone https://github.com/AAnoosheh/ToDayGAN.git
42 | ```
43 |
44 | ### Training
45 | Example running scripts can be found in the `scripts` directory.
46 |
47 | One of our pretrained models for the Oxford Robotcars dataset is found [HERE](https://www.dropbox.com/s/mwqfbs19cptrej6/2DayGAN_Checkpoint150.zip?dl=0). Place under ./checkpoints/robotcar_2day and test using the instructions below, with args `--name robotcar_2day --dataroot ./datasets/ --n_domains 2 --which_epoch 150 --loadSize 512`
48 |
49 | Because of sesitivity to instrinsic camera characteristics, testing should ideally be on the same Oxford dataset photos (and same Grasshopper camera) found conveniently preprocessed and ready-to-use [HERE](https://www.visuallocalization.net/datasets/).
50 |
51 | If using this pretrained model, `` should contain two subfolders `test0` & `test1`, containing Day and Night images to test, respectively (as mine was trained with this ordering). `test0` can be empty if you do not care about Day image translated to Night, but just needs to exist to not break the code.
52 |
53 | - Train a model:
54 | ```
55 | python train.py --name --dataroot ./datasets/ --n_domains --niter --niter_decay
56 | ```
57 | Checkpoints will be saved by default to `./checkpoints//`
58 | - Fine-tuning/Resume training:
59 | ```
60 | python train.py --continue_train --which_epoch --name --dataroot ./datasets/ --n_domains --niter --niter_decay
61 | ```
62 | - Test the model:
63 | ```
64 | python test.py --phase test --serial_test --name --dataroot ./datasets/ --n_domains --which_epoch
65 | ```
66 | The test results will be saved to a html file here: `./results///index.html`.
67 |
68 |
69 |
70 | ## Training/Testing Details
71 | - Flags: see `options/train_options.py` for training-specific flags; see `options/test_options.py` for test-specific flags; and see `options/base_options.py` for all common flags.
72 | - Dataset format: The desired data directory (provided by `--dataroot`) should contain subfolders of the form `train*/` and `test*/`, and they are loaded in alphabetical order. (Note that a folder named train10 would be loaded before train2, and thus all checkpoints and results would be ordered accordingly.) Test directories should match alphabetical ordering of the training ones.
73 | - CPU/GPU (default `--gpu_ids 0`): set`--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode.
74 | - Visualization: during training, the current results and loss plots can be viewed using two methods. First, if you set `--display_id` > 0, the results and loss plot will appear on a local graphics web server launched by [visdom](https://github.com/facebookresearch/visdom). To do this, you should have `visdom` installed and a server running by the command `python -m visdom.server`. The default server URL is `http://localhost:8097`. `display_id` corresponds to the window ID that is displayed on the `visdom` server. The `visdom` display functionality is turned on by default. To avoid the extra overhead of communicating with `visdom` set `--display_id 0`. Secondly, the intermediate results are also saved to `./checkpoints//web/index.html`. To avoid this, set the `--no_html` flag.
75 | - Preprocessing: images can be resized and cropped in different ways using `--resize_or_crop` option. The default option `'resize_and_crop'` resizes the image such that the largest side becomes `opt.loadSize` and then does a random crop of size `(opt.fineSize, opt.fineSize)`. Other options are either just `resize` or `crop` on their own.
76 |
77 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AAnoosheh/ToDayGAN/12de3af4f209cd227bdc54160fc3164f97f6732d/data/__init__.py
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from PIL import Image
3 | import torchvision.transforms as transforms
4 |
5 | class BaseDataset(data.Dataset):
6 | def __init__(self):
7 | super(BaseDataset, self).__init__()
8 |
9 | def name(self):
10 | return 'BaseDataset'
11 |
12 | def initialize(self, opt):
13 | pass
14 |
15 | def get_transform(opt):
16 | transform_list = []
17 | if 'resize' in opt.resize_or_crop:
18 | transform_list.append(transforms.Resize(opt.loadSize, Image.BICUBIC))
19 |
20 | if opt.isTrain:
21 | if 'crop' in opt.resize_or_crop:
22 | transform_list.append(transforms.RandomCrop(opt.fineSize))
23 | if not opt.no_flip:
24 | transform_list.append(transforms.RandomHorizontalFlip())
25 |
26 | transform_list += [transforms.ToTensor(),
27 | transforms.Normalize((0.5, 0.5, 0.5),
28 | (0.5, 0.5, 0.5))]
29 | return transforms.Compose(transform_list)
30 |
--------------------------------------------------------------------------------
/data/data_loader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from data.unaligned_dataset import UnalignedDataset
3 |
4 |
5 | class DataLoader():
6 | def name(self):
7 | return 'DataLoader'
8 |
9 | def __init__(self, opt):
10 | self.opt = opt
11 | self.dataset = UnalignedDataset(opt)
12 | self.dataloader = torch.utils.data.DataLoader(
13 | self.dataset,
14 | batch_size=opt.batchSize,
15 | num_workers=int(opt.nThreads))
16 |
17 | def __len__(self):
18 | return min(len(self.dataset), self.opt.max_dataset_size)
19 |
20 | def __iter__(self):
21 | for i, data in enumerate(self.dataloader):
22 | if i >= self.opt.max_dataset_size:
23 | break
24 | yield data
25 |
26 |
--------------------------------------------------------------------------------
/data/image_folder.py:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # Code from
3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4 | # Modified the original code so that it also loads images from the current
5 | # directory as well as the subdirectories
6 | ###############################################################################
7 |
8 | import torch.utils.data as data
9 |
10 | from PIL import Image
11 | import os
12 | import os.path
13 |
14 | IMG_EXTENSIONS = [
15 | '.jpg', '.JPG', '.jpeg', '.JPEG',
16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
17 | ]
18 |
19 |
20 | def is_image_file(filename):
21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22 |
23 |
24 | def make_dataset(dir):
25 | images = []
26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
27 |
28 | for root, _, fnames in sorted(os.walk(dir)):
29 | for fname in fnames:
30 | if is_image_file(fname):
31 | path = os.path.join(root, fname)
32 | images.append(path)
33 |
34 | return images
35 |
36 |
37 | def default_loader(path):
38 | return Image.open(path).convert('RGB')
39 |
40 |
41 | class ImageFolder(data.Dataset):
42 |
43 | def __init__(self, root, transform=None, return_paths=False,
44 | loader=default_loader):
45 | imgs = make_dataset(root)
46 | if len(imgs) == 0:
47 | raise(RuntimeError("Found 0 images in: " + root + "\n"
48 | "Supported image extensions are: " +
49 | ",".join(IMG_EXTENSIONS)))
50 |
51 | self.root = root
52 | self.imgs = imgs
53 | self.transform = transform
54 | self.return_paths = return_paths
55 | self.loader = loader
56 |
57 | def __getitem__(self, index):
58 | path = self.imgs[index]
59 | img = self.loader(path)
60 | if self.transform is not None:
61 | img = self.transform(img)
62 | if self.return_paths:
63 | return img, path
64 | else:
65 | return img
66 |
67 | def __len__(self):
68 | return len(self.imgs)
69 |
--------------------------------------------------------------------------------
/data/unaligned_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path, glob
2 | import torchvision.transforms as transforms
3 | from data.base_dataset import BaseDataset, get_transform
4 | from data.image_folder import make_dataset
5 | from PIL import Image
6 | import random
7 |
8 | class UnalignedDataset(BaseDataset):
9 | def __init__(self, opt):
10 | super(UnalignedDataset, self).__init__()
11 | self.opt = opt
12 | self.transform = get_transform(opt)
13 |
14 | datapath = os.path.join(opt.dataroot, opt.phase + '*')
15 | self.dirs = sorted(glob.glob(datapath))
16 |
17 | self.paths = [sorted(make_dataset(d)) for d in self.dirs]
18 | self.sizes = [len(p) for p in self.paths]
19 |
20 | def load_image(self, dom, idx):
21 | path = self.paths[dom][idx]
22 | img = Image.open(path).convert('RGB')
23 | img = self.transform(img)
24 | return img, path
25 |
26 | def __getitem__(self, index):
27 | if not self.opt.isTrain:
28 | if self.opt.serial_test:
29 | for d,s in enumerate(self.sizes):
30 | if index < s:
31 | DA = d; break
32 | index -= s
33 | index_A = index
34 | else:
35 | DA = index % len(self.dirs)
36 | index_A = random.randint(0, self.sizes[DA] - 1)
37 | else:
38 | # Choose two of our domains to perform a pass on
39 | DA, DB = random.sample(range(len(self.dirs)), 2)
40 | index_A = random.randint(0, self.sizes[DA] - 1)
41 |
42 | A_img, A_path = self.load_image(DA, index_A)
43 | bundle = {'A': A_img, 'DA': DA, 'path': A_path}
44 |
45 | if self.opt.isTrain:
46 | index_B = random.randint(0, self.sizes[DB] - 1)
47 | B_img, _ = self.load_image(DB, index_B)
48 | bundle.update( {'B': B_img, 'DB': DB} )
49 |
50 | return bundle
51 |
52 | def __len__(self):
53 | if self.opt.isTrain:
54 | return max(self.sizes)
55 | return sum(self.sizes)
56 |
57 | def name(self):
58 | return 'UnalignedDataset'
59 |
--------------------------------------------------------------------------------
/img/Discr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AAnoosheh/ToDayGAN/12de3af4f209cd227bdc54160fc3164f97f6732d/img/Discr.png
--------------------------------------------------------------------------------
/img/Matches.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AAnoosheh/ToDayGAN/12de3af4f209cd227bdc54160fc3164f97f6732d/img/Matches.png
--------------------------------------------------------------------------------
/img/Res.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AAnoosheh/ToDayGAN/12de3af4f209cd227bdc54160fc3164f97f6732d/img/Res.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AAnoosheh/ToDayGAN/12de3af4f209cd227bdc54160fc3164f97f6732d/models/__init__.py
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 |
5 | class BaseModel():
6 | def name(self):
7 | return 'BaseModel'
8 |
9 | def __init__(self, opt):
10 | self.opt = opt
11 | self.gpu_ids = opt.gpu_ids
12 | self.isTrain = opt.isTrain
13 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
14 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
15 |
16 | def set_input(self, input):
17 | self.input = input
18 |
19 | def forward(self):
20 | pass
21 |
22 | # used in test time, no backprop
23 | def test(self):
24 | pass
25 |
26 | def get_image_paths(self):
27 | pass
28 |
29 | def optimize_parameters(self):
30 | pass
31 |
32 | def get_current_visuals(self):
33 | return self.input
34 |
35 | def get_current_errors(self):
36 | return {}
37 |
38 | def save(self, label):
39 | pass
40 |
41 | # helper saving function that can be used by subclasses
42 | def save_network(self, network, network_label, epoch, gpu_ids):
43 | save_filename = '%d_net_%s' % (epoch, network_label)
44 | save_path = os.path.join(self.save_dir, save_filename)
45 | network.save(save_path)
46 | if gpu_ids and torch.cuda.is_available():
47 | network.cuda(gpu_ids[0])
48 |
49 | # helper loading function that can be used by subclasses
50 | def load_network(self, network, network_label, epoch):
51 | save_filename = '%d_net_%s' % (epoch, network_label)
52 | save_path = os.path.join(self.save_dir, save_filename)
53 | network.load(save_path)
54 |
55 | def update_learning_rate():
56 | pass
57 |
--------------------------------------------------------------------------------
/models/combogan_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from collections import OrderedDict
4 | import util.util as util
5 | from util.image_pool import ImagePool
6 | from .base_model import BaseModel
7 | from . import networks
8 |
9 |
10 | class ComboGANModel(BaseModel):
11 | def name(self):
12 | return 'ComboGANModel'
13 |
14 | def __init__(self, opt):
15 | super(ComboGANModel, self).__init__(opt)
16 |
17 | self.n_domains = opt.n_domains
18 | self.DA, self.DB = None, None
19 |
20 | self.real_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
21 | self.real_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
22 |
23 | # load/define networks
24 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
25 | opt.netG_n_blocks, opt.netG_n_shared,
26 | self.n_domains, opt.norm, opt.use_dropout, self.gpu_ids)
27 | if self.isTrain:
28 | self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD_n_layers,
29 | self.n_domains, self.Tensor, opt.norm, self.gpu_ids)
30 |
31 | if not self.isTrain or opt.continue_train:
32 | which_epoch = opt.which_epoch
33 | self.load_network(self.netG, 'G', which_epoch)
34 | if self.isTrain:
35 | self.load_network(self.netD, 'D', which_epoch)
36 |
37 | if self.isTrain:
38 | self.fake_pools = [ImagePool(opt.pool_size) for _ in range(self.n_domains)]
39 | # define loss functions
40 | self.L1 = torch.nn.SmoothL1Loss()
41 | self.downsample = torch.nn.AvgPool2d(3, stride=2)
42 | self.criterionCycle = self.L1
43 | self.criterionIdt = lambda y,t : self.L1(self.downsample(y), self.downsample(t))
44 | self.criterionLatent = lambda y,t : self.L1(y, t.detach())
45 | self.criterionGAN = lambda r,f,v : (networks.GANLoss(r[0],f[0],v) + \
46 | networks.GANLoss(r[1],f[1],v) + \
47 | networks.GANLoss(r[2],f[2],v)) / 3
48 | # initialize optimizers
49 | self.netG.init_optimizers(torch.optim.Adam, opt.lr, (opt.beta1, 0.999))
50 | self.netD.init_optimizers(torch.optim.Adam, opt.lr, (opt.beta1, 0.999))
51 | # initialize loss storage
52 | self.loss_D, self.loss_G = [0]*self.n_domains, [0]*self.n_domains
53 | self.loss_cycle = [0]*self.n_domains
54 | # initialize loss multipliers
55 | self.lambda_cyc, self.lambda_enc = opt.lambda_cycle, (0 * opt.lambda_latent)
56 | self.lambda_idt, self.lambda_fwd = opt.lambda_identity, opt.lambda_forward
57 |
58 | print('---------- Networks initialized -------------')
59 | print(self.netG)
60 | if self.isTrain:
61 | print(self.netD)
62 | print('-----------------------------------------------')
63 |
64 | def set_input(self, input):
65 | input_A = input['A']
66 | self.real_A.resize_(input_A.size()).copy_(input_A)
67 | self.DA = input['DA'][0]
68 | if self.isTrain:
69 | input_B = input['B']
70 | self.real_B.resize_(input_B.size()).copy_(input_B)
71 | self.DB = input['DB'][0]
72 | self.image_paths = input['path']
73 |
74 | def test(self):
75 | with torch.no_grad():
76 | self.visuals = [self.real_A]
77 | self.labels = ['real_%d' % self.DA]
78 |
79 | # cache encoding to not repeat it everytime
80 | encoded = self.netG.encode(self.real_A, self.DA)
81 | for d in range(self.n_domains):
82 | if d == self.DA and not self.opt.autoencode:
83 | continue
84 | fake = self.netG.decode(encoded, d)
85 | self.visuals.append( fake )
86 | self.labels.append( 'fake_%d' % d )
87 | if self.opt.reconstruct:
88 | rec = self.netG.forward(fake, d, self.DA)
89 | self.visuals.append( rec )
90 | self.labels.append( 'rec_%d' % d )
91 |
92 | def get_image_paths(self):
93 | return self.image_paths
94 |
95 | def backward_D_basic(self, pred_real, fake, domain):
96 | pred_fake = self.netD.forward(fake.detach(), domain)
97 | loss_D = self.criterionGAN(pred_real, pred_fake, True) * 0.5
98 | loss_D.backward()
99 | return loss_D
100 |
101 | def backward_D(self):
102 | #D_A
103 | fake_B = self.fake_pools[self.DB].query(self.fake_B)
104 | self.loss_D[self.DA] = self.backward_D_basic(self.pred_real_B, fake_B, self.DB)
105 | #D_B
106 | fake_A = self.fake_pools[self.DA].query(self.fake_A)
107 | self.loss_D[self.DB] = self.backward_D_basic(self.pred_real_A, fake_A, self.DA)
108 |
109 | def backward_G(self):
110 | encoded_A = self.netG.encode(self.real_A, self.DA)
111 | encoded_B = self.netG.encode(self.real_B, self.DB)
112 |
113 | # Optional identity "autoencode" loss
114 | if self.lambda_idt > 0:
115 | # Same encoder and decoder should recreate image
116 | idt_A = self.netG.decode(encoded_A, self.DA)
117 | loss_idt_A = self.criterionIdt(idt_A, self.real_A)
118 | idt_B = self.netG.decode(encoded_B, self.DB)
119 | loss_idt_B = self.criterionIdt(idt_B, self.real_B)
120 | else:
121 | loss_idt_A, loss_idt_B = 0, 0
122 |
123 | # GAN loss
124 | # D_A(G_A(A))
125 | self.fake_B = self.netG.decode(encoded_A, self.DB)
126 | pred_fake = self.netD.forward(self.fake_B, self.DB)
127 | self.loss_G[self.DA] = self.criterionGAN(self.pred_real_B, pred_fake, False)
128 | # D_B(G_B(B))
129 | self.fake_A = self.netG.decode(encoded_B, self.DA)
130 | pred_fake = self.netD.forward(self.fake_A, self.DA)
131 | self.loss_G[self.DB] = self.criterionGAN(self.pred_real_A, pred_fake, False)
132 | # Forward cycle loss
133 | rec_encoded_A = self.netG.encode(self.fake_B, self.DB)
134 | self.rec_A = self.netG.decode(rec_encoded_A, self.DA)
135 | self.loss_cycle[self.DA] = self.criterionCycle(self.rec_A, self.real_A)
136 | # Backward cycle loss
137 | rec_encoded_B = self.netG.encode(self.fake_A, self.DA)
138 | self.rec_B = self.netG.decode(rec_encoded_B, self.DB)
139 | self.loss_cycle[self.DB] = self.criterionCycle(self.rec_B, self.real_B)
140 |
141 | # Optional cycle loss on encoding space
142 | if self.lambda_enc > 0:
143 | loss_enc_A = self.criterionLatent(rec_encoded_A, encoded_A)
144 | loss_enc_B = self.criterionLatent(rec_encoded_B, encoded_B)
145 | else:
146 | loss_enc_A, loss_enc_B = 0, 0
147 |
148 | # Optional loss on downsampled image before and after
149 | if self.lambda_fwd > 0:
150 | loss_fwd_A = self.criterionIdt(self.fake_B, self.real_A)
151 | loss_fwd_B = self.criterionIdt(self.fake_A, self.real_B)
152 | else:
153 | loss_fwd_A, loss_fwd_B = 0, 0
154 |
155 | # combined loss
156 | loss_G = self.loss_G[self.DA] + self.loss_G[self.DB] + \
157 | (self.loss_cycle[self.DA] + self.loss_cycle[self.DB]) * self.lambda_cyc + \
158 | (loss_idt_A + loss_idt_B) * self.lambda_idt + \
159 | (loss_enc_A + loss_enc_B) * self.lambda_enc + \
160 | (loss_fwd_A + loss_fwd_B) * self.lambda_fwd
161 | loss_G.backward()
162 |
163 | def optimize_parameters(self):
164 | self.pred_real_A = self.netD.forward(self.real_A, self.DA)
165 | self.pred_real_B = self.netD.forward(self.real_B, self.DB)
166 | # G_A and G_B
167 | self.netG.zero_grads(self.DA, self.DB)
168 | self.backward_G()
169 | self.netG.step_grads(self.DA, self.DB)
170 | # D_A and D_B
171 | self.netD.zero_grads(self.DA, self.DB)
172 | self.backward_D()
173 | self.netD.step_grads(self.DA, self.DB)
174 |
175 | def get_current_errors(self):
176 | extract = lambda l: [(i if type(i) is int or type(i) is float else i.item()) for i in l]
177 | D_losses, G_losses, cyc_losses = extract(self.loss_D), extract(self.loss_G), extract(self.loss_cycle)
178 | return OrderedDict([('D', D_losses), ('G', G_losses), ('Cyc', cyc_losses)])
179 |
180 | def get_current_visuals(self, testing=False):
181 | if not testing:
182 | self.visuals = [self.real_A, self.fake_B, self.rec_A, self.real_B, self.fake_A, self.rec_B]
183 | self.labels = ['real_A', 'fake_B', 'rec_A', 'real_B', 'fake_A', 'rec_B']
184 | images = [util.tensor2im(v.data) for v in self.visuals]
185 | return OrderedDict(zip(self.labels, images))
186 |
187 | def save(self, label):
188 | self.save_network(self.netG, 'G', label, self.gpu_ids)
189 | self.save_network(self.netD, 'D', label, self.gpu_ids)
190 |
191 | def update_hyperparams(self, curr_iter):
192 | if curr_iter > self.opt.niter:
193 | decay_frac = (curr_iter - self.opt.niter) / self.opt.niter_decay
194 | new_lr = self.opt.lr * (1 - decay_frac)
195 | self.netG.update_lr(new_lr)
196 | self.netD.update_lr(new_lr)
197 | print('updated learning rate: %f' % new_lr)
198 |
199 | if self.opt.lambda_latent > 0:
200 | decay_frac = curr_iter / (self.opt.niter + self.opt.niter_decay)
201 | self.lambda_enc = self.opt.lambda_latent * decay_frac
202 |
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import functools, itertools
5 | import numpy as np
6 | from util.util import gkern_2d
7 |
8 |
9 |
10 |
11 | def weights_init(m):
12 | classname = m.__class__.__name__
13 | if classname.find('Conv') != -1:
14 | m.weight.data.normal_(0.0, 0.02)
15 | if hasattr(m.bias, 'data'):
16 | m.bias.data.fill_(0)
17 | elif classname.find('BatchNorm2d') != -1:
18 | m.weight.data.normal_(1.0, 0.02)
19 | m.bias.data.fill_(0)
20 |
21 |
22 | def get_norm_layer(norm_type='instance'):
23 | if norm_type == 'batch':
24 | return functools.partial(nn.BatchNorm2d, affine=True)
25 | elif norm_type == 'instance':
26 | return functools.partial(nn.InstanceNorm2d, affine=False)
27 | else:
28 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
29 |
30 |
31 | def define_G(input_nc, output_nc, ngf, n_blocks, n_blocks_shared, n_domains, norm='batch', use_dropout=False, gpu_ids=[]):
32 | norm_layer = get_norm_layer(norm_type=norm)
33 | if type(norm_layer) == functools.partial:
34 | use_bias = norm_layer.func == nn.InstanceNorm2d
35 | else:
36 | use_bias = norm_layer == nn.InstanceNorm2d
37 |
38 | n_blocks -= n_blocks_shared
39 | n_blocks_enc = n_blocks // 2
40 | n_blocks_dec = n_blocks - n_blocks_enc
41 |
42 | dup_args = (ngf, norm_layer, use_dropout, gpu_ids, use_bias)
43 | enc_args = (input_nc, n_blocks_enc) + dup_args
44 | dec_args = (output_nc, n_blocks_dec) + dup_args
45 |
46 | if n_blocks_shared > 0:
47 | n_blocks_shdec = n_blocks_shared // 2
48 | n_blocks_shenc = n_blocks_shared - n_blocks_shdec
49 | shenc_args = (n_domains, n_blocks_shenc) + dup_args
50 | shdec_args = (n_domains, n_blocks_shdec) + dup_args
51 | plex_netG = G_Plexer(n_domains, ResnetGenEncoder, enc_args, ResnetGenDecoder, dec_args, ResnetGenShared, shenc_args, shdec_args)
52 | else:
53 | plex_netG = G_Plexer(n_domains, ResnetGenEncoder, enc_args, ResnetGenDecoder, dec_args)
54 |
55 | if len(gpu_ids) > 0:
56 | assert(torch.cuda.is_available())
57 | plex_netG.cuda(gpu_ids[0])
58 |
59 | plex_netG.apply(weights_init)
60 | return plex_netG
61 |
62 |
63 | def define_D(input_nc, ndf, netD_n_layers, n_domains, tensor, norm='batch', gpu_ids=[]):
64 | norm_layer = get_norm_layer(norm_type=norm)
65 |
66 | model_args = (input_nc, ndf, netD_n_layers, tensor, norm_layer, gpu_ids)
67 | plex_netD = D_Plexer(n_domains, NLayerDiscriminator, model_args)
68 |
69 | if len(gpu_ids) > 0:
70 | assert(torch.cuda.is_available())
71 | plex_netD.cuda(gpu_ids[0])
72 |
73 | plex_netD.apply(weights_init)
74 | return plex_netD
75 |
76 |
77 | ##############################################################################
78 | # Classes
79 | ##############################################################################
80 |
81 |
82 | # Defines the GAN loss which uses the Relativistic LSGAN
83 | def GANLoss(inputs_real, inputs_fake, is_discr):
84 | if is_discr:
85 | y = -1
86 | else:
87 | y = 1
88 | inputs_real = [i.detach() for i in inputs_real]
89 | loss = lambda r,f : torch.mean((r-f+y)**2)
90 | losses = [loss(r,f) for r,f in zip(inputs_real, inputs_fake)]
91 | multipliers = list(range(1, len(inputs_real)+1)); multipliers[-1] += 1
92 | losses = [m*l for m,l in zip(multipliers, losses)]
93 | return sum(losses) / (sum(multipliers) * len(losses))
94 |
95 |
96 | # Defines the generator that consists of Resnet blocks between a few
97 | # downsampling/upsampling operations.
98 | # Code and idea originally from Justin Johnson's architecture.
99 | # https://github.com/jcjohnson/fast-neural-style/
100 | class ResnetGenEncoder(nn.Module):
101 | def __init__(self, input_nc, n_blocks=4, ngf=64, norm_layer=nn.BatchNorm2d,
102 | use_dropout=False, gpu_ids=[], use_bias=False, padding_type='reflect'):
103 | assert(n_blocks >= 0)
104 | super(ResnetGenEncoder, self).__init__()
105 | self.gpu_ids = gpu_ids
106 |
107 | model = [nn.ReflectionPad2d(3),
108 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
109 | bias=use_bias),
110 | norm_layer(ngf),
111 | nn.PReLU()]
112 |
113 | n_downsampling = 2
114 | for i in range(n_downsampling):
115 | mult = 2**i
116 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
117 | stride=2, padding=1, bias=use_bias),
118 | norm_layer(ngf * mult * 2),
119 | nn.PReLU()]
120 |
121 | mult = 2**n_downsampling
122 | for _ in range(n_blocks):
123 | model += [ResnetBlock(ngf * mult, norm_layer=norm_layer,
124 | use_dropout=use_dropout, use_bias=use_bias, padding_type=padding_type)]
125 |
126 | self.model = nn.Sequential(*model)
127 |
128 | def forward(self, input):
129 | if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
130 | return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
131 | return self.model(input)
132 |
133 | class ResnetGenShared(nn.Module):
134 | def __init__(self, n_domains, n_blocks=2, ngf=64, norm_layer=nn.BatchNorm2d,
135 | use_dropout=False, gpu_ids=[], use_bias=False, padding_type='reflect'):
136 | assert(n_blocks >= 0)
137 | super(ResnetGenShared, self).__init__()
138 | self.gpu_ids = gpu_ids
139 |
140 | model = []
141 | n_downsampling = 2
142 | mult = 2**n_downsampling
143 |
144 | for _ in range(n_blocks):
145 | model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, n_domains=n_domains,
146 | use_dropout=use_dropout, use_bias=use_bias, padding_type=padding_type)]
147 |
148 | self.model = SequentialContext(n_domains, *model)
149 |
150 | def forward(self, input, domain):
151 | if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
152 | return nn.parallel.data_parallel(self.model, (input, domain), self.gpu_ids)
153 | return self.model(input, domain)
154 |
155 | class ResnetGenDecoder(nn.Module):
156 | def __init__(self, output_nc, n_blocks=5, ngf=64, norm_layer=nn.BatchNorm2d,
157 | use_dropout=False, gpu_ids=[], use_bias=False, padding_type='reflect'):
158 | assert(n_blocks >= 0)
159 | super(ResnetGenDecoder, self).__init__()
160 | self.gpu_ids = gpu_ids
161 |
162 | model = []
163 | n_downsampling = 2
164 | mult = 2**n_downsampling
165 |
166 | for _ in range(n_blocks):
167 | model += [ResnetBlock(ngf * mult, norm_layer=norm_layer,
168 | use_dropout=use_dropout, use_bias=use_bias, padding_type=padding_type)]
169 |
170 | for i in range(n_downsampling):
171 | mult = 2**(n_downsampling - i)
172 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
173 | kernel_size=4, stride=2,
174 | padding=1, output_padding=0,
175 | bias=use_bias),
176 | norm_layer(int(ngf * mult / 2)),
177 | nn.PReLU()]
178 |
179 | model += [nn.ReflectionPad2d(3),
180 | nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
181 | nn.Tanh()]
182 |
183 | self.model = nn.Sequential(*model)
184 |
185 | def forward(self, input):
186 | if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
187 | return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
188 | return self.model(input)
189 |
190 |
191 | # Define a resnet block
192 | class ResnetBlock(nn.Module):
193 | def __init__(self, dim, norm_layer, use_dropout, use_bias, padding_type='reflect', n_domains=0):
194 | super(ResnetBlock, self).__init__()
195 |
196 | conv_block = []
197 | p = 0
198 | if padding_type == 'reflect':
199 | conv_block += [nn.ReflectionPad2d(1)]
200 | elif padding_type == 'replicate':
201 | conv_block += [nn.ReplicationPad2d(1)]
202 | elif padding_type == 'zero':
203 | p = 1
204 | else:
205 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
206 |
207 | conv_block += [nn.Conv2d(dim + n_domains, dim, kernel_size=3, padding=p, bias=use_bias),
208 | norm_layer(dim),
209 | nn.PReLU()]
210 | if use_dropout:
211 | conv_block += [nn.Dropout(0.5)]
212 |
213 | p = 0
214 | if padding_type == 'reflect':
215 | conv_block += [nn.ReflectionPad2d(1)]
216 | elif padding_type == 'replicate':
217 | conv_block += [nn.ReplicationPad2d(1)]
218 | elif padding_type == 'zero':
219 | p = 1
220 | else:
221 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
222 | conv_block += [nn.Conv2d(dim + n_domains, dim, kernel_size=3, padding=p, bias=use_bias),
223 | norm_layer(dim)]
224 |
225 | self.conv_block = SequentialContext(n_domains, *conv_block)
226 |
227 | def forward(self, input):
228 | if isinstance(input, tuple):
229 | return input[0] + self.conv_block(*input)
230 | return input + self.conv_block(input)
231 |
232 |
233 | # Defines the PatchGAN discriminator with the specified arguments.
234 | class NLayerDiscriminator(nn.Module):
235 | def __init__(self, input_nc, ndf=64, n_layers=3, tensor=torch.FloatTensor, norm_layer=nn.BatchNorm2d, gpu_ids=[]):
236 | super(NLayerDiscriminator, self).__init__()
237 | self.gpu_ids = gpu_ids
238 | self.grad_filter = tensor([0,0,0,-1,0,1,0,0,0]).view(1,1,3,3)
239 | self.dsamp_filter = tensor([1]).view(1,1,1,1)
240 | self.blur_filter = tensor(gkern_2d())
241 |
242 | self.model_rgb = self.model(input_nc, ndf, n_layers, norm_layer)
243 | self.model_gray = self.model(1, ndf, n_layers, norm_layer)
244 | self.model_grad = self.model(2, ndf, n_layers-1, norm_layer)
245 |
246 | def model(self, input_nc, ndf, n_layers, norm_layer):
247 | if type(norm_layer) == functools.partial:
248 | use_bias = norm_layer.func == nn.InstanceNorm2d
249 | else:
250 | use_bias = norm_layer == nn.InstanceNorm2d
251 |
252 | kw = 4
253 | padw = int(np.ceil((kw-1)/2))
254 | sequences = [[
255 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
256 | nn.PReLU()
257 | ]]
258 |
259 | nf_mult = 1
260 | nf_mult_prev = 1
261 | for n in range(1, n_layers):
262 | nf_mult_prev = nf_mult
263 | nf_mult = min(2**n, 8)
264 | sequences += [[
265 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult + 1,
266 | kernel_size=kw, stride=2, padding=padw, bias=use_bias),
267 | norm_layer(ndf * nf_mult + 1),
268 | nn.PReLU()
269 | ]]
270 |
271 | nf_mult_prev = nf_mult
272 | nf_mult = min(2**n_layers, 8)
273 | sequences += [[
274 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
275 | kernel_size=kw, stride=1, padding=padw, bias=use_bias),
276 | norm_layer(ndf * nf_mult),
277 | nn.PReLU(),
278 | \
279 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
280 | ]]
281 |
282 | return SequentialOutput(*sequences)
283 |
284 | def forward(self, input):
285 | blurred = torch.nn.functional.conv2d(input, self.blur_filter, groups=3, padding=2)
286 | gray = (.299*input[:,0,:,:] + .587*input[:,1,:,:] + .114*input[:,2,:,:]).unsqueeze_(1)
287 |
288 | gray_dsamp = nn.functional.conv2d(gray, self.dsamp_filter, stride=2)
289 | dx = nn.functional.conv2d(gray_dsamp, self.grad_filter)
290 | dy = nn.functional.conv2d(gray_dsamp, self.grad_filter.transpose(-2,-1))
291 | gradient = torch.cat([dx,dy], 1)
292 |
293 | if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
294 | outs1 = nn.parallel.data_parallel(self.model_rgb, blurred, self.gpu_ids)
295 | outs2 = nn.parallel.data_parallel(self.model_gray, gray, self.gpu_ids)
296 | outs3 = nn.parallel.data_parallel(self.model_grad, gradient, self.gpu_ids)
297 | else:
298 | outs1 = self.model_rgb(blurred)
299 | outs2 = self.model_gray(gray)
300 | outs3 = self.model_grad(gradient)
301 | return outs1, outs2, outs3
302 |
303 |
304 | class Plexer(nn.Module):
305 | def __init__(self):
306 | super(Plexer, self).__init__()
307 |
308 | def apply(self, func):
309 | for net in self.networks:
310 | net.apply(func)
311 |
312 | def cuda(self, device_id):
313 | for net in self.networks:
314 | net.cuda(device_id)
315 |
316 | def init_optimizers(self, opt, lr, betas):
317 | self.optimizers = [opt(net.parameters(), lr=lr, betas=betas) \
318 | for net in self.networks]
319 |
320 | def zero_grads(self, dom_a, dom_b):
321 | self.optimizers[dom_a].zero_grad()
322 | self.optimizers[dom_b].zero_grad()
323 |
324 | def step_grads(self, dom_a, dom_b):
325 | self.optimizers[dom_a].step()
326 | self.optimizers[dom_b].step()
327 |
328 | def update_lr(self, new_lr):
329 | for opt in self.optimizers:
330 | for param_group in opt.param_groups:
331 | param_group['lr'] = new_lr
332 |
333 | def save(self, save_path):
334 | for i, net in enumerate(self.networks):
335 | filename = save_path + ('%d.pth' % i)
336 | torch.save(net.cpu().state_dict(), filename)
337 |
338 | def load(self, save_path):
339 | for i, net in enumerate(self.networks):
340 | filename = save_path + ('%d.pth' % i)
341 | net.load_state_dict(torch.load(filename))
342 |
343 | class G_Plexer(Plexer):
344 | def __init__(self, n_domains, encoder, enc_args, decoder, dec_args,
345 | block=None, shenc_args=None, shdec_args=None):
346 | super(G_Plexer, self).__init__()
347 | self.encoders = [encoder(*enc_args) for _ in range(n_domains)]
348 | self.decoders = [decoder(*dec_args) for _ in range(n_domains)]
349 |
350 | self.sharing = block is not None
351 | if self.sharing:
352 | self.shared_encoder = block(*shenc_args)
353 | self.shared_decoder = block(*shdec_args)
354 | self.encoders.append( self.shared_encoder )
355 | self.decoders.append( self.shared_decoder )
356 | self.networks = self.encoders + self.decoders
357 |
358 | def init_optimizers(self, opt, lr, betas):
359 | self.optimizers = []
360 | for enc, dec in zip(self.encoders, self.decoders):
361 | params = itertools.chain(enc.parameters(), dec.parameters())
362 | self.optimizers.append( opt(params, lr=lr, betas=betas) )
363 |
364 | def forward(self, input, in_domain, out_domain):
365 | encoded = self.encode(input, in_domain)
366 | return self.decode(encoded, out_domain)
367 |
368 | def encode(self, input, domain):
369 | output = self.encoders[domain].forward(input)
370 | if self.sharing:
371 | return self.shared_encoder.forward(output, domain)
372 | return output
373 |
374 | def decode(self, input, domain):
375 | if self.sharing:
376 | input = self.shared_decoder.forward(input, domain)
377 | return self.decoders[domain].forward(input)
378 |
379 | def zero_grads(self, dom_a, dom_b):
380 | self.optimizers[dom_a].zero_grad()
381 | if self.sharing:
382 | self.optimizers[-1].zero_grad()
383 | self.optimizers[dom_b].zero_grad()
384 |
385 | def step_grads(self, dom_a, dom_b):
386 | self.optimizers[dom_a].step()
387 | if self.sharing:
388 | self.optimizers[-1].step()
389 | self.optimizers[dom_b].step()
390 |
391 | def __repr__(self):
392 | e, d = self.encoders[0], self.decoders[0]
393 | e_params = sum([p.numel() for p in e.parameters()])
394 | d_params = sum([p.numel() for p in d.parameters()])
395 | return repr(e) +'\n'+ repr(d) +'\n'+ \
396 | 'Created %d Encoder-Decoder pairs' % len(self.encoders) +'\n'+ \
397 | 'Number of parameters per Encoder: %d' % e_params +'\n'+ \
398 | 'Number of parameters per Deocder: %d' % d_params
399 |
400 | class D_Plexer(Plexer):
401 | def __init__(self, n_domains, model, model_args):
402 | super(D_Plexer, self).__init__()
403 | self.networks = [model(*model_args) for _ in range(n_domains)]
404 |
405 | def forward(self, input, domain):
406 | discriminator = self.networks[domain]
407 | return discriminator.forward(input)
408 |
409 | def __repr__(self):
410 | t = self.networks[0]
411 | t_params = sum([p.numel() for p in t.parameters()])
412 | return repr(t) +'\n'+ \
413 | 'Created %d Discriminators' % len(self.networks) +'\n'+ \
414 | 'Number of parameters per Discriminator: %d' % t_params
415 |
416 |
417 | class SequentialContext(nn.Sequential):
418 | def __init__(self, n_classes, *args):
419 | super(SequentialContext, self).__init__(*args)
420 | self.n_classes = n_classes
421 | self.context_var = None
422 |
423 | def prepare_context(self, input, domain):
424 | if self.context_var is None or self.context_var.size()[-2:] != input.size()[-2:]:
425 | tensor = torch.cuda.FloatTensor if isinstance(input.data, torch.cuda.FloatTensor) \
426 | else torch.FloatTensor
427 | self.context_var = tensor(*((1, self.n_classes) + input.size()[-2:]))
428 |
429 | self.context_var.data.fill_(-1.0)
430 | self.context_var.data[:,domain,:,:] = 1.0
431 | return self.context_var
432 |
433 | def forward(self, *input):
434 | if self.n_classes < 2 or len(input) < 2:
435 | return super(SequentialContext, self).forward(input[0])
436 | x, domain = input
437 |
438 | for module in self._modules.values():
439 | if 'Conv' in module.__class__.__name__:
440 | context_var = self.prepare_context(x, domain)
441 | x = torch.cat([x, context_var], dim=1)
442 | elif 'Block' in module.__class__.__name__:
443 | x = (x,) + input[1:]
444 | x = module(x)
445 | return x
446 |
447 | class SequentialOutput(nn.Sequential):
448 | def __init__(self, *args):
449 | args = [nn.Sequential(*arg) for arg in args]
450 | super(SequentialOutput, self).__init__(*args)
451 |
452 | def forward(self, input):
453 | predictions = []
454 | layers = self._modules.values()
455 | for i, module in enumerate(layers):
456 | output = module(input)
457 | if i == 0:
458 | input = output; continue
459 | predictions.append( output[:,-1,:,:] )
460 | if i != len(layers) - 1:
461 | input = output[:,:-1,:,:]
462 | return predictions
463 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AAnoosheh/ToDayGAN/12de3af4f209cd227bdc54160fc3164f97f6732d/options/__init__.py
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from util import util
4 | import torch
5 |
6 | class BaseOptions():
7 | def __init__(self):
8 | self.parser = argparse.ArgumentParser()
9 | self.initialized = False
10 |
11 | def initialize(self):
12 | self.parser.add_argument('--name', required=True, type=str, help='name of the experiment. It decides where to store samples and models')
13 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
14 |
15 | self.parser.add_argument('--dataroot', required=True, type=str, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
16 | self.parser.add_argument('--n_domains', required=True, type=int, help='Number of domains to transfer among')
17 |
18 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
19 | self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [none|resize|resize_and_crop|crop]')
20 | self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
21 |
22 | self.parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size')
23 | self.parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size')
24 |
25 | self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
26 | self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
27 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
28 |
29 | self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
30 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
31 | self.parser.add_argument('--netG_n_blocks', type=int, default=9, help='number of residual blocks to use for netG')
32 | self.parser.add_argument('--netG_n_shared', type=int, default=0, help='number of blocks to use for netG shared center module')
33 | self.parser.add_argument('--netD_n_layers', type=int, default=4, help='number of layers to use for netD')
34 |
35 | self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
36 | self.parser.add_argument('--use_dropout', action='store_true', help='insert dropout for the generator')
37 |
38 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
39 | self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')
40 |
41 | self.parser.add_argument('--display_id', type=int, default=0, help='window id of the web display (set >1 to use visdom)')
42 | self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
43 | self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
44 | self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
45 |
46 | self.initialized = True
47 |
48 | def parse(self):
49 | if not self.initialized:
50 | self.initialize()
51 | self.opt = self.parser.parse_args()
52 | self.opt.isTrain = self.isTrain # train or test
53 |
54 | str_ids = self.opt.gpu_ids.split(',')
55 | self.opt.gpu_ids = []
56 | for str_id in str_ids:
57 | id = int(str_id)
58 | if id >= 0:
59 | self.opt.gpu_ids.append(id)
60 |
61 | # set gpu ids
62 | if len(self.opt.gpu_ids) > 0:
63 | torch.cuda.set_device(self.opt.gpu_ids[0])
64 |
65 | args = vars(self.opt)
66 |
67 | print('------------ Options -------------')
68 | for k, v in sorted(args.items()):
69 | print('%s: %s' % (str(k), str(v)))
70 | print('-------------- End ----------------')
71 |
72 | # save to the disk
73 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
74 | util.mkdirs(expr_dir)
75 | file_name = os.path.join(expr_dir, 'opt.txt')
76 | with open(file_name, 'wt') as opt_file:
77 | opt_file.write('------------ Options -------------\n')
78 | for k, v in sorted(args.items()):
79 | opt_file.write('%s: %s\n' % (str(k), str(v)))
80 | opt_file.write('-------------- End ----------------\n')
81 | return self.opt
82 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | def initialize(self):
6 | BaseOptions.initialize(self)
7 | self.isTrain = False
8 |
9 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
10 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
11 |
12 | self.parser.add_argument('--which_epoch', required=True, type=int, help='which epoch to load for inference?')
13 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc (determines name of folder to load from)')
14 |
15 | self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run (if serial_test not enabled)')
16 | self.parser.add_argument('--serial_test', action='store_true', help='read each image once from folders in sequential order')
17 |
18 | self.parser.add_argument('--autoencode', action='store_true', help='translate images back into its own domain')
19 | self.parser.add_argument('--reconstruct', action='store_true', help='do reconstructions of images during testing')
20 |
21 | self.parser.add_argument('--show_matrix', action='store_true', help='visualize images in a matrix format as well')
22 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TrainOptions(BaseOptions):
5 | def initialize(self):
6 | BaseOptions.initialize(self)
7 | self.isTrain = True
8 |
9 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
10 | self.parser.add_argument('--which_epoch', type=int, default=0, help='which epoch to load if continuing training')
11 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc (determines name of folder to load from)')
12 |
13 | self.parser.add_argument('--niter', required=True, type=int, help='# of epochs at starting learning rate (try 50*n_domains)')
14 | self.parser.add_argument('--niter_decay', required=True, type=int, help='# of epochs to linearly decay learning rate to zero (try 50*n_domains)')
15 |
16 | self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for ADAM')
17 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of ADAM')
18 |
19 | self.parser.add_argument('--lambda_cycle', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
20 | self.parser.add_argument('--lambda_identity', type=float, default=0.0, help='weight for identity "autoencode" mapping (A -> A)')
21 | self.parser.add_argument('--lambda_latent', type=float, default=0.0, help='weight for latent-space loss (A -> z -> B -> z)')
22 | self.parser.add_argument('--lambda_forward', type=float, default=0.0, help='weight for forward loss (A -> B; try 0.2)')
23 |
24 | self.parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
25 | self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
26 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
27 |
28 | self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
29 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
30 |
--------------------------------------------------------------------------------
/scripts/continue_combogan.sh:
--------------------------------------------------------------------------------
1 | python train.py \
2 | --continue_train --which_epoch 30 \
3 | --dataroot ./datasets/robotcar \
4 | --name robotcar_night2day \
5 | --n_domains 2 \
6 | --niter 25 --niter_decay 25 \
7 | --loadSize 512 --fineSize 384
8 |
--------------------------------------------------------------------------------
/scripts/test_combogan.sh:
--------------------------------------------------------------------------------
1 | python test.py \
2 | --phase test --which_epoch 50 \
3 | --serial_test \
4 | --dataroot ./datasets/robotcar \
5 | --name robotcar_night2day \
6 | --n_domains 2 \
7 | --loadSize 512
8 |
--------------------------------------------------------------------------------
/scripts/train_combogan.sh:
--------------------------------------------------------------------------------
1 | python train.py \
2 | --dataroot ./datasets/robotcar \
3 | --name robotcar_night2day \
4 | --n_domains 2 \
5 | --niter 75 --niter_decay 75 \
6 | --loadSize 512 --fineSize 384
7 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | from options.test_options import TestOptions
4 | from data.data_loader import DataLoader
5 | from models.combogan_model import ComboGANModel
6 | from util.visualizer import Visualizer
7 | from util import html
8 |
9 |
10 | opt = TestOptions().parse()
11 | opt.nThreads = 1 # test code only supports nThreads = 1
12 | opt.batchSize = 1 # test code only supports batchSize = 1
13 |
14 | dataset = DataLoader(opt)
15 | model = ComboGANModel(opt)
16 | visualizer = Visualizer(opt)
17 | # create website
18 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%d' % (opt.phase, opt.which_epoch))
19 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %d' % (opt.name, opt.phase, opt.which_epoch))
20 | # store images for matrix visualization
21 | vis_buffer = []
22 |
23 | # test
24 | for i, data in enumerate(dataset):
25 | if not opt.serial_test and i >= opt.how_many:
26 | break
27 | model.set_input(data)
28 | model.test()
29 | visuals = model.get_current_visuals(testing=True)
30 | img_path = model.get_image_paths()
31 | print('process image... %s' % img_path)
32 | visualizer.save_images(webpage, visuals, img_path)
33 |
34 | if opt.show_matrix:
35 | vis_buffer.append(visuals)
36 | if (i+1) % opt.n_domains == 0:
37 | save_path = os.path.join(web_dir, 'mat_%d.png' % (i//opt.n_domains))
38 | visualizer.save_image_matrix(vis_buffer, save_path)
39 | vis_buffer.clear()
40 |
41 | webpage.save()
42 |
43 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | from options.train_options import TrainOptions
3 | from data.data_loader import DataLoader
4 | from models.combogan_model import ComboGANModel
5 | from util.visualizer import Visualizer
6 |
7 |
8 | opt = TrainOptions().parse()
9 | dataset = DataLoader(opt)
10 | print('# training images = %d' % len(dataset))
11 | model = ComboGANModel(opt)
12 | visualizer = Visualizer(opt)
13 | total_steps = 0
14 |
15 | # Update initially if continuing
16 | if opt.which_epoch > 0:
17 | model.update_hyperparams(opt.which_epoch)
18 |
19 | for epoch in range(opt.which_epoch + 1, opt.niter + opt.niter_decay + 1):
20 | epoch_start_time = time.time()
21 | epoch_iter = 0
22 | for i, data in enumerate(dataset):
23 | iter_start_time = time.time()
24 | total_steps += opt.batchSize
25 | epoch_iter += opt.batchSize
26 | model.set_input(data)
27 | model.optimize_parameters()
28 |
29 | if total_steps % opt.display_freq == 0:
30 | visualizer.display_current_results(model.get_current_visuals(), epoch)
31 |
32 | if total_steps % opt.print_freq == 0:
33 | errors = model.get_current_errors()
34 | t = (time.time() - iter_start_time) / opt.batchSize
35 | visualizer.print_current_errors(epoch, epoch_iter, errors, t)
36 | if opt.display_id > 0:
37 | visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)
38 |
39 | if epoch % opt.save_epoch_freq == 0:
40 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
41 | model.save(epoch)
42 |
43 | print('End of epoch %d / %d \t Time Taken: %d sec' %
44 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
45 |
46 | model.update_hyperparams(epoch)
47 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AAnoosheh/ToDayGAN/12de3af4f209cd227bdc54160fc3164f97f6732d/util/__init__.py
--------------------------------------------------------------------------------
/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import *
3 | import os
4 |
5 |
6 | class HTML:
7 | def __init__(self, web_dir, title, reflesh=0):
8 | self.title = title
9 | self.web_dir = web_dir
10 | self.img_dir = os.path.join(self.web_dir, 'images')
11 | if not os.path.exists(self.web_dir):
12 | os.makedirs(self.web_dir)
13 | if not os.path.exists(self.img_dir):
14 | os.makedirs(self.img_dir)
15 | # print(self.img_dir)
16 |
17 | self.doc = dominate.document(title=title)
18 | if reflesh > 0:
19 | with self.doc.head:
20 | meta(http_equiv="reflesh", content=str(reflesh))
21 |
22 | def get_image_dir(self):
23 | return self.img_dir
24 |
25 | def add_header(self, str):
26 | with self.doc:
27 | h3(str)
28 |
29 | def add_table(self, border=1):
30 | self.t = table(border=border, style="table-layout: fixed;")
31 | self.doc.add(self.t)
32 |
33 | def add_images(self, ims, txts, links, width=400):
34 | self.add_table()
35 | with self.t:
36 | with tr():
37 | for im, txt, link in zip(ims, txts, links):
38 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
39 | with p():
40 | with a(href=os.path.join('images', link)):
41 | img(style="width:%dpx" % width, src=os.path.join('images', im))
42 | br()
43 | p(txt)
44 |
45 | def save(self):
46 | html_file = '%s/index.html' % self.web_dir
47 | f = open(html_file, 'wt')
48 | f.write(self.doc.render())
49 | f.close()
50 |
51 |
52 | if __name__ == '__main__':
53 | html = HTML('web/', 'test_html')
54 | html.add_header('hello world')
55 |
56 | ims = []
57 | txts = []
58 | links = []
59 | for n in range(4):
60 | ims.append('image_%d.png' % n)
61 | txts.append('text_%d' % n)
62 | links.append('image_%d.png' % n)
63 | html.add_images(ims, txts, links)
64 | html.save()
65 |
--------------------------------------------------------------------------------
/util/image_pool.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 | from torch.autograd import Variable
5 | class ImagePool():
6 | def __init__(self, pool_size):
7 | self.pool_size = pool_size
8 | if self.pool_size > 0:
9 | self.num_imgs = 0
10 | self.images = []
11 |
12 | def query(self, images):
13 | if self.pool_size == 0:
14 | return images
15 | return_images = []
16 | for image in images.data:
17 | image = torch.unsqueeze(image, 0)
18 | if self.num_imgs < self.pool_size:
19 | self.num_imgs = self.num_imgs + 1
20 | self.images.append(image)
21 | return_images.append(image)
22 | else:
23 | p = random.uniform(0, 1)
24 | if p > 0.5:
25 | random_id = random.randint(0, self.pool_size-1)
26 | tmp = self.images[random_id].clone()
27 | self.images[random_id] = image
28 | return_images.append(tmp)
29 | else:
30 | return_images.append(image)
31 | return_images = Variable(torch.cat(return_images, 0))
32 | return return_images
33 |
--------------------------------------------------------------------------------
/util/png.py:
--------------------------------------------------------------------------------
1 | import struct
2 | import zlib
3 |
4 | def encode(buf, width, height):
5 | """ buf: must be bytes or a bytearray in py3, a regular string in py2. formatted RGBRGB... """
6 | assert (width * height * 3 == len(buf))
7 | bpp = 3
8 |
9 | def raw_data():
10 | # reverse the vertical line order and add null bytes at the start
11 | row_bytes = width * bpp
12 | for row_start in range((height - 1) * width * bpp, -1, -row_bytes):
13 | yield b'\x00'
14 | yield buf[row_start:row_start + row_bytes]
15 |
16 | def chunk(tag, data):
17 | return [
18 | struct.pack("!I", len(data)),
19 | tag,
20 | data,
21 | struct.pack("!I", 0xFFFFFFFF & zlib.crc32(data, zlib.crc32(tag)))
22 | ]
23 |
24 | SIGNATURE = b'\x89PNG\r\n\x1a\n'
25 | COLOR_TYPE_RGB = 2
26 | COLOR_TYPE_RGBA = 6
27 | bit_depth = 8
28 | return b''.join(
29 | [ SIGNATURE ] +
30 | chunk(b'IHDR', struct.pack("!2I5B", width, height, bit_depth, COLOR_TYPE_RGB, 0, 0, 0)) +
31 | chunk(b'IDAT', zlib.compress(b''.join(raw_data()), 9)) +
32 | chunk(b'IEND', b'')
33 | )
34 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import numpy as np
4 | from scipy.ndimage.filters import gaussian_filter
5 | from PIL import Image
6 | import inspect, re
7 | import os
8 | import collections
9 |
10 | # Converts a Tensor into a Numpy array
11 | # |imtype|: the desired type of the converted numpy array
12 | def tensor2im(image_tensor, imtype=np.uint8):
13 | image_numpy = image_tensor[0].cpu().float().numpy()
14 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
15 | if image_numpy.shape[2] < 3:
16 | image_numpy = np.dstack([image_numpy]*3)
17 | return image_numpy.astype(imtype)
18 |
19 | def gkern_2d(size=5, sigma=3):
20 | # Create 2D gaussian kernel
21 | dirac = np.zeros((size, size))
22 | dirac[size//2, size//2] = 1
23 | mask = gaussian_filter(dirac, sigma)
24 | # Adjust dimensions for torch conv2d
25 | return np.stack([np.expand_dims(mask, axis=0)] * 3)
26 |
27 |
28 | def diagnose_network(net, name='network'):
29 | mean = 0.0
30 | count = 0
31 | for param in net.parameters():
32 | if param.grad is not None:
33 | mean += torch.mean(torch.abs(param.grad.data))
34 | count += 1
35 | if count > 0:
36 | mean = mean / count
37 | print(name)
38 | print(mean)
39 |
40 |
41 | def save_image(image_numpy, image_path):
42 | image_pil = Image.fromarray(image_numpy)
43 | image_pil.save(image_path)
44 |
45 | def info(object, spacing=10, collapse=1):
46 | """Print methods and doc strings.
47 | Takes module, class, list, dictionary, or string."""
48 | methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)]
49 | processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s)
50 | print( "\n".join(["%s %s" %
51 | (method.ljust(spacing),
52 | processFunc(str(getattr(object, method).__doc__)))
53 | for method in methodList]) )
54 |
55 | def varname(p):
56 | for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]:
57 | m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line)
58 | if m:
59 | return m.group(1)
60 |
61 | def print_numpy(x, val=True, shp=False):
62 | x = x.astype(np.float64)
63 | if shp:
64 | print('shape,', x.shape)
65 | if val:
66 | x = x.flatten()
67 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
68 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
69 |
70 |
71 | def mkdirs(paths):
72 | if isinstance(paths, list) and not isinstance(paths, str):
73 | for path in paths:
74 | mkdir(path)
75 | else:
76 | mkdir(paths)
77 |
78 |
79 | def mkdir(path):
80 | if not os.path.exists(path):
81 | os.makedirs(path)
82 |
--------------------------------------------------------------------------------
/util/visualizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import ntpath
4 | import time
5 | from . import util
6 | from . import html
7 |
8 | class Visualizer():
9 | def __init__(self, opt):
10 | # self.opt = opt
11 | self.display_id = opt.display_id
12 | self.use_html = opt.isTrain and not opt.no_html
13 | self.win_size = opt.display_winsize
14 | self.name = opt.name
15 | if self.display_id > 0:
16 | import visdom
17 | self.vis = visdom.Visdom(port = opt.display_port)
18 | self.display_single_pane_ncols = opt.display_single_pane_ncols
19 |
20 | if self.use_html:
21 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
22 | self.img_dir = os.path.join(self.web_dir, 'images')
23 | print('create web directory %s...' % self.web_dir)
24 | util.mkdirs([self.web_dir, self.img_dir])
25 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
26 | with open(self.log_name, "a") as log_file:
27 | now = time.strftime("%c")
28 | log_file.write('================ Training Loss (%s) ================\n' % now)
29 |
30 | # |visuals|: dictionary of images to display or save
31 | def display_current_results(self, visuals, epoch):
32 | if self.display_id > 0: # show images in the browser
33 | if self.display_single_pane_ncols > 0:
34 | h, w = next(iter(visuals.values())).shape[:2]
35 | table_css = """""" % (w, h)
39 | ncols = self.display_single_pane_ncols
40 | title = self.name
41 | label_html = ''
42 | label_html_row = ''
43 | nrows = int(np.ceil(len(visuals.items()) / ncols))
44 | images = []
45 | idx = 0
46 | for label, image_numpy in visuals.items():
47 | label_html_row += '%s | ' % label
48 | images.append(image_numpy.transpose([2, 0, 1]))
49 | idx += 1
50 | if idx % ncols == 0:
51 | label_html += '%s
' % label_html_row
52 | label_html_row = ''
53 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255
54 | while idx % ncols != 0:
55 | images.append(white_image)
56 | label_html_row += ' | '
57 | idx += 1
58 | if label_html_row != '':
59 | label_html += '%s
' % label_html_row
60 | # pane col = image row
61 | self.vis.images(images, nrow=ncols, win=self.display_id + 1,
62 | padding=2, opts=dict(title=title + ' images'))
63 | label_html = '' % label_html
64 | self.vis.text(table_css + label_html, win = self.display_id + 2,
65 | opts=dict(title=title + ' labels'))
66 | else:
67 | idx = 1
68 | for label, image_numpy in visuals.items():
69 | #image_numpy = np.flipud(image_numpy)
70 | self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label),
71 | win=self.display_id + idx)
72 | idx += 1
73 |
74 | if self.use_html: # save images to a html file
75 | for label, image_numpy in visuals.items():
76 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
77 | util.save_image(image_numpy, img_path)
78 | # update website
79 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
80 | for n in range(epoch, 0, -1):
81 | webpage.add_header('epoch [%d]' % n)
82 | ims = []
83 | txts = []
84 | links = []
85 |
86 | for label, image_numpy in visuals.items():
87 | img_path = 'epoch%.3d_%s.png' % (n, label)
88 | ims.append(img_path)
89 | txts.append(label)
90 | links.append(img_path)
91 | webpage.add_images(ims, txts, links, width=self.win_size)
92 | webpage.save()
93 |
94 | # errors: dictionary of error labels and values
95 | def plot_current_errors(self, epoch, counter_ratio, opt, errors):
96 | if not hasattr(self, 'plot_data'):
97 | self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())}
98 | self.plot_data['X'].append(epoch + counter_ratio)
99 | self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
100 | self.vis.line(
101 | X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1),
102 | Y=np.array(self.plot_data['Y']),
103 | opts={
104 | 'title': self.name + ' loss over time',
105 | 'legend': self.plot_data['legend'],
106 | 'xlabel': 'epoch',
107 | 'ylabel': 'loss'},
108 | win=self.display_id)
109 |
110 | # errors: same format as |errors| of plotCurrentErrors
111 | def print_current_errors(self, epoch, i, errors, t):
112 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
113 | for k, v in errors.items():
114 | v = ['%.3f' % iv for iv in v]
115 | message += k + ': ' + ', '.join(v) + ' | '
116 |
117 | print(message)
118 | with open(self.log_name, "a") as log_file:
119 | log_file.write('%s\n' % message)
120 |
121 | # save image to the disk
122 | def save_images(self, webpage, visuals, image_path):
123 | image_dir = webpage.get_image_dir()
124 | short_path = ntpath.basename(image_path[0])
125 | name = os.path.splitext(short_path)[0]
126 |
127 | webpage.add_header(name)
128 | ims = []
129 | txts = []
130 | links = []
131 |
132 | for label, image_numpy in visuals.items():
133 | image_name = '%s_%s.png' % (name, label)
134 | save_path = os.path.join(image_dir, image_name)
135 | util.save_image(image_numpy, save_path)
136 |
137 | ims.append(image_name)
138 | txts.append(label)
139 | links.append(image_name)
140 | webpage.add_images(ims, txts, links, width=self.win_size)
141 |
142 | def save_image_matrix(self, visuals_list, save_path):
143 | images_list = []
144 | get_domain = lambda x: x.split('_')[-1]
145 |
146 | for visuals in visuals_list:
147 | pairs = list(visuals.items())
148 | real_label, real_img = pairs[0]
149 | real_dom = get_domain(real_label)
150 |
151 | for label, img in pairs:
152 | if 'fake' not in label:
153 | continue
154 | if get_domain(label) == real_dom:
155 | images_list.append(real_img)
156 | else:
157 | images_list.append(img)
158 |
159 | immat = self.stack_images(images_list)
160 | util.save_image(immat, save_path)
161 |
162 | # reshape a list of images into a square matrix of them
163 | def stack_images(self, list_np_images):
164 | n = int(np.ceil(np.sqrt(len(list_np_images))))
165 |
166 | # add padding between images
167 | for i, im in enumerate(list_np_images):
168 | val = 255 if i%n == i//n else 0
169 | r_pad = np.pad(im[:,:,0], (3,3), mode='constant', constant_values=0)
170 | g_pad = np.pad(im[:,:,1], (3,3), mode='constant', constant_values=val)
171 | b_pad = np.pad(im[:,:,2], (3,3), mode='constant', constant_values=0)
172 | list_np_images[i] = np.stack([r_pad,g_pad,b_pad], axis=2)
173 |
174 | data = np.array(list_np_images)
175 | data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
176 | data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])
177 | return data
178 |
--------------------------------------------------------------------------------