├── README.md ├── Style_Transfer.ipynb ├── download.py ├── images ├── 02_convolution.png ├── 02_convolution.svg ├── 02_network_flowchart.png ├── 02_network_flowchart.svg ├── 06_network_flowchart.png ├── 06_network_flowchart.svg ├── 07_inception_flowchart.png ├── 08_transfer_learning_flowchart.png ├── 08_transfer_learning_flowchart.svg ├── 09_transfer_learning_flowchart.png ├── 09_transfer_learning_flowchart.svg ├── 11_adversarial_examples_flowchart.png ├── 11_adversarial_examples_flowchart.svg ├── 12_adversarial_noise_flowchart.png ├── 12_adversarial_noise_flowchart.svg ├── 13_visual_analysis_flowchart.png ├── 13_visual_analysis_flowchart.svg ├── 14_deepdream_flowchart.png ├── 14_deepdream_flowchart.svg ├── 14_deepdream_recursive_flowchart.png ├── 14_deepdream_recursive_flowchart.svg ├── 15_style_transfer_flowchart.png ├── 15_style_transfer_flowchart.svg ├── elon_musk.jpg ├── elon_musk_100x100.jpg ├── escher_planefilling2.jpg ├── giger.jpg ├── hulk.jpg ├── parrot.jpg ├── parrot_cropped1.jpg ├── parrot_cropped2.jpg ├── parrot_cropped3.jpg ├── parrot_padded.jpg ├── style1.jpg ├── style2.jpg ├── style3.jpg ├── style4.jpg ├── style5.jpg ├── style6.jpg ├── style7.jpg ├── style8.jpg ├── style9.jpg ├── willy_wonka_new.jpg └── willy_wonka_old.jpg └── vgg16.py /README.md: -------------------------------------------------------------------------------- 1 | # How_to_do_style_transfer_in_tensorflow 2 | 3 | ##Overview 4 | 5 | This is the code for [this](https://youtu.be/Oex0eWoU7AQ) video on Youtube by Siraj Raval as part of the Intro to Deep Learning Nanodegree with Udacity. We're going to re-purpose the pre-trained VGG16 convolutional network that won the ImageNet competition in 2014 for image classification to transfer the style of a given image to another. [This](https://arxiv.org/abs/1508.06576) is the original paper on the topic. 6 | 7 | 8 | ##Dependencies 9 | 10 | * tensorflow 11 | * matplotlib 12 | * python 3 13 | 14 | Use [pip](https://pip.pypa.io/en/stable/installing/) to install missing dependencies 15 | 16 | ##Usage 17 | 18 | Run `jupyter notebook` in the top level directory and the code will pop up in your browser. 19 | 20 | 21 | ##Credits 22 | 23 | The credits for this code go to [Aniruddha-Tapas](https://github.com/Aniruddha-Tapas). I've merely created a wrapper to get people started. 24 | 25 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | ######################################################################## 2 | # 3 | # Functions for downloading and extracting data-files from the internet. 4 | # 5 | # Implemented in Python 3.5 6 | # 7 | ######################################################################## 8 | # 9 | # This file is part of the TensorFlow Tutorials available at: 10 | # 11 | # https://github.com/Hvass-Labs/TensorFlow-Tutorials 12 | # 13 | # Published under the MIT License. See the file LICENSE for details. 14 | # 15 | # Copyright 2016 by Magnus Erik Hvass Pedersen 16 | # 17 | ######################################################################## 18 | 19 | import sys 20 | import os 21 | import urllib.request 22 | import tarfile 23 | import zipfile 24 | 25 | ######################################################################## 26 | 27 | 28 | def _print_download_progress(count, block_size, total_size): 29 | """ 30 | Function used for printing the download progress. 31 | Used as a call-back function in maybe_download_and_extract(). 32 | """ 33 | 34 | # Percentage completion. 35 | pct_complete = float(count * block_size) / total_size 36 | 37 | # Status-message. Note the \r which means the line should overwrite itself. 38 | msg = "\r- Download progress: {0:.1%}".format(pct_complete) 39 | 40 | # Print it. 41 | sys.stdout.write(msg) 42 | sys.stdout.flush() 43 | 44 | 45 | ######################################################################## 46 | 47 | 48 | def maybe_download_and_extract(url, download_dir): 49 | """ 50 | Download and extract the data if it doesn't already exist. 51 | Assumes the url is a tar-ball file. 52 | 53 | :param url: 54 | Internet URL for the tar-file to download. 55 | Example: "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 56 | 57 | :param download_dir: 58 | Directory where the downloaded file is saved. 59 | Example: "data/CIFAR-10/" 60 | 61 | :return: 62 | Nothing. 63 | """ 64 | 65 | # Filename for saving the file downloaded from the internet. 66 | # Use the filename from the URL and add it to the download_dir. 67 | filename = url.split('/')[-1] 68 | file_path = os.path.join(download_dir, filename) 69 | 70 | # Check if the file already exists. 71 | # If it exists then we assume it has also been extracted, 72 | # otherwise we need to download and extract it now. 73 | if not os.path.exists(file_path): 74 | # Check if the download directory exists, otherwise create it. 75 | if not os.path.exists(download_dir): 76 | os.makedirs(download_dir) 77 | 78 | # Download the file from the internet. 79 | file_path, _ = urllib.request.urlretrieve(url=url, 80 | filename=file_path, 81 | reporthook=_print_download_progress) 82 | 83 | print() 84 | print("Download finished. Extracting files.") 85 | 86 | if file_path.endswith(".zip"): 87 | # Unpack the zip-file. 88 | zipfile.ZipFile(file=file_path, mode="r").extractall(download_dir) 89 | elif file_path.endswith((".tar.gz", ".tgz")): 90 | # Unpack the tar-ball. 91 | tarfile.open(name=file_path, mode="r:gz").extractall(download_dir) 92 | 93 | print("Done.") 94 | else: 95 | print("Data has apparently already been downloaded and unpacked.") 96 | 97 | 98 | ######################################################################## 99 | -------------------------------------------------------------------------------- /images/02_convolution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/02_convolution.png -------------------------------------------------------------------------------- /images/02_convolution.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 24 | 26 | 34 | 40 | 41 | 49 | 55 | 56 | 65 | 71 | 72 | 80 | 86 | 87 | 96 | 102 | 103 | 106 | 110 | 111 | 120 | 126 | 127 | 128 | 151 | 153 | 154 | 156 | image/svg+xml 157 | 159 | 160 | 161 | 162 | 163 | 168 | 170 | 352 | 359 | 360 | 363 | 371 | 378 | 379 | 382 | 390 | 397 | 398 | 401 | 409 | 416 | 417 | 418 | 423 | 428 | 472 | 480 | 488 | 496 | 499 | 561 | 569 | 570 | 575 | 580 | 585 | 590 | Input Image with Filter Overlaid (4 copies for clarity) 601 | Result of Convolution 612 | 613 | 614 | -------------------------------------------------------------------------------- /images/02_network_flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/02_network_flowchart.png -------------------------------------------------------------------------------- /images/06_network_flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/06_network_flowchart.png -------------------------------------------------------------------------------- /images/07_inception_flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/07_inception_flowchart.png -------------------------------------------------------------------------------- /images/08_transfer_learning_flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/08_transfer_learning_flowchart.png -------------------------------------------------------------------------------- /images/09_transfer_learning_flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/09_transfer_learning_flowchart.png -------------------------------------------------------------------------------- /images/11_adversarial_examples_flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/11_adversarial_examples_flowchart.png -------------------------------------------------------------------------------- /images/12_adversarial_noise_flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/12_adversarial_noise_flowchart.png -------------------------------------------------------------------------------- /images/12_adversarial_noise_flowchart.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 20 | 22 | 30 | 35 | 36 | 44 | 49 | 50 | 58 | 63 | 64 | 72 | 77 | 78 | 86 | 91 | 92 | 100 | 105 | 106 | 114 | 119 | 120 | 128 | 133 | 134 | 142 | 147 | 148 | 156 | 161 | 162 | 170 | 175 | 176 | 184 | 190 | 191 | 200 | 206 | 207 | 216 | 222 | 223 | 232 | 238 | 239 | 248 | 254 | 255 | 264 | 270 | 271 | 280 | 286 | 287 | 296 | 302 | 303 | 312 | 318 | 319 | 328 | 334 | 335 | 344 | 350 | 351 | 360 | 366 | 367 | 376 | 382 | 383 | 392 | 398 | 399 | 408 | 414 | 415 | 424 | 430 | 431 | 440 | 446 | 447 | 456 | 462 | 463 | 471 | 477 | 478 | 487 | 493 | 494 | 502 | 508 | 509 | 517 | 523 | 524 | 533 | 539 | 540 | 548 | 554 | 555 | 563 | 569 | 570 | 571 | 593 | 595 | 596 | 598 | image/svg+xml 599 | 601 | 602 | 603 | 604 | 605 | 610 | 613 | Convolution5x5 filters16 outputs 633 | 640 | 641 | 644 | Convolution5x5 filters36 outputs 664 | 671 | 672 | 675 | Fully-Connected128 features 692 | 699 | 700 | 703 | L2-LossRegularization 719 | 726 | 727 | 730 | Combined Loss 742 | 749 | 750 | 753 | Use gradient of Combined Loss to update adversarial noise 764 | 771 | 772 | 775 | Use gradient of Loss Function to update network variables 786 | 793 | 794 | 797 | Fully-ConnectedPart of SoftMax/Loss10 features 817 | 824 | 825 | 828 | 831 | Loss FunctionCross Entropy 852 | 859 | 860 | 862 | 864 | TrueClass 879 | 882 | 7 894 | 900 | 901 | 902 | 907 | 908 | 910 | 913 | AdversarialClass 928 | 931 | 3 943 | 949 | 950 | 951 | 956 | 957 | 958 | 961 | 964 | SoftMaxClassifier 980 | 987 | 988 | 990 | 993 | 996 | 3 1008 | 1014 | 1015 | PredictedClass 1030 | 1031 | 1036 | 1037 | 1038 | 1043 | 1049 | 1052 | + 1064 | 1070 | 1071 | 1077 | 1083 | 1089 | 1094 | 1100 | 1106 | + 1118 | 1125 | 1131 | 1137 | 1143 | 1149 | 1155 | 1161 | 1167 | 1173 | 1175 | 1181 | 1184 | 1192 | 1199 | 1200 | 1205 | 1211 | 1219 | 1225 | 1231 | 1233 | 1240 | 1248 | 1249 | 1250 | 1251 | 1252 | -------------------------------------------------------------------------------- /images/13_visual_analysis_flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/13_visual_analysis_flowchart.png -------------------------------------------------------------------------------- /images/14_deepdream_flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/14_deepdream_flowchart.png -------------------------------------------------------------------------------- /images/14_deepdream_recursive_flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/14_deepdream_recursive_flowchart.png -------------------------------------------------------------------------------- /images/15_style_transfer_flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/15_style_transfer_flowchart.png -------------------------------------------------------------------------------- /images/elon_musk.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/elon_musk.jpg -------------------------------------------------------------------------------- /images/elon_musk_100x100.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/elon_musk_100x100.jpg -------------------------------------------------------------------------------- /images/escher_planefilling2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/escher_planefilling2.jpg -------------------------------------------------------------------------------- /images/giger.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/giger.jpg -------------------------------------------------------------------------------- /images/hulk.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/hulk.jpg -------------------------------------------------------------------------------- /images/parrot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/parrot.jpg -------------------------------------------------------------------------------- /images/parrot_cropped1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/parrot_cropped1.jpg -------------------------------------------------------------------------------- /images/parrot_cropped2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/parrot_cropped2.jpg -------------------------------------------------------------------------------- /images/parrot_cropped3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/parrot_cropped3.jpg -------------------------------------------------------------------------------- /images/parrot_padded.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/parrot_padded.jpg -------------------------------------------------------------------------------- /images/style1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/style1.jpg -------------------------------------------------------------------------------- /images/style2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/style2.jpg -------------------------------------------------------------------------------- /images/style3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/style3.jpg -------------------------------------------------------------------------------- /images/style4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/style4.jpg -------------------------------------------------------------------------------- /images/style5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/style5.jpg -------------------------------------------------------------------------------- /images/style6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/style6.jpg -------------------------------------------------------------------------------- /images/style7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/style7.jpg -------------------------------------------------------------------------------- /images/style8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/style8.jpg -------------------------------------------------------------------------------- /images/style9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/style9.jpg -------------------------------------------------------------------------------- /images/willy_wonka_new.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/willy_wonka_new.jpg -------------------------------------------------------------------------------- /images/willy_wonka_old.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llSourcell/How_to_do_style_transfer_in_tensorflow/29cbea8211f06dccf37789e6beddceef7c5afca4/images/willy_wonka_old.jpg -------------------------------------------------------------------------------- /vgg16.py: -------------------------------------------------------------------------------- 1 | # The pre-trained VGG16 Model for TensorFlow. 2 | # 3 | # This model seems to produce better-looking images in Style Transfer 4 | # than the Inception 5h model that otherwise works well for DeepDream. 5 | # 6 | # Implemented in Python 3.5 with TensorFlow v0.12.0rc1 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | import download 11 | import os 12 | 13 | # The pre-trained VGG16 model is taken from this tutorial: 14 | # https://github.com/pkmital/CADL/blob/master/session-4/libs/vgg16.py 15 | 16 | # The class-names are available in the following URL: 17 | # https://s3.amazonaws.com/cadl/models/synset.txt 18 | 19 | # Internet URL for the file with the VGG16 model. 20 | # Note that this might change in the future and will need to be updated. 21 | data_url = "https://s3.amazonaws.com/cadl/models/vgg16.tfmodel" 22 | 23 | # Directory to store the downloaded data. 24 | data_dir = "vgg16/" 25 | 26 | # File containing the TensorFlow graph definition. (Downloaded) 27 | path_graph_def = "vgg16.tfmodel" 28 | 29 | 30 | 31 | def maybe_download(): 32 | """ 33 | Download the VGG16 model from the internet if it does not already 34 | exist in the data_dir. The file is about 550 MB. 35 | """ 36 | 37 | print("Downloading VGG16 Model ...") 38 | 39 | # The file on the internet is not stored in a compressed format. 40 | # This function should not extract the file when it does not have 41 | # a relevant filename-extensions such as .zip or .tar.gz 42 | download.maybe_download_and_extract(url=data_url, download_dir=data_dir) 43 | 44 | 45 | class VGG16: 46 | """ 47 | The VGG16 model is a Deep Neural Network which has already been 48 | trained for classifying images into 1000 different categories. 49 | 50 | When you create a new instance of this class, the VGG16 model 51 | will be loaded and can be used immediately without training. 52 | """ 53 | 54 | # Name of the tensor for feeding the input image. 55 | tensor_name_input_image = "images:0" 56 | 57 | # Names of the tensors for the dropout random-values.. 58 | tensor_name_dropout = 'dropout/random_uniform:0' 59 | tensor_name_dropout1 = 'dropout_1/random_uniform:0' 60 | 61 | # Names for the convolutional layers in the model for use in Style Transfer. 62 | layer_names = ['conv1_1/conv1_1', 'conv1_2/conv1_2', 63 | 'conv2_1/conv2_1', 'conv2_2/conv2_2', 64 | 'conv3_1/conv3_1', 'conv3_2/conv3_2', 'conv3_3/conv3_3', 65 | 'conv4_1/conv4_1', 'conv4_2/conv4_2', 'conv4_3/conv4_3', 66 | 'conv5_1/conv5_1', 'conv5_2/conv5_2', 'conv5_3/conv5_3'] 67 | 68 | def __init__(self): 69 | # Now load the model from file. The way TensorFlow 70 | # does this is confusing and requires several steps. 71 | 72 | # Create a new TensorFlow computational graph. 73 | self.graph = tf.Graph() 74 | 75 | # Set the new graph as the default. 76 | with self.graph.as_default(): 77 | 78 | # TensorFlow graphs are saved to disk as so-called Protocol Buffers 79 | # aka. proto-bufs which is a file-format that works on multiple 80 | # platforms. In this case it is saved as a binary file. 81 | 82 | # Open the graph-def file for binary reading. 83 | path = os.path.join(data_dir, path_graph_def) 84 | with tf.gfile.FastGFile(path, 'rb') as file: 85 | # The graph-def is a saved copy of a TensorFlow graph. 86 | # First we need to create an empty graph-def. 87 | graph_def = tf.GraphDef() 88 | 89 | # Then we load the proto-buf file into the graph-def. 90 | graph_def.ParseFromString(file.read()) 91 | 92 | # Finally we import the graph-def to the default TensorFlow graph. 93 | tf.import_graph_def(graph_def, name='') 94 | 95 | # Now self.graph holds the VGG16 model from the proto-buf file. 96 | 97 | # Get a reference to the tensor for inputting images to the graph. 98 | self.input = self.graph.get_tensor_by_name(self.tensor_name_input_image) 99 | 100 | # Get references to the tensors for the commonly used layers. 101 | self.layer_tensors = [self.graph.get_tensor_by_name(name + ":0") for name in self.layer_names] 102 | 103 | def get_layer_tensors(self, layer_ids): 104 | """ 105 | Return a list of references to the tensors for the layers with the given id's. 106 | """ 107 | 108 | return [self.layer_tensors[idx] for idx in layer_ids] 109 | 110 | def get_layer_names(self, layer_ids): 111 | """ 112 | Return a list of names for the layers with the given id's. 113 | """ 114 | 115 | return [self.layer_names[idx] for idx in layer_ids] 116 | 117 | def get_all_layer_names(self, startswith=None): 118 | """ 119 | Return a list of all the layers (operations) in the graph. 120 | The list can be filtered for names that start with the given string. 121 | """ 122 | 123 | # Get a list of the names for all layers (operations) in the graph. 124 | names = [op.name for op in self.graph.get_operations()] 125 | 126 | # Filter the list of names so we only get those starting with 127 | # the given string. 128 | if startswith is not None: 129 | names = [name for name in names if name.startswith(startswith)] 130 | 131 | return names 132 | 133 | def create_feed_dict(self, image): 134 | """ 135 | Create and return a feed-dict with an image. 136 | 137 | :param image: 138 | The input image is a 3-dim array which is already decoded. 139 | The pixels MUST be values between 0 and 255 (float or int). 140 | 141 | :return: 142 | Dict for feeding to the graph in TensorFlow. 143 | """ 144 | 145 | # Expand 3-dim array to 4-dim by prepending an 'empty' dimension. 146 | # This is because we are only feeding a single image, but the 147 | # VGG16 model was built to take multiple images as input. 148 | image = np.expand_dims(image, axis=0) 149 | 150 | if False: 151 | # In the original code using this VGG16 model, the random values 152 | # for the dropout are fixed to 1.0. 153 | # Experiments suggest that it does not seem to matter for 154 | # Style Transfer, and this causes an error with a GPU. 155 | dropout_fix = 1.0 156 | 157 | # Create feed-dict for inputting data to TensorFlow. 158 | feed_dict = {self.tensor_name_input_image: image, 159 | self.tensor_name_dropout: [[dropout_fix]], 160 | self.tensor_name_dropout1: [[dropout_fix]]} 161 | else: 162 | # Create feed-dict for inputting data to TensorFlow. 163 | feed_dict = {self.tensor_name_input_image: image} 164 | 165 | return feed_dict 166 | --------------------------------------------------------------------------------