├── README.md ├── data.py ├── dataset.py ├── img_augm.py ├── mask_transfer.py ├── model.py ├── networks.py ├── pre_process.py ├── prepare_dataset.py ├── process_mask.py ├── scripts ├── inter.sh ├── test.sh └── train.sh ├── test.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # StyleFormer 2 | Official PyTorch implementation for the paper: 3 | 4 | [StyleFormer:Real-Time Arbitrary Style Transfer via Parametric Style Composition](https://openaccess.thecvf.com/content/ICCV2021/papers/Wu_StyleFormer_Real-Time_Arbitrary_Style_Transfer_via_Parametric_Style_Composition_ICCV_2021_paper.pdf) 5 | 6 | ## Overview 7 | 8 | This is our overall framework. 9 | ![image](https://user-images.githubusercontent.com/53161080/146502829-e6cbfd3d-47f1-48ad-9de1-1ce54a9ccbbc.png) 10 | 11 | ## Examples 12 | 13 | ![image](https://user-images.githubusercontent.com/53161080/146366097-1c314181-1d6e-4eb7-af5a-d6b17eece7a8.png) 14 | 15 | ## Introduction 16 | 17 | This is a release of the code of our [paper](https://openaccess.thecvf.com/content/ICCV2021/papers/Wu_StyleFormer_Real-Time_Arbitrary_Style_Transfer_via_Parametric_Style_Composition_ICCV_2021_paper.pdf) ***StyleFormer:Real-Time Arbitrary Style Transfer via Parametric Style Composition***, ICCV 2021 18 | 19 | **Authors**: Xiaolei Wu, Zhihao Hu, Lu Sheng\*, Dong Xu (\*corresponding author) 20 | 21 | ## Update 22 | * 2021.12.17: Upload PyTorch implementation of [StyleFormer](https://openaccess.thecvf.com/content/ICCV2021/papers/Wu_StyleFormer_Real-Time_Arbitrary_Style_Transfer_via_Parametric_Style_Composition_ICCV_2021_paper.pdf). 23 | 24 | 25 | ## Dependencies: 26 | 27 | * CUDA 10.1 28 | * python 3.7.7 29 | * pytorch 1.3.1 30 | 31 | ### Datasets 32 | 33 | ### MS-COCO 34 | 35 | Please download the [MS-COCO](http://msvocds.blob.core.windows.net/coco2014/train2014.zip) dataset. 36 | 37 | ### WikiArt 38 | 39 | Please download the WikiArt dataset from [Kaggle](https://www.kaggle.com/c/painter-by-numbers). 40 | 41 | ## Download Trained Models 42 | 43 | We provide the trained models of StyleFormer and VGG networks. 44 | 45 | - StyleFormer 46 | - [google drive](https://drive.google.com/drive/folders/1l53CJxbMiaU7c17laAT9d8Q_a4arxI28). 47 | - [BaiduNetdisk](https://pan.baidu.com/s/1gGHYyIwrtoRxZWQLNWHD1w) (Extraction Code: kc44) 48 | - VGG 49 | - [google drive](https://drive.google.com/drive/folders/1l53CJxbMiaU7c17laAT9d8Q_a4arxI28). 50 | - [BaiduNetdisk](https://pan.baidu.com/s/1jIxAlTK9LfgPhgd-rGcuew) (Extraction Code: n47y) 51 | 52 | ## Training 53 | ``` 54 | cd ./scripts 55 | sh train.sh {GPU_ID} 56 | ``` 57 | ## Test 58 | ``` 59 | git clone https://github.com/Wxl-stars/PytorchStyleFormer.git 60 | cd PytorchStyleFormer 61 | 62 | CUDA_VISIBLE_DEVICES={GPU_ID} python test.py \ 63 | --trained_network={PRE-TRAINED_STYLEFORMER_MODEL} \ 64 | --path={VGG_PATH} \ 65 | --input_path={CONTENT_PATH} \ 66 | --style_path={STYLE_PATH} \ 67 | --results_path={RESULTS_PATH} \ 68 | ``` 69 | 70 | ## Citation 71 | If you find our work useful in your research, please consider citing: 72 | ``` 73 | @inproceedings{wu2021styleformer, 74 | title={StyleFormer: Real-Time Arbitrary Style Transfer via Parametric Style Composition}, 75 | author={Wu, Xiaolei and Hu, Zhihao and Sheng, Lu and Xu, Dong}, 76 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 77 | pages={14618--14627}, 78 | year={2021} 79 | } 80 | ``` 81 | 82 | ## Contact 83 | If you have any questions or suggestions about this paper, feel free to contact: 84 | ``` 85 | Xiaolei Wu: wuxiaolei@buaa.edu.cn 86 | Zhihao Hu: huzhihao@buaa.edu.cn 87 | ``` 88 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.utils.data as data 3 | 4 | from PIL import Image 5 | import os 6 | import os.path 7 | import json 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | 19 | Image.MAX_IMAGE_PIXELS = 1e9 20 | mean = [0.485, 0.456, 0.406] 21 | std = [0.229, 0.224, 0.225] 22 | def image_to_tensor_PIL(path, image_size): 23 | try: 24 | image = Image.open(path).convert('RGB') 25 | except: 26 | print(f'cropt image:{path}, then remove') 27 | # shutil.move(path, './') 28 | # convert a PIL image to tensor (HWC) in range [0,255] to a torch.Tensor(CHW) in the range [0.0,1.0] 29 | # data_transform = transform.Compose([transform.Resize((image_size, image_size)), transform.ToTensor(), 30 | # transform.Normalize(mean, std)]) 31 | while(min(image.size[0], image.size[1]))<256: 32 | image = image.resize((image.size[0]*2, image.size[1]*2)) 33 | 34 | alpha=max(image.size[0], image.size[1])/min(image.size[0], image.size[1]) 35 | while(max(image.size[0], image.size[1])>1800 and alpha <3.5): 36 | image = image.resize((image.size[0]//2, image.size[1]//2)) 37 | # data_transform = transform.Compose([transform.Resize((512, 512)), transform.CenterCrop((image_size, image_size)), transform.ToTensor(), 38 | # transform.Normalize(mean, std)]) 39 | data_transform = transform.Compose([transform.CenterCrop((image_size, image_size)), transform.ToTensor(), 40 | transform.Normalize(mean, std)]) 41 | image_tensor = data_transform(image) 42 | 43 | return image_tensor 44 | def make_dataset(dir): 45 | images = [] 46 | for root, _, fnames in sorted(os.walk(dir)): 47 | for fname in fnames: 48 | if is_image_file(fname): 49 | path = os.path.join(root, fname) 50 | images.append(path) 51 | return images 52 | 53 | def make_dataset2(json_dir): 54 | with open(json_dir, 'r', encoding='utf_8') as fp: 55 | image_paths = json.load(fp) 56 | return image_paths 57 | 58 | 59 | def default_loader(path, flag, image_size): 60 | if flag: 61 | return Image.open(path).convert('RGB').resize([image_size,image_size]) 62 | else: 63 | return Image.open(path).convert('RGB') 64 | 65 | 66 | class ImageFolder(data.Dataset): 67 | 68 | def __init__(self, path, transform=None, return_paths=False, 69 | loader=default_loader, resize=False): 70 | imgs = sorted(make_dataset(path)) 71 | if len(imgs) == 0: 72 | raise (RuntimeError("Found 0 images in: " + path + "\n" 73 | "Supported image extensions are: " + 74 | ",".join(IMG_EXTENSIONS))) 75 | 76 | 77 | # self.root = root 78 | self.imgs = imgs 79 | self.transform = transform 80 | self.return_paths = return_paths 81 | self.loader = loader 82 | self.flag = resize 83 | self.style_size = 256 84 | #self.content_size = 85 | 86 | def __getitem__(self, index): 87 | path = self.imgs[index] 88 | print(path) 89 | if self.flag: 90 | img = Image.open(path).convert('RGB') 91 | #img = img.resize((img.size[0]//2, img.size[1]//2)) 92 | img = img.resize((256,256)) 93 | else: 94 | img = Image.open(path).convert('RGB') 95 | #img = self.loader(path, self.flag) 96 | if self.transform is not None: 97 | img = self.transform(img) 98 | else: 99 | img = image_to_tensor_PIL(path, self.style_size) 100 | return img, path 101 | 102 | def __len__(self): 103 | return len(self.imgs) 104 | 105 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import random 3 | 4 | import cv2 5 | import shutil 6 | import os 7 | import numpy as np 8 | import torch 9 | 10 | import warnings 11 | 12 | warnings.filterwarnings("error", category=UserWarning) 13 | 14 | import torchvision.transforms as transform 15 | 16 | from PIL import Image 17 | from utils import normalize_vgg, image_check 18 | 19 | Image.MAX_IMAGE_PIXELS = 1e9 20 | mean = [0.485, 0.456, 0.406] 21 | std = [0.229, 0.224, 0.225] 22 | 23 | Image.MAX_IMAGE_PIXELS = 1e9 24 | 25 | 26 | def image_to_tensor_PIL(path, image_size): 27 | try: 28 | image = Image.open(path).convert('RGB') 29 | except: 30 | print(f'cropt image:{path}, then remove') 31 | # shutil.move(path, './') 32 | # convert a PIL image to tensor (HWC) in range [0,255] to a torch.Tensor(CHW) in the range [0.0,1.0] 33 | # data_transform = transform.Compose([transform.Resize((image_size, image_size)), transform.ToTensor(), 34 | # transform.Normalize(mean, std)]) 35 | while(min(image.size[0], image.size[1]))<256: 36 | image = image.resize((image.size[0]*2, image.size[1]*2)) 37 | 38 | alpha=max(image.size[0], image.size[1])/min(image.size[0], image.size[1]) 39 | while(max(image.size[0], image.size[1])>1800 and alpha <3.5): 40 | image = image.resize((image.size[0]//2, image.size[1]//2)) 41 | # data_transform = transform.Compose([transform.Resize((512, 512)), transform.CenterCrop((image_size, image_size)), transform.ToTensor(), 42 | # transform.Normalize(mean, std)]) 43 | data_transform = transform.Compose([transform.RandomCrop((image_size, image_size)), transform.ToTensor(), 44 | transform.Normalize(mean, std)]) 45 | image_tensor = data_transform(image) 46 | 47 | return image_tensor 48 | 49 | 50 | def read_image_PIL(path, image_size): 51 | try: 52 | image = Image.open(path).convert('RGB') 53 | except: 54 | print(f'cropt image:{path}, then remove') 55 | #shutil.move(path, './') 56 | #image = image.resize((image_size, image_size)) 57 | data_transform = transform.Compose([transform.RandomCrop((image_size, image_size))]) 58 | image = data_transform(image) 59 | return np.array(image) 60 | 61 | 62 | def read_image_path(file_path, root_path): 63 | with open(file_path, 'r', encoding='gb18030') as rfile: 64 | reader_path = csv.DictReader(rfile) 65 | style_image_paths = [] 66 | for row in reader_path: 67 | style_image_paths.append(os.path.join(root_path, row['filename'])) 68 | return style_image_paths 69 | 70 | 71 | IMG_EXTENSIONS = [ 72 | '.jpg', '.JPG', '.jpeg', '.JPEG', 73 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 74 | ] 75 | 76 | 77 | def is_image_file(filename): 78 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 79 | 80 | 81 | def make_datast(dir): 82 | images = [] 83 | for root, _, fnames in sorted(os.walk(dir)): 84 | for fname in fnames: 85 | if is_image_file(fname): 86 | path = os.path.join(root, fname) 87 | images.append(path) 88 | return images 89 | 90 | 91 | def read_image(image_path, image_size): 92 | image = cv2.imread(image_path) 93 | image_tensor = cv2.cvtColor(cv2.resize(image, dsize=(image_size, image_size)), cv2.COLOR_BGR2RGB) 94 | # image_tensor = torch.from_numpy(image).permute(2, 0, 1) 95 | return image_tensor 96 | 97 | 98 | # def image_to_tensor(image): 99 | # image = np.ascontiguousarray(image, dtype=np.float32) 100 | # image_tensor = torch.from_numpy(image).permute(2, 0, 1) # C,H,W 101 | # # image_tensor = torch.from_numpy(image).permute(2, 0, 1) 102 | # return image_tensor 103 | 104 | 105 | def image_to_tensor(image, image_size): 106 | image_tensor = torch.from_numpy( 107 | cv2.cvtColor(cv2.resize(image, dsize=(image_size, image_size)), cv2.COLOR_BGR2RGB)).permute(2, 0, 1) # C,H,W 108 | # image_tensor = torch.from_numpy(image).permute(2, 0, 1) 109 | return image_tensor 110 | 111 | 112 | class ArtDataset(): 113 | def __init__(self, opts, augmentor): 114 | self.style_paths = read_image_path(opts.info_path, opts.style_data_path) 115 | self.content_paths = make_datast(opts.content_data_path) 116 | self.image_size = opts.image_size 117 | # self.root = opts.style_data_path 118 | self.aug = augmentor 119 | 120 | self.image_size = opts.image_size 121 | 122 | def __len__(self): 123 | return len(self.style_paths) 124 | 125 | def __getitem__(self, item): 126 | style_image = None 127 | content_index = random.randint(0, len(self.content_paths) - 1) 128 | content_path = self.content_paths[content_index] 129 | 130 | style_index = random.randint(0, len(self.style_paths) - 1) 131 | style_path = self.style_paths[style_index] 132 | 133 | # content_image = read_image_PIL(content_path, self.image_size) 134 | # try: 135 | # #style_image = cv2.imread(style_path) 136 | # style_image = read_image_PIL(style_path, self.image_size) 137 | # except: 138 | # print(style_path) 139 | # # content_image = image_check(content_image, content_path, augmentor=self.aug) 140 | # # style_image = image_check(style_image, style_path, augmentor=self.aug) 141 | 142 | # # content_tensor = normalize_vgg(image_to_tensor(content_image, self.image_size)) 143 | # style_tensdor = normalize_vgg(image_to_tensor(style_image, self.image_size).float()) 144 | content_tensor = image_to_tensor_PIL(content_path, self.image_size) 145 | style_tensor = image_to_tensor_PIL(style_path, self.image_size) 146 | 147 | return content_tensor, style_tensor 148 | -------------------------------------------------------------------------------- /img_augm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.misc 3 | import cv2 4 | from PIL import Image 5 | 6 | 7 | class Augmentor(): 8 | def __init__(self, 9 | crop_size=(256, 256), 10 | scale_augm_prb=0, scale_augm_range=0, 11 | rotation_augm_prb=0, rotation_augm_range=0.15, 12 | hsv_augm_prb=0, 13 | hue_augm_shift=0.05, 14 | saturation_augm_shift=0.05, saturation_augm_scale=0.05, 15 | value_augm_shift=0.05, value_augm_scale=0.05, 16 | affine_trnsfm_prb=0, affine_trnsfm_range=0.05, 17 | horizontal_flip_prb=0, 18 | vertical_flip_prb=0): 19 | 20 | self.crop_size = crop_size 21 | 22 | self.scale_augm_prb = scale_augm_prb 23 | self.scale_augm_range = scale_augm_range 24 | 25 | self.rotation_augm_prb = rotation_augm_prb 26 | self.rotation_augm_range = rotation_augm_range 27 | 28 | self.hsv_augm_prb = hsv_augm_prb 29 | self.hue_augm_shift = hue_augm_shift 30 | self.saturation_augm_scale = saturation_augm_scale 31 | self.saturation_augm_shift = saturation_augm_shift 32 | self.value_augm_scale = value_augm_scale 33 | self.value_augm_shift = value_augm_shift 34 | 35 | self.affine_trnsfm_prb = affine_trnsfm_prb 36 | self.affine_trnsfm_range = affine_trnsfm_range 37 | 38 | self.horizontal_flip_prb = horizontal_flip_prb 39 | self.vertical_flip_prb = vertical_flip_prb 40 | 41 | def __call__(self, image, is_inference=False): 42 | if is_inference: 43 | return cv2.resize(image, None, fx=self.crop_size[0], fy=self.crop_size[1], interpolation=cv2.INTER_CUBIC) 44 | 45 | # If not inference stage apply the pipeline of augmentations. 46 | if self.scale_augm_prb > np.random.uniform(): 47 | image = self.scale(image=image, 48 | scale_x=1. + np.random.uniform(low=-self.scale_augm_range, high=-self.scale_augm_range), 49 | scale_y=1. + np.random.uniform(low=-self.scale_augm_range, high=-self.scale_augm_range) 50 | ) 51 | 52 | 53 | rows, cols, ch = image.shape 54 | image = np.pad(array=image, pad_width=[[rows // 4, rows // 4], [cols // 4, cols // 4], [0, 0]], mode='reflect') 55 | if self.rotation_augm_prb > np.random.uniform(): 56 | image = self.rotate(image=image, 57 | angle=np.random.uniform(low=-self.rotation_augm_range*90., 58 | high=self.rotation_augm_range*90.) 59 | ) 60 | 61 | if self.affine_trnsfm_prb > np.random.uniform(): 62 | image = self.affine(image=image, 63 | rng=self.affine_trnsfm_range 64 | ) 65 | image = image[(rows // 4):-(rows // 4), (cols // 4):-(cols // 4), :] 66 | 67 | # Crop out patch of desired size. 68 | image = self.crop(image=image, 69 | crop_size=self.crop_size 70 | ) 71 | 72 | if self.hsv_augm_prb > np.random.uniform(): 73 | image = self.hsv_transform(image=image, 74 | hue_shift=self.hue_augm_shift, 75 | saturation_shift=self.saturation_augm_shift, 76 | saturation_scale=self.saturation_augm_scale, 77 | value_shift=self.value_augm_shift, 78 | value_scale=self.value_augm_scale) 79 | 80 | if self.horizontal_flip_prb > np.random.uniform(): 81 | image = self.horizontal_flip(image) 82 | 83 | if self.vertical_flip_prb > np.random.uniform(): 84 | image = self.vertical_flip(image) 85 | 86 | return image 87 | 88 | def scale(self, image, scale_x, scale_y): 89 | """ 90 | Args: 91 | image: 92 | scale_x: float positive value. New horizontal scale 93 | scale_y: float positive value. New vertical scale 94 | Returns: 95 | """ 96 | image = cv2.resize(image, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_CUBIC) 97 | return image 98 | 99 | def rotate(self, image, angle): 100 | """ 101 | Args: 102 | image: input image 103 | angle: angle of rotation in degrees 104 | Returns: 105 | """ 106 | rows, cols, ch = image.shape 107 | 108 | rot_M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1) 109 | image = cv2.warpAffine(image, rot_M, (cols, rows)) 110 | return image 111 | 112 | def crop(self, image, crop_size=(256, 256)): 113 | rows, cols, chs = image.shape 114 | x = int(np.random.uniform(low=0, high=max(0, rows - crop_size[0]))) 115 | y = int(np.random.uniform(low=0, high=max(0, cols - crop_size[1]))) 116 | 117 | image = image[x:x+crop_size[0], y:y+crop_size[1], :] 118 | # If the input image was too small to comprise patch of size crop_size, 119 | # resize obtained patch to desired size. 120 | if image.shape[0] < crop_size[0] or image.shape[1] < crop_size[1]: 121 | image = scipy.misc.imresize(arr=image, size=crop_size) 122 | return image 123 | 124 | def hsv_transform(self, image, 125 | hue_shift=0.2, 126 | saturation_shift=0.2, saturation_scale=0.2, 127 | value_shift=0.2, value_scale=0.2, 128 | ): 129 | 130 | image = Image.fromarray(image) 131 | hsv = np.array(image.convert("HSV"), 'float64') 132 | 133 | # scale the values to fit between 0 and 1 134 | hsv /= 255. 135 | 136 | # do the scalings & shiftings 137 | hsv[..., 0] += np.random.uniform(-hue_shift, hue_shift) 138 | hsv[..., 1] *= np.random.uniform(1. / (1. + saturation_scale), 1. + saturation_scale) 139 | hsv[..., 1] += np.random.uniform(-saturation_shift, saturation_shift) 140 | hsv[..., 2] *= np.random.uniform(1. / (1. + value_scale), 1. + value_scale) 141 | hsv[..., 2] += np.random.uniform(-value_shift, value_shift) 142 | 143 | # cut off invalid values 144 | hsv.clip(0.01, 0.99, hsv) 145 | 146 | # round to full numbers 147 | hsv = np.uint8(np.round(hsv * 254.)) 148 | 149 | # convert back to rgb image 150 | return np.asarray(Image.fromarray(hsv, "HSV").convert("RGB")) 151 | 152 | 153 | def affine(self, image, rng): 154 | rows, cols, ch = image.shape 155 | pts1 = np.float32([[0., 0.], [0., 1.], [1., 0.]]) 156 | [x0, y0] = [0. + np.random.uniform(low=-rng, high=rng), 0. + np.random.uniform(low=-rng, high=rng)] 157 | [x1, y1] = [0. + np.random.uniform(low=-rng, high=rng), 1. + np.random.uniform(low=-rng, high=rng)] 158 | [x2, y2] = [1. + np.random.uniform(low=-rng, high=rng), 0. + np.random.uniform(low=-rng, high=rng)] 159 | pts2 = np.float32([[x0, y0], [x1, y1], [x2, y2]]) 160 | affine_M = cv2.getAffineTransform(pts1, pts2) 161 | image = cv2.warpAffine(image, affine_M, (cols, rows)) 162 | 163 | return image 164 | 165 | def horizontal_flip(self, image): 166 | return image[:, ::-1, :] 167 | 168 | def vertical_flip(self, image): 169 | return image[::-1, :, :] 170 | 171 | -------------------------------------------------------------------------------- /mask_transfer.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from pathlib import Path 4 | import cv2 5 | 6 | import numpy 7 | from PIL import Image 8 | import datetime 9 | 10 | import tensorboardX 11 | import data 12 | from utils import normalize_arr_of_imgs, denormalize_vgg, write_2images, write_images 13 | import argparse 14 | from model import Grid 15 | import torch 16 | from collections import namedtuple 17 | from torchvision import transforms 18 | from torch.utils.data import DataLoader 19 | import torchvision 20 | import matplotlib 21 | matplotlib.use('Agg') 22 | import matplotlib.pyplot as plt 23 | 24 | from PIL import Image 25 | 26 | 27 | 28 | def save_conv_img(conv_img, sub_filename): 29 | root = './feature_map' 30 | if not os.path.exists(root): 31 | os.mkdir(root) 32 | sub_file = root+'/'+ sub_filename 33 | if not os.path.exists(sub_file): 34 | os.mkdir(sub_file) 35 | conv_img = conv_img.detach().cpu() 36 | feature_maps = conv_img.squeeze(0) 37 | img_num = feature_maps.shape[0] # 38 | all_feature_maps = [] 39 | for i in range(0, img_num): 40 | single_feature_map = feature_maps[i, :, :] 41 | all_feature_maps.append(single_feature_map) 42 | plt.imshow(single_feature_map) 43 | plt.savefig(sub_file + '/feature_{}'.format(i)) 44 | 45 | sum_feature_map = sum(feature_map for feature_map in all_feature_maps) 46 | plt.imshow(sum_feature_map) 47 | plt.savefig(sub_file +"/feature_map_sum.png") 48 | 49 | if __name__ == '__main__': 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument('--gf_dim', type=int, default=64) 52 | parser.add_argument('--df_dim', type=int, default=64) 53 | parser.add_argument('--dim', type=int, default=3) 54 | parser.add_argument('--resume', action='store_true') 55 | parser.add_argument('--init', type=str, default='kaiming') 56 | parser.add_argument('--path', type=str, default='/data3/wuxiaolei/models/vgg16-397923af.pth') 57 | 58 | # dataset 59 | parser.add_argument('--root', type=str, default='/data/dataset/') 60 | parser.add_argument('--input_path', type=str, default='/data2/wuxiaolei/project/content') 61 | parser.add_argument('--style1_path', type=str, default='/data2/wuxiaolei/project/style') 62 | parser.add_argument('--style2_path', type=str, default='/data2/wuxiaolei/project/style') 63 | parser.add_argument('--mask_path', type=str, default='/data2/wuxiaolei/project/style') 64 | parser.add_argument('--trained_network', type=str, help="path to the trained network file") 65 | parser.add_argument('--results_path', type=str, default='./results', help='outputs path') 66 | 67 | # data 68 | parser.add_argument('--image_size', type=int, default=256) 69 | 70 | # ouptut 71 | parser.add_argument('--output_path', type=str, default='./output', help='outputs path') 72 | 73 | # train 74 | parser.add_argument('--epoch', type=int, default=100) 75 | parser.add_argument('--lr', type=float, default=0.0001) 76 | parser.add_argument('--lr_policy', type=str, default='constant', help='step/constant') 77 | parser.add_argument('--step_size', type=int, default=200000) 78 | parser.add_argument('--gamma', type=float, default=0.5, help='How much to decay learning rate') 79 | parser.add_argument('--update_D', type=int, default=5) 80 | parser.add_argument('--batch_size', type=int, default=1) 81 | parser.add_argument('--save_freq', type=int, default=10000) 82 | 83 | # loss weight 84 | parser.add_argument('--clw', type=float, default=1, help='content_weight') 85 | parser.add_argument('--slw', type=float, default=10, help='style_weight') 86 | parser.add_argument('--alpha', type=float, default=0.8, help='style_weight') 87 | 88 | # bilateral grid 89 | parser.add_argument('--luma_bins', type=int, default=8) 90 | parser.add_argument('--channel_multiplier', type=int, default=1) 91 | parser.add_argument('--spatial_bin', type=int, default=8) 92 | parser.add_argument('--n_input_channel', type=int, default=256) 93 | parser.add_argument('--n_input_size', type=int, default=64) 94 | parser.add_argument('--group_num', type=int, default=16) 95 | 96 | # test 97 | parser.add_argument('--selection', type=str, default='Ax+b') 98 | parser.add_argument('--inter_selection', type=str, default='A1x+b2') 99 | # seed 100 | parser.add_argument('--seed', type=int, default=123) 101 | opts = parser.parse_args() 102 | 103 | # fix the seed 104 | numpy.random.seed(opts.seed) 105 | torch.manual_seed(opts.seed) 106 | torch.cuda.manual_seed_all(opts.seed) 107 | 108 | # opts = parser.parse_args() 109 | options = parser.parse_args() 110 | if not os.path.exists(options.output_path): 111 | os.mkdir(options.output_path) 112 | if not os.path.exists(options.results_path): 113 | os.mkdir(options.results_path) 114 | 115 | gpu_num = torch.cuda.device_count() 116 | 117 | #train_writer = tensorboardX.SummaryWriter(opts.results_path+'/grid') 118 | myNet = Grid(options, gpu_num).cuda() 119 | 120 | initial_step = myNet.resume_eval(options.trained_network) 121 | torch.backends.cudnn.benchmark = True 122 | # transform_style = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor(), 123 | # transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]) 124 | transform_style = transforms.Compose([transforms.ToTensor(), 125 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]) 126 | 127 | # transform_content = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor(), 128 | # transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]) 129 | transform_content = transforms.Compose([transforms.ToTensor(), 130 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]) 131 | # dataset = data.ImageFolder(options.input_path, transform=transform_content, return_paths=True, resize=False) 132 | # dataset_style =data.ImageFolder(options.style_path, transform=transform_style, return_paths=True, resize=False) 133 | # loader = DataLoader(dataset=dataset, batch_size=1, num_workers=0) 134 | # loader_style = DataLoader(dataset=dataset_style, batch_size=1, num_workers=0) 135 | 136 | # prepare imge 137 | content_img = Image.open(options.input_path).convert('RGB') 138 | style1_img = Image.open(options.style1_path).convert('RGB') 139 | style2_img = Image.open(options.style2_path).convert('RGB').resize((256, 256)) 140 | mask_img = Image.open(options.mask_path).convert('RGB') 141 | 142 | content = transform_content(content_img) 143 | style1 = transform_style(style1_img) 144 | style2 = transform_style(style2_img) 145 | mask = transform_style(mask_img) 146 | # print(mask) 147 | # exit(0) 148 | 149 | contentPath = Path(options.input_path) 150 | stylePath1 = Path(options.style1_path) 151 | stylePath2 = Path(options.style2_path) 152 | 153 | step = 1 154 | total_time = datetime.datetime.now() - datetime.datetime.now() 155 | 156 | 157 | # for it, images in enumerate(loader): 158 | # for it2, styles in enumerate(loader_style): 159 | # step += 1 160 | # content = images[0] 161 | # # from skimage.transform import resize 162 | # # content = resize(content, (content.shape[2]//2, content.shape[3]//2), anti_aliasing=True, preserve_range=True) 163 | # style = styles[0] 164 | # content_path = images[1][0] 165 | # style_path = styles[1][0] 166 | # content_name = os.path.split(content_path)[1] 167 | # style_name = os.path.split(style_path)[1] 168 | content_name = contentPath.stem 169 | style1_name = stylePath1.stem 170 | style2_name = stylePath2.stem 171 | style1 = style1.unsqueeze(0).cuda() 172 | style2 = style2.unsqueeze(0).cuda() 173 | content = content.unsqueeze(0).cuda() 174 | t0 = datetime.datetime.now() 175 | with torch.no_grad(): 176 | samp = myNet.sample_mask(content, style1, style2, mask) 177 | t1 = datetime.datetime.now() 178 | time = t1 - t0 179 | total_time += time 180 | print("time:%.8s",time.seconds+1e-6*time.microseconds) 181 | print('step:{}'.format(step)) 182 | image_outputs = [denormalize_vgg(samp).clamp_(0., 255.)] 183 | 184 | write_2images(image_outputs, 1, options.results_path, f'{style1_name[:-4]}+{style2_name[:-4]}+{content_name[:-4]}+base') 185 | 186 | # # visualization 187 | # torchvision.utils.save_image(score_map, f'./output/score_map_{step}.png', normalize=True) 188 | #save_conv_img(((hw_map.permute(0,3, 1, 2)[:1]+1)/2)*255., 'hw_map') 189 | #hw = hw_map.permute(0,3, 1, 2)[:1] 190 | #torchvision.utils.save_image(hw, f'./output/hw_map_{step}.jpg', normalize=True) 191 | #torchvision.utils.save_image(hw[:,:2,:,:], f'./output/hw_map_{step}.jpg', normalize=True) 192 | 193 | 194 | # content = Image.open(content_path) 195 | # c = Image.open(content_path) 196 | # style = Image.open(style_path) 197 | # result = Image.open(f'{options.output_path}/{style_name[:-4]}+{content_name[:-4]}+base.png') 198 | # c = c.resize(i for i in result.size) 199 | # s = style.resize((i//4 for i in result.size)) 200 | # # c = content.resize((i//4 for i in s.size)) 201 | 202 | # new = Image.new(result.mode, (result.width*2, result.height)) 203 | # new.paste(c, box=(0,0)) 204 | # new.paste(result, box=(c.width, 0)) 205 | # box = ((c.width//4)*3, c.height-s.height) 206 | # new.paste(s, box) 207 | # new.save(f'{options.results_path}/{style_name[:-4]}+{content_name[:-4]}.png', quality=95) 208 | avag = total_time/step 209 | print("avarge time:%.8s", avag.seconds+1e-6*avag.microseconds) 210 | 211 | 212 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from mistune import preprocessing 7 | 8 | from networks import VGG, decoder, StyleFormer 9 | from utils import get_scheduler, get_model_list, gram_matrix, \ 10 | calc_mean_std, adaptive_instance_normalization, put_tensor_cuda, TVloss, content_loss 11 | 12 | 13 | class Grid(nn.Module): 14 | def __init__(self, options, gpu_num): 15 | super(Grid, self).__init__() 16 | # build model 17 | print('-' * 8 + 'init Encoder' + '-' * 8) 18 | self.vgg = nn.DataParallel(VGG(options), list(range(gpu_num))) 19 | self.model = nn.DataParallel(StyleFormer(options), list(range(gpu_num))) 20 | print('-' * 8 + 'init Decoder' + '-' * 8) 21 | self.decoder = nn.DataParallel(decoder(options), list(range(gpu_num))) 22 | 23 | # Setup the optimizer 24 | gen_params = list(self.model.parameters()) + list(self.decoder.parameters()) 25 | self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], lr=options.lr, betas=(0.8, 0.999), 26 | weight_decay=0.0001, amsgrad=True) 27 | 28 | self.gen_scheduler = get_scheduler(self.gen_opt, options) 29 | 30 | # Loss criteria 31 | self.mse = nn.MSELoss(reduction='mean') 32 | self.abs = nn.L1Loss(reduction='mean') 33 | self.cos = nn.CosineSimilarity() 34 | 35 | self.gram_loss = torch.tensor(0.) 36 | self.per_loss = torch.tensor(0.) 37 | self.tv_loss = torch.tensor(0.) 38 | 39 | # image display 40 | self.input = None 41 | self.output = None 42 | self.content_style = None 43 | 44 | def gen_update(self): 45 | self.gener_loss.backward() 46 | self.gen_opt.step() 47 | 48 | def update(self, content, style, options): 49 | # input: content, style 50 | 51 | # zero gradient 52 | self.gen_opt.zero_grad() 53 | 54 | self.input = input 55 | self.content_style = torch.cat((content, style), dim=3) 56 | 57 | content_feats = self.vgg(content) 58 | style_feats = self.vgg(style) 59 | 60 | stylized_feature = self.model(style_feats[-2], content_feats[-2]) # relu3_1 61 | 62 | # stylied photo 63 | output = self.decoder(stylized_feature) 64 | self.output = output 65 | 66 | # loss 67 | 68 | # styel loss 69 | output_feats = self.vgg(output) 70 | 71 | self.gram_loss = self.get_mean_std_diff(output_feats, style_feats) 72 | 73 | # per loss 74 | self.per_loss = options.clw * content_loss(output_feats[-1], content_feats[-1]) 75 | # self.per_loss = options.clw * self.mse(output_feats[-1], content_feats[-1]) 76 | 77 | # # tv loss 78 | self.tv_loss = TVloss(output, options.tvw) 79 | 80 | # generator total loss 81 | self.gener_loss = options.slw * self.gram_loss + options.clw * self.per_loss + self.tv_loss 82 | self.gram_loss = options.slw * self.gram_loss 83 | self.gen_update() 84 | return None 85 | 86 | def get_output(self): 87 | return self.output 88 | 89 | def get_content_style(self): 90 | return self.content_style 91 | 92 | def get_mean_std_diff(self, feature1, feature2): 93 | diff = torch.tensor(0.).cuda() 94 | for i in range(len(feature1)): 95 | feat1 = feature1[i] 96 | feat2 = feature2[i] 97 | feat1_mean, feat1_std = calc_mean_std(feat1) 98 | feat2_mean, feat2_std = calc_mean_std(feat2) 99 | diff += self.mse(feat1_mean, feat2_mean) + self.mse(feat1_std, feat2_std) 100 | return diff 101 | 102 | def update_learning_rate(self): 103 | if self.gen_scheduler is not None: 104 | self.gen_scheduler.step() 105 | 106 | def resume(self, checkpoint_dir, options): 107 | # Load generators 108 | last_model_name = get_model_list(checkpoint_dir, "gen") 109 | if last_model_name == None: 110 | return 0 111 | state_dict = torch.load(last_model_name) 112 | self.model.load_state_dict(state_dict['a']) 113 | self.decoder.load_state_dict(state_dict['b']) 114 | iterations = int(last_model_name[-11:-3]) 115 | 116 | # Load optimizers 117 | last_model_name = get_model_list(checkpoint_dir, "opt") 118 | state_dict = torch.load(last_model_name) 119 | self.gen_opt.load_state_dict(state_dict['a']) 120 | 121 | # Reinitilize schedulers 122 | self.gen_scheduler = get_scheduler(self.gen_opt, options, iterations) 123 | print('Resume from iteration %d' % iterations) 124 | return iterations 125 | 126 | def resume_eval(self, trained_generator): # 在test的时候都要用什么。。 127 | state_dict = torch.load(trained_generator) 128 | self.model.load_state_dict(state_dict['a']) 129 | self.decoder.load_state_dict(state_dict['b']) 130 | return 0 131 | 132 | def save(self, snapshot_dir, iterations): 133 | # Save generators, discriminators, and optimizers 134 | gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) 135 | opt_name = os.path.join(snapshot_dir, 'opt_%08d.pt' % (iterations + 1)) 136 | torch.save({'a': self.model.state_dict(), 'b': self.decoder.state_dict()}, 137 | gen_name) 138 | torch.save({'a': self.gen_opt.state_dict()}, opt_name) 139 | 140 | def sample(self, content, style): 141 | self.eval() 142 | with torch.no_grad(): 143 | # print('content: ', content.shape) 144 | # print('style: ', style.shape) 145 | content_feat = self.vgg(content) 146 | style_feat = self.vgg(style) 147 | 148 | stylized_feature = self.model(style_feat[-2], content_feat[-2]) 149 | output = self.decoder(stylized_feature) 150 | self.train() 151 | return output 152 | 153 | def sample_inter(self, content, style1, style2): 154 | self.eval() 155 | with torch.no_grad(): 156 | print('content: ', content.shape) 157 | print('style1: ', style1.shape) 158 | print('style2: ', style2.shape) 159 | content_feat = self.vgg(content) 160 | style1_feat = self.vgg(style1) 161 | style2_feat = self.vgg(style2) 162 | stylized_feature = self.model.module.interpolation(style1_feat[-2], style2_feat[-2], content_feat[-2]) 163 | output = self.decoder(stylized_feature) 164 | self.train() 165 | return output 166 | 167 | def sample_mask(self, content, style1, style2, mask): 168 | self.eval() 169 | with torch.no_grad(): 170 | print('content: ', content.shape) 171 | print('style1: ', style1.shape) 172 | print('style2: ', style2.shape) 173 | print('mask: ', mask.shape) 174 | content_feat = self.vgg(content) 175 | style1_feat = self.vgg(style1) 176 | style2_feat = self.vgg(style2) 177 | stylized_feature = self.model(style1_feat[-2], content_feat[-2]) 178 | output1 = self.decoder(stylized_feature) 179 | stylized_feature = self.model(style2_feat[-2], content_feat[-2]) 180 | output2 = self.decoder(stylized_feature) 181 | mask = (mask > 0).to(output1.device).float() 182 | print('mask: ', mask.shape) 183 | output = mask * output1 + (1 - mask) * output2 184 | self.train() 185 | return output 186 | 187 | def test(self, content, style): 188 | self.eval() 189 | with torch.no_grad(): 190 | content_feat = self.vgg(content) 191 | style_feat = self.vgg(style) 192 | stylized_feature, score_map = self.model(style_feat[-2], content_feat[-2]) 193 | output = self.decoder(stylized_feature) 194 | self.train() 195 | return output, score_map 196 | 197 | def feats_crop(self, feats, reference): 198 | new_feats = [] 199 | for i in range(len(feats)): 200 | pad = (feats[i].shape[3] - reference[i].shape[3]) // 2 201 | new_feats.append(feats[i][:, :, pad:-pad, pad:-pad].contiguous()) 202 | return new_feats 203 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import torchvision 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torchvision import models 7 | import numpy as np 8 | 9 | 10 | VggNet = nn.Sequential( 11 | nn.Conv2d(3, 3, (1, 1)), 12 | nn.ReflectionPad2d((1, 1, 1, 1)), 13 | nn.Conv2d(3, 64, (3, 3)), 14 | nn.ReLU(), # relu1-1 15 | nn.ReflectionPad2d((1, 1, 1, 1)), 16 | nn.Conv2d(64, 64, (3, 3)), 17 | nn.ReLU(), # relu1-2 18 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 19 | nn.ReflectionPad2d((1, 1, 1, 1)), 20 | nn.Conv2d(64, 128, (3, 3)), 21 | nn.ReLU(), # relu2-1 22 | nn.ReflectionPad2d((1, 1, 1, 1)), 23 | nn.Conv2d(128, 128, (3, 3)), 24 | nn.ReLU(), # relu2-2 25 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 26 | nn.ReflectionPad2d((1, 1, 1, 1)), 27 | nn.Conv2d(128, 256, (3, 3)), 28 | nn.ReLU(), # relu3-1 29 | nn.ReflectionPad2d((1, 1, 1, 1)), 30 | nn.Conv2d(256, 256, (3, 3)), 31 | nn.ReLU(), # relu3-2 32 | nn.ReflectionPad2d((1, 1, 1, 1)), 33 | nn.Conv2d(256, 256, (3, 3)), 34 | nn.ReLU(), # relu3-3 35 | nn.ReflectionPad2d((1, 1, 1, 1)), 36 | nn.Conv2d(256, 256, (3, 3)), 37 | nn.ReLU(), # relu3-4 38 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 39 | nn.ReflectionPad2d((1, 1, 1, 1)), 40 | nn.Conv2d(256, 512, (3, 3)), 41 | nn.ReLU(), # relu4-1, this is the last layer used 42 | nn.ReflectionPad2d((1, 1, 1, 1)), 43 | nn.Conv2d(512, 512, (3, 3)), 44 | nn.ReLU(), # relu4-2 45 | nn.ReflectionPad2d((1, 1, 1, 1)), 46 | nn.Conv2d(512, 512, (3, 3)), 47 | nn.ReLU(), # relu4-3 48 | nn.ReflectionPad2d((1, 1, 1, 1)), 49 | nn.Conv2d(512, 512, (3, 3)), 50 | nn.ReLU(), # relu4-4 51 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 52 | nn.ReflectionPad2d((1, 1, 1, 1)), 53 | nn.Conv2d(512, 512, (3, 3)), 54 | nn.ReLU(), # relu5-1 55 | nn.ReflectionPad2d((1, 1, 1, 1)), 56 | nn.Conv2d(512, 512, (3, 3)), 57 | nn.ReLU(), # relu5-2 58 | nn.ReflectionPad2d((1, 1, 1, 1)), 59 | nn.Conv2d(512, 512, (3, 3)), 60 | nn.ReLU(), # relu5-3 61 | nn.ReflectionPad2d((1, 1, 1, 1)), 62 | nn.Conv2d(512, 512, (3, 3)), 63 | nn.ReLU() # relu5-4 64 | ) 65 | 66 | 67 | vgg16 = nn.Sequential( 68 | nn.ReflectionPad2d((1, 1, 1, 1)), 69 | nn.Conv2d(3, 64, kernel_size=3, stride=1), 70 | nn.ReLU(inplace=True), # relu1_1 71 | nn.ReflectionPad2d((1, 1, 1, 1)), 72 | nn.Conv2d(64, 64, kernel_size=3, stride=1), 73 | nn.ReLU(inplace=True), 74 | nn.MaxPool2d(kernel_size=2, stride=2,padding=0, dilation=1, ceil_mode=False), 75 | nn.ReflectionPad2d((1, 1, 1, 1)), 76 | nn.Conv2d(64, 128, kernel_size=3, stride=1), 77 | nn.ReLU(inplace=True), # relu2_1 78 | nn.ReflectionPad2d((1, 1, 1, 1)), 79 | nn.Conv2d(128,128, kernel_size=3, stride=1), 80 | nn.ReLU(inplace=True), 81 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), 82 | nn.ReflectionPad2d((1, 1, 1, 1)), 83 | nn.Conv2d(128, 256, kernel_size=3,stride=1), 84 | nn.ReLU(inplace=True), # relu3_1 85 | nn.ReflectionPad2d((1, 1, 1, 1)), 86 | nn.Conv2d(256, 256, kernel_size=3, stride=1), 87 | nn.ReLU(inplace=True), 88 | nn.ReflectionPad2d((1, 1, 1, 1)), 89 | nn.Conv2d(256, 256, kernel_size=3, stride=1), 90 | nn.ReLU(inplace=True), 91 | nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False), 92 | nn.ReflectionPad2d((1, 1, 1, 1)), 93 | nn.Conv2d(256, 512, kernel_size=3, stride=1), 94 | nn.ReLU(inplace=True), # relu4_1 95 | ) 96 | 97 | vgg_conv_list = [1,4,8,11,15,18,21,25] 98 | vgg_model_conv_list = [0, 2, 5, 7, 10, 12, 14, 17] 99 | 100 | class VGG(nn.Module): 101 | def __init__(self, options): 102 | super(VGG, self).__init__() 103 | # vgg_pad 104 | vgg_model = models.vgg16(pretrained=False) 105 | vgg_model.load_state_dict(torch.load(options.path)) 106 | vgg_model = vgg_model.features 107 | vgg = vgg16 108 | 109 | for i in range(8): 110 | vgg[vgg_conv_list[i]].weight = vgg_model[vgg_model_conv_list[i]].weight 111 | vgg[vgg_conv_list[i]].bias = vgg_model[vgg_model_conv_list[i]].bias 112 | self.test = vgg[vgg_conv_list[7]].weight 113 | 114 | for p in self.parameters(): 115 | p.requires_grad = False 116 | self.slice1 = vgg[:3] # relu1_1 117 | self.slice2 = vgg[3:10] # relu2_1 118 | self.slice3 = vgg[10:17] # relu3_1 119 | self.slice4 = vgg[17:27] # relu4_1 120 | for p in self.parameters(): 121 | p.requires_grad = False 122 | 123 | def forward(self, x): 124 | out = [] 125 | x = self.slice1(x) 126 | out.append(x) 127 | x = self.slice2(x) 128 | out.append(x) 129 | x = self.slice3(x) 130 | out.append(x) 131 | x = self.slice4(x) 132 | out.append(x) 133 | return out 134 | 135 | 136 | class decoder(nn.Module): 137 | def __init__(self, options): 138 | super(decoder, self).__init__() 139 | self.model = nn.Sequential( 140 | nn.ReflectionPad2d((1, 1, 1, 1)), 141 | nn.Conv2d(256, 256, (3, 3)), 142 | nn.ReLU(), 143 | nn.ReflectionPad2d((1, 1, 1, 1)), 144 | nn.Conv2d(256, 256, (3, 3)), 145 | nn.ReLU(), 146 | nn.ReflectionPad2d((1, 1, 1, 1)), 147 | nn.Conv2d(256, 256, (3, 3)), 148 | nn.ReLU(), 149 | nn.ReflectionPad2d((1, 1, 1, 1)), 150 | nn.Conv2d(256, 128, (3, 3)), 151 | nn.ReLU(), 152 | nn.Upsample(scale_factor=2, mode='nearest'), 153 | nn.ReflectionPad2d((1, 1, 1, 1)), 154 | nn.Conv2d(128, 128, (3, 3)), 155 | nn.ReLU(), 156 | nn.ReflectionPad2d((1, 1, 1, 1)), 157 | nn.Conv2d(128, 64, (3, 3)), 158 | nn.ReLU(), 159 | nn.Upsample(scale_factor=2, mode='nearest'), 160 | nn.ReflectionPad2d((1, 1, 1, 1)), 161 | nn.Conv2d(64, 64, (3, 3)), 162 | nn.ReLU(), 163 | nn.ReflectionPad2d((1, 1, 1, 1)), 164 | nn.Conv2d(64, 3, (3, 3)), 165 | ) 166 | 167 | def forward(self, stylized_feature): 168 | x = self.model(stylized_feature) 169 | return x 170 | 171 | 172 | class ConvBlock(nn.Module): 173 | def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, use_bias=True, activation=nn.ReLU, 174 | batch_norm=False): 175 | super(ConvBlock, self).__init__() 176 | self.conv = nn.Conv2d(int(inc), int(outc), kernel_size, stride=stride, bias=use_bias) 177 | # self.activation = activation() if activation else None 178 | if activation == 'sigmoid': 179 | self.activation = nn.Sigmoid() 180 | elif activation is None: 181 | self.activation = None 182 | else: 183 | self.activation = activation() 184 | self.bn = nn.BatchNorm2d(outc) if batch_norm else None 185 | self.p = padding 186 | 187 | def forward(self, x): 188 | x = F.pad(x, (self.p, self.p, self.p, self.p), mode='reflect') 189 | x = self.conv(x) 190 | if self.bn: 191 | x = self.bn(x) 192 | if self.activation: 193 | x = self.activation(x) 194 | return x 195 | 196 | 197 | class Coeffs(nn.Module): 198 | 199 | def __init__(self, nin=16, nout=17, options=None): 200 | super(Coeffs, self).__init__() 201 | self.nin = nin 202 | self.nout = nout 203 | 204 | self.lb = options.luma_bins #8#params['luma_bins'] 205 | self.cm = options.channel_multiplier #1#params['channel_multiplier'] 206 | self.sb = options.spatial_bin #8#params['spatial_bin'] 207 | self.G = options.group_num #16 208 | bn = False 209 | nsize = options.n_input_size #64#params['net_input_size'] 210 | nchannel = options.n_input_channel #256 211 | self.relu = nn.ReLU() 212 | 213 | n_layers_splat = int(np.log2(nsize / self.sb)) 214 | self.splat_features = nn.ModuleList() 215 | prev_ch = nchannel #nin 216 | for i in range(n_layers_splat): 217 | use_bn = False #bn if i > 0 else False 218 | self.splat_features.append(ConvBlock(prev_ch, nchannel, 3, stride=2, batch_norm=use_bn)) 219 | prev_ch = nchannel 220 | 221 | # local features 222 | self.local_features = nn.ModuleList() 223 | self.local_features.append(ConvBlock(nchannel, 32 * self.cm * self.lb, 3, stride=1, batch_norm=bn)) 224 | self.local_features.append(ConvBlock(32 * self.cm * self.lb, 32 * self.cm * self.lb, 3, stride=1, activation=None, use_bias=False)) 225 | 226 | # predicton 227 | self.conv_out = ConvBlock(32 * self.cm * self.lb, self.G * self.lb * nout * nin, 1, padding=0, activation=None) 228 | 229 | def forward(self, lowres_input): 230 | bs = lowres_input.shape[0] 231 | 232 | x = lowres_input 233 | for layer in self.splat_features: 234 | x = layer(x) 235 | splat_features = x 236 | 237 | x = splat_features 238 | for layer in self.local_features: 239 | x = layer(x) 240 | local_features = x 241 | 242 | fusion_grid = local_features 243 | fusion = self.relu(fusion_grid) 244 | 245 | x = self.conv_out(fusion) 246 | s = x.shape 247 | x = x.view(bs * self.G, self.nin * self.nout, self.lb, s[2], s[3]) # B x Coefs x Luma x Spatial x Spatial 248 | return x 249 | 250 | 251 | class GuideNN(nn.Module): 252 | def __init__(self, options=None): 253 | super(GuideNN, self).__init__() 254 | self.conv1 = ConvBlock(16, 4, kernel_size=1, padding=0) 255 | self.conv2 = ConvBlock(4, 1, kernel_size=1, padding=0, activation='sigmoid') 256 | self.G = options.group_num 257 | 258 | def forward(self, x): 259 | x = self.conv2(self.conv1(x)) 260 | return x 261 | 262 | 263 | class Slice(nn.Module): 264 | def __init__(self): 265 | super(Slice, self).__init__() 266 | 267 | def forward(self, affine_transformation, guidemap): 268 | device = affine_transformation.get_device() 269 | 270 | N, _, H, W = guidemap.shape 271 | hg, wg = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)]) # [0,511] HxW 272 | if device >= 0: 273 | hg = hg.to(device) 274 | wg = wg.to(device) 275 | hg = hg.float().repeat(N, 1, 1).unsqueeze(3) / (H-1) # norm to [0,1] NxHxWx1 276 | wg = wg.float().repeat(N, 1, 1).unsqueeze(3) / (W-1) # norm to [0,1] NxHxWx1 277 | hg, wg = hg*2-1, wg*2-1 278 | guidemap = guidemap.permute(0, 2, 3, 1).contiguous() 279 | guidemap_guide = torch.cat([wg, hg, guidemap], dim=3).unsqueeze(1) # Nx1xHxWx3 280 | coeff = F.grid_sample(affine_transformation, guidemap_guide, mode='bilinear', padding_mode='reflection', align_corners=True) 281 | return coeff.squeeze(2) 282 | 283 | 284 | class ApplyCoeffs(nn.Module): 285 | def __init__(self, options=None): 286 | super(ApplyCoeffs, self).__init__() 287 | self.degree = 3 288 | self.G = options.group_num 289 | self.alpha = options.alpha 290 | self.sect = options.selection 291 | self.inter_sect = options.inter_selection 292 | # self.sect = 'Ax+b' 293 | 294 | def forward(self, coeff, full_res_input): 295 | N, C, H, W = full_res_input.shape 296 | CG = C // self.G 297 | output = [] 298 | for i in range(self.G): 299 | # print(full_res_input.shape) 300 | # print(coeff.shape) 301 | # print(self.sect) 302 | if self.sect == 'Ax+b': 303 | x = torch.sum(full_res_input * coeff[:, i*(CG+1):(i+1)*(CG+1)-1, :, :], dim=1, keepdim=True) + coeff[:, (i+1)*(CG+1)-1:(i+1)*(CG+1), :, :] 304 | if self.sect == 'Ax': 305 | x = torch.sum(full_res_input * coeff[:, i*(CG+1):(i+1)*(CG+1)-1, :, :], dim=1, keepdim=True) # no bias 306 | if self.sect == 'x+b': 307 | x = torch.sum(full_res_input, dim=1, keepdim=True) + coeff[:, (i+1)*(CG+1)-1:(i+1)*(CG+1), :, :] 308 | if self.sect == 'b': 309 | x = coeff[:, (i+1)*(CG+1)-1:(i+1)*(CG+1), :, :] 310 | if self.sect == 'aAx+b': 311 | x = torch.sum(full_res_input * self.alpha*coeff[:, i*(CG+1):(i+1)*(CG+1)-1, :, :], dim=1, keepdim=True) + coeff[:, (i+1)*(CG+1)-1:(i+1)*(CG+1), :, :] 312 | output.append(x) 313 | return torch.cat(output, dim=1) 314 | 315 | def mix(self, coeff1, coeff2, full_res_input): 316 | N, C, H, W = full_res_input.shape 317 | CG = C // self.G 318 | output = [] 319 | a = 0.9 320 | b = 0.1 321 | for i in range(self.G): 322 | # print(self.inter_sect) 323 | if self.inter_sect == 'A1x+b2': 324 | x = torch.sum(full_res_input * coeff1[:, i*(CG+1):(i+1)*(CG+1)-1, :, :], dim=1, keepdim=True) + coeff2[:, (i+1)*(CG+1)-1:(i+1)*(CG+1), :, :] # A1xx+b2 325 | if self.inter_sect == 'A2x+b2': 326 | x = torch.sum(full_res_input * coeff2[:, i*(CG+1):(i+1)*(CG+1)-1, :, :], dim=1, keepdim=True) + coeff2[:, (i+1)*(CG+1)-1:(i+1)*(CG+1), :, :] # A2x+b2 327 | if self.inter_sect == 'A2x+b1': 328 | x = torch.sum(full_res_input * coeff2[:, i*(CG+1):(i+1)*(CG+1)-1, :, :], dim=1, keepdim=True) + coeff1[:, (i+1)*(CG+1)-1:(i+1)*(CG+1), :] 329 | if self.inter_sect == '(a1A1+a2A2)x+b1': 330 | x = torch.sum(full_res_input * (a*coeff1[:, i*(CG+1):(i+1)*(CG+1)-1, :, :]+b*coeff2[:, i*(CG+1):(i+1)*(CG+1)-1, :, :]), dim=1, keepdim=True) + coeff1[:, (i+1)*(CG+1)-1:(i+1)*(CG+1), :] 331 | if self.inter_sect == '(a1A1+a2A2)x+b2': 332 | x = torch.sum(full_res_input * (a*coeff1[:, i*(CG+1):(i+1)*(CG+1)-1, :, :]+b*coeff2[:, i*(CG+1):(i+1)*(CG+1)-1, :, :]), dim=1, keepdim=True) + coeff2[:, (i+1)*(CG+1)-1:(i+1)*(CG+1), :] 333 | if self.inter_sect == '(a1A1+a2A2)x+a1*b1+a2*b2 ': 334 | x = torch.sum(full_res_input * (a*coeff1[:, i*(CG+1):(i+1)*(CG+1)-1, :, :]+b*coeff2[:, i*(CG+1):(i+1)*(CG+1)-1, :, :]), dim=1, keepdim=True) + a*coeff1[:, (i+1)*(CG+1)-1:(i+1)*(CG+1), :] + b*coeff2[:, (i+1)*(CG+1)-1:(i+1)*(CG+1),:] 335 | output.append(x) 336 | return torch.cat(output, dim=1) 337 | 338 | 339 | class StyleFormer(nn.Module): 340 | def __init__(self, options): 341 | super(StyleFormer, self).__init__() 342 | self.coeffs = Coeffs(options=options) 343 | self.att = AttModule(options) 344 | self.guide = GuideNN(options=options) 345 | self.slice = Slice() 346 | self.apply_coeffs = ApplyCoeffs(options=options) 347 | self.G = options.group_num 348 | self.fcc = nn.Conv2d(256, 256, 1, 1) 349 | self.fcs = nn.Conv2d(256, 256, 1, 1) 350 | 351 | def forward(self, style_feat, content_feat): 352 | coeffs = self.coeffs(style_feat) 353 | content_feat = self.fcc(content_feat) 354 | style_feat = self.fcs(style_feat) 355 | content_feat = F.group_norm(content_feat, num_groups=self.G) 356 | style_norm = F.group_norm(style_feat, num_groups=self.G) 357 | N, C, H, W = content_feat.shape 358 | content_feat = content_feat.view(N*self.G, C//self.G, H, W) 359 | N, C, Hs, Ws = style_feat.shape 360 | style_feat = style_feat.view(N*self.G, -1, Hs, Ws) 361 | style_norm = style_norm.view(N*self.G, -1, Hs, Ws) 362 | 363 | # grid with attention 364 | att_coeffs = self.att(content_feat, style_norm, coeffs) 365 | guide = self.guide(content_feat) 366 | 367 | slice_coeffs = self.slice(att_coeffs, guide) 368 | out = self.apply_coeffs(slice_coeffs, content_feat) 369 | out = out.view(N, C, H, W) 370 | return out 371 | 372 | def interpolation(self, style1_feat, style2_feat, content_feat): 373 | coeff1 = self.coeffs(style1_feat) 374 | coeff2 = self.coeffs(style2_feat) 375 | 376 | content_feat = self.fcc(content_feat) 377 | style1_feat = self.fcs(style1_feat) 378 | style2_feat = self.fcs(style2_feat) 379 | 380 | content_feat = F.group_norm(content_feat, num_groups=self.G) 381 | style1_norm = F.group_norm(style1_feat, num_groups=self.G) 382 | style2_norm = F.group_norm(style2_feat, num_groups=self.G) 383 | N, C, H, W = content_feat.shape 384 | content_feat = content_feat.view(N*self.G, C//self.G, H, W) 385 | N, C, Hs1, Ws1 = style1_feat.shape 386 | N, C, Hs2, Ws2 = style2_feat.shape 387 | style1_feat = style1_feat.view(N*self.G, -1, Hs1, Ws1) 388 | style1_norm = style1_feat.view(N*self.G, -1, Hs1, Ws1) 389 | style2_feat = style2_norm.view(N*self.G, -1, Hs2, Ws2) 390 | style2_norm = style2_norm.view(N*self.G, -1, Hs2, Ws2) 391 | 392 | # grid with attention 393 | att_coeffs1 = self.att(content_feat, style1_norm, coeff1) 394 | att_coeffs2 = self.att(content_feat, style2_norm, coeff2) 395 | guide = self.guide(content_feat) 396 | 397 | # interpolation 398 | slice_coeffs1 = self.slice(att_coeffs1, guide) 399 | slice_coeffs2 = self.slice(att_coeffs2, guide) 400 | 401 | # style mix 402 | out = self.apply_coeffs.mix(slice_coeffs1, slice_coeffs2, content_feat) 403 | out = out.view(N, C, H, W) 404 | return out 405 | 406 | def mask_mix(self, style1_feat, style2_feat, content_feat): 407 | coeff1 = self.coeffs(style1_feat) 408 | coeff2 = self.coeffs(style2_feat) 409 | 410 | content_feat = self.fcc(content_feat) 411 | style1_feat = self.fcs(style1_feat) 412 | style2_feat = self.fcs(style2_feat) 413 | 414 | content_feat = F.group_norm(content_feat, num_groups=self.G) 415 | style1_norm = F.group_norm(style1_feat, num_groups=self.G) 416 | style2_norm = F.group_norm(style2_feat, num_groups=self.G) 417 | N, C, H, W = content_feat.shape 418 | content_feat = content_feat.view(N*self.G, C//self.G, H, W) 419 | N, C, Hs1, Ws1 = style1_feat.shape 420 | N, C, Hs2, Ws2 = style2_feat.shape 421 | style1_feat = style1_feat.view(N*self.G, -1, Hs1, Ws1) 422 | style1_norm = style1_feat.view(N*self.G, -1, Hs1, Ws1) 423 | style2_feat = style2_norm.view(N*self.G, -1, Hs2, Ws2) 424 | style2_norm = style2_norm.view(N*self.G, -1, Hs2, Ws2) 425 | 426 | # grid with attention 427 | att_coeffs1 = self.att(content_feat, style1_norm, coeff1) 428 | att_coeffs2 = self.att(content_feat, style2_norm, coeff2) 429 | guide = self.guide(content_feat) 430 | 431 | # interpolation 432 | slice_coeffs1 = self.slice(att_coeffs1, guide) 433 | slice_coeffs2 = self.slice(att_coeffs2, guide) 434 | 435 | # style mix 436 | out = self.apply_coeffs.mix(slice_coeffs1, slice_coeffs2, content_feat) 437 | out = out.view(N, C, H, W) 438 | return out 439 | 440 | 441 | class AttModule(nn.Module): 442 | def __init__(self, options): 443 | super(AttModule, self).__init__() 444 | self.convc1 = ConvBlock(16, 16, stride=2) 445 | self.convc2 = ConvBlock(16, 16, stride=2, activation=None) 446 | 447 | self.convs1 = ConvBlock(16, 16, stride=2) 448 | self.convs2 = ConvBlock(16, 16, stride=2, activation=None) 449 | 450 | self.grid_channel = options.luma_bins*16*17 451 | self.convsr = ConvBlock(self.grid_channel, self.grid_channel, activation=None) 452 | self.G = options.n_input_channel//options.group_num 453 | self.cpg = options.group_num # channels per group 454 | 455 | self.sp = options.spatial_bin 456 | 457 | def forward(self, c, s, grid): 458 | Ng, Cg, Lg, Hg, Wg = grid.shape 459 | Ng, C, Hs, Ws = s.shape 460 | 461 | c = self.convc1(c) 462 | c1 = self.convc2(c) 463 | Ng, C, H, W = c1.shape 464 | c1 = c1.view(Ng, 16, -1) 465 | 466 | s = self.convs1(s) 467 | s1 = self.convs2(s).view(Ng, 16, -1) 468 | 469 | cs = torch.bmm(c1.permute(0, 2, 1), s1) # attention map 470 | 471 | # ############# visualization 472 | # for i in range(16): 473 | # display = cs.view(16, H, W, 16, 16)[i, :, :, 14, 13] 474 | # mx = torch.max(display) 475 | # mn = torch.min(display) 476 | # display = (display - mn) / (mx - mn) 477 | # # display = F.interpolate(display, (64, 64)) 478 | # heatmap = display.cpu().detach().numpy() 479 | # heatmap = heatmap*255 480 | # heatmap=heatmap.astype(np.uint8) 481 | # heatmap = cv2.resize(heatmap, (W*16, H*16)) 482 | # #cv2.imwrite('map.png', heatmap) 483 | # heatmap=cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) 484 | # cv2.imwrite(f'attention_results/heatmap{i+1}.png', heatmap) 485 | # exit(0) 486 | ########### visualization end 487 | 488 | cs = F.softmax(cs, dim=2) 489 | grid = grid.view(Ng, -1, Hg, Wg) 490 | sr = self.convsr(grid).view(Ng, self.grid_channel, -1) 491 | rs = torch.bmm(sr, cs.permute(0,2,1)) 492 | return rs.view(Ng, Cg, Lg, H, W) 493 | 494 | 495 | 496 | -------------------------------------------------------------------------------- /pre_process.py: -------------------------------------------------------------------------------- 1 | 2 | from PIL import Image 3 | 4 | 5 | # def preprocess(content_path): 6 | def preprocess(img): 7 | # img = Image.open(content_path).convert('RGB') 8 | H, W = img.shape 9 | 10 | if max(H, W) >= 256 and max(H, W) < 512: 11 | long_side = 256 12 | if max(H, W) >= 512 and max(H, W) < 1024: 13 | long_side = 512 14 | if max(H, W) > 1024: 15 | long_side = 1024 16 | 17 | if H > W: 18 | H = long_side 19 | W = int(long_side / H * W) 20 | return H, W 21 | 22 | -------------------------------------------------------------------------------- /prepare_dataset.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import imghdr 3 | 4 | import numpy as np 5 | import os 6 | import cv2 7 | 8 | import pandas 9 | import torch 10 | from PIL import Image 11 | from tqdm import tqdm 12 | import scipy.misc 13 | import random 14 | import torchvision.transforms as transform 15 | 16 | # from utils import image_check, get_one_hot_encoded_vector 17 | from utils import image_check, normalize_vgg 18 | 19 | 20 | def get_batch_tensor(dataset, batch_size=2): 21 | batch_tensor = None 22 | for i in range(batch_size): 23 | index = random.randint(1, len(dataset)) 24 | image_tensor = dataset.getitem(index) 25 | image_tensor.unsqeeze_(0) 26 | if batch_tensor is None: 27 | batch_tensor = image_tensor 28 | else: 29 | batch_tensor = torch.cat((batch_tensor, image_tensor), 0) 30 | return batch_tensor 31 | 32 | 33 | def image_to_tensor(image, image_size): 34 | image_tensor = torch.from_numpy( 35 | cv2.cvtColor(cv2.resize(image, dsize=(image_size, image_size)), cv2.COLOR_BGR2RGB)).permute(2, 0, 1) # C,H,W 36 | # image_tensor = torch.from_numpy(image).permute(2, 0, 1) 37 | return image_tensor 38 | 39 | 40 | def image_to_tensor_PIL(path, image_size): 41 | image = Image.open(path).convert('RGB') 42 | # image = Image.open(path).convert('RGB').resize(image_size, image_size) 43 | # convert a PIL image to tensor (HWC) in range [0,255] to a torch.Tensor(CHW) in the range [0.0,1.0] 44 | data_transform = transform.Compose([transform.Resize((image_size, image_size)), transform.ToTensor(), 45 | transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 46 | image_tensor = data_transform(image) 47 | return image_tensor 48 | 49 | 50 | class ArtDataset(): 51 | def __init__(self, opts): 52 | self.path_to_art_dataset = opts.art_data_path 53 | self.style_image_paths = [] 54 | 55 | self.image_size = opts.image_size 56 | # self.dataset = [os.path.join(self.path_to_art_dataset, x) for x in os.listdir(self.path_to_art_dataset)] 57 | print("Art dataset contains %d images." % len(self.style_image_paths)) 58 | 59 | def __len__(self): 60 | return len(self.style_image_paths) // 2 61 | 62 | def __getitem__(self, item): 63 | path, pathb = self.style_image_paths[item * 2], self.style_image_paths[item * 2 + 1] 64 | image_tensor1 = image_to_tensor(path, self.image_size) 65 | image_tensor2 = image_to_tensor(pathb, self.image_size) 66 | # label1 = 67 | return image_tensor1, image_tensor2 68 | 69 | 70 | def find_index(list, data): 71 | for i, item in enumerate(list): 72 | if item == data: 73 | return i 74 | return None 75 | 76 | 77 | def preprocess(path_to_art_dataset, info_path): 78 | style_image_paths = [] 79 | style_str = [] 80 | style_num = [] 81 | label = [] 82 | with open(info_path, 'r') as rfile: 83 | reader = csv.DictReader(rfile) 84 | for row in reader: 85 | if row['style'] is not '': 86 | style_image_paths.append(os.path.join(path_to_art_dataset, row['filename'])) 87 | style_str.append(row['style']) 88 | label = list(set(style_str)) 89 | for i, data in enumerate(style_str): 90 | style_num.append(find_index(label, data)) 91 | 92 | 93 | class StyleDataset(): 94 | def __init__(self, opts): 95 | self.image = [] 96 | self.style = [] 97 | self.root = opts.art_data_path 98 | self.image_size = opts.image_size 99 | with open(opts.csv_path, 'r') as rfile: 100 | reader = csv.DictReader(rfile) 101 | self.image = [row['filename'] for row in reader] 102 | self.style = [row['style'] for row in reader] 103 | label = list(set(self.style)) 104 | for i, data in enumerate(self.style): 105 | self.style[i] = find_index(label, data) 106 | 107 | def __len__(self): 108 | return len(self.image) 109 | 110 | def __getitem__(self, item): 111 | path = os.path.join(self.root, self.image[item]) 112 | image_tensor = image_to_tensor(path, self.image_size) 113 | label = self.style[item] 114 | return image_tensor, label 115 | 116 | 117 | class PlacesDataset(): 118 | categories_names = ['/a/alley', '/a/apartment_building/outdoor', 119 | '/a/aqueduct', '/a/arcade', '/a/arch', 120 | '/a/atrium/public', '/a/auto_showroom', 121 | '/a/amphitheater', 122 | '/b/balcony/exterior', '/b/balcony/interior', '/b/badlands', 123 | '/b/ballroom', '/b/banquet_hall', 124 | '/b/bar', '/b/barn', 125 | '/b/bazaar/outdoor', '/b/beach', '/b/beach_house', 126 | '/b/bedroom', '/b/beer_hall', 127 | '/b/boat_deck', '/b/bookstore', '/b/botanical_garden', 128 | '/b/bridge', '/b/bullring', 129 | '/b/building_facade', '/b/butte', 130 | '/c/cabin/outdoor', '/c/campsite', '/c/campus', '/c/canal/natural', 131 | '/c/canyon', '/c/canal/urban', 132 | '/c/carrousel', '/c/castle', '/c/chalet', 133 | '/c/church/indoor', '/c/church/outdoor', 134 | '/c/cliff', '/c/crevasse', '/c/crosswalk', 135 | '/c/coast', '/c/coffee_shop', 136 | '/c/corn_field', '/c/corral', 137 | '/c/courthouse', '/c/courtyard', 138 | '/d/desert/sand', '/d/desert_road' 139 | '/d/doorway/outdoor', '/d/downtown', 140 | '/d/dressing_room', 141 | '/e/embassy', '/e/entrance_hall', 142 | '/f/field/cultivated', '/f/field/wild', 143 | '/f/field_road', 144 | '/f/formal_garden', '/f/florist_shop/indoor', 145 | '/f/fountain', '/g/gazebo/exterior', 146 | '/g/general_store/outdoor', '/g/glacier', 147 | '/g/grotto', 148 | '/h/harbor', '/h/hayfield', 149 | '/h/hotel/outdoor', 150 | '/h/house', '/h/hunting_lodge/outdoor', '/i/ice_floe', 151 | '/i/iceberg', '/i/igloo', 152 | '/i/inn/outdoor', '/i/islet', '/j/junkyard', '/k/kasbah', 153 | '/l/lagoon', 154 | '/l/lake/natural', '/l/lawn', 155 | '/l/legislative_chamber', '/l/library/outdoor', '/l/lighthouse', 156 | '/l/lobby', '/m/mansion', 157 | '/m/marsh', '/m/mausoleum', 158 | '/m/moat/water', '/m/mosque/outdoor', 159 | '/m/mountain_path', '/m/mountain_snowy', '/m/museum/outdoor', 160 | '/o/oast_house', '/o/ocean', '/o/orchestra_pit', '/p/pagoda', 161 | '/p/palace', 162 | '/p/pasture', '/p/phone_booth', 163 | '/p/picnic_area', '/p/pizzeria', 164 | '/p/plaza', '/p/pond', 165 | '/r/racecourse', '/r/restaurant_patio', '/r/rice_paddy', '/r/river', 166 | '/r/ruin', 167 | '/s/schoolhouse', 168 | '/s/shopfront', '/s/shopping_mall/indoor', 169 | '/s/ski_resort', '/s/sky', '/s/street', 170 | '/s/stable', '/s/swimming_hole', '/s/synagogue/outdoor', '/t/temple/asia', 171 | '/t/throne_room', '/t/tower', 172 | '/t/tree_house', '/t/tundra', '/v/valley', 173 | '/v/viaduct', '/v/village', '/v/volcano', 174 | '/w/water_park', '/w/waterfall', 175 | '/w/wave', '/w/wheat_field', '/w/wind_farm', 176 | '/w/windmill', '/y/yard', 177 | ] 178 | categories_names = [x[1:] for x in categories_names] 179 | 180 | def __init__(self, opts, augmentor): 181 | self.path_to_dataset = opts.content_data_path 182 | self.content_image_size = opts.image_size 183 | self.content_image_paths = [] 184 | self.categories = [] 185 | self.aug = augmentor 186 | for category_idx, category_name in enumerate(tqdm(self.categories_names, ncols=100, mininterval=.5)): 187 | # print(category_name, category_idx) 188 | if os.path.exists(os.path.join(self.path_to_dataset, category_name)): 189 | for file_name in os.listdir(os.path.join(self.path_to_dataset, category_name)): 190 | self.content_image_paths.append(os.path.join(self.path_to_dataset, category_name, file_name)) 191 | self.categories.append(category_name) 192 | else: 193 | pass 194 | # print("Category %s can't be found in path %s. Skip it." % 195 | # (category_name, os.path.join(path_to_dataset, category_name))) 196 | 197 | print("Finished. Constructed Places2 dataset of %d images." % len(self.content_image_paths)) 198 | 199 | self.path_to_art_dataset = opts.art_data_path # train_image 200 | self.artist_list = os.listdir(self.path_to_art_dataset) 201 | self.style_image_paths = [] 202 | self.artist_slugs = [] 203 | self.artist_slugs_oneshot = [] 204 | self.classes = [] 205 | 206 | self.image_size = opts.image_size 207 | 208 | for category_idx, artist_slug in enumerate(tqdm(self.artist_list)): 209 | for file_name in tqdm(os.listdir(os.path.join(self.path_to_art_dataset, artist_slug))): 210 | self.style_image_paths.append(os.path.join(self.path_to_art_dataset, artist_slug, file_name)) 211 | self.artist_slugs.append(artist_slug) 212 | print("Art dataset contains %d images." % len(self.style_image_paths)) 213 | 214 | def __len__(self): 215 | return min(len(self.content_image_paths) // 2, len(self.style_image_paths) // 2) 216 | 217 | def __getitem__(self, item): 218 | content_index = random.randint(0, len(self.content_image_paths) - 1) 219 | content_path = self.content_image_paths[content_index] 220 | 221 | style_index = random.randint(0, len(self.style_image_paths) - 1) 222 | style_path = self.style_image_paths[style_index] 223 | content_image = cv2.imread(content_path) 224 | style_image = cv2.imread(style_path) 225 | content_image = image_check(content_image, augmentor=self.aug) 226 | style_image = image_check(style_image, augmentor=self.aug) 227 | 228 | content_tensor = normalize_vgg(image_to_tensor(content_image, self.image_size)) 229 | style_tensdor = normalize_vgg(image_to_tensor(style_image, self.image_size)) 230 | 231 | return content_tensor, style_tensdor 232 | 233 | def get_label_num(self): 234 | return len(self.artist_list) 235 | 236 | def get_image_num(self): 237 | return len(self.content_image_paths), len(self.style_image_paths) 238 | 239 | def get_label_lenth(self): 240 | return len(self.artist_list), len(self.categories) 241 | 242 | 243 | -------------------------------------------------------------------------------- /process_mask.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | path = '/data2/wuxiaolei/compare-model-pair/adain/input/mask/mask.png' 4 | img = cv2.imread(path) 5 | img2 = 255 - img 6 | cv2.imwrite('mask2.png',img2) 7 | print(img2) 8 | -------------------------------------------------------------------------------- /scripts/inter.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 python interpolation.py \ 2 | --trained_network=./output/transformer_final_tv0_a1_fixdata/checkpoints/gen_00794001.pt \ 3 | --input_path=/data2/wuxiaolei/compare-model-pair/finalresults/add/4/content.png \ 4 | --style1_path=/data2/wuxiaolei/compare-model-pair/finalresults/add/4/style.png \ 5 | --style2_path=/data2/wuxiaolei/compare-model-pair/filteredresults/new9/style22+jiejing91/style.png \ 6 | --output_path=./output_full \ 7 | --spatial_bin=16 \ 8 | --alpha=0.2 \ 9 | --inter_selection=4 \ 10 | --results_path=./xiantiao/A2A1x+b1b282_new 11 | #--style_path=/data1/wuxiaolei/project/grid_exp/grid_cross_final/style 12 | # --selection=aAx+b \ 13 | -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 python test.py \ 2 | --trained_network=./output/StyleFormer_sigmoid/checkpoints/gen_00794001.pt \ 3 | --input_path=/data/dataset/wuxiaolei/test_dataset/new9/content/ \ 4 | --style_path=/data/dataset/wuxiaolei/test_dataset/new9/style/ \ 5 | --output_path=newnewnew9 \ 6 | --spatial_bin=16 \ 7 | --alpha=0.8 \ 8 | --selection=Ax+b \ 9 | --luma_bins=4 \ 10 | --results_path=./tv1 11 | #--style_path=/data1/wuxiaolei/project/grid_exp/grid_cross_final/style 12 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 python train.py \ 2 | --batch_size=10 \ 3 | --n_input_channel=256 \ 4 | --n_input_size=64 \ 5 | --epoch=100 \ 6 | --tb_file=/data3/wuxiaolei/tb/final_cross \ 7 | --sub_file=/StyleFormer_sigmoid \ 8 | --resume \ 9 | --clw=60 \ 10 | --tvw=0 \ 11 | --slw=1 \ 12 | --alpha=1 \ 13 | --spatial_bin=16 \ 14 | --luma_bins=4 \ 15 | --selection=Ax+b 16 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | 4 | import numpy 5 | from PIL import Image 6 | import datetime 7 | 8 | import tensorboardX 9 | import data 10 | from utils import normalize_arr_of_imgs, denormalize_vgg, write_2images, write_images 11 | import argparse 12 | from model import Grid 13 | import torch 14 | from collections import namedtuple 15 | from torchvision import transforms 16 | from torch.utils.data import DataLoader 17 | import torchvision 18 | import matplotlib 19 | 20 | matplotlib.use('Agg') 21 | import matplotlib.pyplot as plt 22 | 23 | from PIL import Image 24 | 25 | 26 | def save_conv_img(conv_img, sub_filename): 27 | root = './feature_map' 28 | if not os.path.exists(root): 29 | os.mkdir(root) 30 | sub_file = root + '/' + sub_filename 31 | if not os.path.exists(sub_file): 32 | os.mkdir(sub_file) 33 | conv_img = conv_img.detach().cpu() 34 | feature_maps = conv_img.squeeze(0) 35 | img_num = feature_maps.shape[0] # 36 | all_feature_maps = [] 37 | for i in range(0, img_num): 38 | single_feature_map = feature_maps[i, :, :] 39 | all_feature_maps.append(single_feature_map) 40 | plt.imshow(single_feature_map) 41 | plt.savefig(sub_file + '/feature_{}'.format(i)) 42 | 43 | sum_feature_map = sum(feature_map for feature_map in all_feature_maps) 44 | plt.imshow(sum_feature_map) 45 | plt.savefig(sub_file + "/feature_map_sum.png") 46 | 47 | 48 | if __name__ == '__main__': 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument('--gf_dim', type=int, default=64) 51 | parser.add_argument('--df_dim', type=int, default=64) 52 | parser.add_argument('--dim', type=int, default=3) 53 | parser.add_argument('--resume', action='store_true') 54 | parser.add_argument('--init', type=str, default='kaiming') 55 | parser.add_argument('--path', type=str, default='/data3/wuxiaolei/models/vgg16-397923af.pth') 56 | 57 | # dataset 58 | parser.add_argument('--root', type=str, default='/data/dataset/') 59 | parser.add_argument('--input_path', type=str, default='/data2/wuxiaolei/project/content') 60 | parser.add_argument('--style_path', type=str, default='/data2/wuxiaolei/project/style') 61 | parser.add_argument('--trained_network', type=str, help="path to the trained network file") 62 | parser.add_argument('--results_path', type=str, default='./results', help='outputs path') 63 | 64 | # data 65 | parser.add_argument('--image_size', type=int, default=256) 66 | 67 | # ouptut 68 | parser.add_argument('--output_path', type=str, default='./output', help='outputs path') 69 | 70 | # train 71 | parser.add_argument('--epoch', type=int, default=100) 72 | parser.add_argument('--lr', type=float, default=0.0001) 73 | parser.add_argument('--lr_policy', type=str, default='constant', help='step/constant') 74 | parser.add_argument('--step_size', type=int, default=200000) 75 | parser.add_argument('--gamma', type=float, default=0.5, help='How much to decay learning rate') 76 | parser.add_argument('--update_D', type=int, default=5) 77 | parser.add_argument('--batch_size', type=int, default=1) 78 | parser.add_argument('--save_freq', type=int, default=10000) 79 | 80 | # loss weight 81 | parser.add_argument('--clw', type=float, default=1, help='content_weight') 82 | parser.add_argument('--slw', type=float, default=10, help='style_weight') 83 | parser.add_argument('--alpha', type=float, default=1, help='style_weight') 84 | 85 | # bilateral grid 86 | parser.add_argument('--luma_bins', type=int, default=8) 87 | parser.add_argument('--channel_multiplier', type=int, default=1) 88 | parser.add_argument('--spatial_bin', type=int, default=8) 89 | parser.add_argument('--n_input_channel', type=int, default=256) 90 | parser.add_argument('--n_input_size', type=int, default=64) 91 | parser.add_argument('--group_num', type=int, default=16) 92 | 93 | # test 94 | parser.add_argument('--selection', type=str, default='Ax+b') 95 | parser.add_argument('--inter_selection', type=str, default='A1x+b2') 96 | # seed 97 | parser.add_argument('--seed', type=int, default=123) 98 | opts = parser.parse_args() 99 | 100 | # fix the seed 101 | numpy.random.seed(opts.seed) 102 | torch.manual_seed(opts.seed) 103 | torch.cuda.manual_seed_all(opts.seed) 104 | 105 | # opts = parser.parse_args() 106 | options = parser.parse_args() 107 | if not os.path.exists(options.output_path): 108 | os.mkdir(options.output_path) 109 | if not os.path.exists(options.results_path): 110 | os.mkdir(options.results_path) 111 | 112 | gpu_num = torch.cuda.device_count() 113 | 114 | myNet = Grid(options, gpu_num).cuda() 115 | 116 | initial_step = myNet.resume_eval(options.trained_network) 117 | torch.backends.cudnn.benchmark = True 118 | transform_style = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor(), 119 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 120 | 121 | transform_content = transforms.Compose([transforms.ToTensor(), 122 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 123 | std=[0.229, 0.224, 0.225])]) 124 | dataset = data.ImageFolder(options.input_path, transform=transform_content, return_paths=True, resize=False) 125 | dataset_style = data.ImageFolder(options.style_path, transform=transform_style, return_paths=True, resize=False) 126 | loader = DataLoader(dataset=dataset, batch_size=1, num_workers=0) 127 | loader_style = DataLoader(dataset=dataset_style, batch_size=1, num_workers=0) 128 | 129 | step = 0 130 | total_time = datetime.datetime.now() - datetime.datetime.now() 131 | 132 | for it, images in enumerate(loader): 133 | for it2, styles in enumerate(loader_style): 134 | step += 1 135 | content = images[0] 136 | 137 | style = styles[0] 138 | content_path = images[1][0] 139 | style_path = styles[1][0] 140 | content_name = os.path.split(content_path)[1] 141 | style_name = os.path.split(style_path)[1] 142 | style = style.cuda() 143 | content = content.cuda() 144 | with torch.no_grad(): 145 | t0 = datetime.datetime.now() 146 | samp = myNet.sample(content, style) 147 | t1 = datetime.datetime.now() 148 | time = t1 - t0 149 | if step != 1: 150 | total_time += time 151 | print("time:%.8s", time.seconds + 1e-6 * time.microseconds) 152 | print('step:{}'.format(step)) 153 | image_outputs = [denormalize_vgg(samp).clamp_(0., 255.)] 154 | 155 | write_2images(image_outputs, 1, options.output_path, f'{style_name[:-4]}+{content_name[:-4]}+base') 156 | 157 | # write the results 158 | c = Image.open(content_path) 159 | style = Image.open(style_path) 160 | result = Image.open(f'{options.output_path}/{style_name[:-4]}+{content_name[:-4]}+base.png') 161 | c = c.resize(i for i in result.size) 162 | s = style.resize((i // 4 for i in result.size)) 163 | 164 | new = Image.new(result.mode, (result.width * 2, result.height)) 165 | new.paste(c, box=(0, 0)) 166 | new.paste(result, box=(c.width, 0)) 167 | box = ((c.width // 4) * 3, c.height - s.height) 168 | new.paste(s, box) 169 | new.save(f'{options.results_path}/{style_name[:-4]}+{content_name[:-4]}.png', quality=95) 170 | avag = total_time / (step - 1) 171 | print("avarge time:%.8s", avag.seconds + 1e-6 * avag.microseconds) 172 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | import parser 5 | 6 | import numpy 7 | import tensorboardX 8 | import torch 9 | import torchvision 10 | from torch.utils import data 11 | 12 | import img_augm 13 | #import prepare_dataset 14 | import dataset 15 | from model import Grid 16 | from utils import prepare_sub_folder, write_loss, write_images, denormalize_vgg_adain, put_tensor_cuda 17 | 18 | if __name__ == '__main__': 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--gf_dim', type=int, default=64) 21 | parser.add_argument('--df_dim', type=int, default=64) 22 | parser.add_argument('--dim', type=int, default=3) 23 | parser.add_argument('--resume', action='store_true') 24 | parser.add_argument('--init', type=str, default='kaiming') 25 | parser.add_argument('--path', type=str, default='/data3/wuxiaolei/models/vgg16-397923af.pth') 26 | parser.add_argument('--tb_file', type=str, default='/data2/wuxiaolei/project/grid_exp/tensorboard') 27 | parser.add_argument('--sub_file', type=str, default='new') 28 | 29 | # dataset 30 | #parser.add_argument('--content_data_path', type=str, default='/data/dataset/Places/data_large') 31 | parser.add_argument("--content_data_path", type=str, default="/data/dataset/wuxiaolei/COCO2014/train2014") 32 | parser.add_argument('--style_data_path', type=str, default='/data/dataset/wuxiaolei/WikiArt/train') 33 | parser.add_argument('--info_path', type=str, default='/data/dataset/wuxiaolei/WikiArt/train_info.csv') 34 | 35 | # data 36 | parser.add_argument('--image_size', type=int, default=256) 37 | 38 | # ouptut 39 | parser.add_argument('--output_path', type=str, default='./output', help='outputs path') 40 | 41 | # train 42 | parser.add_argument('--epoch', type=int, default=100) 43 | parser.add_argument('--lr', type=float, default=0.0001) 44 | parser.add_argument('--lr_policy', type=str, default='constant', help='step/constant') 45 | parser.add_argument('--step_size', type=int, default=200000) 46 | parser.add_argument('--gamma', type=float, default=0.5, help='How much to decay learning rate') 47 | parser.add_argument('--update_D', type=int, default=5) 48 | parser.add_argument('--batch_size', type=int, default=1) 49 | parser.add_argument('--save_freq', type=int, default=50000) 50 | 51 | # loss weight 52 | parser.add_argument('--clw', type=float, default=1, help='content_weight') 53 | parser.add_argument('--slw', type=float, default=10, help='style_weight') 54 | parser.add_argument('--tvw', type=float, default=0.00001, help='tv_loss_weight') 55 | 56 | # bilateral grid 57 | parser.add_argument('--luma_bins', type=int, default=8) 58 | parser.add_argument('--channel_multiplier', type=int, default=1) 59 | parser.add_argument('--spatial_bin', type=int, default=8) 60 | parser.add_argument('--n_input_channel', type=int, default=256) 61 | parser.add_argument('--n_input_size', type=int, default=64) 62 | parser.add_argument('--group_num', type=int, default=16) 63 | parser.add_argument('--alpha', type=float, default=0.8) 64 | parser.add_argument('--selection', type=str, default='Ax+b') 65 | 66 | # seed 67 | parser.add_argument('--seed', type=int, default=123) 68 | opts = parser.parse_args() 69 | 70 | # fix the seed 71 | numpy.random.seed(opts.seed) 72 | torch.manual_seed(opts.seed) 73 | torch.cuda.manual_seed_all(opts.seed) 74 | if not os.path.exists(opts.tb_file): 75 | os.mkdir(opts.tb_file) 76 | 77 | tb_file = opts.tb_file 78 | sub_file = opts.sub_file 79 | # Setup logger and output folders 80 | if not os.path.exists(opts.output_path): 81 | os.mkdir(opts.output_path) 82 | output_directory = opts.output_path + sub_file 83 | checkpoint_directory, image_directory = prepare_sub_folder(output_directory) 84 | print(checkpoint_directory) 85 | train_writer = tensorboardX.SummaryWriter(tb_file+sub_file) 86 | 87 | gpu_num = torch.cuda.device_count() 88 | 89 | # prepare dataset 90 | augmentor = img_augm.Augmentor(crop_size=[opts.image_size, opts.image_size]) 91 | #content_style_dataset = prepare_dataset.PlacesDataset(opts, augmentor) 92 | content_style_dataset = dataset.ArtDataset(opts, augmentor) 93 | total_images = len(content_style_dataset) 94 | print(f"There are total {total_images} image pairs!") 95 | dataloader = data.DataLoader(content_style_dataset, batch_size=opts.batch_size, num_workers=opts.batch_size, shuffle=True) 96 | 97 | # prepare model 98 | trainer = Grid(opts, gpu_num).cuda() 99 | torch.backends.cudnn.benchmark = True 100 | 101 | # start training 102 | print('-' * 8 + 'Start training' + '-' * 8) 103 | initial_step = trainer.resume(checkpoint_directory, opts) if opts.resume else 0 104 | total_step = total_images // opts.batch_size 105 | step = initial_step 106 | for iteration in range(opts.epoch): 107 | for i, data in enumerate(dataloader): 108 | t0 = datetime.datetime.now() 109 | step +=1 110 | input = data 111 | content_cuda = put_tensor_cuda(input[0]) 112 | style_cuda = put_tensor_cuda(input[1]) 113 | trainer.update_learning_rate() 114 | 115 | # training update 116 | trainer.update(content_cuda, style_cuda, opts) 117 | batch_output = trainer.get_output() 118 | batch_content_style = trainer.get_content_style() 119 | display = torch.cat([batch_content_style[:1], batch_output[:1]], 3) 120 | if step % 1000 == 0: 121 | write_loss(step, trainer, train_writer) 122 | if step % 1000 == 0: 123 | write_images('content_style_output', display, train_writer, step) 124 | if step % 1000 == 0: 125 | result = torchvision.utils.make_grid(denormalize_vgg_adain(display).cpu()) 126 | torchvision.utils.save_image(result, os.path.join(image_directory, 'test_%08d.jpg' % (total_step + 1))) 127 | if step % opts.save_freq == 0: 128 | trainer.save(checkpoint_directory, step) 129 | t1 = datetime.datetime.now() 130 | time = t1 - t0 131 | if step % 50 == 0: 132 | print("Epoch: %08d/%08d, iteration: %08d/%08d time: %.8s gloss = %.8s tvloss = %.8f" % ( 133 | iteration + 1, opts.epoch, step, total_step, 134 | time.seconds + 1e-6 * time.microseconds, trainer.gener_loss.item(), trainer.tv_loss.item(), 135 | )) 136 | # if step == 160000: 137 | # # trainer.save(checkpoint_directory, step) 138 | # print('This iteration takes :{}'.format((time.seconds + 1e-6 * time.microseconds))) 139 | # if step == opts.total_steps: 140 | # break 141 | # if step == opts.total_steps: 142 | # break 143 | trainer.save(checkpoint_directory, step) 144 | print("Training is finished.") 145 | print("Done.") 146 | 147 | 148 | 149 | 150 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | # torch.set_printoptions(profile="full") 5 | import torchvision 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | from torch.optim import lr_scheduler 10 | import cv2 11 | import numpy as np 12 | from scipy.spatial.distance import cdist 13 | 14 | 15 | 16 | def normalize_arr_of_imgs(arr): 17 | """ 18 | Normalizes an array so that the result lies in [-1; 1]. 19 | Args: 20 | arr: numpy array of arbitrary shape and dimensions. 21 | Returns: 22 | """ 23 | return arr / 127.5 - 1. 24 | 25 | 26 | def denormalize_arr_of_imgs(arr): 27 | """ 28 | Inverse of the normalize_arr_of_imgs function. 29 | Args: 30 | arr: numpy array of arbitrary shape and dimensions. 31 | Returns: 32 | """ 33 | return (arr + 1.) * 127.5 34 | 35 | def normalize_vgg_adain(arr): 36 | """ 37 | Normalizeds an arry so that the result lies in [0,1] 38 | """ 39 | return (arr/255.) 40 | 41 | def denormalize_vgg_adain(arr): 42 | return (arr*255.) 43 | 44 | def put_tensor_cuda(tensor): 45 | results = Variable(tensor).cuda() 46 | return (results) 47 | 48 | 49 | def prepare_sub_folder(output_directory): 50 | if not os.path.exists(output_directory): 51 | os.mkdir(output_directory) 52 | image_directory = os.path.join(output_directory, 'images') 53 | if not os.path.exists(image_directory): 54 | print("Creating directory: {}".format(image_directory)) 55 | os.makedirs(image_directory) 56 | checkpoint_directory = os.path.join(output_directory, 'checkpoints') 57 | if not os.path.exists(checkpoint_directory): 58 | print("Creating directory: {}".format(checkpoint_directory)) 59 | os.makedirs(checkpoint_directory) 60 | return checkpoint_directory, image_directory 61 | 62 | 63 | def get_scheduler(optimizer, options, iterations=-1): 64 | if options.lr_policy == 'step': 65 | scheduler = lr_scheduler.StepLR(optimizer, step_size=options.step_size, 66 | gamma=options.gamma, last_epoch=iterations) 67 | else: 68 | scheduler = None # constant scheduler 69 | return scheduler 70 | 71 | 72 | def write_images(str, images_tensor, trainer_writer, step): 73 | imges_tensor_de = denormalize_vgg(images_tensor) 74 | imges_tensor_de = imges_tensor_de.clamp(0., 255.) 75 | image_grid = torchvision.utils.make_grid(imges_tensor_de.cpu(), normalize=True) 76 | trainer_writer.add_image(str, image_grid, step + 1) 77 | 78 | 79 | def write_loss(iterations, trainer, train_writer): 80 | members = [attr for attr in dir(trainer) \ 81 | if not callable(getattr(trainer, attr)) and not attr.startswith("__") and ( 82 | 'loss' in attr or 'grad' in attr or 'nwd' in attr)] 83 | for m in members: 84 | train_writer.add_scalar(m, getattr(trainer, m), iterations + 1) 85 | 86 | 87 | # Get model list for resume 88 | def get_model_list(dirname, key): 89 | if os.path.exists(dirname) is False: 90 | return None 91 | gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if 92 | os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f] 93 | if len(gen_models) == 0: 94 | return None 95 | gen_models.sort() 96 | last_model_name = gen_models[-1] 97 | return last_model_name 98 | 99 | 100 | def image_check(image, style_path,augmentor=None): 101 | try: 102 | h, w = image.shape[0], image.shape[1] 103 | except: 104 | print(style_path) 105 | if max(image.shape) > 500.: 106 | scale = 500 / max(image.shape) 107 | image = cv2.resize(image, dsize=(int(scale * w), int(scale * h))) 108 | if max(image.shape) < 300: 109 | # Resize the smallest side of the image to 800px 110 | alpha = 300. / float(min(image.shape)) 111 | if alpha < 4.: 112 | image = cv2.resize(image, dsize=(int(alpha * h), int(alpha * w))) 113 | else: 114 | image = cv2.resize(image, dsize=(300, 300)) 115 | 116 | if augmentor is not None: 117 | image = augmentor(image).astype(np.float32) 118 | return image 119 | 120 | 121 | def gram_matrix(tensor): 122 | # Unwrapping the tensor dimensions into respective variables i.e. batch size, distance, height and width 123 | _, d, h, w = tensor.size() 124 | # Reshaping data into a two dimensional of array or two dimensional of tensor 125 | tensor = tensor.view(d, h * w) 126 | # Multiplying the original tensor with its own transpose using torch.mm 127 | # tensor.t() will return the transpose of original tensor 128 | gram = torch.mm(tensor, tensor.t()) 129 | # Returning gram matrix 130 | return gram.div(tensor.nelement()) 131 | 132 | 133 | # normalize image to satisfy the vgg 134 | def normalize_vgg(im): 135 | im /= 255. 136 | im[0, :, :] -= 0.485 137 | im[1, :, :] -= 0.456 138 | im[2, :, :] -= 0.406 139 | im[0, :, :] /= 0.229 140 | im[1, :, :] /= 0.224 141 | im[2, :, :] /= 0.225 142 | return im 143 | 144 | 145 | def denormalize_vgg(im): 146 | im[:, 0, :, :] *= 0.229 147 | im[:, 1, :, :] *= 0.224 148 | im[:, 2, :, :] *= 0.225 149 | im[:, 0, :, :] += 0.485 150 | im[:, 1, :, :] += 0.456 151 | im[:, 2, :, :] += 0.406 152 | im *= 255. 153 | return im 154 | 155 | 156 | def calc_mean_std(feat, eps=1e-5): 157 | # eps is a small value added to the variance to avoid divide-by-zero. 158 | size = feat.size() 159 | assert (len(size) == 4) 160 | N, C = size[:2] 161 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 162 | feat_std_vector = feat_var.sqrt() 163 | feat_mean_vector = feat.view(N, C, -1).mean(dim=2) 164 | return feat_mean_vector, feat_std_vector 165 | 166 | 167 | class LocalGroupNorm(nn.Module): 168 | def __init__(self, input_channel, G=32, window_size=16): 169 | super(LocalGroupNorm, self).__init__() 170 | self.G = G 171 | self.window_size = window_size 172 | 173 | temp = torch.zeros([input_channel // G, input_channel // G, 1, 1]) 174 | temp_view = temp.view(temp.shape[0], temp.shape[1]) 175 | nn.init.eye_(temp_view) 176 | temp = temp_view.view_as(temp) 177 | self.c1_weight = torch.ones([input_channel // G, 1, window_size, window_size]) / ( 178 | window_size * window_size) 179 | self.c2_weight = temp 180 | 181 | def forward(self, input): 182 | epsilon = 1e-5 183 | G = self.G 184 | N, C, H, W = input.shape[0], input.shape[1], input.shape[2], input.shape[3] 185 | depth = input.shape[1] 186 | self.c1_weight = self.c1_weight.cuda(device=input.device) 187 | self.c2_weight = self.c2_weight.cuda(device=input.device) 188 | 189 | input_reshaped = input.view(N, C // G, G, H, W) 190 | 191 | means = torch.mean(input_reshaped, dim=2) # N, C//G, H W 1, 8, 64, 64 192 | means = F.conv2d(F.conv2d(means, weight=self.c1_weight, stride=self.window_size, groups=C // G), 193 | weight=self.c2_weight, stride=1) 194 | 195 | means = F.interpolate(means, (input.shape[2], input.shape[3]), mode='bilinear') 196 | means = means.unsqueeze(2).repeat(1, 1, G, 1, 1) 197 | 198 | stds = (input_reshaped - means).pow(2) 199 | stds = torch.sqrt(torch.mean(stds, dim=2)) 200 | stds = F.conv2d(F.conv2d(stds, weight=self.c1_weight, stride=self.window_size, groups=C // G), 201 | weight=self.c2_weight, stride=1) 202 | stds = F.interpolate(stds, (input.shape[2], input.shape[3]), mode='bilinear') 203 | stds = stds.unsqueeze(2).repeat(1, 1, G, 1, 1) 204 | 205 | input = (input_reshaped - means) / torch.sqrt(torch.abs(stds) + epsilon) 206 | input = input.view(N, C, H, W) 207 | means = means.view(N, C, H, W) 208 | stds = stds.view(N, C, H, W) 209 | return input, means, stds 210 | 211 | 212 | def adaptive_instance_normalization(content_feat, style_feat): 213 | assert (content_feat.size()[:2] == style_feat.size()[:2]) 214 | size = content_feat.size() 215 | style_mean, style_std = calc_mean_std(style_feat) 216 | content_mean, content_std = calc_mean_std(content_feat) 217 | 218 | style_mean = style_mean.view(size[0], size[1], 1, 1) 219 | style_std = style_std.view(size[0], size[1], 1, 1) 220 | content_mean = content_mean.view(size[0], size[1], 1, 1) 221 | content_std = content_std.view(size[0], size[1], 1, 1) 222 | 223 | normalized_feat = (content_feat - content_mean.expand( 224 | size)) / content_std.expand(size) 225 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 226 | 227 | 228 | def __write_images(image_outputs, display_image_num, file_name): 229 | # image_outputs = [images.expand(-1, 3, -1, -1) for images in 230 | # image_outputs] # expand gray-scale images to 3 channels 231 | image_tensor = torch.cat([images[:display_image_num] for images in image_outputs], 0) 232 | image_grid = torchvision.utils.make_grid(image_tensor.data, nrow=display_image_num, padding=0, normalize=True) 233 | torchvision.utils.save_image(image_grid, file_name, nrow=1) 234 | 235 | 236 | def write_2images(image_outputs, display_image_num, image_directory, postfix): 237 | n = len(image_outputs) 238 | __write_images(image_outputs, display_image_num, '%s/%s.png' % (image_directory, postfix)) 239 | 240 | def get_hw_2(content_feat, style_feat, g_num): 241 | device = content_feat.get_device() 242 | G = g_num 243 | N, C, H, W = content_feat.shape 244 | content_vec = content_feat.view(N*G, C//G, -1).permute(0, 2, 1) 245 | style_vec = style_feat.view(N*G, C//G, -1).permute(0, 2, 1) 246 | c_numpy = content_vec.detach().cpu().numpy() 247 | s_numpy = style_vec.detach().cpu().numpy() 248 | spatial_maps = [] 249 | for i in range(N*G): 250 | distances = cdist(c_numpy[i], s_numpy[i], metric='cosine') # 64*64 64*64 251 | closest = np.argmax(distances, axis=1) 252 | closest = Variable(torch.from_numpy(closest)).to(device) 253 | index_i = (closest // 64).view(1, H, W, 1).float()/(H - 1) * 2 - 1 254 | index_j = (closest % 64).view(1, H, W, 1).float()/(W - 1) * 2 - 1 255 | spatial_map = torch.cat((index_j, index_i), dim=3) # H, W,2 256 | spatial_maps.append(spatial_map) 257 | hw_map = torch.cat(spatial_maps, dim=0) # N, H, W 2 258 | return hw_map 259 | 260 | def get_hw(content_feat, g_num): 261 | content_feat = content_feat.detach() 262 | device = content_feat.get_device() 263 | 264 | N, _, H, W = content_feat.shape 265 | hg, wg = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)]) # [0,511] HxW 266 | hg = hg.to(device) 267 | wg = wg.to(device) 268 | hg = hg.float().repeat(N*g_num, 1, 1).unsqueeze(3) / (H-1) # norm to [0,1] NxHxW 269 | wg = wg.float().repeat(N*g_num, 1, 1).unsqueeze(3) / (W-1) # norm to [0,1] NxHxW 270 | hg, wg = hg*2-1, wg*2-1 271 | hw_map = torch.cat([wg,hg],dim=3) # NG, H, W,2 272 | return hw_map 273 | 274 | 275 | def get_hw3(content_feat, style_feat, g_num): 276 | content_feat = content_feat.detach() 277 | style_feat = style_feat.detach() 278 | device = content_feat.get_device() 279 | G = g_num 280 | N, C, H, W = content_feat.shape 281 | NG = N*G 282 | content_vec = content_feat.view(NG, C//G, -1) # NG, C, HW 283 | style_vec = style_feat.view(NG, C//G, -1) 284 | multi = content_vec*content_vec 285 | content_norm = content_vec / (torch.sqrt(torch.sum(content_vec*content_vec, dim=1, keepdim=True))) # unit vector NG, C, HW 286 | style_norm = style_vec / torch.sqrt(torch.sum(style_vec*style_vec, dim=1, keepdim=True)) # NG, C, HW 287 | closests = [] 288 | for i in range(NG): 289 | #cosine_dist = 1. - torch.bmm(content_norm[i:i+1, :, :].permute(0,2,1), style_norm[i:i+1, :, :]) # NG, hw, hw 290 | cosine_dist = torch.bmm(content_norm[i:i+1, :, :].permute(0,2,1), style_norm[i:i+1, :, :]) # NG, hw, hw 291 | closest = torch.argmax(cosine_dist, dim=2) # NG, HW 292 | closests.append(closest) 293 | closest = torch.cat(closests, dim=0) 294 | #cosine_dists = 1. - torch.bmm(content_norm.permute(0,2,1), style_norm) # NG, hw, hw 295 | 296 | # calculate i,j based on index 297 | index_i = (closest // 64).view(NG, H, W, 1).float() / (H - 1) * 2 - 1 298 | index_j = (closest % 64).view(NG, H, W, 1).float() / (W - 1) * 2 - 1 299 | 300 | # ###### 301 | # index_i = (closest // 64).view(NG, H, W) 302 | # index_j = (closest % 64).view(NG, H, W) 303 | # print(index_i[1]) 304 | # print(index_i[0].shape) 305 | # print(index_j[1]) 306 | # exit(0) 307 | hw_map = torch.cat((index_j, index_i), dim=3) # H, W, 2 308 | # hw_map_crop = hw_map[:,10:-10, 10:-10, :].permute(0, 3, 1, 2).contiguous() 309 | # hw_map = F.pad(hw_map_crop, (10,10,10,10), mode='reflect').permute(0,2,3,1).contiguous() 310 | return hw_map 311 | 312 | def TVloss(img, tv_weight): 313 | """ 314 | Compute total variation loss. 315 | Inputs: 316 | - img: PyTorch Variable of shape (1, 3, H, W) holding an input image. 317 | - tv_weight: Scalar giving the weight w_t to use for the TV loss. 318 | Returns: 319 | - loss: PyTorch Variable holding a scalar giving the total variation loss 320 | for img weighted by tv_weight. 321 | """ 322 | w_variance = torch.mean(torch.pow(img[:, :, :, :-1] - img[:, :, :, 1:], 2)) 323 | h_variance = torch.mean(torch.pow(img[:, :, :-1, :] - img[:, :, 1:, :], 2)) 324 | loss = tv_weight * (h_variance + w_variance) 325 | return loss 326 | 327 | def content_loss(x, y): 328 | N,C, _, _ = x.shape 329 | x_vec = x.view(N,C,-1) 330 | y_vec = y.view(N,C,-1) 331 | D_X = pairwise_distances_cos(x_vec, x_vec) 332 | D_X = D_X/D_X.sum(1, keepdim=True) 333 | D_Y = pairwise_distances_cos(y_vec, y_vec) 334 | D_Y = D_Y/D_Y.sum(1, keepdim=True) 335 | 336 | d = torch.abs(D_X-D_Y).mean() 337 | return d 338 | 339 | def pairwise_distances_cos(x, y): 340 | # x : N,C,-1 341 | x_norm = x/torch.sqrt((x**2).sum(1,keepdim=True)) # N, HW 342 | x_t = x.permute(0, 2, 1) 343 | # y_t = y.transpose(1,2) 344 | y_norm = y/torch.sqrt((y**2).sum(1, keepdim=True)) 345 | 346 | mul = torch.bmm(x_t, y) 347 | 348 | dist = 1.- mul #(N, hw*hw) 349 | 350 | return dist 351 | --------------------------------------------------------------------------------