├── Images └── pebbles.jpg ├── DeepImageSynthesis ├── __init__.py ├── ImageSyn.py ├── LossFunctions.py └── Misc.py ├── README.md └── Models └── VGG_ave_pool_deploy.prototxt /Images/pebbles.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leongatys/DeepTextures/HEAD/Images/pebbles.jpg -------------------------------------------------------------------------------- /DeepImageSynthesis/__init__.py: -------------------------------------------------------------------------------- 1 | import LossFunctions 2 | from ImageSyn import ImageSyn 3 | from .Misc import * 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepTextures 2 | Code to synthesise textures using convolutional neural networks as described in the paper "Texture Synthesis Using Convolutional Neural Networks" (Gatys et al., NIPS 2015) (http://arxiv.org/abs/1505.07376). 3 | More examples of synthesised textures can be found at http://bethgelab.org/deeptextures/. 4 | 5 | The IPythonNotebook Example.ipynb contains the code to synthesise the pebble texture shown in Figure 3A (177k parameters) of the revised version of the paper. In the notebook I additionally match the pixel histograms in each colorchannel of the synthesised and original texture, which is not done in the figures in the paper. 6 | #Prerequisites 7 | * To run the code you need a recent version of the [Caffe](https://github.com/BVLC/caffe) deep learning framework and its dependencies (tested with master branch at commit 20c474fe40fe43dee68545dc80809f30ccdbf99b). 8 | * The images in the paper were generated using a normalised version of the [19-layer VGG-Network](http://www.robots.ox.ac.uk/~vgg/research/very_deep/) 9 | described in the work by [Simonyan and Zisserman](http://arxiv.org/abs/1409.1556). The weights in the normalised network are scaled 10 | such that the mean activation of each filter over images and positions is equal to 1. 11 | **The normalised network can be downloaded [here](http://bethgelab.org/media/uploads/deeptextures/vgg_normalised.caffemodel) and has to be copied into the Models/ folder.** 12 | 13 | # Disclaimer 14 | This software is published for academic and non-commercial use only. 15 | -------------------------------------------------------------------------------- /DeepImageSynthesis/ImageSyn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from scipy.optimize import minimize 4 | from Misc import * 5 | 6 | def ImageSyn(net, constraints, init=None, bounds=None, callback=None, minimize_options=None, gradient_free_region=None): 7 | ''' 8 | This function generates the image by performing gradient descent on the pixels to match the constraints. 9 | 10 | :param net: caffe.Classifier object that defines the network used to generate the image 11 | :param constraints: dictionary object that contains the constraints on each layer used for the image generation 12 | :param init: the initial image to start the gradient descent from. Defaults to gaussian white noise 13 | :param bounds: the optimisation bounds passed to the optimiser 14 | :param callback: the callback function passed to the optimiser 15 | :param minimize_options: the options passed to the optimiser 16 | :param gradient_free_region: a binary mask that defines all pixels that should be ignored in the in the gradient descent 17 | :return: result object from the L-BFGS optimisation 18 | ''' 19 | 20 | if init==None: 21 | init = np.random.randn(*net.blobs['data'].data.shape) 22 | 23 | #get indices for gradient 24 | layers, indices = get_indices(net, constraints) 25 | 26 | #function to minimise 27 | def f(x): 28 | x = x.reshape(*net.blobs['data'].data.shape) 29 | net.forward(data=x, end=layers[min(len(layers)-1, indices[0]+1)]) 30 | f_val = 0 31 | #clear gradient in all layers 32 | for index in indices: 33 | net.blobs[layers[index]].diff[...] = np.zeros_like(net.blobs[layers[index]].diff) 34 | 35 | for i,index in enumerate(indices): 36 | layer = layers[index] 37 | for l,loss_function in enumerate(constraints[layer].loss_functions): 38 | constraints[layer].parameter_lists[l].update({'activations': net.blobs[layer].data.copy()}) 39 | val, grad = loss_function(**constraints[layer].parameter_lists[l]) 40 | f_val += val 41 | net.blobs[layer].diff[:] += grad 42 | #gradient wrt inactive units is 0 43 | net.blobs[layer].diff[(net.blobs[layer].data == 0)] = 0. 44 | if index == indices[-1]: 45 | f_grad = net.backward(start=layer)['data'].copy() 46 | else: 47 | net.backward(start=layer, end=layers[indices[i+1]]) 48 | 49 | if gradient_free_region!=None: 50 | f_grad[gradient_free_region==1] = 0 51 | 52 | return [f_val, np.array(f_grad.ravel(), dtype=float)] 53 | 54 | result = minimize(f, init, 55 | method='L-BFGS-B', 56 | jac=True, 57 | bounds=bounds, 58 | callback=callback, 59 | options=minimize_options) 60 | return result 61 | 62 | 63 | -------------------------------------------------------------------------------- /DeepImageSynthesis/LossFunctions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pdb 3 | 4 | def gram_mse_loss(activations, target_gram_matrix, weight=1., linear_transform=None): 5 | ''' 6 | This function computes an elementwise mean squared distance between the gram matrices of the source and the generated image. 7 | 8 | :param activations: the network activations in response to the image that is generated 9 | :param target_gram_matrix: gram matrix in response to the source image 10 | :param weight: scaling factor for the loss function 11 | :param linear_transform: linear transform that is applied to the feature vector at all positions before gram matrix computation 12 | :return: mean squared distance between normalised gram matrices and gradient wrt activations 13 | ''' 14 | 15 | N = activations.shape[1] 16 | fm_size = np.array(activations.shape[2:]) 17 | M = np.prod(fm_size) 18 | G_target = target_gram_matrix 19 | if linear_transform == None: 20 | F = activations.reshape(N,-1) 21 | G = np.dot(F,F.T) / M 22 | loss = float(weight)/4 * ((G - G_target)**2).sum() / N**2 23 | gradient = (weight * np.dot(F.T, (G - G_target)).T / (M * N**2)).reshape(1, N, fm_size[0], fm_size[1]) 24 | else: 25 | F = np.dot(linear_transform, activations.reshape(N,-1)) 26 | G = np.dot(F,F.T) / M 27 | loss = float(weight)/4 * ((G - G_target)**2).sum() / N**2 28 | gradient = (weight * np.dot(linear_transform.T, np.dot(F.T, (G - G_target)).T) / (M * N**2)).reshape(1, N, fm_size[0], fm_size[1]) 29 | 30 | return [loss, gradient] 31 | 32 | def meanfm_mse_loss(activations, target_activations, weight=1., linear_transform=None): 33 | ''' 34 | This function computes an elementwise mean squared distance between the mean feature maps of the source and the generated image. 35 | 36 | :param activations: the network activations in response to the image that is generated 37 | :param target_activations: the network activations in response to the source image 38 | :param weight: scaling factor for the loss function 39 | :param linear_transform: linear transform that is applied to the feature vector at all positions before gram matrix computation 40 | :return: mean squared distance between mean feature maps and gradient wrt activations 41 | ''' 42 | 43 | N = activations.shape[1] 44 | fm_size = np.array(activations.shape[2:]) 45 | M = np.prod(fm_size) 46 | 47 | target_fm_size = np.array(target_activations.shape[2:]) 48 | M_target = np.prod(target_fm_size) 49 | if linear_transform==None: 50 | target_mean_fm = target_activations.reshape(N,-1).sum(1) / M_target 51 | mean_fm = activations.reshape(N,-1).sum(1) / M 52 | f_val = float(weight)/2 * ((mean_fm - target_mean_fm)**2).sum() / N 53 | f_grad = weight * (np.tile((mean_fm - target_mean_fm)[:,None],(1,M)) / (M * N)).reshape(1,N,fm_size[0],fm_size[1]) 54 | else: 55 | target_mean_fm = np.dot(linear_transform, target_activations.reshape(N,-1)).sum(1) / M_target 56 | mean_fm = np.dot(linear_transform, activations.reshape(N,-1)).sum(1) / M 57 | f_val = float(weight)/2 * ((mean_fm - target_mean_fm)**2).sum() / N 58 | f_grad = weight * (np.dot(linear_transform.T ,np.tile((mean_fm - target_mean_fm)[:,None],(1,M))) / (M * N)).reshape(1,N,fm_size[0],fm_size[1]) 59 | return [f_val,f_grad] 60 | 61 | -------------------------------------------------------------------------------- /Models/VGG_ave_pool_deploy.prototxt: -------------------------------------------------------------------------------- 1 | name: "VGG_ILSVRC_19_layers" 2 | input: "data" 3 | input_dim: 1 4 | input_dim: 3 5 | input_dim: 256 6 | input_dim: 256 7 | force_backward: true 8 | layers { 9 | bottom: "data" 10 | top: "conv1_1" 11 | name: "conv1_1" 12 | type: CONVOLUTION 13 | convolution_param { 14 | num_output: 64 15 | pad: 1 16 | kernel_size: 3 17 | } 18 | } 19 | layers { 20 | bottom: "conv1_1" 21 | top: "conv1_1" 22 | name: "relu1_1" 23 | type: RELU 24 | } 25 | layers { 26 | bottom: "conv1_1" 27 | top: "conv1_2" 28 | name: "conv1_2" 29 | type: CONVOLUTION 30 | convolution_param { 31 | num_output: 64 32 | pad: 1 33 | kernel_size: 3 34 | } 35 | } 36 | layers { 37 | bottom: "conv1_2" 38 | top: "conv1_2" 39 | name: "relu1_2" 40 | type: RELU 41 | } 42 | layers { 43 | bottom: "conv1_2" 44 | top: "pool1" 45 | name: "pool1" 46 | type: POOLING 47 | pooling_param { 48 | pool: AVE 49 | kernel_size: 2 50 | stride: 2 51 | } 52 | } 53 | layers { 54 | bottom: "pool1" 55 | top: "conv2_1" 56 | name: "conv2_1" 57 | type: CONVOLUTION 58 | convolution_param { 59 | num_output: 128 60 | pad: 1 61 | kernel_size: 3 62 | } 63 | } 64 | layers { 65 | bottom: "conv2_1" 66 | top: "conv2_1" 67 | name: "relu2_1" 68 | type: RELU 69 | } 70 | layers { 71 | bottom: "conv2_1" 72 | top: "conv2_2" 73 | name: "conv2_2" 74 | type: CONVOLUTION 75 | convolution_param { 76 | num_output: 128 77 | pad: 1 78 | kernel_size: 3 79 | } 80 | } 81 | layers { 82 | bottom: "conv2_2" 83 | top: "conv2_2" 84 | name: "relu2_2" 85 | type: RELU 86 | } 87 | layers { 88 | bottom: "conv2_2" 89 | top: "pool2" 90 | name: "pool2" 91 | type: POOLING 92 | pooling_param { 93 | pool: AVE 94 | kernel_size: 2 95 | stride: 2 96 | } 97 | } 98 | layers { 99 | bottom: "pool2" 100 | top: "conv3_1" 101 | name: "conv3_1" 102 | type: CONVOLUTION 103 | convolution_param { 104 | num_output: 256 105 | pad: 1 106 | kernel_size: 3 107 | } 108 | } 109 | layers { 110 | bottom: "conv3_1" 111 | top: "conv3_1" 112 | name: "relu3_1" 113 | type: RELU 114 | } 115 | layers { 116 | bottom: "conv3_1" 117 | top: "conv3_2" 118 | name: "conv3_2" 119 | type: CONVOLUTION 120 | convolution_param { 121 | num_output: 256 122 | pad: 1 123 | kernel_size: 3 124 | } 125 | } 126 | layers { 127 | bottom: "conv3_2" 128 | top: "conv3_2" 129 | name: "relu3_2" 130 | type: RELU 131 | } 132 | layers { 133 | bottom: "conv3_2" 134 | top: "conv3_3" 135 | name: "conv3_3" 136 | type: CONVOLUTION 137 | convolution_param { 138 | num_output: 256 139 | pad: 1 140 | kernel_size: 3 141 | } 142 | } 143 | layers { 144 | bottom: "conv3_3" 145 | top: "conv3_3" 146 | name: "relu3_3" 147 | type: RELU 148 | } 149 | layers { 150 | bottom: "conv3_3" 151 | top: "conv3_4" 152 | name: "conv3_4" 153 | type: CONVOLUTION 154 | convolution_param { 155 | num_output: 256 156 | pad: 1 157 | kernel_size: 3 158 | } 159 | } 160 | layers { 161 | bottom: "conv3_4" 162 | top: "conv3_4" 163 | name: "relu3_4" 164 | type: RELU 165 | } 166 | layers { 167 | bottom: "conv3_4" 168 | top: "pool3" 169 | name: "pool3" 170 | type: POOLING 171 | pooling_param { 172 | pool: AVE 173 | kernel_size: 2 174 | stride: 2 175 | } 176 | } 177 | layers { 178 | bottom: "pool3" 179 | top: "conv4_1" 180 | name: "conv4_1" 181 | type: CONVOLUTION 182 | convolution_param { 183 | num_output: 512 184 | pad: 1 185 | kernel_size: 3 186 | } 187 | } 188 | layers { 189 | bottom: "conv4_1" 190 | top: "conv4_1" 191 | name: "relu4_1" 192 | type: RELU 193 | } 194 | layers { 195 | bottom: "conv4_1" 196 | top: "conv4_2" 197 | name: "conv4_2" 198 | type: CONVOLUTION 199 | convolution_param { 200 | num_output: 512 201 | pad: 1 202 | kernel_size: 3 203 | } 204 | } 205 | layers { 206 | bottom: "conv4_2" 207 | top: "conv4_2" 208 | name: "relu4_2" 209 | type: RELU 210 | } 211 | layers { 212 | bottom: "conv4_2" 213 | top: "conv4_3" 214 | name: "conv4_3" 215 | type: CONVOLUTION 216 | convolution_param { 217 | num_output: 512 218 | pad: 1 219 | kernel_size: 3 220 | } 221 | } 222 | layers { 223 | bottom: "conv4_3" 224 | top: "conv4_3" 225 | name: "relu4_3" 226 | type: RELU 227 | } 228 | layers { 229 | bottom: "conv4_3" 230 | top: "conv4_4" 231 | name: "conv4_4" 232 | type: CONVOLUTION 233 | convolution_param { 234 | num_output: 512 235 | pad: 1 236 | kernel_size: 3 237 | } 238 | } 239 | layers { 240 | bottom: "conv4_4" 241 | top: "conv4_4" 242 | name: "relu4_4" 243 | type: RELU 244 | } 245 | layers { 246 | bottom: "conv4_4" 247 | top: "pool4" 248 | name: "pool4" 249 | type: POOLING 250 | pooling_param { 251 | pool: AVE 252 | kernel_size: 2 253 | stride: 2 254 | } 255 | } 256 | layers { 257 | bottom: "pool4" 258 | top: "conv5_1" 259 | name: "conv5_1" 260 | type: CONVOLUTION 261 | convolution_param { 262 | num_output: 512 263 | pad: 1 264 | kernel_size: 3 265 | } 266 | } 267 | layers { 268 | bottom: "conv5_1" 269 | top: "conv5_1" 270 | name: "relu5_1" 271 | type: RELU 272 | } 273 | layers { 274 | bottom: "conv5_1" 275 | top: "conv5_2" 276 | name: "conv5_2" 277 | type: CONVOLUTION 278 | convolution_param { 279 | num_output: 512 280 | pad: 1 281 | kernel_size: 3 282 | } 283 | } 284 | layers { 285 | bottom: "conv5_2" 286 | top: "conv5_2" 287 | name: "relu5_2" 288 | type: RELU 289 | } 290 | layers { 291 | bottom: "conv5_2" 292 | top: "conv5_3" 293 | name: "conv5_3" 294 | type: CONVOLUTION 295 | convolution_param { 296 | num_output: 512 297 | pad: 1 298 | kernel_size: 3 299 | } 300 | } 301 | layers { 302 | bottom: "conv5_3" 303 | top: "conv5_3" 304 | name: "relu5_3" 305 | type: RELU 306 | } 307 | layers { 308 | bottom: "conv5_3" 309 | top: "conv5_4" 310 | name: "conv5_4" 311 | type: CONVOLUTION 312 | convolution_param { 313 | num_output: 512 314 | pad: 1 315 | kernel_size: 3 316 | } 317 | } 318 | layers { 319 | bottom: "conv5_4" 320 | top: "conv5_4" 321 | name: "relu5_4" 322 | type: RELU 323 | } 324 | layers { 325 | bottom: "conv5_4" 326 | top: "pool5" 327 | name: "pool5" 328 | type: POOLING 329 | pooling_param { 330 | pool: AVE 331 | kernel_size: 2 332 | stride: 2 333 | } 334 | } 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | -------------------------------------------------------------------------------- /DeepImageSynthesis/Misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import caffe 4 | import matplotlib.pyplot as plt 5 | from IPython.display import display,clear_output 6 | 7 | class constraint(object): 8 | ''' 9 | Object that contains the constraints on a particular layer for the image synthesis. 10 | ''' 11 | 12 | def __init__(self, loss_functions, parameter_lists): 13 | self.loss_functions = loss_functions 14 | self.parameter_lists = parameter_lists 15 | 16 | def get_indices(net, constraints): 17 | ''' 18 | Helper function to pick the indices of the layers included in the loss function from all layers of the network. 19 | 20 | :param net: caffe.Classifier object defining the network 21 | :param contraints: dictionary where each key is a layer and the corresponding entry is a constraint object 22 | :return: list of layers in the network and list of indices of the loss layers in descending order 23 | ''' 24 | 25 | indices = [ndx for ndx,layer in enumerate(net.blobs.keys()) if layer in constraints.keys()] 26 | return net.blobs.keys(),indices[::-1] 27 | 28 | def show_progress(x, net, title=None, handle=False): 29 | ''' 30 | Helper function to show intermediate results during the gradient descent. 31 | 32 | :param x: vectorised image on which the gradient descent is performed 33 | :param net: caffe.Classifier object defining the network 34 | :param title: optional title of figuer 35 | :param handle: obtional return of figure handle 36 | :return: figure handle (optional) 37 | ''' 38 | 39 | disp_image = (x.reshape(*net.blobs['data'].data.shape)[0].transpose(1,2,0)[:,:,::-1]-x.min())/(x.max()-x.min()) 40 | clear_output() 41 | plt.imshow(disp_image) 42 | if title != None: 43 | ax = plt.gca() 44 | ax.set_title(title) 45 | f = plt.gcf() 46 | display() 47 | plt.show() 48 | if handle: 49 | return f 50 | 51 | def get_bounds(images, im_size): 52 | ''' 53 | Helper function to get optimisation bounds from source image. 54 | 55 | :param images: a list of images 56 | :param im_size: image size (height, width) for the generated image 57 | :return: list of bounds on each pixel for the optimisation 58 | ''' 59 | 60 | lowerbound = np.min([im.min() for im in images]) 61 | upperbound = np.max([im.max() for im in images]) 62 | bounds = list() 63 | for b in range(im_size[0]*im_size[1] * 3): 64 | bounds.append((lowerbound,upperbound)) 65 | return bounds 66 | 67 | def test_gradient(function, parameters, eps=1e-6): 68 | ''' 69 | Simple gradient test for any loss function defined on layer output 70 | 71 | :param function: function to be tested, must return function value and gradient 72 | :param parameters: input arguments to function passed as keyword arguments 73 | :param eps: step size for numerical gradient evaluation 74 | :return: numerical gradient and gradient from function 75 | ''' 76 | 77 | i,j,k,l = [np.random.randint(s) for s in parameters['activations'].shape] 78 | f1,_ = function(**parameters) 79 | parameters['activations'][i,j,k,l] += eps 80 | f2,g = function(**parameters) 81 | 82 | return [(f2-f1)/eps,g[i,j,k,l]] 83 | 84 | def gram_matrix(activations): 85 | ''' 86 | Gives the gram matrix for feature map activations in caffe format with batchsize 1. Normalises by spatial dimensions. 87 | 88 | :param activations: feature map activations to compute gram matrix from 89 | :return: normalised gram matrix 90 | ''' 91 | 92 | N = activations.shape[1] 93 | F = activations.reshape(N,-1) 94 | M = F.shape[1] 95 | G = np.dot(F,F.T) / M 96 | return G 97 | 98 | def disp_img(img): 99 | ''' 100 | Returns rescaled image for display with imshow 101 | ''' 102 | disp_img = (img - img.min())/(img.max()-img.min()) 103 | return disp_img 104 | 105 | def uniform_hist(X): 106 | ''' 107 | Maps data distribution onto uniform histogram 108 | 109 | :param X: data vector 110 | :return: data vector with uniform histogram 111 | ''' 112 | 113 | Z = [(x, i) for i, x in enumerate(X)] 114 | Z.sort() 115 | n = len(Z) 116 | Rx = [0]*n 117 | start = 0 # starting mark 118 | for i in range(1, n): 119 | if Z[i][0] != Z[i-1][0]: 120 | for j in range(start, i): 121 | Rx[Z[j][1]] = float(start+1+i)/2.0; 122 | start = i 123 | for j in range(start, n): 124 | Rx[Z[j][1]] = float(start+1+n)/2.0; 125 | return np.asarray(Rx) / float(len(Rx)) 126 | 127 | def histogram_matching(org_image, match_image, grey=False, n_bins=100): 128 | ''' 129 | Matches histogram of each color channel of org_image with histogram of match_image 130 | 131 | :param org_image: image whose distribution should be remapped 132 | :param match_image: image whose distribution should be matched 133 | :param grey: True if images are greyscale 134 | :param n_bins: number of bins used for histogram calculation 135 | :return: org_image with same histogram as match_image 136 | ''' 137 | 138 | if grey: 139 | hist, bin_edges = np.histogram(match_image.ravel(), bins=n_bins, density=True) 140 | cum_values = np.zeros(bin_edges.shape) 141 | cum_values[1:] = np.cumsum(hist*np.diff(bin_edges)) 142 | inv_cdf = scipy.interpolate.interp1d(cum_values, bin_edges,bounds_error=True) 143 | r = np.asarray(uniform_hist(org_image.ravel())) 144 | r[r>cum_values.max()] = cum_values.max() 145 | matched_image = inv_cdf(r).reshape(org_image.shape) 146 | else: 147 | matched_image = np.zeros_like(org_image) 148 | for i in range(3): 149 | hist, bin_edges = np.histogram(match_image[:,:,i].ravel(), bins=n_bins, density=True) 150 | cum_values = np.zeros(bin_edges.shape) 151 | cum_values[1:] = np.cumsum(hist*np.diff(bin_edges)) 152 | inv_cdf = scipy.interpolate.interp1d(cum_values, bin_edges,bounds_error=True) 153 | r = np.asarray(uniform_hist(org_image[:,:,i].ravel())) 154 | r[r>cum_values.max()] = cum_values.max() 155 | matched_image[:,:,i] = inv_cdf(r).reshape(org_image[:,:,i].shape) 156 | 157 | return matched_image 158 | 159 | def load_image(file_name, im_size, net_model, net_weights, mean, show_img=False): 160 | ''' 161 | Loads and preprocesses image into caffe format by constructing and using the appropriate network. 162 | 163 | :param file_name: file name of the image to be loaded 164 | :param im_size: size of the image after preprocessing if float that the original image is rescaled to contain im_size**2 pixels 165 | :param net_model: file name of the prototxt file defining the network model 166 | :param net_weights: file name of caffemodel file defining the network weights 167 | :param mean: mean values for each color channel (bgr) which are subtracted during preprocessing 168 | :param show_img: if True shows the loaded image before preprocessing 169 | :return: preprocessed image and caffe.Classifier object defining the network 170 | ''' 171 | 172 | img = caffe.io.load_image(file_name) 173 | if show_img: 174 | plt.imshow(img) 175 | if isinstance(im_size,float): 176 | im_scale = np.sqrt(im_size**2 /np.prod(np.asarray(img.shape[:2]))) 177 | im_size = im_scale * np.asarray(img.shape[:2]) 178 | batchSize = 1 179 | with open(net_model,'r+') as f: 180 | data = f.readlines() 181 | data[2] = "input_dim: %i\n" %(batchSize) 182 | data[4] = "input_dim: %i\n" %(im_size[0]) 183 | data[5] = "input_dim: %i\n" %(im_size[1]) 184 | with open(net_model,'r+') as f: 185 | f.writelines(data) 186 | net_mean = np.tile(mean[:,None,None],(1,) + tuple(im_size.astype(int))) 187 | #load pretrained network 188 | net = caffe.Classifier( 189 | net_model, net_weights, 190 | mean = net_mean, 191 | channel_swap=(2,1,0), 192 | input_scale=255,) 193 | img_pp = net.transformer.preprocess('data',img)[None,:] 194 | return[img_pp, net] 195 | --------------------------------------------------------------------------------