├── README.md ├── VGGdecoder.py ├── content ├── avril.jpg ├── blonde_girl.jpg ├── brad_pitt.jpg ├── cornell.jpg ├── golden_gate.jpg ├── in1.jpg ├── lenna.jpg ├── modern.jpg ├── neko.jpg ├── sailboat.jpg └── woman_side_portrait.jpg ├── dataset.py ├── feature_transfer.py ├── model.py ├── normalisedVGG.py ├── result ├── avril_876_demo.jpg ├── avril_antimonocromatismo_demo.jpg ├── avril_asheville_demo.jpg ├── avril_brushstrokers_demo.jpg ├── avril_candy_demo.jpg ├── avril_contrast_of_forms_demo.jpg ├── avril_picasso_self_portrait_demo.jpg └── avril_scene_de_rue_demo.jpg ├── style ├── 088.jpg ├── 101308.jpg ├── 27.jpg ├── 876.jpg ├── antimonocromatismo.jpg ├── asheville.jpg ├── brick1.jpg ├── brushstrokers.jpg ├── candy.jpg ├── en_campo_gris.jpg ├── in2.jpg ├── la_muse.jpg ├── mondrian.jpg ├── mosaic.jpg ├── news1.jpg ├── picasso_seated_nude_hr.jpg ├── picasso_self_portrait.jpg ├── plum_flower.jpg ├── rain-princess.jpg ├── scene_de_rue.jpg ├── seated-nude.jpg ├── sketch.png ├── trial.jpg ├── woman_in_peasant_dress_cropped.jpg └── woman_with_hat_matisse.jpg ├── test.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch_MST 2 | Unofficial Pytorch(1.0+) implementation of ICCV 2019 paper ["Multimodal Style Transfer via Graph Cuts"](https://arxiv.org/abs/1904.04443). 3 | 4 | Original tensorflow implementations from the authon will be found [here](https://github.com/yulunzhang/MST). 5 | 6 | This repository provides a pre-trained model for you to generate your own image given content image and style image. Also, you can download the training dataset or prepare your own dataset to train the model from scratch. 7 | 8 | If you have any question, please feel free to contact me. (Language in English/Japanese/Chinese will be ok!) 9 | 10 | ## Notice 11 | I propose a structure-emphasized multimodal style transfer(SEMST), feel free to use it [here](https://github.com/irasin/Structure-emphasized-Multimodal-Style-Transfer). 12 | 13 | ## Requirements 14 | 15 | - Python 3.7+ 16 | - PyTorch 1.0+ 17 | - TorchVision 18 | - Pillow 19 | - PyMaxflow 20 | 21 | Anaconda environment recommended here! 22 | 23 | (optional) 24 | 25 | - GPU environment 26 | 27 | 28 | 29 | ## test 30 | 31 | 1. Clone this repository 32 | 33 | ```bash 34 | git clone https://github.com/irasin/Pytorch_MST 35 | cd Pytorch_MST 36 | ``` 37 | 38 | 2. Prepare your content image and style image. I provide some in the `content` and `style` and you can try to use them easily. 39 | 40 | 3. Download the pretrained model [here](https://drive.google.com/file/d/16mhOUIo8HKDv9NhlI1GyKvpqST8P9fGw/view?usp=sharing) 41 | 42 | 4. Generate the output image. A transferred output image w/&w/o style image and a NST_demo_like image will be generated. 43 | 44 | ```python 45 | python test.py -c content_image_path -s style_image_path 46 | ``` 47 | 48 | ``` 49 | usage: test.py [-h] [--content CONTENT] [--style STYLE] 50 | [--output_name OUTPUT_NAME] [--n_cluster N_CLUSTER] 51 | [--alpha ALPHA] [--lam LAM] [--max_cycles MAX_CYCLES] 52 | [--gpu GPU] [--model_state_path MODEL_STATE_PATH] 53 | ``` 54 | 55 | If output_name is not given, it will use the combination of content image name and style image name. 56 | 57 | 58 | ------ 59 | 60 | ## train 61 | 62 | 1. Download [COCO](http://cocodataset.org/#download) (as content dataset)and [Wikiart](https://www.kaggle.com/c/painter-by-numbers) (as style dataset) and unzip them, rename them as `content` and `style` respectively (recommended). 63 | 64 | 2. Modify the argument in the` train.py` such as the path of directory, epoch, learning_rate or you can add your own training code. 65 | 66 | 3. Train the model using gpu. 67 | 68 | 4. ```python 69 | python train.py 70 | ``` 71 | 72 | ``` 73 | usage: train.py [-h] [--batch_size BATCH_SIZE] [--epoch EPOCH] [--gpu GPU] 74 | [--learning_rate LEARNING_RATE] 75 | [--snapshot_interval SNAPSHOT_INTERVAL] 76 | [--n_cluster N_CLUSTER] [--alpha ALPHA] [--lam LAM] 77 | [--max_cycles MAX_CYCLES] [--gamma GAMMA] 78 | [--train_content_dir TRAIN_CONTENT_DIR] 79 | [--train_style_dir TRAIN_STYLE_DIR] 80 | [--test_content_dir TEST_CONTENT_DIR] 81 | [--test_style_dir TEST_STYLE_DIR] [--save_dir SAVE_DIR] 82 | [--reuse REUSE] 83 | ``` 84 | 85 | 86 | 87 | # Result 88 | 89 | Some results of content image will be shown here. 90 | 91 | ![image](https://github.com/irasin/Pytorch_MST/blob/master/result/avril_contrast_of_forms_demo.jpg) 92 | ![image](https://github.com/irasin/Pytorch_MST/blob/master/result/avril_scene_de_rue_demo.jpg) 93 | ![image](https://github.com/irasin/Pytorch_MST/blob/master/result/avril_picasso_self_portrait_demo.jpg) 94 | ![image](https://github.com/irasin/Pytorch_MST/blob/master/result/avril_candy_demo.jpg) 95 | ![image](https://github.com/irasin/Pytorch_MST/blob/master/result/avril_brushstrokers_demo.jpg) 96 | ![image](https://github.com/irasin/Pytorch_MST/blob/master/result/avril_asheville_demo.jpg) 97 | ![image](https://github.com/irasin/Pytorch_MST/blob/master/result/avril_antimonocromatismo_demo.jpg) 98 | ![image](https://github.com/irasin/Pytorch_MST/blob/master/result/avril_876_demo.jpg) 99 | 100 | -------------------------------------------------------------------------------- /VGGdecoder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Interpolate(nn.Module): 8 | def __init__(self, scale_factor=2): 9 | super().__init__() 10 | self.scale_factor = scale_factor 11 | 12 | def forward(self, x): 13 | x = F.interpolate(x, scale_factor=self.scale_factor) 14 | return x 15 | 16 | 17 | vgg_decoder_relu5_1 = nn.Sequential( 18 | nn.ReflectionPad2d((1, 1, 1, 1)), 19 | nn.Conv2d(512, 512, 3), 20 | nn.ReLU(), 21 | Interpolate(2), 22 | nn.ReflectionPad2d((1, 1, 1, 1)), 23 | nn.Conv2d(512, 512, 3), 24 | nn.ReLU(), 25 | nn.ReflectionPad2d((1, 1, 1, 1)), 26 | nn.Conv2d(512, 512, 3), 27 | nn.ReLU(), 28 | nn.ReflectionPad2d((1, 1, 1, 1)), 29 | nn.Conv2d(512, 512, 3), 30 | nn.ReLU(), 31 | nn.ReflectionPad2d((1, 1, 1, 1)), 32 | nn.Conv2d(512, 256, 3), 33 | nn.ReLU(), 34 | Interpolate(2), 35 | nn.ReflectionPad2d((1, 1, 1, 1)), 36 | nn.Conv2d(256, 256, 3), 37 | nn.ReLU(), 38 | nn.ReflectionPad2d((1, 1, 1, 1)), 39 | nn.Conv2d(256, 256, 3), 40 | nn.ReLU(), 41 | nn.ReflectionPad2d((1, 1, 1, 1)), 42 | nn.Conv2d(256, 256, 3), 43 | nn.ReLU(), 44 | nn.ReflectionPad2d((1, 1, 1, 1)), 45 | nn.Conv2d(256, 128, 3), 46 | nn.ReLU(), 47 | Interpolate(2), 48 | nn.ReflectionPad2d((1, 1, 1, 1)), 49 | nn.Conv2d(128, 128, 3), 50 | nn.ReLU(), 51 | nn.ReflectionPad2d((1, 1, 1, 1)), 52 | nn.Conv2d(128, 64, 3), 53 | nn.ReLU(), 54 | Interpolate(2), 55 | nn.ReflectionPad2d((1, 1, 1, 1)), 56 | nn.Conv2d(64, 64, 3), 57 | nn.ReLU(), 58 | nn.ReflectionPad2d((1, 1, 1, 1)), 59 | nn.Conv2d(64, 3, 3) 60 | ) 61 | 62 | 63 | class Decoder(nn.Module): 64 | def __init__(self, level, pretrained_path=None): 65 | super().__init__() 66 | if level == 1: 67 | self.net = nn.Sequential(*copy.deepcopy(list(vgg_decoder_relu5_1.children())[-2:])) 68 | elif level == 2: 69 | self.net = nn.Sequential(*copy.deepcopy(list(vgg_decoder_relu5_1.children())[-9:])) 70 | elif level == 3: 71 | self.net = nn.Sequential(*copy.deepcopy(list(vgg_decoder_relu5_1.children())[-16:])) 72 | elif level == 4: 73 | self.net = nn.Sequential(*copy.deepcopy(list(vgg_decoder_relu5_1.children())[-29:])) 74 | elif level == 5: 75 | self.net = nn.Sequential(*copy.deepcopy(list(vgg_decoder_relu5_1.children()))) 76 | else: 77 | raise ValueError('level should be between 1~5') 78 | 79 | if pretrained_path is not None: 80 | self.net.load_state_dict(torch.load(pretrained_path, map_location=lambda storage, loc: storage)) 81 | 82 | def forward(self, x): 83 | return self.net(x) 84 | -------------------------------------------------------------------------------- /content/avril.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/content/avril.jpg -------------------------------------------------------------------------------- /content/blonde_girl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/content/blonde_girl.jpg -------------------------------------------------------------------------------- /content/brad_pitt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/content/brad_pitt.jpg -------------------------------------------------------------------------------- /content/cornell.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/content/cornell.jpg -------------------------------------------------------------------------------- /content/golden_gate.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/content/golden_gate.jpg -------------------------------------------------------------------------------- /content/in1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/content/in1.jpg -------------------------------------------------------------------------------- /content/lenna.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/content/lenna.jpg -------------------------------------------------------------------------------- /content/modern.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/content/modern.jpg -------------------------------------------------------------------------------- /content/neko.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/content/neko.jpg -------------------------------------------------------------------------------- /content/sailboat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/content/sailboat.jpg -------------------------------------------------------------------------------- /content/woman_side_portrait.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/content/woman_side_portrait.jpg -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | from tqdm import tqdm 5 | import torch 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | from skimage import io, transform 9 | from PIL import Image 10 | 11 | 12 | trans = transforms.Compose([transforms.RandomCrop(256), 13 | transforms.ToTensor()]) 14 | 15 | 16 | class PreprocessDataset(Dataset): 17 | def __init__(self, content_dir, style_dir, transforms=trans): 18 | content_dir_resized = content_dir + '_resized' 19 | style_dir_resized = style_dir + '_resized' 20 | if not (os.path.exists(content_dir_resized) and 21 | os.path.exists(style_dir_resized)): 22 | os.mkdir(content_dir_resized) 23 | os.mkdir(style_dir_resized) 24 | self._resize(content_dir, content_dir_resized) 25 | self._resize(style_dir, style_dir_resized) 26 | content_images = glob.glob((content_dir_resized + '/*')) 27 | np.random.shuffle(content_images) 28 | style_images = glob.glob(style_dir_resized + '/*') 29 | np.random.shuffle(style_images) 30 | self.images_pairs = list(zip(content_images, style_images)) 31 | self.transforms = transforms 32 | 33 | @staticmethod 34 | def _resize(source_dir, target_dir): 35 | print(f'Start resizing {source_dir} ') 36 | for i in tqdm(os.listdir(source_dir)): 37 | filename = os.path.basename(i) 38 | try: 39 | image = io.imread(os.path.join(source_dir, i)) 40 | if len(image.shape) == 3 and image.shape[-1] == 3: 41 | H, W, _ = image.shape 42 | if H < W: 43 | ratio = W / H 44 | H = 512 45 | W = int(ratio * H) 46 | else: 47 | ratio = H / W 48 | W = 512 49 | H = int(ratio * W) 50 | image = transform.resize(image, (H, W), mode='reflect', anti_aliasing=True) 51 | io.imsave(os.path.join(target_dir, filename), image) 52 | except: 53 | continue 54 | 55 | def __len__(self): 56 | return len(self.images_pairs) 57 | 58 | def __getitem__(self, index): 59 | content_image, style_image = self.images_pairs[index] 60 | content_image = Image.open(content_image) 61 | style_image = Image.open(style_image) 62 | 63 | if self.transforms: 64 | content_image = self.transforms(content_image) 65 | style_image = self.transforms(style_image) 66 | return content_image, style_image 67 | -------------------------------------------------------------------------------- /feature_transfer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from maxflow.fastmin import aexpansion_grid 4 | from sklearn.cluster import KMeans 5 | 6 | 7 | def data_term(content_feature, cluster_centers): 8 | c = content_feature.permute(1, 2, 0) 9 | d = torch.matmul(c, cluster_centers) 10 | c_norm = torch.norm(c, dim=2, keepdim=True) 11 | s_norm = torch.norm(cluster_centers, dim=0, keepdim=True) 12 | norm = torch.matmul(c_norm, s_norm) 13 | d = 1 - d.div(norm) 14 | return d 15 | 16 | 17 | def pairwise_term(cluster_centers, lam): 18 | _, k = cluster_centers.shape 19 | v = torch.ones((k, k)) - torch.eye(k) 20 | v = lam * v.to(cluster_centers.device) 21 | return v 22 | 23 | 24 | def labeled_whiten_and_color(f_c, f_s, alpha, label): 25 | try: 26 | c, h, w = f_c.shape 27 | cf = (f_c * label).reshape(c, -1) 28 | c_mean = torch.mean(cf, 1).reshape(c, 1, 1) * label 29 | 30 | cf = cf.reshape(c, h, w) - c_mean 31 | cf = cf.reshape(c, -1) 32 | c_cov = torch.mm(cf, cf.t()).div(torch.sum(label).item() / c - 1) 33 | c_u, c_e, c_v = torch.svd(c_cov) 34 | 35 | # if necessary, use k-th largest eig-value 36 | k_c = c 37 | # for i in range(c): 38 | # if c_e[i] < 0.00001: 39 | # k_c = i 40 | # break 41 | c_d = c_e[:k_c].pow(-0.5) 42 | 43 | w_step1 = torch.mm(c_v[:, :k_c], torch.diag(c_d)) 44 | w_step2 = torch.mm(w_step1, (c_v[:, :k_c].t())) 45 | whitened = torch.mm(w_step2, cf) 46 | 47 | sf = f_s.t() 48 | c, k = sf.shape 49 | s_mean = torch.mean(sf, 1, keepdim=True) 50 | sf = sf - s_mean 51 | s_cov = torch.mm(sf, sf.t()).div(k - 1) 52 | s_u, s_e, s_v = torch.svd(s_cov) 53 | 54 | # if necessary, use k-th largest eig-value 55 | k_s = c 56 | # for i in range(c): 57 | # if s_e[i] < 0.00001: 58 | # k_s = i 59 | # break 60 | s_d = s_e[:k_s].pow(0.5) 61 | 62 | c_step1 = torch.mm(s_v[:, :k_s], torch.diag(s_d)) 63 | c_step2 = torch.mm(c_step1, s_v[:, :k_s].t()) 64 | colored = torch.mm(c_step2, whitened).reshape(c, h, w) 65 | s_mean = s_mean.reshape(c, 1, 1) * label 66 | colored = colored + s_mean 67 | colored_feature = alpha * colored + (1 - alpha) * (f_c * label) 68 | except: 69 | # Need fix 70 | # RuntimeError: MAGMA gesdd : the updating process of SBDSDC did not converge 71 | colored_feature = f_c * label 72 | 73 | return colored_feature 74 | 75 | 76 | class MultimodalStyleTransfer: 77 | def __init__(self, n_cluster, alpha, device='cpu', lam=0.1, max_cycles=None): 78 | self.k = n_cluster 79 | self.k_means_estimator = KMeans(n_cluster) 80 | if (type(alpha) is int or type(alpha) is float) and 0 <= alpha <= 1: 81 | self.alpha = [alpha] * n_cluster 82 | elif type(alpha) is list and len(alpha) == n_cluster: 83 | self.alpha = alpha 84 | else: 85 | raise ValueError('Error for alpha') 86 | 87 | self.device = device 88 | self.lam = lam 89 | self.max_cycles = max_cycles 90 | 91 | def style_feature_clustering(self, style_feature): 92 | C, _, _ = style_feature.shape 93 | s = style_feature.reshape(C, -1).transpose(0, 1) 94 | 95 | self.k_means_estimator.fit(s.to('cpu')) 96 | labels = torch.Tensor(self.k_means_estimator.labels_).to(self.device) 97 | cluster_centers = torch.Tensor(self.k_means_estimator.cluster_centers_).to(self.device).transpose(0, 1) 98 | 99 | s = s.to(self.device) 100 | clusters = [s[labels == i] for i in range(self.k)] 101 | 102 | return cluster_centers, clusters 103 | 104 | def graph_based_style_matching(self, content_feature, style_feature): 105 | cluster_centers, s_clusters = self.style_feature_clustering(style_feature) 106 | 107 | D = data_term(content_feature, cluster_centers).to('cpu').numpy().astype(np.double) 108 | V = pairwise_term(cluster_centers, lam=self.lam).to('cpu').numpy().astype(np.double) 109 | labels = torch.Tensor(aexpansion_grid(D, V, max_cycles=self.max_cycles)).to(self.device) 110 | return labels, s_clusters 111 | 112 | def transfer(self, content_feature, style_feature): 113 | labels, s_clusters = self.graph_based_style_matching(content_feature, style_feature) 114 | f_cs = torch.zeros_like(content_feature) 115 | for f_s, a, k in zip(s_clusters, self.alpha, range(self.k)): 116 | label = (labels == k).unsqueeze(dim=0).expand_as(content_feature) 117 | if (label > 0).any(): 118 | label = label.to(torch.float) 119 | f_cs += labeled_whiten_and_color(content_feature, f_s, a, label) 120 | 121 | return f_cs 122 | 123 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from feature_transfer import MultimodalStyleTransfer 6 | from normalisedVGG import NormalisedVGG 7 | from VGGdecoder import Decoder 8 | from utils import download_file_from_google_drive 9 | 10 | 11 | def calc_mean_std(features): 12 | batch_size, c = features.size()[:2] 13 | features_mean = features.reshape(batch_size, c, -1).mean(dim=2).reshape(batch_size, c, 1, 1) 14 | features_std = features.reshape(batch_size, c, -1).std(dim=2).reshape(batch_size, c, 1, 1) 15 | return features_mean, features_std 16 | 17 | 18 | class VGGEncoder(nn.Module): 19 | def __init__(self, pretrained_path=None): 20 | super().__init__() 21 | vgg = NormalisedVGG(pretrained_path=pretrained_path).net 22 | self.block1 = vgg[: 4] 23 | self.block2 = vgg[4: 11] 24 | self.block3 = vgg[11: 18] 25 | self.block4 = vgg[18: 31] 26 | 27 | for p in self.parameters(): 28 | p.requires_grad = False 29 | 30 | def forward(self, images, output_last_feature=True): 31 | h1 = self.block1(images) 32 | h2 = self.block2(h1) 33 | h3 = self.block3(h2) 34 | h4 = self.block4(h3) 35 | if output_last_feature: 36 | return h4 37 | else: 38 | return h1, h2, h3, h4 39 | 40 | 41 | class Model(nn.Module): 42 | def __init__(self, 43 | n_cluster=3, 44 | alpha=1, 45 | device='cpu', 46 | lam=0.1, 47 | pre_train=False, 48 | max_cycles=None): 49 | super().__init__() 50 | self.n_cluster = n_cluster 51 | self.alpha = alpha 52 | self.device = device 53 | self.lam = lam 54 | self.max_cycles = max_cycles 55 | if pre_train: 56 | if not os.path.exists('vgg_normalised_conv5_1.pth'): 57 | download_file_from_google_drive('1IAOFF5rDkVei035228Qp35hcTnliyMol', 58 | 'vgg_normalised_conv5_1.pth') 59 | if not os.path.exists('decoder_relu4_1.pth'): 60 | download_file_from_google_drive('1kkoyNwRup9y5GT1mPbsZ_7WPQO9qB7ZZ', 61 | 'decoder_relu4_1.pth') 62 | self.vgg_encoder = VGGEncoder('vgg_normalised_conv5_1.pth') 63 | self.decoder = Decoder(4, 'decoder_relu4_1.pth') 64 | else: 65 | self.vgg_encoder = VGGEncoder() 66 | self.decoder = Decoder(4) 67 | 68 | self.multimodal_style_feature_transfer = MultimodalStyleTransfer(n_cluster, 69 | alpha, 70 | device, 71 | lam, 72 | max_cycles) 73 | 74 | def generate(self, 75 | content_images, 76 | style_images, 77 | n_cluster=None, 78 | alpha=None, 79 | device=None, 80 | lam=None, 81 | max_cycles=None): 82 | 83 | n_cluster = self.n_cluster if n_cluster is None else n_cluster 84 | alpha = self.alpha if alpha is None else alpha 85 | device = self.device if device is None else device 86 | lam = self.lam if lam is None else lam 87 | max_cycles = self.max_cycles if max_cycles is None else max_cycles 88 | 89 | multimodal_style_feature_transfer = MultimodalStyleTransfer(n_cluster, 90 | alpha, 91 | device, 92 | lam, 93 | max_cycles) 94 | 95 | content_features = self.vgg_encoder(content_images, output_last_feature=True) 96 | style_features = self.vgg_encoder(style_images, output_last_feature=True) 97 | cs = [] 98 | 99 | for c, s in zip(content_features, style_features): 100 | cs.append(multimodal_style_feature_transfer.transfer(c, s).unsqueeze(dim=0)) 101 | cs = torch.cat(cs, dim=0) 102 | 103 | out = self.decoder(cs) 104 | return out 105 | 106 | @staticmethod 107 | def calc_content_loss(out_features, content_features): 108 | return F.mse_loss(out_features, content_features) 109 | 110 | @staticmethod 111 | def calc_style_loss(out_middle_features, style_middle_features): 112 | loss = 0 113 | for c, s in zip(out_middle_features, style_middle_features): 114 | c_mean, c_std = calc_mean_std(c) 115 | s_mean, s_std = calc_mean_std(s) 116 | loss += F.mse_loss(c_mean, s_mean) + F.mse_loss(c_std, s_std) 117 | return loss 118 | 119 | def forward(self, content_images, style_images, gamma=1): 120 | content_features = self.vgg_encoder(content_images, output_last_feature=True) 121 | style_features = self.vgg_encoder(style_images, output_last_feature=True) 122 | 123 | cs = [] 124 | for c, s in zip(content_features, style_features): 125 | cs.append(self.multimodal_style_feature_transfer.transfer(c, s).unsqueeze(dim=0)) 126 | cs = torch.cat(cs, dim=0) 127 | 128 | out = self.decoder(cs) 129 | 130 | output_features = self.vgg_encoder(out, output_last_feature=True) 131 | output_middle_features = self.vgg_encoder(out, output_last_feature=False) 132 | style_middle_features = self.vgg_encoder(style_images, output_last_feature=False) 133 | 134 | loss_c = self.calc_content_loss(output_features, content_features) 135 | loss_s = self.calc_style_loss(output_middle_features, style_middle_features) 136 | loss = loss_c + gamma * loss_s 137 | # print('loss: ', loss_c.item(), gamma*loss_s.item()) 138 | return loss 139 | -------------------------------------------------------------------------------- /normalisedVGG.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | normalised_vgg_relu5_1 = nn.Sequential( 5 | nn.Conv2d(3, 3, 1), 6 | nn.ReflectionPad2d((1, 1, 1, 1)), 7 | nn.Conv2d(3, 64, 3), 8 | nn.ReLU(), 9 | nn.ReflectionPad2d((1, 1, 1, 1)), 10 | nn.Conv2d(64, 64, 3), 11 | nn.ReLU(), 12 | nn.MaxPool2d(2, ceil_mode=True), 13 | nn.ReflectionPad2d((1, 1, 1, 1)), 14 | nn.Conv2d(64, 128, 3), 15 | nn.ReLU(), 16 | nn.ReflectionPad2d((1, 1, 1, 1)), 17 | nn.Conv2d(128, 128, 3), 18 | nn.ReLU(), 19 | nn.MaxPool2d(2, ceil_mode=True), 20 | nn.ReflectionPad2d((1, 1, 1, 1)), 21 | nn.Conv2d(128, 256, 3), 22 | nn.ReLU(), 23 | nn.ReflectionPad2d((1, 1, 1, 1)), 24 | nn.Conv2d(256, 256, 3), 25 | nn.ReLU(), 26 | nn.ReflectionPad2d((1, 1, 1, 1)), 27 | nn.Conv2d(256, 256, 3), 28 | nn.ReLU(), 29 | nn.ReflectionPad2d((1, 1, 1, 1)), 30 | nn.Conv2d(256, 256, 3), 31 | nn.ReLU(), 32 | nn.MaxPool2d(2, ceil_mode=True), 33 | nn.ReflectionPad2d((1, 1, 1, 1)), 34 | nn.Conv2d(256, 512, 3), 35 | nn.ReLU(), 36 | nn.ReflectionPad2d((1, 1, 1, 1)), 37 | nn.Conv2d(512, 512, 3), 38 | nn.ReLU(), 39 | nn.ReflectionPad2d((1, 1, 1, 1)), 40 | nn.Conv2d(512, 512, 3), 41 | nn.ReLU(), 42 | nn.ReflectionPad2d((1, 1, 1, 1)), 43 | nn.Conv2d(512, 512, 3), 44 | nn.ReLU(), 45 | nn.MaxPool2d(2, ceil_mode=True), 46 | nn.ReflectionPad2d((1, 1, 1, 1)), 47 | nn.Conv2d(512, 512, 3), 48 | nn.ReLU() 49 | ) 50 | 51 | 52 | class NormalisedVGG(nn.Module): 53 | """ 54 | VGG reluX_1(X = 1, 2, 3, 4, 5) can be obtained by slicing the follow vgg5_1 model. 55 | 56 | Sequential( 57 | (0): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1)) 58 | (1): ReflectionPad2d((1, 1, 1, 1)) 59 | (2): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1)) 60 | (3): ReLU() # relu1_1 61 | (4): ReflectionPad2d((1, 1, 1, 1)) 62 | (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1)) 63 | (6): ReLU() 64 | (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True) 65 | (8): ReflectionPad2d((1, 1, 1, 1)) 66 | (9): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1)) 67 | (10): ReLU() # relu2_1 68 | (11): ReflectionPad2d((1, 1, 1, 1)) 69 | (12): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1)) 70 | (13): ReLU() 71 | (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True) 72 | (15): ReflectionPad2d((1, 1, 1, 1)) 73 | (16): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1)) 74 | (17): ReLU() # relu3_1 75 | (18): ReflectionPad2d((1, 1, 1, 1)) 76 | (19): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1)) 77 | (20): ReLU() 78 | (21): ReflectionPad2d((1, 1, 1, 1)) 79 | (22): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1)) 80 | (23): ReLU() 81 | (24): ReflectionPad2d((1, 1, 1, 1)) 82 | (25): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1)) 83 | (26): ReLU() 84 | (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True) 85 | (28): ReflectionPad2d((1, 1, 1, 1)) 86 | (29): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1)) 87 | (30): ReLU()# relu4_1 88 | (31): ReflectionPad2d((1, 1, 1, 1)) 89 | (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1)) 90 | (33): ReLU() 91 | (34): ReflectionPad2d((1, 1, 1, 1)) 92 | (35): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1)) 93 | (36): ReLU() 94 | (37): ReflectionPad2d((1, 1, 1, 1)) 95 | (38): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1)) 96 | (39): ReLU() 97 | (40): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True) 98 | (41): ReflectionPad2d((1, 1, 1, 1)) 99 | (42): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1)) 100 | (43): ReLU() # relu5_1 101 | ) 102 | """ 103 | def __init__(self, pretrained_path='vgg_normalised_conv5_1.pth'): 104 | super().__init__() 105 | self.net = normalised_vgg_relu5_1 106 | if pretrained_path is not None: 107 | self.net.load_state_dict(torch.load(pretrained_path, map_location=lambda storage, loc: storage)) 108 | 109 | def forward(self, x, target, output_last_feature=False): 110 | if target == 'relu1_1': 111 | return self.net[:4](x) 112 | elif target == 'relu2_1': 113 | return self.net[:11](x) 114 | elif target == 'relu3_1': 115 | return self.net[:18](x) 116 | elif target == 'relu4_1': 117 | return self.net[:31](x) 118 | elif target == 'relu5_1': 119 | return self.net(x) 120 | else: 121 | raise ValueError(f'target should be in ["relu1_1", "relu2_1", "relu3_1", "relu4_1", "relu5_1"] but not {target}') 122 | -------------------------------------------------------------------------------- /result/avril_876_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/result/avril_876_demo.jpg -------------------------------------------------------------------------------- /result/avril_antimonocromatismo_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/result/avril_antimonocromatismo_demo.jpg -------------------------------------------------------------------------------- /result/avril_asheville_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/result/avril_asheville_demo.jpg -------------------------------------------------------------------------------- /result/avril_brushstrokers_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/result/avril_brushstrokers_demo.jpg -------------------------------------------------------------------------------- /result/avril_candy_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/result/avril_candy_demo.jpg -------------------------------------------------------------------------------- /result/avril_contrast_of_forms_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/result/avril_contrast_of_forms_demo.jpg -------------------------------------------------------------------------------- /result/avril_picasso_self_portrait_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/result/avril_picasso_self_portrait_demo.jpg -------------------------------------------------------------------------------- /result/avril_scene_de_rue_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/result/avril_scene_de_rue_demo.jpg -------------------------------------------------------------------------------- /style/088.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/088.jpg -------------------------------------------------------------------------------- /style/101308.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/101308.jpg -------------------------------------------------------------------------------- /style/27.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/27.jpg -------------------------------------------------------------------------------- /style/876.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/876.jpg -------------------------------------------------------------------------------- /style/antimonocromatismo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/antimonocromatismo.jpg -------------------------------------------------------------------------------- /style/asheville.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/asheville.jpg -------------------------------------------------------------------------------- /style/brick1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/brick1.jpg -------------------------------------------------------------------------------- /style/brushstrokers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/brushstrokers.jpg -------------------------------------------------------------------------------- /style/candy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/candy.jpg -------------------------------------------------------------------------------- /style/en_campo_gris.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/en_campo_gris.jpg -------------------------------------------------------------------------------- /style/in2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/in2.jpg -------------------------------------------------------------------------------- /style/la_muse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/la_muse.jpg -------------------------------------------------------------------------------- /style/mondrian.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/mondrian.jpg -------------------------------------------------------------------------------- /style/mosaic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/mosaic.jpg -------------------------------------------------------------------------------- /style/news1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/news1.jpg -------------------------------------------------------------------------------- /style/picasso_seated_nude_hr.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/picasso_seated_nude_hr.jpg -------------------------------------------------------------------------------- /style/picasso_self_portrait.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/picasso_self_portrait.jpg -------------------------------------------------------------------------------- /style/plum_flower.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/plum_flower.jpg -------------------------------------------------------------------------------- /style/rain-princess.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/rain-princess.jpg -------------------------------------------------------------------------------- /style/scene_de_rue.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/scene_de_rue.jpg -------------------------------------------------------------------------------- /style/seated-nude.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/seated-nude.jpg -------------------------------------------------------------------------------- /style/sketch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/sketch.png -------------------------------------------------------------------------------- /style/trial.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/trial.jpg -------------------------------------------------------------------------------- /style/woman_in_peasant_dress_cropped.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/woman_in_peasant_dress_cropped.jpg -------------------------------------------------------------------------------- /style/woman_with_hat_matisse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/irasin/Pytorch_MST/0ad55484748603c423a538ea9667a087cfba4c87/style/woman_with_hat_matisse.jpg -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from PIL import Image 4 | import torch 5 | from torchvision import transforms 6 | from torchvision.utils import save_image 7 | from model import Model 8 | 9 | trans = transforms.Compose([transforms.ToTensor()]) 10 | 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser(description='Multimodal Style Transfer by Pytorch') 14 | parser.add_argument('--content', '-c', type=str, default=None, 15 | help='Content image path e.g. content.jpg') 16 | parser.add_argument('--style', '-s', type=str, default=None, 17 | help='Style image path e.g. image.jpg') 18 | parser.add_argument('--output_name', '-o', type=str, default=None, 19 | help='Output path for generated image, no need to add ext, e.g. out') 20 | parser.add_argument('--n_cluster', type=int, default=3, 21 | help='number of clusters of k-means ') 22 | parser.add_argument('--alpha', default=1, 23 | help='fusion degree, should be a float or a list which length is n_cluster') 24 | parser.add_argument('--lam', type=float, default=0.1, 25 | help='weight of pairwise term in alpha-expansion') 26 | parser.add_argument('--max_cycles', default=None, 27 | help='max_cycles of alpha-expansion') 28 | parser.add_argument('--gpu', '-g', type=int, default=0, 29 | help='GPU ID(negative value indicate CPU)') 30 | parser.add_argument('--model_state_path', type=str, default='model_state.pth', 31 | help='pretrained model state') 32 | 33 | args = parser.parse_args() 34 | 35 | # set device on GPU if available, else CPU 36 | if torch.cuda.is_available() and args.gpu >= 0: 37 | device = torch.device(f'cuda:{args.gpu}') 38 | print(f'# CUDA available: {torch.cuda.get_device_name(0)}') 39 | else: 40 | device = 'cpu' 41 | 42 | # set model 43 | model = Model(n_cluster=args.n_cluster, 44 | alpha=args.alpha, 45 | device=device, 46 | lam=args.lam, 47 | max_cycles=args.max_cycles) 48 | if args.model_state_path is not None: 49 | model.load_state_dict(torch.load(args.model_state_path, map_location=lambda storage, loc: storage)) 50 | print(f'{args.model_state_path} loaded') 51 | model = model.to(device) 52 | 53 | c = Image.open(args.content) 54 | s = Image.open(args.style) 55 | c_tensor = trans(c).unsqueeze(0).to(device) 56 | s_tensor = trans(s).unsqueeze(0).to(device) 57 | 58 | with torch.no_grad(): 59 | out = model.generate(c_tensor, s_tensor).to('cpu') 60 | 61 | if args.output_name is None: 62 | c_name = os.path.splitext(os.path.basename(args.content))[0] 63 | s_name = os.path.splitext(os.path.basename(args.style))[0] 64 | args.output_name = f'{c_name}_{s_name}' 65 | 66 | save_image(out, f'{args.output_name}.jpg', nrow=1) 67 | o = Image.open(f'{args.output_name}.jpg') 68 | 69 | demo = Image.new('RGB', (c.width * 2, c.height)) 70 | o = o.resize(c.size) 71 | s = s.resize((i // 4 for i in c.size)) 72 | 73 | demo.paste(c, (0, 0)) 74 | demo.paste(o, (c.width, 0)) 75 | demo.paste(s, (c.width, c.height - s.height)) 76 | demo.save(f'{args.output_name}_style_transfer_demo.jpg', quality=95) 77 | 78 | o.paste(s, (0, o.height - s.height)) 79 | o.save(f'{args.output_name}_with_style_image.jpg', quality=95) 80 | 81 | print(f'result saved into files starting with {args.output_name}') 82 | 83 | 84 | if __name__ == '__main__': 85 | main() 86 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.simplefilter("ignore", UserWarning) 3 | import os 4 | import copy 5 | import argparse 6 | import matplotlib as mpl 7 | mpl.use('Agg') 8 | import matplotlib.pyplot as plt 9 | from tqdm import tqdm 10 | import torch 11 | from torch.optim import Adam 12 | from torch.utils.data import DataLoader 13 | from torchvision.utils import save_image 14 | from dataset import PreprocessDataset 15 | from model import Model 16 | 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser(description='Mulitmodal Style Transfer by Pytorch') 20 | parser.add_argument('--batch_size', '-b', type=int, default=16, 21 | help='number of images in each mini-batch') 22 | parser.add_argument('--epoch', '-e', type=int, default=1, 23 | help='number of sweeps over the dataset to train') 24 | parser.add_argument('--gpu', '-g', type=int, default=0, 25 | help='GPU ID(nagative value indicate CPU)') 26 | parser.add_argument('--learning_rate', '-lr', type=int, default=1e-5, 27 | help='learning rate for Adam') 28 | parser.add_argument('--snapshot_interval', type=int, default=1000, 29 | help='Interval of snapshot to generate image') 30 | parser.add_argument('--n_cluster', type=int, default=3, 31 | help='number of clusters of k-means ') 32 | parser.add_argument('--alpha', default=1, 33 | help='fusion degree, should be a float or a list which length is n_cluster') 34 | parser.add_argument('--lam', type=float, default=0.1, 35 | help='weight of pairwise term in alpha-expansion') 36 | parser.add_argument('--max_cycles', default=None, 37 | help='max_cycles of alpha-expansion') 38 | parser.add_argument('--gamma', type=float, default=1, 39 | help='weight of style loss') 40 | parser.add_argument('--train_content_dir', type=str, default='/data/chen/content', 41 | help='content images directory for train') 42 | parser.add_argument('--train_style_dir', type=str, default='/data/chen/style', 43 | help='style images directory for train') 44 | parser.add_argument('--test_content_dir', type=str, default='/data/chen/content', 45 | help='content images directory for test') 46 | parser.add_argument('--test_style_dir', type=str, default='/data/chen/style', 47 | help='style images directory for test') 48 | parser.add_argument('--save_dir', type=str, default='result', 49 | help='save directory for result and loss') 50 | parser.add_argument('--reuse', default=None, 51 | help='model state path to load for reuse') 52 | 53 | args = parser.parse_args() 54 | 55 | # create directory to save 56 | if not os.path.exists(args.save_dir): 57 | os.mkdir(args.save_dir) 58 | 59 | loss_dir = f'{args.save_dir}/loss' 60 | model_state_dir = f'{args.save_dir}/model_state' 61 | image_dir = f'{args.save_dir}/image' 62 | 63 | if not os.path.exists(loss_dir): 64 | os.mkdir(loss_dir) 65 | os.mkdir(model_state_dir) 66 | os.mkdir(image_dir) 67 | 68 | # set device on GPU if available, else CPU 69 | if torch.cuda.is_available() and args.gpu >= 0: 70 | device = torch.device(f'cuda:{args.gpu}') 71 | print(f'# CUDA available: {torch.cuda.get_device_name(0)}') 72 | else: 73 | device = 'cpu' 74 | 75 | print(f'# Minibatch-size: {args.batch_size}') 76 | print(f'# epoch: {args.epoch}') 77 | print('') 78 | 79 | # prepare dataset and dataLoader 80 | train_dataset = PreprocessDataset(args.train_content_dir, args.train_style_dir) 81 | test_dataset = PreprocessDataset(args.test_content_dir, args.test_style_dir) 82 | iters = len(train_dataset) 83 | print(f'Length of train image pairs: {iters}') 84 | 85 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 86 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) 87 | test_iter = iter(test_loader) 88 | 89 | # set model and optimizer 90 | model = Model(n_cluster=args.n_cluster, 91 | alpha=args.alpha, 92 | device=device, 93 | lam=args.lam, 94 | pre_train=True, 95 | max_cycles=args.max_cycles).to(device) 96 | if args.reuse is not None: 97 | model.load_state_dict(torch.load(args.reuse, map_location=lambda storage, loc: storage)) 98 | print(f'{args.reuse} loaded') 99 | optimizer = Adam(model.parameters(), lr=args.learning_rate) 100 | 101 | prev_model = copy.deepcopy(model) 102 | prev_optimizer = copy.deepcopy(optimizer) 103 | 104 | # start training 105 | loss_list = [] 106 | for e in range(1, args.epoch + 1): 107 | print(f'Start {e} epoch') 108 | for i, (content, style) in tqdm(enumerate(train_loader, 1)): 109 | content = content.to(device) 110 | style = style.to(device) 111 | loss = model(content, style, args.gamma) 112 | 113 | if torch.isnan(loss): 114 | model = prev_model 115 | optimizer = torch.optim.Adam(model.parameters()) 116 | optimizer.load_state_dict(prev_optimizer.state_dict()) 117 | else: 118 | prev_model = copy.deepcopy(model) 119 | prev_optimizer = copy.deepcopy(optimizer) 120 | 121 | optimizer.zero_grad() 122 | loss.backward() 123 | optimizer.step() 124 | 125 | loss_list.append(loss.item()) 126 | 127 | print(f'[{e}/total {args.epoch} epoch],[{i} /' 128 | f'total {round(iters/args.batch_size)} iteration]: {loss.item()}') 129 | 130 | if i % args.snapshot_interval == 0: 131 | content, style = next(test_iter) 132 | content = content.to(device) 133 | style = style.to(device) 134 | with torch.no_grad(): 135 | out = model.generate(content, style) 136 | res = torch.cat([content, style, out], dim=0) 137 | res = res.to('cpu') 138 | save_image(res, f'{image_dir}/{e}_epoch_{i}_iteration.png', nrow=args.batch_size) 139 | # if i % 1000 == 0: 140 | torch.save(model.state_dict(), f'{model_state_dir}/{e}_epoch_{i}_iteration.pth') 141 | torch.save(model.state_dict(), f'{model_state_dir}/{e}_epoch.pth') 142 | plt.plot(range(len(loss_list)), loss_list) 143 | plt.xlabel('iteration') 144 | plt.ylabel('loss') 145 | plt.title('train loss') 146 | plt.savefig(f'{loss_dir}/train_loss.png') 147 | with open(f'{loss_dir}/loss_log.txt', 'w') as f: 148 | for l in loss_list: 149 | f.write(f'{l}\n') 150 | print(f'Loss saved in {loss_dir}') 151 | 152 | 153 | if __name__ == '__main__': 154 | main() 155 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | CHUNK_SIZE = 32768 4 | URL = 'https://docs.google.com/uc?export=download' 5 | 6 | 7 | def download_file_from_google_drive(id, destination): 8 | print(f'# Downloading {destination}', end=' => ') 9 | session = requests.Session() 10 | 11 | response = session.get(URL, params={'id': id}, stream=True) 12 | token = get_confirm_token(response) 13 | 14 | if token: 15 | params = {'id': id, 'confirm': token} 16 | response = session.get(URL, params=params, stream=True) 17 | 18 | save_response_content(response, destination) 19 | print('Saved') 20 | 21 | 22 | def get_confirm_token(response): 23 | for key, value in response.cookies.items(): 24 | if key.startswith('download_warning'): 25 | return value 26 | 27 | return None 28 | 29 | 30 | def save_response_content(response, destination): 31 | with open(destination, 'wb') as f: 32 | for chunk in response.iter_content(CHUNK_SIZE): 33 | if chunk: 34 | f.write(chunk) 35 | --------------------------------------------------------------------------------