├── README.md ├── checkpoints └── pretrained │ ├── loss_log.txt │ ├── test_opt.txt │ └── web │ ├── TestImages │ └── epoch_latest │ │ ├── 1.png │ │ ├── 2.png │ │ ├── 3.png │ │ ├── 4.png │ │ ├── 5.png │ │ ├── 6.png │ │ ├── 7.png │ │ └── 8.png │ └── index.html ├── data ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── base_dataset.cpython-39.pyc │ └── cocoart_dataset.cpython-39.pyc ├── base_dataset.py └── cocoart_dataset.py ├── examples ├── comp │ ├── 1.jpg │ ├── 2.jpg │ ├── 3.jpg │ ├── 4.jpg │ ├── 5.jpg │ ├── 6.jpg │ ├── 7.jpg │ └── 8.jpg ├── mask │ ├── 1.png │ ├── 2.png │ ├── 3.png │ ├── 4.png │ ├── 5.png │ ├── 6.png │ ├── 7.png │ └── 8.png └── style │ ├── 1.jpg │ ├── 2.jpg │ ├── 3.jpg │ ├── 4.jpg │ ├── 5.jpg │ ├── 6.jpg │ ├── 7.jpg │ └── 8.jpg ├── figures └── result.jpg ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── base_model.cpython-39.pyc │ ├── networks.cpython-39.pyc │ └── vgg19hrnet_model.cpython-39.pyc ├── base_model.py ├── networks.py └── vgg19hrnet_model.py ├── options ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── base_options.cpython-39.pyc │ └── test_options.cpython-39.pyc ├── base_options.py ├── test_options.py └── train_options.py ├── output ├── 1.jpg ├── 2.jpg ├── 3.jpg ├── 4.jpg ├── 5.jpg ├── 6.jpg ├── 7.jpg └── 8.jpg ├── requirements.txt ├── scripts ├── test.sh └── train.sh ├── test.py ├── train.py └── util ├── __init__.py ├── __pycache__ ├── __init__.cpython-39.pyc ├── html.cpython-39.pyc ├── util.cpython-39.pyc └── visualizer.cpython-39.pyc ├── config.py ├── html.py ├── image_pool.py ├── spectral_norm.py ├── util.py └── visualizer.py /README.md: -------------------------------------------------------------------------------- 1 | # ProPIH-Painterly-Image-Harmonization 2 | 3 | 4 | We release the code used in the following paper: 5 | > **Progressive Painterly Image Harmonization from Low-level Styles to High-level Styles** [[arXiv]](https://arxiv.org/pdf/2312.10264.pdf)
6 | > 7 | > Li Niu, Yan Hong, Junyan Cao, Liqing Zhang 8 | > 9 | > Accepted by AAAI 2024 10 | 11 | Our method can harmonize a composite image from low-level styles to high-level styles. The results harmonized to the highest style level have sufficiently stylized foregrounds, but also take the risk of content distortion and artifacts. The users can select the result harmonized to the proper style level. 12 | 13 | 14 | 15 | ## Prerequisites 16 | - Linux 17 | - Python 3.9 18 | - PyTorch 1.10 19 | - NVIDIA GPU + CUDA 20 | 21 | ## Getting Started 22 | ### Installation 23 | - Clone this repo: 24 | 25 | ```bash 26 | git clone https://github.com/bcmi/ProPIH-Painterly-Image-Harmonization.git 27 | ``` 28 | 29 | - Prepare the datasets as in [PHDNet](https://github.com/bcmi/PHDNet-Painterly-Image-Harmonization/). 30 | 31 | - Install PyTorch and dependencies: 32 | 33 | ```bash 34 | conda create -n ProPIH python=3.9 35 | conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge 36 | ``` 37 | 38 | - Install python requirements: 39 | 40 | ```bash 41 | pip install -r requirements.txt 42 | ``` 43 | 44 | - Download pre-trained VGG19 from [Baidu Cloud](https://pan.baidu.com/s/1HljOE-4Q2yUeeWmteu0nNA) (access code: pc9y) or [Dropbox](https://www.dropbox.com/scl/fi/xe7pqh840enc16wq5f00r/vgg_normalised.pth?rlkey=c09ynaraeir01b1xsrxwmmfqg&st=tsvcikom&dl=0). Put it in `.//pretrained` 45 | 46 | ### ProPIH train/test 47 | - Train ProPIH: 48 | 49 | Modify the `content_dir` and `style_dir` to the corresponding path of each dataset in `train.sh`. 50 | 51 | ```bash 52 | cd scripts 53 | bash train.sh 54 | ``` 55 | 56 | The trained model would be saved in `.///`. If you want to load a model and resume training, add `--continue_train` and set the `--epoch XX` in `train.sh`. It would load the model `.///_net_G.pth`. 57 | For example, if the model is saved in `./AA/BB/latest_net_G.pth`, the `checkpoints_dir` should be `../AA/`, the `name` should be `BB`, and the `epoch` should be `latest`. 58 | 59 | - Test ProPIH: 60 | 61 | 62 | 63 | Our pre-trained model is available in [Baidu Cloud](https://pan.baidu.com/s/1CDSnqzlcLKZGD7fzIFp5Qg) (access code: azir) or [Dropbox](https://www.dropbox.com/scl/fi/hs3cpdrjsojpy9s8z0nb2/latest_net_G.pth?rlkey=gyd9c3tlj5k7e6qufcxxdvzq8&st=bigeqodb&dl=0). Put it in `.//pretrained`. We provide some test examples in `./examples`. 64 | 65 | ```bash 66 | cd scripts 67 | bash test.sh 68 | ``` 69 | The output results would be saved in `./output`. Some results are shown below. We can see that from stage 1 to stage 4, the composite images are harmonized progressively from low-level styles (color, simple texture) to high-level styles (complex texture). 70 | 71 |
72 | harmonization_results 73 |
74 | 75 | ## Other Resources 76 | 77 | + [Awesome-Image-Harmonization](https://github.com/bcmi/Awesome-Image-Harmonization) 78 | + [Awesome-Image-Composition](https://github.com/bcmi/Awesome-Object-Insertion) 79 | -------------------------------------------------------------------------------- /checkpoints/pretrained/loss_log.txt: -------------------------------------------------------------------------------- 1 | ================ Training Loss (Sun Feb 11 15:02:56 2024) ================ 2 | ================ Training Loss (Sun Feb 11 15:09:57 2024) ================ 3 | ================ Training Loss (Sun Feb 11 15:11:04 2024) ================ 4 | ================ Training Loss (Sun Feb 11 15:36:27 2024) ================ 5 | ================ Training Loss (Sun Feb 11 15:37:09 2024) ================ 6 | ================ Training Loss (Sun Feb 11 15:38:05 2024) ================ 7 | ================ Training Loss (Sun Feb 11 15:38:56 2024) ================ 8 | ================ Training Loss (Sun Feb 11 15:40:29 2024) ================ 9 | -------------------------------------------------------------------------------- /checkpoints/pretrained/test_opt.txt: -------------------------------------------------------------------------------- 1 | ----------------- Options --------------- 2 | batch_size: 1 3 | beta1: 0.5 4 | checkpoints_dir: ../checkpoints/ [default: ./checkpoints] 5 | content_dir: ../examples/ [default: /MS-COCO/] 6 | continue_train: False 7 | crop_size: 256 8 | d_lr_ratio: 1 9 | dataset_mode: cocoart [default: iharmony4] 10 | dataset_root: here 11 | display_env: main 12 | display_freq: 1 [default: 400] 13 | display_id: 0 [default: 1] 14 | display_ncols: 4 15 | display_port: 8097 16 | display_server: http://localhost 17 | display_winsize: 256 18 | epoch: latest 19 | epoch_count: 1 20 | g_lr_ratio: 1 21 | gan_mode: vanilla 22 | gpu_ids: 0 23 | init_gain: 0.02 24 | init_type: normal 25 | input_nc: 3 [default: 4] 26 | isTrain: True [default: None] 27 | is_train: True 28 | load_iter: 0 29 | load_size: 256 30 | lr: 0.0002 31 | lr_decay_iters: 50 32 | lr_policy: linear 33 | max_dataset_size: inf 34 | model: vgg19hrnet [default: cycle_gan] 35 | n_layers_D: 3 36 | name: pretrained [default: experiment_name] 37 | ndf: 64 38 | netD: conv [default: basic] 39 | netG: vgg19hrnet [default: resnet_9blocks] 40 | ngf: 64 41 | niter: 100 42 | niter_decay: 100 43 | no_dropout: False 44 | no_html: False 45 | normD: batch [default: instance] 46 | normG: batch [default: RAN_Method1] 47 | num_threads: 6 [default: 4] 48 | output_nc: 3 49 | patch_number: 4 50 | phase: test 51 | pool_size: 0 52 | preprocess: none [default: resize_and_crop] 53 | print_freq: 1000 [default: 300] 54 | save_by_iter: False 55 | save_epoch_freq: 1 56 | save_latest_freq: 1000 [default: 5000] 57 | serial_batches: False 58 | style_dir: ../examples/ [default: /wikiart/] 59 | suffix: 60 | update_html_freq: 500 61 | verbose: False 62 | vgg: ../checkpoints/pretrained/vgg_normalised.pth 63 | ----------------- End ------------------- 64 | -------------------------------------------------------------------------------- /checkpoints/pretrained/web/TestImages/epoch_latest/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/checkpoints/pretrained/web/TestImages/epoch_latest/1.png -------------------------------------------------------------------------------- /checkpoints/pretrained/web/TestImages/epoch_latest/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/checkpoints/pretrained/web/TestImages/epoch_latest/2.png -------------------------------------------------------------------------------- /checkpoints/pretrained/web/TestImages/epoch_latest/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/checkpoints/pretrained/web/TestImages/epoch_latest/3.png -------------------------------------------------------------------------------- /checkpoints/pretrained/web/TestImages/epoch_latest/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/checkpoints/pretrained/web/TestImages/epoch_latest/4.png -------------------------------------------------------------------------------- /checkpoints/pretrained/web/TestImages/epoch_latest/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/checkpoints/pretrained/web/TestImages/epoch_latest/5.png -------------------------------------------------------------------------------- /checkpoints/pretrained/web/TestImages/epoch_latest/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/checkpoints/pretrained/web/TestImages/epoch_latest/6.png -------------------------------------------------------------------------------- /checkpoints/pretrained/web/TestImages/epoch_latest/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/checkpoints/pretrained/web/TestImages/epoch_latest/7.png -------------------------------------------------------------------------------- /checkpoints/pretrained/web/TestImages/epoch_latest/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/checkpoints/pretrained/web/TestImages/epoch_latest/8.png -------------------------------------------------------------------------------- /checkpoints/pretrained/web/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Experiment name = 5 | 6 | 7 | 8 |

epoch [8]

9 | 10 | 11 | 19 | 20 |
12 |

13 | 14 | 15 |
16 |

pretrained

17 |

18 |
21 |

epoch [7]

22 | 23 | 24 | 32 | 33 |
25 |

26 | 27 | 28 |
29 |

pretrained

30 |

31 |
34 |

epoch [6]

35 | 36 | 37 | 45 | 46 |
38 |

39 | 40 | 41 |
42 |

pretrained

43 |

44 |
47 |

epoch [5]

48 | 49 | 50 | 58 | 59 |
51 |

52 | 53 | 54 |
55 |

pretrained

56 |

57 |
60 |

epoch [4]

61 | 62 | 63 | 71 | 72 |
64 |

65 | 66 | 67 |
68 |

pretrained

69 |

70 |
73 |

epoch [3]

74 | 75 | 76 | 84 | 85 |
77 |

78 | 79 | 80 |
81 |

pretrained

82 |

83 |
86 |

epoch [2]

87 | 88 | 89 | 97 | 98 |
90 |

91 | 92 | 93 |
94 |

pretrained

95 |

96 |
99 |

epoch [1]

100 | 101 | 102 | 110 | 111 |
103 |

104 | 105 | 106 |
107 |

pretrained

108 |

109 |
112 | 113 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | '''dataset loader''' 2 | import torch.utils.data 3 | from data.base_dataset import BaseDataset 4 | from data.cocoart_dataset import COCOARTDataset 5 | import numpy as np 6 | import random 7 | 8 | ''' 9 | def setup_seed(seed): 10 | torch.manual_seed(seed) 11 | torch.cuda.manual_seed_all(seed) 12 | np.random.seed(seed) 13 | random.seed(seed) 14 | torch.backends.cudnn.deterministic = True 15 | 16 | setup_seed(1) 17 | ''' 18 | 19 | class CustomDataset(object): 20 | """User-defined dataset 21 | 22 | Example usage: 23 | >>> from data import CustomDataset 24 | >>> dataset = CustomDataset(opt, is_for_train) 25 | >>> dataloader = dataset.load_data() 26 | """ 27 | def __init__(self, opt, is_for_train,vgg=None): 28 | self.opt = opt 29 | if opt.dataset_mode.lower() == 'cocoart': 30 | self.dataset = COCOARTDataset(opt,is_for_train) 31 | print("dataset [%s] was created" % type(self.dataset).__name__) 32 | else: 33 | raise ValueError(opt.dataset_mode, "not implmented.") 34 | 35 | print(len(self.dataset)) 36 | self.dataloader = torch.utils.data.DataLoader( 37 | self.dataset, 38 | batch_size=opt.batch_size, 39 | shuffle=is_for_train, 40 | # shuffle=False, 41 | num_workers=int(opt.num_threads), 42 | drop_last=False) 43 | 44 | def load_data(self): 45 | return self.dataloader 46 | 47 | def __len__(self): 48 | return len(self.dataset) -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/data/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /data/__pycache__/base_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/data/__pycache__/base_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /data/__pycache__/cocoart_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/data/__pycache__/cocoart_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 3 | """ 4 | 5 | import random 6 | import numpy as np 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | from abc import ABC, abstractmethod 11 | 12 | 13 | class BaseDataset(data.Dataset, ABC): 14 | """This class is an abstract base class (ABC) for datasets. 15 | 16 | To create a subclass, you need to implement the following four functions: 17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 18 | -- <__len__>: return the size of dataset. 19 | -- <__getitem__>: get a data point. 20 | -- : (optionally) add dataset-specific options and set default options. 21 | """ 22 | 23 | def __init__(self, opt,vgg=None): 24 | """Initialize the class; save the options in the class 25 | 26 | Parameters: 27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 28 | """ 29 | self.opt = opt 30 | self.root = opt.dataset_root #mia 31 | 32 | @staticmethod 33 | def modify_commandline_options(parser, is_train): 34 | """Add new dataset-specific options, and rewrite default values for existing options. 35 | 36 | Parameters: 37 | parser -- original option parser 38 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 39 | 40 | Returns: 41 | the modified parser. 42 | """ 43 | return parser 44 | 45 | @abstractmethod 46 | def __len__(self): 47 | """Return the total number of images in the dataset.""" 48 | return 0 49 | 50 | @abstractmethod 51 | def __getitem__(self, index): 52 | """Return a data point and its metadata information. 53 | 54 | Parameters: 55 | index - - a random integer for data indexing 56 | 57 | Returns: 58 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 59 | """ 60 | pass 61 | 62 | 63 | def get_params(opt, size): 64 | w, h = size 65 | new_h = h 66 | new_w = w 67 | if opt.preprocess == 'resize_and_crop': 68 | new_h = new_w = opt.load_size 69 | elif opt.preprocess == 'scale_width_and_crop': 70 | new_w = opt.load_size 71 | new_h = opt.load_size * h // w 72 | 73 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 74 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 75 | 76 | flip = random.random() > 0.5 77 | 78 | return {'crop_pos': (x, y), 'flip': flip} 79 | 80 | 81 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 82 | transform_list = [] 83 | if grayscale: 84 | transform_list.append(transforms.Grayscale(1)) 85 | if 'resize' in opt.preprocess: 86 | osize = [opt.load_size, opt.load_size] 87 | transform_list.append(transforms.Resize(osize, method)) 88 | elif 'scale_width' in opt.preprocess: 89 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) 90 | if 'crop' in opt.preprocess: 91 | if params is None: 92 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 93 | else: 94 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 95 | if opt.preprocess == 'none': 96 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 97 | 98 | if convert: 99 | transform_list += [transforms.ToTensor()] 100 | if grayscale: 101 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 102 | else: 103 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 104 | return transforms.Compose(transform_list) 105 | 106 | 107 | def __make_power_2(img, base, method=Image.BICUBIC): 108 | ow, oh = img.size 109 | h = int(round(oh / base) * base) 110 | w = int(round(ow / base) * base) 111 | if (h == oh) and (w == ow): 112 | return img 113 | 114 | __print_size_warning(ow, oh, w, h) 115 | return img.resize((w, h), method) 116 | 117 | 118 | def __scale_width(img, target_width, method=Image.BICUBIC): 119 | ow, oh = img.size 120 | if (ow == target_width): 121 | return img 122 | w = target_width 123 | h = int(target_width * oh / ow) 124 | return img.resize((w, h), method) 125 | 126 | 127 | def __crop(img, pos, size): 128 | ow, oh = img.size 129 | x1, y1 = pos 130 | tw = th = size 131 | if (ow > tw or oh > th): 132 | return img.crop((x1, y1, x1 + tw, y1 + th)) 133 | return img 134 | 135 | 136 | def __flip(img, flip): 137 | if flip: 138 | return img.transpose(Image.FLIP_LEFT_RIGHT) 139 | return img 140 | 141 | 142 | def __print_size_warning(ow, oh, w, h): 143 | """Print warning information about image size(only print once)""" 144 | if not hasattr(__print_size_warning, 'has_printed'): 145 | print("The image size needs to be a multiple of 4. " 146 | "The loaded image size was (%d, %d), so it was adjusted to " 147 | "(%d, %d). This adjustment will be done to all images " 148 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 149 | __print_size_warning.has_printed = True 150 | -------------------------------------------------------------------------------- /data/cocoart_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torch 3 | import random 4 | import torchvision.transforms.functional as tf 5 | from data.base_dataset import BaseDataset, get_transform 6 | import PIL 7 | 8 | from PIL import Image 9 | import numpy as np 10 | import torchvision.transforms as transforms 11 | from pathlib import Path 12 | 13 | # ImageFile.LOAD_TRUNCATED_IMAGES = True 14 | from PIL import Image, ImageFile 15 | Image.MAX_IMAGE_PIXELS = None # Disable DecompressionBombError 16 | ImageFile.LOAD_TRUNCATED_IMAGES = True 17 | import time 18 | import random 19 | 20 | import cv2 21 | random.seed(1) 22 | #random.seed(2) 23 | 24 | 25 | def mask_bboxregion(mask): 26 | w,h = np.shape(mask)[:2] 27 | valid_index = np.argwhere(mask==255) # [length,2] 28 | if np.shape(valid_index)[0] < 1: 29 | x_left = 0 30 | x_right = 0 31 | y_bottom = 0 32 | y_top = 0 33 | else: 34 | x_left = np.min(valid_index[:,0]) 35 | x_right = np.max(valid_index[:,0]) 36 | y_bottom = np.min(valid_index[:,1]) 37 | y_top = np.max(valid_index[:,1]) 38 | region = mask[x_left:x_right,y_bottom:y_top] 39 | return region 40 | # return [x_left, y_top, x_right, y_bottom] 41 | 42 | def findContours(im): 43 | """ 44 | Wraps cv2.findContours to maintain compatiblity between versions 45 | 3 and 4 46 | Returns: 47 | contours, hierarchy 48 | """ 49 | img = cv2.cvtColor(im, cv2.COLOR_RGB2BGR) # PIL转cv2 50 | imgray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 51 | ret, thresh = cv2.threshold(imgray, 127, 255, 0) 52 | contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 53 | # if cv2.__version__.startswith('4'): 54 | # contours, hierarchy = cv2.findContours(*args, **kwargs) 55 | # elif cv2.__version__.startswith('3'): 56 | # _, contours, hierarchy = cv2.findContours(*args, **kwargs) 57 | # else: 58 | # raise AssertionError( 59 | # 'cv2 must be either version 3 or 4 to call this method') 60 | return contours, hierarchy 61 | 62 | 63 | class COCOARTDataset(BaseDataset): 64 | """A template dataset class for you to implement custom datasets.""" 65 | def __init__(self, opt, is_for_train): 66 | """Initialize this dataset class. 67 | Parameters: 68 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 69 | A few things can be done here. 70 | - save the options (have been done in BaseDataset) 71 | - get image paths and meta information of the dataset. 72 | - define the image transformation. 73 | """ 74 | # save the option and dataset root 75 | BaseDataset.__init__(self, opt) 76 | ## content and mask 77 | self.path_img = [] 78 | self.path_mask = [] 79 | ## style 80 | self.path_style = [] 81 | self.isTrain = is_for_train 82 | self.opt = opt 83 | self._load_images_paths() 84 | self.transform = get_transform(opt) 85 | 86 | 87 | 88 | def _load_images_paths(self,): 89 | if self.isTrain: 90 | path = self.opt.content_dir + '/train2014/' 91 | else: 92 | path = self.opt.content_dir + '/val2014/' 93 | 94 | self.paths_all = list(Path(path).glob('*')) 95 | 96 | if self.isTrain: 97 | print('loading training set') 98 | for path in self.paths_all: 99 | path_mask = str(path).replace('train2014','SegmentationClass_select',1).replace('jpg','png') 100 | if not os.path.exists(path_mask): 101 | # print('train not exist',path_mask) 102 | continue 103 | else: 104 | self.path_img.append(str(path)) 105 | self.path_mask.append(path_mask) 106 | ## style 107 | for f in open(self.opt.style_dir + 'WikiArt_Split/style_train.csv'): 108 | self.path_style.append(self.opt.style_dir + f.split(',')[0]) 109 | random.shuffle(self.path_style) 110 | 111 | else: 112 | print('loading testing set') 113 | for img_name in os.listdir(os.path.join(self.opt.content_dir, 'comp')): 114 | self.path_img.append(os.path.join(self.opt.content_dir, 'comp', img_name)) 115 | self.path_mask.append(os.path.join(self.opt.content_dir, 'mask', img_name.replace('.jpg', '.png'))) 116 | ## style 117 | for img_name in os.listdir(os.path.join(self.opt.style_dir, 'style')): 118 | self.path_style.append(os.path.join(self.opt.style_dir, 'style', img_name)) 119 | 120 | print('foreground number',len(self.path_img)) 121 | print('background number',len(self.path_style)) 122 | 123 | def select_mask(self,index): 124 | """select one foreground to generate the composite image""" 125 | if self.isTrain: 126 | mask_all = Image.open(self.path_img[index].replace('train2014','SegmentationClass_select',1).replace('jpg','png')).convert('L') 127 | else: 128 | mask_all = Image.open(self.path_img[index].replace('val2014','SegmentationClass_select',1).replace('jpg','png')).convert('L') 129 | 130 | mask_array = np.array(mask_all) 131 | mask_value = np.unique(np.sort(mask_array[mask_array>0])) 132 | object_num = len(mask_value) 133 | whole_area = np.ones(np.shape(mask_array)) 134 | 135 | if self.isTrain: 136 | random_pixel = random.choice(mask_value) 137 | else: 138 | random_pixel = mask_value[0] 139 | 140 | if random_pixel!=255: 141 | mask_array[mask_array==255] = 0 142 | mask_array[mask_array==random_pixel]=255 143 | mask_array[mask_array!=255]=0 144 | return mask_array 145 | 146 | def get_small_scale_mask(self, mask, number): 147 | """generate n*n patch to supervise discriminator""" 148 | mask = np.asarray(mask) 149 | mask = np.uint8(mask / 255.) 150 | mask_patch = np.zeros([number, number],dtype=np.float32) 151 | split_size = self.opt.load_size // number 152 | for i in range(number): 153 | for j in range(number): 154 | mask_split = mask[i*split_size: (i+1)*split_size, j*split_size: (j+1)*split_size] 155 | mask_patch[i, j] = (np.sum(mask_split) > 100) * 255 156 | #mask_patch[i, j] = (np.sum(mask_split) / (split_size * split_size)) * 255 157 | mask_patch = np.uint8(mask_patch) 158 | return Image.fromarray(mask_patch,mode='L') 159 | 160 | 161 | def __getitem__(self, index): 162 | if self.isTrain: 163 | style = Image.open(self.path_style[index]).convert('RGB') 164 | c_index = index % len(self.path_img) 165 | content = Image.open(self.path_img[c_index]).convert('RGB') 166 | select_mask = self.select_mask(c_index) 167 | np_mask = np.uint8(select_mask) 168 | else: 169 | style = Image.open(self.path_style[index]).convert('RGB') 170 | content = Image.open(self.path_img[index]).convert('RGB') 171 | np_mask = cv2.imread(self.path_mask[index], 0) 172 | 173 | soft_mask = cv2.blur(np.array(np_mask), (7, 7)) 174 | mask = Image.fromarray(np_mask,mode='L') 175 | soft_mask = Image.fromarray(soft_mask,mode='L') 176 | 177 | content = tf.resize(content, [self.opt.load_size, self.opt.load_size]) 178 | style = tf.resize(style, [self.opt.load_size, self.opt.load_size]) 179 | mask = tf.resize(mask, [self.opt.load_size, self.opt.load_size]) 180 | soft_mask = tf.resize(soft_mask, [self.opt.load_size, self.opt.load_size]) 181 | mask_patch = self.get_small_scale_mask(mask, self.opt.patch_number) 182 | 183 | 184 | #apply the same transform to composite and real images 185 | content = self.transform(content) 186 | style = self.transform(style) 187 | 188 | #content = tf.to_tensor(content) 189 | #style = tf.to_tensor(style) 190 | mask = tf.to_tensor(mask) 191 | mask = mask*2 -1 192 | soft_mask = tf.to_tensor(soft_mask) 193 | soft_mask = soft_mask*2 -1 194 | mask_patch = tf.to_tensor(mask_patch) 195 | comp = self._compose(content, mask, style) 196 | 197 | return {'comp': comp, 'mask': mask, 'mask_patch': mask_patch, 'soft_mask': soft_mask, 'style': style,'content':content, 'img_path':self.path_style[index]} 198 | 199 | def __len__(self): 200 | return len(self.path_style) 201 | # return 10 202 | # return 100 203 | 204 | def _compose(self, foreground_img, foreground_mask, background_img): 205 | foreground_img = foreground_img/2 + 0.5 206 | background_img = background_img/2 + 0.5 207 | foreground_mask = foreground_mask/2 + 0.5 208 | comp = foreground_img * foreground_mask + background_img * (1 - foreground_mask) 209 | comp = comp*2-1 210 | return comp 211 | -------------------------------------------------------------------------------- /examples/comp/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/comp/1.jpg -------------------------------------------------------------------------------- /examples/comp/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/comp/2.jpg -------------------------------------------------------------------------------- /examples/comp/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/comp/3.jpg -------------------------------------------------------------------------------- /examples/comp/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/comp/4.jpg -------------------------------------------------------------------------------- /examples/comp/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/comp/5.jpg -------------------------------------------------------------------------------- /examples/comp/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/comp/6.jpg -------------------------------------------------------------------------------- /examples/comp/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/comp/7.jpg -------------------------------------------------------------------------------- /examples/comp/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/comp/8.jpg -------------------------------------------------------------------------------- /examples/mask/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/mask/1.png -------------------------------------------------------------------------------- /examples/mask/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/mask/2.png -------------------------------------------------------------------------------- /examples/mask/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/mask/3.png -------------------------------------------------------------------------------- /examples/mask/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/mask/4.png -------------------------------------------------------------------------------- /examples/mask/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/mask/5.png -------------------------------------------------------------------------------- /examples/mask/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/mask/6.png -------------------------------------------------------------------------------- /examples/mask/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/mask/7.png -------------------------------------------------------------------------------- /examples/mask/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/mask/8.png -------------------------------------------------------------------------------- /examples/style/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/style/1.jpg -------------------------------------------------------------------------------- /examples/style/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/style/2.jpg -------------------------------------------------------------------------------- /examples/style/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/style/3.jpg -------------------------------------------------------------------------------- /examples/style/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/style/4.jpg -------------------------------------------------------------------------------- /examples/style/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/style/5.jpg -------------------------------------------------------------------------------- /examples/style/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/style/6.jpg -------------------------------------------------------------------------------- /examples/style/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/style/7.jpg -------------------------------------------------------------------------------- /examples/style/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/examples/style/8.jpg -------------------------------------------------------------------------------- /figures/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/figures/result.jpg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from models.base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "models." + model_name + "_model" 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace('_', '') + 'model' 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() \ 38 | and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 43 | exit(0) 44 | 45 | return model 46 | 47 | 48 | def get_option_setter(model_name): 49 | """Return the static method of the model class.""" 50 | model_class = find_model_using_name(model_name) 51 | return model_class.modify_commandline_options 52 | 53 | 54 | def create_model(opt): 55 | """Create a model given the option. 56 | 57 | This function warps the class CustomDatasetDataLoader. 58 | This is the main interface between this package and 'train.py'/'test.py' 59 | 60 | Example: 61 | >>> from models import create_model 62 | >>> model = create_model(opt) 63 | """ 64 | model = find_model_using_name(opt.model) 65 | instance = model(opt) 66 | print("model [%s] was created" % type(instance).__name__) 67 | return instance 68 | -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/base_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/models/__pycache__/base_model.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/networks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/models/__pycache__/networks.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/vgg19hrnet_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/models/__pycache__/vgg19hrnet_model.cpython-39.pyc -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from abc import ABC, abstractmethod 5 | from . import networks 6 | import time 7 | import numpy as np 8 | from util import util 9 | 10 | 11 | 12 | class BaseModel(ABC): 13 | """This class is an abstract base class (ABC) for models. 14 | To create a subclass, you need to implement the following five functions: 15 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 16 | -- : unpack data from dataset and apply preprocessing. 17 | -- : produce intermediate results. 18 | -- : calculate losses, gradients, and update network weights. 19 | -- : (optionally) add model-specific options and set default options. 20 | """ 21 | 22 | def __init__(self, opt): 23 | """Initialize the BaseModel class. 24 | 25 | Parameters: 26 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 27 | 28 | When creating your custom class, you need to implement your own initialization. 29 | In this fucntion, you should first call 30 | Then, you need to define four lists: 31 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 32 | -- self.model_names (str list): specify the images that you want to display and save. 33 | -- self.visual_names (str list): define networks used in our training. 34 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 35 | """ 36 | self.opt = opt 37 | self.gpu_ids = opt.gpu_ids 38 | self.isTrain = opt.isTrain 39 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU 40 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir 41 | if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. 42 | torch.backends.cudnn.benchmark = True 43 | self.loss_names = [] 44 | self.model_names = [] 45 | self.visual_names = [] 46 | self.optimizers = [] 47 | self.image_paths = [] 48 | self.metric = 0 # used for learning rate policy 'plateau' 49 | 50 | @staticmethod 51 | def modify_commandline_options(parser, is_train): 52 | """Add new model-specific options, and rewrite default values for existing options. 53 | 54 | Parameters: 55 | parser -- original option parser 56 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 57 | 58 | Returns: 59 | the modified parser. 60 | """ 61 | return parser 62 | 63 | @abstractmethod 64 | def set_input(self, input): 65 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 66 | 67 | Parameters: 68 | input (dict): includes the data itself and its metadata information. 69 | """ 70 | pass 71 | 72 | @abstractmethod 73 | def forward(self): 74 | """Run forward pass; called by both functions and .""" 75 | pass 76 | 77 | @abstractmethod 78 | def optimize_parameters(self): 79 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 80 | pass 81 | def setup(self, opt): 82 | if self.isTrain: 83 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 84 | if opt.continue_train: 85 | self.load_networks(opt.epoch) 86 | # print("LOADING %s"%(self.name)) 87 | else: 88 | if opt.phase == 'test': 89 | self.load_networks(opt.epoch) 90 | # print("LOADING %s"%(self.name)) 91 | # self.test() 92 | self.eval() 93 | # self.test() 94 | self.print_networks(opt.verbose) 95 | 96 | 97 | # def setup(self, opt): 98 | # """Load and print networks; create schedulers 99 | 100 | # Parameters: 101 | # opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 102 | # """ 103 | # if self.isTrain: 104 | # self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 105 | # if not self.isTrain or opt.continue_train: 106 | # load_suffix = '%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch 107 | # self.load_networks(load_suffix) 108 | # if opt.phase == 'test': 109 | # self.load_networks() 110 | # self.print_networks(opt.verbose) 111 | 112 | def eval(self): 113 | """Make models eval mode during test time""" 114 | for name in self.model_names: 115 | if isinstance(name, str): 116 | net = getattr(self, 'net' + name) 117 | net.eval() 118 | 119 | def test(self): 120 | """Forward function used in test time. 121 | 122 | This function wraps function in no_grad() so we don't save intermediate steps for backprop 123 | It also calls to produce additional visualization results 124 | """ 125 | with torch.no_grad(): 126 | self.forward() 127 | self.compute_visuals() 128 | 129 | def compute_visuals(self): 130 | """Calculate additional output images for visdom and HTML visualization""" 131 | pass 132 | 133 | def get_image_paths(self): 134 | """ Return image paths that are used to load current data""" 135 | return self.image_paths 136 | 137 | def update_learning_rate(self): 138 | """Update learning rates for all the networks; called at the end of every epoch""" 139 | for scheduler in self.schedulers: 140 | if self.opt.lr_policy == 'plateau': 141 | scheduler.step(self.metric) 142 | else: 143 | scheduler.step() 144 | 145 | #lr = self.optimizers[0].param_groups[0]['lr'] 146 | #print('learning rate = %.7f' % lr) 147 | 148 | def get_current_visuals(self): 149 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" 150 | visual_ret = OrderedDict() 151 | for name in self.visual_names: 152 | if isinstance(name, str): 153 | visual_ret[name] = getattr(self, name) 154 | return visual_ret 155 | 156 | 157 | 158 | 159 | def get_current_losses(self): 160 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" 161 | errors_ret = OrderedDict() 162 | for name in self.loss_names: 163 | if isinstance(name, str): 164 | errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number 165 | return errors_ret 166 | 167 | def save_networks(self, epoch): 168 | """Save all the networks to the disk. 169 | 170 | Parameters: 171 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 172 | """ 173 | for name in self.model_names: 174 | if isinstance(name, str): 175 | save_filename = '%s_net_%s.pth' % (epoch, name) 176 | save_path = os.path.join(self.save_dir, save_filename) 177 | net = getattr(self, 'net' + name) 178 | 179 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 180 | torch.save(net.module.cpu().state_dict(), save_path) 181 | net.cuda(self.gpu_ids[0]) 182 | else: 183 | torch.save(net.cpu().state_dict(), save_path) 184 | 185 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 186 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" 187 | key = keys[i] 188 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 189 | if module.__class__.__name__.startswith('InstanceNorm') and \ 190 | (key == 'running_mean' or key == 'running_var'): 191 | if getattr(module, key) is None: 192 | state_dict.pop('.'.join(keys)) 193 | if module.__class__.__name__.startswith('InstanceNorm') and \ 194 | (key == 'num_batches_tracked'): 195 | state_dict.pop('.'.join(keys)) 196 | else: 197 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 198 | 199 | def load_networks(self, epoch): 200 | """Load all the networks from the disk. 201 | 202 | Parameters: 203 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 204 | """ 205 | for name in self.model_names: 206 | #if self.opt.phase == 'test': 207 | # if name == 'D': 208 | # continue 209 | if isinstance(name, str): 210 | load_filename = '%s_net_%s.pth' % (epoch, name) 211 | load_path = os.path.join(self.save_dir, load_filename) 212 | # load_path = os.path.join('/data1/cctraithong/PainterlyHarmonization/RainNet-main/0610_result/adain_Gadain_Content1_Style50_Ggan0_Lgan0_lr0.0002_TV10e-6_Matting0_batch8-bgnoise/adain_Gadain_Content1_Style50_Ggan0_Lgan0_lr0.0002_TV10e-6_Matting0_batch8/', load_filename) 213 | # load_path = os.path.join('/data1/cctraithong/PainterlyHarmonization/RainNet-main/0610_result/adain_Gadain_Content1_Style10e3_Ggan0_Lgan0_lr0.0002_TV10e-3_Matting0_batch8/adain_Gadain_Content1_Style10e3_Ggan0_Lgan0_lr0.0002_TV10e-3_Matting0_batch8/', load_filename) 214 | # load_path = os.path.join('/data1/cctraithong/PainterlyHarmonization/RainNet-main/0614_fostyle/adain_Gadain_Content1_Style1e3_Ggan1_Lgan0_lr0.0002_TV10e-3_Matting1_batch8/adain_Gadain_Content1_Style1e3_Ggan1_Lgan0_lr0.0002_TV10e-3_Matting1_batch8/', load_filename) 215 | # load_path = os.path.join('/data1/cctraithong/PainterlyHarmonization/RainNet-main/aaai_0616/RegRainNet_GReg_adain_Content1_Style1e1_Ggan0_Lmask1_lr0.0002_TV10e-6_Matting1_batch8_patchsize8/', load_filename) 216 | 217 | if os.path.exists(load_path): 218 | net = getattr(self, 'net' + name) 219 | if isinstance(net, torch.nn.DataParallel): 220 | net = net.module 221 | print('loading the model from %s' % load_path) 222 | # if you are using PyTorch newer than 0.4 (e.g., built from 223 | # GitHub source), you can remove str() on self.device 224 | state_dict = torch.load(load_path, map_location=str(self.device)) 225 | if hasattr(state_dict, '_metadata'): 226 | del state_dict._metadata 227 | 228 | # net_dict = net.state_dict() 229 | # # 1. filter out unnecessary keys 230 | # pretrained_dict = {k: v for k, v in state_dict.items() if ('fuseblk' not in k and 'final_conv' not in k)} 231 | # # 2. overwrite entries in the existing state dict 232 | # net_dict.update(pretrained_dict) 233 | # net.load_state_dict(net_dict, strict=True) 234 | 235 | # patch InstanceNorm checkpoints prior to 0.4 236 | # for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 237 | # self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 238 | net.load_state_dict(state_dict, strict=True) 239 | else: 240 | print('failed loading model from {}'.format(load_path)) 241 | 242 | def print_networks(self, verbose): 243 | """Print the total number of parameters in the network and (if verbose) network architecture 244 | 245 | Parameters: 246 | verbose (bool) -- if verbose: print the network architecture 247 | """ 248 | print('---------- Networks initialized -------------') 249 | for name in self.model_names: 250 | if isinstance(name, str): 251 | net = getattr(self, 'net' + name) 252 | num_params = 0 253 | for param in net.parameters(): 254 | num_params += param.numel() 255 | if verbose: 256 | print(net) 257 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 258 | print('-----------------------------------------------') 259 | 260 | def set_requires_grad(self, nets, requires_grad=False): 261 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 262 | Parameters: 263 | nets (network list) -- a list of networks 264 | requires_grad (bool) -- whether the networks require gradients or not 265 | """ 266 | if not isinstance(nets, list): 267 | nets = [nets] 268 | for net in nets: 269 | if net is not None: 270 | for param in net.parameters(): 271 | param.requires_grad = requires_grad 272 | 273 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.fft 6 | import functools 7 | from torch.nn import init 8 | from torch.optim import lr_scheduler 9 | 10 | import cv2 11 | import os 12 | import time 13 | 14 | ## modules for painterly image harmonization 15 | 16 | class TVLoss(nn.Module): 17 | def __init__(self, strength): 18 | super(TVLoss, self).__init__() 19 | self.strength = strength 20 | self.x_diff = torch.Tensor() 21 | self.y_diff = torch.Tensor() 22 | 23 | def forward(self, input): 24 | self.x_diff = input[:, :, 1:, :] - input[:, :, :-1, :] 25 | self.y_diff = input[:, :, :, 1:] - input[:, :, :, :-1] 26 | self.loss = self.strength * (torch.sum(torch.abs(self.x_diff)) + torch.sum(torch.abs(self.y_diff))) 27 | # return input 28 | return self.loss 29 | 30 | def calc_mean_std(feat, eps=1e-5): 31 | # eps is a small value added to the variance to avoid divide-by-zero. 32 | size = feat.size() 33 | assert (len(size) == 4) 34 | N, C = size[:2] 35 | feat_var = feat.contiguous().view(N, C, -1).var(dim=2) + eps 36 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 37 | feat_mean = feat.contiguous().view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 38 | return feat_mean, feat_std 39 | 40 | def adain_fg(comp_feat, style_feat, mask): 41 | #assert (content_feat.size()[:2] == style_feat.size()[:2]) 42 | size = comp_feat.size() 43 | #style_mean, style_std = calc_mean_std(style_feat) # the style features 44 | downsample_mask_style = torch.ones(mask.size()).to(style_feat.device) 45 | # downsample_mask_style = 1 - mask 46 | style_mean, style_std = get_foreground_mean_std(style_feat, downsample_mask_style) # the style features 47 | fore_mean, fore_std = get_foreground_mean_std(comp_feat, mask) # the foreground features 48 | 49 | normalized_feat = (comp_feat - fore_mean.expand(size)) / fore_std.expand(size) 50 | return (normalized_feat * style_std.expand(size) + style_mean.expand(size)) * mask + (comp_feat * (1 - mask)) 51 | 52 | decoder = nn.Sequential( 53 | nn.ReflectionPad2d((1, 1, 1, 1)), 54 | nn.Conv2d(512, 256, (3, 3)), 55 | nn.ReLU(), # relu1-1 56 | nn.Upsample(scale_factor=2, mode='nearest'), 57 | nn.ReflectionPad2d((1, 1, 1, 1)), 58 | nn.Conv2d(256, 256, (3, 3)), 59 | nn.ReLU(), 60 | nn.ReflectionPad2d((1, 1, 1, 1)), 61 | nn.Conv2d(256, 256, (3, 3)), 62 | nn.ReLU(), 63 | nn.ReflectionPad2d((1, 1, 1, 1)), 64 | nn.Conv2d(256, 256, (3, 3)), 65 | nn.ReLU(), 66 | nn.ReflectionPad2d((1, 1, 1, 1)), 67 | nn.Conv2d(256, 128, (3, 3)), 68 | nn.ReLU(), # relu2-1 69 | nn.Upsample(scale_factor=2, mode='nearest'), 70 | nn.ReflectionPad2d((1, 1, 1, 1)), 71 | nn.Conv2d(128, 128, (3, 3)), 72 | nn.ReLU(), 73 | nn.ReflectionPad2d((1, 1, 1, 1)), 74 | nn.Conv2d(128, 64, (3, 3)), 75 | nn.ReLU(), 76 | nn.Upsample(scale_factor=2, mode='nearest'), 77 | nn.ReflectionPad2d((1, 1, 1, 1)), 78 | nn.Conv2d(64, 64, (3, 3)), 79 | nn.ReLU(), 80 | # nn.Conv2d(64, 1, (3, 3),padding=0,stride=1), ##matting layer 81 | nn.Conv2d(65, 1, (1, 1),padding=0,stride=1), ##matting layer 82 | nn.ReflectionPad2d((1, 1, 1, 1)), # 24 83 | # nn.ReflectionPad2d((1, 1, 1, 1)), ##matting layer 84 | nn.Conv2d(64, 3, (3, 3)), 85 | ) 86 | 87 | 88 | vgg = nn.Sequential( 89 | nn.Conv2d(3, 3, (1, 1)), 90 | nn.ReflectionPad2d((1, 1, 1, 1)), 91 | nn.Conv2d(3, 64, (3, 3)), 92 | nn.ReLU(), # relu1-1 93 | nn.ReflectionPad2d((1, 1, 1, 1)), 94 | nn.Conv2d(64, 64, (3, 3)), 95 | nn.ReLU(), # relu1-2 96 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 97 | nn.ReflectionPad2d((1, 1, 1, 1)), 98 | nn.Conv2d(64, 128, (3, 3)), 99 | nn.ReLU(), # relu2-1 100 | nn.ReflectionPad2d((1, 1, 1, 1)), 101 | nn.Conv2d(128, 128, (3, 3)), 102 | nn.ReLU(), # relu2-2 103 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 104 | nn.ReflectionPad2d((1, 1, 1, 1)), 105 | nn.Conv2d(128, 256, (3, 3)), 106 | nn.ReLU(), # relu3-1 107 | nn.ReflectionPad2d((1, 1, 1, 1)), 108 | nn.Conv2d(256, 256, (3, 3)), 109 | nn.ReLU(), # relu3-2 110 | nn.ReflectionPad2d((1, 1, 1, 1)), 111 | nn.Conv2d(256, 256, (3, 3)), 112 | nn.ReLU(), # relu3-3 113 | nn.ReflectionPad2d((1, 1, 1, 1)), 114 | nn.Conv2d(256, 256, (3, 3)), 115 | nn.ReLU(), # relu3-4 116 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 117 | nn.ReflectionPad2d((1, 1, 1, 1)), 118 | nn.Conv2d(256, 512, (3, 3)), 119 | nn.ReLU(), # relu4-1, this is the last layer used 120 | nn.ReflectionPad2d((1, 1, 1, 1)), 121 | nn.Conv2d(512, 512, (3, 3)), 122 | nn.ReLU(), # relu4-2 123 | nn.ReflectionPad2d((1, 1, 1, 1)), 124 | nn.Conv2d(512, 512, (3, 3)), 125 | nn.ReLU(), # relu4-3 126 | nn.ReflectionPad2d((1, 1, 1, 1)), 127 | nn.Conv2d(512, 512, (3, 3)), 128 | nn.ReLU(), # relu4-4 129 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 130 | nn.ReflectionPad2d((1, 1, 1, 1)), 131 | nn.Conv2d(512, 512, (3, 3)), 132 | nn.ReLU(), # relu5-1 133 | nn.ReflectionPad2d((1, 1, 1, 1)), 134 | nn.Conv2d(512, 512, (3, 3)), 135 | nn.ReLU(), # relu5-2 136 | nn.ReflectionPad2d((1, 1, 1, 1)), 137 | nn.Conv2d(512, 512, (3, 3)), 138 | nn.ReLU(), # relu5-3 139 | nn.ReflectionPad2d((1, 1, 1, 1)), 140 | nn.Conv2d(512, 512, (3, 3)), 141 | nn.ReLU() # relu5-4 142 | ) 143 | 144 | 145 | encoder = nn.Sequential( 146 | nn.Conv2d(4, 3, (1, 1)), 147 | nn.ReflectionPad2d((1, 1, 1, 1)), 148 | nn.Conv2d(3, 64, (3, 3)), 149 | nn.ReLU(), # relu1-1 150 | nn.ReflectionPad2d((1, 1, 1, 1)), 151 | nn.Conv2d(64, 64, (3, 3)), 152 | nn.ReLU(), # relu1-2 153 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 154 | nn.ReflectionPad2d((1, 1, 1, 1)), 155 | nn.Conv2d(64, 128, (3, 3)), 156 | nn.ReLU(), # relu2-1 157 | nn.ReflectionPad2d((1, 1, 1, 1)), 158 | nn.Conv2d(128, 128, (3, 3)), 159 | nn.ReLU(), # relu2-2 160 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 161 | nn.ReflectionPad2d((1, 1, 1, 1)), 162 | nn.Conv2d(128, 256, (3, 3)), 163 | nn.ReLU(), # relu3-1 164 | nn.ReflectionPad2d((1, 1, 1, 1)), 165 | nn.Conv2d(256, 256, (3, 3)), 166 | nn.ReLU(), # relu3-2 167 | nn.ReflectionPad2d((1, 1, 1, 1)), 168 | nn.Conv2d(256, 256, (3, 3)), 169 | nn.ReLU(), # relu3-3 170 | nn.ReflectionPad2d((1, 1, 1, 1)), 171 | nn.Conv2d(256, 256, (3, 3)), 172 | nn.ReLU(), # relu3-4 173 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 174 | nn.ReflectionPad2d((1, 1, 1, 1)), 175 | nn.Conv2d(256, 512, (3, 3)), 176 | nn.ReLU(), # relu4-1, this is the last layer used 177 | ) 178 | 179 | 180 | 181 | class Identity(nn.Module): 182 | def forward(self, x): 183 | return x 184 | 185 | def get_norm_layer(norm_type='instance'): 186 | """Return a normalization layer 187 | 188 | Parameters: 189 | norm_type (str) -- the name of the normalization layer: batch | instance | none 190 | 191 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). 192 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. 193 | """ 194 | norm_type = norm_type.lower() 195 | if norm_type == 'batch': 196 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 197 | elif norm_type == 'instance': 198 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 199 | elif norm_type == 'none': 200 | norm_layer = lambda x: Identity() 201 | elif norm_type.startswith('rain'): 202 | norm_layer = RAIN 203 | else: 204 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 205 | return norm_layer 206 | 207 | 208 | def define_G(input_nc, output_nc, ngf, netG,opt, norm='batch', use_dropout=False, 209 | init_type='normal', init_gain=0.02, gpu_ids=[],encoder=None,decoder=None): 210 | """load a generator 211 | 212 | Parameters: 213 | input_nc (int) -- the number of channels in input images 214 | output_nc (int) -- the number of channels in output images 215 | ngf (int) -- the number of filters in the last conv layer 216 | netG (str) -- the architecture's name: rainnet 217 | norm (str) -- the name of normalization layers used in the network: batch | instance | none 218 | use_dropout (bool) -- if use dropout layers. 219 | init_type (str) -- the name of our initialization method. 220 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 221 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 222 | """ 223 | norm_layer = get_norm_layer(norm_type=norm) 224 | 225 | if netG == 'rainnet': 226 | net = RainNet(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, use_attention=True) 227 | elif netG == 'adain': 228 | print('loading vgg from {}'.format(opt.vgg)) 229 | encoder.load_state_dict(torch.load(opt.vgg)) 230 | encoder = nn.Sequential(*list(encoder.children())[:31]) 231 | net = AdainNet(encoder, decoder) 232 | elif netG == 'RegRainNet': 233 | print('loading vgg from {}'.format(opt.vgg)) 234 | encoder.load_state_dict(torch.load(opt.vgg)) 235 | encoder = nn.Sequential(*list(encoder.children())[:31]) 236 | net = RegRainNet(encoder, decoder) 237 | else: 238 | raise NotImplementedError('Generator model name [%s] is not recognized' % netG) 239 | return init_net(net, init_type, init_gain, gpu_ids) 240 | 241 | def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]): 242 | """Create a discriminator 243 | 244 | Parameters: 245 | input_nc (int) -- the number of channels in input images 246 | ndf (int) -- the number of filters in the first conv layer 247 | netD (str) -- the architecture's name: basic | n_layers | pixel 248 | n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' 249 | norm (str) -- the type of normalization layers used in the network. 250 | init_type (str) -- the name of the initialization method. 251 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 252 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 253 | """ 254 | norm_layer = get_norm_layer(norm_type=norm) 255 | 256 | if netD == 'basic': # default PatchGAN classifier 257 | net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer) 258 | elif netD == 'n_layers': # more options 259 | net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer) 260 | elif netD == 'pixel': # classify if each pixel is real or fake 261 | net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) 262 | else: 263 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) 264 | return init_net(net, init_type, init_gain, gpu_ids) 265 | 266 | 267 | def get_scheduler(optimizer, opt): 268 | """Return a learning rate scheduler 269 | 270 | Parameters: 271 | optimizer -- the optimizer of the network 272 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  273 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 274 | 275 | For 'linear', we keep the same learning rate for the first epochs 276 | and linearly decay the rate to zero over the next epochs. 277 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 278 | See https://pytorch.org/docs/stable/optim.html for more details. 279 | """ 280 | if opt.lr_policy == 'linear': 281 | def lambda_rule(epoch): 282 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 283 | return lr_l 284 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 285 | elif opt.lr_policy == 'step': 286 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 287 | elif opt.lr_policy == 'plateau': 288 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 289 | elif opt.lr_policy == 'cosine': 290 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) 291 | else: 292 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 293 | return scheduler 294 | 295 | def init_weights(net, init_type='normal', init_gain=0.02): 296 | """Initialize network weights. 297 | 298 | Parameters: 299 | net (network) -- network to be initialized 300 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 301 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 302 | 303 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 304 | work better for some applications. Feel free to try yourself. 305 | """ 306 | def init_func(m): # define the initialization function 307 | classname = m.__class__.__name__ 308 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 309 | if init_type == 'normal': 310 | init.normal_(m.weight.data, 0.0, init_gain) 311 | elif init_type == 'xavier': 312 | init.xavier_normal_(m.weight.data, gain=init_gain) 313 | elif init_type == 'kaiming': 314 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 315 | elif init_type == 'orthogonal': 316 | init.orthogonal_(m.weight.data, gain=init_gain) 317 | else: 318 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 319 | if hasattr(m, 'bias') and m.bias is not None: 320 | init.constant_(m.bias.data, 0.0) 321 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 322 | init.normal_(m.weight.data, 1.0, init_gain) 323 | init.constant_(m.bias.data, 0.0) 324 | 325 | print('initialize network with %s' % init_type) 326 | net.apply(init_func) # apply the initialization function 327 | 328 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 329 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 330 | Parameters: 331 | net (network) -- the network to be initialized 332 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 333 | gain (float) -- scaling factor for normal, xavier and orthogonal. 334 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 335 | 336 | Return an initialized network. 337 | """ 338 | if len(gpu_ids) > 0: 339 | assert(torch.cuda.is_available()) 340 | net.to(gpu_ids[0]) 341 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 342 | init_weights(net, init_type, init_gain=init_gain) 343 | return net 344 | 345 | 346 | class GANLoss(nn.Module): 347 | """Define different GAN objectives. 348 | 349 | The GANLoss class abstracts away the need to create the target label tensor 350 | that has the same size as the input. 351 | """ 352 | 353 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): 354 | """ Initialize the GANLoss class. 355 | 356 | Parameters: 357 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. 358 | target_real_label (bool) - - label for a real image 359 | target_fake_label (bool) - - label of a fake image 360 | 361 | Note: Do not use sigmoid as the last layer of Discriminator. 362 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. 363 | """ 364 | super(GANLoss, self).__init__() 365 | self.register_buffer('real_label', torch.tensor(target_real_label)) 366 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 367 | self.gan_mode = gan_mode 368 | if gan_mode == 'lsgan': 369 | self.loss = nn.MSELoss() 370 | elif gan_mode == 'vanilla': 371 | self.loss = nn.BCEWithLogitsLoss() 372 | elif gan_mode in ['wgangp']: 373 | self.loss = None 374 | self.relu = nn.ReLU() 375 | else: 376 | raise NotImplementedError('gan mode %s not implemented' % gan_mode) 377 | 378 | def get_target_tensor(self, prediction, target_is_real): 379 | """Create label tensors with the same size as the input. 380 | 381 | Parameters: 382 | prediction (tensor) - - tpyically the prediction from a discriminator 383 | target_is_real (bool) - - if the ground truth label is for real images or fake images 384 | 385 | Returns: 386 | A label tensor filled with ground truth label, and with the size of the input 387 | """ 388 | 389 | if target_is_real: 390 | target_tensor = self.real_label 391 | else: 392 | target_tensor = self.fake_label 393 | return target_tensor.expand_as(prediction) 394 | 395 | def __call__(self, prediction, target_is_real): 396 | """Calculate loss given Discriminator's output and grount truth labels. 397 | 398 | Parameters: 399 | prediction (tensor) - - tpyically the prediction output from a discriminator 400 | target_is_real (bool) - - if the ground truth label is for real images or fake images 401 | 402 | Returns: 403 | the calculated loss. 404 | """ 405 | if self.gan_mode in ['lsgan', 'vanilla']: 406 | target_tensor = self.get_target_tensor(prediction, target_is_real) 407 | loss = self.loss(prediction, target_tensor) 408 | elif self.gan_mode == 'wgangp': 409 | if target_is_real: 410 | loss = -prediction.mean() # self.relu(1-prediction.mean()) 411 | else: 412 | loss = prediction.mean() # self.relu(1+prediction.mean()) 413 | return loss 414 | 415 | 416 | def get_foreground_mean_std(features, mask, eps=1e-5): 417 | region = features * mask 418 | sum = torch.sum(region, dim=[2, 3]) # (B, C) 419 | num = torch.sum(mask, dim=[2, 3]) # (B, C) 420 | mu = sum / (num + eps) 421 | mean = mu[:, :, None, None] 422 | var = torch.sum((region + (1 - mask)*mean - mean) ** 2, dim=[2, 3]) / (num + eps) 423 | var = var[:, :, None, None] 424 | std = torch.sqrt(var+eps) 425 | return mean, std 426 | 427 | 428 | class UpsampleDecoder(nn.Module): 429 | def __init__(self, k, dim): 430 | super().__init__() 431 | self.conv1 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)), 432 | nn.Conv2d(dim, dim//2, (3, 3)), 433 | nn.ReLU()) 434 | upconvs = [] 435 | for _ in range(k): 436 | dim = dim//2 437 | upconv = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'), 438 | nn.ReflectionPad2d((1, 1, 1, 1)), 439 | nn.Conv2d(dim, dim//2, (3, 3)), 440 | nn.ReLU()) 441 | upconvs.append(upconv) 442 | 443 | self.upconvs = nn.Sequential(*upconvs) 444 | 445 | def forward(self, x): 446 | output = self.conv1(x) 447 | output = self.upconvs(output) 448 | return output 449 | 450 | class FuseBlock(nn.Module): 451 | def __init__(self): 452 | super().__init__() 453 | self.fuse_block1 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)), 454 | nn.Conv2d(64, 32, (3, 3)), 455 | nn.ReLU(), 456 | nn.ReflectionPad2d((1, 1, 1, 1)), 457 | nn.Conv2d(32, 32, (3, 3)), 458 | nn.ReLU()) 459 | 460 | self.fuse_block2 = nn.Sequential(nn.ReflectionPad2d((1, 1, 1, 1)), 461 | nn.Conv2d(32, 32, (3, 3)), 462 | nn.ReLU(), 463 | nn.ReflectionPad2d((1, 1, 1, 1)), 464 | nn.Conv2d(32, 32, (3, 3)), 465 | nn.ReLU()) 466 | 467 | def forward(self, x1, x2): 468 | output = torch.cat((x1,x2), dim=1) 469 | output = self.fuse_block1(output) 470 | output = self.fuse_block2(output) 471 | return output 472 | 473 | class VGG19HRNet(nn.Module): 474 | def __init__(self,vgg): 475 | super().__init__() 476 | # load the pretrained VGG encoder 477 | vgg_layers = list(vgg.children()) 478 | self.enc_1 = nn.Sequential(*vgg_layers[:4]) # input -> relu1_1 479 | self.enc_2 = nn.Sequential(*vgg_layers[4:11]) # relu1_1 -> relu2_1 480 | self.enc_3 = nn.Sequential(*vgg_layers[11:18]) # relu2_1 -> relu3_1 481 | self.enc_4 = nn.Sequential(*vgg_layers[18:31]) # relu3_1 -> relu4_1 482 | 483 | # fix the VGG encoder 484 | for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']: 485 | for param in getattr(self, name).parameters(): 486 | param.requires_grad = False 487 | 488 | self.updec_1 = UpsampleDecoder(0,64) 489 | self.updec_2 = UpsampleDecoder(1,128) 490 | self.updec_3 = UpsampleDecoder(2,256) 491 | self.updec_4 = UpsampleDecoder(3,512) 492 | 493 | self.final_conv_1 = nn.Conv2d(32, 3, (1, 1)) 494 | self.fuseblk_2 = FuseBlock() 495 | self.final_conv_2 = nn.Conv2d(32, 3, (1, 1)) 496 | self.fuseblk_3 = FuseBlock() 497 | self.final_conv_3 = nn.Conv2d(32, 3, (1, 1)) 498 | self.fuseblk_4 = FuseBlock() 499 | self.final_conv_4 = nn.Conv2d(32, 3, (1, 1)) 500 | 501 | 502 | self.mse_loss = nn.MSELoss() 503 | 504 | # extract relu4_1 from input image 505 | def encode(self, input): 506 | for i in range(4): 507 | input = getattr(self, 'enc_{:d}'.format(i + 1))(input) 508 | return input 509 | 510 | # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image 511 | def encode_with_intermediate(self, input): 512 | results = [input] 513 | for i in range(4): 514 | func = getattr(self, 'enc_{:d}'.format(i + 1)) 515 | results.append(func(results[-1])) 516 | return results[1:] 517 | 518 | def calc_content_loss(self, gen, comb): 519 | loss = self.mse_loss(gen, comb) 520 | return loss 521 | 522 | 523 | def downsample(self, image_tensor, width, height): 524 | image_upsample_tensor = torch.nn.functional.interpolate(image_tensor, size=[width, height]) 525 | image_upsample_tensor = image_upsample_tensor.clamp(0, 1) 526 | return image_upsample_tensor 527 | 528 | 529 | def calc_style_loss_mulitple_fg_layer(self, combs, styles, mask, layer): 530 | loss = torch.zeros(1).to(mask.device) 531 | for i in range(0, layer): 532 | width = height = combs[i].size(-1) 533 | downsample_mask = self.downsample(mask, width, height) 534 | downsample_mask_style = torch.ones(downsample_mask.size()).to(mask.device) #1-downsample_mask 535 | 536 | mu_cs,sigma_cs = get_foreground_mean_std(combs[i], downsample_mask) 537 | mu_target,sigma_target = get_foreground_mean_std(styles[i], downsample_mask_style) 538 | loss_i = self.mse_loss(mu_cs, mu_target) + self.mse_loss(sigma_cs, sigma_target) 539 | loss += loss_i 540 | return loss 541 | 542 | def cal_loss(self,final_output,comb_feats,style_feats,mask, layer): 543 | fine_feats = self.encode_with_intermediate(final_output) 544 | # calculate content loss 545 | loss_c = self.calc_content_loss(fine_feats[-1], comb_feats[-1]) 546 | # calculate style loss 547 | loss_s = self.calc_style_loss_mulitple_fg_layer(fine_feats, style_feats, mask,layer) 548 | return loss_c,loss_s 549 | 550 | 551 | def forward(self, comp, style, mask): 552 | style_feats = self.encode_with_intermediate(style) 553 | comb_feats = self.encode_with_intermediate(comp) 554 | 555 | output1 = self.enc_1(comp) 556 | output2 = self.enc_2(output1) 557 | output3 = self.enc_3(output2) 558 | output4 = self.enc_4(output3) 559 | 560 | width = height = output1.size(-1) 561 | downsample_mask1 = self.downsample(mask, width, height) 562 | t1 = adain_fg(output1, style_feats[0], downsample_mask1) 563 | 564 | width = height = output2.size(-1) 565 | downsample_mask2 = self.downsample(mask, width, height) 566 | t2 = adain_fg(output2, style_feats[1], downsample_mask2) 567 | 568 | width = height = output3.size(-1) 569 | downsample_mask3 = self.downsample(mask, width, height) 570 | t3 = adain_fg(output3, style_feats[2], downsample_mask3) 571 | 572 | width = height = output4.size(-1) 573 | downsample_mask4 = self.downsample(mask, width, height) 574 | t4 = adain_fg(output4, style_feats[3], downsample_mask4) 575 | 576 | output1 = self.updec_1(t1) 577 | output2 = self.updec_2(t2) 578 | output3 = self.updec_3(t3) 579 | output4 = self.updec_4(t4) 580 | 581 | 582 | coarse_output1 = self.final_conv_1(output1) 583 | # final_output1 = coarse_output1 584 | blend_mask1 = mask 585 | final_output1 = coarse_output1 * blend_mask1 + style * (1 - blend_mask1) 586 | 587 | output2 = self.fuseblk_2(output1.detach(),output2) 588 | coarse_output2 = self.final_conv_2(output2) 589 | # final_output2 = coarse_output2 590 | blend_mask2 = mask 591 | final_output2 = coarse_output2 * blend_mask2 + style * (1 - blend_mask2) 592 | 593 | output3 = self.fuseblk_3(output2.detach(),output3) 594 | coarse_output3 = self.final_conv_3(output3) 595 | # final_output3 = coarse_output3 596 | blend_mask3 = mask 597 | final_output3 = coarse_output3 * blend_mask3 + style * (1 - blend_mask3) 598 | 599 | output4 = self.fuseblk_4(output3.detach(),output4) 600 | coarse_output4 = self.final_conv_4(output4) 601 | # final_output4 = coarse_output4 602 | blend_mask4 = mask 603 | final_output4 = coarse_output4 * blend_mask4 + style * (1 - blend_mask4) 604 | 605 | 606 | loss_c_1,loss_s_1 = self.cal_loss(coarse_output1,comb_feats,style_feats,mask, 1) 607 | loss_c_2,loss_s_2 = self.cal_loss(coarse_output2,comb_feats,style_feats,mask, 2) 608 | loss_c_3,loss_s_3 = self.cal_loss(coarse_output3,comb_feats,style_feats,mask, 3) 609 | loss_c_4,loss_s_4 = self.cal_loss(coarse_output4,comb_feats,style_feats,mask, 4) 610 | 611 | loss_c = loss_c_1 + loss_c_2 + loss_c_3 + loss_c_4 612 | loss_s = loss_s_1 + loss_s_2 + loss_s_3 + loss_s_4 613 | 614 | 615 | loss_c_1,loss_s_1 = self.cal_loss(final_output1,comb_feats,style_feats,mask, 1) 616 | loss_c_2,loss_s_2 = self.cal_loss(final_output2,comb_feats,style_feats,mask, 2) 617 | loss_c_3,loss_s_3 = self.cal_loss(final_output3,comb_feats,style_feats,mask, 3) 618 | loss_c_4,loss_s_4 = self.cal_loss(final_output4,comb_feats,style_feats,mask, 4) 619 | 620 | loss_c += loss_c_1 + loss_c_2 + loss_c_3 + loss_c_4 621 | loss_s += loss_s_1 + loss_s_2 + loss_s_3 + loss_s_4 622 | 623 | 624 | # print(output1.shape, output2.shape, output3.shape, output4.shape) 625 | return final_output1, final_output2, final_output3, final_output4, \ 626 | coarse_output1, coarse_output2, coarse_output3, coarse_output4, \ 627 | blend_mask1*2-1, blend_mask2*2-1, blend_mask3*2-1, blend_mask4*2-1, \ 628 | loss_c, loss_s 629 | 630 | 631 | """ 632 | discriminator 633 | """ 634 | 635 | class ConvBlock_D(nn.Module): 636 | def __init__( 637 | self, 638 | in_channels, out_channels, 639 | kernel_size=4, stride=2, padding=1, 640 | norm_layer=nn.BatchNorm2d, activation=nn.ELU, 641 | bias=True, 642 | ): 643 | super(ConvBlock_D, self).__init__() 644 | 645 | self.block = nn.Sequential( 646 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), 647 | norm_layer(out_channels) if norm_layer is not None else nn.Identity(), 648 | nn.LeakyReLU(0.2, True), 649 | ) 650 | def forward(self, x): 651 | return self.block(x) 652 | 653 | 654 | class ConvEncoder_D(nn.Module): 655 | def __init__( 656 | self, 657 | depth, ch, patch_number, 658 | norm_layer, batchnorm_from, max_channels 659 | ): 660 | super(ConvEncoder_D, self).__init__() 661 | self.depth = depth 662 | self.patch_number = patch_number 663 | 664 | in_channels = 3 665 | out_channels = ch 666 | 667 | self.block0 = ConvBlock_D(in_channels, out_channels, norm_layer=norm_layer if batchnorm_from == 0 else None) 668 | self.block1 = ConvBlock_D(out_channels, out_channels, norm_layer=norm_layer if 0 <= batchnorm_from <= 1 else None) 669 | self.blocks_enc = nn.ModuleDict() 670 | for block_i in range(2, depth-2): 671 | if block_i % 2: 672 | in_channels = out_channels 673 | else: 674 | in_channels, out_channels = out_channels, min(2 * out_channels, max_channels) 675 | self.blocks_enc[f'block{block_i}'] = ConvBlock_D( 676 | in_channels, out_channels, 677 | norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None, 678 | padding=1 #int(block_i < depth - 1) 679 | ) 680 | 681 | self.blocks_connected = nn.ModuleDict() 682 | 683 | for block_i in range(depth - 2, depth): 684 | if block_i % 2: 685 | in_channels = out_channels 686 | else: 687 | in_channels, out_channels = out_channels, min(2 * out_channels, max_channels) 688 | self.blocks_connected[f'block{block_i}'] = ConvBlock_D( 689 | in_channels, out_channels, 690 | norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None, 691 | kernel_size=3, stride=1, padding=int(block_i < depth - 1) 692 | ) 693 | self.inner_channels = out_channels 694 | 695 | def forward(self, x): 696 | x = self.block0(x) 697 | x = self.block1(x) 698 | 699 | for block_i in range(2, self.depth - 2): 700 | block = self.blocks_enc[f'block{block_i}'] 701 | x = block(x) 702 | 703 | output = x 704 | 705 | for block_i in range(self.depth - 2, self.depth): 706 | block = self.blocks_connected[f'block{block_i}'] 707 | output = block(output) 708 | 709 | return output 710 | 711 | 712 | class DeconvDecoder_D(nn.Module): 713 | def __init__(self, depth, encoder_innner_channels, norm_layer): 714 | super(DeconvDecoder_D, self).__init__() 715 | self.deconv_blocks = nn.ModuleList() 716 | 717 | in_channels = encoder_innner_channels 718 | self.deconv_block0 = nn.Sequential( 719 | nn.UpsamplingNearest2d(scale_factor=2), 720 | #nn.ReflectionPad2d(1), 721 | nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, stride=1, padding=1), 722 | norm_layer(in_channels // 2) if norm_layer is not None else nn.Identity(), 723 | nn.ReLU(True), 724 | ) 725 | self.deconv_block1 = nn.Sequential( 726 | #nn.UpsamplingNearest2d(scale_factor=2), 727 | #nn.ReflectionPad2d(1), 728 | nn.Conv2d(in_channels // 2, in_channels // 4, kernel_size=3, stride=1, padding=1), 729 | norm_layer(in_channels // 4) if norm_layer is not None else nn.Identity(), 730 | nn.ReLU(True), 731 | ) 732 | 733 | self.to_binary = nn.Conv2d(in_channels // 4, 1, kernel_size=1) 734 | 735 | def forward(self, encoder_outputs, image): 736 | 737 | output = self.deconv_block0(encoder_outputs) 738 | output = self.deconv_block1(output) 739 | 740 | output = self.to_binary(output) 741 | 742 | return output 743 | 744 | 745 | class ConvDiscriminator(nn.Module): 746 | def __init__( 747 | self, 748 | depth, patch_number, 749 | norm_layer=nn.BatchNorm2d, batchnorm_from=0, 750 | ch=64, max_channels=512 751 | ): 752 | super(ConvDiscriminator, self).__init__() 753 | self.depth = depth 754 | self.patch_number = patch_number 755 | self.encoder = ConvEncoder_D(depth, ch, patch_number, norm_layer, batchnorm_from, max_channels) 756 | self.decoder = DeconvDecoder_D(2, self.encoder.inner_channels, norm_layer) 757 | 758 | def forward(self, image): 759 | intermediates = self.encoder(image) 760 | output = self.decoder(intermediates, image) 761 | return output -------------------------------------------------------------------------------- /models/vgg19hrnet_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from collections import OrderedDict 4 | from . import networks 5 | import torch.nn.functional as F 6 | from torch import nn, cuda 7 | from torch.autograd import Variable 8 | import time 9 | import numpy as np 10 | from util import util 11 | import os, cv2 12 | import itertools 13 | from PIL import Image 14 | 15 | 16 | class VGG19HRNetModel(BaseModel): 17 | def __init__(self, opt): 18 | BaseModel.__init__(self, opt) 19 | self.opt = opt 20 | # specify the training losses you want to print out. The training/test scripts will call 21 | self.loss_names = ['G', 'c', 's', 'G_GAN', 'D', 'D_fake', 'D_real'] 22 | # specify the images you want to save/display. The training/test scripts will call 23 | self.visual_names = ['comp', 'mask_vis', 'style','final_output_1_supervised','final_output_2_supervised','final_output_3_supervised','final_output_4_supervised','final_output_1','final_output_2','final_output_3','final_output_4'] 24 | 25 | if self.isTrain: 26 | self.model_names = ['G', 'D'] 27 | else: 28 | self.model_names = ['G'] 29 | # define networks (both generator and discriminator) 30 | self.netvgg = networks.vgg 31 | self.netvgg.load_state_dict(torch.load(opt.vgg)) 32 | self.netvgg = nn.Sequential(*list(self.netvgg.children())[:31]) 33 | 34 | 35 | if opt.netG == 'vgg19hrnet': 36 | self.netG = networks.VGG19HRNet(self.netvgg) 37 | else: 38 | raise NotImplementedError(f'{opt.netD} not implemented') 39 | 40 | if len(self.gpu_ids) > 0: 41 | assert(torch.cuda.is_available()) 42 | self.netG.to(self.gpu_ids[0]) 43 | self.netG = torch.nn.DataParallel(self.netG, self.gpu_ids) 44 | 45 | if self.isTrain: 46 | # define loss functions 47 | self.criterionGAN = nn.MSELoss().to(self.device) 48 | if opt.netD == 'conv': 49 | netD = networks.ConvDiscriminator(depth=8, patch_number=opt.patch_number, batchnorm_from=0) 50 | else: 51 | raise NotImplementedError(f'{opt.netD} not implemented') 52 | self.netD = networks.init_net(netD, opt.init_type, opt.init_gain, self.gpu_ids) 53 | # initialize optimizers; schedulers will be automatically created by function . 54 | self.optimizer_G = torch.optim.Adam(filter(lambda p: p.requires_grad, self.netG.parameters()), lr=opt.lr*opt.g_lr_ratio, betas=(opt.beta1, 0.999)) 55 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr*opt.d_lr_ratio, betas=(opt.beta1, 0.999)) 56 | self.optimizers.append(self.optimizer_G) 57 | self.optimizers.append(self.optimizer_D) 58 | 59 | 60 | def set_input(self, input): 61 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 62 | 63 | Parameters: 64 | input (dict): include the data itself and its metadata information. 65 | """ 66 | self.content = input['content'].to(self.device) 67 | self.style = input['style'].to(self.device) 68 | self.comp = input['comp'].to(self.device) 69 | self.mask_vis = input['mask'].to(self.device) 70 | self.mask = self.mask_vis/2+0.5 71 | if self.isTrain: 72 | self.mask_patch = input['mask_patch'].to(self.device) 73 | 74 | def forward(self): 75 | """Employ generator to generate the output, and calculate the losses for generator""" 76 | 77 | self.final_output_1, self.final_output_2, self.final_output_3,self.final_output_4, \ 78 | self.coarse_output_1, self.coarse_output_2, self.coarse_output_3,self.coarse_output_4, \ 79 | self.blend_mask1, self.blend_mask2, self.blend_mask3, self.blend_mask4, \ 80 | self.loss_c, self.loss_s = self.netG(self.comp, self.style, self.mask) 81 | 82 | self.output = self.final_output_4 83 | 84 | def backward_D_1(self): 85 | """Calculate GAN loss for the discriminator""" 86 | # Fake; 87 | fake_AB = self.final_output_1 88 | pred_fake = self.netD(fake_AB.detach()) 89 | pred_comp = self.netD(self.comp) 90 | output_fake = self.criterionGAN(pred_fake, self.mask_patch) 91 | composite_fake = self.criterionGAN(pred_comp, self.mask_patch) 92 | loss_D_fake = output_fake + composite_fake 93 | 94 | # Real 95 | real_AB = self.style 96 | pred_real = self.netD(real_AB) 97 | loss_D_real = self.criterionGAN(pred_real, torch.zeros(self.mask_patch.size()).cuda()) 98 | 99 | # combine loss and calculate gradients 100 | self.loss_D_1 = loss_D_fake + loss_D_real 101 | self.loss_D_1.backward(retain_graph=True) 102 | 103 | def backward_D_2(self): 104 | """Calculate GAN loss for the discriminator""" 105 | # Fake; 106 | fake_AB = self.final_output_2 107 | pred_fake = self.netD(fake_AB.detach()) 108 | pred_comp = self.netD(self.comp) 109 | output_fake = self.criterionGAN(pred_fake, self.mask_patch) 110 | composite_fake = self.criterionGAN(pred_comp, self.mask_patch) 111 | loss_D_fake = output_fake + composite_fake 112 | 113 | # Real 114 | real_AB = self.style 115 | pred_real = self.netD(real_AB) 116 | loss_D_real = self.criterionGAN(pred_real, torch.zeros(self.mask_patch.size()).cuda()) 117 | 118 | # combine loss and calculate gradients 119 | self.loss_D_2 = loss_D_fake + loss_D_real 120 | self.loss_D_2.backward(retain_graph=True) 121 | 122 | def backward_D_3(self): 123 | """Calculate GAN loss for the discriminator""" 124 | # Fake; 125 | fake_AB = self.final_output_3 126 | pred_fake = self.netD(fake_AB.detach()) 127 | pred_comp = self.netD(self.comp) 128 | output_fake = self.criterionGAN(pred_fake, self.mask_patch) 129 | composite_fake = self.criterionGAN(pred_comp, self.mask_patch) 130 | loss_D_fake = output_fake + composite_fake 131 | 132 | # Real 133 | real_AB = self.style 134 | pred_real = self.netD(real_AB) 135 | loss_D_real = self.criterionGAN(pred_real, torch.zeros(self.mask_patch.size()).cuda()) 136 | 137 | # combine loss and calculate gradients 138 | self.loss_D_3 = loss_D_fake + loss_D_real 139 | self.loss_D_3.backward(retain_graph=True) 140 | 141 | def backward_D(self): 142 | """Calculate GAN loss for the discriminator""" 143 | # Fake; 144 | fake_AB = self.output 145 | self.pred_fake = self.netD(fake_AB.detach()) 146 | self.pred_comp = self.netD(self.comp) 147 | output_fake = self.criterionGAN(self.pred_fake, self.mask_patch) 148 | composite_fake = self.criterionGAN(self.pred_comp, self.mask_patch) 149 | self.loss_D_fake = output_fake + composite_fake 150 | 151 | # Real 152 | real_AB = self.style 153 | self.pred_real = self.netD(real_AB) 154 | self.loss_D_real = self.criterionGAN(self.pred_real, torch.zeros(self.mask_patch.size()).cuda()) 155 | 156 | # combine loss and calculate gradients 157 | self.loss_D = self.loss_D_fake + self.loss_D_real 158 | self.loss_D.backward(retain_graph=True) 159 | 160 | 161 | 162 | def backward_G(self): 163 | """Calculate GAN and other losses for the generator""" 164 | # GAN loss 165 | # fake_AB = self.output 166 | self.pred_fake_G = self.netD(self.final_output_1) + self.netD(self.final_output_2) + self.netD(self.final_output_3) + self.netD(self.final_output_4) 167 | self.loss_G_GAN = self.criterionGAN(self.pred_fake_G, torch.zeros(self.mask_patch.size()).cuda()) 168 | 169 | self.loss_G = self.opt.lambda_content * self.loss_c + self.opt.lambda_style * self.loss_s + self.opt.lambda_g * self.loss_G_GAN 170 | print(f'g {self.loss_G.item()},c {self.loss_c.item()}, s {self.loss_s.item()}, gan {self.loss_G_GAN.item()}') 171 | 172 | self.loss_G.backward(retain_graph=True) 173 | 174 | def optimize_parameters(self): 175 | """optimize both G and D, only run this in training phase""" 176 | self.forward() 177 | # update D 178 | self.set_requires_grad(self.netD, True) # enable backprop for D 179 | self.optimizer_D.zero_grad() # set D's gradients to zero 180 | self.backward_D() # calculate gradients for D 181 | self.backward_D_1() # calculate gradients for D 182 | self.backward_D_2() 183 | self.backward_D_3() 184 | self.optimizer_D.step() # update D's weights 185 | # update G 186 | self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G 187 | self.optimizer_G.zero_grad() # set G's gradients to zero 188 | self.backward_G() # calculate graidents for G 189 | self.optimizer_G.step() # udpate G's weights 190 | 191 | 192 | def get_current_visuals(self): 193 | num = self.style.size(0) 194 | visual_ret = OrderedDict() 195 | all =[] 196 | for i in range(0,num): 197 | row=[] 198 | for name in self.visual_names: 199 | if isinstance(name, str): 200 | if hasattr(self,name): 201 | im = util.tensor2im(getattr(self, name).data[i:i+1,:,:,:]) 202 | row.append(im) 203 | row=tuple(row) 204 | all.append(np.hstack(row)) 205 | all = tuple(all) 206 | 207 | allim = np.vstack(all) 208 | return OrderedDict([(self.opt.name,allim)]) 209 | 210 | 211 | 212 | 213 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /options/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/options/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /options/__pycache__/base_options.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/options/__pycache__/base_options.cpython-39.pyc -------------------------------------------------------------------------------- /options/__pycache__/test_options.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/options/__pycache__/test_options.cpython-39.pyc -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | import models 6 | import data 7 | 8 | 9 | class BaseOptions(): 10 | """This class defines options used during both training and test time. 11 | 12 | It also implements several helper functions such as parsing, printing, and saving the options. 13 | It also gathers additional options defined in functions in both dataset class and model class. 14 | """ 15 | 16 | def __init__(self): 17 | """Reset the class; indicates the class hasn't been initailized""" 18 | self.initialized = False 19 | 20 | def initialize(self, parser): 21 | """Define the common options that are used in both training and test.""" 22 | # basic parameters 23 | parser.add_argument('--dataset_root',type=str, default='here', help='path to iHarmony4 dataset') #mia 24 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 25 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 26 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 27 | parser.add_argument('--is_train', type=bool, default=True, help='train mode') 28 | # model parameters 29 | parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]') 30 | parser.add_argument('--input_nc', type=int, default=4, help='# of input image channels: 4 for concated comp and mask') #mia 31 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale') 32 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') 33 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') 34 | parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') 35 | parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]') 36 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') 37 | parser.add_argument('--normD', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') 38 | parser.add_argument('--normG', type=str, default='RAN_Method1', help='Regional Adaptive Normalization or batch normalization [instance | batch | none]') 39 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') 40 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 41 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') 42 | parser.add_argument('--patch_number', type=int, default=4, help='number of patch in discrinimator output') 43 | # dataset parameters 44 | parser.add_argument('--dataset_mode', type=str, default='iharmony4', help='load iHarmony4 dataset') #mia 45 | parser.add_argument('--content_dir', type=str, default='/MS-COCO/', help='Directory path to a batch of content images') 46 | parser.add_argument('--style_dir', type=str, default='/wikiart/', help='Directory path to a batch of style images') 47 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 48 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') 49 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size') 50 | parser.add_argument('--load_size', type=int, default=256, help='scale images to this size') 51 | parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size') 52 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 53 | parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]') 54 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') 55 | # additional parameters 56 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 57 | parser.add_argument('--load_iter', type=float, default=0, help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') 58 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 59 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') 60 | parser.add_argument('--vgg', type=str, default='../checkpoints/pretrained/vgg_normalised.pth') 61 | parser.add_argument('--display_ncols', type=int, default=4, 62 | help='if positive, display all images in a single visdom web panel with certain number of images per row.') 63 | parser.add_argument('--g_lr_ratio', type=float, default=1, help='a ratio for changing learning rate of generator') #mia 64 | parser.add_argument('--d_lr_ratio', type=float, default=1, help='a ratio for changing learning rate of discriminator') #mia 65 | 66 | self.initialized = True 67 | return parser 68 | 69 | def gather_options(self): 70 | """Initialize our parser with basic options(only once). 71 | Add additional model-specific and dataset-specific options. 72 | These options are defined in the function 73 | in model and dataset classes. 74 | """ 75 | if not self.initialized: # check if it has been initialized 76 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 77 | parser = self.initialize(parser) 78 | 79 | # get the basic options 80 | opt, _ = parser.parse_known_args() 81 | 82 | # modify model-related parser options 83 | model_name = opt.model 84 | model_option_setter = models.get_option_setter(model_name) 85 | parser = model_option_setter(parser, self.isTrain) 86 | opt, _ = parser.parse_known_args() # parse again with new defaults 87 | 88 | # save and return the parser 89 | self.parser = parser 90 | return parser.parse_args() 91 | 92 | def print_options(self, opt): 93 | """Print and save options 94 | 95 | It will print both current options and default values(if different). 96 | It will save options into a text file / [checkpoints_dir] / opt.txt 97 | """ 98 | message = '' 99 | message += '----------------- Options ---------------\n' 100 | for k, v in sorted(vars(opt).items()): 101 | comment = '' 102 | default = self.parser.get_default(k) 103 | if v != default: 104 | comment = '\t[default: %s]' % str(default) 105 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 106 | message += '----------------- End -------------------' 107 | print(message) 108 | 109 | # save to the disk 110 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 111 | util.mkdirs(expr_dir) 112 | file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) 113 | with open(file_name, 'wt') as opt_file: 114 | opt_file.write(message) 115 | opt_file.write('\n') 116 | 117 | def parse(self): 118 | """Parse our options, create checkpoints directory suffix, and set up gpu device.""" 119 | opt = self.gather_options() 120 | opt.isTrain = self.isTrain # train or test 121 | 122 | # process opt.suffix 123 | if opt.suffix: 124 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 125 | opt.name = opt.name + suffix 126 | 127 | self.print_options(opt) 128 | 129 | # set gpu ids 130 | str_ids = opt.gpu_ids.split(',') 131 | opt.gpu_ids = [] 132 | for str_id in str_ids: 133 | id = int(str_id) 134 | if id >= 0: 135 | opt.gpu_ids.append(id) 136 | if len(opt.gpu_ids) > 0: 137 | torch.cuda.set_device(opt.gpu_ids[0]) 138 | 139 | self.opt = opt 140 | return self.opt 141 | 142 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | """This class includes training options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) 12 | 13 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') 14 | parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') 15 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 16 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') 17 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 18 | parser.add_argument('--update_html_freq', type=int, default=500, help='frequency of saving training results to html') 19 | parser.add_argument('--print_freq', type=int, default=300, help='frequency of showing training results on console') 20 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 21 | # network saving and loading parameters 22 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 23 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') 24 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') 25 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 26 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 27 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 28 | # training parameters 29 | parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') 30 | parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') 31 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 32 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 33 | 34 | parser.add_argument('--gan_mode', type=str, default='wgangp', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') 35 | parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') 36 | parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') 37 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 38 | 39 | parser.set_defaults(pool_size=0, gan_mode='vanilla') 40 | 41 | 42 | self.isTrain = True 43 | return parser 44 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | """This class includes training options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) 12 | 13 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') 14 | parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') 15 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 16 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') 17 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 18 | parser.add_argument('--update_html_freq', type=int, default=500, help='frequency of saving training results to html') 19 | parser.add_argument('--print_freq', type=int, default=300, help='frequency of showing training results on console') 20 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 21 | # network saving and loading parameters 22 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 23 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') 24 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') 25 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 26 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 27 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 28 | # training parameters 29 | parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') 30 | parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') 31 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 32 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 33 | 34 | parser.add_argument('--gan_mode', type=str, default='wgangp', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') 35 | parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') 36 | parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') 37 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 38 | parser.set_defaults(pool_size=0, gan_mode='vanilla') 39 | # hyper-parameters 40 | parser.add_argument('--lambda_g', type=float, default=1.0, help='weight for adversarial loss') 41 | parser.add_argument('--lambda_gf', type=float, default=1.0, help='weight for adversarial loss') 42 | parser.add_argument('--lambda_style', type=float, default=1.0, help='weight for style loss') 43 | parser.add_argument('--lambda_content', type=float, default=1.0, help='weight for content loss') 44 | parser.add_argument('--lambda_tv', type=float, default=1e-5, help='weight for smooth loss') 45 | 46 | self.isTrain = True 47 | return parser 48 | -------------------------------------------------------------------------------- /output/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/output/1.jpg -------------------------------------------------------------------------------- /output/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/output/2.jpg -------------------------------------------------------------------------------- /output/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/output/3.jpg -------------------------------------------------------------------------------- /output/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/output/4.jpg -------------------------------------------------------------------------------- /output/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/output/5.jpg -------------------------------------------------------------------------------- /output/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/output/6.jpg -------------------------------------------------------------------------------- /output/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/output/7.jpg -------------------------------------------------------------------------------- /output/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/output/8.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dominate==2.6.0 2 | numpy==1.19.5 3 | opencv_python==4.9.0.80 4 | Pillow==10.2.0 5 | PyYAML==6.0.1 6 | scikit-image==0.17.2 7 | tensorboardX==2.6.2.2 8 | tqdm==4.64.1 9 | visdom==0.1.8.9 10 | yacs==0.1.8 11 | -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | DISPLAY_PORT=8097 3 | 4 | G='vgg19hrnet' 5 | D='conv' 6 | model_name=vgg19hrnet 7 | loadSize=256 8 | 9 | #hyper-parameters 10 | gpu_id=0 11 | batchs=1 12 | load_iter=0 13 | 14 | # network design 15 | patch_num=4 16 | 17 | test_epoch=latest 18 | datasetmode=cocoart 19 | content_dir="../examples/" 20 | style_dir="../examples/" 21 | 22 | NAME="pretrained" 23 | 24 | checkpoint="../checkpoints/" 25 | 26 | 27 | CMD="python ../test.py \ 28 | --name $NAME \ 29 | --checkpoints_dir $checkpoint \ 30 | --model $model_name \ 31 | --netG $G \ 32 | --netD $D \ 33 | --dataset_mode $datasetmode \ 34 | --content_dir $content_dir \ 35 | --style_dir $style_dir \ 36 | --is_train 0 \ 37 | --display_id 0 \ 38 | --normD batch \ 39 | --normG batch \ 40 | --preprocess none \ 41 | --niter 100 \ 42 | --niter_decay 100 \ 43 | --input_nc 3 \ 44 | --batch_size $batchs \ 45 | --num_threads 6 \ 46 | --print_freq 1000 \ 47 | --display_freq 1 \ 48 | --save_latest_freq 1000 \ 49 | --gpu_ids $gpu_id \ 50 | --epoch $test_epoch \ 51 | " 52 | echo $CMD 53 | eval $CMD 54 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | DISPLAY_PORT=8097 3 | 4 | G='vgg19hrnet' 5 | D='conv' 6 | model_name=vgg19hrnet 7 | loadSize=256 8 | 9 | #hyper-parameters 10 | gpu_id=0 11 | L_S=1 12 | L_C=1 13 | L_GAN=1 14 | L_tv=1e-5 15 | lr=2e-4 16 | batchs=4 17 | load_iter=0 18 | 19 | 20 | datasetmode=cocoart 21 | content_dir="../datasets/painterly/MS-COCO/" 22 | style_dir="../datasets/painterly/wikiart/" 23 | 24 | NAME="${G}_Content${L_C}_Style${L_S}_Ggan${L_GAN}_lr${lr}_batch${batchs}" 25 | 26 | checkpoint="../checkpoints/" 27 | 28 | 29 | CMD="python ../train.py \ 30 | --name $NAME \ 31 | --checkpoints_dir $checkpoint \ 32 | --model $model_name \ 33 | --netG $G \ 34 | --netD $D \ 35 | --dataset_mode $datasetmode \ 36 | --content_dir $content_dir \ 37 | --style_dir $style_dir \ 38 | --is_train 1 \ 39 | --display_id 0 \ 40 | --normD batch \ 41 | --normG batch \ 42 | --preprocess none \ 43 | --niter 20 \ 44 | --niter_decay 10 \ 45 | --input_nc 3 \ 46 | --batch_size $batchs \ 47 | --num_threads 6 \ 48 | --print_freq 500 \ 49 | --display_freq 500 \ 50 | --save_latest_freq 1000 \ 51 | --patch_number 4 \ 52 | --gpu_ids $gpu_id \ 53 | --lambda_g $L_GAN \ 54 | --lambda_style $L_S \ 55 | --lambda_content $L_C \ 56 | --lambda_tv $L_tv \ 57 | --lr $lr \ 58 | --load_iter $load_iter \ 59 | --continue_train \ 60 | --epoch latest \ 61 | 62 | 63 | " 64 | echo $CMD 65 | eval $CMD 66 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import realpath 3 | from options.test_options import TestOptions 4 | import torch 5 | import numpy as np 6 | from util import util 7 | from util.visualizer import Visualizer 8 | from PIL import Image, ImageFile 9 | from data import CustomDataset 10 | from models import create_model 11 | import os 12 | from tqdm import tqdm 13 | import time 14 | 15 | 16 | opt = TestOptions().parse() # get training 17 | opt.isTrain = False 18 | 19 | visualizer = Visualizer(opt,'test') 20 | test_dataset = CustomDataset(opt, is_for_train=False) 21 | test_dataset_size = len(test_dataset) 22 | print('The number of testing images = %d' % test_dataset_size) 23 | test_dataloader = test_dataset.load_data() 24 | 25 | model = create_model(opt) # create a model given opt.model and other options 26 | model.setup(opt) # regular setup: load and print networks; create schedulers 27 | model.netG.eval() # inference 28 | 29 | total_iters = 0 30 | iter_data_time = time.time() 31 | for i, data in enumerate(tqdm(test_dataloader)): # inner loop within one epoch 32 | img_path = data['img_path'][0] 33 | 34 | 35 | save_dir = '../output/' 36 | img_name = img_path.split('/')[-1] 37 | 38 | model.set_input(data) # unpack data from dataset 39 | model.forward() # calculate loss functions, get gradients, update network weights 40 | visual_dict = model.get_current_visuals() 41 | print('saving iteration {}'.format(i)) 42 | 43 | #visualizer.display_current_results(visual_dict, total_iters) 44 | visualizer.save_images(visual_dict, save_dir, img_name) 45 | 46 | 47 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import TrainOptions 3 | from data import CustomDataset 4 | from models import create_model 5 | # from torch.utils.tensorboard import SummaryWriter 6 | from tensorboardX import SummaryWriter 7 | 8 | import os 9 | from util import util 10 | import numpy as np 11 | import torch 12 | from skimage.metrics import mean_squared_error 13 | from skimage.metrics import peak_signal_noise_ratio 14 | from tqdm import tqdm 15 | 16 | from util.visualizer import Visualizer 17 | from PIL import Image, ImageFile 18 | 19 | 20 | Image.MAX_IMAGE_PIXELS = None # Disable DecompressionBombError 21 | # Disable OSError: image file is truncated 22 | ImageFile.LOAD_TRUNCATED_IMAGES = True 23 | 24 | def setup_seed(seed): 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed_all(seed) 27 | np.random.seed(seed) 28 | 29 | 30 | if __name__ == '__main__': 31 | # setup_seed(6) 32 | 33 | opt = TrainOptions().parse() # get training 34 | visualizer = Visualizer(opt,'train') 35 | visualizer_test = Visualizer(opt,'test') 36 | 37 | train_dataset = CustomDataset(opt, is_for_train=True) 38 | # test_dataset = CustomDataset(opt, is_for_train=False) 39 | 40 | train_dataset_size = len(train_dataset) # get the number of images in the dataset. 41 | # test_dataset_size = len(test_dataset) 42 | print('The number of training images = %d' % train_dataset_size) 43 | # print('The number of testing images = %d' % test_dataset_size) 44 | 45 | train_dataloader = train_dataset.load_data() 46 | # test_dataloader = test_dataset.load_data() 47 | print('The total batches of training images = %d' % len(train_dataset.dataloader)) 48 | 49 | model = create_model(opt) # create a model given opt.model and other options 50 | model.setup(opt) # regular setup: load and print networks; create schedulers 51 | total_iters = 0 # the total number of training iterations 52 | writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name)) 53 | 54 | for epoch in tqdm(range(int(opt.load_iter)+1, opt.niter + opt.niter_decay + 1)): 55 | epoch_start_time = time.time() # timer for entire epoch 56 | iter_data_time = time.time() # timer for data loading per iteration 57 | epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch 58 | for i, data in enumerate(tqdm(train_dataloader)): # inner loop within one epoch 59 | if i > len(train_dataset.dataloader) -2: 60 | continue 61 | # print('epoch {} iter {}'.format(epoch, i)) 62 | iter_start_time = time.time() # timer for computation per iteration 63 | if total_iters % opt.print_freq == 0: 64 | t_data = iter_start_time - iter_data_time 65 | total_iters += 1 66 | epoch_iter += 1 67 | model.set_input(data) # unpack data from dataset 68 | model.optimize_parameters() # calculate loss functions, get gradients, update network weights 69 | 70 | if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk 71 | losses = model.get_current_losses() 72 | t_comp = (time.time() - iter_start_time) / opt.batch_size 73 | writer.add_scalar('./loss/loss_content', losses['c'], i + 1) 74 | writer.add_scalar('./loss/loss_style', losses['s'], i + 1) 75 | # writer.add_scalar('./loss/loss_tv', losses['tv'], i + 1) 76 | 77 | if total_iters % opt.display_freq == 0: 78 | visual_dict = model.get_current_visuals() 79 | visualizer.display_current_results(visual_dict, epoch) 80 | 81 | 82 | if total_iters % opt.save_latest_freq == 0: # cache our latest model every iterations 83 | print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) 84 | save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest' 85 | model.save_networks(save_suffix) 86 | 87 | iter_data_time = time.time() 88 | 89 | 90 | torch.cuda.empty_cache() 91 | if epoch % opt.save_epoch_freq == 0: # cache our model every epochs 92 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) 93 | model.save_networks('latest') 94 | model.save_networks('%d' % epoch) 95 | 96 | print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 97 | model.update_learning_rate() # update learning rates at the end of every epoch. 98 | for scheduler in model.schedulers: 99 | print('Current learning rate: {}'.format(scheduler.get_lr())) 100 | 101 | writer.close() 102 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | -------------------------------------------------------------------------------- /util/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/util/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /util/__pycache__/html.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/util/__pycache__/html.cpython-39.pyc -------------------------------------------------------------------------------- /util/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/util/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /util/__pycache__/visualizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/ProPIH-Painterly-Image-Harmonization/e53f85c9bf67fad4c0ab463f8dc340536a083d59/util/__pycache__/visualizer.cpython-39.pyc -------------------------------------------------------------------------------- /util/config.py: -------------------------------------------------------------------------------- 1 | # This config file is inspired by https://github.com/Yadiraf/DECA/decalib/utils/config.py 2 | # Testing config for our RainNet 3 | 4 | from yacs.config import CfgNode as CN 5 | import argparse 6 | import yaml 7 | import os 8 | 9 | cfg = CN() 10 | # ------ dataloader ------------- 11 | cfg.dataset_root = '../dataset/iHarmony4' 12 | cfg.dataset_mode = 'iharmony4' 13 | cfg.batch_size = 10 14 | cfg.beta1 = 0.5 15 | cfg.checkpoints_dir = './checkpoints' 16 | cfg.crop_size = 256 17 | cfg.load_size = 256 18 | cfg.num_threads = 11 19 | cfg.preprocess = 'none' # 20 | # ------ model ------------- 21 | cfg.gan_mode = 'wgangp' 22 | cfg.model = 'rainnet' 23 | cfg.netG = 'rainnet' 24 | cfg.normD = 'instance' 25 | cfg.normG = 'RAIN' 26 | cfg.is_train = False 27 | cfg.input_nc = 3 28 | cfg.output_nc = 3 29 | cfg.ngf = 64 30 | cfg.no_dropout = False 31 | # ------ training ------------- 32 | cfg.name = 'experiment_train' 33 | cfg.gpu_ids = 0 34 | cfg.lambda_L1 = 100 35 | cfg.print_freq = 400 36 | cfg.continue_train = False 37 | cfg.load_iter = 0 38 | cfg.niter = 100 39 | cfg.niter_decay = 0 40 | 41 | 42 | def get_cfg_defaults(): 43 | """Get a yacs CfgNode object with default values for my_project.""" 44 | # Return a clone so that the defaults will not be altered 45 | # This is for the "local variable" use pattern 46 | return cfg.clone() 47 | 48 | def update_cfg(cfg, cfg_file): 49 | cfg.merge_from_file(cfg_file) 50 | return cfg.clone() 51 | 52 | def parse_args(): 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument('--cfg', type=str, help='cfg file path') 55 | 56 | args = parser.parse_args() 57 | print(args, end='\n\n') 58 | 59 | cfg = get_cfg_defaults() 60 | if args.cfg is not None: 61 | cfg_file = args.cfg 62 | cfg = update_cfg(cfg, args.cfg) 63 | cfg.cfg_file = cfg_file 64 | 65 | return cfg 66 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, reflesh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | self.img_dir = os.path.join(self.web_dir, 'images') 11 | if not os.path.exists(self.web_dir): 12 | os.makedirs(self.web_dir) 13 | if not os.path.exists(self.img_dir): 14 | os.makedirs(self.img_dir) 15 | # print(self.img_dir) 16 | 17 | self.doc = dominate.document(title=title) 18 | if reflesh > 0: 19 | with self.doc.head: 20 | meta(http_equiv="reflesh", content=str(reflesh)) 21 | 22 | def get_image_dir(self): 23 | return self.img_dir 24 | 25 | def add_header(self, str): 26 | with self.doc: 27 | h3(str) 28 | 29 | def add_table(self, border=1): 30 | self.t = table(border=border, style="table-layout: fixed;") 31 | self.doc.add(self.t) 32 | 33 | def add_images(self, ims, txts, links, width=400): 34 | self.add_table() 35 | with self.t: 36 | with tr(): 37 | for im, txt, link in zip(ims, txts, links): 38 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 39 | with p(): 40 | with a(href=os.path.join('images', link)): 41 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 42 | br() 43 | p(txt) 44 | 45 | def save(self): 46 | html_file = '%s/index.html' % self.web_dir 47 | f = open(html_file, 'wt') 48 | f.write(self.doc.render()) 49 | f.close() 50 | 51 | 52 | if __name__ == '__main__': 53 | html = HTML('web/', 'test_html') 54 | html.add_header('hello world') 55 | 56 | ims = [] 57 | txts = [] 58 | links = [] 59 | for n in range(4): 60 | ims.append('image_%d.png' % n) 61 | txts.append('text_%d' % n) 62 | links.append('image_%d.png' % n) 63 | html.add_images(ims, txts, links) 64 | html.save() 65 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class ImagePool(): 6 | """This class implements an image buffer that stores previously generated images. 7 | 8 | This buffer enables us to update discriminators using a history of generated images 9 | rather than the ones produced by the latest generators. 10 | """ 11 | 12 | def __init__(self, pool_size): 13 | """Initialize the ImagePool class 14 | 15 | Parameters: 16 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 17 | """ 18 | self.pool_size = pool_size 19 | if self.pool_size > 0: # create an empty pool 20 | self.num_imgs = 0 21 | self.images = [] 22 | 23 | def query(self, images): 24 | """Return an image from the pool. 25 | 26 | Parameters: 27 | images: the latest generated images from the generator 28 | 29 | Returns images from the buffer. 30 | 31 | By 50/100, the buffer will return input images. 32 | By 50/100, the buffer will return images previously stored in the buffer, 33 | and insert the current images to the buffer. 34 | """ 35 | if self.pool_size == 0: # if the buffer size is 0, do nothing 36 | return images 37 | return_images = [] 38 | for image in images: 39 | image = torch.unsqueeze(image.data, 0) 40 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer 41 | self.num_imgs = self.num_imgs + 1 42 | self.images.append(image) 43 | return_images.append(image) 44 | else: 45 | p = random.uniform(0, 1) 46 | if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer 47 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 48 | tmp = self.images[random_id].clone() 49 | self.images[random_id] = image 50 | return_images.append(tmp) 51 | else: # by another 50% chance, the buffer will return the current image 52 | return_images.append(image) 53 | return_images = torch.cat(return_images, 0) # collect all the images and return 54 | return return_images 55 | -------------------------------------------------------------------------------- /util/spectral_norm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Spectral Normalization from https://arxiv.org/abs/1802.05957 3 | """ 4 | import torch 5 | from torch.nn.functional import normalize 6 | 7 | 8 | class SpectralNorm(object): 9 | # Invariant before and after each forward call: 10 | # u = normalize(W @ v) 11 | # NB: At initialization, this invariant is not enforced 12 | 13 | _version = 1 14 | # At version 1: 15 | # made `W` not a buffer, 16 | # added `v` as a buffer, and 17 | # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`. 18 | 19 | def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12): 20 | self.name = name 21 | self.dim = dim 22 | if n_power_iterations <= 0: 23 | raise ValueError('Expected n_power_iterations to be positive, but ' 24 | 'got n_power_iterations={}'.format(n_power_iterations)) 25 | self.n_power_iterations = n_power_iterations 26 | self.eps = eps 27 | 28 | def reshape_weight_to_matrix(self, weight): 29 | weight_mat = weight 30 | if self.dim != 0: 31 | # permute dim to front 32 | weight_mat = weight_mat.permute(self.dim, 33 | *[d for d in range(weight_mat.dim()) if d != self.dim]) 34 | height = weight_mat.size(0) 35 | return weight_mat.reshape(height, -1) 36 | 37 | def compute_weight(self, module, do_power_iteration): 38 | # NB: If `do_power_iteration` is set, the `u` and `v` vectors are 39 | # updated in power iteration **in-place**. This is very important 40 | # because in `DataParallel` forward, the vectors (being buffers) are 41 | # broadcast from the parallelized module to each module replica, 42 | # which is a new module object created on the fly. And each replica 43 | # runs its own spectral norm power iteration. So simply assigning 44 | # the updated vectors to the module this function runs on will cause 45 | # the update to be lost forever. And the next time the parallelized 46 | # module is replicated, the same randomly initialized vectors are 47 | # broadcast and used! 48 | # 49 | # Therefore, to make the change propagate back, we rely on two 50 | # important behaviors (also enforced via tests): 51 | # 1. `DataParallel` doesn't clone storage if the broadcast tensor 52 | # is already on correct device; and it makes sure that the 53 | # parallelized module is already on `device[0]`. 54 | # 2. If the out tensor in `out=` kwarg has correct shape, it will 55 | # just fill in the values. 56 | # Therefore, since the same power iteration is performed on all 57 | # devices, simply updating the tensors in-place will make sure that 58 | # the module replica on `device[0]` will update the _u vector on the 59 | # parallized module (by shared storage). 60 | # 61 | # However, after we update `u` and `v` in-place, we need to **clone** 62 | # them before using them to normalize the weight. This is to support 63 | # backproping through two forward passes, e.g., the common pattern in 64 | # GAN training: loss = D(real) - D(fake). Otherwise, engine will 65 | # complain that variables needed to do backward for the first forward 66 | # (i.e., the `u` and `v` vectors) are changed in the second forward. 67 | weight = getattr(module, self.name + '_orig') 68 | u = getattr(module, self.name + '_u') 69 | v = getattr(module, self.name + '_v') 70 | weight_mat = self.reshape_weight_to_matrix(weight) 71 | 72 | if do_power_iteration: 73 | with torch.no_grad(): 74 | for _ in range(self.n_power_iterations): 75 | # Spectral norm of weight equals to `u^T W v`, where `u` and `v` 76 | # are the first left and right singular vectors. 77 | # This power iteration produces approximations of `u` and `v`. 78 | v = normalize(torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v) 79 | u = normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u) 80 | if self.n_power_iterations > 0: 81 | # See above on why we need to clone 82 | u = u.clone() 83 | v = v.clone() 84 | 85 | sigma = torch.dot(u, torch.mv(weight_mat, v)) 86 | weight = weight / sigma 87 | return weight 88 | 89 | def remove(self, module): 90 | with torch.no_grad(): 91 | weight = self.compute_weight(module, do_power_iteration=False) 92 | delattr(module, self.name) 93 | delattr(module, self.name + '_u') 94 | delattr(module, self.name + '_v') 95 | delattr(module, self.name + '_orig') 96 | module.register_parameter(self.name, torch.nn.Parameter(weight.detach())) 97 | 98 | def __call__(self, module, inputs): 99 | setattr(module, self.name, self.compute_weight(module, do_power_iteration=module.training)) 100 | 101 | def _solve_v_and_rescale(self, weight_mat, u, target_sigma): 102 | # Tries to returns a vector `v` s.t. `u = normalize(W @ v)` 103 | # (the invariant at top of this class) and `u @ W @ v = sigma`. 104 | # This uses pinverse in case W^T W is not invertible. 105 | v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)).squeeze(1) 106 | return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v))) 107 | 108 | @staticmethod 109 | def apply(module, name, n_power_iterations, dim, eps): 110 | for k, hook in module._forward_pre_hooks.items(): 111 | if isinstance(hook, SpectralNorm) and hook.name == name: 112 | raise RuntimeError("Cannot register two spectral_norm hooks on " 113 | "the same parameter {}".format(name)) 114 | 115 | fn = SpectralNorm(name, n_power_iterations, dim, eps) 116 | weight = module._parameters[name] 117 | 118 | with torch.no_grad(): 119 | weight_mat = fn.reshape_weight_to_matrix(weight) 120 | 121 | h, w = weight_mat.size() 122 | # randomly initialize `u` and `v` 123 | u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) 124 | v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) 125 | 126 | delattr(module, fn.name) 127 | module.register_parameter(fn.name + "_orig", weight) 128 | # We still need to assign weight back as fn.name because all sorts of 129 | # things may assume that it exists, e.g., when initializing weights. 130 | # However, we can't directly assign as it could be an nn.Parameter and 131 | # gets added as a parameter. Instead, we register weight.data as a plain 132 | # attribute. 133 | setattr(module, fn.name, weight.data) 134 | module.register_buffer(fn.name + "_u", u) 135 | module.register_buffer(fn.name + "_v", v) 136 | 137 | module.register_forward_pre_hook(fn) 138 | module._register_state_dict_hook(SpectralNormStateDictHook(fn)) 139 | module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn)) 140 | return fn 141 | 142 | 143 | # This is a top level class because Py2 pickle doesn't like inner class nor an 144 | # instancemethod. 145 | class SpectralNormLoadStateDictPreHook(object): 146 | # See docstring of SpectralNorm._version on the changes to spectral_norm. 147 | def __init__(self, fn): 148 | self.fn = fn 149 | 150 | # For state_dict with version None, (assuming that it has gone through at 151 | # least one training forward), we have 152 | # 153 | # u = normalize(W_orig @ v) 154 | # W = W_orig / sigma, where sigma = u @ W_orig @ v 155 | # 156 | # To compute `v`, we solve `W_orig @ x = u`, and let 157 | # v = x / (u @ W_orig @ x) * (W / W_orig). 158 | def __call__(self, state_dict, prefix, local_metadata, strict, 159 | missing_keys, unexpected_keys, error_msgs): 160 | fn = self.fn 161 | version = local_metadata.get('spectral_norm', {}).get(fn.name + '.version', None) 162 | if version is None or version < 1: 163 | weight_key = prefix + fn.name 164 | if version is None and all(weight_key + s in state_dict for s in ('_orig', '_u', '_v')) and \ 165 | weight_key not in state_dict: 166 | # Detect if it is the updated state dict and just missing metadata. 167 | # This could happen if the users are crafting a state dict themselves, 168 | # so we just pretend that this is the newest. 169 | return 170 | has_missing_keys = False 171 | for suffix in ('_orig', '', '_u'): 172 | key = weight_key + suffix 173 | if key not in state_dict: 174 | has_missing_keys = True 175 | if strict: 176 | missing_keys.append(key) 177 | if has_missing_keys: 178 | return 179 | with torch.no_grad(): 180 | weight_orig = state_dict[weight_key + '_orig'] 181 | weight = state_dict.pop(weight_key) 182 | sigma = (weight_orig / weight).mean() 183 | weight_mat = fn.reshape_weight_to_matrix(weight_orig) 184 | u = state_dict[weight_key + '_u'] 185 | v = fn._solve_v_and_rescale(weight_mat, u, sigma) 186 | state_dict[weight_key + '_v'] = v 187 | 188 | 189 | # This is a top level class because Py2 pickle doesn't like inner class nor an 190 | # instancemethod. 191 | class SpectralNormStateDictHook(object): 192 | # See docstring of SpectralNorm._version on the changes to spectral_norm. 193 | def __init__(self, fn): 194 | self.fn = fn 195 | 196 | def __call__(self, module, state_dict, prefix, local_metadata): 197 | if 'spectral_norm' not in local_metadata: 198 | local_metadata['spectral_norm'] = {} 199 | key = self.fn.name + '.version' 200 | if key in local_metadata['spectral_norm']: 201 | raise RuntimeError("Unexpected key in metadata['spectral_norm']: {}".format(key)) 202 | local_metadata['spectral_norm'][key] = self.fn._version 203 | 204 | 205 | def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None): 206 | r"""Applies spectral normalization to a parameter in the given module. 207 | .. math:: 208 | \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, 209 | \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} 210 | Spectral normalization stabilizes the training of discriminators (critics) 211 | in Generative Adversarial Networks (GANs) by rescaling the weight tensor 212 | with spectral norm :math:`\sigma` of the weight matrix calculated using 213 | power iteration method. If the dimension of the weight tensor is greater 214 | than 2, it is reshaped to 2D in power iteration method to get spectral 215 | norm. This is implemented via a hook that calculates spectral norm and 216 | rescales weight before every :meth:`~Module.forward` call. 217 | See `Spectral Normalization for Generative Adversarial Networks`_ . 218 | .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 219 | Args: 220 | module (nn.Module): containing module 221 | name (str, optional): name of weight parameter 222 | n_power_iterations (int, optional): number of power iterations to 223 | calculate spectral norm 224 | eps (float, optional): epsilon for numerical stability in 225 | calculating norms 226 | dim (int, optional): dimension corresponding to number of outputs, 227 | the default is ``0``, except for modules that are instances of 228 | ConvTranspose{1,2,3}d, when it is ``1`` 229 | Returns: 230 | The original module with the spectral norm hook 231 | Example:: 232 | >>> m = spectral_norm(nn.Linear(20, 40)) 233 | >>> m 234 | Linear(in_features=20, out_features=40, bias=True) 235 | >>> m.weight_u.size() 236 | torch.Size([40]) 237 | """ 238 | if dim is None: 239 | if isinstance(module, (torch.nn.ConvTranspose1d, 240 | torch.nn.ConvTranspose2d, 241 | torch.nn.ConvTranspose3d)): 242 | dim = 1 243 | else: 244 | dim = 0 245 | SpectralNorm.apply(module, name, n_power_iterations, dim, eps) 246 | return module 247 | 248 | 249 | def remove_spectral_norm(module, name='weight'): 250 | r"""Removes the spectral normalization reparameterization from a module. 251 | Args: 252 | module (Module): containing module 253 | name (str, optional): name of weight parameter 254 | Example: 255 | >>> m = spectral_norm(nn.Linear(40, 10)) 256 | >>> remove_spectral_norm(m) 257 | """ 258 | for k, hook in module._forward_pre_hooks.items(): 259 | if isinstance(hook, SpectralNorm) and hook.name == name: 260 | hook.remove(module) 261 | del module._forward_pre_hooks[k] 262 | break 263 | else: 264 | raise ValueError("spectral_norm of '{}' not found in {}".format( 265 | name, module)) 266 | 267 | for k, hook in module._state_dict_hooks.items(): 268 | if isinstance(hook, SpectralNormStateDictHook) and hook.fn.name == name: 269 | del module._state_dict_hooks[k] 270 | break 271 | 272 | for k, hook in module._load_state_dict_pre_hooks.items(): 273 | if isinstance(hook, SpectralNormLoadStateDictPreHook) and hook.fn.name == name: 274 | del module._load_state_dict_pre_hooks[k] 275 | break 276 | 277 | return module -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | 8 | 9 | def tensor2im(input_image, imtype=np.uint8): 10 | """"Converts a Tensor array into a numpy image array. 11 | 12 | Parameters: 13 | input_image (tensor) -- the input image tensor array 14 | imtype (type) -- the desired type of the converted numpy array 15 | """ 16 | if not isinstance(input_image, np.ndarray): 17 | if isinstance(input_image, torch.Tensor): # get the data from a variable 18 | image_tensor = input_image.data 19 | else: 20 | return input_image 21 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 22 | if image_numpy.shape[0] == 1: # grayscale to RGB 23 | #print(image_numpy.shape) 24 | #print(image_numpy) 25 | #image_numpy = image_numpy > 0.5 26 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 27 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling, tensor range: [-1, 1] 28 | #image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: tranpose and scaling, tensor range: [-1, 1] 29 | else: # if it is a numpy array, do nothing 30 | image_numpy = input_image 31 | image_numpy = np.clip(image_numpy, 0, 255) 32 | return image_numpy.astype(imtype) 33 | 34 | def tensor2mask(input_image, imtype=np.uint8): 35 | """"Converts a Tensor array into a numpy image array. 36 | 37 | Parameters: 38 | input_image (tensor) -- the input image tensor array 39 | imtype (type) -- the desired type of the converted numpy array 40 | """ 41 | if not isinstance(input_image, np.ndarray): 42 | if isinstance(input_image, torch.Tensor): # get the data from a variable 43 | image_tensor = input_image.data 44 | else: 45 | return input_image 46 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 47 | if image_numpy.shape[0] == 1: # grayscale to RGB 48 | #print(image_numpy.shape) 49 | #print(image_numpy) 50 | #image_numpy = image_numpy > 0.5 51 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 52 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: tranpose and scaling 53 | else: # if it is a numpy array, do nothing 54 | image_numpy = input_image 55 | return image_numpy.astype(imtype) 56 | 57 | 58 | def diagnose_network(net, name='network'): 59 | """Calculate and print the mean of average absolute(gradients) 60 | 61 | Parameters: 62 | net (torch network) -- Torch network 63 | name (str) -- the name of the network 64 | """ 65 | mean = 0.0 66 | count = 0 67 | for param in net.parameters(): 68 | if param.grad is not None: 69 | mean += torch.mean(torch.abs(param.grad.data)) 70 | count += 1 71 | if count > 0: 72 | mean = mean / count 73 | print(name) 74 | print(mean) 75 | 76 | 77 | def save_image(image_numpy, image_path, aspect_ratio=1.0): 78 | """Save a numpy image to the disk 79 | 80 | Parameters: 81 | image_numpy (numpy array) -- input numpy array 82 | image_path (str) -- the path of the image 83 | """ 84 | 85 | image_pil = Image.fromarray(image_numpy) 86 | h, w, _ = image_numpy.shape 87 | 88 | if aspect_ratio > 1.0: 89 | image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) 90 | if aspect_ratio < 1.0: 91 | image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) 92 | image_pil.save(image_path,quality=100) #added by Mia (quality) 93 | 94 | 95 | 96 | def print_numpy(x, val=True, shp=False): 97 | """Print the mean, min, max, median, std, and size of a numpy array 98 | 99 | Parameters: 100 | val (bool) -- if print the values of the numpy array 101 | shp (bool) -- if print the shape of the numpy array 102 | """ 103 | x = x.astype(np.float64) 104 | if shp: 105 | print('shape,', x.shape) 106 | if val: 107 | x = x.flatten() 108 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 109 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 110 | 111 | 112 | def mkdirs(paths): 113 | """create empty directories if they don't exist 114 | 115 | Parameters: 116 | paths (str list) -- a list of directory paths 117 | """ 118 | if isinstance(paths, list) and not isinstance(paths, str): 119 | for path in paths: 120 | mkdir(path) 121 | else: 122 | mkdir(paths) 123 | 124 | 125 | def mkdir(path): 126 | """create a single empty directory if it didn't exist 127 | 128 | Parameters: 129 | path (str) -- a single directory path 130 | """ 131 | if not os.path.exists(path): 132 | os.makedirs(path) 133 | 134 | def copy_state_dict(cur_state_dict, pre_state_dict, prefix='', load_name=None): 135 | def _get_params(key): 136 | key = prefix + key 137 | if key in pre_state_dict: 138 | return pre_state_dict[key] 139 | return None 140 | for k in cur_state_dict.keys(): 141 | if load_name is not None: 142 | if load_name not in k: 143 | continue 144 | v = _get_params(k) 145 | try: 146 | if v is None: 147 | # print('parameter {} not found'.format(k)) 148 | continue 149 | cur_state_dict[k].copy_(v) 150 | except: 151 | # print('copy param {} failed'.format(k)) 152 | continue -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import ntpath 4 | import time 5 | from . import util 6 | from . import html 7 | from pdb import set_trace as st 8 | class Visualizer(): 9 | def __init__(self, opt, phase): 10 | # self.opt = opt 11 | self.display_id = opt.display_id 12 | 13 | self.use_html = True 14 | 15 | self.isTrain = opt.isTrain 16 | 17 | self.phase = phase 18 | 19 | 20 | self.win_size = opt.display_winsize 21 | # self.name = opt.name 22 | self.name = opt.checkpoints_dir.split('/')[-1] 23 | if self.display_id > 0: 24 | import visdom 25 | self.vis = visdom.Visdom(server = opt.display_server,port = opt.display_port) 26 | #self.ncols = opt.ncols 27 | self.ncols = opt.display_ncols 28 | if self.use_html: 29 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 30 | if self.phase == 'train': 31 | self.img_dir = os.path.join(self.web_dir, 'TrainImages') 32 | elif self.phase == 'test': 33 | self.img_dir = os.path.join(self.web_dir, 'TestImages/', 'epoch_' + opt.epoch) 34 | else: 35 | self.img_dir = os.path.join(self.web_dir, 'RealCompositeImages') 36 | 37 | print('images are stored in {}'.format(self.img_dir)) 38 | 39 | 40 | 41 | print('create web directory %s...' % self.web_dir) 42 | util.mkdirs([self.web_dir, self.img_dir]) 43 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 44 | self.log_name_test = os.path.join(opt.checkpoints_dir, opt.name, 'test_log.txt') 45 | with open(self.log_name, "a") as log_file: 46 | now = time.strftime("%c") 47 | log_file.write('================ Training Loss (%s) ================\n' % now) 48 | 49 | # |visuals|: dictionary of images to display or save 50 | def display_current_results(self, visuals, epoch): 51 | if self.display_id > 0: # show images in the browser 52 | ncols = self.ncols 53 | if self.ncols > 0: 54 | h, w = next(iter(visuals.values())).shape[:2] 55 | table_css = """""" % (w, h) 59 | ncols = self.ncols 60 | title = self.name 61 | label_html = '' 62 | label_html_row = '' 63 | nrows = int(np.ceil(len(visuals.items()) / ncols)) 64 | images = [] 65 | idx = 0 66 | for label, image_numpy in visuals.items(): 67 | label_html_row += '%s' % label 68 | images.append(image_numpy.transpose([2, 0, 1])) 69 | idx += 1 70 | if idx % ncols == 0: 71 | label_html += '%s' % label_html_row 72 | label_html_row = '' 73 | ''' 74 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255 75 | while idx % ncols != 0: 76 | images.append(white_image) 77 | label_html_row += '' 78 | idx += 1 79 | ''' 80 | if label_html_row != '': 81 | label_html += '%s' % label_html_row 82 | # pane col = image row 83 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 84 | padding=2, opts=dict(title=title + ' images')) 85 | #label_html = '%s
' % label_html 86 | #self.vis.text(table_css + label_html, win = self.display_id + 2, 87 | # opts=dict(title=title + ' labels')) 88 | else: 89 | idx = 1 90 | for label, image_numpy in visuals.items(): 91 | self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label), 92 | win=self.display_id + idx) 93 | idx += 1 94 | 95 | if self.use_html: # save images to a html file 96 | for label, image_numpy in visuals.items(): 97 | if self.isTrain: 98 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 99 | else: 100 | img_path = os.path.join(self.img_dir, '%d.png' % (epoch)) 101 | util.save_image(image_numpy, img_path) 102 | # update website 103 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) 104 | for n in range(epoch, 0, -1): 105 | webpage.add_header('epoch [%d]' % n) 106 | ims = [] 107 | txts = [] 108 | links = [] 109 | 110 | for label, image_numpy in visuals.items(): 111 | img_path = 'epoch%.3d_%s.png' % (n, label) 112 | ims.append(img_path) 113 | txts.append(label) 114 | links.append(img_path) 115 | webpage.add_images(ims, txts, links, width=self.win_size) 116 | webpage.save() 117 | 118 | # errors: dictionary of error labels and values 119 | def plot_current_losses(self, epoch, counter_ratio, opt, errors): 120 | if not hasattr(self, 'plot_data_train'): 121 | self.plot_data_train = {'X':[],'Y':[], 'legend':list(errors.keys())} 122 | self.plot_data_train['X'].append(epoch + counter_ratio) 123 | self.plot_data_train['Y'].append([errors[k] for k in self.plot_data_train['legend']]) 124 | self.vis.line( 125 | X=np.stack([np.array(self.plot_data_train['X'])]*len(self.plot_data_train['legend']),1), 126 | Y=np.array(self.plot_data_train['Y']), 127 | opts={ 128 | 'title': self.name + ' loss over time', 129 | 'legend': self.plot_data_train['legend'], 130 | 'xlabel': 'epoch', 131 | 'ylabel': 'loss'}, 132 | win=self.display_id) 133 | 134 | def plot_test_errors(self, epoch, counter_ratio, opt, errors): 135 | if not hasattr(self, 'plot_data'): 136 | self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} 137 | self.plot_data['X'].append(epoch + counter_ratio) 138 | self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) 139 | self.vis.line( 140 | X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1), 141 | Y=np.array(self.plot_data['Y']), 142 | opts={ 143 | 'title': self.name + ' loss over time', 144 | 'legend': self.plot_data['legend'], 145 | 'xlabel': 'epoch', 146 | 'ylabel': 'loss'}, 147 | win=self.display_id+10) 148 | message = '(epoch: %d)' %(epoch) 149 | for k, v in errors.items(): 150 | message += '%s: %.3f ' % (k, v) 151 | 152 | print(message) 153 | with open(self.log_name_test, "a") as log_file: 154 | log_file.write('%s\n' % message) 155 | 156 | # errors: same format as |errors| of plotCurrentErrors 157 | def print_current_errors(self, epoch, i, errors, t): 158 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) 159 | for k, v in errors.items(): 160 | message += '%s: %.3f ' % (k, v) 161 | 162 | print(message) 163 | with open(self.log_name, "a") as log_file: 164 | log_file.write('%s\n' % message) 165 | 166 | # save image to the disk 167 | def save_images(self, visuals, image_path, image_name): 168 | print(image_path, image_name) 169 | if not os.path.exists(image_path): 170 | os.makedirs(image_path, exist_ok=True) 171 | for label, image_numpy in visuals.items(): 172 | util.save_image(image_numpy, os.path.join(image_path, image_name)) 173 | 174 | 175 | --------------------------------------------------------------------------------