├── examples
├── 00.jpg
├── 01.jpg
├── 02.jpg
├── 03.jpg
├── 04.jpg
├── 05.jpg
├── 06.jpg
├── 07.jpg
├── 08.jpg
├── 09.jpg
├── 10.jpg
└── 11.jpg
├── images
├── sample.png
└── architecture.png
├── util
├── __init__.py
├── image_pool.py
├── html.py
├── get_data.py
├── util.py
└── visualizer.py
├── options
├── __init__.py
├── test_options.py
├── train_options.py
└── base_options.py
├── requirements.txt
├── .gitignore
├── test_seq_style3.py
├── function.py
├── data
├── image_folder.py
├── single_dataset.py
├── __init__.py
├── base_dataset.py
└── unaligned_mask_stylecls_dataset.py
├── readme.md
├── models
├── __init__.py
├── test_model.py
├── pretrained_networks.py
├── base_model.py
├── cycle_gan_cls_model.py
└── networks.py
├── QMUPD.ipynb
├── test.py
└── train.py
/examples/00.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/00.jpg
--------------------------------------------------------------------------------
/examples/01.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/01.jpg
--------------------------------------------------------------------------------
/examples/02.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/02.jpg
--------------------------------------------------------------------------------
/examples/03.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/03.jpg
--------------------------------------------------------------------------------
/examples/04.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/04.jpg
--------------------------------------------------------------------------------
/examples/05.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/05.jpg
--------------------------------------------------------------------------------
/examples/06.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/06.jpg
--------------------------------------------------------------------------------
/examples/07.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/07.jpg
--------------------------------------------------------------------------------
/examples/08.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/08.jpg
--------------------------------------------------------------------------------
/examples/09.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/09.jpg
--------------------------------------------------------------------------------
/examples/10.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/10.jpg
--------------------------------------------------------------------------------
/examples/11.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/11.jpg
--------------------------------------------------------------------------------
/images/sample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/QMUPD/master/images/sample.png
--------------------------------------------------------------------------------
/images/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cedro3/QMUPD/master/images/architecture.png
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes a miscellaneous collection of useful helper functions."""
2 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.2.0
2 | torchvision==0.4.0
3 | dominate==2.4.0
4 | visdom==0.1.8.9
5 | scipy==1.1.0
6 | numpy==1.16.4
7 | #Pillow==6.2.1
8 | opencv-python==4.1.0.25
9 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | debug*
3 | datasets/
4 | checkpoints/
5 | results/
6 | build/
7 | dist/
8 | *.png
9 | torch.egg-info/
10 | */**/__pycache__
11 | torch/version.py
12 | torch/csrc/generic/TensorMethods.cpp
13 | torch/lib/*.so*
14 | torch/lib/*.dylib*
15 | torch/lib/*.h
16 | torch/lib/build
17 | torch/lib/tmp_install
18 | torch/lib/include
19 | torch/lib/torch_shm_manager
20 | torch/csrc/cudnn/cuDNN.cpp
21 | torch/csrc/nn/THNN.cwrap
22 | torch/csrc/nn/THNN.cpp
23 | torch/csrc/nn/THCUNN.cwrap
24 | torch/csrc/nn/THCUNN.cpp
25 | torch/csrc/nn/THNN_generic.cwrap
26 | torch/csrc/nn/THNN_generic.cpp
27 | torch/csrc/nn/THNN_generic.h
28 | docs/src/**/*
29 | test/data/legacy_modules.t7
30 | test/data/gpu_tensors.pt
31 | test/htmlcov
32 | test/.coverage
33 | */*.pyc
34 | */**/*.pyc
35 | */**/**/*.pyc
36 | */**/**/**/*.pyc
37 | */**/**/**/**/*.pyc
38 | */*.so*
39 | */**/*.so*
40 | */**/*.dylib*
41 | test/data/legacy_serialized.pt
42 | *~
43 | .idea
44 | txt_output/*
45 | vo/*
46 | *.xlsx
47 |
--------------------------------------------------------------------------------
/test_seq_style3.py:
--------------------------------------------------------------------------------
1 | import os, glob
2 |
3 | #================== settings ==================
4 | exp = 'QMUPD_model';epoch='200'
5 | dataroot = 'examples'
6 | gpu_id = '-1'
7 |
8 | netga = 'resnet_style2_9blocks'
9 | model0_res = 0
10 | model1_res = 0
11 | imgsize = 512
12 | extraflag = ' --netga %s --model0_res %d --model1_res %d' % (netga, model0_res, model1_res)
13 |
14 | #==================== command ==================
15 | for vec in [[1,0,0],[0,1,0],[0,0,1]]:
16 | svec = '%d,%d,%d' % (vec[0],vec[1],vec[2])
17 | img1 = 'imagesstyle%d-%d-%d'%(vec[0],vec[1],vec[2])
18 | print('results/%s/test_%s/index%s.html'%(exp,epoch,img1[6:]))
19 | command = 'python test.py --dataroot %s --name %s --model test --output_nc 1 --no_dropout --model_suffix _A %s --num_test 1000 --epoch %s --style_control 1 --imagefolder %s --sinput svec --svec %s --crop_size %d --load_size %d --gpu_ids %s' % (dataroot,exp,extraflag,epoch,img1,svec,imgsize,imgsize,gpu_id)
20 | os.system(command)
21 |
22 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | """This class includes test options.
6 |
7 | It also includes shared options defined in BaseOptions.
8 | """
9 |
10 | def initialize(self, parser):
11 | parser = BaseOptions.initialize(self, parser) # define shared options
12 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
13 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
14 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
16 | # Dropout and Batchnorm has different behavioir during training and test.
17 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
18 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
19 | parser.add_argument('--imagefolder', type=str, default='images', help='subfolder to save images')
20 | # rewrite devalue values
21 | parser.set_defaults(model='test')
22 | # To avoid cropping, the load_size should be the same as crop_size
23 | parser.set_defaults(load_size=parser.get_default('crop_size'))
24 | self.isTrain = False
25 | return parser
26 |
--------------------------------------------------------------------------------
/function.py:
--------------------------------------------------------------------------------
1 | # --- display_mp4 ---
2 | from IPython.display import display, HTML
3 | from IPython.display import HTML
4 |
5 | def display_mp4(path):
6 | from base64 import b64encode
7 | mp4 = open(path,'rb').read()
8 | data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
9 | display(HTML("""
10 |
13 | """ % data_url))
14 | #print('Display finished.') ###
15 |
16 |
17 | # --- display_pic ---
18 | import matplotlib.pyplot as plt
19 | from PIL import Image
20 | import numpy as np
21 | import os
22 |
23 | def display_pic(folder):
24 | fig = plt.figure(figsize=(30, 60))
25 | files = os.listdir(folder)
26 | files.sort()
27 | for i, file in enumerate(files):
28 | if file=='.ipynb_checkpoints':
29 | continue
30 | if file=='.DS_Store':
31 | continue
32 | img = Image.open(folder+'/'+file)
33 | images = np.asarray(img)
34 | ax = fig.add_subplot(10, 6, i+1, xticks=[], yticks=[])
35 | image_plt = np.array(images)
36 | ax.imshow(image_plt)
37 | #name = os.path.splitext(file)
38 | ax.set_xlabel(file, fontsize=20)
39 | plt.show()
40 | plt.close()
41 |
42 |
43 | # --- reset_folder ---
44 | import shutil
45 |
46 | def reset_folder(path):
47 | if os.path.isdir(path):
48 | shutil.rmtree(path)
49 | os.makedirs(path,exist_ok=True)
50 |
--------------------------------------------------------------------------------
/data/image_folder.py:
--------------------------------------------------------------------------------
1 | """A modified image folder class
2 |
3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4 | so that this class can load images from both current directory and its subdirectories.
5 | """
6 |
7 | import torch.utils.data as data
8 |
9 | from PIL import Image
10 | import os
11 | import os.path
12 |
13 | IMG_EXTENSIONS = [
14 | '.jpg', '.JPG', '.jpeg', '.JPEG',
15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16 | ]
17 |
18 |
19 | def is_image_file(filename):
20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
21 |
22 |
23 | def make_dataset(dir, max_dataset_size=float("inf")):
24 | images = []
25 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
26 |
27 | for root, _, fnames in sorted(os.walk(dir)):
28 | for fname in fnames:
29 | if is_image_file(fname):
30 | path = os.path.join(root, fname)
31 | images.append(path)
32 | return images[:min(max_dataset_size, len(images))]
33 |
34 |
35 | def default_loader(path):
36 | return Image.open(path).convert('RGB')
37 |
38 |
39 | class ImageFolder(data.Dataset):
40 |
41 | def __init__(self, root, transform=None, return_paths=False,
42 | loader=default_loader):
43 | imgs = make_dataset(root)
44 | if len(imgs) == 0:
45 | raise(RuntimeError("Found 0 images in: " + root + "\n"
46 | "Supported image extensions are: " +
47 | ",".join(IMG_EXTENSIONS)))
48 |
49 | self.root = root
50 | self.imgs = imgs
51 | self.transform = transform
52 | self.return_paths = return_paths
53 | self.loader = loader
54 |
55 | def __getitem__(self, index):
56 | path = self.imgs[index]
57 | img = self.loader(path)
58 | if self.transform is not None:
59 | img = self.transform(img)
60 | if self.return_paths:
61 | return img, path
62 | else:
63 | return img
64 |
65 | def __len__(self):
66 | return len(self.imgs)
67 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 |
2 | # Quality Metric Guided Portrait Line Drawing Generation from Unpaired Training Data
3 |
4 | We provide PyTorch implementations for our TPAMI paper "Quality Metric Guided Portrait Line Drawing Generation from Unpaired Training Data". [paper](https://ieeexplore.ieee.org/document/9699090)
5 |
6 | Our method can (1) learn to generate high quality portrait drawings in multiple styles using a single network and (2) generate portrait drawings in a “new style” unseen in the training data.
7 |
8 |
9 | ## Our Proposed Framework
10 |
11 |
12 |
13 | ## Sample Results
14 |
15 |
16 |
17 | ## Prerequisites
18 | - Linux or macOS
19 | - Python 3
20 | - CPU or NVIDIA GPU + CUDA CuDNN
21 |
22 | ## Installation
23 | - To install the dependencies, run
24 | ```bash
25 | pip install -r requirements.txt
26 | ```
27 |
28 | ## Quick Test (apply a pretrained model, generate high quality portrait drawings in multiple styles using a single network)
29 |
30 | - 1. Download pre-trained models from [BaiduYun](https://pan.baidu.com/s/1eY60A1z2k9gTr9ryDxvMxQ)(extract code:g8is) or [GoogleDrive](https://drive.google.com/drive/folders/1R4mjaiXN3fISp0Lc4rP6DiTbQMCeTC9A?usp=sharing) and rename the folder to `checkpoints/`.
31 |
32 | - 2. Test for example photos: generate artistic portrait drawings for example photos in the folder `./examples` using
33 | ``` bash
34 | python test_seq_style3.py
35 | ```
36 | The test results will be saved to html files here: `./results/QMUPD_model/test_200/indexstyle*.html`.
37 | The result images are saved in `./results/QMUPD_model/test_200/imagesstyle*`,
38 | where `real`, `fake`, correspond to input face photo, synthesized drawing of a certain style, respectively.
39 |
40 | You can contact email ranyi@sjtu.edu.cn for any questions.
41 |
42 |
43 | ## Citation
44 | If you use this code for your research, please cite our paper.
45 |
46 | ```
47 | @article{YiLLR22,
48 | title = {Quality Metric Guided Portrait Line Drawing Generation from Unpaired Training Data},
49 | author = {Yi, Ran and Liu, Yong-Jin and Lai, Yu-Kun and Rosin, Paul L},
50 | journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence},
51 | year = {DOI (identifier) 10.1109/TPAMI.2022.3147570, 2022},
52 | }
53 | ```
54 |
55 | ## Acknowledgments
56 | Our code is inspired by [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).
57 |
--------------------------------------------------------------------------------
/util/image_pool.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 |
4 |
5 | class ImagePool():
6 | """This class implements an image buffer that stores previously generated images.
7 |
8 | This buffer enables us to update discriminators using a history of generated images
9 | rather than the ones produced by the latest generators.
10 | """
11 |
12 | def __init__(self, pool_size):
13 | """Initialize the ImagePool class
14 |
15 | Parameters:
16 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
17 | """
18 | self.pool_size = pool_size
19 | if self.pool_size > 0: # create an empty pool
20 | self.num_imgs = 0
21 | self.images = []
22 |
23 | def query(self, images):
24 | """Return an image from the pool.
25 |
26 | Parameters:
27 | images: the latest generated images from the generator
28 |
29 | Returns images from the buffer.
30 |
31 | By 50/100, the buffer will return input images.
32 | By 50/100, the buffer will return images previously stored in the buffer,
33 | and insert the current images to the buffer.
34 | """
35 | if self.pool_size == 0: # if the buffer size is 0, do nothing
36 | return images
37 | return_images = []
38 | for image in images:
39 | image = torch.unsqueeze(image.data, 0)
40 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
41 | self.num_imgs = self.num_imgs + 1
42 | self.images.append(image)
43 | return_images.append(image)
44 | else:
45 | p = random.uniform(0, 1)
46 | if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
47 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
48 | tmp = self.images[random_id].clone()
49 | self.images[random_id] = image
50 | return_images.append(tmp)
51 | else: # by another 50% chance, the buffer will return the current image
52 | return_images.append(image)
53 | return_images = torch.cat(return_images, 0) # collect all the images and return
54 | return return_images
55 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | """This package contains modules related to objective functions, optimizations, and network architectures.
2 |
3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4 | You need to implement the following five functions:
5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6 | -- : unpack data from dataset and apply preprocessing.
7 | -- : produce intermediate results.
8 | -- : calculate loss, gradients, and update network weights.
9 | -- : (optionally) add model-specific options and set default options.
10 |
11 | In the function <__init__>, you need to define four lists:
12 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
13 | -- self.model_names (str list): define networks used in our training.
14 | -- self.visual_names (str list): specify the images that you want to display and save.
15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16 |
17 | Now you can use the model class by specifying flag '--model dummy'.
18 | See our template model class 'template_model.py' for more details.
19 | """
20 |
21 | import importlib
22 | from models.base_model import BaseModel
23 |
24 |
25 | def find_model_using_name(model_name):
26 | """Import the module "models/[model_name]_model.py".
27 |
28 | In the file, the class called DatasetNameModel() will
29 | be instantiated. It has to be a subclass of BaseModel,
30 | and it is case-insensitive.
31 | """
32 | model_filename = "models." + model_name + "_model"
33 | modellib = importlib.import_module(model_filename)
34 | model = None
35 | target_model_name = model_name.replace('_', '') + 'model'
36 | for name, cls in modellib.__dict__.items():
37 | if name.lower() == target_model_name.lower() \
38 | and issubclass(cls, BaseModel):
39 | model = cls
40 |
41 | if model is None:
42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43 | exit(0)
44 |
45 | return model
46 |
47 |
48 | def get_option_setter(model_name):
49 | """Return the static method of the model class."""
50 | model_class = find_model_using_name(model_name)
51 | return model_class.modify_commandline_options
52 |
53 |
54 | def create_model(opt):
55 | """Create a model given the option.
56 |
57 | This function warps the class CustomDatasetDataLoader.
58 | This is the main interface between this package and 'train.py'/'test.py'
59 |
60 | Example:
61 | >>> from models import create_model
62 | >>> model = create_model(opt)
63 | """
64 | model = find_model_using_name(opt.model)
65 | instance = model(opt)
66 | print("model [%s] was created" % type(instance).__name__)
67 | return instance
68 |
--------------------------------------------------------------------------------
/data/single_dataset.py:
--------------------------------------------------------------------------------
1 | from data.base_dataset import BaseDataset, get_transform, get_params, get_transform_mask
2 | from data.image_folder import make_dataset
3 | from PIL import Image
4 | import torch
5 | import os, glob
6 |
7 |
8 | class SingleDataset(BaseDataset):
9 | """This dataset class can load a set of images specified by the path --dataroot /path/to/data.
10 |
11 | It can be used for generating CycleGAN results only for one side with the model option '-model test'.
12 | """
13 |
14 | def __init__(self, opt):
15 | """Initialize this dataset class.
16 |
17 | Parameters:
18 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
19 | """
20 | BaseDataset.__init__(self, opt)
21 | #self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size))
22 | imglistA = './datasets/list/%s/%s.txt' % (opt.phase+'A', opt.dataroot)
23 | if os.path.exists(imglistA):
24 | self.A_paths = sorted(open(imglistA, 'r').read().splitlines())
25 | else:
26 | self.A_paths = sorted(glob.glob(opt.dataroot + '/*.*'))
27 | self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
28 | #self.transform = get_transform(opt, grayscale=(input_nc == 1))
29 |
30 | def __getitem__(self, index):
31 | """Return a data point and its metadata information.
32 |
33 | Parameters:
34 | index - - a random integer for data indexing
35 |
36 | Returns a dictionary that contains A and A_paths
37 | A(tensor) - - an image in one domain
38 | A_paths(str) - - the path of the image
39 | """
40 | A_path = self.A_paths[index]
41 | A_img = Image.open(A_path).convert('RGB')
42 | transform_params_A = get_params(self.opt, A_img.size)
43 | A = get_transform(self.opt, transform_params_A, grayscale=(self.input_nc == 1))(A_img)
44 | item = {'A': A, 'A_paths': A_path}
45 |
46 | if self.opt.model == 'test_r1':
47 | basenA = os.path.basename(A_path)
48 | A_addchan_img = Image.open(os.path.join('./datasets/list/mask/A_all',basenA))
49 | A_addchan = get_transform_mask(self.opt, transform_params_A, grayscale=1)(A_addchan_img)
50 | item['A_addchan'] = A_addchan
51 |
52 | if self.opt.style_control:
53 | if self.opt.sinput == 'sind':
54 | B_style = torch.Tensor([0.,0.,0.])
55 | B_style[self.opt.sind] = 1.
56 | elif self.opt.sinput == 'svec':
57 | if self.opt.svec[0] == '~':
58 | self.opt.svec = '-'+self.opt.svec[1:]
59 | ss = self.opt.svec.split(',')
60 | B_style = torch.Tensor([float(ss[0]),float(ss[1]),float(ss[2])])
61 | elif self.opt.sinput == 'simg':
62 | self.featureloc = os.path.join('style_features/styles2/', self.opt.sfeature_mode)
63 | B_style = np.load(self.featureloc, self.opt.simg[:-4]+'.npy')
64 |
65 | B_style = B_style.view(3, 1, 1)
66 | B_style = B_style.repeat(1, 128, 128)
67 | item['B_style'] = B_style
68 |
69 | return item
70 |
71 | def __len__(self):
72 | """Return the total number of images in the dataset."""
73 | return len(self.A_paths)
74 |
--------------------------------------------------------------------------------
/QMUPD.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github",
7 | "colab_type": "text"
8 | },
9 | "source": [
10 | "
"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {
17 | "id": "_7IMTJEdpSMJ"
18 | },
19 | "outputs": [],
20 | "source": [
21 | "#@title セットアップ\n",
22 | "\n",
23 | "# githubからコードを取得\n",
24 | "! git clone https://github.com/cedro3/QMUPD.git\n",
25 | "%cd QMUPD\n",
26 | "\n",
27 | "# ライブラリ・インストール\n",
28 | "! pip install -r requirements.txt\n",
29 | "! pip install pretrainedmodels\n",
30 | "\n",
31 | "# 学習済みパラメータ・ダウンロード\n",
32 | "! pip install --upgrade gdown\n",
33 | "import gdown\n",
34 | "gdown.download('https://drive.google.com/uc?id=1QpuCQ0LrrlsHCs3Vh6xC0uIBlWrDrGo1', 'checkpoints.zip', quiet=False)\n",
35 | "! unzip checkpoints.zip\n",
36 | "\n",
37 | "# 関数インポート\n",
38 | "from function import *"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "metadata": {
45 | "id": "cUz-cTvFIoAu"
46 | },
47 | "outputs": [],
48 | "source": [
49 | "#@title サンプル画像の表示\n",
50 | "display_pic('examples')"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": null,
56 | "metadata": {
57 | "id": "hwtglANBqHUQ"
58 | },
59 | "outputs": [],
60 | "source": [
61 | "#@title 線画の作成\n",
62 | "reset_folder('results')\n",
63 | "! python test_seq_style3.py"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "metadata": {
70 | "id": "g3I50o1kwCEo"
71 | },
72 | "outputs": [],
73 | "source": [
74 | "#@title スタイル1表示\n",
75 | "display_pic('results/QMUPD_model/test_200/imagesstyle0-0-1')"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "metadata": {
82 | "id": "w-kSR4bixOh6"
83 | },
84 | "outputs": [],
85 | "source": [
86 | "#@title スタイル2表示\n",
87 | "display_pic('results/QMUPD_model/test_200/imagesstyle0-1-0')"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": null,
93 | "metadata": {
94 | "id": "ti_sOg8fzMMg"
95 | },
96 | "outputs": [],
97 | "source": [
98 | "#@title スタイル3表示\n",
99 | "display_pic('results/QMUPD_model/test_200/imagesstyle1-0-0')"
100 | ]
101 | }
102 | ],
103 | "metadata": {
104 | "accelerator": "GPU",
105 | "colab": {
106 | "collapsed_sections": [],
107 | "name": "QMUPD",
108 | "provenance": [],
109 | "include_colab_link": true
110 | },
111 | "kernelspec": {
112 | "display_name": "Python 3",
113 | "name": "python3"
114 | },
115 | "language_info": {
116 | "name": "python"
117 | }
118 | },
119 | "nbformat": 4,
120 | "nbformat_minor": 0
121 | }
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TrainOptions(BaseOptions):
5 | """This class includes training options.
6 |
7 | It also includes shared options defined in BaseOptions.
8 | """
9 |
10 | def initialize(self, parser):
11 | parser = BaseOptions.initialize(self, parser)
12 | # visdom and HTML visualization parameters
13 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
14 | parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
15 | parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
16 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
17 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
18 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
19 | parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
20 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
21 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
22 | # network saving and loading parameters
23 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
24 | parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs')
25 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
26 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
27 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
28 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
29 | # training parameters
30 | parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
31 | parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
32 | parser.add_argument('--niter_end', type=int, default=200, help='# of iter to end')
33 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
34 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
35 | parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
36 | parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
37 | parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
38 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
39 |
40 | self.isTrain = True
41 | return parser
42 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes all the modules related to data loading and preprocessing
2 |
3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4 | You need to implement four functions:
5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6 | -- <__len__>: return the size of dataset.
7 | -- <__getitem__>: get a data point from data loader.
8 | -- : (optionally) add dataset-specific options and set default options.
9 |
10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11 | See our template dataset class 'template_dataset.py' for more details.
12 | """
13 | import importlib
14 | import torch.utils.data
15 | from data.base_dataset import BaseDataset
16 |
17 |
18 | def find_dataset_using_name(dataset_name):
19 | """Import the module "data/[dataset_name]_dataset.py".
20 |
21 | In the file, the class called DatasetNameDataset() will
22 | be instantiated. It has to be a subclass of BaseDataset,
23 | and it is case-insensitive.
24 | """
25 | dataset_filename = "data." + dataset_name + "_dataset"
26 | datasetlib = importlib.import_module(dataset_filename)
27 |
28 | dataset = None
29 | target_dataset_name = dataset_name.replace('_', '') + 'dataset'
30 | for name, cls in datasetlib.__dict__.items():
31 | if name.lower() == target_dataset_name.lower() \
32 | and issubclass(cls, BaseDataset):
33 | dataset = cls
34 |
35 | if dataset is None:
36 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
37 |
38 | return dataset
39 |
40 |
41 | def get_option_setter(dataset_name):
42 | """Return the static method of the dataset class."""
43 | dataset_class = find_dataset_using_name(dataset_name)
44 | return dataset_class.modify_commandline_options
45 |
46 |
47 | def create_dataset(opt):
48 | """Create a dataset given the option.
49 |
50 | This function wraps the class CustomDatasetDataLoader.
51 | This is the main interface between this package and 'train.py'/'test.py'
52 |
53 | Example:
54 | >>> from data import create_dataset
55 | >>> dataset = create_dataset(opt)
56 | """
57 | data_loader = CustomDatasetDataLoader(opt)
58 | dataset = data_loader.load_data()
59 | return dataset
60 |
61 |
62 | class CustomDatasetDataLoader():
63 | """Wrapper class of Dataset class that performs multi-threaded data loading"""
64 |
65 | def __init__(self, opt):
66 | """Initialize this class
67 |
68 | Step 1: create a dataset instance given the name [dataset_mode]
69 | Step 2: create a multi-threaded data loader.
70 | """
71 | self.opt = opt
72 | dataset_class = find_dataset_using_name(opt.dataset_mode)
73 | self.dataset = dataset_class(opt)
74 | print("dataset [%s] was created" % type(self.dataset).__name__)
75 | self.dataloader = torch.utils.data.DataLoader(
76 | self.dataset,
77 | batch_size=opt.batch_size,
78 | shuffle=not opt.serial_batches,
79 | num_workers=int(opt.num_threads))
80 |
81 | def load_data(self):
82 | return self
83 |
84 | def __len__(self):
85 | """Return the number of data in the dataset"""
86 | return min(len(self.dataset), self.opt.max_dataset_size)
87 |
88 | def __iter__(self):
89 | """Return a batch of data"""
90 | for i, data in enumerate(self.dataloader):
91 | if i * self.opt.batch_size >= self.opt.max_dataset_size:
92 | break
93 | yield data
94 |
--------------------------------------------------------------------------------
/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br
3 | import os
4 |
5 |
6 | class HTML:
7 | """This HTML class allows us to save images and write texts into a single HTML file.
8 |
9 | It consists of functions such as (add a text header to the HTML file),
10 | (add a row of images to the HTML file), and (save the HTML to the disk).
11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
12 | """
13 |
14 | def __init__(self, web_dir, title, refresh=0, folder='images'):
15 | """Initialize the HTML classes
16 |
17 | Parameters:
18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0:
34 | with self.doc.head:
35 | meta(http_equiv="refresh", content=str(refresh))
36 |
37 | def get_image_dir(self):
38 | """Return the directory that stores images"""
39 | return self.img_dir
40 |
41 | def add_header(self, text):
42 | """Insert a header to the HTML file
43 |
44 | Parameters:
45 | text (str) -- the header text
46 | """
47 | with self.doc:
48 | h3(text)
49 |
50 | def add_images(self, ims, txts, links, width=400):
51 | """add images to the HTML file
52 |
53 | Parameters:
54 | ims (str list) -- a list of image paths
55 | txts (str list) -- a list of image names shown on the website
56 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
57 | """
58 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table
59 | self.doc.add(self.t)
60 | with self.t:
61 | with tr():
62 | for im, txt, link in zip(ims, txts, links):
63 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
64 | with p():
65 | with a(href=os.path.join('images', link)):
66 | #img(style="width:%dpx" % width, src=os.path.join('images', im))
67 | img(style="width:%dpx" % width, src=os.path.join(self.folder, im))
68 | br()
69 | p(txt)
70 |
71 | def save(self):
72 | """save the current content to the HMTL file"""
73 | #html_file = '%s/index.html' % self.web_dir
74 | name = self.folder[6:] if self.folder[:6] == 'images' else self.folder
75 | html_file = '%s/index%s.html' % (self.web_dir, name)
76 | f = open(html_file, 'wt')
77 | f.write(self.doc.render())
78 | f.close()
79 |
80 |
81 | if __name__ == '__main__': # we show an example usage here.
82 | html = HTML('web/', 'test_html')
83 | html.add_header('hello world')
84 |
85 | ims, txts, links = [], [], []
86 | for n in range(4):
87 | ims.append('image_%d.png' % n)
88 | txts.append('text_%d' % n)
89 | links.append('image_%d.png' % n)
90 | html.add_images(ims, txts, links)
91 | html.save()
92 |
--------------------------------------------------------------------------------
/util/get_data.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import tarfile
4 | import requests
5 | from warnings import warn
6 | from zipfile import ZipFile
7 | from bs4 import BeautifulSoup
8 | from os.path import abspath, isdir, join, basename
9 |
10 |
11 | class GetData(object):
12 | """A Python script for downloading CycleGAN or pix2pix datasets.
13 |
14 | Parameters:
15 | technique (str) -- One of: 'cyclegan' or 'pix2pix'.
16 | verbose (bool) -- If True, print additional information.
17 |
18 | Examples:
19 | >>> from util.get_data import GetData
20 | >>> gd = GetData(technique='cyclegan')
21 | >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
22 |
23 | Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh'
24 | and 'scripts/download_cyclegan_model.sh'.
25 | """
26 |
27 | def __init__(self, technique='cyclegan', verbose=True):
28 | url_dict = {
29 | 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/',
30 | 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
31 | }
32 | self.url = url_dict.get(technique.lower())
33 | self._verbose = verbose
34 |
35 | def _print(self, text):
36 | if self._verbose:
37 | print(text)
38 |
39 | @staticmethod
40 | def _get_options(r):
41 | soup = BeautifulSoup(r.text, 'lxml')
42 | options = [h.text for h in soup.find_all('a', href=True)
43 | if h.text.endswith(('.zip', 'tar.gz'))]
44 | return options
45 |
46 | def _present_options(self):
47 | r = requests.get(self.url)
48 | options = self._get_options(r)
49 | print('Options:\n')
50 | for i, o in enumerate(options):
51 | print("{0}: {1}".format(i, o))
52 | choice = input("\nPlease enter the number of the "
53 | "dataset above you wish to download:")
54 | return options[int(choice)]
55 |
56 | def _download_data(self, dataset_url, save_path):
57 | if not isdir(save_path):
58 | os.makedirs(save_path)
59 |
60 | base = basename(dataset_url)
61 | temp_save_path = join(save_path, base)
62 |
63 | with open(temp_save_path, "wb") as f:
64 | r = requests.get(dataset_url)
65 | f.write(r.content)
66 |
67 | if base.endswith('.tar.gz'):
68 | obj = tarfile.open(temp_save_path)
69 | elif base.endswith('.zip'):
70 | obj = ZipFile(temp_save_path, 'r')
71 | else:
72 | raise ValueError("Unknown File Type: {0}.".format(base))
73 |
74 | self._print("Unpacking Data...")
75 | obj.extractall(save_path)
76 | obj.close()
77 | os.remove(temp_save_path)
78 |
79 | def get(self, save_path, dataset=None):
80 | """
81 |
82 | Download a dataset.
83 |
84 | Parameters:
85 | save_path (str) -- A directory to save the data to.
86 | dataset (str) -- (optional). A specific dataset to download.
87 | Note: this must include the file extension.
88 | If None, options will be presented for you
89 | to choose from.
90 |
91 | Returns:
92 | save_path_full (str) -- the absolute path to the downloaded data.
93 |
94 | """
95 | if dataset is None:
96 | selected_dataset = self._present_options()
97 | else:
98 | selected_dataset = dataset
99 |
100 | save_path_full = join(save_path, selected_dataset.split('.')[0])
101 |
102 | if isdir(save_path_full):
103 | warn("\n'{0}' already exists. Voiding Download.".format(
104 | save_path_full))
105 | else:
106 | self._print('Downloading Data...')
107 | url = "{0}/{1}".format(self.url, selected_dataset)
108 | self._download_data(url, save_path=save_path)
109 |
110 | return abspath(save_path_full)
111 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | """General-purpose test script for image-to-image translation.
2 |
3 | Once you have trained your model with train.py, you can use this script to test the model.
4 | It will load a saved model from --checkpoints_dir and save the results to --results_dir.
5 |
6 | It first creates model and dataset given the option. It will hard-code some parameters.
7 | It then runs inference for --num_test images and save results to an HTML file.
8 |
9 | Example (You need to train models first or download pre-trained models from our website):
10 | Test a CycleGAN model (both sides):
11 | python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
12 |
13 | Test a CycleGAN model (one side only):
14 | python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout
15 |
16 | The option '--model test' is used for generating CycleGAN results only for one side.
17 | This option will automatically set '--dataset_mode single', which only loads the images from one set.
18 | On the contrary, using '--model cycle_gan' requires loading and generating results in both directions,
19 | which is sometimes unnecessary. The results will be saved at ./results/.
20 | Use '--results_dir ' to specify the results directory.
21 |
22 | Test a pix2pix model:
23 | python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
24 |
25 | See options/base_options.py and options/test_options.py for more test options.
26 | See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
27 | See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
28 | """
29 | import os
30 | from options.test_options import TestOptions
31 | from data import create_dataset
32 | from models import create_model
33 | from util.visualizer import save_images
34 | from util import html
35 |
36 |
37 | if __name__ == '__main__':
38 | opt = TestOptions().parse() # get test options
39 | # hard-code some parameters for test
40 | opt.num_threads = 0 # test code only supports num_threads = 1
41 | opt.batch_size = 1 # test code only supports batch_size = 1
42 | opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
43 | opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
44 | opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
45 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
46 | model = create_model(opt) # create a model given opt.model and other options
47 | model.setup(opt) # regular setup: load and print networks; create schedulers
48 | # create a website
49 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch)) # define the website directory
50 | #webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))
51 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch), refresh=0, folder=opt.imagefolder)
52 | # test with eval mode. This only affects layers like batchnorm and dropout.
53 | # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode.
54 | # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout.
55 | if opt.eval:
56 | model.eval()
57 | for name in model.model_names:
58 | if isinstance(name, str):
59 | print(getattr(model, 'net' + name).training)
60 | for i, data in enumerate(dataset):
61 | if i >= opt.num_test: # only apply our model to opt.num_test images.
62 | break
63 | model.set_input(data) # unpack data from data loader
64 | model.test() # run inference
65 | visuals = model.get_current_visuals() # get image results
66 | img_path = model.get_image_paths() # get image paths
67 | if i % 5 == 0: # save images to an HTML file
68 | print('processing (%04d)-th image... %s' % (i, img_path))
69 | save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
70 | webpage.save() # save the HTML
71 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | """This module contains simple helper functions """
2 | from __future__ import print_function
3 | import torch
4 | import numpy as np
5 | from PIL import Image
6 | import os
7 | import pdb
8 | from scipy.io import savemat
9 |
10 |
11 | def tensor2im(input_image, imtype=np.uint8):
12 | """"Converts a Tensor array into a numpy image array.
13 |
14 | Parameters:
15 | input_image (tensor) -- the input image tensor array
16 | imtype (type) -- the desired type of the converted numpy array
17 | """
18 | if not isinstance(input_image, np.ndarray):
19 | if isinstance(input_image, torch.Tensor): # get the data from a variable
20 | image_tensor = input_image.data
21 | else:
22 | return input_image
23 | #pdb.set_trace()
24 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
25 | if image_numpy.shape[0] == 1: # grayscale to RGB
26 | image_numpy = np.tile(image_numpy, (3, 1, 1))
27 | elif image_numpy.shape[0] == 2:
28 | image_numpy = np.concatenate([image_numpy, image_numpy[1:2,:,:]], 0)
29 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
30 | else: # if it is a numpy array, do nothing
31 | image_numpy = input_image
32 | return image_numpy.astype(imtype)
33 | #return np.round(image_numpy).astype(imtype),image_numpy
34 |
35 | def tensor2im2(input_image, imtype=np.uint8):
36 | """"Converts a Tensor array into a numpy image array.
37 |
38 | Parameters:
39 | input_image (tensor) -- the input image tensor array
40 | imtype (type) -- the desired type of the converted numpy array
41 | """
42 | if not isinstance(input_image, np.ndarray):
43 | if isinstance(input_image, torch.Tensor): # get the data from a variable
44 | image_tensor = input_image.data
45 | else:
46 | return input_image
47 | #pdb.set_trace()
48 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
49 | if image_numpy.shape[0] == 1: # grayscale to RGB
50 | image_numpy = np.tile(image_numpy, (3, 1, 1))
51 | elif image_numpy.shape[0] == 2:
52 | image_numpy = np.concatenate([image_numpy, image_numpy[1:2,:,:]], 0)
53 | image_numpy = np.transpose(image_numpy, (1, 2, 0))
54 | image_numpy[:,:,0] = image_numpy[:,:,0] * 0.229 + 0.485
55 | image_numpy[:,:,1] = image_numpy[:,:,1] * 0.224 + 0.456
56 | image_numpy[:,:,2] = image_numpy[:,:,2] * 0.225 + 0.406
57 | image_numpy = image_numpy * 255.0 # post-processing: tranpose and scaling
58 | else: # if it is a numpy array, do nothing
59 | image_numpy = input_image
60 | return image_numpy.astype(imtype)
61 |
62 |
63 | def diagnose_network(net, name='network'):
64 | """Calculate and print the mean of average absolute(gradients)
65 |
66 | Parameters:
67 | net (torch network) -- Torch network
68 | name (str) -- the name of the network
69 | """
70 | mean = 0.0
71 | count = 0
72 | for param in net.parameters():
73 | if param.grad is not None:
74 | mean += torch.mean(torch.abs(param.grad.data))
75 | count += 1
76 | if count > 0:
77 | mean = mean / count
78 | print(name)
79 | print(mean)
80 |
81 |
82 | def save_image(image_numpy, image_path):
83 | """Save a numpy image to the disk
84 |
85 | Parameters:
86 | image_numpy (numpy array) -- input numpy array
87 | image_path (str) -- the path of the image
88 | """
89 | image_pil = Image.fromarray(image_numpy)
90 | #pdb.set_trace()
91 | image_pil.save(image_path)
92 |
93 |
94 | def print_numpy(x, val=True, shp=False):
95 | """Print the mean, min, max, median, std, and size of a numpy array
96 |
97 | Parameters:
98 | val (bool) -- if print the values of the numpy array
99 | shp (bool) -- if print the shape of the numpy array
100 | """
101 | x = x.astype(np.float64)
102 | if shp:
103 | print('shape,', x.shape)
104 | if val:
105 | x = x.flatten()
106 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
107 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
108 |
109 |
110 | def mkdirs(paths):
111 | """create empty directories if they don't exist
112 |
113 | Parameters:
114 | paths (str list) -- a list of directory paths
115 | """
116 | if isinstance(paths, list) and not isinstance(paths, str):
117 | for path in paths:
118 | mkdir(path)
119 | else:
120 | mkdir(paths)
121 |
122 |
123 | def mkdir(path):
124 | """create a single empty directory if it didn't exist
125 |
126 | Parameters:
127 | path (str) -- a single directory path
128 | """
129 | if not os.path.exists(path):
130 | os.makedirs(path)
131 |
132 | def normalize_tensor(in_feat,eps=1e-10):
133 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
134 | return in_feat/(norm_factor+eps)
--------------------------------------------------------------------------------
/models/test_model.py:
--------------------------------------------------------------------------------
1 | from .base_model import BaseModel
2 | from . import networks
3 | import torch
4 | import pdb
5 |
6 | class TestModel(BaseModel):
7 | """ This TesteModel can be used to generate CycleGAN results for only one direction.
8 | This model will automatically set '--dataset_mode single', which only loads the images from one collection.
9 |
10 | See the test instruction for more details.
11 | """
12 | @staticmethod
13 | def modify_commandline_options(parser, is_train=True):
14 | """Add new dataset-specific options, and rewrite default values for existing options.
15 |
16 | Parameters:
17 | parser -- original option parser
18 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
19 |
20 | Returns:
21 | the modified parser.
22 |
23 | The model can only be used during test time. It requires '--dataset_mode single'.
24 | You need to specify the network using the option '--model_suffix'.
25 | """
26 | assert not is_train, 'TestModel cannot be used during training time'
27 | parser.set_defaults(dataset_mode='single')
28 | parser.add_argument('--model_suffix', type=str, default='', help='In checkpoints_dir, [epoch]_net_G[model_suffix].pth will be loaded as the generator.')
29 | parser.add_argument('--style_control', type=int, default=0, help='use style_control')
30 | parser.add_argument('--sfeature_mode', type=str, default='vgg19_softmax', help='vgg19 softmax as feature')
31 | parser.add_argument('--sinput', type=str, default='sind', help='use which one for style input')
32 | parser.add_argument('--sind', type=int, default=0, help='one hot for sfeature')
33 | parser.add_argument('--svec', type=str, default='1,0,0', help='3-dim vec')
34 | parser.add_argument('--simg', type=str, default='Yann_Legendre-053', help='drawing example for style')
35 | parser.add_argument('--netga', type=str, default='resnet_style_9blocks', help='net arch for netG_A')
36 | parser.add_argument('--model0_res', type=int, default=0, help='number of resblocks in model0')
37 | parser.add_argument('--model1_res', type=int, default=0, help='number of resblocks in model1 (after insert style, before 2 column merge)')
38 |
39 | return parser
40 |
41 | def __init__(self, opt):
42 | """Initialize the pix2pix class.
43 |
44 | Parameters:
45 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
46 | """
47 | assert(not opt.isTrain)
48 | BaseModel.__init__(self, opt)
49 | # specify the training losses you want to print out. The training/test scripts will call
50 | self.loss_names = []
51 | # specify the images you want to save/display. The training/test scripts will call
52 | #self.visual_names = ['real', 'fake', 'rec', 'fake_B']
53 | self.visual_names = ['real', 'fake']
54 | # specify the models you want to save to the disk. The training/test scripts will call and
55 | self.model_names = ['G' + opt.model_suffix, 'G_B'] # only generator is needed.
56 | if not self.opt.style_control:
57 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG,
58 | opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
59 | else:
60 | print(opt.netga)
61 | print('model0_res', opt.model0_res)
62 | print('model1_res', opt.model1_res)
63 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netga, opt.norm,
64 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.model0_res, opt.model1_res)
65 |
66 | self.netGB = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG,
67 | opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
68 | # assigns the model to self.netG_[suffix] so that it can be loaded
69 | # please see
70 | setattr(self, 'netG' + opt.model_suffix, self.netG) # store netG in self.
71 | setattr(self, 'netG_B', self.netGB) # store netGB in self.
72 |
73 | def set_input(self, input):
74 | """Unpack input data from the dataloader and perform necessary pre-processing steps.
75 |
76 | Parameters:
77 | input: a dictionary that contains the data itself and its metadata information.
78 |
79 | We need to use 'single_dataset' dataset mode. It only load images from one domain.
80 | """
81 | self.real = input['A'].to(self.device)
82 | self.image_paths = input['A_paths']
83 | if self.opt.style_control:
84 | self.style = input['B_style']
85 |
86 | def forward(self):
87 | """Run forward pass."""
88 | if not self.opt.style_control:
89 | self.fake = self.netG(self.real) # G(real)
90 | else:
91 | #print(torch.mean(self.style,(2,3)),'style_control')
92 | self.fake = self.netG(self.real, self.style)
93 |
94 | def optimize_parameters(self):
95 | """No optimization for test model."""
96 | pass
97 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """General-purpose training script for image-to-image translation.
2 |
3 | This script works for various models (with option '--model': e.g., pix2pix, cyclegan, colorization) and
4 | different datasets (with option '--dataset_mode': e.g., aligned, unaligned, single, colorization).
5 | You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model').
6 |
7 | It first creates model, dataset, and visualizer given the option.
8 | It then does standard network training. During the training, it also visualize/save the images, print/save the loss plot, and save models.
9 | The script supports continue/resume training. Use '--continue_train' to resume your previous training.
10 |
11 | Example:
12 | Train a CycleGAN model:
13 | python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
14 | Train a pix2pix model:
15 | python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
16 |
17 | See options/base_options.py and options/train_options.py for more training options.
18 | See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
19 | See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
20 | """
21 | import time
22 | from options.train_options import TrainOptions
23 | from data import create_dataset
24 | from models import create_model
25 | from util.visualizer import Visualizer
26 | import pdb
27 |
28 | if __name__ == '__main__':
29 | start = time.time()
30 | opt = TrainOptions().parse() # get training options
31 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
32 | dataset_size = len(dataset) # get the number of images in the dataset.
33 | print('The number of training images = %d' % dataset_size)
34 |
35 | model = create_model(opt) # create a model given opt.model and other options
36 | model.setup(opt) # regular setup: load and print networks; create schedulers
37 | visualizer = Visualizer(opt) # create a visualizer that display/save images and plots
38 | total_iters = 0 # the total number of training iterations
39 |
40 | #for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): # outer loop for different epochs; we save the model by , +
41 | for epoch in range(opt.epoch_count, opt.niter_end + 1):
42 | epoch_start_time = time.time() # timer for entire epoch
43 | iter_data_time = time.time() # timer for data loading per iteration
44 | epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
45 | model.update_process(epoch)
46 |
47 | for i, data in enumerate(dataset): # inner loop within one epoch
48 | iter_start_time = time.time() # timer for computation per iteration
49 | if total_iters % opt.print_freq == 0:
50 | t_data = iter_start_time - iter_data_time
51 | visualizer.reset()
52 | total_iters += opt.batch_size
53 | epoch_iter += opt.batch_size
54 | model.set_input(data) # unpack data from dataset and apply preprocessing
55 | model.optimize_parameters() # calculate loss functions, get gradients, update network weights
56 |
57 | if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file
58 | save_result = total_iters % opt.update_html_freq == 0
59 | model.compute_visuals()
60 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
61 |
62 | if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk
63 | losses = model.get_current_losses()
64 | t_comp = (time.time() - iter_start_time) / opt.batch_size
65 | if opt.model == 'cycle_gan':
66 | processes = [model.process] + model.lambda_As
67 | visualizer.print_current_losses_process(epoch, epoch_iter, losses, t_comp, t_data, processes)
68 | else:
69 | visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
70 | if opt.display_id > 0:
71 | visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)
72 |
73 | if total_iters % opt.save_latest_freq == 0: # cache our latest model every iterations
74 | print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
75 | save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
76 | model.save_networks(save_suffix)
77 |
78 | iter_data_time = time.time()
79 | if epoch % opt.save_epoch_freq == 0: # cache our model every epochs
80 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
81 | model.save_networks('latest')
82 | model.save_networks(epoch)
83 |
84 | print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
85 | model.update_learning_rate() # update learning rates at the end of every epoch.
86 |
87 | print('Total Time Taken: %d sec' % (time.time() - start))
--------------------------------------------------------------------------------
/models/pretrained_networks.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import torch
3 | from torchvision import models
4 | from IPython import embed
5 |
6 | class squeezenet(torch.nn.Module):
7 | def __init__(self, requires_grad=False, pretrained=True):
8 | super(squeezenet, self).__init__()
9 | pretrained_features = models.squeezenet1_1(pretrained=pretrained).features
10 | self.slice1 = torch.nn.Sequential()
11 | self.slice2 = torch.nn.Sequential()
12 | self.slice3 = torch.nn.Sequential()
13 | self.slice4 = torch.nn.Sequential()
14 | self.slice5 = torch.nn.Sequential()
15 | self.slice6 = torch.nn.Sequential()
16 | self.slice7 = torch.nn.Sequential()
17 | self.N_slices = 7
18 | for x in range(2):
19 | self.slice1.add_module(str(x), pretrained_features[x])
20 | for x in range(2,5):
21 | self.slice2.add_module(str(x), pretrained_features[x])
22 | for x in range(5, 8):
23 | self.slice3.add_module(str(x), pretrained_features[x])
24 | for x in range(8, 10):
25 | self.slice4.add_module(str(x), pretrained_features[x])
26 | for x in range(10, 11):
27 | self.slice5.add_module(str(x), pretrained_features[x])
28 | for x in range(11, 12):
29 | self.slice6.add_module(str(x), pretrained_features[x])
30 | for x in range(12, 13):
31 | self.slice7.add_module(str(x), pretrained_features[x])
32 | if not requires_grad:
33 | for param in self.parameters():
34 | param.requires_grad = False
35 |
36 | def forward(self, X):
37 | h = self.slice1(X)
38 | h_relu1 = h
39 | h = self.slice2(h)
40 | h_relu2 = h
41 | h = self.slice3(h)
42 | h_relu3 = h
43 | h = self.slice4(h)
44 | h_relu4 = h
45 | h = self.slice5(h)
46 | h_relu5 = h
47 | h = self.slice6(h)
48 | h_relu6 = h
49 | h = self.slice7(h)
50 | h_relu7 = h
51 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
52 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
53 |
54 | return out
55 |
56 |
57 | class alexnet(torch.nn.Module):
58 | def __init__(self, requires_grad=False, pretrained=True):
59 | super(alexnet, self).__init__()
60 | alexnet_pretrained_features = models.alexnet(pretrained=pretrained).features
61 | self.slice1 = torch.nn.Sequential()
62 | self.slice2 = torch.nn.Sequential()
63 | self.slice3 = torch.nn.Sequential()
64 | self.slice4 = torch.nn.Sequential()
65 | self.slice5 = torch.nn.Sequential()
66 | self.N_slices = 5
67 | for x in range(2):
68 | self.slice1.add_module(str(x), alexnet_pretrained_features[x])
69 | for x in range(2, 5):
70 | self.slice2.add_module(str(x), alexnet_pretrained_features[x])
71 | for x in range(5, 8):
72 | self.slice3.add_module(str(x), alexnet_pretrained_features[x])
73 | for x in range(8, 10):
74 | self.slice4.add_module(str(x), alexnet_pretrained_features[x])
75 | for x in range(10, 12):
76 | self.slice5.add_module(str(x), alexnet_pretrained_features[x])
77 | if not requires_grad:
78 | for param in self.parameters():
79 | param.requires_grad = False
80 |
81 | def forward(self, X):
82 | h = self.slice1(X)
83 | h_relu1 = h
84 | h = self.slice2(h)
85 | h_relu2 = h
86 | h = self.slice3(h)
87 | h_relu3 = h
88 | h = self.slice4(h)
89 | h_relu4 = h
90 | h = self.slice5(h)
91 | h_relu5 = h
92 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
93 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
94 |
95 | return out
96 |
97 | class vgg16(torch.nn.Module):
98 | def __init__(self, requires_grad=False, pretrained=True):
99 | super(vgg16, self).__init__()
100 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
101 | self.slice1 = torch.nn.Sequential()
102 | self.slice2 = torch.nn.Sequential()
103 | self.slice3 = torch.nn.Sequential()
104 | self.slice4 = torch.nn.Sequential()
105 | self.slice5 = torch.nn.Sequential()
106 | self.N_slices = 5
107 | for x in range(4):
108 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
109 | for x in range(4, 9):
110 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
111 | for x in range(9, 16):
112 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
113 | for x in range(16, 23):
114 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
115 | for x in range(23, 30):
116 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
117 | if not requires_grad:
118 | for param in self.parameters():
119 | param.requires_grad = False
120 |
121 | def forward(self, X):
122 | h = self.slice1(X)
123 | h_relu1_2 = h
124 | h = self.slice2(h)
125 | h_relu2_2 = h
126 | h = self.slice3(h)
127 | h_relu3_3 = h
128 | h = self.slice4(h)
129 | h_relu4_3 = h
130 | h = self.slice5(h)
131 | h_relu5_3 = h
132 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
133 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
134 |
135 | return out
136 |
137 |
138 |
139 | class resnet(torch.nn.Module):
140 | def __init__(self, requires_grad=False, pretrained=True, num=18):
141 | super(resnet, self).__init__()
142 | if(num==18):
143 | self.net = models.resnet18(pretrained=pretrained)
144 | elif(num==34):
145 | self.net = models.resnet34(pretrained=pretrained)
146 | elif(num==50):
147 | self.net = models.resnet50(pretrained=pretrained)
148 | elif(num==101):
149 | self.net = models.resnet101(pretrained=pretrained)
150 | elif(num==152):
151 | self.net = models.resnet152(pretrained=pretrained)
152 | self.N_slices = 5
153 |
154 | self.conv1 = self.net.conv1
155 | self.bn1 = self.net.bn1
156 | self.relu = self.net.relu
157 | self.maxpool = self.net.maxpool
158 | self.layer1 = self.net.layer1
159 | self.layer2 = self.net.layer2
160 | self.layer3 = self.net.layer3
161 | self.layer4 = self.net.layer4
162 |
163 | def forward(self, X):
164 | h = self.conv1(X)
165 | h = self.bn1(h)
166 | h = self.relu(h)
167 | h_relu1 = h
168 | h = self.maxpool(h)
169 | h = self.layer1(h)
170 | h_conv2 = h
171 | h = self.layer2(h)
172 | h_conv3 = h
173 | h = self.layer3(h)
174 | h_conv4 = h
175 | h = self.layer4(h)
176 | h_conv5 = h
177 |
178 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
179 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
180 |
181 | return out
182 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2 |
3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4 | """
5 | import random
6 | import numpy as np
7 | import torch.utils.data as data
8 | from PIL import Image
9 | import torchvision.transforms as transforms
10 | from abc import ABCMeta, abstractmethod
11 |
12 |
13 | class BaseDataset(data.Dataset):
14 | __metaclass__ = ABCMeta
15 | """This class is an abstract base class (ABC) for datasets.
16 |
17 | To create a subclass, you need to implement the following four functions:
18 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
19 | -- <__len__>: return the size of dataset.
20 | -- <__getitem__>: get a data point.
21 | -- : (optionally) add dataset-specific options and set default options.
22 | """
23 |
24 | def __init__(self, opt):
25 | """Initialize the class; save the options in the class
26 |
27 | Parameters:
28 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
29 | """
30 | self.opt = opt
31 | self.root = opt.dataroot
32 |
33 | @staticmethod
34 | def modify_commandline_options(parser, is_train):
35 | """Add new dataset-specific options, and rewrite default values for existing options.
36 |
37 | Parameters:
38 | parser -- original option parser
39 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
40 |
41 | Returns:
42 | the modified parser.
43 | """
44 | return parser
45 |
46 | @abstractmethod
47 | def __len__(self):
48 | """Return the total number of images in the dataset."""
49 | return 0
50 |
51 | @abstractmethod
52 | def __getitem__(self, index):
53 | """Return a data point and its metadata information.
54 |
55 | Parameters:
56 | index - - a random integer for data indexing
57 |
58 | Returns:
59 | a dictionary of data with their names. It ususally contains the data itself and its metadata information.
60 | """
61 | pass
62 |
63 |
64 | def get_params(opt, size):
65 | w, h = size
66 | new_h = h
67 | new_w = w
68 | if opt.preprocess == 'resize_and_crop':
69 | new_h = new_w = opt.load_size
70 | elif opt.preprocess == 'scale_width_and_crop':
71 | new_w = opt.load_size
72 | new_h = opt.load_size * h // w
73 |
74 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
75 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
76 |
77 | flip = random.random() > 0.5
78 |
79 | return {'crop_pos': (x, y), 'flip': flip}
80 |
81 |
82 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
83 | transform_list = []
84 | if grayscale:
85 | transform_list.append(transforms.Grayscale(1))
86 | if 'resize' in opt.preprocess:
87 | osize = [opt.load_size, opt.load_size]
88 | transform_list.append(transforms.Resize(osize, method))
89 | elif 'scale_width' in opt.preprocess:
90 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
91 |
92 | if 'crop' in opt.preprocess:
93 | if params is None:
94 | transform_list.append(transforms.RandomCrop(opt.crop_size))
95 | else:
96 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
97 |
98 | if opt.preprocess == 'none':
99 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
100 |
101 | if not opt.no_flip:
102 | if params is None:
103 | transform_list.append(transforms.RandomHorizontalFlip())
104 | elif params['flip']:
105 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
106 |
107 | if convert:
108 | transform_list += [transforms.ToTensor()]
109 | if grayscale:
110 | transform_list += [transforms.Normalize((0.5,), (0.5,))]
111 | else:
112 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
113 | return transforms.Compose(transform_list)
114 |
115 | def get_transform_mask(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
116 | transform_list = []
117 | if grayscale:
118 | transform_list.append(transforms.Grayscale(1))
119 | if 'resize' in opt.preprocess:
120 | osize = [opt.load_size, opt.load_size]
121 | transform_list.append(transforms.Resize(osize, method))
122 | elif 'scale_width' in opt.preprocess:
123 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
124 |
125 | if 'crop' in opt.preprocess:
126 | if params is None:
127 | transform_list.append(transforms.RandomCrop(opt.crop_size))
128 | else:
129 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
130 |
131 | if opt.preprocess == 'none':
132 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
133 |
134 | if not opt.no_flip:
135 | if params is None:
136 | transform_list.append(transforms.RandomHorizontalFlip())
137 | elif params['flip']:
138 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
139 |
140 | if convert:
141 | transform_list += [transforms.ToTensor()]
142 | return transforms.Compose(transform_list)
143 |
144 | def __make_power_2(img, base, method=Image.BICUBIC):
145 | ow, oh = img.size
146 | h = int(round(oh / base) * base)
147 | w = int(round(ow / base) * base)
148 | if (h == oh) and (w == ow):
149 | return img
150 |
151 | __print_size_warning(ow, oh, w, h)
152 | return img.resize((w, h), method)
153 |
154 |
155 | def __scale_width(img, target_width, method=Image.BICUBIC):
156 | ow, oh = img.size
157 | if (ow == target_width):
158 | return img
159 | w = target_width
160 | h = int(target_width * oh / ow)
161 | return img.resize((w, h), method)
162 |
163 |
164 | def __crop(img, pos, size):
165 | ow, oh = img.size
166 | x1, y1 = pos
167 | tw = th = size
168 | if (ow > tw or oh > th):
169 | return img.crop((x1, y1, x1 + tw, y1 + th))
170 | return img
171 |
172 |
173 | def __flip(img, flip):
174 | if flip:
175 | return img.transpose(Image.FLIP_LEFT_RIGHT)
176 | return img
177 |
178 |
179 | def __print_size_warning(ow, oh, w, h):
180 | """Print warning information about image size(only print once)"""
181 | if not hasattr(__print_size_warning, 'has_printed'):
182 | print("The image size needs to be a multiple of 4. "
183 | "The loaded image size was (%d, %d), so it was adjusted to "
184 | "(%d, %d). This adjustment will be done to all images "
185 | "whose sizes are not multiples of 4" % (ow, oh, w, h))
186 | __print_size_warning.has_printed = True
187 |
--------------------------------------------------------------------------------
/data/unaligned_mask_stylecls_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from data.base_dataset import BaseDataset, get_params, get_transform, get_transform_mask
3 | from data.image_folder import make_dataset
4 | from PIL import Image
5 | import random
6 | import torch
7 | import torchvision.transforms as transforms
8 | import numpy as np
9 |
10 |
11 | class UnalignedMaskStyleClsDataset(BaseDataset):
12 | """
13 | This dataset class can load unaligned/unpaired datasets.
14 |
15 | It requires two directories to host training images from domain A '/path/to/data/trainA'
16 | and from domain B '/path/to/data/trainB' respectively.
17 | You can train the model with the dataset flag '--dataroot /path/to/data'.
18 | Similarly, you need to prepare two directories:
19 | '/path/to/data/testA' and '/path/to/data/testB' during test time.
20 | """
21 |
22 | def __init__(self, opt):
23 | """Initialize this dataset class.
24 |
25 | Parameters:
26 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
27 | """
28 | BaseDataset.__init__(self, opt)
29 |
30 | imglistA = './datasets/list/%s/%s.txt' % (opt.phase+'A', opt.dataroot)
31 | imglistB = './datasets/list/%s/%s.txt' % (opt.phase+'B', opt.dataroot)
32 |
33 | if not os.path.exists(imglistA) or not os.path.exists(imglistB):
34 | self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA'
35 | self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB'
36 |
37 | self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
38 | self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
39 | else:
40 | self.A_paths = sorted(open(imglistA, 'r').read().splitlines())
41 | self.B_paths = sorted(open(imglistB, 'r').read().splitlines())
42 |
43 | self.A_size = len(self.A_paths) # get the size of dataset A
44 | self.B_size = len(self.B_paths) # get the size of dataset B
45 | print("A size:", self.A_size)
46 | print("B size:", self.B_size)
47 | btoA = self.opt.direction == 'BtoA'
48 | self.input_nc = self.opt.output_nc if btoA else self.opt.input_nc # get the number of channels of input image
49 | self.output_nc = self.opt.input_nc if btoA else self.opt.output_nc # get the number of channels of output image
50 |
51 | if opt.dataroot == '190613-4s':
52 | self.softmaxloc = os.path.join('style_features/styles2/', '1vgg19_softmax')
53 | elif opt.dataroot == '190613-4sn5':
54 | self.softmaxloc = os.path.join('style_features/styles2_sn_equal/', '1vgg19_softmax')
55 | elif '190613-4sn' in self.opt.dataroot:
56 | self.softmaxloc = os.path.join('style_features/styles2_sn/', '1vgg19_softmax')
57 |
58 |
59 | def __getitem__(self, index):
60 | """Return a data point and its metadata information.
61 |
62 | Parameters:
63 | index (int) -- a random integer for data indexing
64 |
65 | Returns a dictionary that contains A, B, A_paths and B_paths
66 | A (tensor) -- an image in the input domain
67 | B (tensor) -- its corresponding image in the target domain
68 | A_paths (str) -- image paths
69 | B_paths (str) -- image paths
70 | """
71 | A_path = self.A_paths[index % self.A_size] # make sure index is within then range
72 | if self.opt.serial_batches: # make sure index is within then range
73 | index_B = index % self.B_size
74 | else: # randomize the index for domain B to avoid fixed pairs.
75 | index_B = random.randint(0, self.B_size - 1)
76 | B_path = self.B_paths[index_B]
77 | A_img = Image.open(A_path).convert('RGB')
78 | B_img = Image.open(B_path).convert('RGB')
79 |
80 | basenA = os.path.basename(A_path)
81 | A_mask_img = Image.open(os.path.join('./datasets/list/mask/A',basenA))
82 | basenB = os.path.basename(B_path)
83 | basenB2 = basenB.replace('_fake.png','.png')
84 | # for added synthetic drawing
85 | basenB2 = basenB2.replace('_style1.png','.png')
86 | basenB2 = basenB2.replace('_style2.png','.png')
87 | basenB2 = basenB2.replace('_style1single.png','.png')
88 | basenB2 = basenB2.replace('_style2single.png','.png')
89 | B_mask_img = Image.open(os.path.join('./datasets/list/mask/B',basenB2))
90 | if self.opt.use_eye_mask:
91 | A_maske_img = Image.open(os.path.join('./datasets/list/mask/A_eyes',basenA))
92 | B_maske_img = Image.open(os.path.join('./datasets/list/mask/B_eyes',basenB2))
93 | if self.opt.use_lip_mask:
94 | A_maskl_img = Image.open(os.path.join('./datasets/list/mask/A_lips',basenA))
95 | B_maskl_img = Image.open(os.path.join('./datasets/list/mask/B_lips',basenB2))
96 | if self.opt.metric_inmask:
97 | A_maskfg_img = Image.open(os.path.join('./datasets/list/mask/A_fg',basenA))
98 |
99 | # apply image transformation
100 | transform_params_A = get_params(self.opt, A_img.size)
101 | transform_params_B = get_params(self.opt, B_img.size)
102 | A = get_transform(self.opt, transform_params_A, grayscale=(self.input_nc == 1))(A_img)
103 | B = get_transform(self.opt, transform_params_B, grayscale=(self.output_nc == 1))(B_img)
104 | A_mask = get_transform_mask(self.opt, transform_params_A, grayscale=1)(A_mask_img)
105 | B_mask = get_transform_mask(self.opt, transform_params_B, grayscale=1)(B_mask_img)
106 | if self.opt.use_eye_mask:
107 | A_maske = get_transform_mask(self.opt, transform_params_A, grayscale=1)(A_maske_img)
108 | B_maske = get_transform_mask(self.opt, transform_params_B, grayscale=1)(B_maske_img)
109 | if self.opt.use_lip_mask:
110 | A_maskl = get_transform_mask(self.opt, transform_params_A, grayscale=1)(A_maskl_img)
111 | B_maskl = get_transform_mask(self.opt, transform_params_B, grayscale=1)(B_maskl_img)
112 | if self.opt.metric_inmask:
113 | A_maskfg = get_transform_mask(self.opt, transform_params_A, grayscale=1)(A_maskfg_img)
114 |
115 | item = {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path, 'A_mask': A_mask, 'B_mask': B_mask}
116 | if self.opt.use_eye_mask:
117 | item['A_maske'] = A_maske
118 | item['B_maske'] = B_maske
119 | if self.opt.use_lip_mask:
120 | item['A_maskl'] = A_maskl
121 | item['B_maskl'] = B_maskl
122 | if self.opt.metric_inmask:
123 | item['A_maskfg'] = A_maskfg
124 |
125 |
126 | softmax = np.load(os.path.join(self.softmaxloc,basenB[:-4]+'.npy'))
127 | softmax = torch.Tensor(softmax)
128 | [maxv,index] = torch.max(softmax,0)
129 | B_label = index
130 | if len(self.opt.sfeature_mode) >= 8 and self.opt.sfeature_mode[-8:] == '_softmax':
131 | if self.opt.one_hot:
132 | B_style = torch.Tensor([0.,0.,0.])
133 | B_style[index] = 1.
134 | else:
135 | B_style = softmax
136 | B_style = B_style.view(3, 1, 1)
137 | B_style = B_style.repeat(1, 128, 128)
138 | #print(index, index_B, torch.mean(B_style,(1,2)))
139 | elif self.opt.sfeature_mode == 'domain':
140 | B_style = B_label
141 | item['B_style'] = B_style
142 | item['B_label'] = B_label
143 | if self.opt.isTrain and self.opt.style_loss_with_weight:
144 | item['B_style0'] = softmax
145 | if self.opt.isTrain and self.opt.metricvec:
146 | vec = softmax
147 | vec = vec.view(3, 1, 1)
148 | vec = vec.repeat(1, 299, 299)
149 | item['vec'] = vec
150 |
151 | return item
152 |
153 | def __len__(self):
154 | """Return the total number of images in the dataset.
155 |
156 | As we have two datasets with potentially different number of images,
157 | we take a maximum of
158 | """
159 | return max(self.A_size, self.B_size)
160 |
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from util import util
4 | import torch
5 | import models
6 | import data
7 |
8 |
9 | class BaseOptions():
10 | """This class defines options used during both training and test time.
11 |
12 | It also implements several helper functions such as parsing, printing, and saving the options.
13 | It also gathers additional options defined in functions in both dataset class and model class.
14 | """
15 |
16 | def __init__(self):
17 | """Reset the class; indicates the class hasn't been initailized"""
18 | self.initialized = False
19 |
20 | def initialize(self, parser):
21 | """Define the common options that are used in both training and test."""
22 | # basic parameters
23 | parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
24 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
25 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
26 | parser.add_argument('--gpu_ids_p', type=str, default='0', help='gpu ids for pretrained auxiliary models: e.g. 0 0,1,2, 0,2. use -1 for CPU')
27 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
28 | # model parameters
29 | parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')
30 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
31 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
32 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
33 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
34 | parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
35 | parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')
36 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
37 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')
38 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')
39 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
40 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
41 | # dataset parameters
42 | parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
43 | parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
44 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
45 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
46 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
47 | parser.add_argument('--load_size', type=int, default=286, help='scale images to this size')
48 | parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
49 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
50 | parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
51 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
52 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
53 | # additional parameters
54 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
55 | parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
56 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
57 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
58 | self.initialized = True
59 | return parser
60 |
61 | def gather_options(self):
62 | """Initialize our parser with basic options(only once).
63 | Add additional model-specific and dataset-specific options.
64 | These options are defined in the function
65 | in model and dataset classes.
66 | """
67 | if not self.initialized: # check if it has been initialized
68 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
69 | parser = self.initialize(parser)
70 |
71 | # get the basic options
72 | opt, _ = parser.parse_known_args()
73 |
74 | # modify model-related parser options
75 | model_name = opt.model
76 | model_option_setter = models.get_option_setter(model_name)
77 | parser = model_option_setter(parser, self.isTrain)
78 | opt, _ = parser.parse_known_args() # parse again with new defaults
79 |
80 | # modify dataset-related parser options
81 | dataset_name = opt.dataset_mode
82 | dataset_option_setter = data.get_option_setter(dataset_name)
83 | parser = dataset_option_setter(parser, self.isTrain)
84 |
85 | # save and return the parser
86 | self.parser = parser
87 | return parser.parse_args()
88 |
89 | def print_options(self, opt):
90 | """Print and save options
91 |
92 | It will print both current options and default values(if different).
93 | It will save options into a text file / [checkpoints_dir] / opt.txt
94 | """
95 | message = ''
96 | message += '----------------- Options ---------------\n'
97 | for k, v in sorted(vars(opt).items()):
98 | comment = ''
99 | default = self.parser.get_default(k)
100 | if v != default:
101 | comment = '\t[default: %s]' % str(default)
102 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
103 | message += '----------------- End -------------------'
104 | print(message)
105 |
106 | # save to the disk
107 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
108 | util.mkdirs(expr_dir)
109 | file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
110 | with open(file_name, 'wt') as opt_file:
111 | opt_file.write(message)
112 | opt_file.write('\n')
113 |
114 | def parse(self):
115 | """Parse our options, create checkpoints directory suffix, and set up gpu device."""
116 | opt = self.gather_options()
117 | opt.isTrain = self.isTrain # train or test
118 |
119 | # process opt.suffix
120 | if opt.suffix:
121 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
122 | opt.name = opt.name + suffix
123 |
124 | self.print_options(opt)
125 |
126 | # set gpu ids
127 | str_ids = opt.gpu_ids.split(',')
128 | opt.gpu_ids = []
129 | for str_id in str_ids:
130 | id = int(str_id)
131 | if id >= 0:
132 | opt.gpu_ids.append(id)
133 | if len(opt.gpu_ids) > 0:
134 | torch.cuda.set_device(opt.gpu_ids[0])
135 |
136 | # set gpu ids
137 | str_ids = opt.gpu_ids_p.split(',')
138 | opt.gpu_ids_p = []
139 | for str_id in str_ids:
140 | id = int(str_id)
141 | if id >= 0:
142 | opt.gpu_ids_p.append(id)
143 |
144 | self.opt = opt
145 | return self.opt
146 |
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import OrderedDict
4 | from abc import ABCMeta, abstractmethod
5 | from . import networks
6 | import pdb
7 |
8 |
9 | class BaseModel():
10 | __metaclass__ = ABCMeta
11 | """This class is an abstract base class (ABC) for models.
12 | To create a subclass, you need to implement the following five functions:
13 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
14 | -- : unpack data from dataset and apply preprocessing.
15 | -- : produce intermediate results.
16 | -- : calculate losses, gradients, and update network weights.
17 | -- : (optionally) add model-specific options and set default options.
18 | """
19 |
20 | def __init__(self, opt):
21 | """Initialize the BaseModel class.
22 |
23 | Parameters:
24 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
25 |
26 | When creating your custom class, you need to implement your own initialization.
27 | In this fucntion, you should first call
28 | Then, you need to define four lists:
29 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
30 | -- self.model_names (str list): specify the images that you want to display and save.
31 | -- self.visual_names (str list): define networks used in our training.
32 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
33 | """
34 | self.opt = opt
35 | self.gpu_ids = opt.gpu_ids
36 | self.isTrain = opt.isTrain
37 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
38 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
39 | if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
40 | torch.backends.cudnn.benchmark = True
41 | self.loss_names = []
42 | self.model_names = []
43 | self.visual_names = []
44 | self.optimizers = []
45 | self.image_paths = []
46 | self.metric = 0 # used for learning rate policy 'plateau'
47 |
48 | @staticmethod
49 | def modify_commandline_options(parser, is_train):
50 | """Add new model-specific options, and rewrite default values for existing options.
51 |
52 | Parameters:
53 | parser -- original option parser
54 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
55 |
56 | Returns:
57 | the modified parser.
58 | """
59 | return parser
60 |
61 | @abstractmethod
62 | def set_input(self, input):
63 | """Unpack input data from the dataloader and perform necessary pre-processing steps.
64 |
65 | Parameters:
66 | input (dict): includes the data itself and its metadata information.
67 | """
68 | pass
69 |
70 | @abstractmethod
71 | def forward(self):
72 | """Run forward pass; called by both functions and ."""
73 | pass
74 |
75 | @abstractmethod
76 | def optimize_parameters(self):
77 | """Calculate losses, gradients, and update network weights; called in every training iteration"""
78 | pass
79 |
80 | def setup(self, opt):
81 | """Load and print networks; create schedulers
82 |
83 | Parameters:
84 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
85 | """
86 | if self.isTrain:
87 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
88 | if not self.isTrain or opt.continue_train:
89 | load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
90 | self.load_networks(load_suffix)
91 | self.print_networks(opt.verbose)
92 |
93 | def eval(self):
94 | """Make models eval mode during test time"""
95 | for name in self.model_names:
96 | if isinstance(name, str):
97 | net = getattr(self, 'net' + name)
98 | net.eval()
99 |
100 | def test(self):
101 | """Forward function used in test time.
102 |
103 | This function wraps function in no_grad() so we don't save intermediate steps for backprop
104 | It also calls to produce additional visualization results
105 | """
106 | with torch.no_grad():
107 | self.forward()
108 | self.compute_visuals()
109 |
110 | def compute_visuals(self):
111 | """Calculate additional output images for visdom and HTML visualization"""
112 | pass
113 |
114 | def get_image_paths(self):
115 | """ Return image paths that are used to load current data"""
116 | return self.image_paths
117 |
118 | def update_learning_rate(self):
119 | """Update learning rates for all the networks; called at the end of every epoch"""
120 | for scheduler in self.schedulers:
121 | if self.opt.lr_policy == 'plateau':
122 | scheduler.step(self.metric)
123 | else:
124 | scheduler.step()
125 |
126 | lr = self.optimizers[0].param_groups[0]['lr']
127 | print('learning rate = %.7f' % lr)
128 |
129 | def get_current_visuals(self):
130 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
131 | visual_ret = OrderedDict()
132 | for name in self.visual_names:
133 | if isinstance(name, str):
134 | visual_ret[name] = getattr(self, name)
135 | return visual_ret
136 |
137 | def get_current_losses(self):
138 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
139 | errors_ret = OrderedDict()
140 | for name in self.loss_names:
141 | if isinstance(name, str):
142 | errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
143 | return errors_ret
144 |
145 | def save_networks(self, epoch):
146 | """Save all the networks to the disk.
147 |
148 | Parameters:
149 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
150 | """
151 | for name in self.model_names:
152 | if isinstance(name, str):
153 | save_filename = '%s_net_%s.pth' % (epoch, name)
154 | save_path = os.path.join(self.save_dir, save_filename)
155 | net = getattr(self, 'net' + name)
156 |
157 | if len(self.gpu_ids) > 0 and torch.cuda.is_available():
158 | torch.save(net.module.cpu().state_dict(), save_path)
159 | net.cuda(self.gpu_ids[0])
160 | else:
161 | torch.save(net.cpu().state_dict(), save_path)
162 |
163 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
164 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
165 | key = keys[i]
166 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
167 | if module.__class__.__name__.startswith('InstanceNorm') and \
168 | (key == 'running_mean' or key == 'running_var'):
169 | if getattr(module, key) is None:
170 | state_dict.pop('.'.join(keys))
171 | if module.__class__.__name__.startswith('InstanceNorm') and \
172 | (key == 'num_batches_tracked'):
173 | state_dict.pop('.'.join(keys))
174 | else:
175 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
176 |
177 | def load_networks(self, epoch):
178 | """Load all the networks from the disk.
179 |
180 | Parameters:
181 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
182 | """
183 | for name in self.model_names:
184 | if isinstance(name, str):
185 | load_filename = '%s_net_%s.pth' % (epoch, name)
186 | load_path = os.path.join(self.save_dir, load_filename)
187 | net = getattr(self, 'net' + name)
188 | if isinstance(net, torch.nn.DataParallel):
189 | net = net.module
190 | print('loading the model from %s' % load_path)
191 | # if you are using PyTorch newer than 0.4 (e.g., built from
192 | # GitHub source), you can remove str() on self.device
193 | state_dict = torch.load(load_path, map_location=str(self.device))
194 | if hasattr(state_dict, '_metadata'):
195 | del state_dict._metadata
196 |
197 | # patch InstanceNorm checkpoints prior to 0.4
198 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
199 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
200 | net.load_state_dict(state_dict)
201 | #param1 = {}
202 | #for name, parameters in net.named_parameters():
203 | # print(name,',',parameters.size())
204 | # param1[name] = parameters.detach().cpu().numpy()
205 | #pdb.set_trace()
206 |
207 | def print_networks(self, verbose):
208 | """Print the total number of parameters in the network and (if verbose) network architecture
209 |
210 | Parameters:
211 | verbose (bool) -- if verbose: print the network architecture
212 | """
213 | print('---------- Networks initialized -------------')
214 | for name in self.model_names:
215 | if isinstance(name, str):
216 | net = getattr(self, 'net' + name)
217 | num_params = 0
218 | for param in net.parameters():
219 | num_params += param.numel()
220 | if verbose:
221 | print(net)
222 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
223 | print('-----------------------------------------------')
224 |
225 | def set_requires_grad(self, nets, requires_grad=False):
226 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
227 | Parameters:
228 | nets (network list) -- a list of networks
229 | requires_grad (bool) -- whether the networks require gradients or not
230 | """
231 | if not isinstance(nets, list):
232 | nets = [nets]
233 | for net in nets:
234 | if net is not None:
235 | for param in net.parameters():
236 | param.requires_grad = requires_grad
237 |
238 | # ===========================================================================================================
239 | def masked(self, A,mask):
240 | if self.opt.mask_type == 0:
241 | return (A/2+0.5)*mask*2-1
242 | elif self.opt.mask_type == 1:
243 | return ((A/2+0.5)*mask+1-mask)*2-1
244 | elif self.opt.mask_type == 2:
245 | return torch.cat((A, mask), 1)
246 | elif self.opt.mask_type == 3:
247 | masked = ((A/2+0.5)*mask+1-mask)*2-1
248 | return torch.cat((masked, mask), 1)
--------------------------------------------------------------------------------
/util/visualizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import sys
4 | import ntpath
5 | import time
6 | from . import util, html
7 | from subprocess import Popen, PIPE
8 | #from scipy.misc import imresize
9 | from PIL import Image
10 | import pdb
11 | #from scipy.io import savemat
12 |
13 | if sys.version_info[0] == 2:
14 | VisdomExceptionBase = Exception
15 | else:
16 | VisdomExceptionBase = ConnectionError
17 |
18 |
19 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
20 | """Save images to the disk.
21 |
22 | Parameters:
23 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
24 | visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
25 | image_path (str) -- the string is used to create image paths
26 | aspect_ratio (float) -- the aspect ratio of saved images
27 | width (int) -- the images will be resized to width x width
28 |
29 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
30 | """
31 | image_dir = webpage.get_image_dir()
32 | short_path = ntpath.basename(image_path[0])
33 | name = os.path.splitext(short_path)[0]
34 |
35 | webpage.add_header(name)
36 | ims, txts, links = [], [], []
37 |
38 | for label, im_data in visuals.items():
39 | ## tensor to im
40 | im = util.tensor2im(im_data)
41 | #im = util.tensor2im2(im_data)
42 | ## save mat
43 | #im,imo = util.tensor2im(im_data)
44 | #matname = os.path.join(image_dir, '%s_%s.mat' % (name, label))
45 | #savemat(matname,{'imo':imo})
46 | image_name = '%s_%s.png' % (name, label)
47 | save_path = os.path.join(image_dir, image_name)
48 | h, w, _ = im.shape
49 | if aspect_ratio > 1.0:
50 | #im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic')
51 | im = np.array(Image.fromarray(im).resize((int(w * aspect_ratio), h), Image.BICUBIC))
52 | if aspect_ratio < 1.0:
53 | #im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic')
54 | im = np.array(Image.fromarray(im).resize((w, int(h / aspect_ratio)), Image.BICUBIC))
55 | util.save_image(im, save_path)
56 |
57 | ims.append(image_name)
58 | txts.append(label)
59 | links.append(image_name)
60 | webpage.add_images(ims, txts, links, width=width)
61 |
62 |
63 | class Visualizer():
64 | """This class includes several functions that can display/save images and print/save logging information.
65 |
66 | It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
67 | """
68 |
69 | def __init__(self, opt):
70 | """Initialize the Visualizer class
71 |
72 | Parameters:
73 | opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
74 | Step 1: Cache the training/test options
75 | Step 2: connect to a visdom server
76 | Step 3: create an HTML object for saveing HTML filters
77 | Step 4: create a logging file to store training losses
78 | """
79 | self.opt = opt # cache the option
80 | self.display_id = opt.display_id
81 | self.use_html = opt.isTrain and not opt.no_html
82 | self.win_size = opt.display_winsize
83 | self.name = opt.name
84 | self.port = opt.display_port
85 | self.saved = False
86 | if self.display_id > 0: # connect to a visdom server given and
87 | import visdom
88 | self.ncols = opt.display_ncols
89 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
90 | if not self.vis.check_connection():
91 | self.create_visdom_connections()
92 |
93 | if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/
94 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
95 | self.img_dir = os.path.join(self.web_dir, 'images')
96 | print('create web directory %s...' % self.web_dir)
97 | util.mkdirs([self.web_dir, self.img_dir])
98 | # create a logging file to store training losses
99 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
100 | with open(self.log_name, "a") as log_file:
101 | now = time.strftime("%c")
102 | log_file.write('================ Training Loss (%s) ================\n' % now)
103 |
104 | def reset(self):
105 | """Reset the self.saved status"""
106 | self.saved = False
107 |
108 | def create_visdom_connections(self):
109 | """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
110 | cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
111 | print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
112 | print('Command: %s' % cmd)
113 | Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
114 |
115 | def display_current_results(self, visuals, epoch, save_result):
116 | """Display current results on visdom; save current results to an HTML file.
117 |
118 | Parameters:
119 | visuals (OrderedDict) - - dictionary of images to display or save
120 | epoch (int) - - the current epoch
121 | save_result (bool) - - if save the current results to an HTML file
122 | """
123 | if self.display_id > 0: # show images in the browser using visdom
124 | ncols = self.ncols
125 | if ncols > 0: # show all the images in one visdom panel
126 | ncols = min(ncols, len(visuals))
127 | h, w = next(iter(visuals.values())).shape[:2]
128 | table_css = """""" % (w, h) # create a table css
132 | # create a table of images.
133 | title = self.name
134 | label_html = ''
135 | label_html_row = ''
136 | images = []
137 | idx = 0
138 | for label, image in visuals.items():
139 | #image_numpy = util.tensor2im(image)
140 | image_numpy = util.tensor2im2(image)
141 | label_html_row += '%s | ' % label
142 | #pdb.set_trace()
143 | images.append(image_numpy.transpose([2, 0, 1]))
144 | idx += 1
145 | if idx % ncols == 0:
146 | label_html += '%s
' % label_html_row
147 | label_html_row = ''
148 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
149 | while idx % ncols != 0:
150 | images.append(white_image)
151 | label_html_row += ' | '
152 | idx += 1
153 | if label_html_row != '':
154 | label_html += '%s
' % label_html_row
155 | try:
156 | self.vis.images(images, nrow=ncols, win=self.display_id + 1,
157 | padding=2, opts=dict(title=title + ' images'))
158 | label_html = '' % label_html
159 | self.vis.text(table_css + label_html, win=self.display_id + 2,
160 | opts=dict(title=title + ' labels'))
161 | except VisdomExceptionBase:
162 | self.create_visdom_connections()
163 |
164 | else: # show each image in a separate visdom panel;
165 | idx = 1
166 | try:
167 | for label, image in visuals.items():
168 | image_numpy = util.tensor2im(image)
169 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
170 | win=self.display_id + idx)
171 | idx += 1
172 | except VisdomExceptionBase:
173 | self.create_visdom_connections()
174 |
175 | if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
176 | self.saved = True
177 | # save images to the disk
178 | for label, image in visuals.items():
179 | image_numpy = util.tensor2im(image)
180 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
181 | util.save_image(image_numpy, img_path)
182 |
183 | # update website
184 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
185 | for n in range(epoch, 0, -1):
186 | webpage.add_header('epoch [%d]' % n)
187 | ims, txts, links = [], [], []
188 |
189 | for label, image_numpy in visuals.items():
190 | image_numpy = util.tensor2im(image)
191 | img_path = 'epoch%.3d_%s.png' % (n, label)
192 | ims.append(img_path)
193 | txts.append(label)
194 | links.append(img_path)
195 | webpage.add_images(ims, txts, links, width=self.win_size)
196 | webpage.save()
197 |
198 | def plot_current_losses(self, epoch, counter_ratio, losses):
199 | """display the current losses on visdom display: dictionary of error labels and values
200 |
201 | Parameters:
202 | epoch (int) -- current epoch
203 | counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
204 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
205 | """
206 | if not hasattr(self, 'plot_data'):
207 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
208 | self.plot_data['X'].append(epoch + counter_ratio)
209 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
210 | #X = np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1)
211 | #Y = np.array(self.plot_data['Y'])
212 | #pdb.set_trace()
213 | try:
214 | self.vis.line(
215 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
216 | Y=np.array(self.plot_data['Y']),
217 | opts={
218 | 'title': self.name + ' loss over time',
219 | 'legend': self.plot_data['legend'],
220 | 'xlabel': 'epoch',
221 | 'ylabel': 'loss'},
222 | win=self.display_id)
223 | except VisdomExceptionBase:
224 | self.create_visdom_connections()
225 |
226 | # losses: same format as |losses| of plot_current_losses
227 | def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
228 | """print current losses on console; also save the losses to the disk
229 |
230 | Parameters:
231 | epoch (int) -- current epoch
232 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
233 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
234 | t_comp (float) -- computational time per data point (normalized by batch_size)
235 | t_data (float) -- data loading time per data point (normalized by batch_size)
236 | """
237 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
238 | for k, v in losses.items():
239 | message += '%s: %.3f ' % (k, v)
240 |
241 | print(message) # print the message
242 | with open(self.log_name, "a") as log_file:
243 | log_file.write('%s\n' % message) # save the message
244 |
245 | # losses: same format as |losses| of plot_current_losses
246 | def print_current_losses_process(self, epoch, iters, losses, t_comp, t_data, processes):
247 | """print current losses on console; also save the losses to the disk
248 |
249 | Parameters:
250 | epoch (int) -- current epoch
251 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
252 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
253 | t_comp (float) -- computational time per data point (normalized by batch_size)
254 | t_data (float) -- data loading time per data point (normalized by batch_size)
255 | """
256 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
257 | message += '[process: %.3f, non_trunc: %.3f, trunc: %.3f] ' % (processes[0], processes[1], processes[2])
258 | for k, v in losses.items():
259 | message += '%s: %.3f ' % (k, v)
260 |
261 | print(message) # print the message
262 | with open(self.log_name, "a") as log_file:
263 | log_file.write('%s\n' % message) # save the message
264 |
--------------------------------------------------------------------------------
/models/cycle_gan_cls_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import itertools
3 | from util.image_pool import ImagePool
4 | from .base_model import BaseModel
5 | from . import networks
6 | import models.dist_model as dm # numpy==1.14.3
7 | import torchvision.transforms as transforms
8 | import os
9 | from util.util import tensor2im, tensor2im2, save_image
10 |
11 | def truncate(fake_B,a=127.5):#[-1,1]
12 | #return torch.round((fake_B+1)*a)/a-1
13 | return ((fake_B+1)*a).int().float()/a-1
14 |
15 | class CycleGANClsModel(BaseModel):
16 | """
17 | This class implements the CycleGAN model, for learning image-to-image translation without paired data.
18 |
19 | The model training requires '--dataset_mode unaligned' dataset.
20 | By default, it uses a '--netG resnet_9blocks' ResNet generator,
21 | a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
22 | and a least-square GANs objective ('--gan_mode lsgan').
23 |
24 | CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
25 | """
26 | @staticmethod
27 | def modify_commandline_options(parser, is_train=True):
28 | """Add new dataset-specific options, and rewrite default values for existing options.
29 |
30 | Parameters:
31 | parser -- original option parser
32 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
33 |
34 | Returns:
35 | the modified parser.
36 |
37 | For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses.
38 | A (source domain), B (target domain).
39 | Generators: G_A: A -> B; G_B: B -> A.
40 | Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A.
41 | Forward cycle loss: lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper)
42 | Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper)
43 | Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper)
44 | Dropout is not used in the original CycleGAN paper.
45 | """
46 | parser.set_defaults(no_dropout=True) # default CycleGAN did not use dropout
47 | parser.set_defaults(dataset_mode='unaligned_mask_stylecls')
48 | parser.add_argument('--netda', type=str, default='basic_cls') # discriminator has two branches
49 | parser.add_argument('--truncate', type=float, default=0.0, help='whether truncate in forward')
50 | if is_train:
51 | parser.add_argument('--lambda_A', type=float, default=5.0, help='weight for cycle loss (A -> B -> A)')
52 | parser.add_argument('--lambda_B', type=float, default=5.0, help='weight for cycle loss (B -> A -> B)')
53 | parser.add_argument('--lambda_identity', type=float, default=0, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')
54 | parser.add_argument('--perceptual_cycle', type=int, default=6, help='whether use perceptual similarity for cycle loss')
55 | parser.add_argument('--use_hed', type=int, default=1, help='whether use hed processing for cycle loss')
56 | parser.add_argument('--ntrunc_trunc', type=int, default=1, help='whether use both non-trunc version and trunc version')
57 | parser.add_argument('--trunc_a', type=float, default=31.875, help='multiply which value to round when trunc')
58 | parser.add_argument('--lambda_A_trunc', type=float, default=5.0, help='weight for cycle loss for trunc')
59 | parser.add_argument('--hed_pretrained_mode', type=str, default='./checkpoints/network-bsds500.pytorch', help='path to the pretrained hed model')
60 | parser.add_argument('--vgg_pretrained_mode', type=str, default='./checkpoints/vgg19.pth', help='path to the pretrained vgg model')
61 | parser.add_argument('--lambda_G_A_l', type=float, default=0.5, help='weight for local GAN loss in G')
62 | parser.add_argument('--style_loss_with_weight', type=int, default=0, help='whether multiply prob in style loss')
63 | parser.add_argument('--metric', action='store_true', help='whether use metric loss for fakeB')
64 | parser.add_argument('--metric_model_path', type=str, default='3/30_net_Regressor.pth', help='metric model path')
65 | parser.add_argument('--lambda_metric', type=float, default=0.5, help='weight for metric loss')
66 | parser.add_argument('--metricvec', action='store_true', help='whether use metric model with vec input')
67 | parser.add_argument('--metric_resnext', action='store_true', help='whether use resnext as metric model')
68 | parser.add_argument('--metric_resnet', action='store_true', help='whether use resnet as metric model')
69 | parser.add_argument('--metric_inception', action='store_true', help='whether use inception as metric model')# the inception of transform_input=False
70 | parser.add_argument('--metric_inmask', action='store_true', help='whether use inmask in metric model')
71 | else:
72 | parser.add_argument('--check_D', action='store_true', help='whether use check Ds outputs')
73 | # for masks
74 | parser.add_argument('--use_mask', type=int, default=1, help='whether use mask for special face region')
75 | parser.add_argument('--use_eye_mask', type=int, default=1, help='whether use mask for special face region')
76 | parser.add_argument('--use_lip_mask', type=int, default=1, help='whether use mask for special face region')
77 | parser.add_argument('--mask_type', type=int, default=3, help='use mask type, 0 outside black, 1 outside white')
78 | # for style control
79 | parser.add_argument('--style_control', type=int, default=1, help='use style_control')
80 | parser.add_argument('--sfeature_mode', type=str, default='1vgg19_softmax', help='vgg19 softmax as feature')
81 | parser.add_argument('--netga', type=str, default='resnet_style_9blocks', help='net arch for netG_A')
82 | parser.add_argument('--model0_res', type=int, default=0, help='number of resblocks in model0 (before insert style)')
83 | parser.add_argument('--model1_res', type=int, default=0, help='number of resblocks in model1 (after insert style, before 2 column merge)')
84 | parser.add_argument('--one_hot', type=int, default=0, help='use one-hot for style code')
85 |
86 | return parser
87 |
88 | def __init__(self, opt):
89 | """Initialize the CycleGAN class.
90 |
91 | Parameters:
92 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
93 | """
94 | BaseModel.__init__(self, opt)
95 | # specify the training losses you want to print out. The training/test scripts will call
96 | self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
97 | # specify the images you want to save/display. The training/test scripts will call
98 | visual_names_A = ['real_A', 'fake_B', 'rec_A']
99 | visual_names_B = ['real_B', 'fake_A', 'rec_B']
100 | if self.isTrain and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
101 | visual_names_A.append('idt_B')
102 | visual_names_B.append('idt_A')
103 | if self.isTrain and self.opt.use_hed:
104 | visual_names_A.append('real_A_hed')
105 | visual_names_A.append('rec_A_hed')
106 | if self.isTrain and self.opt.ntrunc_trunc:
107 | visual_names_A.append('rec_At')
108 | if self.opt.use_hed:
109 | visual_names_A.append('rec_At_hed')
110 | self.loss_names = ['D_A', 'G_A', 'cycle_A', 'cycle_A2', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B', 'G']
111 | if self.isTrain and self.opt.use_mask:
112 | visual_names_A.append('fake_B_l')
113 | visual_names_A.append('real_B_l')
114 | self.loss_names += ['D_A_l', 'G_A_l']
115 | if self.isTrain and self.opt.use_eye_mask:
116 | visual_names_A.append('fake_B_le')
117 | visual_names_A.append('real_B_le')
118 | self.loss_names += ['D_A_le', 'G_A_le']
119 | if self.isTrain and self.opt.use_lip_mask:
120 | visual_names_A.append('fake_B_ll')
121 | visual_names_A.append('real_B_ll')
122 | self.loss_names += ['D_A_ll', 'G_A_ll']
123 | if self.isTrain and self.opt.metric:
124 | self.loss_names += ['metric']
125 | #visual_names_B += ['fake_B2']
126 | if not self.isTrain and self.opt.use_mask:
127 | visual_names_A.append('fake_B_l')
128 | visual_names_A.append('real_B_l')
129 | if not self.isTrain and self.opt.use_eye_mask:
130 | visual_names_A.append('fake_B_le')
131 | visual_names_A.append('real_B_le')
132 | if not self.isTrain and self.opt.use_lip_mask:
133 | visual_names_A.append('fake_B_ll')
134 | visual_names_A.append('real_B_ll')
135 | self.loss_names += ['D_A_cls','G_A_cls']
136 |
137 | self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B
138 | print(self.visual_names)
139 | # specify the models you want to save to the disk. The training/test scripts will call and .
140 | if self.isTrain:
141 | self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
142 | if self.opt.use_mask:
143 | self.model_names += ['D_A_l']
144 | if self.opt.use_eye_mask:
145 | self.model_names += ['D_A_le']
146 | if self.opt.use_lip_mask:
147 | self.model_names += ['D_A_ll']
148 | else: # during test time, only load Gs
149 | self.model_names = ['G_A', 'G_B']
150 | if self.opt.check_D:
151 | self.model_names += ['D_A', 'D_B']
152 |
153 | # define networks (both Generators and discriminators)
154 | # The naming is different from those used in the paper.
155 | # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
156 | if not self.opt.style_control:
157 | self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
158 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
159 | else:
160 | print(opt.netga)
161 | print('model0_res', opt.model0_res)
162 | print('model1_res', opt.model1_res)
163 | self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netga, opt.norm,
164 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.model0_res, opt.model1_res)
165 | self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
166 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
167 |
168 | #if self.isTrain: # define discriminators
169 | if self.isTrain or self.opt.check_D: # define discriminators
170 | self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netda,
171 | opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids, n_class=3)
172 | self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
173 | opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
174 | if self.opt.use_mask:
175 | if self.opt.mask_type in [2, 3]:
176 | output_nc = opt.output_nc + 1
177 | else:
178 | output_nc = opt.output_nc
179 | self.netD_A_l = networks.define_D(output_nc, opt.ndf, opt.netD,
180 | opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
181 | if self.opt.use_eye_mask:
182 | if self.opt.mask_type in [2, 3]:
183 | output_nc = opt.output_nc + 1
184 | else:
185 | output_nc = opt.output_nc
186 | self.netD_A_le = networks.define_D(output_nc, opt.ndf, opt.netD,
187 | opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
188 | if self.opt.use_lip_mask:
189 | if self.opt.mask_type in [2, 3]:
190 | output_nc = opt.output_nc + 1
191 | else:
192 | output_nc = opt.output_nc
193 | self.netD_A_ll = networks.define_D(output_nc, opt.ndf, opt.netD,
194 | opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
195 |
196 | if self.isTrain and self.opt.metric:
197 | if not opt.metric_resnext and not opt.metric_resnet and not opt.metric_inception:
198 | self.metric = networks.define_inception_v3a(init_weights_='./checkpoints/metric/'+self.opt.metric_model_path,gpu_ids_ = self.gpu_ids,vec=self.opt.metricvec)
199 | elif opt.metric_resnext:
200 | self.metric = networks.define_resnext101a(init_weights_='./checkpoints/metric/'+self.opt.metric_model_path,gpu_ids_ = self.gpu_ids,vec=self.opt.metricvec)
201 | elif opt.metric_resnet:
202 | self.metric = networks.define_resnet101a(init_weights_='./checkpoints/metric/'+self.opt.metric_model_path,gpu_ids_ = self.gpu_ids,vec=self.opt.metricvec)
203 | elif opt.metric_inception:
204 | self.metric = networks.define_inception3a(init_weights_='./checkpoints/metric/'+self.opt.metric_model_path,gpu_ids_ = self.gpu_ids,vec=self.opt.metricvec)
205 | self.metric.eval()
206 | self.set_requires_grad(self.metric, False)
207 |
208 | if not self.isTrain and self.opt.check_D:
209 | self.criterionGAN = networks.GANLoss('lsgan').to(self.device)
210 |
211 | if self.isTrain:
212 | if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels
213 | assert(opt.input_nc == opt.output_nc)
214 | self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
215 | self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
216 | # define loss functions
217 | self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) # define GAN loss.
218 | self.criterionCycle = torch.nn.L1Loss()
219 | self.criterionIdt = torch.nn.L1Loss()
220 | self.criterionCls = torch.nn.CrossEntropyLoss()
221 | self.criterionCls2 = torch.nn.CrossEntropyLoss(reduction='none')
222 | # initialize optimizers; schedulers will be automatically created by function .
223 | self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
224 | if not self.opt.use_mask:
225 | self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
226 | elif not self.opt.use_eye_mask:
227 | D_params = list(self.netD_A.parameters()) + list(self.netD_B.parameters()) + list(self.netD_A_l.parameters())
228 | self.optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, 0.999))
229 | elif not self.opt.use_lip_mask:
230 | D_params = list(self.netD_A.parameters()) + list(self.netD_B.parameters()) + list(self.netD_A_l.parameters()) + list(self.netD_A_le.parameters())
231 | self.optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, 0.999))
232 | else:
233 | D_params = list(self.netD_A.parameters()) + list(self.netD_B.parameters()) + list(self.netD_A_l.parameters()) + list(self.netD_A_le.parameters()) + list(self.netD_A_ll.parameters())
234 | self.optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, 0.999))
235 | self.optimizers.append(self.optimizer_G)
236 | self.optimizers.append(self.optimizer_D)
237 |
238 | if self.opt.perceptual_cycle:
239 | if self.opt.perceptual_cycle in [1,2,3,6]:
240 | self.lpips = dm.DistModel(opt,model='net-lin',net='alex',use_gpu=True)
241 | elif self.opt.perceptual_cycle in [4,5,8]:
242 | self.vgg = networks.define_VGG(init_weights_=opt.vgg_pretrained_mode, feature_mode_=True, gpu_ids_=self.gpu_ids) # using conv4_4 layer
243 |
244 | if self.opt.use_hed:
245 | #self.hed = networks.define_HED(init_weights_=opt.hed_pretrained_mode, gpu_ids_=self.gpu_ids)
246 | self.hed = networks.define_HED(init_weights_=opt.hed_pretrained_mode, gpu_ids_=self.opt.gpu_ids_p)
247 | self.set_requires_grad(self.hed, False)
248 |
249 |
250 | def set_input(self, input):
251 | """Unpack input data from the dataloader and perform necessary pre-processing steps.
252 |
253 | Parameters:
254 | input (dict): include the data itself and its metadata information.
255 |
256 | The option 'direction' can be used to swap domain A and domain B.
257 | """
258 | AtoB = self.opt.direction == 'AtoB'
259 | self.real_A = input['A' if AtoB else 'B'].to(self.device)
260 | self.real_B = input['B' if AtoB else 'A'].to(self.device)
261 | self.image_paths = input['A_paths' if AtoB else 'B_paths']
262 | if self.opt.use_mask:
263 | self.A_mask = input['A_mask'].to(self.device)
264 | self.B_mask = input['B_mask'].to(self.device)
265 | if self.opt.use_eye_mask:
266 | self.A_maske = input['A_maske'].to(self.device)
267 | self.B_maske = input['B_maske'].to(self.device)
268 | if self.opt.use_lip_mask:
269 | self.A_maskl = input['A_maskl'].to(self.device)
270 | self.B_maskl = input['B_maskl'].to(self.device)
271 | if self.opt.style_control:
272 | self.real_B_style = input['B_style'].to(self.device)
273 | self.real_B_label = input['B_label'].to(self.device)
274 | if self.opt.isTrain and self.opt.style_loss_with_weight:
275 | self.real_B_style0 = input['B_style0'].to(self.device)
276 | self.zero = torch.zeros(self.real_B_label.size(),dtype=torch.int64).to(self.device)
277 | self.one = torch.ones(self.real_B_label.size(),dtype=torch.int64).to(self.device)
278 | self.two = 2*torch.ones(self.real_B_label.size(),dtype=torch.int64).to(self.device)
279 | if self.opt.isTrain and self.opt.metricvec:
280 | self.vec = input['vec'].to(self.device)
281 | if self.opt.isTrain and self.opt.metric_inmask:
282 | self.A_maskfg = input['A_maskfg'].to(self.device)
283 |
284 | def forward(self):
285 | """Run forward pass; called by both functions and ."""
286 | if not self.opt.style_control:
287 | self.fake_B = self.netG_A(self.real_A) # G_A(A)
288 | else:
289 | #print(torch.mean(self.real_B_style,(2,3)),'style_control')
290 | #print(self.real_B_style,'style_control')
291 | self.fake_B = self.netG_A(self.real_A, self.real_B_style)
292 | self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
293 | self.fake_A = self.netG_B(self.real_B) # G_B(B)
294 | if not self.opt.style_control:
295 | self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
296 | else:
297 | #print(torch.mean(self.real_B_style,(2,3)),'style_control')
298 | self.rec_B = self.netG_A(self.fake_A, self.real_B_style) # -- cycle_B loss
299 |
300 | if self.opt.use_mask:
301 | self.fake_B_l = self.masked(self.fake_B,self.A_mask)
302 | self.real_B_l = self.masked(self.real_B,self.B_mask)
303 | if self.opt.use_eye_mask:
304 | self.fake_B_le = self.masked(self.fake_B,self.A_maske)
305 | self.real_B_le = self.masked(self.real_B,self.B_maske)
306 | if self.opt.use_lip_mask:
307 | self.fake_B_ll = self.masked(self.fake_B,self.A_maskl)
308 | self.real_B_ll = self.masked(self.real_B,self.B_maskl)
309 |
310 | def backward_D_basic(self, netD, real, fake):
311 | """Calculate GAN loss for the discriminator
312 |
313 | Parameters:
314 | netD (network) -- the discriminator D
315 | real (tensor array) -- real images
316 | fake (tensor array) -- images generated by a generator
317 |
318 | Return the discriminator loss.
319 | We also call loss_D.backward() to calculate the gradients.
320 | """
321 | # Real
322 | pred_real = netD(real)
323 | loss_D_real = self.criterionGAN(pred_real, True)
324 | # Fake
325 | pred_fake = netD(fake.detach())
326 | loss_D_fake = self.criterionGAN(pred_fake, False)
327 | # Combined loss and calculate gradients
328 | loss_D = (loss_D_real + loss_D_fake) * 0.5
329 | loss_D.backward()
330 | return loss_D
331 |
332 | def backward_D_basic_cls(self, netD, real, fake):
333 | # Real
334 | pred_real, pred_real_cls = netD(real)
335 | loss_D_real = self.criterionGAN(pred_real, True)
336 | if not self.opt.style_loss_with_weight:
337 | loss_D_real_cls = self.criterionCls(pred_real_cls, self.real_B_label)
338 | else:
339 | loss_D_real_cls = torch.mean(self.real_B_style0[:,0] * self.criterionCls2(pred_real_cls, self.zero) + self.real_B_style0[:,1] * self.criterionCls2(pred_real_cls, self.one) + self.real_B_style0[:,2] * self.criterionCls2(pred_real_cls, self.two))
340 | # Fake
341 | pred_fake, pred_fake_cls = netD(fake.detach())
342 | loss_D_fake = self.criterionGAN(pred_fake, False)
343 | if not self.opt.style_loss_with_weight:
344 | loss_D_fake_cls = self.criterionCls(pred_fake_cls, self.real_B_label)
345 | else:
346 | loss_D_fake_cls = torch.mean(self.real_B_style0[:,0] * self.criterionCls2(pred_fake_cls, self.zero) + self.real_B_style0[:,1] * self.criterionCls2(pred_fake_cls, self.one) + self.real_B_style0[:,2] * self.criterionCls2(pred_fake_cls, self.two))
347 | # Combined loss and calculate gradients
348 | loss_D = (loss_D_real + loss_D_fake) * 0.5
349 | loss_D_cls = (loss_D_real_cls + loss_D_fake_cls) * 0.5
350 | loss_D_total = loss_D + loss_D_cls
351 | loss_D_total.backward()
352 | return loss_D, loss_D_cls
353 |
354 | def backward_D_A(self):
355 | """Calculate GAN loss for discriminator D_A"""
356 | fake_B = self.fake_B_pool.query(self.fake_B)
357 | self.loss_D_A, self.loss_D_A_cls = self.backward_D_basic_cls(self.netD_A, self.real_B, fake_B)
358 |
359 | def backward_D_A_l(self):
360 | """Calculate GAN loss for discriminator D_A_l"""
361 | fake_B = self.fake_B_pool.query(self.fake_B)
362 | self.loss_D_A_l = self.backward_D_basic(self.netD_A_l, self.masked(self.real_B,self.B_mask), self.masked(fake_B,self.A_mask))
363 |
364 | def backward_D_A_le(self):
365 | """Calculate GAN loss for discriminator D_A_le"""
366 | fake_B = self.fake_B_pool.query(self.fake_B)
367 | self.loss_D_A_le = self.backward_D_basic(self.netD_A_le, self.masked(self.real_B,self.B_maske), self.masked(fake_B,self.A_maske))
368 |
369 | def backward_D_A_ll(self):
370 | """Calculate GAN loss for discriminator D_A_ll"""
371 | fake_B = self.fake_B_pool.query(self.fake_B)
372 | self.loss_D_A_ll = self.backward_D_basic(self.netD_A_ll, self.masked(self.real_B,self.B_maskl), self.masked(fake_B,self.A_maskl))
373 |
374 | def backward_D_B(self):
375 | """Calculate GAN loss for discriminator D_B"""
376 | fake_A = self.fake_A_pool.query(self.fake_A)
377 | self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
378 |
379 | def update_process(self, epoch):
380 | self.process = (epoch - 1) / float(self.opt.niter_decay + self.opt.niter)
381 |
382 | def backward_G(self):
383 | """Calculate the loss for generators G_A and G_B"""
384 | lambda_idt = self.opt.lambda_identity
385 | lambda_G_A_l = self.opt.lambda_G_A_l
386 | lambda_A = self.opt.lambda_A
387 | lambda_B = self.opt.lambda_B
388 | lambda_A_trunc = self.opt.lambda_A_trunc
389 | if self.opt.ntrunc_trunc:
390 | lambda_A = lambda_A * (1 - self.process * 0.9)
391 | lambda_A_trunc = lambda_A_trunc * self.process * 0.9
392 | self.lambda_As = [lambda_A, lambda_A_trunc]
393 | # Identity loss
394 | if lambda_idt > 0:
395 | # G_A should be identity if real_B is fed: ||G_A(B) - B||
396 | self.idt_A = self.netG_A(self.real_B)
397 | self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
398 | # G_B should be identity if real_A is fed: ||G_B(A) - A||
399 | self.idt_B = self.netG_B(self.real_A)
400 | self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
401 | else:
402 | self.loss_idt_A = 0
403 | self.loss_idt_B = 0
404 |
405 | # GAN loss D_A(G_A(A))
406 | pred_fake, pred_fake_cls = self.netD_A(self.fake_B)
407 | self.loss_G_A = self.criterionGAN(pred_fake, True)
408 | if not self.opt.style_loss_with_weight:
409 | self.loss_G_A_cls = self.criterionCls(pred_fake_cls, self.real_B_label)
410 | else:
411 | self.loss_G_A_cls = torch.mean(self.real_B_style0[:,0] * self.criterionCls2(pred_fake_cls, self.zero) + self.real_B_style0[:,1] * self.criterionCls2(pred_fake_cls, self.one) + self.real_B_style0[:,2] * self.criterionCls2(pred_fake_cls, self.two))
412 | if self.opt.use_mask:
413 | self.loss_G_A_l = self.criterionGAN(self.netD_A_l(self.fake_B_l), True) * lambda_G_A_l
414 | if self.opt.use_eye_mask:
415 | self.loss_G_A_le = self.criterionGAN(self.netD_A_le(self.fake_B_le), True) * lambda_G_A_l
416 | if self.opt.use_lip_mask:
417 | self.loss_G_A_ll = self.criterionGAN(self.netD_A_ll(self.fake_B_ll), True) * lambda_G_A_l
418 | # GAN loss D_B(G_B(B))
419 | self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
420 | # Forward cycle loss || G_B(G_A(A)) - A||
421 | if self.opt.perceptual_cycle == 0:
422 | self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
423 | if self.opt.ntrunc_trunc:
424 | self.rec_At = self.netG_B(truncate(self.fake_B,self.opt.trunc_a))
425 | self.loss_cycle_A2 = self.criterionCycle(self.rec_At, self.real_A) * lambda_A_trunc
426 | else:
427 | if self.opt.perceptual_cycle == 1:
428 | self.loss_cycle_A = self.lpips.forward_pair(self.rec_A, self.real_A).mean() * lambda_A
429 | if self.opt.ntrunc_trunc:
430 | self.rec_At = self.netG_B(truncate(self.fake_B,self.opt.trunc_a))
431 | self.loss_cycle_A2 = self.lpips.forward_pair(self.rec_At, self.real_A).mean() * lambda_A_trunc
432 | elif self.opt.perceptual_cycle == 2:
433 | ts = self.real_A.shape
434 | rec_A = (self.rec_A[:,0,:,:]*0.299+self.rec_A[:,1,:,:]*0.587+self.rec_A[:,2,:,:]*0.114).unsqueeze(0)
435 | real_A = (self.real_A[:,0,:,:]*0.299+self.real_A[:,1,:,:]*0.587+self.real_A[:,2,:,:]*0.114).unsqueeze(0)
436 | self.loss_cycle_A = self.lpips.forward_pair(rec_A.expand(ts), real_A.expand(ts)).mean() * lambda_A
437 | elif self.opt.perceptual_cycle == 3 and self.opt.use_hed:
438 | ts = self.real_A.shape
439 | #[-1,1]->[0,1]->[-1,1]
440 | rec_A_hed = (self.hed(self.rec_A/2+0.5)-0.5)*2
441 | real_A_hed = (self.hed(self.real_A/2+0.5)-0.5)*2
442 | self.loss_cycle_A = self.lpips.forward_pair(rec_A_hed.expand(ts), real_A_hed.expand(ts)).mean() * lambda_A
443 | self.rec_A_hed = rec_A_hed
444 | self.real_A_hed = real_A_hed
445 | print(lambda_A)
446 | elif self.opt.perceptual_cycle == 4:
447 | x_a_feature = self.vgg(self.real_A)
448 | g_a_feature = self.vgg(self.rec_A)
449 | self.loss_cycle_A = self.criterionCycle(g_a_feature, x_a_feature.detach()) * lambda_A
450 | elif self.opt.perceptual_cycle == 5 and self.opt.use_hed:
451 | ts = self.real_A.shape
452 | rec_A_hed = (self.hed(self.rec_A/2+0.5)-0.5)*2
453 | real_A_hed = (self.hed(self.real_A/2+0.5)-0.5)*2
454 | x_a_feature = self.vgg(real_A_hed.expand(ts))
455 | g_a_feature = self.vgg(rec_A_hed.expand(ts))
456 | self.loss_cycle_A = self.criterionCycle(g_a_feature, x_a_feature.detach()) * lambda_A
457 | self.rec_A_hed = rec_A_hed
458 | self.real_A_hed = real_A_hed
459 | elif self.opt.perceptual_cycle == 6 and self.opt.use_hed and self.opt.ntrunc_trunc:
460 | ts = self.real_A.shape
461 | gpu_p = self.opt.gpu_ids_p[0]
462 | gpu = self.opt.gpu_ids[0]
463 | rec_A_hed = (self.hed(self.rec_A.cuda(gpu_p)/2+0.5)-0.5)*2
464 | real_A_hed = (self.hed(self.real_A.cuda(gpu_p)/2+0.5)-0.5)*2
465 | self.rec_At = self.netG_B(truncate(self.fake_B,self.opt.trunc_a))
466 | rec_At_hed = (self.hed(self.rec_At.cuda(gpu_p)/2+0.5)-0.5)*2
467 | self.loss_cycle_A = (self.lpips.forward_pair(rec_A_hed.expand(ts), real_A_hed.expand(ts)).mean()).cuda(gpu) * lambda_A
468 | self.loss_cycle_A2 = (self.lpips.forward_pair(rec_At_hed.expand(ts), real_A_hed.expand(ts)).mean()).cuda(gpu) * lambda_A_trunc
469 | self.rec_A_hed = rec_A_hed
470 | self.real_A_hed = real_A_hed
471 | self.rec_At_hed = rec_At_hed
472 | elif self.opt.perceptual_cycle == 8 and self.opt.use_hed and self.opt.ntrunc_trunc:
473 | ts = self.real_A.shape
474 | rec_A_hed = (self.hed(self.rec_A/2+0.5)-0.5)*2
475 | real_A_hed = (self.hed(self.real_A/2+0.5)-0.5)*2
476 | self.rec_At = self.netG_B(truncate(self.fake_B,self.opt.trunc_a))
477 | rec_At_hed = (self.hed(self.rec_At/2+0.5)-0.5)*2
478 | x_a_feature = self.vgg(real_A_hed.expand(ts))
479 | g_a_feature = self.vgg(rec_A_hed.expand(ts))
480 | gt_a_feature = self.vgg(rec_At_hed.expand(ts))
481 | self.loss_cycle_A = self.criterionCycle(g_a_feature, x_a_feature.detach()) * lambda_A
482 | self.loss_cycle_A2 = self.criterionCycle(gt_a_feature, x_a_feature.detach()) * lambda_A_trunc
483 | self.rec_A_hed = rec_A_hed
484 | self.real_A_hed = real_A_hed
485 | self.rec_At_hed = rec_At_hed
486 |
487 | # Backward cycle loss || G_A(G_B(B)) - B||
488 | self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
489 |
490 | # Metric loss, metric higher better
491 | if self.opt.metric:
492 | self.fake_B2 = self.fake_B.clone()
493 | if self.opt.metric_inmask:
494 | # background black
495 | #self.fake_B2 = (self.fake_B2/2+0.5)*self.A_maskfg*2-1
496 | # background white
497 | self.fake_B2 = ((self.fake_B2/2+0.5)*self.A_maskfg+1-self.A_maskfg)*2-1
498 | if not self.opt.metric_resnext and not self.opt.metric_resnet: # for two version of inception (during training input is [-1,1])
499 | self.fake_B2 = torch.nn.functional.interpolate(input=self.fake_B2, size=(299, 299), mode='bilinear', align_corners=False)
500 | self.fake_B2 = self.fake_B2.repeat(1,3,1,1)
501 | else: # for resnet and resnext
502 | self.fake_B2 = torch.nn.functional.interpolate(input=self.fake_B2, size=(224, 224), mode='bilinear', align_corners=False)
503 | x = self.fake_B2.repeat(1,3,1,1)
504 | # [-1,1] -> [0,1] -> mean [0.485,0.456,0.406], std [0.229,0.224,0.225]
505 | x_ch0 = (torch.unsqueeze(x[:, 0],1)*0.5+0.5-0.485)/0.229
506 | x_ch1 = (torch.unsqueeze(x[:, 1],1)*0.5+0.5-0.456)/0.224
507 | x_ch2 = (torch.unsqueeze(x[:, 2],1)*0.5+0.5-0.406)/0.225
508 | self.fake_B2 = torch.cat((x_ch0, x_ch1, x_ch2, x[:, 3:]), 1)
509 |
510 |
511 | if not self.opt.metricvec:
512 | pred = self.metric(self.fake_B2)
513 | else:
514 | pred = self.metric(torch.cat((self.fake_B2, self.vec),1))
515 | self.loss_metric = torch.mean((1-pred)) * self.opt.lambda_metric
516 |
517 | # combined loss and calculate gradients
518 | self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
519 | if getattr(self,'loss_cycle_A2',-1) != -1:
520 | self.loss_G = self.loss_G + self.loss_cycle_A2
521 | if getattr(self,'loss_G_A_l',-1) != -1:
522 | self.loss_G = self.loss_G + self.loss_G_A_l
523 | if getattr(self,'loss_G_A_le',-1) != -1:
524 | self.loss_G = self.loss_G + self.loss_G_A_le
525 | if getattr(self,'loss_G_A_ll',-1) != -1:
526 | self.loss_G = self.loss_G + self.loss_G_A_ll
527 | if getattr(self,'loss_G_A_cls',-1) != -1:
528 | self.loss_G = self.loss_G + self.loss_G_A_cls
529 | if getattr(self,'loss_metric',-1) != -1:
530 | self.loss_G = self.loss_G + self.loss_metric
531 | self.loss_G.backward()
532 |
533 | def optimize_parameters(self):
534 | """Calculate losses, gradients, and update network weights; called in every training iteration"""
535 | # forward
536 | self.forward() # compute fake images and reconstruction images.
537 | # G_A and G_B
538 | self.set_requires_grad([self.netD_A, self.netD_B], False) # Ds require no gradients when optimizing Gs
539 | if self.opt.use_mask:
540 | self.set_requires_grad([self.netD_A_l], False)
541 | if self.opt.use_eye_mask:
542 | self.set_requires_grad([self.netD_A_le], False)
543 | if self.opt.use_lip_mask:
544 | self.set_requires_grad([self.netD_A_ll], False)
545 | self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero
546 | self.backward_G() # calculate gradients for G_A and G_B
547 | self.optimizer_G.step() # update G_A and G_B's weights
548 | # D_A and D_B
549 | self.set_requires_grad([self.netD_A, self.netD_B], True)
550 | if self.opt.use_mask:
551 | self.set_requires_grad([self.netD_A_l], True)
552 | if self.opt.use_eye_mask:
553 | self.set_requires_grad([self.netD_A_le], True)
554 | if self.opt.use_lip_mask:
555 | self.set_requires_grad([self.netD_A_ll], True)
556 | self.optimizer_D.zero_grad() # set D_A and D_B's gradients to zero
557 | self.backward_D_A() # calculate gradients for D_A
558 | if self.opt.use_mask:
559 | self.backward_D_A_l()# calculate gradients for D_A_l
560 | if self.opt.use_eye_mask:
561 | self.backward_D_A_le()# calculate gradients for D_A_le
562 | if self.opt.use_lip_mask:
563 | self.backward_D_A_ll()# calculate gradients for D_A_ll
564 | self.backward_D_B() # calculate graidents for D_B
565 | self.optimizer_D.step() # update D_A and D_B's weights
566 |
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import init
5 | import functools
6 | from torch.optim import lr_scheduler
7 | import pdb
8 |
9 |
10 | ###############################################################################
11 | # Helper Functions
12 | ###############################################################################
13 |
14 |
15 | class Identity(nn.Module):
16 | def forward(self, x):
17 | return x
18 |
19 |
20 | def get_norm_layer(norm_type='instance'):
21 | """Return a normalization layer
22 |
23 | Parameters:
24 | norm_type (str) -- the name of the normalization layer: batch | instance | none
25 |
26 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
27 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
28 | """
29 | if norm_type == 'batch':
30 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
31 | elif norm_type == 'instance':
32 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
33 | elif norm_type == 'none':
34 | norm_layer = lambda x: Identity()
35 | else:
36 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
37 | return norm_layer
38 |
39 |
40 | def get_scheduler(optimizer, opt):
41 | """Return a learning rate scheduler
42 |
43 | Parameters:
44 | optimizer -- the optimizer of the network
45 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
46 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
47 |
48 | For 'linear', we keep the same learning rate for the first epochs
49 | and linearly decay the rate to zero over the next epochs.
50 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
51 | See https://pytorch.org/docs/stable/optim.html for more details.
52 | """
53 | if opt.lr_policy == 'linear':
54 | def lambda_rule(epoch):
55 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
56 | return lr_l
57 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
58 | elif opt.lr_policy == 'step':
59 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
60 | elif opt.lr_policy == 'plateau':
61 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
62 | elif opt.lr_policy == 'cosine':
63 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
64 | else:
65 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
66 | return scheduler
67 |
68 |
69 | def init_weights(net, init_type='normal', init_gain=0.02):
70 | """Initialize network weights.
71 |
72 | Parameters:
73 | net (network) -- network to be initialized
74 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
75 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
76 |
77 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
78 | work better for some applications. Feel free to try yourself.
79 | """
80 | def init_func(m): # define the initialization function
81 | classname = m.__class__.__name__
82 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
83 | if init_type == 'normal':
84 | init.normal_(m.weight.data, 0.0, init_gain)
85 | elif init_type == 'xavier':
86 | init.xavier_normal_(m.weight.data, gain=init_gain)
87 | elif init_type == 'kaiming':
88 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
89 | elif init_type == 'orthogonal':
90 | init.orthogonal_(m.weight.data, gain=init_gain)
91 | else:
92 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
93 | if hasattr(m, 'bias') and m.bias is not None:
94 | init.constant_(m.bias.data, 0.0)
95 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
96 | init.normal_(m.weight.data, 1.0, init_gain)
97 | init.constant_(m.bias.data, 0.0)
98 |
99 | print('initialize network with %s' % init_type)
100 | net.apply(init_func) # apply the initialization function
101 |
102 |
103 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
104 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
105 | Parameters:
106 | net (network) -- the network to be initialized
107 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
108 | gain (float) -- scaling factor for normal, xavier and orthogonal.
109 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
110 |
111 | Return an initialized network.
112 | """
113 | if len(gpu_ids) > 0:
114 | assert(torch.cuda.is_available())
115 | net.to(gpu_ids[0])
116 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
117 | init_weights(net, init_type, init_gain=init_gain)
118 | return net
119 |
120 |
121 | def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], model0_res=0, model1_res=0, extra_channel=3):
122 | """Create a generator
123 |
124 | Parameters:
125 | input_nc (int) -- the number of channels in input images
126 | output_nc (int) -- the number of channels in output images
127 | ngf (int) -- the number of filters in the last conv layer
128 | netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
129 | norm (str) -- the name of normalization layers used in the network: batch | instance | none
130 | use_dropout (bool) -- if use dropout layers.
131 | init_type (str) -- the name of our initialization method.
132 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
133 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
134 |
135 | Returns a generator
136 |
137 | Our current implementation provides two types of generators:
138 | U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
139 | The original U-Net paper: https://arxiv.org/abs/1505.04597
140 |
141 | Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
142 | Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
143 | We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
144 |
145 |
146 | The generator has been initialized by . It uses RELU for non-linearity.
147 | """
148 | net = None
149 | norm_layer = get_norm_layer(norm_type=norm)
150 |
151 | if netG == 'resnet_9blocks':
152 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
153 | elif netG == 'resnet_8blocks':
154 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=8)
155 | elif netG == 'resnet_style_9blocks':
156 | net = ResnetStyleGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, extra_channel=extra_channel)
157 | elif netG == 'resnet_style2_9blocks':
158 | net = ResnetStyle2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, model0_res=model0_res, extra_channel=extra_channel)
159 | elif netG == 'resnet_style2_8blocks':
160 | net = ResnetStyle2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=8, model0_res=model0_res, extra_channel=extra_channel)
161 | elif netG == 'resnet_style2_10blocks':
162 | net = ResnetStyle2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=10, model0_res=model0_res, extra_channel=extra_channel)
163 | elif netG == 'resnet_style3decoder_9blocks':
164 | net = ResnetStyle3DecoderGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, model0_res=model0_res)
165 | elif netG == 'resnet_style2mc_9blocks':
166 | net = ResnetStyle2MCGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, model0_res=model0_res, extra_channel=extra_channel)
167 | elif netG == 'resnet_style2mc2_9blocks':
168 | net = ResnetStyle2MC2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, model0_res=model0_res, model1_res=model1_res, extra_channel=extra_channel)
169 | elif netG == 'resnet_6blocks':
170 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
171 | elif netG == 'unet_128':
172 | net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
173 | elif netG == 'unet_256':
174 | net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
175 | else:
176 | raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
177 | return init_net(net, init_type, init_gain, gpu_ids)
178 |
179 |
180 | def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[], n_class=3):
181 | """Create a discriminator
182 |
183 | Parameters:
184 | input_nc (int) -- the number of channels in input images
185 | ndf (int) -- the number of filters in the first conv layer
186 | netD (str) -- the architecture's name: basic | n_layers | pixel
187 | n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
188 | norm (str) -- the type of normalization layers used in the network.
189 | init_type (str) -- the name of the initialization method.
190 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
191 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
192 |
193 | Returns a discriminator
194 |
195 | Our current implementation provides three types of discriminators:
196 | [basic]: 'PatchGAN' classifier described in the original pix2pix paper.
197 | It can classify whether 70×70 overlapping patches are real or fake.
198 | Such a patch-level discriminator architecture has fewer parameters
199 | than a full-image discriminator and can work on arbitrarily-sized images
200 | in a fully convolutional fashion.
201 |
202 | [n_layers]: With this mode, you cna specify the number of conv layers in the discriminator
203 | with the parameter (default=3 as used in [basic] (PatchGAN).)
204 |
205 | [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
206 | It encourages greater color diversity but has no effect on spatial statistics.
207 |
208 | The discriminator has been initialized by . It uses Leakly RELU for non-linearity.
209 | """
210 | net = None
211 | norm_layer = get_norm_layer(norm_type=norm)
212 |
213 | if netD == 'basic': # default PatchGAN classifier
214 | net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
215 | elif netD == 'basic_cls':
216 | net = NLayerDiscriminatorCls(input_nc, ndf, n_layers=3, n_class=3, norm_layer=norm_layer)
217 | elif netD == 'n_layers': # more options
218 | net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
219 | elif netD == 'pixel': # classify if each pixel is real or fake
220 | net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
221 | else:
222 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % net)
223 | return init_net(net, init_type, init_gain, gpu_ids)
224 |
225 |
226 | def define_HED(init_weights_, gpu_ids_=[]):
227 | net = HED()
228 |
229 | if len(gpu_ids_) > 0:
230 | assert(torch.cuda.is_available())
231 | net.to(gpu_ids_[0])
232 | net = torch.nn.DataParallel(net, gpu_ids_) # multi-GPUs
233 |
234 | if not init_weights_ == None:
235 | device = torch.device('cuda:{}'.format(gpu_ids_[0])) if gpu_ids_ else torch.device('cpu')
236 | print('Loading model from: %s'%init_weights_)
237 | state_dict = torch.load(init_weights_, map_location=str(device))
238 | if isinstance(net, torch.nn.DataParallel):
239 | net.module.load_state_dict(state_dict)
240 | else:
241 | net.load_state_dict(state_dict)
242 | print('load the weights successfully')
243 |
244 | return net
245 |
246 | def define_VGG(init_weights_, feature_mode_, batch_norm_=False, num_classes_=1000, gpu_ids_=[]):
247 | net = VGG19(init_weights=init_weights_, feature_mode=feature_mode_, batch_norm=batch_norm_, num_classes=num_classes_)
248 | # set the GPU
249 | if len(gpu_ids_) > 0:
250 | assert(torch.cuda.is_available())
251 | net.cuda(gpu_ids_[0])
252 | net = torch.nn.DataParallel(net, gpu_ids_) # multi-GPUs
253 |
254 | if not init_weights_ == None:
255 | device = torch.device('cuda:{}'.format(gpu_ids_[0])) if gpu_ids_ else torch.device('cpu')
256 | print('Loading model from: %s'%init_weights_)
257 | state_dict = torch.load(init_weights_, map_location=str(device))
258 | if isinstance(net, torch.nn.DataParallel):
259 | net.module.load_state_dict(state_dict)
260 | else:
261 | net.load_state_dict(state_dict)
262 | print('load the weights successfully')
263 | return net
264 |
265 | ###################################################################################################################
266 | from torchvision.models import vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn
267 | def define_vgg11_bn(gpu_ids_=[],vec=0):
268 | net = vgg11_bn(pretrained=True)
269 | net.classifier[6] = nn.Linear(4096, 1) #LSGAN needs no sigmoid, LSGAN-nn.MSELoss()
270 | if len(gpu_ids_) > 0:
271 | assert(torch.cuda.is_available())
272 | net.cuda(gpu_ids_[0])
273 | net = torch.nn.DataParallel(net, gpu_ids_)
274 | return net
275 | def define_vgg19_bn(gpu_ids_=[],vec=0):
276 | net = vgg19_bn(pretrained=True)
277 | net.classifier[6] = nn.Linear(4096, 1) #LSGAN needs no sigmoid, LSGAN-nn.MSELoss()
278 | if len(gpu_ids_) > 0:
279 | assert(torch.cuda.is_available())
280 | net.cuda(gpu_ids_[0])
281 | net = torch.nn.DataParallel(net, gpu_ids_)
282 | return net
283 | def define_vgg19(gpu_ids_=[],vec=0):
284 | net = vgg19(pretrained=True)
285 | net.classifier[6] = nn.Linear(4096, 1) #LSGAN needs no sigmoid, LSGAN-nn.MSELoss()
286 | if len(gpu_ids_) > 0:
287 | assert(torch.cuda.is_available())
288 | net.cuda(gpu_ids_[0])
289 | net = torch.nn.DataParallel(net, gpu_ids_)
290 | return net
291 | ###################################################################################################################
292 | from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152
293 | def define_resnet101(gpu_ids_=[],vec=0):
294 | net = resnet101(pretrained=True)
295 | num_ftrs = net.fc.in_features
296 | net.fc = nn.Linear(num_ftrs, 1) #LSGAN needs no sigmoid, LSGAN-nn.MSELoss()
297 | if len(gpu_ids_) > 0:
298 | assert(torch.cuda.is_available())
299 | net.cuda(gpu_ids_[0])
300 | net = torch.nn.DataParallel(net, gpu_ids_)
301 | return net
302 | def define_resnet101a(init_weights_,gpu_ids_=[],vec=0):
303 | net = resnet101(pretrained=True)
304 | num_ftrs = net.fc.in_features
305 | net.fc = nn.Linear(num_ftrs, 1) #LSGAN needs no sigmoid, LSGAN-nn.MSELoss()
306 | if not init_weights_ == None:
307 | print('Loading model from: %s'%init_weights_)
308 | state_dict = torch.load(init_weights_, map_location=str(torch.device('cpu')))
309 | if isinstance(net, torch.nn.DataParallel):
310 | net.module.load_state_dict(state_dict)
311 | else:
312 | net.load_state_dict(state_dict)
313 | print('load the weights successfully')
314 | if len(gpu_ids_) > 0:
315 | assert(torch.cuda.is_available())
316 | net.cuda(gpu_ids_[0])
317 | net = torch.nn.DataParallel(net, gpu_ids_)
318 | return net
319 | ###################################################################################################################
320 | import pretrainedmodels.models.resnext as resnext
321 | def define_resnext101(gpu_ids_=[],vec=0):
322 | net = resnext.resnext101_64x4d(num_classes=1000,pretrained='imagenet')
323 | net.last_linear = nn.Linear(2048, 1) #LSGAN needs no sigmoid, LSGAN-nn.MSELoss()
324 | if len(gpu_ids_) > 0:
325 | assert(torch.cuda.is_available())
326 | net.cuda(gpu_ids_[0])
327 | net = torch.nn.DataParallel(net, gpu_ids_)
328 | return net
329 | def define_resnext101a(init_weights_,gpu_ids_=[],vec=0):
330 | net = resnext.resnext101_64x4d(num_classes=1000,pretrained='imagenet')
331 | net.last_linear = nn.Linear(2048, 1) #LSGAN needs no sigmoid, LSGAN-nn.MSELoss()
332 | if not init_weights_ == None:
333 | print('Loading model from: %s'%init_weights_)
334 | state_dict = torch.load(init_weights_, map_location=str(torch.device('cpu')))
335 | if isinstance(net, torch.nn.DataParallel):
336 | net.module.load_state_dict(state_dict)
337 | else:
338 | net.load_state_dict(state_dict)
339 | print('load the weights successfully')
340 | if len(gpu_ids_) > 0:
341 | assert(torch.cuda.is_available())
342 | net.cuda(gpu_ids_[0])
343 | net = torch.nn.DataParallel(net, gpu_ids_)
344 | return net
345 | ###################################################################################################################
346 | from torchvision.models import Inception3, inception_v3
347 | def define_inception3(gpu_ids_=[],vec=0):
348 | net = inception_v3(pretrained=True)
349 | net.transform_input = False # assume [-1,1] input
350 | net.fc = nn.Linear(2048, 1)
351 | net.aux_logits = False
352 | if len(gpu_ids_) > 0:
353 | assert(torch.cuda.is_available())
354 | net.cuda(gpu_ids_[0])
355 | net = torch.nn.DataParallel(net, gpu_ids_)
356 | return net
357 | def define_inception3a(init_weights_,gpu_ids_=[],vec=0):
358 | net = inception_v3(pretrained=True)
359 | net.transform_input = False # assume [-1,1] input
360 | net.fc = nn.Linear(2048, 1)
361 | net.aux_logits = False
362 | if not init_weights_ == None:
363 | print('Loading model from: ', init_weights_)
364 | state_dict = torch.load(init_weights_, map_location=str(torch.device('cpu')))
365 | if isinstance(net, torch.nn.DataParallel):
366 | net.module.load_state_dict(state_dict)
367 | else:
368 | net.load_state_dict(state_dict)
369 | print('load the weights successfully')
370 | if len(gpu_ids_) > 0:
371 | assert(torch.cuda.is_available())
372 | net.cuda(gpu_ids_[0])
373 | net = torch.nn.DataParallel(net, gpu_ids_)
374 | return net
375 | ###################################################################################################################
376 | from torchvision.models.inception import BasicConv2d
377 | def define_inception_v3(init_weights_,gpu_ids_=[],vec=0):
378 |
379 | ## pretrained = True
380 | kwargs = {}
381 | if 'transform_input' not in kwargs:
382 | kwargs['transform_input'] = True
383 | if 'aux_logits' in kwargs:
384 | original_aux_logits = kwargs['aux_logits']
385 | kwargs['aux_logits'] = True
386 | else:
387 | original_aux_logits = True
388 | print(kwargs)
389 | net = Inception3(**kwargs)
390 |
391 | if not init_weights_ == None:
392 | print('Loading model from: %s'%init_weights_)
393 | state_dict = torch.load(init_weights_, map_location=str(torch.device('cpu')))
394 | if isinstance(net, torch.nn.DataParallel):
395 | net.module.load_state_dict(state_dict)
396 | else:
397 | net.load_state_dict(state_dict)
398 | print('load the weights successfully')
399 |
400 | if not original_aux_logits:
401 | net.aux_logits = False
402 | del net.AuxLogits
403 |
404 | net.fc = nn.Linear(2048, 1)
405 | if vec == 1:
406 | net.Conv2d_1a_3x3 = BasicConv2d(6, 32, kernel_size=3, stride=2)
407 | net.aux_logits = False
408 |
409 | if len(gpu_ids_) > 0:
410 | assert(torch.cuda.is_available())
411 | net.cuda(gpu_ids_[0])
412 | net = torch.nn.DataParallel(net, gpu_ids_)
413 |
414 | return net
415 |
416 | def define_inception_v3a(init_weights_,gpu_ids_=[],vec=0):
417 |
418 | kwargs = {}
419 | if 'transform_input' not in kwargs:
420 | kwargs['transform_input'] = True
421 | if 'aux_logits' in kwargs:
422 | original_aux_logits = kwargs['aux_logits']
423 | kwargs['aux_logits'] = True
424 | else:
425 | original_aux_logits = True
426 | print(kwargs)
427 | net = Inception3(**kwargs)
428 |
429 | if not original_aux_logits:
430 | net.aux_logits = False
431 | del net.AuxLogits
432 |
433 | net.fc = nn.Linear(2048, 1)
434 | if vec == 1:
435 | net.Conv2d_1a_3x3 = BasicConv2d(6, 32, kernel_size=3, stride=2)
436 | net.aux_logits = False
437 |
438 | if not init_weights_ == None:
439 | print('Loading model from: %s'%init_weights_)
440 | state_dict = torch.load(init_weights_, map_location=str(torch.device('cpu')))
441 | if isinstance(net, torch.nn.DataParallel):
442 | net.module.load_state_dict(state_dict)
443 | else:
444 | net.load_state_dict(state_dict)
445 | print('load the weights successfully')
446 |
447 | if len(gpu_ids_) > 0:
448 | assert(torch.cuda.is_available())
449 | net.cuda(gpu_ids_[0])
450 | net = torch.nn.DataParallel(net, gpu_ids_)
451 |
452 | return net
453 |
454 | def define_inception_ori(init_weights_,transform_input=False,gpu_ids_=[]):
455 |
456 | ## pretrained = True
457 | kwargs = {}
458 | kwargs['transform_input'] = transform_input
459 |
460 | if 'aux_logits' in kwargs:
461 | original_aux_logits = kwargs['aux_logits']
462 | kwargs['aux_logits'] = True
463 | else:
464 | original_aux_logits = True
465 | print(kwargs)
466 | net = Inception3(**kwargs)
467 |
468 |
469 | if not init_weights_ == None:
470 | print('Loading model from: %s'%init_weights_)
471 | state_dict = torch.load(init_weights_, map_location=str(torch.device('cpu')))
472 | if isinstance(net, torch.nn.DataParallel):
473 | net.module.load_state_dict(state_dict)
474 | else:
475 | net.load_state_dict(state_dict)
476 | print('load the weights successfully')
477 | #for e in list(net.modules()):
478 | # print(e)
479 |
480 | if not original_aux_logits:
481 | net.aux_logits = False
482 | del net.AuxLogits
483 |
484 |
485 | if len(gpu_ids_) > 0:
486 | assert(torch.cuda.is_available())
487 | net.cuda(gpu_ids_[0])
488 |
489 | return net
490 | ###################################################################################################################
491 |
492 | def define_DT(init_weights_, input_nc_, output_nc_, ngf_, netG_, norm_, use_dropout_, init_type_, init_gain_, gpu_ids_):
493 | net = define_G(input_nc_, output_nc_, ngf_, netG_, norm_, use_dropout_, init_type_, init_gain_, gpu_ids_)
494 |
495 | if not init_weights_ == None:
496 | device = torch.device('cuda:{}'.format(gpu_ids_[0])) if gpu_ids_ else torch.device('cpu')
497 | print('Loading model from: %s'%init_weights_)
498 | state_dict = torch.load(init_weights_, map_location=str(device))
499 | if isinstance(net, torch.nn.DataParallel):
500 | net.module.load_state_dict(state_dict)
501 | else:
502 | net.load_state_dict(state_dict)
503 | print('load the weights successfully')
504 | return net
505 |
506 | def define_C(input_nc, classes, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], h=512, w=512, nnG=3, dim=4096):
507 | net = None
508 | norm_layer = get_norm_layer(norm_type=norm)
509 | if netG == 'classifier':
510 | net = Classifier(input_nc, classes, ngf, num_downs=nnG, norm_layer=norm_layer, use_dropout=use_dropout, h=h, w=w, dim=dim)
511 | elif netG == 'vgg':
512 | net = VGG19(init_weights=None, feature_mode=False, batch_norm=True, num_classes=classes)
513 | return init_net(net, init_type, init_gain, gpu_ids)
514 |
515 | ##############################################################################
516 | # Classes
517 | ##############################################################################
518 | class GANLoss(nn.Module):
519 | """Define different GAN objectives.
520 |
521 | The GANLoss class abstracts away the need to create the target label tensor
522 | that has the same size as the input.
523 | """
524 |
525 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
526 | """ Initialize the GANLoss class.
527 |
528 | Parameters:
529 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
530 | target_real_label (bool) - - label for a real image
531 | target_fake_label (bool) - - label of a fake image
532 |
533 | Note: Do not use sigmoid as the last layer of Discriminator.
534 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
535 | """
536 | super(GANLoss, self).__init__()
537 | self.register_buffer('real_label', torch.tensor(target_real_label))
538 | self.register_buffer('fake_label', torch.tensor(target_fake_label))
539 | self.gan_mode = gan_mode
540 | if gan_mode == 'lsgan':#cyclegan
541 | self.loss = nn.MSELoss()
542 | elif gan_mode == 'vanilla':
543 | self.loss = nn.BCEWithLogitsLoss()
544 | elif gan_mode in ['wgangp']:
545 | self.loss = None
546 | else:
547 | raise NotImplementedError('gan mode %s not implemented' % gan_mode)
548 |
549 | def get_target_tensor(self, prediction, target_is_real):
550 | """Create label tensors with the same size as the input.
551 |
552 | Parameters:
553 | prediction (tensor) - - tpyically the prediction from a discriminator
554 | target_is_real (bool) - - if the ground truth label is for real images or fake images
555 |
556 | Returns:
557 | A label tensor filled with ground truth label, and with the size of the input
558 | """
559 |
560 | if target_is_real:
561 | target_tensor = self.real_label
562 | else:
563 | target_tensor = self.fake_label
564 | return target_tensor.expand_as(prediction)
565 |
566 | def __call__(self, prediction, target_is_real):
567 | """Calculate loss given Discriminator's output and grount truth labels.
568 |
569 | Parameters:
570 | prediction (tensor) - - tpyically the prediction output from a discriminator
571 | target_is_real (bool) - - if the ground truth label is for real images or fake images
572 |
573 | Returns:
574 | the calculated loss.
575 | """
576 | if self.gan_mode in ['lsgan', 'vanilla']:
577 | target_tensor = self.get_target_tensor(prediction, target_is_real)
578 | loss = self.loss(prediction, target_tensor)
579 | elif self.gan_mode == 'wgangp':
580 | if target_is_real:
581 | loss = -prediction.mean()
582 | else:
583 | loss = prediction.mean()
584 | return loss
585 |
586 |
587 | def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
588 | """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
589 |
590 | Arguments:
591 | netD (network) -- discriminator network
592 | real_data (tensor array) -- real images
593 | fake_data (tensor array) -- generated images from the generator
594 | device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
595 | type (str) -- if we mix real and fake data or not [real | fake | mixed].
596 | constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
597 | lambda_gp (float) -- weight for this loss
598 |
599 | Returns the gradient penalty loss
600 | """
601 | if lambda_gp > 0.0:
602 | if type == 'real': # either use real images, fake images, or a linear interpolation of two.
603 | interpolatesv = real_data
604 | elif type == 'fake':
605 | interpolatesv = fake_data
606 | elif type == 'mixed':
607 | alpha = torch.rand(real_data.shape[0], 1, device=device)
608 | alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
609 | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
610 | else:
611 | raise NotImplementedError('{} not implemented'.format(type))
612 | interpolatesv.requires_grad_(True)
613 | disc_interpolates = netD(interpolatesv)
614 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
615 | grad_outputs=torch.ones(disc_interpolates.size()).to(device),
616 | create_graph=True, retain_graph=True, only_inputs=True)
617 | gradients = gradients[0].view(real_data.size(0), -1) # flat the data
618 | gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
619 | return gradient_penalty, gradients
620 | else:
621 | return 0.0, None
622 |
623 |
624 | class ResnetGenerator(nn.Module):
625 | """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
626 |
627 | We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
628 | """
629 |
630 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
631 | """Construct a Resnet-based generator
632 |
633 | Parameters:
634 | input_nc (int) -- the number of channels in input images
635 | output_nc (int) -- the number of channels in output images
636 | ngf (int) -- the number of filters in the last conv layer
637 | norm_layer -- normalization layer
638 | use_dropout (bool) -- if use dropout layers
639 | n_blocks (int) -- the number of ResNet blocks
640 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
641 | """
642 | assert(n_blocks >= 0)
643 | super(ResnetGenerator, self).__init__()
644 | if type(norm_layer) == functools.partial:
645 | use_bias = norm_layer.func == nn.InstanceNorm2d
646 | else:
647 | use_bias = norm_layer == nn.InstanceNorm2d
648 |
649 | model = [nn.ReflectionPad2d(3),
650 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
651 | norm_layer(ngf),
652 | nn.ReLU(True)]
653 |
654 | n_downsampling = 2
655 | for i in range(n_downsampling): # add downsampling layers
656 | mult = 2 ** i
657 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
658 | norm_layer(ngf * mult * 2),
659 | nn.ReLU(True)]
660 |
661 | mult = 2 ** n_downsampling
662 | for i in range(n_blocks): # add ResNet blocks
663 |
664 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
665 |
666 | for i in range(n_downsampling): # add upsampling layers
667 | mult = 2 ** (n_downsampling - i)
668 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
669 | kernel_size=3, stride=2,
670 | padding=1, output_padding=1,
671 | bias=use_bias),
672 | norm_layer(int(ngf * mult / 2)),
673 | nn.ReLU(True)]
674 | model += [nn.ReflectionPad2d(3)]
675 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
676 | model += [nn.Tanh()]
677 |
678 | self.model = nn.Sequential(*model)
679 |
680 | def forward(self, input, feature_mode = False):
681 | """Standard forward"""
682 | if not feature_mode:
683 | return self.model(input)
684 | else:
685 | module_list = list(self.model.modules())
686 | x = input.clone()
687 | indexes = list(range(1,11))+[11,20,29,38,47,56,65,74,83]+list(range(92,101))
688 | for i in indexes:
689 | x = module_list[i](x)
690 | if i == 3:
691 | x1 = x.clone()
692 | elif i == 6:
693 | x2 = x.clone()
694 | elif i == 9:
695 | x3 = x.clone()
696 | elif i == 47:
697 | y7 = x.clone()
698 | elif i == 83:
699 | y4 = x.clone()
700 | elif i == 93:
701 | y3 = x.clone()
702 | elif i == 96:
703 | y2 = x.clone()
704 | #y = self.model(input)
705 | #pdb.set_trace()
706 | return x,x1,x2,x3,y4,y3,y2,y7
707 |
708 | class ResnetStyleGenerator(nn.Module):
709 | """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
710 |
711 | We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
712 | """
713 |
714 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
715 | """Construct a Resnet-based generator
716 |
717 | Parameters:
718 | input_nc (int) -- the number of channels in input images
719 | output_nc (int) -- the number of channels in output images
720 | ngf (int) -- the number of filters in the last conv layer
721 | norm_layer -- normalization layer
722 | use_dropout (bool) -- if use dropout layers
723 | n_blocks (int) -- the number of ResNet blocks
724 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
725 | """
726 | assert(n_blocks >= 0)
727 | super(ResnetStyleGenerator, self).__init__()
728 | if type(norm_layer) == functools.partial:
729 | use_bias = norm_layer.func == nn.InstanceNorm2d
730 | else:
731 | use_bias = norm_layer == nn.InstanceNorm2d
732 |
733 | model0 = [nn.ReflectionPad2d(3),
734 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
735 | norm_layer(ngf),
736 | nn.ReLU(True)]
737 |
738 | n_downsampling = 2
739 | for i in range(n_downsampling): # add downsampling layers
740 | mult = 2 ** i
741 | model0 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
742 | norm_layer(ngf * mult * 2),
743 | nn.ReLU(True)]
744 |
745 | mult = 2 ** n_downsampling
746 | model1 = [nn.Conv2d(3, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias),
747 | norm_layer(ngf * mult),
748 | nn.ReLU(True)]
749 |
750 | model = []
751 | model += [nn.Conv2d(ngf * mult * 2, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias),
752 | norm_layer(ngf * mult),
753 | nn.ReLU(True)]
754 | for i in range(n_blocks): # add ResNet blocks
755 |
756 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
757 |
758 | for i in range(n_downsampling): # add upsampling layers
759 | mult = 2 ** (n_downsampling - i)
760 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
761 | kernel_size=3, stride=2,
762 | padding=1, output_padding=1,
763 | bias=use_bias),
764 | norm_layer(int(ngf * mult / 2)),
765 | nn.ReLU(True)]
766 | model += [nn.ReflectionPad2d(3)]
767 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
768 | model += [nn.Tanh()]
769 |
770 | self.model0 = nn.Sequential(*model0)
771 | self.model1 = nn.Sequential(*model1)
772 | self.model = nn.Sequential(*model)
773 |
774 | def forward(self, input1, input2):
775 | """Standard forward"""
776 | f1 = self.model0(input1)
777 | f2 = self.model1(input2)
778 | #pdb.set_trace()
779 | f1 = torch.cat((f1,f2), 1)
780 | return self.model(f1)
781 |
782 |
783 | class ResnetStyle2Generator(nn.Module):
784 | """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
785 |
786 | We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
787 | """
788 |
789 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', extra_channel=3, model0_res=0):
790 | """Construct a Resnet-based generator
791 |
792 | Parameters:
793 | input_nc (int) -- the number of channels in input images
794 | output_nc (int) -- the number of channels in output images
795 | ngf (int) -- the number of filters in the last conv layer
796 | norm_layer -- normalization layer
797 | use_dropout (bool) -- if use dropout layers
798 | n_blocks (int) -- the number of ResNet blocks
799 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
800 | """
801 | assert(n_blocks >= 0)
802 | super(ResnetStyle2Generator, self).__init__()
803 | self.n_blocks = n_blocks
804 | if type(norm_layer) == functools.partial:
805 | use_bias = norm_layer.func == nn.InstanceNorm2d
806 | else:
807 | use_bias = norm_layer == nn.InstanceNorm2d
808 |
809 | model0 = [nn.ReflectionPad2d(3),
810 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
811 | norm_layer(ngf),
812 | nn.ReLU(True)]
813 |
814 | n_downsampling = 2
815 | for i in range(n_downsampling): # add downsampling layers
816 | mult = 2 ** i
817 | model0 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
818 | norm_layer(ngf * mult * 2),
819 | nn.ReLU(True)]
820 |
821 | mult = 2 ** n_downsampling
822 | for i in range(model0_res): # add ResNet blocks
823 | model0 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
824 |
825 | model = []
826 | model += [nn.Conv2d(ngf * mult + extra_channel, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias),
827 | norm_layer(ngf * mult),
828 | nn.ReLU(True)]
829 |
830 | for i in range(n_blocks-model0_res): # add ResNet blocks
831 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
832 |
833 | for i in range(n_downsampling): # add upsampling layers
834 | mult = 2 ** (n_downsampling - i)
835 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
836 | kernel_size=3, stride=2,
837 | padding=1, output_padding=1,
838 | bias=use_bias),
839 | norm_layer(int(ngf * mult / 2)),
840 | nn.ReLU(True)]
841 | model += [nn.ReflectionPad2d(3)]
842 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
843 | model += [nn.Tanh()]
844 |
845 | self.model0 = nn.Sequential(*model0)
846 | self.model = nn.Sequential(*model)
847 | #print(list(self.modules()))
848 |
849 | def forward(self, input1, input2, feature_mode=False, ablate_res=-1):
850 | """Standard forward"""
851 | if not feature_mode:
852 | if ablate_res == -1:
853 | f1 = self.model0(input1)
854 | y1 = torch.cat([f1, input2], 1)
855 | return self.model(y1)
856 | else:
857 | f1 = self.model0(input1)
858 | y = torch.cat([f1, input2], 1)
859 | module_list = list(self.model.modules())
860 | for i in range(1, 4):#merge module
861 | y = module_list[i](y)
862 | for k in range(self.n_blocks):#resblocks
863 | if k+1 == ablate_res:
864 | print('skip resblock'+str(k+1))
865 | continue
866 | y1 = y.clone()
867 | for i in range(6+9*k,13+9*k):
868 | y = module_list[i](y)
869 | y = y1 + y
870 | for i in range(4+9*self.n_blocks,13+9*self.n_blocks):#up convs
871 | y = module_list[i](y)
872 | return y
873 | else:
874 | module_list0 = list(self.model0.modules())
875 | x = input1.clone()
876 | for i in range(1,11):
877 | x = module_list0[i](x)
878 | if i == 3:
879 | x1 = x.clone()#[1,64,512,512]
880 | elif i == 6:
881 | x2 = x.clone()#[1,128,256,256]
882 | elif i == 9:
883 | x3 = x.clone()#[1,256,128,128]
884 | #f1 = self.model0(input1)#[1,256,128,128]
885 | #pdb.set_trace()
886 | y1 = torch.cat([x, input2], 1)#[1,259,128,128]
887 | module_list = list(self.model.modules())
888 | indexes = list(range(1,4))+[4,13,22,31,40,49,58,67,76]+list(range(85,94))
889 | y = y1.clone()
890 | for i in indexes:
891 | y = module_list[i](y)
892 | if i == 76:
893 | y4 = y.clone()#[1,256,128,128]
894 | elif i == 86:
895 | y3 = y.clone()#[1,128,256,256]
896 | elif i == 89:
897 | y2 = y.clone()#[1,64,512,512]
898 | elif i == 40:
899 | y7 = y.clone()
900 | #out = self.model(y1)
901 | #pdb.set_trace()
902 | return y,x1,x2,x3,y4,y3,y2,y7
903 |
904 | class ResnetStyle3DecoderGenerator(nn.Module):
905 | """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
906 |
907 | We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
908 | """
909 |
910 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', model0_res=0):
911 | """Construct a Resnet-based generator
912 |
913 | Parameters:
914 | input_nc (int) -- the number of channels in input images
915 | output_nc (int) -- the number of channels in output images
916 | ngf (int) -- the number of filters in the last conv layer
917 | norm_layer -- normalization layer
918 | use_dropout (bool) -- if use dropout layers
919 | n_blocks (int) -- the number of ResNet blocks
920 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
921 | """
922 | assert(n_blocks >= 0)
923 | super(ResnetStyle3DecoderGenerator, self).__init__()
924 | if type(norm_layer) == functools.partial:
925 | use_bias = norm_layer.func == nn.InstanceNorm2d
926 | else:
927 | use_bias = norm_layer == nn.InstanceNorm2d
928 |
929 | model0 = [nn.ReflectionPad2d(3),
930 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
931 | norm_layer(ngf),
932 | nn.ReLU(True)]
933 |
934 | n_downsampling = 2
935 | for i in range(n_downsampling): # add downsampling layers
936 | mult = 2 ** i
937 | model0 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
938 | norm_layer(ngf * mult * 2),
939 | nn.ReLU(True)]
940 |
941 | mult = 2 ** n_downsampling
942 | for i in range(model0_res): # add ResNet blocks
943 | model0 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
944 |
945 | model1 = []
946 | model2 = []
947 | model3 = []
948 | for i in range(n_blocks-model0_res): # add ResNet blocks
949 | model1 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
950 | model2 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
951 | model3 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
952 |
953 | for i in range(n_downsampling): # add upsampling layers
954 | mult = 2 ** (n_downsampling - i)
955 | model1 += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
956 | kernel_size=3, stride=2,
957 | padding=1, output_padding=1,
958 | bias=use_bias),
959 | norm_layer(int(ngf * mult / 2)),
960 | nn.ReLU(True)]
961 | model2 += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
962 | kernel_size=3, stride=2,
963 | padding=1, output_padding=1,
964 | bias=use_bias),
965 | norm_layer(int(ngf * mult / 2)),
966 | nn.ReLU(True)]
967 | model3 += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
968 | kernel_size=3, stride=2,
969 | padding=1, output_padding=1,
970 | bias=use_bias),
971 | norm_layer(int(ngf * mult / 2)),
972 | nn.ReLU(True)]
973 | model1 += [nn.ReflectionPad2d(3)]
974 | model1 += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
975 | model1 += [nn.Tanh()]
976 | model2 += [nn.ReflectionPad2d(3)]
977 | model2 += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
978 | model2 += [nn.Tanh()]
979 | model3 += [nn.ReflectionPad2d(3)]
980 | model3 += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
981 | model3 += [nn.Tanh()]
982 |
983 | self.model0 = nn.Sequential(*model0)
984 | self.model1 = nn.Sequential(*model1)
985 | self.model2 = nn.Sequential(*model2)
986 | self.model3 = nn.Sequential(*model3)
987 | print(list(self.modules()))
988 |
989 | def forward(self, input, domain):
990 | """Standard forward"""
991 | f1 = self.model0(input)
992 | if domain == 0:
993 | y = self.model1(f1)
994 | elif domain == 1:
995 | y = self.model2(f1)
996 | elif domain == 2:
997 | y = self.model3(f1)
998 | return y
999 |
1000 | class ResnetStyle2MCGenerator(nn.Module):
1001 | # multi-column
1002 |
1003 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', extra_channel=3, model0_res=0):
1004 | """Construct a Resnet-based generator
1005 |
1006 | Parameters:
1007 | input_nc (int) -- the number of channels in input images
1008 | output_nc (int) -- the number of channels in output images
1009 | ngf (int) -- the number of filters in the last conv layer
1010 | norm_layer -- normalization layer
1011 | use_dropout (bool) -- if use dropout layers
1012 | n_blocks (int) -- the number of ResNet blocks
1013 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
1014 | """
1015 | assert(n_blocks >= 0)
1016 | super(ResnetStyle2MCGenerator, self).__init__()
1017 | if type(norm_layer) == functools.partial:
1018 | use_bias = norm_layer.func == nn.InstanceNorm2d
1019 | else:
1020 | use_bias = norm_layer == nn.InstanceNorm2d
1021 |
1022 | model0 = [nn.ReflectionPad2d(3),
1023 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
1024 | norm_layer(ngf),
1025 | nn.ReLU(True)]
1026 |
1027 | n_downsampling = 2
1028 | model1_3 = []
1029 | model1_5 = []
1030 | for i in range(n_downsampling): # add downsampling layers
1031 | mult = 2 ** i
1032 | model1_3 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
1033 | norm_layer(ngf * mult * 2),
1034 | nn.ReLU(True)]
1035 | model1_5 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=5, stride=2, padding=2, bias=use_bias),
1036 | norm_layer(ngf * mult * 2),
1037 | nn.ReLU(True)]
1038 |
1039 | mult = 2 ** n_downsampling
1040 | for i in range(model0_res): # add ResNet blocks
1041 | model1_3 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
1042 | model1_5 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias, kernel=5)]
1043 |
1044 | model = []
1045 | model += [nn.Conv2d(ngf * mult * 2 + extra_channel, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias),
1046 | norm_layer(ngf * mult),
1047 | nn.ReLU(True)]
1048 |
1049 | for i in range(n_blocks-model0_res): # add ResNet blocks
1050 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
1051 |
1052 | for i in range(n_downsampling): # add upsampling layers
1053 | mult = 2 ** (n_downsampling - i)
1054 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
1055 | kernel_size=3, stride=2,
1056 | padding=1, output_padding=1,
1057 | bias=use_bias),
1058 | norm_layer(int(ngf * mult / 2)),
1059 | nn.ReLU(True)]
1060 | model += [nn.ReflectionPad2d(3)]
1061 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
1062 | model += [nn.Tanh()]
1063 |
1064 | self.model0 = nn.Sequential(*model0)
1065 | self.model1_3 = nn.Sequential(*model1_3)
1066 | self.model1_5 = nn.Sequential(*model1_5)
1067 | self.model = nn.Sequential(*model)
1068 | print(list(self.modules()))
1069 |
1070 | def forward(self, input1, input2):
1071 | """Standard forward"""
1072 | f0 = self.model0(input1)
1073 | f1 = self.model1_3(f0)
1074 | f2 = self.model1_5(f0)
1075 | y1 = torch.cat([f1, f2, input2], 1)
1076 | return self.model(y1)
1077 |
1078 | class ResnetStyle2MC2Generator(nn.Module):
1079 | # multi-column, need to insert style early
1080 |
1081 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', extra_channel=3, model0_res=0, model1_res=0):
1082 | """Construct a Resnet-based generator
1083 |
1084 | Parameters:
1085 | input_nc (int) -- the number of channels in input images
1086 | output_nc (int) -- the number of channels in output images
1087 | ngf (int) -- the number of filters in the last conv layer
1088 | norm_layer -- normalization layer
1089 | use_dropout (bool) -- if use dropout layers
1090 | n_blocks (int) -- the number of ResNet blocks
1091 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
1092 | """
1093 | assert(n_blocks >= 0)
1094 | super(ResnetStyle2MC2Generator, self).__init__()
1095 | if type(norm_layer) == functools.partial:
1096 | use_bias = norm_layer.func == nn.InstanceNorm2d
1097 | else:
1098 | use_bias = norm_layer == nn.InstanceNorm2d
1099 |
1100 | model0 = [nn.ReflectionPad2d(3),
1101 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
1102 | norm_layer(ngf),
1103 | nn.ReLU(True)]
1104 |
1105 | n_downsampling = 2
1106 | model1_3 = []
1107 | model1_5 = []
1108 | for i in range(n_downsampling): # add downsampling layers
1109 | mult = 2 ** i
1110 | model1_3 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
1111 | norm_layer(ngf * mult * 2),
1112 | nn.ReLU(True)]
1113 | model1_5 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=5, stride=2, padding=2, bias=use_bias),
1114 | norm_layer(ngf * mult * 2),
1115 | nn.ReLU(True)]
1116 |
1117 | mult = 2 ** n_downsampling
1118 | for i in range(model0_res): # add ResNet blocks
1119 | model1_3 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
1120 | model1_5 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias, kernel=5)]
1121 |
1122 | model2_3 = []
1123 | model2_5 = []
1124 | model2_3 += [nn.Conv2d(ngf * mult + extra_channel, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias),
1125 | norm_layer(ngf * mult),
1126 | nn.ReLU(True)]
1127 | model2_5 += [nn.Conv2d(ngf * mult + extra_channel, ngf * mult, kernel_size=5, stride=1, padding=2, bias=use_bias),
1128 | norm_layer(ngf * mult),
1129 | nn.ReLU(True)]
1130 |
1131 | for i in range(model1_res): # add ResNet blocks
1132 | model2_3 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
1133 | model2_5 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias, kernel=5)]
1134 |
1135 | model = []
1136 | model += [nn.Conv2d(ngf * mult * 2, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias),
1137 | norm_layer(ngf * mult),
1138 | nn.ReLU(True)]
1139 | for i in range(n_blocks-model0_res-model1_res): # add ResNet blocks
1140 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
1141 |
1142 | for i in range(n_downsampling): # add upsampling layers
1143 | mult = 2 ** (n_downsampling - i)
1144 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
1145 | kernel_size=3, stride=2,
1146 | padding=1, output_padding=1,
1147 | bias=use_bias),
1148 | norm_layer(int(ngf * mult / 2)),
1149 | nn.ReLU(True)]
1150 | model += [nn.ReflectionPad2d(3)]
1151 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
1152 | model += [nn.Tanh()]
1153 |
1154 | self.model0 = nn.Sequential(*model0)
1155 | self.model1_3 = nn.Sequential(*model1_3)
1156 | self.model1_5 = nn.Sequential(*model1_5)
1157 | self.model2_3 = nn.Sequential(*model2_3)
1158 | self.model2_5 = nn.Sequential(*model2_5)
1159 | self.model = nn.Sequential(*model)
1160 | print(list(self.modules()))
1161 |
1162 | def forward(self, input1, input2):
1163 | """Standard forward"""
1164 | f0 = self.model0(input1)
1165 | f1 = self.model1_3(f0)
1166 | f2 = self.model1_5(f0)
1167 | f3 = self.model2_3(torch.cat([f1,input2],1))
1168 | f4 = self.model2_5(torch.cat([f2,input2],1))
1169 | #pdb.set_trace()
1170 | return self.model(torch.cat([f3,f4],1))
1171 |
1172 | class ResnetBlock(nn.Module):
1173 | """Define a Resnet block"""
1174 |
1175 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, kernel=3):
1176 | """Initialize the Resnet block
1177 |
1178 | A resnet block is a conv block with skip connections
1179 | We construct a conv block with build_conv_block function,
1180 | and implement skip connections in function.
1181 | Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
1182 | """
1183 | super(ResnetBlock, self).__init__()
1184 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, kernel)
1185 |
1186 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, kernel=3):
1187 | """Construct a convolutional block.
1188 |
1189 | Parameters:
1190 | dim (int) -- the number of channels in the conv layer.
1191 | padding_type (str) -- the name of padding layer: reflect | replicate | zero
1192 | norm_layer -- normalization layer
1193 | use_dropout (bool) -- if use dropout layers.
1194 | use_bias (bool) -- if the conv layer uses bias or not
1195 |
1196 | Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
1197 | """
1198 | conv_block = []
1199 | p = 0
1200 | pad = int((kernel-1)/2)
1201 | if padding_type == 'reflect':#by default
1202 | conv_block += [nn.ReflectionPad2d(pad)]
1203 | elif padding_type == 'replicate':
1204 | conv_block += [nn.ReplicationPad2d(pad)]
1205 | elif padding_type == 'zero':
1206 | p = pad
1207 | else:
1208 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
1209 |
1210 | conv_block += [nn.Conv2d(dim, dim, kernel_size=kernel, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
1211 | if use_dropout:
1212 | conv_block += [nn.Dropout(0.5)]
1213 |
1214 | p = 0
1215 | if padding_type == 'reflect':
1216 | conv_block += [nn.ReflectionPad2d(pad)]
1217 | elif padding_type == 'replicate':
1218 | conv_block += [nn.ReplicationPad2d(pad)]
1219 | elif padding_type == 'zero':
1220 | p = pad
1221 | else:
1222 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
1223 | conv_block += [nn.Conv2d(dim, dim, kernel_size=kernel, padding=p, bias=use_bias), norm_layer(dim)]
1224 |
1225 | return nn.Sequential(*conv_block)
1226 |
1227 | def forward(self, x):
1228 | """Forward function (with skip connections)"""
1229 | out = x + self.conv_block(x) # add skip connections
1230 | return out
1231 |
1232 |
1233 | class UnetGenerator(nn.Module):
1234 | """Create a Unet-based generator"""
1235 |
1236 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
1237 | """Construct a Unet generator
1238 | Parameters:
1239 | input_nc (int) -- the number of channels in input images
1240 | output_nc (int) -- the number of channels in output images
1241 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
1242 | image of size 128x128 will become of size 1x1 # at the bottleneck
1243 | ngf (int) -- the number of filters in the last conv layer
1244 | norm_layer -- normalization layer
1245 |
1246 | We construct the U-Net from the innermost layer to the outermost layer.
1247 | It is a recursive process.
1248 | """
1249 | super(UnetGenerator, self).__init__()
1250 | # construct unet structure
1251 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
1252 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
1253 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
1254 | # gradually reduce the number of filters from ngf * 8 to ngf
1255 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
1256 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
1257 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
1258 | self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
1259 |
1260 | def forward(self, input):
1261 | """Standard forward"""
1262 | return self.model(input)
1263 |
1264 |
1265 | class UnetSkipConnectionBlock(nn.Module):
1266 | """Defines the Unet submodule with skip connection.
1267 | X -------------------identity----------------------
1268 | |-- downsampling -- |submodule| -- upsampling --|
1269 | """
1270 |
1271 | def __init__(self, outer_nc, inner_nc, input_nc=None,
1272 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
1273 | """Construct a Unet submodule with skip connections.
1274 |
1275 | Parameters:
1276 | outer_nc (int) -- the number of filters in the outer conv layer
1277 | inner_nc (int) -- the number of filters in the inner conv layer
1278 | input_nc (int) -- the number of channels in input images/features
1279 | submodule (UnetSkipConnectionBlock) -- previously defined submodules
1280 | outermost (bool) -- if this module is the outermost module
1281 | innermost (bool) -- if this module is the innermost module
1282 | norm_layer -- normalization layer
1283 | user_dropout (bool) -- if use dropout layers.
1284 | """
1285 | super(UnetSkipConnectionBlock, self).__init__()
1286 | self.outermost = outermost
1287 | if type(norm_layer) == functools.partial:
1288 | use_bias = norm_layer.func == nn.InstanceNorm2d
1289 | else:
1290 | use_bias = norm_layer == nn.InstanceNorm2d
1291 | if input_nc is None:
1292 | input_nc = outer_nc
1293 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
1294 | stride=2, padding=1, bias=use_bias)
1295 | downrelu = nn.LeakyReLU(0.2, True)
1296 | downnorm = norm_layer(inner_nc)
1297 | uprelu = nn.ReLU(True)
1298 | upnorm = norm_layer(outer_nc)
1299 |
1300 | if outermost:
1301 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
1302 | kernel_size=4, stride=2,
1303 | padding=1)
1304 | down = [downconv]
1305 | up = [uprelu, upconv, nn.Tanh()]
1306 | model = down + [submodule] + up
1307 | elif innermost:
1308 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
1309 | kernel_size=4, stride=2,
1310 | padding=1, bias=use_bias)
1311 | down = [downrelu, downconv]
1312 | up = [uprelu, upconv, upnorm]
1313 | model = down + up
1314 | else:
1315 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
1316 | kernel_size=4, stride=2,
1317 | padding=1, bias=use_bias)
1318 | down = [downrelu, downconv, downnorm]
1319 | up = [uprelu, upconv, upnorm]
1320 |
1321 | if use_dropout:
1322 | model = down + [submodule] + up + [nn.Dropout(0.5)]
1323 | else:
1324 | model = down + [submodule] + up
1325 |
1326 | self.model = nn.Sequential(*model)
1327 |
1328 | def forward(self, x):
1329 | if self.outermost:
1330 | return self.model(x)
1331 | else: # add skip connections
1332 | return torch.cat([x, self.model(x)], 1)
1333 |
1334 |
1335 | class NLayerDiscriminator(nn.Module):
1336 | """Defines a PatchGAN discriminator"""
1337 |
1338 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
1339 | """Construct a PatchGAN discriminator
1340 |
1341 | Parameters:
1342 | input_nc (int) -- the number of channels in input images
1343 | ndf (int) -- the number of filters in the last conv layer
1344 | n_layers (int) -- the number of conv layers in the discriminator
1345 | norm_layer -- normalization layer
1346 | """
1347 | super(NLayerDiscriminator, self).__init__()
1348 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
1349 | use_bias = norm_layer.func != nn.BatchNorm2d
1350 | else:
1351 | use_bias = norm_layer != nn.BatchNorm2d
1352 |
1353 | kw = 4
1354 | padw = 1
1355 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
1356 | nf_mult = 1
1357 | nf_mult_prev = 1
1358 | for n in range(1, n_layers): # gradually increase the number of filters
1359 | nf_mult_prev = nf_mult
1360 | nf_mult = min(2 ** n, 8)
1361 | sequence += [
1362 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
1363 | norm_layer(ndf * nf_mult),
1364 | nn.LeakyReLU(0.2, True)
1365 | ]
1366 |
1367 | nf_mult_prev = nf_mult
1368 | nf_mult = min(2 ** n_layers, 8)
1369 | sequence += [
1370 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
1371 | norm_layer(ndf * nf_mult),
1372 | nn.LeakyReLU(0.2, True)
1373 | ]
1374 |
1375 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
1376 | self.model = nn.Sequential(*sequence)
1377 |
1378 | def forward(self, input):
1379 | """Standard forward."""
1380 | return self.model(input)
1381 |
1382 |
1383 | class NLayerDiscriminatorCls(nn.Module):
1384 | """Defines a PatchGAN discriminator"""
1385 |
1386 | def __init__(self, input_nc, ndf=64, n_layers=3, n_class=3, norm_layer=nn.BatchNorm2d):
1387 | """Construct a PatchGAN discriminator
1388 |
1389 | Parameters:
1390 | input_nc (int) -- the number of channels in input images
1391 | ndf (int) -- the number of filters in the last conv layer
1392 | n_layers (int) -- the number of conv layers in the discriminator
1393 | norm_layer -- normalization layer
1394 | """
1395 | super(NLayerDiscriminatorCls, self).__init__()
1396 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
1397 | use_bias = norm_layer.func != nn.BatchNorm2d
1398 | else:
1399 | use_bias = norm_layer != nn.BatchNorm2d
1400 |
1401 | kw = 4
1402 | padw = 1
1403 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
1404 | nf_mult = 1
1405 | nf_mult_prev = 1
1406 | for n in range(1, n_layers): # gradually increase the number of filters
1407 | nf_mult_prev = nf_mult
1408 | nf_mult = min(2 ** n, 8)
1409 | sequence += [
1410 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
1411 | norm_layer(ndf * nf_mult),
1412 | nn.LeakyReLU(0.2, True)
1413 | ]
1414 |
1415 | nf_mult_prev = nf_mult
1416 | nf_mult = min(2 ** n_layers, 8)
1417 | sequence1 = [
1418 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
1419 | norm_layer(ndf * nf_mult),
1420 | nn.LeakyReLU(0.2, True)
1421 | ]
1422 | sequence1 += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
1423 |
1424 | sequence2 = [
1425 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
1426 | norm_layer(ndf * nf_mult),
1427 | nn.LeakyReLU(0.2, True)
1428 | ]
1429 | sequence2 += [
1430 | nn.Conv2d(ndf * nf_mult, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
1431 | norm_layer(ndf * nf_mult),
1432 | nn.LeakyReLU(0.2, True)
1433 | ]
1434 | sequence2 += [
1435 | nn.Conv2d(ndf * nf_mult, n_class, kernel_size=16, stride=1, padding=0, bias=use_bias)]
1436 |
1437 |
1438 | self.model0 = nn.Sequential(*sequence)
1439 | self.model1 = nn.Sequential(*sequence1)
1440 | self.model2 = nn.Sequential(*sequence2)
1441 | print(list(self.modules()))
1442 |
1443 | def forward(self, input):
1444 | """Standard forward."""
1445 | feat = self.model0(input)
1446 | # patchGAN output (1 * 62 * 62)
1447 | patch = self.model1(feat)
1448 | # class output (3 * 1 * 1)
1449 | classl = self.model2(feat)
1450 | return patch, classl.view(classl.size(0), -1)
1451 |
1452 |
1453 | class PixelDiscriminator(nn.Module):
1454 | """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
1455 |
1456 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
1457 | """Construct a 1x1 PatchGAN discriminator
1458 |
1459 | Parameters:
1460 | input_nc (int) -- the number of channels in input images
1461 | ndf (int) -- the number of filters in the last conv layer
1462 | norm_layer -- normalization layer
1463 | """
1464 | super(PixelDiscriminator, self).__init__()
1465 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
1466 | use_bias = norm_layer.func != nn.InstanceNorm2d
1467 | else:
1468 | use_bias = norm_layer != nn.InstanceNorm2d
1469 |
1470 | self.net = [
1471 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
1472 | nn.LeakyReLU(0.2, True),
1473 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
1474 | norm_layer(ndf * 2),
1475 | nn.LeakyReLU(0.2, True),
1476 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
1477 |
1478 | self.net = nn.Sequential(*self.net)
1479 |
1480 | def forward(self, input):
1481 | """Standard forward."""
1482 | return self.net(input)
1483 |
1484 |
1485 | class HED(nn.Module):
1486 | def __init__(self):
1487 | super(HED, self).__init__()
1488 |
1489 | self.moduleVggOne = nn.Sequential(
1490 | nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
1491 | nn.ReLU(inplace=False),
1492 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
1493 | nn.ReLU(inplace=False)
1494 | )
1495 |
1496 | self.moduleVggTwo = nn.Sequential(
1497 | nn.MaxPool2d(kernel_size=2, stride=2),
1498 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
1499 | nn.ReLU(inplace=False),
1500 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
1501 | nn.ReLU(inplace=False)
1502 | )
1503 |
1504 | self.moduleVggThr = nn.Sequential(
1505 | nn.MaxPool2d(kernel_size=2, stride=2),
1506 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
1507 | nn.ReLU(inplace=False),
1508 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
1509 | nn.ReLU(inplace=False),
1510 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
1511 | nn.ReLU(inplace=False)
1512 | )
1513 |
1514 | self.moduleVggFou = nn.Sequential(
1515 | nn.MaxPool2d(kernel_size=2, stride=2),
1516 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
1517 | nn.ReLU(inplace=False),
1518 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
1519 | nn.ReLU(inplace=False),
1520 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
1521 | nn.ReLU(inplace=False)
1522 | )
1523 |
1524 | self.moduleVggFiv = nn.Sequential(
1525 | nn.MaxPool2d(kernel_size=2, stride=2),
1526 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
1527 | nn.ReLU(inplace=False),
1528 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
1529 | nn.ReLU(inplace=False),
1530 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
1531 | nn.ReLU(inplace=False)
1532 | )
1533 |
1534 | self.moduleScoreOne = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0)
1535 | self.moduleScoreTwo = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0)
1536 | self.moduleScoreThr = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0)
1537 | self.moduleScoreFou = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
1538 | self.moduleScoreFiv = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
1539 |
1540 | self.moduleCombine = nn.Sequential(
1541 | nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
1542 | nn.Sigmoid()
1543 | )
1544 |
1545 | def forward(self, tensorInput):
1546 | tensorBlue = (tensorInput[:, 2:3, :, :] * 255.0) - 104.00698793
1547 | tensorGreen = (tensorInput[:, 1:2, :, :] * 255.0) - 116.66876762
1548 | tensorRed = (tensorInput[:, 0:1, :, :] * 255.0) - 122.67891434
1549 |
1550 | tensorInput = torch.cat([ tensorBlue, tensorGreen, tensorRed ], 1)
1551 |
1552 | tensorVggOne = self.moduleVggOne(tensorInput)
1553 | tensorVggTwo = self.moduleVggTwo(tensorVggOne)
1554 | tensorVggThr = self.moduleVggThr(tensorVggTwo)
1555 | tensorVggFou = self.moduleVggFou(tensorVggThr)
1556 | tensorVggFiv = self.moduleVggFiv(tensorVggFou)
1557 |
1558 | tensorScoreOne = self.moduleScoreOne(tensorVggOne)
1559 | tensorScoreTwo = self.moduleScoreTwo(tensorVggTwo)
1560 | tensorScoreThr = self.moduleScoreThr(tensorVggThr)
1561 | tensorScoreFou = self.moduleScoreFou(tensorVggFou)
1562 | tensorScoreFiv = self.moduleScoreFiv(tensorVggFiv)
1563 |
1564 | tensorScoreOne = nn.functional.interpolate(input=tensorScoreOne, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
1565 | tensorScoreTwo = nn.functional.interpolate(input=tensorScoreTwo, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
1566 | tensorScoreThr = nn.functional.interpolate(input=tensorScoreThr, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
1567 | tensorScoreFou = nn.functional.interpolate(input=tensorScoreFou, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
1568 | tensorScoreFiv = nn.functional.interpolate(input=tensorScoreFiv, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
1569 |
1570 | return self.moduleCombine(torch.cat([ tensorScoreOne, tensorScoreTwo, tensorScoreThr, tensorScoreFou, tensorScoreFiv ], 1))
1571 |
1572 | # class for VGG19 modle
1573 | # borrows largely from torchvision vgg
1574 | class VGG19(nn.Module):
1575 | def __init__(self, init_weights=None, feature_mode=False, batch_norm=False, num_classes=1000):
1576 | super(VGG19, self).__init__()
1577 | self.cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
1578 | self.init_weights = init_weights
1579 | self.feature_mode = feature_mode
1580 | self.batch_norm = batch_norm
1581 | self.num_clases = num_classes
1582 | self.features = self.make_layers(self.cfg, batch_norm)
1583 | self.classifier = nn.Sequential(
1584 | nn.Linear(512 * 7 * 7, 4096),
1585 | nn.ReLU(True),
1586 | nn.Dropout(),
1587 | nn.Linear(4096, 4096),
1588 | nn.ReLU(True),
1589 | nn.Dropout(),
1590 | nn.Linear(4096, num_classes),
1591 | )
1592 | # print('----------load the pretrained vgg net---------')
1593 | # if not init_weights == None:
1594 | # print('load the weights')
1595 | # self.load_state_dict(torch.load(init_weights))
1596 |
1597 |
1598 | def make_layers(self, cfg, batch_norm=False):
1599 | layers = []
1600 | in_channels = 3
1601 | for v in cfg:
1602 | if v == 'M':
1603 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
1604 | else:
1605 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
1606 | if batch_norm:
1607 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
1608 | else:
1609 | layers += [conv2d, nn.ReLU(inplace=True)]
1610 | in_channels = v
1611 | return nn.Sequential(*layers)
1612 |
1613 | def forward(self, x):
1614 | if self.feature_mode:
1615 | module_list = list(self.features.modules())
1616 | for l in module_list[1:27]: # conv4_4
1617 | x = l(x)
1618 | if not self.feature_mode:
1619 | x = self.features(x)
1620 | x = x.view(x.size(0), -1)
1621 | x = self.classifier(x)
1622 |
1623 | return x
1624 |
1625 | class Classifier(nn.Module):
1626 | def __init__(self, input_nc, classes, ngf=64, num_downs=3, norm_layer=nn.BatchNorm2d, use_dropout=False, h=512, w=512, dim=4096):
1627 | super(Classifier, self).__init__()
1628 | self.input_nc = input_nc
1629 | self.ngf = ngf
1630 | if type(norm_layer) == functools.partial:
1631 | use_bias = norm_layer.func == nn.InstanceNorm2d
1632 | else:
1633 | use_bias = norm_layer == nn.InstanceNorm2d
1634 |
1635 | model = [nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias), nn.LeakyReLU(0.2, True)]
1636 | nf_mult = 1
1637 | nf_mult_prev = 1
1638 | for n in range(1, num_downs):
1639 | nf_mult_prev = nf_mult
1640 | nf_mult = min(2 ** n, 8)
1641 | model += [
1642 | nn.Conv2d(int(ngf * nf_mult_prev), int(ngf * nf_mult), kernel_size=4, stride=2, padding=1, bias=use_bias),
1643 | norm_layer(int(ngf * nf_mult)),
1644 | nn.LeakyReLU(0.2, True)
1645 | ]
1646 | nf_mult_prev = nf_mult
1647 | nf_mult = min(2 ** num_downs, 8)
1648 | model += [
1649 | nn.Conv2d(ngf * nf_mult_prev, ngf * nf_mult, kernel_size=4, stride=1, padding=1, bias=use_bias),
1650 | norm_layer(ngf * nf_mult),
1651 | nn.LeakyReLU(0.2, True)
1652 | ]
1653 | self.encoder = nn.Sequential(*model)
1654 |
1655 | self.classifier = nn.Sequential(
1656 | nn.Linear(512 * 7 * 7, dim),
1657 | nn.ReLU(True),
1658 | nn.Dropout(),
1659 | nn.Linear(dim, dim),
1660 | nn.ReLU(True),
1661 | nn.Dropout(),
1662 | nn.Linear(dim, classes),
1663 | )
1664 |
1665 | def forward(self, x):
1666 | ax = self.encoder(x)
1667 | #print('ax',ax.shape) # (8, 512, 7, 7)
1668 | ax = ax.view(ax.size(0), -1) # view -- reshape
1669 | return self.classifier(ax)
1670 |
--------------------------------------------------------------------------------