├── README.md ├── data ├── __init__.py ├── base_data_loader.py ├── base_dataset.py ├── custom_dataset_data_loader.py ├── data_loader.py ├── image_folder.py └── keypoint.py ├── head_img3_00.png ├── losses ├── CX_style_loss.py ├── L1_plus_perceptualLoss.py ├── __init__.py ├── gan.py └── lpips │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── lpips.cpython-36.pyc │ ├── networks.cpython-36.pyc │ └── utils.cpython-36.pyc │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── models ├── CASD.py ├── __init__.py ├── adgan.py ├── base_model.py ├── models.py ├── networks.py ├── test_model.py ├── vgg.py └── vgg_SC.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── requirements.txt ├── test.py ├── tool ├── generate_fashion_datasets.py ├── generate_pose_map_fashion.py └── resize_fashion.py ├── train.py └── util ├── __init__.py ├── get_data.py ├── html.py ├── image_pool.py ├── png.py ├── pose_utils.py ├── util.py └── visualizer.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Cross-Attention-Based-Style-Distribution 3 | 4 | The source code for our paper "[Cross Attention Based Style Distribution for Controllable Person Image Synthesis](https://arxiv.org/abs/2208.00712)" (**ECCV2022**). 5 | 6 |

7 | 8 |

9 | 10 | 11 | ## Installation 12 | 13 | #### Requirements 14 | 15 | - Python 3 16 | - PyTorch 1.7.0 17 | - CUDA 10.2 18 | 19 | #### Conda Installation 20 | 21 | ``` bash 22 | # 1. Create a conda virtual environment. 23 | conda create -n CASD python=3.6 24 | conda activate CASD 25 | conda install -c pytorch pytorch=1.7.0 torchvision=0.8.0 cudatoolkit=10.2 26 | 27 | # 2. Install other dependencies. 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | 32 | ### Data Preperation 33 | 34 | The dataset structure is recommended as: 35 | ``` 36 | +—dataset 37 | | +—fashion 38 | | +--train (person images in 'train.lst') 39 | | +--test (person images in 'test.lst') 40 | | +--train_resize (resized person images in 'train.lst') 41 | | +--test_resize (resized person images in 'test.lst') 42 | | +--trainK(keypoints of person images) 43 | | +--testK(keypoints of person images) 44 | | +—semantic_merge3(semantic masks of person images) 45 | | +—fashion-resize-pairs-train.csv 46 | | +—fashion-resize-pairs-test.csv 47 | | +—fasion-resize-annotation-pairs-train.csv 48 | | +—fasion-resize-annotation-pairs-test.csv 49 | | +—train.lst 50 | | +—test.lst 51 | | +—vgg19-dcbb9e9d.pth 52 | | +—vgg_conv.pth 53 | | +—vgg.pth 54 | ... 55 | ``` 56 | 57 | 58 | 1. Person images 59 | 60 | - Download `img_highres.zip` of the DeepFashion Dataset from [In-shop Clothes Retrieval Benchmark](https://drive.google.com/drive/folders/0B7EVK8r0v71pYkd5TzBiclMzR00). 61 | 62 | - Unzip `img_highres.zip`. You will need to ask for password from the [dataset maintainers](http://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/InShopRetrieval.html). Then put the obtained folder **img_highres** under the `./dataset/fashion` directory. 63 | 64 | - Download train/test key points annotations and the train/test pairs from [Google Drive](https://drive.google.com/drive/folders/1qGRZUJY7QipLRDNQ0lhCubDPsJxmX2jK?usp=sharing) including **fashion-resize-pairs-train.csv**, **fashion-resize-pairs-test.csv**, **fashion-resize-annotation-train.csv**, **fashion-resize-annotation-test.csv,** **train.lst**, **test.lst**. Put these files under the `./dataset/fashion` directory. 65 | 66 | - Run the following code to split the train/test dataset. 67 | 68 | ```bash 69 | python tool/generate_fashion_datasets.py 70 | ``` 71 | 72 | - Run the following code to resize the train/test dataset. 73 | 74 | ```bash 75 | python tool/resize_fashion.py 76 | ``` 77 | 78 | 79 | 2. Keypoints files 80 | 81 | - Generate the pose heatmaps. Launch 82 | ```bash 83 | python tool/generate_pose_map_fashion.py 84 | ``` 85 | 86 | 3. Segmentation files 87 | - Extract human segmentation results from existing human parser (e.g. LIP_JPPNet). Our segmentation results ‘semantic_merge3’ are provided in [Google Drive](https://drive.google.com/drive/folders/1qGRZUJY7QipLRDNQ0lhCubDPsJxmX2jK?usp=sharing). Put it under the ```./dataset/fashion``` directory. 88 | 89 | 90 | ### Training 91 | 92 | ```bash 93 | python train.py --dataroot ./dataset/fashion --dirSem ./dataset/fashion --pairLst ./dataset/fashion/fashion-resize-pairs-train.csv --name CASD_test --batchSize 16 --gpu_ids 0,1 --which_model_netG CASD --checkpoints_dir ./checkpoints 94 | ``` 95 | The models are save in `./checkpoints`. 96 | 97 | ### Testing 98 | Download our pretrained model from [Google Drive](https://drive.google.com/drive/folders/1qGRZUJY7QipLRDNQ0lhCubDPsJxmX2jK?usp=sharing). Put the obtained checkpoints under `./checkpoints/CASD_test`. Modify your data path and launch 99 | ```bash 100 | python test.py --dataroot ./dataset/fashion --dirSem ./dataset/fashion --pairLst ./dataset/fashion/fashion-resize-pairs-test.csv --checkpoints_dir ./checkpoints --results_dir ./results --name CASD_test --phase test --batchSize 1 --gpu_ids 0,0 --which_model_netG CASD --which_epoch 1000 101 | ``` 102 | The result images are save in `./results`. 103 | 104 | ## Citation 105 | If you use this code for your research, please cite 106 | ``` 107 | @article{zhou2022casd, 108 | title={Cross Attention Based Style Distribution for Controllable Person Image Synthesis}, 109 | author={Zhou, Xinyue and Yin, Mingyu and Chen, Xinyuan and Sun, Li and Gao, Changxin and Li, Qingli}, 110 | journal={arXiv preprint arXiv:2208.00712}, 111 | year={2022} 112 | } 113 | ``` 114 | 115 | 116 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyzhouo/CASD/cb1aabb64b5d8ae712b626a1eec045d08f90933e/data/__init__.py -------------------------------------------------------------------------------- /data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | class BaseDataLoader(): 3 | def __init__(self): 4 | pass 5 | 6 | def initialize(self, opt): 7 | self.opt = opt 8 | pass 9 | 10 | def load_data(): 11 | return None -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | 5 | class BaseDataset(data.Dataset): 6 | def __init__(self): 7 | super(BaseDataset, self).__init__() 8 | 9 | def name(self): 10 | return 'BaseDataset' 11 | 12 | def initialize(self, opt): 13 | pass 14 | 15 | 16 | def get_transform(opt): 17 | transform_list = [] 18 | if opt.resize_or_crop == 'resize_and_crop': 19 | osize = [opt.loadSize, opt.loadSize] 20 | transform_list.append(transforms.Scale(osize, Image.BICUBIC)) 21 | transform_list.append(transforms.RandomCrop(opt.FfineSize)) 22 | elif opt.resize_or_crop == 'crop': 23 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 24 | elif opt.resize_or_crop == 'scale_width': 25 | transform_list.append(transforms.Lambda( 26 | lambda img: __scale_width(img, opt.fineSize))) 27 | elif opt.resize_or_crop == 'scale_width_and_crop': 28 | transform_list.append(transforms.Lambda( 29 | lambda img: __scale_width(img, opt.loadSize))) 30 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 31 | 32 | transform_list += [transforms.ToTensor(), 33 | transforms.Normalize((0.5, 0.5, 0.5), 34 | (0.5, 0.5, 0.5))] 35 | return transforms.Compose(transform_list) 36 | 37 | def __scale_width(img, target_width): 38 | ow, oh = img.size 39 | if (ow == target_width): 40 | return img 41 | w = target_width 42 | h = int(target_width * oh / ow) 43 | return img.resize((w, h), Image.BICUBIC) 44 | -------------------------------------------------------------------------------- /data/custom_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from data.base_data_loader import BaseDataLoader 3 | 4 | 5 | def CreateDataset(opt): 6 | dataset = None 7 | 8 | if opt.dataset_mode == 'keypoint': 9 | from data.keypoint import KeyDataset 10 | dataset = KeyDataset() 11 | elif opt.dataset_mode == 'keypoint_mix': 12 | from data.keypoint_mix import KeyDataset 13 | dataset = KeyDataset() 14 | else: 15 | raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode) 16 | 17 | print("dataset [%s] was created" % (dataset.name())) 18 | dataset.initialize(opt) 19 | return dataset 20 | 21 | 22 | class CustomDatasetDataLoader(BaseDataLoader): 23 | def name(self): 24 | return 'CustomDatasetDataLoader' 25 | 26 | def initialize(self, opt): 27 | BaseDataLoader.initialize(self, opt) 28 | self.dataset = CreateDataset(opt) 29 | self.dataloader = torch.utils.data.DataLoader( 30 | self.dataset, 31 | batch_size=opt.batchSize, 32 | shuffle=not opt.serial_batches, 33 | num_workers=int(opt.nThreads)) 34 | 35 | def load_data(self): 36 | return self 37 | 38 | def __len__(self): 39 | return min(len(self.dataset), self.opt.max_dataset_size) 40 | 41 | def __iter__(self): 42 | for i, data in enumerate(self.dataloader): 43 | if i >= self.opt.max_dataset_size: 44 | break 45 | yield data 46 | -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | def CreateDataLoader(opt): 3 | from data.custom_dataset_data_loader import CustomDatasetDataLoader 4 | data_loader = CustomDatasetDataLoader() 5 | print(data_loader.name()) 6 | data_loader.initialize(opt) 7 | return data_loader 8 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ############################################################################### 7 | 8 | import torch.utils.data as data 9 | 10 | from PIL import Image 11 | import os 12 | import os.path 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir): 25 | images = [] 26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 27 | 28 | for root, _, fnames in sorted(os.walk(dir)): 29 | for fname in fnames: 30 | if is_image_file(fname): 31 | path = os.path.join(root, fname) 32 | images.append(path) 33 | 34 | return images 35 | 36 | 37 | def default_loader(path): 38 | return Image.open(path).convert('RGB') 39 | 40 | 41 | class ImageFolder(data.Dataset): 42 | 43 | def __init__(self, root, transform=None, return_paths=False, 44 | loader=default_loader): 45 | imgs = make_dataset(root) 46 | if len(imgs) == 0: 47 | raise(RuntimeError("Found 0 images in: " + root + "\n" 48 | "Supported image extensions are: " + 49 | ",".join(IMG_EXTENSIONS))) 50 | 51 | self.root = root 52 | self.imgs = imgs 53 | self.transform = transform 54 | self.return_paths = return_paths 55 | self.loader = loader 56 | 57 | def __getitem__(self, index): 58 | path = self.imgs[index] 59 | img = self.loader(path) 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | if self.return_paths: 63 | return img, path 64 | else: 65 | return img 66 | 67 | def __len__(self): 68 | return len(self.imgs) 69 | -------------------------------------------------------------------------------- /data/keypoint.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_transform 3 | from PIL import Image 4 | import random 5 | import pandas as pd 6 | import torch 7 | import util.util as util 8 | import numpy as np 9 | import torchvision.transforms.functional as F 10 | 11 | class KeyDataset(BaseDataset): 12 | def initialize(self, opt): 13 | self.opt = opt 14 | self.root = opt.dataroot 15 | self.dir_P = os.path.join(opt.dataroot, opt.phase + '_resize') 16 | self.dir_K = os.path.join(opt.dataroot, opt.phase + 'K') 17 | self.dir_SP = opt.dirSem 18 | self.SP_input_nc = opt.SP_input_nc 19 | 20 | self.init_categories(opt.pairLst) 21 | self.transform = get_transform(opt) 22 | self.use_BPD = self.opt.use_BPD 23 | 24 | self.finesize = opt.fineSize 25 | 26 | def init_categories(self, pairLst): 27 | pairs_file_train = pd.read_csv(pairLst) 28 | self.size = len(pairs_file_train) 29 | self.pairs = [] 30 | print('Loading data pairs ...') 31 | for i in range(self.size): 32 | pair = [pairs_file_train.iloc[i]['from'], pairs_file_train.iloc[i]['to']] 33 | self.pairs.append(pair) 34 | 35 | print('Loading data pairs finished ...') 36 | 37 | def __getitem__(self, index): 38 | if self.opt.phase == 'train': 39 | index = random.randint(0, self.size - 1) 40 | 41 | P1_name, P2_name = self.pairs[index] 42 | P1_path = os.path.join(self.dir_P, P1_name) 43 | BP1_path = os.path.join(self.dir_K, P1_name + '.npy') 44 | 45 | P2_path = os.path.join(self.dir_P, P2_name) 46 | BP2_path = os.path.join(self.dir_K, P2_name + '.npy') 47 | 48 | P1_img = Image.open(P1_path).convert('RGB') 49 | P2_img = Image.open(P2_path).convert('RGB') 50 | 51 | BP1_img = np.load(BP1_path) 52 | BP2_img = np.load(BP2_path) 53 | 54 | if self.use_BPD: 55 | BPD1_img = util.draw_dis_from_map(BP1_img)[0] 56 | BPD2_img = util.draw_dis_from_map(BP2_img)[0] 57 | 58 | # use flip 59 | if self.opt.phase == 'train' and self.opt.use_flip: 60 | # print ('use_flip ...') 61 | flip_random = random.uniform(0, 1) 62 | 63 | if flip_random > 0.5: 64 | # print('fliped ...') 65 | P1_img = P1_img.transpose(Image.FLIP_LEFT_RIGHT) 66 | P2_img = P2_img.transpose(Image.FLIP_LEFT_RIGHT) 67 | 68 | BP1_img = np.array(BP1_img[:, ::-1, :]) 69 | BP2_img = np.array(BP2_img[:, ::-1, :]) 70 | 71 | BP1 = torch.from_numpy(BP1_img).float() 72 | BP1 = BP1.transpose(2, 0) 73 | BP1 = BP1.transpose(2, 1) 74 | 75 | BP2 = torch.from_numpy(BP2_img).float() 76 | BP2 = BP2.transpose(2, 0) 77 | BP2 = BP2.transpose(2, 1) 78 | 79 | P1 = self.transform(P1_img) 80 | P2 = self.transform(P2_img) 81 | else: 82 | BP1 = torch.from_numpy(BP1_img).float() 83 | BP1 = BP1.transpose(2, 0) 84 | BP1 = BP1.transpose(2, 1) 85 | 86 | BP2 = torch.from_numpy(BP2_img).float() 87 | BP2 = BP2.transpose(2, 0) 88 | BP2 = BP2.transpose(2, 1) 89 | 90 | P1 = self.transform(P1_img) 91 | P2 = self.transform(P2_img) 92 | if self.use_BPD: 93 | BPD1 = torch.from_numpy(BPD1_img).float() 94 | BPD1 = BPD1.transpose(2, 0) 95 | BPD1 = BPD1.transpose(2, 1) 96 | 97 | BPD2 = torch.from_numpy(BPD2_img).float() 98 | BPD2 = BPD2.transpose(2, 0) 99 | BPD2 = BPD2.transpose(2, 1) 100 | 101 | 102 | SP1_name = self.split_name_sementic3(P1_name, 'semantic_merge3') 103 | SP2_name = self.split_name_sementic3(P2_name, 'semantic_merge3') 104 | SP1_path = os.path.join(self.dir_SP, SP1_name) 105 | SP1_path = SP1_path[:-4] + '.png' 106 | SP1_data = Image.open(SP1_path) 107 | SP1_data = np.array(SP1_data) 108 | SP2_path = os.path.join(self.dir_SP, SP2_name) 109 | SP2_path = SP2_path[:-4] + '.png' 110 | SP2_data = Image.open(SP2_path) 111 | SP2_data = np.array(SP2_data) 112 | SP1 = np.zeros((self.SP_input_nc, self.finesize[0], self.finesize[1]), dtype='float32') 113 | SP2 = np.zeros((self.SP_input_nc, self.finesize[0], self.finesize[1]), dtype='float32') 114 | SP1_20 = np.zeros((20, self.finesize[0], self.finesize[1]), dtype='float32') 115 | SP2_20 = np.zeros((20, self.finesize[0], self.finesize[1]), dtype='float32') 116 | nc = 20 117 | for id in range(nc): 118 | SP1_20[id] = (SP1_data == id).astype('float32') 119 | SP2_20[id] = (SP2_data == id).astype('float32') 120 | SP1[0] = SP1_20[0] 121 | SP1[1] = SP1_20[9] + SP1_20[12] 122 | SP1[2] = SP1_20[2] + SP1_20[1] 123 | SP1[3] = SP1_20[3] 124 | SP1[4] = SP1_20[13] + SP1_20[4] 125 | SP1[5] = SP1_20[5] + SP1_20[6] + SP1_20[7] + SP1_20[10] + SP1_20[11] 126 | SP1[6] = SP1_20[14] + SP1_20[15] 127 | SP1[7] = SP1_20[8] + SP1_20[16] + SP1_20[17] + SP1_20[18] + SP1_20[19] 128 | 129 | SP2[0] = SP2_20[0] 130 | SP2[1] = SP2_20[9] + SP2_20[12] 131 | SP2[2] = SP2_20[2] + SP2_20[1] 132 | SP2[3] = SP2_20[3] 133 | SP2[4] = SP2_20[13] + SP2_20[4] 134 | SP2[5] = SP2_20[5] + SP2_20[6] + SP2_20[7] + SP2_20[10] + SP2_20[11] 135 | SP2[6] = SP2_20[14] + SP2_20[15] 136 | SP2[7] = SP2_20[8] + SP2_20[16] + SP2_20[17] + SP2_20[18] + SP2_20[19] 137 | 138 | 139 | if self.use_BPD: 140 | return {'P1': P1, 'BP1': BP1, 'SP1': SP1, 'BPD1': BPD1, 141 | 'P2': P2, 'BP2': BP2, 'SP2': SP2, 'BPD2': BPD2, 142 | 'P1_path': P1_name, 'P2_path': P2_name} 143 | else: 144 | return {'P1': P1, 'BP1': BP1, 'SP1': SP1, 145 | 'P2': P2, 'BP2': BP2, 'SP2': SP2, 146 | 'P1_path': P1_name, 'P2_path': P2_name} 147 | 148 | def __len__(self): 149 | if self.opt.phase == 'train': 150 | return 4000 151 | elif self.opt.phase == 'test': 152 | return self.size 153 | 154 | def name(self): 155 | return 'KeyDataset' 156 | 157 | 158 | def split_name_sementic3(self, str, type): 159 | list = [] 160 | list.append(type) 161 | list.append(str) 162 | 163 | head = '' 164 | for path in list: 165 | head = os.path.join(head, path) 166 | return head 167 | 168 | -------------------------------------------------------------------------------- /head_img3_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyzhouo/CASD/cb1aabb64b5d8ae712b626a1eec045d08f90933e/head_img3_00.png -------------------------------------------------------------------------------- /losses/CX_style_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | class CXLoss(nn.Module): 8 | 9 | def __init__(self, sigma=0.1, b=1.0, similarity="consine"): 10 | super(CXLoss, self).__init__() 11 | self.similarity = similarity 12 | self.sigma = sigma 13 | self.b = b 14 | 15 | def center_by_T(self, featureI, featureT): 16 | # Calculate mean channel vector for feature map. 17 | meanT = featureT.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) 18 | return featureI - meanT, featureT - meanT 19 | 20 | def l2_normalize_channelwise(self, features): 21 | # Normalize on channel dimension (axis=1) 22 | norms = features.norm(p=2, dim=1, keepdim=True) 23 | features = features.div(norms) 24 | return features 25 | 26 | def patch_decomposition(self, features): 27 | N, C, H, W = features.shape 28 | assert N == 1 29 | P = H * W 30 | # NCHW --> 1x1xCxHW --> HWxCx1x1 31 | patches = features.view(1, 1, C, P).permute((3, 2, 0, 1)) 32 | return patches 33 | 34 | def calc_relative_distances(self, raw_dist, axis=1): 35 | epsilon = 1e-5 36 | div = torch.min(raw_dist, dim=axis, keepdim=True)[0] 37 | relative_dist = raw_dist / (div + epsilon) 38 | return relative_dist 39 | 40 | def calc_CX(self, dist, axis=1): 41 | W = torch.exp((self.b - dist) / self.sigma) 42 | W_sum = W.sum(dim=axis, keepdim=True) 43 | return W.div(W_sum) 44 | 45 | def forward(self, featureT, featureI): 46 | ''' 47 | :param featureT: target 48 | :param featureI: inference 49 | :return: 50 | ''' 51 | # NCHW 52 | # print(featureI.shape) 53 | 54 | featureI, featureT = self.center_by_T(featureI, featureT) 55 | 56 | featureI = self.l2_normalize_channelwise(featureI) 57 | featureT = self.l2_normalize_channelwise(featureT) 58 | 59 | dist = [] 60 | N = featureT.size()[0] 61 | for i in range(N): 62 | # NCHW 63 | featureT_i = featureT[i, :, :, :].unsqueeze(0) 64 | # NCHW 65 | featureI_i = featureI[i, :, :, :].unsqueeze(0) 66 | featureT_patch = self.patch_decomposition(featureT_i) 67 | # Calculate cosine similarity 68 | # See the torch document for functional.conv2d 69 | dist_i = F.conv2d(featureI_i, featureT_patch) 70 | dist.append(dist_i) 71 | 72 | # NCHW 73 | dist = torch.cat(dist, dim=0) 74 | 75 | raw_dist = (1. - dist) / 2. 76 | 77 | relative_dist = self.calc_relative_distances(raw_dist) 78 | 79 | CX = self.calc_CX(relative_dist) 80 | 81 | CX = CX.max(dim=3)[0].max(dim=2)[0] 82 | CX = CX.mean(1) 83 | CX = -torch.log(CX) 84 | CX = torch.mean(CX) 85 | return CX 86 | 87 | -------------------------------------------------------------------------------- /losses/L1_plus_perceptualLoss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | import numpy as np 7 | import torch.nn.functional as F 8 | import torchvision.models as models 9 | 10 | class L1_plus_perceptualLoss(nn.Module): 11 | def __init__(self, lambda_L1, lambda_perceptual, perceptual_layers, gpu_ids, percep_is_l1): 12 | super(L1_plus_perceptualLoss, self).__init__() 13 | 14 | self.lambda_L1 = lambda_L1 15 | self.lambda_perceptual = lambda_perceptual 16 | self.gpu_ids = gpu_ids 17 | 18 | self.percep_is_l1 = percep_is_l1 19 | 20 | # vgg = models.vgg19(pretrained=True).features 21 | vgg19 = models.vgg19(pretrained=False) 22 | vgg19.load_state_dict(torch.load('/home/haihuam/CASD-main/dataset/fashion/vgg19-dcbb9e9d.pth')) 23 | vgg = vgg19.features 24 | 25 | 26 | self.vgg_submodel = nn.Sequential() 27 | for i,layer in enumerate(list(vgg)): 28 | self.vgg_submodel.add_module(str(i),layer) 29 | if i == perceptual_layers: 30 | break 31 | self.vgg_submodel = self.vgg_submodel.cuda() 32 | #self.vgg_submodel = torch.nn.DataParallel(self.vgg_submodel, device_ids=gpu_ids).cuda() 33 | 34 | print(self.vgg_submodel) 35 | 36 | def forward(self, inputs, targets): 37 | if self.lambda_L1 == 0 and self.lambda_perceptual == 0: 38 | return Variable(torch.zeros(1)).cuda(), Variable(torch.zeros(1)), Variable(torch.zeros(1)) 39 | # normal L1 40 | loss_l1 = F.l1_loss(inputs, targets) * self.lambda_L1 41 | 42 | # perceptual L1 43 | mean = torch.FloatTensor(3) 44 | mean[0] = 0.485 45 | mean[1] = 0.456 46 | mean[2] = 0.406 47 | mean = Variable(mean) 48 | mean = mean.resize(1, 3, 1, 1).cuda() 49 | 50 | std = torch.FloatTensor(3) 51 | std[0] = 0.229 52 | std[1] = 0.224 53 | std[2] = 0.225 54 | std = Variable(std) 55 | std = std.resize(1, 3, 1, 1).cuda() 56 | 57 | fake_p2_norm = (inputs + 1)/2 # [-1, 1] => [0, 1] 58 | fake_p2_norm = (fake_p2_norm - mean)/std 59 | 60 | input_p2_norm = (targets + 1)/2 # [-1, 1] => [0, 1] 61 | input_p2_norm = (input_p2_norm - mean)/std 62 | 63 | 64 | fake_p2_norm = self.vgg_submodel(fake_p2_norm) 65 | input_p2_norm = self.vgg_submodel(input_p2_norm) 66 | input_p2_norm_no_grad = input_p2_norm.detach() 67 | 68 | if self.percep_is_l1 == 1: 69 | # use l1 for perceptual loss 70 | loss_perceptual = F.l1_loss(fake_p2_norm, input_p2_norm_no_grad) * self.lambda_perceptual 71 | else: 72 | # use l2 for perceptual loss 73 | loss_perceptual = F.mse_loss(fake_p2_norm, input_p2_norm_no_grad) * self.lambda_perceptual 74 | 75 | loss = loss_l1 + loss_perceptual 76 | 77 | return loss, loss_l1, loss_perceptual 78 | 79 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyzhouo/CASD/cb1aabb64b5d8ae712b626a1eec045d08f90933e/losses/__init__.py -------------------------------------------------------------------------------- /losses/gan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from util.distributed import master_only_print as print 6 | 7 | 8 | @torch.jit.script 9 | def fuse_math_min_mean_pos(x): 10 | r"""Fuse operation min mean for hinge loss computation of positive 11 | samples""" 12 | minval = torch.min(x - 1, x * 0) 13 | loss = -torch.mean(minval) 14 | return loss 15 | 16 | 17 | @torch.jit.script 18 | def fuse_math_min_mean_neg(x): 19 | r"""Fuse operation min mean for hinge loss computation of negative 20 | samples""" 21 | minval = torch.min(-x - 1, x * 0) 22 | loss = -torch.mean(minval) 23 | return loss 24 | 25 | 26 | class GANLoss(nn.Module): 27 | r"""GAN loss constructor. 28 | 29 | Args: 30 | gan_mode (str): Type of GAN loss. ``'hinge'``, ``'least_square'``, 31 | ``'non_saturated'``, ``'wasserstein'``. 32 | target_real_label (float): The desired output label for real images. 33 | target_fake_label (float): The desired output label for fake images. 34 | """ 35 | 36 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): 37 | super(GANLoss, self).__init__() 38 | self.real_label = target_real_label 39 | self.fake_label = target_fake_label 40 | self.real_label_tensor = None 41 | self.fake_label_tensor = None 42 | self.gan_mode = gan_mode 43 | print('GAN mode: %s' % gan_mode) 44 | 45 | def forward(self, dis_output, t_real, dis_update=True): 46 | r"""GAN loss computation. 47 | 48 | Args: 49 | dis_output (tensor or list of tensors): Discriminator outputs. 50 | t_real (bool): If ``True``, uses the real label as target, otherwise 51 | uses the fake label as target. 52 | dis_update (bool): If ``True``, the loss will be used to update the 53 | discriminator, otherwise the generator. 54 | Returns: 55 | loss (tensor): Loss value. 56 | """ 57 | if isinstance(dis_output, list): 58 | # For multi-scale discriminators. 59 | # In this implementation, the loss is first averaged for each scale 60 | # (batch size and number of locations) then averaged across scales, 61 | # so that the gradient is not dominated by the discriminator that 62 | # has the most output values (highest resolution). 63 | loss = 0 64 | for dis_output_i in dis_output: 65 | assert isinstance(dis_output_i, torch.Tensor) 66 | loss += self.loss(dis_output_i, t_real, dis_update) 67 | return loss / len(dis_output) 68 | else: 69 | return self.loss(dis_output, t_real, dis_update) 70 | 71 | def loss(self, dis_output, t_real, dis_update=True): 72 | r"""GAN loss computation. 73 | 74 | Args: 75 | dis_output (tensor): Discriminator outputs. 76 | t_real (bool): If ``True``, uses the real label as target, otherwise 77 | uses the fake label as target. 78 | dis_update (bool): Updating the discriminator or the generator. 79 | Returns: 80 | loss (tensor): Loss value. 81 | """ 82 | if not dis_update: 83 | assert t_real, \ 84 | "The target should be real when updating the generator." 85 | 86 | if self.gan_mode == 'non_saturated': 87 | target_tensor = self.get_target_tensor(dis_output, t_real) 88 | loss = F.binary_cross_entropy_with_logits(dis_output, 89 | target_tensor) 90 | elif self.gan_mode == 'least_square': 91 | target_tensor = self.get_target_tensor(dis_output, t_real) 92 | loss = 0.5 * F.mse_loss(dis_output, target_tensor) 93 | elif self.gan_mode == 'hinge': 94 | if dis_update: 95 | if t_real: 96 | loss = fuse_math_min_mean_pos(dis_output) 97 | else: 98 | loss = fuse_math_min_mean_neg(dis_output) 99 | else: 100 | loss = -torch.mean(dis_output) 101 | elif self.gan_mode == 'wasserstein': 102 | if t_real: 103 | loss = -torch.mean(dis_output) 104 | else: 105 | loss = torch.mean(dis_output) 106 | elif self.gan_mode == 'style_gan2': 107 | if t_real: 108 | loss = F.softplus(-dis_output).mean() 109 | else: 110 | loss = F.softplus(dis_output).mean() 111 | else: 112 | raise ValueError('Unexpected gan_mode {}'.format(self.gan_mode)) 113 | return loss 114 | 115 | 116 | def get_target_tensor(self, dis_output, t_real): 117 | r"""Return the target vector for the binary cross entropy loss 118 | computation. 119 | 120 | Args: 121 | dis_output (tensor): Discriminator outputs. 122 | t_real (bool): If ``True``, uses the real label as target, otherwise 123 | uses the fake label as target. 124 | Returns: 125 | target (tensor): Target tensor vector. 126 | """ 127 | if t_real: 128 | if self.real_label_tensor is None: 129 | self.real_label_tensor = dis_output.new_tensor(self.real_label) 130 | return self.real_label_tensor.expand_as(dis_output) 131 | else: 132 | if self.fake_label_tensor is None: 133 | self.fake_label_tensor = dis_output.new_tensor(self.fake_label) 134 | return self.fake_label_tensor.expand_as(dis_output) 135 | -------------------------------------------------------------------------------- /losses/lpips/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyzhouo/CASD/cb1aabb64b5d8ae712b626a1eec045d08f90933e/losses/lpips/__init__.py -------------------------------------------------------------------------------- /losses/lpips/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyzhouo/CASD/cb1aabb64b5d8ae712b626a1eec045d08f90933e/losses/lpips/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /losses/lpips/__pycache__/lpips.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyzhouo/CASD/cb1aabb64b5d8ae712b626a1eec045d08f90933e/losses/lpips/__pycache__/lpips.cpython-36.pyc -------------------------------------------------------------------------------- /losses/lpips/__pycache__/networks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyzhouo/CASD/cb1aabb64b5d8ae712b626a1eec045d08f90933e/losses/lpips/__pycache__/networks.cpython-36.pyc -------------------------------------------------------------------------------- /losses/lpips/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyzhouo/CASD/cb1aabb64b5d8ae712b626a1eec045d08f90933e/losses/lpips/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /losses/lpips/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from losses.lpips.networks import get_network, LinLayers 5 | from losses.lpips.utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | Arguments: 12 | net_type (str): the network type to compare the features: 13 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 14 | version (str): the version of LPIPS. Default: 0.1. 15 | """ 16 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 17 | 18 | assert version in ['0.1'], 'v0.1 is only supported now' 19 | 20 | super(LPIPS, self).__init__() 21 | 22 | # pretrained network 23 | self.net = get_network(net_type).to("cuda") 24 | 25 | # linear layers 26 | self.lin = LinLayers(self.net.n_channels_list).to("cuda") 27 | self.lin.load_state_dict(get_state_dict(net_type, version)) 28 | 29 | def forward(self, x: torch.Tensor, y: torch.Tensor): 30 | feat_x, feat_y = self.net(x), self.net(y) 31 | 32 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 33 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 34 | 35 | return torch.sum(torch.cat(res, 0)) / x.shape[0] 36 | -------------------------------------------------------------------------------- /losses/lpips/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from losses.lpips.utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(True).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) -------------------------------------------------------------------------------- /losses/lpips/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /models/CASD.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import functools 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import os 7 | import torchvision.models.vgg as models 8 | from torch.nn.parameter import Parameter 9 | 10 | from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm 11 | import functools 12 | 13 | 14 | # Moddfied with AdINGen 15 | class ADGen(nn.Module): 16 | # AdaIN auto-encoder architecture 17 | def __init__(self, input_dim, dim, style_dim, n_downsample, n_res, mlp_dim, activ='relu', pad_type='reflect'): 18 | super(ADGen, self).__init__() 19 | 20 | # style encoder 21 | input_dim = 3 22 | self.SP_input_nc = 8 23 | self.enc_style = VggStyleEncoder(3, input_dim, dim, int(style_dim / self.SP_input_nc), norm='none', activ=activ, 24 | pad_type=pad_type) 25 | 26 | # content encoder 27 | self.enc_content = ContentEncoder(layers=2, ngf=64, img_f=512) 28 | 29 | input_dim = 3 30 | self.dec = Decoder(style_dim, mlp_dim, n_downsample, n_res, 256, input_dim, 31 | self.SP_input_nc, res_norm='adain', activ=activ, pad_type=pad_type) 32 | 33 | def forward(self, img_A, img_B, sem_B): 34 | content = self.enc_content(img_A) 35 | style = self.enc_style(img_B, sem_B) 36 | images_recon = self.dec(content, style) 37 | return images_recon 38 | 39 | 40 | def calc_mean_std(feat, eps=1e-5): 41 | # eps is a small value added to the variance to avoid divide-by-zero. 42 | size = feat.size() 43 | assert (len(size) == 4) 44 | N, C = size[:2] 45 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 46 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 47 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 48 | return feat_mean, feat_std 49 | 50 | 51 | class VggStyleEncoder(nn.Module): 52 | def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type): 53 | super(VggStyleEncoder, self).__init__() 54 | # self.vgg = models.vgg19(pretrained=True).features 55 | vgg19 = models.vgg19(pretrained=False) 56 | vgg19.load_state_dict(torch.load('/home/haihuam/CASD-main/dataset/fashion/vgg19-dcbb9e9d.pth')) 57 | self.vgg = vgg19.features 58 | 59 | for param in self.vgg.parameters(): 60 | param.requires_grad_(False) 61 | 62 | self.conv1 = Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type) # 3->64 63 | dim = dim * 2 64 | self.conv2 = Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type) # 128->128 65 | dim = dim * 2 66 | self.conv3 = Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type) # 256->256 67 | dim = dim * 2 68 | self.conv4 = Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type) # 512->512 69 | dim = dim * 2 70 | 71 | self.model0 = [] 72 | self.model0 += [nn.Conv2d(dim, style_dim, 1, 1, 0)] 73 | self.model0 = nn.Sequential(*self.model0) 74 | 75 | self.AP = [] 76 | self.AP += [nn.AdaptiveAvgPool2d(1)] 77 | self.AP = nn.Sequential(*self.AP) 78 | self.output_dim = dim 79 | 80 | def get_features(self, image, model, layers=None): 81 | if layers is None: 82 | layers = {'0': 'conv1_1', '5': 'conv2_1', '10': 'conv3_1', '19': 'conv4_1'} 83 | features = {} 84 | x = image 85 | # model._modules is a dictionary holding each module in the model 86 | for name, layer in model._modules.items(): 87 | x = layer(x) 88 | if name in layers: 89 | features[layers[name]] = x 90 | return features 91 | 92 | def texture_enc(self, x): 93 | sty_fea = self.get_features(x, self.vgg) 94 | x = self.conv1(x) 95 | x = torch.cat([x, sty_fea['conv1_1']], dim=1) 96 | x = self.conv2(x) 97 | x = torch.cat([x, sty_fea['conv2_1']], dim=1) 98 | x = self.conv3(x) 99 | x = torch.cat([x, sty_fea['conv3_1']], dim=1) 100 | x = self.conv4(x) 101 | x = torch.cat([x, sty_fea['conv4_1']], dim=1) 102 | x0 = self.model0(x) 103 | return x0 104 | 105 | def forward(self, x, sem): 106 | 107 | codes = self.texture_enc(x) 108 | segmap = F.interpolate(sem, size=codes.size()[2:], mode='nearest') 109 | 110 | bs = codes.shape[0] 111 | hs = codes.shape[2] 112 | ws = codes.shape[3] 113 | cs = codes.shape[1] 114 | f_size = cs 115 | 116 | s_size = segmap.shape[1] 117 | codes_vector = torch.zeros((bs, s_size, cs), dtype=codes.dtype, device=codes.device) 118 | 119 | for i in range(bs): 120 | for j in range(s_size): 121 | component_mask_area = torch.sum(segmap.bool()[i, j]) 122 | if component_mask_area > 0: 123 | codes_component_feature = codes[i].masked_select(segmap.bool()[i, j]).reshape(f_size, 124 | component_mask_area).mean(1) 125 | codes_vector[i][j] = codes_component_feature 126 | else: 127 | tmpmean, tmpstd = calc_mean_std( 128 | codes[i].reshape(1, codes[i].shape[0], codes[i].shape[1], codes[i].shape[2])) 129 | codes_vector[i][j] = tmpmean.squeeze() 130 | 131 | 132 | return codes_vector.view(bs, -1).unsqueeze(2).unsqueeze(3) 133 | 134 | 135 | class ContentEncoder(nn.Module): 136 | def __init__(self, layers=2, ngf=64, img_f=512, use_spect = False, use_coord = False): 137 | super(ContentEncoder, self).__init__() 138 | 139 | self.layers = layers 140 | norm_layer = get_norm_layer(norm_type='instance') 141 | nonlinearity = get_nonlinearity_layer(activation_type='LeakyReLU') 142 | self.ngf = ngf 143 | self.img_f = img_f 144 | self.block0 = EncoderBlock(30, ngf, norm_layer, 145 | nonlinearity, use_spect, use_coord) 146 | mult = 1 147 | for i in range(self.layers-1): 148 | mult_prev = mult 149 | mult = min(2 ** (i + 1), self.img_f//self.ngf) 150 | block = EncoderBlock(self.ngf*mult_prev, self.ngf*mult, norm_layer, 151 | nonlinearity, use_spect, use_coord) 152 | setattr(self, 'encoder' + str(i), block) 153 | 154 | self.model0 = [] 155 | self.model0 += [norm_layer(128)] 156 | self.model0 += [nonlinearity] 157 | self.model0 += [nn.Conv2d(128, 256, 1, 1, 0)] 158 | self.model0 = nn.Sequential(*self.model0) 159 | 160 | def forward(self, x): 161 | out = self.block0(x) 162 | for i in range(self.layers-1): 163 | model = getattr(self, 'encoder' + str(i)) 164 | out = model(out) 165 | out = self.model0(out) 166 | return out 167 | 168 | 169 | class FFN(nn.Module): 170 | def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.): 171 | super().__init__() 172 | out_features = out_features or in_features 173 | hidden_features = hidden_features or in_features 174 | self.fc1 = nn.Conv2d(in_features, hidden_features, 1) 175 | self.fc2 = nn.Conv2d(hidden_features, out_features, 1) 176 | self.drop = nn.Dropout(drop) 177 | 178 | def forward(self, x): 179 | b, c, h, w = x.size() 180 | x = self.fc1(x) 181 | x = F.gelu(x) 182 | x = self.drop(x) 183 | x = self.fc2(x) 184 | x = self.drop(x) 185 | x = torch.reshape(x, (b, c, h, w)) 186 | return x 187 | 188 | 189 | 190 | class Decoder(nn.Module): 191 | def __init__(self, style_dim, mlp_dim, n_upsample, n_res, dim, output_dim, SP_input_nc, res_norm='adain', 192 | activ='relu', pad_type='zero'): 193 | super(Decoder, self).__init__() 194 | self.softmax = nn.Softmax(dim=1) 195 | self.softmax_style = nn.Softmax(dim=2) 196 | self.SP_input_nc = SP_input_nc 197 | self.model0 = [] 198 | self.model1 = [] 199 | self.model2 = [] 200 | self.n_res = n_res 201 | 202 | self.mlp = MLP(style_dim, n_res * dim * 4, mlp_dim, 3, norm='none', activ=activ) 203 | self.fc = LinearBlock(style_dim, style_dim, norm='none', activation=activ) 204 | 205 | # AdaIN residual blocks 206 | self.model0_0 = [ResBlock_my(dim, res_norm, activ, pad_type=pad_type)] 207 | self.model0_0 = nn.Sequential(*self.model0_0) 208 | self.model0_1 = [ResBlock_my(dim, res_norm, activ, pad_type=pad_type)] 209 | self.model0_1 = nn.Sequential(*self.model0_1) 210 | self.model0_2 = [ResBlock_my(dim, res_norm, activ, pad_type=pad_type)] 211 | self.model0_2 = nn.Sequential(*self.model0_2) 212 | self.model0_3 = [ResBlock_my(dim, res_norm, activ, pad_type=pad_type)] 213 | self.model0_3 = nn.Sequential(*self.model0_3) 214 | self.model0_4 = [ResBlock_myDFNM(dim, 'spade', activ, pad_type=pad_type)] 215 | self.model0_4 = nn.Sequential(*self.model0_4) 216 | self.model0_5 = [ResBlock_myDFNM(dim, 'spade', activ, pad_type=pad_type)] 217 | self.model0_5 = nn.Sequential(*self.model0_5) 218 | self.model0_6 = [ResBlock_myDFNM(dim, 'spade', activ, pad_type=pad_type)] 219 | self.model0_6 = nn.Sequential(*self.model0_6) 220 | self.model0_7 = [ResBlock_myDFNM(dim, 'spade', activ, pad_type=pad_type)] 221 | self.model0_7 = nn.Sequential(*self.model0_7) 222 | # upsampling blocks 223 | for i in range(n_upsample): 224 | self.model1 += [nn.Upsample(scale_factor=2), 225 | Conv2dBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)] 226 | dim //= 2 227 | self.model1 = nn.Sequential(*self.model1) 228 | # use reflection padding in the last conv layer 229 | self.model2 += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)] 230 | self.model2 = nn.Sequential(*self.model2) 231 | # attention parameter 232 | 233 | self.gamma3_1 = nn.Parameter(torch.zeros(1)) 234 | self.gamma3_2 = nn.Parameter(torch.zeros(1)) 235 | self.gamma3_3 = nn.Parameter(torch.zeros(1)) 236 | self.gamma3_style_sa = nn.Parameter(torch.zeros(1)) 237 | in_dim = int(style_dim / self.SP_input_nc) 238 | self.value3_conv_sa = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 239 | self.LN_3_style = ILNKVT(256) 240 | self.LN_3_pose = ILNQT(256) 241 | self.LN_3_pose_0 = ILNQT(256) 242 | self.query3_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 243 | self.key3_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 244 | self.value3_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 245 | self.query3_conv_0 = nn.Conv2d(in_channels=in_dim, out_channels=self.SP_input_nc, kernel_size=1) 246 | 247 | self.gamma4_1 = nn.Parameter(torch.zeros(1)) 248 | self.gamma4_2 = nn.Parameter(torch.zeros(1)) 249 | self.gamma4_3 = nn.Parameter(torch.zeros(1)) 250 | self.gamma4_style_sa = nn.Parameter(torch.zeros(1)) 251 | in_dim = int(style_dim / self.SP_input_nc) 252 | self.value4_conv_sa = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 253 | self.LN_4_style = ILNKVT(256) 254 | self.LN_4_pose = ILNQT(256) 255 | self.LN_4_pose_0 = ILNQT(256) 256 | self.query4_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 257 | self.key4_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 258 | self.value4_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 259 | self.query4_conv_0 = nn.Conv2d(in_channels=in_dim, out_channels=self.SP_input_nc, kernel_size=1) 260 | 261 | self.FFN3_1 = FFN(256) 262 | self.FFN4_1 = FFN(256) 263 | self.up = nn.Upsample(scale_factor=2) 264 | 265 | def forward(self, x, style): 266 | # fusion module 267 | style_fusion = self.fc(style.view(style.size(0), -1)) 268 | adain_params = self.mlp(style_fusion) 269 | adain_params = torch.split(adain_params, int(adain_params.shape[1] / self.n_res), 1) 270 | 271 | x_0 = x 272 | x = self.model0_0([x, adain_params[0]]) 273 | x = self.model0_1([x, adain_params[1]]) 274 | x = self.model0_2([x, adain_params[2]]) 275 | x = self.model0_3([x, adain_params[3]]) 276 | 277 | x3, enerrgy_sum3 = self.styleatt(x, x_0, style, self.gamma3_1, self.gamma3_2, self.gamma3_3, \ 278 | self.gamma3_style_sa, self.value3_conv_sa, \ 279 | self.LN_3_style, self.LN_3_pose, self.LN_3_pose_0, \ 280 | self.query3_conv, self.key3_conv, self.value3_conv, self.query3_conv_0, \ 281 | self.FFN3_1) 282 | 283 | x_, enerrgy_sum4 = self.styleatt(x3, x_0, style, self.gamma4_1, self.gamma4_2, self.gamma4_3, \ 284 | self.gamma4_style_sa, self.value4_conv_sa, \ 285 | self.LN_4_style, self.LN_4_pose, self.LN_4_pose_0, \ 286 | self.query4_conv, self.key4_conv, self.value4_conv, self.query4_conv_0, \ 287 | self.FFN4_1) 288 | 289 | x = self.model0_4([x_0, x_]) 290 | x = self.model0_5([x, x_]) 291 | x = self.model0_6([x, x_]) 292 | x = self.model0_7([x, x_]) 293 | x = self.model1(x) 294 | return self.model2(x), [enerrgy_sum3, enerrgy_sum4] 295 | 296 | def styleatt(self, x, x_0, style, gamma1, gamma2, gamma3, gamma_style_sa, value_conv_sa, ln_style, ln_pose, 297 | ln_pose_0, query_conv, key_conv, value_conv, query_conv_0, ffn1): 298 | B, C, H, W = x.size() 299 | B, Cs, _, _ = style.size() 300 | K = self.SP_input_nc 301 | style = style.view((B, K, int(Cs / K))) # [B,K,C] 302 | 303 | x = ln_pose(x) # [B,C,H,W] 304 | style = ln_style(style.permute(0, 2, 1)) # [B,C,K] 305 | x_0 = ln_pose_0(x_0) 306 | 307 | style = style.permute(0, 2, 1) # [B,K,C] 308 | style_sa_value = torch.squeeze(value_conv_sa(torch.unsqueeze(style.permute(0, 2, 1), 3)), 3) # [B,C,K] 309 | self_att = self.softmax(torch.bmm(style, style.permute(0, 2, 1))) + 1e-8 # [B,K,K] 310 | self_att = self_att / torch.sum(self_att, dim=2, keepdim=True) 311 | style_ = torch.bmm(self_att, style_sa_value.permute(0, 2, 1)) 312 | style = style + gamma_style_sa * style_ # [B,K,C] 313 | 314 | style = style.permute(0, 2, 1) #[B,C,K] 315 | x_query = query_conv(x) 316 | style_key = torch.squeeze(key_conv(torch.unsqueeze(style, 3)).permute(0, 2, 1, 3), 3) 317 | style_value = torch.squeeze(value_conv(torch.unsqueeze(style, 3)), 3) 318 | 319 | energy_0 = query_conv_0(x_0).view((B, K, H * W)) 320 | energy = torch.bmm(style_key.detach(), x_query.view(B, C, -1)) 321 | enerrgy_sum = energy_0 + energy 322 | attention = self.softmax_style(enerrgy_sum) + 1e-8 323 | attention = attention / torch.sum(attention, dim=1, keepdim=True) 324 | 325 | out = torch.bmm(style_value, attention) 326 | out = out.view(B, C, H, W) 327 | out = gamma1 * out + x 328 | out = out + gamma3 * ffn1(out) 329 | 330 | return out, torch.reshape(enerrgy_sum, (B, K, H, W)) 331 | 332 | 333 | class ILNKVT(nn.Module): 334 | def __init__(self, num_features, eps=1e-5): 335 | super().__init__() 336 | self.eps = eps 337 | self.rho = Parameter(torch.Tensor(1, num_features, 1)) 338 | self.gamma = Parameter(torch.Tensor(1, num_features, 1)) 339 | self.beta = Parameter(torch.Tensor(1, num_features, 1)) 340 | self.rho.data.fill_(0.0) 341 | self.gamma.data.fill_(1.0) 342 | self.beta.data.fill_(0.0) 343 | 344 | def forward(self, input): 345 | in_mean, in_var = torch.mean(input, dim=[2], keepdim=True), torch.var(input, dim=[2], keepdim=True) 346 | out_in = (input - in_mean) / torch.sqrt(in_var + self.eps) 347 | ln_mean, ln_var = torch.mean(input, dim=[1], keepdim=True), torch.var(input, dim=[1], keepdim=True) 348 | out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps) 349 | out = self.rho.expand(input.shape[0], -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1)) * out_ln 350 | out = out * self.gamma.expand(input.shape[0], -1, -1) + self.beta.expand(input.shape[0], -1, -1) 351 | 352 | return out 353 | 354 | class ILNQT(nn.Module): 355 | def __init__(self, num_features, eps=1e-5): 356 | super().__init__() 357 | self.eps = eps 358 | self.rho = Parameter(torch.Tensor(1, num_features, 1, 1)) 359 | self.gamma = Parameter(torch.Tensor(1, num_features, 1, 1)) 360 | self.beta = Parameter(torch.Tensor(1, num_features, 1, 1)) 361 | self.rho.data.fill_(0.0) 362 | self.gamma.data.fill_(1.0) 363 | self.beta.data.fill_(0.0) 364 | 365 | def forward(self, input): 366 | in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True) 367 | out_in = (input - in_mean) / torch.sqrt(in_var + self.eps) 368 | ln_mean, ln_var = torch.mean(input, dim=[1], keepdim=True), torch.var(input, dim=[1], keepdim=True) 369 | out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps) 370 | out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln 371 | out = out * self.gamma.expand(input.shape[0], -1, -1, -1) + self.beta.expand(input.shape[0], -1, -1, -1) 372 | 373 | return out 374 | 375 | 376 | ################################################################################## 377 | # Sequential Models 378 | ################################################################################## 379 | class ResBlocks(nn.Module): 380 | def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'): 381 | super(ResBlocks, self).__init__() 382 | self.model = [] 383 | for i in range(num_blocks): 384 | self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)] 385 | self.model = nn.Sequential(*self.model) 386 | 387 | def forward(self, x): 388 | return self.model(x) 389 | 390 | 391 | class ResBlock_myDFNM(nn.Module): 392 | def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): 393 | super(ResBlock_myDFNM, self).__init__() 394 | 395 | model1 = [] 396 | model2 = [] 397 | model1 += [Conv2dBlock_my(dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] 398 | model2 += [Conv2dBlock_my(dim, dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] 399 | models1 = [] 400 | models1 += [Conv2dBlock(dim, dim, 3, 1, 1, norm='in', activation='relu', pad_type=pad_type)] 401 | models1 += [Conv2dBlock(dim, 2 * dim, 3, 1, 1, norm='none', activation='none', pad_type=pad_type)] 402 | models2 = [] 403 | models2 += [Conv2dBlock(dim, dim, 3, 1, 1, norm='in', activation='relu', pad_type=pad_type)] 404 | models2 += [Conv2dBlock(dim, 2 * dim, 3, 1, 1, norm='none', activation='none', pad_type=pad_type)] 405 | self.model1 = nn.Sequential(*model1) 406 | self.model2 = nn.Sequential(*model2) 407 | self.models1 = nn.Sequential(*models1) 408 | self.models2 = nn.Sequential(*models2) 409 | 410 | def forward(self, x): 411 | style = x[1] 412 | style1 = self.models1(style) 413 | style2 = self.models2(style) 414 | residual = x[0] 415 | out = self.model1([x[0], style1]) 416 | out = self.model2([out, style2]) 417 | out += residual 418 | 419 | return out 420 | 421 | 422 | class ResBlock_my(nn.Module): 423 | def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): 424 | super(ResBlock_my, self).__init__() 425 | 426 | model1 = [] 427 | model2 = [] 428 | model1 += [Conv2dBlock_my(dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] 429 | model2 += [Conv2dBlock_my(dim, dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] 430 | self.model1 = nn.Sequential(*model1) 431 | self.model2 = nn.Sequential(*model2) 432 | 433 | def forward(self, x): 434 | style = x[1] 435 | style1, style2 = torch.split(style, int(style.shape[1] / 2), 1) 436 | residual = x[0] 437 | out = self.model1([x[0], style1]) 438 | out = self.model2([out, style2]) 439 | out += residual 440 | return out 441 | 442 | 443 | class MLP(nn.Module): 444 | def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'): 445 | 446 | super(MLP, self).__init__() 447 | self.model = [] 448 | self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)] 449 | for i in range(n_blk - 2): 450 | self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)] 451 | self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations 452 | self.model = nn.Sequential(*self.model) 453 | 454 | def forward(self, x): 455 | return self.model(x) 456 | 457 | 458 | ################################################################################## 459 | # Basic Blocks 460 | ################################################################################## 461 | class ResBlock(nn.Module): 462 | def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): 463 | super(ResBlock, self).__init__() 464 | 465 | model = [] 466 | model += [Conv2dBlock(dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] 467 | model += [Conv2dBlock(dim, dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] 468 | self.model = nn.Sequential(*model) 469 | 470 | def forward(self, x): 471 | residual = x 472 | out = self.model(x) 473 | out += residual 474 | return out 475 | 476 | 477 | class Conv2dBlock_my(nn.Module): 478 | def __init__(self, input_dim, output_dim, kernel_size, stride, 479 | padding=0, norm='none', activation='relu', pad_type='zero'): 480 | super(Conv2dBlock_my, self).__init__() 481 | self.use_bias = True 482 | # initialize padding 483 | if pad_type == 'reflect': 484 | self.pad = nn.ReflectionPad2d(padding) 485 | elif pad_type == 'replicate': 486 | self.pad = nn.ReplicationPad2d(padding) 487 | elif pad_type == 'zero': 488 | self.pad = nn.ZeroPad2d(padding) 489 | else: 490 | assert 0, "Unsupported padding type: {}".format(pad_type) 491 | 492 | # initialize normalization 493 | norm_dim = output_dim 494 | if norm == 'bn': 495 | self.norm = nn.BatchNorm2d(norm_dim) 496 | elif norm == 'in': 497 | # self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True) 498 | self.norm = nn.InstanceNorm2d(norm_dim) 499 | elif norm == 'ln': 500 | self.norm = LayerNorm(norm_dim) 501 | elif norm == 'adain': 502 | self.norm = AdaptiveInstanceNorm2d(norm_dim) 503 | elif norm == 'spade': 504 | self.norm = SPADE() 505 | elif norm == 'none' or norm == 'sn': 506 | self.norm = None 507 | else: 508 | assert 0, "Unsupported normalization: {}".format(norm) 509 | 510 | # initialize activation 511 | if activation == 'relu': 512 | self.activation = nn.ReLU(inplace=True) 513 | elif activation == 'lrelu': 514 | self.activation = nn.LeakyReLU(0.2, inplace=True) 515 | elif activation == 'prelu': 516 | self.activation = nn.PReLU() 517 | elif activation == 'selu': 518 | self.activation = nn.SELU(inplace=True) 519 | elif activation == 'tanh': 520 | self.activation = nn.Tanh() 521 | elif activation == 'none': 522 | self.activation = None 523 | else: 524 | assert 0, "Unsupported activation: {}".format(activation) 525 | 526 | # initialize convolution 527 | if norm == 'sn': 528 | self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)) 529 | else: 530 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) 531 | 532 | def forward(self, x): 533 | style = x[1] 534 | x = x[0] 535 | x = self.conv(self.pad(x)) 536 | if self.norm: 537 | x = self.norm([x, style]) 538 | if self.activation: 539 | x = self.activation(x) 540 | return x 541 | 542 | 543 | class Conv2dBlock(nn.Module): 544 | def __init__(self, input_dim, output_dim, kernel_size, stride, 545 | padding=0, norm='none', activation='relu', pad_type='zero'): 546 | super(Conv2dBlock, self).__init__() 547 | self.use_bias = True 548 | # initialize padding 549 | if pad_type == 'reflect': 550 | self.pad = nn.ReflectionPad2d(padding) 551 | elif pad_type == 'replicate': 552 | self.pad = nn.ReplicationPad2d(padding) 553 | elif pad_type == 'zero': 554 | self.pad = nn.ZeroPad2d(padding) 555 | else: 556 | assert 0, "Unsupported padding type: {}".format(pad_type) 557 | 558 | # initialize normalization 559 | norm_dim = output_dim 560 | if norm == 'bn': 561 | self.norm = nn.BatchNorm2d(norm_dim) 562 | elif norm == 'in': 563 | # self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True) 564 | self.norm = nn.InstanceNorm2d(norm_dim) 565 | elif norm == 'ln': 566 | self.norm = LayerNorm(norm_dim) 567 | elif norm == 'adain': 568 | self.norm = AdaptiveInstanceNorm2d(norm_dim) 569 | elif norm == 'none' or norm == 'sn': 570 | self.norm = None 571 | else: 572 | assert 0, "Unsupported normalization: {}".format(norm) 573 | 574 | # initialize activation 575 | if activation == 'relu': 576 | self.activation = nn.ReLU(inplace=True) 577 | elif activation == 'lrelu': 578 | self.activation = nn.LeakyReLU(0.2, inplace=True) 579 | elif activation == 'prelu': 580 | self.activation = nn.PReLU() 581 | elif activation == 'selu': 582 | self.activation = nn.SELU(inplace=True) 583 | elif activation == 'tanh': 584 | self.activation = nn.Tanh() 585 | elif activation == 'none': 586 | self.activation = None 587 | else: 588 | assert 0, "Unsupported activation: {}".format(activation) 589 | 590 | # initialize convolution 591 | if norm == 'sn': 592 | self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)) 593 | else: 594 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) 595 | 596 | def forward(self, x): 597 | x = self.conv(self.pad(x)) 598 | if self.norm: 599 | x = self.norm(x) 600 | if self.activation: 601 | x = self.activation(x) 602 | return x 603 | 604 | 605 | class LinearBlock(nn.Module): 606 | def __init__(self, input_dim, output_dim, norm='none', activation='relu'): 607 | super(LinearBlock, self).__init__() 608 | use_bias = True 609 | # initialize fully connected layer 610 | if norm == 'sn': 611 | self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias)) 612 | else: 613 | self.fc = nn.Linear(input_dim, output_dim, bias=use_bias) 614 | 615 | # initialize normalization 616 | norm_dim = output_dim 617 | if norm == 'bn': 618 | self.norm = nn.BatchNorm1d(norm_dim) 619 | elif norm == 'in': 620 | self.norm = nn.InstanceNorm1d(norm_dim) 621 | elif norm == 'ln': 622 | self.norm = LayerNorm(norm_dim) 623 | elif norm == 'none' or norm == 'sn': 624 | self.norm = None 625 | else: 626 | assert 0, "Unsupported normalization: {}".format(norm) 627 | 628 | # initialize activation 629 | if activation == 'relu': 630 | self.activation = nn.ReLU(inplace=True) 631 | elif activation == 'lrelu': 632 | self.activation = nn.LeakyReLU(0.2, inplace=True) 633 | elif activation == 'prelu': 634 | self.activation = nn.PReLU() 635 | elif activation == 'selu': 636 | self.activation = nn.SELU(inplace=True) 637 | elif activation == 'tanh': 638 | self.activation = nn.Tanh() 639 | elif activation == 'none': 640 | self.activation = None 641 | else: 642 | assert 0, "Unsupported activation: {}".format(activation) 643 | 644 | def forward(self, x): 645 | out = self.fc(x) 646 | if self.norm: 647 | out = self.norm(out) 648 | if self.activation: 649 | out = self.activation(out) 650 | return out 651 | 652 | 653 | ################################################################################## 654 | # Normalization layers 655 | ################################################################################## 656 | class SPADE(nn.Module): 657 | def __init__(self): 658 | super().__init__() 659 | 660 | def forward(self, x): 661 | style = x[1] 662 | x = x[0] 663 | # Part 1. generate parameter-free normalized activations 664 | x_mean = torch.mean(x, (0, 2, 3), keepdim=True) 665 | x_var = torch.var(x, (0, 2, 3), keepdim=True) 666 | normalized = (x - x_mean) / (x_var + 1e-6) 667 | 668 | # Part 2. produce scaling and bias conditioned on semantic map 669 | gamma, beta = torch.split(style, int(style.size(1) / 2), 1) 670 | # apply scale and bias 671 | out = normalized * (1 + gamma) + beta 672 | 673 | return out 674 | 675 | 676 | class AdaptiveInstanceNorm2d(nn.Module): 677 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 678 | super(AdaptiveInstanceNorm2d, self).__init__() 679 | self.num_features = num_features 680 | self.eps = eps 681 | self.momentum = momentum 682 | # weight and bias are dynamically assigned 683 | self.weight = None 684 | self.bias = None 685 | # just dummy buffers, not used 686 | self.register_buffer('running_mean', torch.zeros(num_features)) 687 | self.register_buffer('running_var', torch.ones(num_features)) 688 | 689 | def forward(self, x): 690 | style = x[1] 691 | self.weight, self.bias = torch.split(style, int(style.shape[1] / 2), 1) 692 | x = x[0] 693 | b, c = x.size(0), x.size(1) 694 | running_mean = self.running_mean.repeat(b) 695 | running_var = self.running_var.repeat(b) 696 | 697 | # Apply instance norm 698 | x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) 699 | 700 | out = F.batch_norm( 701 | x_reshaped, running_mean, running_var, self.weight, self.bias, 702 | True, self.momentum, self.eps) 703 | 704 | return out.view(b, c, *x.size()[2:]) 705 | 706 | def __repr__(self): 707 | return self.__class__.__name__ + '(' + str(self.num_features) + ')' 708 | 709 | 710 | class LayerNorm(nn.Module): 711 | def __init__(self, num_features, eps=1e-5, affine=True): 712 | super(LayerNorm, self).__init__() 713 | self.num_features = num_features 714 | self.affine = affine 715 | self.eps = eps 716 | 717 | if self.affine: 718 | self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_()) 719 | self.beta = nn.Parameter(torch.zeros(num_features)) 720 | 721 | def forward(self, x): 722 | shape = [-1] + [1] * (x.dim() - 1) 723 | # print(x.size()) 724 | if x.size(0) == 1: 725 | # These two lines run much faster in pytorch 0.4 than the two lines listed below. 726 | mean = x.view(-1).mean().view(*shape) 727 | std = x.view(-1).std().view(*shape) 728 | else: 729 | mean = x.view(x.size(0), -1).mean(1).view(*shape) 730 | std = x.view(x.size(0), -1).std(1).view(*shape) 731 | 732 | x = (x - mean) / (std + self.eps) 733 | 734 | if self.affine: 735 | shape = [1, -1] + [1] * (x.dim() - 2) 736 | x = x * self.gamma.view(*shape) + self.beta.view(*shape) 737 | return x 738 | 739 | 740 | def l2normalize(v, eps=1e-12): 741 | return v / (v.norm() + eps) 742 | 743 | 744 | class SpectralNorm(nn.Module): 745 | """ 746 | Based on the paper "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida 747 | and the Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan 748 | """ 749 | 750 | def __init__(self, module, name='weight', power_iterations=1): 751 | super(SpectralNorm, self).__init__() 752 | self.module = module 753 | self.name = name 754 | self.power_iterations = power_iterations 755 | if not self._made_params(): 756 | self._make_params() 757 | 758 | def _update_u_v(self): 759 | u = getattr(self.module, self.name + "_u") 760 | v = getattr(self.module, self.name + "_v") 761 | w = getattr(self.module, self.name + "_bar") 762 | 763 | height = w.data.shape[0] 764 | for _ in range(self.power_iterations): 765 | v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data)) 766 | u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data)) 767 | 768 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) 769 | sigma = u.dot(w.view(height, -1).mv(v)) 770 | setattr(self.module, self.name, w / sigma.expand_as(w)) 771 | 772 | def _made_params(self): 773 | try: 774 | u = getattr(self.module, self.name + "_u") 775 | v = getattr(self.module, self.name + "_v") 776 | w = getattr(self.module, self.name + "_bar") 777 | return True 778 | except AttributeError: 779 | return False 780 | 781 | def _make_params(self): 782 | w = getattr(self.module, self.name) 783 | 784 | height = w.data.shape[0] 785 | width = w.view(height, -1).data.shape[1] 786 | 787 | u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 788 | v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 789 | u.data = l2normalize(u.data) 790 | v.data = l2normalize(v.data) 791 | w_bar = nn.Parameter(w.data) 792 | 793 | del self.module._parameters[self.name] 794 | 795 | self.module.register_parameter(self.name + "_u", u) 796 | self.module.register_parameter(self.name + "_v", v) 797 | self.module.register_parameter(self.name + "_bar", w_bar) 798 | 799 | def forward(self, *args): 800 | self._update_u_v() 801 | return self.module.forward(*args) 802 | 803 | 804 | def get_norm_layer(norm_type='batch'): 805 | """Get the normalization layer for the networks""" 806 | if norm_type == 'batch': 807 | norm_layer = functools.partial(nn.BatchNorm2d, momentum=0.1, affine=True) 808 | elif norm_type == 'instance': 809 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=True) 810 | elif norm_type == 'adain': 811 | norm_layer = functools.partial(ADAIN) 812 | elif norm_type == 'spade': 813 | norm_layer = functools.partial(SPADE, config_text='spadeinstance3x3') 814 | elif norm_type == 'none': 815 | norm_layer = None 816 | else: 817 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 818 | 819 | if norm_type != 'none': 820 | norm_layer.__name__ = norm_type 821 | 822 | return norm_layer 823 | 824 | def get_nonlinearity_layer(activation_type='PReLU'): 825 | """Get the activation layer for the networks""" 826 | if activation_type == 'ReLU': 827 | nonlinearity_layer = nn.ReLU() 828 | elif activation_type == 'SELU': 829 | nonlinearity_layer = nn.SELU() 830 | elif activation_type == 'LeakyReLU': 831 | nonlinearity_layer = nn.LeakyReLU(0.1) 832 | elif activation_type == 'PReLU': 833 | nonlinearity_layer = nn.PReLU() 834 | else: 835 | raise NotImplementedError('activation layer [%s] is not found' % activation_type) 836 | return nonlinearity_layer 837 | 838 | 839 | class AddCoords(nn.Module): 840 | """ 841 | Add Coords to a tensor 842 | """ 843 | def __init__(self, with_r=False): 844 | super(AddCoords, self).__init__() 845 | self.with_r = with_r 846 | 847 | def forward(self, x): 848 | """ 849 | :param x: shape (batch, channel, x_dim, y_dim) 850 | :return: shape (batch, channel+2, x_dim, y_dim) 851 | """ 852 | B, _, x_dim, y_dim = x.size() 853 | 854 | # coord calculate 855 | xx_channel = torch.arange(x_dim).repeat(B, 1, y_dim, 1).type_as(x) 856 | yy_cahnnel = torch.arange(y_dim).repeat(B, 1, x_dim, 1).permute(0, 1, 3, 2).type_as(x) 857 | # normalization 858 | xx_channel = xx_channel.float() / (x_dim-1) 859 | yy_cahnnel = yy_cahnnel.float() / (y_dim-1) 860 | xx_channel = xx_channel * 2 - 1 861 | yy_cahnnel = yy_cahnnel * 2 - 1 862 | 863 | ret = torch.cat([x, xx_channel, yy_cahnnel], dim=1) 864 | 865 | if self.with_r: 866 | rr = torch.sqrt(xx_channel ** 2 + yy_cahnnel ** 2) 867 | ret = torch.cat([ret, rr], dim=1) 868 | 869 | return ret 870 | 871 | 872 | def spectral_norm(module, use_spect=True): 873 | """use spectral normal layer to stable the training process""" 874 | if use_spect: 875 | return SpectralNorm(module) 876 | else: 877 | return module 878 | 879 | 880 | 881 | class CoordConv(nn.Module): 882 | """ 883 | CoordConv operation 884 | """ 885 | def __init__(self, input_nc, output_nc, with_r=False, use_spect=False, **kwargs): 886 | super(CoordConv, self).__init__() 887 | self.addcoords = AddCoords(with_r=with_r) 888 | input_nc = input_nc + 2 889 | if with_r: 890 | input_nc = input_nc + 1 891 | self.conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect) 892 | 893 | def forward(self, x): 894 | ret = self.addcoords(x) 895 | ret = self.conv(ret) 896 | 897 | return ret 898 | 899 | 900 | def coord_conv(input_nc, output_nc, use_spect=False, use_coord=False, with_r=False, **kwargs): 901 | """use coord convolution layer to add position information""" 902 | if use_coord: 903 | return CoordConv(input_nc, output_nc, with_r, use_spect, **kwargs) 904 | else: 905 | return spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect) 906 | 907 | 908 | class EncoderBlock(nn.Module): 909 | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), 910 | use_spect=False, use_coord=False): 911 | super(EncoderBlock, self).__init__() 912 | 913 | 914 | kwargs_down = {'kernel_size': 4, 'stride': 2, 'padding': 1} 915 | kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1} 916 | 917 | conv1 = coord_conv(input_nc, output_nc, use_spect, use_coord, **kwargs_down) 918 | conv2 = coord_conv(output_nc, output_nc, use_spect, use_coord, **kwargs_fine) 919 | 920 | if type(norm_layer) == type(None): 921 | self.model = nn.Sequential(nonlinearity, conv1, nonlinearity, conv2,) 922 | else: 923 | self.model = nn.Sequential(norm_layer(input_nc), nonlinearity, conv1, 924 | norm_layer(output_nc), nonlinearity, conv2,) 925 | 926 | def forward(self, x): 927 | out = self.model(x) 928 | return out 929 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyzhouo/CASD/cb1aabb64b5d8ae712b626a1eec045d08f90933e/models/__init__.py -------------------------------------------------------------------------------- /models/adgan.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import numpy as np 3 | import torch 4 | import os 5 | from collections import OrderedDict 6 | import util.util as util 7 | from util.image_pool import ImagePool 8 | from .base_model import BaseModel 9 | from . import networks 10 | # losses 11 | from losses.L1_plus_perceptualLoss import L1_plus_perceptualLoss 12 | from losses.CX_style_loss import CXLoss 13 | from .vgg_SC import VGG, VGGLoss 14 | from losses.lpips.lpips import LPIPS 15 | 16 | 17 | 18 | class TransferModel(BaseModel): 19 | def name(self): 20 | return 'TransferModel' 21 | 22 | def initialize(self, opt): 23 | BaseModel.initialize(self, opt) 24 | nb = opt.batchSize 25 | size = opt.fineSize 26 | self.use_AMCE = opt.use_AMCE 27 | self.use_BPD = opt.use_BPD 28 | self.SP_input_nc = opt.SP_input_nc 29 | self.input_P1_set = self.Tensor(nb, opt.P_input_nc, size[0], size[1]) 30 | self.input_BP1_set = self.Tensor(nb, opt.BP_input_nc, size[0], size[1]) 31 | self.input_P2_set = self.Tensor(nb, opt.P_input_nc, size[0], size[1]) 32 | self.input_BP2_set = self.Tensor(nb, opt.BP_input_nc, size[0], size[1]) 33 | self.input_SP1_set = self.Tensor(nb, opt.SP_input_nc, size[0], size[1]) 34 | self.input_SP2_set = self.Tensor(nb, opt.SP_input_nc, size[0], size[1]) 35 | if self.use_BPD: 36 | self.input_BPD1_set = self.Tensor(nb, opt.BPD_input_nc, size[0], size[1]) 37 | self.input_BPD2_set = self.Tensor(nb, opt.BPD_input_nc, size[0], size[1]) 38 | 39 | 40 | input_nc = [opt.P_input_nc, opt.BP_input_nc+opt.BP_input_nc + (opt.BPD_input_nc+opt.BPD_input_nc if self.use_BPD else 0)] 41 | self.netG = networks.define_G(input_nc, opt.P_input_nc, 42 | opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, 43 | n_downsampling=opt.G_n_downsampling) 44 | 45 | if self.isTrain: 46 | use_sigmoid = opt.no_lsgan 47 | if opt.with_D_PB: 48 | self.netD_PB = networks.define_D(opt.P_input_nc+opt.BP_input_nc + (opt.BPD_input_nc if self.use_BPD else 0), opt.ndf, 49 | opt.which_model_netD, 50 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, 51 | not opt.no_dropout_D, 52 | n_downsampling = opt.D_n_downsampling) 53 | 54 | if opt.with_D_PP: 55 | self.netD_PP = networks.define_D(opt.P_input_nc+opt.P_input_nc, opt.ndf, 56 | opt.which_model_netD, 57 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids, 58 | not opt.no_dropout_D, 59 | n_downsampling = opt.D_n_downsampling) 60 | 61 | if len(opt.gpu_ids) > 1: 62 | self.load_VGG(self.netG.module.enc_style.vgg) 63 | else: 64 | self.load_VGG(self.netG.enc_style.vgg) 65 | 66 | if not self.isTrain or opt.continue_train: 67 | which_epoch = opt.which_epoch 68 | self.load_network(self.netG, 'netG', which_epoch) 69 | if self.isTrain: 70 | if opt.with_D_PB: 71 | self.load_network(self.netD_PB, 'netD_PB', which_epoch) 72 | if opt.with_D_PP: 73 | self.load_network(self.netD_PP, 'netD_PP', which_epoch) 74 | 75 | if self.isTrain: 76 | self.old_lr = opt.lr 77 | self.fake_PP_pool = ImagePool(opt.pool_size) 78 | self.fake_PB_pool = ImagePool(opt.pool_size) 79 | # define loss functions 80 | self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) 81 | 82 | 83 | if opt.L1_type == 'origin': 84 | self.criterionL1 = torch.nn.L1Loss() 85 | elif opt.L1_type == 'l1_plus_perL1': 86 | self.criterionL1 = L1_plus_perceptualLoss(opt.lambda_A, opt.lambda_B, opt.perceptual_layers, self.gpu_ids, opt.percep_is_l1) 87 | else: 88 | raise Excption('Unsurportted type of L1!') 89 | 90 | if opt.use_cxloss: 91 | self.CX_loss = CXLoss(sigma=0.5) 92 | if torch.cuda.is_available(): 93 | self.CX_loss.cuda() 94 | self.vgg = VGG() 95 | self.vgg.load_state_dict(torch.load(os.path.abspath(opt.dataroot) + '/vgg_conv.pth')) 96 | for param in self.vgg.parameters(): 97 | param.requires_grad = False 98 | if torch.cuda.is_available(): 99 | self.vgg.cuda() 100 | 101 | if opt.use_lpips: 102 | self.lpips_loss = LPIPS(net_type='vgg').cuda().eval() 103 | 104 | if self.use_AMCE: 105 | self.AM_CE_loss = torch.nn.CrossEntropyLoss() 106 | if torch.cuda.is_available(): 107 | self.AM_CE_loss.cuda() 108 | 109 | 110 | self.Vggloss = VGGLoss().cuda().eval() 111 | 112 | 113 | # initialize optimizers 114 | self.optimizer_G = torch.optim.Adam(filter(lambda p: p.requires_grad, self.netG.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) 115 | 116 | if opt.with_D_PB: 117 | self.optimizer_D_PB = torch.optim.Adam(self.netD_PB.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 118 | if opt.with_D_PP: 119 | self.optimizer_D_PP = torch.optim.Adam(self.netD_PP.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 120 | 121 | self.optimizers = [] 122 | self.schedulers = [] 123 | self.optimizers.append(self.optimizer_G) 124 | if opt.with_D_PB: 125 | self.optimizers.append(self.optimizer_D_PB) 126 | if opt.with_D_PP: 127 | self.optimizers.append(self.optimizer_D_PP) 128 | for optimizer in self.optimizers: 129 | self.schedulers.append(networks.get_scheduler(optimizer, opt)) 130 | 131 | print('---------- Networks initialized -------------') 132 | networks.print_network(self.netG) 133 | if self.isTrain: 134 | if opt.with_D_PB: 135 | networks.print_network(self.netD_PB) 136 | if opt.with_D_PP: 137 | networks.print_network(self.netD_PP) 138 | print('-----------------------------------------------') 139 | 140 | 141 | def set_input(self, input): 142 | input_P1, input_BP1 = input['P1'], input['BP1'] 143 | input_P2, input_BP2 = input['P2'], input['BP2'] 144 | 145 | self.input_P1_set.resize_(input_P1.size()).copy_(input_P1) 146 | self.input_BP1_set.resize_(input_BP1.size()).copy_(input_BP1) 147 | self.input_P2_set.resize_(input_P2.size()).copy_(input_P2) 148 | self.input_BP2_set.resize_(input_BP2.size()).copy_(input_BP2) 149 | 150 | if self.use_BPD: 151 | input_BPD1, input_BPD2 = input['BPD1'], input['BPD2'] 152 | self.input_BPD1_set.resize_(input_BPD1.size()).copy_(input_BPD1) 153 | self.input_BPD2_set.resize_(input_BPD2.size()).copy_(input_BPD2) 154 | 155 | input_SP1 = input['SP1'] 156 | self.input_SP1_set.resize_(input_SP1.size()).copy_(input_SP1) 157 | if self.use_AMCE: 158 | input_SP2 = input['SP2'] 159 | self.input_SP2_set.resize_(input_SP2.size()).copy_(input_SP2) 160 | 161 | self.image_paths = input['P1_path'][0] + '___' + input['P2_path'][0] 162 | self.person_paths = input['P1_path'][0] 163 | 164 | 165 | def forward(self): 166 | 167 | self.input_P1 = Variable(self.input_P1_set) 168 | self.input_BP1 = Variable(self.input_BP1_set) 169 | 170 | self.input_P2 = Variable(self.input_P2_set) 171 | self.input_BP2 = Variable(self.input_BP2_set) 172 | 173 | if self.use_BPD: 174 | self.input_BPD1 = Variable(self.input_BPD1_set) 175 | self.input_BPD2 = Variable(self.input_BPD2_set) 176 | 177 | self.input_SP1 = Variable(self.input_SP1_set) 178 | self.input_SP2 = Variable(self.input_SP2_set) 179 | 180 | if self.use_BPD: 181 | self.fake_p2, self.fake_sp2 = self.netG(torch.cat([self.input_BP2, self.input_BPD2], 1), self.input_P1, self.input_SP1) 182 | else: 183 | self.fake_p2, self.fake_sp2 = self.netG(self.input_BP2, self.input_P1, self.input_SP1) 184 | 185 | 186 | def test(self): 187 | self.input_P1 = Variable(self.input_P1_set) 188 | self.input_BP1 = Variable(self.input_BP1_set) 189 | 190 | self.input_P2 = Variable(self.input_P2_set) 191 | self.input_BP2 = Variable(self.input_BP2_set) 192 | 193 | if self.use_BPD: 194 | self.input_BPD1 = Variable(self.input_BPD1_set) 195 | self.input_BPD2 = Variable(self.input_BPD2_set) 196 | 197 | self.input_SP1 = Variable(self.input_SP1_set) 198 | self.input_SP2 = Variable(self.input_SP2_set) 199 | 200 | 201 | if self.use_BPD: 202 | self.fake_p2, self.fake_sp2 = self.netG(torch.cat([self.input_BP2, self.input_BPD2], 1), self.input_P1, self.input_SP1) 203 | else: 204 | self.fake_p2, self.fake_sp2 = self.netG(self.input_BP2, self.input_P1, self.input_SP1) 205 | 206 | 207 | # get image paths 208 | def get_image_paths(self): 209 | return self.image_paths 210 | 211 | def get_person_paths(self): 212 | return self.person_paths 213 | 214 | 215 | def backward_G(self): 216 | if self.opt.with_D_PB: 217 | if self.use_BPD: 218 | pred_fake_PB = self.netD_PB(torch.cat((self.fake_p2, self.input_BP2, self.input_BPD2), 1)) 219 | else: 220 | pred_fake_PB = self.netD_PB(torch.cat((self.fake_p2, self.input_BP2), 1)) 221 | self.loss_G_GAN_PB = self.criterionGAN(pred_fake_PB, True) 222 | 223 | if self.opt.with_D_PP: 224 | pred_fake_PP = self.netD_PP(torch.cat((self.fake_p2, self.input_P1), 1)) 225 | self.loss_G_GAN_PP = self.criterionGAN(pred_fake_PP, True) 226 | 227 | # CX loss 228 | if self.opt.use_cxloss: 229 | style_layer = ['r32', 'r42'] 230 | vgg_style = self.vgg(self.input_P2, style_layer) 231 | vgg_fake = self.vgg(self.fake_p2, style_layer) 232 | cx_style_loss = 0 233 | 234 | for i, val in enumerate(vgg_fake): 235 | cx_style_loss += self.CX_loss(vgg_style[i], vgg_fake[i]) 236 | cx_style_loss *= self.opt.lambda_cx 237 | 238 | pair_cxloss = cx_style_loss 239 | 240 | if self.opt.use_lpips: 241 | lpips_loss = self.lpips_loss(self.fake_p2, self.input_P2) 242 | lpips_loss *= self.opt.lambda_lpips 243 | pair_lpips_loss = lpips_loss 244 | 245 | # Attention Map Cross Entropy loss 246 | if self.use_AMCE: 247 | up_ = torch.nn.Upsample(scale_factor=4, mode='bilinear') 248 | if isinstance(self.fake_sp2,list): 249 | AMCE_loss = 0 250 | B, C, H, W = self.input_SP2.shape 251 | for i in range(len(self.fake_sp2)): 252 | logits = up_(self.fake_sp2[i]) 253 | logits = torch.reshape(logits.permute(0,2,3,1), (B*H*W, C)) 254 | labels = torch.argmax(torch.reshape(self.input_SP2.permute(0,2,3,1), (B*H*W, C)), 1) 255 | AMCE_loss += self.AM_CE_loss(logits, labels) 256 | 257 | AMCE_loss *= self.opt.lambda_AMCE 258 | pair_AMCE_loss = AMCE_loss 259 | else: 260 | logits = up_(self.fake_sp2) 261 | B, C, H, W = self.input_SP2.shape 262 | logits = torch.reshape(logits.permute(0,2,3,1), (B*H*W, C)) 263 | labels = torch.argmax(torch.reshape(self.input_SP2.permute(0,2,3,1), (B*H*W, C)), 1) 264 | AMCE_loss = self.AM_CE_loss(logits, labels) 265 | AMCE_loss *= self.opt.lambda_AMCE 266 | pair_AMCE_loss = AMCE_loss 267 | 268 | self.opt.lambda_style = 200 269 | self.opt.lambda_content = 0.5 270 | loss_content_gen, loss_style_gen = self.Vggloss(self.fake_p2, self.input_P2) 271 | pair_style_loss = loss_style_gen*self.opt.lambda_style 272 | pair_content_loss = loss_content_gen*self.opt.lambda_content 273 | 274 | 275 | 276 | # L1 loss 277 | if self.opt.L1_type == 'l1_plus_perL1' : 278 | losses = self.criterionL1(self.fake_p2, self.input_P2) 279 | self.loss_G_L1 = losses[0] 280 | self.loss_originL1 = losses[1].data 281 | self.loss_perceptual = losses[2].data 282 | 283 | else: 284 | self.loss_G_L1 = self.criterionL1(self.fake_p2, self.input_P2) * self.opt.lambda_A 285 | 286 | pair_L1loss = self.loss_G_L1 287 | 288 | if self.opt.with_D_PB: 289 | pair_GANloss = self.loss_G_GAN_PB * self.opt.lambda_GAN 290 | if self.opt.with_D_PP: 291 | pair_GANloss += self.loss_G_GAN_PP * self.opt.lambda_GAN 292 | pair_GANloss = pair_GANloss / 2 293 | else: 294 | if self.opt.with_D_PP: 295 | pair_GANloss = self.loss_G_GAN_PP * self.opt.lambda_GAN 296 | 297 | 298 | if self.opt.with_D_PB or self.opt.with_D_PP: 299 | pair_loss = pair_L1loss + pair_GANloss 300 | else: 301 | pair_loss = pair_L1loss 302 | 303 | if self.opt.use_cxloss: 304 | pair_loss = pair_loss + pair_cxloss 305 | if self.opt.use_AMCE: 306 | pair_loss = pair_loss + pair_AMCE_loss 307 | if self.opt.use_lpips: 308 | pair_loss = pair_loss + pair_lpips_loss 309 | 310 | pair_loss = pair_loss + pair_content_loss 311 | pair_loss = pair_loss + pair_style_loss 312 | 313 | pair_loss.backward() 314 | 315 | self.pair_L1loss = pair_L1loss.data 316 | if self.opt.with_D_PB or self.opt.with_D_PP: 317 | self.pair_GANloss = pair_GANloss.data 318 | 319 | if self.opt.use_cxloss: 320 | self.pair_cxloss = pair_cxloss.data 321 | 322 | if self.opt.use_lpips: 323 | self.pair_lpips_loss = pair_lpips_loss.data 324 | if self.opt.use_AMCE: 325 | self.pair_AMCE_loss = pair_AMCE_loss.data 326 | 327 | self.pair_content_loss = pair_content_loss.data 328 | self.pair_style_loss = pair_style_loss.data 329 | 330 | 331 | def backward_D_basic(self, netD, real, fake): 332 | # Real 333 | pred_real = netD(real) 334 | loss_D_real = self.criterionGAN(pred_real, True) * self.opt.lambda_GAN 335 | # Fake 336 | pred_fake = netD(fake.detach()) 337 | loss_D_fake = self.criterionGAN(pred_fake, False) * self.opt.lambda_GAN 338 | # Combined loss 339 | loss_D = (loss_D_real + loss_D_fake) * 0.5 340 | # backward 341 | loss_D.backward() 342 | return loss_D 343 | 344 | # D: take(P, B) as input 345 | def backward_D_PB(self): 346 | if self.use_BPD: 347 | real_PB = torch.cat((self.input_P2, self.input_BP2, self.input_BPD2), 1) 348 | fake_PB = self.fake_PB_pool.query( torch.cat((self.fake_p2, self.input_BP2, self.input_BPD2), 1).data ) 349 | else: 350 | real_PB = torch.cat((self.input_P2, self.input_BP2), 1) 351 | fake_PB = self.fake_PB_pool.query( torch.cat((self.fake_p2, self.input_BP2), 1).data ) 352 | loss_D_PB = self.backward_D_basic(self.netD_PB, real_PB, fake_PB) 353 | 354 | self.loss_D_PB = loss_D_PB.data 355 | 356 | # D: take(P, P') as input 357 | def backward_D_PP(self): 358 | real_PP = torch.cat((self.input_P2, self.input_P1), 1) 359 | fake_PP = self.fake_PP_pool.query( torch.cat((self.fake_p2, self.input_P1), 1).data ) 360 | loss_D_PP = self.backward_D_basic(self.netD_PP, real_PP, fake_PP) 361 | 362 | self.loss_D_PP = loss_D_PP.data 363 | 364 | 365 | def optimize_parameters(self): 366 | # forward 367 | self.forward() 368 | 369 | self.optimizer_G.zero_grad() 370 | self.backward_G() 371 | self.optimizer_G.step() 372 | 373 | # D_P 374 | if self.opt.with_D_PP: 375 | for i in range(self.opt.DG_ratio): 376 | self.optimizer_D_PP.zero_grad() 377 | self.backward_D_PP() 378 | self.optimizer_D_PP.step() 379 | 380 | # D_BP 381 | if self.opt.with_D_PB: 382 | for i in range(self.opt.DG_ratio): 383 | self.optimizer_D_PB.zero_grad() 384 | self.backward_D_PB() 385 | self.optimizer_D_PB.step() 386 | 387 | def get_current_errors(self): 388 | ret_errors = OrderedDict([ ('pair_L1loss', self.pair_L1loss)]) 389 | if self.opt.with_D_PP: 390 | ret_errors['D_PP'] = self.loss_D_PP 391 | if self.opt.with_D_PB: 392 | ret_errors['D_PB'] = self.loss_D_PB 393 | if self.opt.with_D_PB or self.opt.with_D_PP or self.opt.with_D_PS: 394 | ret_errors['pair_GANloss'] = self.pair_GANloss 395 | 396 | if self.opt.L1_type == 'l1_plus_perL1': 397 | ret_errors['origin_L1'] = self.loss_originL1 398 | ret_errors['perceptual'] = self.loss_perceptual 399 | 400 | if self.opt.use_cxloss: 401 | ret_errors['CXLoss'] = self.pair_cxloss 402 | if self.opt.use_lpips: 403 | ret_errors['lpips'] = self.pair_lpips_loss 404 | if self.opt.use_AMCE: 405 | ret_errors['AMCE'] = self.pair_AMCE_loss 406 | 407 | ret_errors['content'] = self.pair_content_loss 408 | ret_errors['style'] = self.pair_style_loss 409 | 410 | return ret_errors 411 | 412 | def get_current_visuals(self): 413 | height, width = self.input_P1.size(2), self.input_P1.size(3) 414 | input_P1 = util.tensor2im(self.input_P1.data) 415 | input_P2 = util.tensor2im(self.input_P2.data) 416 | 417 | input_BP1 = util.draw_pose_from_map(self.input_BP1.data)[0] 418 | input_BP2 = util.draw_pose_from_map(self.input_BP2.data)[0] 419 | 420 | 421 | if self.use_BPD: 422 | input_BPD1 = util.draw_dis_from_map(self.input_BP1.data)[1] 423 | input_BPD1 = (np.repeat(np.expand_dims(input_BPD1, -1), 3, -1)*255).astype('uint8') 424 | input_BPD2 = util.draw_dis_from_map(self.input_BP2.data)[1] 425 | input_BPD2 = (np.repeat(np.expand_dims(input_BPD2, -1), 3, -1)*255).astype('uint8') 426 | 427 | 428 | fake_p2 = util.tensor2im(self.fake_p2.data) 429 | 430 | if self.use_BPD: 431 | vis = np.zeros((height, width*7, 3)).astype(np.uint8) #h, w, c 432 | vis[:, :width, :] = input_P1 433 | vis[:, width:width*2, :] = input_BP1 434 | vis[:, width*2:width*3, :] = input_BPD1 435 | vis[:, width*3:width*4, :] = input_P2 436 | vis[:, width*4:width*5, :] = input_BP2 437 | vis[:, width*5:width*6, :] = input_BPD2 438 | vis[:, width*6:width*7, :] = fake_p2 439 | else: 440 | vis = np.zeros((height, width*5, 3)).astype(np.uint8) #h, w, c 441 | vis[:, :width, :] = input_P1 442 | vis[:, width:width*2, :] = input_BP1 443 | vis[:, width*2:width*3, :] = input_P2 444 | vis[:, width*3:width*4, :] = input_BP2 445 | vis[:, width*4:, :] = fake_p2 446 | 447 | ret_visuals = OrderedDict([('vis', vis)]) 448 | 449 | return ret_visuals 450 | 451 | 452 | def save(self, label): 453 | self.save_network(self.netG, 'netG', label, self.gpu_ids) 454 | if self.opt.with_D_PB: 455 | self.save_network(self.netD_PB, 'netD_PB', label, self.gpu_ids) 456 | if self.opt.with_D_PP: 457 | self.save_network(self.netD_PP, 'netD_PP', label, self.gpu_ids) 458 | 459 | 460 | 461 | 462 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.models.vgg as models 5 | 6 | class BaseModel(nn.Module): 7 | 8 | def __init__(self): 9 | super(BaseModel, self).__init__() 10 | 11 | def name(self): 12 | return 'BaseModel' 13 | 14 | def initialize(self, opt): 15 | self.opt = opt 16 | self.gpu_ids = opt.gpu_ids 17 | self.isTrain = opt.isTrain 18 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 19 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 20 | self.vgg_path = os.path.join(os.path.abspath(opt.dataroot), 'vgg19-dcbb9e9d.pth') 21 | 22 | def set_input(self, input): 23 | self.input = input 24 | 25 | def forward(self): 26 | pass 27 | 28 | # used in test time, no backprop 29 | def test(self): 30 | pass 31 | 32 | def get_image_paths(self): 33 | pass 34 | 35 | def optimize_parameters(self): 36 | pass 37 | 38 | def get_current_visuals(self): 39 | return self.input 40 | 41 | def get_current_errors(self): 42 | return {} 43 | 44 | def save(self, label): 45 | pass 46 | 47 | # helper saving function that can be used by subclasses 48 | def save_network(self, network, network_label, epoch_label, gpu_ids): 49 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 50 | save_path = os.path.join(self.save_dir, save_filename) 51 | torch.save(network.cpu().state_dict(), save_path) 52 | if len(gpu_ids) and torch.cuda.is_available(): 53 | network.cuda(gpu_ids[0]) 54 | 55 | # helper loading function that can be used by subclasses 56 | def load_network(self, network, network_label, epoch_label): 57 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 58 | save_path = os.path.join(self.save_dir, save_filename) 59 | # network.load_state_dict(torch.load(save_path)) 60 | 61 | model_dict = torch.load(save_path) 62 | model_dict_clone = model_dict.copy() # We can't mutate while iterating 63 | for key, value in model_dict_clone.items(): 64 | if key.endswith(('running_mean', 'running_var')): 65 | del model_dict[key] 66 | ### Next cell 67 | network.load_state_dict(model_dict, False) 68 | 69 | def load_VGG(self, network): 70 | # pretrained_dict = torch.load(self.vgg_path) 71 | 72 | # pretrained_model = models.vgg19(pretrained=True).features 73 | vgg19 = models.vgg19(pretrained=False) 74 | vgg19.load_state_dict(torch.load(self.vgg_path)) 75 | pretrained_model = vgg19.features 76 | 77 | pretrained_dict = pretrained_model.state_dict() 78 | 79 | model_dict = network.state_dict() 80 | 81 | # filter out unnecessary keys 82 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 83 | # overwrite entries in the existing state dict 84 | model_dict.update(pretrained_dict) 85 | # load the new state dict 86 | network.load_state_dict(model_dict) 87 | 88 | # update learning rate (called once every epoch) 89 | def update_learning_rate(self): 90 | for scheduler in self.schedulers: 91 | scheduler.step() 92 | lr = self.optimizers[0].param_groups[0]['lr'] 93 | print('learning rate = %.7f' % lr) 94 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | 2 | def create_model(opt): 3 | model = None 4 | print(opt.model) 5 | if opt.model == 'adgan': 6 | assert opt.dataset_mode == 'keypoint' 7 | from .adgan import TransferModel 8 | model = TransferModel() 9 | elif opt.model == 'adgan_mix': 10 | assert opt.dataset_mode == 'keypoint_mix' 11 | from .adgan_mix import TransferModel 12 | model = TransferModel() 13 | else: 14 | raise ValueError("Model [%s] not recognized." % opt.model) 15 | model.initialize(opt) 16 | print("model [%s] was created" % (model.name())) 17 | return model 18 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.autograd import Variable 6 | from torch.optim import lr_scheduler 7 | 8 | import math 9 | 10 | # added 11 | def weights_init_ada(init_type='gaussian'): 12 | def init_fun(m): 13 | classname = m.__class__.__name__ 14 | if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'): 15 | # print m.__class__.__name__ 16 | if init_type == 'gaussian': 17 | init.normal_(m.weight.data, 0.0, 0.02) 18 | elif init_type == 'xavier': 19 | init.xavier_normal_(m.weight.data, gain=math.sqrt(2)) 20 | elif init_type == 'kaiming': 21 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 22 | elif init_type == 'orthogonal': 23 | init.orthogonal_(m.weight.data, gain=math.sqrt(2)) 24 | elif init_type == 'default': 25 | pass 26 | else: 27 | assert 0, "Unsupported initialization: {}".format(init_type) 28 | if hasattr(m, 'bias') and m.bias is not None: 29 | init.constant_(m.bias.data, 0.0) 30 | return init_fun 31 | 32 | 33 | def weights_init_normal(m): 34 | classname = m.__class__.__name__ 35 | if classname.find('Conv') != -1 and hasattr(m, 'weight'): 36 | init.normal(m.weight.data, 0.0, 0.02) 37 | elif classname.find('Linear') != -1 and hasattr(m, 'weight'): 38 | init.normal(m.weight.data, 0.0, 0.02) 39 | elif classname.find('BatchNorm2d') != -1: 40 | init.normal(m.weight.data, 1.0, 0.02) 41 | init.constant(m.bias.data, 0.0) 42 | 43 | 44 | def weights_init_xavier(m): 45 | classname = m.__class__.__name__ 46 | # print(classname) 47 | if classname.find('Conv') != -1: 48 | init.xavier_normal(m.weight.data, gain=0.02) 49 | elif classname.find('Linear') != -1: 50 | init.xavier_normal(m.weight.data, gain=0.02) 51 | elif classname.find('BatchNorm2d') != -1: 52 | init.normal(m.weight.data, 1.0, 0.02) 53 | init.constant(m.bias.data, 0.0) 54 | 55 | 56 | def weights_init_kaiming(m): 57 | classname = m.__class__.__name__ 58 | # print(classname) 59 | if classname.find('Conv') != -1: 60 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 61 | elif classname.find('Linear') != -1: 62 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 63 | elif classname.find('BatchNorm2d') != -1: 64 | init.normal(m.weight.data, 1.0, 0.02) 65 | init.constant(m.bias.data, 0.0) 66 | 67 | 68 | def weights_init_orthogonal(m): 69 | classname = m.__class__.__name__ 70 | print(classname) 71 | if classname.find('Conv') != -1: 72 | init.orthogonal(m.weight.data, gain=1) 73 | elif classname.find('Linear') != -1: 74 | init.orthogonal(m.weight.data, gain=1) 75 | elif classname.find('BatchNorm2d') != -1: 76 | init.normal(m.weight.data, 1.0, 0.02) 77 | init.constant(m.bias.data, 0.0) 78 | 79 | 80 | def init_weights(net, init_type='normal'): 81 | print('initialization method [%s]' % init_type) 82 | if init_type == 'normal': 83 | net.apply(weights_init_normal) 84 | elif init_type == 'xavier': 85 | net.apply(weights_init_xavier) 86 | elif init_type == 'kaiming': 87 | net.apply(weights_init_kaiming) 88 | elif init_type == 'orthogonal': 89 | net.apply(weights_init_orthogonal) 90 | else: 91 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 92 | 93 | 94 | def get_norm_layer(norm_type='instance'): 95 | if norm_type == 'batch': 96 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 97 | elif norm_type == 'batch_sync': 98 | norm_layer = BatchNorm2d 99 | elif norm_type == 'instance': 100 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) 101 | elif norm_type == 'none': 102 | norm_layer = None 103 | else: 104 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 105 | return norm_layer 106 | 107 | 108 | def get_scheduler(optimizer, opt): 109 | if opt.lr_policy == 'lambda': 110 | def lambda_rule(epoch): 111 | lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 112 | return lr_l 113 | 114 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 115 | elif opt.lr_policy == 'step': 116 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 117 | elif opt.lr_policy == 'plateau': 118 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 119 | else: 120 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 121 | return scheduler 122 | 123 | 124 | def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal', 125 | gpu_ids=[], n_downsampling=2): 126 | netG = None 127 | use_gpu = len(gpu_ids) > 0 128 | norm_layer = get_norm_layer(norm_type=norm) 129 | 130 | if use_gpu: 131 | assert (torch.cuda.is_available()) 132 | 133 | if which_model_netG == 'CASD': 134 | style_dim = 2048 135 | n_res = 8 136 | mlp_dim = 256 137 | from models.CASD import ADGen 138 | netG = ADGen(input_nc, ngf, style_dim, n_downsampling, n_res, mlp_dim) 139 | else: 140 | raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) 141 | if len(gpu_ids) > 1: 142 | netG = torch.nn.DataParallel(netG, device_ids=gpu_ids) 143 | netG.cuda() 144 | init_weights(netG, init_type=init_type) 145 | return netG 146 | 147 | 148 | class AttrDict(dict): 149 | 150 | def __init__(self,*args,**kwargs): 151 | super().__init__(*args,**kwargs) 152 | 153 | def operation_list(self,value): 154 | new_value = [] 155 | for v in value: 156 | if isinstance(v, dict): 157 | new_value.append(AttrDict(v)) 158 | elif isinstance(v,list): 159 | new_value.append(self.operation_list(v)) 160 | else: 161 | new_value.append(v) 162 | return new_value 163 | 164 | def __getattr__(self, item): 165 | value=self[item] 166 | if isinstance(value,dict): 167 | value=AttrDict(value) 168 | elif isinstance(value,list): 169 | value=self.operation_list(value) 170 | return value 171 | 172 | 173 | def define_D(input_nc, ndf, which_model_netD, 174 | n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', gpu_ids=[], use_dropout=False, 175 | n_downsampling=2): 176 | netD = None 177 | use_gpu = len(gpu_ids) > 0 178 | norm_layer = get_norm_layer(norm_type=norm) 179 | 180 | if use_gpu: 181 | assert (torch.cuda.is_available()) 182 | 183 | if which_model_netD == 'resnet': 184 | netD = ResnetDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=n_layers_D, 185 | gpu_ids=[], padding_type='reflect', use_sigmoid=use_sigmoid, 186 | n_downsampling=n_downsampling) 187 | else: 188 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % 189 | which_model_netD) 190 | if len(gpu_ids) > 1: 191 | netD = torch.nn.DataParallel(netD, device_ids=gpu_ids) 192 | netD.cuda() 193 | return netD 194 | 195 | 196 | def print_network(net): 197 | num_params = 0 198 | for param in net.parameters(): 199 | num_params += param.numel() 200 | print(net) 201 | print('Total number of parameters: %d' % num_params) 202 | 203 | 204 | ############################################################################## 205 | # Classes 206 | ############################################################################## 207 | 208 | # Defines the GAN loss which uses either LSGAN or the regular GAN. 209 | # When LSGAN is used, it is basically same as MSELoss, 210 | # but it abstracts away the need to create the target label tensor 211 | # that has the same size as the input 212 | class GANLoss(nn.Module): 213 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, 214 | tensor=torch.FloatTensor): 215 | super(GANLoss, self).__init__() 216 | self.real_label = target_real_label 217 | self.fake_label = target_fake_label 218 | self.real_label_var = None 219 | self.fake_label_var = None 220 | self.Tensor = tensor 221 | if use_lsgan: 222 | self.loss = nn.MSELoss() 223 | else: 224 | self.loss = nn.BCELoss() 225 | 226 | def get_target_tensor(self, input, target_is_real): 227 | target_tensor = None 228 | if target_is_real: 229 | create_label = ((self.real_label_var is None) or 230 | (self.real_label_var.numel() != input.numel())) 231 | if create_label: 232 | real_tensor = self.Tensor(input.size()).fill_(self.real_label) 233 | self.real_label_var = Variable(real_tensor, requires_grad=False) 234 | target_tensor = self.real_label_var 235 | else: 236 | create_label = ((self.fake_label_var is None) or 237 | (self.fake_label_var.numel() != input.numel())) 238 | if create_label: 239 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) 240 | self.fake_label_var = Variable(fake_tensor, requires_grad=False) 241 | target_tensor = self.fake_label_var 242 | return target_tensor 243 | 244 | def __call__(self, input, target_is_real): 245 | target_tensor = self.get_target_tensor(input, target_is_real) 246 | return self.loss(input, target_tensor) 247 | 248 | 249 | 250 | # Define a resnet block 251 | class ResnetBlock(nn.Module): 252 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 253 | super(ResnetBlock, self).__init__() 254 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 255 | 256 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 257 | conv_block = [] 258 | p = 0 259 | if padding_type == 'reflect': 260 | conv_block += [nn.ReflectionPad2d(1)] 261 | elif padding_type == 'replicate': 262 | conv_block += [nn.ReplicationPad2d(1)] 263 | elif padding_type == 'zero': 264 | p = 1 265 | else: 266 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 267 | 268 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 269 | norm_layer(dim), 270 | nn.ReLU(True)] 271 | if use_dropout: 272 | conv_block += [nn.Dropout(0.5)] 273 | 274 | p = 0 275 | if padding_type == 'reflect': 276 | conv_block += [nn.ReflectionPad2d(1)] 277 | elif padding_type == 'replicate': 278 | conv_block += [nn.ReplicationPad2d(1)] 279 | elif padding_type == 'zero': 280 | p = 1 281 | else: 282 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 283 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 284 | norm_layer(dim)] 285 | 286 | return nn.Sequential(*conv_block) 287 | 288 | def forward(self, x): 289 | out = x + self.conv_block(x) 290 | return out 291 | 292 | class ResnetDiscriminator(nn.Module): 293 | def __init__(self, input_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[], 294 | padding_type='reflect', use_sigmoid=False, n_downsampling=2): 295 | assert (n_blocks >= 0) 296 | super(ResnetDiscriminator, self).__init__() 297 | self.input_nc = input_nc 298 | self.ngf = ngf 299 | self.gpu_ids = gpu_ids 300 | if type(norm_layer) == functools.partial: 301 | use_bias = norm_layer.func == nn.InstanceNorm2d 302 | else: 303 | use_bias = norm_layer == nn.InstanceNorm2d 304 | 305 | model = [nn.ReflectionPad2d(3), 306 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, 307 | bias=use_bias), 308 | norm_layer(ngf), 309 | nn.ReLU(True)] 310 | 311 | # n_downsampling = 2 312 | if n_downsampling <= 2: 313 | for i in range(n_downsampling): 314 | mult = 2 ** i 315 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, 316 | stride=2, padding=1, bias=use_bias), 317 | norm_layer(ngf * mult * 2), 318 | nn.ReLU(True)] 319 | elif n_downsampling == 3: 320 | mult = 2 ** 0 321 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, 322 | stride=2, padding=1, bias=use_bias), 323 | norm_layer(ngf * mult * 2), 324 | nn.ReLU(True)] 325 | mult = 2 ** 1 326 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, 327 | stride=2, padding=1, bias=use_bias), 328 | norm_layer(ngf * mult * 2), 329 | nn.ReLU(True)] 330 | mult = 2 ** 2 331 | model += [nn.Conv2d(ngf * mult, ngf * mult, kernel_size=3, 332 | stride=2, padding=1, bias=use_bias), 333 | norm_layer(ngf * mult), 334 | nn.ReLU(True)] 335 | 336 | if n_downsampling <= 2: 337 | mult = 2 ** n_downsampling 338 | else: 339 | mult = 4 340 | for i in range(n_blocks): 341 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, 342 | use_bias=use_bias)] 343 | 344 | if use_sigmoid: 345 | model += [nn.Sigmoid()] 346 | 347 | self.model = nn.Sequential(*model) 348 | 349 | def forward(self, input): 350 | if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): 351 | return nn.parallel.data_parallel(self.model, input, self.gpu_ids) 352 | else: 353 | return self.model(input) 354 | 355 | 356 | -------------------------------------------------------------------------------- /models/test_model.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | from collections import OrderedDict 3 | import util.util as util 4 | from .base_model import BaseModel 5 | from . import networks 6 | 7 | 8 | class TestModel(BaseModel): 9 | def name(self): 10 | return 'TestModel' 11 | 12 | def initialize(self, opt): 13 | assert(not opt.isTrain) 14 | BaseModel.initialize(self, opt) 15 | self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) 16 | 17 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, 18 | opt.ngf, opt.which_model_netG, 19 | opt.norm, not opt.no_dropout, 20 | opt.init_type, 21 | self.gpu_ids) 22 | which_epoch = opt.which_epoch 23 | self.load_network(self.netG, 'G', which_epoch) 24 | 25 | print('---------- Networks initialized -------------') 26 | networks.print_network(self.netG) 27 | print('-----------------------------------------------') 28 | 29 | def set_input(self, input): 30 | # we need to use single_dataset mode 31 | input_A = input['A'] 32 | self.input_A.resize_(input_A.size()).copy_(input_A) 33 | self.image_paths = input['A_paths'] 34 | 35 | def test(self): 36 | self.real_A = Variable(self.input_A) 37 | self.fake_B = self.netG(self.real_A) 38 | 39 | # get image paths 40 | def get_image_paths(self): 41 | return self.image_paths 42 | 43 | def get_current_visuals(self): 44 | real_A = util.tensor2im(self.real_A.data) 45 | fake_B = util.tensor2im(self.fake_B.data) 46 | return OrderedDict([('real_A', real_A), ('fake_B', fake_B)]) 47 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # gram matrix and loss 7 | class GramMatrix(nn.Module): 8 | def forward(self, input): 9 | b, c, h, w = input.size() 10 | F = input.view(b, c, h * w) 11 | G = torch.bmm(F, F.transpose(1, 2)) 12 | G.div_(h * w) 13 | return G 14 | 15 | 16 | class GramMSELoss(nn.Module): 17 | def forward(self, input, target): 18 | out = nn.MSELoss()(GramMatrix()(input), target) 19 | return (out) 20 | 21 | 22 | # vgg definition that conveniently let's you grab the outputs from any layer 23 | class VGG(nn.Module): 24 | def __init__(self, pool='max'): 25 | super(VGG, self).__init__() 26 | # vgg modules 27 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) 28 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 29 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 30 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 31 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 32 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 33 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 34 | self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 35 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1) 36 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 37 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 38 | self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 39 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 40 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 41 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 42 | self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 43 | if pool == 'max': 44 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 45 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 46 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 47 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 48 | self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2) 49 | elif pool == 'avg': 50 | self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2) 51 | self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) 52 | self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2) 53 | self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2) 54 | self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2) 55 | 56 | def forward(self, x, out_keys): 57 | out = {} 58 | out['r11'] = F.relu(self.conv1_1(x)) 59 | out['r12'] = F.relu(self.conv1_2(out['r11'])) 60 | out['p1'] = self.pool1(out['r12']) 61 | out['r21'] = F.relu(self.conv2_1(out['p1'])) 62 | out['r22'] = F.relu(self.conv2_2(out['r21'])) 63 | out['p2'] = self.pool2(out['r22']) 64 | out['r31'] = F.relu(self.conv3_1(out['p2'])) 65 | out['r32'] = F.relu(self.conv3_2(out['r31'])) 66 | out['r33'] = F.relu(self.conv3_3(out['r32'])) 67 | out['r34'] = F.relu(self.conv3_4(out['r33'])) 68 | out['p3'] = self.pool3(out['r34']) 69 | out['r41'] = F.relu(self.conv4_1(out['p3'])) 70 | out['r42'] = F.relu(self.conv4_2(out['r41'])) 71 | out['r43'] = F.relu(self.conv4_3(out['r42'])) 72 | out['r44'] = F.relu(self.conv4_4(out['r43'])) 73 | out['p4'] = self.pool4(out['r44']) 74 | out['r51'] = F.relu(self.conv5_1(out['p4'])) 75 | out['r52'] = F.relu(self.conv5_2(out['r51'])) 76 | out['r53'] = F.relu(self.conv5_3(out['r52'])) 77 | out['r54'] = F.relu(self.conv5_4(out['r53'])) 78 | out['p5'] = self.pool5(out['r54']) 79 | return [out[key] for key in out_keys] 80 | 81 | -------------------------------------------------------------------------------- /models/vgg_SC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision.models.vgg as models 6 | 7 | 8 | 9 | # gram matrix and loss 10 | class GramMatrix(nn.Module): 11 | def forward(self, input): 12 | b, c, h, w = input.size() 13 | F = input.view(b, c, h * w) 14 | G = torch.bmm(F, F.transpose(1, 2)) 15 | G.div_(h * w) 16 | return G 17 | 18 | 19 | class GramMSELoss(nn.Module): 20 | def forward(self, input, target): 21 | out = nn.MSELoss()(GramMatrix()(input), target) 22 | return (out) 23 | 24 | 25 | # vgg definition that conveniently let's you grab the outputs from any layer 26 | class VGG(nn.Module): 27 | def __init__(self, pool='max'): 28 | super(VGG, self).__init__() 29 | # vgg modules 30 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) 31 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 32 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 33 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 34 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 35 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 36 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 37 | self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 38 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1) 39 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 40 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 41 | self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 42 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 43 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 44 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 45 | self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 46 | if pool == 'max': 47 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 48 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 49 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 50 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 51 | self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2) 52 | elif pool == 'avg': 53 | self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2) 54 | self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) 55 | self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2) 56 | self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2) 57 | self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2) 58 | 59 | def forward(self, x, out_keys): 60 | out = {} 61 | out['r11'] = F.relu(self.conv1_1(x)) 62 | out['r12'] = F.relu(self.conv1_2(out['r11'])) 63 | out['p1'] = self.pool1(out['r12']) 64 | out['r21'] = F.relu(self.conv2_1(out['p1'])) 65 | out['r22'] = F.relu(self.conv2_2(out['r21'])) 66 | out['p2'] = self.pool2(out['r22']) 67 | out['r31'] = F.relu(self.conv3_1(out['p2'])) 68 | out['r32'] = F.relu(self.conv3_2(out['r31'])) 69 | out['r33'] = F.relu(self.conv3_3(out['r32'])) 70 | out['r34'] = F.relu(self.conv3_4(out['r33'])) 71 | out['p3'] = self.pool3(out['r34']) 72 | out['r41'] = F.relu(self.conv4_1(out['p3'])) 73 | out['r42'] = F.relu(self.conv4_2(out['r41'])) 74 | out['r43'] = F.relu(self.conv4_3(out['r42'])) 75 | out['r44'] = F.relu(self.conv4_4(out['r43'])) 76 | out['p4'] = self.pool4(out['r44']) 77 | out['r51'] = F.relu(self.conv5_1(out['p4'])) 78 | out['r52'] = F.relu(self.conv5_2(out['r51'])) 79 | out['r53'] = F.relu(self.conv5_3(out['r52'])) 80 | out['r54'] = F.relu(self.conv5_4(out['r53'])) 81 | out['p5'] = self.pool5(out['r54']) 82 | return [out[key] for key in out_keys] 83 | 84 | 85 | class VGGLoss(nn.Module): 86 | r""" 87 | Perceptual loss, VGG-based 88 | https://arxiv.org/abs/1603.08155 89 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 90 | """ 91 | 92 | def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]): 93 | super(VGGLoss, self).__init__() 94 | self.add_module('vgg', VGG19()) 95 | self.criterion = torch.nn.L1Loss() 96 | self.weights = weights 97 | 98 | def compute_gram(self, x): 99 | b, ch, h, w = x.size() 100 | f = x.view(b, ch, w * h) 101 | f_T = f.transpose(1, 2) 102 | G = f.bmm(f_T) / (h * w * ch) 103 | return G 104 | 105 | def __call__(self, x, y): 106 | # Compute features 107 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 108 | 109 | content_loss = 0.0 110 | content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1']) 111 | content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1']) 112 | content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1']) 113 | content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1']) 114 | content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1']) 115 | 116 | # Compute loss 117 | style_loss = 0.0 118 | style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2'])) 119 | style_loss += self.criterion(self.compute_gram(x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4'])) 120 | style_loss += self.criterion(self.compute_gram(x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4'])) 121 | style_loss += self.criterion(self.compute_gram(x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2'])) 122 | 123 | return content_loss, style_loss 124 | 125 | 126 | class StyleLoss(nn.Module): 127 | r""" 128 | Perceptual loss, VGG-based 129 | https://arxiv.org/abs/1603.08155 130 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 131 | """ 132 | 133 | def __init__(self): 134 | super(StyleLoss, self).__init__() 135 | self.add_module('vgg', VGG19()) 136 | self.criterion = torch.nn.L1Loss() 137 | 138 | def compute_gram(self, x): 139 | b, ch, h, w = x.size() 140 | f = x.view(b, ch, w * h) 141 | f_T = f.transpose(1, 2) 142 | G = f.bmm(f_T) / (h * w * ch) 143 | 144 | return G 145 | 146 | def __call__(self, x, y): 147 | # Compute features 148 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 149 | 150 | # Compute loss 151 | style_loss = 0.0 152 | style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2'])) 153 | style_loss += self.criterion(self.compute_gram(x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4'])) 154 | style_loss += self.criterion(self.compute_gram(x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4'])) 155 | style_loss += self.criterion(self.compute_gram(x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2'])) 156 | 157 | return style_loss 158 | 159 | 160 | class PerceptualLoss(nn.Module): 161 | r""" 162 | Perceptual loss, VGG-based 163 | https://arxiv.org/abs/1603.08155 164 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 165 | """ 166 | 167 | def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]): 168 | super(PerceptualLoss, self).__init__() 169 | self.add_module('vgg', VGG19()) 170 | self.criterion = torch.nn.L1Loss() 171 | self.weights = weights 172 | 173 | def __call__(self, x, y): 174 | # Compute features 175 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 176 | content_loss = 0.0 177 | content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1']) 178 | content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1']) 179 | content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1']) 180 | content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1']) 181 | content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1']) 182 | 183 | return content_loss 184 | 185 | class VGG19(torch.nn.Module): 186 | def __init__(self): 187 | super(VGG19, self).__init__() 188 | # features = models.vgg19(pretrained=True).features 189 | 190 | vgg19 = models.vgg19(pretrained=False) 191 | vgg19.load_state_dict(torch.load('/home/haihuam/CASD-main/dataset/fashion/vgg19-dcbb9e9d.pth')) 192 | self.vgg = vgg19.features 193 | features = vgg19.features 194 | 195 | for param in self.vgg.parameters(): 196 | param.requires_grad_(False) 197 | 198 | 199 | self.relu1_1 = torch.nn.Sequential() 200 | self.relu1_2 = torch.nn.Sequential() 201 | 202 | self.relu2_1 = torch.nn.Sequential() 203 | self.relu2_2 = torch.nn.Sequential() 204 | 205 | self.relu3_1 = torch.nn.Sequential() 206 | self.relu3_2 = torch.nn.Sequential() 207 | self.relu3_3 = torch.nn.Sequential() 208 | self.relu3_4 = torch.nn.Sequential() 209 | 210 | self.relu4_1 = torch.nn.Sequential() 211 | self.relu4_2 = torch.nn.Sequential() 212 | self.relu4_3 = torch.nn.Sequential() 213 | self.relu4_4 = torch.nn.Sequential() 214 | 215 | self.relu5_1 = torch.nn.Sequential() 216 | self.relu5_2 = torch.nn.Sequential() 217 | self.relu5_3 = torch.nn.Sequential() 218 | self.relu5_4 = torch.nn.Sequential() 219 | 220 | for x in range(2): 221 | self.relu1_1.add_module(str(x), features[x]) 222 | 223 | for x in range(2, 4): 224 | self.relu1_2.add_module(str(x), features[x]) 225 | 226 | for x in range(4, 7): 227 | self.relu2_1.add_module(str(x), features[x]) 228 | 229 | for x in range(7, 9): 230 | self.relu2_2.add_module(str(x), features[x]) 231 | 232 | for x in range(9, 12): 233 | self.relu3_1.add_module(str(x), features[x]) 234 | 235 | for x in range(12, 14): 236 | self.relu3_2.add_module(str(x), features[x]) 237 | 238 | for x in range(14, 16): 239 | self.relu3_2.add_module(str(x), features[x]) 240 | 241 | for x in range(16, 18): 242 | self.relu3_4.add_module(str(x), features[x]) 243 | 244 | for x in range(18, 21): 245 | self.relu4_1.add_module(str(x), features[x]) 246 | 247 | for x in range(21, 23): 248 | self.relu4_2.add_module(str(x), features[x]) 249 | 250 | for x in range(23, 25): 251 | self.relu4_3.add_module(str(x), features[x]) 252 | 253 | for x in range(25, 27): 254 | self.relu4_4.add_module(str(x), features[x]) 255 | 256 | for x in range(27, 30): 257 | self.relu5_1.add_module(str(x), features[x]) 258 | 259 | for x in range(30, 32): 260 | self.relu5_2.add_module(str(x), features[x]) 261 | 262 | for x in range(32, 34): 263 | self.relu5_3.add_module(str(x), features[x]) 264 | 265 | for x in range(34, 36): 266 | self.relu5_4.add_module(str(x), features[x]) 267 | 268 | # don't need the gradients, just want the features 269 | for param in self.parameters(): 270 | param.requires_grad = False 271 | 272 | def forward(self, x): 273 | relu1_1 = self.relu1_1(x) 274 | relu1_2 = self.relu1_2(relu1_1) 275 | 276 | relu2_1 = self.relu2_1(relu1_2) 277 | relu2_2 = self.relu2_2(relu2_1) 278 | 279 | relu3_1 = self.relu3_1(relu2_2) 280 | relu3_2 = self.relu3_2(relu3_1) 281 | relu3_3 = self.relu3_3(relu3_2) 282 | relu3_4 = self.relu3_4(relu3_3) 283 | 284 | relu4_1 = self.relu4_1(relu3_4) 285 | relu4_2 = self.relu4_2(relu4_1) 286 | relu4_3 = self.relu4_3(relu4_2) 287 | relu4_4 = self.relu4_4(relu4_3) 288 | 289 | relu5_1 = self.relu5_1(relu4_4) 290 | relu5_2 = self.relu5_2(relu5_1) 291 | relu5_3 = self.relu5_3(relu5_2) 292 | relu5_4 = self.relu5_4(relu5_3) 293 | 294 | out = { 295 | 'relu1_1': relu1_1, 296 | 'relu1_2': relu1_2, 297 | 298 | 'relu2_1': relu2_1, 299 | 'relu2_2': relu2_2, 300 | 301 | 'relu3_1': relu3_1, 302 | 'relu3_2': relu3_2, 303 | 'relu3_3': relu3_3, 304 | 'relu3_4': relu3_4, 305 | 306 | 'relu4_1': relu4_1, 307 | 'relu4_2': relu4_2, 308 | 'relu4_3': relu4_3, 309 | 'relu4_4': relu4_4, 310 | 311 | 'relu5_1': relu5_1, 312 | 'relu5_2': relu5_2, 313 | 'relu5_3': relu5_3, 314 | 'relu5_4': relu5_4, 315 | } 316 | return out -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyzhouo/CASD/cb1aabb64b5d8ae712b626a1eec045d08f90933e/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | 6 | 7 | class BaseOptions(): 8 | def __init__(self): 9 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 10 | self.initialized = False 11 | 12 | def initialize(self): 13 | self.parser.add_argument('--dataroot', default='./dataset/fashion',\ 14 | help='path to images ') 15 | self.parser.add_argument('--dirSem', default='./dataset/fashion',\ 16 | help='path to semantic images') 17 | 18 | self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size') 19 | self.parser.add_argument('--which_model_netG', type=str, default='CASD', help='selects model to use for netG') 20 | self.parser.add_argument('--name', type=str, 21 | default='CASD_test', 22 | help='name of the experiment. It decides where to store samples and models') 23 | self.parser.add_argument('--fineSize', type=int, default=[256,256], help='input image size') 24 | self.parser.add_argument('--pairLst', type=str, default='./dataset/fashion/fashion-resize-pairs-train.csv', help='fashion pairs') 25 | # self.parser.add_argument('--pairLst', type=str, default='./dataset/fashion/fashion-resize-pairs-test.csv', help='fashion pairs') 26 | 27 | self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') 28 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') 29 | self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 30 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 31 | self.parser.add_argument('--which_model_netD', type=str, default='resnet', help='selects model to use for netD') 32 | self.parser.add_argument('--n_layers_D', type=int, default=0, help='blocks used in D') 33 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 34 | self.parser.add_argument('--dataset_mode', type=str, default='keypoint', help='chooses how datasets are loaded. [unaligned | aligned | single]') 35 | self.parser.add_argument('--model', type=str, default='adgan', 36 | help='chooses which model to use. cycle_gan, pix2pix, test') 37 | self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA') 38 | self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data') 39 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 40 | self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') 41 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 42 | self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size') 43 | self.parser.add_argument('--display_id', type=int, default=0, help='window id of the web display') 44 | self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 45 | self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') 46 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 47 | self.parser.add_argument('--resize_or_crop', type=str, default='no', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') 48 | self.parser.add_argument('--no_flip', default=False, help='if specified, flip the images for data augmentation') 49 | self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') 50 | 51 | self.parser.add_argument('--P_input_nc', type=int, default=3, help='# of input image channels') 52 | self.parser.add_argument('--BP_input_nc', type=int, default=18, help='# of input image channels') 53 | self.parser.add_argument('--BPD_input_nc', type=int, default=12, help='# of input image channels') 54 | self.parser.add_argument('--SP_input_nc', type=int, default=8, help='# of input image channels') 55 | self.parser.add_argument('--with_D_PP', type=int, default=1, help='use D to judge P and P is pair or not') 56 | self.parser.add_argument('--with_D_PB', type=int, default=1, help='use D to judge P and B is pair or not') 57 | self.parser.add_argument('--without_concat_SBP', type=int, default=0, help='do not concat source BP') 58 | self.parser.add_argument('--use_flip', type=int, default=0, help='flip or not') 59 | 60 | # down-sampling times 61 | self.parser.add_argument('--G_n_downsampling', type=int, default=2, help='down-sampling blocks for generator') 62 | self.parser.add_argument('--D_n_downsampling', type=int, default=2, help='down-sampling blocks for discriminator') 63 | self.parser.add_argument('--use_AMCE', type=int, default=1, help='flip or not') 64 | self.parser.add_argument('--use_BPD', type=int, default=1, help='flip or not') 65 | self.parser.add_argument('--use_lpips', type=int, default=1, help='flip or not') 66 | 67 | self.initialized = True 68 | 69 | def parse(self): 70 | if not self.initialized: 71 | self.initialize() 72 | self.opt = self.parser.parse_args() 73 | self.opt.isTrain = self.isTrain # train or test 74 | 75 | str_ids = self.opt.gpu_ids.split(',') 76 | self.opt.gpu_ids = [] 77 | for str_id in str_ids: 78 | id = int(str_id) 79 | if id >= 0: 80 | self.opt.gpu_ids.append(id) 81 | 82 | # set gpu ids 83 | # num_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',').__len__() 84 | # os.environ['CUDA_VISIBLE_DEVICES'] = '3' 85 | if len(self.opt.gpu_ids) > 0: 86 | torch.cuda.set_device(self.opt.gpu_ids[0]) 87 | 88 | args = vars(self.opt) 89 | 90 | print('------------ Options -------------') 91 | for k, v in sorted(args.items()): 92 | print('%s: %s' % (str(k), str(v))) 93 | print('-------------- End ----------------') 94 | 95 | # save to the disk 96 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) 97 | util.mkdirs(expr_dir) 98 | file_name = os.path.join(expr_dir, 'opt.txt') 99 | with open(file_name, 'wt') as opt_file: 100 | opt_file.write('------------ Options -------------\n') 101 | for k, v in sorted(args.items()): 102 | opt_file.write('%s: %s\n' % (str(k), str(v))) 103 | opt_file.write('-------------- End ----------------\n') 104 | return self.opt 105 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 8 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 9 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 10 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 11 | self.parser.add_argument('--which_epoch', type=str, default='1000', help='which epoch to load? set to latest to use latest cached model') 12 | self.parser.add_argument('--how_many', type=int, default=200, help='how many test images to run') 13 | 14 | self.isTrain = False 15 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') 8 | self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 9 | self.parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') 10 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 11 | self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 12 | self.parser.add_argument('--save_epoch_freq', type=int, default=20, help='frequency of saving checkpoints at the end of epochs') 13 | self.parser.add_argument('--continue_train', default=False, help='continue training: load the latest model') 14 | self.parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 15 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 16 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 17 | self.parser.add_argument('--niter', type=int, default=500, help='# of iter at starting learning rate') 18 | self.parser.add_argument('--niter_decay', type=int, default=500, help='# of iter to linearly decay learning rate to zero') 19 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 20 | self.parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate for adam') 21 | self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') 22 | 23 | self.parser.add_argument('--lambda_A', type=float, default=1.0, help='weight for L1 loss') 24 | self.parser.add_argument('--lambda_B', type=float, default=1.0, help='weight for perceptual L1 loss') 25 | self.parser.add_argument('--lambda_GAN', type=float, default=5.0, help='weight of GAN loss') 26 | self.parser.add_argument('--lambda_cx', type=float, default=0.1, help='weight of CX loss') 27 | self.parser.add_argument('--lambda_AMCE', type=float, default=0.1, help='weight of CX loss') 28 | self.parser.add_argument('--lambda_lpips', type=float, default=1.0, help='weight of CX loss') 29 | 30 | self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images') 31 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 32 | self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau') 33 | self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 34 | self.parser.add_argument('--L1_type', type=str, default='l1_plus_perL1', help='use which kind of L1 loss. (origin|l1_plus_perL1)') 35 | self.parser.add_argument('--perceptual_layers', type=int, default=3, help='index of vgg layer for extracting perceptual features.') 36 | self.parser.add_argument('--percep_is_l1', type=int, default=1, help='type of perceptual loss: l1 or l2') 37 | self.parser.add_argument('--no_dropout_D', action='store_true', help='no dropout for the discriminator') 38 | self.parser.add_argument('--DG_ratio', type=int, default=1, help='how many times for D training after training G once') 39 | self.parser.add_argument('--use_cxloss', type=int, default=1, help='use cxloss or not') 40 | 41 | 42 | self.isTrain = True 43 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dominate 2 | scikit-image 3 | pandas 4 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from options.test_options import TestOptions 3 | from data.data_loader import CreateDataLoader 4 | from models.models import create_model 5 | from util.visualizer import Visualizer 6 | from util import html 7 | import time 8 | 9 | opt = TestOptions().parse() 10 | opt.nThreads = 1 # test code only supports nThreads = 1 11 | opt.batchSize = 1 # test code only supports batchSize = 1 12 | opt.serial_batches = True # no shuffle 13 | opt.no_flip = True # no flip 14 | 15 | data_loader = CreateDataLoader(opt) 16 | dataset = data_loader.load_data() 17 | model = create_model(opt) 18 | visualizer = Visualizer(opt) 19 | # create website 20 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) 21 | 22 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) 23 | 24 | print(opt.how_many) 25 | print(len(dataset)) 26 | 27 | model = model.eval() 28 | print(model.training) 29 | 30 | opt.how_many = 999999 31 | # test 32 | for i, data in enumerate(dataset): 33 | print(' process %d/%d img ..'%(i,opt.how_many)) 34 | if i >= opt.how_many: 35 | break 36 | model.set_input(data) 37 | startTime = time.time() 38 | model.test() 39 | endTime = time.time() 40 | print(endTime-startTime) 41 | visuals = model.get_current_visuals() 42 | img_path = model.get_image_paths() 43 | img_path = [img_path] 44 | print(img_path) 45 | visualizer.save_images(webpage, visuals, img_path) 46 | 47 | webpage.save() 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /tool/generate_fashion_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from PIL import Image 4 | 5 | IMG_EXTENSIONS = [ 6 | '.jpg', '.JPG', '.jpeg', '.JPEG', 7 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 8 | ] 9 | 10 | def is_image_file(filename): 11 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 12 | 13 | def make_dataset(dir): 14 | images = [] 15 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 16 | # new_root = './fashion' 17 | # if not os.path.exists(new_root): 18 | # os.mkdir(new_root) 19 | 20 | train_root = os.path.join(dir, 'train') 21 | if not os.path.exists(train_root): 22 | os.mkdir(train_root) 23 | 24 | test_root = os.path.join(dir, 'test') 25 | if not os.path.exists(test_root): 26 | os.mkdir(test_root) 27 | 28 | train_images = [] 29 | train_f = open(os.path.join(dir, 'train.lst'), 'r') 30 | for lines in train_f: 31 | lines = lines.strip() 32 | if lines.endswith('.jpg'): 33 | train_images.append(lines) 34 | 35 | test_images = [] 36 | test_f = open(os.path.join(dir, 'test.lst'), 'r') 37 | for lines in test_f: 38 | lines = lines.strip() 39 | if lines.endswith('.jpg'): 40 | test_images.append(lines) 41 | 42 | # print(train_images, test_images) 43 | 44 | for root, _, fnames in sorted(os.walk(os.path.join(dir, 'img_highres'))): 45 | for fname in fnames: 46 | if is_image_file(fname): 47 | path = os.path.join(root, fname) 48 | path_names = path.split('/') 49 | print(path_names) 50 | 51 | path_names = path_names[2:] 52 | del path_names[1] 53 | path_names[3] = path_names[3].replace('_', '') 54 | path_names[4] = path_names[4].split('_')[0] + "_" + "".join(path_names[4].split('_')[1:]) 55 | path_names = "".join(path_names) 56 | # img = Image.open(path) 57 | if path_names in train_images: 58 | shutil.copy(path, os.path.join(train_root, path_names)) 59 | print(os.path.join(train_root, path_names)) 60 | # pass 61 | elif path_names in test_images: 62 | shutil.copy(path, os.path.join(test_root, path_names)) 63 | print(os.path.join(train_root, path_names)) 64 | # pass 65 | 66 | make_dataset('../dataset/fashion/') -------------------------------------------------------------------------------- /tool/generate_pose_map_fashion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import json 4 | import os 5 | 6 | MISSING_VALUE = -1 7 | # fix PATH 8 | img_dir = '../dataset/fashion/' 9 | annotations_file = os.path.join(img_dir, 'fashion-resize-annotation-test.csv') #pose annotation path 10 | save_path = os.path.join(img_dir, 'testK') 11 | if not os.path.exists(save_path): 12 | os.makedirs(save_path) 13 | 14 | def load_pose_cords_from_strings(y_str, x_str): 15 | y_cords = json.loads(y_str) 16 | x_cords = json.loads(x_str) 17 | return np.concatenate([np.expand_dims(y_cords, -1), np.expand_dims(x_cords, -1)], axis=1) 18 | 19 | def cords_to_map(cords, img_size, sigma=6): 20 | result = np.zeros(img_size + cords.shape[0:1], dtype='uint8') 21 | for i, point in enumerate(cords): 22 | if point[0] == MISSING_VALUE or point[1] == MISSING_VALUE: 23 | continue 24 | xx, yy = np.meshgrid(np.arange(img_size[1]), np.arange(img_size[0])) 25 | result[..., i] = np.exp(-((yy - point[0]) ** 2 + (xx - point[1]) ** 2) / (2 * sigma ** 2)) 26 | # result[..., i] = np.where(((yy - point[0]) ** 2 + (xx - point[1]) ** 2) < (sigma ** 2), 1, 0) 27 | return result 28 | 29 | def compute_pose(image_dir, annotations_file, savePath, sigma): 30 | annotations_file = pd.read_csv(annotations_file, sep=':') 31 | annotations_file = annotations_file.set_index('name') 32 | image_size = (256, 256) 33 | cnt = len(annotations_file) 34 | for i in range(cnt): 35 | print('processing %d / %d ...' %(i, cnt)) 36 | row = annotations_file.iloc[i] 37 | name = row.name 38 | print(savePath, name) 39 | file_name = os.path.join(savePath, name + '.npy') 40 | kp_array = load_pose_cords_from_strings(row.keypoints_y, row.keypoints_x) 41 | pose = cords_to_map(kp_array, image_size, sigma) 42 | np.save(file_name, pose) 43 | # input() 44 | 45 | compute_pose(img_dir, annotations_file, save_path,6) 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /tool/resize_fashion.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from PIL import ImageFile 4 | 5 | ImageFile.LOAD_TRUNCATED_IMAGES=True 6 | def resize_dataset(folder, new_folder, new_size = (256, 256), crop_bord=0): 7 | if not os.path.exists(new_folder): 8 | os.makedirs(new_folder) 9 | for name in os.listdir(folder): 10 | old_name = os.path.join(folder, name) 11 | new_name = os.path.join(new_folder, name) 12 | 13 | 14 | img = Image.open(old_name) 15 | w, h =img.size 16 | if crop_bord == 0: 17 | pass 18 | else: 19 | img = img.crop((crop_bord, 0, w-crop_bord, h)) 20 | img = img.resize([new_size[1],new_size[0]]) 21 | img.save(new_name) 22 | print('resize %s succefully' % old_name) 23 | 24 | 25 | old_dir = '../dataset/fashion/train' 26 | root_dir = '../dataset/fashion/train_resize' 27 | resize_dataset(old_dir, root_dir) 28 | 29 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import TrainOptions 3 | from data.data_loader import CreateDataLoader 4 | from models.models import create_model 5 | from util.visualizer import Visualizer 6 | 7 | 8 | opt = TrainOptions().parse() 9 | data_loader = CreateDataLoader(opt) 10 | dataset = data_loader.load_data() 11 | dataset_size = len(data_loader) 12 | print('#training images = %d' % dataset_size) 13 | 14 | model = create_model(opt) 15 | 16 | 17 | visualizer = Visualizer(opt) 18 | total_steps = 0 19 | 20 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): 21 | epoch_start_time = time.time() 22 | epoch_iter = 0 23 | 24 | 25 | for i, data in enumerate(dataset): 26 | iter_start_time = time.time() 27 | visualizer.reset() 28 | total_steps += opt.batchSize 29 | epoch_iter += opt.batchSize 30 | model.set_input(data) 31 | 32 | # model.optimize_parameters() 33 | model.optimize_parameters() 34 | 35 | if total_steps % opt.display_freq == 0: 36 | save_result = total_steps % opt.update_html_freq == 0 37 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) 38 | 39 | if total_steps % opt.print_freq == 0: 40 | errors = model.get_current_errors() 41 | t = (time.time() - iter_start_time) / opt.batchSize 42 | visualizer.print_current_errors(epoch, epoch_iter, errors, t) 43 | if opt.display_id > 0: 44 | visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors) 45 | 46 | if total_steps % opt.save_latest_freq == 0: 47 | print('saving the latest model (epoch %d, total_steps %d)' % 48 | (epoch, total_steps)) 49 | model.save('latest') 50 | 51 | if epoch % opt.save_epoch_freq == 0: 52 | print('saving the model at the end of epoch %d, iters %d' % 53 | (epoch, total_steps)) 54 | model.save('latest') 55 | model.save(epoch) 56 | 57 | print('End of epoch %d / %d \t Time Taken: %d sec' % 58 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 59 | model.update_learning_rate() 60 | 61 | 62 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyzhouo/CASD/cb1aabb64b5d8ae712b626a1eec045d08f90933e/util/__init__.py -------------------------------------------------------------------------------- /util/get_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import tarfile 4 | import requests 5 | from warnings import warn 6 | from zipfile import ZipFile 7 | from bs4 import BeautifulSoup 8 | from os.path import abspath, isdir, join, basename 9 | 10 | 11 | class GetData(object): 12 | """ 13 | 14 | Download CycleGAN or Pix2Pix Data. 15 | 16 | Args: 17 | technique : str 18 | One of: 'cyclegan' or 'pix2pix'. 19 | verbose : bool 20 | If True, print additional information. 21 | 22 | Examples: 23 | >>> from util.get_data import GetData 24 | >>> gd = GetData(technique='cyclegan') 25 | >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. 26 | 27 | """ 28 | 29 | def __init__(self, technique='cyclegan', verbose=True): 30 | url_dict = { 31 | 'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets', 32 | 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' 33 | } 34 | self.url = url_dict.get(technique.lower()) 35 | self._verbose = verbose 36 | 37 | def _print(self, text): 38 | if self._verbose: 39 | print(text) 40 | 41 | @staticmethod 42 | def _get_options(r): 43 | soup = BeautifulSoup(r.text, 'lxml') 44 | options = [h.text for h in soup.find_all('a', href=True) 45 | if h.text.endswith(('.zip', 'tar.gz'))] 46 | return options 47 | 48 | def _present_options(self): 49 | r = requests.get(self.url) 50 | options = self._get_options(r) 51 | print('Options:\n') 52 | for i, o in enumerate(options): 53 | print("{0}: {1}".format(i, o)) 54 | choice = input("\nPlease enter the number of the " 55 | "dataset above you wish to download:") 56 | return options[int(choice)] 57 | 58 | def _download_data(self, dataset_url, save_path): 59 | if not isdir(save_path): 60 | os.makedirs(save_path) 61 | 62 | base = basename(dataset_url) 63 | temp_save_path = join(save_path, base) 64 | 65 | with open(temp_save_path, "wb") as f: 66 | r = requests.get(dataset_url) 67 | f.write(r.content) 68 | 69 | if base.endswith('.tar.gz'): 70 | obj = tarfile.open(temp_save_path) 71 | elif base.endswith('.zip'): 72 | obj = ZipFile(temp_save_path, 'r') 73 | else: 74 | raise ValueError("Unknown File Type: {0}.".format(base)) 75 | 76 | self._print("Unpacking Data...") 77 | obj.extractall(save_path) 78 | obj.close() 79 | os.remove(temp_save_path) 80 | 81 | def get(self, save_path, dataset=None): 82 | """ 83 | 84 | Download a dataset. 85 | 86 | Args: 87 | save_path : str 88 | A directory to save the data to. 89 | dataset : str, optional 90 | A specific dataset to download. 91 | Note: this must include the file extension. 92 | If None, options will be presented for you 93 | to choose from. 94 | 95 | Returns: 96 | save_path_full : str 97 | The absolute path to the downloaded data. 98 | 99 | """ 100 | if dataset is None: 101 | selected_dataset = self._present_options() 102 | else: 103 | selected_dataset = dataset 104 | 105 | save_path_full = join(save_path, selected_dataset.split('.')[0]) 106 | 107 | if isdir(save_path_full): 108 | warn("\n'{0}' already exists. Voiding Download.".format( 109 | save_path_full)) 110 | else: 111 | self._print('Downloading Data...') 112 | url = "{0}/{1}".format(self.url, selected_dataset) 113 | self._download_data(url, save_path=save_path) 114 | 115 | return abspath(save_path_full) 116 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, reflesh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | self.img_dir = os.path.join(self.web_dir, 'images') 11 | if not os.path.exists(self.web_dir): 12 | os.makedirs(self.web_dir) 13 | if not os.path.exists(self.img_dir): 14 | os.makedirs(self.img_dir) 15 | # print(self.img_dir) 16 | 17 | self.doc = dominate.document(title=title) 18 | if reflesh > 0: 19 | with self.doc.head: 20 | meta(http_equiv="reflesh", content=str(reflesh)) 21 | 22 | def get_image_dir(self): 23 | return self.img_dir 24 | 25 | def add_header(self, str): 26 | with self.doc: 27 | h3(str) 28 | 29 | def add_table(self, border=1): 30 | self.t = table(border=border, style="table-layout: fixed;") 31 | self.doc.add(self.t) 32 | 33 | def add_images(self, ims, txts, links, width=400): 34 | self.add_table() 35 | with self.t: 36 | with tr(): 37 | for im, txt, link in zip(ims, txts, links): 38 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 39 | with p(): 40 | with a(href=os.path.join('images', link)): 41 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 42 | br() 43 | p(txt) 44 | 45 | def save(self): 46 | html_file = '%s/index.html' % self.web_dir 47 | f = open(html_file, 'wt') 48 | f.write(self.doc.render()) 49 | f.close() 50 | 51 | 52 | if __name__ == '__main__': 53 | html = HTML('web/', 'test_html') 54 | html.add_header('hello world') 55 | 56 | ims = [] 57 | txts = [] 58 | links = [] 59 | for n in range(4): 60 | ims.append('image_%d.png' % n) 61 | txts.append('text_%d' % n) 62 | links.append('image_%d.png' % n) 63 | html.add_images(ims, txts, links) 64 | html.save() 65 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from torch.autograd import Variable 5 | 6 | 7 | class ImagePool(): 8 | def __init__(self, pool_size): 9 | self.pool_size = pool_size 10 | if self.pool_size > 0: 11 | self.num_imgs = 0 12 | self.images = [] 13 | 14 | def query(self, images): 15 | if self.pool_size == 0: 16 | return Variable(images) 17 | return_images = [] 18 | for image in images: 19 | image = torch.unsqueeze(image, 0) 20 | if self.num_imgs < self.pool_size: 21 | self.num_imgs = self.num_imgs + 1 22 | self.images.append(image) 23 | return_images.append(image) 24 | else: 25 | p = random.uniform(0, 1) 26 | if p > 0.5: 27 | random_id = random.randint(0, self.pool_size-1) 28 | tmp = self.images[random_id].clone() 29 | self.images[random_id] = image 30 | return_images.append(tmp) 31 | else: 32 | return_images.append(image) 33 | return_images = Variable(torch.cat(return_images, 0)) 34 | return return_images 35 | -------------------------------------------------------------------------------- /util/png.py: -------------------------------------------------------------------------------- 1 | import struct 2 | import zlib 3 | 4 | def encode(buf, width, height): 5 | """ buf: must be bytes or a bytearray in py3, a regular string in py2. formatted RGBRGB... """ 6 | assert (width * height * 3 == len(buf)) 7 | bpp = 3 8 | 9 | def raw_data(): 10 | # reverse the vertical line order and add null bytes at the start 11 | row_bytes = width * bpp 12 | for row_start in range((height - 1) * width * bpp, -1, -row_bytes): 13 | yield b'\x00' 14 | yield buf[row_start:row_start + row_bytes] 15 | 16 | def chunk(tag, data): 17 | return [ 18 | struct.pack("!I", len(data)), 19 | tag, 20 | data, 21 | struct.pack("!I", 0xFFFFFFFF & zlib.crc32(data, zlib.crc32(tag))) 22 | ] 23 | 24 | SIGNATURE = b'\x89PNG\r\n\x1a\n' 25 | COLOR_TYPE_RGB = 2 26 | COLOR_TYPE_RGBA = 6 27 | bit_depth = 8 28 | return b''.join( 29 | [ SIGNATURE ] + 30 | chunk(b'IHDR', struct.pack("!2I5B", width, height, bit_depth, COLOR_TYPE_RGB, 0, 0, 0)) + 31 | chunk(b'IDAT', zlib.compress(b''.join(raw_data()), 9)) + 32 | chunk(b'IEND', b'') 33 | ) 34 | -------------------------------------------------------------------------------- /util/pose_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.draw import circle, line_aa, polygon 3 | import json 4 | 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | import matplotlib.patches as mpatches 9 | 10 | LIMB_SEQ = [[1,2], [1,5], [2,3], [3,4], [5,6], [6,7], [1,8], [8,9], 11 | [9,10], [1,11], [11,12], [12,13], [1,0], [0,14], [14,16], 12 | [0,15], [15,17], [2,16], [5,17]] 13 | 14 | COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], 15 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], 16 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 17 | 18 | 19 | LABELS = ['nose', 'neck', 'Rsho', 'Relb', 'Rwri', 'Lsho', 'Lelb', 'Lwri', 20 | 'Rhip', 'Rkne', 'Rank', 'Lhip', 'Lkne', 'Lank', 'Leye', 'Reye', 'Lear', 'Rear'] 21 | 22 | MISSING_VALUE = -1 23 | 24 | 25 | def map_to_cord(pose_map, threshold=0.1): 26 | all_peaks = [[] for i in range(18)] 27 | pose_map = pose_map[..., :18] 28 | 29 | y, x, z = np.where(np.logical_and(pose_map == pose_map.max(axis = (0, 1)), 30 | pose_map > threshold)) 31 | for x_i, y_i, z_i in zip(x, y, z): 32 | all_peaks[z_i].append([x_i, y_i]) 33 | 34 | x_values = [] 35 | y_values = [] 36 | 37 | for i in range(18): 38 | if len(all_peaks[i]) != 0: 39 | x_values.append(all_peaks[i][0][0]) 40 | y_values.append(all_peaks[i][0][1]) 41 | else: 42 | x_values.append(MISSING_VALUE) 43 | y_values.append(MISSING_VALUE) 44 | 45 | return np.concatenate([np.expand_dims(y_values, -1), np.expand_dims(x_values, -1)], axis=1) 46 | 47 | 48 | def cords_to_map(cords, img_size, old_size=None, affine_matrix=None, sigma=6): 49 | old_size = img_size if old_size is None else old_size 50 | cords = cords.astype(float) 51 | result = np.zeros(img_size + cords.shape[0:1], dtype='float32') 52 | for i, point in enumerate(cords): 53 | if point[0] == MISSING_VALUE or point[1] == MISSING_VALUE: 54 | continue 55 | point[0] = point[0]/old_size[0] * img_size[0] 56 | point[1] = point[1]/old_size[1] * img_size[1] 57 | if affine_matrix is not None: 58 | point_ =np.dot(affine_matrix, np.matrix([point[1], point[0], 1]).reshape(3,1)) 59 | point_0 = int(point_[1]) 60 | point_1 = int(point_[0]) 61 | else: 62 | point_0 = int(point[0]) 63 | point_1 = int(point[1]) 64 | xx, yy = np.meshgrid(np.arange(img_size[1]), np.arange(img_size[0])) 65 | result[..., i] = np.exp(-((yy - point_0) ** 2 + (xx - point_1) ** 2) / (2 * sigma ** 2)) 66 | return result 67 | 68 | 69 | def draw_pose_from_cords(pose_joints, img_size, radius=2, draw_joints=True): 70 | colors = np.zeros(shape=img_size + (3, ), dtype=np.uint8) 71 | mask = np.zeros(shape=img_size, dtype=bool) 72 | 73 | if draw_joints: 74 | for f, t in LIMB_SEQ: 75 | from_missing = pose_joints[f][0] == MISSING_VALUE or pose_joints[f][1] == MISSING_VALUE 76 | to_missing = pose_joints[t][0] == MISSING_VALUE or pose_joints[t][1] == MISSING_VALUE 77 | if from_missing or to_missing: 78 | continue 79 | yy, xx, val = line_aa(pose_joints[f][0], pose_joints[f][1], pose_joints[t][0], pose_joints[t][1]) 80 | colors[yy, xx] = np.expand_dims(val, 1) * 255 81 | mask[yy, xx] = True 82 | 83 | for i, joint in enumerate(pose_joints): 84 | if pose_joints[i][0] == MISSING_VALUE or pose_joints[i][1] == MISSING_VALUE: 85 | continue 86 | yy, xx = circle(joint[0], joint[1], radius=radius, shape=img_size) 87 | colors[yy, xx] = COLORS[i] 88 | mask[yy, xx] = True 89 | 90 | return colors, mask 91 | 92 | 93 | def draw_pose_from_map(pose_map, threshold=0.1, **kwargs): 94 | cords = map_to_cord(pose_map, threshold=threshold) 95 | return draw_pose_from_cords(cords, pose_map.shape[:2], **kwargs) 96 | 97 | 98 | def load_pose_cords_from_strings(y_str, x_str): 99 | y_cords = json.loads(y_str) 100 | x_cords = json.loads(x_str) 101 | return np.concatenate([np.expand_dims(y_cords, -1), np.expand_dims(x_cords, -1)], axis=1) 102 | 103 | def mean_inputation(X): 104 | X = X.copy() 105 | for i in range(X.shape[1]): 106 | for j in range(X.shape[2]): 107 | val = np.mean(X[:, i, j][X[:, i, j] != -1]) 108 | X[:, i, j][X[:, i, j] == -1] = val 109 | return X 110 | 111 | def draw_legend(): 112 | handles = [mpatches.Patch(color=np.array(color) / 255.0, label=name) for color, name in zip(COLORS, LABELS)] 113 | plt.legend(handles=handles, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 114 | 115 | def produce_ma_mask(kp_array, img_size, point_radius=4): 116 | from skimage.morphology import dilation, erosion, square 117 | mask = np.zeros(shape=img_size, dtype=bool) 118 | limbs = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10], 119 | [10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17], 120 | [1,16], [16,18], [2,17], [2,18], [9,12], [12,6], [9,3], [17,18]] 121 | limbs = np.array(limbs) - 1 122 | for f, t in limbs: 123 | from_missing = kp_array[f][0] == MISSING_VALUE or kp_array[f][1] == MISSING_VALUE 124 | to_missing = kp_array[t][0] == MISSING_VALUE or kp_array[t][1] == MISSING_VALUE 125 | if from_missing or to_missing: 126 | continue 127 | 128 | norm_vec = kp_array[f] - kp_array[t] 129 | norm_vec = np.array([-norm_vec[1], norm_vec[0]]) 130 | norm_vec = point_radius * norm_vec / np.linalg.norm(norm_vec) 131 | 132 | 133 | vetexes = np.array([ 134 | kp_array[f] + norm_vec, 135 | kp_array[f] - norm_vec, 136 | kp_array[t] - norm_vec, 137 | kp_array[t] + norm_vec 138 | ]) 139 | yy, xx = polygon(vetexes[:, 0], vetexes[:, 1], shape=img_size) 140 | mask[yy, xx] = True 141 | 142 | for i, joint in enumerate(kp_array): 143 | if kp_array[i][0] == MISSING_VALUE or kp_array[i][1] == MISSING_VALUE: 144 | continue 145 | yy, xx = circle(joint[0], joint[1], radius=point_radius, shape=img_size) 146 | mask[yy, xx] = True 147 | 148 | mask = dilation(mask, square(5)) 149 | mask = erosion(mask, square(5)) 150 | return mask 151 | 152 | if __name__ == "__main__": 153 | import pandas as pd 154 | from skimage.io import imread 155 | import pylab as plt 156 | import os 157 | i = 5 158 | df = pd.read_csv('data/market-annotation-train.csv', sep=':') 159 | 160 | for index, row in df.iterrows(): 161 | pose_cords = load_pose_cords_from_strings(row['keypoints_y'], row['keypoints_x']) 162 | 163 | colors, mask = draw_pose_from_cords(pose_cords, (128, 64)) 164 | 165 | mmm = produce_ma_mask(pose_cords, (128, 64)).astype(float)[..., np.newaxis].repeat(3, axis=-1) 166 | print(mmm.shape) 167 | img = imread('data/market-dataset/train/' + row['name']) 168 | 169 | mmm[mask] = colors[mask] 170 | 171 | print (mmm) 172 | plt.subplot(1, 1, 1) 173 | plt.imshow(mmm) 174 | plt.show() 175 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import inspect, re 6 | import os 7 | import collections 8 | from skimage.draw import circle, line_aa 9 | 10 | 11 | 12 | # Converts a Tensor into a Numpy array 13 | # |imtype|: the desired type of the converted numpy array 14 | def tensor2im(image_tensor, imtype=np.uint8): 15 | image_numpy = image_tensor[0].cpu().float().numpy() 16 | if image_numpy.shape[0] == 1: 17 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 18 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 19 | return image_numpy.astype(imtype) 20 | 21 | 22 | LIMB_SEQ = [[1, 2], [1, 5], [2, 3], [3, 4], [5, 6], [6, 7], [1, 8], [8, 9], 23 | [9, 10], [1, 11], [11, 12], [12, 13], [1, 0], [0, 14], [14, 16], 24 | [0, 15], [15, 17]] 25 | 26 | # draw dis img 27 | LIMB_SEQ_DIS = [[1, 2], [1, 5], [2, 3], [3, 4], [5, 6], [6, 7], [1, 8], [8, 9], 28 | [9, 10], [1, 11], [11, 12], [12, 13]] 29 | 30 | COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], 31 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], 32 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 33 | 34 | LABELS = ['nose', 'neck', 'Rsho', 'Relb', 'Rwri', 'Lsho', 'Lelb', 'Lwri', 35 | 'Rhip', 'Rkne', 'Rank', 'Lhip', 'Lkne', 'Lank', 'Leye', 'Reye', 'Lear', 'Rear'] 36 | 37 | MISSING_VALUE = -1 38 | 39 | 40 | def map_to_cord(pose_map, threshold=0.1): 41 | all_peaks = [[] for i in range(18)] 42 | pose_map = pose_map[..., :18] 43 | 44 | if torch.is_tensor(pose_map): 45 | pose_map = pose_map.cpu() 46 | try: 47 | y, x, z = np.where(np.logical_and(pose_map == 1.0, pose_map > threshold)) 48 | except: 49 | print(np.where(np.logical_and(pose_map == 1.0, pose_map > threshold))) 50 | print(pose_map.shape) 51 | for x_i, y_i, z_i in zip(x, y, z): 52 | all_peaks[z_i].append([x_i, y_i]) 53 | 54 | x_values = [] 55 | y_values = [] 56 | 57 | for i in range(18): 58 | if len(all_peaks[i]) != 0: 59 | x_values.append(all_peaks[i][0][0]) 60 | y_values.append(all_peaks[i][0][1]) 61 | else: 62 | x_values.append(MISSING_VALUE) 63 | y_values.append(MISSING_VALUE) 64 | 65 | return np.concatenate([np.expand_dims(y_values, -1), np.expand_dims(x_values, -1)], axis=1) 66 | 67 | 68 | def draw_pose_from_map(pose_map, threshold=0.1, **kwargs): 69 | # CHW -> HCW -> HWC 70 | pose_map = pose_map[0].cpu().transpose(1, 0).transpose(2, 1).numpy() 71 | 72 | cords = map_to_cord(pose_map, threshold=threshold) 73 | return draw_pose_from_cords(cords, pose_map.shape[:2], **kwargs) 74 | 75 | 76 | def draw_dis_from_map(pose_map, threshold=0.1, **kwargs): 77 | # CHW -> HCW -> HWC 78 | # print(pose_map.shape) 79 | if torch.is_tensor(pose_map): 80 | pose_map = pose_map[0].cpu().transpose(1, 0).transpose(2, 1).numpy() 81 | print(pose_map.shape) 82 | cords = map_to_cord(pose_map, threshold=threshold) 83 | return draw_dis_from_cords(cords, pose_map.shape[:2], **kwargs) 84 | 85 | 86 | 87 | # draw pose from map 88 | def draw_pose_from_cords(pose_joints, img_size, radius=2, draw_joints=True): 89 | colors = np.zeros(shape=img_size + (3,), dtype=np.uint8) 90 | mask = np.zeros(shape=img_size, dtype=bool) 91 | 92 | if draw_joints: 93 | for f, t in LIMB_SEQ: 94 | from_missing = pose_joints[f][0] == MISSING_VALUE or pose_joints[f][1] == MISSING_VALUE 95 | to_missing = pose_joints[t][0] == MISSING_VALUE or pose_joints[t][1] == MISSING_VALUE 96 | if from_missing or to_missing: 97 | continue 98 | yy, xx, val = line_aa(pose_joints[f][0], pose_joints[f][1], pose_joints[t][0], pose_joints[t][1]) 99 | colors[yy, xx] = np.expand_dims(val, 1) * 255 100 | mask[yy, xx] = True 101 | 102 | for i, joint in enumerate(pose_joints): 103 | if pose_joints[i][0] == MISSING_VALUE or pose_joints[i][1] == MISSING_VALUE: 104 | continue 105 | yy, xx = circle(joint[0], joint[1], radius=radius, shape=img_size) 106 | colors[yy, xx] = COLORS[i] 107 | mask[yy, xx] = True 108 | 109 | return colors, mask 110 | 111 | 112 | # point to line distance 113 | def get_distance_from_point_to_line(point, line_point1, line_point2): 114 | 115 | if line_point1 == line_point2: 116 | point_array = np.array(point) 117 | point1_array = np.array(line_point1) 118 | aa = np.expand_dims(np.expand_dims(point1_array, -1), -1) 119 | aa = np.repeat(aa, point.shape[1], 1) 120 | aa = np.repeat(aa, point.shape[2], 2) 121 | return np.linalg.norm(point_array - aa) 122 | A = line_point2[0] - line_point1[0] 123 | B = line_point1[1] - line_point2[1] 124 | C = (line_point1[0] - line_point2[0]) * line_point1[1] + \ 125 | (line_point2[1] - line_point1[1]) * line_point1[0] 126 | distance = np.abs(A * point[1] + B * point[0] + C) / (np.sqrt(A ** 2 + B ** 2)) 127 | distance = np.exp(-0.1 * distance) 128 | return distance 129 | 130 | 131 | 132 | 133 | # draw dis from map 134 | def draw_dis_from_cords(pose_joints, img_size, radius=2, draw_joints=True): 135 | dis = np.zeros(shape=img_size + (12,), dtype=np.float64) 136 | y = np.linspace(0, img_size[0] - 1, img_size[0]) 137 | x = np.linspace(0, img_size[1] - 1, img_size[1]) 138 | xv, yv = np.meshgrid(x, y) 139 | point = np.concatenate([np.expand_dims(yv, 0), np.expand_dims(xv, 0)], 0) 140 | 141 | for i, (f, t) in enumerate(LIMB_SEQ_DIS): 142 | from_missing = pose_joints[f][0] == MISSING_VALUE or pose_joints[f][1] == MISSING_VALUE 143 | to_missing = pose_joints[t][0] == MISSING_VALUE or pose_joints[t][1] == MISSING_VALUE 144 | if from_missing or to_missing: 145 | continue 146 | dis[:, :, i] = get_distance_from_point_to_line(point, [pose_joints[f][0], pose_joints[f][1]], 147 | [pose_joints[t][0], pose_joints[t][1]]) 148 | return dis, np.mean(dis, -1) 149 | 150 | 151 | 152 | def diagnose_network(net, name='network'): 153 | mean = 0.0 154 | count = 0 155 | for param in net.parameters(): 156 | if param.grad is not None: 157 | mean += torch.mean(torch.abs(param.grad.data)) 158 | count += 1 159 | if count > 0: 160 | mean = mean / count 161 | print(name) 162 | print(mean) 163 | 164 | 165 | def save_image(image_numpy, image_path): 166 | image_pil = Image.fromarray(image_numpy) 167 | image_pil.save(image_path) 168 | 169 | 170 | def info(object, spacing=10, collapse=1): 171 | """Print methods and doc strings. 172 | Takes module, class, list, dictionary, or string.""" 173 | methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)] 174 | processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s) 175 | print("\n".join(["%s %s" % 176 | (method.ljust(spacing), 177 | processFunc(str(getattr(object, method).__doc__))) 178 | for method in methodList])) 179 | 180 | 181 | def varname(p): 182 | for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]: 183 | m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line) 184 | if m: 185 | return m.group(1) 186 | 187 | 188 | def print_numpy(x, val=True, shp=False): 189 | x = x.astype(np.float64) 190 | if shp: 191 | print('shape,', x.shape) 192 | if val: 193 | x = x.flatten() 194 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 195 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 196 | 197 | 198 | def mkdirs(paths): 199 | if isinstance(paths, list) and not isinstance(paths, str): 200 | for path in paths: 201 | mkdir(path) 202 | else: 203 | mkdir(paths) 204 | 205 | 206 | def mkdir(path): 207 | if not os.path.exists(path): 208 | os.makedirs(path) 209 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import ntpath 4 | import time 5 | from . import util 6 | #from . import html 7 | 8 | 9 | class Visualizer(): 10 | def __init__(self, opt): 11 | # self.opt = opt 12 | self.display_id = opt.display_id 13 | self.use_html = opt.isTrain and not opt.no_html 14 | self.win_size = opt.display_winsize 15 | self.name = opt.name 16 | self.opt = opt 17 | self.saved = False 18 | if self.display_id > 0: 19 | import visdom 20 | self.vis = visdom.Visdom(port=opt.display_port) 21 | 22 | if self.use_html: 23 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 24 | self.img_dir = os.path.join(self.web_dir, 'images') 25 | print('create web directory %s...' % self.web_dir) 26 | util.mkdirs([self.web_dir, self.img_dir]) 27 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 28 | with open(self.log_name, "a") as log_file: 29 | now = time.strftime("%c") 30 | log_file.write('================ Training Loss (%s) ================\n' % now) 31 | 32 | def reset(self): 33 | self.saved = False 34 | 35 | # |visuals|: dictionary of images to display or save 36 | def display_current_results(self, visuals, epoch, save_result): 37 | if self.display_id > 0: # show images in the browser 38 | ncols = self.opt.display_single_pane_ncols 39 | if ncols > 0: 40 | h, w = next(iter(visuals.values())).shape[:2] 41 | table_css = """""" % (w, h) 45 | title = self.name 46 | label_html = '' 47 | label_html_row = '' 48 | nrows = int(np.ceil(len(visuals.items()) / ncols)) 49 | images = [] 50 | idx = 0 51 | for label, image_numpy in visuals.items(): 52 | label_html_row += '%s' % label 53 | images.append(image_numpy.transpose([2, 0, 1])) 54 | idx += 1 55 | if idx % ncols == 0: 56 | label_html += '%s' % label_html_row 57 | label_html_row = '' 58 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255 59 | while idx % ncols != 0: 60 | images.append(white_image) 61 | label_html_row += '' 62 | idx += 1 63 | if label_html_row != '': 64 | label_html += '%s' % label_html_row 65 | # pane col = image row 66 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 67 | padding=2, opts=dict(title=title + ' images')) 68 | label_html = '%s
' % label_html 69 | self.vis.text(table_css + label_html, win=self.display_id + 2, 70 | opts=dict(title=title + ' labels')) 71 | else: 72 | idx = 1 73 | for label, image_numpy in visuals.items(): 74 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 75 | win=self.display_id + idx) 76 | idx += 1 77 | 78 | if self.use_html and (save_result or not self.saved): # save images to a html file 79 | self.saved = True 80 | for label, image_numpy in visuals.items(): 81 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 82 | util.save_image(image_numpy, img_path) 83 | # update website 84 | # webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) 85 | # for n in range(epoch, 0, -1): 86 | # webpage.add_header('epoch [%d]' % n) 87 | # ims = [] 88 | # txts = [] 89 | # links = [] 90 | 91 | # for label, image_numpy in visuals.items(): 92 | # img_path = 'epoch%.3d_%s.png' % (n, label) 93 | # ims.append(img_path) 94 | # txts.append(label) 95 | # links.append(img_path) 96 | # webpage.add_images(ims, txts, links, width=self.win_size) 97 | # webpage.save() 98 | 99 | # errors: dictionary of error labels and values 100 | def plot_current_errors(self, epoch, counter_ratio, opt, errors): 101 | if not hasattr(self, 'plot_data'): 102 | self.plot_data = {'X': [], 'Y': [], 'legend': list(errors.keys())} 103 | self.plot_data['X'].append(epoch + counter_ratio) 104 | self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) 105 | self.vis.line( 106 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 107 | Y=np.array(self.plot_data['Y']), 108 | opts={ 109 | 'title': self.name + ' loss over time', 110 | 'legend': self.plot_data['legend'], 111 | 'xlabel': 'epoch', 112 | 'ylabel': 'loss'}, 113 | win=self.display_id) 114 | 115 | # errors: same format as |errors| of plotCurrentErrors 116 | def print_current_errors(self, epoch, i, errors, t): 117 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) 118 | for k, v in errors.items(): 119 | message += '%s: %.3f ' % (k, v) 120 | 121 | print(message) 122 | with open(self.log_name, "a") as log_file: 123 | log_file.write('%s\n' % message) 124 | 125 | # save image to the disk 126 | def save_images(self, webpage, visuals, image_path): 127 | image_dir = webpage.get_image_dir() 128 | short_path = ntpath.basename(image_path[0]) 129 | name = os.path.splitext(short_path)[0] 130 | 131 | webpage.add_header(name) 132 | ims = [] 133 | txts = [] 134 | links = [] 135 | 136 | for label, image_numpy in visuals.items(): 137 | image_name = '%s_%s.jpg' % (image_path[0], label) 138 | save_path = os.path.join(image_dir, image_name) 139 | print(save_path) 140 | util.save_image(image_numpy, save_path) 141 | 142 | ims.append(image_name) 143 | txts.append(label) 144 | links.append(image_name) 145 | webpage.add_images(ims, txts, links, width=self.win_size) 146 | --------------------------------------------------------------------------------