├── examples ├── 00.jpg ├── 01.jpg ├── 02.jpg ├── 03.jpg ├── 04.jpg ├── 05.jpg ├── 06.jpg ├── 07.jpg ├── 08.jpg ├── 09.jpg ├── 10.jpg └── 11.jpg ├── images ├── sample.png └── architecture.png ├── util ├── __init__.py ├── image_pool.py ├── html.py ├── get_data.py ├── util.py └── visualizer.py ├── options ├── __init__.py ├── test_options.py ├── train_options.py └── base_options.py ├── requirements.txt ├── .gitignore ├── test_seq_style3.py ├── function.py ├── data ├── image_folder.py ├── single_dataset.py ├── __init__.py ├── base_dataset.py └── unaligned_mask_stylecls_dataset.py ├── readme.md ├── models ├── __init__.py ├── test_model.py ├── pretrained_networks.py ├── base_model.py ├── cycle_gan_cls_model.py └── networks.py ├── QMUPD.ipynb ├── test.py └── train.py /examples/00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/00.jpg -------------------------------------------------------------------------------- /examples/01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/01.jpg -------------------------------------------------------------------------------- /examples/02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/02.jpg -------------------------------------------------------------------------------- /examples/03.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/03.jpg -------------------------------------------------------------------------------- /examples/04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/04.jpg -------------------------------------------------------------------------------- /examples/05.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/05.jpg -------------------------------------------------------------------------------- /examples/06.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/06.jpg -------------------------------------------------------------------------------- /examples/07.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/07.jpg -------------------------------------------------------------------------------- /examples/08.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/08.jpg -------------------------------------------------------------------------------- /examples/09.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/09.jpg -------------------------------------------------------------------------------- /examples/10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/10.jpg -------------------------------------------------------------------------------- /examples/11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/QMUPD/master/examples/11.jpg -------------------------------------------------------------------------------- /images/sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/QMUPD/master/images/sample.png -------------------------------------------------------------------------------- /images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cedro3/QMUPD/master/images/architecture.png -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.2.0 2 | torchvision==0.4.0 3 | dominate==2.4.0 4 | visdom==0.1.8.9 5 | scipy==1.1.0 6 | numpy==1.16.4 7 | #Pillow==6.2.1 8 | opencv-python==4.1.0.25 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | debug* 3 | datasets/ 4 | checkpoints/ 5 | results/ 6 | build/ 7 | dist/ 8 | *.png 9 | torch.egg-info/ 10 | */**/__pycache__ 11 | torch/version.py 12 | torch/csrc/generic/TensorMethods.cpp 13 | torch/lib/*.so* 14 | torch/lib/*.dylib* 15 | torch/lib/*.h 16 | torch/lib/build 17 | torch/lib/tmp_install 18 | torch/lib/include 19 | torch/lib/torch_shm_manager 20 | torch/csrc/cudnn/cuDNN.cpp 21 | torch/csrc/nn/THNN.cwrap 22 | torch/csrc/nn/THNN.cpp 23 | torch/csrc/nn/THCUNN.cwrap 24 | torch/csrc/nn/THCUNN.cpp 25 | torch/csrc/nn/THNN_generic.cwrap 26 | torch/csrc/nn/THNN_generic.cpp 27 | torch/csrc/nn/THNN_generic.h 28 | docs/src/**/* 29 | test/data/legacy_modules.t7 30 | test/data/gpu_tensors.pt 31 | test/htmlcov 32 | test/.coverage 33 | */*.pyc 34 | */**/*.pyc 35 | */**/**/*.pyc 36 | */**/**/**/*.pyc 37 | */**/**/**/**/*.pyc 38 | */*.so* 39 | */**/*.so* 40 | */**/*.dylib* 41 | test/data/legacy_serialized.pt 42 | *~ 43 | .idea 44 | txt_output/* 45 | vo/* 46 | *.xlsx 47 | -------------------------------------------------------------------------------- /test_seq_style3.py: -------------------------------------------------------------------------------- 1 | import os, glob 2 | 3 | #================== settings ================== 4 | exp = 'QMUPD_model';epoch='200' 5 | dataroot = 'examples' 6 | gpu_id = '-1' 7 | 8 | netga = 'resnet_style2_9blocks' 9 | model0_res = 0 10 | model1_res = 0 11 | imgsize = 512 12 | extraflag = ' --netga %s --model0_res %d --model1_res %d' % (netga, model0_res, model1_res) 13 | 14 | #==================== command ================== 15 | for vec in [[1,0,0],[0,1,0],[0,0,1]]: 16 | svec = '%d,%d,%d' % (vec[0],vec[1],vec[2]) 17 | img1 = 'imagesstyle%d-%d-%d'%(vec[0],vec[1],vec[2]) 18 | print('results/%s/test_%s/index%s.html'%(exp,epoch,img1[6:])) 19 | command = 'python test.py --dataroot %s --name %s --model test --output_nc 1 --no_dropout --model_suffix _A %s --num_test 1000 --epoch %s --style_control 1 --imagefolder %s --sinput svec --svec %s --crop_size %d --load_size %d --gpu_ids %s' % (dataroot,exp,extraflag,epoch,img1,svec,imgsize,imgsize,gpu_id) 20 | os.system(command) 21 | 22 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 13 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 14 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 16 | # Dropout and Batchnorm has different behavioir during training and test. 17 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') 18 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') 19 | parser.add_argument('--imagefolder', type=str, default='images', help='subfolder to save images') 20 | # rewrite devalue values 21 | parser.set_defaults(model='test') 22 | # To avoid cropping, the load_size should be the same as crop_size 23 | parser.set_defaults(load_size=parser.get_default('crop_size')) 24 | self.isTrain = False 25 | return parser 26 | -------------------------------------------------------------------------------- /function.py: -------------------------------------------------------------------------------- 1 | # --- display_mp4 --- 2 | from IPython.display import display, HTML 3 | from IPython.display import HTML 4 | 5 | def display_mp4(path): 6 | from base64 import b64encode 7 | mp4 = open(path,'rb').read() 8 | data_url = "data:video/mp4;base64," + b64encode(mp4).decode() 9 | display(HTML(""" 10 | 13 | """ % data_url)) 14 | #print('Display finished.') ### 15 | 16 | 17 | # --- display_pic --- 18 | import matplotlib.pyplot as plt 19 | from PIL import Image 20 | import numpy as np 21 | import os 22 | 23 | def display_pic(folder): 24 | fig = plt.figure(figsize=(30, 60)) 25 | files = os.listdir(folder) 26 | files.sort() 27 | for i, file in enumerate(files): 28 | if file=='.ipynb_checkpoints': 29 | continue 30 | if file=='.DS_Store': 31 | continue 32 | img = Image.open(folder+'/'+file) 33 | images = np.asarray(img) 34 | ax = fig.add_subplot(10, 6, i+1, xticks=[], yticks=[]) 35 | image_plt = np.array(images) 36 | ax.imshow(image_plt) 37 | #name = os.path.splitext(file) 38 | ax.set_xlabel(file, fontsize=20) 39 | plt.show() 40 | plt.close() 41 | 42 | 43 | # --- reset_folder --- 44 | import shutil 45 | 46 | def reset_folder(path): 47 | if os.path.isdir(path): 48 | shutil.rmtree(path) 49 | os.makedirs(path,exist_ok=True) 50 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | """A modified image folder class 2 | 3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) 4 | so that this class can load images from both current directory and its subdirectories. 5 | """ 6 | 7 | import torch.utils.data as data 8 | 9 | from PIL import Image 10 | import os 11 | import os.path 12 | 13 | IMG_EXTENSIONS = [ 14 | '.jpg', '.JPG', '.jpeg', '.JPEG', 15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 16 | ] 17 | 18 | 19 | def is_image_file(filename): 20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 21 | 22 | 23 | def make_dataset(dir, max_dataset_size=float("inf")): 24 | images = [] 25 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 26 | 27 | for root, _, fnames in sorted(os.walk(dir)): 28 | for fname in fnames: 29 | if is_image_file(fname): 30 | path = os.path.join(root, fname) 31 | images.append(path) 32 | return images[:min(max_dataset_size, len(images))] 33 | 34 | 35 | def default_loader(path): 36 | return Image.open(path).convert('RGB') 37 | 38 | 39 | class ImageFolder(data.Dataset): 40 | 41 | def __init__(self, root, transform=None, return_paths=False, 42 | loader=default_loader): 43 | imgs = make_dataset(root) 44 | if len(imgs) == 0: 45 | raise(RuntimeError("Found 0 images in: " + root + "\n" 46 | "Supported image extensions are: " + 47 | ",".join(IMG_EXTENSIONS))) 48 | 49 | self.root = root 50 | self.imgs = imgs 51 | self.transform = transform 52 | self.return_paths = return_paths 53 | self.loader = loader 54 | 55 | def __getitem__(self, index): 56 | path = self.imgs[index] 57 | img = self.loader(path) 58 | if self.transform is not None: 59 | img = self.transform(img) 60 | if self.return_paths: 61 | return img, path 62 | else: 63 | return img 64 | 65 | def __len__(self): 66 | return len(self.imgs) 67 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | 2 | # Quality Metric Guided Portrait Line Drawing Generation from Unpaired Training Data 3 | 4 | We provide PyTorch implementations for our TPAMI paper "Quality Metric Guided Portrait Line Drawing Generation from Unpaired Training Data". [paper](https://ieeexplore.ieee.org/document/9699090) 5 | 6 | Our method can (1) learn to generate high quality portrait drawings in multiple styles using a single network and (2) generate portrait drawings in a “new style” unseen in the training data. 7 | 8 | 9 | ## Our Proposed Framework 10 | 11 | 12 | 13 | ## Sample Results 14 | 15 | 16 | 17 | ## Prerequisites 18 | - Linux or macOS 19 | - Python 3 20 | - CPU or NVIDIA GPU + CUDA CuDNN 21 | 22 | ## Installation 23 | - To install the dependencies, run 24 | ```bash 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | ## Quick Test (apply a pretrained model, generate high quality portrait drawings in multiple styles using a single network) 29 | 30 | - 1. Download pre-trained models from [BaiduYun](https://pan.baidu.com/s/1eY60A1z2k9gTr9ryDxvMxQ)(extract code:g8is) or [GoogleDrive](https://drive.google.com/drive/folders/1R4mjaiXN3fISp0Lc4rP6DiTbQMCeTC9A?usp=sharing) and rename the folder to `checkpoints/`. 31 | 32 | - 2. Test for example photos: generate artistic portrait drawings for example photos in the folder `./examples` using 33 | ``` bash 34 | python test_seq_style3.py 35 | ``` 36 | The test results will be saved to html files here: `./results/QMUPD_model/test_200/indexstyle*.html`. 37 | The result images are saved in `./results/QMUPD_model/test_200/imagesstyle*`, 38 | where `real`, `fake`, correspond to input face photo, synthesized drawing of a certain style, respectively. 39 | 40 | You can contact email ranyi@sjtu.edu.cn for any questions. 41 | 42 | 43 | ## Citation 44 | If you use this code for your research, please cite our paper. 45 | 46 | ``` 47 | @article{YiLLR22, 48 | title = {Quality Metric Guided Portrait Line Drawing Generation from Unpaired Training Data}, 49 | author = {Yi, Ran and Liu, Yong-Jin and Lai, Yu-Kun and Rosin, Paul L}, 50 | journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence}, 51 | year = {DOI (identifier) 10.1109/TPAMI.2022.3147570, 2022}, 52 | } 53 | ``` 54 | 55 | ## Acknowledgments 56 | Our code is inspired by [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). 57 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class ImagePool(): 6 | """This class implements an image buffer that stores previously generated images. 7 | 8 | This buffer enables us to update discriminators using a history of generated images 9 | rather than the ones produced by the latest generators. 10 | """ 11 | 12 | def __init__(self, pool_size): 13 | """Initialize the ImagePool class 14 | 15 | Parameters: 16 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 17 | """ 18 | self.pool_size = pool_size 19 | if self.pool_size > 0: # create an empty pool 20 | self.num_imgs = 0 21 | self.images = [] 22 | 23 | def query(self, images): 24 | """Return an image from the pool. 25 | 26 | Parameters: 27 | images: the latest generated images from the generator 28 | 29 | Returns images from the buffer. 30 | 31 | By 50/100, the buffer will return input images. 32 | By 50/100, the buffer will return images previously stored in the buffer, 33 | and insert the current images to the buffer. 34 | """ 35 | if self.pool_size == 0: # if the buffer size is 0, do nothing 36 | return images 37 | return_images = [] 38 | for image in images: 39 | image = torch.unsqueeze(image.data, 0) 40 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer 41 | self.num_imgs = self.num_imgs + 1 42 | self.images.append(image) 43 | return_images.append(image) 44 | else: 45 | p = random.uniform(0, 1) 46 | if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer 47 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 48 | tmp = self.images[random_id].clone() 49 | self.images[random_id] = image 50 | return_images.append(tmp) 51 | else: # by another 50% chance, the buffer will return the current image 52 | return_images.append(image) 53 | return_images = torch.cat(return_images, 0) # collect all the images and return 54 | return return_images 55 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from models.base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "models." + model_name + "_model" 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace('_', '') + 'model' 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() \ 38 | and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 43 | exit(0) 44 | 45 | return model 46 | 47 | 48 | def get_option_setter(model_name): 49 | """Return the static method of the model class.""" 50 | model_class = find_model_using_name(model_name) 51 | return model_class.modify_commandline_options 52 | 53 | 54 | def create_model(opt): 55 | """Create a model given the option. 56 | 57 | This function warps the class CustomDatasetDataLoader. 58 | This is the main interface between this package and 'train.py'/'test.py' 59 | 60 | Example: 61 | >>> from models import create_model 62 | >>> model = create_model(opt) 63 | """ 64 | model = find_model_using_name(opt.model) 65 | instance = model(opt) 66 | print("model [%s] was created" % type(instance).__name__) 67 | return instance 68 | -------------------------------------------------------------------------------- /data/single_dataset.py: -------------------------------------------------------------------------------- 1 | from data.base_dataset import BaseDataset, get_transform, get_params, get_transform_mask 2 | from data.image_folder import make_dataset 3 | from PIL import Image 4 | import torch 5 | import os, glob 6 | 7 | 8 | class SingleDataset(BaseDataset): 9 | """This dataset class can load a set of images specified by the path --dataroot /path/to/data. 10 | 11 | It can be used for generating CycleGAN results only for one side with the model option '-model test'. 12 | """ 13 | 14 | def __init__(self, opt): 15 | """Initialize this dataset class. 16 | 17 | Parameters: 18 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 19 | """ 20 | BaseDataset.__init__(self, opt) 21 | #self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size)) 22 | imglistA = './datasets/list/%s/%s.txt' % (opt.phase+'A', opt.dataroot) 23 | if os.path.exists(imglistA): 24 | self.A_paths = sorted(open(imglistA, 'r').read().splitlines()) 25 | else: 26 | self.A_paths = sorted(glob.glob(opt.dataroot + '/*.*')) 27 | self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc 28 | #self.transform = get_transform(opt, grayscale=(input_nc == 1)) 29 | 30 | def __getitem__(self, index): 31 | """Return a data point and its metadata information. 32 | 33 | Parameters: 34 | index - - a random integer for data indexing 35 | 36 | Returns a dictionary that contains A and A_paths 37 | A(tensor) - - an image in one domain 38 | A_paths(str) - - the path of the image 39 | """ 40 | A_path = self.A_paths[index] 41 | A_img = Image.open(A_path).convert('RGB') 42 | transform_params_A = get_params(self.opt, A_img.size) 43 | A = get_transform(self.opt, transform_params_A, grayscale=(self.input_nc == 1))(A_img) 44 | item = {'A': A, 'A_paths': A_path} 45 | 46 | if self.opt.model == 'test_r1': 47 | basenA = os.path.basename(A_path) 48 | A_addchan_img = Image.open(os.path.join('./datasets/list/mask/A_all',basenA)) 49 | A_addchan = get_transform_mask(self.opt, transform_params_A, grayscale=1)(A_addchan_img) 50 | item['A_addchan'] = A_addchan 51 | 52 | if self.opt.style_control: 53 | if self.opt.sinput == 'sind': 54 | B_style = torch.Tensor([0.,0.,0.]) 55 | B_style[self.opt.sind] = 1. 56 | elif self.opt.sinput == 'svec': 57 | if self.opt.svec[0] == '~': 58 | self.opt.svec = '-'+self.opt.svec[1:] 59 | ss = self.opt.svec.split(',') 60 | B_style = torch.Tensor([float(ss[0]),float(ss[1]),float(ss[2])]) 61 | elif self.opt.sinput == 'simg': 62 | self.featureloc = os.path.join('style_features/styles2/', self.opt.sfeature_mode) 63 | B_style = np.load(self.featureloc, self.opt.simg[:-4]+'.npy') 64 | 65 | B_style = B_style.view(3, 1, 1) 66 | B_style = B_style.repeat(1, 128, 128) 67 | item['B_style'] = B_style 68 | 69 | return item 70 | 71 | def __len__(self): 72 | """Return the total number of images in the dataset.""" 73 | return len(self.A_paths) 74 | -------------------------------------------------------------------------------- /QMUPD.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": { 17 | "id": "_7IMTJEdpSMJ" 18 | }, 19 | "outputs": [], 20 | "source": [ 21 | "#@title セットアップ\n", 22 | "\n", 23 | "# githubからコードを取得\n", 24 | "! git clone https://github.com/cedro3/QMUPD.git\n", 25 | "%cd QMUPD\n", 26 | "\n", 27 | "# ライブラリ・インストール\n", 28 | "! pip install -r requirements.txt\n", 29 | "! pip install pretrainedmodels\n", 30 | "\n", 31 | "# 学習済みパラメータ・ダウンロード\n", 32 | "! pip install --upgrade gdown\n", 33 | "import gdown\n", 34 | "gdown.download('https://drive.google.com/uc?id=1QpuCQ0LrrlsHCs3Vh6xC0uIBlWrDrGo1', 'checkpoints.zip', quiet=False)\n", 35 | "! unzip checkpoints.zip\n", 36 | "\n", 37 | "# 関数インポート\n", 38 | "from function import *" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": { 45 | "id": "cUz-cTvFIoAu" 46 | }, 47 | "outputs": [], 48 | "source": [ 49 | "#@title サンプル画像の表示\n", 50 | "display_pic('examples')" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": { 57 | "id": "hwtglANBqHUQ" 58 | }, 59 | "outputs": [], 60 | "source": [ 61 | "#@title 線画の作成\n", 62 | "reset_folder('results')\n", 63 | "! python test_seq_style3.py" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": { 70 | "id": "g3I50o1kwCEo" 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "#@title スタイル1表示\n", 75 | "display_pic('results/QMUPD_model/test_200/imagesstyle0-0-1')" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": { 82 | "id": "w-kSR4bixOh6" 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "#@title スタイル2表示\n", 87 | "display_pic('results/QMUPD_model/test_200/imagesstyle0-1-0')" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": { 94 | "id": "ti_sOg8fzMMg" 95 | }, 96 | "outputs": [], 97 | "source": [ 98 | "#@title スタイル3表示\n", 99 | "display_pic('results/QMUPD_model/test_200/imagesstyle1-0-0')" 100 | ] 101 | } 102 | ], 103 | "metadata": { 104 | "accelerator": "GPU", 105 | "colab": { 106 | "collapsed_sections": [], 107 | "name": "QMUPD", 108 | "provenance": [], 109 | "include_colab_link": true 110 | }, 111 | "kernelspec": { 112 | "display_name": "Python 3", 113 | "name": "python3" 114 | }, 115 | "language_info": { 116 | "name": "python" 117 | } 118 | }, 119 | "nbformat": 4, 120 | "nbformat_minor": 0 121 | } -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | """This class includes training options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) 12 | # visdom and HTML visualization parameters 13 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') 14 | parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 15 | parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') 16 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 17 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') 18 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 19 | parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') 20 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 21 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 22 | # network saving and loading parameters 23 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 24 | parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs') 25 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') 26 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 27 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 28 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 29 | # training parameters 30 | parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') 31 | parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') 32 | parser.add_argument('--niter_end', type=int, default=200, help='# of iter to end') 33 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 34 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 35 | parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') 36 | parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') 37 | parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') 38 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 39 | 40 | self.isTrain = True 41 | return parser 42 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import importlib 14 | import torch.utils.data 15 | from data.base_dataset import BaseDataset 16 | 17 | 18 | def find_dataset_using_name(dataset_name): 19 | """Import the module "data/[dataset_name]_dataset.py". 20 | 21 | In the file, the class called DatasetNameDataset() will 22 | be instantiated. It has to be a subclass of BaseDataset, 23 | and it is case-insensitive. 24 | """ 25 | dataset_filename = "data." + dataset_name + "_dataset" 26 | datasetlib = importlib.import_module(dataset_filename) 27 | 28 | dataset = None 29 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 30 | for name, cls in datasetlib.__dict__.items(): 31 | if name.lower() == target_dataset_name.lower() \ 32 | and issubclass(cls, BaseDataset): 33 | dataset = cls 34 | 35 | if dataset is None: 36 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 37 | 38 | return dataset 39 | 40 | 41 | def get_option_setter(dataset_name): 42 | """Return the static method of the dataset class.""" 43 | dataset_class = find_dataset_using_name(dataset_name) 44 | return dataset_class.modify_commandline_options 45 | 46 | 47 | def create_dataset(opt): 48 | """Create a dataset given the option. 49 | 50 | This function wraps the class CustomDatasetDataLoader. 51 | This is the main interface between this package and 'train.py'/'test.py' 52 | 53 | Example: 54 | >>> from data import create_dataset 55 | >>> dataset = create_dataset(opt) 56 | """ 57 | data_loader = CustomDatasetDataLoader(opt) 58 | dataset = data_loader.load_data() 59 | return dataset 60 | 61 | 62 | class CustomDatasetDataLoader(): 63 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 64 | 65 | def __init__(self, opt): 66 | """Initialize this class 67 | 68 | Step 1: create a dataset instance given the name [dataset_mode] 69 | Step 2: create a multi-threaded data loader. 70 | """ 71 | self.opt = opt 72 | dataset_class = find_dataset_using_name(opt.dataset_mode) 73 | self.dataset = dataset_class(opt) 74 | print("dataset [%s] was created" % type(self.dataset).__name__) 75 | self.dataloader = torch.utils.data.DataLoader( 76 | self.dataset, 77 | batch_size=opt.batch_size, 78 | shuffle=not opt.serial_batches, 79 | num_workers=int(opt.num_threads)) 80 | 81 | def load_data(self): 82 | return self 83 | 84 | def __len__(self): 85 | """Return the number of data in the dataset""" 86 | return min(len(self.dataset), self.opt.max_dataset_size) 87 | 88 | def __iter__(self): 89 | """Return a batch of data""" 90 | for i, data in enumerate(self.dataloader): 91 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 92 | break 93 | yield data 94 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0, folder='images'): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 34 | with self.doc.head: 35 | meta(http_equiv="refresh", content=str(refresh)) 36 | 37 | def get_image_dir(self): 38 | """Return the directory that stores images""" 39 | return self.img_dir 40 | 41 | def add_header(self, text): 42 | """Insert a header to the HTML file 43 | 44 | Parameters: 45 | text (str) -- the header text 46 | """ 47 | with self.doc: 48 | h3(text) 49 | 50 | def add_images(self, ims, txts, links, width=400): 51 | """add images to the HTML file 52 | 53 | Parameters: 54 | ims (str list) -- a list of image paths 55 | txts (str list) -- a list of image names shown on the website 56 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 57 | """ 58 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 59 | self.doc.add(self.t) 60 | with self.t: 61 | with tr(): 62 | for im, txt, link in zip(ims, txts, links): 63 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 64 | with p(): 65 | with a(href=os.path.join('images', link)): 66 | #img(style="width:%dpx" % width, src=os.path.join('images', im)) 67 | img(style="width:%dpx" % width, src=os.path.join(self.folder, im)) 68 | br() 69 | p(txt) 70 | 71 | def save(self): 72 | """save the current content to the HMTL file""" 73 | #html_file = '%s/index.html' % self.web_dir 74 | name = self.folder[6:] if self.folder[:6] == 'images' else self.folder 75 | html_file = '%s/index%s.html' % (self.web_dir, name) 76 | f = open(html_file, 'wt') 77 | f.write(self.doc.render()) 78 | f.close() 79 | 80 | 81 | if __name__ == '__main__': # we show an example usage here. 82 | html = HTML('web/', 'test_html') 83 | html.add_header('hello world') 84 | 85 | ims, txts, links = [], [], [] 86 | for n in range(4): 87 | ims.append('image_%d.png' % n) 88 | txts.append('text_%d' % n) 89 | links.append('image_%d.png' % n) 90 | html.add_images(ims, txts, links) 91 | html.save() 92 | -------------------------------------------------------------------------------- /util/get_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import tarfile 4 | import requests 5 | from warnings import warn 6 | from zipfile import ZipFile 7 | from bs4 import BeautifulSoup 8 | from os.path import abspath, isdir, join, basename 9 | 10 | 11 | class GetData(object): 12 | """A Python script for downloading CycleGAN or pix2pix datasets. 13 | 14 | Parameters: 15 | technique (str) -- One of: 'cyclegan' or 'pix2pix'. 16 | verbose (bool) -- If True, print additional information. 17 | 18 | Examples: 19 | >>> from util.get_data import GetData 20 | >>> gd = GetData(technique='cyclegan') 21 | >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. 22 | 23 | Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh' 24 | and 'scripts/download_cyclegan_model.sh'. 25 | """ 26 | 27 | def __init__(self, technique='cyclegan', verbose=True): 28 | url_dict = { 29 | 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/', 30 | 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' 31 | } 32 | self.url = url_dict.get(technique.lower()) 33 | self._verbose = verbose 34 | 35 | def _print(self, text): 36 | if self._verbose: 37 | print(text) 38 | 39 | @staticmethod 40 | def _get_options(r): 41 | soup = BeautifulSoup(r.text, 'lxml') 42 | options = [h.text for h in soup.find_all('a', href=True) 43 | if h.text.endswith(('.zip', 'tar.gz'))] 44 | return options 45 | 46 | def _present_options(self): 47 | r = requests.get(self.url) 48 | options = self._get_options(r) 49 | print('Options:\n') 50 | for i, o in enumerate(options): 51 | print("{0}: {1}".format(i, o)) 52 | choice = input("\nPlease enter the number of the " 53 | "dataset above you wish to download:") 54 | return options[int(choice)] 55 | 56 | def _download_data(self, dataset_url, save_path): 57 | if not isdir(save_path): 58 | os.makedirs(save_path) 59 | 60 | base = basename(dataset_url) 61 | temp_save_path = join(save_path, base) 62 | 63 | with open(temp_save_path, "wb") as f: 64 | r = requests.get(dataset_url) 65 | f.write(r.content) 66 | 67 | if base.endswith('.tar.gz'): 68 | obj = tarfile.open(temp_save_path) 69 | elif base.endswith('.zip'): 70 | obj = ZipFile(temp_save_path, 'r') 71 | else: 72 | raise ValueError("Unknown File Type: {0}.".format(base)) 73 | 74 | self._print("Unpacking Data...") 75 | obj.extractall(save_path) 76 | obj.close() 77 | os.remove(temp_save_path) 78 | 79 | def get(self, save_path, dataset=None): 80 | """ 81 | 82 | Download a dataset. 83 | 84 | Parameters: 85 | save_path (str) -- A directory to save the data to. 86 | dataset (str) -- (optional). A specific dataset to download. 87 | Note: this must include the file extension. 88 | If None, options will be presented for you 89 | to choose from. 90 | 91 | Returns: 92 | save_path_full (str) -- the absolute path to the downloaded data. 93 | 94 | """ 95 | if dataset is None: 96 | selected_dataset = self._present_options() 97 | else: 98 | selected_dataset = dataset 99 | 100 | save_path_full = join(save_path, selected_dataset.split('.')[0]) 101 | 102 | if isdir(save_path_full): 103 | warn("\n'{0}' already exists. Voiding Download.".format( 104 | save_path_full)) 105 | else: 106 | self._print('Downloading Data...') 107 | url = "{0}/{1}".format(self.url, selected_dataset) 108 | self._download_data(url, save_path=save_path) 109 | 110 | return abspath(save_path_full) 111 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """General-purpose test script for image-to-image translation. 2 | 3 | Once you have trained your model with train.py, you can use this script to test the model. 4 | It will load a saved model from --checkpoints_dir and save the results to --results_dir. 5 | 6 | It first creates model and dataset given the option. It will hard-code some parameters. 7 | It then runs inference for --num_test images and save results to an HTML file. 8 | 9 | Example (You need to train models first or download pre-trained models from our website): 10 | Test a CycleGAN model (both sides): 11 | python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan 12 | 13 | Test a CycleGAN model (one side only): 14 | python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout 15 | 16 | The option '--model test' is used for generating CycleGAN results only for one side. 17 | This option will automatically set '--dataset_mode single', which only loads the images from one set. 18 | On the contrary, using '--model cycle_gan' requires loading and generating results in both directions, 19 | which is sometimes unnecessary. The results will be saved at ./results/. 20 | Use '--results_dir ' to specify the results directory. 21 | 22 | Test a pix2pix model: 23 | python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA 24 | 25 | See options/base_options.py and options/test_options.py for more test options. 26 | See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md 27 | See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md 28 | """ 29 | import os 30 | from options.test_options import TestOptions 31 | from data import create_dataset 32 | from models import create_model 33 | from util.visualizer import save_images 34 | from util import html 35 | 36 | 37 | if __name__ == '__main__': 38 | opt = TestOptions().parse() # get test options 39 | # hard-code some parameters for test 40 | opt.num_threads = 0 # test code only supports num_threads = 1 41 | opt.batch_size = 1 # test code only supports batch_size = 1 42 | opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. 43 | opt.no_flip = True # no flip; comment this line if results on flipped images are needed. 44 | opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file. 45 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options 46 | model = create_model(opt) # create a model given opt.model and other options 47 | model.setup(opt) # regular setup: load and print networks; create schedulers 48 | # create a website 49 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch)) # define the website directory 50 | #webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch)) 51 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch), refresh=0, folder=opt.imagefolder) 52 | # test with eval mode. This only affects layers like batchnorm and dropout. 53 | # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode. 54 | # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout. 55 | if opt.eval: 56 | model.eval() 57 | for name in model.model_names: 58 | if isinstance(name, str): 59 | print(getattr(model, 'net' + name).training) 60 | for i, data in enumerate(dataset): 61 | if i >= opt.num_test: # only apply our model to opt.num_test images. 62 | break 63 | model.set_input(data) # unpack data from data loader 64 | model.test() # run inference 65 | visuals = model.get_current_visuals() # get image results 66 | img_path = model.get_image_paths() # get image paths 67 | if i % 5 == 0: # save images to an HTML file 68 | print('processing (%04d)-th image... %s' % (i, img_path)) 69 | save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) 70 | webpage.save() # save the HTML 71 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | import pdb 8 | from scipy.io import savemat 9 | 10 | 11 | def tensor2im(input_image, imtype=np.uint8): 12 | """"Converts a Tensor array into a numpy image array. 13 | 14 | Parameters: 15 | input_image (tensor) -- the input image tensor array 16 | imtype (type) -- the desired type of the converted numpy array 17 | """ 18 | if not isinstance(input_image, np.ndarray): 19 | if isinstance(input_image, torch.Tensor): # get the data from a variable 20 | image_tensor = input_image.data 21 | else: 22 | return input_image 23 | #pdb.set_trace() 24 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 25 | if image_numpy.shape[0] == 1: # grayscale to RGB 26 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 27 | elif image_numpy.shape[0] == 2: 28 | image_numpy = np.concatenate([image_numpy, image_numpy[1:2,:,:]], 0) 29 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 30 | else: # if it is a numpy array, do nothing 31 | image_numpy = input_image 32 | return image_numpy.astype(imtype) 33 | #return np.round(image_numpy).astype(imtype),image_numpy 34 | 35 | def tensor2im2(input_image, imtype=np.uint8): 36 | """"Converts a Tensor array into a numpy image array. 37 | 38 | Parameters: 39 | input_image (tensor) -- the input image tensor array 40 | imtype (type) -- the desired type of the converted numpy array 41 | """ 42 | if not isinstance(input_image, np.ndarray): 43 | if isinstance(input_image, torch.Tensor): # get the data from a variable 44 | image_tensor = input_image.data 45 | else: 46 | return input_image 47 | #pdb.set_trace() 48 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 49 | if image_numpy.shape[0] == 1: # grayscale to RGB 50 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 51 | elif image_numpy.shape[0] == 2: 52 | image_numpy = np.concatenate([image_numpy, image_numpy[1:2,:,:]], 0) 53 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) 54 | image_numpy[:,:,0] = image_numpy[:,:,0] * 0.229 + 0.485 55 | image_numpy[:,:,1] = image_numpy[:,:,1] * 0.224 + 0.456 56 | image_numpy[:,:,2] = image_numpy[:,:,2] * 0.225 + 0.406 57 | image_numpy = image_numpy * 255.0 # post-processing: tranpose and scaling 58 | else: # if it is a numpy array, do nothing 59 | image_numpy = input_image 60 | return image_numpy.astype(imtype) 61 | 62 | 63 | def diagnose_network(net, name='network'): 64 | """Calculate and print the mean of average absolute(gradients) 65 | 66 | Parameters: 67 | net (torch network) -- Torch network 68 | name (str) -- the name of the network 69 | """ 70 | mean = 0.0 71 | count = 0 72 | for param in net.parameters(): 73 | if param.grad is not None: 74 | mean += torch.mean(torch.abs(param.grad.data)) 75 | count += 1 76 | if count > 0: 77 | mean = mean / count 78 | print(name) 79 | print(mean) 80 | 81 | 82 | def save_image(image_numpy, image_path): 83 | """Save a numpy image to the disk 84 | 85 | Parameters: 86 | image_numpy (numpy array) -- input numpy array 87 | image_path (str) -- the path of the image 88 | """ 89 | image_pil = Image.fromarray(image_numpy) 90 | #pdb.set_trace() 91 | image_pil.save(image_path) 92 | 93 | 94 | def print_numpy(x, val=True, shp=False): 95 | """Print the mean, min, max, median, std, and size of a numpy array 96 | 97 | Parameters: 98 | val (bool) -- if print the values of the numpy array 99 | shp (bool) -- if print the shape of the numpy array 100 | """ 101 | x = x.astype(np.float64) 102 | if shp: 103 | print('shape,', x.shape) 104 | if val: 105 | x = x.flatten() 106 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 107 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 108 | 109 | 110 | def mkdirs(paths): 111 | """create empty directories if they don't exist 112 | 113 | Parameters: 114 | paths (str list) -- a list of directory paths 115 | """ 116 | if isinstance(paths, list) and not isinstance(paths, str): 117 | for path in paths: 118 | mkdir(path) 119 | else: 120 | mkdir(paths) 121 | 122 | 123 | def mkdir(path): 124 | """create a single empty directory if it didn't exist 125 | 126 | Parameters: 127 | path (str) -- a single directory path 128 | """ 129 | if not os.path.exists(path): 130 | os.makedirs(path) 131 | 132 | def normalize_tensor(in_feat,eps=1e-10): 133 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 134 | return in_feat/(norm_factor+eps) -------------------------------------------------------------------------------- /models/test_model.py: -------------------------------------------------------------------------------- 1 | from .base_model import BaseModel 2 | from . import networks 3 | import torch 4 | import pdb 5 | 6 | class TestModel(BaseModel): 7 | """ This TesteModel can be used to generate CycleGAN results for only one direction. 8 | This model will automatically set '--dataset_mode single', which only loads the images from one collection. 9 | 10 | See the test instruction for more details. 11 | """ 12 | @staticmethod 13 | def modify_commandline_options(parser, is_train=True): 14 | """Add new dataset-specific options, and rewrite default values for existing options. 15 | 16 | Parameters: 17 | parser -- original option parser 18 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 19 | 20 | Returns: 21 | the modified parser. 22 | 23 | The model can only be used during test time. It requires '--dataset_mode single'. 24 | You need to specify the network using the option '--model_suffix'. 25 | """ 26 | assert not is_train, 'TestModel cannot be used during training time' 27 | parser.set_defaults(dataset_mode='single') 28 | parser.add_argument('--model_suffix', type=str, default='', help='In checkpoints_dir, [epoch]_net_G[model_suffix].pth will be loaded as the generator.') 29 | parser.add_argument('--style_control', type=int, default=0, help='use style_control') 30 | parser.add_argument('--sfeature_mode', type=str, default='vgg19_softmax', help='vgg19 softmax as feature') 31 | parser.add_argument('--sinput', type=str, default='sind', help='use which one for style input') 32 | parser.add_argument('--sind', type=int, default=0, help='one hot for sfeature') 33 | parser.add_argument('--svec', type=str, default='1,0,0', help='3-dim vec') 34 | parser.add_argument('--simg', type=str, default='Yann_Legendre-053', help='drawing example for style') 35 | parser.add_argument('--netga', type=str, default='resnet_style_9blocks', help='net arch for netG_A') 36 | parser.add_argument('--model0_res', type=int, default=0, help='number of resblocks in model0') 37 | parser.add_argument('--model1_res', type=int, default=0, help='number of resblocks in model1 (after insert style, before 2 column merge)') 38 | 39 | return parser 40 | 41 | def __init__(self, opt): 42 | """Initialize the pix2pix class. 43 | 44 | Parameters: 45 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 46 | """ 47 | assert(not opt.isTrain) 48 | BaseModel.__init__(self, opt) 49 | # specify the training losses you want to print out. The training/test scripts will call 50 | self.loss_names = [] 51 | # specify the images you want to save/display. The training/test scripts will call 52 | #self.visual_names = ['real', 'fake', 'rec', 'fake_B'] 53 | self.visual_names = ['real', 'fake'] 54 | # specify the models you want to save to the disk. The training/test scripts will call and 55 | self.model_names = ['G' + opt.model_suffix, 'G_B'] # only generator is needed. 56 | if not self.opt.style_control: 57 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, 58 | opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) 59 | else: 60 | print(opt.netga) 61 | print('model0_res', opt.model0_res) 62 | print('model1_res', opt.model1_res) 63 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netga, opt.norm, 64 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.model0_res, opt.model1_res) 65 | 66 | self.netGB = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, 67 | opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) 68 | # assigns the model to self.netG_[suffix] so that it can be loaded 69 | # please see 70 | setattr(self, 'netG' + opt.model_suffix, self.netG) # store netG in self. 71 | setattr(self, 'netG_B', self.netGB) # store netGB in self. 72 | 73 | def set_input(self, input): 74 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 75 | 76 | Parameters: 77 | input: a dictionary that contains the data itself and its metadata information. 78 | 79 | We need to use 'single_dataset' dataset mode. It only load images from one domain. 80 | """ 81 | self.real = input['A'].to(self.device) 82 | self.image_paths = input['A_paths'] 83 | if self.opt.style_control: 84 | self.style = input['B_style'] 85 | 86 | def forward(self): 87 | """Run forward pass.""" 88 | if not self.opt.style_control: 89 | self.fake = self.netG(self.real) # G(real) 90 | else: 91 | #print(torch.mean(self.style,(2,3)),'style_control') 92 | self.fake = self.netG(self.real, self.style) 93 | 94 | def optimize_parameters(self): 95 | """No optimization for test model.""" 96 | pass 97 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """General-purpose training script for image-to-image translation. 2 | 3 | This script works for various models (with option '--model': e.g., pix2pix, cyclegan, colorization) and 4 | different datasets (with option '--dataset_mode': e.g., aligned, unaligned, single, colorization). 5 | You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model'). 6 | 7 | It first creates model, dataset, and visualizer given the option. 8 | It then does standard network training. During the training, it also visualize/save the images, print/save the loss plot, and save models. 9 | The script supports continue/resume training. Use '--continue_train' to resume your previous training. 10 | 11 | Example: 12 | Train a CycleGAN model: 13 | python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan 14 | Train a pix2pix model: 15 | python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA 16 | 17 | See options/base_options.py and options/train_options.py for more training options. 18 | See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md 19 | See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md 20 | """ 21 | import time 22 | from options.train_options import TrainOptions 23 | from data import create_dataset 24 | from models import create_model 25 | from util.visualizer import Visualizer 26 | import pdb 27 | 28 | if __name__ == '__main__': 29 | start = time.time() 30 | opt = TrainOptions().parse() # get training options 31 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options 32 | dataset_size = len(dataset) # get the number of images in the dataset. 33 | print('The number of training images = %d' % dataset_size) 34 | 35 | model = create_model(opt) # create a model given opt.model and other options 36 | model.setup(opt) # regular setup: load and print networks; create schedulers 37 | visualizer = Visualizer(opt) # create a visualizer that display/save images and plots 38 | total_iters = 0 # the total number of training iterations 39 | 40 | #for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): # outer loop for different epochs; we save the model by , + 41 | for epoch in range(opt.epoch_count, opt.niter_end + 1): 42 | epoch_start_time = time.time() # timer for entire epoch 43 | iter_data_time = time.time() # timer for data loading per iteration 44 | epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch 45 | model.update_process(epoch) 46 | 47 | for i, data in enumerate(dataset): # inner loop within one epoch 48 | iter_start_time = time.time() # timer for computation per iteration 49 | if total_iters % opt.print_freq == 0: 50 | t_data = iter_start_time - iter_data_time 51 | visualizer.reset() 52 | total_iters += opt.batch_size 53 | epoch_iter += opt.batch_size 54 | model.set_input(data) # unpack data from dataset and apply preprocessing 55 | model.optimize_parameters() # calculate loss functions, get gradients, update network weights 56 | 57 | if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file 58 | save_result = total_iters % opt.update_html_freq == 0 59 | model.compute_visuals() 60 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) 61 | 62 | if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk 63 | losses = model.get_current_losses() 64 | t_comp = (time.time() - iter_start_time) / opt.batch_size 65 | if opt.model == 'cycle_gan': 66 | processes = [model.process] + model.lambda_As 67 | visualizer.print_current_losses_process(epoch, epoch_iter, losses, t_comp, t_data, processes) 68 | else: 69 | visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data) 70 | if opt.display_id > 0: 71 | visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses) 72 | 73 | if total_iters % opt.save_latest_freq == 0: # cache our latest model every iterations 74 | print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) 75 | save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest' 76 | model.save_networks(save_suffix) 77 | 78 | iter_data_time = time.time() 79 | if epoch % opt.save_epoch_freq == 0: # cache our model every epochs 80 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) 81 | model.save_networks('latest') 82 | model.save_networks(epoch) 83 | 84 | print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 85 | model.update_learning_rate() # update learning rates at the end of every epoch. 86 | 87 | print('Total Time Taken: %d sec' % (time.time() - start)) -------------------------------------------------------------------------------- /models/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torchvision import models 4 | from IPython import embed 5 | 6 | class squeezenet(torch.nn.Module): 7 | def __init__(self, requires_grad=False, pretrained=True): 8 | super(squeezenet, self).__init__() 9 | pretrained_features = models.squeezenet1_1(pretrained=pretrained).features 10 | self.slice1 = torch.nn.Sequential() 11 | self.slice2 = torch.nn.Sequential() 12 | self.slice3 = torch.nn.Sequential() 13 | self.slice4 = torch.nn.Sequential() 14 | self.slice5 = torch.nn.Sequential() 15 | self.slice6 = torch.nn.Sequential() 16 | self.slice7 = torch.nn.Sequential() 17 | self.N_slices = 7 18 | for x in range(2): 19 | self.slice1.add_module(str(x), pretrained_features[x]) 20 | for x in range(2,5): 21 | self.slice2.add_module(str(x), pretrained_features[x]) 22 | for x in range(5, 8): 23 | self.slice3.add_module(str(x), pretrained_features[x]) 24 | for x in range(8, 10): 25 | self.slice4.add_module(str(x), pretrained_features[x]) 26 | for x in range(10, 11): 27 | self.slice5.add_module(str(x), pretrained_features[x]) 28 | for x in range(11, 12): 29 | self.slice6.add_module(str(x), pretrained_features[x]) 30 | for x in range(12, 13): 31 | self.slice7.add_module(str(x), pretrained_features[x]) 32 | if not requires_grad: 33 | for param in self.parameters(): 34 | param.requires_grad = False 35 | 36 | def forward(self, X): 37 | h = self.slice1(X) 38 | h_relu1 = h 39 | h = self.slice2(h) 40 | h_relu2 = h 41 | h = self.slice3(h) 42 | h_relu3 = h 43 | h = self.slice4(h) 44 | h_relu4 = h 45 | h = self.slice5(h) 46 | h_relu5 = h 47 | h = self.slice6(h) 48 | h_relu6 = h 49 | h = self.slice7(h) 50 | h_relu7 = h 51 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) 52 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) 53 | 54 | return out 55 | 56 | 57 | class alexnet(torch.nn.Module): 58 | def __init__(self, requires_grad=False, pretrained=True): 59 | super(alexnet, self).__init__() 60 | alexnet_pretrained_features = models.alexnet(pretrained=pretrained).features 61 | self.slice1 = torch.nn.Sequential() 62 | self.slice2 = torch.nn.Sequential() 63 | self.slice3 = torch.nn.Sequential() 64 | self.slice4 = torch.nn.Sequential() 65 | self.slice5 = torch.nn.Sequential() 66 | self.N_slices = 5 67 | for x in range(2): 68 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 69 | for x in range(2, 5): 70 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 71 | for x in range(5, 8): 72 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 73 | for x in range(8, 10): 74 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 75 | for x in range(10, 12): 76 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 77 | if not requires_grad: 78 | for param in self.parameters(): 79 | param.requires_grad = False 80 | 81 | def forward(self, X): 82 | h = self.slice1(X) 83 | h_relu1 = h 84 | h = self.slice2(h) 85 | h_relu2 = h 86 | h = self.slice3(h) 87 | h_relu3 = h 88 | h = self.slice4(h) 89 | h_relu4 = h 90 | h = self.slice5(h) 91 | h_relu5 = h 92 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) 93 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 94 | 95 | return out 96 | 97 | class vgg16(torch.nn.Module): 98 | def __init__(self, requires_grad=False, pretrained=True): 99 | super(vgg16, self).__init__() 100 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 101 | self.slice1 = torch.nn.Sequential() 102 | self.slice2 = torch.nn.Sequential() 103 | self.slice3 = torch.nn.Sequential() 104 | self.slice4 = torch.nn.Sequential() 105 | self.slice5 = torch.nn.Sequential() 106 | self.N_slices = 5 107 | for x in range(4): 108 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 109 | for x in range(4, 9): 110 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 111 | for x in range(9, 16): 112 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 113 | for x in range(16, 23): 114 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 115 | for x in range(23, 30): 116 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 117 | if not requires_grad: 118 | for param in self.parameters(): 119 | param.requires_grad = False 120 | 121 | def forward(self, X): 122 | h = self.slice1(X) 123 | h_relu1_2 = h 124 | h = self.slice2(h) 125 | h_relu2_2 = h 126 | h = self.slice3(h) 127 | h_relu3_3 = h 128 | h = self.slice4(h) 129 | h_relu4_3 = h 130 | h = self.slice5(h) 131 | h_relu5_3 = h 132 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 133 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 134 | 135 | return out 136 | 137 | 138 | 139 | class resnet(torch.nn.Module): 140 | def __init__(self, requires_grad=False, pretrained=True, num=18): 141 | super(resnet, self).__init__() 142 | if(num==18): 143 | self.net = models.resnet18(pretrained=pretrained) 144 | elif(num==34): 145 | self.net = models.resnet34(pretrained=pretrained) 146 | elif(num==50): 147 | self.net = models.resnet50(pretrained=pretrained) 148 | elif(num==101): 149 | self.net = models.resnet101(pretrained=pretrained) 150 | elif(num==152): 151 | self.net = models.resnet152(pretrained=pretrained) 152 | self.N_slices = 5 153 | 154 | self.conv1 = self.net.conv1 155 | self.bn1 = self.net.bn1 156 | self.relu = self.net.relu 157 | self.maxpool = self.net.maxpool 158 | self.layer1 = self.net.layer1 159 | self.layer2 = self.net.layer2 160 | self.layer3 = self.net.layer3 161 | self.layer4 = self.net.layer4 162 | 163 | def forward(self, X): 164 | h = self.conv1(X) 165 | h = self.bn1(h) 166 | h = self.relu(h) 167 | h_relu1 = h 168 | h = self.maxpool(h) 169 | h = self.layer1(h) 170 | h_conv2 = h 171 | h = self.layer2(h) 172 | h_conv3 = h 173 | h = self.layer3(h) 174 | h_conv4 = h 175 | h = self.layer4(h) 176 | h_conv5 = h 177 | 178 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) 179 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 180 | 181 | return out 182 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 4 | """ 5 | import random 6 | import numpy as np 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | from abc import ABCMeta, abstractmethod 11 | 12 | 13 | class BaseDataset(data.Dataset): 14 | __metaclass__ = ABCMeta 15 | """This class is an abstract base class (ABC) for datasets. 16 | 17 | To create a subclass, you need to implement the following four functions: 18 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 19 | -- <__len__>: return the size of dataset. 20 | -- <__getitem__>: get a data point. 21 | -- : (optionally) add dataset-specific options and set default options. 22 | """ 23 | 24 | def __init__(self, opt): 25 | """Initialize the class; save the options in the class 26 | 27 | Parameters: 28 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 29 | """ 30 | self.opt = opt 31 | self.root = opt.dataroot 32 | 33 | @staticmethod 34 | def modify_commandline_options(parser, is_train): 35 | """Add new dataset-specific options, and rewrite default values for existing options. 36 | 37 | Parameters: 38 | parser -- original option parser 39 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 40 | 41 | Returns: 42 | the modified parser. 43 | """ 44 | return parser 45 | 46 | @abstractmethod 47 | def __len__(self): 48 | """Return the total number of images in the dataset.""" 49 | return 0 50 | 51 | @abstractmethod 52 | def __getitem__(self, index): 53 | """Return a data point and its metadata information. 54 | 55 | Parameters: 56 | index - - a random integer for data indexing 57 | 58 | Returns: 59 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 60 | """ 61 | pass 62 | 63 | 64 | def get_params(opt, size): 65 | w, h = size 66 | new_h = h 67 | new_w = w 68 | if opt.preprocess == 'resize_and_crop': 69 | new_h = new_w = opt.load_size 70 | elif opt.preprocess == 'scale_width_and_crop': 71 | new_w = opt.load_size 72 | new_h = opt.load_size * h // w 73 | 74 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 75 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 76 | 77 | flip = random.random() > 0.5 78 | 79 | return {'crop_pos': (x, y), 'flip': flip} 80 | 81 | 82 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 83 | transform_list = [] 84 | if grayscale: 85 | transform_list.append(transforms.Grayscale(1)) 86 | if 'resize' in opt.preprocess: 87 | osize = [opt.load_size, opt.load_size] 88 | transform_list.append(transforms.Resize(osize, method)) 89 | elif 'scale_width' in opt.preprocess: 90 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) 91 | 92 | if 'crop' in opt.preprocess: 93 | if params is None: 94 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 95 | else: 96 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 97 | 98 | if opt.preprocess == 'none': 99 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 100 | 101 | if not opt.no_flip: 102 | if params is None: 103 | transform_list.append(transforms.RandomHorizontalFlip()) 104 | elif params['flip']: 105 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 106 | 107 | if convert: 108 | transform_list += [transforms.ToTensor()] 109 | if grayscale: 110 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 111 | else: 112 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 113 | return transforms.Compose(transform_list) 114 | 115 | def get_transform_mask(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 116 | transform_list = [] 117 | if grayscale: 118 | transform_list.append(transforms.Grayscale(1)) 119 | if 'resize' in opt.preprocess: 120 | osize = [opt.load_size, opt.load_size] 121 | transform_list.append(transforms.Resize(osize, method)) 122 | elif 'scale_width' in opt.preprocess: 123 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) 124 | 125 | if 'crop' in opt.preprocess: 126 | if params is None: 127 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 128 | else: 129 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 130 | 131 | if opt.preprocess == 'none': 132 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 133 | 134 | if not opt.no_flip: 135 | if params is None: 136 | transform_list.append(transforms.RandomHorizontalFlip()) 137 | elif params['flip']: 138 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 139 | 140 | if convert: 141 | transform_list += [transforms.ToTensor()] 142 | return transforms.Compose(transform_list) 143 | 144 | def __make_power_2(img, base, method=Image.BICUBIC): 145 | ow, oh = img.size 146 | h = int(round(oh / base) * base) 147 | w = int(round(ow / base) * base) 148 | if (h == oh) and (w == ow): 149 | return img 150 | 151 | __print_size_warning(ow, oh, w, h) 152 | return img.resize((w, h), method) 153 | 154 | 155 | def __scale_width(img, target_width, method=Image.BICUBIC): 156 | ow, oh = img.size 157 | if (ow == target_width): 158 | return img 159 | w = target_width 160 | h = int(target_width * oh / ow) 161 | return img.resize((w, h), method) 162 | 163 | 164 | def __crop(img, pos, size): 165 | ow, oh = img.size 166 | x1, y1 = pos 167 | tw = th = size 168 | if (ow > tw or oh > th): 169 | return img.crop((x1, y1, x1 + tw, y1 + th)) 170 | return img 171 | 172 | 173 | def __flip(img, flip): 174 | if flip: 175 | return img.transpose(Image.FLIP_LEFT_RIGHT) 176 | return img 177 | 178 | 179 | def __print_size_warning(ow, oh, w, h): 180 | """Print warning information about image size(only print once)""" 181 | if not hasattr(__print_size_warning, 'has_printed'): 182 | print("The image size needs to be a multiple of 4. " 183 | "The loaded image size was (%d, %d), so it was adjusted to " 184 | "(%d, %d). This adjustment will be done to all images " 185 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 186 | __print_size_warning.has_printed = True 187 | -------------------------------------------------------------------------------- /data/unaligned_mask_stylecls_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_params, get_transform, get_transform_mask 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | import random 6 | import torch 7 | import torchvision.transforms as transforms 8 | import numpy as np 9 | 10 | 11 | class UnalignedMaskStyleClsDataset(BaseDataset): 12 | """ 13 | This dataset class can load unaligned/unpaired datasets. 14 | 15 | It requires two directories to host training images from domain A '/path/to/data/trainA' 16 | and from domain B '/path/to/data/trainB' respectively. 17 | You can train the model with the dataset flag '--dataroot /path/to/data'. 18 | Similarly, you need to prepare two directories: 19 | '/path/to/data/testA' and '/path/to/data/testB' during test time. 20 | """ 21 | 22 | def __init__(self, opt): 23 | """Initialize this dataset class. 24 | 25 | Parameters: 26 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 27 | """ 28 | BaseDataset.__init__(self, opt) 29 | 30 | imglistA = './datasets/list/%s/%s.txt' % (opt.phase+'A', opt.dataroot) 31 | imglistB = './datasets/list/%s/%s.txt' % (opt.phase+'B', opt.dataroot) 32 | 33 | if not os.path.exists(imglistA) or not os.path.exists(imglistB): 34 | self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA' 35 | self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB' 36 | 37 | self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA' 38 | self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB' 39 | else: 40 | self.A_paths = sorted(open(imglistA, 'r').read().splitlines()) 41 | self.B_paths = sorted(open(imglistB, 'r').read().splitlines()) 42 | 43 | self.A_size = len(self.A_paths) # get the size of dataset A 44 | self.B_size = len(self.B_paths) # get the size of dataset B 45 | print("A size:", self.A_size) 46 | print("B size:", self.B_size) 47 | btoA = self.opt.direction == 'BtoA' 48 | self.input_nc = self.opt.output_nc if btoA else self.opt.input_nc # get the number of channels of input image 49 | self.output_nc = self.opt.input_nc if btoA else self.opt.output_nc # get the number of channels of output image 50 | 51 | if opt.dataroot == '190613-4s': 52 | self.softmaxloc = os.path.join('style_features/styles2/', '1vgg19_softmax') 53 | elif opt.dataroot == '190613-4sn5': 54 | self.softmaxloc = os.path.join('style_features/styles2_sn_equal/', '1vgg19_softmax') 55 | elif '190613-4sn' in self.opt.dataroot: 56 | self.softmaxloc = os.path.join('style_features/styles2_sn/', '1vgg19_softmax') 57 | 58 | 59 | def __getitem__(self, index): 60 | """Return a data point and its metadata information. 61 | 62 | Parameters: 63 | index (int) -- a random integer for data indexing 64 | 65 | Returns a dictionary that contains A, B, A_paths and B_paths 66 | A (tensor) -- an image in the input domain 67 | B (tensor) -- its corresponding image in the target domain 68 | A_paths (str) -- image paths 69 | B_paths (str) -- image paths 70 | """ 71 | A_path = self.A_paths[index % self.A_size] # make sure index is within then range 72 | if self.opt.serial_batches: # make sure index is within then range 73 | index_B = index % self.B_size 74 | else: # randomize the index for domain B to avoid fixed pairs. 75 | index_B = random.randint(0, self.B_size - 1) 76 | B_path = self.B_paths[index_B] 77 | A_img = Image.open(A_path).convert('RGB') 78 | B_img = Image.open(B_path).convert('RGB') 79 | 80 | basenA = os.path.basename(A_path) 81 | A_mask_img = Image.open(os.path.join('./datasets/list/mask/A',basenA)) 82 | basenB = os.path.basename(B_path) 83 | basenB2 = basenB.replace('_fake.png','.png') 84 | # for added synthetic drawing 85 | basenB2 = basenB2.replace('_style1.png','.png') 86 | basenB2 = basenB2.replace('_style2.png','.png') 87 | basenB2 = basenB2.replace('_style1single.png','.png') 88 | basenB2 = basenB2.replace('_style2single.png','.png') 89 | B_mask_img = Image.open(os.path.join('./datasets/list/mask/B',basenB2)) 90 | if self.opt.use_eye_mask: 91 | A_maske_img = Image.open(os.path.join('./datasets/list/mask/A_eyes',basenA)) 92 | B_maske_img = Image.open(os.path.join('./datasets/list/mask/B_eyes',basenB2)) 93 | if self.opt.use_lip_mask: 94 | A_maskl_img = Image.open(os.path.join('./datasets/list/mask/A_lips',basenA)) 95 | B_maskl_img = Image.open(os.path.join('./datasets/list/mask/B_lips',basenB2)) 96 | if self.opt.metric_inmask: 97 | A_maskfg_img = Image.open(os.path.join('./datasets/list/mask/A_fg',basenA)) 98 | 99 | # apply image transformation 100 | transform_params_A = get_params(self.opt, A_img.size) 101 | transform_params_B = get_params(self.opt, B_img.size) 102 | A = get_transform(self.opt, transform_params_A, grayscale=(self.input_nc == 1))(A_img) 103 | B = get_transform(self.opt, transform_params_B, grayscale=(self.output_nc == 1))(B_img) 104 | A_mask = get_transform_mask(self.opt, transform_params_A, grayscale=1)(A_mask_img) 105 | B_mask = get_transform_mask(self.opt, transform_params_B, grayscale=1)(B_mask_img) 106 | if self.opt.use_eye_mask: 107 | A_maske = get_transform_mask(self.opt, transform_params_A, grayscale=1)(A_maske_img) 108 | B_maske = get_transform_mask(self.opt, transform_params_B, grayscale=1)(B_maske_img) 109 | if self.opt.use_lip_mask: 110 | A_maskl = get_transform_mask(self.opt, transform_params_A, grayscale=1)(A_maskl_img) 111 | B_maskl = get_transform_mask(self.opt, transform_params_B, grayscale=1)(B_maskl_img) 112 | if self.opt.metric_inmask: 113 | A_maskfg = get_transform_mask(self.opt, transform_params_A, grayscale=1)(A_maskfg_img) 114 | 115 | item = {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path, 'A_mask': A_mask, 'B_mask': B_mask} 116 | if self.opt.use_eye_mask: 117 | item['A_maske'] = A_maske 118 | item['B_maske'] = B_maske 119 | if self.opt.use_lip_mask: 120 | item['A_maskl'] = A_maskl 121 | item['B_maskl'] = B_maskl 122 | if self.opt.metric_inmask: 123 | item['A_maskfg'] = A_maskfg 124 | 125 | 126 | softmax = np.load(os.path.join(self.softmaxloc,basenB[:-4]+'.npy')) 127 | softmax = torch.Tensor(softmax) 128 | [maxv,index] = torch.max(softmax,0) 129 | B_label = index 130 | if len(self.opt.sfeature_mode) >= 8 and self.opt.sfeature_mode[-8:] == '_softmax': 131 | if self.opt.one_hot: 132 | B_style = torch.Tensor([0.,0.,0.]) 133 | B_style[index] = 1. 134 | else: 135 | B_style = softmax 136 | B_style = B_style.view(3, 1, 1) 137 | B_style = B_style.repeat(1, 128, 128) 138 | #print(index, index_B, torch.mean(B_style,(1,2))) 139 | elif self.opt.sfeature_mode == 'domain': 140 | B_style = B_label 141 | item['B_style'] = B_style 142 | item['B_label'] = B_label 143 | if self.opt.isTrain and self.opt.style_loss_with_weight: 144 | item['B_style0'] = softmax 145 | if self.opt.isTrain and self.opt.metricvec: 146 | vec = softmax 147 | vec = vec.view(3, 1, 1) 148 | vec = vec.repeat(1, 299, 299) 149 | item['vec'] = vec 150 | 151 | return item 152 | 153 | def __len__(self): 154 | """Return the total number of images in the dataset. 155 | 156 | As we have two datasets with potentially different number of images, 157 | we take a maximum of 158 | """ 159 | return max(self.A_size, self.B_size) 160 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | import models 6 | import data 7 | 8 | 9 | class BaseOptions(): 10 | """This class defines options used during both training and test time. 11 | 12 | It also implements several helper functions such as parsing, printing, and saving the options. 13 | It also gathers additional options defined in functions in both dataset class and model class. 14 | """ 15 | 16 | def __init__(self): 17 | """Reset the class; indicates the class hasn't been initailized""" 18 | self.initialized = False 19 | 20 | def initialize(self, parser): 21 | """Define the common options that are used in both training and test.""" 22 | # basic parameters 23 | parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') 24 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 25 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 26 | parser.add_argument('--gpu_ids_p', type=str, default='0', help='gpu ids for pretrained auxiliary models: e.g. 0 0,1,2, 0,2. use -1 for CPU') 27 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 28 | # model parameters 29 | parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]') 30 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale') 31 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale') 32 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') 33 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') 34 | parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') 35 | parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]') 36 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') 37 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') 38 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') 39 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 40 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') 41 | # dataset parameters 42 | parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') 43 | parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA') 44 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 45 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') 46 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size') 47 | parser.add_argument('--load_size', type=int, default=286, help='scale images to this size') 48 | parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size') 49 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 50 | parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]') 51 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 52 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') 53 | # additional parameters 54 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 55 | parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') 56 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 57 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') 58 | self.initialized = True 59 | return parser 60 | 61 | def gather_options(self): 62 | """Initialize our parser with basic options(only once). 63 | Add additional model-specific and dataset-specific options. 64 | These options are defined in the function 65 | in model and dataset classes. 66 | """ 67 | if not self.initialized: # check if it has been initialized 68 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 69 | parser = self.initialize(parser) 70 | 71 | # get the basic options 72 | opt, _ = parser.parse_known_args() 73 | 74 | # modify model-related parser options 75 | model_name = opt.model 76 | model_option_setter = models.get_option_setter(model_name) 77 | parser = model_option_setter(parser, self.isTrain) 78 | opt, _ = parser.parse_known_args() # parse again with new defaults 79 | 80 | # modify dataset-related parser options 81 | dataset_name = opt.dataset_mode 82 | dataset_option_setter = data.get_option_setter(dataset_name) 83 | parser = dataset_option_setter(parser, self.isTrain) 84 | 85 | # save and return the parser 86 | self.parser = parser 87 | return parser.parse_args() 88 | 89 | def print_options(self, opt): 90 | """Print and save options 91 | 92 | It will print both current options and default values(if different). 93 | It will save options into a text file / [checkpoints_dir] / opt.txt 94 | """ 95 | message = '' 96 | message += '----------------- Options ---------------\n' 97 | for k, v in sorted(vars(opt).items()): 98 | comment = '' 99 | default = self.parser.get_default(k) 100 | if v != default: 101 | comment = '\t[default: %s]' % str(default) 102 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 103 | message += '----------------- End -------------------' 104 | print(message) 105 | 106 | # save to the disk 107 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 108 | util.mkdirs(expr_dir) 109 | file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) 110 | with open(file_name, 'wt') as opt_file: 111 | opt_file.write(message) 112 | opt_file.write('\n') 113 | 114 | def parse(self): 115 | """Parse our options, create checkpoints directory suffix, and set up gpu device.""" 116 | opt = self.gather_options() 117 | opt.isTrain = self.isTrain # train or test 118 | 119 | # process opt.suffix 120 | if opt.suffix: 121 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 122 | opt.name = opt.name + suffix 123 | 124 | self.print_options(opt) 125 | 126 | # set gpu ids 127 | str_ids = opt.gpu_ids.split(',') 128 | opt.gpu_ids = [] 129 | for str_id in str_ids: 130 | id = int(str_id) 131 | if id >= 0: 132 | opt.gpu_ids.append(id) 133 | if len(opt.gpu_ids) > 0: 134 | torch.cuda.set_device(opt.gpu_ids[0]) 135 | 136 | # set gpu ids 137 | str_ids = opt.gpu_ids_p.split(',') 138 | opt.gpu_ids_p = [] 139 | for str_id in str_ids: 140 | id = int(str_id) 141 | if id >= 0: 142 | opt.gpu_ids_p.append(id) 143 | 144 | self.opt = opt 145 | return self.opt 146 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from abc import ABCMeta, abstractmethod 5 | from . import networks 6 | import pdb 7 | 8 | 9 | class BaseModel(): 10 | __metaclass__ = ABCMeta 11 | """This class is an abstract base class (ABC) for models. 12 | To create a subclass, you need to implement the following five functions: 13 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 14 | -- : unpack data from dataset and apply preprocessing. 15 | -- : produce intermediate results. 16 | -- : calculate losses, gradients, and update network weights. 17 | -- : (optionally) add model-specific options and set default options. 18 | """ 19 | 20 | def __init__(self, opt): 21 | """Initialize the BaseModel class. 22 | 23 | Parameters: 24 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 25 | 26 | When creating your custom class, you need to implement your own initialization. 27 | In this fucntion, you should first call 28 | Then, you need to define four lists: 29 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 30 | -- self.model_names (str list): specify the images that you want to display and save. 31 | -- self.visual_names (str list): define networks used in our training. 32 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 33 | """ 34 | self.opt = opt 35 | self.gpu_ids = opt.gpu_ids 36 | self.isTrain = opt.isTrain 37 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU 38 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir 39 | if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. 40 | torch.backends.cudnn.benchmark = True 41 | self.loss_names = [] 42 | self.model_names = [] 43 | self.visual_names = [] 44 | self.optimizers = [] 45 | self.image_paths = [] 46 | self.metric = 0 # used for learning rate policy 'plateau' 47 | 48 | @staticmethod 49 | def modify_commandline_options(parser, is_train): 50 | """Add new model-specific options, and rewrite default values for existing options. 51 | 52 | Parameters: 53 | parser -- original option parser 54 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 55 | 56 | Returns: 57 | the modified parser. 58 | """ 59 | return parser 60 | 61 | @abstractmethod 62 | def set_input(self, input): 63 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 64 | 65 | Parameters: 66 | input (dict): includes the data itself and its metadata information. 67 | """ 68 | pass 69 | 70 | @abstractmethod 71 | def forward(self): 72 | """Run forward pass; called by both functions and .""" 73 | pass 74 | 75 | @abstractmethod 76 | def optimize_parameters(self): 77 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 78 | pass 79 | 80 | def setup(self, opt): 81 | """Load and print networks; create schedulers 82 | 83 | Parameters: 84 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 85 | """ 86 | if self.isTrain: 87 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 88 | if not self.isTrain or opt.continue_train: 89 | load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch 90 | self.load_networks(load_suffix) 91 | self.print_networks(opt.verbose) 92 | 93 | def eval(self): 94 | """Make models eval mode during test time""" 95 | for name in self.model_names: 96 | if isinstance(name, str): 97 | net = getattr(self, 'net' + name) 98 | net.eval() 99 | 100 | def test(self): 101 | """Forward function used in test time. 102 | 103 | This function wraps function in no_grad() so we don't save intermediate steps for backprop 104 | It also calls to produce additional visualization results 105 | """ 106 | with torch.no_grad(): 107 | self.forward() 108 | self.compute_visuals() 109 | 110 | def compute_visuals(self): 111 | """Calculate additional output images for visdom and HTML visualization""" 112 | pass 113 | 114 | def get_image_paths(self): 115 | """ Return image paths that are used to load current data""" 116 | return self.image_paths 117 | 118 | def update_learning_rate(self): 119 | """Update learning rates for all the networks; called at the end of every epoch""" 120 | for scheduler in self.schedulers: 121 | if self.opt.lr_policy == 'plateau': 122 | scheduler.step(self.metric) 123 | else: 124 | scheduler.step() 125 | 126 | lr = self.optimizers[0].param_groups[0]['lr'] 127 | print('learning rate = %.7f' % lr) 128 | 129 | def get_current_visuals(self): 130 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" 131 | visual_ret = OrderedDict() 132 | for name in self.visual_names: 133 | if isinstance(name, str): 134 | visual_ret[name] = getattr(self, name) 135 | return visual_ret 136 | 137 | def get_current_losses(self): 138 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" 139 | errors_ret = OrderedDict() 140 | for name in self.loss_names: 141 | if isinstance(name, str): 142 | errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number 143 | return errors_ret 144 | 145 | def save_networks(self, epoch): 146 | """Save all the networks to the disk. 147 | 148 | Parameters: 149 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 150 | """ 151 | for name in self.model_names: 152 | if isinstance(name, str): 153 | save_filename = '%s_net_%s.pth' % (epoch, name) 154 | save_path = os.path.join(self.save_dir, save_filename) 155 | net = getattr(self, 'net' + name) 156 | 157 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 158 | torch.save(net.module.cpu().state_dict(), save_path) 159 | net.cuda(self.gpu_ids[0]) 160 | else: 161 | torch.save(net.cpu().state_dict(), save_path) 162 | 163 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 164 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" 165 | key = keys[i] 166 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 167 | if module.__class__.__name__.startswith('InstanceNorm') and \ 168 | (key == 'running_mean' or key == 'running_var'): 169 | if getattr(module, key) is None: 170 | state_dict.pop('.'.join(keys)) 171 | if module.__class__.__name__.startswith('InstanceNorm') and \ 172 | (key == 'num_batches_tracked'): 173 | state_dict.pop('.'.join(keys)) 174 | else: 175 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 176 | 177 | def load_networks(self, epoch): 178 | """Load all the networks from the disk. 179 | 180 | Parameters: 181 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 182 | """ 183 | for name in self.model_names: 184 | if isinstance(name, str): 185 | load_filename = '%s_net_%s.pth' % (epoch, name) 186 | load_path = os.path.join(self.save_dir, load_filename) 187 | net = getattr(self, 'net' + name) 188 | if isinstance(net, torch.nn.DataParallel): 189 | net = net.module 190 | print('loading the model from %s' % load_path) 191 | # if you are using PyTorch newer than 0.4 (e.g., built from 192 | # GitHub source), you can remove str() on self.device 193 | state_dict = torch.load(load_path, map_location=str(self.device)) 194 | if hasattr(state_dict, '_metadata'): 195 | del state_dict._metadata 196 | 197 | # patch InstanceNorm checkpoints prior to 0.4 198 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 199 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 200 | net.load_state_dict(state_dict) 201 | #param1 = {} 202 | #for name, parameters in net.named_parameters(): 203 | # print(name,',',parameters.size()) 204 | # param1[name] = parameters.detach().cpu().numpy() 205 | #pdb.set_trace() 206 | 207 | def print_networks(self, verbose): 208 | """Print the total number of parameters in the network and (if verbose) network architecture 209 | 210 | Parameters: 211 | verbose (bool) -- if verbose: print the network architecture 212 | """ 213 | print('---------- Networks initialized -------------') 214 | for name in self.model_names: 215 | if isinstance(name, str): 216 | net = getattr(self, 'net' + name) 217 | num_params = 0 218 | for param in net.parameters(): 219 | num_params += param.numel() 220 | if verbose: 221 | print(net) 222 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 223 | print('-----------------------------------------------') 224 | 225 | def set_requires_grad(self, nets, requires_grad=False): 226 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 227 | Parameters: 228 | nets (network list) -- a list of networks 229 | requires_grad (bool) -- whether the networks require gradients or not 230 | """ 231 | if not isinstance(nets, list): 232 | nets = [nets] 233 | for net in nets: 234 | if net is not None: 235 | for param in net.parameters(): 236 | param.requires_grad = requires_grad 237 | 238 | # =========================================================================================================== 239 | def masked(self, A,mask): 240 | if self.opt.mask_type == 0: 241 | return (A/2+0.5)*mask*2-1 242 | elif self.opt.mask_type == 1: 243 | return ((A/2+0.5)*mask+1-mask)*2-1 244 | elif self.opt.mask_type == 2: 245 | return torch.cat((A, mask), 1) 246 | elif self.opt.mask_type == 3: 247 | masked = ((A/2+0.5)*mask+1-mask)*2-1 248 | return torch.cat((masked, mask), 1) -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import ntpath 5 | import time 6 | from . import util, html 7 | from subprocess import Popen, PIPE 8 | #from scipy.misc import imresize 9 | from PIL import Image 10 | import pdb 11 | #from scipy.io import savemat 12 | 13 | if sys.version_info[0] == 2: 14 | VisdomExceptionBase = Exception 15 | else: 16 | VisdomExceptionBase = ConnectionError 17 | 18 | 19 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): 20 | """Save images to the disk. 21 | 22 | Parameters: 23 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) 24 | visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs 25 | image_path (str) -- the string is used to create image paths 26 | aspect_ratio (float) -- the aspect ratio of saved images 27 | width (int) -- the images will be resized to width x width 28 | 29 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. 30 | """ 31 | image_dir = webpage.get_image_dir() 32 | short_path = ntpath.basename(image_path[0]) 33 | name = os.path.splitext(short_path)[0] 34 | 35 | webpage.add_header(name) 36 | ims, txts, links = [], [], [] 37 | 38 | for label, im_data in visuals.items(): 39 | ## tensor to im 40 | im = util.tensor2im(im_data) 41 | #im = util.tensor2im2(im_data) 42 | ## save mat 43 | #im,imo = util.tensor2im(im_data) 44 | #matname = os.path.join(image_dir, '%s_%s.mat' % (name, label)) 45 | #savemat(matname,{'imo':imo}) 46 | image_name = '%s_%s.png' % (name, label) 47 | save_path = os.path.join(image_dir, image_name) 48 | h, w, _ = im.shape 49 | if aspect_ratio > 1.0: 50 | #im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') 51 | im = np.array(Image.fromarray(im).resize((int(w * aspect_ratio), h), Image.BICUBIC)) 52 | if aspect_ratio < 1.0: 53 | #im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') 54 | im = np.array(Image.fromarray(im).resize((w, int(h / aspect_ratio)), Image.BICUBIC)) 55 | util.save_image(im, save_path) 56 | 57 | ims.append(image_name) 58 | txts.append(label) 59 | links.append(image_name) 60 | webpage.add_images(ims, txts, links, width=width) 61 | 62 | 63 | class Visualizer(): 64 | """This class includes several functions that can display/save images and print/save logging information. 65 | 66 | It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. 67 | """ 68 | 69 | def __init__(self, opt): 70 | """Initialize the Visualizer class 71 | 72 | Parameters: 73 | opt -- stores all the experiment flags; needs to be a subclass of BaseOptions 74 | Step 1: Cache the training/test options 75 | Step 2: connect to a visdom server 76 | Step 3: create an HTML object for saveing HTML filters 77 | Step 4: create a logging file to store training losses 78 | """ 79 | self.opt = opt # cache the option 80 | self.display_id = opt.display_id 81 | self.use_html = opt.isTrain and not opt.no_html 82 | self.win_size = opt.display_winsize 83 | self.name = opt.name 84 | self.port = opt.display_port 85 | self.saved = False 86 | if self.display_id > 0: # connect to a visdom server given and 87 | import visdom 88 | self.ncols = opt.display_ncols 89 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env) 90 | if not self.vis.check_connection(): 91 | self.create_visdom_connections() 92 | 93 | if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ 94 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 95 | self.img_dir = os.path.join(self.web_dir, 'images') 96 | print('create web directory %s...' % self.web_dir) 97 | util.mkdirs([self.web_dir, self.img_dir]) 98 | # create a logging file to store training losses 99 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 100 | with open(self.log_name, "a") as log_file: 101 | now = time.strftime("%c") 102 | log_file.write('================ Training Loss (%s) ================\n' % now) 103 | 104 | def reset(self): 105 | """Reset the self.saved status""" 106 | self.saved = False 107 | 108 | def create_visdom_connections(self): 109 | """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """ 110 | cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port 111 | print('\n\nCould not connect to Visdom server. \n Trying to start a server....') 112 | print('Command: %s' % cmd) 113 | Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) 114 | 115 | def display_current_results(self, visuals, epoch, save_result): 116 | """Display current results on visdom; save current results to an HTML file. 117 | 118 | Parameters: 119 | visuals (OrderedDict) - - dictionary of images to display or save 120 | epoch (int) - - the current epoch 121 | save_result (bool) - - if save the current results to an HTML file 122 | """ 123 | if self.display_id > 0: # show images in the browser using visdom 124 | ncols = self.ncols 125 | if ncols > 0: # show all the images in one visdom panel 126 | ncols = min(ncols, len(visuals)) 127 | h, w = next(iter(visuals.values())).shape[:2] 128 | table_css = """""" % (w, h) # create a table css 132 | # create a table of images. 133 | title = self.name 134 | label_html = '' 135 | label_html_row = '' 136 | images = [] 137 | idx = 0 138 | for label, image in visuals.items(): 139 | #image_numpy = util.tensor2im(image) 140 | image_numpy = util.tensor2im2(image) 141 | label_html_row += '%s' % label 142 | #pdb.set_trace() 143 | images.append(image_numpy.transpose([2, 0, 1])) 144 | idx += 1 145 | if idx % ncols == 0: 146 | label_html += '%s' % label_html_row 147 | label_html_row = '' 148 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 149 | while idx % ncols != 0: 150 | images.append(white_image) 151 | label_html_row += '' 152 | idx += 1 153 | if label_html_row != '': 154 | label_html += '%s' % label_html_row 155 | try: 156 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 157 | padding=2, opts=dict(title=title + ' images')) 158 | label_html = '%s
' % label_html 159 | self.vis.text(table_css + label_html, win=self.display_id + 2, 160 | opts=dict(title=title + ' labels')) 161 | except VisdomExceptionBase: 162 | self.create_visdom_connections() 163 | 164 | else: # show each image in a separate visdom panel; 165 | idx = 1 166 | try: 167 | for label, image in visuals.items(): 168 | image_numpy = util.tensor2im(image) 169 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 170 | win=self.display_id + idx) 171 | idx += 1 172 | except VisdomExceptionBase: 173 | self.create_visdom_connections() 174 | 175 | if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. 176 | self.saved = True 177 | # save images to the disk 178 | for label, image in visuals.items(): 179 | image_numpy = util.tensor2im(image) 180 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 181 | util.save_image(image_numpy, img_path) 182 | 183 | # update website 184 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1) 185 | for n in range(epoch, 0, -1): 186 | webpage.add_header('epoch [%d]' % n) 187 | ims, txts, links = [], [], [] 188 | 189 | for label, image_numpy in visuals.items(): 190 | image_numpy = util.tensor2im(image) 191 | img_path = 'epoch%.3d_%s.png' % (n, label) 192 | ims.append(img_path) 193 | txts.append(label) 194 | links.append(img_path) 195 | webpage.add_images(ims, txts, links, width=self.win_size) 196 | webpage.save() 197 | 198 | def plot_current_losses(self, epoch, counter_ratio, losses): 199 | """display the current losses on visdom display: dictionary of error labels and values 200 | 201 | Parameters: 202 | epoch (int) -- current epoch 203 | counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1 204 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 205 | """ 206 | if not hasattr(self, 'plot_data'): 207 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 208 | self.plot_data['X'].append(epoch + counter_ratio) 209 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) 210 | #X = np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1) 211 | #Y = np.array(self.plot_data['Y']) 212 | #pdb.set_trace() 213 | try: 214 | self.vis.line( 215 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 216 | Y=np.array(self.plot_data['Y']), 217 | opts={ 218 | 'title': self.name + ' loss over time', 219 | 'legend': self.plot_data['legend'], 220 | 'xlabel': 'epoch', 221 | 'ylabel': 'loss'}, 222 | win=self.display_id) 223 | except VisdomExceptionBase: 224 | self.create_visdom_connections() 225 | 226 | # losses: same format as |losses| of plot_current_losses 227 | def print_current_losses(self, epoch, iters, losses, t_comp, t_data): 228 | """print current losses on console; also save the losses to the disk 229 | 230 | Parameters: 231 | epoch (int) -- current epoch 232 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) 233 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 234 | t_comp (float) -- computational time per data point (normalized by batch_size) 235 | t_data (float) -- data loading time per data point (normalized by batch_size) 236 | """ 237 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) 238 | for k, v in losses.items(): 239 | message += '%s: %.3f ' % (k, v) 240 | 241 | print(message) # print the message 242 | with open(self.log_name, "a") as log_file: 243 | log_file.write('%s\n' % message) # save the message 244 | 245 | # losses: same format as |losses| of plot_current_losses 246 | def print_current_losses_process(self, epoch, iters, losses, t_comp, t_data, processes): 247 | """print current losses on console; also save the losses to the disk 248 | 249 | Parameters: 250 | epoch (int) -- current epoch 251 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) 252 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 253 | t_comp (float) -- computational time per data point (normalized by batch_size) 254 | t_data (float) -- data loading time per data point (normalized by batch_size) 255 | """ 256 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) 257 | message += '[process: %.3f, non_trunc: %.3f, trunc: %.3f] ' % (processes[0], processes[1], processes[2]) 258 | for k, v in losses.items(): 259 | message += '%s: %.3f ' % (k, v) 260 | 261 | print(message) # print the message 262 | with open(self.log_name, "a") as log_file: 263 | log_file.write('%s\n' % message) # save the message 264 | -------------------------------------------------------------------------------- /models/cycle_gan_cls_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import itertools 3 | from util.image_pool import ImagePool 4 | from .base_model import BaseModel 5 | from . import networks 6 | import models.dist_model as dm # numpy==1.14.3 7 | import torchvision.transforms as transforms 8 | import os 9 | from util.util import tensor2im, tensor2im2, save_image 10 | 11 | def truncate(fake_B,a=127.5):#[-1,1] 12 | #return torch.round((fake_B+1)*a)/a-1 13 | return ((fake_B+1)*a).int().float()/a-1 14 | 15 | class CycleGANClsModel(BaseModel): 16 | """ 17 | This class implements the CycleGAN model, for learning image-to-image translation without paired data. 18 | 19 | The model training requires '--dataset_mode unaligned' dataset. 20 | By default, it uses a '--netG resnet_9blocks' ResNet generator, 21 | a '--netD basic' discriminator (PatchGAN introduced by pix2pix), 22 | and a least-square GANs objective ('--gan_mode lsgan'). 23 | 24 | CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf 25 | """ 26 | @staticmethod 27 | def modify_commandline_options(parser, is_train=True): 28 | """Add new dataset-specific options, and rewrite default values for existing options. 29 | 30 | Parameters: 31 | parser -- original option parser 32 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 33 | 34 | Returns: 35 | the modified parser. 36 | 37 | For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses. 38 | A (source domain), B (target domain). 39 | Generators: G_A: A -> B; G_B: B -> A. 40 | Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A. 41 | Forward cycle loss: lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper) 42 | Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper) 43 | Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper) 44 | Dropout is not used in the original CycleGAN paper. 45 | """ 46 | parser.set_defaults(no_dropout=True) # default CycleGAN did not use dropout 47 | parser.set_defaults(dataset_mode='unaligned_mask_stylecls') 48 | parser.add_argument('--netda', type=str, default='basic_cls') # discriminator has two branches 49 | parser.add_argument('--truncate', type=float, default=0.0, help='whether truncate in forward') 50 | if is_train: 51 | parser.add_argument('--lambda_A', type=float, default=5.0, help='weight for cycle loss (A -> B -> A)') 52 | parser.add_argument('--lambda_B', type=float, default=5.0, help='weight for cycle loss (B -> A -> B)') 53 | parser.add_argument('--lambda_identity', type=float, default=0, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1') 54 | parser.add_argument('--perceptual_cycle', type=int, default=6, help='whether use perceptual similarity for cycle loss') 55 | parser.add_argument('--use_hed', type=int, default=1, help='whether use hed processing for cycle loss') 56 | parser.add_argument('--ntrunc_trunc', type=int, default=1, help='whether use both non-trunc version and trunc version') 57 | parser.add_argument('--trunc_a', type=float, default=31.875, help='multiply which value to round when trunc') 58 | parser.add_argument('--lambda_A_trunc', type=float, default=5.0, help='weight for cycle loss for trunc') 59 | parser.add_argument('--hed_pretrained_mode', type=str, default='./checkpoints/network-bsds500.pytorch', help='path to the pretrained hed model') 60 | parser.add_argument('--vgg_pretrained_mode', type=str, default='./checkpoints/vgg19.pth', help='path to the pretrained vgg model') 61 | parser.add_argument('--lambda_G_A_l', type=float, default=0.5, help='weight for local GAN loss in G') 62 | parser.add_argument('--style_loss_with_weight', type=int, default=0, help='whether multiply prob in style loss') 63 | parser.add_argument('--metric', action='store_true', help='whether use metric loss for fakeB') 64 | parser.add_argument('--metric_model_path', type=str, default='3/30_net_Regressor.pth', help='metric model path') 65 | parser.add_argument('--lambda_metric', type=float, default=0.5, help='weight for metric loss') 66 | parser.add_argument('--metricvec', action='store_true', help='whether use metric model with vec input') 67 | parser.add_argument('--metric_resnext', action='store_true', help='whether use resnext as metric model') 68 | parser.add_argument('--metric_resnet', action='store_true', help='whether use resnet as metric model') 69 | parser.add_argument('--metric_inception', action='store_true', help='whether use inception as metric model')# the inception of transform_input=False 70 | parser.add_argument('--metric_inmask', action='store_true', help='whether use inmask in metric model') 71 | else: 72 | parser.add_argument('--check_D', action='store_true', help='whether use check Ds outputs') 73 | # for masks 74 | parser.add_argument('--use_mask', type=int, default=1, help='whether use mask for special face region') 75 | parser.add_argument('--use_eye_mask', type=int, default=1, help='whether use mask for special face region') 76 | parser.add_argument('--use_lip_mask', type=int, default=1, help='whether use mask for special face region') 77 | parser.add_argument('--mask_type', type=int, default=3, help='use mask type, 0 outside black, 1 outside white') 78 | # for style control 79 | parser.add_argument('--style_control', type=int, default=1, help='use style_control') 80 | parser.add_argument('--sfeature_mode', type=str, default='1vgg19_softmax', help='vgg19 softmax as feature') 81 | parser.add_argument('--netga', type=str, default='resnet_style_9blocks', help='net arch for netG_A') 82 | parser.add_argument('--model0_res', type=int, default=0, help='number of resblocks in model0 (before insert style)') 83 | parser.add_argument('--model1_res', type=int, default=0, help='number of resblocks in model1 (after insert style, before 2 column merge)') 84 | parser.add_argument('--one_hot', type=int, default=0, help='use one-hot for style code') 85 | 86 | return parser 87 | 88 | def __init__(self, opt): 89 | """Initialize the CycleGAN class. 90 | 91 | Parameters: 92 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 93 | """ 94 | BaseModel.__init__(self, opt) 95 | # specify the training losses you want to print out. The training/test scripts will call 96 | self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'] 97 | # specify the images you want to save/display. The training/test scripts will call 98 | visual_names_A = ['real_A', 'fake_B', 'rec_A'] 99 | visual_names_B = ['real_B', 'fake_A', 'rec_B'] 100 | if self.isTrain and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B) 101 | visual_names_A.append('idt_B') 102 | visual_names_B.append('idt_A') 103 | if self.isTrain and self.opt.use_hed: 104 | visual_names_A.append('real_A_hed') 105 | visual_names_A.append('rec_A_hed') 106 | if self.isTrain and self.opt.ntrunc_trunc: 107 | visual_names_A.append('rec_At') 108 | if self.opt.use_hed: 109 | visual_names_A.append('rec_At_hed') 110 | self.loss_names = ['D_A', 'G_A', 'cycle_A', 'cycle_A2', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B', 'G'] 111 | if self.isTrain and self.opt.use_mask: 112 | visual_names_A.append('fake_B_l') 113 | visual_names_A.append('real_B_l') 114 | self.loss_names += ['D_A_l', 'G_A_l'] 115 | if self.isTrain and self.opt.use_eye_mask: 116 | visual_names_A.append('fake_B_le') 117 | visual_names_A.append('real_B_le') 118 | self.loss_names += ['D_A_le', 'G_A_le'] 119 | if self.isTrain and self.opt.use_lip_mask: 120 | visual_names_A.append('fake_B_ll') 121 | visual_names_A.append('real_B_ll') 122 | self.loss_names += ['D_A_ll', 'G_A_ll'] 123 | if self.isTrain and self.opt.metric: 124 | self.loss_names += ['metric'] 125 | #visual_names_B += ['fake_B2'] 126 | if not self.isTrain and self.opt.use_mask: 127 | visual_names_A.append('fake_B_l') 128 | visual_names_A.append('real_B_l') 129 | if not self.isTrain and self.opt.use_eye_mask: 130 | visual_names_A.append('fake_B_le') 131 | visual_names_A.append('real_B_le') 132 | if not self.isTrain and self.opt.use_lip_mask: 133 | visual_names_A.append('fake_B_ll') 134 | visual_names_A.append('real_B_ll') 135 | self.loss_names += ['D_A_cls','G_A_cls'] 136 | 137 | self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B 138 | print(self.visual_names) 139 | # specify the models you want to save to the disk. The training/test scripts will call and . 140 | if self.isTrain: 141 | self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] 142 | if self.opt.use_mask: 143 | self.model_names += ['D_A_l'] 144 | if self.opt.use_eye_mask: 145 | self.model_names += ['D_A_le'] 146 | if self.opt.use_lip_mask: 147 | self.model_names += ['D_A_ll'] 148 | else: # during test time, only load Gs 149 | self.model_names = ['G_A', 'G_B'] 150 | if self.opt.check_D: 151 | self.model_names += ['D_A', 'D_B'] 152 | 153 | # define networks (both Generators and discriminators) 154 | # The naming is different from those used in the paper. 155 | # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) 156 | if not self.opt.style_control: 157 | self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, 158 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) 159 | else: 160 | print(opt.netga) 161 | print('model0_res', opt.model0_res) 162 | print('model1_res', opt.model1_res) 163 | self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netga, opt.norm, 164 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.model0_res, opt.model1_res) 165 | self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, 166 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) 167 | 168 | #if self.isTrain: # define discriminators 169 | if self.isTrain or self.opt.check_D: # define discriminators 170 | self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netda, 171 | opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids, n_class=3) 172 | self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD, 173 | opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) 174 | if self.opt.use_mask: 175 | if self.opt.mask_type in [2, 3]: 176 | output_nc = opt.output_nc + 1 177 | else: 178 | output_nc = opt.output_nc 179 | self.netD_A_l = networks.define_D(output_nc, opt.ndf, opt.netD, 180 | opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) 181 | if self.opt.use_eye_mask: 182 | if self.opt.mask_type in [2, 3]: 183 | output_nc = opt.output_nc + 1 184 | else: 185 | output_nc = opt.output_nc 186 | self.netD_A_le = networks.define_D(output_nc, opt.ndf, opt.netD, 187 | opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) 188 | if self.opt.use_lip_mask: 189 | if self.opt.mask_type in [2, 3]: 190 | output_nc = opt.output_nc + 1 191 | else: 192 | output_nc = opt.output_nc 193 | self.netD_A_ll = networks.define_D(output_nc, opt.ndf, opt.netD, 194 | opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) 195 | 196 | if self.isTrain and self.opt.metric: 197 | if not opt.metric_resnext and not opt.metric_resnet and not opt.metric_inception: 198 | self.metric = networks.define_inception_v3a(init_weights_='./checkpoints/metric/'+self.opt.metric_model_path,gpu_ids_ = self.gpu_ids,vec=self.opt.metricvec) 199 | elif opt.metric_resnext: 200 | self.metric = networks.define_resnext101a(init_weights_='./checkpoints/metric/'+self.opt.metric_model_path,gpu_ids_ = self.gpu_ids,vec=self.opt.metricvec) 201 | elif opt.metric_resnet: 202 | self.metric = networks.define_resnet101a(init_weights_='./checkpoints/metric/'+self.opt.metric_model_path,gpu_ids_ = self.gpu_ids,vec=self.opt.metricvec) 203 | elif opt.metric_inception: 204 | self.metric = networks.define_inception3a(init_weights_='./checkpoints/metric/'+self.opt.metric_model_path,gpu_ids_ = self.gpu_ids,vec=self.opt.metricvec) 205 | self.metric.eval() 206 | self.set_requires_grad(self.metric, False) 207 | 208 | if not self.isTrain and self.opt.check_D: 209 | self.criterionGAN = networks.GANLoss('lsgan').to(self.device) 210 | 211 | if self.isTrain: 212 | if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels 213 | assert(opt.input_nc == opt.output_nc) 214 | self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images 215 | self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images 216 | # define loss functions 217 | self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) # define GAN loss. 218 | self.criterionCycle = torch.nn.L1Loss() 219 | self.criterionIdt = torch.nn.L1Loss() 220 | self.criterionCls = torch.nn.CrossEntropyLoss() 221 | self.criterionCls2 = torch.nn.CrossEntropyLoss(reduction='none') 222 | # initialize optimizers; schedulers will be automatically created by function . 223 | self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) 224 | if not self.opt.use_mask: 225 | self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) 226 | elif not self.opt.use_eye_mask: 227 | D_params = list(self.netD_A.parameters()) + list(self.netD_B.parameters()) + list(self.netD_A_l.parameters()) 228 | self.optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, 0.999)) 229 | elif not self.opt.use_lip_mask: 230 | D_params = list(self.netD_A.parameters()) + list(self.netD_B.parameters()) + list(self.netD_A_l.parameters()) + list(self.netD_A_le.parameters()) 231 | self.optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, 0.999)) 232 | else: 233 | D_params = list(self.netD_A.parameters()) + list(self.netD_B.parameters()) + list(self.netD_A_l.parameters()) + list(self.netD_A_le.parameters()) + list(self.netD_A_ll.parameters()) 234 | self.optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, 0.999)) 235 | self.optimizers.append(self.optimizer_G) 236 | self.optimizers.append(self.optimizer_D) 237 | 238 | if self.opt.perceptual_cycle: 239 | if self.opt.perceptual_cycle in [1,2,3,6]: 240 | self.lpips = dm.DistModel(opt,model='net-lin',net='alex',use_gpu=True) 241 | elif self.opt.perceptual_cycle in [4,5,8]: 242 | self.vgg = networks.define_VGG(init_weights_=opt.vgg_pretrained_mode, feature_mode_=True, gpu_ids_=self.gpu_ids) # using conv4_4 layer 243 | 244 | if self.opt.use_hed: 245 | #self.hed = networks.define_HED(init_weights_=opt.hed_pretrained_mode, gpu_ids_=self.gpu_ids) 246 | self.hed = networks.define_HED(init_weights_=opt.hed_pretrained_mode, gpu_ids_=self.opt.gpu_ids_p) 247 | self.set_requires_grad(self.hed, False) 248 | 249 | 250 | def set_input(self, input): 251 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 252 | 253 | Parameters: 254 | input (dict): include the data itself and its metadata information. 255 | 256 | The option 'direction' can be used to swap domain A and domain B. 257 | """ 258 | AtoB = self.opt.direction == 'AtoB' 259 | self.real_A = input['A' if AtoB else 'B'].to(self.device) 260 | self.real_B = input['B' if AtoB else 'A'].to(self.device) 261 | self.image_paths = input['A_paths' if AtoB else 'B_paths'] 262 | if self.opt.use_mask: 263 | self.A_mask = input['A_mask'].to(self.device) 264 | self.B_mask = input['B_mask'].to(self.device) 265 | if self.opt.use_eye_mask: 266 | self.A_maske = input['A_maske'].to(self.device) 267 | self.B_maske = input['B_maske'].to(self.device) 268 | if self.opt.use_lip_mask: 269 | self.A_maskl = input['A_maskl'].to(self.device) 270 | self.B_maskl = input['B_maskl'].to(self.device) 271 | if self.opt.style_control: 272 | self.real_B_style = input['B_style'].to(self.device) 273 | self.real_B_label = input['B_label'].to(self.device) 274 | if self.opt.isTrain and self.opt.style_loss_with_weight: 275 | self.real_B_style0 = input['B_style0'].to(self.device) 276 | self.zero = torch.zeros(self.real_B_label.size(),dtype=torch.int64).to(self.device) 277 | self.one = torch.ones(self.real_B_label.size(),dtype=torch.int64).to(self.device) 278 | self.two = 2*torch.ones(self.real_B_label.size(),dtype=torch.int64).to(self.device) 279 | if self.opt.isTrain and self.opt.metricvec: 280 | self.vec = input['vec'].to(self.device) 281 | if self.opt.isTrain and self.opt.metric_inmask: 282 | self.A_maskfg = input['A_maskfg'].to(self.device) 283 | 284 | def forward(self): 285 | """Run forward pass; called by both functions and .""" 286 | if not self.opt.style_control: 287 | self.fake_B = self.netG_A(self.real_A) # G_A(A) 288 | else: 289 | #print(torch.mean(self.real_B_style,(2,3)),'style_control') 290 | #print(self.real_B_style,'style_control') 291 | self.fake_B = self.netG_A(self.real_A, self.real_B_style) 292 | self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) 293 | self.fake_A = self.netG_B(self.real_B) # G_B(B) 294 | if not self.opt.style_control: 295 | self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) 296 | else: 297 | #print(torch.mean(self.real_B_style,(2,3)),'style_control') 298 | self.rec_B = self.netG_A(self.fake_A, self.real_B_style) # -- cycle_B loss 299 | 300 | if self.opt.use_mask: 301 | self.fake_B_l = self.masked(self.fake_B,self.A_mask) 302 | self.real_B_l = self.masked(self.real_B,self.B_mask) 303 | if self.opt.use_eye_mask: 304 | self.fake_B_le = self.masked(self.fake_B,self.A_maske) 305 | self.real_B_le = self.masked(self.real_B,self.B_maske) 306 | if self.opt.use_lip_mask: 307 | self.fake_B_ll = self.masked(self.fake_B,self.A_maskl) 308 | self.real_B_ll = self.masked(self.real_B,self.B_maskl) 309 | 310 | def backward_D_basic(self, netD, real, fake): 311 | """Calculate GAN loss for the discriminator 312 | 313 | Parameters: 314 | netD (network) -- the discriminator D 315 | real (tensor array) -- real images 316 | fake (tensor array) -- images generated by a generator 317 | 318 | Return the discriminator loss. 319 | We also call loss_D.backward() to calculate the gradients. 320 | """ 321 | # Real 322 | pred_real = netD(real) 323 | loss_D_real = self.criterionGAN(pred_real, True) 324 | # Fake 325 | pred_fake = netD(fake.detach()) 326 | loss_D_fake = self.criterionGAN(pred_fake, False) 327 | # Combined loss and calculate gradients 328 | loss_D = (loss_D_real + loss_D_fake) * 0.5 329 | loss_D.backward() 330 | return loss_D 331 | 332 | def backward_D_basic_cls(self, netD, real, fake): 333 | # Real 334 | pred_real, pred_real_cls = netD(real) 335 | loss_D_real = self.criterionGAN(pred_real, True) 336 | if not self.opt.style_loss_with_weight: 337 | loss_D_real_cls = self.criterionCls(pred_real_cls, self.real_B_label) 338 | else: 339 | loss_D_real_cls = torch.mean(self.real_B_style0[:,0] * self.criterionCls2(pred_real_cls, self.zero) + self.real_B_style0[:,1] * self.criterionCls2(pred_real_cls, self.one) + self.real_B_style0[:,2] * self.criterionCls2(pred_real_cls, self.two)) 340 | # Fake 341 | pred_fake, pred_fake_cls = netD(fake.detach()) 342 | loss_D_fake = self.criterionGAN(pred_fake, False) 343 | if not self.opt.style_loss_with_weight: 344 | loss_D_fake_cls = self.criterionCls(pred_fake_cls, self.real_B_label) 345 | else: 346 | loss_D_fake_cls = torch.mean(self.real_B_style0[:,0] * self.criterionCls2(pred_fake_cls, self.zero) + self.real_B_style0[:,1] * self.criterionCls2(pred_fake_cls, self.one) + self.real_B_style0[:,2] * self.criterionCls2(pred_fake_cls, self.two)) 347 | # Combined loss and calculate gradients 348 | loss_D = (loss_D_real + loss_D_fake) * 0.5 349 | loss_D_cls = (loss_D_real_cls + loss_D_fake_cls) * 0.5 350 | loss_D_total = loss_D + loss_D_cls 351 | loss_D_total.backward() 352 | return loss_D, loss_D_cls 353 | 354 | def backward_D_A(self): 355 | """Calculate GAN loss for discriminator D_A""" 356 | fake_B = self.fake_B_pool.query(self.fake_B) 357 | self.loss_D_A, self.loss_D_A_cls = self.backward_D_basic_cls(self.netD_A, self.real_B, fake_B) 358 | 359 | def backward_D_A_l(self): 360 | """Calculate GAN loss for discriminator D_A_l""" 361 | fake_B = self.fake_B_pool.query(self.fake_B) 362 | self.loss_D_A_l = self.backward_D_basic(self.netD_A_l, self.masked(self.real_B,self.B_mask), self.masked(fake_B,self.A_mask)) 363 | 364 | def backward_D_A_le(self): 365 | """Calculate GAN loss for discriminator D_A_le""" 366 | fake_B = self.fake_B_pool.query(self.fake_B) 367 | self.loss_D_A_le = self.backward_D_basic(self.netD_A_le, self.masked(self.real_B,self.B_maske), self.masked(fake_B,self.A_maske)) 368 | 369 | def backward_D_A_ll(self): 370 | """Calculate GAN loss for discriminator D_A_ll""" 371 | fake_B = self.fake_B_pool.query(self.fake_B) 372 | self.loss_D_A_ll = self.backward_D_basic(self.netD_A_ll, self.masked(self.real_B,self.B_maskl), self.masked(fake_B,self.A_maskl)) 373 | 374 | def backward_D_B(self): 375 | """Calculate GAN loss for discriminator D_B""" 376 | fake_A = self.fake_A_pool.query(self.fake_A) 377 | self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) 378 | 379 | def update_process(self, epoch): 380 | self.process = (epoch - 1) / float(self.opt.niter_decay + self.opt.niter) 381 | 382 | def backward_G(self): 383 | """Calculate the loss for generators G_A and G_B""" 384 | lambda_idt = self.opt.lambda_identity 385 | lambda_G_A_l = self.opt.lambda_G_A_l 386 | lambda_A = self.opt.lambda_A 387 | lambda_B = self.opt.lambda_B 388 | lambda_A_trunc = self.opt.lambda_A_trunc 389 | if self.opt.ntrunc_trunc: 390 | lambda_A = lambda_A * (1 - self.process * 0.9) 391 | lambda_A_trunc = lambda_A_trunc * self.process * 0.9 392 | self.lambda_As = [lambda_A, lambda_A_trunc] 393 | # Identity loss 394 | if lambda_idt > 0: 395 | # G_A should be identity if real_B is fed: ||G_A(B) - B|| 396 | self.idt_A = self.netG_A(self.real_B) 397 | self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt 398 | # G_B should be identity if real_A is fed: ||G_B(A) - A|| 399 | self.idt_B = self.netG_B(self.real_A) 400 | self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt 401 | else: 402 | self.loss_idt_A = 0 403 | self.loss_idt_B = 0 404 | 405 | # GAN loss D_A(G_A(A)) 406 | pred_fake, pred_fake_cls = self.netD_A(self.fake_B) 407 | self.loss_G_A = self.criterionGAN(pred_fake, True) 408 | if not self.opt.style_loss_with_weight: 409 | self.loss_G_A_cls = self.criterionCls(pred_fake_cls, self.real_B_label) 410 | else: 411 | self.loss_G_A_cls = torch.mean(self.real_B_style0[:,0] * self.criterionCls2(pred_fake_cls, self.zero) + self.real_B_style0[:,1] * self.criterionCls2(pred_fake_cls, self.one) + self.real_B_style0[:,2] * self.criterionCls2(pred_fake_cls, self.two)) 412 | if self.opt.use_mask: 413 | self.loss_G_A_l = self.criterionGAN(self.netD_A_l(self.fake_B_l), True) * lambda_G_A_l 414 | if self.opt.use_eye_mask: 415 | self.loss_G_A_le = self.criterionGAN(self.netD_A_le(self.fake_B_le), True) * lambda_G_A_l 416 | if self.opt.use_lip_mask: 417 | self.loss_G_A_ll = self.criterionGAN(self.netD_A_ll(self.fake_B_ll), True) * lambda_G_A_l 418 | # GAN loss D_B(G_B(B)) 419 | self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) 420 | # Forward cycle loss || G_B(G_A(A)) - A|| 421 | if self.opt.perceptual_cycle == 0: 422 | self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A 423 | if self.opt.ntrunc_trunc: 424 | self.rec_At = self.netG_B(truncate(self.fake_B,self.opt.trunc_a)) 425 | self.loss_cycle_A2 = self.criterionCycle(self.rec_At, self.real_A) * lambda_A_trunc 426 | else: 427 | if self.opt.perceptual_cycle == 1: 428 | self.loss_cycle_A = self.lpips.forward_pair(self.rec_A, self.real_A).mean() * lambda_A 429 | if self.opt.ntrunc_trunc: 430 | self.rec_At = self.netG_B(truncate(self.fake_B,self.opt.trunc_a)) 431 | self.loss_cycle_A2 = self.lpips.forward_pair(self.rec_At, self.real_A).mean() * lambda_A_trunc 432 | elif self.opt.perceptual_cycle == 2: 433 | ts = self.real_A.shape 434 | rec_A = (self.rec_A[:,0,:,:]*0.299+self.rec_A[:,1,:,:]*0.587+self.rec_A[:,2,:,:]*0.114).unsqueeze(0) 435 | real_A = (self.real_A[:,0,:,:]*0.299+self.real_A[:,1,:,:]*0.587+self.real_A[:,2,:,:]*0.114).unsqueeze(0) 436 | self.loss_cycle_A = self.lpips.forward_pair(rec_A.expand(ts), real_A.expand(ts)).mean() * lambda_A 437 | elif self.opt.perceptual_cycle == 3 and self.opt.use_hed: 438 | ts = self.real_A.shape 439 | #[-1,1]->[0,1]->[-1,1] 440 | rec_A_hed = (self.hed(self.rec_A/2+0.5)-0.5)*2 441 | real_A_hed = (self.hed(self.real_A/2+0.5)-0.5)*2 442 | self.loss_cycle_A = self.lpips.forward_pair(rec_A_hed.expand(ts), real_A_hed.expand(ts)).mean() * lambda_A 443 | self.rec_A_hed = rec_A_hed 444 | self.real_A_hed = real_A_hed 445 | print(lambda_A) 446 | elif self.opt.perceptual_cycle == 4: 447 | x_a_feature = self.vgg(self.real_A) 448 | g_a_feature = self.vgg(self.rec_A) 449 | self.loss_cycle_A = self.criterionCycle(g_a_feature, x_a_feature.detach()) * lambda_A 450 | elif self.opt.perceptual_cycle == 5 and self.opt.use_hed: 451 | ts = self.real_A.shape 452 | rec_A_hed = (self.hed(self.rec_A/2+0.5)-0.5)*2 453 | real_A_hed = (self.hed(self.real_A/2+0.5)-0.5)*2 454 | x_a_feature = self.vgg(real_A_hed.expand(ts)) 455 | g_a_feature = self.vgg(rec_A_hed.expand(ts)) 456 | self.loss_cycle_A = self.criterionCycle(g_a_feature, x_a_feature.detach()) * lambda_A 457 | self.rec_A_hed = rec_A_hed 458 | self.real_A_hed = real_A_hed 459 | elif self.opt.perceptual_cycle == 6 and self.opt.use_hed and self.opt.ntrunc_trunc: 460 | ts = self.real_A.shape 461 | gpu_p = self.opt.gpu_ids_p[0] 462 | gpu = self.opt.gpu_ids[0] 463 | rec_A_hed = (self.hed(self.rec_A.cuda(gpu_p)/2+0.5)-0.5)*2 464 | real_A_hed = (self.hed(self.real_A.cuda(gpu_p)/2+0.5)-0.5)*2 465 | self.rec_At = self.netG_B(truncate(self.fake_B,self.opt.trunc_a)) 466 | rec_At_hed = (self.hed(self.rec_At.cuda(gpu_p)/2+0.5)-0.5)*2 467 | self.loss_cycle_A = (self.lpips.forward_pair(rec_A_hed.expand(ts), real_A_hed.expand(ts)).mean()).cuda(gpu) * lambda_A 468 | self.loss_cycle_A2 = (self.lpips.forward_pair(rec_At_hed.expand(ts), real_A_hed.expand(ts)).mean()).cuda(gpu) * lambda_A_trunc 469 | self.rec_A_hed = rec_A_hed 470 | self.real_A_hed = real_A_hed 471 | self.rec_At_hed = rec_At_hed 472 | elif self.opt.perceptual_cycle == 8 and self.opt.use_hed and self.opt.ntrunc_trunc: 473 | ts = self.real_A.shape 474 | rec_A_hed = (self.hed(self.rec_A/2+0.5)-0.5)*2 475 | real_A_hed = (self.hed(self.real_A/2+0.5)-0.5)*2 476 | self.rec_At = self.netG_B(truncate(self.fake_B,self.opt.trunc_a)) 477 | rec_At_hed = (self.hed(self.rec_At/2+0.5)-0.5)*2 478 | x_a_feature = self.vgg(real_A_hed.expand(ts)) 479 | g_a_feature = self.vgg(rec_A_hed.expand(ts)) 480 | gt_a_feature = self.vgg(rec_At_hed.expand(ts)) 481 | self.loss_cycle_A = self.criterionCycle(g_a_feature, x_a_feature.detach()) * lambda_A 482 | self.loss_cycle_A2 = self.criterionCycle(gt_a_feature, x_a_feature.detach()) * lambda_A_trunc 483 | self.rec_A_hed = rec_A_hed 484 | self.real_A_hed = real_A_hed 485 | self.rec_At_hed = rec_At_hed 486 | 487 | # Backward cycle loss || G_A(G_B(B)) - B|| 488 | self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B 489 | 490 | # Metric loss, metric higher better 491 | if self.opt.metric: 492 | self.fake_B2 = self.fake_B.clone() 493 | if self.opt.metric_inmask: 494 | # background black 495 | #self.fake_B2 = (self.fake_B2/2+0.5)*self.A_maskfg*2-1 496 | # background white 497 | self.fake_B2 = ((self.fake_B2/2+0.5)*self.A_maskfg+1-self.A_maskfg)*2-1 498 | if not self.opt.metric_resnext and not self.opt.metric_resnet: # for two version of inception (during training input is [-1,1]) 499 | self.fake_B2 = torch.nn.functional.interpolate(input=self.fake_B2, size=(299, 299), mode='bilinear', align_corners=False) 500 | self.fake_B2 = self.fake_B2.repeat(1,3,1,1) 501 | else: # for resnet and resnext 502 | self.fake_B2 = torch.nn.functional.interpolate(input=self.fake_B2, size=(224, 224), mode='bilinear', align_corners=False) 503 | x = self.fake_B2.repeat(1,3,1,1) 504 | # [-1,1] -> [0,1] -> mean [0.485,0.456,0.406], std [0.229,0.224,0.225] 505 | x_ch0 = (torch.unsqueeze(x[:, 0],1)*0.5+0.5-0.485)/0.229 506 | x_ch1 = (torch.unsqueeze(x[:, 1],1)*0.5+0.5-0.456)/0.224 507 | x_ch2 = (torch.unsqueeze(x[:, 2],1)*0.5+0.5-0.406)/0.225 508 | self.fake_B2 = torch.cat((x_ch0, x_ch1, x_ch2, x[:, 3:]), 1) 509 | 510 | 511 | if not self.opt.metricvec: 512 | pred = self.metric(self.fake_B2) 513 | else: 514 | pred = self.metric(torch.cat((self.fake_B2, self.vec),1)) 515 | self.loss_metric = torch.mean((1-pred)) * self.opt.lambda_metric 516 | 517 | # combined loss and calculate gradients 518 | self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B 519 | if getattr(self,'loss_cycle_A2',-1) != -1: 520 | self.loss_G = self.loss_G + self.loss_cycle_A2 521 | if getattr(self,'loss_G_A_l',-1) != -1: 522 | self.loss_G = self.loss_G + self.loss_G_A_l 523 | if getattr(self,'loss_G_A_le',-1) != -1: 524 | self.loss_G = self.loss_G + self.loss_G_A_le 525 | if getattr(self,'loss_G_A_ll',-1) != -1: 526 | self.loss_G = self.loss_G + self.loss_G_A_ll 527 | if getattr(self,'loss_G_A_cls',-1) != -1: 528 | self.loss_G = self.loss_G + self.loss_G_A_cls 529 | if getattr(self,'loss_metric',-1) != -1: 530 | self.loss_G = self.loss_G + self.loss_metric 531 | self.loss_G.backward() 532 | 533 | def optimize_parameters(self): 534 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 535 | # forward 536 | self.forward() # compute fake images and reconstruction images. 537 | # G_A and G_B 538 | self.set_requires_grad([self.netD_A, self.netD_B], False) # Ds require no gradients when optimizing Gs 539 | if self.opt.use_mask: 540 | self.set_requires_grad([self.netD_A_l], False) 541 | if self.opt.use_eye_mask: 542 | self.set_requires_grad([self.netD_A_le], False) 543 | if self.opt.use_lip_mask: 544 | self.set_requires_grad([self.netD_A_ll], False) 545 | self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero 546 | self.backward_G() # calculate gradients for G_A and G_B 547 | self.optimizer_G.step() # update G_A and G_B's weights 548 | # D_A and D_B 549 | self.set_requires_grad([self.netD_A, self.netD_B], True) 550 | if self.opt.use_mask: 551 | self.set_requires_grad([self.netD_A_l], True) 552 | if self.opt.use_eye_mask: 553 | self.set_requires_grad([self.netD_A_le], True) 554 | if self.opt.use_lip_mask: 555 | self.set_requires_grad([self.netD_A_ll], True) 556 | self.optimizer_D.zero_grad() # set D_A and D_B's gradients to zero 557 | self.backward_D_A() # calculate gradients for D_A 558 | if self.opt.use_mask: 559 | self.backward_D_A_l()# calculate gradients for D_A_l 560 | if self.opt.use_eye_mask: 561 | self.backward_D_A_le()# calculate gradients for D_A_le 562 | if self.opt.use_lip_mask: 563 | self.backward_D_A_ll()# calculate gradients for D_A_ll 564 | self.backward_D_B() # calculate graidents for D_B 565 | self.optimizer_D.step() # update D_A and D_B's weights 566 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import init 5 | import functools 6 | from torch.optim import lr_scheduler 7 | import pdb 8 | 9 | 10 | ############################################################################### 11 | # Helper Functions 12 | ############################################################################### 13 | 14 | 15 | class Identity(nn.Module): 16 | def forward(self, x): 17 | return x 18 | 19 | 20 | def get_norm_layer(norm_type='instance'): 21 | """Return a normalization layer 22 | 23 | Parameters: 24 | norm_type (str) -- the name of the normalization layer: batch | instance | none 25 | 26 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). 27 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. 28 | """ 29 | if norm_type == 'batch': 30 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 31 | elif norm_type == 'instance': 32 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 33 | elif norm_type == 'none': 34 | norm_layer = lambda x: Identity() 35 | else: 36 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 37 | return norm_layer 38 | 39 | 40 | def get_scheduler(optimizer, opt): 41 | """Return a learning rate scheduler 42 | 43 | Parameters: 44 | optimizer -- the optimizer of the network 45 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  46 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 47 | 48 | For 'linear', we keep the same learning rate for the first epochs 49 | and linearly decay the rate to zero over the next epochs. 50 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 51 | See https://pytorch.org/docs/stable/optim.html for more details. 52 | """ 53 | if opt.lr_policy == 'linear': 54 | def lambda_rule(epoch): 55 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 56 | return lr_l 57 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 58 | elif opt.lr_policy == 'step': 59 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 60 | elif opt.lr_policy == 'plateau': 61 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 62 | elif opt.lr_policy == 'cosine': 63 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) 64 | else: 65 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 66 | return scheduler 67 | 68 | 69 | def init_weights(net, init_type='normal', init_gain=0.02): 70 | """Initialize network weights. 71 | 72 | Parameters: 73 | net (network) -- network to be initialized 74 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 75 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 76 | 77 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 78 | work better for some applications. Feel free to try yourself. 79 | """ 80 | def init_func(m): # define the initialization function 81 | classname = m.__class__.__name__ 82 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 83 | if init_type == 'normal': 84 | init.normal_(m.weight.data, 0.0, init_gain) 85 | elif init_type == 'xavier': 86 | init.xavier_normal_(m.weight.data, gain=init_gain) 87 | elif init_type == 'kaiming': 88 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 89 | elif init_type == 'orthogonal': 90 | init.orthogonal_(m.weight.data, gain=init_gain) 91 | else: 92 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 93 | if hasattr(m, 'bias') and m.bias is not None: 94 | init.constant_(m.bias.data, 0.0) 95 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 96 | init.normal_(m.weight.data, 1.0, init_gain) 97 | init.constant_(m.bias.data, 0.0) 98 | 99 | print('initialize network with %s' % init_type) 100 | net.apply(init_func) # apply the initialization function 101 | 102 | 103 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 104 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 105 | Parameters: 106 | net (network) -- the network to be initialized 107 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 108 | gain (float) -- scaling factor for normal, xavier and orthogonal. 109 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 110 | 111 | Return an initialized network. 112 | """ 113 | if len(gpu_ids) > 0: 114 | assert(torch.cuda.is_available()) 115 | net.to(gpu_ids[0]) 116 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 117 | init_weights(net, init_type, init_gain=init_gain) 118 | return net 119 | 120 | 121 | def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], model0_res=0, model1_res=0, extra_channel=3): 122 | """Create a generator 123 | 124 | Parameters: 125 | input_nc (int) -- the number of channels in input images 126 | output_nc (int) -- the number of channels in output images 127 | ngf (int) -- the number of filters in the last conv layer 128 | netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128 129 | norm (str) -- the name of normalization layers used in the network: batch | instance | none 130 | use_dropout (bool) -- if use dropout layers. 131 | init_type (str) -- the name of our initialization method. 132 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 133 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 134 | 135 | Returns a generator 136 | 137 | Our current implementation provides two types of generators: 138 | U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images) 139 | The original U-Net paper: https://arxiv.org/abs/1505.04597 140 | 141 | Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks) 142 | Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations. 143 | We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style). 144 | 145 | 146 | The generator has been initialized by . It uses RELU for non-linearity. 147 | """ 148 | net = None 149 | norm_layer = get_norm_layer(norm_type=norm) 150 | 151 | if netG == 'resnet_9blocks': 152 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) 153 | elif netG == 'resnet_8blocks': 154 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=8) 155 | elif netG == 'resnet_style_9blocks': 156 | net = ResnetStyleGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, extra_channel=extra_channel) 157 | elif netG == 'resnet_style2_9blocks': 158 | net = ResnetStyle2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, model0_res=model0_res, extra_channel=extra_channel) 159 | elif netG == 'resnet_style2_8blocks': 160 | net = ResnetStyle2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=8, model0_res=model0_res, extra_channel=extra_channel) 161 | elif netG == 'resnet_style2_10blocks': 162 | net = ResnetStyle2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=10, model0_res=model0_res, extra_channel=extra_channel) 163 | elif netG == 'resnet_style3decoder_9blocks': 164 | net = ResnetStyle3DecoderGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, model0_res=model0_res) 165 | elif netG == 'resnet_style2mc_9blocks': 166 | net = ResnetStyle2MCGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, model0_res=model0_res, extra_channel=extra_channel) 167 | elif netG == 'resnet_style2mc2_9blocks': 168 | net = ResnetStyle2MC2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, model0_res=model0_res, model1_res=model1_res, extra_channel=extra_channel) 169 | elif netG == 'resnet_6blocks': 170 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6) 171 | elif netG == 'unet_128': 172 | net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 173 | elif netG == 'unet_256': 174 | net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 175 | else: 176 | raise NotImplementedError('Generator model name [%s] is not recognized' % netG) 177 | return init_net(net, init_type, init_gain, gpu_ids) 178 | 179 | 180 | def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[], n_class=3): 181 | """Create a discriminator 182 | 183 | Parameters: 184 | input_nc (int) -- the number of channels in input images 185 | ndf (int) -- the number of filters in the first conv layer 186 | netD (str) -- the architecture's name: basic | n_layers | pixel 187 | n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' 188 | norm (str) -- the type of normalization layers used in the network. 189 | init_type (str) -- the name of the initialization method. 190 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 191 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 192 | 193 | Returns a discriminator 194 | 195 | Our current implementation provides three types of discriminators: 196 | [basic]: 'PatchGAN' classifier described in the original pix2pix paper. 197 | It can classify whether 70×70 overlapping patches are real or fake. 198 | Such a patch-level discriminator architecture has fewer parameters 199 | than a full-image discriminator and can work on arbitrarily-sized images 200 | in a fully convolutional fashion. 201 | 202 | [n_layers]: With this mode, you cna specify the number of conv layers in the discriminator 203 | with the parameter (default=3 as used in [basic] (PatchGAN).) 204 | 205 | [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not. 206 | It encourages greater color diversity but has no effect on spatial statistics. 207 | 208 | The discriminator has been initialized by . It uses Leakly RELU for non-linearity. 209 | """ 210 | net = None 211 | norm_layer = get_norm_layer(norm_type=norm) 212 | 213 | if netD == 'basic': # default PatchGAN classifier 214 | net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer) 215 | elif netD == 'basic_cls': 216 | net = NLayerDiscriminatorCls(input_nc, ndf, n_layers=3, n_class=3, norm_layer=norm_layer) 217 | elif netD == 'n_layers': # more options 218 | net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer) 219 | elif netD == 'pixel': # classify if each pixel is real or fake 220 | net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) 221 | else: 222 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % net) 223 | return init_net(net, init_type, init_gain, gpu_ids) 224 | 225 | 226 | def define_HED(init_weights_, gpu_ids_=[]): 227 | net = HED() 228 | 229 | if len(gpu_ids_) > 0: 230 | assert(torch.cuda.is_available()) 231 | net.to(gpu_ids_[0]) 232 | net = torch.nn.DataParallel(net, gpu_ids_) # multi-GPUs 233 | 234 | if not init_weights_ == None: 235 | device = torch.device('cuda:{}'.format(gpu_ids_[0])) if gpu_ids_ else torch.device('cpu') 236 | print('Loading model from: %s'%init_weights_) 237 | state_dict = torch.load(init_weights_, map_location=str(device)) 238 | if isinstance(net, torch.nn.DataParallel): 239 | net.module.load_state_dict(state_dict) 240 | else: 241 | net.load_state_dict(state_dict) 242 | print('load the weights successfully') 243 | 244 | return net 245 | 246 | def define_VGG(init_weights_, feature_mode_, batch_norm_=False, num_classes_=1000, gpu_ids_=[]): 247 | net = VGG19(init_weights=init_weights_, feature_mode=feature_mode_, batch_norm=batch_norm_, num_classes=num_classes_) 248 | # set the GPU 249 | if len(gpu_ids_) > 0: 250 | assert(torch.cuda.is_available()) 251 | net.cuda(gpu_ids_[0]) 252 | net = torch.nn.DataParallel(net, gpu_ids_) # multi-GPUs 253 | 254 | if not init_weights_ == None: 255 | device = torch.device('cuda:{}'.format(gpu_ids_[0])) if gpu_ids_ else torch.device('cpu') 256 | print('Loading model from: %s'%init_weights_) 257 | state_dict = torch.load(init_weights_, map_location=str(device)) 258 | if isinstance(net, torch.nn.DataParallel): 259 | net.module.load_state_dict(state_dict) 260 | else: 261 | net.load_state_dict(state_dict) 262 | print('load the weights successfully') 263 | return net 264 | 265 | ################################################################################################################### 266 | from torchvision.models import vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn 267 | def define_vgg11_bn(gpu_ids_=[],vec=0): 268 | net = vgg11_bn(pretrained=True) 269 | net.classifier[6] = nn.Linear(4096, 1) #LSGAN needs no sigmoid, LSGAN-nn.MSELoss() 270 | if len(gpu_ids_) > 0: 271 | assert(torch.cuda.is_available()) 272 | net.cuda(gpu_ids_[0]) 273 | net = torch.nn.DataParallel(net, gpu_ids_) 274 | return net 275 | def define_vgg19_bn(gpu_ids_=[],vec=0): 276 | net = vgg19_bn(pretrained=True) 277 | net.classifier[6] = nn.Linear(4096, 1) #LSGAN needs no sigmoid, LSGAN-nn.MSELoss() 278 | if len(gpu_ids_) > 0: 279 | assert(torch.cuda.is_available()) 280 | net.cuda(gpu_ids_[0]) 281 | net = torch.nn.DataParallel(net, gpu_ids_) 282 | return net 283 | def define_vgg19(gpu_ids_=[],vec=0): 284 | net = vgg19(pretrained=True) 285 | net.classifier[6] = nn.Linear(4096, 1) #LSGAN needs no sigmoid, LSGAN-nn.MSELoss() 286 | if len(gpu_ids_) > 0: 287 | assert(torch.cuda.is_available()) 288 | net.cuda(gpu_ids_[0]) 289 | net = torch.nn.DataParallel(net, gpu_ids_) 290 | return net 291 | ################################################################################################################### 292 | from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152 293 | def define_resnet101(gpu_ids_=[],vec=0): 294 | net = resnet101(pretrained=True) 295 | num_ftrs = net.fc.in_features 296 | net.fc = nn.Linear(num_ftrs, 1) #LSGAN needs no sigmoid, LSGAN-nn.MSELoss() 297 | if len(gpu_ids_) > 0: 298 | assert(torch.cuda.is_available()) 299 | net.cuda(gpu_ids_[0]) 300 | net = torch.nn.DataParallel(net, gpu_ids_) 301 | return net 302 | def define_resnet101a(init_weights_,gpu_ids_=[],vec=0): 303 | net = resnet101(pretrained=True) 304 | num_ftrs = net.fc.in_features 305 | net.fc = nn.Linear(num_ftrs, 1) #LSGAN needs no sigmoid, LSGAN-nn.MSELoss() 306 | if not init_weights_ == None: 307 | print('Loading model from: %s'%init_weights_) 308 | state_dict = torch.load(init_weights_, map_location=str(torch.device('cpu'))) 309 | if isinstance(net, torch.nn.DataParallel): 310 | net.module.load_state_dict(state_dict) 311 | else: 312 | net.load_state_dict(state_dict) 313 | print('load the weights successfully') 314 | if len(gpu_ids_) > 0: 315 | assert(torch.cuda.is_available()) 316 | net.cuda(gpu_ids_[0]) 317 | net = torch.nn.DataParallel(net, gpu_ids_) 318 | return net 319 | ################################################################################################################### 320 | import pretrainedmodels.models.resnext as resnext 321 | def define_resnext101(gpu_ids_=[],vec=0): 322 | net = resnext.resnext101_64x4d(num_classes=1000,pretrained='imagenet') 323 | net.last_linear = nn.Linear(2048, 1) #LSGAN needs no sigmoid, LSGAN-nn.MSELoss() 324 | if len(gpu_ids_) > 0: 325 | assert(torch.cuda.is_available()) 326 | net.cuda(gpu_ids_[0]) 327 | net = torch.nn.DataParallel(net, gpu_ids_) 328 | return net 329 | def define_resnext101a(init_weights_,gpu_ids_=[],vec=0): 330 | net = resnext.resnext101_64x4d(num_classes=1000,pretrained='imagenet') 331 | net.last_linear = nn.Linear(2048, 1) #LSGAN needs no sigmoid, LSGAN-nn.MSELoss() 332 | if not init_weights_ == None: 333 | print('Loading model from: %s'%init_weights_) 334 | state_dict = torch.load(init_weights_, map_location=str(torch.device('cpu'))) 335 | if isinstance(net, torch.nn.DataParallel): 336 | net.module.load_state_dict(state_dict) 337 | else: 338 | net.load_state_dict(state_dict) 339 | print('load the weights successfully') 340 | if len(gpu_ids_) > 0: 341 | assert(torch.cuda.is_available()) 342 | net.cuda(gpu_ids_[0]) 343 | net = torch.nn.DataParallel(net, gpu_ids_) 344 | return net 345 | ################################################################################################################### 346 | from torchvision.models import Inception3, inception_v3 347 | def define_inception3(gpu_ids_=[],vec=0): 348 | net = inception_v3(pretrained=True) 349 | net.transform_input = False # assume [-1,1] input 350 | net.fc = nn.Linear(2048, 1) 351 | net.aux_logits = False 352 | if len(gpu_ids_) > 0: 353 | assert(torch.cuda.is_available()) 354 | net.cuda(gpu_ids_[0]) 355 | net = torch.nn.DataParallel(net, gpu_ids_) 356 | return net 357 | def define_inception3a(init_weights_,gpu_ids_=[],vec=0): 358 | net = inception_v3(pretrained=True) 359 | net.transform_input = False # assume [-1,1] input 360 | net.fc = nn.Linear(2048, 1) 361 | net.aux_logits = False 362 | if not init_weights_ == None: 363 | print('Loading model from: ', init_weights_) 364 | state_dict = torch.load(init_weights_, map_location=str(torch.device('cpu'))) 365 | if isinstance(net, torch.nn.DataParallel): 366 | net.module.load_state_dict(state_dict) 367 | else: 368 | net.load_state_dict(state_dict) 369 | print('load the weights successfully') 370 | if len(gpu_ids_) > 0: 371 | assert(torch.cuda.is_available()) 372 | net.cuda(gpu_ids_[0]) 373 | net = torch.nn.DataParallel(net, gpu_ids_) 374 | return net 375 | ################################################################################################################### 376 | from torchvision.models.inception import BasicConv2d 377 | def define_inception_v3(init_weights_,gpu_ids_=[],vec=0): 378 | 379 | ## pretrained = True 380 | kwargs = {} 381 | if 'transform_input' not in kwargs: 382 | kwargs['transform_input'] = True 383 | if 'aux_logits' in kwargs: 384 | original_aux_logits = kwargs['aux_logits'] 385 | kwargs['aux_logits'] = True 386 | else: 387 | original_aux_logits = True 388 | print(kwargs) 389 | net = Inception3(**kwargs) 390 | 391 | if not init_weights_ == None: 392 | print('Loading model from: %s'%init_weights_) 393 | state_dict = torch.load(init_weights_, map_location=str(torch.device('cpu'))) 394 | if isinstance(net, torch.nn.DataParallel): 395 | net.module.load_state_dict(state_dict) 396 | else: 397 | net.load_state_dict(state_dict) 398 | print('load the weights successfully') 399 | 400 | if not original_aux_logits: 401 | net.aux_logits = False 402 | del net.AuxLogits 403 | 404 | net.fc = nn.Linear(2048, 1) 405 | if vec == 1: 406 | net.Conv2d_1a_3x3 = BasicConv2d(6, 32, kernel_size=3, stride=2) 407 | net.aux_logits = False 408 | 409 | if len(gpu_ids_) > 0: 410 | assert(torch.cuda.is_available()) 411 | net.cuda(gpu_ids_[0]) 412 | net = torch.nn.DataParallel(net, gpu_ids_) 413 | 414 | return net 415 | 416 | def define_inception_v3a(init_weights_,gpu_ids_=[],vec=0): 417 | 418 | kwargs = {} 419 | if 'transform_input' not in kwargs: 420 | kwargs['transform_input'] = True 421 | if 'aux_logits' in kwargs: 422 | original_aux_logits = kwargs['aux_logits'] 423 | kwargs['aux_logits'] = True 424 | else: 425 | original_aux_logits = True 426 | print(kwargs) 427 | net = Inception3(**kwargs) 428 | 429 | if not original_aux_logits: 430 | net.aux_logits = False 431 | del net.AuxLogits 432 | 433 | net.fc = nn.Linear(2048, 1) 434 | if vec == 1: 435 | net.Conv2d_1a_3x3 = BasicConv2d(6, 32, kernel_size=3, stride=2) 436 | net.aux_logits = False 437 | 438 | if not init_weights_ == None: 439 | print('Loading model from: %s'%init_weights_) 440 | state_dict = torch.load(init_weights_, map_location=str(torch.device('cpu'))) 441 | if isinstance(net, torch.nn.DataParallel): 442 | net.module.load_state_dict(state_dict) 443 | else: 444 | net.load_state_dict(state_dict) 445 | print('load the weights successfully') 446 | 447 | if len(gpu_ids_) > 0: 448 | assert(torch.cuda.is_available()) 449 | net.cuda(gpu_ids_[0]) 450 | net = torch.nn.DataParallel(net, gpu_ids_) 451 | 452 | return net 453 | 454 | def define_inception_ori(init_weights_,transform_input=False,gpu_ids_=[]): 455 | 456 | ## pretrained = True 457 | kwargs = {} 458 | kwargs['transform_input'] = transform_input 459 | 460 | if 'aux_logits' in kwargs: 461 | original_aux_logits = kwargs['aux_logits'] 462 | kwargs['aux_logits'] = True 463 | else: 464 | original_aux_logits = True 465 | print(kwargs) 466 | net = Inception3(**kwargs) 467 | 468 | 469 | if not init_weights_ == None: 470 | print('Loading model from: %s'%init_weights_) 471 | state_dict = torch.load(init_weights_, map_location=str(torch.device('cpu'))) 472 | if isinstance(net, torch.nn.DataParallel): 473 | net.module.load_state_dict(state_dict) 474 | else: 475 | net.load_state_dict(state_dict) 476 | print('load the weights successfully') 477 | #for e in list(net.modules()): 478 | # print(e) 479 | 480 | if not original_aux_logits: 481 | net.aux_logits = False 482 | del net.AuxLogits 483 | 484 | 485 | if len(gpu_ids_) > 0: 486 | assert(torch.cuda.is_available()) 487 | net.cuda(gpu_ids_[0]) 488 | 489 | return net 490 | ################################################################################################################### 491 | 492 | def define_DT(init_weights_, input_nc_, output_nc_, ngf_, netG_, norm_, use_dropout_, init_type_, init_gain_, gpu_ids_): 493 | net = define_G(input_nc_, output_nc_, ngf_, netG_, norm_, use_dropout_, init_type_, init_gain_, gpu_ids_) 494 | 495 | if not init_weights_ == None: 496 | device = torch.device('cuda:{}'.format(gpu_ids_[0])) if gpu_ids_ else torch.device('cpu') 497 | print('Loading model from: %s'%init_weights_) 498 | state_dict = torch.load(init_weights_, map_location=str(device)) 499 | if isinstance(net, torch.nn.DataParallel): 500 | net.module.load_state_dict(state_dict) 501 | else: 502 | net.load_state_dict(state_dict) 503 | print('load the weights successfully') 504 | return net 505 | 506 | def define_C(input_nc, classes, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], h=512, w=512, nnG=3, dim=4096): 507 | net = None 508 | norm_layer = get_norm_layer(norm_type=norm) 509 | if netG == 'classifier': 510 | net = Classifier(input_nc, classes, ngf, num_downs=nnG, norm_layer=norm_layer, use_dropout=use_dropout, h=h, w=w, dim=dim) 511 | elif netG == 'vgg': 512 | net = VGG19(init_weights=None, feature_mode=False, batch_norm=True, num_classes=classes) 513 | return init_net(net, init_type, init_gain, gpu_ids) 514 | 515 | ############################################################################## 516 | # Classes 517 | ############################################################################## 518 | class GANLoss(nn.Module): 519 | """Define different GAN objectives. 520 | 521 | The GANLoss class abstracts away the need to create the target label tensor 522 | that has the same size as the input. 523 | """ 524 | 525 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): 526 | """ Initialize the GANLoss class. 527 | 528 | Parameters: 529 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. 530 | target_real_label (bool) - - label for a real image 531 | target_fake_label (bool) - - label of a fake image 532 | 533 | Note: Do not use sigmoid as the last layer of Discriminator. 534 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. 535 | """ 536 | super(GANLoss, self).__init__() 537 | self.register_buffer('real_label', torch.tensor(target_real_label)) 538 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 539 | self.gan_mode = gan_mode 540 | if gan_mode == 'lsgan':#cyclegan 541 | self.loss = nn.MSELoss() 542 | elif gan_mode == 'vanilla': 543 | self.loss = nn.BCEWithLogitsLoss() 544 | elif gan_mode in ['wgangp']: 545 | self.loss = None 546 | else: 547 | raise NotImplementedError('gan mode %s not implemented' % gan_mode) 548 | 549 | def get_target_tensor(self, prediction, target_is_real): 550 | """Create label tensors with the same size as the input. 551 | 552 | Parameters: 553 | prediction (tensor) - - tpyically the prediction from a discriminator 554 | target_is_real (bool) - - if the ground truth label is for real images or fake images 555 | 556 | Returns: 557 | A label tensor filled with ground truth label, and with the size of the input 558 | """ 559 | 560 | if target_is_real: 561 | target_tensor = self.real_label 562 | else: 563 | target_tensor = self.fake_label 564 | return target_tensor.expand_as(prediction) 565 | 566 | def __call__(self, prediction, target_is_real): 567 | """Calculate loss given Discriminator's output and grount truth labels. 568 | 569 | Parameters: 570 | prediction (tensor) - - tpyically the prediction output from a discriminator 571 | target_is_real (bool) - - if the ground truth label is for real images or fake images 572 | 573 | Returns: 574 | the calculated loss. 575 | """ 576 | if self.gan_mode in ['lsgan', 'vanilla']: 577 | target_tensor = self.get_target_tensor(prediction, target_is_real) 578 | loss = self.loss(prediction, target_tensor) 579 | elif self.gan_mode == 'wgangp': 580 | if target_is_real: 581 | loss = -prediction.mean() 582 | else: 583 | loss = prediction.mean() 584 | return loss 585 | 586 | 587 | def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): 588 | """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 589 | 590 | Arguments: 591 | netD (network) -- discriminator network 592 | real_data (tensor array) -- real images 593 | fake_data (tensor array) -- generated images from the generator 594 | device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 595 | type (str) -- if we mix real and fake data or not [real | fake | mixed]. 596 | constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2 597 | lambda_gp (float) -- weight for this loss 598 | 599 | Returns the gradient penalty loss 600 | """ 601 | if lambda_gp > 0.0: 602 | if type == 'real': # either use real images, fake images, or a linear interpolation of two. 603 | interpolatesv = real_data 604 | elif type == 'fake': 605 | interpolatesv = fake_data 606 | elif type == 'mixed': 607 | alpha = torch.rand(real_data.shape[0], 1, device=device) 608 | alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) 609 | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) 610 | else: 611 | raise NotImplementedError('{} not implemented'.format(type)) 612 | interpolatesv.requires_grad_(True) 613 | disc_interpolates = netD(interpolatesv) 614 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, 615 | grad_outputs=torch.ones(disc_interpolates.size()).to(device), 616 | create_graph=True, retain_graph=True, only_inputs=True) 617 | gradients = gradients[0].view(real_data.size(0), -1) # flat the data 618 | gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps 619 | return gradient_penalty, gradients 620 | else: 621 | return 0.0, None 622 | 623 | 624 | class ResnetGenerator(nn.Module): 625 | """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. 626 | 627 | We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) 628 | """ 629 | 630 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): 631 | """Construct a Resnet-based generator 632 | 633 | Parameters: 634 | input_nc (int) -- the number of channels in input images 635 | output_nc (int) -- the number of channels in output images 636 | ngf (int) -- the number of filters in the last conv layer 637 | norm_layer -- normalization layer 638 | use_dropout (bool) -- if use dropout layers 639 | n_blocks (int) -- the number of ResNet blocks 640 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero 641 | """ 642 | assert(n_blocks >= 0) 643 | super(ResnetGenerator, self).__init__() 644 | if type(norm_layer) == functools.partial: 645 | use_bias = norm_layer.func == nn.InstanceNorm2d 646 | else: 647 | use_bias = norm_layer == nn.InstanceNorm2d 648 | 649 | model = [nn.ReflectionPad2d(3), 650 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), 651 | norm_layer(ngf), 652 | nn.ReLU(True)] 653 | 654 | n_downsampling = 2 655 | for i in range(n_downsampling): # add downsampling layers 656 | mult = 2 ** i 657 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), 658 | norm_layer(ngf * mult * 2), 659 | nn.ReLU(True)] 660 | 661 | mult = 2 ** n_downsampling 662 | for i in range(n_blocks): # add ResNet blocks 663 | 664 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 665 | 666 | for i in range(n_downsampling): # add upsampling layers 667 | mult = 2 ** (n_downsampling - i) 668 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 669 | kernel_size=3, stride=2, 670 | padding=1, output_padding=1, 671 | bias=use_bias), 672 | norm_layer(int(ngf * mult / 2)), 673 | nn.ReLU(True)] 674 | model += [nn.ReflectionPad2d(3)] 675 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 676 | model += [nn.Tanh()] 677 | 678 | self.model = nn.Sequential(*model) 679 | 680 | def forward(self, input, feature_mode = False): 681 | """Standard forward""" 682 | if not feature_mode: 683 | return self.model(input) 684 | else: 685 | module_list = list(self.model.modules()) 686 | x = input.clone() 687 | indexes = list(range(1,11))+[11,20,29,38,47,56,65,74,83]+list(range(92,101)) 688 | for i in indexes: 689 | x = module_list[i](x) 690 | if i == 3: 691 | x1 = x.clone() 692 | elif i == 6: 693 | x2 = x.clone() 694 | elif i == 9: 695 | x3 = x.clone() 696 | elif i == 47: 697 | y7 = x.clone() 698 | elif i == 83: 699 | y4 = x.clone() 700 | elif i == 93: 701 | y3 = x.clone() 702 | elif i == 96: 703 | y2 = x.clone() 704 | #y = self.model(input) 705 | #pdb.set_trace() 706 | return x,x1,x2,x3,y4,y3,y2,y7 707 | 708 | class ResnetStyleGenerator(nn.Module): 709 | """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. 710 | 711 | We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) 712 | """ 713 | 714 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): 715 | """Construct a Resnet-based generator 716 | 717 | Parameters: 718 | input_nc (int) -- the number of channels in input images 719 | output_nc (int) -- the number of channels in output images 720 | ngf (int) -- the number of filters in the last conv layer 721 | norm_layer -- normalization layer 722 | use_dropout (bool) -- if use dropout layers 723 | n_blocks (int) -- the number of ResNet blocks 724 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero 725 | """ 726 | assert(n_blocks >= 0) 727 | super(ResnetStyleGenerator, self).__init__() 728 | if type(norm_layer) == functools.partial: 729 | use_bias = norm_layer.func == nn.InstanceNorm2d 730 | else: 731 | use_bias = norm_layer == nn.InstanceNorm2d 732 | 733 | model0 = [nn.ReflectionPad2d(3), 734 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), 735 | norm_layer(ngf), 736 | nn.ReLU(True)] 737 | 738 | n_downsampling = 2 739 | for i in range(n_downsampling): # add downsampling layers 740 | mult = 2 ** i 741 | model0 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), 742 | norm_layer(ngf * mult * 2), 743 | nn.ReLU(True)] 744 | 745 | mult = 2 ** n_downsampling 746 | model1 = [nn.Conv2d(3, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias), 747 | norm_layer(ngf * mult), 748 | nn.ReLU(True)] 749 | 750 | model = [] 751 | model += [nn.Conv2d(ngf * mult * 2, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias), 752 | norm_layer(ngf * mult), 753 | nn.ReLU(True)] 754 | for i in range(n_blocks): # add ResNet blocks 755 | 756 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 757 | 758 | for i in range(n_downsampling): # add upsampling layers 759 | mult = 2 ** (n_downsampling - i) 760 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 761 | kernel_size=3, stride=2, 762 | padding=1, output_padding=1, 763 | bias=use_bias), 764 | norm_layer(int(ngf * mult / 2)), 765 | nn.ReLU(True)] 766 | model += [nn.ReflectionPad2d(3)] 767 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 768 | model += [nn.Tanh()] 769 | 770 | self.model0 = nn.Sequential(*model0) 771 | self.model1 = nn.Sequential(*model1) 772 | self.model = nn.Sequential(*model) 773 | 774 | def forward(self, input1, input2): 775 | """Standard forward""" 776 | f1 = self.model0(input1) 777 | f2 = self.model1(input2) 778 | #pdb.set_trace() 779 | f1 = torch.cat((f1,f2), 1) 780 | return self.model(f1) 781 | 782 | 783 | class ResnetStyle2Generator(nn.Module): 784 | """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. 785 | 786 | We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) 787 | """ 788 | 789 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', extra_channel=3, model0_res=0): 790 | """Construct a Resnet-based generator 791 | 792 | Parameters: 793 | input_nc (int) -- the number of channels in input images 794 | output_nc (int) -- the number of channels in output images 795 | ngf (int) -- the number of filters in the last conv layer 796 | norm_layer -- normalization layer 797 | use_dropout (bool) -- if use dropout layers 798 | n_blocks (int) -- the number of ResNet blocks 799 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero 800 | """ 801 | assert(n_blocks >= 0) 802 | super(ResnetStyle2Generator, self).__init__() 803 | self.n_blocks = n_blocks 804 | if type(norm_layer) == functools.partial: 805 | use_bias = norm_layer.func == nn.InstanceNorm2d 806 | else: 807 | use_bias = norm_layer == nn.InstanceNorm2d 808 | 809 | model0 = [nn.ReflectionPad2d(3), 810 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), 811 | norm_layer(ngf), 812 | nn.ReLU(True)] 813 | 814 | n_downsampling = 2 815 | for i in range(n_downsampling): # add downsampling layers 816 | mult = 2 ** i 817 | model0 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), 818 | norm_layer(ngf * mult * 2), 819 | nn.ReLU(True)] 820 | 821 | mult = 2 ** n_downsampling 822 | for i in range(model0_res): # add ResNet blocks 823 | model0 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 824 | 825 | model = [] 826 | model += [nn.Conv2d(ngf * mult + extra_channel, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias), 827 | norm_layer(ngf * mult), 828 | nn.ReLU(True)] 829 | 830 | for i in range(n_blocks-model0_res): # add ResNet blocks 831 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 832 | 833 | for i in range(n_downsampling): # add upsampling layers 834 | mult = 2 ** (n_downsampling - i) 835 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 836 | kernel_size=3, stride=2, 837 | padding=1, output_padding=1, 838 | bias=use_bias), 839 | norm_layer(int(ngf * mult / 2)), 840 | nn.ReLU(True)] 841 | model += [nn.ReflectionPad2d(3)] 842 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 843 | model += [nn.Tanh()] 844 | 845 | self.model0 = nn.Sequential(*model0) 846 | self.model = nn.Sequential(*model) 847 | #print(list(self.modules())) 848 | 849 | def forward(self, input1, input2, feature_mode=False, ablate_res=-1): 850 | """Standard forward""" 851 | if not feature_mode: 852 | if ablate_res == -1: 853 | f1 = self.model0(input1) 854 | y1 = torch.cat([f1, input2], 1) 855 | return self.model(y1) 856 | else: 857 | f1 = self.model0(input1) 858 | y = torch.cat([f1, input2], 1) 859 | module_list = list(self.model.modules()) 860 | for i in range(1, 4):#merge module 861 | y = module_list[i](y) 862 | for k in range(self.n_blocks):#resblocks 863 | if k+1 == ablate_res: 864 | print('skip resblock'+str(k+1)) 865 | continue 866 | y1 = y.clone() 867 | for i in range(6+9*k,13+9*k): 868 | y = module_list[i](y) 869 | y = y1 + y 870 | for i in range(4+9*self.n_blocks,13+9*self.n_blocks):#up convs 871 | y = module_list[i](y) 872 | return y 873 | else: 874 | module_list0 = list(self.model0.modules()) 875 | x = input1.clone() 876 | for i in range(1,11): 877 | x = module_list0[i](x) 878 | if i == 3: 879 | x1 = x.clone()#[1,64,512,512] 880 | elif i == 6: 881 | x2 = x.clone()#[1,128,256,256] 882 | elif i == 9: 883 | x3 = x.clone()#[1,256,128,128] 884 | #f1 = self.model0(input1)#[1,256,128,128] 885 | #pdb.set_trace() 886 | y1 = torch.cat([x, input2], 1)#[1,259,128,128] 887 | module_list = list(self.model.modules()) 888 | indexes = list(range(1,4))+[4,13,22,31,40,49,58,67,76]+list(range(85,94)) 889 | y = y1.clone() 890 | for i in indexes: 891 | y = module_list[i](y) 892 | if i == 76: 893 | y4 = y.clone()#[1,256,128,128] 894 | elif i == 86: 895 | y3 = y.clone()#[1,128,256,256] 896 | elif i == 89: 897 | y2 = y.clone()#[1,64,512,512] 898 | elif i == 40: 899 | y7 = y.clone() 900 | #out = self.model(y1) 901 | #pdb.set_trace() 902 | return y,x1,x2,x3,y4,y3,y2,y7 903 | 904 | class ResnetStyle3DecoderGenerator(nn.Module): 905 | """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. 906 | 907 | We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) 908 | """ 909 | 910 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', model0_res=0): 911 | """Construct a Resnet-based generator 912 | 913 | Parameters: 914 | input_nc (int) -- the number of channels in input images 915 | output_nc (int) -- the number of channels in output images 916 | ngf (int) -- the number of filters in the last conv layer 917 | norm_layer -- normalization layer 918 | use_dropout (bool) -- if use dropout layers 919 | n_blocks (int) -- the number of ResNet blocks 920 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero 921 | """ 922 | assert(n_blocks >= 0) 923 | super(ResnetStyle3DecoderGenerator, self).__init__() 924 | if type(norm_layer) == functools.partial: 925 | use_bias = norm_layer.func == nn.InstanceNorm2d 926 | else: 927 | use_bias = norm_layer == nn.InstanceNorm2d 928 | 929 | model0 = [nn.ReflectionPad2d(3), 930 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), 931 | norm_layer(ngf), 932 | nn.ReLU(True)] 933 | 934 | n_downsampling = 2 935 | for i in range(n_downsampling): # add downsampling layers 936 | mult = 2 ** i 937 | model0 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), 938 | norm_layer(ngf * mult * 2), 939 | nn.ReLU(True)] 940 | 941 | mult = 2 ** n_downsampling 942 | for i in range(model0_res): # add ResNet blocks 943 | model0 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 944 | 945 | model1 = [] 946 | model2 = [] 947 | model3 = [] 948 | for i in range(n_blocks-model0_res): # add ResNet blocks 949 | model1 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 950 | model2 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 951 | model3 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 952 | 953 | for i in range(n_downsampling): # add upsampling layers 954 | mult = 2 ** (n_downsampling - i) 955 | model1 += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 956 | kernel_size=3, stride=2, 957 | padding=1, output_padding=1, 958 | bias=use_bias), 959 | norm_layer(int(ngf * mult / 2)), 960 | nn.ReLU(True)] 961 | model2 += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 962 | kernel_size=3, stride=2, 963 | padding=1, output_padding=1, 964 | bias=use_bias), 965 | norm_layer(int(ngf * mult / 2)), 966 | nn.ReLU(True)] 967 | model3 += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 968 | kernel_size=3, stride=2, 969 | padding=1, output_padding=1, 970 | bias=use_bias), 971 | norm_layer(int(ngf * mult / 2)), 972 | nn.ReLU(True)] 973 | model1 += [nn.ReflectionPad2d(3)] 974 | model1 += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 975 | model1 += [nn.Tanh()] 976 | model2 += [nn.ReflectionPad2d(3)] 977 | model2 += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 978 | model2 += [nn.Tanh()] 979 | model3 += [nn.ReflectionPad2d(3)] 980 | model3 += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 981 | model3 += [nn.Tanh()] 982 | 983 | self.model0 = nn.Sequential(*model0) 984 | self.model1 = nn.Sequential(*model1) 985 | self.model2 = nn.Sequential(*model2) 986 | self.model3 = nn.Sequential(*model3) 987 | print(list(self.modules())) 988 | 989 | def forward(self, input, domain): 990 | """Standard forward""" 991 | f1 = self.model0(input) 992 | if domain == 0: 993 | y = self.model1(f1) 994 | elif domain == 1: 995 | y = self.model2(f1) 996 | elif domain == 2: 997 | y = self.model3(f1) 998 | return y 999 | 1000 | class ResnetStyle2MCGenerator(nn.Module): 1001 | # multi-column 1002 | 1003 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', extra_channel=3, model0_res=0): 1004 | """Construct a Resnet-based generator 1005 | 1006 | Parameters: 1007 | input_nc (int) -- the number of channels in input images 1008 | output_nc (int) -- the number of channels in output images 1009 | ngf (int) -- the number of filters in the last conv layer 1010 | norm_layer -- normalization layer 1011 | use_dropout (bool) -- if use dropout layers 1012 | n_blocks (int) -- the number of ResNet blocks 1013 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero 1014 | """ 1015 | assert(n_blocks >= 0) 1016 | super(ResnetStyle2MCGenerator, self).__init__() 1017 | if type(norm_layer) == functools.partial: 1018 | use_bias = norm_layer.func == nn.InstanceNorm2d 1019 | else: 1020 | use_bias = norm_layer == nn.InstanceNorm2d 1021 | 1022 | model0 = [nn.ReflectionPad2d(3), 1023 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), 1024 | norm_layer(ngf), 1025 | nn.ReLU(True)] 1026 | 1027 | n_downsampling = 2 1028 | model1_3 = [] 1029 | model1_5 = [] 1030 | for i in range(n_downsampling): # add downsampling layers 1031 | mult = 2 ** i 1032 | model1_3 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), 1033 | norm_layer(ngf * mult * 2), 1034 | nn.ReLU(True)] 1035 | model1_5 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=5, stride=2, padding=2, bias=use_bias), 1036 | norm_layer(ngf * mult * 2), 1037 | nn.ReLU(True)] 1038 | 1039 | mult = 2 ** n_downsampling 1040 | for i in range(model0_res): # add ResNet blocks 1041 | model1_3 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 1042 | model1_5 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias, kernel=5)] 1043 | 1044 | model = [] 1045 | model += [nn.Conv2d(ngf * mult * 2 + extra_channel, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias), 1046 | norm_layer(ngf * mult), 1047 | nn.ReLU(True)] 1048 | 1049 | for i in range(n_blocks-model0_res): # add ResNet blocks 1050 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 1051 | 1052 | for i in range(n_downsampling): # add upsampling layers 1053 | mult = 2 ** (n_downsampling - i) 1054 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 1055 | kernel_size=3, stride=2, 1056 | padding=1, output_padding=1, 1057 | bias=use_bias), 1058 | norm_layer(int(ngf * mult / 2)), 1059 | nn.ReLU(True)] 1060 | model += [nn.ReflectionPad2d(3)] 1061 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 1062 | model += [nn.Tanh()] 1063 | 1064 | self.model0 = nn.Sequential(*model0) 1065 | self.model1_3 = nn.Sequential(*model1_3) 1066 | self.model1_5 = nn.Sequential(*model1_5) 1067 | self.model = nn.Sequential(*model) 1068 | print(list(self.modules())) 1069 | 1070 | def forward(self, input1, input2): 1071 | """Standard forward""" 1072 | f0 = self.model0(input1) 1073 | f1 = self.model1_3(f0) 1074 | f2 = self.model1_5(f0) 1075 | y1 = torch.cat([f1, f2, input2], 1) 1076 | return self.model(y1) 1077 | 1078 | class ResnetStyle2MC2Generator(nn.Module): 1079 | # multi-column, need to insert style early 1080 | 1081 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', extra_channel=3, model0_res=0, model1_res=0): 1082 | """Construct a Resnet-based generator 1083 | 1084 | Parameters: 1085 | input_nc (int) -- the number of channels in input images 1086 | output_nc (int) -- the number of channels in output images 1087 | ngf (int) -- the number of filters in the last conv layer 1088 | norm_layer -- normalization layer 1089 | use_dropout (bool) -- if use dropout layers 1090 | n_blocks (int) -- the number of ResNet blocks 1091 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero 1092 | """ 1093 | assert(n_blocks >= 0) 1094 | super(ResnetStyle2MC2Generator, self).__init__() 1095 | if type(norm_layer) == functools.partial: 1096 | use_bias = norm_layer.func == nn.InstanceNorm2d 1097 | else: 1098 | use_bias = norm_layer == nn.InstanceNorm2d 1099 | 1100 | model0 = [nn.ReflectionPad2d(3), 1101 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), 1102 | norm_layer(ngf), 1103 | nn.ReLU(True)] 1104 | 1105 | n_downsampling = 2 1106 | model1_3 = [] 1107 | model1_5 = [] 1108 | for i in range(n_downsampling): # add downsampling layers 1109 | mult = 2 ** i 1110 | model1_3 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), 1111 | norm_layer(ngf * mult * 2), 1112 | nn.ReLU(True)] 1113 | model1_5 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=5, stride=2, padding=2, bias=use_bias), 1114 | norm_layer(ngf * mult * 2), 1115 | nn.ReLU(True)] 1116 | 1117 | mult = 2 ** n_downsampling 1118 | for i in range(model0_res): # add ResNet blocks 1119 | model1_3 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 1120 | model1_5 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias, kernel=5)] 1121 | 1122 | model2_3 = [] 1123 | model2_5 = [] 1124 | model2_3 += [nn.Conv2d(ngf * mult + extra_channel, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias), 1125 | norm_layer(ngf * mult), 1126 | nn.ReLU(True)] 1127 | model2_5 += [nn.Conv2d(ngf * mult + extra_channel, ngf * mult, kernel_size=5, stride=1, padding=2, bias=use_bias), 1128 | norm_layer(ngf * mult), 1129 | nn.ReLU(True)] 1130 | 1131 | for i in range(model1_res): # add ResNet blocks 1132 | model2_3 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 1133 | model2_5 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias, kernel=5)] 1134 | 1135 | model = [] 1136 | model += [nn.Conv2d(ngf * mult * 2, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias), 1137 | norm_layer(ngf * mult), 1138 | nn.ReLU(True)] 1139 | for i in range(n_blocks-model0_res-model1_res): # add ResNet blocks 1140 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 1141 | 1142 | for i in range(n_downsampling): # add upsampling layers 1143 | mult = 2 ** (n_downsampling - i) 1144 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 1145 | kernel_size=3, stride=2, 1146 | padding=1, output_padding=1, 1147 | bias=use_bias), 1148 | norm_layer(int(ngf * mult / 2)), 1149 | nn.ReLU(True)] 1150 | model += [nn.ReflectionPad2d(3)] 1151 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 1152 | model += [nn.Tanh()] 1153 | 1154 | self.model0 = nn.Sequential(*model0) 1155 | self.model1_3 = nn.Sequential(*model1_3) 1156 | self.model1_5 = nn.Sequential(*model1_5) 1157 | self.model2_3 = nn.Sequential(*model2_3) 1158 | self.model2_5 = nn.Sequential(*model2_5) 1159 | self.model = nn.Sequential(*model) 1160 | print(list(self.modules())) 1161 | 1162 | def forward(self, input1, input2): 1163 | """Standard forward""" 1164 | f0 = self.model0(input1) 1165 | f1 = self.model1_3(f0) 1166 | f2 = self.model1_5(f0) 1167 | f3 = self.model2_3(torch.cat([f1,input2],1)) 1168 | f4 = self.model2_5(torch.cat([f2,input2],1)) 1169 | #pdb.set_trace() 1170 | return self.model(torch.cat([f3,f4],1)) 1171 | 1172 | class ResnetBlock(nn.Module): 1173 | """Define a Resnet block""" 1174 | 1175 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, kernel=3): 1176 | """Initialize the Resnet block 1177 | 1178 | A resnet block is a conv block with skip connections 1179 | We construct a conv block with build_conv_block function, 1180 | and implement skip connections in function. 1181 | Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf 1182 | """ 1183 | super(ResnetBlock, self).__init__() 1184 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, kernel) 1185 | 1186 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, kernel=3): 1187 | """Construct a convolutional block. 1188 | 1189 | Parameters: 1190 | dim (int) -- the number of channels in the conv layer. 1191 | padding_type (str) -- the name of padding layer: reflect | replicate | zero 1192 | norm_layer -- normalization layer 1193 | use_dropout (bool) -- if use dropout layers. 1194 | use_bias (bool) -- if the conv layer uses bias or not 1195 | 1196 | Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) 1197 | """ 1198 | conv_block = [] 1199 | p = 0 1200 | pad = int((kernel-1)/2) 1201 | if padding_type == 'reflect':#by default 1202 | conv_block += [nn.ReflectionPad2d(pad)] 1203 | elif padding_type == 'replicate': 1204 | conv_block += [nn.ReplicationPad2d(pad)] 1205 | elif padding_type == 'zero': 1206 | p = pad 1207 | else: 1208 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 1209 | 1210 | conv_block += [nn.Conv2d(dim, dim, kernel_size=kernel, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] 1211 | if use_dropout: 1212 | conv_block += [nn.Dropout(0.5)] 1213 | 1214 | p = 0 1215 | if padding_type == 'reflect': 1216 | conv_block += [nn.ReflectionPad2d(pad)] 1217 | elif padding_type == 'replicate': 1218 | conv_block += [nn.ReplicationPad2d(pad)] 1219 | elif padding_type == 'zero': 1220 | p = pad 1221 | else: 1222 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 1223 | conv_block += [nn.Conv2d(dim, dim, kernel_size=kernel, padding=p, bias=use_bias), norm_layer(dim)] 1224 | 1225 | return nn.Sequential(*conv_block) 1226 | 1227 | def forward(self, x): 1228 | """Forward function (with skip connections)""" 1229 | out = x + self.conv_block(x) # add skip connections 1230 | return out 1231 | 1232 | 1233 | class UnetGenerator(nn.Module): 1234 | """Create a Unet-based generator""" 1235 | 1236 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): 1237 | """Construct a Unet generator 1238 | Parameters: 1239 | input_nc (int) -- the number of channels in input images 1240 | output_nc (int) -- the number of channels in output images 1241 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, 1242 | image of size 128x128 will become of size 1x1 # at the bottleneck 1243 | ngf (int) -- the number of filters in the last conv layer 1244 | norm_layer -- normalization layer 1245 | 1246 | We construct the U-Net from the innermost layer to the outermost layer. 1247 | It is a recursive process. 1248 | """ 1249 | super(UnetGenerator, self).__init__() 1250 | # construct unet structure 1251 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer 1252 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters 1253 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) 1254 | # gradually reduce the number of filters from ngf * 8 to ngf 1255 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 1256 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 1257 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 1258 | self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer 1259 | 1260 | def forward(self, input): 1261 | """Standard forward""" 1262 | return self.model(input) 1263 | 1264 | 1265 | class UnetSkipConnectionBlock(nn.Module): 1266 | """Defines the Unet submodule with skip connection. 1267 | X -------------------identity---------------------- 1268 | |-- downsampling -- |submodule| -- upsampling --| 1269 | """ 1270 | 1271 | def __init__(self, outer_nc, inner_nc, input_nc=None, 1272 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): 1273 | """Construct a Unet submodule with skip connections. 1274 | 1275 | Parameters: 1276 | outer_nc (int) -- the number of filters in the outer conv layer 1277 | inner_nc (int) -- the number of filters in the inner conv layer 1278 | input_nc (int) -- the number of channels in input images/features 1279 | submodule (UnetSkipConnectionBlock) -- previously defined submodules 1280 | outermost (bool) -- if this module is the outermost module 1281 | innermost (bool) -- if this module is the innermost module 1282 | norm_layer -- normalization layer 1283 | user_dropout (bool) -- if use dropout layers. 1284 | """ 1285 | super(UnetSkipConnectionBlock, self).__init__() 1286 | self.outermost = outermost 1287 | if type(norm_layer) == functools.partial: 1288 | use_bias = norm_layer.func == nn.InstanceNorm2d 1289 | else: 1290 | use_bias = norm_layer == nn.InstanceNorm2d 1291 | if input_nc is None: 1292 | input_nc = outer_nc 1293 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, 1294 | stride=2, padding=1, bias=use_bias) 1295 | downrelu = nn.LeakyReLU(0.2, True) 1296 | downnorm = norm_layer(inner_nc) 1297 | uprelu = nn.ReLU(True) 1298 | upnorm = norm_layer(outer_nc) 1299 | 1300 | if outermost: 1301 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 1302 | kernel_size=4, stride=2, 1303 | padding=1) 1304 | down = [downconv] 1305 | up = [uprelu, upconv, nn.Tanh()] 1306 | model = down + [submodule] + up 1307 | elif innermost: 1308 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 1309 | kernel_size=4, stride=2, 1310 | padding=1, bias=use_bias) 1311 | down = [downrelu, downconv] 1312 | up = [uprelu, upconv, upnorm] 1313 | model = down + up 1314 | else: 1315 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 1316 | kernel_size=4, stride=2, 1317 | padding=1, bias=use_bias) 1318 | down = [downrelu, downconv, downnorm] 1319 | up = [uprelu, upconv, upnorm] 1320 | 1321 | if use_dropout: 1322 | model = down + [submodule] + up + [nn.Dropout(0.5)] 1323 | else: 1324 | model = down + [submodule] + up 1325 | 1326 | self.model = nn.Sequential(*model) 1327 | 1328 | def forward(self, x): 1329 | if self.outermost: 1330 | return self.model(x) 1331 | else: # add skip connections 1332 | return torch.cat([x, self.model(x)], 1) 1333 | 1334 | 1335 | class NLayerDiscriminator(nn.Module): 1336 | """Defines a PatchGAN discriminator""" 1337 | 1338 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): 1339 | """Construct a PatchGAN discriminator 1340 | 1341 | Parameters: 1342 | input_nc (int) -- the number of channels in input images 1343 | ndf (int) -- the number of filters in the last conv layer 1344 | n_layers (int) -- the number of conv layers in the discriminator 1345 | norm_layer -- normalization layer 1346 | """ 1347 | super(NLayerDiscriminator, self).__init__() 1348 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 1349 | use_bias = norm_layer.func != nn.BatchNorm2d 1350 | else: 1351 | use_bias = norm_layer != nn.BatchNorm2d 1352 | 1353 | kw = 4 1354 | padw = 1 1355 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 1356 | nf_mult = 1 1357 | nf_mult_prev = 1 1358 | for n in range(1, n_layers): # gradually increase the number of filters 1359 | nf_mult_prev = nf_mult 1360 | nf_mult = min(2 ** n, 8) 1361 | sequence += [ 1362 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 1363 | norm_layer(ndf * nf_mult), 1364 | nn.LeakyReLU(0.2, True) 1365 | ] 1366 | 1367 | nf_mult_prev = nf_mult 1368 | nf_mult = min(2 ** n_layers, 8) 1369 | sequence += [ 1370 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 1371 | norm_layer(ndf * nf_mult), 1372 | nn.LeakyReLU(0.2, True) 1373 | ] 1374 | 1375 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 1376 | self.model = nn.Sequential(*sequence) 1377 | 1378 | def forward(self, input): 1379 | """Standard forward.""" 1380 | return self.model(input) 1381 | 1382 | 1383 | class NLayerDiscriminatorCls(nn.Module): 1384 | """Defines a PatchGAN discriminator""" 1385 | 1386 | def __init__(self, input_nc, ndf=64, n_layers=3, n_class=3, norm_layer=nn.BatchNorm2d): 1387 | """Construct a PatchGAN discriminator 1388 | 1389 | Parameters: 1390 | input_nc (int) -- the number of channels in input images 1391 | ndf (int) -- the number of filters in the last conv layer 1392 | n_layers (int) -- the number of conv layers in the discriminator 1393 | norm_layer -- normalization layer 1394 | """ 1395 | super(NLayerDiscriminatorCls, self).__init__() 1396 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 1397 | use_bias = norm_layer.func != nn.BatchNorm2d 1398 | else: 1399 | use_bias = norm_layer != nn.BatchNorm2d 1400 | 1401 | kw = 4 1402 | padw = 1 1403 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 1404 | nf_mult = 1 1405 | nf_mult_prev = 1 1406 | for n in range(1, n_layers): # gradually increase the number of filters 1407 | nf_mult_prev = nf_mult 1408 | nf_mult = min(2 ** n, 8) 1409 | sequence += [ 1410 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 1411 | norm_layer(ndf * nf_mult), 1412 | nn.LeakyReLU(0.2, True) 1413 | ] 1414 | 1415 | nf_mult_prev = nf_mult 1416 | nf_mult = min(2 ** n_layers, 8) 1417 | sequence1 = [ 1418 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 1419 | norm_layer(ndf * nf_mult), 1420 | nn.LeakyReLU(0.2, True) 1421 | ] 1422 | sequence1 += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 1423 | 1424 | sequence2 = [ 1425 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 1426 | norm_layer(ndf * nf_mult), 1427 | nn.LeakyReLU(0.2, True) 1428 | ] 1429 | sequence2 += [ 1430 | nn.Conv2d(ndf * nf_mult, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 1431 | norm_layer(ndf * nf_mult), 1432 | nn.LeakyReLU(0.2, True) 1433 | ] 1434 | sequence2 += [ 1435 | nn.Conv2d(ndf * nf_mult, n_class, kernel_size=16, stride=1, padding=0, bias=use_bias)] 1436 | 1437 | 1438 | self.model0 = nn.Sequential(*sequence) 1439 | self.model1 = nn.Sequential(*sequence1) 1440 | self.model2 = nn.Sequential(*sequence2) 1441 | print(list(self.modules())) 1442 | 1443 | def forward(self, input): 1444 | """Standard forward.""" 1445 | feat = self.model0(input) 1446 | # patchGAN output (1 * 62 * 62) 1447 | patch = self.model1(feat) 1448 | # class output (3 * 1 * 1) 1449 | classl = self.model2(feat) 1450 | return patch, classl.view(classl.size(0), -1) 1451 | 1452 | 1453 | class PixelDiscriminator(nn.Module): 1454 | """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" 1455 | 1456 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): 1457 | """Construct a 1x1 PatchGAN discriminator 1458 | 1459 | Parameters: 1460 | input_nc (int) -- the number of channels in input images 1461 | ndf (int) -- the number of filters in the last conv layer 1462 | norm_layer -- normalization layer 1463 | """ 1464 | super(PixelDiscriminator, self).__init__() 1465 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 1466 | use_bias = norm_layer.func != nn.InstanceNorm2d 1467 | else: 1468 | use_bias = norm_layer != nn.InstanceNorm2d 1469 | 1470 | self.net = [ 1471 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), 1472 | nn.LeakyReLU(0.2, True), 1473 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), 1474 | norm_layer(ndf * 2), 1475 | nn.LeakyReLU(0.2, True), 1476 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] 1477 | 1478 | self.net = nn.Sequential(*self.net) 1479 | 1480 | def forward(self, input): 1481 | """Standard forward.""" 1482 | return self.net(input) 1483 | 1484 | 1485 | class HED(nn.Module): 1486 | def __init__(self): 1487 | super(HED, self).__init__() 1488 | 1489 | self.moduleVggOne = nn.Sequential( 1490 | nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), 1491 | nn.ReLU(inplace=False), 1492 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), 1493 | nn.ReLU(inplace=False) 1494 | ) 1495 | 1496 | self.moduleVggTwo = nn.Sequential( 1497 | nn.MaxPool2d(kernel_size=2, stride=2), 1498 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), 1499 | nn.ReLU(inplace=False), 1500 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), 1501 | nn.ReLU(inplace=False) 1502 | ) 1503 | 1504 | self.moduleVggThr = nn.Sequential( 1505 | nn.MaxPool2d(kernel_size=2, stride=2), 1506 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1), 1507 | nn.ReLU(inplace=False), 1508 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), 1509 | nn.ReLU(inplace=False), 1510 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), 1511 | nn.ReLU(inplace=False) 1512 | ) 1513 | 1514 | self.moduleVggFou = nn.Sequential( 1515 | nn.MaxPool2d(kernel_size=2, stride=2), 1516 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1), 1517 | nn.ReLU(inplace=False), 1518 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), 1519 | nn.ReLU(inplace=False), 1520 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), 1521 | nn.ReLU(inplace=False) 1522 | ) 1523 | 1524 | self.moduleVggFiv = nn.Sequential( 1525 | nn.MaxPool2d(kernel_size=2, stride=2), 1526 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), 1527 | nn.ReLU(inplace=False), 1528 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), 1529 | nn.ReLU(inplace=False), 1530 | nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), 1531 | nn.ReLU(inplace=False) 1532 | ) 1533 | 1534 | self.moduleScoreOne = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0) 1535 | self.moduleScoreTwo = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0) 1536 | self.moduleScoreThr = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0) 1537 | self.moduleScoreFou = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0) 1538 | self.moduleScoreFiv = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0) 1539 | 1540 | self.moduleCombine = nn.Sequential( 1541 | nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0), 1542 | nn.Sigmoid() 1543 | ) 1544 | 1545 | def forward(self, tensorInput): 1546 | tensorBlue = (tensorInput[:, 2:3, :, :] * 255.0) - 104.00698793 1547 | tensorGreen = (tensorInput[:, 1:2, :, :] * 255.0) - 116.66876762 1548 | tensorRed = (tensorInput[:, 0:1, :, :] * 255.0) - 122.67891434 1549 | 1550 | tensorInput = torch.cat([ tensorBlue, tensorGreen, tensorRed ], 1) 1551 | 1552 | tensorVggOne = self.moduleVggOne(tensorInput) 1553 | tensorVggTwo = self.moduleVggTwo(tensorVggOne) 1554 | tensorVggThr = self.moduleVggThr(tensorVggTwo) 1555 | tensorVggFou = self.moduleVggFou(tensorVggThr) 1556 | tensorVggFiv = self.moduleVggFiv(tensorVggFou) 1557 | 1558 | tensorScoreOne = self.moduleScoreOne(tensorVggOne) 1559 | tensorScoreTwo = self.moduleScoreTwo(tensorVggTwo) 1560 | tensorScoreThr = self.moduleScoreThr(tensorVggThr) 1561 | tensorScoreFou = self.moduleScoreFou(tensorVggFou) 1562 | tensorScoreFiv = self.moduleScoreFiv(tensorVggFiv) 1563 | 1564 | tensorScoreOne = nn.functional.interpolate(input=tensorScoreOne, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False) 1565 | tensorScoreTwo = nn.functional.interpolate(input=tensorScoreTwo, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False) 1566 | tensorScoreThr = nn.functional.interpolate(input=tensorScoreThr, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False) 1567 | tensorScoreFou = nn.functional.interpolate(input=tensorScoreFou, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False) 1568 | tensorScoreFiv = nn.functional.interpolate(input=tensorScoreFiv, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False) 1569 | 1570 | return self.moduleCombine(torch.cat([ tensorScoreOne, tensorScoreTwo, tensorScoreThr, tensorScoreFou, tensorScoreFiv ], 1)) 1571 | 1572 | # class for VGG19 modle 1573 | # borrows largely from torchvision vgg 1574 | class VGG19(nn.Module): 1575 | def __init__(self, init_weights=None, feature_mode=False, batch_norm=False, num_classes=1000): 1576 | super(VGG19, self).__init__() 1577 | self.cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'] 1578 | self.init_weights = init_weights 1579 | self.feature_mode = feature_mode 1580 | self.batch_norm = batch_norm 1581 | self.num_clases = num_classes 1582 | self.features = self.make_layers(self.cfg, batch_norm) 1583 | self.classifier = nn.Sequential( 1584 | nn.Linear(512 * 7 * 7, 4096), 1585 | nn.ReLU(True), 1586 | nn.Dropout(), 1587 | nn.Linear(4096, 4096), 1588 | nn.ReLU(True), 1589 | nn.Dropout(), 1590 | nn.Linear(4096, num_classes), 1591 | ) 1592 | # print('----------load the pretrained vgg net---------') 1593 | # if not init_weights == None: 1594 | # print('load the weights') 1595 | # self.load_state_dict(torch.load(init_weights)) 1596 | 1597 | 1598 | def make_layers(self, cfg, batch_norm=False): 1599 | layers = [] 1600 | in_channels = 3 1601 | for v in cfg: 1602 | if v == 'M': 1603 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 1604 | else: 1605 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 1606 | if batch_norm: 1607 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 1608 | else: 1609 | layers += [conv2d, nn.ReLU(inplace=True)] 1610 | in_channels = v 1611 | return nn.Sequential(*layers) 1612 | 1613 | def forward(self, x): 1614 | if self.feature_mode: 1615 | module_list = list(self.features.modules()) 1616 | for l in module_list[1:27]: # conv4_4 1617 | x = l(x) 1618 | if not self.feature_mode: 1619 | x = self.features(x) 1620 | x = x.view(x.size(0), -1) 1621 | x = self.classifier(x) 1622 | 1623 | return x 1624 | 1625 | class Classifier(nn.Module): 1626 | def __init__(self, input_nc, classes, ngf=64, num_downs=3, norm_layer=nn.BatchNorm2d, use_dropout=False, h=512, w=512, dim=4096): 1627 | super(Classifier, self).__init__() 1628 | self.input_nc = input_nc 1629 | self.ngf = ngf 1630 | if type(norm_layer) == functools.partial: 1631 | use_bias = norm_layer.func == nn.InstanceNorm2d 1632 | else: 1633 | use_bias = norm_layer == nn.InstanceNorm2d 1634 | 1635 | model = [nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias), nn.LeakyReLU(0.2, True)] 1636 | nf_mult = 1 1637 | nf_mult_prev = 1 1638 | for n in range(1, num_downs): 1639 | nf_mult_prev = nf_mult 1640 | nf_mult = min(2 ** n, 8) 1641 | model += [ 1642 | nn.Conv2d(int(ngf * nf_mult_prev), int(ngf * nf_mult), kernel_size=4, stride=2, padding=1, bias=use_bias), 1643 | norm_layer(int(ngf * nf_mult)), 1644 | nn.LeakyReLU(0.2, True) 1645 | ] 1646 | nf_mult_prev = nf_mult 1647 | nf_mult = min(2 ** num_downs, 8) 1648 | model += [ 1649 | nn.Conv2d(ngf * nf_mult_prev, ngf * nf_mult, kernel_size=4, stride=1, padding=1, bias=use_bias), 1650 | norm_layer(ngf * nf_mult), 1651 | nn.LeakyReLU(0.2, True) 1652 | ] 1653 | self.encoder = nn.Sequential(*model) 1654 | 1655 | self.classifier = nn.Sequential( 1656 | nn.Linear(512 * 7 * 7, dim), 1657 | nn.ReLU(True), 1658 | nn.Dropout(), 1659 | nn.Linear(dim, dim), 1660 | nn.ReLU(True), 1661 | nn.Dropout(), 1662 | nn.Linear(dim, classes), 1663 | ) 1664 | 1665 | def forward(self, x): 1666 | ax = self.encoder(x) 1667 | #print('ax',ax.shape) # (8, 512, 7, 7) 1668 | ax = ax.view(ax.size(0), -1) # view -- reshape 1669 | return self.classifier(ax) 1670 | --------------------------------------------------------------------------------