├── 1.png ├── LICENSE ├── Metamer_Transform.py ├── Metamer_Transform_Update.png ├── README.md ├── download_models_and_stimuli.sh ├── function.py └── net.py /1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArturoDeza/NeuroFovea_PyTorch/cf8f3e41ccc08b9f631f5f59776c01f92d52e944/1.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Arturo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Metamer_Transform.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import torch 5 | import torch.nn as nn 6 | from PIL import Image 7 | from torchvision import transforms 8 | from torchvision.utils import save_image 9 | import numpy as np 10 | import math 11 | import time 12 | 13 | import net 14 | from function import adaptive_instance_normalization, coral 15 | 16 | # Example code: 17 | # python Metamer_Transform.py --image 4751.png --output output_Stimuli 18 | 19 | parser = argparse.ArgumentParser() 20 | # Basic options 21 | parser.add_argument('--image', type=str, help='File path to the content image') 22 | parser.add_argument('--image_dir', type=str, help='Directory path to a batch of content images') 23 | parser.add_argument('--vgg', type=str, default='models/vgg_normalised.pth') 24 | parser.add_argument('--decoder', type=str, default='models/decoder-content-similar.pth') # Notice that this decoder is different than the classically used "decoder.pth"! 25 | 26 | # Additional options 27 | parser.add_argument('--image_size', type=int, default=512, help='New (minimum) size for the content image, keeping the original size if set to 0') 28 | parser.add_argument('--crop', action='store_true', help='do center crop to create squared image') 29 | parser.add_argument('--save_ext', default='.png', help='The extension name of the output image') 30 | parser.add_argument('--output', type=str, default='output', help='Directory to save the output image(s)') 31 | parser.add_argument('--scale',type=str,default='0.4',help='Rate of growth of the Log-Polar Receptive Fields') 32 | parser.add_argument('--verbose',type=int,default=0,help='Print several hyper-parameters as we run the rendering scheme. Default should be 0, should only be set to 1 for debugging.') 33 | parser.add_argument('--reference',type=int,default=0,help='Compute the reference image') 34 | 35 | args = parser.parse_args() 36 | 37 | # List of potentially different rate of growth of receptive fields 38 | # assuming a center fixation. 39 | scale_in = ['0.25','0.3','0.4','0.5','0.6','0.7'] 40 | scale_out = [377,301,187,126,103,91] 41 | 42 | Pooling_Region_Map = dict(zip(scale_in,scale_out)) 43 | 44 | verb = args.verbose 45 | 46 | resize_output = transforms.Compose([transforms.Resize((256,256))]) 47 | to_pil_image = transforms.ToPILImage() 48 | to_tensor = transforms.ToTensor() 49 | 50 | 51 | #function that loads receptive fields: 52 | def load_receptive_fields(): 53 | d = 1.281 # a value that was fitted via psychophysical experiments assuming 26 deg of visual angle maps to 512 pixels on a screen. 54 | mask_total = torch.zeros(Pooling_Region_Map[args.scale],64,64) 55 | alpha_matrix = torch.zeros(Pooling_Region_Map[args.scale]) 56 | for i in range(Pooling_Region_Map[args.scale]): 57 | i_str = str(i) 58 | #mask_str = './Receptive_Fields/MetaWindows_clean_s0.4/' + i_str + '.png' 59 | mask_str = './Receptive_Fields/MetaWindows_clean_s' + args.scale + '/' + i_str + '.png' 60 | mask_temp = mask_tf(Image.open(str(mask_str))) 61 | mask_total[i,:,:] = mask_temp 62 | mask_regular = mask_regular_tf(Image.open(str(mask_str))) 63 | mask_size = torch.sum(torch.sum(mask_regular>0.5)) 64 | recep_size = np.sqrt(mask_size/3.14)*26.0/512.0 65 | if i == 0: 66 | alpha_matrix[i] = 0 67 | else: 68 | alpha_matrix[i] = -1 + 2.0 / (1.0+math.exp(-recep_size*d)) 69 | if verb == 1: 70 | print(alpha_matrix[i]) 71 | return mask_total, alpha_matrix 72 | 73 | 74 | def test_transform(size, crop): 75 | transform_list = [] 76 | if size != 0: 77 | transform_list.append(transforms.Resize(size)) 78 | if crop: 79 | transform_list.append(transforms.CenterCrop(size)) 80 | transform_list.append(transforms.ToTensor()) 81 | transform = transforms.Compose(transform_list) 82 | return transform 83 | 84 | def mask_transform(): 85 | transform = transforms.Compose([transforms.Resize(64),transforms.Grayscale(1),transforms.ToTensor()]) 86 | return transform 87 | 88 | def mask_transform_regular(): 89 | transform = transforms.Compose([transforms.Resize(512),transforms.Grayscale(1),transforms.ToTensor()]) 90 | return transform 91 | 92 | def tile(a, dim, n_tile): 93 | init_dim = a.size(dim) 94 | repeat_idx = [1] * a.dim() 95 | repeat_idx[dim] = n_tile 96 | a = a.repeat(*(repeat_idx)) 97 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) 98 | return torch.index_select(a, dim, order_index) 99 | 100 | def foveated_style_transfer(vgg,decoder, content, mask_total, alpha_matrix, reference): 101 | noise = torch.randn(3,512,512) 102 | noise = coral(noise,content) 103 | # Move images to GPU: 104 | style = content.to(device).unsqueeze(0) # Remember we are doing something like "Auto-Style Transfer" 105 | content = content.to(device).unsqueeze(0) 106 | noise = noise.to(device).unsqueeze(0) 107 | # Create Empty Foveated Feature Vector to which we will allocate the latent crowded feature vectors: 108 | foveated_f = torch.zeros(1,512,64,64).to(device) 109 | # Create Content Feature Vector (post VGG): 110 | content_f = vgg(content) 111 | # Create Style Feature Vector (post VGG): 112 | style_f = vgg(style) 113 | # Create Noise Feature Vector (post VGG): 114 | noise_f = vgg(noise) 115 | # assume alpha_i = 0.5 116 | if reference == 1: 117 | return decoder(content_f) 118 | else: 119 | for i in range(Pooling_Region_Map[args.scale]): # Loop over all the receptive fields (pooling regions) 120 | alpha_i = alpha_matrix[i] 121 | mask = mask_total[i,:,:] 122 | mask = mask.unsqueeze(0) 123 | mask = tile(mask,0,512) 124 | mask_binary = mask>0.001 125 | if verb == 1: 126 | print(np.shape(content_f)) 127 | print(np.shape(mask_binary[0,:,:])) 128 | content_f_mask = content_f[:,:,mask_binary[0,:,:]] # 0 was 0th prefix before 129 | style_f_mask = style_f[:,:,mask_binary[0,:,:]] 130 | noise_f_mask = noise_f[:,:,mask_binary[0,:,:]] 131 | # 132 | if verb == 1: 133 | print(np.shape(noise_f_mask.unsqueeze(3))) 134 | print(np.shape(style_f_mask.unsqueeze(3))) 135 | content_f_mask = content_f_mask.unsqueeze(3) 136 | noise_f_mask = noise_f_mask.unsqueeze(3) 137 | style_f_mask = style_f_mask.unsqueeze(3) 138 | if verb == 1: 139 | print(np.shape(content_f_mask)) 140 | # Perform the Crowding Operation and Localized Auto Style-Transfer 141 | texture_f_mask = adaptive_instance_normalization(noise_f_mask,style_f_mask) 142 | alpha_mixture = (1-alpha_i)*content_f_mask + (alpha_i)*texture_f_mask 143 | if verb == 1: 144 | print(np.shape(alpha_mixture)) 145 | foveated_f[:,:,mask_binary[0,:,:]] = alpha_mixture.squeeze(3) 146 | # Run the now foveated image in the latent space through the decoder to render the metamer 147 | return decoder(foveated_f) 148 | 149 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 150 | 151 | output_dir = Path(args.output) 152 | output_dir.mkdir(exist_ok=True, parents=True) 153 | 154 | # Either --image or --imageDir should be given. 155 | assert (args.image or args.image_dir) 156 | if args.image: 157 | image_paths = [Path(args.image)] 158 | else: 159 | image_dir = Path(args.image_dir) 160 | image_paths = [f for f in image_dir.glob('*')] 161 | 162 | decoder = net.decoder 163 | vgg = net.vgg 164 | 165 | decoder.eval() 166 | vgg.eval() 167 | 168 | decoder.load_state_dict(torch.load(args.decoder)) 169 | vgg.load_state_dict(torch.load(args.vgg)) 170 | vgg = nn.Sequential(*list(vgg.children())[:31]) 171 | 172 | reference = args.reference 173 | 174 | vgg.to(device) 175 | decoder.to(device) 176 | 177 | image_tf = test_transform(args.image_size, args.crop) 178 | 179 | # Define Masked Transforms (for localized style transfer that is the basis 180 | # of the foveated style transform operation for each pooling region 181 | mask_tf = mask_transform() 182 | mask_regular_tf = mask_transform_regular() 183 | 184 | if verb == 1: 185 | print(image_paths) 186 | print(image_paths[0]) 187 | print(image_paths[0].stem) 188 | 189 | print(len(image_paths)) 190 | 191 | with torch.no_grad(): 192 | mask_total, alpha_matrix = load_receptive_fields() 193 | for z in range(len(image_paths)): 194 | image_path = image_paths[z] 195 | image = image_tf(Image.open(str(image_path)).convert('RGB')) 196 | start_time = time.time() 197 | output = foveated_style_transfer(vgg,decoder,image,mask_total,alpha_matrix,reference) 198 | output = output.cpu() 199 | output2 = to_pil_image(torch.clamp(output.squeeze(0),0,1)) 200 | output = torch.clamp(to_tensor(resize_output(output2)),0,1) 201 | end_time = time.time() 202 | # Move from GPU to CPU 203 | #output = output.cpu() 204 | # Move Output 205 | if reference == 0: 206 | output_name = output_dir / '{:s}_s{:s}{:s}'.format(image_path.stem, args.scale, args.save_ext) 207 | elif reference == 1: 208 | output_name = output_dir / '{:s}_Reference{:s}'.format(image_path.stem, args.save_ext) 209 | # Save Image 210 | save_image(output, str(output_name)) 211 | # Display Compute Time 212 | print(['Total Rendering time: ' + str(end_time-start_time) + ' seconds']) 213 | if z % 50 == 1: 214 | time.sleep(10) 215 | -------------------------------------------------------------------------------- /Metamer_Transform_Update.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArturoDeza/NeuroFovea_PyTorch/cf8f3e41ccc08b9f631f5f59776c01f92d52e944/Metamer_Transform_Update.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | # NeuroFovea_PyTorch 6 | An adapted version of the Metamer Foveation Transform code from Deza et al. ICLR 2019 7 | 8 | 9 | To complete the installation please run: 10 | 11 | ``` 12 | $ bash download_models_and_stimuli.sh 13 | ``` 14 | 15 | ### Example code: 16 | 17 | Generate a foveated image for the `512x512` image `1.png` with a center fixation, specified by the rate of growth of the receptive field: `s=0.4`. Note: The approximate rendering time for a metamer should be around 1 second if you have your GPU on. 18 | 19 | ``` 20 | $ python Metamer_Transform.py --image 1.png --scale 0.4 --reference 0 21 | ``` 22 | 23 | The paper "Emergent Properties of Foveated Perceptual Systems" of Deza & Konkle, 2020/2021 (https://arxiv.org/abs/2006.07991) that uses a foveated transform (with an exagerated distortion given the scaling factor set to `s=0.4`) was ran with the lua code accessible here: https://github.com/ArturoDeza/NeuroFovea, but current and future follow-up work has transitioned to this PyTorch version. After finally vetting the code (and making sure both the lua + PyTorch versions produce the same outputs), we've decide to release it to accelerate work on spatially-adaptive (foveated) texture computation in humans and machines. 24 | 25 | The Foveated Texture Transform essentially computes log-polar + localized Adaptive Instance Normalization (See Huang & Belongie (ICCV, 2019); This code is thus an extension of: https://github.com/naoto0804/pytorch-AdaIN) 26 | 27 | Please read our paper to learn more about visual metamerism: https://openreview.net/forum?id=BJzbG20cFQ 28 | 29 | We hope this code and our paper can help researchers, scientists and engineers improve the use and design of metamer models that have potentially exciting applications in both computer vision and visual neuroscience. 30 | 31 | This code is free to use for Research Purposes, and if used/modified in any way please consider citing: 32 | 33 | ``` 34 | @inproceedings{ 35 | deza2018towards, 36 | title={Towards Metamerism via Foveated Style Transfer}, 37 | author={Arturo Deza and Aditya Jonnalagadda and Miguel P. Eckstein}, 38 | booktitle={International Conference on Learning Representations}, 39 | year={2019}, 40 | url={https://openreview.net/forum?id=BJzbG20cFQ}, 41 | } 42 | ``` 43 | 44 | Other inquiries: deza@mit.edu 45 | -------------------------------------------------------------------------------- /download_models_and_stimuli.sh: -------------------------------------------------------------------------------- 1 | wget https://www.dropbox.com/s/ovuyfqumr5xymds/Full_Metamer_Stimuli.zip 2 | unzip Full_Metamer_Stimuli.zip 3 | wget https://www.dropbox.com/s/vktpuc16pj5stac/modelsPyTorch.zip 4 | unzip modelsPyTorch.zip 5 | mv modelsPyTorch models 6 | wget https://www.dropbox.com/s/rtovjv8y1xbyso6/Receptive_Fields.zip 7 | unzip Receptive_Fields.zip 8 | wget https://www.dropbox.com/s/5semltqozxywsvz/Refinement.zip 9 | unzip Refinement.zip 10 | rm Full_Metamer_Stimuli.zip 11 | rm modelsPyTorch.zip 12 | rm Receptive_Fields.zip 13 | rm Refinement.zip 14 | -------------------------------------------------------------------------------- /function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def calc_mean_std(feat, eps=1e-5): 5 | # eps is a small value added to the variance to avoid divide-by-zero. 6 | size = feat.size() 7 | assert (len(size) == 4) 8 | N, C = size[:2] 9 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 10 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 11 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 12 | return feat_mean, feat_std 13 | 14 | 15 | def adaptive_instance_normalization(content_feat, style_feat): 16 | assert (content_feat.size()[:2] == style_feat.size()[:2]) 17 | size = content_feat.size() 18 | style_mean, style_std = calc_mean_std(style_feat) 19 | content_mean, content_std = calc_mean_std(content_feat) 20 | 21 | normalized_feat = (content_feat - content_mean.expand( 22 | size)) / content_std.expand(size) 23 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 24 | 25 | 26 | def _calc_feat_flatten_mean_std(feat): 27 | # takes 3D feat (C, H, W), return mean and std of array within channels 28 | assert (feat.size()[0] == 3) 29 | assert (isinstance(feat, torch.FloatTensor)) 30 | feat_flatten = feat.view(3, -1) 31 | mean = feat_flatten.mean(dim=-1, keepdim=True) 32 | std = feat_flatten.std(dim=-1, keepdim=True) 33 | return feat_flatten, mean, std 34 | 35 | 36 | def _mat_sqrt(x): 37 | U, D, V = torch.svd(x) 38 | return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t()) 39 | 40 | 41 | def coral(source, target): 42 | # assume both source and target are 3D array (C, H, W) 43 | # Note: flatten -> f 44 | 45 | source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source) 46 | source_f_norm = (source_f - source_f_mean.expand_as( 47 | source_f)) / source_f_std.expand_as(source_f) 48 | source_f_cov_eye = \ 49 | torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3) 50 | 51 | target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target) 52 | target_f_norm = (target_f - target_f_mean.expand_as( 53 | target_f)) / target_f_std.expand_as(target_f) 54 | target_f_cov_eye = \ 55 | torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3) 56 | 57 | source_f_norm_transfer = torch.mm( 58 | _mat_sqrt(target_f_cov_eye), 59 | torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)), 60 | source_f_norm) 61 | ) 62 | 63 | source_f_transfer = source_f_norm_transfer * \ 64 | target_f_std.expand_as(source_f_norm) + \ 65 | target_f_mean.expand_as(source_f_norm) 66 | 67 | return source_f_transfer.view(source.size()) 68 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from function import adaptive_instance_normalization as adain 4 | from function import calc_mean_std 5 | 6 | decoder = nn.Sequential( 7 | nn.ReflectionPad2d((1, 1, 1, 1)), 8 | nn.Conv2d(512, 256, (3, 3)), 9 | nn.ReLU(), 10 | nn.Upsample(scale_factor=2, mode='nearest'), 11 | nn.ReflectionPad2d((1, 1, 1, 1)), 12 | nn.Conv2d(256, 256, (3, 3)), 13 | nn.ReLU(), 14 | nn.ReflectionPad2d((1, 1, 1, 1)), 15 | nn.Conv2d(256, 256, (3, 3)), 16 | nn.ReLU(), 17 | nn.ReflectionPad2d((1, 1, 1, 1)), 18 | nn.Conv2d(256, 256, (3, 3)), 19 | nn.ReLU(), 20 | nn.ReflectionPad2d((1, 1, 1, 1)), 21 | nn.Conv2d(256, 128, (3, 3)), 22 | nn.ReLU(), 23 | nn.Upsample(scale_factor=2, mode='nearest'), 24 | nn.ReflectionPad2d((1, 1, 1, 1)), 25 | nn.Conv2d(128, 128, (3, 3)), 26 | nn.ReLU(), 27 | nn.ReflectionPad2d((1, 1, 1, 1)), 28 | nn.Conv2d(128, 64, (3, 3)), 29 | nn.ReLU(), 30 | nn.Upsample(scale_factor=2, mode='nearest'), 31 | nn.ReflectionPad2d((1, 1, 1, 1)), 32 | nn.Conv2d(64, 64, (3, 3)), 33 | nn.ReLU(), 34 | nn.ReflectionPad2d((1, 1, 1, 1)), 35 | nn.Conv2d(64, 3, (3, 3)), 36 | ) 37 | 38 | vgg = nn.Sequential( 39 | nn.Conv2d(3, 3, (1, 1)), 40 | nn.ReflectionPad2d((1, 1, 1, 1)), 41 | nn.Conv2d(3, 64, (3, 3)), 42 | nn.ReLU(), # relu1-1 43 | nn.ReflectionPad2d((1, 1, 1, 1)), 44 | nn.Conv2d(64, 64, (3, 3)), 45 | nn.ReLU(), # relu1-2 46 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 47 | nn.ReflectionPad2d((1, 1, 1, 1)), 48 | nn.Conv2d(64, 128, (3, 3)), 49 | nn.ReLU(), # relu2-1 50 | nn.ReflectionPad2d((1, 1, 1, 1)), 51 | nn.Conv2d(128, 128, (3, 3)), 52 | nn.ReLU(), # relu2-2 53 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 54 | nn.ReflectionPad2d((1, 1, 1, 1)), 55 | nn.Conv2d(128, 256, (3, 3)), 56 | nn.ReLU(), # relu3-1 57 | nn.ReflectionPad2d((1, 1, 1, 1)), 58 | nn.Conv2d(256, 256, (3, 3)), 59 | nn.ReLU(), # relu3-2 60 | nn.ReflectionPad2d((1, 1, 1, 1)), 61 | nn.Conv2d(256, 256, (3, 3)), 62 | nn.ReLU(), # relu3-3 63 | nn.ReflectionPad2d((1, 1, 1, 1)), 64 | nn.Conv2d(256, 256, (3, 3)), 65 | nn.ReLU(), # relu3-4 66 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 67 | nn.ReflectionPad2d((1, 1, 1, 1)), 68 | nn.Conv2d(256, 512, (3, 3)), 69 | nn.ReLU(), # relu4-1, this is the last layer used 70 | nn.ReflectionPad2d((1, 1, 1, 1)), 71 | nn.Conv2d(512, 512, (3, 3)), 72 | nn.ReLU(), # relu4-2 73 | nn.ReflectionPad2d((1, 1, 1, 1)), 74 | nn.Conv2d(512, 512, (3, 3)), 75 | nn.ReLU(), # relu4-3 76 | nn.ReflectionPad2d((1, 1, 1, 1)), 77 | nn.Conv2d(512, 512, (3, 3)), 78 | nn.ReLU(), # relu4-4 79 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 80 | nn.ReflectionPad2d((1, 1, 1, 1)), 81 | nn.Conv2d(512, 512, (3, 3)), 82 | nn.ReLU(), # relu5-1 83 | nn.ReflectionPad2d((1, 1, 1, 1)), 84 | nn.Conv2d(512, 512, (3, 3)), 85 | nn.ReLU(), # relu5-2 86 | nn.ReflectionPad2d((1, 1, 1, 1)), 87 | nn.Conv2d(512, 512, (3, 3)), 88 | nn.ReLU(), # relu5-3 89 | nn.ReflectionPad2d((1, 1, 1, 1)), 90 | nn.Conv2d(512, 512, (3, 3)), 91 | nn.ReLU() # relu5-4 92 | ) 93 | 94 | 95 | class Net(nn.Module): 96 | def __init__(self, encoder, decoder): 97 | super(Net, self).__init__() 98 | enc_layers = list(encoder.children()) 99 | self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1 100 | self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1 101 | self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1 102 | self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1 103 | self.decoder = decoder 104 | self.mse_loss = nn.MSELoss() 105 | 106 | # fix the encoder 107 | for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']: 108 | for param in getattr(self, name).parameters(): 109 | param.requires_grad = False 110 | 111 | # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image 112 | def encode_with_intermediate(self, input): 113 | results = [input] 114 | for i in range(4): 115 | func = getattr(self, 'enc_{:d}'.format(i + 1)) 116 | results.append(func(results[-1])) 117 | return results[1:] 118 | 119 | # extract relu4_1 from input image 120 | def encode(self, input): 121 | for i in range(4): 122 | input = getattr(self, 'enc_{:d}'.format(i + 1))(input) 123 | return input 124 | 125 | def calc_content_loss(self, input, target): 126 | assert (input.size() == target.size()) 127 | assert (target.requires_grad is False) 128 | return self.mse_loss(input, target) 129 | 130 | def calc_style_loss(self, input, target): 131 | assert (input.size() == target.size()) 132 | assert (target.requires_grad is False) 133 | input_mean, input_std = calc_mean_std(input) 134 | target_mean, target_std = calc_mean_std(target) 135 | return self.mse_loss(input_mean, target_mean) + \ 136 | self.mse_loss(input_std, target_std) 137 | 138 | def forward(self, content, style, alpha=1.0): 139 | assert 0 <= alpha <= 1 140 | style_feats = self.encode_with_intermediate(style) 141 | content_feat = self.encode(content) 142 | t = adain(content_feat, style_feats[-1]) 143 | t = alpha * t + (1 - alpha) * content_feat 144 | 145 | g_t = self.decoder(t) 146 | g_t_feats = self.encode_with_intermediate(g_t) 147 | 148 | loss_c = self.calc_content_loss(g_t_feats[-1], t) 149 | loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0]) 150 | for i in range(1, 4): 151 | loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i]) 152 | return loss_c, loss_s 153 | --------------------------------------------------------------------------------