├── .gitignore ├── README.md ├── model.py ├── finetune.py ├── log ├── finetune.log └── scratch.log ├── dataset.py ├── network.py └── assemble_data.py /.gitignore: -------------------------------------------------------------------------------- 1 | .pyc 2 | .npy 3 | .swp 4 | .idea -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensorflow-finetune-flickr-style 2 | 3 | In this project, we use flickr style dataset to demonstrate finetune in TensorFlow. 4 | The details please refer to the example from the [Caffe website](http://caffe.berkeleyvision.org/gathered/examples/finetune_flickr_style.html) 5 | 6 | Thank @ethereon and @sergeyk for their code. We modify the network.py from [caffe-tensorflow](https://github.com/ethereon/caffe-tensorflow) and flickr.py from [vislab](https://github.com/sergeyk/vislab) for our use. 7 | 8 | 9 | ### Download flickr style dataset 10 | 11 | ```sh 12 | # Download dataset 13 | $ python assemble_data.py images train.txt test.txt 500 14 | ``` 15 | 16 | ### Download the pre-trained model 17 | 18 | Download link: [here](https://drive.google.com/open?id=0B1TxGXQOCIQQME1peW9USXBDME0) 19 | 20 | Or follow the tutorial and extract bvlc_alexnet.npy from https://github.com/ethereon/caffe-tensorflow 21 | 22 | 23 | ### Lauch the finetune process 24 | 25 | ```sh 26 | $ python finetune.py train.txt test.txt bvlc_alexnet.npy 27 | ``` 28 | 29 | ### Finetune result 30 | 31 | ```sh 32 | // Fine-tuning result 33 | Iter 1280: Testing Accuracy = 0.3250 34 | // From scratch result 35 | Iter 1280: Testing Accuracy = 0.1655 36 | ``` 37 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import sys 4 | from network import * 5 | 6 | class Model: 7 | @staticmethod 8 | def alexnet(net_input, keep_rate): 9 | # TODO weight decay loss tern 10 | # Layer 1 (conv-relu-pool-lrn) 11 | conv1 = conv(net_input, 11, 11, 96, 4, 4, padding='VALID', name='conv1') 12 | conv1 = max_pool(conv1, 3, 3, 2, 2, padding='VALID', name='pool1') 13 | norm1 = lrn(conv1, 2, 2e-05, 0.75, name='norm1') 14 | # Layer 2 (conv-relu-pool-lrn) 15 | conv2 = conv(norm1, 5, 5, 256, 1, 1, group=2, name='conv2') 16 | conv2 = max_pool(conv2, 3, 3, 2, 2, padding='VALID', name='pool2') 17 | norm2 = lrn(conv2, 2, 2e-05, 0.75, name='norm2') 18 | # Layer 3 (conv-relu) 19 | conv3 = conv(norm2, 3, 3, 384, 1, 1, name='conv3') 20 | # Layer 4 (conv-relu) 21 | conv4 = conv(conv3, 3, 3, 384, 1, 1, group=2, name='conv4') 22 | # Layer 5 (conv-relu-pool) 23 | conv5 = conv(conv4, 3, 3, 256, 1, 1, group=2, name='conv5') 24 | pool5 = max_pool(conv5, 3, 3, 2, 2, padding='VALID', name='pool5') 25 | # Layer 6 (fc-relu-drop) 26 | fc6 = tf.reshape(pool5, [-1, 6*6*256]) 27 | fc6 = fc(fc6, 6*6*256, 4096, name='fc6') 28 | fc6 = dropout(fc6, keep_rate) 29 | # Layer 7 (fc-relu-drop) 30 | fc7 = fc(fc6, 4096, 4096, name='fc7') 31 | fc7 = dropout(fc7, keep_rate) 32 | # Layer 8 (fc-prob) 33 | fc8 = fc(fc7, 4096, 20, relu=False, name='fc8') 34 | return fc8 35 | 36 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import sys 3 | import tensorflow as tf 4 | 5 | from model import Model 6 | from dataset import Dataset 7 | from network import * 8 | 9 | 10 | def main(): 11 | if len(sys.argv) != 4: 12 | print('Usage: python finetune.py train_file test_file weight_file') 13 | return 14 | 15 | train_file = sys.argv[1] 16 | test_file = sys.argv[2] 17 | weight_file = sys.argv[3] 18 | 19 | # Learning params 20 | learning_rate = 0.001 21 | training_iters = 12800 22 | batch_size = 50 23 | display_step = 20 24 | test_step = 640 25 | 26 | # Network params 27 | n_classes = 20 28 | keep_rate = 0.5 29 | 30 | # Graph input 31 | x = tf.placeholder(tf.float32, [batch_size, 227, 227, 3]) 32 | y = tf.placeholder(tf.float32, [None, n_classes]) 33 | keep_var = tf.placeholder(tf.float32) 34 | 35 | # Model 36 | pred = Model.alexnet(x, keep_var) 37 | 38 | # Loss and optimizer 39 | loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y)) 40 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(loss) 41 | 42 | # Evaluation 43 | correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) 44 | accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) 45 | 46 | # Init 47 | init = tf.global_variables_initializer() 48 | 49 | # Load dataset 50 | dataset = Dataset(train_file, test_file) 51 | 52 | # Launch the graph 53 | with tf.Session() as sess: 54 | print('Init variable') 55 | sess.run(init) 56 | 57 | print('Load pre-trained model: {}'.format(weight_file)) 58 | load_with_skip(weight_file, sess, ['fc8']) # Skip weights from fc8 59 | 60 | print('Start training') 61 | step = 1 62 | while step < training_iters: 63 | batch_xs, batch_ys = dataset.next_batch(batch_size, 'train') 64 | sess.run(optimizer, feed_dict={x: batch_xs, y: batch_ys, keep_var: keep_rate}) 65 | 66 | # Display testing status 67 | if step % test_step == 0: 68 | test_acc = 0. 69 | test_count = 0 70 | for _ in range(dataset.test_size // batch_size): 71 | batch_tx, batch_ty = dataset.next_batch(batch_size, 'test') 72 | acc = sess.run(accuracy, feed_dict={x: batch_tx, y: batch_ty, keep_var: 1.}) 73 | test_acc += acc 74 | test_count += 1 75 | test_acc /= test_count 76 | print('{} Iter {}: Testing Accuracy = {:.4f}'.format(datetime.now(), step, test_acc), file=sys.stderr) 77 | 78 | # Display training status 79 | if step % display_step == 0: 80 | acc = sess.run(accuracy, feed_dict={x: batch_xs, y: batch_ys, keep_var: 1.}) 81 | batch_loss = sess.run(loss, feed_dict={x: batch_xs, y: batch_ys, keep_var: 1.}) 82 | print('{} Iter {}: Training Loss = {:.4f}, Accuracy = {:.4f}'.format(datetime.now(), step, batch_loss, acc), file=sys.stderr) 83 | 84 | step += 1 85 | print('Finish!') 86 | 87 | 88 | if __name__ == '__main__': 89 | main() 90 | -------------------------------------------------------------------------------- /log/finetune.log: -------------------------------------------------------------------------------- 1 | I tensorflow/core/common_runtime/local_device.cc:25] Local device intra op parallelism threads: 32 2 | I tensorflow/core/common_runtime/gpu/gpu_init.cc:88] Found device 0 with properties: 3 | name: Tesla K40m 4 | major: 3 minor: 5 memoryClockRate (GHz) 0.745 5 | pciBusID 0000:05:00.0 6 | Total memory: 11.25GiB 7 | Free memory: 11.15GiB 8 | I tensorflow/core/common_runtime/gpu/gpu_init.cc:112] DMA: 0 9 | I tensorflow/core/common_runtime/gpu/gpu_init.cc:122] 0: Y 10 | I tensorflow/core/common_runtime/gpu/gpu_device.cc:643] Creating TensorFlow device (/gpu:0) -> (device: 0, name: Tesla K40m, pci bus id: 0000:05:00.0) 11 | I tensorflow/core/common_runtime/gpu/gpu_region_allocator.cc:47] Setting region size to 11377705370 12 | I tensorflow/core/common_runtime/local_session.cc:45] Local session inter op parallelism threads: 32 13 | 2016-02-10 19:56:18.485604 Iter 20: Training Loss = 2.9940, Accuracy = 0.0800 14 | 2016-02-10 19:56:34.004613 Iter 40: Training Loss = 2.7346, Accuracy = 0.1400 15 | 2016-02-10 19:56:49.646841 Iter 60: Training Loss = 2.6816, Accuracy = 0.1600 16 | 2016-02-10 19:57:05.170348 Iter 80: Training Loss = 2.3965, Accuracy = 0.3400 17 | 2016-02-10 19:57:20.537985 Iter 100: Training Loss = 2.5446, Accuracy = 0.1200 18 | 2016-02-10 19:57:35.900998 Iter 120: Training Loss = 2.5917, Accuracy = 0.2200 19 | 2016-02-10 19:57:51.001218 Iter 140: Training Loss = 2.3370, Accuracy = 0.3600 20 | 2016-02-10 19:58:06.071278 Iter 160: Training Loss = 2.4684, Accuracy = 0.2800 21 | 2016-02-10 19:58:21.307901 Iter 180: Training Loss = 2.4611, Accuracy = 0.1600 22 | 2016-02-10 19:58:36.631471 Iter 200: Training Loss = 2.5977, Accuracy = 0.2200 23 | 2016-02-10 19:58:52.102553 Iter 220: Training Loss = 2.5604, Accuracy = 0.3200 24 | 2016-02-10 19:59:07.817598 Iter 240: Training Loss = 2.4314, Accuracy = 0.2200 25 | 2016-02-10 19:59:23.271027 Iter 260: Training Loss = 2.4652, Accuracy = 0.3400 26 | 2016-02-10 19:59:39.007032 Iter 280: Training Loss = 2.4593, Accuracy = 0.3400 27 | 2016-02-10 19:59:54.603981 Iter 300: Training Loss = 2.4865, Accuracy = 0.2800 28 | 2016-02-10 20:00:09.922842 Iter 320: Training Loss = 2.3478, Accuracy = 0.3400 29 | 2016-02-10 20:00:25.330387 Iter 340: Training Loss = 2.3171, Accuracy = 0.3400 30 | 2016-02-10 20:00:41.054655 Iter 360: Training Loss = 2.3333, Accuracy = 0.3400 31 | 2016-02-10 20:00:56.684913 Iter 380: Training Loss = 2.3568, Accuracy = 0.2600 32 | 2016-02-10 20:01:12.378246 Iter 400: Training Loss = 2.3520, Accuracy = 0.3600 33 | 2016-02-10 20:01:27.915051 Iter 420: Training Loss = 2.2507, Accuracy = 0.2800 34 | 2016-02-10 20:01:43.365520 Iter 440: Training Loss = 2.3765, Accuracy = 0.3000 35 | 2016-02-10 20:01:58.856385 Iter 460: Training Loss = 2.1458, Accuracy = 0.2800 36 | 2016-02-10 20:02:14.400480 Iter 480: Training Loss = 2.1913, Accuracy = 0.3800 37 | 2016-02-10 20:02:30.051589 Iter 500: Training Loss = 2.3197, Accuracy = 0.3000 38 | 2016-02-10 20:02:45.587557 Iter 520: Training Loss = 2.2722, Accuracy = 0.2400 39 | 2016-02-10 20:03:01.480268 Iter 540: Training Loss = 2.0654, Accuracy = 0.4600 40 | 2016-02-10 20:03:17.133404 Iter 560: Training Loss = 2.1883, Accuracy = 0.3600 41 | 2016-02-10 20:03:32.878469 Iter 580: Training Loss = 2.3559, Accuracy = 0.2800 42 | 2016-02-10 20:03:48.026672 Iter 600: Training Loss = 2.3969, Accuracy = 0.2200 43 | 2016-02-10 20:04:03.195702 Iter 620: Training Loss = 2.1030, Accuracy = 0.4400 44 | 2016-02-10 20:06:59.260866 Iter 640: Testing Accuracy = 0.2801 45 | -------------------------------------------------------------------------------- /log/scratch.log: -------------------------------------------------------------------------------- 1 | I tensorflow/core/common_runtime/local_device.cc:25] Local device intra op parallelism threads: 32 2 | I tensorflow/core/common_runtime/gpu/gpu_init.cc:88] Found device 0 with properties: 3 | name: Tesla K40m 4 | major: 3 minor: 5 memoryClockRate (GHz) 0.745 5 | pciBusID 0000:05:00.0 6 | Total memory: 11.25GiB 7 | Free memory: 11.15GiB 8 | I tensorflow/core/common_runtime/gpu/gpu_init.cc:112] DMA: 0 9 | I tensorflow/core/common_runtime/gpu/gpu_init.cc:122] 0: Y 10 | I tensorflow/core/common_runtime/gpu/gpu_device.cc:643] Creating TensorFlow device (/gpu:0) -> (device: 0, name: Tesla K40m, pci bus id: 0000:05:00.0) 11 | I tensorflow/core/common_runtime/gpu/gpu_region_allocator.cc:47] Setting region size to 11377705370 12 | I tensorflow/core/common_runtime/local_session.cc:45] Local session inter op parallelism threads: 32 13 | 2016-02-10 19:44:22.296913 Iter 20: Training Loss = 2.8392, Accuracy = 0.1000 14 | 2016-02-10 19:44:37.397859 Iter 40: Training Loss = 2.8207, Accuracy = 0.1000 15 | 2016-02-10 19:44:52.477623 Iter 60: Training Loss = 2.8653, Accuracy = 0.1400 16 | 2016-02-10 19:45:07.455741 Iter 80: Training Loss = 2.7354, Accuracy = 0.1400 17 | 2016-02-10 19:45:22.321581 Iter 100: Training Loss = 2.7453, Accuracy = 0.1200 18 | 2016-02-10 19:45:37.366398 Iter 120: Training Loss = 2.7889, Accuracy = 0.2200 19 | 2016-02-10 19:45:52.198182 Iter 140: Training Loss = 2.6468, Accuracy = 0.2000 20 | 2016-02-10 19:46:07.138030 Iter 160: Training Loss = 2.7515, Accuracy = 0.1400 21 | 2016-02-10 19:46:21.972835 Iter 180: Training Loss = 2.8239, Accuracy = 0.1600 22 | 2016-02-10 19:46:36.879264 Iter 200: Training Loss = 2.7449, Accuracy = 0.1600 23 | 2016-02-10 19:46:52.029270 Iter 220: Training Loss = 2.7727, Accuracy = 0.1800 24 | 2016-02-10 19:47:07.171615 Iter 240: Training Loss = 2.7483, Accuracy = 0.1400 25 | 2016-02-10 19:47:22.190170 Iter 260: Training Loss = 2.6880, Accuracy = 0.2800 26 | 2016-02-10 19:47:37.241129 Iter 280: Training Loss = 2.7259, Accuracy = 0.1600 27 | 2016-02-10 19:47:52.341997 Iter 300: Training Loss = 2.7181, Accuracy = 0.1800 28 | 2016-02-10 19:48:07.603788 Iter 320: Training Loss = 2.6874, Accuracy = 0.2000 29 | 2016-02-10 19:48:22.671804 Iter 340: Training Loss = 2.7065, Accuracy = 0.1800 30 | 2016-02-10 19:48:37.584380 Iter 360: Training Loss = 2.6199, Accuracy = 0.2000 31 | 2016-02-10 19:48:52.728980 Iter 380: Training Loss = 2.6831, Accuracy = 0.1600 32 | 2016-02-10 19:49:07.917627 Iter 400: Training Loss = 2.7135, Accuracy = 0.1400 33 | 2016-02-10 19:49:22.804866 Iter 420: Training Loss = 2.6694, Accuracy = 0.2000 34 | 2016-02-10 19:49:37.877277 Iter 440: Training Loss = 2.7154, Accuracy = 0.1400 35 | 2016-02-10 19:49:52.975109 Iter 460: Training Loss = 2.6224, Accuracy = 0.1600 36 | 2016-02-10 19:50:08.013710 Iter 480: Training Loss = 2.7775, Accuracy = 0.1400 37 | 2016-02-10 19:50:23.045544 Iter 500: Training Loss = 2.6879, Accuracy = 0.1600 38 | 2016-02-10 19:50:38.058165 Iter 520: Training Loss = 2.6325, Accuracy = 0.1800 39 | 2016-02-10 19:50:53.186669 Iter 540: Training Loss = 2.4613, Accuracy = 0.3000 40 | 2016-02-10 19:51:08.081020 Iter 560: Training Loss = 2.6298, Accuracy = 0.2400 41 | 2016-02-10 19:51:23.067064 Iter 580: Training Loss = 2.7325, Accuracy = 0.1800 42 | 2016-02-10 19:51:38.210447 Iter 600: Training Loss = 2.7366, Accuracy = 0.2000 43 | 2016-02-10 19:51:53.403916 Iter 620: Training Loss = 2.5829, Accuracy = 0.2400 44 | 2016-02-10 19:54:46.389659 Iter 640: Testing Accuracy = 0.1588 45 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | class Dataset: 5 | def __init__(self, train_list, test_list): 6 | # Load training images (path) and labels 7 | with open(train_list) as f: 8 | lines = f.readlines() 9 | self.train_image = [] 10 | self.train_label = [] 11 | for l in lines: 12 | items = l.split() 13 | self.train_image.append(items[0]) 14 | self.train_label.append(int(items[1])) 15 | 16 | # Load testing images (path) and labels 17 | with open(test_list) as f: 18 | lines = f.readlines() 19 | self.test_image = [] 20 | self.test_label = [] 21 | for l in lines: 22 | items = l.split() 23 | self.test_image.append(items[0]) 24 | self.test_label.append(int(items[1])) 25 | 26 | # Init params 27 | self.train_ptr = 0 28 | self.test_ptr = 0 29 | self.train_size = len(self.train_label) 30 | self.test_size = len(self.test_label) 31 | self.crop_size = 227 32 | self.scale_size = 256 33 | self.mean = np.array([104., 117., 124.]) 34 | self.n_classes = 20 35 | 36 | def next_batch(self, batch_size, phase): 37 | # Get next batch of image (path) and labels 38 | if phase == 'train': 39 | if self.train_ptr + batch_size < self.train_size: 40 | paths = self.train_image[self.train_ptr:self.train_ptr + batch_size] 41 | labels = self.train_label[self.train_ptr:self.train_ptr + batch_size] 42 | self.train_ptr += batch_size 43 | else: 44 | new_ptr = (self.train_ptr + batch_size)%self.train_size 45 | paths = self.train_image[self.train_ptr:] + self.train_image[:new_ptr] 46 | labels = self.train_label[self.train_ptr:] + self.train_label[:new_ptr] 47 | self.train_ptr = new_ptr 48 | elif phase == 'test': 49 | if self.test_ptr + batch_size < self.test_size: 50 | paths = self.test_image[self.test_ptr:self.test_ptr + batch_size] 51 | labels = self.test_label[self.test_ptr:self.test_ptr + batch_size] 52 | self.test_ptr += batch_size 53 | else: 54 | new_ptr = (self.test_ptr + batch_size)%self.test_size 55 | paths = self.test_image[self.test_ptr:] + self.test_image[:new_ptr] 56 | labels = self.test_label[self.test_ptr:] + self.test_label[:new_ptr] 57 | self.test_ptr = new_ptr 58 | else: 59 | return None, None 60 | 61 | # Read images 62 | images = np.ndarray([batch_size, self.crop_size, self.crop_size, 3]) 63 | for i, path in enumerate(paths): 64 | img = cv2.imread(path) 65 | h, w, c = img.shape 66 | assert c == 3 67 | img = cv2.resize(img, (self.scale_size, self.scale_size)) 68 | img = img.astype(np.float32) 69 | img -= self.mean 70 | shift = (self.scale_size - self.crop_size) // 2 71 | img_crop = img[shift: shift + self.crop_size, shift: shift + self.crop_size, :] 72 | images[i] = img_crop 73 | 74 | # Expand labels 75 | one_hot_labels = np.zeros((batch_size, self.n_classes)) 76 | for i, label in enumerate(labels): 77 | one_hot_labels[i][label] = 1 78 | return images, one_hot_labels 79 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | DEFAULT_PADDING = 'SAME' 5 | 6 | 7 | def load(data_path, session): 8 | data_dict = np.load(data_path).item() 9 | for key in data_dict: 10 | with tf.variable_scope(key, reuse=True): 11 | for subkey, data in zip(('weights', 'biases'), data_dict[key]): 12 | session.run(tf.get_variable(subkey).assign(data)) 13 | 14 | 15 | def load_with_skip(data_path, session, skip_layer): 16 | data_dict = np.load(data_path).item() 17 | for key in data_dict: 18 | if key not in skip_layer: 19 | with tf.variable_scope(key, reuse=True): 20 | for subkey, data in zip(('weights', 'biases'), data_dict[key]): 21 | session.run(tf.get_variable(subkey).assign(data)) 22 | 23 | 24 | def make_var(name, shape): 25 | return tf.get_variable(name, shape) 26 | 27 | 28 | def conv(input, k_h, k_w, c_o, s_h, s_w, name, relu=True, padding=DEFAULT_PADDING, group=1): 29 | c_i = int(input.get_shape()[-1]) 30 | assert c_i%group==0 31 | assert c_o%group==0 32 | convolve = lambda i, k: tf.nn.conv2d(i, k, [1, s_h, s_w, 1], padding=padding) 33 | with tf.variable_scope(name) as scope: 34 | kernel = make_var('weights', shape=[k_h, k_w, c_i/group, c_o]) 35 | biases = make_var('biases', [c_o]) 36 | if group==1: 37 | conv = convolve(input, kernel) 38 | else: 39 | input_groups = tf.split(input, group, 3) 40 | kernel_groups = tf.split(kernel, group, 3) 41 | output_groups = [convolve(i, k) for i, k in zip(input_groups, kernel_groups)] 42 | conv = tf.concat(output_groups, 3) 43 | if relu: 44 | bias = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape().as_list()) 45 | return tf.nn.relu(bias, name=scope.name) 46 | return tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape().as_list(), name=scope.name) 47 | 48 | 49 | def relu(input, name): 50 | return tf.nn.relu(input, name=name) 51 | 52 | 53 | def max_pool(input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING): 54 | return tf.nn.max_pool(input, 55 | ksize=[1, k_h, k_w, 1], 56 | strides=[1, s_h, s_w, 1], 57 | padding=padding, 58 | name=name) 59 | 60 | 61 | def avg_pool(input, k_h, k_w, s_h, s_w, name, padding=DEFAULT_PADDING): 62 | return tf.nn.avg_pool(input, 63 | ksize=[1, k_h, k_w, 1], 64 | strides=[1, s_h, s_w, 1], 65 | padding=padding, 66 | name=name) 67 | 68 | 69 | def lrn(input, radius, alpha, beta, name, bias=1.0): 70 | return tf.nn.local_response_normalization(input, 71 | depth_radius=radius, 72 | alpha=alpha, 73 | beta=beta, 74 | bias=bias, 75 | name=name) 76 | 77 | 78 | def concat(inputs, axis, name): 79 | return tf.concat(inputs, axis, name=name) 80 | 81 | 82 | def fc(input, num_in, num_out, name, relu=True): 83 | with tf.variable_scope(name) as scope: 84 | weights = make_var('weights', shape=[num_in, num_out]) 85 | biases = make_var('biases', [num_out]) 86 | op = tf.nn.relu_layer if relu else tf.nn.xw_plus_b 87 | fc = op(input, weights, biases, name=scope.name) 88 | return fc 89 | 90 | 91 | def softmax(input, name): 92 | return tf.nn.softmax(input, name) 93 | 94 | 95 | def dropout(input, keep_prob): 96 | return tf.nn.dropout(input, keep_prob) 97 | -------------------------------------------------------------------------------- /assemble_data.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import glob 3 | import hashlib 4 | import os 5 | import random 6 | import requests 7 | import shutil 8 | import sys 9 | import time 10 | 11 | import cv2 12 | from skimage import io 13 | 14 | # Mapping: (class, Name, groups) 15 | STYLE_MAPPING = [ 16 | (0, 'Bokeh', ['1543486@N25']), 17 | (1, 'Bright', ['799643@N24']), 18 | (2, 'Depth_of_Field', ['75418467@N00', '407825@N20']), 19 | (3, 'Detailed', ['1670588@N24', '1131378@N23']), 20 | (4, 'Ethereal', ['907784@N22']), 21 | (5, 'Geometric_Composition', ['46353124@N00']), 22 | (6, 'Hazy', ['38694591@N00']), 23 | (7, 'HDR', ['99275357@N00']), 24 | (8, 'Horror', ['29561404@N00']), 25 | (9, 'Long_Exposure', ['52240257802@N01']), 26 | (10, 'Macro', ['52241335207@N01']), 27 | (11, 'Melancholy', ['70495179@N00']), 28 | (12, 'Minimal', ['42097308@N00']), 29 | (13, 'Noir', ['42109523@N00']), 30 | (14, 'Romantic', ['54284561@N00']), 31 | (15, 'Serene', ['1081625@N25']), 32 | (16, 'Pastel', ['1055565@N24', '1371818@N25']), 33 | (17, 'Sunny', ['1242213@N23']), 34 | (18, 'Texture', ['70176273@N00']), 35 | (19, 'Vintage', ['1222306@N25', "1176551@N24"]), 36 | ] 37 | 38 | 39 | def main(): 40 | if len(sys.argv) != 5: 41 | print('Usage: python assemble_data.py image_path train_file test_file images_per_style') 42 | return 43 | 44 | image_path = os.path.abspath(sys.argv[1]) 45 | train_file = sys.argv[2] 46 | test_file = sys.argv[3] 47 | images_per_style = int(sys.argv[4]) 48 | 49 | url_file = os.path.join(os.path.dirname(__file__), 'flickr_style_url.txt') 50 | img_info_file = os.path.join(os.path.dirname(__file__), 'flickr_style_img_info.txt') 51 | 52 | collect_image_style_url(url_file, images_per_style) 53 | fetch_images(url_file, img_info_file, image_path) 54 | generate_train_test_dataset(img_info_file, train_file, test_file, train_ratio=0.8) 55 | 56 | 57 | def collect_image_style_url(url_file, photos_per_style): 58 | if os.path.exists(url_file): 59 | print('[Skip] Url file exists: {}'.format(url_file)) 60 | return 61 | 62 | with open(url_file, 'w') as f: 63 | for class_id, style, groups in STYLE_MAPPING: 64 | print('Get_photos_for_style: {}'.format(style)) 65 | urls = get_image_url_from_group(groups, photos_per_style) 66 | for url in urls: 67 | print('{} {}'.format(url, class_id), file=f) 68 | 69 | print('[Done] Url file saves to: {}'.format(url_file)) 70 | 71 | 72 | def get_image_url_from_group(groups, num_images): 73 | params = { 74 | 'api_key': "d31c7cb60c57aa7483c5c80919df5371", 75 | 'per_page': 500, # 500 is the maximum allowed 76 | 'content_type': 1, # only photos 77 | } 78 | 79 | image_urls = [] 80 | for page in range(10): 81 | params['page'] = page 82 | 83 | for group in groups: 84 | params['group_id'] = group 85 | 86 | url = ('https://api.flickr.com/services/rest/?' 87 | 'method=flickr.photos.search&format=json&nojsoncallback=1' 88 | '&api_key={api_key}&content_type={content_type}' 89 | '&group_id={group_id}&page={page}&per_page={per_page}') 90 | url = url.format(**params) 91 | 92 | # Make the request and ensure it succeeds. 93 | try: 94 | page_data = requests.get(url).json() 95 | except: 96 | print(requests.get(url)) 97 | raise 98 | if page_data['stat'] != 'ok': 99 | raise Exception("Something is wrong: API returned {}".format(page_data['stat'])) 100 | 101 | for photo_item in page_data['photos']['photo']: 102 | image_urls.append(_get_image_url(photo_item)) 103 | 104 | if len(image_urls) >= num_images: 105 | return image_urls[:num_images] 106 | 107 | raise Exception('Not enough images, only find {}'.format(len(image_urls))) 108 | 109 | 110 | def _get_image_url(photo_item, size_flag=''): 111 | """ 112 | size_flag: string [''] 113 | See http://www.flickr.com/services/api/misc.urls.html for options. 114 | '': 500 px on longest side 115 | '_m': 240px on longest side 116 | """ 117 | url = "http://farm{farm}.staticflickr.com/{server}/{id}_{secret}{size}.jpg" 118 | return url.format(size=size_flag, **photo_item) 119 | 120 | 121 | def fetch_images(url_file, img_info_file, image_folder): 122 | if os.path.exists(img_info_file): 123 | print('[Skip] Image info file exists: {}'.format(img_info_file)) 124 | return 125 | 126 | os.makedirs(image_folder, exist_ok=True) 127 | 128 | with open(url_file, 'r') as f: 129 | lines = [line.strip() for line in f] 130 | 131 | image_info = [] 132 | for line in lines: 133 | url, class_id = line.strip().split() 134 | image_name = _get_image_name(url, class_id) 135 | image_file = os.path.join(image_folder, image_name) 136 | 137 | # Download and verify 138 | if not os.path.exists(image_file): 139 | res = download_image(url, image_file) 140 | res = verify_image(image_file) 141 | 142 | if not res: 143 | print('[FAILURE] {}'.format(url)) 144 | else: 145 | image_info.append((image_file, class_id)) 146 | print('[SUCCESS] {}'.format(url)) 147 | 148 | with open(img_info_file, 'w') as f: 149 | for image_file, class_id in image_info: 150 | print('{} {}'.format(image_file, class_id), file=f) 151 | 152 | print('Success: {}, Failure: {}'.format(len(image_info), len(lines) - len(image_info))) 153 | print('[Done] Image info file saves to: {}'.format(img_info_file)) 154 | 155 | 156 | def _get_image_name(url, class_id): 157 | return '{}_{}.jpg'.format(hashlib.sha1(url.encode()).hexdigest(), class_id) 158 | 159 | 160 | def download_image(url, file): 161 | try: 162 | if os.path.exists(file): 163 | return True 164 | 165 | r = requests.get(url, stream=True) 166 | if r.status_code == 200: 167 | with open(file, 'wb') as f: 168 | r.raw.decode_content = True 169 | shutil.copyfileobj(r.raw, f) 170 | return True 171 | else: 172 | return False 173 | except KeyboardInterrupt: 174 | raise Exception() # multiprocessing doesn't catch keyboard exceptions 175 | except: 176 | return False 177 | 178 | 179 | def verify_image(img_file): 180 | try: 181 | img = io.imread(img_file) 182 | except: 183 | return False 184 | return True 185 | 186 | 187 | def generate_train_test_dataset(img_info_file, train_file, test_file, train_ratio=0.8): 188 | class_to_images = defaultdict(list) 189 | with open(img_info_file, 'r') as f: 190 | lines = [line.strip() for line in f] 191 | 192 | random.seed(1211) 193 | random.shuffle(lines) 194 | train_size = int(len(lines) * train_ratio) 195 | 196 | with open(train_file, 'w') as f: 197 | for line in lines[:train_size]: 198 | print(line, file=f) 199 | 200 | with open(test_file, 'w') as f: 201 | for line in lines[train_size:]: 202 | print(line, file=f) 203 | 204 | 205 | print('[Done] Test file (size={}) saves to: {}'.format(train_size, train_file)) 206 | print('[Done] Train file (size={}) saves to: {}'.format(len(lines) - train_size, test_file)) 207 | 208 | if __name__ == '__main__': 209 | main() 210 | --------------------------------------------------------------------------------