├── interface.jpg ├── README.md └── style_transfer_GUI.py /interface.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spot92/Python_Style_Transfer_GUI/HEAD/interface.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Python_Style_Transfer_GUI 2 | Took the neural style transfer for python and made it way more user friendly. 3 | 4 | ## Background 5 | This was made with Python 3.7.4 on a Windows 10 machine. 6 | 7 | I take the style transfer code from this repo and turn it into a more user friendly experience: 8 | 9 | https://github.com/spot92/neural-style-pt 10 | 11 | This is just a fork of the following with some minor changes to the image sizing: 12 | 13 | https://github.com/ProGamerGov/neural-style-pt 14 | 15 | And this is a Python implementation of the Lua code here: 16 | 17 | https://github.com/jcjohnson/neural-style 18 | 19 | ## Requirements 20 | For any of the requirements, such as torch, scipy, etc., I will refer you to the ProGamerGov repo as it has basically everything you could want there. 21 | 22 | In addition to all of those python packages, you will need to run the following: 23 | pip install gooey 24 | 25 | You will need a CUDA capable graphics card (which I believe is NVidia only) in order to run this with any sort of reasonable speed. I have never run this on CPU myself, but my understanding is that it is significantly slower. 26 | 27 | And again, to be clear, this is just the GUI file for this. If you would like to get the CaffeLoader.py file or the models folder, I will direct you to https://github.com/spot92/neural-style-pt or https://github.com/ProGamerGov/neural-style-pt (The ProGamerGov one is much more actively maintained and I would defer to that for these files). 28 | 29 | NOTE: YOU DO NOT NEED THE neural_style.py OR neural_style_my_edits.py FILES FOR THIS TO WORK. 30 | 31 | Once you have gotten set up with the necessary files and models from the other repos, come back here. 32 | Make sure that all the files and models are in the same root folder. For example, my folder is D:/Neural Style Python and I have every other necessary folder/file in there. I find it helpful to have separate folders for content, styles, and output, but this is not required. 33 | 34 | ## Running It 35 | Running it should be as easy as double clicking the file. That is all I do. Fill in the fields as desired (I set up defaults on most of them). If you would like to change the defaults, find that line in the code at the top and change the default value. Please note that I have set the default of the output file to be output/ since I save mine to the output folder. You can also change the size of the interface by editing the default_size=(1280, 720) line. 36 | 37 | This is what it should look like when run sucessfully: 38 | 39 | ![Image of Interface](https://github.com/spot92/Python_Style_Transfer_GUI/blob/master/interface.jpg) 40 | 41 | I think this covers it, please let me know if there are any questions. 42 | -------------------------------------------------------------------------------- /style_transfer_GUI.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torchvision.transforms as transforms 7 | #Changes made between short to long segments of #### 8 | #Some things say Deleted: 9 | from PIL import Image 10 | from CaffeLoader import loadCaffemodel, ModelParallel 11 | 12 | from gooey import Gooey, GooeyParser 13 | 14 | @Gooey(default_size=(1280, 720)) 15 | def gpu(): 16 | parser = GooeyParser(description = "Style Transfer GUI") 17 | # Basic options 18 | parser.add_argument("content_image", help="Desired content image.", widget = "FileChooser") 19 | parser.add_argument("style_image", help="Desired style image.", widget = "MultiFileChooser") 20 | parser.add_argument("image_size", help="Maximum height/width of generated image.", type=int, default=384) 21 | parser.add_argument("output_image", help ="The name you wish to save to, make sure to inlude the .jpg extension.", default='output/') 22 | parser.add_argument("-cudnn_autotune", help ="Do you want to use cudnn_autotune?", action='store_true', default = True) 23 | parser.add_argument("-init_image", help="Desired initialization image.", widget = "FileChooser", default=None) 24 | parser.add_argument("original_colors", help ="Do you want to use the content's original colors?", type=int, choices=[0, 1], default=0) 25 | parser.add_argument("-gpu", help="Zero-indexed ID of the GPU to use; for CPU mode set gpu = c.", default=0) 26 | 27 | # Optimization options 28 | parser.add_argument("content_weight", help ="The weight of the content image.", type=float, default=5) 29 | parser.add_argument("style_weight", help ="The weight of the style image.", type=float, default=1000) 30 | parser.add_argument("-style_blend_weights", help = "The weights when using multiple styles.", default=None) 31 | parser.add_argument("tv_weight", help ="Setting this low will make the final image more crisp.", type=float, default=0.00085) 32 | parser.add_argument("num_iterations", help ="Number of iterations to run.", type=int, default=1000) 33 | parser.add_argument("-init", help ="Intialize an image or something random?", choices=['random', 'image'], default='image') 34 | parser.add_argument("-optimizer", help ="LBFGS is better, but uses more memory than Adam.", choices=['lbfgs', 'adam'], default='lbfgs') 35 | parser.add_argument("-learning_rate", help ="The learning rate for when using Adam.", type=float, default=1e0) 36 | parser.add_argument("-lbfgs_num_correction",help ="Literally no idea what this does.", type=int, default=100) 37 | 38 | # Output options 39 | parser.add_argument("print_iter",help ="How often do you want to be updated on the progess?", type=int, default=50) 40 | parser.add_argument("save_iter",help ="How often do you want to save the created image?", type=int, default='0') 41 | 42 | # Other options 43 | parser.add_argument("-style_scale",help ="At what scale to get features from style image.", type=float, default=1.0) 44 | parser.add_argument("-pooling", help ="You can change this, but I prefer max over avg.", choices=['avg', 'max'], default='max') 45 | parser.add_argument("-model_file",help ="Path to the .pth VGG Caffe model.", type=str, default='models/vgg19-d01eb7cb.pth') 46 | parser.add_argument("-disable_check",help ="No clue.", action='store_true') 47 | parser.add_argument("-backend",help ="I like cudnn, but I believe that for CPU you have to use nn never tested myself.", choices=['nn', 'cudnn', 'mkl', 'mkldnn', 'openmp', 'mkl,cudnn', 'cudnn,mkl'], default='cudnn') 48 | parser.add_argument("-normalize_weights",help ="Not really sure, I've tried it and did not like the resutls.", action='store_true') 49 | parser.add_argument("-seed",help ="Random seed.", type=int, default=100) 50 | parser.add_argument("-content_layers", help="Layers to use from the content image. The defaults I have set work really well.", default='relu1_1,relu2_1,relu3_1,relu4_1,relu4_2,relu5_1') 51 | parser.add_argument("-style_layers", help="Layers to use from the style image. The defaults I have set work really well.", default='relu3_1,relu4_1,relu4_2,relu5_1') 52 | 53 | parser.add_argument("-multidevice_strategy",help="If you're rich and have more than one GPU. No idea how to use this option.", default='4,7,29') 54 | params = parser.parse_args() 55 | return params 56 | 57 | Image.MAX_IMAGE_PIXELS = 1000000000 # Support gigapixel images 58 | 59 | 60 | 61 | def main(): 62 | dtype, multidevice, backward_device = setup_gpu() 63 | 64 | cnn, layerList = loadCaffemodel(params.model_file, params.pooling, params.gpu, params.disable_check) 65 | 66 | content_image = preprocess(params.content_image, params.image_size).type(dtype) 67 | 68 | 69 | ##################################################### 70 | Ch = content_image.size(2) #literally no idea why its (2) and not [0] 71 | Cw = content_image.size(3) #literally no idea why its (3) and not [1] 72 | ################################################################# 73 | 74 | style_image_input = params.style_image.split(';') 75 | style_image_list, ext = [], [".jpg", ".jpeg", ".png", ".tiff"] 76 | for image in style_image_input: 77 | if os.path.isdir(image): 78 | images = (image + "/" + file for file in os.listdir(image) 79 | if os.path.splitext(file)[1].lower() in ext) 80 | style_image_list.extend(images) 81 | else: 82 | style_image_list.append(image) 83 | style_images_caffe = [] 84 | for image in style_image_list: 85 | ################################################# 86 | image_path = image 87 | print(image_path) 88 | im_sizing = Image.open(image_path) 89 | print(im_sizing) 90 | Sh = im_sizing.size[0] #this one is the way I expect it to be, but the Ch is not 91 | Sw = im_sizing.size[1] #this one is the way I expect it to be, but the Ch is not 92 | style_size = 0 93 | resizeStyle = 1 94 | Cr = Cw / Ch 95 | Sr = Sw / Sh 96 | 97 | if Cr >= Sr: 98 | if Sr >= 1: 99 | style_size = Cw * params.style_scale 100 | else: 101 | style_size = params.style_scale * Cw * Sh /Sw 102 | 103 | if style_size > Sw: 104 | style_size = Sw 105 | resizeStyle = 0 106 | else: 107 | if Sr >= 1: 108 | style_size = params.style_scale * Ch * Sw /Sh 109 | else: 110 | style_size = Ch * params.style_scale 111 | 112 | if style_size > Sh: 113 | style_size = Sh 114 | resizeStyle = 0 115 | 116 | ############################################################# 117 | #Deleted: style_size = int(params.image_size * params.style_scale) 118 | 119 | img_caffe = preprocess(image, style_size).type(dtype) 120 | style_images_caffe.append(img_caffe) 121 | 122 | if params.init_image != None: 123 | image_size = (content_image.size(2), content_image.size(3)) 124 | init_image = preprocess(params.init_image, image_size).type(dtype) 125 | 126 | # Handle style blending weights for multiple style inputs 127 | style_blend_weights = [] 128 | if params.style_blend_weights == None: 129 | # Style blending not specified, so use equal weighting 130 | for i in style_image_list: 131 | style_blend_weights.append(1.0) 132 | for i, blend_weights in enumerate(style_blend_weights): 133 | style_blend_weights[i] = int(style_blend_weights[i]) 134 | else: 135 | style_blend_weights = params.style_blend_weights.split(',') 136 | assert len(style_blend_weights) == len(style_image_list), \ 137 | "-style_blend_weights and -style_images must have the same number of elements!" 138 | 139 | # Normalize the style blending weights so they sum to 1 140 | style_blend_sum = 0 141 | for i, blend_weights in enumerate(style_blend_weights): 142 | style_blend_weights[i] = float(style_blend_weights[i]) 143 | style_blend_sum = float(style_blend_sum) + style_blend_weights[i] 144 | for i, blend_weights in enumerate(style_blend_weights): 145 | style_blend_weights[i] = float(style_blend_weights[i]) / float(style_blend_sum) 146 | 147 | content_layers = params.content_layers.split(',') 148 | style_layers = params.style_layers.split(',') 149 | 150 | # Set up the network, inserting style and content loss modules 151 | cnn = copy.deepcopy(cnn) 152 | content_losses, style_losses, tv_losses = [], [], [] 153 | next_content_idx, next_style_idx = 1, 1 154 | net = nn.Sequential() 155 | c, r = 0, 0 156 | if params.tv_weight > 0: 157 | tv_mod = TVLoss(params.tv_weight).type(dtype) 158 | net.add_module(str(len(net)), tv_mod) 159 | tv_losses.append(tv_mod) 160 | 161 | for i, layer in enumerate(list(cnn), 1): 162 | if next_content_idx <= len(content_layers) or next_style_idx <= len(style_layers): 163 | if isinstance(layer, nn.Conv2d): 164 | net.add_module(str(len(net)), layer) 165 | 166 | if layerList['C'][c] in content_layers: 167 | print("Setting up content layer " + str(i) + ": " + str(layerList['C'][c])) 168 | loss_module = ContentLoss(params.content_weight) 169 | net.add_module(str(len(net)), loss_module) 170 | content_losses.append(loss_module) 171 | 172 | if layerList['C'][c] in style_layers: 173 | print("Setting up style layer " + str(i) + ": " + str(layerList['C'][c])) 174 | loss_module = StyleLoss(params.style_weight) 175 | net.add_module(str(len(net)), loss_module) 176 | style_losses.append(loss_module) 177 | c+=1 178 | 179 | if isinstance(layer, nn.ReLU): 180 | net.add_module(str(len(net)), layer) 181 | 182 | if layerList['R'][r] in content_layers: 183 | print("Setting up content layer " + str(i) + ": " + str(layerList['R'][r])) 184 | loss_module = ContentLoss(params.content_weight) 185 | net.add_module(str(len(net)), loss_module) 186 | content_losses.append(loss_module) 187 | next_content_idx += 1 188 | 189 | if layerList['R'][r] in style_layers: 190 | print("Setting up style layer " + str(i) + ": " + str(layerList['R'][r])) 191 | loss_module = StyleLoss(params.style_weight) 192 | net.add_module(str(len(net)), loss_module) 193 | style_losses.append(loss_module) 194 | next_style_idx += 1 195 | r+=1 196 | 197 | if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d): 198 | net.add_module(str(len(net)), layer) 199 | 200 | if multidevice: 201 | net = setup_multi_device(net) 202 | 203 | # Capture content targets 204 | for i in content_losses: 205 | i.mode = 'capture' 206 | print("Capturing content targets") 207 | print_torch(net, multidevice) 208 | net(content_image) 209 | 210 | # Capture style targets 211 | for i in content_losses: 212 | i.mode = 'None' 213 | 214 | for i, image in enumerate(style_images_caffe): 215 | print("Capturing style target " + str(i+1)) 216 | for j in style_losses: 217 | j.mode = 'capture' 218 | j.blend_weight = style_blend_weights[i] 219 | net(style_images_caffe[i]) 220 | 221 | # Set all loss modules to loss mode 222 | for i in content_losses: 223 | i.mode = 'loss' 224 | for i in style_losses: 225 | i.mode = 'loss' 226 | 227 | # Maybe normalize content and style weights 228 | if params.normalize_weights: 229 | normalize_weights(content_losses, style_losses) 230 | 231 | # Freeze the network in order to prevent 232 | # unnecessary gradient calculations 233 | for param in net.parameters(): 234 | param.requires_grad = False 235 | 236 | # Initialize the image 237 | if params.seed >= 0: 238 | torch.manual_seed(params.seed) 239 | torch.cuda.manual_seed_all(params.seed) 240 | torch.backends.cudnn.deterministic=True 241 | if params.init == 'random': 242 | B, C, H, W = content_image.size() 243 | img = torch.randn(C, H, W).mul(0.001).unsqueeze(0).type(dtype) 244 | elif params.init == 'image': 245 | if params.init_image != None: 246 | img = init_image.clone() 247 | else: 248 | img = content_image.clone() 249 | img = nn.Parameter(img) 250 | 251 | def maybe_print(t, loss): 252 | if params.print_iter > 0 and t % params.print_iter == 0: 253 | print("Iteration " + str(t) + " / "+ str(params.num_iterations)) 254 | for i, loss_module in enumerate(content_losses): 255 | print(" Content " + str(i+1) + " loss: " + str(loss_module.loss.item())) 256 | for i, loss_module in enumerate(style_losses): 257 | print(" Style " + str(i+1) + " loss: " + str(loss_module.loss.item())) 258 | print(" Total loss: " + str(loss.item())) 259 | 260 | def maybe_save(t): 261 | should_save = params.save_iter > 0 and t % params.save_iter == 0 262 | should_save = should_save or t == params.num_iterations 263 | if should_save: 264 | output_filename, file_extension = os.path.splitext(params.output_image) 265 | if t == params.num_iterations: 266 | filename = output_filename + str(file_extension) 267 | else: 268 | filename = str(output_filename) + "_" + str(t) + str(file_extension) 269 | disp = deprocess(img.clone()) 270 | 271 | # Maybe perform postprocessing for color-independent style transfer 272 | if params.original_colors == 1: 273 | disp = original_colors(deprocess(content_image.clone()), disp) 274 | 275 | disp.save(str(filename)) 276 | 277 | # Function to evaluate loss and gradient. We run the net forward and 278 | # backward to get the gradient, and sum up losses from the loss modules. 279 | # optim.lbfgs internally handles iteration and calls this function many 280 | # times, so we manually count the number of iterations to handle printing 281 | # and saving intermediate results. 282 | num_calls = [0] 283 | def feval(): 284 | num_calls[0] += 1 285 | optimizer.zero_grad() 286 | net(img) 287 | loss = 0 288 | 289 | for mod in content_losses: 290 | loss += mod.loss.to(backward_device) 291 | for mod in style_losses: 292 | loss += mod.loss.to(backward_device) 293 | if params.tv_weight > 0: 294 | for mod in tv_losses: 295 | loss += mod.loss.to(backward_device) 296 | 297 | loss.backward() 298 | 299 | maybe_save(num_calls[0]) 300 | maybe_print(num_calls[0], loss) 301 | 302 | return loss 303 | 304 | optimizer, loopVal = setup_optimizer(img) 305 | while num_calls[0] <= loopVal: 306 | optimizer.step(feval) 307 | 308 | 309 | # Configure the optimizer 310 | def setup_optimizer(img): 311 | if params.optimizer == 'lbfgs': 312 | print("Running optimization with L-BFGS") 313 | optim_state = { 314 | 'max_iter': params.num_iterations, 315 | 'tolerance_change': -1, 316 | 'tolerance_grad': -1, 317 | } 318 | if params.lbfgs_num_correction != 100: 319 | optim_state['history_size'] = params.lbfgs_num_correction 320 | optimizer = optim.LBFGS([img], **optim_state) 321 | loopVal = 1 322 | elif params.optimizer == 'adam': 323 | print("Running optimization with ADAM") 324 | optimizer = optim.Adam([img], lr = params.learning_rate) 325 | loopVal = params.num_iterations - 1 326 | return optimizer, loopVal 327 | 328 | 329 | def setup_gpu(): 330 | def setup_cuda(): 331 | if 'cudnn' in params.backend: 332 | torch.backends.cudnn.enabled = True 333 | if params.cudnn_autotune: 334 | torch.backends.cudnn.benchmark = True 335 | else: 336 | torch.backends.cudnn.enabled = False 337 | 338 | def setup_cpu(): 339 | if 'mkl' in params.backend and 'mkldnn' not in params.backend: 340 | torch.backends.mkl.enabled = True 341 | elif 'mkldnn' in params.backend: 342 | raise ValueError("MKL-DNN is not supported yet.") 343 | elif 'openmp' in params.backend: 344 | torch.backends.openmp.enabled = True 345 | 346 | multidevice = False 347 | if "," in str(params.gpu): 348 | devices = params.gpu.split(',') 349 | multidevice = True 350 | 351 | if 'c' in str(devices[0]).lower(): 352 | backward_device = "cpu" 353 | setup_cuda(), setup_cpu() 354 | else: 355 | backward_device = "cuda:" + devices[0] 356 | setup_cuda() 357 | dtype = torch.FloatTensor 358 | 359 | elif "c" not in str(params.gpu).lower(): 360 | setup_cuda() 361 | dtype, backward_device = torch.cuda.FloatTensor, "cuda:" + str(params.gpu) 362 | else: 363 | setup_cpu() 364 | dtype, backward_device = torch.FloatTensor, "cpu" 365 | return dtype, multidevice, backward_device 366 | 367 | 368 | def setup_multi_device(net): 369 | assert len(params.gpu.split(',')) - 1 == len(params.multidevice_strategy.split(',')), \ 370 | "The number of -multidevice_strategy layer indices minus 1, must be equal to the number of -gpu devices." 371 | 372 | new_net = ModelParallel(net, params.gpu, params.multidevice_strategy) 373 | return new_net 374 | 375 | 376 | # Preprocess an image before passing it to a model. 377 | # We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR, 378 | # and subtract the mean pixel. 379 | def preprocess(image_name, image_size): 380 | image = Image.open(image_name).convert('RGB') 381 | if type(image_size) is not tuple: 382 | image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)]) 383 | Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()]) 384 | rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])]) 385 | Normalize = transforms.Compose([transforms.Normalize(mean=[103.939, 116.779, 123.68], std=[1,1,1])]) 386 | tensor = Normalize(rgb2bgr(Loader(image) * 256)).unsqueeze(0) 387 | return tensor 388 | 389 | 390 | # Undo the above preprocessing. 391 | def deprocess(output_tensor): 392 | Normalize = transforms.Compose([transforms.Normalize(mean=[-103.939, -116.779, -123.68], std=[1,1,1])]) 393 | bgr2rgb = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])]) 394 | output_tensor = bgr2rgb(Normalize(output_tensor.squeeze(0).cpu())) / 256 395 | output_tensor.clamp_(0, 1) 396 | Image2PIL = transforms.ToPILImage() 397 | image = Image2PIL(output_tensor.cpu()) 398 | return image 399 | 400 | 401 | # Combine the Y channel of the generated image and the UV/CbCr channels of the 402 | # content image to perform color-independent style transfer. 403 | def original_colors(content, generated): 404 | content_channels = list(content.convert('YCbCr').split()) 405 | generated_channels = list(generated.convert('YCbCr').split()) 406 | content_channels[0] = generated_channels[0] 407 | return Image.merge('YCbCr', content_channels).convert('RGB') 408 | 409 | 410 | # Print like Lua/Torch7 411 | def print_torch(net, multidevice): 412 | if multidevice: 413 | return 414 | simplelist = "" 415 | for i, layer in enumerate(net, 1): 416 | simplelist = simplelist + "(" + str(i) + ") -> " 417 | print("nn.Sequential ( \n [input -> " + simplelist + "output]") 418 | 419 | def strip(x): 420 | return str(x).replace(", ",',').replace("(",'').replace(")",'') + ", " 421 | def n(): 422 | return " (" + str(i) + "): " + "nn." + str(l).split("(", 1)[0] 423 | 424 | for i, l in enumerate(net, 1): 425 | if "2d" in str(l): 426 | ks, st, pd = strip(l.kernel_size), strip(l.stride), strip(l.padding) 427 | if "Conv2d" in str(l): 428 | ch = str(l.in_channels) + " -> " + str(l.out_channels) 429 | print(n() + "(" + ch + ", " + (ks).replace(",",'x', 1) + st + pd.replace(", ",')')) 430 | elif "Pool2d" in str(l): 431 | st = st.replace(" ",' ') + st.replace(", ",')') 432 | print(n() + "(" + ((ks).replace(",",'x' + ks, 1) + st).replace(", ",',')) 433 | else: 434 | print(n()) 435 | print(")") 436 | 437 | 438 | # Divide weights by channel size 439 | def normalize_weights(content_losses, style_losses): 440 | for n, i in enumerate(content_losses): 441 | i.strength = i.strength / max(i.target.size()) 442 | for n, i in enumerate(style_losses): 443 | i.strength = i.strength / max(i.target.size()) 444 | 445 | 446 | # Define an nn Module to compute content loss 447 | class ContentLoss(nn.Module): 448 | 449 | def __init__(self, strength): 450 | super(ContentLoss, self).__init__() 451 | self.strength = strength 452 | self.crit = nn.MSELoss() 453 | self.mode = 'None' 454 | 455 | def forward(self, input): 456 | if self.mode == 'loss': 457 | self.loss = self.crit(input, self.target) * self.strength 458 | elif self.mode == 'capture': 459 | self.target = input.detach() 460 | return input 461 | 462 | 463 | class GramMatrix(nn.Module): 464 | 465 | def forward(self, input): 466 | B, C, H, W = input.size() 467 | x_flat = input.view(C, H * W) 468 | return torch.mm(x_flat, x_flat.t()) 469 | 470 | 471 | # Define an nn Module to compute style loss 472 | class StyleLoss(nn.Module): 473 | 474 | def __init__(self, strength): 475 | super(StyleLoss, self).__init__() 476 | self.target = torch.Tensor() 477 | self.strength = strength 478 | self.gram = GramMatrix() 479 | self.crit = nn.MSELoss() 480 | self.mode = 'None' 481 | self.blend_weight = None 482 | 483 | def forward(self, input): 484 | self.G = self.gram(input) 485 | self.G = self.G.div(input.nelement()) 486 | if self.mode == 'capture': 487 | if self.blend_weight == None: 488 | self.target = self.G.detach() 489 | elif self.target.nelement() == 0: 490 | self.target = self.G.detach().mul(self.blend_weight) 491 | else: 492 | self.target = self.target.add(self.blend_weight, self.G.detach()) 493 | elif self.mode == 'loss': 494 | self.loss = self.strength * self.crit(self.G, self.target) 495 | return input 496 | 497 | 498 | class TVLoss(nn.Module): 499 | 500 | def __init__(self, strength): 501 | super(TVLoss, self).__init__() 502 | self.strength = strength 503 | 504 | def forward(self, input): 505 | self.x_diff = input[:,:,1:,:] - input[:,:,:-1,:] 506 | self.y_diff = input[:,:,:,1:] - input[:,:,:,:-1] 507 | self.loss = self.strength * (torch.sum(torch.abs(self.x_diff)) + torch.sum(torch.abs(self.y_diff))) 508 | return input 509 | 510 | 511 | if __name__ == "__main__": 512 | params = gpu() 513 | main() 514 | --------------------------------------------------------------------------------